diff --git a/.travis.yml b/.travis.yml index f91065a07eda0a515a9cf23aadc2c12edb902b0e..cae734902464590869ba15d9f2ea33e8ae344bb6 100644 --- a/.travis.yml +++ b/.travis.yml @@ -8,7 +8,7 @@ env: go: - 1.7.4 - - 1.8rc1 + - 1.8rc2 - tip matrix: diff --git a/README.md b/README.md index 60a196130f1dceb5058117ca247ed830be1ac772..5d294c32030bc48413c5662128e17309c777edd6 100644 --- a/README.md +++ b/README.md @@ -99,6 +99,8 @@ For any option that takes a file path as a parameter (e.g. SSH signing key, TLS - A Google GCS bucket + object path starting with `/gcs/` e.g. `/gcs/my-bucket/ssh_signing_key`. - A [Vault](https://www.vaultproject.io) path + key starting with `/vault/` e.g. `/vault/secret/cashier/ssh_signing_key`. You should add a [vault](#vault) config as needed. +Exception to this: the `http_logfile` option **ONLY** writes to local files. + ## server - `use_tls` : boolean. If this is set then either `tls_key` and `tls_cert` are required, or `letsencrypt_servername` is required. - `tls_key` : string. Path to the TLS key. See the [note](#a-note-on-files) on files above. @@ -110,7 +112,7 @@ For any option that takes a file path as a parameter (e.g. SSH signing key, TLS - `user` : string. User to which the server drops privileges to. - `cookie_secret`: string. Authentication key for the session cookie. This can be a secret stored in a [vault](https://www.vaultproject.io/) using the form `/vault/path/key` e.g. `/vault/secret/cashier/cookie_secret`. - `csrf_secret`: string. Authentication key for CSRF protection. This can be a secret stored in a [vault](https://www.vaultproject.io/) using the form `/vault/path/key` e.g. `/vault/secret/cashier/csrf_secret`. -- `http_logfile`: string. Path to the HTTP request log. Logs are written in the [Common Log Format](https://en.wikipedia.org/wiki/Common_Log_Format). If not set logs are written to stderr. +- `http_logfile`: string. Path to the HTTP request log. Logs are written in the [Common Log Format](https://en.wikipedia.org/wiki/Common_Log_Format). The only valid destination for logs is a local file path. - `datastore`: string. Datastore connection string. See [Datastore](#datastore). ### database diff --git a/client/client.go b/client/client.go index b13c4cbdd9c5ad88246f741077df7d704b298b27..382c53dcb4a6f86b4158becca4756a96f342ed27 100644 --- a/client/client.go +++ b/client/client.go @@ -11,6 +11,7 @@ import ( "time" "github.com/nsheridan/cashier/lib" + "github.com/pkg/errors" "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/agent" ) @@ -27,7 +28,7 @@ func InstallCert(a agent.Agent, cert *ssh.Certificate, key Key) error { LifetimeSecs: uint32(lifetime), } if err := a.Add(pubcert); err != nil { - return fmt.Errorf("error importing certificate: %s", err) + return errors.Wrap(err, "unable to add cert to ssh agent") } privkey := agent.AddedKey{ PrivateKey: key, @@ -35,7 +36,7 @@ func InstallCert(a agent.Agent, cert *ssh.Certificate, key Key) error { LifetimeSecs: uint32(lifetime), } if err := a.Add(privkey); err != nil { - return fmt.Errorf("error importing key: %s", err) + return errors.Wrap(err, "unable to add private key to ssh agent") } return nil } @@ -48,7 +49,7 @@ func send(s []byte, token, ca string, ValidateTLSCertificate bool) (*lib.SignRes client := &http.Client{Transport: transport} u, err := url.Parse(ca) if err != nil { - return nil, err + return nil, errors.Wrap(err, "unable to parse CA url") } u.Path = path.Join(u.Path, "/sign") req, err := http.NewRequest("POST", u.String(), bytes.NewReader(s)) @@ -68,7 +69,7 @@ func send(s []byte, token, ca string, ValidateTLSCertificate bool) (*lib.SignRes defer resp.Body.Close() c := &lib.SignResponse{} if err := json.NewDecoder(resp.Body).Decode(c); err != nil { - return nil, err + return nil, errors.Wrap(err, "unable to decode server response") } return c, nil } @@ -84,22 +85,22 @@ func Sign(pub ssh.PublicKey, token string, conf *Config) (*ssh.Certificate, erro ValidUntil: time.Now().Add(validity), }) if err != nil { - return nil, err + return nil, errors.Wrap(err, "unable to create sign request") } resp, err := send(s, token, conf.CA, conf.ValidateTLSCertificate) if err != nil { - return nil, err + return nil, errors.Wrap(err, "error sending request to CA") } if resp.Status != "ok" { - return nil, fmt.Errorf("error: %s", resp.Response) + return nil, fmt.Errorf("bad response from CA: %s", resp.Response) } k, _, _, _, err := ssh.ParseAuthorizedKey([]byte(resp.Response)) if err != nil { - return nil, err + return nil, errors.Wrap(err, "unable to parse response") } cert, ok := k.(*ssh.Certificate) if !ok { - return nil, fmt.Errorf("did not receive a certificate from server") + return nil, fmt.Errorf("did not receive a valid certificate from server") } return cert, nil } diff --git a/client/keys.go b/client/keys.go index 3d2fb312649fea9381e4b81e68f0e0096a834a96..73983a8207af3e0535a70183e775d756a8886ee7 100644 --- a/client/keys.go +++ b/client/keys.go @@ -8,6 +8,8 @@ import ( "crypto/rsa" "fmt" + "github.com/pkg/errors" + "golang.org/x/crypto/ed25519" "golang.org/x/crypto/ssh" ) @@ -68,7 +70,7 @@ func generateECDSAKey(size int) (Key, error) { case 521: curve = elliptic.P521() default: - return nil, fmt.Errorf("Unsupported key size: %d. Valid sizes are '256', '384', '521'", size) + return nil, fmt.Errorf("Unsupported ECDSA key size: %d. Valid sizes are '256', '384', '521'", size) } return ecdsa.GenerateKey(curve, rand.Reader) } @@ -101,8 +103,8 @@ func GenerateKey(options ...func(*options)) (Key, ssh.PublicKey, error) { privkey, err = generateRSAKey(config.size) } if err != nil { - return nil, nil, err + return nil, nil, errors.Wrapf(err, "unable to generate %s key-pair", config.keytype) } pubkey, err = ssh.NewPublicKey(privkey.Public()) - return privkey, pubkey, err + return privkey, pubkey, errors.Wrap(err, "error parsing public key") } diff --git a/cmd/cashierd/main.go b/cmd/cashierd/main.go index 83627ad6d0566684a75aafa3012aa133e0447764..fc031714a7dbeae318fbe61da585036a36d00530 100644 --- a/cmd/cashierd/main.go +++ b/cmd/cashierd/main.go @@ -5,7 +5,6 @@ import ( "crypto/tls" "encoding/hex" "encoding/json" - "errors" "flag" "fmt" "html/template" @@ -17,6 +16,8 @@ import ( "strconv" "strings" + "github.com/pkg/errors" + "go4.org/wkfs" "golang.org/x/crypto/acme/autocert" "golang.org/x/oauth2" @@ -125,7 +126,7 @@ func (a *appContext) login(w http.ResponseWriter, r *http.Request) (int, error) } // parseKey retrieves and unmarshals the signing request. -func parseKey(r *http.Request) (*lib.SignRequest, error) { +func extractKey(r *http.Request) (*lib.SignRequest, error) { var s lib.SignRequest if err := json.NewDecoder(r.Body).Decode(&s); err != nil { return nil, err @@ -154,23 +155,25 @@ func signHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, er } // Sign the pubkey and issue the cert. - req, err := parseKey(r) + req, err := extractKey(r) if err != nil { - return http.StatusInternalServerError, err + return http.StatusBadRequest, errors.Wrap(err, "unable to extract key from request") } username := a.authprovider.Username(token) a.authprovider.Revoke(token) // We don't need this anymore. cert, err := a.sshKeySigner.SignUserKey(req, username) if err != nil { - return http.StatusInternalServerError, err + return http.StatusInternalServerError, errors.Wrap(err, "error signing key") } if err := a.certstore.SetCert(cert); err != nil { log.Printf("Error recording cert: %v", err) } - json.NewEncoder(w).Encode(&lib.SignResponse{ + if err := json.NewEncoder(w).Encode(&lib.SignResponse{ Status: "ok", Response: lib.GetPublicKey(cert), - }) + }); err != nil { + return http.StatusInternalServerError, errors.Wrap(err, "error encoding response") + } return http.StatusOK, nil } @@ -219,7 +222,7 @@ func listRevokedCertsHandler(a *appContext, w http.ResponseWriter, r *http.Reque } rl, err := a.sshKeySigner.GenerateRevocationList(revoked) if err != nil { - return http.StatusInternalServerError, err + return http.StatusInternalServerError, errors.Wrap(err, "unable to generate KRL") } w.Header().Set("Content-Type", "application/octet-stream") w.Write(rl) @@ -258,7 +261,7 @@ func revokeCertHandler(a *appContext, w http.ResponseWriter, r *http.Request) (i r.ParseForm() for _, id := range r.Form["cert_id"] { if err := a.certstore.Revoke(id); err != nil { - return http.StatusInternalServerError, err + return http.StatusInternalServerError, errors.Wrap(err, "unable to revoke") } } http.Redirect(w, r, "/admin/certs", http.StatusSeeOther) @@ -292,7 +295,7 @@ func newState() string { func readConfig(filename string) (*config.Config, error) { f, err := os.Open(filename) if err != nil { - return nil, err + return nil, errors.Wrap(err, "failed to parse config file") } defer f.Close() return config.ReadConfig(f) @@ -301,11 +304,11 @@ func readConfig(filename string) (*config.Config, error) { func loadCerts(certFile, keyFile string) (tls.Certificate, error) { key, err := wkfs.ReadFile(keyFile) if err != nil { - return tls.Certificate{}, err + return tls.Certificate{}, errors.Wrap(err, "error reading TLS private key") } cert, err := wkfs.ReadFile(certFile) if err != nil { - return tls.Certificate{}, err + return tls.Certificate{}, errors.Wrap(err, "error reading TLS certificate") } return tls.X509KeyPair(cert, key) } @@ -338,14 +341,15 @@ func main() { if conf.Server.HTTPLogFile != "" { logfile, err = os.OpenFile(conf.Server.HTTPLogFile, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0640) if err != nil { - log.Fatal(err) + log.Printf("unable to open %s for writing. logging to stdout", conf.Server.HTTPLogFile) + logfile = os.Stderr } } laddr := fmt.Sprintf("%s:%d", conf.Server.Addr, conf.Server.Port) l, err := net.Listen("tcp", laddr) if err != nil { - log.Fatal(err) + log.Fatal(errors.Wrapf(err, "unable to listen on %s:%d", conf.Server.Addr, conf.Server.Port)) } tlsConfig := &tls.Config{} @@ -364,7 +368,7 @@ func main() { tlsConfig.Certificates = make([]tls.Certificate, 1) tlsConfig.Certificates[0], err = loadCerts(conf.Server.TLSCert, conf.Server.TLSKey) if err != nil { - log.Fatal(err) + log.Fatal(errors.Wrap(err, "unable to create TLS listener")) } } l = tls.NewListener(l, tlsConfig) @@ -373,7 +377,7 @@ func main() { if conf.Server.User != "" { log.Print("Dropping privileges...") if err := drop.DropPrivileges(conf.Server.User); err != nil { - log.Fatal(err) + log.Fatal(errors.Wrap(err, "unable to drop privileges")) } } @@ -388,7 +392,7 @@ func main() { log.Fatalf("Unknown provider %s\n", conf.Auth.Provider) } if err != nil { - log.Fatal(err) + log.Fatal(errors.Wrapf(err, "unable to use provider '%s'", conf.Auth.Provider)) } certstore, err := store.New(conf.Server.Database) diff --git a/server/auth/github/github.go b/server/auth/github/github.go index 24a4bbfdd0e577c98ecd8a2e4575564d601dab93..46cf76a97665bad865de69e675824b65d924b4c8 100644 --- a/server/auth/github/github.go +++ b/server/auth/github/github.go @@ -25,14 +25,16 @@ type Config struct { whitelist map[string]bool } +var _ auth.Provider = (*Config)(nil) + // New creates a new Github provider from a configuration. -func New(c *config.Auth) (auth.Provider, error) { +func New(c *config.Auth) (*Config, error) { uw := make(map[string]bool) for _, u := range c.UsersWhitelist { uw[u] = true } if c.ProviderOpts["organization"] == "" && len(uw) == 0 { - return nil, errors.New("github_opts organization and the users whitelist must not be both empty") + return nil, errors.New("either GitHub organization or users whitelist must be specified") } return &Config{ config: &oauth2.Config{ diff --git a/server/auth/github/github_test.go b/server/auth/github/github_test.go index c0b26a41e9e980ff5dda8655bf18774e19cd05f5..8c51f4f8f13021342f6a93eb87e56ea3c5ca2b16 100644 --- a/server/auth/github/github_test.go +++ b/server/auth/github/github_test.go @@ -4,7 +4,6 @@ import ( "fmt" "testing" - "github.com/nsheridan/cashier/server/auth" "github.com/nsheridan/cashier/server/config" "github.com/stretchr/testify/assert" ) @@ -14,27 +13,48 @@ var ( oauthClientSecret = "secret" oauthCallbackURL = "url" organization = "exampleorg" + users = []string{"user"} ) func TestNew(t *testing.T) { a := assert.New(t) - p, _ := newGithub() - g := p.(*Config) - a.Equal(g.config.ClientID, oauthClientID) - a.Equal(g.config.ClientSecret, oauthClientSecret) - a.Equal(g.config.RedirectURL, oauthCallbackURL) - a.Equal(g.organization, organization) + p, _ := New(&config.Auth{ + OauthClientID: oauthClientID, + OauthClientSecret: oauthClientSecret, + OauthCallbackURL: oauthCallbackURL, + ProviderOpts: map[string]string{"organization": organization}, + UsersWhitelist: users, + }) + a.Equal(p.config.ClientID, oauthClientID) + a.Equal(p.config.ClientSecret, oauthClientSecret) + a.Equal(p.config.RedirectURL, oauthCallbackURL) + a.Equal(p.organization, organization) + a.Equal(p.whitelist, map[string]bool{"user": true}) } -func TestNewEmptyOrganization(t *testing.T) { - organization = "" - a := assert.New(t) - - _, err := newGithub() - a.EqualError(err, "github_opts organization and the users whitelist must not be both empty") - - organization = "exampleorg" +func TestWhitelist(t *testing.T) { + c := &config.Auth{ + OauthClientID: oauthClientID, + OauthClientSecret: oauthClientSecret, + OauthCallbackURL: oauthCallbackURL, + ProviderOpts: map[string]string{"organization": ""}, + UsersWhitelist: []string{}, + } + if _, err := New(c); err == nil { + t.Error("creating a provider without an organization set should return an error") + } + // Set a user whitelist but no domain + c.UsersWhitelist = users + if _, err := New(c); err != nil { + t.Error("creating a provider with users but no organization should not return an error") + } + // Unset the user whitelist and set a domain + c.UsersWhitelist = []string{} + c.ProviderOpts = map[string]string{"organization": organization} + if _, err := New(c); err != nil { + t.Error("creating a provider with an organization set but without a user whitelist should not return an error") + } } func TestStartSession(t *testing.T) { @@ -47,7 +67,7 @@ func TestStartSession(t *testing.T) { a.Contains(s.AuthURL, fmt.Sprintf("client_id=%s", oauthClientID)) } -func newGithub() (auth.Provider, error) { +func newGithub() (*Config, error) { c := &config.Auth{ OauthClientID: oauthClientID, OauthClientSecret: oauthClientSecret, diff --git a/server/auth/google/google.go b/server/auth/google/google.go index 08a4083cf7bec13cb8c57e6582f7d4f5a8d0d460..8c6f53bda128ff0893a2825eff70819e5b9e1f71 100644 --- a/server/auth/google/google.go +++ b/server/auth/google/google.go @@ -27,14 +27,16 @@ type Config struct { whitelist map[string]bool } +var _ auth.Provider = (*Config)(nil) + // New creates a new Google provider from a configuration. -func New(c *config.Auth) (auth.Provider, error) { +func New(c *config.Auth) (*Config, error) { uw := make(map[string]bool) for _, u := range c.UsersWhitelist { uw[u] = true } if c.ProviderOpts["domain"] == "" && len(uw) == 0 { - return nil, errors.New("google_opts domain and the users whitelist must not be both empty") + return nil, errors.New("either Google Apps domain or users whitelist must be specified") } return &Config{ diff --git a/server/auth/google/google_test.go b/server/auth/google/google_test.go index b80c4bf9bf45a5db1a1271b417ce8d1ddf98cb68..b3d26334a9224b79612ed3cc915e26ece47eded7 100644 --- a/server/auth/google/google_test.go +++ b/server/auth/google/google_test.go @@ -4,7 +4,6 @@ import ( "fmt" "testing" - "github.com/nsheridan/cashier/server/auth" "github.com/nsheridan/cashier/server/config" "github.com/stretchr/testify/assert" ) @@ -14,28 +13,42 @@ var ( oauthClientSecret = "secret" oauthCallbackURL = "url" domain = "example.com" + users = []string{"user"} ) func TestNew(t *testing.T) { a := assert.New(t) - - p, _ := newGoogle() - g := p.(*Config) - a.Equal(g.config.ClientID, oauthClientID) - a.Equal(g.config.ClientSecret, oauthClientSecret) - a.Equal(g.config.RedirectURL, oauthCallbackURL) - a.Equal(g.domain, domain) + p, err := newGoogle() + a.NoError(err) + a.Equal(p.config.ClientID, oauthClientID) + a.Equal(p.config.ClientSecret, oauthClientSecret) + a.Equal(p.config.RedirectURL, oauthCallbackURL) + a.Equal(p.domain, domain) + a.Equal(p.whitelist, map[string]bool{"user": true}) } -func TestNewWithoutDomain(t *testing.T) { - a := assert.New(t) - - domain = "" - - _, err := newGoogle() - a.EqualError(err, "google_opts domain and the users whitelist must not be both empty") - - domain = "example.com" +func TestWhitelist(t *testing.T) { + c := &config.Auth{ + OauthClientID: oauthClientID, + OauthClientSecret: oauthClientSecret, + OauthCallbackURL: oauthCallbackURL, + ProviderOpts: map[string]string{"domain": ""}, + UsersWhitelist: []string{}, + } + if _, err := New(c); err == nil { + t.Error("creating a provider without a domain set should return an error") + } + // Set a user whitelist but no domain + c.UsersWhitelist = users + if _, err := New(c); err != nil { + t.Error("creating a provider with users but no domain should not return an error") + } + // Unset the user whitelist and set a domain + c.UsersWhitelist = []string{} + c.ProviderOpts = map[string]string{"domain": domain} + if _, err := New(c); err != nil { + t.Error("creating a provider with a domain set but without a user whitelist should not return an error") + } } func TestStartSession(t *testing.T) { @@ -50,12 +63,13 @@ func TestStartSession(t *testing.T) { a.Contains(s.AuthURL, fmt.Sprintf("client_id=%s", oauthClientID)) } -func newGoogle() (auth.Provider, error) { +func newGoogle() (*Config, error) { c := &config.Auth{ OauthClientID: oauthClientID, OauthClientSecret: oauthClientSecret, OauthCallbackURL: oauthCallbackURL, ProviderOpts: map[string]string{"domain": domain}, + UsersWhitelist: users, } return New(c) } diff --git a/server/auth/testprovider/testprovider.go b/server/auth/testprovider/testprovider.go index 3d2b13a49845065e3d67a36a696adac5753ea8e8..e30b04aaa58e8cae689cb47b6b992ead3ebe02f4 100644 --- a/server/auth/testprovider/testprovider.go +++ b/server/auth/testprovider/testprovider.go @@ -15,8 +15,10 @@ const ( // Config is an implementation of `auth.Provider` for testing. type Config struct{} +var _ auth.Provider = (*Config)(nil) + // New creates a new provider. -func New() auth.Provider { +func New() *Config { return &Config{} } diff --git a/server/config/config.go b/server/config/config.go index 5f3f458cf240913f8ec674f6c2e670f1f5a520b6..f2598a0c46c458b32cef4ac0f0a5a8b558cb9a9f 100644 --- a/server/config/config.go +++ b/server/config/config.go @@ -1,7 +1,6 @@ package config import ( - "errors" "fmt" "io" "log" @@ -12,6 +11,7 @@ import ( "github.com/hashicorp/go-multierror" "github.com/mitchellh/mapstructure" "github.com/nsheridan/cashier/server/helpers/vault" + "github.com/pkg/errors" "github.com/spf13/viper" ) @@ -156,14 +156,14 @@ func setFromVault(c *Config) error { } v, err := vault.NewClient(c.Vault.Address, c.Vault.Token) if err != nil { - return err + return errors.Wrap(err, "vault error") } - var errors error + var errs error get := func(value string) string { if strings.HasPrefix(value, "/vault/") { s, err := v.Read(value) if err != nil { - errors = multierror.Append(errors, err) + errs = multierror.Append(errs, err) } return s } @@ -180,12 +180,12 @@ func setFromVault(c *Config) error { c.AWS.AccessKey = get(c.AWS.AccessKey) c.AWS.SecretKey = get(c.AWS.SecretKey) } - return errors + return errors.Wrap(errs, "errors reading from vault") } // Unmarshal the config into a *Config func decode() (*Config, error) { - var errors error + var errs error config := &Config{} configPieces := map[string]interface{}{ "auth": &config.Auth, @@ -200,21 +200,21 @@ func decode() (*Config, error) { continue } if err := mapstructure.WeakDecode(conf[0], val); err != nil { - errors = multierror.Append(errors, err) + errs = multierror.Append(errs, err) } } - return config, errors + return config, errs } // ReadConfig parses a hcl configuration file into a Config struct. func ReadConfig(r io.Reader) (*Config, error) { viper.SetConfigType("hcl") if err := viper.ReadConfig(r); err != nil { - return nil, err + return nil, errors.Wrap(err, "unable to read config") } config, err := decode() if err != nil { - return nil, err + return nil, errors.Wrap(err, "unable to parse config") } if err := setFromVault(config); err != nil { return nil, err @@ -222,7 +222,7 @@ func ReadConfig(r io.Reader) (*Config, error) { setFromEnvironment(config) convertDatastoreConfig(config) if err := verifyConfig(config); err != nil { - return nil, err + return nil, errors.Wrap(err, "unable to verify config") } return config, nil } diff --git a/server/store/mem.go b/server/store/mem.go index 54aa9655d011794bb4446d788129ba7a6740097e..e289b161abbf09a92c9b9c243a1daedeefeb05c7 100644 --- a/server/store/mem.go +++ b/server/store/mem.go @@ -8,12 +8,16 @@ import ( "golang.org/x/crypto/ssh" ) -type memoryStore struct { +var _ CertStorer = (*MemoryStore)(nil) + +// MemoryStore is an in-memory CertStorer +type MemoryStore struct { sync.Mutex certs map[string]*CertRecord } -func (ms *memoryStore) Get(id string) (*CertRecord, error) { +// Get a single *CertRecord +func (ms *MemoryStore) Get(id string) (*CertRecord, error) { ms.Lock() defer ms.Unlock() r, ok := ms.certs[id] @@ -23,18 +27,22 @@ func (ms *memoryStore) Get(id string) (*CertRecord, error) { return r, nil } -func (ms *memoryStore) SetCert(cert *ssh.Certificate) error { +// SetCert parses a *ssh.Certificate and records it +func (ms *MemoryStore) SetCert(cert *ssh.Certificate) error { return ms.SetRecord(parseCertificate(cert)) } -func (ms *memoryStore) SetRecord(record *CertRecord) error { +// SetRecord records a *CertRecord +func (ms *MemoryStore) SetRecord(record *CertRecord) error { ms.Lock() defer ms.Unlock() ms.certs[record.KeyID] = record return nil } -func (ms *memoryStore) List(includeExpired bool) ([]*CertRecord, error) { +// List returns all recorded certs. +// By default only active certs are returned. +func (ms *MemoryStore) List(includeExpired bool) ([]*CertRecord, error) { var records []*CertRecord ms.Lock() defer ms.Unlock() @@ -48,7 +56,8 @@ func (ms *memoryStore) List(includeExpired bool) ([]*CertRecord, error) { return records, nil } -func (ms *memoryStore) Revoke(id string) error { +// Revoke an issued cert by id. +func (ms *MemoryStore) Revoke(id string) error { r, err := ms.Get(id) if err != nil { return err @@ -58,7 +67,8 @@ func (ms *memoryStore) Revoke(id string) error { return nil } -func (ms *memoryStore) GetRevoked() ([]*CertRecord, error) { +// GetRevoked returns all revoked certs +func (ms *MemoryStore) GetRevoked() ([]*CertRecord, error) { var revoked []*CertRecord all, _ := ms.List(false) for _, r := range all { @@ -69,22 +79,23 @@ func (ms *memoryStore) GetRevoked() ([]*CertRecord, error) { return revoked, nil } -func (ms *memoryStore) Close() error { +// Close the store. This will clear the contents. +func (ms *MemoryStore) Close() error { ms.Lock() defer ms.Unlock() ms.certs = nil return nil } -func (ms *memoryStore) clear() { +func (ms *MemoryStore) clear() { for k := range ms.certs { delete(ms.certs, k) } } // NewMemoryStore returns an in-memory CertStorer. -func NewMemoryStore() CertStorer { - return &memoryStore{ +func NewMemoryStore() *MemoryStore { + return &MemoryStore{ certs: make(map[string]*CertRecord), } } diff --git a/server/store/mongo.go b/server/store/mongo.go index fc4131f21eb6d515a59a74ea70297f93eb640897..6a23e8936539a63fd74108e48fc52b96f9d0c0bd 100644 --- a/server/store/mongo.go +++ b/server/store/mongo.go @@ -22,7 +22,7 @@ func collection(session *mgo.Session) *mgo.Collection { } // NewMongoStore returns a MongoDB CertStorer. -func NewMongoStore(c config.Database) (CertStorer, error) { +func NewMongoStore(c config.Database) (*MongoStore, error) { m := &mgo.DialInfo{ Addrs: strings.Split(c["address"], ","), Username: c["username"], @@ -34,16 +34,20 @@ func NewMongoStore(c config.Database) (CertStorer, error) { if err != nil { return nil, err } - return &mongoDB{ + return &MongoStore{ session: session, }, nil } -type mongoDB struct { +var _ CertStorer = (*MongoStore)(nil) + +// MongoStore is a MongoDB-based CertStorer +type MongoStore struct { session *mgo.Session } -func (m *mongoDB) Get(id string) (*CertRecord, error) { +// Get a single *CertRecord +func (m *MongoStore) Get(id string) (*CertRecord, error) { s := m.session.Copy() defer s.Close() if err := s.Ping(); err != nil { @@ -54,12 +58,14 @@ func (m *mongoDB) Get(id string) (*CertRecord, error) { return c, err } -func (m *mongoDB) SetCert(cert *ssh.Certificate) error { +// SetCert parses a *ssh.Certificate and records it +func (m *MongoStore) SetCert(cert *ssh.Certificate) error { r := parseCertificate(cert) return m.SetRecord(r) } -func (m *mongoDB) SetRecord(record *CertRecord) error { +// SetRecord records a *CertRecord +func (m *MongoStore) SetRecord(record *CertRecord) error { s := m.session.Copy() defer s.Close() if err := s.Ping(); err != nil { @@ -68,7 +74,9 @@ func (m *mongoDB) SetRecord(record *CertRecord) error { return collection(s).Insert(record) } -func (m *mongoDB) List(includeExpired bool) ([]*CertRecord, error) { +// List returns all recorded certs. +// By default only active certs are returned. +func (m *MongoStore) List(includeExpired bool) ([]*CertRecord, error) { s := m.session.Copy() defer s.Close() if err := s.Ping(); err != nil { @@ -85,7 +93,8 @@ func (m *mongoDB) List(includeExpired bool) ([]*CertRecord, error) { return result, err } -func (m *mongoDB) Revoke(id string) error { +// Revoke an issued cert by id. +func (m *MongoStore) Revoke(id string) error { s := m.session.Copy() defer s.Close() if err := s.Ping(); err != nil { @@ -95,7 +104,8 @@ func (m *mongoDB) Revoke(id string) error { return c.Update(bson.M{"keyid": id}, bson.M{"$set": bson.M{"revoked": true}}) } -func (m *mongoDB) GetRevoked() ([]*CertRecord, error) { +// GetRevoked returns all revoked certs +func (m *MongoStore) GetRevoked() ([]*CertRecord, error) { s := m.session.Copy() defer s.Close() if err := s.Ping(); err != nil { @@ -106,7 +116,8 @@ func (m *mongoDB) GetRevoked() ([]*CertRecord, error) { return result, err } -func (m *mongoDB) Close() error { +// Close the connection to the database +func (m *MongoStore) Close() error { m.session.Close() return nil } diff --git a/server/store/sqldb.go b/server/store/sqldb.go index d7ef878d8d30e39acf13c85bb18d42c154cfc854..a51678ecdef4b1e87c99e25517199f54bdf39fb9 100644 --- a/server/store/sqldb.go +++ b/server/store/sqldb.go @@ -13,7 +13,10 @@ import ( "github.com/nsheridan/cashier/server/config" ) -type sqldb struct { +var _ CertStorer = (*SQLStore)(nil) + +// SQLStore is an sql-based CertStorer +type SQLStore struct { conn *sql.DB get *sql.Stmt @@ -25,7 +28,7 @@ type sqldb struct { } // NewSQLStore returns a *sql.DB CertStorer. -func NewSQLStore(c config.Database) (CertStorer, error) { +func NewSQLStore(c config.Database) (*SQLStore, error) { var driver string var dsn string switch c["type"] { @@ -51,34 +54,34 @@ func NewSQLStore(c config.Database) (CertStorer, error) { } conn, err := sql.Open(driver, dsn) if err != nil { - return nil, fmt.Errorf("sqldb: could not get a connection: %v", err) + return nil, fmt.Errorf("SQLStore: could not get a connection: %v", err) } if err := conn.Ping(); err != nil { conn.Close() - return nil, fmt.Errorf("sqldb: could not establish a good connection: %v", err) + return nil, fmt.Errorf("SQLStore: could not establish a good connection: %v", err) } - db := &sqldb{ + db := &SQLStore{ conn: conn, } if db.set, err = conn.Prepare("INSERT INTO issued_certs (key_id, principals, created_at, expires_at, raw_key) VALUES (?, ?, ?, ?, ?)"); err != nil { - return nil, fmt.Errorf("sqldb: prepare set: %v", err) + return nil, fmt.Errorf("SQLStore: prepare set: %v", err) } if db.get, err = conn.Prepare("SELECT * FROM issued_certs WHERE key_id = ?"); err != nil { - return nil, fmt.Errorf("sqldb: prepare get: %v", err) + return nil, fmt.Errorf("SQLStore: prepare get: %v", err) } if db.listAll, err = conn.Prepare("SELECT * FROM issued_certs"); err != nil { - return nil, fmt.Errorf("sqldb: prepare listAll: %v", err) + return nil, fmt.Errorf("SQLStore: prepare listAll: %v", err) } if db.listCurrent, err = conn.Prepare("SELECT * FROM issued_certs WHERE ? <= expires_at"); err != nil { - return nil, fmt.Errorf("sqldb: prepare listCurrent: %v", err) + return nil, fmt.Errorf("SQLStore: prepare listCurrent: %v", err) } if db.revoke, err = conn.Prepare("UPDATE issued_certs SET revoked = 1 WHERE key_id = ?"); err != nil { - return nil, fmt.Errorf("sqldb: prepare revoke: %v", err) + return nil, fmt.Errorf("SQLStore: prepare revoke: %v", err) } if db.revoked, err = conn.Prepare("SELECT * FROM issued_certs WHERE revoked = 1 AND ? <= expires_at"); err != nil { - return nil, fmt.Errorf("sqldb: prepare revoked: %v", err) + return nil, fmt.Errorf("SQLStore: prepare revoked: %v", err) } return db, nil } @@ -114,18 +117,21 @@ func scanCert(s rowScanner) (*CertRecord, error) { }, nil } -func (db *sqldb) Get(id string) (*CertRecord, error) { +// Get a single *CertRecord +func (db *SQLStore) Get(id string) (*CertRecord, error) { if err := db.conn.Ping(); err != nil { return nil, err } return scanCert(db.get.QueryRow(id)) } -func (db *sqldb) SetCert(cert *ssh.Certificate) error { +// SetCert parses a *ssh.Certificate and records it +func (db *SQLStore) SetCert(cert *ssh.Certificate) error { return db.SetRecord(parseCertificate(cert)) } -func (db *sqldb) SetRecord(rec *CertRecord) error { +// SetRecord records a *CertRecord +func (db *SQLStore) SetRecord(rec *CertRecord) error { principals, err := json.Marshal(rec.Principals) if err != nil { return err @@ -137,7 +143,9 @@ func (db *sqldb) SetRecord(rec *CertRecord) error { return err } -func (db *sqldb) List(includeExpired bool) ([]*CertRecord, error) { +// List returns all recorded certs. +// By default only active certs are returned. +func (db *SQLStore) List(includeExpired bool) ([]*CertRecord, error) { if err := db.conn.Ping(); err != nil { return nil, err } @@ -159,7 +167,8 @@ func (db *sqldb) List(includeExpired bool) ([]*CertRecord, error) { return recs, nil } -func (db *sqldb) Revoke(id string) error { +// Revoke an issued cert by id. +func (db *SQLStore) Revoke(id string) error { if err := db.conn.Ping(); err != nil { return err } @@ -170,7 +179,8 @@ func (db *sqldb) Revoke(id string) error { return nil } -func (db *sqldb) GetRevoked() ([]*CertRecord, error) { +// GetRevoked returns all revoked certs +func (db *SQLStore) GetRevoked() ([]*CertRecord, error) { if err := db.conn.Ping(); err != nil { return nil, err } @@ -187,6 +197,7 @@ func (db *sqldb) GetRevoked() ([]*CertRecord, error) { return recs, nil } -func (db *sqldb) Close() error { +// Close the connection to the database +func (db *SQLStore) Close() error { return db.conn.Close() }