From f5661c0b4a07934cde304a9a0768fce9aec739a6 Mon Sep 17 00:00:00 2001
From: Kevin Lyda <kevin@ie.suberic.net>
Date: Mon, 17 Sep 2018 12:35:08 +0100
Subject: [PATCH] Much simpler version thanks to Niall's feedback.

---
 client/client.go    | 72 +++++++++++++++------------------------------
 cmd/cashier/main.go |  6 ++--
 2 files changed, 27 insertions(+), 51 deletions(-)

diff --git a/client/client.go b/client/client.go
index e46fdd65..3116ab8b 100644
--- a/client/client.go
+++ b/client/client.go
@@ -3,12 +3,14 @@ package client
 import (
 	"bufio"
 	"bytes"
+	"context"
 	"crypto/tls"
 	"encoding/base64"
 	"encoding/json"
 	"encoding/pem"
 	"fmt"
 	"io/ioutil"
+	"net"
 	"net/http"
 	"net/url"
 	"os"
@@ -175,23 +177,18 @@ func Sign(pub ssh.PublicKey, token string, conf *Config) (*ssh.Certificate, erro
 
 // Listener type contains information for the client listener.
 type Listener struct {
-	Srv       *http.Server
-	TargetURL string
-	Token     chan string
+	srv         *http.Server
+	ReceiverURL string
+	Token       chan string
 }
 
 // StartHTTPServer starts an http server in the background.
 func StartHTTPServer() *Listener {
-	listening := make(chan bool)
 	listener := &Listener{
-		Srv:   &http.Server{},
+		srv:   &http.Server{},
 		Token: make(chan string),
 	}
-	timeout := 5 * time.Second          // TODO: Configurable?
-	portStart := 8050                   // TODO: Configurable?
-	portCheck := []byte("OK")           // TODO: Random?
 	authCallbackURL := "/auth/callback" // TODO: Random?
-	portCheckURL := "/port/check"       // TODO: Random?
 
 	http.HandleFunc(authCallbackURL,
 		func(w http.ResponseWriter, r *http.Request) {
@@ -201,50 +198,27 @@ func StartHTTPServer() *Listener {
 			listener.Token <- r.FormValue("token")
 		})
 
-	http.HandleFunc(portCheckURL,
-		func(w http.ResponseWriter, r *http.Request) {
-			listening <- true
-			w.Write(portCheck)
-		})
-
 	// Create the http server.
-	go func() {
-		for port := portStart; port < 65535; port++ {
-			listener.Srv.Addr = fmt.Sprintf("localhost:%d", port)
-			if err := listener.Srv.ListenAndServe(); err != nil {
-				if strings.Contains(err.Error(), "Server closed") {
-					return // Shutdown was called.
-				} else if !strings.Contains(err.Error(), "address already in use") {
-					fmt.Printf("Httpserver: ListenAndServe() error: %s", err)
-					return // Some other error.
-				}
-			}
-		}
-	}()
+	l, err := net.Listen("tcp", ":0")
+	if err != nil {
+		return nil
+	}
+	port := l.Addr().(*net.TCPAddr).Port
+	listener.ReceiverURL = fmt.Sprintf("http://localhost:%d%s",
+		port, authCallbackURL)
 
-	// Make sure http server is up.
 	go func() {
-		for i := 0 * time.Second; i < timeout; i += time.Second {
-			time.Sleep(1)
-			resp, err := http.Get(
-				fmt.Sprintf("http://%s%s", listener.Srv.Addr, portCheckURL))
-			if err != nil {
-				continue
-			}
-			defer resp.Body.Close()
-			body, err := ioutil.ReadAll(resp.Body)
-			if bytes.Equal(body, portCheck) {
-				return
-			}
+		err := listener.srv.Serve(l)
+		if err == http.ErrServerClosed {
+			fmt.Printf("Httpserver: Server() error: %s", err)
 		}
+		return
 	}()
 
-	select {
-	case <-listening:
-		listener.TargetURL =
-			fmt.Sprintf("http://%s%s", listener.Srv.Addr, authCallbackURL)
-		return listener
-	case <-time.After(timeout):
-		return nil
-	}
+	return listener
+}
+
+// Shutdown stops the server created in StartHTTPServer.
+func (l *Listener) Shutdown() {
+	l.srv.Shutdown(context.Background())
 }
diff --git a/cmd/cashier/main.go b/cmd/cashier/main.go
index da4ad54f..4400e7c5 100644
--- a/cmd/cashier/main.go
+++ b/cmd/cashier/main.go
@@ -51,11 +51,12 @@ func main() {
 		log.Fatalln("Error generating key pair: ", err)
 	}
 	authURL := c.CA
-	var listener *client.Listener
+	listener := &client.Listener{}
 	if c.AutoToken {
 		listener = client.StartHTTPServer()
 		if listener != nil {
-			authURL = fmt.Sprintf("%s?auto_token=%s", c.CA, url.PathEscape(listener.TargetURL))
+			authURL = fmt.Sprintf("%s?auto_token=%s",
+				c.CA, url.PathEscape(listener.ReceiverURL))
 		}
 	}
 	fmt.Printf("Your browser has been opened to visit %s\n", authURL)
@@ -67,6 +68,7 @@ func main() {
 	if listener != nil {
 		// TODO: Timeout?
 		token = <-listener.Token
+		listener.Shutdown()
 	}
 	if token == "" {
 		fmt.Print("Enter token: ")
-- 
GitLab