Merge pull request #1462 from wpaulino/verack-version-handshake
peer: include verack as part of version handshake
This commit is contained in:
commit
130ea5bddd
2 changed files with 82 additions and 27 deletions
99
peer/peer.go
99
peer/peer.go
|
@ -1377,20 +1377,12 @@ out:
|
||||||
break out
|
break out
|
||||||
|
|
||||||
case *wire.MsgVerAck:
|
case *wire.MsgVerAck:
|
||||||
|
// Limit to one verack message per peer.
|
||||||
// No read lock is necessary because verAckReceived is not written
|
p.PushRejectMsg(
|
||||||
// to in any other goroutine.
|
msg.Command(), wire.RejectDuplicate,
|
||||||
if p.verAckReceived {
|
"duplicate verack message", nil, true,
|
||||||
log.Infof("Already received 'verack' from peer %v -- "+
|
)
|
||||||
"disconnecting", p)
|
break out
|
||||||
break out
|
|
||||||
}
|
|
||||||
p.flagsMtx.Lock()
|
|
||||||
p.verAckReceived = true
|
|
||||||
p.flagsMtx.Unlock()
|
|
||||||
if p.cfg.Listeners.OnVerAck != nil {
|
|
||||||
p.cfg.Listeners.OnVerAck(p, msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
case *wire.MsgGetAddr:
|
case *wire.MsgGetAddr:
|
||||||
if p.cfg.Listeners.OnGetAddr != nil {
|
if p.cfg.Listeners.OnGetAddr != nil {
|
||||||
|
@ -1974,6 +1966,40 @@ func (p *Peer) readRemoteVersionMsg() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// readRemoteVerAckMsg waits for the next message to arrive from the remote
|
||||||
|
// peer. If this message is not a verack message, then an error is returned.
|
||||||
|
// This method is to be used as part of the version negotiation upon a new
|
||||||
|
// connection.
|
||||||
|
func (p *Peer) readRemoteVerAckMsg() error {
|
||||||
|
// Read the next message from the wire.
|
||||||
|
remoteMsg, _, err := p.readMessage(wire.LatestEncoding)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// It should be a verack message, otherwise send a reject message to the
|
||||||
|
// peer explaining why.
|
||||||
|
msg, ok := remoteMsg.(*wire.MsgVerAck)
|
||||||
|
if !ok {
|
||||||
|
reason := "a verack message must follow version"
|
||||||
|
rejectMsg := wire.NewMsgReject(
|
||||||
|
msg.Command(), wire.RejectMalformed, reason,
|
||||||
|
)
|
||||||
|
_ = p.writeMessage(rejectMsg, wire.LatestEncoding)
|
||||||
|
return errors.New(reason)
|
||||||
|
}
|
||||||
|
|
||||||
|
p.flagsMtx.Lock()
|
||||||
|
p.verAckReceived = true
|
||||||
|
p.flagsMtx.Unlock()
|
||||||
|
|
||||||
|
if p.cfg.Listeners.OnVerAck != nil {
|
||||||
|
p.cfg.Listeners.OnVerAck(p, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// localVersionMsg creates a version message that can be used to send to the
|
// localVersionMsg creates a version message that can be used to send to the
|
||||||
// remote peer.
|
// remote peer.
|
||||||
func (p *Peer) localVersionMsg() (*wire.MsgVersion, error) {
|
func (p *Peer) localVersionMsg() (*wire.MsgVersion, error) {
|
||||||
|
@ -2046,26 +2072,53 @@ func (p *Peer) writeLocalVersionMsg() error {
|
||||||
return p.writeMessage(localVerMsg, wire.LatestEncoding)
|
return p.writeMessage(localVerMsg, wire.LatestEncoding)
|
||||||
}
|
}
|
||||||
|
|
||||||
// negotiateInboundProtocol waits to receive a version message from the peer
|
// negotiateInboundProtocol performs the negotiation protocol for an inbound
|
||||||
// then sends our version message. If the events do not occur in that order then
|
// peer. The events should occur in the following order, otherwise an error is
|
||||||
// it returns an error.
|
// returned:
|
||||||
|
//
|
||||||
|
// 1. Remote peer sends their version.
|
||||||
|
// 2. We send our version.
|
||||||
|
// 3. We send our verack.
|
||||||
|
// 4. Remote peer sends their verack.
|
||||||
func (p *Peer) negotiateInboundProtocol() error {
|
func (p *Peer) negotiateInboundProtocol() error {
|
||||||
if err := p.readRemoteVersionMsg(); err != nil {
|
if err := p.readRemoteVersionMsg(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return p.writeLocalVersionMsg()
|
if err := p.writeLocalVersionMsg(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err := p.writeMessage(wire.NewMsgVerAck(), wire.LatestEncoding)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return p.readRemoteVerAckMsg()
|
||||||
}
|
}
|
||||||
|
|
||||||
// negotiateOutboundProtocol sends our version message then waits to receive a
|
// negotiateOutoundProtocol performs the negotiation protocol for an outbound
|
||||||
// version message from the peer. If the events do not occur in that order then
|
// peer. The events should occur in the following order, otherwise an error is
|
||||||
// it returns an error.
|
// returned:
|
||||||
|
//
|
||||||
|
// 1. We send our version.
|
||||||
|
// 2. Remote peer sends their version.
|
||||||
|
// 3. Remote peer sends their verack.
|
||||||
|
// 4. We send our verack.
|
||||||
func (p *Peer) negotiateOutboundProtocol() error {
|
func (p *Peer) negotiateOutboundProtocol() error {
|
||||||
if err := p.writeLocalVersionMsg(); err != nil {
|
if err := p.writeLocalVersionMsg(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return p.readRemoteVersionMsg()
|
if err := p.readRemoteVersionMsg(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := p.readRemoteVerAckMsg(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return p.writeMessage(wire.NewMsgVerAck(), wire.LatestEncoding)
|
||||||
}
|
}
|
||||||
|
|
||||||
// start begins processing input and output messages.
|
// start begins processing input and output messages.
|
||||||
|
@ -2102,8 +2155,6 @@ func (p *Peer) start() error {
|
||||||
go p.outHandler()
|
go p.outHandler()
|
||||||
go p.pingHandler()
|
go p.pingHandler()
|
||||||
|
|
||||||
// Send our verack message now that the IO processing machinery has started.
|
|
||||||
p.QueueMessage(wire.NewMsgVerAck(), nil)
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -321,9 +321,13 @@ func WriteMessageWithEncodingN(w io.Writer, msg Message, pver uint32,
|
||||||
return totalBytes, err
|
return totalBytes, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write payload.
|
// Only write the payload if there is one, e.g., verack messages don't
|
||||||
n, err = w.Write(payload)
|
// have one.
|
||||||
totalBytes += n
|
if len(payload) > 0 {
|
||||||
|
n, err = w.Write(payload)
|
||||||
|
totalBytes += n
|
||||||
|
}
|
||||||
|
|
||||||
return totalBytes, err
|
return totalBytes, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue