diff --git a/tests/unit/test_database.py b/tests/unit/test_database.py index f22be4769..45cb029cd 100644 --- a/tests/unit/test_database.py +++ b/tests/unit/test_database.py @@ -11,53 +11,66 @@ from .test_transaction import get_output, NULL_HASH class TestConstraintBuilder(unittest.TestCase): + def test_dot(self): + constraints = {'txo.position': 18} + self.assertEqual( + constraints_to_sql(constraints, prepend_sql=''), + 'txo.position = :txo_position' + ) + self.assertEqual(constraints, {'txo_position': 18}) + def test_any(self): constraints = { 'ages__any': { - 'age__gt': 18, - 'age__lt': 38 + 'txo.age__gt': 18, + 'txo.age__lt': 38 } } self.assertEqual( constraints_to_sql(constraints, prepend_sql=''), - '(age > :ages__any_age__gt OR age < :ages__any_age__lt)' + '(txo.age > :ages__any_txo_age__gt OR txo.age < :ages__any_txo_age__lt)' ) self.assertEqual( constraints, { - 'ages__any_age__gt': 18, - 'ages__any_age__lt': 38 + 'ages__any_txo_age__gt': 18, + 'ages__any_txo_age__lt': 38 } ) def test_in_list(self): - constraints = {'ages__in': [18, 38]} + constraints = {'txo.age__in': [18, 38]} self.assertEqual( constraints_to_sql(constraints, prepend_sql=''), - 'ages IN (:ages_1, :ages_2)' + 'txo.age IN (:txo_age_1, :txo_age_2)' ) self.assertEqual( constraints, { - 'ages_1': 18, - 'ages_2': 38 + 'txo_age_1': 18, + 'txo_age_2': 38 } ) def test_in_query(self): - constraints = {'ages__in': 'SELECT age from ages_table'} + constraints = {'txo.age__in': 'SELECT age from ages_table'} self.assertEqual( constraints_to_sql(constraints, prepend_sql=''), - 'ages IN (SELECT age from ages_table)' + 'txo.age IN (SELECT age from ages_table)' ) self.assertEqual(constraints, {}) def test_not_in_query(self): - constraints = {'ages__not_in': 'SELECT age from ages_table'} + constraints = {'txo.age__not_in': 'SELECT age from ages_table'} self.assertEqual( constraints_to_sql(constraints, prepend_sql=''), - 'ages NOT IN (SELECT age from ages_table)' + 'txo.age NOT IN (SELECT age from ages_table)' ) self.assertEqual(constraints, {}) + def test_in_invalid(self): + constraints = {'ages__in': 9} + with self.assertRaisesRegex(ValueError, 'list or string'): + constraints_to_sql(constraints, prepend_sql='') + class TestQueries(unittest.TestCase): diff --git a/torba/__init__.py b/torba/__init__.py index 2123d238e..c5362f6e8 100644 --- a/torba/__init__.py +++ b/torba/__init__.py @@ -1,2 +1,2 @@ __path__: str = __import__('pkgutil').extend_path(__path__, __name__) -__version__ = '0.0.5' +__version__ = '0.0.7' diff --git a/torba/basedatabase.py b/torba/basedatabase.py index 8dd65d8f6..375fa4337 100644 --- a/torba/basedatabase.py +++ b/torba/basedatabase.py @@ -11,12 +11,16 @@ from torba.basetransaction import BaseTransaction, TXORefResolvable log = logging.getLogger(__name__) +def clean_arg_name(arg): + return arg.replace('.', '_') + + def constraints_to_sql(constraints, joiner=' AND ', prepend_sql=' AND ', prepend_key=''): if not constraints: return '' extras = [] for key in list(constraints): - col, op = key, '=' + col, op, constraint = key, '=', constraints.pop(key) if key.endswith('__not'): col, op = key[:-len('__not')], '!=' elif key.endswith('__lt'): @@ -32,24 +36,27 @@ def constraints_to_sql(constraints, joiner=' AND ', prepend_sql=' AND ', prepend col, op = key[:-len('__in')], 'IN' else: col, op = key[:-len('__not_in')], 'NOT IN' - items = constraints.pop(key) - if isinstance(items, list): + if isinstance(constraint, list): placeholders = [] - for item_no, item in enumerate(items, 1): - constraints['{}_{}'.format(col, item_no)] = item - placeholders.append(':{}_{}'.format(col, item_no)) + for item_no, item in enumerate(constraint, 1): + constraints['{}_{}'.format(clean_arg_name(col), item_no)] = item + placeholders.append(':{}_{}'.format(clean_arg_name(col), item_no)) items = ', '.join(placeholders) + elif isinstance(constraint, str): + items = constraint + else: + raise ValueError("{} requires a list or string as constraint value.".format(key)) extras.append('{} {} ({})'.format(col, op, items)) continue elif key.endswith('__any'): - subconstraints = constraints.pop(key) extras.append('({})'.format( - constraints_to_sql(subconstraints, ' OR ', '', key+'_') + constraints_to_sql(constraint, ' OR ', '', key+'_') )) - for subkey, val in subconstraints.items(): - constraints['{}_{}'.format(key, subkey)] = val + for subkey, val in constraint.items(): + constraints['{}_{}'.format(clean_arg_name(key), clean_arg_name(subkey))] = val continue - extras.append('{} {} :{}'.format(col, op, prepend_key+key)) + constraints[clean_arg_name(key)] = constraint + extras.append('{} {} :{}'.format(col, op, prepend_key+clean_arg_name(key))) return prepend_sql + joiner.join(extras) if extras else ''