From 2ef390bd8f07923034e70e72685861061bdb5c4a Mon Sep 17 00:00:00 2001 From: Markus Pesch Date: Thu, 3 Sep 2020 10:25:25 +0200 Subject: [PATCH] fix: add test functions changes: - fix: test add, remove, read and write functions - fix: parsing ssh key aliases - fix: skip invalid ssh keys --- main.go | 45 +++++++++------ main_test.go | 155 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 182 insertions(+), 18 deletions(-) diff --git a/main.go b/main.go index 6a81205..4d44f09 100644 --- a/main.go +++ b/main.go @@ -4,24 +4,21 @@ import ( "bufio" "errors" "fmt" + "io" "log" "os" "os/user" "path/filepath" "strings" - "git.cryptic.systems/volker.raschek/go-logger" "github.com/spf13/cobra" ) var ( - flogger logger.Logger version string ) func main() { - flogger = logger.NewLogger(logger.LogLevelDebug) - rootCmd := cobra.Command{ Use: "set-sshkeys", RunE: rootCmd, @@ -37,8 +34,13 @@ func main() { func addSSHKeys(sshKeys []*sshKey, newSSHKeys []*sshKey) []*sshKey { Label: for i := range newSSHKeys { + if err := newSSHKeys[i].Validate(); err != nil { + continue Label + } + for j := range sshKeys { if sshKeys[j].Compare(newSSHKeys[i]) { + sshKeys[j].SetAlias(newSSHKeys[i].alias) continue Label } } @@ -62,16 +64,21 @@ func createAutorizationFile(authorizedKeyFile string) error { return f.Close() } -func readSSHKeys(authorizedKeyFile string) ([]*sshKey, error) { - +func readSSHKeysFile(authorizedKeyFile string) ([]*sshKey, error) { f, err := os.Open(authorizedKeyFile) if err != nil { return nil, err } + defer f.Close() + return readSSHKeys(f) +} - sshKeys := make([]*sshKey, 0) +func readSSHKeys(r io.Reader) ([]*sshKey, error) { - s := bufio.NewScanner(f) + var ( + sshKeys = make([]*sshKey, 0) + s = bufio.NewScanner(r) + ) Loop: for s.Scan() { @@ -107,11 +114,11 @@ Loop: if err != nil { return nil, err } - sshKey.SetAlias(alias) for i := range sshKeys { if sshKeys[i].Compare(sshKey) { + sshKeys[i].SetAlias(sshKey.alias) continue Loop } } @@ -119,7 +126,7 @@ Loop: sshKeys = append(sshKeys, sshKey) default: - log.Printf("Require two and optional three parts for each line. Get %v parts", len(parts)) + log.Printf("WARN: Require two and optional three parts for each line. Get %v parts - SKIP entry", len(parts)) } } @@ -177,7 +184,6 @@ func rootCmd(cmd *cobra.Command, args []string) error { case err == nil: break case errors.Is(err, os.ErrNotExist): - flogger.Debug("Create authorization file %v", userAuthorizedKeyFile) if err := createAutorizationFile(userAuthorizedKeyFile); err != nil { return err } @@ -185,12 +191,12 @@ func rootCmd(cmd *cobra.Command, args []string) error { return err } - etcAuthorizedKeys, err := readSSHKeys(etcAuthorizedKeyFile) + etcAuthorizedKeys, err := readSSHKeysFile(etcAuthorizedKeyFile) if err != nil { return err } - userAuthorizedKeys, err := readSSHKeys(userAuthorizedKeyFile) + userAuthorizedKeys, err := readSSHKeysFile(userAuthorizedKeyFile) if err != nil { return err } @@ -201,11 +207,10 @@ func rootCmd(cmd *cobra.Command, args []string) error { userAuthorizedKeys = addSSHKeys(userAuthorizedKeys, etcAuthorizedKeys) } - return writeSSHKeys(userAuthorizedKeyFile, userAuthorizedKeys) + return writeSSHKeysFile(userAuthorizedKeyFile, userAuthorizedKeys) } -func writeSSHKeys(authorizedKeyFile string, sshKeys []*sshKey) error { - +func writeSSHKeysFile(authorizedKeyFile string, sshKeys []*sshKey) error { if err := createAutorizationFile(authorizedKeyFile); err != nil { return err } @@ -214,11 +219,15 @@ func writeSSHKeys(authorizedKeyFile string, sshKeys []*sshKey) error { if err != nil { return err } + defer f.Close() + return writeSSHKeys(f, sshKeys) +} + +func writeSSHKeys(w io.Writer, sshKeys []*sshKey) error { for i := range sshKeys { - fmt.Fprintln(f, sshKeys[i].String()) + fmt.Fprintln(w, sshKeys[i].String()) } - return nil } diff --git a/main_test.go b/main_test.go index 92737e7..cfd6326 100644 --- a/main_test.go +++ b/main_test.go @@ -1,11 +1,91 @@ package main import ( + "bytes" + "fmt" "testing" "github.com/stretchr/testify/require" ) +func TestAdd(t *testing.T) { + require := require.New(t) + + sshKeys := addSSHKeys([]*sshKey{}, []*sshKey{{algorithm: "ssh-ed25519", pubKey: "abcdefg"}}) + require.Len(sshKeys, 1) + + sshKeys = addSSHKeys(sshKeys, []*sshKey{{algorithm: "ssh-ed25519", pubKey: "abcdefg"}}) + require.Len(sshKeys, 1) + + sshKeys = addSSHKeys(sshKeys, []*sshKey{{algorithm: "ssh-ed25519", pubKey: "abcdefg", alias: "buxdehude"}}) + require.Len(sshKeys, 1) + require.Equal("buxdehude", sshKeys[0].alias) + + sshKeys = addSSHKeys(sshKeys, []*sshKey{{algorithm: "ssh-ed25519", pubKey: "asdsadasd", alias: "hello@world"}}) + require.Len(sshKeys, 2) + + sshKeys = addSSHKeys(sshKeys, []*sshKey{{algorithm: "ssh-ed25519", pubKey: "asdsadasd", alias: "world@hello"}}) + require.Len(sshKeys, 2) + require.Equal("hello@world", sshKeys[1].alias) + + sshKeys = addSSHKeys(sshKeys, []*sshKey{{algorithm: "ssh-ed25519", pubKey: "", alias: "world@hello"}}) + require.Len(sshKeys, 2) + + sshKeys = addSSHKeys(sshKeys, []*sshKey{{algorithm: "", pubKey: "asdsadasd", alias: "world@hello"}}) + require.Len(sshKeys, 2) +} + +func TestRead(t *testing.T) { + require := require.New(t) + + testCases := []struct { + expectedSSHKeys []*sshKey + b []byte + }{ + { + expectedSSHKeys: []*sshKey{ + {algorithm: "ssh-ed25519", pubKey: "abcdefg", alias: "hello@world"}, + }, + b: []byte(`ssh-ed25519 abcdefg hello@world`), + }, + { + expectedSSHKeys: []*sshKey{ + {algorithm: "ssh-ed25519", pubKey: "abcdefg", alias: "hello@world"}, + }, + b: []byte(`ssh-ed25519 abcdefg hello@world +ssh-ed25519 abcdefg`), + }, + { + expectedSSHKeys: []*sshKey{ + {algorithm: "ssh-ed25519", pubKey: "abcdefg", alias: "hello@world"}, + }, + b: []byte(`ssh-ed25519 abcdefg +ssh-ed25519 abcdefg hello@world`), + }, + { + expectedSSHKeys: []*sshKey{ + {algorithm: "ssh-ed25519", pubKey: "abcdefg", alias: "hello@world"}, + }, + b: []byte(`ssh-ed25519 abcdefg hello@world +ssh-ed25519 abcdefg world@hello`), + }, + { + expectedSSHKeys: []*sshKey{ + {algorithm: "ssh-ed25519", pubKey: "abcdefg", alias: "hello@world"}, + }, + b: []byte(`ssh-ed25519 abcdefg hello@world +ssh-ed25519 + abcdefg test`), + }, + } + + for i := range testCases { + sshKeys, err := readSSHKeys(bytes.NewReader(testCases[i].b)) + require.NoError(err) + require.ElementsMatch(testCases[i].expectedSSHKeys, sshKeys) + } +} + func TestRemove(t *testing.T) { require := require.New(t) @@ -37,6 +117,27 @@ func TestRemove(t *testing.T) { require.Len(result, 1) } +func TestSSHKeyCompare(t *testing.T) { + require := require.New(t) + + testCases := []struct { + s *sshKey + r *sshKey + expectedValue bool + }{ + {&sshKey{algorithm: "ssh-rsa", pubKey: "abcdefgh", alias: "world@hello"}, &sshKey{algorithm: "ssh-rsa", pubKey: "abcdefgh", alias: "world@hello"}, true}, + {&sshKey{algorithm: "ssh-rsa", pubKey: "abcdefgh", alias: "world@hello"}, &sshKey{algorithm: "ssh-rsa", pubKey: "abcdefgh"}, true}, + {&sshKey{algorithm: "ssh-rsa", pubKey: "abcdefgh"}, &sshKey{algorithm: "ssh-rsa", pubKey: "abcdefgh", alias: "world@hello"}, true}, + {&sshKey{algorithm: "ssh-rsa", pubKey: "abcdefgh", alias: "world@hello"}, &sshKey{algorithm: "ssh-ed25519", pubKey: "sdfsdf", alias: "hello@world"}, false}, + {&sshKey{algorithm: "ssh-rsa", pubKey: "abcdefgh", alias: "world@hello"}, &sshKey{algorithm: "ssh-rsa", pubKey: "sdfsdf", alias: "hello@world"}, false}, + {&sshKey{algorithm: "ssh-rsa", pubKey: "abcdefgh", alias: "world@hello"}, &sshKey{algorithm: "ssh-rsa", pubKey: "abcdefgh", alias: "hello@world"}, true}, + } + + for i := range testCases { + require.Equal(testCases[i].expectedValue, testCases[i].s.Compare(testCases[i].r), "TestCase %v does not match the expected value", i) + } +} + func TestSSHKeyString(t *testing.T) { require := require.New(t) @@ -56,3 +157,57 @@ func TestSSHKeyString(t *testing.T) { require.Equal(b[i], s[i].String()) } } + +func TestSSHKeyValidate(t *testing.T) { + require := require.New(t) + + testCases := []struct { + s *sshKey + expectedValue error + }{ + {&sshKey{algorithm: "ssh-rsa", pubKey: "abcdefgh", alias: "world@hello"}, nil}, + {&sshKey{algorithm: "ssh-rsa", pubKey: "abcdefgh"}, nil}, + {&sshKey{algorithm: "ssh-rsa", pubKey: ""}, fmt.Errorf("Missing attribute pubKey")}, + {&sshKey{algorithm: "ssh-rsa"}, fmt.Errorf("Missing attribute pubKey")}, + {&sshKey{algorithm: "", pubKey: "abcdefgh"}, fmt.Errorf("Missing attribute algorithm")}, + {&sshKey{pubKey: "abcdefgh"}, fmt.Errorf("Missing attribute algorithm")}, + {&sshKey{algorithm: "ssh-rsa", alias: "sdfsdf"}, fmt.Errorf("Missing attribute pubKey")}, + {&sshKey{pubKey: "abcdefgh", alias: "sdfsdf"}, fmt.Errorf("Missing attribute algorithm")}, + } + + for i := range testCases { + if testCases[i].expectedValue == nil { + require.NoError(testCases[i].s.Validate()) + continue + } + require.EqualError(testCases[i].expectedValue, testCases[i].s.Validate().Error()) + } +} + +func TestWrite(t *testing.T) { + require := require.New(t) + + testCases := []struct { + expectedSSHKeys []*sshKey + b []byte + }{ + { + expectedSSHKeys: []*sshKey{ + {algorithm: "ssh-ed25519", pubKey: "abcdefg", alias: "hello@world"}, + {algorithm: "ssh-rsa", pubKey: "gfedcba"}, + {algorithm: "ssh-rsa", pubKey: "1234567", alias: "world@hello"}, + }, + b: []byte(`ssh-ed25519 abcdefg hello@world +ssh-rsa gfedcba +ssh-rsa 1234567 world@hello +`), + }, + } + + for i := range testCases { + b := bytes.NewBuffer([]byte{}) + err := writeSSHKeys(b, testCases[i].expectedSSHKeys) + require.NoError(err) + require.Equal(testCases[i].b, b.Bytes()) + } +}