From 0ed23b71115ad2213bf7ea545f9f765052008872 Mon Sep 17 00:00:00 2001
From: Niall Sheridan <nsheridan@gmail.com>
Date: Sat, 24 Sep 2016 23:25:06 +0100
Subject: [PATCH] Use a new session for each request

---
 server/store/mongo.go | 51 +++++++++++++++++++++++++++----------------
 1 file changed, 32 insertions(+), 19 deletions(-)

diff --git a/server/store/mongo.go b/server/store/mongo.go
index 8a3ccda3..1b13d7a8 100644
--- a/server/store/mongo.go
+++ b/server/store/mongo.go
@@ -15,11 +15,6 @@ var (
 	issuedTable = "issued_certs"
 )
 
-type mongoDB struct {
-	collection *mgo.Collection
-	session    *mgo.Session
-}
-
 func parseMongoConfig(config string) *mgo.DialInfo {
 	s := strings.SplitN(config, ":", 4)
 	_, user, passwd, hosts := s[0], s[1], s[2], s[3]
@@ -33,25 +28,33 @@ func parseMongoConfig(config string) *mgo.DialInfo {
 	return d
 }
 
+func collection(session *mgo.Session) *mgo.Collection {
+	return session.DB(certsDB).C(issuedTable)
+}
+
 // NewMongoStore returns a MongoDB CertStorer.
 func NewMongoStore(config string) (CertStorer, error) {
 	session, err := mgo.DialWithInfo(parseMongoConfig(config))
 	if err != nil {
 		return nil, err
 	}
-	c := session.DB(certsDB).C(issuedTable)
 	return &mongoDB{
-		collection: c,
-		session:    session,
+		session: session,
 	}, nil
 }
 
+type mongoDB struct {
+	session *mgo.Session
+}
+
 func (m *mongoDB) Get(id string) (*CertRecord, error) {
-	if err := m.session.Ping(); err != nil {
+	s := m.session.Copy()
+	defer s.Close()
+	if err := s.Ping(); err != nil {
 		return nil, err
 	}
 	c := &CertRecord{}
-	err := m.collection.Find(bson.M{"keyid": id}).One(c)
+	err := collection(s).Find(bson.M{"keyid": id}).One(c)
 	return c, err
 }
 
@@ -61,39 +64,49 @@ func (m *mongoDB) SetCert(cert *ssh.Certificate) error {
 }
 
 func (m *mongoDB) SetRecord(record *CertRecord) error {
-	if err := m.session.Ping(); err != nil {
+	s := m.session.Copy()
+	defer s.Close()
+	if err := s.Ping(); err != nil {
 		return err
 	}
-	return m.collection.Insert(record)
+	return collection(s).Insert(record)
 }
 
 func (m *mongoDB) List(includeExpired bool) ([]*CertRecord, error) {
-	if err := m.session.Ping(); err != nil {
+	s := m.session.Copy()
+	defer s.Close()
+	if err := s.Ping(); err != nil {
 		return nil, err
 	}
 	var result []*CertRecord
 	var err error
+	c := collection(s)
 	if includeExpired {
-		err = m.collection.Find(nil).All(&result)
+		err = c.Find(nil).All(&result)
 	} else {
-		err = m.collection.Find(bson.M{"expires": bson.M{"$gte": time.Now().UTC()}}).All(&result)
+		err = c.Find(bson.M{"expires": bson.M{"$gte": time.Now().UTC()}}).All(&result)
 	}
 	return result, err
 }
 
 func (m *mongoDB) Revoke(id string) error {
-	if err := m.session.Ping(); err != nil {
+	s := m.session.Copy()
+	defer s.Close()
+	if err := s.Ping(); err != nil {
 		return err
 	}
-	return m.collection.Update(bson.M{"keyid": id}, bson.M{"$set": bson.M{"revoked": true}})
+	c := collection(s)
+	return c.Update(bson.M{"keyid": id}, bson.M{"$set": bson.M{"revoked": true}})
 }
 
 func (m *mongoDB) GetRevoked() ([]*CertRecord, error) {
-	if err := m.session.Ping(); err != nil {
+	s := m.session.Copy()
+	defer s.Close()
+	if err := s.Ping(); err != nil {
 		return nil, err
 	}
 	var result []*CertRecord
-	err := m.collection.Find(bson.M{"expires": bson.M{"$gte": time.Now().UTC()}, "revoked": true}).All(&result)
+	err := collection(s).Find(bson.M{"expires": bson.M{"$gte": time.Now().UTC()}, "revoked": true}).All(&result)
 	return result, err
 }
 
-- 
GitLab