From a6e42d899cde380f513710d07787ba11dfbe229a Mon Sep 17 00:00:00 2001
From: Niall Sheridan <nsheridan@gmail.com>
Date: Sun, 7 Aug 2016 17:55:05 +0100
Subject: [PATCH] Ping the db before attempting to query it

---
 server/store/mongo.go | 33 +++++++++++++++++++++++++--------
 1 file changed, 25 insertions(+), 8 deletions(-)

diff --git a/server/store/mongo.go b/server/store/mongo.go
index 9773da77..c0561719 100644
--- a/server/store/mongo.go
+++ b/server/store/mongo.go
@@ -16,7 +16,8 @@ var (
 )
 
 type mongoDB struct {
-	conn *mgo.Collection
+	collection *mgo.Collection
+	session    *mgo.Session
 }
 
 func parseMongoConfig(config string) *mgo.DialInfo {
@@ -40,13 +41,17 @@ func NewMongoStore(config string) (CertStorer, error) {
 	}
 	c := session.DB(certsDB).C(issuedTable)
 	return &mongoDB{
-		conn: c,
+		collection: c,
+		session:    session,
 	}, nil
 }
 
 func (m *mongoDB) Get(id string) (*CertRecord, error) {
+	if err := m.session.Ping(); err != nil {
+		return nil, err
+	}
 	c := &CertRecord{}
-	err := m.conn.Find(bson.M{"keyid": id}).One(c)
+	err := m.collection.Find(bson.M{"keyid": id}).One(c)
 	return c, err
 }
 
@@ -56,26 +61,38 @@ func (m *mongoDB) SetCert(cert *ssh.Certificate) error {
 }
 
 func (m *mongoDB) SetRecord(record *CertRecord) error {
-	return m.conn.Insert(record)
+	if err := m.session.Ping(); err != nil {
+		return err
+	}
+	return m.collection.Insert(record)
 }
 
 func (m *mongoDB) List() ([]*CertRecord, error) {
+	if err := m.session.Ping(); err != nil {
+		return nil, err
+	}
 	var result []*CertRecord
-	m.conn.Find(nil).All(&result)
+	m.collection.Find(nil).All(&result)
 	return result, nil
 }
 
 func (m *mongoDB) Revoke(id string) error {
-	return m.conn.Update(bson.M{"keyid": id}, bson.M{"$set": bson.M{"revoked": true}})
+	if err := m.session.Ping(); err != nil {
+		return err
+	}
+	return m.collection.Update(bson.M{"keyid": id}, bson.M{"$set": bson.M{"revoked": true}})
 }
 
 func (m *mongoDB) GetRevoked() ([]*CertRecord, error) {
+	if err := m.session.Ping(); err != nil {
+		return nil, err
+	}
 	var result []*CertRecord
-	err := m.conn.Find(bson.M{"expires": bson.M{"$gte": time.Now().UTC()}, "revoked": true}).All(&result)
+	err := m.collection.Find(bson.M{"expires": bson.M{"$gte": time.Now().UTC()}, "revoked": true}).All(&result)
 	return result, err
 }
 
 func (m *mongoDB) Close() error {
-	m.conn.Database.Session.Close()
+	m.session.Close()
 	return nil
 }
-- 
GitLab