hdkeychain: Consolidate tests into package.

Putting the test code in the same package makes it easier for forks
since they don't have to change the import paths as much.

Also, address a few style and consistent nits while here:
- Prefer t.Fatalf over t.Errorf followed by a return
- Use the consistent style of starting a test function comments with the
  test name
- Prefix test errors by the function being called instead of the one
  doing the calling since the caller itself is already logged by the
  test framework
- Check err in max depth test before checking the returned key is nil
This commit is contained in:
Dave Collins 2017-05-09 12:03:40 -05:00
parent 4f8b4dbcb2
commit 1fb0120cc6
No known key found for this signature in database
GPG key ID: B8904D9D9C93D1F2
2 changed files with 50 additions and 56 deletions

View file

@ -197,16 +197,17 @@ func (k *ExtendedKey) ParentFingerprint() uint32 {
// returned if this should occur, and the caller is expected to ignore the // returned if this should occur, and the caller is expected to ignore the
// invalid child and simply increment to the next index. // invalid child and simply increment to the next index.
func (k *ExtendedKey) Child(i uint32) (*ExtendedKey, error) { func (k *ExtendedKey) Child(i uint32) (*ExtendedKey, error) {
// Prevent derivation of children beyond the max allowed depth.
if k.depth == maxUint8 {
return nil, ErrDeriveBeyondMaxDepth
}
// There are four scenarios that could happen here: // There are four scenarios that could happen here:
// 1) Private extended key -> Hardened child private extended key // 1) Private extended key -> Hardened child private extended key
// 2) Private extended key -> Non-hardened child private extended key // 2) Private extended key -> Non-hardened child private extended key
// 3) Public extended key -> Non-hardened child public extended key // 3) Public extended key -> Non-hardened child public extended key
// 4) Public extended key -> Hardened child public extended key (INVALID!) // 4) Public extended key -> Hardened child public extended key (INVALID!)
if k.depth == maxUint8 {
return nil, ErrDeriveBeyondMaxDepth
}
// Case #4 is invalid, so error out early. // Case #4 is invalid, so error out early.
// A hardened child extended key may not be created from a public // A hardened child extended key may not be created from a public
// extended key. // extended key.

View file

@ -2,7 +2,7 @@
// Use of this source code is governed by an ISC // Use of this source code is governed by an ISC
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package hdkeychain_test package hdkeychain
// References: // References:
// [BIP32]: BIP0032 - Hierarchical Deterministic Wallets // [BIP32]: BIP0032 - Hierarchical Deterministic Wallets
@ -17,7 +17,6 @@ import (
"testing" "testing"
"github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg"
"github.com/btcsuite/btcutil/hdkeychain"
) )
// TestBIP0032Vectors tests the vectors provided by [BIP32] to ensure the // TestBIP0032Vectors tests the vectors provided by [BIP32] to ensure the
@ -196,7 +195,7 @@ tests:
continue continue
} }
extKey, err := hdkeychain.NewMaster(masterSeed, test.net) extKey, err := NewMaster(masterSeed, test.net)
if err != nil { if err != nil {
t.Errorf("NewMaster #%d (%s): unexpected error when "+ t.Errorf("NewMaster #%d (%s): unexpected error when "+
"creating new master key: %v", i, test.name, "creating new master key: %v", i, test.name,
@ -347,7 +346,7 @@ func TestPrivateDerivation(t *testing.T) {
tests: tests:
for i, test := range tests { for i, test := range tests {
extKey, err := hdkeychain.NewKeyFromString(test.master) extKey, err := NewKeyFromString(test.master)
if err != nil { if err != nil {
t.Errorf("NewKeyFromString #%d (%s): unexpected error "+ t.Errorf("NewKeyFromString #%d (%s): unexpected error "+
"creating extended key: %v", i, test.name, "creating extended key: %v", i, test.name,
@ -466,7 +465,7 @@ func TestPublicDerivation(t *testing.T) {
tests: tests:
for i, test := range tests { for i, test := range tests {
extKey, err := hdkeychain.NewKeyFromString(test.master) extKey, err := NewKeyFromString(test.master)
if err != nil { if err != nil {
t.Errorf("NewKeyFromString #%d (%s): unexpected error "+ t.Errorf("NewKeyFromString #%d (%s): unexpected error "+
"creating extended key: %v", i, test.name, "creating extended key: %v", i, test.name,
@ -515,7 +514,7 @@ func TestGenenerateSeed(t *testing.T) {
} }
for i, test := range tests { for i, test := range tests {
seed, err := hdkeychain.GenerateSeed(test.length) seed, err := GenerateSeed(test.length)
if !reflect.DeepEqual(err, test.err) { if !reflect.DeepEqual(err, test.err) {
t.Errorf("GenerateSeed #%d (%s): unexpected error -- "+ t.Errorf("GenerateSeed #%d (%s): unexpected error -- "+
"want %v, got %v", i, test.name, test.err, err) "want %v, got %v", i, test.name, test.err, err)
@ -557,14 +556,14 @@ func TestExtendedKeyAPI(t *testing.T) {
extKey: "xpub6D4BDPcP2GT577Vvch3R8wDkScZWzQzMMUm3PWbmWvVJrZwQY4VUNgqFJPMM3No2dFDFGTsxxpG5uJh7n7epu4trkrX7x7DogT5Uv6fcLW5", extKey: "xpub6D4BDPcP2GT577Vvch3R8wDkScZWzQzMMUm3PWbmWvVJrZwQY4VUNgqFJPMM3No2dFDFGTsxxpG5uJh7n7epu4trkrX7x7DogT5Uv6fcLW5",
isPrivate: false, isPrivate: false,
parentFP: 3203769081, parentFP: 3203769081,
privKeyErr: hdkeychain.ErrNotPrivExtKey, privKeyErr: ErrNotPrivExtKey,
pubKey: "0357bfe1e341d01c69fe5654309956cbea516822fba8a601743a012a7896ee8dc2", pubKey: "0357bfe1e341d01c69fe5654309956cbea516822fba8a601743a012a7896ee8dc2",
address: "1NjxqbA9aZWnh17q1UW3rB4EPu79wDXj7x", address: "1NjxqbA9aZWnh17q1UW3rB4EPu79wDXj7x",
}, },
} }
for i, test := range tests { for i, test := range tests {
key, err := hdkeychain.NewKeyFromString(test.extKey) key, err := NewKeyFromString(test.extKey)
if err != nil { if err != nil {
t.Errorf("NewKeyFromString #%d (%s): unexpected "+ t.Errorf("NewKeyFromString #%d (%s): unexpected "+
"error: %v", i, test.name, err) "error: %v", i, test.name, err)
@ -724,7 +723,7 @@ func TestNet(t *testing.T) {
} }
for i, test := range tests { for i, test := range tests {
extKey, err := hdkeychain.NewKeyFromString(test.key) extKey, err := NewKeyFromString(test.key)
if err != nil { if err != nil {
t.Errorf("NewKeyFromString #%d (%s): unexpected error "+ t.Errorf("NewKeyFromString #%d (%s): unexpected error "+
"creating extended key: %v", i, test.name, "creating extended key: %v", i, test.name,
@ -778,41 +777,38 @@ func TestNet(t *testing.T) {
func TestErrors(t *testing.T) { func TestErrors(t *testing.T) {
// Should get an error when seed has too few bytes. // Should get an error when seed has too few bytes.
net := &chaincfg.MainNetParams net := &chaincfg.MainNetParams
_, err := hdkeychain.NewMaster(bytes.Repeat([]byte{0x00}, 15), net) _, err := NewMaster(bytes.Repeat([]byte{0x00}, 15), net)
if err != hdkeychain.ErrInvalidSeedLen { if err != ErrInvalidSeedLen {
t.Errorf("NewMaster: mismatched error -- got: %v, want: %v", t.Fatalf("NewMaster: mismatched error -- got: %v, want: %v",
err, hdkeychain.ErrInvalidSeedLen) err, ErrInvalidSeedLen)
} }
// Should get an error when seed has too many bytes. // Should get an error when seed has too many bytes.
_, err = hdkeychain.NewMaster(bytes.Repeat([]byte{0x00}, 65), net) _, err = NewMaster(bytes.Repeat([]byte{0x00}, 65), net)
if err != hdkeychain.ErrInvalidSeedLen { if err != ErrInvalidSeedLen {
t.Errorf("NewMaster: mismatched error -- got: %v, want: %v", t.Fatalf("NewMaster: mismatched error -- got: %v, want: %v",
err, hdkeychain.ErrInvalidSeedLen) err, ErrInvalidSeedLen)
} }
// Generate a new key and neuter it to a public extended key. // Generate a new key and neuter it to a public extended key.
seed, err := hdkeychain.GenerateSeed(hdkeychain.RecommendedSeedLen) seed, err := GenerateSeed(RecommendedSeedLen)
if err != nil { if err != nil {
t.Errorf("GenerateSeed: unexpected error: %v", err) t.Fatalf("GenerateSeed: unexpected error: %v", err)
return
} }
extKey, err := hdkeychain.NewMaster(seed, net) extKey, err := NewMaster(seed, net)
if err != nil { if err != nil {
t.Errorf("NewMaster: unexpected error: %v", err) t.Fatalf("NewMaster: unexpected error: %v", err)
return
} }
pubKey, err := extKey.Neuter() pubKey, err := extKey.Neuter()
if err != nil { if err != nil {
t.Errorf("Neuter: unexpected error: %v", err) t.Fatalf("Neuter: unexpected error: %v", err)
return
} }
// Deriving a hardened child extended key should fail from a public key. // Deriving a hardened child extended key should fail from a public key.
_, err = pubKey.Child(hdkeychain.HardenedKeyStart) _, err = pubKey.Child(HardenedKeyStart)
if err != hdkeychain.ErrDeriveHardFromPublic { if err != ErrDeriveHardFromPublic {
t.Errorf("Child: mismatched error -- got: %v, want: %v", t.Fatalf("Child: mismatched error -- got: %v, want: %v",
err, hdkeychain.ErrDeriveHardFromPublic) err, ErrDeriveHardFromPublic)
} }
// NewKeyFromString failure tests. // NewKeyFromString failure tests.
@ -826,12 +822,12 @@ func TestErrors(t *testing.T) {
{ {
name: "invalid key length", name: "invalid key length",
key: "xpub1234", key: "xpub1234",
err: hdkeychain.ErrInvalidKeyLen, err: ErrInvalidKeyLen,
}, },
{ {
name: "bad checksum", name: "bad checksum",
key: "xpub661MyMwAqRbcFtXgS5sYJABqqG9YLmC4Q1Rdap9gSE8NqtwybGhePY2gZ29ESFjqJoCu1Rupje8YtGqsefD265TMg7usUDFdp6W1EBygr15", key: "xpub661MyMwAqRbcFtXgS5sYJABqqG9YLmC4Q1Rdap9gSE8NqtwybGhePY2gZ29ESFjqJoCu1Rupje8YtGqsefD265TMg7usUDFdp6W1EBygr15",
err: hdkeychain.ErrBadChecksum, err: ErrBadChecksum,
}, },
{ {
name: "pubkey not on curve", name: "pubkey not on curve",
@ -848,7 +844,7 @@ func TestErrors(t *testing.T) {
} }
for i, test := range tests { for i, test := range tests {
extKey, err := hdkeychain.NewKeyFromString(test.key) extKey, err := NewKeyFromString(test.key)
if !reflect.DeepEqual(err, test.err) { if !reflect.DeepEqual(err, test.err) {
t.Errorf("NewKeyFromString #%d (%s): mismatched error "+ t.Errorf("NewKeyFromString #%d (%s): mismatched error "+
"-- got: %v, want: %v", i, test.name, err, "-- got: %v, want: %v", i, test.name, err,
@ -896,7 +892,7 @@ func TestZero(t *testing.T) {
// Use a closure to test that a key is zeroed since the tests create // Use a closure to test that a key is zeroed since the tests create
// keys in different ways and need to test the same things multiple // keys in different ways and need to test the same things multiple
// times. // times.
testZeroed := func(i int, testName string, key *hdkeychain.ExtendedKey) bool { testZeroed := func(i int, testName string, key *ExtendedKey) bool {
// Zeroing a key should result in it no longer being private // Zeroing a key should result in it no longer being private
if key.IsPrivate() { if key.IsPrivate() {
t.Errorf("IsPrivate #%d (%s): mismatched key type -- "+ t.Errorf("IsPrivate #%d (%s): mismatched key type -- "+
@ -922,7 +918,7 @@ func TestZero(t *testing.T) {
return false return false
} }
wantErr := hdkeychain.ErrNotPrivExtKey wantErr := ErrNotPrivExtKey
_, err := key.ECPrivKey() _, err := key.ECPrivKey()
if !reflect.DeepEqual(err, wantErr) { if !reflect.DeepEqual(err, wantErr) {
t.Errorf("ECPrivKey #%d (%s): mismatched error: want "+ t.Errorf("ECPrivKey #%d (%s): mismatched error: want "+
@ -963,7 +959,7 @@ func TestZero(t *testing.T) {
i, test.name, err) i, test.name, err)
continue continue
} }
key, err := hdkeychain.NewMaster(masterSeed, test.net) key, err := NewMaster(masterSeed, test.net)
if err != nil { if err != nil {
t.Errorf("NewMaster #%d (%s): unexpected error when "+ t.Errorf("NewMaster #%d (%s): unexpected error when "+
"creating new master key: %v", i, test.name, "creating new master key: %v", i, test.name,
@ -989,7 +985,7 @@ func TestZero(t *testing.T) {
} }
// Deserialize key and get the neutered version. // Deserialize key and get the neutered version.
key, err = hdkeychain.NewKeyFromString(test.extKey) key, err = NewKeyFromString(test.extKey)
if err != nil { if err != nil {
t.Errorf("NewKeyFromString #%d (%s): unexpected "+ t.Errorf("NewKeyFromString #%d (%s): unexpected "+
"error: %v", i, test.name, err) "error: %v", i, test.name, err)
@ -1015,34 +1011,31 @@ func TestZero(t *testing.T) {
} }
} }
// The serialization of a BIP32 key uses uint8 to encode the depth. This implicitly // TestMaximumDepth ensures that attempting to retrieve a child key when already
// bounds the depth of the tree to 255 derivations. Here we test that an error is // at the maximum depth is not allowed. The serialization of a BIP32 key uses
// returned after 'max uint8'. // uint8 to encode the depth. This implicitly bounds the depth of the tree to
// 255 derivations. Here we test that an error is returned after 'max uint8'.
func TestMaximumDepth(t *testing.T) { func TestMaximumDepth(t *testing.T) {
net := &chaincfg.MainNetParams
extKey, err := hdkeychain.NewMaster([]byte(`abcd1234abcd1234abcd1234abcd1234`), &chaincfg.MainNetParams) extKey, err := NewMaster([]byte(`abcd1234abcd1234abcd1234abcd1234`), net)
if err != nil { if err != nil {
t.Error("MaxDepthTest: Failed to produce test fixture key from string") t.Fatalf("NewMaster: unexpected error: %v", err)
return
} }
for i := uint8(0); i < math.MaxUint8; i++ { for i := uint8(0); i < math.MaxUint8; i++ {
newKey, err := extKey.Child(1) newKey, err := extKey.Child(1)
if err != nil { if err != nil {
t.Error("MaxDepthTest: Failed to produce key required for test") t.Fatalf("Child: unexpected error: %v", err)
return
} }
extKey = newKey extKey = newKey
} }
noKey, err := extKey.Child(1) noKey, err := extKey.Child(1)
if err != ErrDeriveBeyondMaxDepth {
t.Fatalf("Child: mismatched error: want %v, got %v",
ErrDeriveBeyondMaxDepth, err)
}
if noKey != nil { if noKey != nil {
t.Error("MaxDepthTest: Deriving 256th key should not succeed") t.Fatal("Child: deriving 256th key should not succeed")
return
} }
if err != hdkeychain.ErrDeriveBeyondMaxDepth {
t.Error("MaxDepthTest: Received unexpected error during test")
}
} }