From f5515e5e77c599a86817d5b34384b7e4670ee27f Mon Sep 17 00:00:00 2001 From: Lex Berezhny Date: Fri, 3 Aug 2018 21:26:53 -0400 Subject: [PATCH] general purpose constraints builder --- tests/unit/test_database.py | 23 ++++++++++++++++ torba/basedatabase.py | 55 +++++++++++++++++++++---------------- 2 files changed, 55 insertions(+), 23 deletions(-) create mode 100644 tests/unit/test_database.py diff --git a/tests/unit/test_database.py b/tests/unit/test_database.py new file mode 100644 index 000000000..7831f7916 --- /dev/null +++ b/tests/unit/test_database.py @@ -0,0 +1,23 @@ +from unittest import TestCase +from torba.basedatabase import constraints_to_sql + + +class TestConstraintBuilder(TestCase): + + def test_any(self): + constraints = { + 'ages__any': { + 'age__gt': 18, + 'age__lt': 38 + } + } + self.assertEqual( + constraints_to_sql(constraints, prepend_sql=''), + '(age > :ages__any_age__gt OR age < :ages__any_age__lt)' + ) + self.assertEqual( + constraints, { + 'ages__any_age__gt': 18, + 'ages__any_age__lt': 38 + } + ) diff --git a/torba/basedatabase.py b/torba/basedatabase.py index b086f7700..a65e2e029 100644 --- a/torba/basedatabase.py +++ b/torba/basedatabase.py @@ -11,6 +11,34 @@ from torba.hash import TXRefImmutable log = logging.getLogger(__name__) +def constraints_to_sql(constraints, joiner=' AND ', prepend_sql=' AND ', prepend_key=''): + if not constraints: + return '' + extras = [] + for key in list(constraints): + col, op = key, '=' + if key.endswith('__not'): + col, op = key[:-len('__not')], '!=' + elif key.endswith('__lt'): + col, op = key[:-len('__lt')], '<' + elif key.endswith('__lte'): + col, op = key[:-len('__lte')], '<=' + elif key.endswith('__gt'): + col, op = key[:-len('__gt')], '>' + elif key.endswith('__like'): + col, op = key[:-len('__like')], 'LIKE' + elif key.endswith('__any'): + subconstraints = constraints.pop(key) + extras.append('({})'.format( + constraints_to_sql(subconstraints, ' OR ', '', key+'_') + )) + for subkey, val in subconstraints.items(): + constraints['{}_{}'.format(key, subkey)] = val + continue + extras.append('{} {} :{}'.format(col, op, prepend_key+key)) + return prepend_sql + joiner.join(extras) if extras else '' + + class SQLiteMixin: CREATE_TABLES_QUERY: Sequence[str] = () @@ -224,21 +252,8 @@ class BaseDatabase(SQLiteMixin): defer.returnValue((None, None, False)) def get_balance_for_account(self, account, include_reserved=False, **constraints): - extra_sql = "" - if constraints: - extras = [] - for key in constraints: - col, op = key, '=' - if key.endswith('__not'): - col, op = key[:-len('__not')], '!=' - elif key.endswith('__lte'): - col, op = key[:-len('__lte')], '<=' - elif key.endswith('__gt'): - col, op = key[:-len('__gt')], '>' - extras.append('{} {} :{}'.format(col, op, key)) - extra_sql = ' AND ' + ' AND '.join(extras) if not include_reserved: - extra_sql += ' AND is_reserved=0' + constraints['is_reserved'] = 0 values = {'account': account.public_key.address} values.update(constraints) return self.query_one_value( @@ -250,24 +265,18 @@ class BaseDatabase(SQLiteMixin): WHERE pubkey_address.account=:account AND txoid NOT IN (SELECT txoid FROM txi) - """+extra_sql, values, 0 + """+constraints_to_sql(constraints), values, 0 ) @defer.inlineCallbacks def get_utxos_for_account(self, account, **constraints): - extra_sql = "" - if constraints: - extra_sql = ' AND ' + ' AND '.join( - '{} = :{}'.format(c, c) for c in constraints - ) - values = {'account': account.public_key.address} - values.update(constraints) + constraints['account'] = account.public_key.address utxos = yield self.run_query( """ SELECT amount, script, txid, txo.position FROM txo JOIN pubkey_address ON pubkey_address.address=txo.address WHERE account=:account AND txo.is_reserved=0 AND txoid NOT IN (SELECT txoid FROM txi) - """+extra_sql, values + """+constraints_to_sql(constraints), constraints ) output_class = account.ledger.transaction_class.output_class defer.returnValue([