336 lines
8.2 KiB
Go
336 lines
8.2 KiB
Go
|
package daemon
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"fmt"
|
||
|
"net"
|
||
|
"os"
|
||
|
"os/signal"
|
||
|
"regexp"
|
||
|
"strings"
|
||
|
"sync"
|
||
|
"syscall"
|
||
|
"time"
|
||
|
|
||
|
"git.cryptic.systems/volker.raschek/dyndns-client/pkg/types"
|
||
|
"git.cryptic.systems/volker.raschek/dyndns-client/pkg/updater"
|
||
|
"github.com/asaskevich/govalidator"
|
||
|
log "github.com/sirupsen/logrus"
|
||
|
"github.com/vishvananda/netlink"
|
||
|
)
|
||
|
|
||
|
func Start(cnf *types.Config) {
|
||
|
addrUpdates := make(chan netlink.AddrUpdate, 1)
|
||
|
done := make(chan struct{}, 1)
|
||
|
err := netlink.AddrSubscribeWithOptions(addrUpdates, done, netlink.AddrSubscribeOptions{
|
||
|
ListExisting: true,
|
||
|
})
|
||
|
if err != nil {
|
||
|
log.Fatalf("failed to subscribe netlink notifications from kernel: %v", err.Error())
|
||
|
}
|
||
|
|
||
|
interuptChannel := make(chan os.Signal, 1)
|
||
|
signal.Notify(interuptChannel, syscall.SIGINT, syscall.SIGTERM)
|
||
|
|
||
|
ctx := context.Background()
|
||
|
daemonCtx, cancle := context.WithCancel(ctx)
|
||
|
defer cancle()
|
||
|
|
||
|
updaters, err := getUpdaterForEachZone(cnf)
|
||
|
if err != nil {
|
||
|
log.Fatalf("%v", err.Error())
|
||
|
}
|
||
|
|
||
|
if err := pruneRecords(daemonCtx, updaters, cnf.Zones); err != nil {
|
||
|
log.Fatalf("%v", err.Error())
|
||
|
}
|
||
|
|
||
|
for {
|
||
|
interfaces, err := netlink.LinkList()
|
||
|
if err != nil {
|
||
|
log.Fatal("%v", err.Error())
|
||
|
}
|
||
|
|
||
|
select {
|
||
|
case update := <-addrUpdates:
|
||
|
|
||
|
interfaceLogger := log.WithFields(log.Fields{
|
||
|
"ip": update.LinkAddress.IP.String(),
|
||
|
})
|
||
|
|
||
|
// search interface by index
|
||
|
iface, err := searchInterfaceByIndex(update.LinkIndex, interfaces)
|
||
|
if err != nil {
|
||
|
log.Errorf("%v", err.Error())
|
||
|
continue
|
||
|
}
|
||
|
interfaceLogger = interfaceLogger.WithField("device", iface.Attrs().Name)
|
||
|
|
||
|
var recordType string
|
||
|
switch {
|
||
|
case govalidator.IsIPv4(strings.TrimRight(update.LinkAddress.IP.String(), "/")):
|
||
|
recordType = "A"
|
||
|
case govalidator.IsIPv6(strings.TrimRight(update.LinkAddress.IP.String(), "/")):
|
||
|
recordType = "AAAA"
|
||
|
default:
|
||
|
interfaceLogger.Error("failed to detect record type")
|
||
|
continue
|
||
|
}
|
||
|
interfaceLogger = interfaceLogger.WithField("rr", recordType)
|
||
|
|
||
|
interfaceLogger.Debug("receive kernel notification for interface")
|
||
|
|
||
|
// filter out not configured interfaces
|
||
|
if !matchInterfaces(iface.Attrs().Name, cnf.Ifaces) {
|
||
|
interfaceLogger.Warn("interface is not part of the allowed interface list")
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
// filter out notification for a bad interface ip address, for example link-local-addresses
|
||
|
if update.LinkAddress.IP.IsLoopback() || strings.HasPrefix(update.LinkAddress.IP.String(), "fe80") {
|
||
|
interfaceLogger.Warn("interface is a loopback device or part of a loopback network")
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
// decide if trigger a add or delete event
|
||
|
if update.NewAddr {
|
||
|
err = addIPRecords(daemonCtx, interfaceLogger, updaters, cnf.Zones, recordType, update.LinkAddress.IP)
|
||
|
if err != nil {
|
||
|
interfaceLogger.Error(err.Error())
|
||
|
}
|
||
|
} else {
|
||
|
err = removeIPRecords(daemonCtx, interfaceLogger, updaters, cnf.Zones, recordType, update.LinkAddress.IP)
|
||
|
if err != nil {
|
||
|
interfaceLogger.Error(err.Error())
|
||
|
}
|
||
|
}
|
||
|
|
||
|
case killSignal := <-interuptChannel:
|
||
|
log.Debugf("got signal: %v", killSignal)
|
||
|
log.Debugf("daemon was killed by: %v", killSignal)
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func getUpdaterForEachZone(config *types.Config) (map[string]updater.Updater, error) {
|
||
|
updaterCollection := make(map[string]updater.Updater)
|
||
|
|
||
|
for zoneName, zone := range config.Zones {
|
||
|
nsUpdater, err := updater.NewNSUpdate(zone.DNSServer, config.TSIGKeys[zone.TSIGKeyName])
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
updaterCollection[zoneName] = nsUpdater
|
||
|
}
|
||
|
|
||
|
return updaterCollection, nil
|
||
|
}
|
||
|
|
||
|
func matchInterfaces(iface string, ifaces []string) bool {
|
||
|
for _, i := range ifaces {
|
||
|
if i == iface {
|
||
|
return true
|
||
|
}
|
||
|
}
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
func searchInterfaceByIndex(index int, interfaces []netlink.Link) (netlink.Link, error) {
|
||
|
for _, iface := range interfaces {
|
||
|
if iface.Attrs().Index == index {
|
||
|
return iface, nil
|
||
|
}
|
||
|
}
|
||
|
return nil, fmt.Errorf("can not find interface by index %v", index)
|
||
|
}
|
||
|
|
||
|
func addIPRecords(ctx context.Context, logEntry *log.Entry, updaters map[string]updater.Updater, zones map[string]*types.Zone, recordType string, ip net.IP) error {
|
||
|
|
||
|
var (
|
||
|
errorChannel = make(chan error, len(zones))
|
||
|
wg = new(sync.WaitGroup)
|
||
|
)
|
||
|
|
||
|
hostname, err := os.Hostname()
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("failed to get host name from kernel: %w", err)
|
||
|
}
|
||
|
hostname = strings.ToLower(hostname)
|
||
|
|
||
|
if !verifyHostname(hostname) {
|
||
|
return fmt.Errorf("host name not valid: %w", err)
|
||
|
}
|
||
|
|
||
|
for zoneName := range zones {
|
||
|
wg.Add(1)
|
||
|
|
||
|
go func(ctx context.Context, zoneName string, hostname string, recordType string, ip net.IP, wg *sync.WaitGroup) {
|
||
|
|
||
|
zoneLogger := logEntry.WithFields(log.Fields{
|
||
|
"zone": zoneName,
|
||
|
"hostname": hostname,
|
||
|
})
|
||
|
|
||
|
defer wg.Done()
|
||
|
|
||
|
pruneRecordCtx, cancle := context.WithTimeout(ctx, time.Second*15)
|
||
|
defer cancle()
|
||
|
|
||
|
fqdn := fmt.Sprintf("%v.%v", hostname, zoneName)
|
||
|
|
||
|
err := updaters[zoneName].AddRecord(pruneRecordCtx, fqdn, 60, recordType, ip.String())
|
||
|
if err != nil {
|
||
|
errorChannel <- fmt.Errorf("failed to remove record type %v for %v: %v", recordType, fqdn, err.Error())
|
||
|
return
|
||
|
}
|
||
|
|
||
|
zoneLogger.Info("dns-record successfully updated")
|
||
|
|
||
|
}(ctx, zoneName, hostname, recordType, ip, wg)
|
||
|
}
|
||
|
|
||
|
wg.Wait()
|
||
|
close(errorChannel)
|
||
|
|
||
|
for err := range errorChannel {
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func pruneRecords(ctx context.Context, updaters map[string]updater.Updater, zones map[string]*types.Zone) error {
|
||
|
|
||
|
var (
|
||
|
errorChannel = make(chan error, len(zones))
|
||
|
wg = new(sync.WaitGroup)
|
||
|
)
|
||
|
|
||
|
hostname, err := os.Hostname()
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("failed to get host name from kernel: %w", err)
|
||
|
}
|
||
|
hostname = strings.ToLower(hostname)
|
||
|
|
||
|
if !verifyHostname(hostname) {
|
||
|
return fmt.Errorf("host name not valid: %w", err)
|
||
|
}
|
||
|
|
||
|
for zoneName := range zones {
|
||
|
wg.Add(1)
|
||
|
|
||
|
go func(zoneName string, hostname string, errorChannel chan<- error, wg *sync.WaitGroup) {
|
||
|
defer wg.Done()
|
||
|
|
||
|
pruneRecordCtx, cancle := context.WithTimeout(ctx, time.Second*15)
|
||
|
defer cancle()
|
||
|
|
||
|
fqdn := fmt.Sprintf("%v.%v", hostname, zoneName)
|
||
|
|
||
|
err := updaters[zoneName].PruneRecords(pruneRecordCtx, fqdn)
|
||
|
if err != nil {
|
||
|
errorChannel <- fmt.Errorf("failed to prune %v: %v", fqdn, err)
|
||
|
return
|
||
|
}
|
||
|
}(zoneName, hostname, errorChannel, wg)
|
||
|
}
|
||
|
|
||
|
wg.Wait()
|
||
|
close(errorChannel)
|
||
|
|
||
|
for err := range errorChannel {
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func removeIPRecords(ctx context.Context, logEntry *log.Entry, updaters map[string]updater.Updater, zones map[string]*types.Zone, recordType string, ip net.IP) error {
|
||
|
|
||
|
var (
|
||
|
errorChannel = make(chan error, len(zones))
|
||
|
wg = new(sync.WaitGroup)
|
||
|
)
|
||
|
|
||
|
hostname, err := os.Hostname()
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("failed to get host name from kernel: %w", err)
|
||
|
}
|
||
|
hostname = strings.ToLower(hostname)
|
||
|
|
||
|
if !verifyHostname(hostname) {
|
||
|
return fmt.Errorf("host name not valid: %w", err)
|
||
|
}
|
||
|
|
||
|
for zoneName := range zones {
|
||
|
wg.Add(1)
|
||
|
|
||
|
go func(ctx context.Context, zoneName string, hostname string, recordType string, wg *sync.WaitGroup) {
|
||
|
defer wg.Done()
|
||
|
|
||
|
zoneLogger := logEntry.WithFields(log.Fields{
|
||
|
"zone": zoneName,
|
||
|
"hostname": hostname,
|
||
|
})
|
||
|
|
||
|
pruneRecordCtx, cancle := context.WithTimeout(ctx, time.Second*15)
|
||
|
defer cancle()
|
||
|
|
||
|
fqdn := fmt.Sprintf("%v.%v", hostname, zoneName)
|
||
|
|
||
|
err := updaters[zoneName].DeleteRecord(pruneRecordCtx, fqdn, recordType)
|
||
|
if err != nil {
|
||
|
errorChannel <- fmt.Errorf("failed to remove record type %v for %v: %v", recordType, fqdn, err.Error())
|
||
|
return
|
||
|
}
|
||
|
|
||
|
zoneLogger.Info("dns-record successfully removed")
|
||
|
|
||
|
}(ctx, zoneName, hostname, recordType, wg)
|
||
|
}
|
||
|
|
||
|
wg.Wait()
|
||
|
close(errorChannel)
|
||
|
|
||
|
for err := range errorChannel {
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// verifyHostname returns a boolean if the hostname id valid. The hostname does
|
||
|
// not contains any dot or local, localhost, localdomain.
|
||
|
func verifyHostname(hostname string) bool {
|
||
|
|
||
|
if !validHostname.MatchString(hostname) {
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
hostnames := []string{
|
||
|
"local",
|
||
|
"localhost",
|
||
|
"localdomain",
|
||
|
"orbisos",
|
||
|
}
|
||
|
|
||
|
for i := range hostnames {
|
||
|
if hostnames[i] == hostname {
|
||
|
return false
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return true
|
||
|
}
|
||
|
|
||
|
var (
|
||
|
validHostname = regexp.MustCompile(`^[a-zA-Z0-9]+([\-][a-zA-Z0-9]+)*$`)
|
||
|
)
|