feat: add support for basic auth (#16)

Add new CLI parameters to enable protecting the API endpoints with basic
auth authentication.
Wrap the server endpoints in a new authMiddleware that protects it using
the provided basic auth credentials (if set).
This commit is contained in:
Hector 2022-01-11 21:53:43 +00:00
parent 013e8f30c9
commit e3d8c1e0e5
2 changed files with 26 additions and 4 deletions

View File

@ -18,6 +18,8 @@ type AppSettings struct {
Fail2BanSocketPath string Fail2BanSocketPath string
FileCollectorPath string FileCollectorPath string
FileCollectorEnabled bool FileCollectorEnabled bool
BasicAuthUsername string
BasicAuthPassword string
} }
func Parse() *AppSettings { func Parse() *AppSettings {
@ -28,6 +30,8 @@ func Parse() *AppSettings {
flag.StringVar(&appSettings.Fail2BanSocketPath, "socket", "", "path to the fail2ban server socket") flag.StringVar(&appSettings.Fail2BanSocketPath, "socket", "", "path to the fail2ban server socket")
flag.BoolVar(&appSettings.FileCollectorEnabled, "collector.textfile", false, "enable the textfile collector") flag.BoolVar(&appSettings.FileCollectorEnabled, "collector.textfile", false, "enable the textfile collector")
flag.StringVar(&appSettings.FileCollectorPath, "collector.textfile.directory", "", "directory to read text files with metrics from") flag.StringVar(&appSettings.FileCollectorPath, "collector.textfile.directory", "", "directory to read text files with metrics from")
flag.StringVar(&appSettings.BasicAuthUsername, "web.basic-auth.username", "", "set username for basic auth")
flag.StringVar(&appSettings.BasicAuthPassword, "web.basic-auth.password", "", "set password for basic auth")
flag.Parse() flag.Parse()
appSettings.validateFlags() appSettings.validateFlags()

View File

@ -48,6 +48,21 @@ func metricHandler(w http.ResponseWriter, r *http.Request, collector *textfile.C
collector.WriteTextFileMetrics(w, r) collector.WriteTextFileMetrics(w, r)
} }
func authMiddleware(handlerFunc http.HandlerFunc, appSettings *cfg.AppSettings) http.HandlerFunc {
authEnabled := len(appSettings.BasicAuthUsername) > 0 && len(appSettings.BasicAuthPassword) > 0
if authEnabled {
return func(w http.ResponseWriter, r *http.Request) {
username, password, ok := r.BasicAuth()
if !ok || username != appSettings.BasicAuthUsername || password != appSettings.BasicAuthPassword {
w.WriteHeader(http.StatusUnauthorized)
return
}
handlerFunc.ServeHTTP(w, r)
}
}
return handlerFunc
}
func main() { func main() {
appSettings := cfg.Parse() appSettings := cfg.Parse()
if appSettings.VersionMode { if appSettings.VersionMode {
@ -63,10 +78,13 @@ func main() {
textFileCollector := textfile.NewCollector(appSettings) textFileCollector := textfile.NewCollector(appSettings)
prometheus.MustRegister(textFileCollector) prometheus.MustRegister(textFileCollector)
http.HandleFunc("/", rootHtmlHandler) http.HandleFunc("/", authMiddleware(rootHtmlHandler, appSettings))
http.HandleFunc(metricsPath, func(w http.ResponseWriter, r *http.Request) { http.HandleFunc(metricsPath, authMiddleware(
func(w http.ResponseWriter, r *http.Request) {
metricHandler(w, r, textFileCollector) metricHandler(w, r, textFileCollector)
}) },
appSettings,
))
log.Printf("metrics available at '%s'", metricsPath) log.Printf("metrics available at '%s'", metricsPath)
svrErr := make(chan error) svrErr := make(chan error)