forked from LBRYCommunity/lbry-sdk
retry and batch requests, fix some loose ends
This commit is contained in:
parent
c60b658443
commit
f567aca532
5 changed files with 47 additions and 36 deletions
|
@ -18,6 +18,9 @@ class MockNetwork:
|
||||||
self.get_transaction_called = []
|
self.get_transaction_called = []
|
||||||
self.is_connected = False
|
self.is_connected = False
|
||||||
|
|
||||||
|
def retriable_call(self, function, *args, **kwargs):
|
||||||
|
return function(*args, **kwargs)
|
||||||
|
|
||||||
async def get_history(self, address):
|
async def get_history(self, address):
|
||||||
self.get_history_called.append(address)
|
self.get_history_called.append(address)
|
||||||
self.address = address
|
self.address = address
|
||||||
|
@ -121,8 +124,9 @@ class TestSynchronization(LedgerTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class MocHeaderNetwork:
|
class MocHeaderNetwork(MockNetwork):
|
||||||
def __init__(self, responses):
|
def __init__(self, responses):
|
||||||
|
super().__init__(None, None)
|
||||||
self.responses = responses
|
self.responses = responses
|
||||||
|
|
||||||
async def get_headers(self, height, blocks):
|
async def get_headers(self, height, blocks):
|
||||||
|
|
|
@ -310,7 +310,7 @@ class BaseLedger(metaclass=LedgerRegistry):
|
||||||
subscription_update = False
|
subscription_update = False
|
||||||
|
|
||||||
if not headers:
|
if not headers:
|
||||||
header_response = await self.network.get_headers(height, 2001)
|
header_response = await self.network.retriable_call(self.network.get_headers, height, 2001)
|
||||||
headers = header_response['hex']
|
headers = header_response['hex']
|
||||||
|
|
||||||
if not headers:
|
if not headers:
|
||||||
|
@ -395,13 +395,9 @@ class BaseLedger(metaclass=LedgerRegistry):
|
||||||
|
|
||||||
async def subscribe_addresses(self, address_manager: baseaccount.AddressManager, addresses: List[str]):
|
async def subscribe_addresses(self, address_manager: baseaccount.AddressManager, addresses: List[str]):
|
||||||
if self.network.is_connected and addresses:
|
if self.network.is_connected and addresses:
|
||||||
await asyncio.wait([
|
async for address, remote_status in self.network.subscribe_address(*addresses):
|
||||||
self.subscribe_address(address_manager, address) for address in addresses
|
# subscribe isnt a retriable call as it happens right after a connection is made
|
||||||
])
|
self._update_tasks.add(self.update_history(address, remote_status, address_manager))
|
||||||
|
|
||||||
async def subscribe_address(self, address_manager: baseaccount.AddressManager, address: str):
|
|
||||||
remote_status = await self.network.subscribe_address(address)
|
|
||||||
self._update_tasks.add(self.update_history(address, remote_status, address_manager))
|
|
||||||
|
|
||||||
def process_status_update(self, update):
|
def process_status_update(self, update):
|
||||||
address, remote_status = update
|
address, remote_status = update
|
||||||
|
@ -417,7 +413,7 @@ class BaseLedger(metaclass=LedgerRegistry):
|
||||||
if local_status == remote_status:
|
if local_status == remote_status:
|
||||||
return
|
return
|
||||||
|
|
||||||
remote_history = await self.network.get_history(address)
|
remote_history = await self.network.retriable_call(self.network.get_history, address)
|
||||||
|
|
||||||
cache_tasks = []
|
cache_tasks = []
|
||||||
synced_history = StringIO()
|
synced_history = StringIO()
|
||||||
|
@ -489,7 +485,7 @@ class BaseLedger(metaclass=LedgerRegistry):
|
||||||
|
|
||||||
if tx is None:
|
if tx is None:
|
||||||
# fetch from network
|
# fetch from network
|
||||||
_raw = await self.network.get_transaction(txid)
|
_raw = await self.network.retriable_call(self.network.get_transaction, txid)
|
||||||
if _raw:
|
if _raw:
|
||||||
tx = self.transaction_class(unhexlify(_raw))
|
tx = self.transaction_class(unhexlify(_raw))
|
||||||
await self.maybe_verify_transaction(tx, remote_height)
|
await self.maybe_verify_transaction(tx, remote_height)
|
||||||
|
@ -510,7 +506,7 @@ class BaseLedger(metaclass=LedgerRegistry):
|
||||||
async def maybe_verify_transaction(self, tx, remote_height):
|
async def maybe_verify_transaction(self, tx, remote_height):
|
||||||
tx.height = remote_height
|
tx.height = remote_height
|
||||||
if 0 < remote_height <= len(self.headers):
|
if 0 < remote_height <= len(self.headers):
|
||||||
merkle = await self.network.get_merkle(tx.id, remote_height)
|
merkle = await self.network.retriable_call(self.network.get_merkle, tx.id, remote_height)
|
||||||
merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash)
|
merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash)
|
||||||
header = self.headers[remote_height]
|
header = self.headers[remote_height]
|
||||||
tx.position = merkle['pos']
|
tx.position = merkle['pos']
|
||||||
|
@ -524,6 +520,7 @@ class BaseLedger(metaclass=LedgerRegistry):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def broadcast(self, tx):
|
def broadcast(self, tx):
|
||||||
|
# broadcast cant be a retriable call yet
|
||||||
return self.network.broadcast(hexlify(tx.raw).decode())
|
return self.network.broadcast(hexlify(tx.raw).decode())
|
||||||
|
|
||||||
async def wait(self, tx: basetransaction.BaseTransaction, height=-1, timeout=None):
|
async def wait(self, tx: basetransaction.BaseTransaction, height=-1, timeout=None):
|
||||||
|
|
|
@ -135,31 +135,33 @@ class BaseNetwork:
|
||||||
self.running = False
|
self.running = False
|
||||||
if self.session_pool:
|
if self.session_pool:
|
||||||
self.session_pool.stop()
|
self.session_pool.stop()
|
||||||
if self.is_connected:
|
|
||||||
disconnected = self.client.on_disconnected.first
|
|
||||||
await self.client.close()
|
|
||||||
await disconnected
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_connected(self):
|
def is_connected(self):
|
||||||
return self.client and not self.client.is_closing()
|
return self.client and not self.client.is_closing()
|
||||||
|
|
||||||
def rpc(self, list_or_method, args):
|
def rpc(self, list_or_method, args, session=None):
|
||||||
fastest = self.session_pool.fastest_session
|
session = session or self.session_pool.fastest_session
|
||||||
if fastest is not None and self.client != fastest:
|
if session:
|
||||||
self.switch_event.set()
|
return session.send_request(list_or_method, args)
|
||||||
if self.is_connected:
|
|
||||||
return self.client.send_request(list_or_method, args)
|
|
||||||
else:
|
else:
|
||||||
self.session_pool.trigger_nodelay_connect()
|
self.session_pool.trigger_nodelay_connect()
|
||||||
raise ConnectionError("Attempting to send rpc request when connection is not available.")
|
raise ConnectionError("Attempting to send rpc request when connection is not available.")
|
||||||
|
|
||||||
|
async def retriable_call(self, function, *args, **kwargs):
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
return await function(*args, **kwargs)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
log.warning("Wallet server call timed out, retrying.")
|
||||||
|
except ConnectionError:
|
||||||
|
if not self.is_connected:
|
||||||
|
log.warning("Wallet server unavailable, waiting for it to come back and retry.")
|
||||||
|
await self.on_connected.first
|
||||||
|
|
||||||
def _update_remote_height(self, header_args):
|
def _update_remote_height(self, header_args):
|
||||||
self.remote_height = header_args[0]["height"]
|
self.remote_height = header_args[0]["height"]
|
||||||
|
|
||||||
def broadcast(self, raw_transaction):
|
|
||||||
return self.rpc('blockchain.transaction.broadcast', [raw_transaction])
|
|
||||||
|
|
||||||
def get_history(self, address):
|
def get_history(self, address):
|
||||||
return self.rpc('blockchain.address.get_history', [address])
|
return self.rpc('blockchain.address.get_history', [address])
|
||||||
|
|
||||||
|
@ -175,11 +177,19 @@ class BaseNetwork:
|
||||||
def get_headers(self, height, count=10000):
|
def get_headers(self, height, count=10000):
|
||||||
return self.rpc('blockchain.block.headers', [height, count])
|
return self.rpc('blockchain.block.headers', [height, count])
|
||||||
|
|
||||||
def subscribe_headers(self):
|
# --- Subscribes and broadcasts are always aimed towards the master client directly
|
||||||
return self.rpc('blockchain.headers.subscribe', [True])
|
def broadcast(self, raw_transaction):
|
||||||
|
return self.rpc('blockchain.transaction.broadcast', [raw_transaction], session=self.client)
|
||||||
|
|
||||||
def subscribe_address(self, address):
|
def subscribe_headers(self):
|
||||||
return self.rpc('blockchain.address.subscribe', [address])
|
return self.rpc('blockchain.headers.subscribe', [True], session=self.client)
|
||||||
|
|
||||||
|
async def subscribe_address(self, *addresses):
|
||||||
|
async with self.client.send_batch() as batch:
|
||||||
|
for address in addresses:
|
||||||
|
batch.add_request('blockchain.address.subscribe', [address])
|
||||||
|
for address, status in zip(addresses, batch.results):
|
||||||
|
yield address, status
|
||||||
|
|
||||||
|
|
||||||
class SessionPool:
|
class SessionPool:
|
||||||
|
@ -218,6 +228,7 @@ class SessionPool:
|
||||||
def stop(self):
|
def stop(self):
|
||||||
for session, task in self.sessions.items():
|
for session, task in self.sessions.items():
|
||||||
task.cancel()
|
task.cancel()
|
||||||
|
session.connection_lost(asyncio.CancelledError())
|
||||||
session.abort()
|
session.abort()
|
||||||
self.sessions.clear()
|
self.sessions.clear()
|
||||||
|
|
||||||
|
|
|
@ -746,10 +746,8 @@ class JSONRPCConnection(object):
|
||||||
self._protocol = item
|
self._protocol = item
|
||||||
return self.receive_message(message)
|
return self.receive_message(message)
|
||||||
|
|
||||||
def time_out_pending_requests(self):
|
def raise_pending_requests(self, exception):
|
||||||
"""Times out all pending requests."""
|
exception = exception or asyncio.TimeoutError()
|
||||||
# this used to be CancelledError, but thats confusing as in are we closing the whole sdk or failing?
|
|
||||||
exception = asyncio.TimeoutError()
|
|
||||||
for request, event in self._requests.values():
|
for request, event in self._requests.values():
|
||||||
event.result = exception
|
event.result = exception
|
||||||
event.set()
|
event.set()
|
||||||
|
|
|
@ -146,7 +146,7 @@ class SessionBase(asyncio.Protocol):
|
||||||
await asyncio.wait_for(self._can_send.wait(), secs)
|
await asyncio.wait_for(self._can_send.wait(), secs)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
self.abort()
|
self.abort()
|
||||||
raise asyncio.CancelledError(f'task timed out after {secs}s')
|
raise asyncio.TimeoutError(f'task timed out after {secs}s')
|
||||||
|
|
||||||
async def _send_message(self, message):
|
async def _send_message(self, message):
|
||||||
if not self._can_send.is_set():
|
if not self._can_send.is_set():
|
||||||
|
@ -215,7 +215,8 @@ class SessionBase(asyncio.Protocol):
|
||||||
self._address = None
|
self._address = None
|
||||||
self.transport = None
|
self.transport = None
|
||||||
self._task_group.cancel()
|
self._task_group.cancel()
|
||||||
self._pm_task.cancel()
|
if self._pm_task:
|
||||||
|
self._pm_task.cancel()
|
||||||
# Release waiting tasks
|
# Release waiting tasks
|
||||||
self._can_send.set()
|
self._can_send.set()
|
||||||
|
|
||||||
|
@ -456,7 +457,7 @@ class RPCSession(SessionBase):
|
||||||
|
|
||||||
def connection_lost(self, exc):
|
def connection_lost(self, exc):
|
||||||
# Cancel pending requests and message processing
|
# Cancel pending requests and message processing
|
||||||
self.connection.time_out_pending_requests()
|
self.connection.raise_pending_requests(exc)
|
||||||
super().connection_lost(exc)
|
super().connection_lost(exc)
|
||||||
|
|
||||||
# External API
|
# External API
|
||||||
|
|
Loading…
Reference in a new issue