diff --git a/client/client.go b/client/client.go index b84e09a2b04a57a6bffd51116c7c200b591630c0..e46fdd658d4debd899af79c2ad938a0a2b959ed4 100644 --- a/client/client.go +++ b/client/client.go @@ -172,3 +172,79 @@ func Sign(pub ssh.PublicKey, token string, conf *Config) (*ssh.Certificate, erro } return cert, nil } + +// Listener type contains information for the client listener. +type Listener struct { + Srv *http.Server + TargetURL string + Token chan string +} + +// StartHTTPServer starts an http server in the background. +func StartHTTPServer() *Listener { + listening := make(chan bool) + listener := &Listener{ + Srv: &http.Server{}, + Token: make(chan string), + } + timeout := 5 * time.Second // TODO: Configurable? + portStart := 8050 // TODO: Configurable? + portCheck := []byte("OK") // TODO: Random? + authCallbackURL := "/auth/callback" // TODO: Random? + portCheckURL := "/port/check" // TODO: Random? + + http.HandleFunc(authCallbackURL, + func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.Write([]byte("<html><head><title>Authorized</title></head><body>Authorized. You can now close this window.</body></html>")) + defer r.Body.Close() + listener.Token <- r.FormValue("token") + }) + + http.HandleFunc(portCheckURL, + func(w http.ResponseWriter, r *http.Request) { + listening <- true + w.Write(portCheck) + }) + + // Create the http server. + go func() { + for port := portStart; port < 65535; port++ { + listener.Srv.Addr = fmt.Sprintf("localhost:%d", port) + if err := listener.Srv.ListenAndServe(); err != nil { + if strings.Contains(err.Error(), "Server closed") { + return // Shutdown was called. + } else if !strings.Contains(err.Error(), "address already in use") { + fmt.Printf("Httpserver: ListenAndServe() error: %s", err) + return // Some other error. + } + } + } + }() + + // Make sure http server is up. + go func() { + for i := 0 * time.Second; i < timeout; i += time.Second { + time.Sleep(1) + resp, err := http.Get( + fmt.Sprintf("http://%s%s", listener.Srv.Addr, portCheckURL)) + if err != nil { + continue + } + defer resp.Body.Close() + body, err := ioutil.ReadAll(resp.Body) + if bytes.Equal(body, portCheck) { + return + } + } + }() + + select { + case <-listening: + listener.TargetURL = + fmt.Sprintf("http://%s%s", listener.Srv.Addr, authCallbackURL) + return listener + case <-time.After(timeout): + return nil + } +} diff --git a/client/config.go b/client/config.go index 4536994b7ac820265956fbdc9375b2df1493a89a..22e3b09e76969d6e3138890cecf235dd0dd6b68a 100644 --- a/client/config.go +++ b/client/config.go @@ -16,6 +16,7 @@ type Config struct { Validity string `mapstructure:"validity"` ValidateTLSCertificate bool `mapstructure:"validate_tls_certificate"` PublicFilePrefix string `mapstructure:"key_file_prefix"` + AutoToken bool `mapstructure:"auto_token"` } func setDefaults() { @@ -25,6 +26,7 @@ func setDefaults() { viper.BindPFlag("validity", pflag.Lookup("validity")) viper.BindPFlag("key_file_prefix", pflag.Lookup("key_file_prefix")) viper.SetDefault("validateTLSCertificate", true) + viper.SetDefault("auto_token", false) } // ReadConfig reads the client configuration from a file into a Config struct. diff --git a/cmd/cashier/main.go b/cmd/cashier/main.go index 54fad82d9563f3fd1cc36d06efbde0fdaede25a5..da4ad54f321a1650f39007286abc8207a708f55b 100644 --- a/cmd/cashier/main.go +++ b/cmd/cashier/main.go @@ -7,6 +7,7 @@ import ( "fmt" "log" "net" + "net/url" "os" "os/user" "path" @@ -49,22 +50,37 @@ func main() { if err != nil { log.Fatalln("Error generating key pair: ", err) } - fmt.Printf("Your browser has been opened to visit %s\n", c.CA) - if err := browser.OpenURL(c.CA); err != nil { + authURL := c.CA + var listener *client.Listener + if c.AutoToken { + listener = client.StartHTTPServer() + if listener != nil { + authURL = fmt.Sprintf("%s?auto_token=%s", c.CA, url.PathEscape(listener.TargetURL)) + } + } + fmt.Printf("Your browser has been opened to visit %s\n", authURL) + if err := browser.OpenURL(authURL); err != nil { fmt.Println("Error launching web browser. Go to the link in your web browser") } - fmt.Print("Enter token: ") - scanner := bufio.NewScanner(os.Stdin) - var buffer bytes.Buffer - for scanner.Scan(); scanner.Text() != "."; scanner.Scan() { - buffer.WriteString(scanner.Text()) + var token string + if listener != nil { + // TODO: Timeout? + token = <-listener.Token } - tokenBytes, err := base64.StdEncoding.DecodeString(buffer.String()) - if err != nil { - log.Fatalln(err) + if token == "" { + fmt.Print("Enter token: ") + scanner := bufio.NewScanner(os.Stdin) + var buffer bytes.Buffer + for scanner.Scan(); scanner.Text() != "."; scanner.Scan() { + buffer.WriteString(scanner.Text()) + } + tokenBytes, err := base64.StdEncoding.DecodeString(buffer.String()) + if err != nil { + log.Fatalln(err) + } + token = string(tokenBytes) } - token := string(tokenBytes) cert, err := client.Sign(pub, token, c) if err != nil { diff --git a/server/handlers.go b/server/handlers.go index 0ade8ad7ecc5c088ee5ab6bee7d3e01ed8ded616..0b56cb5a89f70f89ecba01fdc29662b41943fcf7 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -115,15 +115,19 @@ func (a *app) auth(w http.ResponseWriter, r *http.Request) { } func (a *app) index(w http.ResponseWriter, r *http.Request) { - log.Printf("Entering index handler.") tok := a.getAuthToken(r) - log.Printf("Token found: %v\n", tok) - page := struct { - Token string - }{tok.AccessToken} - page.Token = encodeString(page.Token) - tmpl := template.Must(template.New("token.html").Parse(templates.Token)) - tmpl.Execute(w, page) + autoTokenURL := a.getSessionVariable(r, "auto_token") + if autoTokenURL != "" { + http.Redirect(w, r, fmt.Sprintf("%s?token=%s", + autoTokenURL, tok.AccessToken), http.StatusSeeOther) + } else { + page := struct { + Token string + }{tok.AccessToken} + page.Token = encodeString(page.Token) + tmpl := template.Must(template.New("token.html").Parse(templates.Token)) + tmpl.Execute(w, page) + } } func (a *app) revoked(w http.ResponseWriter, r *http.Request) { diff --git a/server/server.go b/server/server.go index 2a6af15b8dea9299e7a303fff7723e7b9dd806b4..9357239a115ad4cb914f1e9bbd26e1731f564d34 100644 --- a/server/server.go +++ b/server/server.go @@ -253,6 +253,7 @@ func (a *app) authed(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { t := a.getAuthToken(r) if !t.Valid() || !a.authprovider.Valid(t) { + a.setSessionVariable(w, r, "auto_token", r.FormValue("auto_token")) a.setSessionVariable(w, r, "origin_url", r.URL.EscapedPath()) http.Redirect(w, r, "/auth/login", http.StatusSeeOther) return