retry and batch requests, fix some loose ends

This commit is contained in:
Victor Shyba 2019-08-18 15:40:38 -03:00
parent c60b658443
commit f567aca532
5 changed files with 47 additions and 36 deletions

View file

@ -18,6 +18,9 @@ class MockNetwork:
self.get_transaction_called = [] self.get_transaction_called = []
self.is_connected = False self.is_connected = False
def retriable_call(self, function, *args, **kwargs):
return function(*args, **kwargs)
async def get_history(self, address): async def get_history(self, address):
self.get_history_called.append(address) self.get_history_called.append(address)
self.address = address self.address = address
@ -121,8 +124,9 @@ class TestSynchronization(LedgerTestCase):
) )
class MocHeaderNetwork: class MocHeaderNetwork(MockNetwork):
def __init__(self, responses): def __init__(self, responses):
super().__init__(None, None)
self.responses = responses self.responses = responses
async def get_headers(self, height, blocks): async def get_headers(self, height, blocks):

View file

@ -310,7 +310,7 @@ class BaseLedger(metaclass=LedgerRegistry):
subscription_update = False subscription_update = False
if not headers: if not headers:
header_response = await self.network.get_headers(height, 2001) header_response = await self.network.retriable_call(self.network.get_headers, height, 2001)
headers = header_response['hex'] headers = header_response['hex']
if not headers: if not headers:
@ -395,13 +395,9 @@ class BaseLedger(metaclass=LedgerRegistry):
async def subscribe_addresses(self, address_manager: baseaccount.AddressManager, addresses: List[str]): async def subscribe_addresses(self, address_manager: baseaccount.AddressManager, addresses: List[str]):
if self.network.is_connected and addresses: if self.network.is_connected and addresses:
await asyncio.wait([ async for address, remote_status in self.network.subscribe_address(*addresses):
self.subscribe_address(address_manager, address) for address in addresses # subscribe isnt a retriable call as it happens right after a connection is made
]) self._update_tasks.add(self.update_history(address, remote_status, address_manager))
async def subscribe_address(self, address_manager: baseaccount.AddressManager, address: str):
remote_status = await self.network.subscribe_address(address)
self._update_tasks.add(self.update_history(address, remote_status, address_manager))
def process_status_update(self, update): def process_status_update(self, update):
address, remote_status = update address, remote_status = update
@ -417,7 +413,7 @@ class BaseLedger(metaclass=LedgerRegistry):
if local_status == remote_status: if local_status == remote_status:
return return
remote_history = await self.network.get_history(address) remote_history = await self.network.retriable_call(self.network.get_history, address)
cache_tasks = [] cache_tasks = []
synced_history = StringIO() synced_history = StringIO()
@ -489,7 +485,7 @@ class BaseLedger(metaclass=LedgerRegistry):
if tx is None: if tx is None:
# fetch from network # fetch from network
_raw = await self.network.get_transaction(txid) _raw = await self.network.retriable_call(self.network.get_transaction, txid)
if _raw: if _raw:
tx = self.transaction_class(unhexlify(_raw)) tx = self.transaction_class(unhexlify(_raw))
await self.maybe_verify_transaction(tx, remote_height) await self.maybe_verify_transaction(tx, remote_height)
@ -510,7 +506,7 @@ class BaseLedger(metaclass=LedgerRegistry):
async def maybe_verify_transaction(self, tx, remote_height): async def maybe_verify_transaction(self, tx, remote_height):
tx.height = remote_height tx.height = remote_height
if 0 < remote_height <= len(self.headers): if 0 < remote_height <= len(self.headers):
merkle = await self.network.get_merkle(tx.id, remote_height) merkle = await self.network.retriable_call(self.network.get_merkle, tx.id, remote_height)
merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash) merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash)
header = self.headers[remote_height] header = self.headers[remote_height]
tx.position = merkle['pos'] tx.position = merkle['pos']
@ -524,6 +520,7 @@ class BaseLedger(metaclass=LedgerRegistry):
return None return None
def broadcast(self, tx): def broadcast(self, tx):
# broadcast cant be a retriable call yet
return self.network.broadcast(hexlify(tx.raw).decode()) return self.network.broadcast(hexlify(tx.raw).decode())
async def wait(self, tx: basetransaction.BaseTransaction, height=-1, timeout=None): async def wait(self, tx: basetransaction.BaseTransaction, height=-1, timeout=None):

View file

@ -135,31 +135,33 @@ class BaseNetwork:
self.running = False self.running = False
if self.session_pool: if self.session_pool:
self.session_pool.stop() self.session_pool.stop()
if self.is_connected:
disconnected = self.client.on_disconnected.first
await self.client.close()
await disconnected
@property @property
def is_connected(self): def is_connected(self):
return self.client and not self.client.is_closing() return self.client and not self.client.is_closing()
def rpc(self, list_or_method, args): def rpc(self, list_or_method, args, session=None):
fastest = self.session_pool.fastest_session session = session or self.session_pool.fastest_session
if fastest is not None and self.client != fastest: if session:
self.switch_event.set() return session.send_request(list_or_method, args)
if self.is_connected:
return self.client.send_request(list_or_method, args)
else: else:
self.session_pool.trigger_nodelay_connect() self.session_pool.trigger_nodelay_connect()
raise ConnectionError("Attempting to send rpc request when connection is not available.") raise ConnectionError("Attempting to send rpc request when connection is not available.")
async def retriable_call(self, function, *args, **kwargs):
while True:
try:
return await function(*args, **kwargs)
except asyncio.TimeoutError:
log.warning("Wallet server call timed out, retrying.")
except ConnectionError:
if not self.is_connected:
log.warning("Wallet server unavailable, waiting for it to come back and retry.")
await self.on_connected.first
def _update_remote_height(self, header_args): def _update_remote_height(self, header_args):
self.remote_height = header_args[0]["height"] self.remote_height = header_args[0]["height"]
def broadcast(self, raw_transaction):
return self.rpc('blockchain.transaction.broadcast', [raw_transaction])
def get_history(self, address): def get_history(self, address):
return self.rpc('blockchain.address.get_history', [address]) return self.rpc('blockchain.address.get_history', [address])
@ -175,11 +177,19 @@ class BaseNetwork:
def get_headers(self, height, count=10000): def get_headers(self, height, count=10000):
return self.rpc('blockchain.block.headers', [height, count]) return self.rpc('blockchain.block.headers', [height, count])
def subscribe_headers(self): # --- Subscribes and broadcasts are always aimed towards the master client directly
return self.rpc('blockchain.headers.subscribe', [True]) def broadcast(self, raw_transaction):
return self.rpc('blockchain.transaction.broadcast', [raw_transaction], session=self.client)
def subscribe_address(self, address): def subscribe_headers(self):
return self.rpc('blockchain.address.subscribe', [address]) return self.rpc('blockchain.headers.subscribe', [True], session=self.client)
async def subscribe_address(self, *addresses):
async with self.client.send_batch() as batch:
for address in addresses:
batch.add_request('blockchain.address.subscribe', [address])
for address, status in zip(addresses, batch.results):
yield address, status
class SessionPool: class SessionPool:
@ -218,6 +228,7 @@ class SessionPool:
def stop(self): def stop(self):
for session, task in self.sessions.items(): for session, task in self.sessions.items():
task.cancel() task.cancel()
session.connection_lost(asyncio.CancelledError())
session.abort() session.abort()
self.sessions.clear() self.sessions.clear()

View file

@ -746,10 +746,8 @@ class JSONRPCConnection(object):
self._protocol = item self._protocol = item
return self.receive_message(message) return self.receive_message(message)
def time_out_pending_requests(self): def raise_pending_requests(self, exception):
"""Times out all pending requests.""" exception = exception or asyncio.TimeoutError()
# this used to be CancelledError, but thats confusing as in are we closing the whole sdk or failing?
exception = asyncio.TimeoutError()
for request, event in self._requests.values(): for request, event in self._requests.values():
event.result = exception event.result = exception
event.set() event.set()

View file

@ -146,7 +146,7 @@ class SessionBase(asyncio.Protocol):
await asyncio.wait_for(self._can_send.wait(), secs) await asyncio.wait_for(self._can_send.wait(), secs)
except asyncio.TimeoutError: except asyncio.TimeoutError:
self.abort() self.abort()
raise asyncio.CancelledError(f'task timed out after {secs}s') raise asyncio.TimeoutError(f'task timed out after {secs}s')
async def _send_message(self, message): async def _send_message(self, message):
if not self._can_send.is_set(): if not self._can_send.is_set():
@ -215,7 +215,8 @@ class SessionBase(asyncio.Protocol):
self._address = None self._address = None
self.transport = None self.transport = None
self._task_group.cancel() self._task_group.cancel()
self._pm_task.cancel() if self._pm_task:
self._pm_task.cancel()
# Release waiting tasks # Release waiting tasks
self._can_send.set() self._can_send.set()
@ -456,7 +457,7 @@ class RPCSession(SessionBase):
def connection_lost(self, exc): def connection_lost(self, exc):
# Cancel pending requests and message processing # Cancel pending requests and message processing
self.connection.time_out_pending_requests() self.connection.raise_pending_requests(exc)
super().connection_lost(exc) super().connection_lost(exc)
# External API # External API