From d21fac6f190c1079ca247658530d465ad5867ff5 Mon Sep 17 00:00:00 2001
From: Niall Sheridan <nsheridan@gmail.com>
Date: Thu, 9 Aug 2018 20:47:50 +0100
Subject: [PATCH] Only request a reason from the client if the server requires
 it

---
 client/client.go        | 57 +++++++++++++++++++++++++++++++----------
 client/client_test.go   |  8 +++---
 cmd/cashier/main.go     | 10 +-------
 server/config/config.go |  1 +
 server/web.go           | 16 +++++++++---
 5 files changed, 61 insertions(+), 31 deletions(-)

diff --git a/client/client.go b/client/client.go
index 58cc6bb6..628783ab 100644
--- a/client/client.go
+++ b/client/client.go
@@ -1,6 +1,7 @@
 package client
 
 import (
+	"bufio"
 	"bytes"
 	"crypto/tls"
 	"encoding/base64"
@@ -10,7 +11,9 @@ import (
 	"io/ioutil"
 	"net/http"
 	"net/url"
+	"os"
 	"path"
+	"strings"
 	"time"
 
 	"github.com/nsheridan/cashier/lib"
@@ -19,6 +22,10 @@ import (
 	"golang.org/x/crypto/ssh/agent"
 )
 
+var (
+	errNeedsReason = errors.New("reason required")
+)
+
 // SavePublicFiles installs the public part of the cert and key.
 func SavePublicFiles(prefix string, cert *ssh.Certificate, pub ssh.PublicKey) error {
 	if prefix == "" {
@@ -77,7 +84,11 @@ func InstallCert(a agent.Agent, cert *ssh.Certificate, key Key) error {
 }
 
 // send the signing request to the CA.
-func send(s []byte, token, ca string, ValidateTLSCertificate bool) (*lib.SignResponse, error) {
+func send(sr *lib.SignRequest, token, ca string, ValidateTLSCertificate bool) (*lib.SignResponse, error) {
+	s, err := json.Marshal(sr)
+	if err != nil {
+		return nil, errors.Wrap(err, "unable to create sign request")
+	}
 	transport := &http.Transport{
 		TLSClientConfig: &tls.Config{InsecureSkipVerify: !ValidateTLSCertificate},
 	}
@@ -99,33 +110,51 @@ func send(s []byte, token, ca string, ValidateTLSCertificate bool) (*lib.SignRes
 		return nil, err
 	}
 	defer resp.Body.Close()
+	signResponse := &lib.SignResponse{}
 	if resp.StatusCode != http.StatusOK {
-		return nil, fmt.Errorf("Bad response from server: %s", resp.Status)
+		if resp.StatusCode == http.StatusForbidden && strings.HasPrefix(resp.Header.Get("X-Need-Reason"), "required") {
+			return signResponse, errNeedsReason
+		}
+		return signResponse, fmt.Errorf("bad response from server: %s", resp.Status)
 	}
-	c := &lib.SignResponse{}
-	if err := json.NewDecoder(resp.Body).Decode(c); err != nil {
+	if err := json.NewDecoder(resp.Body).Decode(signResponse); err != nil {
 		return nil, errors.Wrap(err, "unable to decode server response")
 	}
-	return c, nil
+	return signResponse, nil
+}
+
+func promptForReason() (message string) {
+	fmt.Print("Enter message: ")
+	scanner := bufio.NewScanner(os.Stdin)
+	if scanner.Scan() {
+		message = scanner.Text()
+	}
+	return message
 }
 
 // Sign sends the public key to the CA to be signed.
-func Sign(pub ssh.PublicKey, token string, message string, conf *Config) (*ssh.Certificate, error) {
+func Sign(pub ssh.PublicKey, token string, conf *Config) (*ssh.Certificate, error) {
+	var err error
 	validity, err := time.ParseDuration(conf.Validity)
 	if err != nil {
 		return nil, err
 	}
-	s, err := json.Marshal(&lib.SignRequest{
+	s := &lib.SignRequest{
 		Key:        string(lib.GetPublicKey(pub)),
 		ValidUntil: time.Now().Add(validity),
-		Message:    message,
-	})
-	if err != nil {
-		return nil, errors.Wrap(err, "unable to create sign request")
 	}
-	resp, err := send(s, token, conf.CA, conf.ValidateTLSCertificate)
-	if err != nil {
-		return nil, errors.Wrap(err, "error sending request to CA")
+	resp := &lib.SignResponse{}
+	for {
+		resp, err = send(s, token, conf.CA, conf.ValidateTLSCertificate)
+		if err == nil {
+			break
+		}
+		if err != nil && err == errNeedsReason {
+			s.Message = promptForReason()
+			continue
+		} else if err != nil {
+			return nil, errors.Wrap(err, "error sending request to CA")
+		}
 	}
 	if resp.Status != "ok" {
 		return nil, fmt.Errorf("bad response from CA: %s", resp.Response)
diff --git a/client/client_test.go b/client/client_test.go
index fddd543e..2447db3f 100644
--- a/client/client_test.go
+++ b/client/client_test.go
@@ -67,7 +67,7 @@ func TestSignGood(t *testing.T) {
 		fmt.Fprintln(w, string(j))
 	}))
 	defer ts.Close()
-	_, err := send([]byte(`{}`), "token", ts.URL, true)
+	_, err := send(&lib.SignRequest{}, "token", ts.URL, true)
 	if err != nil {
 		t.Error(err)
 	}
@@ -79,7 +79,7 @@ func TestSignGood(t *testing.T) {
 		CA:       ts.URL,
 		Validity: "24h",
 	}
-	cert, err := Sign(k, "token", "message", c)
+	cert, err := Sign(k, "token", c)
 	if cert == nil && err != nil {
 		t.Error(err)
 	}
@@ -95,7 +95,7 @@ func TestSignBad(t *testing.T) {
 		fmt.Fprintln(w, string(j))
 	}))
 	defer ts.Close()
-	_, err := send([]byte(`{}`), "token", ts.URL, true)
+	_, err := send(&lib.SignRequest{}, "token", ts.URL, true)
 	if err != nil {
 		t.Error(err)
 	}
@@ -107,7 +107,7 @@ func TestSignBad(t *testing.T) {
 		CA:       ts.URL,
 		Validity: "24h",
 	}
-	cert, err := Sign(k, "token", "message", c)
+	cert, err := Sign(k, "token", c)
 	if cert != nil && err == nil {
 		t.Error(err)
 	}
diff --git a/cmd/cashier/main.go b/cmd/cashier/main.go
index 7054bef4..f448a252 100644
--- a/cmd/cashier/main.go
+++ b/cmd/cashier/main.go
@@ -1,7 +1,6 @@
 package main
 
 import (
-	"bufio"
 	"fmt"
 	"log"
 	"net"
@@ -50,14 +49,7 @@ func main() {
 	var token string
 	fmt.Scanln(&token)
 
-	var message string
-	fmt.Print("Enter message: ")
-	scanner := bufio.NewScanner(os.Stdin)
-	if scanner.Scan() {
-		message = scanner.Text()
-	}
-
-	cert, err := client.Sign(pub, token, message, c)
+	cert, err := client.Sign(pub, token, c)
 	if err != nil {
 		log.Fatalln(err)
 	}
diff --git a/server/config/config.go b/server/config/config.go
index 422a135e..19858006 100644
--- a/server/config/config.go
+++ b/server/config/config.go
@@ -37,6 +37,7 @@ type Server struct {
 	CSRFSecret            string   `hcl:"csrf_secret"`
 	HTTPLogFile           string   `hcl:"http_logfile"`
 	Database              Database `hcl:"database"`
+	RequireReason         bool     `hcl:"require_reason"`
 }
 
 // Auth holds the configuration specific to the OAuth provider.
diff --git a/server/web.go b/server/web.go
index 5677429b..e238150c 100644
--- a/server/web.go
+++ b/server/web.go
@@ -33,8 +33,9 @@ import (
 
 // appContext contains local context - cookiestore, authsession etc.
 type appContext struct {
-	cookiestore *sessions.CookieStore
-	authsession *auth.Session
+	cookiestore   *sessions.CookieStore
+	authsession   *auth.Session
+	requireReason bool
 }
 
 // getAuthTokenCookie retrieves a cookie from the request.
@@ -141,6 +142,12 @@ func signHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, er
 	if err != nil {
 		return http.StatusBadRequest, errors.Wrap(err, "unable to extract key from request")
 	}
+
+	if a.requireReason && req.Message == "" {
+		w.Header().Add("X-Need-Reason", "required")
+		return http.StatusForbidden, errors.New(http.StatusText(http.StatusForbidden))
+	}
+
 	username := authprovider.Username(token)
 	authprovider.Revoke(token) // We don't need this anymore.
 	cert, err := keysigner.SignUserKey(req, username)
@@ -266,7 +273,6 @@ type appHandler struct {
 func (ah appHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 	status, err := ah.h(ah.appContext, w, r)
 	if err != nil {
-		log.Printf("HTTP %d: %q", status, err)
 		http.Error(w, err.Error(), status)
 	}
 }
@@ -283,7 +289,8 @@ func newState() string {
 func runHTTPServer(conf *config.Server, l net.Listener) {
 	var err error
 	ctx := &appContext{
-		cookiestore: sessions.NewCookieStore([]byte(conf.CookieSecret)),
+		cookiestore:   sessions.NewCookieStore([]byte(conf.CookieSecret)),
+		requireReason: conf.RequireReason,
 	}
 	ctx.cookiestore.Options = &sessions.Options{
 		MaxAge:   900,
@@ -313,6 +320,7 @@ func runHTTPServer(conf *config.Server, l net.Listener) {
 	r.Methods("GET").Path("/admin/certs.json").Handler(appHandler{ctx, listCertsJSONHandler})
 	r.Methods("GET").Path("/metrics").Handler(promhttp.Handler())
 	r.Methods("GET").Path("/healthcheck").HandlerFunc(healthcheck)
+
 	box := packr.NewBox("static")
 	r.PathPrefix("/static/").Handler(http.StripPrefix("/static/", http.FileServer(box)))
 	h := handlers.LoggingHandler(logfile, r)
-- 
GitLab