diff --git a/btcwallet.go b/btcwallet.go index dd84eb2..4ed5cbd 100644 --- a/btcwallet.go +++ b/btcwallet.go @@ -91,52 +91,58 @@ func walletMain() error { // Shutdown the server if an interrupt signal is received. addInterruptHandler(server.Stop) - // Create channel so that the goroutine which opens the chain server - // connection can pass the conn to the goroutine which opens the wallet. - // Buffer the channel so sends are not blocked, since if the wallet is - // not yet created, the wallet open goroutine does not read this. - chainSvrChan := make(chan *chain.Client, 1) - go func() { - // Read CA certs and create the RPC client. - var certs []byte - if !cfg.DisableClientTLS { - certs, err = ioutil.ReadFile(cfg.CAFile) - if err != nil { - log.Warnf("Cannot open CA file: %v", err) - // If there's an error reading the CA file, continue - // with nil certs and without the client connection - certs = nil + for { + // Read CA certs and create the RPC client. + var certs []byte + if !cfg.DisableClientTLS { + certs, err = ioutil.ReadFile(cfg.CAFile) + if err != nil { + log.Warnf("Cannot open CA file: %v", err) + // If there's an error reading the CA file, continue + // with nil certs and without the client connection + certs = nil + } + } else { + log.Info("Client TLS is disabled") } - } else { - log.Info("Client TLS is disabled") - } - rpcc, err := chain.NewClient(activeNet.Params, cfg.RPCConnect, - cfg.BtcdUsername, cfg.BtcdPassword, certs, cfg.DisableClientTLS) - if err != nil { - log.Errorf("Cannot create chain server RPC client: %v", err) - return - } - err = rpcc.Start() - if err != nil { - log.Warnf("Connection to Bitcoin RPC chain server " + - "unsuccessful -- available RPC methods will be limited") - } - // Even if Start errored, we still add the server disconnected. - // All client methods will then error, so it's obvious to a - // client that the there was a connection problem. - server.SetChainServer(rpcc) + rpcc, err := chain.NewClient(activeNet.Params, cfg.RPCConnect, + cfg.BtcdUsername, cfg.BtcdPassword, certs, cfg.DisableClientTLS) + if err != nil { + log.Errorf("Cannot create chain server RPC client: %v", err) + return + } + err = rpcc.Start() + if err != nil { + log.Warnf("Connection to Bitcoin RPC chain server " + + "unsuccessful -- available RPC methods will be limited") + } + // Even if Start errored, we still add the server disconnected. + // All client methods will then error, so it's obvious to a + // client that the there was a connection problem. + server.SetChainServer(rpcc) - chainSvrChan <- rpcc - }() + // Start wallet goroutines and handle RPC client notifications + // if the server is not shutting down. + select { + case <-server.quit: + return + default: + wallet.Start(rpcc) + } - go func() { - // Start wallet goroutines and handle RPC client notifications - // if the chain server connection was opened. - select { - case chainSvr := <-chainSvrChan: - wallet.Start(chainSvr) - case <-server.quit: + // Block goroutine until the client is finished. + rpcc.WaitForShutdown() + + wallet.SetChainSynced(false) + wallet.Stop() + + // Reconnect only if the server is not shutting down. + select { + case <-server.quit: + return + default: + } } }() diff --git a/chain/chain.go b/chain/chain.go index 4aa91fd..bab7f29 100644 --- a/chain/chain.go +++ b/chain/chain.go @@ -70,13 +70,14 @@ func NewClient(chainParams *chaincfg.Params, connect, user, pass string, certs [ OnRescanProgress: client.onRescanProgress, } conf := btcrpcclient.ConnConfig{ - Host: connect, - Endpoint: "ws", - User: user, - Pass: pass, - Certificates: certs, - DisableConnectOnNew: true, - DisableTLS: disableTLS, + Host: connect, + Endpoint: "ws", + User: user, + Pass: pass, + Certificates: certs, + DisableAutoReconnect: true, + DisableConnectOnNew: true, + DisableTLS: disableTLS, } c, err := btcrpcclient.New(&conf, &ntfnCallbacks) if err != nil { @@ -121,8 +122,6 @@ func (c *Client) Start() error { // started by Start. func (c *Client) Stop() { c.quitMtx.Lock() - defer c.quitMtx.Unlock() - select { case <-c.quit: default: @@ -133,6 +132,7 @@ func (c *Client) Stop() { close(c.dequeueNotification) } } + c.quitMtx.Unlock() } // WaitForShutdown blocks until both the client has finished disconnecting @@ -225,26 +225,35 @@ func parseBlock(block *btcjson.BlockDetails) (*wtxmgr.BlockMeta, error) { func (c *Client) onClientConnect() { log.Info("Established websocket RPC connection to btcd") - c.enqueueNotification <- ClientConnected{} + select { + case c.enqueueNotification <- ClientConnected{}: + case <-c.quit: + } } func (c *Client) onBlockConnected(hash *wire.ShaHash, height int32, time time.Time) { - c.enqueueNotification <- BlockConnected{ + select { + case c.enqueueNotification <- BlockConnected{ Block: wtxmgr.Block{ Hash: *hash, Height: height, }, Time: time, + }: + case <-c.quit: } } func (c *Client) onBlockDisconnected(hash *wire.ShaHash, height int32, time time.Time) { - c.enqueueNotification <- BlockDisconnected{ + select { + case c.enqueueNotification <- BlockDisconnected{ Block: wtxmgr.Block{ Hash: *hash, Height: height, }, Time: time, + }: + case <-c.quit: } } @@ -262,7 +271,10 @@ func (c *Client) onRecvTx(tx *btcutil.Tx, block *btcjson.BlockDetails) { "tx: %v", err) return } - c.enqueueNotification <- RelevantTx{rec, blk} + select { + case c.enqueueNotification <- RelevantTx{rec, blk}: + case <-c.quit: + } } func (c *Client) onRedeemingTx(tx *btcutil.Tx, block *btcjson.BlockDetails) { @@ -271,11 +283,18 @@ func (c *Client) onRedeemingTx(tx *btcutil.Tx, block *btcjson.BlockDetails) { } func (c *Client) onRescanProgress(hash *wire.ShaHash, height int32, blkTime time.Time) { - c.enqueueNotification <- &RescanProgress{hash, height, blkTime} + select { + case c.enqueueNotification <- &RescanProgress{hash, height, blkTime}: + case <-c.quit: + } } func (c *Client) onRescanFinished(hash *wire.ShaHash, height int32, blkTime time.Time) { - c.enqueueNotification <- &RescanFinished{hash, height, blkTime} + select { + case c.enqueueNotification <- &RescanFinished{hash, height, blkTime}: + case <-c.quit: + } + } // handler maintains a queue of notifications and the current state (best @@ -283,8 +302,10 @@ func (c *Client) onRescanFinished(hash *wire.ShaHash, height int32, blkTime time func (c *Client) handler() { hash, height, err := c.GetBestBlock() if err != nil { - close(c.quit) + log.Errorf("Failed to receive best block from chain server: %v", err) + c.Stop() c.wg.Done() + return } bs := &waddrmgr.BlockStamp{Hash: *hash, Height: height} @@ -300,6 +321,7 @@ func (c *Client) handler() { enqueue := c.enqueueNotification var dequeue chan interface{} var next interface{} + pingChan := time.After(time.Minute) out: for { select { @@ -319,6 +341,7 @@ out: dequeue = c.dequeueNotification } notifications = append(notifications, n) + pingChan = time.After(time.Minute) case dequeue <- next: if n, ok := next.(BlockConnected); ok { @@ -341,12 +364,45 @@ out: dequeue = nil } + case <-pingChan: + // No notifications were received in the last 60s. + // Ensure the connection is still active by making a new + // request to the server. + // A 3 second timeout is used to prevent the handler loop + // from blocking here forever. + type sessionResult struct { + err error + } + sessionResponse := make(chan sessionResult, 1) + go func() { + _, err := c.Session() + sessionResponse <- sessionResult{err} + }() + + select { + case resp := <-sessionResponse: + if resp.err != nil { + log.Errorf("Failed to receive session "+ + "result: %v", resp.err) + c.Stop() + break out + } + pingChan = time.After(time.Minute) + + case <-time.After(3 * time.Second): + log.Errorf("Timeout waiting for session RPC") + c.Stop() + break out + } + case c.currentBlock <- bs: case <-c.quit: break out } } + + c.Stop() close(c.dequeueNotification) c.wg.Done() } diff --git a/wallet/rescan.go b/wallet/rescan.go index 780cde8..fa73437 100644 --- a/wallet/rescan.go +++ b/wallet/rescan.go @@ -111,6 +111,7 @@ func (b *rescanBatch) done(err error) { // can be handled by a single rescan after the current one completes. func (w *Wallet) rescanBatchHandler() { var curBatch, nextBatch *rescanBatch + quit := w.quitChan() out: for { @@ -162,18 +163,18 @@ out: panic(n) } - case <-w.quit: + case <-quit: break out } } - close(w.rescanBatch) w.wg.Done() } // rescanProgressHandler handles notifications for partially and fully completed // rescans by marking each rescanned address as partially or fully synced. func (w *Wallet) rescanProgressHandler() { + quit := w.quitChan() out: for { // These can't be processed out of order since both chans are @@ -226,7 +227,7 @@ out: } w.notifyConnectedBlock(b) - case <-w.quit: + case <-quit: break out } } @@ -237,21 +238,30 @@ out: // RPC requests to perform a rescan. New jobs are not read until a rescan // finishes. func (w *Wallet) rescanRPCHandler() { - for batch := range w.rescanBatch { - // Log the newly-started rescan. - numAddrs := len(batch.addrs) - noun := pickNoun(numAddrs, "address", "addresses") - log.Infof("Started rescan from block %v (height %d) for %d %s", - batch.bs.Hash, batch.bs.Height, numAddrs, noun) + quit := w.quitChan() - err := w.chainSvr.Rescan(&batch.bs.Hash, batch.addrs, - batch.outpoints) - if err != nil { - log.Errorf("Rescan for %d %s failed: %v", numAddrs, - noun, err) +out: + for { + select { + case batch := <-w.rescanBatch: + // Log the newly-started rescan. + numAddrs := len(batch.addrs) + noun := pickNoun(numAddrs, "address", "addresses") + log.Infof("Started rescan from block %v (height %d) for %d %s", + batch.bs.Hash, batch.bs.Height, numAddrs, noun) + + err := w.chainSvr.Rescan(&batch.bs.Hash, batch.addrs, + batch.outpoints) + if err != nil { + log.Errorf("Rescan for %d %s failed: %v", numAddrs, + noun, err) + } + batch.done(err) + case <-quit: + break out } - batch.done(err) } + w.wg.Done() } diff --git a/wallet/wallet.go b/wallet/wallet.go index 859b799..690cd29 100644 --- a/wallet/wallet.go +++ b/wallet/wallet.go @@ -106,7 +106,10 @@ type Wallet struct { chainParams *chaincfg.Params wg sync.WaitGroup - quit chan struct{} + + started bool + quit chan struct{} + quitMu sync.Mutex } // ErrDuplicateListen is returned for any attempts to listen for the same @@ -261,15 +264,25 @@ func (w *Wallet) notifyRelevantTx(relevantTx chain.RelevantTx) { // Start starts the goroutines necessary to manage a wallet. func (w *Wallet) Start(chainServer *chain.Client) { + w.quitMu.Lock() select { case <-w.quit: - return + // Restart the wallet goroutines after shutdown finishes. + w.WaitForShutdown() + w.quit = make(chan struct{}) default: + // Ignore when the wallet is still running. + if w.started { + w.quitMu.Unlock() + return + } + w.started = true } + w.quitMu.Unlock() - defer w.chainSvrLock.Unlock() w.chainSvrLock.Lock() w.chainSvr = chainServer + w.chainSvrLock.Unlock() w.wg.Add(6) go w.handleChainNotifications() @@ -280,12 +293,24 @@ func (w *Wallet) Start(chainServer *chain.Client) { go w.rescanRPCHandler() } +// quitChan atomically reads the quit channel. +func (w *Wallet) quitChan() <-chan struct{} { + w.quitMu.Lock() + c := w.quit + w.quitMu.Unlock() + return c +} + // Stop signals all wallet goroutines to shutdown. func (w *Wallet) Stop() { + w.quitMu.Lock() + quit := w.quit + w.quitMu.Unlock() + select { - case <-w.quit: + case <-quit: default: - close(w.quit) + close(quit) w.chainSvrLock.Lock() if w.chainSvr != nil { w.chainSvr.Stop() @@ -298,7 +323,7 @@ func (w *Wallet) Stop() { // shutting down or not. func (w *Wallet) ShuttingDown() bool { select { - case <-w.quit: + case <-w.quitChan(): return true default: return false @@ -445,6 +470,7 @@ type ( // for both requests, rather than just one, to fail due to not enough available // inputs. func (w *Wallet) txCreator() { + quit := w.quitChan() out: for { select { @@ -452,7 +478,7 @@ out: tx, err := w.txToPairs(txr.pairs, txr.account, txr.minconf) txr.resp <- createTxResponse{tx, err} - case <-w.quit: + case <-quit: break out } } @@ -503,6 +529,7 @@ type ( func (w *Wallet) walletLocker() { var timeout <-chan time.Time holdChan := make(HeldUnlock) + quit := w.quitChan() out: for { select { @@ -551,7 +578,7 @@ out: case w.lockState <- w.Manager.IsLocked(): continue - case <-w.quit: + case <-quit: break out case <-w.lockRequests: