peer: Rework version negotiation.

This modifies the negotiation logic to ensure the callback has the
opportunity to see the message before the peer is disconnected and
improves the error handling when reading the remote version message.

It also has the side effect of ensuring the protocol version is
negotiated before sending reject messages with the exception of the
first message not being a version message since negotiation is not
possible in that case.

This is being changed because it is useful for the server to see the
message regardless in order to have the opportunity to things such as
update the address manager and reject peers that don't have desired
services.

Backported from Decred.
This commit is contained in:
Dave Collins 2018-08-10 20:44:44 -05:00
parent 25dfda9bd3
commit 118f55233b
No known key found for this signature in database
GPG key ID: B8904D9D9C93D1F2
2 changed files with 64 additions and 59 deletions

View file

@ -34,9 +34,9 @@ const (
// inv message to a peer. // inv message to a peer.
DefaultTrickleInterval = 10 * time.Second DefaultTrickleInterval = 10 * time.Second
// minAcceptableProtocolVersion is the lowest protocol version that a // MinAcceptableProtocolVersion is the lowest protocol version that a
// connected peer may support. // connected peer may support.
minAcceptableProtocolVersion = wire.MultipleAddressVersion MinAcceptableProtocolVersion = wire.MultipleAddressVersion
// outputBufferSize is the number of elements the output channels use. // outputBufferSize is the number of elements the output channels use.
outputBufferSize = 50 outputBufferSize = 50
@ -1875,26 +1875,42 @@ func (p *Peer) Disconnect() {
close(p.quit) close(p.quit)
} }
// handleRemoteVersionMsg is invoked when a version bitcoin message is received // readRemoteVersionMsg waits for the next message to arrive from the remote
// from the remote peer. It will return an error if the remote peer's version // peer. If the next message is not a version message or the version is not
// is not compatible with ours. // acceptable then return an error.
func (p *Peer) handleRemoteVersionMsg(msg *wire.MsgVersion) error { func (p *Peer) readRemoteVersionMsg() error {
// Read their version message.
remoteMsg, _, err := p.readMessage(wire.LatestEncoding)
if err != nil {
return err
}
// Notify and disconnect clients if the first message is not a version
// message.
msg, ok := remoteMsg.(*wire.MsgVersion)
if !ok {
reason := "a version message must precede all others"
rejectMsg := wire.NewMsgReject(msg.Command(), wire.RejectMalformed,
reason)
_ = p.writeMessage(rejectMsg, wire.LatestEncoding)
return errors.New(reason)
}
// Detect self connections. // Detect self connections.
if !allowSelfConns && sentNonces.Exists(msg.Nonce) { if !allowSelfConns && sentNonces.Exists(msg.Nonce) {
return errors.New("disconnecting peer connected to self") return errors.New("disconnecting peer connected to self")
} }
// Notify and disconnect clients that have a protocol version that is // Negotiate the protocol version and set the services to what the remote
// too old. // peer advertised.
// p.flagsMtx.Lock()
// NOTE: If minAcceptableProtocolVersion is raised to be higher than p.advertisedProtoVer = uint32(msg.ProtocolVersion)
// wire.RejectVersion, this should send a reject packet before p.protocolVersion = minUint32(p.protocolVersion, p.advertisedProtoVer)
// disconnecting. p.versionKnown = true
if uint32(msg.ProtocolVersion) < minAcceptableProtocolVersion { p.services = msg.Services
reason := fmt.Sprintf("protocol version must be %d or greater", p.flagsMtx.Unlock()
minAcceptableProtocolVersion) log.Debugf("Negotiated protocol version %d for peer %s",
return errors.New(reason) p.protocolVersion, p)
}
// Updating a bunch of stats including block based stats, and the // Updating a bunch of stats including block based stats, and the
// peer's time offset. // peer's time offset.
@ -1904,22 +1920,10 @@ func (p *Peer) handleRemoteVersionMsg(msg *wire.MsgVersion) error {
p.timeOffset = msg.Timestamp.Unix() - time.Now().Unix() p.timeOffset = msg.Timestamp.Unix() - time.Now().Unix()
p.statsMtx.Unlock() p.statsMtx.Unlock()
// Negotiate the protocol version. // Set the peer's ID, user agent, and potentially the flag which
// specifies the witness support is enabled.
p.flagsMtx.Lock() p.flagsMtx.Lock()
p.advertisedProtoVer = uint32(msg.ProtocolVersion)
p.protocolVersion = minUint32(p.protocolVersion, p.advertisedProtoVer)
p.versionKnown = true
log.Debugf("Negotiated protocol version %d for peer %s",
p.protocolVersion, p)
// Set the peer's ID.
p.id = atomic.AddInt32(&nodeCount, 1) p.id = atomic.AddInt32(&nodeCount, 1)
// Set the supported services for the peer to what the remote peer
// advertised.
p.services = msg.Services
// Set the remote peer's user agent.
p.userAgent = msg.UserAgent p.userAgent = msg.UserAgent
// Determine if the peer would like to receive witness data with // Determine if the peer would like to receive witness data with
@ -1938,36 +1942,29 @@ func (p *Peer) handleRemoteVersionMsg(msg *wire.MsgVersion) error {
p.wireEncoding = wire.WitnessEncoding p.wireEncoding = wire.WitnessEncoding
} }
return nil // Invoke the callback if specified.
}
// readRemoteVersionMsg waits for the next message to arrive from the remote
// peer. If the next message is not a version message or the version is not
// acceptable then return an error.
func (p *Peer) readRemoteVersionMsg() error {
// Read their version message.
msg, _, err := p.readMessage(wire.LatestEncoding)
if err != nil {
return err
}
remoteVerMsg, ok := msg.(*wire.MsgVersion)
if !ok {
errStr := "A version message must precede all others"
log.Errorf(errStr)
rejectMsg := wire.NewMsgReject(msg.Command(), wire.RejectMalformed,
errStr)
return p.writeMessage(rejectMsg, wire.LatestEncoding)
}
if err := p.handleRemoteVersionMsg(remoteVerMsg); err != nil {
return err
}
if p.cfg.Listeners.OnVersion != nil { if p.cfg.Listeners.OnVersion != nil {
p.cfg.Listeners.OnVersion(p, remoteVerMsg) p.cfg.Listeners.OnVersion(p, msg)
} }
// Notify and disconnect clients that have a protocol version that is
// too old.
//
// NOTE: If minAcceptableProtocolVersion is raised to be higher than
// wire.RejectVersion, this should send a reject packet before
// disconnecting.
if uint32(msg.ProtocolVersion) < MinAcceptableProtocolVersion {
// Send a reject message indicating the protocol version is
// obsolete and wait for the message to be sent before
// disconnecting.
reason := fmt.Sprintf("protocol version must be %d or greater",
MinAcceptableProtocolVersion)
rejectMsg := wire.NewMsgReject(msg.Command(), wire.RejectObsolete,
reason)
_ = p.writeMessage(rejectMsg, wire.LatestEncoding)
return errors.New(reason)
}
return nil return nil
} }
@ -2099,9 +2096,11 @@ func (p *Peer) start() error {
select { select {
case err := <-negotiateErr: case err := <-negotiateErr:
if err != nil { if err != nil {
p.Disconnect()
return err return err
} }
case <-time.After(negotiateTimeout): case <-time.After(negotiateTimeout):
p.Disconnect()
return errors.New("protocol negotiation timeout") return errors.New("protocol negotiation timeout")
} }
log.Debugf("Connected to %s", p.Addr()) log.Debugf("Connected to %s", p.Addr())

View file

@ -389,6 +389,12 @@ func (sp *serverPeer) addBanScore(persistent, transient uint32, reason string) {
// and is used to negotiate the protocol version details as well as kick start // and is used to negotiate the protocol version details as well as kick start
// the communications. // the communications.
func (sp *serverPeer) OnVersion(_ *peer.Peer, msg *wire.MsgVersion) { func (sp *serverPeer) OnVersion(_ *peer.Peer, msg *wire.MsgVersion) {
// Ignore peers that have a protcol version that is too old. The peer
// negotiation logic will disconnect it after this callback returns.
if msg.ProtocolVersion < int32(peer.MinAcceptableProtocolVersion) {
return
}
// Add the remote peer time as a sample for creating an offset against // Add the remote peer time as a sample for creating an offset against
// the local clock to keep the network time in sync. // the local clock to keep the network time in sync.
sp.server.timeSource.AddTimeSample(sp.Addr(), msg.Timestamp) sp.server.timeSource.AddTimeSample(sp.Addr(), msg.Timestamp)