peer: Extract protocol negotiation from main read and write code paths.

This allows cleaner separation of the half-duplex version negotiation from the fully duplex message passing between peers.
This commit is contained in:
Jonathan Gillham 2016-02-10 18:01:55 +00:00
parent 777ccdade3
commit f3d759d783
2 changed files with 179 additions and 168 deletions

View file

@ -188,7 +188,7 @@ type MessageListeners struct {
// not directly provide a callback.
OnRead func(p *Peer, bytesRead int, msg wire.Message, err error)
// OnWrite is invoked when a peer receives a bitcoin message. It
// OnWrite is invoked when we write a bitcoin message to a peer. It
// consists of the number of bytes written, the message, and whether or
// not an error in the write occurred. This can be useful for
// circumstances such as keeping track of server-wide byte counts.
@ -735,15 +735,15 @@ func (p *Peer) WantsHeaders() bool {
return p.sendHeadersPreferred
}
// pushVersionMsg sends a version message to the connected peer using the
// current state.
func (p *Peer) pushVersionMsg() error {
// localVersionMsg creates a version message that can be used to send to the
// remote peer.
func (p *Peer) localVersionMsg() (*wire.MsgVersion, error) {
var blockNum int32
if p.cfg.NewestBlock != nil {
var err error
_, blockNum, err = p.cfg.NewestBlock()
if err != nil {
return err
return nil, err
}
}
@ -775,7 +775,7 @@ func (p *Peer) pushVersionMsg() error {
// recently seen nonces.
nonce, err := wire.RandomUint64()
if err != nil {
return err
return nil, err
}
sentNonces.Add(nonce)
@ -810,8 +810,7 @@ func (p *Peer) pushVersionMsg() error {
// Advertise if inv messages for transactions are desired.
msg.DisableRelayTx = p.cfg.DisableRelayTx
p.QueueMessage(msg, nil)
return nil
return msg, nil
}
// PushAddrMsg sends an addr message to the connected peer using the provided
@ -913,8 +912,8 @@ func (p *Peer) PushGetHeadersMsg(locator blockchain.BlockLocator, stopHash *wire
p.prevGetHdrsMtx.Unlock()
if isDuplicate {
log.Tracef("Filtering duplicate [getheaders] with begin "+
"hash %v", beginHash)
log.Tracef("Filtering duplicate [getheaders] with begin hash %v",
beginHash)
return nil
}
@ -974,10 +973,10 @@ func (p *Peer) PushRejectMsg(command string, code wire.RejectCode, reason string
<-doneChan
}
// handleVersionMsg is invoked when a peer receives a version bitcoin message
// and is used to negotiate the protocol version details as well as kick start
// the communications.
func (p *Peer) handleVersionMsg(msg *wire.MsgVersion) error {
// handleRemoteVersionMsg is invoked when a version bitcoin message is received
// from the remote peer. It will return an error if the remote peer's version
// is not compatible with ours.
func (p *Peer) handleRemoteVersionMsg(msg *wire.MsgVersion) error {
// Detect self connections.
if !allowSelfConns && sentNonces.Exists(msg.Nonce) {
return errors.New("disconnecting peer connected to self")
@ -991,21 +990,9 @@ func (p *Peer) handleVersionMsg(msg *wire.MsgVersion) error {
// disconnecting.
reason := fmt.Sprintf("protocol version must be %d or greater",
wire.MultipleAddressVersion)
p.PushRejectMsg(msg.Command(), wire.RejectObsolete, reason,
nil, true)
return errors.New(reason)
}
// Limit to one version message per peer.
// No read lock is necessary because versionKnown is not written to in any
// other goroutine
if p.versionKnown {
// Send an reject message indicating the version message was
// incorrectly sent twice and wait for the message to be sent
// before disconnecting.
p.PushRejectMsg(msg.Command(), wire.RejectDuplicate,
"duplicate version message", nil, true)
return errors.New("only one version message per peer is allowed")
rejectMsg := wire.NewMsgReject(msg.Command(), wire.RejectObsolete,
reason)
return p.writeMessage(rejectMsg)
}
// Updating a bunch of stats.
@ -1030,27 +1017,6 @@ func (p *Peer) handleVersionMsg(msg *wire.MsgVersion) error {
// Set the remote peer's user agent.
p.userAgent = msg.UserAgent
p.flagsMtx.Unlock()
// Inbound connections.
if p.inbound {
// Set up a NetAddress for the peer to be used with AddrManager.
// We only do this inbound because outbound set this up
// at connection time and no point recomputing.
na, err := newNetAddress(p.conn.RemoteAddr(), p.services)
if err != nil {
return err
}
p.na = na
// Send version.
err = p.pushVersionMsg()
if err != nil {
return err
}
}
// Send verack.
p.QueueMessage(wire.NewMsgVerAck(), nil)
return nil
}
@ -1147,18 +1113,6 @@ func (p *Peer) writeMessage(msg wire.Message) error {
if atomic.LoadInt32(&p.disconnect) != 0 {
return nil
}
if !p.VersionKnown() {
switch msg.(type) {
case *wire.MsgVersion:
// This is OK.
case *wire.MsgReject:
// This is OK.
default:
// Drop all messages other than version and reject if
// the handshake has not already been done.
return nil
}
}
// Use closures to log expensive operations so they are only run when
// the logging level requires it.
@ -1194,13 +1148,16 @@ func (p *Peer) writeMessage(msg wire.Message) error {
return err
}
// isAllowedByRegression returns whether or not the passed error is allowed by
// regression tests without disconnecting the peer. In particular, regression
// tests need to be allowed to send malformed messages without the peer being
// disconnected.
func (p *Peer) isAllowedByRegression(err error) bool {
// Don't allow the error if it's not specifically a malformed message
// error.
// isAllowedReadError returns whether or not the passed error is allowed without
// disconnecting the peer. In particular, regression tests need to be allowed
// to send malformed messages without the peer being disconnected.
func (p *Peer) isAllowedReadError(err error) bool {
// Only allow read errors in regression test mode.
if p.cfg.ChainParams.Net != wire.TestNet {
return false
}
// Don't allow the error if it's not specifically a malformed message error.
if _, ok := err.(*wire.MessageError); !ok {
return false
}
@ -1220,12 +1177,6 @@ func (p *Peer) isAllowedByRegression(err error) bool {
return true
}
// isRegTestNetwork returns whether or not the peer is running on the regression
// test network.
func (p *Peer) isRegTestNetwork() bool {
return p.cfg.ChainParams.Net == wire.TestNet
}
// shouldHandleReadError returns whether or not the passed error, which is
// expected to have come from reading from the remote peer in the inHandler,
// should be logged and responded to with a reject message.
@ -1437,14 +1388,8 @@ func (p *Peer) inHandler() {
// Peers must complete the initial version negotiation within a shorter
// timeframe than a general idle timeout. The timer is then reset below
// to idleTimeout for all future messages.
idleTimer := time.AfterFunc(negotiateTimeout, func() {
if p.VersionKnown() {
log.Warnf("Peer %s no answer for %s -- disconnecting",
p, idleTimeout)
} else {
log.Debugf("Peer %s no valid version message for %s -- "+
"disconnecting", p, negotiateTimeout)
}
idleTimer := time.AfterFunc(idleTimeout, func() {
log.Warnf("Peer %s no answer for %s -- disconnecting", p, idleTimeout)
p.Disconnect()
})
@ -1456,13 +1401,11 @@ out:
rmsg, buf, err := p.readMessage()
idleTimer.Stop()
if err != nil {
// In order to allow regression tests with malformed
// messages, don't disconnect the peer when we're in
// regression test mode and the error is one of the
// allowed errors.
if p.isRegTestNetwork() && p.isAllowedByRegression(err) {
log.Errorf("Allowed regression test error "+
"from %s: %v", p, err)
// In order to allow regression tests with malformed messages, don't
// disconnect the peer when we're in regression test mode and the
// error is one of the allowed errors.
if p.isAllowedReadError(err) {
log.Errorf("Allowed test error from %s: %v", p, err)
idleTimer.Reset(idleTimeout)
continue
}
@ -1471,70 +1414,40 @@ out:
// local peer is not forcibly disconnecting and the
// remote peer has not disconnected.
if p.shouldHandleReadError(err) {
errMsg := fmt.Sprintf("Can't read message "+
"from %s: %v", p, err)
errMsg := fmt.Sprintf("Can't read message from %s: %v", p, err)
log.Errorf(errMsg)
// Push a reject message for the malformed
// message and wait for the message to be sent
// before disconnecting.
// Push a reject message for the malformed message and wait for
// the message to be sent before disconnecting.
//
// NOTE: Ideally this would include the command
// in the header if at least that much of the
// message was valid, but that is not currently
// exposed by wire, so just used malformed for
// the command.
p.PushRejectMsg("malformed",
wire.RejectMalformed, errMsg, nil, true)
// NOTE: Ideally this would include the command in the header if
// at least that much of the message was valid, but that is not
// currently exposed by wire, so just used malformed for the
// command.
p.PushRejectMsg("malformed", wire.RejectMalformed, errMsg, nil,
true)
}
break out
}
atomic.StoreInt64(&p.lastRecv, time.Now().Unix())
p.stallControl <- stallControlMsg{sccReceiveMessage, rmsg}
// Ensure version message comes first.
if vmsg, ok := rmsg.(*wire.MsgVersion); !ok && !p.VersionKnown() {
errStr := "A version message must precede all others"
log.Errorf(errStr)
// Push a reject message and wait for the message to be
// sent before disconnecting.
p.PushRejectMsg(vmsg.Command(), wire.RejectMalformed,
errStr, nil, true)
break out
}
// Handle each supported message type.
p.stallControl <- stallControlMsg{sccHandlerStart, rmsg}
switch msg := rmsg.(type) {
case *wire.MsgVersion:
err := p.handleVersionMsg(msg)
if err != nil {
log.Debugf("New peer %v - error negotiating protocol: %v",
p, err)
p.Disconnect()
p.PushRejectMsg(msg.Command(), wire.RejectDuplicate,
"duplicate version message", nil, true)
break out
}
if p.cfg.Listeners.OnVersion != nil {
p.cfg.Listeners.OnVersion(p, msg)
}
case *wire.MsgVerAck:
p.flagsMtx.Lock()
versionSent := p.versionSent
p.flagsMtx.Unlock()
if !versionSent {
log.Infof("Received 'verack' from peer %v "+
"before version was sent -- "+
"disconnecting", p)
break out
}
// No read lock is necessary because verAckReceived is
// not written to in any other goroutine.
// No read lock is necessary because verAckReceived is not written
// to in any other goroutine.
if p.verAckReceived {
log.Infof("Already received 'verack' from "+
"peer %v -- disconnecting", p)
log.Infof("Already received 'verack' from peer %v -- "+
"disconnecting", p)
break out
}
p.flagsMtx.Lock()
@ -1830,13 +1743,6 @@ out:
select {
case msg := <-p.sendQueue:
switch m := msg.msg.(type) {
case *wire.MsgVersion:
// Set the flag which indicates the version has
// been sent.
p.flagsMtx.Lock()
p.versionSent = true
p.flagsMtx.Unlock()
case *wire.MsgPing:
// Only expects a pong message in later protocol
// versions. Also set up statistics.
@ -1849,8 +1755,7 @@ out:
}
p.stallControl <- stallControlMsg{sccSendMessage, msg.msg}
err := p.writeMessage(msg.msg)
if err != nil {
if err := p.writeMessage(msg.msg); err != nil {
p.Disconnect()
if p.shouldLogWriteError(err) {
log.Errorf("Failed to send message to "+
@ -1956,15 +1861,28 @@ func (p *Peer) Connect(conn net.Conn) {
return
}
if p.inbound {
p.addr = conn.RemoteAddr().String()
}
p.conn = conn
p.timeConnected = time.Now()
if p.inbound {
p.addr = p.conn.RemoteAddr().String()
// Set up a NetAddress for the peer to be used with AddrManager. We
// only do this inbound because outbound set this up at connection time
// and no point recomputing.
na, err := newNetAddress(p.conn.RemoteAddr(), p.services)
if err != nil {
log.Errorf("Cannot create remote net address: %v", err)
p.Disconnect()
return
}
p.na = na
}
go func() {
if err := p.start(); err != nil {
log.Errorf("Cannot start peer %v: %v", p, err)
log.Warnf("Cannot start peer %v: %v", p, err)
p.Disconnect()
}
}()
}
@ -1992,26 +1910,38 @@ func (p *Peer) Disconnect() {
close(p.quit)
}
// Start begins processing input and output messages. It also sends the initial
// version message for outbound connections to start the negotiation process.
// start begins processing input and output messages.
func (p *Peer) start() error {
log.Tracef("Starting peer %s", p)
// Send an initial version message if this is an outbound connection.
if !p.inbound {
if err := p.pushVersionMsg(); err != nil {
log.Errorf("Can't send outbound version message %v", err)
p.Disconnect()
negotiateErr := make(chan error)
go func() {
if p.inbound {
negotiateErr <- p.negotiateInboundProtocol()
} else {
negotiateErr <- p.negotiateOutboundProtocol()
}
}()
// Negotiate the protocol within the specified negotiateTimeout.
select {
case err := <-negotiateErr:
if err != nil {
return err
}
case <-time.After(negotiateTimeout):
return errors.New("protocol negotiation timeout")
}
// Start processing input and output.
// The protocol has been negotiated successfully so start processing input
// and output messages.
go p.stallHandler()
go p.inHandler()
go p.queueHandler()
go p.outHandler()
// Send our verack message now that the IO processing machinery has started.
p.QueueMessage(wire.NewMsgVerAck(), nil)
return nil
}
@ -2023,6 +1953,79 @@ func (p *Peer) WaitForDisconnect() {
<-p.quit
}
// readRemoteVersionMsg waits for the next message to arrive from the remote
// peer. If the next message is not a version message or the version is not
// acceptable then return an error.
func (p *Peer) readRemoteVersionMsg() error {
// Read their version message.
msg, _, err := p.readMessage()
if err != nil {
return err
}
remoteVerMsg, ok := msg.(*wire.MsgVersion)
if !ok {
errStr := "A version message must precede all others"
log.Errorf(errStr)
rejectMsg := wire.NewMsgReject(msg.Command(), wire.RejectMalformed,
errStr)
return p.writeMessage(rejectMsg)
}
if err := p.handleRemoteVersionMsg(remoteVerMsg); err != nil {
return err
}
if p.cfg.Listeners.OnVersion != nil {
p.cfg.Listeners.OnVersion(p, remoteVerMsg)
}
return nil
}
// writeLocalVersionMsg writes our version message to the remote peer.
func (p *Peer) writeLocalVersionMsg() error {
localVerMsg, err := p.localVersionMsg()
if err != nil {
return err
}
if err := p.writeMessage(localVerMsg); err != nil {
return err
}
p.flagsMtx.Lock()
p.versionSent = true
p.flagsMtx.Unlock()
return nil
}
// negotiateInboundProtocol waits to receive a version message from the peer
// then sends our version message. If the events do not occur in that order then
// it returns an error.
func (p *Peer) negotiateInboundProtocol() error {
if err := p.readRemoteVersionMsg(); err != nil {
return err
}
return p.writeLocalVersionMsg()
}
// negotiateOutboundProtocol sends our version message then waits to receive a
// version message from the peer. If the events do not occur in that order then
// it returns an error.
func (p *Peer) negotiateOutboundProtocol() error {
if err := p.writeLocalVersionMsg(); err != nil {
return err
}
return p.readRemoteVersionMsg()
}
// newPeerBase returns a new base bitcoin peer based on the inbound flag. This
// is used by the NewInboundPeer and NewOutboundPeer functions to perform base
// setup needed by both types of peers.

View file

@ -204,12 +204,15 @@ func testPeer(t *testing.T, p *peer.Peer, s peerStats) {
// TestPeerConnection tests connection between inbound and outbound peers.
func TestPeerConnection(t *testing.T) {
verack := make(chan struct{}, 1)
verack := make(chan struct{})
peerCfg := &peer.Config{
Listeners: peer.MessageListeners{
OnWrite: func(p *peer.Peer, bytesWritten int, msg wire.Message, err error) {
switch msg.(type) {
case *wire.MsgVerAck:
OnVerAck: func(p *peer.Peer, msg *wire.MsgVerAck) {
verack <- struct{}{}
},
OnWrite: func(p *peer.Peer, bytesWritten int, msg wire.Message,
err error) {
if _, ok := msg.(*wire.MsgVerAck); ok {
verack <- struct{}{}
}
},
@ -253,10 +256,10 @@ func TestPeerConnection(t *testing.T) {
}
outPeer.Connect(outConn)
for i := 0; i < 2; i++ {
for i := 0; i < 4; i++ {
select {
case <-verack:
case <-time.After(time.Second * 1):
case <-time.After(time.Second):
return nil, nil, errors.New("verack timeout")
}
}
@ -279,10 +282,10 @@ func TestPeerConnection(t *testing.T) {
}
outPeer.Connect(outConn)
for i := 0; i < 2; i++ {
for i := 0; i < 4; i++ {
select {
case <-verack:
case <-time.After(time.Second * 1):
case <-time.After(time.Second):
return nil, nil, errors.New("verack timeout")
}
}
@ -294,7 +297,7 @@ func TestPeerConnection(t *testing.T) {
for i, test := range tests {
inPeer, outPeer, err := test.setup()
if err != nil {
t.Errorf("TestPeerConnection setup #%d: unexpected err %v\n", i, err)
t.Errorf("TestPeerConnection setup #%d: unexpected err %v", i, err)
return
}
testPeer(t, inPeer, wantStats)
@ -302,6 +305,8 @@ func TestPeerConnection(t *testing.T) {
inPeer.Disconnect()
outPeer.Disconnect()
inPeer.WaitForDisconnect()
outPeer.WaitForDisconnect()
}
}
@ -547,6 +552,7 @@ func TestOutboundPeer(t *testing.T) {
select {
case <-disconnected:
close(disconnected)
case <-time.After(time.Second):
t.Fatal("Peer did not automatically disconnect.")
}
@ -580,6 +586,7 @@ func TestOutboundPeer(t *testing.T) {
}
return hash, 234439, nil
}
peerCfg.NewestBlock = newestBlock
r1, w1 := io.Pipe()
c1 := &conn{raddr: "10.0.0.1:8333", Writer: w1, Reader: r1}
@ -638,7 +645,8 @@ func TestOutboundPeer(t *testing.T) {
t.Errorf("PushGetHeadersMsg: unexpected err %v\n", err)
return
}
p2.PushRejectMsg("block", wire.RejectMalformed, "malformed", nil, true)
p2.PushRejectMsg("block", wire.RejectMalformed, "malformed", nil, false)
p2.PushRejectMsg("block", wire.RejectInvalid, "invalid", nil, false)
// Test Queue Messages