add new basic auth provider class
Fix the circular dependency by adding a new BasicAuthProvider interface to wrap the basic auth values. Add unit tests for the hash functions. Refactor functions to use the new basic auth provider interface.
This commit is contained in:
parent
26f0b0d0ce
commit
e78d4e7298
@ -2,13 +2,17 @@ package auth
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
)
|
||||
|
||||
func Hash(data []byte) []byte {
|
||||
if len(data) == 0 {
|
||||
return []byte{}
|
||||
}
|
||||
b := sha256.Sum256(data)
|
||||
return b[:]
|
||||
}
|
||||
|
||||
func HashString(data string) string {
|
||||
return string(Hash([]byte(data)))
|
||||
return hex.EncodeToString(Hash([]byte(data)))
|
||||
}
|
||||
|
26
src/auth/hash_test.go
Normal file
26
src/auth/hash_test.go
Normal file
@ -0,0 +1,26 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHashString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args string
|
||||
want string
|
||||
}{
|
||||
{"Happy path #1", "123", "a665a45920422f9d417e4867efdc4fb8a04a1f3fff1fa07e998e86f7f7a27ae3"},
|
||||
{"Happy path #2", "hello world", "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9"},
|
||||
{"Happy path #3", "H3Ll0_W0RLD", "d58a27fe9a6e73a1d8a67189fb8acace047e7a1a795276a0056d3717ad61bd0e"},
|
||||
{"Blank string", "", ""},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := HashString(tt.args); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("HashString() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -1,15 +1,18 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"fail2ban-prometheus-exporter/cfg"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func BasicAuthMiddleware(handlerFunc http.HandlerFunc, appSettings *cfg.AppSettings) http.HandlerFunc {
|
||||
authEnabled := len(appSettings.BasicAuthUsername) > 0 && len(appSettings.BasicAuthPassword) > 0
|
||||
if authEnabled {
|
||||
type BasicAuthProvider interface {
|
||||
Enabled() bool
|
||||
DoesBasicAuthMatch(username, password string) bool
|
||||
}
|
||||
|
||||
func BasicAuthMiddleware(handlerFunc http.HandlerFunc, basicAuthProvider BasicAuthProvider) http.HandlerFunc {
|
||||
if basicAuthProvider.Enabled() {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if doesBasicAuthMatch(r, appSettings) {
|
||||
if doesBasicAuthMatch(r, basicAuthProvider) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
} else {
|
||||
handlerFunc.ServeHTTP(w, r)
|
||||
@ -19,12 +22,12 @@ func BasicAuthMiddleware(handlerFunc http.HandlerFunc, appSettings *cfg.AppSetti
|
||||
return handlerFunc
|
||||
}
|
||||
|
||||
func doesBasicAuthMatch(r *http.Request, appSettings *cfg.AppSettings) bool {
|
||||
func doesBasicAuthMatch(r *http.Request, basicAuthProvider BasicAuthProvider) bool {
|
||||
rawUsername, rawPassword, ok := r.BasicAuth()
|
||||
if ok {
|
||||
username := HashString(rawUsername)
|
||||
password := HashString(rawPassword)
|
||||
return username == appSettings.BasicAuthUsername && password == appSettings.BasicAuthPassword
|
||||
return basicAuthProvider.DoesBasicAuthMatch(username, password)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
14
src/cfg/basicAuth.go
Normal file
14
src/cfg/basicAuth.go
Normal file
@ -0,0 +1,14 @@
|
||||
package cfg
|
||||
|
||||
type basicAuth struct {
|
||||
username string
|
||||
password string
|
||||
}
|
||||
|
||||
func (p *basicAuth) Enabled() bool {
|
||||
return len(p.username) > 0 && len(p.password) > 0
|
||||
}
|
||||
|
||||
func (p *basicAuth) DoesBasicAuthMatch(username, password string) bool {
|
||||
return username == p.username && password == p.password
|
||||
}
|
@ -19,8 +19,7 @@ type AppSettings struct {
|
||||
Fail2BanSocketPath string
|
||||
FileCollectorPath string
|
||||
FileCollectorEnabled bool
|
||||
BasicAuthUsername string
|
||||
BasicAuthPassword string
|
||||
BasicAuthProvider *basicAuth
|
||||
}
|
||||
|
||||
func Parse() *AppSettings {
|
||||
@ -43,9 +42,11 @@ func Parse() *AppSettings {
|
||||
return appSettings
|
||||
}
|
||||
|
||||
func (settings *AppSettings) setBasicAuthValues(username, password string) {
|
||||
settings.BasicAuthUsername = auth.HashString(username)
|
||||
settings.BasicAuthPassword = auth.HashString(password)
|
||||
func (settings *AppSettings) setBasicAuthValues(rawUsername, rawPassword string) {
|
||||
settings.BasicAuthProvider = &basicAuth{
|
||||
username: auth.HashString(rawUsername),
|
||||
password: auth.HashString(rawPassword),
|
||||
}
|
||||
}
|
||||
|
||||
func (settings *AppSettings) validateFlags() {
|
||||
@ -64,7 +65,7 @@ func (settings *AppSettings) validateFlags() {
|
||||
fmt.Printf("file collector directory path must not be empty if collector enabled\n")
|
||||
flagsValid = false
|
||||
}
|
||||
if (len(settings.BasicAuthUsername) > 0) != (len(settings.BasicAuthPassword) > 0) {
|
||||
if (len(settings.BasicAuthProvider.username) > 0) != (len(settings.BasicAuthProvider.password) > 0) {
|
||||
fmt.Printf("to enable basic auth both the username and the password must be provided")
|
||||
flagsValid = false
|
||||
}
|
||||
|
@ -64,12 +64,12 @@ func main() {
|
||||
textFileCollector := textfile.NewCollector(appSettings)
|
||||
prometheus.MustRegister(textFileCollector)
|
||||
|
||||
http.HandleFunc("/", auth.BasicAuthMiddleware(rootHtmlHandler, appSettings))
|
||||
http.HandleFunc("/", auth.BasicAuthMiddleware(rootHtmlHandler, appSettings.BasicAuthProvider))
|
||||
http.HandleFunc(metricsPath, auth.BasicAuthMiddleware(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
metricHandler(w, r, textFileCollector)
|
||||
},
|
||||
appSettings,
|
||||
appSettings.BasicAuthProvider,
|
||||
))
|
||||
log.Printf("metrics available at '%s'", metricsPath)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user