diff --git a/src/auth/middleware_test.go b/src/auth/middleware_test.go index 3b2381e..27c8ab6 100644 --- a/src/auth/middleware_test.go +++ b/src/auth/middleware_test.go @@ -23,60 +23,36 @@ func newTestRequest() *http.Request { return httptest.NewRequest(http.MethodGet, "http://example.com", nil) } -func Test_GIVEN_DisabledBasicAuth_WHEN_MethodCalled_THEN_RequestProcessed(t *testing.T) { +func executeBasicAuthMiddlewareTest(t *testing.T, authEnabled bool, authMatches bool, expectedCode int, expectedCallCount int) { callCount := 0 testHandler := func(w http.ResponseWriter, r *http.Request) { callCount++ } - handler := BasicAuthMiddleware(testHandler, testAuthProvider{enabled: false, match: false}) + handler := BasicAuthMiddleware(testHandler, testAuthProvider{enabled: authEnabled, match: authMatches}) recorder := httptest.NewRecorder() - handler.ServeHTTP(recorder, newTestRequest()) + request := newTestRequest() + if authEnabled { + request.SetBasicAuth("test", "test") + } + handler.ServeHTTP(recorder, request) - if recorder.Code != http.StatusOK { - t.Errorf("statusCode = %v, want %v", recorder.Code, http.StatusOK) + if recorder.Code != expectedCode { + t.Errorf("statusCode = %v, want %v", recorder.Code, expectedCode) } - if callCount != 1 { - t.Errorf("callCount = %v, want %v", callCount, 1) + if callCount != expectedCallCount { + t.Errorf("callCount = %v, want %v", callCount, expectedCallCount) } } +func Test_GIVEN_DisabledBasicAuth_WHEN_MethodCalled_THEN_RequestProcessed(t *testing.T) { + executeBasicAuthMiddlewareTest(t, false, false, http.StatusOK, 1) +} + func Test_GIVEN_EnabledBasicAuth_WHEN_MethodCalledWithCorrectCredentials_THEN_RequestProcessed(t *testing.T) { - callCount := 0 - testHandler := func(w http.ResponseWriter, r *http.Request) { - callCount++ - } - - handler := BasicAuthMiddleware(testHandler, testAuthProvider{enabled: true, match: true}) - recorder := httptest.NewRecorder() - request := newTestRequest() - request.SetBasicAuth("test", "1234") - handler.ServeHTTP(recorder, request) - - if recorder.Code != http.StatusOK { - t.Errorf("statusCode = %v, want %v", recorder.Code, http.StatusOK) - } - if callCount != 1 { - t.Errorf("callCount = %v, want %v", callCount, 1) - } + executeBasicAuthMiddlewareTest(t, true, true, http.StatusOK, 1) } func Test_GIVEN_EnabledBasicAuth_WHEN_MethodCalledWithIncorrectCredentials_THEN_RequestRejected(t *testing.T) { - callCount := 0 - testHandler := func(w http.ResponseWriter, r *http.Request) { - callCount++ - } - - handler := BasicAuthMiddleware(testHandler, testAuthProvider{enabled: true, match: false}) - recorder := httptest.NewRecorder() - request := newTestRequest() - request.SetBasicAuth("user", "pass") - handler.ServeHTTP(recorder, request) - - if recorder.Code != http.StatusUnauthorized { - t.Errorf("statusCode = %v, want %v", recorder.Code, http.StatusUnauthorized) - } - if callCount != 0 { - t.Errorf("callCount = %v, want %v", callCount, 0) - } + executeBasicAuthMiddlewareTest(t, true, false, http.StatusUnauthorized, 0) }