From bcd50ea838c984bb1e4756f5b1ed0021b52e9dd2 Mon Sep 17 00:00:00 2001 From: Wilmer Paulino Date: Wed, 6 Feb 2019 13:32:52 -0800 Subject: [PATCH] addrmgr: store address' supported service bits In this commit, we update the serialized version of the AddressManager to also store each address' service bits. This allows an address' service bits to be properly set when read from disk, which allows external callers to reliably filter out addresses which do not support their required services. Since the service bits were not previously stored, the serialization version needed to be bumped. A test has been added to test the behavior of upgrading from version 1 to 2. --- addrmgr/addrmanager.go | 45 ++++++++++++++--- addrmgr/addrmanager_internal_test.go | 75 ++++++++++++++++++++++++++++ addrmgr/addrmanager_test.go | 4 +- 3 files changed, 114 insertions(+), 10 deletions(-) diff --git a/addrmgr/addrmanager.go b/addrmgr/addrmanager.go index 89569aa0..a8a8fb33 100644 --- a/addrmgr/addrmanager.go +++ b/addrmgr/addrmanager.go @@ -46,6 +46,7 @@ type AddrManager struct { nNew int lamtx sync.Mutex localAddresses map[string]*localAddress + version int } type serializedKnownAddress struct { @@ -55,6 +56,8 @@ type serializedKnownAddress struct { TimeStamp int64 LastAttempt int64 LastSuccess int64 + Services wire.ServiceFlag + SrcServices wire.ServiceFlag // no refcount or tried, that is available from context. } @@ -155,7 +158,7 @@ const ( getAddrPercent = 23 // serialisationVersion is the current version of the on-disk format. - serialisationVersion = 1 + serialisationVersion = 2 ) // updateAddress is a helper function to either update an address already known @@ -362,7 +365,7 @@ func (a *AddrManager) savePeers() { // First we make a serialisable datastructure so we can encode it to // json. sam := new(serializedAddrManager) - sam.Version = serialisationVersion + sam.Version = a.version copy(sam.Key[:], a.key[:]) sam.Addresses = make([]*serializedKnownAddress, len(a.addrIndex)) @@ -375,6 +378,10 @@ func (a *AddrManager) savePeers() { ska.Attempts = v.attempts ska.LastAttempt = v.lastattempt.Unix() ska.LastSuccess = v.lastsuccess.Unix() + if a.version > 1 { + ska.Services = v.na.Services + ska.SrcServices = v.srcAddr.Services + } // Tried and refs are implicit in the rest of the structure // and will be worked out from context on unserialisation. sam.Addresses[i] = ska @@ -451,24 +458,43 @@ func (a *AddrManager) deserializePeers(filePath string) error { return fmt.Errorf("error reading %s: %v", filePath, err) } - if sam.Version != serialisationVersion { + // Since decoding JSON is backwards compatible (i.e., only decodes + // fields it understands), we'll only return an error upon seeing a + // version past our latest supported version. + if sam.Version > serialisationVersion { return fmt.Errorf("unknown version %v in serialized "+ "addrmanager", sam.Version) } + copy(a.key[:], sam.Key[:]) for _, v := range sam.Addresses { ka := new(KnownAddress) - ka.na, err = a.DeserializeNetAddress(v.Addr) + + // The first version of the serialized address manager was not + // aware of the service bits associated with this address, so + // we'll assign a default of SFNodeNetwork to it. + if sam.Version == 1 { + v.Services = wire.SFNodeNetwork + } + ka.na, err = a.DeserializeNetAddress(v.Addr, v.Services) if err != nil { return fmt.Errorf("failed to deserialize netaddress "+ "%s: %v", v.Addr, err) } - ka.srcAddr, err = a.DeserializeNetAddress(v.Src) + + // The first version of the serialized address manager was not + // aware of the service bits associated with the source address, + // so we'll assign a default of SFNodeNetwork to it. + if sam.Version == 1 { + v.SrcServices = wire.SFNodeNetwork + } + ka.srcAddr, err = a.DeserializeNetAddress(v.Src, v.SrcServices) if err != nil { return fmt.Errorf("failed to deserialize netaddress "+ "%s: %v", v.Src, err) } + ka.attempts = v.Attempts ka.lastattempt = time.Unix(v.LastAttempt, 0) ka.lastsuccess = time.Unix(v.LastSuccess, 0) @@ -520,8 +546,10 @@ func (a *AddrManager) deserializePeers(filePath string) error { return nil } -// DeserializeNetAddress converts a given address string to a *wire.NetAddress -func (a *AddrManager) DeserializeNetAddress(addr string) (*wire.NetAddress, error) { +// DeserializeNetAddress converts a given address string to a *wire.NetAddress. +func (a *AddrManager) DeserializeNetAddress(addr string, + services wire.ServiceFlag) (*wire.NetAddress, error) { + host, portStr, err := net.SplitHostPort(addr) if err != nil { return nil, err @@ -531,7 +559,7 @@ func (a *AddrManager) DeserializeNetAddress(addr string) (*wire.NetAddress, erro return nil, err } - return a.HostToNetAddress(host, uint16(port), wire.SFNodeNetwork) + return a.HostToNetAddress(host, uint16(port), services) } // Start begins the core address handler which manages a pool of known @@ -1116,6 +1144,7 @@ func New(dataDir string, lookupFunc func(string) ([]net.IP, error)) *AddrManager rand: rand.New(rand.NewSource(time.Now().UnixNano())), quit: make(chan struct{}), localAddresses: make(map[string]*localAddress), + version: serialisationVersion, } am.reset() return &am diff --git a/addrmgr/addrmanager_internal_test.go b/addrmgr/addrmanager_internal_test.go index 82036d37..1c19dceb 100644 --- a/addrmgr/addrmanager_internal_test.go +++ b/addrmgr/addrmanager_internal_test.go @@ -117,3 +117,78 @@ func TestAddrManagerSerialization(t *testing.T) { addrMgr.loadPeers() assertAddrs(t, addrMgr, expectedAddrs) } + +// TestAddrManagerV1ToV2 ensures that we can properly upgrade the serialized +// version of the address manager from v1 to v2. +func TestAddrManagerV1ToV2(t *testing.T) { + t.Parallel() + + // We'll start by creating our address manager backed by a temporary + // directory. + tempDir, err := ioutil.TempDir("", "addrmgr") + if err != nil { + t.Fatalf("unable to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + addrMgr := New(tempDir, nil) + + // As we're interested in testing the upgrade path from v1 to v2, we'll + // override the manager's current version. + addrMgr.version = 1 + + // We'll be adding 5 random addresses to the manager. Since this is v1, + // each addresses' services will not be stored. + const numAddrs = 5 + + expectedAddrs := make(map[string]*wire.NetAddress, numAddrs) + for i := 0; i < numAddrs; i++ { + addr := randAddr(t) + expectedAddrs[NetAddressKey(addr)] = addr + addrMgr.AddAddress(addr, randAddr(t)) + } + + // Then, we'll persist these addresses to disk and restart the address + // manager - overriding its version back to v1. + addrMgr.savePeers() + addrMgr = New(tempDir, nil) + addrMgr.version = 1 + + // When we read all of the addresses back from disk, we should expect to + // find all of them, but their services will be set to a default of + // SFNodeNetwork since they were not previously stored. After ensuring + // that this default is set, we'll override each addresses' services + // with the original value from when they were created. + addrMgr.loadPeers() + addrs := addrMgr.getAddresses() + if len(addrs) != len(expectedAddrs) { + t.Fatalf("expected to find %d adddresses, found %d", + len(expectedAddrs), len(addrs)) + } + for _, addr := range addrs { + addrStr := NetAddressKey(addr) + expectedAddr, ok := expectedAddrs[addrStr] + if !ok { + t.Fatalf("expected to find address %v", addrStr) + } + + if addr.Services != wire.SFNodeNetwork { + t.Fatalf("expected address services to be %v, got %v", + wire.SFNodeNetwork, addr.Services) + } + + addrMgr.SetServices(addr, expectedAddr.Services) + } + + // We'll also bump up the manager's version to v2, which should signal + // that it should include the address services when persisting its + // state. + addrMgr.version = 2 + addrMgr.savePeers() + + // Finally, we'll recreate the manager and ensure that the services were + // persisted correctly. + addrMgr = New(tempDir, nil) + addrMgr.loadPeers() + assertAddrs(t, addrMgr, expectedAddrs) +} diff --git a/addrmgr/addrmanager_test.go b/addrmgr/addrmanager_test.go index fcfe845f..676913e2 100644 --- a/addrmgr/addrmanager_test.go +++ b/addrmgr/addrmanager_test.go @@ -262,7 +262,7 @@ func TestNeedMoreAddresses(t *testing.T) { var err error for i := 0; i < addrsToAdd; i++ { s := fmt.Sprintf("%d.%d.173.147:8333", i/128+60, i%128+60) - addrs[i], err = n.DeserializeNetAddress(s) + addrs[i], err = n.DeserializeNetAddress(s, wire.SFNodeNetwork) if err != nil { t.Errorf("Failed to turn %s into an address: %v", s, err) } @@ -290,7 +290,7 @@ func TestGood(t *testing.T) { var err error for i := 0; i < addrsToAdd; i++ { s := fmt.Sprintf("%d.173.147.%d:8333", i/64+60, i%64+60) - addrs[i], err = n.DeserializeNetAddress(s) + addrs[i], err = n.DeserializeNetAddress(s, wire.SFNodeNetwork) if err != nil { t.Errorf("Failed to turn %s into an address: %v", s, err) }