2020-09-02 09:07:10 +00:00
|
|
|
package main
|
|
|
|
|
|
|
|
import (
|
|
|
|
"bufio"
|
|
|
|
"errors"
|
|
|
|
"fmt"
|
2020-09-03 08:25:25 +00:00
|
|
|
"io"
|
2020-09-02 09:07:10 +00:00
|
|
|
"log"
|
|
|
|
"os"
|
|
|
|
"os/user"
|
|
|
|
"path/filepath"
|
2021-04-11 11:37:32 +00:00
|
|
|
"strconv"
|
2020-09-02 09:07:10 +00:00
|
|
|
"strings"
|
|
|
|
|
|
|
|
"github.com/spf13/cobra"
|
|
|
|
)
|
|
|
|
|
|
|
|
var (
|
|
|
|
version string
|
|
|
|
)
|
|
|
|
|
|
|
|
func main() {
|
|
|
|
rootCmd := cobra.Command{
|
|
|
|
Use: "set-sshkeys",
|
|
|
|
RunE: rootCmd,
|
|
|
|
Version: version,
|
|
|
|
}
|
|
|
|
rootCmd.Flags().String("authorized_keys", "/etc/set-sshkeys/authorized_keys", "Public ssh keys which should be merged with users existing ssh keys")
|
|
|
|
rootCmd.Flags().Bool("remove", false, "Remove public ssh keys")
|
|
|
|
rootCmd.Flags().String("user", "root", "For which user the public SSH keys should be defined")
|
|
|
|
|
|
|
|
rootCmd.Execute()
|
|
|
|
}
|
|
|
|
|
|
|
|
func addSSHKeys(sshKeys []*sshKey, newSSHKeys []*sshKey) []*sshKey {
|
|
|
|
Label:
|
|
|
|
for i := range newSSHKeys {
|
2020-09-03 08:25:25 +00:00
|
|
|
if err := newSSHKeys[i].Validate(); err != nil {
|
|
|
|
continue Label
|
|
|
|
}
|
|
|
|
|
2020-09-02 09:07:10 +00:00
|
|
|
for j := range sshKeys {
|
|
|
|
if sshKeys[j].Compare(newSSHKeys[i]) {
|
2020-09-03 08:25:25 +00:00
|
|
|
sshKeys[j].SetAlias(newSSHKeys[i].alias)
|
2020-09-02 09:07:10 +00:00
|
|
|
continue Label
|
|
|
|
}
|
|
|
|
}
|
|
|
|
sshKeys = append(sshKeys, newSSHKeys[i])
|
|
|
|
}
|
|
|
|
return sshKeys
|
|
|
|
}
|
|
|
|
|
|
|
|
func createAutorizationFile(authorizedKeyFile string) error {
|
|
|
|
|
2021-04-11 11:37:32 +00:00
|
|
|
err := os.MkdirAll(filepath.Dir(authorizedKeyFile), 0700)
|
2020-09-02 09:07:10 +00:00
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
f, err := os.Create(authorizedKeyFile)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
return f.Close()
|
|
|
|
}
|
|
|
|
|
2020-09-03 08:25:25 +00:00
|
|
|
func readSSHKeysFile(authorizedKeyFile string) ([]*sshKey, error) {
|
2020-09-02 09:07:10 +00:00
|
|
|
f, err := os.Open(authorizedKeyFile)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
2020-09-03 08:25:25 +00:00
|
|
|
defer f.Close()
|
|
|
|
return readSSHKeys(f)
|
|
|
|
}
|
2020-09-02 09:07:10 +00:00
|
|
|
|
2020-09-03 08:25:25 +00:00
|
|
|
func readSSHKeys(r io.Reader) ([]*sshKey, error) {
|
2020-09-02 09:07:10 +00:00
|
|
|
|
2020-09-03 08:25:25 +00:00
|
|
|
var (
|
|
|
|
sshKeys = make([]*sshKey, 0)
|
|
|
|
s = bufio.NewScanner(r)
|
|
|
|
)
|
2020-09-02 09:07:10 +00:00
|
|
|
|
|
|
|
Loop:
|
|
|
|
for s.Scan() {
|
|
|
|
line := s.Text()
|
|
|
|
parts := strings.Split(line, " ")
|
|
|
|
|
|
|
|
switch len(parts) {
|
|
|
|
case 2:
|
|
|
|
|
|
|
|
algorithm := parts[0]
|
|
|
|
pubKey := parts[1]
|
|
|
|
|
|
|
|
sshKey, err := newSSHKey(algorithm, pubKey)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
for i := range sshKeys {
|
|
|
|
if sshKeys[i].Compare(sshKey) {
|
|
|
|
continue Loop
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
sshKeys = append(sshKeys, sshKey)
|
|
|
|
|
|
|
|
case 3:
|
|
|
|
|
|
|
|
algorithm := parts[0]
|
|
|
|
pubKey := parts[1]
|
|
|
|
alias := parts[2]
|
|
|
|
|
|
|
|
sshKey, err := newSSHKey(algorithm, pubKey)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
sshKey.SetAlias(alias)
|
|
|
|
|
|
|
|
for i := range sshKeys {
|
|
|
|
if sshKeys[i].Compare(sshKey) {
|
2020-09-03 08:25:25 +00:00
|
|
|
sshKeys[i].SetAlias(sshKey.alias)
|
2020-09-02 09:07:10 +00:00
|
|
|
continue Loop
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
sshKeys = append(sshKeys, sshKey)
|
|
|
|
|
|
|
|
default:
|
2020-09-03 08:25:25 +00:00
|
|
|
log.Printf("WARN: Require two and optional three parts for each line. Get %v parts - SKIP entry", len(parts))
|
2020-09-02 09:07:10 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return sshKeys, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func removeSSHKeys(sshKeys []*sshKey, removeSSHKeys []*sshKey) []*sshKey {
|
2020-09-02 11:28:21 +00:00
|
|
|
|
|
|
|
s := make([]*sshKey, 0)
|
|
|
|
|
|
|
|
Loop:
|
|
|
|
for i := range sshKeys {
|
|
|
|
for j := range removeSSHKeys {
|
|
|
|
if sshKeys[i].Compare(removeSSHKeys[j]) {
|
|
|
|
continue Loop
|
2020-09-02 09:07:10 +00:00
|
|
|
}
|
|
|
|
}
|
2020-09-02 11:28:21 +00:00
|
|
|
s = append(s, sshKeys[i])
|
2020-09-02 09:07:10 +00:00
|
|
|
}
|
2020-09-02 11:28:21 +00:00
|
|
|
|
|
|
|
return s
|
2020-09-02 09:07:10 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
func rootCmd(cmd *cobra.Command, args []string) error {
|
|
|
|
|
|
|
|
remove, err := cmd.Flags().GetBool("remove")
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
etcAuthorizedKeyFile, err := cmd.Flags().GetString("authorized_keys")
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
_, err = os.Stat(etcAuthorizedKeyFile)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
username, err := cmd.Flags().GetString("user")
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
user, err := user.Lookup(username)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
userAuthorizedKeyFile := filepath.Join(user.HomeDir, ".ssh", "authorized_keys")
|
|
|
|
|
|
|
|
_, err = os.Stat(userAuthorizedKeyFile)
|
|
|
|
switch {
|
|
|
|
case err == nil:
|
|
|
|
break
|
|
|
|
case errors.Is(err, os.ErrNotExist):
|
|
|
|
if err := createAutorizationFile(userAuthorizedKeyFile); err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
default:
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
2020-09-03 08:25:25 +00:00
|
|
|
etcAuthorizedKeys, err := readSSHKeysFile(etcAuthorizedKeyFile)
|
2020-09-02 09:07:10 +00:00
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
2020-09-03 08:25:25 +00:00
|
|
|
userAuthorizedKeys, err := readSSHKeysFile(userAuthorizedKeyFile)
|
2020-09-02 09:07:10 +00:00
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
if remove {
|
|
|
|
userAuthorizedKeys = removeSSHKeys(userAuthorizedKeys, etcAuthorizedKeys)
|
|
|
|
} else {
|
|
|
|
userAuthorizedKeys = addSSHKeys(userAuthorizedKeys, etcAuthorizedKeys)
|
|
|
|
}
|
|
|
|
|
2021-04-11 11:37:32 +00:00
|
|
|
return writeSSHKeysFile(user, userAuthorizedKeyFile, userAuthorizedKeys)
|
2020-09-02 09:07:10 +00:00
|
|
|
}
|
|
|
|
|
2021-04-11 11:37:32 +00:00
|
|
|
func writeSSHKeysFile(u *user.User, authorizedKeyFile string, sshKeys []*sshKey) error {
|
2020-09-02 09:07:10 +00:00
|
|
|
if err := createAutorizationFile(authorizedKeyFile); err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
f, err := os.Create(authorizedKeyFile)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
2020-09-03 08:25:25 +00:00
|
|
|
defer f.Close()
|
2020-09-02 09:07:10 +00:00
|
|
|
|
2021-04-11 11:37:32 +00:00
|
|
|
err = writeSSHKeys(f, sshKeys)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
uid, err := strconv.Atoi(u.Uid)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
gid, err := strconv.Atoi(u.Gid)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
return os.Chown(authorizedKeyFile, uid, gid)
|
2020-09-03 08:25:25 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
func writeSSHKeys(w io.Writer, sshKeys []*sshKey) error {
|
2020-09-02 09:07:10 +00:00
|
|
|
for i := range sshKeys {
|
2020-09-03 08:25:25 +00:00
|
|
|
fmt.Fprintln(w, sshKeys[i].String())
|
2020-09-02 09:07:10 +00:00
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
type sshKey struct {
|
|
|
|
algorithm string
|
|
|
|
pubKey string
|
|
|
|
alias string
|
|
|
|
}
|
|
|
|
|
|
|
|
func (s *sshKey) Compare(k *sshKey) bool {
|
|
|
|
if s.algorithm == k.algorithm &&
|
|
|
|
s.pubKey == k.pubKey {
|
|
|
|
return true
|
|
|
|
}
|
|
|
|
return false
|
|
|
|
}
|
|
|
|
|
|
|
|
func (s *sshKey) SetAlias(alias string) {
|
|
|
|
if len(s.alias) <= 0 {
|
|
|
|
s.alias = alias
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2020-09-02 11:28:21 +00:00
|
|
|
func (s *sshKey) String() string {
|
|
|
|
l := fmt.Sprintf("%v %v", s.algorithm, s.pubKey)
|
|
|
|
if len(s.alias) > 0 {
|
|
|
|
l = fmt.Sprintf("%v %v", l, s.alias)
|
|
|
|
}
|
|
|
|
return l
|
|
|
|
}
|
|
|
|
|
2020-09-02 09:07:10 +00:00
|
|
|
func (s *sshKey) Validate() error {
|
|
|
|
entries := map[string]string{
|
|
|
|
"algorithm": s.algorithm,
|
|
|
|
"pubKey": s.pubKey,
|
|
|
|
}
|
|
|
|
|
|
|
|
for key, value := range entries {
|
|
|
|
if len(value) <= 0 {
|
|
|
|
return fmt.Errorf("Missing attribute %v", key)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func newSSHKey(algorithm string, pubKey string) (*sshKey, error) {
|
|
|
|
return &sshKey{
|
|
|
|
algorithm: algorithm,
|
|
|
|
pubKey: pubKey,
|
|
|
|
}, nil
|
|
|
|
}
|