From 7092f4070169e47e719d716e9fe0131cd1ad0cc2 Mon Sep 17 00:00:00 2001 From: Lex Berezhny Date: Fri, 14 Dec 2018 16:15:59 -0500 Subject: [PATCH] improved shutdown for torba.server and related test setup code --- torba/orchstr8/node.py | 74 ++++++++++++++++++++------------- torba/server/block_processor.py | 1 + torba/server/db.py | 2 +- torba/server/server.py | 3 +- torba/testcase.py | 66 ++++++++++++++++------------- 5 files changed, 88 insertions(+), 58 deletions(-) diff --git a/torba/orchstr8/node.py b/torba/orchstr8/node.py index 1918a1b01..9d30bdd4b 100644 --- a/torba/orchstr8/node.py +++ b/torba/orchstr8/node.py @@ -19,6 +19,9 @@ from torba.client.basemanager import BaseWalletManager from torba.client.baseaccount import BaseAccount +log = logging.getLogger(__name__) + + def get_manager_from_environment(default_manager=BaseWalletManager): if 'TORBA_MANAGER' not in os.environ: return default_manager @@ -73,40 +76,55 @@ class Conductor: self.spv_started = False self.wallet_started = False + self.log = log.getChild('conductor') + async def start_blockchain(self): - await self.blockchain_node.start() - await self.blockchain_node.generate(200) - self.blockchain_started = True + if not self.blockchain_started: + await self.blockchain_node.start() + await self.blockchain_node.generate(200) + self.blockchain_started = True + + async def stop_blockchain(self): + if self.blockchain_started: + await self.blockchain_node.stop(cleanup=True) + self.blockchain_started = False async def start_spv(self): - await self.spv_node.start() - self.spv_started = True + if not self.spv_started: + await self.spv_node.start() + self.spv_started = True + + async def stop_spv(self): + if self.spv_started: + await self.spv_node.stop(cleanup=True) + self.spv_started = False async def start_wallet(self): - await self.wallet_node.start() - self.wallet_started = True + if not self.wallet_started: + await self.wallet_node.start() + self.wallet_started = True + + async def stop_wallet(self): + if self.wallet_started: + await self.wallet_node.stop(cleanup=True) + self.wallet_started = False async def start(self): - self.blockchain_started or await self.start_blockchain() - self.spv_started or await self.start_spv() - self.wallet_started or await self.start_wallet() + await self.start_blockchain() + await self.start_spv() + await self.start_wallet() async def stop(self): - if self.wallet_started: + all_the_stops = [ + self.wallet_node.stop, + self.spv_node.stop, + self.blockchain_node.stop + ] + for stop in all_the_stops: try: - await self.wallet_node.stop(cleanup=True) + await stop() except Exception as e: - print(e) - if self.spv_started: - try: - await self.spv_node.stop(cleanup=True) - except Exception as e: - print(e) - if self.blockchain_started: - try: - await self.blockchain_node.stop(cleanup=True) - except Exception as e: - print(e) + log.exception('Exception raised while stopping services:', exc_info=e) class WalletNode: @@ -182,7 +200,7 @@ class SPVNode: async def stop(self, cleanup=True): try: - self.server.db.shutdown() + self.server.stop() finally: cleanup and self.cleanup() @@ -198,10 +216,10 @@ class BlockchainProcess(asyncio.SubprocessProtocol): b'keypool return', ] - def __init__(self, log): + def __init__(self): self.ready = asyncio.Event() self.stopped = asyncio.Event() - self.log = log + self.log = log.getChild('blockchain') def pipe_data_received(self, fd, data): if self.log and not any(ignore in data for ignore in self.IGNORE_OUTPUT): @@ -227,7 +245,7 @@ class BlockchainNode: self.bin_dir = os.path.join(self.project_dir, 'bin') self.daemon_bin = os.path.join(self.bin_dir, daemon) self.cli_bin = os.path.join(self.bin_dir, cli) - self.log = logging.getLogger('blockchain') + self.log = log.getChild('blockchain') self.data_path = None self.protocol = None self.transport = None @@ -289,7 +307,7 @@ class BlockchainNode: ) self.log.info(' '.join(command)) self.transport, self.protocol = await loop.subprocess_exec( - lambda: BlockchainProcess(self.log), *command + lambda: BlockchainProcess(), *command ) await self.protocol.ready.wait() diff --git a/torba/server/block_processor.py b/torba/server/block_processor.py index e544eb7b8..93925b1ed 100644 --- a/torba/server/block_processor.py +++ b/torba/server/block_processor.py @@ -657,6 +657,7 @@ class BlockProcessor: # Shut down block processing self.logger.info('flushing to DB for a clean shutdown...') await self.flush(True) + self.db.close() def force_chain_reorg(self, count): '''Force a reorg of the given number of blocks. diff --git a/torba/server/db.py b/torba/server/db.py index 0721dcef5..16e165bc9 100644 --- a/torba/server/db.py +++ b/torba/server/db.py @@ -134,7 +134,7 @@ class DB: # Read TX counts (requires meta directory) await self._read_tx_counts() - def shutdown(self): + def close(self): self.utxo_db.close() self.history.close_db() diff --git a/torba/server/server.py b/torba/server/server.py index b6822e28b..433df49fa 100644 --- a/torba/server/server.py +++ b/torba/server/server.py @@ -108,9 +108,10 @@ class Server: await _start_cancellable(self.mempool.keep_synchronized) await _start_cancellable(self.session_mgr.serve, self.notifications) - def stop(self): + async def stop(self): for task in reversed(self.cancellable_tasks): task.cancel() + await asyncio.wait(self.cancellable_tasks) self.shutdown_event.set() def run(self): diff --git a/torba/testcase.py b/torba/testcase.py index 8e61b7828..e6074c7e9 100644 --- a/torba/testcase.py +++ b/torba/testcase.py @@ -96,36 +96,35 @@ class AsyncioTestCase(unittest.TestCase): "__unittest_expecting_failure__", False) expecting_failure = expecting_failure_class or expecting_failure_method outcome = _Outcome(result) + loop = asyncio.new_event_loop() try: self._outcome = outcome - loop = asyncio.new_event_loop() - try: - asyncio.set_event_loop(loop) - loop.set_debug(True) + asyncio.set_event_loop(loop) + loop.set_debug(True) + with outcome.testPartExecutor(self): + self.setUp() + loop.run_until_complete(self.asyncSetUp()) + if outcome.success: + outcome.expecting_failure = expecting_failure + with outcome.testPartExecutor(self, isTest=True): + maybe_coroutine = testMethod() + if asyncio.iscoroutine(maybe_coroutine): + loop.run_until_complete(maybe_coroutine) + outcome.expecting_failure = False with outcome.testPartExecutor(self): - self.setUp() - loop.run_until_complete(self.asyncSetUp()) - if outcome.success: - outcome.expecting_failure = expecting_failure - with outcome.testPartExecutor(self, isTest=True): - possible_coroutine = testMethod() - if asyncio.iscoroutine(possible_coroutine): - loop.run_until_complete(possible_coroutine) - outcome.expecting_failure = False - with outcome.testPartExecutor(self): - loop.run_until_complete(self.asyncTearDown()) - self.tearDown() - finally: - try: - _cancel_all_tasks(loop) - loop.run_until_complete(loop.shutdown_asyncgens()) - finally: - asyncio.set_event_loop(None) - loop.close() + loop.run_until_complete(self.asyncTearDown()) + self.tearDown() - self.doCleanups() + self.doAsyncCleanups(loop) + + try: + _cancel_all_tasks(loop) + loop.run_until_complete(loop.shutdown_asyncgens()) + finally: + asyncio.set_event_loop(None) + loop.close() for test, reason in outcome.skipped: self._addSkip(result, test, reason) @@ -155,6 +154,15 @@ class AsyncioTestCase(unittest.TestCase): # clear the outcome, no more needed self._outcome = None + def doAsyncCleanups(self, loop): + outcome = self._outcome or _Outcome() + while self._cleanups: + function, args, kwargs = self._cleanups.pop() + with outcome.testPartExecutor(self): + maybe_coroutine = function(*args, **kwargs) + if asyncio.iscoroutine(maybe_coroutine): + loop.run_until_complete(maybe_coroutine) + class IntegrationTestCase(AsyncioTestCase): @@ -176,7 +184,12 @@ class IntegrationTestCase(AsyncioTestCase): self.conductor = Conductor( ledger_module=self.LEDGER, manager_module=self.MANAGER, verbosity=self.VERBOSITY ) - await self.conductor.start() + await self.conductor.start_blockchain() + self.addCleanup(self.conductor.stop_blockchain) + await self.conductor.start_spv() + self.addCleanup(self.conductor.stop_spv) + await self.conductor.start_wallet() + self.addCleanup(self.conductor.stop_wallet) self.blockchain = self.conductor.blockchain_node self.wallet_node = self.conductor.wallet_node self.manager = self.wallet_node.manager @@ -184,9 +197,6 @@ class IntegrationTestCase(AsyncioTestCase): self.wallet = self.wallet_node.wallet self.account = self.wallet_node.wallet.default_account - async def asyncTearDown(self): - await self.conductor.stop() - async def assertBalance(self, account, expected_balance: str): # pylint: disable=C0103 balance = await account.get_balance() self.assertEqual(satoshis_to_coins(balance), expected_balance)