diff --git a/lbry/wallet/server/db/elasticsearch/search.py b/lbry/wallet/server/db/elasticsearch/search.py index 362111489..55e1bd0dd 100644 --- a/lbry/wallet/server/db/elasticsearch/search.py +++ b/lbry/wallet/server/db/elasticsearch/search.py @@ -429,6 +429,11 @@ def expand_query(**kwargs): query["minimum_should_match"] = 1 query['should'].append({"bool": {"must_not": {"exists": {"field": "signature_digest"}}}}) query['should'].append({"term": {"signature_valid": bool(kwargs["signature_valid"])}}) + if 'has_source' in kwargs: + query.setdefault('should', []) + query["minimum_should_match"] = 1 + query['should'].append({"bool": {"must": [{"match": {"has_source": kwargs['has_source']}}, {"match": {"claim_type": CLAIM_TYPES['stream']}}]}}) + query['should'].append({"bool": {"must_not": [{"match": {"claim_type": CLAIM_TYPES['stream']}}]}}) if kwargs.get('text'): query['must'].append( {"simple_query_string": diff --git a/lbry/wallet/server/db/elasticsearch/sync.py b/lbry/wallet/server/db/elasticsearch/sync.py index 645b7e758..0255c0c2a 100644 --- a/lbry/wallet/server/db/elasticsearch/sync.py +++ b/lbry/wallet/server/db/elasticsearch/sync.py @@ -28,12 +28,14 @@ SELECT claimtrie.claim_hash as is_controlling, claimtrie.last_take_over_height, (select group_concat(tag, ',,') from tag where tag.claim_hash in (claim.claim_hash, claim.reposted_claim_hash)) as tags, (select group_concat(language, ' ') from language where language.claim_hash in (claim.claim_hash, claim.reposted_claim_hash)) as languages, + (select cr.has_source from claim cr where cr.claim_hash = claim.reposted_claim_hash) as reposted_has_source, claim.* FROM claim LEFT JOIN claimtrie USING (claim_hash) WHERE claim.height % {shards_total} = {shard_num} ORDER BY claim.height desc """)): claim = dict(claim._asdict()) + claim['has_source'] = bool(claim.pop('reposted_has_source') or claim['has_source']) claim['censor_type'] = 0 claim['censoring_channel_hash'] = None claim['tags'] = claim['tags'].split(',,') if claim['tags'] else [] diff --git a/lbry/wallet/server/db/writer.py b/lbry/wallet/server/db/writer.py index 80fc4b556..6f94a3c66 100644 --- a/lbry/wallet/server/db/writer.py +++ b/lbry/wallet/server/db/writer.py @@ -827,6 +827,7 @@ class SQLDB: claimtrie.last_take_over_height, (select group_concat(tag, ',,') from tag where tag.claim_hash in (claim.claim_hash, claim.reposted_claim_hash)) as tags, (select group_concat(language, ' ') from language where language.claim_hash in (claim.claim_hash, claim.reposted_claim_hash)) as languages, + (select cr.has_source from claim cr where cr.claim_hash = claim.reposted_claim_hash) as reposted_has_source, claim.* FROM claim LEFT JOIN claimtrie USING (claim_hash) WHERE claim.claim_hash in (SELECT claim_hash FROM changelog) @@ -835,6 +836,7 @@ class SQLDB: id_set = set(filter(None, (claim['claim_hash'], claim['channel_hash'], claim['reposted_claim_hash']))) claim['censor_type'] = 0 claim['censoring_channel_hash'] = None + claim['has_source'] = bool(claim.pop('reposted_has_source') or claim['has_source']) for reason_id in id_set: if reason_id in self.blocked_streams: claim['censor_type'] = 2 diff --git a/tests/integration/blockchain/test_claim_commands.py b/tests/integration/blockchain/test_claim_commands.py index ab373e702..a1981614f 100644 --- a/tests/integration/blockchain/test_claim_commands.py +++ b/tests/integration/blockchain/test_claim_commands.py @@ -180,11 +180,13 @@ class ClaimSearchCommand(ClaimTestCase): await self.assertFindsClaims([three], claim_id=self.get_claim_id(three), text='*') async def test_source_filter(self): - no_source = await self.stream_create('no_source', data=None) + no_source = await self.stream_create('no-source', data=None) normal = await self.stream_create('normal', data=b'normal') - await self.assertFindsClaims([no_source], has_no_source=True) - await self.assertFindsClaims([normal], has_source=True) - await self.assertFindsClaims([normal, no_source]) + normal_repost = await self.stream_repost(self.get_claim_id(normal), 'normal-repost') + no_source_repost = await self.stream_repost(self.get_claim_id(no_source), 'no-source-repost') + await self.assertFindsClaims([no_source_repost, no_source], has_no_source=True) + await self.assertFindsClaims([normal_repost, normal], has_source=True) + await self.assertFindsClaims([no_source_repost, normal_repost, normal, no_source]) async def test_pagination(self): await self.create_channel()