From 8b63d26271a9250e262962c2998a36bae36f8d20 Mon Sep 17 00:00:00 2001
From: Niall Sheridan <nsheridan@gmail.com>
Date: Fri, 7 Apr 2017 19:16:58 +0100
Subject: [PATCH] fix behaviour of SQLStore.List

---
 server/store/sqldb.go      | 10 ++++++++--
 server/store/store_test.go | 24 +++++++++++++++++++++---
 2 files changed, 29 insertions(+), 5 deletions(-)

diff --git a/server/store/sqldb.go b/server/store/sqldb.go
index 2efca0e0..bdb88935 100644
--- a/server/store/sqldb.go
+++ b/server/store/sqldb.go
@@ -120,8 +120,14 @@ func (db *SQLStore) List(includeExpired bool) ([]*CertRecord, error) {
 		return nil, err
 	}
 	recs := []*CertRecord{}
-	if err := db.listAll.Select(&recs); err != nil {
-		return nil, err
+	if includeExpired {
+		if err := db.listAll.Select(&recs); err != nil {
+			return nil, err
+		}
+	} else {
+		if err := db.listCurrent.Select(&recs, time.Now()); err != nil {
+			return nil, err
+		}
 	}
 	return recs, nil
 }
diff --git a/server/store/store_test.go b/server/store/store_test.go
index e8bebafa..cd58ccd7 100644
--- a/server/store/store_test.go
+++ b/server/store/store_test.go
@@ -48,15 +48,29 @@ func testStore(t *testing.T, db CertStorer) {
 		KeyID:      "a",
 		Principals: []string{"b"},
 		CreatedAt:  time.Now().UTC(),
-		Expires:    time.Now().UTC().Add(1 * time.Minute),
+		Expires:    time.Now().UTC().Add(-1 * time.Second),
 		Raw:        "AAAAAA",
 	}
 	if err := db.SetRecord(r); err != nil {
 		t.Error(err)
 	}
-	if _, err := db.List(true); err != nil {
+
+	// includeExpired = false should return 0 results
+	recs, err := db.List(false)
+	if err != nil {
 		t.Error(err)
 	}
+	if len(recs) > 0 {
+		t.Errorf("Expected 0 results, got %d", len(recs))
+	}
+	// includeExpired = false should return 1 result
+	recs, err = db.List(true)
+	if err != nil {
+		t.Error(err)
+	}
+	if recs[0].KeyID != r.KeyID {
+		t.Error("key mismatch")
+	}
 
 	c, _, _, _, _ := ssh.ParseAuthorizedKey(testdata.Cert)
 	cert := c.(*ssh.Certificate)
@@ -66,9 +80,13 @@ func testStore(t *testing.T, db CertStorer) {
 		t.Error(err)
 	}
 
-	if _, err := db.Get("key"); err != nil {
+	ret, err := db.Get("key")
+	if err != nil {
 		t.Error(err)
 	}
+	if ret.KeyID != cert.KeyId {
+		t.Error("key mismatch")
+	}
 	if err := db.Revoke("key"); err != nil {
 		t.Error(err)
 	}
-- 
GitLab