From f792ba5846db06630482d2e223cc5066c92faa2b Mon Sep 17 00:00:00 2001 From: Daniel Krol Date: Sun, 24 Jul 2022 16:02:55 -0400 Subject: [PATCH] `env` package provides config values --- auth/auth.go | 22 +++++++++ env/env.go | 73 ++++++++++++++++++++++++++++ env/env_test.go | 115 +++++++++++++++++++++++++++++++++++++++++++++ server/account.go | 5 +- server/auth.go | 2 +- server/client.go | 2 +- server/password.go | 4 +- server/server.go | 22 --------- 8 files changed, 217 insertions(+), 28 deletions(-) create mode 100644 env/env_test.go diff --git a/auth/auth.go b/auth/auth.go index fd30f34..7cc6f7f 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -4,6 +4,7 @@ import ( "crypto/rand" "encoding/hex" "fmt" + "net/mail" "strings" "time" @@ -113,6 +114,27 @@ func (p Password) Check(checkKey KDFKey, salt ServerSalt) (match bool, err error return } +func (e Email) Validate() bool { + email, err := mail.ParseAddress(string(e)) + if err != nil { + return false + } + // "Joe " is valid according to ParseAddress. Likewise + // " joe@example.com". Etc. We only want the exact address, "joe@example.com" + // to be valid. ParseAddress will extract the exact address as + // parsed.Address. So we'll take the input email, put it through + // ParseAddress, see if it parses successfully, and then compare the input + // email to parsed.Address to make sure that it was an exact address to begin + // with. + return string(e) == email.Address +} + +func (c ClientSaltSeed) Validate() bool { + _, err := hex.DecodeString(string(c)) + const seedHexLength = ClientSaltSeedLength * 2 + return len(c) == seedHexLength && err == nil +} + // TODO consider unicode. Also some providers might be case sensitive, and/or // may have other ways of having email addresses be equivalent (which we may // not care about though) diff --git a/env/env.go b/env/env.go index 83c3034..e140a3d 100644 --- a/env/env.go +++ b/env/env.go @@ -1,9 +1,30 @@ package env import ( + "fmt" "os" + "strings" + + "lbryio/lbry-id/auth" ) +const whitelistKey = "ACCOUNT_WHITELIST" +const verificationModeKey = "ACCOUNT_VERIFICATION_MODE" + +// Not exported, so that GetAccountVerificationMode is the only way to get one +// of these values from outside of this package. +type accountVerificationMode string + +// Everyone can make an account. Only use for dev purposes. +const ModeAllowAll = accountVerificationMode("AllowAll") + +// Verify accounts via email. Good for big open servers. +const ModeEmailVerify = accountVerificationMode("EmailVerify") + +// Specific email accounts are automatically verified. Good for small +// self-hosting users. +const ModeWhitelist = accountVerificationMode("Whitelist") + // For test stubs type EnvInterface interface { Getenv(key string) string @@ -14,3 +35,55 @@ type Env struct{} func (e *Env) Getenv(key string) string { return os.Getenv(key) } + +func GetAccountVerificationMode(e EnvInterface) (accountVerificationMode, error) { + return getAccountVerificationMode(e.Getenv(verificationModeKey)) +} + +func GetAccountWhitelist(e EnvInterface, mode accountVerificationMode) (emails []auth.Email, err error) { + return getAccountWhitelist(e.Getenv(whitelistKey), mode) +} + +// Factor out the guts of the functions so we can test them by just passing in +// the env vars + +func getAccountVerificationMode(modeStr string) (accountVerificationMode, error) { + mode := accountVerificationMode(modeStr) + switch mode { + case "": + // Whitelist is the least dangerous mode. If you forget to set any env + // vars, it effectively disables all account creation. + return ModeWhitelist, nil + case ModeAllowAll: + case ModeEmailVerify: + case ModeWhitelist: + default: + return "", fmt.Errorf("Invalid account verification mode in %s: %s", verificationModeKey, mode) + } + return mode, nil +} + +func getAccountWhitelist(whitelist string, mode accountVerificationMode) (emails []auth.Email, err error) { + if whitelist == "" { + return []auth.Email{}, nil + } + + if mode != ModeWhitelist { + return nil, fmt.Errorf("Do not specify ACCOUNT_WHITELIST in env if ACCOUNT_VERIFICATION_MODE is not Whitelist") + } + + rawEmails := strings.Split(whitelist, ",") + for _, rawEmail := range rawEmails { + // Give them a specific error here to let them know not to add spaces. It + // could be confusing otherwise to figure out what's invalid. + if strings.TrimSpace(rawEmail) != rawEmail { + return nil, fmt.Errorf("Emails in %s should be comma separated with no spaces.", whitelistKey) + } + email := auth.Email(rawEmail) + if !email.Validate() { + return nil, fmt.Errorf("Invalid email in %s: %s", whitelistKey, email) + } + emails = append(emails, email) + } + return emails, nil +} diff --git a/env/env_test.go b/env/env_test.go new file mode 100644 index 0000000..7634f81 --- /dev/null +++ b/env/env_test.go @@ -0,0 +1,115 @@ +package env + +import ( + "fmt" + "reflect" + "testing" + + "lbryio/lbry-id/auth" +) + +func TestAccountVerificationMode(t *testing.T) { + tt := []struct { + name string + + modeStr string + expectedMode accountVerificationMode + expectErr bool + }{ + { + name: "allow all", + + modeStr: "AllowAll", + expectedMode: ModeAllowAll, + }, + { + name: "email verify", + + modeStr: "EmailVerify", + expectedMode: ModeEmailVerify, + }, + { + name: "whitelist", + + modeStr: "Whitelist", + expectedMode: ModeWhitelist, + }, + { + name: "blank", + + modeStr: "", + expectedMode: ModeWhitelist, + }, + { + name: "invalid", + + modeStr: "Banana", + expectErr: true, + }, + } + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + mode, err := getAccountVerificationMode(tc.modeStr) + if mode != tc.expectedMode { + t.Errorf("Expected mode %s got %s", tc.expectedMode, mode) + } + if tc.expectErr && err == nil { + t.Errorf("Expected err") + } + if !tc.expectErr && err != nil { + t.Errorf("Unexpected err: %s", err.Error()) + } + }) + } +} + +func TestAccountWhitelist(t *testing.T) { + tt := []struct { + name string + + whitelist string + expectedEmails []auth.Email + expectedErr error + mode accountVerificationMode + }{ + { + name: "empty", + + mode: ModeWhitelist, + whitelist: "", + expectedEmails: []auth.Email{}, + }, + { + name: "invalid mode", + + mode: ModeEmailVerify, + whitelist: "test1@example.com,test2@example.com", + expectedErr: fmt.Errorf("Do not specify ACCOUNT_WHITELIST in env if ACCOUNT_VERIFICATION_MODE is not Whitelist"), + }, + { + name: "spaces in email", + + mode: ModeWhitelist, + whitelist: "test1@example.com ,test2@example.com", + expectedErr: fmt.Errorf("Emails in ACCOUNT_WHITELIST should be comma separated with no spaces."), + }, + { + name: "invalid email", + + mode: ModeWhitelist, + whitelist: "test1@example.com,test2-example.com", + expectedErr: fmt.Errorf("Invalid email in ACCOUNT_WHITELIST: test2-example.com"), + }, + } + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + emails, err := getAccountWhitelist(tc.whitelist, tc.mode) + if !reflect.DeepEqual(emails, tc.expectedEmails) { + t.Errorf("Expected emails %+v got %+v", tc.expectedEmails, emails) + } + if fmt.Sprint(err) != fmt.Sprint(tc.expectedErr) { + t.Errorf("Expected error `%s` got `%s`", tc.expectedErr, err.Error()) + } + }) + } +} diff --git a/server/account.go b/server/account.go index 7108c8d..cd50fb2 100644 --- a/server/account.go +++ b/server/account.go @@ -16,13 +16,14 @@ type RegisterRequest struct { } func (r *RegisterRequest) validate() error { - if !validateEmail(r.Email) { + if !r.Email.Validate() { return fmt.Errorf("Invalid or missing 'email'") } if r.Password == "" { return fmt.Errorf("Missing 'password'") } - if !validateClientSaltSeed(r.ClientSaltSeed) { + + if !r.ClientSaltSeed.Validate() { return fmt.Errorf("Invalid or missing 'clientSaltSeed'") } return nil diff --git a/server/auth.go b/server/auth.go index 8e90e1e..d7e4704 100644 --- a/server/auth.go +++ b/server/auth.go @@ -18,7 +18,7 @@ type AuthRequest struct { } func (r *AuthRequest) validate() error { - if !validateEmail(r.Email) { + if !r.Email.Validate() { return fmt.Errorf("Invalid 'email'") } if r.Password == "" { diff --git a/server/client.go b/server/client.go index 22e0a2b..12efb2c 100644 --- a/server/client.go +++ b/server/client.go @@ -40,7 +40,7 @@ func getClientSaltSeedParams(req *http.Request) (email auth.Email, err error) { } } - if !validateEmail(email) { + if !email.Validate() { email = "" err = fmt.Errorf("Invalid email") } diff --git a/server/password.go b/server/password.go index 5f1a308..5c3ee78 100644 --- a/server/password.go +++ b/server/password.go @@ -25,7 +25,7 @@ func (r *ChangePasswordRequest) validate() error { walletPresent := (r.EncryptedWallet != "" && r.Hmac != "" && r.Sequence > 0) walletAbsent := (r.EncryptedWallet == "" && r.Hmac == "" && r.Sequence == 0) - if !validateEmail(r.Email) { + if !r.Email.Validate() { return fmt.Errorf("Invalid or missing 'email'") } if r.OldPassword == "" { @@ -38,7 +38,7 @@ func (r *ChangePasswordRequest) validate() error { if r.OldPassword == r.NewPassword { return fmt.Errorf("'oldPassword' and 'newPassword' should not be the same") } - if !validateClientSaltSeed(r.ClientSaltSeed) { + if !r.ClientSaltSeed.Validate() { return fmt.Errorf("Invalid or missing 'clientSaltSeed'") } if !walletPresent && !walletAbsent { diff --git a/server/server.go b/server/server.go index 0616433..58b2497 100644 --- a/server/server.go +++ b/server/server.go @@ -1,12 +1,10 @@ package server import ( - "encoding/hex" "encoding/json" "fmt" "log" "net/http" - "net/mail" "github.com/prometheus/client_golang/prometheus/promhttp" @@ -171,26 +169,6 @@ func (s *Server) checkAuth( return authToken } -func validateEmail(email auth.Email) bool { - e, err := mail.ParseAddress(string(email)) - if err != nil { - return false - } - // "Joe " is valid according to ParseAddress. Likewise - // " joe@example.com". Etc. We only want the exact address, "joe@example.com" - // to be valid. ParseAddress will extract the exact address as e.Address. So - // we'll take the input email, put it through ParseAddress, see if it parses - // successfully, and then compare the input email to e.Address to make sure - // that it was an exact address to begin with. - return string(email) == e.Address -} - -func validateClientSaltSeed(clientSaltSeed auth.ClientSaltSeed) bool { - _, err := hex.DecodeString(string(clientSaltSeed)) - const seedHexLength = auth.ClientSaltSeedLength * 2 - return len(clientSaltSeed) == seedHexLength && err == nil -} - // TODO - both wallet and token requests should be PUT, not POST. // PUT = "...creates a new resource or replaces a representation of the target resource with the request payload."