From 2e7c8c2f521c9e50bb3aea4df16771c22fe70e58 Mon Sep 17 00:00:00 2001
From: Niall Sheridan <nsheridan@gmail.com>
Date: Sat, 10 Sep 2016 20:14:20 +0100
Subject: [PATCH] Allow filtering results

---
 client/keys.go        |  3 ++-
 server/store/mem.go   | 14 +++++++-------
 server/store/mongo.go |  9 +++++++--
 server/store/sqldb.go | 27 ++++++++++++++++++---------
 server/store/store.go | 14 +++++++-------
 5 files changed, 41 insertions(+), 26 deletions(-)

diff --git a/client/keys.go b/client/keys.go
index 4b3b69e2..0ec0f1dd 100644
--- a/client/keys.go
+++ b/client/keys.go
@@ -6,6 +6,7 @@ import (
 	"crypto/rand"
 	"crypto/rsa"
 	"fmt"
+	"strings"
 
 	"golang.org/x/crypto/ed25519"
 	"golang.org/x/crypto/ssh"
@@ -78,7 +79,7 @@ func GenerateKey(keytype string, bits int) (Key, ssh.PublicKey, error) {
 		for k := range keytypes {
 			valid = append(valid, k)
 		}
-		return nil, nil, fmt.Errorf("Unsupported key type %s. Valid choices are %s", keytype, valid)
+		return nil, nil, fmt.Errorf("Unsupported key type %s. Valid choices are %s", keytype, strings.Join(valid, "|"))
 	}
 	return f(bits)
 }
diff --git a/server/store/mem.go b/server/store/mem.go
index 92167a96..e63d00a2 100644
--- a/server/store/mem.go
+++ b/server/store/mem.go
@@ -34,14 +34,16 @@ func (ms *memoryStore) SetRecord(record *CertRecord) error {
 	return nil
 }
 
-func (ms *memoryStore) List() ([]*CertRecord, error) {
+func (ms *memoryStore) List(includeExpired bool) ([]*CertRecord, error) {
 	var records []*CertRecord
 	ms.Lock()
 	defer ms.Unlock()
+
 	for _, value := range ms.certs {
-		if value.Expires.After(time.Now().UTC()) {
-			records = append(records, value)
+		if !includeExpired && value.Expires.After(time.Now().UTC()) {
+			continue
 		}
+		records = append(records, value)
 	}
 	return records, nil
 }
@@ -58,11 +60,9 @@ func (ms *memoryStore) Revoke(id string) error {
 
 func (ms *memoryStore) GetRevoked() ([]*CertRecord, error) {
 	var revoked []*CertRecord
-	all, _ := ms.List()
+	all, _ := ms.List(false)
 	for _, r := range all {
-		if r.Revoked && time.Now().UTC().Unix() <= r.Expires.UTC().Unix() {
-			revoked = append(revoked, r)
-		}
+		revoked = append(revoked, r)
 	}
 	return revoked, nil
 }
diff --git a/server/store/mongo.go b/server/store/mongo.go
index 79df69d5..8a3ccda3 100644
--- a/server/store/mongo.go
+++ b/server/store/mongo.go
@@ -67,12 +67,17 @@ func (m *mongoDB) SetRecord(record *CertRecord) error {
 	return m.collection.Insert(record)
 }
 
-func (m *mongoDB) List() ([]*CertRecord, error) {
+func (m *mongoDB) List(includeExpired bool) ([]*CertRecord, error) {
 	if err := m.session.Ping(); err != nil {
 		return nil, err
 	}
 	var result []*CertRecord
-	err := m.collection.Find(bson.M{"expires": bson.M{"$gte": time.Now().UTC()}}).All(&result)
+	var err error
+	if includeExpired {
+		err = m.collection.Find(nil).All(&result)
+	} else {
+		err = m.collection.Find(bson.M{"expires": bson.M{"$gte": time.Now().UTC()}}).All(&result)
+	}
 	return result, err
 }
 
diff --git a/server/store/sqldb.go b/server/store/sqldb.go
index 54a52c63..81784b02 100644
--- a/server/store/sqldb.go
+++ b/server/store/sqldb.go
@@ -16,11 +16,12 @@ import (
 type sqldb struct {
 	conn *sql.DB
 
-	get     *sql.Stmt
-	set     *sql.Stmt
-	list    *sql.Stmt
-	revoke  *sql.Stmt
-	revoked *sql.Stmt
+	get         *sql.Stmt
+	set         *sql.Stmt
+	listAll     *sql.Stmt
+	listCurrent *sql.Stmt
+	revoke      *sql.Stmt
+	revoked     *sql.Stmt
 }
 
 func parse(config string) []string {
@@ -66,8 +67,11 @@ func NewSQLStore(config string) (CertStorer, error) {
 	if db.get, err = conn.Prepare("SELECT * FROM issued_certs WHERE key_id = ?"); err != nil {
 		return nil, fmt.Errorf("sqldb: prepare get: %v", err)
 	}
-	if db.list, err = conn.Prepare("SELECT * FROM issued_certs WHERE ? <= expires_at"); err != nil {
-		return nil, fmt.Errorf("sqldb: prepare list: %v", err)
+	if db.listAll, err = conn.Prepare("SELECT * FROM issued_certs"); err != nil {
+		return nil, fmt.Errorf("sqldb: prepare listAll: %v", err)
+	}
+	if db.listCurrent, err = conn.Prepare("SELECT * FROM issued_certs WHERE ? <= expires_at"); err != nil {
+		return nil, fmt.Errorf("sqldb: prepare listCurrent: %v", err)
 	}
 	if db.revoke, err = conn.Prepare("UPDATE issued_certs SET revoked = 1 WHERE key_id = ?"); err != nil {
 		return nil, fmt.Errorf("sqldb: prepare revoke: %v", err)
@@ -132,12 +136,17 @@ func (db *sqldb) SetRecord(rec *CertRecord) error {
 	return err
 }
 
-func (db *sqldb) List() ([]*CertRecord, error) {
+func (db *sqldb) List(includeExpired bool) ([]*CertRecord, error) {
 	if err := db.conn.Ping(); err != nil {
 		return nil, err
 	}
 	var recs []*CertRecord
-	rows, _ := db.revoked.Query(time.Now().UTC())
+	var rows *sql.Rows
+	if includeExpired {
+		rows, _ = db.listAll.Query()
+	} else {
+		rows, _ = db.listCurrent.Query(time.Now().UTC())
+	}
 	defer rows.Close()
 	for rows.Next() {
 		cert, err := scanCert(rows)
diff --git a/server/store/store.go b/server/store/store.go
index f6ac66e1..a846bdaa 100644
--- a/server/store/store.go
+++ b/server/store/store.go
@@ -14,7 +14,7 @@ type CertStorer interface {
 	Get(id string) (*CertRecord, error)
 	SetCert(cert *ssh.Certificate) error
 	SetRecord(record *CertRecord) error
-	List() ([]*CertRecord, error)
+	List(includeExpired bool) ([]*CertRecord, error)
 	Revoke(id string) error
 	GetRevoked() ([]*CertRecord, error)
 	Close() error
@@ -22,12 +22,12 @@ type CertStorer interface {
 
 // A CertRecord is a representation of a ssh certificate used by a CertStorer.
 type CertRecord struct {
-	KeyID      string
-	Principals []string
-	CreatedAt  time.Time
-	Expires    time.Time
-	Revoked    bool
-	Raw        string
+	KeyID      string    `json:"key_id"`
+	Principals []string  `json:"principals"`
+	CreatedAt  time.Time `json:"created_at"`
+	Expires    time.Time `json:"expires"`
+	Revoked    bool      `json:"revoked"`
+	Raw        string    `json:"-"`
 }
 
 func parseTime(t uint64) time.Time {
-- 
GitLab