Merge pull request #2790 from lbryio/sql_in_for_single_value

SQL generation fix to handle IN operation for one value lists
This commit is contained in:
Lex Berezhny 2020-02-12 14:27:36 -05:00 committed by GitHub
commit 20774280b9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 37 additions and 16 deletions

View file

@ -134,16 +134,20 @@ def constraints_to_sql(constraints, joiner=' AND ', prepend_key=''):
col, op = col[:-len('__not_like')], 'NOT LIKE'
elif key.endswith('__in') or key.endswith('__not_in'):
if key.endswith('__in'):
col, op = col[:-len('__in')], 'IN'
col, op, one_val_op = col[:-len('__in')], 'IN', '='
else:
col, op = col[:-len('__not_in')], 'NOT IN'
col, op, one_val_op = col[:-len('__not_in')], 'NOT IN', '!='
if constraint:
if isinstance(constraint, (list, set, tuple)):
keys = []
for i, val in enumerate(constraint):
keys.append(f':{key}{tag}_{i}')
values[f'{key}{tag}_{i}'] = val
sql.append(f'{col} {op} ({", ".join(keys)})')
if len(constraint) == 1:
values[f'{key}{tag}'] = next(iter(constraint))
sql.append(f'{col} {one_val_op} :{key}{tag}')
else:
keys = []
for i, val in enumerate(constraint):
keys.append(f':{key}{tag}_{i}')
values[f'{key}{tag}_{i}'] = val
sql.append(f'{col} {op} ({", ".join(keys)})')
elif isinstance(constraint, str):
sql.append(f'{col} {op} ({constraint})')
else:

View file

@ -1,6 +1,7 @@
CLAIM_TYPES = {
'stream': 1,
'channel': 2,
'repost': 3
}
STREAM_TYPES = {

View file

@ -266,7 +266,7 @@ def claims_query(cols, for_count=False, **constraints) -> Tuple[str, Dict]:
else:
constraints['claim.claim_id__like'] = f'{claim_id[:40]}%'
elif 'claim_ids' in constraints:
constraints['claim.claim_id__in'] = constraints.pop('claim_ids')
constraints['claim.claim_id__in'] = set(constraints.pop('claim_ids'))
if 'reposted_claim_id' in constraints:
constraints['claim.reposted_claim_hash'] = unhexlify(constraints.pop('reposted_claim_id'))[::-1]
@ -282,15 +282,15 @@ def claims_query(cols, for_count=False, **constraints) -> Tuple[str, Dict]:
if 'channel_ids' in constraints:
channel_ids = constraints.pop('channel_ids')
if channel_ids:
constraints['claim.channel_hash__in'] = [
constraints['claim.channel_hash__in'] = {
unhexlify(cid)[::-1] for cid in channel_ids
]
}
if 'not_channel_ids' in constraints:
not_channel_ids = constraints.pop('not_channel_ids')
if not_channel_ids:
not_channel_ids_binary = [
not_channel_ids_binary = {
unhexlify(ncid)[::-1] for ncid in not_channel_ids
]
}
if constraints.get('has_channel_signature', False):
constraints['claim.channel_hash__not_in'] = not_channel_ids_binary
else:
@ -320,13 +320,13 @@ def claims_query(cols, for_count=False, **constraints) -> Tuple[str, Dict]:
if 'stream_types' in constraints:
stream_types = constraints.pop('stream_types')
if stream_types:
constraints['claim.stream_type__in'] = [
constraints['claim.stream_type__in'] = {
STREAM_TYPES[stream_type] for stream_type in stream_types
]
}
if 'media_types' in constraints:
media_types = constraints.pop('media_types')
if media_types:
constraints['claim.media_type__in'] = media_types
constraints['claim.media_type__in'] = set(media_types)
if 'fee_currency' in constraints:
constraints['claim.fee_currency'] = constraints.pop('fee_currency').lower()

View file

@ -360,6 +360,7 @@ class SQLDB:
if isinstance(fee.amount, Decimal):
claim_record['fee_amount'] = int(fee.amount*1000)
elif claim.is_repost:
claim_record['claim_type'] = CLAIM_TYPES['repost']
claim_record['reposted_claim_hash'] = claim.repost.reference.claim_hash
elif claim.is_channel:
claim_record['claim_type'] = CLAIM_TYPES['channel']

View file

@ -32,6 +32,7 @@ disable=
too-many-branches,
too-many-arguments,
too-many-statements,
too-many-nested-blocks,
too-many-public-methods,
too-many-instance-attributes,
protected-access,

View file

@ -309,12 +309,14 @@ class ClaimSearchCommand(ClaimTestCase):
octet = await self.stream_create()
video = await self.stream_create('chrome', file_path=self.video_file_name)
image = await self.stream_create('blank-image', data=self.image_data, suffix='.png')
repost = await self.stream_repost(self.get_claim_id(image))
channel = await self.channel_create()
unknown = self.sout(tx)
# claim_type
await self.assertFindsClaims([image, video, octet, unknown], claim_type='stream')
await self.assertFindsClaims([channel], claim_type='channel')
await self.assertFindsClaims([repost], claim_type='repost')
# stream_type
await self.assertFindsClaims([octet, unknown], stream_types=['binary'])
@ -322,7 +324,7 @@ class ClaimSearchCommand(ClaimTestCase):
await self.assertFindsClaims([image], stream_types=['image'])
await self.assertFindsClaims([image, video], stream_types=['video', 'image'])
# stream_type
# media_type
await self.assertFindsClaims([octet, unknown], media_types=['application/octet-stream'])
await self.assertFindsClaims([video], media_types=['video/mp4'])
await self.assertFindsClaims([image], media_types=['image/png'])

View file

@ -98,6 +98,12 @@ class TestQueryBuilder(unittest.TestCase):
'txo_name__in0_1': 'def456'
})
)
self.assertTupleEqual(
constraints_to_sql({'txo.name__in': {'abc123'}}),
('txo.name = :txo_name__in0', {
'txo_name__in0': 'abc123',
})
)
self.assertTupleEqual(
constraints_to_sql({'txo.age__in': 'SELECT age from ages_table'}),
('txo.age IN (SELECT age from ages_table)', {})
@ -118,6 +124,12 @@ class TestQueryBuilder(unittest.TestCase):
'txo_name__not_in0_1': 'def456'
})
)
self.assertTupleEqual(
constraints_to_sql({'txo.name__not_in': ('abc123',)}),
('txo.name != :txo_name__not_in0', {
'txo_name__not_in0': 'abc123',
})
)
self.assertTupleEqual(
constraints_to_sql({'txo.age__not_in': 'SELECT age from ages_table'}),
('txo.age NOT IN (SELECT age from ages_table)', {})