rewrite auth provider logic
This commit is contained in:
parent
d31ea4b23c
commit
17fb995928
29
auth/basic.go
Normal file
29
auth/basic.go
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewBasicAuthProvider(username, password string) AuthProvider {
|
||||||
|
return &basicAuthProvider{
|
||||||
|
hashedAuth: encodeBasicAuth(username, password),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type basicAuthProvider struct {
|
||||||
|
hashedAuth string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *basicAuthProvider) IsAllowed(request *http.Request) bool {
|
||||||
|
username, password, ok := request.BasicAuth()
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
requestAuth := encodeBasicAuth(username, password)
|
||||||
|
return p.hashedAuth == requestAuth
|
||||||
|
}
|
||||||
|
|
||||||
|
func encodeBasicAuth(username, password string) string {
|
||||||
|
return HashString(fmt.Sprintf("%s:%s", username, password))
|
||||||
|
}
|
@ -5,7 +5,7 @@ import (
|
|||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Hash(data []byte) []byte {
|
func hash(data []byte) []byte {
|
||||||
if len(data) == 0 {
|
if len(data) == 0 {
|
||||||
return []byte{}
|
return []byte{}
|
||||||
}
|
}
|
||||||
@ -14,5 +14,5 @@ func Hash(data []byte) []byte {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func HashString(data string) string {
|
func HashString(data string) string {
|
||||||
return hex.EncodeToString(Hash([]byte(data)))
|
return hex.EncodeToString(hash([]byte(data)))
|
||||||
}
|
}
|
||||||
|
34
auth/provider.go
Normal file
34
auth/provider.go
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AuthProvider interface {
|
||||||
|
IsAllowed(*http.Request) bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewEmptyAuthProvider() AuthProvider {
|
||||||
|
return &emptyAuthProvider{}
|
||||||
|
}
|
||||||
|
|
||||||
|
type emptyAuthProvider struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *emptyAuthProvider) IsAllowed(request *http.Request) bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
type compositeAuthProvider struct {
|
||||||
|
providers []AuthProvider
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *compositeAuthProvider) IsAllowed(request *http.Request) bool {
|
||||||
|
for i := 0; i < len(p.providers); i++ {
|
||||||
|
provider := p.providers[i]
|
||||||
|
if provider.IsAllowed(request) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
15
cfg/cfg.go
15
cfg/cfg.go
@ -2,9 +2,11 @@ package cfg
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/alecthomas/kong"
|
"github.com/alecthomas/kong"
|
||||||
|
"gitlab.com/hectorjsmith/fail2ban-prometheus-exporter/auth"
|
||||||
)
|
)
|
||||||
|
|
||||||
var cliStruct struct {
|
var cliStruct struct {
|
||||||
@ -36,11 +38,22 @@ func Parse() *AppSettings {
|
|||||||
Fail2BanSocketPath: cliStruct.F2bSocketPath,
|
Fail2BanSocketPath: cliStruct.F2bSocketPath,
|
||||||
FileCollectorPath: cliStruct.TextFileExporterPath,
|
FileCollectorPath: cliStruct.TextFileExporterPath,
|
||||||
ExitOnSocketConnError: cliStruct.ExitOnSocketError,
|
ExitOnSocketConnError: cliStruct.ExitOnSocketError,
|
||||||
BasicAuthProvider: newHashedBasicAuth(cliStruct.BasicAuthUser, cliStruct.BasicAuthPass),
|
AuthProvider: createAuthProvider(),
|
||||||
}
|
}
|
||||||
return settings
|
return settings
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func createAuthProvider() auth.AuthProvider {
|
||||||
|
username := cliStruct.BasicAuthUser
|
||||||
|
password := cliStruct.BasicAuthPass
|
||||||
|
|
||||||
|
if len(username) == 0 && len(password) == 0 {
|
||||||
|
return auth.NewEmptyAuthProvider()
|
||||||
|
}
|
||||||
|
log.Print("basic auth enabled")
|
||||||
|
return auth.NewBasicAuthProvider(username, password)
|
||||||
|
}
|
||||||
|
|
||||||
func validateFlags(cliCtx *kong.Context) {
|
func validateFlags(cliCtx *kong.Context) {
|
||||||
var flagsValid = true
|
var flagsValid = true
|
||||||
var messages = []string{}
|
var messages = []string{}
|
||||||
|
@ -1,10 +1,12 @@
|
|||||||
package cfg
|
package cfg
|
||||||
|
|
||||||
|
import "gitlab.com/hectorjsmith/fail2ban-prometheus-exporter/auth"
|
||||||
|
|
||||||
type AppSettings struct {
|
type AppSettings struct {
|
||||||
VersionMode bool
|
VersionMode bool
|
||||||
MetricsAddress string
|
MetricsAddress string
|
||||||
Fail2BanSocketPath string
|
Fail2BanSocketPath string
|
||||||
FileCollectorPath string
|
FileCollectorPath string
|
||||||
BasicAuthProvider *hashedBasicAuth
|
AuthProvider auth.AuthProvider
|
||||||
ExitOnSocketConnError bool
|
ExitOnSocketConnError bool
|
||||||
}
|
}
|
||||||
|
@ -67,17 +67,14 @@ func main() {
|
|||||||
textFileCollector := textfile.NewCollector(appSettings)
|
textFileCollector := textfile.NewCollector(appSettings)
|
||||||
prometheus.MustRegister(textFileCollector)
|
prometheus.MustRegister(textFileCollector)
|
||||||
|
|
||||||
http.HandleFunc("/", server.BasicAuthMiddleware(rootHtmlHandler, appSettings.BasicAuthProvider))
|
http.HandleFunc("/", server.BasicAuthMiddleware(rootHtmlHandler, appSettings.AuthProvider))
|
||||||
http.HandleFunc(metricsPath, server.BasicAuthMiddleware(
|
http.HandleFunc(metricsPath, server.BasicAuthMiddleware(
|
||||||
func(w http.ResponseWriter, r *http.Request) {
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
metricHandler(w, r, textFileCollector)
|
metricHandler(w, r, textFileCollector)
|
||||||
},
|
},
|
||||||
appSettings.BasicAuthProvider,
|
appSettings.AuthProvider,
|
||||||
))
|
))
|
||||||
log.Printf("metrics available at '%s'", metricsPath)
|
log.Printf("metrics available at '%s'", metricsPath)
|
||||||
if appSettings.BasicAuthProvider.Enabled() {
|
|
||||||
log.Printf("basic auth enabled")
|
|
||||||
}
|
|
||||||
|
|
||||||
svrErr := make(chan error)
|
svrErr := make(chan error)
|
||||||
go func() {
|
go func() {
|
||||||
|
@ -2,30 +2,16 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
"gitlab.com/hectorjsmith/fail2ban-prometheus-exporter/auth"
|
||||||
)
|
)
|
||||||
|
|
||||||
type BasicAuthProvider interface {
|
func BasicAuthMiddleware(handlerFunc http.HandlerFunc, authProvider auth.AuthProvider) http.HandlerFunc {
|
||||||
Enabled() bool
|
|
||||||
DoesBasicAuthMatch(username, password string) bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func BasicAuthMiddleware(handlerFunc http.HandlerFunc, basicAuthProvider BasicAuthProvider) http.HandlerFunc {
|
|
||||||
if basicAuthProvider.Enabled() {
|
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
if doesBasicAuthMatch(r, basicAuthProvider) {
|
if authProvider.IsAllowed(r) {
|
||||||
handlerFunc.ServeHTTP(w, r)
|
handlerFunc.ServeHTTP(w, r)
|
||||||
} else {
|
} else {
|
||||||
w.WriteHeader(http.StatusUnauthorized)
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return handlerFunc
|
|
||||||
}
|
|
||||||
|
|
||||||
func doesBasicAuthMatch(r *http.Request, basicAuthProvider BasicAuthProvider) bool {
|
|
||||||
rawUsername, rawPassword, ok := r.BasicAuth()
|
|
||||||
if ok {
|
|
||||||
return basicAuthProvider.DoesBasicAuthMatch(rawUsername, rawPassword)
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
@ -7,15 +7,10 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type testAuthProvider struct {
|
type testAuthProvider struct {
|
||||||
enabled bool
|
|
||||||
match bool
|
match bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p testAuthProvider) Enabled() bool {
|
func (p testAuthProvider) IsAllowed(request *http.Request) bool {
|
||||||
return p.enabled
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p testAuthProvider) DoesBasicAuthMatch(username, password string) bool {
|
|
||||||
return p.match
|
return p.match
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -23,18 +18,15 @@ func newTestRequest() *http.Request {
|
|||||||
return httptest.NewRequest(http.MethodGet, "http://example.com", nil)
|
return httptest.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func executeBasicAuthMiddlewareTest(t *testing.T, authEnabled bool, authMatches bool, expectedCode int, expectedCallCount int) {
|
func executeBasicAuthMiddlewareTest(t *testing.T, authMatches bool, expectedCode int, expectedCallCount int) {
|
||||||
callCount := 0
|
callCount := 0
|
||||||
testHandler := func(w http.ResponseWriter, r *http.Request) {
|
testHandler := func(w http.ResponseWriter, r *http.Request) {
|
||||||
callCount++
|
callCount++
|
||||||
}
|
}
|
||||||
|
|
||||||
handler := BasicAuthMiddleware(testHandler, testAuthProvider{enabled: authEnabled, match: authMatches})
|
handler := BasicAuthMiddleware(testHandler, testAuthProvider{match: authMatches})
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
request := newTestRequest()
|
request := newTestRequest()
|
||||||
if authEnabled {
|
|
||||||
request.SetBasicAuth("test", "test")
|
|
||||||
}
|
|
||||||
handler.ServeHTTP(recorder, request)
|
handler.ServeHTTP(recorder, request)
|
||||||
|
|
||||||
if recorder.Code != expectedCode {
|
if recorder.Code != expectedCode {
|
||||||
@ -45,14 +37,10 @@ func executeBasicAuthMiddlewareTest(t *testing.T, authEnabled bool, authMatches
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_GIVEN_DisabledBasicAuth_WHEN_MethodCalled_THEN_RequestProcessed(t *testing.T) {
|
func Test_GIVEN_MatchingBasicAuth_WHEN_MethodCalled_THEN_RequestProcessed(t *testing.T) {
|
||||||
executeBasicAuthMiddlewareTest(t, false, false, http.StatusOK, 1)
|
executeBasicAuthMiddlewareTest(t, true, http.StatusOK, 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_GIVEN_EnabledBasicAuth_WHEN_MethodCalledWithCorrectCredentials_THEN_RequestProcessed(t *testing.T) {
|
func Test_GIVEN_NonMatchingBasicAuth_WHEN_MethodCalled_THEN_RequestRejected(t *testing.T) {
|
||||||
executeBasicAuthMiddlewareTest(t, true, true, http.StatusOK, 1)
|
executeBasicAuthMiddlewareTest(t, false, http.StatusUnauthorized, 0)
|
||||||
}
|
|
||||||
|
|
||||||
func Test_GIVEN_EnabledBasicAuth_WHEN_MethodCalledWithIncorrectCredentials_THEN_RequestRejected(t *testing.T) {
|
|
||||||
executeBasicAuthMiddlewareTest(t, true, false, http.StatusUnauthorized, 0)
|
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user