add unit tests for the auth middleware

Add new unit tests for the auth middleware. Fix minor bug in the auth
middleware where the auth logic was backwards.
This commit is contained in:
Hector 2022-01-13 22:57:53 +00:00
parent 5be7095e87
commit ab7f9d854b
2 changed files with 84 additions and 2 deletions

View File

@ -13,9 +13,9 @@ func BasicAuthMiddleware(handlerFunc http.HandlerFunc, basicAuthProvider BasicAu
if basicAuthProvider.Enabled() { if basicAuthProvider.Enabled() {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
if doesBasicAuthMatch(r, basicAuthProvider) { if doesBasicAuthMatch(r, basicAuthProvider) {
w.WriteHeader(http.StatusUnauthorized)
} else {
handlerFunc.ServeHTTP(w, r) handlerFunc.ServeHTTP(w, r)
} else {
w.WriteHeader(http.StatusUnauthorized)
} }
} }
} }

View File

@ -0,0 +1,82 @@
package auth
import (
"net/http"
"net/http/httptest"
"testing"
)
type testAuthProvider struct {
enabled bool
match bool
}
func (p testAuthProvider) Enabled() bool {
return p.enabled
}
func (p testAuthProvider) DoesBasicAuthMatch(username, password string) bool {
return p.match
}
func newTestRequest() *http.Request {
return httptest.NewRequest(http.MethodGet, "http://example.com", nil)
}
func Test_GIVEN_DisabledBasicAuth_WHEN_MethodCalled_THEN_RequestProcessed(t *testing.T) {
callCount := 0
testHandler := func(w http.ResponseWriter, r *http.Request) {
callCount++
}
handler := BasicAuthMiddleware(testHandler, testAuthProvider{enabled: false, match: false})
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, newTestRequest())
if recorder.Code != http.StatusOK {
t.Errorf("statusCode = %v, want %v", recorder.Code, http.StatusOK)
}
if callCount != 1 {
t.Errorf("callCount = %v, want %v", callCount, 1)
}
}
func Test_GIVEN_EnabledBasicAuth_WHEN_MethodCalledWithCorrectCredentials_THEN_RequestProcessed(t *testing.T) {
callCount := 0
testHandler := func(w http.ResponseWriter, r *http.Request) {
callCount++
}
handler := BasicAuthMiddleware(testHandler, testAuthProvider{enabled: true, match: true})
recorder := httptest.NewRecorder()
request := newTestRequest()
request.SetBasicAuth("test", "1234")
handler.ServeHTTP(recorder, request)
if recorder.Code != http.StatusOK {
t.Errorf("statusCode = %v, want %v", recorder.Code, http.StatusOK)
}
if callCount != 1 {
t.Errorf("callCount = %v, want %v", callCount, 1)
}
}
func Test_GIVEN_EnabledBasicAuth_WHEN_MethodCalledWithIncorrectCredentials_THEN_RequestRejected(t *testing.T) {
callCount := 0
testHandler := func(w http.ResponseWriter, r *http.Request) {
callCount++
}
handler := BasicAuthMiddleware(testHandler, testAuthProvider{enabled: true, match: false})
recorder := httptest.NewRecorder()
request := newTestRequest()
request.SetBasicAuth("user", "pass")
handler.ServeHTTP(recorder, request)
if recorder.Code != http.StatusUnauthorized {
t.Errorf("statusCode = %v, want %v", recorder.Code, http.StatusUnauthorized)
}
if callCount != 0 {
t.Errorf("callCount = %v, want %v", callCount, 0)
}
}