diff --git a/account.go b/account.go index 83af455..763194f 100644 --- a/account.go +++ b/account.go @@ -465,7 +465,10 @@ func (a *Account) LockedOutpoints() []btcjson.TransactionInput { locked := make([]btcjson.TransactionInput, len(a.lockedOutpoints)) i := 0 for op := range a.lockedOutpoints { - locked[i] = btcjson.TransactionInput{op.Hash.String(), op.Index} + locked[i] = btcjson.TransactionInput{ + Txid: op.Hash.String(), + Vout: op.Index, + } i++ } return locked diff --git a/acctmgr.go b/acctmgr.go index 8a72add..39d3250 100644 --- a/acctmgr.go +++ b/acctmgr.go @@ -29,6 +29,7 @@ import ( "os" "path/filepath" "strings" + "sync" ) // Errors relating to accounts. @@ -36,6 +37,7 @@ var ( ErrAccountExists = errors.New("account already exists") ErrWalletExists = errors.New("wallet already exists") ErrNotFound = errors.New("not found") + ErrNoAccounts = errors.New("no accounts") ) // AcctMgr is the global account manager for all opened accounts. @@ -70,6 +72,8 @@ type removeAccountCmd struct { a *Account } +type quitCmd struct{} + // AccountManager manages a collection of accounts. type AccountManager struct { // The accounts accessed through the account manager are not safe for @@ -81,6 +85,9 @@ type AccountManager struct { ds *DiskSyncer rm *RescanManager + + wg sync.WaitGroup + quit chan struct{} } // NewAccountManager returns a new AccountManager. @@ -89,6 +96,7 @@ func NewAccountManager() *AccountManager { bsem: make(chan struct{}, 1), cmdChan: make(chan interface{}), rescanMsgs: make(chan RescanMsg, 1), + quit: make(chan struct{}), } am.ds = NewDiskSyncer(am) am.rm = NewRescanManager(am.rescanMsgs) @@ -100,12 +108,29 @@ func (am *AccountManager) Start() { // Ready the semaphore - can't grab unless the manager has started. am.bsem <- struct{}{} + am.wg.Add(4) go am.accountHandler() go am.rescanListener() go am.ds.Start() go am.rm.Start() } +// Stop shuts down the account manager by stoping all signaling all goroutines +// started by Start to close. +func (am *AccountManager) Stop() { + am.rm.Stop() + am.ds.Stop() + close(am.quit) +} + +// WaitForShutdown blocks until all goroutines started by Start and stopped +// with Stop have finished. +func (am *AccountManager) WaitForShutdown() { + am.rm.WaitForShutdown() + am.ds.WaitForShutdown() + am.wg.Wait() +} + // accountData is a helper structure to let us centralise logic for adding // and removing accounts. type accountData struct { @@ -394,49 +419,57 @@ func openAccounts() *accountData { func (am *AccountManager) accountHandler() { ad := openAccounts() - for c := range am.cmdChan { - switch cmd := c.(type) { - case *openAccountsCmd: - // Write all old accounts before proceeding. - for _, a := range ad.nameToAccount { - if err := am.ds.FlushAccount(a); err != nil { - log.Errorf("Cannot write previously "+ - "scheduled account file: %v", err) +out: + for { + select { + case c := <-am.cmdChan: + switch cmd := c.(type) { + case *openAccountsCmd: + // Write all old accounts before proceeding. + for _, a := range ad.nameToAccount { + if err := am.ds.FlushAccount(a); err != nil { + log.Errorf("Cannot write previously "+ + "scheduled account file: %v", err) + } } + + ad = openAccounts() + case *accessAccountRequest: + a, ok := ad.nameToAccount[cmd.name] + if !ok { + a = nil + } + cmd.resp <- a + + case *accessAccountByAddressRequest: + a, ok := ad.addressToAccount[cmd.address] + if !ok { + a = nil + } + cmd.resp <- a + + case *accessAllRequest: + s := make([]*Account, 0, len(ad.nameToAccount)) + for _, a := range ad.nameToAccount { + s = append(s, a) + } + cmd.resp <- s + + case *addAccountCmd: + ad.addAccount(cmd.a) + case *removeAccountCmd: + ad.removeAccount(cmd.a) + + case *markAddressForAccountCmd: + // TODO(oga) make sure we own account + ad.addressToAccount[cmd.address] = cmd.account } - ad = openAccounts() - case *accessAccountRequest: - a, ok := ad.nameToAccount[cmd.name] - if !ok { - a = nil - } - cmd.resp <- a - - case *accessAccountByAddressRequest: - a, ok := ad.addressToAccount[cmd.address] - if !ok { - a = nil - } - cmd.resp <- a - - case *accessAllRequest: - s := make([]*Account, 0, len(ad.nameToAccount)) - for _, a := range ad.nameToAccount { - s = append(s, a) - } - cmd.resp <- s - - case *addAccountCmd: - ad.addAccount(cmd.a) - case *removeAccountCmd: - ad.removeAccount(cmd.a) - - case *markAddressForAccountCmd: - // TODO(oga) make sure we own account - ad.addressToAccount[cmd.address] = cmd.account + case <-am.quit: + break out } } + am.wg.Done() } // rescanListener listens for messages from the rescan manager and marks @@ -540,18 +573,22 @@ func (am *AccountManager) Account(name string) (*Account, error) { // AccountByAddress returns the account specified by address, or // ErrNotFound as an error if the account is not found. -func (am *AccountManager) AccountByAddress(addr btcutil.Address) (*Account, - error) { +func (am *AccountManager) AccountByAddress(addr btcutil.Address) (*Account, error) { respChan := make(chan *Account) - am.cmdChan <- &accessAccountByAddressRequest{ + req := accessAccountByAddressRequest{ address: addr.EncodeAddress(), resp: respChan, } - resp := <-respChan - if resp == nil { - return nil, ErrNotFound + select { + case am.cmdChan <- &req: + resp := <-respChan + if resp == nil { + return nil, ErrNotFound + } + return resp, nil + case <-am.quit: + return nil, ErrNoAccounts } - return resp, nil } // MarkAddressForAccount labels the given account as containing the provided @@ -561,15 +598,18 @@ func (am *AccountManager) MarkAddressForAccount(address btcutil.Address, // TODO(oga) really this entire dance should be carried out implicitly // instead of requiring explicit messaging from the account to the // manager. - am.cmdChan <- &markAddressForAccountCmd{ + req := markAddressForAccountCmd{ address: address.EncodeAddress(), account: account, } + select { + case am.cmdChan <- &req: + case <-am.quit: + } } // Address looks up an address if it is known to wallet at all. -func (am *AccountManager) Address(addr btcutil.Address) (wallet.WalletAddress, - error) { +func (am *AccountManager) Address(addr btcutil.Address) (wallet.WalletAddress, error) { a, err := am.AccountByAddress(addr) if err != nil { return nil, err @@ -581,25 +621,38 @@ func (am *AccountManager) Address(addr btcutil.Address) (wallet.WalletAddress, // AllAccounts returns a slice of all managed accounts. func (am *AccountManager) AllAccounts() []*Account { respChan := make(chan []*Account) - am.cmdChan <- &accessAllRequest{ + req := accessAllRequest{ resp: respChan, } - return <-respChan + select { + case am.cmdChan <- &req: + return <-respChan + case <-am.quit: + return nil + } } // AddAccount adds an account to the collection managed by an AccountManager. func (am *AccountManager) AddAccount(a *Account) { - am.cmdChan <- &addAccountCmd{ + req := addAccountCmd{ a: a, } + select { + case am.cmdChan <- &req: + case <-am.quit: + } } // RemoveAccount removes an account to the collection managed by an // AccountManager. func (am *AccountManager) RemoveAccount(a *Account) { - am.cmdChan <- &removeAccountCmd{ + req := removeAccountCmd{ a: a, } + select { + case am.cmdChan <- &req: + case <-am.quit: + } } // RegisterNewAccount adds a new memory account to the account manager, diff --git a/cmd.go b/cmd.go index 53f796d..2ad885b 100644 --- a/cmd.go +++ b/cmd.go @@ -32,8 +32,9 @@ import ( ) var ( - cfg *config - server *rpcServer + cfg *config + server *rpcServer + shutdownChan = make(chan struct{}) curBlock = struct { sync.RWMutex @@ -102,6 +103,36 @@ func accessClient() (*rpcClient, error) { return c, nil } +func clientConnect(certs []byte, newClient chan<- *rpcClient) { + const initialWait = 5 * time.Second + wait := initialWait + for { + select { + case <-server.quit: + return + default: + } + + client, err := newRPCClient(certs) + if err != nil { + log.Warnf("Unable to open chain server client "+ + "connection: %v", err) + time.Sleep(wait) + wait <<= 1 + if wait > time.Minute { + wait = time.Minute + } + continue + } + + wait = initialWait + client.Start() + newClient <- client + + client.WaitForShutdown() + } +} + func main() { // Work around defer not working after os.Exit. if err := walletMain(); err != nil { @@ -160,30 +191,18 @@ func walletMain() error { // Start HTTP server to serve wallet client connections. server.Start() + // Shutdown the server if an interrupt signal is received. + addInterruptHandler(server.Stop) + // Start client connection to a btcd chain server. Attempt // reconnections if the client could not be successfully connected. clientChan := make(chan *rpcClient) go clientAccess(clientChan) - const initialWait = 5 * time.Second - wait := initialWait - for { - client, err := newRPCClient(certs) - if err != nil { - log.Warnf("Unable to open chain server client "+ - "connection: %v", err) - time.Sleep(wait) - wait <<= 1 - if wait > time.Minute { - wait = time.Minute - } - continue - } + go clientConnect(certs, clientChan) - wait = initialWait - - client.Start() - clientChan <- client - - client.WaitForShutdown() - } + // Wait for the server to shutdown either due to a stop RPC request + // or an interrupt. + server.WaitForShutdown() + log.Info("Shutdown complete") + return nil } diff --git a/disksync.go b/disksync.go index 4ec57cf..1779c96 100644 --- a/disksync.go +++ b/disksync.go @@ -204,6 +204,9 @@ type DiskSyncer struct { // Account manager for this DiskSyncer. This is only // needed to grab the account manager semaphore. am *AccountManager + + quit chan struct{} + shutdown chan struct{} } // NewDiskSyncer creates a new DiskSyncer. @@ -215,6 +218,8 @@ func NewDiskSyncer(am *AccountManager) *DiskSyncer { writeBatch: make(chan *writeBatchRequest), exportAccount: make(chan *exportRequest), am: am, + quit: make(chan struct{}), + shutdown: make(chan struct{}), } } @@ -223,6 +228,14 @@ func (ds *DiskSyncer) Start() { go ds.handler() } +func (ds *DiskSyncer) Stop() { + close(ds.quit) +} + +func (ds *DiskSyncer) WaitForShutdown() { + <-ds.shutdown +} + // handler runs the disk syncer. It manages a set of "dirty" account files // which must be written to disk, and synchronizes all writes in a single // goroutine. Periodic flush operations may be signaled by an AccountManager. @@ -239,6 +252,7 @@ func (ds *DiskSyncer) handler() { var timer <-chan time.Time var sem chan struct{} schedule := newSyncSchedule(netdir) +out: for { select { case <-sem: // Now have exclusive access of the account manager @@ -288,8 +302,16 @@ func (ds *DiskSyncer) handler() { a := er.a dir := er.dir er.err <- a.writeAll(dir) + + case <-ds.quit: + err := schedule.flush() + if err != nil { + log.Errorf("Cannot write accounts: %v", err) + } + break out } } + close(ds.shutdown) } // FlushAccount writes all scheduled account files to disk for a single diff --git a/rescan.go b/rescan.go index a13e3ab..55da21b 100644 --- a/rescan.go +++ b/rescan.go @@ -17,6 +17,8 @@ package main import ( + "sync" + "github.com/conformal/btcutil" "github.com/conformal/btcwire" ) @@ -68,6 +70,8 @@ type RescanManager struct { status chan interface{} // rescanProgress and rescanFinished msgs chan RescanMsg jobCompleteChan chan chan struct{} + wg sync.WaitGroup + quit chan struct{} } // NewRescanManager creates a new RescanManger. If msgChan is non-nil, @@ -80,15 +84,25 @@ func NewRescanManager(msgChan chan RescanMsg) *RescanManager { status: make(chan interface{}, 1), msgs: msgChan, jobCompleteChan: make(chan chan struct{}, 1), + quit: make(chan struct{}), } } // Start starts the goroutines to run the RescanManager. func (m *RescanManager) Start() { + m.wg.Add(2) go m.jobHandler() go m.rpcHandler() } +func (m *RescanManager) Stop() { + close(m.quit) +} + +func (m *RescanManager) WaitForShutdown() { + m.wg.Wait() +} + type rescanBatch struct { addrs map[*Account][]btcutil.Address outpoints map[btcwire.OutPoint]struct{} @@ -146,6 +160,7 @@ func (m *RescanManager) jobHandler() { curBatch := newRescanBatch() nextBatch := newRescanBatch() +out: for { select { case job := <-m.addJob: @@ -205,8 +220,16 @@ func (m *RescanManager) jobHandler() { // Unexpected status message panic(s) } + + case <-m.quit: + break out } } + close(m.sendJob) + if m.msgs != nil { + close(m.msgs) + } + m.wg.Done() } // rpcHandler reads jobs sent by the jobHandler and sends the rpc requests @@ -228,6 +251,7 @@ func (m *RescanManager) rpcHandler() { m.MarkFinished(rescanFinished{err}) } } + m.wg.Done() } // RescanJob is a job to be processed by the RescanManager. The job includes diff --git a/rpcclient.go b/rpcclient.go index d54f7f0..f0d562d 100644 --- a/rpcclient.go +++ b/rpcclient.go @@ -260,6 +260,7 @@ type rpcClient struct { *btcrpcclient.Client // client to btcd enqueueNotification chan notification dequeueNotification chan notification + quit chan struct{} wg sync.WaitGroup } @@ -267,6 +268,7 @@ func newRPCClient(certs []byte) (*rpcClient, error) { client := rpcClient{ enqueueNotification: make(chan notification), dequeueNotification: make(chan notification), + quit: make(chan struct{}), } initializedClient := make(chan struct{}) ntfnCallbacks := btcrpcclient.NotificationHandlers{ @@ -317,7 +319,13 @@ func (c *rpcClient) Stop() { log.Warn("Disconnecting chain server client connection") c.Client.Shutdown() } - close(c.enqueueNotification) + + select { + case <-c.quit: + default: + close(c.quit) + close(c.enqueueNotification) + } } func (c *rpcClient) WaitForShutdown() { diff --git a/rpcserver.go b/rpcserver.go index 8d95b57..c140a72 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -213,11 +213,13 @@ type rpcServer struct { upgrader websocket.Upgrader - requests requestChan + requests chan handlerJob addWSClient chan *websocketClient removeWSClient chan *websocketClient broadcasts chan []byte + + quit chan struct{} } // newRPCServer creates a new server for serving RPC client connections, both @@ -232,10 +234,11 @@ func newRPCServer(listenAddrs []string) (*rpcServer, error) { // Allow all origins. CheckOrigin: func(r *http.Request) bool { return true }, }, - requests: make(requestChan), + requests: make(chan handlerJob), addWSClient: make(chan *websocketClient), removeWSClient: make(chan *websocketClient), broadcasts: make(chan []byte), + quit: make(chan struct{}), } // Check for existence of cert file and key file @@ -292,8 +295,9 @@ func (s *rpcServer) Start() { // A duplicator for notifications intended for all clients runs // in another goroutines. Any such notifications are sent to // the allClients channel and then sent to each connected client. + s.wg.Add(2) go s.NotificationHandler() - go s.requests.handler() + go s.RequestHandler() log.Trace("Starting RPC server") @@ -349,16 +353,56 @@ func (s *rpcServer) Start() { s.wg.Add(1) go func(listener net.Listener) { log.Infof("RPCS: RPC server listening on %s", listener.Addr()) - if err := httpServer.Serve(listener); err != nil { - log.Errorf("Listener for %s exited with error: %v", - listener.Addr(), err) - } + _ = httpServer.Serve(listener) log.Tracef("RPCS: RPC listener done for %s", listener.Addr()) s.wg.Done() }(listener) } } +// Stop gracefully shuts down the rpc server by stopping and disconnecting all +// clients, disconnecting the chain server connection, and closing the wallet's +// account files. +func (s *rpcServer) Stop() { + // If the server is changed to run more than one rpc handler at a time, + // to prevent a double channel close, this should be replaced with an + // atomic test-and-set. + select { + case <-s.quit: + log.Warnf("Server already shutting down") + return + default: + } + + log.Warn("Server shutting down") + + // Stop all the listeners. There will not be any listeners if + // listening is disabled. + for _, listener := range s.listeners { + err := listener.Close() + if err != nil { + log.Errorf("Cannot close listener %s: %v", + listener.Addr(), err) + } + } + + // Disconnect the connected chain server, if any. + client, err := accessClient() + if err == nil { + client.Stop() + } + + // Stop the account manager and finish all pending account file writes. + AcctMgr.Stop() + + // Signal the remaining goroutines to stop. + close(s.quit) +} + +func (s *rpcServer) WaitForShutdown() { + s.wg.Wait() +} + // ErrNoAuth represents an error where authentication could not succeed // due to a missing Authorization HTTP header. var ErrNoAuth = errors.New("no auth") @@ -789,6 +833,7 @@ func (s *rpcServer) NotifyConnectionStatus(wsc *websocketClient) { } func (s *rpcServer) NotificationHandler() { +out: for { select { case c := <-s.addWSClient: @@ -801,8 +846,11 @@ func (s *rpcServer) NotificationHandler() { delete(s.wsClients, wsc) } } + case <-s.quit: + break out } } + s.wg.Done() } // requestHandler is a handler function to handle an unmarshaled and parsed @@ -841,6 +889,7 @@ var rpcHandlers = map[string]requestHandler{ "settxfee": SetTxFee, "signmessage": SignMessage, "signrawtransaction": SignRawTransaction, + "stop": Stop, "validateaddress": ValidateAddress, "verifymessage": VerifyMessage, "walletlock": WalletLock, @@ -859,7 +908,6 @@ var rpcHandlers = map[string]requestHandler{ "listreceivedbyaccount": Unimplemented, "move": Unimplemented, "setaccount": Unimplemented, - "stop": Unimplemented, // Standard bitcoind methods which won't be implemented by btcwallet. "encryptwallet": Unsupported, @@ -901,36 +949,42 @@ type handlerJob struct { response chan<- handlerResponse } -type requestChan chan handlerJob +// RequestHandler reads and processes client requests from the request channel. +// Each request is run with exclusive access to the account manager. +func (s *rpcServer) RequestHandler() { +out: + for { + select { + case r := <-s.requests: + AcctMgr.Grab() + result, err := r.handler(r.request) + AcctMgr.Release() -// handler reads and processes client requests from the channel. Each -// request is run with exclusive access to the account manager. -func (c requestChan) handler() { - for r := range c { - AcctMgr.Grab() - result, err := r.handler(r.request) - AcctMgr.Release() - - var jsonErr *btcjson.Error - if err != nil { - jsonErr = &btcjson.Error{Message: err.Error()} - switch e := err.(type) { - case btcjson.Error: - *jsonErr = e - case DeserializationError: - jsonErr.Code = btcjson.ErrDeserialization.Code - case InvalidParameterError: - jsonErr.Code = btcjson.ErrInvalidParameter.Code - case ParseError: - jsonErr.Code = btcjson.ErrParse.Code - case InvalidAddressOrKeyError: - jsonErr.Code = btcjson.ErrInvalidAddressOrKey.Code - default: // All other errors get the wallet error code. - jsonErr.Code = btcjson.ErrWallet.Code + var jsonErr *btcjson.Error + if err != nil { + jsonErr = &btcjson.Error{Message: err.Error()} + switch e := err.(type) { + case btcjson.Error: + *jsonErr = e + case DeserializationError: + jsonErr.Code = btcjson.ErrDeserialization.Code + case InvalidParameterError: + jsonErr.Code = btcjson.ErrInvalidParameter.Code + case ParseError: + jsonErr.Code = btcjson.ErrParse.Code + case InvalidAddressOrKeyError: + jsonErr.Code = btcjson.ErrInvalidAddressOrKey.Code + default: // All other errors get the wallet error code. + jsonErr.Code = btcjson.ErrWallet.Code + } } + r.response <- handlerResponse{result, jsonErr} + + case <-s.quit: + break out } - r.response <- handlerResponse{result, jsonErr} } + s.wg.Done() } // Unimplemented handles an unimplemented RPC request with the @@ -2492,6 +2546,13 @@ func SignRawTransaction(icmd btcjson.Cmd) (interface{}, error) { }, nil } +// Stop handles the stop command by shutting down the process after the request +// is handled. +func Stop(icmd btcjson.Cmd) (interface{}, error) { + server.Stop() + return "btcwallet stopping.", nil +} + // ValidateAddress handles the validateaddress command. func ValidateAddress(icmd btcjson.Cmd) (interface{}, error) { cmd, ok := icmd.(*btcjson.ValidateAddressCmd) diff --git a/signal.go b/signal.go new file mode 100644 index 0000000..4ed8183 --- /dev/null +++ b/signal.go @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2013, 2014 Conformal Systems LLC + * + * Permission to use, copy, modify, and distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +package main + +import ( + "os" + "os/signal" +) + +// interruptChannel is used to receive SIGINT (Ctrl+C) signals. +var interruptChannel chan os.Signal + +// addHandlerChannel is used to add an interrupt handler to the list of handlers +// to be invoked on SIGINT (Ctrl+C) signals. +var addHandlerChannel = make(chan func()) + +// mainInterruptHandler listens for SIGINT (Ctrl+C) signals on the +// interruptChannel and invokes the registered interruptCallbacks accordingly. +// It also listens for callback registration. It must be run as a goroutine. +func mainInterruptHandler() { + // interruptCallbacks is a list of callbacks to invoke when a + // SIGINT (Ctrl+C) is received. + var interruptCallbacks []func() + + for { + select { + case <-interruptChannel: + log.Info("Received SIGINT (Ctrl+C). Shutting down...") + // run handlers in LIFO order. + for i := range interruptCallbacks { + idx := len(interruptCallbacks) - 1 - i + interruptCallbacks[idx]() + } + + case handler := <-addHandlerChannel: + interruptCallbacks = append(interruptCallbacks, handler) + } + } +} + +// addInterruptHandler adds a handler to call when a SIGINT (Ctrl+C) is +// received. +func addInterruptHandler(handler func()) { + // Create the channel and start the main interrupt handler which invokes + // all other callbacks and exits if not already done. + if interruptChannel == nil { + interruptChannel = make(chan os.Signal, 1) + signal.Notify(interruptChannel, os.Interrupt) + go mainInterruptHandler() + } + + addHandlerChannel <- handler +}