lbry-sdk/torba/testcase.py
2018-11-04 01:55:50 -04:00

196 lines
6.9 KiB
Python

import sys
import logging
import unittest
from unittest.case import _Outcome
from torba.orchstr8 import Conductor
try:
import asyncio
from asyncio.runners import _cancel_all_tasks # type: ignore
except ImportError:
import asyncio
# this is only available in py3.7
def _cancel_all_tasks(loop):
pass
HANDLER = logging.StreamHandler(sys.stdout)
HANDLER.setFormatter(
logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
)
logging.getLogger().addHandler(HANDLER)
class AsyncioTestCase(unittest.TestCase):
# Implementation inspired by discussion:
# https://bugs.python.org/issue32972
async def asyncSetUp(self): # pylint: disable=C0103
pass
async def asyncTearDown(self): # pylint: disable=C0103
pass
async def doAsyncCleanups(self): # pylint: disable=C0103
pass
def run(self, result=None): # pylint: disable=R0915
orig_result = result
if result is None:
result = self.defaultTestResult()
startTestRun = getattr(result, 'startTestRun', None) # pylint: disable=C0103
if startTestRun is not None:
startTestRun()
result.startTest(self)
testMethod = getattr(self, self._testMethodName) # pylint: disable=C0103
if (getattr(self.__class__, "__unittest_skip__", False) or
getattr(testMethod, "__unittest_skip__", False)):
# If the class or method was skipped.
try:
skip_why = (getattr(self.__class__, '__unittest_skip_why__', '')
or getattr(testMethod, '__unittest_skip_why__', ''))
self._addSkip(result, self, skip_why)
finally:
result.stopTest(self)
return
expecting_failure_method = getattr(testMethod,
"__unittest_expecting_failure__", False)
expecting_failure_class = getattr(self,
"__unittest_expecting_failure__", False)
expecting_failure = expecting_failure_class or expecting_failure_method
outcome = _Outcome(result)
try:
self._outcome = outcome
loop = asyncio.new_event_loop()
try:
asyncio.set_event_loop(loop)
loop.set_debug(True)
with outcome.testPartExecutor(self):
self.setUp()
loop.run_until_complete(self.asyncSetUp())
if outcome.success:
outcome.expecting_failure = expecting_failure
with outcome.testPartExecutor(self, isTest=True):
possible_coroutine = testMethod()
if asyncio.iscoroutine(possible_coroutine):
loop.run_until_complete(possible_coroutine)
outcome.expecting_failure = False
with outcome.testPartExecutor(self):
loop.run_until_complete(self.asyncTearDown())
self.tearDown()
finally:
try:
_cancel_all_tasks(loop)
loop.run_until_complete(loop.shutdown_asyncgens())
finally:
asyncio.set_event_loop(None)
loop.close()
self.doCleanups()
for test, reason in outcome.skipped:
self._addSkip(result, test, reason)
self._feedErrorsToResult(result, outcome.errors)
if outcome.success:
if expecting_failure:
if outcome.expectedFailure:
self._addExpectedFailure(result, outcome.expectedFailure)
else:
self._addUnexpectedSuccess(result)
else:
result.addSuccess(self)
return result
finally:
result.stopTest(self)
if orig_result is None:
stopTestRun = getattr(result, 'stopTestRun', None) # pylint: disable=C0103
if stopTestRun is not None:
stopTestRun() # pylint: disable=E1102
# explicitly break reference cycles:
# outcome.errors -> frame -> outcome -> outcome.errors
# outcome.expectedFailure -> frame -> outcome -> outcome.expectedFailure
outcome.errors.clear()
outcome.expectedFailure = None
# clear the outcome, no more needed
self._outcome = None
class IntegrationTestCase(AsyncioTestCase):
LEDGER = None
MANAGER = None
VERBOSITY = logging.WARN
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.conductor = None
self.blockchain = None
self.wallet_node = None
self.manager = None
self.ledger = None
self.wallet = None
self.account = None
async def asyncSetUp(self):
self.conductor = Conductor(
ledger_module=self.LEDGER, manager_module=self.MANAGER, verbosity=self.VERBOSITY
)
await self.conductor.start()
self.blockchain = self.conductor.blockchain_node
self.wallet_node = self.conductor.wallet_node
self.manager = self.wallet_node.manager
self.ledger = self.wallet_node.ledger
self.wallet = self.wallet_node.wallet
self.account = self.wallet_node.wallet.default_account
async def asyncTearDown(self):
await self.conductor.stop()
def broadcast(self, tx):
return self.ledger.broadcast(tx)
def get_balance(self, account=None, confirmations=0):
if account is None:
return self.manager.get_balance(confirmations=confirmations)
else:
return account.get_balance(confirmations=confirmations)
async def on_header(self, height):
if self.ledger.headers.height < height:
await self.ledger.on_header.where(
lambda e: e.height == height
)
return True
def on_transaction_id(self, txid):
return self.ledger.on_transaction.where(
lambda e: e.tx.id == txid
)
def on_transaction_address(self, tx, address):
return self.ledger.on_transaction.where(
lambda e: e.tx.id == tx.id and e.address == address
)
async def on_transaction(self, tx):
addresses = await self.get_tx_addresses(tx, self.ledger)
await asyncio.wait([
self.ledger.on_transaction.where(lambda e: e.address == address) # pylint: disable=W0640
for address in addresses
])
async def get_tx_addresses(self, tx, ledger):
addresses = set()
for txo in tx.outputs:
address = ledger.hash160_to_address(txo.script.values['pubkey_hash'])
record = await ledger.db.get_address(address=address)
if record is not None:
addresses.add(address)
return list(addresses)