From 5d297b04bb2e19f1416ec08eb9803b4ef16dc0fa Mon Sep 17 00:00:00 2001 From: Daniel Krol Date: Sun, 19 Dec 2021 17:24:43 -0500 Subject: [PATCH] Implement and test a basic sqlite store --- auth.go | 7 +- go.mod | 2 + go.sum | 2 + server_test.go | 12 +- store.go | 138 ++++++++++++++++++++- store_test.go | 325 +++++++++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 473 insertions(+), 13 deletions(-) create mode 100644 store_test.go diff --git a/auth.go b/auth.go index 955b7b2..1ffc925 100644 --- a/auth.go +++ b/auth.go @@ -1,5 +1,7 @@ package main // TODO - make it its own `auth` package later +import "time" + // TODO - Learn how to use https://github.com/golang/oauth2 instead // TODO - Look into jwt, etc. // For now I just want a process that's shaped like what I'm looking for (pubkey signatures, downloadKey, etc) @@ -24,7 +26,10 @@ func (a *Auth) IsValidSignature(pubKey PublicKey, payload string, signature stri } type AuthToken struct { - Token AuthTokenString `json:"token"` + Token AuthTokenString `json:"token"` + DeviceID string `json:"deviceId"` + PubKey PublicKey `json:"publicKey"` + Expiration *time.Time `json:"expiration"` } type TokenRequest struct { diff --git a/go.mod b/go.mod index f68455b..b004959 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,5 @@ module orblivion/lbry-id go 1.17 + +require github.com/mattn/go-sqlite3 v1.14.9 diff --git a/go.sum b/go.sum index e69de29..91694bf 100644 --- a/go.sum +++ b/go.sum @@ -0,0 +1,2 @@ +github.com/mattn/go-sqlite3 v1.14.9 h1:10HX2Td0ocZpYEjhilsuo6WWtUqttj2Kb0KtD86/KYA= +github.com/mattn/go-sqlite3 v1.14.9/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= diff --git a/server_test.go b/server_test.go index c999708..c612766 100644 --- a/server_test.go +++ b/server_test.go @@ -47,10 +47,10 @@ func (s *TestStore) SaveToken(token *AuthToken) error { //////////////// -func TestAuthHandlerSuccess(t *testing.T) { +func TestServerAuthHandlerSuccess(t *testing.T) { testAuth := TestAuth{TestToken: AuthTokenString("seekrit")} testStore := TestStore{} - testServer := Server{ + s := Server{ &testAuth, &testStore, } @@ -64,7 +64,7 @@ func TestAuthHandlerSuccess(t *testing.T) { req := httptest.NewRequest(http.MethodPost, PathGetAuthToken, bytes.NewBuffer(requestBody)) w := httptest.NewRecorder() - testServer.getAuthToken(w, req) + s.getAuthToken(w, req) body, _ := ioutil.ReadAll(w.Body) var result AuthToken @@ -84,7 +84,7 @@ func TestAuthHandlerSuccess(t *testing.T) { } } -func TestAuthHandlerErrors(t *testing.T) { +func TestServerAuthHandlerErrors(t *testing.T) { tt := []struct { name string method string @@ -194,13 +194,13 @@ func TestAuthHandlerErrors(t *testing.T) { } } -func TestValidateAuthRequest(t *testing.T) { +func TestServerValidateAuthRequest(t *testing.T) { // also add a basic test case for this in TestAuthHandlerErrors to make sure it's called at all // Maybe 401 specifically for missing signature? t.Fatalf("Implement and test validateAuthRequest") } -func TestValidateTokenRequest(t *testing.T) { +func TestServerValidateTokenRequest(t *testing.T) { // also add a basic test case for this in TestAuthHandlerErrors to make sure it's called at all t.Fatalf("Implement and test validateTokenRequest") } diff --git a/store.go b/store.go index 8c8eac0..8d2e1dd 100644 --- a/store.go +++ b/store.go @@ -1,16 +1,142 @@ package main // TODO - make it its own `store` package later +import ( + "database/sql" + + "errors" + "fmt" + "github.com/mattn/go-sqlite3" + "log" + "time" +) + +var ( + ErrDuplicateToken = fmt.Errorf("Token already exists for this user and device") + ErrNoToken = fmt.Errorf("Token does not exist for this user and device") +) + type StoreInterface interface { SaveToken(*AuthToken) error } -type Store struct{} +type Store struct { + db *sql.DB +} -func (s *Store) SaveToken(token *AuthToken) error { - // params: pubKey PublicKey, DeviceID string? - // or is PubKey part of AuthToken struct? - // Anyway, (pubkey, deviceID) is primary key. we should have one token for each device. - return nil +func (s *Store) Migrate() error { + query := ` + CREATE TABLE IF NOT EXISTS auth_tokens( + token TEXT NOT NULL, + public_key TEXT NOT NULL, + device_id TEXT NOT NULL, + expiration DATETIME NOT NULL, + PRIMARY KEY (public_key, device_id) + ); + ` + + _, err := s.db.Exec(query) + return err +} + +func (s *Store) GetToken(pubKey PublicKey, deviceID string) (*AuthToken, error) { + expirationCutoff := time.Now().UTC() + + rows, err := s.db.Query("SELECT * FROM auth_tokens WHERE public_key=? AND device_id=? AND expiration>?", + pubKey, deviceID, expirationCutoff, + ) + if err != nil { + return nil, err + } + defer rows.Close() + + var authToken AuthToken + for rows.Next() { + + err := rows.Scan( + &authToken.Token, + &authToken.PubKey, + &authToken.DeviceID, + &authToken.Expiration, + ) + + if err != nil { + return nil, err + } + return &authToken, nil + } + return nil, nil +} + +func (s *Store) insertToken(authToken *AuthToken, expiration time.Time) (err error) { + _, err = s.db.Exec( + "INSERT INTO auth_tokens (token, public_key, device_id, expiration) values(?,?,?,?)", + authToken.Token, authToken.PubKey, authToken.DeviceID, expiration, + ) + + var sqliteErr sqlite3.Error + if errors.As(err, &sqliteErr) { + // I initially expected to need to check for ErrConstraintUnique. + // Maybe for psql it will be? + if errors.Is(sqliteErr.ExtendedCode, sqlite3.ErrConstraintPrimaryKey) { + err = ErrDuplicateToken + } + } + + return +} + +func (s *Store) updateToken(authToken *AuthToken, experation time.Time) (err error) { + res, err := s.db.Exec( + "UPDATE auth_tokens SET token=?, expiration=? WHERE public_key=? AND device_id=?", + authToken.Token, experation, authToken.PubKey, authToken.DeviceID, + ) + if err != nil { + return + } + + numRows, err := res.RowsAffected() + if err != nil { + return + } + if numRows == 0 { + err = ErrNoToken + } + return +} + +func (s *Store) SaveToken(token *AuthToken) (err error) { + // TODO: For psql, do upsert here instead of separate insertToken and updateToken functions + + // TODO - Should we auto-delete expired tokens? + + expiration := time.Now().UTC().Add(time.Hour * 24 * 14) + + // This is most likely not the first time calling this function for this + // device, so there's probably already a token in there. + err = s.updateToken(token, expiration) + + if err == ErrNoToken { + // If we don't have a token already saved, insert a new one: + err = s.insertToken(token, expiration) + + if err == ErrDuplicateToken { + // By unlikely coincidence, a token was created between trying `updateToken` + // and trying `insertToken`. At this point we can safely `updateToken`. + err = s.updateToken(token, expiration) + } + } + if err == nil { + token.Expiration = &expiration + } + return +} + +func (s *Store) Init(fileName string) { + db, err := sql.Open("sqlite3", fileName) + if err != nil { + log.Fatal(err) + } + s.db = db } /* TODO: diff --git a/store_test.go b/store_test.go new file mode 100644 index 0000000..40761bd --- /dev/null +++ b/store_test.go @@ -0,0 +1,325 @@ +package main + +import ( + "io/ioutil" + "os" + "reflect" + "testing" + "time" +) + +func storeTestInit(t *testing.T) (s Store, tmpFile *os.File) { + s = Store{} + + tmpFile, err := ioutil.TempFile(os.TempDir(), "sqlite-test-") + if err != nil { + t.Fatalf("DB setup failure: %+v", err) + return + } + + s.Init(tmpFile.Name()) + + err = s.Migrate() + if err != nil { + t.Fatalf("DB setup failure: %+v", err) + } + + return +} + +func storeTestCleanup(tmpFile *os.File) { + if tmpFile != nil { + os.Remove(tmpFile.Name()) + } +} + +// Test insertToken, using GetToken as a helper +// Try insertToken twice with the same public key, error the second time +func TestStoreInsertToken(t *testing.T) { + + s, tmpFile := storeTestInit(t) + defer storeTestCleanup(tmpFile) + + authToken1 := AuthToken{ + Token: "seekrit-1", + DeviceID: "dID", + PubKey: "pubKey", + } + + // The value expected when we pull it from the database. + authToken1DB := authToken1 + authToken1DB.Expiration = timePtr(time.Now().Add(time.Hour * 24 * 14).UTC()) + + authToken2 := authToken1 + authToken2.Token = "seekrit-2" + + // Get a token, come back empty + gotToken, err := s.GetToken(authToken1.PubKey, authToken1.DeviceID) + if gotToken != nil || err != nil { + t.Fatalf("Expected no token and no error. token: %+v err: %+v", gotToken, err) + } + + // Put in a token + if err := s.insertToken(&authToken1, *authToken1DB.Expiration); err != nil { + t.Fatalf("Unexpected error in insertToken: %+v", err) + } + + // Get and confirm the token we just put in + gotToken, err = s.GetToken(authToken1.PubKey, authToken1.DeviceID) + if err != nil { + t.Fatalf("Unexpected error in GetToken: %+v", err) + } + if gotToken == nil || !reflect.DeepEqual(*gotToken, authToken1DB) { + t.Fatalf("token: expected %+v, got: %+v", authToken1DB, gotToken) + } + + // Try to put a different token, fail becaues we already have one + if err := s.insertToken(&authToken2, *authToken1DB.Expiration); err != ErrDuplicateToken { + t.Fatalf(`insertToken err: wanted "%+v", got "%+v"`, ErrDuplicateToken, err) + } + + // Get the same *first* token we successfully put in + gotToken, err = s.GetToken(authToken1.PubKey, authToken1.DeviceID) + if err != nil { + t.Fatalf("Unexpected error in GetToken: %+v", err) + } + if gotToken == nil || !reflect.DeepEqual(*gotToken, authToken1DB) { + t.Fatalf("token: expected %+v, got: %+v", authToken1DB, gotToken) + } +} + +// Test updateToken, using GetToken and insertToken as helpers +// Try updateToken with no existing token, err for lack of anything to update +// Try updateToken with a preexisting token, succeed +// Try updateToken again with a new token, succeed +func TestStoreUpdateToken(t *testing.T) { + s, tmpFile := storeTestInit(t) + defer storeTestCleanup(tmpFile) + + authToken1 := AuthToken{ + Token: "seekrit-1", + DeviceID: "dID", + PubKey: "pubKey", + } + authToken2 := authToken1 + authToken2.Token = "seekrit-2" + + // The value expected when we pull it from the database. + authToken2DB := authToken2 + authToken2DB.Expiration = timePtr(time.Now().Add(time.Hour * 24 * 14).UTC()) + + // Try to get a token, come back empty because we're just starting out + gotToken, err := s.GetToken(authToken1.PubKey, authToken1.DeviceID) + if gotToken != nil || err != nil { + t.Fatalf("Expected no token and no error. token: %+v err: %+v", gotToken, err) + } + + // Try to update the token - fail because we don't have an entry there in the first place + if err := s.updateToken(&authToken1, *authToken2DB.Expiration); err != ErrNoToken { + t.Fatalf(`updateToken err: wanted "%+v", got "%+v"`, ErrNoToken, err) + } + + // Try to get a token, come back empty because the update attempt failed to do anything + gotToken, err = s.GetToken(authToken1.PubKey, authToken1.DeviceID) + if gotToken != nil || err != nil { + t.Fatalf("Expected no token and no error. token: %+v err: %+v", gotToken, err) + } + + // Put in a token - just so we have something to test updateToken with + if err := s.insertToken(&authToken1, *authToken2DB.Expiration); err != nil { + t.Fatalf("Unexpected error in insertToken: %+v", err) + } + + // Now successfully update token + if err := s.updateToken(&authToken2, *authToken2DB.Expiration); err != nil { + t.Fatalf("Unexpected error in updateToken: %+v", err) + } + + // Get and confirm the token we just put in + gotToken, err = s.GetToken(authToken2.PubKey, authToken2.DeviceID) + if err != nil { + t.Fatalf("Unexpected error in GetToken: %+v", err) + } + if gotToken == nil || !reflect.DeepEqual(*gotToken, authToken2DB) { + t.Fatalf("token: \n expected %+v\n got: %+v", authToken2DB, *gotToken) + } +} + +// Two different devices. +// Test first and second Save (one for insert, one for update) +// Get fails initially +// Put token1-d1 token1-d2 +// Get token1-d1 token1-d2 +// Put token2-d1 token2-d2 +// Get token2-d1 token2-d2 +func TestStoreSaveToken(t *testing.T) { + s, tmpFile := storeTestInit(t) + defer storeTestCleanup(tmpFile) + + // Version 1 of the token for both devices + authToken_d1_1 := AuthToken{ + Token: "seekrit-d1-1", + DeviceID: "dID-1", + PubKey: "pubKey", + } + + authToken_d2_1 := authToken_d1_1 + authToken_d2_1.DeviceID = "dID-2" + authToken_d2_1.Token = "seekrit-d2-1" + + // Version 2 of the token for both devices + authToken_d1_2 := authToken_d1_1 + authToken_d1_2.Token = "seekrit-d1-2" + + authToken_d2_2 := authToken_d2_1 + authToken_d2_2.Token = "seekrit-d2-2" + + // Try to get the tokens, come back empty because we're just starting out + gotToken, err := s.GetToken(authToken_d1_1.PubKey, authToken_d1_1.DeviceID) + if gotToken != nil || err != nil { + t.Fatalf("Expected no token and no error. token: %+v err: %+v", gotToken, err) + } + gotToken, err = s.GetToken(authToken_d2_1.PubKey, authToken_d2_1.DeviceID) + if gotToken != nil || err != nil { + t.Fatalf("Expected no token and no error. token: %+v err: %+v", gotToken, err) + } + + // Save Version 1 tokens for both devices + if err = s.SaveToken(&authToken_d1_1); err != nil { + t.Fatalf("Unexpected error in SaveToken: %+v", err) + } + if err = s.SaveToken(&authToken_d2_1); err != nil { + t.Fatalf("Unexpected error in SaveToken: %+v", err) + } + + // Check one of the authTokens to make sure expiration was set + if authToken_d1_1.Expiration == nil { + t.Fatalf("Expected SaveToken to set an Expiration") + } + nowDiff := authToken_d1_1.Expiration.Sub(time.Now()) + if time.Hour*24*14+time.Minute < nowDiff || nowDiff < time.Hour*24*14-time.Minute { + t.Fatalf("Expected SaveToken to set a token Expiration 2 weeks in the future.") + } + + // Get and confirm the tokens we just put in + gotToken, err = s.GetToken(authToken_d1_1.PubKey, authToken_d1_1.DeviceID) + if err != nil { + t.Fatalf("Unexpected error in GetToken: %+v", err) + } + if gotToken == nil || !reflect.DeepEqual(*gotToken, authToken_d1_1) { + t.Fatalf("token: \n expected %+v\n got: %+v", authToken_d1_1, gotToken) + } + gotToken, err = s.GetToken(authToken_d2_1.PubKey, authToken_d2_1.DeviceID) + if err != nil { + t.Fatalf("Unexpected error in GetToken: %+v", err) + } + if gotToken == nil || !reflect.DeepEqual(*gotToken, authToken_d2_1) { + t.Fatalf("token: expected %+v, got: %+v", authToken_d2_1, gotToken) + } + + // Save Version 2 tokens for both devices + if err = s.SaveToken(&authToken_d1_2); err != nil { + t.Fatalf("Unexpected error in SaveToken: %+v", err) + } + if err = s.SaveToken(&authToken_d2_2); err != nil { + t.Fatalf("Unexpected error in SaveToken: %+v", err) + } + + // Check that the expiration of this new token is marginally later + if authToken_d1_2.Expiration == nil { + t.Fatalf("Expected SaveToken to set an Expiration") + } + expDiff := authToken_d1_2.Expiration.Sub(*authToken_d1_1.Expiration) + if time.Second < expDiff || expDiff < 0 { + t.Fatalf("Expected new expiration to be slightly later than previous expiration. diff: %+v", expDiff) + } + + // Get and confirm the tokens we just put in + gotToken, err = s.GetToken(authToken_d1_2.PubKey, authToken_d1_2.DeviceID) + if err != nil { + t.Fatalf("Unexpected error in GetToken: %+v", err) + } + if gotToken == nil || !reflect.DeepEqual(*gotToken, authToken_d1_2) { + t.Fatalf("token: \n expected %+v\n got: %+v", authToken_d1_2, gotToken) + } + gotToken, err = s.GetToken(authToken_d2_2.PubKey, authToken_d2_2.DeviceID) + if err != nil { + t.Fatalf("Unexpected error in GetToken: %+v", err) + } + if gotToken == nil || !reflect.DeepEqual(*gotToken, authToken_d2_2) { + t.Fatalf("token: expected %+v, got: %+v", authToken_d2_2, gotToken) + } +} + +func timePtr(t time.Time) *time.Time { + return &t +} + +// test GetToken using insertToken and updateToken as helpers (so we can set expiration timestamps) +// normal +// not found for pubkey +// not found for device (one for another device does exist) +// expired token not returned +func TestStoreGetToken(t *testing.T) { + s, tmpFile := storeTestInit(t) + defer storeTestCleanup(tmpFile) + + // created for addition to the DB (no expiration attached) + authToken := AuthToken{ + Token: "seekrit-d1", + DeviceID: "dID", + PubKey: "pubKey", + } + + // The value expected when we pull it from the database. + authTokenDB := authToken + authTokenDB.Expiration = timePtr(time.Time(time.Now().UTC().Add(time.Hour * 24 * 14))) + + // Not found (nothing saved for this pubkey) + gotToken, err := s.GetToken(authToken.PubKey, authToken.DeviceID) + if gotToken != nil || err != nil { + t.Fatalf("Expected no token and no error. token: %+v err: %+v", gotToken, err) + } + + // Put in a token + if err := s.insertToken(&authToken, *authTokenDB.Expiration); err != nil { + t.Fatalf("Unexpected error in insertToken: %+v", err) + } + + // Confirm it saved + gotToken, err = s.GetToken(authToken.PubKey, authToken.DeviceID) + if err != nil { + t.Fatalf("Unexpected error in GetToken: %+v", err) + } + if gotToken == nil || !reflect.DeepEqual(*gotToken, authTokenDB) { + t.Fatalf("token: \n expected %+v\n got: %+v", authTokenDB, gotToken) + } + + // Fail to get for another device + gotToken, err = s.GetToken(authToken.PubKey, "other-device") + if gotToken != nil || err != nil { + t.Fatalf("Expected no token and no error for nonexistent device. token: %+v err: %+v", gotToken, err) + } + + // Update the token to be expired + expirationOld := time.Now().Add(time.Second * (-1)) + if err := s.updateToken(&authToken, expirationOld); err != nil { + t.Fatalf("Unexpected error in updateToken: %+v", err) + } + + // Fail to get the expired token + gotToken, err = s.GetToken(authToken.PubKey, authToken.DeviceID) + if gotToken != nil || err != nil { + t.Fatalf("Expected no token and no error, for expired token. token: %+v err: %+v", gotToken, err) + } +} + +func TestStoreSanitizeEmptyFields(t *testing.T) { + // Make sure expiration doesn't get set if sanitization fails + t.Fatalf("Test me") +} + +func TestStoreTimeZones(t *testing.T) { + // Make sure the tz situation is as we prefer in the DB. Probably just do UTC. + t.Fatalf("Test me") +}