diff --git a/server/account.go b/server/account.go index cd50fb2..57b9699 100644 --- a/server/account.go +++ b/server/account.go @@ -6,6 +6,7 @@ import ( "net/http" "lbryio/lbry-id/auth" + "lbryio/lbry-id/env" "lbryio/lbry-id/store" ) @@ -15,6 +16,10 @@ type RegisterRequest struct { ClientSaltSeed auth.ClientSaltSeed `json:"clientSaltSeed"` } +type RegisterResponse struct { + Verified bool `json:"verified"` +} + func (r *RegisterRequest) validate() error { if !r.Email.Validate() { return fmt.Errorf("Invalid or missing 'email'") @@ -35,7 +40,49 @@ func (s *Server) register(w http.ResponseWriter, req *http.Request) { return } - err := s.store.CreateAccount(registerRequest.Email, registerRequest.Password, registerRequest.ClientSaltSeed) + verificationMode, err := env.GetAccountVerificationMode(s.env) + if err != nil { + internalServiceErrorJson(w, err, "Error getting account verification mode") + return + } + accountWhitelist, err := env.GetAccountWhitelist(s.env, verificationMode) + if err != nil { + internalServiceErrorJson(w, err, "Error getting account whitelist") + return + } + + var registerResponse RegisterResponse + +modes: + switch verificationMode { + case env.AccountVerificationModeAllowAll: + // Always verified (for testers). No need to jump through email verify + // hoops. + registerResponse.Verified = true + case env.AccountVerificationModeWhitelist: + for _, whitelisteEmail := range accountWhitelist { + if whitelisteEmail == registerRequest.Email { + registerResponse.Verified = true + break modes + } + } + // If we have unverified users on whitelist setups, we'd need to create a way + // to verify them. It's easier to just prevent account creation. It also will + // make it easier for self-hosters to figure out that something is wrong + // with their whitelist. + errorJson(w, http.StatusForbidden, "Account not whitelisted") + return + case env.AccountVerificationModeEmailVerify: + // Not verified until they click their email link. + registerResponse.Verified = false + } + + err = s.store.CreateAccount( + registerRequest.Email, + registerRequest.Password, + registerRequest.ClientSaltSeed, + registerResponse.Verified, + ) if err != nil { if err == store.ErrDuplicateEmail || err == store.ErrDuplicateAccount { @@ -46,9 +93,7 @@ func (s *Server) register(w http.ResponseWriter, req *http.Request) { return } - var registerResponse struct{} // no data to respond with, but keep it JSON - var response []byte - response, err = json.Marshal(registerResponse) + response, err := json.Marshal(registerResponse) if err != nil { internalServiceErrorJson(w, err, "Error generating register response") diff --git a/server/account_test.go b/server/account_test.go index 2f59f67..ba71087 100644 --- a/server/account_test.go +++ b/server/account_test.go @@ -2,10 +2,12 @@ package server import ( "bytes" + "encoding/json" "fmt" "io/ioutil" "net/http" "net/http/httptest" + "reflect" "strings" "testing" @@ -14,7 +16,10 @@ import ( func TestServerRegisterSuccess(t *testing.T) { testStore := &TestStore{} - s := Server{&TestAuth{}, testStore, &TestEnv{}} + env := map[string]string{ + "ACCOUNT_VERIFICATION_MODE": "AllowAll", + } + s := Server{&TestAuth{}, testStore, &TestEnv{env}} requestBody := []byte(`{"email": "abc@example.com", "password": "123", "clientSaltSeed": "abcd1234abcd1234abcd1234abcd1234abcd1234abcd1234abcd1234abcd1234" }`) @@ -26,11 +31,15 @@ func TestServerRegisterSuccess(t *testing.T) { expectStatusCode(t, w, http.StatusCreated) - if string(body) != "{}" { - t.Errorf("Expected register response to be \"{}\": result: %+v", string(body)) + var result RegisterResponse + err := json.Unmarshal(body, &result) + + expectedResponse := RegisterResponse{Verified: true} + if err != nil || !reflect.DeepEqual(&result, &expectedResponse) { + t.Errorf("Unexpected value for register response. Want: %+v Got: %+v Err: %+v", expectedResponse, result, err) } - if !testStore.Called.CreateAccount { + if testStore.Called.CreateAccount == nil { t.Errorf("Expected Store.CreateAccount to be called") } } @@ -74,15 +83,19 @@ func TestServerRegisterErrors(t *testing.T) { for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { + env := map[string]string{ + "ACCOUNT_VERIFICATION_MODE": "AllowAll", + } + // Set this up to fail according to specification - server := Server{&TestAuth{}, &TestStore{Errors: tc.storeErrors}, &TestEnv{}} + s := Server{&TestAuth{}, &TestStore{Errors: tc.storeErrors}, &TestEnv{env}} // Make request requestBody := fmt.Sprintf(`{"email": "%s", "password": "123", "clientSaltSeed": "abcd1234abcd1234abcd1234abcd1234abcd1234abcd1234abcd1234abcd1234"}`, tc.email) req := httptest.NewRequest(http.MethodPost, PathAuthToken, bytes.NewBuffer([]byte(requestBody))) w := httptest.NewRecorder() - server.register(w, req) + s.register(w, req) body, _ := ioutil.ReadAll(w.Body) @@ -92,6 +105,101 @@ func TestServerRegisterErrors(t *testing.T) { } } +func TestServerRegisterAccountVerification(t *testing.T) { + tt := []struct { + name string + + env map[string]string + expectSuccess bool + expectedVerified bool + expectedStatusCode int + }{ + { + name: "allow all", + + env: map[string]string{ + "ACCOUNT_VERIFICATION_MODE": "AllowAll", + }, + + expectedVerified: true, + expectSuccess: true, + expectedStatusCode: http.StatusCreated, + }, + { + name: "whitelist allowed", + + env: map[string]string{ + "ACCOUNT_VERIFICATION_MODE": "Whitelist", + "ACCOUNT_WHITELIST": "abc@example.com", + }, + + expectedVerified: true, + expectSuccess: true, + expectedStatusCode: http.StatusCreated, + }, + { + name: "whitelist disallowed", + + env: map[string]string{ + "ACCOUNT_VERIFICATION_MODE": "Whitelist", + "ACCOUNT_WHITELIST": "something-else@example.com", + }, + + expectedVerified: false, + expectSuccess: false, + expectedStatusCode: http.StatusForbidden, + }, + { + name: "email verify", + + env: map[string]string{ + "ACCOUNT_VERIFICATION_MODE": "EmailVerify", + }, + + expectedVerified: false, + expectSuccess: true, + expectedStatusCode: http.StatusCreated, + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + testStore := &TestStore{} + s := Server{&TestAuth{}, testStore, &TestEnv{tc.env}} + + requestBody := []byte(`{"email": "abc@example.com", "password": "123", "clientSaltSeed": "abcd1234abcd1234abcd1234abcd1234abcd1234abcd1234abcd1234abcd1234" }`) + + req := httptest.NewRequest(http.MethodPost, PathRegister, bytes.NewBuffer(requestBody)) + w := httptest.NewRecorder() + + s.register(w, req) + body, _ := ioutil.ReadAll(w.Body) + + expectStatusCode(t, w, tc.expectedStatusCode) + + if tc.expectSuccess { + if testStore.Called.CreateAccount == nil { + t.Fatalf("Expected CreateAccount to be called") + } + if tc.expectedVerified != testStore.Called.CreateAccount.Verified { + t.Errorf("Unexpected value in call to CreateAccount for `verified`. Want: %+v Got: %+v", tc.expectedVerified, testStore.Called.CreateAccount.Verified) + } + var result RegisterResponse + err := json.Unmarshal(body, &result) + + if err != nil || tc.expectedVerified != result.Verified { + t.Errorf("Unexpected value in register response for `verified`. Want: %+v Got: %+v Err: %+v", tc.expectedVerified, result.Verified, err) + } + } else { + if testStore.Called.CreateAccount != nil { + t.Errorf("Expected CreateAccount not to be called") + } + } + + }) + } +} + func TestServerValidateRegisterRequest(t *testing.T) { registerRequest := RegisterRequest{Email: "joe@example.com", Password: "aoeu", ClientSaltSeed: "abcd1234abcd1234abcd1234abcd1234abcd1234abcd1234abcd1234abcd1234"} if registerRequest.validate() != nil { diff --git a/server/server_test.go b/server/server_test.go index 98ef36f..b3402bc 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -60,12 +60,19 @@ type ChangePasswordWithWalletCall struct { ClientSaltSeed auth.ClientSaltSeed } +type CreateAccountCall struct { + Email auth.Email + Password auth.Password + ClientSaltSeed auth.ClientSaltSeed + Verified bool +} + // Whether functions are called, and sometimes what they're called with type TestStoreFunctionsCalled struct { SaveToken auth.TokenString GetToken auth.TokenString GetUserId bool - CreateAccount bool + CreateAccount *CreateAccountCall SetWallet SetWalletCall GetWallet bool ChangePasswordWithWallet ChangePasswordWithWalletCall @@ -117,8 +124,13 @@ func (s *TestStore) GetUserId(auth.Email, auth.Password) (auth.UserId, error) { return 0, s.Errors.GetUserId } -func (s *TestStore) CreateAccount(auth.Email, auth.Password, auth.ClientSaltSeed) error { - s.Called.CreateAccount = true +func (s *TestStore) CreateAccount(email auth.Email, password auth.Password, clientSaltSeed auth.ClientSaltSeed, verified bool) error { + s.Called.CreateAccount = &CreateAccountCall{ + Email: email, + Password: password, + ClientSaltSeed: clientSaltSeed, + Verified: verified, + } return s.Errors.CreateAccount } diff --git a/store/store.go b/store/store.go index 62e3580..cb51146 100644 --- a/store/store.go +++ b/store/store.go @@ -39,7 +39,7 @@ type StoreInterface interface { SetWallet(auth.UserId, wallet.EncryptedWallet, wallet.Sequence, wallet.WalletHmac) error GetWallet(auth.UserId) (wallet.EncryptedWallet, wallet.Sequence, wallet.WalletHmac, error) GetUserId(auth.Email, auth.Password) (auth.UserId, error) - CreateAccount(auth.Email, auth.Password, auth.ClientSaltSeed) error + CreateAccount(auth.Email, auth.Password, auth.ClientSaltSeed, bool) error ChangePasswordWithWallet(auth.Email, auth.Password, auth.Password, auth.ClientSaltSeed, wallet.EncryptedWallet, wallet.Sequence, wallet.WalletHmac) error ChangePasswordNoWallet(auth.Email, auth.Password, auth.Password, auth.ClientSaltSeed) error GetClientSaltSeed(auth.Email) (auth.ClientSaltSeed, error) @@ -359,7 +359,7 @@ func (s *Store) GetUserId(email auth.Email, password auth.Password) (userId auth // Account // ///////////// -func (s *Store) CreateAccount(email auth.Email, password auth.Password, seed auth.ClientSaltSeed) (err error) { +func (s *Store) CreateAccount(email auth.Email, password auth.Password, seed auth.ClientSaltSeed, verified bool) (err error) { key, salt, err := password.Create() if err != nil { return