Skip to content
Snippets Groups Projects
Commit 4f2385db authored by Niall Sheridan's avatar Niall Sheridan
Browse files

Unexport store implementations

Return an error if the store isn't known, instead of defaulting to a mem store
parent 162efe88
Branches
Tags
No related merge requests found
......@@ -38,7 +38,7 @@ func init() {
MaxAge: "1h",
})
authprovider = testprovider.New()
certstore = store.NewMemoryStore()
certstore, _ = store.New(map[string]string{"type": "mem"})
ctx = &appContext{
cookiestore: sessions.NewCookieStore([]byte("secret")),
authsession: &auth.Session{AuthURL: "https://www.example.com/auth"},
......
......@@ -2,23 +2,22 @@ package store
import (
"fmt"
"log"
"sync"
"time"
"golang.org/x/crypto/ssh"
)
var _ CertStorer = (*MemoryStore)(nil)
var _ CertStorer = (*memoryStore)(nil)
// MemoryStore is an in-memory CertStorer
type MemoryStore struct {
// memoryStore is an in-memory CertStorer
type memoryStore struct {
sync.Mutex
certs map[string]*CertRecord
}
// Get a single *CertRecord
func (ms *MemoryStore) Get(id string) (*CertRecord, error) {
func (ms *memoryStore) Get(id string) (*CertRecord, error) {
ms.Lock()
defer ms.Unlock()
r, ok := ms.certs[id]
......@@ -29,12 +28,12 @@ func (ms *MemoryStore) Get(id string) (*CertRecord, error) {
}
// SetCert parses a *ssh.Certificate and records it
func (ms *MemoryStore) SetCert(cert *ssh.Certificate) error {
func (ms *memoryStore) SetCert(cert *ssh.Certificate) error {
return ms.SetRecord(parseCertificate(cert))
}
// SetRecord records a *CertRecord
func (ms *MemoryStore) SetRecord(record *CertRecord) error {
func (ms *memoryStore) SetRecord(record *CertRecord) error {
ms.Lock()
defer ms.Unlock()
ms.certs[record.KeyID] = record
......@@ -43,7 +42,7 @@ func (ms *MemoryStore) SetRecord(record *CertRecord) error {
// List returns all recorded certs.
// By default only active certs are returned.
func (ms *MemoryStore) List(includeExpired bool) ([]*CertRecord, error) {
func (ms *memoryStore) List(includeExpired bool) ([]*CertRecord, error) {
var records []*CertRecord
ms.Lock()
defer ms.Unlock()
......@@ -58,7 +57,7 @@ func (ms *MemoryStore) List(includeExpired bool) ([]*CertRecord, error) {
}
// Revoke an issued cert by id.
func (ms *MemoryStore) Revoke(ids []string) error {
func (ms *memoryStore) Revoke(ids []string) error {
ms.Lock()
defer ms.Unlock()
for _, id := range ids {
......@@ -68,7 +67,7 @@ func (ms *MemoryStore) Revoke(ids []string) error {
}
// GetRevoked returns all revoked certs
func (ms *MemoryStore) GetRevoked() ([]*CertRecord, error) {
func (ms *memoryStore) GetRevoked() ([]*CertRecord, error) {
var revoked []*CertRecord
all, _ := ms.List(false)
for _, r := range all {
......@@ -80,23 +79,22 @@ func (ms *MemoryStore) GetRevoked() ([]*CertRecord, error) {
}
// Close the store. This will clear the contents.
func (ms *MemoryStore) Close() error {
func (ms *memoryStore) Close() error {
ms.Lock()
defer ms.Unlock()
ms.certs = nil
return nil
}
func (ms *MemoryStore) clear() {
func (ms *memoryStore) clear() {
for k := range ms.certs {
delete(ms.certs, k)
}
}
// NewMemoryStore returns an in-memory CertStorer.
func NewMemoryStore() *MemoryStore {
log.Println("WARNING: Using memory store to record issued certs.")
return &MemoryStore{
// newMemoryStore returns an in-memory CertStorer.
func newMemoryStore() *memoryStore {
return &memoryStore{
certs: make(map[string]*CertRecord),
}
}
......@@ -18,10 +18,10 @@ import (
migrate "github.com/rubenv/sql-migrate"
)
var _ CertStorer = (*SQLStore)(nil)
var _ CertStorer = (*sqlStore)(nil)
// SQLStore is an sql-based CertStorer
type SQLStore struct {
// sqlStore is an sql-based CertStorer
type sqlStore struct {
conn *sqlx.DB
get *sqlx.Stmt
......@@ -31,8 +31,8 @@ type SQLStore struct {
revoked *sqlx.Stmt
}
// NewSQLStore returns a *sql.DB CertStorer.
func NewSQLStore(c config.Database) (*SQLStore, error) {
// newSQLStore returns a *sql.DB CertStorer.
func newSQLStore(c config.Database) (*sqlStore, error) {
var driver string
var dsn string
switch c["type"] {
......@@ -61,30 +61,30 @@ func NewSQLStore(c config.Database) (*SQLStore, error) {
conn, err := sqlx.Connect(driver, dsn)
if err != nil {
return nil, fmt.Errorf("SQLStore: could not get a connection: %v", err)
return nil, fmt.Errorf("sqlStore: could not get a connection: %v", err)
}
if err := autoMigrate(driver, conn); err != nil {
return nil, fmt.Errorf("SQLStore: could not update schema: %v", err)
return nil, fmt.Errorf("sqlStore: could not update schema: %v", err)
}
db := &SQLStore{
db := &sqlStore{
conn: conn,
}
if db.set, err = conn.Preparex("INSERT INTO issued_certs (key_id, principals, created_at, expires_at, raw_key) VALUES (?, ?, ?, ?, ?)"); err != nil {
return nil, fmt.Errorf("SQLStore: prepare set: %v", err)
return nil, fmt.Errorf("sqlStore: prepare set: %v", err)
}
if db.get, err = conn.Preparex("SELECT * FROM issued_certs WHERE key_id = ?"); err != nil {
return nil, fmt.Errorf("SQLStore: prepare get: %v", err)
return nil, fmt.Errorf("sqlStore: prepare get: %v", err)
}
if db.listAll, err = conn.Preparex("SELECT * FROM issued_certs"); err != nil {
return nil, fmt.Errorf("SQLStore: prepare listAll: %v", err)
return nil, fmt.Errorf("sqlStore: prepare listAll: %v", err)
}
if db.listCurrent, err = conn.Preparex("SELECT * FROM issued_certs WHERE expires_at >= ?"); err != nil {
return nil, fmt.Errorf("SQLStore: prepare listCurrent: %v", err)
return nil, fmt.Errorf("sqlStore: prepare listCurrent: %v", err)
}
if db.revoked, err = conn.Preparex("SELECT * FROM issued_certs WHERE revoked = 1 AND ? <= expires_at"); err != nil {
return nil, fmt.Errorf("SQLStore: prepare revoked: %v", err)
return nil, fmt.Errorf("sqlStore: prepare revoked: %v", err)
}
return db, nil
}
......@@ -114,7 +114,7 @@ type rowScanner interface {
}
// Get a single *CertRecord
func (db *SQLStore) Get(id string) (*CertRecord, error) {
func (db *sqlStore) Get(id string) (*CertRecord, error) {
if err := db.conn.Ping(); err != nil {
return nil, errors.Wrap(err, "unable to connect to database")
}
......@@ -123,12 +123,12 @@ func (db *SQLStore) Get(id string) (*CertRecord, error) {
}
// SetCert parses a *ssh.Certificate and records it
func (db *SQLStore) SetCert(cert *ssh.Certificate) error {
func (db *sqlStore) SetCert(cert *ssh.Certificate) error {
return db.SetRecord(parseCertificate(cert))
}
// SetRecord records a *CertRecord
func (db *SQLStore) SetRecord(rec *CertRecord) error {
func (db *sqlStore) SetRecord(rec *CertRecord) error {
if err := db.conn.Ping(); err != nil {
return errors.Wrap(err, "unable to connect to database")
}
......@@ -138,7 +138,7 @@ func (db *SQLStore) SetRecord(rec *CertRecord) error {
// List returns all recorded certs.
// By default only active certs are returned.
func (db *SQLStore) List(includeExpired bool) ([]*CertRecord, error) {
func (db *sqlStore) List(includeExpired bool) ([]*CertRecord, error) {
if err := db.conn.Ping(); err != nil {
return nil, errors.Wrap(err, "unable to connect to database")
}
......@@ -156,7 +156,7 @@ func (db *SQLStore) List(includeExpired bool) ([]*CertRecord, error) {
}
// Revoke an issued cert by id.
func (db *SQLStore) Revoke(ids []string) error {
func (db *sqlStore) Revoke(ids []string) error {
if err := db.conn.Ping(); err != nil {
return errors.Wrap(err, "unable to connect to database")
}
......@@ -166,7 +166,7 @@ func (db *SQLStore) Revoke(ids []string) error {
}
// GetRevoked returns all revoked certs
func (db *SQLStore) GetRevoked() ([]*CertRecord, error) {
func (db *sqlStore) GetRevoked() ([]*CertRecord, error) {
if err := db.conn.Ping(); err != nil {
return nil, errors.Wrap(err, "unable to connect to database")
}
......@@ -178,6 +178,6 @@ func (db *SQLStore) GetRevoked() ([]*CertRecord, error) {
}
// Close the connection to the database
func (db *SQLStore) Close() error {
func (db *sqlStore) Close() error {
return db.conn.Close()
}
......@@ -2,6 +2,7 @@ package store
import (
"encoding/json"
"fmt"
"time"
"golang.org/x/crypto/ssh"
......@@ -15,11 +16,11 @@ import (
func New(c config.Database) (CertStorer, error) {
switch c["type"] {
case "mysql", "sqlite":
return NewSQLStore(c)
return newSQLStore(c)
case "mem":
return NewMemoryStore(), nil
return newMemoryStore(), nil
}
return NewMemoryStore(), nil
return nil, fmt.Errorf("unable to create store with driver %s", c["type"])
}
// CertStorer records issued certs in a persistent store for audit and
......
......@@ -101,7 +101,7 @@ func testStore(t *testing.T, db CertStorer) {
}
func TestMemoryStore(t *testing.T) {
db := NewMemoryStore()
db := newMemoryStore()
testStore(t, db)
}
......@@ -120,7 +120,7 @@ func TestMySQLStore(t *testing.T) {
} else {
sqlConfig["username"] = u.Username
}
db, err := NewSQLStore(sqlConfig)
db, err := newSQLStore(sqlConfig)
if err != nil {
t.Error(err)
}
......@@ -134,7 +134,7 @@ func TestSQLiteStore(t *testing.T) {
}
defer os.Remove(f.Name())
config := map[string]string{"type": "sqlite", "filename": f.Name()}
db, err := NewSQLStore(config)
db, err := newSQLStore(config)
if err != nil {
t.Error(err)
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment