From 9924b7b43830b61d5f05e3ba474e368473545ef4 Mon Sep 17 00:00:00 2001
From: Victor Shyba <victor.shyba@gmail.com>
Date: Tue, 19 Jan 2021 20:38:03 -0300
Subject: [PATCH] reposts and tag inheritance

---
 lbry/wallet/orchstr8/node.py                      |  1 +
 lbry/wallet/server/db/elastic_search.py           | 15 ++++++++-------
 lbry/wallet/server/db/writer.py                   |  7 ++++---
 .../integration/blockchain/test_claim_commands.py |  3 ++-
 4 files changed, 15 insertions(+), 11 deletions(-)

diff --git a/lbry/wallet/orchstr8/node.py b/lbry/wallet/orchstr8/node.py
index 53323a6df..d15e32d5d 100644
--- a/lbry/wallet/orchstr8/node.py
+++ b/lbry/wallet/orchstr8/node.py
@@ -202,6 +202,7 @@ class SPVNode:
     async def stop(self, cleanup=True):
         try:
             await self.server.db.search_index.delete_index()
+            await self.server.db.search_index.stop()
             await self.server.stop()
         finally:
             cleanup and self.cleanup()
diff --git a/lbry/wallet/server/db/elastic_search.py b/lbry/wallet/server/db/elastic_search.py
index dbf752dab..15e0bcce4 100644
--- a/lbry/wallet/server/db/elastic_search.py
+++ b/lbry/wallet/server/db/elastic_search.py
@@ -187,11 +187,12 @@ FIELDS = ['is_controlling', 'last_take_over_height', 'claim_id', 'claim_name', '
           'stream_type', 'media_type', 'fee_amount', 'fee_currency', 'duration', 'reposted_claim_hash',
           'claims_in_channel', 'channel_join', 'signature_valid', 'effective_amount', 'support_amount',
           'trending_group', 'trending_mixed', 'trending_local', 'trending_global', 'channel_id', 'tx_id', 'tx_nout',
-          'signature', 'signature_digest', 'public_key_bytes', 'public_key_hash', 'public_key_id', '_id', 'tags']
+          'signature', 'signature_digest', 'public_key_bytes', 'public_key_hash', 'public_key_id', '_id', 'tags',
+          'reposted_claim_id']
 TEXT_FIELDS = ['author', 'canonical_url', 'channel_id', 'claim_id', 'claim_name', 'description',
                'media_type', 'normalized', 'public_key_bytes', 'public_key_hash', 'short_url', 'signature',
-               'signature_digest', 'stream_type', 'title', 'tx_id', 'fee_currency']
-RANGE_FIELDS = ['height', 'fee_amount', 'duration']
+               'signature_digest', 'stream_type', 'title', 'tx_id', 'fee_currency', 'reposted_claim_id', 'tags']
+RANGE_FIELDS = ['height', 'fee_amount', 'duration', 'reposted']
 REPLACEMENTS = {
     'name': 'claim_name',
     'txid': 'tx_id',
@@ -204,7 +205,7 @@ def expand_query(**kwargs):
     collapse = None
     for key, value in kwargs.items():
         key = key.replace('claim.', '')
-        many = key.endswith('__in')
+        many = key.endswith('__in') or isinstance(value, list)
         if many:
             key = key.replace('__in', '')
         key = REPLACEMENTS.get(key, key)
@@ -256,11 +257,11 @@ def expand_query(**kwargs):
         elif key == 'all_languages':
             query['must'].extend([{"term": {'languages': tag}} for tag in value])
         elif key == 'any_tags':
-            query['must'].append({"terms": {'tags': clean_tags(value)}})
+            query['must'].append({"terms": {'tags.keyword': clean_tags(value)}})
         elif key == 'all_tags':
-            query['must'].extend([{"term": {'tags': tag}} for tag in clean_tags(value)])
+            query['must'].extend([{"term": {'tags.keyword': tag}} for tag in clean_tags(value)])
         elif key == 'not_tags':
-            query['must_not'].extend([{"term": {'tags': tag}} for tag in clean_tags(value)])
+            query['must_not'].extend([{"term": {'tags.keyword': tag}} for tag in clean_tags(value)])
         elif key == 'limit_claims_per_channel':
             collapse = ('channel_id.keyword', value)
     if kwargs.get('has_channel_signature'):
diff --git a/lbry/wallet/server/db/writer.py b/lbry/wallet/server/db/writer.py
index 9794865f9..3f144613d 100644
--- a/lbry/wallet/server/db/writer.py
+++ b/lbry/wallet/server/db/writer.py
@@ -532,6 +532,7 @@ class SQLDB:
                 WHERE claim_hash = ?
                 """, targets
             )
+        return set(target[0] for target in targets)
 
     def validate_channel_signatures(self, height, new_claims, updated_claims, spent_claims, affected_channels, timer):
         if not new_claims and not updated_claims and not spent_claims:
@@ -828,7 +829,7 @@ class SQLDB:
         WHERE claim_hash IN ({','.join('?' for _ in changed_claim_hashes)})
         """, changed_claim_hashes):
             claim = dict(claim._asdict())
-            claim['tags'] = tags.get(claim['claim_hash'], [])
+            claim['tags'] = tags.get(claim['claim_hash']) or tags.get(claim['reposted_claim_hash'])
             claim['languages'] = langs.get(claim['claim_hash'], [])
             if not self.claim_queue.full():
                 self.claim_queue.put_nowait(('update', claim))
@@ -914,7 +915,7 @@ class SQLDB:
         affected_channels = r(self.delete_claims, delete_claim_hashes)
         r(self.delete_supports, delete_support_txo_hashes)
         r(self.insert_claims, insert_claims, header)
-        r(self.calculate_reposts, insert_claims)
+        reposted = r(self.calculate_reposts, insert_claims)
         r(update_full_text_search, 'after-insert',
           [txo.claim_hash for txo in insert_claims], self.db.cursor(), self.main.first_sync)
         r(update_full_text_search, 'before-update',
@@ -931,7 +932,7 @@ class SQLDB:
         if not self._fts_synced and self.main.first_sync and height == daemon_height:
             r(first_sync_finished, self.db.cursor())
             self._fts_synced = True
-        r(self.enqueue_changes, recalculate_claim_hashes | affected_channels, delete_claim_hashes)
+        r(self.enqueue_changes, recalculate_claim_hashes | affected_channels | reposted, delete_claim_hashes)
 
 
 class LBRYLevelDB(LevelDB):
diff --git a/tests/integration/blockchain/test_claim_commands.py b/tests/integration/blockchain/test_claim_commands.py
index 1bad0e0d8..47728dc66 100644
--- a/tests/integration/blockchain/test_claim_commands.py
+++ b/tests/integration/blockchain/test_claim_commands.py
@@ -73,7 +73,8 @@ class ClaimSearchCommand(ClaimTestCase):
         for claim, result in zip(claims, results):
             self.assertEqual(
                 (claim['txid'], self.get_claim_id(claim)),
-                (result['txid'], result['claim_id'])
+                (result['txid'], result['claim_id']),
+                f"{claim['outputs'][0]['name']} != {result['name']}"
             )
 
     @skip("doesnt happen on ES...?")