add new basic auth provider class

Fix the circular dependency by adding a new BasicAuthProvider interface to
wrap the basic auth values.
Add unit tests for the hash functions.
Refactor functions to use the new basic auth provider interface.
This commit is contained in:
Hector 2022-01-13 21:25:37 +00:00
parent 26f0b0d0ce
commit e78d4e7298
6 changed files with 64 additions and 16 deletions

View File

@ -2,13 +2,17 @@ package auth
import (
"crypto/sha256"
"encoding/hex"
)
func Hash(data []byte) []byte {
if len(data) == 0 {
return []byte{}
}
b := sha256.Sum256(data)
return b[:]
}
func HashString(data string) string {
return string(Hash([]byte(data)))
return hex.EncodeToString(Hash([]byte(data)))
}

26
src/auth/hash_test.go Normal file
View File

@ -0,0 +1,26 @@
package auth
import (
"reflect"
"testing"
)
func TestHashString(t *testing.T) {
tests := []struct {
name string
args string
want string
}{
{"Happy path #1", "123", "a665a45920422f9d417e4867efdc4fb8a04a1f3fff1fa07e998e86f7f7a27ae3"},
{"Happy path #2", "hello world", "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9"},
{"Happy path #3", "H3Ll0_W0RLD", "d58a27fe9a6e73a1d8a67189fb8acace047e7a1a795276a0056d3717ad61bd0e"},
{"Blank string", "", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := HashString(tt.args); !reflect.DeepEqual(got, tt.want) {
t.Errorf("HashString() = %v, want %v", got, tt.want)
}
})
}
}

View File

@ -1,15 +1,18 @@
package auth
import (
"fail2ban-prometheus-exporter/cfg"
"net/http"
)
func BasicAuthMiddleware(handlerFunc http.HandlerFunc, appSettings *cfg.AppSettings) http.HandlerFunc {
authEnabled := len(appSettings.BasicAuthUsername) > 0 && len(appSettings.BasicAuthPassword) > 0
if authEnabled {
type BasicAuthProvider interface {
Enabled() bool
DoesBasicAuthMatch(username, password string) bool
}
func BasicAuthMiddleware(handlerFunc http.HandlerFunc, basicAuthProvider BasicAuthProvider) http.HandlerFunc {
if basicAuthProvider.Enabled() {
return func(w http.ResponseWriter, r *http.Request) {
if doesBasicAuthMatch(r, appSettings) {
if doesBasicAuthMatch(r, basicAuthProvider) {
w.WriteHeader(http.StatusUnauthorized)
} else {
handlerFunc.ServeHTTP(w, r)
@ -19,12 +22,12 @@ func BasicAuthMiddleware(handlerFunc http.HandlerFunc, appSettings *cfg.AppSetti
return handlerFunc
}
func doesBasicAuthMatch(r *http.Request, appSettings *cfg.AppSettings) bool {
func doesBasicAuthMatch(r *http.Request, basicAuthProvider BasicAuthProvider) bool {
rawUsername, rawPassword, ok := r.BasicAuth()
if ok {
username := HashString(rawUsername)
password := HashString(rawPassword)
return username == appSettings.BasicAuthUsername && password == appSettings.BasicAuthPassword
return basicAuthProvider.DoesBasicAuthMatch(username, password)
}
return false
}

14
src/cfg/basicAuth.go Normal file
View File

@ -0,0 +1,14 @@
package cfg
type basicAuth struct {
username string
password string
}
func (p *basicAuth) Enabled() bool {
return len(p.username) > 0 && len(p.password) > 0
}
func (p *basicAuth) DoesBasicAuthMatch(username, password string) bool {
return username == p.username && password == p.password
}

View File

@ -19,8 +19,7 @@ type AppSettings struct {
Fail2BanSocketPath string
FileCollectorPath string
FileCollectorEnabled bool
BasicAuthUsername string
BasicAuthPassword string
BasicAuthProvider *basicAuth
}
func Parse() *AppSettings {
@ -43,9 +42,11 @@ func Parse() *AppSettings {
return appSettings
}
func (settings *AppSettings) setBasicAuthValues(username, password string) {
settings.BasicAuthUsername = auth.HashString(username)
settings.BasicAuthPassword = auth.HashString(password)
func (settings *AppSettings) setBasicAuthValues(rawUsername, rawPassword string) {
settings.BasicAuthProvider = &basicAuth{
username: auth.HashString(rawUsername),
password: auth.HashString(rawPassword),
}
}
func (settings *AppSettings) validateFlags() {
@ -64,7 +65,7 @@ func (settings *AppSettings) validateFlags() {
fmt.Printf("file collector directory path must not be empty if collector enabled\n")
flagsValid = false
}
if (len(settings.BasicAuthUsername) > 0) != (len(settings.BasicAuthPassword) > 0) {
if (len(settings.BasicAuthProvider.username) > 0) != (len(settings.BasicAuthProvider.password) > 0) {
fmt.Printf("to enable basic auth both the username and the password must be provided")
flagsValid = false
}

View File

@ -64,12 +64,12 @@ func main() {
textFileCollector := textfile.NewCollector(appSettings)
prometheus.MustRegister(textFileCollector)
http.HandleFunc("/", auth.BasicAuthMiddleware(rootHtmlHandler, appSettings))
http.HandleFunc("/", auth.BasicAuthMiddleware(rootHtmlHandler, appSettings.BasicAuthProvider))
http.HandleFunc(metricsPath, auth.BasicAuthMiddleware(
func(w http.ResponseWriter, r *http.Request) {
metricHandler(w, r, textFileCollector)
},
appSettings,
appSettings.BasicAuthProvider,
))
log.Printf("metrics available at '%s'", metricsPath)