set-sshkeys/main.go

264 lines
4.9 KiB
Go

package main
import (
"bufio"
"errors"
"fmt"
"log"
"os"
"os/user"
"path/filepath"
"strings"
"git.cryptic.systems/volker.raschek/go-logger"
"github.com/spf13/cobra"
)
var (
flogger logger.Logger
version string
)
func main() {
flogger = logger.NewLogger(logger.LogLevelDebug)
rootCmd := cobra.Command{
Use: "set-sshkeys",
RunE: rootCmd,
Version: version,
}
rootCmd.Flags().String("authorized_keys", "/etc/set-sshkeys/authorized_keys", "Public ssh keys which should be merged with users existing ssh keys")
rootCmd.Flags().Bool("remove", false, "Remove public ssh keys")
rootCmd.Flags().String("user", "root", "For which user the public SSH keys should be defined")
rootCmd.Execute()
}
func addSSHKeys(sshKeys []*sshKey, newSSHKeys []*sshKey) []*sshKey {
Label:
for i := range newSSHKeys {
for j := range sshKeys {
if sshKeys[j].Compare(newSSHKeys[i]) {
continue Label
}
}
sshKeys = append(sshKeys, newSSHKeys[i])
}
return sshKeys
}
func createAutorizationFile(authorizedKeyFile string) error {
err := os.MkdirAll(filepath.Dir(authorizedKeyFile), 700)
if err != nil {
return err
}
f, err := os.Create(authorizedKeyFile)
if err != nil {
return err
}
return f.Close()
}
func readSSHKeys(authorizedKeyFile string) ([]*sshKey, error) {
f, err := os.Open(authorizedKeyFile)
if err != nil {
return nil, err
}
sshKeys := make([]*sshKey, 0)
s := bufio.NewScanner(f)
Loop:
for s.Scan() {
line := s.Text()
parts := strings.Split(line, " ")
switch len(parts) {
case 2:
algorithm := parts[0]
pubKey := parts[1]
sshKey, err := newSSHKey(algorithm, pubKey)
if err != nil {
return nil, err
}
for i := range sshKeys {
if sshKeys[i].Compare(sshKey) {
continue Loop
}
}
sshKeys = append(sshKeys, sshKey)
case 3:
algorithm := parts[0]
pubKey := parts[1]
alias := parts[2]
sshKey, err := newSSHKey(algorithm, pubKey)
if err != nil {
return nil, err
}
sshKey.SetAlias(alias)
for i := range sshKeys {
if sshKeys[i].Compare(sshKey) {
continue Loop
}
}
sshKeys = append(sshKeys, sshKey)
default:
log.Printf("Require two and optional three parts for each line. Get %v parts", len(parts))
}
}
return sshKeys, nil
}
func removeSSHKeys(sshKeys []*sshKey, removeSSHKeys []*sshKey) []*sshKey {
for i := range removeSSHKeys {
for j := range sshKeys {
if sshKeys[j].Compare(removeSSHKeys[i]) {
sshKeys = append(sshKeys[:j], sshKeys[j+1:]...)
}
}
}
return sshKeys
}
func rootCmd(cmd *cobra.Command, args []string) error {
remove, err := cmd.Flags().GetBool("remove")
if err != nil {
return err
}
etcAuthorizedKeyFile, err := cmd.Flags().GetString("authorized_keys")
if err != nil {
return err
}
_, err = os.Stat(etcAuthorizedKeyFile)
if err != nil {
return err
}
username, err := cmd.Flags().GetString("user")
if err != nil {
return err
}
user, err := user.Lookup(username)
if err != nil {
return err
}
userAuthorizedKeyFile := filepath.Join(user.HomeDir, ".ssh", "authorized_keys")
_, err = os.Stat(userAuthorizedKeyFile)
switch {
case err == nil:
break
case errors.Is(err, os.ErrNotExist):
flogger.Debug("Create authorization file %v", userAuthorizedKeyFile)
if err := createAutorizationFile(userAuthorizedKeyFile); err != nil {
return err
}
default:
return err
}
etcAuthorizedKeys, err := readSSHKeys(etcAuthorizedKeyFile)
if err != nil {
return err
}
userAuthorizedKeys, err := readSSHKeys(userAuthorizedKeyFile)
if err != nil {
return err
}
if remove {
userAuthorizedKeys = removeSSHKeys(userAuthorizedKeys, etcAuthorizedKeys)
} else {
userAuthorizedKeys = addSSHKeys(userAuthorizedKeys, etcAuthorizedKeys)
}
return writeSSHKeys(userAuthorizedKeyFile, userAuthorizedKeys)
}
func writeSSHKeys(authorizedKeyFile string, sshKeys []*sshKey) error {
if err := createAutorizationFile(authorizedKeyFile); err != nil {
return err
}
f, err := os.Create(authorizedKeyFile)
if err != nil {
return err
}
for i := range sshKeys {
if len(sshKeys[i].alias) > 0 {
fmt.Fprintf(f, "%v %v %v\n", sshKeys[i].algorithm, sshKeys[i].pubKey, sshKeys[i].alias)
} else {
fmt.Fprintf(f, "%v %v\n", sshKeys[i].algorithm, sshKeys[i].pubKey)
}
}
return nil
}
type sshKey struct {
algorithm string
pubKey string
alias string
}
func (s *sshKey) Compare(k *sshKey) bool {
if s.algorithm == k.algorithm &&
s.pubKey == k.pubKey {
return true
}
return false
}
func (s *sshKey) SetAlias(alias string) {
if len(s.alias) <= 0 {
s.alias = alias
}
}
func (s *sshKey) Validate() error {
entries := map[string]string{
"algorithm": s.algorithm,
"pubKey": s.pubKey,
}
for key, value := range entries {
if len(value) <= 0 {
return fmt.Errorf("Missing attribute %v", key)
}
}
return nil
}
func newSSHKey(algorithm string, pubKey string) (*sshKey, error) {
return &sshKey{
algorithm: algorithm,
pubKey: pubKey,
}, nil
}