From 49f40a952943f26494d6407dc608b50b2ec0df7f Mon Sep 17 00:00:00 2001
From: Niall Sheridan <nsheridan@gmail.com>
Date: Sun, 10 Jul 2016 22:35:13 +0100
Subject: [PATCH] Add some handlers tests

---
 cmd/cashierd/handlers_test.go            | 139 +++++++++++++++++++++++
 cmd/cashierd/main.go                     |  21 ++--
 server/auth/testprovider/testprovider.go |  56 +++++++++
 server/signer/signer.go                  |   2 +-
 4 files changed, 205 insertions(+), 13 deletions(-)
 create mode 100644 cmd/cashierd/handlers_test.go
 create mode 100644 server/auth/testprovider/testprovider.go

diff --git a/cmd/cashierd/handlers_test.go b/cmd/cashierd/handlers_test.go
new file mode 100644
index 00000000..a214dfdd
--- /dev/null
+++ b/cmd/cashierd/handlers_test.go
@@ -0,0 +1,139 @@
+package main
+
+import (
+	"bytes"
+	"encoding/json"
+	"io/ioutil"
+	"net/http"
+	"net/http/httptest"
+	"net/url"
+	"os"
+	"strings"
+	"testing"
+	"time"
+
+	"golang.org/x/crypto/ssh"
+	"golang.org/x/oauth2"
+
+	"github.com/gorilla/sessions"
+	"github.com/nsheridan/cashier/lib"
+	"github.com/nsheridan/cashier/server/auth"
+	"github.com/nsheridan/cashier/server/auth/testprovider"
+	"github.com/nsheridan/cashier/server/config"
+	"github.com/nsheridan/cashier/server/signer"
+	"github.com/nsheridan/cashier/server/store"
+	"github.com/nsheridan/cashier/testdata"
+)
+
+func newContext(t *testing.T) *appContext {
+	f, err := ioutil.TempFile(os.TempDir(), "signing_key_")
+	if err != nil {
+		t.Error(err)
+	}
+	defer os.Remove(f.Name())
+	f.Write(testdata.Priv)
+	f.Close()
+	signer, err := signer.New(&config.SSH{
+		SigningKey: f.Name(),
+		MaxAge:     "1h",
+	})
+	if err != nil {
+		t.Error(err)
+	}
+	return &appContext{
+		cookiestore:  sessions.NewCookieStore([]byte("secret")),
+		authprovider: testprovider.New(),
+		certstore:    store.NewMemoryStore(),
+		authsession:  &auth.Session{AuthURL: "https://www.example.com/auth"},
+		sshKeySigner: signer,
+	}
+}
+
+func TestLoginHandler(t *testing.T) {
+	req, _ := http.NewRequest("GET", "/auth/login", nil)
+	resp := httptest.NewRecorder()
+	loginHandler(newContext(t), resp, req)
+	if resp.Code != http.StatusFound && resp.Header().Get("Location") != "https://www.example.com/auth" {
+		t.Error("Unexpected response")
+	}
+}
+
+func TestCallbackHandler(t *testing.T) {
+	req, _ := http.NewRequest("GET", "/auth/callback", nil)
+	req.Form = url.Values{"state": []string{"state"}, "code": []string{"abcdef"}}
+	resp := httptest.NewRecorder()
+	ctx := newContext(t)
+	ctx.setAuthStateCookie(resp, req, "state")
+	callbackHandler(ctx, resp, req)
+	if resp.Code != http.StatusFound && resp.Header().Get("Location") != "/" {
+		t.Error("Unexpected response")
+	}
+}
+
+func TestRootHandler(t *testing.T) {
+	req, _ := http.NewRequest("GET", "/", nil)
+	resp := httptest.NewRecorder()
+	ctx := newContext(t)
+	tok := &oauth2.Token{
+		AccessToken: "XXX_TEST_TOKEN_STRING_XXX",
+		Expiry:      time.Now().Add(1 * time.Hour),
+	}
+	ctx.setAuthTokenCookie(resp, req, tok)
+	rootHandler(ctx, resp, req)
+	if resp.Code != http.StatusOK && !strings.Contains(resp.Body.String(), "XXX_TEST_TOKEN_STRING_XXX") {
+		t.Error("Unable to find token in response")
+	}
+}
+
+func TestRootHandlerNoSession(t *testing.T) {
+	req, _ := http.NewRequest("GET", "/", nil)
+	resp := httptest.NewRecorder()
+	ctx := newContext(t)
+	rootHandler(ctx, resp, req)
+	if resp.Code != http.StatusSeeOther {
+		t.Errorf("Unexpected status: %s, wanted %s", http.StatusText(resp.Code), http.StatusText(http.StatusSeeOther))
+	}
+}
+
+func TestSignRevoke(t *testing.T) {
+	s, _ := json.Marshal(&lib.SignRequest{
+		Key: string(testdata.Pub),
+	})
+	req, _ := http.NewRequest("POST", "/sign", bytes.NewReader(s))
+	resp := httptest.NewRecorder()
+	ctx := newContext(t)
+	req.Header.Set("Authorization", "Bearer abcdef")
+	signHandler(ctx, resp, req)
+	if resp.Code != http.StatusOK {
+		t.Error("Unexpected response")
+	}
+	b, err := ioutil.ReadAll(resp.Body)
+	if err != nil {
+		t.Error(err)
+	}
+	r := &lib.SignResponse{}
+	if err := json.Unmarshal(b, r); err != nil {
+		t.Error(err)
+	}
+	if r.Status != "ok" {
+		t.Error("Unexpected response")
+	}
+	k, _, _, _, err := ssh.ParseAuthorizedKey([]byte(r.Response))
+	if err != nil {
+		t.Error(err)
+	}
+	cert, ok := k.(*ssh.Certificate)
+	if !ok {
+		t.Error("Did not receive a certificate")
+	}
+	// Revoke the cert and verify
+	req, _ = http.NewRequest("POST", "/revoke", nil)
+	req.Form = url.Values{"cert_id": []string{cert.KeyId}}
+	revokeCertHandler(ctx, resp, req)
+	req, _ = http.NewRequest("GET", "/revoked", nil)
+	revokedCertsHandler(ctx, resp, req)
+	revoked, _ := ioutil.ReadAll(resp.Body)
+	if string(revoked[:len(revoked)-1]) != r.Response {
+		t.Error("omg")
+	}
+}
diff --git a/cmd/cashierd/main.go b/cmd/cashierd/main.go
index 1db7d301..31ba104c 100644
--- a/cmd/cashierd/main.go
+++ b/cmd/cashierd/main.go
@@ -123,11 +123,11 @@ func signHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, er
 
 	// Sign the pubkey and issue the cert.
 	req, err := parseKey(r)
-	req.Principal = a.authprovider.Username(token)
-	a.authprovider.Revoke(token) // We don't need this anymore.
 	if err != nil {
 		return http.StatusInternalServerError, err
 	}
+	req.Principal = a.authprovider.Username(token)
+	a.authprovider.Revoke(token) // We don't need this anymore.
 	cert, err := a.sshKeySigner.SignUserKey(req)
 	if err != nil {
 		return http.StatusInternalServerError, err
@@ -199,9 +199,6 @@ func revokedCertsHandler(a *appContext, w http.ResponseWriter, r *http.Request)
 }
 
 func revokeCertHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) {
-	if r.Method == "GET" {
-		return http.StatusMethodNotAllowed, errors.New(http.StatusText(http.StatusMethodNotAllowed))
-	}
 	r.ParseForm()
 	id := r.FormValue("cert_id")
 	if id == "" {
@@ -268,7 +265,7 @@ func main() {
 		log.Fatal(err)
 	}
 	fs.Register(&config.AWS)
-	signer, err := signer.New(config.SSH)
+	signer, err := signer.New(&config.SSH)
 	if err != nil {
 		log.Fatal(err)
 	}
@@ -304,12 +301,12 @@ func main() {
 	}
 
 	r := mux.NewRouter()
-	r.Handle("/", appHandler{ctx, rootHandler})
-	r.Handle("/auth/login", appHandler{ctx, loginHandler})
-	r.Handle("/auth/callback", appHandler{ctx, callbackHandler})
-	r.Handle("/sign", appHandler{ctx, signHandler})
-	r.Handle("/revoked", appHandler{ctx, revokedCertsHandler})
-	r.Handle("/revoke", appHandler{ctx, revokeCertHandler})
+	r.Methods("GET").Path("/").Handler(appHandler{ctx, rootHandler})
+	r.Methods("GET").Path("/auth/login").Handler(appHandler{ctx, loginHandler})
+	r.Methods("GET").Path("/auth/callback").Handler(appHandler{ctx, callbackHandler})
+	r.Methods("POST").Path("/sign").Handler(appHandler{ctx, signHandler})
+	r.Methods("GET").Path("/revoked").Handler(appHandler{ctx, revokedCertsHandler})
+	r.Methods("POST").Path("/revoke").Handler(appHandler{ctx, revokeCertHandler})
 	logfile := os.Stderr
 	if config.Server.HTTPLogFile != "" {
 		logfile, err = os.OpenFile(config.Server.HTTPLogFile, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0660)
diff --git a/server/auth/testprovider/testprovider.go b/server/auth/testprovider/testprovider.go
new file mode 100644
index 00000000..3d2b13a4
--- /dev/null
+++ b/server/auth/testprovider/testprovider.go
@@ -0,0 +1,56 @@
+package testprovider
+
+import (
+	"time"
+
+	"github.com/nsheridan/cashier/server/auth"
+
+	"golang.org/x/oauth2"
+)
+
+const (
+	name = "testprovider"
+)
+
+// Config is an implementation of `auth.Provider` for testing.
+type Config struct{}
+
+// New creates a new provider.
+func New() auth.Provider {
+	return &Config{}
+}
+
+// Name returns the name of the provider.
+func (c *Config) Name() string {
+	return name
+}
+
+// Valid validates the oauth token.
+func (c *Config) Valid(token *oauth2.Token) bool {
+	return true
+}
+
+// Revoke disables the access token.
+func (c *Config) Revoke(token *oauth2.Token) error {
+	return nil
+}
+
+// StartSession retrieves an authentication endpoint.
+func (c *Config) StartSession(state string) *auth.Session {
+	return &auth.Session{
+		AuthURL: "https://www.example.com/auth",
+	}
+}
+
+// Exchange authorizes the session and returns an access token.
+func (c *Config) Exchange(code string) (*oauth2.Token, error) {
+	return &oauth2.Token{
+		AccessToken: "token",
+		Expiry:      time.Now().Add(1 * time.Hour),
+	}, nil
+}
+
+// Username retrieves the username portion of the user's email address.
+func (c *Config) Username(token *oauth2.Token) string {
+	return "test"
+}
diff --git a/server/signer/signer.go b/server/signer/signer.go
index a3f056aa..8169c11b 100644
--- a/server/signer/signer.go
+++ b/server/signer/signer.go
@@ -69,7 +69,7 @@ func makeperms(perms []string) map[string]string {
 }
 
 // New creates a new KeySigner from the supplied configuration.
-func New(conf config.SSH) (*KeySigner, error) {
+func New(conf *config.SSH) (*KeySigner, error) {
 	data, err := wkfs.ReadFile(conf.SigningKey)
 	if err != nil {
 		return nil, fmt.Errorf("unable to read CA key %s: %v", conf.SigningKey, err)
-- 
GitLab