forked from LBRYCommunity/lbry-sdk
374 lines
15 KiB
Python
374 lines
15 KiB
Python
import logging
|
|
import asyncio
|
|
from time import perf_counter
|
|
from operator import itemgetter
|
|
from typing import Dict, Optional, Tuple
|
|
|
|
from lbry import __version__
|
|
from lbry.wallet.rpc import RPCSession as BaseClientSession, Connector, RPCError, ProtocolError
|
|
from lbry.wallet.stream import StreamController
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
class ClientSession(BaseClientSession):
|
|
def __init__(self, *args, network, server, timeout=30, on_connect_callback=None, **kwargs):
|
|
self.network = network
|
|
self.server = server
|
|
super().__init__(*args, **kwargs)
|
|
self._on_disconnect_controller = StreamController()
|
|
self.on_disconnected = self._on_disconnect_controller.stream
|
|
self.framer.max_size = self.max_errors = 1 << 32
|
|
self.bw_limit = -1
|
|
self.timeout = timeout
|
|
self.max_seconds_idle = timeout * 2
|
|
self.response_time: Optional[float] = None
|
|
self.connection_latency: Optional[float] = None
|
|
self._response_samples = 0
|
|
self.pending_amount = 0
|
|
self._on_connect_cb = on_connect_callback or (lambda: None)
|
|
self.trigger_urgent_reconnect = asyncio.Event()
|
|
|
|
@property
|
|
def available(self):
|
|
return not self.is_closing() and self.response_time is not None
|
|
|
|
@property
|
|
def server_address_and_port(self) -> Optional[Tuple[str, int]]:
|
|
if not self.transport:
|
|
return None
|
|
return self.transport.get_extra_info('peername')
|
|
|
|
async def send_timed_server_version_request(self, args=(), timeout=None):
|
|
timeout = timeout or self.timeout
|
|
log.debug("send version request to %s:%i", *self.server)
|
|
start = perf_counter()
|
|
result = await asyncio.wait_for(
|
|
super().send_request('server.version', args), timeout=timeout
|
|
)
|
|
current_response_time = perf_counter() - start
|
|
response_sum = (self.response_time or 0) * self._response_samples + current_response_time
|
|
self.response_time = response_sum / (self._response_samples + 1)
|
|
self._response_samples += 1
|
|
return result
|
|
|
|
async def send_request(self, method, args=()):
|
|
self.pending_amount += 1
|
|
log.debug("send %s to %s:%i", method, *self.server)
|
|
try:
|
|
if method == 'server.version':
|
|
return await self.send_timed_server_version_request(args, self.timeout)
|
|
request = asyncio.ensure_future(super().send_request(method, args))
|
|
while not request.done():
|
|
done, pending = await asyncio.wait([request], timeout=self.timeout)
|
|
if pending:
|
|
log.debug("Time since last packet: %s", perf_counter() - self.last_packet_received)
|
|
if (perf_counter() - self.last_packet_received) < self.timeout:
|
|
continue
|
|
log.info("timeout sending %s to %s:%i", method, *self.server)
|
|
raise asyncio.TimeoutError
|
|
if done:
|
|
return request.result()
|
|
except (RPCError, ProtocolError) as e:
|
|
log.warning("Wallet server (%s:%i) returned an error. Code: %s Message: %s",
|
|
*self.server, *e.args)
|
|
raise e
|
|
except ConnectionError:
|
|
log.warning("connection to %s:%i lost", *self.server)
|
|
self.synchronous_close()
|
|
raise
|
|
except asyncio.CancelledError:
|
|
log.info("cancelled sending %s to %s:%i", method, *self.server)
|
|
self.synchronous_close()
|
|
raise
|
|
finally:
|
|
self.pending_amount -= 1
|
|
|
|
async def ensure_session(self):
|
|
# Handles reconnecting and maintaining a session alive
|
|
# TODO: change to 'ping' on newer protocol (above 1.2)
|
|
retry_delay = default_delay = 1.0
|
|
while True:
|
|
try:
|
|
if self.is_closing():
|
|
await self.create_connection(self.timeout)
|
|
await self.ensure_server_version()
|
|
self._on_connect_cb()
|
|
if (perf_counter() - self.last_send) > self.max_seconds_idle or self.response_time is None:
|
|
await self.ensure_server_version()
|
|
retry_delay = default_delay
|
|
except RPCError as e:
|
|
log.warning("Server error, ignoring for 1h: %s:%d -- %s", *self.server, e.message)
|
|
retry_delay = 60 * 60
|
|
except (asyncio.TimeoutError, OSError):
|
|
await self.close()
|
|
retry_delay = min(60, retry_delay * 2)
|
|
log.debug("Wallet server timeout (retry in %s seconds): %s:%d", retry_delay, *self.server)
|
|
try:
|
|
await asyncio.wait_for(self.trigger_urgent_reconnect.wait(), timeout=retry_delay)
|
|
except asyncio.TimeoutError:
|
|
pass
|
|
finally:
|
|
self.trigger_urgent_reconnect.clear()
|
|
|
|
async def ensure_server_version(self, required=None, timeout=3):
|
|
required = required or self.network.PROTOCOL_VERSION
|
|
return await asyncio.wait_for(
|
|
self.send_request('server.version', [__version__, required]), timeout=timeout
|
|
)
|
|
|
|
async def create_connection(self, timeout=6):
|
|
connector = Connector(lambda: self, *self.server)
|
|
start = perf_counter()
|
|
await asyncio.wait_for(connector.create_connection(), timeout=timeout)
|
|
self.connection_latency = perf_counter() - start
|
|
|
|
async def handle_request(self, request):
|
|
controller = self.network.subscription_controllers[request.method]
|
|
controller.add(request.args)
|
|
|
|
def connection_lost(self, exc):
|
|
log.debug("Connection lost: %s:%d", *self.server)
|
|
super().connection_lost(exc)
|
|
self.response_time = None
|
|
self.connection_latency = None
|
|
self._response_samples = 0
|
|
self.pending_amount = 0
|
|
self._on_disconnect_controller.add(True)
|
|
|
|
|
|
class Network:
|
|
|
|
PROTOCOL_VERSION = __version__
|
|
|
|
def __init__(self, ledger):
|
|
self.ledger = ledger
|
|
self.session_pool = SessionPool(network=self, timeout=self.config.get('connect_timeout', 6))
|
|
self.client: Optional[ClientSession] = None
|
|
self._switch_task: Optional[asyncio.Task] = None
|
|
self.running = False
|
|
self.remote_height: int = 0
|
|
self._concurrency = asyncio.Semaphore(16)
|
|
|
|
self._on_connected_controller = StreamController()
|
|
self.on_connected = self._on_connected_controller.stream
|
|
|
|
self._on_header_controller = StreamController(merge_repeated_events=True)
|
|
self.on_header = self._on_header_controller.stream
|
|
|
|
self._on_status_controller = StreamController(merge_repeated_events=True)
|
|
self.on_status = self._on_status_controller.stream
|
|
|
|
self.subscription_controllers = {
|
|
'blockchain.headers.subscribe': self._on_header_controller,
|
|
'blockchain.address.subscribe': self._on_status_controller,
|
|
}
|
|
|
|
@property
|
|
def config(self):
|
|
return self.ledger.config
|
|
|
|
async def switch_forever(self):
|
|
while self.running:
|
|
if self.is_connected:
|
|
await self.client.on_disconnected.first
|
|
self.client = None
|
|
continue
|
|
self.client = await self.session_pool.wait_for_fastest_session()
|
|
log.info("Switching to SPV wallet server: %s:%d", *self.client.server)
|
|
try:
|
|
self._update_remote_height((await self.subscribe_headers(),))
|
|
self._on_connected_controller.add(True)
|
|
log.info("Subscribed to headers: %s:%d", *self.client.server)
|
|
except (asyncio.TimeoutError, ConnectionError):
|
|
log.info("Switching to %s:%d timed out, closing and retrying.", *self.client.server)
|
|
self.client.synchronous_close()
|
|
self.client = None
|
|
|
|
async def start(self):
|
|
self.running = True
|
|
self._switch_task = asyncio.ensure_future(self.switch_forever())
|
|
# this may become unnecessary when there are no more bugs found,
|
|
# but for now it helps understanding log reports
|
|
self._switch_task.add_done_callback(lambda _: log.info("Wallet client switching task stopped."))
|
|
self.session_pool.start(self.config['default_servers'])
|
|
self.on_header.listen(self._update_remote_height)
|
|
|
|
async def stop(self):
|
|
if self.running:
|
|
self.running = False
|
|
self._switch_task.cancel()
|
|
self.session_pool.stop()
|
|
|
|
@property
|
|
def is_connected(self):
|
|
return self.client and not self.client.is_closing()
|
|
|
|
def rpc(self, list_or_method, args, restricted=True):
|
|
session = self.client if restricted else self.session_pool.fastest_session
|
|
if session and not session.is_closing():
|
|
return session.send_request(list_or_method, args)
|
|
else:
|
|
self.session_pool.trigger_nodelay_connect()
|
|
raise ConnectionError("Attempting to send rpc request when connection is not available.")
|
|
|
|
async def retriable_call(self, function, *args, **kwargs):
|
|
async with self._concurrency:
|
|
while self.running:
|
|
if not self.is_connected:
|
|
log.warning("Wallet server unavailable, waiting for it to come back and retry.")
|
|
await self.on_connected.first
|
|
await self.session_pool.wait_for_fastest_session()
|
|
try:
|
|
return await function(*args, **kwargs)
|
|
except asyncio.TimeoutError:
|
|
log.warning("Wallet server call timed out, retrying.")
|
|
except ConnectionError:
|
|
pass
|
|
raise asyncio.CancelledError() # if we got here, we are shutting down
|
|
|
|
def _update_remote_height(self, header_args):
|
|
self.remote_height = header_args[0]["height"]
|
|
|
|
def get_transaction(self, tx_hash, known_height=None):
|
|
# use any server if its old, otherwise restrict to who gave us the history
|
|
restricted = known_height in (None, -1, 0) or 0 > known_height > self.remote_height - 10
|
|
return self.rpc('blockchain.transaction.get', [tx_hash], restricted)
|
|
|
|
def get_transaction_height(self, tx_hash, known_height=None):
|
|
restricted = not known_height or 0 > known_height > self.remote_height - 10
|
|
return self.rpc('blockchain.transaction.get_height', [tx_hash], restricted)
|
|
|
|
def get_merkle(self, tx_hash, height):
|
|
restricted = 0 > height > self.remote_height - 10
|
|
return self.rpc('blockchain.transaction.get_merkle', [tx_hash, height], restricted)
|
|
|
|
def get_headers(self, height, count=10000, b64=False):
|
|
restricted = height >= self.remote_height - 100
|
|
return self.rpc('blockchain.block.headers', [height, count, 0, b64], restricted)
|
|
|
|
# --- Subscribes, history and broadcasts are always aimed towards the master client directly
|
|
def get_history(self, address):
|
|
return self.rpc('blockchain.address.get_history', [address], True)
|
|
|
|
def broadcast(self, raw_transaction):
|
|
return self.rpc('blockchain.transaction.broadcast', [raw_transaction], True)
|
|
|
|
def subscribe_headers(self):
|
|
return self.rpc('blockchain.headers.subscribe', [True], True)
|
|
|
|
async def subscribe_address(self, address):
|
|
try:
|
|
return await self.rpc('blockchain.address.subscribe', [address], True)
|
|
except asyncio.TimeoutError:
|
|
# abort and cancel, we can't lose a subscription, it will happen again on reconnect
|
|
if self.client:
|
|
self.client.abort()
|
|
raise asyncio.CancelledError()
|
|
|
|
def unsubscribe_address(self, address):
|
|
return self.rpc('blockchain.address.unsubscribe', [address], True)
|
|
|
|
def get_server_features(self):
|
|
return self.rpc('server.features', (), restricted=True)
|
|
|
|
def get_claims_by_ids(self, claim_ids):
|
|
return self.rpc('blockchain.claimtrie.getclaimsbyids', claim_ids)
|
|
|
|
def resolve(self, urls):
|
|
return self.rpc('blockchain.claimtrie.resolve', urls)
|
|
|
|
def claim_search(self, **kwargs):
|
|
return self.rpc('blockchain.claimtrie.search', kwargs)
|
|
|
|
|
|
class SessionPool:
|
|
|
|
def __init__(self, network: Network, timeout: float):
|
|
self.network = network
|
|
self.sessions: Dict[ClientSession, Optional[asyncio.Task]] = dict()
|
|
self.timeout = timeout
|
|
self.new_connection_event = asyncio.Event()
|
|
|
|
@property
|
|
def online(self):
|
|
return any(not session.is_closing() for session in self.sessions)
|
|
|
|
@property
|
|
def available_sessions(self):
|
|
return (session for session in self.sessions if session.available)
|
|
|
|
@property
|
|
def fastest_session(self):
|
|
if not self.online:
|
|
return None
|
|
return min(
|
|
[((session.response_time + session.connection_latency) * (session.pending_amount + 1), session)
|
|
for session in self.available_sessions] or [(0, None)],
|
|
key=itemgetter(0)
|
|
)[1]
|
|
|
|
def _get_session_connect_callback(self, session: ClientSession):
|
|
loop = asyncio.get_event_loop()
|
|
|
|
def callback():
|
|
duplicate_connections = [
|
|
s for s in self.sessions
|
|
if s is not session and s.server_address_and_port == session.server_address_and_port
|
|
]
|
|
already_connected = None if not duplicate_connections else duplicate_connections[0]
|
|
if already_connected:
|
|
self.sessions.pop(session).cancel()
|
|
session.synchronous_close()
|
|
log.debug("wallet server %s resolves to the same server as %s, rechecking in an hour",
|
|
session.server[0], already_connected.server[0])
|
|
loop.call_later(3600, self._connect_session, session.server)
|
|
return
|
|
self.new_connection_event.set()
|
|
log.info("connected to %s:%i", *session.server)
|
|
|
|
return callback
|
|
|
|
def _connect_session(self, server: Tuple[str, int]):
|
|
session = None
|
|
for s in self.sessions:
|
|
if s.server == server:
|
|
session = s
|
|
break
|
|
if not session:
|
|
session = ClientSession(
|
|
network=self.network, server=server
|
|
)
|
|
session._on_connect_cb = self._get_session_connect_callback(session)
|
|
task = self.sessions.get(session, None)
|
|
if not task or task.done():
|
|
task = asyncio.create_task(session.ensure_session())
|
|
task.add_done_callback(lambda _: self.ensure_connections())
|
|
self.sessions[session] = task
|
|
|
|
def start(self, default_servers):
|
|
for server in default_servers:
|
|
self._connect_session(server)
|
|
|
|
def stop(self):
|
|
for session, task in self.sessions.items():
|
|
task.cancel()
|
|
session.synchronous_close()
|
|
self.sessions.clear()
|
|
|
|
def ensure_connections(self):
|
|
for session in self.sessions:
|
|
self._connect_session(session.server)
|
|
|
|
def trigger_nodelay_connect(self):
|
|
# used when other parts of the system sees we might have internet back
|
|
# bypasses the retry interval
|
|
for session in self.sessions:
|
|
session.trigger_urgent_reconnect.set()
|
|
|
|
async def wait_for_fastest_session(self):
|
|
while not self.fastest_session:
|
|
self.trigger_nodelay_connect()
|
|
self.new_connection_event.clear()
|
|
await self.new_connection_event.wait()
|
|
return self.fastest_session
|