Skip to content
Snippets Groups Projects
Commit 49f40a95 authored by Niall Sheridan's avatar Niall Sheridan
Browse files

Add some handlers tests

parent dee5a19d
No related branches found
No related tags found
No related merge requests found
package main
import (
"bytes"
"encoding/json"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"os"
"strings"
"testing"
"time"
"golang.org/x/crypto/ssh"
"golang.org/x/oauth2"
"github.com/gorilla/sessions"
"github.com/nsheridan/cashier/lib"
"github.com/nsheridan/cashier/server/auth"
"github.com/nsheridan/cashier/server/auth/testprovider"
"github.com/nsheridan/cashier/server/config"
"github.com/nsheridan/cashier/server/signer"
"github.com/nsheridan/cashier/server/store"
"github.com/nsheridan/cashier/testdata"
)
func newContext(t *testing.T) *appContext {
f, err := ioutil.TempFile(os.TempDir(), "signing_key_")
if err != nil {
t.Error(err)
}
defer os.Remove(f.Name())
f.Write(testdata.Priv)
f.Close()
signer, err := signer.New(&config.SSH{
SigningKey: f.Name(),
MaxAge: "1h",
})
if err != nil {
t.Error(err)
}
return &appContext{
cookiestore: sessions.NewCookieStore([]byte("secret")),
authprovider: testprovider.New(),
certstore: store.NewMemoryStore(),
authsession: &auth.Session{AuthURL: "https://www.example.com/auth"},
sshKeySigner: signer,
}
}
func TestLoginHandler(t *testing.T) {
req, _ := http.NewRequest("GET", "/auth/login", nil)
resp := httptest.NewRecorder()
loginHandler(newContext(t), resp, req)
if resp.Code != http.StatusFound && resp.Header().Get("Location") != "https://www.example.com/auth" {
t.Error("Unexpected response")
}
}
func TestCallbackHandler(t *testing.T) {
req, _ := http.NewRequest("GET", "/auth/callback", nil)
req.Form = url.Values{"state": []string{"state"}, "code": []string{"abcdef"}}
resp := httptest.NewRecorder()
ctx := newContext(t)
ctx.setAuthStateCookie(resp, req, "state")
callbackHandler(ctx, resp, req)
if resp.Code != http.StatusFound && resp.Header().Get("Location") != "/" {
t.Error("Unexpected response")
}
}
func TestRootHandler(t *testing.T) {
req, _ := http.NewRequest("GET", "/", nil)
resp := httptest.NewRecorder()
ctx := newContext(t)
tok := &oauth2.Token{
AccessToken: "XXX_TEST_TOKEN_STRING_XXX",
Expiry: time.Now().Add(1 * time.Hour),
}
ctx.setAuthTokenCookie(resp, req, tok)
rootHandler(ctx, resp, req)
if resp.Code != http.StatusOK && !strings.Contains(resp.Body.String(), "XXX_TEST_TOKEN_STRING_XXX") {
t.Error("Unable to find token in response")
}
}
func TestRootHandlerNoSession(t *testing.T) {
req, _ := http.NewRequest("GET", "/", nil)
resp := httptest.NewRecorder()
ctx := newContext(t)
rootHandler(ctx, resp, req)
if resp.Code != http.StatusSeeOther {
t.Errorf("Unexpected status: %s, wanted %s", http.StatusText(resp.Code), http.StatusText(http.StatusSeeOther))
}
}
func TestSignRevoke(t *testing.T) {
s, _ := json.Marshal(&lib.SignRequest{
Key: string(testdata.Pub),
})
req, _ := http.NewRequest("POST", "/sign", bytes.NewReader(s))
resp := httptest.NewRecorder()
ctx := newContext(t)
req.Header.Set("Authorization", "Bearer abcdef")
signHandler(ctx, resp, req)
if resp.Code != http.StatusOK {
t.Error("Unexpected response")
}
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Error(err)
}
r := &lib.SignResponse{}
if err := json.Unmarshal(b, r); err != nil {
t.Error(err)
}
if r.Status != "ok" {
t.Error("Unexpected response")
}
k, _, _, _, err := ssh.ParseAuthorizedKey([]byte(r.Response))
if err != nil {
t.Error(err)
}
cert, ok := k.(*ssh.Certificate)
if !ok {
t.Error("Did not receive a certificate")
}
// Revoke the cert and verify
req, _ = http.NewRequest("POST", "/revoke", nil)
req.Form = url.Values{"cert_id": []string{cert.KeyId}}
revokeCertHandler(ctx, resp, req)
req, _ = http.NewRequest("GET", "/revoked", nil)
revokedCertsHandler(ctx, resp, req)
revoked, _ := ioutil.ReadAll(resp.Body)
if string(revoked[:len(revoked)-1]) != r.Response {
t.Error("omg")
}
}
......@@ -123,11 +123,11 @@ func signHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, er
// Sign the pubkey and issue the cert.
req, err := parseKey(r)
req.Principal = a.authprovider.Username(token)
a.authprovider.Revoke(token) // We don't need this anymore.
if err != nil {
return http.StatusInternalServerError, err
}
req.Principal = a.authprovider.Username(token)
a.authprovider.Revoke(token) // We don't need this anymore.
cert, err := a.sshKeySigner.SignUserKey(req)
if err != nil {
return http.StatusInternalServerError, err
......@@ -199,9 +199,6 @@ func revokedCertsHandler(a *appContext, w http.ResponseWriter, r *http.Request)
}
func revokeCertHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) {
if r.Method == "GET" {
return http.StatusMethodNotAllowed, errors.New(http.StatusText(http.StatusMethodNotAllowed))
}
r.ParseForm()
id := r.FormValue("cert_id")
if id == "" {
......@@ -268,7 +265,7 @@ func main() {
log.Fatal(err)
}
fs.Register(&config.AWS)
signer, err := signer.New(config.SSH)
signer, err := signer.New(&config.SSH)
if err != nil {
log.Fatal(err)
}
......@@ -304,12 +301,12 @@ func main() {
}
r := mux.NewRouter()
r.Handle("/", appHandler{ctx, rootHandler})
r.Handle("/auth/login", appHandler{ctx, loginHandler})
r.Handle("/auth/callback", appHandler{ctx, callbackHandler})
r.Handle("/sign", appHandler{ctx, signHandler})
r.Handle("/revoked", appHandler{ctx, revokedCertsHandler})
r.Handle("/revoke", appHandler{ctx, revokeCertHandler})
r.Methods("GET").Path("/").Handler(appHandler{ctx, rootHandler})
r.Methods("GET").Path("/auth/login").Handler(appHandler{ctx, loginHandler})
r.Methods("GET").Path("/auth/callback").Handler(appHandler{ctx, callbackHandler})
r.Methods("POST").Path("/sign").Handler(appHandler{ctx, signHandler})
r.Methods("GET").Path("/revoked").Handler(appHandler{ctx, revokedCertsHandler})
r.Methods("POST").Path("/revoke").Handler(appHandler{ctx, revokeCertHandler})
logfile := os.Stderr
if config.Server.HTTPLogFile != "" {
logfile, err = os.OpenFile(config.Server.HTTPLogFile, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0660)
......
......
package testprovider
import (
"time"
"github.com/nsheridan/cashier/server/auth"
"golang.org/x/oauth2"
)
const (
name = "testprovider"
)
// Config is an implementation of `auth.Provider` for testing.
type Config struct{}
// New creates a new provider.
func New() auth.Provider {
return &Config{}
}
// Name returns the name of the provider.
func (c *Config) Name() string {
return name
}
// Valid validates the oauth token.
func (c *Config) Valid(token *oauth2.Token) bool {
return true
}
// Revoke disables the access token.
func (c *Config) Revoke(token *oauth2.Token) error {
return nil
}
// StartSession retrieves an authentication endpoint.
func (c *Config) StartSession(state string) *auth.Session {
return &auth.Session{
AuthURL: "https://www.example.com/auth",
}
}
// Exchange authorizes the session and returns an access token.
func (c *Config) Exchange(code string) (*oauth2.Token, error) {
return &oauth2.Token{
AccessToken: "token",
Expiry: time.Now().Add(1 * time.Hour),
}, nil
}
// Username retrieves the username portion of the user's email address.
func (c *Config) Username(token *oauth2.Token) string {
return "test"
}
......@@ -69,7 +69,7 @@ func makeperms(perms []string) map[string]string {
}
// New creates a new KeySigner from the supplied configuration.
func New(conf config.SSH) (*KeySigner, error) {
func New(conf *config.SSH) (*KeySigner, error) {
data, err := wkfs.ReadFile(conf.SigningKey)
if err != nil {
return nil, fmt.Errorf("unable to read CA key %s: %v", conf.SigningKey, err)
......
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment