env package provides config values

This commit is contained in:
Daniel Krol 2022-07-24 16:02:55 -04:00
parent ade526f4f9
commit f792ba5846
8 changed files with 217 additions and 28 deletions

View file

@ -4,6 +4,7 @@ import (
"crypto/rand"
"encoding/hex"
"fmt"
"net/mail"
"strings"
"time"
@ -113,6 +114,27 @@ func (p Password) Check(checkKey KDFKey, salt ServerSalt) (match bool, err error
return
}
func (e Email) Validate() bool {
email, err := mail.ParseAddress(string(e))
if err != nil {
return false
}
// "Joe <joe@example.com>" is valid according to ParseAddress. Likewise
// " joe@example.com". Etc. We only want the exact address, "joe@example.com"
// to be valid. ParseAddress will extract the exact address as
// parsed.Address. So we'll take the input email, put it through
// ParseAddress, see if it parses successfully, and then compare the input
// email to parsed.Address to make sure that it was an exact address to begin
// with.
return string(e) == email.Address
}
func (c ClientSaltSeed) Validate() bool {
_, err := hex.DecodeString(string(c))
const seedHexLength = ClientSaltSeedLength * 2
return len(c) == seedHexLength && err == nil
}
// TODO consider unicode. Also some providers might be case sensitive, and/or
// may have other ways of having email addresses be equivalent (which we may
// not care about though)

73
env/env.go vendored
View file

@ -1,9 +1,30 @@
package env
import (
"fmt"
"os"
"strings"
"lbryio/lbry-id/auth"
)
const whitelistKey = "ACCOUNT_WHITELIST"
const verificationModeKey = "ACCOUNT_VERIFICATION_MODE"
// Not exported, so that GetAccountVerificationMode is the only way to get one
// of these values from outside of this package.
type accountVerificationMode string
// Everyone can make an account. Only use for dev purposes.
const ModeAllowAll = accountVerificationMode("AllowAll")
// Verify accounts via email. Good for big open servers.
const ModeEmailVerify = accountVerificationMode("EmailVerify")
// Specific email accounts are automatically verified. Good for small
// self-hosting users.
const ModeWhitelist = accountVerificationMode("Whitelist")
// For test stubs
type EnvInterface interface {
Getenv(key string) string
@ -14,3 +35,55 @@ type Env struct{}
func (e *Env) Getenv(key string) string {
return os.Getenv(key)
}
func GetAccountVerificationMode(e EnvInterface) (accountVerificationMode, error) {
return getAccountVerificationMode(e.Getenv(verificationModeKey))
}
func GetAccountWhitelist(e EnvInterface, mode accountVerificationMode) (emails []auth.Email, err error) {
return getAccountWhitelist(e.Getenv(whitelistKey), mode)
}
// Factor out the guts of the functions so we can test them by just passing in
// the env vars
func getAccountVerificationMode(modeStr string) (accountVerificationMode, error) {
mode := accountVerificationMode(modeStr)
switch mode {
case "":
// Whitelist is the least dangerous mode. If you forget to set any env
// vars, it effectively disables all account creation.
return ModeWhitelist, nil
case ModeAllowAll:
case ModeEmailVerify:
case ModeWhitelist:
default:
return "", fmt.Errorf("Invalid account verification mode in %s: %s", verificationModeKey, mode)
}
return mode, nil
}
func getAccountWhitelist(whitelist string, mode accountVerificationMode) (emails []auth.Email, err error) {
if whitelist == "" {
return []auth.Email{}, nil
}
if mode != ModeWhitelist {
return nil, fmt.Errorf("Do not specify ACCOUNT_WHITELIST in env if ACCOUNT_VERIFICATION_MODE is not Whitelist")
}
rawEmails := strings.Split(whitelist, ",")
for _, rawEmail := range rawEmails {
// Give them a specific error here to let them know not to add spaces. It
// could be confusing otherwise to figure out what's invalid.
if strings.TrimSpace(rawEmail) != rawEmail {
return nil, fmt.Errorf("Emails in %s should be comma separated with no spaces.", whitelistKey)
}
email := auth.Email(rawEmail)
if !email.Validate() {
return nil, fmt.Errorf("Invalid email in %s: %s", whitelistKey, email)
}
emails = append(emails, email)
}
return emails, nil
}

115
env/env_test.go vendored Normal file
View file

@ -0,0 +1,115 @@
package env
import (
"fmt"
"reflect"
"testing"
"lbryio/lbry-id/auth"
)
func TestAccountVerificationMode(t *testing.T) {
tt := []struct {
name string
modeStr string
expectedMode accountVerificationMode
expectErr bool
}{
{
name: "allow all",
modeStr: "AllowAll",
expectedMode: ModeAllowAll,
},
{
name: "email verify",
modeStr: "EmailVerify",
expectedMode: ModeEmailVerify,
},
{
name: "whitelist",
modeStr: "Whitelist",
expectedMode: ModeWhitelist,
},
{
name: "blank",
modeStr: "",
expectedMode: ModeWhitelist,
},
{
name: "invalid",
modeStr: "Banana",
expectErr: true,
},
}
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
mode, err := getAccountVerificationMode(tc.modeStr)
if mode != tc.expectedMode {
t.Errorf("Expected mode %s got %s", tc.expectedMode, mode)
}
if tc.expectErr && err == nil {
t.Errorf("Expected err")
}
if !tc.expectErr && err != nil {
t.Errorf("Unexpected err: %s", err.Error())
}
})
}
}
func TestAccountWhitelist(t *testing.T) {
tt := []struct {
name string
whitelist string
expectedEmails []auth.Email
expectedErr error
mode accountVerificationMode
}{
{
name: "empty",
mode: ModeWhitelist,
whitelist: "",
expectedEmails: []auth.Email{},
},
{
name: "invalid mode",
mode: ModeEmailVerify,
whitelist: "test1@example.com,test2@example.com",
expectedErr: fmt.Errorf("Do not specify ACCOUNT_WHITELIST in env if ACCOUNT_VERIFICATION_MODE is not Whitelist"),
},
{
name: "spaces in email",
mode: ModeWhitelist,
whitelist: "test1@example.com ,test2@example.com",
expectedErr: fmt.Errorf("Emails in ACCOUNT_WHITELIST should be comma separated with no spaces."),
},
{
name: "invalid email",
mode: ModeWhitelist,
whitelist: "test1@example.com,test2-example.com",
expectedErr: fmt.Errorf("Invalid email in ACCOUNT_WHITELIST: test2-example.com"),
},
}
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
emails, err := getAccountWhitelist(tc.whitelist, tc.mode)
if !reflect.DeepEqual(emails, tc.expectedEmails) {
t.Errorf("Expected emails %+v got %+v", tc.expectedEmails, emails)
}
if fmt.Sprint(err) != fmt.Sprint(tc.expectedErr) {
t.Errorf("Expected error `%s` got `%s`", tc.expectedErr, err.Error())
}
})
}
}

View file

@ -16,13 +16,14 @@ type RegisterRequest struct {
}
func (r *RegisterRequest) validate() error {
if !validateEmail(r.Email) {
if !r.Email.Validate() {
return fmt.Errorf("Invalid or missing 'email'")
}
if r.Password == "" {
return fmt.Errorf("Missing 'password'")
}
if !validateClientSaltSeed(r.ClientSaltSeed) {
if !r.ClientSaltSeed.Validate() {
return fmt.Errorf("Invalid or missing 'clientSaltSeed'")
}
return nil

View file

@ -18,7 +18,7 @@ type AuthRequest struct {
}
func (r *AuthRequest) validate() error {
if !validateEmail(r.Email) {
if !r.Email.Validate() {
return fmt.Errorf("Invalid 'email'")
}
if r.Password == "" {

View file

@ -40,7 +40,7 @@ func getClientSaltSeedParams(req *http.Request) (email auth.Email, err error) {
}
}
if !validateEmail(email) {
if !email.Validate() {
email = ""
err = fmt.Errorf("Invalid email")
}

View file

@ -25,7 +25,7 @@ func (r *ChangePasswordRequest) validate() error {
walletPresent := (r.EncryptedWallet != "" && r.Hmac != "" && r.Sequence > 0)
walletAbsent := (r.EncryptedWallet == "" && r.Hmac == "" && r.Sequence == 0)
if !validateEmail(r.Email) {
if !r.Email.Validate() {
return fmt.Errorf("Invalid or missing 'email'")
}
if r.OldPassword == "" {
@ -38,7 +38,7 @@ func (r *ChangePasswordRequest) validate() error {
if r.OldPassword == r.NewPassword {
return fmt.Errorf("'oldPassword' and 'newPassword' should not be the same")
}
if !validateClientSaltSeed(r.ClientSaltSeed) {
if !r.ClientSaltSeed.Validate() {
return fmt.Errorf("Invalid or missing 'clientSaltSeed'")
}
if !walletPresent && !walletAbsent {

View file

@ -1,12 +1,10 @@
package server
import (
"encoding/hex"
"encoding/json"
"fmt"
"log"
"net/http"
"net/mail"
"github.com/prometheus/client_golang/prometheus/promhttp"
@ -171,26 +169,6 @@ func (s *Server) checkAuth(
return authToken
}
func validateEmail(email auth.Email) bool {
e, err := mail.ParseAddress(string(email))
if err != nil {
return false
}
// "Joe <joe@example.com>" is valid according to ParseAddress. Likewise
// " joe@example.com". Etc. We only want the exact address, "joe@example.com"
// to be valid. ParseAddress will extract the exact address as e.Address. So
// we'll take the input email, put it through ParseAddress, see if it parses
// successfully, and then compare the input email to e.Address to make sure
// that it was an exact address to begin with.
return string(email) == e.Address
}
func validateClientSaltSeed(clientSaltSeed auth.ClientSaltSeed) bool {
_, err := hex.DecodeString(string(clientSaltSeed))
const seedHexLength = auth.ClientSaltSeedLength * 2
return len(clientSaltSeed) == seedHexLength && err == nil
}
// TODO - both wallet and token requests should be PUT, not POST.
// PUT = "...creates a new resource or replaces a representation of the target resource with the request payload."