Merge pull request #3058 from lbryio/faster-resolve

faster resolve and claim_search
This commit is contained in:
Jack Robison 2020-11-05 21:09:07 -05:00 committed by GitHub
commit 511a5c3f82
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 240 additions and 41 deletions

View file

@ -29,7 +29,6 @@ from .constants import TXO_TYPES, CLAIM_TYPES, COIN, NULL_HASH32
from .bip32 import PubKey, PrivateKey from .bip32 import PubKey, PrivateKey
from .coinselection import CoinSelector from .coinselection import CoinSelector
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
LedgerType = Type['BaseLedger'] LedgerType = Type['BaseLedger']
@ -479,12 +478,14 @@ class Ledger(metaclass=LedgerRegistry):
for address, remote_status in zip(batch, results): for address, remote_status in zip(batch, results):
self._update_tasks.add(self.update_history(address, remote_status, address_manager)) self._update_tasks.add(self.update_history(address, remote_status, address_manager))
addresses_remaining = addresses_remaining[batch_size:] addresses_remaining = addresses_remaining[batch_size:]
log.info("subscribed to %i/%i addresses on %s:%i", len(addresses) - len(addresses_remaining), if self.network.client and self.network.client.server_address_and_port:
len(addresses), *self.network.client.server_address_and_port) log.info("subscribed to %i/%i addresses on %s:%i", len(addresses) - len(addresses_remaining),
log.info( len(addresses), *self.network.client.server_address_and_port)
"finished subscribing to %i addresses on %s:%i", len(addresses), if self.network.client and self.network.client.server_address_and_port:
*self.network.client.server_address_and_port log.info(
) "finished subscribing to %i addresses on %s:%i", len(addresses),
*self.network.client.server_address_and_port
)
def process_status_update(self, update): def process_status_update(self, update):
address, remote_status = update address, remote_status = update
@ -687,6 +688,59 @@ class Ledger(metaclass=LedgerRegistry):
tx.position = merkle['pos'] tx.position = merkle['pos']
tx.is_verified = merkle_root == header['merkle_root'] tx.is_verified = merkle_root == header['merkle_root']
async def request_transactions_for_inflate(self, to_request: Tuple[Tuple[str, int], ...], session_override=None):
header_cache = {}
batches = [[]]
remote_heights = {}
transactions = []
heights_in_batch = 0
last_height = 0
for txid, height in sorted(to_request, key=lambda x: x[1]):
remote_heights[txid] = height
if height != last_height:
heights_in_batch += 1
last_height = height
if len(batches[-1]) == 100 or heights_in_batch == 20:
batches.append([])
heights_in_batch = 1
batches[-1].append(txid)
if not batches[-1]:
batches.pop()
async def _single_batch(batch):
if session_override:
batch_result = await self.network.get_transaction_batch(
batch, restricted=False, session=session_override
)
else:
batch_result = await self.network.retriable_call(self.network.get_transaction_batch, batch)
for txid, (raw, merkle) in batch_result.items():
remote_height = remote_heights[txid]
merkle_height = merkle['block_height']
cache_item = self._tx_cache.get(txid)
if cache_item is None:
cache_item = TransactionCacheItem()
self._tx_cache[txid] = cache_item
tx = cache_item.tx or Transaction(unhexlify(raw), height=remote_height)
tx.height = remote_height
cache_item.tx = tx
if 'merkle' in merkle and remote_heights[txid] > 0:
merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash)
try:
header = header_cache.get(remote_heights[txid]) or (await self.headers.get(merkle_height))
except IndexError:
log.warning("failed to verify %s at height %i", tx.id, merkle_height)
else:
header_cache[remote_heights[txid]] = header
tx.position = merkle['pos']
tx.is_verified = merkle_root == header['merkle_root']
transactions.append(tx)
for batch in batches:
await _single_batch(batch)
return transactions
async def _request_transaction_batch(self, to_request, remote_history_size, address): async def _request_transaction_batch(self, to_request, remote_history_size, address):
header_cache = {} header_cache = {}
batches = [[]] batches = [[]]
@ -844,14 +898,17 @@ class Ledger(metaclass=LedgerRegistry):
include_is_my_output=False, include_is_my_output=False,
include_sent_supports=False, include_sent_supports=False,
include_sent_tips=False, include_sent_tips=False,
include_received_tips=False) -> Tuple[List[Output], dict, int, int]: include_received_tips=False,
session_override=None) -> Tuple[List[Output], dict, int, int]:
encoded_outputs = await query encoded_outputs = await query
outputs = Outputs.from_base64(encoded_outputs or b'') # TODO: why is the server returning None? outputs = Outputs.from_base64(encoded_outputs or b'') # TODO: why is the server returning None?
txs = [] txs = []
if len(outputs.txs) > 0: if len(outputs.txs) > 0:
txs: List[Transaction] = await asyncio.gather(*( txs: List[Transaction] = []
self.cache_transaction(*tx) for tx in outputs.txs if session_override:
)) txs.extend((await self.request_transactions_for_inflate(tuple(outputs.txs), session_override)))
else:
txs.extend((await asyncio.gather(*(self.cache_transaction(*tx) for tx in outputs.txs))))
_txos, blocked = outputs.inflate(txs) _txos, blocked = outputs.inflate(txs)
@ -924,15 +981,28 @@ class Ledger(metaclass=LedgerRegistry):
return txos, blocked, outputs.offset, outputs.total return txos, blocked, outputs.offset, outputs.total
async def resolve(self, accounts, urls, new_sdk_server=None, **kwargs): async def resolve(self, accounts, urls, new_sdk_server=None, **kwargs):
txos = []
urls_copy = list(urls)
if new_sdk_server: if new_sdk_server:
resolve = partial(self.network.new_resolve, new_sdk_server) resolve = partial(self.network.new_resolve, new_sdk_server)
while urls_copy:
batch, urls_copy = urls_copy[:500], urls_copy[500:]
txos.extend(
(await self._inflate_outputs(
resolve(batch), accounts, **kwargs
))[0]
)
else: else:
resolve = partial(self.network.retriable_call, self.network.resolve) async with self.network.single_call_context(self.network.resolve) as (resolve, session):
urls_copy = list(urls) while urls_copy:
txos = [] batch, urls_copy = urls_copy[:500], urls_copy[500:]
while urls_copy: txos.extend(
batch, urls_copy = urls_copy[:500], urls_copy[500:] (await self._inflate_outputs(
txos.extend((await self._inflate_outputs(resolve(batch), accounts, **kwargs))[0]) resolve(batch), accounts, session_override=session, **kwargs
))[0]
)
assert len(urls) == len(txos), "Mismatch between urls requested for resolve and responses received." assert len(urls) == len(txos), "Mismatch between urls requested for resolve and responses received."
result = {} result = {}
for url, txo in zip(urls, txos): for url, txo in zip(urls, txos):
@ -953,13 +1023,17 @@ class Ledger(metaclass=LedgerRegistry):
new_sdk_server=None, **kwargs) -> Tuple[List[Output], dict, int, int]: new_sdk_server=None, **kwargs) -> Tuple[List[Output], dict, int, int]:
if new_sdk_server: if new_sdk_server:
claim_search = partial(self.network.new_claim_search, new_sdk_server) claim_search = partial(self.network.new_claim_search, new_sdk_server)
else: return await self._inflate_outputs(
claim_search = self.network.claim_search claim_search(**kwargs), accounts,
return await self._inflate_outputs( include_purchase_receipt=include_purchase_receipt,
claim_search(**kwargs), accounts, include_is_my_output=include_is_my_output,
include_purchase_receipt=include_purchase_receipt, )
include_is_my_output=include_is_my_output async with self.network.single_call_context(self.network.claim_search) as (claim_search, session):
) return await self._inflate_outputs(
claim_search(**kwargs), accounts, session_override=session,
include_purchase_receipt=include_purchase_receipt,
include_is_my_output=include_is_my_output,
)
async def get_claim_by_claim_id(self, accounts, claim_id, **kwargs) -> Output: async def get_claim_by_claim_id(self, accounts, claim_id, **kwargs) -> Output:
for claim in (await self.claim_search(accounts, claim_id=claim_id, **kwargs))[0]: for claim in (await self.claim_search(accounts, claim_id=claim_id, **kwargs))[0]:

View file

@ -3,6 +3,8 @@ import asyncio
import json import json
from time import perf_counter from time import perf_counter
from operator import itemgetter from operator import itemgetter
from contextlib import asynccontextmanager
from functools import partial
from typing import Dict, Optional, Tuple from typing import Dict, Optional, Tuple
import aiohttp import aiohttp
@ -230,8 +232,8 @@ class Network:
def is_connected(self): def is_connected(self):
return self.client and not self.client.is_closing() return self.client and not self.client.is_closing()
def rpc(self, list_or_method, args, restricted=True): def rpc(self, list_or_method, args, restricted=True, session=None):
session = self.client if restricted else self.session_pool.fastest_session session = session or (self.client if restricted else self.session_pool.fastest_session)
if session and not session.is_closing(): if session and not session.is_closing():
return session.send_request(list_or_method, args) return session.send_request(list_or_method, args)
else: else:
@ -253,6 +255,28 @@ class Network:
pass pass
raise asyncio.CancelledError() # if we got here, we are shutting down raise asyncio.CancelledError() # if we got here, we are shutting down
@asynccontextmanager
async def single_call_context(self, function, *args, **kwargs):
if not self.is_connected:
log.warning("Wallet server unavailable, waiting for it to come back and retry.")
await self.on_connected.first
await self.session_pool.wait_for_fastest_session()
server = self.session_pool.fastest_session.server
session = ClientSession(network=self, server=server)
async def call_with_reconnect(*a, **kw):
while self.running:
if not session.available:
await session.create_connection()
try:
return await partial(function, *args, session_override=session, **kwargs)(*a, **kw)
except asyncio.TimeoutError:
log.warning("'%s' failed, retrying", function.__name__)
try:
yield (call_with_reconnect, session)
finally:
await session.close()
def _update_remote_height(self, header_args): def _update_remote_height(self, header_args):
self.remote_height = header_args[0]["height"] self.remote_height = header_args[0]["height"]
@ -261,9 +285,9 @@ class Network:
restricted = known_height in (None, -1, 0) or 0 > known_height > self.remote_height - 10 restricted = known_height in (None, -1, 0) or 0 > known_height > self.remote_height - 10
return self.rpc('blockchain.transaction.get', [tx_hash], restricted) return self.rpc('blockchain.transaction.get', [tx_hash], restricted)
def get_transaction_batch(self, txids): def get_transaction_batch(self, txids, restricted=True, session=None):
# use any server if its old, otherwise restrict to who gave us the history # use any server if its old, otherwise restrict to who gave us the history
return self.rpc('blockchain.transaction.get_batch', txids, True) return self.rpc('blockchain.transaction.get_batch', txids, restricted, session)
def get_transaction_and_merkle(self, tx_hash, known_height=None): def get_transaction_and_merkle(self, tx_hash, known_height=None):
# use any server if its old, otherwise restrict to who gave us the history # use any server if its old, otherwise restrict to who gave us the history
@ -316,11 +340,11 @@ class Network:
def get_claims_by_ids(self, claim_ids): def get_claims_by_ids(self, claim_ids):
return self.rpc('blockchain.claimtrie.getclaimsbyids', claim_ids) return self.rpc('blockchain.claimtrie.getclaimsbyids', claim_ids)
def resolve(self, urls): def resolve(self, urls, session_override=None):
return self.rpc('blockchain.claimtrie.resolve', urls) return self.rpc('blockchain.claimtrie.resolve', urls, False, session_override)
def claim_search(self, **kwargs): def claim_search(self, session_override=None, **kwargs):
return self.rpc('blockchain.claimtrie.search', kwargs) return self.rpc('blockchain.claimtrie.search', kwargs, False, session_override)
async def new_resolve(self, server, urls): async def new_resolve(self, server, urls):
message = {"method": "resolve", "params": {"urls": urls, "protobuf": True}} message = {"method": "resolve", "params": {"urls": urls, "protobuf": True}}

View file

@ -51,8 +51,8 @@ MOST_USED_TAGS = {
"tutorial", "tutorial",
"video game", "video game",
"weapons", "weapons",
"pc",
"playthrough", "playthrough",
"pc",
"anime", "anime",
"how to", "how to",
"btc", "btc",
@ -80,9 +80,9 @@ MOST_USED_TAGS = {
"español", "español",
"money", "money",
"music video", "music video",
"nintendo",
"movie", "movie",
"coronavirus", "coronavirus",
"nintendo",
"donald trump", "donald trump",
"steam", "steam",
"trailer", "trailer",
@ -90,10 +90,10 @@ MOST_USED_TAGS = {
"podcast", "podcast",
"xbox one", "xbox one",
"survival", "survival",
"audio",
"linux", "linux",
"travel", "travel",
"funny moments", "funny moments",
"audio",
"litecoin", "litecoin",
"animation", "animation",
"gamer", "gamer",
@ -101,20 +101,120 @@ MOST_USED_TAGS = {
"playstation", "playstation",
"bitcoin news", "bitcoin news",
"history", "history",
"fox news",
"xxx", "xxx",
"god", "fox news",
"dance", "dance",
"god",
"adventure", "adventure",
"liberal", "liberal",
"2020",
"horror", "horror",
"government", "government",
"freedom", "freedom",
"2020",
"reaction", "reaction",
"meme", "meme",
"photography", "photography",
"truth" "truth",
"health",
"lbry",
"family",
"online",
"eth",
"crypto news",
"diy",
"trading",
"gold",
"memes",
"world",
"space",
"lol",
"covid-19",
"rpg",
"humor",
"democrat",
"film",
"call of duty",
"tech",
"religion",
"conspiracy",
"rap",
"cnn",
"hangoutsonair",
"unboxing",
"fiction",
"conservative",
"cars",
"hoa",
"epic",
"programming",
"progressive",
"cryptocurrency news",
"classical",
"jesus",
"movies",
"book",
"ps3",
"republican",
"fitness",
"books",
"multiplayer",
"animals",
"pokemon",
"bitcoin price",
"facebook",
"sharefactory",
"criptomonedas",
"cod",
"bible",
"business",
"stream",
"comics",
"how",
"fail",
"nsfw",
"new music",
"satire",
"pets & animals",
"computer",
"classical music",
"indie",
"musica",
"msnbc",
"fps",
"mod",
"sport",
"sony",
"ripple",
"auto",
"rock",
"marvel",
"complete",
"mining",
"political",
"mobile",
"pubg",
"hip hop",
"flat earth",
"xbox 360",
"reviews",
"vlogging",
"latest news",
"hack",
"tarot",
"iphone",
"media",
"cute",
"christian",
"free speech",
"trap",
"war",
"remix",
"ios",
"xrp",
"spirituality",
"song",
"league of legends",
"cat"
} }
MATURE_TAGS = [ MATURE_TAGS = [

View file

@ -1,6 +1,7 @@
import curses import curses
import time import time
import asyncio import asyncio
import lbry.wallet
from lbry.conf import Config from lbry.conf import Config
from lbry.extras.daemon.client import daemon_rpc from lbry.extras.daemon.client import daemon_rpc

View file

@ -34,7 +34,7 @@ setup(
}, },
install_requires=[ install_requires=[
'aiohttp==3.5.4', 'aiohttp==3.5.4',
'aioupnp==0.0.17', 'aioupnp==0.0.18',
'appdirs==1.4.3', 'appdirs==1.4.3',
'certifi>=2018.11.29', 'certifi>=2018.11.29',
'colorama==0.3.7', 'colorama==0.3.7',

View file

@ -89,7 +89,7 @@ class ClaimSearchCommand(ClaimTestCase):
# 23829 claim ids makes the request just large enough # 23829 claim ids makes the request just large enough
claim_ids = [ claim_ids = [
'0000000000000000000000000000000000000000', '0000000000000000000000000000000000000000',
] * 23829 ] * 33829
with self.assertRaises(ConnectionResetError): with self.assertRaises(ConnectionResetError):
await self.claim_search(claim_ids=claim_ids) await self.claim_search(claim_ids=claim_ids)