From 6f76a03118afbf90d35a16e8ee3c863f0718803a Mon Sep 17 00:00:00 2001 From: Hector Date: Fri, 14 Jan 2022 21:36:49 +0000 Subject: [PATCH] feat: add support for basic auth (#16) Add new CLI parameters to enable protecting the API endpoints with basic auth authentication. Wrap the server endpoints in a new auth middleware that protects it using the provided basic auth credentials (if set). Store the provided basic auth credentials as hashed values to prevent them from being accidentally leaked. Add unit tests to ensure the new functionality works as expected. --- README.md | 16 ++++++---- src/auth/hash.go | 18 +++++++++++ src/auth/hash_test.go | 26 ++++++++++++++++ src/auth/middleware.go | 31 +++++++++++++++++++ src/auth/middleware_test.go | 58 +++++++++++++++++++++++++++++++++++ src/cfg/basicAuth.go | 25 ++++++++++++++++ src/cfg/basicAuth_test.go | 60 +++++++++++++++++++++++++++++++++++++ src/cfg/cfg.go | 15 ++++++++++ src/exporter.go | 12 +++++--- 9 files changed, 251 insertions(+), 10 deletions(-) create mode 100644 src/auth/hash.go create mode 100644 src/auth/hash_test.go create mode 100644 src/auth/middleware.go create mode 100644 src/auth/middleware_test.go create mode 100644 src/cfg/basicAuth.go create mode 100644 src/cfg/basicAuth_test.go diff --git a/README.md b/README.md index fb56af8..02b3676 100644 --- a/README.md +++ b/README.md @@ -39,18 +39,22 @@ See the [releases page](https://gitlab.com/hectorjsmith/fail2ban-prometheus-expo ``` $ fail2ban-prometheus-exporter -h - -web.listen-address string - address to use for metrics server (default 0.0.0.0) + -collector.textfile + enable the textfile collector + -collector.textfile.directory string + directory to read text files with metrics from -port int port to use for the metrics server (default 9191) -socket string path to the fail2ban server socket -version show version info and exit - -collector.textfile - enable the textfile collector - -collector.textfile.directory string - directory to read text files with metrics from + -web.basic-auth.password string + password to use to protect endpoints with basic auth + -web.basic-auth.username string + username to use to protect endpoints with basic auth + -web.listen-address string + address to use for the metrics server (default "0.0.0.0") ``` **Example** diff --git a/src/auth/hash.go b/src/auth/hash.go new file mode 100644 index 0000000..e1b4b3b --- /dev/null +++ b/src/auth/hash.go @@ -0,0 +1,18 @@ +package auth + +import ( + "crypto/sha256" + "encoding/hex" +) + +func Hash(data []byte) []byte { + if len(data) == 0 { + return []byte{} + } + b := sha256.Sum256(data) + return b[:] +} + +func HashString(data string) string { + return hex.EncodeToString(Hash([]byte(data))) +} diff --git a/src/auth/hash_test.go b/src/auth/hash_test.go new file mode 100644 index 0000000..ffe4bea --- /dev/null +++ b/src/auth/hash_test.go @@ -0,0 +1,26 @@ +package auth + +import ( + "reflect" + "testing" +) + +func TestHashString(t *testing.T) { + tests := []struct { + name string + args string + want string + }{ + {"Happy path #1", "123", "a665a45920422f9d417e4867efdc4fb8a04a1f3fff1fa07e998e86f7f7a27ae3"}, + {"Happy path #2", "hello world", "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9"}, + {"Happy path #3", "H3Ll0_W0RLD", "d58a27fe9a6e73a1d8a67189fb8acace047e7a1a795276a0056d3717ad61bd0e"}, + {"Blank string", "", ""}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := HashString(tt.args); !reflect.DeepEqual(got, tt.want) { + t.Errorf("HashString() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/src/auth/middleware.go b/src/auth/middleware.go new file mode 100644 index 0000000..dd16fe3 --- /dev/null +++ b/src/auth/middleware.go @@ -0,0 +1,31 @@ +package auth + +import ( + "net/http" +) + +type BasicAuthProvider interface { + Enabled() bool + DoesBasicAuthMatch(username, password string) bool +} + +func BasicAuthMiddleware(handlerFunc http.HandlerFunc, basicAuthProvider BasicAuthProvider) http.HandlerFunc { + if basicAuthProvider.Enabled() { + return func(w http.ResponseWriter, r *http.Request) { + if doesBasicAuthMatch(r, basicAuthProvider) { + handlerFunc.ServeHTTP(w, r) + } else { + w.WriteHeader(http.StatusUnauthorized) + } + } + } + return handlerFunc +} + +func doesBasicAuthMatch(r *http.Request, basicAuthProvider BasicAuthProvider) bool { + rawUsername, rawPassword, ok := r.BasicAuth() + if ok { + return basicAuthProvider.DoesBasicAuthMatch(rawUsername, rawPassword) + } + return false +} diff --git a/src/auth/middleware_test.go b/src/auth/middleware_test.go new file mode 100644 index 0000000..27c8ab6 --- /dev/null +++ b/src/auth/middleware_test.go @@ -0,0 +1,58 @@ +package auth + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +type testAuthProvider struct { + enabled bool + match bool +} + +func (p testAuthProvider) Enabled() bool { + return p.enabled +} + +func (p testAuthProvider) DoesBasicAuthMatch(username, password string) bool { + return p.match +} + +func newTestRequest() *http.Request { + return httptest.NewRequest(http.MethodGet, "http://example.com", nil) +} + +func executeBasicAuthMiddlewareTest(t *testing.T, authEnabled bool, authMatches bool, expectedCode int, expectedCallCount int) { + callCount := 0 + testHandler := func(w http.ResponseWriter, r *http.Request) { + callCount++ + } + + handler := BasicAuthMiddleware(testHandler, testAuthProvider{enabled: authEnabled, match: authMatches}) + recorder := httptest.NewRecorder() + request := newTestRequest() + if authEnabled { + request.SetBasicAuth("test", "test") + } + handler.ServeHTTP(recorder, request) + + if recorder.Code != expectedCode { + t.Errorf("statusCode = %v, want %v", recorder.Code, expectedCode) + } + if callCount != expectedCallCount { + t.Errorf("callCount = %v, want %v", callCount, expectedCallCount) + } +} + +func Test_GIVEN_DisabledBasicAuth_WHEN_MethodCalled_THEN_RequestProcessed(t *testing.T) { + executeBasicAuthMiddlewareTest(t, false, false, http.StatusOK, 1) +} + +func Test_GIVEN_EnabledBasicAuth_WHEN_MethodCalledWithCorrectCredentials_THEN_RequestProcessed(t *testing.T) { + executeBasicAuthMiddlewareTest(t, true, true, http.StatusOK, 1) +} + +func Test_GIVEN_EnabledBasicAuth_WHEN_MethodCalledWithIncorrectCredentials_THEN_RequestRejected(t *testing.T) { + executeBasicAuthMiddlewareTest(t, true, false, http.StatusUnauthorized, 0) +} diff --git a/src/cfg/basicAuth.go b/src/cfg/basicAuth.go new file mode 100644 index 0000000..bc5408f --- /dev/null +++ b/src/cfg/basicAuth.go @@ -0,0 +1,25 @@ +package cfg + +import "fail2ban-prometheus-exporter/auth" + +type hashedBasicAuth struct { + username string + password string +} + +func newHashedBasicAuth(rawUsername, rawPassword string) *hashedBasicAuth { + return &hashedBasicAuth{ + username: auth.HashString(rawUsername), + password: auth.HashString(rawPassword), + } +} + +func (p *hashedBasicAuth) Enabled() bool { + return len(p.username) > 0 && len(p.password) > 0 +} + +func (p *hashedBasicAuth) DoesBasicAuthMatch(rawUsername, rawPassword string) bool { + username := auth.HashString(rawUsername) + password := auth.HashString(rawPassword) + return username == p.username && password == p.password +} diff --git a/src/cfg/basicAuth_test.go b/src/cfg/basicAuth_test.go new file mode 100644 index 0000000..85bd8b3 --- /dev/null +++ b/src/cfg/basicAuth_test.go @@ -0,0 +1,60 @@ +package cfg + +import "testing" + +func Test_hashedBasicAuth_DoesBasicAuthMatch(t *testing.T) { + type args struct { + username string + password string + } + type fields struct { + username string + password string + } + tests := []struct { + name string + fields fields + args args + want bool + }{ + {"Happy test #1", fields{username: "1234", password: "1234"}, args{username: "1234", password: "1234"}, true}, + {"Happy test #2", fields{username: "test", password: "1234"}, args{username: "test", password: "1234"}, true}, + {"Happy test #3", fields{username: "TEST", password: "1234"}, args{username: "TEST", password: "1234"}, true}, + {"Non match #1", fields{username: "test", password: "1234"}, args{username: "1234", password: "1234"}, false}, + {"Non match #2", fields{username: "1234", password: "test"}, args{username: "1234", password: "1234"}, false}, + {"Non match #3", fields{username: "1234", password: "test"}, args{username: "1234", password: "TEST"}, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + basicAuth := newHashedBasicAuth(tt.fields.username, tt.fields.password) + if got := basicAuth.DoesBasicAuthMatch(tt.args.username, tt.args.password); got != tt.want { + t.Errorf("DoesBasicAuthMatch() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_hashedBasicAuth_Enabled(t *testing.T) { + type fields struct { + username string + password string + } + tests := []struct { + name string + fields fields + want bool + }{ + {"Both blank", fields{username: "", password: ""}, false}, + {"Single blank #1", fields{username: "test", password: ""}, false}, + {"Single blank #1", fields{username: "", password: "test"}, false}, + {"Both populated", fields{username: "test", password: "test"}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + basicAuth := newHashedBasicAuth(tt.fields.username, tt.fields.password) + if got := basicAuth.Enabled(); got != tt.want { + t.Errorf("Enabled() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/src/cfg/cfg.go b/src/cfg/cfg.go index 9e13dd2..008fef1 100644 --- a/src/cfg/cfg.go +++ b/src/cfg/cfg.go @@ -18,9 +18,13 @@ type AppSettings struct { Fail2BanSocketPath string FileCollectorPath string FileCollectorEnabled bool + BasicAuthProvider *hashedBasicAuth } func Parse() *AppSettings { + var rawBasicAuthUsername string + var rawBasicAuthPassword string + appSettings := &AppSettings{} flag.BoolVar(&appSettings.VersionMode, "version", false, "show version info and exit") flag.StringVar(&appSettings.MetricsAddress, "web.listen-address", "0.0.0.0", "address to use for the metrics server") @@ -28,12 +32,19 @@ func Parse() *AppSettings { flag.StringVar(&appSettings.Fail2BanSocketPath, "socket", "", "path to the fail2ban server socket") flag.BoolVar(&appSettings.FileCollectorEnabled, "collector.textfile", false, "enable the textfile collector") flag.StringVar(&appSettings.FileCollectorPath, "collector.textfile.directory", "", "directory to read text files with metrics from") + flag.StringVar(&rawBasicAuthUsername, "web.basic-auth.username", "", "username to use to protect endpoints with basic auth") + flag.StringVar(&rawBasicAuthPassword, "web.basic-auth.password", "", "password to use to protect endpoints with basic auth") flag.Parse() + appSettings.setBasicAuthValues(rawBasicAuthUsername, rawBasicAuthPassword) appSettings.validateFlags() return appSettings } +func (settings *AppSettings) setBasicAuthValues(rawUsername, rawPassword string) { + settings.BasicAuthProvider = newHashedBasicAuth(rawUsername, rawPassword) +} + func (settings *AppSettings) validateFlags() { var flagsValid = true if !settings.VersionMode { @@ -50,6 +61,10 @@ func (settings *AppSettings) validateFlags() { fmt.Printf("file collector directory path must not be empty if collector enabled\n") flagsValid = false } + if (len(settings.BasicAuthProvider.username) > 0) != (len(settings.BasicAuthProvider.password) > 0) { + fmt.Printf("to enable basic auth both the username and the password must be provided") + flagsValid = false + } } if !flagsValid { flag.Usage() diff --git a/src/exporter.go b/src/exporter.go index a62271c..46eb3cc 100644 --- a/src/exporter.go +++ b/src/exporter.go @@ -1,6 +1,7 @@ package main import ( + "fail2ban-prometheus-exporter/auth" "fail2ban-prometheus-exporter/cfg" "fail2ban-prometheus-exporter/collector/f2b" "fail2ban-prometheus-exporter/collector/textfile" @@ -63,10 +64,13 @@ func main() { textFileCollector := textfile.NewCollector(appSettings) prometheus.MustRegister(textFileCollector) - http.HandleFunc("/", rootHtmlHandler) - http.HandleFunc(metricsPath, func(w http.ResponseWriter, r *http.Request) { - metricHandler(w, r, textFileCollector) - }) + http.HandleFunc("/", auth.BasicAuthMiddleware(rootHtmlHandler, appSettings.BasicAuthProvider)) + http.HandleFunc(metricsPath, auth.BasicAuthMiddleware( + func(w http.ResponseWriter, r *http.Request) { + metricHandler(w, r, textFileCollector) + }, + appSettings.BasicAuthProvider, + )) log.Printf("metrics available at '%s'", metricsPath) svrErr := make(chan error)