retry and batch requests, fix some loose ends
This commit is contained in:
parent
c60b658443
commit
f567aca532
5 changed files with 47 additions and 36 deletions
|
@ -18,6 +18,9 @@ class MockNetwork:
|
|||
self.get_transaction_called = []
|
||||
self.is_connected = False
|
||||
|
||||
def retriable_call(self, function, *args, **kwargs):
|
||||
return function(*args, **kwargs)
|
||||
|
||||
async def get_history(self, address):
|
||||
self.get_history_called.append(address)
|
||||
self.address = address
|
||||
|
@ -121,8 +124,9 @@ class TestSynchronization(LedgerTestCase):
|
|||
)
|
||||
|
||||
|
||||
class MocHeaderNetwork:
|
||||
class MocHeaderNetwork(MockNetwork):
|
||||
def __init__(self, responses):
|
||||
super().__init__(None, None)
|
||||
self.responses = responses
|
||||
|
||||
async def get_headers(self, height, blocks):
|
||||
|
|
|
@ -310,7 +310,7 @@ class BaseLedger(metaclass=LedgerRegistry):
|
|||
subscription_update = False
|
||||
|
||||
if not headers:
|
||||
header_response = await self.network.get_headers(height, 2001)
|
||||
header_response = await self.network.retriable_call(self.network.get_headers, height, 2001)
|
||||
headers = header_response['hex']
|
||||
|
||||
if not headers:
|
||||
|
@ -395,12 +395,8 @@ class BaseLedger(metaclass=LedgerRegistry):
|
|||
|
||||
async def subscribe_addresses(self, address_manager: baseaccount.AddressManager, addresses: List[str]):
|
||||
if self.network.is_connected and addresses:
|
||||
await asyncio.wait([
|
||||
self.subscribe_address(address_manager, address) for address in addresses
|
||||
])
|
||||
|
||||
async def subscribe_address(self, address_manager: baseaccount.AddressManager, address: str):
|
||||
remote_status = await self.network.subscribe_address(address)
|
||||
async for address, remote_status in self.network.subscribe_address(*addresses):
|
||||
# subscribe isnt a retriable call as it happens right after a connection is made
|
||||
self._update_tasks.add(self.update_history(address, remote_status, address_manager))
|
||||
|
||||
def process_status_update(self, update):
|
||||
|
@ -417,7 +413,7 @@ class BaseLedger(metaclass=LedgerRegistry):
|
|||
if local_status == remote_status:
|
||||
return
|
||||
|
||||
remote_history = await self.network.get_history(address)
|
||||
remote_history = await self.network.retriable_call(self.network.get_history, address)
|
||||
|
||||
cache_tasks = []
|
||||
synced_history = StringIO()
|
||||
|
@ -489,7 +485,7 @@ class BaseLedger(metaclass=LedgerRegistry):
|
|||
|
||||
if tx is None:
|
||||
# fetch from network
|
||||
_raw = await self.network.get_transaction(txid)
|
||||
_raw = await self.network.retriable_call(self.network.get_transaction, txid)
|
||||
if _raw:
|
||||
tx = self.transaction_class(unhexlify(_raw))
|
||||
await self.maybe_verify_transaction(tx, remote_height)
|
||||
|
@ -510,7 +506,7 @@ class BaseLedger(metaclass=LedgerRegistry):
|
|||
async def maybe_verify_transaction(self, tx, remote_height):
|
||||
tx.height = remote_height
|
||||
if 0 < remote_height <= len(self.headers):
|
||||
merkle = await self.network.get_merkle(tx.id, remote_height)
|
||||
merkle = await self.network.retriable_call(self.network.get_merkle, tx.id, remote_height)
|
||||
merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash)
|
||||
header = self.headers[remote_height]
|
||||
tx.position = merkle['pos']
|
||||
|
@ -524,6 +520,7 @@ class BaseLedger(metaclass=LedgerRegistry):
|
|||
return None
|
||||
|
||||
def broadcast(self, tx):
|
||||
# broadcast cant be a retriable call yet
|
||||
return self.network.broadcast(hexlify(tx.raw).decode())
|
||||
|
||||
async def wait(self, tx: basetransaction.BaseTransaction, height=-1, timeout=None):
|
||||
|
|
|
@ -135,31 +135,33 @@ class BaseNetwork:
|
|||
self.running = False
|
||||
if self.session_pool:
|
||||
self.session_pool.stop()
|
||||
if self.is_connected:
|
||||
disconnected = self.client.on_disconnected.first
|
||||
await self.client.close()
|
||||
await disconnected
|
||||
|
||||
@property
|
||||
def is_connected(self):
|
||||
return self.client and not self.client.is_closing()
|
||||
|
||||
def rpc(self, list_or_method, args):
|
||||
fastest = self.session_pool.fastest_session
|
||||
if fastest is not None and self.client != fastest:
|
||||
self.switch_event.set()
|
||||
if self.is_connected:
|
||||
return self.client.send_request(list_or_method, args)
|
||||
def rpc(self, list_or_method, args, session=None):
|
||||
session = session or self.session_pool.fastest_session
|
||||
if session:
|
||||
return session.send_request(list_or_method, args)
|
||||
else:
|
||||
self.session_pool.trigger_nodelay_connect()
|
||||
raise ConnectionError("Attempting to send rpc request when connection is not available.")
|
||||
|
||||
async def retriable_call(self, function, *args, **kwargs):
|
||||
while True:
|
||||
try:
|
||||
return await function(*args, **kwargs)
|
||||
except asyncio.TimeoutError:
|
||||
log.warning("Wallet server call timed out, retrying.")
|
||||
except ConnectionError:
|
||||
if not self.is_connected:
|
||||
log.warning("Wallet server unavailable, waiting for it to come back and retry.")
|
||||
await self.on_connected.first
|
||||
|
||||
def _update_remote_height(self, header_args):
|
||||
self.remote_height = header_args[0]["height"]
|
||||
|
||||
def broadcast(self, raw_transaction):
|
||||
return self.rpc('blockchain.transaction.broadcast', [raw_transaction])
|
||||
|
||||
def get_history(self, address):
|
||||
return self.rpc('blockchain.address.get_history', [address])
|
||||
|
||||
|
@ -175,11 +177,19 @@ class BaseNetwork:
|
|||
def get_headers(self, height, count=10000):
|
||||
return self.rpc('blockchain.block.headers', [height, count])
|
||||
|
||||
def subscribe_headers(self):
|
||||
return self.rpc('blockchain.headers.subscribe', [True])
|
||||
# --- Subscribes and broadcasts are always aimed towards the master client directly
|
||||
def broadcast(self, raw_transaction):
|
||||
return self.rpc('blockchain.transaction.broadcast', [raw_transaction], session=self.client)
|
||||
|
||||
def subscribe_address(self, address):
|
||||
return self.rpc('blockchain.address.subscribe', [address])
|
||||
def subscribe_headers(self):
|
||||
return self.rpc('blockchain.headers.subscribe', [True], session=self.client)
|
||||
|
||||
async def subscribe_address(self, *addresses):
|
||||
async with self.client.send_batch() as batch:
|
||||
for address in addresses:
|
||||
batch.add_request('blockchain.address.subscribe', [address])
|
||||
for address, status in zip(addresses, batch.results):
|
||||
yield address, status
|
||||
|
||||
|
||||
class SessionPool:
|
||||
|
@ -218,6 +228,7 @@ class SessionPool:
|
|||
def stop(self):
|
||||
for session, task in self.sessions.items():
|
||||
task.cancel()
|
||||
session.connection_lost(asyncio.CancelledError())
|
||||
session.abort()
|
||||
self.sessions.clear()
|
||||
|
||||
|
|
|
@ -746,10 +746,8 @@ class JSONRPCConnection(object):
|
|||
self._protocol = item
|
||||
return self.receive_message(message)
|
||||
|
||||
def time_out_pending_requests(self):
|
||||
"""Times out all pending requests."""
|
||||
# this used to be CancelledError, but thats confusing as in are we closing the whole sdk or failing?
|
||||
exception = asyncio.TimeoutError()
|
||||
def raise_pending_requests(self, exception):
|
||||
exception = exception or asyncio.TimeoutError()
|
||||
for request, event in self._requests.values():
|
||||
event.result = exception
|
||||
event.set()
|
||||
|
|
|
@ -146,7 +146,7 @@ class SessionBase(asyncio.Protocol):
|
|||
await asyncio.wait_for(self._can_send.wait(), secs)
|
||||
except asyncio.TimeoutError:
|
||||
self.abort()
|
||||
raise asyncio.CancelledError(f'task timed out after {secs}s')
|
||||
raise asyncio.TimeoutError(f'task timed out after {secs}s')
|
||||
|
||||
async def _send_message(self, message):
|
||||
if not self._can_send.is_set():
|
||||
|
@ -215,6 +215,7 @@ class SessionBase(asyncio.Protocol):
|
|||
self._address = None
|
||||
self.transport = None
|
||||
self._task_group.cancel()
|
||||
if self._pm_task:
|
||||
self._pm_task.cancel()
|
||||
# Release waiting tasks
|
||||
self._can_send.set()
|
||||
|
@ -456,7 +457,7 @@ class RPCSession(SessionBase):
|
|||
|
||||
def connection_lost(self, exc):
|
||||
# Cancel pending requests and message processing
|
||||
self.connection.time_out_pending_requests()
|
||||
self.connection.raise_pending_requests(exc)
|
||||
super().connection_lost(exc)
|
||||
|
||||
# External API
|
||||
|
|
Loading…
Reference in a new issue