improved shutdown for torba.server and related test setup code

This commit is contained in:
Lex Berezhny 2018-12-14 16:15:59 -05:00
parent 458189366f
commit 7092f40701
5 changed files with 88 additions and 58 deletions

View file

@ -19,6 +19,9 @@ from torba.client.basemanager import BaseWalletManager
from torba.client.baseaccount import BaseAccount from torba.client.baseaccount import BaseAccount
log = logging.getLogger(__name__)
def get_manager_from_environment(default_manager=BaseWalletManager): def get_manager_from_environment(default_manager=BaseWalletManager):
if 'TORBA_MANAGER' not in os.environ: if 'TORBA_MANAGER' not in os.environ:
return default_manager return default_manager
@ -73,40 +76,55 @@ class Conductor:
self.spv_started = False self.spv_started = False
self.wallet_started = False self.wallet_started = False
self.log = log.getChild('conductor')
async def start_blockchain(self): async def start_blockchain(self):
if not self.blockchain_started:
await self.blockchain_node.start() await self.blockchain_node.start()
await self.blockchain_node.generate(200) await self.blockchain_node.generate(200)
self.blockchain_started = True self.blockchain_started = True
async def stop_blockchain(self):
if self.blockchain_started:
await self.blockchain_node.stop(cleanup=True)
self.blockchain_started = False
async def start_spv(self): async def start_spv(self):
if not self.spv_started:
await self.spv_node.start() await self.spv_node.start()
self.spv_started = True self.spv_started = True
async def stop_spv(self):
if self.spv_started:
await self.spv_node.stop(cleanup=True)
self.spv_started = False
async def start_wallet(self): async def start_wallet(self):
if not self.wallet_started:
await self.wallet_node.start() await self.wallet_node.start()
self.wallet_started = True self.wallet_started = True
async def stop_wallet(self):
if self.wallet_started:
await self.wallet_node.stop(cleanup=True)
self.wallet_started = False
async def start(self): async def start(self):
self.blockchain_started or await self.start_blockchain() await self.start_blockchain()
self.spv_started or await self.start_spv() await self.start_spv()
self.wallet_started or await self.start_wallet() await self.start_wallet()
async def stop(self): async def stop(self):
if self.wallet_started: all_the_stops = [
self.wallet_node.stop,
self.spv_node.stop,
self.blockchain_node.stop
]
for stop in all_the_stops:
try: try:
await self.wallet_node.stop(cleanup=True) await stop()
except Exception as e: except Exception as e:
print(e) log.exception('Exception raised while stopping services:', exc_info=e)
if self.spv_started:
try:
await self.spv_node.stop(cleanup=True)
except Exception as e:
print(e)
if self.blockchain_started:
try:
await self.blockchain_node.stop(cleanup=True)
except Exception as e:
print(e)
class WalletNode: class WalletNode:
@ -182,7 +200,7 @@ class SPVNode:
async def stop(self, cleanup=True): async def stop(self, cleanup=True):
try: try:
self.server.db.shutdown() self.server.stop()
finally: finally:
cleanup and self.cleanup() cleanup and self.cleanup()
@ -198,10 +216,10 @@ class BlockchainProcess(asyncio.SubprocessProtocol):
b'keypool return', b'keypool return',
] ]
def __init__(self, log): def __init__(self):
self.ready = asyncio.Event() self.ready = asyncio.Event()
self.stopped = asyncio.Event() self.stopped = asyncio.Event()
self.log = log self.log = log.getChild('blockchain')
def pipe_data_received(self, fd, data): def pipe_data_received(self, fd, data):
if self.log and not any(ignore in data for ignore in self.IGNORE_OUTPUT): if self.log and not any(ignore in data for ignore in self.IGNORE_OUTPUT):
@ -227,7 +245,7 @@ class BlockchainNode:
self.bin_dir = os.path.join(self.project_dir, 'bin') self.bin_dir = os.path.join(self.project_dir, 'bin')
self.daemon_bin = os.path.join(self.bin_dir, daemon) self.daemon_bin = os.path.join(self.bin_dir, daemon)
self.cli_bin = os.path.join(self.bin_dir, cli) self.cli_bin = os.path.join(self.bin_dir, cli)
self.log = logging.getLogger('blockchain') self.log = log.getChild('blockchain')
self.data_path = None self.data_path = None
self.protocol = None self.protocol = None
self.transport = None self.transport = None
@ -289,7 +307,7 @@ class BlockchainNode:
) )
self.log.info(' '.join(command)) self.log.info(' '.join(command))
self.transport, self.protocol = await loop.subprocess_exec( self.transport, self.protocol = await loop.subprocess_exec(
lambda: BlockchainProcess(self.log), *command lambda: BlockchainProcess(), *command
) )
await self.protocol.ready.wait() await self.protocol.ready.wait()

View file

@ -657,6 +657,7 @@ class BlockProcessor:
# Shut down block processing # Shut down block processing
self.logger.info('flushing to DB for a clean shutdown...') self.logger.info('flushing to DB for a clean shutdown...')
await self.flush(True) await self.flush(True)
self.db.close()
def force_chain_reorg(self, count): def force_chain_reorg(self, count):
'''Force a reorg of the given number of blocks. '''Force a reorg of the given number of blocks.

View file

@ -134,7 +134,7 @@ class DB:
# Read TX counts (requires meta directory) # Read TX counts (requires meta directory)
await self._read_tx_counts() await self._read_tx_counts()
def shutdown(self): def close(self):
self.utxo_db.close() self.utxo_db.close()
self.history.close_db() self.history.close_db()

View file

@ -108,9 +108,10 @@ class Server:
await _start_cancellable(self.mempool.keep_synchronized) await _start_cancellable(self.mempool.keep_synchronized)
await _start_cancellable(self.session_mgr.serve, self.notifications) await _start_cancellable(self.session_mgr.serve, self.notifications)
def stop(self): async def stop(self):
for task in reversed(self.cancellable_tasks): for task in reversed(self.cancellable_tasks):
task.cancel() task.cancel()
await asyncio.wait(self.cancellable_tasks)
self.shutdown_event.set() self.shutdown_event.set()
def run(self): def run(self):

View file

@ -96,11 +96,10 @@ class AsyncioTestCase(unittest.TestCase):
"__unittest_expecting_failure__", False) "__unittest_expecting_failure__", False)
expecting_failure = expecting_failure_class or expecting_failure_method expecting_failure = expecting_failure_class or expecting_failure_method
outcome = _Outcome(result) outcome = _Outcome(result)
loop = asyncio.new_event_loop()
try: try:
self._outcome = outcome self._outcome = outcome
loop = asyncio.new_event_loop()
try:
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
loop.set_debug(True) loop.set_debug(True)
@ -110,14 +109,16 @@ class AsyncioTestCase(unittest.TestCase):
if outcome.success: if outcome.success:
outcome.expecting_failure = expecting_failure outcome.expecting_failure = expecting_failure
with outcome.testPartExecutor(self, isTest=True): with outcome.testPartExecutor(self, isTest=True):
possible_coroutine = testMethod() maybe_coroutine = testMethod()
if asyncio.iscoroutine(possible_coroutine): if asyncio.iscoroutine(maybe_coroutine):
loop.run_until_complete(possible_coroutine) loop.run_until_complete(maybe_coroutine)
outcome.expecting_failure = False outcome.expecting_failure = False
with outcome.testPartExecutor(self): with outcome.testPartExecutor(self):
loop.run_until_complete(self.asyncTearDown()) loop.run_until_complete(self.asyncTearDown())
self.tearDown() self.tearDown()
finally:
self.doAsyncCleanups(loop)
try: try:
_cancel_all_tasks(loop) _cancel_all_tasks(loop)
loop.run_until_complete(loop.shutdown_asyncgens()) loop.run_until_complete(loop.shutdown_asyncgens())
@ -125,8 +126,6 @@ class AsyncioTestCase(unittest.TestCase):
asyncio.set_event_loop(None) asyncio.set_event_loop(None)
loop.close() loop.close()
self.doCleanups()
for test, reason in outcome.skipped: for test, reason in outcome.skipped:
self._addSkip(result, test, reason) self._addSkip(result, test, reason)
self._feedErrorsToResult(result, outcome.errors) self._feedErrorsToResult(result, outcome.errors)
@ -155,6 +154,15 @@ class AsyncioTestCase(unittest.TestCase):
# clear the outcome, no more needed # clear the outcome, no more needed
self._outcome = None self._outcome = None
def doAsyncCleanups(self, loop):
outcome = self._outcome or _Outcome()
while self._cleanups:
function, args, kwargs = self._cleanups.pop()
with outcome.testPartExecutor(self):
maybe_coroutine = function(*args, **kwargs)
if asyncio.iscoroutine(maybe_coroutine):
loop.run_until_complete(maybe_coroutine)
class IntegrationTestCase(AsyncioTestCase): class IntegrationTestCase(AsyncioTestCase):
@ -176,7 +184,12 @@ class IntegrationTestCase(AsyncioTestCase):
self.conductor = Conductor( self.conductor = Conductor(
ledger_module=self.LEDGER, manager_module=self.MANAGER, verbosity=self.VERBOSITY ledger_module=self.LEDGER, manager_module=self.MANAGER, verbosity=self.VERBOSITY
) )
await self.conductor.start() await self.conductor.start_blockchain()
self.addCleanup(self.conductor.stop_blockchain)
await self.conductor.start_spv()
self.addCleanup(self.conductor.stop_spv)
await self.conductor.start_wallet()
self.addCleanup(self.conductor.stop_wallet)
self.blockchain = self.conductor.blockchain_node self.blockchain = self.conductor.blockchain_node
self.wallet_node = self.conductor.wallet_node self.wallet_node = self.conductor.wallet_node
self.manager = self.wallet_node.manager self.manager = self.wallet_node.manager
@ -184,9 +197,6 @@ class IntegrationTestCase(AsyncioTestCase):
self.wallet = self.wallet_node.wallet self.wallet = self.wallet_node.wallet
self.account = self.wallet_node.wallet.default_account self.account = self.wallet_node.wallet.default_account
async def asyncTearDown(self):
await self.conductor.stop()
async def assertBalance(self, account, expected_balance: str): # pylint: disable=C0103 async def assertBalance(self, account, expected_balance: str): # pylint: disable=C0103
balance = await account.get_balance() balance = await account.get_balance()
self.assertEqual(satoshis_to_coins(balance), expected_balance) self.assertEqual(satoshis_to_coins(balance), expected_balance)