Add errors to GetWallet test

This commit is contained in:
Daniel Krol 2022-06-21 11:52:03 -04:00
parent 8ba482521f
commit 1f4bf9da6d
4 changed files with 75 additions and 53 deletions

View file

@ -13,7 +13,7 @@ import (
) )
func TestServerAuthHandlerSuccess(t *testing.T) { func TestServerAuthHandlerSuccess(t *testing.T) {
testAuth := TestAuth{TestToken: auth.TokenString("seekrit")} testAuth := TestAuth{TestNewTokenString: auth.TokenString("seekrit")}
testStore := TestStore{} testStore := TestStore{}
s := Server{&testAuth, &testStore} s := Server{&testAuth, &testStore}
@ -25,19 +25,17 @@ func TestServerAuthHandlerSuccess(t *testing.T) {
s.getAuthToken(w, req) s.getAuthToken(w, req)
body, _ := ioutil.ReadAll(w.Body) body, _ := ioutil.ReadAll(w.Body)
if want, got := http.StatusOK, w.Result().StatusCode; want != got { expectStatusCode(t, w, http.StatusOK)
t.Errorf("StatusCode: expected %s (%d), got %s (%d)", http.StatusText(want), want, http.StatusText(got), got)
}
var result auth.AuthToken var result auth.AuthToken
err := json.Unmarshal(body, &result) err := json.Unmarshal(body, &result)
if err != nil || result.Token != testAuth.TestToken { if err != nil || result.Token != testAuth.TestNewTokenString {
t.Errorf("Expected auth response to contain token: result: %+v err: %+v", string(body), err) t.Errorf("Expected auth response to contain token: result: %+v err: %+v", string(body), err)
} }
if testStore.Called.SaveToken != testAuth.TestToken { if testStore.Called.SaveToken != testAuth.TestNewTokenString {
t.Errorf("Expected Store.SaveToken to be called with %s", testAuth.TestToken) t.Errorf("Expected Store.SaveToken to be called with %s", testAuth.TestNewTokenString)
} }
} }
@ -76,7 +74,7 @@ func TestServerAuthHandlerErrors(t *testing.T) {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
// Set this up to fail according to specification // Set this up to fail according to specification
testAuth := TestAuth{TestToken: auth.TokenString("seekrit")} testAuth := TestAuth{TestNewTokenString: auth.TokenString("seekrit")}
testStore := TestStore{Errors: tc.storeErrors} testStore := TestStore{Errors: tc.storeErrors}
if tc.authFailGenToken { // TODO - TestAuth{Errors:authErrors} if tc.authFailGenToken { // TODO - TestAuth{Errors:authErrors}
testAuth.FailGenToken = true testAuth.FailGenToken = true
@ -91,7 +89,8 @@ func TestServerAuthHandlerErrors(t *testing.T) {
server.getAuthToken(w, req) server.getAuthToken(w, req)
expectErrorResponse(t, w, tc.expectedStatusCode, tc.expectedErrorString) expectStatusCode(t, w, tc.expectedStatusCode)
expectErrorString(t, w, tc.expectedErrorString)
}) })
} }
} }

View file

@ -8,12 +8,11 @@ import (
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"orblivion/lbry-id/auth"
"orblivion/lbry-id/store" "orblivion/lbry-id/store"
) )
func TestServerRegisterSuccess(t *testing.T) { func TestServerRegisterSuccess(t *testing.T) {
testAuth := TestAuth{TestToken: auth.TokenString("seekrit")} testAuth := TestAuth{}
testStore := TestStore{} testStore := TestStore{}
s := Server{&testAuth, &testStore} s := Server{&testAuth, &testStore}
@ -25,9 +24,7 @@ func TestServerRegisterSuccess(t *testing.T) {
s.register(w, req) s.register(w, req)
body, _ := ioutil.ReadAll(w.Body) body, _ := ioutil.ReadAll(w.Body)
if want, got := http.StatusCreated, w.Result().StatusCode; want != got { expectStatusCode(t, w, http.StatusCreated)
t.Errorf("StatusCode: expected %s (%d), got %s (%d)", http.StatusText(want), want, http.StatusText(got), got)
}
if string(body) != "{}" { if string(body) != "{}" {
t.Errorf("Expected register response to be \"{}\": result: %+v", string(body)) t.Errorf("Expected register response to be \"{}\": result: %+v", string(body))
@ -66,7 +63,7 @@ func TestServerRegisterErrors(t *testing.T) {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
// Set this up to fail according to specification // Set this up to fail according to specification
testAuth := TestAuth{TestToken: auth.TokenString("seekrit")} testAuth := TestAuth{}
testStore := TestStore{Errors: tc.storeErrors} testStore := TestStore{Errors: tc.storeErrors}
server := Server{&testAuth, &testStore} server := Server{&testAuth, &testStore}
@ -77,7 +74,8 @@ func TestServerRegisterErrors(t *testing.T) {
server.register(w, req) server.register(w, req)
expectErrorResponse(t, w, tc.expectedStatusCode, tc.expectedErrorString) expectStatusCode(t, w, tc.expectedStatusCode)
expectErrorString(t, w, tc.expectedErrorString)
}) })
} }
} }

View file

@ -16,15 +16,15 @@ import (
// Implementing interfaces for stubbed out packages // Implementing interfaces for stubbed out packages
type TestAuth struct { type TestAuth struct {
TestToken auth.TokenString TestNewTokenString auth.TokenString
FailGenToken bool FailGenToken bool
} }
func (a *TestAuth) NewToken(userId auth.UserId, deviceId auth.DeviceId, scope auth.AuthScope) (*auth.AuthToken, error) { func (a *TestAuth) NewToken(userId auth.UserId, deviceId auth.DeviceId, scope auth.AuthScope) (*auth.AuthToken, error) {
if a.FailGenToken { if a.FailGenToken {
return nil, fmt.Errorf("Test error: fail to generate token") return nil, fmt.Errorf("Test error: fail to generate token")
} }
return &auth.AuthToken{Token: a.TestToken, UserId: userId, DeviceId: deviceId, Scope: scope}, nil return &auth.AuthToken{Token: a.TestNewTokenString, UserId: userId, DeviceId: deviceId, Scope: scope}, nil
} }
// Whether functions are called, and sometimes what they're called with // Whether functions are called, and sometimes what they're called with
@ -103,19 +103,22 @@ func (s *TestStore) GetWallet(userId auth.UserId) (encryptedWallet wallet.Encryp
return return
} }
// expectErrorResponse: A helper to call in functions that test that request // expectStatusCode: A helper to call in functions that test that request
// handlers fail with a certain status code and error string. Cuts down on // handlers responded with a certain status code. Cuts down on noise.
// noise. func expectStatusCode(t *testing.T, w *httptest.ResponseRecorder, expectedStatusCode int) {
func expectErrorResponse(t *testing.T, w *httptest.ResponseRecorder, expectedStatusCode int, expectedErrorString string) {
if want, got := expectedStatusCode, w.Result().StatusCode; want != got { if want, got := expectedStatusCode, w.Result().StatusCode; want != got {
t.Errorf("StatusCode: expected %d, got %d", want, got) t.Errorf("StatusCode: expected %s (%d), got %s (%d)", http.StatusText(want), want, http.StatusText(got), got)
} }
}
// expectErrorString: A helper to call in functions that test that request
// handlers failed with a certain error string. Cuts down on noise.
func expectErrorString(t *testing.T, w *httptest.ResponseRecorder, expectedErrorString string) {
body, _ := ioutil.ReadAll(w.Body) body, _ := ioutil.ReadAll(w.Body)
var result ErrorResponse var result ErrorResponse
if err := json.Unmarshal(body, &result); err != nil { if err := json.Unmarshal(body, &result); err != nil {
t.Fatalf("Error decoding error message %s: `%s`", err, body) t.Fatalf("Error decoding error message: %s: `%s`", err, body)
} }
if want, got := expectedErrorString, result.Error; want != got { if want, got := expectedErrorString, result.Error; want != got {
@ -201,7 +204,8 @@ func TestServerHelperGetPostDataErrors(t *testing.T) {
t.Errorf("getPostData succeeded unexpectedly") t.Errorf("getPostData succeeded unexpectedly")
} }
expectErrorResponse(t, w, tc.expectedStatusCode, tc.expectedErrorString) expectStatusCode(t, w, tc.expectedStatusCode)
expectErrorString(t, w, tc.expectedErrorString)
}) })
} }
} }

View file

@ -2,29 +2,51 @@ package server
import ( import (
"encoding/json" "encoding/json"
"fmt"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"orblivion/lbry-id/auth" "orblivion/lbry-id/auth"
"orblivion/lbry-id/store"
"orblivion/lbry-id/wallet" "orblivion/lbry-id/wallet"
) )
func TestServerGetWalletSuccess(t *testing.T) { func TestServerGetWallet(t *testing.T) {
tt := []struct { tt := []struct {
name string name string
expectedStatusCode int
expectedErrorString string
storeErrors TestStoreFunctionsErrors
}{ }{
{ {
name: "success", name: "success",
expectedStatusCode: http.StatusOK,
},
{
name: "auth error",
expectedStatusCode: http.StatusUnauthorized,
expectedErrorString: http.StatusText(http.StatusUnauthorized) + ": Token Not Found",
storeErrors: TestStoreFunctionsErrors{GetToken: store.ErrNoToken},
},
{
name: "db error getting wallet",
expectedStatusCode: http.StatusInternalServerError,
expectedErrorString: http.StatusText(http.StatusInternalServerError),
storeErrors: TestStoreFunctionsErrors{GetWallet: fmt.Errorf("Some random DB Error!")},
}, },
} }
for _, tc := range tt { for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
testAuth := TestAuth{ testAuth := TestAuth{}
TestToken: auth.TokenString("seekrit"),
}
testStore := TestStore{ testStore := TestStore{
TestAuthToken: auth.AuthToken{ TestAuthToken: auth.AuthToken{
Token: auth.TokenString("seekrit"), Token: auth.TokenString("seekrit"),
@ -34,6 +56,8 @@ func TestServerGetWalletSuccess(t *testing.T) {
TestEncryptedWallet: wallet.EncryptedWallet("my-encrypted-wallet"), TestEncryptedWallet: wallet.EncryptedWallet("my-encrypted-wallet"),
TestSequence: wallet.Sequence(2), TestSequence: wallet.Sequence(2),
TestHmac: wallet.WalletHmac("my-hmac"), TestHmac: wallet.WalletHmac("my-hmac"),
Errors: tc.storeErrors,
} }
s := Server{&testAuth, &testStore} s := Server{&testAuth, &testStore}
@ -47,44 +71,41 @@ func TestServerGetWalletSuccess(t *testing.T) {
// test handleWallet while we're at it, which is a dispatch for get and post // test handleWallet while we're at it, which is a dispatch for get and post
// wallet // wallet
s.handleWallet(w, req) s.handleWallet(w, req)
body, _ := ioutil.ReadAll(w.Body)
if want, got := http.StatusOK, w.Result().StatusCode; want != got { // Make sure we tried to get an auth based on the `token` param (whether or
t.Errorf("StatusCode: expected %s (%d), got %s (%d)", http.StatusText(want), want, http.StatusText(got), got) // not it was a valid `token`)
if testStore.Called.GetToken != testStore.TestAuthToken.Token {
t.Errorf("Expected Store.GetToken to be called with %s. Got %s",
testStore.TestAuthToken.Token,
testStore.Called.GetToken)
} }
expectStatusCode(t, w, tc.expectedStatusCode)
if len(tc.expectedErrorString) > 0 {
// Only check if we're expecting an error, since it reads the body
expectErrorString(t, w, tc.expectedErrorString)
return
}
body, _ := ioutil.ReadAll(w.Body)
var result WalletResponse var result WalletResponse
err := json.Unmarshal(body, &result) err := json.Unmarshal(body, &result)
if err != nil || result.EncryptedWallet != testStore.TestEncryptedWallet || result.Hmac != testStore.TestHmac || result.Sequence != testStore.TestSequence { if err != nil ||
result.EncryptedWallet != testStore.TestEncryptedWallet ||
result.Hmac != testStore.TestHmac ||
result.Sequence != testStore.TestSequence {
t.Errorf("Expected wallet response to have the test wallet values: result: %+v err: %+v", string(body), err) t.Errorf("Expected wallet response to have the test wallet values: result: %+v err: %+v", string(body), err)
} }
if !testStore.Called.GetWallet { if !testStore.Called.GetWallet {
t.Errorf("Expected Store.GetWallet to be called") t.Errorf("Expected Store.GetWallet to be called")
} }
// Make sure the right auth was gotten
if testStore.Called.GetToken != testAuth.TestToken {
t.Errorf("Expected Store.GetToken to be called with %s", testAuth.TestToken)
}
}) })
} }
} }
func TestServerGetWalletErrors(t *testing.T) {
t.Fatalf("Test me: GetWallet fails for various reasons (malformed, auth, db fail)")
}
func TestServerGetWalletParams(t *testing.T) {
t.Fatalf("Test me: getWalletParams")
}
func TestServerPostWalletSuccess(t *testing.T) {
t.Fatalf("Test me: PostWallet succeeds and returns the new wallet, PostWallet succeeds but is preempted")
}
func TestServerPostWalletTooLate(t *testing.T) { func TestServerPostWalletTooLate(t *testing.T) {
t.Fatalf("Test me: PostWallet fails for sequence being too low, returns the latest wallet") t.Fatalf("Test me: PostWallet fails for sequence being too low, returns the latest wallet")
} }