Skip to content
Snippets Groups Projects
Commit f4567532 authored by Niall Sheridan's avatar Niall Sheridan
Browse files

Save oauth 'state' identifier in the client

parent a52d19e9
Branches
Tags
No related merge requests found
...@@ -40,9 +40,9 @@ type appContext struct { ...@@ -40,9 +40,9 @@ type appContext struct {
sshKeySigner *signer.KeySigner sshKeySigner *signer.KeySigner
} }
// getAuthCookie retrieves a cookie from the request and validates it. // getAuthTokenCookie retrieves a cookie from the request.
func (a *appContext) getAuthCookie(r *http.Request) *oauth2.Token { func (a *appContext) getAuthTokenCookie(r *http.Request) *oauth2.Token {
session, _ := a.cookiestore.Get(r, "tok") session, _ := a.cookiestore.Get(r, "session")
t, ok := session.Values["token"] t, ok := session.Values["token"]
if !ok { if !ok {
return nil return nil
...@@ -57,14 +57,31 @@ func (a *appContext) getAuthCookie(r *http.Request) *oauth2.Token { ...@@ -57,14 +57,31 @@ func (a *appContext) getAuthCookie(r *http.Request) *oauth2.Token {
return &tok return &tok
} }
// setAuthCookie marshals the auth token and stores it as a cookie. // setAuthTokenCookie marshals the auth token and stores it as a cookie.
func (a *appContext) setAuthCookie(w http.ResponseWriter, r *http.Request, t *oauth2.Token) { func (a *appContext) setAuthTokenCookie(w http.ResponseWriter, r *http.Request, t *oauth2.Token) {
session, _ := a.cookiestore.Get(r, "tok") session, _ := a.cookiestore.Get(r, "session")
val, _ := json.Marshal(t) val, _ := json.Marshal(t)
session.Values["token"] = val session.Values["token"] = val
session.Save(r, w) session.Save(r, w)
} }
// getAuthStateCookie retrieves the oauth csrf state value from the client request.
func (a *appContext) getAuthStateCookie(r *http.Request) string {
session, _ := a.cookiestore.Get(r, "session")
state, ok := session.Values["state"]
if !ok {
return ""
}
return state.(string)
}
// setAuthStateCookie saves the oauth csrf state value.
func (a *appContext) setAuthStateCookie(w http.ResponseWriter, r *http.Request, state string) {
session, _ := a.cookiestore.Get(r, "session")
session.Values["state"] = state
session.Save(r, w)
}
// parseKey retrieves and unmarshals the signing request. // parseKey retrieves and unmarshals the signing request.
func parseKey(r *http.Request) (*lib.SignRequest, error) { func parseKey(r *http.Request) (*lib.SignRequest, error) {
var s lib.SignRequest var s lib.SignRequest
...@@ -118,28 +135,30 @@ func signHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, er ...@@ -118,28 +135,30 @@ func signHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, er
// loginHandler starts the authentication process with the provider. // loginHandler starts the authentication process with the provider.
func loginHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) { func loginHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) {
a.authsession = a.authprovider.StartSession(newState()) state := newState()
a.setAuthStateCookie(w, r, state)
a.authsession = a.authprovider.StartSession(state)
http.Redirect(w, r, a.authsession.AuthURL, http.StatusFound) http.Redirect(w, r, a.authsession.AuthURL, http.StatusFound)
return http.StatusFound, nil return http.StatusFound, nil
} }
// callbackHandler handles retrieving the access token from the auth provider and saves it for later use. // callbackHandler handles retrieving the access token from the auth provider and saves it for later use.
func callbackHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) { func callbackHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) {
if r.FormValue("state") != a.authsession.State { if r.FormValue("state") != a.getAuthStateCookie(r) {
return http.StatusUnauthorized, errors.New(http.StatusText(http.StatusUnauthorized)) return http.StatusUnauthorized, errors.New(http.StatusText(http.StatusUnauthorized))
} }
code := r.FormValue("code") code := r.FormValue("code")
if err := a.authsession.Authorize(a.authprovider, code); err != nil { if err := a.authsession.Authorize(a.authprovider, code); err != nil {
return http.StatusInternalServerError, err return http.StatusInternalServerError, err
} }
a.setAuthCookie(w, r, a.authsession.Token) a.setAuthTokenCookie(w, r, a.authsession.Token)
http.Redirect(w, r, "/", http.StatusFound) http.Redirect(w, r, "/", http.StatusFound)
return http.StatusFound, nil return http.StatusFound, nil
} }
// rootHandler starts the auth process. If the client is authenticated it renders the token to the user. // rootHandler starts the auth process. If the client is authenticated it renders the token to the user.
func rootHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) { func rootHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) {
tok := a.getAuthCookie(r) tok := a.getAuthTokenCookie(r)
if !tok.Valid() || !a.authprovider.Valid(tok) { if !tok.Valid() || !a.authprovider.Valid(tok) {
http.Redirect(w, r, "/auth/login", http.StatusSeeOther) http.Redirect(w, r, "/auth/login", http.StatusSeeOther)
return http.StatusSeeOther, nil return http.StatusSeeOther, nil
......
...@@ -78,7 +78,6 @@ func (c *Config) Revoke(token *oauth2.Token) error { ...@@ -78,7 +78,6 @@ func (c *Config) Revoke(token *oauth2.Token) error {
func (c *Config) StartSession(state string) *auth.Session { func (c *Config) StartSession(state string) *auth.Session {
return &auth.Session{ return &auth.Session{
AuthURL: c.config.AuthCodeURL(state), AuthURL: c.config.AuthCodeURL(state),
State: state,
} }
} }
......
...@@ -42,7 +42,6 @@ func TestStartSession(t *testing.T) { ...@@ -42,7 +42,6 @@ func TestStartSession(t *testing.T) {
p, _ := newGithub() p, _ := newGithub()
s := p.StartSession("test_state") s := p.StartSession("test_state")
a.Equal(s.State, "test_state")
a.Contains(s.AuthURL, "github.com/login/oauth/authorize") a.Contains(s.AuthURL, "github.com/login/oauth/authorize")
a.Contains(s.AuthURL, "state=test_state") a.Contains(s.AuthURL, "state=test_state")
a.Contains(s.AuthURL, fmt.Sprintf("client_id=%s", oauthClientID)) a.Contains(s.AuthURL, fmt.Sprintf("client_id=%s", oauthClientID))
......
...@@ -90,7 +90,6 @@ func (c *Config) Revoke(token *oauth2.Token) error { ...@@ -90,7 +90,6 @@ func (c *Config) Revoke(token *oauth2.Token) error {
func (c *Config) StartSession(state string) *auth.Session { func (c *Config) StartSession(state string) *auth.Session {
return &auth.Session{ return &auth.Session{
AuthURL: c.config.AuthCodeURL(state, oauth2.SetAuthURLParam("hd", c.domain)), AuthURL: c.config.AuthCodeURL(state, oauth2.SetAuthURLParam("hd", c.domain)),
State: state,
} }
} }
......
...@@ -44,7 +44,6 @@ func TestStartSession(t *testing.T) { ...@@ -44,7 +44,6 @@ func TestStartSession(t *testing.T) {
p, err := newGoogle() p, err := newGoogle()
a.NoError(err) a.NoError(err)
s := p.StartSession("test_state") s := p.StartSession("test_state")
a.Equal(s.State, "test_state")
a.Contains(s.AuthURL, "accounts.google.com/o/oauth2/auth") a.Contains(s.AuthURL, "accounts.google.com/o/oauth2/auth")
a.Contains(s.AuthURL, "state=test_state") a.Contains(s.AuthURL, "state=test_state")
a.Contains(s.AuthURL, fmt.Sprintf("hd=%s", domain)) a.Contains(s.AuthURL, fmt.Sprintf("hd=%s", domain))
......
...@@ -16,7 +16,6 @@ type Provider interface { ...@@ -16,7 +16,6 @@ type Provider interface {
type Session struct { type Session struct {
AuthURL string AuthURL string
Token *oauth2.Token Token *oauth2.Token
State string
} }
// Authorize obtains data from the provider and retains an access token that // Authorize obtains data from the provider and retains an access token that
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment