Compare commits
5 commits
master
...
one-sessio
Author | SHA1 | Date | |
---|---|---|---|
|
55fb76f7f9 | ||
|
6d7b94fe1b | ||
|
d6cf0e2699 | ||
|
730a67c8d6 | ||
|
13b473403e |
9 changed files with 218 additions and 97 deletions
|
@ -1303,8 +1303,9 @@ class Daemon(metaclass=JSONRPCServerType):
|
||||||
'name': SingleKey.name if single_key else HierarchicalDeterministic.name
|
'name': SingleKey.name if single_key else HierarchicalDeterministic.name
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
if self.ledger.network.is_connected:
|
await self.ledger._subscribe_accounts([account])
|
||||||
await self.ledger.subscribe_account(account)
|
elif wallet.accounts:
|
||||||
|
await self.ledger._subscribe_accounts(wallet.accounts)
|
||||||
wallet.save()
|
wallet.save()
|
||||||
if not skip_on_startup:
|
if not skip_on_startup:
|
||||||
with self.conf.update_config() as c:
|
with self.conf.update_config() as c:
|
||||||
|
@ -1331,9 +1332,9 @@ class Daemon(metaclass=JSONRPCServerType):
|
||||||
if not os.path.exists(wallet_path):
|
if not os.path.exists(wallet_path):
|
||||||
raise Exception(f"Wallet at path '{wallet_path}' was not found.")
|
raise Exception(f"Wallet at path '{wallet_path}' was not found.")
|
||||||
wallet = self.wallet_manager.import_wallet(wallet_path)
|
wallet = self.wallet_manager.import_wallet(wallet_path)
|
||||||
if self.ledger.network.is_connected:
|
|
||||||
for account in wallet.accounts:
|
if not self.ledger.network.is_connected(wallet.id):
|
||||||
await self.ledger.subscribe_account(account)
|
await self.ledger._subscribe_accounts(wallet.accounts)
|
||||||
return wallet
|
return wallet
|
||||||
|
|
||||||
@requires("wallet")
|
@requires("wallet")
|
||||||
|
@ -1619,7 +1620,7 @@ class Daemon(metaclass=JSONRPCServerType):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
wallet.save()
|
wallet.save()
|
||||||
if self.ledger.network.is_connected:
|
if self.ledger.network.is_connected(wallet.id):
|
||||||
await self.ledger.subscribe_account(account)
|
await self.ledger.subscribe_account(account)
|
||||||
return account
|
return account
|
||||||
|
|
||||||
|
@ -1647,7 +1648,7 @@ class Daemon(metaclass=JSONRPCServerType):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
wallet.save()
|
wallet.save()
|
||||||
if self.ledger.network.is_connected:
|
if self.ledger.network.is_connected(wallet.id):
|
||||||
await self.ledger.subscribe_account(account)
|
await self.ledger.subscribe_account(account)
|
||||||
return account
|
return account
|
||||||
|
|
||||||
|
@ -1863,7 +1864,7 @@ class Daemon(metaclass=JSONRPCServerType):
|
||||||
wallet_changed = False
|
wallet_changed = False
|
||||||
if data is not None:
|
if data is not None:
|
||||||
added_accounts = wallet.merge(self.wallet_manager, password, data)
|
added_accounts = wallet.merge(self.wallet_manager, password, data)
|
||||||
if added_accounts and self.ledger.network.is_connected:
|
if added_accounts and self.ledger.network.is_connected(wallet.id):
|
||||||
if blocking:
|
if blocking:
|
||||||
await asyncio.wait([
|
await asyncio.wait([
|
||||||
a.ledger.subscribe_account(a) for a in added_accounts
|
a.ledger.subscribe_account(a) for a in added_accounts
|
||||||
|
@ -2957,7 +2958,7 @@ class Daemon(metaclass=JSONRPCServerType):
|
||||||
'public_key': data['holding_public_key'],
|
'public_key': data['holding_public_key'],
|
||||||
'address_generator': {'name': 'single-address'}
|
'address_generator': {'name': 'single-address'}
|
||||||
})
|
})
|
||||||
if self.ledger.network.is_connected:
|
if self.ledger.network.is_connected(wallet.id):
|
||||||
await self.ledger.subscribe_account(account)
|
await self.ledger.subscribe_account(account)
|
||||||
await self.ledger._update_tasks.done.wait()
|
await self.ledger._update_tasks.done.wait()
|
||||||
# Case 3: the holding address has changed and we can't create or find an account for it
|
# Case 3: the holding address has changed and we can't create or find an account for it
|
||||||
|
|
|
@ -122,7 +122,7 @@ class Ledger(metaclass=LedgerRegistry):
|
||||||
self.headers.checkpoints = self.checkpoints
|
self.headers.checkpoints = self.checkpoints
|
||||||
self.network: Network = self.config.get('network') or Network(self)
|
self.network: Network = self.config.get('network') or Network(self)
|
||||||
self.network.on_header.listen(self.receive_header)
|
self.network.on_header.listen(self.receive_header)
|
||||||
self.network.on_status.listen(self.process_status_update)
|
# self.network.on_status.listen(self.process_status_update)
|
||||||
self.network.on_connected.listen(self.join_network)
|
self.network.on_connected.listen(self.join_network)
|
||||||
|
|
||||||
self.accounts = []
|
self.accounts = []
|
||||||
|
@ -324,12 +324,15 @@ class Ledger(metaclass=LedgerRegistry):
|
||||||
self.db.open(),
|
self.db.open(),
|
||||||
self.headers.open()
|
self.headers.open()
|
||||||
])
|
])
|
||||||
|
|
||||||
fully_synced = self.on_ready.first
|
fully_synced = self.on_ready.first
|
||||||
|
|
||||||
asyncio.create_task(self.network.start())
|
asyncio.create_task(self.network.start())
|
||||||
await self.network.on_connected.first
|
await self.network.on_connected.first
|
||||||
async with self._header_processing_lock:
|
async with self._header_processing_lock:
|
||||||
await self._update_tasks.add(self.initial_headers_sync())
|
await self._update_tasks.add(self.initial_headers_sync())
|
||||||
await fully_synced
|
await fully_synced
|
||||||
|
|
||||||
await self.db.release_all_outputs()
|
await self.db.release_all_outputs()
|
||||||
await asyncio.gather(*(a.maybe_migrate_certificates() for a in self.accounts))
|
await asyncio.gather(*(a.maybe_migrate_certificates() for a in self.accounts))
|
||||||
await asyncio.gather(*(a.save_max_gap() for a in self.accounts))
|
await asyncio.gather(*(a.save_max_gap() for a in self.accounts))
|
||||||
|
@ -448,10 +451,16 @@ class Ledger(metaclass=LedgerRegistry):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def subscribe_accounts(self):
|
async def subscribe_accounts(self):
|
||||||
if self.network.is_connected and self.accounts:
|
return await self._subscribe_accounts(self.accounts)
|
||||||
log.info("Subscribe to %i accounts", len(self.accounts))
|
|
||||||
|
async def _subscribe_accounts(self, accounts):
|
||||||
|
wallet_ids = {account.wallet.id for account in accounts}
|
||||||
|
await self.network.connect_wallets(*wallet_ids)
|
||||||
|
# if self.network.is_connected and self.accounts:
|
||||||
|
if accounts:
|
||||||
|
log.info("Subscribe to %i accounts", len(accounts))
|
||||||
await asyncio.wait([
|
await asyncio.wait([
|
||||||
self.subscribe_account(a) for a in self.accounts
|
self.subscribe_account(a) for a in accounts
|
||||||
])
|
])
|
||||||
|
|
||||||
async def subscribe_account(self, account: Account):
|
async def subscribe_account(self, account: Account):
|
||||||
|
@ -460,8 +469,10 @@ class Ledger(metaclass=LedgerRegistry):
|
||||||
await account.ensure_address_gap()
|
await account.ensure_address_gap()
|
||||||
|
|
||||||
async def unsubscribe_account(self, account: Account):
|
async def unsubscribe_account(self, account: Account):
|
||||||
|
session = self.network.get_wallet_session(account.wallet)
|
||||||
for address in await account.get_addresses():
|
for address in await account.get_addresses():
|
||||||
await self.network.unsubscribe_address(address)
|
await self.network.unsubscribe_address(address, session=session)
|
||||||
|
await self.network.close_wallet_session(account.wallet)
|
||||||
|
|
||||||
async def announce_addresses(self, address_manager: AddressManager, addresses: List[str]):
|
async def announce_addresses(self, address_manager: AddressManager, addresses: List[str]):
|
||||||
await self.subscribe_addresses(address_manager, addresses)
|
await self.subscribe_addresses(address_manager, addresses)
|
||||||
|
@ -470,28 +481,29 @@ class Ledger(metaclass=LedgerRegistry):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def subscribe_addresses(self, address_manager: AddressManager, addresses: List[str], batch_size: int = 1000):
|
async def subscribe_addresses(self, address_manager: AddressManager, addresses: List[str], batch_size: int = 1000):
|
||||||
if self.network.is_connected and addresses:
|
if self.network.is_connected(address_manager.account.wallet.id) and addresses:
|
||||||
|
session = self.network.get_wallet_session(address_manager.account.wallet)
|
||||||
addresses_remaining = list(addresses)
|
addresses_remaining = list(addresses)
|
||||||
while addresses_remaining:
|
while addresses_remaining:
|
||||||
batch = addresses_remaining[:batch_size]
|
batch = addresses_remaining[:batch_size]
|
||||||
results = await self.network.subscribe_address(*batch)
|
results = await self.network.subscribe_address(session, *batch)
|
||||||
for address, remote_status in zip(batch, results):
|
for address, remote_status in zip(batch, results):
|
||||||
self._update_tasks.add(self.update_history(address, remote_status, address_manager))
|
self._update_tasks.add(self.update_history(session, address, remote_status, address_manager))
|
||||||
addresses_remaining = addresses_remaining[batch_size:]
|
addresses_remaining = addresses_remaining[batch_size:]
|
||||||
if self.network.client and self.network.client.server_address_and_port:
|
if session and session.server_address_and_port:
|
||||||
log.info("subscribed to %i/%i addresses on %s:%i", len(addresses) - len(addresses_remaining),
|
log.info("subscribed to %i/%i addresses on %s:%i", len(addresses) - len(addresses_remaining),
|
||||||
len(addresses), *self.network.client.server_address_and_port)
|
len(addresses), *session.server_address_and_port)
|
||||||
if self.network.client and self.network.client.server_address_and_port:
|
if session and session.server_address_and_port:
|
||||||
log.info(
|
log.info(
|
||||||
"finished subscribing to %i addresses on %s:%i", len(addresses),
|
"finished subscribing to %i addresses on %s:%i", len(addresses),
|
||||||
*self.network.client.server_address_and_port
|
*session.server_address_and_port
|
||||||
)
|
)
|
||||||
|
|
||||||
def process_status_update(self, update):
|
def process_status_update(self, client, update):
|
||||||
address, remote_status = update
|
address, remote_status = update
|
||||||
self._update_tasks.add(self.update_history(address, remote_status))
|
self._update_tasks.add(self.update_history(client, address, remote_status))
|
||||||
|
|
||||||
async def update_history(self, address, remote_status, address_manager: AddressManager = None,
|
async def update_history(self, client, address, remote_status, address_manager: AddressManager = None,
|
||||||
reattempt_update: bool = True):
|
reattempt_update: bool = True):
|
||||||
async with self._address_update_locks[address]:
|
async with self._address_update_locks[address]:
|
||||||
self._known_addresses_out_of_sync.discard(address)
|
self._known_addresses_out_of_sync.discard(address)
|
||||||
|
@ -500,7 +512,9 @@ class Ledger(metaclass=LedgerRegistry):
|
||||||
if local_status == remote_status:
|
if local_status == remote_status:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
remote_history = await self.network.retriable_call(self.network.get_history, address)
|
# remote_history = await self.network.retriable_call(self.network.get_history, address, client)
|
||||||
|
remote_history = await self.network.get_history(address, session=client)
|
||||||
|
|
||||||
remote_history = list(map(itemgetter('tx_hash', 'height'), remote_history))
|
remote_history = list(map(itemgetter('tx_hash', 'height'), remote_history))
|
||||||
we_need = set(remote_history) - set(local_history)
|
we_need = set(remote_history) - set(local_history)
|
||||||
if not we_need:
|
if not we_need:
|
||||||
|
@ -563,7 +577,7 @@ class Ledger(metaclass=LedgerRegistry):
|
||||||
"request %i transactions, %i/%i for %s are already synced", len(to_request), len(synced_txs),
|
"request %i transactions, %i/%i for %s are already synced", len(to_request), len(synced_txs),
|
||||||
len(remote_history), address
|
len(remote_history), address
|
||||||
)
|
)
|
||||||
requested_txes = await self._request_transaction_batch(to_request, len(remote_history), address)
|
requested_txes = await self._request_transaction_batch(to_request, len(remote_history), address, client)
|
||||||
for tx in requested_txes:
|
for tx in requested_txes:
|
||||||
pending_synced_history[tx_indexes[tx.id]] = f"{tx.id}:{tx.height}:"
|
pending_synced_history[tx_indexes[tx.id]] = f"{tx.id}:{tx.height}:"
|
||||||
synced_txs.append(tx)
|
synced_txs.append(tx)
|
||||||
|
@ -603,7 +617,7 @@ class Ledger(metaclass=LedgerRegistry):
|
||||||
if self._tx_cache.get(txid) is not cache_item:
|
if self._tx_cache.get(txid) is not cache_item:
|
||||||
log.warning("tx cache corrupted while syncing %s, reattempt sync=%s", address, reattempt_update)
|
log.warning("tx cache corrupted while syncing %s, reattempt sync=%s", address, reattempt_update)
|
||||||
if reattempt_update:
|
if reattempt_update:
|
||||||
return await self.update_history(address, remote_status, address_manager, False)
|
return await self.update_history(client, address, remote_status, address_manager, False)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
local_status, local_history = \
|
local_status, local_history = \
|
||||||
|
@ -741,7 +755,7 @@ class Ledger(metaclass=LedgerRegistry):
|
||||||
await _single_batch(batch)
|
await _single_batch(batch)
|
||||||
return transactions
|
return transactions
|
||||||
|
|
||||||
async def _request_transaction_batch(self, to_request, remote_history_size, address):
|
async def _request_transaction_batch(self, to_request, remote_history_size, address, session):
|
||||||
header_cache = {}
|
header_cache = {}
|
||||||
batches = [[]]
|
batches = [[]]
|
||||||
remote_heights = {}
|
remote_heights = {}
|
||||||
|
@ -766,7 +780,8 @@ class Ledger(metaclass=LedgerRegistry):
|
||||||
|
|
||||||
async def _single_batch(batch):
|
async def _single_batch(batch):
|
||||||
this_batch_synced = []
|
this_batch_synced = []
|
||||||
batch_result = await self.network.retriable_call(self.network.get_transaction_batch, batch)
|
# batch_result = await self.network.retriable_call(self.network.get_transaction_batch, batch, session)
|
||||||
|
batch_result = await self.network.get_transaction_batch(batch, session)
|
||||||
for txid, (raw, merkle) in batch_result.items():
|
for txid, (raw, merkle) in batch_result.items():
|
||||||
remote_height = remote_heights[txid]
|
remote_height = remote_heights[txid]
|
||||||
merkle_height = merkle['block_height']
|
merkle_height = merkle['block_height']
|
||||||
|
|
|
@ -6,7 +6,6 @@ from operator import itemgetter
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Dict, Optional, Tuple
|
from typing import Dict, Optional, Tuple
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
||||||
from lbry import __version__
|
from lbry import __version__
|
||||||
|
@ -14,6 +13,7 @@ from lbry.error import IncompatibleWalletServerError
|
||||||
from lbry.wallet.rpc import RPCSession as BaseClientSession, Connector, RPCError, ProtocolError
|
from lbry.wallet.rpc import RPCSession as BaseClientSession, Connector, RPCError, ProtocolError
|
||||||
from lbry.wallet.stream import StreamController
|
from lbry.wallet.stream import StreamController
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -33,11 +33,26 @@ class ClientSession(BaseClientSession):
|
||||||
self.pending_amount = 0
|
self.pending_amount = 0
|
||||||
self._on_connect_cb = on_connect_callback or (lambda: None)
|
self._on_connect_cb = on_connect_callback or (lambda: None)
|
||||||
self.trigger_urgent_reconnect = asyncio.Event()
|
self.trigger_urgent_reconnect = asyncio.Event()
|
||||||
|
self._connected = asyncio.Event()
|
||||||
|
self._disconnected = asyncio.Event()
|
||||||
|
self._disconnected.set()
|
||||||
|
|
||||||
|
self._on_status_controller = StreamController(merge_repeated_events=True)
|
||||||
|
self.on_status = self._on_status_controller.stream
|
||||||
|
self.on_status.listen(partial(self.network.ledger.process_status_update, self))
|
||||||
|
self.subscription_controllers = {
|
||||||
|
'blockchain.headers.subscribe': self.network._on_header_controller,
|
||||||
|
'blockchain.address.subscribe': self._on_status_controller,
|
||||||
|
}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def available(self):
|
def available(self):
|
||||||
return not self.is_closing() and self.response_time is not None
|
return not self.is_closing() and self.response_time is not None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_connected(self) -> bool:
|
||||||
|
return self._connected.is_set()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def server_address_and_port(self) -> Optional[Tuple[str, int]]:
|
def server_address_and_port(self) -> Optional[Tuple[str, int]]:
|
||||||
if not self.transport:
|
if not self.transport:
|
||||||
|
@ -98,7 +113,6 @@ class ClientSession(BaseClientSession):
|
||||||
|
|
||||||
async def ensure_session(self):
|
async def ensure_session(self):
|
||||||
# Handles reconnecting and maintaining a session alive
|
# Handles reconnecting and maintaining a session alive
|
||||||
# TODO: change to 'ping' on newer protocol (above 1.2)
|
|
||||||
retry_delay = default_delay = 1.0
|
retry_delay = default_delay = 1.0
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
|
@ -144,7 +158,7 @@ class ClientSession(BaseClientSession):
|
||||||
self.connection_latency = perf_counter() - start
|
self.connection_latency = perf_counter() - start
|
||||||
|
|
||||||
async def handle_request(self, request):
|
async def handle_request(self, request):
|
||||||
controller = self.network.subscription_controllers[request.method]
|
controller = self.subscription_controllers[request.method]
|
||||||
controller.add(request.args)
|
controller.add(request.args)
|
||||||
|
|
||||||
def connection_lost(self, exc):
|
def connection_lost(self, exc):
|
||||||
|
@ -154,6 +168,13 @@ class ClientSession(BaseClientSession):
|
||||||
self.connection_latency = None
|
self.connection_latency = None
|
||||||
self._response_samples = 0
|
self._response_samples = 0
|
||||||
self._on_disconnect_controller.add(True)
|
self._on_disconnect_controller.add(True)
|
||||||
|
self._connected.clear()
|
||||||
|
self._disconnected.set()
|
||||||
|
|
||||||
|
def connection_made(self, transport):
|
||||||
|
super(ClientSession, self).connection_made(transport)
|
||||||
|
self._disconnected.clear()
|
||||||
|
self._connected.set()
|
||||||
|
|
||||||
|
|
||||||
class Network:
|
class Network:
|
||||||
|
@ -165,6 +186,7 @@ class Network:
|
||||||
self.ledger = ledger
|
self.ledger = ledger
|
||||||
self.session_pool = SessionPool(network=self, timeout=self.config.get('connect_timeout', 6))
|
self.session_pool = SessionPool(network=self, timeout=self.config.get('connect_timeout', 6))
|
||||||
self.client: Optional[ClientSession] = None
|
self.client: Optional[ClientSession] = None
|
||||||
|
self.clients: Dict[str, ClientSession] = {}
|
||||||
self.server_features = None
|
self.server_features = None
|
||||||
self._switch_task: Optional[asyncio.Task] = None
|
self._switch_task: Optional[asyncio.Task] = None
|
||||||
self.running = False
|
self.running = False
|
||||||
|
@ -177,27 +199,34 @@ class Network:
|
||||||
self._on_header_controller = StreamController(merge_repeated_events=True)
|
self._on_header_controller = StreamController(merge_repeated_events=True)
|
||||||
self.on_header = self._on_header_controller.stream
|
self.on_header = self._on_header_controller.stream
|
||||||
|
|
||||||
self._on_status_controller = StreamController(merge_repeated_events=True)
|
|
||||||
self.on_status = self._on_status_controller.stream
|
|
||||||
|
|
||||||
self.subscription_controllers = {
|
self.subscription_controllers = {
|
||||||
'blockchain.headers.subscribe': self._on_header_controller,
|
'blockchain.headers.subscribe': self._on_header_controller,
|
||||||
'blockchain.address.subscribe': self._on_status_controller,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
self.aiohttp_session: Optional[aiohttp.ClientSession] = None
|
self.aiohttp_session: Optional[aiohttp.ClientSession] = None
|
||||||
|
|
||||||
|
def get_wallet_session(self, wallet):
|
||||||
|
return self.clients[wallet.id]
|
||||||
|
|
||||||
|
async def close_wallet_session(self, wallet):
|
||||||
|
session = self.clients.pop(wallet.id)
|
||||||
|
ensure_connect_task = self.session_pool.wallet_session_tasks.pop(session)
|
||||||
|
if ensure_connect_task and not ensure_connect_task.done():
|
||||||
|
ensure_connect_task.cancel()
|
||||||
|
await session.close()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def config(self):
|
def config(self):
|
||||||
return self.ledger.config
|
return self.ledger.config
|
||||||
|
|
||||||
async def switch_forever(self):
|
async def switch_forever(self):
|
||||||
while self.running:
|
while self.running:
|
||||||
if self.is_connected:
|
if self.is_connected():
|
||||||
await self.client.on_disconnected.first
|
await self.client.on_disconnected.first
|
||||||
self.server_features = None
|
self.server_features = None
|
||||||
self.client = None
|
self.client = None
|
||||||
continue
|
continue
|
||||||
|
|
||||||
self.client = await self.session_pool.wait_for_fastest_session()
|
self.client = await self.session_pool.wait_for_fastest_session()
|
||||||
log.info("Switching to SPV wallet server: %s:%d", *self.client.server)
|
log.info("Switching to SPV wallet server: %s:%d", *self.client.server)
|
||||||
try:
|
try:
|
||||||
|
@ -220,6 +249,11 @@ class Network:
|
||||||
self._switch_task.add_done_callback(lambda _: log.info("Wallet client switching task stopped."))
|
self._switch_task.add_done_callback(lambda _: log.info("Wallet client switching task stopped."))
|
||||||
self.session_pool.start(self.config['default_servers'])
|
self.session_pool.start(self.config['default_servers'])
|
||||||
self.on_header.listen(self._update_remote_height)
|
self.on_header.listen(self._update_remote_height)
|
||||||
|
await self.ledger.subscribe_accounts()
|
||||||
|
|
||||||
|
async def connect_wallets(self, *wallet_ids):
|
||||||
|
if wallet_ids:
|
||||||
|
await asyncio.wait([self.session_pool.connect_wallet(wallet_id) for wallet_id in wallet_ids])
|
||||||
|
|
||||||
async def stop(self):
|
async def stop(self):
|
||||||
if self.running:
|
if self.running:
|
||||||
|
@ -228,9 +262,19 @@ class Network:
|
||||||
self._switch_task.cancel()
|
self._switch_task.cancel()
|
||||||
self.session_pool.stop()
|
self.session_pool.stop()
|
||||||
|
|
||||||
@property
|
def is_connected(self, wallet_id: str = None):
|
||||||
def is_connected(self):
|
if wallet_id is None:
|
||||||
return self.client and not self.client.is_closing()
|
if not self.client:
|
||||||
|
return False
|
||||||
|
return self.client.is_connected
|
||||||
|
if wallet_id not in self.clients:
|
||||||
|
return False
|
||||||
|
return self.clients[wallet_id].is_connected
|
||||||
|
|
||||||
|
def connect_wallet_client(self, wallet_id: str, server: Tuple[str, int]):
|
||||||
|
client = ClientSession(network=self, server=server)
|
||||||
|
self.clients[wallet_id] = client
|
||||||
|
return client
|
||||||
|
|
||||||
def rpc(self, list_or_method, args, restricted=True, session=None):
|
def rpc(self, list_or_method, args, restricted=True, session=None):
|
||||||
session = session or (self.client if restricted else self.session_pool.fastest_session)
|
session = session or (self.client if restricted else self.session_pool.fastest_session)
|
||||||
|
@ -256,7 +300,8 @@ class Network:
|
||||||
raise asyncio.CancelledError() # if we got here, we are shutting down
|
raise asyncio.CancelledError() # if we got here, we are shutting down
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def single_call_context(self, function, *args, **kwargs):
|
async def fastest_connection_context(self):
|
||||||
|
|
||||||
if not self.is_connected:
|
if not self.is_connected:
|
||||||
log.warning("Wallet server unavailable, waiting for it to come back and retry.")
|
log.warning("Wallet server unavailable, waiting for it to come back and retry.")
|
||||||
await self.on_connected.first
|
await self.on_connected.first
|
||||||
|
@ -264,6 +309,38 @@ class Network:
|
||||||
server = self.session_pool.fastest_session.server
|
server = self.session_pool.fastest_session.server
|
||||||
session = ClientSession(network=self, server=server)
|
session = ClientSession(network=self, server=server)
|
||||||
|
|
||||||
|
async def call_with_reconnect(function, *args, **kwargs):
|
||||||
|
nonlocal session
|
||||||
|
|
||||||
|
while self.running:
|
||||||
|
if not session.is_connected:
|
||||||
|
try:
|
||||||
|
await session.create_connection()
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
if not session.is_connected:
|
||||||
|
log.warning("Wallet server unavailable, waiting for it to come back and retry.")
|
||||||
|
await self.on_connected.first
|
||||||
|
await self.session_pool.wait_for_fastest_session()
|
||||||
|
server = self.session_pool.fastest_session.server
|
||||||
|
session = ClientSession(network=self, server=server)
|
||||||
|
try:
|
||||||
|
return await partial(function, *args, session_override=session, **kwargs)(*args, **kwargs)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
log.warning("'%s' failed, retrying", function.__name__)
|
||||||
|
try:
|
||||||
|
yield call_with_reconnect
|
||||||
|
finally:
|
||||||
|
await session.close()
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def single_call_context(self, function, *args, **kwargs):
|
||||||
|
if not self.is_connected():
|
||||||
|
log.warning("Wallet server unavailable, waiting for it to come back and retry.")
|
||||||
|
await self.on_connected.first
|
||||||
|
await self.session_pool.wait_for_fastest_session()
|
||||||
|
server = self.session_pool.fastest_session.server
|
||||||
|
session = ClientSession(network=self, server=server)
|
||||||
|
|
||||||
async def call_with_reconnect(*a, **kw):
|
async def call_with_reconnect(*a, **kw):
|
||||||
while self.running:
|
while self.running:
|
||||||
if not session.available:
|
if not session.available:
|
||||||
|
@ -280,71 +357,71 @@ class Network:
|
||||||
def _update_remote_height(self, header_args):
|
def _update_remote_height(self, header_args):
|
||||||
self.remote_height = header_args[0]["height"]
|
self.remote_height = header_args[0]["height"]
|
||||||
|
|
||||||
def get_transaction(self, tx_hash, known_height=None):
|
def get_transaction(self, tx_hash, known_height=None, session=None):
|
||||||
# use any server if its old, otherwise restrict to who gave us the history
|
# use any server if its old, otherwise restrict to who gave us the history
|
||||||
restricted = known_height in (None, -1, 0) or 0 > known_height > self.remote_height - 10
|
restricted = known_height in (None, -1, 0) or 0 > known_height > self.remote_height - 10
|
||||||
return self.rpc('blockchain.transaction.get', [tx_hash], restricted)
|
return self.rpc('blockchain.transaction.get', [tx_hash], restricted, session=session)
|
||||||
|
|
||||||
def get_transaction_batch(self, txids, restricted=True, session=None):
|
def get_transaction_batch(self, txids, restricted=True, session=None):
|
||||||
# use any server if its old, otherwise restrict to who gave us the history
|
# use any server if its old, otherwise restrict to who gave us the history
|
||||||
return self.rpc('blockchain.transaction.get_batch', txids, restricted, session)
|
return self.rpc('blockchain.transaction.get_batch', txids, restricted, session=session)
|
||||||
|
|
||||||
def get_transaction_and_merkle(self, tx_hash, known_height=None):
|
def get_transaction_and_merkle(self, tx_hash, known_height=None, session=None):
|
||||||
# use any server if its old, otherwise restrict to who gave us the history
|
# use any server if its old, otherwise restrict to who gave us the history
|
||||||
restricted = known_height in (None, -1, 0) or 0 > known_height > self.remote_height - 10
|
restricted = known_height in (None, -1, 0) or 0 > known_height > self.remote_height - 10
|
||||||
return self.rpc('blockchain.transaction.info', [tx_hash], restricted)
|
return self.rpc('blockchain.transaction.info', [tx_hash], restricted, session=session)
|
||||||
|
|
||||||
def get_transaction_height(self, tx_hash, known_height=None):
|
def get_transaction_height(self, tx_hash, known_height=None, session=None):
|
||||||
restricted = not known_height or 0 > known_height > self.remote_height - 10
|
restricted = not known_height or 0 > known_height > self.remote_height - 10
|
||||||
return self.rpc('blockchain.transaction.get_height', [tx_hash], restricted)
|
return self.rpc('blockchain.transaction.get_height', [tx_hash], restricted, session=session)
|
||||||
|
|
||||||
def get_merkle(self, tx_hash, height):
|
def get_merkle(self, tx_hash, height, session=None):
|
||||||
restricted = 0 > height > self.remote_height - 10
|
restricted = 0 > height > self.remote_height - 10
|
||||||
return self.rpc('blockchain.transaction.get_merkle', [tx_hash, height], restricted)
|
return self.rpc('blockchain.transaction.get_merkle', [tx_hash, height], restricted, session=session)
|
||||||
|
|
||||||
def get_headers(self, height, count=10000, b64=False):
|
def get_headers(self, height, count=10000, b64=False):
|
||||||
restricted = height >= self.remote_height - 100
|
restricted = height >= self.remote_height - 100
|
||||||
return self.rpc('blockchain.block.headers', [height, count, 0, b64], restricted)
|
return self.rpc('blockchain.block.headers', [height, count, 0, b64], restricted)
|
||||||
|
|
||||||
# --- Subscribes, history and broadcasts are always aimed towards the master client directly
|
# --- Subscribes, history and broadcasts are always aimed towards the master client directly
|
||||||
def get_history(self, address):
|
def get_history(self, address, session=None):
|
||||||
return self.rpc('blockchain.address.get_history', [address], True)
|
return self.rpc('blockchain.address.get_history', [address], True, session=session)
|
||||||
|
|
||||||
def broadcast(self, raw_transaction):
|
def broadcast(self, raw_transaction, session=None):
|
||||||
return self.rpc('blockchain.transaction.broadcast', [raw_transaction], True)
|
return self.rpc('blockchain.transaction.broadcast', [raw_transaction], True, session=session)
|
||||||
|
|
||||||
def subscribe_headers(self):
|
def subscribe_headers(self):
|
||||||
return self.rpc('blockchain.headers.subscribe', [True], True)
|
return self.rpc('blockchain.headers.subscribe', [True], True)
|
||||||
|
|
||||||
async def subscribe_address(self, address, *addresses):
|
async def subscribe_address(self, session, address, *addresses):
|
||||||
addresses = list((address, ) + addresses)
|
addresses = list((address, ) + addresses)
|
||||||
server_addr_and_port = self.client.server_address_and_port # on disconnect client will be None
|
server_addr_and_port = session.server_address_and_port # on disconnect client will be None
|
||||||
try:
|
try:
|
||||||
return await self.rpc('blockchain.address.subscribe', addresses, True)
|
return await self.rpc('blockchain.address.subscribe', addresses, True, session=session)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
log.warning(
|
log.warning(
|
||||||
"timed out subscribing to addresses from %s:%i",
|
"timed out subscribing to addresses from %s:%i",
|
||||||
*server_addr_and_port
|
*server_addr_and_port
|
||||||
)
|
)
|
||||||
# abort and cancel, we can't lose a subscription, it will happen again on reconnect
|
# abort and cancel, we can't lose a subscription, it will happen again on reconnect
|
||||||
if self.client:
|
if session:
|
||||||
self.client.abort()
|
session.abort()
|
||||||
raise asyncio.CancelledError()
|
raise asyncio.CancelledError()
|
||||||
|
|
||||||
def unsubscribe_address(self, address):
|
def unsubscribe_address(self, address, session=None):
|
||||||
return self.rpc('blockchain.address.unsubscribe', [address], True)
|
return self.rpc('blockchain.address.unsubscribe', [address], True, session=session)
|
||||||
|
|
||||||
def get_server_features(self):
|
def get_server_features(self, session=None):
|
||||||
return self.rpc('server.features', (), restricted=True)
|
return self.rpc('server.features', (), restricted=True, session=session)
|
||||||
|
|
||||||
def get_claims_by_ids(self, claim_ids):
|
def get_claims_by_ids(self, claim_ids):
|
||||||
return self.rpc('blockchain.claimtrie.getclaimsbyids', claim_ids)
|
return self.rpc('blockchain.claimtrie.getclaimsbyids', claim_ids)
|
||||||
|
|
||||||
def resolve(self, urls, session_override=None):
|
def resolve(self, urls, session_override=None):
|
||||||
return self.rpc('blockchain.claimtrie.resolve', urls, False, session_override)
|
return self.rpc('blockchain.claimtrie.resolve', urls, False, session=session_override)
|
||||||
|
|
||||||
def claim_search(self, session_override=None, **kwargs):
|
def claim_search(self, session_override=None, **kwargs):
|
||||||
return self.rpc('blockchain.claimtrie.search', kwargs, False, session_override)
|
return self.rpc('blockchain.claimtrie.search', kwargs, False, session=session_override)
|
||||||
|
|
||||||
async def new_resolve(self, server, urls):
|
async def new_resolve(self, server, urls):
|
||||||
message = {"method": "resolve", "params": {"urls": urls, "protobuf": True}}
|
message = {"method": "resolve", "params": {"urls": urls, "protobuf": True}}
|
||||||
|
@ -371,6 +448,7 @@ class SessionPool:
|
||||||
def __init__(self, network: Network, timeout: float):
|
def __init__(self, network: Network, timeout: float):
|
||||||
self.network = network
|
self.network = network
|
||||||
self.sessions: Dict[ClientSession, Optional[asyncio.Task]] = dict()
|
self.sessions: Dict[ClientSession, Optional[asyncio.Task]] = dict()
|
||||||
|
self.wallet_session_tasks: Dict[ClientSession, Optional[asyncio.Task]] = dict()
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
self.new_connection_event = asyncio.Event()
|
self.new_connection_event = asyncio.Event()
|
||||||
|
|
||||||
|
@ -430,6 +508,23 @@ class SessionPool:
|
||||||
task.add_done_callback(lambda _: self.ensure_connections())
|
task.add_done_callback(lambda _: self.ensure_connections())
|
||||||
self.sessions[session] = task
|
self.sessions[session] = task
|
||||||
|
|
||||||
|
async def connect_wallet(self, wallet_id: str):
|
||||||
|
fastest = await self.wait_for_fastest_session()
|
||||||
|
connected = asyncio.Event()
|
||||||
|
if not self.network.is_connected(wallet_id):
|
||||||
|
session = self.network.connect_wallet_client(wallet_id, fastest.server)
|
||||||
|
session._on_connect_cb = connected.set
|
||||||
|
# session._on_connect_cb = self._get_session_connect_callback(session)
|
||||||
|
else:
|
||||||
|
connected.set()
|
||||||
|
session = self.network.clients[wallet_id]
|
||||||
|
task = self.wallet_session_tasks.get(session, None)
|
||||||
|
if not task or task.done():
|
||||||
|
task = asyncio.create_task(session.ensure_session())
|
||||||
|
# task.add_done_callback(lambda _: self.ensure_connections())
|
||||||
|
self.wallet_session_tasks[session] = task
|
||||||
|
await connected.wait()
|
||||||
|
|
||||||
def start(self, default_servers):
|
def start(self, default_servers):
|
||||||
for server in default_servers:
|
for server in default_servers:
|
||||||
self._connect_session(server)
|
self._connect_session(server)
|
||||||
|
|
|
@ -1227,10 +1227,7 @@ class LBRYElectrumX(SessionBase):
|
||||||
return await self.address_status(hashX)
|
return await self.address_status(hashX)
|
||||||
|
|
||||||
async def hashX_unsubscribe(self, hashX, alias):
|
async def hashX_unsubscribe(self, hashX, alias):
|
||||||
try:
|
self.hashX_subs.pop(hashX, None)
|
||||||
del self.hashX_subs[hashX]
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def address_to_hashX(self, address):
|
def address_to_hashX(self, address):
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -75,7 +75,7 @@ class ReconnectTests(IntegrationTestCase):
|
||||||
session.trigger_urgent_reconnect.set()
|
session.trigger_urgent_reconnect.set()
|
||||||
await asyncio.wait_for(self.ledger.network.session_pool.new_connection_event.wait(), timeout=1)
|
await asyncio.wait_for(self.ledger.network.session_pool.new_connection_event.wait(), timeout=1)
|
||||||
self.assertEqual(2, len(list(self.ledger.network.session_pool.available_sessions)))
|
self.assertEqual(2, len(list(self.ledger.network.session_pool.available_sessions)))
|
||||||
self.assertTrue(self.ledger.network.is_connected)
|
self.assertTrue(self.ledger.network.is_connected())
|
||||||
switch_event = self.ledger.network.on_connected.first
|
switch_event = self.ledger.network.on_connected.first
|
||||||
await node2.stop(True)
|
await node2.stop(True)
|
||||||
# secondary down, but primary is ok, do not switch! (switches trigger new on_connected events)
|
# secondary down, but primary is ok, do not switch! (switches trigger new on_connected events)
|
||||||
|
@ -99,7 +99,7 @@ class ReconnectTests(IntegrationTestCase):
|
||||||
address1 = await self.account.receiving.get_or_create_usable_address()
|
address1 = await self.account.receiving.get_or_create_usable_address()
|
||||||
# disconnect and send a new tx, should reconnect and get it
|
# disconnect and send a new tx, should reconnect and get it
|
||||||
self.ledger.network.client.connection_lost(Exception())
|
self.ledger.network.client.connection_lost(Exception())
|
||||||
self.assertFalse(self.ledger.network.is_connected)
|
self.assertFalse(self.ledger.network.is_connected())
|
||||||
sendtxid = await self.blockchain.send_to_address(address1, 1.1337)
|
sendtxid = await self.blockchain.send_to_address(address1, 1.1337)
|
||||||
await asyncio.wait_for(self.on_transaction_id(sendtxid), 1.0) # mempool
|
await asyncio.wait_for(self.on_transaction_id(sendtxid), 1.0) # mempool
|
||||||
await self.blockchain.generate(1)
|
await self.blockchain.generate(1)
|
||||||
|
@ -122,7 +122,7 @@ class ReconnectTests(IntegrationTestCase):
|
||||||
sendtxid = await self.blockchain.send_to_address(address1, 42)
|
sendtxid = await self.blockchain.send_to_address(address1, 42)
|
||||||
await self.blockchain.generate(1)
|
await self.blockchain.generate(1)
|
||||||
# (this is just so the test doesn't hang forever if it doesn't reconnect)
|
# (this is just so the test doesn't hang forever if it doesn't reconnect)
|
||||||
if not self.ledger.network.is_connected:
|
if not self.ledger.network.is_connected():
|
||||||
await asyncio.wait_for(self.ledger.network.on_connected.first, timeout=1.0)
|
await asyncio.wait_for(self.ledger.network.on_connected.first, timeout=1.0)
|
||||||
# omg, the burned cable still works! torba is fire proof!
|
# omg, the burned cable still works! torba is fire proof!
|
||||||
await self.ledger.network.get_transaction(sendtxid)
|
await self.ledger.network.get_transaction(sendtxid)
|
||||||
|
@ -130,11 +130,11 @@ class ReconnectTests(IntegrationTestCase):
|
||||||
async def test_timeout_then_reconnect(self):
|
async def test_timeout_then_reconnect(self):
|
||||||
# tests that it connects back after some failed attempts
|
# tests that it connects back after some failed attempts
|
||||||
await self.conductor.spv_node.stop()
|
await self.conductor.spv_node.stop()
|
||||||
self.assertFalse(self.ledger.network.is_connected)
|
self.assertFalse(self.ledger.network.is_connected())
|
||||||
await asyncio.sleep(0.2) # let it retry and fail once
|
await asyncio.sleep(0.2) # let it retry and fail once
|
||||||
await self.conductor.spv_node.start(self.conductor.blockchain_node)
|
await self.conductor.spv_node.start(self.conductor.blockchain_node)
|
||||||
await self.ledger.network.on_connected.first
|
await self.ledger.network.on_connected.first
|
||||||
self.assertTrue(self.ledger.network.is_connected)
|
self.assertTrue(self.ledger.network.is_connected())
|
||||||
|
|
||||||
async def test_online_but_still_unavailable(self):
|
async def test_online_but_still_unavailable(self):
|
||||||
# Edge case. See issue #2445 for context
|
# Edge case. See issue #2445 for context
|
||||||
|
@ -179,12 +179,18 @@ class ServerPickingTestCase(AsyncioTestCase):
|
||||||
],
|
],
|
||||||
'connect_timeout': 3
|
'connect_timeout': 3
|
||||||
})
|
})
|
||||||
|
ledger.accounts = []
|
||||||
|
|
||||||
|
async def subscribe_accounts():
|
||||||
|
pass
|
||||||
|
|
||||||
|
ledger.subscribe_accounts = subscribe_accounts
|
||||||
|
|
||||||
network = Network(ledger)
|
network = Network(ledger)
|
||||||
self.addCleanup(network.stop)
|
self.addCleanup(network.stop)
|
||||||
asyncio.ensure_future(network.start())
|
asyncio.ensure_future(network.start())
|
||||||
await asyncio.wait_for(network.on_connected.first, timeout=1)
|
await asyncio.wait_for(network.on_connected.first, timeout=1)
|
||||||
self.assertTrue(network.is_connected)
|
self.assertTrue(network.is_connected())
|
||||||
self.assertTupleEqual(network.client.server, ('127.0.0.1', 1337))
|
self.assertTupleEqual(network.client.server, ('127.0.0.1', 1337))
|
||||||
self.assertTrue(all([not session.is_closing() for session in network.session_pool.available_sessions]))
|
self.assertTrue(all([not session.is_closing() for session in network.session_pool.available_sessions]))
|
||||||
# ensure we are connected to all of them after a while
|
# ensure we are connected to all of them after a while
|
||||||
|
|
|
@ -152,8 +152,11 @@ class BasicTransactionTests(IntegrationTestCase):
|
||||||
for batch in range(0, len(sends), 10):
|
for batch in range(0, len(sends), 10):
|
||||||
txids = await asyncio.gather(*sends[batch:batch + 10])
|
txids = await asyncio.gather(*sends[batch:batch + 10])
|
||||||
await asyncio.wait([self.on_transaction_id(txid) for txid in txids])
|
await asyncio.wait([self.on_transaction_id(txid) for txid in txids])
|
||||||
remote_status = await self.ledger.network.subscribe_address(address)
|
|
||||||
self.assertTrue(await self.ledger.update_history(address, remote_status))
|
client = self.ledger.network.get_wallet_session(self.account.wallet)
|
||||||
|
|
||||||
|
remote_status = await self.ledger.network.subscribe_address(client, address)
|
||||||
|
self.assertTrue(await self.ledger.update_history(client, address, remote_status))
|
||||||
# 20 unconfirmed txs, 10 from blockchain, 10 from local to local
|
# 20 unconfirmed txs, 10 from blockchain, 10 from local to local
|
||||||
utxos = await self.account.get_utxos()
|
utxos = await self.account.get_utxos()
|
||||||
txs = []
|
txs = []
|
||||||
|
@ -166,12 +169,12 @@ class BasicTransactionTests(IntegrationTestCase):
|
||||||
await self.broadcast(tx)
|
await self.broadcast(tx)
|
||||||
txs.append(tx)
|
txs.append(tx)
|
||||||
await asyncio.wait([self.on_transaction_address(tx, address) for tx in txs], timeout=1)
|
await asyncio.wait([self.on_transaction_address(tx, address) for tx in txs], timeout=1)
|
||||||
remote_status = await self.ledger.network.subscribe_address(address)
|
remote_status = await self.ledger.network.subscribe_address(client, address)
|
||||||
self.assertTrue(await self.ledger.update_history(address, remote_status))
|
self.assertTrue(await self.ledger.update_history(client, address, remote_status))
|
||||||
# server history grows unordered
|
# server history grows unordered
|
||||||
txid = await self.blockchain.send_to_address(address, 1)
|
txid = await self.blockchain.send_to_address(address, 1)
|
||||||
await self.on_transaction_id(txid)
|
await self.on_transaction_id(txid)
|
||||||
self.assertTrue(await self.ledger.update_history(address, remote_status))
|
self.assertTrue(await self.ledger.update_history(client, address, remote_status))
|
||||||
self.assertEqual(21, len((await self.ledger.get_local_status_and_history(address))[1]))
|
self.assertEqual(21, len((await self.ledger.get_local_status_and_history(address))[1]))
|
||||||
self.assertEqual(0, len(self.ledger._known_addresses_out_of_sync))
|
self.assertEqual(0, len(self.ledger._known_addresses_out_of_sync))
|
||||||
|
|
||||||
|
|
|
@ -8,16 +8,14 @@ from lbry.wallet.dewies import dict_values_to_lbc
|
||||||
|
|
||||||
|
|
||||||
class WalletCommands(CommandTestCase):
|
class WalletCommands(CommandTestCase):
|
||||||
|
|
||||||
async def test_wallet_create_and_add_subscribe(self):
|
async def test_wallet_create_and_add_subscribe(self):
|
||||||
session = next(iter(self.conductor.spv_node.server.session_mgr.sessions))
|
self.assertSetEqual({0, 27}, {len(session.hashX_subs) for session in self.conductor.spv_node.server.session_mgr.sessions})
|
||||||
self.assertEqual(len(session.hashX_subs), 27)
|
|
||||||
wallet = await self.daemon.jsonrpc_wallet_create('foo', create_account=True, single_key=True)
|
wallet = await self.daemon.jsonrpc_wallet_create('foo', create_account=True, single_key=True)
|
||||||
self.assertEqual(len(session.hashX_subs), 28)
|
self.assertSetEqual({0, 27, 1}, {len(session.hashX_subs) for session in self.conductor.spv_node.server.session_mgr.sessions})
|
||||||
await self.daemon.jsonrpc_wallet_remove(wallet.id)
|
await self.daemon.jsonrpc_wallet_remove(wallet.id)
|
||||||
self.assertEqual(len(session.hashX_subs), 27)
|
self.assertSetEqual({0, 27}, {len(session.hashX_subs) for session in self.conductor.spv_node.server.session_mgr.sessions})
|
||||||
await self.daemon.jsonrpc_wallet_add(wallet.id)
|
await self.daemon.jsonrpc_wallet_add(wallet.id)
|
||||||
self.assertEqual(len(session.hashX_subs), 28)
|
self.assertSetEqual({0, 27, 1}, {len(session.hashX_subs) for session in self.conductor.spv_node.server.session_mgr.sessions})
|
||||||
|
|
||||||
async def test_wallet_syncing_status(self):
|
async def test_wallet_syncing_status(self):
|
||||||
address = await self.daemon.jsonrpc_address_unused()
|
address = await self.daemon.jsonrpc_address_unused()
|
||||||
|
|
|
@ -4,24 +4,30 @@ import lbry
|
||||||
import lbry.wallet
|
import lbry.wallet
|
||||||
from lbry.error import ServerPaymentFeeAboveMaxAllowedError
|
from lbry.error import ServerPaymentFeeAboveMaxAllowedError
|
||||||
from lbry.wallet.network import ClientSession
|
from lbry.wallet.network import ClientSession
|
||||||
from lbry.testcase import IntegrationTestCase, CommandTestCase
|
from lbry.testcase import CommandTestCase
|
||||||
from lbry.wallet.orchstr8.node import SPVNode
|
from lbry.wallet.orchstr8.node import SPVNode
|
||||||
|
|
||||||
|
|
||||||
class TestSessions(IntegrationTestCase):
|
class MockNetwork:
|
||||||
|
def __init__(self, ledger):
|
||||||
|
self.ledger = ledger
|
||||||
|
self._on_header_controller = None
|
||||||
|
|
||||||
|
|
||||||
|
class TestSessions(CommandTestCase):
|
||||||
"""
|
"""
|
||||||
Tests that server cleans up stale connections after session timeout and client times out too.
|
Tests that server cleans up stale connections after session timeout and client times out too.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
LEDGER = lbry.wallet
|
|
||||||
|
|
||||||
async def test_session_bloat_from_socket_timeout(self):
|
async def test_session_bloat_from_socket_timeout(self):
|
||||||
await self.conductor.stop_spv()
|
await self.conductor.stop_spv()
|
||||||
await self.ledger.stop()
|
await self.ledger.stop()
|
||||||
self.conductor.spv_node.session_timeout = 1
|
self.conductor.spv_node.session_timeout = 1
|
||||||
await self.conductor.start_spv()
|
await self.conductor.start_spv()
|
||||||
|
|
||||||
session = ClientSession(
|
session = ClientSession(
|
||||||
network=None, server=(self.conductor.spv_node.hostname, self.conductor.spv_node.port), timeout=0.2
|
network=MockNetwork(self.ledger), server=(self.conductor.spv_node.hostname, self.conductor.spv_node.port),
|
||||||
|
timeout=0.2
|
||||||
)
|
)
|
||||||
await session.create_connection()
|
await session.create_connection()
|
||||||
await session.send_request('server.banner', ())
|
await session.send_request('server.banner', ())
|
||||||
|
|
|
@ -16,12 +16,12 @@ class MockNetwork:
|
||||||
self.address = None
|
self.address = None
|
||||||
self.get_history_called = []
|
self.get_history_called = []
|
||||||
self.get_transaction_called = []
|
self.get_transaction_called = []
|
||||||
self.is_connected = False
|
self.is_connected = lambda _: False
|
||||||
|
|
||||||
def retriable_call(self, function, *args, **kwargs):
|
def retriable_call(self, function, *args, **kwargs):
|
||||||
return function(*args, **kwargs)
|
return function(*args, **kwargs)
|
||||||
|
|
||||||
async def get_history(self, address):
|
async def get_history(self, address, session=None):
|
||||||
self.get_history_called.append(address)
|
self.get_history_called.append(address)
|
||||||
self.address = address
|
self.address = address
|
||||||
return self.history
|
return self.history
|
||||||
|
@ -40,7 +40,7 @@ class MockNetwork:
|
||||||
merkle = await self.get_merkle(tx_hash, known_height)
|
merkle = await self.get_merkle(tx_hash, known_height)
|
||||||
return tx, merkle
|
return tx, merkle
|
||||||
|
|
||||||
async def get_transaction_batch(self, txids):
|
async def get_transaction_batch(self, txids, session=None):
|
||||||
return {
|
return {
|
||||||
txid: await self.get_transaction_and_merkle(txid)
|
txid: await self.get_transaction_and_merkle(txid)
|
||||||
for txid in txids
|
for txid in txids
|
||||||
|
@ -111,7 +111,7 @@ class TestSynchronization(LedgerTestCase):
|
||||||
txid2: hexlify(get_transaction(get_output(2)).raw),
|
txid2: hexlify(get_transaction(get_output(2)).raw),
|
||||||
txid3: hexlify(get_transaction(get_output(3)).raw),
|
txid3: hexlify(get_transaction(get_output(3)).raw),
|
||||||
})
|
})
|
||||||
await self.ledger.update_history(address, '')
|
await self.ledger.update_history(None, address, '')
|
||||||
self.assertListEqual(self.ledger.network.get_history_called, [address])
|
self.assertListEqual(self.ledger.network.get_history_called, [address])
|
||||||
self.assertListEqual(self.ledger.network.get_transaction_called, [txid1, txid2, txid3])
|
self.assertListEqual(self.ledger.network.get_transaction_called, [txid1, txid2, txid3])
|
||||||
|
|
||||||
|
@ -129,7 +129,7 @@ class TestSynchronization(LedgerTestCase):
|
||||||
self.assertFalse(self.ledger._tx_cache[txid1].tx.is_verified)
|
self.assertFalse(self.ledger._tx_cache[txid1].tx.is_verified)
|
||||||
self.assertFalse(self.ledger._tx_cache[txid2].tx.is_verified)
|
self.assertFalse(self.ledger._tx_cache[txid2].tx.is_verified)
|
||||||
self.assertFalse(self.ledger._tx_cache[txid3].tx.is_verified)
|
self.assertFalse(self.ledger._tx_cache[txid3].tx.is_verified)
|
||||||
await self.ledger.update_history(address, '')
|
await self.ledger.update_history(None, address, '')
|
||||||
self.assertListEqual(self.ledger.network.get_history_called, [address])
|
self.assertListEqual(self.ledger.network.get_history_called, [address])
|
||||||
self.assertListEqual(self.ledger.network.get_transaction_called, [])
|
self.assertListEqual(self.ledger.network.get_transaction_called, [])
|
||||||
|
|
||||||
|
@ -137,7 +137,7 @@ class TestSynchronization(LedgerTestCase):
|
||||||
self.ledger.network.transaction[txid4] = hexlify(get_transaction(get_output(4)).raw)
|
self.ledger.network.transaction[txid4] = hexlify(get_transaction(get_output(4)).raw)
|
||||||
self.ledger.network.get_history_called = []
|
self.ledger.network.get_history_called = []
|
||||||
self.ledger.network.get_transaction_called = []
|
self.ledger.network.get_transaction_called = []
|
||||||
await self.ledger.update_history(address, '')
|
await self.ledger.update_history(None, address, '')
|
||||||
self.assertListEqual(self.ledger.network.get_history_called, [address])
|
self.assertListEqual(self.ledger.network.get_history_called, [address])
|
||||||
self.assertListEqual(self.ledger.network.get_transaction_called, [txid4])
|
self.assertListEqual(self.ledger.network.get_transaction_called, [txid4])
|
||||||
address_details = await self.ledger.db.get_address(address=address)
|
address_details = await self.ledger.db.get_address(address=address)
|
||||||
|
|
Loading…
Reference in a new issue