From ab7f9d854bb6674e8bb3f49d57a5b85dc6619029 Mon Sep 17 00:00:00 2001
From: Hector <hector@hjs.dev>
Date: Thu, 13 Jan 2022 22:57:53 +0000
Subject: [PATCH] add unit tests for the auth middleware

Add new unit tests for the auth middleware. Fix minor bug in the auth
middleware where the auth logic was backwards.
---
 src/auth/middleware.go      |  4 +-
 src/auth/middleware_test.go | 82 +++++++++++++++++++++++++++++++++++++
 2 files changed, 84 insertions(+), 2 deletions(-)
 create mode 100644 src/auth/middleware_test.go

diff --git a/src/auth/middleware.go b/src/auth/middleware.go
index 2518932..dd16fe3 100644
--- a/src/auth/middleware.go
+++ b/src/auth/middleware.go
@@ -13,9 +13,9 @@ func BasicAuthMiddleware(handlerFunc http.HandlerFunc, basicAuthProvider BasicAu
 	if basicAuthProvider.Enabled() {
 		return func(w http.ResponseWriter, r *http.Request) {
 			if doesBasicAuthMatch(r, basicAuthProvider) {
-				w.WriteHeader(http.StatusUnauthorized)
-			} else {
 				handlerFunc.ServeHTTP(w, r)
+			} else {
+				w.WriteHeader(http.StatusUnauthorized)
 			}
 		}
 	}
diff --git a/src/auth/middleware_test.go b/src/auth/middleware_test.go
new file mode 100644
index 0000000..3b2381e
--- /dev/null
+++ b/src/auth/middleware_test.go
@@ -0,0 +1,82 @@
+package auth
+
+import (
+	"net/http"
+	"net/http/httptest"
+	"testing"
+)
+
+type testAuthProvider struct {
+	enabled bool
+	match   bool
+}
+
+func (p testAuthProvider) Enabled() bool {
+	return p.enabled
+}
+
+func (p testAuthProvider) DoesBasicAuthMatch(username, password string) bool {
+	return p.match
+}
+
+func newTestRequest() *http.Request {
+	return httptest.NewRequest(http.MethodGet, "http://example.com", nil)
+}
+
+func Test_GIVEN_DisabledBasicAuth_WHEN_MethodCalled_THEN_RequestProcessed(t *testing.T) {
+	callCount := 0
+	testHandler := func(w http.ResponseWriter, r *http.Request) {
+		callCount++
+	}
+
+	handler := BasicAuthMiddleware(testHandler, testAuthProvider{enabled: false, match: false})
+	recorder := httptest.NewRecorder()
+	handler.ServeHTTP(recorder, newTestRequest())
+
+	if recorder.Code != http.StatusOK {
+		t.Errorf("statusCode = %v, want %v", recorder.Code, http.StatusOK)
+	}
+	if callCount != 1 {
+		t.Errorf("callCount = %v, want %v", callCount, 1)
+	}
+}
+
+func Test_GIVEN_EnabledBasicAuth_WHEN_MethodCalledWithCorrectCredentials_THEN_RequestProcessed(t *testing.T) {
+	callCount := 0
+	testHandler := func(w http.ResponseWriter, r *http.Request) {
+		callCount++
+	}
+
+	handler := BasicAuthMiddleware(testHandler, testAuthProvider{enabled: true, match: true})
+	recorder := httptest.NewRecorder()
+	request := newTestRequest()
+	request.SetBasicAuth("test", "1234")
+	handler.ServeHTTP(recorder, request)
+
+	if recorder.Code != http.StatusOK {
+		t.Errorf("statusCode = %v, want %v", recorder.Code, http.StatusOK)
+	}
+	if callCount != 1 {
+		t.Errorf("callCount = %v, want %v", callCount, 1)
+	}
+}
+
+func Test_GIVEN_EnabledBasicAuth_WHEN_MethodCalledWithIncorrectCredentials_THEN_RequestRejected(t *testing.T) {
+	callCount := 0
+	testHandler := func(w http.ResponseWriter, r *http.Request) {
+		callCount++
+	}
+
+	handler := BasicAuthMiddleware(testHandler, testAuthProvider{enabled: true, match: false})
+	recorder := httptest.NewRecorder()
+	request := newTestRequest()
+	request.SetBasicAuth("user", "pass")
+	handler.ServeHTTP(recorder, request)
+
+	if recorder.Code != http.StatusUnauthorized {
+		t.Errorf("statusCode = %v, want %v", recorder.Code, http.StatusUnauthorized)
+	}
+	if callCount != 0 {
+		t.Errorf("callCount = %v, want %v", callCount, 0)
+	}
+}