From 8ee3c6473f3e2373303b9cb16ab5f059f9e6369e Mon Sep 17 00:00:00 2001
From: Niall Sheridan <nsheridan@gmail.com>
Date: Sat, 15 Apr 2017 18:28:23 +0100
Subject: [PATCH] Revoke multiple certs in a single call

---
 .gitignore                 |  1 +
 server/store/mem.go        | 11 +++++------
 server/store/sqldb.go      | 16 ++++++----------
 server/store/store.go      |  2 +-
 server/store/store_test.go |  2 +-
 server/web.go              |  6 ++----
 6 files changed, 16 insertions(+), 22 deletions(-)

diff --git a/.gitignore b/.gitignore
index 19418512..23800a45 100644
--- a/.gitignore
+++ b/.gitignore
@@ -8,3 +8,4 @@ tmp
 signing_key*
 http.log
 .idea
+.DS_Store
diff --git a/server/store/mem.go b/server/store/mem.go
index e289b161..68c5a13b 100644
--- a/server/store/mem.go
+++ b/server/store/mem.go
@@ -57,13 +57,12 @@ func (ms *MemoryStore) List(includeExpired bool) ([]*CertRecord, error) {
 }
 
 // Revoke an issued cert by id.
-func (ms *MemoryStore) Revoke(id string) error {
-	r, err := ms.Get(id)
-	if err != nil {
-		return err
+func (ms *MemoryStore) Revoke(ids []string) error {
+	ms.Lock()
+	defer ms.Unlock()
+	for _, id := range ids {
+		ms.certs[id].Revoked = true
 	}
-	r.Revoked = true
-	ms.SetRecord(r)
 	return nil
 }
 
diff --git a/server/store/sqldb.go b/server/store/sqldb.go
index bdb88935..c5a0f4ec 100644
--- a/server/store/sqldb.go
+++ b/server/store/sqldb.go
@@ -10,6 +10,7 @@ import (
 	"github.com/go-sql-driver/mysql"
 	"github.com/jmoiron/sqlx"
 	"github.com/nsheridan/cashier/server/config"
+	"github.com/pkg/errors"
 )
 
 var _ CertStorer = (*SQLStore)(nil)
@@ -22,7 +23,6 @@ type SQLStore struct {
 	set         *sqlx.Stmt
 	listAll     *sqlx.Stmt
 	listCurrent *sqlx.Stmt
-	revoke      *sqlx.Stmt
 	revoked     *sqlx.Stmt
 }
 
@@ -76,9 +76,6 @@ func NewSQLStore(c config.Database) (*SQLStore, error) {
 	if db.listCurrent, err = conn.Preparex("SELECT * FROM issued_certs WHERE ? <= expires_at"); err != nil {
 		return nil, fmt.Errorf("SQLStore: prepare listCurrent: %v", err)
 	}
-	if db.revoke, err = conn.Preparex("UPDATE issued_certs SET revoked = 1 WHERE key_id = ?"); err != nil {
-		return nil, fmt.Errorf("SQLStore: prepare revoke: %v", err)
-	}
 	if db.revoked, err = conn.Preparex("SELECT * FROM issued_certs WHERE revoked = 1 AND ? <= expires_at"); err != nil {
 		return nil, fmt.Errorf("SQLStore: prepare revoked: %v", err)
 	}
@@ -133,14 +130,13 @@ func (db *SQLStore) List(includeExpired bool) ([]*CertRecord, error) {
 }
 
 // Revoke an issued cert by id.
-func (db *SQLStore) Revoke(id string) error {
+func (db *SQLStore) Revoke(ids []string) error {
 	if err := db.conn.Ping(); err != nil {
-		return err
+		return errors.Wrap(err, "unable to connect to database")
 	}
-	if _, err := db.revoke.Exec(id); err != nil {
-		return err
-	}
-	return nil
+	q, args, err := sqlx.In("UPDATE issued_certs SET revoked = 1 WHERE key_id IN (?)", ids)
+	_, err = db.conn.Query(q, args...)
+	return err
 }
 
 // GetRevoked returns all revoked certs
diff --git a/server/store/store.go b/server/store/store.go
index b200e814..4edb4461 100644
--- a/server/store/store.go
+++ b/server/store/store.go
@@ -29,7 +29,7 @@ type CertStorer interface {
 	SetCert(cert *ssh.Certificate) error
 	SetRecord(record *CertRecord) error
 	List(includeExpired bool) ([]*CertRecord, error)
-	Revoke(id string) error
+	Revoke(id []string) error
 	GetRevoked() ([]*CertRecord, error)
 	Close() error
 }
diff --git a/server/store/store_test.go b/server/store/store_test.go
index 9a8a4bef..d18d02b3 100644
--- a/server/store/store_test.go
+++ b/server/store/store_test.go
@@ -87,7 +87,7 @@ func testStore(t *testing.T, db CertStorer) {
 	if ret.KeyID != cert.KeyId {
 		t.Error("key mismatch")
 	}
-	if err := db.Revoke("key"); err != nil {
+	if err := db.Revoke([]string{"key"}); err != nil {
 		t.Error(err)
 	}
 
diff --git a/server/web.go b/server/web.go
index edaa394c..08162d51 100644
--- a/server/web.go
+++ b/server/web.go
@@ -240,10 +240,8 @@ func revokeCertHandler(a *appContext, w http.ResponseWriter, r *http.Request) (i
 		return a.login(w, r)
 	}
 	r.ParseForm()
-	for _, id := range r.Form["cert_id"] {
-		if err := certstore.Revoke(id); err != nil {
-			return http.StatusInternalServerError, errors.Wrap(err, "unable to revoke")
-		}
+	if err := certstore.Revoke(r.Form["cert_id"]); err != nil {
+		return http.StatusInternalServerError, errors.Wrap(err, "unable to revoke certs")
 	}
 	http.Redirect(w, r, "/admin/certs", http.StatusSeeOther)
 	return http.StatusSeeOther, nil
-- 
GitLab