264 lines
4.9 KiB
Go
264 lines
4.9 KiB
Go
package main
|
|
|
|
import (
|
|
"bufio"
|
|
"errors"
|
|
"fmt"
|
|
"log"
|
|
"os"
|
|
"os/user"
|
|
"path/filepath"
|
|
"strings"
|
|
|
|
"git.cryptic.systems/volker.raschek/go-logger"
|
|
"github.com/spf13/cobra"
|
|
)
|
|
|
|
var (
|
|
flogger logger.Logger
|
|
version string
|
|
)
|
|
|
|
func main() {
|
|
flogger = logger.NewLogger(logger.LogLevelDebug)
|
|
|
|
rootCmd := cobra.Command{
|
|
Use: "set-sshkeys",
|
|
RunE: rootCmd,
|
|
Version: version,
|
|
}
|
|
rootCmd.Flags().String("authorized_keys", "/etc/set-sshkeys/authorized_keys", "Public ssh keys which should be merged with users existing ssh keys")
|
|
rootCmd.Flags().Bool("remove", false, "Remove public ssh keys")
|
|
rootCmd.Flags().String("user", "root", "For which user the public SSH keys should be defined")
|
|
|
|
rootCmd.Execute()
|
|
}
|
|
|
|
func addSSHKeys(sshKeys []*sshKey, newSSHKeys []*sshKey) []*sshKey {
|
|
Label:
|
|
for i := range newSSHKeys {
|
|
for j := range sshKeys {
|
|
if sshKeys[j].Compare(newSSHKeys[i]) {
|
|
continue Label
|
|
}
|
|
}
|
|
sshKeys = append(sshKeys, newSSHKeys[i])
|
|
}
|
|
return sshKeys
|
|
}
|
|
|
|
func createAutorizationFile(authorizedKeyFile string) error {
|
|
|
|
err := os.MkdirAll(filepath.Dir(authorizedKeyFile), 700)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
f, err := os.Create(authorizedKeyFile)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return f.Close()
|
|
}
|
|
|
|
func readSSHKeys(authorizedKeyFile string) ([]*sshKey, error) {
|
|
|
|
f, err := os.Open(authorizedKeyFile)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
sshKeys := make([]*sshKey, 0)
|
|
|
|
s := bufio.NewScanner(f)
|
|
|
|
Loop:
|
|
for s.Scan() {
|
|
line := s.Text()
|
|
parts := strings.Split(line, " ")
|
|
|
|
switch len(parts) {
|
|
case 2:
|
|
|
|
algorithm := parts[0]
|
|
pubKey := parts[1]
|
|
|
|
sshKey, err := newSSHKey(algorithm, pubKey)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for i := range sshKeys {
|
|
if sshKeys[i].Compare(sshKey) {
|
|
continue Loop
|
|
}
|
|
}
|
|
|
|
sshKeys = append(sshKeys, sshKey)
|
|
|
|
case 3:
|
|
|
|
algorithm := parts[0]
|
|
pubKey := parts[1]
|
|
alias := parts[2]
|
|
|
|
sshKey, err := newSSHKey(algorithm, pubKey)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
sshKey.SetAlias(alias)
|
|
|
|
for i := range sshKeys {
|
|
if sshKeys[i].Compare(sshKey) {
|
|
continue Loop
|
|
}
|
|
}
|
|
|
|
sshKeys = append(sshKeys, sshKey)
|
|
|
|
default:
|
|
log.Printf("Require two and optional three parts for each line. Get %v parts", len(parts))
|
|
}
|
|
}
|
|
|
|
return sshKeys, nil
|
|
}
|
|
|
|
func removeSSHKeys(sshKeys []*sshKey, removeSSHKeys []*sshKey) []*sshKey {
|
|
for i := range removeSSHKeys {
|
|
for j := range sshKeys {
|
|
if sshKeys[j].Compare(removeSSHKeys[i]) {
|
|
sshKeys = append(sshKeys[:j], sshKeys[j+1:]...)
|
|
}
|
|
}
|
|
}
|
|
return sshKeys
|
|
}
|
|
|
|
func rootCmd(cmd *cobra.Command, args []string) error {
|
|
|
|
remove, err := cmd.Flags().GetBool("remove")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
etcAuthorizedKeyFile, err := cmd.Flags().GetString("authorized_keys")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
_, err = os.Stat(etcAuthorizedKeyFile)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
username, err := cmd.Flags().GetString("user")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
user, err := user.Lookup(username)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
userAuthorizedKeyFile := filepath.Join(user.HomeDir, ".ssh", "authorized_keys")
|
|
|
|
_, err = os.Stat(userAuthorizedKeyFile)
|
|
switch {
|
|
case err == nil:
|
|
break
|
|
case errors.Is(err, os.ErrNotExist):
|
|
flogger.Debug("Create authorization file %v", userAuthorizedKeyFile)
|
|
if err := createAutorizationFile(userAuthorizedKeyFile); err != nil {
|
|
return err
|
|
}
|
|
default:
|
|
return err
|
|
}
|
|
|
|
etcAuthorizedKeys, err := readSSHKeys(etcAuthorizedKeyFile)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
userAuthorizedKeys, err := readSSHKeys(userAuthorizedKeyFile)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if remove {
|
|
userAuthorizedKeys = removeSSHKeys(userAuthorizedKeys, etcAuthorizedKeys)
|
|
} else {
|
|
userAuthorizedKeys = addSSHKeys(userAuthorizedKeys, etcAuthorizedKeys)
|
|
}
|
|
|
|
return writeSSHKeys(userAuthorizedKeyFile, userAuthorizedKeys)
|
|
}
|
|
|
|
func writeSSHKeys(authorizedKeyFile string, sshKeys []*sshKey) error {
|
|
|
|
if err := createAutorizationFile(authorizedKeyFile); err != nil {
|
|
return err
|
|
}
|
|
|
|
f, err := os.Create(authorizedKeyFile)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for i := range sshKeys {
|
|
if len(sshKeys[i].alias) > 0 {
|
|
fmt.Fprintf(f, "%v %v %v\n", sshKeys[i].algorithm, sshKeys[i].pubKey, sshKeys[i].alias)
|
|
} else {
|
|
fmt.Fprintf(f, "%v %v\n", sshKeys[i].algorithm, sshKeys[i].pubKey)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
type sshKey struct {
|
|
algorithm string
|
|
pubKey string
|
|
alias string
|
|
}
|
|
|
|
func (s *sshKey) Compare(k *sshKey) bool {
|
|
if s.algorithm == k.algorithm &&
|
|
s.pubKey == k.pubKey {
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (s *sshKey) SetAlias(alias string) {
|
|
if len(s.alias) <= 0 {
|
|
s.alias = alias
|
|
}
|
|
}
|
|
|
|
func (s *sshKey) Validate() error {
|
|
entries := map[string]string{
|
|
"algorithm": s.algorithm,
|
|
"pubKey": s.pubKey,
|
|
}
|
|
|
|
for key, value := range entries {
|
|
if len(value) <= 0 {
|
|
return fmt.Errorf("Missing attribute %v", key)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func newSSHKey(algorithm string, pubKey string) (*sshKey, error) {
|
|
return &sshKey{
|
|
algorithm: algorithm,
|
|
pubKey: pubKey,
|
|
}, nil
|
|
}
|