improved db.get_transactions()

This commit is contained in:
Lex Berezhny 2018-09-21 22:18:30 -04:00
parent 8a87195f55
commit 8ed0791b26
4 changed files with 65 additions and 46 deletions

View file

@ -9,7 +9,7 @@ class BasicTransactionTests(IntegrationTestCase):
async def test_sending_and_receiving(self): async def test_sending_and_receiving(self):
account1, account2 = self.account, self.wallet.generate_account(self.ledger) account1, account2 = self.account, self.wallet.generate_account(self.ledger)
yield self.ledger.update_account(account2) await d2f(self.ledger.update_account(account2))
self.assertEqual(await self.get_balance(account1), 0) self.assertEqual(await self.get_balance(account1), 0)
self.assertEqual(await self.get_balance(account2), 0) self.assertEqual(await self.get_balance(account2), 0)
@ -53,3 +53,12 @@ class BasicTransactionTests(IntegrationTestCase):
await self.on_transaction(tx) # mempool await self.on_transaction(tx) # mempool
await self.blockchain.generate(1) await self.blockchain.generate(1)
await self.on_transaction(tx) # confirmed await self.on_transaction(tx) # confirmed
txs = await d2f(account1.get_transactions())
tx = txs[1]
self.assertEqual(round(tx.inputs[0].txo_ref.txo.amount/COIN, 1), 1.1)
self.assertEqual(round(tx.inputs[1].txo_ref.txo.amount/COIN, 1), 1.1)
self.assertEqual(round(tx.outputs[0].amount/COIN, 1), 2.0)
self.assertEqual(tx.outputs[0].get_address(self.ledger), address2)
self.assertEqual(tx.outputs[0].is_change, False)
self.assertEqual(tx.outputs[1].is_change, True)

View file

@ -342,9 +342,6 @@ class BaseAccount:
def get_unspent_outputs(self, **constraints): def get_unspent_outputs(self, **constraints):
return self.ledger.db.get_utxos_for_account(self, **constraints) return self.ledger.db.get_utxos_for_account(self, **constraints)
def get_inputs_outputs(self, **constraints):
return self.ledger.db.get_txios_for_account(self, **constraints)
def get_transactions(self): def get_transactions(self):
return self.ledger.db.get_transactions(self) return self.ledger.db.get_transactions(self)

View file

@ -6,6 +6,7 @@ from twisted.internet import defer
from twisted.enterprise import adbapi from twisted.enterprise import adbapi
from torba.hash import TXRefImmutable from torba.hash import TXRefImmutable
from torba.basetransaction import TXORefResolvable
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -267,24 +268,62 @@ class BaseDatabase(SQLiteMixin):
return None, None, False return None, None, False
@defer.inlineCallbacks @defer.inlineCallbacks
def get_transactions(self, account): def get_transactions(self, account, offset=0, limit=100):
txs = self.run_query( offset, limit = min(offset, 0), max(limit, 100)
tx_records = yield self.run_query(
""" """
SELECT raw FROM tx where txid in ( SELECT txid, raw, height FROM tx WHERE txid IN (
SELECT txo.txid SELECT txo.txid FROM txo
FROM txo
JOIN pubkey_address USING (address) JOIN pubkey_address USING (address)
WHERE pubkey_address.account = :account WHERE pubkey_address.account = :account
UNION UNION
SELECT txo.txid SELECT txo.txid FROM txi
FROM txi
JOIN txo USING (txoid) JOIN txo USING (txoid)
JOIN pubkey_address USING (address) JOIN pubkey_address USING (address)
WHERE pubkey_address.account = :account WHERE pubkey_address.account = :account
) ) ORDER BY height DESC LIMIT :offset, :limit
""", {'account': account.public_key.address} """, {'account': account.public_key.address, 'offset': offset, 'limit': limit}
) )
return [account.ledger.transaction_class(values[0]) for values in txs] txids, txs = [], []
for r in tx_records:
txids.append(r[0])
txs.append(account.ledger.transaction_class(raw=r[1], height=r[2]))
txo_records = yield self.run_query(
"""
SELECT txoid, pubkey_address.chain
FROM txo JOIN pubkey_address USING (address)
WHERE txid IN ({})
""".format(', '.join(['?']*len(txids))), txids
)
txos = dict(txo_records)
txi_records = yield self.run_query(
"""
SELECT txoid, txo.amount, txo.script, txo.txid, txo.position
FROM txi JOIN txo USING (txoid)
WHERE txi.txid IN ({})
""".format(', '.join(['?']*len(txids))), txids
)
txis = {}
output_class = account.ledger.transaction_class.output_class
for r in txi_records:
txis[r[0]] = output_class(
r[1],
output_class.script_class(r[2]),
TXRefImmutable.from_id(r[3]),
position=r[4]
)
for tx in txs:
for txi in tx.inputs:
if txi.txo_ref.id in txis:
txi.txo_ref = TXORefResolvable(txis[txi.txo_ref.id])
for txo in tx.outputs:
if txo.id in txos:
txo.is_change = txos[txo.id] == 1
return txs
def get_balance_for_account(self, account, include_reserved=False, **constraints): def get_balance_for_account(self, account, include_reserved=False, **constraints):
if not include_reserved: if not include_reserved:
@ -323,26 +362,6 @@ class BaseDatabase(SQLiteMixin):
) for values in utxos ) for values in utxos
] ]
@defer.inlineCallbacks
def get_txios_for_account(self, account, **constraints):
constraints['account'] = account.public_key.address
utxos = yield self.run_query(
"""
SELECT amount, script, txid, txo.position
FROM txo JOIN pubkey_address ON pubkey_address.address=txo.address
WHERE account=:account AND txo.is_reserved=0 AND txoid NOT IN (SELECT txoid FROM txi)
"""+constraints_to_sql(constraints), constraints
)
output_class = account.ledger.transaction_class.output_class
return [
output_class(
values[0],
output_class.script_class(values[1]),
TXRefImmutable.from_id(values[2]),
position=values[3]
) for values in utxos
]
def add_keys(self, account, chain, keys): def add_keys(self, account, chain, keys):
sql = ( sql = (
"insert into pubkey_address " "insert into pubkey_address "

View file

@ -181,13 +181,14 @@ class BaseOutput(InputOutput):
script_class = BaseOutputScript script_class = BaseOutputScript
estimator_class = BaseOutputEffectiveAmountEstimator estimator_class = BaseOutputEffectiveAmountEstimator
__slots__ = 'amount', 'script' __slots__ = 'amount', 'script', 'is_change'
def __init__(self, amount: int, script: BaseOutputScript, def __init__(self, amount: int, script: BaseOutputScript,
tx_ref: TXRef = None, position: int = None) -> None: tx_ref: TXRef = None, position: int = None) -> None:
super().__init__(tx_ref, position) super().__init__(tx_ref, position)
self.amount = amount self.amount = amount
self.script = script self.script = script
self.is_change = None
@property @property
def ref(self): def ref(self):
@ -226,13 +227,14 @@ class BaseTransaction:
input_class = BaseInput input_class = BaseInput
output_class = BaseOutput output_class = BaseOutput
def __init__(self, raw=None, version=1, locktime=0) -> None: def __init__(self, raw=None, version=1, locktime=0, height=None) -> None:
self._raw = raw self._raw = raw
self.ref = TXRefMutable(self) self.ref = TXRefMutable(self)
self.version = version # type: int self.version = version # type: int
self.locktime = locktime # type: int self.locktime = locktime # type: int
self._inputs = [] # type: List[BaseInput] self._inputs = [] # type: List[BaseInput]
self._outputs = [] # type: List[BaseOutput] self._outputs = [] # type: List[BaseOutput]
self.height = height
if raw is not None: if raw is not None:
self._deserialize() self._deserialize()
@ -416,6 +418,8 @@ class BaseTransaction:
change_address = yield change_account.change.get_or_create_usable_address() change_address = yield change_account.change.get_or_create_usable_address()
change_hash160 = change_account.ledger.address_to_hash160(change_address) change_hash160 = change_account.ledger.address_to_hash160(change_address)
change_amount = change - cost_of_change change_amount = change - cost_of_change
change_output = cls.output_class.pay_pubkey_hash(change_amount, change_hash160)
change_output.is_change = True
tx.add_outputs([cls.output_class.pay_pubkey_hash(change_amount, change_hash160)]) tx.add_outputs([cls.output_class.pay_pubkey_hash(change_amount, change_hash160)])
yield tx.sign(funding_accounts) yield tx.sign(funding_accounts)
@ -449,13 +453,3 @@ class BaseTransaction:
else: else:
raise NotImplementedError("Don't know how to spend this output.") raise NotImplementedError("Don't know how to spend this output.")
self._reset() self._reset()
@defer.inlineCallbacks
def get_my_addresses(self, ledger):
addresses = set()
for txo in self.outputs:
address = ledger.hash160_to_address(txo.script.values['pubkey_hash'])
record = yield ledger.db.get_address(address)
if record is not None:
addresses.add(address)
defer.returnValue(list(addresses))