diff --git a/lbry/wallet/server/db/elasticsearch/search.py b/lbry/wallet/server/db/elasticsearch/search.py index 883b7877f..e03d4cbec 100644 --- a/lbry/wallet/server/db/elasticsearch/search.py +++ b/lbry/wallet/server/db/elasticsearch/search.py @@ -499,6 +499,7 @@ def expand_query(**kwargs): query['should'].append( {"bool": {"must": [{"match": {"has_source": kwargs['has_source']}}, is_stream_or_repost]}}) query['should'].append({"bool": {"must_not": [is_stream_or_repost]}}) + query['should'].append({"bool": {"must": [{"term": {"reposted_claim_type": CLAIM_TYPES['channel']}}]}}) if kwargs.get('text'): query['must'].append( {"simple_query_string": diff --git a/lbry/wallet/server/db/writer.py b/lbry/wallet/server/db/writer.py index e1a117635..dc1281220 100644 --- a/lbry/wallet/server/db/writer.py +++ b/lbry/wallet/server/db/writer.py @@ -829,6 +829,7 @@ class SQLDB: (select group_concat(tag, ',,') from tag where tag.claim_hash in (claim.claim_hash, claim.reposted_claim_hash)) as tags, (select group_concat(language, ' ') from language where language.claim_hash in (claim.claim_hash, claim.reposted_claim_hash)) as languages, (select cr.has_source from claim cr where cr.claim_hash = claim.reposted_claim_hash) as reposted_has_source, + (select cr.claim_type from claim cr where cr.claim_hash = claim.reposted_claim_hash) as reposted_claim_type, claim.* FROM claim LEFT JOIN claimtrie USING (claim_hash) WHERE claim.claim_hash in (SELECT claim_hash FROM changelog) diff --git a/tests/integration/blockchain/test_claim_commands.py b/tests/integration/blockchain/test_claim_commands.py index 8d62777bd..01d3b4151 100644 --- a/tests/integration/blockchain/test_claim_commands.py +++ b/tests/integration/blockchain/test_claim_commands.py @@ -196,11 +196,11 @@ class ClaimSearchCommand(ClaimTestCase): normal = await self.stream_create('normal', data=b'normal') normal_repost = await self.stream_repost(self.get_claim_id(normal), 'normal-repost') no_source_repost = await self.stream_repost(self.get_claim_id(no_source), 'no-source-repost') - await self.assertFindsClaims([no_source_repost, no_source, channel], has_no_source=True) + await self.assertFindsClaims([channel_repost, no_source_repost, no_source, channel], has_no_source=True) await self.assertListsClaims([no_source, channel], has_no_source=True) - await self.assertFindsClaims([normal_repost, normal, channel], has_source=True) + await self.assertFindsClaims([channel_repost, normal_repost, normal, channel], has_source=True) await self.assertListsClaims([no_source_repost, normal_repost, normal], has_source=True) - await self.assertFindsClaims([no_source_repost, normal_repost, normal, no_source, channel]) + await self.assertFindsClaims([channel_repost, no_source_repost, normal_repost, normal, no_source, channel]) await self.assertListsClaims([no_source_repost, normal_repost, normal, no_source, channel]) async def test_pagination(self):