Skip to content
Snippets Groups Projects
Commit 5166bbbf authored by Niall Sheridan's avatar Niall Sheridan
Browse files

Merge pull request #15 from nsheridan/auth_state

Save oauth 'state' identifier in the client
parents a52d19e9 f4567532
No related branches found
No related tags found
No related merge requests found
...@@ -40,9 +40,9 @@ type appContext struct { ...@@ -40,9 +40,9 @@ type appContext struct {
sshKeySigner *signer.KeySigner sshKeySigner *signer.KeySigner
} }
// getAuthCookie retrieves a cookie from the request and validates it. // getAuthTokenCookie retrieves a cookie from the request.
func (a *appContext) getAuthCookie(r *http.Request) *oauth2.Token { func (a *appContext) getAuthTokenCookie(r *http.Request) *oauth2.Token {
session, _ := a.cookiestore.Get(r, "tok") session, _ := a.cookiestore.Get(r, "session")
t, ok := session.Values["token"] t, ok := session.Values["token"]
if !ok { if !ok {
return nil return nil
...@@ -57,14 +57,31 @@ func (a *appContext) getAuthCookie(r *http.Request) *oauth2.Token { ...@@ -57,14 +57,31 @@ func (a *appContext) getAuthCookie(r *http.Request) *oauth2.Token {
return &tok return &tok
} }
// setAuthCookie marshals the auth token and stores it as a cookie. // setAuthTokenCookie marshals the auth token and stores it as a cookie.
func (a *appContext) setAuthCookie(w http.ResponseWriter, r *http.Request, t *oauth2.Token) { func (a *appContext) setAuthTokenCookie(w http.ResponseWriter, r *http.Request, t *oauth2.Token) {
session, _ := a.cookiestore.Get(r, "tok") session, _ := a.cookiestore.Get(r, "session")
val, _ := json.Marshal(t) val, _ := json.Marshal(t)
session.Values["token"] = val session.Values["token"] = val
session.Save(r, w) session.Save(r, w)
} }
// getAuthStateCookie retrieves the oauth csrf state value from the client request.
func (a *appContext) getAuthStateCookie(r *http.Request) string {
session, _ := a.cookiestore.Get(r, "session")
state, ok := session.Values["state"]
if !ok {
return ""
}
return state.(string)
}
// setAuthStateCookie saves the oauth csrf state value.
func (a *appContext) setAuthStateCookie(w http.ResponseWriter, r *http.Request, state string) {
session, _ := a.cookiestore.Get(r, "session")
session.Values["state"] = state
session.Save(r, w)
}
// parseKey retrieves and unmarshals the signing request. // parseKey retrieves and unmarshals the signing request.
func parseKey(r *http.Request) (*lib.SignRequest, error) { func parseKey(r *http.Request) (*lib.SignRequest, error) {
var s lib.SignRequest var s lib.SignRequest
...@@ -118,28 +135,30 @@ func signHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, er ...@@ -118,28 +135,30 @@ func signHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, er
// loginHandler starts the authentication process with the provider. // loginHandler starts the authentication process with the provider.
func loginHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) { func loginHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) {
a.authsession = a.authprovider.StartSession(newState()) state := newState()
a.setAuthStateCookie(w, r, state)
a.authsession = a.authprovider.StartSession(state)
http.Redirect(w, r, a.authsession.AuthURL, http.StatusFound) http.Redirect(w, r, a.authsession.AuthURL, http.StatusFound)
return http.StatusFound, nil return http.StatusFound, nil
} }
// callbackHandler handles retrieving the access token from the auth provider and saves it for later use. // callbackHandler handles retrieving the access token from the auth provider and saves it for later use.
func callbackHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) { func callbackHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) {
if r.FormValue("state") != a.authsession.State { if r.FormValue("state") != a.getAuthStateCookie(r) {
return http.StatusUnauthorized, errors.New(http.StatusText(http.StatusUnauthorized)) return http.StatusUnauthorized, errors.New(http.StatusText(http.StatusUnauthorized))
} }
code := r.FormValue("code") code := r.FormValue("code")
if err := a.authsession.Authorize(a.authprovider, code); err != nil { if err := a.authsession.Authorize(a.authprovider, code); err != nil {
return http.StatusInternalServerError, err return http.StatusInternalServerError, err
} }
a.setAuthCookie(w, r, a.authsession.Token) a.setAuthTokenCookie(w, r, a.authsession.Token)
http.Redirect(w, r, "/", http.StatusFound) http.Redirect(w, r, "/", http.StatusFound)
return http.StatusFound, nil return http.StatusFound, nil
} }
// rootHandler starts the auth process. If the client is authenticated it renders the token to the user. // rootHandler starts the auth process. If the client is authenticated it renders the token to the user.
func rootHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) { func rootHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) {
tok := a.getAuthCookie(r) tok := a.getAuthTokenCookie(r)
if !tok.Valid() || !a.authprovider.Valid(tok) { if !tok.Valid() || !a.authprovider.Valid(tok) {
http.Redirect(w, r, "/auth/login", http.StatusSeeOther) http.Redirect(w, r, "/auth/login", http.StatusSeeOther)
return http.StatusSeeOther, nil return http.StatusSeeOther, nil
......
...@@ -78,7 +78,6 @@ func (c *Config) Revoke(token *oauth2.Token) error { ...@@ -78,7 +78,6 @@ func (c *Config) Revoke(token *oauth2.Token) error {
func (c *Config) StartSession(state string) *auth.Session { func (c *Config) StartSession(state string) *auth.Session {
return &auth.Session{ return &auth.Session{
AuthURL: c.config.AuthCodeURL(state), AuthURL: c.config.AuthCodeURL(state),
State: state,
} }
} }
......
...@@ -42,7 +42,6 @@ func TestStartSession(t *testing.T) { ...@@ -42,7 +42,6 @@ func TestStartSession(t *testing.T) {
p, _ := newGithub() p, _ := newGithub()
s := p.StartSession("test_state") s := p.StartSession("test_state")
a.Equal(s.State, "test_state")
a.Contains(s.AuthURL, "github.com/login/oauth/authorize") a.Contains(s.AuthURL, "github.com/login/oauth/authorize")
a.Contains(s.AuthURL, "state=test_state") a.Contains(s.AuthURL, "state=test_state")
a.Contains(s.AuthURL, fmt.Sprintf("client_id=%s", oauthClientID)) a.Contains(s.AuthURL, fmt.Sprintf("client_id=%s", oauthClientID))
......
...@@ -90,7 +90,6 @@ func (c *Config) Revoke(token *oauth2.Token) error { ...@@ -90,7 +90,6 @@ func (c *Config) Revoke(token *oauth2.Token) error {
func (c *Config) StartSession(state string) *auth.Session { func (c *Config) StartSession(state string) *auth.Session {
return &auth.Session{ return &auth.Session{
AuthURL: c.config.AuthCodeURL(state, oauth2.SetAuthURLParam("hd", c.domain)), AuthURL: c.config.AuthCodeURL(state, oauth2.SetAuthURLParam("hd", c.domain)),
State: state,
} }
} }
......
...@@ -44,7 +44,6 @@ func TestStartSession(t *testing.T) { ...@@ -44,7 +44,6 @@ func TestStartSession(t *testing.T) {
p, err := newGoogle() p, err := newGoogle()
a.NoError(err) a.NoError(err)
s := p.StartSession("test_state") s := p.StartSession("test_state")
a.Equal(s.State, "test_state")
a.Contains(s.AuthURL, "accounts.google.com/o/oauth2/auth") a.Contains(s.AuthURL, "accounts.google.com/o/oauth2/auth")
a.Contains(s.AuthURL, "state=test_state") a.Contains(s.AuthURL, "state=test_state")
a.Contains(s.AuthURL, fmt.Sprintf("hd=%s", domain)) a.Contains(s.AuthURL, fmt.Sprintf("hd=%s", domain))
......
...@@ -16,7 +16,6 @@ type Provider interface { ...@@ -16,7 +16,6 @@ type Provider interface {
type Session struct { type Session struct {
AuthURL string AuthURL string
Token *oauth2.Token Token *oauth2.Token
State string
} }
// Authorize obtains data from the provider and retains an access token that // Authorize obtains data from the provider and retains an access token that
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment