fix and test has_source for channel reposts

This commit is contained in:
Victor Shyba 2021-05-03 18:40:03 -03:00
parent 0f02906c9b
commit d5f722792f
2 changed files with 13 additions and 8 deletions

View file

@ -822,8 +822,8 @@ class SQLDB:
f"SELECT claim_hash, normalized FROM claim WHERE expiration_height = {height}" f"SELECT claim_hash, normalized FROM claim WHERE expiration_height = {height}"
) )
def enqueue_changes(self): def enqueue_changes(self, shard=None, total_shards=None):
for claim in self.execute(f""" query = """
SELECT claimtrie.claim_hash as is_controlling, SELECT claimtrie.claim_hash as is_controlling,
claimtrie.last_take_over_height, claimtrie.last_take_over_height,
(select group_concat(tag, ',,') from tag where tag.claim_hash in (claim.claim_hash, claim.reposted_claim_hash)) as tags, (select group_concat(tag, ',,') from tag where tag.claim_hash in (claim.claim_hash, claim.reposted_claim_hash)) as tags,
@ -832,8 +832,12 @@ class SQLDB:
(select cr.claim_type from claim cr where cr.claim_hash = claim.reposted_claim_hash) as reposted_claim_type, (select cr.claim_type from claim cr where cr.claim_hash = claim.reposted_claim_hash) as reposted_claim_type,
claim.* claim.*
FROM claim LEFT JOIN claimtrie USING (claim_hash) FROM claim LEFT JOIN claimtrie USING (claim_hash)
WHERE claim.claim_hash in (SELECT claim_hash FROM changelog) """
"""): if shard is not None and total_shards is not None:
query += f" WHERE claim.height % {total_shards} = {shard}"
else:
query += " WHERE claim.claim_hash in (SELECT claim_hash FROM changelog)"
for claim in self.execute(query):
claim = claim._asdict() claim = claim._asdict()
id_set = set(filter(None, (claim['claim_hash'], claim['channel_hash'], claim['reposted_claim_hash']))) id_set = set(filter(None, (claim['claim_hash'], claim['channel_hash'], claim['reposted_claim_hash'])))
claim['censor_type'] = 0 claim['censor_type'] = 0
@ -860,11 +864,11 @@ class SQLDB:
def clear_changelog(self): def clear_changelog(self):
self.execute("delete from changelog;") self.execute("delete from changelog;")
def claim_producer(self): def claim_producer(self, shard=None, total_shards=None):
while self.pending_deletes: while self.pending_deletes:
claim_hash = self.pending_deletes.pop() claim_hash = self.pending_deletes.pop()
yield 'delete', hexlify(claim_hash[::-1]).decode() yield 'delete', hexlify(claim_hash[::-1]).decode()
for claim in self.enqueue_changes(): for claim in self.enqueue_changes(shard, total_shards):
yield claim yield claim
self.clear_changelog() self.clear_changelog()

View file

@ -196,12 +196,13 @@ class ClaimSearchCommand(ClaimTestCase):
normal = await self.stream_create('normal', data=b'normal') normal = await self.stream_create('normal', data=b'normal')
normal_repost = await self.stream_repost(self.get_claim_id(normal), 'normal-repost') normal_repost = await self.stream_repost(self.get_claim_id(normal), 'normal-repost')
no_source_repost = await self.stream_repost(self.get_claim_id(no_source), 'no-source-repost') no_source_repost = await self.stream_repost(self.get_claim_id(no_source), 'no-source-repost')
channel_repost = await self.stream_repost(self.get_claim_id(channel), 'channel-repost')
await self.assertFindsClaims([channel_repost, no_source_repost, no_source, channel], has_no_source=True) await self.assertFindsClaims([channel_repost, no_source_repost, no_source, channel], has_no_source=True)
await self.assertListsClaims([no_source, channel], has_no_source=True) await self.assertListsClaims([no_source, channel], has_no_source=True)
await self.assertFindsClaims([channel_repost, normal_repost, normal, channel], has_source=True) await self.assertFindsClaims([channel_repost, normal_repost, normal, channel], has_source=True)
await self.assertListsClaims([no_source_repost, normal_repost, normal], has_source=True) await self.assertListsClaims([channel_repost, no_source_repost, normal_repost, normal], has_source=True)
await self.assertFindsClaims([channel_repost, no_source_repost, normal_repost, normal, no_source, channel]) await self.assertFindsClaims([channel_repost, no_source_repost, normal_repost, normal, no_source, channel])
await self.assertListsClaims([no_source_repost, normal_repost, normal, no_source, channel]) await self.assertListsClaims([channel_repost, no_source_repost, normal_repost, normal, no_source, channel])
async def test_pagination(self): async def test_pagination(self):
await self.create_channel() await self.create_channel()