lbry-sdk/torba/testcase.py

238 lines
8.4 KiB
Python
Raw Normal View History

2018-11-04 06:55:50 +01:00
import sys
import logging
2019-01-07 09:02:15 +01:00
import functools
2018-12-05 07:13:31 +01:00
import asyncio
from asyncio.runners import _cancel_all_tasks # type: ignore
2018-12-05 13:35:06 +01:00
import unittest
from unittest.case import _Outcome
from typing import Optional
2018-11-04 06:55:50 +01:00
from torba.orchstr8 import Conductor
from torba.orchstr8.node import BlockchainNode, WalletNode
from torba.client.baseledger import BaseLedger
from torba.client.baseaccount import BaseAccount
from torba.client.basemanager import BaseWalletManager
from torba.client.wallet import Wallet
2018-11-19 04:54:00 +01:00
from torba.client.util import satoshis_to_coins
2018-12-05 07:13:31 +01:00
class ColorHandler(logging.StreamHandler):
2018-11-04 01:07:23 +01:00
2018-12-05 07:13:31 +01:00
level_color = {
logging.DEBUG: "black",
logging.INFO: "light_gray",
2018-12-05 07:13:31 +01:00
logging.WARNING: "yellow",
logging.ERROR: "red"
}
2018-12-05 07:13:31 +01:00
color_code = dict(
black=30,
red=31,
green=32,
yellow=33,
blue=34,
magenta=35,
cyan=36,
white=37,
light_gray='0;37',
dark_gray='1;30'
2018-12-05 07:13:31 +01:00
)
def emit(self, record):
try:
msg = self.format(record)
color_name = self.level_color.get(record.levelno, "black")
color_code = self.color_code[color_name]
stream = self.stream
stream.write('\x1b[%sm%s\x1b[0m' % (color_code, msg))
stream.write(self.terminator)
self.flush()
except Exception:
self.handleError(record)
HANDLER = ColorHandler(sys.stdout)
2018-11-04 06:55:50 +01:00
HANDLER.setFormatter(
logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
)
logging.getLogger().addHandler(HANDLER)
class AsyncioTestCase(unittest.TestCase):
# Implementation inspired by discussion:
# https://bugs.python.org/issue32972
2018-11-04 00:45:28 +01:00
async def asyncSetUp(self): # pylint: disable=C0103
pass
2018-11-04 00:45:28 +01:00
async def asyncTearDown(self): # pylint: disable=C0103
pass
2018-11-04 00:45:28 +01:00
def run(self, result=None): # pylint: disable=R0915
orig_result = result
if result is None:
result = self.defaultTestResult()
2018-11-04 00:45:28 +01:00
startTestRun = getattr(result, 'startTestRun', None) # pylint: disable=C0103
if startTestRun is not None:
startTestRun()
result.startTest(self)
2018-11-04 00:45:28 +01:00
testMethod = getattr(self, self._testMethodName) # pylint: disable=C0103
if (getattr(self.__class__, "__unittest_skip__", False) or
getattr(testMethod, "__unittest_skip__", False)):
# If the class or method was skipped.
try:
skip_why = (getattr(self.__class__, '__unittest_skip_why__', '')
or getattr(testMethod, '__unittest_skip_why__', ''))
self._addSkip(result, self, skip_why)
finally:
result.stopTest(self)
return
expecting_failure_method = getattr(testMethod,
"__unittest_expecting_failure__", False)
expecting_failure_class = getattr(self,
"__unittest_expecting_failure__", False)
expecting_failure = expecting_failure_class or expecting_failure_method
outcome = _Outcome(result)
2019-01-07 09:02:15 +01:00
self.loop = asyncio.new_event_loop() # pylint: disable=W0201
asyncio.set_event_loop(self.loop)
self.loop.set_debug(True)
try:
self._outcome = outcome
with outcome.testPartExecutor(self):
self.setUp()
2019-01-07 09:02:15 +01:00
self.loop.run_until_complete(self.asyncSetUp())
if outcome.success:
outcome.expecting_failure = expecting_failure
with outcome.testPartExecutor(self, isTest=True):
maybe_coroutine = testMethod()
if asyncio.iscoroutine(maybe_coroutine):
2019-01-07 09:02:15 +01:00
self.loop.run_until_complete(maybe_coroutine)
outcome.expecting_failure = False
with outcome.testPartExecutor(self):
2019-01-07 09:02:15 +01:00
self.loop.run_until_complete(self.asyncTearDown())
self.tearDown()
2019-01-07 09:02:15 +01:00
self.doAsyncCleanups()
try:
2019-01-07 09:02:15 +01:00
_cancel_all_tasks(self.loop)
self.loop.run_until_complete(self.loop.shutdown_asyncgens())
finally:
asyncio.set_event_loop(None)
2019-01-07 09:02:15 +01:00
self.loop.close()
for test, reason in outcome.skipped:
self._addSkip(result, test, reason)
self._feedErrorsToResult(result, outcome.errors)
if outcome.success:
if expecting_failure:
if outcome.expectedFailure:
self._addExpectedFailure(result, outcome.expectedFailure)
else:
self._addUnexpectedSuccess(result)
else:
result.addSuccess(self)
return result
finally:
result.stopTest(self)
if orig_result is None:
2018-11-04 00:45:28 +01:00
stopTestRun = getattr(result, 'stopTestRun', None) # pylint: disable=C0103
if stopTestRun is not None:
2018-11-04 00:45:28 +01:00
stopTestRun() # pylint: disable=E1102
# explicitly break reference cycles:
# outcome.errors -> frame -> outcome -> outcome.errors
# outcome.expectedFailure -> frame -> outcome -> outcome.expectedFailure
outcome.errors.clear()
outcome.expectedFailure = None
# clear the outcome, no more needed
self._outcome = None
2019-01-07 09:02:15 +01:00
def doAsyncCleanups(self): # pylint: disable=C0103
outcome = self._outcome or _Outcome()
while self._cleanups:
function, args, kwargs = self._cleanups.pop()
with outcome.testPartExecutor(self):
maybe_coroutine = function(*args, **kwargs)
if asyncio.iscoroutine(maybe_coroutine):
2019-01-07 09:02:15 +01:00
self.loop.run_until_complete(maybe_coroutine)
class AdvanceTimeTestCase(AsyncioTestCase):
async def asyncSetUp(self):
self._time = 0 # pylint: disable=W0201
self.loop.time = functools.wraps(self.loop.time)(lambda: self._time)
await super().asyncSetUp()
async def advance(self, seconds):
while self.loop._ready:
await asyncio.sleep(0)
self._time += seconds
await asyncio.sleep(0)
while self.loop._ready:
await asyncio.sleep(0)
class IntegrationTestCase(AsyncioTestCase):
LEDGER = None
MANAGER = None
2018-11-04 06:55:50 +01:00
VERBOSITY = logging.WARN
2018-11-04 00:45:28 +01:00
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.conductor: Optional[Conductor] = None
self.blockchain: Optional[BlockchainNode] = None
self.wallet_node: Optional[WalletNode] = None
self.manager: Optional[BaseWalletManager] = None
self.ledger: Optional[BaseLedger] = None
self.wallet: Optional[Wallet] = None
self.account: Optional[BaseAccount] = None
2018-11-04 00:45:28 +01:00
async def asyncSetUp(self):
self.conductor = Conductor(
ledger_module=self.LEDGER, manager_module=self.MANAGER, verbosity=self.VERBOSITY
)
await self.conductor.start_blockchain()
self.addCleanup(self.conductor.stop_blockchain)
await self.conductor.start_spv()
self.addCleanup(self.conductor.stop_spv)
await self.conductor.start_wallet()
self.addCleanup(self.conductor.stop_wallet)
self.blockchain = self.conductor.blockchain_node
2018-11-04 06:55:50 +01:00
self.wallet_node = self.conductor.wallet_node
self.manager = self.wallet_node.manager
self.ledger = self.wallet_node.ledger
self.wallet = self.wallet_node.wallet
self.account = self.wallet_node.wallet.default_account
2018-11-19 05:07:54 +01:00
async def assertBalance(self, account, expected_balance: str): # pylint: disable=C0103
2018-11-19 04:54:00 +01:00
balance = await account.get_balance()
self.assertEqual(satoshis_to_coins(balance), expected_balance)
def broadcast(self, tx):
return self.ledger.broadcast(tx)
async def on_header(self, height):
if self.ledger.headers.height < height:
await self.ledger.on_header.where(
lambda e: e.height == height
)
return True
2018-11-19 04:54:00 +01:00
def on_transaction_id(self, txid, ledger=None):
return (ledger or self.ledger).on_transaction.where(
lambda e: e.tx.id == txid
)
def on_transaction_address(self, tx, address):
return self.ledger.on_transaction.where(
lambda e: e.tx.id == tx.id and e.address == address
)