env
package provides config values
This commit is contained in:
parent
ade526f4f9
commit
f792ba5846
8 changed files with 217 additions and 28 deletions
22
auth/auth.go
22
auth/auth.go
|
@ -4,6 +4,7 @@ import (
|
|||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/mail"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
@ -113,6 +114,27 @@ func (p Password) Check(checkKey KDFKey, salt ServerSalt) (match bool, err error
|
|||
return
|
||||
}
|
||||
|
||||
func (e Email) Validate() bool {
|
||||
email, err := mail.ParseAddress(string(e))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
// "Joe <joe@example.com>" is valid according to ParseAddress. Likewise
|
||||
// " joe@example.com". Etc. We only want the exact address, "joe@example.com"
|
||||
// to be valid. ParseAddress will extract the exact address as
|
||||
// parsed.Address. So we'll take the input email, put it through
|
||||
// ParseAddress, see if it parses successfully, and then compare the input
|
||||
// email to parsed.Address to make sure that it was an exact address to begin
|
||||
// with.
|
||||
return string(e) == email.Address
|
||||
}
|
||||
|
||||
func (c ClientSaltSeed) Validate() bool {
|
||||
_, err := hex.DecodeString(string(c))
|
||||
const seedHexLength = ClientSaltSeedLength * 2
|
||||
return len(c) == seedHexLength && err == nil
|
||||
}
|
||||
|
||||
// TODO consider unicode. Also some providers might be case sensitive, and/or
|
||||
// may have other ways of having email addresses be equivalent (which we may
|
||||
// not care about though)
|
||||
|
|
73
env/env.go
vendored
73
env/env.go
vendored
|
@ -1,9 +1,30 @@
|
|||
package env
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"lbryio/lbry-id/auth"
|
||||
)
|
||||
|
||||
const whitelistKey = "ACCOUNT_WHITELIST"
|
||||
const verificationModeKey = "ACCOUNT_VERIFICATION_MODE"
|
||||
|
||||
// Not exported, so that GetAccountVerificationMode is the only way to get one
|
||||
// of these values from outside of this package.
|
||||
type accountVerificationMode string
|
||||
|
||||
// Everyone can make an account. Only use for dev purposes.
|
||||
const ModeAllowAll = accountVerificationMode("AllowAll")
|
||||
|
||||
// Verify accounts via email. Good for big open servers.
|
||||
const ModeEmailVerify = accountVerificationMode("EmailVerify")
|
||||
|
||||
// Specific email accounts are automatically verified. Good for small
|
||||
// self-hosting users.
|
||||
const ModeWhitelist = accountVerificationMode("Whitelist")
|
||||
|
||||
// For test stubs
|
||||
type EnvInterface interface {
|
||||
Getenv(key string) string
|
||||
|
@ -14,3 +35,55 @@ type Env struct{}
|
|||
func (e *Env) Getenv(key string) string {
|
||||
return os.Getenv(key)
|
||||
}
|
||||
|
||||
func GetAccountVerificationMode(e EnvInterface) (accountVerificationMode, error) {
|
||||
return getAccountVerificationMode(e.Getenv(verificationModeKey))
|
||||
}
|
||||
|
||||
func GetAccountWhitelist(e EnvInterface, mode accountVerificationMode) (emails []auth.Email, err error) {
|
||||
return getAccountWhitelist(e.Getenv(whitelistKey), mode)
|
||||
}
|
||||
|
||||
// Factor out the guts of the functions so we can test them by just passing in
|
||||
// the env vars
|
||||
|
||||
func getAccountVerificationMode(modeStr string) (accountVerificationMode, error) {
|
||||
mode := accountVerificationMode(modeStr)
|
||||
switch mode {
|
||||
case "":
|
||||
// Whitelist is the least dangerous mode. If you forget to set any env
|
||||
// vars, it effectively disables all account creation.
|
||||
return ModeWhitelist, nil
|
||||
case ModeAllowAll:
|
||||
case ModeEmailVerify:
|
||||
case ModeWhitelist:
|
||||
default:
|
||||
return "", fmt.Errorf("Invalid account verification mode in %s: %s", verificationModeKey, mode)
|
||||
}
|
||||
return mode, nil
|
||||
}
|
||||
|
||||
func getAccountWhitelist(whitelist string, mode accountVerificationMode) (emails []auth.Email, err error) {
|
||||
if whitelist == "" {
|
||||
return []auth.Email{}, nil
|
||||
}
|
||||
|
||||
if mode != ModeWhitelist {
|
||||
return nil, fmt.Errorf("Do not specify ACCOUNT_WHITELIST in env if ACCOUNT_VERIFICATION_MODE is not Whitelist")
|
||||
}
|
||||
|
||||
rawEmails := strings.Split(whitelist, ",")
|
||||
for _, rawEmail := range rawEmails {
|
||||
// Give them a specific error here to let them know not to add spaces. It
|
||||
// could be confusing otherwise to figure out what's invalid.
|
||||
if strings.TrimSpace(rawEmail) != rawEmail {
|
||||
return nil, fmt.Errorf("Emails in %s should be comma separated with no spaces.", whitelistKey)
|
||||
}
|
||||
email := auth.Email(rawEmail)
|
||||
if !email.Validate() {
|
||||
return nil, fmt.Errorf("Invalid email in %s: %s", whitelistKey, email)
|
||||
}
|
||||
emails = append(emails, email)
|
||||
}
|
||||
return emails, nil
|
||||
}
|
||||
|
|
115
env/env_test.go
vendored
Normal file
115
env/env_test.go
vendored
Normal file
|
@ -0,0 +1,115 @@
|
|||
package env
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"lbryio/lbry-id/auth"
|
||||
)
|
||||
|
||||
func TestAccountVerificationMode(t *testing.T) {
|
||||
tt := []struct {
|
||||
name string
|
||||
|
||||
modeStr string
|
||||
expectedMode accountVerificationMode
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "allow all",
|
||||
|
||||
modeStr: "AllowAll",
|
||||
expectedMode: ModeAllowAll,
|
||||
},
|
||||
{
|
||||
name: "email verify",
|
||||
|
||||
modeStr: "EmailVerify",
|
||||
expectedMode: ModeEmailVerify,
|
||||
},
|
||||
{
|
||||
name: "whitelist",
|
||||
|
||||
modeStr: "Whitelist",
|
||||
expectedMode: ModeWhitelist,
|
||||
},
|
||||
{
|
||||
name: "blank",
|
||||
|
||||
modeStr: "",
|
||||
expectedMode: ModeWhitelist,
|
||||
},
|
||||
{
|
||||
name: "invalid",
|
||||
|
||||
modeStr: "Banana",
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
mode, err := getAccountVerificationMode(tc.modeStr)
|
||||
if mode != tc.expectedMode {
|
||||
t.Errorf("Expected mode %s got %s", tc.expectedMode, mode)
|
||||
}
|
||||
if tc.expectErr && err == nil {
|
||||
t.Errorf("Expected err")
|
||||
}
|
||||
if !tc.expectErr && err != nil {
|
||||
t.Errorf("Unexpected err: %s", err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountWhitelist(t *testing.T) {
|
||||
tt := []struct {
|
||||
name string
|
||||
|
||||
whitelist string
|
||||
expectedEmails []auth.Email
|
||||
expectedErr error
|
||||
mode accountVerificationMode
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
|
||||
mode: ModeWhitelist,
|
||||
whitelist: "",
|
||||
expectedEmails: []auth.Email{},
|
||||
},
|
||||
{
|
||||
name: "invalid mode",
|
||||
|
||||
mode: ModeEmailVerify,
|
||||
whitelist: "test1@example.com,test2@example.com",
|
||||
expectedErr: fmt.Errorf("Do not specify ACCOUNT_WHITELIST in env if ACCOUNT_VERIFICATION_MODE is not Whitelist"),
|
||||
},
|
||||
{
|
||||
name: "spaces in email",
|
||||
|
||||
mode: ModeWhitelist,
|
||||
whitelist: "test1@example.com ,test2@example.com",
|
||||
expectedErr: fmt.Errorf("Emails in ACCOUNT_WHITELIST should be comma separated with no spaces."),
|
||||
},
|
||||
{
|
||||
name: "invalid email",
|
||||
|
||||
mode: ModeWhitelist,
|
||||
whitelist: "test1@example.com,test2-example.com",
|
||||
expectedErr: fmt.Errorf("Invalid email in ACCOUNT_WHITELIST: test2-example.com"),
|
||||
},
|
||||
}
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
emails, err := getAccountWhitelist(tc.whitelist, tc.mode)
|
||||
if !reflect.DeepEqual(emails, tc.expectedEmails) {
|
||||
t.Errorf("Expected emails %+v got %+v", tc.expectedEmails, emails)
|
||||
}
|
||||
if fmt.Sprint(err) != fmt.Sprint(tc.expectedErr) {
|
||||
t.Errorf("Expected error `%s` got `%s`", tc.expectedErr, err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -16,13 +16,14 @@ type RegisterRequest struct {
|
|||
}
|
||||
|
||||
func (r *RegisterRequest) validate() error {
|
||||
if !validateEmail(r.Email) {
|
||||
if !r.Email.Validate() {
|
||||
return fmt.Errorf("Invalid or missing 'email'")
|
||||
}
|
||||
if r.Password == "" {
|
||||
return fmt.Errorf("Missing 'password'")
|
||||
}
|
||||
if !validateClientSaltSeed(r.ClientSaltSeed) {
|
||||
|
||||
if !r.ClientSaltSeed.Validate() {
|
||||
return fmt.Errorf("Invalid or missing 'clientSaltSeed'")
|
||||
}
|
||||
return nil
|
||||
|
|
|
@ -18,7 +18,7 @@ type AuthRequest struct {
|
|||
}
|
||||
|
||||
func (r *AuthRequest) validate() error {
|
||||
if !validateEmail(r.Email) {
|
||||
if !r.Email.Validate() {
|
||||
return fmt.Errorf("Invalid 'email'")
|
||||
}
|
||||
if r.Password == "" {
|
||||
|
|
|
@ -40,7 +40,7 @@ func getClientSaltSeedParams(req *http.Request) (email auth.Email, err error) {
|
|||
}
|
||||
}
|
||||
|
||||
if !validateEmail(email) {
|
||||
if !email.Validate() {
|
||||
email = ""
|
||||
err = fmt.Errorf("Invalid email")
|
||||
}
|
||||
|
|
|
@ -25,7 +25,7 @@ func (r *ChangePasswordRequest) validate() error {
|
|||
walletPresent := (r.EncryptedWallet != "" && r.Hmac != "" && r.Sequence > 0)
|
||||
walletAbsent := (r.EncryptedWallet == "" && r.Hmac == "" && r.Sequence == 0)
|
||||
|
||||
if !validateEmail(r.Email) {
|
||||
if !r.Email.Validate() {
|
||||
return fmt.Errorf("Invalid or missing 'email'")
|
||||
}
|
||||
if r.OldPassword == "" {
|
||||
|
@ -38,7 +38,7 @@ func (r *ChangePasswordRequest) validate() error {
|
|||
if r.OldPassword == r.NewPassword {
|
||||
return fmt.Errorf("'oldPassword' and 'newPassword' should not be the same")
|
||||
}
|
||||
if !validateClientSaltSeed(r.ClientSaltSeed) {
|
||||
if !r.ClientSaltSeed.Validate() {
|
||||
return fmt.Errorf("Invalid or missing 'clientSaltSeed'")
|
||||
}
|
||||
if !walletPresent && !walletAbsent {
|
||||
|
|
|
@ -1,12 +1,10 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/mail"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
|
||||
|
@ -171,26 +169,6 @@ func (s *Server) checkAuth(
|
|||
return authToken
|
||||
}
|
||||
|
||||
func validateEmail(email auth.Email) bool {
|
||||
e, err := mail.ParseAddress(string(email))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
// "Joe <joe@example.com>" is valid according to ParseAddress. Likewise
|
||||
// " joe@example.com". Etc. We only want the exact address, "joe@example.com"
|
||||
// to be valid. ParseAddress will extract the exact address as e.Address. So
|
||||
// we'll take the input email, put it through ParseAddress, see if it parses
|
||||
// successfully, and then compare the input email to e.Address to make sure
|
||||
// that it was an exact address to begin with.
|
||||
return string(email) == e.Address
|
||||
}
|
||||
|
||||
func validateClientSaltSeed(clientSaltSeed auth.ClientSaltSeed) bool {
|
||||
_, err := hex.DecodeString(string(clientSaltSeed))
|
||||
const seedHexLength = auth.ClientSaltSeedLength * 2
|
||||
return len(clientSaltSeed) == seedHexLength && err == nil
|
||||
}
|
||||
|
||||
// TODO - both wallet and token requests should be PUT, not POST.
|
||||
// PUT = "...creates a new resource or replaces a representation of the target resource with the request payload."
|
||||
|
||||
|
|
Loading…
Reference in a new issue