From e3d8c1e0e59117ac35202fceaf561470952b35b3 Mon Sep 17 00:00:00 2001
From: Hector <hector@hjs.dev>
Date: Tue, 11 Jan 2022 21:53:43 +0000
Subject: [PATCH 1/7] feat: add support for basic auth (#16)

Add new CLI parameters to enable protecting the API endpoints with basic
auth authentication.
Wrap the server endpoints in a new authMiddleware that protects it using
the provided basic auth credentials (if set).
---
 src/cfg/cfg.go  |  4 ++++
 src/exporter.go | 26 ++++++++++++++++++++++----
 2 files changed, 26 insertions(+), 4 deletions(-)

diff --git a/src/cfg/cfg.go b/src/cfg/cfg.go
index 9e13dd2..4d7bc64 100644
--- a/src/cfg/cfg.go
+++ b/src/cfg/cfg.go
@@ -18,6 +18,8 @@ type AppSettings struct {
 	Fail2BanSocketPath   string
 	FileCollectorPath    string
 	FileCollectorEnabled bool
+	BasicAuthUsername    string
+	BasicAuthPassword    string
 }
 
 func Parse() *AppSettings {
@@ -28,6 +30,8 @@ func Parse() *AppSettings {
 	flag.StringVar(&appSettings.Fail2BanSocketPath, "socket", "", "path to the fail2ban server socket")
 	flag.BoolVar(&appSettings.FileCollectorEnabled, "collector.textfile", false, "enable the textfile collector")
 	flag.StringVar(&appSettings.FileCollectorPath, "collector.textfile.directory", "", "directory to read text files with metrics from")
+	flag.StringVar(&appSettings.BasicAuthUsername, "web.basic-auth.username", "", "set username for basic auth")
+	flag.StringVar(&appSettings.BasicAuthPassword, "web.basic-auth.password", "", "set password for basic auth")
 
 	flag.Parse()
 	appSettings.validateFlags()
diff --git a/src/exporter.go b/src/exporter.go
index a62271c..843b861 100644
--- a/src/exporter.go
+++ b/src/exporter.go
@@ -48,6 +48,21 @@ func metricHandler(w http.ResponseWriter, r *http.Request, collector *textfile.C
 	collector.WriteTextFileMetrics(w, r)
 }
 
+func authMiddleware(handlerFunc http.HandlerFunc, appSettings *cfg.AppSettings) http.HandlerFunc {
+	authEnabled := len(appSettings.BasicAuthUsername) > 0 && len(appSettings.BasicAuthPassword) > 0
+	if authEnabled {
+		return func(w http.ResponseWriter, r *http.Request) {
+			username, password, ok := r.BasicAuth()
+			if !ok || username != appSettings.BasicAuthUsername || password != appSettings.BasicAuthPassword {
+				w.WriteHeader(http.StatusUnauthorized)
+				return
+			}
+			handlerFunc.ServeHTTP(w, r)
+		}
+	}
+	return handlerFunc
+}
+
 func main() {
 	appSettings := cfg.Parse()
 	if appSettings.VersionMode {
@@ -63,10 +78,13 @@ func main() {
 		textFileCollector := textfile.NewCollector(appSettings)
 		prometheus.MustRegister(textFileCollector)
 
-		http.HandleFunc("/", rootHtmlHandler)
-		http.HandleFunc(metricsPath, func(w http.ResponseWriter, r *http.Request) {
-			metricHandler(w, r, textFileCollector)
-		})
+		http.HandleFunc("/", authMiddleware(rootHtmlHandler, appSettings))
+		http.HandleFunc(metricsPath, authMiddleware(
+			func(w http.ResponseWriter, r *http.Request) {
+				metricHandler(w, r, textFileCollector)
+			},
+			appSettings,
+		))
 		log.Printf("metrics available at '%s'", metricsPath)
 
 		svrErr := make(chan error)

From e176a3ea2245a40f3e1783f4cb87f904d905a77e Mon Sep 17 00:00:00 2001
From: Hector <hector@hjs.dev>
Date: Wed, 12 Jan 2022 22:05:27 +0000
Subject: [PATCH 2/7] check basic auth username and password set

Add check to ensure basic auth username and password are both set or both
unset. It isn't valid to set one without the other.
Update README file to include the new CLI parameters.
---
 README.md      | 16 ++++++++++------
 src/cfg/cfg.go |  8 ++++++--
 2 files changed, 16 insertions(+), 8 deletions(-)

diff --git a/README.md b/README.md
index fb56af8..02b3676 100644
--- a/README.md
+++ b/README.md
@@ -39,18 +39,22 @@ See the [releases page](https://gitlab.com/hectorjsmith/fail2ban-prometheus-expo
 ```
 $ fail2ban-prometheus-exporter -h
 
-  -web.listen-address string
-        address to use for metrics server (default 0.0.0.0)
+  -collector.textfile
+        enable the textfile collector
+  -collector.textfile.directory string
+        directory to read text files with metrics from
   -port int
         port to use for the metrics server (default 9191)
   -socket string
         path to the fail2ban server socket
   -version
         show version info and exit
-  -collector.textfile
-        enable the textfile collector
-  -collector.textfile.directory string
-        directory to read text files with metrics from
+  -web.basic-auth.password string
+        password to use to protect endpoints with basic auth
+  -web.basic-auth.username string
+        username to use to protect endpoints with basic auth
+  -web.listen-address string
+        address to use for the metrics server (default "0.0.0.0")
 ```
 
 **Example**
diff --git a/src/cfg/cfg.go b/src/cfg/cfg.go
index 4d7bc64..85d291c 100644
--- a/src/cfg/cfg.go
+++ b/src/cfg/cfg.go
@@ -30,8 +30,8 @@ func Parse() *AppSettings {
 	flag.StringVar(&appSettings.Fail2BanSocketPath, "socket", "", "path to the fail2ban server socket")
 	flag.BoolVar(&appSettings.FileCollectorEnabled, "collector.textfile", false, "enable the textfile collector")
 	flag.StringVar(&appSettings.FileCollectorPath, "collector.textfile.directory", "", "directory to read text files with metrics from")
-	flag.StringVar(&appSettings.BasicAuthUsername, "web.basic-auth.username", "", "set username for basic auth")
-	flag.StringVar(&appSettings.BasicAuthPassword, "web.basic-auth.password", "", "set password for basic auth")
+	flag.StringVar(&appSettings.BasicAuthUsername, "web.basic-auth.username", "", "username to use to protect endpoints with basic auth")
+	flag.StringVar(&appSettings.BasicAuthPassword, "web.basic-auth.password", "", "password to use to protect endpoints with basic auth")
 
 	flag.Parse()
 	appSettings.validateFlags()
@@ -54,6 +54,10 @@ func (settings *AppSettings) validateFlags() {
 			fmt.Printf("file collector directory path must not be empty if collector enabled\n")
 			flagsValid = false
 		}
+		if (len(settings.BasicAuthUsername) > 0) != (len(settings.BasicAuthPassword) > 0) {
+			fmt.Printf("to enable basic auth both the username and the password must be provided")
+			flagsValid = false
+		}
 	}
 	if !flagsValid {
 		flag.Usage()

From 26f0b0d0cee07f664d13705d010eb3bac35b825f Mon Sep 17 00:00:00 2001
From: Hector <hector@hjs.dev>
Date: Thu, 13 Jan 2022 20:34:42 +0000
Subject: [PATCH 3/7] store basic auth credentials as hash instead of raw value

Update how the basic auth credentials are stored in the application to use
a hashed value instead of the raw value. This prevents the raw value from
being accidentally leaked.
Add new `auth` package for functions related to authentication.
---
 src/auth/hash.go       | 14 ++++++++++++++
 src/auth/middleware.go | 30 ++++++++++++++++++++++++++++++
 src/cfg/cfg.go         | 14 ++++++++++++--
 src/exporter.go        | 20 +++-----------------
 4 files changed, 59 insertions(+), 19 deletions(-)
 create mode 100644 src/auth/hash.go
 create mode 100644 src/auth/middleware.go

diff --git a/src/auth/hash.go b/src/auth/hash.go
new file mode 100644
index 0000000..e1bf8e2
--- /dev/null
+++ b/src/auth/hash.go
@@ -0,0 +1,14 @@
+package auth
+
+import (
+	"crypto/sha256"
+)
+
+func Hash(data []byte) []byte {
+	b := sha256.Sum256(data)
+	return b[:]
+}
+
+func HashString(data string) string {
+	return string(Hash([]byte(data)))
+}
diff --git a/src/auth/middleware.go b/src/auth/middleware.go
new file mode 100644
index 0000000..f83c291
--- /dev/null
+++ b/src/auth/middleware.go
@@ -0,0 +1,30 @@
+package auth
+
+import (
+	"fail2ban-prometheus-exporter/cfg"
+	"net/http"
+)
+
+func BasicAuthMiddleware(handlerFunc http.HandlerFunc, appSettings *cfg.AppSettings) http.HandlerFunc {
+	authEnabled := len(appSettings.BasicAuthUsername) > 0 && len(appSettings.BasicAuthPassword) > 0
+	if authEnabled {
+		return func(w http.ResponseWriter, r *http.Request) {
+			if doesBasicAuthMatch(r, appSettings) {
+				w.WriteHeader(http.StatusUnauthorized)
+			} else {
+				handlerFunc.ServeHTTP(w, r)
+			}
+		}
+	}
+	return handlerFunc
+}
+
+func doesBasicAuthMatch(r *http.Request, appSettings *cfg.AppSettings) bool {
+	rawUsername, rawPassword, ok := r.BasicAuth()
+	if ok {
+		username := HashString(rawUsername)
+		password := HashString(rawPassword)
+		return username == appSettings.BasicAuthUsername && password == appSettings.BasicAuthPassword
+	}
+	return false
+}
diff --git a/src/cfg/cfg.go b/src/cfg/cfg.go
index 85d291c..4b8da26 100644
--- a/src/cfg/cfg.go
+++ b/src/cfg/cfg.go
@@ -1,6 +1,7 @@
 package cfg
 
 import (
+	"fail2ban-prometheus-exporter/auth"
 	"flag"
 	"fmt"
 	"os"
@@ -23,6 +24,9 @@ type AppSettings struct {
 }
 
 func Parse() *AppSettings {
+	var rawBasicAuthUsername string
+	var rawBasicAuthPassword string
+
 	appSettings := &AppSettings{}
 	flag.BoolVar(&appSettings.VersionMode, "version", false, "show version info and exit")
 	flag.StringVar(&appSettings.MetricsAddress, "web.listen-address", "0.0.0.0", "address to use for the metrics server")
@@ -30,14 +34,20 @@ func Parse() *AppSettings {
 	flag.StringVar(&appSettings.Fail2BanSocketPath, "socket", "", "path to the fail2ban server socket")
 	flag.BoolVar(&appSettings.FileCollectorEnabled, "collector.textfile", false, "enable the textfile collector")
 	flag.StringVar(&appSettings.FileCollectorPath, "collector.textfile.directory", "", "directory to read text files with metrics from")
-	flag.StringVar(&appSettings.BasicAuthUsername, "web.basic-auth.username", "", "username to use to protect endpoints with basic auth")
-	flag.StringVar(&appSettings.BasicAuthPassword, "web.basic-auth.password", "", "password to use to protect endpoints with basic auth")
+	flag.StringVar(&rawBasicAuthUsername, "web.basic-auth.username", "", "username to use to protect endpoints with basic auth")
+	flag.StringVar(&rawBasicAuthPassword, "web.basic-auth.password", "", "password to use to protect endpoints with basic auth")
 
 	flag.Parse()
+	appSettings.setBasicAuthValues(rawBasicAuthUsername, rawBasicAuthPassword)
 	appSettings.validateFlags()
 	return appSettings
 }
 
+func (settings *AppSettings) setBasicAuthValues(username, password string) {
+	settings.BasicAuthUsername = auth.HashString(username)
+	settings.BasicAuthPassword = auth.HashString(password)
+}
+
 func (settings *AppSettings) validateFlags() {
 	var flagsValid = true
 	if !settings.VersionMode {
diff --git a/src/exporter.go b/src/exporter.go
index 843b861..743b577 100644
--- a/src/exporter.go
+++ b/src/exporter.go
@@ -1,6 +1,7 @@
 package main
 
 import (
+	"fail2ban-prometheus-exporter/auth"
 	"fail2ban-prometheus-exporter/cfg"
 	"fail2ban-prometheus-exporter/collector/f2b"
 	"fail2ban-prometheus-exporter/collector/textfile"
@@ -48,21 +49,6 @@ func metricHandler(w http.ResponseWriter, r *http.Request, collector *textfile.C
 	collector.WriteTextFileMetrics(w, r)
 }
 
-func authMiddleware(handlerFunc http.HandlerFunc, appSettings *cfg.AppSettings) http.HandlerFunc {
-	authEnabled := len(appSettings.BasicAuthUsername) > 0 && len(appSettings.BasicAuthPassword) > 0
-	if authEnabled {
-		return func(w http.ResponseWriter, r *http.Request) {
-			username, password, ok := r.BasicAuth()
-			if !ok || username != appSettings.BasicAuthUsername || password != appSettings.BasicAuthPassword {
-				w.WriteHeader(http.StatusUnauthorized)
-				return
-			}
-			handlerFunc.ServeHTTP(w, r)
-		}
-	}
-	return handlerFunc
-}
-
 func main() {
 	appSettings := cfg.Parse()
 	if appSettings.VersionMode {
@@ -78,8 +64,8 @@ func main() {
 		textFileCollector := textfile.NewCollector(appSettings)
 		prometheus.MustRegister(textFileCollector)
 
-		http.HandleFunc("/", authMiddleware(rootHtmlHandler, appSettings))
-		http.HandleFunc(metricsPath, authMiddleware(
+		http.HandleFunc("/", auth.BasicAuthMiddleware(rootHtmlHandler, appSettings))
+		http.HandleFunc(metricsPath, auth.BasicAuthMiddleware(
 			func(w http.ResponseWriter, r *http.Request) {
 				metricHandler(w, r, textFileCollector)
 			},

From e78d4e72985a014ba65fed9536f10afa0078c8b4 Mon Sep 17 00:00:00 2001
From: Hector <hector@hjs.dev>
Date: Thu, 13 Jan 2022 21:25:37 +0000
Subject: [PATCH 4/7] add new basic auth provider class

Fix the circular dependency by adding a new BasicAuthProvider interface to
wrap the basic auth values.
Add unit tests for the hash functions.
Refactor functions to use the new basic auth provider interface.
---
 src/auth/hash.go       |  6 +++++-
 src/auth/hash_test.go  | 26 ++++++++++++++++++++++++++
 src/auth/middleware.go | 17 ++++++++++-------
 src/cfg/basicAuth.go   | 14 ++++++++++++++
 src/cfg/cfg.go         | 13 +++++++------
 src/exporter.go        |  4 ++--
 6 files changed, 64 insertions(+), 16 deletions(-)
 create mode 100644 src/auth/hash_test.go
 create mode 100644 src/cfg/basicAuth.go

diff --git a/src/auth/hash.go b/src/auth/hash.go
index e1bf8e2..e1b4b3b 100644
--- a/src/auth/hash.go
+++ b/src/auth/hash.go
@@ -2,13 +2,17 @@ package auth
 
 import (
 	"crypto/sha256"
+	"encoding/hex"
 )
 
 func Hash(data []byte) []byte {
+	if len(data) == 0 {
+		return []byte{}
+	}
 	b := sha256.Sum256(data)
 	return b[:]
 }
 
 func HashString(data string) string {
-	return string(Hash([]byte(data)))
+	return hex.EncodeToString(Hash([]byte(data)))
 }
diff --git a/src/auth/hash_test.go b/src/auth/hash_test.go
new file mode 100644
index 0000000..ffe4bea
--- /dev/null
+++ b/src/auth/hash_test.go
@@ -0,0 +1,26 @@
+package auth
+
+import (
+	"reflect"
+	"testing"
+)
+
+func TestHashString(t *testing.T) {
+	tests := []struct {
+		name string
+		args string
+		want string
+	}{
+		{"Happy path #1", "123", "a665a45920422f9d417e4867efdc4fb8a04a1f3fff1fa07e998e86f7f7a27ae3"},
+		{"Happy path #2", "hello world", "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9"},
+		{"Happy path #3", "H3Ll0_W0RLD", "d58a27fe9a6e73a1d8a67189fb8acace047e7a1a795276a0056d3717ad61bd0e"},
+		{"Blank string", "", ""},
+	}
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			if got := HashString(tt.args); !reflect.DeepEqual(got, tt.want) {
+				t.Errorf("HashString() = %v, want %v", got, tt.want)
+			}
+		})
+	}
+}
diff --git a/src/auth/middleware.go b/src/auth/middleware.go
index f83c291..42e818a 100644
--- a/src/auth/middleware.go
+++ b/src/auth/middleware.go
@@ -1,15 +1,18 @@
 package auth
 
 import (
-	"fail2ban-prometheus-exporter/cfg"
 	"net/http"
 )
 
-func BasicAuthMiddleware(handlerFunc http.HandlerFunc, appSettings *cfg.AppSettings) http.HandlerFunc {
-	authEnabled := len(appSettings.BasicAuthUsername) > 0 && len(appSettings.BasicAuthPassword) > 0
-	if authEnabled {
+type BasicAuthProvider interface {
+	Enabled() bool
+	DoesBasicAuthMatch(username, password string) bool
+}
+
+func BasicAuthMiddleware(handlerFunc http.HandlerFunc, basicAuthProvider BasicAuthProvider) http.HandlerFunc {
+	if basicAuthProvider.Enabled() {
 		return func(w http.ResponseWriter, r *http.Request) {
-			if doesBasicAuthMatch(r, appSettings) {
+			if doesBasicAuthMatch(r, basicAuthProvider) {
 				w.WriteHeader(http.StatusUnauthorized)
 			} else {
 				handlerFunc.ServeHTTP(w, r)
@@ -19,12 +22,12 @@ func BasicAuthMiddleware(handlerFunc http.HandlerFunc, appSettings *cfg.AppSetti
 	return handlerFunc
 }
 
-func doesBasicAuthMatch(r *http.Request, appSettings *cfg.AppSettings) bool {
+func doesBasicAuthMatch(r *http.Request, basicAuthProvider BasicAuthProvider) bool {
 	rawUsername, rawPassword, ok := r.BasicAuth()
 	if ok {
 		username := HashString(rawUsername)
 		password := HashString(rawPassword)
-		return username == appSettings.BasicAuthUsername && password == appSettings.BasicAuthPassword
+		return basicAuthProvider.DoesBasicAuthMatch(username, password)
 	}
 	return false
 }
diff --git a/src/cfg/basicAuth.go b/src/cfg/basicAuth.go
new file mode 100644
index 0000000..c79dc8c
--- /dev/null
+++ b/src/cfg/basicAuth.go
@@ -0,0 +1,14 @@
+package cfg
+
+type basicAuth struct {
+	username string
+	password string
+}
+
+func (p *basicAuth) Enabled() bool {
+	return len(p.username) > 0 && len(p.password) > 0
+}
+
+func (p *basicAuth) DoesBasicAuthMatch(username, password string) bool {
+	return username == p.username && password == p.password
+}
diff --git a/src/cfg/cfg.go b/src/cfg/cfg.go
index 4b8da26..0b33b99 100644
--- a/src/cfg/cfg.go
+++ b/src/cfg/cfg.go
@@ -19,8 +19,7 @@ type AppSettings struct {
 	Fail2BanSocketPath   string
 	FileCollectorPath    string
 	FileCollectorEnabled bool
-	BasicAuthUsername    string
-	BasicAuthPassword    string
+	BasicAuthProvider    *basicAuth
 }
 
 func Parse() *AppSettings {
@@ -43,9 +42,11 @@ func Parse() *AppSettings {
 	return appSettings
 }
 
-func (settings *AppSettings) setBasicAuthValues(username, password string) {
-	settings.BasicAuthUsername = auth.HashString(username)
-	settings.BasicAuthPassword = auth.HashString(password)
+func (settings *AppSettings) setBasicAuthValues(rawUsername, rawPassword string) {
+	settings.BasicAuthProvider = &basicAuth{
+		username: auth.HashString(rawUsername),
+		password: auth.HashString(rawPassword),
+	}
 }
 
 func (settings *AppSettings) validateFlags() {
@@ -64,7 +65,7 @@ func (settings *AppSettings) validateFlags() {
 			fmt.Printf("file collector directory path must not be empty if collector enabled\n")
 			flagsValid = false
 		}
-		if (len(settings.BasicAuthUsername) > 0) != (len(settings.BasicAuthPassword) > 0) {
+		if (len(settings.BasicAuthProvider.username) > 0) != (len(settings.BasicAuthProvider.password) > 0) {
 			fmt.Printf("to enable basic auth both the username and the password must be provided")
 			flagsValid = false
 		}
diff --git a/src/exporter.go b/src/exporter.go
index 743b577..46eb3cc 100644
--- a/src/exporter.go
+++ b/src/exporter.go
@@ -64,12 +64,12 @@ func main() {
 		textFileCollector := textfile.NewCollector(appSettings)
 		prometheus.MustRegister(textFileCollector)
 
-		http.HandleFunc("/", auth.BasicAuthMiddleware(rootHtmlHandler, appSettings))
+		http.HandleFunc("/", auth.BasicAuthMiddleware(rootHtmlHandler, appSettings.BasicAuthProvider))
 		http.HandleFunc(metricsPath, auth.BasicAuthMiddleware(
 			func(w http.ResponseWriter, r *http.Request) {
 				metricHandler(w, r, textFileCollector)
 			},
-			appSettings,
+			appSettings.BasicAuthProvider,
 		))
 		log.Printf("metrics available at '%s'", metricsPath)
 

From 5be7095e879666e5ae8e45b782eda481fcfab176 Mon Sep 17 00:00:00 2001
From: Hector <hector@hjs.dev>
Date: Thu, 13 Jan 2022 22:21:39 +0000
Subject: [PATCH 5/7] refactor config basic auth to handle hashing data

Add unit tests for the basicAuth to ensure it behaves as expected.
Update the basic auth provider struct in the cfg package to handle all the
hashing automatically.
---
 src/auth/middleware.go    |  4 +--
 src/cfg/basicAuth.go      | 17 +++++++++--
 src/cfg/basicAuth_test.go | 60 +++++++++++++++++++++++++++++++++++++++
 src/cfg/cfg.go            |  8 ++----
 4 files changed, 77 insertions(+), 12 deletions(-)
 create mode 100644 src/cfg/basicAuth_test.go

diff --git a/src/auth/middleware.go b/src/auth/middleware.go
index 42e818a..2518932 100644
--- a/src/auth/middleware.go
+++ b/src/auth/middleware.go
@@ -25,9 +25,7 @@ func BasicAuthMiddleware(handlerFunc http.HandlerFunc, basicAuthProvider BasicAu
 func doesBasicAuthMatch(r *http.Request, basicAuthProvider BasicAuthProvider) bool {
 	rawUsername, rawPassword, ok := r.BasicAuth()
 	if ok {
-		username := HashString(rawUsername)
-		password := HashString(rawPassword)
-		return basicAuthProvider.DoesBasicAuthMatch(username, password)
+		return basicAuthProvider.DoesBasicAuthMatch(rawUsername, rawPassword)
 	}
 	return false
 }
diff --git a/src/cfg/basicAuth.go b/src/cfg/basicAuth.go
index c79dc8c..bc5408f 100644
--- a/src/cfg/basicAuth.go
+++ b/src/cfg/basicAuth.go
@@ -1,14 +1,25 @@
 package cfg
 
-type basicAuth struct {
+import "fail2ban-prometheus-exporter/auth"
+
+type hashedBasicAuth struct {
 	username string
 	password string
 }
 
-func (p *basicAuth) Enabled() bool {
+func newHashedBasicAuth(rawUsername, rawPassword string) *hashedBasicAuth {
+	return &hashedBasicAuth{
+		username: auth.HashString(rawUsername),
+		password: auth.HashString(rawPassword),
+	}
+}
+
+func (p *hashedBasicAuth) Enabled() bool {
 	return len(p.username) > 0 && len(p.password) > 0
 }
 
-func (p *basicAuth) DoesBasicAuthMatch(username, password string) bool {
+func (p *hashedBasicAuth) DoesBasicAuthMatch(rawUsername, rawPassword string) bool {
+	username := auth.HashString(rawUsername)
+	password := auth.HashString(rawPassword)
 	return username == p.username && password == p.password
 }
diff --git a/src/cfg/basicAuth_test.go b/src/cfg/basicAuth_test.go
new file mode 100644
index 0000000..85bd8b3
--- /dev/null
+++ b/src/cfg/basicAuth_test.go
@@ -0,0 +1,60 @@
+package cfg
+
+import "testing"
+
+func Test_hashedBasicAuth_DoesBasicAuthMatch(t *testing.T) {
+	type args struct {
+		username string
+		password string
+	}
+	type fields struct {
+		username string
+		password string
+	}
+	tests := []struct {
+		name   string
+		fields fields
+		args   args
+		want   bool
+	}{
+		{"Happy test #1", fields{username: "1234", password: "1234"}, args{username: "1234", password: "1234"}, true},
+		{"Happy test #2", fields{username: "test", password: "1234"}, args{username: "test", password: "1234"}, true},
+		{"Happy test #3", fields{username: "TEST", password: "1234"}, args{username: "TEST", password: "1234"}, true},
+		{"Non match #1", fields{username: "test", password: "1234"}, args{username: "1234", password: "1234"}, false},
+		{"Non match #2", fields{username: "1234", password: "test"}, args{username: "1234", password: "1234"}, false},
+		{"Non match #3", fields{username: "1234", password: "test"}, args{username: "1234", password: "TEST"}, false},
+	}
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			basicAuth := newHashedBasicAuth(tt.fields.username, tt.fields.password)
+			if got := basicAuth.DoesBasicAuthMatch(tt.args.username, tt.args.password); got != tt.want {
+				t.Errorf("DoesBasicAuthMatch() = %v, want %v", got, tt.want)
+			}
+		})
+	}
+}
+
+func Test_hashedBasicAuth_Enabled(t *testing.T) {
+	type fields struct {
+		username string
+		password string
+	}
+	tests := []struct {
+		name   string
+		fields fields
+		want   bool
+	}{
+		{"Both blank", fields{username: "", password: ""}, false},
+		{"Single blank #1", fields{username: "test", password: ""}, false},
+		{"Single blank #1", fields{username: "", password: "test"}, false},
+		{"Both populated", fields{username: "test", password: "test"}, true},
+	}
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			basicAuth := newHashedBasicAuth(tt.fields.username, tt.fields.password)
+			if got := basicAuth.Enabled(); got != tt.want {
+				t.Errorf("Enabled() = %v, want %v", got, tt.want)
+			}
+		})
+	}
+}
diff --git a/src/cfg/cfg.go b/src/cfg/cfg.go
index 0b33b99..008fef1 100644
--- a/src/cfg/cfg.go
+++ b/src/cfg/cfg.go
@@ -1,7 +1,6 @@
 package cfg
 
 import (
-	"fail2ban-prometheus-exporter/auth"
 	"flag"
 	"fmt"
 	"os"
@@ -19,7 +18,7 @@ type AppSettings struct {
 	Fail2BanSocketPath   string
 	FileCollectorPath    string
 	FileCollectorEnabled bool
-	BasicAuthProvider    *basicAuth
+	BasicAuthProvider    *hashedBasicAuth
 }
 
 func Parse() *AppSettings {
@@ -43,10 +42,7 @@ func Parse() *AppSettings {
 }
 
 func (settings *AppSettings) setBasicAuthValues(rawUsername, rawPassword string) {
-	settings.BasicAuthProvider = &basicAuth{
-		username: auth.HashString(rawUsername),
-		password: auth.HashString(rawPassword),
-	}
+	settings.BasicAuthProvider = newHashedBasicAuth(rawUsername, rawPassword)
 }
 
 func (settings *AppSettings) validateFlags() {

From ab7f9d854bb6674e8bb3f49d57a5b85dc6619029 Mon Sep 17 00:00:00 2001
From: Hector <hector@hjs.dev>
Date: Thu, 13 Jan 2022 22:57:53 +0000
Subject: [PATCH 6/7] add unit tests for the auth middleware

Add new unit tests for the auth middleware. Fix minor bug in the auth
middleware where the auth logic was backwards.
---
 src/auth/middleware.go      |  4 +-
 src/auth/middleware_test.go | 82 +++++++++++++++++++++++++++++++++++++
 2 files changed, 84 insertions(+), 2 deletions(-)
 create mode 100644 src/auth/middleware_test.go

diff --git a/src/auth/middleware.go b/src/auth/middleware.go
index 2518932..dd16fe3 100644
--- a/src/auth/middleware.go
+++ b/src/auth/middleware.go
@@ -13,9 +13,9 @@ func BasicAuthMiddleware(handlerFunc http.HandlerFunc, basicAuthProvider BasicAu
 	if basicAuthProvider.Enabled() {
 		return func(w http.ResponseWriter, r *http.Request) {
 			if doesBasicAuthMatch(r, basicAuthProvider) {
-				w.WriteHeader(http.StatusUnauthorized)
-			} else {
 				handlerFunc.ServeHTTP(w, r)
+			} else {
+				w.WriteHeader(http.StatusUnauthorized)
 			}
 		}
 	}
diff --git a/src/auth/middleware_test.go b/src/auth/middleware_test.go
new file mode 100644
index 0000000..3b2381e
--- /dev/null
+++ b/src/auth/middleware_test.go
@@ -0,0 +1,82 @@
+package auth
+
+import (
+	"net/http"
+	"net/http/httptest"
+	"testing"
+)
+
+type testAuthProvider struct {
+	enabled bool
+	match   bool
+}
+
+func (p testAuthProvider) Enabled() bool {
+	return p.enabled
+}
+
+func (p testAuthProvider) DoesBasicAuthMatch(username, password string) bool {
+	return p.match
+}
+
+func newTestRequest() *http.Request {
+	return httptest.NewRequest(http.MethodGet, "http://example.com", nil)
+}
+
+func Test_GIVEN_DisabledBasicAuth_WHEN_MethodCalled_THEN_RequestProcessed(t *testing.T) {
+	callCount := 0
+	testHandler := func(w http.ResponseWriter, r *http.Request) {
+		callCount++
+	}
+
+	handler := BasicAuthMiddleware(testHandler, testAuthProvider{enabled: false, match: false})
+	recorder := httptest.NewRecorder()
+	handler.ServeHTTP(recorder, newTestRequest())
+
+	if recorder.Code != http.StatusOK {
+		t.Errorf("statusCode = %v, want %v", recorder.Code, http.StatusOK)
+	}
+	if callCount != 1 {
+		t.Errorf("callCount = %v, want %v", callCount, 1)
+	}
+}
+
+func Test_GIVEN_EnabledBasicAuth_WHEN_MethodCalledWithCorrectCredentials_THEN_RequestProcessed(t *testing.T) {
+	callCount := 0
+	testHandler := func(w http.ResponseWriter, r *http.Request) {
+		callCount++
+	}
+
+	handler := BasicAuthMiddleware(testHandler, testAuthProvider{enabled: true, match: true})
+	recorder := httptest.NewRecorder()
+	request := newTestRequest()
+	request.SetBasicAuth("test", "1234")
+	handler.ServeHTTP(recorder, request)
+
+	if recorder.Code != http.StatusOK {
+		t.Errorf("statusCode = %v, want %v", recorder.Code, http.StatusOK)
+	}
+	if callCount != 1 {
+		t.Errorf("callCount = %v, want %v", callCount, 1)
+	}
+}
+
+func Test_GIVEN_EnabledBasicAuth_WHEN_MethodCalledWithIncorrectCredentials_THEN_RequestRejected(t *testing.T) {
+	callCount := 0
+	testHandler := func(w http.ResponseWriter, r *http.Request) {
+		callCount++
+	}
+
+	handler := BasicAuthMiddleware(testHandler, testAuthProvider{enabled: true, match: false})
+	recorder := httptest.NewRecorder()
+	request := newTestRequest()
+	request.SetBasicAuth("user", "pass")
+	handler.ServeHTTP(recorder, request)
+
+	if recorder.Code != http.StatusUnauthorized {
+		t.Errorf("statusCode = %v, want %v", recorder.Code, http.StatusUnauthorized)
+	}
+	if callCount != 0 {
+		t.Errorf("callCount = %v, want %v", callCount, 0)
+	}
+}

From de5d9e4b115a12480990e0b511eb7b0b8634b13b Mon Sep 17 00:00:00 2001
From: Hector <hector@hjs.dev>
Date: Fri, 14 Jan 2022 17:06:46 +0000
Subject: [PATCH 7/7] reduce duplicate code in basic auth test

---
 src/auth/middleware_test.go | 58 +++++++++++--------------------------
 1 file changed, 17 insertions(+), 41 deletions(-)

diff --git a/src/auth/middleware_test.go b/src/auth/middleware_test.go
index 3b2381e..27c8ab6 100644
--- a/src/auth/middleware_test.go
+++ b/src/auth/middleware_test.go
@@ -23,60 +23,36 @@ func newTestRequest() *http.Request {
 	return httptest.NewRequest(http.MethodGet, "http://example.com", nil)
 }
 
-func Test_GIVEN_DisabledBasicAuth_WHEN_MethodCalled_THEN_RequestProcessed(t *testing.T) {
+func executeBasicAuthMiddlewareTest(t *testing.T, authEnabled bool, authMatches bool, expectedCode int, expectedCallCount int) {
 	callCount := 0
 	testHandler := func(w http.ResponseWriter, r *http.Request) {
 		callCount++
 	}
 
-	handler := BasicAuthMiddleware(testHandler, testAuthProvider{enabled: false, match: false})
+	handler := BasicAuthMiddleware(testHandler, testAuthProvider{enabled: authEnabled, match: authMatches})
 	recorder := httptest.NewRecorder()
-	handler.ServeHTTP(recorder, newTestRequest())
+	request := newTestRequest()
+	if authEnabled {
+		request.SetBasicAuth("test", "test")
+	}
+	handler.ServeHTTP(recorder, request)
 
-	if recorder.Code != http.StatusOK {
-		t.Errorf("statusCode = %v, want %v", recorder.Code, http.StatusOK)
+	if recorder.Code != expectedCode {
+		t.Errorf("statusCode = %v, want %v", recorder.Code, expectedCode)
 	}
-	if callCount != 1 {
-		t.Errorf("callCount = %v, want %v", callCount, 1)
+	if callCount != expectedCallCount {
+		t.Errorf("callCount = %v, want %v", callCount, expectedCallCount)
 	}
 }
 
+func Test_GIVEN_DisabledBasicAuth_WHEN_MethodCalled_THEN_RequestProcessed(t *testing.T) {
+	executeBasicAuthMiddlewareTest(t, false, false, http.StatusOK, 1)
+}
+
 func Test_GIVEN_EnabledBasicAuth_WHEN_MethodCalledWithCorrectCredentials_THEN_RequestProcessed(t *testing.T) {
-	callCount := 0
-	testHandler := func(w http.ResponseWriter, r *http.Request) {
-		callCount++
-	}
-
-	handler := BasicAuthMiddleware(testHandler, testAuthProvider{enabled: true, match: true})
-	recorder := httptest.NewRecorder()
-	request := newTestRequest()
-	request.SetBasicAuth("test", "1234")
-	handler.ServeHTTP(recorder, request)
-
-	if recorder.Code != http.StatusOK {
-		t.Errorf("statusCode = %v, want %v", recorder.Code, http.StatusOK)
-	}
-	if callCount != 1 {
-		t.Errorf("callCount = %v, want %v", callCount, 1)
-	}
+	executeBasicAuthMiddlewareTest(t, true, true, http.StatusOK, 1)
 }
 
 func Test_GIVEN_EnabledBasicAuth_WHEN_MethodCalledWithIncorrectCredentials_THEN_RequestRejected(t *testing.T) {
-	callCount := 0
-	testHandler := func(w http.ResponseWriter, r *http.Request) {
-		callCount++
-	}
-
-	handler := BasicAuthMiddleware(testHandler, testAuthProvider{enabled: true, match: false})
-	recorder := httptest.NewRecorder()
-	request := newTestRequest()
-	request.SetBasicAuth("user", "pass")
-	handler.ServeHTTP(recorder, request)
-
-	if recorder.Code != http.StatusUnauthorized {
-		t.Errorf("statusCode = %v, want %v", recorder.Code, http.StatusUnauthorized)
-	}
-	if callCount != 0 {
-		t.Errorf("callCount = %v, want %v", callCount, 0)
-	}
+	executeBasicAuthMiddlewareTest(t, true, false, http.StatusUnauthorized, 0)
 }