diff --git a/peer.go b/peer.go index 6316f7c1..d0e93fb0 100644 --- a/peer.go +++ b/peer.go @@ -155,8 +155,7 @@ func (p *peer) handleVersionMsg(msg *btcwire.MsgVersion) { if msg.Nonce == p.server.nonce { log.Debugf("[PEER] Disconnecting peer connected to self %s", p.conn.RemoteAddr()) - p.disconnect = true - p.conn.Close() + p.Disconnect() return } @@ -164,8 +163,7 @@ func (p *peer) handleVersionMsg(msg *btcwire.MsgVersion) { if p.versionKnown { log.Errorf("[PEER] Only one version message per peer is allowed %s.", p.conn.RemoteAddr()) - p.disconnect = true - p.conn.Close() + p.Disconnect() return } @@ -186,8 +184,7 @@ func (p *peer) handleVersionMsg(msg *btcwire.MsgVersion) { err := p.pushVersionMsg() if err != nil { log.Errorf("[PEER] %v", err) - p.disconnect = true - p.conn.Close() + p.Disconnect() return } @@ -195,8 +192,7 @@ func (p *peer) handleVersionMsg(msg *btcwire.MsgVersion) { na, err := btcwire.NewNetAddress(p.conn.RemoteAddr(), p.services) if err != nil { log.Errorf("[PEER] %v", err) - p.disconnect = true - p.conn.Close() + p.Disconnect() return } p.server.addrManager.AddAddress(na) @@ -213,8 +209,7 @@ func (p *peer) handleVersionMsg(msg *btcwire.MsgVersion) { na, err := newNetAddress(p.conn.LocalAddr(), p.services) if err != nil { log.Errorf("[PEER] %v", err) - p.disconnect = true - p.conn.Close() + p.Disconnect() return } addresses := map[string]*btcwire.NetAddress{ @@ -236,8 +231,7 @@ func (p *peer) handleVersionMsg(msg *btcwire.MsgVersion) { sha, lastBlock, err := p.server.db.NewestSha() if err != nil { log.Errorf("[PEER] %v", err) - p.disconnect = true - p.conn.Close() + p.Disconnect() } // If the peer has blocks we're interested in. if p.lastBlock > int32(lastBlock) { @@ -442,8 +436,7 @@ func (p *peer) handleGetAddrMsg(msg *btcwire.MsgGetAddr) { err := p.pushAddrMsg(addrCache) if err != nil { log.Errorf("[PEER] %v", err) - p.disconnect = true - p.conn.Close() + p.Disconnect() return } } @@ -497,8 +490,7 @@ func (p *peer) handleAddrMsg(msg *btcwire.MsgAddr) { if len(msg.AddrList) == 0 { log.Errorf("[PEER] Command [%s] from %s does not contain any addresses", msg.Command(), p.conn.RemoteAddr()) - p.disconnect = true - p.conn.Close() + p.Disconnect() return } @@ -692,8 +684,7 @@ out: } // Ensure connection is closed and notify server that the peer is done. - p.disconnect = true - p.conn.Close() + p.Disconnect() p.server.donePeers <- p p.quit <- true @@ -715,7 +706,7 @@ out: } err := p.writeMessage(msg) if err != nil { - p.disconnect = true + p.Disconnect() log.Errorf("[PEER] %v", err) } @@ -763,12 +754,18 @@ func (p *peer) Start() error { return nil } -// Shutdown gracefully shuts down the peer by signalling the async input and -// output handler and waiting for them to finish. -func (p *peer) Shutdown() { - log.Tracef("[PEER] Shutdown peer %s", p.conn.RemoteAddr()) +// Disconnect disconnects the peer by closing the connection. It also sets +// a flag so the impending shutdown can be detected. +func (p *peer) Disconnect() { p.disconnect = true p.conn.Close() +} + +// Shutdown gracefully shuts down the peer by disconnecting it and waiting for +// all goroutines to finish. +func (p *peer) Shutdown() { + log.Tracef("[PEER] Shutdown peer %s", p.conn.RemoteAddr()) + p.Disconnect() p.wg.Wait() }