Actually put in foreign key constraints! Also test wallet and account empty db fields.

This commit is contained in:
Daniel Krol 2022-06-29 00:06:43 -04:00
parent fac36a7931
commit a37b64faad
5 changed files with 135 additions and 39 deletions

View file

@ -1,8 +1,11 @@
package store package store
import ( import (
"errors"
"testing" "testing"
"github.com/mattn/go-sqlite3"
"orblivion/lbry-id/auth" "orblivion/lbry-id/auth"
) )
@ -94,8 +97,36 @@ func TestStoreGetUserId(t *testing.T) {
} }
} }
// TODO - Tests each db method. Check for missing "NOT NULL" fields. Do the loop thing, and always just check for null error.
func TestStoreAccountEmptyFields(t *testing.T) { func TestStoreAccountEmptyFields(t *testing.T) {
// Make sure expiration doesn't get set if sanitization fails // Make sure expiration doesn't get set if sanitization fails
t.Fatalf("Test me") tt := []struct {
name string
email auth.Email
password auth.Password
}{
{
name: "missing email",
email: "",
password: "xyz",
},
// Not testing empty password because it gets obfuscated to something
// non-empty in the method
}
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
s, sqliteTmpFile := StoreTestInit(t)
defer StoreTestCleanup(sqliteTmpFile)
var sqliteErr sqlite3.Error
err := s.CreateAccount(tc.email, tc.password)
if errors.As(err, &sqliteErr) {
if errors.Is(sqliteErr.ExtendedCode, sqlite3.ErrConstraintCheck) {
return // We got the error we expected
}
}
t.Errorf("Expected check constraint error for empty field. Got %+v", err)
})
}
} }

View file

@ -45,7 +45,7 @@ type Store struct {
} }
func (s *Store) Init(fileName string) { func (s *Store) Init(fileName string) {
db, err := sql.Open("sqlite3", fileName) db, err := sql.Open("sqlite3", "file:"+fileName+"?_foreign_keys=on")
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
@ -80,9 +80,9 @@ func (s *Store) Migrate() error {
CHECK ( CHECK (
-- should eventually fail for foreign key constraint instead -- should eventually fail for foreign key constraint instead
user_id <> 0 AND user_id <> 0 AND
device_id <> '' AND
token <> '' AND token <> '' AND
device_id <> '' AND
scope <> '' AND scope <> '' AND
-- Don't know when it uses either format to denote UTC -- Don't know when it uses either format to denote UTC
@ -91,6 +91,7 @@ func (s *Store) Migrate() error {
), ),
PRIMARY KEY (user_id, device_id) PRIMARY KEY (user_id, device_id)
FOREIGN KEY (user_id) REFERENCES accounts(user_id)
); );
CREATE TABLE IF NOT EXISTS wallets( CREATE TABLE IF NOT EXISTS wallets(
user_id INTEGER NOT NULL, user_id INTEGER NOT NULL,
@ -99,11 +100,21 @@ func (s *Store) Migrate() error {
hmac TEXT NOT NULL, hmac TEXT NOT NULL,
PRIMARY KEY (user_id) PRIMARY KEY (user_id)
FOREIGN KEY (user_id) REFERENCES accounts(user_id) FOREIGN KEY (user_id) REFERENCES accounts(user_id)
CHECK (
user_id <> 0 AND
encrypted_wallet <> '' AND
hmac <> '' AND
sequence <> 0
)
); );
CREATE TABLE IF NOT EXISTS accounts( CREATE TABLE IF NOT EXISTS accounts(
email TEXT NOT NULL UNIQUE, email TEXT NOT NULL UNIQUE,
password TEXT NOT NULL,
user_id INTEGER PRIMARY KEY AUTOINCREMENT, user_id INTEGER PRIMARY KEY AUTOINCREMENT,
password TEXT NOT NULL CHECK (
email <> '' AND
password <> ''
)
); );
` `

View file

@ -4,6 +4,8 @@ import (
"io/ioutil" "io/ioutil"
"os" "os"
"testing" "testing"
"orblivion/lbry-id/auth"
) )
func StoreTestInit(t *testing.T) (s Store, tmpFile *os.File) { func StoreTestInit(t *testing.T) (s Store, tmpFile *os.File) {
@ -30,3 +32,26 @@ func StoreTestCleanup(tmpFile *os.File) {
os.Remove(tmpFile.Name()) os.Remove(tmpFile.Name())
} }
} }
func makeTestUserId(t *testing.T, s *Store) auth.UserId {
email, password := auth.Email("abc@example.com"), auth.Password("123")
rows, err := s.db.Query(
"INSERT INTO accounts (email, password) values(?,?) returning user_id",
email, password.Obfuscate(),
)
if err != nil {
t.Fatalf("Error setting up account")
}
defer rows.Close()
for rows.Next() {
var userId auth.UserId
err := rows.Scan(&userId)
if err != nil {
t.Fatalf("Error setting up account")
}
return userId
}
t.Fatalf("Error setting up account")
return auth.UserId(0)
}

View file

@ -77,12 +77,14 @@ func TestStoreInsertToken(t *testing.T) {
s, sqliteTmpFile := StoreTestInit(t) s, sqliteTmpFile := StoreTestInit(t)
defer StoreTestCleanup(sqliteTmpFile) defer StoreTestCleanup(sqliteTmpFile)
userId := makeTestUserId(t, &s)
// created for addition to the DB (no expiration attached) // created for addition to the DB (no expiration attached)
authToken1 := auth.AuthToken{ authToken1 := auth.AuthToken{
Token: "seekrit-1", Token: "seekrit-1",
DeviceId: "dId", DeviceId: "dId",
Scope: "*", Scope: "*",
UserId: 123, UserId: userId,
} }
expiration := time.Now().Add(time.Hour * 24 * 14).UTC() expiration := time.Now().Add(time.Hour * 24 * 14).UTC()
@ -121,12 +123,14 @@ func TestStoreUpdateToken(t *testing.T) {
s, sqliteTmpFile := StoreTestInit(t) s, sqliteTmpFile := StoreTestInit(t)
defer StoreTestCleanup(sqliteTmpFile) defer StoreTestCleanup(sqliteTmpFile)
userId := makeTestUserId(t, &s)
// created for addition to the DB (no expiration attached) // created for addition to the DB (no expiration attached)
authTokenUpdate := auth.AuthToken{ authTokenUpdate := auth.AuthToken{
Token: "seekrit-update", Token: "seekrit-update",
DeviceId: "dId", DeviceId: "dId",
Scope: "*", Scope: "*",
UserId: 123, UserId: userId,
} }
expiration := time.Now().Add(time.Hour * 24 * 14).UTC() expiration := time.Now().Add(time.Hour * 24 * 14).UTC()
@ -177,13 +181,15 @@ func TestStoreSaveToken(t *testing.T) {
s, sqliteTmpFile := StoreTestInit(t) s, sqliteTmpFile := StoreTestInit(t)
defer StoreTestCleanup(sqliteTmpFile) defer StoreTestCleanup(sqliteTmpFile)
userId := makeTestUserId(t, &s)
// Version 1 of the token for both devices // Version 1 of the token for both devices
// created for addition to the DB (no expiration attached) // created for addition to the DB (no expiration attached)
authToken_d1_1 := auth.AuthToken{ authToken_d1_1 := auth.AuthToken{
Token: "seekrit-d1-1", Token: "seekrit-d1-1",
DeviceId: "dId-1", DeviceId: "dId-1",
Scope: "*", Scope: "*",
UserId: 123, UserId: userId,
} }
authToken_d2_1 := authToken_d1_1 authToken_d2_1 := authToken_d1_1
@ -256,12 +262,14 @@ func TestStoreGetToken(t *testing.T) {
s, sqliteTmpFile := StoreTestInit(t) s, sqliteTmpFile := StoreTestInit(t)
defer StoreTestCleanup(sqliteTmpFile) defer StoreTestCleanup(sqliteTmpFile)
userId := makeTestUserId(t, &s)
// created for addition to the DB (no expiration attached) // created for addition to the DB (no expiration attached)
authToken := auth.AuthToken{ authToken := auth.AuthToken{
Token: "seekrit-d1", Token: "seekrit-d1",
DeviceId: "dId", DeviceId: "dId",
Scope: "*", Scope: "*",
UserId: 123, UserId: userId,
} }
expiration := time.Time(time.Now().UTC().Add(time.Hour * 24 * 14)) expiration := time.Time(time.Now().UTC().Add(time.Hour * 24 * 14))
@ -307,11 +315,13 @@ func TestStoreTokenUTC(t *testing.T) {
s, sqliteTmpFile := StoreTestInit(t) s, sqliteTmpFile := StoreTestInit(t)
defer StoreTestCleanup(sqliteTmpFile) defer StoreTestCleanup(sqliteTmpFile)
userId := makeTestUserId(t, &s)
authToken := auth.AuthToken{ authToken := auth.AuthToken{
Token: "seekrit-1", Token: "seekrit-1",
DeviceId: "dId", DeviceId: "dId",
Scope: "*", Scope: "*",
UserId: 123, UserId: userId,
} }
if err := s.SaveToken(&authToken); err != nil { if err := s.SaveToken(&authToken); err != nil {

View file

@ -1,8 +1,11 @@
package store package store
import ( import (
"errors"
"testing" "testing"
"github.com/mattn/go-sqlite3"
"orblivion/lbry-id/auth" "orblivion/lbry-id/auth"
"orblivion/lbry-id/wallet" "orblivion/lbry-id/wallet"
) )
@ -76,29 +79,6 @@ func expectWalletNotExists(t *testing.T, s *Store, userId auth.UserId) {
return // found nothing, we're good return // found nothing, we're good
} }
func setupWalletTest(t *testing.T, s *Store) auth.UserId {
email, password := auth.Email("abc@example.com"), auth.Password("123")
rows, err := s.db.Query(
"INSERT INTO accounts (email, password) values(?,?) returning user_id",
email, password.Obfuscate(),
)
if err != nil {
t.Fatalf("Error setting up account")
}
defer rows.Close()
for rows.Next() {
var userId auth.UserId
err := rows.Scan(&userId)
if err != nil {
t.Fatalf("Error setting up account")
}
return userId
}
t.Fatalf("Error setting up account")
return auth.UserId(0)
}
// Test insertFirstWallet, using GetWallet, CreateAccount and GetUserID as a helpers // Test insertFirstWallet, using GetWallet, CreateAccount and GetUserID as a helpers
// Try insertFirstWallet twice with the same user id, error the second time // Try insertFirstWallet twice with the same user id, error the second time
func TestStoreInsertWallet(t *testing.T) { func TestStoreInsertWallet(t *testing.T) {
@ -106,7 +86,7 @@ func TestStoreInsertWallet(t *testing.T) {
defer StoreTestCleanup(sqliteTmpFile) defer StoreTestCleanup(sqliteTmpFile)
// Get a valid userId // Get a valid userId
userId := setupWalletTest(t, &s) userId := makeTestUserId(t, &s)
// Get a wallet, come back empty // Get a wallet, come back empty
expectWalletNotExists(t, &s, userId) expectWalletNotExists(t, &s, userId)
@ -138,7 +118,7 @@ func TestStoreUpdateWallet(t *testing.T) {
defer StoreTestCleanup(sqliteTmpFile) defer StoreTestCleanup(sqliteTmpFile)
// Get a valid userId // Get a valid userId
userId := setupWalletTest(t, &s) userId := makeTestUserId(t, &s)
// Try to update a wallet, fail for nothing to update // Try to update a wallet, fail for nothing to update
if err := s.updateWalletToSequence(userId, wallet.EncryptedWallet("my-enc-wallet-a"), wallet.Sequence(1), wallet.WalletHmac("my-hmac-a")); err != ErrNoWallet { if err := s.updateWalletToSequence(userId, wallet.EncryptedWallet("my-enc-wallet-a"), wallet.Sequence(1), wallet.WalletHmac("my-hmac-a")); err != ErrNoWallet {
@ -194,7 +174,7 @@ func TestStoreSetWallet(t *testing.T) {
defer StoreTestCleanup(sqliteTmpFile) defer StoreTestCleanup(sqliteTmpFile)
// Get a valid userId // Get a valid userId
userId := setupWalletTest(t, &s) userId := makeTestUserId(t, &s)
// Sequence 2 - fails - out of sequence (behind the scenes, tries to update but there's nothing there yet) // Sequence 2 - fails - out of sequence (behind the scenes, tries to update but there's nothing there yet)
if err := s.SetWallet(userId, wallet.EncryptedWallet("my-enc-wallet-a"), wallet.Sequence(2), wallet.WalletHmac("my-hmac-a")); err != ErrWrongSequence { if err := s.SetWallet(userId, wallet.EncryptedWallet("my-enc-wallet-a"), wallet.Sequence(2), wallet.WalletHmac("my-hmac-a")); err != ErrWrongSequence {
@ -241,7 +221,7 @@ func TestStoreGetWallet(t *testing.T) {
defer StoreTestCleanup(sqliteTmpFile) defer StoreTestCleanup(sqliteTmpFile)
// Get a valid userId // Get a valid userId
userId := setupWalletTest(t, &s) userId := makeTestUserId(t, &s)
// GetWallet fails when there's no wallet // GetWallet fails when there's no wallet
encryptedWallet, sequence, hmac, err := s.GetWallet(userId) encryptedWallet, sequence, hmac, err := s.GetWallet(userId)
@ -260,8 +240,47 @@ func TestStoreGetWallet(t *testing.T) {
} }
} }
// TODO - Tests each db method. Check for missing "NOT NULL" fields. Do the loop thing, and always just check for null error.
func TestStoreWalletEmptyFields(t *testing.T) { func TestStoreWalletEmptyFields(t *testing.T) {
// Make sure expiration doesn't get set if sanitization fails // Make sure expiration doesn't get set if sanitization fails
t.Fatalf("Test me") tt := []struct {
name string
userId auth.UserId
encryptedWallet wallet.EncryptedWallet
hmac wallet.WalletHmac
}{
{
name: "missing user id",
userId: auth.UserId(0),
encryptedWallet: wallet.EncryptedWallet("my-enc-wallet"),
hmac: wallet.WalletHmac("my-hmac"),
}, {
name: "missing encrypted wallet",
userId: auth.UserId(1),
encryptedWallet: wallet.EncryptedWallet(""),
hmac: wallet.WalletHmac("my-hmac"),
}, {
name: "missing hmac",
userId: auth.UserId(1),
encryptedWallet: wallet.EncryptedWallet("my-enc-wallet"),
hmac: wallet.WalletHmac(""),
},
// Not testing 0 sequence because the method basically doesn't allow for it.
}
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
s, sqliteTmpFile := StoreTestInit(t)
defer StoreTestCleanup(sqliteTmpFile)
var sqliteErr sqlite3.Error
err := s.insertFirstWallet(tc.userId, tc.encryptedWallet, tc.hmac)
if errors.As(err, &sqliteErr) {
if errors.Is(sqliteErr.ExtendedCode, sqlite3.ErrConstraintCheck) {
return // We got the error we expected
}
}
t.Errorf("Expected check constraint error for empty field. Got %+v", err)
})
}
} }