Stop rescans when a websocket client disconnects.

Fixes #66.
This commit is contained in:
Josh Rickmar 2014-01-14 15:59:31 -05:00
parent 6abad1d8ac
commit 051a9013ce

View file

@ -22,7 +22,12 @@ import (
type ntfnChan chan btcjson.Cmd type ntfnChan chan btcjson.Cmd
type wsCommandHandler func(*rpcServer, btcjson.Cmd, ntfnChan) (interface{}, *btcjson.Error) type handlerChans struct {
n ntfnChan // channel to send notifications
disconnected <-chan struct{} // closed when a client has disconnected.
}
type wsCommandHandler func(*rpcServer, btcjson.Cmd, handlerChans) (interface{}, *btcjson.Error)
// wsHandlers maps RPC command strings to appropriate websocket handler // wsHandlers maps RPC command strings to appropriate websocket handler
// functions. // functions.
@ -226,7 +231,7 @@ type requestContexts struct {
// extension JSON-RPC command and runs the proper handler to reply to // extension JSON-RPC command and runs the proper handler to reply to
// the command. Any and all responses are sent to the wallet from // the command. Any and all responses are sent to the wallet from
// this function. // this function.
func respondToAnyCmd(cmd btcjson.Cmd, s *rpcServer, n ntfnChan) *btcjson.Reply { func respondToAnyCmd(cmd btcjson.Cmd, s *rpcServer, c handlerChans) *btcjson.Reply {
// Lookup the websocket extension for the command and if it doesn't // Lookup the websocket extension for the command and if it doesn't
// exist fallback to handling the command as a standard command. // exist fallback to handling the command as a standard command.
wsHandler, ok := wsHandlers[cmd.Method()] wsHandler, ok := wsHandlers[cmd.Method()]
@ -236,7 +241,7 @@ func respondToAnyCmd(cmd btcjson.Cmd, s *rpcServer, n ntfnChan) *btcjson.Reply {
response := standardCmdReply(cmd, s) response := standardCmdReply(cmd, s)
return &response return &response
} }
result, jsonErr := wsHandler(s, cmd, n) result, jsonErr := wsHandler(s, cmd, c)
id := cmd.Id() id := cmd.Id()
response := btcjson.Reply{ response := btcjson.Reply{
Id: &id, Id: &id,
@ -248,7 +253,7 @@ func respondToAnyCmd(cmd btcjson.Cmd, s *rpcServer, n ntfnChan) *btcjson.Reply {
// handleGetCurrentNet implements the getcurrentnet command extension // handleGetCurrentNet implements the getcurrentnet command extension
// for websocket connections. // for websocket connections.
func handleGetCurrentNet(s *rpcServer, icmd btcjson.Cmd, ntfns ntfnChan) (interface{}, *btcjson.Error) { func handleGetCurrentNet(s *rpcServer, icmd btcjson.Cmd, c handlerChans) (interface{}, *btcjson.Error) {
if cfg.TestNet3 { if cfg.TestNet3 {
return btcwire.TestNet3, nil return btcwire.TestNet3, nil
} }
@ -257,7 +262,7 @@ func handleGetCurrentNet(s *rpcServer, icmd btcjson.Cmd, ntfns ntfnChan) (interf
// handleGetBestBlock implements the getbestblock command extension // handleGetBestBlock implements the getbestblock command extension
// for websocket connections. // for websocket connections.
func handleGetBestBlock(s *rpcServer, icmd btcjson.Cmd, ntfns ntfnChan) (interface{}, *btcjson.Error) { func handleGetBestBlock(s *rpcServer, icmd btcjson.Cmd, c handlerChans) (interface{}, *btcjson.Error) {
// All other "get block" commands give either the height, the // All other "get block" commands give either the height, the
// hash, or both but require the block SHA. This gets both for // hash, or both but require the block SHA. This gets both for
// the best block. // the best block.
@ -276,7 +281,7 @@ func handleGetBestBlock(s *rpcServer, icmd btcjson.Cmd, ntfns ntfnChan) (interfa
// handleNotifyNewTXs implements the notifynewtxs command extension for // handleNotifyNewTXs implements the notifynewtxs command extension for
// websocket connections. // websocket connections.
func handleNotifyNewTXs(s *rpcServer, icmd btcjson.Cmd, ntfns ntfnChan) (interface{}, *btcjson.Error) { func handleNotifyNewTXs(s *rpcServer, icmd btcjson.Cmd, c handlerChans) (interface{}, *btcjson.Error) {
cmd, ok := icmd.(*btcws.NotifyNewTXsCmd) cmd, ok := icmd.(*btcws.NotifyNewTXsCmd)
if !ok { if !ok {
return nil, &btcjson.ErrInternal return nil, &btcjson.ErrInternal
@ -302,7 +307,7 @@ func handleNotifyNewTXs(s *rpcServer, icmd btcjson.Cmd, ntfns ntfnChan) (interfa
return nil, &e return nil, &e
} }
s.ws.AddTxRequest(ntfns, addr.EncodeAddress()) s.ws.AddTxRequest(c.n, addr.EncodeAddress())
} }
return nil, nil return nil, nil
@ -310,20 +315,20 @@ func handleNotifyNewTXs(s *rpcServer, icmd btcjson.Cmd, ntfns ntfnChan) (interfa
// handleNotifySpent implements the notifyspent command extension for // handleNotifySpent implements the notifyspent command extension for
// websocket connections. // websocket connections.
func handleNotifySpent(s *rpcServer, icmd btcjson.Cmd, ntfns ntfnChan) (interface{}, *btcjson.Error) { func handleNotifySpent(s *rpcServer, icmd btcjson.Cmd, c handlerChans) (interface{}, *btcjson.Error) {
cmd, ok := icmd.(*btcws.NotifySpentCmd) cmd, ok := icmd.(*btcws.NotifySpentCmd)
if !ok { if !ok {
return nil, &btcjson.ErrInternal return nil, &btcjson.ErrInternal
} }
s.ws.AddSpentRequest(ntfns, cmd.OutPoint) s.ws.AddSpentRequest(c.n, cmd.OutPoint)
return nil, nil return nil, nil
} }
// handleRescan implements the rescan command extension for websocket // handleRescan implements the rescan command extension for websocket
// connections. // connections.
func handleRescan(s *rpcServer, icmd btcjson.Cmd, ntfns ntfnChan) (interface{}, *btcjson.Error) { func handleRescan(s *rpcServer, icmd btcjson.Cmd, c handlerChans) (interface{}, *btcjson.Error) {
cmd, ok := icmd.(*btcws.RescanCmd) cmd, ok := icmd.(*btcws.RescanCmd)
if !ok { if !ok {
return nil, &btcjson.ErrInternal return nil, &btcjson.ErrInternal
@ -357,54 +362,17 @@ func handleRescan(s *rpcServer, icmd btcjson.Cmd, ntfns ntfnChan) (interface{},
rpcsLog.Errorf("Error looking up block sha: %v", err) rpcsLog.Errorf("Error looking up block sha: %v", err)
return nil, &btcjson.ErrDatabase return nil, &btcjson.ErrDatabase
} }
for _, tx := range blk.Transactions() {
var txReply *btcdb.TxListReply
txouts:
for txOutIdx, txout := range tx.MsgTx().TxOut {
_, addrs, _, err := btcscript.ExtractPkScriptAddrs(
txout.PkScript, s.server.btcnet)
if err != nil {
continue txouts
}
for i, addr := range addrs { // A select statement is used to stop rescans if the
encodedAddr := addr.EncodeAddress() // client requesting the rescan has disconnected.
if _, ok := cmd.Addresses[encodedAddr]; ok { select {
// TODO(jrick): This lookup is expensive and can be avoided case <-c.disconnected:
// if the wallet is sent the previous outpoints for all inputs rpcsLog.Infof("Stopping rescan at height %v for disconnected client",
// of the tx, so any can removed from the utxo set (since blk.Height())
// they are, as of this tx, now spent). return nil, nil
if txReply == nil {
txReplyList, err := s.server.db.FetchTxBySha(tx.Sha())
if err != nil {
rpcsLog.Errorf("Tx Sha %v not found by db.", tx.Sha())
continue txouts
}
for i := range txReplyList {
if txReplyList[i].Height == blk.Height() {
txReply = txReplyList[i]
break
}
}
} default:
rescanBlock(s, cmd, c, blk)
n := &btcws.ProcessedTxNtfn{
Receiver: encodedAddr,
Amount: txout.Value,
TxID: tx.Sha().String(),
TxOutIndex: uint32(txOutIdx),
PkScript: hex.EncodeToString(txout.PkScript),
BlockHash: blkshalist[i].String(),
BlockHeight: int32(blk.Height()),
BlockIndex: tx.Index(),
BlockTime: blk.MsgBlock().Header.Timestamp.Unix(),
Spent: txReply.TxSpent[txOutIdx],
}
ntfns <- n
}
}
}
} }
} }
@ -419,9 +387,74 @@ func handleRescan(s *rpcServer, icmd btcjson.Cmd, ntfns ntfnChan) (interface{},
return nil, nil return nil, nil
} }
// rescanBlock rescans all transactions in a single block. This is a
// helper function for handleRescan.
func rescanBlock(s *rpcServer, cmd *btcws.RescanCmd, c handlerChans, blk *btcutil.Block) {
for _, tx := range blk.Transactions() {
var txReply *btcdb.TxListReply
txouts:
for txOutIdx, txout := range tx.MsgTx().TxOut {
_, addrs, _, err := btcscript.ExtractPkScriptAddrs(
txout.PkScript, s.server.btcnet)
if err != nil {
continue txouts
}
for _, addr := range addrs {
encodedAddr := addr.EncodeAddress()
if _, ok := cmd.Addresses[encodedAddr]; !ok {
continue
}
// TODO(jrick): This lookup is expensive and can be avoided
// if the wallet is sent the previous outpoints for all inputs
// of the tx, so any can removed from the utxo set (since
// they are, as of this tx, now spent).
if txReply == nil {
txReplyList, err := s.server.db.FetchTxBySha(tx.Sha())
if err != nil {
rpcsLog.Errorf("Tx Sha %v not found by db.", tx.Sha())
continue txouts
}
for i := range txReplyList {
if txReplyList[i].Height == blk.Height() {
txReply = txReplyList[i]
break
}
}
}
// Sha never errors.
blksha, _ := blk.Sha()
ntfn := &btcws.ProcessedTxNtfn{
Receiver: encodedAddr,
Amount: txout.Value,
TxID: tx.Sha().String(),
TxOutIndex: uint32(txOutIdx),
PkScript: hex.EncodeToString(txout.PkScript),
BlockHash: blksha.String(),
BlockHeight: int32(blk.Height()),
BlockIndex: tx.Index(),
BlockTime: blk.MsgBlock().Header.Timestamp.Unix(),
Spent: txReply.TxSpent[txOutIdx],
}
select {
case <-c.disconnected:
return
default:
c.n <- ntfn
}
}
}
}
}
// handleWalletSendRawTransaction implements the websocket extended version of // handleWalletSendRawTransaction implements the websocket extended version of
// the sendrawtransaction command. // the sendrawtransaction command.
func handleWalletSendRawTransaction(s *rpcServer, icmd btcjson.Cmd, n ntfnChan) (interface{}, *btcjson.Error) { func handleWalletSendRawTransaction(s *rpcServer, icmd btcjson.Cmd, c handlerChans) (interface{}, *btcjson.Error) {
result, err := handleSendRawTransaction(s, icmd) result, err := handleSendRawTransaction(s, icmd)
// TODO: the standard handlers really should be changed to // TODO: the standard handlers really should be changed to
// return btcjson.Errors which get used directly in the // return btcjson.Errors which get used directly in the
@ -443,7 +476,7 @@ func handleWalletSendRawTransaction(s *rpcServer, icmd btcjson.Cmd, n ntfnChan)
txSha, _ := btcwire.NewShaHashFromStr(result.(string)) txSha, _ := btcwire.NewShaHashFromStr(result.(string))
// Request to be notified when the transaction is mined. // Request to be notified when the transaction is mined.
s.ws.AddMinedTxRequest(n, txSha) s.ws.AddMinedTxRequest(c.n, txSha)
return result, nil return result, nil
} }
@ -513,6 +546,13 @@ func (s *rpcServer) walletReqsNotifications(ws *websocket.Conn) {
// Channel for responses. // Channel for responses.
r := make(chan *btcjson.Reply) r := make(chan *btcjson.Reply)
// Channels for websocket handlers.
disconnected := make(chan struct{})
hc := handlerChans{
n: n,
disconnected: disconnected,
}
// msgs is a channel for all messages received over the websocket. // msgs is a channel for all messages received over the websocket.
msgs := make(chan []byte) msgs := make(chan []byte)
@ -522,13 +562,12 @@ func (s *rpcServer) walletReqsNotifications(ws *websocket.Conn) {
for { for {
select { select {
case <-s.quit: case <-s.quit:
close(msgs)
return return
default: default:
var m []byte var m []byte
if err := websocket.Message.Receive(ws, &m); err != nil { if err := websocket.Message.Receive(ws, &m); err != nil {
close(msgs) close(disconnected)
return return
} }
msgs <- m msgs <- m
@ -538,13 +577,27 @@ func (s *rpcServer) walletReqsNotifications(ws *websocket.Conn) {
for { for {
select { select {
case m, ok := <-msgs: case <-s.quit:
if !ok { // Server closed. Closing disconnected signals handlers to stop
// Wallet disconnected. // and flushes all channels handlers may write to.
return close(disconnected)
case <-disconnected:
for {
select {
case <-msgs:
case <-r:
case <-n:
default:
return
}
} }
// Handle request here.
go s.websocketJSONHandler(r, n, m) case m := <-msgs:
// Spawn new goroutine to handle request. Responses and
// notifications are read by channels in this for-select
// loop.
go s.websocketJSONHandler(r, hc, m)
case response := <-r: case response := <-r:
// Marshal and send response. // Marshal and send response.
@ -555,6 +608,7 @@ func (s *rpcServer) walletReqsNotifications(ws *websocket.Conn) {
} }
if err := websocket.Message.Send(ws, mresp); err != nil { if err := websocket.Message.Send(ws, mresp); err != nil {
// Wallet disconnected. // Wallet disconnected.
close(disconnected)
return return
} }
@ -567,25 +621,24 @@ func (s *rpcServer) walletReqsNotifications(ws *websocket.Conn) {
} }
if err := websocket.Message.Send(ws, mntfn); err != nil { if err := websocket.Message.Send(ws, mntfn); err != nil {
// Wallet disconnected. // Wallet disconnected.
close(disconnected)
return return
} }
case <-s.quit:
// Server closed.
return
} }
} }
} }
// websocketJSONHandler parses and handles a marshalled json message, // websocketJSONHandler parses and handles a marshalled json message,
// sending the marshalled reply to a wallet notification channel. // sending the marshalled reply to a wallet notification channel.
func (s *rpcServer) websocketJSONHandler(r chan *btcjson.Reply, n ntfnChan, msg []byte) { func (s *rpcServer) websocketJSONHandler(r chan *btcjson.Reply, c handlerChans, msg []byte) {
s.wg.Add(1) s.wg.Add(1)
defer s.wg.Done() defer s.wg.Done()
var resp *btcjson.Reply
cmd, jsonErr := parseCmd(msg) cmd, jsonErr := parseCmd(msg)
if jsonErr != nil { if jsonErr != nil {
var resp btcjson.Reply resp = &btcjson.Reply{}
if cmd != nil { if cmd != nil {
// Unmarshaling at least a valid JSON-RPC message succeeded. // Unmarshaling at least a valid JSON-RPC message succeeded.
// Use the provided id for errors. Requests with no IDs // Use the provided id for errors. Requests with no IDs
@ -597,12 +650,19 @@ func (s *rpcServer) websocketJSONHandler(r chan *btcjson.Reply, n ntfnChan, msg
resp.Id = &id resp.Id = &id
} }
resp.Error = jsonErr resp.Error = jsonErr
r <- &resp } else {
return resp = respondToAnyCmd(cmd, s, c)
} }
resp := respondToAnyCmd(cmd, s, n) // Once response has been processed, only send if the client
r <- resp // is still connected.
select {
case <-c.disconnected:
return
default:
r <- resp
}
} }
// NotifyBlockConnected creates and marshalls a JSON message to notify // NotifyBlockConnected creates and marshalls a JSON message to notify