rewrite auth provider logic
This commit is contained in:
parent
d31ea4b23c
commit
17fb995928
29
auth/basic.go
Normal file
29
auth/basic.go
Normal file
@ -0,0 +1,29 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func NewBasicAuthProvider(username, password string) AuthProvider {
|
||||
return &basicAuthProvider{
|
||||
hashedAuth: encodeBasicAuth(username, password),
|
||||
}
|
||||
}
|
||||
|
||||
type basicAuthProvider struct {
|
||||
hashedAuth string
|
||||
}
|
||||
|
||||
func (p *basicAuthProvider) IsAllowed(request *http.Request) bool {
|
||||
username, password, ok := request.BasicAuth()
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
requestAuth := encodeBasicAuth(username, password)
|
||||
return p.hashedAuth == requestAuth
|
||||
}
|
||||
|
||||
func encodeBasicAuth(username, password string) string {
|
||||
return HashString(fmt.Sprintf("%s:%s", username, password))
|
||||
}
|
@ -5,7 +5,7 @@ import (
|
||||
"encoding/hex"
|
||||
)
|
||||
|
||||
func Hash(data []byte) []byte {
|
||||
func hash(data []byte) []byte {
|
||||
if len(data) == 0 {
|
||||
return []byte{}
|
||||
}
|
||||
@ -14,5 +14,5 @@ func Hash(data []byte) []byte {
|
||||
}
|
||||
|
||||
func HashString(data string) string {
|
||||
return hex.EncodeToString(Hash([]byte(data)))
|
||||
return hex.EncodeToString(hash([]byte(data)))
|
||||
}
|
||||
|
34
auth/provider.go
Normal file
34
auth/provider.go
Normal file
@ -0,0 +1,34 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type AuthProvider interface {
|
||||
IsAllowed(*http.Request) bool
|
||||
}
|
||||
|
||||
func NewEmptyAuthProvider() AuthProvider {
|
||||
return &emptyAuthProvider{}
|
||||
}
|
||||
|
||||
type emptyAuthProvider struct {
|
||||
}
|
||||
|
||||
func (p *emptyAuthProvider) IsAllowed(request *http.Request) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
type compositeAuthProvider struct {
|
||||
providers []AuthProvider
|
||||
}
|
||||
|
||||
func (p *compositeAuthProvider) IsAllowed(request *http.Request) bool {
|
||||
for i := 0; i < len(p.providers); i++ {
|
||||
provider := p.providers[i]
|
||||
if provider.IsAllowed(request) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
15
cfg/cfg.go
15
cfg/cfg.go
@ -2,9 +2,11 @@ package cfg
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
|
||||
"github.com/alecthomas/kong"
|
||||
"gitlab.com/hectorjsmith/fail2ban-prometheus-exporter/auth"
|
||||
)
|
||||
|
||||
var cliStruct struct {
|
||||
@ -36,11 +38,22 @@ func Parse() *AppSettings {
|
||||
Fail2BanSocketPath: cliStruct.F2bSocketPath,
|
||||
FileCollectorPath: cliStruct.TextFileExporterPath,
|
||||
ExitOnSocketConnError: cliStruct.ExitOnSocketError,
|
||||
BasicAuthProvider: newHashedBasicAuth(cliStruct.BasicAuthUser, cliStruct.BasicAuthPass),
|
||||
AuthProvider: createAuthProvider(),
|
||||
}
|
||||
return settings
|
||||
}
|
||||
|
||||
func createAuthProvider() auth.AuthProvider {
|
||||
username := cliStruct.BasicAuthUser
|
||||
password := cliStruct.BasicAuthPass
|
||||
|
||||
if len(username) == 0 && len(password) == 0 {
|
||||
return auth.NewEmptyAuthProvider()
|
||||
}
|
||||
log.Print("basic auth enabled")
|
||||
return auth.NewBasicAuthProvider(username, password)
|
||||
}
|
||||
|
||||
func validateFlags(cliCtx *kong.Context) {
|
||||
var flagsValid = true
|
||||
var messages = []string{}
|
||||
|
@ -1,10 +1,12 @@
|
||||
package cfg
|
||||
|
||||
import "gitlab.com/hectorjsmith/fail2ban-prometheus-exporter/auth"
|
||||
|
||||
type AppSettings struct {
|
||||
VersionMode bool
|
||||
MetricsAddress string
|
||||
Fail2BanSocketPath string
|
||||
FileCollectorPath string
|
||||
BasicAuthProvider *hashedBasicAuth
|
||||
AuthProvider auth.AuthProvider
|
||||
ExitOnSocketConnError bool
|
||||
}
|
||||
|
@ -67,17 +67,14 @@ func main() {
|
||||
textFileCollector := textfile.NewCollector(appSettings)
|
||||
prometheus.MustRegister(textFileCollector)
|
||||
|
||||
http.HandleFunc("/", server.BasicAuthMiddleware(rootHtmlHandler, appSettings.BasicAuthProvider))
|
||||
http.HandleFunc("/", server.BasicAuthMiddleware(rootHtmlHandler, appSettings.AuthProvider))
|
||||
http.HandleFunc(metricsPath, server.BasicAuthMiddleware(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
metricHandler(w, r, textFileCollector)
|
||||
},
|
||||
appSettings.BasicAuthProvider,
|
||||
appSettings.AuthProvider,
|
||||
))
|
||||
log.Printf("metrics available at '%s'", metricsPath)
|
||||
if appSettings.BasicAuthProvider.Enabled() {
|
||||
log.Printf("basic auth enabled")
|
||||
}
|
||||
|
||||
svrErr := make(chan error)
|
||||
go func() {
|
||||
|
@ -2,30 +2,16 @@ package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"gitlab.com/hectorjsmith/fail2ban-prometheus-exporter/auth"
|
||||
)
|
||||
|
||||
type BasicAuthProvider interface {
|
||||
Enabled() bool
|
||||
DoesBasicAuthMatch(username, password string) bool
|
||||
}
|
||||
|
||||
func BasicAuthMiddleware(handlerFunc http.HandlerFunc, basicAuthProvider BasicAuthProvider) http.HandlerFunc {
|
||||
if basicAuthProvider.Enabled() {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if doesBasicAuthMatch(r, basicAuthProvider) {
|
||||
handlerFunc.ServeHTTP(w, r)
|
||||
} else {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
}
|
||||
func BasicAuthMiddleware(handlerFunc http.HandlerFunc, authProvider auth.AuthProvider) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if authProvider.IsAllowed(r) {
|
||||
handlerFunc.ServeHTTP(w, r)
|
||||
} else {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
}
|
||||
}
|
||||
return handlerFunc
|
||||
}
|
||||
|
||||
func doesBasicAuthMatch(r *http.Request, basicAuthProvider BasicAuthProvider) bool {
|
||||
rawUsername, rawPassword, ok := r.BasicAuth()
|
||||
if ok {
|
||||
return basicAuthProvider.DoesBasicAuthMatch(rawUsername, rawPassword)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
@ -7,15 +7,10 @@ import (
|
||||
)
|
||||
|
||||
type testAuthProvider struct {
|
||||
enabled bool
|
||||
match bool
|
||||
match bool
|
||||
}
|
||||
|
||||
func (p testAuthProvider) Enabled() bool {
|
||||
return p.enabled
|
||||
}
|
||||
|
||||
func (p testAuthProvider) DoesBasicAuthMatch(username, password string) bool {
|
||||
func (p testAuthProvider) IsAllowed(request *http.Request) bool {
|
||||
return p.match
|
||||
}
|
||||
|
||||
@ -23,18 +18,15 @@ func newTestRequest() *http.Request {
|
||||
return httptest.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||
}
|
||||
|
||||
func executeBasicAuthMiddlewareTest(t *testing.T, authEnabled bool, authMatches bool, expectedCode int, expectedCallCount int) {
|
||||
func executeBasicAuthMiddlewareTest(t *testing.T, authMatches bool, expectedCode int, expectedCallCount int) {
|
||||
callCount := 0
|
||||
testHandler := func(w http.ResponseWriter, r *http.Request) {
|
||||
callCount++
|
||||
}
|
||||
|
||||
handler := BasicAuthMiddleware(testHandler, testAuthProvider{enabled: authEnabled, match: authMatches})
|
||||
handler := BasicAuthMiddleware(testHandler, testAuthProvider{match: authMatches})
|
||||
recorder := httptest.NewRecorder()
|
||||
request := newTestRequest()
|
||||
if authEnabled {
|
||||
request.SetBasicAuth("test", "test")
|
||||
}
|
||||
handler.ServeHTTP(recorder, request)
|
||||
|
||||
if recorder.Code != expectedCode {
|
||||
@ -45,14 +37,10 @@ func executeBasicAuthMiddlewareTest(t *testing.T, authEnabled bool, authMatches
|
||||
}
|
||||
}
|
||||
|
||||
func Test_GIVEN_DisabledBasicAuth_WHEN_MethodCalled_THEN_RequestProcessed(t *testing.T) {
|
||||
executeBasicAuthMiddlewareTest(t, false, false, http.StatusOK, 1)
|
||||
func Test_GIVEN_MatchingBasicAuth_WHEN_MethodCalled_THEN_RequestProcessed(t *testing.T) {
|
||||
executeBasicAuthMiddlewareTest(t, true, http.StatusOK, 1)
|
||||
}
|
||||
|
||||
func Test_GIVEN_EnabledBasicAuth_WHEN_MethodCalledWithCorrectCredentials_THEN_RequestProcessed(t *testing.T) {
|
||||
executeBasicAuthMiddlewareTest(t, true, true, http.StatusOK, 1)
|
||||
}
|
||||
|
||||
func Test_GIVEN_EnabledBasicAuth_WHEN_MethodCalledWithIncorrectCredentials_THEN_RequestRejected(t *testing.T) {
|
||||
executeBasicAuthMiddlewareTest(t, true, false, http.StatusUnauthorized, 0)
|
||||
func Test_GIVEN_NonMatchingBasicAuth_WHEN_MethodCalled_THEN_RequestRejected(t *testing.T) {
|
||||
executeBasicAuthMiddlewareTest(t, false, http.StatusUnauthorized, 0)
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user