Actually put in foreign key constraints! Also test wallet and account empty db fields.
This commit is contained in:
parent
fac36a7931
commit
a37b64faad
5 changed files with 135 additions and 39 deletions
|
@ -1,8 +1,11 @@
|
|||
package store
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/mattn/go-sqlite3"
|
||||
|
||||
"orblivion/lbry-id/auth"
|
||||
)
|
||||
|
||||
|
@ -94,8 +97,36 @@ func TestStoreGetUserId(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// TODO - Tests each db method. Check for missing "NOT NULL" fields. Do the loop thing, and always just check for null error.
|
||||
func TestStoreAccountEmptyFields(t *testing.T) {
|
||||
// Make sure expiration doesn't get set if sanitization fails
|
||||
t.Fatalf("Test me")
|
||||
tt := []struct {
|
||||
name string
|
||||
email auth.Email
|
||||
password auth.Password
|
||||
}{
|
||||
{
|
||||
name: "missing email",
|
||||
email: "",
|
||||
password: "xyz",
|
||||
},
|
||||
// Not testing empty password because it gets obfuscated to something
|
||||
// non-empty in the method
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
s, sqliteTmpFile := StoreTestInit(t)
|
||||
defer StoreTestCleanup(sqliteTmpFile)
|
||||
|
||||
var sqliteErr sqlite3.Error
|
||||
|
||||
err := s.CreateAccount(tc.email, tc.password)
|
||||
if errors.As(err, &sqliteErr) {
|
||||
if errors.Is(sqliteErr.ExtendedCode, sqlite3.ErrConstraintCheck) {
|
||||
return // We got the error we expected
|
||||
}
|
||||
}
|
||||
t.Errorf("Expected check constraint error for empty field. Got %+v", err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -45,7 +45,7 @@ type Store struct {
|
|||
}
|
||||
|
||||
func (s *Store) Init(fileName string) {
|
||||
db, err := sql.Open("sqlite3", fileName)
|
||||
db, err := sql.Open("sqlite3", "file:"+fileName+"?_foreign_keys=on")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
@ -80,9 +80,9 @@ func (s *Store) Migrate() error {
|
|||
CHECK (
|
||||
-- should eventually fail for foreign key constraint instead
|
||||
user_id <> 0 AND
|
||||
device_id <> '' AND
|
||||
|
||||
token <> '' AND
|
||||
device_id <> '' AND
|
||||
scope <> '' AND
|
||||
|
||||
-- Don't know when it uses either format to denote UTC
|
||||
|
@ -91,6 +91,7 @@ func (s *Store) Migrate() error {
|
|||
|
||||
),
|
||||
PRIMARY KEY (user_id, device_id)
|
||||
FOREIGN KEY (user_id) REFERENCES accounts(user_id)
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS wallets(
|
||||
user_id INTEGER NOT NULL,
|
||||
|
@ -99,11 +100,21 @@ func (s *Store) Migrate() error {
|
|||
hmac TEXT NOT NULL,
|
||||
PRIMARY KEY (user_id)
|
||||
FOREIGN KEY (user_id) REFERENCES accounts(user_id)
|
||||
CHECK (
|
||||
user_id <> 0 AND
|
||||
encrypted_wallet <> '' AND
|
||||
hmac <> '' AND
|
||||
sequence <> 0
|
||||
)
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS accounts(
|
||||
email TEXT NOT NULL UNIQUE,
|
||||
password TEXT NOT NULL,
|
||||
user_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
password TEXT NOT NULL
|
||||
CHECK (
|
||||
email <> '' AND
|
||||
password <> ''
|
||||
)
|
||||
);
|
||||
`
|
||||
|
||||
|
|
|
@ -4,6 +4,8 @@ import (
|
|||
"io/ioutil"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"orblivion/lbry-id/auth"
|
||||
)
|
||||
|
||||
func StoreTestInit(t *testing.T) (s Store, tmpFile *os.File) {
|
||||
|
@ -30,3 +32,26 @@ func StoreTestCleanup(tmpFile *os.File) {
|
|||
os.Remove(tmpFile.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func makeTestUserId(t *testing.T, s *Store) auth.UserId {
|
||||
email, password := auth.Email("abc@example.com"), auth.Password("123")
|
||||
|
||||
rows, err := s.db.Query(
|
||||
"INSERT INTO accounts (email, password) values(?,?) returning user_id",
|
||||
email, password.Obfuscate(),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Error setting up account")
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var userId auth.UserId
|
||||
err := rows.Scan(&userId)
|
||||
if err != nil {
|
||||
t.Fatalf("Error setting up account")
|
||||
}
|
||||
return userId
|
||||
}
|
||||
t.Fatalf("Error setting up account")
|
||||
return auth.UserId(0)
|
||||
}
|
||||
|
|
|
@ -77,12 +77,14 @@ func TestStoreInsertToken(t *testing.T) {
|
|||
s, sqliteTmpFile := StoreTestInit(t)
|
||||
defer StoreTestCleanup(sqliteTmpFile)
|
||||
|
||||
userId := makeTestUserId(t, &s)
|
||||
|
||||
// created for addition to the DB (no expiration attached)
|
||||
authToken1 := auth.AuthToken{
|
||||
Token: "seekrit-1",
|
||||
DeviceId: "dId",
|
||||
Scope: "*",
|
||||
UserId: 123,
|
||||
UserId: userId,
|
||||
}
|
||||
expiration := time.Now().Add(time.Hour * 24 * 14).UTC()
|
||||
|
||||
|
@ -121,12 +123,14 @@ func TestStoreUpdateToken(t *testing.T) {
|
|||
s, sqliteTmpFile := StoreTestInit(t)
|
||||
defer StoreTestCleanup(sqliteTmpFile)
|
||||
|
||||
userId := makeTestUserId(t, &s)
|
||||
|
||||
// created for addition to the DB (no expiration attached)
|
||||
authTokenUpdate := auth.AuthToken{
|
||||
Token: "seekrit-update",
|
||||
DeviceId: "dId",
|
||||
Scope: "*",
|
||||
UserId: 123,
|
||||
UserId: userId,
|
||||
}
|
||||
expiration := time.Now().Add(time.Hour * 24 * 14).UTC()
|
||||
|
||||
|
@ -177,13 +181,15 @@ func TestStoreSaveToken(t *testing.T) {
|
|||
s, sqliteTmpFile := StoreTestInit(t)
|
||||
defer StoreTestCleanup(sqliteTmpFile)
|
||||
|
||||
userId := makeTestUserId(t, &s)
|
||||
|
||||
// Version 1 of the token for both devices
|
||||
// created for addition to the DB (no expiration attached)
|
||||
authToken_d1_1 := auth.AuthToken{
|
||||
Token: "seekrit-d1-1",
|
||||
DeviceId: "dId-1",
|
||||
Scope: "*",
|
||||
UserId: 123,
|
||||
UserId: userId,
|
||||
}
|
||||
|
||||
authToken_d2_1 := authToken_d1_1
|
||||
|
@ -256,12 +262,14 @@ func TestStoreGetToken(t *testing.T) {
|
|||
s, sqliteTmpFile := StoreTestInit(t)
|
||||
defer StoreTestCleanup(sqliteTmpFile)
|
||||
|
||||
userId := makeTestUserId(t, &s)
|
||||
|
||||
// created for addition to the DB (no expiration attached)
|
||||
authToken := auth.AuthToken{
|
||||
Token: "seekrit-d1",
|
||||
DeviceId: "dId",
|
||||
Scope: "*",
|
||||
UserId: 123,
|
||||
UserId: userId,
|
||||
}
|
||||
expiration := time.Time(time.Now().UTC().Add(time.Hour * 24 * 14))
|
||||
|
||||
|
@ -307,11 +315,13 @@ func TestStoreTokenUTC(t *testing.T) {
|
|||
s, sqliteTmpFile := StoreTestInit(t)
|
||||
defer StoreTestCleanup(sqliteTmpFile)
|
||||
|
||||
userId := makeTestUserId(t, &s)
|
||||
|
||||
authToken := auth.AuthToken{
|
||||
Token: "seekrit-1",
|
||||
DeviceId: "dId",
|
||||
Scope: "*",
|
||||
UserId: 123,
|
||||
UserId: userId,
|
||||
}
|
||||
|
||||
if err := s.SaveToken(&authToken); err != nil {
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
package store
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/mattn/go-sqlite3"
|
||||
|
||||
"orblivion/lbry-id/auth"
|
||||
"orblivion/lbry-id/wallet"
|
||||
)
|
||||
|
@ -76,29 +79,6 @@ func expectWalletNotExists(t *testing.T, s *Store, userId auth.UserId) {
|
|||
return // found nothing, we're good
|
||||
}
|
||||
|
||||
func setupWalletTest(t *testing.T, s *Store) auth.UserId {
|
||||
email, password := auth.Email("abc@example.com"), auth.Password("123")
|
||||
|
||||
rows, err := s.db.Query(
|
||||
"INSERT INTO accounts (email, password) values(?,?) returning user_id",
|
||||
email, password.Obfuscate(),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Error setting up account")
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var userId auth.UserId
|
||||
err := rows.Scan(&userId)
|
||||
if err != nil {
|
||||
t.Fatalf("Error setting up account")
|
||||
}
|
||||
return userId
|
||||
}
|
||||
t.Fatalf("Error setting up account")
|
||||
return auth.UserId(0)
|
||||
}
|
||||
|
||||
// Test insertFirstWallet, using GetWallet, CreateAccount and GetUserID as a helpers
|
||||
// Try insertFirstWallet twice with the same user id, error the second time
|
||||
func TestStoreInsertWallet(t *testing.T) {
|
||||
|
@ -106,7 +86,7 @@ func TestStoreInsertWallet(t *testing.T) {
|
|||
defer StoreTestCleanup(sqliteTmpFile)
|
||||
|
||||
// Get a valid userId
|
||||
userId := setupWalletTest(t, &s)
|
||||
userId := makeTestUserId(t, &s)
|
||||
|
||||
// Get a wallet, come back empty
|
||||
expectWalletNotExists(t, &s, userId)
|
||||
|
@ -138,7 +118,7 @@ func TestStoreUpdateWallet(t *testing.T) {
|
|||
defer StoreTestCleanup(sqliteTmpFile)
|
||||
|
||||
// Get a valid userId
|
||||
userId := setupWalletTest(t, &s)
|
||||
userId := makeTestUserId(t, &s)
|
||||
|
||||
// Try to update a wallet, fail for nothing to update
|
||||
if err := s.updateWalletToSequence(userId, wallet.EncryptedWallet("my-enc-wallet-a"), wallet.Sequence(1), wallet.WalletHmac("my-hmac-a")); err != ErrNoWallet {
|
||||
|
@ -194,7 +174,7 @@ func TestStoreSetWallet(t *testing.T) {
|
|||
defer StoreTestCleanup(sqliteTmpFile)
|
||||
|
||||
// Get a valid userId
|
||||
userId := setupWalletTest(t, &s)
|
||||
userId := makeTestUserId(t, &s)
|
||||
|
||||
// Sequence 2 - fails - out of sequence (behind the scenes, tries to update but there's nothing there yet)
|
||||
if err := s.SetWallet(userId, wallet.EncryptedWallet("my-enc-wallet-a"), wallet.Sequence(2), wallet.WalletHmac("my-hmac-a")); err != ErrWrongSequence {
|
||||
|
@ -241,7 +221,7 @@ func TestStoreGetWallet(t *testing.T) {
|
|||
defer StoreTestCleanup(sqliteTmpFile)
|
||||
|
||||
// Get a valid userId
|
||||
userId := setupWalletTest(t, &s)
|
||||
userId := makeTestUserId(t, &s)
|
||||
|
||||
// GetWallet fails when there's no wallet
|
||||
encryptedWallet, sequence, hmac, err := s.GetWallet(userId)
|
||||
|
@ -260,8 +240,47 @@ func TestStoreGetWallet(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// TODO - Tests each db method. Check for missing "NOT NULL" fields. Do the loop thing, and always just check for null error.
|
||||
func TestStoreWalletEmptyFields(t *testing.T) {
|
||||
// Make sure expiration doesn't get set if sanitization fails
|
||||
t.Fatalf("Test me")
|
||||
tt := []struct {
|
||||
name string
|
||||
userId auth.UserId
|
||||
encryptedWallet wallet.EncryptedWallet
|
||||
hmac wallet.WalletHmac
|
||||
}{
|
||||
{
|
||||
name: "missing user id",
|
||||
userId: auth.UserId(0),
|
||||
encryptedWallet: wallet.EncryptedWallet("my-enc-wallet"),
|
||||
hmac: wallet.WalletHmac("my-hmac"),
|
||||
}, {
|
||||
name: "missing encrypted wallet",
|
||||
userId: auth.UserId(1),
|
||||
encryptedWallet: wallet.EncryptedWallet(""),
|
||||
hmac: wallet.WalletHmac("my-hmac"),
|
||||
}, {
|
||||
name: "missing hmac",
|
||||
userId: auth.UserId(1),
|
||||
encryptedWallet: wallet.EncryptedWallet("my-enc-wallet"),
|
||||
hmac: wallet.WalletHmac(""),
|
||||
},
|
||||
// Not testing 0 sequence because the method basically doesn't allow for it.
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
s, sqliteTmpFile := StoreTestInit(t)
|
||||
defer StoreTestCleanup(sqliteTmpFile)
|
||||
|
||||
var sqliteErr sqlite3.Error
|
||||
|
||||
err := s.insertFirstWallet(tc.userId, tc.encryptedWallet, tc.hmac)
|
||||
if errors.As(err, &sqliteErr) {
|
||||
if errors.Is(sqliteErr.ExtendedCode, sqlite3.ErrConstraintCheck) {
|
||||
return // We got the error we expected
|
||||
}
|
||||
}
|
||||
t.Errorf("Expected check constraint error for empty field. Got %+v", err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue