diff --git a/peer/peer.go b/peer/peer.go index ae61872a..a4ea1fd4 100644 --- a/peer/peer.go +++ b/peer/peer.go @@ -34,9 +34,9 @@ const ( // inv message to a peer. DefaultTrickleInterval = 10 * time.Second - // minAcceptableProtocolVersion is the lowest protocol version that a + // MinAcceptableProtocolVersion is the lowest protocol version that a // connected peer may support. - minAcceptableProtocolVersion = wire.MultipleAddressVersion + MinAcceptableProtocolVersion = wire.MultipleAddressVersion // outputBufferSize is the number of elements the output channels use. outputBufferSize = 50 @@ -1875,26 +1875,42 @@ func (p *Peer) Disconnect() { close(p.quit) } -// handleRemoteVersionMsg is invoked when a version bitcoin message is received -// from the remote peer. It will return an error if the remote peer's version -// is not compatible with ours. -func (p *Peer) handleRemoteVersionMsg(msg *wire.MsgVersion) error { +// readRemoteVersionMsg waits for the next message to arrive from the remote +// peer. If the next message is not a version message or the version is not +// acceptable then return an error. +func (p *Peer) readRemoteVersionMsg() error { + // Read their version message. + remoteMsg, _, err := p.readMessage(wire.LatestEncoding) + if err != nil { + return err + } + + // Notify and disconnect clients if the first message is not a version + // message. + msg, ok := remoteMsg.(*wire.MsgVersion) + if !ok { + reason := "a version message must precede all others" + rejectMsg := wire.NewMsgReject(msg.Command(), wire.RejectMalformed, + reason) + _ = p.writeMessage(rejectMsg, wire.LatestEncoding) + return errors.New(reason) + } + // Detect self connections. if !allowSelfConns && sentNonces.Exists(msg.Nonce) { return errors.New("disconnecting peer connected to self") } - // Notify and disconnect clients that have a protocol version that is - // too old. - // - // NOTE: If minAcceptableProtocolVersion is raised to be higher than - // wire.RejectVersion, this should send a reject packet before - // disconnecting. - if uint32(msg.ProtocolVersion) < minAcceptableProtocolVersion { - reason := fmt.Sprintf("protocol version must be %d or greater", - minAcceptableProtocolVersion) - return errors.New(reason) - } + // Negotiate the protocol version and set the services to what the remote + // peer advertised. + p.flagsMtx.Lock() + p.advertisedProtoVer = uint32(msg.ProtocolVersion) + p.protocolVersion = minUint32(p.protocolVersion, p.advertisedProtoVer) + p.versionKnown = true + p.services = msg.Services + p.flagsMtx.Unlock() + log.Debugf("Negotiated protocol version %d for peer %s", + p.protocolVersion, p) // Updating a bunch of stats including block based stats, and the // peer's time offset. @@ -1904,22 +1920,10 @@ func (p *Peer) handleRemoteVersionMsg(msg *wire.MsgVersion) error { p.timeOffset = msg.Timestamp.Unix() - time.Now().Unix() p.statsMtx.Unlock() - // Negotiate the protocol version. + // Set the peer's ID, user agent, and potentially the flag which + // specifies the witness support is enabled. p.flagsMtx.Lock() - p.advertisedProtoVer = uint32(msg.ProtocolVersion) - p.protocolVersion = minUint32(p.protocolVersion, p.advertisedProtoVer) - p.versionKnown = true - log.Debugf("Negotiated protocol version %d for peer %s", - p.protocolVersion, p) - - // Set the peer's ID. p.id = atomic.AddInt32(&nodeCount, 1) - - // Set the supported services for the peer to what the remote peer - // advertised. - p.services = msg.Services - - // Set the remote peer's user agent. p.userAgent = msg.UserAgent // Determine if the peer would like to receive witness data with @@ -1938,36 +1942,29 @@ func (p *Peer) handleRemoteVersionMsg(msg *wire.MsgVersion) error { p.wireEncoding = wire.WitnessEncoding } - return nil -} - -// readRemoteVersionMsg waits for the next message to arrive from the remote -// peer. If the next message is not a version message or the version is not -// acceptable then return an error. -func (p *Peer) readRemoteVersionMsg() error { - // Read their version message. - msg, _, err := p.readMessage(wire.LatestEncoding) - if err != nil { - return err - } - - remoteVerMsg, ok := msg.(*wire.MsgVersion) - if !ok { - errStr := "A version message must precede all others" - log.Errorf(errStr) - - rejectMsg := wire.NewMsgReject(msg.Command(), wire.RejectMalformed, - errStr) - return p.writeMessage(rejectMsg, wire.LatestEncoding) - } - - if err := p.handleRemoteVersionMsg(remoteVerMsg); err != nil { - return err - } - + // Invoke the callback if specified. if p.cfg.Listeners.OnVersion != nil { - p.cfg.Listeners.OnVersion(p, remoteVerMsg) + p.cfg.Listeners.OnVersion(p, msg) } + + // Notify and disconnect clients that have a protocol version that is + // too old. + // + // NOTE: If minAcceptableProtocolVersion is raised to be higher than + // wire.RejectVersion, this should send a reject packet before + // disconnecting. + if uint32(msg.ProtocolVersion) < MinAcceptableProtocolVersion { + // Send a reject message indicating the protocol version is + // obsolete and wait for the message to be sent before + // disconnecting. + reason := fmt.Sprintf("protocol version must be %d or greater", + MinAcceptableProtocolVersion) + rejectMsg := wire.NewMsgReject(msg.Command(), wire.RejectObsolete, + reason) + _ = p.writeMessage(rejectMsg, wire.LatestEncoding) + return errors.New(reason) + } + return nil } @@ -2099,9 +2096,11 @@ func (p *Peer) start() error { select { case err := <-negotiateErr: if err != nil { + p.Disconnect() return err } case <-time.After(negotiateTimeout): + p.Disconnect() return errors.New("protocol negotiation timeout") } log.Debugf("Connected to %s", p.Addr()) diff --git a/server.go b/server.go index 2e91eee3..9c82ffa7 100644 --- a/server.go +++ b/server.go @@ -389,6 +389,12 @@ func (sp *serverPeer) addBanScore(persistent, transient uint32, reason string) { // and is used to negotiate the protocol version details as well as kick start // the communications. func (sp *serverPeer) OnVersion(_ *peer.Peer, msg *wire.MsgVersion) { + // Ignore peers that have a protcol version that is too old. The peer + // negotiation logic will disconnect it after this callback returns. + if msg.ProtocolVersion < int32(peer.MinAcceptableProtocolVersion) { + return + } + // Add the remote peer time as a sample for creating an offset against // the local clock to keep the network time in sync. sp.server.timeSource.AddTimeSample(sp.Addr(), msg.Timestamp)