Compare commits
1 commit
Author | SHA1 | Date | |
---|---|---|---|
|
593422f70c |
8 changed files with 300 additions and 175 deletions
|
@ -18,6 +18,9 @@ class MockNetwork:
|
|||
self.get_transaction_called = []
|
||||
self.is_connected = False
|
||||
|
||||
def retriable_call(self, function, *args, **kwargs):
|
||||
return function(*args, **kwargs)
|
||||
|
||||
async def get_history(self, address):
|
||||
self.get_history_called.append(address)
|
||||
self.address = address
|
||||
|
@ -121,8 +124,9 @@ class TestSynchronization(LedgerTestCase):
|
|||
)
|
||||
|
||||
|
||||
class MocHeaderNetwork:
|
||||
class MocHeaderNetwork(MockNetwork):
|
||||
def __init__(self, responses):
|
||||
super().__init__(None, None)
|
||||
self.responses = responses
|
||||
|
||||
async def get_headers(self, height, blocks):
|
||||
|
|
|
@ -70,8 +70,10 @@ class BaseHeaders:
|
|||
return True
|
||||
|
||||
def __getitem__(self, height) -> dict:
|
||||
assert not isinstance(height, slice), \
|
||||
"Slicing of header chain has not been implemented yet."
|
||||
if isinstance(height, slice):
|
||||
raise NotImplementedError("Slicing of header chain has not been implemented yet.")
|
||||
if not 0 <= height <= self.height:
|
||||
raise IndexError(f"{height} is out of bounds, current height: {self.height}")
|
||||
return self.deserialize(height, self.get_raw_header(height))
|
||||
|
||||
def get_raw_header(self, height) -> bytes:
|
||||
|
|
|
@ -287,7 +287,7 @@ class BaseLedger(metaclass=LedgerRegistry):
|
|||
subscription_update = False
|
||||
|
||||
if not headers:
|
||||
header_response = await self.network.get_headers(height, 2001)
|
||||
header_response = await self.network.retriable_call(self.network.get_headers, height, 2001)
|
||||
headers = header_response['hex']
|
||||
|
||||
if not headers:
|
||||
|
@ -394,7 +394,7 @@ class BaseLedger(metaclass=LedgerRegistry):
|
|||
if local_status == remote_status:
|
||||
return
|
||||
|
||||
remote_history = await self.network.get_history(address)
|
||||
remote_history = await self.network.retriable_call(self.network.get_history, address)
|
||||
|
||||
cache_tasks = []
|
||||
synced_history = StringIO()
|
||||
|
@ -447,6 +447,17 @@ class BaseLedger(metaclass=LedgerRegistry):
|
|||
if address_manager is not None:
|
||||
await address_manager.ensure_address_gap()
|
||||
|
||||
local_status, local_history = await self.get_local_status_and_history(address)
|
||||
if local_status != remote_status:
|
||||
log.debug(
|
||||
"Wallet is out of sync after syncing. Remote: %s with %d items, local: %s with %d items",
|
||||
remote_status, len(remote_history), local_status, len(local_history)
|
||||
)
|
||||
log.debug("local: %s", local_history)
|
||||
log.debug("remote: %s", remote_history)
|
||||
else:
|
||||
log.debug("Sync completed for: %s", address)
|
||||
|
||||
async def cache_transaction(self, txid, remote_height):
|
||||
cache_item = self._tx_cache.get(txid)
|
||||
if cache_item is None:
|
||||
|
@ -466,7 +477,7 @@ class BaseLedger(metaclass=LedgerRegistry):
|
|||
|
||||
if tx is None:
|
||||
# fetch from network
|
||||
_raw = await self.network.get_transaction(txid)
|
||||
_raw = await self.network.retriable_call(self.network.get_transaction, txid)
|
||||
if _raw:
|
||||
tx = self.transaction_class(unhexlify(_raw))
|
||||
await self.maybe_verify_transaction(tx, remote_height)
|
||||
|
@ -486,8 +497,8 @@ class BaseLedger(metaclass=LedgerRegistry):
|
|||
|
||||
async def maybe_verify_transaction(self, tx, remote_height):
|
||||
tx.height = remote_height
|
||||
if 0 < remote_height <= len(self.headers):
|
||||
merkle = await self.network.get_merkle(tx.id, remote_height)
|
||||
if 0 < remote_height < len(self.headers):
|
||||
merkle = await self.network.retriable_call(self.network.get_merkle, tx.id, remote_height)
|
||||
merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash)
|
||||
header = self.headers[remote_height]
|
||||
tx.position = merkle['pos']
|
||||
|
@ -501,6 +512,7 @@ class BaseLedger(metaclass=LedgerRegistry):
|
|||
return None
|
||||
|
||||
def broadcast(self, tx):
|
||||
# broadcast cant be a retriable call yet
|
||||
return self.network.broadcast(hexlify(tx.raw).decode())
|
||||
|
||||
async def wait(self, tx: basetransaction.BaseTransaction, height=-1, timeout=None):
|
||||
|
@ -522,4 +534,4 @@ class BaseLedger(metaclass=LedgerRegistry):
|
|||
)) for address_record in records
|
||||
], timeout=timeout)
|
||||
if pending:
|
||||
raise TimeoutError('Timed out waiting for transaction.')
|
||||
raise asyncio.TimeoutError('Timed out waiting for transaction.')
|
||||
|
|
|
@ -1,9 +1,8 @@
|
|||
import logging
|
||||
import asyncio
|
||||
from asyncio import CancelledError
|
||||
from time import time
|
||||
from typing import List
|
||||
import socket
|
||||
from operator import itemgetter
|
||||
from typing import Dict, Optional, Tuple
|
||||
from time import perf_counter
|
||||
|
||||
from torba.rpc import RPCSession as BaseClientSession, Connector, RPCError
|
||||
|
||||
|
@ -14,71 +13,146 @@ log = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class ClientSession(BaseClientSession):
|
||||
|
||||
def __init__(self, *args, network, server, timeout=30, **kwargs):
|
||||
def __init__(self, *args, network, server, timeout=30, on_connect_callback=None, **kwargs):
|
||||
self.network = network
|
||||
self.server = server
|
||||
super().__init__(*args, **kwargs)
|
||||
self._on_disconnect_controller = StreamController()
|
||||
self.on_disconnected = self._on_disconnect_controller.stream
|
||||
self.bw_limit = self.framer.max_size = self.max_errors = 1 << 32
|
||||
self.framer.max_size = self.max_errors = 1 << 32
|
||||
self.bw_limit = -1
|
||||
self.timeout = timeout
|
||||
self.max_seconds_idle = timeout * 2
|
||||
self.ping_task = None
|
||||
self.response_time: Optional[float] = None
|
||||
self.connection_latency: Optional[float] = None
|
||||
self._response_samples = 0
|
||||
self.pending_amount = 0
|
||||
self._on_connect_cb = on_connect_callback or (lambda: None)
|
||||
self.trigger_urgent_reconnect = asyncio.Event()
|
||||
self._semaphore = asyncio.Semaphore(int(self.timeout))
|
||||
|
||||
@property
|
||||
def available(self):
|
||||
return not self.is_closing() and self._can_send.is_set() and self.response_time is not None
|
||||
|
||||
@property
|
||||
def server_address_and_port(self) -> Optional[Tuple[str, int]]:
|
||||
if not self.transport:
|
||||
return None
|
||||
return self.transport.get_extra_info('peername')
|
||||
|
||||
async def send_timed_server_version_request(self, args=(), timeout=None):
|
||||
timeout = timeout or self.timeout
|
||||
log.debug("send version request to %s:%i", *self.server)
|
||||
start = perf_counter()
|
||||
result = await asyncio.wait_for(
|
||||
super().send_request('server.version', args), timeout=timeout
|
||||
)
|
||||
current_response_time = perf_counter() - start
|
||||
response_sum = (self.response_time or 0) * self._response_samples + current_response_time
|
||||
self.response_time = response_sum / (self._response_samples + 1)
|
||||
self._response_samples += 1
|
||||
return result
|
||||
|
||||
async def send_request(self, method, args=()):
|
||||
try:
|
||||
return await asyncio.wait_for(super().send_request(method, args), timeout=self.timeout)
|
||||
except RPCError as e:
|
||||
log.warning("Wallet server returned an error. Code: %s Message: %s", *e.args)
|
||||
raise e
|
||||
except asyncio.TimeoutError:
|
||||
self.abort()
|
||||
raise
|
||||
self.pending_amount += 1
|
||||
async with self._semaphore:
|
||||
return await self._send_request(method, args)
|
||||
|
||||
async def ping_forever(self):
|
||||
async def _send_request(self, method, args=()):
|
||||
log.debug("send %s to %s:%i", method, *self.server)
|
||||
try:
|
||||
if method == 'server.version':
|
||||
reply = await self.send_timed_server_version_request(args, self.timeout)
|
||||
else:
|
||||
reply = await asyncio.wait_for(
|
||||
super().send_request(method, args), timeout=self.timeout
|
||||
)
|
||||
log.debug("got reply for %s from %s:%i", method, *self.server)
|
||||
return reply
|
||||
except RPCError as e:
|
||||
log.warning("Wallet server (%s:%i) returned an error. Code: %s Message: %s",
|
||||
*self.server, *e.args)
|
||||
raise e
|
||||
except ConnectionError:
|
||||
log.warning("connection to %s:%i lost", *self.server)
|
||||
self.synchronous_close()
|
||||
raise asyncio.CancelledError(f"connection to {self.server[0]}:{self.server[1]} lost")
|
||||
except asyncio.TimeoutError:
|
||||
log.info("timeout sending %s to %s:%i", method, *self.server)
|
||||
raise
|
||||
except asyncio.CancelledError:
|
||||
log.info("cancelled sending %s to %s:%i", method, *self.server)
|
||||
self.synchronous_close()
|
||||
raise
|
||||
finally:
|
||||
self.pending_amount -= 1
|
||||
|
||||
async def ensure_session(self):
|
||||
# Handles reconnecting and maintaining a session alive
|
||||
# TODO: change to 'ping' on newer protocol (above 1.2)
|
||||
while not self.is_closing():
|
||||
if (time() - self.last_send) > self.max_seconds_idle:
|
||||
try:
|
||||
await self.send_request('server.banner')
|
||||
except:
|
||||
self.abort()
|
||||
raise
|
||||
await asyncio.sleep(self.max_seconds_idle//3)
|
||||
retry_delay = default_delay = 1.0
|
||||
while True:
|
||||
try:
|
||||
if self.is_closing():
|
||||
await self.create_connection(self.timeout)
|
||||
await self.ensure_server_version()
|
||||
self._on_connect_cb()
|
||||
if (perf_counter() - self.last_send) > self.max_seconds_idle or self.response_time is None:
|
||||
await self.ensure_server_version()
|
||||
retry_delay = default_delay
|
||||
except (asyncio.TimeoutError, OSError):
|
||||
await self.close()
|
||||
retry_delay = min(60, retry_delay * 2)
|
||||
log.debug("Wallet server timeout (retry in %s seconds): %s:%d", retry_delay, *self.server)
|
||||
try:
|
||||
await asyncio.wait_for(self.trigger_urgent_reconnect.wait(), timeout=retry_delay)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
finally:
|
||||
self.trigger_urgent_reconnect.clear()
|
||||
|
||||
async def ensure_server_version(self, required='1.2', timeout=3):
|
||||
return await asyncio.wait_for(
|
||||
self.send_request('server.version', [__version__, required]), timeout=timeout
|
||||
)
|
||||
|
||||
async def create_connection(self, timeout=6):
|
||||
connector = Connector(lambda: self, *self.server)
|
||||
start = perf_counter()
|
||||
await asyncio.wait_for(connector.create_connection(), timeout=timeout)
|
||||
self.ping_task = asyncio.create_task(self.ping_forever())
|
||||
self.connection_latency = perf_counter() - start
|
||||
|
||||
async def handle_request(self, request):
|
||||
controller = self.network.subscription_controllers[request.method]
|
||||
controller.add(request.args)
|
||||
|
||||
def connection_lost(self, exc):
|
||||
log.debug("Connection lost: %s:%d", *self.server)
|
||||
super().connection_lost(exc)
|
||||
self.response_time = None
|
||||
self.connection_latency = None
|
||||
self._response_samples = 0
|
||||
self.pending_amount = 0
|
||||
self._on_disconnect_controller.add(True)
|
||||
if self.ping_task:
|
||||
self.ping_task.cancel()
|
||||
|
||||
|
||||
class BaseNetwork:
|
||||
|
||||
def __init__(self, ledger):
|
||||
self.config = ledger.config
|
||||
self.client: ClientSession = None
|
||||
self.session_pool: SessionPool = None
|
||||
self.session_pool = SessionPool(network=self, timeout=self.config.get('connect_timeout', 6))
|
||||
self.client: Optional[ClientSession] = None
|
||||
self.running = False
|
||||
self.remote_height: int = 0
|
||||
|
||||
self._on_connected_controller = StreamController()
|
||||
self.on_connected = self._on_connected_controller.stream
|
||||
|
||||
self._on_header_controller = StreamController()
|
||||
self._on_header_controller = StreamController(merge_repeated_events=True)
|
||||
self.on_header = self._on_header_controller.stream
|
||||
|
||||
self._on_status_controller = StreamController()
|
||||
self._on_status_controller = StreamController(merge_repeated_events=True)
|
||||
self.on_status = self._on_status_controller.stream
|
||||
|
||||
self.subscription_controllers = {
|
||||
|
@ -86,79 +160,72 @@ class BaseNetwork:
|
|||
'blockchain.address.subscribe': self._on_status_controller,
|
||||
}
|
||||
|
||||
async def switch_to_fastest(self):
|
||||
try:
|
||||
client = await asyncio.wait_for(self.session_pool.wait_for_fastest_session(), 30)
|
||||
except asyncio.TimeoutError:
|
||||
if self.client:
|
||||
await self.client.close()
|
||||
self.client = None
|
||||
for session in self.session_pool.sessions:
|
||||
session.synchronous_close()
|
||||
log.warning("not connected to any wallet servers")
|
||||
return
|
||||
current_client = self.client
|
||||
self.client = client
|
||||
log.info("Switching to SPV wallet server: %s:%d", *self.client.server)
|
||||
self._on_connected_controller.add(True)
|
||||
try:
|
||||
self._update_remote_height((await self.subscribe_headers(),))
|
||||
log.info("Subscribed to headers: %s:%d", *self.client.server)
|
||||
except asyncio.TimeoutError:
|
||||
if self.client:
|
||||
await self.client.close()
|
||||
self.client = current_client
|
||||
return
|
||||
self.session_pool.new_connection_event.clear()
|
||||
return await self.session_pool.new_connection_event.wait()
|
||||
|
||||
async def start(self):
|
||||
self.running = True
|
||||
connect_timeout = self.config.get('connect_timeout', 6)
|
||||
self.session_pool = SessionPool(network=self, timeout=connect_timeout)
|
||||
self.session_pool.start(self.config['default_servers'])
|
||||
self.on_header.listen(self._update_remote_height)
|
||||
while True:
|
||||
try:
|
||||
self.client = await self.pick_fastest_session()
|
||||
if self.is_connected:
|
||||
await self.ensure_server_version()
|
||||
self._update_remote_height((await self.subscribe_headers(),))
|
||||
log.info("Successfully connected to SPV wallet server: %s:%d", *self.client.server)
|
||||
self._on_connected_controller.add(True)
|
||||
await self.client.on_disconnected.first
|
||||
except CancelledError:
|
||||
self.running = False
|
||||
except asyncio.TimeoutError:
|
||||
log.warning("Timed out while trying to find a server!")
|
||||
except Exception: # pylint: disable=broad-except
|
||||
log.exception("Exception while trying to find a server!")
|
||||
if not self.running:
|
||||
return
|
||||
elif self.client:
|
||||
await self.client.close()
|
||||
self.client.connection.cancel_pending_requests()
|
||||
while self.running:
|
||||
await self.switch_to_fastest()
|
||||
|
||||
async def stop(self):
|
||||
self.running = False
|
||||
if self.session_pool:
|
||||
self.session_pool.stop()
|
||||
if self.is_connected:
|
||||
disconnected = self.client.on_disconnected.first
|
||||
await self.client.close()
|
||||
await disconnected
|
||||
self.session_pool.stop()
|
||||
|
||||
@property
|
||||
def is_connected(self):
|
||||
return self.client is not None and not self.client.is_closing()
|
||||
return self.client and not self.client.is_closing()
|
||||
|
||||
def rpc(self, list_or_method, args):
|
||||
if self.is_connected:
|
||||
return self.client.send_request(list_or_method, args)
|
||||
def rpc(self, list_or_method, args, session=None):
|
||||
session = session or self.session_pool.fastest_session
|
||||
if session:
|
||||
return session.send_request(list_or_method, args)
|
||||
else:
|
||||
self.session_pool.trigger_nodelay_connect()
|
||||
raise ConnectionError("Attempting to send rpc request when connection is not available.")
|
||||
|
||||
async def pick_fastest_session(self):
|
||||
sessions = await self.session_pool.get_online_sessions()
|
||||
done, pending = await asyncio.wait([
|
||||
self.probe_session(session)
|
||||
for session in sessions if not session.is_closing()
|
||||
], return_when='FIRST_COMPLETED')
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
for session in done:
|
||||
return await session
|
||||
|
||||
async def probe_session(self, session: ClientSession):
|
||||
await session.send_request('server.banner')
|
||||
return session
|
||||
async def retriable_call(self, function, *args, **kwargs):
|
||||
while self.running:
|
||||
if not self.is_connected:
|
||||
log.warning("Wallet server unavailable, waiting for it to come back and retry.")
|
||||
await self.on_connected.first
|
||||
await self.session_pool.wait_for_fastest_session()
|
||||
try:
|
||||
return await function(*args, **kwargs)
|
||||
except asyncio.TimeoutError:
|
||||
log.warning("Wallet server call timed out, retrying.")
|
||||
except ConnectionError:
|
||||
pass
|
||||
raise asyncio.CancelledError() # if we got here, we are shutting down
|
||||
|
||||
def _update_remote_height(self, header_args):
|
||||
self.remote_height = header_args[0]["height"]
|
||||
|
||||
def ensure_server_version(self, required='1.2'):
|
||||
return self.rpc('server.version', [__version__, required])
|
||||
|
||||
def broadcast(self, raw_transaction):
|
||||
return self.rpc('blockchain.transaction.broadcast', [raw_transaction])
|
||||
|
||||
def get_history(self, address):
|
||||
return self.rpc('blockchain.address.get_history', [address])
|
||||
|
||||
def get_transaction(self, tx_hash):
|
||||
return self.rpc('blockchain.transaction.get', [tx_hash])
|
||||
|
||||
|
@ -171,84 +238,111 @@ class BaseNetwork:
|
|||
def get_headers(self, height, count=10000):
|
||||
return self.rpc('blockchain.block.headers', [height, count])
|
||||
|
||||
def subscribe_headers(self):
|
||||
return self.rpc('blockchain.headers.subscribe', [True])
|
||||
# --- Subscribes, history and broadcasts are always aimed towards the master client directly
|
||||
def get_history(self, address):
|
||||
return self.rpc('blockchain.address.get_history', [address], session=self.client)
|
||||
|
||||
def subscribe_address(self, address):
|
||||
return self.rpc('blockchain.address.subscribe', [address])
|
||||
def broadcast(self, raw_transaction):
|
||||
return self.rpc('blockchain.transaction.broadcast', [raw_transaction], session=self.client)
|
||||
|
||||
def subscribe_headers(self):
|
||||
return self.rpc('blockchain.headers.subscribe', [True], session=self.client)
|
||||
|
||||
async def subscribe_address(self, address):
|
||||
try:
|
||||
return await self.rpc('blockchain.address.subscribe', [address], session=self.client)
|
||||
except asyncio.TimeoutError:
|
||||
# abort and cancel, we cant lose a subscription, it will happen again on reconnect
|
||||
self.client.abort()
|
||||
raise asyncio.CancelledError()
|
||||
|
||||
|
||||
class SessionPool:
|
||||
|
||||
def __init__(self, network: BaseNetwork, timeout: float):
|
||||
self.network = network
|
||||
self.sessions: List[ClientSession] = []
|
||||
self._dead_servers: List[ClientSession] = []
|
||||
self.maintain_connections_task = None
|
||||
self.sessions: Dict[ClientSession, Optional[asyncio.Task]] = dict()
|
||||
self.timeout = timeout
|
||||
# triggered when the master server is out, to speed up reconnect
|
||||
self._lost_master = asyncio.Event()
|
||||
self.new_connection_event = asyncio.Event()
|
||||
|
||||
@property
|
||||
def online(self):
|
||||
for session in self.sessions:
|
||||
if not session.is_closing():
|
||||
return True
|
||||
return False
|
||||
return any(not session.is_closing() for session in self.sessions)
|
||||
|
||||
@property
|
||||
def available_sessions(self):
|
||||
return [session for session in self.sessions if session.available]
|
||||
|
||||
@property
|
||||
def fastest_session(self):
|
||||
if not self.available_sessions:
|
||||
return None
|
||||
return min(
|
||||
[((session.response_time + session.connection_latency) * (session.pending_amount + 1), session)
|
||||
for session in self.available_sessions],
|
||||
key=itemgetter(0)
|
||||
)[1]
|
||||
|
||||
def _get_session_connect_callback(self, session: ClientSession):
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
def callback():
|
||||
duplicate_connections = [
|
||||
s for s in self.sessions
|
||||
if s is not session and s.server_address_and_port == session.server_address_and_port
|
||||
]
|
||||
already_connected = None if not duplicate_connections else duplicate_connections[0]
|
||||
if already_connected:
|
||||
self.sessions.pop(session).cancel()
|
||||
session.synchronous_close()
|
||||
log.debug("wallet server %s resolves to the same server as %s, rechecking in an hour",
|
||||
session.server[0], already_connected.server[0])
|
||||
loop.call_later(3600, self._connect_session, session.server)
|
||||
return
|
||||
self.new_connection_event.set()
|
||||
log.info("connected to %s:%i", *session.server)
|
||||
|
||||
return callback
|
||||
|
||||
def _connect_session(self, server: Tuple[str, int]):
|
||||
session = None
|
||||
for s in self.sessions:
|
||||
if s.server == server:
|
||||
session = s
|
||||
break
|
||||
if not session:
|
||||
session = ClientSession(
|
||||
network=self.network, server=server
|
||||
)
|
||||
session._on_connect_cb = self._get_session_connect_callback(session)
|
||||
task = self.sessions.get(session, None)
|
||||
if not task or task.done():
|
||||
task = asyncio.create_task(session.ensure_session())
|
||||
task.add_done_callback(lambda _: self.ensure_connections())
|
||||
self.sessions[session] = task
|
||||
|
||||
def start(self, default_servers):
|
||||
self.sessions = [
|
||||
ClientSession(network=self.network, server=server)
|
||||
for server in default_servers
|
||||
]
|
||||
self.maintain_connections_task = asyncio.create_task(self.ensure_connections())
|
||||
for server in default_servers:
|
||||
self._connect_session(server)
|
||||
|
||||
def stop(self):
|
||||
if self.maintain_connections_task:
|
||||
self.maintain_connections_task.cancel()
|
||||
for task in self.sessions.values():
|
||||
task.cancel()
|
||||
self.sessions.clear()
|
||||
|
||||
def ensure_connections(self):
|
||||
for session in self.sessions:
|
||||
if not session.is_closing():
|
||||
session.abort()
|
||||
self.sessions, self._dead_servers, self.maintain_connections_task = [], [], None
|
||||
self._connect_session(session.server)
|
||||
|
||||
async def ensure_connections(self):
|
||||
while True:
|
||||
await asyncio.gather(*[
|
||||
self.ensure_connection(session)
|
||||
for session in self.sessions
|
||||
], return_exceptions=True)
|
||||
try:
|
||||
await asyncio.wait_for(self._lost_master.wait(), timeout=3)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
self._lost_master.clear()
|
||||
if not self.sessions:
|
||||
self.sessions.extend(self._dead_servers)
|
||||
self._dead_servers = []
|
||||
def trigger_nodelay_connect(self):
|
||||
# used when other parts of the system sees we might have internet back
|
||||
# bypasses the retry interval
|
||||
for session in self.sessions:
|
||||
session.trigger_urgent_reconnect.set()
|
||||
|
||||
async def ensure_connection(self, session):
|
||||
self._dead_servers.append(session)
|
||||
self.sessions.remove(session)
|
||||
try:
|
||||
if session.is_closing():
|
||||
await session.create_connection(self.timeout)
|
||||
await asyncio.wait_for(session.send_request('server.banner'), timeout=self.timeout)
|
||||
self.sessions.append(session)
|
||||
self._dead_servers.remove(session)
|
||||
except asyncio.TimeoutError:
|
||||
log.warning("Timeout connecting to %s:%d", *session.server)
|
||||
except asyncio.CancelledError: # pylint: disable=try-except-raise
|
||||
raise
|
||||
except socket.gaierror:
|
||||
log.warning("Could not resolve IP for %s", session.server[0])
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
if 'Connect call failed' in str(err):
|
||||
log.warning("Could not connect to %s:%d", *session.server)
|
||||
else:
|
||||
log.exception("Connecting to %s:%d raised an exception:", *session.server)
|
||||
|
||||
async def get_online_sessions(self):
|
||||
while not self.online:
|
||||
self._lost_master.set()
|
||||
await asyncio.sleep(0.5)
|
||||
return self.sessions
|
||||
async def wait_for_fastest_session(self):
|
||||
while not self.fastest_session:
|
||||
self.trigger_nodelay_connect()
|
||||
self.new_connection_event.clear()
|
||||
await self.new_connection_event.wait()
|
||||
return self.fastest_session
|
||||
|
|
|
@ -33,6 +33,7 @@ __all__ = ('JSONRPC', 'JSONRPCv1', 'JSONRPCv2', 'JSONRPCLoose',
|
|||
import itertools
|
||||
import json
|
||||
import typing
|
||||
import asyncio
|
||||
from functools import partial
|
||||
from numbers import Number
|
||||
|
||||
|
@ -745,9 +746,8 @@ class JSONRPCConnection(object):
|
|||
self._protocol = item
|
||||
return self.receive_message(message)
|
||||
|
||||
def cancel_pending_requests(self):
|
||||
"""Cancel all pending requests."""
|
||||
exception = CancelledError()
|
||||
def raise_pending_requests(self, exception):
|
||||
exception = exception or asyncio.TimeoutError()
|
||||
for request, event in self._requests.values():
|
||||
event.result = exception
|
||||
event.set()
|
||||
|
|
|
@ -103,7 +103,7 @@ class SessionBase(asyncio.Protocol):
|
|||
# Force-close a connection if a send doesn't succeed in this time
|
||||
self.max_send_delay = 60
|
||||
# Statistics. The RPC object also keeps its own statistics.
|
||||
self.start_time = time.time()
|
||||
self.start_time = time.perf_counter()
|
||||
self.errors = 0
|
||||
self.send_count = 0
|
||||
self.send_size = 0
|
||||
|
@ -123,7 +123,7 @@ class SessionBase(asyncio.Protocol):
|
|||
# A non-positive value means not to limit concurrency
|
||||
if self.bw_limit <= 0:
|
||||
return
|
||||
now = time.time()
|
||||
now = time.perf_counter()
|
||||
# Reduce the recorded usage in proportion to the elapsed time
|
||||
refund = (now - self.bw_time) * (self.bw_limit / 3600)
|
||||
self.bw_charge = max(0, self.bw_charge - int(refund))
|
||||
|
@ -146,7 +146,7 @@ class SessionBase(asyncio.Protocol):
|
|||
await asyncio.wait_for(self._can_send.wait(), secs)
|
||||
except asyncio.TimeoutError:
|
||||
self.abort()
|
||||
raise asyncio.CancelledError(f'task timed out after {secs}s')
|
||||
raise asyncio.TimeoutError(f'task timed out after {secs}s')
|
||||
|
||||
async def _send_message(self, message):
|
||||
if not self._can_send.is_set():
|
||||
|
@ -156,7 +156,7 @@ class SessionBase(asyncio.Protocol):
|
|||
self.send_size += len(framed_message)
|
||||
self._using_bandwidth(len(framed_message))
|
||||
self.send_count += 1
|
||||
self.last_send = time.time()
|
||||
self.last_send = time.perf_counter()
|
||||
if self.verbosity >= 4:
|
||||
self.logger.debug(f'Sending framed message {framed_message}')
|
||||
self.transport.write(framed_message)
|
||||
|
@ -215,7 +215,8 @@ class SessionBase(asyncio.Protocol):
|
|||
self._address = None
|
||||
self.transport = None
|
||||
self._task_group.cancel()
|
||||
self._pm_task.cancel()
|
||||
if self._pm_task:
|
||||
self._pm_task.cancel()
|
||||
# Release waiting tasks
|
||||
self._can_send.set()
|
||||
|
||||
|
@ -253,6 +254,7 @@ class SessionBase(asyncio.Protocol):
|
|||
if self.transport:
|
||||
self.transport.abort()
|
||||
|
||||
# TODO: replace with synchronous_close
|
||||
async def close(self, *, force_after=30):
|
||||
"""Close the connection and return when closed."""
|
||||
self._close()
|
||||
|
@ -262,6 +264,11 @@ class SessionBase(asyncio.Protocol):
|
|||
self.abort()
|
||||
await self._pm_task
|
||||
|
||||
def synchronous_close(self):
|
||||
self._close()
|
||||
if self._pm_task and not self._pm_task.done():
|
||||
self._pm_task.cancel()
|
||||
|
||||
|
||||
class MessageSession(SessionBase):
|
||||
"""Session class for protocols where messages are not tied to responses,
|
||||
|
@ -296,7 +303,7 @@ class MessageSession(SessionBase):
|
|||
)
|
||||
self._bump_errors()
|
||||
else:
|
||||
self.last_recv = time.time()
|
||||
self.last_recv = time.perf_counter()
|
||||
self.recv_count += 1
|
||||
if self.recv_count % 10 == 0:
|
||||
await self._update_concurrency()
|
||||
|
@ -416,7 +423,7 @@ class RPCSession(SessionBase):
|
|||
self.logger.warning(f'{e!r}')
|
||||
continue
|
||||
|
||||
self.last_recv = time.time()
|
||||
self.last_recv = time.perf_counter()
|
||||
self.recv_count += 1
|
||||
if self.recv_count % 10 == 0:
|
||||
await self._update_concurrency()
|
||||
|
@ -456,7 +463,7 @@ class RPCSession(SessionBase):
|
|||
|
||||
def connection_lost(self, exc):
|
||||
# Cancel pending requests and message processing
|
||||
self.connection.cancel_pending_requests()
|
||||
self.connection.raise_pending_requests(exc)
|
||||
super().connection_lost(exc)
|
||||
|
||||
# External API
|
||||
|
@ -473,6 +480,8 @@ class RPCSession(SessionBase):
|
|||
|
||||
async def send_request(self, method, args=()):
|
||||
"""Send an RPC request over the network."""
|
||||
if self.is_closing():
|
||||
raise asyncio.TimeoutError("Trying to send request on a recently dropped connection.")
|
||||
message, event = self.connection.send_request(Request(method, args))
|
||||
await self._send_message(message)
|
||||
await event.wait()
|
||||
|
|
|
@ -258,7 +258,7 @@ class SessionManager:
|
|||
session_timeout = self.env.session_timeout
|
||||
while True:
|
||||
await sleep(session_timeout // 10)
|
||||
stale_cutoff = time.time() - session_timeout
|
||||
stale_cutoff = time.perf_counter() - session_timeout
|
||||
stale_sessions = [session for session in self.sessions
|
||||
if session.last_recv < stale_cutoff]
|
||||
if stale_sessions:
|
||||
|
|
|
@ -45,10 +45,12 @@ class BroadcastSubscription:
|
|||
|
||||
class StreamController:
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, merge_repeated_events=False):
|
||||
self.stream = Stream(self)
|
||||
self._first_subscription = None
|
||||
self._last_subscription = None
|
||||
self._last_event = None
|
||||
self._merge_repeated = merge_repeated_events
|
||||
|
||||
@property
|
||||
def has_listener(self):
|
||||
|
@ -76,8 +78,10 @@ class StreamController:
|
|||
return f
|
||||
|
||||
def add(self, event):
|
||||
skip = self._merge_repeated and event == self._last_event
|
||||
self._last_event = event
|
||||
return self._notify_and_ensure_future(
|
||||
lambda subscription: subscription._add(event)
|
||||
lambda subscription: None if skip else subscription._add(event)
|
||||
)
|
||||
|
||||
def add_error(self, exception):
|
||||
|
@ -141,8 +145,8 @@ class Stream:
|
|||
def first(self):
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
subscription = self.listen(
|
||||
lambda value: self._cancel_and_callback(subscription, future, value),
|
||||
lambda exception: self._cancel_and_error(subscription, future, exception)
|
||||
lambda value: not future.done() and self._cancel_and_callback(subscription, future, value),
|
||||
lambda exception: not future.done() and self._cancel_and_error(subscription, future, exception)
|
||||
)
|
||||
return future
|
||||
|
||||
|
|
Loading…
Reference in a new issue