forked from LBRYCommunity/lbry-sdk
full node wallet support
This commit is contained in:
parent
6ed2fa20ec
commit
5cdfbcc88e
11 changed files with 116 additions and 73 deletions
|
@ -44,6 +44,7 @@ class BlockchainSync(Sync):
|
||||||
self.pid = os.getpid()
|
self.pid = os.getpid()
|
||||||
self._on_block_controller = EventController()
|
self._on_block_controller = EventController()
|
||||||
self.on_block = self._on_block_controller.stream
|
self.on_block = self._on_block_controller.stream
|
||||||
|
self.conf.events.register("blockchain.block", self.on_block)
|
||||||
self._on_mempool_controller = EventController()
|
self._on_mempool_controller = EventController()
|
||||||
self.on_mempool = self._on_mempool_controller.stream
|
self.on_mempool = self._on_mempool_controller.stream
|
||||||
self.on_block_hash_subscription: Optional[BroadcastSubscription] = None
|
self.on_block_hash_subscription: Optional[BroadcastSubscription] = None
|
||||||
|
@ -109,6 +110,12 @@ class BlockchainSync(Sync):
|
||||||
return
|
return
|
||||||
return done
|
return done
|
||||||
|
|
||||||
|
async def get_block_headers(self, start_height: int, end_height: int = None):
|
||||||
|
return await self.db.get_block_headers(start_height, end_height)
|
||||||
|
|
||||||
|
async def get_best_block_height(self) -> int:
|
||||||
|
return await self.db.get_best_block_height()
|
||||||
|
|
||||||
async def get_best_block_height_for_file(self, file_number) -> int:
|
async def get_best_block_height_for_file(self, file_number) -> int:
|
||||||
return await self.db.run(
|
return await self.db.run(
|
||||||
block_phase.get_best_block_height_for_file, file_number
|
block_phase.get_best_block_height_for_file, file_number
|
||||||
|
|
|
@ -11,6 +11,7 @@ from lbry.utils.dirs import user_data_dir, user_download_dir
|
||||||
from lbry.error import InvalidCurrencyError
|
from lbry.error import InvalidCurrencyError
|
||||||
from lbry.dht import constants
|
from lbry.dht import constants
|
||||||
from lbry.wallet.coinselection import COIN_SELECTION_STRATEGIES
|
from lbry.wallet.coinselection import COIN_SELECTION_STRATEGIES
|
||||||
|
from lbry.event import EventRegistry
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -382,6 +383,7 @@ class BaseConfig:
|
||||||
self.environment = {} # from environment variables
|
self.environment = {} # from environment variables
|
||||||
self.persisted = {} # from config file
|
self.persisted = {} # from config file
|
||||||
self._updating_config = False
|
self._updating_config = False
|
||||||
|
self.events = EventRegistry()
|
||||||
self.set(**kwargs)
|
self.set(**kwargs)
|
||||||
|
|
||||||
def set(self, **kwargs):
|
def set(self, **kwargs):
|
||||||
|
|
|
@ -235,10 +235,10 @@ class Database:
|
||||||
async def process_all_things_after_sync(self):
|
async def process_all_things_after_sync(self):
|
||||||
return await self.run(sync.process_all_things_after_sync)
|
return await self.run(sync.process_all_things_after_sync)
|
||||||
|
|
||||||
async def get_blocks(self, first, last=None):
|
async def get_block_headers(self, start_height: int, end_height: int = None):
|
||||||
return await self.run(q.get_blocks, first, last)
|
return await self.run(q.get_block_headers, start_height, end_height)
|
||||||
|
|
||||||
async def get_filters(self, start_height, end_height=None, granularity=0):
|
async def get_filters(self, start_height: int, end_height: int = None, granularity: int = 0):
|
||||||
return await self.run(q.get_filters, start_height, end_height, granularity)
|
return await self.run(q.get_filters, start_height, end_height, granularity)
|
||||||
|
|
||||||
async def insert_block(self, block):
|
async def insert_block(self, block):
|
||||||
|
|
|
@ -43,7 +43,7 @@ def insert_block(block):
|
||||||
context().get_bulk_loader().add_block(block).flush()
|
context().get_bulk_loader().add_block(block).flush()
|
||||||
|
|
||||||
|
|
||||||
def get_blocks(first, last=None):
|
def get_block_headers(first, last=None):
|
||||||
if last is not None:
|
if last is not None:
|
||||||
query = (
|
query = (
|
||||||
select('*').select_from(Block)
|
select('*').select_from(Block)
|
||||||
|
|
|
@ -184,6 +184,22 @@ class EventStream:
|
||||||
future.set_exception(exception)
|
future.set_exception(exception)
|
||||||
|
|
||||||
|
|
||||||
|
class EventRegistry:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.events = {}
|
||||||
|
|
||||||
|
def register(self, name, stream: EventStream):
|
||||||
|
assert name not in self.events
|
||||||
|
self.events[name] = stream
|
||||||
|
|
||||||
|
def get(self, event_name):
|
||||||
|
return self.events.get(event_name)
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
self.events.clear()
|
||||||
|
|
||||||
|
|
||||||
class EventQueuePublisher(threading.Thread):
|
class EventQueuePublisher(threading.Thread):
|
||||||
|
|
||||||
STOP = 'STOP'
|
STOP = 'STOP'
|
||||||
|
|
|
@ -745,25 +745,11 @@ class API:
|
||||||
[--create_account] [--single_key]
|
[--create_account] [--single_key]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
wallet_path = os.path.join(self.conf.wallet_dir, 'wallets', wallet_id)
|
wallet = await self.wallets.create(
|
||||||
for wallet in self.wallets.wallets:
|
wallet_id, create_account=create_account, single_key=single_key
|
||||||
if wallet.id == wallet_id:
|
|
||||||
raise Exception(f"Wallet at path '{wallet_path}' already exists and is loaded.")
|
|
||||||
if os.path.exists(wallet_path):
|
|
||||||
raise Exception(f"Wallet at path '{wallet_path}' already exists, use 'wallet_add' to load wallet.")
|
|
||||||
|
|
||||||
wallet = self.wallets.import_wallet(wallet_path)
|
|
||||||
if not wallet.accounts and create_account:
|
|
||||||
account = Account.generate(
|
|
||||||
self.ledger, wallet, address_generator={
|
|
||||||
'name': SingleKey.name if single_key else HierarchicalDeterministic.name
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
if self.ledger.sync.network.is_connected:
|
|
||||||
await self.ledger.subscribe_account(account)
|
|
||||||
wallet.save()
|
|
||||||
if not skip_on_startup:
|
if not skip_on_startup:
|
||||||
with self.conf.update_config() as c:
|
with self.service.conf.update_config() as c:
|
||||||
c.wallets += [wallet_id]
|
c.wallets += [wallet_id]
|
||||||
return wallet
|
return wallet
|
||||||
|
|
||||||
|
@ -820,9 +806,7 @@ class API:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
wallet = self.wallets.get_or_default(wallet_id)
|
wallet = self.wallets.get_or_default(wallet_id)
|
||||||
balance = await self.ledger.get_detailed_balance(
|
balance = await wallet.get_balance()
|
||||||
accounts=wallet.accounts, confirmations=confirmations
|
|
||||||
)
|
|
||||||
return dict_values_to_lbc(balance)
|
return dict_values_to_lbc(balance)
|
||||||
|
|
||||||
async def wallet_status(
|
async def wallet_status(
|
||||||
|
@ -2682,9 +2666,20 @@ class API:
|
||||||
List block info
|
List block info
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
block_list <start_height> [<end_height>]
|
block list <start_height> [<end_height>]
|
||||||
"""
|
"""
|
||||||
return await self.service.get_blocks(start_height=start_height, end_height=end_height)
|
return await self.service.sync.get_block_headers(
|
||||||
|
start_height=start_height, end_height=end_height
|
||||||
|
)
|
||||||
|
|
||||||
|
async def block_tip(self) -> int: # block number at the tip of the blockchain
|
||||||
|
"""
|
||||||
|
Retrieve the last confirmed block (tip) of the blockchain.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
block tip
|
||||||
|
"""
|
||||||
|
return await self.service.sync.get_best_block_height()
|
||||||
|
|
||||||
TRANSACTION_DOC = """
|
TRANSACTION_DOC = """
|
||||||
Transaction management.
|
Transaction management.
|
||||||
|
@ -3603,7 +3598,25 @@ class Client(API):
|
||||||
if events:
|
if events:
|
||||||
await self.ws.send_json({'id': None, 'method': 'subscribe', 'params': events})
|
await self.ws.send_json({'id': None, 'method': 'subscribe', 'params': events})
|
||||||
|
|
||||||
|
@property
|
||||||
|
def first(self) -> 'Client':
|
||||||
|
return ClientReturnsFirstResponse(self)
|
||||||
|
|
||||||
def __getattribute__(self, name):
|
def __getattribute__(self, name):
|
||||||
if name in dir(API):
|
if name in dir(API):
|
||||||
return partial(object.__getattribute__(self, 'send'), name)
|
return partial(object.__getattribute__(self, 'send'), name)
|
||||||
return object.__getattribute__(self, name)
|
return object.__getattribute__(self, name)
|
||||||
|
|
||||||
|
|
||||||
|
class ClientReturnsFirstResponse(Client):
|
||||||
|
|
||||||
|
def __init__(self, client: Client):
|
||||||
|
self.client = client
|
||||||
|
|
||||||
|
def __getattribute__(self, name):
|
||||||
|
if name in dir(API):
|
||||||
|
async def return_first(**kwargs):
|
||||||
|
responses = await self.client.send(name, **kwargs)
|
||||||
|
return await responses.first
|
||||||
|
return return_first
|
||||||
|
return object.__getattribute__(self, name)
|
||||||
|
|
|
@ -51,6 +51,12 @@ class Sync:
|
||||||
async def stop(self):
|
async def stop(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def get_block_headers(self, start_height: int, end_height: int = None):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def get_best_block_height(self) -> int:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class Service:
|
class Service:
|
||||||
"""
|
"""
|
||||||
|
@ -92,9 +98,6 @@ class Service:
|
||||||
async def get_file(self, uri, **kwargs):
|
async def get_file(self, uri, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def get_block_headers(self, first, last=None):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def create_wallet(self, wallet_id):
|
def create_wallet(self, wallet_id):
|
||||||
return self.wallets.create(wallet_id)
|
return self.wallets.create(wallet_id)
|
||||||
|
|
||||||
|
|
|
@ -3,8 +3,9 @@ import signal
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from weakref import WeakSet
|
from weakref import WeakSet
|
||||||
|
from functools import partial
|
||||||
from asyncio.runners import _cancel_all_tasks
|
from asyncio.runners import _cancel_all_tasks
|
||||||
from typing import Type
|
from typing import Type, List, Dict, Tuple
|
||||||
|
|
||||||
from aiohttp.web import Application, AppRunner, WebSocketResponse, TCPSite, Response
|
from aiohttp.web import Application, AppRunner, WebSocketResponse, TCPSite, Response
|
||||||
from aiohttp.http_websocket import WSMsgType, WSCloseCode
|
from aiohttp.http_websocket import WSMsgType, WSCloseCode
|
||||||
|
@ -14,6 +15,7 @@ from lbry.console import Console, console_class_from_name
|
||||||
from lbry.service import API, Service
|
from lbry.service import API, Service
|
||||||
from lbry.service.json_encoder import JSONResponseEncoder
|
from lbry.service.json_encoder import JSONResponseEncoder
|
||||||
from lbry.blockchain.ledger import ledger_class_from_name
|
from lbry.blockchain.ledger import ledger_class_from_name
|
||||||
|
from lbry.event import BroadcastSubscription
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
@ -51,24 +53,6 @@ class WebSocketManager(WebSocketResponse):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
def subscribe(self, requested: list, subscriptions):
|
|
||||||
for request in requested:
|
|
||||||
if request == '*':
|
|
||||||
for _, component in subscriptions.items():
|
|
||||||
for _, sockets in component.items():
|
|
||||||
sockets.add(self)
|
|
||||||
elif '.' not in request:
|
|
||||||
for _, sockets in subscriptions[request].items():
|
|
||||||
sockets.add(self)
|
|
||||||
elif request.count('.') == 1:
|
|
||||||
component, stream = request.split('.')
|
|
||||||
subscriptions[component][stream].add(self)
|
|
||||||
|
|
||||||
def unsubscribe(self, subscriptions):
|
|
||||||
for _, component in subscriptions.items():
|
|
||||||
for _, sockets in component.items():
|
|
||||||
sockets.discard(self)
|
|
||||||
|
|
||||||
|
|
||||||
class Daemon:
|
class Daemon:
|
||||||
"""
|
"""
|
||||||
|
@ -83,13 +67,7 @@ class Daemon:
|
||||||
self.api = API(service)
|
self.api = API(service)
|
||||||
self.app = Application()
|
self.app = Application()
|
||||||
self.app['websockets'] = WeakSet()
|
self.app['websockets'] = WeakSet()
|
||||||
self.app['subscriptions'] = {}
|
self.app['subscriptions']: Dict[str, Tuple[BroadcastSubscription, WeakSet]] = {}
|
||||||
self.components = {}
|
|
||||||
#for component in components:
|
|
||||||
# streams = self.app['subscriptions'][component.name] = {}
|
|
||||||
# for event_name, event_stream in component.event_streams.items():
|
|
||||||
# streams[event_name] = WeakSet()
|
|
||||||
# event_stream.listen(partial(self.broadcast_event, component.name, event_name))
|
|
||||||
self.app.router.add_get('/ws', self.on_connect)
|
self.app.router.add_get('/ws', self.on_connect)
|
||||||
self.app.router.add_post('/api', self.on_rpc)
|
self.app.router.add_post('/api', self.on_rpc)
|
||||||
self.app.on_shutdown.append(self.on_shutdown)
|
self.app.on_shutdown.append(self.on_shutdown)
|
||||||
|
@ -162,7 +140,6 @@ class Daemon:
|
||||||
print('web socket connection closed with exception %s' %
|
print('web socket connection closed with exception %s' %
|
||||||
web_socket.exception())
|
web_socket.exception())
|
||||||
finally:
|
finally:
|
||||||
web_socket.unsubscribe(self.app['subscriptions'])
|
|
||||||
self.app['websockets'].discard(web_socket)
|
self.app['websockets'].discard(web_socket)
|
||||||
return web_socket
|
return web_socket
|
||||||
|
|
||||||
|
@ -171,7 +148,7 @@ class Daemon:
|
||||||
streams = msg['params']
|
streams = msg['params']
|
||||||
if isinstance(streams, str):
|
if isinstance(streams, str):
|
||||||
streams = [streams]
|
streams = [streams]
|
||||||
web_socket.subscribe(streams, self.app['subscriptions'])
|
await self.on_subscribe(web_socket, streams)
|
||||||
else:
|
else:
|
||||||
params = msg.get('params', {})
|
params = msg.get('params', {})
|
||||||
method = getattr(self.api, msg['method'])
|
method = getattr(self.api, msg['method'])
|
||||||
|
@ -186,22 +163,27 @@ class Daemon:
|
||||||
await web_socket.send_json({'id': msg.get('id', ''), 'result': "unexpected error: " + str(e)})
|
await web_socket.send_json({'id': msg.get('id', ''), 'result': "unexpected error: " + str(e)})
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
async def on_subscribe(self, web_socket: WebSocketManager, events: List[str]):
|
||||||
|
for event_name in events:
|
||||||
|
if event_name not in self.app["subscriptions"]:
|
||||||
|
event_stream = self.conf.events.get(event_name)
|
||||||
|
subscribers = WeakSet()
|
||||||
|
event_stream.listen(partial(self.broadcast_event, event_name, subscribers))
|
||||||
|
self.app["subscriptions"][event_name] = {
|
||||||
|
"stream": event_stream,
|
||||||
|
"subscribers": subscribers
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
subscribers = self.app["subscriptions"][event_name]["subscribers"]
|
||||||
|
subscribers.add(web_socket)
|
||||||
|
|
||||||
|
def broadcast_event(self, event_name, subscribers, payload):
|
||||||
|
for web_socket in subscribers:
|
||||||
|
asyncio.create_task(web_socket.send_json({
|
||||||
|
'event': event_name, 'payload': payload
|
||||||
|
}))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def on_shutdown(app):
|
async def on_shutdown(app):
|
||||||
for web_socket in set(app['websockets']):
|
for web_socket in set(app['websockets']):
|
||||||
await web_socket.close(code=WSCloseCode.GOING_AWAY, message='Server shutdown')
|
await web_socket.close(code=WSCloseCode.GOING_AWAY, message='Server shutdown')
|
||||||
|
|
||||||
def broadcast_event(self, module, stream, payload):
|
|
||||||
for web_socket in self.app['subscriptions'][module][stream]:
|
|
||||||
asyncio.create_task(web_socket.send_json({
|
|
||||||
'module': module,
|
|
||||||
'stream': stream,
|
|
||||||
'payload': payload
|
|
||||||
}))
|
|
||||||
|
|
||||||
def broadcast_message(self, msg):
|
|
||||||
for web_socket in self.app['websockets']:
|
|
||||||
asyncio.create_task(web_socket.send_json({
|
|
||||||
'module': 'blockchain_sync',
|
|
||||||
'payload': msg
|
|
||||||
}))
|
|
||||||
|
|
|
@ -23,6 +23,10 @@ class FullNode(Service):
|
||||||
super().__init__(ledger)
|
super().__init__(ledger)
|
||||||
self.chain = chain or Lbrycrd(ledger)
|
self.chain = chain or Lbrycrd(ledger)
|
||||||
self.sync = BlockchainSync(self.chain, self.db)
|
self.sync = BlockchainSync(self.chain, self.db)
|
||||||
|
self.sync.on_block.listen(self.generate_addresses)
|
||||||
|
|
||||||
|
async def generate_addresses(self, _):
|
||||||
|
await self.wallets.generate_addresses()
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
await self.chain.open()
|
await self.chain.open()
|
||||||
|
|
|
@ -32,6 +32,10 @@ class WalletManager:
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise ValueError(f"Couldn't find wallet: {wallet_id}.")
|
raise ValueError(f"Couldn't find wallet: {wallet_id}.")
|
||||||
|
|
||||||
|
async def generate_addresses(self):
|
||||||
|
for wallet in self.wallets.values():
|
||||||
|
await wallet.generate_addresses()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def default(self) -> Optional[Wallet]:
|
def default(self) -> Optional[Wallet]:
|
||||||
for wallet in self.wallets.values():
|
for wallet in self.wallets.values():
|
||||||
|
|
|
@ -59,6 +59,12 @@ class Wallet:
|
||||||
self.purchases = PurchaseListManager(self)
|
self.purchases = PurchaseListManager(self)
|
||||||
self.supports = SupportListManager(self)
|
self.supports = SupportListManager(self)
|
||||||
|
|
||||||
|
async def generate_addresses(self):
|
||||||
|
await asyncio.wait([
|
||||||
|
account.ensure_address_gap()
|
||||||
|
for account in self.accounts
|
||||||
|
])
|
||||||
|
|
||||||
async def notify_change(self, field: str, value=None):
|
async def notify_change(self, field: str, value=None):
|
||||||
await self._on_change_controller.add({
|
await self._on_change_controller.add({
|
||||||
'field': field, 'value': value
|
'field': field, 'value': value
|
||||||
|
@ -382,6 +388,12 @@ class Wallet:
|
||||||
f"Use --allow-duplicate-name flag to override."
|
f"Use --allow-duplicate-name flag to override."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def get_balance(self):
|
||||||
|
balance = {"total": 0}
|
||||||
|
for account in self.accounts:
|
||||||
|
balance["total"] += await account.get_balance()
|
||||||
|
return balance
|
||||||
|
|
||||||
|
|
||||||
class AccountListManager:
|
class AccountListManager:
|
||||||
__slots__ = 'wallet', '_accounts'
|
__slots__ = 'wallet', '_accounts'
|
||||||
|
|
Loading…
Reference in a new issue