From dcb1b6469631c5ddfcab33e4e4def10f91c93c06 Mon Sep 17 00:00:00 2001 From: Lex Berezhny Date: Wed, 12 Feb 2020 10:31:27 -0500 Subject: [PATCH] SQL generation fix to handle IN operation for one value lists --- lbry/wallet/database.py | 18 +++++++++++------- tests/unit/wallet/test_database.py | 12 ++++++++++++ 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/lbry/wallet/database.py b/lbry/wallet/database.py index 5cf5e9c8b..ab4b9a402 100644 --- a/lbry/wallet/database.py +++ b/lbry/wallet/database.py @@ -134,16 +134,20 @@ def constraints_to_sql(constraints, joiner=' AND ', prepend_key=''): col, op = col[:-len('__not_like')], 'NOT LIKE' elif key.endswith('__in') or key.endswith('__not_in'): if key.endswith('__in'): - col, op = col[:-len('__in')], 'IN' + col, op, one_val_op = col[:-len('__in')], 'IN', '=' else: - col, op = col[:-len('__not_in')], 'NOT IN' + col, op, one_val_op = col[:-len('__not_in')], 'NOT IN', '!=' if constraint: if isinstance(constraint, (list, set, tuple)): - keys = [] - for i, val in enumerate(constraint): - keys.append(f':{key}{tag}_{i}') - values[f'{key}{tag}_{i}'] = val - sql.append(f'{col} {op} ({", ".join(keys)})') + if len(constraint) == 1: + values[f'{key}{tag}'] = constraint[0] + sql.append(f'{col} {one_val_op} :{key}{tag}') + else: + keys = [] + for i, val in enumerate(constraint): + keys.append(f':{key}{tag}_{i}') + values[f'{key}{tag}_{i}'] = val + sql.append(f'{col} {op} ({", ".join(keys)})') elif isinstance(constraint, str): sql.append(f'{col} {op} ({constraint})') else: diff --git a/tests/unit/wallet/test_database.py b/tests/unit/wallet/test_database.py index 289d5b447..6b05d239e 100644 --- a/tests/unit/wallet/test_database.py +++ b/tests/unit/wallet/test_database.py @@ -98,6 +98,12 @@ class TestQueryBuilder(unittest.TestCase): 'txo_name__in0_1': 'def456' }) ) + self.assertTupleEqual( + constraints_to_sql({'txo.name__in': ('abc123',)}), + ('txo.name = :txo_name__in0', { + 'txo_name__in0': 'abc123', + }) + ) self.assertTupleEqual( constraints_to_sql({'txo.age__in': 'SELECT age from ages_table'}), ('txo.age IN (SELECT age from ages_table)', {}) @@ -118,6 +124,12 @@ class TestQueryBuilder(unittest.TestCase): 'txo_name__not_in0_1': 'def456' }) ) + self.assertTupleEqual( + constraints_to_sql({'txo.name__not_in': ('abc123',)}), + ('txo.name != :txo_name__not_in0', { + 'txo_name__not_in0': 'abc123', + }) + ) self.assertTupleEqual( constraints_to_sql({'txo.age__not_in': 'SELECT age from ages_table'}), ('txo.age NOT IN (SELECT age from ages_table)', {})