Skip to content
Snippets Groups Projects
Commit d3d2d538 authored by Niall Sheridan's avatar Niall Sheridan
Browse files

Make CertStorer implementations public

parent 5d7e2397
No related branches found
No related tags found
No related merge requests found
......@@ -8,12 +8,16 @@ import (
"golang.org/x/crypto/ssh"
)
type memoryStore struct {
var _ CertStorer = (*MemoryStore)(nil)
// MemoryStore is an in-memory CertStorer
type MemoryStore struct {
sync.Mutex
certs map[string]*CertRecord
}
func (ms *memoryStore) Get(id string) (*CertRecord, error) {
// Get a single *CertRecord
func (ms *MemoryStore) Get(id string) (*CertRecord, error) {
ms.Lock()
defer ms.Unlock()
r, ok := ms.certs[id]
......@@ -23,18 +27,22 @@ func (ms *memoryStore) Get(id string) (*CertRecord, error) {
return r, nil
}
func (ms *memoryStore) SetCert(cert *ssh.Certificate) error {
// SetCert parses a *ssh.Certificate and records it
func (ms *MemoryStore) SetCert(cert *ssh.Certificate) error {
return ms.SetRecord(parseCertificate(cert))
}
func (ms *memoryStore) SetRecord(record *CertRecord) error {
// SetRecord records a *CertRecord
func (ms *MemoryStore) SetRecord(record *CertRecord) error {
ms.Lock()
defer ms.Unlock()
ms.certs[record.KeyID] = record
return nil
}
func (ms *memoryStore) List(includeExpired bool) ([]*CertRecord, error) {
// List returns all recorded certs.
// By default only active certs are returned.
func (ms *MemoryStore) List(includeExpired bool) ([]*CertRecord, error) {
var records []*CertRecord
ms.Lock()
defer ms.Unlock()
......@@ -48,7 +56,8 @@ func (ms *memoryStore) List(includeExpired bool) ([]*CertRecord, error) {
return records, nil
}
func (ms *memoryStore) Revoke(id string) error {
// Revoke an issued cert by id.
func (ms *MemoryStore) Revoke(id string) error {
r, err := ms.Get(id)
if err != nil {
return err
......@@ -58,7 +67,8 @@ func (ms *memoryStore) Revoke(id string) error {
return nil
}
func (ms *memoryStore) GetRevoked() ([]*CertRecord, error) {
// GetRevoked returns all revoked certs
func (ms *MemoryStore) GetRevoked() ([]*CertRecord, error) {
var revoked []*CertRecord
all, _ := ms.List(false)
for _, r := range all {
......@@ -69,22 +79,23 @@ func (ms *memoryStore) GetRevoked() ([]*CertRecord, error) {
return revoked, nil
}
func (ms *memoryStore) Close() error {
// Close the store. This will clear the contents.
func (ms *MemoryStore) Close() error {
ms.Lock()
defer ms.Unlock()
ms.certs = nil
return nil
}
func (ms *memoryStore) clear() {
func (ms *MemoryStore) clear() {
for k := range ms.certs {
delete(ms.certs, k)
}
}
// NewMemoryStore returns an in-memory CertStorer.
func NewMemoryStore() CertStorer {
return &memoryStore{
func NewMemoryStore() *MemoryStore {
return &MemoryStore{
certs: make(map[string]*CertRecord),
}
}
......@@ -22,7 +22,7 @@ func collection(session *mgo.Session) *mgo.Collection {
}
// NewMongoStore returns a MongoDB CertStorer.
func NewMongoStore(c config.Database) (CertStorer, error) {
func NewMongoStore(c config.Database) (*MongoStore, error) {
m := &mgo.DialInfo{
Addrs: strings.Split(c["address"], ","),
Username: c["username"],
......@@ -34,16 +34,20 @@ func NewMongoStore(c config.Database) (CertStorer, error) {
if err != nil {
return nil, err
}
return &mongoDB{
return &MongoStore{
session: session,
}, nil
}
type mongoDB struct {
var _ CertStorer = (*MongoStore)(nil)
// MongoStore is a MongoDB-based CertStorer
type MongoStore struct {
session *mgo.Session
}
func (m *mongoDB) Get(id string) (*CertRecord, error) {
// Get a single *CertRecord
func (m *MongoStore) Get(id string) (*CertRecord, error) {
s := m.session.Copy()
defer s.Close()
if err := s.Ping(); err != nil {
......@@ -54,12 +58,14 @@ func (m *mongoDB) Get(id string) (*CertRecord, error) {
return c, err
}
func (m *mongoDB) SetCert(cert *ssh.Certificate) error {
// SetCert parses a *ssh.Certificate and records it
func (m *MongoStore) SetCert(cert *ssh.Certificate) error {
r := parseCertificate(cert)
return m.SetRecord(r)
}
func (m *mongoDB) SetRecord(record *CertRecord) error {
// SetRecord records a *CertRecord
func (m *MongoStore) SetRecord(record *CertRecord) error {
s := m.session.Copy()
defer s.Close()
if err := s.Ping(); err != nil {
......@@ -68,7 +74,9 @@ func (m *mongoDB) SetRecord(record *CertRecord) error {
return collection(s).Insert(record)
}
func (m *mongoDB) List(includeExpired bool) ([]*CertRecord, error) {
// List returns all recorded certs.
// By default only active certs are returned.
func (m *MongoStore) List(includeExpired bool) ([]*CertRecord, error) {
s := m.session.Copy()
defer s.Close()
if err := s.Ping(); err != nil {
......@@ -85,7 +93,8 @@ func (m *mongoDB) List(includeExpired bool) ([]*CertRecord, error) {
return result, err
}
func (m *mongoDB) Revoke(id string) error {
// Revoke an issued cert by id.
func (m *MongoStore) Revoke(id string) error {
s := m.session.Copy()
defer s.Close()
if err := s.Ping(); err != nil {
......@@ -95,7 +104,8 @@ func (m *mongoDB) Revoke(id string) error {
return c.Update(bson.M{"keyid": id}, bson.M{"$set": bson.M{"revoked": true}})
}
func (m *mongoDB) GetRevoked() ([]*CertRecord, error) {
// GetRevoked returns all revoked certs
func (m *MongoStore) GetRevoked() ([]*CertRecord, error) {
s := m.session.Copy()
defer s.Close()
if err := s.Ping(); err != nil {
......@@ -106,7 +116,8 @@ func (m *mongoDB) GetRevoked() ([]*CertRecord, error) {
return result, err
}
func (m *mongoDB) Close() error {
// Close the connection to the database
func (m *MongoStore) Close() error {
m.session.Close()
return nil
}
......@@ -13,7 +13,10 @@ import (
"github.com/nsheridan/cashier/server/config"
)
type sqldb struct {
var _ CertStorer = (*SQLStore)(nil)
// SQLStore is an sql-based CertStorer
type SQLStore struct {
conn *sql.DB
get *sql.Stmt
......@@ -25,7 +28,7 @@ type sqldb struct {
}
// NewSQLStore returns a *sql.DB CertStorer.
func NewSQLStore(c config.Database) (CertStorer, error) {
func NewSQLStore(c config.Database) (*SQLStore, error) {
var driver string
var dsn string
switch c["type"] {
......@@ -51,34 +54,34 @@ func NewSQLStore(c config.Database) (CertStorer, error) {
}
conn, err := sql.Open(driver, dsn)
if err != nil {
return nil, fmt.Errorf("sqldb: could not get a connection: %v", err)
return nil, fmt.Errorf("SQLStore: could not get a connection: %v", err)
}
if err := conn.Ping(); err != nil {
conn.Close()
return nil, fmt.Errorf("sqldb: could not establish a good connection: %v", err)
return nil, fmt.Errorf("SQLStore: could not establish a good connection: %v", err)
}
db := &sqldb{
db := &SQLStore{
conn: conn,
}
if db.set, err = conn.Prepare("INSERT INTO issued_certs (key_id, principals, created_at, expires_at, raw_key) VALUES (?, ?, ?, ?, ?)"); err != nil {
return nil, fmt.Errorf("sqldb: prepare set: %v", err)
return nil, fmt.Errorf("SQLStore: prepare set: %v", err)
}
if db.get, err = conn.Prepare("SELECT * FROM issued_certs WHERE key_id = ?"); err != nil {
return nil, fmt.Errorf("sqldb: prepare get: %v", err)
return nil, fmt.Errorf("SQLStore: prepare get: %v", err)
}
if db.listAll, err = conn.Prepare("SELECT * FROM issued_certs"); err != nil {
return nil, fmt.Errorf("sqldb: prepare listAll: %v", err)
return nil, fmt.Errorf("SQLStore: prepare listAll: %v", err)
}
if db.listCurrent, err = conn.Prepare("SELECT * FROM issued_certs WHERE ? <= expires_at"); err != nil {
return nil, fmt.Errorf("sqldb: prepare listCurrent: %v", err)
return nil, fmt.Errorf("SQLStore: prepare listCurrent: %v", err)
}
if db.revoke, err = conn.Prepare("UPDATE issued_certs SET revoked = 1 WHERE key_id = ?"); err != nil {
return nil, fmt.Errorf("sqldb: prepare revoke: %v", err)
return nil, fmt.Errorf("SQLStore: prepare revoke: %v", err)
}
if db.revoked, err = conn.Prepare("SELECT * FROM issued_certs WHERE revoked = 1 AND ? <= expires_at"); err != nil {
return nil, fmt.Errorf("sqldb: prepare revoked: %v", err)
return nil, fmt.Errorf("SQLStore: prepare revoked: %v", err)
}
return db, nil
}
......@@ -114,18 +117,21 @@ func scanCert(s rowScanner) (*CertRecord, error) {
}, nil
}
func (db *sqldb) Get(id string) (*CertRecord, error) {
// Get a single *CertRecord
func (db *SQLStore) Get(id string) (*CertRecord, error) {
if err := db.conn.Ping(); err != nil {
return nil, err
}
return scanCert(db.get.QueryRow(id))
}
func (db *sqldb) SetCert(cert *ssh.Certificate) error {
// SetCert parses a *ssh.Certificate and records it
func (db *SQLStore) SetCert(cert *ssh.Certificate) error {
return db.SetRecord(parseCertificate(cert))
}
func (db *sqldb) SetRecord(rec *CertRecord) error {
// SetRecord records a *CertRecord
func (db *SQLStore) SetRecord(rec *CertRecord) error {
principals, err := json.Marshal(rec.Principals)
if err != nil {
return err
......@@ -137,7 +143,9 @@ func (db *sqldb) SetRecord(rec *CertRecord) error {
return err
}
func (db *sqldb) List(includeExpired bool) ([]*CertRecord, error) {
// List returns all recorded certs.
// By default only active certs are returned.
func (db *SQLStore) List(includeExpired bool) ([]*CertRecord, error) {
if err := db.conn.Ping(); err != nil {
return nil, err
}
......@@ -159,7 +167,8 @@ func (db *sqldb) List(includeExpired bool) ([]*CertRecord, error) {
return recs, nil
}
func (db *sqldb) Revoke(id string) error {
// Revoke an issued cert by id.
func (db *SQLStore) Revoke(id string) error {
if err := db.conn.Ping(); err != nil {
return err
}
......@@ -170,7 +179,8 @@ func (db *sqldb) Revoke(id string) error {
return nil
}
func (db *sqldb) GetRevoked() ([]*CertRecord, error) {
// GetRevoked returns all revoked certs
func (db *SQLStore) GetRevoked() ([]*CertRecord, error) {
if err := db.conn.Ping(); err != nil {
return nil, err
}
......@@ -187,6 +197,7 @@ func (db *sqldb) GetRevoked() ([]*CertRecord, error) {
return recs, nil
}
func (db *sqldb) Close() error {
// Close the connection to the database
func (db *SQLStore) Close() error {
return db.conn.Close()
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment