diff --git a/chain/pruned_block_dispatcher.go b/chain/pruned_block_dispatcher.go index aec8d02..4a80a5f 100644 --- a/chain/pruned_block_dispatcher.go +++ b/chain/pruned_block_dispatcher.go @@ -417,10 +417,12 @@ func filterNodeAddrs(nodeAddrs []btcjson.GetNodeAddressesResult) []string { } // satisfiesRequiredServices determines whether the services signaled by a peer -// satisfy our requirements for retrieving pruned blocks from them. +// satisfy our requirements for retrieving pruned blocks from them. We need the +// full chain, and witness data as well. Note that we ignore the limited +// (pruned bit) as nodes can have the full data and set that as well. Pure +// pruned nodes won't set the network bit. func satisfiesRequiredServices(services wire.ServiceFlag) bool { - return services&requiredServices == requiredServices && - services&prunedNodeService != prunedNodeService + return services&requiredServices == requiredServices } // newQueryPeer creates a new peer instance configured to relay any received diff --git a/chain/pruned_block_dispatcher_test.go b/chain/pruned_block_dispatcher_test.go index 8b319d7..50af406 100644 --- a/chain/pruned_block_dispatcher_test.go +++ b/chain/pruned_block_dispatcher_test.go @@ -620,3 +620,40 @@ func TestPrunedBlockDispatcherInvalidBlock(t *testing.T) { h.assertPeerQueried() h.assertPeerReplied(blockChan, errChan, true) } + +func TestSatisfiesRequiredServices(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + services wire.ServiceFlag + ok bool + }{ + { + name: "full node, segwit", + services: wire.SFNodeWitness | wire.SFNodeNetwork, + ok: true, + }, + { + name: "full node segwit, signals limited", + services: wire.SFNodeWitness | wire.SFNodeNetwork | prunedNodeService, + ok: true, + }, + { + name: "full node, no segwit", + services: wire.SFNodeNetwork, + ok: false, + }, + { + name: "segwit, pure pruned", + services: wire.SFNodeWitness | prunedNodeService, + ok: false, + }, + } + for _, testCase := range testCases { + ok := satisfiesRequiredServices(testCase.services) + require.Equal( + t, testCase.ok, ok, fmt.Sprintf("test case: %v", testCase.name), + ) + } +}