Remove a (largely harmless) race on p.conn

Add a Connected() member function that checks atomic variables to see if
the peer is connected.
This commit is contained in:
Owain G. Ainsworth 2013-10-02 23:06:29 +01:00
parent 5c1340be8f
commit f333cb4220
3 changed files with 23 additions and 15 deletions

View file

@ -121,7 +121,7 @@ func (b *blockManager) startSync(peers *list.List) {
} }
log.Infof("[BMGR] Syncing to block height %d from peer %v", log.Infof("[BMGR] Syncing to block height %d from peer %v",
bestPeer.lastBlock, bestPeer.conn.RemoteAddr()) bestPeer.lastBlock, bestPeer.addr)
bestPeer.PushGetBlocksMsg(locator, &zeroHash) bestPeer.PushGetBlocksMsg(locator, &zeroHash)
b.syncPeer = bestPeer b.syncPeer = bestPeer
} }

32
peer.go
View file

@ -102,6 +102,7 @@ type peer struct {
na *btcwire.NetAddress na *btcwire.NetAddress
timeConnected time.Time timeConnected time.Time
inbound bool inbound bool
connected int32
disconnect int32 // only to be used atomically disconnect int32 // only to be used atomically
persistent bool persistent bool
versionKnown bool versionKnown bool
@ -202,7 +203,7 @@ func (p *peer) handleVersionMsg(msg *btcwire.MsgVersion) {
// Detect self connections. // Detect self connections.
if msg.Nonce == p.server.nonce { if msg.Nonce == p.server.nonce {
log.Debugf("[PEER] Disconnecting peer connected to self %s", log.Debugf("[PEER] Disconnecting peer connected to self %s",
p.conn.RemoteAddr()) p.addr)
p.Disconnect() p.Disconnect()
return return
} }
@ -210,7 +211,7 @@ func (p *peer) handleVersionMsg(msg *btcwire.MsgVersion) {
// Limit to one version message per peer. // Limit to one version message per peer.
if p.versionKnown { if p.versionKnown {
p.logError("[PEER] Only one version message per peer is allowed %s.", p.logError("[PEER] Only one version message per peer is allowed %s.",
p.conn.RemoteAddr()) p.addr)
p.Disconnect() p.Disconnect()
return return
} }
@ -219,7 +220,7 @@ func (p *peer) handleVersionMsg(msg *btcwire.MsgVersion) {
p.protocolVersion = minUint32(p.protocolVersion, uint32(msg.ProtocolVersion)) p.protocolVersion = minUint32(p.protocolVersion, uint32(msg.ProtocolVersion))
p.versionKnown = true p.versionKnown = true
log.Debugf("[PEER] Negotiated protocol version %d for peer %s", log.Debugf("[PEER] Negotiated protocol version %d for peer %s",
p.protocolVersion, p.conn.RemoteAddr()) p.protocolVersion, p.addr)
p.lastBlock = msg.LastBlock p.lastBlock = msg.LastBlock
// Set the supported services for the peer to what the remote peer // Set the supported services for the peer to what the remote peer
@ -685,7 +686,7 @@ func (p *peer) handleAddrMsg(msg *btcwire.MsgAddr) {
// A message that has no addresses is invalid. // A message that has no addresses is invalid.
if len(msg.AddrList) == 0 { if len(msg.AddrList) == 0 {
p.logError("[PEER] Command [%s] from %s does not contain any addresses", p.logError("[PEER] Command [%s] from %s does not contain any addresses",
msg.Command(), p.conn.RemoteAddr()) msg.Command(), p.addr)
p.Disconnect() p.Disconnect()
return return
} }
@ -735,7 +736,7 @@ func (p *peer) readMessage() (msg btcwire.Message, buf []byte, err error) {
return return
} }
log.Debugf("[PEER] Received command [%v] from %s", msg.Command(), log.Debugf("[PEER] Received command [%v] from %s", msg.Command(),
p.conn.RemoteAddr()) p.addr)
// Use closures to log expensive operations so they are only run when // Use closures to log expensive operations so they are only run when
// the logging level requires it. // the logging level requires it.
@ -757,7 +758,7 @@ func (p *peer) writeMessage(msg btcwire.Message) {
} }
log.Debugf("[PEER] Sending command [%v] to %s", msg.Command(), log.Debugf("[PEER] Sending command [%v] to %s", msg.Command(),
p.conn.RemoteAddr()) p.addr)
// Use closures to log expensive operations so they are only run when the // Use closures to log expensive operations so they are only run when the
// logging level requires it. // logging level requires it.
@ -795,7 +796,7 @@ func (p *peer) isAllowedByRegression(err error) bool {
// Don't allow the error if it's not coming from localhost or the // Don't allow the error if it's not coming from localhost or the
// hostname can't be determined for some reason. // hostname can't be determined for some reason.
host, _, err := net.SplitHostPort(p.conn.RemoteAddr().String()) host, _, err := net.SplitHostPort(p.addr)
if err != nil { if err != nil {
return false return false
} }
@ -909,7 +910,7 @@ out:
p.server.blockManager.DonePeer(p) p.server.blockManager.DonePeer(p)
} }
log.Tracef("[PEER] Peer input handler done for %s", p.conn.RemoteAddr()) log.Tracef("[PEER] Peer input handler done for %s", p.addr)
} }
// outHandler handles all outgoing messages for the peer. It must be run as a // outHandler handles all outgoing messages for the peer. It must be run as a
@ -964,7 +965,7 @@ out:
break out break out
} }
} }
log.Tracef("[PEER] Peer output handler done for %s", p.conn.RemoteAddr()) log.Tracef("[PEER] Peer output handler done for %s", p.addr)
} }
// QueueMessage adds the passed bitcoin message to the peer send queue. It // QueueMessage adds the passed bitcoin message to the peer send queue. It
@ -996,7 +997,7 @@ func (p *peer) Start() error {
return nil return nil
} }
log.Tracef("[PEER] Starting peer %s", p.conn.RemoteAddr()) log.Tracef("[PEER] Starting peer %s", p.addr)
// Send an initial version message if this is an outbound connection. // Send an initial version message if this is an outbound connection.
if !p.inbound { if !p.inbound {
@ -1004,7 +1005,7 @@ func (p *peer) Start() error {
if err != nil { if err != nil {
p.logError("[PEER] Can't send outbound version "+ p.logError("[PEER] Can't send outbound version "+
"message %v", err) "message %v", err)
p.conn.Close() p.Disconnect()
return err return err
} }
} }
@ -1025,7 +1026,7 @@ func (p *peer) Disconnect() {
return return
} }
close(p.quit) close(p.quit)
if p.conn != nil { if p.Connected() {
p.conn.Close() p.conn.Close()
} }
} }
@ -1067,6 +1068,7 @@ func newInboundPeer(s *server, conn net.Conn) *peer {
p := newPeerBase(s, true) p := newPeerBase(s, true)
p.conn = conn p.conn = conn
p.addr = conn.RemoteAddr().String() p.addr = conn.RemoteAddr().String()
atomic.AddInt32(&p.connected, 1)
return p return p
} }
@ -1149,6 +1151,7 @@ func newOutboundPeer(s *server, addr string, persistent bool) *peer {
log.Debugf("[SRVR] Connected to %s", log.Debugf("[SRVR] Connected to %s",
conn.RemoteAddr()) conn.RemoteAddr())
p.conn = conn p.conn = conn
atomic.AddInt32(&p.connected, 1)
p.retrycount = 0 p.retrycount = 0
p.Start() p.Start()
} }
@ -1167,3 +1170,8 @@ func (p *peer) logError(fmt string, args...interface{}) {
log.Debugf(fmt, args...) log.Debugf(fmt, args...)
} }
} }
func (p *peer) Connected() bool {
return atomic.LoadInt32(&p.connected) != 0 &&
atomic.LoadInt32(&p.disconnect) == 0
}

View file

@ -173,7 +173,7 @@ func (s *server) handleRelayInvMsg(peers *list.List, iv *btcwire.InvVect) {
// which are not already known to have it. // which are not already known to have it.
for e := peers.Front(); e != nil; e = e.Next() { for e := peers.Front(); e != nil; e = e.Next() {
p := e.Value.(*peer) p := e.Value.(*peer)
if p.conn == nil { if !p.Connected() {
continue continue
} }
@ -196,7 +196,7 @@ func (s *server) handleBroadcastMsg(peers *list.List, bmsg *broadcastMsg) {
} }
p := e.Value.(*peer) p := e.Value.(*peer)
// Don't broadcast to still connecting outbound peers . // Don't broadcast to still connecting outbound peers .
if p.conn == nil { if !p.Connected() {
excluded = true excluded = true
} }
if !excluded { if !excluded {