sql constraints support table.column dot delimited column names

This commit is contained in:
Lex Berezhny 2018-10-03 11:51:42 -04:00
parent 1f4a9cff26
commit 0960762694
3 changed files with 45 additions and 25 deletions

View file

@ -11,53 +11,66 @@ from .test_transaction import get_output, NULL_HASH
class TestConstraintBuilder(unittest.TestCase): class TestConstraintBuilder(unittest.TestCase):
def test_dot(self):
constraints = {'txo.position': 18}
self.assertEqual(
constraints_to_sql(constraints, prepend_sql=''),
'txo.position = :txo_position'
)
self.assertEqual(constraints, {'txo_position': 18})
def test_any(self): def test_any(self):
constraints = { constraints = {
'ages__any': { 'ages__any': {
'age__gt': 18, 'txo.age__gt': 18,
'age__lt': 38 'txo.age__lt': 38
} }
} }
self.assertEqual( self.assertEqual(
constraints_to_sql(constraints, prepend_sql=''), constraints_to_sql(constraints, prepend_sql=''),
'(age > :ages__any_age__gt OR age < :ages__any_age__lt)' '(txo.age > :ages__any_txo_age__gt OR txo.age < :ages__any_txo_age__lt)'
) )
self.assertEqual( self.assertEqual(
constraints, { constraints, {
'ages__any_age__gt': 18, 'ages__any_txo_age__gt': 18,
'ages__any_age__lt': 38 'ages__any_txo_age__lt': 38
} }
) )
def test_in_list(self): def test_in_list(self):
constraints = {'ages__in': [18, 38]} constraints = {'txo.age__in': [18, 38]}
self.assertEqual( self.assertEqual(
constraints_to_sql(constraints, prepend_sql=''), constraints_to_sql(constraints, prepend_sql=''),
'ages IN (:ages_1, :ages_2)' 'txo.age IN (:txo_age_1, :txo_age_2)'
) )
self.assertEqual( self.assertEqual(
constraints, { constraints, {
'ages_1': 18, 'txo_age_1': 18,
'ages_2': 38 'txo_age_2': 38
} }
) )
def test_in_query(self): def test_in_query(self):
constraints = {'ages__in': 'SELECT age from ages_table'} constraints = {'txo.age__in': 'SELECT age from ages_table'}
self.assertEqual( self.assertEqual(
constraints_to_sql(constraints, prepend_sql=''), constraints_to_sql(constraints, prepend_sql=''),
'ages IN (SELECT age from ages_table)' 'txo.age IN (SELECT age from ages_table)'
) )
self.assertEqual(constraints, {}) self.assertEqual(constraints, {})
def test_not_in_query(self): def test_not_in_query(self):
constraints = {'ages__not_in': 'SELECT age from ages_table'} constraints = {'txo.age__not_in': 'SELECT age from ages_table'}
self.assertEqual( self.assertEqual(
constraints_to_sql(constraints, prepend_sql=''), constraints_to_sql(constraints, prepend_sql=''),
'ages NOT IN (SELECT age from ages_table)' 'txo.age NOT IN (SELECT age from ages_table)'
) )
self.assertEqual(constraints, {}) self.assertEqual(constraints, {})
def test_in_invalid(self):
constraints = {'ages__in': 9}
with self.assertRaisesRegex(ValueError, 'list or string'):
constraints_to_sql(constraints, prepend_sql='')
class TestQueries(unittest.TestCase): class TestQueries(unittest.TestCase):

View file

@ -1,2 +1,2 @@
__path__: str = __import__('pkgutil').extend_path(__path__, __name__) __path__: str = __import__('pkgutil').extend_path(__path__, __name__)
__version__ = '0.0.5' __version__ = '0.0.7'

View file

@ -11,12 +11,16 @@ from torba.basetransaction import BaseTransaction, TXORefResolvable
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
def clean_arg_name(arg):
return arg.replace('.', '_')
def constraints_to_sql(constraints, joiner=' AND ', prepend_sql=' AND ', prepend_key=''): def constraints_to_sql(constraints, joiner=' AND ', prepend_sql=' AND ', prepend_key=''):
if not constraints: if not constraints:
return '' return ''
extras = [] extras = []
for key in list(constraints): for key in list(constraints):
col, op = key, '=' col, op, constraint = key, '=', constraints.pop(key)
if key.endswith('__not'): if key.endswith('__not'):
col, op = key[:-len('__not')], '!=' col, op = key[:-len('__not')], '!='
elif key.endswith('__lt'): elif key.endswith('__lt'):
@ -32,24 +36,27 @@ def constraints_to_sql(constraints, joiner=' AND ', prepend_sql=' AND ', prepend
col, op = key[:-len('__in')], 'IN' col, op = key[:-len('__in')], 'IN'
else: else:
col, op = key[:-len('__not_in')], 'NOT IN' col, op = key[:-len('__not_in')], 'NOT IN'
items = constraints.pop(key) if isinstance(constraint, list):
if isinstance(items, list):
placeholders = [] placeholders = []
for item_no, item in enumerate(items, 1): for item_no, item in enumerate(constraint, 1):
constraints['{}_{}'.format(col, item_no)] = item constraints['{}_{}'.format(clean_arg_name(col), item_no)] = item
placeholders.append(':{}_{}'.format(col, item_no)) placeholders.append(':{}_{}'.format(clean_arg_name(col), item_no))
items = ', '.join(placeholders) items = ', '.join(placeholders)
elif isinstance(constraint, str):
items = constraint
else:
raise ValueError("{} requires a list or string as constraint value.".format(key))
extras.append('{} {} ({})'.format(col, op, items)) extras.append('{} {} ({})'.format(col, op, items))
continue continue
elif key.endswith('__any'): elif key.endswith('__any'):
subconstraints = constraints.pop(key)
extras.append('({})'.format( extras.append('({})'.format(
constraints_to_sql(subconstraints, ' OR ', '', key+'_') constraints_to_sql(constraint, ' OR ', '', key+'_')
)) ))
for subkey, val in subconstraints.items(): for subkey, val in constraint.items():
constraints['{}_{}'.format(key, subkey)] = val constraints['{}_{}'.format(clean_arg_name(key), clean_arg_name(subkey))] = val
continue continue
extras.append('{} {} :{}'.format(col, op, prepend_key+key)) constraints[clean_arg_name(key)] = constraint
extras.append('{} {} :{}'.format(col, op, prepend_key+clean_arg_name(key)))
return prepend_sql + joiner.join(extras) if extras else '' return prepend_sql + joiner.join(extras) if extras else ''