Merge pull request #1462 from wpaulino/verack-version-handshake

peer: include verack as part of version handshake
This commit is contained in:
Olaoluwa Osuntokun 2019-08-23 17:37:49 -07:00 committed by GitHub
commit 130ea5bddd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 82 additions and 27 deletions

View file

@ -1377,20 +1377,12 @@ out:
break out
case *wire.MsgVerAck:
// No read lock is necessary because verAckReceived is not written
// to in any other goroutine.
if p.verAckReceived {
log.Infof("Already received 'verack' from peer %v -- "+
"disconnecting", p)
break out
}
p.flagsMtx.Lock()
p.verAckReceived = true
p.flagsMtx.Unlock()
if p.cfg.Listeners.OnVerAck != nil {
p.cfg.Listeners.OnVerAck(p, msg)
}
// Limit to one verack message per peer.
p.PushRejectMsg(
msg.Command(), wire.RejectDuplicate,
"duplicate verack message", nil, true,
)
break out
case *wire.MsgGetAddr:
if p.cfg.Listeners.OnGetAddr != nil {
@ -1974,6 +1966,40 @@ func (p *Peer) readRemoteVersionMsg() error {
return nil
}
// readRemoteVerAckMsg waits for the next message to arrive from the remote
// peer. If this message is not a verack message, then an error is returned.
// This method is to be used as part of the version negotiation upon a new
// connection.
func (p *Peer) readRemoteVerAckMsg() error {
// Read the next message from the wire.
remoteMsg, _, err := p.readMessage(wire.LatestEncoding)
if err != nil {
return err
}
// It should be a verack message, otherwise send a reject message to the
// peer explaining why.
msg, ok := remoteMsg.(*wire.MsgVerAck)
if !ok {
reason := "a verack message must follow version"
rejectMsg := wire.NewMsgReject(
msg.Command(), wire.RejectMalformed, reason,
)
_ = p.writeMessage(rejectMsg, wire.LatestEncoding)
return errors.New(reason)
}
p.flagsMtx.Lock()
p.verAckReceived = true
p.flagsMtx.Unlock()
if p.cfg.Listeners.OnVerAck != nil {
p.cfg.Listeners.OnVerAck(p, msg)
}
return nil
}
// localVersionMsg creates a version message that can be used to send to the
// remote peer.
func (p *Peer) localVersionMsg() (*wire.MsgVersion, error) {
@ -2046,26 +2072,53 @@ func (p *Peer) writeLocalVersionMsg() error {
return p.writeMessage(localVerMsg, wire.LatestEncoding)
}
// negotiateInboundProtocol waits to receive a version message from the peer
// then sends our version message. If the events do not occur in that order then
// it returns an error.
// negotiateInboundProtocol performs the negotiation protocol for an inbound
// peer. The events should occur in the following order, otherwise an error is
// returned:
//
// 1. Remote peer sends their version.
// 2. We send our version.
// 3. We send our verack.
// 4. Remote peer sends their verack.
func (p *Peer) negotiateInboundProtocol() error {
if err := p.readRemoteVersionMsg(); err != nil {
return err
}
return p.writeLocalVersionMsg()
if err := p.writeLocalVersionMsg(); err != nil {
return err
}
err := p.writeMessage(wire.NewMsgVerAck(), wire.LatestEncoding)
if err != nil {
return err
}
return p.readRemoteVerAckMsg()
}
// negotiateOutboundProtocol sends our version message then waits to receive a
// version message from the peer. If the events do not occur in that order then
// it returns an error.
// negotiateOutoundProtocol performs the negotiation protocol for an outbound
// peer. The events should occur in the following order, otherwise an error is
// returned:
//
// 1. We send our version.
// 2. Remote peer sends their version.
// 3. Remote peer sends their verack.
// 4. We send our verack.
func (p *Peer) negotiateOutboundProtocol() error {
if err := p.writeLocalVersionMsg(); err != nil {
return err
}
return p.readRemoteVersionMsg()
if err := p.readRemoteVersionMsg(); err != nil {
return err
}
if err := p.readRemoteVerAckMsg(); err != nil {
return err
}
return p.writeMessage(wire.NewMsgVerAck(), wire.LatestEncoding)
}
// start begins processing input and output messages.
@ -2102,8 +2155,6 @@ func (p *Peer) start() error {
go p.outHandler()
go p.pingHandler()
// Send our verack message now that the IO processing machinery has started.
p.QueueMessage(wire.NewMsgVerAck(), nil)
return nil
}

View file

@ -321,9 +321,13 @@ func WriteMessageWithEncodingN(w io.Writer, msg Message, pver uint32,
return totalBytes, err
}
// Write payload.
n, err = w.Write(payload)
totalBytes += n
// Only write the payload if there is one, e.g., verack messages don't
// have one.
if len(payload) > 0 {
n, err = w.Write(payload)
totalBytes += n
}
return totalBytes, err
}