diff --git a/README.md b/README.md index fb56af8..02b3676 100644 --- a/README.md +++ b/README.md @@ -39,18 +39,22 @@ See the [releases page](https://gitlab.com/hectorjsmith/fail2ban-prometheus-expo ``` $ fail2ban-prometheus-exporter -h - -web.listen-address string - address to use for metrics server (default 0.0.0.0) + -collector.textfile + enable the textfile collector + -collector.textfile.directory string + directory to read text files with metrics from -port int port to use for the metrics server (default 9191) -socket string path to the fail2ban server socket -version show version info and exit - -collector.textfile - enable the textfile collector - -collector.textfile.directory string - directory to read text files with metrics from + -web.basic-auth.password string + password to use to protect endpoints with basic auth + -web.basic-auth.username string + username to use to protect endpoints with basic auth + -web.listen-address string + address to use for the metrics server (default "0.0.0.0") ``` **Example** diff --git a/src/auth/hash.go b/src/auth/hash.go new file mode 100644 index 0000000..e1b4b3b --- /dev/null +++ b/src/auth/hash.go @@ -0,0 +1,18 @@ +package auth + +import ( + "crypto/sha256" + "encoding/hex" +) + +func Hash(data []byte) []byte { + if len(data) == 0 { + return []byte{} + } + b := sha256.Sum256(data) + return b[:] +} + +func HashString(data string) string { + return hex.EncodeToString(Hash([]byte(data))) +} diff --git a/src/auth/hash_test.go b/src/auth/hash_test.go new file mode 100644 index 0000000..ffe4bea --- /dev/null +++ b/src/auth/hash_test.go @@ -0,0 +1,26 @@ +package auth + +import ( + "reflect" + "testing" +) + +func TestHashString(t *testing.T) { + tests := []struct { + name string + args string + want string + }{ + {"Happy path #1", "123", "a665a45920422f9d417e4867efdc4fb8a04a1f3fff1fa07e998e86f7f7a27ae3"}, + {"Happy path #2", "hello world", "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9"}, + {"Happy path #3", "H3Ll0_W0RLD", "d58a27fe9a6e73a1d8a67189fb8acace047e7a1a795276a0056d3717ad61bd0e"}, + {"Blank string", "", ""}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := HashString(tt.args); !reflect.DeepEqual(got, tt.want) { + t.Errorf("HashString() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/src/auth/middleware.go b/src/auth/middleware.go new file mode 100644 index 0000000..dd16fe3 --- /dev/null +++ b/src/auth/middleware.go @@ -0,0 +1,31 @@ +package auth + +import ( + "net/http" +) + +type BasicAuthProvider interface { + Enabled() bool + DoesBasicAuthMatch(username, password string) bool +} + +func BasicAuthMiddleware(handlerFunc http.HandlerFunc, basicAuthProvider BasicAuthProvider) http.HandlerFunc { + if basicAuthProvider.Enabled() { + return func(w http.ResponseWriter, r *http.Request) { + if doesBasicAuthMatch(r, basicAuthProvider) { + handlerFunc.ServeHTTP(w, r) + } else { + w.WriteHeader(http.StatusUnauthorized) + } + } + } + return handlerFunc +} + +func doesBasicAuthMatch(r *http.Request, basicAuthProvider BasicAuthProvider) bool { + rawUsername, rawPassword, ok := r.BasicAuth() + if ok { + return basicAuthProvider.DoesBasicAuthMatch(rawUsername, rawPassword) + } + return false +} diff --git a/src/auth/middleware_test.go b/src/auth/middleware_test.go new file mode 100644 index 0000000..27c8ab6 --- /dev/null +++ b/src/auth/middleware_test.go @@ -0,0 +1,58 @@ +package auth + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +type testAuthProvider struct { + enabled bool + match bool +} + +func (p testAuthProvider) Enabled() bool { + return p.enabled +} + +func (p testAuthProvider) DoesBasicAuthMatch(username, password string) bool { + return p.match +} + +func newTestRequest() *http.Request { + return httptest.NewRequest(http.MethodGet, "http://example.com", nil) +} + +func executeBasicAuthMiddlewareTest(t *testing.T, authEnabled bool, authMatches bool, expectedCode int, expectedCallCount int) { + callCount := 0 + testHandler := func(w http.ResponseWriter, r *http.Request) { + callCount++ + } + + handler := BasicAuthMiddleware(testHandler, testAuthProvider{enabled: authEnabled, match: authMatches}) + recorder := httptest.NewRecorder() + request := newTestRequest() + if authEnabled { + request.SetBasicAuth("test", "test") + } + handler.ServeHTTP(recorder, request) + + if recorder.Code != expectedCode { + t.Errorf("statusCode = %v, want %v", recorder.Code, expectedCode) + } + if callCount != expectedCallCount { + t.Errorf("callCount = %v, want %v", callCount, expectedCallCount) + } +} + +func Test_GIVEN_DisabledBasicAuth_WHEN_MethodCalled_THEN_RequestProcessed(t *testing.T) { + executeBasicAuthMiddlewareTest(t, false, false, http.StatusOK, 1) +} + +func Test_GIVEN_EnabledBasicAuth_WHEN_MethodCalledWithCorrectCredentials_THEN_RequestProcessed(t *testing.T) { + executeBasicAuthMiddlewareTest(t, true, true, http.StatusOK, 1) +} + +func Test_GIVEN_EnabledBasicAuth_WHEN_MethodCalledWithIncorrectCredentials_THEN_RequestRejected(t *testing.T) { + executeBasicAuthMiddlewareTest(t, true, false, http.StatusUnauthorized, 0) +} diff --git a/src/cfg/basicAuth.go b/src/cfg/basicAuth.go new file mode 100644 index 0000000..bc5408f --- /dev/null +++ b/src/cfg/basicAuth.go @@ -0,0 +1,25 @@ +package cfg + +import "fail2ban-prometheus-exporter/auth" + +type hashedBasicAuth struct { + username string + password string +} + +func newHashedBasicAuth(rawUsername, rawPassword string) *hashedBasicAuth { + return &hashedBasicAuth{ + username: auth.HashString(rawUsername), + password: auth.HashString(rawPassword), + } +} + +func (p *hashedBasicAuth) Enabled() bool { + return len(p.username) > 0 && len(p.password) > 0 +} + +func (p *hashedBasicAuth) DoesBasicAuthMatch(rawUsername, rawPassword string) bool { + username := auth.HashString(rawUsername) + password := auth.HashString(rawPassword) + return username == p.username && password == p.password +} diff --git a/src/cfg/basicAuth_test.go b/src/cfg/basicAuth_test.go new file mode 100644 index 0000000..85bd8b3 --- /dev/null +++ b/src/cfg/basicAuth_test.go @@ -0,0 +1,60 @@ +package cfg + +import "testing" + +func Test_hashedBasicAuth_DoesBasicAuthMatch(t *testing.T) { + type args struct { + username string + password string + } + type fields struct { + username string + password string + } + tests := []struct { + name string + fields fields + args args + want bool + }{ + {"Happy test #1", fields{username: "1234", password: "1234"}, args{username: "1234", password: "1234"}, true}, + {"Happy test #2", fields{username: "test", password: "1234"}, args{username: "test", password: "1234"}, true}, + {"Happy test #3", fields{username: "TEST", password: "1234"}, args{username: "TEST", password: "1234"}, true}, + {"Non match #1", fields{username: "test", password: "1234"}, args{username: "1234", password: "1234"}, false}, + {"Non match #2", fields{username: "1234", password: "test"}, args{username: "1234", password: "1234"}, false}, + {"Non match #3", fields{username: "1234", password: "test"}, args{username: "1234", password: "TEST"}, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + basicAuth := newHashedBasicAuth(tt.fields.username, tt.fields.password) + if got := basicAuth.DoesBasicAuthMatch(tt.args.username, tt.args.password); got != tt.want { + t.Errorf("DoesBasicAuthMatch() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_hashedBasicAuth_Enabled(t *testing.T) { + type fields struct { + username string + password string + } + tests := []struct { + name string + fields fields + want bool + }{ + {"Both blank", fields{username: "", password: ""}, false}, + {"Single blank #1", fields{username: "test", password: ""}, false}, + {"Single blank #1", fields{username: "", password: "test"}, false}, + {"Both populated", fields{username: "test", password: "test"}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + basicAuth := newHashedBasicAuth(tt.fields.username, tt.fields.password) + if got := basicAuth.Enabled(); got != tt.want { + t.Errorf("Enabled() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/src/cfg/cfg.go b/src/cfg/cfg.go index 9e13dd2..008fef1 100644 --- a/src/cfg/cfg.go +++ b/src/cfg/cfg.go @@ -18,9 +18,13 @@ type AppSettings struct { Fail2BanSocketPath string FileCollectorPath string FileCollectorEnabled bool + BasicAuthProvider *hashedBasicAuth } func Parse() *AppSettings { + var rawBasicAuthUsername string + var rawBasicAuthPassword string + appSettings := &AppSettings{} flag.BoolVar(&appSettings.VersionMode, "version", false, "show version info and exit") flag.StringVar(&appSettings.MetricsAddress, "web.listen-address", "0.0.0.0", "address to use for the metrics server") @@ -28,12 +32,19 @@ func Parse() *AppSettings { flag.StringVar(&appSettings.Fail2BanSocketPath, "socket", "", "path to the fail2ban server socket") flag.BoolVar(&appSettings.FileCollectorEnabled, "collector.textfile", false, "enable the textfile collector") flag.StringVar(&appSettings.FileCollectorPath, "collector.textfile.directory", "", "directory to read text files with metrics from") + flag.StringVar(&rawBasicAuthUsername, "web.basic-auth.username", "", "username to use to protect endpoints with basic auth") + flag.StringVar(&rawBasicAuthPassword, "web.basic-auth.password", "", "password to use to protect endpoints with basic auth") flag.Parse() + appSettings.setBasicAuthValues(rawBasicAuthUsername, rawBasicAuthPassword) appSettings.validateFlags() return appSettings } +func (settings *AppSettings) setBasicAuthValues(rawUsername, rawPassword string) { + settings.BasicAuthProvider = newHashedBasicAuth(rawUsername, rawPassword) +} + func (settings *AppSettings) validateFlags() { var flagsValid = true if !settings.VersionMode { @@ -50,6 +61,10 @@ func (settings *AppSettings) validateFlags() { fmt.Printf("file collector directory path must not be empty if collector enabled\n") flagsValid = false } + if (len(settings.BasicAuthProvider.username) > 0) != (len(settings.BasicAuthProvider.password) > 0) { + fmt.Printf("to enable basic auth both the username and the password must be provided") + flagsValid = false + } } if !flagsValid { flag.Usage() diff --git a/src/exporter.go b/src/exporter.go index a62271c..46eb3cc 100644 --- a/src/exporter.go +++ b/src/exporter.go @@ -1,6 +1,7 @@ package main import ( + "fail2ban-prometheus-exporter/auth" "fail2ban-prometheus-exporter/cfg" "fail2ban-prometheus-exporter/collector/f2b" "fail2ban-prometheus-exporter/collector/textfile" @@ -63,10 +64,13 @@ func main() { textFileCollector := textfile.NewCollector(appSettings) prometheus.MustRegister(textFileCollector) - http.HandleFunc("/", rootHtmlHandler) - http.HandleFunc(metricsPath, func(w http.ResponseWriter, r *http.Request) { - metricHandler(w, r, textFileCollector) - }) + http.HandleFunc("/", auth.BasicAuthMiddleware(rootHtmlHandler, appSettings.BasicAuthProvider)) + http.HandleFunc(metricsPath, auth.BasicAuthMiddleware( + func(w http.ResponseWriter, r *http.Request) { + metricHandler(w, r, textFileCollector) + }, + appSettings.BasicAuthProvider, + )) log.Printf("metrics available at '%s'", metricsPath) svrErr := make(chan error)