From 6ed2fa20ecde7efd2c4a60c1a5b19bc395666037 Mon Sep 17 00:00:00 2001 From: Lex Berezhny Date: Fri, 9 Oct 2020 10:47:44 -0400 Subject: [PATCH] working db based wallet and wallet sync progress --- lbry/blockchain/sync/synchronizer.py | 6 +- lbry/conf.py | 1 + lbry/db/database.py | 18 + lbry/db/queries/__init__.py | 1 + lbry/db/queries/address.py | 4 +- lbry/db/queries/base.py | 61 ++- lbry/db/queries/wallet.py | 24 ++ lbry/db/tables.py | 7 + lbry/service/api.py | 56 ++- lbry/service/base.py | 33 +- lbry/service/daemon.py | 15 +- lbry/service/full_endpoint.py | 61 +++ lbry/service/full_node.py | 10 +- lbry/service/light_client.py | 478 +++++++++++++++++++++- lbry/testcase.py | 74 ++-- lbry/wallet/account.py | 20 +- lbry/wallet/database.py | 385 ------------------ lbry/wallet/manager.py | 209 +++++++--- lbry/wallet/storage.py | 60 --- lbry/wallet/sync.py | 420 -------------------- lbry/wallet/wallet.py | 129 +++--- tests/unit/wallet/__init__.py | 2 - tests/unit/wallet/test_account.py | 62 +-- tests/unit/wallet/test_database.py | 566 --------------------------- tests/unit/wallet/test_ledger.py | 242 ------------ tests/unit/wallet/test_manager.py | 168 ++++++-- tests/unit/wallet/test_sync.py | 78 ---- tests/unit/wallet/test_wallet.py | 221 ++++++++--- 28 files changed, 1315 insertions(+), 2096 deletions(-) create mode 100644 lbry/db/queries/wallet.py create mode 100644 lbry/service/full_endpoint.py delete mode 100644 lbry/wallet/database.py delete mode 100644 lbry/wallet/storage.py delete mode 100644 lbry/wallet/sync.py delete mode 100644 tests/unit/wallet/test_database.py delete mode 100644 tests/unit/wallet/test_ledger.py delete mode 100644 tests/unit/wallet/test_sync.py diff --git a/lbry/blockchain/sync/synchronizer.py b/lbry/blockchain/sync/synchronizer.py index 4c1652550..8824c9c9e 100644 --- a/lbry/blockchain/sync/synchronizer.py +++ b/lbry/blockchain/sync/synchronizer.py @@ -42,13 +42,15 @@ class BlockchainSync(Sync): super().__init__(chain.ledger, db) self.chain = chain self.pid = os.getpid() + self._on_block_controller = EventController() + self.on_block = self._on_block_controller.stream + self._on_mempool_controller = EventController() + self.on_mempool = self._on_mempool_controller.stream self.on_block_hash_subscription: Optional[BroadcastSubscription] = None self.on_tx_hash_subscription: Optional[BroadcastSubscription] = None self.advance_loop_task: Optional[asyncio.Task] = None self.block_hash_event = asyncio.Event() self.tx_hash_event = asyncio.Event() - self._on_mempool_controller = EventController() - self.on_mempool = self._on_mempool_controller.stream self.mempool = [] async def wait_for_chain_ready(self): diff --git a/lbry/conf.py b/lbry/conf.py index 90deaaccb..5c65ce41d 100644 --- a/lbry/conf.py +++ b/lbry/conf.py @@ -519,6 +519,7 @@ class Config(CLIConfig): data_dir = Path("Main directory containing blobs, wallets and blockchain data.", metavar='DIR') blob_dir = Path("Directory to store blobs (default: 'data_dir'/blobs).", metavar='DIR') wallet_dir = Path("Directory to store wallets (default: 'data_dir'/wallets).", metavar='DIR') + wallet_storage = StringChoice("Wallet storage mode.", ["file", "database"], "file") wallets = Strings( "Wallet files in 'wallet_dir' to load at startup.", ['default_wallet'] ) diff --git a/lbry/db/database.py b/lbry/db/database.py index af6efeac2..498f82a44 100644 --- a/lbry/db/database.py +++ b/lbry/db/database.py @@ -205,6 +205,9 @@ class Database: async def execute(self, sql): return await self.run(q.execute, sql) + async def execute_sql_object(self, sql): + return await self.run(q.execute_sql_object, sql) + async def execute_fetchall(self, sql): return await self.run(q.execute_fetchall, sql) @@ -217,12 +220,27 @@ class Database: async def has_supports(self): return await self.run(q.has_supports) + async def has_wallet(self, wallet_id): + return await self.run(q.has_wallet, wallet_id) + + async def get_wallet(self, wallet_id: str): + return await self.run(q.get_wallet, wallet_id) + + async def add_wallet(self, wallet_id: str, data: str): + return await self.run(q.add_wallet, wallet_id, data) + async def get_best_block_height(self) -> int: return await self.run(q.get_best_block_height) async def process_all_things_after_sync(self): return await self.run(sync.process_all_things_after_sync) + async def get_blocks(self, first, last=None): + return await self.run(q.get_blocks, first, last) + + async def get_filters(self, start_height, end_height=None, granularity=0): + return await self.run(q.get_filters, start_height, end_height, granularity) + async def insert_block(self, block): return await self.run(q.insert_block, block) diff --git a/lbry/db/queries/__init__.py b/lbry/db/queries/__init__.py index 4de770595..87d9b167e 100644 --- a/lbry/db/queries/__init__.py +++ b/lbry/db/queries/__init__.py @@ -3,3 +3,4 @@ from .txio import * from .search import * from .resolve import * from .address import * +from .wallet import * diff --git a/lbry/db/queries/address.py b/lbry/db/queries/address.py index 35ec51daf..c2f194cb1 100644 --- a/lbry/db/queries/address.py +++ b/lbry/db/queries/address.py @@ -62,11 +62,11 @@ def add_keys(account, chain, pubkeys): c = context() c.execute( c.insert_or_ignore(PubkeyAddress) - .values([{'address': k.address} for k in pubkeys]) + .values([{'address': k.address} for k in pubkeys]) ) c.execute( c.insert_or_ignore(AccountAddress) - .values([{ + .values([{ 'account': account.id, 'address': k.address, 'chain': chain, diff --git a/lbry/db/queries/base.py b/lbry/db/queries/base.py index f758adb23..75727c825 100644 --- a/lbry/db/queries/base.py +++ b/lbry/db/queries/base.py @@ -1,14 +1,24 @@ -from sqlalchemy import text +from math import log10 +from binascii import hexlify + +from sqlalchemy import text, between from sqlalchemy.future import select from ..query_context import context -from ..tables import SCHEMA_VERSION, metadata, Version, Claim, Support, Block, BlockFilter, TX +from ..tables import ( + SCHEMA_VERSION, metadata, Version, + Claim, Support, Block, BlockFilter, BlockGroupFilter, TX, TXFilter, +) def execute(sql): return context().execute(text(sql)) +def execute_sql_object(sql): + return context().execute(sql) + + def execute_fetchall(sql): return context().fetchall(text(sql)) @@ -33,6 +43,53 @@ def insert_block(block): context().get_bulk_loader().add_block(block).flush() +def get_blocks(first, last=None): + if last is not None: + query = ( + select('*').select_from(Block) + .where(between(Block.c.height, first, last)) + .order_by(Block.c.height) + ) + else: + query = select('*').select_from(Block).where(Block.c.height == first) + return context().fetchall(query) + + +def get_filters(start_height, end_height=None, granularity=0): + assert granularity >= 0, "filter granularity must be 0 or positive number" + if granularity == 0: + query = ( + select('*').select_from(TXFilter) + .where(between(TXFilter.c.height, start_height, end_height)) + .order_by(TXFilter.c.height) + ) + elif granularity == 1: + query = ( + select('*').select_from(BlockFilter) + .where(between(BlockFilter.c.height, start_height, end_height)) + .order_by(BlockFilter.c.height) + ) + else: + query = ( + select('*').select_from(BlockGroupFilter) + .where( + (BlockGroupFilter.c.height == start_height) & + (BlockGroupFilter.c.factor == log10(granularity)) + ) + .order_by(BlockGroupFilter.c.height) + ) + result = [] + for row in context().fetchall(query): + record = { + "height": row["height"], + "filter": hexlify(row["address_filter"]).decode(), + } + if granularity == 0: + record["txid"] = hexlify(row["tx_hash"][::-1]).decode() + result.append(record) + return result + + def insert_transaction(block_hash, tx): context().get_bulk_loader().add_transaction(block_hash, tx).flush(TX) diff --git a/lbry/db/queries/wallet.py b/lbry/db/queries/wallet.py new file mode 100644 index 000000000..b1b7f3a98 --- /dev/null +++ b/lbry/db/queries/wallet.py @@ -0,0 +1,24 @@ +from sqlalchemy import exists +from sqlalchemy.future import select + +from ..query_context import context +from ..tables import Wallet + + +def has_wallet(wallet_id: str) -> bool: + sql = select(exists(select(Wallet.c.wallet_id).where(Wallet.c.wallet_id == wallet_id))) + return context().execute(sql).fetchone()[0] + + +def get_wallet(wallet_id: str): + return context().fetchone( + select(Wallet.c.data).where(Wallet.c.wallet_id == wallet_id) + ) + + +def add_wallet(wallet_id: str, data: str): + c = context() + c.execute( + c.insert_or_replace(Wallet, ["data"]) + .values(wallet_id=wallet_id, data=data) + ) diff --git a/lbry/db/tables.py b/lbry/db/tables.py index 5a0fa3bf7..22aa14b1e 100644 --- a/lbry/db/tables.py +++ b/lbry/db/tables.py @@ -19,6 +19,13 @@ Version = Table( ) +Wallet = Table( + 'wallet', metadata, + Column('wallet_id', Text, primary_key=True), + Column('data', Text), +) + + PubkeyAddress = Table( 'pubkey_address', metadata, Column('address', Text, primary_key=True), diff --git a/lbry/service/api.py b/lbry/service/api.py index bb94c5d78..bfa59aa0f 100644 --- a/lbry/service/api.py +++ b/lbry/service/api.py @@ -1225,7 +1225,7 @@ class API: wallet = self.wallets.get_or_default(wallet_id) wallet_changed = False if data is not None: - added_accounts = await wallet.merge(self.wallets, password, data) + added_accounts = await wallet.merge(password, data) if added_accounts and self.ledger.sync.network.is_connected: if blocking: await asyncio.wait([ @@ -1318,11 +1318,24 @@ class API: .receiving.get_or_create_usable_address() ) - async def address_block_filters(self): - return await self.service.get_block_address_filters() + async def address_filter( + self, + start_height: int, # starting height of block range or just single block + end_height: int = None, # return a range of blocks from start_height to end_height + granularity: int = None, # 0 - individual tx filters, 1 - block per filter, + # 1000, 10000, 100000 blocks per filter + ) -> list: # blocks + """ + List address filters - async def address_transaction_filters(self, block_hash: str): - return await self.service.get_transaction_address_filters(block_hash) + Usage: + address_filter + [--end_height=] + [--granularity=] + """ + return await self.service.get_address_filters( + start_height=start_height, end_height=end_height, granularity=granularity + ) FILE_DOC = """ File management. @@ -2656,6 +2669,23 @@ class API: await self.service.maybe_broadcast_or_release(tx, blocking, preview) return tx + BLOCK_DOC = """ + Block information. + """ + + async def block_list( + self, + start_height: int, # starting height of block range or just single block + end_height: int = None, # return a range of blocks from start_height to end_height + ) -> list: # blocks + """ + List block info + + Usage: + block_list [] + """ + return await self.service.get_blocks(start_height=start_height, end_height=end_height) + TRANSACTION_DOC = """ Transaction management. """ @@ -3529,8 +3559,12 @@ class Client(API): self.receive_messages_task = asyncio.create_task(self.receive_messages()) async def disconnect(self): - await self.session.close() - self.receive_messages_task.cancel() + if self.session is not None: + await self.session.close() + self.session = None + if self.receive_messages_task is not None: + self.receive_messages_task.cancel() + self.receive_messages_task = None async def receive_messages(self): async for message in self.ws: @@ -3559,12 +3593,16 @@ class Client(API): await self.ws.send_json({'id': self.message_id, 'method': method, 'params': kwargs}) return ec.stream - async def subscribe(self, event) -> EventStream: + def get_event_stream(self, event) -> EventStream: if event not in self.subscriptions: self.subscriptions[event] = EventController() - await self.ws.send_json({'id': None, 'method': 'subscribe', 'params': [event]}) return self.subscriptions[event].stream + async def start_event_streams(self): + events = list(self.subscriptions.keys()) + if events: + await self.ws.send_json({'id': None, 'method': 'subscribe', 'params': events}) + def __getattribute__(self, name): if name in dir(API): return partial(object.__getattribute__(self, 'send'), name) diff --git a/lbry/service/base.py b/lbry/service/base.py index 2f9c86f05..c65ff605d 100644 --- a/lbry/service/base.py +++ b/lbry/service/base.py @@ -9,7 +9,7 @@ from lbry.schema.result import Censor from lbry.blockchain.transaction import Transaction, Output from lbry.blockchain.ledger import Ledger from lbry.wallet import WalletManager -from lbry.event import EventController +from lbry.event import EventController, EventStream log = logging.getLogger(__name__) @@ -25,14 +25,14 @@ class Sync: Server stays synced with lbrycrd """ + on_block: EventStream + on_mempool: EventStream + def __init__(self, ledger: Ledger, db: Database): self.ledger = ledger self.conf = ledger.conf self.db = db - self._on_block_controller = EventController() - self.on_block = self._on_block_controller.stream - self._on_progress_controller = db._on_progress_controller self.on_progress = db.on_progress @@ -63,13 +63,7 @@ class Service: def __init__(self, ledger: Ledger): self.ledger, self.conf = ledger, ledger.conf self.db = Database(ledger) - self.wallets = WalletManager(ledger, self.db) - - #self.on_address = sync.on_address - #self.accounts = sync.accounts - #self.on_header = sync.on_header - #self.on_ready = sync.on_ready - #self.on_transaction = sync.on_transaction + self.wallets = WalletManager(self.db) # sync has established connection with a source from which it can synchronize # for full service this is lbrycrd (or sync service) and for light this is full node @@ -78,8 +72,8 @@ class Service: async def start(self): await self.db.open() - await self.wallets.ensure_path_exists() - await self.wallets.load() + await self.wallets.storage.prepare() + await self.wallets.initialize() await self.sync.start() async def stop(self): @@ -95,16 +89,21 @@ class Service: async def find_ffmpeg(self): pass - async def get(self, uri, **kwargs): + async def get_file(self, uri, **kwargs): pass - def create_wallet(self, file_name): - path = os.path.join(self.conf.wallet_dir, file_name) - return self.wallets.add_from_path(path) + async def get_block_headers(self, first, last=None): + pass + + def create_wallet(self, wallet_id): + return self.wallets.create(wallet_id) async def get_addresses(self, **constraints): return await self.db.get_addresses(**constraints) + async def get_address_filters(self, start_height: int, end_height: int=None, granularity: int=0): + raise NotImplementedError + def reserve_outputs(self, txos): return self.db.reserve_outputs(txos) diff --git a/lbry/service/daemon.py b/lbry/service/daemon.py index 55b67a5e2..d681b7b7d 100644 --- a/lbry/service/daemon.py +++ b/lbry/service/daemon.py @@ -19,11 +19,13 @@ from lbry.blockchain.ledger import ledger_class_from_name log = logging.getLogger(__name__) -def jsonrpc_dumps_pretty(obj, **kwargs): +def jsonrpc_dumps_pretty(obj, message_id=None, **kwargs): #if not isinstance(obj, dict): # data = {"jsonrpc": "2.0", "error": obj.to_dict()} #else: data = {"jsonrpc": "2.0", "result": obj} + if message_id is not None: + data["id"] = message_id return json.dumps(data, cls=JSONResponseEncoder, sort_keys=True, indent=2, **kwargs) + "\n" @@ -166,7 +168,7 @@ class Daemon: async def on_message(self, web_socket: WebSocketManager, msg: dict): if msg['method'] == 'subscribe': - streams = msg['streams'] + streams = msg['params'] if isinstance(streams, str): streams = [streams] web_socket.subscribe(streams, self.app['subscriptions']) @@ -175,11 +177,10 @@ class Daemon: method = getattr(self.api, msg['method']) try: result = await method(**params) - encoded_result = jsonrpc_dumps_pretty(result, service=self.service) - await web_socket.send_json({ - 'id': msg.get('id', ''), - 'result': encoded_result - }) + encoded_result = jsonrpc_dumps_pretty( + result, message_id=msg.get('id', ''), service=self.service + ) + await web_socket.send_str(encoded_result) except Exception as e: log.exception("RPC error") await web_socket.send_json({'id': msg.get('id', ''), 'result': "unexpected error: " + str(e)}) diff --git a/lbry/service/full_endpoint.py b/lbry/service/full_endpoint.py new file mode 100644 index 000000000..811899b24 --- /dev/null +++ b/lbry/service/full_endpoint.py @@ -0,0 +1,61 @@ +import logging +from binascii import hexlify, unhexlify + +from lbry.blockchain.lbrycrd import Lbrycrd +from lbry.blockchain.sync import BlockchainSync +from lbry.blockchain.ledger import Ledger +from lbry.blockchain.transaction import Transaction + +from .base import Service, Sync +from .api import Client as APIClient + + +log = logging.getLogger(__name__) + + +class NoSync(Sync): + + def __init__(self, service: Service, client: APIClient): + super().__init__(service.ledger, service.db) + self.service = service + self.client = client + self.on_block = client.get_event_stream('blockchain.block') + self.on_block_subscription: Optional[BroadcastSubscription] = None + self.on_mempool = client.get_event_stream('blockchain.mempool') + self.on_mempool_subscription: Optional[BroadcastSubscription] = None + + async def wait_for_client_ready(self): + await self.client.connect() + + async def start(self): + self.db.stop_event.clear() + await self.wait_for_client_ready() + self.advance_loop_task = asyncio.create_task(self.advance()) + await self.advance_loop_task + await self.client.subscribe() + self.advance_loop_task = asyncio.create_task(self.advance_loop()) + self.on_block_subscription = self.on_block.listen( + lambda e: self.on_block_event.set() + ) + self.on_mempool_subscription = self.on_mempool.listen( + lambda e: self.on_mempool_event.set() + ) + await self.download_filters() + await self.download_headers() + + async def stop(self): + await self.client.disconnect() + + +class FullEndpoint(Service): + + name = "endpoint" + + sync: 'NoSync' + + def __init__(self, ledger: Ledger): + super().__init__(ledger) + self.client = APIClient( + f"http://{ledger.conf.full_nodes[0][0]}:{ledger.conf.full_nodes[0][1]}/api" + ) + self.sync = NoSync(self, self.client) diff --git a/lbry/service/full_node.py b/lbry/service/full_node.py index fb75e513d..69a61177a 100644 --- a/lbry/service/full_node.py +++ b/lbry/service/full_node.py @@ -35,7 +35,15 @@ class FullNode(Service): async def get_status(self): return 'everything is wonderful' -# async def get_block_address_filters(self): + async def get_block_headers(self, first, last=None): + return await self.db.get_blocks(first, last) + + async def get_address_filters(self, start_height: int, end_height: int = None, granularity: int = 0): + return await self.db.get_filters( + start_height=start_height, end_height=end_height, granularity=granularity + ) + + # async def get_block_address_filters(self): # return { # hexlify(f['block_hash']).decode(): hexlify(f['block_filter']).decode() # for f in await self.db.get_block_address_filters() diff --git a/lbry/service/light_client.py b/lbry/service/light_client.py index cb027b02b..64b30c945 100644 --- a/lbry/service/light_client.py +++ b/lbry/service/light_client.py @@ -1,11 +1,25 @@ +import asyncio import logging -from typing import List, Dict, Tuple +from typing import List, Dict +#from io import StringIO +#from functools import partial +#from operator import itemgetter +from collections import defaultdict +#from binascii import hexlify, unhexlify +from typing import List, Optional, DefaultDict, NamedTuple +#from lbry.crypto.hash import double_sha256, sha256 + +from lbry.tasks import TaskGroup +from lbry.blockchain.transaction import Transaction +from lbry.blockchain.block import get_address_filter +from lbry.event import BroadcastSubscription, EventController +from lbry.wallet.account import AddressManager from lbry.blockchain import Ledger, Transaction -from lbry.wallet.sync import SPVSync -from .base import Service -from .api import Client +from .base import Service, Sync +from .api import Client as APIClient + log = logging.getLogger(__name__) @@ -14,23 +28,24 @@ class LightClient(Service): name = "client" - sync: SPVSync + sync: 'FastSync' def __init__(self, ledger: Ledger): super().__init__(ledger) - self.client = Client( - f"http://{ledger.conf.full_nodes[0][0]}:{ledger.conf.full_nodes[0][1]}/api" + self.client = APIClient( + f"http://{ledger.conf.full_nodes[0][0]}:{ledger.conf.full_nodes[0][1]}/ws" ) - self.sync = SPVSync(self) + self.sync = FastSync(self, self.client) + self.blocks = BlockHeaderManager(self.db, self.client) + self.filters = FilterManager(self.db, self.client) async def search_transactions(self, txids): return await self.client.transaction_search(txids=txids) - async def get_block_address_filters(self): - return await self.client.address_block_filters() - - async def get_transaction_address_filters(self, block_hash): - return await self.client.address_transaction_filters(block_hash=block_hash) + async def get_address_filters(self, start_height: int, end_height: int = None, granularity: int = 0): + return await self.filters.get_filters( + start_height=start_height, end_height=end_height, granularity=granularity + ) async def broadcast(self, tx): pass @@ -47,6 +62,437 @@ class LightClient(Service): async def search_supports(self, accounts, **kwargs): pass - async def sum_supports(self, claim_hash: bytes, include_channel_content=False, exclude_own_supports=False) \ - -> Tuple[List[Dict], int]: - return await self.client.sum_supports(claim_hash, include_channel_content, exclude_own_supports) + async def sum_supports(self, claim_hash: bytes, include_channel_content=False) -> List[Dict]: + return await self.client.sum_supports(claim_hash, include_channel_content) + + +class TransactionEvent(NamedTuple): + address: str + tx: Transaction + + +class AddressesGeneratedEvent(NamedTuple): + address_manager: AddressManager + addresses: List[str] + + +class TransactionCacheItem: + __slots__ = '_tx', 'lock', 'has_tx', 'pending_verifications' + + def __init__(self, tx: Optional[Transaction] = None, lock: Optional[asyncio.Lock] = None): + self.has_tx = asyncio.Event() + self.lock = lock or asyncio.Lock() + self._tx = self.tx = tx + self.pending_verifications = 0 + + @property + def tx(self) -> Optional[Transaction]: + return self._tx + + @tx.setter + def tx(self, tx: Transaction): + self._tx = tx + if tx is not None: + self.has_tx.set() + + +class FilterManager: + """ + Efficient on-demand address filter access. + Stores and retrieves from local db what it previously downloaded and + downloads on-demand what it doesn't have from full node. + """ + + def __init__(self, db, client): + self.db = db + self.client = client + self.cache = {} + + async def get_filters(self, start_height, end_height, granularity): + return await self.client.address_filter( + start_height=start_height, end_height=end_height, granularity=granularity + ) + + +class BlockHeaderManager: + """ + Efficient on-demand block header access. + Stores and retrieves from local db what it previously downloaded and + downloads on-demand what it doesn't have from full node. + """ + + def __init__(self, db, client): + self.db = db + self.client = client + self.cache = {} + + async def get_header(self, height): + blocks = await self.client.block_list(height) + if blocks: + return blocks[0] + + async def add(self, header): + pass + + async def download(self): + pass + + +class FastSync(Sync): + + def __init__(self, service: Service, client: APIClient): + super().__init__(service.ledger, service.db) + self.service = service + self.client = client + self.advance_loop_task: Optional[asyncio.Task] = None + self.on_block = client.get_event_stream('blockchain.block') + self.on_block_event = asyncio.Event() + self.on_block_subscription: Optional[BroadcastSubscription] = None + self.on_mempool = client.get_event_stream('blockchain.mempool') + self.on_mempool_event = asyncio.Event() + self.on_mempool_subscription: Optional[BroadcastSubscription] = None + + async def wait_for_client_ready(self): + await self.client.connect() + + async def start(self): + return + self.db.stop_event.clear() + await self.wait_for_client_ready() + self.advance_loop_task = asyncio.create_task(self.advance()) + await self.advance_loop_task + await self.client.subscribe() + self.advance_loop_task = asyncio.create_task(self.advance_loop()) + self.on_block_subscription = self.on_block.listen( + lambda e: self.on_block_event.set() + ) + self.on_mempool_subscription = self.on_mempool.listen( + lambda e: self.on_mempool_event.set() + ) + await self.download_filters() + await self.download_headers() + + async def stop(self): + await self.client.disconnect() + + async def advance(self): + address_array = [ + bytearray(a['address'].encode()) + for a in await self.service.db.get_all_addresses() + ] + block_filters = await self.service.get_block_address_filters() + for block_hash, block_filter in block_filters.items(): + bf = get_address_filter(block_filter) + if bf.MatchAny(address_array): + print(f'match: {block_hash} - {block_filter}') + tx_filters = await self.service.get_transaction_address_filters(block_hash=block_hash) + for txid, tx_filter in tx_filters.items(): + tf = get_address_filter(tx_filter) + if tf.MatchAny(address_array): + print(f' match: {txid} - {tx_filter}') + txs = await self.service.search_transactions([txid]) + tx = Transaction(unhexlify(txs[txid])) + await self.service.db.insert_transaction(tx) + + # async def get_local_status_and_history(self, address, history=None): + # if not history: + # address_details = await self.db.get_address(address=address) + # history = (address_details['history'] if address_details else '') or '' + # parts = history.split(':')[:-1] + # return ( + # hexlify(sha256(history.encode())).decode() if history else None, + # list(zip(parts[0::2], map(int, parts[1::2]))) + # ) + # + # @staticmethod + # def get_root_of_merkle_tree(branches, branch_positions, working_branch): + # for i, branch in enumerate(branches): + # other_branch = unhexlify(branch)[::-1] + # other_branch_on_left = bool((branch_positions >> i) & 1) + # if other_branch_on_left: + # combined = other_branch + working_branch + # else: + # combined = working_branch + other_branch + # working_branch = double_sha256(combined) + # return hexlify(working_branch[::-1]) + # + # + # @property + # def local_height_including_downloaded_height(self): + # return max(self.headers.height, self._download_height) + # + # async def initial_headers_sync(self): + # get_chunk = partial(self.network.retriable_call, self.network.get_headers, count=1000, b64=True) + # self.headers.chunk_getter = get_chunk + # + # async def doit(): + # for height in reversed(sorted(self.headers.known_missing_checkpointed_chunks)): + # async with self._header_processing_lock: + # await self.headers.ensure_chunk_at(height) + # self._other_tasks.add(doit()) + # await self.update_headers() + # + # async def update_headers(self, height=None, headers=None, subscription_update=False): + # rewound = 0 + # while True: + # + # if height is None or height > len(self.headers): + # # sometimes header subscription updates are for a header in the future + # # which can't be connected, so we do a normal header sync instead + # height = len(self.headers) + # headers = None + # subscription_update = False + # + # if not headers: + # header_response = await self.network.retriable_call(self.network.get_headers, height, 2001) + # headers = header_response['hex'] + # + # if not headers: + # # Nothing to do, network thinks we're already at the latest height. + # return + # + # added = await self.headers.connect(height, unhexlify(headers)) + # if added > 0: + # height += added + # self._on_header_controller.add( + # BlockHeightEvent(self.headers.height, added)) + # + # if rewound > 0: + # # we started rewinding blocks and apparently found + # # a new chain + # rewound = 0 + # await self.db.rewind_blockchain(height) + # + # if subscription_update: + # # subscription updates are for latest header already + # # so we don't need to check if there are newer / more + # # on another loop of update_headers(), just return instead + # return + # + # elif added == 0: + # # we had headers to connect but none got connected, probably a reorganization + # height -= 1 + # rewound += 1 + # log.warning( + # "Blockchain Reorganization: attempting rewind to height %s from starting height %s", + # height, height+rewound + # ) + # + # else: + # raise IndexError(f"headers.connect() returned negative number ({added})") + # + # if height < 0: + # raise IndexError( + # "Blockchain reorganization rewound all the way back to genesis hash. " + # "Something is very wrong. Maybe you are on the wrong blockchain?" + # ) + # + # if rewound >= 100: + # raise IndexError( + # "Blockchain reorganization dropped {} headers. This is highly unusual. " + # "Will not continue to attempt reorganizing. Please, delete the ledger " + # "synchronization directory inside your wallet directory (folder: '{}') and " + # "restart the program to synchronize from scratch." + # .format(rewound, self.ledger.get_id()) + # ) + # + # headers = None # ready to download some more headers + # + # # if we made it this far and this was a subscription_update + # # it means something went wrong and now we're doing a more + # # robust sync, turn off subscription update shortcut + # subscription_update = False + # + # async def receive_header(self, response): + # async with self._header_processing_lock: + # header = response[0] + # await self.update_headers( + # height=header['height'], headers=header['hex'], subscription_update=True + # ) + # + # async def subscribe_accounts(self): + # if self.network.is_connected and self.accounts: + # log.info("Subscribe to %i accounts", len(self.accounts)) + # await asyncio.wait([ + # self.subscribe_account(a) for a in self.accounts + # ]) + # + # async def subscribe_account(self, account: Account): + # for address_manager in account.address_managers.values(): + # await self.subscribe_addresses(address_manager, await address_manager.get_addresses()) + # await account.ensure_address_gap() + # + # async def unsubscribe_account(self, account: Account): + # for address in await account.get_addresses(): + # await self.network.unsubscribe_address(address) + # + # async def subscribe_addresses( + # self, address_manager: AddressManager, addresses: List[str], batch_size: int = 1000): + # if self.network.is_connected and addresses: + # addresses_remaining = list(addresses) + # while addresses_remaining: + # batch = addresses_remaining[:batch_size] + # results = await self.network.subscribe_address(*batch) + # for address, remote_status in zip(batch, results): + # self._update_tasks.add(self.update_history(address, remote_status, address_manager)) + # addresses_remaining = addresses_remaining[batch_size:] + # log.info("subscribed to %i/%i addresses on %s:%i", len(addresses) - len(addresses_remaining), + # len(addresses), *self.network.client.server_address_and_port) + # log.info( + # "finished subscribing to %i addresses on %s:%i", len(addresses), + # *self.network.client.server_address_and_port + # ) + # + # def process_status_update(self, update): + # address, remote_status = update + # self._update_tasks.add(self.update_history(address, remote_status)) + # + # async def update_history(self, address, remote_status, address_manager: AddressManager = None): + # async with self._address_update_locks[address]: + # self._known_addresses_out_of_sync.discard(address) + # + # local_status, local_history = await self.get_local_status_and_history(address) + # + # if local_status == remote_status: + # return True + # + # remote_history = await self.network.retriable_call(self.network.get_history, address) + # remote_history = list(map(itemgetter('tx_hash', 'height'), remote_history)) + # we_need = set(remote_history) - set(local_history) + # if not we_need: + # return True + # + # cache_tasks: List[asyncio.Task[Transaction]] = [] + # synced_history = StringIO() + # loop = asyncio.get_running_loop() + # for i, (txid, remote_height) in enumerate(remote_history): + # if i < len(local_history) and local_history[i] == (txid, remote_height) and not cache_tasks: + # synced_history.write(f'{txid}:{remote_height}:') + # else: + # check_local = (txid, remote_height) not in we_need + # cache_tasks.append(loop.create_task( + # self.cache_transaction(unhexlify(txid)[::-1], remote_height, check_local=check_local) + # )) + # + # synced_txs = [] + # for task in cache_tasks: + # tx = await task + # + # check_db_for_txos = [] + # for txi in tx.inputs: + # if txi.txo_ref.txo is not None: + # continue + # cache_item = self._tx_cache.get(txi.txo_ref.tx_ref.hash) + # if cache_item is not None: + # if cache_item.tx is None: + # await cache_item.has_tx.wait() + # assert cache_item.tx is not None + # txi.txo_ref = cache_item.tx.outputs[txi.txo_ref.position].ref + # else: + # check_db_for_txos.append(txi.txo_ref.hash) + # + # referenced_txos = {} if not check_db_for_txos else { + # txo.id: txo for txo in await self.db.get_txos( + # txo_hash__in=check_db_for_txos, order_by='txo.txo_hash', no_tx=True + # ) + # } + # + # for txi in tx.inputs: + # if txi.txo_ref.txo is not None: + # continue + # referenced_txo = referenced_txos.get(txi.txo_ref.id) + # if referenced_txo is not None: + # txi.txo_ref = referenced_txo.ref + # + # synced_history.write(f'{tx.id}:{tx.height}:') + # synced_txs.append(tx) + # + # await self.db.save_transaction_io_batch( + # synced_txs, address, self.ledger.address_to_hash160(address), synced_history.getvalue() + # ) + # await asyncio.wait([ + # self.ledger._on_transaction_controller.add(TransactionEvent(address, tx)) + # for tx in synced_txs + # ]) + # + # if address_manager is None: + # address_manager = await self.get_address_manager_for_address(address) + # + # if address_manager is not None: + # await address_manager.ensure_address_gap() + # + # local_status, local_history = \ + # await self.get_local_status_and_history(address, synced_history.getvalue()) + # if local_status != remote_status: + # if local_history == remote_history: + # return True + # log.warning( + # "Wallet is out of sync after syncing. Remote: %s with %d items, local: %s with %d items", + # remote_status, len(remote_history), local_status, len(local_history) + # ) + # log.warning("local: %s", local_history) + # log.warning("remote: %s", remote_history) + # self._known_addresses_out_of_sync.add(address) + # return False + # else: + # return True + # + # async def cache_transaction(self, tx_hash, remote_height, check_local=True): + # cache_item = self._tx_cache.get(tx_hash) + # if cache_item is None: + # cache_item = self._tx_cache[tx_hash] = TransactionCacheItem() + # elif cache_item.tx is not None and \ + # cache_item.tx.height >= remote_height and \ + # (cache_item.tx.is_verified or remote_height < 1): + # return cache_item.tx # cached tx is already up-to-date + # + # try: + # cache_item.pending_verifications += 1 + # return await self._update_cache_item(cache_item, tx_hash, remote_height, check_local) + # finally: + # cache_item.pending_verifications -= 1 + # + # async def _update_cache_item(self, cache_item, tx_hash, remote_height, check_local=True): + # + # async with cache_item.lock: + # + # tx = cache_item.tx + # + # if tx is None and check_local: + # # check local db + # tx = cache_item.tx = await self.db.get_transaction(tx_hash=tx_hash) + # + # merkle = None + # if tx is None: + # # fetch from network + # _raw, merkle = await self.network.retriable_call( + # self.network.get_transaction_and_merkle, tx_hash, remote_height + # ) + # tx = Transaction(unhexlify(_raw), height=merkle.get('block_height')) + # cache_item.tx = tx # make sure it's saved before caching it + # await self.maybe_verify_transaction(tx, remote_height, merkle) + # return tx + # + # async def maybe_verify_transaction(self, tx, remote_height, merkle=None): + # tx.height = remote_height + # cached = self._tx_cache.get(tx.hash) + # if not cached: + # # cache txs looked up by transaction_show too + # cached = TransactionCacheItem() + # cached.tx = tx + # self._tx_cache[tx.hash] = cached + # if 0 < remote_height < len(self.headers) and cached.pending_verifications <= 1: + # # can't be tx.pending_verifications == 1 because we have to handle the transaction_show case + # if not merkle: + # merkle = await self.network.retriable_call(self.network.get_merkle, tx.hash, remote_height) + # merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash) + # header = await self.headers.get(remote_height) + # tx.position = merkle['pos'] + # tx.is_verified = merkle_root == header['merkle_root'] + # + # async def get_address_manager_for_address(self, address) -> Optional[AddressManager]: + # details = await self.db.get_address(address=address) + # for account in self.accounts: + # if account.id == details['account']: + # return account.address_managers[details['chain']] + # return None \ No newline at end of file diff --git a/lbry/testcase.py b/lbry/testcase.py index 1f414e8cc..8ec498ad8 100644 --- a/lbry/testcase.py +++ b/lbry/testcase.py @@ -447,6 +447,38 @@ class IntegrationTestCase(AsyncioTestCase): self.db_driver = db_driver return db + async def add_full_node(self, port): + path = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, path, True) + ledger = RegTestLedger(Config.with_same_dir(path).set( + api=f'localhost:{port}', + lbrycrd_dir=self.chain.ledger.conf.lbrycrd_dir, + lbrycrd_rpc_port=self.chain.ledger.conf.lbrycrd_rpc_port, + lbrycrd_peer_port=self.chain.ledger.conf.lbrycrd_peer_port, + lbrycrd_zmq=self.chain.ledger.conf.lbrycrd_zmq + )) + service = FullNode(ledger) + console = Console(service) + daemon = Daemon(service, console) + self.addCleanup(daemon.stop) + await daemon.start() + return daemon + + async def add_light_client(self, full_node, port, start=True): + path = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, path, True) + ledger = RegTestLedger(Config.with_same_dir(path).set( + api=f'localhost:{port}', + full_nodes=[(full_node.conf.api_host, full_node.conf.api_port)] + )) + service = LightClient(ledger) + console = Console(service) + daemon = Daemon(service, console) + self.addCleanup(daemon.stop) + if start: + await daemon.start() + return daemon + @staticmethod def find_claim_txo(tx) -> Optional[Output]: for txo in tx.outputs: @@ -538,9 +570,11 @@ class CommandTestCase(IntegrationTestCase): await super().asyncSetUp() await self.generate(200, wait=False) - self.full_node = self.daemon = await self.add_full_node() - if os.environ.get('TEST_MODE', 'full-node') == 'client': - self.daemon = await self.add_light_client(self.full_node) + self.daemon_port += 1 + self.full_node = self.daemon = await self.add_full_node(self.daemon_port) + if os.environ.get('TEST_MODE', 'node') == 'client': + self.daemon_port += 1 + self.daemon = await self.add_light_client(self.full_node, self.daemon_port) self.service = self.daemon.service self.ledger = self.service.ledger @@ -556,40 +590,6 @@ class CommandTestCase(IntegrationTestCase): await self.chain.send_to_address(addresses[0], '10.0') await self.generate(5) - async def add_full_node(self): - self.daemon_port += 1 - path = tempfile.mkdtemp() - self.addCleanup(shutil.rmtree, path, True) - ledger = RegTestLedger(Config.with_same_dir(path).set( - api=f'localhost:{self.daemon_port}', - lbrycrd_dir=self.chain.ledger.conf.lbrycrd_dir, - lbrycrd_rpc_port=self.chain.ledger.conf.lbrycrd_rpc_port, - lbrycrd_peer_port=self.chain.ledger.conf.lbrycrd_peer_port, - lbrycrd_zmq=self.chain.ledger.conf.lbrycrd_zmq, - spv_address_filters=False - )) - service = FullNode(ledger) - console = Console(service) - daemon = Daemon(service, console) - self.addCleanup(daemon.stop) - await daemon.start() - return daemon - - async def add_light_client(self, full_node): - self.daemon_port += 1 - path = tempfile.mkdtemp() - self.addCleanup(shutil.rmtree, path, True) - ledger = RegTestLedger(Config.with_same_dir(path).set( - api=f'localhost:{self.daemon_port}', - full_nodes=[(full_node.conf.api_host, full_node.conf.api_port)] - )) - service = LightClient(ledger) - console = Console(service) - daemon = Daemon(service, console) - self.addCleanup(daemon.stop) - await daemon.start() - return daemon - async def asyncTearDown(self): await super().asyncTearDown() for wallet_node in self.extra_wallet_nodes: diff --git a/lbry/wallet/account.py b/lbry/wallet/account.py index f8b996321..1a7b065b1 100644 --- a/lbry/wallet/account.py +++ b/lbry/wallet/account.py @@ -12,6 +12,7 @@ import ecdsa from lbry.constants import COIN from lbry.db import Database, CLAIM_TYPE_CODES, TXO_TYPES +from lbry.db.tables import AccountAddress from lbry.blockchain import Ledger from lbry.error import InvalidPasswordError from lbry.crypto.crypt import aes_encrypt, aes_decrypt @@ -214,12 +215,12 @@ class Account: HierarchicalDeterministic.name: HierarchicalDeterministic, } - def __init__(self, ledger: Ledger, db: Database, name: str, + def __init__(self, db: Database, name: str, phrase: str, language: str, private_key_string: str, encrypted: bool, private_key: Optional[PrivateKey], public_key: PubKey, address_generator: dict, modified_on: float, channel_keys: dict) -> None: - self.ledger = ledger self.db = db + self.ledger = db.ledger self.id = public_key.address self.name = name self.phrase = phrase @@ -245,10 +246,10 @@ class Account: @classmethod async def generate( - cls, ledger: Ledger, db: Database, - name: str = None, language: str = 'en', - address_generator: dict = None): - return await cls.from_dict(ledger, db, { + cls, db: Database, name: str = None, + language: str = 'en', address_generator: dict = None + ): + return await cls.from_dict(db, { 'name': name, 'seed': await mnemonic.generate_phrase(language), 'language': language, @@ -276,13 +277,12 @@ class Account: return phrase, private_key, public_key @classmethod - async def from_dict(cls, ledger: Ledger, db: Database, d: dict): - phrase, private_key, public_key = await cls.keys_from_dict(ledger, d) + async def from_dict(cls, db: Database, d: dict): + phrase, private_key, public_key = await cls.keys_from_dict(db.ledger, d) name = d.get('name') if not name: name = f'Account #{public_key.address}' return cls( - ledger=ledger, db=db, name=name, phrase=phrase, @@ -415,7 +415,7 @@ class Account: return await self.db.get_addresses(account=self, **constraints) async def get_addresses(self, **constraints) -> List[str]: - rows = await self.get_address_records(cols=['account_address.address'], **constraints) + rows = await self.get_address_records(cols=[AccountAddress.c.address], **constraints) return [r['address'] for r in rows] async def get_valid_receiving_address(self, default_address: str) -> str: diff --git a/lbry/wallet/database.py b/lbry/wallet/database.py deleted file mode 100644 index 9f1371bd2..000000000 --- a/lbry/wallet/database.py +++ /dev/null @@ -1,385 +0,0 @@ -import os -import logging -import asyncio -import sqlite3 -import platform -from binascii import hexlify -from dataclasses import dataclass -from contextvars import ContextVar -from concurrent.futures.thread import ThreadPoolExecutor -from concurrent.futures.process import ProcessPoolExecutor -from typing import Tuple, List, Union, Callable, Any, Awaitable, Iterable, Dict, Optional - - -log = logging.getLogger(__name__) -sqlite3.enable_callback_tracebacks(True) - - -@dataclass -class ReaderProcessState: - cursor: sqlite3.Cursor - - -reader_context: Optional[ContextVar[ReaderProcessState]] = ContextVar('reader_context') - - -def initializer(path): - db = sqlite3.connect(path) - db.row_factory = dict_row_factory - db.executescript("pragma journal_mode=WAL;") - reader = ReaderProcessState(db.cursor()) - reader_context.set(reader) - - -def run_read_only_fetchall(sql, params): - cursor = reader_context.get().cursor - try: - return cursor.execute(sql, params).fetchall() - except (Exception, OSError) as e: - log.exception('Error running transaction:', exc_info=e) - raise - - -def run_read_only_fetchone(sql, params): - cursor = reader_context.get().cursor - try: - return cursor.execute(sql, params).fetchone() - except (Exception, OSError) as e: - log.exception('Error running transaction:', exc_info=e) - raise - - -if platform.system() == 'Windows' or 'ANDROID_ARGUMENT' in os.environ: - ReaderExecutorClass = ThreadPoolExecutor -else: - ReaderExecutorClass = ProcessPoolExecutor - - -class AIOSQLite: - reader_executor: ReaderExecutorClass - - def __init__(self): - # has to be single threaded as there is no mapping of thread:connection - self.writer_executor = ThreadPoolExecutor(max_workers=1) - self.writer_connection: Optional[sqlite3.Connection] = None - self._closing = False - self.query_count = 0 - self.write_lock = asyncio.Lock() - self.writers = 0 - self.read_ready = asyncio.Event() - - @classmethod - async def connect(cls, path: Union[bytes, str], *args, **kwargs): - sqlite3.enable_callback_tracebacks(True) - db = cls() - - def _connect_writer(): - db.writer_connection = sqlite3.connect(path, *args, **kwargs) - - readers = max(os.cpu_count() - 2, 2) - db.reader_executor = ReaderExecutorClass( - max_workers=readers, initializer=initializer, initargs=(path, ) - ) - await asyncio.get_event_loop().run_in_executor(db.writer_executor, _connect_writer) - db.read_ready.set() - return db - - async def close(self): - if self._closing: - return - self._closing = True - await asyncio.get_event_loop().run_in_executor(self.writer_executor, self.writer_connection.close) - self.writer_executor.shutdown(wait=True) - self.reader_executor.shutdown(wait=True) - self.read_ready.clear() - self.writer_connection = None - - def executemany(self, sql: str, params: Iterable): - params = params if params is not None else [] - # this fetchall is needed to prevent SQLITE_MISUSE - return self.run(lambda conn: conn.executemany(sql, params).fetchall()) - - def executescript(self, script: str) -> Awaitable: - return self.run(lambda conn: conn.executescript(script)) - - async def _execute_fetch(self, sql: str, parameters: Iterable = None, - read_only=False, fetch_all: bool = False) -> List[dict]: - read_only_fn = run_read_only_fetchall if fetch_all else run_read_only_fetchone - parameters = parameters if parameters is not None else [] - if read_only: - while self.writers: - await self.read_ready.wait() - return await asyncio.get_event_loop().run_in_executor( - self.reader_executor, read_only_fn, sql, parameters - ) - if fetch_all: - return await self.run(lambda conn: conn.execute(sql, parameters).fetchall()) - return await self.run(lambda conn: conn.execute(sql, parameters).fetchone()) - - async def execute_fetchall(self, sql: str, parameters: Iterable = None, - read_only=False) -> List[dict]: - return await self._execute_fetch(sql, parameters, read_only, fetch_all=True) - - async def execute_fetchone(self, sql: str, parameters: Iterable = None, - read_only=False) -> List[dict]: - return await self._execute_fetch(sql, parameters, read_only, fetch_all=False) - - def execute(self, sql: str, parameters: Iterable = None) -> Awaitable[sqlite3.Cursor]: - parameters = parameters if parameters is not None else [] - return self.run(lambda conn: conn.execute(sql, parameters)) - - async def run(self, fun, *args, **kwargs): - self.writers += 1 - self.read_ready.clear() - async with self.write_lock: - try: - return await asyncio.get_event_loop().run_in_executor( - self.writer_executor, lambda: self.__run_transaction(fun, *args, **kwargs) - ) - finally: - self.writers -= 1 - if not self.writers: - self.read_ready.set() - - def __run_transaction(self, fun: Callable[[sqlite3.Connection, Any, Any], Any], *args, **kwargs): - self.writer_connection.execute('begin') - try: - self.query_count += 1 - result = fun(self.writer_connection, *args, **kwargs) # type: ignore - self.writer_connection.commit() - return result - except (Exception, OSError) as e: - log.exception('Error running transaction:', exc_info=e) - self.writer_connection.rollback() - log.warning("rolled back") - raise - - def run_with_foreign_keys_disabled(self, fun, *args, **kwargs) -> Awaitable: - return asyncio.get_event_loop().run_in_executor( - self.writer_executor, self.__run_transaction_with_foreign_keys_disabled, fun, args, kwargs - ) - - def __run_transaction_with_foreign_keys_disabled(self, - fun: Callable[[sqlite3.Connection, Any, Any], Any], - args, kwargs): - foreign_keys_enabled, = self.writer_connection.execute("pragma foreign_keys").fetchone() - if not foreign_keys_enabled: - raise sqlite3.IntegrityError("foreign keys are disabled, use `AIOSQLite.run` instead") - try: - self.writer_connection.execute('pragma foreign_keys=off').fetchone() - return self.__run_transaction(fun, *args, **kwargs) - finally: - self.writer_connection.execute('pragma foreign_keys=on').fetchone() - - -def constraints_to_sql(constraints, joiner=' AND ', prepend_key=''): - sql, values = [], {} - for key, constraint in constraints.items(): - tag = '0' - if '#' in key: - key, tag = key[:key.index('#')], key[key.index('#')+1:] - col, op, key = key, '=', key.replace('.', '_') - if not key: - sql.append(constraint) - continue - if key.startswith('$$'): - col, key = col[2:], key[1:] - elif key.startswith('$'): - values[key] = constraint - continue - if key.endswith('__not'): - col, op = col[:-len('__not')], '!=' - elif key.endswith('__is_null'): - col = col[:-len('__is_null')] - sql.append(f'{col} IS NULL') - continue - if key.endswith('__is_not_null'): - col = col[:-len('__is_not_null')] - sql.append(f'{col} IS NOT NULL') - continue - if key.endswith('__lt'): - col, op = col[:-len('__lt')], '<' - elif key.endswith('__lte'): - col, op = col[:-len('__lte')], '<=' - elif key.endswith('__gt'): - col, op = col[:-len('__gt')], '>' - elif key.endswith('__gte'): - col, op = col[:-len('__gte')], '>=' - elif key.endswith('__like'): - col, op = col[:-len('__like')], 'LIKE' - elif key.endswith('__not_like'): - col, op = col[:-len('__not_like')], 'NOT LIKE' - elif key.endswith('__in') or key.endswith('__not_in'): - if key.endswith('__in'): - col, op, one_val_op = col[:-len('__in')], 'IN', '=' - else: - col, op, one_val_op = col[:-len('__not_in')], 'NOT IN', '!=' - if constraint: - if isinstance(constraint, (list, set, tuple)): - if len(constraint) == 1: - values[f'{key}{tag}'] = next(iter(constraint)) - sql.append(f'{col} {one_val_op} :{key}{tag}') - else: - keys = [] - for i, val in enumerate(constraint): - keys.append(f':{key}{tag}_{i}') - values[f'{key}{tag}_{i}'] = val - sql.append(f'{col} {op} ({", ".join(keys)})') - elif isinstance(constraint, str): - sql.append(f'{col} {op} ({constraint})') - else: - raise ValueError(f"{col} requires a list, set or string as constraint value.") - continue - elif key.endswith('__any') or key.endswith('__or'): - where, subvalues = constraints_to_sql(constraint, ' OR ', key+tag+'_') - sql.append(f'({where})') - values.update(subvalues) - continue - if key.endswith('__and'): - where, subvalues = constraints_to_sql(constraint, ' AND ', key+tag+'_') - sql.append(f'({where})') - values.update(subvalues) - continue - sql.append(f'{col} {op} :{prepend_key}{key}{tag}') - values[prepend_key+key+tag] = constraint - return joiner.join(sql) if sql else '', values - - -def query(select, **constraints) -> Tuple[str, Dict[str, Any]]: - sql = [select] - limit = constraints.pop('limit', None) - offset = constraints.pop('offset', None) - order_by = constraints.pop('order_by', None) - group_by = constraints.pop('group_by', None) - - accounts = constraints.pop('accounts', []) - if accounts: - constraints['account__in'] = [a.public_key.address for a in accounts] - - where, values = constraints_to_sql(constraints) - if where: - sql.append('WHERE') - sql.append(where) - - if group_by is not None: - sql.append(f'GROUP BY {group_by}') - - if order_by: - sql.append('ORDER BY') - if isinstance(order_by, str): - sql.append(order_by) - elif isinstance(order_by, list): - sql.append(', '.join(order_by)) - else: - raise ValueError("order_by must be string or list") - - if limit is not None: - sql.append(f'LIMIT {limit}') - - if offset is not None: - sql.append(f'OFFSET {offset}') - - return ' '.join(sql), values - - -def interpolate(sql, values): - for k in sorted(values.keys(), reverse=True): - value = values[k] - if isinstance(value, bytes): - value = f"X'{hexlify(value).decode()}'" - elif isinstance(value, str): - value = f"'{value}'" - else: - value = str(value) - sql = sql.replace(f":{k}", value) - return sql - - -def constrain_single_or_list(constraints, column, value, convert=lambda x: x): - if value is not None: - if isinstance(value, list): - value = [convert(v) for v in value] - if len(value) == 1: - constraints[column] = value[0] - elif len(value) > 1: - constraints[f"{column}__in"] = value - else: - constraints[column] = convert(value) - return constraints - - -class SQLiteMixin: - - SCHEMA_VERSION: Optional[str] = None - CREATE_TABLES_QUERY: str - MAX_QUERY_VARIABLES = 900 - - CREATE_VERSION_TABLE = """ - create table if not exists version ( - version text - ); - """ - - def __init__(self, path): - self._db_path = path - self.db: AIOSQLite = None - self.ledger = None - - async def open(self): - log.info("connecting to database: %s", self._db_path) - self.db = await AIOSQLite.connect(self._db_path, isolation_level=None) - if self.SCHEMA_VERSION: - tables = [t[0] for t in await self.db.execute_fetchall( - "SELECT name FROM sqlite_master WHERE type='table';" - )] - if tables: - if 'version' in tables: - version = await self.db.execute_fetchone("SELECT version FROM version LIMIT 1;") - if version == (self.SCHEMA_VERSION,): - return - await self.db.executescript('\n'.join( - f"DROP TABLE {table};" for table in tables - )) - await self.db.execute(self.CREATE_VERSION_TABLE) - await self.db.execute("INSERT INTO version VALUES (?)", (self.SCHEMA_VERSION,)) - await self.db.executescript(self.CREATE_TABLES_QUERY) - - async def close(self): - await self.db.close() - - @staticmethod - def _insert_sql(table: str, data: dict, ignore_duplicate: bool = False, - replace: bool = False) -> Tuple[str, List]: - columns, values = [], [] - for column, value in data.items(): - columns.append(column) - values.append(value) - policy = "" - if ignore_duplicate: - policy = " OR IGNORE" - if replace: - policy = " OR REPLACE" - sql = "INSERT{} INTO {} ({}) VALUES ({})".format( - policy, table, ', '.join(columns), ', '.join(['?'] * len(values)) - ) - return sql, values - - @staticmethod - def _update_sql(table: str, data: dict, where: str, - constraints: Union[list, tuple]) -> Tuple[str, list]: - columns, values = [], [] - for column, value in data.items(): - columns.append(f"{column} = ?") - values.append(value) - values.extend(constraints) - sql = "UPDATE {} SET {} WHERE {}".format( - table, ', '.join(columns), where - ) - return sql, values - - -def dict_row_factory(cursor, row): - d = {} - for idx, col in enumerate(cursor.description): - d[col[0]] = row[idx] - return d diff --git a/lbry/wallet/manager.py b/lbry/wallet/manager.py index 965365fef..69818b633 100644 --- a/lbry/wallet/manager.py +++ b/lbry/wallet/manager.py @@ -1,22 +1,30 @@ import os +import stat +import json import asyncio import logging from typing import Optional, Dict from lbry.db import Database -from lbry.blockchain.ledger import Ledger from .wallet import Wallet +from .account import SingleKey, HierarchicalDeterministic log = logging.getLogger(__name__) class WalletManager: - def __init__(self, ledger: Ledger, db: Database): - self.ledger = ledger + def __init__(self, db: Database): self.db = db + self.ledger = db.ledger self.wallets: Dict[str, Wallet] = {} + if self.ledger.conf.wallet_storage == "file": + self.storage = FileWallet(self.db, self.ledger.conf.wallet_dir) + elif self.ledger.conf.wallet_storage == "database": + self.storage = DatabaseWallet(self.db) + else: + raise Exception(f"Unknown wallet storage format: {self.ledger.conf.wallet_storage}") def __getitem__(self, wallet_id: str) -> Wallet: try: @@ -43,32 +51,12 @@ class WalletManager: raise ValueError("Cannot spend funds with locked wallet, unlock first.") return wallet - @property - def path(self): - return os.path.join(self.ledger.conf.wallet_dir, 'wallets') - - def sync_ensure_path_exists(self): - if not os.path.exists(self.path): - os.mkdir(self.path) - - async def ensure_path_exists(self): - await asyncio.get_running_loop().run_in_executor( - None, self.sync_ensure_path_exists - ) - - async def load(self): - wallets_directory = self.path + async def initialize(self): for wallet_id in self.ledger.conf.wallets: if wallet_id in self.wallets: log.warning("Ignoring duplicate wallet_id in config: %s", wallet_id) continue - wallet_path = os.path.join(wallets_directory, wallet_id) - if not os.path.exists(wallet_path): - if not wallet_id == "default_wallet": # we'll probably generate this wallet, don't show error - log.error("Could not load wallet, file does not exist: %s", wallet_path) - continue - wallet = await Wallet.from_path(self.ledger, self.db, wallet_path) - self.add(wallet) + await self.load(wallet_id) default_wallet = self.default if default_wallet is None: if self.ledger.conf.create_default_wallet: @@ -83,34 +71,153 @@ class WalletManager: elif not default_wallet.has_accounts and self.ledger.conf.create_default_account: await default_wallet.accounts.generate() - def add(self, wallet: Wallet) -> Wallet: - self.wallets[wallet.id] = wallet - return wallet - - async def add_from_path(self, wallet_path) -> Wallet: - wallet_id = os.path.basename(wallet_path) - if wallet_id in self.wallets: - existing = self.wallets.get(wallet_id) - if existing.storage.path == wallet_path: - raise Exception(f"Wallet '{wallet_id}' is already loaded.") - raise Exception( - f"Wallet '{wallet_id}' is already loaded from '{existing.storage.path}'" - f" and cannot be loaded from '{wallet_path}'. Consider changing the wallet" - f" filename to be unique in order to avoid conflicts." - ) - wallet = await Wallet.from_path(self.ledger, self.db, wallet_path) - return self.add(wallet) + async def load(self, wallet_id: str) -> Optional[Wallet]: + wallet = await self.storage.get(wallet_id) + if wallet is not None: + return self.add(wallet) async def create( - self, wallet_id: str, name: str, - create_account=False, language='en', single_key=False) -> Wallet: + self, wallet_id: str, name: str = "", + create_account=False, language="en", single_key=False + ) -> Wallet: if wallet_id in self.wallets: raise Exception(f"Wallet with id '{wallet_id}' is already loaded and cannot be created.") - wallet_path = os.path.join(self.path, wallet_id) - if os.path.exists(wallet_path): - raise Exception(f"Wallet at path '{wallet_path}' already exists, use 'wallet_add' to load wallet.") - wallet = await Wallet.create( - self.ledger, self.db, wallet_path, name, - create_account, language, single_key - ) + if await self.storage.exists(wallet_id): + raise Exception(f"Wallet '{wallet_id}' already exists, use 'wallet_add' to load wallet.") + wallet = Wallet(wallet_id, self.db, name) + if create_account: + await wallet.accounts.generate(language=language, address_generator={ + 'name': SingleKey.name if single_key else HierarchicalDeterministic.name + }) + await self.storage.save(wallet) return self.add(wallet) + + def add(self, wallet: Wallet) -> Wallet: + self.wallets[wallet.id] = wallet + wallet.on_change.listen(lambda _: self.storage.save(wallet)) + return wallet + + async def _report_state(self): + try: + for account in self.accounts: + balance = dewies_to_lbc(await account.get_balance(include_claims=True)) + _, channel_count = await account.get_channels(limit=1) + claim_count = await account.get_claim_count() + if isinstance(account.receiving, SingleKey): + log.info("Loaded single key account %s with %s LBC. " + "%d channels, %d certificates and %d claims", + account.id, balance, channel_count, len(account.channel_keys), claim_count) + else: + total_receiving = len(await account.receiving.get_addresses()) + total_change = len(await account.change.get_addresses()) + log.info("Loaded account %s with %s LBC, %d receiving addresses (gap: %d), " + "%d change addresses (gap: %d), %d channels, %d certificates and %d claims. ", + account.id, balance, total_receiving, account.receiving.gap, total_change, + account.change.gap, channel_count, len(account.channel_keys), claim_count) + except Exception as err: + if isinstance(err, asyncio.CancelledError): # TODO: remove when updated to 3.8 + raise + log.exception( + 'Failed to display wallet state, please file issue ' + 'for this bug along with the traceback you see below:' + ) + + +class WalletStorage: + + async def prepare(self): + raise NotImplementedError + + async def exists(self, walllet_id: str) -> bool: + raise NotImplementedError + + async def get(self, wallet_id: str) -> Wallet: + raise NotImplementedError + + async def save(self, wallet: Wallet): + raise NotImplementedError + + +class FileWallet(WalletStorage): + + def __init__(self, db, wallet_dir): + self.db = db + self.wallet_dir = wallet_dir + + def get_wallet_path(self, wallet_id: str): + return os.path.join(self.wallet_dir, wallet_id) + + async def prepare(self): + await asyncio.get_running_loop().run_in_executor( + None, self.sync_ensure_wallets_directory_exists + ) + + def sync_ensure_wallets_directory_exists(self): + if not os.path.exists(self.wallet_dir): + os.mkdir(self.wallet_dir) + + async def exists(self, wallet_id: str) -> bool: + return os.path.exists(self.get_wallet_path(wallet_id)) + + async def get(self, wallet_id: str) -> Wallet: + wallet_dict = await asyncio.get_running_loop().run_in_executor( + None, self.sync_read, wallet_id + ) + if wallet_dict is not None: + return await Wallet.from_dict(wallet_id, wallet_dict, self.db) + + def sync_read(self, wallet_id): + try: + with open(self.get_wallet_path(wallet_id), 'r') as f: + json_data = f.read() + return json.loads(json_data) + except FileNotFoundError: + return None + + async def save(self, wallet: Wallet): + return await asyncio.get_running_loop().run_in_executor( + None, self.sync_write, wallet + ) + + def sync_write(self, wallet: Wallet): + temp_path = os.path.join(self.wallet_dir, f".tmp.{os.path.basename(wallet.id)}") + with open(temp_path, "w") as f: + f.write(wallet.to_serialized()) + f.flush() + os.fsync(f.fileno()) + + wallet_path = self.get_wallet_path(wallet.id) + if os.path.exists(wallet_path): + mode = os.stat(wallet_path).st_mode + else: + mode = stat.S_IREAD | stat.S_IWRITE + try: + os.rename(temp_path, wallet_path) + except Exception: # pylint: disable=broad-except + os.remove(wallet_path) + os.rename(temp_path, wallet_path) + os.chmod(wallet_path, mode) + + +class DatabaseWallet(WalletStorage): + + def __init__(self, db: 'Database'): + self.db = db + + async def prepare(self): + pass + + async def exists(self, wallet_id: str) -> bool: + return await self.db.has_wallet(wallet_id) + + async def get(self, wallet_id: str) -> Wallet: + data = await self.db.get_wallet(wallet_id) + if data: + wallet_dict = json.loads(data['data']) + if wallet_dict is not None: + return await Wallet.from_dict(wallet_id, wallet_dict, self.db) + + async def save(self, wallet: Wallet): + await self.db.add_wallet( + wallet.id, wallet.to_serialized() + ) diff --git a/lbry/wallet/storage.py b/lbry/wallet/storage.py deleted file mode 100644 index 3859bb204..000000000 --- a/lbry/wallet/storage.py +++ /dev/null @@ -1,60 +0,0 @@ -import os -import stat -import json -import asyncio - - -class WalletStorage: - VERSION = 1 - - def __init__(self, path=None): - self.path = path - - def sync_read(self): - with open(self.path, 'r') as f: - json_data = f.read() - json_dict = json.loads(json_data) - if json_dict.get('version') == self.VERSION: - return json_dict - else: - return self.upgrade(json_dict) - - async def read(self): - return await asyncio.get_running_loop().run_in_executor( - None, self.sync_read - ) - - def upgrade(self, json_dict): - version = json_dict.pop('version', -1) - if version == -1: - pass - json_dict['version'] = self.VERSION - return json_dict - - def sync_write(self, json_dict): - - json_data = json.dumps(json_dict, indent=4, sort_keys=True) - if self.path is None: - return json_data - - temp_path = "{}.tmp.{}".format(self.path, os.getpid()) - with open(temp_path, "w") as f: - f.write(json_data) - f.flush() - os.fsync(f.fileno()) - - if os.path.exists(self.path): - mode = os.stat(self.path).st_mode - else: - mode = stat.S_IREAD | stat.S_IWRITE - try: - os.rename(temp_path, self.path) - except Exception: # pylint: disable=broad-except - os.remove(self.path) - os.rename(temp_path, self.path) - os.chmod(self.path, mode) - - async def write(self, json_dict): - return await asyncio.get_running_loop().run_in_executor( - None, self.sync_write, json_dict - ) diff --git a/lbry/wallet/sync.py b/lbry/wallet/sync.py deleted file mode 100644 index 991d597de..000000000 --- a/lbry/wallet/sync.py +++ /dev/null @@ -1,420 +0,0 @@ -import asyncio -import logging -#from io import StringIO -#from functools import partial -#from operator import itemgetter -from collections import defaultdict -#from binascii import hexlify, unhexlify -from typing import List, Optional, DefaultDict, NamedTuple - -#from lbry.crypto.hash import double_sha256, sha256 - -from lbry.tasks import TaskGroup -from lbry.blockchain.transaction import Transaction -#from lbry.blockchain.block import get_block_filter -from lbry.event import EventController -from lbry.service.base import Service, Sync - -from .account import AddressManager - - -log = logging.getLogger(__name__) - - -class TransactionEvent(NamedTuple): - address: str - tx: Transaction - - -class AddressesGeneratedEvent(NamedTuple): - address_manager: AddressManager - addresses: List[str] - - -class TransactionCacheItem: - __slots__ = '_tx', 'lock', 'has_tx', 'pending_verifications' - - def __init__(self, tx: Optional[Transaction] = None, lock: Optional[asyncio.Lock] = None): - self.has_tx = asyncio.Event() - self.lock = lock or asyncio.Lock() - self._tx = self.tx = tx - self.pending_verifications = 0 - - @property - def tx(self) -> Optional[Transaction]: - return self._tx - - @tx.setter - def tx(self, tx: Transaction): - self._tx = tx - if tx is not None: - self.has_tx.set() - - -class SPVSync(Sync): - - def __init__(self, service: Service): - super().__init__(service.ledger, service.db) - - self.accounts = [] - - self._on_header_controller = EventController() - self.on_header = self._on_header_controller.stream - self.on_header.listen( - lambda change: log.info( - '%s: added %s header blocks', - self.ledger.get_id(), change - ) - ) - self._download_height = 0 - - self._on_ready_controller = EventController() - self.on_ready = self._on_ready_controller.stream - - #self._tx_cache = pylru.lrucache(100000) - self._update_tasks = TaskGroup() - self._other_tasks = TaskGroup() # that we dont need to start - self._header_processing_lock = asyncio.Lock() - self._address_update_locks: DefaultDict[str, asyncio.Lock] = defaultdict(asyncio.Lock) - self._known_addresses_out_of_sync = set() - - # async def advance(self): - # address_array = [ - # bytearray(a['address'].encode()) - # for a in await self.service.db.get_all_addresses() - # ] - # block_filters = await self.service.get_block_address_filters() - # for block_hash, block_filter in block_filters.items(): - # bf = get_block_filter(block_filter) - # if bf.MatchAny(address_array): - # print(f'match: {block_hash} - {block_filter}') - # tx_filters = await self.service.get_transaction_address_filters(block_hash=block_hash) - # for txid, tx_filter in tx_filters.items(): - # tf = get_block_filter(tx_filter) - # if tf.MatchAny(address_array): - # print(f' match: {txid} - {tx_filter}') - # txs = await self.service.search_transactions([txid]) - # tx = Transaction(unhexlify(txs[txid])) - # await self.service.db.insert_transaction(tx) - # - # async def get_local_status_and_history(self, address, history=None): - # if not history: - # address_details = await self.db.get_address(address=address) - # history = (address_details['history'] if address_details else '') or '' - # parts = history.split(':')[:-1] - # return ( - # hexlify(sha256(history.encode())).decode() if history else None, - # list(zip(parts[0::2], map(int, parts[1::2]))) - # ) - # - # @staticmethod - # def get_root_of_merkle_tree(branches, branch_positions, working_branch): - # for i, branch in enumerate(branches): - # other_branch = unhexlify(branch)[::-1] - # other_branch_on_left = bool((branch_positions >> i) & 1) - # if other_branch_on_left: - # combined = other_branch + working_branch - # else: - # combined = working_branch + other_branch - # working_branch = double_sha256(combined) - # return hexlify(working_branch[::-1]) - # - async def start(self): - fully_synced = self.on_ready.first - #asyncio.create_task(self.network.start()) - #await self.network.on_connected.first - #async with self._header_processing_lock: - # await self._update_tasks.add(self.initial_headers_sync()) - await fully_synced - # - # async def join_network(self, *_): - # log.info("Subscribing and updating accounts.") - # await self._update_tasks.add(self.subscribe_accounts()) - # await self._update_tasks.done.wait() - # self._on_ready_controller.add(True) - # - async def stop(self): - self._update_tasks.cancel() - self._other_tasks.cancel() - await self._update_tasks.done.wait() - await self._other_tasks.done.wait() - # - # @property - # def local_height_including_downloaded_height(self): - # return max(self.headers.height, self._download_height) - # - # async def initial_headers_sync(self): - # get_chunk = partial(self.network.retriable_call, self.network.get_headers, count=1000, b64=True) - # self.headers.chunk_getter = get_chunk - # - # async def doit(): - # for height in reversed(sorted(self.headers.known_missing_checkpointed_chunks)): - # async with self._header_processing_lock: - # await self.headers.ensure_chunk_at(height) - # self._other_tasks.add(doit()) - # await self.update_headers() - # - # async def update_headers(self, height=None, headers=None, subscription_update=False): - # rewound = 0 - # while True: - # - # if height is None or height > len(self.headers): - # # sometimes header subscription updates are for a header in the future - # # which can't be connected, so we do a normal header sync instead - # height = len(self.headers) - # headers = None - # subscription_update = False - # - # if not headers: - # header_response = await self.network.retriable_call(self.network.get_headers, height, 2001) - # headers = header_response['hex'] - # - # if not headers: - # # Nothing to do, network thinks we're already at the latest height. - # return - # - # added = await self.headers.connect(height, unhexlify(headers)) - # if added > 0: - # height += added - # self._on_header_controller.add( - # BlockHeightEvent(self.headers.height, added)) - # - # if rewound > 0: - # # we started rewinding blocks and apparently found - # # a new chain - # rewound = 0 - # await self.db.rewind_blockchain(height) - # - # if subscription_update: - # # subscription updates are for latest header already - # # so we don't need to check if there are newer / more - # # on another loop of update_headers(), just return instead - # return - # - # elif added == 0: - # # we had headers to connect but none got connected, probably a reorganization - # height -= 1 - # rewound += 1 - # log.warning( - # "Blockchain Reorganization: attempting rewind to height %s from starting height %s", - # height, height+rewound - # ) - # - # else: - # raise IndexError(f"headers.connect() returned negative number ({added})") - # - # if height < 0: - # raise IndexError( - # "Blockchain reorganization rewound all the way back to genesis hash. " - # "Something is very wrong. Maybe you are on the wrong blockchain?" - # ) - # - # if rewound >= 100: - # raise IndexError( - # "Blockchain reorganization dropped {} headers. This is highly unusual. " - # "Will not continue to attempt reorganizing. Please, delete the ledger " - # "synchronization directory inside your wallet directory (folder: '{}') and " - # "restart the program to synchronize from scratch." - # .format(rewound, self.ledger.get_id()) - # ) - # - # headers = None # ready to download some more headers - # - # # if we made it this far and this was a subscription_update - # # it means something went wrong and now we're doing a more - # # robust sync, turn off subscription update shortcut - # subscription_update = False - # - # async def receive_header(self, response): - # async with self._header_processing_lock: - # header = response[0] - # await self.update_headers( - # height=header['height'], headers=header['hex'], subscription_update=True - # ) - # - # async def subscribe_accounts(self): - # if self.network.is_connected and self.accounts: - # log.info("Subscribe to %i accounts", len(self.accounts)) - # await asyncio.wait([ - # self.subscribe_account(a) for a in self.accounts - # ]) - # - # async def subscribe_account(self, account: Account): - # for address_manager in account.address_managers.values(): - # await self.subscribe_addresses(address_manager, await address_manager.get_addresses()) - # await account.ensure_address_gap() - # - # async def unsubscribe_account(self, account: Account): - # for address in await account.get_addresses(): - # await self.network.unsubscribe_address(address) - # - # async def subscribe_addresses( - # self, address_manager: AddressManager, addresses: List[str], batch_size: int = 1000): - # if self.network.is_connected and addresses: - # addresses_remaining = list(addresses) - # while addresses_remaining: - # batch = addresses_remaining[:batch_size] - # results = await self.network.subscribe_address(*batch) - # for address, remote_status in zip(batch, results): - # self._update_tasks.add(self.update_history(address, remote_status, address_manager)) - # addresses_remaining = addresses_remaining[batch_size:] - # log.info("subscribed to %i/%i addresses on %s:%i", len(addresses) - len(addresses_remaining), - # len(addresses), *self.network.client.server_address_and_port) - # log.info( - # "finished subscribing to %i addresses on %s:%i", len(addresses), - # *self.network.client.server_address_and_port - # ) - # - # def process_status_update(self, update): - # address, remote_status = update - # self._update_tasks.add(self.update_history(address, remote_status)) - # - # async def update_history(self, address, remote_status, address_manager: AddressManager = None): - # async with self._address_update_locks[address]: - # self._known_addresses_out_of_sync.discard(address) - # - # local_status, local_history = await self.get_local_status_and_history(address) - # - # if local_status == remote_status: - # return True - # - # remote_history = await self.network.retriable_call(self.network.get_history, address) - # remote_history = list(map(itemgetter('tx_hash', 'height'), remote_history)) - # we_need = set(remote_history) - set(local_history) - # if not we_need: - # return True - # - # cache_tasks: List[asyncio.Task[Transaction]] = [] - # synced_history = StringIO() - # loop = asyncio.get_running_loop() - # for i, (txid, remote_height) in enumerate(remote_history): - # if i < len(local_history) and local_history[i] == (txid, remote_height) and not cache_tasks: - # synced_history.write(f'{txid}:{remote_height}:') - # else: - # check_local = (txid, remote_height) not in we_need - # cache_tasks.append(loop.create_task( - # self.cache_transaction(unhexlify(txid)[::-1], remote_height, check_local=check_local) - # )) - # - # synced_txs = [] - # for task in cache_tasks: - # tx = await task - # - # check_db_for_txos = [] - # for txi in tx.inputs: - # if txi.txo_ref.txo is not None: - # continue - # cache_item = self._tx_cache.get(txi.txo_ref.tx_ref.hash) - # if cache_item is not None: - # if cache_item.tx is None: - # await cache_item.has_tx.wait() - # assert cache_item.tx is not None - # txi.txo_ref = cache_item.tx.outputs[txi.txo_ref.position].ref - # else: - # check_db_for_txos.append(txi.txo_ref.hash) - # - # referenced_txos = {} if not check_db_for_txos else { - # txo.id: txo for txo in await self.db.get_txos( - # txo_hash__in=check_db_for_txos, order_by='txo.txo_hash', no_tx=True - # ) - # } - # - # for txi in tx.inputs: - # if txi.txo_ref.txo is not None: - # continue - # referenced_txo = referenced_txos.get(txi.txo_ref.id) - # if referenced_txo is not None: - # txi.txo_ref = referenced_txo.ref - # - # synced_history.write(f'{tx.id}:{tx.height}:') - # synced_txs.append(tx) - # - # await self.db.save_transaction_io_batch( - # synced_txs, address, self.ledger.address_to_hash160(address), synced_history.getvalue() - # ) - # await asyncio.wait([ - # self.ledger._on_transaction_controller.add(TransactionEvent(address, tx)) - # for tx in synced_txs - # ]) - # - # if address_manager is None: - # address_manager = await self.get_address_manager_for_address(address) - # - # if address_manager is not None: - # await address_manager.ensure_address_gap() - # - # local_status, local_history = \ - # await self.get_local_status_and_history(address, synced_history.getvalue()) - # if local_status != remote_status: - # if local_history == remote_history: - # return True - # log.warning( - # "Wallet is out of sync after syncing. Remote: %s with %d items, local: %s with %d items", - # remote_status, len(remote_history), local_status, len(local_history) - # ) - # log.warning("local: %s", local_history) - # log.warning("remote: %s", remote_history) - # self._known_addresses_out_of_sync.add(address) - # return False - # else: - # return True - # - # async def cache_transaction(self, tx_hash, remote_height, check_local=True): - # cache_item = self._tx_cache.get(tx_hash) - # if cache_item is None: - # cache_item = self._tx_cache[tx_hash] = TransactionCacheItem() - # elif cache_item.tx is not None and \ - # cache_item.tx.height >= remote_height and \ - # (cache_item.tx.is_verified or remote_height < 1): - # return cache_item.tx # cached tx is already up-to-date - # - # try: - # cache_item.pending_verifications += 1 - # return await self._update_cache_item(cache_item, tx_hash, remote_height, check_local) - # finally: - # cache_item.pending_verifications -= 1 - # - # async def _update_cache_item(self, cache_item, tx_hash, remote_height, check_local=True): - # - # async with cache_item.lock: - # - # tx = cache_item.tx - # - # if tx is None and check_local: - # # check local db - # tx = cache_item.tx = await self.db.get_transaction(tx_hash=tx_hash) - # - # merkle = None - # if tx is None: - # # fetch from network - # _raw, merkle = await self.network.retriable_call( - # self.network.get_transaction_and_merkle, tx_hash, remote_height - # ) - # tx = Transaction(unhexlify(_raw), height=merkle.get('block_height')) - # cache_item.tx = tx # make sure it's saved before caching it - # await self.maybe_verify_transaction(tx, remote_height, merkle) - # return tx - # - # async def maybe_verify_transaction(self, tx, remote_height, merkle=None): - # tx.height = remote_height - # cached = self._tx_cache.get(tx.hash) - # if not cached: - # # cache txs looked up by transaction_show too - # cached = TransactionCacheItem() - # cached.tx = tx - # self._tx_cache[tx.hash] = cached - # if 0 < remote_height < len(self.headers) and cached.pending_verifications <= 1: - # # can't be tx.pending_verifications == 1 because we have to handle the transaction_show case - # if not merkle: - # merkle = await self.network.retriable_call(self.network.get_merkle, tx.hash, remote_height) - # merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash) - # header = await self.headers.get(remote_height) - # tx.position = merkle['pos'] - # tx.is_verified = merkle_root == header['merkle_root'] - # - # async def get_address_manager_for_address(self, address) -> Optional[AddressManager]: - # details = await self.db.get_address(address=address) - # for account in self.accounts: - # if account.id == details['account']: - # return account.address_managers[details['chain']] - # return None diff --git a/lbry/wallet/wallet.py b/lbry/wallet/wallet.py index 9b1cb29fb..f53717e4e 100644 --- a/lbry/wallet/wallet.py +++ b/lbry/wallet/wallet.py @@ -1,5 +1,4 @@ # pylint: disable=arguments-differ -import os import json import zlib import asyncio @@ -10,9 +9,8 @@ from hashlib import sha256 from operator import attrgetter from decimal import Decimal - from lbry.db import Database, SPENDABLE_TYPE_CODES, Result -from lbry.blockchain.ledger import Ledger +from lbry.event import EventController from lbry.constants import COIN, NULL_HASH32 from lbry.blockchain.transaction import Transaction, Input, Output from lbry.blockchain.dewies import dewies_to_lbc @@ -23,9 +21,8 @@ from lbry.schema.purchase import Purchase from lbry.error import InsufficientFundsError, KeyFeeAboveMaxAllowedError from lbry.stream.managed_stream import ManagedStream -from .account import Account, SingleKey, HierarchicalDeterministic +from .account import Account from .coinselection import CoinSelector, OutputEffectiveAmountEstimator -from .storage import WalletStorage from .preferences import TimestampedPreferences @@ -37,20 +34,22 @@ ENCRYPT_ON_DISK = 'encrypt-on-disk' class Wallet: """ The primary role of Wallet is to encapsulate a collection of accounts (seed/private keys) and the spending rules / settings - for the coins attached to those accounts. Wallets are represented - by physical files on the filesystem. + for the coins attached to those accounts. """ - def __init__(self, ledger: Ledger, db: Database, name: str, storage: WalletStorage, preferences: dict): - self.ledger = ledger + VERSION = 1 + + def __init__(self, wallet_id: str, db: Database, name: str = "", preferences: dict = None): + self.id = wallet_id self.db = db self.name = name - self.storage = storage + self.ledger = db.ledger self.preferences = TimestampedPreferences(preferences or {}) self.encryption_password: Optional[str] = None - self.id = self.get_id() self.utxo_lock = asyncio.Lock() + self._on_change_controller = EventController() + self.on_change = self._on_change_controller.stream self.accounts = AccountListManager(self) self.claims = ClaimListManager(self) @@ -60,61 +59,55 @@ class Wallet: self.purchases = PurchaseListManager(self) self.supports = SupportListManager(self) - def get_id(self): - return os.path.basename(self.storage.path) if self.storage.path else self.name + async def notify_change(self, field: str, value=None): + await self._on_change_controller.add({ + 'field': field, 'value': value + }) @classmethod - async def create( - cls, ledger: Ledger, db: Database, path: str, name: str, - create_account=False, language='en', single_key=False): - wallet = cls(ledger, db, name, WalletStorage(path), {}) - if create_account: - await wallet.accounts.generate(language=language, address_generator={ - 'name': SingleKey.name if single_key else HierarchicalDeterministic.name - }) - await wallet.save() - return wallet - - @classmethod - async def from_path(cls, ledger: Ledger, db: Database, path: str): - return await cls.from_storage(ledger, db, WalletStorage(path)) - - @classmethod - async def from_storage(cls, ledger: Ledger, db: Database, storage: WalletStorage) -> 'Wallet': - json_dict = await storage.read() - if 'ledger' in json_dict and json_dict['ledger'] != ledger.get_id(): + async def from_dict(cls, wallet_id: str, wallet_dict, db: Database) -> 'Wallet': + if 'ledger' in wallet_dict and wallet_dict['ledger'] != db.ledger.get_id(): raise ValueError( - f"Using ledger {ledger.get_id()} but wallet is {json_dict['ledger']}." + f"Using ledger {db.ledger.get_id()} but wallet is {wallet_dict['ledger']}." ) + version = wallet_dict.get('version') + if version == 1: + pass wallet = cls( - ledger, db, - name=json_dict.get('name', 'Wallet'), - storage=storage, - preferences=json_dict.get('preferences', {}), + wallet_id, db, + name=wallet_dict.get('name', 'Wallet'), + preferences=wallet_dict.get('preferences', {}), ) - for account_dict in json_dict.get('accounts', []): + for account_dict in wallet_dict.get('accounts', []): await wallet.accounts.add_from_dict(account_dict) return wallet - def to_dict(self, encrypt_password: str = None): + def to_dict(self, encrypt_password: str = None) -> dict: return { - 'version': WalletStorage.VERSION, + 'version': self.VERSION, 'ledger': self.ledger.get_id(), 'name': self.name, 'preferences': self.preferences.data, 'accounts': [a.to_dict(encrypt_password) for a in self.accounts] } - async def save(self): + @classmethod + async def from_serialized(cls, wallet_id: str, json_data: str, db: Database) -> 'Wallet': + return await cls.from_dict(wallet_id, json.loads(json_data), db) + + def to_serialized(self) -> str: + wallet_dict = None if self.preferences.get(ENCRYPT_ON_DISK, False): if self.encryption_password is not None: - return await self.storage.write(self.to_dict(encrypt_password=self.encryption_password)) + wallet_dict = self.to_dict(encrypt_password=self.encryption_password) elif not self.is_locked: log.warning( "Disk encryption requested but no password available for encryption. " "Saving wallet in an unencrypted state." ) - return await self.storage.write(self.to_dict()) + if wallet_dict is None: + wallet_dict = self.to_dict() + return json.dumps(wallet_dict, indent=4, sort_keys=True) @property def hash(self) -> bytes: @@ -157,8 +150,9 @@ class Wallet: local_match.merge(account_dict) else: added_accounts.append( - self.accounts.add_from_dict(account_dict) + await self.accounts.add_from_dict(account_dict, notify=False) ) + await self.notify_change('wallet.merge') return added_accounts @property @@ -190,7 +184,7 @@ class Wallet: async def decrypt(self): assert not self.is_locked, "Cannot decrypt a locked wallet, unlock first." self.preferences[ENCRYPT_ON_DISK] = False - await self.save() + await self.notify_change(ENCRYPT_ON_DISK, False) return True async def encrypt(self, password): @@ -198,7 +192,7 @@ class Wallet: assert password, "Cannot encrypt with blank password." self.encryption_password = password self.preferences[ENCRYPT_ON_DISK] = True - await self.save() + await self.notify_change(ENCRYPT_ON_DISK, True) return True @property @@ -237,7 +231,7 @@ class Wallet: if await account.save_max_gap(): gap_changed = True if gap_changed: - await self.save() + await self.notify_change('address-max-gap') async def get_effective_amount_estimators(self, funding_accounts: Iterable[Account]): estimators = [] @@ -379,30 +373,6 @@ class Wallet: return tx - async def _report_state(self): - try: - for account in self.accounts: - balance = dewies_to_lbc(await account.get_balance(include_claims=True)) - _, channel_count = await account.get_channels(limit=1) - claim_count = await account.get_claim_count() - if isinstance(account.receiving, SingleKey): - log.info("Loaded single key account %s with %s LBC. " - "%d channels, %d certificates and %d claims", - account.id, balance, channel_count, len(account.channel_keys), claim_count) - else: - total_receiving = len(await account.receiving.get_addresses()) - total_change = len(await account.change.get_addresses()) - log.info("Loaded account %s with %s LBC, %d receiving addresses (gap: %d), " - "%d change addresses (gap: %d), %d channels, %d certificates and %d claims. ", - account.id, balance, total_receiving, account.receiving.gap, total_change, - account.change.gap, channel_count, len(account.channel_keys), claim_count) - except Exception as err: - if isinstance(err, asyncio.CancelledError): # TODO: remove when updated to 3.8 - raise - log.exception( - 'Failed to display wallet state, please file issue ' - 'for this bug along with the traceback you see below:') - async def verify_duplicate(self, name: str, allow_duplicate: bool): if not allow_duplicate: claims = await self.claims.list(claim_name=name) @@ -438,21 +408,22 @@ class AccountListManager: return account async def generate(self, name: str = None, language: str = 'en', address_generator: dict = None) -> Account: - account = await Account.generate( - self.wallet.ledger, self.wallet.db, name, language, address_generator - ) + account = await Account.generate(self.wallet.db, name, language, address_generator) self._accounts.append(account) + await self.wallet.notify_change('account.added') return account - async def add_from_dict(self, account_dict: dict) -> Account: - account = await Account.from_dict(self.wallet.ledger, self.wallet.db, account_dict) + async def add_from_dict(self, account_dict: dict, notify=True) -> Account: + account = await Account.from_dict(self.wallet.db, account_dict) self._accounts.append(account) + if notify: + await self.wallet.notify_change('account.added') return account async def remove(self, account_id: str) -> Account: account = self[account_id] self._accounts.remove(account) - await self.wallet.save() + await self.wallet.notify_change('account.removed') return account def get_or_none(self, account_id: str) -> Optional[Account]: @@ -608,7 +579,7 @@ class ChannelListManager(ClaimListManager): if save_key: holding_account.add_channel_private_key(txo.private_key) - await self.wallet.save() + await self.wallet.notify_change('channel.added') return tx @@ -652,7 +623,7 @@ class ChannelListManager(ClaimListManager): if any((new_signing_key, moving_accounts)) and save_key: holding_account.add_channel_private_key(txo.private_key) - await self.wallet.save() + await self.wallet.notify_change('channel.added') return tx diff --git a/tests/unit/wallet/__init__.py b/tests/unit/wallet/__init__.py index efc3df099..e69de29bb 100644 --- a/tests/unit/wallet/__init__.py +++ b/tests/unit/wallet/__init__.py @@ -1,2 +0,0 @@ -from unittest import SkipTest -raise SkipTest("WIP") diff --git a/tests/unit/wallet/test_account.py b/tests/unit/wallet/test_account.py index d754071f3..4c0fc1c71 100644 --- a/tests/unit/wallet/test_account.py +++ b/tests/unit/wallet/test_account.py @@ -13,7 +13,7 @@ class AccountTestCase(AsyncioTestCase): self.addCleanup(self.db.close) async def update_addressed_used(self, address, used): - await self.db.execute( + await self.db.execute_sql_object( tables.PubkeyAddress.update() .where(tables.PubkeyAddress.c.address == address) .values(used_times=used) @@ -23,7 +23,7 @@ class AccountTestCase(AsyncioTestCase): class TestHierarchicalDeterministicAccount(AccountTestCase): async def test_generate_account(self): - account = await Account.generate(self.ledger, self.db) + account = await Account.generate(self.db) self.assertEqual(account.ledger, self.ledger) self.assertEqual(account.db, self.db) self.assertEqual(account.name, f'Account #{account.public_key.address}') @@ -36,18 +36,19 @@ class TestHierarchicalDeterministicAccount(AccountTestCase): self.assertIsInstance(account.receiving, HierarchicalDeterministic) self.assertIsInstance(account.change, HierarchicalDeterministic) - async def test_ensure_address_gap(self): - account = await Account.generate(self.ledger, self.db, 'lbryum') self.assertEqual(len(await account.receiving.get_addresses()), 0) self.assertEqual(len(await account.change.get_addresses()), 0) await account.ensure_address_gap() self.assertEqual(len(await account.receiving.get_addresses()), 20) self.assertEqual(len(await account.change.get_addresses()), 6) + async def test_ensure_address_gap(self): + account = await Account.generate(self.db) async with account.receiving.address_generator_lock: await account.receiving._generate_keys(4, 7) await account.receiving._generate_keys(0, 3) await account.receiving._generate_keys(8, 11) + records = await account.receiving.get_address_records() self.assertListEqual( [r['pubkey'].n for r in records], @@ -79,14 +80,14 @@ class TestHierarchicalDeterministicAccount(AccountTestCase): self.assertEqual(len(new_keys), 20) async def test_generate_keys_over_batch_threshold_saves_it_properly(self): - account = Account.generate(self.ledger, self.db, 'lbryum') + account = await Account.generate(self.db) async with account.receiving.address_generator_lock: await account.receiving._generate_keys(0, 200) records = await account.receiving.get_address_records() self.assertEqual(len(records), 201) async def test_get_or_create_usable_address(self): - account = Account.generate(self.ledger, self.db, 'lbryum') + account = await Account.generate(self.db) keys = await account.receiving.get_addresses() self.assertEqual(len(keys), 0) @@ -98,13 +99,11 @@ class TestHierarchicalDeterministicAccount(AccountTestCase): self.assertEqual(len(keys), 20) async def test_generate_account_from_seed(self): - account = await Account.from_dict( - self.ledger, self.db, { - "seed": - "carbon smart garage balance margin twelve chest sword toas" - "t envelope bottom stomach absent" - } - ) + account = await Account.from_dict(self.db, { + "seed": + "carbon smart garage balance margin twelve chest sword toas" + "t envelope bottom stomach absent" + }) self.assertEqual( account.private_key.extended_key_string(), 'xprv9s21ZrQH143K42ovpZygnjfHdAqSd9jo7zceDfPRogM7bkkoNVv7DRNLEoB8' @@ -126,6 +125,7 @@ class TestHierarchicalDeterministicAccount(AccountTestCase): "carbon smart garage balance margin twelve chest sword toast envelope bottom stomac" "h absent", 'encrypted': False, + 'lang': 'en', 'private_key': 'xprv9s21ZrQH143K42ovpZygnjfHdAqSd9jo7zceDfPRogM7bkkoNVv7DRNLEoB8' 'HoirMgH969NrgL8jNzLEegqFzPRWM37GXd4uE8uuRkx4LAe', @@ -140,7 +140,7 @@ class TestHierarchicalDeterministicAccount(AccountTestCase): } } - account = Account.from_dict(self.ledger, self.db, account_data) + account = await Account.from_dict(self.db, account_data) await account.ensure_address_gap() @@ -150,7 +150,7 @@ class TestHierarchicalDeterministicAccount(AccountTestCase): self.assertEqual(len(addresses), 10) self.assertDictEqual(account_data, account.to_dict()) - def test_merge_diff(self): + async def test_merge_diff(self): account_data = { 'name': 'My Account', 'modified_on': 123.456, @@ -158,6 +158,7 @@ class TestHierarchicalDeterministicAccount(AccountTestCase): "carbon smart garage balance margin twelve chest sword toast envelope bottom stomac" "h absent", 'encrypted': False, + 'lang': 'en', 'private_key': 'xprv9s21ZrQH143K3TsAz5efNV8K93g3Ms3FXcjaWB9fVUsMwAoE3ZT4vYymkp' '5BxKKfnpz8J6sHDFriX1SnpvjNkzcks8XBnxjGLS83BTyfpna', @@ -170,7 +171,7 @@ class TestHierarchicalDeterministicAccount(AccountTestCase): 'change': {'gap': 5, 'maximum_uses_per_address': 2} } } - account = Account.from_dict(self.ledger, self.db, account_data) + account = await Account.from_dict(self.db, account_data) self.assertEqual(account.name, 'My Account') self.assertEqual(account.modified_on, 123.456) @@ -203,15 +204,15 @@ class TestSingleKeyAccount(AccountTestCase): async def asyncSetUp(self): await super().asyncSetUp() - self.account = Account.generate( - self.ledger, self.db, "torba", {'name': 'single-address'} + self.account = await Account.generate( + self.db, address_generator={"name": "single-address"} ) async def test_generate_account(self): account = self.account self.assertEqual(account.ledger, self.ledger) - self.assertIsNotNone(account.seed) + self.assertIsNotNone(account.phrase) self.assertEqual(account.public_key.ledger, self.ledger) self.assertEqual(account.private_key.public_key, account.public_key) @@ -246,7 +247,7 @@ class TestSingleKeyAccount(AccountTestCase): self.assertEqual(new_keys[0], account.public_key.address) records = await account.receiving.get_address_records() pubkey = records[0].pop('pubkey') - self.assertListEqual(records, [{ + self.assertEqual(records.rows, [{ 'chain': 0, 'account': account.public_key.address, 'address': account.public_key.address, @@ -294,6 +295,7 @@ class TestSingleKeyAccount(AccountTestCase): "carbon smart garage balance margin twelve chest sword toast envelope bottom stomac" "h absent", 'encrypted': False, + 'lang': 'en', 'private_key': 'xprv9s21ZrQH143K42ovpZygnjfHdAqSd9jo7zceDfPRogM7bkkoNVv7' 'DRNLEoB8HoirMgH969NrgL8jNzLEegqFzPRWM37GXd4uE8uuRkx4LAe', 'public_key': 'xpub661MyMwAqRbcGWtPvbWh9sc2BCfw2cTeVDYF23o3N1t6UZ5wv3EM' @@ -302,7 +304,7 @@ class TestSingleKeyAccount(AccountTestCase): 'certificates': {} } - account = Account.from_dict(self.ledger, self.db, account_data) + account = await Account.from_dict(self.db, account_data) await account.ensure_address_gap() @@ -351,9 +353,9 @@ class AccountEncryptionTests(AccountTestCase): } async def test_encrypt_wallet(self): - account = await Account.from_dict(self.ledger, self.db, self.unencrypted_account) + account = await Account.from_dict(self.db, self.unencrypted_account) account.init_vectors = { - 'seed': self.init_vector, + 'phrase': self.init_vector, 'private_key': self.init_vector } @@ -361,7 +363,7 @@ class AccountEncryptionTests(AccountTestCase): self.assertIsNotNone(account.private_key) account.encrypt(self.password) self.assertTrue(account.encrypted) - self.assertEqual(account.seed, self.encrypted_account['seed']) + self.assertEqual(account.phrase, self.encrypted_account['seed']) self.assertEqual(account.private_key_string, self.encrypted_account['private_key']) self.assertIsNone(account.private_key) @@ -370,9 +372,9 @@ class AccountEncryptionTests(AccountTestCase): account.decrypt(self.password) self.assertEqual(account.init_vectors['private_key'], self.init_vector) - self.assertEqual(account.init_vectors['seed'], self.init_vector) + self.assertEqual(account.init_vectors['phrase'], self.init_vector) - self.assertEqual(account.seed, self.unencrypted_account['seed']) + self.assertEqual(account.phrase, self.unencrypted_account['seed']) self.assertEqual(account.private_key.extended_key_string(), self.unencrypted_account['private_key']) self.assertEqual(account.to_dict(encrypt_password=self.password)['seed'], self.encrypted_account['seed']) @@ -381,16 +383,16 @@ class AccountEncryptionTests(AccountTestCase): self.assertFalse(account.encrypted) async def test_decrypt_wallet(self): - account = await Account.from_dict(self.ledger, self.db, self.encrypted_account) + account = await Account.from_dict(self.db, self.encrypted_account) self.assertTrue(account.encrypted) account.decrypt(self.password) self.assertEqual(account.init_vectors['private_key'], self.init_vector) - self.assertEqual(account.init_vectors['seed'], self.init_vector) + self.assertEqual(account.init_vectors['phrase'], self.init_vector) self.assertFalse(account.encrypted) - self.assertEqual(account.seed, self.unencrypted_account['seed']) + self.assertEqual(account.phrase, self.unencrypted_account['seed']) self.assertEqual(account.private_key.extended_key_string(), self.unencrypted_account['private_key']) self.assertEqual(account.to_dict(encrypt_password=self.password)['seed'], self.encrypted_account['seed']) @@ -402,7 +404,7 @@ class AccountEncryptionTests(AccountTestCase): account_data = self.unencrypted_account.copy() del account_data['seed'] del account_data['private_key'] - account = await Account.from_dict(self.ledger, self.db, account_data) + account = await Account.from_dict(self.db, account_data) encrypted = account.to_dict('password') self.assertFalse(encrypted['seed']) self.assertFalse(encrypted['private_key']) diff --git a/tests/unit/wallet/test_database.py b/tests/unit/wallet/test_database.py deleted file mode 100644 index 3bd6b0660..000000000 --- a/tests/unit/wallet/test_database.py +++ /dev/null @@ -1,566 +0,0 @@ -import sys -import os -import unittest -import sqlite3 -import tempfile -import asyncio -from concurrent.futures.thread import ThreadPoolExecutor - -from sqlalchemy import Column, Text - -from lbry.wallet import ( - Wallet, Account, Ledger, Headers, Transaction, Input -) -from lbry.db import Table, Version, Database, metadata -from lbry.wallet.constants import COIN -from lbry.wallet.database import query, interpolate, constraints_to_sql, AIOSQLite -from lbry.crypto.hash import sha256 -from lbry.testcase import AsyncioTestCase - -from tests.unit.wallet.test_transaction import get_output, NULL_HASH - - -class TestAIOSQLite(AsyncioTestCase): - async def asyncSetUp(self): - self.db = await AIOSQLite.connect(':memory:') - await self.db.executescript(""" - pragma foreign_keys=on; - create table parent (id integer primary key, name); - create table child (id integer primary key, parent_id references parent); - """) - await self.db.execute("insert into parent values (1, 'test')") - await self.db.execute("insert into child values (2, 1)") - - @staticmethod - def delete_item(transaction): - transaction.execute('delete from parent where id=1') - - async def test_foreign_keys_integrity_error(self): - self.assertListEqual([(1, 'test')], await self.db.execute_fetchall("select * from parent")) - - with self.assertRaises(sqlite3.IntegrityError): - await self.db.run(self.delete_item) - self.assertListEqual([(1, 'test')], await self.db.execute_fetchall("select * from parent")) - - await self.db.executescript("pragma foreign_keys=off;") - - await self.db.run(self.delete_item) - self.assertListEqual([], await self.db.execute_fetchall("select * from parent")) - - async def test_run_without_foreign_keys(self): - self.assertListEqual([(1, 'test')], await self.db.execute_fetchall("select * from parent")) - await self.db.run_with_foreign_keys_disabled(self.delete_item) - self.assertListEqual([], await self.db.execute_fetchall("select * from parent")) - - async def test_integrity_error_when_foreign_keys_disabled_and_skipped(self): - await self.db.executescript("pragma foreign_keys=off;") - self.assertListEqual([(1, 'test')], await self.db.execute_fetchall("select * from parent")) - with self.assertRaises(sqlite3.IntegrityError): - await self.db.run_with_foreign_keys_disabled(self.delete_item) - self.assertListEqual([(1, 'test')], await self.db.execute_fetchall("select * from parent")) - - -class TestQueryBuilder(unittest.TestCase): - - def test_dot(self): - self.assertTupleEqual( - constraints_to_sql({'txo.position': 18}), - ('txo.position = :txo_position0', {'txo_position0': 18}) - ) - self.assertTupleEqual( - constraints_to_sql({'txo.position#6': 18}), - ('txo.position = :txo_position6', {'txo_position6': 18}) - ) - - def test_any(self): - self.assertTupleEqual( - constraints_to_sql({ - 'ages__any': { - 'txo.age__gt': 18, - 'txo.age__lt': 38 - } - }), - ('(txo.age > :ages__any0_txo_age__gt0 OR txo.age < :ages__any0_txo_age__lt0)', { - 'ages__any0_txo_age__gt0': 18, - 'ages__any0_txo_age__lt0': 38 - }) - ) - - def test_in(self): - self.assertTupleEqual( - constraints_to_sql({'txo.age__in#2': [18, 38]}), - ('txo.age IN (:txo_age__in2_0, :txo_age__in2_1)', { - 'txo_age__in2_0': 18, - 'txo_age__in2_1': 38 - }) - ) - self.assertTupleEqual( - constraints_to_sql({'txo.name__in': ('abc123', 'def456')}), - ('txo.name IN (:txo_name__in0_0, :txo_name__in0_1)', { - 'txo_name__in0_0': 'abc123', - 'txo_name__in0_1': 'def456' - }) - ) - self.assertTupleEqual( - constraints_to_sql({'txo.name__in': {'abc123'}}), - ('txo.name = :txo_name__in0', { - 'txo_name__in0': 'abc123', - }) - ) - self.assertTupleEqual( - constraints_to_sql({'txo.age__in': 'SELECT age from ages_table'}), - ('txo.age IN (SELECT age from ages_table)', {}) - ) - - def test_not_in(self): - self.assertTupleEqual( - constraints_to_sql({'txo.age__not_in': [18, 38]}), - ('txo.age NOT IN (:txo_age__not_in0_0, :txo_age__not_in0_1)', { - 'txo_age__not_in0_0': 18, - 'txo_age__not_in0_1': 38 - }) - ) - self.assertTupleEqual( - constraints_to_sql({'txo.name__not_in': ('abc123', 'def456')}), - ('txo.name NOT IN (:txo_name__not_in0_0, :txo_name__not_in0_1)', { - 'txo_name__not_in0_0': 'abc123', - 'txo_name__not_in0_1': 'def456' - }) - ) - self.assertTupleEqual( - constraints_to_sql({'txo.name__not_in': ('abc123',)}), - ('txo.name != :txo_name__not_in0', { - 'txo_name__not_in0': 'abc123', - }) - ) - self.assertTupleEqual( - constraints_to_sql({'txo.age__not_in': 'SELECT age from ages_table'}), - ('txo.age NOT IN (SELECT age from ages_table)', {}) - ) - - def test_in_invalid(self): - with self.assertRaisesRegex(ValueError, 'list, set or string'): - constraints_to_sql({'ages__in': 9}) - - def test_query(self): - self.assertTupleEqual( - query("select * from foo"), - ("select * from foo", {}) - ) - self.assertTupleEqual( - query( - "select * from foo", - a__not='b', b__in='select * from blah where c=:$c', - d__any={'one__like': 'o', 'two': 2}, limit=10, order_by='b', **{'$c': 3}), - ( - "select * from foo WHERE a != :a__not0 AND " - "b IN (select * from blah where c=:$c) AND " - "(one LIKE :d__any0_one__like0 OR two = :d__any0_two0) ORDER BY b LIMIT 10", - {'a__not0': 'b', 'd__any0_one__like0': 'o', 'd__any0_two0': 2, '$c': 3} - ) - ) - - def test_query_order_by(self): - self.assertTupleEqual( - query("select * from foo", order_by='foo'), - ("select * from foo ORDER BY foo", {}) - ) - self.assertTupleEqual( - query("select * from foo", order_by=['foo', 'bar']), - ("select * from foo ORDER BY foo, bar", {}) - ) - with self.assertRaisesRegex(ValueError, 'order_by must be string or list'): - query("select * from foo", order_by={'foo': 'bar'}) - - def test_query_limit_offset(self): - self.assertTupleEqual( - query("select * from foo", limit=10), - ("select * from foo LIMIT 10", {}) - ) - self.assertTupleEqual( - query("select * from foo", offset=10), - ("select * from foo OFFSET 10", {}) - ) - self.assertTupleEqual( - query("select * from foo", limit=20, offset=10), - ("select * from foo LIMIT 20 OFFSET 10", {}) - ) - - def test_query_interpolation(self): - self.maxDiff = None - # tests that interpolation replaces longer keys first - self.assertEqual( - interpolate(*query( - "select * from foo", - a__not='b', b__in='select * from blah where c=:$c', - d__any={'one__like': 'o', 'two': 2}, - a0=3, a00=1, a00a=2, a00aa=4, # <-- breaks without correct interpolation key order - ahash=sha256(b'hello world'), - limit=10, order_by='b', **{'$c': 3}) - ), - "select * from foo WHERE a != 'b' AND " - "b IN (select * from blah where c=3) AND " - "(one LIKE 'o' OR two = 2) AND " - "a0 = 3 AND a00 = 1 AND a00a = 2 AND a00aa = 4 " - "AND ahash = X'b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9' " - "ORDER BY b LIMIT 10" - ) - - -class TestQueries(AsyncioTestCase): - - async def asyncSetUp(self): - self.ledger = Ledger({ - 'db': Database('sqlite:///:memory:'), - 'headers': Headers(':memory:') - }) - self.wallet = Wallet() - await self.ledger.db.open() - - async def asyncTearDown(self): - await self.ledger.db.close() - - async def create_account(self, wallet=None): - account = Account.generate(self.ledger, wallet or self.wallet) - await account.ensure_address_gap() - return account - - async def create_tx_from_nothing(self, my_account, height): - to_address = await my_account.receiving.get_or_create_usable_address() - to_hash = Ledger.address_to_hash160(to_address) - tx = Transaction(height=height, is_verified=True) \ - .add_inputs([self.txi(self.txo(1, sha256(str(height).encode())))]) \ - .add_outputs([self.txo(1, to_hash)]) - await self.ledger.db.insert_transaction(tx) - await self.ledger.db.save_transaction_io(tx, to_address, to_hash, '') - return tx - - async def create_tx_from_txo(self, txo, to_account, height): - from_hash = txo.script.values['pubkey_hash'] - from_address = self.ledger.pubkey_hash_to_address(from_hash) - to_address = await to_account.receiving.get_or_create_usable_address() - to_hash = Ledger.address_to_hash160(to_address) - tx = Transaction(height=height, is_verified=True) \ - .add_inputs([self.txi(txo)]) \ - .add_outputs([self.txo(1, to_hash)]) - await self.ledger.db.insert_transaction(tx) - await self.ledger.db.save_transaction_io(tx, from_address, from_hash, '') - await self.ledger.db.save_transaction_io(tx, to_address, to_hash, '') - return tx - - async def create_tx_to_nowhere(self, txo, height): - from_hash = txo.script.values['pubkey_hash'] - from_address = self.ledger.pubkey_hash_to_address(from_hash) - to_hash = NULL_HASH - tx = Transaction(height=height, is_verified=True) \ - .add_inputs([self.txi(txo)]) \ - .add_outputs([self.txo(1, to_hash)]) - await self.ledger.db.insert_transaction(tx) - await self.ledger.db.save_transaction_io(tx, from_address, from_hash, '') - return tx - - def txo(self, amount, address): - return get_output(int(amount*COIN), address) - - def txi(self, txo): - return Input.spend(txo) - - async def test_large_tx_doesnt_hit_variable_limits(self): - # SQLite is usually compiled with 999 variables limit: https://www.sqlite.org/limits.html - # This can be removed when there is a better way. See: https://github.com/lbryio/lbry-sdk/issues/2281 - fetchall = self.ledger.db.execute_fetchall - - def check_parameters_length(sql, parameters=None): - self.assertLess(len(parameters or []), 999) - return fetchall(sql, parameters) - - self.ledger.db.execute_fetchall = check_parameters_length - account = await self.create_account() - tx = await self.create_tx_from_nothing(account, 0) - for height in range(1, 1200): - tx = await self.create_tx_from_txo(tx.outputs[0], account, height=height) - variable_limit = self.ledger.db.MAX_QUERY_VARIABLES - for limit in range(variable_limit - 2, variable_limit + 2): - txs = await self.ledger.get_transactions( - accounts=self.wallet.accounts, limit=limit, order_by='height asc') - self.assertEqual(len(txs), limit) - inputs, outputs, last_tx = set(), set(), txs[0] - for tx in txs[1:]: - self.assertEqual(len(tx.inputs), 1) - self.assertEqual(tx.inputs[0].txo_ref.tx_ref.id, last_tx.id) - self.assertEqual(len(tx.outputs), 1) - last_tx = tx - - async def test_queries(self): - wallet1 = Wallet() - account1 = await self.create_account(wallet1) - self.assertEqual(26, await self.ledger.db.get_address_count(accounts=[account1])) - wallet2 = Wallet() - account2 = await self.create_account(wallet2) - account3 = await self.create_account(wallet2) - self.assertEqual(26, await self.ledger.db.get_address_count(accounts=[account2])) - - self.assertEqual(0, await self.ledger.db.get_transaction_count(accounts=[account1, account2, account3])) - self.assertEqual(0, await self.ledger.db.get_utxo_count()) - self.assertListEqual([], await self.ledger.db.get_utxos()) - self.assertEqual(0, await self.ledger.db.get_txo_count()) - self.assertEqual(0, await self.ledger.db.get_balance(wallet=wallet1)) - self.assertEqual(0, await self.ledger.db.get_balance(wallet=wallet2)) - self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account1])) - self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account2])) - self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account3])) - - tx1 = await self.create_tx_from_nothing(account1, 1) - self.assertEqual(1, await self.ledger.db.get_transaction_count(accounts=[account1])) - self.assertEqual(0, await self.ledger.db.get_transaction_count(accounts=[account2])) - self.assertEqual(1, await self.ledger.db.get_utxo_count(accounts=[account1])) - self.assertEqual(1, await self.ledger.db.get_txo_count(accounts=[account1])) - self.assertEqual(0, await self.ledger.db.get_txo_count(accounts=[account2])) - self.assertEqual(10**8, await self.ledger.db.get_balance(wallet=wallet1)) - self.assertEqual(0, await self.ledger.db.get_balance(wallet=wallet2)) - self.assertEqual(10**8, await self.ledger.db.get_balance(accounts=[account1])) - self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account2])) - self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account3])) - - tx2 = await self.create_tx_from_txo(tx1.outputs[0], account2, 2) - tx2b = await self.create_tx_from_nothing(account3, 2) - self.assertEqual(2, await self.ledger.db.get_transaction_count(accounts=[account1])) - self.assertEqual(1, await self.ledger.db.get_transaction_count(accounts=[account2])) - self.assertEqual(1, await self.ledger.db.get_transaction_count(accounts=[account3])) - self.assertEqual(0, await self.ledger.db.get_utxo_count(accounts=[account1])) - self.assertEqual(1, await self.ledger.db.get_txo_count(accounts=[account1])) - self.assertEqual(1, await self.ledger.db.get_utxo_count(accounts=[account2])) - self.assertEqual(1, await self.ledger.db.get_txo_count(accounts=[account2])) - self.assertEqual(1, await self.ledger.db.get_utxo_count(accounts=[account3])) - self.assertEqual(1, await self.ledger.db.get_txo_count(accounts=[account3])) - self.assertEqual(0, await self.ledger.db.get_balance(wallet=wallet1)) - self.assertEqual(10**8+10**8, await self.ledger.db.get_balance(wallet=wallet2)) - self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account1])) - self.assertEqual(10**8, await self.ledger.db.get_balance(accounts=[account2])) - self.assertEqual(10**8, await self.ledger.db.get_balance(accounts=[account3])) - - tx3 = await self.create_tx_to_nowhere(tx2.outputs[0], 3) - self.assertEqual(2, await self.ledger.db.get_transaction_count(accounts=[account1])) - self.assertEqual(2, await self.ledger.db.get_transaction_count(accounts=[account2])) - self.assertEqual(0, await self.ledger.db.get_utxo_count(accounts=[account1])) - self.assertEqual(1, await self.ledger.db.get_txo_count(accounts=[account1])) - self.assertEqual(0, await self.ledger.db.get_utxo_count(accounts=[account2])) - self.assertEqual(1, await self.ledger.db.get_txo_count(accounts=[account2])) - self.assertEqual(0, await self.ledger.db.get_balance(wallet=wallet1)) - self.assertEqual(10**8, await self.ledger.db.get_balance(wallet=wallet2)) - self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account1])) - self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account2])) - self.assertEqual(10**8, await self.ledger.db.get_balance(accounts=[account3])) - - txs = await self.ledger.db.get_transactions(accounts=[account1, account2]) - self.assertListEqual([tx3.id, tx2.id, tx1.id], [tx.id for tx in txs]) - self.assertListEqual([3, 2, 1], [tx.height for tx in txs]) - - txs = await self.ledger.db.get_transactions(wallet=wallet1, accounts=wallet1.accounts, include_is_my_output=True) - self.assertListEqual([tx2.id, tx1.id], [tx.id for tx in txs]) - self.assertEqual(txs[0].inputs[0].is_my_input, True) - self.assertEqual(txs[0].outputs[0].is_my_output, False) - self.assertEqual(txs[1].inputs[0].is_my_input, False) - self.assertEqual(txs[1].outputs[0].is_my_output, True) - - txs = await self.ledger.db.get_transactions(wallet=wallet2, accounts=[account2], include_is_my_output=True) - self.assertListEqual([tx3.id, tx2.id], [tx.id for tx in txs]) - self.assertEqual(txs[0].inputs[0].is_my_input, True) - self.assertEqual(txs[0].outputs[0].is_my_output, False) - self.assertEqual(txs[1].inputs[0].is_my_input, False) - self.assertEqual(txs[1].outputs[0].is_my_output, True) - self.assertEqual(2, await self.ledger.db.get_transaction_count(accounts=[account2])) - - tx = await self.ledger.db.get_transaction(tx_hash=tx2.hash) - self.assertEqual(tx.id, tx2.id) - self.assertIsNone(tx.inputs[0].is_my_input) - self.assertIsNone(tx.outputs[0].is_my_output) - tx = await self.ledger.db.get_transaction(wallet=wallet1, tx_hash=tx2.hash, include_is_my_output=True) - self.assertTrue(tx.inputs[0].is_my_input) - self.assertFalse(tx.outputs[0].is_my_output) - tx = await self.ledger.db.get_transaction(wallet=wallet2, tx_hash=tx2.hash, include_is_my_output=True) - self.assertFalse(tx.inputs[0].is_my_input) - self.assertTrue(tx.outputs[0].is_my_output) - - # height 0 sorted to the top with the rest in descending order - tx4 = await self.create_tx_from_nothing(account1, 0) - txos = await self.ledger.db.get_txos() - self.assertListEqual([0, 3, 2, 2, 1], [txo.tx_ref.height for txo in txos]) - self.assertListEqual([tx4.id, tx3.id, tx2.id, tx2b.id, tx1.id], [txo.tx_ref.id for txo in txos]) - txs = await self.ledger.db.get_transactions(accounts=[account1, account2]) - self.assertListEqual([0, 3, 2, 1], [tx.height for tx in txs]) - self.assertListEqual([tx4.id, tx3.id, tx2.id, tx1.id], [tx.id for tx in txs]) - - async def test_empty_history(self): - self.assertEqual((None, []), await self.ledger.get_local_status_and_history('')) - - -class TestUpgrade(AsyncioTestCase): - - def setUp(self) -> None: - self.path = tempfile.mktemp() - - def tearDown(self) -> None: - os.remove(self.path) - - def get_version(self): - with sqlite3.connect(self.path) as conn: - versions = conn.execute('select version from version').fetchall() - assert len(versions) == 1 - return versions[0][0] - - def get_tables(self): - with sqlite3.connect(self.path) as conn: - sql = "SELECT name FROM sqlite_master WHERE type='table' ORDER BY name;" - return [col[0] for col in conn.execute(sql).fetchall()] - - def add_address(self, address): - with sqlite3.connect(self.path) as conn: - conn.execute(""" - INSERT INTO account_address (address, account, chain, n, pubkey, chain_code, depth) - VALUES (?, 'account1', 0, 0, 'pubkey', 'chain_code', 0) - """, (address,)) - - def get_addresses(self): - with sqlite3.connect(self.path) as conn: - sql = "SELECT address FROM account_address ORDER BY address;" - return [col[0] for col in conn.execute(sql).fetchall()] - - async def test_reset_on_version_change(self): - self.ledger = Ledger({ - 'db': Database('sqlite:///'+self.path), - 'headers': Headers(':memory:') - }) - - # initial open, pre-version enabled db - self.ledger.db.SCHEMA_VERSION = None - self.assertListEqual(self.get_tables(), []) - await self.ledger.db.open() - metadata.drop_all(self.ledger.db.engine, [Version]) # simulate pre-version table db - self.assertEqual(self.get_tables(), ['account_address', 'block', 'pubkey_address', 'tx', 'txi', 'txo']) - self.assertListEqual(self.get_addresses(), []) - self.add_address('address1') - await self.ledger.db.close() - - # initial open after version enabled - self.ledger.db.SCHEMA_VERSION = '1.0' - await self.ledger.db.open() - self.assertEqual(self.get_version(), '1.0') - self.assertListEqual(self.get_tables(), ['account_address', 'block', 'pubkey_address', 'tx', 'txi', 'txo', 'version']) - self.assertListEqual(self.get_addresses(), []) # address1 deleted during version upgrade - self.add_address('address2') - await self.ledger.db.close() - - # nothing changes - self.assertEqual(self.get_version(), '1.0') - self.assertListEqual(self.get_tables(), ['account_address', 'block', 'pubkey_address', 'tx', 'txi', 'txo', 'version']) - await self.ledger.db.open() - self.assertEqual(self.get_version(), '1.0') - self.assertListEqual(self.get_tables(), ['account_address', 'block', 'pubkey_address', 'tx', 'txi', 'txo', 'version']) - self.assertListEqual(self.get_addresses(), ['address2']) - await self.ledger.db.close() - - # upgrade version, database reset - foo = Table('foo', metadata, Column('bar', Text, primary_key=True)) - self.addCleanup(metadata.remove, foo) - self.ledger.db.SCHEMA_VERSION = '1.1' - await self.ledger.db.open() - self.assertEqual(self.get_version(), '1.1') - self.assertListEqual(self.get_tables(), ['account_address', 'block', 'foo', 'pubkey_address', 'tx', 'txi', 'txo', 'version']) - self.assertListEqual(self.get_addresses(), []) # all tables got reset - await self.ledger.db.close() - - -class TestSQLiteRace(AsyncioTestCase): - max_misuse_attempts = 40000 - - def setup_db(self): - self.db = sqlite3.connect(":memory:", isolation_level=None) - self.db.executescript( - "create table test1 (id text primary key not null, val text);\n" + - "create table test2 (id text primary key not null, val text);\n" + - "\n".join(f"insert into test1 values ({v}, NULL);" for v in range(1000)) - ) - - async def asyncSetUp(self): - self.executor = ThreadPoolExecutor(1) - await self.loop.run_in_executor(self.executor, self.setup_db) - - async def asyncTearDown(self): - await self.loop.run_in_executor(self.executor, self.db.close) - self.executor.shutdown() - - async def test_binding_param_0_error(self): - # test real param 0 binding errors - - for supported_type in [str, int, bytes]: - await self.loop.run_in_executor( - self.executor, self.db.executemany, "insert into test2 values (?, NULL)", - [(supported_type(1), ), (supported_type(2), )] - ) - await self.loop.run_in_executor( - self.executor, self.db.execute, "delete from test2 where id in (1, 2)" - ) - for unsupported_type in [lambda x: (x, ), lambda x: [x], lambda x: {x}]: - try: - await self.loop.run_in_executor( - self.executor, self.db.executemany, "insert into test2 (id, val) values (?, NULL)", - [(unsupported_type(1), ), (unsupported_type(2), )] - ) - self.assertTrue(False) - except sqlite3.InterfaceError as err: - self.assertEqual(str(err), "Error binding parameter 0 - probably unsupported type.") - - async def test_unhandled_sqlite_misuse(self): - # test SQLITE_MISUSE being incorrectly raised as a param 0 binding error - attempts = 0 - python_version = sys.version.split('\n')[0].rstrip(' ') - - try: - while attempts < self.max_misuse_attempts: - f1 = asyncio.wrap_future( - self.loop.run_in_executor( - self.executor, self.db.executemany, "update test1 set val='derp' where id=?", - ((str(i),) for i in range(2)) - ) - ) - f2 = asyncio.wrap_future( - self.loop.run_in_executor( - self.executor, self.db.executemany, "update test2 set val='derp' where id=?", - ((str(i),) for i in range(2)) - ) - ) - attempts += 1 - await asyncio.gather(f1, f2) - print(f"\nsqlite3 {sqlite3.version}/python {python_version} " - f"did not raise SQLITE_MISUSE within {attempts} attempts of the race condition") - self.assertTrue(False, 'this test failing means either the sqlite race conditions ' - 'have been fixed in cpython or the test max_attempts needs to be increased') - except sqlite3.InterfaceError as err: - self.assertEqual(str(err), "Error binding parameter 0 - probably unsupported type.") - print(f"\nsqlite3 {sqlite3.version}/python {python_version} raised SQLITE_MISUSE " - f"after {attempts} attempts of the race condition") - - @unittest.SkipTest - async def test_fetchall_prevents_sqlite_misuse(self): - # test that calling fetchall sufficiently avoids the race - attempts = 0 - - def executemany_fetchall(query, params): - self.db.executemany(query, params).fetchall() - - while attempts < self.max_misuse_attempts: - f1 = asyncio.wrap_future( - self.loop.run_in_executor( - self.executor, executemany_fetchall, "update test1 set val='derp' where id=?", - ((str(i),) for i in range(2)) - ) - ) - f2 = asyncio.wrap_future( - self.loop.run_in_executor( - self.executor, executemany_fetchall, "update test2 set val='derp' where id=?", - ((str(i),) for i in range(2)) - ) - ) - attempts += 1 - await asyncio.gather(f1, f2) \ No newline at end of file diff --git a/tests/unit/wallet/test_ledger.py b/tests/unit/wallet/test_ledger.py deleted file mode 100644 index e6e4aaf14..000000000 --- a/tests/unit/wallet/test_ledger.py +++ /dev/null @@ -1,242 +0,0 @@ -import os -from binascii import hexlify, unhexlify - -from lbry.testcase import AsyncioTestCase -from lbry.wallet import Wallet, Account, Transaction, Output, Input, Ledger, Headers -from lbry.db import Database - -from tests.unit.wallet.test_transaction import get_transaction, get_output -from tests.unit.wallet.test_headers import HEADERS, block_bytes - - -class MockNetwork: - - def __init__(self, history, transaction): - self.history = history - self.transaction = transaction - self.address = None - self.get_history_called = [] - self.get_transaction_called = [] - self.is_connected = False - - def retriable_call(self, function, *args, **kwargs): - return function(*args, **kwargs) - - async def get_history(self, address): - self.get_history_called.append(address) - self.address = address - return self.history - - async def get_merkle(self, txid, height): - return {'merkle': ['abcd01'], 'pos': 1} - - async def get_transaction(self, tx_hash, _=None): - self.get_transaction_called.append(tx_hash) - return self.transaction[tx_hash] - - async def get_transaction_and_merkle(self, tx_hash, known_height=None): - tx = await self.get_transaction(tx_hash) - merkle = {} - if known_height: - merkle = await self.get_merkle(tx_hash, known_height) - return tx, merkle - - -class LedgerTestCase(AsyncioTestCase): - - async def asyncSetUp(self): - self.ledger = Ledger({ - 'db': Database('sqlite:///:memory:'), - 'headers': Headers(':memory:') - }) - self.account = Account.generate(self.ledger, Wallet(), "lbryum") - await self.ledger.db.open() - - async def asyncTearDown(self): - await self.ledger.db.close() - - def make_header(self, **kwargs): - header = { - 'bits': 486604799, - 'block_height': 0, - 'merkle_root': b'4a5e1e4baab89f3a32518a88c31bc87f618f76673e2cc77ab2127b7afdeda33b', - 'nonce': 2083236893, - 'prev_block_hash': b'0000000000000000000000000000000000000000000000000000000000000000', - 'timestamp': 1231006505, - 'claim_trie_root': b'4a5e1e4baab89f3a32518a88c31bc87f618f76673e2cc77ab2127b7afdeda33b', - 'version': 1 - } - header.update(kwargs) - header['merkle_root'] = header['merkle_root'].ljust(64, b'a') - header['prev_block_hash'] = header['prev_block_hash'].ljust(64, b'0') - return self.ledger.headers.serialize(header) - - def add_header(self, **kwargs): - serialized = self.make_header(**kwargs) - self.ledger.headers.io.seek(0, os.SEEK_END) - self.ledger.headers.io.write(serialized) - self.ledger.headers._size = self.ledger.headers.io.seek(0, os.SEEK_END) // self.ledger.headers.header_size - - -class TestSynchronization(LedgerTestCase): - - async def test_update_history(self): - txid1 = '252bda9b22cc902ca2aa2de3548ee8baf06b8501ff7bfb3b0b7d980dbd1bf792' - txid2 = 'ab9c0654dd484ac20437030f2034e25dcb29fc507e84b91138f80adc3af738f9' - txid3 = 'a2ae3d1db3c727e7d696122cab39ee20a7f81856dab7019056dd539f38c548a0' - txid4 = '047cf1d53ef68f0fd586d46f90c09ff8e57a4180f67e7f4b8dd0135c3741e828' - txhash1 = unhexlify(txid1)[::-1] - txhash2 = unhexlify(txid2)[::-1] - txhash3 = unhexlify(txid3)[::-1] - txhash4 = unhexlify(txid4)[::-1] - - account = Account.generate(self.ledger, Wallet(), "torba") - address = await account.receiving.get_or_create_usable_address() - address_details = await self.ledger.db.get_address(address=address) - self.assertIsNone(address_details['history']) - - self.add_header(block_height=0, merkle_root=b'abcd04') - self.add_header(block_height=1, merkle_root=b'abcd04') - self.add_header(block_height=2, merkle_root=b'abcd04') - self.add_header(block_height=3, merkle_root=b'abcd04') - self.ledger.network = MockNetwork([ - {'tx_hash': txid1, 'height': 0}, - {'tx_hash': txid2, 'height': 1}, - {'tx_hash': txid3, 'height': 2}, - ], { - txhash1: hexlify(get_transaction(get_output(1)).raw), - txhash2: hexlify(get_transaction(get_output(2)).raw), - txhash3: hexlify(get_transaction(get_output(3)).raw), - }) - await self.ledger.update_history(address, '') - self.assertListEqual(self.ledger.network.get_history_called, [address]) - self.assertListEqual(self.ledger.network.get_transaction_called, [txhash1, txhash2, txhash3]) - - address_details = await self.ledger.db.get_address(address=address) - - self.assertEqual( - address_details['history'], - f'{txid1}:0:' - f'{txid2}:1:' - f'{txid3}:2:' - ) - - self.ledger.network.get_history_called = [] - self.ledger.network.get_transaction_called = [] - for cache_item in self.ledger._tx_cache.values(): - cache_item.tx.is_verified = True - await self.ledger.update_history(address, '') - self.assertListEqual(self.ledger.network.get_history_called, [address]) - self.assertListEqual(self.ledger.network.get_transaction_called, []) - - self.ledger.network.history.append({'tx_hash': txid4, 'height': 3}) - self.ledger.network.transaction[txhash4] = hexlify(get_transaction(get_output(4)).raw) - self.ledger.network.get_history_called = [] - self.ledger.network.get_transaction_called = [] - await self.ledger.update_history(address, '') - self.assertListEqual(self.ledger.network.get_history_called, [address]) - self.assertListEqual(self.ledger.network.get_transaction_called, [txhash4]) - address_details = await self.ledger.db.get_address(address=address) - self.assertEqual( - address_details['history'], - f'{txid1}:0:' - f'{txid2}:1:' - f'{txid3}:2:' - f'{txid4}:3:' - ) - - -class MocHeaderNetwork(MockNetwork): - def __init__(self, responses): - super().__init__(None, None) - self.responses = responses - - async def get_headers(self, height, blocks): - return self.responses[height] - - -class BlockchainReorganizationTests(LedgerTestCase): - - async def test_1_block_reorganization(self): - self.ledger.network = MocHeaderNetwork({ - 10: {'height': 10, 'count': 5, 'hex': hexlify( - HEADERS[block_bytes(10):block_bytes(15)] - )}, - 15: {'height': 15, 'count': 0, 'hex': b''} - }) - headers = self.ledger.headers - await headers.connect(0, HEADERS[:block_bytes(10)]) - self.add_header(block_height=len(headers)) - self.assertEqual(10, headers.height) - await self.ledger.receive_header([{ - 'height': 11, 'hex': hexlify(self.make_header(block_height=11)) - }]) - - async def test_3_block_reorganization(self): - self.ledger.network = MocHeaderNetwork({ - 10: {'height': 10, 'count': 5, 'hex': hexlify( - HEADERS[block_bytes(10):block_bytes(15)] - )}, - 11: {'height': 11, 'count': 1, 'hex': hexlify(self.make_header(block_height=11))}, - 12: {'height': 12, 'count': 1, 'hex': hexlify(self.make_header(block_height=12))}, - 15: {'height': 15, 'count': 0, 'hex': b''} - }) - headers = self.ledger.headers - await headers.connect(0, HEADERS[:block_bytes(10)]) - self.add_header(block_height=len(headers)) - self.add_header(block_height=len(headers)) - self.add_header(block_height=len(headers)) - self.assertEqual(headers.height, 12) - await self.ledger.receive_header([{ - 'height': 13, 'hex': hexlify(self.make_header(block_height=13)) - }]) - - -class BasicAccountingTests(LedgerTestCase): - - async def test_empty_state(self): - self.assertEqual(await self.account.get_balance(), 0) - - async def test_balance(self): - address = await self.account.receiving.get_or_create_usable_address() - hash160 = self.ledger.address_to_hash160(address) - - tx = Transaction(is_verified=True)\ - .add_outputs([Output.pay_pubkey_hash(100, hash160)]) - await self.ledger.db.insert_transaction(tx) - await self.ledger.db.save_transaction_io( - tx, address, hash160, f'{tx.id}:1:' - ) - self.assertEqual(await self.account.get_balance(), 100) - - tx = Transaction(is_verified=True)\ - .add_outputs([Output.pay_claim_name_pubkey_hash(100, 'foo', b'', hash160)]) - await self.ledger.db.insert_transaction(tx) - await self.ledger.db.save_transaction_io( - tx, address, hash160, f'{tx.id}:1:' - ) - self.assertEqual(await self.account.get_balance(), 100) # claim names don't count towards balance - self.assertEqual(await self.account.get_balance(include_claims=True), 200) - - async def test_get_utxo(self): - address = yield self.account.receiving.get_or_create_usable_address() - hash160 = self.ledger.address_to_hash160(address) - - tx = Transaction(is_verified=True)\ - .add_outputs([Output.pay_pubkey_hash(100, hash160)]) - await self.ledger.db.save_transaction_io( - 'insert', tx, address, hash160, f'{tx.id}:1:' - ) - - utxos = await self.account.get_utxos() - self.assertEqual(len(utxos), 1) - - tx = Transaction(is_verified=True)\ - .add_inputs([Input.spend(utxos[0])]) - await self.ledger.db.save_transaction_io( - 'insert', tx, address, hash160, f'{tx.id}:1:' - ) - self.assertEqual(await self.account.get_balance(include_claims=True), 0) - - utxos = await self.account.get_utxos() - self.assertEqual(len(utxos), 0) diff --git a/tests/unit/wallet/test_manager.py b/tests/unit/wallet/test_manager.py index f699dd6e8..e775814fe 100644 --- a/tests/unit/wallet/test_manager.py +++ b/tests/unit/wallet/test_manager.py @@ -4,69 +4,70 @@ import tempfile from lbry import Config, Ledger, Database, WalletManager, Wallet, Account from lbry.testcase import AsyncioTestCase +from lbry.wallet.manager import FileWallet, DatabaseWallet -class TestWalletManager(AsyncioTestCase): +class DBBasedWalletManagerTestCase(AsyncioTestCase): async def asyncSetUp(self): - self.temp_dir = tempfile.mkdtemp() - self.addCleanup(shutil.rmtree, self.temp_dir) - self.ledger = Ledger(Config.with_same_dir(self.temp_dir).set( - db_url="sqlite:///:memory:" + self.ledger = Ledger(Config.with_null_dir().set( + db_url="sqlite:///:memory:", + wallet_storage="database" )) self.db = Database(self.ledger) + await self.db.open() + self.addCleanup(self.db.close) - async def test_ensure_path_exists(self): - wm = WalletManager(self.ledger, self.db) - self.assertFalse(os.path.exists(wm.path)) - await wm.ensure_path_exists() - self.assertTrue(os.path.exists(wm.path)) - async def test_load_with_default_wallet_account_progression(self): - wm = WalletManager(self.ledger, self.db) - await wm.ensure_path_exists() +class TestDatabaseWalletManager(DBBasedWalletManagerTestCase): + + async def test_initialize_with_default_wallet_account_progression(self): + wm = WalletManager(self.db) + self.assertIsInstance(wm.storage, DatabaseWallet) + storage: DatabaseWallet = wm.storage + await storage.prepare() # first, no defaults self.ledger.conf.create_default_wallet = False self.ledger.conf.create_default_account = False - await wm.load() + await wm.initialize() self.assertIsNone(wm.default) # then, yes to default wallet but no to default account self.ledger.conf.create_default_wallet = True self.ledger.conf.create_default_account = False - await wm.load() + await wm.initialize() self.assertIsInstance(wm.default, Wallet) - self.assertTrue(os.path.exists(wm.default.storage.path)) + self.assertTrue(await storage.exists(wm.default.id)) self.assertIsNone(wm.default.accounts.default) # finally, yes to all the things self.ledger.conf.create_default_wallet = True self.ledger.conf.create_default_account = True - await wm.load() + await wm.initialize() self.assertIsInstance(wm.default, Wallet) self.assertIsInstance(wm.default.accounts.default, Account) async def test_load_with_create_default_everything_upfront(self): - wm = WalletManager(self.ledger, self.db) - await wm.ensure_path_exists() + wm = WalletManager(self.db) + await wm.storage.prepare() self.ledger.conf.create_default_wallet = True self.ledger.conf.create_default_account = True - await wm.load() + await wm.initialize() self.assertIsInstance(wm.default, Wallet) self.assertIsInstance(wm.default.accounts.default, Account) - self.assertTrue(os.path.exists(wm.default.storage.path)) + self.assertTrue(await wm.storage.exists(wm.default.id)) async def test_load_errors(self): - _wm = WalletManager(self.ledger, self.db) - await _wm.ensure_path_exists() + _wm = WalletManager(self.db) + await _wm.storage.prepare() await _wm.create('bar', '') await _wm.create('foo', '') - wm = WalletManager(self.ledger, self.db) + wm = WalletManager(self.db) self.ledger.conf.wallets = ['bar', 'foo', 'foo'] with self.assertLogs(level='WARN') as cm: - await wm.load() + await wm.initialize() self.assertEqual( cm.output, [ 'WARNING:lbry.wallet.manager:Ignoring duplicate wallet_id in config: foo', @@ -75,9 +76,9 @@ class TestWalletManager(AsyncioTestCase): self.assertEqual({'bar', 'foo'}, set(wm.wallets)) async def test_creating_and_accessing_wallets(self): - wm = WalletManager(self.ledger, self.db) - await wm.ensure_path_exists() - await wm.load() + wm = WalletManager(self.db) + await wm.storage.prepare() + await wm.initialize() default_wallet = wm.default self.assertEqual(default_wallet, wm['default_wallet']) self.assertEqual(default_wallet, wm.get_or_default(None)) @@ -90,3 +91,114 @@ class TestWalletManager(AsyncioTestCase): _ = wm['invalid'] with self.assertRaisesRegex(ValueError, "Couldn't find wallet: invalid"): wm.get_or_default('invalid') + + +class TestFileBasedWalletManager(AsyncioTestCase): + + async def asyncSetUp(self): + self.temp_dir = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, self.temp_dir) + self.ledger = Ledger(Config( + data_dir=self.temp_dir, + db_url="sqlite:///:memory:" + )) + self.ledger.conf.set_default_paths() + self.db = Database(self.ledger) + await self.db.open() + self.addCleanup(self.db.close) + + async def test_ensure_path_exists(self): + wm = WalletManager(self.db) + self.assertIsInstance(wm.storage, FileWallet) + storage: FileWallet = wm.storage + self.assertFalse(os.path.exists(storage.wallet_dir)) + await storage.prepare() + self.assertTrue(os.path.exists(storage.wallet_dir)) + + async def test_initialize_with_default_wallet_account_progression(self): + wm = WalletManager(self.db) + storage: FileWallet = wm.storage + await storage.prepare() + + # first, no defaults + self.ledger.conf.create_default_wallet = False + self.ledger.conf.create_default_account = False + await wm.initialize() + self.assertIsNone(wm.default) + + # then, yes to default wallet but no to default account + self.ledger.conf.create_default_wallet = True + self.ledger.conf.create_default_account = False + await wm.initialize() + self.assertIsInstance(wm.default, Wallet) + self.assertTrue(os.path.exists(storage.get_wallet_path(wm.default.id))) + self.assertIsNone(wm.default.accounts.default) + + # finally, yes to all the things + self.ledger.conf.create_default_wallet = True + self.ledger.conf.create_default_account = True + await wm.initialize() + self.assertIsInstance(wm.default, Wallet) + self.assertIsInstance(wm.default.accounts.default, Account) + + async def test_load_with_create_default_everything_upfront(self): + wm = WalletManager(self.db) + await wm.storage.prepare() + self.ledger.conf.create_default_wallet = True + self.ledger.conf.create_default_account = True + await wm.initialize() + self.assertIsInstance(wm.default, Wallet) + self.assertIsInstance(wm.default.accounts.default, Account) + self.assertTrue(os.path.exists(wm.storage.get_wallet_path(wm.default.id))) + + async def test_load_errors(self): + _wm = WalletManager(self.db) + await _wm.storage.prepare() + await _wm.create('bar', '') + await _wm.create('foo', '') + + wm = WalletManager(self.db) + self.ledger.conf.wallets = ['bar', 'foo', 'foo'] + with self.assertLogs(level='WARN') as cm: + await wm.initialize() + self.assertEqual( + cm.output, [ + 'WARNING:lbry.wallet.manager:Ignoring duplicate wallet_id in config: foo', + ] + ) + self.assertEqual({'bar', 'foo'}, set(wm.wallets)) + + async def test_creating_and_accessing_wallets(self): + wm = WalletManager(self.db) + await wm.storage.prepare() + await wm.initialize() + default_wallet = wm.default + self.assertEqual(default_wallet, wm['default_wallet']) + self.assertEqual(default_wallet, wm.get_or_default(None)) + new_wallet = await wm.create('second', 'Second Wallet') + self.assertEqual(default_wallet, wm.default) + self.assertEqual(new_wallet, wm['second']) + self.assertEqual(new_wallet, wm.get_or_default('second')) + self.assertEqual(default_wallet, wm.get_or_default(None)) + with self.assertRaisesRegex(ValueError, "Couldn't find wallet: invalid"): + _ = wm['invalid'] + with self.assertRaisesRegex(ValueError, "Couldn't find wallet: invalid"): + wm.get_or_default('invalid') + + async def test_read_write(self): + manager = WalletManager(self.db) + await manager.storage.prepare() + + with tempfile.NamedTemporaryFile(suffix='.json') as wallet_file: + wallet_file.write(b'{"version": 1}') + wallet_file.seek(0) + + # create and write wallet to a file + wallet = await manager.load(wallet_file.name) + account = await wallet.accounts.generate() + await manager.storage.save(wallet) + + # read wallet from file + wallet = await manager.load(wallet_file.name) + + self.assertEqual(account.public_key.address, wallet.accounts.default.public_key.address) diff --git a/tests/unit/wallet/test_sync.py b/tests/unit/wallet/test_sync.py deleted file mode 100644 index 51e7248c9..000000000 --- a/tests/unit/wallet/test_sync.py +++ /dev/null @@ -1,78 +0,0 @@ -from lbry.testcase import UnitDBTestCase - - -class TestClientDBSync(UnitDBTestCase): - - async def asyncSetUp(self): - await super().asyncSetUp() - await self.add(self.coinbase()) - - async def test_process_inputs(self): - await self.add(self.tx()) - await self.add(self.tx()) - txo1, txo2a, txo2b, txo3a, txo3b = self.outputs - self.assertEqual([ - (txo1.id, None), - (txo2b.id, None), - ], await self.get_txis()) - self.assertEqual([ - (txo1.id, False), - (txo2a.id, False), - (txo2b.id, False), - (txo3a.id, False), - (txo3b.id, False), - ], await self.get_txos()) - await self.db.process_all_things_after_sync() - self.assertEqual([ - (txo1.id, txo1.get_address(self.ledger)), - (txo2b.id, txo2b.get_address(self.ledger)), - ], await self.get_txis()) - self.assertEqual([ - (txo1.id, True), - (txo2a.id, False), - (txo2b.id, True), - (txo3a.id, False), - (txo3b.id, False), - ], await self.get_txos()) - - async def test_process_claims(self): - claim1 = await self.add(self.create_claim()) - await self.db.process_all_things_after_sync() - self.assertEqual([claim1.claim_id], await self.get_claims()) - - claim2 = await self.add(self.create_claim()) - self.assertEqual([claim1.claim_id], await self.get_claims()) - await self.db.process_all_things_after_sync() - self.assertEqual([claim1.claim_id, claim2.claim_id], await self.get_claims()) - - claim3 = await self.add(self.create_claim()) - claim4 = await self.add(self.create_claim()) - await self.db.process_all_things_after_sync() - self.assertEqual([ - claim1.claim_id, - claim2.claim_id, - claim3.claim_id, - claim4.claim_id, - ], await self.get_claims()) - - await self.add(self.abandon_claim(claim4)) - await self.db.process_all_things_after_sync() - self.assertEqual([ - claim1.claim_id, claim2.claim_id, claim3.claim_id - ], await self.get_claims()) - - # create and abandon in same block - claim5 = await self.add(self.create_claim()) - await self.add(self.abandon_claim(claim5)) - await self.db.process_all_things_after_sync() - self.assertEqual([ - claim1.claim_id, claim2.claim_id, claim3.claim_id - ], await self.get_claims()) - - # create and abandon in different blocks but with bulk sync - claim6 = await self.add(self.create_claim()) - await self.add(self.abandon_claim(claim6)) - await self.db.process_all_things_after_sync() - self.assertEqual([ - claim1.claim_id, claim2.claim_id, claim3.claim_id - ], await self.get_claims()) diff --git a/tests/unit/wallet/test_wallet.py b/tests/unit/wallet/test_wallet.py index 087004877..9fd2c2e01 100644 --- a/tests/unit/wallet/test_wallet.py +++ b/tests/unit/wallet/test_wallet.py @@ -1,18 +1,17 @@ -import tempfile +from itertools import cycle from binascii import hexlify from unittest import TestCase, mock -from lbry import Config, Database, Ledger, Account, Wallet, WalletManager -from lbry.testcase import AsyncioTestCase -from lbry.wallet.storage import WalletStorage +from lbry import Config, Database, Ledger, Account, Wallet, Transaction, Output, Input +from lbry.testcase import AsyncioTestCase, get_output, COIN, CENT from lbry.wallet.preferences import TimestampedPreferences class WalletTestCase(AsyncioTestCase): async def asyncSetUp(self): - self.ledger = Ledger(Config.with_null_dir()) - self.db = Database(self.ledger, "sqlite:///:memory:") + self.ledger = Ledger(Config.with_null_dir().set(db_url='sqlite:///:memory:')) + self.db = Database(self.ledger) await self.db.open() self.addCleanup(self.db.close) @@ -20,8 +19,8 @@ class WalletTestCase(AsyncioTestCase): class WalletAccountTest(WalletTestCase): async def test_private_key_for_hierarchical_account(self): - wallet = Wallet(self.ledger, self.db) - account = wallet.add_account({ + wallet = Wallet("wallet1", self.db) + account = await wallet.accounts.add_from_dict({ "seed": "carbon smart garage balance margin twelve chest sword toas" "t envelope bottom stomach absent" @@ -40,8 +39,8 @@ class WalletAccountTest(WalletTestCase): ) async def test_private_key_for_single_address_account(self): - wallet = Wallet(self.ledger, self.db) - account = wallet.add_account({ + wallet = Wallet("wallet1", self.db) + account = await wallet.accounts.add_from_dict({ "seed": "carbon smart garage balance margin twelve chest sword toas" "t envelope bottom stomach absent", @@ -59,9 +58,9 @@ class WalletAccountTest(WalletTestCase): ) async def test_save_max_gap(self): - wallet = Wallet(self.ledger, self.db) - account = wallet.generate_account( - 'lbryum', { + wallet = Wallet("wallet1", self.db) + account = await wallet.accounts.generate( + address_generator={ 'name': 'deterministic-chain', 'receiving': {'gap': 3, 'maximum_uses_per_address': 2}, 'change': {'gap': 4, 'maximum_uses_per_address': 2} @@ -73,24 +72,25 @@ class WalletAccountTest(WalletTestCase): self.assertEqual(account.receiving.gap, 20) self.assertEqual(account.change.gap, 6) # doesn't fail for single-address account - wallet.generate_account('lbryum', {'name': 'single-address'}) + await wallet.accounts.generate(address_generator={'name': 'single-address'}) await wallet.save_max_gap() class TestWalletCreation(WalletTestCase): - def test_create_wallet_and_accounts(self): - wallet = Wallet(self.ledger, self.db) - self.assertEqual(wallet.name, 'Wallet') - self.assertListEqual(wallet.accounts, []) + async def test_create_wallet_and_accounts(self): + wallet = Wallet("wallet1", self.db) + self.assertEqual(wallet.id, "wallet1") + self.assertEqual(wallet.name, "") + self.assertEqual(list(wallet.accounts), []) - account1 = wallet.generate_account() - wallet.generate_account() - wallet.generate_account() - self.assertEqual(wallet.default_account, account1) + account1 = await wallet.accounts.generate() + await wallet.accounts.generate() + await wallet.accounts.generate() + self.assertEqual(wallet.accounts.default, account1) self.assertEqual(len(wallet.accounts), 3) - def test_load_and_save_wallet(self): + async def test_load_and_save_wallet(self): wallet_dict = { 'version': 1, 'name': 'Main Wallet', @@ -105,6 +105,7 @@ class TestWalletCreation(WalletTestCase): "carbon smart garage balance margin twelve chest sword toast envelope bottom stomac" "h absent", 'encrypted': False, + 'lang': 'en', 'private_key': 'xprv9s21ZrQH143K42ovpZygnjfHdAqSd9jo7zceDfPRogM7bkkoNVv7' 'DRNLEoB8HoirMgH969NrgL8jNzLEegqFzPRWM37GXd4uE8uuRkx4LAe', @@ -120,15 +121,14 @@ class TestWalletCreation(WalletTestCase): ] } - storage = WalletStorage(default=wallet_dict) - wallet = Wallet.from_storage(self.ledger, self.db, storage) + wallet = await Wallet.from_dict('wallet1', wallet_dict, self.db) self.assertEqual(wallet.name, 'Main Wallet') self.assertEqual( hexlify(wallet.hash), - b'3b23aae8cd9b360f4296130b8f7afc5b2437560cdef7237bed245288ce8a5f79' + b'64a32cf8434a59c547abf61b4691a8189ac24272678b52ced2310fbf93eac974' ) self.assertEqual(len(wallet.accounts), 1) - account = wallet.default_account + account = wallet.accounts.default self.assertIsInstance(account, Account) self.maxDiff = None self.assertDictEqual(wallet_dict, wallet.to_dict()) @@ -137,41 +137,23 @@ class TestWalletCreation(WalletTestCase): decrypted = Wallet.unpack('password', encrypted) self.assertEqual(decrypted['accounts'][0]['name'], 'An Account') - def test_read_write(self): - manager = WalletManager(self.ledger, self.db) - - with tempfile.NamedTemporaryFile(suffix='.json') as wallet_file: - wallet_file.write(b'{"version": 1}') - wallet_file.seek(0) - - # create and write wallet to a file - wallet = manager.import_wallet(wallet_file.name) - account = wallet.generate_account() - wallet.save() - - # read wallet from file - wallet_storage = WalletStorage(wallet_file.name) - wallet = Wallet.from_storage(self.ledger, self.db, wallet_storage) - - self.assertEqual(account.public_key.address, wallet.default_account.public_key.address) - - def test_merge(self): - wallet1 = Wallet(self.ledger, self.db) + async def test_merge(self): + wallet1 = Wallet('wallet1', self.db) wallet1.preferences['one'] = 1 wallet1.preferences['conflict'] = 1 - wallet1.generate_account() - wallet2 = Wallet(self.ledger, self.db) + await wallet1.accounts.generate() + wallet2 = Wallet('wallet2', self.db) wallet2.preferences['two'] = 2 wallet2.preferences['conflict'] = 2 # will be more recent - wallet2.generate_account() + await wallet2.accounts.generate() self.assertEqual(len(wallet1.accounts), 1) self.assertEqual(wallet1.preferences, {'one': 1, 'conflict': 1}) added = await wallet1.merge('password', wallet2.pack('password')) - self.assertEqual(added[0].id, wallet2.default_account.id) + self.assertEqual(added[0].id, wallet2.accounts.default.id) self.assertEqual(len(wallet1.accounts), 2) - self.assertEqual(wallet1.accounts[1].id, wallet2.default_account.id) + self.assertEqual(list(wallet1.accounts)[1].id, wallet2.accounts.default.id) self.assertEqual(wallet1.preferences, {'one': 1, 'two': 2, 'conflict': 2}) @@ -213,3 +195,138 @@ class TestTimestampedPreferences(TestCase): p1['conflict'] = 1 p1.merge(p2.data) self.assertEqual(p1, {'one': 1, 'two': 2, 'conflict': 1}) + + +class TestTransactionSigning(WalletTestCase): + + async def test_sign(self): + wallet = Wallet('wallet1', self.db) + account = await wallet.accounts.add_from_dict({ + "seed": + "carbon smart garage balance margin twelve chest sword toas" + "t envelope bottom stomach absent" + }) + + await account.ensure_address_gap() + address1, address2 = await account.receiving.get_addresses(limit=2) + pubkey_hash1 = self.ledger.address_to_hash160(address1) + pubkey_hash2 = self.ledger.address_to_hash160(address2) + + tx = Transaction() \ + .add_inputs([Input.spend(get_output(int(2*COIN), pubkey_hash1))]) \ + .add_outputs([Output.pay_pubkey_hash(int(1.9*COIN), pubkey_hash2)]) + + await wallet.sign(tx) + + self.assertEqual( + hexlify(tx.inputs[0].script.values['signature']), + b'304402200dafa26ad7cf38c5a971c8a25ce7d85a076235f146126762296b1223c42ae21e022020ef9eeb8' + b'398327891008c5c0be4357683f12cb22346691ff23914f457bf679601' + ) + + +class TransactionIOBalancing(WalletTestCase): + + async def asyncSetUp(self): + await super().asyncSetUp() + self.wallet = Wallet('wallet1', self.db) + self.account = await self.wallet.accounts.add_from_dict({ + "seed": + "carbon smart garage balance margin twelve chest sword toas" + "t envelope bottom stomach absent" + }) + addresses = await self.account.ensure_address_gap() + self.pubkey_hash = [self.ledger.address_to_hash160(a) for a in addresses] + self.hash_cycler = cycle(self.pubkey_hash) + + def txo(self, amount, address=None): + return get_output(int(amount*COIN), address or next(self.hash_cycler)) + + def txi(self, txo): + return Input.spend(txo) + + def tx(self, inputs, outputs): + return self.wallet.create_transaction(inputs, outputs, [self.account], self.account) + + async def create_utxos(self, amounts): + utxos = [self.txo(amount) for amount in amounts] + + self.funding_tx = Transaction(is_verified=True) \ + .add_inputs([self.txi(self.txo(sum(amounts)+0.1))]) \ + .add_outputs(utxos) + + await self.db.insert_transaction(b'beef', self.funding_tx) + + return utxos + + @staticmethod + def inputs(tx): + return [round(i.amount/COIN, 2) for i in tx.inputs] + + @staticmethod + def outputs(tx): + return [round(o.amount/COIN, 2) for o in tx.outputs] + + async def test_basic_use_cases(self): + self.ledger.fee_per_byte = int(.01*CENT) + + # available UTXOs for filling missing inputs + utxos = await self.create_utxos([ + 1, 1, 3, 5, 10 + ]) + + # pay 3 coins (3.02 w/ fees) + tx = await self.tx( + [], # inputs + [self.txo(3)] # outputs + ) + # best UTXO match is 5 (as UTXO 3 will be short 0.02 to cover fees) + self.assertListEqual(self.inputs(tx), [5]) + # a change of 1.98 is added to reach balance + self.assertListEqual(self.outputs(tx), [3, 1.98]) + + await self.db.release_outputs(utxos) + + # pay 2.98 coins (3.00 w/ fees) + tx = await self.tx( + [], # inputs + [self.txo(2.98)] # outputs + ) + # best UTXO match is 3 and no change is needed + self.assertListEqual(self.inputs(tx), [3]) + self.assertListEqual(self.outputs(tx), [2.98]) + + await self.db.release_outputs(utxos) + + # supplied input and output, but input is not enough to cover output + tx = await self.tx( + [self.txi(self.txo(10))], # inputs + [self.txo(11)] # outputs + ) + # additional input is chosen (UTXO 3) + self.assertListEqual([10, 3], self.inputs(tx)) + # change is now needed to consume extra input + self.assertListEqual([11, 1.96], self.outputs(tx)) + + await self.db.release_outputs(utxos) + + # liquidating a UTXO + tx = await self.tx( + [self.txi(self.txo(10))], # inputs + [] # outputs + ) + self.assertListEqual([10], self.inputs(tx)) + # missing change added to consume the amount + self.assertListEqual([9.98], self.outputs(tx)) + + await self.db.release_outputs(utxos) + + # liquidating at a loss, requires adding extra inputs + tx = await self.tx( + [self.txi(self.txo(0.01))], # inputs + [] # outputs + ) + # UTXO 1 is added to cover some of the fee + self.assertListEqual([0.01, 1], self.inputs(tx)) + # change is now needed to consume extra input + self.assertListEqual([0.97], self.outputs(tx))