set-sshkeys/main.go

299 lines
5.4 KiB
Go

package main
import (
"bufio"
"errors"
"fmt"
"io"
"log"
"os"
"os/user"
"path/filepath"
"strconv"
"strings"
"github.com/spf13/cobra"
)
var (
version string
)
func main() {
rootCmd := cobra.Command{
Use: "set-sshkeys",
RunE: rootCmd,
Version: version,
}
rootCmd.Flags().String("authorized_keys", "/etc/set-sshkeys/authorized_keys", "Public ssh keys which should be merged with users existing ssh keys")
rootCmd.Flags().Bool("remove", false, "Remove public ssh keys")
rootCmd.Flags().String("user", "root", "For which user the public SSH keys should be defined")
rootCmd.Execute()
}
func addSSHKeys(sshKeys []*sshKey, newSSHKeys []*sshKey) []*sshKey {
Label:
for i := range newSSHKeys {
if err := newSSHKeys[i].Validate(); err != nil {
continue Label
}
for j := range sshKeys {
if sshKeys[j].Compare(newSSHKeys[i]) {
sshKeys[j].SetAlias(newSSHKeys[i].alias)
continue Label
}
}
sshKeys = append(sshKeys, newSSHKeys[i])
}
return sshKeys
}
func createAutorizationFile(authorizedKeyFile string) error {
err := os.MkdirAll(filepath.Dir(authorizedKeyFile), 0700)
if err != nil {
return err
}
f, err := os.Create(authorizedKeyFile)
if err != nil {
return err
}
return f.Close()
}
func readSSHKeysFile(authorizedKeyFile string) ([]*sshKey, error) {
f, err := os.Open(authorizedKeyFile)
if err != nil {
return nil, err
}
defer f.Close()
return readSSHKeys(f)
}
func readSSHKeys(r io.Reader) ([]*sshKey, error) {
var (
sshKeys = make([]*sshKey, 0)
s = bufio.NewScanner(r)
)
Loop:
for s.Scan() {
line := s.Text()
parts := strings.Split(line, " ")
switch len(parts) {
case 2:
algorithm := parts[0]
pubKey := parts[1]
sshKey, err := newSSHKey(algorithm, pubKey)
if err != nil {
return nil, err
}
for i := range sshKeys {
if sshKeys[i].Compare(sshKey) {
continue Loop
}
}
sshKeys = append(sshKeys, sshKey)
case 3:
algorithm := parts[0]
pubKey := parts[1]
alias := parts[2]
sshKey, err := newSSHKey(algorithm, pubKey)
if err != nil {
return nil, err
}
sshKey.SetAlias(alias)
for i := range sshKeys {
if sshKeys[i].Compare(sshKey) {
sshKeys[i].SetAlias(sshKey.alias)
continue Loop
}
}
sshKeys = append(sshKeys, sshKey)
default:
log.Printf("WARN: Require two and optional three parts for each line. Get %v parts - SKIP entry", len(parts))
}
}
return sshKeys, nil
}
func removeSSHKeys(sshKeys []*sshKey, removeSSHKeys []*sshKey) []*sshKey {
s := make([]*sshKey, 0)
Loop:
for i := range sshKeys {
for j := range removeSSHKeys {
if sshKeys[i].Compare(removeSSHKeys[j]) {
continue Loop
}
}
s = append(s, sshKeys[i])
}
return s
}
func rootCmd(cmd *cobra.Command, args []string) error {
remove, err := cmd.Flags().GetBool("remove")
if err != nil {
return err
}
etcAuthorizedKeyFile, err := cmd.Flags().GetString("authorized_keys")
if err != nil {
return err
}
_, err = os.Stat(etcAuthorizedKeyFile)
if err != nil {
return err
}
username, err := cmd.Flags().GetString("user")
if err != nil {
return err
}
user, err := user.Lookup(username)
if err != nil {
return err
}
userAuthorizedKeyFile := filepath.Join(user.HomeDir, ".ssh", "authorized_keys")
_, err = os.Stat(userAuthorizedKeyFile)
switch {
case err == nil:
break
case errors.Is(err, os.ErrNotExist):
if err := createAutorizationFile(userAuthorizedKeyFile); err != nil {
return err
}
default:
return err
}
etcAuthorizedKeys, err := readSSHKeysFile(etcAuthorizedKeyFile)
if err != nil {
return err
}
userAuthorizedKeys, err := readSSHKeysFile(userAuthorizedKeyFile)
if err != nil {
return err
}
if remove {
userAuthorizedKeys = removeSSHKeys(userAuthorizedKeys, etcAuthorizedKeys)
} else {
userAuthorizedKeys = addSSHKeys(userAuthorizedKeys, etcAuthorizedKeys)
}
return writeSSHKeysFile(user, userAuthorizedKeyFile, userAuthorizedKeys)
}
func writeSSHKeysFile(u *user.User, authorizedKeyFile string, sshKeys []*sshKey) error {
if err := createAutorizationFile(authorizedKeyFile); err != nil {
return err
}
f, err := os.Create(authorizedKeyFile)
if err != nil {
return err
}
defer f.Close()
err = writeSSHKeys(f, sshKeys)
if err != nil {
return err
}
uid, err := strconv.Atoi(u.Uid)
if err != nil {
return err
}
gid, err := strconv.Atoi(u.Gid)
if err != nil {
return err
}
return os.Chown(authorizedKeyFile, uid, gid)
}
func writeSSHKeys(w io.Writer, sshKeys []*sshKey) error {
for i := range sshKeys {
fmt.Fprintln(w, sshKeys[i].String())
}
return nil
}
type sshKey struct {
algorithm string
pubKey string
alias string
}
func (s *sshKey) Compare(k *sshKey) bool {
if s.algorithm == k.algorithm &&
s.pubKey == k.pubKey {
return true
}
return false
}
func (s *sshKey) SetAlias(alias string) {
if len(s.alias) <= 0 {
s.alias = alias
}
}
func (s *sshKey) String() string {
l := fmt.Sprintf("%v %v", s.algorithm, s.pubKey)
if len(s.alias) > 0 {
l = fmt.Sprintf("%v %v", l, s.alias)
}
return l
}
func (s *sshKey) Validate() error {
entries := map[string]string{
"algorithm": s.algorithm,
"pubKey": s.pubKey,
}
for key, value := range entries {
if len(value) <= 0 {
return fmt.Errorf("Missing attribute %v", key)
}
}
return nil
}
func newSSHKey(algorithm string, pubKey string) (*sshKey, error) {
return &sshKey{
algorithm: algorithm,
pubKey: pubKey,
}, nil
}