diff --git a/auth/auth.go b/auth/auth.go index c944351..61b5c3b 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -19,7 +19,7 @@ type Password string type KDFKey string // KDF output type ClientSaltSeed string // part of client-side KDF input along with root password type ServerSalt string // server-side KDF input for accounts -type TokenString string +type AuthTokenString string type VerifyTokenString string type AuthScope string @@ -28,30 +28,31 @@ const ScopeFull = AuthScope("*") // For test stubs type AuthInterface interface { // TODO maybe have a "refresh token" thing if the client won't have email available all the time? - NewToken(UserId, DeviceId, AuthScope) (*AuthToken, error) + NewAuthToken(UserId, DeviceId, AuthScope) (*AuthToken, error) + NewVerifyTokenString() (VerifyTokenString, error) } type Auth struct{} type AuthToken struct { - Token TokenString `json:"token"` - DeviceId DeviceId `json:"deviceId"` - Scope AuthScope `json:"scope"` - UserId UserId `json:"userId"` - Expiration *time.Time `json:"expiration"` + Token AuthTokenString `json:"token"` + DeviceId DeviceId `json:"deviceId"` + Scope AuthScope `json:"scope"` + UserId UserId `json:"userId"` + Expiration *time.Time `json:"expiration"` } -const AuthTokenLength = 32 +const TokenLength = 32 -func (a *Auth) NewToken(userId UserId, deviceId DeviceId, scope AuthScope) (*AuthToken, error) { - b := make([]byte, AuthTokenLength) +func (a *Auth) NewAuthToken(userId UserId, deviceId DeviceId, scope AuthScope) (*AuthToken, error) { + b := make([]byte, TokenLength) // TODO - Is this is a secure random function? (Maybe audit) if _, err := rand.Read(b); err != nil { return nil, fmt.Errorf("Error generating token: %+v", err) } return &AuthToken{ - Token: TokenString(hex.EncodeToString(b)), + Token: AuthTokenString(hex.EncodeToString(b)), DeviceId: deviceId, Scope: scope, UserId: userId, @@ -59,6 +60,16 @@ func (a *Auth) NewToken(userId UserId, deviceId DeviceId, scope AuthScope) (*Aut }, nil } +func (a *Auth) NewVerifyTokenString() (VerifyTokenString, error) { + b := make([]byte, TokenLength) + // TODO - Is this is a secure random function? (Maybe audit) + if _, err := rand.Read(b); err != nil { + return "", fmt.Errorf("Error generating token: %+v", err) + } + + return VerifyTokenString(hex.EncodeToString(b)), nil +} + // NOTE - not stubbing methods of structs like this. more convoluted than it's worth right now func (at *AuthToken) ScopeValid(required AuthScope) bool { // So far * is the only scope issued. Used to have more, didn't want to diff --git a/auth/auth_test.go b/auth/auth_test.go index 4f35c81..f48af9c 100644 --- a/auth/auth_test.go +++ b/auth/auth_test.go @@ -6,9 +6,9 @@ import ( // Test stubs for now -func TestAuthNewToken(t *testing.T) { +func TestAuthNewAuthToken(t *testing.T) { auth := Auth{} - authToken, err := auth.NewToken(234, "dId", "my-scope") + authToken, err := auth.NewAuthToken(234, "dId", "my-scope") if err != nil { t.Fatalf("Error creating new token") @@ -20,8 +20,8 @@ func TestAuthNewToken(t *testing.T) { t.Fatalf("authToken fields don't match expected values") } - // result.Token is in hex, AuthTokenLength is bytes in the original - expectedTokenLength := AuthTokenLength * 2 + // result.Token is in hex, TokenLength is bytes in the original + expectedTokenLength := TokenLength * 2 if len(authToken.Token) != expectedTokenLength { t.Fatalf("authToken token string length isn't the expected length") } diff --git a/server/account.go b/server/account.go index dd5692c..c28aea6 100644 --- a/server/account.go +++ b/server/account.go @@ -53,6 +53,8 @@ func (s *Server) register(w http.ResponseWriter, req *http.Request) { var registerResponse RegisterResponse + var token auth.VerifyTokenString + modes: switch verificationMode { case env.AccountVerificationModeAllowAll: @@ -75,13 +77,19 @@ modes: case env.AccountVerificationModeEmailVerify: // Not verified until they click their email link. registerResponse.Verified = false + token, err = s.auth.NewVerifyTokenString() + + if err != nil { + internalServiceErrorJson(w, err, "Error generating verify token string") + return + } } err = s.store.CreateAccount( registerRequest.Email, registerRequest.Password, registerRequest.ClientSaltSeed, - registerResponse.Verified, + token, // if it's not set, the user is marked as verified ) if err != nil { @@ -93,6 +101,15 @@ modes: return } + if len(token) > 0 { + err = s.mail.SendVerificationEmail(registerRequest.Email, token) + } + + if err != nil { + internalServiceErrorJson(w, err, "Error sending verification email") + return + } + response, err := json.Marshal(registerResponse) if err != nil { diff --git a/server/account_test.go b/server/account_test.go index 1b43bbb..8d8dcaa 100644 --- a/server/account_test.go +++ b/server/account_test.go @@ -7,20 +7,21 @@ import ( "io/ioutil" "net/http" "net/http/httptest" - "reflect" "strings" "testing" - "lbryio/lbry-id/auth" "lbryio/lbry-id/store" ) +// TODO - maybe this test could just be one of the TestServerRegisterAccountVerification tests now func TestServerRegisterSuccess(t *testing.T) { testStore := &TestStore{} env := map[string]string{ - "ACCOUNT_VERIFICATION_MODE": "AllowAll", + "ACCOUNT_VERIFICATION_MODE": "EmailVerify", } - s := Server{&TestAuth{}, testStore, &TestEnv{env}} + testMail := TestMail{} + testAuth := TestAuth{TestNewVerifyTokenString: "abcd1234abcd1234abcd1234abcd1234"} + s := Server{&testAuth, testStore, &TestEnv{env}, &testMail} requestBody := []byte(`{"email": "abc@example.com", "password": "123", "clientSaltSeed": "abcd1234abcd1234abcd1234abcd1234abcd1234abcd1234abcd1234abcd1234" }`) @@ -35,61 +36,99 @@ func TestServerRegisterSuccess(t *testing.T) { var result RegisterResponse err := json.Unmarshal(body, &result) - expectedResponse := RegisterResponse{Verified: true} - if err != nil || !reflect.DeepEqual(&result, &expectedResponse) { + expectedResponse := RegisterResponse{Verified: false} + if err != nil || result != expectedResponse { t.Errorf("Unexpected value for register response. Want: %+v Got: %+v Err: %+v", expectedResponse, result, err) } if testStore.Called.CreateAccount == nil { t.Errorf("Expected Store.CreateAccount to be called") } + + if testMail.SendVerificationEmailCall == nil { + // We're doing EmailVerify for this test. + t.Fatalf("Expected Store.SendVerificationEmail to be called") + } } func TestServerRegisterErrors(t *testing.T) { tt := []struct { - name string - email string - expectedStatusCode int - expectedErrorString string + name string + email string + expectedStatusCode int + expectedErrorString string + expectedCallSendVerificationEmail bool + expectedCallCreateAccount bool - storeErrors TestStoreFunctionsErrors + storeErrors TestStoreFunctionsErrors + mailError error + failGenToken bool }{ { - name: "validation error", // missing email address - email: "", - expectedStatusCode: http.StatusBadRequest, - expectedErrorString: http.StatusText(http.StatusBadRequest) + ": Request failed validation: Invalid or missing 'email'", + name: "validation error", // missing email address + email: "", + expectedStatusCode: http.StatusBadRequest, + expectedErrorString: http.StatusText(http.StatusBadRequest) + ": Request failed validation: Invalid or missing 'email'", + expectedCallSendVerificationEmail: false, + expectedCallCreateAccount: false, // Just check one validation error (missing email address) to make sure the // validate function is called. We'll check the rest of the validation // errors in the other test below. }, { - name: "existing account", - email: "abc@example.com", - expectedStatusCode: http.StatusConflict, - expectedErrorString: http.StatusText(http.StatusConflict) + ": Error registering", + name: "existing account", + email: "abc@example.com", + expectedStatusCode: http.StatusConflict, + expectedErrorString: http.StatusText(http.StatusConflict) + ": Error registering", + expectedCallSendVerificationEmail: false, + expectedCallCreateAccount: true, storeErrors: TestStoreFunctionsErrors{CreateAccount: store.ErrDuplicateEmail}, }, { - name: "unspecified account creation failure", - email: "abc@example.com", - expectedStatusCode: http.StatusInternalServerError, - expectedErrorString: http.StatusText(http.StatusInternalServerError), + name: "unspecified account creation failure", + email: "abc@example.com", + expectedStatusCode: http.StatusInternalServerError, + expectedErrorString: http.StatusText(http.StatusInternalServerError), + expectedCallSendVerificationEmail: false, + expectedCallCreateAccount: true, storeErrors: TestStoreFunctionsErrors{CreateAccount: fmt.Errorf("TestStore.CreateAccount fail")}, }, + { + name: "fail to generate verifiy token", + email: "abc@example.com", + expectedStatusCode: http.StatusInternalServerError, + expectedErrorString: http.StatusText(http.StatusInternalServerError), + expectedCallSendVerificationEmail: false, + expectedCallCreateAccount: false, + + failGenToken: true, + }, + { + name: "fail to generate verification email", + email: "abc@example.com", + expectedStatusCode: http.StatusInternalServerError, + expectedErrorString: http.StatusText(http.StatusInternalServerError), + expectedCallSendVerificationEmail: true, + expectedCallCreateAccount: true, + + mailError: fmt.Errorf("TestEmail.SendVerificationEmail fail"), + }, } for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { env := map[string]string{ - "ACCOUNT_VERIFICATION_MODE": "AllowAll", + "ACCOUNT_VERIFICATION_MODE": "EmailVerify", } // Set this up to fail according to specification - s := Server{&TestAuth{}, &TestStore{Errors: tc.storeErrors}, &TestEnv{env}} + testAuth := TestAuth{TestNewVerifyTokenString: "abcd1234abcd1234abcd1234abcd1234", FailGenToken: tc.failGenToken} + testMail := TestMail{SendVerificationEmailError: tc.mailError} + testStore := TestStore{Errors: tc.storeErrors} + s := Server{&testAuth, &testStore, &TestEnv{env}, &testMail} // Make request requestBody := fmt.Sprintf(`{"email": "%s", "password": "123", "clientSaltSeed": "abcd1234abcd1234abcd1234abcd1234abcd1234abcd1234abcd1234abcd1234"}`, tc.email) @@ -102,6 +141,20 @@ func TestServerRegisterErrors(t *testing.T) { expectStatusCode(t, w, tc.expectedStatusCode) expectErrorString(t, body, tc.expectedErrorString) + + if tc.expectedCallCreateAccount && testStore.Called.CreateAccount == nil { + t.Errorf("Expected Store.CreateAccount to be called") + } + if !tc.expectedCallCreateAccount && testStore.Called.CreateAccount != nil { + t.Errorf("Expected Store.CreateAccount not to be called") + } + + if tc.expectedCallSendVerificationEmail && testMail.SendVerificationEmailCall == nil { + t.Errorf("Expected Store.SendVerificationEmail to be called") + } + if !tc.expectedCallSendVerificationEmail && testMail.SendVerificationEmailCall != nil { + t.Errorf("Expected Store.SendVerificationEmail not to be called") + } }) } } @@ -110,10 +163,11 @@ func TestServerRegisterAccountVerification(t *testing.T) { tt := []struct { name string - env map[string]string - expectSuccess bool - expectedVerified bool - expectedStatusCode int + env map[string]string + expectSuccess bool + expectedVerified bool + expectedStatusCode int + expectedCallSendVerificationEmail bool }{ { name: "allow all", @@ -122,9 +176,10 @@ func TestServerRegisterAccountVerification(t *testing.T) { "ACCOUNT_VERIFICATION_MODE": "AllowAll", }, - expectedVerified: true, - expectSuccess: true, - expectedStatusCode: http.StatusCreated, + expectedVerified: true, + expectSuccess: true, + expectedStatusCode: http.StatusCreated, + expectedCallSendVerificationEmail: false, }, { name: "whitelist allowed", @@ -134,9 +189,10 @@ func TestServerRegisterAccountVerification(t *testing.T) { "ACCOUNT_WHITELIST": "abc@example.com", }, - expectedVerified: true, - expectSuccess: true, - expectedStatusCode: http.StatusCreated, + expectedVerified: true, + expectSuccess: true, + expectedStatusCode: http.StatusCreated, + expectedCallSendVerificationEmail: false, }, { name: "whitelist disallowed", @@ -146,9 +202,10 @@ func TestServerRegisterAccountVerification(t *testing.T) { "ACCOUNT_WHITELIST": "something-else@example.com", }, - expectedVerified: false, - expectSuccess: false, - expectedStatusCode: http.StatusForbidden, + expectedVerified: false, + expectSuccess: false, + expectedStatusCode: http.StatusForbidden, + expectedCallSendVerificationEmail: false, }, { name: "email verify", @@ -157,16 +214,19 @@ func TestServerRegisterAccountVerification(t *testing.T) { "ACCOUNT_VERIFICATION_MODE": "EmailVerify", }, - expectedVerified: false, - expectSuccess: true, - expectedStatusCode: http.StatusCreated, + expectedVerified: false, + expectSuccess: true, + expectedStatusCode: http.StatusCreated, + expectedCallSendVerificationEmail: true, }, } for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { testStore := &TestStore{} - s := Server{&TestAuth{}, testStore, &TestEnv{tc.env}} + testAuth := TestAuth{TestNewVerifyTokenString: "abcd1234abcd1234abcd1234abcd1234"} + testMail := TestMail{} + s := Server{&testAuth, testStore, &TestEnv{tc.env}, &testMail} requestBody := []byte(`{"email": "abc@example.com", "password": "123", "clientSaltSeed": "abcd1234abcd1234abcd1234abcd1234abcd1234abcd1234abcd1234abcd1234" }`) @@ -182,8 +242,12 @@ func TestServerRegisterAccountVerification(t *testing.T) { if testStore.Called.CreateAccount == nil { t.Fatalf("Expected CreateAccount to be called") } - if tc.expectedVerified != testStore.Called.CreateAccount.Verified { - t.Errorf("Unexpected value in call to CreateAccount for `verified`. Want: %+v Got: %+v", tc.expectedVerified, testStore.Called.CreateAccount.Verified) + tokenPassedIn := testStore.Called.CreateAccount.VerifyToken != "" + if tc.expectedVerified && tokenPassedIn { + t.Errorf("Expected new account to be verified, thus expected verifyToken *not to be passed in* to call to CreateAccount.") + } + if !tc.expectedVerified && !tokenPassedIn { + t.Errorf("Expected new account not to be verified, thus expected verifyToken not *to be passed in* to call to CreateAccount.") } var result RegisterResponse err := json.Unmarshal(body, &result) @@ -197,6 +261,13 @@ func TestServerRegisterAccountVerification(t *testing.T) { } } + if tc.expectedCallSendVerificationEmail && testMail.SendVerificationEmailCall == nil { + t.Errorf("Expected Store.SendVerificationEmail to be called") + } + if !tc.expectedCallSendVerificationEmail && testMail.SendVerificationEmailCall != nil { + t.Errorf("Expected Store.SendVerificationEmail not to be called") + } + }) } } @@ -254,12 +325,12 @@ func TestServerValidateRegisterRequest(t *testing.T) { } func TestServerVerifyAccountSuccess(t *testing.T) { - testStore := TestStore{TestVerifyTokenString: "abcd1234abcd1234abcd1234abcd1234"} - s := Server{&TestAuth{}, &testStore, &TestEnv{}} + testStore := TestStore{} + s := Server{&TestAuth{}, &testStore, &TestEnv{}, &TestMail{}} req := httptest.NewRequest(http.MethodGet, PathVerify, nil) q := req.URL.Query() - q.Add("verifyToken", string(testStore.TestVerifyTokenString)) + q.Add("verifyToken", "abcd1234abcd1234abcd1234abcd1234") req.URL.RawQuery = q.Encode() w := httptest.NewRecorder() @@ -280,7 +351,7 @@ func TestServerVerifyAccountSuccess(t *testing.T) { func TestServerVerifyAccountErrors(t *testing.T) { tt := []struct { name string - token auth.VerifyTokenString + token string expectedStatusCode int expectedErrorString string expectedCallVerifyAccount bool @@ -315,13 +386,13 @@ func TestServerVerifyAccountErrors(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Set this up to fail according to specification - testStore := TestStore{Errors: tc.storeErrors, TestVerifyTokenString: tc.token} - s := Server{&TestAuth{}, &testStore, &TestEnv{}} + testStore := TestStore{Errors: tc.storeErrors} + s := Server{&TestAuth{}, &testStore, &TestEnv{}, &TestMail{}} // Make request req := httptest.NewRequest(http.MethodGet, PathVerify, nil) q := req.URL.Query() - q.Add("verifyToken", string(testStore.TestVerifyTokenString)) + q.Add("verifyToken", tc.token) req.URL.RawQuery = q.Encode() w := httptest.NewRecorder() diff --git a/server/auth.go b/server/auth.go index ad3fb92..56d9a28 100644 --- a/server/auth.go +++ b/server/auth.go @@ -50,7 +50,7 @@ func (s *Server) getAuthToken(w http.ResponseWriter, req *http.Request) { return } - authToken, err := s.auth.NewToken(userId, authRequest.DeviceId, auth.ScopeFull) + authToken, err := s.auth.NewAuthToken(userId, authRequest.DeviceId, auth.ScopeFull) if err != nil { internalServiceErrorJson(w, err, "Error generating auth token") diff --git a/server/auth_test.go b/server/auth_test.go index 965ba05..4157c66 100644 --- a/server/auth_test.go +++ b/server/auth_test.go @@ -15,9 +15,9 @@ import ( ) func TestServerAuthHandlerSuccess(t *testing.T) { - testAuth := TestAuth{TestNewTokenString: auth.TokenString("seekrit")} + testAuth := TestAuth{TestNewAuthTokenString: auth.AuthTokenString("seekrit")} testStore := TestStore{} - s := Server{&testAuth, &testStore, &TestEnv{}} + s := Server{&testAuth, &testStore, &TestEnv{}, &TestMail{}} requestBody := []byte(`{"deviceId": "dev-1", "email": "abc@example.com", "password": "123"}`) @@ -32,12 +32,12 @@ func TestServerAuthHandlerSuccess(t *testing.T) { var result auth.AuthToken err := json.Unmarshal(body, &result) - if err != nil || result.Token != testAuth.TestNewTokenString { + if err != nil || result.Token != testAuth.TestNewAuthTokenString { t.Errorf("Expected auth response to contain token: result: %+v err: %+v", string(body), err) } - if testStore.Called.SaveToken != testAuth.TestNewTokenString { - t.Errorf("Expected Store.SaveToken to be called with %s", testAuth.TestNewTokenString) + if testStore.Called.SaveToken != testAuth.TestNewAuthTokenString { + t.Errorf("Expected Store.SaveToken to be called with %s", testAuth.TestNewAuthTokenString) } } @@ -98,12 +98,12 @@ func TestServerAuthHandlerErrors(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Set this up to fail according to specification - testAuth := TestAuth{TestNewTokenString: auth.TokenString("seekrit")} + testAuth := TestAuth{TestNewAuthTokenString: auth.AuthTokenString("seekrit")} testStore := TestStore{Errors: tc.storeErrors} if tc.authFailGenToken { // TODO - TestAuth{Errors:authErrors} testAuth.FailGenToken = true } - server := Server{&testAuth, &testStore, &TestEnv{}} + server := Server{&testAuth, &testStore, &TestEnv{}, &TestMail{}} // Make request // So long as the JSON is well-formed, the content doesn't matter here since the password check will be stubbed out diff --git a/server/client_test.go b/server/client_test.go index 7635737..499d7c9 100644 --- a/server/client_test.go +++ b/server/client_test.go @@ -66,7 +66,7 @@ func TestServerGetClientSalt(t *testing.T) { Errors: tc.storeErrors, } - s := Server{&testAuth, &testStore, &TestEnv{}} + s := Server{&testAuth, &testStore, &TestEnv{}, &TestMail{}} req := httptest.NewRequest(http.MethodGet, PathClientSaltSeed, nil) q := req.URL.Query() diff --git a/server/integration_test.go b/server/integration_test.go index ff8845e..7d67046 100644 --- a/server/integration_test.go +++ b/server/integration_test.go @@ -96,7 +96,7 @@ func TestIntegrationWalletUpdates(t *testing.T) { env := map[string]string{ "ACCOUNT_VERIFICATION_MODE": "EmailVerify", } - s := Server{&auth.Auth{}, &st, &TestEnv{env}} + s := Server{&auth.Auth{}, &st, &TestEnv{env}, &TestMail{}} //////////////////// t.Log("Request: Register email address - any device") @@ -130,8 +130,8 @@ func TestIntegrationWalletUpdates(t *testing.T) { checkStatusCode(t, statusCode, responseBody) - // result.Token is in hex, auth.AuthTokenLength is bytes in the original - expectedTokenLength := auth.AuthTokenLength * 2 + // result.Token is in hex, auth.TokenLength is bytes in the original + expectedTokenLength := auth.TokenLength * 2 if len(authToken1.Token) != expectedTokenLength { t.Fatalf("Expected auth response to contain token length 32: result: %+v", string(responseBody)) } @@ -265,7 +265,7 @@ func TestIntegrationChangePassword(t *testing.T) { env := map[string]string{ "ACCOUNT_VERIFICATION_MODE": "EmailVerify", } - s := Server{&auth.Auth{}, &st, &TestEnv{env}} + s := Server{&auth.Auth{}, &st, &TestEnv{env}, &TestMail{}} //////////////////// t.Log("Request: Register email address") @@ -321,8 +321,8 @@ func TestIntegrationChangePassword(t *testing.T) { checkStatusCode(t, statusCode, responseBody) - // result.Token is in hex, auth.AuthTokenLength is bytes in the original - expectedTokenLength := auth.AuthTokenLength * 2 + // result.Token is in hex, auth.TokenLength is bytes in the original + expectedTokenLength := auth.TokenLength * 2 if len(authToken.Token) != expectedTokenLength { t.Fatalf("Expected auth response to contain token length 32: result: %+v", string(responseBody)) } @@ -404,8 +404,8 @@ func TestIntegrationChangePassword(t *testing.T) { checkStatusCode(t, statusCode, responseBody) - // result.Token is in hex, auth.AuthTokenLength is bytes in the original - expectedTokenLength = auth.AuthTokenLength * 2 + // result.Token is in hex, auth.TokenLength is bytes in the original + expectedTokenLength = auth.TokenLength * 2 if len(authToken.Token) != expectedTokenLength { t.Fatalf("Expected auth response to contain token length 32: result: %+v", string(responseBody)) } @@ -509,8 +509,8 @@ func TestIntegrationChangePassword(t *testing.T) { checkStatusCode(t, statusCode, responseBody) - // result.Token is in hex, auth.AuthTokenLength is bytes in the original - expectedTokenLength = auth.AuthTokenLength * 2 + // result.Token is in hex, auth.TokenLength is bytes in the original + expectedTokenLength = auth.TokenLength * 2 if len(authToken.Token) != expectedTokenLength { t.Fatalf("Expected auth response to contain token length 32: result: %+v", string(responseBody)) } diff --git a/server/password_test.go b/server/password_test.go index 52f2cca..c3a9742 100644 --- a/server/password_test.go +++ b/server/password_test.go @@ -168,7 +168,7 @@ func TestServerChangePassword(t *testing.T) { for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { testStore := TestStore{Errors: tc.storeErrors} - s := Server{&TestAuth{}, &testStore, &TestEnv{}} + s := Server{&TestAuth{}, &testStore, &TestEnv{}, &TestMail{}} // Whether we passed in wallet fields (these test cases should be passing // in all of them or none of them, so we only test EncryptedWallet). This diff --git a/server/server.go b/server/server.go index b8ea92f..53c6b6e 100644 --- a/server/server.go +++ b/server/server.go @@ -10,6 +10,7 @@ import ( "lbryio/lbry-id/auth" "lbryio/lbry-id/env" + "lbryio/lbry-id/mail" "lbryio/lbry-id/store" ) @@ -34,6 +35,7 @@ type Server struct { auth auth.AuthInterface store store.StoreInterface env env.EnvInterface + mail mail.MailInterface } // TODO If I capitalize the `auth` `store` and `env` fields of Store{} I can @@ -42,8 +44,9 @@ func Init( auth auth.AuthInterface, store store.StoreInterface, env env.EnvInterface, + mail mail.MailInterface, ) *Server { - return &Server{auth, store, env} + return &Server{auth, store, env, mail} } type ErrorResponse struct { @@ -149,7 +152,7 @@ func getGetData(w http.ResponseWriter, req *http.Request) bool { // deviceId. Also this is apparently not idiomatic go error handling. func (s *Server) checkAuth( w http.ResponseWriter, - token auth.TokenString, + token auth.AuthTokenString, scope auth.AuthScope, ) *auth.AuthToken { authToken, err := s.store.GetToken(token) diff --git a/server/server_test.go b/server/server_test.go index 0b065a4..49c6099 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -17,6 +17,21 @@ import ( // Implementing interfaces for stubbed out packages +type SendVerificationEmailCall struct { + Email auth.Email + Token auth.VerifyTokenString +} + +type TestMail struct { + SendVerificationEmailError error + SendVerificationEmailCall *SendVerificationEmailCall +} + +func (m *TestMail) SendVerificationEmail(email auth.Email, token auth.VerifyTokenString) error { + m.SendVerificationEmailCall = &SendVerificationEmailCall{email, token} + return m.SendVerificationEmailError +} + type TestEnv struct { env map[string]string } @@ -26,15 +41,23 @@ func (e *TestEnv) Getenv(key string) string { } type TestAuth struct { - TestNewTokenString auth.TokenString - FailGenToken bool + TestNewAuthTokenString auth.AuthTokenString + TestNewVerifyTokenString auth.VerifyTokenString + FailGenToken bool } -func (a *TestAuth) NewToken(userId auth.UserId, deviceId auth.DeviceId, scope auth.AuthScope) (*auth.AuthToken, error) { +func (a *TestAuth) NewAuthToken(userId auth.UserId, deviceId auth.DeviceId, scope auth.AuthScope) (*auth.AuthToken, error) { if a.FailGenToken { return nil, fmt.Errorf("Test error: fail to generate token") } - return &auth.AuthToken{Token: a.TestNewTokenString, UserId: userId, DeviceId: deviceId, Scope: scope}, nil + return &auth.AuthToken{Token: a.TestNewAuthTokenString, UserId: userId, DeviceId: deviceId, Scope: scope}, nil +} + +func (a *TestAuth) NewVerifyTokenString() (auth.VerifyTokenString, error) { + if a.FailGenToken { + return "", fmt.Errorf("Test error: fail to generate token") + } + return a.TestNewVerifyTokenString, nil } type SetWalletCall struct { @@ -64,13 +87,13 @@ type CreateAccountCall struct { Email auth.Email Password auth.Password ClientSaltSeed auth.ClientSaltSeed - Verified bool + VerifyToken auth.VerifyTokenString } // Whether functions are called, and sometimes what they're called with type TestStoreFunctionsCalled struct { - SaveToken auth.TokenString - GetToken auth.TokenString + SaveToken auth.AuthTokenString + GetToken auth.AuthTokenString GetUserId bool CreateAccount *CreateAccountCall VerifyAccount bool @@ -102,8 +125,7 @@ type TestStore struct { // the test setup Errors TestStoreFunctionsErrors - TestAuthToken auth.AuthToken - TestVerifyTokenString auth.VerifyTokenString + TestAuthToken auth.AuthToken TestEncryptedWallet wallet.EncryptedWallet TestSequence wallet.Sequence @@ -117,7 +139,7 @@ func (s *TestStore) SaveToken(authToken *auth.AuthToken) error { return s.Errors.SaveToken } -func (s *TestStore) GetToken(token auth.TokenString) (*auth.AuthToken, error) { +func (s *TestStore) GetToken(token auth.AuthTokenString) (*auth.AuthToken, error) { s.Called.GetToken = token return &s.TestAuthToken, s.Errors.GetToken } @@ -127,12 +149,12 @@ func (s *TestStore) GetUserId(auth.Email, auth.Password) (auth.UserId, error) { return 0, s.Errors.GetUserId } -func (s *TestStore) CreateAccount(email auth.Email, password auth.Password, clientSaltSeed auth.ClientSaltSeed, verified bool) error { +func (s *TestStore) CreateAccount(email auth.Email, password auth.Password, seed auth.ClientSaltSeed, verifyToken auth.VerifyTokenString) error { s.Called.CreateAccount = &CreateAccountCall{ Email: email, Password: password, - ClientSaltSeed: clientSaltSeed, - Verified: verified, + ClientSaltSeed: seed, + VerifyToken: verifyToken, } return s.Errors.CreateAccount } @@ -290,9 +312,9 @@ func TestServerHelperCheckAuth(t *testing.T) { t.Run(tc.name, func(t *testing.T) { testStore := TestStore{ Errors: tc.storeErrors, - TestAuthToken: auth.AuthToken{Token: auth.TokenString("seekrit"), Scope: tc.userScope}, + TestAuthToken: auth.AuthToken{Token: auth.AuthTokenString("seekrit"), Scope: tc.userScope}, } - s := Server{&TestAuth{}, &testStore, &TestEnv{}} + s := Server{&TestAuth{}, &testStore, &TestEnv{}, &TestMail{}} w := httptest.NewRecorder() authToken := s.checkAuth(w, testStore.TestAuthToken.Token, tc.requiredScope) diff --git a/server/wallet.go b/server/wallet.go index 3721f6c..f59b42a 100644 --- a/server/wallet.go +++ b/server/wallet.go @@ -14,7 +14,7 @@ import ( ) type WalletRequest struct { - Token auth.TokenString `json:"token"` + Token auth.AuthTokenString `json:"token"` EncryptedWallet wallet.EncryptedWallet `json:"encryptedWallet"` Sequence wallet.Sequence `json:"sequence"` Hmac wallet.WalletHmac `json:"hmac"` @@ -54,7 +54,7 @@ func (s *Server) handleWallet(w http.ResponseWriter, req *http.Request) { // TODO - There's probably a struct-based solution here like with POST/PUT. // We could put that struct up top as well. -func getWalletParams(req *http.Request) (token auth.TokenString, err error) { +func getWalletParams(req *http.Request) (token auth.AuthTokenString, err error) { tokenSlice, hasTokenSlice := req.URL.Query()["token"] if !hasTokenSlice || tokenSlice[0] == "" { @@ -62,7 +62,7 @@ func getWalletParams(req *http.Request) (token auth.TokenString, err error) { } if err == nil { - token = auth.TokenString(tokenSlice[0]) + token = auth.AuthTokenString(tokenSlice[0]) } return diff --git a/server/wallet_test.go b/server/wallet_test.go index 7c5565b..1a27903 100644 --- a/server/wallet_test.go +++ b/server/wallet_test.go @@ -18,7 +18,7 @@ import ( func TestServerGetWallet(t *testing.T) { tt := []struct { name string - tokenString auth.TokenString + tokenString auth.AuthTokenString expectedStatusCode int expectedErrorString string @@ -27,12 +27,12 @@ func TestServerGetWallet(t *testing.T) { }{ { name: "success", - tokenString: auth.TokenString("seekrit"), + tokenString: auth.AuthTokenString("seekrit"), expectedStatusCode: http.StatusOK, }, { name: "validation error", // missing auth token - tokenString: auth.TokenString(""), + tokenString: auth.AuthTokenString(""), expectedStatusCode: http.StatusBadRequest, expectedErrorString: http.StatusText(http.StatusBadRequest) + ": Missing token parameter", @@ -42,7 +42,7 @@ func TestServerGetWallet(t *testing.T) { }, { name: "auth error", - tokenString: auth.TokenString("seekrit"), + tokenString: auth.AuthTokenString("seekrit"), expectedStatusCode: http.StatusUnauthorized, expectedErrorString: http.StatusText(http.StatusUnauthorized) + ": Token Not Found", @@ -51,7 +51,7 @@ func TestServerGetWallet(t *testing.T) { }, { name: "db error getting wallet", - tokenString: auth.TokenString("seekrit"), + tokenString: auth.AuthTokenString("seekrit"), expectedStatusCode: http.StatusInternalServerError, expectedErrorString: http.StatusText(http.StatusInternalServerError), @@ -65,7 +65,7 @@ func TestServerGetWallet(t *testing.T) { testAuth := TestAuth{} testStore := TestStore{ TestAuthToken: auth.AuthToken{ - Token: auth.TokenString(tc.tokenString), + Token: auth.AuthTokenString(tc.tokenString), Scope: auth.ScopeFull, }, @@ -77,7 +77,7 @@ func TestServerGetWallet(t *testing.T) { } testEnv := TestEnv{} - s := Server{&testAuth, &testStore, &testEnv} + s := Server{&testAuth, &testStore, &testEnv, &TestMail{}} req := httptest.NewRequest(http.MethodGet, PathWallet, nil) q := req.URL.Query() @@ -228,14 +228,14 @@ func TestServerPostWallet(t *testing.T) { testAuth := TestAuth{} testStore := TestStore{ TestAuthToken: auth.AuthToken{ - Token: auth.TokenString("seekrit"), + Token: auth.AuthTokenString("seekrit"), Scope: auth.ScopeFull, }, Errors: tc.storeErrors, } - s := Server{&testAuth, &testStore, &TestEnv{}} + s := Server{&testAuth, &testStore, &TestEnv{}, &TestMail{}} requestBody := []byte( fmt.Sprintf(`{ diff --git a/store/auth_test.go b/store/auth_test.go index c4a82b6..9210c01 100644 --- a/store/auth_test.go +++ b/store/auth_test.go @@ -43,7 +43,7 @@ func expectTokenExists(t *testing.T, s *Store, expectedToken auth.AuthToken) { t.Fatalf("Expected token for: %s", expectedToken.Token) } -func expectTokenNotExists(t *testing.T, s *Store, token auth.TokenString) { +func expectTokenNotExists(t *testing.T, s *Store, token auth.AuthTokenString) { rows, err := s.db.Query("SELECT * FROM auth_tokens WHERE token=?", token) if err != nil { t.Fatalf("Error finding (lack of) token for: %s - %+v", token, err) diff --git a/store/password_test.go b/store/password_test.go index e44c33d..7df6af9 100644 --- a/store/password_test.go +++ b/store/password_test.go @@ -17,7 +17,7 @@ func TestStoreChangePasswordSuccess(t *testing.T) { defer StoreTestCleanup(sqliteTmpFile) userId, email, oldPassword, _ := makeTestUser(t, &s) - token := auth.TokenString("my-token") + token := auth.AuthTokenString("my-token") _, err := s.db.Exec( "INSERT INTO auth_tokens (token, user_id, device_id, scope, expiration) VALUES(?,?,?,?,?)", @@ -117,7 +117,7 @@ func TestStoreChangePasswordErrors(t *testing.T) { userId, email, oldPassword, oldSeed := makeTestUser(t, &s) expiration := time.Now().UTC().Add(time.Hour * 24 * 14) authToken := auth.AuthToken{ - Token: auth.TokenString("my-token"), + Token: auth.AuthTokenString("my-token"), DeviceId: auth.DeviceId("my-dev-id"), UserId: userId, Scope: auth.AuthScope("*"), @@ -177,7 +177,7 @@ func TestStoreChangePasswordNoWalletSuccess(t *testing.T) { defer StoreTestCleanup(sqliteTmpFile) userId, email, oldPassword, _ := makeTestUser(t, &s) - token := auth.TokenString("my-token") + token := auth.AuthTokenString("my-token") _, err := s.db.Exec( "INSERT INTO auth_tokens (token, user_id, device_id, scope, expiration) VALUES(?,?,?,?,?)", @@ -249,7 +249,7 @@ func TestStoreChangePasswordNoWalletErrors(t *testing.T) { userId, email, oldPassword, oldSeed := makeTestUser(t, &s) expiration := time.Now().UTC().Add(time.Hour * 24 * 14) authToken := auth.AuthToken{ - Token: auth.TokenString("my-token"), + Token: auth.AuthTokenString("my-token"), DeviceId: auth.DeviceId("my-dev-id"), UserId: userId, Scope: auth.AuthScope("*"), diff --git a/store/store.go b/store/store.go index b95ff14..e403202 100644 --- a/store/store.go +++ b/store/store.go @@ -37,11 +37,11 @@ var ( // For test stubs type StoreInterface interface { SaveToken(*auth.AuthToken) error - GetToken(auth.TokenString) (*auth.AuthToken, error) + GetToken(auth.AuthTokenString) (*auth.AuthToken, error) SetWallet(auth.UserId, wallet.EncryptedWallet, wallet.Sequence, wallet.WalletHmac) error GetWallet(auth.UserId) (wallet.EncryptedWallet, wallet.Sequence, wallet.WalletHmac, error) GetUserId(auth.Email, auth.Password) (auth.UserId, error) - CreateAccount(auth.Email, auth.Password, auth.ClientSaltSeed, bool) error + CreateAccount(auth.Email, auth.Password, auth.ClientSaltSeed, auth.VerifyTokenString) error VerifyAccount(auth.VerifyTokenString) error ChangePasswordWithWallet(auth.Email, auth.Password, auth.Password, auth.ClientSaltSeed, wallet.EncryptedWallet, wallet.Sequence, wallet.WalletHmac) error ChangePasswordNoWallet(auth.Email, auth.Password, auth.Password, auth.ClientSaltSeed) error @@ -143,7 +143,7 @@ func (s *Store) Migrate() error { // (which I did previously)? // // TODO Put the timestamp in the token to avoid duplicates over time. And/or just use a library! Someone solved this already. -func (s *Store) GetToken(token auth.TokenString) (authToken *auth.AuthToken, err error) { +func (s *Store) GetToken(token auth.AuthTokenString) (authToken *auth.AuthToken, err error) { expirationCutoff := time.Now().UTC() authToken = &(auth.AuthToken{}) @@ -362,7 +362,7 @@ func (s *Store) GetUserId(email auth.Email, password auth.Password) (userId auth // Account // ///////////// -func (s *Store) CreateAccount(email auth.Email, password auth.Password, seed auth.ClientSaltSeed, verified bool) (err error) { +func (s *Store) CreateAccount(email auth.Email, password auth.Password, seed auth.ClientSaltSeed, verifyToken auth.VerifyTokenString) (err error) { key, salt, err := password.Create() if err != nil { return