lbry-sdk/lbry/wallet/network.py
2020-01-10 13:57:52 -05:00

378 lines
16 KiB
Python

import logging
import asyncio
from time import perf_counter
from operator import itemgetter
from typing import Dict, Optional, Tuple
from lbry import __version__
from lbry.wallet.rpc import RPCSession as BaseClientSession, Connector, RPCError, ProtocolError
from lbry.wallet.stream import StreamController
log = logging.getLogger(__name__)
class ClientSession(BaseClientSession):
def __init__(self, *args, network, server, timeout=30, on_connect_callback=None, **kwargs):
self.network = network
self.server = server
super().__init__(*args, **kwargs)
self._on_disconnect_controller = StreamController()
self.on_disconnected = self._on_disconnect_controller.stream
self.framer.max_size = self.max_errors = 1 << 32
self.timeout = timeout
self.max_seconds_idle = timeout * 2
self.response_time: Optional[float] = None
self.connection_latency: Optional[float] = None
self._response_samples = 0
self.pending_amount = 0
self._on_connect_cb = on_connect_callback or (lambda: None)
self.trigger_urgent_reconnect = asyncio.Event()
@property
def available(self):
return not self.is_closing() and self.response_time is not None
@property
def server_address_and_port(self) -> Optional[Tuple[str, int]]:
if not self.transport:
return None
return self.transport.get_extra_info('peername')
async def send_timed_server_version_request(self, args=(), timeout=None):
timeout = timeout or self.timeout
log.debug("send version request to %s:%i", *self.server)
start = perf_counter()
result = await asyncio.wait_for(
super().send_request('server.version', args), timeout=timeout
)
current_response_time = perf_counter() - start
response_sum = (self.response_time or 0) * self._response_samples + current_response_time
self.response_time = response_sum / (self._response_samples + 1)
self._response_samples += 1
return result
async def send_request(self, method, args=()):
self.pending_amount += 1
log.debug("send %s to %s:%i", method, *self.server)
try:
if method == 'server.version':
return await self.send_timed_server_version_request(args, self.timeout)
request = asyncio.ensure_future(super().send_request(method, args))
while not request.done():
done, pending = await asyncio.wait([request], timeout=self.timeout)
if pending:
log.debug("Time since last packet: %s", perf_counter() - self.last_packet_received)
if (perf_counter() - self.last_packet_received) < self.timeout:
continue
log.info("timeout sending %s to %s:%i", method, *self.server)
raise asyncio.TimeoutError
if done:
return request.result()
except (RPCError, ProtocolError) as e:
log.warning("Wallet server (%s:%i) returned an error. Code: %s Message: %s",
*self.server, *e.args)
raise e
except ConnectionError:
log.warning("connection to %s:%i lost", *self.server)
self.synchronous_close()
raise
except asyncio.CancelledError:
log.info("cancelled sending %s to %s:%i", method, *self.server)
self.synchronous_close()
raise
finally:
self.pending_amount -= 1
async def ensure_session(self):
# Handles reconnecting and maintaining a session alive
# TODO: change to 'ping' on newer protocol (above 1.2)
retry_delay = default_delay = 1.0
while True:
try:
if self.is_closing():
await self.create_connection(self.timeout)
await self.ensure_server_version()
self._on_connect_cb()
if (perf_counter() - self.last_send) > self.max_seconds_idle or self.response_time is None:
await self.ensure_server_version()
retry_delay = default_delay
except RPCError as e:
log.warning("Server error, ignoring for 1h: %s:%d -- %s", *self.server, e.message)
retry_delay = 60 * 60
except (asyncio.TimeoutError, OSError):
await self.close()
retry_delay = min(60, retry_delay * 2)
log.debug("Wallet server timeout (retry in %s seconds): %s:%d", retry_delay, *self.server)
try:
await asyncio.wait_for(self.trigger_urgent_reconnect.wait(), timeout=retry_delay)
except asyncio.TimeoutError:
pass
finally:
self.trigger_urgent_reconnect.clear()
async def ensure_server_version(self, required=None, timeout=3):
required = required or self.network.PROTOCOL_VERSION
return await asyncio.wait_for(
self.send_request('server.version', [__version__, required]), timeout=timeout
)
async def create_connection(self, timeout=6):
connector = Connector(lambda: self, *self.server)
start = perf_counter()
await asyncio.wait_for(connector.create_connection(), timeout=timeout)
self.connection_latency = perf_counter() - start
async def handle_request(self, request):
controller = self.network.subscription_controllers[request.method]
controller.add(request.args)
def connection_lost(self, exc):
log.debug("Connection lost: %s:%d", *self.server)
super().connection_lost(exc)
self.response_time = None
self.connection_latency = None
self._response_samples = 0
self.pending_amount = 0
self._on_disconnect_controller.add(True)
class Network:
PROTOCOL_VERSION = __version__
def __init__(self, ledger):
self.ledger = ledger
self.session_pool = SessionPool(network=self, timeout=self.config.get('connect_timeout', 6))
self.client: Optional[ClientSession] = None
self._switch_task: Optional[asyncio.Task] = None
self.running = False
self.remote_height: int = 0
self._concurrency = asyncio.Semaphore(16)
self._on_connected_controller = StreamController()
self.on_connected = self._on_connected_controller.stream
self._on_header_controller = StreamController(merge_repeated_events=True)
self.on_header = self._on_header_controller.stream
self._on_status_controller = StreamController(merge_repeated_events=True)
self.on_status = self._on_status_controller.stream
self.subscription_controllers = {
'blockchain.headers.subscribe': self._on_header_controller,
'blockchain.address.subscribe': self._on_status_controller,
}
@property
def config(self):
return self.ledger.config
async def switch_forever(self):
while self.running:
if self.is_connected:
await self.client.on_disconnected.first
self.client = None
continue
self.client = await self.session_pool.wait_for_fastest_session()
log.info("Switching to SPV wallet server: %s:%d", *self.client.server)
try:
self._update_remote_height((await self.subscribe_headers(),))
self._on_connected_controller.add(True)
log.info("Subscribed to headers: %s:%d", *self.client.server)
except (asyncio.TimeoutError, ConnectionError):
log.info("Switching to %s:%d timed out, closing and retrying.", *self.client.server)
self.client.synchronous_close()
self.client = None
async def start(self):
self.running = True
self._switch_task = asyncio.ensure_future(self.switch_forever())
# this may become unnecessary when there are no more bugs found,
# but for now it helps understanding log reports
self._switch_task.add_done_callback(lambda _: log.info("Wallet client switching task stopped."))
self.session_pool.start(self.config['default_servers'])
self.on_header.listen(self._update_remote_height)
async def stop(self):
if self.running:
self.running = False
self._switch_task.cancel()
self.session_pool.stop()
@property
def is_connected(self):
return self.client and not self.client.is_closing()
def rpc(self, list_or_method, args, restricted=True):
session = self.client if restricted else self.session_pool.fastest_session
if session and not session.is_closing():
return session.send_request(list_or_method, args)
else:
self.session_pool.trigger_nodelay_connect()
raise ConnectionError("Attempting to send rpc request when connection is not available.")
async def retriable_call(self, function, *args, **kwargs):
async with self._concurrency:
while self.running:
if not self.is_connected:
log.warning("Wallet server unavailable, waiting for it to come back and retry.")
await self.on_connected.first
await self.session_pool.wait_for_fastest_session()
try:
return await function(*args, **kwargs)
except asyncio.TimeoutError:
log.warning("Wallet server call timed out, retrying.")
except ConnectionError:
pass
raise asyncio.CancelledError() # if we got here, we are shutting down
def _update_remote_height(self, header_args):
self.remote_height = header_args[0]["height"]
def get_transaction(self, tx_hash, known_height=None):
# use any server if its old, otherwise restrict to who gave us the history
restricted = known_height in (None, -1, 0) or 0 > known_height > self.remote_height - 10
return self.rpc('blockchain.transaction.get', [tx_hash], restricted)
def get_transaction_height(self, tx_hash, known_height=None):
restricted = not known_height or 0 > known_height > self.remote_height - 10
return self.rpc('blockchain.transaction.get_height', [tx_hash], restricted)
def get_merkle(self, tx_hash, height):
restricted = 0 > height > self.remote_height - 10
return self.rpc('blockchain.transaction.get_merkle', [tx_hash, height], restricted)
def get_headers(self, height, count=10000, b64=False):
restricted = height >= self.remote_height - 100
return self.rpc('blockchain.block.headers', [height, count, 0, b64], restricted)
# --- Subscribes, history and broadcasts are always aimed towards the master client directly
def get_history(self, address):
return self.rpc('blockchain.address.get_history', [address], True)
def broadcast(self, raw_transaction):
return self.rpc('blockchain.transaction.broadcast', [raw_transaction], True)
def subscribe_headers(self):
return self.rpc('blockchain.headers.subscribe', [True], True)
async def subscribe_address(self, address, *addresses):
addresses = list((address, ) + addresses)
try:
return await self.rpc('blockchain.address.subscribe', addresses, True)
except asyncio.TimeoutError:
log.warning(
"timed out subscribing to addresses from %s:%i",
*self.client.server_address_and_port
)
# abort and cancel, we can't lose a subscription, it will happen again on reconnect
if self.client:
self.client.abort()
raise asyncio.CancelledError()
def unsubscribe_address(self, address):
return self.rpc('blockchain.address.unsubscribe', [address], True)
def get_server_features(self):
return self.rpc('server.features', (), restricted=True)
def get_claims_by_ids(self, claim_ids):
return self.rpc('blockchain.claimtrie.getclaimsbyids', claim_ids)
def resolve(self, urls):
return self.rpc('blockchain.claimtrie.resolve', urls)
def claim_search(self, **kwargs):
return self.rpc('blockchain.claimtrie.search', kwargs)
class SessionPool:
def __init__(self, network: Network, timeout: float):
self.network = network
self.sessions: Dict[ClientSession, Optional[asyncio.Task]] = dict()
self.timeout = timeout
self.new_connection_event = asyncio.Event()
@property
def online(self):
return any(not session.is_closing() for session in self.sessions)
@property
def available_sessions(self):
return (session for session in self.sessions if session.available)
@property
def fastest_session(self):
if not self.online:
return None
return min(
[((session.response_time + session.connection_latency) * (session.pending_amount + 1), session)
for session in self.available_sessions] or [(0, None)],
key=itemgetter(0)
)[1]
def _get_session_connect_callback(self, session: ClientSession):
loop = asyncio.get_event_loop()
def callback():
duplicate_connections = [
s for s in self.sessions
if s is not session and s.server_address_and_port == session.server_address_and_port
]
already_connected = None if not duplicate_connections else duplicate_connections[0]
if already_connected:
self.sessions.pop(session).cancel()
session.synchronous_close()
log.debug("wallet server %s resolves to the same server as %s, rechecking in an hour",
session.server[0], already_connected.server[0])
loop.call_later(3600, self._connect_session, session.server)
return
self.new_connection_event.set()
log.info("connected to %s:%i", *session.server)
return callback
def _connect_session(self, server: Tuple[str, int]):
session = None
for s in self.sessions:
if s.server == server:
session = s
break
if not session:
session = ClientSession(
network=self.network, server=server
)
session._on_connect_cb = self._get_session_connect_callback(session)
task = self.sessions.get(session, None)
if not task or task.done():
task = asyncio.create_task(session.ensure_session())
task.add_done_callback(lambda _: self.ensure_connections())
self.sessions[session] = task
def start(self, default_servers):
for server in default_servers:
self._connect_session(server)
def stop(self):
for session, task in self.sessions.items():
task.cancel()
session.synchronous_close()
self.sessions.clear()
def ensure_connections(self):
for session in self.sessions:
self._connect_session(session.server)
def trigger_nodelay_connect(self):
# used when other parts of the system sees we might have internet back
# bypasses the retry interval
for session in self.sessions:
session.trigger_urgent_reconnect.set()
async def wait_for_fastest_session(self):
while not self.fastest_session:
self.trigger_nodelay_connect()
self.new_connection_event.clear()
await self.new_connection_event.wait()
return self.fastest_session