package server

import (
	"bytes"
	"crypto/tls"
	"encoding/base64"
	"encoding/json"
	"fmt"
	"log"
	"net"
	"net/http"
	"os"
	"time"

	"github.com/gorilla/csrf"

	"github.com/gobuffalo/packr"
	"github.com/gorilla/handlers"
	"github.com/prometheus/client_golang/prometheus/promhttp"

	"github.com/gorilla/mux"
	"github.com/gorilla/sessions"
	"github.com/pkg/errors"

	"go4.org/wkfs"
	"golang.org/x/crypto/acme/autocert"
	"golang.org/x/oauth2"

	wkfscache "github.com/nsheridan/autocert-wkfs-cache"
	"github.com/nsheridan/cashier/lib"
	"github.com/nsheridan/cashier/server/auth"
	"github.com/nsheridan/cashier/server/auth/github"
	"github.com/nsheridan/cashier/server/auth/gitlab"
	"github.com/nsheridan/cashier/server/auth/google"
	"github.com/nsheridan/cashier/server/auth/microsoft"
	"github.com/nsheridan/cashier/server/config"
	"github.com/nsheridan/cashier/server/metrics"
	"github.com/nsheridan/cashier/server/signer"
	"github.com/nsheridan/cashier/server/store"
	"github.com/sid77/drop"
)

func loadCerts(certFile, keyFile string) (tls.Certificate, error) {
	key, err := wkfs.ReadFile(keyFile)
	if err != nil {
		return tls.Certificate{}, errors.Wrap(err, "error reading TLS private key")
	}
	cert, err := wkfs.ReadFile(certFile)
	if err != nil {
		return tls.Certificate{}, errors.Wrap(err, "error reading TLS certificate")
	}
	return tls.X509KeyPair(cert, key)
}

// Run the server.
func Run(conf *config.Config) {
	var err error

	laddr := fmt.Sprintf("%s:%d", conf.Server.Addr, conf.Server.Port)
	l, err := net.Listen("tcp", laddr)
	if err != nil {
		log.Fatal(errors.Wrapf(err, "unable to listen on %s:%d", conf.Server.Addr, conf.Server.Port))
	}

	tlsConfig := &tls.Config{}
	if conf.Server.UseTLS {
		if conf.Server.LetsEncryptServername != "" {
			m := autocert.Manager{
				Prompt:     autocert.AcceptTOS,
				HostPolicy: autocert.HostWhitelist(conf.Server.LetsEncryptServername),
			}
			if conf.Server.LetsEncryptCache != "" {
				m.Cache = wkfscache.Cache(conf.Server.LetsEncryptCache)
			}
			tlsConfig = m.TLSConfig()
		} else {
			if conf.Server.TLSCert == "" || conf.Server.TLSKey == "" {
				log.Fatal("TLS cert or key not specified in config")
			}
			tlsConfig.Certificates = make([]tls.Certificate, 1)
			tlsConfig.Certificates[0], err = loadCerts(conf.Server.TLSCert, conf.Server.TLSKey)
			if err != nil {
				log.Fatal(errors.Wrap(err, "unable to create TLS listener"))
			}
		}
		l = tls.NewListener(l, tlsConfig)
	}

	if conf.Server.User != "" {
		log.Print("Dropping privileges...")
		if err := drop.DropPrivileges(conf.Server.User); err != nil {
			log.Fatal(errors.Wrap(err, "unable to drop privileges"))
		}
	}

	// Unprivileged section
	metrics.Register()

	var authprovider auth.Provider
	switch conf.Auth.Provider {
	case "github":
		authprovider, err = github.New(conf.Auth)
	case "gitlab":
		authprovider, err = gitlab.New(conf.Auth)
	case "google":
		authprovider, err = google.New(conf.Auth)
	case "microsoft":
		authprovider, err = microsoft.New(conf.Auth)
	default:
		log.Fatalf("Unknown provider %s\n", conf.Auth.Provider)
	}
	if err != nil {
		log.Fatal(errors.Wrapf(err, "unable to use provider '%s'", conf.Auth.Provider))
	}

	keysigner, err := signer.New(conf.SSH)
	if err != nil {
		log.Fatal(err)
	}

	certstore, err := store.New(conf.Server.Database)
	if err != nil {
		log.Fatal(err)
	}

	ctx := &app{
		cookiestore:   sessions.NewCookieStore([]byte(conf.Server.CookieSecret)),
		requireReason: conf.Server.RequireReason,
		keysigner:     keysigner,
		certstore:     certstore,
		authprovider:  authprovider,
		config:        conf.Server,
		router:        mux.NewRouter(),
	}
	ctx.cookiestore.Options = &sessions.Options{
		MaxAge:   900,
		Path:     "/",
		Secure:   conf.Server.UseTLS,
		HttpOnly: true,
	}

	logfile := os.Stderr
	if conf.Server.HTTPLogFile != "" {
		logfile, err = os.OpenFile(conf.Server.HTTPLogFile, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0640)
		if err != nil {
			log.Printf("error opening log: %v. logging to stdout", err)
		}
	}

	ctx.routes()
	ctx.router.Use(mwVersion)
	ctx.router.Use(handlers.CompressHandler)
	ctx.router.Use(handlers.RecoveryHandler())
	r := handlers.LoggingHandler(logfile, ctx.router)
	s := &http.Server{
		Handler:      r,
		ReadTimeout:  20 * time.Second,
		WriteTimeout: 20 * time.Second,
		IdleTimeout:  120 * time.Second,
	}

	log.Printf("Starting server on %s", laddr)
	s.Serve(l)
}

// mwVersion is middleware to add a X-Cashier-Version header to the response.
func mwVersion(next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		w.Header().Set("X-Cashier-Version", lib.Version)
		next.ServeHTTP(w, r)
	})
}

func encodeString(s string) string {
	var buffer bytes.Buffer
	chunkSize := 70
	runes := []rune(base64.StdEncoding.EncodeToString([]byte(s)))

	for i := 0; i < len(runes); i += chunkSize {
		end := i + chunkSize
		if end > len(runes) {
			end = len(runes)
		}
		buffer.WriteString(string(runes[i:end]))
		buffer.WriteString("\n")
	}
	buffer.WriteString(".\n")
	return buffer.String()
}

// app contains local context - cookiestore, authsession etc.
type app struct {
	cookiestore   *sessions.CookieStore
	authprovider  auth.Provider
	certstore     store.CertStorer
	keysigner     *signer.KeySigner
	router        *mux.Router
	config        *config.Server
	requireReason bool
}

func (a *app) routes() {
	// login required
	csrfHandler := csrf.Protect([]byte(a.config.CSRFSecret), csrf.Secure(a.config.UseTLS))
	a.router.Methods("GET").Path("/").Handler(a.authed(http.HandlerFunc(a.index)))
	a.router.Methods("POST").Path("/admin/revoke").Handler(a.authed(csrfHandler(http.HandlerFunc(a.revoke))))
	a.router.Methods("GET").Path("/admin/certs").Handler(a.authed(csrfHandler(http.HandlerFunc(a.getAllCerts))))
	a.router.Methods("GET").Path("/admin/certs.json").Handler(a.authed(http.HandlerFunc(a.getCertsJSON)))

	// no login required
	a.router.Methods("GET").Path("/auth/login").HandlerFunc(a.auth)
	a.router.Methods("GET").Path("/auth/callback").HandlerFunc(a.auth)
	a.router.Methods("GET").Path("/revoked").HandlerFunc(a.revoked)
	a.router.Methods("POST").Path("/sign").HandlerFunc(a.sign)

	a.router.Methods("GET").Path("/healthcheck").HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
		w.WriteHeader(http.StatusOK)
		fmt.Fprintf(w, "ok")
	})
	a.router.Methods("GET").Path("/metrics").Handler(promhttp.Handler())
	box := packr.NewBox("static")
	a.router.PathPrefix("/static/").Handler(http.StripPrefix("/static", http.FileServer(box)))
}

func (a *app) getAuthToken(r *http.Request) *oauth2.Token {
	token := &oauth2.Token{}
	marshalled := a.getSessionVariable(r, "token")
	json.Unmarshal([]byte(marshalled), token)
	return token
}

func (a *app) setAuthToken(w http.ResponseWriter, r *http.Request, token *oauth2.Token) {
	v, _ := json.Marshal(token)
	a.setSessionVariable(w, r, "token", string(v))
}

func (a *app) getSessionVariable(r *http.Request, key string) string {
	session, _ := a.cookiestore.Get(r, "session")
	v, ok := session.Values[key].(string)
	if !ok {
		v = ""
	}
	return v
}

func (a *app) setSessionVariable(w http.ResponseWriter, r *http.Request, key, value string) {
	session, _ := a.cookiestore.Get(r, "session")
	session.Values[key] = value
	session.Save(r, w)
}

func (a *app) authed(next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		log.Printf("Checking auth for %s.", r.URL.EscapedPath())
		t := a.getAuthToken(r)
		log.Printf("Token is: %v.", t)
		if !t.Valid() || !a.authprovider.Valid(t) {
			log.Printf("Invalid token t.Valid() = %s.", t.Valid())
			a.setSessionVariable(w, r, "origin_url", r.URL.EscapedPath())
			http.Redirect(w, r, "/auth/login", http.StatusSeeOther)
			return
		}
		next.ServeHTTP(w, r)
	})
}