KDF for server password. Save salt in DB.
This commit is contained in:
parent
dbfdff167b
commit
4430013bae
13 changed files with 288 additions and 127 deletions
51
auth/auth.go
51
auth/auth.go
|
@ -2,16 +2,19 @@ package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/sha256"
|
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/scrypt"
|
||||||
)
|
)
|
||||||
|
|
||||||
type UserId int32
|
type UserId int32
|
||||||
type Email string
|
type Email string
|
||||||
type DeviceId string
|
type DeviceId string
|
||||||
type Password string
|
type Password string
|
||||||
|
type KDFKey string // KDF output
|
||||||
|
type Salt string
|
||||||
type TokenString string
|
type TokenString string
|
||||||
type AuthScope string
|
type AuthScope string
|
||||||
|
|
||||||
|
@ -62,8 +65,46 @@ func (at *AuthToken) ScopeValid(required AuthScope) bool {
|
||||||
return at.Scope == ScopeFull || at.Scope == required
|
return at.Scope == ScopeFull || at.Scope == required
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p Password) Obfuscate() string {
|
const SaltLength = 8
|
||||||
// TODO KDF instead
|
|
||||||
hash := sha256.Sum256([]byte(p))
|
// https://words.filippo.io/the-scrypt-parameters/
|
||||||
return hex.EncodeToString(hash[:])
|
func passwordScrypt(p Password, saltBytes []byte) ([]byte, error) {
|
||||||
|
scryptN := 32768
|
||||||
|
scryptR := 8
|
||||||
|
scryptP := 1
|
||||||
|
keyLen := 32
|
||||||
|
return scrypt.Key([]byte(p), saltBytes, scryptN, scryptR, scryptP, keyLen)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Given a password (in the same format submitted via request), generate a
|
||||||
|
// random salt, run the password and salt thorugh the KDF, and return the salt
|
||||||
|
// and kdf output. The result generally goes into a database.
|
||||||
|
func (p Password) Create() (key KDFKey, salt Salt, err error) {
|
||||||
|
saltBytes := make([]byte, SaltLength)
|
||||||
|
if _, err := rand.Read(saltBytes); err != nil {
|
||||||
|
return "", "", fmt.Errorf("Error generating salt: %+v", err)
|
||||||
|
}
|
||||||
|
keyBytes, err := passwordScrypt(p, saltBytes)
|
||||||
|
if err == nil {
|
||||||
|
key = KDFKey(hex.EncodeToString(keyBytes[:]))
|
||||||
|
salt = Salt(hex.EncodeToString(saltBytes[:]))
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Given a password (in the same format submitted via request), a salt, and an
|
||||||
|
// expected kdf output, run the password and salt thorugh the KDF, and return
|
||||||
|
// whether the result kdf output matches the kdf test output.
|
||||||
|
// The salt and test kdf output generally come out of the database, and is used
|
||||||
|
// to check a submitted password.
|
||||||
|
func (p Password) Check(checkKey KDFKey, salt Salt) (match bool, err error) {
|
||||||
|
saltBytes, err := hex.DecodeString(string(salt))
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("Error decoding salt from hex: %+v", err)
|
||||||
|
}
|
||||||
|
keyBytes, err := passwordScrypt(p, saltBytes)
|
||||||
|
if err == nil {
|
||||||
|
match = KDFKey(hex.EncodeToString(keyBytes[:])) == checkKey
|
||||||
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -53,3 +53,73 @@ func TestAuthScopeInvalid(t *testing.T) {
|
||||||
t.Fatalf("Expected banana to be an invalid scope for carrot")
|
t.Fatalf("Expected banana to be an invalid scope for carrot")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCreatePassword(t *testing.T) {
|
||||||
|
// Since the salt is randomized, there's really not much we can do to test
|
||||||
|
// the create function other than to check the length of the outputs and that
|
||||||
|
// they're different each time.
|
||||||
|
|
||||||
|
const password = Password("password")
|
||||||
|
|
||||||
|
key1, salt1, err := password.Create()
|
||||||
|
if err != nil {
|
||||||
|
t.Error("Error creating password")
|
||||||
|
}
|
||||||
|
if len(key1) != 64 {
|
||||||
|
t.Error("Key has wrong length", key1)
|
||||||
|
}
|
||||||
|
if len(salt1) != 16 {
|
||||||
|
t.Error("Salt has wrong length", salt1)
|
||||||
|
}
|
||||||
|
|
||||||
|
key2, salt2, err := password.Create()
|
||||||
|
if err != nil {
|
||||||
|
t.Error("Error creating password")
|
||||||
|
}
|
||||||
|
if key1 == key2 {
|
||||||
|
t.Error("Key is not random", key1)
|
||||||
|
}
|
||||||
|
if salt1 == salt2 {
|
||||||
|
t.Error("Salt is not random", key1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckPassword(t *testing.T) {
|
||||||
|
const password = Password("password 1")
|
||||||
|
const key = KDFKey("b9a3669973fcd2da3625e84da9d9a2da87bd280bcb02586851e1cb5bee1efa10")
|
||||||
|
const salt = Salt("080cbdf6d247c665")
|
||||||
|
|
||||||
|
match, err := password.Check(key, salt)
|
||||||
|
if err != nil {
|
||||||
|
t.Error("Error checking password")
|
||||||
|
}
|
||||||
|
if !match {
|
||||||
|
t.Error("Expected password to match correct key and salt")
|
||||||
|
}
|
||||||
|
|
||||||
|
const wrongKey = KDFKey("0000000073fcd2da3625e84da9d9a2da87bd280bcb02586851e1cb5bee1efa10")
|
||||||
|
match, err = password.Check(wrongKey, salt)
|
||||||
|
if err != nil {
|
||||||
|
t.Error("Error checking password")
|
||||||
|
}
|
||||||
|
if match {
|
||||||
|
t.Error("Expected password to not match incorrect key")
|
||||||
|
}
|
||||||
|
|
||||||
|
const wrongSalt = Salt("00000000d247c665")
|
||||||
|
match, err = password.Check(key, wrongSalt)
|
||||||
|
if err != nil {
|
||||||
|
t.Error("Error checking password")
|
||||||
|
}
|
||||||
|
if match {
|
||||||
|
t.Error("Expected password to not match incorrect salt")
|
||||||
|
}
|
||||||
|
|
||||||
|
const invalidSalt = Salt("Whoops")
|
||||||
|
match, err = password.Check(key, invalidSalt)
|
||||||
|
if err == nil {
|
||||||
|
// It does a decode of salt inside the function but not the key so we won't
|
||||||
|
// test invalid hex string with that
|
||||||
|
t.Error("Expected password check to fail with invalid salt")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
8
go.mod
8
go.mod
|
@ -3,3 +3,11 @@ module lbryio/lbry-id
|
||||||
go 1.17
|
go 1.17
|
||||||
|
|
||||||
require github.com/mattn/go-sqlite3 v1.14.9
|
require github.com/mattn/go-sqlite3 v1.14.9
|
||||||
|
|
||||||
|
require (
|
||||||
|
golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d // indirect
|
||||||
|
golang.org/x/net v0.0.0-20220708220712-1185a9018129 // indirect
|
||||||
|
golang.org/x/sys v0.0.0-20220712014510-0a85c31ab51e // indirect
|
||||||
|
golang.org/x/term v0.0.0-20220526004731-065cf7ba2467 // indirect
|
||||||
|
golang.org/x/text v0.3.7 // indirect
|
||||||
|
)
|
||||||
|
|
10
go.sum
10
go.sum
|
@ -1,2 +1,12 @@
|
||||||
github.com/mattn/go-sqlite3 v1.14.9 h1:10HX2Td0ocZpYEjhilsuo6WWtUqttj2Kb0KtD86/KYA=
|
github.com/mattn/go-sqlite3 v1.14.9 h1:10HX2Td0ocZpYEjhilsuo6WWtUqttj2Kb0KtD86/KYA=
|
||||||
github.com/mattn/go-sqlite3 v1.14.9/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=
|
github.com/mattn/go-sqlite3 v1.14.9/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=
|
||||||
|
golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d h1:sK3txAijHtOK88l68nt020reeT1ZdKLIYetKl95FzVY=
|
||||||
|
golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||||
|
golang.org/x/net v0.0.0-20220708220712-1185a9018129 h1:vucSRfWwTsoXro7P+3Cjlr6flUMtzCwzlvkxEQtHHB0=
|
||||||
|
golang.org/x/net v0.0.0-20220708220712-1185a9018129/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||||
|
golang.org/x/sys v0.0.0-20220712014510-0a85c31ab51e h1:NHvCuwuS43lGnYhten69ZWqi2QOj/CiDNcKbVqwVoew=
|
||||||
|
golang.org/x/sys v0.0.0-20220712014510-0a85c31ab51e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/term v0.0.0-20220526004731-065cf7ba2467 h1:CBpWXWQpIRjzmkkA+M7q9Fqnwd2mZr3AFqexg8YTfoM=
|
||||||
|
golang.org/x/term v0.0.0-20220526004731-065cf7ba2467/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||||
|
golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk=
|
||||||
|
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||||
|
|
1
main.go
1
main.go
|
@ -2,6 +2,7 @@ package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"log"
|
"log"
|
||||||
|
|
||||||
"lbryio/lbry-id/auth"
|
"lbryio/lbry-id/auth"
|
||||||
"lbryio/lbry-id/server"
|
"lbryio/lbry-id/server"
|
||||||
"lbryio/lbry-id/store"
|
"lbryio/lbry-id/store"
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"lbryio/lbry-id/auth"
|
"lbryio/lbry-id/auth"
|
||||||
"lbryio/lbry-id/store"
|
"lbryio/lbry-id/store"
|
||||||
)
|
)
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"lbryio/lbry-id/auth"
|
"lbryio/lbry-id/auth"
|
||||||
"lbryio/lbry-id/store"
|
"lbryio/lbry-id/store"
|
||||||
)
|
)
|
||||||
|
|
|
@ -7,10 +7,11 @@ import (
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"lbryio/lbry-id/auth"
|
|
||||||
"lbryio/lbry-id/store"
|
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"lbryio/lbry-id/auth"
|
||||||
|
"lbryio/lbry-id/store"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestServerAuthHandlerSuccess(t *testing.T) {
|
func TestServerAuthHandlerSuccess(t *testing.T) {
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/mail"
|
"net/mail"
|
||||||
|
|
||||||
"lbryio/lbry-id/auth"
|
"lbryio/lbry-id/auth"
|
||||||
"lbryio/lbry-id/store"
|
"lbryio/lbry-id/store"
|
||||||
)
|
)
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"lbryio/lbry-id/auth"
|
"lbryio/lbry-id/auth"
|
||||||
"lbryio/lbry-id/store"
|
"lbryio/lbry-id/store"
|
||||||
"lbryio/lbry-id/wallet"
|
"lbryio/lbry-id/wallet"
|
||||||
|
|
|
@ -10,40 +10,44 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func expectAccountMatch(t *testing.T, s *Store, email auth.Email, password auth.Password) {
|
func expectAccountMatch(t *testing.T, s *Store, email auth.Email, password auth.Password) {
|
||||||
rows, err := s.db.Query(
|
var key auth.KDFKey
|
||||||
`SELECT 1 from accounts WHERE email=? AND password=?`,
|
var salt auth.Salt
|
||||||
email, password.Obfuscate(),
|
|
||||||
)
|
err := s.db.QueryRow(
|
||||||
|
`SELECT key, salt from accounts WHERE email=?`,
|
||||||
|
email,
|
||||||
|
).Scan(&key, &salt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Error finding account for: %s %s - %+v", email, password, err)
|
t.Fatalf("Error finding account for: %s %s - %+v", email, password, err)
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
for rows.Next() {
|
match, err := password.Check(key, salt)
|
||||||
return // found something, we're good
|
if err != nil {
|
||||||
|
t.Fatalf("Error checking password for: %s %s - %+v", email, password, err)
|
||||||
|
}
|
||||||
|
if !match {
|
||||||
|
t.Fatalf("Expected account for: %s %s", email, password)
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Fatalf("Expected account for: %s %s", email, password)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func expectAccountNotMatch(t *testing.T, s *Store, email auth.Email, password auth.Password) {
|
func expectAccountNotExists(t *testing.T, s *Store, email auth.Email) {
|
||||||
rows, err := s.db.Query(
|
rows, err := s.db.Query(
|
||||||
`SELECT 1 from accounts WHERE email=? AND password=?`,
|
`SELECT 1 from accounts WHERE email=?`,
|
||||||
email, password.Obfuscate(),
|
email,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Error finding account for: %s %s - %+v", email, password, err)
|
t.Fatalf("Error finding account for: %s - %+v", email, err)
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
t.Fatalf("Expected no account for: %s %s", email, password)
|
t.Fatalf("Expected no account for: %s", email)
|
||||||
}
|
}
|
||||||
|
|
||||||
// found nothing, we're good
|
// found nothing, we're good
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test CreateAccount, using GetUserId as a helper
|
// Test CreateAccount
|
||||||
// Try CreateAccount twice with the same email and different password, error the second time
|
// Try CreateAccount twice with the same email and different password, error the second time
|
||||||
func TestStoreCreateAccount(t *testing.T) {
|
func TestStoreCreateAccount(t *testing.T) {
|
||||||
s, sqliteTmpFile := StoreTestInit(t)
|
s, sqliteTmpFile := StoreTestInit(t)
|
||||||
|
@ -52,7 +56,7 @@ func TestStoreCreateAccount(t *testing.T) {
|
||||||
email, password := auth.Email("abc@example.com"), auth.Password("123")
|
email, password := auth.Email("abc@example.com"), auth.Password("123")
|
||||||
|
|
||||||
// Get an account, come back empty
|
// Get an account, come back empty
|
||||||
expectAccountNotMatch(t, &s, email, password)
|
expectAccountNotExists(t, &s, email)
|
||||||
|
|
||||||
// Create an account
|
// Create an account
|
||||||
if err := s.CreateAccount(email, password); err != nil {
|
if err := s.CreateAccount(email, password); err != nil {
|
||||||
|
@ -70,14 +74,12 @@ func TestStoreCreateAccount(t *testing.T) {
|
||||||
t.Fatalf(`CreateAccount err: wanted "%+v", got "%+v"`, ErrDuplicateAccount, err)
|
t.Fatalf(`CreateAccount err: wanted "%+v", got "%+v"`, ErrDuplicateAccount, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the email and same *first* password we successfully put in, but not the second
|
// Get the email and same *first* password we successfully put in
|
||||||
expectAccountMatch(t, &s, email, password)
|
expectAccountMatch(t, &s, email, password)
|
||||||
expectAccountNotMatch(t, &s, email, newPassword)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test GetUserId, using CreateAccount as a helper
|
// Test GetUserId for nonexisting email
|
||||||
// Try GetUserId before creating an account (fail), and after (succeed)
|
func TestStoreGetUserIdAccountNotExists(t *testing.T) {
|
||||||
func TestStoreGetUserId(t *testing.T) {
|
|
||||||
s, sqliteTmpFile := StoreTestInit(t)
|
s, sqliteTmpFile := StoreTestInit(t)
|
||||||
defer StoreTestCleanup(sqliteTmpFile)
|
defer StoreTestCleanup(sqliteTmpFile)
|
||||||
|
|
||||||
|
@ -85,16 +87,26 @@ func TestStoreGetUserId(t *testing.T) {
|
||||||
|
|
||||||
// Check that there's no user id for email and password first
|
// Check that there's no user id for email and password first
|
||||||
if userId, err := s.GetUserId(email, password); err != ErrWrongCredentials || userId != 0 {
|
if userId, err := s.GetUserId(email, password); err != ErrWrongCredentials || userId != 0 {
|
||||||
t.Fatalf(`CreateAccount err: wanted "%+v", got "%+v. userId: %v"`, ErrWrongCredentials, err, userId)
|
t.Fatalf(`GetUserId error for nonexistant account: wanted "%+v", got "%+v. userId: %v"`, ErrWrongCredentials, err, userId)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Create the account
|
// Test GetUserId for existing account, with the correct and incorrect password
|
||||||
_ = s.CreateAccount(email, password)
|
func TestStoreGetUserIdAccountExists(t *testing.T) {
|
||||||
|
s, sqliteTmpFile := StoreTestInit(t)
|
||||||
|
defer StoreTestCleanup(sqliteTmpFile)
|
||||||
|
|
||||||
|
createdUserId, email, password := makeTestUser(t, &s)
|
||||||
|
|
||||||
// Check that there's now a user id for the email and password
|
// Check that there's now a user id for the email and password
|
||||||
if userId, err := s.GetUserId(email, password); err != nil || userId == 0 {
|
if userId, err := s.GetUserId(email, password); err != nil || userId != createdUserId {
|
||||||
t.Fatalf("Unexpected error in GetUserId: err: %+v userId: %v", err, userId)
|
t.Fatalf("Unexpected error in GetUserId: err: %+v userId: %v", err, userId)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check that it won't return if the wrong password is given
|
||||||
|
if userId, err := s.GetUserId(email, password+auth.Password("_wrong")); err != ErrWrongCredentials || userId != 0 {
|
||||||
|
t.Fatalf(`GetUserId error for wrong password: wanted "%+v", got "%+v. userId: %v"`, ErrWrongCredentials, err, userId)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestStoreAccountEmptyFields(t *testing.T) {
|
func TestStoreAccountEmptyFields(t *testing.T) {
|
||||||
|
@ -109,7 +121,7 @@ func TestStoreAccountEmptyFields(t *testing.T) {
|
||||||
email: "",
|
email: "",
|
||||||
password: "xyz",
|
password: "xyz",
|
||||||
},
|
},
|
||||||
// Not testing empty password because it gets obfuscated to something
|
// Not testing empty key and salt because they get generated to something
|
||||||
// non-empty in the method
|
// non-empty in the method
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
188
store/store.go
188
store/store.go
|
@ -111,11 +111,13 @@ func (s *Store) Migrate() error {
|
||||||
);
|
);
|
||||||
CREATE TABLE IF NOT EXISTS accounts(
|
CREATE TABLE IF NOT EXISTS accounts(
|
||||||
email TEXT NOT NULL UNIQUE,
|
email TEXT NOT NULL UNIQUE,
|
||||||
password TEXT NOT NULL,
|
key TEXT NOT NULL,
|
||||||
|
salt TEXT NOT NULL,
|
||||||
user_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
user_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
CHECK (
|
CHECK (
|
||||||
email <> '' AND
|
email <> '' AND
|
||||||
password <> ''
|
key <> '' AND
|
||||||
|
salt <> ''
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
`
|
`
|
||||||
|
@ -328,13 +330,23 @@ func (s *Store) SetWallet(userId auth.UserId, encryptedWallet wallet.EncryptedWa
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Store) GetUserId(email auth.Email, password auth.Password) (userId auth.UserId, err error) {
|
func (s *Store) GetUserId(email auth.Email, password auth.Password) (userId auth.UserId, err error) {
|
||||||
|
var key auth.KDFKey
|
||||||
|
var salt auth.Salt
|
||||||
err = s.db.QueryRow(
|
err = s.db.QueryRow(
|
||||||
`SELECT user_id from accounts WHERE email=? AND password=?`,
|
`SELECT user_id, key, salt from accounts WHERE email=?`,
|
||||||
email, password.Obfuscate(),
|
email,
|
||||||
).Scan(&userId)
|
).Scan(&userId, &key, &salt)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
err = ErrWrongCredentials
|
err = ErrWrongCredentials
|
||||||
}
|
}
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
match, err := password.Check(key, salt)
|
||||||
|
if err == nil && !match {
|
||||||
|
err = ErrWrongCredentials
|
||||||
|
userId = auth.UserId(0)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -343,10 +355,14 @@ func (s *Store) GetUserId(email auth.Email, password auth.Password) (userId auth
|
||||||
/////////////
|
/////////////
|
||||||
|
|
||||||
func (s *Store) CreateAccount(email auth.Email, password auth.Password) (err error) {
|
func (s *Store) CreateAccount(email auth.Email, password auth.Password) (err error) {
|
||||||
|
key, salt, err := password.Create()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
// userId auto-increments
|
// userId auto-increments
|
||||||
_, err = s.db.Exec(
|
_, err = s.db.Exec(
|
||||||
"INSERT INTO accounts (email, password) VALUES(?,?)",
|
"INSERT INTO accounts (email, key, salt) VALUES(?,?,?)",
|
||||||
email, password.Obfuscate(),
|
email, key, salt,
|
||||||
)
|
)
|
||||||
|
|
||||||
var sqliteErr sqlite3.Error
|
var sqliteErr sqlite3.Error
|
||||||
|
@ -375,75 +391,14 @@ func (s *Store) ChangePasswordWithWallet(
|
||||||
sequence wallet.Sequence,
|
sequence wallet.Sequence,
|
||||||
hmac wallet.WalletHmac,
|
hmac wallet.WalletHmac,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
var userId auth.UserId
|
return s.changePassword(
|
||||||
|
email,
|
||||||
tx, err := s.db.Begin()
|
oldPassword,
|
||||||
if err != nil {
|
newPassword,
|
||||||
return
|
encryptedWallet,
|
||||||
}
|
sequence,
|
||||||
|
hmac,
|
||||||
// Lots of error conditions. Just defer this. However, we need to make sure to
|
|
||||||
// make sure the variable `err` is set to the error before we return, instead
|
|
||||||
// of doing `return <error>`.
|
|
||||||
endTxn := func() {
|
|
||||||
if err != nil {
|
|
||||||
tx.Rollback()
|
|
||||||
} else {
|
|
||||||
tx.Commit()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
defer endTxn()
|
|
||||||
|
|
||||||
err = tx.QueryRow(
|
|
||||||
"SELECT user_id from accounts WHERE email=? AND password=?",
|
|
||||||
email, oldPassword.Obfuscate(),
|
|
||||||
).Scan(&userId)
|
|
||||||
if err == sql.ErrNoRows {
|
|
||||||
err = ErrWrongCredentials
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
res, err := tx.Exec(
|
|
||||||
"UPDATE accounts SET password=? WHERE user_id=?",
|
|
||||||
newPassword.Obfuscate(), userId,
|
|
||||||
)
|
)
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
numRows, err := res.RowsAffected()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if numRows == 0 {
|
|
||||||
// Very unexpected error!
|
|
||||||
err = fmt.Errorf("Password failed to update")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
res, err = tx.Exec(
|
|
||||||
`UPDATE wallets SET encrypted_wallet=?, sequence=?, hmac=?
|
|
||||||
WHERE user_id=? AND sequence=?`,
|
|
||||||
encryptedWallet, sequence, hmac, userId, sequence-1,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
numRows, err = res.RowsAffected()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if numRows == 0 {
|
|
||||||
err = ErrWrongSequence
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Don't care how many I delete here. Might even be zero. No login token while
|
|
||||||
// changing password seems plausible.
|
|
||||||
_, err = tx.Exec("DELETE FROM auth_tokens WHERE user_id=?", userId)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Change password, but with no wallet currently saved. Since there's no
|
// Change password, but with no wallet currently saved. Since there's no
|
||||||
|
@ -456,6 +411,25 @@ func (s *Store) ChangePasswordNoWallet(
|
||||||
email auth.Email,
|
email auth.Email,
|
||||||
oldPassword auth.Password,
|
oldPassword auth.Password,
|
||||||
newPassword auth.Password,
|
newPassword auth.Password,
|
||||||
|
) (err error) {
|
||||||
|
return s.changePassword(
|
||||||
|
email,
|
||||||
|
oldPassword,
|
||||||
|
newPassword,
|
||||||
|
wallet.EncryptedWallet(""),
|
||||||
|
wallet.Sequence(0),
|
||||||
|
wallet.WalletHmac(""),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Common code for for WithWallet and WithNoWallet password change functions
|
||||||
|
func (s *Store) changePassword(
|
||||||
|
email auth.Email,
|
||||||
|
oldPassword auth.Password,
|
||||||
|
newPassword auth.Password,
|
||||||
|
encryptedWallet wallet.EncryptedWallet,
|
||||||
|
sequence wallet.Sequence,
|
||||||
|
hmac wallet.WalletHmac,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
var userId auth.UserId
|
var userId auth.UserId
|
||||||
|
|
||||||
|
@ -476,21 +450,35 @@ func (s *Store) ChangePasswordNoWallet(
|
||||||
}
|
}
|
||||||
defer endTxn()
|
defer endTxn()
|
||||||
|
|
||||||
|
var oldKey auth.KDFKey
|
||||||
|
var oldSalt auth.Salt
|
||||||
|
|
||||||
err = tx.QueryRow(
|
err = tx.QueryRow(
|
||||||
"SELECT user_id from accounts WHERE email=? AND password=?",
|
`SELECT user_id, key, salt from accounts WHERE email=?`,
|
||||||
email, oldPassword.Obfuscate(),
|
email,
|
||||||
).Scan(&userId)
|
).Scan(&userId, &oldKey, &oldSalt)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
err = ErrWrongCredentials
|
err = ErrWrongCredentials
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
match, err := oldPassword.Check(oldKey, oldSalt)
|
||||||
|
if err == nil && !match {
|
||||||
|
err = ErrWrongCredentials
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
newKey, newSalt, err := newPassword.Create()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
res, err := tx.Exec(
|
res, err := tx.Exec(
|
||||||
"UPDATE accounts SET password=? WHERE user_id=?",
|
"UPDATE accounts SET key=?, salt=? WHERE user_id=?",
|
||||||
newPassword.Obfuscate(), userId,
|
newKey, newSalt, userId,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
|
@ -505,17 +493,39 @@ func (s *Store) ChangePasswordNoWallet(
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Assert we have no wallet for this version of the password change function.
|
if encryptedWallet != "" {
|
||||||
var dummy string
|
// With a wallet expected: update it.
|
||||||
err = tx.QueryRow("SELECT 1 FROM wallets WHERE user_id=?", userId).Scan(&dummy)
|
|
||||||
if err != sql.ErrNoRows {
|
res, err = tx.Exec(
|
||||||
if err == nil {
|
`UPDATE wallets SET encrypted_wallet=?, sequence=?, hmac=?
|
||||||
// We expected no rows
|
WHERE user_id=? AND sequence=?`,
|
||||||
err = ErrUnexpectedWallet
|
encryptedWallet, sequence, hmac, userId, sequence-1,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
numRows, err = res.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if numRows == 0 {
|
||||||
|
err = ErrWrongSequence
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// With no wallet expected: assert we have no wallet.
|
||||||
|
|
||||||
|
var dummy string
|
||||||
|
err = tx.QueryRow("SELECT 1 FROM wallets WHERE user_id=?", userId).Scan(&dummy)
|
||||||
|
if err != sql.ErrNoRows {
|
||||||
|
if err == nil {
|
||||||
|
// We expected no rows
|
||||||
|
err = ErrUnexpectedWallet
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Some other error
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Some other error
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Don't care how many I delete here. Might even be zero. No login token while
|
// Don't care how many I delete here. Might even be zero. No login token while
|
||||||
|
|
|
@ -35,22 +35,26 @@ func StoreTestCleanup(tmpFile *os.File) {
|
||||||
|
|
||||||
func makeTestUser(t *testing.T, s *Store) (userId auth.UserId, email auth.Email, password auth.Password) {
|
func makeTestUser(t *testing.T, s *Store) (userId auth.UserId, email auth.Email, password auth.Password) {
|
||||||
email, password = auth.Email("abc@example.com"), auth.Password("123")
|
email, password = auth.Email("abc@example.com"), auth.Password("123")
|
||||||
|
key, salt, err := password.Create()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Error creating password")
|
||||||
|
}
|
||||||
|
|
||||||
rows, err := s.db.Query(
|
rows, err := s.db.Query(
|
||||||
"INSERT INTO accounts (email, password) values(?,?) returning user_id",
|
"INSERT INTO accounts (email, key, salt) values(?,?,?) returning user_id",
|
||||||
email, password.Obfuscate(),
|
email, key, salt,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Error setting up account")
|
t.Fatalf("Error setting up account: %+v", err)
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
err := rows.Scan(&userId)
|
err := rows.Scan(&userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Error setting up account")
|
t.Fatalf("Error setting up account: %+v", err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
t.Fatalf("Error setting up account")
|
t.Fatalf("Error setting up account - no rows found")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue