diff --git a/dht/bitmap.go b/dht/bitmap.go index aa09166..0da55a0 100644 --- a/dht/bitmap.go +++ b/dht/bitmap.go @@ -6,7 +6,7 @@ import ( "encoding/hex" "strings" - "github.com/lbryio/errors.go" + "github.com/lbryio/lbry.go/errors" "github.com/lyoshenka/bencode" ) diff --git a/dht/bootstrap.go b/dht/bootstrap.go index 34f431d..a4d629a 100644 --- a/dht/bootstrap.go +++ b/dht/bootstrap.go @@ -121,8 +121,8 @@ func (b *BootstrapNode) get(limit int) []Contact { // ping pings a node. if the node responds, it is added to the list. otherwise, it is removed func (b *BootstrapNode) ping(c Contact) { - b.stopWG.Add(1) - defer b.stopWG.Done() + b.stop.Add(1) + defer b.stop.Done() resCh, cancel := b.SendCancelable(c, Request{Method: pingMethod}) diff --git a/dht/dht.go b/dht/dht.go index 67615f4..7595c84 100644 --- a/dht/dht.go +++ b/dht/dht.go @@ -7,7 +7,7 @@ import ( "sync" "time" - "github.com/lbryio/errors.go" + "github.com/lbryio/lbry.go/errors" "github.com/lbryio/lbry.go/stopOnce" "github.com/spf13/cast" diff --git a/dht/message.go b/dht/message.go index bf43c97..d858621 100644 --- a/dht/message.go +++ b/dht/message.go @@ -7,7 +7,7 @@ import ( "strconv" "strings" - "github.com/lbryio/errors.go" + "github.com/lbryio/lbry.go/errors" "github.com/lyoshenka/bencode" "github.com/spf13/cast" diff --git a/dht/node.go b/dht/node.go index f60578a..4bcb4ba 100644 --- a/dht/node.go +++ b/dht/node.go @@ -40,6 +40,8 @@ type Node struct { id Bitmap // UDP connection for sending and receiving data conn UDPConn + // true if we've closed the connection on purpose + connClosed bool // token manager tokens *tokenManager @@ -56,8 +58,7 @@ type Node struct { requestHandler RequestHandlerFunc // stop the node neatly and clean up after itself - stop *stopOnce.Stopper - stopWG *sync.WaitGroup + stop *stopOnce.Stopper } // New returns a Node pointer. @@ -71,7 +72,6 @@ func NewNode(id Bitmap) *Node { transactions: make(map[messageID]*transaction), stop: stopOnce.New(), - stopWG: &sync.WaitGroup{}, tokens: &tokenManager{}, } } @@ -80,43 +80,31 @@ func NewNode(id Bitmap) *Node { func (n *Node) Connect(conn UDPConn) error { n.conn = conn - //if dht.conf.PrintState > 0 { - // go func() { - // t := time.NewTicker(dht.conf.PrintState) - // for { - // dht.PrintState() - // select { - // case <-t.C: - // case <-dht.stop.Ch(): - // return - // } - // } - // }() - //} - n.tokens.Start(tokenSecretRotationInterval) + go func() { + // stop tokens and close the connection when we're shutting down + <-n.stop.Ch() + n.tokens.Stop() + n.connClosed = true + n.conn.Close() + }() + packets := make(chan packet) go func() { - n.stopWG.Add(1) - defer n.stopWG.Done() + n.stop.Add(1) + defer n.stop.Done() buf := make([]byte, udpMaxMessageLength) for { - select { - case <-n.stop.Ch(): - return - default: - } - - n.conn.SetReadDeadline(time.Now().Add(200 * time.Millisecond)) // need this to periodically check shutdown chan bytesRead, raddr, err := n.conn.ReadFromUDP(buf) if err != nil { - if e, ok := err.(net.Error); !ok || !e.Timeout() { - log.Errorf("udp read error: %v", err) + if n.connClosed { + return } + log.Errorf("udp read error: %v", err) continue } else if raddr == nil { log.Errorf("udp read with no raddr") @@ -129,13 +117,14 @@ func (n *Node) Connect(conn UDPConn) error { select { // needs select here because packet consumer can quit and the packets channel gets filled up and blocks case packets <- packet{data: data, raddr: raddr}: case <-n.stop.Ch(): + return } } }() go func() { - n.stopWG.Add(1) - defer n.stopWG.Done() + n.stop.Add(1) + defer n.stop.Done() var pkt packet @@ -157,10 +146,7 @@ func (n *Node) Connect(conn UDPConn) error { // Shutdown shuts down the node func (n *Node) Shutdown() { log.Debugf("[%s] node shutting down", n.id.HexShort()) - n.stop.Stop() - n.stopWG.Wait() - n.tokens.Stop() - n.conn.Close() + n.stop.StopAndWait() log.Debugf("[%s] node stopped", n.id.HexShort()) } @@ -316,7 +302,7 @@ func (n *Node) sendMessage(addr *net.UDPAddr, data Message) error { log.Debugf("[%s] (%d bytes) %s", n.id.HexShort(), len(encoded), spew.Sdump(data)) } - n.conn.SetWriteDeadline(time.Now().Add(time.Second * 15)) + n.conn.SetWriteDeadline(time.Now().Add(5 * time.Second)) _, err = n.conn.WriteToUDP(encoded, addr) return errors.Err(err) @@ -427,9 +413,9 @@ func (n *Node) CountActiveTransactions() int { } func (n *Node) startRoutingTableGrooming() { - n.stopWG.Add(1) + n.stop.Add(1) go func() { - defer n.stopWG.Done() + defer n.stop.Done() refreshTicker := time.NewTicker(tRefresh / 5) // how often to check for buckets that need to be refreshed for { select { diff --git a/dht/node_finder.go b/dht/node_finder.go index af8910e..d38eec9 100644 --- a/dht/node_finder.go +++ b/dht/node_finder.go @@ -5,7 +5,7 @@ import ( "sync" "time" - "github.com/lbryio/errors.go" + "github.com/lbryio/lbry.go/errors" "github.com/lbryio/lbry.go/stopOnce" log "github.com/sirupsen/logrus" diff --git a/dht/routing_table.go b/dht/routing_table.go index b0f1739..29bbb81 100644 --- a/dht/routing_table.go +++ b/dht/routing_table.go @@ -11,7 +11,7 @@ import ( "sync" "time" - "github.com/lbryio/errors.go" + "github.com/lbryio/lbry.go/errors" "github.com/lyoshenka/bencode" ) @@ -437,9 +437,8 @@ func (rt *routingTable) UnmarshalJSON(b []byte) error { // RoutingTableRefresh refreshes any buckets that need to be refreshed // It returns a channel that will be closed when the refresh is done func RoutingTableRefresh(n *Node, refreshInterval time.Duration, cancel <-chan struct{}) <-chan struct{} { - done := make(chan struct{}) - var wg sync.WaitGroup + done := make(chan struct{}) for _, id := range n.rt.GetIDsForRefresh(refreshInterval) { wg.Add(1) diff --git a/dht/testing.go b/dht/testing.go index 5d496e5..0a69439 100644 --- a/dht/testing.go +++ b/dht/testing.go @@ -7,7 +7,7 @@ import ( "testing" "time" - "github.com/lbryio/errors.go" + "github.com/lbryio/lbry.go/errors" ) var testingDHTIP = "127.0.0.1" @@ -107,7 +107,10 @@ func (t testUDPConn) ReadFromUDP(b []byte) (int, *net.UDPAddr, error) { } select { - case packet := <-t.toRead: + case packet, ok := <-t.toRead: + if !ok { + return 0, nil, errors.Err("conn closed") + } n := copy(b, packet.data) return n, packet.addr, nil case <-timeoutCh: @@ -130,7 +133,7 @@ func (t *testUDPConn) SetWriteDeadline(tm time.Time) error { } func (t *testUDPConn) Close() error { - t.toRead = nil + close(t.toRead) t.writes = nil return nil } diff --git a/dht/token_manager.go b/dht/token_manager.go index ee1d856..e4dfa6d 100644 --- a/dht/token_manager.go +++ b/dht/token_manager.go @@ -16,28 +16,26 @@ type tokenManager struct { secret []byte prevSecret []byte lock *sync.RWMutex - wg *sync.WaitGroup - done *stopOnce.Stopper + stop *stopOnce.Stopper } func (tm *tokenManager) Start(interval time.Duration) { tm.secret = make([]byte, 64) tm.prevSecret = make([]byte, 64) tm.lock = &sync.RWMutex{} - tm.wg = &sync.WaitGroup{} - tm.done = stopOnce.New() + tm.stop = stopOnce.New() tm.rotateSecret() - tm.wg.Add(1) + tm.stop.Add(1) go func() { - defer tm.wg.Done() + defer tm.stop.Done() tick := time.NewTicker(interval) for { select { case <-tick.C: tm.rotateSecret() - case <-tm.done.Ch(): + case <-tm.stop.Ch(): return } } @@ -45,8 +43,7 @@ func (tm *tokenManager) Start(interval time.Duration) { } func (tm *tokenManager) Stop() { - tm.done.Stop() - tm.wg.Wait() + tm.stop.StopAndWait() } func (tm *tokenManager) Get(nodeID Bitmap, addr *net.UDPAddr) string {