package store

import (
	"crypto/rand"
	"crypto/rsa"
	"database/sql"
	"encoding/json"
	"io/ioutil"
	"os"
	"os/user"
	"strings"
	"testing"
	"time"

	"github.com/nsheridan/cashier/server/store/types"
	"github.com/nsheridan/cashier/testdata"
	"github.com/stretchr/testify/assert"

	"golang.org/x/crypto/ssh"
)

func TestParseCertificate(t *testing.T) {
	t.Parallel()
	a := assert.New(t)
	now := uint64(time.Now().Unix())
	r, _ := rsa.GenerateKey(rand.Reader, 1024)
	pub, _ := ssh.NewPublicKey(r.Public())
	c := &ssh.Certificate{
		KeyId:           "id",
		ValidPrincipals: types.StringSlice{"principal"},
		ValidBefore:     now,
		CertType:        ssh.UserCert,
		Key:             pub,
	}
	s, _ := ssh.NewSignerFromKey(r)
	c.SignCert(rand.Reader, s)
	rec := parseCertificate(c)

	a.Equal(c.KeyId, rec.KeyID)
	a.Equal(c.ValidPrincipals, []string(rec.Principals))
	a.Equal(c.ValidBefore, uint64(rec.Expires.Unix()))
	a.Equal(c.ValidAfter, uint64(rec.CreatedAt.Unix()))
}

func testStore(t *testing.T, db CertStorer) {
	defer db.Close()

	r := &CertRecord{
		KeyID:      "a",
		Principals: []string{"b"},
		CreatedAt:  time.Now().UTC(),
		Expires:    time.Now().UTC().Add(-1 * time.Second),
		Raw:        "AAAAAA",
	}
	if err := db.SetRecord(r); err != nil {
		t.Error(err)
	}

	// includeExpired = false should return 0 results
	recs, err := db.List(false)
	if err != nil {
		t.Error(err)
	}
	if len(recs) > 0 {
		t.Errorf("Expected 0 results, got %d", len(recs))
	}
	// includeExpired = false should return 1 result
	recs, err = db.List(true)
	if err != nil {
		t.Error(err)
	}
	if recs[0].KeyID != r.KeyID {
		t.Error("key mismatch")
	}

	c, _, _, _, _ := ssh.ParseAuthorizedKey(testdata.Cert)
	cert := c.(*ssh.Certificate)
	cert.ValidBefore = uint64(time.Now().Add(1 * time.Hour).UTC().Unix())
	cert.ValidAfter = uint64(time.Now().Add(-5 * time.Minute).UTC().Unix())
	if err := db.SetCert(cert); err != nil {
		t.Error(err)
	}

	ret, err := db.Get("key")
	if err != nil {
		t.Error(err)
	}
	if ret.KeyID != cert.KeyId {
		t.Error("key mismatch")
	}
	if err := db.Revoke("key"); err != nil {
		t.Error(err)
	}

	revoked, err := db.GetRevoked()
	if err != nil {
		t.Error(err)
	}
	for _, k := range revoked {
		if k.KeyID != "key" {
			t.Errorf("Unexpected key: %s", k.KeyID)
		}
	}
}

func TestMemoryStore(t *testing.T) {
	t.Parallel()
	db := NewMemoryStore()
	testStore(t, db)
}

func TestMySQLStore(t *testing.T) {
	t.Parallel()
	if os.Getenv("MYSQL_TEST") == "" {
		t.Skip("No MYSQL_TEST environment variable")
	}
	u, _ := user.Current()
	sqlConfig := map[string]string{
		"type":     "mysql",
		"password": os.Getenv("MYSQL_TEST_PASS"),
		"address":  os.Getenv("MYSQL_TEST_HOST"),
	}
	if testUser, ok := os.LookupEnv("MYSQL_TEST_USER"); ok {
		sqlConfig["username"] = testUser
	} else {
		sqlConfig["username"] = u.Username
	}
	db, err := NewSQLStore(sqlConfig)
	if err != nil {
		t.Error(err)
	}
	testStore(t, db)
}

func TestSQLiteStore(t *testing.T) {
	t.Parallel()
	f, err := ioutil.TempFile("", "sqlite_test_db")
	if err != nil {
		t.Error(err)
	}
	defer os.Remove(f.Name())

	seed, err := ioutil.ReadFile("../../db/seed.sql")
	if err != nil {
		t.Error(err)
	}
	stmts := strings.Split(string(seed), ";")
	d, _ := sql.Open("sqlite3", f.Name())
	for _, stmt := range stmts {
		if !strings.Contains(stmt, "CREATE TABLE") {
			continue
		}
		d.Exec(stmt)
	}
	d.Close()

	config := map[string]string{"type": "sqlite", "filename": f.Name()}
	db, err := NewSQLStore(config)
	if err != nil {
		t.Error(err)
	}
	testStore(t, db)
}

func TestMarshalCert(t *testing.T) {
	a := assert.New(t)
	c := &CertRecord{
		KeyID:      "id",
		Principals: []string{"user"},
		CreatedAt:  time.Date(2017, time.April, 10, 13, 0, 0, 0, time.UTC),
		Expires:    time.Date(2017, time.April, 11, 10, 0, 0, 0, time.UTC),
		Raw:        "ABCDEF",
	}
	b, err := json.Marshal(c)
	if err != nil {
		t.Error(err)
	}
	want := `{"key_id":"id","principals":["user"],"revoked":false,"created_at":"2017-04-10 13:00:00 +0000","expires":"2017-04-11 10:00:00 +0000"}`
	a.JSONEq(want, string(b))
}