Merge pull request #2398 from lbryio/retriable_batcheable_networking
Improve wallet server selection
This commit is contained in:
commit
a11956ece0
7 changed files with 150 additions and 75 deletions
|
@ -255,7 +255,7 @@ class TestQueries(AsyncioTestCase):
|
|||
self.ledger.db.db.execute_fetchall = check_parameters_length
|
||||
account = await self.create_account()
|
||||
tx = await self.create_tx_from_nothing(account, 0)
|
||||
for height in range(1200):
|
||||
for height in range(1, 1200):
|
||||
tx = await self.create_tx_from_txo(tx.outputs[0], account, height=height)
|
||||
variable_limit = self.ledger.db.MAX_QUERY_VARIABLES
|
||||
for limit in range(variable_limit-2, variable_limit+2):
|
||||
|
|
|
@ -18,6 +18,9 @@ class MockNetwork:
|
|||
self.get_transaction_called = []
|
||||
self.is_connected = False
|
||||
|
||||
def retriable_call(self, function, *args, **kwargs):
|
||||
return function(*args, **kwargs)
|
||||
|
||||
async def get_history(self, address):
|
||||
self.get_history_called.append(address)
|
||||
self.address = address
|
||||
|
@ -121,8 +124,9 @@ class TestSynchronization(LedgerTestCase):
|
|||
)
|
||||
|
||||
|
||||
class MocHeaderNetwork:
|
||||
class MocHeaderNetwork(MockNetwork):
|
||||
def __init__(self, responses):
|
||||
super().__init__(None, None)
|
||||
self.responses = responses
|
||||
|
||||
async def get_headers(self, height, blocks):
|
||||
|
|
|
@ -310,7 +310,7 @@ class BaseLedger(metaclass=LedgerRegistry):
|
|||
subscription_update = False
|
||||
|
||||
if not headers:
|
||||
header_response = await self.network.get_headers(height, 2001)
|
||||
header_response = await self.network.retriable_call(self.network.get_headers, height, 2001)
|
||||
headers = header_response['hex']
|
||||
|
||||
if not headers:
|
||||
|
@ -395,12 +395,8 @@ class BaseLedger(metaclass=LedgerRegistry):
|
|||
|
||||
async def subscribe_addresses(self, address_manager: baseaccount.AddressManager, addresses: List[str]):
|
||||
if self.network.is_connected and addresses:
|
||||
await asyncio.wait([
|
||||
self.subscribe_address(address_manager, address) for address in addresses
|
||||
])
|
||||
|
||||
async def subscribe_address(self, address_manager: baseaccount.AddressManager, address: str):
|
||||
remote_status = await self.network.subscribe_address(address)
|
||||
async for address, remote_status in self.network.subscribe_address(*addresses):
|
||||
# subscribe isnt a retriable call as it happens right after a connection is made
|
||||
self._update_tasks.add(self.update_history(address, remote_status, address_manager))
|
||||
|
||||
def process_status_update(self, update):
|
||||
|
@ -417,7 +413,7 @@ class BaseLedger(metaclass=LedgerRegistry):
|
|||
if local_status == remote_status:
|
||||
return
|
||||
|
||||
remote_history = await self.network.get_history(address)
|
||||
remote_history = await self.network.retriable_call(self.network.get_history, address)
|
||||
|
||||
cache_tasks = []
|
||||
synced_history = StringIO()
|
||||
|
@ -489,7 +485,7 @@ class BaseLedger(metaclass=LedgerRegistry):
|
|||
|
||||
if tx is None:
|
||||
# fetch from network
|
||||
_raw = await self.network.get_transaction(txid)
|
||||
_raw = await self.network.retriable_call(self.network.get_transaction, txid)
|
||||
if _raw:
|
||||
tx = self.transaction_class(unhexlify(_raw))
|
||||
await self.maybe_verify_transaction(tx, remote_height)
|
||||
|
@ -510,7 +506,7 @@ class BaseLedger(metaclass=LedgerRegistry):
|
|||
async def maybe_verify_transaction(self, tx, remote_height):
|
||||
tx.height = remote_height
|
||||
if 0 < remote_height <= len(self.headers):
|
||||
merkle = await self.network.get_merkle(tx.id, remote_height)
|
||||
merkle = await self.network.retriable_call(self.network.get_merkle, tx.id, remote_height)
|
||||
merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash)
|
||||
header = self.headers[remote_height]
|
||||
tx.position = merkle['pos']
|
||||
|
@ -524,6 +520,7 @@ class BaseLedger(metaclass=LedgerRegistry):
|
|||
return None
|
||||
|
||||
def broadcast(self, tx):
|
||||
# broadcast cant be a retriable call yet
|
||||
return self.network.broadcast(hexlify(tx.raw).decode())
|
||||
|
||||
async def wait(self, tx: basetransaction.BaseTransaction, height=-1, timeout=None):
|
||||
|
@ -545,4 +542,4 @@ class BaseLedger(metaclass=LedgerRegistry):
|
|||
)) for address_record in records
|
||||
], timeout=timeout)
|
||||
if pending:
|
||||
raise TimeoutError('Timed out waiting for transaction.')
|
||||
raise asyncio.TimeoutError('Timed out waiting for transaction.')
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
import logging
|
||||
import asyncio
|
||||
from operator import itemgetter
|
||||
from typing import Dict, Optional
|
||||
from time import time, perf_counter
|
||||
from typing import Dict, Optional, Tuple
|
||||
from time import perf_counter
|
||||
|
||||
from torba.rpc import RPCSession as BaseClientSession, Connector, RPCError
|
||||
|
||||
|
@ -13,17 +13,20 @@ log = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class ClientSession(BaseClientSession):
|
||||
|
||||
def __init__(self, *args, network, server, timeout=30, on_connect_callback=None, **kwargs):
|
||||
self.network = network
|
||||
self.server = server
|
||||
super().__init__(*args, **kwargs)
|
||||
self._on_disconnect_controller = StreamController()
|
||||
self.on_disconnected = self._on_disconnect_controller.stream
|
||||
self.bw_limit = self.framer.max_size = self.max_errors = 1 << 32
|
||||
self.framer.max_size = self.max_errors = 1 << 32
|
||||
self.bw_limit = -1
|
||||
self.timeout = timeout
|
||||
self.max_seconds_idle = timeout * 2
|
||||
self.response_time: Optional[float] = None
|
||||
self.connection_latency: Optional[float] = None
|
||||
self._response_samples = 0
|
||||
self.pending_amount = 0
|
||||
self._on_connect_cb = on_connect_callback or (lambda: None)
|
||||
self.trigger_urgent_reconnect = asyncio.Event()
|
||||
|
||||
|
@ -31,20 +34,38 @@ class ClientSession(BaseClientSession):
|
|||
def available(self):
|
||||
return not self.is_closing() and self._can_send.is_set() and self.response_time is not None
|
||||
|
||||
async def send_request(self, method, args=()):
|
||||
try:
|
||||
@property
|
||||
def server_address_and_port(self) -> Optional[Tuple[str, int]]:
|
||||
if not self.transport:
|
||||
return None
|
||||
return self.transport.get_extra_info('peername')
|
||||
|
||||
async def send_timed_server_version_request(self, args=()):
|
||||
log.debug("send version request to %s:%i", *self.server)
|
||||
start = perf_counter()
|
||||
result = await asyncio.wait_for(
|
||||
super().send_request('server.version', args), timeout=self.timeout
|
||||
)
|
||||
current_response_time = perf_counter() - start
|
||||
response_sum = (self.response_time or 0) * self._response_samples + current_response_time
|
||||
self.response_time = response_sum / (self._response_samples + 1)
|
||||
self._response_samples += 1
|
||||
return result
|
||||
|
||||
async def send_request(self, method, args=()):
|
||||
self.pending_amount += 1
|
||||
try:
|
||||
if method == 'server.version':
|
||||
return await self.send_timed_server_version_request(args)
|
||||
return await asyncio.wait_for(
|
||||
super().send_request(method, args), timeout=self.timeout
|
||||
)
|
||||
self.response_time = perf_counter() - start
|
||||
return result
|
||||
except RPCError as e:
|
||||
log.warning("Wallet server returned an error. Code: %s Message: %s", *e.args)
|
||||
log.warning("Wallet server (%s:%i) returned an error. Code: %s Message: %s",
|
||||
*self.server, *e.args)
|
||||
raise e
|
||||
except TimeoutError:
|
||||
self.response_time = None
|
||||
raise
|
||||
finally:
|
||||
self.pending_amount -= 1
|
||||
|
||||
async def ensure_session(self):
|
||||
# Handles reconnecting and maintaining a session alive
|
||||
|
@ -56,8 +77,8 @@ class ClientSession(BaseClientSession):
|
|||
await self.create_connection(self.timeout)
|
||||
await self.ensure_server_version()
|
||||
self._on_connect_cb()
|
||||
if (time() - self.last_send) > self.max_seconds_idle or self.response_time is None:
|
||||
await self.send_request('server.banner')
|
||||
if (perf_counter() - self.last_send) > self.max_seconds_idle or self.response_time is None:
|
||||
await self.ensure_server_version()
|
||||
retry_delay = default_delay
|
||||
except (asyncio.TimeoutError, OSError):
|
||||
await self.close()
|
||||
|
@ -67,6 +88,9 @@ class ClientSession(BaseClientSession):
|
|||
await asyncio.wait_for(self.trigger_urgent_reconnect.wait(), timeout=retry_delay)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
except asyncio.CancelledError:
|
||||
self.synchronous_close()
|
||||
raise
|
||||
finally:
|
||||
self.trigger_urgent_reconnect.clear()
|
||||
|
||||
|
@ -75,7 +99,9 @@ class ClientSession(BaseClientSession):
|
|||
|
||||
async def create_connection(self, timeout=6):
|
||||
connector = Connector(lambda: self, *self.server)
|
||||
start = perf_counter()
|
||||
await asyncio.wait_for(connector.create_connection(), timeout=timeout)
|
||||
self.connection_latency = perf_counter() - start
|
||||
|
||||
async def handle_request(self, request):
|
||||
controller = self.network.subscription_controllers[request.method]
|
||||
|
@ -85,6 +111,9 @@ class ClientSession(BaseClientSession):
|
|||
log.debug("Connection lost: %s:%d", *self.server)
|
||||
super().connection_lost(exc)
|
||||
self.response_time = None
|
||||
self.connection_latency = None
|
||||
self._response_samples = 0
|
||||
self.pending_amount = 0
|
||||
self._on_disconnect_controller.add(True)
|
||||
|
||||
|
||||
|
@ -133,33 +162,34 @@ class BaseNetwork:
|
|||
|
||||
async def stop(self):
|
||||
self.running = False
|
||||
if self.session_pool:
|
||||
self.session_pool.stop()
|
||||
if self.is_connected:
|
||||
disconnected = self.client.on_disconnected.first
|
||||
await self.client.close()
|
||||
await disconnected
|
||||
|
||||
@property
|
||||
def is_connected(self):
|
||||
return self.client and not self.client.is_closing()
|
||||
|
||||
def rpc(self, list_or_method, args):
|
||||
fastest = self.session_pool.fastest_session
|
||||
if fastest is not None and self.client != fastest:
|
||||
self.switch_event.set()
|
||||
if self.is_connected:
|
||||
return self.client.send_request(list_or_method, args)
|
||||
def rpc(self, list_or_method, args, session=None):
|
||||
session = session or self.session_pool.fastest_session
|
||||
if session:
|
||||
return session.send_request(list_or_method, args)
|
||||
else:
|
||||
self.session_pool.trigger_nodelay_connect()
|
||||
raise ConnectionError("Attempting to send rpc request when connection is not available.")
|
||||
|
||||
async def retriable_call(self, function, *args, **kwargs):
|
||||
while self.running:
|
||||
try:
|
||||
return await function(*args, **kwargs)
|
||||
except asyncio.TimeoutError:
|
||||
log.warning("Wallet server call timed out, retrying.")
|
||||
except ConnectionError:
|
||||
if not self.is_connected and self.running:
|
||||
log.warning("Wallet server unavailable, waiting for it to come back and retry.")
|
||||
await self.on_connected.first
|
||||
|
||||
def _update_remote_height(self, header_args):
|
||||
self.remote_height = header_args[0]["height"]
|
||||
|
||||
def broadcast(self, raw_transaction):
|
||||
return self.rpc('blockchain.transaction.broadcast', [raw_transaction])
|
||||
|
||||
def get_history(self, address):
|
||||
return self.rpc('blockchain.address.get_history', [address])
|
||||
|
||||
|
@ -175,11 +205,19 @@ class BaseNetwork:
|
|||
def get_headers(self, height, count=10000):
|
||||
return self.rpc('blockchain.block.headers', [height, count])
|
||||
|
||||
def subscribe_headers(self):
|
||||
return self.rpc('blockchain.headers.subscribe', [True])
|
||||
# --- Subscribes and broadcasts are always aimed towards the master client directly
|
||||
def broadcast(self, raw_transaction):
|
||||
return self.rpc('blockchain.transaction.broadcast', [raw_transaction], session=self.client)
|
||||
|
||||
def subscribe_address(self, address):
|
||||
return self.rpc('blockchain.address.subscribe', [address])
|
||||
def subscribe_headers(self):
|
||||
return self.rpc('blockchain.headers.subscribe', [True], session=self.client)
|
||||
|
||||
async def subscribe_address(self, *addresses):
|
||||
async with self.client.send_batch() as batch:
|
||||
for address in addresses:
|
||||
batch.add_request('blockchain.address.subscribe', [address])
|
||||
for address, status in zip(addresses, batch.results):
|
||||
yield address, status
|
||||
|
||||
|
||||
class SessionPool:
|
||||
|
@ -203,31 +241,62 @@ class SessionPool:
|
|||
if not self.available_sessions:
|
||||
return None
|
||||
return min(
|
||||
[(session.response_time, session) for session in self.available_sessions], key=itemgetter(0)
|
||||
[((session.response_time + session.connection_latency) * (session.pending_amount + 1), session)
|
||||
for session in self.available_sessions],
|
||||
key=itemgetter(0)
|
||||
)[1]
|
||||
|
||||
def start(self, default_servers):
|
||||
callback = self.new_connection_event.set
|
||||
self.sessions = {
|
||||
ClientSession(
|
||||
network=self.network, server=server, on_connect_callback=callback
|
||||
): None for server in default_servers
|
||||
}
|
||||
self.ensure_connections()
|
||||
def _get_session_connect_callback(self, session: ClientSession):
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
def stop(self):
|
||||
for session, task in self.sessions.items():
|
||||
task.cancel()
|
||||
session.abort()
|
||||
self.sessions.clear()
|
||||
def callback():
|
||||
duplicate_connections = [
|
||||
s for s in self.sessions
|
||||
if s is not session and s.server_address_and_port == session.server_address_and_port
|
||||
]
|
||||
already_connected = None if not duplicate_connections else duplicate_connections[0]
|
||||
if already_connected:
|
||||
self.sessions.pop(session).cancel()
|
||||
session.synchronous_close()
|
||||
log.info("wallet server %s resolves to the same server as %s, rechecking in an hour",
|
||||
session.server[0], already_connected.server[0])
|
||||
loop.call_later(3600, self._connect_session, session.server)
|
||||
return
|
||||
self.new_connection_event.set()
|
||||
log.info("connected to %s:%i", *session.server)
|
||||
|
||||
def ensure_connections(self):
|
||||
for session, task in list(self.sessions.items()):
|
||||
return callback
|
||||
|
||||
def _connect_session(self, server: Tuple[str, int]):
|
||||
session = None
|
||||
for s in self.sessions:
|
||||
if s.server == server:
|
||||
session = s
|
||||
break
|
||||
if not session:
|
||||
session = ClientSession(
|
||||
network=self.network, server=server
|
||||
)
|
||||
session._on_connect_cb = self._get_session_connect_callback(session)
|
||||
task = self.sessions.get(session, None)
|
||||
if not task or task.done():
|
||||
task = asyncio.create_task(session.ensure_session())
|
||||
task.add_done_callback(lambda _: self.ensure_connections())
|
||||
self.sessions[session] = task
|
||||
|
||||
def start(self, default_servers):
|
||||
for server in default_servers:
|
||||
self._connect_session(server)
|
||||
|
||||
def stop(self):
|
||||
for task in self.sessions.values():
|
||||
task.cancel()
|
||||
self.sessions.clear()
|
||||
|
||||
def ensure_connections(self):
|
||||
for session in self.sessions:
|
||||
self._connect_session(session.server)
|
||||
|
||||
def trigger_nodelay_connect(self):
|
||||
# used when other parts of the system sees we might have internet back
|
||||
# bypasses the retry interval
|
||||
|
|
|
@ -746,10 +746,8 @@ class JSONRPCConnection(object):
|
|||
self._protocol = item
|
||||
return self.receive_message(message)
|
||||
|
||||
def time_out_pending_requests(self):
|
||||
"""Times out all pending requests."""
|
||||
# this used to be CancelledError, but thats confusing as in are we closing the whole sdk or failing?
|
||||
exception = asyncio.TimeoutError()
|
||||
def raise_pending_requests(self, exception):
|
||||
exception = exception or asyncio.TimeoutError()
|
||||
for request, event in self._requests.values():
|
||||
event.result = exception
|
||||
event.set()
|
||||
|
|
|
@ -103,7 +103,7 @@ class SessionBase(asyncio.Protocol):
|
|||
# Force-close a connection if a send doesn't succeed in this time
|
||||
self.max_send_delay = 60
|
||||
# Statistics. The RPC object also keeps its own statistics.
|
||||
self.start_time = time.time()
|
||||
self.start_time = time.perf_counter()
|
||||
self.errors = 0
|
||||
self.send_count = 0
|
||||
self.send_size = 0
|
||||
|
@ -123,7 +123,7 @@ class SessionBase(asyncio.Protocol):
|
|||
# A non-positive value means not to limit concurrency
|
||||
if self.bw_limit <= 0:
|
||||
return
|
||||
now = time.time()
|
||||
now = time.perf_counter()
|
||||
# Reduce the recorded usage in proportion to the elapsed time
|
||||
refund = (now - self.bw_time) * (self.bw_limit / 3600)
|
||||
self.bw_charge = max(0, self.bw_charge - int(refund))
|
||||
|
@ -146,7 +146,7 @@ class SessionBase(asyncio.Protocol):
|
|||
await asyncio.wait_for(self._can_send.wait(), secs)
|
||||
except asyncio.TimeoutError:
|
||||
self.abort()
|
||||
raise asyncio.CancelledError(f'task timed out after {secs}s')
|
||||
raise asyncio.TimeoutError(f'task timed out after {secs}s')
|
||||
|
||||
async def _send_message(self, message):
|
||||
if not self._can_send.is_set():
|
||||
|
@ -156,7 +156,7 @@ class SessionBase(asyncio.Protocol):
|
|||
self.send_size += len(framed_message)
|
||||
self._using_bandwidth(len(framed_message))
|
||||
self.send_count += 1
|
||||
self.last_send = time.time()
|
||||
self.last_send = time.perf_counter()
|
||||
if self.verbosity >= 4:
|
||||
self.logger.debug(f'Sending framed message {framed_message}')
|
||||
self.transport.write(framed_message)
|
||||
|
@ -215,6 +215,7 @@ class SessionBase(asyncio.Protocol):
|
|||
self._address = None
|
||||
self.transport = None
|
||||
self._task_group.cancel()
|
||||
if self._pm_task:
|
||||
self._pm_task.cancel()
|
||||
# Release waiting tasks
|
||||
self._can_send.set()
|
||||
|
@ -253,6 +254,7 @@ class SessionBase(asyncio.Protocol):
|
|||
if self.transport:
|
||||
self.transport.abort()
|
||||
|
||||
# TODO: replace with synchronous_close
|
||||
async def close(self, *, force_after=30):
|
||||
"""Close the connection and return when closed."""
|
||||
self._close()
|
||||
|
@ -262,6 +264,11 @@ class SessionBase(asyncio.Protocol):
|
|||
self.abort()
|
||||
await self._pm_task
|
||||
|
||||
def synchronous_close(self):
|
||||
self._close()
|
||||
if self._pm_task and not self._pm_task.done():
|
||||
self._pm_task.cancel()
|
||||
|
||||
|
||||
class MessageSession(SessionBase):
|
||||
"""Session class for protocols where messages are not tied to responses,
|
||||
|
@ -296,7 +303,7 @@ class MessageSession(SessionBase):
|
|||
)
|
||||
self._bump_errors()
|
||||
else:
|
||||
self.last_recv = time.time()
|
||||
self.last_recv = time.perf_counter()
|
||||
self.recv_count += 1
|
||||
if self.recv_count % 10 == 0:
|
||||
await self._update_concurrency()
|
||||
|
@ -416,7 +423,7 @@ class RPCSession(SessionBase):
|
|||
self.logger.warning(f'{e!r}')
|
||||
continue
|
||||
|
||||
self.last_recv = time.time()
|
||||
self.last_recv = time.perf_counter()
|
||||
self.recv_count += 1
|
||||
if self.recv_count % 10 == 0:
|
||||
await self._update_concurrency()
|
||||
|
@ -456,7 +463,7 @@ class RPCSession(SessionBase):
|
|||
|
||||
def connection_lost(self, exc):
|
||||
# Cancel pending requests and message processing
|
||||
self.connection.time_out_pending_requests()
|
||||
self.connection.raise_pending_requests(exc)
|
||||
super().connection_lost(exc)
|
||||
|
||||
# External API
|
||||
|
|
|
@ -258,7 +258,7 @@ class SessionManager:
|
|||
session_timeout = self.env.session_timeout
|
||||
while True:
|
||||
await sleep(session_timeout // 10)
|
||||
stale_cutoff = time.time() - session_timeout
|
||||
stale_cutoff = time.perf_counter() - session_timeout
|
||||
stale_sessions = [session for session in self.sessions
|
||||
if session.last_recv < stale_cutoff]
|
||||
if stale_sessions:
|
||||
|
|
Loading…
Reference in a new issue