Fix skip not being applied to mempool txns in searchrawtransactions

This commit is contained in:
Dario Nieuwenhuis 2015-08-28 13:58:18 +02:00
parent 0f9fc42a06
commit 1806557d14
5 changed files with 76 additions and 24 deletions

View file

@ -134,7 +134,9 @@ type Db interface {
// Additionally, if the caller wishes to skip forward in the results // Additionally, if the caller wishes to skip forward in the results
// some amount, the 'seek' represents how many results to skip. // some amount, the 'seek' represents how many results to skip.
// NOTE: Values for both `seek` and `limit` MUST be positive. // NOTE: Values for both `seek` and `limit` MUST be positive.
FetchTxsForAddr(addr btcutil.Address, skip int, limit int) ([]*TxListReply, error) // It will return the array of fetched transactions, along with the amount
// of transactions that were actually skipped.
FetchTxsForAddr(addr btcutil.Address, skip int, limit int) ([]*TxListReply, int, error)
// DeleteAddrIndex deletes the entire addrindex stored within the DB. // DeleteAddrIndex deletes the entire addrindex stored within the DB.
DeleteAddrIndex() error DeleteAddrIndex() error

View file

@ -91,12 +91,12 @@ func testAddrIndexOperations(t *testing.T, db database.Db, newestBlock *btcutil.
// Test enforcement of constraints for "limit" and "skip" // Test enforcement of constraints for "limit" and "skip"
var fakeAddr btcutil.Address var fakeAddr btcutil.Address
_, err = db.FetchTxsForAddr(fakeAddr, -1, 0) _, _, err = db.FetchTxsForAddr(fakeAddr, -1, 0)
if err == nil { if err == nil {
t.Fatalf("Negative value for skip passed, should return an error") t.Fatalf("Negative value for skip passed, should return an error")
} }
_, err = db.FetchTxsForAddr(fakeAddr, 0, -1) _, _, err = db.FetchTxsForAddr(fakeAddr, 0, -1)
if err == nil { if err == nil {
t.Fatalf("Negative value for limit passed, should return an error") t.Fatalf("Negative value for limit passed, should return an error")
} }
@ -136,7 +136,7 @@ func testAddrIndexOperations(t *testing.T, db database.Db, newestBlock *btcutil.
assertAddrIndexTipIsUpdated(db, t, newestSha, newestBlockIdx) assertAddrIndexTipIsUpdated(db, t, newestSha, newestBlockIdx)
// Check index retrieval. // Check index retrieval.
txReplies, err := db.FetchTxsForAddr(testAddrs[0], 0, 1000) txReplies, _, err := db.FetchTxsForAddr(testAddrs[0], 0, 1000)
if err != nil { if err != nil {
t.Fatalf("FetchTxsForAddr failed to correctly fetch txs for an "+ t.Fatalf("FetchTxsForAddr failed to correctly fetch txs for an "+
"address, err %v", err) "address, err %v", err)
@ -171,7 +171,7 @@ func testAddrIndexOperations(t *testing.T, db database.Db, newestBlock *btcutil.
} }
// Former index should no longer exist. // Former index should no longer exist.
txReplies, err = db.FetchTxsForAddr(testAddrs[0], 0, 1000) txReplies, _, err = db.FetchTxsForAddr(testAddrs[0], 0, 1000)
if err != nil { if err != nil {
t.Fatalf("Unable to fetch transactions for address: %v", err) t.Fatalf("Unable to fetch transactions for address: %v", err)
} }
@ -555,30 +555,42 @@ func TestLimitAndSkipFetchTxsForAddr(t *testing.T) {
} }
// Try skipping the first 4 results, should get 6 in return. // Try skipping the first 4 results, should get 6 in return.
txReply, err := testDb.db.FetchTxsForAddr(targetAddr, 4, 100000) txReply, txSkipped, err := testDb.db.FetchTxsForAddr(targetAddr, 4, 100000)
if err != nil { if err != nil {
t.Fatalf("Unable to fetch transactions for address: %v", err) t.Fatalf("Unable to fetch transactions for address: %v", err)
} }
if txSkipped != 4 {
t.Fatalf("Did not correctly return skipped amount"+
" got %v txs, expected %v", txSkipped, 4)
}
if len(txReply) != 6 { if len(txReply) != 6 {
t.Fatalf("Did not correctly skip forward in txs for address reply"+ t.Fatalf("Did not correctly skip forward in txs for address reply"+
" got %v txs, expected %v", len(txReply), 6) " got %v txs, expected %v", len(txReply), 6)
} }
// Limit the number of results to 3. // Limit the number of results to 3.
txReply, err = testDb.db.FetchTxsForAddr(targetAddr, 0, 3) txReply, txSkipped, err = testDb.db.FetchTxsForAddr(targetAddr, 0, 3)
if err != nil { if err != nil {
t.Fatalf("Unable to fetch transactions for address: %v", err) t.Fatalf("Unable to fetch transactions for address: %v", err)
} }
if txSkipped != 0 {
t.Fatalf("Did not correctly return skipped amount"+
" got %v txs, expected %v", txSkipped, 0)
}
if len(txReply) != 3 { if len(txReply) != 3 {
t.Fatalf("Did not correctly limit in txs for address reply"+ t.Fatalf("Did not correctly limit in txs for address reply"+
" got %v txs, expected %v", len(txReply), 3) " got %v txs, expected %v", len(txReply), 3)
} }
// Skip 1, limit 5. // Skip 1, limit 5.
txReply, err = testDb.db.FetchTxsForAddr(targetAddr, 1, 5) txReply, txSkipped, err = testDb.db.FetchTxsForAddr(targetAddr, 1, 5)
if err != nil { if err != nil {
t.Fatalf("Unable to fetch transactions for address: %v", err) t.Fatalf("Unable to fetch transactions for address: %v", err)
} }
if txSkipped != 1 {
t.Fatalf("Did not correctly return skipped amount"+
" got %v txs, expected %v", txSkipped, 1)
}
if len(txReply) != 5 { if len(txReply) != 5 {
t.Fatalf("Did not correctly limit in txs for address reply"+ t.Fatalf("Did not correctly limit in txs for address reply"+
" got %v txs, expected %v", len(txReply), 5) " got %v txs, expected %v", len(txReply), 5)

View file

@ -430,16 +430,16 @@ func bytesPrefix(prefix []byte) *util.Range {
// caller wishes to seek forward in the results some amount, the 'seek' // caller wishes to seek forward in the results some amount, the 'seek'
// represents how many results to skip. // represents how many results to skip.
func (db *LevelDb) FetchTxsForAddr(addr btcutil.Address, skip int, func (db *LevelDb) FetchTxsForAddr(addr btcutil.Address, skip int,
limit int) ([]*database.TxListReply, error) { limit int) ([]*database.TxListReply, int, error) {
db.dbLock.Lock() db.dbLock.Lock()
defer db.dbLock.Unlock() defer db.dbLock.Unlock()
// Enforce constraints for skip and limit. // Enforce constraints for skip and limit.
if skip < 0 { if skip < 0 {
return nil, errors.New("offset for skip must be positive") return nil, 0, errors.New("offset for skip must be positive")
} }
if limit < 0 { if limit < 0 {
return nil, errors.New("value for limit must be positive") return nil, 0, errors.New("value for limit must be positive")
} }
// Parse address type, bailing on an unknown type. // Parse address type, bailing on an unknown type.
@ -455,7 +455,7 @@ func (db *LevelDb) FetchTxsForAddr(addr btcutil.Address, skip int,
hash160 := addr.AddressPubKeyHash().Hash160() hash160 := addr.AddressPubKeyHash().Hash160()
addrKey = hash160[:] addrKey = hash160[:]
default: default:
return nil, database.ErrUnsupportedAddressType return nil, 0, database.ErrUnsupportedAddressType
} }
// Create the prefix for our search. // Create the prefix for our search.
@ -464,8 +464,10 @@ func (db *LevelDb) FetchTxsForAddr(addr btcutil.Address, skip int,
copy(addrPrefix[3:23], addrKey) copy(addrPrefix[3:23], addrKey)
iter := db.lDb.NewIterator(bytesPrefix(addrPrefix), nil) iter := db.lDb.NewIterator(bytesPrefix(addrPrefix), nil)
skipped := 0
for skip != 0 && iter.Next() { for skip != 0 && iter.Next() {
skip-- skip--
skipped++
} }
// Iterate through all address indexes that match the targeted prefix. // Iterate through all address indexes that match the targeted prefix.
@ -491,10 +493,10 @@ func (db *LevelDb) FetchTxsForAddr(addr btcutil.Address, skip int,
} }
iter.Release() iter.Release()
if err := iter.Error(); err != nil { if err := iter.Error(); err != nil {
return nil, err return nil, 0, err
} }
return replies, nil return replies, skipped, nil
} }
// UpdateAddrIndexForBlock updates the stored addrindex with passed // UpdateAddrIndexForBlock updates the stored addrindex with passed

View file

@ -690,8 +690,8 @@ func (db *MemDb) UpdateAddrIndexForBlock(*wire.ShaHash, int32,
// FetchTxsForAddr isn't currently implemented. This is a part of the database.Db // FetchTxsForAddr isn't currently implemented. This is a part of the database.Db
// interface implementation. // interface implementation.
func (db *MemDb) FetchTxsForAddr(btcutil.Address, int, int) ([]*database.TxListReply, error) { func (db *MemDb) FetchTxsForAddr(btcutil.Address, int, int) ([]*database.TxListReply, int, error) {
return nil, database.ErrNotImplemented return nil, 0, database.ErrNotImplemented
} }
// DeleteAddrIndex isn't currently implemented. This is a part of the database.Db // DeleteAddrIndex isn't currently implemented. This is a part of the database.Db

View file

@ -2939,6 +2939,43 @@ func handlePing(s *rpcServer, cmd interface{}, closeChan <-chan struct{}) (inter
return nil, nil return nil, nil
} }
// getMempoolTxsForAddressRange looks up and returns all transactions from the
// mempool related to the given address. The, `limit` parameter
// should be the max number of transactions to be returned. Additionally, if the
// caller wishes to seek forward in the results some amount, the 'seek'
// represents how many results to skip.
// It will return the array of fetched transactions, along with the amount
// of transactions that were actually skipped.
func getMempoolTxsForAddressRange(s *rpcServer, addr btcutil.Address, skip int,
limit int) ([]*database.TxListReply, int, error) {
memPoolTxs, err := s.server.txMemPool.FilterTransactionsByAddress(addr)
if err != nil {
return nil, 0, err
}
// If we're asked to skip more transactions than we have,
// we skip them all and return an empty slice.
if skip >= len(memPoolTxs) {
return nil, len(memPoolTxs), nil
}
var result []*database.TxListReply
// Otherwise, calculate the range we have to return and return it.
rangeEnd := skip + limit
if rangeEnd > len(memPoolTxs) {
rangeEnd = len(memPoolTxs)
}
for _, tx := range memPoolTxs[skip:rangeEnd] {
txReply := &database.TxListReply{Tx: tx.MsgTx(), Sha: tx.Sha()}
result = append(result, txReply)
}
return result, skip, nil
}
// handleSearchRawTransaction implements the searchrawtransactions command. // handleSearchRawTransaction implements the searchrawtransactions command.
func handleSearchRawTransactions(s *rpcServer, cmd interface{}, closeChan <-chan struct{}) (interface{}, error) { func handleSearchRawTransactions(s *rpcServer, cmd interface{}, closeChan <-chan struct{}) (interface{}, error) {
if !cfg.AddrIndex { if !cfg.AddrIndex {
@ -2968,7 +3005,7 @@ func handleSearchRawTransactions(s *rpcServer, cmd interface{}, closeChan <-chan
var addressTxs []*database.TxListReply var addressTxs []*database.TxListReply
var numRequested, numToSkip int var numRequested, numToSkip, skipped int
if c.Count != nil { if c.Count != nil {
numRequested = *c.Count numRequested = *c.Count
if numRequested < 0 { if numRequested < 0 {
@ -2986,9 +3023,10 @@ func handleSearchRawTransactions(s *rpcServer, cmd interface{}, closeChan <-chan
// first, we want to return results in order of occurrence/dependency so // first, we want to return results in order of occurrence/dependency so
// we'll check the mempool only if there aren't enough results returned // we'll check the mempool only if there aren't enough results returned
// by the database. // by the database.
dbTxs, err := s.server.db.FetchTxsForAddr(addr, numToSkip, dbTxs, dbSkipped, err := s.server.db.FetchTxsForAddr(addr, numToSkip,
numRequested-len(addressTxs)) numRequested-len(addressTxs))
if err == nil { if err == nil {
skipped += dbSkipped
for _, txReply := range dbTxs { for _, txReply := range dbTxs {
addressTxs = append(addressTxs, txReply) addressTxs = append(addressTxs, txReply)
} }
@ -2998,14 +3036,12 @@ func handleSearchRawTransactions(s *rpcServer, cmd interface{}, closeChan <-chan
// dependency. This might be something we want to do in the future when we // dependency. This might be something we want to do in the future when we
// return results for the client's convenience, or leave it to the client. // return results for the client's convenience, or leave it to the client.
if len(addressTxs) < numRequested { if len(addressTxs) < numRequested {
memPoolTxs, err := s.server.txMemPool.FilterTransactionsByAddress(addr) memPoolTxs, memPoolSkipped, err := getMempoolTxsForAddressRange(s, addr,
numToSkip-skipped, numRequested-len(addressTxs))
if err == nil { if err == nil {
for _, tx := range memPoolTxs { skipped += memPoolSkipped
txReply := &database.TxListReply{Tx: tx.MsgTx(), Sha: tx.Sha()} for _, txReply := range memPoolTxs {
addressTxs = append(addressTxs, txReply) addressTxs = append(addressTxs, txReply)
if len(addressTxs) == numRequested {
break
}
} }
} }
} }