diff --git a/lbry/wallet/server/leveldb.py b/lbry/wallet/server/leveldb.py index bc48be52a..44b38fced 100644 --- a/lbry/wallet/server/leveldb.py +++ b/lbry/wallet/server/leveldb.py @@ -58,6 +58,13 @@ TXO_STRUCT_pack = TXO_STRUCT.pack OptionalResolveResultOrError = Optional[typing.Union[ResolveResult, ResolveCensoredError, LookupError, ValueError]] +class ExpandedResolveResult(typing.NamedTuple): + stream: OptionalResolveResultOrError + channel: OptionalResolveResultOrError + repost: OptionalResolveResultOrError + reposted_channel: OptionalResolveResultOrError + + class DBError(Exception): """Raised on general DB errors generally indicating corruption.""" @@ -228,8 +235,8 @@ class LevelDB: signature_valid=None if not channel_hash else signature_valid ) - def _resolve(self, name: str, claim_id: Optional[str] = None, - amount_order: Optional[int] = None) -> Optional[ResolveResult]: + def _resolve_parsed_url(self, name: str, claim_id: Optional[str] = None, + amount_order: Optional[int] = None) -> Optional[ResolveResult]: """ :param normalized_name: name :param claim_id: partial or complete claim id @@ -296,12 +303,11 @@ class LevelDB: return return list(sorted(candidates, key=lambda item: item[1]))[0] - def _fs_resolve(self, url) -> typing.Tuple[OptionalResolveResultOrError, OptionalResolveResultOrError, - OptionalResolveResultOrError]: + def _resolve(self, url) -> ExpandedResolveResult: try: parsed = URL.parse(url) except ValueError as e: - return e, None, None + return ExpandedResolveResult(e, None, None, None) stream = channel = resolved_channel = resolved_stream = None if parsed.has_stream_in_channel: @@ -312,9 +318,9 @@ class LevelDB: elif parsed.has_stream: stream = parsed.stream if channel: - resolved_channel = self._resolve(channel.name, channel.claim_id, channel.amount_order) + resolved_channel = self._resolve_parsed_url(channel.name, channel.claim_id, channel.amount_order) if not resolved_channel: - return None, LookupError(f'Could not find channel in "{url}".'), None + return ExpandedResolveResult(None, LookupError(f'Could not find channel in "{url}".'), None, None) if stream: if resolved_channel: stream_claim = self._resolve_claim_in_channel(resolved_channel.claim_hash, stream.normalized) @@ -322,13 +328,14 @@ class LevelDB: stream_claim_id, stream_tx_num, stream_tx_pos, effective_amount = stream_claim resolved_stream = self._fs_get_claim_by_hash(stream_claim_id) else: - resolved_stream = self._resolve(stream.name, stream.claim_id, stream.amount_order) + resolved_stream = self._resolve_parsed_url(stream.name, stream.claim_id, stream.amount_order) if not channel and not resolved_channel and resolved_stream and resolved_stream.channel_hash: resolved_channel = self._fs_get_claim_by_hash(resolved_stream.channel_hash) if not resolved_stream: - return LookupError(f'Could not find claim at "{url}".'), None, None + return ExpandedResolveResult(LookupError(f'Could not find claim at "{url}".'), None, None, None) repost = None + reposted_channel = None if resolved_stream or resolved_channel: claim_hash = resolved_stream.claim_hash if resolved_stream else resolved_channel.claim_hash claim = resolved_stream if resolved_stream else resolved_channel @@ -338,14 +345,17 @@ class LevelDB: reposted_claim_hash) or self.blocked_channels.get(claim.channel_hash) if blocker_hash: reason_row = self._fs_get_claim_by_hash(blocker_hash) - return None, ResolveCensoredError(url, blocker_hash, censor_row=reason_row), None + return ExpandedResolveResult( + None, ResolveCensoredError(url, blocker_hash, censor_row=reason_row), None, None + ) if claim.reposted_claim_hash: repost = self._fs_get_claim_by_hash(claim.reposted_claim_hash) - return resolved_stream, resolved_channel, repost + if repost and repost.channel_hash and repost.signature_valid: + reposted_channel = self._fs_get_claim_by_hash(repost.channel_hash) + return ExpandedResolveResult(resolved_stream, resolved_channel, repost, reposted_channel) - async def fs_resolve(self, url) -> typing.Tuple[OptionalResolveResultOrError, OptionalResolveResultOrError, - OptionalResolveResultOrError]: - return await asyncio.get_event_loop().run_in_executor(None, self._fs_resolve, url) + async def resolve(self, url) -> ExpandedResolveResult: + return await asyncio.get_event_loop().run_in_executor(None, self._resolve, url) def _fs_get_claim_by_hash(self, claim_hash): claim = self.claim_to_txo.get(claim_hash) diff --git a/lbry/wallet/server/session.py b/lbry/wallet/server/session.py index 3983756be..c51fc76e4 100644 --- a/lbry/wallet/server/session.py +++ b/lbry/wallet/server/session.py @@ -1019,7 +1019,7 @@ class LBRYElectrumX(SessionBase): self.session_mgr.pending_query_metric.inc() if 'channel' in kwargs: channel_url = kwargs.pop('channel') - _, channel_claim, _ = await self.db.fs_resolve(channel_url) + _, channel_claim, _, _ = await self.db.resolve(channel_url) if not channel_claim or isinstance(channel_claim, (ResolveCensoredError, LookupError, ValueError)): return Outputs.to_base64([], [], 0, None, None) kwargs['channel_id'] = channel_claim.claim_hash.hex() @@ -1036,12 +1036,11 @@ class LBRYElectrumX(SessionBase): self.session_mgr.pending_query_metric.dec() self.session_mgr.executor_time_metric.observe(time.perf_counter() - start) - async def claimtrie_resolve(self, *urls): + def _claimtrie_resolve(self, *urls): rows, extra = [], [] for url in urls: self.session_mgr.urls_to_resolve_count_metric.inc() - stream, channel, repost = await self.db.fs_resolve(url) - self.session_mgr.resolved_url_count_metric.inc() + stream, channel, repost, reposted_channel = self.db._resolve(url) if isinstance(channel, ResolveCensoredError): rows.append(channel) extra.append(channel.censor_row) @@ -1053,6 +1052,8 @@ class LBRYElectrumX(SessionBase): # print("resolved channel", channel.name.decode()) if repost: extra.append(repost) + if reposted_channel: + extra.append(reposted_channel) elif stream: # print("resolved stream", stream.name.decode()) rows.append(stream) @@ -1061,9 +1062,16 @@ class LBRYElectrumX(SessionBase): extra.append(channel) if repost: extra.append(repost) + if reposted_channel: + extra.append(reposted_channel) # print("claimtrie resolve %i rows %i extrat" % (len(rows), len(extra))) return Outputs.to_base64(rows, extra, 0, None, None) + async def claimtrie_resolve(self, *urls): + result = await self.loop.run_in_executor(None, self._claimtrie_resolve, *urls) + self.session_mgr.resolved_url_count_metric.inc(len(urls)) + return result + async def get_server_height(self): return self.bp.height diff --git a/tests/integration/takeovers/test_resolve_command.py b/tests/integration/takeovers/test_resolve_command.py index 6899d0974..e1e1d18f7 100644 --- a/tests/integration/takeovers/test_resolve_command.py +++ b/tests/integration/takeovers/test_resolve_command.py @@ -35,7 +35,7 @@ class BaseResolveTestCase(CommandTestCase): async def assertNoClaimForName(self, name: str): lbrycrd_winning = json.loads(await self.blockchain._cli_cmnd('getvalueforname', name)) - stream, channel, _ = await self.conductor.spv_node.server.bp.db.fs_resolve(name) + stream, channel, _, _ = await self.conductor.spv_node.server.bp.db.resolve(name) self.assertNotIn('claimId', lbrycrd_winning) if stream is not None: self.assertIsInstance(stream, LookupError) @@ -55,7 +55,7 @@ class BaseResolveTestCase(CommandTestCase): async def assertMatchWinningClaim(self, name): expected = json.loads(await self.blockchain._cli_cmnd('getvalueforname', name)) - stream, channel, _ = await self.conductor.spv_node.server.bp.db.fs_resolve(name) + stream, channel, _, _ = await self.conductor.spv_node.server.bp.db.resolve(name) claim = stream if stream else channel claim_from_es = await self.conductor.spv_node.server.bp.db.search_index.search( claim_id=claim.claim_hash.hex() @@ -983,7 +983,7 @@ class ResolveClaimTakeovers(BaseResolveTestCase): await self.generate(32 * 10 - 1) self.assertEqual(1120, self.conductor.spv_node.server.bp.db.db_height) claim_id_B = (await self.stream_create(name, '20.0', allow_duplicate_name=True))['outputs'][0]['claim_id'] - claim_B, _, _ = await self.conductor.spv_node.server.bp.db.fs_resolve(f"{name}:{claim_id_B}") + claim_B, _, _, _ = await self.conductor.spv_node.server.bp.db.resolve(f"{name}:{claim_id_B}") self.assertEqual(1121, self.conductor.spv_node.server.bp.db.db_height) self.assertEqual(1131, claim_B.activation_height) await self.assertMatchClaimIsWinning(name, claim_id_A) @@ -1000,7 +1000,7 @@ class ResolveClaimTakeovers(BaseResolveTestCase): # State: A(10+14) is controlling, B(20) is accepted, C(50) is accepted. claim_id_C = (await self.stream_create(name, '50.0', allow_duplicate_name=True))['outputs'][0]['claim_id'] self.assertEqual(1123, self.conductor.spv_node.server.bp.db.db_height) - claim_C, _, _ = await self.conductor.spv_node.server.bp.db.fs_resolve(f"{name}:{claim_id_C}") + claim_C, _, _, _ = await self.conductor.spv_node.server.bp.db.resolve(f"{name}:{claim_id_C}") self.assertEqual(1133, claim_C.activation_height) await self.assertMatchClaimIsWinning(name, claim_id_A) @@ -1018,7 +1018,7 @@ class ResolveClaimTakeovers(BaseResolveTestCase): # State: A(10+14) is controlling, B(20) is active, C(50) is accepted, D(300) is accepted. claim_id_D = (await self.stream_create(name, '300.0', allow_duplicate_name=True))['outputs'][0]['claim_id'] self.assertEqual(1132, self.conductor.spv_node.server.bp.db.db_height) - claim_D, _, _ = await self.conductor.spv_node.server.bp.db.fs_resolve(f"{name}:{claim_id_D}") + claim_D, _, _, _ = await self.conductor.spv_node.server.bp.db.resolve(f"{name}:{claim_id_D}") self.assertEqual(False, claim_D.is_controlling) self.assertEqual(801, claim_D.last_takeover_height) self.assertEqual(1142, claim_D.activation_height) @@ -1028,7 +1028,7 @@ class ResolveClaimTakeovers(BaseResolveTestCase): # State: A(10+14) is active, B(20) is active, C(50) is active, D(300) is controlling await self.generate(1) self.assertEqual(1133, self.conductor.spv_node.server.bp.db.db_height) - claim_D, _, _ = await self.conductor.spv_node.server.bp.db.fs_resolve(f"{name}:{claim_id_D}") + claim_D, _, _, _ = await self.conductor.spv_node.server.bp.db.resolve(f"{name}:{claim_id_D}") self.assertEqual(True, claim_D.is_controlling) self.assertEqual(1133, claim_D.last_takeover_height) self.assertEqual(1133, claim_D.activation_height)