working db based wallet and wallet sync progress

This commit is contained in:
Lex Berezhny 2020-10-09 10:47:44 -04:00
parent 4356d23cc1
commit 6ed2fa20ec
28 changed files with 1315 additions and 2096 deletions

View file

@ -42,13 +42,15 @@ class BlockchainSync(Sync):
super().__init__(chain.ledger, db)
self.chain = chain
self.pid = os.getpid()
self._on_block_controller = EventController()
self.on_block = self._on_block_controller.stream
self._on_mempool_controller = EventController()
self.on_mempool = self._on_mempool_controller.stream
self.on_block_hash_subscription: Optional[BroadcastSubscription] = None
self.on_tx_hash_subscription: Optional[BroadcastSubscription] = None
self.advance_loop_task: Optional[asyncio.Task] = None
self.block_hash_event = asyncio.Event()
self.tx_hash_event = asyncio.Event()
self._on_mempool_controller = EventController()
self.on_mempool = self._on_mempool_controller.stream
self.mempool = []
async def wait_for_chain_ready(self):

View file

@ -519,6 +519,7 @@ class Config(CLIConfig):
data_dir = Path("Main directory containing blobs, wallets and blockchain data.", metavar='DIR')
blob_dir = Path("Directory to store blobs (default: 'data_dir'/blobs).", metavar='DIR')
wallet_dir = Path("Directory to store wallets (default: 'data_dir'/wallets).", metavar='DIR')
wallet_storage = StringChoice("Wallet storage mode.", ["file", "database"], "file")
wallets = Strings(
"Wallet files in 'wallet_dir' to load at startup.", ['default_wallet']
)

View file

@ -205,6 +205,9 @@ class Database:
async def execute(self, sql):
return await self.run(q.execute, sql)
async def execute_sql_object(self, sql):
return await self.run(q.execute_sql_object, sql)
async def execute_fetchall(self, sql):
return await self.run(q.execute_fetchall, sql)
@ -217,12 +220,27 @@ class Database:
async def has_supports(self):
return await self.run(q.has_supports)
async def has_wallet(self, wallet_id):
return await self.run(q.has_wallet, wallet_id)
async def get_wallet(self, wallet_id: str):
return await self.run(q.get_wallet, wallet_id)
async def add_wallet(self, wallet_id: str, data: str):
return await self.run(q.add_wallet, wallet_id, data)
async def get_best_block_height(self) -> int:
return await self.run(q.get_best_block_height)
async def process_all_things_after_sync(self):
return await self.run(sync.process_all_things_after_sync)
async def get_blocks(self, first, last=None):
return await self.run(q.get_blocks, first, last)
async def get_filters(self, start_height, end_height=None, granularity=0):
return await self.run(q.get_filters, start_height, end_height, granularity)
async def insert_block(self, block):
return await self.run(q.insert_block, block)

View file

@ -3,3 +3,4 @@ from .txio import *
from .search import *
from .resolve import *
from .address import *
from .wallet import *

View file

@ -1,14 +1,24 @@
from sqlalchemy import text
from math import log10
from binascii import hexlify
from sqlalchemy import text, between
from sqlalchemy.future import select
from ..query_context import context
from ..tables import SCHEMA_VERSION, metadata, Version, Claim, Support, Block, BlockFilter, TX
from ..tables import (
SCHEMA_VERSION, metadata, Version,
Claim, Support, Block, BlockFilter, BlockGroupFilter, TX, TXFilter,
)
def execute(sql):
return context().execute(text(sql))
def execute_sql_object(sql):
return context().execute(sql)
def execute_fetchall(sql):
return context().fetchall(text(sql))
@ -33,6 +43,53 @@ def insert_block(block):
context().get_bulk_loader().add_block(block).flush()
def get_blocks(first, last=None):
if last is not None:
query = (
select('*').select_from(Block)
.where(between(Block.c.height, first, last))
.order_by(Block.c.height)
)
else:
query = select('*').select_from(Block).where(Block.c.height == first)
return context().fetchall(query)
def get_filters(start_height, end_height=None, granularity=0):
assert granularity >= 0, "filter granularity must be 0 or positive number"
if granularity == 0:
query = (
select('*').select_from(TXFilter)
.where(between(TXFilter.c.height, start_height, end_height))
.order_by(TXFilter.c.height)
)
elif granularity == 1:
query = (
select('*').select_from(BlockFilter)
.where(between(BlockFilter.c.height, start_height, end_height))
.order_by(BlockFilter.c.height)
)
else:
query = (
select('*').select_from(BlockGroupFilter)
.where(
(BlockGroupFilter.c.height == start_height) &
(BlockGroupFilter.c.factor == log10(granularity))
)
.order_by(BlockGroupFilter.c.height)
)
result = []
for row in context().fetchall(query):
record = {
"height": row["height"],
"filter": hexlify(row["address_filter"]).decode(),
}
if granularity == 0:
record["txid"] = hexlify(row["tx_hash"][::-1]).decode()
result.append(record)
return result
def insert_transaction(block_hash, tx):
context().get_bulk_loader().add_transaction(block_hash, tx).flush(TX)

24
lbry/db/queries/wallet.py Normal file
View file

@ -0,0 +1,24 @@
from sqlalchemy import exists
from sqlalchemy.future import select
from ..query_context import context
from ..tables import Wallet
def has_wallet(wallet_id: str) -> bool:
sql = select(exists(select(Wallet.c.wallet_id).where(Wallet.c.wallet_id == wallet_id)))
return context().execute(sql).fetchone()[0]
def get_wallet(wallet_id: str):
return context().fetchone(
select(Wallet.c.data).where(Wallet.c.wallet_id == wallet_id)
)
def add_wallet(wallet_id: str, data: str):
c = context()
c.execute(
c.insert_or_replace(Wallet, ["data"])
.values(wallet_id=wallet_id, data=data)
)

View file

@ -19,6 +19,13 @@ Version = Table(
)
Wallet = Table(
'wallet', metadata,
Column('wallet_id', Text, primary_key=True),
Column('data', Text),
)
PubkeyAddress = Table(
'pubkey_address', metadata,
Column('address', Text, primary_key=True),

View file

@ -1225,7 +1225,7 @@ class API:
wallet = self.wallets.get_or_default(wallet_id)
wallet_changed = False
if data is not None:
added_accounts = await wallet.merge(self.wallets, password, data)
added_accounts = await wallet.merge(password, data)
if added_accounts and self.ledger.sync.network.is_connected:
if blocking:
await asyncio.wait([
@ -1318,11 +1318,24 @@ class API:
.receiving.get_or_create_usable_address()
)
async def address_block_filters(self):
return await self.service.get_block_address_filters()
async def address_filter(
self,
start_height: int, # starting height of block range or just single block
end_height: int = None, # return a range of blocks from start_height to end_height
granularity: int = None, # 0 - individual tx filters, 1 - block per filter,
# 1000, 10000, 100000 blocks per filter
) -> list: # blocks
"""
List address filters
async def address_transaction_filters(self, block_hash: str):
return await self.service.get_transaction_address_filters(block_hash)
Usage:
address_filter <start_height>
[--end_height=<end_height>]
[--granularity=<granularity>]
"""
return await self.service.get_address_filters(
start_height=start_height, end_height=end_height, granularity=granularity
)
FILE_DOC = """
File management.
@ -2656,6 +2669,23 @@ class API:
await self.service.maybe_broadcast_or_release(tx, blocking, preview)
return tx
BLOCK_DOC = """
Block information.
"""
async def block_list(
self,
start_height: int, # starting height of block range or just single block
end_height: int = None, # return a range of blocks from start_height to end_height
) -> list: # blocks
"""
List block info
Usage:
block_list <start_height> [<end_height>]
"""
return await self.service.get_blocks(start_height=start_height, end_height=end_height)
TRANSACTION_DOC = """
Transaction management.
"""
@ -3529,8 +3559,12 @@ class Client(API):
self.receive_messages_task = asyncio.create_task(self.receive_messages())
async def disconnect(self):
if self.session is not None:
await self.session.close()
self.session = None
if self.receive_messages_task is not None:
self.receive_messages_task.cancel()
self.receive_messages_task = None
async def receive_messages(self):
async for message in self.ws:
@ -3559,12 +3593,16 @@ class Client(API):
await self.ws.send_json({'id': self.message_id, 'method': method, 'params': kwargs})
return ec.stream
async def subscribe(self, event) -> EventStream:
def get_event_stream(self, event) -> EventStream:
if event not in self.subscriptions:
self.subscriptions[event] = EventController()
await self.ws.send_json({'id': None, 'method': 'subscribe', 'params': [event]})
return self.subscriptions[event].stream
async def start_event_streams(self):
events = list(self.subscriptions.keys())
if events:
await self.ws.send_json({'id': None, 'method': 'subscribe', 'params': events})
def __getattribute__(self, name):
if name in dir(API):
return partial(object.__getattribute__(self, 'send'), name)

View file

@ -9,7 +9,7 @@ from lbry.schema.result import Censor
from lbry.blockchain.transaction import Transaction, Output
from lbry.blockchain.ledger import Ledger
from lbry.wallet import WalletManager
from lbry.event import EventController
from lbry.event import EventController, EventStream
log = logging.getLogger(__name__)
@ -25,14 +25,14 @@ class Sync:
Server stays synced with lbrycrd
"""
on_block: EventStream
on_mempool: EventStream
def __init__(self, ledger: Ledger, db: Database):
self.ledger = ledger
self.conf = ledger.conf
self.db = db
self._on_block_controller = EventController()
self.on_block = self._on_block_controller.stream
self._on_progress_controller = db._on_progress_controller
self.on_progress = db.on_progress
@ -63,13 +63,7 @@ class Service:
def __init__(self, ledger: Ledger):
self.ledger, self.conf = ledger, ledger.conf
self.db = Database(ledger)
self.wallets = WalletManager(ledger, self.db)
#self.on_address = sync.on_address
#self.accounts = sync.accounts
#self.on_header = sync.on_header
#self.on_ready = sync.on_ready
#self.on_transaction = sync.on_transaction
self.wallets = WalletManager(self.db)
# sync has established connection with a source from which it can synchronize
# for full service this is lbrycrd (or sync service) and for light this is full node
@ -78,8 +72,8 @@ class Service:
async def start(self):
await self.db.open()
await self.wallets.ensure_path_exists()
await self.wallets.load()
await self.wallets.storage.prepare()
await self.wallets.initialize()
await self.sync.start()
async def stop(self):
@ -95,16 +89,21 @@ class Service:
async def find_ffmpeg(self):
pass
async def get(self, uri, **kwargs):
async def get_file(self, uri, **kwargs):
pass
def create_wallet(self, file_name):
path = os.path.join(self.conf.wallet_dir, file_name)
return self.wallets.add_from_path(path)
async def get_block_headers(self, first, last=None):
pass
def create_wallet(self, wallet_id):
return self.wallets.create(wallet_id)
async def get_addresses(self, **constraints):
return await self.db.get_addresses(**constraints)
async def get_address_filters(self, start_height: int, end_height: int=None, granularity: int=0):
raise NotImplementedError
def reserve_outputs(self, txos):
return self.db.reserve_outputs(txos)

View file

@ -19,11 +19,13 @@ from lbry.blockchain.ledger import ledger_class_from_name
log = logging.getLogger(__name__)
def jsonrpc_dumps_pretty(obj, **kwargs):
def jsonrpc_dumps_pretty(obj, message_id=None, **kwargs):
#if not isinstance(obj, dict):
# data = {"jsonrpc": "2.0", "error": obj.to_dict()}
#else:
data = {"jsonrpc": "2.0", "result": obj}
if message_id is not None:
data["id"] = message_id
return json.dumps(data, cls=JSONResponseEncoder, sort_keys=True, indent=2, **kwargs) + "\n"
@ -166,7 +168,7 @@ class Daemon:
async def on_message(self, web_socket: WebSocketManager, msg: dict):
if msg['method'] == 'subscribe':
streams = msg['streams']
streams = msg['params']
if isinstance(streams, str):
streams = [streams]
web_socket.subscribe(streams, self.app['subscriptions'])
@ -175,11 +177,10 @@ class Daemon:
method = getattr(self.api, msg['method'])
try:
result = await method(**params)
encoded_result = jsonrpc_dumps_pretty(result, service=self.service)
await web_socket.send_json({
'id': msg.get('id', ''),
'result': encoded_result
})
encoded_result = jsonrpc_dumps_pretty(
result, message_id=msg.get('id', ''), service=self.service
)
await web_socket.send_str(encoded_result)
except Exception as e:
log.exception("RPC error")
await web_socket.send_json({'id': msg.get('id', ''), 'result': "unexpected error: " + str(e)})

View file

@ -0,0 +1,61 @@
import logging
from binascii import hexlify, unhexlify
from lbry.blockchain.lbrycrd import Lbrycrd
from lbry.blockchain.sync import BlockchainSync
from lbry.blockchain.ledger import Ledger
from lbry.blockchain.transaction import Transaction
from .base import Service, Sync
from .api import Client as APIClient
log = logging.getLogger(__name__)
class NoSync(Sync):
def __init__(self, service: Service, client: APIClient):
super().__init__(service.ledger, service.db)
self.service = service
self.client = client
self.on_block = client.get_event_stream('blockchain.block')
self.on_block_subscription: Optional[BroadcastSubscription] = None
self.on_mempool = client.get_event_stream('blockchain.mempool')
self.on_mempool_subscription: Optional[BroadcastSubscription] = None
async def wait_for_client_ready(self):
await self.client.connect()
async def start(self):
self.db.stop_event.clear()
await self.wait_for_client_ready()
self.advance_loop_task = asyncio.create_task(self.advance())
await self.advance_loop_task
await self.client.subscribe()
self.advance_loop_task = asyncio.create_task(self.advance_loop())
self.on_block_subscription = self.on_block.listen(
lambda e: self.on_block_event.set()
)
self.on_mempool_subscription = self.on_mempool.listen(
lambda e: self.on_mempool_event.set()
)
await self.download_filters()
await self.download_headers()
async def stop(self):
await self.client.disconnect()
class FullEndpoint(Service):
name = "endpoint"
sync: 'NoSync'
def __init__(self, ledger: Ledger):
super().__init__(ledger)
self.client = APIClient(
f"http://{ledger.conf.full_nodes[0][0]}:{ledger.conf.full_nodes[0][1]}/api"
)
self.sync = NoSync(self, self.client)

View file

@ -35,6 +35,14 @@ class FullNode(Service):
async def get_status(self):
return 'everything is wonderful'
async def get_block_headers(self, first, last=None):
return await self.db.get_blocks(first, last)
async def get_address_filters(self, start_height: int, end_height: int = None, granularity: int = 0):
return await self.db.get_filters(
start_height=start_height, end_height=end_height, granularity=granularity
)
# async def get_block_address_filters(self):
# return {
# hexlify(f['block_hash']).decode(): hexlify(f['block_filter']).decode()

View file

@ -1,11 +1,25 @@
import asyncio
import logging
from typing import List, Dict, Tuple
from typing import List, Dict
#from io import StringIO
#from functools import partial
#from operator import itemgetter
from collections import defaultdict
#from binascii import hexlify, unhexlify
from typing import List, Optional, DefaultDict, NamedTuple
#from lbry.crypto.hash import double_sha256, sha256
from lbry.tasks import TaskGroup
from lbry.blockchain.transaction import Transaction
from lbry.blockchain.block import get_address_filter
from lbry.event import BroadcastSubscription, EventController
from lbry.wallet.account import AddressManager
from lbry.blockchain import Ledger, Transaction
from lbry.wallet.sync import SPVSync
from .base import Service
from .api import Client
from .base import Service, Sync
from .api import Client as APIClient
log = logging.getLogger(__name__)
@ -14,23 +28,24 @@ class LightClient(Service):
name = "client"
sync: SPVSync
sync: 'FastSync'
def __init__(self, ledger: Ledger):
super().__init__(ledger)
self.client = Client(
f"http://{ledger.conf.full_nodes[0][0]}:{ledger.conf.full_nodes[0][1]}/api"
self.client = APIClient(
f"http://{ledger.conf.full_nodes[0][0]}:{ledger.conf.full_nodes[0][1]}/ws"
)
self.sync = SPVSync(self)
self.sync = FastSync(self, self.client)
self.blocks = BlockHeaderManager(self.db, self.client)
self.filters = FilterManager(self.db, self.client)
async def search_transactions(self, txids):
return await self.client.transaction_search(txids=txids)
async def get_block_address_filters(self):
return await self.client.address_block_filters()
async def get_transaction_address_filters(self, block_hash):
return await self.client.address_transaction_filters(block_hash=block_hash)
async def get_address_filters(self, start_height: int, end_height: int = None, granularity: int = 0):
return await self.filters.get_filters(
start_height=start_height, end_height=end_height, granularity=granularity
)
async def broadcast(self, tx):
pass
@ -47,6 +62,437 @@ class LightClient(Service):
async def search_supports(self, accounts, **kwargs):
pass
async def sum_supports(self, claim_hash: bytes, include_channel_content=False, exclude_own_supports=False) \
-> Tuple[List[Dict], int]:
return await self.client.sum_supports(claim_hash, include_channel_content, exclude_own_supports)
async def sum_supports(self, claim_hash: bytes, include_channel_content=False) -> List[Dict]:
return await self.client.sum_supports(claim_hash, include_channel_content)
class TransactionEvent(NamedTuple):
address: str
tx: Transaction
class AddressesGeneratedEvent(NamedTuple):
address_manager: AddressManager
addresses: List[str]
class TransactionCacheItem:
__slots__ = '_tx', 'lock', 'has_tx', 'pending_verifications'
def __init__(self, tx: Optional[Transaction] = None, lock: Optional[asyncio.Lock] = None):
self.has_tx = asyncio.Event()
self.lock = lock or asyncio.Lock()
self._tx = self.tx = tx
self.pending_verifications = 0
@property
def tx(self) -> Optional[Transaction]:
return self._tx
@tx.setter
def tx(self, tx: Transaction):
self._tx = tx
if tx is not None:
self.has_tx.set()
class FilterManager:
"""
Efficient on-demand address filter access.
Stores and retrieves from local db what it previously downloaded and
downloads on-demand what it doesn't have from full node.
"""
def __init__(self, db, client):
self.db = db
self.client = client
self.cache = {}
async def get_filters(self, start_height, end_height, granularity):
return await self.client.address_filter(
start_height=start_height, end_height=end_height, granularity=granularity
)
class BlockHeaderManager:
"""
Efficient on-demand block header access.
Stores and retrieves from local db what it previously downloaded and
downloads on-demand what it doesn't have from full node.
"""
def __init__(self, db, client):
self.db = db
self.client = client
self.cache = {}
async def get_header(self, height):
blocks = await self.client.block_list(height)
if blocks:
return blocks[0]
async def add(self, header):
pass
async def download(self):
pass
class FastSync(Sync):
def __init__(self, service: Service, client: APIClient):
super().__init__(service.ledger, service.db)
self.service = service
self.client = client
self.advance_loop_task: Optional[asyncio.Task] = None
self.on_block = client.get_event_stream('blockchain.block')
self.on_block_event = asyncio.Event()
self.on_block_subscription: Optional[BroadcastSubscription] = None
self.on_mempool = client.get_event_stream('blockchain.mempool')
self.on_mempool_event = asyncio.Event()
self.on_mempool_subscription: Optional[BroadcastSubscription] = None
async def wait_for_client_ready(self):
await self.client.connect()
async def start(self):
return
self.db.stop_event.clear()
await self.wait_for_client_ready()
self.advance_loop_task = asyncio.create_task(self.advance())
await self.advance_loop_task
await self.client.subscribe()
self.advance_loop_task = asyncio.create_task(self.advance_loop())
self.on_block_subscription = self.on_block.listen(
lambda e: self.on_block_event.set()
)
self.on_mempool_subscription = self.on_mempool.listen(
lambda e: self.on_mempool_event.set()
)
await self.download_filters()
await self.download_headers()
async def stop(self):
await self.client.disconnect()
async def advance(self):
address_array = [
bytearray(a['address'].encode())
for a in await self.service.db.get_all_addresses()
]
block_filters = await self.service.get_block_address_filters()
for block_hash, block_filter in block_filters.items():
bf = get_address_filter(block_filter)
if bf.MatchAny(address_array):
print(f'match: {block_hash} - {block_filter}')
tx_filters = await self.service.get_transaction_address_filters(block_hash=block_hash)
for txid, tx_filter in tx_filters.items():
tf = get_address_filter(tx_filter)
if tf.MatchAny(address_array):
print(f' match: {txid} - {tx_filter}')
txs = await self.service.search_transactions([txid])
tx = Transaction(unhexlify(txs[txid]))
await self.service.db.insert_transaction(tx)
# async def get_local_status_and_history(self, address, history=None):
# if not history:
# address_details = await self.db.get_address(address=address)
# history = (address_details['history'] if address_details else '') or ''
# parts = history.split(':')[:-1]
# return (
# hexlify(sha256(history.encode())).decode() if history else None,
# list(zip(parts[0::2], map(int, parts[1::2])))
# )
#
# @staticmethod
# def get_root_of_merkle_tree(branches, branch_positions, working_branch):
# for i, branch in enumerate(branches):
# other_branch = unhexlify(branch)[::-1]
# other_branch_on_left = bool((branch_positions >> i) & 1)
# if other_branch_on_left:
# combined = other_branch + working_branch
# else:
# combined = working_branch + other_branch
# working_branch = double_sha256(combined)
# return hexlify(working_branch[::-1])
#
#
# @property
# def local_height_including_downloaded_height(self):
# return max(self.headers.height, self._download_height)
#
# async def initial_headers_sync(self):
# get_chunk = partial(self.network.retriable_call, self.network.get_headers, count=1000, b64=True)
# self.headers.chunk_getter = get_chunk
#
# async def doit():
# for height in reversed(sorted(self.headers.known_missing_checkpointed_chunks)):
# async with self._header_processing_lock:
# await self.headers.ensure_chunk_at(height)
# self._other_tasks.add(doit())
# await self.update_headers()
#
# async def update_headers(self, height=None, headers=None, subscription_update=False):
# rewound = 0
# while True:
#
# if height is None or height > len(self.headers):
# # sometimes header subscription updates are for a header in the future
# # which can't be connected, so we do a normal header sync instead
# height = len(self.headers)
# headers = None
# subscription_update = False
#
# if not headers:
# header_response = await self.network.retriable_call(self.network.get_headers, height, 2001)
# headers = header_response['hex']
#
# if not headers:
# # Nothing to do, network thinks we're already at the latest height.
# return
#
# added = await self.headers.connect(height, unhexlify(headers))
# if added > 0:
# height += added
# self._on_header_controller.add(
# BlockHeightEvent(self.headers.height, added))
#
# if rewound > 0:
# # we started rewinding blocks and apparently found
# # a new chain
# rewound = 0
# await self.db.rewind_blockchain(height)
#
# if subscription_update:
# # subscription updates are for latest header already
# # so we don't need to check if there are newer / more
# # on another loop of update_headers(), just return instead
# return
#
# elif added == 0:
# # we had headers to connect but none got connected, probably a reorganization
# height -= 1
# rewound += 1
# log.warning(
# "Blockchain Reorganization: attempting rewind to height %s from starting height %s",
# height, height+rewound
# )
#
# else:
# raise IndexError(f"headers.connect() returned negative number ({added})")
#
# if height < 0:
# raise IndexError(
# "Blockchain reorganization rewound all the way back to genesis hash. "
# "Something is very wrong. Maybe you are on the wrong blockchain?"
# )
#
# if rewound >= 100:
# raise IndexError(
# "Blockchain reorganization dropped {} headers. This is highly unusual. "
# "Will not continue to attempt reorganizing. Please, delete the ledger "
# "synchronization directory inside your wallet directory (folder: '{}') and "
# "restart the program to synchronize from scratch."
# .format(rewound, self.ledger.get_id())
# )
#
# headers = None # ready to download some more headers
#
# # if we made it this far and this was a subscription_update
# # it means something went wrong and now we're doing a more
# # robust sync, turn off subscription update shortcut
# subscription_update = False
#
# async def receive_header(self, response):
# async with self._header_processing_lock:
# header = response[0]
# await self.update_headers(
# height=header['height'], headers=header['hex'], subscription_update=True
# )
#
# async def subscribe_accounts(self):
# if self.network.is_connected and self.accounts:
# log.info("Subscribe to %i accounts", len(self.accounts))
# await asyncio.wait([
# self.subscribe_account(a) for a in self.accounts
# ])
#
# async def subscribe_account(self, account: Account):
# for address_manager in account.address_managers.values():
# await self.subscribe_addresses(address_manager, await address_manager.get_addresses())
# await account.ensure_address_gap()
#
# async def unsubscribe_account(self, account: Account):
# for address in await account.get_addresses():
# await self.network.unsubscribe_address(address)
#
# async def subscribe_addresses(
# self, address_manager: AddressManager, addresses: List[str], batch_size: int = 1000):
# if self.network.is_connected and addresses:
# addresses_remaining = list(addresses)
# while addresses_remaining:
# batch = addresses_remaining[:batch_size]
# results = await self.network.subscribe_address(*batch)
# for address, remote_status in zip(batch, results):
# self._update_tasks.add(self.update_history(address, remote_status, address_manager))
# addresses_remaining = addresses_remaining[batch_size:]
# log.info("subscribed to %i/%i addresses on %s:%i", len(addresses) - len(addresses_remaining),
# len(addresses), *self.network.client.server_address_and_port)
# log.info(
# "finished subscribing to %i addresses on %s:%i", len(addresses),
# *self.network.client.server_address_and_port
# )
#
# def process_status_update(self, update):
# address, remote_status = update
# self._update_tasks.add(self.update_history(address, remote_status))
#
# async def update_history(self, address, remote_status, address_manager: AddressManager = None):
# async with self._address_update_locks[address]:
# self._known_addresses_out_of_sync.discard(address)
#
# local_status, local_history = await self.get_local_status_and_history(address)
#
# if local_status == remote_status:
# return True
#
# remote_history = await self.network.retriable_call(self.network.get_history, address)
# remote_history = list(map(itemgetter('tx_hash', 'height'), remote_history))
# we_need = set(remote_history) - set(local_history)
# if not we_need:
# return True
#
# cache_tasks: List[asyncio.Task[Transaction]] = []
# synced_history = StringIO()
# loop = asyncio.get_running_loop()
# for i, (txid, remote_height) in enumerate(remote_history):
# if i < len(local_history) and local_history[i] == (txid, remote_height) and not cache_tasks:
# synced_history.write(f'{txid}:{remote_height}:')
# else:
# check_local = (txid, remote_height) not in we_need
# cache_tasks.append(loop.create_task(
# self.cache_transaction(unhexlify(txid)[::-1], remote_height, check_local=check_local)
# ))
#
# synced_txs = []
# for task in cache_tasks:
# tx = await task
#
# check_db_for_txos = []
# for txi in tx.inputs:
# if txi.txo_ref.txo is not None:
# continue
# cache_item = self._tx_cache.get(txi.txo_ref.tx_ref.hash)
# if cache_item is not None:
# if cache_item.tx is None:
# await cache_item.has_tx.wait()
# assert cache_item.tx is not None
# txi.txo_ref = cache_item.tx.outputs[txi.txo_ref.position].ref
# else:
# check_db_for_txos.append(txi.txo_ref.hash)
#
# referenced_txos = {} if not check_db_for_txos else {
# txo.id: txo for txo in await self.db.get_txos(
# txo_hash__in=check_db_for_txos, order_by='txo.txo_hash', no_tx=True
# )
# }
#
# for txi in tx.inputs:
# if txi.txo_ref.txo is not None:
# continue
# referenced_txo = referenced_txos.get(txi.txo_ref.id)
# if referenced_txo is not None:
# txi.txo_ref = referenced_txo.ref
#
# synced_history.write(f'{tx.id}:{tx.height}:')
# synced_txs.append(tx)
#
# await self.db.save_transaction_io_batch(
# synced_txs, address, self.ledger.address_to_hash160(address), synced_history.getvalue()
# )
# await asyncio.wait([
# self.ledger._on_transaction_controller.add(TransactionEvent(address, tx))
# for tx in synced_txs
# ])
#
# if address_manager is None:
# address_manager = await self.get_address_manager_for_address(address)
#
# if address_manager is not None:
# await address_manager.ensure_address_gap()
#
# local_status, local_history = \
# await self.get_local_status_and_history(address, synced_history.getvalue())
# if local_status != remote_status:
# if local_history == remote_history:
# return True
# log.warning(
# "Wallet is out of sync after syncing. Remote: %s with %d items, local: %s with %d items",
# remote_status, len(remote_history), local_status, len(local_history)
# )
# log.warning("local: %s", local_history)
# log.warning("remote: %s", remote_history)
# self._known_addresses_out_of_sync.add(address)
# return False
# else:
# return True
#
# async def cache_transaction(self, tx_hash, remote_height, check_local=True):
# cache_item = self._tx_cache.get(tx_hash)
# if cache_item is None:
# cache_item = self._tx_cache[tx_hash] = TransactionCacheItem()
# elif cache_item.tx is not None and \
# cache_item.tx.height >= remote_height and \
# (cache_item.tx.is_verified or remote_height < 1):
# return cache_item.tx # cached tx is already up-to-date
#
# try:
# cache_item.pending_verifications += 1
# return await self._update_cache_item(cache_item, tx_hash, remote_height, check_local)
# finally:
# cache_item.pending_verifications -= 1
#
# async def _update_cache_item(self, cache_item, tx_hash, remote_height, check_local=True):
#
# async with cache_item.lock:
#
# tx = cache_item.tx
#
# if tx is None and check_local:
# # check local db
# tx = cache_item.tx = await self.db.get_transaction(tx_hash=tx_hash)
#
# merkle = None
# if tx is None:
# # fetch from network
# _raw, merkle = await self.network.retriable_call(
# self.network.get_transaction_and_merkle, tx_hash, remote_height
# )
# tx = Transaction(unhexlify(_raw), height=merkle.get('block_height'))
# cache_item.tx = tx # make sure it's saved before caching it
# await self.maybe_verify_transaction(tx, remote_height, merkle)
# return tx
#
# async def maybe_verify_transaction(self, tx, remote_height, merkle=None):
# tx.height = remote_height
# cached = self._tx_cache.get(tx.hash)
# if not cached:
# # cache txs looked up by transaction_show too
# cached = TransactionCacheItem()
# cached.tx = tx
# self._tx_cache[tx.hash] = cached
# if 0 < remote_height < len(self.headers) and cached.pending_verifications <= 1:
# # can't be tx.pending_verifications == 1 because we have to handle the transaction_show case
# if not merkle:
# merkle = await self.network.retriable_call(self.network.get_merkle, tx.hash, remote_height)
# merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash)
# header = await self.headers.get(remote_height)
# tx.position = merkle['pos']
# tx.is_verified = merkle_root == header['merkle_root']
#
# async def get_address_manager_for_address(self, address) -> Optional[AddressManager]:
# details = await self.db.get_address(address=address)
# for account in self.accounts:
# if account.id == details['account']:
# return account.address_managers[details['chain']]
# return None

View file

@ -447,6 +447,38 @@ class IntegrationTestCase(AsyncioTestCase):
self.db_driver = db_driver
return db
async def add_full_node(self, port):
path = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, path, True)
ledger = RegTestLedger(Config.with_same_dir(path).set(
api=f'localhost:{port}',
lbrycrd_dir=self.chain.ledger.conf.lbrycrd_dir,
lbrycrd_rpc_port=self.chain.ledger.conf.lbrycrd_rpc_port,
lbrycrd_peer_port=self.chain.ledger.conf.lbrycrd_peer_port,
lbrycrd_zmq=self.chain.ledger.conf.lbrycrd_zmq
))
service = FullNode(ledger)
console = Console(service)
daemon = Daemon(service, console)
self.addCleanup(daemon.stop)
await daemon.start()
return daemon
async def add_light_client(self, full_node, port, start=True):
path = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, path, True)
ledger = RegTestLedger(Config.with_same_dir(path).set(
api=f'localhost:{port}',
full_nodes=[(full_node.conf.api_host, full_node.conf.api_port)]
))
service = LightClient(ledger)
console = Console(service)
daemon = Daemon(service, console)
self.addCleanup(daemon.stop)
if start:
await daemon.start()
return daemon
@staticmethod
def find_claim_txo(tx) -> Optional[Output]:
for txo in tx.outputs:
@ -538,9 +570,11 @@ class CommandTestCase(IntegrationTestCase):
await super().asyncSetUp()
await self.generate(200, wait=False)
self.full_node = self.daemon = await self.add_full_node()
if os.environ.get('TEST_MODE', 'full-node') == 'client':
self.daemon = await self.add_light_client(self.full_node)
self.daemon_port += 1
self.full_node = self.daemon = await self.add_full_node(self.daemon_port)
if os.environ.get('TEST_MODE', 'node') == 'client':
self.daemon_port += 1
self.daemon = await self.add_light_client(self.full_node, self.daemon_port)
self.service = self.daemon.service
self.ledger = self.service.ledger
@ -556,40 +590,6 @@ class CommandTestCase(IntegrationTestCase):
await self.chain.send_to_address(addresses[0], '10.0')
await self.generate(5)
async def add_full_node(self):
self.daemon_port += 1
path = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, path, True)
ledger = RegTestLedger(Config.with_same_dir(path).set(
api=f'localhost:{self.daemon_port}',
lbrycrd_dir=self.chain.ledger.conf.lbrycrd_dir,
lbrycrd_rpc_port=self.chain.ledger.conf.lbrycrd_rpc_port,
lbrycrd_peer_port=self.chain.ledger.conf.lbrycrd_peer_port,
lbrycrd_zmq=self.chain.ledger.conf.lbrycrd_zmq,
spv_address_filters=False
))
service = FullNode(ledger)
console = Console(service)
daemon = Daemon(service, console)
self.addCleanup(daemon.stop)
await daemon.start()
return daemon
async def add_light_client(self, full_node):
self.daemon_port += 1
path = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, path, True)
ledger = RegTestLedger(Config.with_same_dir(path).set(
api=f'localhost:{self.daemon_port}',
full_nodes=[(full_node.conf.api_host, full_node.conf.api_port)]
))
service = LightClient(ledger)
console = Console(service)
daemon = Daemon(service, console)
self.addCleanup(daemon.stop)
await daemon.start()
return daemon
async def asyncTearDown(self):
await super().asyncTearDown()
for wallet_node in self.extra_wallet_nodes:

View file

@ -12,6 +12,7 @@ import ecdsa
from lbry.constants import COIN
from lbry.db import Database, CLAIM_TYPE_CODES, TXO_TYPES
from lbry.db.tables import AccountAddress
from lbry.blockchain import Ledger
from lbry.error import InvalidPasswordError
from lbry.crypto.crypt import aes_encrypt, aes_decrypt
@ -214,12 +215,12 @@ class Account:
HierarchicalDeterministic.name: HierarchicalDeterministic,
}
def __init__(self, ledger: Ledger, db: Database, name: str,
def __init__(self, db: Database, name: str,
phrase: str, language: str, private_key_string: str,
encrypted: bool, private_key: Optional[PrivateKey], public_key: PubKey,
address_generator: dict, modified_on: float, channel_keys: dict) -> None:
self.ledger = ledger
self.db = db
self.ledger = db.ledger
self.id = public_key.address
self.name = name
self.phrase = phrase
@ -245,10 +246,10 @@ class Account:
@classmethod
async def generate(
cls, ledger: Ledger, db: Database,
name: str = None, language: str = 'en',
address_generator: dict = None):
return await cls.from_dict(ledger, db, {
cls, db: Database, name: str = None,
language: str = 'en', address_generator: dict = None
):
return await cls.from_dict(db, {
'name': name,
'seed': await mnemonic.generate_phrase(language),
'language': language,
@ -276,13 +277,12 @@ class Account:
return phrase, private_key, public_key
@classmethod
async def from_dict(cls, ledger: Ledger, db: Database, d: dict):
phrase, private_key, public_key = await cls.keys_from_dict(ledger, d)
async def from_dict(cls, db: Database, d: dict):
phrase, private_key, public_key = await cls.keys_from_dict(db.ledger, d)
name = d.get('name')
if not name:
name = f'Account #{public_key.address}'
return cls(
ledger=ledger,
db=db,
name=name,
phrase=phrase,
@ -415,7 +415,7 @@ class Account:
return await self.db.get_addresses(account=self, **constraints)
async def get_addresses(self, **constraints) -> List[str]:
rows = await self.get_address_records(cols=['account_address.address'], **constraints)
rows = await self.get_address_records(cols=[AccountAddress.c.address], **constraints)
return [r['address'] for r in rows]
async def get_valid_receiving_address(self, default_address: str) -> str:

View file

@ -1,385 +0,0 @@
import os
import logging
import asyncio
import sqlite3
import platform
from binascii import hexlify
from dataclasses import dataclass
from contextvars import ContextVar
from concurrent.futures.thread import ThreadPoolExecutor
from concurrent.futures.process import ProcessPoolExecutor
from typing import Tuple, List, Union, Callable, Any, Awaitable, Iterable, Dict, Optional
log = logging.getLogger(__name__)
sqlite3.enable_callback_tracebacks(True)
@dataclass
class ReaderProcessState:
cursor: sqlite3.Cursor
reader_context: Optional[ContextVar[ReaderProcessState]] = ContextVar('reader_context')
def initializer(path):
db = sqlite3.connect(path)
db.row_factory = dict_row_factory
db.executescript("pragma journal_mode=WAL;")
reader = ReaderProcessState(db.cursor())
reader_context.set(reader)
def run_read_only_fetchall(sql, params):
cursor = reader_context.get().cursor
try:
return cursor.execute(sql, params).fetchall()
except (Exception, OSError) as e:
log.exception('Error running transaction:', exc_info=e)
raise
def run_read_only_fetchone(sql, params):
cursor = reader_context.get().cursor
try:
return cursor.execute(sql, params).fetchone()
except (Exception, OSError) as e:
log.exception('Error running transaction:', exc_info=e)
raise
if platform.system() == 'Windows' or 'ANDROID_ARGUMENT' in os.environ:
ReaderExecutorClass = ThreadPoolExecutor
else:
ReaderExecutorClass = ProcessPoolExecutor
class AIOSQLite:
reader_executor: ReaderExecutorClass
def __init__(self):
# has to be single threaded as there is no mapping of thread:connection
self.writer_executor = ThreadPoolExecutor(max_workers=1)
self.writer_connection: Optional[sqlite3.Connection] = None
self._closing = False
self.query_count = 0
self.write_lock = asyncio.Lock()
self.writers = 0
self.read_ready = asyncio.Event()
@classmethod
async def connect(cls, path: Union[bytes, str], *args, **kwargs):
sqlite3.enable_callback_tracebacks(True)
db = cls()
def _connect_writer():
db.writer_connection = sqlite3.connect(path, *args, **kwargs)
readers = max(os.cpu_count() - 2, 2)
db.reader_executor = ReaderExecutorClass(
max_workers=readers, initializer=initializer, initargs=(path, )
)
await asyncio.get_event_loop().run_in_executor(db.writer_executor, _connect_writer)
db.read_ready.set()
return db
async def close(self):
if self._closing:
return
self._closing = True
await asyncio.get_event_loop().run_in_executor(self.writer_executor, self.writer_connection.close)
self.writer_executor.shutdown(wait=True)
self.reader_executor.shutdown(wait=True)
self.read_ready.clear()
self.writer_connection = None
def executemany(self, sql: str, params: Iterable):
params = params if params is not None else []
# this fetchall is needed to prevent SQLITE_MISUSE
return self.run(lambda conn: conn.executemany(sql, params).fetchall())
def executescript(self, script: str) -> Awaitable:
return self.run(lambda conn: conn.executescript(script))
async def _execute_fetch(self, sql: str, parameters: Iterable = None,
read_only=False, fetch_all: bool = False) -> List[dict]:
read_only_fn = run_read_only_fetchall if fetch_all else run_read_only_fetchone
parameters = parameters if parameters is not None else []
if read_only:
while self.writers:
await self.read_ready.wait()
return await asyncio.get_event_loop().run_in_executor(
self.reader_executor, read_only_fn, sql, parameters
)
if fetch_all:
return await self.run(lambda conn: conn.execute(sql, parameters).fetchall())
return await self.run(lambda conn: conn.execute(sql, parameters).fetchone())
async def execute_fetchall(self, sql: str, parameters: Iterable = None,
read_only=False) -> List[dict]:
return await self._execute_fetch(sql, parameters, read_only, fetch_all=True)
async def execute_fetchone(self, sql: str, parameters: Iterable = None,
read_only=False) -> List[dict]:
return await self._execute_fetch(sql, parameters, read_only, fetch_all=False)
def execute(self, sql: str, parameters: Iterable = None) -> Awaitable[sqlite3.Cursor]:
parameters = parameters if parameters is not None else []
return self.run(lambda conn: conn.execute(sql, parameters))
async def run(self, fun, *args, **kwargs):
self.writers += 1
self.read_ready.clear()
async with self.write_lock:
try:
return await asyncio.get_event_loop().run_in_executor(
self.writer_executor, lambda: self.__run_transaction(fun, *args, **kwargs)
)
finally:
self.writers -= 1
if not self.writers:
self.read_ready.set()
def __run_transaction(self, fun: Callable[[sqlite3.Connection, Any, Any], Any], *args, **kwargs):
self.writer_connection.execute('begin')
try:
self.query_count += 1
result = fun(self.writer_connection, *args, **kwargs) # type: ignore
self.writer_connection.commit()
return result
except (Exception, OSError) as e:
log.exception('Error running transaction:', exc_info=e)
self.writer_connection.rollback()
log.warning("rolled back")
raise
def run_with_foreign_keys_disabled(self, fun, *args, **kwargs) -> Awaitable:
return asyncio.get_event_loop().run_in_executor(
self.writer_executor, self.__run_transaction_with_foreign_keys_disabled, fun, args, kwargs
)
def __run_transaction_with_foreign_keys_disabled(self,
fun: Callable[[sqlite3.Connection, Any, Any], Any],
args, kwargs):
foreign_keys_enabled, = self.writer_connection.execute("pragma foreign_keys").fetchone()
if not foreign_keys_enabled:
raise sqlite3.IntegrityError("foreign keys are disabled, use `AIOSQLite.run` instead")
try:
self.writer_connection.execute('pragma foreign_keys=off').fetchone()
return self.__run_transaction(fun, *args, **kwargs)
finally:
self.writer_connection.execute('pragma foreign_keys=on').fetchone()
def constraints_to_sql(constraints, joiner=' AND ', prepend_key=''):
sql, values = [], {}
for key, constraint in constraints.items():
tag = '0'
if '#' in key:
key, tag = key[:key.index('#')], key[key.index('#')+1:]
col, op, key = key, '=', key.replace('.', '_')
if not key:
sql.append(constraint)
continue
if key.startswith('$$'):
col, key = col[2:], key[1:]
elif key.startswith('$'):
values[key] = constraint
continue
if key.endswith('__not'):
col, op = col[:-len('__not')], '!='
elif key.endswith('__is_null'):
col = col[:-len('__is_null')]
sql.append(f'{col} IS NULL')
continue
if key.endswith('__is_not_null'):
col = col[:-len('__is_not_null')]
sql.append(f'{col} IS NOT NULL')
continue
if key.endswith('__lt'):
col, op = col[:-len('__lt')], '<'
elif key.endswith('__lte'):
col, op = col[:-len('__lte')], '<='
elif key.endswith('__gt'):
col, op = col[:-len('__gt')], '>'
elif key.endswith('__gte'):
col, op = col[:-len('__gte')], '>='
elif key.endswith('__like'):
col, op = col[:-len('__like')], 'LIKE'
elif key.endswith('__not_like'):
col, op = col[:-len('__not_like')], 'NOT LIKE'
elif key.endswith('__in') or key.endswith('__not_in'):
if key.endswith('__in'):
col, op, one_val_op = col[:-len('__in')], 'IN', '='
else:
col, op, one_val_op = col[:-len('__not_in')], 'NOT IN', '!='
if constraint:
if isinstance(constraint, (list, set, tuple)):
if len(constraint) == 1:
values[f'{key}{tag}'] = next(iter(constraint))
sql.append(f'{col} {one_val_op} :{key}{tag}')
else:
keys = []
for i, val in enumerate(constraint):
keys.append(f':{key}{tag}_{i}')
values[f'{key}{tag}_{i}'] = val
sql.append(f'{col} {op} ({", ".join(keys)})')
elif isinstance(constraint, str):
sql.append(f'{col} {op} ({constraint})')
else:
raise ValueError(f"{col} requires a list, set or string as constraint value.")
continue
elif key.endswith('__any') or key.endswith('__or'):
where, subvalues = constraints_to_sql(constraint, ' OR ', key+tag+'_')
sql.append(f'({where})')
values.update(subvalues)
continue
if key.endswith('__and'):
where, subvalues = constraints_to_sql(constraint, ' AND ', key+tag+'_')
sql.append(f'({where})')
values.update(subvalues)
continue
sql.append(f'{col} {op} :{prepend_key}{key}{tag}')
values[prepend_key+key+tag] = constraint
return joiner.join(sql) if sql else '', values
def query(select, **constraints) -> Tuple[str, Dict[str, Any]]:
sql = [select]
limit = constraints.pop('limit', None)
offset = constraints.pop('offset', None)
order_by = constraints.pop('order_by', None)
group_by = constraints.pop('group_by', None)
accounts = constraints.pop('accounts', [])
if accounts:
constraints['account__in'] = [a.public_key.address for a in accounts]
where, values = constraints_to_sql(constraints)
if where:
sql.append('WHERE')
sql.append(where)
if group_by is not None:
sql.append(f'GROUP BY {group_by}')
if order_by:
sql.append('ORDER BY')
if isinstance(order_by, str):
sql.append(order_by)
elif isinstance(order_by, list):
sql.append(', '.join(order_by))
else:
raise ValueError("order_by must be string or list")
if limit is not None:
sql.append(f'LIMIT {limit}')
if offset is not None:
sql.append(f'OFFSET {offset}')
return ' '.join(sql), values
def interpolate(sql, values):
for k in sorted(values.keys(), reverse=True):
value = values[k]
if isinstance(value, bytes):
value = f"X'{hexlify(value).decode()}'"
elif isinstance(value, str):
value = f"'{value}'"
else:
value = str(value)
sql = sql.replace(f":{k}", value)
return sql
def constrain_single_or_list(constraints, column, value, convert=lambda x: x):
if value is not None:
if isinstance(value, list):
value = [convert(v) for v in value]
if len(value) == 1:
constraints[column] = value[0]
elif len(value) > 1:
constraints[f"{column}__in"] = value
else:
constraints[column] = convert(value)
return constraints
class SQLiteMixin:
SCHEMA_VERSION: Optional[str] = None
CREATE_TABLES_QUERY: str
MAX_QUERY_VARIABLES = 900
CREATE_VERSION_TABLE = """
create table if not exists version (
version text
);
"""
def __init__(self, path):
self._db_path = path
self.db: AIOSQLite = None
self.ledger = None
async def open(self):
log.info("connecting to database: %s", self._db_path)
self.db = await AIOSQLite.connect(self._db_path, isolation_level=None)
if self.SCHEMA_VERSION:
tables = [t[0] for t in await self.db.execute_fetchall(
"SELECT name FROM sqlite_master WHERE type='table';"
)]
if tables:
if 'version' in tables:
version = await self.db.execute_fetchone("SELECT version FROM version LIMIT 1;")
if version == (self.SCHEMA_VERSION,):
return
await self.db.executescript('\n'.join(
f"DROP TABLE {table};" for table in tables
))
await self.db.execute(self.CREATE_VERSION_TABLE)
await self.db.execute("INSERT INTO version VALUES (?)", (self.SCHEMA_VERSION,))
await self.db.executescript(self.CREATE_TABLES_QUERY)
async def close(self):
await self.db.close()
@staticmethod
def _insert_sql(table: str, data: dict, ignore_duplicate: bool = False,
replace: bool = False) -> Tuple[str, List]:
columns, values = [], []
for column, value in data.items():
columns.append(column)
values.append(value)
policy = ""
if ignore_duplicate:
policy = " OR IGNORE"
if replace:
policy = " OR REPLACE"
sql = "INSERT{} INTO {} ({}) VALUES ({})".format(
policy, table, ', '.join(columns), ', '.join(['?'] * len(values))
)
return sql, values
@staticmethod
def _update_sql(table: str, data: dict, where: str,
constraints: Union[list, tuple]) -> Tuple[str, list]:
columns, values = [], []
for column, value in data.items():
columns.append(f"{column} = ?")
values.append(value)
values.extend(constraints)
sql = "UPDATE {} SET {} WHERE {}".format(
table, ', '.join(columns), where
)
return sql, values
def dict_row_factory(cursor, row):
d = {}
for idx, col in enumerate(cursor.description):
d[col[0]] = row[idx]
return d

View file

@ -1,22 +1,30 @@
import os
import stat
import json
import asyncio
import logging
from typing import Optional, Dict
from lbry.db import Database
from lbry.blockchain.ledger import Ledger
from .wallet import Wallet
from .account import SingleKey, HierarchicalDeterministic
log = logging.getLogger(__name__)
class WalletManager:
def __init__(self, ledger: Ledger, db: Database):
self.ledger = ledger
def __init__(self, db: Database):
self.db = db
self.ledger = db.ledger
self.wallets: Dict[str, Wallet] = {}
if self.ledger.conf.wallet_storage == "file":
self.storage = FileWallet(self.db, self.ledger.conf.wallet_dir)
elif self.ledger.conf.wallet_storage == "database":
self.storage = DatabaseWallet(self.db)
else:
raise Exception(f"Unknown wallet storage format: {self.ledger.conf.wallet_storage}")
def __getitem__(self, wallet_id: str) -> Wallet:
try:
@ -43,32 +51,12 @@ class WalletManager:
raise ValueError("Cannot spend funds with locked wallet, unlock first.")
return wallet
@property
def path(self):
return os.path.join(self.ledger.conf.wallet_dir, 'wallets')
def sync_ensure_path_exists(self):
if not os.path.exists(self.path):
os.mkdir(self.path)
async def ensure_path_exists(self):
await asyncio.get_running_loop().run_in_executor(
None, self.sync_ensure_path_exists
)
async def load(self):
wallets_directory = self.path
async def initialize(self):
for wallet_id in self.ledger.conf.wallets:
if wallet_id in self.wallets:
log.warning("Ignoring duplicate wallet_id in config: %s", wallet_id)
continue
wallet_path = os.path.join(wallets_directory, wallet_id)
if not os.path.exists(wallet_path):
if not wallet_id == "default_wallet": # we'll probably generate this wallet, don't show error
log.error("Could not load wallet, file does not exist: %s", wallet_path)
continue
wallet = await Wallet.from_path(self.ledger, self.db, wallet_path)
self.add(wallet)
await self.load(wallet_id)
default_wallet = self.default
if default_wallet is None:
if self.ledger.conf.create_default_wallet:
@ -83,34 +71,153 @@ class WalletManager:
elif not default_wallet.has_accounts and self.ledger.conf.create_default_account:
await default_wallet.accounts.generate()
def add(self, wallet: Wallet) -> Wallet:
self.wallets[wallet.id] = wallet
return wallet
async def add_from_path(self, wallet_path) -> Wallet:
wallet_id = os.path.basename(wallet_path)
if wallet_id in self.wallets:
existing = self.wallets.get(wallet_id)
if existing.storage.path == wallet_path:
raise Exception(f"Wallet '{wallet_id}' is already loaded.")
raise Exception(
f"Wallet '{wallet_id}' is already loaded from '{existing.storage.path}'"
f" and cannot be loaded from '{wallet_path}'. Consider changing the wallet"
f" filename to be unique in order to avoid conflicts."
)
wallet = await Wallet.from_path(self.ledger, self.db, wallet_path)
async def load(self, wallet_id: str) -> Optional[Wallet]:
wallet = await self.storage.get(wallet_id)
if wallet is not None:
return self.add(wallet)
async def create(
self, wallet_id: str, name: str,
create_account=False, language='en', single_key=False) -> Wallet:
self, wallet_id: str, name: str = "",
create_account=False, language="en", single_key=False
) -> Wallet:
if wallet_id in self.wallets:
raise Exception(f"Wallet with id '{wallet_id}' is already loaded and cannot be created.")
wallet_path = os.path.join(self.path, wallet_id)
if os.path.exists(wallet_path):
raise Exception(f"Wallet at path '{wallet_path}' already exists, use 'wallet_add' to load wallet.")
wallet = await Wallet.create(
self.ledger, self.db, wallet_path, name,
create_account, language, single_key
)
if await self.storage.exists(wallet_id):
raise Exception(f"Wallet '{wallet_id}' already exists, use 'wallet_add' to load wallet.")
wallet = Wallet(wallet_id, self.db, name)
if create_account:
await wallet.accounts.generate(language=language, address_generator={
'name': SingleKey.name if single_key else HierarchicalDeterministic.name
})
await self.storage.save(wallet)
return self.add(wallet)
def add(self, wallet: Wallet) -> Wallet:
self.wallets[wallet.id] = wallet
wallet.on_change.listen(lambda _: self.storage.save(wallet))
return wallet
async def _report_state(self):
try:
for account in self.accounts:
balance = dewies_to_lbc(await account.get_balance(include_claims=True))
_, channel_count = await account.get_channels(limit=1)
claim_count = await account.get_claim_count()
if isinstance(account.receiving, SingleKey):
log.info("Loaded single key account %s with %s LBC. "
"%d channels, %d certificates and %d claims",
account.id, balance, channel_count, len(account.channel_keys), claim_count)
else:
total_receiving = len(await account.receiving.get_addresses())
total_change = len(await account.change.get_addresses())
log.info("Loaded account %s with %s LBC, %d receiving addresses (gap: %d), "
"%d change addresses (gap: %d), %d channels, %d certificates and %d claims. ",
account.id, balance, total_receiving, account.receiving.gap, total_change,
account.change.gap, channel_count, len(account.channel_keys), claim_count)
except Exception as err:
if isinstance(err, asyncio.CancelledError): # TODO: remove when updated to 3.8
raise
log.exception(
'Failed to display wallet state, please file issue '
'for this bug along with the traceback you see below:'
)
class WalletStorage:
async def prepare(self):
raise NotImplementedError
async def exists(self, walllet_id: str) -> bool:
raise NotImplementedError
async def get(self, wallet_id: str) -> Wallet:
raise NotImplementedError
async def save(self, wallet: Wallet):
raise NotImplementedError
class FileWallet(WalletStorage):
def __init__(self, db, wallet_dir):
self.db = db
self.wallet_dir = wallet_dir
def get_wallet_path(self, wallet_id: str):
return os.path.join(self.wallet_dir, wallet_id)
async def prepare(self):
await asyncio.get_running_loop().run_in_executor(
None, self.sync_ensure_wallets_directory_exists
)
def sync_ensure_wallets_directory_exists(self):
if not os.path.exists(self.wallet_dir):
os.mkdir(self.wallet_dir)
async def exists(self, wallet_id: str) -> bool:
return os.path.exists(self.get_wallet_path(wallet_id))
async def get(self, wallet_id: str) -> Wallet:
wallet_dict = await asyncio.get_running_loop().run_in_executor(
None, self.sync_read, wallet_id
)
if wallet_dict is not None:
return await Wallet.from_dict(wallet_id, wallet_dict, self.db)
def sync_read(self, wallet_id):
try:
with open(self.get_wallet_path(wallet_id), 'r') as f:
json_data = f.read()
return json.loads(json_data)
except FileNotFoundError:
return None
async def save(self, wallet: Wallet):
return await asyncio.get_running_loop().run_in_executor(
None, self.sync_write, wallet
)
def sync_write(self, wallet: Wallet):
temp_path = os.path.join(self.wallet_dir, f".tmp.{os.path.basename(wallet.id)}")
with open(temp_path, "w") as f:
f.write(wallet.to_serialized())
f.flush()
os.fsync(f.fileno())
wallet_path = self.get_wallet_path(wallet.id)
if os.path.exists(wallet_path):
mode = os.stat(wallet_path).st_mode
else:
mode = stat.S_IREAD | stat.S_IWRITE
try:
os.rename(temp_path, wallet_path)
except Exception: # pylint: disable=broad-except
os.remove(wallet_path)
os.rename(temp_path, wallet_path)
os.chmod(wallet_path, mode)
class DatabaseWallet(WalletStorage):
def __init__(self, db: 'Database'):
self.db = db
async def prepare(self):
pass
async def exists(self, wallet_id: str) -> bool:
return await self.db.has_wallet(wallet_id)
async def get(self, wallet_id: str) -> Wallet:
data = await self.db.get_wallet(wallet_id)
if data:
wallet_dict = json.loads(data['data'])
if wallet_dict is not None:
return await Wallet.from_dict(wallet_id, wallet_dict, self.db)
async def save(self, wallet: Wallet):
await self.db.add_wallet(
wallet.id, wallet.to_serialized()
)

View file

@ -1,60 +0,0 @@
import os
import stat
import json
import asyncio
class WalletStorage:
VERSION = 1
def __init__(self, path=None):
self.path = path
def sync_read(self):
with open(self.path, 'r') as f:
json_data = f.read()
json_dict = json.loads(json_data)
if json_dict.get('version') == self.VERSION:
return json_dict
else:
return self.upgrade(json_dict)
async def read(self):
return await asyncio.get_running_loop().run_in_executor(
None, self.sync_read
)
def upgrade(self, json_dict):
version = json_dict.pop('version', -1)
if version == -1:
pass
json_dict['version'] = self.VERSION
return json_dict
def sync_write(self, json_dict):
json_data = json.dumps(json_dict, indent=4, sort_keys=True)
if self.path is None:
return json_data
temp_path = "{}.tmp.{}".format(self.path, os.getpid())
with open(temp_path, "w") as f:
f.write(json_data)
f.flush()
os.fsync(f.fileno())
if os.path.exists(self.path):
mode = os.stat(self.path).st_mode
else:
mode = stat.S_IREAD | stat.S_IWRITE
try:
os.rename(temp_path, self.path)
except Exception: # pylint: disable=broad-except
os.remove(self.path)
os.rename(temp_path, self.path)
os.chmod(self.path, mode)
async def write(self, json_dict):
return await asyncio.get_running_loop().run_in_executor(
None, self.sync_write, json_dict
)

View file

@ -1,420 +0,0 @@
import asyncio
import logging
#from io import StringIO
#from functools import partial
#from operator import itemgetter
from collections import defaultdict
#from binascii import hexlify, unhexlify
from typing import List, Optional, DefaultDict, NamedTuple
#from lbry.crypto.hash import double_sha256, sha256
from lbry.tasks import TaskGroup
from lbry.blockchain.transaction import Transaction
#from lbry.blockchain.block import get_block_filter
from lbry.event import EventController
from lbry.service.base import Service, Sync
from .account import AddressManager
log = logging.getLogger(__name__)
class TransactionEvent(NamedTuple):
address: str
tx: Transaction
class AddressesGeneratedEvent(NamedTuple):
address_manager: AddressManager
addresses: List[str]
class TransactionCacheItem:
__slots__ = '_tx', 'lock', 'has_tx', 'pending_verifications'
def __init__(self, tx: Optional[Transaction] = None, lock: Optional[asyncio.Lock] = None):
self.has_tx = asyncio.Event()
self.lock = lock or asyncio.Lock()
self._tx = self.tx = tx
self.pending_verifications = 0
@property
def tx(self) -> Optional[Transaction]:
return self._tx
@tx.setter
def tx(self, tx: Transaction):
self._tx = tx
if tx is not None:
self.has_tx.set()
class SPVSync(Sync):
def __init__(self, service: Service):
super().__init__(service.ledger, service.db)
self.accounts = []
self._on_header_controller = EventController()
self.on_header = self._on_header_controller.stream
self.on_header.listen(
lambda change: log.info(
'%s: added %s header blocks',
self.ledger.get_id(), change
)
)
self._download_height = 0
self._on_ready_controller = EventController()
self.on_ready = self._on_ready_controller.stream
#self._tx_cache = pylru.lrucache(100000)
self._update_tasks = TaskGroup()
self._other_tasks = TaskGroup() # that we dont need to start
self._header_processing_lock = asyncio.Lock()
self._address_update_locks: DefaultDict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
self._known_addresses_out_of_sync = set()
# async def advance(self):
# address_array = [
# bytearray(a['address'].encode())
# for a in await self.service.db.get_all_addresses()
# ]
# block_filters = await self.service.get_block_address_filters()
# for block_hash, block_filter in block_filters.items():
# bf = get_block_filter(block_filter)
# if bf.MatchAny(address_array):
# print(f'match: {block_hash} - {block_filter}')
# tx_filters = await self.service.get_transaction_address_filters(block_hash=block_hash)
# for txid, tx_filter in tx_filters.items():
# tf = get_block_filter(tx_filter)
# if tf.MatchAny(address_array):
# print(f' match: {txid} - {tx_filter}')
# txs = await self.service.search_transactions([txid])
# tx = Transaction(unhexlify(txs[txid]))
# await self.service.db.insert_transaction(tx)
#
# async def get_local_status_and_history(self, address, history=None):
# if not history:
# address_details = await self.db.get_address(address=address)
# history = (address_details['history'] if address_details else '') or ''
# parts = history.split(':')[:-1]
# return (
# hexlify(sha256(history.encode())).decode() if history else None,
# list(zip(parts[0::2], map(int, parts[1::2])))
# )
#
# @staticmethod
# def get_root_of_merkle_tree(branches, branch_positions, working_branch):
# for i, branch in enumerate(branches):
# other_branch = unhexlify(branch)[::-1]
# other_branch_on_left = bool((branch_positions >> i) & 1)
# if other_branch_on_left:
# combined = other_branch + working_branch
# else:
# combined = working_branch + other_branch
# working_branch = double_sha256(combined)
# return hexlify(working_branch[::-1])
#
async def start(self):
fully_synced = self.on_ready.first
#asyncio.create_task(self.network.start())
#await self.network.on_connected.first
#async with self._header_processing_lock:
# await self._update_tasks.add(self.initial_headers_sync())
await fully_synced
#
# async def join_network(self, *_):
# log.info("Subscribing and updating accounts.")
# await self._update_tasks.add(self.subscribe_accounts())
# await self._update_tasks.done.wait()
# self._on_ready_controller.add(True)
#
async def stop(self):
self._update_tasks.cancel()
self._other_tasks.cancel()
await self._update_tasks.done.wait()
await self._other_tasks.done.wait()
#
# @property
# def local_height_including_downloaded_height(self):
# return max(self.headers.height, self._download_height)
#
# async def initial_headers_sync(self):
# get_chunk = partial(self.network.retriable_call, self.network.get_headers, count=1000, b64=True)
# self.headers.chunk_getter = get_chunk
#
# async def doit():
# for height in reversed(sorted(self.headers.known_missing_checkpointed_chunks)):
# async with self._header_processing_lock:
# await self.headers.ensure_chunk_at(height)
# self._other_tasks.add(doit())
# await self.update_headers()
#
# async def update_headers(self, height=None, headers=None, subscription_update=False):
# rewound = 0
# while True:
#
# if height is None or height > len(self.headers):
# # sometimes header subscription updates are for a header in the future
# # which can't be connected, so we do a normal header sync instead
# height = len(self.headers)
# headers = None
# subscription_update = False
#
# if not headers:
# header_response = await self.network.retriable_call(self.network.get_headers, height, 2001)
# headers = header_response['hex']
#
# if not headers:
# # Nothing to do, network thinks we're already at the latest height.
# return
#
# added = await self.headers.connect(height, unhexlify(headers))
# if added > 0:
# height += added
# self._on_header_controller.add(
# BlockHeightEvent(self.headers.height, added))
#
# if rewound > 0:
# # we started rewinding blocks and apparently found
# # a new chain
# rewound = 0
# await self.db.rewind_blockchain(height)
#
# if subscription_update:
# # subscription updates are for latest header already
# # so we don't need to check if there are newer / more
# # on another loop of update_headers(), just return instead
# return
#
# elif added == 0:
# # we had headers to connect but none got connected, probably a reorganization
# height -= 1
# rewound += 1
# log.warning(
# "Blockchain Reorganization: attempting rewind to height %s from starting height %s",
# height, height+rewound
# )
#
# else:
# raise IndexError(f"headers.connect() returned negative number ({added})")
#
# if height < 0:
# raise IndexError(
# "Blockchain reorganization rewound all the way back to genesis hash. "
# "Something is very wrong. Maybe you are on the wrong blockchain?"
# )
#
# if rewound >= 100:
# raise IndexError(
# "Blockchain reorganization dropped {} headers. This is highly unusual. "
# "Will not continue to attempt reorganizing. Please, delete the ledger "
# "synchronization directory inside your wallet directory (folder: '{}') and "
# "restart the program to synchronize from scratch."
# .format(rewound, self.ledger.get_id())
# )
#
# headers = None # ready to download some more headers
#
# # if we made it this far and this was a subscription_update
# # it means something went wrong and now we're doing a more
# # robust sync, turn off subscription update shortcut
# subscription_update = False
#
# async def receive_header(self, response):
# async with self._header_processing_lock:
# header = response[0]
# await self.update_headers(
# height=header['height'], headers=header['hex'], subscription_update=True
# )
#
# async def subscribe_accounts(self):
# if self.network.is_connected and self.accounts:
# log.info("Subscribe to %i accounts", len(self.accounts))
# await asyncio.wait([
# self.subscribe_account(a) for a in self.accounts
# ])
#
# async def subscribe_account(self, account: Account):
# for address_manager in account.address_managers.values():
# await self.subscribe_addresses(address_manager, await address_manager.get_addresses())
# await account.ensure_address_gap()
#
# async def unsubscribe_account(self, account: Account):
# for address in await account.get_addresses():
# await self.network.unsubscribe_address(address)
#
# async def subscribe_addresses(
# self, address_manager: AddressManager, addresses: List[str], batch_size: int = 1000):
# if self.network.is_connected and addresses:
# addresses_remaining = list(addresses)
# while addresses_remaining:
# batch = addresses_remaining[:batch_size]
# results = await self.network.subscribe_address(*batch)
# for address, remote_status in zip(batch, results):
# self._update_tasks.add(self.update_history(address, remote_status, address_manager))
# addresses_remaining = addresses_remaining[batch_size:]
# log.info("subscribed to %i/%i addresses on %s:%i", len(addresses) - len(addresses_remaining),
# len(addresses), *self.network.client.server_address_and_port)
# log.info(
# "finished subscribing to %i addresses on %s:%i", len(addresses),
# *self.network.client.server_address_and_port
# )
#
# def process_status_update(self, update):
# address, remote_status = update
# self._update_tasks.add(self.update_history(address, remote_status))
#
# async def update_history(self, address, remote_status, address_manager: AddressManager = None):
# async with self._address_update_locks[address]:
# self._known_addresses_out_of_sync.discard(address)
#
# local_status, local_history = await self.get_local_status_and_history(address)
#
# if local_status == remote_status:
# return True
#
# remote_history = await self.network.retriable_call(self.network.get_history, address)
# remote_history = list(map(itemgetter('tx_hash', 'height'), remote_history))
# we_need = set(remote_history) - set(local_history)
# if not we_need:
# return True
#
# cache_tasks: List[asyncio.Task[Transaction]] = []
# synced_history = StringIO()
# loop = asyncio.get_running_loop()
# for i, (txid, remote_height) in enumerate(remote_history):
# if i < len(local_history) and local_history[i] == (txid, remote_height) and not cache_tasks:
# synced_history.write(f'{txid}:{remote_height}:')
# else:
# check_local = (txid, remote_height) not in we_need
# cache_tasks.append(loop.create_task(
# self.cache_transaction(unhexlify(txid)[::-1], remote_height, check_local=check_local)
# ))
#
# synced_txs = []
# for task in cache_tasks:
# tx = await task
#
# check_db_for_txos = []
# for txi in tx.inputs:
# if txi.txo_ref.txo is not None:
# continue
# cache_item = self._tx_cache.get(txi.txo_ref.tx_ref.hash)
# if cache_item is not None:
# if cache_item.tx is None:
# await cache_item.has_tx.wait()
# assert cache_item.tx is not None
# txi.txo_ref = cache_item.tx.outputs[txi.txo_ref.position].ref
# else:
# check_db_for_txos.append(txi.txo_ref.hash)
#
# referenced_txos = {} if not check_db_for_txos else {
# txo.id: txo for txo in await self.db.get_txos(
# txo_hash__in=check_db_for_txos, order_by='txo.txo_hash', no_tx=True
# )
# }
#
# for txi in tx.inputs:
# if txi.txo_ref.txo is not None:
# continue
# referenced_txo = referenced_txos.get(txi.txo_ref.id)
# if referenced_txo is not None:
# txi.txo_ref = referenced_txo.ref
#
# synced_history.write(f'{tx.id}:{tx.height}:')
# synced_txs.append(tx)
#
# await self.db.save_transaction_io_batch(
# synced_txs, address, self.ledger.address_to_hash160(address), synced_history.getvalue()
# )
# await asyncio.wait([
# self.ledger._on_transaction_controller.add(TransactionEvent(address, tx))
# for tx in synced_txs
# ])
#
# if address_manager is None:
# address_manager = await self.get_address_manager_for_address(address)
#
# if address_manager is not None:
# await address_manager.ensure_address_gap()
#
# local_status, local_history = \
# await self.get_local_status_and_history(address, synced_history.getvalue())
# if local_status != remote_status:
# if local_history == remote_history:
# return True
# log.warning(
# "Wallet is out of sync after syncing. Remote: %s with %d items, local: %s with %d items",
# remote_status, len(remote_history), local_status, len(local_history)
# )
# log.warning("local: %s", local_history)
# log.warning("remote: %s", remote_history)
# self._known_addresses_out_of_sync.add(address)
# return False
# else:
# return True
#
# async def cache_transaction(self, tx_hash, remote_height, check_local=True):
# cache_item = self._tx_cache.get(tx_hash)
# if cache_item is None:
# cache_item = self._tx_cache[tx_hash] = TransactionCacheItem()
# elif cache_item.tx is not None and \
# cache_item.tx.height >= remote_height and \
# (cache_item.tx.is_verified or remote_height < 1):
# return cache_item.tx # cached tx is already up-to-date
#
# try:
# cache_item.pending_verifications += 1
# return await self._update_cache_item(cache_item, tx_hash, remote_height, check_local)
# finally:
# cache_item.pending_verifications -= 1
#
# async def _update_cache_item(self, cache_item, tx_hash, remote_height, check_local=True):
#
# async with cache_item.lock:
#
# tx = cache_item.tx
#
# if tx is None and check_local:
# # check local db
# tx = cache_item.tx = await self.db.get_transaction(tx_hash=tx_hash)
#
# merkle = None
# if tx is None:
# # fetch from network
# _raw, merkle = await self.network.retriable_call(
# self.network.get_transaction_and_merkle, tx_hash, remote_height
# )
# tx = Transaction(unhexlify(_raw), height=merkle.get('block_height'))
# cache_item.tx = tx # make sure it's saved before caching it
# await self.maybe_verify_transaction(tx, remote_height, merkle)
# return tx
#
# async def maybe_verify_transaction(self, tx, remote_height, merkle=None):
# tx.height = remote_height
# cached = self._tx_cache.get(tx.hash)
# if not cached:
# # cache txs looked up by transaction_show too
# cached = TransactionCacheItem()
# cached.tx = tx
# self._tx_cache[tx.hash] = cached
# if 0 < remote_height < len(self.headers) and cached.pending_verifications <= 1:
# # can't be tx.pending_verifications == 1 because we have to handle the transaction_show case
# if not merkle:
# merkle = await self.network.retriable_call(self.network.get_merkle, tx.hash, remote_height)
# merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash)
# header = await self.headers.get(remote_height)
# tx.position = merkle['pos']
# tx.is_verified = merkle_root == header['merkle_root']
#
# async def get_address_manager_for_address(self, address) -> Optional[AddressManager]:
# details = await self.db.get_address(address=address)
# for account in self.accounts:
# if account.id == details['account']:
# return account.address_managers[details['chain']]
# return None

View file

@ -1,5 +1,4 @@
# pylint: disable=arguments-differ
import os
import json
import zlib
import asyncio
@ -10,9 +9,8 @@ from hashlib import sha256
from operator import attrgetter
from decimal import Decimal
from lbry.db import Database, SPENDABLE_TYPE_CODES, Result
from lbry.blockchain.ledger import Ledger
from lbry.event import EventController
from lbry.constants import COIN, NULL_HASH32
from lbry.blockchain.transaction import Transaction, Input, Output
from lbry.blockchain.dewies import dewies_to_lbc
@ -23,9 +21,8 @@ from lbry.schema.purchase import Purchase
from lbry.error import InsufficientFundsError, KeyFeeAboveMaxAllowedError
from lbry.stream.managed_stream import ManagedStream
from .account import Account, SingleKey, HierarchicalDeterministic
from .account import Account
from .coinselection import CoinSelector, OutputEffectiveAmountEstimator
from .storage import WalletStorage
from .preferences import TimestampedPreferences
@ -37,20 +34,22 @@ ENCRYPT_ON_DISK = 'encrypt-on-disk'
class Wallet:
""" The primary role of Wallet is to encapsulate a collection
of accounts (seed/private keys) and the spending rules / settings
for the coins attached to those accounts. Wallets are represented
by physical files on the filesystem.
for the coins attached to those accounts.
"""
def __init__(self, ledger: Ledger, db: Database, name: str, storage: WalletStorage, preferences: dict):
self.ledger = ledger
VERSION = 1
def __init__(self, wallet_id: str, db: Database, name: str = "", preferences: dict = None):
self.id = wallet_id
self.db = db
self.name = name
self.storage = storage
self.ledger = db.ledger
self.preferences = TimestampedPreferences(preferences or {})
self.encryption_password: Optional[str] = None
self.id = self.get_id()
self.utxo_lock = asyncio.Lock()
self._on_change_controller = EventController()
self.on_change = self._on_change_controller.stream
self.accounts = AccountListManager(self)
self.claims = ClaimListManager(self)
@ -60,61 +59,55 @@ class Wallet:
self.purchases = PurchaseListManager(self)
self.supports = SupportListManager(self)
def get_id(self):
return os.path.basename(self.storage.path) if self.storage.path else self.name
@classmethod
async def create(
cls, ledger: Ledger, db: Database, path: str, name: str,
create_account=False, language='en', single_key=False):
wallet = cls(ledger, db, name, WalletStorage(path), {})
if create_account:
await wallet.accounts.generate(language=language, address_generator={
'name': SingleKey.name if single_key else HierarchicalDeterministic.name
async def notify_change(self, field: str, value=None):
await self._on_change_controller.add({
'field': field, 'value': value
})
await wallet.save()
return wallet
@classmethod
async def from_path(cls, ledger: Ledger, db: Database, path: str):
return await cls.from_storage(ledger, db, WalletStorage(path))
@classmethod
async def from_storage(cls, ledger: Ledger, db: Database, storage: WalletStorage) -> 'Wallet':
json_dict = await storage.read()
if 'ledger' in json_dict and json_dict['ledger'] != ledger.get_id():
async def from_dict(cls, wallet_id: str, wallet_dict, db: Database) -> 'Wallet':
if 'ledger' in wallet_dict and wallet_dict['ledger'] != db.ledger.get_id():
raise ValueError(
f"Using ledger {ledger.get_id()} but wallet is {json_dict['ledger']}."
f"Using ledger {db.ledger.get_id()} but wallet is {wallet_dict['ledger']}."
)
version = wallet_dict.get('version')
if version == 1:
pass
wallet = cls(
ledger, db,
name=json_dict.get('name', 'Wallet'),
storage=storage,
preferences=json_dict.get('preferences', {}),
wallet_id, db,
name=wallet_dict.get('name', 'Wallet'),
preferences=wallet_dict.get('preferences', {}),
)
for account_dict in json_dict.get('accounts', []):
for account_dict in wallet_dict.get('accounts', []):
await wallet.accounts.add_from_dict(account_dict)
return wallet
def to_dict(self, encrypt_password: str = None):
def to_dict(self, encrypt_password: str = None) -> dict:
return {
'version': WalletStorage.VERSION,
'version': self.VERSION,
'ledger': self.ledger.get_id(),
'name': self.name,
'preferences': self.preferences.data,
'accounts': [a.to_dict(encrypt_password) for a in self.accounts]
}
async def save(self):
@classmethod
async def from_serialized(cls, wallet_id: str, json_data: str, db: Database) -> 'Wallet':
return await cls.from_dict(wallet_id, json.loads(json_data), db)
def to_serialized(self) -> str:
wallet_dict = None
if self.preferences.get(ENCRYPT_ON_DISK, False):
if self.encryption_password is not None:
return await self.storage.write(self.to_dict(encrypt_password=self.encryption_password))
wallet_dict = self.to_dict(encrypt_password=self.encryption_password)
elif not self.is_locked:
log.warning(
"Disk encryption requested but no password available for encryption. "
"Saving wallet in an unencrypted state."
)
return await self.storage.write(self.to_dict())
if wallet_dict is None:
wallet_dict = self.to_dict()
return json.dumps(wallet_dict, indent=4, sort_keys=True)
@property
def hash(self) -> bytes:
@ -157,8 +150,9 @@ class Wallet:
local_match.merge(account_dict)
else:
added_accounts.append(
self.accounts.add_from_dict(account_dict)
await self.accounts.add_from_dict(account_dict, notify=False)
)
await self.notify_change('wallet.merge')
return added_accounts
@property
@ -190,7 +184,7 @@ class Wallet:
async def decrypt(self):
assert not self.is_locked, "Cannot decrypt a locked wallet, unlock first."
self.preferences[ENCRYPT_ON_DISK] = False
await self.save()
await self.notify_change(ENCRYPT_ON_DISK, False)
return True
async def encrypt(self, password):
@ -198,7 +192,7 @@ class Wallet:
assert password, "Cannot encrypt with blank password."
self.encryption_password = password
self.preferences[ENCRYPT_ON_DISK] = True
await self.save()
await self.notify_change(ENCRYPT_ON_DISK, True)
return True
@property
@ -237,7 +231,7 @@ class Wallet:
if await account.save_max_gap():
gap_changed = True
if gap_changed:
await self.save()
await self.notify_change('address-max-gap')
async def get_effective_amount_estimators(self, funding_accounts: Iterable[Account]):
estimators = []
@ -379,30 +373,6 @@ class Wallet:
return tx
async def _report_state(self):
try:
for account in self.accounts:
balance = dewies_to_lbc(await account.get_balance(include_claims=True))
_, channel_count = await account.get_channels(limit=1)
claim_count = await account.get_claim_count()
if isinstance(account.receiving, SingleKey):
log.info("Loaded single key account %s with %s LBC. "
"%d channels, %d certificates and %d claims",
account.id, balance, channel_count, len(account.channel_keys), claim_count)
else:
total_receiving = len(await account.receiving.get_addresses())
total_change = len(await account.change.get_addresses())
log.info("Loaded account %s with %s LBC, %d receiving addresses (gap: %d), "
"%d change addresses (gap: %d), %d channels, %d certificates and %d claims. ",
account.id, balance, total_receiving, account.receiving.gap, total_change,
account.change.gap, channel_count, len(account.channel_keys), claim_count)
except Exception as err:
if isinstance(err, asyncio.CancelledError): # TODO: remove when updated to 3.8
raise
log.exception(
'Failed to display wallet state, please file issue '
'for this bug along with the traceback you see below:')
async def verify_duplicate(self, name: str, allow_duplicate: bool):
if not allow_duplicate:
claims = await self.claims.list(claim_name=name)
@ -438,21 +408,22 @@ class AccountListManager:
return account
async def generate(self, name: str = None, language: str = 'en', address_generator: dict = None) -> Account:
account = await Account.generate(
self.wallet.ledger, self.wallet.db, name, language, address_generator
)
account = await Account.generate(self.wallet.db, name, language, address_generator)
self._accounts.append(account)
await self.wallet.notify_change('account.added')
return account
async def add_from_dict(self, account_dict: dict) -> Account:
account = await Account.from_dict(self.wallet.ledger, self.wallet.db, account_dict)
async def add_from_dict(self, account_dict: dict, notify=True) -> Account:
account = await Account.from_dict(self.wallet.db, account_dict)
self._accounts.append(account)
if notify:
await self.wallet.notify_change('account.added')
return account
async def remove(self, account_id: str) -> Account:
account = self[account_id]
self._accounts.remove(account)
await self.wallet.save()
await self.wallet.notify_change('account.removed')
return account
def get_or_none(self, account_id: str) -> Optional[Account]:
@ -608,7 +579,7 @@ class ChannelListManager(ClaimListManager):
if save_key:
holding_account.add_channel_private_key(txo.private_key)
await self.wallet.save()
await self.wallet.notify_change('channel.added')
return tx
@ -652,7 +623,7 @@ class ChannelListManager(ClaimListManager):
if any((new_signing_key, moving_accounts)) and save_key:
holding_account.add_channel_private_key(txo.private_key)
await self.wallet.save()
await self.wallet.notify_change('channel.added')
return tx

View file

@ -1,2 +0,0 @@
from unittest import SkipTest
raise SkipTest("WIP")

View file

@ -13,7 +13,7 @@ class AccountTestCase(AsyncioTestCase):
self.addCleanup(self.db.close)
async def update_addressed_used(self, address, used):
await self.db.execute(
await self.db.execute_sql_object(
tables.PubkeyAddress.update()
.where(tables.PubkeyAddress.c.address == address)
.values(used_times=used)
@ -23,7 +23,7 @@ class AccountTestCase(AsyncioTestCase):
class TestHierarchicalDeterministicAccount(AccountTestCase):
async def test_generate_account(self):
account = await Account.generate(self.ledger, self.db)
account = await Account.generate(self.db)
self.assertEqual(account.ledger, self.ledger)
self.assertEqual(account.db, self.db)
self.assertEqual(account.name, f'Account #{account.public_key.address}')
@ -36,18 +36,19 @@ class TestHierarchicalDeterministicAccount(AccountTestCase):
self.assertIsInstance(account.receiving, HierarchicalDeterministic)
self.assertIsInstance(account.change, HierarchicalDeterministic)
async def test_ensure_address_gap(self):
account = await Account.generate(self.ledger, self.db, 'lbryum')
self.assertEqual(len(await account.receiving.get_addresses()), 0)
self.assertEqual(len(await account.change.get_addresses()), 0)
await account.ensure_address_gap()
self.assertEqual(len(await account.receiving.get_addresses()), 20)
self.assertEqual(len(await account.change.get_addresses()), 6)
async def test_ensure_address_gap(self):
account = await Account.generate(self.db)
async with account.receiving.address_generator_lock:
await account.receiving._generate_keys(4, 7)
await account.receiving._generate_keys(0, 3)
await account.receiving._generate_keys(8, 11)
records = await account.receiving.get_address_records()
self.assertListEqual(
[r['pubkey'].n for r in records],
@ -79,14 +80,14 @@ class TestHierarchicalDeterministicAccount(AccountTestCase):
self.assertEqual(len(new_keys), 20)
async def test_generate_keys_over_batch_threshold_saves_it_properly(self):
account = Account.generate(self.ledger, self.db, 'lbryum')
account = await Account.generate(self.db)
async with account.receiving.address_generator_lock:
await account.receiving._generate_keys(0, 200)
records = await account.receiving.get_address_records()
self.assertEqual(len(records), 201)
async def test_get_or_create_usable_address(self):
account = Account.generate(self.ledger, self.db, 'lbryum')
account = await Account.generate(self.db)
keys = await account.receiving.get_addresses()
self.assertEqual(len(keys), 0)
@ -98,13 +99,11 @@ class TestHierarchicalDeterministicAccount(AccountTestCase):
self.assertEqual(len(keys), 20)
async def test_generate_account_from_seed(self):
account = await Account.from_dict(
self.ledger, self.db, {
account = await Account.from_dict(self.db, {
"seed":
"carbon smart garage balance margin twelve chest sword toas"
"t envelope bottom stomach absent"
}
)
})
self.assertEqual(
account.private_key.extended_key_string(),
'xprv9s21ZrQH143K42ovpZygnjfHdAqSd9jo7zceDfPRogM7bkkoNVv7DRNLEoB8'
@ -126,6 +125,7 @@ class TestHierarchicalDeterministicAccount(AccountTestCase):
"carbon smart garage balance margin twelve chest sword toast envelope bottom stomac"
"h absent",
'encrypted': False,
'lang': 'en',
'private_key':
'xprv9s21ZrQH143K42ovpZygnjfHdAqSd9jo7zceDfPRogM7bkkoNVv7DRNLEoB8'
'HoirMgH969NrgL8jNzLEegqFzPRWM37GXd4uE8uuRkx4LAe',
@ -140,7 +140,7 @@ class TestHierarchicalDeterministicAccount(AccountTestCase):
}
}
account = Account.from_dict(self.ledger, self.db, account_data)
account = await Account.from_dict(self.db, account_data)
await account.ensure_address_gap()
@ -150,7 +150,7 @@ class TestHierarchicalDeterministicAccount(AccountTestCase):
self.assertEqual(len(addresses), 10)
self.assertDictEqual(account_data, account.to_dict())
def test_merge_diff(self):
async def test_merge_diff(self):
account_data = {
'name': 'My Account',
'modified_on': 123.456,
@ -158,6 +158,7 @@ class TestHierarchicalDeterministicAccount(AccountTestCase):
"carbon smart garage balance margin twelve chest sword toast envelope bottom stomac"
"h absent",
'encrypted': False,
'lang': 'en',
'private_key':
'xprv9s21ZrQH143K3TsAz5efNV8K93g3Ms3FXcjaWB9fVUsMwAoE3ZT4vYymkp'
'5BxKKfnpz8J6sHDFriX1SnpvjNkzcks8XBnxjGLS83BTyfpna',
@ -170,7 +171,7 @@ class TestHierarchicalDeterministicAccount(AccountTestCase):
'change': {'gap': 5, 'maximum_uses_per_address': 2}
}
}
account = Account.from_dict(self.ledger, self.db, account_data)
account = await Account.from_dict(self.db, account_data)
self.assertEqual(account.name, 'My Account')
self.assertEqual(account.modified_on, 123.456)
@ -203,15 +204,15 @@ class TestSingleKeyAccount(AccountTestCase):
async def asyncSetUp(self):
await super().asyncSetUp()
self.account = Account.generate(
self.ledger, self.db, "torba", {'name': 'single-address'}
self.account = await Account.generate(
self.db, address_generator={"name": "single-address"}
)
async def test_generate_account(self):
account = self.account
self.assertEqual(account.ledger, self.ledger)
self.assertIsNotNone(account.seed)
self.assertIsNotNone(account.phrase)
self.assertEqual(account.public_key.ledger, self.ledger)
self.assertEqual(account.private_key.public_key, account.public_key)
@ -246,7 +247,7 @@ class TestSingleKeyAccount(AccountTestCase):
self.assertEqual(new_keys[0], account.public_key.address)
records = await account.receiving.get_address_records()
pubkey = records[0].pop('pubkey')
self.assertListEqual(records, [{
self.assertEqual(records.rows, [{
'chain': 0,
'account': account.public_key.address,
'address': account.public_key.address,
@ -294,6 +295,7 @@ class TestSingleKeyAccount(AccountTestCase):
"carbon smart garage balance margin twelve chest sword toast envelope bottom stomac"
"h absent",
'encrypted': False,
'lang': 'en',
'private_key': 'xprv9s21ZrQH143K42ovpZygnjfHdAqSd9jo7zceDfPRogM7bkkoNVv7'
'DRNLEoB8HoirMgH969NrgL8jNzLEegqFzPRWM37GXd4uE8uuRkx4LAe',
'public_key': 'xpub661MyMwAqRbcGWtPvbWh9sc2BCfw2cTeVDYF23o3N1t6UZ5wv3EM'
@ -302,7 +304,7 @@ class TestSingleKeyAccount(AccountTestCase):
'certificates': {}
}
account = Account.from_dict(self.ledger, self.db, account_data)
account = await Account.from_dict(self.db, account_data)
await account.ensure_address_gap()
@ -351,9 +353,9 @@ class AccountEncryptionTests(AccountTestCase):
}
async def test_encrypt_wallet(self):
account = await Account.from_dict(self.ledger, self.db, self.unencrypted_account)
account = await Account.from_dict(self.db, self.unencrypted_account)
account.init_vectors = {
'seed': self.init_vector,
'phrase': self.init_vector,
'private_key': self.init_vector
}
@ -361,7 +363,7 @@ class AccountEncryptionTests(AccountTestCase):
self.assertIsNotNone(account.private_key)
account.encrypt(self.password)
self.assertTrue(account.encrypted)
self.assertEqual(account.seed, self.encrypted_account['seed'])
self.assertEqual(account.phrase, self.encrypted_account['seed'])
self.assertEqual(account.private_key_string, self.encrypted_account['private_key'])
self.assertIsNone(account.private_key)
@ -370,9 +372,9 @@ class AccountEncryptionTests(AccountTestCase):
account.decrypt(self.password)
self.assertEqual(account.init_vectors['private_key'], self.init_vector)
self.assertEqual(account.init_vectors['seed'], self.init_vector)
self.assertEqual(account.init_vectors['phrase'], self.init_vector)
self.assertEqual(account.seed, self.unencrypted_account['seed'])
self.assertEqual(account.phrase, self.unencrypted_account['seed'])
self.assertEqual(account.private_key.extended_key_string(), self.unencrypted_account['private_key'])
self.assertEqual(account.to_dict(encrypt_password=self.password)['seed'], self.encrypted_account['seed'])
@ -381,16 +383,16 @@ class AccountEncryptionTests(AccountTestCase):
self.assertFalse(account.encrypted)
async def test_decrypt_wallet(self):
account = await Account.from_dict(self.ledger, self.db, self.encrypted_account)
account = await Account.from_dict(self.db, self.encrypted_account)
self.assertTrue(account.encrypted)
account.decrypt(self.password)
self.assertEqual(account.init_vectors['private_key'], self.init_vector)
self.assertEqual(account.init_vectors['seed'], self.init_vector)
self.assertEqual(account.init_vectors['phrase'], self.init_vector)
self.assertFalse(account.encrypted)
self.assertEqual(account.seed, self.unencrypted_account['seed'])
self.assertEqual(account.phrase, self.unencrypted_account['seed'])
self.assertEqual(account.private_key.extended_key_string(), self.unencrypted_account['private_key'])
self.assertEqual(account.to_dict(encrypt_password=self.password)['seed'], self.encrypted_account['seed'])
@ -402,7 +404,7 @@ class AccountEncryptionTests(AccountTestCase):
account_data = self.unencrypted_account.copy()
del account_data['seed']
del account_data['private_key']
account = await Account.from_dict(self.ledger, self.db, account_data)
account = await Account.from_dict(self.db, account_data)
encrypted = account.to_dict('password')
self.assertFalse(encrypted['seed'])
self.assertFalse(encrypted['private_key'])

View file

@ -1,566 +0,0 @@
import sys
import os
import unittest
import sqlite3
import tempfile
import asyncio
from concurrent.futures.thread import ThreadPoolExecutor
from sqlalchemy import Column, Text
from lbry.wallet import (
Wallet, Account, Ledger, Headers, Transaction, Input
)
from lbry.db import Table, Version, Database, metadata
from lbry.wallet.constants import COIN
from lbry.wallet.database import query, interpolate, constraints_to_sql, AIOSQLite
from lbry.crypto.hash import sha256
from lbry.testcase import AsyncioTestCase
from tests.unit.wallet.test_transaction import get_output, NULL_HASH
class TestAIOSQLite(AsyncioTestCase):
async def asyncSetUp(self):
self.db = await AIOSQLite.connect(':memory:')
await self.db.executescript("""
pragma foreign_keys=on;
create table parent (id integer primary key, name);
create table child (id integer primary key, parent_id references parent);
""")
await self.db.execute("insert into parent values (1, 'test')")
await self.db.execute("insert into child values (2, 1)")
@staticmethod
def delete_item(transaction):
transaction.execute('delete from parent where id=1')
async def test_foreign_keys_integrity_error(self):
self.assertListEqual([(1, 'test')], await self.db.execute_fetchall("select * from parent"))
with self.assertRaises(sqlite3.IntegrityError):
await self.db.run(self.delete_item)
self.assertListEqual([(1, 'test')], await self.db.execute_fetchall("select * from parent"))
await self.db.executescript("pragma foreign_keys=off;")
await self.db.run(self.delete_item)
self.assertListEqual([], await self.db.execute_fetchall("select * from parent"))
async def test_run_without_foreign_keys(self):
self.assertListEqual([(1, 'test')], await self.db.execute_fetchall("select * from parent"))
await self.db.run_with_foreign_keys_disabled(self.delete_item)
self.assertListEqual([], await self.db.execute_fetchall("select * from parent"))
async def test_integrity_error_when_foreign_keys_disabled_and_skipped(self):
await self.db.executescript("pragma foreign_keys=off;")
self.assertListEqual([(1, 'test')], await self.db.execute_fetchall("select * from parent"))
with self.assertRaises(sqlite3.IntegrityError):
await self.db.run_with_foreign_keys_disabled(self.delete_item)
self.assertListEqual([(1, 'test')], await self.db.execute_fetchall("select * from parent"))
class TestQueryBuilder(unittest.TestCase):
def test_dot(self):
self.assertTupleEqual(
constraints_to_sql({'txo.position': 18}),
('txo.position = :txo_position0', {'txo_position0': 18})
)
self.assertTupleEqual(
constraints_to_sql({'txo.position#6': 18}),
('txo.position = :txo_position6', {'txo_position6': 18})
)
def test_any(self):
self.assertTupleEqual(
constraints_to_sql({
'ages__any': {
'txo.age__gt': 18,
'txo.age__lt': 38
}
}),
('(txo.age > :ages__any0_txo_age__gt0 OR txo.age < :ages__any0_txo_age__lt0)', {
'ages__any0_txo_age__gt0': 18,
'ages__any0_txo_age__lt0': 38
})
)
def test_in(self):
self.assertTupleEqual(
constraints_to_sql({'txo.age__in#2': [18, 38]}),
('txo.age IN (:txo_age__in2_0, :txo_age__in2_1)', {
'txo_age__in2_0': 18,
'txo_age__in2_1': 38
})
)
self.assertTupleEqual(
constraints_to_sql({'txo.name__in': ('abc123', 'def456')}),
('txo.name IN (:txo_name__in0_0, :txo_name__in0_1)', {
'txo_name__in0_0': 'abc123',
'txo_name__in0_1': 'def456'
})
)
self.assertTupleEqual(
constraints_to_sql({'txo.name__in': {'abc123'}}),
('txo.name = :txo_name__in0', {
'txo_name__in0': 'abc123',
})
)
self.assertTupleEqual(
constraints_to_sql({'txo.age__in': 'SELECT age from ages_table'}),
('txo.age IN (SELECT age from ages_table)', {})
)
def test_not_in(self):
self.assertTupleEqual(
constraints_to_sql({'txo.age__not_in': [18, 38]}),
('txo.age NOT IN (:txo_age__not_in0_0, :txo_age__not_in0_1)', {
'txo_age__not_in0_0': 18,
'txo_age__not_in0_1': 38
})
)
self.assertTupleEqual(
constraints_to_sql({'txo.name__not_in': ('abc123', 'def456')}),
('txo.name NOT IN (:txo_name__not_in0_0, :txo_name__not_in0_1)', {
'txo_name__not_in0_0': 'abc123',
'txo_name__not_in0_1': 'def456'
})
)
self.assertTupleEqual(
constraints_to_sql({'txo.name__not_in': ('abc123',)}),
('txo.name != :txo_name__not_in0', {
'txo_name__not_in0': 'abc123',
})
)
self.assertTupleEqual(
constraints_to_sql({'txo.age__not_in': 'SELECT age from ages_table'}),
('txo.age NOT IN (SELECT age from ages_table)', {})
)
def test_in_invalid(self):
with self.assertRaisesRegex(ValueError, 'list, set or string'):
constraints_to_sql({'ages__in': 9})
def test_query(self):
self.assertTupleEqual(
query("select * from foo"),
("select * from foo", {})
)
self.assertTupleEqual(
query(
"select * from foo",
a__not='b', b__in='select * from blah where c=:$c',
d__any={'one__like': 'o', 'two': 2}, limit=10, order_by='b', **{'$c': 3}),
(
"select * from foo WHERE a != :a__not0 AND "
"b IN (select * from blah where c=:$c) AND "
"(one LIKE :d__any0_one__like0 OR two = :d__any0_two0) ORDER BY b LIMIT 10",
{'a__not0': 'b', 'd__any0_one__like0': 'o', 'd__any0_two0': 2, '$c': 3}
)
)
def test_query_order_by(self):
self.assertTupleEqual(
query("select * from foo", order_by='foo'),
("select * from foo ORDER BY foo", {})
)
self.assertTupleEqual(
query("select * from foo", order_by=['foo', 'bar']),
("select * from foo ORDER BY foo, bar", {})
)
with self.assertRaisesRegex(ValueError, 'order_by must be string or list'):
query("select * from foo", order_by={'foo': 'bar'})
def test_query_limit_offset(self):
self.assertTupleEqual(
query("select * from foo", limit=10),
("select * from foo LIMIT 10", {})
)
self.assertTupleEqual(
query("select * from foo", offset=10),
("select * from foo OFFSET 10", {})
)
self.assertTupleEqual(
query("select * from foo", limit=20, offset=10),
("select * from foo LIMIT 20 OFFSET 10", {})
)
def test_query_interpolation(self):
self.maxDiff = None
# tests that interpolation replaces longer keys first
self.assertEqual(
interpolate(*query(
"select * from foo",
a__not='b', b__in='select * from blah where c=:$c',
d__any={'one__like': 'o', 'two': 2},
a0=3, a00=1, a00a=2, a00aa=4, # <-- breaks without correct interpolation key order
ahash=sha256(b'hello world'),
limit=10, order_by='b', **{'$c': 3})
),
"select * from foo WHERE a != 'b' AND "
"b IN (select * from blah where c=3) AND "
"(one LIKE 'o' OR two = 2) AND "
"a0 = 3 AND a00 = 1 AND a00a = 2 AND a00aa = 4 "
"AND ahash = X'b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9' "
"ORDER BY b LIMIT 10"
)
class TestQueries(AsyncioTestCase):
async def asyncSetUp(self):
self.ledger = Ledger({
'db': Database('sqlite:///:memory:'),
'headers': Headers(':memory:')
})
self.wallet = Wallet()
await self.ledger.db.open()
async def asyncTearDown(self):
await self.ledger.db.close()
async def create_account(self, wallet=None):
account = Account.generate(self.ledger, wallet or self.wallet)
await account.ensure_address_gap()
return account
async def create_tx_from_nothing(self, my_account, height):
to_address = await my_account.receiving.get_or_create_usable_address()
to_hash = Ledger.address_to_hash160(to_address)
tx = Transaction(height=height, is_verified=True) \
.add_inputs([self.txi(self.txo(1, sha256(str(height).encode())))]) \
.add_outputs([self.txo(1, to_hash)])
await self.ledger.db.insert_transaction(tx)
await self.ledger.db.save_transaction_io(tx, to_address, to_hash, '')
return tx
async def create_tx_from_txo(self, txo, to_account, height):
from_hash = txo.script.values['pubkey_hash']
from_address = self.ledger.pubkey_hash_to_address(from_hash)
to_address = await to_account.receiving.get_or_create_usable_address()
to_hash = Ledger.address_to_hash160(to_address)
tx = Transaction(height=height, is_verified=True) \
.add_inputs([self.txi(txo)]) \
.add_outputs([self.txo(1, to_hash)])
await self.ledger.db.insert_transaction(tx)
await self.ledger.db.save_transaction_io(tx, from_address, from_hash, '')
await self.ledger.db.save_transaction_io(tx, to_address, to_hash, '')
return tx
async def create_tx_to_nowhere(self, txo, height):
from_hash = txo.script.values['pubkey_hash']
from_address = self.ledger.pubkey_hash_to_address(from_hash)
to_hash = NULL_HASH
tx = Transaction(height=height, is_verified=True) \
.add_inputs([self.txi(txo)]) \
.add_outputs([self.txo(1, to_hash)])
await self.ledger.db.insert_transaction(tx)
await self.ledger.db.save_transaction_io(tx, from_address, from_hash, '')
return tx
def txo(self, amount, address):
return get_output(int(amount*COIN), address)
def txi(self, txo):
return Input.spend(txo)
async def test_large_tx_doesnt_hit_variable_limits(self):
# SQLite is usually compiled with 999 variables limit: https://www.sqlite.org/limits.html
# This can be removed when there is a better way. See: https://github.com/lbryio/lbry-sdk/issues/2281
fetchall = self.ledger.db.execute_fetchall
def check_parameters_length(sql, parameters=None):
self.assertLess(len(parameters or []), 999)
return fetchall(sql, parameters)
self.ledger.db.execute_fetchall = check_parameters_length
account = await self.create_account()
tx = await self.create_tx_from_nothing(account, 0)
for height in range(1, 1200):
tx = await self.create_tx_from_txo(tx.outputs[0], account, height=height)
variable_limit = self.ledger.db.MAX_QUERY_VARIABLES
for limit in range(variable_limit - 2, variable_limit + 2):
txs = await self.ledger.get_transactions(
accounts=self.wallet.accounts, limit=limit, order_by='height asc')
self.assertEqual(len(txs), limit)
inputs, outputs, last_tx = set(), set(), txs[0]
for tx in txs[1:]:
self.assertEqual(len(tx.inputs), 1)
self.assertEqual(tx.inputs[0].txo_ref.tx_ref.id, last_tx.id)
self.assertEqual(len(tx.outputs), 1)
last_tx = tx
async def test_queries(self):
wallet1 = Wallet()
account1 = await self.create_account(wallet1)
self.assertEqual(26, await self.ledger.db.get_address_count(accounts=[account1]))
wallet2 = Wallet()
account2 = await self.create_account(wallet2)
account3 = await self.create_account(wallet2)
self.assertEqual(26, await self.ledger.db.get_address_count(accounts=[account2]))
self.assertEqual(0, await self.ledger.db.get_transaction_count(accounts=[account1, account2, account3]))
self.assertEqual(0, await self.ledger.db.get_utxo_count())
self.assertListEqual([], await self.ledger.db.get_utxos())
self.assertEqual(0, await self.ledger.db.get_txo_count())
self.assertEqual(0, await self.ledger.db.get_balance(wallet=wallet1))
self.assertEqual(0, await self.ledger.db.get_balance(wallet=wallet2))
self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account1]))
self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account2]))
self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account3]))
tx1 = await self.create_tx_from_nothing(account1, 1)
self.assertEqual(1, await self.ledger.db.get_transaction_count(accounts=[account1]))
self.assertEqual(0, await self.ledger.db.get_transaction_count(accounts=[account2]))
self.assertEqual(1, await self.ledger.db.get_utxo_count(accounts=[account1]))
self.assertEqual(1, await self.ledger.db.get_txo_count(accounts=[account1]))
self.assertEqual(0, await self.ledger.db.get_txo_count(accounts=[account2]))
self.assertEqual(10**8, await self.ledger.db.get_balance(wallet=wallet1))
self.assertEqual(0, await self.ledger.db.get_balance(wallet=wallet2))
self.assertEqual(10**8, await self.ledger.db.get_balance(accounts=[account1]))
self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account2]))
self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account3]))
tx2 = await self.create_tx_from_txo(tx1.outputs[0], account2, 2)
tx2b = await self.create_tx_from_nothing(account3, 2)
self.assertEqual(2, await self.ledger.db.get_transaction_count(accounts=[account1]))
self.assertEqual(1, await self.ledger.db.get_transaction_count(accounts=[account2]))
self.assertEqual(1, await self.ledger.db.get_transaction_count(accounts=[account3]))
self.assertEqual(0, await self.ledger.db.get_utxo_count(accounts=[account1]))
self.assertEqual(1, await self.ledger.db.get_txo_count(accounts=[account1]))
self.assertEqual(1, await self.ledger.db.get_utxo_count(accounts=[account2]))
self.assertEqual(1, await self.ledger.db.get_txo_count(accounts=[account2]))
self.assertEqual(1, await self.ledger.db.get_utxo_count(accounts=[account3]))
self.assertEqual(1, await self.ledger.db.get_txo_count(accounts=[account3]))
self.assertEqual(0, await self.ledger.db.get_balance(wallet=wallet1))
self.assertEqual(10**8+10**8, await self.ledger.db.get_balance(wallet=wallet2))
self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account1]))
self.assertEqual(10**8, await self.ledger.db.get_balance(accounts=[account2]))
self.assertEqual(10**8, await self.ledger.db.get_balance(accounts=[account3]))
tx3 = await self.create_tx_to_nowhere(tx2.outputs[0], 3)
self.assertEqual(2, await self.ledger.db.get_transaction_count(accounts=[account1]))
self.assertEqual(2, await self.ledger.db.get_transaction_count(accounts=[account2]))
self.assertEqual(0, await self.ledger.db.get_utxo_count(accounts=[account1]))
self.assertEqual(1, await self.ledger.db.get_txo_count(accounts=[account1]))
self.assertEqual(0, await self.ledger.db.get_utxo_count(accounts=[account2]))
self.assertEqual(1, await self.ledger.db.get_txo_count(accounts=[account2]))
self.assertEqual(0, await self.ledger.db.get_balance(wallet=wallet1))
self.assertEqual(10**8, await self.ledger.db.get_balance(wallet=wallet2))
self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account1]))
self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account2]))
self.assertEqual(10**8, await self.ledger.db.get_balance(accounts=[account3]))
txs = await self.ledger.db.get_transactions(accounts=[account1, account2])
self.assertListEqual([tx3.id, tx2.id, tx1.id], [tx.id for tx in txs])
self.assertListEqual([3, 2, 1], [tx.height for tx in txs])
txs = await self.ledger.db.get_transactions(wallet=wallet1, accounts=wallet1.accounts, include_is_my_output=True)
self.assertListEqual([tx2.id, tx1.id], [tx.id for tx in txs])
self.assertEqual(txs[0].inputs[0].is_my_input, True)
self.assertEqual(txs[0].outputs[0].is_my_output, False)
self.assertEqual(txs[1].inputs[0].is_my_input, False)
self.assertEqual(txs[1].outputs[0].is_my_output, True)
txs = await self.ledger.db.get_transactions(wallet=wallet2, accounts=[account2], include_is_my_output=True)
self.assertListEqual([tx3.id, tx2.id], [tx.id for tx in txs])
self.assertEqual(txs[0].inputs[0].is_my_input, True)
self.assertEqual(txs[0].outputs[0].is_my_output, False)
self.assertEqual(txs[1].inputs[0].is_my_input, False)
self.assertEqual(txs[1].outputs[0].is_my_output, True)
self.assertEqual(2, await self.ledger.db.get_transaction_count(accounts=[account2]))
tx = await self.ledger.db.get_transaction(tx_hash=tx2.hash)
self.assertEqual(tx.id, tx2.id)
self.assertIsNone(tx.inputs[0].is_my_input)
self.assertIsNone(tx.outputs[0].is_my_output)
tx = await self.ledger.db.get_transaction(wallet=wallet1, tx_hash=tx2.hash, include_is_my_output=True)
self.assertTrue(tx.inputs[0].is_my_input)
self.assertFalse(tx.outputs[0].is_my_output)
tx = await self.ledger.db.get_transaction(wallet=wallet2, tx_hash=tx2.hash, include_is_my_output=True)
self.assertFalse(tx.inputs[0].is_my_input)
self.assertTrue(tx.outputs[0].is_my_output)
# height 0 sorted to the top with the rest in descending order
tx4 = await self.create_tx_from_nothing(account1, 0)
txos = await self.ledger.db.get_txos()
self.assertListEqual([0, 3, 2, 2, 1], [txo.tx_ref.height for txo in txos])
self.assertListEqual([tx4.id, tx3.id, tx2.id, tx2b.id, tx1.id], [txo.tx_ref.id for txo in txos])
txs = await self.ledger.db.get_transactions(accounts=[account1, account2])
self.assertListEqual([0, 3, 2, 1], [tx.height for tx in txs])
self.assertListEqual([tx4.id, tx3.id, tx2.id, tx1.id], [tx.id for tx in txs])
async def test_empty_history(self):
self.assertEqual((None, []), await self.ledger.get_local_status_and_history(''))
class TestUpgrade(AsyncioTestCase):
def setUp(self) -> None:
self.path = tempfile.mktemp()
def tearDown(self) -> None:
os.remove(self.path)
def get_version(self):
with sqlite3.connect(self.path) as conn:
versions = conn.execute('select version from version').fetchall()
assert len(versions) == 1
return versions[0][0]
def get_tables(self):
with sqlite3.connect(self.path) as conn:
sql = "SELECT name FROM sqlite_master WHERE type='table' ORDER BY name;"
return [col[0] for col in conn.execute(sql).fetchall()]
def add_address(self, address):
with sqlite3.connect(self.path) as conn:
conn.execute("""
INSERT INTO account_address (address, account, chain, n, pubkey, chain_code, depth)
VALUES (?, 'account1', 0, 0, 'pubkey', 'chain_code', 0)
""", (address,))
def get_addresses(self):
with sqlite3.connect(self.path) as conn:
sql = "SELECT address FROM account_address ORDER BY address;"
return [col[0] for col in conn.execute(sql).fetchall()]
async def test_reset_on_version_change(self):
self.ledger = Ledger({
'db': Database('sqlite:///'+self.path),
'headers': Headers(':memory:')
})
# initial open, pre-version enabled db
self.ledger.db.SCHEMA_VERSION = None
self.assertListEqual(self.get_tables(), [])
await self.ledger.db.open()
metadata.drop_all(self.ledger.db.engine, [Version]) # simulate pre-version table db
self.assertEqual(self.get_tables(), ['account_address', 'block', 'pubkey_address', 'tx', 'txi', 'txo'])
self.assertListEqual(self.get_addresses(), [])
self.add_address('address1')
await self.ledger.db.close()
# initial open after version enabled
self.ledger.db.SCHEMA_VERSION = '1.0'
await self.ledger.db.open()
self.assertEqual(self.get_version(), '1.0')
self.assertListEqual(self.get_tables(), ['account_address', 'block', 'pubkey_address', 'tx', 'txi', 'txo', 'version'])
self.assertListEqual(self.get_addresses(), []) # address1 deleted during version upgrade
self.add_address('address2')
await self.ledger.db.close()
# nothing changes
self.assertEqual(self.get_version(), '1.0')
self.assertListEqual(self.get_tables(), ['account_address', 'block', 'pubkey_address', 'tx', 'txi', 'txo', 'version'])
await self.ledger.db.open()
self.assertEqual(self.get_version(), '1.0')
self.assertListEqual(self.get_tables(), ['account_address', 'block', 'pubkey_address', 'tx', 'txi', 'txo', 'version'])
self.assertListEqual(self.get_addresses(), ['address2'])
await self.ledger.db.close()
# upgrade version, database reset
foo = Table('foo', metadata, Column('bar', Text, primary_key=True))
self.addCleanup(metadata.remove, foo)
self.ledger.db.SCHEMA_VERSION = '1.1'
await self.ledger.db.open()
self.assertEqual(self.get_version(), '1.1')
self.assertListEqual(self.get_tables(), ['account_address', 'block', 'foo', 'pubkey_address', 'tx', 'txi', 'txo', 'version'])
self.assertListEqual(self.get_addresses(), []) # all tables got reset
await self.ledger.db.close()
class TestSQLiteRace(AsyncioTestCase):
max_misuse_attempts = 40000
def setup_db(self):
self.db = sqlite3.connect(":memory:", isolation_level=None)
self.db.executescript(
"create table test1 (id text primary key not null, val text);\n" +
"create table test2 (id text primary key not null, val text);\n" +
"\n".join(f"insert into test1 values ({v}, NULL);" for v in range(1000))
)
async def asyncSetUp(self):
self.executor = ThreadPoolExecutor(1)
await self.loop.run_in_executor(self.executor, self.setup_db)
async def asyncTearDown(self):
await self.loop.run_in_executor(self.executor, self.db.close)
self.executor.shutdown()
async def test_binding_param_0_error(self):
# test real param 0 binding errors
for supported_type in [str, int, bytes]:
await self.loop.run_in_executor(
self.executor, self.db.executemany, "insert into test2 values (?, NULL)",
[(supported_type(1), ), (supported_type(2), )]
)
await self.loop.run_in_executor(
self.executor, self.db.execute, "delete from test2 where id in (1, 2)"
)
for unsupported_type in [lambda x: (x, ), lambda x: [x], lambda x: {x}]:
try:
await self.loop.run_in_executor(
self.executor, self.db.executemany, "insert into test2 (id, val) values (?, NULL)",
[(unsupported_type(1), ), (unsupported_type(2), )]
)
self.assertTrue(False)
except sqlite3.InterfaceError as err:
self.assertEqual(str(err), "Error binding parameter 0 - probably unsupported type.")
async def test_unhandled_sqlite_misuse(self):
# test SQLITE_MISUSE being incorrectly raised as a param 0 binding error
attempts = 0
python_version = sys.version.split('\n')[0].rstrip(' ')
try:
while attempts < self.max_misuse_attempts:
f1 = asyncio.wrap_future(
self.loop.run_in_executor(
self.executor, self.db.executemany, "update test1 set val='derp' where id=?",
((str(i),) for i in range(2))
)
)
f2 = asyncio.wrap_future(
self.loop.run_in_executor(
self.executor, self.db.executemany, "update test2 set val='derp' where id=?",
((str(i),) for i in range(2))
)
)
attempts += 1
await asyncio.gather(f1, f2)
print(f"\nsqlite3 {sqlite3.version}/python {python_version} "
f"did not raise SQLITE_MISUSE within {attempts} attempts of the race condition")
self.assertTrue(False, 'this test failing means either the sqlite race conditions '
'have been fixed in cpython or the test max_attempts needs to be increased')
except sqlite3.InterfaceError as err:
self.assertEqual(str(err), "Error binding parameter 0 - probably unsupported type.")
print(f"\nsqlite3 {sqlite3.version}/python {python_version} raised SQLITE_MISUSE "
f"after {attempts} attempts of the race condition")
@unittest.SkipTest
async def test_fetchall_prevents_sqlite_misuse(self):
# test that calling fetchall sufficiently avoids the race
attempts = 0
def executemany_fetchall(query, params):
self.db.executemany(query, params).fetchall()
while attempts < self.max_misuse_attempts:
f1 = asyncio.wrap_future(
self.loop.run_in_executor(
self.executor, executemany_fetchall, "update test1 set val='derp' where id=?",
((str(i),) for i in range(2))
)
)
f2 = asyncio.wrap_future(
self.loop.run_in_executor(
self.executor, executemany_fetchall, "update test2 set val='derp' where id=?",
((str(i),) for i in range(2))
)
)
attempts += 1
await asyncio.gather(f1, f2)

View file

@ -1,242 +0,0 @@
import os
from binascii import hexlify, unhexlify
from lbry.testcase import AsyncioTestCase
from lbry.wallet import Wallet, Account, Transaction, Output, Input, Ledger, Headers
from lbry.db import Database
from tests.unit.wallet.test_transaction import get_transaction, get_output
from tests.unit.wallet.test_headers import HEADERS, block_bytes
class MockNetwork:
def __init__(self, history, transaction):
self.history = history
self.transaction = transaction
self.address = None
self.get_history_called = []
self.get_transaction_called = []
self.is_connected = False
def retriable_call(self, function, *args, **kwargs):
return function(*args, **kwargs)
async def get_history(self, address):
self.get_history_called.append(address)
self.address = address
return self.history
async def get_merkle(self, txid, height):
return {'merkle': ['abcd01'], 'pos': 1}
async def get_transaction(self, tx_hash, _=None):
self.get_transaction_called.append(tx_hash)
return self.transaction[tx_hash]
async def get_transaction_and_merkle(self, tx_hash, known_height=None):
tx = await self.get_transaction(tx_hash)
merkle = {}
if known_height:
merkle = await self.get_merkle(tx_hash, known_height)
return tx, merkle
class LedgerTestCase(AsyncioTestCase):
async def asyncSetUp(self):
self.ledger = Ledger({
'db': Database('sqlite:///:memory:'),
'headers': Headers(':memory:')
})
self.account = Account.generate(self.ledger, Wallet(), "lbryum")
await self.ledger.db.open()
async def asyncTearDown(self):
await self.ledger.db.close()
def make_header(self, **kwargs):
header = {
'bits': 486604799,
'block_height': 0,
'merkle_root': b'4a5e1e4baab89f3a32518a88c31bc87f618f76673e2cc77ab2127b7afdeda33b',
'nonce': 2083236893,
'prev_block_hash': b'0000000000000000000000000000000000000000000000000000000000000000',
'timestamp': 1231006505,
'claim_trie_root': b'4a5e1e4baab89f3a32518a88c31bc87f618f76673e2cc77ab2127b7afdeda33b',
'version': 1
}
header.update(kwargs)
header['merkle_root'] = header['merkle_root'].ljust(64, b'a')
header['prev_block_hash'] = header['prev_block_hash'].ljust(64, b'0')
return self.ledger.headers.serialize(header)
def add_header(self, **kwargs):
serialized = self.make_header(**kwargs)
self.ledger.headers.io.seek(0, os.SEEK_END)
self.ledger.headers.io.write(serialized)
self.ledger.headers._size = self.ledger.headers.io.seek(0, os.SEEK_END) // self.ledger.headers.header_size
class TestSynchronization(LedgerTestCase):
async def test_update_history(self):
txid1 = '252bda9b22cc902ca2aa2de3548ee8baf06b8501ff7bfb3b0b7d980dbd1bf792'
txid2 = 'ab9c0654dd484ac20437030f2034e25dcb29fc507e84b91138f80adc3af738f9'
txid3 = 'a2ae3d1db3c727e7d696122cab39ee20a7f81856dab7019056dd539f38c548a0'
txid4 = '047cf1d53ef68f0fd586d46f90c09ff8e57a4180f67e7f4b8dd0135c3741e828'
txhash1 = unhexlify(txid1)[::-1]
txhash2 = unhexlify(txid2)[::-1]
txhash3 = unhexlify(txid3)[::-1]
txhash4 = unhexlify(txid4)[::-1]
account = Account.generate(self.ledger, Wallet(), "torba")
address = await account.receiving.get_or_create_usable_address()
address_details = await self.ledger.db.get_address(address=address)
self.assertIsNone(address_details['history'])
self.add_header(block_height=0, merkle_root=b'abcd04')
self.add_header(block_height=1, merkle_root=b'abcd04')
self.add_header(block_height=2, merkle_root=b'abcd04')
self.add_header(block_height=3, merkle_root=b'abcd04')
self.ledger.network = MockNetwork([
{'tx_hash': txid1, 'height': 0},
{'tx_hash': txid2, 'height': 1},
{'tx_hash': txid3, 'height': 2},
], {
txhash1: hexlify(get_transaction(get_output(1)).raw),
txhash2: hexlify(get_transaction(get_output(2)).raw),
txhash3: hexlify(get_transaction(get_output(3)).raw),
})
await self.ledger.update_history(address, '')
self.assertListEqual(self.ledger.network.get_history_called, [address])
self.assertListEqual(self.ledger.network.get_transaction_called, [txhash1, txhash2, txhash3])
address_details = await self.ledger.db.get_address(address=address)
self.assertEqual(
address_details['history'],
f'{txid1}:0:'
f'{txid2}:1:'
f'{txid3}:2:'
)
self.ledger.network.get_history_called = []
self.ledger.network.get_transaction_called = []
for cache_item in self.ledger._tx_cache.values():
cache_item.tx.is_verified = True
await self.ledger.update_history(address, '')
self.assertListEqual(self.ledger.network.get_history_called, [address])
self.assertListEqual(self.ledger.network.get_transaction_called, [])
self.ledger.network.history.append({'tx_hash': txid4, 'height': 3})
self.ledger.network.transaction[txhash4] = hexlify(get_transaction(get_output(4)).raw)
self.ledger.network.get_history_called = []
self.ledger.network.get_transaction_called = []
await self.ledger.update_history(address, '')
self.assertListEqual(self.ledger.network.get_history_called, [address])
self.assertListEqual(self.ledger.network.get_transaction_called, [txhash4])
address_details = await self.ledger.db.get_address(address=address)
self.assertEqual(
address_details['history'],
f'{txid1}:0:'
f'{txid2}:1:'
f'{txid3}:2:'
f'{txid4}:3:'
)
class MocHeaderNetwork(MockNetwork):
def __init__(self, responses):
super().__init__(None, None)
self.responses = responses
async def get_headers(self, height, blocks):
return self.responses[height]
class BlockchainReorganizationTests(LedgerTestCase):
async def test_1_block_reorganization(self):
self.ledger.network = MocHeaderNetwork({
10: {'height': 10, 'count': 5, 'hex': hexlify(
HEADERS[block_bytes(10):block_bytes(15)]
)},
15: {'height': 15, 'count': 0, 'hex': b''}
})
headers = self.ledger.headers
await headers.connect(0, HEADERS[:block_bytes(10)])
self.add_header(block_height=len(headers))
self.assertEqual(10, headers.height)
await self.ledger.receive_header([{
'height': 11, 'hex': hexlify(self.make_header(block_height=11))
}])
async def test_3_block_reorganization(self):
self.ledger.network = MocHeaderNetwork({
10: {'height': 10, 'count': 5, 'hex': hexlify(
HEADERS[block_bytes(10):block_bytes(15)]
)},
11: {'height': 11, 'count': 1, 'hex': hexlify(self.make_header(block_height=11))},
12: {'height': 12, 'count': 1, 'hex': hexlify(self.make_header(block_height=12))},
15: {'height': 15, 'count': 0, 'hex': b''}
})
headers = self.ledger.headers
await headers.connect(0, HEADERS[:block_bytes(10)])
self.add_header(block_height=len(headers))
self.add_header(block_height=len(headers))
self.add_header(block_height=len(headers))
self.assertEqual(headers.height, 12)
await self.ledger.receive_header([{
'height': 13, 'hex': hexlify(self.make_header(block_height=13))
}])
class BasicAccountingTests(LedgerTestCase):
async def test_empty_state(self):
self.assertEqual(await self.account.get_balance(), 0)
async def test_balance(self):
address = await self.account.receiving.get_or_create_usable_address()
hash160 = self.ledger.address_to_hash160(address)
tx = Transaction(is_verified=True)\
.add_outputs([Output.pay_pubkey_hash(100, hash160)])
await self.ledger.db.insert_transaction(tx)
await self.ledger.db.save_transaction_io(
tx, address, hash160, f'{tx.id}:1:'
)
self.assertEqual(await self.account.get_balance(), 100)
tx = Transaction(is_verified=True)\
.add_outputs([Output.pay_claim_name_pubkey_hash(100, 'foo', b'', hash160)])
await self.ledger.db.insert_transaction(tx)
await self.ledger.db.save_transaction_io(
tx, address, hash160, f'{tx.id}:1:'
)
self.assertEqual(await self.account.get_balance(), 100) # claim names don't count towards balance
self.assertEqual(await self.account.get_balance(include_claims=True), 200)
async def test_get_utxo(self):
address = yield self.account.receiving.get_or_create_usable_address()
hash160 = self.ledger.address_to_hash160(address)
tx = Transaction(is_verified=True)\
.add_outputs([Output.pay_pubkey_hash(100, hash160)])
await self.ledger.db.save_transaction_io(
'insert', tx, address, hash160, f'{tx.id}:1:'
)
utxos = await self.account.get_utxos()
self.assertEqual(len(utxos), 1)
tx = Transaction(is_verified=True)\
.add_inputs([Input.spend(utxos[0])])
await self.ledger.db.save_transaction_io(
'insert', tx, address, hash160, f'{tx.id}:1:'
)
self.assertEqual(await self.account.get_balance(include_claims=True), 0)
utxos = await self.account.get_utxos()
self.assertEqual(len(utxos), 0)

View file

@ -4,69 +4,70 @@ import tempfile
from lbry import Config, Ledger, Database, WalletManager, Wallet, Account
from lbry.testcase import AsyncioTestCase
from lbry.wallet.manager import FileWallet, DatabaseWallet
class TestWalletManager(AsyncioTestCase):
class DBBasedWalletManagerTestCase(AsyncioTestCase):
async def asyncSetUp(self):
self.temp_dir = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, self.temp_dir)
self.ledger = Ledger(Config.with_same_dir(self.temp_dir).set(
db_url="sqlite:///:memory:"
self.ledger = Ledger(Config.with_null_dir().set(
db_url="sqlite:///:memory:",
wallet_storage="database"
))
self.db = Database(self.ledger)
await self.db.open()
self.addCleanup(self.db.close)
async def test_ensure_path_exists(self):
wm = WalletManager(self.ledger, self.db)
self.assertFalse(os.path.exists(wm.path))
await wm.ensure_path_exists()
self.assertTrue(os.path.exists(wm.path))
async def test_load_with_default_wallet_account_progression(self):
wm = WalletManager(self.ledger, self.db)
await wm.ensure_path_exists()
class TestDatabaseWalletManager(DBBasedWalletManagerTestCase):
async def test_initialize_with_default_wallet_account_progression(self):
wm = WalletManager(self.db)
self.assertIsInstance(wm.storage, DatabaseWallet)
storage: DatabaseWallet = wm.storage
await storage.prepare()
# first, no defaults
self.ledger.conf.create_default_wallet = False
self.ledger.conf.create_default_account = False
await wm.load()
await wm.initialize()
self.assertIsNone(wm.default)
# then, yes to default wallet but no to default account
self.ledger.conf.create_default_wallet = True
self.ledger.conf.create_default_account = False
await wm.load()
await wm.initialize()
self.assertIsInstance(wm.default, Wallet)
self.assertTrue(os.path.exists(wm.default.storage.path))
self.assertTrue(await storage.exists(wm.default.id))
self.assertIsNone(wm.default.accounts.default)
# finally, yes to all the things
self.ledger.conf.create_default_wallet = True
self.ledger.conf.create_default_account = True
await wm.load()
await wm.initialize()
self.assertIsInstance(wm.default, Wallet)
self.assertIsInstance(wm.default.accounts.default, Account)
async def test_load_with_create_default_everything_upfront(self):
wm = WalletManager(self.ledger, self.db)
await wm.ensure_path_exists()
wm = WalletManager(self.db)
await wm.storage.prepare()
self.ledger.conf.create_default_wallet = True
self.ledger.conf.create_default_account = True
await wm.load()
await wm.initialize()
self.assertIsInstance(wm.default, Wallet)
self.assertIsInstance(wm.default.accounts.default, Account)
self.assertTrue(os.path.exists(wm.default.storage.path))
self.assertTrue(await wm.storage.exists(wm.default.id))
async def test_load_errors(self):
_wm = WalletManager(self.ledger, self.db)
await _wm.ensure_path_exists()
_wm = WalletManager(self.db)
await _wm.storage.prepare()
await _wm.create('bar', '')
await _wm.create('foo', '')
wm = WalletManager(self.ledger, self.db)
wm = WalletManager(self.db)
self.ledger.conf.wallets = ['bar', 'foo', 'foo']
with self.assertLogs(level='WARN') as cm:
await wm.load()
await wm.initialize()
self.assertEqual(
cm.output, [
'WARNING:lbry.wallet.manager:Ignoring duplicate wallet_id in config: foo',
@ -75,9 +76,9 @@ class TestWalletManager(AsyncioTestCase):
self.assertEqual({'bar', 'foo'}, set(wm.wallets))
async def test_creating_and_accessing_wallets(self):
wm = WalletManager(self.ledger, self.db)
await wm.ensure_path_exists()
await wm.load()
wm = WalletManager(self.db)
await wm.storage.prepare()
await wm.initialize()
default_wallet = wm.default
self.assertEqual(default_wallet, wm['default_wallet'])
self.assertEqual(default_wallet, wm.get_or_default(None))
@ -90,3 +91,114 @@ class TestWalletManager(AsyncioTestCase):
_ = wm['invalid']
with self.assertRaisesRegex(ValueError, "Couldn't find wallet: invalid"):
wm.get_or_default('invalid')
class TestFileBasedWalletManager(AsyncioTestCase):
async def asyncSetUp(self):
self.temp_dir = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, self.temp_dir)
self.ledger = Ledger(Config(
data_dir=self.temp_dir,
db_url="sqlite:///:memory:"
))
self.ledger.conf.set_default_paths()
self.db = Database(self.ledger)
await self.db.open()
self.addCleanup(self.db.close)
async def test_ensure_path_exists(self):
wm = WalletManager(self.db)
self.assertIsInstance(wm.storage, FileWallet)
storage: FileWallet = wm.storage
self.assertFalse(os.path.exists(storage.wallet_dir))
await storage.prepare()
self.assertTrue(os.path.exists(storage.wallet_dir))
async def test_initialize_with_default_wallet_account_progression(self):
wm = WalletManager(self.db)
storage: FileWallet = wm.storage
await storage.prepare()
# first, no defaults
self.ledger.conf.create_default_wallet = False
self.ledger.conf.create_default_account = False
await wm.initialize()
self.assertIsNone(wm.default)
# then, yes to default wallet but no to default account
self.ledger.conf.create_default_wallet = True
self.ledger.conf.create_default_account = False
await wm.initialize()
self.assertIsInstance(wm.default, Wallet)
self.assertTrue(os.path.exists(storage.get_wallet_path(wm.default.id)))
self.assertIsNone(wm.default.accounts.default)
# finally, yes to all the things
self.ledger.conf.create_default_wallet = True
self.ledger.conf.create_default_account = True
await wm.initialize()
self.assertIsInstance(wm.default, Wallet)
self.assertIsInstance(wm.default.accounts.default, Account)
async def test_load_with_create_default_everything_upfront(self):
wm = WalletManager(self.db)
await wm.storage.prepare()
self.ledger.conf.create_default_wallet = True
self.ledger.conf.create_default_account = True
await wm.initialize()
self.assertIsInstance(wm.default, Wallet)
self.assertIsInstance(wm.default.accounts.default, Account)
self.assertTrue(os.path.exists(wm.storage.get_wallet_path(wm.default.id)))
async def test_load_errors(self):
_wm = WalletManager(self.db)
await _wm.storage.prepare()
await _wm.create('bar', '')
await _wm.create('foo', '')
wm = WalletManager(self.db)
self.ledger.conf.wallets = ['bar', 'foo', 'foo']
with self.assertLogs(level='WARN') as cm:
await wm.initialize()
self.assertEqual(
cm.output, [
'WARNING:lbry.wallet.manager:Ignoring duplicate wallet_id in config: foo',
]
)
self.assertEqual({'bar', 'foo'}, set(wm.wallets))
async def test_creating_and_accessing_wallets(self):
wm = WalletManager(self.db)
await wm.storage.prepare()
await wm.initialize()
default_wallet = wm.default
self.assertEqual(default_wallet, wm['default_wallet'])
self.assertEqual(default_wallet, wm.get_or_default(None))
new_wallet = await wm.create('second', 'Second Wallet')
self.assertEqual(default_wallet, wm.default)
self.assertEqual(new_wallet, wm['second'])
self.assertEqual(new_wallet, wm.get_or_default('second'))
self.assertEqual(default_wallet, wm.get_or_default(None))
with self.assertRaisesRegex(ValueError, "Couldn't find wallet: invalid"):
_ = wm['invalid']
with self.assertRaisesRegex(ValueError, "Couldn't find wallet: invalid"):
wm.get_or_default('invalid')
async def test_read_write(self):
manager = WalletManager(self.db)
await manager.storage.prepare()
with tempfile.NamedTemporaryFile(suffix='.json') as wallet_file:
wallet_file.write(b'{"version": 1}')
wallet_file.seek(0)
# create and write wallet to a file
wallet = await manager.load(wallet_file.name)
account = await wallet.accounts.generate()
await manager.storage.save(wallet)
# read wallet from file
wallet = await manager.load(wallet_file.name)
self.assertEqual(account.public_key.address, wallet.accounts.default.public_key.address)

View file

@ -1,78 +0,0 @@
from lbry.testcase import UnitDBTestCase
class TestClientDBSync(UnitDBTestCase):
async def asyncSetUp(self):
await super().asyncSetUp()
await self.add(self.coinbase())
async def test_process_inputs(self):
await self.add(self.tx())
await self.add(self.tx())
txo1, txo2a, txo2b, txo3a, txo3b = self.outputs
self.assertEqual([
(txo1.id, None),
(txo2b.id, None),
], await self.get_txis())
self.assertEqual([
(txo1.id, False),
(txo2a.id, False),
(txo2b.id, False),
(txo3a.id, False),
(txo3b.id, False),
], await self.get_txos())
await self.db.process_all_things_after_sync()
self.assertEqual([
(txo1.id, txo1.get_address(self.ledger)),
(txo2b.id, txo2b.get_address(self.ledger)),
], await self.get_txis())
self.assertEqual([
(txo1.id, True),
(txo2a.id, False),
(txo2b.id, True),
(txo3a.id, False),
(txo3b.id, False),
], await self.get_txos())
async def test_process_claims(self):
claim1 = await self.add(self.create_claim())
await self.db.process_all_things_after_sync()
self.assertEqual([claim1.claim_id], await self.get_claims())
claim2 = await self.add(self.create_claim())
self.assertEqual([claim1.claim_id], await self.get_claims())
await self.db.process_all_things_after_sync()
self.assertEqual([claim1.claim_id, claim2.claim_id], await self.get_claims())
claim3 = await self.add(self.create_claim())
claim4 = await self.add(self.create_claim())
await self.db.process_all_things_after_sync()
self.assertEqual([
claim1.claim_id,
claim2.claim_id,
claim3.claim_id,
claim4.claim_id,
], await self.get_claims())
await self.add(self.abandon_claim(claim4))
await self.db.process_all_things_after_sync()
self.assertEqual([
claim1.claim_id, claim2.claim_id, claim3.claim_id
], await self.get_claims())
# create and abandon in same block
claim5 = await self.add(self.create_claim())
await self.add(self.abandon_claim(claim5))
await self.db.process_all_things_after_sync()
self.assertEqual([
claim1.claim_id, claim2.claim_id, claim3.claim_id
], await self.get_claims())
# create and abandon in different blocks but with bulk sync
claim6 = await self.add(self.create_claim())
await self.add(self.abandon_claim(claim6))
await self.db.process_all_things_after_sync()
self.assertEqual([
claim1.claim_id, claim2.claim_id, claim3.claim_id
], await self.get_claims())

View file

@ -1,18 +1,17 @@
import tempfile
from itertools import cycle
from binascii import hexlify
from unittest import TestCase, mock
from lbry import Config, Database, Ledger, Account, Wallet, WalletManager
from lbry.testcase import AsyncioTestCase
from lbry.wallet.storage import WalletStorage
from lbry import Config, Database, Ledger, Account, Wallet, Transaction, Output, Input
from lbry.testcase import AsyncioTestCase, get_output, COIN, CENT
from lbry.wallet.preferences import TimestampedPreferences
class WalletTestCase(AsyncioTestCase):
async def asyncSetUp(self):
self.ledger = Ledger(Config.with_null_dir())
self.db = Database(self.ledger, "sqlite:///:memory:")
self.ledger = Ledger(Config.with_null_dir().set(db_url='sqlite:///:memory:'))
self.db = Database(self.ledger)
await self.db.open()
self.addCleanup(self.db.close)
@ -20,8 +19,8 @@ class WalletTestCase(AsyncioTestCase):
class WalletAccountTest(WalletTestCase):
async def test_private_key_for_hierarchical_account(self):
wallet = Wallet(self.ledger, self.db)
account = wallet.add_account({
wallet = Wallet("wallet1", self.db)
account = await wallet.accounts.add_from_dict({
"seed":
"carbon smart garage balance margin twelve chest sword toas"
"t envelope bottom stomach absent"
@ -40,8 +39,8 @@ class WalletAccountTest(WalletTestCase):
)
async def test_private_key_for_single_address_account(self):
wallet = Wallet(self.ledger, self.db)
account = wallet.add_account({
wallet = Wallet("wallet1", self.db)
account = await wallet.accounts.add_from_dict({
"seed":
"carbon smart garage balance margin twelve chest sword toas"
"t envelope bottom stomach absent",
@ -59,9 +58,9 @@ class WalletAccountTest(WalletTestCase):
)
async def test_save_max_gap(self):
wallet = Wallet(self.ledger, self.db)
account = wallet.generate_account(
'lbryum', {
wallet = Wallet("wallet1", self.db)
account = await wallet.accounts.generate(
address_generator={
'name': 'deterministic-chain',
'receiving': {'gap': 3, 'maximum_uses_per_address': 2},
'change': {'gap': 4, 'maximum_uses_per_address': 2}
@ -73,24 +72,25 @@ class WalletAccountTest(WalletTestCase):
self.assertEqual(account.receiving.gap, 20)
self.assertEqual(account.change.gap, 6)
# doesn't fail for single-address account
wallet.generate_account('lbryum', {'name': 'single-address'})
await wallet.accounts.generate(address_generator={'name': 'single-address'})
await wallet.save_max_gap()
class TestWalletCreation(WalletTestCase):
def test_create_wallet_and_accounts(self):
wallet = Wallet(self.ledger, self.db)
self.assertEqual(wallet.name, 'Wallet')
self.assertListEqual(wallet.accounts, [])
async def test_create_wallet_and_accounts(self):
wallet = Wallet("wallet1", self.db)
self.assertEqual(wallet.id, "wallet1")
self.assertEqual(wallet.name, "")
self.assertEqual(list(wallet.accounts), [])
account1 = wallet.generate_account()
wallet.generate_account()
wallet.generate_account()
self.assertEqual(wallet.default_account, account1)
account1 = await wallet.accounts.generate()
await wallet.accounts.generate()
await wallet.accounts.generate()
self.assertEqual(wallet.accounts.default, account1)
self.assertEqual(len(wallet.accounts), 3)
def test_load_and_save_wallet(self):
async def test_load_and_save_wallet(self):
wallet_dict = {
'version': 1,
'name': 'Main Wallet',
@ -105,6 +105,7 @@ class TestWalletCreation(WalletTestCase):
"carbon smart garage balance margin twelve chest sword toast envelope bottom stomac"
"h absent",
'encrypted': False,
'lang': 'en',
'private_key':
'xprv9s21ZrQH143K42ovpZygnjfHdAqSd9jo7zceDfPRogM7bkkoNVv7'
'DRNLEoB8HoirMgH969NrgL8jNzLEegqFzPRWM37GXd4uE8uuRkx4LAe',
@ -120,15 +121,14 @@ class TestWalletCreation(WalletTestCase):
]
}
storage = WalletStorage(default=wallet_dict)
wallet = Wallet.from_storage(self.ledger, self.db, storage)
wallet = await Wallet.from_dict('wallet1', wallet_dict, self.db)
self.assertEqual(wallet.name, 'Main Wallet')
self.assertEqual(
hexlify(wallet.hash),
b'3b23aae8cd9b360f4296130b8f7afc5b2437560cdef7237bed245288ce8a5f79'
b'64a32cf8434a59c547abf61b4691a8189ac24272678b52ced2310fbf93eac974'
)
self.assertEqual(len(wallet.accounts), 1)
account = wallet.default_account
account = wallet.accounts.default
self.assertIsInstance(account, Account)
self.maxDiff = None
self.assertDictEqual(wallet_dict, wallet.to_dict())
@ -137,41 +137,23 @@ class TestWalletCreation(WalletTestCase):
decrypted = Wallet.unpack('password', encrypted)
self.assertEqual(decrypted['accounts'][0]['name'], 'An Account')
def test_read_write(self):
manager = WalletManager(self.ledger, self.db)
with tempfile.NamedTemporaryFile(suffix='.json') as wallet_file:
wallet_file.write(b'{"version": 1}')
wallet_file.seek(0)
# create and write wallet to a file
wallet = manager.import_wallet(wallet_file.name)
account = wallet.generate_account()
wallet.save()
# read wallet from file
wallet_storage = WalletStorage(wallet_file.name)
wallet = Wallet.from_storage(self.ledger, self.db, wallet_storage)
self.assertEqual(account.public_key.address, wallet.default_account.public_key.address)
def test_merge(self):
wallet1 = Wallet(self.ledger, self.db)
async def test_merge(self):
wallet1 = Wallet('wallet1', self.db)
wallet1.preferences['one'] = 1
wallet1.preferences['conflict'] = 1
wallet1.generate_account()
wallet2 = Wallet(self.ledger, self.db)
await wallet1.accounts.generate()
wallet2 = Wallet('wallet2', self.db)
wallet2.preferences['two'] = 2
wallet2.preferences['conflict'] = 2 # will be more recent
wallet2.generate_account()
await wallet2.accounts.generate()
self.assertEqual(len(wallet1.accounts), 1)
self.assertEqual(wallet1.preferences, {'one': 1, 'conflict': 1})
added = await wallet1.merge('password', wallet2.pack('password'))
self.assertEqual(added[0].id, wallet2.default_account.id)
self.assertEqual(added[0].id, wallet2.accounts.default.id)
self.assertEqual(len(wallet1.accounts), 2)
self.assertEqual(wallet1.accounts[1].id, wallet2.default_account.id)
self.assertEqual(list(wallet1.accounts)[1].id, wallet2.accounts.default.id)
self.assertEqual(wallet1.preferences, {'one': 1, 'two': 2, 'conflict': 2})
@ -213,3 +195,138 @@ class TestTimestampedPreferences(TestCase):
p1['conflict'] = 1
p1.merge(p2.data)
self.assertEqual(p1, {'one': 1, 'two': 2, 'conflict': 1})
class TestTransactionSigning(WalletTestCase):
async def test_sign(self):
wallet = Wallet('wallet1', self.db)
account = await wallet.accounts.add_from_dict({
"seed":
"carbon smart garage balance margin twelve chest sword toas"
"t envelope bottom stomach absent"
})
await account.ensure_address_gap()
address1, address2 = await account.receiving.get_addresses(limit=2)
pubkey_hash1 = self.ledger.address_to_hash160(address1)
pubkey_hash2 = self.ledger.address_to_hash160(address2)
tx = Transaction() \
.add_inputs([Input.spend(get_output(int(2*COIN), pubkey_hash1))]) \
.add_outputs([Output.pay_pubkey_hash(int(1.9*COIN), pubkey_hash2)])
await wallet.sign(tx)
self.assertEqual(
hexlify(tx.inputs[0].script.values['signature']),
b'304402200dafa26ad7cf38c5a971c8a25ce7d85a076235f146126762296b1223c42ae21e022020ef9eeb8'
b'398327891008c5c0be4357683f12cb22346691ff23914f457bf679601'
)
class TransactionIOBalancing(WalletTestCase):
async def asyncSetUp(self):
await super().asyncSetUp()
self.wallet = Wallet('wallet1', self.db)
self.account = await self.wallet.accounts.add_from_dict({
"seed":
"carbon smart garage balance margin twelve chest sword toas"
"t envelope bottom stomach absent"
})
addresses = await self.account.ensure_address_gap()
self.pubkey_hash = [self.ledger.address_to_hash160(a) for a in addresses]
self.hash_cycler = cycle(self.pubkey_hash)
def txo(self, amount, address=None):
return get_output(int(amount*COIN), address or next(self.hash_cycler))
def txi(self, txo):
return Input.spend(txo)
def tx(self, inputs, outputs):
return self.wallet.create_transaction(inputs, outputs, [self.account], self.account)
async def create_utxos(self, amounts):
utxos = [self.txo(amount) for amount in amounts]
self.funding_tx = Transaction(is_verified=True) \
.add_inputs([self.txi(self.txo(sum(amounts)+0.1))]) \
.add_outputs(utxos)
await self.db.insert_transaction(b'beef', self.funding_tx)
return utxos
@staticmethod
def inputs(tx):
return [round(i.amount/COIN, 2) for i in tx.inputs]
@staticmethod
def outputs(tx):
return [round(o.amount/COIN, 2) for o in tx.outputs]
async def test_basic_use_cases(self):
self.ledger.fee_per_byte = int(.01*CENT)
# available UTXOs for filling missing inputs
utxos = await self.create_utxos([
1, 1, 3, 5, 10
])
# pay 3 coins (3.02 w/ fees)
tx = await self.tx(
[], # inputs
[self.txo(3)] # outputs
)
# best UTXO match is 5 (as UTXO 3 will be short 0.02 to cover fees)
self.assertListEqual(self.inputs(tx), [5])
# a change of 1.98 is added to reach balance
self.assertListEqual(self.outputs(tx), [3, 1.98])
await self.db.release_outputs(utxos)
# pay 2.98 coins (3.00 w/ fees)
tx = await self.tx(
[], # inputs
[self.txo(2.98)] # outputs
)
# best UTXO match is 3 and no change is needed
self.assertListEqual(self.inputs(tx), [3])
self.assertListEqual(self.outputs(tx), [2.98])
await self.db.release_outputs(utxos)
# supplied input and output, but input is not enough to cover output
tx = await self.tx(
[self.txi(self.txo(10))], # inputs
[self.txo(11)] # outputs
)
# additional input is chosen (UTXO 3)
self.assertListEqual([10, 3], self.inputs(tx))
# change is now needed to consume extra input
self.assertListEqual([11, 1.96], self.outputs(tx))
await self.db.release_outputs(utxos)
# liquidating a UTXO
tx = await self.tx(
[self.txi(self.txo(10))], # inputs
[] # outputs
)
self.assertListEqual([10], self.inputs(tx))
# missing change added to consume the amount
self.assertListEqual([9.98], self.outputs(tx))
await self.db.release_outputs(utxos)
# liquidating at a loss, requires adding extra inputs
tx = await self.tx(
[self.txi(self.txo(0.01))], # inputs
[] # outputs
)
# UTXO 1 is added to cover some of the fee
self.assertListEqual([0.01, 1], self.inputs(tx))
# change is now needed to consume extra input
self.assertListEqual([0.97], self.outputs(tx))