add new basic auth provider class

Fix the circular dependency by adding a new BasicAuthProvider interface to
wrap the basic auth values.
Add unit tests for the hash functions.
Refactor functions to use the new basic auth provider interface.
This commit is contained in:
Hector 2022-01-13 21:25:37 +00:00
parent 26f0b0d0ce
commit e78d4e7298
6 changed files with 64 additions and 16 deletions

View File

@ -2,13 +2,17 @@ package auth
import ( import (
"crypto/sha256" "crypto/sha256"
"encoding/hex"
) )
func Hash(data []byte) []byte { func Hash(data []byte) []byte {
if len(data) == 0 {
return []byte{}
}
b := sha256.Sum256(data) b := sha256.Sum256(data)
return b[:] return b[:]
} }
func HashString(data string) string { func HashString(data string) string {
return string(Hash([]byte(data))) return hex.EncodeToString(Hash([]byte(data)))
} }

26
src/auth/hash_test.go Normal file
View File

@ -0,0 +1,26 @@
package auth
import (
"reflect"
"testing"
)
func TestHashString(t *testing.T) {
tests := []struct {
name string
args string
want string
}{
{"Happy path #1", "123", "a665a45920422f9d417e4867efdc4fb8a04a1f3fff1fa07e998e86f7f7a27ae3"},
{"Happy path #2", "hello world", "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9"},
{"Happy path #3", "H3Ll0_W0RLD", "d58a27fe9a6e73a1d8a67189fb8acace047e7a1a795276a0056d3717ad61bd0e"},
{"Blank string", "", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := HashString(tt.args); !reflect.DeepEqual(got, tt.want) {
t.Errorf("HashString() = %v, want %v", got, tt.want)
}
})
}
}

View File

@ -1,15 +1,18 @@
package auth package auth
import ( import (
"fail2ban-prometheus-exporter/cfg"
"net/http" "net/http"
) )
func BasicAuthMiddleware(handlerFunc http.HandlerFunc, appSettings *cfg.AppSettings) http.HandlerFunc { type BasicAuthProvider interface {
authEnabled := len(appSettings.BasicAuthUsername) > 0 && len(appSettings.BasicAuthPassword) > 0 Enabled() bool
if authEnabled { DoesBasicAuthMatch(username, password string) bool
}
func BasicAuthMiddleware(handlerFunc http.HandlerFunc, basicAuthProvider BasicAuthProvider) http.HandlerFunc {
if basicAuthProvider.Enabled() {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
if doesBasicAuthMatch(r, appSettings) { if doesBasicAuthMatch(r, basicAuthProvider) {
w.WriteHeader(http.StatusUnauthorized) w.WriteHeader(http.StatusUnauthorized)
} else { } else {
handlerFunc.ServeHTTP(w, r) handlerFunc.ServeHTTP(w, r)
@ -19,12 +22,12 @@ func BasicAuthMiddleware(handlerFunc http.HandlerFunc, appSettings *cfg.AppSetti
return handlerFunc return handlerFunc
} }
func doesBasicAuthMatch(r *http.Request, appSettings *cfg.AppSettings) bool { func doesBasicAuthMatch(r *http.Request, basicAuthProvider BasicAuthProvider) bool {
rawUsername, rawPassword, ok := r.BasicAuth() rawUsername, rawPassword, ok := r.BasicAuth()
if ok { if ok {
username := HashString(rawUsername) username := HashString(rawUsername)
password := HashString(rawPassword) password := HashString(rawPassword)
return username == appSettings.BasicAuthUsername && password == appSettings.BasicAuthPassword return basicAuthProvider.DoesBasicAuthMatch(username, password)
} }
return false return false
} }

14
src/cfg/basicAuth.go Normal file
View File

@ -0,0 +1,14 @@
package cfg
type basicAuth struct {
username string
password string
}
func (p *basicAuth) Enabled() bool {
return len(p.username) > 0 && len(p.password) > 0
}
func (p *basicAuth) DoesBasicAuthMatch(username, password string) bool {
return username == p.username && password == p.password
}

View File

@ -19,8 +19,7 @@ type AppSettings struct {
Fail2BanSocketPath string Fail2BanSocketPath string
FileCollectorPath string FileCollectorPath string
FileCollectorEnabled bool FileCollectorEnabled bool
BasicAuthUsername string BasicAuthProvider *basicAuth
BasicAuthPassword string
} }
func Parse() *AppSettings { func Parse() *AppSettings {
@ -43,9 +42,11 @@ func Parse() *AppSettings {
return appSettings return appSettings
} }
func (settings *AppSettings) setBasicAuthValues(username, password string) { func (settings *AppSettings) setBasicAuthValues(rawUsername, rawPassword string) {
settings.BasicAuthUsername = auth.HashString(username) settings.BasicAuthProvider = &basicAuth{
settings.BasicAuthPassword = auth.HashString(password) username: auth.HashString(rawUsername),
password: auth.HashString(rawPassword),
}
} }
func (settings *AppSettings) validateFlags() { func (settings *AppSettings) validateFlags() {
@ -64,7 +65,7 @@ func (settings *AppSettings) validateFlags() {
fmt.Printf("file collector directory path must not be empty if collector enabled\n") fmt.Printf("file collector directory path must not be empty if collector enabled\n")
flagsValid = false flagsValid = false
} }
if (len(settings.BasicAuthUsername) > 0) != (len(settings.BasicAuthPassword) > 0) { if (len(settings.BasicAuthProvider.username) > 0) != (len(settings.BasicAuthProvider.password) > 0) {
fmt.Printf("to enable basic auth both the username and the password must be provided") fmt.Printf("to enable basic auth both the username and the password must be provided")
flagsValid = false flagsValid = false
} }

View File

@ -64,12 +64,12 @@ func main() {
textFileCollector := textfile.NewCollector(appSettings) textFileCollector := textfile.NewCollector(appSettings)
prometheus.MustRegister(textFileCollector) prometheus.MustRegister(textFileCollector)
http.HandleFunc("/", auth.BasicAuthMiddleware(rootHtmlHandler, appSettings)) http.HandleFunc("/", auth.BasicAuthMiddleware(rootHtmlHandler, appSettings.BasicAuthProvider))
http.HandleFunc(metricsPath, auth.BasicAuthMiddleware( http.HandleFunc(metricsPath, auth.BasicAuthMiddleware(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
metricHandler(w, r, textFileCollector) metricHandler(w, r, textFileCollector)
}, },
appSettings, appSettings.BasicAuthProvider,
)) ))
log.Printf("metrics available at '%s'", metricsPath) log.Printf("metrics available at '%s'", metricsPath)