Fix hdkeychain to avoid zeroing net version bytes.
This commit corrects the Zero function in hdkeychain to nil the version instead of zeroing the bytes. This is necessary because the keys are holding onto a reference into the specific version bytes for the network as provided by the btcnet package. Zeroing them causes the bytes in the btcnet package to be zeroed which then leads to issues later when trying to use them. Also, to prevent regressions, new tests have been added to exercise this scenario. Pointed out by @jimmysong.
This commit is contained in:
parent
d4a2dd199b
commit
2539ca9860
2 changed files with 89 additions and 24 deletions
|
@ -424,7 +424,7 @@ func (k *ExtendedKey) Zero() {
|
|||
zero(k.pubKey)
|
||||
zero(k.chainCode)
|
||||
zero(k.parentFP)
|
||||
zero(k.version)
|
||||
k.version = nil
|
||||
k.key = nil
|
||||
k.depth = 0
|
||||
k.childNum = 0
|
||||
|
|
|
@ -679,76 +679,141 @@ func TestErrors(t *testing.T) {
|
|||
func TestZero(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
master string
|
||||
extKey string
|
||||
}{
|
||||
// Test vector 1
|
||||
{
|
||||
name: "test vector 1 chain m",
|
||||
master: "000102030405060708090a0b0c0d0e0f",
|
||||
extKey: "xprv9s21ZrQH143K3QTDL4LXw2F7HEK3wJUD2nW2nRk4stbPy6cq3jPPqjiChkVvvNKmPGJxWUtg6LnF5kejMRNNU3TGtRBeJgk33yuGBxrMPHi",
|
||||
},
|
||||
|
||||
// Test vector 2
|
||||
{
|
||||
name: "test vector 2 chain m",
|
||||
master: "fffcf9f6f3f0edeae7e4e1dedbd8d5d2cfccc9c6c3c0bdbab7b4b1aeaba8a5a29f9c999693908d8a8784817e7b7875726f6c696663605d5a5754514e4b484542",
|
||||
extKey: "xprv9s21ZrQH143K31xYSDQpPDxsXRTUcvj2iNHm5NUtrGiGG5e2DtALGdso3pGz6ssrdK4PFmM8NSpSBHNqPqm55Qn3LqFtT2emdEXVYsCzC2U",
|
||||
},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
key, err := hdkeychain.NewKeyFromString(test.extKey)
|
||||
if err != nil {
|
||||
t.Errorf("NewKeyFromString #%d (%s): unexpected "+
|
||||
"error: %v", i, test.name, err)
|
||||
continue
|
||||
}
|
||||
key.Zero()
|
||||
|
||||
// Use a closure to test that a key is zeroed since the tests create
|
||||
// keys in different ways and need to test the same things multiple
|
||||
// times.
|
||||
testZeroed := func(i int, testName string, key *hdkeychain.ExtendedKey) bool {
|
||||
// Zeroing a key should result in it no longer being private
|
||||
if key.IsPrivate() != false {
|
||||
t.Errorf("IsPrivate #%d (%s): mismatched key type -- "+
|
||||
"want private %v, got private %v", i, test.name,
|
||||
"want private %v, got private %v", i, testName,
|
||||
false, key.IsPrivate())
|
||||
continue
|
||||
return false
|
||||
}
|
||||
|
||||
parentFP := key.ParentFingerprint()
|
||||
if parentFP != 0 {
|
||||
t.Errorf("ParentFingerprint #%d (%s): mismatched "+
|
||||
"parent fingerprint -- want %d, got %d", i,
|
||||
test.name, 0, parentFP)
|
||||
continue
|
||||
testName, 0, parentFP)
|
||||
return false
|
||||
}
|
||||
|
||||
wantKey := "zeroed extended key"
|
||||
serializedKey := key.String()
|
||||
if serializedKey != wantKey {
|
||||
t.Errorf("String #%d (%s): mismatched serialized key "+
|
||||
"-- want %s, got %s", i, test.name, wantKey,
|
||||
"-- want %s, got %s", i, testName, wantKey,
|
||||
serializedKey)
|
||||
continue
|
||||
return false
|
||||
}
|
||||
|
||||
wantErr := hdkeychain.ErrNotPrivExtKey
|
||||
_, err = key.ECPrivKey()
|
||||
_, err := key.ECPrivKey()
|
||||
if !reflect.DeepEqual(err, wantErr) {
|
||||
t.Errorf("ECPrivKey #%d (%s): mismatched error: want "+
|
||||
"%v, got %v", i, test.name, wantErr, err)
|
||||
continue
|
||||
"%v, got %v", i, testName, wantErr, err)
|
||||
return false
|
||||
}
|
||||
|
||||
wantErr = errors.New("pubkey string is empty")
|
||||
_, err = key.ECPubKey()
|
||||
if !reflect.DeepEqual(err, wantErr) {
|
||||
t.Errorf("ECPubKey #%d (%s): mismatched error: want "+
|
||||
"%v, got %v", i, test.name, wantErr, err)
|
||||
continue
|
||||
"%v, got %v", i, testName, wantErr, err)
|
||||
return false
|
||||
}
|
||||
|
||||
wantAddr := "1HT7xU2Ngenf7D4yocz2SAcnNLW7rK8d4E"
|
||||
addr, err := key.Address(&btcnet.MainNetParams)
|
||||
if err != nil {
|
||||
t.Errorf("Addres s #%d (%s): unexpected error: %v", i,
|
||||
test.name, err)
|
||||
continue
|
||||
testName, err)
|
||||
return false
|
||||
}
|
||||
if addr.EncodeAddress() != wantAddr {
|
||||
t.Errorf("Address #%d (%s): mismatched address -- want "+
|
||||
"%s, got %s", i, test.name, wantAddr,
|
||||
"%s, got %s", i, testName, wantAddr,
|
||||
addr.EncodeAddress())
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
// Create new key from seed and get the neutered version.
|
||||
masterSeed, err := hex.DecodeString(test.master)
|
||||
if err != nil {
|
||||
t.Errorf("DecodeString #%d (%s): unexpected error: %v",
|
||||
i, test.name, err)
|
||||
continue
|
||||
}
|
||||
key, err := hdkeychain.NewMaster(masterSeed)
|
||||
if err != nil {
|
||||
t.Errorf("NewMaster #%d (%s): unexpected error when "+
|
||||
"creating new master key: %v", i, test.name,
|
||||
err)
|
||||
continue
|
||||
}
|
||||
neuteredKey, err := key.Neuter()
|
||||
if err != nil {
|
||||
t.Errorf("Neuter #%d (%s): unexpected error: %v", i,
|
||||
test.name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Ensure both non-neutered and neutered keys are zeroed
|
||||
// properly.
|
||||
key.Zero()
|
||||
if !testZeroed(i, test.name+" from seed not neutered", key) {
|
||||
continue
|
||||
}
|
||||
neuteredKey.Zero()
|
||||
if !testZeroed(i, test.name+" from seed neutered", key) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Deserialize key and get the neutered version.
|
||||
key, err = hdkeychain.NewKeyFromString(test.extKey)
|
||||
if err != nil {
|
||||
t.Errorf("NewKeyFromString #%d (%s): unexpected "+
|
||||
"error: %v", i, test.name, err)
|
||||
continue
|
||||
}
|
||||
neuteredKey, err = key.Neuter()
|
||||
if err != nil {
|
||||
t.Errorf("Neuter #%d (%s): unexpected error: %v", i,
|
||||
test.name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Ensure both non-neutered and neutered keys are zeroed
|
||||
// properly.
|
||||
key.Zero()
|
||||
if !testZeroed(i, test.name+" deserialized not neutered", key) {
|
||||
continue
|
||||
}
|
||||
neuteredKey.Zero()
|
||||
if !testZeroed(i, test.name+" deserialized neutered", key) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue