forked from LBRYCommunity/lbry-sdk
working db based wallet and wallet sync progress
This commit is contained in:
parent
4356d23cc1
commit
6ed2fa20ec
28 changed files with 1315 additions and 2096 deletions
|
@ -42,13 +42,15 @@ class BlockchainSync(Sync):
|
||||||
super().__init__(chain.ledger, db)
|
super().__init__(chain.ledger, db)
|
||||||
self.chain = chain
|
self.chain = chain
|
||||||
self.pid = os.getpid()
|
self.pid = os.getpid()
|
||||||
|
self._on_block_controller = EventController()
|
||||||
|
self.on_block = self._on_block_controller.stream
|
||||||
|
self._on_mempool_controller = EventController()
|
||||||
|
self.on_mempool = self._on_mempool_controller.stream
|
||||||
self.on_block_hash_subscription: Optional[BroadcastSubscription] = None
|
self.on_block_hash_subscription: Optional[BroadcastSubscription] = None
|
||||||
self.on_tx_hash_subscription: Optional[BroadcastSubscription] = None
|
self.on_tx_hash_subscription: Optional[BroadcastSubscription] = None
|
||||||
self.advance_loop_task: Optional[asyncio.Task] = None
|
self.advance_loop_task: Optional[asyncio.Task] = None
|
||||||
self.block_hash_event = asyncio.Event()
|
self.block_hash_event = asyncio.Event()
|
||||||
self.tx_hash_event = asyncio.Event()
|
self.tx_hash_event = asyncio.Event()
|
||||||
self._on_mempool_controller = EventController()
|
|
||||||
self.on_mempool = self._on_mempool_controller.stream
|
|
||||||
self.mempool = []
|
self.mempool = []
|
||||||
|
|
||||||
async def wait_for_chain_ready(self):
|
async def wait_for_chain_ready(self):
|
||||||
|
|
|
@ -519,6 +519,7 @@ class Config(CLIConfig):
|
||||||
data_dir = Path("Main directory containing blobs, wallets and blockchain data.", metavar='DIR')
|
data_dir = Path("Main directory containing blobs, wallets and blockchain data.", metavar='DIR')
|
||||||
blob_dir = Path("Directory to store blobs (default: 'data_dir'/blobs).", metavar='DIR')
|
blob_dir = Path("Directory to store blobs (default: 'data_dir'/blobs).", metavar='DIR')
|
||||||
wallet_dir = Path("Directory to store wallets (default: 'data_dir'/wallets).", metavar='DIR')
|
wallet_dir = Path("Directory to store wallets (default: 'data_dir'/wallets).", metavar='DIR')
|
||||||
|
wallet_storage = StringChoice("Wallet storage mode.", ["file", "database"], "file")
|
||||||
wallets = Strings(
|
wallets = Strings(
|
||||||
"Wallet files in 'wallet_dir' to load at startup.", ['default_wallet']
|
"Wallet files in 'wallet_dir' to load at startup.", ['default_wallet']
|
||||||
)
|
)
|
||||||
|
|
|
@ -205,6 +205,9 @@ class Database:
|
||||||
async def execute(self, sql):
|
async def execute(self, sql):
|
||||||
return await self.run(q.execute, sql)
|
return await self.run(q.execute, sql)
|
||||||
|
|
||||||
|
async def execute_sql_object(self, sql):
|
||||||
|
return await self.run(q.execute_sql_object, sql)
|
||||||
|
|
||||||
async def execute_fetchall(self, sql):
|
async def execute_fetchall(self, sql):
|
||||||
return await self.run(q.execute_fetchall, sql)
|
return await self.run(q.execute_fetchall, sql)
|
||||||
|
|
||||||
|
@ -217,12 +220,27 @@ class Database:
|
||||||
async def has_supports(self):
|
async def has_supports(self):
|
||||||
return await self.run(q.has_supports)
|
return await self.run(q.has_supports)
|
||||||
|
|
||||||
|
async def has_wallet(self, wallet_id):
|
||||||
|
return await self.run(q.has_wallet, wallet_id)
|
||||||
|
|
||||||
|
async def get_wallet(self, wallet_id: str):
|
||||||
|
return await self.run(q.get_wallet, wallet_id)
|
||||||
|
|
||||||
|
async def add_wallet(self, wallet_id: str, data: str):
|
||||||
|
return await self.run(q.add_wallet, wallet_id, data)
|
||||||
|
|
||||||
async def get_best_block_height(self) -> int:
|
async def get_best_block_height(self) -> int:
|
||||||
return await self.run(q.get_best_block_height)
|
return await self.run(q.get_best_block_height)
|
||||||
|
|
||||||
async def process_all_things_after_sync(self):
|
async def process_all_things_after_sync(self):
|
||||||
return await self.run(sync.process_all_things_after_sync)
|
return await self.run(sync.process_all_things_after_sync)
|
||||||
|
|
||||||
|
async def get_blocks(self, first, last=None):
|
||||||
|
return await self.run(q.get_blocks, first, last)
|
||||||
|
|
||||||
|
async def get_filters(self, start_height, end_height=None, granularity=0):
|
||||||
|
return await self.run(q.get_filters, start_height, end_height, granularity)
|
||||||
|
|
||||||
async def insert_block(self, block):
|
async def insert_block(self, block):
|
||||||
return await self.run(q.insert_block, block)
|
return await self.run(q.insert_block, block)
|
||||||
|
|
||||||
|
|
|
@ -3,3 +3,4 @@ from .txio import *
|
||||||
from .search import *
|
from .search import *
|
||||||
from .resolve import *
|
from .resolve import *
|
||||||
from .address import *
|
from .address import *
|
||||||
|
from .wallet import *
|
||||||
|
|
|
@ -62,11 +62,11 @@ def add_keys(account, chain, pubkeys):
|
||||||
c = context()
|
c = context()
|
||||||
c.execute(
|
c.execute(
|
||||||
c.insert_or_ignore(PubkeyAddress)
|
c.insert_or_ignore(PubkeyAddress)
|
||||||
.values([{'address': k.address} for k in pubkeys])
|
.values([{'address': k.address} for k in pubkeys])
|
||||||
)
|
)
|
||||||
c.execute(
|
c.execute(
|
||||||
c.insert_or_ignore(AccountAddress)
|
c.insert_or_ignore(AccountAddress)
|
||||||
.values([{
|
.values([{
|
||||||
'account': account.id,
|
'account': account.id,
|
||||||
'address': k.address,
|
'address': k.address,
|
||||||
'chain': chain,
|
'chain': chain,
|
||||||
|
|
|
@ -1,14 +1,24 @@
|
||||||
from sqlalchemy import text
|
from math import log10
|
||||||
|
from binascii import hexlify
|
||||||
|
|
||||||
|
from sqlalchemy import text, between
|
||||||
from sqlalchemy.future import select
|
from sqlalchemy.future import select
|
||||||
|
|
||||||
from ..query_context import context
|
from ..query_context import context
|
||||||
from ..tables import SCHEMA_VERSION, metadata, Version, Claim, Support, Block, BlockFilter, TX
|
from ..tables import (
|
||||||
|
SCHEMA_VERSION, metadata, Version,
|
||||||
|
Claim, Support, Block, BlockFilter, BlockGroupFilter, TX, TXFilter,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def execute(sql):
|
def execute(sql):
|
||||||
return context().execute(text(sql))
|
return context().execute(text(sql))
|
||||||
|
|
||||||
|
|
||||||
|
def execute_sql_object(sql):
|
||||||
|
return context().execute(sql)
|
||||||
|
|
||||||
|
|
||||||
def execute_fetchall(sql):
|
def execute_fetchall(sql):
|
||||||
return context().fetchall(text(sql))
|
return context().fetchall(text(sql))
|
||||||
|
|
||||||
|
@ -33,6 +43,53 @@ def insert_block(block):
|
||||||
context().get_bulk_loader().add_block(block).flush()
|
context().get_bulk_loader().add_block(block).flush()
|
||||||
|
|
||||||
|
|
||||||
|
def get_blocks(first, last=None):
|
||||||
|
if last is not None:
|
||||||
|
query = (
|
||||||
|
select('*').select_from(Block)
|
||||||
|
.where(between(Block.c.height, first, last))
|
||||||
|
.order_by(Block.c.height)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
query = select('*').select_from(Block).where(Block.c.height == first)
|
||||||
|
return context().fetchall(query)
|
||||||
|
|
||||||
|
|
||||||
|
def get_filters(start_height, end_height=None, granularity=0):
|
||||||
|
assert granularity >= 0, "filter granularity must be 0 or positive number"
|
||||||
|
if granularity == 0:
|
||||||
|
query = (
|
||||||
|
select('*').select_from(TXFilter)
|
||||||
|
.where(between(TXFilter.c.height, start_height, end_height))
|
||||||
|
.order_by(TXFilter.c.height)
|
||||||
|
)
|
||||||
|
elif granularity == 1:
|
||||||
|
query = (
|
||||||
|
select('*').select_from(BlockFilter)
|
||||||
|
.where(between(BlockFilter.c.height, start_height, end_height))
|
||||||
|
.order_by(BlockFilter.c.height)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
query = (
|
||||||
|
select('*').select_from(BlockGroupFilter)
|
||||||
|
.where(
|
||||||
|
(BlockGroupFilter.c.height == start_height) &
|
||||||
|
(BlockGroupFilter.c.factor == log10(granularity))
|
||||||
|
)
|
||||||
|
.order_by(BlockGroupFilter.c.height)
|
||||||
|
)
|
||||||
|
result = []
|
||||||
|
for row in context().fetchall(query):
|
||||||
|
record = {
|
||||||
|
"height": row["height"],
|
||||||
|
"filter": hexlify(row["address_filter"]).decode(),
|
||||||
|
}
|
||||||
|
if granularity == 0:
|
||||||
|
record["txid"] = hexlify(row["tx_hash"][::-1]).decode()
|
||||||
|
result.append(record)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def insert_transaction(block_hash, tx):
|
def insert_transaction(block_hash, tx):
|
||||||
context().get_bulk_loader().add_transaction(block_hash, tx).flush(TX)
|
context().get_bulk_loader().add_transaction(block_hash, tx).flush(TX)
|
||||||
|
|
||||||
|
|
24
lbry/db/queries/wallet.py
Normal file
24
lbry/db/queries/wallet.py
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
from sqlalchemy import exists
|
||||||
|
from sqlalchemy.future import select
|
||||||
|
|
||||||
|
from ..query_context import context
|
||||||
|
from ..tables import Wallet
|
||||||
|
|
||||||
|
|
||||||
|
def has_wallet(wallet_id: str) -> bool:
|
||||||
|
sql = select(exists(select(Wallet.c.wallet_id).where(Wallet.c.wallet_id == wallet_id)))
|
||||||
|
return context().execute(sql).fetchone()[0]
|
||||||
|
|
||||||
|
|
||||||
|
def get_wallet(wallet_id: str):
|
||||||
|
return context().fetchone(
|
||||||
|
select(Wallet.c.data).where(Wallet.c.wallet_id == wallet_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def add_wallet(wallet_id: str, data: str):
|
||||||
|
c = context()
|
||||||
|
c.execute(
|
||||||
|
c.insert_or_replace(Wallet, ["data"])
|
||||||
|
.values(wallet_id=wallet_id, data=data)
|
||||||
|
)
|
|
@ -19,6 +19,13 @@ Version = Table(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Wallet = Table(
|
||||||
|
'wallet', metadata,
|
||||||
|
Column('wallet_id', Text, primary_key=True),
|
||||||
|
Column('data', Text),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
PubkeyAddress = Table(
|
PubkeyAddress = Table(
|
||||||
'pubkey_address', metadata,
|
'pubkey_address', metadata,
|
||||||
Column('address', Text, primary_key=True),
|
Column('address', Text, primary_key=True),
|
||||||
|
|
|
@ -1225,7 +1225,7 @@ class API:
|
||||||
wallet = self.wallets.get_or_default(wallet_id)
|
wallet = self.wallets.get_or_default(wallet_id)
|
||||||
wallet_changed = False
|
wallet_changed = False
|
||||||
if data is not None:
|
if data is not None:
|
||||||
added_accounts = await wallet.merge(self.wallets, password, data)
|
added_accounts = await wallet.merge(password, data)
|
||||||
if added_accounts and self.ledger.sync.network.is_connected:
|
if added_accounts and self.ledger.sync.network.is_connected:
|
||||||
if blocking:
|
if blocking:
|
||||||
await asyncio.wait([
|
await asyncio.wait([
|
||||||
|
@ -1318,11 +1318,24 @@ class API:
|
||||||
.receiving.get_or_create_usable_address()
|
.receiving.get_or_create_usable_address()
|
||||||
)
|
)
|
||||||
|
|
||||||
async def address_block_filters(self):
|
async def address_filter(
|
||||||
return await self.service.get_block_address_filters()
|
self,
|
||||||
|
start_height: int, # starting height of block range or just single block
|
||||||
|
end_height: int = None, # return a range of blocks from start_height to end_height
|
||||||
|
granularity: int = None, # 0 - individual tx filters, 1 - block per filter,
|
||||||
|
# 1000, 10000, 100000 blocks per filter
|
||||||
|
) -> list: # blocks
|
||||||
|
"""
|
||||||
|
List address filters
|
||||||
|
|
||||||
async def address_transaction_filters(self, block_hash: str):
|
Usage:
|
||||||
return await self.service.get_transaction_address_filters(block_hash)
|
address_filter <start_height>
|
||||||
|
[--end_height=<end_height>]
|
||||||
|
[--granularity=<granularity>]
|
||||||
|
"""
|
||||||
|
return await self.service.get_address_filters(
|
||||||
|
start_height=start_height, end_height=end_height, granularity=granularity
|
||||||
|
)
|
||||||
|
|
||||||
FILE_DOC = """
|
FILE_DOC = """
|
||||||
File management.
|
File management.
|
||||||
|
@ -2656,6 +2669,23 @@ class API:
|
||||||
await self.service.maybe_broadcast_or_release(tx, blocking, preview)
|
await self.service.maybe_broadcast_or_release(tx, blocking, preview)
|
||||||
return tx
|
return tx
|
||||||
|
|
||||||
|
BLOCK_DOC = """
|
||||||
|
Block information.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def block_list(
|
||||||
|
self,
|
||||||
|
start_height: int, # starting height of block range or just single block
|
||||||
|
end_height: int = None, # return a range of blocks from start_height to end_height
|
||||||
|
) -> list: # blocks
|
||||||
|
"""
|
||||||
|
List block info
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
block_list <start_height> [<end_height>]
|
||||||
|
"""
|
||||||
|
return await self.service.get_blocks(start_height=start_height, end_height=end_height)
|
||||||
|
|
||||||
TRANSACTION_DOC = """
|
TRANSACTION_DOC = """
|
||||||
Transaction management.
|
Transaction management.
|
||||||
"""
|
"""
|
||||||
|
@ -3529,8 +3559,12 @@ class Client(API):
|
||||||
self.receive_messages_task = asyncio.create_task(self.receive_messages())
|
self.receive_messages_task = asyncio.create_task(self.receive_messages())
|
||||||
|
|
||||||
async def disconnect(self):
|
async def disconnect(self):
|
||||||
await self.session.close()
|
if self.session is not None:
|
||||||
self.receive_messages_task.cancel()
|
await self.session.close()
|
||||||
|
self.session = None
|
||||||
|
if self.receive_messages_task is not None:
|
||||||
|
self.receive_messages_task.cancel()
|
||||||
|
self.receive_messages_task = None
|
||||||
|
|
||||||
async def receive_messages(self):
|
async def receive_messages(self):
|
||||||
async for message in self.ws:
|
async for message in self.ws:
|
||||||
|
@ -3559,12 +3593,16 @@ class Client(API):
|
||||||
await self.ws.send_json({'id': self.message_id, 'method': method, 'params': kwargs})
|
await self.ws.send_json({'id': self.message_id, 'method': method, 'params': kwargs})
|
||||||
return ec.stream
|
return ec.stream
|
||||||
|
|
||||||
async def subscribe(self, event) -> EventStream:
|
def get_event_stream(self, event) -> EventStream:
|
||||||
if event not in self.subscriptions:
|
if event not in self.subscriptions:
|
||||||
self.subscriptions[event] = EventController()
|
self.subscriptions[event] = EventController()
|
||||||
await self.ws.send_json({'id': None, 'method': 'subscribe', 'params': [event]})
|
|
||||||
return self.subscriptions[event].stream
|
return self.subscriptions[event].stream
|
||||||
|
|
||||||
|
async def start_event_streams(self):
|
||||||
|
events = list(self.subscriptions.keys())
|
||||||
|
if events:
|
||||||
|
await self.ws.send_json({'id': None, 'method': 'subscribe', 'params': events})
|
||||||
|
|
||||||
def __getattribute__(self, name):
|
def __getattribute__(self, name):
|
||||||
if name in dir(API):
|
if name in dir(API):
|
||||||
return partial(object.__getattribute__(self, 'send'), name)
|
return partial(object.__getattribute__(self, 'send'), name)
|
||||||
|
|
|
@ -9,7 +9,7 @@ from lbry.schema.result import Censor
|
||||||
from lbry.blockchain.transaction import Transaction, Output
|
from lbry.blockchain.transaction import Transaction, Output
|
||||||
from lbry.blockchain.ledger import Ledger
|
from lbry.blockchain.ledger import Ledger
|
||||||
from lbry.wallet import WalletManager
|
from lbry.wallet import WalletManager
|
||||||
from lbry.event import EventController
|
from lbry.event import EventController, EventStream
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -25,14 +25,14 @@ class Sync:
|
||||||
Server stays synced with lbrycrd
|
Server stays synced with lbrycrd
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
on_block: EventStream
|
||||||
|
on_mempool: EventStream
|
||||||
|
|
||||||
def __init__(self, ledger: Ledger, db: Database):
|
def __init__(self, ledger: Ledger, db: Database):
|
||||||
self.ledger = ledger
|
self.ledger = ledger
|
||||||
self.conf = ledger.conf
|
self.conf = ledger.conf
|
||||||
self.db = db
|
self.db = db
|
||||||
|
|
||||||
self._on_block_controller = EventController()
|
|
||||||
self.on_block = self._on_block_controller.stream
|
|
||||||
|
|
||||||
self._on_progress_controller = db._on_progress_controller
|
self._on_progress_controller = db._on_progress_controller
|
||||||
self.on_progress = db.on_progress
|
self.on_progress = db.on_progress
|
||||||
|
|
||||||
|
@ -63,13 +63,7 @@ class Service:
|
||||||
def __init__(self, ledger: Ledger):
|
def __init__(self, ledger: Ledger):
|
||||||
self.ledger, self.conf = ledger, ledger.conf
|
self.ledger, self.conf = ledger, ledger.conf
|
||||||
self.db = Database(ledger)
|
self.db = Database(ledger)
|
||||||
self.wallets = WalletManager(ledger, self.db)
|
self.wallets = WalletManager(self.db)
|
||||||
|
|
||||||
#self.on_address = sync.on_address
|
|
||||||
#self.accounts = sync.accounts
|
|
||||||
#self.on_header = sync.on_header
|
|
||||||
#self.on_ready = sync.on_ready
|
|
||||||
#self.on_transaction = sync.on_transaction
|
|
||||||
|
|
||||||
# sync has established connection with a source from which it can synchronize
|
# sync has established connection with a source from which it can synchronize
|
||||||
# for full service this is lbrycrd (or sync service) and for light this is full node
|
# for full service this is lbrycrd (or sync service) and for light this is full node
|
||||||
|
@ -78,8 +72,8 @@ class Service:
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
await self.db.open()
|
await self.db.open()
|
||||||
await self.wallets.ensure_path_exists()
|
await self.wallets.storage.prepare()
|
||||||
await self.wallets.load()
|
await self.wallets.initialize()
|
||||||
await self.sync.start()
|
await self.sync.start()
|
||||||
|
|
||||||
async def stop(self):
|
async def stop(self):
|
||||||
|
@ -95,16 +89,21 @@ class Service:
|
||||||
async def find_ffmpeg(self):
|
async def find_ffmpeg(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def get(self, uri, **kwargs):
|
async def get_file(self, uri, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def create_wallet(self, file_name):
|
async def get_block_headers(self, first, last=None):
|
||||||
path = os.path.join(self.conf.wallet_dir, file_name)
|
pass
|
||||||
return self.wallets.add_from_path(path)
|
|
||||||
|
def create_wallet(self, wallet_id):
|
||||||
|
return self.wallets.create(wallet_id)
|
||||||
|
|
||||||
async def get_addresses(self, **constraints):
|
async def get_addresses(self, **constraints):
|
||||||
return await self.db.get_addresses(**constraints)
|
return await self.db.get_addresses(**constraints)
|
||||||
|
|
||||||
|
async def get_address_filters(self, start_height: int, end_height: int=None, granularity: int=0):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def reserve_outputs(self, txos):
|
def reserve_outputs(self, txos):
|
||||||
return self.db.reserve_outputs(txos)
|
return self.db.reserve_outputs(txos)
|
||||||
|
|
||||||
|
|
|
@ -19,11 +19,13 @@ from lbry.blockchain.ledger import ledger_class_from_name
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def jsonrpc_dumps_pretty(obj, **kwargs):
|
def jsonrpc_dumps_pretty(obj, message_id=None, **kwargs):
|
||||||
#if not isinstance(obj, dict):
|
#if not isinstance(obj, dict):
|
||||||
# data = {"jsonrpc": "2.0", "error": obj.to_dict()}
|
# data = {"jsonrpc": "2.0", "error": obj.to_dict()}
|
||||||
#else:
|
#else:
|
||||||
data = {"jsonrpc": "2.0", "result": obj}
|
data = {"jsonrpc": "2.0", "result": obj}
|
||||||
|
if message_id is not None:
|
||||||
|
data["id"] = message_id
|
||||||
return json.dumps(data, cls=JSONResponseEncoder, sort_keys=True, indent=2, **kwargs) + "\n"
|
return json.dumps(data, cls=JSONResponseEncoder, sort_keys=True, indent=2, **kwargs) + "\n"
|
||||||
|
|
||||||
|
|
||||||
|
@ -166,7 +168,7 @@ class Daemon:
|
||||||
|
|
||||||
async def on_message(self, web_socket: WebSocketManager, msg: dict):
|
async def on_message(self, web_socket: WebSocketManager, msg: dict):
|
||||||
if msg['method'] == 'subscribe':
|
if msg['method'] == 'subscribe':
|
||||||
streams = msg['streams']
|
streams = msg['params']
|
||||||
if isinstance(streams, str):
|
if isinstance(streams, str):
|
||||||
streams = [streams]
|
streams = [streams]
|
||||||
web_socket.subscribe(streams, self.app['subscriptions'])
|
web_socket.subscribe(streams, self.app['subscriptions'])
|
||||||
|
@ -175,11 +177,10 @@ class Daemon:
|
||||||
method = getattr(self.api, msg['method'])
|
method = getattr(self.api, msg['method'])
|
||||||
try:
|
try:
|
||||||
result = await method(**params)
|
result = await method(**params)
|
||||||
encoded_result = jsonrpc_dumps_pretty(result, service=self.service)
|
encoded_result = jsonrpc_dumps_pretty(
|
||||||
await web_socket.send_json({
|
result, message_id=msg.get('id', ''), service=self.service
|
||||||
'id': msg.get('id', ''),
|
)
|
||||||
'result': encoded_result
|
await web_socket.send_str(encoded_result)
|
||||||
})
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception("RPC error")
|
log.exception("RPC error")
|
||||||
await web_socket.send_json({'id': msg.get('id', ''), 'result': "unexpected error: " + str(e)})
|
await web_socket.send_json({'id': msg.get('id', ''), 'result': "unexpected error: " + str(e)})
|
||||||
|
|
61
lbry/service/full_endpoint.py
Normal file
61
lbry/service/full_endpoint.py
Normal file
|
@ -0,0 +1,61 @@
|
||||||
|
import logging
|
||||||
|
from binascii import hexlify, unhexlify
|
||||||
|
|
||||||
|
from lbry.blockchain.lbrycrd import Lbrycrd
|
||||||
|
from lbry.blockchain.sync import BlockchainSync
|
||||||
|
from lbry.blockchain.ledger import Ledger
|
||||||
|
from lbry.blockchain.transaction import Transaction
|
||||||
|
|
||||||
|
from .base import Service, Sync
|
||||||
|
from .api import Client as APIClient
|
||||||
|
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class NoSync(Sync):
|
||||||
|
|
||||||
|
def __init__(self, service: Service, client: APIClient):
|
||||||
|
super().__init__(service.ledger, service.db)
|
||||||
|
self.service = service
|
||||||
|
self.client = client
|
||||||
|
self.on_block = client.get_event_stream('blockchain.block')
|
||||||
|
self.on_block_subscription: Optional[BroadcastSubscription] = None
|
||||||
|
self.on_mempool = client.get_event_stream('blockchain.mempool')
|
||||||
|
self.on_mempool_subscription: Optional[BroadcastSubscription] = None
|
||||||
|
|
||||||
|
async def wait_for_client_ready(self):
|
||||||
|
await self.client.connect()
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
self.db.stop_event.clear()
|
||||||
|
await self.wait_for_client_ready()
|
||||||
|
self.advance_loop_task = asyncio.create_task(self.advance())
|
||||||
|
await self.advance_loop_task
|
||||||
|
await self.client.subscribe()
|
||||||
|
self.advance_loop_task = asyncio.create_task(self.advance_loop())
|
||||||
|
self.on_block_subscription = self.on_block.listen(
|
||||||
|
lambda e: self.on_block_event.set()
|
||||||
|
)
|
||||||
|
self.on_mempool_subscription = self.on_mempool.listen(
|
||||||
|
lambda e: self.on_mempool_event.set()
|
||||||
|
)
|
||||||
|
await self.download_filters()
|
||||||
|
await self.download_headers()
|
||||||
|
|
||||||
|
async def stop(self):
|
||||||
|
await self.client.disconnect()
|
||||||
|
|
||||||
|
|
||||||
|
class FullEndpoint(Service):
|
||||||
|
|
||||||
|
name = "endpoint"
|
||||||
|
|
||||||
|
sync: 'NoSync'
|
||||||
|
|
||||||
|
def __init__(self, ledger: Ledger):
|
||||||
|
super().__init__(ledger)
|
||||||
|
self.client = APIClient(
|
||||||
|
f"http://{ledger.conf.full_nodes[0][0]}:{ledger.conf.full_nodes[0][1]}/api"
|
||||||
|
)
|
||||||
|
self.sync = NoSync(self, self.client)
|
|
@ -35,7 +35,15 @@ class FullNode(Service):
|
||||||
async def get_status(self):
|
async def get_status(self):
|
||||||
return 'everything is wonderful'
|
return 'everything is wonderful'
|
||||||
|
|
||||||
# async def get_block_address_filters(self):
|
async def get_block_headers(self, first, last=None):
|
||||||
|
return await self.db.get_blocks(first, last)
|
||||||
|
|
||||||
|
async def get_address_filters(self, start_height: int, end_height: int = None, granularity: int = 0):
|
||||||
|
return await self.db.get_filters(
|
||||||
|
start_height=start_height, end_height=end_height, granularity=granularity
|
||||||
|
)
|
||||||
|
|
||||||
|
# async def get_block_address_filters(self):
|
||||||
# return {
|
# return {
|
||||||
# hexlify(f['block_hash']).decode(): hexlify(f['block_filter']).decode()
|
# hexlify(f['block_hash']).decode(): hexlify(f['block_filter']).decode()
|
||||||
# for f in await self.db.get_block_address_filters()
|
# for f in await self.db.get_block_address_filters()
|
||||||
|
|
|
@ -1,11 +1,25 @@
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from typing import List, Dict, Tuple
|
from typing import List, Dict
|
||||||
|
#from io import StringIO
|
||||||
|
#from functools import partial
|
||||||
|
#from operator import itemgetter
|
||||||
|
from collections import defaultdict
|
||||||
|
#from binascii import hexlify, unhexlify
|
||||||
|
from typing import List, Optional, DefaultDict, NamedTuple
|
||||||
|
|
||||||
|
#from lbry.crypto.hash import double_sha256, sha256
|
||||||
|
|
||||||
|
from lbry.tasks import TaskGroup
|
||||||
|
from lbry.blockchain.transaction import Transaction
|
||||||
|
from lbry.blockchain.block import get_address_filter
|
||||||
|
from lbry.event import BroadcastSubscription, EventController
|
||||||
|
from lbry.wallet.account import AddressManager
|
||||||
from lbry.blockchain import Ledger, Transaction
|
from lbry.blockchain import Ledger, Transaction
|
||||||
from lbry.wallet.sync import SPVSync
|
|
||||||
|
|
||||||
from .base import Service
|
from .base import Service, Sync
|
||||||
from .api import Client
|
from .api import Client as APIClient
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -14,23 +28,24 @@ class LightClient(Service):
|
||||||
|
|
||||||
name = "client"
|
name = "client"
|
||||||
|
|
||||||
sync: SPVSync
|
sync: 'FastSync'
|
||||||
|
|
||||||
def __init__(self, ledger: Ledger):
|
def __init__(self, ledger: Ledger):
|
||||||
super().__init__(ledger)
|
super().__init__(ledger)
|
||||||
self.client = Client(
|
self.client = APIClient(
|
||||||
f"http://{ledger.conf.full_nodes[0][0]}:{ledger.conf.full_nodes[0][1]}/api"
|
f"http://{ledger.conf.full_nodes[0][0]}:{ledger.conf.full_nodes[0][1]}/ws"
|
||||||
)
|
)
|
||||||
self.sync = SPVSync(self)
|
self.sync = FastSync(self, self.client)
|
||||||
|
self.blocks = BlockHeaderManager(self.db, self.client)
|
||||||
|
self.filters = FilterManager(self.db, self.client)
|
||||||
|
|
||||||
async def search_transactions(self, txids):
|
async def search_transactions(self, txids):
|
||||||
return await self.client.transaction_search(txids=txids)
|
return await self.client.transaction_search(txids=txids)
|
||||||
|
|
||||||
async def get_block_address_filters(self):
|
async def get_address_filters(self, start_height: int, end_height: int = None, granularity: int = 0):
|
||||||
return await self.client.address_block_filters()
|
return await self.filters.get_filters(
|
||||||
|
start_height=start_height, end_height=end_height, granularity=granularity
|
||||||
async def get_transaction_address_filters(self, block_hash):
|
)
|
||||||
return await self.client.address_transaction_filters(block_hash=block_hash)
|
|
||||||
|
|
||||||
async def broadcast(self, tx):
|
async def broadcast(self, tx):
|
||||||
pass
|
pass
|
||||||
|
@ -47,6 +62,437 @@ class LightClient(Service):
|
||||||
async def search_supports(self, accounts, **kwargs):
|
async def search_supports(self, accounts, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def sum_supports(self, claim_hash: bytes, include_channel_content=False, exclude_own_supports=False) \
|
async def sum_supports(self, claim_hash: bytes, include_channel_content=False) -> List[Dict]:
|
||||||
-> Tuple[List[Dict], int]:
|
return await self.client.sum_supports(claim_hash, include_channel_content)
|
||||||
return await self.client.sum_supports(claim_hash, include_channel_content, exclude_own_supports)
|
|
||||||
|
|
||||||
|
class TransactionEvent(NamedTuple):
|
||||||
|
address: str
|
||||||
|
tx: Transaction
|
||||||
|
|
||||||
|
|
||||||
|
class AddressesGeneratedEvent(NamedTuple):
|
||||||
|
address_manager: AddressManager
|
||||||
|
addresses: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
class TransactionCacheItem:
|
||||||
|
__slots__ = '_tx', 'lock', 'has_tx', 'pending_verifications'
|
||||||
|
|
||||||
|
def __init__(self, tx: Optional[Transaction] = None, lock: Optional[asyncio.Lock] = None):
|
||||||
|
self.has_tx = asyncio.Event()
|
||||||
|
self.lock = lock or asyncio.Lock()
|
||||||
|
self._tx = self.tx = tx
|
||||||
|
self.pending_verifications = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tx(self) -> Optional[Transaction]:
|
||||||
|
return self._tx
|
||||||
|
|
||||||
|
@tx.setter
|
||||||
|
def tx(self, tx: Transaction):
|
||||||
|
self._tx = tx
|
||||||
|
if tx is not None:
|
||||||
|
self.has_tx.set()
|
||||||
|
|
||||||
|
|
||||||
|
class FilterManager:
|
||||||
|
"""
|
||||||
|
Efficient on-demand address filter access.
|
||||||
|
Stores and retrieves from local db what it previously downloaded and
|
||||||
|
downloads on-demand what it doesn't have from full node.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, db, client):
|
||||||
|
self.db = db
|
||||||
|
self.client = client
|
||||||
|
self.cache = {}
|
||||||
|
|
||||||
|
async def get_filters(self, start_height, end_height, granularity):
|
||||||
|
return await self.client.address_filter(
|
||||||
|
start_height=start_height, end_height=end_height, granularity=granularity
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BlockHeaderManager:
|
||||||
|
"""
|
||||||
|
Efficient on-demand block header access.
|
||||||
|
Stores and retrieves from local db what it previously downloaded and
|
||||||
|
downloads on-demand what it doesn't have from full node.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, db, client):
|
||||||
|
self.db = db
|
||||||
|
self.client = client
|
||||||
|
self.cache = {}
|
||||||
|
|
||||||
|
async def get_header(self, height):
|
||||||
|
blocks = await self.client.block_list(height)
|
||||||
|
if blocks:
|
||||||
|
return blocks[0]
|
||||||
|
|
||||||
|
async def add(self, header):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def download(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class FastSync(Sync):
|
||||||
|
|
||||||
|
def __init__(self, service: Service, client: APIClient):
|
||||||
|
super().__init__(service.ledger, service.db)
|
||||||
|
self.service = service
|
||||||
|
self.client = client
|
||||||
|
self.advance_loop_task: Optional[asyncio.Task] = None
|
||||||
|
self.on_block = client.get_event_stream('blockchain.block')
|
||||||
|
self.on_block_event = asyncio.Event()
|
||||||
|
self.on_block_subscription: Optional[BroadcastSubscription] = None
|
||||||
|
self.on_mempool = client.get_event_stream('blockchain.mempool')
|
||||||
|
self.on_mempool_event = asyncio.Event()
|
||||||
|
self.on_mempool_subscription: Optional[BroadcastSubscription] = None
|
||||||
|
|
||||||
|
async def wait_for_client_ready(self):
|
||||||
|
await self.client.connect()
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
return
|
||||||
|
self.db.stop_event.clear()
|
||||||
|
await self.wait_for_client_ready()
|
||||||
|
self.advance_loop_task = asyncio.create_task(self.advance())
|
||||||
|
await self.advance_loop_task
|
||||||
|
await self.client.subscribe()
|
||||||
|
self.advance_loop_task = asyncio.create_task(self.advance_loop())
|
||||||
|
self.on_block_subscription = self.on_block.listen(
|
||||||
|
lambda e: self.on_block_event.set()
|
||||||
|
)
|
||||||
|
self.on_mempool_subscription = self.on_mempool.listen(
|
||||||
|
lambda e: self.on_mempool_event.set()
|
||||||
|
)
|
||||||
|
await self.download_filters()
|
||||||
|
await self.download_headers()
|
||||||
|
|
||||||
|
async def stop(self):
|
||||||
|
await self.client.disconnect()
|
||||||
|
|
||||||
|
async def advance(self):
|
||||||
|
address_array = [
|
||||||
|
bytearray(a['address'].encode())
|
||||||
|
for a in await self.service.db.get_all_addresses()
|
||||||
|
]
|
||||||
|
block_filters = await self.service.get_block_address_filters()
|
||||||
|
for block_hash, block_filter in block_filters.items():
|
||||||
|
bf = get_address_filter(block_filter)
|
||||||
|
if bf.MatchAny(address_array):
|
||||||
|
print(f'match: {block_hash} - {block_filter}')
|
||||||
|
tx_filters = await self.service.get_transaction_address_filters(block_hash=block_hash)
|
||||||
|
for txid, tx_filter in tx_filters.items():
|
||||||
|
tf = get_address_filter(tx_filter)
|
||||||
|
if tf.MatchAny(address_array):
|
||||||
|
print(f' match: {txid} - {tx_filter}')
|
||||||
|
txs = await self.service.search_transactions([txid])
|
||||||
|
tx = Transaction(unhexlify(txs[txid]))
|
||||||
|
await self.service.db.insert_transaction(tx)
|
||||||
|
|
||||||
|
# async def get_local_status_and_history(self, address, history=None):
|
||||||
|
# if not history:
|
||||||
|
# address_details = await self.db.get_address(address=address)
|
||||||
|
# history = (address_details['history'] if address_details else '') or ''
|
||||||
|
# parts = history.split(':')[:-1]
|
||||||
|
# return (
|
||||||
|
# hexlify(sha256(history.encode())).decode() if history else None,
|
||||||
|
# list(zip(parts[0::2], map(int, parts[1::2])))
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# @staticmethod
|
||||||
|
# def get_root_of_merkle_tree(branches, branch_positions, working_branch):
|
||||||
|
# for i, branch in enumerate(branches):
|
||||||
|
# other_branch = unhexlify(branch)[::-1]
|
||||||
|
# other_branch_on_left = bool((branch_positions >> i) & 1)
|
||||||
|
# if other_branch_on_left:
|
||||||
|
# combined = other_branch + working_branch
|
||||||
|
# else:
|
||||||
|
# combined = working_branch + other_branch
|
||||||
|
# working_branch = double_sha256(combined)
|
||||||
|
# return hexlify(working_branch[::-1])
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# @property
|
||||||
|
# def local_height_including_downloaded_height(self):
|
||||||
|
# return max(self.headers.height, self._download_height)
|
||||||
|
#
|
||||||
|
# async def initial_headers_sync(self):
|
||||||
|
# get_chunk = partial(self.network.retriable_call, self.network.get_headers, count=1000, b64=True)
|
||||||
|
# self.headers.chunk_getter = get_chunk
|
||||||
|
#
|
||||||
|
# async def doit():
|
||||||
|
# for height in reversed(sorted(self.headers.known_missing_checkpointed_chunks)):
|
||||||
|
# async with self._header_processing_lock:
|
||||||
|
# await self.headers.ensure_chunk_at(height)
|
||||||
|
# self._other_tasks.add(doit())
|
||||||
|
# await self.update_headers()
|
||||||
|
#
|
||||||
|
# async def update_headers(self, height=None, headers=None, subscription_update=False):
|
||||||
|
# rewound = 0
|
||||||
|
# while True:
|
||||||
|
#
|
||||||
|
# if height is None or height > len(self.headers):
|
||||||
|
# # sometimes header subscription updates are for a header in the future
|
||||||
|
# # which can't be connected, so we do a normal header sync instead
|
||||||
|
# height = len(self.headers)
|
||||||
|
# headers = None
|
||||||
|
# subscription_update = False
|
||||||
|
#
|
||||||
|
# if not headers:
|
||||||
|
# header_response = await self.network.retriable_call(self.network.get_headers, height, 2001)
|
||||||
|
# headers = header_response['hex']
|
||||||
|
#
|
||||||
|
# if not headers:
|
||||||
|
# # Nothing to do, network thinks we're already at the latest height.
|
||||||
|
# return
|
||||||
|
#
|
||||||
|
# added = await self.headers.connect(height, unhexlify(headers))
|
||||||
|
# if added > 0:
|
||||||
|
# height += added
|
||||||
|
# self._on_header_controller.add(
|
||||||
|
# BlockHeightEvent(self.headers.height, added))
|
||||||
|
#
|
||||||
|
# if rewound > 0:
|
||||||
|
# # we started rewinding blocks and apparently found
|
||||||
|
# # a new chain
|
||||||
|
# rewound = 0
|
||||||
|
# await self.db.rewind_blockchain(height)
|
||||||
|
#
|
||||||
|
# if subscription_update:
|
||||||
|
# # subscription updates are for latest header already
|
||||||
|
# # so we don't need to check if there are newer / more
|
||||||
|
# # on another loop of update_headers(), just return instead
|
||||||
|
# return
|
||||||
|
#
|
||||||
|
# elif added == 0:
|
||||||
|
# # we had headers to connect but none got connected, probably a reorganization
|
||||||
|
# height -= 1
|
||||||
|
# rewound += 1
|
||||||
|
# log.warning(
|
||||||
|
# "Blockchain Reorganization: attempting rewind to height %s from starting height %s",
|
||||||
|
# height, height+rewound
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# else:
|
||||||
|
# raise IndexError(f"headers.connect() returned negative number ({added})")
|
||||||
|
#
|
||||||
|
# if height < 0:
|
||||||
|
# raise IndexError(
|
||||||
|
# "Blockchain reorganization rewound all the way back to genesis hash. "
|
||||||
|
# "Something is very wrong. Maybe you are on the wrong blockchain?"
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# if rewound >= 100:
|
||||||
|
# raise IndexError(
|
||||||
|
# "Blockchain reorganization dropped {} headers. This is highly unusual. "
|
||||||
|
# "Will not continue to attempt reorganizing. Please, delete the ledger "
|
||||||
|
# "synchronization directory inside your wallet directory (folder: '{}') and "
|
||||||
|
# "restart the program to synchronize from scratch."
|
||||||
|
# .format(rewound, self.ledger.get_id())
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# headers = None # ready to download some more headers
|
||||||
|
#
|
||||||
|
# # if we made it this far and this was a subscription_update
|
||||||
|
# # it means something went wrong and now we're doing a more
|
||||||
|
# # robust sync, turn off subscription update shortcut
|
||||||
|
# subscription_update = False
|
||||||
|
#
|
||||||
|
# async def receive_header(self, response):
|
||||||
|
# async with self._header_processing_lock:
|
||||||
|
# header = response[0]
|
||||||
|
# await self.update_headers(
|
||||||
|
# height=header['height'], headers=header['hex'], subscription_update=True
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# async def subscribe_accounts(self):
|
||||||
|
# if self.network.is_connected and self.accounts:
|
||||||
|
# log.info("Subscribe to %i accounts", len(self.accounts))
|
||||||
|
# await asyncio.wait([
|
||||||
|
# self.subscribe_account(a) for a in self.accounts
|
||||||
|
# ])
|
||||||
|
#
|
||||||
|
# async def subscribe_account(self, account: Account):
|
||||||
|
# for address_manager in account.address_managers.values():
|
||||||
|
# await self.subscribe_addresses(address_manager, await address_manager.get_addresses())
|
||||||
|
# await account.ensure_address_gap()
|
||||||
|
#
|
||||||
|
# async def unsubscribe_account(self, account: Account):
|
||||||
|
# for address in await account.get_addresses():
|
||||||
|
# await self.network.unsubscribe_address(address)
|
||||||
|
#
|
||||||
|
# async def subscribe_addresses(
|
||||||
|
# self, address_manager: AddressManager, addresses: List[str], batch_size: int = 1000):
|
||||||
|
# if self.network.is_connected and addresses:
|
||||||
|
# addresses_remaining = list(addresses)
|
||||||
|
# while addresses_remaining:
|
||||||
|
# batch = addresses_remaining[:batch_size]
|
||||||
|
# results = await self.network.subscribe_address(*batch)
|
||||||
|
# for address, remote_status in zip(batch, results):
|
||||||
|
# self._update_tasks.add(self.update_history(address, remote_status, address_manager))
|
||||||
|
# addresses_remaining = addresses_remaining[batch_size:]
|
||||||
|
# log.info("subscribed to %i/%i addresses on %s:%i", len(addresses) - len(addresses_remaining),
|
||||||
|
# len(addresses), *self.network.client.server_address_and_port)
|
||||||
|
# log.info(
|
||||||
|
# "finished subscribing to %i addresses on %s:%i", len(addresses),
|
||||||
|
# *self.network.client.server_address_and_port
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# def process_status_update(self, update):
|
||||||
|
# address, remote_status = update
|
||||||
|
# self._update_tasks.add(self.update_history(address, remote_status))
|
||||||
|
#
|
||||||
|
# async def update_history(self, address, remote_status, address_manager: AddressManager = None):
|
||||||
|
# async with self._address_update_locks[address]:
|
||||||
|
# self._known_addresses_out_of_sync.discard(address)
|
||||||
|
#
|
||||||
|
# local_status, local_history = await self.get_local_status_and_history(address)
|
||||||
|
#
|
||||||
|
# if local_status == remote_status:
|
||||||
|
# return True
|
||||||
|
#
|
||||||
|
# remote_history = await self.network.retriable_call(self.network.get_history, address)
|
||||||
|
# remote_history = list(map(itemgetter('tx_hash', 'height'), remote_history))
|
||||||
|
# we_need = set(remote_history) - set(local_history)
|
||||||
|
# if not we_need:
|
||||||
|
# return True
|
||||||
|
#
|
||||||
|
# cache_tasks: List[asyncio.Task[Transaction]] = []
|
||||||
|
# synced_history = StringIO()
|
||||||
|
# loop = asyncio.get_running_loop()
|
||||||
|
# for i, (txid, remote_height) in enumerate(remote_history):
|
||||||
|
# if i < len(local_history) and local_history[i] == (txid, remote_height) and not cache_tasks:
|
||||||
|
# synced_history.write(f'{txid}:{remote_height}:')
|
||||||
|
# else:
|
||||||
|
# check_local = (txid, remote_height) not in we_need
|
||||||
|
# cache_tasks.append(loop.create_task(
|
||||||
|
# self.cache_transaction(unhexlify(txid)[::-1], remote_height, check_local=check_local)
|
||||||
|
# ))
|
||||||
|
#
|
||||||
|
# synced_txs = []
|
||||||
|
# for task in cache_tasks:
|
||||||
|
# tx = await task
|
||||||
|
#
|
||||||
|
# check_db_for_txos = []
|
||||||
|
# for txi in tx.inputs:
|
||||||
|
# if txi.txo_ref.txo is not None:
|
||||||
|
# continue
|
||||||
|
# cache_item = self._tx_cache.get(txi.txo_ref.tx_ref.hash)
|
||||||
|
# if cache_item is not None:
|
||||||
|
# if cache_item.tx is None:
|
||||||
|
# await cache_item.has_tx.wait()
|
||||||
|
# assert cache_item.tx is not None
|
||||||
|
# txi.txo_ref = cache_item.tx.outputs[txi.txo_ref.position].ref
|
||||||
|
# else:
|
||||||
|
# check_db_for_txos.append(txi.txo_ref.hash)
|
||||||
|
#
|
||||||
|
# referenced_txos = {} if not check_db_for_txos else {
|
||||||
|
# txo.id: txo for txo in await self.db.get_txos(
|
||||||
|
# txo_hash__in=check_db_for_txos, order_by='txo.txo_hash', no_tx=True
|
||||||
|
# )
|
||||||
|
# }
|
||||||
|
#
|
||||||
|
# for txi in tx.inputs:
|
||||||
|
# if txi.txo_ref.txo is not None:
|
||||||
|
# continue
|
||||||
|
# referenced_txo = referenced_txos.get(txi.txo_ref.id)
|
||||||
|
# if referenced_txo is not None:
|
||||||
|
# txi.txo_ref = referenced_txo.ref
|
||||||
|
#
|
||||||
|
# synced_history.write(f'{tx.id}:{tx.height}:')
|
||||||
|
# synced_txs.append(tx)
|
||||||
|
#
|
||||||
|
# await self.db.save_transaction_io_batch(
|
||||||
|
# synced_txs, address, self.ledger.address_to_hash160(address), synced_history.getvalue()
|
||||||
|
# )
|
||||||
|
# await asyncio.wait([
|
||||||
|
# self.ledger._on_transaction_controller.add(TransactionEvent(address, tx))
|
||||||
|
# for tx in synced_txs
|
||||||
|
# ])
|
||||||
|
#
|
||||||
|
# if address_manager is None:
|
||||||
|
# address_manager = await self.get_address_manager_for_address(address)
|
||||||
|
#
|
||||||
|
# if address_manager is not None:
|
||||||
|
# await address_manager.ensure_address_gap()
|
||||||
|
#
|
||||||
|
# local_status, local_history = \
|
||||||
|
# await self.get_local_status_and_history(address, synced_history.getvalue())
|
||||||
|
# if local_status != remote_status:
|
||||||
|
# if local_history == remote_history:
|
||||||
|
# return True
|
||||||
|
# log.warning(
|
||||||
|
# "Wallet is out of sync after syncing. Remote: %s with %d items, local: %s with %d items",
|
||||||
|
# remote_status, len(remote_history), local_status, len(local_history)
|
||||||
|
# )
|
||||||
|
# log.warning("local: %s", local_history)
|
||||||
|
# log.warning("remote: %s", remote_history)
|
||||||
|
# self._known_addresses_out_of_sync.add(address)
|
||||||
|
# return False
|
||||||
|
# else:
|
||||||
|
# return True
|
||||||
|
#
|
||||||
|
# async def cache_transaction(self, tx_hash, remote_height, check_local=True):
|
||||||
|
# cache_item = self._tx_cache.get(tx_hash)
|
||||||
|
# if cache_item is None:
|
||||||
|
# cache_item = self._tx_cache[tx_hash] = TransactionCacheItem()
|
||||||
|
# elif cache_item.tx is not None and \
|
||||||
|
# cache_item.tx.height >= remote_height and \
|
||||||
|
# (cache_item.tx.is_verified or remote_height < 1):
|
||||||
|
# return cache_item.tx # cached tx is already up-to-date
|
||||||
|
#
|
||||||
|
# try:
|
||||||
|
# cache_item.pending_verifications += 1
|
||||||
|
# return await self._update_cache_item(cache_item, tx_hash, remote_height, check_local)
|
||||||
|
# finally:
|
||||||
|
# cache_item.pending_verifications -= 1
|
||||||
|
#
|
||||||
|
# async def _update_cache_item(self, cache_item, tx_hash, remote_height, check_local=True):
|
||||||
|
#
|
||||||
|
# async with cache_item.lock:
|
||||||
|
#
|
||||||
|
# tx = cache_item.tx
|
||||||
|
#
|
||||||
|
# if tx is None and check_local:
|
||||||
|
# # check local db
|
||||||
|
# tx = cache_item.tx = await self.db.get_transaction(tx_hash=tx_hash)
|
||||||
|
#
|
||||||
|
# merkle = None
|
||||||
|
# if tx is None:
|
||||||
|
# # fetch from network
|
||||||
|
# _raw, merkle = await self.network.retriable_call(
|
||||||
|
# self.network.get_transaction_and_merkle, tx_hash, remote_height
|
||||||
|
# )
|
||||||
|
# tx = Transaction(unhexlify(_raw), height=merkle.get('block_height'))
|
||||||
|
# cache_item.tx = tx # make sure it's saved before caching it
|
||||||
|
# await self.maybe_verify_transaction(tx, remote_height, merkle)
|
||||||
|
# return tx
|
||||||
|
#
|
||||||
|
# async def maybe_verify_transaction(self, tx, remote_height, merkle=None):
|
||||||
|
# tx.height = remote_height
|
||||||
|
# cached = self._tx_cache.get(tx.hash)
|
||||||
|
# if not cached:
|
||||||
|
# # cache txs looked up by transaction_show too
|
||||||
|
# cached = TransactionCacheItem()
|
||||||
|
# cached.tx = tx
|
||||||
|
# self._tx_cache[tx.hash] = cached
|
||||||
|
# if 0 < remote_height < len(self.headers) and cached.pending_verifications <= 1:
|
||||||
|
# # can't be tx.pending_verifications == 1 because we have to handle the transaction_show case
|
||||||
|
# if not merkle:
|
||||||
|
# merkle = await self.network.retriable_call(self.network.get_merkle, tx.hash, remote_height)
|
||||||
|
# merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash)
|
||||||
|
# header = await self.headers.get(remote_height)
|
||||||
|
# tx.position = merkle['pos']
|
||||||
|
# tx.is_verified = merkle_root == header['merkle_root']
|
||||||
|
#
|
||||||
|
# async def get_address_manager_for_address(self, address) -> Optional[AddressManager]:
|
||||||
|
# details = await self.db.get_address(address=address)
|
||||||
|
# for account in self.accounts:
|
||||||
|
# if account.id == details['account']:
|
||||||
|
# return account.address_managers[details['chain']]
|
||||||
|
# return None
|
|
@ -447,6 +447,38 @@ class IntegrationTestCase(AsyncioTestCase):
|
||||||
self.db_driver = db_driver
|
self.db_driver = db_driver
|
||||||
return db
|
return db
|
||||||
|
|
||||||
|
async def add_full_node(self, port):
|
||||||
|
path = tempfile.mkdtemp()
|
||||||
|
self.addCleanup(shutil.rmtree, path, True)
|
||||||
|
ledger = RegTestLedger(Config.with_same_dir(path).set(
|
||||||
|
api=f'localhost:{port}',
|
||||||
|
lbrycrd_dir=self.chain.ledger.conf.lbrycrd_dir,
|
||||||
|
lbrycrd_rpc_port=self.chain.ledger.conf.lbrycrd_rpc_port,
|
||||||
|
lbrycrd_peer_port=self.chain.ledger.conf.lbrycrd_peer_port,
|
||||||
|
lbrycrd_zmq=self.chain.ledger.conf.lbrycrd_zmq
|
||||||
|
))
|
||||||
|
service = FullNode(ledger)
|
||||||
|
console = Console(service)
|
||||||
|
daemon = Daemon(service, console)
|
||||||
|
self.addCleanup(daemon.stop)
|
||||||
|
await daemon.start()
|
||||||
|
return daemon
|
||||||
|
|
||||||
|
async def add_light_client(self, full_node, port, start=True):
|
||||||
|
path = tempfile.mkdtemp()
|
||||||
|
self.addCleanup(shutil.rmtree, path, True)
|
||||||
|
ledger = RegTestLedger(Config.with_same_dir(path).set(
|
||||||
|
api=f'localhost:{port}',
|
||||||
|
full_nodes=[(full_node.conf.api_host, full_node.conf.api_port)]
|
||||||
|
))
|
||||||
|
service = LightClient(ledger)
|
||||||
|
console = Console(service)
|
||||||
|
daemon = Daemon(service, console)
|
||||||
|
self.addCleanup(daemon.stop)
|
||||||
|
if start:
|
||||||
|
await daemon.start()
|
||||||
|
return daemon
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def find_claim_txo(tx) -> Optional[Output]:
|
def find_claim_txo(tx) -> Optional[Output]:
|
||||||
for txo in tx.outputs:
|
for txo in tx.outputs:
|
||||||
|
@ -538,9 +570,11 @@ class CommandTestCase(IntegrationTestCase):
|
||||||
await super().asyncSetUp()
|
await super().asyncSetUp()
|
||||||
await self.generate(200, wait=False)
|
await self.generate(200, wait=False)
|
||||||
|
|
||||||
self.full_node = self.daemon = await self.add_full_node()
|
self.daemon_port += 1
|
||||||
if os.environ.get('TEST_MODE', 'full-node') == 'client':
|
self.full_node = self.daemon = await self.add_full_node(self.daemon_port)
|
||||||
self.daemon = await self.add_light_client(self.full_node)
|
if os.environ.get('TEST_MODE', 'node') == 'client':
|
||||||
|
self.daemon_port += 1
|
||||||
|
self.daemon = await self.add_light_client(self.full_node, self.daemon_port)
|
||||||
|
|
||||||
self.service = self.daemon.service
|
self.service = self.daemon.service
|
||||||
self.ledger = self.service.ledger
|
self.ledger = self.service.ledger
|
||||||
|
@ -556,40 +590,6 @@ class CommandTestCase(IntegrationTestCase):
|
||||||
await self.chain.send_to_address(addresses[0], '10.0')
|
await self.chain.send_to_address(addresses[0], '10.0')
|
||||||
await self.generate(5)
|
await self.generate(5)
|
||||||
|
|
||||||
async def add_full_node(self):
|
|
||||||
self.daemon_port += 1
|
|
||||||
path = tempfile.mkdtemp()
|
|
||||||
self.addCleanup(shutil.rmtree, path, True)
|
|
||||||
ledger = RegTestLedger(Config.with_same_dir(path).set(
|
|
||||||
api=f'localhost:{self.daemon_port}',
|
|
||||||
lbrycrd_dir=self.chain.ledger.conf.lbrycrd_dir,
|
|
||||||
lbrycrd_rpc_port=self.chain.ledger.conf.lbrycrd_rpc_port,
|
|
||||||
lbrycrd_peer_port=self.chain.ledger.conf.lbrycrd_peer_port,
|
|
||||||
lbrycrd_zmq=self.chain.ledger.conf.lbrycrd_zmq,
|
|
||||||
spv_address_filters=False
|
|
||||||
))
|
|
||||||
service = FullNode(ledger)
|
|
||||||
console = Console(service)
|
|
||||||
daemon = Daemon(service, console)
|
|
||||||
self.addCleanup(daemon.stop)
|
|
||||||
await daemon.start()
|
|
||||||
return daemon
|
|
||||||
|
|
||||||
async def add_light_client(self, full_node):
|
|
||||||
self.daemon_port += 1
|
|
||||||
path = tempfile.mkdtemp()
|
|
||||||
self.addCleanup(shutil.rmtree, path, True)
|
|
||||||
ledger = RegTestLedger(Config.with_same_dir(path).set(
|
|
||||||
api=f'localhost:{self.daemon_port}',
|
|
||||||
full_nodes=[(full_node.conf.api_host, full_node.conf.api_port)]
|
|
||||||
))
|
|
||||||
service = LightClient(ledger)
|
|
||||||
console = Console(service)
|
|
||||||
daemon = Daemon(service, console)
|
|
||||||
self.addCleanup(daemon.stop)
|
|
||||||
await daemon.start()
|
|
||||||
return daemon
|
|
||||||
|
|
||||||
async def asyncTearDown(self):
|
async def asyncTearDown(self):
|
||||||
await super().asyncTearDown()
|
await super().asyncTearDown()
|
||||||
for wallet_node in self.extra_wallet_nodes:
|
for wallet_node in self.extra_wallet_nodes:
|
||||||
|
|
|
@ -12,6 +12,7 @@ import ecdsa
|
||||||
|
|
||||||
from lbry.constants import COIN
|
from lbry.constants import COIN
|
||||||
from lbry.db import Database, CLAIM_TYPE_CODES, TXO_TYPES
|
from lbry.db import Database, CLAIM_TYPE_CODES, TXO_TYPES
|
||||||
|
from lbry.db.tables import AccountAddress
|
||||||
from lbry.blockchain import Ledger
|
from lbry.blockchain import Ledger
|
||||||
from lbry.error import InvalidPasswordError
|
from lbry.error import InvalidPasswordError
|
||||||
from lbry.crypto.crypt import aes_encrypt, aes_decrypt
|
from lbry.crypto.crypt import aes_encrypt, aes_decrypt
|
||||||
|
@ -214,12 +215,12 @@ class Account:
|
||||||
HierarchicalDeterministic.name: HierarchicalDeterministic,
|
HierarchicalDeterministic.name: HierarchicalDeterministic,
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, ledger: Ledger, db: Database, name: str,
|
def __init__(self, db: Database, name: str,
|
||||||
phrase: str, language: str, private_key_string: str,
|
phrase: str, language: str, private_key_string: str,
|
||||||
encrypted: bool, private_key: Optional[PrivateKey], public_key: PubKey,
|
encrypted: bool, private_key: Optional[PrivateKey], public_key: PubKey,
|
||||||
address_generator: dict, modified_on: float, channel_keys: dict) -> None:
|
address_generator: dict, modified_on: float, channel_keys: dict) -> None:
|
||||||
self.ledger = ledger
|
|
||||||
self.db = db
|
self.db = db
|
||||||
|
self.ledger = db.ledger
|
||||||
self.id = public_key.address
|
self.id = public_key.address
|
||||||
self.name = name
|
self.name = name
|
||||||
self.phrase = phrase
|
self.phrase = phrase
|
||||||
|
@ -245,10 +246,10 @@ class Account:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def generate(
|
async def generate(
|
||||||
cls, ledger: Ledger, db: Database,
|
cls, db: Database, name: str = None,
|
||||||
name: str = None, language: str = 'en',
|
language: str = 'en', address_generator: dict = None
|
||||||
address_generator: dict = None):
|
):
|
||||||
return await cls.from_dict(ledger, db, {
|
return await cls.from_dict(db, {
|
||||||
'name': name,
|
'name': name,
|
||||||
'seed': await mnemonic.generate_phrase(language),
|
'seed': await mnemonic.generate_phrase(language),
|
||||||
'language': language,
|
'language': language,
|
||||||
|
@ -276,13 +277,12 @@ class Account:
|
||||||
return phrase, private_key, public_key
|
return phrase, private_key, public_key
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def from_dict(cls, ledger: Ledger, db: Database, d: dict):
|
async def from_dict(cls, db: Database, d: dict):
|
||||||
phrase, private_key, public_key = await cls.keys_from_dict(ledger, d)
|
phrase, private_key, public_key = await cls.keys_from_dict(db.ledger, d)
|
||||||
name = d.get('name')
|
name = d.get('name')
|
||||||
if not name:
|
if not name:
|
||||||
name = f'Account #{public_key.address}'
|
name = f'Account #{public_key.address}'
|
||||||
return cls(
|
return cls(
|
||||||
ledger=ledger,
|
|
||||||
db=db,
|
db=db,
|
||||||
name=name,
|
name=name,
|
||||||
phrase=phrase,
|
phrase=phrase,
|
||||||
|
@ -415,7 +415,7 @@ class Account:
|
||||||
return await self.db.get_addresses(account=self, **constraints)
|
return await self.db.get_addresses(account=self, **constraints)
|
||||||
|
|
||||||
async def get_addresses(self, **constraints) -> List[str]:
|
async def get_addresses(self, **constraints) -> List[str]:
|
||||||
rows = await self.get_address_records(cols=['account_address.address'], **constraints)
|
rows = await self.get_address_records(cols=[AccountAddress.c.address], **constraints)
|
||||||
return [r['address'] for r in rows]
|
return [r['address'] for r in rows]
|
||||||
|
|
||||||
async def get_valid_receiving_address(self, default_address: str) -> str:
|
async def get_valid_receiving_address(self, default_address: str) -> str:
|
||||||
|
|
|
@ -1,385 +0,0 @@
|
||||||
import os
|
|
||||||
import logging
|
|
||||||
import asyncio
|
|
||||||
import sqlite3
|
|
||||||
import platform
|
|
||||||
from binascii import hexlify
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from contextvars import ContextVar
|
|
||||||
from concurrent.futures.thread import ThreadPoolExecutor
|
|
||||||
from concurrent.futures.process import ProcessPoolExecutor
|
|
||||||
from typing import Tuple, List, Union, Callable, Any, Awaitable, Iterable, Dict, Optional
|
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
|
||||||
sqlite3.enable_callback_tracebacks(True)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ReaderProcessState:
|
|
||||||
cursor: sqlite3.Cursor
|
|
||||||
|
|
||||||
|
|
||||||
reader_context: Optional[ContextVar[ReaderProcessState]] = ContextVar('reader_context')
|
|
||||||
|
|
||||||
|
|
||||||
def initializer(path):
|
|
||||||
db = sqlite3.connect(path)
|
|
||||||
db.row_factory = dict_row_factory
|
|
||||||
db.executescript("pragma journal_mode=WAL;")
|
|
||||||
reader = ReaderProcessState(db.cursor())
|
|
||||||
reader_context.set(reader)
|
|
||||||
|
|
||||||
|
|
||||||
def run_read_only_fetchall(sql, params):
|
|
||||||
cursor = reader_context.get().cursor
|
|
||||||
try:
|
|
||||||
return cursor.execute(sql, params).fetchall()
|
|
||||||
except (Exception, OSError) as e:
|
|
||||||
log.exception('Error running transaction:', exc_info=e)
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
def run_read_only_fetchone(sql, params):
|
|
||||||
cursor = reader_context.get().cursor
|
|
||||||
try:
|
|
||||||
return cursor.execute(sql, params).fetchone()
|
|
||||||
except (Exception, OSError) as e:
|
|
||||||
log.exception('Error running transaction:', exc_info=e)
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
if platform.system() == 'Windows' or 'ANDROID_ARGUMENT' in os.environ:
|
|
||||||
ReaderExecutorClass = ThreadPoolExecutor
|
|
||||||
else:
|
|
||||||
ReaderExecutorClass = ProcessPoolExecutor
|
|
||||||
|
|
||||||
|
|
||||||
class AIOSQLite:
|
|
||||||
reader_executor: ReaderExecutorClass
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
# has to be single threaded as there is no mapping of thread:connection
|
|
||||||
self.writer_executor = ThreadPoolExecutor(max_workers=1)
|
|
||||||
self.writer_connection: Optional[sqlite3.Connection] = None
|
|
||||||
self._closing = False
|
|
||||||
self.query_count = 0
|
|
||||||
self.write_lock = asyncio.Lock()
|
|
||||||
self.writers = 0
|
|
||||||
self.read_ready = asyncio.Event()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def connect(cls, path: Union[bytes, str], *args, **kwargs):
|
|
||||||
sqlite3.enable_callback_tracebacks(True)
|
|
||||||
db = cls()
|
|
||||||
|
|
||||||
def _connect_writer():
|
|
||||||
db.writer_connection = sqlite3.connect(path, *args, **kwargs)
|
|
||||||
|
|
||||||
readers = max(os.cpu_count() - 2, 2)
|
|
||||||
db.reader_executor = ReaderExecutorClass(
|
|
||||||
max_workers=readers, initializer=initializer, initargs=(path, )
|
|
||||||
)
|
|
||||||
await asyncio.get_event_loop().run_in_executor(db.writer_executor, _connect_writer)
|
|
||||||
db.read_ready.set()
|
|
||||||
return db
|
|
||||||
|
|
||||||
async def close(self):
|
|
||||||
if self._closing:
|
|
||||||
return
|
|
||||||
self._closing = True
|
|
||||||
await asyncio.get_event_loop().run_in_executor(self.writer_executor, self.writer_connection.close)
|
|
||||||
self.writer_executor.shutdown(wait=True)
|
|
||||||
self.reader_executor.shutdown(wait=True)
|
|
||||||
self.read_ready.clear()
|
|
||||||
self.writer_connection = None
|
|
||||||
|
|
||||||
def executemany(self, sql: str, params: Iterable):
|
|
||||||
params = params if params is not None else []
|
|
||||||
# this fetchall is needed to prevent SQLITE_MISUSE
|
|
||||||
return self.run(lambda conn: conn.executemany(sql, params).fetchall())
|
|
||||||
|
|
||||||
def executescript(self, script: str) -> Awaitable:
|
|
||||||
return self.run(lambda conn: conn.executescript(script))
|
|
||||||
|
|
||||||
async def _execute_fetch(self, sql: str, parameters: Iterable = None,
|
|
||||||
read_only=False, fetch_all: bool = False) -> List[dict]:
|
|
||||||
read_only_fn = run_read_only_fetchall if fetch_all else run_read_only_fetchone
|
|
||||||
parameters = parameters if parameters is not None else []
|
|
||||||
if read_only:
|
|
||||||
while self.writers:
|
|
||||||
await self.read_ready.wait()
|
|
||||||
return await asyncio.get_event_loop().run_in_executor(
|
|
||||||
self.reader_executor, read_only_fn, sql, parameters
|
|
||||||
)
|
|
||||||
if fetch_all:
|
|
||||||
return await self.run(lambda conn: conn.execute(sql, parameters).fetchall())
|
|
||||||
return await self.run(lambda conn: conn.execute(sql, parameters).fetchone())
|
|
||||||
|
|
||||||
async def execute_fetchall(self, sql: str, parameters: Iterable = None,
|
|
||||||
read_only=False) -> List[dict]:
|
|
||||||
return await self._execute_fetch(sql, parameters, read_only, fetch_all=True)
|
|
||||||
|
|
||||||
async def execute_fetchone(self, sql: str, parameters: Iterable = None,
|
|
||||||
read_only=False) -> List[dict]:
|
|
||||||
return await self._execute_fetch(sql, parameters, read_only, fetch_all=False)
|
|
||||||
|
|
||||||
def execute(self, sql: str, parameters: Iterable = None) -> Awaitable[sqlite3.Cursor]:
|
|
||||||
parameters = parameters if parameters is not None else []
|
|
||||||
return self.run(lambda conn: conn.execute(sql, parameters))
|
|
||||||
|
|
||||||
async def run(self, fun, *args, **kwargs):
|
|
||||||
self.writers += 1
|
|
||||||
self.read_ready.clear()
|
|
||||||
async with self.write_lock:
|
|
||||||
try:
|
|
||||||
return await asyncio.get_event_loop().run_in_executor(
|
|
||||||
self.writer_executor, lambda: self.__run_transaction(fun, *args, **kwargs)
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
self.writers -= 1
|
|
||||||
if not self.writers:
|
|
||||||
self.read_ready.set()
|
|
||||||
|
|
||||||
def __run_transaction(self, fun: Callable[[sqlite3.Connection, Any, Any], Any], *args, **kwargs):
|
|
||||||
self.writer_connection.execute('begin')
|
|
||||||
try:
|
|
||||||
self.query_count += 1
|
|
||||||
result = fun(self.writer_connection, *args, **kwargs) # type: ignore
|
|
||||||
self.writer_connection.commit()
|
|
||||||
return result
|
|
||||||
except (Exception, OSError) as e:
|
|
||||||
log.exception('Error running transaction:', exc_info=e)
|
|
||||||
self.writer_connection.rollback()
|
|
||||||
log.warning("rolled back")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def run_with_foreign_keys_disabled(self, fun, *args, **kwargs) -> Awaitable:
|
|
||||||
return asyncio.get_event_loop().run_in_executor(
|
|
||||||
self.writer_executor, self.__run_transaction_with_foreign_keys_disabled, fun, args, kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
def __run_transaction_with_foreign_keys_disabled(self,
|
|
||||||
fun: Callable[[sqlite3.Connection, Any, Any], Any],
|
|
||||||
args, kwargs):
|
|
||||||
foreign_keys_enabled, = self.writer_connection.execute("pragma foreign_keys").fetchone()
|
|
||||||
if not foreign_keys_enabled:
|
|
||||||
raise sqlite3.IntegrityError("foreign keys are disabled, use `AIOSQLite.run` instead")
|
|
||||||
try:
|
|
||||||
self.writer_connection.execute('pragma foreign_keys=off').fetchone()
|
|
||||||
return self.__run_transaction(fun, *args, **kwargs)
|
|
||||||
finally:
|
|
||||||
self.writer_connection.execute('pragma foreign_keys=on').fetchone()
|
|
||||||
|
|
||||||
|
|
||||||
def constraints_to_sql(constraints, joiner=' AND ', prepend_key=''):
|
|
||||||
sql, values = [], {}
|
|
||||||
for key, constraint in constraints.items():
|
|
||||||
tag = '0'
|
|
||||||
if '#' in key:
|
|
||||||
key, tag = key[:key.index('#')], key[key.index('#')+1:]
|
|
||||||
col, op, key = key, '=', key.replace('.', '_')
|
|
||||||
if not key:
|
|
||||||
sql.append(constraint)
|
|
||||||
continue
|
|
||||||
if key.startswith('$$'):
|
|
||||||
col, key = col[2:], key[1:]
|
|
||||||
elif key.startswith('$'):
|
|
||||||
values[key] = constraint
|
|
||||||
continue
|
|
||||||
if key.endswith('__not'):
|
|
||||||
col, op = col[:-len('__not')], '!='
|
|
||||||
elif key.endswith('__is_null'):
|
|
||||||
col = col[:-len('__is_null')]
|
|
||||||
sql.append(f'{col} IS NULL')
|
|
||||||
continue
|
|
||||||
if key.endswith('__is_not_null'):
|
|
||||||
col = col[:-len('__is_not_null')]
|
|
||||||
sql.append(f'{col} IS NOT NULL')
|
|
||||||
continue
|
|
||||||
if key.endswith('__lt'):
|
|
||||||
col, op = col[:-len('__lt')], '<'
|
|
||||||
elif key.endswith('__lte'):
|
|
||||||
col, op = col[:-len('__lte')], '<='
|
|
||||||
elif key.endswith('__gt'):
|
|
||||||
col, op = col[:-len('__gt')], '>'
|
|
||||||
elif key.endswith('__gte'):
|
|
||||||
col, op = col[:-len('__gte')], '>='
|
|
||||||
elif key.endswith('__like'):
|
|
||||||
col, op = col[:-len('__like')], 'LIKE'
|
|
||||||
elif key.endswith('__not_like'):
|
|
||||||
col, op = col[:-len('__not_like')], 'NOT LIKE'
|
|
||||||
elif key.endswith('__in') or key.endswith('__not_in'):
|
|
||||||
if key.endswith('__in'):
|
|
||||||
col, op, one_val_op = col[:-len('__in')], 'IN', '='
|
|
||||||
else:
|
|
||||||
col, op, one_val_op = col[:-len('__not_in')], 'NOT IN', '!='
|
|
||||||
if constraint:
|
|
||||||
if isinstance(constraint, (list, set, tuple)):
|
|
||||||
if len(constraint) == 1:
|
|
||||||
values[f'{key}{tag}'] = next(iter(constraint))
|
|
||||||
sql.append(f'{col} {one_val_op} :{key}{tag}')
|
|
||||||
else:
|
|
||||||
keys = []
|
|
||||||
for i, val in enumerate(constraint):
|
|
||||||
keys.append(f':{key}{tag}_{i}')
|
|
||||||
values[f'{key}{tag}_{i}'] = val
|
|
||||||
sql.append(f'{col} {op} ({", ".join(keys)})')
|
|
||||||
elif isinstance(constraint, str):
|
|
||||||
sql.append(f'{col} {op} ({constraint})')
|
|
||||||
else:
|
|
||||||
raise ValueError(f"{col} requires a list, set or string as constraint value.")
|
|
||||||
continue
|
|
||||||
elif key.endswith('__any') or key.endswith('__or'):
|
|
||||||
where, subvalues = constraints_to_sql(constraint, ' OR ', key+tag+'_')
|
|
||||||
sql.append(f'({where})')
|
|
||||||
values.update(subvalues)
|
|
||||||
continue
|
|
||||||
if key.endswith('__and'):
|
|
||||||
where, subvalues = constraints_to_sql(constraint, ' AND ', key+tag+'_')
|
|
||||||
sql.append(f'({where})')
|
|
||||||
values.update(subvalues)
|
|
||||||
continue
|
|
||||||
sql.append(f'{col} {op} :{prepend_key}{key}{tag}')
|
|
||||||
values[prepend_key+key+tag] = constraint
|
|
||||||
return joiner.join(sql) if sql else '', values
|
|
||||||
|
|
||||||
|
|
||||||
def query(select, **constraints) -> Tuple[str, Dict[str, Any]]:
|
|
||||||
sql = [select]
|
|
||||||
limit = constraints.pop('limit', None)
|
|
||||||
offset = constraints.pop('offset', None)
|
|
||||||
order_by = constraints.pop('order_by', None)
|
|
||||||
group_by = constraints.pop('group_by', None)
|
|
||||||
|
|
||||||
accounts = constraints.pop('accounts', [])
|
|
||||||
if accounts:
|
|
||||||
constraints['account__in'] = [a.public_key.address for a in accounts]
|
|
||||||
|
|
||||||
where, values = constraints_to_sql(constraints)
|
|
||||||
if where:
|
|
||||||
sql.append('WHERE')
|
|
||||||
sql.append(where)
|
|
||||||
|
|
||||||
if group_by is not None:
|
|
||||||
sql.append(f'GROUP BY {group_by}')
|
|
||||||
|
|
||||||
if order_by:
|
|
||||||
sql.append('ORDER BY')
|
|
||||||
if isinstance(order_by, str):
|
|
||||||
sql.append(order_by)
|
|
||||||
elif isinstance(order_by, list):
|
|
||||||
sql.append(', '.join(order_by))
|
|
||||||
else:
|
|
||||||
raise ValueError("order_by must be string or list")
|
|
||||||
|
|
||||||
if limit is not None:
|
|
||||||
sql.append(f'LIMIT {limit}')
|
|
||||||
|
|
||||||
if offset is not None:
|
|
||||||
sql.append(f'OFFSET {offset}')
|
|
||||||
|
|
||||||
return ' '.join(sql), values
|
|
||||||
|
|
||||||
|
|
||||||
def interpolate(sql, values):
|
|
||||||
for k in sorted(values.keys(), reverse=True):
|
|
||||||
value = values[k]
|
|
||||||
if isinstance(value, bytes):
|
|
||||||
value = f"X'{hexlify(value).decode()}'"
|
|
||||||
elif isinstance(value, str):
|
|
||||||
value = f"'{value}'"
|
|
||||||
else:
|
|
||||||
value = str(value)
|
|
||||||
sql = sql.replace(f":{k}", value)
|
|
||||||
return sql
|
|
||||||
|
|
||||||
|
|
||||||
def constrain_single_or_list(constraints, column, value, convert=lambda x: x):
|
|
||||||
if value is not None:
|
|
||||||
if isinstance(value, list):
|
|
||||||
value = [convert(v) for v in value]
|
|
||||||
if len(value) == 1:
|
|
||||||
constraints[column] = value[0]
|
|
||||||
elif len(value) > 1:
|
|
||||||
constraints[f"{column}__in"] = value
|
|
||||||
else:
|
|
||||||
constraints[column] = convert(value)
|
|
||||||
return constraints
|
|
||||||
|
|
||||||
|
|
||||||
class SQLiteMixin:
|
|
||||||
|
|
||||||
SCHEMA_VERSION: Optional[str] = None
|
|
||||||
CREATE_TABLES_QUERY: str
|
|
||||||
MAX_QUERY_VARIABLES = 900
|
|
||||||
|
|
||||||
CREATE_VERSION_TABLE = """
|
|
||||||
create table if not exists version (
|
|
||||||
version text
|
|
||||||
);
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, path):
|
|
||||||
self._db_path = path
|
|
||||||
self.db: AIOSQLite = None
|
|
||||||
self.ledger = None
|
|
||||||
|
|
||||||
async def open(self):
|
|
||||||
log.info("connecting to database: %s", self._db_path)
|
|
||||||
self.db = await AIOSQLite.connect(self._db_path, isolation_level=None)
|
|
||||||
if self.SCHEMA_VERSION:
|
|
||||||
tables = [t[0] for t in await self.db.execute_fetchall(
|
|
||||||
"SELECT name FROM sqlite_master WHERE type='table';"
|
|
||||||
)]
|
|
||||||
if tables:
|
|
||||||
if 'version' in tables:
|
|
||||||
version = await self.db.execute_fetchone("SELECT version FROM version LIMIT 1;")
|
|
||||||
if version == (self.SCHEMA_VERSION,):
|
|
||||||
return
|
|
||||||
await self.db.executescript('\n'.join(
|
|
||||||
f"DROP TABLE {table};" for table in tables
|
|
||||||
))
|
|
||||||
await self.db.execute(self.CREATE_VERSION_TABLE)
|
|
||||||
await self.db.execute("INSERT INTO version VALUES (?)", (self.SCHEMA_VERSION,))
|
|
||||||
await self.db.executescript(self.CREATE_TABLES_QUERY)
|
|
||||||
|
|
||||||
async def close(self):
|
|
||||||
await self.db.close()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _insert_sql(table: str, data: dict, ignore_duplicate: bool = False,
|
|
||||||
replace: bool = False) -> Tuple[str, List]:
|
|
||||||
columns, values = [], []
|
|
||||||
for column, value in data.items():
|
|
||||||
columns.append(column)
|
|
||||||
values.append(value)
|
|
||||||
policy = ""
|
|
||||||
if ignore_duplicate:
|
|
||||||
policy = " OR IGNORE"
|
|
||||||
if replace:
|
|
||||||
policy = " OR REPLACE"
|
|
||||||
sql = "INSERT{} INTO {} ({}) VALUES ({})".format(
|
|
||||||
policy, table, ', '.join(columns), ', '.join(['?'] * len(values))
|
|
||||||
)
|
|
||||||
return sql, values
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _update_sql(table: str, data: dict, where: str,
|
|
||||||
constraints: Union[list, tuple]) -> Tuple[str, list]:
|
|
||||||
columns, values = [], []
|
|
||||||
for column, value in data.items():
|
|
||||||
columns.append(f"{column} = ?")
|
|
||||||
values.append(value)
|
|
||||||
values.extend(constraints)
|
|
||||||
sql = "UPDATE {} SET {} WHERE {}".format(
|
|
||||||
table, ', '.join(columns), where
|
|
||||||
)
|
|
||||||
return sql, values
|
|
||||||
|
|
||||||
|
|
||||||
def dict_row_factory(cursor, row):
|
|
||||||
d = {}
|
|
||||||
for idx, col in enumerate(cursor.description):
|
|
||||||
d[col[0]] = row[idx]
|
|
||||||
return d
|
|
|
@ -1,22 +1,30 @@
|
||||||
import os
|
import os
|
||||||
|
import stat
|
||||||
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional, Dict
|
from typing import Optional, Dict
|
||||||
|
|
||||||
from lbry.db import Database
|
from lbry.db import Database
|
||||||
from lbry.blockchain.ledger import Ledger
|
|
||||||
|
|
||||||
from .wallet import Wallet
|
from .wallet import Wallet
|
||||||
|
from .account import SingleKey, HierarchicalDeterministic
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class WalletManager:
|
class WalletManager:
|
||||||
|
|
||||||
def __init__(self, ledger: Ledger, db: Database):
|
def __init__(self, db: Database):
|
||||||
self.ledger = ledger
|
|
||||||
self.db = db
|
self.db = db
|
||||||
|
self.ledger = db.ledger
|
||||||
self.wallets: Dict[str, Wallet] = {}
|
self.wallets: Dict[str, Wallet] = {}
|
||||||
|
if self.ledger.conf.wallet_storage == "file":
|
||||||
|
self.storage = FileWallet(self.db, self.ledger.conf.wallet_dir)
|
||||||
|
elif self.ledger.conf.wallet_storage == "database":
|
||||||
|
self.storage = DatabaseWallet(self.db)
|
||||||
|
else:
|
||||||
|
raise Exception(f"Unknown wallet storage format: {self.ledger.conf.wallet_storage}")
|
||||||
|
|
||||||
def __getitem__(self, wallet_id: str) -> Wallet:
|
def __getitem__(self, wallet_id: str) -> Wallet:
|
||||||
try:
|
try:
|
||||||
|
@ -43,32 +51,12 @@ class WalletManager:
|
||||||
raise ValueError("Cannot spend funds with locked wallet, unlock first.")
|
raise ValueError("Cannot spend funds with locked wallet, unlock first.")
|
||||||
return wallet
|
return wallet
|
||||||
|
|
||||||
@property
|
async def initialize(self):
|
||||||
def path(self):
|
|
||||||
return os.path.join(self.ledger.conf.wallet_dir, 'wallets')
|
|
||||||
|
|
||||||
def sync_ensure_path_exists(self):
|
|
||||||
if not os.path.exists(self.path):
|
|
||||||
os.mkdir(self.path)
|
|
||||||
|
|
||||||
async def ensure_path_exists(self):
|
|
||||||
await asyncio.get_running_loop().run_in_executor(
|
|
||||||
None, self.sync_ensure_path_exists
|
|
||||||
)
|
|
||||||
|
|
||||||
async def load(self):
|
|
||||||
wallets_directory = self.path
|
|
||||||
for wallet_id in self.ledger.conf.wallets:
|
for wallet_id in self.ledger.conf.wallets:
|
||||||
if wallet_id in self.wallets:
|
if wallet_id in self.wallets:
|
||||||
log.warning("Ignoring duplicate wallet_id in config: %s", wallet_id)
|
log.warning("Ignoring duplicate wallet_id in config: %s", wallet_id)
|
||||||
continue
|
continue
|
||||||
wallet_path = os.path.join(wallets_directory, wallet_id)
|
await self.load(wallet_id)
|
||||||
if not os.path.exists(wallet_path):
|
|
||||||
if not wallet_id == "default_wallet": # we'll probably generate this wallet, don't show error
|
|
||||||
log.error("Could not load wallet, file does not exist: %s", wallet_path)
|
|
||||||
continue
|
|
||||||
wallet = await Wallet.from_path(self.ledger, self.db, wallet_path)
|
|
||||||
self.add(wallet)
|
|
||||||
default_wallet = self.default
|
default_wallet = self.default
|
||||||
if default_wallet is None:
|
if default_wallet is None:
|
||||||
if self.ledger.conf.create_default_wallet:
|
if self.ledger.conf.create_default_wallet:
|
||||||
|
@ -83,34 +71,153 @@ class WalletManager:
|
||||||
elif not default_wallet.has_accounts and self.ledger.conf.create_default_account:
|
elif not default_wallet.has_accounts and self.ledger.conf.create_default_account:
|
||||||
await default_wallet.accounts.generate()
|
await default_wallet.accounts.generate()
|
||||||
|
|
||||||
def add(self, wallet: Wallet) -> Wallet:
|
async def load(self, wallet_id: str) -> Optional[Wallet]:
|
||||||
self.wallets[wallet.id] = wallet
|
wallet = await self.storage.get(wallet_id)
|
||||||
return wallet
|
if wallet is not None:
|
||||||
|
return self.add(wallet)
|
||||||
async def add_from_path(self, wallet_path) -> Wallet:
|
|
||||||
wallet_id = os.path.basename(wallet_path)
|
|
||||||
if wallet_id in self.wallets:
|
|
||||||
existing = self.wallets.get(wallet_id)
|
|
||||||
if existing.storage.path == wallet_path:
|
|
||||||
raise Exception(f"Wallet '{wallet_id}' is already loaded.")
|
|
||||||
raise Exception(
|
|
||||||
f"Wallet '{wallet_id}' is already loaded from '{existing.storage.path}'"
|
|
||||||
f" and cannot be loaded from '{wallet_path}'. Consider changing the wallet"
|
|
||||||
f" filename to be unique in order to avoid conflicts."
|
|
||||||
)
|
|
||||||
wallet = await Wallet.from_path(self.ledger, self.db, wallet_path)
|
|
||||||
return self.add(wallet)
|
|
||||||
|
|
||||||
async def create(
|
async def create(
|
||||||
self, wallet_id: str, name: str,
|
self, wallet_id: str, name: str = "",
|
||||||
create_account=False, language='en', single_key=False) -> Wallet:
|
create_account=False, language="en", single_key=False
|
||||||
|
) -> Wallet:
|
||||||
if wallet_id in self.wallets:
|
if wallet_id in self.wallets:
|
||||||
raise Exception(f"Wallet with id '{wallet_id}' is already loaded and cannot be created.")
|
raise Exception(f"Wallet with id '{wallet_id}' is already loaded and cannot be created.")
|
||||||
wallet_path = os.path.join(self.path, wallet_id)
|
if await self.storage.exists(wallet_id):
|
||||||
if os.path.exists(wallet_path):
|
raise Exception(f"Wallet '{wallet_id}' already exists, use 'wallet_add' to load wallet.")
|
||||||
raise Exception(f"Wallet at path '{wallet_path}' already exists, use 'wallet_add' to load wallet.")
|
wallet = Wallet(wallet_id, self.db, name)
|
||||||
wallet = await Wallet.create(
|
if create_account:
|
||||||
self.ledger, self.db, wallet_path, name,
|
await wallet.accounts.generate(language=language, address_generator={
|
||||||
create_account, language, single_key
|
'name': SingleKey.name if single_key else HierarchicalDeterministic.name
|
||||||
)
|
})
|
||||||
|
await self.storage.save(wallet)
|
||||||
return self.add(wallet)
|
return self.add(wallet)
|
||||||
|
|
||||||
|
def add(self, wallet: Wallet) -> Wallet:
|
||||||
|
self.wallets[wallet.id] = wallet
|
||||||
|
wallet.on_change.listen(lambda _: self.storage.save(wallet))
|
||||||
|
return wallet
|
||||||
|
|
||||||
|
async def _report_state(self):
|
||||||
|
try:
|
||||||
|
for account in self.accounts:
|
||||||
|
balance = dewies_to_lbc(await account.get_balance(include_claims=True))
|
||||||
|
_, channel_count = await account.get_channels(limit=1)
|
||||||
|
claim_count = await account.get_claim_count()
|
||||||
|
if isinstance(account.receiving, SingleKey):
|
||||||
|
log.info("Loaded single key account %s with %s LBC. "
|
||||||
|
"%d channels, %d certificates and %d claims",
|
||||||
|
account.id, balance, channel_count, len(account.channel_keys), claim_count)
|
||||||
|
else:
|
||||||
|
total_receiving = len(await account.receiving.get_addresses())
|
||||||
|
total_change = len(await account.change.get_addresses())
|
||||||
|
log.info("Loaded account %s with %s LBC, %d receiving addresses (gap: %d), "
|
||||||
|
"%d change addresses (gap: %d), %d channels, %d certificates and %d claims. ",
|
||||||
|
account.id, balance, total_receiving, account.receiving.gap, total_change,
|
||||||
|
account.change.gap, channel_count, len(account.channel_keys), claim_count)
|
||||||
|
except Exception as err:
|
||||||
|
if isinstance(err, asyncio.CancelledError): # TODO: remove when updated to 3.8
|
||||||
|
raise
|
||||||
|
log.exception(
|
||||||
|
'Failed to display wallet state, please file issue '
|
||||||
|
'for this bug along with the traceback you see below:'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class WalletStorage:
|
||||||
|
|
||||||
|
async def prepare(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def exists(self, walllet_id: str) -> bool:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def get(self, wallet_id: str) -> Wallet:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def save(self, wallet: Wallet):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class FileWallet(WalletStorage):
|
||||||
|
|
||||||
|
def __init__(self, db, wallet_dir):
|
||||||
|
self.db = db
|
||||||
|
self.wallet_dir = wallet_dir
|
||||||
|
|
||||||
|
def get_wallet_path(self, wallet_id: str):
|
||||||
|
return os.path.join(self.wallet_dir, wallet_id)
|
||||||
|
|
||||||
|
async def prepare(self):
|
||||||
|
await asyncio.get_running_loop().run_in_executor(
|
||||||
|
None, self.sync_ensure_wallets_directory_exists
|
||||||
|
)
|
||||||
|
|
||||||
|
def sync_ensure_wallets_directory_exists(self):
|
||||||
|
if not os.path.exists(self.wallet_dir):
|
||||||
|
os.mkdir(self.wallet_dir)
|
||||||
|
|
||||||
|
async def exists(self, wallet_id: str) -> bool:
|
||||||
|
return os.path.exists(self.get_wallet_path(wallet_id))
|
||||||
|
|
||||||
|
async def get(self, wallet_id: str) -> Wallet:
|
||||||
|
wallet_dict = await asyncio.get_running_loop().run_in_executor(
|
||||||
|
None, self.sync_read, wallet_id
|
||||||
|
)
|
||||||
|
if wallet_dict is not None:
|
||||||
|
return await Wallet.from_dict(wallet_id, wallet_dict, self.db)
|
||||||
|
|
||||||
|
def sync_read(self, wallet_id):
|
||||||
|
try:
|
||||||
|
with open(self.get_wallet_path(wallet_id), 'r') as f:
|
||||||
|
json_data = f.read()
|
||||||
|
return json.loads(json_data)
|
||||||
|
except FileNotFoundError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def save(self, wallet: Wallet):
|
||||||
|
return await asyncio.get_running_loop().run_in_executor(
|
||||||
|
None, self.sync_write, wallet
|
||||||
|
)
|
||||||
|
|
||||||
|
def sync_write(self, wallet: Wallet):
|
||||||
|
temp_path = os.path.join(self.wallet_dir, f".tmp.{os.path.basename(wallet.id)}")
|
||||||
|
with open(temp_path, "w") as f:
|
||||||
|
f.write(wallet.to_serialized())
|
||||||
|
f.flush()
|
||||||
|
os.fsync(f.fileno())
|
||||||
|
|
||||||
|
wallet_path = self.get_wallet_path(wallet.id)
|
||||||
|
if os.path.exists(wallet_path):
|
||||||
|
mode = os.stat(wallet_path).st_mode
|
||||||
|
else:
|
||||||
|
mode = stat.S_IREAD | stat.S_IWRITE
|
||||||
|
try:
|
||||||
|
os.rename(temp_path, wallet_path)
|
||||||
|
except Exception: # pylint: disable=broad-except
|
||||||
|
os.remove(wallet_path)
|
||||||
|
os.rename(temp_path, wallet_path)
|
||||||
|
os.chmod(wallet_path, mode)
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseWallet(WalletStorage):
|
||||||
|
|
||||||
|
def __init__(self, db: 'Database'):
|
||||||
|
self.db = db
|
||||||
|
|
||||||
|
async def prepare(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def exists(self, wallet_id: str) -> bool:
|
||||||
|
return await self.db.has_wallet(wallet_id)
|
||||||
|
|
||||||
|
async def get(self, wallet_id: str) -> Wallet:
|
||||||
|
data = await self.db.get_wallet(wallet_id)
|
||||||
|
if data:
|
||||||
|
wallet_dict = json.loads(data['data'])
|
||||||
|
if wallet_dict is not None:
|
||||||
|
return await Wallet.from_dict(wallet_id, wallet_dict, self.db)
|
||||||
|
|
||||||
|
async def save(self, wallet: Wallet):
|
||||||
|
await self.db.add_wallet(
|
||||||
|
wallet.id, wallet.to_serialized()
|
||||||
|
)
|
||||||
|
|
|
@ -1,60 +0,0 @@
|
||||||
import os
|
|
||||||
import stat
|
|
||||||
import json
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
|
|
||||||
class WalletStorage:
|
|
||||||
VERSION = 1
|
|
||||||
|
|
||||||
def __init__(self, path=None):
|
|
||||||
self.path = path
|
|
||||||
|
|
||||||
def sync_read(self):
|
|
||||||
with open(self.path, 'r') as f:
|
|
||||||
json_data = f.read()
|
|
||||||
json_dict = json.loads(json_data)
|
|
||||||
if json_dict.get('version') == self.VERSION:
|
|
||||||
return json_dict
|
|
||||||
else:
|
|
||||||
return self.upgrade(json_dict)
|
|
||||||
|
|
||||||
async def read(self):
|
|
||||||
return await asyncio.get_running_loop().run_in_executor(
|
|
||||||
None, self.sync_read
|
|
||||||
)
|
|
||||||
|
|
||||||
def upgrade(self, json_dict):
|
|
||||||
version = json_dict.pop('version', -1)
|
|
||||||
if version == -1:
|
|
||||||
pass
|
|
||||||
json_dict['version'] = self.VERSION
|
|
||||||
return json_dict
|
|
||||||
|
|
||||||
def sync_write(self, json_dict):
|
|
||||||
|
|
||||||
json_data = json.dumps(json_dict, indent=4, sort_keys=True)
|
|
||||||
if self.path is None:
|
|
||||||
return json_data
|
|
||||||
|
|
||||||
temp_path = "{}.tmp.{}".format(self.path, os.getpid())
|
|
||||||
with open(temp_path, "w") as f:
|
|
||||||
f.write(json_data)
|
|
||||||
f.flush()
|
|
||||||
os.fsync(f.fileno())
|
|
||||||
|
|
||||||
if os.path.exists(self.path):
|
|
||||||
mode = os.stat(self.path).st_mode
|
|
||||||
else:
|
|
||||||
mode = stat.S_IREAD | stat.S_IWRITE
|
|
||||||
try:
|
|
||||||
os.rename(temp_path, self.path)
|
|
||||||
except Exception: # pylint: disable=broad-except
|
|
||||||
os.remove(self.path)
|
|
||||||
os.rename(temp_path, self.path)
|
|
||||||
os.chmod(self.path, mode)
|
|
||||||
|
|
||||||
async def write(self, json_dict):
|
|
||||||
return await asyncio.get_running_loop().run_in_executor(
|
|
||||||
None, self.sync_write, json_dict
|
|
||||||
)
|
|
|
@ -1,420 +0,0 @@
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
#from io import StringIO
|
|
||||||
#from functools import partial
|
|
||||||
#from operator import itemgetter
|
|
||||||
from collections import defaultdict
|
|
||||||
#from binascii import hexlify, unhexlify
|
|
||||||
from typing import List, Optional, DefaultDict, NamedTuple
|
|
||||||
|
|
||||||
#from lbry.crypto.hash import double_sha256, sha256
|
|
||||||
|
|
||||||
from lbry.tasks import TaskGroup
|
|
||||||
from lbry.blockchain.transaction import Transaction
|
|
||||||
#from lbry.blockchain.block import get_block_filter
|
|
||||||
from lbry.event import EventController
|
|
||||||
from lbry.service.base import Service, Sync
|
|
||||||
|
|
||||||
from .account import AddressManager
|
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class TransactionEvent(NamedTuple):
|
|
||||||
address: str
|
|
||||||
tx: Transaction
|
|
||||||
|
|
||||||
|
|
||||||
class AddressesGeneratedEvent(NamedTuple):
|
|
||||||
address_manager: AddressManager
|
|
||||||
addresses: List[str]
|
|
||||||
|
|
||||||
|
|
||||||
class TransactionCacheItem:
|
|
||||||
__slots__ = '_tx', 'lock', 'has_tx', 'pending_verifications'
|
|
||||||
|
|
||||||
def __init__(self, tx: Optional[Transaction] = None, lock: Optional[asyncio.Lock] = None):
|
|
||||||
self.has_tx = asyncio.Event()
|
|
||||||
self.lock = lock or asyncio.Lock()
|
|
||||||
self._tx = self.tx = tx
|
|
||||||
self.pending_verifications = 0
|
|
||||||
|
|
||||||
@property
|
|
||||||
def tx(self) -> Optional[Transaction]:
|
|
||||||
return self._tx
|
|
||||||
|
|
||||||
@tx.setter
|
|
||||||
def tx(self, tx: Transaction):
|
|
||||||
self._tx = tx
|
|
||||||
if tx is not None:
|
|
||||||
self.has_tx.set()
|
|
||||||
|
|
||||||
|
|
||||||
class SPVSync(Sync):
|
|
||||||
|
|
||||||
def __init__(self, service: Service):
|
|
||||||
super().__init__(service.ledger, service.db)
|
|
||||||
|
|
||||||
self.accounts = []
|
|
||||||
|
|
||||||
self._on_header_controller = EventController()
|
|
||||||
self.on_header = self._on_header_controller.stream
|
|
||||||
self.on_header.listen(
|
|
||||||
lambda change: log.info(
|
|
||||||
'%s: added %s header blocks',
|
|
||||||
self.ledger.get_id(), change
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self._download_height = 0
|
|
||||||
|
|
||||||
self._on_ready_controller = EventController()
|
|
||||||
self.on_ready = self._on_ready_controller.stream
|
|
||||||
|
|
||||||
#self._tx_cache = pylru.lrucache(100000)
|
|
||||||
self._update_tasks = TaskGroup()
|
|
||||||
self._other_tasks = TaskGroup() # that we dont need to start
|
|
||||||
self._header_processing_lock = asyncio.Lock()
|
|
||||||
self._address_update_locks: DefaultDict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
|
|
||||||
self._known_addresses_out_of_sync = set()
|
|
||||||
|
|
||||||
# async def advance(self):
|
|
||||||
# address_array = [
|
|
||||||
# bytearray(a['address'].encode())
|
|
||||||
# for a in await self.service.db.get_all_addresses()
|
|
||||||
# ]
|
|
||||||
# block_filters = await self.service.get_block_address_filters()
|
|
||||||
# for block_hash, block_filter in block_filters.items():
|
|
||||||
# bf = get_block_filter(block_filter)
|
|
||||||
# if bf.MatchAny(address_array):
|
|
||||||
# print(f'match: {block_hash} - {block_filter}')
|
|
||||||
# tx_filters = await self.service.get_transaction_address_filters(block_hash=block_hash)
|
|
||||||
# for txid, tx_filter in tx_filters.items():
|
|
||||||
# tf = get_block_filter(tx_filter)
|
|
||||||
# if tf.MatchAny(address_array):
|
|
||||||
# print(f' match: {txid} - {tx_filter}')
|
|
||||||
# txs = await self.service.search_transactions([txid])
|
|
||||||
# tx = Transaction(unhexlify(txs[txid]))
|
|
||||||
# await self.service.db.insert_transaction(tx)
|
|
||||||
#
|
|
||||||
# async def get_local_status_and_history(self, address, history=None):
|
|
||||||
# if not history:
|
|
||||||
# address_details = await self.db.get_address(address=address)
|
|
||||||
# history = (address_details['history'] if address_details else '') or ''
|
|
||||||
# parts = history.split(':')[:-1]
|
|
||||||
# return (
|
|
||||||
# hexlify(sha256(history.encode())).decode() if history else None,
|
|
||||||
# list(zip(parts[0::2], map(int, parts[1::2])))
|
|
||||||
# )
|
|
||||||
#
|
|
||||||
# @staticmethod
|
|
||||||
# def get_root_of_merkle_tree(branches, branch_positions, working_branch):
|
|
||||||
# for i, branch in enumerate(branches):
|
|
||||||
# other_branch = unhexlify(branch)[::-1]
|
|
||||||
# other_branch_on_left = bool((branch_positions >> i) & 1)
|
|
||||||
# if other_branch_on_left:
|
|
||||||
# combined = other_branch + working_branch
|
|
||||||
# else:
|
|
||||||
# combined = working_branch + other_branch
|
|
||||||
# working_branch = double_sha256(combined)
|
|
||||||
# return hexlify(working_branch[::-1])
|
|
||||||
#
|
|
||||||
async def start(self):
|
|
||||||
fully_synced = self.on_ready.first
|
|
||||||
#asyncio.create_task(self.network.start())
|
|
||||||
#await self.network.on_connected.first
|
|
||||||
#async with self._header_processing_lock:
|
|
||||||
# await self._update_tasks.add(self.initial_headers_sync())
|
|
||||||
await fully_synced
|
|
||||||
#
|
|
||||||
# async def join_network(self, *_):
|
|
||||||
# log.info("Subscribing and updating accounts.")
|
|
||||||
# await self._update_tasks.add(self.subscribe_accounts())
|
|
||||||
# await self._update_tasks.done.wait()
|
|
||||||
# self._on_ready_controller.add(True)
|
|
||||||
#
|
|
||||||
async def stop(self):
|
|
||||||
self._update_tasks.cancel()
|
|
||||||
self._other_tasks.cancel()
|
|
||||||
await self._update_tasks.done.wait()
|
|
||||||
await self._other_tasks.done.wait()
|
|
||||||
#
|
|
||||||
# @property
|
|
||||||
# def local_height_including_downloaded_height(self):
|
|
||||||
# return max(self.headers.height, self._download_height)
|
|
||||||
#
|
|
||||||
# async def initial_headers_sync(self):
|
|
||||||
# get_chunk = partial(self.network.retriable_call, self.network.get_headers, count=1000, b64=True)
|
|
||||||
# self.headers.chunk_getter = get_chunk
|
|
||||||
#
|
|
||||||
# async def doit():
|
|
||||||
# for height in reversed(sorted(self.headers.known_missing_checkpointed_chunks)):
|
|
||||||
# async with self._header_processing_lock:
|
|
||||||
# await self.headers.ensure_chunk_at(height)
|
|
||||||
# self._other_tasks.add(doit())
|
|
||||||
# await self.update_headers()
|
|
||||||
#
|
|
||||||
# async def update_headers(self, height=None, headers=None, subscription_update=False):
|
|
||||||
# rewound = 0
|
|
||||||
# while True:
|
|
||||||
#
|
|
||||||
# if height is None or height > len(self.headers):
|
|
||||||
# # sometimes header subscription updates are for a header in the future
|
|
||||||
# # which can't be connected, so we do a normal header sync instead
|
|
||||||
# height = len(self.headers)
|
|
||||||
# headers = None
|
|
||||||
# subscription_update = False
|
|
||||||
#
|
|
||||||
# if not headers:
|
|
||||||
# header_response = await self.network.retriable_call(self.network.get_headers, height, 2001)
|
|
||||||
# headers = header_response['hex']
|
|
||||||
#
|
|
||||||
# if not headers:
|
|
||||||
# # Nothing to do, network thinks we're already at the latest height.
|
|
||||||
# return
|
|
||||||
#
|
|
||||||
# added = await self.headers.connect(height, unhexlify(headers))
|
|
||||||
# if added > 0:
|
|
||||||
# height += added
|
|
||||||
# self._on_header_controller.add(
|
|
||||||
# BlockHeightEvent(self.headers.height, added))
|
|
||||||
#
|
|
||||||
# if rewound > 0:
|
|
||||||
# # we started rewinding blocks and apparently found
|
|
||||||
# # a new chain
|
|
||||||
# rewound = 0
|
|
||||||
# await self.db.rewind_blockchain(height)
|
|
||||||
#
|
|
||||||
# if subscription_update:
|
|
||||||
# # subscription updates are for latest header already
|
|
||||||
# # so we don't need to check if there are newer / more
|
|
||||||
# # on another loop of update_headers(), just return instead
|
|
||||||
# return
|
|
||||||
#
|
|
||||||
# elif added == 0:
|
|
||||||
# # we had headers to connect but none got connected, probably a reorganization
|
|
||||||
# height -= 1
|
|
||||||
# rewound += 1
|
|
||||||
# log.warning(
|
|
||||||
# "Blockchain Reorganization: attempting rewind to height %s from starting height %s",
|
|
||||||
# height, height+rewound
|
|
||||||
# )
|
|
||||||
#
|
|
||||||
# else:
|
|
||||||
# raise IndexError(f"headers.connect() returned negative number ({added})")
|
|
||||||
#
|
|
||||||
# if height < 0:
|
|
||||||
# raise IndexError(
|
|
||||||
# "Blockchain reorganization rewound all the way back to genesis hash. "
|
|
||||||
# "Something is very wrong. Maybe you are on the wrong blockchain?"
|
|
||||||
# )
|
|
||||||
#
|
|
||||||
# if rewound >= 100:
|
|
||||||
# raise IndexError(
|
|
||||||
# "Blockchain reorganization dropped {} headers. This is highly unusual. "
|
|
||||||
# "Will not continue to attempt reorganizing. Please, delete the ledger "
|
|
||||||
# "synchronization directory inside your wallet directory (folder: '{}') and "
|
|
||||||
# "restart the program to synchronize from scratch."
|
|
||||||
# .format(rewound, self.ledger.get_id())
|
|
||||||
# )
|
|
||||||
#
|
|
||||||
# headers = None # ready to download some more headers
|
|
||||||
#
|
|
||||||
# # if we made it this far and this was a subscription_update
|
|
||||||
# # it means something went wrong and now we're doing a more
|
|
||||||
# # robust sync, turn off subscription update shortcut
|
|
||||||
# subscription_update = False
|
|
||||||
#
|
|
||||||
# async def receive_header(self, response):
|
|
||||||
# async with self._header_processing_lock:
|
|
||||||
# header = response[0]
|
|
||||||
# await self.update_headers(
|
|
||||||
# height=header['height'], headers=header['hex'], subscription_update=True
|
|
||||||
# )
|
|
||||||
#
|
|
||||||
# async def subscribe_accounts(self):
|
|
||||||
# if self.network.is_connected and self.accounts:
|
|
||||||
# log.info("Subscribe to %i accounts", len(self.accounts))
|
|
||||||
# await asyncio.wait([
|
|
||||||
# self.subscribe_account(a) for a in self.accounts
|
|
||||||
# ])
|
|
||||||
#
|
|
||||||
# async def subscribe_account(self, account: Account):
|
|
||||||
# for address_manager in account.address_managers.values():
|
|
||||||
# await self.subscribe_addresses(address_manager, await address_manager.get_addresses())
|
|
||||||
# await account.ensure_address_gap()
|
|
||||||
#
|
|
||||||
# async def unsubscribe_account(self, account: Account):
|
|
||||||
# for address in await account.get_addresses():
|
|
||||||
# await self.network.unsubscribe_address(address)
|
|
||||||
#
|
|
||||||
# async def subscribe_addresses(
|
|
||||||
# self, address_manager: AddressManager, addresses: List[str], batch_size: int = 1000):
|
|
||||||
# if self.network.is_connected and addresses:
|
|
||||||
# addresses_remaining = list(addresses)
|
|
||||||
# while addresses_remaining:
|
|
||||||
# batch = addresses_remaining[:batch_size]
|
|
||||||
# results = await self.network.subscribe_address(*batch)
|
|
||||||
# for address, remote_status in zip(batch, results):
|
|
||||||
# self._update_tasks.add(self.update_history(address, remote_status, address_manager))
|
|
||||||
# addresses_remaining = addresses_remaining[batch_size:]
|
|
||||||
# log.info("subscribed to %i/%i addresses on %s:%i", len(addresses) - len(addresses_remaining),
|
|
||||||
# len(addresses), *self.network.client.server_address_and_port)
|
|
||||||
# log.info(
|
|
||||||
# "finished subscribing to %i addresses on %s:%i", len(addresses),
|
|
||||||
# *self.network.client.server_address_and_port
|
|
||||||
# )
|
|
||||||
#
|
|
||||||
# def process_status_update(self, update):
|
|
||||||
# address, remote_status = update
|
|
||||||
# self._update_tasks.add(self.update_history(address, remote_status))
|
|
||||||
#
|
|
||||||
# async def update_history(self, address, remote_status, address_manager: AddressManager = None):
|
|
||||||
# async with self._address_update_locks[address]:
|
|
||||||
# self._known_addresses_out_of_sync.discard(address)
|
|
||||||
#
|
|
||||||
# local_status, local_history = await self.get_local_status_and_history(address)
|
|
||||||
#
|
|
||||||
# if local_status == remote_status:
|
|
||||||
# return True
|
|
||||||
#
|
|
||||||
# remote_history = await self.network.retriable_call(self.network.get_history, address)
|
|
||||||
# remote_history = list(map(itemgetter('tx_hash', 'height'), remote_history))
|
|
||||||
# we_need = set(remote_history) - set(local_history)
|
|
||||||
# if not we_need:
|
|
||||||
# return True
|
|
||||||
#
|
|
||||||
# cache_tasks: List[asyncio.Task[Transaction]] = []
|
|
||||||
# synced_history = StringIO()
|
|
||||||
# loop = asyncio.get_running_loop()
|
|
||||||
# for i, (txid, remote_height) in enumerate(remote_history):
|
|
||||||
# if i < len(local_history) and local_history[i] == (txid, remote_height) and not cache_tasks:
|
|
||||||
# synced_history.write(f'{txid}:{remote_height}:')
|
|
||||||
# else:
|
|
||||||
# check_local = (txid, remote_height) not in we_need
|
|
||||||
# cache_tasks.append(loop.create_task(
|
|
||||||
# self.cache_transaction(unhexlify(txid)[::-1], remote_height, check_local=check_local)
|
|
||||||
# ))
|
|
||||||
#
|
|
||||||
# synced_txs = []
|
|
||||||
# for task in cache_tasks:
|
|
||||||
# tx = await task
|
|
||||||
#
|
|
||||||
# check_db_for_txos = []
|
|
||||||
# for txi in tx.inputs:
|
|
||||||
# if txi.txo_ref.txo is not None:
|
|
||||||
# continue
|
|
||||||
# cache_item = self._tx_cache.get(txi.txo_ref.tx_ref.hash)
|
|
||||||
# if cache_item is not None:
|
|
||||||
# if cache_item.tx is None:
|
|
||||||
# await cache_item.has_tx.wait()
|
|
||||||
# assert cache_item.tx is not None
|
|
||||||
# txi.txo_ref = cache_item.tx.outputs[txi.txo_ref.position].ref
|
|
||||||
# else:
|
|
||||||
# check_db_for_txos.append(txi.txo_ref.hash)
|
|
||||||
#
|
|
||||||
# referenced_txos = {} if not check_db_for_txos else {
|
|
||||||
# txo.id: txo for txo in await self.db.get_txos(
|
|
||||||
# txo_hash__in=check_db_for_txos, order_by='txo.txo_hash', no_tx=True
|
|
||||||
# )
|
|
||||||
# }
|
|
||||||
#
|
|
||||||
# for txi in tx.inputs:
|
|
||||||
# if txi.txo_ref.txo is not None:
|
|
||||||
# continue
|
|
||||||
# referenced_txo = referenced_txos.get(txi.txo_ref.id)
|
|
||||||
# if referenced_txo is not None:
|
|
||||||
# txi.txo_ref = referenced_txo.ref
|
|
||||||
#
|
|
||||||
# synced_history.write(f'{tx.id}:{tx.height}:')
|
|
||||||
# synced_txs.append(tx)
|
|
||||||
#
|
|
||||||
# await self.db.save_transaction_io_batch(
|
|
||||||
# synced_txs, address, self.ledger.address_to_hash160(address), synced_history.getvalue()
|
|
||||||
# )
|
|
||||||
# await asyncio.wait([
|
|
||||||
# self.ledger._on_transaction_controller.add(TransactionEvent(address, tx))
|
|
||||||
# for tx in synced_txs
|
|
||||||
# ])
|
|
||||||
#
|
|
||||||
# if address_manager is None:
|
|
||||||
# address_manager = await self.get_address_manager_for_address(address)
|
|
||||||
#
|
|
||||||
# if address_manager is not None:
|
|
||||||
# await address_manager.ensure_address_gap()
|
|
||||||
#
|
|
||||||
# local_status, local_history = \
|
|
||||||
# await self.get_local_status_and_history(address, synced_history.getvalue())
|
|
||||||
# if local_status != remote_status:
|
|
||||||
# if local_history == remote_history:
|
|
||||||
# return True
|
|
||||||
# log.warning(
|
|
||||||
# "Wallet is out of sync after syncing. Remote: %s with %d items, local: %s with %d items",
|
|
||||||
# remote_status, len(remote_history), local_status, len(local_history)
|
|
||||||
# )
|
|
||||||
# log.warning("local: %s", local_history)
|
|
||||||
# log.warning("remote: %s", remote_history)
|
|
||||||
# self._known_addresses_out_of_sync.add(address)
|
|
||||||
# return False
|
|
||||||
# else:
|
|
||||||
# return True
|
|
||||||
#
|
|
||||||
# async def cache_transaction(self, tx_hash, remote_height, check_local=True):
|
|
||||||
# cache_item = self._tx_cache.get(tx_hash)
|
|
||||||
# if cache_item is None:
|
|
||||||
# cache_item = self._tx_cache[tx_hash] = TransactionCacheItem()
|
|
||||||
# elif cache_item.tx is not None and \
|
|
||||||
# cache_item.tx.height >= remote_height and \
|
|
||||||
# (cache_item.tx.is_verified or remote_height < 1):
|
|
||||||
# return cache_item.tx # cached tx is already up-to-date
|
|
||||||
#
|
|
||||||
# try:
|
|
||||||
# cache_item.pending_verifications += 1
|
|
||||||
# return await self._update_cache_item(cache_item, tx_hash, remote_height, check_local)
|
|
||||||
# finally:
|
|
||||||
# cache_item.pending_verifications -= 1
|
|
||||||
#
|
|
||||||
# async def _update_cache_item(self, cache_item, tx_hash, remote_height, check_local=True):
|
|
||||||
#
|
|
||||||
# async with cache_item.lock:
|
|
||||||
#
|
|
||||||
# tx = cache_item.tx
|
|
||||||
#
|
|
||||||
# if tx is None and check_local:
|
|
||||||
# # check local db
|
|
||||||
# tx = cache_item.tx = await self.db.get_transaction(tx_hash=tx_hash)
|
|
||||||
#
|
|
||||||
# merkle = None
|
|
||||||
# if tx is None:
|
|
||||||
# # fetch from network
|
|
||||||
# _raw, merkle = await self.network.retriable_call(
|
|
||||||
# self.network.get_transaction_and_merkle, tx_hash, remote_height
|
|
||||||
# )
|
|
||||||
# tx = Transaction(unhexlify(_raw), height=merkle.get('block_height'))
|
|
||||||
# cache_item.tx = tx # make sure it's saved before caching it
|
|
||||||
# await self.maybe_verify_transaction(tx, remote_height, merkle)
|
|
||||||
# return tx
|
|
||||||
#
|
|
||||||
# async def maybe_verify_transaction(self, tx, remote_height, merkle=None):
|
|
||||||
# tx.height = remote_height
|
|
||||||
# cached = self._tx_cache.get(tx.hash)
|
|
||||||
# if not cached:
|
|
||||||
# # cache txs looked up by transaction_show too
|
|
||||||
# cached = TransactionCacheItem()
|
|
||||||
# cached.tx = tx
|
|
||||||
# self._tx_cache[tx.hash] = cached
|
|
||||||
# if 0 < remote_height < len(self.headers) and cached.pending_verifications <= 1:
|
|
||||||
# # can't be tx.pending_verifications == 1 because we have to handle the transaction_show case
|
|
||||||
# if not merkle:
|
|
||||||
# merkle = await self.network.retriable_call(self.network.get_merkle, tx.hash, remote_height)
|
|
||||||
# merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash)
|
|
||||||
# header = await self.headers.get(remote_height)
|
|
||||||
# tx.position = merkle['pos']
|
|
||||||
# tx.is_verified = merkle_root == header['merkle_root']
|
|
||||||
#
|
|
||||||
# async def get_address_manager_for_address(self, address) -> Optional[AddressManager]:
|
|
||||||
# details = await self.db.get_address(address=address)
|
|
||||||
# for account in self.accounts:
|
|
||||||
# if account.id == details['account']:
|
|
||||||
# return account.address_managers[details['chain']]
|
|
||||||
# return None
|
|
|
@ -1,5 +1,4 @@
|
||||||
# pylint: disable=arguments-differ
|
# pylint: disable=arguments-differ
|
||||||
import os
|
|
||||||
import json
|
import json
|
||||||
import zlib
|
import zlib
|
||||||
import asyncio
|
import asyncio
|
||||||
|
@ -10,9 +9,8 @@ from hashlib import sha256
|
||||||
from operator import attrgetter
|
from operator import attrgetter
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
|
|
||||||
|
|
||||||
from lbry.db import Database, SPENDABLE_TYPE_CODES, Result
|
from lbry.db import Database, SPENDABLE_TYPE_CODES, Result
|
||||||
from lbry.blockchain.ledger import Ledger
|
from lbry.event import EventController
|
||||||
from lbry.constants import COIN, NULL_HASH32
|
from lbry.constants import COIN, NULL_HASH32
|
||||||
from lbry.blockchain.transaction import Transaction, Input, Output
|
from lbry.blockchain.transaction import Transaction, Input, Output
|
||||||
from lbry.blockchain.dewies import dewies_to_lbc
|
from lbry.blockchain.dewies import dewies_to_lbc
|
||||||
|
@ -23,9 +21,8 @@ from lbry.schema.purchase import Purchase
|
||||||
from lbry.error import InsufficientFundsError, KeyFeeAboveMaxAllowedError
|
from lbry.error import InsufficientFundsError, KeyFeeAboveMaxAllowedError
|
||||||
from lbry.stream.managed_stream import ManagedStream
|
from lbry.stream.managed_stream import ManagedStream
|
||||||
|
|
||||||
from .account import Account, SingleKey, HierarchicalDeterministic
|
from .account import Account
|
||||||
from .coinselection import CoinSelector, OutputEffectiveAmountEstimator
|
from .coinselection import CoinSelector, OutputEffectiveAmountEstimator
|
||||||
from .storage import WalletStorage
|
|
||||||
from .preferences import TimestampedPreferences
|
from .preferences import TimestampedPreferences
|
||||||
|
|
||||||
|
|
||||||
|
@ -37,20 +34,22 @@ ENCRYPT_ON_DISK = 'encrypt-on-disk'
|
||||||
class Wallet:
|
class Wallet:
|
||||||
""" The primary role of Wallet is to encapsulate a collection
|
""" The primary role of Wallet is to encapsulate a collection
|
||||||
of accounts (seed/private keys) and the spending rules / settings
|
of accounts (seed/private keys) and the spending rules / settings
|
||||||
for the coins attached to those accounts. Wallets are represented
|
for the coins attached to those accounts.
|
||||||
by physical files on the filesystem.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, ledger: Ledger, db: Database, name: str, storage: WalletStorage, preferences: dict):
|
VERSION = 1
|
||||||
self.ledger = ledger
|
|
||||||
|
def __init__(self, wallet_id: str, db: Database, name: str = "", preferences: dict = None):
|
||||||
|
self.id = wallet_id
|
||||||
self.db = db
|
self.db = db
|
||||||
self.name = name
|
self.name = name
|
||||||
self.storage = storage
|
self.ledger = db.ledger
|
||||||
self.preferences = TimestampedPreferences(preferences or {})
|
self.preferences = TimestampedPreferences(preferences or {})
|
||||||
self.encryption_password: Optional[str] = None
|
self.encryption_password: Optional[str] = None
|
||||||
self.id = self.get_id()
|
|
||||||
|
|
||||||
self.utxo_lock = asyncio.Lock()
|
self.utxo_lock = asyncio.Lock()
|
||||||
|
self._on_change_controller = EventController()
|
||||||
|
self.on_change = self._on_change_controller.stream
|
||||||
|
|
||||||
self.accounts = AccountListManager(self)
|
self.accounts = AccountListManager(self)
|
||||||
self.claims = ClaimListManager(self)
|
self.claims = ClaimListManager(self)
|
||||||
|
@ -60,61 +59,55 @@ class Wallet:
|
||||||
self.purchases = PurchaseListManager(self)
|
self.purchases = PurchaseListManager(self)
|
||||||
self.supports = SupportListManager(self)
|
self.supports = SupportListManager(self)
|
||||||
|
|
||||||
def get_id(self):
|
async def notify_change(self, field: str, value=None):
|
||||||
return os.path.basename(self.storage.path) if self.storage.path else self.name
|
await self._on_change_controller.add({
|
||||||
|
'field': field, 'value': value
|
||||||
|
})
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def create(
|
async def from_dict(cls, wallet_id: str, wallet_dict, db: Database) -> 'Wallet':
|
||||||
cls, ledger: Ledger, db: Database, path: str, name: str,
|
if 'ledger' in wallet_dict and wallet_dict['ledger'] != db.ledger.get_id():
|
||||||
create_account=False, language='en', single_key=False):
|
|
||||||
wallet = cls(ledger, db, name, WalletStorage(path), {})
|
|
||||||
if create_account:
|
|
||||||
await wallet.accounts.generate(language=language, address_generator={
|
|
||||||
'name': SingleKey.name if single_key else HierarchicalDeterministic.name
|
|
||||||
})
|
|
||||||
await wallet.save()
|
|
||||||
return wallet
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def from_path(cls, ledger: Ledger, db: Database, path: str):
|
|
||||||
return await cls.from_storage(ledger, db, WalletStorage(path))
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def from_storage(cls, ledger: Ledger, db: Database, storage: WalletStorage) -> 'Wallet':
|
|
||||||
json_dict = await storage.read()
|
|
||||||
if 'ledger' in json_dict and json_dict['ledger'] != ledger.get_id():
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Using ledger {ledger.get_id()} but wallet is {json_dict['ledger']}."
|
f"Using ledger {db.ledger.get_id()} but wallet is {wallet_dict['ledger']}."
|
||||||
)
|
)
|
||||||
|
version = wallet_dict.get('version')
|
||||||
|
if version == 1:
|
||||||
|
pass
|
||||||
wallet = cls(
|
wallet = cls(
|
||||||
ledger, db,
|
wallet_id, db,
|
||||||
name=json_dict.get('name', 'Wallet'),
|
name=wallet_dict.get('name', 'Wallet'),
|
||||||
storage=storage,
|
preferences=wallet_dict.get('preferences', {}),
|
||||||
preferences=json_dict.get('preferences', {}),
|
|
||||||
)
|
)
|
||||||
for account_dict in json_dict.get('accounts', []):
|
for account_dict in wallet_dict.get('accounts', []):
|
||||||
await wallet.accounts.add_from_dict(account_dict)
|
await wallet.accounts.add_from_dict(account_dict)
|
||||||
return wallet
|
return wallet
|
||||||
|
|
||||||
def to_dict(self, encrypt_password: str = None):
|
def to_dict(self, encrypt_password: str = None) -> dict:
|
||||||
return {
|
return {
|
||||||
'version': WalletStorage.VERSION,
|
'version': self.VERSION,
|
||||||
'ledger': self.ledger.get_id(),
|
'ledger': self.ledger.get_id(),
|
||||||
'name': self.name,
|
'name': self.name,
|
||||||
'preferences': self.preferences.data,
|
'preferences': self.preferences.data,
|
||||||
'accounts': [a.to_dict(encrypt_password) for a in self.accounts]
|
'accounts': [a.to_dict(encrypt_password) for a in self.accounts]
|
||||||
}
|
}
|
||||||
|
|
||||||
async def save(self):
|
@classmethod
|
||||||
|
async def from_serialized(cls, wallet_id: str, json_data: str, db: Database) -> 'Wallet':
|
||||||
|
return await cls.from_dict(wallet_id, json.loads(json_data), db)
|
||||||
|
|
||||||
|
def to_serialized(self) -> str:
|
||||||
|
wallet_dict = None
|
||||||
if self.preferences.get(ENCRYPT_ON_DISK, False):
|
if self.preferences.get(ENCRYPT_ON_DISK, False):
|
||||||
if self.encryption_password is not None:
|
if self.encryption_password is not None:
|
||||||
return await self.storage.write(self.to_dict(encrypt_password=self.encryption_password))
|
wallet_dict = self.to_dict(encrypt_password=self.encryption_password)
|
||||||
elif not self.is_locked:
|
elif not self.is_locked:
|
||||||
log.warning(
|
log.warning(
|
||||||
"Disk encryption requested but no password available for encryption. "
|
"Disk encryption requested but no password available for encryption. "
|
||||||
"Saving wallet in an unencrypted state."
|
"Saving wallet in an unencrypted state."
|
||||||
)
|
)
|
||||||
return await self.storage.write(self.to_dict())
|
if wallet_dict is None:
|
||||||
|
wallet_dict = self.to_dict()
|
||||||
|
return json.dumps(wallet_dict, indent=4, sort_keys=True)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def hash(self) -> bytes:
|
def hash(self) -> bytes:
|
||||||
|
@ -157,8 +150,9 @@ class Wallet:
|
||||||
local_match.merge(account_dict)
|
local_match.merge(account_dict)
|
||||||
else:
|
else:
|
||||||
added_accounts.append(
|
added_accounts.append(
|
||||||
self.accounts.add_from_dict(account_dict)
|
await self.accounts.add_from_dict(account_dict, notify=False)
|
||||||
)
|
)
|
||||||
|
await self.notify_change('wallet.merge')
|
||||||
return added_accounts
|
return added_accounts
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -190,7 +184,7 @@ class Wallet:
|
||||||
async def decrypt(self):
|
async def decrypt(self):
|
||||||
assert not self.is_locked, "Cannot decrypt a locked wallet, unlock first."
|
assert not self.is_locked, "Cannot decrypt a locked wallet, unlock first."
|
||||||
self.preferences[ENCRYPT_ON_DISK] = False
|
self.preferences[ENCRYPT_ON_DISK] = False
|
||||||
await self.save()
|
await self.notify_change(ENCRYPT_ON_DISK, False)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def encrypt(self, password):
|
async def encrypt(self, password):
|
||||||
|
@ -198,7 +192,7 @@ class Wallet:
|
||||||
assert password, "Cannot encrypt with blank password."
|
assert password, "Cannot encrypt with blank password."
|
||||||
self.encryption_password = password
|
self.encryption_password = password
|
||||||
self.preferences[ENCRYPT_ON_DISK] = True
|
self.preferences[ENCRYPT_ON_DISK] = True
|
||||||
await self.save()
|
await self.notify_change(ENCRYPT_ON_DISK, True)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -237,7 +231,7 @@ class Wallet:
|
||||||
if await account.save_max_gap():
|
if await account.save_max_gap():
|
||||||
gap_changed = True
|
gap_changed = True
|
||||||
if gap_changed:
|
if gap_changed:
|
||||||
await self.save()
|
await self.notify_change('address-max-gap')
|
||||||
|
|
||||||
async def get_effective_amount_estimators(self, funding_accounts: Iterable[Account]):
|
async def get_effective_amount_estimators(self, funding_accounts: Iterable[Account]):
|
||||||
estimators = []
|
estimators = []
|
||||||
|
@ -379,30 +373,6 @@ class Wallet:
|
||||||
|
|
||||||
return tx
|
return tx
|
||||||
|
|
||||||
async def _report_state(self):
|
|
||||||
try:
|
|
||||||
for account in self.accounts:
|
|
||||||
balance = dewies_to_lbc(await account.get_balance(include_claims=True))
|
|
||||||
_, channel_count = await account.get_channels(limit=1)
|
|
||||||
claim_count = await account.get_claim_count()
|
|
||||||
if isinstance(account.receiving, SingleKey):
|
|
||||||
log.info("Loaded single key account %s with %s LBC. "
|
|
||||||
"%d channels, %d certificates and %d claims",
|
|
||||||
account.id, balance, channel_count, len(account.channel_keys), claim_count)
|
|
||||||
else:
|
|
||||||
total_receiving = len(await account.receiving.get_addresses())
|
|
||||||
total_change = len(await account.change.get_addresses())
|
|
||||||
log.info("Loaded account %s with %s LBC, %d receiving addresses (gap: %d), "
|
|
||||||
"%d change addresses (gap: %d), %d channels, %d certificates and %d claims. ",
|
|
||||||
account.id, balance, total_receiving, account.receiving.gap, total_change,
|
|
||||||
account.change.gap, channel_count, len(account.channel_keys), claim_count)
|
|
||||||
except Exception as err:
|
|
||||||
if isinstance(err, asyncio.CancelledError): # TODO: remove when updated to 3.8
|
|
||||||
raise
|
|
||||||
log.exception(
|
|
||||||
'Failed to display wallet state, please file issue '
|
|
||||||
'for this bug along with the traceback you see below:')
|
|
||||||
|
|
||||||
async def verify_duplicate(self, name: str, allow_duplicate: bool):
|
async def verify_duplicate(self, name: str, allow_duplicate: bool):
|
||||||
if not allow_duplicate:
|
if not allow_duplicate:
|
||||||
claims = await self.claims.list(claim_name=name)
|
claims = await self.claims.list(claim_name=name)
|
||||||
|
@ -438,21 +408,22 @@ class AccountListManager:
|
||||||
return account
|
return account
|
||||||
|
|
||||||
async def generate(self, name: str = None, language: str = 'en', address_generator: dict = None) -> Account:
|
async def generate(self, name: str = None, language: str = 'en', address_generator: dict = None) -> Account:
|
||||||
account = await Account.generate(
|
account = await Account.generate(self.wallet.db, name, language, address_generator)
|
||||||
self.wallet.ledger, self.wallet.db, name, language, address_generator
|
|
||||||
)
|
|
||||||
self._accounts.append(account)
|
self._accounts.append(account)
|
||||||
|
await self.wallet.notify_change('account.added')
|
||||||
return account
|
return account
|
||||||
|
|
||||||
async def add_from_dict(self, account_dict: dict) -> Account:
|
async def add_from_dict(self, account_dict: dict, notify=True) -> Account:
|
||||||
account = await Account.from_dict(self.wallet.ledger, self.wallet.db, account_dict)
|
account = await Account.from_dict(self.wallet.db, account_dict)
|
||||||
self._accounts.append(account)
|
self._accounts.append(account)
|
||||||
|
if notify:
|
||||||
|
await self.wallet.notify_change('account.added')
|
||||||
return account
|
return account
|
||||||
|
|
||||||
async def remove(self, account_id: str) -> Account:
|
async def remove(self, account_id: str) -> Account:
|
||||||
account = self[account_id]
|
account = self[account_id]
|
||||||
self._accounts.remove(account)
|
self._accounts.remove(account)
|
||||||
await self.wallet.save()
|
await self.wallet.notify_change('account.removed')
|
||||||
return account
|
return account
|
||||||
|
|
||||||
def get_or_none(self, account_id: str) -> Optional[Account]:
|
def get_or_none(self, account_id: str) -> Optional[Account]:
|
||||||
|
@ -608,7 +579,7 @@ class ChannelListManager(ClaimListManager):
|
||||||
|
|
||||||
if save_key:
|
if save_key:
|
||||||
holding_account.add_channel_private_key(txo.private_key)
|
holding_account.add_channel_private_key(txo.private_key)
|
||||||
await self.wallet.save()
|
await self.wallet.notify_change('channel.added')
|
||||||
|
|
||||||
return tx
|
return tx
|
||||||
|
|
||||||
|
@ -652,7 +623,7 @@ class ChannelListManager(ClaimListManager):
|
||||||
|
|
||||||
if any((new_signing_key, moving_accounts)) and save_key:
|
if any((new_signing_key, moving_accounts)) and save_key:
|
||||||
holding_account.add_channel_private_key(txo.private_key)
|
holding_account.add_channel_private_key(txo.private_key)
|
||||||
await self.wallet.save()
|
await self.wallet.notify_change('channel.added')
|
||||||
|
|
||||||
return tx
|
return tx
|
||||||
|
|
||||||
|
|
|
@ -1,2 +0,0 @@
|
||||||
from unittest import SkipTest
|
|
||||||
raise SkipTest("WIP")
|
|
|
@ -13,7 +13,7 @@ class AccountTestCase(AsyncioTestCase):
|
||||||
self.addCleanup(self.db.close)
|
self.addCleanup(self.db.close)
|
||||||
|
|
||||||
async def update_addressed_used(self, address, used):
|
async def update_addressed_used(self, address, used):
|
||||||
await self.db.execute(
|
await self.db.execute_sql_object(
|
||||||
tables.PubkeyAddress.update()
|
tables.PubkeyAddress.update()
|
||||||
.where(tables.PubkeyAddress.c.address == address)
|
.where(tables.PubkeyAddress.c.address == address)
|
||||||
.values(used_times=used)
|
.values(used_times=used)
|
||||||
|
@ -23,7 +23,7 @@ class AccountTestCase(AsyncioTestCase):
|
||||||
class TestHierarchicalDeterministicAccount(AccountTestCase):
|
class TestHierarchicalDeterministicAccount(AccountTestCase):
|
||||||
|
|
||||||
async def test_generate_account(self):
|
async def test_generate_account(self):
|
||||||
account = await Account.generate(self.ledger, self.db)
|
account = await Account.generate(self.db)
|
||||||
self.assertEqual(account.ledger, self.ledger)
|
self.assertEqual(account.ledger, self.ledger)
|
||||||
self.assertEqual(account.db, self.db)
|
self.assertEqual(account.db, self.db)
|
||||||
self.assertEqual(account.name, f'Account #{account.public_key.address}')
|
self.assertEqual(account.name, f'Account #{account.public_key.address}')
|
||||||
|
@ -36,18 +36,19 @@ class TestHierarchicalDeterministicAccount(AccountTestCase):
|
||||||
self.assertIsInstance(account.receiving, HierarchicalDeterministic)
|
self.assertIsInstance(account.receiving, HierarchicalDeterministic)
|
||||||
self.assertIsInstance(account.change, HierarchicalDeterministic)
|
self.assertIsInstance(account.change, HierarchicalDeterministic)
|
||||||
|
|
||||||
async def test_ensure_address_gap(self):
|
|
||||||
account = await Account.generate(self.ledger, self.db, 'lbryum')
|
|
||||||
self.assertEqual(len(await account.receiving.get_addresses()), 0)
|
self.assertEqual(len(await account.receiving.get_addresses()), 0)
|
||||||
self.assertEqual(len(await account.change.get_addresses()), 0)
|
self.assertEqual(len(await account.change.get_addresses()), 0)
|
||||||
await account.ensure_address_gap()
|
await account.ensure_address_gap()
|
||||||
self.assertEqual(len(await account.receiving.get_addresses()), 20)
|
self.assertEqual(len(await account.receiving.get_addresses()), 20)
|
||||||
self.assertEqual(len(await account.change.get_addresses()), 6)
|
self.assertEqual(len(await account.change.get_addresses()), 6)
|
||||||
|
|
||||||
|
async def test_ensure_address_gap(self):
|
||||||
|
account = await Account.generate(self.db)
|
||||||
async with account.receiving.address_generator_lock:
|
async with account.receiving.address_generator_lock:
|
||||||
await account.receiving._generate_keys(4, 7)
|
await account.receiving._generate_keys(4, 7)
|
||||||
await account.receiving._generate_keys(0, 3)
|
await account.receiving._generate_keys(0, 3)
|
||||||
await account.receiving._generate_keys(8, 11)
|
await account.receiving._generate_keys(8, 11)
|
||||||
|
|
||||||
records = await account.receiving.get_address_records()
|
records = await account.receiving.get_address_records()
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
[r['pubkey'].n for r in records],
|
[r['pubkey'].n for r in records],
|
||||||
|
@ -79,14 +80,14 @@ class TestHierarchicalDeterministicAccount(AccountTestCase):
|
||||||
self.assertEqual(len(new_keys), 20)
|
self.assertEqual(len(new_keys), 20)
|
||||||
|
|
||||||
async def test_generate_keys_over_batch_threshold_saves_it_properly(self):
|
async def test_generate_keys_over_batch_threshold_saves_it_properly(self):
|
||||||
account = Account.generate(self.ledger, self.db, 'lbryum')
|
account = await Account.generate(self.db)
|
||||||
async with account.receiving.address_generator_lock:
|
async with account.receiving.address_generator_lock:
|
||||||
await account.receiving._generate_keys(0, 200)
|
await account.receiving._generate_keys(0, 200)
|
||||||
records = await account.receiving.get_address_records()
|
records = await account.receiving.get_address_records()
|
||||||
self.assertEqual(len(records), 201)
|
self.assertEqual(len(records), 201)
|
||||||
|
|
||||||
async def test_get_or_create_usable_address(self):
|
async def test_get_or_create_usable_address(self):
|
||||||
account = Account.generate(self.ledger, self.db, 'lbryum')
|
account = await Account.generate(self.db)
|
||||||
|
|
||||||
keys = await account.receiving.get_addresses()
|
keys = await account.receiving.get_addresses()
|
||||||
self.assertEqual(len(keys), 0)
|
self.assertEqual(len(keys), 0)
|
||||||
|
@ -98,13 +99,11 @@ class TestHierarchicalDeterministicAccount(AccountTestCase):
|
||||||
self.assertEqual(len(keys), 20)
|
self.assertEqual(len(keys), 20)
|
||||||
|
|
||||||
async def test_generate_account_from_seed(self):
|
async def test_generate_account_from_seed(self):
|
||||||
account = await Account.from_dict(
|
account = await Account.from_dict(self.db, {
|
||||||
self.ledger, self.db, {
|
"seed":
|
||||||
"seed":
|
"carbon smart garage balance margin twelve chest sword toas"
|
||||||
"carbon smart garage balance margin twelve chest sword toas"
|
"t envelope bottom stomach absent"
|
||||||
"t envelope bottom stomach absent"
|
})
|
||||||
}
|
|
||||||
)
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
account.private_key.extended_key_string(),
|
account.private_key.extended_key_string(),
|
||||||
'xprv9s21ZrQH143K42ovpZygnjfHdAqSd9jo7zceDfPRogM7bkkoNVv7DRNLEoB8'
|
'xprv9s21ZrQH143K42ovpZygnjfHdAqSd9jo7zceDfPRogM7bkkoNVv7DRNLEoB8'
|
||||||
|
@ -126,6 +125,7 @@ class TestHierarchicalDeterministicAccount(AccountTestCase):
|
||||||
"carbon smart garage balance margin twelve chest sword toast envelope bottom stomac"
|
"carbon smart garage balance margin twelve chest sword toast envelope bottom stomac"
|
||||||
"h absent",
|
"h absent",
|
||||||
'encrypted': False,
|
'encrypted': False,
|
||||||
|
'lang': 'en',
|
||||||
'private_key':
|
'private_key':
|
||||||
'xprv9s21ZrQH143K42ovpZygnjfHdAqSd9jo7zceDfPRogM7bkkoNVv7DRNLEoB8'
|
'xprv9s21ZrQH143K42ovpZygnjfHdAqSd9jo7zceDfPRogM7bkkoNVv7DRNLEoB8'
|
||||||
'HoirMgH969NrgL8jNzLEegqFzPRWM37GXd4uE8uuRkx4LAe',
|
'HoirMgH969NrgL8jNzLEegqFzPRWM37GXd4uE8uuRkx4LAe',
|
||||||
|
@ -140,7 +140,7 @@ class TestHierarchicalDeterministicAccount(AccountTestCase):
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
account = Account.from_dict(self.ledger, self.db, account_data)
|
account = await Account.from_dict(self.db, account_data)
|
||||||
|
|
||||||
await account.ensure_address_gap()
|
await account.ensure_address_gap()
|
||||||
|
|
||||||
|
@ -150,7 +150,7 @@ class TestHierarchicalDeterministicAccount(AccountTestCase):
|
||||||
self.assertEqual(len(addresses), 10)
|
self.assertEqual(len(addresses), 10)
|
||||||
self.assertDictEqual(account_data, account.to_dict())
|
self.assertDictEqual(account_data, account.to_dict())
|
||||||
|
|
||||||
def test_merge_diff(self):
|
async def test_merge_diff(self):
|
||||||
account_data = {
|
account_data = {
|
||||||
'name': 'My Account',
|
'name': 'My Account',
|
||||||
'modified_on': 123.456,
|
'modified_on': 123.456,
|
||||||
|
@ -158,6 +158,7 @@ class TestHierarchicalDeterministicAccount(AccountTestCase):
|
||||||
"carbon smart garage balance margin twelve chest sword toast envelope bottom stomac"
|
"carbon smart garage balance margin twelve chest sword toast envelope bottom stomac"
|
||||||
"h absent",
|
"h absent",
|
||||||
'encrypted': False,
|
'encrypted': False,
|
||||||
|
'lang': 'en',
|
||||||
'private_key':
|
'private_key':
|
||||||
'xprv9s21ZrQH143K3TsAz5efNV8K93g3Ms3FXcjaWB9fVUsMwAoE3ZT4vYymkp'
|
'xprv9s21ZrQH143K3TsAz5efNV8K93g3Ms3FXcjaWB9fVUsMwAoE3ZT4vYymkp'
|
||||||
'5BxKKfnpz8J6sHDFriX1SnpvjNkzcks8XBnxjGLS83BTyfpna',
|
'5BxKKfnpz8J6sHDFriX1SnpvjNkzcks8XBnxjGLS83BTyfpna',
|
||||||
|
@ -170,7 +171,7 @@ class TestHierarchicalDeterministicAccount(AccountTestCase):
|
||||||
'change': {'gap': 5, 'maximum_uses_per_address': 2}
|
'change': {'gap': 5, 'maximum_uses_per_address': 2}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
account = Account.from_dict(self.ledger, self.db, account_data)
|
account = await Account.from_dict(self.db, account_data)
|
||||||
|
|
||||||
self.assertEqual(account.name, 'My Account')
|
self.assertEqual(account.name, 'My Account')
|
||||||
self.assertEqual(account.modified_on, 123.456)
|
self.assertEqual(account.modified_on, 123.456)
|
||||||
|
@ -203,15 +204,15 @@ class TestSingleKeyAccount(AccountTestCase):
|
||||||
|
|
||||||
async def asyncSetUp(self):
|
async def asyncSetUp(self):
|
||||||
await super().asyncSetUp()
|
await super().asyncSetUp()
|
||||||
self.account = Account.generate(
|
self.account = await Account.generate(
|
||||||
self.ledger, self.db, "torba", {'name': 'single-address'}
|
self.db, address_generator={"name": "single-address"}
|
||||||
)
|
)
|
||||||
|
|
||||||
async def test_generate_account(self):
|
async def test_generate_account(self):
|
||||||
account = self.account
|
account = self.account
|
||||||
|
|
||||||
self.assertEqual(account.ledger, self.ledger)
|
self.assertEqual(account.ledger, self.ledger)
|
||||||
self.assertIsNotNone(account.seed)
|
self.assertIsNotNone(account.phrase)
|
||||||
self.assertEqual(account.public_key.ledger, self.ledger)
|
self.assertEqual(account.public_key.ledger, self.ledger)
|
||||||
self.assertEqual(account.private_key.public_key, account.public_key)
|
self.assertEqual(account.private_key.public_key, account.public_key)
|
||||||
|
|
||||||
|
@ -246,7 +247,7 @@ class TestSingleKeyAccount(AccountTestCase):
|
||||||
self.assertEqual(new_keys[0], account.public_key.address)
|
self.assertEqual(new_keys[0], account.public_key.address)
|
||||||
records = await account.receiving.get_address_records()
|
records = await account.receiving.get_address_records()
|
||||||
pubkey = records[0].pop('pubkey')
|
pubkey = records[0].pop('pubkey')
|
||||||
self.assertListEqual(records, [{
|
self.assertEqual(records.rows, [{
|
||||||
'chain': 0,
|
'chain': 0,
|
||||||
'account': account.public_key.address,
|
'account': account.public_key.address,
|
||||||
'address': account.public_key.address,
|
'address': account.public_key.address,
|
||||||
|
@ -294,6 +295,7 @@ class TestSingleKeyAccount(AccountTestCase):
|
||||||
"carbon smart garage balance margin twelve chest sword toast envelope bottom stomac"
|
"carbon smart garage balance margin twelve chest sword toast envelope bottom stomac"
|
||||||
"h absent",
|
"h absent",
|
||||||
'encrypted': False,
|
'encrypted': False,
|
||||||
|
'lang': 'en',
|
||||||
'private_key': 'xprv9s21ZrQH143K42ovpZygnjfHdAqSd9jo7zceDfPRogM7bkkoNVv7'
|
'private_key': 'xprv9s21ZrQH143K42ovpZygnjfHdAqSd9jo7zceDfPRogM7bkkoNVv7'
|
||||||
'DRNLEoB8HoirMgH969NrgL8jNzLEegqFzPRWM37GXd4uE8uuRkx4LAe',
|
'DRNLEoB8HoirMgH969NrgL8jNzLEegqFzPRWM37GXd4uE8uuRkx4LAe',
|
||||||
'public_key': 'xpub661MyMwAqRbcGWtPvbWh9sc2BCfw2cTeVDYF23o3N1t6UZ5wv3EM'
|
'public_key': 'xpub661MyMwAqRbcGWtPvbWh9sc2BCfw2cTeVDYF23o3N1t6UZ5wv3EM'
|
||||||
|
@ -302,7 +304,7 @@ class TestSingleKeyAccount(AccountTestCase):
|
||||||
'certificates': {}
|
'certificates': {}
|
||||||
}
|
}
|
||||||
|
|
||||||
account = Account.from_dict(self.ledger, self.db, account_data)
|
account = await Account.from_dict(self.db, account_data)
|
||||||
|
|
||||||
await account.ensure_address_gap()
|
await account.ensure_address_gap()
|
||||||
|
|
||||||
|
@ -351,9 +353,9 @@ class AccountEncryptionTests(AccountTestCase):
|
||||||
}
|
}
|
||||||
|
|
||||||
async def test_encrypt_wallet(self):
|
async def test_encrypt_wallet(self):
|
||||||
account = await Account.from_dict(self.ledger, self.db, self.unencrypted_account)
|
account = await Account.from_dict(self.db, self.unencrypted_account)
|
||||||
account.init_vectors = {
|
account.init_vectors = {
|
||||||
'seed': self.init_vector,
|
'phrase': self.init_vector,
|
||||||
'private_key': self.init_vector
|
'private_key': self.init_vector
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -361,7 +363,7 @@ class AccountEncryptionTests(AccountTestCase):
|
||||||
self.assertIsNotNone(account.private_key)
|
self.assertIsNotNone(account.private_key)
|
||||||
account.encrypt(self.password)
|
account.encrypt(self.password)
|
||||||
self.assertTrue(account.encrypted)
|
self.assertTrue(account.encrypted)
|
||||||
self.assertEqual(account.seed, self.encrypted_account['seed'])
|
self.assertEqual(account.phrase, self.encrypted_account['seed'])
|
||||||
self.assertEqual(account.private_key_string, self.encrypted_account['private_key'])
|
self.assertEqual(account.private_key_string, self.encrypted_account['private_key'])
|
||||||
self.assertIsNone(account.private_key)
|
self.assertIsNone(account.private_key)
|
||||||
|
|
||||||
|
@ -370,9 +372,9 @@ class AccountEncryptionTests(AccountTestCase):
|
||||||
|
|
||||||
account.decrypt(self.password)
|
account.decrypt(self.password)
|
||||||
self.assertEqual(account.init_vectors['private_key'], self.init_vector)
|
self.assertEqual(account.init_vectors['private_key'], self.init_vector)
|
||||||
self.assertEqual(account.init_vectors['seed'], self.init_vector)
|
self.assertEqual(account.init_vectors['phrase'], self.init_vector)
|
||||||
|
|
||||||
self.assertEqual(account.seed, self.unencrypted_account['seed'])
|
self.assertEqual(account.phrase, self.unencrypted_account['seed'])
|
||||||
self.assertEqual(account.private_key.extended_key_string(), self.unencrypted_account['private_key'])
|
self.assertEqual(account.private_key.extended_key_string(), self.unencrypted_account['private_key'])
|
||||||
|
|
||||||
self.assertEqual(account.to_dict(encrypt_password=self.password)['seed'], self.encrypted_account['seed'])
|
self.assertEqual(account.to_dict(encrypt_password=self.password)['seed'], self.encrypted_account['seed'])
|
||||||
|
@ -381,16 +383,16 @@ class AccountEncryptionTests(AccountTestCase):
|
||||||
self.assertFalse(account.encrypted)
|
self.assertFalse(account.encrypted)
|
||||||
|
|
||||||
async def test_decrypt_wallet(self):
|
async def test_decrypt_wallet(self):
|
||||||
account = await Account.from_dict(self.ledger, self.db, self.encrypted_account)
|
account = await Account.from_dict(self.db, self.encrypted_account)
|
||||||
|
|
||||||
self.assertTrue(account.encrypted)
|
self.assertTrue(account.encrypted)
|
||||||
account.decrypt(self.password)
|
account.decrypt(self.password)
|
||||||
self.assertEqual(account.init_vectors['private_key'], self.init_vector)
|
self.assertEqual(account.init_vectors['private_key'], self.init_vector)
|
||||||
self.assertEqual(account.init_vectors['seed'], self.init_vector)
|
self.assertEqual(account.init_vectors['phrase'], self.init_vector)
|
||||||
|
|
||||||
self.assertFalse(account.encrypted)
|
self.assertFalse(account.encrypted)
|
||||||
|
|
||||||
self.assertEqual(account.seed, self.unencrypted_account['seed'])
|
self.assertEqual(account.phrase, self.unencrypted_account['seed'])
|
||||||
self.assertEqual(account.private_key.extended_key_string(), self.unencrypted_account['private_key'])
|
self.assertEqual(account.private_key.extended_key_string(), self.unencrypted_account['private_key'])
|
||||||
|
|
||||||
self.assertEqual(account.to_dict(encrypt_password=self.password)['seed'], self.encrypted_account['seed'])
|
self.assertEqual(account.to_dict(encrypt_password=self.password)['seed'], self.encrypted_account['seed'])
|
||||||
|
@ -402,7 +404,7 @@ class AccountEncryptionTests(AccountTestCase):
|
||||||
account_data = self.unencrypted_account.copy()
|
account_data = self.unencrypted_account.copy()
|
||||||
del account_data['seed']
|
del account_data['seed']
|
||||||
del account_data['private_key']
|
del account_data['private_key']
|
||||||
account = await Account.from_dict(self.ledger, self.db, account_data)
|
account = await Account.from_dict(self.db, account_data)
|
||||||
encrypted = account.to_dict('password')
|
encrypted = account.to_dict('password')
|
||||||
self.assertFalse(encrypted['seed'])
|
self.assertFalse(encrypted['seed'])
|
||||||
self.assertFalse(encrypted['private_key'])
|
self.assertFalse(encrypted['private_key'])
|
||||||
|
|
|
@ -1,566 +0,0 @@
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
import unittest
|
|
||||||
import sqlite3
|
|
||||||
import tempfile
|
|
||||||
import asyncio
|
|
||||||
from concurrent.futures.thread import ThreadPoolExecutor
|
|
||||||
|
|
||||||
from sqlalchemy import Column, Text
|
|
||||||
|
|
||||||
from lbry.wallet import (
|
|
||||||
Wallet, Account, Ledger, Headers, Transaction, Input
|
|
||||||
)
|
|
||||||
from lbry.db import Table, Version, Database, metadata
|
|
||||||
from lbry.wallet.constants import COIN
|
|
||||||
from lbry.wallet.database import query, interpolate, constraints_to_sql, AIOSQLite
|
|
||||||
from lbry.crypto.hash import sha256
|
|
||||||
from lbry.testcase import AsyncioTestCase
|
|
||||||
|
|
||||||
from tests.unit.wallet.test_transaction import get_output, NULL_HASH
|
|
||||||
|
|
||||||
|
|
||||||
class TestAIOSQLite(AsyncioTestCase):
|
|
||||||
async def asyncSetUp(self):
|
|
||||||
self.db = await AIOSQLite.connect(':memory:')
|
|
||||||
await self.db.executescript("""
|
|
||||||
pragma foreign_keys=on;
|
|
||||||
create table parent (id integer primary key, name);
|
|
||||||
create table child (id integer primary key, parent_id references parent);
|
|
||||||
""")
|
|
||||||
await self.db.execute("insert into parent values (1, 'test')")
|
|
||||||
await self.db.execute("insert into child values (2, 1)")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def delete_item(transaction):
|
|
||||||
transaction.execute('delete from parent where id=1')
|
|
||||||
|
|
||||||
async def test_foreign_keys_integrity_error(self):
|
|
||||||
self.assertListEqual([(1, 'test')], await self.db.execute_fetchall("select * from parent"))
|
|
||||||
|
|
||||||
with self.assertRaises(sqlite3.IntegrityError):
|
|
||||||
await self.db.run(self.delete_item)
|
|
||||||
self.assertListEqual([(1, 'test')], await self.db.execute_fetchall("select * from parent"))
|
|
||||||
|
|
||||||
await self.db.executescript("pragma foreign_keys=off;")
|
|
||||||
|
|
||||||
await self.db.run(self.delete_item)
|
|
||||||
self.assertListEqual([], await self.db.execute_fetchall("select * from parent"))
|
|
||||||
|
|
||||||
async def test_run_without_foreign_keys(self):
|
|
||||||
self.assertListEqual([(1, 'test')], await self.db.execute_fetchall("select * from parent"))
|
|
||||||
await self.db.run_with_foreign_keys_disabled(self.delete_item)
|
|
||||||
self.assertListEqual([], await self.db.execute_fetchall("select * from parent"))
|
|
||||||
|
|
||||||
async def test_integrity_error_when_foreign_keys_disabled_and_skipped(self):
|
|
||||||
await self.db.executescript("pragma foreign_keys=off;")
|
|
||||||
self.assertListEqual([(1, 'test')], await self.db.execute_fetchall("select * from parent"))
|
|
||||||
with self.assertRaises(sqlite3.IntegrityError):
|
|
||||||
await self.db.run_with_foreign_keys_disabled(self.delete_item)
|
|
||||||
self.assertListEqual([(1, 'test')], await self.db.execute_fetchall("select * from parent"))
|
|
||||||
|
|
||||||
|
|
||||||
class TestQueryBuilder(unittest.TestCase):
|
|
||||||
|
|
||||||
def test_dot(self):
|
|
||||||
self.assertTupleEqual(
|
|
||||||
constraints_to_sql({'txo.position': 18}),
|
|
||||||
('txo.position = :txo_position0', {'txo_position0': 18})
|
|
||||||
)
|
|
||||||
self.assertTupleEqual(
|
|
||||||
constraints_to_sql({'txo.position#6': 18}),
|
|
||||||
('txo.position = :txo_position6', {'txo_position6': 18})
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_any(self):
|
|
||||||
self.assertTupleEqual(
|
|
||||||
constraints_to_sql({
|
|
||||||
'ages__any': {
|
|
||||||
'txo.age__gt': 18,
|
|
||||||
'txo.age__lt': 38
|
|
||||||
}
|
|
||||||
}),
|
|
||||||
('(txo.age > :ages__any0_txo_age__gt0 OR txo.age < :ages__any0_txo_age__lt0)', {
|
|
||||||
'ages__any0_txo_age__gt0': 18,
|
|
||||||
'ages__any0_txo_age__lt0': 38
|
|
||||||
})
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_in(self):
|
|
||||||
self.assertTupleEqual(
|
|
||||||
constraints_to_sql({'txo.age__in#2': [18, 38]}),
|
|
||||||
('txo.age IN (:txo_age__in2_0, :txo_age__in2_1)', {
|
|
||||||
'txo_age__in2_0': 18,
|
|
||||||
'txo_age__in2_1': 38
|
|
||||||
})
|
|
||||||
)
|
|
||||||
self.assertTupleEqual(
|
|
||||||
constraints_to_sql({'txo.name__in': ('abc123', 'def456')}),
|
|
||||||
('txo.name IN (:txo_name__in0_0, :txo_name__in0_1)', {
|
|
||||||
'txo_name__in0_0': 'abc123',
|
|
||||||
'txo_name__in0_1': 'def456'
|
|
||||||
})
|
|
||||||
)
|
|
||||||
self.assertTupleEqual(
|
|
||||||
constraints_to_sql({'txo.name__in': {'abc123'}}),
|
|
||||||
('txo.name = :txo_name__in0', {
|
|
||||||
'txo_name__in0': 'abc123',
|
|
||||||
})
|
|
||||||
)
|
|
||||||
self.assertTupleEqual(
|
|
||||||
constraints_to_sql({'txo.age__in': 'SELECT age from ages_table'}),
|
|
||||||
('txo.age IN (SELECT age from ages_table)', {})
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_not_in(self):
|
|
||||||
self.assertTupleEqual(
|
|
||||||
constraints_to_sql({'txo.age__not_in': [18, 38]}),
|
|
||||||
('txo.age NOT IN (:txo_age__not_in0_0, :txo_age__not_in0_1)', {
|
|
||||||
'txo_age__not_in0_0': 18,
|
|
||||||
'txo_age__not_in0_1': 38
|
|
||||||
})
|
|
||||||
)
|
|
||||||
self.assertTupleEqual(
|
|
||||||
constraints_to_sql({'txo.name__not_in': ('abc123', 'def456')}),
|
|
||||||
('txo.name NOT IN (:txo_name__not_in0_0, :txo_name__not_in0_1)', {
|
|
||||||
'txo_name__not_in0_0': 'abc123',
|
|
||||||
'txo_name__not_in0_1': 'def456'
|
|
||||||
})
|
|
||||||
)
|
|
||||||
self.assertTupleEqual(
|
|
||||||
constraints_to_sql({'txo.name__not_in': ('abc123',)}),
|
|
||||||
('txo.name != :txo_name__not_in0', {
|
|
||||||
'txo_name__not_in0': 'abc123',
|
|
||||||
})
|
|
||||||
)
|
|
||||||
self.assertTupleEqual(
|
|
||||||
constraints_to_sql({'txo.age__not_in': 'SELECT age from ages_table'}),
|
|
||||||
('txo.age NOT IN (SELECT age from ages_table)', {})
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_in_invalid(self):
|
|
||||||
with self.assertRaisesRegex(ValueError, 'list, set or string'):
|
|
||||||
constraints_to_sql({'ages__in': 9})
|
|
||||||
|
|
||||||
def test_query(self):
|
|
||||||
self.assertTupleEqual(
|
|
||||||
query("select * from foo"),
|
|
||||||
("select * from foo", {})
|
|
||||||
)
|
|
||||||
self.assertTupleEqual(
|
|
||||||
query(
|
|
||||||
"select * from foo",
|
|
||||||
a__not='b', b__in='select * from blah where c=:$c',
|
|
||||||
d__any={'one__like': 'o', 'two': 2}, limit=10, order_by='b', **{'$c': 3}),
|
|
||||||
(
|
|
||||||
"select * from foo WHERE a != :a__not0 AND "
|
|
||||||
"b IN (select * from blah where c=:$c) AND "
|
|
||||||
"(one LIKE :d__any0_one__like0 OR two = :d__any0_two0) ORDER BY b LIMIT 10",
|
|
||||||
{'a__not0': 'b', 'd__any0_one__like0': 'o', 'd__any0_two0': 2, '$c': 3}
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_query_order_by(self):
|
|
||||||
self.assertTupleEqual(
|
|
||||||
query("select * from foo", order_by='foo'),
|
|
||||||
("select * from foo ORDER BY foo", {})
|
|
||||||
)
|
|
||||||
self.assertTupleEqual(
|
|
||||||
query("select * from foo", order_by=['foo', 'bar']),
|
|
||||||
("select * from foo ORDER BY foo, bar", {})
|
|
||||||
)
|
|
||||||
with self.assertRaisesRegex(ValueError, 'order_by must be string or list'):
|
|
||||||
query("select * from foo", order_by={'foo': 'bar'})
|
|
||||||
|
|
||||||
def test_query_limit_offset(self):
|
|
||||||
self.assertTupleEqual(
|
|
||||||
query("select * from foo", limit=10),
|
|
||||||
("select * from foo LIMIT 10", {})
|
|
||||||
)
|
|
||||||
self.assertTupleEqual(
|
|
||||||
query("select * from foo", offset=10),
|
|
||||||
("select * from foo OFFSET 10", {})
|
|
||||||
)
|
|
||||||
self.assertTupleEqual(
|
|
||||||
query("select * from foo", limit=20, offset=10),
|
|
||||||
("select * from foo LIMIT 20 OFFSET 10", {})
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_query_interpolation(self):
|
|
||||||
self.maxDiff = None
|
|
||||||
# tests that interpolation replaces longer keys first
|
|
||||||
self.assertEqual(
|
|
||||||
interpolate(*query(
|
|
||||||
"select * from foo",
|
|
||||||
a__not='b', b__in='select * from blah where c=:$c',
|
|
||||||
d__any={'one__like': 'o', 'two': 2},
|
|
||||||
a0=3, a00=1, a00a=2, a00aa=4, # <-- breaks without correct interpolation key order
|
|
||||||
ahash=sha256(b'hello world'),
|
|
||||||
limit=10, order_by='b', **{'$c': 3})
|
|
||||||
),
|
|
||||||
"select * from foo WHERE a != 'b' AND "
|
|
||||||
"b IN (select * from blah where c=3) AND "
|
|
||||||
"(one LIKE 'o' OR two = 2) AND "
|
|
||||||
"a0 = 3 AND a00 = 1 AND a00a = 2 AND a00aa = 4 "
|
|
||||||
"AND ahash = X'b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9' "
|
|
||||||
"ORDER BY b LIMIT 10"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestQueries(AsyncioTestCase):
|
|
||||||
|
|
||||||
async def asyncSetUp(self):
|
|
||||||
self.ledger = Ledger({
|
|
||||||
'db': Database('sqlite:///:memory:'),
|
|
||||||
'headers': Headers(':memory:')
|
|
||||||
})
|
|
||||||
self.wallet = Wallet()
|
|
||||||
await self.ledger.db.open()
|
|
||||||
|
|
||||||
async def asyncTearDown(self):
|
|
||||||
await self.ledger.db.close()
|
|
||||||
|
|
||||||
async def create_account(self, wallet=None):
|
|
||||||
account = Account.generate(self.ledger, wallet or self.wallet)
|
|
||||||
await account.ensure_address_gap()
|
|
||||||
return account
|
|
||||||
|
|
||||||
async def create_tx_from_nothing(self, my_account, height):
|
|
||||||
to_address = await my_account.receiving.get_or_create_usable_address()
|
|
||||||
to_hash = Ledger.address_to_hash160(to_address)
|
|
||||||
tx = Transaction(height=height, is_verified=True) \
|
|
||||||
.add_inputs([self.txi(self.txo(1, sha256(str(height).encode())))]) \
|
|
||||||
.add_outputs([self.txo(1, to_hash)])
|
|
||||||
await self.ledger.db.insert_transaction(tx)
|
|
||||||
await self.ledger.db.save_transaction_io(tx, to_address, to_hash, '')
|
|
||||||
return tx
|
|
||||||
|
|
||||||
async def create_tx_from_txo(self, txo, to_account, height):
|
|
||||||
from_hash = txo.script.values['pubkey_hash']
|
|
||||||
from_address = self.ledger.pubkey_hash_to_address(from_hash)
|
|
||||||
to_address = await to_account.receiving.get_or_create_usable_address()
|
|
||||||
to_hash = Ledger.address_to_hash160(to_address)
|
|
||||||
tx = Transaction(height=height, is_verified=True) \
|
|
||||||
.add_inputs([self.txi(txo)]) \
|
|
||||||
.add_outputs([self.txo(1, to_hash)])
|
|
||||||
await self.ledger.db.insert_transaction(tx)
|
|
||||||
await self.ledger.db.save_transaction_io(tx, from_address, from_hash, '')
|
|
||||||
await self.ledger.db.save_transaction_io(tx, to_address, to_hash, '')
|
|
||||||
return tx
|
|
||||||
|
|
||||||
async def create_tx_to_nowhere(self, txo, height):
|
|
||||||
from_hash = txo.script.values['pubkey_hash']
|
|
||||||
from_address = self.ledger.pubkey_hash_to_address(from_hash)
|
|
||||||
to_hash = NULL_HASH
|
|
||||||
tx = Transaction(height=height, is_verified=True) \
|
|
||||||
.add_inputs([self.txi(txo)]) \
|
|
||||||
.add_outputs([self.txo(1, to_hash)])
|
|
||||||
await self.ledger.db.insert_transaction(tx)
|
|
||||||
await self.ledger.db.save_transaction_io(tx, from_address, from_hash, '')
|
|
||||||
return tx
|
|
||||||
|
|
||||||
def txo(self, amount, address):
|
|
||||||
return get_output(int(amount*COIN), address)
|
|
||||||
|
|
||||||
def txi(self, txo):
|
|
||||||
return Input.spend(txo)
|
|
||||||
|
|
||||||
async def test_large_tx_doesnt_hit_variable_limits(self):
|
|
||||||
# SQLite is usually compiled with 999 variables limit: https://www.sqlite.org/limits.html
|
|
||||||
# This can be removed when there is a better way. See: https://github.com/lbryio/lbry-sdk/issues/2281
|
|
||||||
fetchall = self.ledger.db.execute_fetchall
|
|
||||||
|
|
||||||
def check_parameters_length(sql, parameters=None):
|
|
||||||
self.assertLess(len(parameters or []), 999)
|
|
||||||
return fetchall(sql, parameters)
|
|
||||||
|
|
||||||
self.ledger.db.execute_fetchall = check_parameters_length
|
|
||||||
account = await self.create_account()
|
|
||||||
tx = await self.create_tx_from_nothing(account, 0)
|
|
||||||
for height in range(1, 1200):
|
|
||||||
tx = await self.create_tx_from_txo(tx.outputs[0], account, height=height)
|
|
||||||
variable_limit = self.ledger.db.MAX_QUERY_VARIABLES
|
|
||||||
for limit in range(variable_limit - 2, variable_limit + 2):
|
|
||||||
txs = await self.ledger.get_transactions(
|
|
||||||
accounts=self.wallet.accounts, limit=limit, order_by='height asc')
|
|
||||||
self.assertEqual(len(txs), limit)
|
|
||||||
inputs, outputs, last_tx = set(), set(), txs[0]
|
|
||||||
for tx in txs[1:]:
|
|
||||||
self.assertEqual(len(tx.inputs), 1)
|
|
||||||
self.assertEqual(tx.inputs[0].txo_ref.tx_ref.id, last_tx.id)
|
|
||||||
self.assertEqual(len(tx.outputs), 1)
|
|
||||||
last_tx = tx
|
|
||||||
|
|
||||||
async def test_queries(self):
|
|
||||||
wallet1 = Wallet()
|
|
||||||
account1 = await self.create_account(wallet1)
|
|
||||||
self.assertEqual(26, await self.ledger.db.get_address_count(accounts=[account1]))
|
|
||||||
wallet2 = Wallet()
|
|
||||||
account2 = await self.create_account(wallet2)
|
|
||||||
account3 = await self.create_account(wallet2)
|
|
||||||
self.assertEqual(26, await self.ledger.db.get_address_count(accounts=[account2]))
|
|
||||||
|
|
||||||
self.assertEqual(0, await self.ledger.db.get_transaction_count(accounts=[account1, account2, account3]))
|
|
||||||
self.assertEqual(0, await self.ledger.db.get_utxo_count())
|
|
||||||
self.assertListEqual([], await self.ledger.db.get_utxos())
|
|
||||||
self.assertEqual(0, await self.ledger.db.get_txo_count())
|
|
||||||
self.assertEqual(0, await self.ledger.db.get_balance(wallet=wallet1))
|
|
||||||
self.assertEqual(0, await self.ledger.db.get_balance(wallet=wallet2))
|
|
||||||
self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account1]))
|
|
||||||
self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account2]))
|
|
||||||
self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account3]))
|
|
||||||
|
|
||||||
tx1 = await self.create_tx_from_nothing(account1, 1)
|
|
||||||
self.assertEqual(1, await self.ledger.db.get_transaction_count(accounts=[account1]))
|
|
||||||
self.assertEqual(0, await self.ledger.db.get_transaction_count(accounts=[account2]))
|
|
||||||
self.assertEqual(1, await self.ledger.db.get_utxo_count(accounts=[account1]))
|
|
||||||
self.assertEqual(1, await self.ledger.db.get_txo_count(accounts=[account1]))
|
|
||||||
self.assertEqual(0, await self.ledger.db.get_txo_count(accounts=[account2]))
|
|
||||||
self.assertEqual(10**8, await self.ledger.db.get_balance(wallet=wallet1))
|
|
||||||
self.assertEqual(0, await self.ledger.db.get_balance(wallet=wallet2))
|
|
||||||
self.assertEqual(10**8, await self.ledger.db.get_balance(accounts=[account1]))
|
|
||||||
self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account2]))
|
|
||||||
self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account3]))
|
|
||||||
|
|
||||||
tx2 = await self.create_tx_from_txo(tx1.outputs[0], account2, 2)
|
|
||||||
tx2b = await self.create_tx_from_nothing(account3, 2)
|
|
||||||
self.assertEqual(2, await self.ledger.db.get_transaction_count(accounts=[account1]))
|
|
||||||
self.assertEqual(1, await self.ledger.db.get_transaction_count(accounts=[account2]))
|
|
||||||
self.assertEqual(1, await self.ledger.db.get_transaction_count(accounts=[account3]))
|
|
||||||
self.assertEqual(0, await self.ledger.db.get_utxo_count(accounts=[account1]))
|
|
||||||
self.assertEqual(1, await self.ledger.db.get_txo_count(accounts=[account1]))
|
|
||||||
self.assertEqual(1, await self.ledger.db.get_utxo_count(accounts=[account2]))
|
|
||||||
self.assertEqual(1, await self.ledger.db.get_txo_count(accounts=[account2]))
|
|
||||||
self.assertEqual(1, await self.ledger.db.get_utxo_count(accounts=[account3]))
|
|
||||||
self.assertEqual(1, await self.ledger.db.get_txo_count(accounts=[account3]))
|
|
||||||
self.assertEqual(0, await self.ledger.db.get_balance(wallet=wallet1))
|
|
||||||
self.assertEqual(10**8+10**8, await self.ledger.db.get_balance(wallet=wallet2))
|
|
||||||
self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account1]))
|
|
||||||
self.assertEqual(10**8, await self.ledger.db.get_balance(accounts=[account2]))
|
|
||||||
self.assertEqual(10**8, await self.ledger.db.get_balance(accounts=[account3]))
|
|
||||||
|
|
||||||
tx3 = await self.create_tx_to_nowhere(tx2.outputs[0], 3)
|
|
||||||
self.assertEqual(2, await self.ledger.db.get_transaction_count(accounts=[account1]))
|
|
||||||
self.assertEqual(2, await self.ledger.db.get_transaction_count(accounts=[account2]))
|
|
||||||
self.assertEqual(0, await self.ledger.db.get_utxo_count(accounts=[account1]))
|
|
||||||
self.assertEqual(1, await self.ledger.db.get_txo_count(accounts=[account1]))
|
|
||||||
self.assertEqual(0, await self.ledger.db.get_utxo_count(accounts=[account2]))
|
|
||||||
self.assertEqual(1, await self.ledger.db.get_txo_count(accounts=[account2]))
|
|
||||||
self.assertEqual(0, await self.ledger.db.get_balance(wallet=wallet1))
|
|
||||||
self.assertEqual(10**8, await self.ledger.db.get_balance(wallet=wallet2))
|
|
||||||
self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account1]))
|
|
||||||
self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account2]))
|
|
||||||
self.assertEqual(10**8, await self.ledger.db.get_balance(accounts=[account3]))
|
|
||||||
|
|
||||||
txs = await self.ledger.db.get_transactions(accounts=[account1, account2])
|
|
||||||
self.assertListEqual([tx3.id, tx2.id, tx1.id], [tx.id for tx in txs])
|
|
||||||
self.assertListEqual([3, 2, 1], [tx.height for tx in txs])
|
|
||||||
|
|
||||||
txs = await self.ledger.db.get_transactions(wallet=wallet1, accounts=wallet1.accounts, include_is_my_output=True)
|
|
||||||
self.assertListEqual([tx2.id, tx1.id], [tx.id for tx in txs])
|
|
||||||
self.assertEqual(txs[0].inputs[0].is_my_input, True)
|
|
||||||
self.assertEqual(txs[0].outputs[0].is_my_output, False)
|
|
||||||
self.assertEqual(txs[1].inputs[0].is_my_input, False)
|
|
||||||
self.assertEqual(txs[1].outputs[0].is_my_output, True)
|
|
||||||
|
|
||||||
txs = await self.ledger.db.get_transactions(wallet=wallet2, accounts=[account2], include_is_my_output=True)
|
|
||||||
self.assertListEqual([tx3.id, tx2.id], [tx.id for tx in txs])
|
|
||||||
self.assertEqual(txs[0].inputs[0].is_my_input, True)
|
|
||||||
self.assertEqual(txs[0].outputs[0].is_my_output, False)
|
|
||||||
self.assertEqual(txs[1].inputs[0].is_my_input, False)
|
|
||||||
self.assertEqual(txs[1].outputs[0].is_my_output, True)
|
|
||||||
self.assertEqual(2, await self.ledger.db.get_transaction_count(accounts=[account2]))
|
|
||||||
|
|
||||||
tx = await self.ledger.db.get_transaction(tx_hash=tx2.hash)
|
|
||||||
self.assertEqual(tx.id, tx2.id)
|
|
||||||
self.assertIsNone(tx.inputs[0].is_my_input)
|
|
||||||
self.assertIsNone(tx.outputs[0].is_my_output)
|
|
||||||
tx = await self.ledger.db.get_transaction(wallet=wallet1, tx_hash=tx2.hash, include_is_my_output=True)
|
|
||||||
self.assertTrue(tx.inputs[0].is_my_input)
|
|
||||||
self.assertFalse(tx.outputs[0].is_my_output)
|
|
||||||
tx = await self.ledger.db.get_transaction(wallet=wallet2, tx_hash=tx2.hash, include_is_my_output=True)
|
|
||||||
self.assertFalse(tx.inputs[0].is_my_input)
|
|
||||||
self.assertTrue(tx.outputs[0].is_my_output)
|
|
||||||
|
|
||||||
# height 0 sorted to the top with the rest in descending order
|
|
||||||
tx4 = await self.create_tx_from_nothing(account1, 0)
|
|
||||||
txos = await self.ledger.db.get_txos()
|
|
||||||
self.assertListEqual([0, 3, 2, 2, 1], [txo.tx_ref.height for txo in txos])
|
|
||||||
self.assertListEqual([tx4.id, tx3.id, tx2.id, tx2b.id, tx1.id], [txo.tx_ref.id for txo in txos])
|
|
||||||
txs = await self.ledger.db.get_transactions(accounts=[account1, account2])
|
|
||||||
self.assertListEqual([0, 3, 2, 1], [tx.height for tx in txs])
|
|
||||||
self.assertListEqual([tx4.id, tx3.id, tx2.id, tx1.id], [tx.id for tx in txs])
|
|
||||||
|
|
||||||
async def test_empty_history(self):
|
|
||||||
self.assertEqual((None, []), await self.ledger.get_local_status_and_history(''))
|
|
||||||
|
|
||||||
|
|
||||||
class TestUpgrade(AsyncioTestCase):
|
|
||||||
|
|
||||||
def setUp(self) -> None:
|
|
||||||
self.path = tempfile.mktemp()
|
|
||||||
|
|
||||||
def tearDown(self) -> None:
|
|
||||||
os.remove(self.path)
|
|
||||||
|
|
||||||
def get_version(self):
|
|
||||||
with sqlite3.connect(self.path) as conn:
|
|
||||||
versions = conn.execute('select version from version').fetchall()
|
|
||||||
assert len(versions) == 1
|
|
||||||
return versions[0][0]
|
|
||||||
|
|
||||||
def get_tables(self):
|
|
||||||
with sqlite3.connect(self.path) as conn:
|
|
||||||
sql = "SELECT name FROM sqlite_master WHERE type='table' ORDER BY name;"
|
|
||||||
return [col[0] for col in conn.execute(sql).fetchall()]
|
|
||||||
|
|
||||||
def add_address(self, address):
|
|
||||||
with sqlite3.connect(self.path) as conn:
|
|
||||||
conn.execute("""
|
|
||||||
INSERT INTO account_address (address, account, chain, n, pubkey, chain_code, depth)
|
|
||||||
VALUES (?, 'account1', 0, 0, 'pubkey', 'chain_code', 0)
|
|
||||||
""", (address,))
|
|
||||||
|
|
||||||
def get_addresses(self):
|
|
||||||
with sqlite3.connect(self.path) as conn:
|
|
||||||
sql = "SELECT address FROM account_address ORDER BY address;"
|
|
||||||
return [col[0] for col in conn.execute(sql).fetchall()]
|
|
||||||
|
|
||||||
async def test_reset_on_version_change(self):
|
|
||||||
self.ledger = Ledger({
|
|
||||||
'db': Database('sqlite:///'+self.path),
|
|
||||||
'headers': Headers(':memory:')
|
|
||||||
})
|
|
||||||
|
|
||||||
# initial open, pre-version enabled db
|
|
||||||
self.ledger.db.SCHEMA_VERSION = None
|
|
||||||
self.assertListEqual(self.get_tables(), [])
|
|
||||||
await self.ledger.db.open()
|
|
||||||
metadata.drop_all(self.ledger.db.engine, [Version]) # simulate pre-version table db
|
|
||||||
self.assertEqual(self.get_tables(), ['account_address', 'block', 'pubkey_address', 'tx', 'txi', 'txo'])
|
|
||||||
self.assertListEqual(self.get_addresses(), [])
|
|
||||||
self.add_address('address1')
|
|
||||||
await self.ledger.db.close()
|
|
||||||
|
|
||||||
# initial open after version enabled
|
|
||||||
self.ledger.db.SCHEMA_VERSION = '1.0'
|
|
||||||
await self.ledger.db.open()
|
|
||||||
self.assertEqual(self.get_version(), '1.0')
|
|
||||||
self.assertListEqual(self.get_tables(), ['account_address', 'block', 'pubkey_address', 'tx', 'txi', 'txo', 'version'])
|
|
||||||
self.assertListEqual(self.get_addresses(), []) # address1 deleted during version upgrade
|
|
||||||
self.add_address('address2')
|
|
||||||
await self.ledger.db.close()
|
|
||||||
|
|
||||||
# nothing changes
|
|
||||||
self.assertEqual(self.get_version(), '1.0')
|
|
||||||
self.assertListEqual(self.get_tables(), ['account_address', 'block', 'pubkey_address', 'tx', 'txi', 'txo', 'version'])
|
|
||||||
await self.ledger.db.open()
|
|
||||||
self.assertEqual(self.get_version(), '1.0')
|
|
||||||
self.assertListEqual(self.get_tables(), ['account_address', 'block', 'pubkey_address', 'tx', 'txi', 'txo', 'version'])
|
|
||||||
self.assertListEqual(self.get_addresses(), ['address2'])
|
|
||||||
await self.ledger.db.close()
|
|
||||||
|
|
||||||
# upgrade version, database reset
|
|
||||||
foo = Table('foo', metadata, Column('bar', Text, primary_key=True))
|
|
||||||
self.addCleanup(metadata.remove, foo)
|
|
||||||
self.ledger.db.SCHEMA_VERSION = '1.1'
|
|
||||||
await self.ledger.db.open()
|
|
||||||
self.assertEqual(self.get_version(), '1.1')
|
|
||||||
self.assertListEqual(self.get_tables(), ['account_address', 'block', 'foo', 'pubkey_address', 'tx', 'txi', 'txo', 'version'])
|
|
||||||
self.assertListEqual(self.get_addresses(), []) # all tables got reset
|
|
||||||
await self.ledger.db.close()
|
|
||||||
|
|
||||||
|
|
||||||
class TestSQLiteRace(AsyncioTestCase):
|
|
||||||
max_misuse_attempts = 40000
|
|
||||||
|
|
||||||
def setup_db(self):
|
|
||||||
self.db = sqlite3.connect(":memory:", isolation_level=None)
|
|
||||||
self.db.executescript(
|
|
||||||
"create table test1 (id text primary key not null, val text);\n" +
|
|
||||||
"create table test2 (id text primary key not null, val text);\n" +
|
|
||||||
"\n".join(f"insert into test1 values ({v}, NULL);" for v in range(1000))
|
|
||||||
)
|
|
||||||
|
|
||||||
async def asyncSetUp(self):
|
|
||||||
self.executor = ThreadPoolExecutor(1)
|
|
||||||
await self.loop.run_in_executor(self.executor, self.setup_db)
|
|
||||||
|
|
||||||
async def asyncTearDown(self):
|
|
||||||
await self.loop.run_in_executor(self.executor, self.db.close)
|
|
||||||
self.executor.shutdown()
|
|
||||||
|
|
||||||
async def test_binding_param_0_error(self):
|
|
||||||
# test real param 0 binding errors
|
|
||||||
|
|
||||||
for supported_type in [str, int, bytes]:
|
|
||||||
await self.loop.run_in_executor(
|
|
||||||
self.executor, self.db.executemany, "insert into test2 values (?, NULL)",
|
|
||||||
[(supported_type(1), ), (supported_type(2), )]
|
|
||||||
)
|
|
||||||
await self.loop.run_in_executor(
|
|
||||||
self.executor, self.db.execute, "delete from test2 where id in (1, 2)"
|
|
||||||
)
|
|
||||||
for unsupported_type in [lambda x: (x, ), lambda x: [x], lambda x: {x}]:
|
|
||||||
try:
|
|
||||||
await self.loop.run_in_executor(
|
|
||||||
self.executor, self.db.executemany, "insert into test2 (id, val) values (?, NULL)",
|
|
||||||
[(unsupported_type(1), ), (unsupported_type(2), )]
|
|
||||||
)
|
|
||||||
self.assertTrue(False)
|
|
||||||
except sqlite3.InterfaceError as err:
|
|
||||||
self.assertEqual(str(err), "Error binding parameter 0 - probably unsupported type.")
|
|
||||||
|
|
||||||
async def test_unhandled_sqlite_misuse(self):
|
|
||||||
# test SQLITE_MISUSE being incorrectly raised as a param 0 binding error
|
|
||||||
attempts = 0
|
|
||||||
python_version = sys.version.split('\n')[0].rstrip(' ')
|
|
||||||
|
|
||||||
try:
|
|
||||||
while attempts < self.max_misuse_attempts:
|
|
||||||
f1 = asyncio.wrap_future(
|
|
||||||
self.loop.run_in_executor(
|
|
||||||
self.executor, self.db.executemany, "update test1 set val='derp' where id=?",
|
|
||||||
((str(i),) for i in range(2))
|
|
||||||
)
|
|
||||||
)
|
|
||||||
f2 = asyncio.wrap_future(
|
|
||||||
self.loop.run_in_executor(
|
|
||||||
self.executor, self.db.executemany, "update test2 set val='derp' where id=?",
|
|
||||||
((str(i),) for i in range(2))
|
|
||||||
)
|
|
||||||
)
|
|
||||||
attempts += 1
|
|
||||||
await asyncio.gather(f1, f2)
|
|
||||||
print(f"\nsqlite3 {sqlite3.version}/python {python_version} "
|
|
||||||
f"did not raise SQLITE_MISUSE within {attempts} attempts of the race condition")
|
|
||||||
self.assertTrue(False, 'this test failing means either the sqlite race conditions '
|
|
||||||
'have been fixed in cpython or the test max_attempts needs to be increased')
|
|
||||||
except sqlite3.InterfaceError as err:
|
|
||||||
self.assertEqual(str(err), "Error binding parameter 0 - probably unsupported type.")
|
|
||||||
print(f"\nsqlite3 {sqlite3.version}/python {python_version} raised SQLITE_MISUSE "
|
|
||||||
f"after {attempts} attempts of the race condition")
|
|
||||||
|
|
||||||
@unittest.SkipTest
|
|
||||||
async def test_fetchall_prevents_sqlite_misuse(self):
|
|
||||||
# test that calling fetchall sufficiently avoids the race
|
|
||||||
attempts = 0
|
|
||||||
|
|
||||||
def executemany_fetchall(query, params):
|
|
||||||
self.db.executemany(query, params).fetchall()
|
|
||||||
|
|
||||||
while attempts < self.max_misuse_attempts:
|
|
||||||
f1 = asyncio.wrap_future(
|
|
||||||
self.loop.run_in_executor(
|
|
||||||
self.executor, executemany_fetchall, "update test1 set val='derp' where id=?",
|
|
||||||
((str(i),) for i in range(2))
|
|
||||||
)
|
|
||||||
)
|
|
||||||
f2 = asyncio.wrap_future(
|
|
||||||
self.loop.run_in_executor(
|
|
||||||
self.executor, executemany_fetchall, "update test2 set val='derp' where id=?",
|
|
||||||
((str(i),) for i in range(2))
|
|
||||||
)
|
|
||||||
)
|
|
||||||
attempts += 1
|
|
||||||
await asyncio.gather(f1, f2)
|
|
|
@ -1,242 +0,0 @@
|
||||||
import os
|
|
||||||
from binascii import hexlify, unhexlify
|
|
||||||
|
|
||||||
from lbry.testcase import AsyncioTestCase
|
|
||||||
from lbry.wallet import Wallet, Account, Transaction, Output, Input, Ledger, Headers
|
|
||||||
from lbry.db import Database
|
|
||||||
|
|
||||||
from tests.unit.wallet.test_transaction import get_transaction, get_output
|
|
||||||
from tests.unit.wallet.test_headers import HEADERS, block_bytes
|
|
||||||
|
|
||||||
|
|
||||||
class MockNetwork:
|
|
||||||
|
|
||||||
def __init__(self, history, transaction):
|
|
||||||
self.history = history
|
|
||||||
self.transaction = transaction
|
|
||||||
self.address = None
|
|
||||||
self.get_history_called = []
|
|
||||||
self.get_transaction_called = []
|
|
||||||
self.is_connected = False
|
|
||||||
|
|
||||||
def retriable_call(self, function, *args, **kwargs):
|
|
||||||
return function(*args, **kwargs)
|
|
||||||
|
|
||||||
async def get_history(self, address):
|
|
||||||
self.get_history_called.append(address)
|
|
||||||
self.address = address
|
|
||||||
return self.history
|
|
||||||
|
|
||||||
async def get_merkle(self, txid, height):
|
|
||||||
return {'merkle': ['abcd01'], 'pos': 1}
|
|
||||||
|
|
||||||
async def get_transaction(self, tx_hash, _=None):
|
|
||||||
self.get_transaction_called.append(tx_hash)
|
|
||||||
return self.transaction[tx_hash]
|
|
||||||
|
|
||||||
async def get_transaction_and_merkle(self, tx_hash, known_height=None):
|
|
||||||
tx = await self.get_transaction(tx_hash)
|
|
||||||
merkle = {}
|
|
||||||
if known_height:
|
|
||||||
merkle = await self.get_merkle(tx_hash, known_height)
|
|
||||||
return tx, merkle
|
|
||||||
|
|
||||||
|
|
||||||
class LedgerTestCase(AsyncioTestCase):
|
|
||||||
|
|
||||||
async def asyncSetUp(self):
|
|
||||||
self.ledger = Ledger({
|
|
||||||
'db': Database('sqlite:///:memory:'),
|
|
||||||
'headers': Headers(':memory:')
|
|
||||||
})
|
|
||||||
self.account = Account.generate(self.ledger, Wallet(), "lbryum")
|
|
||||||
await self.ledger.db.open()
|
|
||||||
|
|
||||||
async def asyncTearDown(self):
|
|
||||||
await self.ledger.db.close()
|
|
||||||
|
|
||||||
def make_header(self, **kwargs):
|
|
||||||
header = {
|
|
||||||
'bits': 486604799,
|
|
||||||
'block_height': 0,
|
|
||||||
'merkle_root': b'4a5e1e4baab89f3a32518a88c31bc87f618f76673e2cc77ab2127b7afdeda33b',
|
|
||||||
'nonce': 2083236893,
|
|
||||||
'prev_block_hash': b'0000000000000000000000000000000000000000000000000000000000000000',
|
|
||||||
'timestamp': 1231006505,
|
|
||||||
'claim_trie_root': b'4a5e1e4baab89f3a32518a88c31bc87f618f76673e2cc77ab2127b7afdeda33b',
|
|
||||||
'version': 1
|
|
||||||
}
|
|
||||||
header.update(kwargs)
|
|
||||||
header['merkle_root'] = header['merkle_root'].ljust(64, b'a')
|
|
||||||
header['prev_block_hash'] = header['prev_block_hash'].ljust(64, b'0')
|
|
||||||
return self.ledger.headers.serialize(header)
|
|
||||||
|
|
||||||
def add_header(self, **kwargs):
|
|
||||||
serialized = self.make_header(**kwargs)
|
|
||||||
self.ledger.headers.io.seek(0, os.SEEK_END)
|
|
||||||
self.ledger.headers.io.write(serialized)
|
|
||||||
self.ledger.headers._size = self.ledger.headers.io.seek(0, os.SEEK_END) // self.ledger.headers.header_size
|
|
||||||
|
|
||||||
|
|
||||||
class TestSynchronization(LedgerTestCase):
|
|
||||||
|
|
||||||
async def test_update_history(self):
|
|
||||||
txid1 = '252bda9b22cc902ca2aa2de3548ee8baf06b8501ff7bfb3b0b7d980dbd1bf792'
|
|
||||||
txid2 = 'ab9c0654dd484ac20437030f2034e25dcb29fc507e84b91138f80adc3af738f9'
|
|
||||||
txid3 = 'a2ae3d1db3c727e7d696122cab39ee20a7f81856dab7019056dd539f38c548a0'
|
|
||||||
txid4 = '047cf1d53ef68f0fd586d46f90c09ff8e57a4180f67e7f4b8dd0135c3741e828'
|
|
||||||
txhash1 = unhexlify(txid1)[::-1]
|
|
||||||
txhash2 = unhexlify(txid2)[::-1]
|
|
||||||
txhash3 = unhexlify(txid3)[::-1]
|
|
||||||
txhash4 = unhexlify(txid4)[::-1]
|
|
||||||
|
|
||||||
account = Account.generate(self.ledger, Wallet(), "torba")
|
|
||||||
address = await account.receiving.get_or_create_usable_address()
|
|
||||||
address_details = await self.ledger.db.get_address(address=address)
|
|
||||||
self.assertIsNone(address_details['history'])
|
|
||||||
|
|
||||||
self.add_header(block_height=0, merkle_root=b'abcd04')
|
|
||||||
self.add_header(block_height=1, merkle_root=b'abcd04')
|
|
||||||
self.add_header(block_height=2, merkle_root=b'abcd04')
|
|
||||||
self.add_header(block_height=3, merkle_root=b'abcd04')
|
|
||||||
self.ledger.network = MockNetwork([
|
|
||||||
{'tx_hash': txid1, 'height': 0},
|
|
||||||
{'tx_hash': txid2, 'height': 1},
|
|
||||||
{'tx_hash': txid3, 'height': 2},
|
|
||||||
], {
|
|
||||||
txhash1: hexlify(get_transaction(get_output(1)).raw),
|
|
||||||
txhash2: hexlify(get_transaction(get_output(2)).raw),
|
|
||||||
txhash3: hexlify(get_transaction(get_output(3)).raw),
|
|
||||||
})
|
|
||||||
await self.ledger.update_history(address, '')
|
|
||||||
self.assertListEqual(self.ledger.network.get_history_called, [address])
|
|
||||||
self.assertListEqual(self.ledger.network.get_transaction_called, [txhash1, txhash2, txhash3])
|
|
||||||
|
|
||||||
address_details = await self.ledger.db.get_address(address=address)
|
|
||||||
|
|
||||||
self.assertEqual(
|
|
||||||
address_details['history'],
|
|
||||||
f'{txid1}:0:'
|
|
||||||
f'{txid2}:1:'
|
|
||||||
f'{txid3}:2:'
|
|
||||||
)
|
|
||||||
|
|
||||||
self.ledger.network.get_history_called = []
|
|
||||||
self.ledger.network.get_transaction_called = []
|
|
||||||
for cache_item in self.ledger._tx_cache.values():
|
|
||||||
cache_item.tx.is_verified = True
|
|
||||||
await self.ledger.update_history(address, '')
|
|
||||||
self.assertListEqual(self.ledger.network.get_history_called, [address])
|
|
||||||
self.assertListEqual(self.ledger.network.get_transaction_called, [])
|
|
||||||
|
|
||||||
self.ledger.network.history.append({'tx_hash': txid4, 'height': 3})
|
|
||||||
self.ledger.network.transaction[txhash4] = hexlify(get_transaction(get_output(4)).raw)
|
|
||||||
self.ledger.network.get_history_called = []
|
|
||||||
self.ledger.network.get_transaction_called = []
|
|
||||||
await self.ledger.update_history(address, '')
|
|
||||||
self.assertListEqual(self.ledger.network.get_history_called, [address])
|
|
||||||
self.assertListEqual(self.ledger.network.get_transaction_called, [txhash4])
|
|
||||||
address_details = await self.ledger.db.get_address(address=address)
|
|
||||||
self.assertEqual(
|
|
||||||
address_details['history'],
|
|
||||||
f'{txid1}:0:'
|
|
||||||
f'{txid2}:1:'
|
|
||||||
f'{txid3}:2:'
|
|
||||||
f'{txid4}:3:'
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MocHeaderNetwork(MockNetwork):
|
|
||||||
def __init__(self, responses):
|
|
||||||
super().__init__(None, None)
|
|
||||||
self.responses = responses
|
|
||||||
|
|
||||||
async def get_headers(self, height, blocks):
|
|
||||||
return self.responses[height]
|
|
||||||
|
|
||||||
|
|
||||||
class BlockchainReorganizationTests(LedgerTestCase):
|
|
||||||
|
|
||||||
async def test_1_block_reorganization(self):
|
|
||||||
self.ledger.network = MocHeaderNetwork({
|
|
||||||
10: {'height': 10, 'count': 5, 'hex': hexlify(
|
|
||||||
HEADERS[block_bytes(10):block_bytes(15)]
|
|
||||||
)},
|
|
||||||
15: {'height': 15, 'count': 0, 'hex': b''}
|
|
||||||
})
|
|
||||||
headers = self.ledger.headers
|
|
||||||
await headers.connect(0, HEADERS[:block_bytes(10)])
|
|
||||||
self.add_header(block_height=len(headers))
|
|
||||||
self.assertEqual(10, headers.height)
|
|
||||||
await self.ledger.receive_header([{
|
|
||||||
'height': 11, 'hex': hexlify(self.make_header(block_height=11))
|
|
||||||
}])
|
|
||||||
|
|
||||||
async def test_3_block_reorganization(self):
|
|
||||||
self.ledger.network = MocHeaderNetwork({
|
|
||||||
10: {'height': 10, 'count': 5, 'hex': hexlify(
|
|
||||||
HEADERS[block_bytes(10):block_bytes(15)]
|
|
||||||
)},
|
|
||||||
11: {'height': 11, 'count': 1, 'hex': hexlify(self.make_header(block_height=11))},
|
|
||||||
12: {'height': 12, 'count': 1, 'hex': hexlify(self.make_header(block_height=12))},
|
|
||||||
15: {'height': 15, 'count': 0, 'hex': b''}
|
|
||||||
})
|
|
||||||
headers = self.ledger.headers
|
|
||||||
await headers.connect(0, HEADERS[:block_bytes(10)])
|
|
||||||
self.add_header(block_height=len(headers))
|
|
||||||
self.add_header(block_height=len(headers))
|
|
||||||
self.add_header(block_height=len(headers))
|
|
||||||
self.assertEqual(headers.height, 12)
|
|
||||||
await self.ledger.receive_header([{
|
|
||||||
'height': 13, 'hex': hexlify(self.make_header(block_height=13))
|
|
||||||
}])
|
|
||||||
|
|
||||||
|
|
||||||
class BasicAccountingTests(LedgerTestCase):
|
|
||||||
|
|
||||||
async def test_empty_state(self):
|
|
||||||
self.assertEqual(await self.account.get_balance(), 0)
|
|
||||||
|
|
||||||
async def test_balance(self):
|
|
||||||
address = await self.account.receiving.get_or_create_usable_address()
|
|
||||||
hash160 = self.ledger.address_to_hash160(address)
|
|
||||||
|
|
||||||
tx = Transaction(is_verified=True)\
|
|
||||||
.add_outputs([Output.pay_pubkey_hash(100, hash160)])
|
|
||||||
await self.ledger.db.insert_transaction(tx)
|
|
||||||
await self.ledger.db.save_transaction_io(
|
|
||||||
tx, address, hash160, f'{tx.id}:1:'
|
|
||||||
)
|
|
||||||
self.assertEqual(await self.account.get_balance(), 100)
|
|
||||||
|
|
||||||
tx = Transaction(is_verified=True)\
|
|
||||||
.add_outputs([Output.pay_claim_name_pubkey_hash(100, 'foo', b'', hash160)])
|
|
||||||
await self.ledger.db.insert_transaction(tx)
|
|
||||||
await self.ledger.db.save_transaction_io(
|
|
||||||
tx, address, hash160, f'{tx.id}:1:'
|
|
||||||
)
|
|
||||||
self.assertEqual(await self.account.get_balance(), 100) # claim names don't count towards balance
|
|
||||||
self.assertEqual(await self.account.get_balance(include_claims=True), 200)
|
|
||||||
|
|
||||||
async def test_get_utxo(self):
|
|
||||||
address = yield self.account.receiving.get_or_create_usable_address()
|
|
||||||
hash160 = self.ledger.address_to_hash160(address)
|
|
||||||
|
|
||||||
tx = Transaction(is_verified=True)\
|
|
||||||
.add_outputs([Output.pay_pubkey_hash(100, hash160)])
|
|
||||||
await self.ledger.db.save_transaction_io(
|
|
||||||
'insert', tx, address, hash160, f'{tx.id}:1:'
|
|
||||||
)
|
|
||||||
|
|
||||||
utxos = await self.account.get_utxos()
|
|
||||||
self.assertEqual(len(utxos), 1)
|
|
||||||
|
|
||||||
tx = Transaction(is_verified=True)\
|
|
||||||
.add_inputs([Input.spend(utxos[0])])
|
|
||||||
await self.ledger.db.save_transaction_io(
|
|
||||||
'insert', tx, address, hash160, f'{tx.id}:1:'
|
|
||||||
)
|
|
||||||
self.assertEqual(await self.account.get_balance(include_claims=True), 0)
|
|
||||||
|
|
||||||
utxos = await self.account.get_utxos()
|
|
||||||
self.assertEqual(len(utxos), 0)
|
|
|
@ -4,69 +4,70 @@ import tempfile
|
||||||
|
|
||||||
from lbry import Config, Ledger, Database, WalletManager, Wallet, Account
|
from lbry import Config, Ledger, Database, WalletManager, Wallet, Account
|
||||||
from lbry.testcase import AsyncioTestCase
|
from lbry.testcase import AsyncioTestCase
|
||||||
|
from lbry.wallet.manager import FileWallet, DatabaseWallet
|
||||||
|
|
||||||
|
|
||||||
class TestWalletManager(AsyncioTestCase):
|
class DBBasedWalletManagerTestCase(AsyncioTestCase):
|
||||||
|
|
||||||
async def asyncSetUp(self):
|
async def asyncSetUp(self):
|
||||||
self.temp_dir = tempfile.mkdtemp()
|
self.ledger = Ledger(Config.with_null_dir().set(
|
||||||
self.addCleanup(shutil.rmtree, self.temp_dir)
|
db_url="sqlite:///:memory:",
|
||||||
self.ledger = Ledger(Config.with_same_dir(self.temp_dir).set(
|
wallet_storage="database"
|
||||||
db_url="sqlite:///:memory:"
|
|
||||||
))
|
))
|
||||||
self.db = Database(self.ledger)
|
self.db = Database(self.ledger)
|
||||||
|
await self.db.open()
|
||||||
|
self.addCleanup(self.db.close)
|
||||||
|
|
||||||
async def test_ensure_path_exists(self):
|
|
||||||
wm = WalletManager(self.ledger, self.db)
|
|
||||||
self.assertFalse(os.path.exists(wm.path))
|
|
||||||
await wm.ensure_path_exists()
|
|
||||||
self.assertTrue(os.path.exists(wm.path))
|
|
||||||
|
|
||||||
async def test_load_with_default_wallet_account_progression(self):
|
class TestDatabaseWalletManager(DBBasedWalletManagerTestCase):
|
||||||
wm = WalletManager(self.ledger, self.db)
|
|
||||||
await wm.ensure_path_exists()
|
async def test_initialize_with_default_wallet_account_progression(self):
|
||||||
|
wm = WalletManager(self.db)
|
||||||
|
self.assertIsInstance(wm.storage, DatabaseWallet)
|
||||||
|
storage: DatabaseWallet = wm.storage
|
||||||
|
await storage.prepare()
|
||||||
|
|
||||||
# first, no defaults
|
# first, no defaults
|
||||||
self.ledger.conf.create_default_wallet = False
|
self.ledger.conf.create_default_wallet = False
|
||||||
self.ledger.conf.create_default_account = False
|
self.ledger.conf.create_default_account = False
|
||||||
await wm.load()
|
await wm.initialize()
|
||||||
self.assertIsNone(wm.default)
|
self.assertIsNone(wm.default)
|
||||||
|
|
||||||
# then, yes to default wallet but no to default account
|
# then, yes to default wallet but no to default account
|
||||||
self.ledger.conf.create_default_wallet = True
|
self.ledger.conf.create_default_wallet = True
|
||||||
self.ledger.conf.create_default_account = False
|
self.ledger.conf.create_default_account = False
|
||||||
await wm.load()
|
await wm.initialize()
|
||||||
self.assertIsInstance(wm.default, Wallet)
|
self.assertIsInstance(wm.default, Wallet)
|
||||||
self.assertTrue(os.path.exists(wm.default.storage.path))
|
self.assertTrue(await storage.exists(wm.default.id))
|
||||||
self.assertIsNone(wm.default.accounts.default)
|
self.assertIsNone(wm.default.accounts.default)
|
||||||
|
|
||||||
# finally, yes to all the things
|
# finally, yes to all the things
|
||||||
self.ledger.conf.create_default_wallet = True
|
self.ledger.conf.create_default_wallet = True
|
||||||
self.ledger.conf.create_default_account = True
|
self.ledger.conf.create_default_account = True
|
||||||
await wm.load()
|
await wm.initialize()
|
||||||
self.assertIsInstance(wm.default, Wallet)
|
self.assertIsInstance(wm.default, Wallet)
|
||||||
self.assertIsInstance(wm.default.accounts.default, Account)
|
self.assertIsInstance(wm.default.accounts.default, Account)
|
||||||
|
|
||||||
async def test_load_with_create_default_everything_upfront(self):
|
async def test_load_with_create_default_everything_upfront(self):
|
||||||
wm = WalletManager(self.ledger, self.db)
|
wm = WalletManager(self.db)
|
||||||
await wm.ensure_path_exists()
|
await wm.storage.prepare()
|
||||||
self.ledger.conf.create_default_wallet = True
|
self.ledger.conf.create_default_wallet = True
|
||||||
self.ledger.conf.create_default_account = True
|
self.ledger.conf.create_default_account = True
|
||||||
await wm.load()
|
await wm.initialize()
|
||||||
self.assertIsInstance(wm.default, Wallet)
|
self.assertIsInstance(wm.default, Wallet)
|
||||||
self.assertIsInstance(wm.default.accounts.default, Account)
|
self.assertIsInstance(wm.default.accounts.default, Account)
|
||||||
self.assertTrue(os.path.exists(wm.default.storage.path))
|
self.assertTrue(await wm.storage.exists(wm.default.id))
|
||||||
|
|
||||||
async def test_load_errors(self):
|
async def test_load_errors(self):
|
||||||
_wm = WalletManager(self.ledger, self.db)
|
_wm = WalletManager(self.db)
|
||||||
await _wm.ensure_path_exists()
|
await _wm.storage.prepare()
|
||||||
await _wm.create('bar', '')
|
await _wm.create('bar', '')
|
||||||
await _wm.create('foo', '')
|
await _wm.create('foo', '')
|
||||||
|
|
||||||
wm = WalletManager(self.ledger, self.db)
|
wm = WalletManager(self.db)
|
||||||
self.ledger.conf.wallets = ['bar', 'foo', 'foo']
|
self.ledger.conf.wallets = ['bar', 'foo', 'foo']
|
||||||
with self.assertLogs(level='WARN') as cm:
|
with self.assertLogs(level='WARN') as cm:
|
||||||
await wm.load()
|
await wm.initialize()
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
cm.output, [
|
cm.output, [
|
||||||
'WARNING:lbry.wallet.manager:Ignoring duplicate wallet_id in config: foo',
|
'WARNING:lbry.wallet.manager:Ignoring duplicate wallet_id in config: foo',
|
||||||
|
@ -75,9 +76,9 @@ class TestWalletManager(AsyncioTestCase):
|
||||||
self.assertEqual({'bar', 'foo'}, set(wm.wallets))
|
self.assertEqual({'bar', 'foo'}, set(wm.wallets))
|
||||||
|
|
||||||
async def test_creating_and_accessing_wallets(self):
|
async def test_creating_and_accessing_wallets(self):
|
||||||
wm = WalletManager(self.ledger, self.db)
|
wm = WalletManager(self.db)
|
||||||
await wm.ensure_path_exists()
|
await wm.storage.prepare()
|
||||||
await wm.load()
|
await wm.initialize()
|
||||||
default_wallet = wm.default
|
default_wallet = wm.default
|
||||||
self.assertEqual(default_wallet, wm['default_wallet'])
|
self.assertEqual(default_wallet, wm['default_wallet'])
|
||||||
self.assertEqual(default_wallet, wm.get_or_default(None))
|
self.assertEqual(default_wallet, wm.get_or_default(None))
|
||||||
|
@ -90,3 +91,114 @@ class TestWalletManager(AsyncioTestCase):
|
||||||
_ = wm['invalid']
|
_ = wm['invalid']
|
||||||
with self.assertRaisesRegex(ValueError, "Couldn't find wallet: invalid"):
|
with self.assertRaisesRegex(ValueError, "Couldn't find wallet: invalid"):
|
||||||
wm.get_or_default('invalid')
|
wm.get_or_default('invalid')
|
||||||
|
|
||||||
|
|
||||||
|
class TestFileBasedWalletManager(AsyncioTestCase):
|
||||||
|
|
||||||
|
async def asyncSetUp(self):
|
||||||
|
self.temp_dir = tempfile.mkdtemp()
|
||||||
|
self.addCleanup(shutil.rmtree, self.temp_dir)
|
||||||
|
self.ledger = Ledger(Config(
|
||||||
|
data_dir=self.temp_dir,
|
||||||
|
db_url="sqlite:///:memory:"
|
||||||
|
))
|
||||||
|
self.ledger.conf.set_default_paths()
|
||||||
|
self.db = Database(self.ledger)
|
||||||
|
await self.db.open()
|
||||||
|
self.addCleanup(self.db.close)
|
||||||
|
|
||||||
|
async def test_ensure_path_exists(self):
|
||||||
|
wm = WalletManager(self.db)
|
||||||
|
self.assertIsInstance(wm.storage, FileWallet)
|
||||||
|
storage: FileWallet = wm.storage
|
||||||
|
self.assertFalse(os.path.exists(storage.wallet_dir))
|
||||||
|
await storage.prepare()
|
||||||
|
self.assertTrue(os.path.exists(storage.wallet_dir))
|
||||||
|
|
||||||
|
async def test_initialize_with_default_wallet_account_progression(self):
|
||||||
|
wm = WalletManager(self.db)
|
||||||
|
storage: FileWallet = wm.storage
|
||||||
|
await storage.prepare()
|
||||||
|
|
||||||
|
# first, no defaults
|
||||||
|
self.ledger.conf.create_default_wallet = False
|
||||||
|
self.ledger.conf.create_default_account = False
|
||||||
|
await wm.initialize()
|
||||||
|
self.assertIsNone(wm.default)
|
||||||
|
|
||||||
|
# then, yes to default wallet but no to default account
|
||||||
|
self.ledger.conf.create_default_wallet = True
|
||||||
|
self.ledger.conf.create_default_account = False
|
||||||
|
await wm.initialize()
|
||||||
|
self.assertIsInstance(wm.default, Wallet)
|
||||||
|
self.assertTrue(os.path.exists(storage.get_wallet_path(wm.default.id)))
|
||||||
|
self.assertIsNone(wm.default.accounts.default)
|
||||||
|
|
||||||
|
# finally, yes to all the things
|
||||||
|
self.ledger.conf.create_default_wallet = True
|
||||||
|
self.ledger.conf.create_default_account = True
|
||||||
|
await wm.initialize()
|
||||||
|
self.assertIsInstance(wm.default, Wallet)
|
||||||
|
self.assertIsInstance(wm.default.accounts.default, Account)
|
||||||
|
|
||||||
|
async def test_load_with_create_default_everything_upfront(self):
|
||||||
|
wm = WalletManager(self.db)
|
||||||
|
await wm.storage.prepare()
|
||||||
|
self.ledger.conf.create_default_wallet = True
|
||||||
|
self.ledger.conf.create_default_account = True
|
||||||
|
await wm.initialize()
|
||||||
|
self.assertIsInstance(wm.default, Wallet)
|
||||||
|
self.assertIsInstance(wm.default.accounts.default, Account)
|
||||||
|
self.assertTrue(os.path.exists(wm.storage.get_wallet_path(wm.default.id)))
|
||||||
|
|
||||||
|
async def test_load_errors(self):
|
||||||
|
_wm = WalletManager(self.db)
|
||||||
|
await _wm.storage.prepare()
|
||||||
|
await _wm.create('bar', '')
|
||||||
|
await _wm.create('foo', '')
|
||||||
|
|
||||||
|
wm = WalletManager(self.db)
|
||||||
|
self.ledger.conf.wallets = ['bar', 'foo', 'foo']
|
||||||
|
with self.assertLogs(level='WARN') as cm:
|
||||||
|
await wm.initialize()
|
||||||
|
self.assertEqual(
|
||||||
|
cm.output, [
|
||||||
|
'WARNING:lbry.wallet.manager:Ignoring duplicate wallet_id in config: foo',
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.assertEqual({'bar', 'foo'}, set(wm.wallets))
|
||||||
|
|
||||||
|
async def test_creating_and_accessing_wallets(self):
|
||||||
|
wm = WalletManager(self.db)
|
||||||
|
await wm.storage.prepare()
|
||||||
|
await wm.initialize()
|
||||||
|
default_wallet = wm.default
|
||||||
|
self.assertEqual(default_wallet, wm['default_wallet'])
|
||||||
|
self.assertEqual(default_wallet, wm.get_or_default(None))
|
||||||
|
new_wallet = await wm.create('second', 'Second Wallet')
|
||||||
|
self.assertEqual(default_wallet, wm.default)
|
||||||
|
self.assertEqual(new_wallet, wm['second'])
|
||||||
|
self.assertEqual(new_wallet, wm.get_or_default('second'))
|
||||||
|
self.assertEqual(default_wallet, wm.get_or_default(None))
|
||||||
|
with self.assertRaisesRegex(ValueError, "Couldn't find wallet: invalid"):
|
||||||
|
_ = wm['invalid']
|
||||||
|
with self.assertRaisesRegex(ValueError, "Couldn't find wallet: invalid"):
|
||||||
|
wm.get_or_default('invalid')
|
||||||
|
|
||||||
|
async def test_read_write(self):
|
||||||
|
manager = WalletManager(self.db)
|
||||||
|
await manager.storage.prepare()
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(suffix='.json') as wallet_file:
|
||||||
|
wallet_file.write(b'{"version": 1}')
|
||||||
|
wallet_file.seek(0)
|
||||||
|
|
||||||
|
# create and write wallet to a file
|
||||||
|
wallet = await manager.load(wallet_file.name)
|
||||||
|
account = await wallet.accounts.generate()
|
||||||
|
await manager.storage.save(wallet)
|
||||||
|
|
||||||
|
# read wallet from file
|
||||||
|
wallet = await manager.load(wallet_file.name)
|
||||||
|
|
||||||
|
self.assertEqual(account.public_key.address, wallet.accounts.default.public_key.address)
|
||||||
|
|
|
@ -1,78 +0,0 @@
|
||||||
from lbry.testcase import UnitDBTestCase
|
|
||||||
|
|
||||||
|
|
||||||
class TestClientDBSync(UnitDBTestCase):
|
|
||||||
|
|
||||||
async def asyncSetUp(self):
|
|
||||||
await super().asyncSetUp()
|
|
||||||
await self.add(self.coinbase())
|
|
||||||
|
|
||||||
async def test_process_inputs(self):
|
|
||||||
await self.add(self.tx())
|
|
||||||
await self.add(self.tx())
|
|
||||||
txo1, txo2a, txo2b, txo3a, txo3b = self.outputs
|
|
||||||
self.assertEqual([
|
|
||||||
(txo1.id, None),
|
|
||||||
(txo2b.id, None),
|
|
||||||
], await self.get_txis())
|
|
||||||
self.assertEqual([
|
|
||||||
(txo1.id, False),
|
|
||||||
(txo2a.id, False),
|
|
||||||
(txo2b.id, False),
|
|
||||||
(txo3a.id, False),
|
|
||||||
(txo3b.id, False),
|
|
||||||
], await self.get_txos())
|
|
||||||
await self.db.process_all_things_after_sync()
|
|
||||||
self.assertEqual([
|
|
||||||
(txo1.id, txo1.get_address(self.ledger)),
|
|
||||||
(txo2b.id, txo2b.get_address(self.ledger)),
|
|
||||||
], await self.get_txis())
|
|
||||||
self.assertEqual([
|
|
||||||
(txo1.id, True),
|
|
||||||
(txo2a.id, False),
|
|
||||||
(txo2b.id, True),
|
|
||||||
(txo3a.id, False),
|
|
||||||
(txo3b.id, False),
|
|
||||||
], await self.get_txos())
|
|
||||||
|
|
||||||
async def test_process_claims(self):
|
|
||||||
claim1 = await self.add(self.create_claim())
|
|
||||||
await self.db.process_all_things_after_sync()
|
|
||||||
self.assertEqual([claim1.claim_id], await self.get_claims())
|
|
||||||
|
|
||||||
claim2 = await self.add(self.create_claim())
|
|
||||||
self.assertEqual([claim1.claim_id], await self.get_claims())
|
|
||||||
await self.db.process_all_things_after_sync()
|
|
||||||
self.assertEqual([claim1.claim_id, claim2.claim_id], await self.get_claims())
|
|
||||||
|
|
||||||
claim3 = await self.add(self.create_claim())
|
|
||||||
claim4 = await self.add(self.create_claim())
|
|
||||||
await self.db.process_all_things_after_sync()
|
|
||||||
self.assertEqual([
|
|
||||||
claim1.claim_id,
|
|
||||||
claim2.claim_id,
|
|
||||||
claim3.claim_id,
|
|
||||||
claim4.claim_id,
|
|
||||||
], await self.get_claims())
|
|
||||||
|
|
||||||
await self.add(self.abandon_claim(claim4))
|
|
||||||
await self.db.process_all_things_after_sync()
|
|
||||||
self.assertEqual([
|
|
||||||
claim1.claim_id, claim2.claim_id, claim3.claim_id
|
|
||||||
], await self.get_claims())
|
|
||||||
|
|
||||||
# create and abandon in same block
|
|
||||||
claim5 = await self.add(self.create_claim())
|
|
||||||
await self.add(self.abandon_claim(claim5))
|
|
||||||
await self.db.process_all_things_after_sync()
|
|
||||||
self.assertEqual([
|
|
||||||
claim1.claim_id, claim2.claim_id, claim3.claim_id
|
|
||||||
], await self.get_claims())
|
|
||||||
|
|
||||||
# create and abandon in different blocks but with bulk sync
|
|
||||||
claim6 = await self.add(self.create_claim())
|
|
||||||
await self.add(self.abandon_claim(claim6))
|
|
||||||
await self.db.process_all_things_after_sync()
|
|
||||||
self.assertEqual([
|
|
||||||
claim1.claim_id, claim2.claim_id, claim3.claim_id
|
|
||||||
], await self.get_claims())
|
|
|
@ -1,18 +1,17 @@
|
||||||
import tempfile
|
from itertools import cycle
|
||||||
from binascii import hexlify
|
from binascii import hexlify
|
||||||
from unittest import TestCase, mock
|
from unittest import TestCase, mock
|
||||||
|
|
||||||
from lbry import Config, Database, Ledger, Account, Wallet, WalletManager
|
from lbry import Config, Database, Ledger, Account, Wallet, Transaction, Output, Input
|
||||||
from lbry.testcase import AsyncioTestCase
|
from lbry.testcase import AsyncioTestCase, get_output, COIN, CENT
|
||||||
from lbry.wallet.storage import WalletStorage
|
|
||||||
from lbry.wallet.preferences import TimestampedPreferences
|
from lbry.wallet.preferences import TimestampedPreferences
|
||||||
|
|
||||||
|
|
||||||
class WalletTestCase(AsyncioTestCase):
|
class WalletTestCase(AsyncioTestCase):
|
||||||
|
|
||||||
async def asyncSetUp(self):
|
async def asyncSetUp(self):
|
||||||
self.ledger = Ledger(Config.with_null_dir())
|
self.ledger = Ledger(Config.with_null_dir().set(db_url='sqlite:///:memory:'))
|
||||||
self.db = Database(self.ledger, "sqlite:///:memory:")
|
self.db = Database(self.ledger)
|
||||||
await self.db.open()
|
await self.db.open()
|
||||||
self.addCleanup(self.db.close)
|
self.addCleanup(self.db.close)
|
||||||
|
|
||||||
|
@ -20,8 +19,8 @@ class WalletTestCase(AsyncioTestCase):
|
||||||
class WalletAccountTest(WalletTestCase):
|
class WalletAccountTest(WalletTestCase):
|
||||||
|
|
||||||
async def test_private_key_for_hierarchical_account(self):
|
async def test_private_key_for_hierarchical_account(self):
|
||||||
wallet = Wallet(self.ledger, self.db)
|
wallet = Wallet("wallet1", self.db)
|
||||||
account = wallet.add_account({
|
account = await wallet.accounts.add_from_dict({
|
||||||
"seed":
|
"seed":
|
||||||
"carbon smart garage balance margin twelve chest sword toas"
|
"carbon smart garage balance margin twelve chest sword toas"
|
||||||
"t envelope bottom stomach absent"
|
"t envelope bottom stomach absent"
|
||||||
|
@ -40,8 +39,8 @@ class WalletAccountTest(WalletTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def test_private_key_for_single_address_account(self):
|
async def test_private_key_for_single_address_account(self):
|
||||||
wallet = Wallet(self.ledger, self.db)
|
wallet = Wallet("wallet1", self.db)
|
||||||
account = wallet.add_account({
|
account = await wallet.accounts.add_from_dict({
|
||||||
"seed":
|
"seed":
|
||||||
"carbon smart garage balance margin twelve chest sword toas"
|
"carbon smart garage balance margin twelve chest sword toas"
|
||||||
"t envelope bottom stomach absent",
|
"t envelope bottom stomach absent",
|
||||||
|
@ -59,9 +58,9 @@ class WalletAccountTest(WalletTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def test_save_max_gap(self):
|
async def test_save_max_gap(self):
|
||||||
wallet = Wallet(self.ledger, self.db)
|
wallet = Wallet("wallet1", self.db)
|
||||||
account = wallet.generate_account(
|
account = await wallet.accounts.generate(
|
||||||
'lbryum', {
|
address_generator={
|
||||||
'name': 'deterministic-chain',
|
'name': 'deterministic-chain',
|
||||||
'receiving': {'gap': 3, 'maximum_uses_per_address': 2},
|
'receiving': {'gap': 3, 'maximum_uses_per_address': 2},
|
||||||
'change': {'gap': 4, 'maximum_uses_per_address': 2}
|
'change': {'gap': 4, 'maximum_uses_per_address': 2}
|
||||||
|
@ -73,24 +72,25 @@ class WalletAccountTest(WalletTestCase):
|
||||||
self.assertEqual(account.receiving.gap, 20)
|
self.assertEqual(account.receiving.gap, 20)
|
||||||
self.assertEqual(account.change.gap, 6)
|
self.assertEqual(account.change.gap, 6)
|
||||||
# doesn't fail for single-address account
|
# doesn't fail for single-address account
|
||||||
wallet.generate_account('lbryum', {'name': 'single-address'})
|
await wallet.accounts.generate(address_generator={'name': 'single-address'})
|
||||||
await wallet.save_max_gap()
|
await wallet.save_max_gap()
|
||||||
|
|
||||||
|
|
||||||
class TestWalletCreation(WalletTestCase):
|
class TestWalletCreation(WalletTestCase):
|
||||||
|
|
||||||
def test_create_wallet_and_accounts(self):
|
async def test_create_wallet_and_accounts(self):
|
||||||
wallet = Wallet(self.ledger, self.db)
|
wallet = Wallet("wallet1", self.db)
|
||||||
self.assertEqual(wallet.name, 'Wallet')
|
self.assertEqual(wallet.id, "wallet1")
|
||||||
self.assertListEqual(wallet.accounts, [])
|
self.assertEqual(wallet.name, "")
|
||||||
|
self.assertEqual(list(wallet.accounts), [])
|
||||||
|
|
||||||
account1 = wallet.generate_account()
|
account1 = await wallet.accounts.generate()
|
||||||
wallet.generate_account()
|
await wallet.accounts.generate()
|
||||||
wallet.generate_account()
|
await wallet.accounts.generate()
|
||||||
self.assertEqual(wallet.default_account, account1)
|
self.assertEqual(wallet.accounts.default, account1)
|
||||||
self.assertEqual(len(wallet.accounts), 3)
|
self.assertEqual(len(wallet.accounts), 3)
|
||||||
|
|
||||||
def test_load_and_save_wallet(self):
|
async def test_load_and_save_wallet(self):
|
||||||
wallet_dict = {
|
wallet_dict = {
|
||||||
'version': 1,
|
'version': 1,
|
||||||
'name': 'Main Wallet',
|
'name': 'Main Wallet',
|
||||||
|
@ -105,6 +105,7 @@ class TestWalletCreation(WalletTestCase):
|
||||||
"carbon smart garage balance margin twelve chest sword toast envelope bottom stomac"
|
"carbon smart garage balance margin twelve chest sword toast envelope bottom stomac"
|
||||||
"h absent",
|
"h absent",
|
||||||
'encrypted': False,
|
'encrypted': False,
|
||||||
|
'lang': 'en',
|
||||||
'private_key':
|
'private_key':
|
||||||
'xprv9s21ZrQH143K42ovpZygnjfHdAqSd9jo7zceDfPRogM7bkkoNVv7'
|
'xprv9s21ZrQH143K42ovpZygnjfHdAqSd9jo7zceDfPRogM7bkkoNVv7'
|
||||||
'DRNLEoB8HoirMgH969NrgL8jNzLEegqFzPRWM37GXd4uE8uuRkx4LAe',
|
'DRNLEoB8HoirMgH969NrgL8jNzLEegqFzPRWM37GXd4uE8uuRkx4LAe',
|
||||||
|
@ -120,15 +121,14 @@ class TestWalletCreation(WalletTestCase):
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
storage = WalletStorage(default=wallet_dict)
|
wallet = await Wallet.from_dict('wallet1', wallet_dict, self.db)
|
||||||
wallet = Wallet.from_storage(self.ledger, self.db, storage)
|
|
||||||
self.assertEqual(wallet.name, 'Main Wallet')
|
self.assertEqual(wallet.name, 'Main Wallet')
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
hexlify(wallet.hash),
|
hexlify(wallet.hash),
|
||||||
b'3b23aae8cd9b360f4296130b8f7afc5b2437560cdef7237bed245288ce8a5f79'
|
b'64a32cf8434a59c547abf61b4691a8189ac24272678b52ced2310fbf93eac974'
|
||||||
)
|
)
|
||||||
self.assertEqual(len(wallet.accounts), 1)
|
self.assertEqual(len(wallet.accounts), 1)
|
||||||
account = wallet.default_account
|
account = wallet.accounts.default
|
||||||
self.assertIsInstance(account, Account)
|
self.assertIsInstance(account, Account)
|
||||||
self.maxDiff = None
|
self.maxDiff = None
|
||||||
self.assertDictEqual(wallet_dict, wallet.to_dict())
|
self.assertDictEqual(wallet_dict, wallet.to_dict())
|
||||||
|
@ -137,41 +137,23 @@ class TestWalletCreation(WalletTestCase):
|
||||||
decrypted = Wallet.unpack('password', encrypted)
|
decrypted = Wallet.unpack('password', encrypted)
|
||||||
self.assertEqual(decrypted['accounts'][0]['name'], 'An Account')
|
self.assertEqual(decrypted['accounts'][0]['name'], 'An Account')
|
||||||
|
|
||||||
def test_read_write(self):
|
async def test_merge(self):
|
||||||
manager = WalletManager(self.ledger, self.db)
|
wallet1 = Wallet('wallet1', self.db)
|
||||||
|
|
||||||
with tempfile.NamedTemporaryFile(suffix='.json') as wallet_file:
|
|
||||||
wallet_file.write(b'{"version": 1}')
|
|
||||||
wallet_file.seek(0)
|
|
||||||
|
|
||||||
# create and write wallet to a file
|
|
||||||
wallet = manager.import_wallet(wallet_file.name)
|
|
||||||
account = wallet.generate_account()
|
|
||||||
wallet.save()
|
|
||||||
|
|
||||||
# read wallet from file
|
|
||||||
wallet_storage = WalletStorage(wallet_file.name)
|
|
||||||
wallet = Wallet.from_storage(self.ledger, self.db, wallet_storage)
|
|
||||||
|
|
||||||
self.assertEqual(account.public_key.address, wallet.default_account.public_key.address)
|
|
||||||
|
|
||||||
def test_merge(self):
|
|
||||||
wallet1 = Wallet(self.ledger, self.db)
|
|
||||||
wallet1.preferences['one'] = 1
|
wallet1.preferences['one'] = 1
|
||||||
wallet1.preferences['conflict'] = 1
|
wallet1.preferences['conflict'] = 1
|
||||||
wallet1.generate_account()
|
await wallet1.accounts.generate()
|
||||||
wallet2 = Wallet(self.ledger, self.db)
|
wallet2 = Wallet('wallet2', self.db)
|
||||||
wallet2.preferences['two'] = 2
|
wallet2.preferences['two'] = 2
|
||||||
wallet2.preferences['conflict'] = 2 # will be more recent
|
wallet2.preferences['conflict'] = 2 # will be more recent
|
||||||
wallet2.generate_account()
|
await wallet2.accounts.generate()
|
||||||
|
|
||||||
self.assertEqual(len(wallet1.accounts), 1)
|
self.assertEqual(len(wallet1.accounts), 1)
|
||||||
self.assertEqual(wallet1.preferences, {'one': 1, 'conflict': 1})
|
self.assertEqual(wallet1.preferences, {'one': 1, 'conflict': 1})
|
||||||
|
|
||||||
added = await wallet1.merge('password', wallet2.pack('password'))
|
added = await wallet1.merge('password', wallet2.pack('password'))
|
||||||
self.assertEqual(added[0].id, wallet2.default_account.id)
|
self.assertEqual(added[0].id, wallet2.accounts.default.id)
|
||||||
self.assertEqual(len(wallet1.accounts), 2)
|
self.assertEqual(len(wallet1.accounts), 2)
|
||||||
self.assertEqual(wallet1.accounts[1].id, wallet2.default_account.id)
|
self.assertEqual(list(wallet1.accounts)[1].id, wallet2.accounts.default.id)
|
||||||
self.assertEqual(wallet1.preferences, {'one': 1, 'two': 2, 'conflict': 2})
|
self.assertEqual(wallet1.preferences, {'one': 1, 'two': 2, 'conflict': 2})
|
||||||
|
|
||||||
|
|
||||||
|
@ -213,3 +195,138 @@ class TestTimestampedPreferences(TestCase):
|
||||||
p1['conflict'] = 1
|
p1['conflict'] = 1
|
||||||
p1.merge(p2.data)
|
p1.merge(p2.data)
|
||||||
self.assertEqual(p1, {'one': 1, 'two': 2, 'conflict': 1})
|
self.assertEqual(p1, {'one': 1, 'two': 2, 'conflict': 1})
|
||||||
|
|
||||||
|
|
||||||
|
class TestTransactionSigning(WalletTestCase):
|
||||||
|
|
||||||
|
async def test_sign(self):
|
||||||
|
wallet = Wallet('wallet1', self.db)
|
||||||
|
account = await wallet.accounts.add_from_dict({
|
||||||
|
"seed":
|
||||||
|
"carbon smart garage balance margin twelve chest sword toas"
|
||||||
|
"t envelope bottom stomach absent"
|
||||||
|
})
|
||||||
|
|
||||||
|
await account.ensure_address_gap()
|
||||||
|
address1, address2 = await account.receiving.get_addresses(limit=2)
|
||||||
|
pubkey_hash1 = self.ledger.address_to_hash160(address1)
|
||||||
|
pubkey_hash2 = self.ledger.address_to_hash160(address2)
|
||||||
|
|
||||||
|
tx = Transaction() \
|
||||||
|
.add_inputs([Input.spend(get_output(int(2*COIN), pubkey_hash1))]) \
|
||||||
|
.add_outputs([Output.pay_pubkey_hash(int(1.9*COIN), pubkey_hash2)])
|
||||||
|
|
||||||
|
await wallet.sign(tx)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
hexlify(tx.inputs[0].script.values['signature']),
|
||||||
|
b'304402200dafa26ad7cf38c5a971c8a25ce7d85a076235f146126762296b1223c42ae21e022020ef9eeb8'
|
||||||
|
b'398327891008c5c0be4357683f12cb22346691ff23914f457bf679601'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TransactionIOBalancing(WalletTestCase):
|
||||||
|
|
||||||
|
async def asyncSetUp(self):
|
||||||
|
await super().asyncSetUp()
|
||||||
|
self.wallet = Wallet('wallet1', self.db)
|
||||||
|
self.account = await self.wallet.accounts.add_from_dict({
|
||||||
|
"seed":
|
||||||
|
"carbon smart garage balance margin twelve chest sword toas"
|
||||||
|
"t envelope bottom stomach absent"
|
||||||
|
})
|
||||||
|
addresses = await self.account.ensure_address_gap()
|
||||||
|
self.pubkey_hash = [self.ledger.address_to_hash160(a) for a in addresses]
|
||||||
|
self.hash_cycler = cycle(self.pubkey_hash)
|
||||||
|
|
||||||
|
def txo(self, amount, address=None):
|
||||||
|
return get_output(int(amount*COIN), address or next(self.hash_cycler))
|
||||||
|
|
||||||
|
def txi(self, txo):
|
||||||
|
return Input.spend(txo)
|
||||||
|
|
||||||
|
def tx(self, inputs, outputs):
|
||||||
|
return self.wallet.create_transaction(inputs, outputs, [self.account], self.account)
|
||||||
|
|
||||||
|
async def create_utxos(self, amounts):
|
||||||
|
utxos = [self.txo(amount) for amount in amounts]
|
||||||
|
|
||||||
|
self.funding_tx = Transaction(is_verified=True) \
|
||||||
|
.add_inputs([self.txi(self.txo(sum(amounts)+0.1))]) \
|
||||||
|
.add_outputs(utxos)
|
||||||
|
|
||||||
|
await self.db.insert_transaction(b'beef', self.funding_tx)
|
||||||
|
|
||||||
|
return utxos
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def inputs(tx):
|
||||||
|
return [round(i.amount/COIN, 2) for i in tx.inputs]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def outputs(tx):
|
||||||
|
return [round(o.amount/COIN, 2) for o in tx.outputs]
|
||||||
|
|
||||||
|
async def test_basic_use_cases(self):
|
||||||
|
self.ledger.fee_per_byte = int(.01*CENT)
|
||||||
|
|
||||||
|
# available UTXOs for filling missing inputs
|
||||||
|
utxos = await self.create_utxos([
|
||||||
|
1, 1, 3, 5, 10
|
||||||
|
])
|
||||||
|
|
||||||
|
# pay 3 coins (3.02 w/ fees)
|
||||||
|
tx = await self.tx(
|
||||||
|
[], # inputs
|
||||||
|
[self.txo(3)] # outputs
|
||||||
|
)
|
||||||
|
# best UTXO match is 5 (as UTXO 3 will be short 0.02 to cover fees)
|
||||||
|
self.assertListEqual(self.inputs(tx), [5])
|
||||||
|
# a change of 1.98 is added to reach balance
|
||||||
|
self.assertListEqual(self.outputs(tx), [3, 1.98])
|
||||||
|
|
||||||
|
await self.db.release_outputs(utxos)
|
||||||
|
|
||||||
|
# pay 2.98 coins (3.00 w/ fees)
|
||||||
|
tx = await self.tx(
|
||||||
|
[], # inputs
|
||||||
|
[self.txo(2.98)] # outputs
|
||||||
|
)
|
||||||
|
# best UTXO match is 3 and no change is needed
|
||||||
|
self.assertListEqual(self.inputs(tx), [3])
|
||||||
|
self.assertListEqual(self.outputs(tx), [2.98])
|
||||||
|
|
||||||
|
await self.db.release_outputs(utxos)
|
||||||
|
|
||||||
|
# supplied input and output, but input is not enough to cover output
|
||||||
|
tx = await self.tx(
|
||||||
|
[self.txi(self.txo(10))], # inputs
|
||||||
|
[self.txo(11)] # outputs
|
||||||
|
)
|
||||||
|
# additional input is chosen (UTXO 3)
|
||||||
|
self.assertListEqual([10, 3], self.inputs(tx))
|
||||||
|
# change is now needed to consume extra input
|
||||||
|
self.assertListEqual([11, 1.96], self.outputs(tx))
|
||||||
|
|
||||||
|
await self.db.release_outputs(utxos)
|
||||||
|
|
||||||
|
# liquidating a UTXO
|
||||||
|
tx = await self.tx(
|
||||||
|
[self.txi(self.txo(10))], # inputs
|
||||||
|
[] # outputs
|
||||||
|
)
|
||||||
|
self.assertListEqual([10], self.inputs(tx))
|
||||||
|
# missing change added to consume the amount
|
||||||
|
self.assertListEqual([9.98], self.outputs(tx))
|
||||||
|
|
||||||
|
await self.db.release_outputs(utxos)
|
||||||
|
|
||||||
|
# liquidating at a loss, requires adding extra inputs
|
||||||
|
tx = await self.tx(
|
||||||
|
[self.txi(self.txo(0.01))], # inputs
|
||||||
|
[] # outputs
|
||||||
|
)
|
||||||
|
# UTXO 1 is added to cover some of the fee
|
||||||
|
self.assertListEqual([0.01, 1], self.inputs(tx))
|
||||||
|
# change is now needed to consume extra input
|
||||||
|
self.assertListEqual([0.97], self.outputs(tx))
|
||||||
|
|
Loading…
Reference in a new issue