diff --git a/tests/client_tests/unit/test_database.py b/tests/client_tests/unit/test_database.py index 29aaa684b..d722f62db 100644 --- a/tests/client_tests/unit/test_database.py +++ b/tests/client_tests/unit/test_database.py @@ -76,11 +76,17 @@ class TestQueryBuilder(unittest.TestCase): def test_in(self): self.assertEqual( constraints_to_sql({'txo.age__in': [18, 38]}), - ('txo.age IN (18, 38)', {}) + ('txo.age IN (:txo_age__in0, :txo_age__in1)', { + 'txo_age__in0': 18, + 'txo_age__in1': 38 + }) ) self.assertEqual( - constraints_to_sql({'txo.age__in': ['abc123', 'def456']}), - ("txo.age IN ('abc123', 'def456')", {}) + constraints_to_sql({'txo.name__in': ('abc123', 'def456')}), + ('txo.name IN (:txo_name__in0, :txo_name__in1)', { + 'txo_name__in0': 'abc123', + 'txo_name__in1': 'def456' + }) ) self.assertEqual( constraints_to_sql({'txo.age__in': 'SELECT age from ages_table'}), @@ -90,7 +96,17 @@ class TestQueryBuilder(unittest.TestCase): def test_not_in(self): self.assertEqual( constraints_to_sql({'txo.age__not_in': [18, 38]}), - ('txo.age NOT IN (18, 38)', {}) + ('txo.age NOT IN (:txo_age__not_in0, :txo_age__not_in1)', { + 'txo_age__not_in0': 18, + 'txo_age__not_in1': 38 + }) + ) + self.assertEqual( + constraints_to_sql({'txo.name__not_in': ('abc123', 'def456')}), + ('txo.name NOT IN (:txo_name__not_in0, :txo_name__not_in1)', { + 'txo_name__not_in0': 'abc123', + 'txo_name__not_in1': 'def456' + }) ) self.assertEqual( constraints_to_sql({'txo.age__not_in': 'SELECT age from ages_table'}), diff --git a/torba/client/basedatabase.py b/torba/client/basedatabase.py index 0a91d25cd..8e4b9ccdc 100644 --- a/torba/client/basedatabase.py +++ b/torba/client/basedatabase.py @@ -93,6 +93,14 @@ def constraints_to_sql(constraints, joiner=' AND ', prepend_key=''): continue elif key.endswith('__not'): col, op = col[:-len('__not')], '!=' + elif key.endswith('__is_null'): + col = col[:-len('__is_null')] + sql.append(f'{col} IS NULL') + continue + elif key.endswith('__is_not_null'): + col = col[:-len('__is_not_null')] + sql.append(f'{col} IS NOT NULL') + continue elif key.endswith('__lt'): col, op = col[:-len('__lt')], '<' elif key.endswith('__lte'): @@ -108,23 +116,24 @@ def constraints_to_sql(constraints, joiner=' AND ', prepend_key=''): col, op = col[:-len('__in')], 'IN' else: col, op = col[:-len('__not_in')], 'NOT IN' - if isinstance(constraint, (list, set)): - items = ', '.join( - "'{}'".format(item) if isinstance(item, str) else str(item) - for item in constraint - ) - elif isinstance(constraint, str): - items = constraint - else: - raise ValueError("{} requires a list, set or string as constraint value.".format(col)) - sql.append('{} {} ({})'.format(col, op, items)) + if constraint: + if isinstance(constraint, (list, set, tuple)): + keys = [] + for i, val in enumerate(constraint): + keys.append(f':{key}{i}') + values[f'{key}{i}'] = val + sql.append(f'{col} {op} ({", ".join(keys)})') + elif isinstance(constraint, str): + sql.append(f'{col} {op} ({constraint})') + else: + raise ValueError(f"{col} requires a list, set or string as constraint value.") continue elif key.endswith('__any'): where, subvalues = constraints_to_sql(constraint, ' OR ', key+'_') - sql.append('({})'.format(where)) + sql.append(f'({where})') values.update(subvalues) continue - sql.append('{} {} :{}'.format(col, op, prepend_key+key)) + sql.append(f'{col} {op} :{prepend_key}{key}') values[prepend_key+key] = constraint return joiner.join(sql) if sql else '', values @@ -382,12 +391,14 @@ class BaseDatabase(SQLiteMixin): if not tx_rows: return [] - txids, txs = [], [] + txids, txs, txi_txoids = [], [], [] for row in tx_rows: txids.append(row[0]) txs.append(self.ledger.transaction_class( raw=row[1], height=row[2], position=row[3], is_verified=bool(row[4]) )) + for txi in txs[-1].inputs: + txi_txoids.append(txi.txo_ref.id) annotated_txos = { txo.id: txo for txo in @@ -401,7 +412,7 @@ class BaseDatabase(SQLiteMixin): txo.id: txo for txo in (await self.get_txos( my_account=my_account, - txoid__in=query("SELECT txoid FROM txi", **{'txid__in': txids})[0] + txoid__in=txi_txoids )) } diff --git a/torba/client/basenetwork.py b/torba/client/basenetwork.py index 0915fb056..d6ef84dc2 100644 --- a/torba/client/basenetwork.py +++ b/torba/client/basenetwork.py @@ -114,32 +114,32 @@ class BaseNetwork: def is_connected(self): return self.client is not None and not self.client.is_closing() - def rpc(self, list_or_method, *args): + def rpc(self, list_or_method, args): if self.is_connected: return self.client.send_request(list_or_method, args) else: raise ConnectionError("Attempting to send rpc request when connection is not available.") def ensure_server_version(self, required='1.2'): - return self.rpc('server.version', __version__, required) + return self.rpc('server.version', [__version__, required]) def broadcast(self, raw_transaction): - return self.rpc('blockchain.transaction.broadcast', raw_transaction) + return self.rpc('blockchain.transaction.broadcast', [raw_transaction]) def get_history(self, address): - return self.rpc('blockchain.address.get_history', address) + return self.rpc('blockchain.address.get_history', [address]) def get_transaction(self, tx_hash): - return self.rpc('blockchain.transaction.get', tx_hash) + return self.rpc('blockchain.transaction.get', [tx_hash]) def get_merkle(self, tx_hash, height): - return self.rpc('blockchain.transaction.get_merkle', tx_hash, height) + return self.rpc('blockchain.transaction.get_merkle', [tx_hash, height]) def get_headers(self, height, count=10000): - return self.rpc('blockchain.block.headers', height, count) + return self.rpc('blockchain.block.headers', [height, count]) def subscribe_headers(self): - return self.rpc('blockchain.headers.subscribe', True) + return self.rpc('blockchain.headers.subscribe', [True]) def subscribe_address(self, address): - return self.rpc('blockchain.address.subscribe', address) + return self.rpc('blockchain.address.subscribe', [address]) diff --git a/torba/server/cli.py b/torba/server/cli.py index 2409ec52e..bf32bfb6b 100644 --- a/torba/server/cli.py +++ b/torba/server/cli.py @@ -2,8 +2,8 @@ import logging import traceback import argparse import importlib -from .env import Env -from .server import Server +from torba.server.env import Env +from torba.server.server import Server def get_argument_parser():