diff --git a/store/account_test.go b/store/account_test.go index d98c101..b5aa164 100644 --- a/store/account_test.go +++ b/store/account_test.go @@ -7,17 +7,37 @@ import ( ) func expectAccountExists(t *testing.T, s *Store, email auth.Email, password auth.Password) { - _, err := s.GetUserId(email, password) + rows, err := s.db.Query( + `SELECT 1 from accounts WHERE email=? AND password=?`, + email, password.Obfuscate(), + ) if err != nil { - t.Fatalf("Unexpected error in GetUserId: %+v", err) + t.Fatalf("Error finding account for: %s %s - %+v", email, password, err) } + defer rows.Close() + + for rows.Next() { + return // found something, we're good + } + + t.Fatalf("Expected account for: %s %s", email, password) } func expectAccountNotExists(t *testing.T, s *Store, email auth.Email, password auth.Password) { - _, err := s.GetUserId(email, password) - if err != ErrNoUId { - t.Fatalf("Expected ErrNoUId. err: %+v", err) + rows, err := s.db.Query( + `SELECT 1 from accounts WHERE email=? AND password=?`, + email, password.Obfuscate(), + ) + if err != nil { + t.Fatalf("Error finding account for: %s %s - %+v", email, password, err) } + defer rows.Close() + + for rows.Next() { + t.Fatalf("Expected no account for: %s %s", email, password) + } + + // found nothing, we're good } // Test CreateAccount, using GetUserId as a helper diff --git a/store/store.go b/store/store.go index 228120e..9f3f489 100644 --- a/store/store.go +++ b/store/store.go @@ -4,14 +4,15 @@ package store import ( "database/sql" - "errors" "fmt" - "github.com/mattn/go-sqlite3" "log" + "time" + + "github.com/mattn/go-sqlite3" + "orblivion/lbry-id/auth" "orblivion/lbry-id/wallet" - "time" ) var ( @@ -133,7 +134,7 @@ func (s *Store) GetToken(token auth.TokenString) (*auth.AuthToken, error) { } return &authToken, nil } - return nil, ErrNoToken // TODO - will need to test + return nil, ErrNoToken } func (s *Store) insertToken(authToken *auth.AuthToken, expiration time.Time) (err error) { diff --git a/store/token_test.go b/store/token_test.go index c46c6ad..625c988 100644 --- a/store/token_test.go +++ b/store/token_test.go @@ -1,28 +1,70 @@ package store import ( + "reflect" "strings" "testing" - "reflect" "time" + "orblivion/lbry-id/auth" ) -func expectTokenExists(t *testing.T, s *Store, token auth.TokenString, expectedToken auth.AuthToken) { - gotToken, err := s.GetToken(token) +func expectTokenExists(t *testing.T, s *Store, expectedToken auth.AuthToken) { + rows, err := s.db.Query("SELECT * FROM auth_tokens WHERE token=?", expectedToken.Token) if err != nil { - t.Fatalf("Unexpected error in GetToken: %+v", err) + t.Fatalf("Error finding token for: %s - %+v", expectedToken.Token, err) } - if gotToken == nil || !reflect.DeepEqual(*gotToken, expectedToken) { - t.Fatalf("token: \n expected %+v\n got: %+v", expectedToken, *gotToken) + defer rows.Close() + + var gotToken auth.AuthToken + for rows.Next() { + + err := rows.Scan( + &gotToken.Token, + &gotToken.UserId, + &gotToken.DeviceId, + &gotToken.Scope, + &gotToken.Expiration, + ) + + if err != nil { + t.Fatalf("Error finding token for: %s - %+v", expectedToken.Token, err) + } + + if !reflect.DeepEqual(gotToken, expectedToken) { + t.Fatalf("token: \n expected %+v\n got: %+v", expectedToken, gotToken) + } + + return // found a match, we're good } + t.Fatalf("Expected token for: %s", expectedToken.Token) } func expectTokenNotExists(t *testing.T, s *Store, token auth.TokenString) { - gotToken, err := s.GetToken(token) - if gotToken != nil || err != ErrNoToken { - t.Fatalf("Expected ErrNoToken. token: %+v err: %+v", gotToken, err) + rows, err := s.db.Query("SELECT * FROM auth_tokens WHERE token=?", token) + if err != nil { + t.Fatalf("Error finding (lack of) token for: %s - %+v", token, err) } + defer rows.Close() + + var gotToken auth.AuthToken + for rows.Next() { + + err := rows.Scan( + &gotToken.Token, + &gotToken.UserId, + &gotToken.DeviceId, + &gotToken.Scope, + &gotToken.Expiration, + ) + + if err != nil { + t.Fatalf("Error finding (lack of) token for: %s - %+v", token, err) + } + + t.Fatalf("Expected no token. Got: %+v", gotToken) + } + return // found nothing, we're good } // Test insertToken, using GetToken as a helper @@ -41,7 +83,7 @@ func TestStoreInsertToken(t *testing.T) { } expiration := time.Now().Add(time.Hour * 24 * 14).UTC() - // Get a token, come back empty + // Try to get a token, come back empty because we're just starting out expectTokenNotExists(t, &s, authToken1.Token) // Put in a token @@ -54,7 +96,7 @@ func TestStoreInsertToken(t *testing.T) { authToken1Expected.Expiration = &expiration // Get and confirm the token we just put in - expectTokenExists(t, &s, authToken1.Token, authToken1Expected) + expectTokenExists(t, &s, authToken1Expected) // Try to put a different token, fail because we already have one authToken2 := authToken1 @@ -65,7 +107,7 @@ func TestStoreInsertToken(t *testing.T) { } // Get the same *first* token we successfully put in - expectTokenExists(t, &s, authToken1.Token, authToken1Expected) + expectTokenExists(t, &s, authToken1Expected) } // Test updateToken, using GetToken and insertToken as helpers @@ -115,7 +157,7 @@ func TestStoreUpdateToken(t *testing.T) { authTokenUpdateExpected.Expiration = &expiration // Get and confirm the token we just put in - expectTokenExists(t, &s, authTokenUpdate.Token, authTokenUpdateExpected) + expectTokenExists(t, &s, authTokenUpdateExpected) // Fail to get the token we previously inserted, because it's now been overwritten expectTokenNotExists(t, &s, authTokenInsert.Token) @@ -167,8 +209,8 @@ func TestStoreSaveToken(t *testing.T) { } // Get and confirm the tokens we just put in - expectTokenExists(t, &s, authToken_d1_1.Token, authToken_d1_1) - expectTokenExists(t, &s, authToken_d2_1.Token, authToken_d2_1) + expectTokenExists(t, &s, authToken_d1_1) + expectTokenExists(t, &s, authToken_d2_1) // Version 2 of the token for both devices authToken_d1_2 := authToken_d1_1 @@ -195,8 +237,12 @@ func TestStoreSaveToken(t *testing.T) { } // Get and confirm the tokens we just put in - expectTokenExists(t, &s, authToken_d1_2.Token, authToken_d1_2) - expectTokenExists(t, &s, authToken_d2_2.Token, authToken_d2_2) + expectTokenExists(t, &s, authToken_d1_2) + expectTokenExists(t, &s, authToken_d2_2) + + // Confirm the old ones are gone + expectTokenNotExists(t, &s, authToken_d1_1.Token) + expectTokenNotExists(t, &s, authToken_d2_1.Token) } // test GetToken using insertToken and updateToken as helpers (so we can set expiration timestamps) diff --git a/store/wallet_test.go b/store/wallet_test.go index bba3fbe..ea18e00 100644 --- a/store/wallet_test.go +++ b/store/wallet_test.go @@ -15,17 +15,65 @@ func expectWalletExists( expectedSequence wallet.Sequence, expectedHmac wallet.WalletHmac, ) { - encryptedWallet, sequence, hmac, err := s.GetWallet(userId) - if encryptedWallet != expectedEncryptedWallet || sequence != expectedSequence || hmac != expectedHmac || err != nil { - t.Fatalf("Unexpected values for wallet: encrypted wallet: %+v sequence: %+v hmac: %+v err: %+v", encryptedWallet, sequence, hmac, err) + rows, err := s.db.Query( + "SELECT encrypted_wallet, sequence, hmac FROM wallets WHERE user_id=?", userId) + if err != nil { + t.Fatalf("Error finding wallet for user_id=%d: %+v", userId, err) } + defer rows.Close() + + var encryptedWallet wallet.EncryptedWallet + var sequence wallet.Sequence + var hmac wallet.WalletHmac + + for rows.Next() { + + err := rows.Scan( + &encryptedWallet, + &sequence, + &hmac, + ) + + if err != nil { + t.Fatalf("Error finding wallet for user_id=%d: %+v", userId, err) + } + + if encryptedWallet != expectedEncryptedWallet || sequence != expectedSequence || hmac != expectedHmac || err != nil { + t.Fatalf("Unexpected values for wallet: encrypted wallet: %+v sequence: %+v hmac: %+v err: %+v", encryptedWallet, sequence, hmac, err) + } + + return // found a match, we're good + } + t.Fatalf("Expected wallet for user_id=%d: %+v", userId, err) } func expectWalletNotExists(t *testing.T, s *Store, userId auth.UserId) { - encryptedWallet, sequence, hmac, err := s.GetWallet(userId) - if len(encryptedWallet) != 0 || sequence != 0 || len(hmac) != 0 || err != ErrNoWallet { - t.Fatalf("Expected ErrNoWallet, and no wallet values. Instead got: encrypted wallet: %+v sequence: %+v hmac: %+v err: %+v", encryptedWallet, sequence, hmac, err) + rows, err := s.db.Query( + "SELECT encrypted_wallet, sequence, hmac FROM wallets WHERE user_id=?", userId) + if err != nil { + t.Fatalf("Error finding (lack of) wallet for user_id=%d: %+v", userId, err) } + defer rows.Close() + + var encryptedWallet wallet.EncryptedWallet + var sequence wallet.Sequence + var hmac wallet.WalletHmac + + for rows.Next() { + + err := rows.Scan( + &encryptedWallet, + &sequence, + &hmac, + ) + + if err != nil { + t.Fatalf("Error finding (lack of) wallet for user_id=%d: %+v", userId, err) + } + + t.Fatalf("Expected no wallet. Got: encrypted wallet: %+v sequence: %+v hmac: %+v", encryptedWallet, sequence, hmac) + } + return // found nothing, we're good } func setupWalletTest(s *Store) auth.UserId {