general purpose constraints builder

This commit is contained in:
Lex Berezhny 2018-08-03 21:26:53 -04:00
parent 1eb375ba80
commit f5515e5e77
2 changed files with 55 additions and 23 deletions

View file

@ -0,0 +1,23 @@
from unittest import TestCase
from torba.basedatabase import constraints_to_sql
class TestConstraintBuilder(TestCase):
def test_any(self):
constraints = {
'ages__any': {
'age__gt': 18,
'age__lt': 38
}
}
self.assertEqual(
constraints_to_sql(constraints, prepend_sql=''),
'(age > :ages__any_age__gt OR age < :ages__any_age__lt)'
)
self.assertEqual(
constraints, {
'ages__any_age__gt': 18,
'ages__any_age__lt': 38
}
)

View file

@ -11,6 +11,34 @@ from torba.hash import TXRefImmutable
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
def constraints_to_sql(constraints, joiner=' AND ', prepend_sql=' AND ', prepend_key=''):
if not constraints:
return ''
extras = []
for key in list(constraints):
col, op = key, '='
if key.endswith('__not'):
col, op = key[:-len('__not')], '!='
elif key.endswith('__lt'):
col, op = key[:-len('__lt')], '<'
elif key.endswith('__lte'):
col, op = key[:-len('__lte')], '<='
elif key.endswith('__gt'):
col, op = key[:-len('__gt')], '>'
elif key.endswith('__like'):
col, op = key[:-len('__like')], 'LIKE'
elif key.endswith('__any'):
subconstraints = constraints.pop(key)
extras.append('({})'.format(
constraints_to_sql(subconstraints, ' OR ', '', key+'_')
))
for subkey, val in subconstraints.items():
constraints['{}_{}'.format(key, subkey)] = val
continue
extras.append('{} {} :{}'.format(col, op, prepend_key+key))
return prepend_sql + joiner.join(extras) if extras else ''
class SQLiteMixin: class SQLiteMixin:
CREATE_TABLES_QUERY: Sequence[str] = () CREATE_TABLES_QUERY: Sequence[str] = ()
@ -224,21 +252,8 @@ class BaseDatabase(SQLiteMixin):
defer.returnValue((None, None, False)) defer.returnValue((None, None, False))
def get_balance_for_account(self, account, include_reserved=False, **constraints): def get_balance_for_account(self, account, include_reserved=False, **constraints):
extra_sql = ""
if constraints:
extras = []
for key in constraints:
col, op = key, '='
if key.endswith('__not'):
col, op = key[:-len('__not')], '!='
elif key.endswith('__lte'):
col, op = key[:-len('__lte')], '<='
elif key.endswith('__gt'):
col, op = key[:-len('__gt')], '>'
extras.append('{} {} :{}'.format(col, op, key))
extra_sql = ' AND ' + ' AND '.join(extras)
if not include_reserved: if not include_reserved:
extra_sql += ' AND is_reserved=0' constraints['is_reserved'] = 0
values = {'account': account.public_key.address} values = {'account': account.public_key.address}
values.update(constraints) values.update(constraints)
return self.query_one_value( return self.query_one_value(
@ -250,24 +265,18 @@ class BaseDatabase(SQLiteMixin):
WHERE WHERE
pubkey_address.account=:account AND pubkey_address.account=:account AND
txoid NOT IN (SELECT txoid FROM txi) txoid NOT IN (SELECT txoid FROM txi)
"""+extra_sql, values, 0 """+constraints_to_sql(constraints), values, 0
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def get_utxos_for_account(self, account, **constraints): def get_utxos_for_account(self, account, **constraints):
extra_sql = "" constraints['account'] = account.public_key.address
if constraints:
extra_sql = ' AND ' + ' AND '.join(
'{} = :{}'.format(c, c) for c in constraints
)
values = {'account': account.public_key.address}
values.update(constraints)
utxos = yield self.run_query( utxos = yield self.run_query(
""" """
SELECT amount, script, txid, txo.position SELECT amount, script, txid, txo.position
FROM txo JOIN pubkey_address ON pubkey_address.address=txo.address FROM txo JOIN pubkey_address ON pubkey_address.address=txo.address
WHERE account=:account AND txo.is_reserved=0 AND txoid NOT IN (SELECT txoid FROM txi) WHERE account=:account AND txo.is_reserved=0 AND txoid NOT IN (SELECT txoid FROM txi)
"""+extra_sql, values """+constraints_to_sql(constraints), constraints
) )
output_class = account.ledger.transaction_class.output_class output_class = account.ledger.transaction_class.output_class
defer.returnValue([ defer.returnValue([