Merge pull request #2398 from lbryio/retriable_batcheable_networking

Improve wallet server selection
This commit is contained in:
Jack Robison 2019-08-20 20:10:05 -04:00 committed by GitHub
commit a11956ece0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 150 additions and 75 deletions

View file

@ -255,7 +255,7 @@ class TestQueries(AsyncioTestCase):
self.ledger.db.db.execute_fetchall = check_parameters_length
account = await self.create_account()
tx = await self.create_tx_from_nothing(account, 0)
for height in range(1200):
for height in range(1, 1200):
tx = await self.create_tx_from_txo(tx.outputs[0], account, height=height)
variable_limit = self.ledger.db.MAX_QUERY_VARIABLES
for limit in range(variable_limit-2, variable_limit+2):

View file

@ -18,6 +18,9 @@ class MockNetwork:
self.get_transaction_called = []
self.is_connected = False
def retriable_call(self, function, *args, **kwargs):
return function(*args, **kwargs)
async def get_history(self, address):
self.get_history_called.append(address)
self.address = address
@ -121,8 +124,9 @@ class TestSynchronization(LedgerTestCase):
)
class MocHeaderNetwork:
class MocHeaderNetwork(MockNetwork):
def __init__(self, responses):
super().__init__(None, None)
self.responses = responses
async def get_headers(self, height, blocks):

View file

@ -310,7 +310,7 @@ class BaseLedger(metaclass=LedgerRegistry):
subscription_update = False
if not headers:
header_response = await self.network.get_headers(height, 2001)
header_response = await self.network.retriable_call(self.network.get_headers, height, 2001)
headers = header_response['hex']
if not headers:
@ -395,12 +395,8 @@ class BaseLedger(metaclass=LedgerRegistry):
async def subscribe_addresses(self, address_manager: baseaccount.AddressManager, addresses: List[str]):
if self.network.is_connected and addresses:
await asyncio.wait([
self.subscribe_address(address_manager, address) for address in addresses
])
async def subscribe_address(self, address_manager: baseaccount.AddressManager, address: str):
remote_status = await self.network.subscribe_address(address)
async for address, remote_status in self.network.subscribe_address(*addresses):
# subscribe isnt a retriable call as it happens right after a connection is made
self._update_tasks.add(self.update_history(address, remote_status, address_manager))
def process_status_update(self, update):
@ -417,7 +413,7 @@ class BaseLedger(metaclass=LedgerRegistry):
if local_status == remote_status:
return
remote_history = await self.network.get_history(address)
remote_history = await self.network.retriable_call(self.network.get_history, address)
cache_tasks = []
synced_history = StringIO()
@ -489,7 +485,7 @@ class BaseLedger(metaclass=LedgerRegistry):
if tx is None:
# fetch from network
_raw = await self.network.get_transaction(txid)
_raw = await self.network.retriable_call(self.network.get_transaction, txid)
if _raw:
tx = self.transaction_class(unhexlify(_raw))
await self.maybe_verify_transaction(tx, remote_height)
@ -510,7 +506,7 @@ class BaseLedger(metaclass=LedgerRegistry):
async def maybe_verify_transaction(self, tx, remote_height):
tx.height = remote_height
if 0 < remote_height <= len(self.headers):
merkle = await self.network.get_merkle(tx.id, remote_height)
merkle = await self.network.retriable_call(self.network.get_merkle, tx.id, remote_height)
merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash)
header = self.headers[remote_height]
tx.position = merkle['pos']
@ -524,6 +520,7 @@ class BaseLedger(metaclass=LedgerRegistry):
return None
def broadcast(self, tx):
# broadcast cant be a retriable call yet
return self.network.broadcast(hexlify(tx.raw).decode())
async def wait(self, tx: basetransaction.BaseTransaction, height=-1, timeout=None):
@ -545,4 +542,4 @@ class BaseLedger(metaclass=LedgerRegistry):
)) for address_record in records
], timeout=timeout)
if pending:
raise TimeoutError('Timed out waiting for transaction.')
raise asyncio.TimeoutError('Timed out waiting for transaction.')

View file

@ -1,8 +1,8 @@
import logging
import asyncio
from operator import itemgetter
from typing import Dict, Optional
from time import time, perf_counter
from typing import Dict, Optional, Tuple
from time import perf_counter
from torba.rpc import RPCSession as BaseClientSession, Connector, RPCError
@ -13,17 +13,20 @@ log = logging.getLogger(__name__)
class ClientSession(BaseClientSession):
def __init__(self, *args, network, server, timeout=30, on_connect_callback=None, **kwargs):
self.network = network
self.server = server
super().__init__(*args, **kwargs)
self._on_disconnect_controller = StreamController()
self.on_disconnected = self._on_disconnect_controller.stream
self.bw_limit = self.framer.max_size = self.max_errors = 1 << 32
self.framer.max_size = self.max_errors = 1 << 32
self.bw_limit = -1
self.timeout = timeout
self.max_seconds_idle = timeout * 2
self.response_time: Optional[float] = None
self.connection_latency: Optional[float] = None
self._response_samples = 0
self.pending_amount = 0
self._on_connect_cb = on_connect_callback or (lambda: None)
self.trigger_urgent_reconnect = asyncio.Event()
@ -31,20 +34,38 @@ class ClientSession(BaseClientSession):
def available(self):
return not self.is_closing() and self._can_send.is_set() and self.response_time is not None
async def send_request(self, method, args=()):
try:
@property
def server_address_and_port(self) -> Optional[Tuple[str, int]]:
if not self.transport:
return None
return self.transport.get_extra_info('peername')
async def send_timed_server_version_request(self, args=()):
log.debug("send version request to %s:%i", *self.server)
start = perf_counter()
result = await asyncio.wait_for(
super().send_request('server.version', args), timeout=self.timeout
)
current_response_time = perf_counter() - start
response_sum = (self.response_time or 0) * self._response_samples + current_response_time
self.response_time = response_sum / (self._response_samples + 1)
self._response_samples += 1
return result
async def send_request(self, method, args=()):
self.pending_amount += 1
try:
if method == 'server.version':
return await self.send_timed_server_version_request(args)
return await asyncio.wait_for(
super().send_request(method, args), timeout=self.timeout
)
self.response_time = perf_counter() - start
return result
except RPCError as e:
log.warning("Wallet server returned an error. Code: %s Message: %s", *e.args)
log.warning("Wallet server (%s:%i) returned an error. Code: %s Message: %s",
*self.server, *e.args)
raise e
except TimeoutError:
self.response_time = None
raise
finally:
self.pending_amount -= 1
async def ensure_session(self):
# Handles reconnecting and maintaining a session alive
@ -56,8 +77,8 @@ class ClientSession(BaseClientSession):
await self.create_connection(self.timeout)
await self.ensure_server_version()
self._on_connect_cb()
if (time() - self.last_send) > self.max_seconds_idle or self.response_time is None:
await self.send_request('server.banner')
if (perf_counter() - self.last_send) > self.max_seconds_idle or self.response_time is None:
await self.ensure_server_version()
retry_delay = default_delay
except (asyncio.TimeoutError, OSError):
await self.close()
@ -67,6 +88,9 @@ class ClientSession(BaseClientSession):
await asyncio.wait_for(self.trigger_urgent_reconnect.wait(), timeout=retry_delay)
except asyncio.TimeoutError:
pass
except asyncio.CancelledError:
self.synchronous_close()
raise
finally:
self.trigger_urgent_reconnect.clear()
@ -75,7 +99,9 @@ class ClientSession(BaseClientSession):
async def create_connection(self, timeout=6):
connector = Connector(lambda: self, *self.server)
start = perf_counter()
await asyncio.wait_for(connector.create_connection(), timeout=timeout)
self.connection_latency = perf_counter() - start
async def handle_request(self, request):
controller = self.network.subscription_controllers[request.method]
@ -85,6 +111,9 @@ class ClientSession(BaseClientSession):
log.debug("Connection lost: %s:%d", *self.server)
super().connection_lost(exc)
self.response_time = None
self.connection_latency = None
self._response_samples = 0
self.pending_amount = 0
self._on_disconnect_controller.add(True)
@ -133,33 +162,34 @@ class BaseNetwork:
async def stop(self):
self.running = False
if self.session_pool:
self.session_pool.stop()
if self.is_connected:
disconnected = self.client.on_disconnected.first
await self.client.close()
await disconnected
@property
def is_connected(self):
return self.client and not self.client.is_closing()
def rpc(self, list_or_method, args):
fastest = self.session_pool.fastest_session
if fastest is not None and self.client != fastest:
self.switch_event.set()
if self.is_connected:
return self.client.send_request(list_or_method, args)
def rpc(self, list_or_method, args, session=None):
session = session or self.session_pool.fastest_session
if session:
return session.send_request(list_or_method, args)
else:
self.session_pool.trigger_nodelay_connect()
raise ConnectionError("Attempting to send rpc request when connection is not available.")
async def retriable_call(self, function, *args, **kwargs):
while self.running:
try:
return await function(*args, **kwargs)
except asyncio.TimeoutError:
log.warning("Wallet server call timed out, retrying.")
except ConnectionError:
if not self.is_connected and self.running:
log.warning("Wallet server unavailable, waiting for it to come back and retry.")
await self.on_connected.first
def _update_remote_height(self, header_args):
self.remote_height = header_args[0]["height"]
def broadcast(self, raw_transaction):
return self.rpc('blockchain.transaction.broadcast', [raw_transaction])
def get_history(self, address):
return self.rpc('blockchain.address.get_history', [address])
@ -175,11 +205,19 @@ class BaseNetwork:
def get_headers(self, height, count=10000):
return self.rpc('blockchain.block.headers', [height, count])
def subscribe_headers(self):
return self.rpc('blockchain.headers.subscribe', [True])
# --- Subscribes and broadcasts are always aimed towards the master client directly
def broadcast(self, raw_transaction):
return self.rpc('blockchain.transaction.broadcast', [raw_transaction], session=self.client)
def subscribe_address(self, address):
return self.rpc('blockchain.address.subscribe', [address])
def subscribe_headers(self):
return self.rpc('blockchain.headers.subscribe', [True], session=self.client)
async def subscribe_address(self, *addresses):
async with self.client.send_batch() as batch:
for address in addresses:
batch.add_request('blockchain.address.subscribe', [address])
for address, status in zip(addresses, batch.results):
yield address, status
class SessionPool:
@ -203,31 +241,62 @@ class SessionPool:
if not self.available_sessions:
return None
return min(
[(session.response_time, session) for session in self.available_sessions], key=itemgetter(0)
[((session.response_time + session.connection_latency) * (session.pending_amount + 1), session)
for session in self.available_sessions],
key=itemgetter(0)
)[1]
def start(self, default_servers):
callback = self.new_connection_event.set
self.sessions = {
ClientSession(
network=self.network, server=server, on_connect_callback=callback
): None for server in default_servers
}
self.ensure_connections()
def _get_session_connect_callback(self, session: ClientSession):
loop = asyncio.get_event_loop()
def stop(self):
for session, task in self.sessions.items():
task.cancel()
session.abort()
self.sessions.clear()
def callback():
duplicate_connections = [
s for s in self.sessions
if s is not session and s.server_address_and_port == session.server_address_and_port
]
already_connected = None if not duplicate_connections else duplicate_connections[0]
if already_connected:
self.sessions.pop(session).cancel()
session.synchronous_close()
log.info("wallet server %s resolves to the same server as %s, rechecking in an hour",
session.server[0], already_connected.server[0])
loop.call_later(3600, self._connect_session, session.server)
return
self.new_connection_event.set()
log.info("connected to %s:%i", *session.server)
def ensure_connections(self):
for session, task in list(self.sessions.items()):
return callback
def _connect_session(self, server: Tuple[str, int]):
session = None
for s in self.sessions:
if s.server == server:
session = s
break
if not session:
session = ClientSession(
network=self.network, server=server
)
session._on_connect_cb = self._get_session_connect_callback(session)
task = self.sessions.get(session, None)
if not task or task.done():
task = asyncio.create_task(session.ensure_session())
task.add_done_callback(lambda _: self.ensure_connections())
self.sessions[session] = task
def start(self, default_servers):
for server in default_servers:
self._connect_session(server)
def stop(self):
for task in self.sessions.values():
task.cancel()
self.sessions.clear()
def ensure_connections(self):
for session in self.sessions:
self._connect_session(session.server)
def trigger_nodelay_connect(self):
# used when other parts of the system sees we might have internet back
# bypasses the retry interval

View file

@ -746,10 +746,8 @@ class JSONRPCConnection(object):
self._protocol = item
return self.receive_message(message)
def time_out_pending_requests(self):
"""Times out all pending requests."""
# this used to be CancelledError, but thats confusing as in are we closing the whole sdk or failing?
exception = asyncio.TimeoutError()
def raise_pending_requests(self, exception):
exception = exception or asyncio.TimeoutError()
for request, event in self._requests.values():
event.result = exception
event.set()

View file

@ -103,7 +103,7 @@ class SessionBase(asyncio.Protocol):
# Force-close a connection if a send doesn't succeed in this time
self.max_send_delay = 60
# Statistics. The RPC object also keeps its own statistics.
self.start_time = time.time()
self.start_time = time.perf_counter()
self.errors = 0
self.send_count = 0
self.send_size = 0
@ -123,7 +123,7 @@ class SessionBase(asyncio.Protocol):
# A non-positive value means not to limit concurrency
if self.bw_limit <= 0:
return
now = time.time()
now = time.perf_counter()
# Reduce the recorded usage in proportion to the elapsed time
refund = (now - self.bw_time) * (self.bw_limit / 3600)
self.bw_charge = max(0, self.bw_charge - int(refund))
@ -146,7 +146,7 @@ class SessionBase(asyncio.Protocol):
await asyncio.wait_for(self._can_send.wait(), secs)
except asyncio.TimeoutError:
self.abort()
raise asyncio.CancelledError(f'task timed out after {secs}s')
raise asyncio.TimeoutError(f'task timed out after {secs}s')
async def _send_message(self, message):
if not self._can_send.is_set():
@ -156,7 +156,7 @@ class SessionBase(asyncio.Protocol):
self.send_size += len(framed_message)
self._using_bandwidth(len(framed_message))
self.send_count += 1
self.last_send = time.time()
self.last_send = time.perf_counter()
if self.verbosity >= 4:
self.logger.debug(f'Sending framed message {framed_message}')
self.transport.write(framed_message)
@ -215,6 +215,7 @@ class SessionBase(asyncio.Protocol):
self._address = None
self.transport = None
self._task_group.cancel()
if self._pm_task:
self._pm_task.cancel()
# Release waiting tasks
self._can_send.set()
@ -253,6 +254,7 @@ class SessionBase(asyncio.Protocol):
if self.transport:
self.transport.abort()
# TODO: replace with synchronous_close
async def close(self, *, force_after=30):
"""Close the connection and return when closed."""
self._close()
@ -262,6 +264,11 @@ class SessionBase(asyncio.Protocol):
self.abort()
await self._pm_task
def synchronous_close(self):
self._close()
if self._pm_task and not self._pm_task.done():
self._pm_task.cancel()
class MessageSession(SessionBase):
"""Session class for protocols where messages are not tied to responses,
@ -296,7 +303,7 @@ class MessageSession(SessionBase):
)
self._bump_errors()
else:
self.last_recv = time.time()
self.last_recv = time.perf_counter()
self.recv_count += 1
if self.recv_count % 10 == 0:
await self._update_concurrency()
@ -416,7 +423,7 @@ class RPCSession(SessionBase):
self.logger.warning(f'{e!r}')
continue
self.last_recv = time.time()
self.last_recv = time.perf_counter()
self.recv_count += 1
if self.recv_count % 10 == 0:
await self._update_concurrency()
@ -456,7 +463,7 @@ class RPCSession(SessionBase):
def connection_lost(self, exc):
# Cancel pending requests and message processing
self.connection.time_out_pending_requests()
self.connection.raise_pending_requests(exc)
super().connection_lost(exc)
# External API

View file

@ -258,7 +258,7 @@ class SessionManager:
session_timeout = self.env.session_timeout
while True:
await sleep(session_timeout // 10)
stale_cutoff = time.time() - session_timeout
stale_cutoff = time.perf_counter() - session_timeout
stale_sessions = [session for session in self.sessions
if session.last_recv < stale_cutoff]
if stale_sessions: