Merge pull request #2398 from lbryio/retriable_batcheable_networking

Improve wallet server selection
This commit is contained in:
Jack Robison 2019-08-20 20:10:05 -04:00 committed by GitHub
commit a11956ece0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 150 additions and 75 deletions

View file

@ -255,7 +255,7 @@ class TestQueries(AsyncioTestCase):
self.ledger.db.db.execute_fetchall = check_parameters_length self.ledger.db.db.execute_fetchall = check_parameters_length
account = await self.create_account() account = await self.create_account()
tx = await self.create_tx_from_nothing(account, 0) tx = await self.create_tx_from_nothing(account, 0)
for height in range(1200): for height in range(1, 1200):
tx = await self.create_tx_from_txo(tx.outputs[0], account, height=height) tx = await self.create_tx_from_txo(tx.outputs[0], account, height=height)
variable_limit = self.ledger.db.MAX_QUERY_VARIABLES variable_limit = self.ledger.db.MAX_QUERY_VARIABLES
for limit in range(variable_limit-2, variable_limit+2): for limit in range(variable_limit-2, variable_limit+2):

View file

@ -18,6 +18,9 @@ class MockNetwork:
self.get_transaction_called = [] self.get_transaction_called = []
self.is_connected = False self.is_connected = False
def retriable_call(self, function, *args, **kwargs):
return function(*args, **kwargs)
async def get_history(self, address): async def get_history(self, address):
self.get_history_called.append(address) self.get_history_called.append(address)
self.address = address self.address = address
@ -121,8 +124,9 @@ class TestSynchronization(LedgerTestCase):
) )
class MocHeaderNetwork: class MocHeaderNetwork(MockNetwork):
def __init__(self, responses): def __init__(self, responses):
super().__init__(None, None)
self.responses = responses self.responses = responses
async def get_headers(self, height, blocks): async def get_headers(self, height, blocks):

View file

@ -310,7 +310,7 @@ class BaseLedger(metaclass=LedgerRegistry):
subscription_update = False subscription_update = False
if not headers: if not headers:
header_response = await self.network.get_headers(height, 2001) header_response = await self.network.retriable_call(self.network.get_headers, height, 2001)
headers = header_response['hex'] headers = header_response['hex']
if not headers: if not headers:
@ -395,13 +395,9 @@ class BaseLedger(metaclass=LedgerRegistry):
async def subscribe_addresses(self, address_manager: baseaccount.AddressManager, addresses: List[str]): async def subscribe_addresses(self, address_manager: baseaccount.AddressManager, addresses: List[str]):
if self.network.is_connected and addresses: if self.network.is_connected and addresses:
await asyncio.wait([ async for address, remote_status in self.network.subscribe_address(*addresses):
self.subscribe_address(address_manager, address) for address in addresses # subscribe isnt a retriable call as it happens right after a connection is made
]) self._update_tasks.add(self.update_history(address, remote_status, address_manager))
async def subscribe_address(self, address_manager: baseaccount.AddressManager, address: str):
remote_status = await self.network.subscribe_address(address)
self._update_tasks.add(self.update_history(address, remote_status, address_manager))
def process_status_update(self, update): def process_status_update(self, update):
address, remote_status = update address, remote_status = update
@ -417,7 +413,7 @@ class BaseLedger(metaclass=LedgerRegistry):
if local_status == remote_status: if local_status == remote_status:
return return
remote_history = await self.network.get_history(address) remote_history = await self.network.retriable_call(self.network.get_history, address)
cache_tasks = [] cache_tasks = []
synced_history = StringIO() synced_history = StringIO()
@ -489,7 +485,7 @@ class BaseLedger(metaclass=LedgerRegistry):
if tx is None: if tx is None:
# fetch from network # fetch from network
_raw = await self.network.get_transaction(txid) _raw = await self.network.retriable_call(self.network.get_transaction, txid)
if _raw: if _raw:
tx = self.transaction_class(unhexlify(_raw)) tx = self.transaction_class(unhexlify(_raw))
await self.maybe_verify_transaction(tx, remote_height) await self.maybe_verify_transaction(tx, remote_height)
@ -510,7 +506,7 @@ class BaseLedger(metaclass=LedgerRegistry):
async def maybe_verify_transaction(self, tx, remote_height): async def maybe_verify_transaction(self, tx, remote_height):
tx.height = remote_height tx.height = remote_height
if 0 < remote_height <= len(self.headers): if 0 < remote_height <= len(self.headers):
merkle = await self.network.get_merkle(tx.id, remote_height) merkle = await self.network.retriable_call(self.network.get_merkle, tx.id, remote_height)
merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash) merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash)
header = self.headers[remote_height] header = self.headers[remote_height]
tx.position = merkle['pos'] tx.position = merkle['pos']
@ -524,6 +520,7 @@ class BaseLedger(metaclass=LedgerRegistry):
return None return None
def broadcast(self, tx): def broadcast(self, tx):
# broadcast cant be a retriable call yet
return self.network.broadcast(hexlify(tx.raw).decode()) return self.network.broadcast(hexlify(tx.raw).decode())
async def wait(self, tx: basetransaction.BaseTransaction, height=-1, timeout=None): async def wait(self, tx: basetransaction.BaseTransaction, height=-1, timeout=None):
@ -545,4 +542,4 @@ class BaseLedger(metaclass=LedgerRegistry):
)) for address_record in records )) for address_record in records
], timeout=timeout) ], timeout=timeout)
if pending: if pending:
raise TimeoutError('Timed out waiting for transaction.') raise asyncio.TimeoutError('Timed out waiting for transaction.')

View file

@ -1,8 +1,8 @@
import logging import logging
import asyncio import asyncio
from operator import itemgetter from operator import itemgetter
from typing import Dict, Optional from typing import Dict, Optional, Tuple
from time import time, perf_counter from time import perf_counter
from torba.rpc import RPCSession as BaseClientSession, Connector, RPCError from torba.rpc import RPCSession as BaseClientSession, Connector, RPCError
@ -13,17 +13,20 @@ log = logging.getLogger(__name__)
class ClientSession(BaseClientSession): class ClientSession(BaseClientSession):
def __init__(self, *args, network, server, timeout=30, on_connect_callback=None, **kwargs): def __init__(self, *args, network, server, timeout=30, on_connect_callback=None, **kwargs):
self.network = network self.network = network
self.server = server self.server = server
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._on_disconnect_controller = StreamController() self._on_disconnect_controller = StreamController()
self.on_disconnected = self._on_disconnect_controller.stream self.on_disconnected = self._on_disconnect_controller.stream
self.bw_limit = self.framer.max_size = self.max_errors = 1 << 32 self.framer.max_size = self.max_errors = 1 << 32
self.bw_limit = -1
self.timeout = timeout self.timeout = timeout
self.max_seconds_idle = timeout * 2 self.max_seconds_idle = timeout * 2
self.response_time: Optional[float] = None self.response_time: Optional[float] = None
self.connection_latency: Optional[float] = None
self._response_samples = 0
self.pending_amount = 0
self._on_connect_cb = on_connect_callback or (lambda: None) self._on_connect_cb = on_connect_callback or (lambda: None)
self.trigger_urgent_reconnect = asyncio.Event() self.trigger_urgent_reconnect = asyncio.Event()
@ -31,20 +34,38 @@ class ClientSession(BaseClientSession):
def available(self): def available(self):
return not self.is_closing() and self._can_send.is_set() and self.response_time is not None return not self.is_closing() and self._can_send.is_set() and self.response_time is not None
@property
def server_address_and_port(self) -> Optional[Tuple[str, int]]:
if not self.transport:
return None
return self.transport.get_extra_info('peername')
async def send_timed_server_version_request(self, args=()):
log.debug("send version request to %s:%i", *self.server)
start = perf_counter()
result = await asyncio.wait_for(
super().send_request('server.version', args), timeout=self.timeout
)
current_response_time = perf_counter() - start
response_sum = (self.response_time or 0) * self._response_samples + current_response_time
self.response_time = response_sum / (self._response_samples + 1)
self._response_samples += 1
return result
async def send_request(self, method, args=()): async def send_request(self, method, args=()):
self.pending_amount += 1
try: try:
start = perf_counter() if method == 'server.version':
result = await asyncio.wait_for( return await self.send_timed_server_version_request(args)
return await asyncio.wait_for(
super().send_request(method, args), timeout=self.timeout super().send_request(method, args), timeout=self.timeout
) )
self.response_time = perf_counter() - start
return result
except RPCError as e: except RPCError as e:
log.warning("Wallet server returned an error. Code: %s Message: %s", *e.args) log.warning("Wallet server (%s:%i) returned an error. Code: %s Message: %s",
*self.server, *e.args)
raise e raise e
except TimeoutError: finally:
self.response_time = None self.pending_amount -= 1
raise
async def ensure_session(self): async def ensure_session(self):
# Handles reconnecting and maintaining a session alive # Handles reconnecting and maintaining a session alive
@ -56,8 +77,8 @@ class ClientSession(BaseClientSession):
await self.create_connection(self.timeout) await self.create_connection(self.timeout)
await self.ensure_server_version() await self.ensure_server_version()
self._on_connect_cb() self._on_connect_cb()
if (time() - self.last_send) > self.max_seconds_idle or self.response_time is None: if (perf_counter() - self.last_send) > self.max_seconds_idle or self.response_time is None:
await self.send_request('server.banner') await self.ensure_server_version()
retry_delay = default_delay retry_delay = default_delay
except (asyncio.TimeoutError, OSError): except (asyncio.TimeoutError, OSError):
await self.close() await self.close()
@ -67,6 +88,9 @@ class ClientSession(BaseClientSession):
await asyncio.wait_for(self.trigger_urgent_reconnect.wait(), timeout=retry_delay) await asyncio.wait_for(self.trigger_urgent_reconnect.wait(), timeout=retry_delay)
except asyncio.TimeoutError: except asyncio.TimeoutError:
pass pass
except asyncio.CancelledError:
self.synchronous_close()
raise
finally: finally:
self.trigger_urgent_reconnect.clear() self.trigger_urgent_reconnect.clear()
@ -75,7 +99,9 @@ class ClientSession(BaseClientSession):
async def create_connection(self, timeout=6): async def create_connection(self, timeout=6):
connector = Connector(lambda: self, *self.server) connector = Connector(lambda: self, *self.server)
start = perf_counter()
await asyncio.wait_for(connector.create_connection(), timeout=timeout) await asyncio.wait_for(connector.create_connection(), timeout=timeout)
self.connection_latency = perf_counter() - start
async def handle_request(self, request): async def handle_request(self, request):
controller = self.network.subscription_controllers[request.method] controller = self.network.subscription_controllers[request.method]
@ -85,6 +111,9 @@ class ClientSession(BaseClientSession):
log.debug("Connection lost: %s:%d", *self.server) log.debug("Connection lost: %s:%d", *self.server)
super().connection_lost(exc) super().connection_lost(exc)
self.response_time = None self.response_time = None
self.connection_latency = None
self._response_samples = 0
self.pending_amount = 0
self._on_disconnect_controller.add(True) self._on_disconnect_controller.add(True)
@ -133,33 +162,34 @@ class BaseNetwork:
async def stop(self): async def stop(self):
self.running = False self.running = False
if self.session_pool: self.session_pool.stop()
self.session_pool.stop()
if self.is_connected:
disconnected = self.client.on_disconnected.first
await self.client.close()
await disconnected
@property @property
def is_connected(self): def is_connected(self):
return self.client and not self.client.is_closing() return self.client and not self.client.is_closing()
def rpc(self, list_or_method, args): def rpc(self, list_or_method, args, session=None):
fastest = self.session_pool.fastest_session session = session or self.session_pool.fastest_session
if fastest is not None and self.client != fastest: if session:
self.switch_event.set() return session.send_request(list_or_method, args)
if self.is_connected:
return self.client.send_request(list_or_method, args)
else: else:
self.session_pool.trigger_nodelay_connect() self.session_pool.trigger_nodelay_connect()
raise ConnectionError("Attempting to send rpc request when connection is not available.") raise ConnectionError("Attempting to send rpc request when connection is not available.")
async def retriable_call(self, function, *args, **kwargs):
while self.running:
try:
return await function(*args, **kwargs)
except asyncio.TimeoutError:
log.warning("Wallet server call timed out, retrying.")
except ConnectionError:
if not self.is_connected and self.running:
log.warning("Wallet server unavailable, waiting for it to come back and retry.")
await self.on_connected.first
def _update_remote_height(self, header_args): def _update_remote_height(self, header_args):
self.remote_height = header_args[0]["height"] self.remote_height = header_args[0]["height"]
def broadcast(self, raw_transaction):
return self.rpc('blockchain.transaction.broadcast', [raw_transaction])
def get_history(self, address): def get_history(self, address):
return self.rpc('blockchain.address.get_history', [address]) return self.rpc('blockchain.address.get_history', [address])
@ -175,11 +205,19 @@ class BaseNetwork:
def get_headers(self, height, count=10000): def get_headers(self, height, count=10000):
return self.rpc('blockchain.block.headers', [height, count]) return self.rpc('blockchain.block.headers', [height, count])
def subscribe_headers(self): # --- Subscribes and broadcasts are always aimed towards the master client directly
return self.rpc('blockchain.headers.subscribe', [True]) def broadcast(self, raw_transaction):
return self.rpc('blockchain.transaction.broadcast', [raw_transaction], session=self.client)
def subscribe_address(self, address): def subscribe_headers(self):
return self.rpc('blockchain.address.subscribe', [address]) return self.rpc('blockchain.headers.subscribe', [True], session=self.client)
async def subscribe_address(self, *addresses):
async with self.client.send_batch() as batch:
for address in addresses:
batch.add_request('blockchain.address.subscribe', [address])
for address, status in zip(addresses, batch.results):
yield address, status
class SessionPool: class SessionPool:
@ -203,30 +241,61 @@ class SessionPool:
if not self.available_sessions: if not self.available_sessions:
return None return None
return min( return min(
[(session.response_time, session) for session in self.available_sessions], key=itemgetter(0) [((session.response_time + session.connection_latency) * (session.pending_amount + 1), session)
for session in self.available_sessions],
key=itemgetter(0)
)[1] )[1]
def _get_session_connect_callback(self, session: ClientSession):
loop = asyncio.get_event_loop()
def callback():
duplicate_connections = [
s for s in self.sessions
if s is not session and s.server_address_and_port == session.server_address_and_port
]
already_connected = None if not duplicate_connections else duplicate_connections[0]
if already_connected:
self.sessions.pop(session).cancel()
session.synchronous_close()
log.info("wallet server %s resolves to the same server as %s, rechecking in an hour",
session.server[0], already_connected.server[0])
loop.call_later(3600, self._connect_session, session.server)
return
self.new_connection_event.set()
log.info("connected to %s:%i", *session.server)
return callback
def _connect_session(self, server: Tuple[str, int]):
session = None
for s in self.sessions:
if s.server == server:
session = s
break
if not session:
session = ClientSession(
network=self.network, server=server
)
session._on_connect_cb = self._get_session_connect_callback(session)
task = self.sessions.get(session, None)
if not task or task.done():
task = asyncio.create_task(session.ensure_session())
task.add_done_callback(lambda _: self.ensure_connections())
self.sessions[session] = task
def start(self, default_servers): def start(self, default_servers):
callback = self.new_connection_event.set for server in default_servers:
self.sessions = { self._connect_session(server)
ClientSession(
network=self.network, server=server, on_connect_callback=callback
): None for server in default_servers
}
self.ensure_connections()
def stop(self): def stop(self):
for session, task in self.sessions.items(): for task in self.sessions.values():
task.cancel() task.cancel()
session.abort()
self.sessions.clear() self.sessions.clear()
def ensure_connections(self): def ensure_connections(self):
for session, task in list(self.sessions.items()): for session in self.sessions:
if not task or task.done(): self._connect_session(session.server)
task = asyncio.create_task(session.ensure_session())
task.add_done_callback(lambda _: self.ensure_connections())
self.sessions[session] = task
def trigger_nodelay_connect(self): def trigger_nodelay_connect(self):
# used when other parts of the system sees we might have internet back # used when other parts of the system sees we might have internet back

View file

@ -746,10 +746,8 @@ class JSONRPCConnection(object):
self._protocol = item self._protocol = item
return self.receive_message(message) return self.receive_message(message)
def time_out_pending_requests(self): def raise_pending_requests(self, exception):
"""Times out all pending requests.""" exception = exception or asyncio.TimeoutError()
# this used to be CancelledError, but thats confusing as in are we closing the whole sdk or failing?
exception = asyncio.TimeoutError()
for request, event in self._requests.values(): for request, event in self._requests.values():
event.result = exception event.result = exception
event.set() event.set()

View file

@ -103,7 +103,7 @@ class SessionBase(asyncio.Protocol):
# Force-close a connection if a send doesn't succeed in this time # Force-close a connection if a send doesn't succeed in this time
self.max_send_delay = 60 self.max_send_delay = 60
# Statistics. The RPC object also keeps its own statistics. # Statistics. The RPC object also keeps its own statistics.
self.start_time = time.time() self.start_time = time.perf_counter()
self.errors = 0 self.errors = 0
self.send_count = 0 self.send_count = 0
self.send_size = 0 self.send_size = 0
@ -123,7 +123,7 @@ class SessionBase(asyncio.Protocol):
# A non-positive value means not to limit concurrency # A non-positive value means not to limit concurrency
if self.bw_limit <= 0: if self.bw_limit <= 0:
return return
now = time.time() now = time.perf_counter()
# Reduce the recorded usage in proportion to the elapsed time # Reduce the recorded usage in proportion to the elapsed time
refund = (now - self.bw_time) * (self.bw_limit / 3600) refund = (now - self.bw_time) * (self.bw_limit / 3600)
self.bw_charge = max(0, self.bw_charge - int(refund)) self.bw_charge = max(0, self.bw_charge - int(refund))
@ -146,7 +146,7 @@ class SessionBase(asyncio.Protocol):
await asyncio.wait_for(self._can_send.wait(), secs) await asyncio.wait_for(self._can_send.wait(), secs)
except asyncio.TimeoutError: except asyncio.TimeoutError:
self.abort() self.abort()
raise asyncio.CancelledError(f'task timed out after {secs}s') raise asyncio.TimeoutError(f'task timed out after {secs}s')
async def _send_message(self, message): async def _send_message(self, message):
if not self._can_send.is_set(): if not self._can_send.is_set():
@ -156,7 +156,7 @@ class SessionBase(asyncio.Protocol):
self.send_size += len(framed_message) self.send_size += len(framed_message)
self._using_bandwidth(len(framed_message)) self._using_bandwidth(len(framed_message))
self.send_count += 1 self.send_count += 1
self.last_send = time.time() self.last_send = time.perf_counter()
if self.verbosity >= 4: if self.verbosity >= 4:
self.logger.debug(f'Sending framed message {framed_message}') self.logger.debug(f'Sending framed message {framed_message}')
self.transport.write(framed_message) self.transport.write(framed_message)
@ -215,7 +215,8 @@ class SessionBase(asyncio.Protocol):
self._address = None self._address = None
self.transport = None self.transport = None
self._task_group.cancel() self._task_group.cancel()
self._pm_task.cancel() if self._pm_task:
self._pm_task.cancel()
# Release waiting tasks # Release waiting tasks
self._can_send.set() self._can_send.set()
@ -253,6 +254,7 @@ class SessionBase(asyncio.Protocol):
if self.transport: if self.transport:
self.transport.abort() self.transport.abort()
# TODO: replace with synchronous_close
async def close(self, *, force_after=30): async def close(self, *, force_after=30):
"""Close the connection and return when closed.""" """Close the connection and return when closed."""
self._close() self._close()
@ -262,6 +264,11 @@ class SessionBase(asyncio.Protocol):
self.abort() self.abort()
await self._pm_task await self._pm_task
def synchronous_close(self):
self._close()
if self._pm_task and not self._pm_task.done():
self._pm_task.cancel()
class MessageSession(SessionBase): class MessageSession(SessionBase):
"""Session class for protocols where messages are not tied to responses, """Session class for protocols where messages are not tied to responses,
@ -296,7 +303,7 @@ class MessageSession(SessionBase):
) )
self._bump_errors() self._bump_errors()
else: else:
self.last_recv = time.time() self.last_recv = time.perf_counter()
self.recv_count += 1 self.recv_count += 1
if self.recv_count % 10 == 0: if self.recv_count % 10 == 0:
await self._update_concurrency() await self._update_concurrency()
@ -416,7 +423,7 @@ class RPCSession(SessionBase):
self.logger.warning(f'{e!r}') self.logger.warning(f'{e!r}')
continue continue
self.last_recv = time.time() self.last_recv = time.perf_counter()
self.recv_count += 1 self.recv_count += 1
if self.recv_count % 10 == 0: if self.recv_count % 10 == 0:
await self._update_concurrency() await self._update_concurrency()
@ -456,7 +463,7 @@ class RPCSession(SessionBase):
def connection_lost(self, exc): def connection_lost(self, exc):
# Cancel pending requests and message processing # Cancel pending requests and message processing
self.connection.time_out_pending_requests() self.connection.raise_pending_requests(exc)
super().connection_lost(exc) super().connection_lost(exc)
# External API # External API

View file

@ -258,7 +258,7 @@ class SessionManager:
session_timeout = self.env.session_timeout session_timeout = self.env.session_timeout
while True: while True:
await sleep(session_timeout // 10) await sleep(session_timeout // 10)
stale_cutoff = time.time() - session_timeout stale_cutoff = time.perf_counter() - session_timeout
stale_sessions = [session for session in self.sessions stale_sessions = [session for session in self.sessions
if session.last_recv < stale_cutoff] if session.last_recv < stale_cutoff]
if stale_sessions: if stale_sessions: