From 5cdfbcc88e9e32cc461d48b16af4abca8f7374d8 Mon Sep 17 00:00:00 2001 From: Lex Berezhny Date: Fri, 16 Oct 2020 12:52:57 -0400 Subject: [PATCH] full node wallet support --- lbry/blockchain/sync/synchronizer.py | 7 +++ lbry/conf.py | 2 + lbry/db/database.py | 6 +-- lbry/db/queries/base.py | 2 +- lbry/event.py | 16 +++++++ lbry/service/api.py | 59 ++++++++++++++---------- lbry/service/base.py | 9 ++-- lbry/service/daemon.py | 68 ++++++++++------------------ lbry/service/full_node.py | 4 ++ lbry/wallet/manager.py | 4 ++ lbry/wallet/wallet.py | 12 +++++ 11 files changed, 116 insertions(+), 73 deletions(-) diff --git a/lbry/blockchain/sync/synchronizer.py b/lbry/blockchain/sync/synchronizer.py index 8824c9c9e..8bb390235 100644 --- a/lbry/blockchain/sync/synchronizer.py +++ b/lbry/blockchain/sync/synchronizer.py @@ -44,6 +44,7 @@ class BlockchainSync(Sync): self.pid = os.getpid() self._on_block_controller = EventController() self.on_block = self._on_block_controller.stream + self.conf.events.register("blockchain.block", self.on_block) self._on_mempool_controller = EventController() self.on_mempool = self._on_mempool_controller.stream self.on_block_hash_subscription: Optional[BroadcastSubscription] = None @@ -109,6 +110,12 @@ class BlockchainSync(Sync): return return done + async def get_block_headers(self, start_height: int, end_height: int = None): + return await self.db.get_block_headers(start_height, end_height) + + async def get_best_block_height(self) -> int: + return await self.db.get_best_block_height() + async def get_best_block_height_for_file(self, file_number) -> int: return await self.db.run( block_phase.get_best_block_height_for_file, file_number diff --git a/lbry/conf.py b/lbry/conf.py index 5c65ce41d..d60bc214b 100644 --- a/lbry/conf.py +++ b/lbry/conf.py @@ -11,6 +11,7 @@ from lbry.utils.dirs import user_data_dir, user_download_dir from lbry.error import InvalidCurrencyError from lbry.dht import constants from lbry.wallet.coinselection import COIN_SELECTION_STRATEGIES +from lbry.event import EventRegistry log = logging.getLogger(__name__) @@ -382,6 +383,7 @@ class BaseConfig: self.environment = {} # from environment variables self.persisted = {} # from config file self._updating_config = False + self.events = EventRegistry() self.set(**kwargs) def set(self, **kwargs): diff --git a/lbry/db/database.py b/lbry/db/database.py index 498f82a44..03d4685d7 100644 --- a/lbry/db/database.py +++ b/lbry/db/database.py @@ -235,10 +235,10 @@ class Database: async def process_all_things_after_sync(self): return await self.run(sync.process_all_things_after_sync) - async def get_blocks(self, first, last=None): - return await self.run(q.get_blocks, first, last) + async def get_block_headers(self, start_height: int, end_height: int = None): + return await self.run(q.get_block_headers, start_height, end_height) - async def get_filters(self, start_height, end_height=None, granularity=0): + async def get_filters(self, start_height: int, end_height: int = None, granularity: int = 0): return await self.run(q.get_filters, start_height, end_height, granularity) async def insert_block(self, block): diff --git a/lbry/db/queries/base.py b/lbry/db/queries/base.py index 75727c825..35e35673f 100644 --- a/lbry/db/queries/base.py +++ b/lbry/db/queries/base.py @@ -43,7 +43,7 @@ def insert_block(block): context().get_bulk_loader().add_block(block).flush() -def get_blocks(first, last=None): +def get_block_headers(first, last=None): if last is not None: query = ( select('*').select_from(Block) diff --git a/lbry/event.py b/lbry/event.py index e0b15c129..7c39d6642 100644 --- a/lbry/event.py +++ b/lbry/event.py @@ -184,6 +184,22 @@ class EventStream: future.set_exception(exception) +class EventRegistry: + + def __init__(self): + self.events = {} + + def register(self, name, stream: EventStream): + assert name not in self.events + self.events[name] = stream + + def get(self, event_name): + return self.events.get(event_name) + + def clear(self): + self.events.clear() + + class EventQueuePublisher(threading.Thread): STOP = 'STOP' diff --git a/lbry/service/api.py b/lbry/service/api.py index bfa59aa0f..79f5d6c74 100644 --- a/lbry/service/api.py +++ b/lbry/service/api.py @@ -745,25 +745,11 @@ class API: [--create_account] [--single_key] """ - wallet_path = os.path.join(self.conf.wallet_dir, 'wallets', wallet_id) - for wallet in self.wallets.wallets: - if wallet.id == wallet_id: - raise Exception(f"Wallet at path '{wallet_path}' already exists and is loaded.") - if os.path.exists(wallet_path): - raise Exception(f"Wallet at path '{wallet_path}' already exists, use 'wallet_add' to load wallet.") - - wallet = self.wallets.import_wallet(wallet_path) - if not wallet.accounts and create_account: - account = Account.generate( - self.ledger, wallet, address_generator={ - 'name': SingleKey.name if single_key else HierarchicalDeterministic.name - } - ) - if self.ledger.sync.network.is_connected: - await self.ledger.subscribe_account(account) - wallet.save() + wallet = await self.wallets.create( + wallet_id, create_account=create_account, single_key=single_key + ) if not skip_on_startup: - with self.conf.update_config() as c: + with self.service.conf.update_config() as c: c.wallets += [wallet_id] return wallet @@ -820,9 +806,7 @@ class API: """ wallet = self.wallets.get_or_default(wallet_id) - balance = await self.ledger.get_detailed_balance( - accounts=wallet.accounts, confirmations=confirmations - ) + balance = await wallet.get_balance() return dict_values_to_lbc(balance) async def wallet_status( @@ -2682,9 +2666,20 @@ class API: List block info Usage: - block_list [] + block list [] """ - return await self.service.get_blocks(start_height=start_height, end_height=end_height) + return await self.service.sync.get_block_headers( + start_height=start_height, end_height=end_height + ) + + async def block_tip(self) -> int: # block number at the tip of the blockchain + """ + Retrieve the last confirmed block (tip) of the blockchain. + + Usage: + block tip + """ + return await self.service.sync.get_best_block_height() TRANSACTION_DOC = """ Transaction management. @@ -3603,7 +3598,25 @@ class Client(API): if events: await self.ws.send_json({'id': None, 'method': 'subscribe', 'params': events}) + @property + def first(self) -> 'Client': + return ClientReturnsFirstResponse(self) + def __getattribute__(self, name): if name in dir(API): return partial(object.__getattribute__(self, 'send'), name) return object.__getattribute__(self, name) + + +class ClientReturnsFirstResponse(Client): + + def __init__(self, client: Client): + self.client = client + + def __getattribute__(self, name): + if name in dir(API): + async def return_first(**kwargs): + responses = await self.client.send(name, **kwargs) + return await responses.first + return return_first + return object.__getattribute__(self, name) diff --git a/lbry/service/base.py b/lbry/service/base.py index c65ff605d..b447e4822 100644 --- a/lbry/service/base.py +++ b/lbry/service/base.py @@ -51,6 +51,12 @@ class Sync: async def stop(self): raise NotImplementedError + async def get_block_headers(self, start_height: int, end_height: int = None): + raise NotImplementedError + + async def get_best_block_height(self) -> int: + raise NotImplementedError + class Service: """ @@ -92,9 +98,6 @@ class Service: async def get_file(self, uri, **kwargs): pass - async def get_block_headers(self, first, last=None): - pass - def create_wallet(self, wallet_id): return self.wallets.create(wallet_id) diff --git a/lbry/service/daemon.py b/lbry/service/daemon.py index d681b7b7d..140a55a6e 100644 --- a/lbry/service/daemon.py +++ b/lbry/service/daemon.py @@ -3,8 +3,9 @@ import signal import asyncio import logging from weakref import WeakSet +from functools import partial from asyncio.runners import _cancel_all_tasks -from typing import Type +from typing import Type, List, Dict, Tuple from aiohttp.web import Application, AppRunner, WebSocketResponse, TCPSite, Response from aiohttp.http_websocket import WSMsgType, WSCloseCode @@ -14,6 +15,7 @@ from lbry.console import Console, console_class_from_name from lbry.service import API, Service from lbry.service.json_encoder import JSONResponseEncoder from lbry.blockchain.ledger import ledger_class_from_name +from lbry.event import BroadcastSubscription log = logging.getLogger(__name__) @@ -51,24 +53,6 @@ class WebSocketManager(WebSocketResponse): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - def subscribe(self, requested: list, subscriptions): - for request in requested: - if request == '*': - for _, component in subscriptions.items(): - for _, sockets in component.items(): - sockets.add(self) - elif '.' not in request: - for _, sockets in subscriptions[request].items(): - sockets.add(self) - elif request.count('.') == 1: - component, stream = request.split('.') - subscriptions[component][stream].add(self) - - def unsubscribe(self, subscriptions): - for _, component in subscriptions.items(): - for _, sockets in component.items(): - sockets.discard(self) - class Daemon: """ @@ -83,13 +67,7 @@ class Daemon: self.api = API(service) self.app = Application() self.app['websockets'] = WeakSet() - self.app['subscriptions'] = {} - self.components = {} - #for component in components: - # streams = self.app['subscriptions'][component.name] = {} - # for event_name, event_stream in component.event_streams.items(): - # streams[event_name] = WeakSet() - # event_stream.listen(partial(self.broadcast_event, component.name, event_name)) + self.app['subscriptions']: Dict[str, Tuple[BroadcastSubscription, WeakSet]] = {} self.app.router.add_get('/ws', self.on_connect) self.app.router.add_post('/api', self.on_rpc) self.app.on_shutdown.append(self.on_shutdown) @@ -162,7 +140,6 @@ class Daemon: print('web socket connection closed with exception %s' % web_socket.exception()) finally: - web_socket.unsubscribe(self.app['subscriptions']) self.app['websockets'].discard(web_socket) return web_socket @@ -171,7 +148,7 @@ class Daemon: streams = msg['params'] if isinstance(streams, str): streams = [streams] - web_socket.subscribe(streams, self.app['subscriptions']) + await self.on_subscribe(web_socket, streams) else: params = msg.get('params', {}) method = getattr(self.api, msg['method']) @@ -186,22 +163,27 @@ class Daemon: await web_socket.send_json({'id': msg.get('id', ''), 'result': "unexpected error: " + str(e)}) raise e + async def on_subscribe(self, web_socket: WebSocketManager, events: List[str]): + for event_name in events: + if event_name not in self.app["subscriptions"]: + event_stream = self.conf.events.get(event_name) + subscribers = WeakSet() + event_stream.listen(partial(self.broadcast_event, event_name, subscribers)) + self.app["subscriptions"][event_name] = { + "stream": event_stream, + "subscribers": subscribers + } + else: + subscribers = self.app["subscriptions"][event_name]["subscribers"] + subscribers.add(web_socket) + + def broadcast_event(self, event_name, subscribers, payload): + for web_socket in subscribers: + asyncio.create_task(web_socket.send_json({ + 'event': event_name, 'payload': payload + })) + @staticmethod async def on_shutdown(app): for web_socket in set(app['websockets']): await web_socket.close(code=WSCloseCode.GOING_AWAY, message='Server shutdown') - - def broadcast_event(self, module, stream, payload): - for web_socket in self.app['subscriptions'][module][stream]: - asyncio.create_task(web_socket.send_json({ - 'module': module, - 'stream': stream, - 'payload': payload - })) - - def broadcast_message(self, msg): - for web_socket in self.app['websockets']: - asyncio.create_task(web_socket.send_json({ - 'module': 'blockchain_sync', - 'payload': msg - })) diff --git a/lbry/service/full_node.py b/lbry/service/full_node.py index 69a61177a..3eb341260 100644 --- a/lbry/service/full_node.py +++ b/lbry/service/full_node.py @@ -23,6 +23,10 @@ class FullNode(Service): super().__init__(ledger) self.chain = chain or Lbrycrd(ledger) self.sync = BlockchainSync(self.chain, self.db) + self.sync.on_block.listen(self.generate_addresses) + + async def generate_addresses(self, _): + await self.wallets.generate_addresses() async def start(self): await self.chain.open() diff --git a/lbry/wallet/manager.py b/lbry/wallet/manager.py index 69818b633..a8554e494 100644 --- a/lbry/wallet/manager.py +++ b/lbry/wallet/manager.py @@ -32,6 +32,10 @@ class WalletManager: except KeyError: raise ValueError(f"Couldn't find wallet: {wallet_id}.") + async def generate_addresses(self): + for wallet in self.wallets.values(): + await wallet.generate_addresses() + @property def default(self) -> Optional[Wallet]: for wallet in self.wallets.values(): diff --git a/lbry/wallet/wallet.py b/lbry/wallet/wallet.py index f53717e4e..52a1b9760 100644 --- a/lbry/wallet/wallet.py +++ b/lbry/wallet/wallet.py @@ -59,6 +59,12 @@ class Wallet: self.purchases = PurchaseListManager(self) self.supports = SupportListManager(self) + async def generate_addresses(self): + await asyncio.wait([ + account.ensure_address_gap() + for account in self.accounts + ]) + async def notify_change(self, field: str, value=None): await self._on_change_controller.add({ 'field': field, 'value': value @@ -382,6 +388,12 @@ class Wallet: f"Use --allow-duplicate-name flag to override." ) + async def get_balance(self): + balance = {"total": 0} + for account in self.accounts: + balance["total"] += await account.get_balance() + return balance + class AccountListManager: __slots__ = 'wallet', '_accounts'