package daemon

import (
	"context"
	"fmt"
	"net"
	"os"
	"os/signal"
	"regexp"
	"strings"
	"sync"
	"syscall"
	"time"

	"git.cryptic.systems/volker.raschek/dyndns-client/pkg/types"
	"git.cryptic.systems/volker.raschek/dyndns-client/pkg/updater"
	"github.com/asaskevich/govalidator"
	log "github.com/sirupsen/logrus"
	"github.com/vishvananda/netlink"
)

func Start(cnf *types.Config) {
	addrUpdates := make(chan netlink.AddrUpdate, 1)
	done := make(chan struct{}, 1)
	err := netlink.AddrSubscribeWithOptions(addrUpdates, done, netlink.AddrSubscribeOptions{
		ListExisting: true,
	})
	if err != nil {
		log.Fatalf("failed to subscribe netlink notifications from kernel: %v", err.Error())
	}

	interuptChannel := make(chan os.Signal, 1)
	signal.Notify(interuptChannel, syscall.SIGINT, syscall.SIGTERM)

	ctx := context.Background()
	daemonCtx, cancle := context.WithCancel(ctx)
	defer cancle()

	updaters, err := getUpdaterForEachZone(cnf)
	if err != nil {
		log.Fatalf("%v", err.Error())
	}

	if err := pruneRecords(daemonCtx, updaters, cnf.Zones); err != nil {
		log.Fatalf("%v", err.Error())
	}

	for {
		interfaces, err := netlink.LinkList()
		if err != nil {
			log.Fatalf("%v", err.Error())
		}

		select {
		case update := <-addrUpdates:

			interfaceLogger := log.WithFields(log.Fields{
				"ip": update.LinkAddress.IP.String(),
			})

			// search interface by index
			iface, err := searchInterfaceByIndex(update.LinkIndex, interfaces)
			if err != nil {
				log.Errorf("%v", err.Error())
				continue
			}
			interfaceLogger = interfaceLogger.WithField("device", iface.Attrs().Name)

			var recordType string
			switch {
			case govalidator.IsIPv4(strings.TrimRight(update.LinkAddress.IP.String(), "/")):
				recordType = "A"
			case govalidator.IsIPv6(strings.TrimRight(update.LinkAddress.IP.String(), "/")):
				recordType = "AAAA"
			default:
				interfaceLogger.Error("failed to detect record type")
				continue
			}
			interfaceLogger = interfaceLogger.WithField("rr", recordType)

			interfaceLogger.Debug("receive kernel notification for interface")

			// filter out not configured interfaces
			if !matchInterfaces(iface.Attrs().Name, cnf.Ifaces) {
				interfaceLogger.Warn("interface is not part of the allowed interface list")
				continue
			}

			// filter out notification for a bad interface ip address, for example link-local-addresses
			if update.LinkAddress.IP.IsLoopback() || strings.HasPrefix(update.LinkAddress.IP.String(), "fe80") {
				interfaceLogger.Warn("interface is a loopback device or part of a loopback network")
				continue
			}

			// decide if trigger a add or delete event
			if update.NewAddr {
				err = addIPRecords(daemonCtx, interfaceLogger, updaters, cnf.Zones, recordType, update.LinkAddress.IP)
				if err != nil {
					interfaceLogger.Error(err.Error())
				}
			} else {
				err = removeIPRecords(daemonCtx, interfaceLogger, updaters, cnf.Zones, recordType)
				if err != nil {
					interfaceLogger.Error(err.Error())
				}
			}

		case killSignal := <-interuptChannel:
			log.Debugf("got signal: %v", killSignal)
			log.Debugf("daemon was killed by: %v", killSignal)
			return
		}
	}
}

func getUpdaterForEachZone(config *types.Config) (map[string]updater.Updater, error) {
	updaterCollection := make(map[string]updater.Updater)

	for zoneName, zone := range config.Zones {
		nsUpdater, err := updater.NewNSUpdate(zone.DNSServer, config.TSIGKeys[zone.TSIGKeyName])
		if err != nil {
			return nil, err
		}
		updaterCollection[zoneName] = nsUpdater
	}

	return updaterCollection, nil
}

func matchInterfaces(iface string, ifaces []string) bool {
	for _, i := range ifaces {
		if i == iface {
			return true
		}
	}
	return false
}

func searchInterfaceByIndex(index int, interfaces []netlink.Link) (netlink.Link, error) {
	for _, iface := range interfaces {
		if iface.Attrs().Index == index {
			return iface, nil
		}
	}
	return nil, fmt.Errorf("can not find interface by index %v", index)
}

func addIPRecords(ctx context.Context, logEntry *log.Entry, updaters map[string]updater.Updater, zones map[string]*types.Zone, recordType string, ip net.IP) error {
	var (
		errorChannel = make(chan error, len(zones))
		wg           = new(sync.WaitGroup)
	)

	hostname, err := os.Hostname()
	if err != nil {
		return fmt.Errorf("failed to get host name from kernel: %w", err)
	}
	hostname = strings.ToLower(hostname)

	if !verifyHostname(hostname) {
		return fmt.Errorf("host name not valid: %w", err)
	}

	for zoneName := range zones {
		wg.Add(1)

		go func(ctx context.Context, zoneName string, hostname string, recordType string, ip net.IP, wg *sync.WaitGroup) {
			zoneLogger := logEntry.WithFields(log.Fields{
				"zone":     zoneName,
				"hostname": hostname,
			})

			defer wg.Done()

			pruneRecordCtx, cancle := context.WithTimeout(ctx, time.Second*15)
			defer cancle()

			fqdn := fmt.Sprintf("%v.%v", hostname, zoneName)

			err := updaters[zoneName].AddRecord(pruneRecordCtx, fqdn, 60, recordType, ip.String())
			if err != nil {
				errorChannel <- fmt.Errorf("failed to remove record type %v for %v: %v", recordType, fqdn, err.Error())
				return
			}

			zoneLogger.Info("dns-record successfully updated")
		}(ctx, zoneName, hostname, recordType, ip, wg)
	}

	wg.Wait()
	close(errorChannel)

	for err := range errorChannel {
		if err != nil {
			return err
		}
	}

	return nil
}

func pruneRecords(ctx context.Context, updaters map[string]updater.Updater, zones map[string]*types.Zone) error {
	var (
		errorChannel = make(chan error, len(zones))
		wg           = new(sync.WaitGroup)
	)

	hostname, err := os.Hostname()
	if err != nil {
		return fmt.Errorf("failed to get host name from kernel: %w", err)
	}
	hostname = strings.ToLower(hostname)

	if !verifyHostname(hostname) {
		return fmt.Errorf("host name not valid: %w", err)
	}

	for zoneName := range zones {
		wg.Add(1)

		go func(zoneName string, hostname string, errorChannel chan<- error, wg *sync.WaitGroup) {
			defer wg.Done()

			pruneRecordCtx, cancle := context.WithTimeout(ctx, time.Second*15)
			defer cancle()

			fqdn := fmt.Sprintf("%v.%v", hostname, zoneName)

			err := updaters[zoneName].PruneRecords(pruneRecordCtx, fqdn)
			if err != nil {
				errorChannel <- fmt.Errorf("failed to prune %v: %v", fqdn, err)
				return
			}
		}(zoneName, hostname, errorChannel, wg)
	}

	wg.Wait()
	close(errorChannel)

	for err := range errorChannel {
		if err != nil {
			return err
		}
	}

	return nil
}

func removeIPRecords(ctx context.Context, logEntry *log.Entry, updaters map[string]updater.Updater, zones map[string]*types.Zone, recordType string) error {
	var (
		errorChannel = make(chan error, len(zones))
		wg           = new(sync.WaitGroup)
	)

	hostname, err := os.Hostname()
	if err != nil {
		return fmt.Errorf("failed to get host name from kernel: %w", err)
	}
	hostname = strings.ToLower(hostname)

	if !verifyHostname(hostname) {
		return fmt.Errorf("host name not valid: %w", err)
	}

	for zoneName := range zones {
		wg.Add(1)

		go func(ctx context.Context, zoneName string, hostname string, recordType string, wg *sync.WaitGroup) {
			defer wg.Done()

			zoneLogger := logEntry.WithFields(log.Fields{
				"zone":     zoneName,
				"hostname": hostname,
			})

			pruneRecordCtx, cancle := context.WithTimeout(ctx, time.Second*15)
			defer cancle()

			fqdn := fmt.Sprintf("%v.%v", hostname, zoneName)

			err := updaters[zoneName].DeleteRecord(pruneRecordCtx, fqdn, recordType)
			if err != nil {
				errorChannel <- fmt.Errorf("failed to remove record type %v for %v: %v", recordType, fqdn, err.Error())
				return
			}

			zoneLogger.Info("dns-record successfully removed")
		}(ctx, zoneName, hostname, recordType, wg)
	}

	wg.Wait()
	close(errorChannel)

	for err := range errorChannel {
		if err != nil {
			return err
		}
	}

	return nil
}

// verifyHostname returns a boolean if the hostname id valid. The hostname does
// not contains any dot or local, localhost, localdomain.
func verifyHostname(hostname string) bool {
	if !validHostname.MatchString(hostname) {
		return false
	}

	hostnames := []string{
		"local",
		"localhost",
		"localdomain",
		"orbisos",
	}

	for i := range hostnames {
		if hostnames[i] == hostname {
			return false
		}
	}

	return true
}

var (
	validHostname = regexp.MustCompile(`^[a-zA-Z0-9]+([\-][a-zA-Z0-9]+)*$`)
)