improved shutdown for torba.server and related test setup code
This commit is contained in:
parent
458189366f
commit
7092f40701
5 changed files with 88 additions and 58 deletions
|
@ -19,6 +19,9 @@ from torba.client.basemanager import BaseWalletManager
|
||||||
from torba.client.baseaccount import BaseAccount
|
from torba.client.baseaccount import BaseAccount
|
||||||
|
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_manager_from_environment(default_manager=BaseWalletManager):
|
def get_manager_from_environment(default_manager=BaseWalletManager):
|
||||||
if 'TORBA_MANAGER' not in os.environ:
|
if 'TORBA_MANAGER' not in os.environ:
|
||||||
return default_manager
|
return default_manager
|
||||||
|
@ -73,40 +76,55 @@ class Conductor:
|
||||||
self.spv_started = False
|
self.spv_started = False
|
||||||
self.wallet_started = False
|
self.wallet_started = False
|
||||||
|
|
||||||
|
self.log = log.getChild('conductor')
|
||||||
|
|
||||||
async def start_blockchain(self):
|
async def start_blockchain(self):
|
||||||
await self.blockchain_node.start()
|
if not self.blockchain_started:
|
||||||
await self.blockchain_node.generate(200)
|
await self.blockchain_node.start()
|
||||||
self.blockchain_started = True
|
await self.blockchain_node.generate(200)
|
||||||
|
self.blockchain_started = True
|
||||||
|
|
||||||
|
async def stop_blockchain(self):
|
||||||
|
if self.blockchain_started:
|
||||||
|
await self.blockchain_node.stop(cleanup=True)
|
||||||
|
self.blockchain_started = False
|
||||||
|
|
||||||
async def start_spv(self):
|
async def start_spv(self):
|
||||||
await self.spv_node.start()
|
if not self.spv_started:
|
||||||
self.spv_started = True
|
await self.spv_node.start()
|
||||||
|
self.spv_started = True
|
||||||
|
|
||||||
|
async def stop_spv(self):
|
||||||
|
if self.spv_started:
|
||||||
|
await self.spv_node.stop(cleanup=True)
|
||||||
|
self.spv_started = False
|
||||||
|
|
||||||
async def start_wallet(self):
|
async def start_wallet(self):
|
||||||
await self.wallet_node.start()
|
if not self.wallet_started:
|
||||||
self.wallet_started = True
|
await self.wallet_node.start()
|
||||||
|
self.wallet_started = True
|
||||||
|
|
||||||
|
async def stop_wallet(self):
|
||||||
|
if self.wallet_started:
|
||||||
|
await self.wallet_node.stop(cleanup=True)
|
||||||
|
self.wallet_started = False
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
self.blockchain_started or await self.start_blockchain()
|
await self.start_blockchain()
|
||||||
self.spv_started or await self.start_spv()
|
await self.start_spv()
|
||||||
self.wallet_started or await self.start_wallet()
|
await self.start_wallet()
|
||||||
|
|
||||||
async def stop(self):
|
async def stop(self):
|
||||||
if self.wallet_started:
|
all_the_stops = [
|
||||||
|
self.wallet_node.stop,
|
||||||
|
self.spv_node.stop,
|
||||||
|
self.blockchain_node.stop
|
||||||
|
]
|
||||||
|
for stop in all_the_stops:
|
||||||
try:
|
try:
|
||||||
await self.wallet_node.stop(cleanup=True)
|
await stop()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
log.exception('Exception raised while stopping services:', exc_info=e)
|
||||||
if self.spv_started:
|
|
||||||
try:
|
|
||||||
await self.spv_node.stop(cleanup=True)
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
if self.blockchain_started:
|
|
||||||
try:
|
|
||||||
await self.blockchain_node.stop(cleanup=True)
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
|
|
||||||
|
|
||||||
class WalletNode:
|
class WalletNode:
|
||||||
|
@ -182,7 +200,7 @@ class SPVNode:
|
||||||
|
|
||||||
async def stop(self, cleanup=True):
|
async def stop(self, cleanup=True):
|
||||||
try:
|
try:
|
||||||
self.server.db.shutdown()
|
self.server.stop()
|
||||||
finally:
|
finally:
|
||||||
cleanup and self.cleanup()
|
cleanup and self.cleanup()
|
||||||
|
|
||||||
|
@ -198,10 +216,10 @@ class BlockchainProcess(asyncio.SubprocessProtocol):
|
||||||
b'keypool return',
|
b'keypool return',
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(self, log):
|
def __init__(self):
|
||||||
self.ready = asyncio.Event()
|
self.ready = asyncio.Event()
|
||||||
self.stopped = asyncio.Event()
|
self.stopped = asyncio.Event()
|
||||||
self.log = log
|
self.log = log.getChild('blockchain')
|
||||||
|
|
||||||
def pipe_data_received(self, fd, data):
|
def pipe_data_received(self, fd, data):
|
||||||
if self.log and not any(ignore in data for ignore in self.IGNORE_OUTPUT):
|
if self.log and not any(ignore in data for ignore in self.IGNORE_OUTPUT):
|
||||||
|
@ -227,7 +245,7 @@ class BlockchainNode:
|
||||||
self.bin_dir = os.path.join(self.project_dir, 'bin')
|
self.bin_dir = os.path.join(self.project_dir, 'bin')
|
||||||
self.daemon_bin = os.path.join(self.bin_dir, daemon)
|
self.daemon_bin = os.path.join(self.bin_dir, daemon)
|
||||||
self.cli_bin = os.path.join(self.bin_dir, cli)
|
self.cli_bin = os.path.join(self.bin_dir, cli)
|
||||||
self.log = logging.getLogger('blockchain')
|
self.log = log.getChild('blockchain')
|
||||||
self.data_path = None
|
self.data_path = None
|
||||||
self.protocol = None
|
self.protocol = None
|
||||||
self.transport = None
|
self.transport = None
|
||||||
|
@ -289,7 +307,7 @@ class BlockchainNode:
|
||||||
)
|
)
|
||||||
self.log.info(' '.join(command))
|
self.log.info(' '.join(command))
|
||||||
self.transport, self.protocol = await loop.subprocess_exec(
|
self.transport, self.protocol = await loop.subprocess_exec(
|
||||||
lambda: BlockchainProcess(self.log), *command
|
lambda: BlockchainProcess(), *command
|
||||||
)
|
)
|
||||||
await self.protocol.ready.wait()
|
await self.protocol.ready.wait()
|
||||||
|
|
||||||
|
|
|
@ -657,6 +657,7 @@ class BlockProcessor:
|
||||||
# Shut down block processing
|
# Shut down block processing
|
||||||
self.logger.info('flushing to DB for a clean shutdown...')
|
self.logger.info('flushing to DB for a clean shutdown...')
|
||||||
await self.flush(True)
|
await self.flush(True)
|
||||||
|
self.db.close()
|
||||||
|
|
||||||
def force_chain_reorg(self, count):
|
def force_chain_reorg(self, count):
|
||||||
'''Force a reorg of the given number of blocks.
|
'''Force a reorg of the given number of blocks.
|
||||||
|
|
|
@ -134,7 +134,7 @@ class DB:
|
||||||
# Read TX counts (requires meta directory)
|
# Read TX counts (requires meta directory)
|
||||||
await self._read_tx_counts()
|
await self._read_tx_counts()
|
||||||
|
|
||||||
def shutdown(self):
|
def close(self):
|
||||||
self.utxo_db.close()
|
self.utxo_db.close()
|
||||||
self.history.close_db()
|
self.history.close_db()
|
||||||
|
|
||||||
|
|
|
@ -108,9 +108,10 @@ class Server:
|
||||||
await _start_cancellable(self.mempool.keep_synchronized)
|
await _start_cancellable(self.mempool.keep_synchronized)
|
||||||
await _start_cancellable(self.session_mgr.serve, self.notifications)
|
await _start_cancellable(self.session_mgr.serve, self.notifications)
|
||||||
|
|
||||||
def stop(self):
|
async def stop(self):
|
||||||
for task in reversed(self.cancellable_tasks):
|
for task in reversed(self.cancellable_tasks):
|
||||||
task.cancel()
|
task.cancel()
|
||||||
|
await asyncio.wait(self.cancellable_tasks)
|
||||||
self.shutdown_event.set()
|
self.shutdown_event.set()
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
|
|
|
@ -96,36 +96,35 @@ class AsyncioTestCase(unittest.TestCase):
|
||||||
"__unittest_expecting_failure__", False)
|
"__unittest_expecting_failure__", False)
|
||||||
expecting_failure = expecting_failure_class or expecting_failure_method
|
expecting_failure = expecting_failure_class or expecting_failure_method
|
||||||
outcome = _Outcome(result)
|
outcome = _Outcome(result)
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
try:
|
try:
|
||||||
self._outcome = outcome
|
self._outcome = outcome
|
||||||
|
|
||||||
loop = asyncio.new_event_loop()
|
asyncio.set_event_loop(loop)
|
||||||
try:
|
loop.set_debug(True)
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
loop.set_debug(True)
|
|
||||||
|
|
||||||
|
with outcome.testPartExecutor(self):
|
||||||
|
self.setUp()
|
||||||
|
loop.run_until_complete(self.asyncSetUp())
|
||||||
|
if outcome.success:
|
||||||
|
outcome.expecting_failure = expecting_failure
|
||||||
|
with outcome.testPartExecutor(self, isTest=True):
|
||||||
|
maybe_coroutine = testMethod()
|
||||||
|
if asyncio.iscoroutine(maybe_coroutine):
|
||||||
|
loop.run_until_complete(maybe_coroutine)
|
||||||
|
outcome.expecting_failure = False
|
||||||
with outcome.testPartExecutor(self):
|
with outcome.testPartExecutor(self):
|
||||||
self.setUp()
|
loop.run_until_complete(self.asyncTearDown())
|
||||||
loop.run_until_complete(self.asyncSetUp())
|
self.tearDown()
|
||||||
if outcome.success:
|
|
||||||
outcome.expecting_failure = expecting_failure
|
|
||||||
with outcome.testPartExecutor(self, isTest=True):
|
|
||||||
possible_coroutine = testMethod()
|
|
||||||
if asyncio.iscoroutine(possible_coroutine):
|
|
||||||
loop.run_until_complete(possible_coroutine)
|
|
||||||
outcome.expecting_failure = False
|
|
||||||
with outcome.testPartExecutor(self):
|
|
||||||
loop.run_until_complete(self.asyncTearDown())
|
|
||||||
self.tearDown()
|
|
||||||
finally:
|
|
||||||
try:
|
|
||||||
_cancel_all_tasks(loop)
|
|
||||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
|
||||||
finally:
|
|
||||||
asyncio.set_event_loop(None)
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
self.doCleanups()
|
self.doAsyncCleanups(loop)
|
||||||
|
|
||||||
|
try:
|
||||||
|
_cancel_all_tasks(loop)
|
||||||
|
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||||
|
finally:
|
||||||
|
asyncio.set_event_loop(None)
|
||||||
|
loop.close()
|
||||||
|
|
||||||
for test, reason in outcome.skipped:
|
for test, reason in outcome.skipped:
|
||||||
self._addSkip(result, test, reason)
|
self._addSkip(result, test, reason)
|
||||||
|
@ -155,6 +154,15 @@ class AsyncioTestCase(unittest.TestCase):
|
||||||
# clear the outcome, no more needed
|
# clear the outcome, no more needed
|
||||||
self._outcome = None
|
self._outcome = None
|
||||||
|
|
||||||
|
def doAsyncCleanups(self, loop):
|
||||||
|
outcome = self._outcome or _Outcome()
|
||||||
|
while self._cleanups:
|
||||||
|
function, args, kwargs = self._cleanups.pop()
|
||||||
|
with outcome.testPartExecutor(self):
|
||||||
|
maybe_coroutine = function(*args, **kwargs)
|
||||||
|
if asyncio.iscoroutine(maybe_coroutine):
|
||||||
|
loop.run_until_complete(maybe_coroutine)
|
||||||
|
|
||||||
|
|
||||||
class IntegrationTestCase(AsyncioTestCase):
|
class IntegrationTestCase(AsyncioTestCase):
|
||||||
|
|
||||||
|
@ -176,7 +184,12 @@ class IntegrationTestCase(AsyncioTestCase):
|
||||||
self.conductor = Conductor(
|
self.conductor = Conductor(
|
||||||
ledger_module=self.LEDGER, manager_module=self.MANAGER, verbosity=self.VERBOSITY
|
ledger_module=self.LEDGER, manager_module=self.MANAGER, verbosity=self.VERBOSITY
|
||||||
)
|
)
|
||||||
await self.conductor.start()
|
await self.conductor.start_blockchain()
|
||||||
|
self.addCleanup(self.conductor.stop_blockchain)
|
||||||
|
await self.conductor.start_spv()
|
||||||
|
self.addCleanup(self.conductor.stop_spv)
|
||||||
|
await self.conductor.start_wallet()
|
||||||
|
self.addCleanup(self.conductor.stop_wallet)
|
||||||
self.blockchain = self.conductor.blockchain_node
|
self.blockchain = self.conductor.blockchain_node
|
||||||
self.wallet_node = self.conductor.wallet_node
|
self.wallet_node = self.conductor.wallet_node
|
||||||
self.manager = self.wallet_node.manager
|
self.manager = self.wallet_node.manager
|
||||||
|
@ -184,9 +197,6 @@ class IntegrationTestCase(AsyncioTestCase):
|
||||||
self.wallet = self.wallet_node.wallet
|
self.wallet = self.wallet_node.wallet
|
||||||
self.account = self.wallet_node.wallet.default_account
|
self.account = self.wallet_node.wallet.default_account
|
||||||
|
|
||||||
async def asyncTearDown(self):
|
|
||||||
await self.conductor.stop()
|
|
||||||
|
|
||||||
async def assertBalance(self, account, expected_balance: str): # pylint: disable=C0103
|
async def assertBalance(self, account, expected_balance: str): # pylint: disable=C0103
|
||||||
balance = await account.get_balance()
|
balance = await account.get_balance()
|
||||||
self.assertEqual(satoshis_to_coins(balance), expected_balance)
|
self.assertEqual(satoshis_to_coins(balance), expected_balance)
|
||||||
|
|
Loading…
Reference in a new issue