rewrite auth provider logic

This commit is contained in:
Hector 2023-06-20 22:15:12 +01:00
parent d31ea4b23c
commit 17fb995928
8 changed files with 100 additions and 51 deletions

29
auth/basic.go Normal file
View File

@ -0,0 +1,29 @@
package auth
import (
"fmt"
"net/http"
)
func NewBasicAuthProvider(username, password string) AuthProvider {
return &basicAuthProvider{
hashedAuth: encodeBasicAuth(username, password),
}
}
type basicAuthProvider struct {
hashedAuth string
}
func (p *basicAuthProvider) IsAllowed(request *http.Request) bool {
username, password, ok := request.BasicAuth()
if !ok {
return false
}
requestAuth := encodeBasicAuth(username, password)
return p.hashedAuth == requestAuth
}
func encodeBasicAuth(username, password string) string {
return HashString(fmt.Sprintf("%s:%s", username, password))
}

View File

@ -5,7 +5,7 @@ import (
"encoding/hex"
)
func Hash(data []byte) []byte {
func hash(data []byte) []byte {
if len(data) == 0 {
return []byte{}
}
@ -14,5 +14,5 @@ func Hash(data []byte) []byte {
}
func HashString(data string) string {
return hex.EncodeToString(Hash([]byte(data)))
return hex.EncodeToString(hash([]byte(data)))
}

34
auth/provider.go Normal file
View File

@ -0,0 +1,34 @@
package auth
import (
"net/http"
)
type AuthProvider interface {
IsAllowed(*http.Request) bool
}
func NewEmptyAuthProvider() AuthProvider {
return &emptyAuthProvider{}
}
type emptyAuthProvider struct {
}
func (p *emptyAuthProvider) IsAllowed(request *http.Request) bool {
return true
}
type compositeAuthProvider struct {
providers []AuthProvider
}
func (p *compositeAuthProvider) IsAllowed(request *http.Request) bool {
for i := 0; i < len(p.providers); i++ {
provider := p.providers[i]
if provider.IsAllowed(request) {
return true
}
}
return false
}

View File

@ -2,9 +2,11 @@ package cfg
import (
"fmt"
"log"
"os"
"github.com/alecthomas/kong"
"gitlab.com/hectorjsmith/fail2ban-prometheus-exporter/auth"
)
var cliStruct struct {
@ -36,11 +38,22 @@ func Parse() *AppSettings {
Fail2BanSocketPath: cliStruct.F2bSocketPath,
FileCollectorPath: cliStruct.TextFileExporterPath,
ExitOnSocketConnError: cliStruct.ExitOnSocketError,
BasicAuthProvider: newHashedBasicAuth(cliStruct.BasicAuthUser, cliStruct.BasicAuthPass),
AuthProvider: createAuthProvider(),
}
return settings
}
func createAuthProvider() auth.AuthProvider {
username := cliStruct.BasicAuthUser
password := cliStruct.BasicAuthPass
if len(username) == 0 && len(password) == 0 {
return auth.NewEmptyAuthProvider()
}
log.Print("basic auth enabled")
return auth.NewBasicAuthProvider(username, password)
}
func validateFlags(cliCtx *kong.Context) {
var flagsValid = true
var messages = []string{}

View File

@ -1,10 +1,12 @@
package cfg
import "gitlab.com/hectorjsmith/fail2ban-prometheus-exporter/auth"
type AppSettings struct {
VersionMode bool
MetricsAddress string
Fail2BanSocketPath string
FileCollectorPath string
BasicAuthProvider *hashedBasicAuth
AuthProvider auth.AuthProvider
ExitOnSocketConnError bool
}

View File

@ -67,17 +67,14 @@ func main() {
textFileCollector := textfile.NewCollector(appSettings)
prometheus.MustRegister(textFileCollector)
http.HandleFunc("/", server.BasicAuthMiddleware(rootHtmlHandler, appSettings.BasicAuthProvider))
http.HandleFunc("/", server.BasicAuthMiddleware(rootHtmlHandler, appSettings.AuthProvider))
http.HandleFunc(metricsPath, server.BasicAuthMiddleware(
func(w http.ResponseWriter, r *http.Request) {
metricHandler(w, r, textFileCollector)
},
appSettings.BasicAuthProvider,
appSettings.AuthProvider,
))
log.Printf("metrics available at '%s'", metricsPath)
if appSettings.BasicAuthProvider.Enabled() {
log.Printf("basic auth enabled")
}
svrErr := make(chan error)
go func() {

View File

@ -2,30 +2,16 @@ package server
import (
"net/http"
"gitlab.com/hectorjsmith/fail2ban-prometheus-exporter/auth"
)
type BasicAuthProvider interface {
Enabled() bool
DoesBasicAuthMatch(username, password string) bool
}
func BasicAuthMiddleware(handlerFunc http.HandlerFunc, basicAuthProvider BasicAuthProvider) http.HandlerFunc {
if basicAuthProvider.Enabled() {
return func(w http.ResponseWriter, r *http.Request) {
if doesBasicAuthMatch(r, basicAuthProvider) {
handlerFunc.ServeHTTP(w, r)
} else {
w.WriteHeader(http.StatusUnauthorized)
}
func BasicAuthMiddleware(handlerFunc http.HandlerFunc, authProvider auth.AuthProvider) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if authProvider.IsAllowed(r) {
handlerFunc.ServeHTTP(w, r)
} else {
w.WriteHeader(http.StatusUnauthorized)
}
}
return handlerFunc
}
func doesBasicAuthMatch(r *http.Request, basicAuthProvider BasicAuthProvider) bool {
rawUsername, rawPassword, ok := r.BasicAuth()
if ok {
return basicAuthProvider.DoesBasicAuthMatch(rawUsername, rawPassword)
}
return false
}

View File

@ -7,15 +7,10 @@ import (
)
type testAuthProvider struct {
enabled bool
match bool
match bool
}
func (p testAuthProvider) Enabled() bool {
return p.enabled
}
func (p testAuthProvider) DoesBasicAuthMatch(username, password string) bool {
func (p testAuthProvider) IsAllowed(request *http.Request) bool {
return p.match
}
@ -23,18 +18,15 @@ func newTestRequest() *http.Request {
return httptest.NewRequest(http.MethodGet, "http://example.com", nil)
}
func executeBasicAuthMiddlewareTest(t *testing.T, authEnabled bool, authMatches bool, expectedCode int, expectedCallCount int) {
func executeBasicAuthMiddlewareTest(t *testing.T, authMatches bool, expectedCode int, expectedCallCount int) {
callCount := 0
testHandler := func(w http.ResponseWriter, r *http.Request) {
callCount++
}
handler := BasicAuthMiddleware(testHandler, testAuthProvider{enabled: authEnabled, match: authMatches})
handler := BasicAuthMiddleware(testHandler, testAuthProvider{match: authMatches})
recorder := httptest.NewRecorder()
request := newTestRequest()
if authEnabled {
request.SetBasicAuth("test", "test")
}
handler.ServeHTTP(recorder, request)
if recorder.Code != expectedCode {
@ -45,14 +37,10 @@ func executeBasicAuthMiddlewareTest(t *testing.T, authEnabled bool, authMatches
}
}
func Test_GIVEN_DisabledBasicAuth_WHEN_MethodCalled_THEN_RequestProcessed(t *testing.T) {
executeBasicAuthMiddlewareTest(t, false, false, http.StatusOK, 1)
func Test_GIVEN_MatchingBasicAuth_WHEN_MethodCalled_THEN_RequestProcessed(t *testing.T) {
executeBasicAuthMiddlewareTest(t, true, http.StatusOK, 1)
}
func Test_GIVEN_EnabledBasicAuth_WHEN_MethodCalledWithCorrectCredentials_THEN_RequestProcessed(t *testing.T) {
executeBasicAuthMiddlewareTest(t, true, true, http.StatusOK, 1)
}
func Test_GIVEN_EnabledBasicAuth_WHEN_MethodCalledWithIncorrectCredentials_THEN_RequestRejected(t *testing.T) {
executeBasicAuthMiddlewareTest(t, true, false, http.StatusUnauthorized, 0)
func Test_GIVEN_NonMatchingBasicAuth_WHEN_MethodCalled_THEN_RequestRejected(t *testing.T) {
executeBasicAuthMiddlewareTest(t, false, http.StatusUnauthorized, 0)
}