fix: add test functions

changes:
- fix: test add, remove, read and write functions
- fix: parsing ssh key aliases
- fix: skip invalid ssh keys
This commit is contained in:
Markus Pesch 2020-09-03 10:25:25 +02:00
parent a50da9bbed
commit 2ef390bd8f
Signed by: volker.raschek
GPG Key ID: 852BCC170D81A982
2 changed files with 182 additions and 18 deletions

45
main.go
View File

@ -4,24 +4,21 @@ import (
"bufio" "bufio"
"errors" "errors"
"fmt" "fmt"
"io"
"log" "log"
"os" "os"
"os/user" "os/user"
"path/filepath" "path/filepath"
"strings" "strings"
"git.cryptic.systems/volker.raschek/go-logger"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
var ( var (
flogger logger.Logger
version string version string
) )
func main() { func main() {
flogger = logger.NewLogger(logger.LogLevelDebug)
rootCmd := cobra.Command{ rootCmd := cobra.Command{
Use: "set-sshkeys", Use: "set-sshkeys",
RunE: rootCmd, RunE: rootCmd,
@ -37,8 +34,13 @@ func main() {
func addSSHKeys(sshKeys []*sshKey, newSSHKeys []*sshKey) []*sshKey { func addSSHKeys(sshKeys []*sshKey, newSSHKeys []*sshKey) []*sshKey {
Label: Label:
for i := range newSSHKeys { for i := range newSSHKeys {
if err := newSSHKeys[i].Validate(); err != nil {
continue Label
}
for j := range sshKeys { for j := range sshKeys {
if sshKeys[j].Compare(newSSHKeys[i]) { if sshKeys[j].Compare(newSSHKeys[i]) {
sshKeys[j].SetAlias(newSSHKeys[i].alias)
continue Label continue Label
} }
} }
@ -62,16 +64,21 @@ func createAutorizationFile(authorizedKeyFile string) error {
return f.Close() return f.Close()
} }
func readSSHKeys(authorizedKeyFile string) ([]*sshKey, error) { func readSSHKeysFile(authorizedKeyFile string) ([]*sshKey, error) {
f, err := os.Open(authorizedKeyFile) f, err := os.Open(authorizedKeyFile)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer f.Close()
return readSSHKeys(f)
}
sshKeys := make([]*sshKey, 0) func readSSHKeys(r io.Reader) ([]*sshKey, error) {
s := bufio.NewScanner(f) var (
sshKeys = make([]*sshKey, 0)
s = bufio.NewScanner(r)
)
Loop: Loop:
for s.Scan() { for s.Scan() {
@ -107,11 +114,11 @@ Loop:
if err != nil { if err != nil {
return nil, err return nil, err
} }
sshKey.SetAlias(alias) sshKey.SetAlias(alias)
for i := range sshKeys { for i := range sshKeys {
if sshKeys[i].Compare(sshKey) { if sshKeys[i].Compare(sshKey) {
sshKeys[i].SetAlias(sshKey.alias)
continue Loop continue Loop
} }
} }
@ -119,7 +126,7 @@ Loop:
sshKeys = append(sshKeys, sshKey) sshKeys = append(sshKeys, sshKey)
default: default:
log.Printf("Require two and optional three parts for each line. Get %v parts", len(parts)) log.Printf("WARN: Require two and optional three parts for each line. Get %v parts - SKIP entry", len(parts))
} }
} }
@ -177,7 +184,6 @@ func rootCmd(cmd *cobra.Command, args []string) error {
case err == nil: case err == nil:
break break
case errors.Is(err, os.ErrNotExist): case errors.Is(err, os.ErrNotExist):
flogger.Debug("Create authorization file %v", userAuthorizedKeyFile)
if err := createAutorizationFile(userAuthorizedKeyFile); err != nil { if err := createAutorizationFile(userAuthorizedKeyFile); err != nil {
return err return err
} }
@ -185,12 +191,12 @@ func rootCmd(cmd *cobra.Command, args []string) error {
return err return err
} }
etcAuthorizedKeys, err := readSSHKeys(etcAuthorizedKeyFile) etcAuthorizedKeys, err := readSSHKeysFile(etcAuthorizedKeyFile)
if err != nil { if err != nil {
return err return err
} }
userAuthorizedKeys, err := readSSHKeys(userAuthorizedKeyFile) userAuthorizedKeys, err := readSSHKeysFile(userAuthorizedKeyFile)
if err != nil { if err != nil {
return err return err
} }
@ -201,11 +207,10 @@ func rootCmd(cmd *cobra.Command, args []string) error {
userAuthorizedKeys = addSSHKeys(userAuthorizedKeys, etcAuthorizedKeys) userAuthorizedKeys = addSSHKeys(userAuthorizedKeys, etcAuthorizedKeys)
} }
return writeSSHKeys(userAuthorizedKeyFile, userAuthorizedKeys) return writeSSHKeysFile(userAuthorizedKeyFile, userAuthorizedKeys)
} }
func writeSSHKeys(authorizedKeyFile string, sshKeys []*sshKey) error { func writeSSHKeysFile(authorizedKeyFile string, sshKeys []*sshKey) error {
if err := createAutorizationFile(authorizedKeyFile); err != nil { if err := createAutorizationFile(authorizedKeyFile); err != nil {
return err return err
} }
@ -214,11 +219,15 @@ func writeSSHKeys(authorizedKeyFile string, sshKeys []*sshKey) error {
if err != nil { if err != nil {
return err return err
} }
defer f.Close()
return writeSSHKeys(f, sshKeys)
}
func writeSSHKeys(w io.Writer, sshKeys []*sshKey) error {
for i := range sshKeys { for i := range sshKeys {
fmt.Fprintln(f, sshKeys[i].String()) fmt.Fprintln(w, sshKeys[i].String())
} }
return nil return nil
} }

View File

@ -1,11 +1,91 @@
package main package main
import ( import (
"bytes"
"fmt"
"testing" "testing"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestAdd(t *testing.T) {
require := require.New(t)
sshKeys := addSSHKeys([]*sshKey{}, []*sshKey{{algorithm: "ssh-ed25519", pubKey: "abcdefg"}})
require.Len(sshKeys, 1)
sshKeys = addSSHKeys(sshKeys, []*sshKey{{algorithm: "ssh-ed25519", pubKey: "abcdefg"}})
require.Len(sshKeys, 1)
sshKeys = addSSHKeys(sshKeys, []*sshKey{{algorithm: "ssh-ed25519", pubKey: "abcdefg", alias: "buxdehude"}})
require.Len(sshKeys, 1)
require.Equal("buxdehude", sshKeys[0].alias)
sshKeys = addSSHKeys(sshKeys, []*sshKey{{algorithm: "ssh-ed25519", pubKey: "asdsadasd", alias: "hello@world"}})
require.Len(sshKeys, 2)
sshKeys = addSSHKeys(sshKeys, []*sshKey{{algorithm: "ssh-ed25519", pubKey: "asdsadasd", alias: "world@hello"}})
require.Len(sshKeys, 2)
require.Equal("hello@world", sshKeys[1].alias)
sshKeys = addSSHKeys(sshKeys, []*sshKey{{algorithm: "ssh-ed25519", pubKey: "", alias: "world@hello"}})
require.Len(sshKeys, 2)
sshKeys = addSSHKeys(sshKeys, []*sshKey{{algorithm: "", pubKey: "asdsadasd", alias: "world@hello"}})
require.Len(sshKeys, 2)
}
func TestRead(t *testing.T) {
require := require.New(t)
testCases := []struct {
expectedSSHKeys []*sshKey
b []byte
}{
{
expectedSSHKeys: []*sshKey{
{algorithm: "ssh-ed25519", pubKey: "abcdefg", alias: "hello@world"},
},
b: []byte(`ssh-ed25519 abcdefg hello@world`),
},
{
expectedSSHKeys: []*sshKey{
{algorithm: "ssh-ed25519", pubKey: "abcdefg", alias: "hello@world"},
},
b: []byte(`ssh-ed25519 abcdefg hello@world
ssh-ed25519 abcdefg`),
},
{
expectedSSHKeys: []*sshKey{
{algorithm: "ssh-ed25519", pubKey: "abcdefg", alias: "hello@world"},
},
b: []byte(`ssh-ed25519 abcdefg
ssh-ed25519 abcdefg hello@world`),
},
{
expectedSSHKeys: []*sshKey{
{algorithm: "ssh-ed25519", pubKey: "abcdefg", alias: "hello@world"},
},
b: []byte(`ssh-ed25519 abcdefg hello@world
ssh-ed25519 abcdefg world@hello`),
},
{
expectedSSHKeys: []*sshKey{
{algorithm: "ssh-ed25519", pubKey: "abcdefg", alias: "hello@world"},
},
b: []byte(`ssh-ed25519 abcdefg hello@world
ssh-ed25519
abcdefg test`),
},
}
for i := range testCases {
sshKeys, err := readSSHKeys(bytes.NewReader(testCases[i].b))
require.NoError(err)
require.ElementsMatch(testCases[i].expectedSSHKeys, sshKeys)
}
}
func TestRemove(t *testing.T) { func TestRemove(t *testing.T) {
require := require.New(t) require := require.New(t)
@ -37,6 +117,27 @@ func TestRemove(t *testing.T) {
require.Len(result, 1) require.Len(result, 1)
} }
func TestSSHKeyCompare(t *testing.T) {
require := require.New(t)
testCases := []struct {
s *sshKey
r *sshKey
expectedValue bool
}{
{&sshKey{algorithm: "ssh-rsa", pubKey: "abcdefgh", alias: "world@hello"}, &sshKey{algorithm: "ssh-rsa", pubKey: "abcdefgh", alias: "world@hello"}, true},
{&sshKey{algorithm: "ssh-rsa", pubKey: "abcdefgh", alias: "world@hello"}, &sshKey{algorithm: "ssh-rsa", pubKey: "abcdefgh"}, true},
{&sshKey{algorithm: "ssh-rsa", pubKey: "abcdefgh"}, &sshKey{algorithm: "ssh-rsa", pubKey: "abcdefgh", alias: "world@hello"}, true},
{&sshKey{algorithm: "ssh-rsa", pubKey: "abcdefgh", alias: "world@hello"}, &sshKey{algorithm: "ssh-ed25519", pubKey: "sdfsdf", alias: "hello@world"}, false},
{&sshKey{algorithm: "ssh-rsa", pubKey: "abcdefgh", alias: "world@hello"}, &sshKey{algorithm: "ssh-rsa", pubKey: "sdfsdf", alias: "hello@world"}, false},
{&sshKey{algorithm: "ssh-rsa", pubKey: "abcdefgh", alias: "world@hello"}, &sshKey{algorithm: "ssh-rsa", pubKey: "abcdefgh", alias: "hello@world"}, true},
}
for i := range testCases {
require.Equal(testCases[i].expectedValue, testCases[i].s.Compare(testCases[i].r), "TestCase %v does not match the expected value", i)
}
}
func TestSSHKeyString(t *testing.T) { func TestSSHKeyString(t *testing.T) {
require := require.New(t) require := require.New(t)
@ -56,3 +157,57 @@ func TestSSHKeyString(t *testing.T) {
require.Equal(b[i], s[i].String()) require.Equal(b[i], s[i].String())
} }
} }
func TestSSHKeyValidate(t *testing.T) {
require := require.New(t)
testCases := []struct {
s *sshKey
expectedValue error
}{
{&sshKey{algorithm: "ssh-rsa", pubKey: "abcdefgh", alias: "world@hello"}, nil},
{&sshKey{algorithm: "ssh-rsa", pubKey: "abcdefgh"}, nil},
{&sshKey{algorithm: "ssh-rsa", pubKey: ""}, fmt.Errorf("Missing attribute pubKey")},
{&sshKey{algorithm: "ssh-rsa"}, fmt.Errorf("Missing attribute pubKey")},
{&sshKey{algorithm: "", pubKey: "abcdefgh"}, fmt.Errorf("Missing attribute algorithm")},
{&sshKey{pubKey: "abcdefgh"}, fmt.Errorf("Missing attribute algorithm")},
{&sshKey{algorithm: "ssh-rsa", alias: "sdfsdf"}, fmt.Errorf("Missing attribute pubKey")},
{&sshKey{pubKey: "abcdefgh", alias: "sdfsdf"}, fmt.Errorf("Missing attribute algorithm")},
}
for i := range testCases {
if testCases[i].expectedValue == nil {
require.NoError(testCases[i].s.Validate())
continue
}
require.EqualError(testCases[i].expectedValue, testCases[i].s.Validate().Error())
}
}
func TestWrite(t *testing.T) {
require := require.New(t)
testCases := []struct {
expectedSSHKeys []*sshKey
b []byte
}{
{
expectedSSHKeys: []*sshKey{
{algorithm: "ssh-ed25519", pubKey: "abcdefg", alias: "hello@world"},
{algorithm: "ssh-rsa", pubKey: "gfedcba"},
{algorithm: "ssh-rsa", pubKey: "1234567", alias: "world@hello"},
},
b: []byte(`ssh-ed25519 abcdefg hello@world
ssh-rsa gfedcba
ssh-rsa 1234567 world@hello
`),
},
}
for i := range testCases {
b := bytes.NewBuffer([]byte{})
err := writeSSHKeys(b, testCases[i].expectedSSHKeys)
require.NoError(err)
require.Equal(testCases[i].b, b.Bytes())
}
}