Skip to content
Snippets Groups Projects
Commit d21fac6f authored by Niall Sheridan's avatar Niall Sheridan
Browse files

Only request a reason from the client if the server requires it

parent 347c11ec
No related branches found
No related tags found
No related merge requests found
package client package client
import ( import (
"bufio"
"bytes" "bytes"
"crypto/tls" "crypto/tls"
"encoding/base64" "encoding/base64"
...@@ -10,7 +11,9 @@ import ( ...@@ -10,7 +11,9 @@ import (
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/url" "net/url"
"os"
"path" "path"
"strings"
"time" "time"
"github.com/nsheridan/cashier/lib" "github.com/nsheridan/cashier/lib"
...@@ -19,6 +22,10 @@ import ( ...@@ -19,6 +22,10 @@ import (
"golang.org/x/crypto/ssh/agent" "golang.org/x/crypto/ssh/agent"
) )
var (
errNeedsReason = errors.New("reason required")
)
// SavePublicFiles installs the public part of the cert and key. // SavePublicFiles installs the public part of the cert and key.
func SavePublicFiles(prefix string, cert *ssh.Certificate, pub ssh.PublicKey) error { func SavePublicFiles(prefix string, cert *ssh.Certificate, pub ssh.PublicKey) error {
if prefix == "" { if prefix == "" {
...@@ -77,7 +84,11 @@ func InstallCert(a agent.Agent, cert *ssh.Certificate, key Key) error { ...@@ -77,7 +84,11 @@ func InstallCert(a agent.Agent, cert *ssh.Certificate, key Key) error {
} }
// send the signing request to the CA. // send the signing request to the CA.
func send(s []byte, token, ca string, ValidateTLSCertificate bool) (*lib.SignResponse, error) { func send(sr *lib.SignRequest, token, ca string, ValidateTLSCertificate bool) (*lib.SignResponse, error) {
s, err := json.Marshal(sr)
if err != nil {
return nil, errors.Wrap(err, "unable to create sign request")
}
transport := &http.Transport{ transport := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: !ValidateTLSCertificate}, TLSClientConfig: &tls.Config{InsecureSkipVerify: !ValidateTLSCertificate},
} }
...@@ -99,34 +110,52 @@ func send(s []byte, token, ca string, ValidateTLSCertificate bool) (*lib.SignRes ...@@ -99,34 +110,52 @@ func send(s []byte, token, ca string, ValidateTLSCertificate bool) (*lib.SignRes
return nil, err return nil, err
} }
defer resp.Body.Close() defer resp.Body.Close()
signResponse := &lib.SignResponse{}
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("Bad response from server: %s", resp.Status) if resp.StatusCode == http.StatusForbidden && strings.HasPrefix(resp.Header.Get("X-Need-Reason"), "required") {
return signResponse, errNeedsReason
}
return signResponse, fmt.Errorf("bad response from server: %s", resp.Status)
} }
c := &lib.SignResponse{} if err := json.NewDecoder(resp.Body).Decode(signResponse); err != nil {
if err := json.NewDecoder(resp.Body).Decode(c); err != nil {
return nil, errors.Wrap(err, "unable to decode server response") return nil, errors.Wrap(err, "unable to decode server response")
} }
return c, nil return signResponse, nil
}
func promptForReason() (message string) {
fmt.Print("Enter message: ")
scanner := bufio.NewScanner(os.Stdin)
if scanner.Scan() {
message = scanner.Text()
}
return message
} }
// Sign sends the public key to the CA to be signed. // Sign sends the public key to the CA to be signed.
func Sign(pub ssh.PublicKey, token string, message string, conf *Config) (*ssh.Certificate, error) { func Sign(pub ssh.PublicKey, token string, conf *Config) (*ssh.Certificate, error) {
var err error
validity, err := time.ParseDuration(conf.Validity) validity, err := time.ParseDuration(conf.Validity)
if err != nil { if err != nil {
return nil, err return nil, err
} }
s, err := json.Marshal(&lib.SignRequest{ s := &lib.SignRequest{
Key: string(lib.GetPublicKey(pub)), Key: string(lib.GetPublicKey(pub)),
ValidUntil: time.Now().Add(validity), ValidUntil: time.Now().Add(validity),
Message: message,
})
if err != nil {
return nil, errors.Wrap(err, "unable to create sign request")
} }
resp, err := send(s, token, conf.CA, conf.ValidateTLSCertificate) resp := &lib.SignResponse{}
if err != nil { for {
resp, err = send(s, token, conf.CA, conf.ValidateTLSCertificate)
if err == nil {
break
}
if err != nil && err == errNeedsReason {
s.Message = promptForReason()
continue
} else if err != nil {
return nil, errors.Wrap(err, "error sending request to CA") return nil, errors.Wrap(err, "error sending request to CA")
} }
}
if resp.Status != "ok" { if resp.Status != "ok" {
return nil, fmt.Errorf("bad response from CA: %s", resp.Response) return nil, fmt.Errorf("bad response from CA: %s", resp.Response)
} }
......
...@@ -67,7 +67,7 @@ func TestSignGood(t *testing.T) { ...@@ -67,7 +67,7 @@ func TestSignGood(t *testing.T) {
fmt.Fprintln(w, string(j)) fmt.Fprintln(w, string(j))
})) }))
defer ts.Close() defer ts.Close()
_, err := send([]byte(`{}`), "token", ts.URL, true) _, err := send(&lib.SignRequest{}, "token", ts.URL, true)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
...@@ -79,7 +79,7 @@ func TestSignGood(t *testing.T) { ...@@ -79,7 +79,7 @@ func TestSignGood(t *testing.T) {
CA: ts.URL, CA: ts.URL,
Validity: "24h", Validity: "24h",
} }
cert, err := Sign(k, "token", "message", c) cert, err := Sign(k, "token", c)
if cert == nil && err != nil { if cert == nil && err != nil {
t.Error(err) t.Error(err)
} }
...@@ -95,7 +95,7 @@ func TestSignBad(t *testing.T) { ...@@ -95,7 +95,7 @@ func TestSignBad(t *testing.T) {
fmt.Fprintln(w, string(j)) fmt.Fprintln(w, string(j))
})) }))
defer ts.Close() defer ts.Close()
_, err := send([]byte(`{}`), "token", ts.URL, true) _, err := send(&lib.SignRequest{}, "token", ts.URL, true)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
...@@ -107,7 +107,7 @@ func TestSignBad(t *testing.T) { ...@@ -107,7 +107,7 @@ func TestSignBad(t *testing.T) {
CA: ts.URL, CA: ts.URL,
Validity: "24h", Validity: "24h",
} }
cert, err := Sign(k, "token", "message", c) cert, err := Sign(k, "token", c)
if cert != nil && err == nil { if cert != nil && err == nil {
t.Error(err) t.Error(err)
} }
......
package main package main
import ( import (
"bufio"
"fmt" "fmt"
"log" "log"
"net" "net"
...@@ -50,14 +49,7 @@ func main() { ...@@ -50,14 +49,7 @@ func main() {
var token string var token string
fmt.Scanln(&token) fmt.Scanln(&token)
var message string cert, err := client.Sign(pub, token, c)
fmt.Print("Enter message: ")
scanner := bufio.NewScanner(os.Stdin)
if scanner.Scan() {
message = scanner.Text()
}
cert, err := client.Sign(pub, token, message, c)
if err != nil { if err != nil {
log.Fatalln(err) log.Fatalln(err)
} }
......
...@@ -37,6 +37,7 @@ type Server struct { ...@@ -37,6 +37,7 @@ type Server struct {
CSRFSecret string `hcl:"csrf_secret"` CSRFSecret string `hcl:"csrf_secret"`
HTTPLogFile string `hcl:"http_logfile"` HTTPLogFile string `hcl:"http_logfile"`
Database Database `hcl:"database"` Database Database `hcl:"database"`
RequireReason bool `hcl:"require_reason"`
} }
// Auth holds the configuration specific to the OAuth provider. // Auth holds the configuration specific to the OAuth provider.
......
...@@ -35,6 +35,7 @@ import ( ...@@ -35,6 +35,7 @@ import (
type appContext struct { type appContext struct {
cookiestore *sessions.CookieStore cookiestore *sessions.CookieStore
authsession *auth.Session authsession *auth.Session
requireReason bool
} }
// getAuthTokenCookie retrieves a cookie from the request. // getAuthTokenCookie retrieves a cookie from the request.
...@@ -141,6 +142,12 @@ func signHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, er ...@@ -141,6 +142,12 @@ func signHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, er
if err != nil { if err != nil {
return http.StatusBadRequest, errors.Wrap(err, "unable to extract key from request") return http.StatusBadRequest, errors.Wrap(err, "unable to extract key from request")
} }
if a.requireReason && req.Message == "" {
w.Header().Add("X-Need-Reason", "required")
return http.StatusForbidden, errors.New(http.StatusText(http.StatusForbidden))
}
username := authprovider.Username(token) username := authprovider.Username(token)
authprovider.Revoke(token) // We don't need this anymore. authprovider.Revoke(token) // We don't need this anymore.
cert, err := keysigner.SignUserKey(req, username) cert, err := keysigner.SignUserKey(req, username)
...@@ -266,7 +273,6 @@ type appHandler struct { ...@@ -266,7 +273,6 @@ type appHandler struct {
func (ah appHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (ah appHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
status, err := ah.h(ah.appContext, w, r) status, err := ah.h(ah.appContext, w, r)
if err != nil { if err != nil {
log.Printf("HTTP %d: %q", status, err)
http.Error(w, err.Error(), status) http.Error(w, err.Error(), status)
} }
} }
...@@ -284,6 +290,7 @@ func runHTTPServer(conf *config.Server, l net.Listener) { ...@@ -284,6 +290,7 @@ func runHTTPServer(conf *config.Server, l net.Listener) {
var err error var err error
ctx := &appContext{ ctx := &appContext{
cookiestore: sessions.NewCookieStore([]byte(conf.CookieSecret)), cookiestore: sessions.NewCookieStore([]byte(conf.CookieSecret)),
requireReason: conf.RequireReason,
} }
ctx.cookiestore.Options = &sessions.Options{ ctx.cookiestore.Options = &sessions.Options{
MaxAge: 900, MaxAge: 900,
...@@ -313,6 +320,7 @@ func runHTTPServer(conf *config.Server, l net.Listener) { ...@@ -313,6 +320,7 @@ func runHTTPServer(conf *config.Server, l net.Listener) {
r.Methods("GET").Path("/admin/certs.json").Handler(appHandler{ctx, listCertsJSONHandler}) r.Methods("GET").Path("/admin/certs.json").Handler(appHandler{ctx, listCertsJSONHandler})
r.Methods("GET").Path("/metrics").Handler(promhttp.Handler()) r.Methods("GET").Path("/metrics").Handler(promhttp.Handler())
r.Methods("GET").Path("/healthcheck").HandlerFunc(healthcheck) r.Methods("GET").Path("/healthcheck").HandlerFunc(healthcheck)
box := packr.NewBox("static") box := packr.NewBox("static")
r.PathPrefix("/static/").Handler(http.StripPrefix("/static/", http.FileServer(box))) r.PathPrefix("/static/").Handler(http.StripPrefix("/static/", http.FileServer(box)))
h := handlers.LoggingHandler(logfile, r) h := handlers.LoggingHandler(logfile, r)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment