forked from LBRYCommunity/lbry-sdk
Merge pull request #2790 from lbryio/sql_in_for_single_value
SQL generation fix to handle IN operation for one value lists
This commit is contained in:
commit
20774280b9
7 changed files with 37 additions and 16 deletions
|
@ -134,16 +134,20 @@ def constraints_to_sql(constraints, joiner=' AND ', prepend_key=''):
|
|||
col, op = col[:-len('__not_like')], 'NOT LIKE'
|
||||
elif key.endswith('__in') or key.endswith('__not_in'):
|
||||
if key.endswith('__in'):
|
||||
col, op = col[:-len('__in')], 'IN'
|
||||
col, op, one_val_op = col[:-len('__in')], 'IN', '='
|
||||
else:
|
||||
col, op = col[:-len('__not_in')], 'NOT IN'
|
||||
col, op, one_val_op = col[:-len('__not_in')], 'NOT IN', '!='
|
||||
if constraint:
|
||||
if isinstance(constraint, (list, set, tuple)):
|
||||
keys = []
|
||||
for i, val in enumerate(constraint):
|
||||
keys.append(f':{key}{tag}_{i}')
|
||||
values[f'{key}{tag}_{i}'] = val
|
||||
sql.append(f'{col} {op} ({", ".join(keys)})')
|
||||
if len(constraint) == 1:
|
||||
values[f'{key}{tag}'] = next(iter(constraint))
|
||||
sql.append(f'{col} {one_val_op} :{key}{tag}')
|
||||
else:
|
||||
keys = []
|
||||
for i, val in enumerate(constraint):
|
||||
keys.append(f':{key}{tag}_{i}')
|
||||
values[f'{key}{tag}_{i}'] = val
|
||||
sql.append(f'{col} {op} ({", ".join(keys)})')
|
||||
elif isinstance(constraint, str):
|
||||
sql.append(f'{col} {op} ({constraint})')
|
||||
else:
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
CLAIM_TYPES = {
|
||||
'stream': 1,
|
||||
'channel': 2,
|
||||
'repost': 3
|
||||
}
|
||||
|
||||
STREAM_TYPES = {
|
||||
|
|
|
@ -266,7 +266,7 @@ def claims_query(cols, for_count=False, **constraints) -> Tuple[str, Dict]:
|
|||
else:
|
||||
constraints['claim.claim_id__like'] = f'{claim_id[:40]}%'
|
||||
elif 'claim_ids' in constraints:
|
||||
constraints['claim.claim_id__in'] = constraints.pop('claim_ids')
|
||||
constraints['claim.claim_id__in'] = set(constraints.pop('claim_ids'))
|
||||
|
||||
if 'reposted_claim_id' in constraints:
|
||||
constraints['claim.reposted_claim_hash'] = unhexlify(constraints.pop('reposted_claim_id'))[::-1]
|
||||
|
@ -282,15 +282,15 @@ def claims_query(cols, for_count=False, **constraints) -> Tuple[str, Dict]:
|
|||
if 'channel_ids' in constraints:
|
||||
channel_ids = constraints.pop('channel_ids')
|
||||
if channel_ids:
|
||||
constraints['claim.channel_hash__in'] = [
|
||||
constraints['claim.channel_hash__in'] = {
|
||||
unhexlify(cid)[::-1] for cid in channel_ids
|
||||
]
|
||||
}
|
||||
if 'not_channel_ids' in constraints:
|
||||
not_channel_ids = constraints.pop('not_channel_ids')
|
||||
if not_channel_ids:
|
||||
not_channel_ids_binary = [
|
||||
not_channel_ids_binary = {
|
||||
unhexlify(ncid)[::-1] for ncid in not_channel_ids
|
||||
]
|
||||
}
|
||||
if constraints.get('has_channel_signature', False):
|
||||
constraints['claim.channel_hash__not_in'] = not_channel_ids_binary
|
||||
else:
|
||||
|
@ -320,13 +320,13 @@ def claims_query(cols, for_count=False, **constraints) -> Tuple[str, Dict]:
|
|||
if 'stream_types' in constraints:
|
||||
stream_types = constraints.pop('stream_types')
|
||||
if stream_types:
|
||||
constraints['claim.stream_type__in'] = [
|
||||
constraints['claim.stream_type__in'] = {
|
||||
STREAM_TYPES[stream_type] for stream_type in stream_types
|
||||
]
|
||||
}
|
||||
if 'media_types' in constraints:
|
||||
media_types = constraints.pop('media_types')
|
||||
if media_types:
|
||||
constraints['claim.media_type__in'] = media_types
|
||||
constraints['claim.media_type__in'] = set(media_types)
|
||||
|
||||
if 'fee_currency' in constraints:
|
||||
constraints['claim.fee_currency'] = constraints.pop('fee_currency').lower()
|
||||
|
|
|
@ -360,6 +360,7 @@ class SQLDB:
|
|||
if isinstance(fee.amount, Decimal):
|
||||
claim_record['fee_amount'] = int(fee.amount*1000)
|
||||
elif claim.is_repost:
|
||||
claim_record['claim_type'] = CLAIM_TYPES['repost']
|
||||
claim_record['reposted_claim_hash'] = claim.repost.reference.claim_hash
|
||||
elif claim.is_channel:
|
||||
claim_record['claim_type'] = CLAIM_TYPES['channel']
|
||||
|
|
|
@ -32,6 +32,7 @@ disable=
|
|||
too-many-branches,
|
||||
too-many-arguments,
|
||||
too-many-statements,
|
||||
too-many-nested-blocks,
|
||||
too-many-public-methods,
|
||||
too-many-instance-attributes,
|
||||
protected-access,
|
||||
|
|
|
@ -309,12 +309,14 @@ class ClaimSearchCommand(ClaimTestCase):
|
|||
octet = await self.stream_create()
|
||||
video = await self.stream_create('chrome', file_path=self.video_file_name)
|
||||
image = await self.stream_create('blank-image', data=self.image_data, suffix='.png')
|
||||
repost = await self.stream_repost(self.get_claim_id(image))
|
||||
channel = await self.channel_create()
|
||||
unknown = self.sout(tx)
|
||||
|
||||
# claim_type
|
||||
await self.assertFindsClaims([image, video, octet, unknown], claim_type='stream')
|
||||
await self.assertFindsClaims([channel], claim_type='channel')
|
||||
await self.assertFindsClaims([repost], claim_type='repost')
|
||||
|
||||
# stream_type
|
||||
await self.assertFindsClaims([octet, unknown], stream_types=['binary'])
|
||||
|
@ -322,7 +324,7 @@ class ClaimSearchCommand(ClaimTestCase):
|
|||
await self.assertFindsClaims([image], stream_types=['image'])
|
||||
await self.assertFindsClaims([image, video], stream_types=['video', 'image'])
|
||||
|
||||
# stream_type
|
||||
# media_type
|
||||
await self.assertFindsClaims([octet, unknown], media_types=['application/octet-stream'])
|
||||
await self.assertFindsClaims([video], media_types=['video/mp4'])
|
||||
await self.assertFindsClaims([image], media_types=['image/png'])
|
||||
|
|
|
@ -98,6 +98,12 @@ class TestQueryBuilder(unittest.TestCase):
|
|||
'txo_name__in0_1': 'def456'
|
||||
})
|
||||
)
|
||||
self.assertTupleEqual(
|
||||
constraints_to_sql({'txo.name__in': {'abc123'}}),
|
||||
('txo.name = :txo_name__in0', {
|
||||
'txo_name__in0': 'abc123',
|
||||
})
|
||||
)
|
||||
self.assertTupleEqual(
|
||||
constraints_to_sql({'txo.age__in': 'SELECT age from ages_table'}),
|
||||
('txo.age IN (SELECT age from ages_table)', {})
|
||||
|
@ -118,6 +124,12 @@ class TestQueryBuilder(unittest.TestCase):
|
|||
'txo_name__not_in0_1': 'def456'
|
||||
})
|
||||
)
|
||||
self.assertTupleEqual(
|
||||
constraints_to_sql({'txo.name__not_in': ('abc123',)}),
|
||||
('txo.name != :txo_name__not_in0', {
|
||||
'txo_name__not_in0': 'abc123',
|
||||
})
|
||||
)
|
||||
self.assertTupleEqual(
|
||||
constraints_to_sql({'txo.age__not_in': 'SELECT age from ages_table'}),
|
||||
('txo.age NOT IN (SELECT age from ages_table)', {})
|
||||
|
|
Loading…
Reference in a new issue