diff --git a/peer/peer.go b/peer/peer.go index 61bc619f..11306ace 100644 --- a/peer/peer.go +++ b/peer/peer.go @@ -1377,20 +1377,12 @@ out: break out case *wire.MsgVerAck: - - // No read lock is necessary because verAckReceived is not written - // to in any other goroutine. - if p.verAckReceived { - log.Infof("Already received 'verack' from peer %v -- "+ - "disconnecting", p) - break out - } - p.flagsMtx.Lock() - p.verAckReceived = true - p.flagsMtx.Unlock() - if p.cfg.Listeners.OnVerAck != nil { - p.cfg.Listeners.OnVerAck(p, msg) - } + // Limit to one verack message per peer. + p.PushRejectMsg( + msg.Command(), wire.RejectDuplicate, + "duplicate verack message", nil, true, + ) + break out case *wire.MsgGetAddr: if p.cfg.Listeners.OnGetAddr != nil { @@ -1974,6 +1966,40 @@ func (p *Peer) readRemoteVersionMsg() error { return nil } +// readRemoteVerAckMsg waits for the next message to arrive from the remote +// peer. If this message is not a verack message, then an error is returned. +// This method is to be used as part of the version negotiation upon a new +// connection. +func (p *Peer) readRemoteVerAckMsg() error { + // Read the next message from the wire. + remoteMsg, _, err := p.readMessage(wire.LatestEncoding) + if err != nil { + return err + } + + // It should be a verack message, otherwise send a reject message to the + // peer explaining why. + msg, ok := remoteMsg.(*wire.MsgVerAck) + if !ok { + reason := "a verack message must follow version" + rejectMsg := wire.NewMsgReject( + msg.Command(), wire.RejectMalformed, reason, + ) + _ = p.writeMessage(rejectMsg, wire.LatestEncoding) + return errors.New(reason) + } + + p.flagsMtx.Lock() + p.verAckReceived = true + p.flagsMtx.Unlock() + + if p.cfg.Listeners.OnVerAck != nil { + p.cfg.Listeners.OnVerAck(p, msg) + } + + return nil +} + // localVersionMsg creates a version message that can be used to send to the // remote peer. func (p *Peer) localVersionMsg() (*wire.MsgVersion, error) { @@ -2046,26 +2072,53 @@ func (p *Peer) writeLocalVersionMsg() error { return p.writeMessage(localVerMsg, wire.LatestEncoding) } -// negotiateInboundProtocol waits to receive a version message from the peer -// then sends our version message. If the events do not occur in that order then -// it returns an error. +// negotiateInboundProtocol performs the negotiation protocol for an inbound +// peer. The events should occur in the following order, otherwise an error is +// returned: +// +// 1. Remote peer sends their version. +// 2. We send our version. +// 3. We send our verack. +// 4. Remote peer sends their verack. func (p *Peer) negotiateInboundProtocol() error { if err := p.readRemoteVersionMsg(); err != nil { return err } - return p.writeLocalVersionMsg() + if err := p.writeLocalVersionMsg(); err != nil { + return err + } + + err := p.writeMessage(wire.NewMsgVerAck(), wire.LatestEncoding) + if err != nil { + return err + } + + return p.readRemoteVerAckMsg() } -// negotiateOutboundProtocol sends our version message then waits to receive a -// version message from the peer. If the events do not occur in that order then -// it returns an error. +// negotiateOutoundProtocol performs the negotiation protocol for an outbound +// peer. The events should occur in the following order, otherwise an error is +// returned: +// +// 1. We send our version. +// 2. Remote peer sends their version. +// 3. Remote peer sends their verack. +// 4. We send our verack. func (p *Peer) negotiateOutboundProtocol() error { if err := p.writeLocalVersionMsg(); err != nil { return err } - return p.readRemoteVersionMsg() + if err := p.readRemoteVersionMsg(); err != nil { + return err + } + + if err := p.readRemoteVerAckMsg(); err != nil { + return err + } + + return p.writeMessage(wire.NewMsgVerAck(), wire.LatestEncoding) } // start begins processing input and output messages. @@ -2102,8 +2155,6 @@ func (p *Peer) start() error { go p.outHandler() go p.pingHandler() - // Send our verack message now that the IO processing machinery has started. - p.QueueMessage(wire.NewMsgVerAck(), nil) return nil } diff --git a/wire/message.go b/wire/message.go index 4f03cf56..e9376477 100644 --- a/wire/message.go +++ b/wire/message.go @@ -321,9 +321,13 @@ func WriteMessageWithEncodingN(w io.Writer, msg Message, pver uint32, return totalBytes, err } - // Write payload. - n, err = w.Write(payload) - totalBytes += n + // Only write the payload if there is one, e.g., verack messages don't + // have one. + if len(payload) > 0 { + n, err = w.Write(payload) + totalBytes += n + } + return totalBytes, err }