This commit is contained in:
Victor Shyba 2019-08-23 21:43:25 -03:00
parent de12e25450
commit 593422f70c
8 changed files with 300 additions and 175 deletions

View file

@ -18,6 +18,9 @@ class MockNetwork:
self.get_transaction_called = []
self.is_connected = False
def retriable_call(self, function, *args, **kwargs):
return function(*args, **kwargs)
async def get_history(self, address):
self.get_history_called.append(address)
self.address = address
@ -121,8 +124,9 @@ class TestSynchronization(LedgerTestCase):
)
class MocHeaderNetwork:
class MocHeaderNetwork(MockNetwork):
def __init__(self, responses):
super().__init__(None, None)
self.responses = responses
async def get_headers(self, height, blocks):

View file

@ -70,8 +70,10 @@ class BaseHeaders:
return True
def __getitem__(self, height) -> dict:
assert not isinstance(height, slice), \
"Slicing of header chain has not been implemented yet."
if isinstance(height, slice):
raise NotImplementedError("Slicing of header chain has not been implemented yet.")
if not 0 <= height <= self.height:
raise IndexError(f"{height} is out of bounds, current height: {self.height}")
return self.deserialize(height, self.get_raw_header(height))
def get_raw_header(self, height) -> bytes:

View file

@ -287,7 +287,7 @@ class BaseLedger(metaclass=LedgerRegistry):
subscription_update = False
if not headers:
header_response = await self.network.get_headers(height, 2001)
header_response = await self.network.retriable_call(self.network.get_headers, height, 2001)
headers = header_response['hex']
if not headers:
@ -394,7 +394,7 @@ class BaseLedger(metaclass=LedgerRegistry):
if local_status == remote_status:
return
remote_history = await self.network.get_history(address)
remote_history = await self.network.retriable_call(self.network.get_history, address)
cache_tasks = []
synced_history = StringIO()
@ -447,6 +447,17 @@ class BaseLedger(metaclass=LedgerRegistry):
if address_manager is not None:
await address_manager.ensure_address_gap()
local_status, local_history = await self.get_local_status_and_history(address)
if local_status != remote_status:
log.debug(
"Wallet is out of sync after syncing. Remote: %s with %d items, local: %s with %d items",
remote_status, len(remote_history), local_status, len(local_history)
)
log.debug("local: %s", local_history)
log.debug("remote: %s", remote_history)
else:
log.debug("Sync completed for: %s", address)
async def cache_transaction(self, txid, remote_height):
cache_item = self._tx_cache.get(txid)
if cache_item is None:
@ -466,7 +477,7 @@ class BaseLedger(metaclass=LedgerRegistry):
if tx is None:
# fetch from network
_raw = await self.network.get_transaction(txid)
_raw = await self.network.retriable_call(self.network.get_transaction, txid)
if _raw:
tx = self.transaction_class(unhexlify(_raw))
await self.maybe_verify_transaction(tx, remote_height)
@ -486,8 +497,8 @@ class BaseLedger(metaclass=LedgerRegistry):
async def maybe_verify_transaction(self, tx, remote_height):
tx.height = remote_height
if 0 < remote_height <= len(self.headers):
merkle = await self.network.get_merkle(tx.id, remote_height)
if 0 < remote_height < len(self.headers):
merkle = await self.network.retriable_call(self.network.get_merkle, tx.id, remote_height)
merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash)
header = self.headers[remote_height]
tx.position = merkle['pos']
@ -501,6 +512,7 @@ class BaseLedger(metaclass=LedgerRegistry):
return None
def broadcast(self, tx):
# broadcast cant be a retriable call yet
return self.network.broadcast(hexlify(tx.raw).decode())
async def wait(self, tx: basetransaction.BaseTransaction, height=-1, timeout=None):
@ -522,4 +534,4 @@ class BaseLedger(metaclass=LedgerRegistry):
)) for address_record in records
], timeout=timeout)
if pending:
raise TimeoutError('Timed out waiting for transaction.')
raise asyncio.TimeoutError('Timed out waiting for transaction.')

View file

@ -1,9 +1,8 @@
import logging
import asyncio
from asyncio import CancelledError
from time import time
from typing import List
import socket
from operator import itemgetter
from typing import Dict, Optional, Tuple
from time import perf_counter
from torba.rpc import RPCSession as BaseClientSession, Connector, RPCError
@ -14,71 +13,146 @@ log = logging.getLogger(__name__)
class ClientSession(BaseClientSession):
def __init__(self, *args, network, server, timeout=30, **kwargs):
def __init__(self, *args, network, server, timeout=30, on_connect_callback=None, **kwargs):
self.network = network
self.server = server
super().__init__(*args, **kwargs)
self._on_disconnect_controller = StreamController()
self.on_disconnected = self._on_disconnect_controller.stream
self.bw_limit = self.framer.max_size = self.max_errors = 1 << 32
self.framer.max_size = self.max_errors = 1 << 32
self.bw_limit = -1
self.timeout = timeout
self.max_seconds_idle = timeout * 2
self.ping_task = None
self.response_time: Optional[float] = None
self.connection_latency: Optional[float] = None
self._response_samples = 0
self.pending_amount = 0
self._on_connect_cb = on_connect_callback or (lambda: None)
self.trigger_urgent_reconnect = asyncio.Event()
self._semaphore = asyncio.Semaphore(int(self.timeout))
@property
def available(self):
return not self.is_closing() and self._can_send.is_set() and self.response_time is not None
@property
def server_address_and_port(self) -> Optional[Tuple[str, int]]:
if not self.transport:
return None
return self.transport.get_extra_info('peername')
async def send_timed_server_version_request(self, args=(), timeout=None):
timeout = timeout or self.timeout
log.debug("send version request to %s:%i", *self.server)
start = perf_counter()
result = await asyncio.wait_for(
super().send_request('server.version', args), timeout=timeout
)
current_response_time = perf_counter() - start
response_sum = (self.response_time or 0) * self._response_samples + current_response_time
self.response_time = response_sum / (self._response_samples + 1)
self._response_samples += 1
return result
async def send_request(self, method, args=()):
try:
return await asyncio.wait_for(super().send_request(method, args), timeout=self.timeout)
except RPCError as e:
log.warning("Wallet server returned an error. Code: %s Message: %s", *e.args)
raise e
except asyncio.TimeoutError:
self.abort()
raise
self.pending_amount += 1
async with self._semaphore:
return await self._send_request(method, args)
async def ping_forever(self):
# TODO: change to 'ping' on newer protocol (above 1.2)
while not self.is_closing():
if (time() - self.last_send) > self.max_seconds_idle:
async def _send_request(self, method, args=()):
log.debug("send %s to %s:%i", method, *self.server)
try:
await self.send_request('server.banner')
except:
self.abort()
if method == 'server.version':
reply = await self.send_timed_server_version_request(args, self.timeout)
else:
reply = await asyncio.wait_for(
super().send_request(method, args), timeout=self.timeout
)
log.debug("got reply for %s from %s:%i", method, *self.server)
return reply
except RPCError as e:
log.warning("Wallet server (%s:%i) returned an error. Code: %s Message: %s",
*self.server, *e.args)
raise e
except ConnectionError:
log.warning("connection to %s:%i lost", *self.server)
self.synchronous_close()
raise asyncio.CancelledError(f"connection to {self.server[0]}:{self.server[1]} lost")
except asyncio.TimeoutError:
log.info("timeout sending %s to %s:%i", method, *self.server)
raise
await asyncio.sleep(self.max_seconds_idle//3)
except asyncio.CancelledError:
log.info("cancelled sending %s to %s:%i", method, *self.server)
self.synchronous_close()
raise
finally:
self.pending_amount -= 1
async def ensure_session(self):
# Handles reconnecting and maintaining a session alive
# TODO: change to 'ping' on newer protocol (above 1.2)
retry_delay = default_delay = 1.0
while True:
try:
if self.is_closing():
await self.create_connection(self.timeout)
await self.ensure_server_version()
self._on_connect_cb()
if (perf_counter() - self.last_send) > self.max_seconds_idle or self.response_time is None:
await self.ensure_server_version()
retry_delay = default_delay
except (asyncio.TimeoutError, OSError):
await self.close()
retry_delay = min(60, retry_delay * 2)
log.debug("Wallet server timeout (retry in %s seconds): %s:%d", retry_delay, *self.server)
try:
await asyncio.wait_for(self.trigger_urgent_reconnect.wait(), timeout=retry_delay)
except asyncio.TimeoutError:
pass
finally:
self.trigger_urgent_reconnect.clear()
async def ensure_server_version(self, required='1.2', timeout=3):
return await asyncio.wait_for(
self.send_request('server.version', [__version__, required]), timeout=timeout
)
async def create_connection(self, timeout=6):
connector = Connector(lambda: self, *self.server)
start = perf_counter()
await asyncio.wait_for(connector.create_connection(), timeout=timeout)
self.ping_task = asyncio.create_task(self.ping_forever())
self.connection_latency = perf_counter() - start
async def handle_request(self, request):
controller = self.network.subscription_controllers[request.method]
controller.add(request.args)
def connection_lost(self, exc):
log.debug("Connection lost: %s:%d", *self.server)
super().connection_lost(exc)
self.response_time = None
self.connection_latency = None
self._response_samples = 0
self.pending_amount = 0
self._on_disconnect_controller.add(True)
if self.ping_task:
self.ping_task.cancel()
class BaseNetwork:
def __init__(self, ledger):
self.config = ledger.config
self.client: ClientSession = None
self.session_pool: SessionPool = None
self.session_pool = SessionPool(network=self, timeout=self.config.get('connect_timeout', 6))
self.client: Optional[ClientSession] = None
self.running = False
self.remote_height: int = 0
self._on_connected_controller = StreamController()
self.on_connected = self._on_connected_controller.stream
self._on_header_controller = StreamController()
self._on_header_controller = StreamController(merge_repeated_events=True)
self.on_header = self._on_header_controller.stream
self._on_status_controller = StreamController()
self._on_status_controller = StreamController(merge_repeated_events=True)
self.on_status = self._on_status_controller.stream
self.subscription_controllers = {
@ -86,79 +160,72 @@ class BaseNetwork:
'blockchain.address.subscribe': self._on_status_controller,
}
async def switch_to_fastest(self):
try:
client = await asyncio.wait_for(self.session_pool.wait_for_fastest_session(), 30)
except asyncio.TimeoutError:
if self.client:
await self.client.close()
self.client = None
for session in self.session_pool.sessions:
session.synchronous_close()
log.warning("not connected to any wallet servers")
return
current_client = self.client
self.client = client
log.info("Switching to SPV wallet server: %s:%d", *self.client.server)
self._on_connected_controller.add(True)
try:
self._update_remote_height((await self.subscribe_headers(),))
log.info("Subscribed to headers: %s:%d", *self.client.server)
except asyncio.TimeoutError:
if self.client:
await self.client.close()
self.client = current_client
return
self.session_pool.new_connection_event.clear()
return await self.session_pool.new_connection_event.wait()
async def start(self):
self.running = True
connect_timeout = self.config.get('connect_timeout', 6)
self.session_pool = SessionPool(network=self, timeout=connect_timeout)
self.session_pool.start(self.config['default_servers'])
self.on_header.listen(self._update_remote_height)
while True:
try:
self.client = await self.pick_fastest_session()
if self.is_connected:
await self.ensure_server_version()
self._update_remote_height((await self.subscribe_headers(),))
log.info("Successfully connected to SPV wallet server: %s:%d", *self.client.server)
self._on_connected_controller.add(True)
await self.client.on_disconnected.first
except CancelledError:
self.running = False
except asyncio.TimeoutError:
log.warning("Timed out while trying to find a server!")
except Exception: # pylint: disable=broad-except
log.exception("Exception while trying to find a server!")
if not self.running:
return
elif self.client:
await self.client.close()
self.client.connection.cancel_pending_requests()
while self.running:
await self.switch_to_fastest()
async def stop(self):
self.running = False
if self.session_pool:
self.session_pool.stop()
if self.is_connected:
disconnected = self.client.on_disconnected.first
await self.client.close()
await disconnected
@property
def is_connected(self):
return self.client is not None and not self.client.is_closing()
return self.client and not self.client.is_closing()
def rpc(self, list_or_method, args):
if self.is_connected:
return self.client.send_request(list_or_method, args)
def rpc(self, list_or_method, args, session=None):
session = session or self.session_pool.fastest_session
if session:
return session.send_request(list_or_method, args)
else:
self.session_pool.trigger_nodelay_connect()
raise ConnectionError("Attempting to send rpc request when connection is not available.")
async def pick_fastest_session(self):
sessions = await self.session_pool.get_online_sessions()
done, pending = await asyncio.wait([
self.probe_session(session)
for session in sessions if not session.is_closing()
], return_when='FIRST_COMPLETED')
for task in pending:
task.cancel()
for session in done:
return await session
async def probe_session(self, session: ClientSession):
await session.send_request('server.banner')
return session
async def retriable_call(self, function, *args, **kwargs):
while self.running:
if not self.is_connected:
log.warning("Wallet server unavailable, waiting for it to come back and retry.")
await self.on_connected.first
await self.session_pool.wait_for_fastest_session()
try:
return await function(*args, **kwargs)
except asyncio.TimeoutError:
log.warning("Wallet server call timed out, retrying.")
except ConnectionError:
pass
raise asyncio.CancelledError() # if we got here, we are shutting down
def _update_remote_height(self, header_args):
self.remote_height = header_args[0]["height"]
def ensure_server_version(self, required='1.2'):
return self.rpc('server.version', [__version__, required])
def broadcast(self, raw_transaction):
return self.rpc('blockchain.transaction.broadcast', [raw_transaction])
def get_history(self, address):
return self.rpc('blockchain.address.get_history', [address])
def get_transaction(self, tx_hash):
return self.rpc('blockchain.transaction.get', [tx_hash])
@ -171,84 +238,111 @@ class BaseNetwork:
def get_headers(self, height, count=10000):
return self.rpc('blockchain.block.headers', [height, count])
def subscribe_headers(self):
return self.rpc('blockchain.headers.subscribe', [True])
# --- Subscribes, history and broadcasts are always aimed towards the master client directly
def get_history(self, address):
return self.rpc('blockchain.address.get_history', [address], session=self.client)
def subscribe_address(self, address):
return self.rpc('blockchain.address.subscribe', [address])
def broadcast(self, raw_transaction):
return self.rpc('blockchain.transaction.broadcast', [raw_transaction], session=self.client)
def subscribe_headers(self):
return self.rpc('blockchain.headers.subscribe', [True], session=self.client)
async def subscribe_address(self, address):
try:
return await self.rpc('blockchain.address.subscribe', [address], session=self.client)
except asyncio.TimeoutError:
# abort and cancel, we cant lose a subscription, it will happen again on reconnect
self.client.abort()
raise asyncio.CancelledError()
class SessionPool:
def __init__(self, network: BaseNetwork, timeout: float):
self.network = network
self.sessions: List[ClientSession] = []
self._dead_servers: List[ClientSession] = []
self.maintain_connections_task = None
self.sessions: Dict[ClientSession, Optional[asyncio.Task]] = dict()
self.timeout = timeout
# triggered when the master server is out, to speed up reconnect
self._lost_master = asyncio.Event()
self.new_connection_event = asyncio.Event()
@property
def online(self):
for session in self.sessions:
if not session.is_closing():
return True
return False
return any(not session.is_closing() for session in self.sessions)
@property
def available_sessions(self):
return [session for session in self.sessions if session.available]
@property
def fastest_session(self):
if not self.available_sessions:
return None
return min(
[((session.response_time + session.connection_latency) * (session.pending_amount + 1), session)
for session in self.available_sessions],
key=itemgetter(0)
)[1]
def _get_session_connect_callback(self, session: ClientSession):
loop = asyncio.get_event_loop()
def callback():
duplicate_connections = [
s for s in self.sessions
if s is not session and s.server_address_and_port == session.server_address_and_port
]
already_connected = None if not duplicate_connections else duplicate_connections[0]
if already_connected:
self.sessions.pop(session).cancel()
session.synchronous_close()
log.debug("wallet server %s resolves to the same server as %s, rechecking in an hour",
session.server[0], already_connected.server[0])
loop.call_later(3600, self._connect_session, session.server)
return
self.new_connection_event.set()
log.info("connected to %s:%i", *session.server)
return callback
def _connect_session(self, server: Tuple[str, int]):
session = None
for s in self.sessions:
if s.server == server:
session = s
break
if not session:
session = ClientSession(
network=self.network, server=server
)
session._on_connect_cb = self._get_session_connect_callback(session)
task = self.sessions.get(session, None)
if not task or task.done():
task = asyncio.create_task(session.ensure_session())
task.add_done_callback(lambda _: self.ensure_connections())
self.sessions[session] = task
def start(self, default_servers):
self.sessions = [
ClientSession(network=self.network, server=server)
for server in default_servers
]
self.maintain_connections_task = asyncio.create_task(self.ensure_connections())
for server in default_servers:
self._connect_session(server)
def stop(self):
if self.maintain_connections_task:
self.maintain_connections_task.cancel()
for task in self.sessions.values():
task.cancel()
self.sessions.clear()
def ensure_connections(self):
for session in self.sessions:
if not session.is_closing():
session.abort()
self.sessions, self._dead_servers, self.maintain_connections_task = [], [], None
self._connect_session(session.server)
async def ensure_connections(self):
while True:
await asyncio.gather(*[
self.ensure_connection(session)
for session in self.sessions
], return_exceptions=True)
try:
await asyncio.wait_for(self._lost_master.wait(), timeout=3)
except asyncio.TimeoutError:
pass
self._lost_master.clear()
if not self.sessions:
self.sessions.extend(self._dead_servers)
self._dead_servers = []
def trigger_nodelay_connect(self):
# used when other parts of the system sees we might have internet back
# bypasses the retry interval
for session in self.sessions:
session.trigger_urgent_reconnect.set()
async def ensure_connection(self, session):
self._dead_servers.append(session)
self.sessions.remove(session)
try:
if session.is_closing():
await session.create_connection(self.timeout)
await asyncio.wait_for(session.send_request('server.banner'), timeout=self.timeout)
self.sessions.append(session)
self._dead_servers.remove(session)
except asyncio.TimeoutError:
log.warning("Timeout connecting to %s:%d", *session.server)
except asyncio.CancelledError: # pylint: disable=try-except-raise
raise
except socket.gaierror:
log.warning("Could not resolve IP for %s", session.server[0])
except Exception as err: # pylint: disable=broad-except
if 'Connect call failed' in str(err):
log.warning("Could not connect to %s:%d", *session.server)
else:
log.exception("Connecting to %s:%d raised an exception:", *session.server)
async def get_online_sessions(self):
while not self.online:
self._lost_master.set()
await asyncio.sleep(0.5)
return self.sessions
async def wait_for_fastest_session(self):
while not self.fastest_session:
self.trigger_nodelay_connect()
self.new_connection_event.clear()
await self.new_connection_event.wait()
return self.fastest_session

View file

@ -33,6 +33,7 @@ __all__ = ('JSONRPC', 'JSONRPCv1', 'JSONRPCv2', 'JSONRPCLoose',
import itertools
import json
import typing
import asyncio
from functools import partial
from numbers import Number
@ -745,9 +746,8 @@ class JSONRPCConnection(object):
self._protocol = item
return self.receive_message(message)
def cancel_pending_requests(self):
"""Cancel all pending requests."""
exception = CancelledError()
def raise_pending_requests(self, exception):
exception = exception or asyncio.TimeoutError()
for request, event in self._requests.values():
event.result = exception
event.set()

View file

@ -103,7 +103,7 @@ class SessionBase(asyncio.Protocol):
# Force-close a connection if a send doesn't succeed in this time
self.max_send_delay = 60
# Statistics. The RPC object also keeps its own statistics.
self.start_time = time.time()
self.start_time = time.perf_counter()
self.errors = 0
self.send_count = 0
self.send_size = 0
@ -123,7 +123,7 @@ class SessionBase(asyncio.Protocol):
# A non-positive value means not to limit concurrency
if self.bw_limit <= 0:
return
now = time.time()
now = time.perf_counter()
# Reduce the recorded usage in proportion to the elapsed time
refund = (now - self.bw_time) * (self.bw_limit / 3600)
self.bw_charge = max(0, self.bw_charge - int(refund))
@ -146,7 +146,7 @@ class SessionBase(asyncio.Protocol):
await asyncio.wait_for(self._can_send.wait(), secs)
except asyncio.TimeoutError:
self.abort()
raise asyncio.CancelledError(f'task timed out after {secs}s')
raise asyncio.TimeoutError(f'task timed out after {secs}s')
async def _send_message(self, message):
if not self._can_send.is_set():
@ -156,7 +156,7 @@ class SessionBase(asyncio.Protocol):
self.send_size += len(framed_message)
self._using_bandwidth(len(framed_message))
self.send_count += 1
self.last_send = time.time()
self.last_send = time.perf_counter()
if self.verbosity >= 4:
self.logger.debug(f'Sending framed message {framed_message}')
self.transport.write(framed_message)
@ -215,6 +215,7 @@ class SessionBase(asyncio.Protocol):
self._address = None
self.transport = None
self._task_group.cancel()
if self._pm_task:
self._pm_task.cancel()
# Release waiting tasks
self._can_send.set()
@ -253,6 +254,7 @@ class SessionBase(asyncio.Protocol):
if self.transport:
self.transport.abort()
# TODO: replace with synchronous_close
async def close(self, *, force_after=30):
"""Close the connection and return when closed."""
self._close()
@ -262,6 +264,11 @@ class SessionBase(asyncio.Protocol):
self.abort()
await self._pm_task
def synchronous_close(self):
self._close()
if self._pm_task and not self._pm_task.done():
self._pm_task.cancel()
class MessageSession(SessionBase):
"""Session class for protocols where messages are not tied to responses,
@ -296,7 +303,7 @@ class MessageSession(SessionBase):
)
self._bump_errors()
else:
self.last_recv = time.time()
self.last_recv = time.perf_counter()
self.recv_count += 1
if self.recv_count % 10 == 0:
await self._update_concurrency()
@ -416,7 +423,7 @@ class RPCSession(SessionBase):
self.logger.warning(f'{e!r}')
continue
self.last_recv = time.time()
self.last_recv = time.perf_counter()
self.recv_count += 1
if self.recv_count % 10 == 0:
await self._update_concurrency()
@ -456,7 +463,7 @@ class RPCSession(SessionBase):
def connection_lost(self, exc):
# Cancel pending requests and message processing
self.connection.cancel_pending_requests()
self.connection.raise_pending_requests(exc)
super().connection_lost(exc)
# External API
@ -473,6 +480,8 @@ class RPCSession(SessionBase):
async def send_request(self, method, args=()):
"""Send an RPC request over the network."""
if self.is_closing():
raise asyncio.TimeoutError("Trying to send request on a recently dropped connection.")
message, event = self.connection.send_request(Request(method, args))
await self._send_message(message)
await event.wait()

View file

@ -258,7 +258,7 @@ class SessionManager:
session_timeout = self.env.session_timeout
while True:
await sleep(session_timeout // 10)
stale_cutoff = time.time() - session_timeout
stale_cutoff = time.perf_counter() - session_timeout
stale_sessions = [session for session in self.sessions
if session.last_recv < stale_cutoff]
if stale_sessions:

View file

@ -45,10 +45,12 @@ class BroadcastSubscription:
class StreamController:
def __init__(self):
def __init__(self, merge_repeated_events=False):
self.stream = Stream(self)
self._first_subscription = None
self._last_subscription = None
self._last_event = None
self._merge_repeated = merge_repeated_events
@property
def has_listener(self):
@ -76,8 +78,10 @@ class StreamController:
return f
def add(self, event):
skip = self._merge_repeated and event == self._last_event
self._last_event = event
return self._notify_and_ensure_future(
lambda subscription: subscription._add(event)
lambda subscription: None if skip else subscription._add(event)
)
def add_error(self, exception):
@ -141,8 +145,8 @@ class Stream:
def first(self):
future = asyncio.get_event_loop().create_future()
subscription = self.listen(
lambda value: self._cancel_and_callback(subscription, future, value),
lambda exception: self._cancel_and_error(subscription, future, exception)
lambda value: not future.done() and self._cancel_and_callback(subscription, future, value),
lambda exception: not future.done() and self._cancel_and_error(subscription, future, exception)
)
return future