diff --git a/auth/basic.go b/auth/basic.go new file mode 100644 index 0000000..243b2a5 --- /dev/null +++ b/auth/basic.go @@ -0,0 +1,29 @@ +package auth + +import ( + "fmt" + "net/http" +) + +func NewBasicAuthProvider(username, password string) AuthProvider { + return &basicAuthProvider{ + hashedAuth: encodeBasicAuth(username, password), + } +} + +type basicAuthProvider struct { + hashedAuth string +} + +func (p *basicAuthProvider) IsAllowed(request *http.Request) bool { + username, password, ok := request.BasicAuth() + if !ok { + return false + } + requestAuth := encodeBasicAuth(username, password) + return p.hashedAuth == requestAuth +} + +func encodeBasicAuth(username, password string) string { + return HashString(fmt.Sprintf("%s:%s", username, password)) +} diff --git a/auth/hash.go b/auth/hash.go index e1b4b3b..cc41e36 100644 --- a/auth/hash.go +++ b/auth/hash.go @@ -5,7 +5,7 @@ import ( "encoding/hex" ) -func Hash(data []byte) []byte { +func hash(data []byte) []byte { if len(data) == 0 { return []byte{} } @@ -14,5 +14,5 @@ func Hash(data []byte) []byte { } func HashString(data string) string { - return hex.EncodeToString(Hash([]byte(data))) + return hex.EncodeToString(hash([]byte(data))) } diff --git a/auth/provider.go b/auth/provider.go new file mode 100644 index 0000000..4d8167e --- /dev/null +++ b/auth/provider.go @@ -0,0 +1,34 @@ +package auth + +import ( + "net/http" +) + +type AuthProvider interface { + IsAllowed(*http.Request) bool +} + +func NewEmptyAuthProvider() AuthProvider { + return &emptyAuthProvider{} +} + +type emptyAuthProvider struct { +} + +func (p *emptyAuthProvider) IsAllowed(request *http.Request) bool { + return true +} + +type compositeAuthProvider struct { + providers []AuthProvider +} + +func (p *compositeAuthProvider) IsAllowed(request *http.Request) bool { + for i := 0; i < len(p.providers); i++ { + provider := p.providers[i] + if provider.IsAllowed(request) { + return true + } + } + return false +} diff --git a/cfg/cfg.go b/cfg/cfg.go index cee729e..9628632 100644 --- a/cfg/cfg.go +++ b/cfg/cfg.go @@ -2,9 +2,11 @@ package cfg import ( "fmt" + "log" "os" "github.com/alecthomas/kong" + "gitlab.com/hectorjsmith/fail2ban-prometheus-exporter/auth" ) var cliStruct struct { @@ -36,11 +38,22 @@ func Parse() *AppSettings { Fail2BanSocketPath: cliStruct.F2bSocketPath, FileCollectorPath: cliStruct.TextFileExporterPath, ExitOnSocketConnError: cliStruct.ExitOnSocketError, - BasicAuthProvider: newHashedBasicAuth(cliStruct.BasicAuthUser, cliStruct.BasicAuthPass), + AuthProvider: createAuthProvider(), } return settings } +func createAuthProvider() auth.AuthProvider { + username := cliStruct.BasicAuthUser + password := cliStruct.BasicAuthPass + + if len(username) == 0 && len(password) == 0 { + return auth.NewEmptyAuthProvider() + } + log.Print("basic auth enabled") + return auth.NewBasicAuthProvider(username, password) +} + func validateFlags(cliCtx *kong.Context) { var flagsValid = true var messages = []string{} diff --git a/cfg/settings.go b/cfg/settings.go index 160234e..519f8e0 100644 --- a/cfg/settings.go +++ b/cfg/settings.go @@ -1,10 +1,12 @@ package cfg +import "gitlab.com/hectorjsmith/fail2ban-prometheus-exporter/auth" + type AppSettings struct { VersionMode bool MetricsAddress string Fail2BanSocketPath string FileCollectorPath string - BasicAuthProvider *hashedBasicAuth + AuthProvider auth.AuthProvider ExitOnSocketConnError bool } diff --git a/exporter.go b/exporter.go index 65857ba..6bc0126 100644 --- a/exporter.go +++ b/exporter.go @@ -67,17 +67,14 @@ func main() { textFileCollector := textfile.NewCollector(appSettings) prometheus.MustRegister(textFileCollector) - http.HandleFunc("/", server.BasicAuthMiddleware(rootHtmlHandler, appSettings.BasicAuthProvider)) + http.HandleFunc("/", server.BasicAuthMiddleware(rootHtmlHandler, appSettings.AuthProvider)) http.HandleFunc(metricsPath, server.BasicAuthMiddleware( func(w http.ResponseWriter, r *http.Request) { metricHandler(w, r, textFileCollector) }, - appSettings.BasicAuthProvider, + appSettings.AuthProvider, )) log.Printf("metrics available at '%s'", metricsPath) - if appSettings.BasicAuthProvider.Enabled() { - log.Printf("basic auth enabled") - } svrErr := make(chan error) go func() { diff --git a/server/middleware.go b/server/middleware.go index 92dec38..7593bbc 100644 --- a/server/middleware.go +++ b/server/middleware.go @@ -2,30 +2,16 @@ package server import ( "net/http" + + "gitlab.com/hectorjsmith/fail2ban-prometheus-exporter/auth" ) -type BasicAuthProvider interface { - Enabled() bool - DoesBasicAuthMatch(username, password string) bool -} - -func BasicAuthMiddleware(handlerFunc http.HandlerFunc, basicAuthProvider BasicAuthProvider) http.HandlerFunc { - if basicAuthProvider.Enabled() { - return func(w http.ResponseWriter, r *http.Request) { - if doesBasicAuthMatch(r, basicAuthProvider) { - handlerFunc.ServeHTTP(w, r) - } else { - w.WriteHeader(http.StatusUnauthorized) - } +func BasicAuthMiddleware(handlerFunc http.HandlerFunc, authProvider auth.AuthProvider) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if authProvider.IsAllowed(r) { + handlerFunc.ServeHTTP(w, r) + } else { + w.WriteHeader(http.StatusUnauthorized) } } - return handlerFunc -} - -func doesBasicAuthMatch(r *http.Request, basicAuthProvider BasicAuthProvider) bool { - rawUsername, rawPassword, ok := r.BasicAuth() - if ok { - return basicAuthProvider.DoesBasicAuthMatch(rawUsername, rawPassword) - } - return false } diff --git a/server/middleware_test.go b/server/middleware_test.go index 7c7e428..3ac564b 100644 --- a/server/middleware_test.go +++ b/server/middleware_test.go @@ -7,15 +7,10 @@ import ( ) type testAuthProvider struct { - enabled bool - match bool + match bool } -func (p testAuthProvider) Enabled() bool { - return p.enabled -} - -func (p testAuthProvider) DoesBasicAuthMatch(username, password string) bool { +func (p testAuthProvider) IsAllowed(request *http.Request) bool { return p.match } @@ -23,18 +18,15 @@ func newTestRequest() *http.Request { return httptest.NewRequest(http.MethodGet, "http://example.com", nil) } -func executeBasicAuthMiddlewareTest(t *testing.T, authEnabled bool, authMatches bool, expectedCode int, expectedCallCount int) { +func executeBasicAuthMiddlewareTest(t *testing.T, authMatches bool, expectedCode int, expectedCallCount int) { callCount := 0 testHandler := func(w http.ResponseWriter, r *http.Request) { callCount++ } - handler := BasicAuthMiddleware(testHandler, testAuthProvider{enabled: authEnabled, match: authMatches}) + handler := BasicAuthMiddleware(testHandler, testAuthProvider{match: authMatches}) recorder := httptest.NewRecorder() request := newTestRequest() - if authEnabled { - request.SetBasicAuth("test", "test") - } handler.ServeHTTP(recorder, request) if recorder.Code != expectedCode { @@ -45,14 +37,10 @@ func executeBasicAuthMiddlewareTest(t *testing.T, authEnabled bool, authMatches } } -func Test_GIVEN_DisabledBasicAuth_WHEN_MethodCalled_THEN_RequestProcessed(t *testing.T) { - executeBasicAuthMiddlewareTest(t, false, false, http.StatusOK, 1) +func Test_GIVEN_MatchingBasicAuth_WHEN_MethodCalled_THEN_RequestProcessed(t *testing.T) { + executeBasicAuthMiddlewareTest(t, true, http.StatusOK, 1) } -func Test_GIVEN_EnabledBasicAuth_WHEN_MethodCalledWithCorrectCredentials_THEN_RequestProcessed(t *testing.T) { - executeBasicAuthMiddlewareTest(t, true, true, http.StatusOK, 1) -} - -func Test_GIVEN_EnabledBasicAuth_WHEN_MethodCalledWithIncorrectCredentials_THEN_RequestRejected(t *testing.T) { - executeBasicAuthMiddlewareTest(t, true, false, http.StatusUnauthorized, 0) +func Test_GIVEN_NonMatchingBasicAuth_WHEN_MethodCalled_THEN_RequestRejected(t *testing.T) { + executeBasicAuthMiddlewareTest(t, false, http.StatusUnauthorized, 0) }