diff --git a/src/auth/middleware.go b/src/auth/middleware.go index 42e818a..2518932 100644 --- a/src/auth/middleware.go +++ b/src/auth/middleware.go @@ -25,9 +25,7 @@ func BasicAuthMiddleware(handlerFunc http.HandlerFunc, basicAuthProvider BasicAu func doesBasicAuthMatch(r *http.Request, basicAuthProvider BasicAuthProvider) bool { rawUsername, rawPassword, ok := r.BasicAuth() if ok { - username := HashString(rawUsername) - password := HashString(rawPassword) - return basicAuthProvider.DoesBasicAuthMatch(username, password) + return basicAuthProvider.DoesBasicAuthMatch(rawUsername, rawPassword) } return false } diff --git a/src/cfg/basicAuth.go b/src/cfg/basicAuth.go index c79dc8c..bc5408f 100644 --- a/src/cfg/basicAuth.go +++ b/src/cfg/basicAuth.go @@ -1,14 +1,25 @@ package cfg -type basicAuth struct { +import "fail2ban-prometheus-exporter/auth" + +type hashedBasicAuth struct { username string password string } -func (p *basicAuth) Enabled() bool { +func newHashedBasicAuth(rawUsername, rawPassword string) *hashedBasicAuth { + return &hashedBasicAuth{ + username: auth.HashString(rawUsername), + password: auth.HashString(rawPassword), + } +} + +func (p *hashedBasicAuth) Enabled() bool { return len(p.username) > 0 && len(p.password) > 0 } -func (p *basicAuth) DoesBasicAuthMatch(username, password string) bool { +func (p *hashedBasicAuth) DoesBasicAuthMatch(rawUsername, rawPassword string) bool { + username := auth.HashString(rawUsername) + password := auth.HashString(rawPassword) return username == p.username && password == p.password } diff --git a/src/cfg/basicAuth_test.go b/src/cfg/basicAuth_test.go new file mode 100644 index 0000000..85bd8b3 --- /dev/null +++ b/src/cfg/basicAuth_test.go @@ -0,0 +1,60 @@ +package cfg + +import "testing" + +func Test_hashedBasicAuth_DoesBasicAuthMatch(t *testing.T) { + type args struct { + username string + password string + } + type fields struct { + username string + password string + } + tests := []struct { + name string + fields fields + args args + want bool + }{ + {"Happy test #1", fields{username: "1234", password: "1234"}, args{username: "1234", password: "1234"}, true}, + {"Happy test #2", fields{username: "test", password: "1234"}, args{username: "test", password: "1234"}, true}, + {"Happy test #3", fields{username: "TEST", password: "1234"}, args{username: "TEST", password: "1234"}, true}, + {"Non match #1", fields{username: "test", password: "1234"}, args{username: "1234", password: "1234"}, false}, + {"Non match #2", fields{username: "1234", password: "test"}, args{username: "1234", password: "1234"}, false}, + {"Non match #3", fields{username: "1234", password: "test"}, args{username: "1234", password: "TEST"}, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + basicAuth := newHashedBasicAuth(tt.fields.username, tt.fields.password) + if got := basicAuth.DoesBasicAuthMatch(tt.args.username, tt.args.password); got != tt.want { + t.Errorf("DoesBasicAuthMatch() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_hashedBasicAuth_Enabled(t *testing.T) { + type fields struct { + username string + password string + } + tests := []struct { + name string + fields fields + want bool + }{ + {"Both blank", fields{username: "", password: ""}, false}, + {"Single blank #1", fields{username: "test", password: ""}, false}, + {"Single blank #1", fields{username: "", password: "test"}, false}, + {"Both populated", fields{username: "test", password: "test"}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + basicAuth := newHashedBasicAuth(tt.fields.username, tt.fields.password) + if got := basicAuth.Enabled(); got != tt.want { + t.Errorf("Enabled() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/src/cfg/cfg.go b/src/cfg/cfg.go index 0b33b99..008fef1 100644 --- a/src/cfg/cfg.go +++ b/src/cfg/cfg.go @@ -1,7 +1,6 @@ package cfg import ( - "fail2ban-prometheus-exporter/auth" "flag" "fmt" "os" @@ -19,7 +18,7 @@ type AppSettings struct { Fail2BanSocketPath string FileCollectorPath string FileCollectorEnabled bool - BasicAuthProvider *basicAuth + BasicAuthProvider *hashedBasicAuth } func Parse() *AppSettings { @@ -43,10 +42,7 @@ func Parse() *AppSettings { } func (settings *AppSettings) setBasicAuthValues(rawUsername, rawPassword string) { - settings.BasicAuthProvider = &basicAuth{ - username: auth.HashString(rawUsername), - password: auth.HashString(rawPassword), - } + settings.BasicAuthProvider = newHashedBasicAuth(rawUsername, rawPassword) } func (settings *AppSettings) validateFlags() {