delete single_call_context, use session pool

This commit is contained in:
Jack Robison 2020-11-20 15:52:11 -05:00
parent f6b396ae64
commit fa63bf758d
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
3 changed files with 18 additions and 99 deletions

View file

@ -637,40 +637,6 @@ class Ledger(metaclass=LedgerRegistry):
log.debug("finished syncing transaction history for %s, %i known txs", address, len(local_history)) log.debug("finished syncing transaction history for %s, %i known txs", address, len(local_history))
return True return True
async def cache_transaction(self, txid, remote_height, check_local=True):
cache_item = self._tx_cache.get(txid)
if cache_item is None:
cache_item = self._tx_cache[txid] = TransactionCacheItem()
elif cache_item.tx is not None and \
cache_item.tx.height >= remote_height and \
(cache_item.tx.is_verified or remote_height < 1):
return cache_item.tx # cached tx is already up-to-date
cache_item.pending_verifications += 1
try:
async with cache_item.lock:
tx = cache_item.tx
if tx is None and check_local:
# check local db
tx = cache_item.tx = await self.db.get_transaction(txid=txid)
merkle = None
if tx is None:
# fetch from network
_raw, merkle = await self.network.retriable_call(
self.network.get_transaction_and_merkle, txid, remote_height
)
tx = Transaction(unhexlify(_raw), height=merkle['block_height'])
cache_item.tx = tx # make sure it's saved before caching it
tx.height = remote_height
if merkle and 0 < remote_height < len(self.headers):
merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash)
header = await self.headers.get(remote_height)
tx.position = merkle['pos']
tx.is_verified = merkle_root == header['merkle_root']
return tx
finally:
cache_item.pending_verifications -= 1
async def maybe_verify_transaction(self, tx, remote_height, merkle=None): async def maybe_verify_transaction(self, tx, remote_height, merkle=None):
tx.height = remote_height tx.height = remote_height
cached = self._tx_cache.get(tx.id) cached = self._tx_cache.get(tx.id)
@ -688,7 +654,7 @@ class Ledger(metaclass=LedgerRegistry):
tx.position = merkle['pos'] tx.position = merkle['pos']
tx.is_verified = merkle_root == header['merkle_root'] tx.is_verified = merkle_root == header['merkle_root']
async def request_transactions_for_inflate(self, to_request: Tuple[Tuple[str, int], ...], session_override=None): async def request_transactions_for_inflate(self, to_request: Tuple[Tuple[str, int], ...]):
header_cache = {} header_cache = {}
batches = [[]] batches = [[]]
remote_heights = {} remote_heights = {}
@ -709,12 +675,9 @@ class Ledger(metaclass=LedgerRegistry):
batches.pop() batches.pop()
async def _single_batch(batch): async def _single_batch(batch):
if session_override: batch_result = await self.network.retriable_call(
batch_result = await self.network.get_transaction_batch( self.network.get_transaction_batch, batch, restricted=False
batch, restricted=False, session=session_override )
)
else:
batch_result = await self.network.retriable_call(self.network.get_transaction_batch, batch)
for txid, (raw, merkle) in batch_result.items(): for txid, (raw, merkle) in batch_result.items():
remote_height = remote_heights[txid] remote_height = remote_heights[txid]
merkle_height = merkle['block_height'] merkle_height = merkle['block_height']
@ -898,17 +861,13 @@ class Ledger(metaclass=LedgerRegistry):
include_is_my_output=False, include_is_my_output=False,
include_sent_supports=False, include_sent_supports=False,
include_sent_tips=False, include_sent_tips=False,
include_received_tips=False, include_received_tips=False) -> Tuple[List[Output], dict, int, int]:
session_override=None) -> Tuple[List[Output], dict, int, int]:
encoded_outputs = await query encoded_outputs = await query
outputs = Outputs.from_base64(encoded_outputs or b'') # TODO: why is the server returning None? outputs = Outputs.from_base64(encoded_outputs or b'') # TODO: why is the server returning None?
txs = [] txs = []
if len(outputs.txs) > 0: if len(outputs.txs) > 0:
txs: List[Transaction] = [] txs: List[Transaction] = []
if session_override: txs.extend((await self.request_transactions_for_inflate(tuple(outputs.txs))))
txs.extend((await self.request_transactions_for_inflate(tuple(outputs.txs), session_override)))
else:
txs.extend((await asyncio.gather(*(self.cache_transaction(*tx) for tx in outputs.txs))))
_txos, blocked = outputs.inflate(txs) _txos, blocked = outputs.inflate(txs)
@ -983,25 +942,17 @@ class Ledger(metaclass=LedgerRegistry):
async def resolve(self, accounts, urls, new_sdk_server=None, **kwargs): async def resolve(self, accounts, urls, new_sdk_server=None, **kwargs):
txos = [] txos = []
urls_copy = list(urls) urls_copy = list(urls)
if new_sdk_server: if new_sdk_server:
resolve = partial(self.network.new_resolve, new_sdk_server) resolve = partial(self.network.new_resolve, new_sdk_server)
while urls_copy:
batch, urls_copy = urls_copy[:500], urls_copy[500:]
txos.extend(
(await self._inflate_outputs(
resolve(batch), accounts, **kwargs
))[0]
)
else: else:
async with self.network.single_call_context(self.network.resolve) as (resolve, session): resolve = partial(self.network.retriable_call, self.network.resolve)
while urls_copy: while urls_copy:
batch, urls_copy = urls_copy[:500], urls_copy[500:] batch, urls_copy = urls_copy[:100], urls_copy[100:]
txos.extend( txos.extend(
(await self._inflate_outputs( (await self._inflate_outputs(
resolve(batch), accounts, session_override=session, **kwargs resolve(batch), accounts, **kwargs
))[0] ))[0]
) )
assert len(urls) == len(txos), "Mismatch between urls requested for resolve and responses received." assert len(urls) == len(txos), "Mismatch between urls requested for resolve and responses received."
result = {} result = {}
@ -1023,17 +974,13 @@ class Ledger(metaclass=LedgerRegistry):
new_sdk_server=None, **kwargs) -> Tuple[List[Output], dict, int, int]: new_sdk_server=None, **kwargs) -> Tuple[List[Output], dict, int, int]:
if new_sdk_server: if new_sdk_server:
claim_search = partial(self.network.new_claim_search, new_sdk_server) claim_search = partial(self.network.new_claim_search, new_sdk_server)
return await self._inflate_outputs( else:
claim_search = self.network.claim_search
return await self._inflate_outputs(
claim_search(**kwargs), accounts, claim_search(**kwargs), accounts,
include_purchase_receipt=include_purchase_receipt, include_purchase_receipt=include_purchase_receipt,
include_is_my_output=include_is_my_output, include_is_my_output=include_is_my_output,
) )
async with self.network.single_call_context(self.network.claim_search) as (claim_search, session):
return await self._inflate_outputs(
claim_search(**kwargs), accounts, session_override=session,
include_purchase_receipt=include_purchase_receipt,
include_is_my_output=include_is_my_output,
)
async def get_claim_by_claim_id(self, accounts, claim_id, **kwargs) -> Output: async def get_claim_by_claim_id(self, accounts, claim_id, **kwargs) -> Output:
for claim in (await self.claim_search(accounts, claim_id=claim_id, **kwargs))[0]: for claim in (await self.claim_search(accounts, claim_id=claim_id, **kwargs))[0]:

View file

@ -3,10 +3,7 @@ import asyncio
import json import json
from time import perf_counter from time import perf_counter
from operator import itemgetter from operator import itemgetter
from contextlib import asynccontextmanager
from functools import partial
from typing import Dict, Optional, Tuple from typing import Dict, Optional, Tuple
import aiohttp import aiohttp
from lbry import __version__ from lbry import __version__
@ -255,28 +252,6 @@ class Network:
pass pass
raise asyncio.CancelledError() # if we got here, we are shutting down raise asyncio.CancelledError() # if we got here, we are shutting down
@asynccontextmanager
async def single_call_context(self, function, *args, **kwargs):
if not self.is_connected:
log.warning("Wallet server unavailable, waiting for it to come back and retry.")
await self.on_connected.first
await self.session_pool.wait_for_fastest_session()
server = self.session_pool.fastest_session.server
session = ClientSession(network=self, server=server)
async def call_with_reconnect(*a, **kw):
while self.running:
if not session.available:
await session.create_connection()
try:
return await partial(function, *args, session_override=session, **kwargs)(*a, **kw)
except asyncio.TimeoutError:
log.warning("'%s' failed, retrying", function.__name__)
try:
yield (call_with_reconnect, session)
finally:
await session.close()
def _update_remote_height(self, header_args): def _update_remote_height(self, header_args):
self.remote_height = header_args[0]["height"] self.remote_height = header_args[0]["height"]

View file

@ -1227,10 +1227,7 @@ class LBRYElectrumX(SessionBase):
return await self.address_status(hashX) return await self.address_status(hashX)
async def hashX_unsubscribe(self, hashX, alias): async def hashX_unsubscribe(self, hashX, alias):
try: self.hashX_subs.pop(hashX, None)
del self.hashX_subs[hashX]
except ValueError:
pass
def address_to_hashX(self, address): def address_to_hashX(self, address):
try: try: