From 3cff8ccd648155c7a9aef4ebb2ab074270bb4366 Mon Sep 17 00:00:00 2001 From: Hector Date: Wed, 21 Jun 2023 10:31:33 +0000 Subject: [PATCH] refactor: rewrite auth handler code (!89) * Rewrite the code handling basic auth to make it easier to extend for other types of auth. * The behaviour of the existing code is maintained. * No changes to how basic auth is configured from a user's perspective. https://gitlab.com/hectorjsmith/fail2ban-prometheus-exporter/-/merge_requests/89 --- auth/basic.go | 29 +++++++++++++++++++ auth/basic_test.go | 53 ++++++++++++++++++++++++++++++++++ auth/empty.go | 14 +++++++++ auth/empty_test.go | 36 +++++++++++++++++++++++ auth/hash.go | 4 +-- auth/middleware.go | 31 -------------------- auth/middleware_test.go | 58 ------------------------------------- auth/provider.go | 9 ++++++ cfg/basicAuth.go | 25 ---------------- cfg/basicAuth_test.go | 60 --------------------------------------- cfg/cfg.go | 15 +++++++++- cfg/settings.go | 4 ++- exporter.go | 22 +++++++------- server/middleware.go | 17 +++++++++++ server/middleware_test.go | 46 ++++++++++++++++++++++++++++++ 15 files changed, 233 insertions(+), 190 deletions(-) create mode 100644 auth/basic.go create mode 100644 auth/basic_test.go create mode 100644 auth/empty.go create mode 100644 auth/empty_test.go delete mode 100644 auth/middleware.go delete mode 100644 auth/middleware_test.go create mode 100644 auth/provider.go delete mode 100644 cfg/basicAuth.go delete mode 100644 cfg/basicAuth_test.go create mode 100644 server/middleware.go create mode 100644 server/middleware_test.go diff --git a/auth/basic.go b/auth/basic.go new file mode 100644 index 0000000..243b2a5 --- /dev/null +++ b/auth/basic.go @@ -0,0 +1,29 @@ +package auth + +import ( + "fmt" + "net/http" +) + +func NewBasicAuthProvider(username, password string) AuthProvider { + return &basicAuthProvider{ + hashedAuth: encodeBasicAuth(username, password), + } +} + +type basicAuthProvider struct { + hashedAuth string +} + +func (p *basicAuthProvider) IsAllowed(request *http.Request) bool { + username, password, ok := request.BasicAuth() + if !ok { + return false + } + requestAuth := encodeBasicAuth(username, password) + return p.hashedAuth == requestAuth +} + +func encodeBasicAuth(username, password string) string { + return HashString(fmt.Sprintf("%s:%s", username, password)) +} diff --git a/auth/basic_test.go b/auth/basic_test.go new file mode 100644 index 0000000..e4ca1fd --- /dev/null +++ b/auth/basic_test.go @@ -0,0 +1,53 @@ +package auth + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func Test_GIVEN_BasicAuthSet_WHEN_CallingIsAllowedWithCorrectCreds_THEN_TrueReturned(t *testing.T) { + // assemble + username := "u1" + password := HashString("abc") + request := httptest.NewRequest(http.MethodGet, "http://example.com", nil) + request.SetBasicAuth(username, password) + provider := NewBasicAuthProvider(username, password) + + // act + result := provider.IsAllowed(request) + + // assert + if !result { + t.Errorf("expected request to be allowed, but failed") + } +} + +func Test_GIVEN_BasicAuthSet_WHEN_CallingIsAllowedWithoutCreds_THEN_FalseReturned(t *testing.T) { + // assemble + request := httptest.NewRequest(http.MethodGet, "http://example.com", nil) + provider := NewBasicAuthProvider("u1", "p1") + + // act + result := provider.IsAllowed(request) + + // assert + if result { + t.Errorf("expected request to be denied, but was allowed") + } +} + +func Test_GIVEN_BasicAuthSet_WHEN_CallingIsAllowedWithWrongCreds_THEN_FalseReturned(t *testing.T) { + // assemble + request := httptest.NewRequest(http.MethodGet, "http://example.com", nil) + request.SetBasicAuth("wrong", "pw") + provider := NewBasicAuthProvider("u1", "p1") + + // act + result := provider.IsAllowed(request) + + // assert + if result { + t.Errorf("expected request to be denied, but was allowed") + } +} diff --git a/auth/empty.go b/auth/empty.go new file mode 100644 index 0000000..01e6775 --- /dev/null +++ b/auth/empty.go @@ -0,0 +1,14 @@ +package auth + +import "net/http" + +func NewEmptyAuthProvider() AuthProvider { + return &emptyAuthProvider{} +} + +type emptyAuthProvider struct { +} + +func (p *emptyAuthProvider) IsAllowed(request *http.Request) bool { + return true +} diff --git a/auth/empty_test.go b/auth/empty_test.go new file mode 100644 index 0000000..048de9d --- /dev/null +++ b/auth/empty_test.go @@ -0,0 +1,36 @@ +package auth + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func Test_GIVEN_EmptyAuth_WHEN_CallingIsAllowedWithoutAuth_THEN_TrueReturned(t *testing.T) { + // assemble + request := httptest.NewRequest(http.MethodGet, "http://example.com", nil) + provider := NewEmptyAuthProvider() + + // act + response := provider.IsAllowed(request) + + // assert + if !response { + t.Errorf("expected request to be allowed, but failed") + } +} + +func Test_GIVEN_EmptyAuth_WHEN_CallingIsAllowedWithAuth_THEN_TrueReturned(t *testing.T) { + // assemble + request := httptest.NewRequest(http.MethodGet, "http://example.com", nil) + request.SetBasicAuth("user", "pass") + provider := NewEmptyAuthProvider() + + // act + response := provider.IsAllowed(request) + + // assert + if !response { + t.Errorf("expected request to be allowed, but failed") + } +} diff --git a/auth/hash.go b/auth/hash.go index e1b4b3b..cc41e36 100644 --- a/auth/hash.go +++ b/auth/hash.go @@ -5,7 +5,7 @@ import ( "encoding/hex" ) -func Hash(data []byte) []byte { +func hash(data []byte) []byte { if len(data) == 0 { return []byte{} } @@ -14,5 +14,5 @@ func Hash(data []byte) []byte { } func HashString(data string) string { - return hex.EncodeToString(Hash([]byte(data))) + return hex.EncodeToString(hash([]byte(data))) } diff --git a/auth/middleware.go b/auth/middleware.go deleted file mode 100644 index dd16fe3..0000000 --- a/auth/middleware.go +++ /dev/null @@ -1,31 +0,0 @@ -package auth - -import ( - "net/http" -) - -type BasicAuthProvider interface { - Enabled() bool - DoesBasicAuthMatch(username, password string) bool -} - -func BasicAuthMiddleware(handlerFunc http.HandlerFunc, basicAuthProvider BasicAuthProvider) http.HandlerFunc { - if basicAuthProvider.Enabled() { - return func(w http.ResponseWriter, r *http.Request) { - if doesBasicAuthMatch(r, basicAuthProvider) { - handlerFunc.ServeHTTP(w, r) - } else { - w.WriteHeader(http.StatusUnauthorized) - } - } - } - return handlerFunc -} - -func doesBasicAuthMatch(r *http.Request, basicAuthProvider BasicAuthProvider) bool { - rawUsername, rawPassword, ok := r.BasicAuth() - if ok { - return basicAuthProvider.DoesBasicAuthMatch(rawUsername, rawPassword) - } - return false -} diff --git a/auth/middleware_test.go b/auth/middleware_test.go deleted file mode 100644 index 27c8ab6..0000000 --- a/auth/middleware_test.go +++ /dev/null @@ -1,58 +0,0 @@ -package auth - -import ( - "net/http" - "net/http/httptest" - "testing" -) - -type testAuthProvider struct { - enabled bool - match bool -} - -func (p testAuthProvider) Enabled() bool { - return p.enabled -} - -func (p testAuthProvider) DoesBasicAuthMatch(username, password string) bool { - return p.match -} - -func newTestRequest() *http.Request { - return httptest.NewRequest(http.MethodGet, "http://example.com", nil) -} - -func executeBasicAuthMiddlewareTest(t *testing.T, authEnabled bool, authMatches bool, expectedCode int, expectedCallCount int) { - callCount := 0 - testHandler := func(w http.ResponseWriter, r *http.Request) { - callCount++ - } - - handler := BasicAuthMiddleware(testHandler, testAuthProvider{enabled: authEnabled, match: authMatches}) - recorder := httptest.NewRecorder() - request := newTestRequest() - if authEnabled { - request.SetBasicAuth("test", "test") - } - handler.ServeHTTP(recorder, request) - - if recorder.Code != expectedCode { - t.Errorf("statusCode = %v, want %v", recorder.Code, expectedCode) - } - if callCount != expectedCallCount { - t.Errorf("callCount = %v, want %v", callCount, expectedCallCount) - } -} - -func Test_GIVEN_DisabledBasicAuth_WHEN_MethodCalled_THEN_RequestProcessed(t *testing.T) { - executeBasicAuthMiddlewareTest(t, false, false, http.StatusOK, 1) -} - -func Test_GIVEN_EnabledBasicAuth_WHEN_MethodCalledWithCorrectCredentials_THEN_RequestProcessed(t *testing.T) { - executeBasicAuthMiddlewareTest(t, true, true, http.StatusOK, 1) -} - -func Test_GIVEN_EnabledBasicAuth_WHEN_MethodCalledWithIncorrectCredentials_THEN_RequestRejected(t *testing.T) { - executeBasicAuthMiddlewareTest(t, true, false, http.StatusUnauthorized, 0) -} diff --git a/auth/provider.go b/auth/provider.go new file mode 100644 index 0000000..a259ea0 --- /dev/null +++ b/auth/provider.go @@ -0,0 +1,9 @@ +package auth + +import ( + "net/http" +) + +type AuthProvider interface { + IsAllowed(*http.Request) bool +} diff --git a/cfg/basicAuth.go b/cfg/basicAuth.go deleted file mode 100644 index 991861e..0000000 --- a/cfg/basicAuth.go +++ /dev/null @@ -1,25 +0,0 @@ -package cfg - -import "gitlab.com/hectorjsmith/fail2ban-prometheus-exporter/auth" - -type hashedBasicAuth struct { - username string - password string -} - -func newHashedBasicAuth(rawUsername, rawPassword string) *hashedBasicAuth { - return &hashedBasicAuth{ - username: auth.HashString(rawUsername), - password: auth.HashString(rawPassword), - } -} - -func (p *hashedBasicAuth) Enabled() bool { - return len(p.username) > 0 && len(p.password) > 0 -} - -func (p *hashedBasicAuth) DoesBasicAuthMatch(rawUsername, rawPassword string) bool { - username := auth.HashString(rawUsername) - password := auth.HashString(rawPassword) - return username == p.username && password == p.password -} diff --git a/cfg/basicAuth_test.go b/cfg/basicAuth_test.go deleted file mode 100644 index 85bd8b3..0000000 --- a/cfg/basicAuth_test.go +++ /dev/null @@ -1,60 +0,0 @@ -package cfg - -import "testing" - -func Test_hashedBasicAuth_DoesBasicAuthMatch(t *testing.T) { - type args struct { - username string - password string - } - type fields struct { - username string - password string - } - tests := []struct { - name string - fields fields - args args - want bool - }{ - {"Happy test #1", fields{username: "1234", password: "1234"}, args{username: "1234", password: "1234"}, true}, - {"Happy test #2", fields{username: "test", password: "1234"}, args{username: "test", password: "1234"}, true}, - {"Happy test #3", fields{username: "TEST", password: "1234"}, args{username: "TEST", password: "1234"}, true}, - {"Non match #1", fields{username: "test", password: "1234"}, args{username: "1234", password: "1234"}, false}, - {"Non match #2", fields{username: "1234", password: "test"}, args{username: "1234", password: "1234"}, false}, - {"Non match #3", fields{username: "1234", password: "test"}, args{username: "1234", password: "TEST"}, false}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - basicAuth := newHashedBasicAuth(tt.fields.username, tt.fields.password) - if got := basicAuth.DoesBasicAuthMatch(tt.args.username, tt.args.password); got != tt.want { - t.Errorf("DoesBasicAuthMatch() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_hashedBasicAuth_Enabled(t *testing.T) { - type fields struct { - username string - password string - } - tests := []struct { - name string - fields fields - want bool - }{ - {"Both blank", fields{username: "", password: ""}, false}, - {"Single blank #1", fields{username: "test", password: ""}, false}, - {"Single blank #1", fields{username: "", password: "test"}, false}, - {"Both populated", fields{username: "test", password: "test"}, true}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - basicAuth := newHashedBasicAuth(tt.fields.username, tt.fields.password) - if got := basicAuth.Enabled(); got != tt.want { - t.Errorf("Enabled() = %v, want %v", got, tt.want) - } - }) - } -} diff --git a/cfg/cfg.go b/cfg/cfg.go index cee729e..9628632 100644 --- a/cfg/cfg.go +++ b/cfg/cfg.go @@ -2,9 +2,11 @@ package cfg import ( "fmt" + "log" "os" "github.com/alecthomas/kong" + "gitlab.com/hectorjsmith/fail2ban-prometheus-exporter/auth" ) var cliStruct struct { @@ -36,11 +38,22 @@ func Parse() *AppSettings { Fail2BanSocketPath: cliStruct.F2bSocketPath, FileCollectorPath: cliStruct.TextFileExporterPath, ExitOnSocketConnError: cliStruct.ExitOnSocketError, - BasicAuthProvider: newHashedBasicAuth(cliStruct.BasicAuthUser, cliStruct.BasicAuthPass), + AuthProvider: createAuthProvider(), } return settings } +func createAuthProvider() auth.AuthProvider { + username := cliStruct.BasicAuthUser + password := cliStruct.BasicAuthPass + + if len(username) == 0 && len(password) == 0 { + return auth.NewEmptyAuthProvider() + } + log.Print("basic auth enabled") + return auth.NewBasicAuthProvider(username, password) +} + func validateFlags(cliCtx *kong.Context) { var flagsValid = true var messages = []string{} diff --git a/cfg/settings.go b/cfg/settings.go index 160234e..519f8e0 100644 --- a/cfg/settings.go +++ b/cfg/settings.go @@ -1,10 +1,12 @@ package cfg +import "gitlab.com/hectorjsmith/fail2ban-prometheus-exporter/auth" + type AppSettings struct { VersionMode bool MetricsAddress string Fail2BanSocketPath string FileCollectorPath string - BasicAuthProvider *hashedBasicAuth + AuthProvider auth.AuthProvider ExitOnSocketConnError bool } diff --git a/exporter.go b/exporter.go index 449f5ac..6bc0126 100644 --- a/exporter.go +++ b/exporter.go @@ -2,17 +2,18 @@ package main import ( "fmt" - "github.com/prometheus/client_golang/prometheus" - "github.com/prometheus/client_golang/prometheus/promhttp" - "gitlab.com/hectorjsmith/fail2ban-prometheus-exporter/auth" - "gitlab.com/hectorjsmith/fail2ban-prometheus-exporter/cfg" - "gitlab.com/hectorjsmith/fail2ban-prometheus-exporter/collector/f2b" - "gitlab.com/hectorjsmith/fail2ban-prometheus-exporter/collector/textfile" "log" "net/http" "os" "os/signal" "syscall" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promhttp" + "gitlab.com/hectorjsmith/fail2ban-prometheus-exporter/cfg" + "gitlab.com/hectorjsmith/fail2ban-prometheus-exporter/collector/f2b" + "gitlab.com/hectorjsmith/fail2ban-prometheus-exporter/collector/textfile" + "gitlab.com/hectorjsmith/fail2ban-prometheus-exporter/server" ) const ( @@ -66,17 +67,14 @@ func main() { textFileCollector := textfile.NewCollector(appSettings) prometheus.MustRegister(textFileCollector) - http.HandleFunc("/", auth.BasicAuthMiddleware(rootHtmlHandler, appSettings.BasicAuthProvider)) - http.HandleFunc(metricsPath, auth.BasicAuthMiddleware( + http.HandleFunc("/", server.BasicAuthMiddleware(rootHtmlHandler, appSettings.AuthProvider)) + http.HandleFunc(metricsPath, server.BasicAuthMiddleware( func(w http.ResponseWriter, r *http.Request) { metricHandler(w, r, textFileCollector) }, - appSettings.BasicAuthProvider, + appSettings.AuthProvider, )) log.Printf("metrics available at '%s'", metricsPath) - if appSettings.BasicAuthProvider.Enabled() { - log.Printf("basic auth enabled") - } svrErr := make(chan error) go func() { diff --git a/server/middleware.go b/server/middleware.go new file mode 100644 index 0000000..7593bbc --- /dev/null +++ b/server/middleware.go @@ -0,0 +1,17 @@ +package server + +import ( + "net/http" + + "gitlab.com/hectorjsmith/fail2ban-prometheus-exporter/auth" +) + +func BasicAuthMiddleware(handlerFunc http.HandlerFunc, authProvider auth.AuthProvider) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if authProvider.IsAllowed(r) { + handlerFunc.ServeHTTP(w, r) + } else { + w.WriteHeader(http.StatusUnauthorized) + } + } +} diff --git a/server/middleware_test.go b/server/middleware_test.go new file mode 100644 index 0000000..3ac564b --- /dev/null +++ b/server/middleware_test.go @@ -0,0 +1,46 @@ +package server + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +type testAuthProvider struct { + match bool +} + +func (p testAuthProvider) IsAllowed(request *http.Request) bool { + return p.match +} + +func newTestRequest() *http.Request { + return httptest.NewRequest(http.MethodGet, "http://example.com", nil) +} + +func executeBasicAuthMiddlewareTest(t *testing.T, authMatches bool, expectedCode int, expectedCallCount int) { + callCount := 0 + testHandler := func(w http.ResponseWriter, r *http.Request) { + callCount++ + } + + handler := BasicAuthMiddleware(testHandler, testAuthProvider{match: authMatches}) + recorder := httptest.NewRecorder() + request := newTestRequest() + handler.ServeHTTP(recorder, request) + + if recorder.Code != expectedCode { + t.Errorf("statusCode = %v, want %v", recorder.Code, expectedCode) + } + if callCount != expectedCallCount { + t.Errorf("callCount = %v, want %v", callCount, expectedCallCount) + } +} + +func Test_GIVEN_MatchingBasicAuth_WHEN_MethodCalled_THEN_RequestProcessed(t *testing.T) { + executeBasicAuthMiddlewareTest(t, true, http.StatusOK, 1) +} + +func Test_GIVEN_NonMatchingBasicAuth_WHEN_MethodCalled_THEN_RequestRejected(t *testing.T) { + executeBasicAuthMiddlewareTest(t, false, http.StatusUnauthorized, 0) +}