From bc36cf51c6aec42ddbfc4585107a87ca26ba9815 Mon Sep 17 00:00:00 2001 From: Jim Posen Date: Mon, 14 Aug 2017 14:58:57 -0700 Subject: [PATCH] peer: Don't send unsupported reject to old peers. --- peer/peer.go | 24 +++++----- peer/peer_test.go | 110 +++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 122 insertions(+), 12 deletions(-) diff --git a/peer/peer.go b/peer/peer.go index 9c77ff1b..dc6f8b4a 100644 --- a/peer/peer.go +++ b/peer/peer.go @@ -29,6 +29,10 @@ const ( // MaxProtocolVersion is the max protocol version the peer supports. MaxProtocolVersion = wire.FeeFilterVersion + // minAcceptableProtocolVersion is the lowest protocol version that a + // connected peer may support. + minAcceptableProtocolVersion = wire.MultipleAddressVersion + // outputBufferSize is the number of elements the output channels use. outputBufferSize = 50 @@ -1012,15 +1016,14 @@ func (p *Peer) handleRemoteVersionMsg(msg *wire.MsgVersion) error { // Notify and disconnect clients that have a protocol version that is // too old. - if msg.ProtocolVersion < int32(wire.MultipleAddressVersion) { - // Send a reject message indicating the protocol version is - // obsolete and wait for the message to be sent before - // disconnecting. + // + // NOTE: If minAcceptableProtocolVersion is raised to be higher than + // wire.RejectVersion, this should send a reject packet before + // disconnecting. + if uint32(msg.ProtocolVersion) < minAcceptableProtocolVersion { reason := fmt.Sprintf("protocol version must be %d or greater", - wire.MultipleAddressVersion) - rejectMsg := wire.NewMsgReject(msg.Command(), wire.RejectObsolete, - reason) - return p.writeMessage(rejectMsg, wire.LatestEncoding) + minAcceptableProtocolVersion) + return errors.New(reason) } // Updating a bunch of stats including block based stats, and the @@ -1412,9 +1415,8 @@ cleanup: // inHandler handles all incoming messages for the peer. It must be run as a // goroutine. func (p *Peer) inHandler() { - // Peers must complete the initial version negotiation within a shorter - // timeframe than a general idle timeout. The timer is then reset below - // to idleTimeout for all future messages. + // The timer is stopped when a new message is received and reset after it + // is processed. idleTimer := time.AfterFunc(idleTimeout, func() { log.Warnf("Peer %s no answer for %s -- disconnecting", p, idleTimeout) p.Disconnect() diff --git a/peer/peer_test.go b/peer/peer_test.go index 2db94404..b6dcc83e 100644 --- a/peer/peer_test.go +++ b/peer/peer_test.go @@ -58,7 +58,10 @@ func (c conn) RemoteAddr() net.Addr { // Close handles closing the connection. func (c conn) Close() error { - return nil + if c.Closer == nil { + return nil + } + return c.Closer.Close() } func (c conn) SetDeadline(t time.Time) error { return nil } @@ -80,9 +83,11 @@ func pipe(c1, c2 *conn) (*conn, *conn) { r2, w2 := io.Pipe() c1.Writer = w1 + c1.Closer = w1 c2.Reader = r1 c1.Reader = r2 c2.Writer = w2 + c2.Closer = w2 return c1, c2 } @@ -673,6 +678,109 @@ func TestOutboundPeer(t *testing.T) { p2.Disconnect() } +// Tests that the node disconnects from peers with an unsupported protocol +// version. +func TestUnsupportedVersionPeer(t *testing.T) { + peerCfg := &peer.Config{ + UserAgentName: "peer", + UserAgentVersion: "1.0", + UserAgentComments: []string{"comment"}, + ChainParams: &chaincfg.MainNetParams, + Services: 0, + } + + localNA := wire.NewNetAddressIPPort( + net.ParseIP("10.0.0.1"), + uint16(8333), + wire.SFNodeNetwork, + ) + remoteNA := wire.NewNetAddressIPPort( + net.ParseIP("10.0.0.2"), + uint16(8333), + wire.SFNodeNetwork, + ) + localConn, remoteConn := pipe( + &conn{laddr: "10.0.0.1:8333", raddr: "10.0.0.2:8333"}, + &conn{laddr: "10.0.0.2:8333", raddr: "10.0.0.1:8333"}, + ) + + p, err := peer.NewOutboundPeer(peerCfg, "10.0.0.1:8333") + if err != nil { + t.Fatalf("NewOutboundPeer: unexpected err - %v\n", err) + } + p.AssociateConnection(localConn) + + // Read outbound messages to peer into a channel + outboundMessages := make(chan wire.Message) + go func() { + for { + _, msg, _, err := wire.ReadMessageN( + remoteConn, + p.ProtocolVersion(), + peerCfg.ChainParams.Net, + ) + if err == io.EOF { + close(outboundMessages) + return + } + if err != nil { + t.Errorf("Error reading message from local node: %v\n", err) + return + } + + outboundMessages <- msg + } + }() + + // Read version message sent to remote peer + select { + case msg := <-outboundMessages: + if _, ok := msg.(*wire.MsgVersion); !ok { + t.Fatalf("Expected version message, got [%s]", msg.Command()) + } + case <-time.After(time.Second): + t.Fatal("Peer did not send version message") + } + + // Remote peer writes version message advertising invalid protocol version 1 + invalidVersionMsg := wire.NewMsgVersion(remoteNA, localNA, 0, 0) + invalidVersionMsg.ProtocolVersion = 1 + + _, err = wire.WriteMessageN( + remoteConn.Writer, + invalidVersionMsg, + uint32(invalidVersionMsg.ProtocolVersion), + peerCfg.ChainParams.Net, + ) + if err != nil { + t.Fatalf("wire.WriteMessageN: unexpected err - %v\n", err) + } + + // Expect peer to disconnect automatically + disconnected := make(chan struct{}) + go func() { + p.WaitForDisconnect() + disconnected <- struct{}{} + }() + + select { + case <-disconnected: + close(disconnected) + case <-time.After(time.Second): + t.Fatal("Peer did not automatically disconnect") + } + + // Expect no further outbound messages from peer + select { + case msg, chanOpen := <-outboundMessages: + if chanOpen { + t.Fatalf("Expected no further messages, received [%s]", msg.Command()) + } + case <-time.After(time.Second): + t.Fatal("Timeout waiting for remote reader to close") + } +} + func init() { // Allow self connection when running the tests. peer.TstAllowSelfConns()