Skip to content
Snippets Groups Projects
Commit 88401309 authored by Niall Sheridan's avatar Niall Sheridan
Browse files

Initial commit

parents
Branches
Tags
No related merge requests found
package main
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"fmt"
"golang.org/x/crypto/ssh"
)
const (
rsaKey = "rsa"
ecdsaKey = "ecdsa"
)
type key interface{}
func generateRSAKey(bits int) (*rsa.PrivateKey, ssh.PublicKey, error) {
k, err := rsa.GenerateKey(rand.Reader, bits)
if err != nil {
return nil, nil, err
}
pub, err := ssh.NewPublicKey(&k.PublicKey)
if err != nil {
return nil, nil, err
}
return k, pub, nil
}
func generateECDSAKey(bits int) (*ecdsa.PrivateKey, ssh.PublicKey, error) {
var curve elliptic.Curve
switch bits {
case 256:
curve = elliptic.P256()
case 384:
curve = elliptic.P384()
case 521:
curve = elliptic.P521()
default:
return nil, nil, fmt.Errorf("Unsupported key size. Valid sizes are '256', '384', '521'")
}
k, err := ecdsa.GenerateKey(curve, rand.Reader)
if err != nil {
return nil, nil, err
}
pub, err := ssh.NewPublicKey(&k.PublicKey)
if err != nil {
return nil, nil, err
}
return k, pub, nil
}
func generateKey(keytype string, bits int) (key, ssh.PublicKey, error) {
switch keytype {
case rsaKey:
return generateRSAKey(bits)
case ecdsaKey:
return generateECDSAKey(bits)
default:
return nil, nil, fmt.Errorf("Unsupported key type %s. Valid choices are [%s, %s]", keytype, rsaKey, ecdsaKey)
}
}
package main
import (
"bytes"
"encoding/json"
"flag"
"fmt"
"io/ioutil"
"log"
"net"
"net/http"
"os"
"time"
"github.com/nsheridan/cashier/lib"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
)
var (
url = flag.String("url", "http://localhost:10000/sign", "Signing URL")
keybits = flag.Int("bits", 4096, "Key size")
validity = flag.Duration("validity", time.Hour*24, "Key validity")
keytype = flag.String("key_type", "rsa", "Type of private key to generate - rsa or ecdsa")
)
func installCert(a agent.Agent, cert *ssh.Certificate, key key) error {
pubcert := agent.AddedKey{
PrivateKey: key,
Certificate: cert,
Comment: cert.KeyId,
}
if err := a.Add(pubcert); err != nil {
return fmt.Errorf("error importing certificate: %s", err)
}
return nil
}
func send(s []byte, token string) (*lib.SignResponse, error) {
req, err := http.NewRequest("POST", *url, bytes.NewReader(s))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Add("Accept", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
client := &http.Client{}
resp, err := client.Do(req)
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("Bad response from server: %s", resp.Status)
}
if err != nil {
return nil, err
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, err
}
c := &lib.SignResponse{}
if err := json.Unmarshal(body, c); err != nil {
return nil, err
}
return c, nil
}
func sign(pub ssh.PublicKey, token string) (*ssh.Certificate, error) {
marshaled := ssh.MarshalAuthorizedKey(pub)
marshaled = marshaled[:len(marshaled)-1]
s, err := json.Marshal(&lib.SignRequest{
Key: string(marshaled),
ValidUntil: time.Now().Add(*validity),
})
if err != nil {
return nil, err
}
resp, err := send(s, token)
if err != nil {
return nil, err
}
if resp.Status != "ok" {
return nil, fmt.Errorf("error: %s", resp.Response)
}
k, _, _, _, err := ssh.ParseAuthorizedKey([]byte(resp.Response))
if err != nil {
return nil, err
}
cert, ok := k.(*ssh.Certificate)
if !ok {
return nil, fmt.Errorf("did not receive a certificate from server")
}
return cert, nil
}
func main() {
flag.Parse()
priv, pub, err := generateKey(*keytype, *keybits)
if err != nil {
log.Fatalln("Error generating key pair: ", err)
}
fmt.Print("Enter token: ")
var token string
fmt.Scanln(&token)
cert, err := sign(pub, token)
if err != nil {
log.Fatalln(err)
}
sock, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK"))
if err != nil {
log.Fatalln("Error connecting to agent: %s", err)
}
defer sock.Close()
a := agent.NewClient(sock)
if err := installCert(a, cert, priv); err != nil {
log.Fatalln(err)
}
fmt.Println("Certificate added.")
}
{
"server": {
"tls_key": "server.key",
"tls_cert": "server.crt",
"port": 443,
"cookie_secret": "supersecret"
},
"auth": {
"provider": "google",
"oauth_client_id": "nnnnnnnnnnnnnnnn.apps.googleusercontent.com",
"oauth_client_secret": "yyyyyyyyyyyyyyyyyyyyyy",
"oauth_callback_url": "https://sshca.example.com/auth/callback",
"google_opts": {
"domain": "example.com"
},
"jwt_signing_key": "supersecret"
},
"ssh": {
"signing_key": "signing_key",
"additional_principals": ["ec2-user"],
"max_age": "720h",
"permissions": ["permit-pty"]
}
}
package lib
import "time"
// SignRequest represents a signing request sent to the server.
type SignRequest struct {
Key string `json:"key"`
Principal string `json:"principal"`
ValidUntil time.Time `json:"valid_until"`
}
// SignResponse is sent by the server.
// `Status' is "ok" or "error".
// `Response' contains a signed certificate or an error message.
type SignResponse struct {
Status string `json:"status"`
Response string `json:"response"`
}
package google
import (
"fmt"
"net/http"
"strings"
"github.com/nsheridan/cashier/server/auth"
"github.com/nsheridan/cashier/server/config"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
googleapi "google.golang.org/api/oauth2/v2"
)
const (
revokeURL = "https://accounts.google.com/o/oauth2/revoke?token=%s"
name = "google"
)
type Config struct {
config *oauth2.Config
domain string
}
func New(c config.Auth) auth.Provider {
return &Config{
config: &oauth2.Config{
ClientID: c.OauthClientID,
ClientSecret: c.OauthClientSecret,
RedirectURL: c.OauthCallbackURL,
Endpoint: google.Endpoint,
Scopes: []string{googleapi.UserinfoEmailScope, googleapi.UserinfoProfileScope},
},
domain: c.GoogleOpts["domain"].(string),
}
}
func (c *Config) newClient(token *oauth2.Token) *http.Client {
return c.config.Client(oauth2.NoContext, token)
}
func (c *Config) Name() string {
return name
}
func (c *Config) Valid(token *oauth2.Token) bool {
if !token.Valid() {
return false
}
svc, err := googleapi.New(c.newClient(token))
if err != nil {
return false
}
t := svc.Tokeninfo()
t.AccessToken(token.AccessToken)
ti, err := t.Do()
if err != nil {
return false
}
ui, err := svc.Userinfo.Get().Do()
if err != nil {
return false
}
switch {
case ti.Audience != c.config.ClientID:
case ui.Hd != c.domain:
return false
}
return true
}
func (c *Config) Revoke(token *oauth2.Token) error {
h := c.newClient(token)
_, err := h.Get(fmt.Sprintf(revokeURL, token.AccessToken))
return err
}
func (c *Config) StartSession(state string) *auth.Session {
return &auth.Session{
AuthURL: c.config.AuthCodeURL(state, oauth2.SetAuthURLParam("hd", c.domain)),
State: state,
}
}
func (c *Config) Exchange(code string) (*oauth2.Token, error) {
return c.config.Exchange(oauth2.NoContext, code)
}
func (c *Config) Username(token *oauth2.Token) string {
svc, err := googleapi.New(c.newClient(token))
if err != nil {
return ""
}
ui, err := svc.Userinfo.Get().Do()
if err != nil {
return ""
}
return strings.Split(ui.Email, "@")[0]
}
package auth
import "golang.org/x/oauth2"
type Provider interface {
Name() string
StartSession(string) *Session
Exchange(string) (*oauth2.Token, error)
Username(*oauth2.Token) string
Valid(*oauth2.Token) bool
Revoke(*oauth2.Token) error
}
type Session struct {
AuthURL string
Token *oauth2.Token
State string
}
func (s *Session) Authorize(provider Provider, code string) error {
t, err := provider.Exchange(code)
if err != nil {
return err
}
s.Token = t
return nil
}
package config
import "github.com/spf13/viper"
// Config holds the values from the json config file.
type Config struct {
Server Server `mapstructure:"server"`
Auth Auth `mapstructure:"auth"`
SSH SSH `mapstructure:"ssh"`
}
// Server holds the configuration specific to the web server and sessions.
type Server struct {
UseTLS bool `mapstructure:"use_tls"`
TLSKey string `mapstructure:"tls_key"`
TLSCert string `mapstructure:"tls_cert"`
Port int `mapstructure:"port"`
CookieSecret string `mapstructure:"cookie_secret"`
}
// Auth holds the configuration specific to the OAuth provider.
type Auth struct {
OauthClientID string `mapstructure:"oauth_client_id"`
OauthClientSecret string `mapstructure:"oauth_client_secret"`
OauthCallbackURL string `mapstructure:"oauth_callback_url"`
Provider string `mapstructure:"provider"`
GoogleOpts map[string]interface{} `mapstructure:"google_opts"`
JWTSigningKey string `mapstructure:"jwt_signing_key"`
}
// SSH holds the configuration specific to signing ssh keys.
type SSH struct {
SigningKey string `mapstructure:"signing_key"`
Principals []string `mapstructure:"principals"`
MaxAge string `mapstructure:"max_age"`
Permissions []string `mapstructure:"permissions"`
}
// ReadConfig parses a JSON configuration file into a Config struct.
func ReadConfig(filename string) (*Config, error) {
config := &Config{}
v := viper.New()
v.SetConfigFile(filename)
if err := v.ReadInConfig(); err != nil {
return nil, err
}
if err := v.Unmarshal(config); err != nil {
return nil, err
}
return config, nil
}
package main
import (
"crypto/rand"
"encoding/hex"
"encoding/json"
"errors"
"flag"
"fmt"
"html/template"
"io"
"io/ioutil"
"log"
"net/http"
"time"
"golang.org/x/oauth2"
"github.com/dgrijalva/jwt-go"
"github.com/gorilla/mux"
"github.com/gorilla/sessions"
"github.com/nsheridan/cashier/lib"
"github.com/nsheridan/cashier/server/auth"
"github.com/nsheridan/cashier/server/auth/google"
"github.com/nsheridan/cashier/server/config"
"github.com/nsheridan/cashier/server/signer"
)
var (
cfg = flag.String("config_file", "config.json", "Path to configuration file.")
)
type appContext struct {
cookiestore *sessions.CookieStore
authprovider auth.Provider
authsession *auth.Session
views *template.Template
sshKeySigner *signer.KeySigner
jwtSigningKey []byte
}
func (a *appContext) getAuthCookie(r *http.Request) *oauth2.Token {
session, _ := a.cookiestore.Get(r, "tok")
t, ok := session.Values["token"]
if !ok {
return nil
}
var tok oauth2.Token
if err := json.Unmarshal(t.([]byte), &tok); err != nil {
return nil
}
if !a.authprovider.Valid(&tok) {
return nil
}
return &tok
}
func (a *appContext) setAuthCookie(w http.ResponseWriter, r *http.Request, t *oauth2.Token) {
session, _ := a.cookiestore.Get(r, "tok")
val, _ := json.Marshal(t)
session.Values["token"] = val
session.Save(r, w)
}
func parseKey(r *http.Request) (*lib.SignRequest, error) {
var s lib.SignRequest
body, err := ioutil.ReadAll(r.Body)
if err != nil {
return nil, err
}
if err := json.Unmarshal(body, &s); err != nil {
return nil, err
}
return &s, nil
}
func signHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) {
jwtoken, err := jwt.ParseFromRequest(r, func(t *jwt.Token) (interface{}, error) {
return a.jwtSigningKey, nil
})
if err != nil {
return http.StatusUnauthorized, errors.New(http.StatusText(http.StatusUnauthorized))
}
if !jwtoken.Valid {
log.Printf("Token %v not valid", jwtoken)
return http.StatusUnauthorized, errors.New(http.StatusText(http.StatusUnauthorized))
}
expiry := int64(jwtoken.Claims["exp"].(float64))
token := &oauth2.Token{
AccessToken: jwtoken.Claims["token"].(string),
Expiry: time.Unix(expiry, 0),
}
ok := a.authprovider.Valid(token)
if !ok {
return http.StatusUnauthorized, errors.New(http.StatusText(http.StatusUnauthorized))
}
// finally sign the pubkey and issue the cert.
req, err := parseKey(r)
req.Principal = a.authprovider.Username(token)
if err != nil {
return http.StatusInternalServerError, err
}
signed, err := a.sshKeySigner.Sign(req)
a.authprovider.Revoke(token)
if err != nil {
return http.StatusInternalServerError, err
}
json.NewEncoder(w).Encode(&lib.SignResponse{
Status: "ok",
Response: signed,
})
return http.StatusOK, nil
}
func loginHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) {
a.authsession = a.authprovider.StartSession(hex.EncodeToString(random(32)))
http.Redirect(w, r, a.authsession.AuthURL, http.StatusFound)
return http.StatusFound, nil
}
func callbackHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) {
if r.FormValue("state") != a.authsession.State {
return http.StatusUnauthorized, errors.New(http.StatusText(http.StatusUnauthorized))
}
code := r.FormValue("code")
if err := a.authsession.Authorize(a.authprovider, code); err != nil {
return http.StatusInternalServerError, err
}
a.setAuthCookie(w, r, a.authsession.Token)
http.Redirect(w, r, "/", http.StatusFound)
return http.StatusFound, nil
}
func rootHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) {
tok := a.getAuthCookie(r)
if !tok.Valid() {
http.Redirect(w, r, "/auth/login", http.StatusSeeOther)
return http.StatusSeeOther, nil
}
j := jwt.New(jwt.SigningMethodHS256)
j.Claims["token"] = tok.AccessToken
j.Claims["exp"] = tok.Expiry.Unix()
t, err := j.SignedString(a.jwtSigningKey)
if err != nil {
return http.StatusInternalServerError, err
}
page := struct {
Token string
}{t}
a.views.ExecuteTemplate(w, "token.html", page)
return http.StatusOK, nil
}
type appHandler struct {
*appContext
h func(*appContext, http.ResponseWriter, *http.Request) (int, error)
}
func (ah appHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
status, err := ah.h(ah.appContext, w, r)
if err != nil {
log.Printf("HTTP %d: %q", status, err)
switch status {
case http.StatusNotFound:
http.NotFound(w, r)
case http.StatusInternalServerError:
http.Error(w, http.StatusText(status), status)
default:
http.Error(w, http.StatusText(status), status)
}
}
}
func random(length int) []byte {
k := make([]byte, length)
if _, err := io.ReadFull(rand.Reader, k); err != nil {
return nil
}
return k
}
func main() {
flag.Parse()
config, err := config.ReadConfig(*cfg)
if err != nil {
log.Fatal(err)
}
signer, err := signer.NewSigner(config.SSH)
if err != nil {
log.Fatal(err)
}
authprovider := google.New(config.Auth)
ctx := &appContext{
cookiestore: sessions.NewCookieStore([]byte(config.Server.CookieSecret)),
authprovider: authprovider,
views: template.Must(template.ParseGlob("templates/*")),
sshKeySigner: signer,
jwtSigningKey: []byte(config.Auth.JWTSigningKey),
}
ctx.cookiestore.Options = &sessions.Options{
MaxAge: 900,
Path: "/",
Secure: config.Server.UseTLS,
HttpOnly: true,
}
m := mux.NewRouter()
m.Handle("/", appHandler{ctx, rootHandler})
m.Handle("/auth/login", appHandler{ctx, loginHandler})
m.Handle("/auth/callback", appHandler{ctx, callbackHandler})
m.Handle("/sign", appHandler{ctx, signHandler})
fmt.Println("Starting server...")
l := fmt.Sprintf(":%d", config.Server.Port)
if config.Server.UseTLS {
log.Fatal(http.ListenAndServeTLS(l, config.Server.TLSCert, config.Server.TLSKey, m))
}
log.Fatal(http.ListenAndServe(l, m))
}
package signer
import (
"crypto/rand"
"fmt"
"io/ioutil"
"time"
"github.com/nsheridan/cashier/lib"
"github.com/nsheridan/cashier/server/config"
"golang.org/x/crypto/ssh"
)
type KeySigner struct {
ca ssh.Signer
validity time.Duration
principals []string
permissions map[string]string
}
func (s *KeySigner) Sign(req *lib.SignRequest) (string, error) {
pubkey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(req.Key))
if err != nil {
return "", err
}
expires := time.Now().Add(s.validity)
if req.ValidUntil.After(expires) {
req.ValidUntil = expires
}
cert := &ssh.Certificate{
CertType: ssh.UserCert,
Key: pubkey,
KeyId: req.Principal,
ValidBefore: uint64(req.ValidUntil.Unix()),
ValidAfter: uint64(time.Now().Add(-5 * time.Minute).Unix()),
}
cert.ValidPrincipals = append(cert.ValidPrincipals, req.Principal)
cert.ValidPrincipals = append(cert.ValidPrincipals, s.principals...)
cert.Extensions = s.permissions
if err := cert.SignCert(rand.Reader, s.ca); err != nil {
return "", err
}
marshaled := ssh.MarshalAuthorizedKey(cert)
// Remove the trailing newline.
marshaled = marshaled[:len(marshaled)-1]
return string(marshaled), nil
}
func makeperms(perms []string) map[string]string {
if len(perms) > 0 {
m := make(map[string]string)
for _, p := range perms {
m[p] = ""
}
return m
}
return map[string]string{
"permit-X11-forwarding": "",
"permit-agent-forwarding": "",
"permit-port-forwarding": "",
"permit-pty": "",
"permit-user-rc": "",
}
}
func NewSigner(conf config.SSH) (*KeySigner, error) {
data, err := ioutil.ReadFile(conf.SigningKey)
if err != nil {
return nil, fmt.Errorf("unable to read CA key %s: %v", conf.SigningKey, err)
}
key, err := ssh.ParsePrivateKey(data)
if err != nil {
return nil, fmt.Errorf("unable to parse CA key: %v", err)
}
validity, err := time.ParseDuration(conf.MaxAge)
if err != nil {
return nil, fmt.Errorf("error parsing duration '%s': %v", conf.MaxAge, err)
}
return &KeySigner{
ca: key,
validity: validity,
principals: conf.Principals,
permissions: makeperms(conf.Permissions),
}, nil
}
<html>
<head>
<title>YOUR TOKEN!</title>
<style>
<!--
body {
text-align: center;
font-family: sans-serif;
background-color: #edece4;
margin-top: 120px;
}
.code {
background-color: #26292B;
border: none;
color: #fff;
font-family: monospace;
font-size: 13;
font-weight: bold;
height: auto;
margin: 12px 12px 12px 12px;
padding: 12px 12px 12px 12px;
resize: none;
text-align: center;
width: 960px;
}
::selection {
background: #32d0ff;
color: #000;
}
::-moz-selection {
background: #32d0ff;
color: #000;
}
-->
</style>
</head>
<body>
<h2>
This is your token. There are many like it but this one is yours.
</h2>
<textarea class="code" readonly spellcheck="false" onclick="this.focus();this.select();">{{.Token}}</textarea>
<h2>
The token will expire in &lt; 1 hour.
</h2>
</body>
</html>
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment