diff --git a/hdkeychain/extendedkey.go b/hdkeychain/extendedkey.go index 0912060..e8ebec2 100644 --- a/hdkeychain/extendedkey.go +++ b/hdkeychain/extendedkey.go @@ -197,16 +197,17 @@ func (k *ExtendedKey) ParentFingerprint() uint32 { // returned if this should occur, and the caller is expected to ignore the // invalid child and simply increment to the next index. func (k *ExtendedKey) Child(i uint32) (*ExtendedKey, error) { + // Prevent derivation of children beyond the max allowed depth. + if k.depth == maxUint8 { + return nil, ErrDeriveBeyondMaxDepth + } + // There are four scenarios that could happen here: // 1) Private extended key -> Hardened child private extended key // 2) Private extended key -> Non-hardened child private extended key // 3) Public extended key -> Non-hardened child public extended key // 4) Public extended key -> Hardened child public extended key (INVALID!) - if k.depth == maxUint8 { - return nil, ErrDeriveBeyondMaxDepth - } - // Case #4 is invalid, so error out early. // A hardened child extended key may not be created from a public // extended key. diff --git a/hdkeychain/extendedkey_test.go b/hdkeychain/extendedkey_test.go index 320e242..55eada5 100644 --- a/hdkeychain/extendedkey_test.go +++ b/hdkeychain/extendedkey_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. -package hdkeychain_test +package hdkeychain // References: // [BIP32]: BIP0032 - Hierarchical Deterministic Wallets @@ -17,7 +17,6 @@ import ( "testing" "github.com/btcsuite/btcd/chaincfg" - "github.com/btcsuite/btcutil/hdkeychain" ) // TestBIP0032Vectors tests the vectors provided by [BIP32] to ensure the @@ -196,7 +195,7 @@ tests: continue } - extKey, err := hdkeychain.NewMaster(masterSeed, test.net) + extKey, err := NewMaster(masterSeed, test.net) if err != nil { t.Errorf("NewMaster #%d (%s): unexpected error when "+ "creating new master key: %v", i, test.name, @@ -347,7 +346,7 @@ func TestPrivateDerivation(t *testing.T) { tests: for i, test := range tests { - extKey, err := hdkeychain.NewKeyFromString(test.master) + extKey, err := NewKeyFromString(test.master) if err != nil { t.Errorf("NewKeyFromString #%d (%s): unexpected error "+ "creating extended key: %v", i, test.name, @@ -466,7 +465,7 @@ func TestPublicDerivation(t *testing.T) { tests: for i, test := range tests { - extKey, err := hdkeychain.NewKeyFromString(test.master) + extKey, err := NewKeyFromString(test.master) if err != nil { t.Errorf("NewKeyFromString #%d (%s): unexpected error "+ "creating extended key: %v", i, test.name, @@ -515,7 +514,7 @@ func TestGenenerateSeed(t *testing.T) { } for i, test := range tests { - seed, err := hdkeychain.GenerateSeed(test.length) + seed, err := GenerateSeed(test.length) if !reflect.DeepEqual(err, test.err) { t.Errorf("GenerateSeed #%d (%s): unexpected error -- "+ "want %v, got %v", i, test.name, test.err, err) @@ -557,14 +556,14 @@ func TestExtendedKeyAPI(t *testing.T) { extKey: "xpub6D4BDPcP2GT577Vvch3R8wDkScZWzQzMMUm3PWbmWvVJrZwQY4VUNgqFJPMM3No2dFDFGTsxxpG5uJh7n7epu4trkrX7x7DogT5Uv6fcLW5", isPrivate: false, parentFP: 3203769081, - privKeyErr: hdkeychain.ErrNotPrivExtKey, + privKeyErr: ErrNotPrivExtKey, pubKey: "0357bfe1e341d01c69fe5654309956cbea516822fba8a601743a012a7896ee8dc2", address: "1NjxqbA9aZWnh17q1UW3rB4EPu79wDXj7x", }, } for i, test := range tests { - key, err := hdkeychain.NewKeyFromString(test.extKey) + key, err := NewKeyFromString(test.extKey) if err != nil { t.Errorf("NewKeyFromString #%d (%s): unexpected "+ "error: %v", i, test.name, err) @@ -724,7 +723,7 @@ func TestNet(t *testing.T) { } for i, test := range tests { - extKey, err := hdkeychain.NewKeyFromString(test.key) + extKey, err := NewKeyFromString(test.key) if err != nil { t.Errorf("NewKeyFromString #%d (%s): unexpected error "+ "creating extended key: %v", i, test.name, @@ -778,41 +777,38 @@ func TestNet(t *testing.T) { func TestErrors(t *testing.T) { // Should get an error when seed has too few bytes. net := &chaincfg.MainNetParams - _, err := hdkeychain.NewMaster(bytes.Repeat([]byte{0x00}, 15), net) - if err != hdkeychain.ErrInvalidSeedLen { - t.Errorf("NewMaster: mismatched error -- got: %v, want: %v", - err, hdkeychain.ErrInvalidSeedLen) + _, err := NewMaster(bytes.Repeat([]byte{0x00}, 15), net) + if err != ErrInvalidSeedLen { + t.Fatalf("NewMaster: mismatched error -- got: %v, want: %v", + err, ErrInvalidSeedLen) } // Should get an error when seed has too many bytes. - _, err = hdkeychain.NewMaster(bytes.Repeat([]byte{0x00}, 65), net) - if err != hdkeychain.ErrInvalidSeedLen { - t.Errorf("NewMaster: mismatched error -- got: %v, want: %v", - err, hdkeychain.ErrInvalidSeedLen) + _, err = NewMaster(bytes.Repeat([]byte{0x00}, 65), net) + if err != ErrInvalidSeedLen { + t.Fatalf("NewMaster: mismatched error -- got: %v, want: %v", + err, ErrInvalidSeedLen) } // Generate a new key and neuter it to a public extended key. - seed, err := hdkeychain.GenerateSeed(hdkeychain.RecommendedSeedLen) + seed, err := GenerateSeed(RecommendedSeedLen) if err != nil { - t.Errorf("GenerateSeed: unexpected error: %v", err) - return + t.Fatalf("GenerateSeed: unexpected error: %v", err) } - extKey, err := hdkeychain.NewMaster(seed, net) + extKey, err := NewMaster(seed, net) if err != nil { - t.Errorf("NewMaster: unexpected error: %v", err) - return + t.Fatalf("NewMaster: unexpected error: %v", err) } pubKey, err := extKey.Neuter() if err != nil { - t.Errorf("Neuter: unexpected error: %v", err) - return + t.Fatalf("Neuter: unexpected error: %v", err) } // Deriving a hardened child extended key should fail from a public key. - _, err = pubKey.Child(hdkeychain.HardenedKeyStart) - if err != hdkeychain.ErrDeriveHardFromPublic { - t.Errorf("Child: mismatched error -- got: %v, want: %v", - err, hdkeychain.ErrDeriveHardFromPublic) + _, err = pubKey.Child(HardenedKeyStart) + if err != ErrDeriveHardFromPublic { + t.Fatalf("Child: mismatched error -- got: %v, want: %v", + err, ErrDeriveHardFromPublic) } // NewKeyFromString failure tests. @@ -826,12 +822,12 @@ func TestErrors(t *testing.T) { { name: "invalid key length", key: "xpub1234", - err: hdkeychain.ErrInvalidKeyLen, + err: ErrInvalidKeyLen, }, { name: "bad checksum", key: "xpub661MyMwAqRbcFtXgS5sYJABqqG9YLmC4Q1Rdap9gSE8NqtwybGhePY2gZ29ESFjqJoCu1Rupje8YtGqsefD265TMg7usUDFdp6W1EBygr15", - err: hdkeychain.ErrBadChecksum, + err: ErrBadChecksum, }, { name: "pubkey not on curve", @@ -848,7 +844,7 @@ func TestErrors(t *testing.T) { } for i, test := range tests { - extKey, err := hdkeychain.NewKeyFromString(test.key) + extKey, err := NewKeyFromString(test.key) if !reflect.DeepEqual(err, test.err) { t.Errorf("NewKeyFromString #%d (%s): mismatched error "+ "-- got: %v, want: %v", i, test.name, err, @@ -896,7 +892,7 @@ func TestZero(t *testing.T) { // Use a closure to test that a key is zeroed since the tests create // keys in different ways and need to test the same things multiple // times. - testZeroed := func(i int, testName string, key *hdkeychain.ExtendedKey) bool { + testZeroed := func(i int, testName string, key *ExtendedKey) bool { // Zeroing a key should result in it no longer being private if key.IsPrivate() { t.Errorf("IsPrivate #%d (%s): mismatched key type -- "+ @@ -922,7 +918,7 @@ func TestZero(t *testing.T) { return false } - wantErr := hdkeychain.ErrNotPrivExtKey + wantErr := ErrNotPrivExtKey _, err := key.ECPrivKey() if !reflect.DeepEqual(err, wantErr) { t.Errorf("ECPrivKey #%d (%s): mismatched error: want "+ @@ -963,7 +959,7 @@ func TestZero(t *testing.T) { i, test.name, err) continue } - key, err := hdkeychain.NewMaster(masterSeed, test.net) + key, err := NewMaster(masterSeed, test.net) if err != nil { t.Errorf("NewMaster #%d (%s): unexpected error when "+ "creating new master key: %v", i, test.name, @@ -989,7 +985,7 @@ func TestZero(t *testing.T) { } // Deserialize key and get the neutered version. - key, err = hdkeychain.NewKeyFromString(test.extKey) + key, err = NewKeyFromString(test.extKey) if err != nil { t.Errorf("NewKeyFromString #%d (%s): unexpected "+ "error: %v", i, test.name, err) @@ -1015,34 +1011,31 @@ func TestZero(t *testing.T) { } } -// The serialization of a BIP32 key uses uint8 to encode the depth. This implicitly -// bounds the depth of the tree to 255 derivations. Here we test that an error is -// returned after 'max uint8'. +// TestMaximumDepth ensures that attempting to retrieve a child key when already +// at the maximum depth is not allowed. The serialization of a BIP32 key uses +// uint8 to encode the depth. This implicitly bounds the depth of the tree to +// 255 derivations. Here we test that an error is returned after 'max uint8'. func TestMaximumDepth(t *testing.T) { - - extKey, err := hdkeychain.NewMaster([]byte(`abcd1234abcd1234abcd1234abcd1234`), &chaincfg.MainNetParams) + net := &chaincfg.MainNetParams + extKey, err := NewMaster([]byte(`abcd1234abcd1234abcd1234abcd1234`), net) if err != nil { - t.Error("MaxDepthTest: Failed to produce test fixture key from string") - return + t.Fatalf("NewMaster: unexpected error: %v", err) } for i := uint8(0); i < math.MaxUint8; i++ { newKey, err := extKey.Child(1) if err != nil { - t.Error("MaxDepthTest: Failed to produce key required for test") - return + t.Fatalf("Child: unexpected error: %v", err) } extKey = newKey } noKey, err := extKey.Child(1) + if err != ErrDeriveBeyondMaxDepth { + t.Fatalf("Child: mismatched error: want %v, got %v", + ErrDeriveBeyondMaxDepth, err) + } if noKey != nil { - t.Error("MaxDepthTest: Deriving 256th key should not succeed") - return + t.Fatal("Child: deriving 256th key should not succeed") } - - if err != hdkeychain.ErrDeriveBeyondMaxDepth { - t.Error("MaxDepthTest: Received unexpected error during test") - } - }