This commit is contained in:
Lex Berezhny 2020-09-16 10:37:49 -04:00
parent 999e4209fa
commit fa85558d71
8 changed files with 240 additions and 61 deletions

View file

@ -77,12 +77,21 @@ class Lbrycrd:
self.subscribed = False self.subscribed = False
self.subscription: Optional[asyncio.Task] = None self.subscription: Optional[asyncio.Task] = None
self.default_generate_address = None self.default_generate_address = None
self._on_block_controller = EventController() self._on_block_hash_controller = EventController()
self.on_block = self._on_block_controller.stream self.on_block_hash = self._on_block_hash_controller.stream
self.on_block.listen(lambda e: log.info('%s %s', hexlify(e['hash']), e['msg'])) self.on_block_hash.listen(lambda e: log.info('%s %s', hexlify(e['hash']), e['msg']))
self._on_tx_hash_controller = EventController()
self.on_tx_hash = self._on_tx_hash_controller.stream
self.db = BlockchainDB(self.actual_data_dir) self.db = BlockchainDB(self.actual_data_dir)
self.session: Optional[aiohttp.ClientSession] = None self._session: Optional[aiohttp.ClientSession] = None
self._loop: Optional[asyncio.AbstractEventLoop] = None
@property
def session(self) -> aiohttp.ClientSession:
if self._session is None:
self._session = aiohttp.ClientSession()
return self._session
@classmethod @classmethod
def temp_regtest(cls): def temp_regtest(cls):
@ -91,7 +100,7 @@ class Lbrycrd:
blockchain="regtest", blockchain="regtest",
lbrycrd_rpc_port=9245 + 2, # avoid conflict with default rpc port lbrycrd_rpc_port=9245 + 2, # avoid conflict with default rpc port
lbrycrd_peer_port=9246 + 2, # avoid conflict with default peer port lbrycrd_peer_port=9246 + 2, # avoid conflict with default peer port
lbrycrd_zmq_blocks="tcp://127.0.0.1:29002" # avoid conflict with default port lbrycrd_zmq="tcp://127.0.0.1:29002"
) )
)) ))
@ -161,8 +170,11 @@ class Lbrycrd:
def get_start_command(self, *args): def get_start_command(self, *args):
if self.is_regtest: if self.is_regtest:
args += ('-regtest',) args += ('-regtest',)
if self.conf.lbrycrd_zmq_blocks: if self.conf.lbrycrd_zmq:
args += (f'-zmqpubhashblock={self.conf.lbrycrd_zmq_blocks}',) args += (
f'-zmqpubhashblock={self.conf.lbrycrd_zmq}',
f'-zmqpubhashtx={self.conf.lbrycrd_zmq}',
)
return ( return (
self.daemon_bin, self.daemon_bin,
f'-datadir={self.data_dir}', f'-datadir={self.data_dir}',
@ -175,13 +187,15 @@ class Lbrycrd:
) )
async def open(self): async def open(self):
self.session = aiohttp.ClientSession()
await self.db.open() await self.db.open()
async def close(self): async def close(self):
await self.db.close() await self.db.close()
if self.session is not None: await self.close_session()
await self.session.close()
async def close_session(self):
if self._session is not None:
await self._session.close()
async def start(self, *args): async def start(self, *args):
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
@ -213,23 +227,30 @@ class Lbrycrd:
subs = {e['type']: e['address'] for e in zmq_notifications} subs = {e['type']: e['address'] for e in zmq_notifications}
if ZMQ_BLOCK_EVENT not in subs: if ZMQ_BLOCK_EVENT not in subs:
raise LbrycrdEventSubscriptionError(ZMQ_BLOCK_EVENT) raise LbrycrdEventSubscriptionError(ZMQ_BLOCK_EVENT)
if not self.conf.lbrycrd_zmq_blocks: if not self.conf.lbrycrd_zmq:
self.conf.lbrycrd_zmq_blocks = subs[ZMQ_BLOCK_EVENT] self.conf.lbrycrd_zmq = subs[ZMQ_BLOCK_EVENT]
async def subscribe(self): async def subscribe(self):
if not self.subscribed: if not self.subscribed:
self.subscribed = True self.subscribed = True
ctx = zmq.asyncio.Context.instance() ctx = zmq.asyncio.Context.instance()
sock = ctx.socket(zmq.SUB) # pylint: disable=no-member sock = ctx.socket(zmq.SUB) # pylint: disable=no-member
sock.connect(self.conf.lbrycrd_zmq_blocks) sock.connect(self.conf.lbrycrd_zmq)
sock.subscribe("hashblock") sock.subscribe("hashblock")
sock.subscribe("hashtx")
self.subscription = asyncio.create_task(self.subscription_handler(sock)) self.subscription = asyncio.create_task(self.subscription_handler(sock))
async def subscription_handler(self, sock): async def subscription_handler(self, sock):
try: try:
while self.subscribed: while self.subscribed:
msg = await sock.recv_multipart() msg = await sock.recv_multipart()
await self._on_block_controller.add({ if msg[0] == b'hashtx':
await self._on_tx_hash_controller.add({
'hash': msg[1],
'msg': struct.unpack('<I', msg[2])[0]
})
elif msg[0] == b'hashblock':
await self._on_block_hash_controller.add({
'hash': msg[1], 'hash': msg[1],
'msg': struct.unpack('<I', msg[2])[0] 'msg': struct.unpack('<I', msg[2])[0]
}) })
@ -243,8 +264,16 @@ class Lbrycrd:
self.subscription.cancel() self.subscription.cancel()
self.subscription = None self.subscription = None
def sync_run(self, coro):
if self._loop is None:
try:
self._loop = asyncio.get_event_loop()
except RuntimeError:
self._loop = asyncio.new_event_loop()
return self._loop.run_until_complete(coro)
async def rpc(self, method, params=None): async def rpc(self, method, params=None):
if self.session.closed: if self._session is not None and self._session.closed:
raise Exception("session is closed! RPC attempted during shutting down.") raise Exception("session is closed! RPC attempted during shutting down.")
message = { message = {
"jsonrpc": "1.0", "jsonrpc": "1.0",
@ -285,6 +314,9 @@ class Lbrycrd:
async def get_block(self, block_hash): async def get_block(self, block_hash):
return await self.rpc("getblock", [block_hash]) return await self.rpc("getblock", [block_hash])
async def get_raw_mempool(self):
return await self.rpc("getrawmempool")
async def get_raw_transaction(self, txid): async def get_raw_transaction(self, txid):
return await self.rpc("getrawtransaction", [txid]) return await self.rpc("getrawtransaction", [txid])

View file

@ -1,4 +1,6 @@
import asyncio
import logging import logging
from binascii import hexlify, unhexlify
from typing import Tuple from typing import Tuple
from sqlalchemy import table, text, func, union, between from sqlalchemy import table, text, func, union, between
@ -19,6 +21,7 @@ from lbry.db.tables import (
) )
from lbry.db.query_context import ProgressContext, event_emitter, context from lbry.db.query_context import ProgressContext, event_emitter, context
from lbry.db.sync import set_input_addresses, update_spent_outputs from lbry.db.sync import set_input_addresses, update_spent_outputs
from lbry.blockchain.transaction import Transaction
from lbry.blockchain.block import Block, create_address_filter from lbry.blockchain.block import Block, create_address_filter
from lbry.blockchain.bcd_data_stream import BCDataStream from lbry.blockchain.bcd_data_stream import BCDataStream
@ -177,6 +180,30 @@ def sync_spends(initial_sync: bool, p: ProgressContext):
p.step() p.step()
@event_emitter("blockchain.sync.mempool.clear", "txs")
def clear_mempool(p: ProgressContext):
delete_all_the_things(-1, p)
@event_emitter("blockchain.sync.mempool.main", "txs")
def sync_mempool(p: ProgressContext):
chain = get_or_initialize_lbrycrd(p.ctx)
mempool = chain.sync_run(chain.get_raw_mempool())
current = [hexlify(r['tx_hash'][::-1]) for r in p.ctx.fetchall(
select(TX.c.tx_hash).where(TX.c.height < 0)
)]
loader = p.ctx.get_bulk_loader()
for txid in mempool:
if txid not in current:
raw_tx = chain.sync_run(chain.get_raw_transaction(txid))
loader.add_transaction(
None, Transaction(unhexlify(raw_tx), height=-1)
)
if p.ctx.stop_event.is_set():
return
loader.flush(TX)
@event_emitter("blockchain.sync.filters.generate", "blocks", throttle=100) @event_emitter("blockchain.sync.filters.generate", "blocks", throttle=100)
def sync_filters(start, end, p: ProgressContext): def sync_filters(start, end, p: ProgressContext):
fp = FilterBuilder(start, end) fp = FilterBuilder(start, end)
@ -273,23 +300,35 @@ def get_block_tx_addresses_sql(start_height, end_height):
@event_emitter("blockchain.sync.rewind.main", "steps") @event_emitter("blockchain.sync.rewind.main", "steps")
def rewind(height: int, p: ProgressContext): def rewind(height: int, p: ProgressContext):
delete_all_the_things(height, p)
def delete_all_the_things(height: int, p: ProgressContext):
def constrain(col):
if height >= 0:
return col >= height
return col <= height
deletes = [ deletes = [
BlockTable.delete().where(BlockTable.c.height >= height), BlockTable.delete().where(constrain(BlockTable.c.height)),
TXI.delete().where(TXI.c.height >= height), TXI.delete().where(constrain(TXI.c.height)),
TXO.delete().where(TXO.c.height >= height), TXO.delete().where(constrain(TXO.c.height)),
TX.delete().where(TX.c.height >= height), TX.delete().where(constrain(TX.c.height)),
Tag.delete().where( Tag.delete().where(
Tag.c.claim_hash.in_( Tag.c.claim_hash.in_(
select(Claim.c.claim_hash).where(Claim.c.height >= height) select(Claim.c.claim_hash).where(constrain(Claim.c.height))
) )
), ),
Claim.delete().where(Claim.c.height >= height), Claim.delete().where(constrain(Claim.c.height)),
Support.delete().where(Support.c.height >= height), Support.delete().where(constrain(Support.c.height)),
MempoolFilter.delete(),
]
if height > 0:
deletes.extend([
BlockFilter.delete().where(BlockFilter.c.height >= height), BlockFilter.delete().where(BlockFilter.c.height >= height),
# TODO: group and tx filters need where() clauses (below actually breaks things) # TODO: group and tx filters need where() clauses (below actually breaks things)
BlockGroupFilter.delete(), BlockGroupFilter.delete(),
TXFilter.delete(), TXFilter.delete(),
MempoolFilter.delete() ])
]
for delete in p.iter(deletes): for delete in p.iter(deletes):
p.ctx.execute(delete) p.ctx.execute(delete)

View file

@ -15,3 +15,11 @@ def get_or_initialize_lbrycrd(ctx=None) -> Lbrycrd:
chain.db.sync_open() chain.db.sync_open()
_chain.set(chain) _chain.set(chain)
return chain return chain
def uninitialize():
chain = _chain.get(None)
if chain is not None:
chain.db.sync_close()
chain.sync_run(chain.close_session())
_chain.set(None)

View file

@ -2,17 +2,19 @@ import os
import asyncio import asyncio
import logging import logging
from typing import Optional, Tuple, Set, List, Coroutine from typing import Optional, Tuple, Set, List, Coroutine
from concurrent.futures import ThreadPoolExecutor
from lbry.db import Database from lbry.db import Database
from lbry.db import queries as q from lbry.db import queries as q
from lbry.db.constants import TXO_TYPES, CLAIM_TYPE_CODES from lbry.db.constants import TXO_TYPES, CLAIM_TYPE_CODES
from lbry.db.query_context import Event, Progress from lbry.db.query_context import Event, Progress
from lbry.event import BroadcastSubscription from lbry.event import BroadcastSubscription, EventController
from lbry.service.base import Sync, BlockEvent from lbry.service.base import Sync, BlockEvent
from lbry.blockchain.lbrycrd import Lbrycrd from lbry.blockchain.lbrycrd import Lbrycrd
from lbry.error import LbrycrdEventSubscriptionError from lbry.error import LbrycrdEventSubscriptionError
from . import blocks as block_phase, claims as claim_phase, supports as support_phase from . import blocks as block_phase, claims as claim_phase, supports as support_phase
from .context import uninitialize
from .filter_builder import split_range_into_10k_batches from .filter_builder import split_range_into_10k_batches
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -40,9 +42,13 @@ class BlockchainSync(Sync):
super().__init__(chain.ledger, db) super().__init__(chain.ledger, db)
self.chain = chain self.chain = chain
self.pid = os.getpid() self.pid = os.getpid()
self.on_block_subscription: Optional[BroadcastSubscription] = None self.on_block_hash_subscription: Optional[BroadcastSubscription] = None
self.on_tx_hash_subscription: Optional[BroadcastSubscription] = None
self.advance_loop_task: Optional[asyncio.Task] = None self.advance_loop_task: Optional[asyncio.Task] = None
self.advance_loop_event = asyncio.Event() self.block_hash_event = asyncio.Event()
self.tx_hash_event = asyncio.Event()
self._on_mempool_controller = EventController()
self.on_mempool = self._on_mempool_controller.stream
async def wait_for_chain_ready(self): async def wait_for_chain_ready(self):
while True: while True:
@ -67,17 +73,25 @@ class BlockchainSync(Sync):
await self.advance_loop_task await self.advance_loop_task
await self.chain.subscribe() await self.chain.subscribe()
self.advance_loop_task = asyncio.create_task(self.advance_loop()) self.advance_loop_task = asyncio.create_task(self.advance_loop())
self.on_block_subscription = self.chain.on_block.listen( self.on_block_hash_subscription = self.chain.on_block_hash.listen(
lambda e: self.advance_loop_event.set() lambda e: self.block_hash_event.set()
)
self.on_tx_hash_subscription = self.chain.on_tx_hash.listen(
lambda e: self.tx_hash_event.set()
) )
async def stop(self): async def stop(self):
self.chain.unsubscribe() self.chain.unsubscribe()
if self.on_block_subscription is not None:
self.on_block_subscription.cancel()
self.db.stop_event.set() self.db.stop_event.set()
if self.advance_loop_task is not None: for subscription in (
self.advance_loop_task.cancel() self.on_block_hash_subscription,
self.on_tx_hash_subscription,
self.advance_loop_task
):
if subscription is not None:
subscription.cancel()
if isinstance(self.db.executor, ThreadPoolExecutor):
await self.db.run(uninitialize)
async def run_tasks(self, tasks: List[Coroutine]) -> Optional[Set[asyncio.Future]]: async def run_tasks(self, tasks: List[Coroutine]) -> Optional[Set[asyncio.Future]]:
done, pending = await asyncio.wait( done, pending = await asyncio.wait(
@ -337,12 +351,25 @@ class BlockchainSync(Sync):
if blocks_added: if blocks_added:
await self._on_block_controller.add(BlockEvent(blocks_added[-1])) await self._on_block_controller.add(BlockEvent(blocks_added[-1]))
async def sync_mempool(self):
await self.db.run(block_phase.sync_mempool)
await self.sync_spends([-1])
await self.db.run(claim_phase.claims_insert, [-2, 0], True, self.CLAIM_FLUSH_SIZE)
await self.db.run(claim_phase.claims_vacuum)
async def advance_loop(self): async def advance_loop(self):
while True: while True:
await self.advance_loop_event.wait()
self.advance_loop_event.clear()
try: try:
await asyncio.wait([
self.tx_hash_event.wait(),
self.block_hash_event.wait(),
], return_when=asyncio.FIRST_COMPLETED)
if self.block_hash_event.is_set():
self.block_hash_event.clear()
await self.db.run(block_phase.clear_mempool)
await self.advance() await self.advance()
self.tx_hash_event.clear()
await self.sync_mempool()
except asyncio.CancelledError: except asyncio.CancelledError:
return return
except Exception as e: except Exception as e:

View file

@ -594,7 +594,7 @@ class Config(CLIConfig):
reflector_servers = Servers("Reflector re-hosting servers", [ reflector_servers = Servers("Reflector re-hosting servers", [
('reflector.lbry.com', 5566) ('reflector.lbry.com', 5566)
]) ])
lbryum_servers = Servers("SPV wallet servers", [ known_full_nodes = Servers("Full blockchain nodes", [
('spv11.lbry.com', 50001), ('spv11.lbry.com', 50001),
('spv12.lbry.com', 50001), ('spv12.lbry.com', 50001),
('spv13.lbry.com', 50001), ('spv13.lbry.com', 50001),
@ -621,7 +621,7 @@ class Config(CLIConfig):
lbrycrd_rpc_host = String("Hostname for connecting to lbrycrd.", "localhost") lbrycrd_rpc_host = String("Hostname for connecting to lbrycrd.", "localhost")
lbrycrd_rpc_port = Integer("Port for connecting to lbrycrd.", 9245) lbrycrd_rpc_port = Integer("Port for connecting to lbrycrd.", 9245)
lbrycrd_peer_port = Integer("Peer port for lbrycrd.", 9246) lbrycrd_peer_port = Integer("Peer port for lbrycrd.", 9246)
lbrycrd_zmq_blocks = String("ZMQ block events address.") lbrycrd_zmq = String("ZMQ events address.")
lbrycrd_dir = Path("Directory containing lbrycrd data.", metavar='DIR') lbrycrd_dir = Path("Directory containing lbrycrd data.", metavar='DIR')
spv_address_filters = Toggle( spv_address_filters = Toggle(
"Generate Golomb-Rice coding filters for blocks and transactions. Enables " "Generate Golomb-Rice coding filters for blocks and transactions. Enables "

View file

@ -100,22 +100,31 @@ class Database:
return 1 return 1
@classmethod @classmethod
def temp_from_url_regtest(cls, db_url, lbrycrd_dir=None): def temp_from_url_regtest(cls, db_url, lbrycrd_config=None):
from lbry import Config, RegTestLedger # pylint: disable=import-outside-toplevel from lbry import Config, RegTestLedger # pylint: disable=import-outside-toplevel
directory = tempfile.mkdtemp() directory = tempfile.mkdtemp()
conf = Config.with_same_dir(directory).set(db_url=db_url) if lbrycrd_config:
if lbrycrd_dir is not None: conf = lbrycrd_config
conf.lbrycrd_dir = lbrycrd_dir conf.data_dir = directory
conf.download_dir = directory
conf.wallet_dir = directory
else:
conf = Config.with_same_dir(directory)
conf.set(blockchain="regtest", db_url=db_url)
ledger = RegTestLedger(conf) ledger = RegTestLedger(conf)
return cls(ledger) return cls(ledger)
@classmethod @classmethod
def temp_sqlite_regtest(cls, lbrycrd_dir=None): def temp_sqlite_regtest(cls, lbrycrd_config=None):
from lbry import Config, RegTestLedger # pylint: disable=import-outside-toplevel from lbry import Config, RegTestLedger # pylint: disable=import-outside-toplevel
directory = tempfile.mkdtemp() directory = tempfile.mkdtemp()
if lbrycrd_config:
conf = lbrycrd_config
conf.data_dir = directory
conf.download_dir = directory
conf.wallet_dir = directory
else:
conf = Config.with_same_dir(directory).set(blockchain="regtest") conf = Config.with_same_dir(directory).set(blockchain="regtest")
if lbrycrd_dir is not None:
conf.lbrycrd_dir = lbrycrd_dir
ledger = RegTestLedger(conf) ledger = RegTestLedger(conf)
return cls(ledger) return cls(ledger)

View file

@ -19,7 +19,9 @@ class LightClient(Service):
def __init__(self, ledger: Ledger): def __init__(self, ledger: Ledger):
super().__init__(ledger) super().__init__(ledger)
self.client = Client(Config().api_connection_url) self.client = Client(
f"http://{ledger.conf.full_nodes[0][0]}:{ledger.conf.full_nodes[0][1]}/api"
)
self.sync = SPVSync(self) self.sync = SPVSync(self)
async def search_transactions(self, txids): async def search_transactions(self, txids):

View file

@ -16,6 +16,7 @@ from typing import Optional, List, Union
from binascii import unhexlify, hexlify from binascii import unhexlify, hexlify
import ecdsa import ecdsa
from distutils.dir_util import remove_tree
from lbry.db import Database from lbry.db import Database
from lbry.blockchain import ( from lbry.blockchain import (
@ -26,7 +27,7 @@ from lbry.blockchain.bcd_data_stream import BCDataStream
from lbry.blockchain.lbrycrd import Lbrycrd from lbry.blockchain.lbrycrd import Lbrycrd
from lbry.blockchain.dewies import lbc_to_dewies from lbry.blockchain.dewies import lbc_to_dewies
from lbry.constants import COIN, CENT, NULL_HASH32 from lbry.constants import COIN, CENT, NULL_HASH32
from lbry.service import Daemon, FullNode, jsonrpc_dumps_pretty from lbry.service import Daemon, FullNode, LightClient, jsonrpc_dumps_pretty
from lbry.conf import Config from lbry.conf import Config
from lbry.console import Console from lbry.console import Console
from lbry.wallet import Wallet, Account from lbry.wallet import Wallet, Account
@ -400,6 +401,7 @@ class UnitDBTestCase(AsyncioTestCase):
class IntegrationTestCase(AsyncioTestCase): class IntegrationTestCase(AsyncioTestCase):
SEED = None SEED = None
LBRYCRD_ARGS = '-rpcworkqueue=128',
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@ -411,6 +413,52 @@ class IntegrationTestCase(AsyncioTestCase):
self.wallet: Optional[Wallet] = None self.wallet: Optional[Wallet] = None
self.account: Optional[Account] = None self.account: Optional[Account] = None
async def asyncSetUp(self):
await super().asyncSetUp()
self.chain = self.make_chain()
await self.chain.ensure()
self.addCleanup(self.chain.stop)
await self.chain.start(*self.LBRYCRD_ARGS)
@staticmethod
def make_chain():
return Lbrycrd.temp_regtest()
async def make_db(self, chain):
db_driver = os.environ.get('TEST_DB', 'sqlite')
if db_driver == 'sqlite':
db = Database.temp_sqlite_regtest(chain.ledger.conf)
elif db_driver.startswith('postgres') or db_driver.startswith('psycopg'):
db_driver = 'postgresql'
db_name = f'lbry_test_chain'
db_connection = 'postgres:postgres@localhost:5432'
meta_db = Database.from_url(f'postgresql://{db_connection}/postgres')
await meta_db.drop(db_name)
await meta_db.create(db_name)
db = Database.temp_from_url_regtest(
f'postgresql://{db_connection}/{db_name}',
chain.ledger.conf
)
else:
raise RuntimeError(f"Unsupported database driver: {db_driver}")
self.addCleanup(remove_tree, db.ledger.conf.data_dir)
await db.open()
self.addCleanup(db.close)
self.db_driver = db_driver
return db
@staticmethod
def find_claim_txo(tx) -> Optional[Output]:
for txo in tx.outputs:
if txo.is_claim:
return txo
@staticmethod
def find_support_txo(tx) -> Optional[Output]:
for txo in tx.outputs:
if txo.is_support:
return txo
async def assertBalance(self, account, expected_balance: str): # pylint: disable=C0103 async def assertBalance(self, account, expected_balance: str): # pylint: disable=C0103
balance = await account.get_balance() balance = await account.get_balance()
self.assertEqual(dewies_to_lbc(balance), expected_balance) self.assertEqual(dewies_to_lbc(balance), expected_balance)
@ -487,14 +535,13 @@ class CommandTestCase(IntegrationTestCase):
self.reflector = None self.reflector = None
async def asyncSetUp(self): async def asyncSetUp(self):
self.chain = Lbrycrd.temp_regtest() await super().asyncSetUp()
await self.chain.ensure()
self.addCleanup(self.chain.stop)
await self.chain.start('-rpcworkqueue=128')
await self.generate(200, wait=False) await self.generate(200, wait=False)
self.daemon = await self.add_daemon() self.full_node = self.daemon = await self.add_full_node()
if os.environ.get('TEST_MODE', 'full-node') == 'client':
self.daemon = await self.add_light_client(self.full_node)
self.service = self.daemon.service self.service = self.daemon.service
self.ledger = self.service.ledger self.ledger = self.service.ledger
self.api = self.daemon.api self.api = self.daemon.api
@ -509,7 +556,7 @@ class CommandTestCase(IntegrationTestCase):
await self.chain.send_to_address(addresses[0], '10.0') await self.chain.send_to_address(addresses[0], '10.0')
await self.generate(5) await self.generate(5)
async def add_daemon(self): async def add_full_node(self):
self.daemon_port += 1 self.daemon_port += 1
path = tempfile.mkdtemp() path = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, path, True) self.addCleanup(shutil.rmtree, path, True)
@ -528,6 +575,21 @@ class CommandTestCase(IntegrationTestCase):
await daemon.start() await daemon.start()
return daemon return daemon
async def add_light_client(self, full_node):
self.daemon_port += 1
path = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, path, True)
ledger = RegTestLedger(Config.with_same_dir(path).set(
api=f'localhost:{self.daemon_port}',
full_nodes=[(full_node.conf.api_host, full_node.conf.api_port)]
))
service = LightClient(ledger)
console = Console(service)
daemon = Daemon(service, console)
self.addCleanup(daemon.stop)
await daemon.start()
return daemon
async def asyncTearDown(self): async def asyncTearDown(self):
await super().asyncTearDown() await super().asyncTearDown()
for wallet_node in self.extra_wallet_nodes: for wallet_node in self.extra_wallet_nodes: