diff --git a/peer/peer.go b/peer/peer.go index b837a115..ee787e37 100644 --- a/peer/peer.go +++ b/peer/peer.go @@ -188,7 +188,7 @@ type MessageListeners struct { // not directly provide a callback. OnRead func(p *Peer, bytesRead int, msg wire.Message, err error) - // OnWrite is invoked when a peer receives a bitcoin message. It + // OnWrite is invoked when we write a bitcoin message to a peer. It // consists of the number of bytes written, the message, and whether or // not an error in the write occurred. This can be useful for // circumstances such as keeping track of server-wide byte counts. @@ -735,15 +735,15 @@ func (p *Peer) WantsHeaders() bool { return p.sendHeadersPreferred } -// pushVersionMsg sends a version message to the connected peer using the -// current state. -func (p *Peer) pushVersionMsg() error { +// localVersionMsg creates a version message that can be used to send to the +// remote peer. +func (p *Peer) localVersionMsg() (*wire.MsgVersion, error) { var blockNum int32 if p.cfg.NewestBlock != nil { var err error _, blockNum, err = p.cfg.NewestBlock() if err != nil { - return err + return nil, err } } @@ -775,7 +775,7 @@ func (p *Peer) pushVersionMsg() error { // recently seen nonces. nonce, err := wire.RandomUint64() if err != nil { - return err + return nil, err } sentNonces.Add(nonce) @@ -810,8 +810,7 @@ func (p *Peer) pushVersionMsg() error { // Advertise if inv messages for transactions are desired. msg.DisableRelayTx = p.cfg.DisableRelayTx - p.QueueMessage(msg, nil) - return nil + return msg, nil } // PushAddrMsg sends an addr message to the connected peer using the provided @@ -913,8 +912,8 @@ func (p *Peer) PushGetHeadersMsg(locator blockchain.BlockLocator, stopHash *wire p.prevGetHdrsMtx.Unlock() if isDuplicate { - log.Tracef("Filtering duplicate [getheaders] with begin "+ - "hash %v", beginHash) + log.Tracef("Filtering duplicate [getheaders] with begin hash %v", + beginHash) return nil } @@ -974,10 +973,10 @@ func (p *Peer) PushRejectMsg(command string, code wire.RejectCode, reason string <-doneChan } -// handleVersionMsg is invoked when a peer receives a version bitcoin message -// and is used to negotiate the protocol version details as well as kick start -// the communications. -func (p *Peer) handleVersionMsg(msg *wire.MsgVersion) error { +// handleRemoteVersionMsg is invoked when a version bitcoin message is received +// from the remote peer. It will return an error if the remote peer's version +// is not compatible with ours. +func (p *Peer) handleRemoteVersionMsg(msg *wire.MsgVersion) error { // Detect self connections. if !allowSelfConns && sentNonces.Exists(msg.Nonce) { return errors.New("disconnecting peer connected to self") @@ -991,21 +990,9 @@ func (p *Peer) handleVersionMsg(msg *wire.MsgVersion) error { // disconnecting. reason := fmt.Sprintf("protocol version must be %d or greater", wire.MultipleAddressVersion) - p.PushRejectMsg(msg.Command(), wire.RejectObsolete, reason, - nil, true) - return errors.New(reason) - } - - // Limit to one version message per peer. - // No read lock is necessary because versionKnown is not written to in any - // other goroutine - if p.versionKnown { - // Send an reject message indicating the version message was - // incorrectly sent twice and wait for the message to be sent - // before disconnecting. - p.PushRejectMsg(msg.Command(), wire.RejectDuplicate, - "duplicate version message", nil, true) - return errors.New("only one version message per peer is allowed") + rejectMsg := wire.NewMsgReject(msg.Command(), wire.RejectObsolete, + reason) + return p.writeMessage(rejectMsg) } // Updating a bunch of stats. @@ -1030,27 +1017,6 @@ func (p *Peer) handleVersionMsg(msg *wire.MsgVersion) error { // Set the remote peer's user agent. p.userAgent = msg.UserAgent p.flagsMtx.Unlock() - - // Inbound connections. - if p.inbound { - // Set up a NetAddress for the peer to be used with AddrManager. - // We only do this inbound because outbound set this up - // at connection time and no point recomputing. - na, err := newNetAddress(p.conn.RemoteAddr(), p.services) - if err != nil { - return err - } - p.na = na - - // Send version. - err = p.pushVersionMsg() - if err != nil { - return err - } - } - - // Send verack. - p.QueueMessage(wire.NewMsgVerAck(), nil) return nil } @@ -1147,18 +1113,6 @@ func (p *Peer) writeMessage(msg wire.Message) error { if atomic.LoadInt32(&p.disconnect) != 0 { return nil } - if !p.VersionKnown() { - switch msg.(type) { - case *wire.MsgVersion: - // This is OK. - case *wire.MsgReject: - // This is OK. - default: - // Drop all messages other than version and reject if - // the handshake has not already been done. - return nil - } - } // Use closures to log expensive operations so they are only run when // the logging level requires it. @@ -1194,13 +1148,16 @@ func (p *Peer) writeMessage(msg wire.Message) error { return err } -// isAllowedByRegression returns whether or not the passed error is allowed by -// regression tests without disconnecting the peer. In particular, regression -// tests need to be allowed to send malformed messages without the peer being -// disconnected. -func (p *Peer) isAllowedByRegression(err error) bool { - // Don't allow the error if it's not specifically a malformed message - // error. +// isAllowedReadError returns whether or not the passed error is allowed without +// disconnecting the peer. In particular, regression tests need to be allowed +// to send malformed messages without the peer being disconnected. +func (p *Peer) isAllowedReadError(err error) bool { + // Only allow read errors in regression test mode. + if p.cfg.ChainParams.Net != wire.TestNet { + return false + } + + // Don't allow the error if it's not specifically a malformed message error. if _, ok := err.(*wire.MessageError); !ok { return false } @@ -1220,12 +1177,6 @@ func (p *Peer) isAllowedByRegression(err error) bool { return true } -// isRegTestNetwork returns whether or not the peer is running on the regression -// test network. -func (p *Peer) isRegTestNetwork() bool { - return p.cfg.ChainParams.Net == wire.TestNet -} - // shouldHandleReadError returns whether or not the passed error, which is // expected to have come from reading from the remote peer in the inHandler, // should be logged and responded to with a reject message. @@ -1437,14 +1388,8 @@ func (p *Peer) inHandler() { // Peers must complete the initial version negotiation within a shorter // timeframe than a general idle timeout. The timer is then reset below // to idleTimeout for all future messages. - idleTimer := time.AfterFunc(negotiateTimeout, func() { - if p.VersionKnown() { - log.Warnf("Peer %s no answer for %s -- disconnecting", - p, idleTimeout) - } else { - log.Debugf("Peer %s no valid version message for %s -- "+ - "disconnecting", p, negotiateTimeout) - } + idleTimer := time.AfterFunc(idleTimeout, func() { + log.Warnf("Peer %s no answer for %s -- disconnecting", p, idleTimeout) p.Disconnect() }) @@ -1456,13 +1401,11 @@ out: rmsg, buf, err := p.readMessage() idleTimer.Stop() if err != nil { - // In order to allow regression tests with malformed - // messages, don't disconnect the peer when we're in - // regression test mode and the error is one of the - // allowed errors. - if p.isRegTestNetwork() && p.isAllowedByRegression(err) { - log.Errorf("Allowed regression test error "+ - "from %s: %v", p, err) + // In order to allow regression tests with malformed messages, don't + // disconnect the peer when we're in regression test mode and the + // error is one of the allowed errors. + if p.isAllowedReadError(err) { + log.Errorf("Allowed test error from %s: %v", p, err) idleTimer.Reset(idleTimeout) continue } @@ -1471,70 +1414,40 @@ out: // local peer is not forcibly disconnecting and the // remote peer has not disconnected. if p.shouldHandleReadError(err) { - errMsg := fmt.Sprintf("Can't read message "+ - "from %s: %v", p, err) + errMsg := fmt.Sprintf("Can't read message from %s: %v", p, err) log.Errorf(errMsg) - // Push a reject message for the malformed - // message and wait for the message to be sent - // before disconnecting. + // Push a reject message for the malformed message and wait for + // the message to be sent before disconnecting. // - // NOTE: Ideally this would include the command - // in the header if at least that much of the - // message was valid, but that is not currently - // exposed by wire, so just used malformed for - // the command. - p.PushRejectMsg("malformed", - wire.RejectMalformed, errMsg, nil, true) + // NOTE: Ideally this would include the command in the header if + // at least that much of the message was valid, but that is not + // currently exposed by wire, so just used malformed for the + // command. + p.PushRejectMsg("malformed", wire.RejectMalformed, errMsg, nil, + true) } break out } atomic.StoreInt64(&p.lastRecv, time.Now().Unix()) p.stallControl <- stallControlMsg{sccReceiveMessage, rmsg} - // Ensure version message comes first. - if vmsg, ok := rmsg.(*wire.MsgVersion); !ok && !p.VersionKnown() { - errStr := "A version message must precede all others" - log.Errorf(errStr) - - // Push a reject message and wait for the message to be - // sent before disconnecting. - p.PushRejectMsg(vmsg.Command(), wire.RejectMalformed, - errStr, nil, true) - break out - } - // Handle each supported message type. p.stallControl <- stallControlMsg{sccHandlerStart, rmsg} switch msg := rmsg.(type) { case *wire.MsgVersion: - err := p.handleVersionMsg(msg) - if err != nil { - log.Debugf("New peer %v - error negotiating protocol: %v", - p, err) - p.Disconnect() - break out - } - if p.cfg.Listeners.OnVersion != nil { - p.cfg.Listeners.OnVersion(p, msg) - } + + p.PushRejectMsg(msg.Command(), wire.RejectDuplicate, + "duplicate version message", nil, true) + break out case *wire.MsgVerAck: - p.flagsMtx.Lock() - versionSent := p.versionSent - p.flagsMtx.Unlock() - if !versionSent { - log.Infof("Received 'verack' from peer %v "+ - "before version was sent -- "+ - "disconnecting", p) - break out - } - // No read lock is necessary because verAckReceived is - // not written to in any other goroutine. + // No read lock is necessary because verAckReceived is not written + // to in any other goroutine. if p.verAckReceived { - log.Infof("Already received 'verack' from "+ - "peer %v -- disconnecting", p) + log.Infof("Already received 'verack' from peer %v -- "+ + "disconnecting", p) break out } p.flagsMtx.Lock() @@ -1830,13 +1743,6 @@ out: select { case msg := <-p.sendQueue: switch m := msg.msg.(type) { - case *wire.MsgVersion: - // Set the flag which indicates the version has - // been sent. - p.flagsMtx.Lock() - p.versionSent = true - p.flagsMtx.Unlock() - case *wire.MsgPing: // Only expects a pong message in later protocol // versions. Also set up statistics. @@ -1849,8 +1755,7 @@ out: } p.stallControl <- stallControlMsg{sccSendMessage, msg.msg} - err := p.writeMessage(msg.msg) - if err != nil { + if err := p.writeMessage(msg.msg); err != nil { p.Disconnect() if p.shouldLogWriteError(err) { log.Errorf("Failed to send message to "+ @@ -1956,15 +1861,28 @@ func (p *Peer) Connect(conn net.Conn) { return } - if p.inbound { - p.addr = conn.RemoteAddr().String() - } p.conn = conn p.timeConnected = time.Now() + if p.inbound { + p.addr = p.conn.RemoteAddr().String() + + // Set up a NetAddress for the peer to be used with AddrManager. We + // only do this inbound because outbound set this up at connection time + // and no point recomputing. + na, err := newNetAddress(p.conn.RemoteAddr(), p.services) + if err != nil { + log.Errorf("Cannot create remote net address: %v", err) + p.Disconnect() + return + } + p.na = na + } + go func() { if err := p.start(); err != nil { - log.Errorf("Cannot start peer %v: %v", p, err) + log.Warnf("Cannot start peer %v: %v", p, err) + p.Disconnect() } }() } @@ -1992,26 +1910,38 @@ func (p *Peer) Disconnect() { close(p.quit) } -// Start begins processing input and output messages. It also sends the initial -// version message for outbound connections to start the negotiation process. +// start begins processing input and output messages. func (p *Peer) start() error { log.Tracef("Starting peer %s", p) - // Send an initial version message if this is an outbound connection. - if !p.inbound { - if err := p.pushVersionMsg(); err != nil { - log.Errorf("Can't send outbound version message %v", err) - p.Disconnect() + negotiateErr := make(chan error) + go func() { + if p.inbound { + negotiateErr <- p.negotiateInboundProtocol() + } else { + negotiateErr <- p.negotiateOutboundProtocol() + } + }() + + // Negotiate the protocol within the specified negotiateTimeout. + select { + case err := <-negotiateErr: + if err != nil { return err } + case <-time.After(negotiateTimeout): + return errors.New("protocol negotiation timeout") } - // Start processing input and output. + // The protocol has been negotiated successfully so start processing input + // and output messages. go p.stallHandler() go p.inHandler() go p.queueHandler() go p.outHandler() + // Send our verack message now that the IO processing machinery has started. + p.QueueMessage(wire.NewMsgVerAck(), nil) return nil } @@ -2023,6 +1953,79 @@ func (p *Peer) WaitForDisconnect() { <-p.quit } +// readRemoteVersionMsg waits for the next message to arrive from the remote +// peer. If the next message is not a version message or the version is not +// acceptable then return an error. +func (p *Peer) readRemoteVersionMsg() error { + + // Read their version message. + msg, _, err := p.readMessage() + if err != nil { + return err + } + + remoteVerMsg, ok := msg.(*wire.MsgVersion) + if !ok { + errStr := "A version message must precede all others" + log.Errorf(errStr) + + rejectMsg := wire.NewMsgReject(msg.Command(), wire.RejectMalformed, + errStr) + return p.writeMessage(rejectMsg) + } + + if err := p.handleRemoteVersionMsg(remoteVerMsg); err != nil { + return err + } + + if p.cfg.Listeners.OnVersion != nil { + p.cfg.Listeners.OnVersion(p, remoteVerMsg) + } + return nil +} + +// writeLocalVersionMsg writes our version message to the remote peer. +func (p *Peer) writeLocalVersionMsg() error { + + localVerMsg, err := p.localVersionMsg() + if err != nil { + return err + } + + if err := p.writeMessage(localVerMsg); err != nil { + return err + } + + p.flagsMtx.Lock() + p.versionSent = true + p.flagsMtx.Unlock() + return nil +} + +// negotiateInboundProtocol waits to receive a version message from the peer +// then sends our version message. If the events do not occur in that order then +// it returns an error. +func (p *Peer) negotiateInboundProtocol() error { + + if err := p.readRemoteVersionMsg(); err != nil { + return err + } + + return p.writeLocalVersionMsg() +} + +// negotiateOutboundProtocol sends our version message then waits to receive a +// version message from the peer. If the events do not occur in that order then +// it returns an error. +func (p *Peer) negotiateOutboundProtocol() error { + + if err := p.writeLocalVersionMsg(); err != nil { + return err + } + + return p.readRemoteVersionMsg() +} + // newPeerBase returns a new base bitcoin peer based on the inbound flag. This // is used by the NewInboundPeer and NewOutboundPeer functions to perform base // setup needed by both types of peers. diff --git a/peer/peer_test.go b/peer/peer_test.go index f2475b9c..c05d6898 100644 --- a/peer/peer_test.go +++ b/peer/peer_test.go @@ -204,12 +204,15 @@ func testPeer(t *testing.T, p *peer.Peer, s peerStats) { // TestPeerConnection tests connection between inbound and outbound peers. func TestPeerConnection(t *testing.T) { - verack := make(chan struct{}, 1) + verack := make(chan struct{}) peerCfg := &peer.Config{ Listeners: peer.MessageListeners{ - OnWrite: func(p *peer.Peer, bytesWritten int, msg wire.Message, err error) { - switch msg.(type) { - case *wire.MsgVerAck: + OnVerAck: func(p *peer.Peer, msg *wire.MsgVerAck) { + verack <- struct{}{} + }, + OnWrite: func(p *peer.Peer, bytesWritten int, msg wire.Message, + err error) { + if _, ok := msg.(*wire.MsgVerAck); ok { verack <- struct{}{} } }, @@ -253,10 +256,10 @@ func TestPeerConnection(t *testing.T) { } outPeer.Connect(outConn) - for i := 0; i < 2; i++ { + for i := 0; i < 4; i++ { select { case <-verack: - case <-time.After(time.Second * 1): + case <-time.After(time.Second): return nil, nil, errors.New("verack timeout") } } @@ -279,10 +282,10 @@ func TestPeerConnection(t *testing.T) { } outPeer.Connect(outConn) - for i := 0; i < 2; i++ { + for i := 0; i < 4; i++ { select { case <-verack: - case <-time.After(time.Second * 1): + case <-time.After(time.Second): return nil, nil, errors.New("verack timeout") } } @@ -294,7 +297,7 @@ func TestPeerConnection(t *testing.T) { for i, test := range tests { inPeer, outPeer, err := test.setup() if err != nil { - t.Errorf("TestPeerConnection setup #%d: unexpected err %v\n", i, err) + t.Errorf("TestPeerConnection setup #%d: unexpected err %v", i, err) return } testPeer(t, inPeer, wantStats) @@ -302,6 +305,8 @@ func TestPeerConnection(t *testing.T) { inPeer.Disconnect() outPeer.Disconnect() + inPeer.WaitForDisconnect() + outPeer.WaitForDisconnect() } } @@ -547,6 +552,7 @@ func TestOutboundPeer(t *testing.T) { select { case <-disconnected: + close(disconnected) case <-time.After(time.Second): t.Fatal("Peer did not automatically disconnect.") } @@ -580,6 +586,7 @@ func TestOutboundPeer(t *testing.T) { } return hash, 234439, nil } + peerCfg.NewestBlock = newestBlock r1, w1 := io.Pipe() c1 := &conn{raddr: "10.0.0.1:8333", Writer: w1, Reader: r1} @@ -638,7 +645,8 @@ func TestOutboundPeer(t *testing.T) { t.Errorf("PushGetHeadersMsg: unexpected err %v\n", err) return } - p2.PushRejectMsg("block", wire.RejectMalformed, "malformed", nil, true) + + p2.PushRejectMsg("block", wire.RejectMalformed, "malformed", nil, false) p2.PushRejectMsg("block", wire.RejectInvalid, "invalid", nil, false) // Test Queue Messages