Compare commits

...

5 commits

Author SHA1 Message Date
Jack Robison 55fb76f7f9
fix test 2020-11-20 13:48:04 -05:00
Jack Robison 6d7b94fe1b
pylint 2020-11-20 13:48:04 -05:00
Jack Robison d6cf0e2699
fix 2020-11-20 13:48:04 -05:00
Jack Robison 730a67c8d6
tests 2020-11-20 13:48:04 -05:00
Jack Robison 13b473403e
one spv connection per loaded wallet 2020-11-20 13:48:04 -05:00
9 changed files with 218 additions and 97 deletions

View file

@ -1303,8 +1303,9 @@ class Daemon(metaclass=JSONRPCServerType):
'name': SingleKey.name if single_key else HierarchicalDeterministic.name 'name': SingleKey.name if single_key else HierarchicalDeterministic.name
} }
) )
if self.ledger.network.is_connected: await self.ledger._subscribe_accounts([account])
await self.ledger.subscribe_account(account) elif wallet.accounts:
await self.ledger._subscribe_accounts(wallet.accounts)
wallet.save() wallet.save()
if not skip_on_startup: if not skip_on_startup:
with self.conf.update_config() as c: with self.conf.update_config() as c:
@ -1331,9 +1332,9 @@ class Daemon(metaclass=JSONRPCServerType):
if not os.path.exists(wallet_path): if not os.path.exists(wallet_path):
raise Exception(f"Wallet at path '{wallet_path}' was not found.") raise Exception(f"Wallet at path '{wallet_path}' was not found.")
wallet = self.wallet_manager.import_wallet(wallet_path) wallet = self.wallet_manager.import_wallet(wallet_path)
if self.ledger.network.is_connected:
for account in wallet.accounts: if not self.ledger.network.is_connected(wallet.id):
await self.ledger.subscribe_account(account) await self.ledger._subscribe_accounts(wallet.accounts)
return wallet return wallet
@requires("wallet") @requires("wallet")
@ -1619,7 +1620,7 @@ class Daemon(metaclass=JSONRPCServerType):
} }
) )
wallet.save() wallet.save()
if self.ledger.network.is_connected: if self.ledger.network.is_connected(wallet.id):
await self.ledger.subscribe_account(account) await self.ledger.subscribe_account(account)
return account return account
@ -1647,7 +1648,7 @@ class Daemon(metaclass=JSONRPCServerType):
} }
) )
wallet.save() wallet.save()
if self.ledger.network.is_connected: if self.ledger.network.is_connected(wallet.id):
await self.ledger.subscribe_account(account) await self.ledger.subscribe_account(account)
return account return account
@ -1863,7 +1864,7 @@ class Daemon(metaclass=JSONRPCServerType):
wallet_changed = False wallet_changed = False
if data is not None: if data is not None:
added_accounts = wallet.merge(self.wallet_manager, password, data) added_accounts = wallet.merge(self.wallet_manager, password, data)
if added_accounts and self.ledger.network.is_connected: if added_accounts and self.ledger.network.is_connected(wallet.id):
if blocking: if blocking:
await asyncio.wait([ await asyncio.wait([
a.ledger.subscribe_account(a) for a in added_accounts a.ledger.subscribe_account(a) for a in added_accounts
@ -2957,7 +2958,7 @@ class Daemon(metaclass=JSONRPCServerType):
'public_key': data['holding_public_key'], 'public_key': data['holding_public_key'],
'address_generator': {'name': 'single-address'} 'address_generator': {'name': 'single-address'}
}) })
if self.ledger.network.is_connected: if self.ledger.network.is_connected(wallet.id):
await self.ledger.subscribe_account(account) await self.ledger.subscribe_account(account)
await self.ledger._update_tasks.done.wait() await self.ledger._update_tasks.done.wait()
# Case 3: the holding address has changed and we can't create or find an account for it # Case 3: the holding address has changed and we can't create or find an account for it

View file

@ -122,7 +122,7 @@ class Ledger(metaclass=LedgerRegistry):
self.headers.checkpoints = self.checkpoints self.headers.checkpoints = self.checkpoints
self.network: Network = self.config.get('network') or Network(self) self.network: Network = self.config.get('network') or Network(self)
self.network.on_header.listen(self.receive_header) self.network.on_header.listen(self.receive_header)
self.network.on_status.listen(self.process_status_update) # self.network.on_status.listen(self.process_status_update)
self.network.on_connected.listen(self.join_network) self.network.on_connected.listen(self.join_network)
self.accounts = [] self.accounts = []
@ -324,12 +324,15 @@ class Ledger(metaclass=LedgerRegistry):
self.db.open(), self.db.open(),
self.headers.open() self.headers.open()
]) ])
fully_synced = self.on_ready.first fully_synced = self.on_ready.first
asyncio.create_task(self.network.start()) asyncio.create_task(self.network.start())
await self.network.on_connected.first await self.network.on_connected.first
async with self._header_processing_lock: async with self._header_processing_lock:
await self._update_tasks.add(self.initial_headers_sync()) await self._update_tasks.add(self.initial_headers_sync())
await fully_synced await fully_synced
await self.db.release_all_outputs() await self.db.release_all_outputs()
await asyncio.gather(*(a.maybe_migrate_certificates() for a in self.accounts)) await asyncio.gather(*(a.maybe_migrate_certificates() for a in self.accounts))
await asyncio.gather(*(a.save_max_gap() for a in self.accounts)) await asyncio.gather(*(a.save_max_gap() for a in self.accounts))
@ -448,10 +451,16 @@ class Ledger(metaclass=LedgerRegistry):
) )
async def subscribe_accounts(self): async def subscribe_accounts(self):
if self.network.is_connected and self.accounts: return await self._subscribe_accounts(self.accounts)
log.info("Subscribe to %i accounts", len(self.accounts))
async def _subscribe_accounts(self, accounts):
wallet_ids = {account.wallet.id for account in accounts}
await self.network.connect_wallets(*wallet_ids)
# if self.network.is_connected and self.accounts:
if accounts:
log.info("Subscribe to %i accounts", len(accounts))
await asyncio.wait([ await asyncio.wait([
self.subscribe_account(a) for a in self.accounts self.subscribe_account(a) for a in accounts
]) ])
async def subscribe_account(self, account: Account): async def subscribe_account(self, account: Account):
@ -460,8 +469,10 @@ class Ledger(metaclass=LedgerRegistry):
await account.ensure_address_gap() await account.ensure_address_gap()
async def unsubscribe_account(self, account: Account): async def unsubscribe_account(self, account: Account):
session = self.network.get_wallet_session(account.wallet)
for address in await account.get_addresses(): for address in await account.get_addresses():
await self.network.unsubscribe_address(address) await self.network.unsubscribe_address(address, session=session)
await self.network.close_wallet_session(account.wallet)
async def announce_addresses(self, address_manager: AddressManager, addresses: List[str]): async def announce_addresses(self, address_manager: AddressManager, addresses: List[str]):
await self.subscribe_addresses(address_manager, addresses) await self.subscribe_addresses(address_manager, addresses)
@ -470,28 +481,29 @@ class Ledger(metaclass=LedgerRegistry):
) )
async def subscribe_addresses(self, address_manager: AddressManager, addresses: List[str], batch_size: int = 1000): async def subscribe_addresses(self, address_manager: AddressManager, addresses: List[str], batch_size: int = 1000):
if self.network.is_connected and addresses: if self.network.is_connected(address_manager.account.wallet.id) and addresses:
session = self.network.get_wallet_session(address_manager.account.wallet)
addresses_remaining = list(addresses) addresses_remaining = list(addresses)
while addresses_remaining: while addresses_remaining:
batch = addresses_remaining[:batch_size] batch = addresses_remaining[:batch_size]
results = await self.network.subscribe_address(*batch) results = await self.network.subscribe_address(session, *batch)
for address, remote_status in zip(batch, results): for address, remote_status in zip(batch, results):
self._update_tasks.add(self.update_history(address, remote_status, address_manager)) self._update_tasks.add(self.update_history(session, address, remote_status, address_manager))
addresses_remaining = addresses_remaining[batch_size:] addresses_remaining = addresses_remaining[batch_size:]
if self.network.client and self.network.client.server_address_and_port: if session and session.server_address_and_port:
log.info("subscribed to %i/%i addresses on %s:%i", len(addresses) - len(addresses_remaining), log.info("subscribed to %i/%i addresses on %s:%i", len(addresses) - len(addresses_remaining),
len(addresses), *self.network.client.server_address_and_port) len(addresses), *session.server_address_and_port)
if self.network.client and self.network.client.server_address_and_port: if session and session.server_address_and_port:
log.info( log.info(
"finished subscribing to %i addresses on %s:%i", len(addresses), "finished subscribing to %i addresses on %s:%i", len(addresses),
*self.network.client.server_address_and_port *session.server_address_and_port
) )
def process_status_update(self, update): def process_status_update(self, client, update):
address, remote_status = update address, remote_status = update
self._update_tasks.add(self.update_history(address, remote_status)) self._update_tasks.add(self.update_history(client, address, remote_status))
async def update_history(self, address, remote_status, address_manager: AddressManager = None, async def update_history(self, client, address, remote_status, address_manager: AddressManager = None,
reattempt_update: bool = True): reattempt_update: bool = True):
async with self._address_update_locks[address]: async with self._address_update_locks[address]:
self._known_addresses_out_of_sync.discard(address) self._known_addresses_out_of_sync.discard(address)
@ -500,7 +512,9 @@ class Ledger(metaclass=LedgerRegistry):
if local_status == remote_status: if local_status == remote_status:
return True return True
remote_history = await self.network.retriable_call(self.network.get_history, address) # remote_history = await self.network.retriable_call(self.network.get_history, address, client)
remote_history = await self.network.get_history(address, session=client)
remote_history = list(map(itemgetter('tx_hash', 'height'), remote_history)) remote_history = list(map(itemgetter('tx_hash', 'height'), remote_history))
we_need = set(remote_history) - set(local_history) we_need = set(remote_history) - set(local_history)
if not we_need: if not we_need:
@ -563,7 +577,7 @@ class Ledger(metaclass=LedgerRegistry):
"request %i transactions, %i/%i for %s are already synced", len(to_request), len(synced_txs), "request %i transactions, %i/%i for %s are already synced", len(to_request), len(synced_txs),
len(remote_history), address len(remote_history), address
) )
requested_txes = await self._request_transaction_batch(to_request, len(remote_history), address) requested_txes = await self._request_transaction_batch(to_request, len(remote_history), address, client)
for tx in requested_txes: for tx in requested_txes:
pending_synced_history[tx_indexes[tx.id]] = f"{tx.id}:{tx.height}:" pending_synced_history[tx_indexes[tx.id]] = f"{tx.id}:{tx.height}:"
synced_txs.append(tx) synced_txs.append(tx)
@ -603,7 +617,7 @@ class Ledger(metaclass=LedgerRegistry):
if self._tx_cache.get(txid) is not cache_item: if self._tx_cache.get(txid) is not cache_item:
log.warning("tx cache corrupted while syncing %s, reattempt sync=%s", address, reattempt_update) log.warning("tx cache corrupted while syncing %s, reattempt sync=%s", address, reattempt_update)
if reattempt_update: if reattempt_update:
return await self.update_history(address, remote_status, address_manager, False) return await self.update_history(client, address, remote_status, address_manager, False)
return False return False
local_status, local_history = \ local_status, local_history = \
@ -741,7 +755,7 @@ class Ledger(metaclass=LedgerRegistry):
await _single_batch(batch) await _single_batch(batch)
return transactions return transactions
async def _request_transaction_batch(self, to_request, remote_history_size, address): async def _request_transaction_batch(self, to_request, remote_history_size, address, session):
header_cache = {} header_cache = {}
batches = [[]] batches = [[]]
remote_heights = {} remote_heights = {}
@ -766,7 +780,8 @@ class Ledger(metaclass=LedgerRegistry):
async def _single_batch(batch): async def _single_batch(batch):
this_batch_synced = [] this_batch_synced = []
batch_result = await self.network.retriable_call(self.network.get_transaction_batch, batch) # batch_result = await self.network.retriable_call(self.network.get_transaction_batch, batch, session)
batch_result = await self.network.get_transaction_batch(batch, session)
for txid, (raw, merkle) in batch_result.items(): for txid, (raw, merkle) in batch_result.items():
remote_height = remote_heights[txid] remote_height = remote_heights[txid]
merkle_height = merkle['block_height'] merkle_height = merkle['block_height']

View file

@ -6,7 +6,6 @@ from operator import itemgetter
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from functools import partial from functools import partial
from typing import Dict, Optional, Tuple from typing import Dict, Optional, Tuple
import aiohttp import aiohttp
from lbry import __version__ from lbry import __version__
@ -14,6 +13,7 @@ from lbry.error import IncompatibleWalletServerError
from lbry.wallet.rpc import RPCSession as BaseClientSession, Connector, RPCError, ProtocolError from lbry.wallet.rpc import RPCSession as BaseClientSession, Connector, RPCError, ProtocolError
from lbry.wallet.stream import StreamController from lbry.wallet.stream import StreamController
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -33,11 +33,26 @@ class ClientSession(BaseClientSession):
self.pending_amount = 0 self.pending_amount = 0
self._on_connect_cb = on_connect_callback or (lambda: None) self._on_connect_cb = on_connect_callback or (lambda: None)
self.trigger_urgent_reconnect = asyncio.Event() self.trigger_urgent_reconnect = asyncio.Event()
self._connected = asyncio.Event()
self._disconnected = asyncio.Event()
self._disconnected.set()
self._on_status_controller = StreamController(merge_repeated_events=True)
self.on_status = self._on_status_controller.stream
self.on_status.listen(partial(self.network.ledger.process_status_update, self))
self.subscription_controllers = {
'blockchain.headers.subscribe': self.network._on_header_controller,
'blockchain.address.subscribe': self._on_status_controller,
}
@property @property
def available(self): def available(self):
return not self.is_closing() and self.response_time is not None return not self.is_closing() and self.response_time is not None
@property
def is_connected(self) -> bool:
return self._connected.is_set()
@property @property
def server_address_and_port(self) -> Optional[Tuple[str, int]]: def server_address_and_port(self) -> Optional[Tuple[str, int]]:
if not self.transport: if not self.transport:
@ -98,7 +113,6 @@ class ClientSession(BaseClientSession):
async def ensure_session(self): async def ensure_session(self):
# Handles reconnecting and maintaining a session alive # Handles reconnecting and maintaining a session alive
# TODO: change to 'ping' on newer protocol (above 1.2)
retry_delay = default_delay = 1.0 retry_delay = default_delay = 1.0
while True: while True:
try: try:
@ -144,7 +158,7 @@ class ClientSession(BaseClientSession):
self.connection_latency = perf_counter() - start self.connection_latency = perf_counter() - start
async def handle_request(self, request): async def handle_request(self, request):
controller = self.network.subscription_controllers[request.method] controller = self.subscription_controllers[request.method]
controller.add(request.args) controller.add(request.args)
def connection_lost(self, exc): def connection_lost(self, exc):
@ -154,6 +168,13 @@ class ClientSession(BaseClientSession):
self.connection_latency = None self.connection_latency = None
self._response_samples = 0 self._response_samples = 0
self._on_disconnect_controller.add(True) self._on_disconnect_controller.add(True)
self._connected.clear()
self._disconnected.set()
def connection_made(self, transport):
super(ClientSession, self).connection_made(transport)
self._disconnected.clear()
self._connected.set()
class Network: class Network:
@ -165,6 +186,7 @@ class Network:
self.ledger = ledger self.ledger = ledger
self.session_pool = SessionPool(network=self, timeout=self.config.get('connect_timeout', 6)) self.session_pool = SessionPool(network=self, timeout=self.config.get('connect_timeout', 6))
self.client: Optional[ClientSession] = None self.client: Optional[ClientSession] = None
self.clients: Dict[str, ClientSession] = {}
self.server_features = None self.server_features = None
self._switch_task: Optional[asyncio.Task] = None self._switch_task: Optional[asyncio.Task] = None
self.running = False self.running = False
@ -177,27 +199,34 @@ class Network:
self._on_header_controller = StreamController(merge_repeated_events=True) self._on_header_controller = StreamController(merge_repeated_events=True)
self.on_header = self._on_header_controller.stream self.on_header = self._on_header_controller.stream
self._on_status_controller = StreamController(merge_repeated_events=True)
self.on_status = self._on_status_controller.stream
self.subscription_controllers = { self.subscription_controllers = {
'blockchain.headers.subscribe': self._on_header_controller, 'blockchain.headers.subscribe': self._on_header_controller,
'blockchain.address.subscribe': self._on_status_controller,
} }
self.aiohttp_session: Optional[aiohttp.ClientSession] = None self.aiohttp_session: Optional[aiohttp.ClientSession] = None
def get_wallet_session(self, wallet):
return self.clients[wallet.id]
async def close_wallet_session(self, wallet):
session = self.clients.pop(wallet.id)
ensure_connect_task = self.session_pool.wallet_session_tasks.pop(session)
if ensure_connect_task and not ensure_connect_task.done():
ensure_connect_task.cancel()
await session.close()
@property @property
def config(self): def config(self):
return self.ledger.config return self.ledger.config
async def switch_forever(self): async def switch_forever(self):
while self.running: while self.running:
if self.is_connected: if self.is_connected():
await self.client.on_disconnected.first await self.client.on_disconnected.first
self.server_features = None self.server_features = None
self.client = None self.client = None
continue continue
self.client = await self.session_pool.wait_for_fastest_session() self.client = await self.session_pool.wait_for_fastest_session()
log.info("Switching to SPV wallet server: %s:%d", *self.client.server) log.info("Switching to SPV wallet server: %s:%d", *self.client.server)
try: try:
@ -220,6 +249,11 @@ class Network:
self._switch_task.add_done_callback(lambda _: log.info("Wallet client switching task stopped.")) self._switch_task.add_done_callback(lambda _: log.info("Wallet client switching task stopped."))
self.session_pool.start(self.config['default_servers']) self.session_pool.start(self.config['default_servers'])
self.on_header.listen(self._update_remote_height) self.on_header.listen(self._update_remote_height)
await self.ledger.subscribe_accounts()
async def connect_wallets(self, *wallet_ids):
if wallet_ids:
await asyncio.wait([self.session_pool.connect_wallet(wallet_id) for wallet_id in wallet_ids])
async def stop(self): async def stop(self):
if self.running: if self.running:
@ -228,9 +262,19 @@ class Network:
self._switch_task.cancel() self._switch_task.cancel()
self.session_pool.stop() self.session_pool.stop()
@property def is_connected(self, wallet_id: str = None):
def is_connected(self): if wallet_id is None:
return self.client and not self.client.is_closing() if not self.client:
return False
return self.client.is_connected
if wallet_id not in self.clients:
return False
return self.clients[wallet_id].is_connected
def connect_wallet_client(self, wallet_id: str, server: Tuple[str, int]):
client = ClientSession(network=self, server=server)
self.clients[wallet_id] = client
return client
def rpc(self, list_or_method, args, restricted=True, session=None): def rpc(self, list_or_method, args, restricted=True, session=None):
session = session or (self.client if restricted else self.session_pool.fastest_session) session = session or (self.client if restricted else self.session_pool.fastest_session)
@ -256,7 +300,8 @@ class Network:
raise asyncio.CancelledError() # if we got here, we are shutting down raise asyncio.CancelledError() # if we got here, we are shutting down
@asynccontextmanager @asynccontextmanager
async def single_call_context(self, function, *args, **kwargs): async def fastest_connection_context(self):
if not self.is_connected: if not self.is_connected:
log.warning("Wallet server unavailable, waiting for it to come back and retry.") log.warning("Wallet server unavailable, waiting for it to come back and retry.")
await self.on_connected.first await self.on_connected.first
@ -264,6 +309,38 @@ class Network:
server = self.session_pool.fastest_session.server server = self.session_pool.fastest_session.server
session = ClientSession(network=self, server=server) session = ClientSession(network=self, server=server)
async def call_with_reconnect(function, *args, **kwargs):
nonlocal session
while self.running:
if not session.is_connected:
try:
await session.create_connection()
except asyncio.TimeoutError:
if not session.is_connected:
log.warning("Wallet server unavailable, waiting for it to come back and retry.")
await self.on_connected.first
await self.session_pool.wait_for_fastest_session()
server = self.session_pool.fastest_session.server
session = ClientSession(network=self, server=server)
try:
return await partial(function, *args, session_override=session, **kwargs)(*args, **kwargs)
except asyncio.TimeoutError:
log.warning("'%s' failed, retrying", function.__name__)
try:
yield call_with_reconnect
finally:
await session.close()
@asynccontextmanager
async def single_call_context(self, function, *args, **kwargs):
if not self.is_connected():
log.warning("Wallet server unavailable, waiting for it to come back and retry.")
await self.on_connected.first
await self.session_pool.wait_for_fastest_session()
server = self.session_pool.fastest_session.server
session = ClientSession(network=self, server=server)
async def call_with_reconnect(*a, **kw): async def call_with_reconnect(*a, **kw):
while self.running: while self.running:
if not session.available: if not session.available:
@ -280,71 +357,71 @@ class Network:
def _update_remote_height(self, header_args): def _update_remote_height(self, header_args):
self.remote_height = header_args[0]["height"] self.remote_height = header_args[0]["height"]
def get_transaction(self, tx_hash, known_height=None): def get_transaction(self, tx_hash, known_height=None, session=None):
# use any server if its old, otherwise restrict to who gave us the history # use any server if its old, otherwise restrict to who gave us the history
restricted = known_height in (None, -1, 0) or 0 > known_height > self.remote_height - 10 restricted = known_height in (None, -1, 0) or 0 > known_height > self.remote_height - 10
return self.rpc('blockchain.transaction.get', [tx_hash], restricted) return self.rpc('blockchain.transaction.get', [tx_hash], restricted, session=session)
def get_transaction_batch(self, txids, restricted=True, session=None): def get_transaction_batch(self, txids, restricted=True, session=None):
# use any server if its old, otherwise restrict to who gave us the history # use any server if its old, otherwise restrict to who gave us the history
return self.rpc('blockchain.transaction.get_batch', txids, restricted, session) return self.rpc('blockchain.transaction.get_batch', txids, restricted, session=session)
def get_transaction_and_merkle(self, tx_hash, known_height=None): def get_transaction_and_merkle(self, tx_hash, known_height=None, session=None):
# use any server if its old, otherwise restrict to who gave us the history # use any server if its old, otherwise restrict to who gave us the history
restricted = known_height in (None, -1, 0) or 0 > known_height > self.remote_height - 10 restricted = known_height in (None, -1, 0) or 0 > known_height > self.remote_height - 10
return self.rpc('blockchain.transaction.info', [tx_hash], restricted) return self.rpc('blockchain.transaction.info', [tx_hash], restricted, session=session)
def get_transaction_height(self, tx_hash, known_height=None): def get_transaction_height(self, tx_hash, known_height=None, session=None):
restricted = not known_height or 0 > known_height > self.remote_height - 10 restricted = not known_height or 0 > known_height > self.remote_height - 10
return self.rpc('blockchain.transaction.get_height', [tx_hash], restricted) return self.rpc('blockchain.transaction.get_height', [tx_hash], restricted, session=session)
def get_merkle(self, tx_hash, height): def get_merkle(self, tx_hash, height, session=None):
restricted = 0 > height > self.remote_height - 10 restricted = 0 > height > self.remote_height - 10
return self.rpc('blockchain.transaction.get_merkle', [tx_hash, height], restricted) return self.rpc('blockchain.transaction.get_merkle', [tx_hash, height], restricted, session=session)
def get_headers(self, height, count=10000, b64=False): def get_headers(self, height, count=10000, b64=False):
restricted = height >= self.remote_height - 100 restricted = height >= self.remote_height - 100
return self.rpc('blockchain.block.headers', [height, count, 0, b64], restricted) return self.rpc('blockchain.block.headers', [height, count, 0, b64], restricted)
# --- Subscribes, history and broadcasts are always aimed towards the master client directly # --- Subscribes, history and broadcasts are always aimed towards the master client directly
def get_history(self, address): def get_history(self, address, session=None):
return self.rpc('blockchain.address.get_history', [address], True) return self.rpc('blockchain.address.get_history', [address], True, session=session)
def broadcast(self, raw_transaction): def broadcast(self, raw_transaction, session=None):
return self.rpc('blockchain.transaction.broadcast', [raw_transaction], True) return self.rpc('blockchain.transaction.broadcast', [raw_transaction], True, session=session)
def subscribe_headers(self): def subscribe_headers(self):
return self.rpc('blockchain.headers.subscribe', [True], True) return self.rpc('blockchain.headers.subscribe', [True], True)
async def subscribe_address(self, address, *addresses): async def subscribe_address(self, session, address, *addresses):
addresses = list((address, ) + addresses) addresses = list((address, ) + addresses)
server_addr_and_port = self.client.server_address_and_port # on disconnect client will be None server_addr_and_port = session.server_address_and_port # on disconnect client will be None
try: try:
return await self.rpc('blockchain.address.subscribe', addresses, True) return await self.rpc('blockchain.address.subscribe', addresses, True, session=session)
except asyncio.TimeoutError: except asyncio.TimeoutError:
log.warning( log.warning(
"timed out subscribing to addresses from %s:%i", "timed out subscribing to addresses from %s:%i",
*server_addr_and_port *server_addr_and_port
) )
# abort and cancel, we can't lose a subscription, it will happen again on reconnect # abort and cancel, we can't lose a subscription, it will happen again on reconnect
if self.client: if session:
self.client.abort() session.abort()
raise asyncio.CancelledError() raise asyncio.CancelledError()
def unsubscribe_address(self, address): def unsubscribe_address(self, address, session=None):
return self.rpc('blockchain.address.unsubscribe', [address], True) return self.rpc('blockchain.address.unsubscribe', [address], True, session=session)
def get_server_features(self): def get_server_features(self, session=None):
return self.rpc('server.features', (), restricted=True) return self.rpc('server.features', (), restricted=True, session=session)
def get_claims_by_ids(self, claim_ids): def get_claims_by_ids(self, claim_ids):
return self.rpc('blockchain.claimtrie.getclaimsbyids', claim_ids) return self.rpc('blockchain.claimtrie.getclaimsbyids', claim_ids)
def resolve(self, urls, session_override=None): def resolve(self, urls, session_override=None):
return self.rpc('blockchain.claimtrie.resolve', urls, False, session_override) return self.rpc('blockchain.claimtrie.resolve', urls, False, session=session_override)
def claim_search(self, session_override=None, **kwargs): def claim_search(self, session_override=None, **kwargs):
return self.rpc('blockchain.claimtrie.search', kwargs, False, session_override) return self.rpc('blockchain.claimtrie.search', kwargs, False, session=session_override)
async def new_resolve(self, server, urls): async def new_resolve(self, server, urls):
message = {"method": "resolve", "params": {"urls": urls, "protobuf": True}} message = {"method": "resolve", "params": {"urls": urls, "protobuf": True}}
@ -371,6 +448,7 @@ class SessionPool:
def __init__(self, network: Network, timeout: float): def __init__(self, network: Network, timeout: float):
self.network = network self.network = network
self.sessions: Dict[ClientSession, Optional[asyncio.Task]] = dict() self.sessions: Dict[ClientSession, Optional[asyncio.Task]] = dict()
self.wallet_session_tasks: Dict[ClientSession, Optional[asyncio.Task]] = dict()
self.timeout = timeout self.timeout = timeout
self.new_connection_event = asyncio.Event() self.new_connection_event = asyncio.Event()
@ -430,6 +508,23 @@ class SessionPool:
task.add_done_callback(lambda _: self.ensure_connections()) task.add_done_callback(lambda _: self.ensure_connections())
self.sessions[session] = task self.sessions[session] = task
async def connect_wallet(self, wallet_id: str):
fastest = await self.wait_for_fastest_session()
connected = asyncio.Event()
if not self.network.is_connected(wallet_id):
session = self.network.connect_wallet_client(wallet_id, fastest.server)
session._on_connect_cb = connected.set
# session._on_connect_cb = self._get_session_connect_callback(session)
else:
connected.set()
session = self.network.clients[wallet_id]
task = self.wallet_session_tasks.get(session, None)
if not task or task.done():
task = asyncio.create_task(session.ensure_session())
# task.add_done_callback(lambda _: self.ensure_connections())
self.wallet_session_tasks[session] = task
await connected.wait()
def start(self, default_servers): def start(self, default_servers):
for server in default_servers: for server in default_servers:
self._connect_session(server) self._connect_session(server)

View file

@ -1227,10 +1227,7 @@ class LBRYElectrumX(SessionBase):
return await self.address_status(hashX) return await self.address_status(hashX)
async def hashX_unsubscribe(self, hashX, alias): async def hashX_unsubscribe(self, hashX, alias):
try: self.hashX_subs.pop(hashX, None)
del self.hashX_subs[hashX]
except ValueError:
pass
def address_to_hashX(self, address): def address_to_hashX(self, address):
try: try:

View file

@ -75,7 +75,7 @@ class ReconnectTests(IntegrationTestCase):
session.trigger_urgent_reconnect.set() session.trigger_urgent_reconnect.set()
await asyncio.wait_for(self.ledger.network.session_pool.new_connection_event.wait(), timeout=1) await asyncio.wait_for(self.ledger.network.session_pool.new_connection_event.wait(), timeout=1)
self.assertEqual(2, len(list(self.ledger.network.session_pool.available_sessions))) self.assertEqual(2, len(list(self.ledger.network.session_pool.available_sessions)))
self.assertTrue(self.ledger.network.is_connected) self.assertTrue(self.ledger.network.is_connected())
switch_event = self.ledger.network.on_connected.first switch_event = self.ledger.network.on_connected.first
await node2.stop(True) await node2.stop(True)
# secondary down, but primary is ok, do not switch! (switches trigger new on_connected events) # secondary down, but primary is ok, do not switch! (switches trigger new on_connected events)
@ -99,7 +99,7 @@ class ReconnectTests(IntegrationTestCase):
address1 = await self.account.receiving.get_or_create_usable_address() address1 = await self.account.receiving.get_or_create_usable_address()
# disconnect and send a new tx, should reconnect and get it # disconnect and send a new tx, should reconnect and get it
self.ledger.network.client.connection_lost(Exception()) self.ledger.network.client.connection_lost(Exception())
self.assertFalse(self.ledger.network.is_connected) self.assertFalse(self.ledger.network.is_connected())
sendtxid = await self.blockchain.send_to_address(address1, 1.1337) sendtxid = await self.blockchain.send_to_address(address1, 1.1337)
await asyncio.wait_for(self.on_transaction_id(sendtxid), 1.0) # mempool await asyncio.wait_for(self.on_transaction_id(sendtxid), 1.0) # mempool
await self.blockchain.generate(1) await self.blockchain.generate(1)
@ -122,7 +122,7 @@ class ReconnectTests(IntegrationTestCase):
sendtxid = await self.blockchain.send_to_address(address1, 42) sendtxid = await self.blockchain.send_to_address(address1, 42)
await self.blockchain.generate(1) await self.blockchain.generate(1)
# (this is just so the test doesn't hang forever if it doesn't reconnect) # (this is just so the test doesn't hang forever if it doesn't reconnect)
if not self.ledger.network.is_connected: if not self.ledger.network.is_connected():
await asyncio.wait_for(self.ledger.network.on_connected.first, timeout=1.0) await asyncio.wait_for(self.ledger.network.on_connected.first, timeout=1.0)
# omg, the burned cable still works! torba is fire proof! # omg, the burned cable still works! torba is fire proof!
await self.ledger.network.get_transaction(sendtxid) await self.ledger.network.get_transaction(sendtxid)
@ -130,11 +130,11 @@ class ReconnectTests(IntegrationTestCase):
async def test_timeout_then_reconnect(self): async def test_timeout_then_reconnect(self):
# tests that it connects back after some failed attempts # tests that it connects back after some failed attempts
await self.conductor.spv_node.stop() await self.conductor.spv_node.stop()
self.assertFalse(self.ledger.network.is_connected) self.assertFalse(self.ledger.network.is_connected())
await asyncio.sleep(0.2) # let it retry and fail once await asyncio.sleep(0.2) # let it retry and fail once
await self.conductor.spv_node.start(self.conductor.blockchain_node) await self.conductor.spv_node.start(self.conductor.blockchain_node)
await self.ledger.network.on_connected.first await self.ledger.network.on_connected.first
self.assertTrue(self.ledger.network.is_connected) self.assertTrue(self.ledger.network.is_connected())
async def test_online_but_still_unavailable(self): async def test_online_but_still_unavailable(self):
# Edge case. See issue #2445 for context # Edge case. See issue #2445 for context
@ -179,12 +179,18 @@ class ServerPickingTestCase(AsyncioTestCase):
], ],
'connect_timeout': 3 'connect_timeout': 3
}) })
ledger.accounts = []
async def subscribe_accounts():
pass
ledger.subscribe_accounts = subscribe_accounts
network = Network(ledger) network = Network(ledger)
self.addCleanup(network.stop) self.addCleanup(network.stop)
asyncio.ensure_future(network.start()) asyncio.ensure_future(network.start())
await asyncio.wait_for(network.on_connected.first, timeout=1) await asyncio.wait_for(network.on_connected.first, timeout=1)
self.assertTrue(network.is_connected) self.assertTrue(network.is_connected())
self.assertTupleEqual(network.client.server, ('127.0.0.1', 1337)) self.assertTupleEqual(network.client.server, ('127.0.0.1', 1337))
self.assertTrue(all([not session.is_closing() for session in network.session_pool.available_sessions])) self.assertTrue(all([not session.is_closing() for session in network.session_pool.available_sessions]))
# ensure we are connected to all of them after a while # ensure we are connected to all of them after a while

View file

@ -152,8 +152,11 @@ class BasicTransactionTests(IntegrationTestCase):
for batch in range(0, len(sends), 10): for batch in range(0, len(sends), 10):
txids = await asyncio.gather(*sends[batch:batch + 10]) txids = await asyncio.gather(*sends[batch:batch + 10])
await asyncio.wait([self.on_transaction_id(txid) for txid in txids]) await asyncio.wait([self.on_transaction_id(txid) for txid in txids])
remote_status = await self.ledger.network.subscribe_address(address)
self.assertTrue(await self.ledger.update_history(address, remote_status)) client = self.ledger.network.get_wallet_session(self.account.wallet)
remote_status = await self.ledger.network.subscribe_address(client, address)
self.assertTrue(await self.ledger.update_history(client, address, remote_status))
# 20 unconfirmed txs, 10 from blockchain, 10 from local to local # 20 unconfirmed txs, 10 from blockchain, 10 from local to local
utxos = await self.account.get_utxos() utxos = await self.account.get_utxos()
txs = [] txs = []
@ -166,12 +169,12 @@ class BasicTransactionTests(IntegrationTestCase):
await self.broadcast(tx) await self.broadcast(tx)
txs.append(tx) txs.append(tx)
await asyncio.wait([self.on_transaction_address(tx, address) for tx in txs], timeout=1) await asyncio.wait([self.on_transaction_address(tx, address) for tx in txs], timeout=1)
remote_status = await self.ledger.network.subscribe_address(address) remote_status = await self.ledger.network.subscribe_address(client, address)
self.assertTrue(await self.ledger.update_history(address, remote_status)) self.assertTrue(await self.ledger.update_history(client, address, remote_status))
# server history grows unordered # server history grows unordered
txid = await self.blockchain.send_to_address(address, 1) txid = await self.blockchain.send_to_address(address, 1)
await self.on_transaction_id(txid) await self.on_transaction_id(txid)
self.assertTrue(await self.ledger.update_history(address, remote_status)) self.assertTrue(await self.ledger.update_history(client, address, remote_status))
self.assertEqual(21, len((await self.ledger.get_local_status_and_history(address))[1])) self.assertEqual(21, len((await self.ledger.get_local_status_and_history(address))[1]))
self.assertEqual(0, len(self.ledger._known_addresses_out_of_sync)) self.assertEqual(0, len(self.ledger._known_addresses_out_of_sync))

View file

@ -8,16 +8,14 @@ from lbry.wallet.dewies import dict_values_to_lbc
class WalletCommands(CommandTestCase): class WalletCommands(CommandTestCase):
async def test_wallet_create_and_add_subscribe(self): async def test_wallet_create_and_add_subscribe(self):
session = next(iter(self.conductor.spv_node.server.session_mgr.sessions)) self.assertSetEqual({0, 27}, {len(session.hashX_subs) for session in self.conductor.spv_node.server.session_mgr.sessions})
self.assertEqual(len(session.hashX_subs), 27)
wallet = await self.daemon.jsonrpc_wallet_create('foo', create_account=True, single_key=True) wallet = await self.daemon.jsonrpc_wallet_create('foo', create_account=True, single_key=True)
self.assertEqual(len(session.hashX_subs), 28) self.assertSetEqual({0, 27, 1}, {len(session.hashX_subs) for session in self.conductor.spv_node.server.session_mgr.sessions})
await self.daemon.jsonrpc_wallet_remove(wallet.id) await self.daemon.jsonrpc_wallet_remove(wallet.id)
self.assertEqual(len(session.hashX_subs), 27) self.assertSetEqual({0, 27}, {len(session.hashX_subs) for session in self.conductor.spv_node.server.session_mgr.sessions})
await self.daemon.jsonrpc_wallet_add(wallet.id) await self.daemon.jsonrpc_wallet_add(wallet.id)
self.assertEqual(len(session.hashX_subs), 28) self.assertSetEqual({0, 27, 1}, {len(session.hashX_subs) for session in self.conductor.spv_node.server.session_mgr.sessions})
async def test_wallet_syncing_status(self): async def test_wallet_syncing_status(self):
address = await self.daemon.jsonrpc_address_unused() address = await self.daemon.jsonrpc_address_unused()

View file

@ -4,24 +4,30 @@ import lbry
import lbry.wallet import lbry.wallet
from lbry.error import ServerPaymentFeeAboveMaxAllowedError from lbry.error import ServerPaymentFeeAboveMaxAllowedError
from lbry.wallet.network import ClientSession from lbry.wallet.network import ClientSession
from lbry.testcase import IntegrationTestCase, CommandTestCase from lbry.testcase import CommandTestCase
from lbry.wallet.orchstr8.node import SPVNode from lbry.wallet.orchstr8.node import SPVNode
class TestSessions(IntegrationTestCase): class MockNetwork:
def __init__(self, ledger):
self.ledger = ledger
self._on_header_controller = None
class TestSessions(CommandTestCase):
""" """
Tests that server cleans up stale connections after session timeout and client times out too. Tests that server cleans up stale connections after session timeout and client times out too.
""" """
LEDGER = lbry.wallet
async def test_session_bloat_from_socket_timeout(self): async def test_session_bloat_from_socket_timeout(self):
await self.conductor.stop_spv() await self.conductor.stop_spv()
await self.ledger.stop() await self.ledger.stop()
self.conductor.spv_node.session_timeout = 1 self.conductor.spv_node.session_timeout = 1
await self.conductor.start_spv() await self.conductor.start_spv()
session = ClientSession( session = ClientSession(
network=None, server=(self.conductor.spv_node.hostname, self.conductor.spv_node.port), timeout=0.2 network=MockNetwork(self.ledger), server=(self.conductor.spv_node.hostname, self.conductor.spv_node.port),
timeout=0.2
) )
await session.create_connection() await session.create_connection()
await session.send_request('server.banner', ()) await session.send_request('server.banner', ())

View file

@ -16,12 +16,12 @@ class MockNetwork:
self.address = None self.address = None
self.get_history_called = [] self.get_history_called = []
self.get_transaction_called = [] self.get_transaction_called = []
self.is_connected = False self.is_connected = lambda _: False
def retriable_call(self, function, *args, **kwargs): def retriable_call(self, function, *args, **kwargs):
return function(*args, **kwargs) return function(*args, **kwargs)
async def get_history(self, address): async def get_history(self, address, session=None):
self.get_history_called.append(address) self.get_history_called.append(address)
self.address = address self.address = address
return self.history return self.history
@ -40,7 +40,7 @@ class MockNetwork:
merkle = await self.get_merkle(tx_hash, known_height) merkle = await self.get_merkle(tx_hash, known_height)
return tx, merkle return tx, merkle
async def get_transaction_batch(self, txids): async def get_transaction_batch(self, txids, session=None):
return { return {
txid: await self.get_transaction_and_merkle(txid) txid: await self.get_transaction_and_merkle(txid)
for txid in txids for txid in txids
@ -111,7 +111,7 @@ class TestSynchronization(LedgerTestCase):
txid2: hexlify(get_transaction(get_output(2)).raw), txid2: hexlify(get_transaction(get_output(2)).raw),
txid3: hexlify(get_transaction(get_output(3)).raw), txid3: hexlify(get_transaction(get_output(3)).raw),
}) })
await self.ledger.update_history(address, '') await self.ledger.update_history(None, address, '')
self.assertListEqual(self.ledger.network.get_history_called, [address]) self.assertListEqual(self.ledger.network.get_history_called, [address])
self.assertListEqual(self.ledger.network.get_transaction_called, [txid1, txid2, txid3]) self.assertListEqual(self.ledger.network.get_transaction_called, [txid1, txid2, txid3])
@ -129,7 +129,7 @@ class TestSynchronization(LedgerTestCase):
self.assertFalse(self.ledger._tx_cache[txid1].tx.is_verified) self.assertFalse(self.ledger._tx_cache[txid1].tx.is_verified)
self.assertFalse(self.ledger._tx_cache[txid2].tx.is_verified) self.assertFalse(self.ledger._tx_cache[txid2].tx.is_verified)
self.assertFalse(self.ledger._tx_cache[txid3].tx.is_verified) self.assertFalse(self.ledger._tx_cache[txid3].tx.is_verified)
await self.ledger.update_history(address, '') await self.ledger.update_history(None, address, '')
self.assertListEqual(self.ledger.network.get_history_called, [address]) self.assertListEqual(self.ledger.network.get_history_called, [address])
self.assertListEqual(self.ledger.network.get_transaction_called, []) self.assertListEqual(self.ledger.network.get_transaction_called, [])
@ -137,7 +137,7 @@ class TestSynchronization(LedgerTestCase):
self.ledger.network.transaction[txid4] = hexlify(get_transaction(get_output(4)).raw) self.ledger.network.transaction[txid4] = hexlify(get_transaction(get_output(4)).raw)
self.ledger.network.get_history_called = [] self.ledger.network.get_history_called = []
self.ledger.network.get_transaction_called = [] self.ledger.network.get_transaction_called = []
await self.ledger.update_history(address, '') await self.ledger.update_history(None, address, '')
self.assertListEqual(self.ledger.network.get_history_called, [address]) self.assertListEqual(self.ledger.network.get_history_called, [address])
self.assertListEqual(self.ledger.network.get_transaction_called, [txid4]) self.assertListEqual(self.ledger.network.get_transaction_called, [txid4])
address_details = await self.ledger.db.get_address(address=address) address_details = await self.ledger.db.get_address(address=address)