improved shutdown for torba.server and related test setup code
This commit is contained in:
parent
458189366f
commit
7092f40701
5 changed files with 88 additions and 58 deletions
|
@ -19,6 +19,9 @@ from torba.client.basemanager import BaseWalletManager
|
|||
from torba.client.baseaccount import BaseAccount
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_manager_from_environment(default_manager=BaseWalletManager):
|
||||
if 'TORBA_MANAGER' not in os.environ:
|
||||
return default_manager
|
||||
|
@ -73,40 +76,55 @@ class Conductor:
|
|||
self.spv_started = False
|
||||
self.wallet_started = False
|
||||
|
||||
self.log = log.getChild('conductor')
|
||||
|
||||
async def start_blockchain(self):
|
||||
if not self.blockchain_started:
|
||||
await self.blockchain_node.start()
|
||||
await self.blockchain_node.generate(200)
|
||||
self.blockchain_started = True
|
||||
|
||||
async def stop_blockchain(self):
|
||||
if self.blockchain_started:
|
||||
await self.blockchain_node.stop(cleanup=True)
|
||||
self.blockchain_started = False
|
||||
|
||||
async def start_spv(self):
|
||||
if not self.spv_started:
|
||||
await self.spv_node.start()
|
||||
self.spv_started = True
|
||||
|
||||
async def stop_spv(self):
|
||||
if self.spv_started:
|
||||
await self.spv_node.stop(cleanup=True)
|
||||
self.spv_started = False
|
||||
|
||||
async def start_wallet(self):
|
||||
if not self.wallet_started:
|
||||
await self.wallet_node.start()
|
||||
self.wallet_started = True
|
||||
|
||||
async def stop_wallet(self):
|
||||
if self.wallet_started:
|
||||
await self.wallet_node.stop(cleanup=True)
|
||||
self.wallet_started = False
|
||||
|
||||
async def start(self):
|
||||
self.blockchain_started or await self.start_blockchain()
|
||||
self.spv_started or await self.start_spv()
|
||||
self.wallet_started or await self.start_wallet()
|
||||
await self.start_blockchain()
|
||||
await self.start_spv()
|
||||
await self.start_wallet()
|
||||
|
||||
async def stop(self):
|
||||
if self.wallet_started:
|
||||
all_the_stops = [
|
||||
self.wallet_node.stop,
|
||||
self.spv_node.stop,
|
||||
self.blockchain_node.stop
|
||||
]
|
||||
for stop in all_the_stops:
|
||||
try:
|
||||
await self.wallet_node.stop(cleanup=True)
|
||||
await stop()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
if self.spv_started:
|
||||
try:
|
||||
await self.spv_node.stop(cleanup=True)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
if self.blockchain_started:
|
||||
try:
|
||||
await self.blockchain_node.stop(cleanup=True)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception('Exception raised while stopping services:', exc_info=e)
|
||||
|
||||
|
||||
class WalletNode:
|
||||
|
@ -182,7 +200,7 @@ class SPVNode:
|
|||
|
||||
async def stop(self, cleanup=True):
|
||||
try:
|
||||
self.server.db.shutdown()
|
||||
self.server.stop()
|
||||
finally:
|
||||
cleanup and self.cleanup()
|
||||
|
||||
|
@ -198,10 +216,10 @@ class BlockchainProcess(asyncio.SubprocessProtocol):
|
|||
b'keypool return',
|
||||
]
|
||||
|
||||
def __init__(self, log):
|
||||
def __init__(self):
|
||||
self.ready = asyncio.Event()
|
||||
self.stopped = asyncio.Event()
|
||||
self.log = log
|
||||
self.log = log.getChild('blockchain')
|
||||
|
||||
def pipe_data_received(self, fd, data):
|
||||
if self.log and not any(ignore in data for ignore in self.IGNORE_OUTPUT):
|
||||
|
@ -227,7 +245,7 @@ class BlockchainNode:
|
|||
self.bin_dir = os.path.join(self.project_dir, 'bin')
|
||||
self.daemon_bin = os.path.join(self.bin_dir, daemon)
|
||||
self.cli_bin = os.path.join(self.bin_dir, cli)
|
||||
self.log = logging.getLogger('blockchain')
|
||||
self.log = log.getChild('blockchain')
|
||||
self.data_path = None
|
||||
self.protocol = None
|
||||
self.transport = None
|
||||
|
@ -289,7 +307,7 @@ class BlockchainNode:
|
|||
)
|
||||
self.log.info(' '.join(command))
|
||||
self.transport, self.protocol = await loop.subprocess_exec(
|
||||
lambda: BlockchainProcess(self.log), *command
|
||||
lambda: BlockchainProcess(), *command
|
||||
)
|
||||
await self.protocol.ready.wait()
|
||||
|
||||
|
|
|
@ -657,6 +657,7 @@ class BlockProcessor:
|
|||
# Shut down block processing
|
||||
self.logger.info('flushing to DB for a clean shutdown...')
|
||||
await self.flush(True)
|
||||
self.db.close()
|
||||
|
||||
def force_chain_reorg(self, count):
|
||||
'''Force a reorg of the given number of blocks.
|
||||
|
|
|
@ -134,7 +134,7 @@ class DB:
|
|||
# Read TX counts (requires meta directory)
|
||||
await self._read_tx_counts()
|
||||
|
||||
def shutdown(self):
|
||||
def close(self):
|
||||
self.utxo_db.close()
|
||||
self.history.close_db()
|
||||
|
||||
|
|
|
@ -108,9 +108,10 @@ class Server:
|
|||
await _start_cancellable(self.mempool.keep_synchronized)
|
||||
await _start_cancellable(self.session_mgr.serve, self.notifications)
|
||||
|
||||
def stop(self):
|
||||
async def stop(self):
|
||||
for task in reversed(self.cancellable_tasks):
|
||||
task.cancel()
|
||||
await asyncio.wait(self.cancellable_tasks)
|
||||
self.shutdown_event.set()
|
||||
|
||||
def run(self):
|
||||
|
|
|
@ -96,11 +96,10 @@ class AsyncioTestCase(unittest.TestCase):
|
|||
"__unittest_expecting_failure__", False)
|
||||
expecting_failure = expecting_failure_class or expecting_failure_method
|
||||
outcome = _Outcome(result)
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
self._outcome = outcome
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.set_debug(True)
|
||||
|
||||
|
@ -110,14 +109,16 @@ class AsyncioTestCase(unittest.TestCase):
|
|||
if outcome.success:
|
||||
outcome.expecting_failure = expecting_failure
|
||||
with outcome.testPartExecutor(self, isTest=True):
|
||||
possible_coroutine = testMethod()
|
||||
if asyncio.iscoroutine(possible_coroutine):
|
||||
loop.run_until_complete(possible_coroutine)
|
||||
maybe_coroutine = testMethod()
|
||||
if asyncio.iscoroutine(maybe_coroutine):
|
||||
loop.run_until_complete(maybe_coroutine)
|
||||
outcome.expecting_failure = False
|
||||
with outcome.testPartExecutor(self):
|
||||
loop.run_until_complete(self.asyncTearDown())
|
||||
self.tearDown()
|
||||
finally:
|
||||
|
||||
self.doAsyncCleanups(loop)
|
||||
|
||||
try:
|
||||
_cancel_all_tasks(loop)
|
||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||
|
@ -125,8 +126,6 @@ class AsyncioTestCase(unittest.TestCase):
|
|||
asyncio.set_event_loop(None)
|
||||
loop.close()
|
||||
|
||||
self.doCleanups()
|
||||
|
||||
for test, reason in outcome.skipped:
|
||||
self._addSkip(result, test, reason)
|
||||
self._feedErrorsToResult(result, outcome.errors)
|
||||
|
@ -155,6 +154,15 @@ class AsyncioTestCase(unittest.TestCase):
|
|||
# clear the outcome, no more needed
|
||||
self._outcome = None
|
||||
|
||||
def doAsyncCleanups(self, loop):
|
||||
outcome = self._outcome or _Outcome()
|
||||
while self._cleanups:
|
||||
function, args, kwargs = self._cleanups.pop()
|
||||
with outcome.testPartExecutor(self):
|
||||
maybe_coroutine = function(*args, **kwargs)
|
||||
if asyncio.iscoroutine(maybe_coroutine):
|
||||
loop.run_until_complete(maybe_coroutine)
|
||||
|
||||
|
||||
class IntegrationTestCase(AsyncioTestCase):
|
||||
|
||||
|
@ -176,7 +184,12 @@ class IntegrationTestCase(AsyncioTestCase):
|
|||
self.conductor = Conductor(
|
||||
ledger_module=self.LEDGER, manager_module=self.MANAGER, verbosity=self.VERBOSITY
|
||||
)
|
||||
await self.conductor.start()
|
||||
await self.conductor.start_blockchain()
|
||||
self.addCleanup(self.conductor.stop_blockchain)
|
||||
await self.conductor.start_spv()
|
||||
self.addCleanup(self.conductor.stop_spv)
|
||||
await self.conductor.start_wallet()
|
||||
self.addCleanup(self.conductor.stop_wallet)
|
||||
self.blockchain = self.conductor.blockchain_node
|
||||
self.wallet_node = self.conductor.wallet_node
|
||||
self.manager = self.wallet_node.manager
|
||||
|
@ -184,9 +197,6 @@ class IntegrationTestCase(AsyncioTestCase):
|
|||
self.wallet = self.wallet_node.wallet
|
||||
self.account = self.wallet_node.wallet.default_account
|
||||
|
||||
async def asyncTearDown(self):
|
||||
await self.conductor.stop()
|
||||
|
||||
async def assertBalance(self, account, expected_balance: str): # pylint: disable=C0103
|
||||
balance = await account.get_balance()
|
||||
self.assertEqual(satoshis_to_coins(balance), expected_balance)
|
||||
|
|
Loading…
Reference in a new issue