use single_call_context for claim_search and resolve

This commit is contained in:
Jack Robison 2020-10-18 21:02:19 -04:00
parent 2faa29b1c4
commit 76946c447f
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
2 changed files with 124 additions and 28 deletions

View file

@ -29,7 +29,6 @@ from .constants import TXO_TYPES, CLAIM_TYPES, COIN, NULL_HASH32
from .bip32 import PubKey, PrivateKey from .bip32 import PubKey, PrivateKey
from .coinselection import CoinSelector from .coinselection import CoinSelector
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
LedgerType = Type['BaseLedger'] LedgerType = Type['BaseLedger']
@ -687,6 +686,59 @@ class Ledger(metaclass=LedgerRegistry):
tx.position = merkle['pos'] tx.position = merkle['pos']
tx.is_verified = merkle_root == header['merkle_root'] tx.is_verified = merkle_root == header['merkle_root']
async def request_transactions_for_inflate(self, to_request: Tuple[Tuple[str, int], ...], session_override=None):
header_cache = {}
batches = [[]]
remote_heights = {}
transactions = []
heights_in_batch = 0
last_height = 0
for txid, height in sorted(to_request, key=lambda x: x[1]):
remote_heights[txid] = height
if height != last_height:
heights_in_batch += 1
last_height = height
if len(batches[-1]) == 100 or heights_in_batch == 20:
batches.append([])
heights_in_batch = 1
batches[-1].append(txid)
if not batches[-1]:
batches.pop()
async def _single_batch(batch):
if session_override:
batch_result = await self.network.get_transaction_batch(
batch, restricted=False, session=session_override
)
else:
batch_result = await self.network.retriable_call(self.network.get_transaction_batch, batch)
for txid, (raw, merkle) in batch_result.items():
remote_height = remote_heights[txid]
merkle_height = merkle['block_height']
cache_item = self._tx_cache.get(txid)
if cache_item is None:
cache_item = TransactionCacheItem()
self._tx_cache[txid] = cache_item
tx = cache_item.tx or Transaction(unhexlify(raw), height=remote_height)
tx.height = remote_height
cache_item.tx = tx
if 'merkle' in merkle and remote_heights[txid] > 0:
merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash)
try:
header = header_cache.get(remote_heights[txid]) or (await self.headers.get(merkle_height))
except IndexError:
log.warning("failed to verify %s at height %i", tx.id, merkle_height)
else:
header_cache[remote_heights[txid]] = header
tx.position = merkle['pos']
tx.is_verified = merkle_root == header['merkle_root']
transactions.append(tx)
for batch in batches:
await _single_batch(batch)
return transactions
async def _request_transaction_batch(self, to_request, remote_history_size, address): async def _request_transaction_batch(self, to_request, remote_history_size, address):
header_cache = {} header_cache = {}
batches = [[]] batches = [[]]
@ -844,14 +896,17 @@ class Ledger(metaclass=LedgerRegistry):
include_is_my_output=False, include_is_my_output=False,
include_sent_supports=False, include_sent_supports=False,
include_sent_tips=False, include_sent_tips=False,
include_received_tips=False) -> Tuple[List[Output], dict, int, int]: include_received_tips=False,
session_override=None) -> Tuple[List[Output], dict, int, int]:
encoded_outputs = await query encoded_outputs = await query
outputs = Outputs.from_base64(encoded_outputs or b'') # TODO: why is the server returning None? outputs = Outputs.from_base64(encoded_outputs or b'') # TODO: why is the server returning None?
txs = [] txs = []
if len(outputs.txs) > 0: if len(outputs.txs) > 0:
txs: List[Transaction] = await asyncio.gather(*( txs: List[Transaction] = []
self.cache_transaction(*tx) for tx in outputs.txs if session_override:
)) txs.extend((await self.request_transactions_for_inflate(tuple(outputs.txs), session_override)))
else:
txs.extend((await asyncio.gather(*(self.cache_transaction(*tx) for tx in outputs.txs))))
_txos, blocked = outputs.inflate(txs) _txos, blocked = outputs.inflate(txs)
@ -924,15 +979,28 @@ class Ledger(metaclass=LedgerRegistry):
return txos, blocked, outputs.offset, outputs.total return txos, blocked, outputs.offset, outputs.total
async def resolve(self, accounts, urls, new_sdk_server=None, **kwargs): async def resolve(self, accounts, urls, new_sdk_server=None, **kwargs):
if new_sdk_server:
resolve = partial(self.network.new_resolve, new_sdk_server)
else:
resolve = partial(self.network.retriable_call, self.network.resolve)
urls_copy = list(urls)
txos = [] txos = []
while urls_copy: urls_copy = list(urls)
batch, urls_copy = urls_copy[:500], urls_copy[500:]
txos.extend((await self._inflate_outputs(resolve(batch), accounts, **kwargs))[0]) if new_sdk_server:
resolve = partial(self.network.resolve, new_sdk_server)
while urls_copy:
batch, urls_copy = urls_copy[:500], urls_copy[500:]
txos.extend(
(await self._inflate_outputs(
resolve(batch), accounts, **kwargs
))[0]
)
else:
async with self.network.single_call_context(self.network.resolve) as (resolve, session):
while urls_copy:
batch, urls_copy = urls_copy[:500], urls_copy[500:]
txos.extend(
(await self._inflate_outputs(
resolve(batch), accounts, session_override=session, **kwargs
))[0]
)
assert len(urls) == len(txos), "Mismatch between urls requested for resolve and responses received." assert len(urls) == len(txos), "Mismatch between urls requested for resolve and responses received."
result = {} result = {}
for url, txo in zip(urls, txos): for url, txo in zip(urls, txos):
@ -953,13 +1021,17 @@ class Ledger(metaclass=LedgerRegistry):
new_sdk_server=None, **kwargs) -> Tuple[List[Output], dict, int, int]: new_sdk_server=None, **kwargs) -> Tuple[List[Output], dict, int, int]:
if new_sdk_server: if new_sdk_server:
claim_search = partial(self.network.new_claim_search, new_sdk_server) claim_search = partial(self.network.new_claim_search, new_sdk_server)
else: return await self._inflate_outputs(
claim_search = self.network.claim_search claim_search(**kwargs), accounts,
return await self._inflate_outputs( include_purchase_receipt=include_purchase_receipt,
claim_search(**kwargs), accounts, include_is_my_output=include_is_my_output,
include_purchase_receipt=include_purchase_receipt, )
include_is_my_output=include_is_my_output async with self.network.single_call_context(self.network.claim_search) as (claim_search, session):
) return await self._inflate_outputs(
claim_search(**kwargs), accounts, session_override=session,
include_purchase_receipt=include_purchase_receipt,
include_is_my_output=include_is_my_output,
)
async def get_claim_by_claim_id(self, accounts, claim_id, **kwargs) -> Output: async def get_claim_by_claim_id(self, accounts, claim_id, **kwargs) -> Output:
for claim in (await self.claim_search(accounts, claim_id=claim_id, **kwargs))[0]: for claim in (await self.claim_search(accounts, claim_id=claim_id, **kwargs))[0]:

View file

@ -3,6 +3,8 @@ import asyncio
import json import json
from time import perf_counter from time import perf_counter
from operator import itemgetter from operator import itemgetter
from contextlib import asynccontextmanager
from functools import partial
from typing import Dict, Optional, Tuple from typing import Dict, Optional, Tuple
import aiohttp import aiohttp
@ -230,8 +232,8 @@ class Network:
def is_connected(self): def is_connected(self):
return self.client and not self.client.is_closing() return self.client and not self.client.is_closing()
def rpc(self, list_or_method, args, restricted=True): def rpc(self, list_or_method, args, restricted=True, session=None):
session = self.client if restricted else self.session_pool.fastest_session session = session or (self.client if restricted else self.session_pool.fastest_session)
if session and not session.is_closing(): if session and not session.is_closing():
return session.send_request(list_or_method, args) return session.send_request(list_or_method, args)
else: else:
@ -253,6 +255,28 @@ class Network:
pass pass
raise asyncio.CancelledError() # if we got here, we are shutting down raise asyncio.CancelledError() # if we got here, we are shutting down
@asynccontextmanager
async def single_call_context(self, function, *args, **kwargs):
if not self.is_connected:
log.warning("Wallet server unavailable, waiting for it to come back and retry.")
await self.on_connected.first
await self.session_pool.wait_for_fastest_session()
server = self.session_pool.fastest_session.server
session = ClientSession(network=self, server=server)
async def call_with_reconnect(*a, **kw):
while self.running:
if not session.available:
await session.create_connection()
try:
return await partial(function, *args, session_override=session, **kwargs)(*a, **kw)
except asyncio.TimeoutError:
log.warning("'%s' failed, retrying", function.__name__)
try:
yield (call_with_reconnect, session)
finally:
await session.close()
def _update_remote_height(self, header_args): def _update_remote_height(self, header_args):
self.remote_height = header_args[0]["height"] self.remote_height = header_args[0]["height"]
@ -261,9 +285,9 @@ class Network:
restricted = known_height in (None, -1, 0) or 0 > known_height > self.remote_height - 10 restricted = known_height in (None, -1, 0) or 0 > known_height > self.remote_height - 10
return self.rpc('blockchain.transaction.get', [tx_hash], restricted) return self.rpc('blockchain.transaction.get', [tx_hash], restricted)
def get_transaction_batch(self, txids): def get_transaction_batch(self, txids, restricted=True, session=None):
# use any server if its old, otherwise restrict to who gave us the history # use any server if its old, otherwise restrict to who gave us the history
return self.rpc('blockchain.transaction.get_batch', txids, True) return self.rpc('blockchain.transaction.get_batch', txids, restricted, session)
def get_transaction_and_merkle(self, tx_hash, known_height=None): def get_transaction_and_merkle(self, tx_hash, known_height=None):
# use any server if its old, otherwise restrict to who gave us the history # use any server if its old, otherwise restrict to who gave us the history
@ -316,11 +340,11 @@ class Network:
def get_claims_by_ids(self, claim_ids): def get_claims_by_ids(self, claim_ids):
return self.rpc('blockchain.claimtrie.getclaimsbyids', claim_ids) return self.rpc('blockchain.claimtrie.getclaimsbyids', claim_ids)
def resolve(self, urls): def resolve(self, urls, session_override=None):
return self.rpc('blockchain.claimtrie.resolve', urls) return self.rpc('blockchain.claimtrie.resolve', urls, False, session_override)
def claim_search(self, **kwargs): def claim_search(self, session_override=None, **kwargs):
return self.rpc('blockchain.claimtrie.search', kwargs) return self.rpc('blockchain.claimtrie.search', kwargs, False, session_override)
async def new_resolve(self, server, urls): async def new_resolve(self, server, urls):
message = {"method": "resolve", "params": {"urls": urls, "protobuf": True}} message = {"method": "resolve", "params": {"urls": urls, "protobuf": True}}