diff --git a/src/auth/middleware.go b/src/auth/middleware.go index 2518932..dd16fe3 100644 --- a/src/auth/middleware.go +++ b/src/auth/middleware.go @@ -13,9 +13,9 @@ func BasicAuthMiddleware(handlerFunc http.HandlerFunc, basicAuthProvider BasicAu if basicAuthProvider.Enabled() { return func(w http.ResponseWriter, r *http.Request) { if doesBasicAuthMatch(r, basicAuthProvider) { - w.WriteHeader(http.StatusUnauthorized) - } else { handlerFunc.ServeHTTP(w, r) + } else { + w.WriteHeader(http.StatusUnauthorized) } } } diff --git a/src/auth/middleware_test.go b/src/auth/middleware_test.go new file mode 100644 index 0000000..3b2381e --- /dev/null +++ b/src/auth/middleware_test.go @@ -0,0 +1,82 @@ +package auth + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +type testAuthProvider struct { + enabled bool + match bool +} + +func (p testAuthProvider) Enabled() bool { + return p.enabled +} + +func (p testAuthProvider) DoesBasicAuthMatch(username, password string) bool { + return p.match +} + +func newTestRequest() *http.Request { + return httptest.NewRequest(http.MethodGet, "http://example.com", nil) +} + +func Test_GIVEN_DisabledBasicAuth_WHEN_MethodCalled_THEN_RequestProcessed(t *testing.T) { + callCount := 0 + testHandler := func(w http.ResponseWriter, r *http.Request) { + callCount++ + } + + handler := BasicAuthMiddleware(testHandler, testAuthProvider{enabled: false, match: false}) + recorder := httptest.NewRecorder() + handler.ServeHTTP(recorder, newTestRequest()) + + if recorder.Code != http.StatusOK { + t.Errorf("statusCode = %v, want %v", recorder.Code, http.StatusOK) + } + if callCount != 1 { + t.Errorf("callCount = %v, want %v", callCount, 1) + } +} + +func Test_GIVEN_EnabledBasicAuth_WHEN_MethodCalledWithCorrectCredentials_THEN_RequestProcessed(t *testing.T) { + callCount := 0 + testHandler := func(w http.ResponseWriter, r *http.Request) { + callCount++ + } + + handler := BasicAuthMiddleware(testHandler, testAuthProvider{enabled: true, match: true}) + recorder := httptest.NewRecorder() + request := newTestRequest() + request.SetBasicAuth("test", "1234") + handler.ServeHTTP(recorder, request) + + if recorder.Code != http.StatusOK { + t.Errorf("statusCode = %v, want %v", recorder.Code, http.StatusOK) + } + if callCount != 1 { + t.Errorf("callCount = %v, want %v", callCount, 1) + } +} + +func Test_GIVEN_EnabledBasicAuth_WHEN_MethodCalledWithIncorrectCredentials_THEN_RequestRejected(t *testing.T) { + callCount := 0 + testHandler := func(w http.ResponseWriter, r *http.Request) { + callCount++ + } + + handler := BasicAuthMiddleware(testHandler, testAuthProvider{enabled: true, match: false}) + recorder := httptest.NewRecorder() + request := newTestRequest() + request.SetBasicAuth("user", "pass") + handler.ServeHTTP(recorder, request) + + if recorder.Code != http.StatusUnauthorized { + t.Errorf("statusCode = %v, want %v", recorder.Code, http.StatusUnauthorized) + } + if callCount != 0 { + t.Errorf("callCount = %v, want %v", callCount, 0) + } +}