475 lines
14 KiB
Go
475 lines
14 KiB
Go
package server
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"lbryio/wallet-sync-server/auth"
|
|
"lbryio/wallet-sync-server/server/paths"
|
|
"lbryio/wallet-sync-server/store"
|
|
"lbryio/wallet-sync-server/wallet"
|
|
)
|
|
|
|
const TestPort = 8090
|
|
|
|
// Implementing interfaces for stubbed out packages
|
|
|
|
type SendVerificationEmailCall struct {
|
|
Email auth.Email
|
|
Token auth.VerifyTokenString
|
|
}
|
|
|
|
type TestMail struct {
|
|
SendVerificationEmailError error
|
|
SendVerificationEmailCall *SendVerificationEmailCall
|
|
}
|
|
|
|
func (m *TestMail) SendVerificationEmail(email auth.Email, token auth.VerifyTokenString) error {
|
|
m.SendVerificationEmailCall = &SendVerificationEmailCall{email, token}
|
|
return m.SendVerificationEmailError
|
|
}
|
|
|
|
type TestEnv struct {
|
|
env map[string]string
|
|
}
|
|
|
|
func (e *TestEnv) Getenv(key string) string {
|
|
return e.env[key]
|
|
}
|
|
|
|
type TestAuth struct {
|
|
TestNewAuthTokenString auth.AuthTokenString
|
|
TestNewVerifyTokenString auth.VerifyTokenString
|
|
FailGenToken bool
|
|
}
|
|
|
|
func (a *TestAuth) NewAuthToken(userId auth.UserId, deviceId auth.DeviceId, scope auth.AuthScope) (*auth.AuthToken, error) {
|
|
if a.FailGenToken {
|
|
return nil, fmt.Errorf("Test error: fail to generate token")
|
|
}
|
|
return &auth.AuthToken{Token: a.TestNewAuthTokenString, UserId: userId, DeviceId: deviceId, Scope: scope}, nil
|
|
}
|
|
|
|
func (a *TestAuth) NewVerifyTokenString() (auth.VerifyTokenString, error) {
|
|
if a.FailGenToken {
|
|
return "", fmt.Errorf("Test error: fail to generate token")
|
|
}
|
|
return a.TestNewVerifyTokenString, nil
|
|
}
|
|
|
|
type SetWalletCall struct {
|
|
EncryptedWallet wallet.EncryptedWallet
|
|
Sequence wallet.Sequence
|
|
Hmac wallet.WalletHmac
|
|
}
|
|
|
|
type ChangePasswordNoWalletCall struct {
|
|
Email auth.Email
|
|
OldPassword auth.Password
|
|
NewPassword auth.Password
|
|
ClientSaltSeed auth.ClientSaltSeed
|
|
}
|
|
|
|
type ChangePasswordWithWalletCall struct {
|
|
EncryptedWallet wallet.EncryptedWallet
|
|
Sequence wallet.Sequence
|
|
Hmac wallet.WalletHmac
|
|
Email auth.Email
|
|
OldPassword auth.Password
|
|
NewPassword auth.Password
|
|
ClientSaltSeed auth.ClientSaltSeed
|
|
}
|
|
|
|
type CreateAccountCall struct {
|
|
Email auth.Email
|
|
Password auth.Password
|
|
ClientSaltSeed auth.ClientSaltSeed
|
|
VerifyToken *auth.VerifyTokenString
|
|
}
|
|
|
|
// Whether functions are called, and sometimes what they're called with
|
|
type TestStoreFunctionsCalled struct {
|
|
SaveToken auth.AuthTokenString
|
|
GetToken auth.AuthTokenString
|
|
GetUserId bool
|
|
CreateAccount *CreateAccountCall
|
|
UpdateVerifyTokenString bool
|
|
VerifyAccount bool
|
|
SetWallet SetWalletCall
|
|
GetWallet bool
|
|
ChangePasswordWithWallet ChangePasswordWithWalletCall
|
|
ChangePasswordNoWallet ChangePasswordNoWalletCall
|
|
GetClientSaltSeed auth.Email
|
|
}
|
|
|
|
type TestStoreFunctionsErrors struct {
|
|
SaveToken error
|
|
GetToken error
|
|
GetUserId error
|
|
CreateAccount error
|
|
UpdateVerifyTokenString error
|
|
VerifyAccount error
|
|
SetWallet error
|
|
GetWallet error
|
|
ChangePasswordWithWallet error
|
|
ChangePasswordNoWallet error
|
|
GetClientSaltSeed error
|
|
}
|
|
|
|
type TestStore struct {
|
|
// Fake store functions will set these to `true` as they are called
|
|
Called TestStoreFunctionsCalled
|
|
|
|
// Fake store functions will return the errors (including `nil`) specified in
|
|
// the test setup
|
|
Errors TestStoreFunctionsErrors
|
|
|
|
TestAuthToken auth.AuthToken
|
|
TestUserId auth.UserId
|
|
|
|
TestEncryptedWallet wallet.EncryptedWallet
|
|
TestSequence wallet.Sequence
|
|
TestHmac wallet.WalletHmac
|
|
|
|
TestClientSaltSeed auth.ClientSaltSeed
|
|
}
|
|
|
|
func (s *TestStore) SaveToken(authToken *auth.AuthToken) error {
|
|
s.Called.SaveToken = authToken.Token
|
|
return s.Errors.SaveToken
|
|
}
|
|
|
|
func (s *TestStore) GetToken(token auth.AuthTokenString) (*auth.AuthToken, error) {
|
|
s.Called.GetToken = token
|
|
return &s.TestAuthToken, s.Errors.GetToken
|
|
}
|
|
|
|
func (s *TestStore) GetUserId(auth.Email, auth.Password) (auth.UserId, error) {
|
|
s.Called.GetUserId = true
|
|
return 0, s.Errors.GetUserId
|
|
}
|
|
|
|
func (s *TestStore) CreateAccount(email auth.Email, password auth.Password, seed auth.ClientSaltSeed, verifyToken *auth.VerifyTokenString) error {
|
|
s.Called.CreateAccount = &CreateAccountCall{
|
|
Email: email,
|
|
Password: password,
|
|
ClientSaltSeed: seed,
|
|
VerifyToken: verifyToken,
|
|
}
|
|
return s.Errors.CreateAccount
|
|
}
|
|
|
|
func (s *TestStore) UpdateVerifyTokenString(auth.Email, auth.VerifyTokenString) (err error) {
|
|
s.Called.UpdateVerifyTokenString = true
|
|
return s.Errors.UpdateVerifyTokenString
|
|
}
|
|
|
|
func (s *TestStore) VerifyAccount(auth.VerifyTokenString) (err error) {
|
|
s.Called.VerifyAccount = true
|
|
return s.Errors.VerifyAccount
|
|
}
|
|
|
|
func (s *TestStore) SetWallet(
|
|
UserId auth.UserId,
|
|
encryptedWallet wallet.EncryptedWallet,
|
|
sequence wallet.Sequence,
|
|
hmac wallet.WalletHmac,
|
|
) (err error) {
|
|
s.Called.SetWallet = SetWalletCall{encryptedWallet, sequence, hmac}
|
|
return s.Errors.SetWallet
|
|
}
|
|
|
|
func (s *TestStore) GetWallet(userId auth.UserId) (encryptedWallet wallet.EncryptedWallet, sequence wallet.Sequence, hmac wallet.WalletHmac, err error) {
|
|
s.Called.GetWallet = true
|
|
err = s.Errors.GetWallet
|
|
if err == nil {
|
|
encryptedWallet = s.TestEncryptedWallet
|
|
sequence = s.TestSequence
|
|
hmac = s.TestHmac
|
|
}
|
|
return
|
|
}
|
|
|
|
func (s *TestStore) ChangePasswordWithWallet(
|
|
email auth.Email,
|
|
oldPassword auth.Password,
|
|
newPassword auth.Password,
|
|
clientSaltSeed auth.ClientSaltSeed,
|
|
encryptedWallet wallet.EncryptedWallet,
|
|
sequence wallet.Sequence,
|
|
hmac wallet.WalletHmac,
|
|
) (auth.UserId, error) {
|
|
s.Called.ChangePasswordWithWallet = ChangePasswordWithWalletCall{
|
|
EncryptedWallet: encryptedWallet,
|
|
Sequence: sequence,
|
|
Hmac: hmac,
|
|
Email: email,
|
|
OldPassword: oldPassword,
|
|
NewPassword: newPassword,
|
|
ClientSaltSeed: clientSaltSeed,
|
|
}
|
|
return s.TestUserId, s.Errors.ChangePasswordWithWallet
|
|
}
|
|
|
|
func (s *TestStore) ChangePasswordNoWallet(
|
|
email auth.Email,
|
|
oldPassword auth.Password,
|
|
newPassword auth.Password,
|
|
clientSaltSeed auth.ClientSaltSeed,
|
|
) (auth.UserId, error) {
|
|
s.Called.ChangePasswordNoWallet = ChangePasswordNoWalletCall{
|
|
Email: email,
|
|
OldPassword: oldPassword,
|
|
NewPassword: newPassword,
|
|
ClientSaltSeed: clientSaltSeed,
|
|
}
|
|
return s.TestUserId, s.Errors.ChangePasswordNoWallet
|
|
}
|
|
|
|
func (s *TestStore) GetClientSaltSeed(email auth.Email) (seed auth.ClientSaltSeed, err error) {
|
|
s.Called.GetClientSaltSeed = email
|
|
err = s.Errors.GetClientSaltSeed
|
|
if err == nil {
|
|
seed = s.TestClientSaltSeed
|
|
}
|
|
return
|
|
}
|
|
|
|
// expectStatusCode: A helper to call in functions that test that request
|
|
// handlers responded with a certain status code. Cuts down on noise.
|
|
func expectStatusCode(t *testing.T, w *httptest.ResponseRecorder, expectedStatusCode int) {
|
|
if want, got := expectedStatusCode, w.Result().StatusCode; want != got {
|
|
t.Errorf("StatusCode: expected %s (%d), got %s (%d)", http.StatusText(want), want, http.StatusText(got), got)
|
|
}
|
|
}
|
|
|
|
// expectErrorString: A helper to call in functions that test that request
|
|
// handlers failed with a certain error string. Cuts down on noise.
|
|
func expectErrorString(t *testing.T, body []byte, expectedErrorString string) {
|
|
if len(body) == 0 {
|
|
// Nothing to decode
|
|
if expectedErrorString == "" {
|
|
return // Nothing expected either, we're all good
|
|
}
|
|
t.Errorf("Error String: expected %s, got an empty body (no JSON to decode)", expectedErrorString)
|
|
}
|
|
|
|
var result ErrorResponse
|
|
if err := json.Unmarshal(body, &result); err != nil {
|
|
t.Fatalf("Error decoding error message: %s: `%s`", err, body)
|
|
}
|
|
|
|
if want, got := expectedErrorString, result.Error; want != got {
|
|
t.Errorf("Error String: expected %s, got %s", want, got)
|
|
}
|
|
}
|
|
|
|
type wsMockManager struct {
|
|
s *Server
|
|
done chan bool
|
|
|
|
addedClientUserId auth.UserId
|
|
removedClientUserId auth.UserId
|
|
removedUserId auth.UserId
|
|
walletUpdateUserId auth.UserId
|
|
noMessage bool
|
|
}
|
|
|
|
func (m *wsMockManager) getOneMessage(timeout time.Duration) {
|
|
t := time.NewTicker(timeout)
|
|
select {
|
|
case msg := <-m.s.clientAdd:
|
|
m.addedClientUserId = msg.userId
|
|
case msg := <-m.s.clientRemove:
|
|
m.removedClientUserId = msg.userId
|
|
case msg := <-m.s.userRemove:
|
|
m.removedUserId = msg.userId
|
|
case msg := <-m.s.walletUpdates:
|
|
m.walletUpdateUserId = msg.userId
|
|
case <-t.C:
|
|
m.noMessage = true
|
|
}
|
|
t.Stop()
|
|
m.done <- true
|
|
}
|
|
|
|
func TestServerHelperCheckAuth(t *testing.T) {
|
|
tt := []struct {
|
|
name string
|
|
requiredScope auth.AuthScope
|
|
userScope auth.AuthScope
|
|
|
|
tokenExpected bool
|
|
expectedStatusCode int
|
|
expectedErrorString string
|
|
|
|
storeErrors TestStoreFunctionsErrors
|
|
}{
|
|
{
|
|
name: "success",
|
|
// Just check that scope checks exist. The more detailed specific tests
|
|
// go in the auth module
|
|
requiredScope: auth.AuthScope("banana"),
|
|
userScope: auth.AuthScope("*"),
|
|
|
|
// not that it's a full request but as of now no error yet means 200 by default
|
|
expectedStatusCode: 200,
|
|
tokenExpected: true,
|
|
}, {
|
|
name: "auth token not found",
|
|
requiredScope: auth.AuthScope("banana"),
|
|
userScope: auth.AuthScope("*"),
|
|
|
|
expectedStatusCode: http.StatusUnauthorized,
|
|
expectedErrorString: http.StatusText(http.StatusUnauthorized) + ": Token Not Found",
|
|
|
|
storeErrors: TestStoreFunctionsErrors{GetToken: store.ErrNoTokenForUserDevice},
|
|
}, {
|
|
name: "unknown auth token db error",
|
|
requiredScope: auth.AuthScope("banana"),
|
|
userScope: auth.AuthScope("*"),
|
|
|
|
expectedStatusCode: http.StatusInternalServerError,
|
|
expectedErrorString: http.StatusText(http.StatusInternalServerError),
|
|
|
|
storeErrors: TestStoreFunctionsErrors{GetToken: fmt.Errorf("Some random DB Error!")},
|
|
}, {
|
|
name: "auth scope failure",
|
|
requiredScope: auth.AuthScope("banana"),
|
|
userScope: auth.AuthScope("carrot"),
|
|
|
|
expectedStatusCode: http.StatusForbidden,
|
|
expectedErrorString: http.StatusText(http.StatusForbidden) + ": Scope",
|
|
},
|
|
}
|
|
for _, tc := range tt {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
testStore := TestStore{
|
|
Errors: tc.storeErrors,
|
|
TestAuthToken: auth.AuthToken{Token: auth.AuthTokenString("seekrit"), Scope: tc.userScope},
|
|
}
|
|
s := Init(&TestAuth{}, &testStore, &TestEnv{}, &TestMail{}, TestPort)
|
|
|
|
w := httptest.NewRecorder()
|
|
authToken := s.checkAuth(w, testStore.TestAuthToken.Token, tc.requiredScope)
|
|
if tc.tokenExpected && (*authToken != testStore.TestAuthToken) {
|
|
t.Errorf("Expected checkAuth to return a valid AuthToken")
|
|
}
|
|
if !tc.tokenExpected && (authToken != nil) {
|
|
t.Errorf("Expected checkAuth not to return a valid AuthToken")
|
|
}
|
|
body, _ := ioutil.ReadAll(w.Body)
|
|
|
|
expectStatusCode(t, w, tc.expectedStatusCode)
|
|
expectErrorString(t, body, tc.expectedErrorString)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestServerHelperGetGetDataSuccess(t *testing.T) {
|
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
|
w := httptest.NewRecorder()
|
|
success := getGetData(w, req)
|
|
if !success {
|
|
t.Errorf("getGetData failed unexpectedly")
|
|
}
|
|
}
|
|
func TestServerHelperGetGetDataErrors(t *testing.T) {
|
|
// Only error right now is if you do a POST request
|
|
req := httptest.NewRequest(http.MethodPost, "/test", nil)
|
|
w := httptest.NewRecorder()
|
|
success := getGetData(w, req)
|
|
if success {
|
|
t.Errorf("getGetData succeeded unexpectedly")
|
|
}
|
|
}
|
|
|
|
type TestReqStruct struct{ key string }
|
|
|
|
func (t *TestReqStruct) validate() error {
|
|
if t.key == "" {
|
|
return fmt.Errorf("TestReq Error")
|
|
} else {
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func TestServerHelperGetPostDataSuccess(t *testing.T) {
|
|
requestBody := []byte(`{}`)
|
|
req := httptest.NewRequest(http.MethodPost, "/test", bytes.NewBuffer(requestBody))
|
|
w := httptest.NewRecorder()
|
|
success := getPostData(w, req, &TestReqStruct{key: "hi"})
|
|
if !success {
|
|
t.Errorf("getPostData failed unexpectedly")
|
|
}
|
|
}
|
|
|
|
// Test getPostData, including requestOverhead and any other mini-helpers it calls.
|
|
func TestServerHelperGetPostDataErrors(t *testing.T) {
|
|
tt := []struct {
|
|
name string
|
|
method string
|
|
requestBody string
|
|
expectedStatusCode int
|
|
expectedErrorString string
|
|
}{
|
|
{
|
|
name: "bad method",
|
|
method: http.MethodGet,
|
|
requestBody: "",
|
|
expectedStatusCode: http.StatusMethodNotAllowed,
|
|
expectedErrorString: http.StatusText(http.StatusMethodNotAllowed),
|
|
},
|
|
{
|
|
name: "request body too large",
|
|
method: http.MethodPost,
|
|
requestBody: fmt.Sprintf(`{"key": "%s"}`, strings.Repeat("a", 100000)),
|
|
expectedStatusCode: http.StatusRequestEntityTooLarge,
|
|
expectedErrorString: http.StatusText(http.StatusRequestEntityTooLarge),
|
|
},
|
|
{
|
|
name: "malformed request body JSON",
|
|
method: http.MethodPost,
|
|
requestBody: "{",
|
|
expectedStatusCode: http.StatusBadRequest,
|
|
expectedErrorString: http.StatusText(http.StatusBadRequest) + ": Error parsing JSON",
|
|
},
|
|
{
|
|
name: "body JSON failed validation",
|
|
method: http.MethodPost,
|
|
requestBody: "{}",
|
|
expectedStatusCode: http.StatusBadRequest,
|
|
expectedErrorString: http.StatusText(http.StatusBadRequest) + ": Request failed validation: TestReq Error",
|
|
},
|
|
{
|
|
name: "body JSON has unknown field",
|
|
method: http.MethodPost,
|
|
requestBody: `{"lol": "wut"}`,
|
|
expectedStatusCode: http.StatusBadRequest,
|
|
expectedErrorString: http.StatusText(http.StatusBadRequest) + `: json: unknown field "lol"`,
|
|
},
|
|
}
|
|
for _, tc := range tt {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
// Make request
|
|
req := httptest.NewRequest(tc.method, paths.PathAuthToken, bytes.NewBuffer([]byte(tc.requestBody)))
|
|
w := httptest.NewRecorder()
|
|
|
|
success := getPostData(w, req, &TestReqStruct{})
|
|
if success {
|
|
t.Errorf("getPostData succeeded unexpectedly")
|
|
}
|
|
body, _ := ioutil.ReadAll(w.Body)
|
|
|
|
expectStatusCode(t, w, tc.expectedStatusCode)
|
|
expectErrorString(t, body, tc.expectedErrorString)
|
|
})
|
|
}
|
|
}
|