merged torba base classes with lbry sub-classes
This commit is contained in:
parent
c8d72b59c0
commit
c9e410a6f4
23 changed files with 0 additions and 3475 deletions
|
@ -1,485 +0,0 @@
|
|||
import os
|
||||
import json
|
||||
import time
|
||||
import asyncio
|
||||
import random
|
||||
import typing
|
||||
from typing import Dict, Tuple, Type, Optional, Any, List
|
||||
|
||||
from lbry.crypto.hash import sha256
|
||||
from lbry.crypto.crypt import aes_encrypt, aes_decrypt
|
||||
from lbry.wallet.client.bip32 import PrivateKey, PubKey, from_extended_key_string
|
||||
from lbry.wallet.client.mnemonic import Mnemonic
|
||||
from lbry.wallet.client.constants import COIN
|
||||
from lbry.error import InvalidPasswordError
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from lbry.wallet.client import baseledger, wallet as basewallet
|
||||
|
||||
|
||||
class AddressManager:
|
||||
|
||||
name: str
|
||||
|
||||
__slots__ = 'account', 'public_key', 'chain_number', 'address_generator_lock'
|
||||
|
||||
def __init__(self, account, public_key, chain_number):
|
||||
self.account = account
|
||||
self.public_key = public_key
|
||||
self.chain_number = chain_number
|
||||
self.address_generator_lock = asyncio.Lock()
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, account: 'BaseAccount', d: dict) \
|
||||
-> Tuple['AddressManager', 'AddressManager']:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def to_dict(cls, receiving: 'AddressManager', change: 'AddressManager') -> Dict:
|
||||
d: Dict[str, Any] = {'name': cls.name}
|
||||
receiving_dict = receiving.to_dict_instance()
|
||||
if receiving_dict:
|
||||
d['receiving'] = receiving_dict
|
||||
change_dict = change.to_dict_instance()
|
||||
if change_dict:
|
||||
d['change'] = change_dict
|
||||
return d
|
||||
|
||||
def merge(self, d: dict):
|
||||
pass
|
||||
|
||||
def to_dict_instance(self) -> Optional[dict]:
|
||||
raise NotImplementedError
|
||||
|
||||
def _query_addresses(self, **constraints):
|
||||
return self.account.ledger.db.get_addresses(
|
||||
accounts=[self.account],
|
||||
chain=self.chain_number,
|
||||
**constraints
|
||||
)
|
||||
|
||||
def get_private_key(self, index: int) -> PrivateKey:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_public_key(self, index: int) -> PubKey:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_max_gap(self):
|
||||
raise NotImplementedError
|
||||
|
||||
async def ensure_address_gap(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_address_records(self, only_usable: bool = False, **constraints):
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_addresses(self, only_usable: bool = False, **constraints) -> List[str]:
|
||||
records = await self.get_address_records(only_usable=only_usable, **constraints)
|
||||
return [r['address'] for r in records]
|
||||
|
||||
async def get_or_create_usable_address(self) -> str:
|
||||
addresses = await self.get_addresses(only_usable=True, limit=10)
|
||||
if addresses:
|
||||
return random.choice(addresses)
|
||||
addresses = await self.ensure_address_gap()
|
||||
return addresses[0]
|
||||
|
||||
|
||||
class HierarchicalDeterministic(AddressManager):
|
||||
""" Implements simple version of Bitcoin Hierarchical Deterministic key management. """
|
||||
|
||||
name: str = "deterministic-chain"
|
||||
|
||||
__slots__ = 'gap', 'maximum_uses_per_address'
|
||||
|
||||
def __init__(self, account: 'BaseAccount', chain: int, gap: int, maximum_uses_per_address: int) -> None:
|
||||
super().__init__(account, account.public_key.child(chain), chain)
|
||||
self.gap = gap
|
||||
self.maximum_uses_per_address = maximum_uses_per_address
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, account: 'BaseAccount', d: dict) -> Tuple[AddressManager, AddressManager]:
|
||||
return (
|
||||
cls(account, 0, **d.get('receiving', {'gap': 20, 'maximum_uses_per_address': 1})),
|
||||
cls(account, 1, **d.get('change', {'gap': 6, 'maximum_uses_per_address': 1}))
|
||||
)
|
||||
|
||||
def merge(self, d: dict):
|
||||
self.gap = d.get('gap', self.gap)
|
||||
self.maximum_uses_per_address = d.get('maximum_uses_per_address', self.maximum_uses_per_address)
|
||||
|
||||
def to_dict_instance(self):
|
||||
return {'gap': self.gap, 'maximum_uses_per_address': self.maximum_uses_per_address}
|
||||
|
||||
def get_private_key(self, index: int) -> PrivateKey:
|
||||
return self.account.private_key.child(self.chain_number).child(index)
|
||||
|
||||
def get_public_key(self, index: int) -> PubKey:
|
||||
return self.account.public_key.child(self.chain_number).child(index)
|
||||
|
||||
async def get_max_gap(self) -> int:
|
||||
addresses = await self._query_addresses(order_by="n asc")
|
||||
max_gap = 0
|
||||
current_gap = 0
|
||||
for address in addresses:
|
||||
if address['used_times'] == 0:
|
||||
current_gap += 1
|
||||
else:
|
||||
max_gap = max(max_gap, current_gap)
|
||||
current_gap = 0
|
||||
return max_gap
|
||||
|
||||
async def ensure_address_gap(self) -> List[str]:
|
||||
async with self.address_generator_lock:
|
||||
addresses = await self._query_addresses(limit=self.gap, order_by="n desc")
|
||||
|
||||
existing_gap = 0
|
||||
for address in addresses:
|
||||
if address['used_times'] == 0:
|
||||
existing_gap += 1
|
||||
else:
|
||||
break
|
||||
|
||||
if existing_gap == self.gap:
|
||||
return []
|
||||
|
||||
start = addresses[0]['pubkey'].n+1 if addresses else 0
|
||||
end = start + (self.gap - existing_gap)
|
||||
new_keys = await self._generate_keys(start, end-1)
|
||||
await self.account.ledger.announce_addresses(self, new_keys)
|
||||
return new_keys
|
||||
|
||||
async def _generate_keys(self, start: int, end: int) -> List[str]:
|
||||
if not self.address_generator_lock.locked():
|
||||
raise RuntimeError('Should not be called outside of address_generator_lock.')
|
||||
keys = [self.public_key.child(index) for index in range(start, end+1)]
|
||||
await self.account.ledger.db.add_keys(self.account, self.chain_number, keys)
|
||||
return [key.address for key in keys]
|
||||
|
||||
def get_address_records(self, only_usable: bool = False, **constraints):
|
||||
if only_usable:
|
||||
constraints['used_times__lt'] = self.maximum_uses_per_address
|
||||
if 'order_by' not in constraints:
|
||||
constraints['order_by'] = "used_times asc, n asc"
|
||||
return self._query_addresses(**constraints)
|
||||
|
||||
|
||||
class SingleKey(AddressManager):
|
||||
""" Single Key address manager always returns the same address for all operations. """
|
||||
|
||||
name: str = "single-address"
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, account: 'BaseAccount', d: dict)\
|
||||
-> Tuple[AddressManager, AddressManager]:
|
||||
same_address_manager = cls(account, account.public_key, 0)
|
||||
return same_address_manager, same_address_manager
|
||||
|
||||
def to_dict_instance(self):
|
||||
return None
|
||||
|
||||
def get_private_key(self, index: int) -> PrivateKey:
|
||||
return self.account.private_key
|
||||
|
||||
def get_public_key(self, index: int) -> PubKey:
|
||||
return self.account.public_key
|
||||
|
||||
async def get_max_gap(self) -> int:
|
||||
return 0
|
||||
|
||||
async def ensure_address_gap(self) -> List[str]:
|
||||
async with self.address_generator_lock:
|
||||
exists = await self.get_address_records()
|
||||
if not exists:
|
||||
await self.account.ledger.db.add_keys(self.account, self.chain_number, [self.public_key])
|
||||
new_keys = [self.public_key.address]
|
||||
await self.account.ledger.announce_addresses(self, new_keys)
|
||||
return new_keys
|
||||
return []
|
||||
|
||||
def get_address_records(self, only_usable: bool = False, **constraints):
|
||||
return self._query_addresses(**constraints)
|
||||
|
||||
|
||||
class BaseAccount:
|
||||
|
||||
mnemonic_class = Mnemonic
|
||||
private_key_class = PrivateKey
|
||||
public_key_class = PubKey
|
||||
address_generators: Dict[str, Type[AddressManager]] = {
|
||||
SingleKey.name: SingleKey,
|
||||
HierarchicalDeterministic.name: HierarchicalDeterministic,
|
||||
}
|
||||
|
||||
def __init__(self, ledger: 'baseledger.BaseLedger', wallet: 'basewallet.Wallet', name: str,
|
||||
seed: str, private_key_string: str, encrypted: bool,
|
||||
private_key: Optional[PrivateKey], public_key: PubKey,
|
||||
address_generator: dict, modified_on: float) -> None:
|
||||
self.ledger = ledger
|
||||
self.wallet = wallet
|
||||
self.id = public_key.address
|
||||
self.name = name
|
||||
self.seed = seed
|
||||
self.modified_on = modified_on
|
||||
self.private_key_string = private_key_string
|
||||
self.init_vectors: Dict[str, bytes] = {}
|
||||
self.encrypted = encrypted
|
||||
self.private_key = private_key
|
||||
self.public_key = public_key
|
||||
generator_name = address_generator.get('name', HierarchicalDeterministic.name)
|
||||
self.address_generator = self.address_generators[generator_name]
|
||||
self.receiving, self.change = self.address_generator.from_dict(self, address_generator)
|
||||
self.address_managers = {am.chain_number: am for am in {self.receiving, self.change}}
|
||||
ledger.add_account(self)
|
||||
wallet.add_account(self)
|
||||
|
||||
def get_init_vector(self, key) -> Optional[bytes]:
|
||||
init_vector = self.init_vectors.get(key, None)
|
||||
if init_vector is None:
|
||||
init_vector = self.init_vectors[key] = os.urandom(16)
|
||||
return init_vector
|
||||
|
||||
@classmethod
|
||||
def generate(cls, ledger: 'baseledger.BaseLedger', wallet: 'basewallet.Wallet',
|
||||
name: str = None, address_generator: dict = None):
|
||||
return cls.from_dict(ledger, wallet, {
|
||||
'name': name,
|
||||
'seed': cls.mnemonic_class().make_seed(),
|
||||
'address_generator': address_generator or {}
|
||||
})
|
||||
|
||||
@classmethod
|
||||
def get_private_key_from_seed(cls, ledger: 'baseledger.BaseLedger', seed: str, password: str):
|
||||
return cls.private_key_class.from_seed(
|
||||
ledger, cls.mnemonic_class.mnemonic_to_seed(seed, password)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def keys_from_dict(cls, ledger: 'baseledger.BaseLedger', d: dict) \
|
||||
-> Tuple[str, Optional[PrivateKey], PubKey]:
|
||||
seed = d.get('seed', '')
|
||||
private_key_string = d.get('private_key', '')
|
||||
private_key = None
|
||||
public_key = None
|
||||
encrypted = d.get('encrypted', False)
|
||||
if not encrypted:
|
||||
if seed:
|
||||
private_key = cls.get_private_key_from_seed(ledger, seed, '')
|
||||
public_key = private_key.public_key
|
||||
elif private_key_string:
|
||||
private_key = from_extended_key_string(ledger, private_key_string)
|
||||
public_key = private_key.public_key
|
||||
if public_key is None:
|
||||
public_key = from_extended_key_string(ledger, d['public_key'])
|
||||
return seed, private_key, public_key
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, ledger: 'baseledger.BaseLedger', wallet: 'basewallet.Wallet', d: dict):
|
||||
seed, private_key, public_key = cls.keys_from_dict(ledger, d)
|
||||
name = d.get('name')
|
||||
if not name:
|
||||
name = f'Account #{public_key.address}'
|
||||
return cls(
|
||||
ledger=ledger,
|
||||
wallet=wallet,
|
||||
name=name,
|
||||
seed=seed,
|
||||
private_key_string=d.get('private_key', ''),
|
||||
encrypted=d.get('encrypted', False),
|
||||
private_key=private_key,
|
||||
public_key=public_key,
|
||||
address_generator=d.get('address_generator', {}),
|
||||
modified_on=d.get('modified_on', time.time())
|
||||
)
|
||||
|
||||
def to_dict(self, encrypt_password: str = None):
|
||||
private_key_string, seed = self.private_key_string, self.seed
|
||||
if not self.encrypted and self.private_key:
|
||||
private_key_string = self.private_key.extended_key_string()
|
||||
if not self.encrypted and encrypt_password:
|
||||
if private_key_string:
|
||||
private_key_string = aes_encrypt(
|
||||
encrypt_password, private_key_string, self.get_init_vector('private_key')
|
||||
)
|
||||
if seed:
|
||||
seed = aes_encrypt(encrypt_password, self.seed, self.get_init_vector('seed'))
|
||||
return {
|
||||
'ledger': self.ledger.get_id(),
|
||||
'name': self.name,
|
||||
'seed': seed,
|
||||
'encrypted': bool(self.encrypted or encrypt_password),
|
||||
'private_key': private_key_string,
|
||||
'public_key': self.public_key.extended_key_string(),
|
||||
'address_generator': self.address_generator.to_dict(self.receiving, self.change),
|
||||
'modified_on': self.modified_on
|
||||
}
|
||||
|
||||
def merge(self, d: dict):
|
||||
if d.get('modified_on', 0) > self.modified_on:
|
||||
self.name = d['name']
|
||||
self.modified_on = d.get('modified_on', time.time())
|
||||
assert self.address_generator.name == d['address_generator']['name']
|
||||
for chain_name in ('change', 'receiving'):
|
||||
if chain_name in d['address_generator']:
|
||||
chain_object = getattr(self, chain_name)
|
||||
chain_object.merge(d['address_generator'][chain_name])
|
||||
|
||||
@property
|
||||
def hash(self) -> bytes:
|
||||
assert not self.encrypted, "Cannot hash an encrypted account."
|
||||
return sha256(json.dumps(self.to_dict()).encode())
|
||||
|
||||
async def get_details(self, show_seed=False, **kwargs):
|
||||
satoshis = await self.get_balance(**kwargs)
|
||||
details = {
|
||||
'id': self.id,
|
||||
'name': self.name,
|
||||
'ledger': self.ledger.get_id(),
|
||||
'coins': round(satoshis/COIN, 2),
|
||||
'satoshis': satoshis,
|
||||
'encrypted': self.encrypted,
|
||||
'public_key': self.public_key.extended_key_string(),
|
||||
'address_generator': self.address_generator.to_dict(self.receiving, self.change)
|
||||
}
|
||||
if show_seed:
|
||||
details['seed'] = self.seed
|
||||
return details
|
||||
|
||||
def decrypt(self, password: str) -> bool:
|
||||
assert self.encrypted, "Key is not encrypted."
|
||||
try:
|
||||
seed = self._decrypt_seed(password)
|
||||
except (ValueError, InvalidPasswordError):
|
||||
return False
|
||||
try:
|
||||
private_key = self._decrypt_private_key_string(password)
|
||||
except (TypeError, ValueError, InvalidPasswordError):
|
||||
return False
|
||||
self.seed = seed
|
||||
self.private_key = private_key
|
||||
self.private_key_string = ""
|
||||
self.encrypted = False
|
||||
return True
|
||||
|
||||
def _decrypt_private_key_string(self, password: str) -> Optional[PrivateKey]:
|
||||
if not self.private_key_string:
|
||||
return None
|
||||
private_key_string, self.init_vectors['private_key'] = aes_decrypt(password, self.private_key_string)
|
||||
if not private_key_string:
|
||||
return None
|
||||
return from_extended_key_string(
|
||||
self.ledger, private_key_string
|
||||
)
|
||||
|
||||
def _decrypt_seed(self, password: str) -> str:
|
||||
if not self.seed:
|
||||
return ""
|
||||
seed, self.init_vectors['seed'] = aes_decrypt(password, self.seed)
|
||||
if not seed:
|
||||
return ""
|
||||
try:
|
||||
Mnemonic().mnemonic_decode(seed)
|
||||
except IndexError:
|
||||
# failed to decode the seed, this either means it decrypted and is invalid
|
||||
# or that we hit an edge case where an incorrect password gave valid padding
|
||||
raise ValueError("Failed to decode seed.")
|
||||
return seed
|
||||
|
||||
def encrypt(self, password: str) -> bool:
|
||||
assert not self.encrypted, "Key is already encrypted."
|
||||
if self.seed:
|
||||
self.seed = aes_encrypt(password, self.seed, self.get_init_vector('seed'))
|
||||
if isinstance(self.private_key, PrivateKey):
|
||||
self.private_key_string = aes_encrypt(
|
||||
password, self.private_key.extended_key_string(), self.get_init_vector('private_key')
|
||||
)
|
||||
self.private_key = None
|
||||
self.encrypted = True
|
||||
return True
|
||||
|
||||
async def ensure_address_gap(self):
|
||||
addresses = []
|
||||
for address_manager in self.address_managers.values():
|
||||
new_addresses = await address_manager.ensure_address_gap()
|
||||
addresses.extend(new_addresses)
|
||||
return addresses
|
||||
|
||||
async def get_addresses(self, **constraints) -> List[str]:
|
||||
rows = await self.ledger.db.select_addresses('address', accounts=[self], **constraints)
|
||||
return [r[0] for r in rows]
|
||||
|
||||
def get_address_records(self, **constraints):
|
||||
return self.ledger.db.get_addresses(accounts=[self], **constraints)
|
||||
|
||||
def get_address_count(self, **constraints):
|
||||
return self.ledger.db.get_address_count(accounts=[self], **constraints)
|
||||
|
||||
def get_private_key(self, chain: int, index: int) -> PrivateKey:
|
||||
assert not self.encrypted, "Cannot get private key on encrypted wallet account."
|
||||
return self.address_managers[chain].get_private_key(index)
|
||||
|
||||
def get_public_key(self, chain: int, index: int) -> PubKey:
|
||||
return self.address_managers[chain].get_public_key(index)
|
||||
|
||||
def get_balance(self, confirmations: int = 0, **constraints):
|
||||
if confirmations > 0:
|
||||
height = self.ledger.headers.height - (confirmations-1)
|
||||
constraints.update({'height__lte': height, 'height__gt': 0})
|
||||
return self.ledger.db.get_balance(accounts=[self], **constraints)
|
||||
|
||||
async def get_max_gap(self):
|
||||
change_gap = await self.change.get_max_gap()
|
||||
receiving_gap = await self.receiving.get_max_gap()
|
||||
return {
|
||||
'max_change_gap': change_gap,
|
||||
'max_receiving_gap': receiving_gap,
|
||||
}
|
||||
|
||||
def get_utxos(self, **constraints):
|
||||
return self.ledger.get_utxos(wallet=self.wallet, accounts=[self], **constraints)
|
||||
|
||||
def get_utxo_count(self, **constraints):
|
||||
return self.ledger.get_utxo_count(wallet=self.wallet, accounts=[self], **constraints)
|
||||
|
||||
def get_transactions(self, **constraints):
|
||||
return self.ledger.get_transactions(wallet=self.wallet, accounts=[self], **constraints)
|
||||
|
||||
def get_transaction_count(self, **constraints):
|
||||
return self.ledger.get_transaction_count(wallet=self.wallet, accounts=[self], **constraints)
|
||||
|
||||
async def fund(self, to_account, amount=None, everything=False,
|
||||
outputs=1, broadcast=False, **constraints):
|
||||
assert self.ledger == to_account.ledger, 'Can only transfer between accounts of the same ledger.'
|
||||
tx_class = self.ledger.transaction_class
|
||||
if everything:
|
||||
utxos = await self.get_utxos(**constraints)
|
||||
await self.ledger.reserve_outputs(utxos)
|
||||
tx = await tx_class.create(
|
||||
inputs=[tx_class.input_class.spend(txo) for txo in utxos],
|
||||
outputs=[],
|
||||
funding_accounts=[self],
|
||||
change_account=to_account
|
||||
)
|
||||
elif amount > 0:
|
||||
to_address = await to_account.change.get_or_create_usable_address()
|
||||
to_hash160 = to_account.ledger.address_to_hash160(to_address)
|
||||
tx = await tx_class.create(
|
||||
inputs=[],
|
||||
outputs=[
|
||||
tx_class.output_class.pay_pubkey_hash(amount//outputs, to_hash160)
|
||||
for _ in range(outputs)
|
||||
],
|
||||
funding_accounts=[self],
|
||||
change_account=self
|
||||
)
|
||||
else:
|
||||
raise ValueError('An amount is required.')
|
||||
|
||||
if broadcast:
|
||||
await self.ledger.broadcast(tx)
|
||||
else:
|
||||
await self.ledger.release_tx(tx)
|
||||
|
||||
return tx
|
|
@ -1,652 +0,0 @@
|
|||
import logging
|
||||
import asyncio
|
||||
from binascii import hexlify
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
|
||||
from typing import Tuple, List, Union, Callable, Any, Awaitable, Iterable, Dict, Optional
|
||||
|
||||
import sqlite3
|
||||
|
||||
from lbry.wallet.client.basetransaction import BaseTransaction, TXRefImmutable
|
||||
from lbry.wallet.client.bip32 import PubKey
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
sqlite3.enable_callback_tracebacks(True)
|
||||
|
||||
|
||||
class AIOSQLite:
|
||||
|
||||
def __init__(self):
|
||||
# has to be single threaded as there is no mapping of thread:connection
|
||||
self.executor = ThreadPoolExecutor(max_workers=1)
|
||||
self.connection: sqlite3.Connection = None
|
||||
self._closing = False
|
||||
self.query_count = 0
|
||||
|
||||
@classmethod
|
||||
async def connect(cls, path: Union[bytes, str], *args, **kwargs):
|
||||
sqlite3.enable_callback_tracebacks(True)
|
||||
def _connect():
|
||||
return sqlite3.connect(path, *args, **kwargs)
|
||||
db = cls()
|
||||
db.connection = await asyncio.get_event_loop().run_in_executor(db.executor, _connect)
|
||||
return db
|
||||
|
||||
async def close(self):
|
||||
if self._closing:
|
||||
return
|
||||
self._closing = True
|
||||
await asyncio.get_event_loop().run_in_executor(self.executor, self.connection.close)
|
||||
self.executor.shutdown(wait=True)
|
||||
self.connection = None
|
||||
|
||||
def executemany(self, sql: str, params: Iterable):
|
||||
params = params if params is not None else []
|
||||
# this fetchall is needed to prevent SQLITE_MISUSE
|
||||
return self.run(lambda conn: conn.executemany(sql, params).fetchall())
|
||||
|
||||
def executescript(self, script: str) -> Awaitable:
|
||||
return self.run(lambda conn: conn.executescript(script))
|
||||
|
||||
def execute_fetchall(self, sql: str, parameters: Iterable = None) -> Awaitable[Iterable[sqlite3.Row]]:
|
||||
parameters = parameters if parameters is not None else []
|
||||
return self.run(lambda conn: conn.execute(sql, parameters).fetchall())
|
||||
|
||||
def execute_fetchone(self, sql: str, parameters: Iterable = None) -> Awaitable[Iterable[sqlite3.Row]]:
|
||||
parameters = parameters if parameters is not None else []
|
||||
return self.run(lambda conn: conn.execute(sql, parameters).fetchone())
|
||||
|
||||
def execute(self, sql: str, parameters: Iterable = None) -> Awaitable[sqlite3.Cursor]:
|
||||
parameters = parameters if parameters is not None else []
|
||||
return self.run(lambda conn: conn.execute(sql, parameters))
|
||||
|
||||
def run(self, fun, *args, **kwargs) -> Awaitable:
|
||||
return asyncio.get_event_loop().run_in_executor(
|
||||
self.executor, lambda: self.__run_transaction(fun, *args, **kwargs)
|
||||
)
|
||||
|
||||
def __run_transaction(self, fun: Callable[[sqlite3.Connection, Any, Any], Any], *args, **kwargs):
|
||||
self.connection.execute('begin')
|
||||
try:
|
||||
self.query_count += 1
|
||||
result = fun(self.connection, *args, **kwargs) # type: ignore
|
||||
self.connection.commit()
|
||||
return result
|
||||
except (Exception, OSError) as e:
|
||||
log.exception('Error running transaction:', exc_info=e)
|
||||
self.connection.rollback()
|
||||
log.warning("rolled back")
|
||||
raise
|
||||
|
||||
def run_with_foreign_keys_disabled(self, fun, *args, **kwargs) -> Awaitable:
|
||||
return asyncio.get_event_loop().run_in_executor(
|
||||
self.executor, self.__run_transaction_with_foreign_keys_disabled, fun, args, kwargs
|
||||
)
|
||||
|
||||
def __run_transaction_with_foreign_keys_disabled(self,
|
||||
fun: Callable[[sqlite3.Connection, Any, Any], Any],
|
||||
args, kwargs):
|
||||
foreign_keys_enabled, = self.connection.execute("pragma foreign_keys").fetchone()
|
||||
if not foreign_keys_enabled:
|
||||
raise sqlite3.IntegrityError("foreign keys are disabled, use `AIOSQLite.run` instead")
|
||||
try:
|
||||
self.connection.execute('pragma foreign_keys=off').fetchone()
|
||||
return self.__run_transaction(fun, *args, **kwargs)
|
||||
finally:
|
||||
self.connection.execute('pragma foreign_keys=on').fetchone()
|
||||
|
||||
|
||||
def constraints_to_sql(constraints, joiner=' AND ', prepend_key=''):
|
||||
sql, values = [], {}
|
||||
for key, constraint in constraints.items():
|
||||
tag = '0'
|
||||
if '#' in key:
|
||||
key, tag = key[:key.index('#')], key[key.index('#')+1:]
|
||||
col, op, key = key, '=', key.replace('.', '_')
|
||||
if not key:
|
||||
sql.append(constraint)
|
||||
continue
|
||||
if key.startswith('$'):
|
||||
values[key] = constraint
|
||||
continue
|
||||
if key.endswith('__not'):
|
||||
col, op = col[:-len('__not')], '!='
|
||||
elif key.endswith('__is_null'):
|
||||
col = col[:-len('__is_null')]
|
||||
sql.append(f'{col} IS NULL')
|
||||
continue
|
||||
if key.endswith('__is_not_null'):
|
||||
col = col[:-len('__is_not_null')]
|
||||
sql.append(f'{col} IS NOT NULL')
|
||||
continue
|
||||
if key.endswith('__lt'):
|
||||
col, op = col[:-len('__lt')], '<'
|
||||
elif key.endswith('__lte'):
|
||||
col, op = col[:-len('__lte')], '<='
|
||||
elif key.endswith('__gt'):
|
||||
col, op = col[:-len('__gt')], '>'
|
||||
elif key.endswith('__gte'):
|
||||
col, op = col[:-len('__gte')], '>='
|
||||
elif key.endswith('__like'):
|
||||
col, op = col[:-len('__like')], 'LIKE'
|
||||
elif key.endswith('__not_like'):
|
||||
col, op = col[:-len('__not_like')], 'NOT LIKE'
|
||||
elif key.endswith('__in') or key.endswith('__not_in'):
|
||||
if key.endswith('__in'):
|
||||
col, op = col[:-len('__in')], 'IN'
|
||||
else:
|
||||
col, op = col[:-len('__not_in')], 'NOT IN'
|
||||
if constraint:
|
||||
if isinstance(constraint, (list, set, tuple)):
|
||||
keys = []
|
||||
for i, val in enumerate(constraint):
|
||||
keys.append(f':{key}{tag}_{i}')
|
||||
values[f'{key}{tag}_{i}'] = val
|
||||
sql.append(f'{col} {op} ({", ".join(keys)})')
|
||||
elif isinstance(constraint, str):
|
||||
sql.append(f'{col} {op} ({constraint})')
|
||||
else:
|
||||
raise ValueError(f"{col} requires a list, set or string as constraint value.")
|
||||
continue
|
||||
elif key.endswith('__any') or key.endswith('__or'):
|
||||
where, subvalues = constraints_to_sql(constraint, ' OR ', key+tag+'_')
|
||||
sql.append(f'({where})')
|
||||
values.update(subvalues)
|
||||
continue
|
||||
if key.endswith('__and'):
|
||||
where, subvalues = constraints_to_sql(constraint, ' AND ', key+tag+'_')
|
||||
sql.append(f'({where})')
|
||||
values.update(subvalues)
|
||||
continue
|
||||
sql.append(f'{col} {op} :{prepend_key}{key}{tag}')
|
||||
values[prepend_key+key+tag] = constraint
|
||||
return joiner.join(sql) if sql else '', values
|
||||
|
||||
|
||||
def query(select, **constraints) -> Tuple[str, Dict[str, Any]]:
|
||||
sql = [select]
|
||||
limit = constraints.pop('limit', None)
|
||||
offset = constraints.pop('offset', None)
|
||||
order_by = constraints.pop('order_by', None)
|
||||
|
||||
accounts = constraints.pop('accounts', [])
|
||||
if accounts:
|
||||
constraints['account__in'] = [a.public_key.address for a in accounts]
|
||||
|
||||
where, values = constraints_to_sql(constraints)
|
||||
if where:
|
||||
sql.append('WHERE')
|
||||
sql.append(where)
|
||||
|
||||
if order_by:
|
||||
sql.append('ORDER BY')
|
||||
if isinstance(order_by, str):
|
||||
sql.append(order_by)
|
||||
elif isinstance(order_by, list):
|
||||
sql.append(', '.join(order_by))
|
||||
else:
|
||||
raise ValueError("order_by must be string or list")
|
||||
|
||||
if limit is not None:
|
||||
sql.append(f'LIMIT {limit}')
|
||||
|
||||
if offset is not None:
|
||||
sql.append(f'OFFSET {offset}')
|
||||
|
||||
return ' '.join(sql), values
|
||||
|
||||
|
||||
def interpolate(sql, values):
|
||||
for k in sorted(values.keys(), reverse=True):
|
||||
value = values[k]
|
||||
if isinstance(value, bytes):
|
||||
value = f"X'{hexlify(value).decode()}'"
|
||||
elif isinstance(value, str):
|
||||
value = f"'{value}'"
|
||||
else:
|
||||
value = str(value)
|
||||
sql = sql.replace(f":{k}", value)
|
||||
return sql
|
||||
|
||||
|
||||
def rows_to_dict(rows, fields):
|
||||
if rows:
|
||||
return [dict(zip(fields, r)) for r in rows]
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
class SQLiteMixin:
|
||||
|
||||
SCHEMA_VERSION: Optional[str] = None
|
||||
CREATE_TABLES_QUERY: str
|
||||
MAX_QUERY_VARIABLES = 900
|
||||
|
||||
CREATE_VERSION_TABLE = """
|
||||
create table if not exists version (
|
||||
version text
|
||||
);
|
||||
"""
|
||||
|
||||
def __init__(self, path):
|
||||
self._db_path = path
|
||||
self.db: AIOSQLite = None
|
||||
self.ledger = None
|
||||
|
||||
async def open(self):
|
||||
log.info("connecting to database: %s", self._db_path)
|
||||
self.db = await AIOSQLite.connect(self._db_path, isolation_level=None)
|
||||
if self.SCHEMA_VERSION:
|
||||
tables = [t[0] for t in await self.db.execute_fetchall(
|
||||
"SELECT name FROM sqlite_master WHERE type='table';"
|
||||
)]
|
||||
if tables:
|
||||
if 'version' in tables:
|
||||
version = await self.db.execute_fetchone("SELECT version FROM version LIMIT 1;")
|
||||
if version == (self.SCHEMA_VERSION,):
|
||||
return
|
||||
await self.db.executescript('\n'.join(
|
||||
f"DROP TABLE {table};" for table in tables
|
||||
))
|
||||
await self.db.execute(self.CREATE_VERSION_TABLE)
|
||||
await self.db.execute("INSERT INTO version VALUES (?)", (self.SCHEMA_VERSION,))
|
||||
await self.db.executescript(self.CREATE_TABLES_QUERY)
|
||||
|
||||
async def close(self):
|
||||
await self.db.close()
|
||||
|
||||
@staticmethod
|
||||
def _insert_sql(table: str, data: dict, ignore_duplicate: bool = False,
|
||||
replace: bool = False) -> Tuple[str, List]:
|
||||
columns, values = [], []
|
||||
for column, value in data.items():
|
||||
columns.append(column)
|
||||
values.append(value)
|
||||
policy = ""
|
||||
if ignore_duplicate:
|
||||
policy = " OR IGNORE"
|
||||
if replace:
|
||||
policy = " OR REPLACE"
|
||||
sql = "INSERT{} INTO {} ({}) VALUES ({})".format(
|
||||
policy, table, ', '.join(columns), ', '.join(['?'] * len(values))
|
||||
)
|
||||
return sql, values
|
||||
|
||||
@staticmethod
|
||||
def _update_sql(table: str, data: dict, where: str,
|
||||
constraints: Union[list, tuple]) -> Tuple[str, list]:
|
||||
columns, values = [], []
|
||||
for column, value in data.items():
|
||||
columns.append(f"{column} = ?")
|
||||
values.append(value)
|
||||
values.extend(constraints)
|
||||
sql = "UPDATE {} SET {} WHERE {}".format(
|
||||
table, ', '.join(columns), where
|
||||
)
|
||||
return sql, values
|
||||
|
||||
|
||||
class BaseDatabase(SQLiteMixin):
|
||||
|
||||
SCHEMA_VERSION = "1.1"
|
||||
|
||||
PRAGMAS = """
|
||||
pragma journal_mode=WAL;
|
||||
"""
|
||||
|
||||
CREATE_ACCOUNT_TABLE = """
|
||||
create table if not exists account_address (
|
||||
account text not null,
|
||||
address text not null,
|
||||
chain integer not null,
|
||||
pubkey blob not null,
|
||||
chain_code blob not null,
|
||||
n integer not null,
|
||||
depth integer not null,
|
||||
primary key (account, address)
|
||||
);
|
||||
create index if not exists address_account_idx on account_address (address, account);
|
||||
"""
|
||||
|
||||
CREATE_PUBKEY_ADDRESS_TABLE = """
|
||||
create table if not exists pubkey_address (
|
||||
address text primary key,
|
||||
history text,
|
||||
used_times integer not null default 0
|
||||
);
|
||||
"""
|
||||
|
||||
CREATE_TX_TABLE = """
|
||||
create table if not exists tx (
|
||||
txid text primary key,
|
||||
raw blob not null,
|
||||
height integer not null,
|
||||
position integer not null,
|
||||
is_verified boolean not null default 0
|
||||
);
|
||||
"""
|
||||
|
||||
CREATE_TXO_TABLE = """
|
||||
create table if not exists txo (
|
||||
txid text references tx,
|
||||
txoid text primary key,
|
||||
address text references pubkey_address,
|
||||
position integer not null,
|
||||
amount integer not null,
|
||||
script blob not null,
|
||||
is_reserved boolean not null default 0
|
||||
);
|
||||
create index if not exists txo_address_idx on txo (address);
|
||||
"""
|
||||
|
||||
CREATE_TXI_TABLE = """
|
||||
create table if not exists txi (
|
||||
txid text references tx,
|
||||
txoid text references txo,
|
||||
address text references pubkey_address
|
||||
);
|
||||
create index if not exists txi_address_idx on txi (address);
|
||||
create index if not exists txi_txoid_idx on txi (txoid);
|
||||
"""
|
||||
|
||||
CREATE_TABLES_QUERY = (
|
||||
PRAGMAS +
|
||||
CREATE_ACCOUNT_TABLE +
|
||||
CREATE_PUBKEY_ADDRESS_TABLE +
|
||||
CREATE_TX_TABLE +
|
||||
CREATE_TXO_TABLE +
|
||||
CREATE_TXI_TABLE
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def txo_to_row(tx, address, txo):
|
||||
return {
|
||||
'txid': tx.id,
|
||||
'txoid': txo.id,
|
||||
'address': address,
|
||||
'position': txo.position,
|
||||
'amount': txo.amount,
|
||||
'script': sqlite3.Binary(txo.script.source)
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def tx_to_row(tx):
|
||||
return {
|
||||
'txid': tx.id,
|
||||
'raw': sqlite3.Binary(tx.raw),
|
||||
'height': tx.height,
|
||||
'position': tx.position,
|
||||
'is_verified': tx.is_verified
|
||||
}
|
||||
|
||||
async def insert_transaction(self, tx):
|
||||
await self.db.execute_fetchall(*self._insert_sql('tx', self.tx_to_row(tx)))
|
||||
|
||||
async def update_transaction(self, tx):
|
||||
await self.db.execute_fetchall(*self._update_sql("tx", {
|
||||
'height': tx.height, 'position': tx.position, 'is_verified': tx.is_verified
|
||||
}, 'txid = ?', (tx.id,)))
|
||||
|
||||
def _transaction_io(self, conn: sqlite3.Connection, tx: BaseTransaction, address, txhash, history):
|
||||
conn.execute(*self._insert_sql('tx', self.tx_to_row(tx), replace=True))
|
||||
|
||||
for txo in tx.outputs:
|
||||
if txo.script.is_pay_pubkey_hash and txo.pubkey_hash == txhash:
|
||||
conn.execute(*self._insert_sql(
|
||||
"txo", self.txo_to_row(tx, address, txo), ignore_duplicate=True
|
||||
)).fetchall()
|
||||
elif txo.script.is_pay_script_hash:
|
||||
# TODO: implement script hash payments
|
||||
log.warning('Database.save_transaction_io: pay script hash is not implemented!')
|
||||
|
||||
for txi in tx.inputs:
|
||||
if txi.txo_ref.txo is not None:
|
||||
txo = txi.txo_ref.txo
|
||||
if txo.has_address and txo.get_address(self.ledger) == address:
|
||||
conn.execute(*self._insert_sql("txi", {
|
||||
'txid': tx.id,
|
||||
'txoid': txo.id,
|
||||
'address': address,
|
||||
}, ignore_duplicate=True)).fetchall()
|
||||
|
||||
conn.execute(
|
||||
"UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?",
|
||||
(history, history.count(':') // 2, address)
|
||||
)
|
||||
|
||||
def save_transaction_io(self, tx: BaseTransaction, address, txhash, history):
|
||||
return self.db.run(self._transaction_io, tx, address, txhash, history)
|
||||
|
||||
def save_transaction_io_batch(self, txs: Iterable[BaseTransaction], address, txhash, history):
|
||||
def __many(conn):
|
||||
for tx in txs:
|
||||
self._transaction_io(conn, tx, address, txhash, history)
|
||||
return self.db.run(__many)
|
||||
|
||||
async def reserve_outputs(self, txos, is_reserved=True):
|
||||
txoids = ((is_reserved, txo.id) for txo in txos)
|
||||
await self.db.executemany("UPDATE txo SET is_reserved = ? WHERE txoid = ?", txoids)
|
||||
|
||||
async def release_outputs(self, txos):
|
||||
await self.reserve_outputs(txos, is_reserved=False)
|
||||
|
||||
async def rewind_blockchain(self, above_height): # pylint: disable=no-self-use
|
||||
# TODO:
|
||||
# 1. delete transactions above_height
|
||||
# 2. update address histories removing deleted TXs
|
||||
return True
|
||||
|
||||
async def select_transactions(self, cols, accounts=None, **constraints):
|
||||
if not {'txid', 'txid__in'}.intersection(constraints):
|
||||
assert accounts, "'accounts' argument required when no 'txid' constraint is present"
|
||||
constraints.update({
|
||||
f'$account{i}': a.public_key.address for i, a in enumerate(accounts)
|
||||
})
|
||||
account_values = ', '.join([f':$account{i}' for i in range(len(accounts))])
|
||||
where = f" WHERE account_address.account IN ({account_values})"
|
||||
constraints['txid__in'] = f"""
|
||||
SELECT txo.txid FROM txo JOIN account_address USING (address) {where}
|
||||
UNION
|
||||
SELECT txi.txid FROM txi JOIN account_address USING (address) {where}
|
||||
"""
|
||||
return await self.db.execute_fetchall(
|
||||
*query(f"SELECT {cols} FROM tx", **constraints)
|
||||
)
|
||||
|
||||
async def get_transactions(self, wallet=None, **constraints):
|
||||
tx_rows = await self.select_transactions(
|
||||
'txid, raw, height, position, is_verified',
|
||||
order_by=constraints.pop('order_by', ["height=0 DESC", "height DESC", "position DESC"]),
|
||||
**constraints
|
||||
)
|
||||
|
||||
if not tx_rows:
|
||||
return []
|
||||
|
||||
txids, txs, txi_txoids = [], [], []
|
||||
for row in tx_rows:
|
||||
txids.append(row[0])
|
||||
txs.append(self.ledger.transaction_class(
|
||||
raw=row[1], height=row[2], position=row[3], is_verified=bool(row[4])
|
||||
))
|
||||
for txi in txs[-1].inputs:
|
||||
txi_txoids.append(txi.txo_ref.id)
|
||||
|
||||
step = self.MAX_QUERY_VARIABLES
|
||||
annotated_txos = {}
|
||||
for offset in range(0, len(txids), step):
|
||||
annotated_txos.update({
|
||||
txo.id: txo for txo in
|
||||
(await self.get_txos(
|
||||
wallet=wallet,
|
||||
txid__in=txids[offset:offset+step],
|
||||
))
|
||||
})
|
||||
|
||||
referenced_txos = {}
|
||||
for offset in range(0, len(txi_txoids), step):
|
||||
referenced_txos.update({
|
||||
txo.id: txo for txo in
|
||||
(await self.get_txos(
|
||||
wallet=wallet,
|
||||
txoid__in=txi_txoids[offset:offset+step],
|
||||
))
|
||||
})
|
||||
|
||||
for tx in txs:
|
||||
for txi in tx.inputs:
|
||||
txo = referenced_txos.get(txi.txo_ref.id)
|
||||
if txo:
|
||||
txi.txo_ref = txo.ref
|
||||
for txo in tx.outputs:
|
||||
_txo = annotated_txos.get(txo.id)
|
||||
if _txo:
|
||||
txo.update_annotations(_txo)
|
||||
else:
|
||||
txo.update_annotations(None)
|
||||
|
||||
return txs
|
||||
|
||||
async def get_transaction_count(self, **constraints):
|
||||
constraints.pop('wallet', None)
|
||||
constraints.pop('offset', None)
|
||||
constraints.pop('limit', None)
|
||||
constraints.pop('order_by', None)
|
||||
count = await self.select_transactions('count(*)', **constraints)
|
||||
return count[0][0]
|
||||
|
||||
async def get_transaction(self, **constraints):
|
||||
txs = await self.get_transactions(limit=1, **constraints)
|
||||
if txs:
|
||||
return txs[0]
|
||||
|
||||
async def select_txos(self, cols, **constraints):
|
||||
sql = f"SELECT {cols} FROM txo JOIN tx USING (txid)"
|
||||
if 'accounts' in constraints:
|
||||
sql += " JOIN account_address USING (address)"
|
||||
return await self.db.execute_fetchall(*query(sql, **constraints))
|
||||
|
||||
async def get_txos(self, wallet=None, no_tx=False, **constraints):
|
||||
my_accounts = {a.public_key.address for a in wallet.accounts} if wallet else set()
|
||||
if 'order_by' not in constraints:
|
||||
constraints['order_by'] = [
|
||||
"tx.height=0 DESC", "tx.height DESC", "tx.position DESC", "txo.position"
|
||||
]
|
||||
rows = await self.select_txos(
|
||||
"""
|
||||
tx.txid, raw, tx.height, tx.position, tx.is_verified, txo.position, amount, script, (
|
||||
select group_concat(account||"|"||chain) from account_address
|
||||
where account_address.address=txo.address
|
||||
)
|
||||
""",
|
||||
**constraints
|
||||
)
|
||||
txos = []
|
||||
txs = {}
|
||||
output_class = self.ledger.transaction_class.output_class
|
||||
for row in rows:
|
||||
if no_tx:
|
||||
txo = output_class(
|
||||
amount=row[6],
|
||||
script=output_class.script_class(row[7]),
|
||||
tx_ref=TXRefImmutable.from_id(row[0], row[2]),
|
||||
position=row[5]
|
||||
)
|
||||
else:
|
||||
if row[0] not in txs:
|
||||
txs[row[0]] = self.ledger.transaction_class(
|
||||
row[1], height=row[2], position=row[3], is_verified=row[4]
|
||||
)
|
||||
txo = txs[row[0]].outputs[row[5]]
|
||||
row_accounts = dict(a.split('|') for a in row[8].split(','))
|
||||
account_match = set(row_accounts) & my_accounts
|
||||
if account_match:
|
||||
txo.is_my_account = True
|
||||
txo.is_change = row_accounts[account_match.pop()] == '1'
|
||||
else:
|
||||
txo.is_change = txo.is_my_account = False
|
||||
txos.append(txo)
|
||||
return txos
|
||||
|
||||
async def get_txo_count(self, **constraints):
|
||||
constraints.pop('wallet', None)
|
||||
constraints.pop('offset', None)
|
||||
constraints.pop('limit', None)
|
||||
constraints.pop('order_by', None)
|
||||
count = await self.select_txos('count(*)', **constraints)
|
||||
return count[0][0]
|
||||
|
||||
@staticmethod
|
||||
def constrain_utxo(constraints):
|
||||
constraints['is_reserved'] = False
|
||||
constraints['txoid__not_in'] = "SELECT txoid FROM txi"
|
||||
|
||||
def get_utxos(self, **constraints):
|
||||
self.constrain_utxo(constraints)
|
||||
return self.get_txos(**constraints)
|
||||
|
||||
def get_utxo_count(self, **constraints):
|
||||
self.constrain_utxo(constraints)
|
||||
return self.get_txo_count(**constraints)
|
||||
|
||||
async def get_balance(self, wallet=None, accounts=None, **constraints):
|
||||
assert wallet or accounts, \
|
||||
"'wallet' or 'accounts' constraints required to calculate balance"
|
||||
constraints['accounts'] = accounts or wallet.accounts
|
||||
self.constrain_utxo(constraints)
|
||||
balance = await self.select_txos('SUM(amount)', **constraints)
|
||||
return balance[0][0] or 0
|
||||
|
||||
async def select_addresses(self, cols, **constraints):
|
||||
return await self.db.execute_fetchall(*query(
|
||||
f"SELECT {cols} FROM pubkey_address JOIN account_address USING (address)",
|
||||
**constraints
|
||||
))
|
||||
|
||||
async def get_addresses(self, cols=None, **constraints):
|
||||
cols = cols or (
|
||||
'address', 'account', 'chain', 'history', 'used_times',
|
||||
'pubkey', 'chain_code', 'n', 'depth'
|
||||
)
|
||||
addresses = rows_to_dict(await self.select_addresses(', '.join(cols), **constraints), cols)
|
||||
if 'pubkey' in cols:
|
||||
for address in addresses:
|
||||
address['pubkey'] = PubKey(
|
||||
self.ledger, address.pop('pubkey'), address.pop('chain_code'),
|
||||
address.pop('n'), address.pop('depth')
|
||||
)
|
||||
return addresses
|
||||
|
||||
async def get_address_count(self, cols=None, **constraints):
|
||||
count = await self.select_addresses('count(*)', **constraints)
|
||||
return count[0][0]
|
||||
|
||||
async def get_address(self, **constraints):
|
||||
addresses = await self.get_addresses(limit=1, **constraints)
|
||||
if addresses:
|
||||
return addresses[0]
|
||||
|
||||
async def add_keys(self, account, chain, pubkeys):
|
||||
await self.db.executemany(
|
||||
"insert or ignore into account_address "
|
||||
"(account, address, chain, pubkey, chain_code, n, depth) values "
|
||||
"(?, ?, ?, ?, ?, ?, ?)", ((
|
||||
account.id, k.address, chain,
|
||||
sqlite3.Binary(k.pubkey_bytes),
|
||||
sqlite3.Binary(k.chain_code),
|
||||
k.n, k.depth
|
||||
) for k in pubkeys)
|
||||
)
|
||||
await self.db.executemany(
|
||||
"insert or ignore into pubkey_address (address) values (?)",
|
||||
((pubkey.address,) for pubkey in pubkeys)
|
||||
)
|
||||
|
||||
async def _set_address_history(self, address, history):
|
||||
await self.db.execute_fetchall(
|
||||
"UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?",
|
||||
(history, history.count(':')//2, address)
|
||||
)
|
||||
|
||||
async def set_address_history(self, address, history):
|
||||
await self._set_address_history(address, history)
|
|
@ -1,246 +0,0 @@
|
|||
import asyncio
|
||||
import hashlib
|
||||
import os
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from io import BytesIO
|
||||
from typing import Optional, Iterator, Tuple
|
||||
from binascii import hexlify
|
||||
|
||||
from lbry.wallet.client.util import ArithUint256
|
||||
from lbry.crypto.hash import double_sha256
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InvalidHeader(Exception):
|
||||
|
||||
def __init__(self, height, message):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.height = height
|
||||
|
||||
|
||||
class BaseHeaders:
|
||||
|
||||
header_size: int
|
||||
chunk_size: int
|
||||
|
||||
max_target: int
|
||||
genesis_hash: Optional[bytes]
|
||||
target_timespan: int
|
||||
|
||||
validate_difficulty: bool = True
|
||||
checkpoint = None
|
||||
|
||||
def __init__(self, path) -> None:
|
||||
if path == ':memory:':
|
||||
self.io = BytesIO()
|
||||
self.path = path
|
||||
self._size: Optional[int] = None
|
||||
|
||||
async def open(self):
|
||||
if self.path != ':memory:':
|
||||
if not os.path.exists(self.path):
|
||||
self.io = open(self.path, 'w+b')
|
||||
else:
|
||||
self.io = open(self.path, 'r+b')
|
||||
|
||||
async def close(self):
|
||||
self.io.close()
|
||||
|
||||
@staticmethod
|
||||
def serialize(header: dict) -> bytes:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def deserialize(height, header):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_next_chunk_target(self, chunk: int) -> ArithUint256:
|
||||
return ArithUint256(self.max_target)
|
||||
|
||||
@staticmethod
|
||||
def get_next_block_target(chunk_target: ArithUint256, previous: Optional[dict],
|
||||
current: Optional[dict]) -> ArithUint256:
|
||||
return chunk_target
|
||||
|
||||
def __len__(self) -> int:
|
||||
if self._size is None:
|
||||
self._size = self.io.seek(0, os.SEEK_END) // self.header_size
|
||||
return self._size
|
||||
|
||||
def __bool__(self):
|
||||
return True
|
||||
|
||||
def __getitem__(self, height) -> dict:
|
||||
if isinstance(height, slice):
|
||||
raise NotImplementedError("Slicing of header chain has not been implemented yet.")
|
||||
if not 0 <= height <= self.height:
|
||||
raise IndexError(f"{height} is out of bounds, current height: {self.height}")
|
||||
return self.deserialize(height, self.get_raw_header(height))
|
||||
|
||||
def get_raw_header(self, height) -> bytes:
|
||||
self.io.seek(height * self.header_size, os.SEEK_SET)
|
||||
return self.io.read(self.header_size)
|
||||
|
||||
@property
|
||||
def height(self) -> int:
|
||||
return len(self)-1
|
||||
|
||||
@property
|
||||
def bytes_size(self):
|
||||
return len(self) * self.header_size
|
||||
|
||||
def hash(self, height=None) -> bytes:
|
||||
return self.hash_header(
|
||||
self.get_raw_header(height if height is not None else self.height)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def hash_header(header: bytes) -> bytes:
|
||||
if header is None:
|
||||
return b'0' * 64
|
||||
return hexlify(double_sha256(header)[::-1])
|
||||
|
||||
@asynccontextmanager
|
||||
async def checkpointed_connector(self):
|
||||
buf = BytesIO()
|
||||
try:
|
||||
yield buf
|
||||
finally:
|
||||
await asyncio.sleep(0)
|
||||
final_height = len(self) + buf.tell() // self.header_size
|
||||
verifiable_bytes = (self.checkpoint[0] - len(self)) * self.header_size if self.checkpoint else 0
|
||||
if verifiable_bytes > 0 and final_height >= self.checkpoint[0]:
|
||||
buf.seek(0)
|
||||
self.io.seek(0)
|
||||
h = hashlib.sha256()
|
||||
h.update(self.io.read())
|
||||
h.update(buf.read(verifiable_bytes))
|
||||
if h.hexdigest().encode() == self.checkpoint[1]:
|
||||
buf.seek(0)
|
||||
self._write(len(self), buf.read(verifiable_bytes))
|
||||
remaining = buf.read()
|
||||
buf.seek(0)
|
||||
buf.write(remaining)
|
||||
buf.truncate()
|
||||
else:
|
||||
log.warning("Checkpoint mismatch, connecting headers through slow method.")
|
||||
if buf.tell() > 0:
|
||||
await self.connect(len(self), buf.getvalue())
|
||||
|
||||
async def connect(self, start: int, headers: bytes) -> int:
|
||||
added = 0
|
||||
bail = False
|
||||
for height, chunk in self._iterate_chunks(start, headers):
|
||||
try:
|
||||
# validate_chunk() is CPU bound and reads previous chunks from file system
|
||||
self.validate_chunk(height, chunk)
|
||||
except InvalidHeader as e:
|
||||
bail = True
|
||||
chunk = chunk[:(height-e.height)*self.header_size]
|
||||
added += self._write(height, chunk) if chunk else 0
|
||||
if bail:
|
||||
break
|
||||
return added
|
||||
|
||||
def _write(self, height, verified_chunk):
|
||||
self.io.seek(height * self.header_size, os.SEEK_SET)
|
||||
written = self.io.write(verified_chunk) // self.header_size
|
||||
self.io.truncate()
|
||||
# .seek()/.write()/.truncate() might also .flush() when needed
|
||||
# the goal here is mainly to ensure we're definitely flush()'ing
|
||||
self.io.flush()
|
||||
self._size = self.io.tell() // self.header_size
|
||||
return written
|
||||
|
||||
def validate_chunk(self, height, chunk):
|
||||
previous_hash, previous_header, previous_previous_header = None, None, None
|
||||
if height > 0:
|
||||
previous_header = self[height-1]
|
||||
previous_hash = self.hash(height-1)
|
||||
if height > 1:
|
||||
previous_previous_header = self[height-2]
|
||||
chunk_target = self.get_next_chunk_target(height // 2016 - 1)
|
||||
for current_hash, current_header in self._iterate_headers(height, chunk):
|
||||
block_target = self.get_next_block_target(chunk_target, previous_previous_header, previous_header)
|
||||
self.validate_header(height, current_hash, current_header, previous_hash, block_target)
|
||||
previous_previous_header = previous_header
|
||||
previous_header = current_header
|
||||
previous_hash = current_hash
|
||||
|
||||
def validate_header(self, height: int, current_hash: bytes,
|
||||
header: dict, previous_hash: bytes, target: ArithUint256):
|
||||
|
||||
if previous_hash is None:
|
||||
if self.genesis_hash is not None and self.genesis_hash != current_hash:
|
||||
raise InvalidHeader(
|
||||
height, f"genesis header doesn't match: {current_hash.decode()} "
|
||||
f"vs expected {self.genesis_hash.decode()}")
|
||||
return
|
||||
|
||||
if header['prev_block_hash'] != previous_hash:
|
||||
raise InvalidHeader(
|
||||
height, "previous hash mismatch: {} vs expected {}".format(
|
||||
header['prev_block_hash'].decode(), previous_hash.decode())
|
||||
)
|
||||
|
||||
if self.validate_difficulty:
|
||||
|
||||
if header['bits'] != target.compact:
|
||||
raise InvalidHeader(
|
||||
height, "bits mismatch: {} vs expected {}".format(
|
||||
header['bits'], target.compact)
|
||||
)
|
||||
|
||||
proof_of_work = self.get_proof_of_work(current_hash)
|
||||
if proof_of_work > target:
|
||||
raise InvalidHeader(
|
||||
height, f"insufficient proof of work: {proof_of_work.value} vs target {target.value}"
|
||||
)
|
||||
|
||||
async def repair(self):
|
||||
previous_header_hash = fail = None
|
||||
batch_size = 36
|
||||
for start_height in range(0, self.height, batch_size):
|
||||
self.io.seek(self.header_size * start_height)
|
||||
headers = self.io.read(self.header_size*batch_size)
|
||||
if len(headers) % self.header_size != 0:
|
||||
headers = headers[:(len(headers) // self.header_size) * self.header_size]
|
||||
for header_hash, header in self._iterate_headers(start_height, headers):
|
||||
height = header['block_height']
|
||||
if height:
|
||||
if header['prev_block_hash'] != previous_header_hash:
|
||||
fail = True
|
||||
else:
|
||||
if header_hash != self.genesis_hash:
|
||||
fail = True
|
||||
if fail:
|
||||
log.warning("Header file corrupted at height %s, truncating it.", height - 1)
|
||||
self.io.seek(max(0, (height - 1)) * self.header_size, os.SEEK_SET)
|
||||
self.io.truncate()
|
||||
self.io.flush()
|
||||
self._size = None
|
||||
return
|
||||
previous_header_hash = header_hash
|
||||
|
||||
@staticmethod
|
||||
def get_proof_of_work(header_hash: bytes) -> ArithUint256:
|
||||
return ArithUint256(int(b'0x' + header_hash, 16))
|
||||
|
||||
def _iterate_chunks(self, height: int, headers: bytes) -> Iterator[Tuple[int, bytes]]:
|
||||
assert len(headers) % self.header_size == 0, f"{len(headers)} {len(headers)%self.header_size}"
|
||||
start = 0
|
||||
end = (self.chunk_size - height % self.chunk_size) * self.header_size
|
||||
while start < end:
|
||||
yield height + (start // self.header_size), headers[start:end]
|
||||
start = end
|
||||
end = min(len(headers), end + self.chunk_size * self.header_size)
|
||||
|
||||
def _iterate_headers(self, height: int, headers: bytes) -> Iterator[Tuple[bytes, dict]]:
|
||||
assert len(headers) % self.header_size == 0, len(headers)
|
||||
for idx in range(len(headers) // self.header_size):
|
||||
start, end = idx * self.header_size, (idx + 1) * self.header_size
|
||||
header = headers[start:end]
|
||||
yield self.hash_header(header), self.deserialize(height+idx, header)
|
|
@ -1,605 +0,0 @@
|
|||
import base64
|
||||
import os
|
||||
import asyncio
|
||||
import logging
|
||||
import zlib
|
||||
from functools import partial
|
||||
from binascii import hexlify, unhexlify
|
||||
from io import StringIO
|
||||
|
||||
from typing import Dict, Type, Iterable, List, Optional
|
||||
from operator import itemgetter
|
||||
from collections import namedtuple
|
||||
|
||||
import pylru
|
||||
from lbry.wallet.client.basetransaction import BaseTransaction
|
||||
from lbry.wallet.tasks import TaskGroup
|
||||
from lbry.wallet.client import baseaccount, basenetwork, basetransaction
|
||||
from lbry.wallet.client.basedatabase import BaseDatabase
|
||||
from lbry.wallet.client.baseheader import BaseHeaders
|
||||
from lbry.wallet.client.coinselection import CoinSelector
|
||||
from lbry.wallet.client.constants import COIN, NULL_HASH32
|
||||
from lbry.wallet.stream import StreamController
|
||||
from lbry.crypto.hash import hash160, double_sha256, sha256
|
||||
from lbry.crypto.base58 import Base58
|
||||
from lbry.wallet.client.bip32 import PubKey, PrivateKey
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
LedgerType = Type['BaseLedger']
|
||||
|
||||
|
||||
class LedgerRegistry(type):
|
||||
|
||||
ledgers: Dict[str, LedgerType] = {}
|
||||
|
||||
def __new__(mcs, name, bases, attrs):
|
||||
cls: LedgerType = super().__new__(mcs, name, bases, attrs)
|
||||
if not (name == 'BaseLedger' and not bases):
|
||||
ledger_id = cls.get_id()
|
||||
assert ledger_id not in mcs.ledgers,\
|
||||
f'Ledger with id "{ledger_id}" already registered.'
|
||||
mcs.ledgers[ledger_id] = cls
|
||||
return cls
|
||||
|
||||
@classmethod
|
||||
def get_ledger_class(mcs, ledger_id: str) -> LedgerType:
|
||||
return mcs.ledgers[ledger_id]
|
||||
|
||||
|
||||
class TransactionEvent(namedtuple('TransactionEvent', ('address', 'tx'))):
|
||||
pass
|
||||
|
||||
|
||||
class AddressesGeneratedEvent(namedtuple('AddressesGeneratedEvent', ('address_manager', 'addresses'))):
|
||||
pass
|
||||
|
||||
|
||||
class BlockHeightEvent(namedtuple('BlockHeightEvent', ('height', 'change'))):
|
||||
pass
|
||||
|
||||
|
||||
class TransactionCacheItem:
|
||||
__slots__ = '_tx', 'lock', 'has_tx'
|
||||
|
||||
def __init__(self,
|
||||
tx: Optional[basetransaction.BaseTransaction] = None,
|
||||
lock: Optional[asyncio.Lock] = None):
|
||||
self.has_tx = asyncio.Event()
|
||||
self.lock = lock or asyncio.Lock()
|
||||
self._tx = self.tx = tx
|
||||
|
||||
@property
|
||||
def tx(self) -> Optional[basetransaction.BaseTransaction]:
|
||||
return self._tx
|
||||
|
||||
@tx.setter
|
||||
def tx(self, tx: basetransaction.BaseTransaction):
|
||||
self._tx = tx
|
||||
if tx is not None:
|
||||
self.has_tx.set()
|
||||
|
||||
|
||||
class BaseLedger(metaclass=LedgerRegistry):
|
||||
|
||||
name: str
|
||||
symbol: str
|
||||
network_name: str
|
||||
|
||||
database_class = BaseDatabase
|
||||
account_class = baseaccount.BaseAccount
|
||||
network_class = basenetwork.BaseNetwork
|
||||
transaction_class = basetransaction.BaseTransaction
|
||||
|
||||
headers_class: Type[BaseHeaders]
|
||||
|
||||
pubkey_address_prefix: bytes
|
||||
script_address_prefix: bytes
|
||||
extended_public_key_prefix: bytes
|
||||
extended_private_key_prefix: bytes
|
||||
|
||||
default_fee_per_byte = 10
|
||||
|
||||
def __init__(self, config=None):
|
||||
self.config = config or {}
|
||||
self.db: BaseDatabase = self.config.get('db') or self.database_class(
|
||||
os.path.join(self.path, "blockchain.db")
|
||||
)
|
||||
self.db.ledger = self
|
||||
self.headers: BaseHeaders = self.config.get('headers') or self.headers_class(
|
||||
os.path.join(self.path, "headers")
|
||||
)
|
||||
self.network = self.config.get('network') or self.network_class(self)
|
||||
self.network.on_header.listen(self.receive_header)
|
||||
self.network.on_status.listen(self.process_status_update)
|
||||
self.network.on_connected.listen(self.join_network)
|
||||
|
||||
self.accounts = []
|
||||
self.fee_per_byte: int = self.config.get('fee_per_byte', self.default_fee_per_byte)
|
||||
|
||||
self._on_transaction_controller = StreamController()
|
||||
self.on_transaction = self._on_transaction_controller.stream
|
||||
self.on_transaction.listen(
|
||||
lambda e: log.info(
|
||||
'(%s) on_transaction: address=%s, height=%s, is_verified=%s, tx.id=%s',
|
||||
self.get_id(), e.address, e.tx.height, e.tx.is_verified, e.tx.id
|
||||
)
|
||||
)
|
||||
|
||||
self._on_address_controller = StreamController()
|
||||
self.on_address = self._on_address_controller.stream
|
||||
self.on_address.listen(
|
||||
lambda e: log.info('(%s) on_address: %s', self.get_id(), e.addresses)
|
||||
)
|
||||
|
||||
self._on_header_controller = StreamController()
|
||||
self.on_header = self._on_header_controller.stream
|
||||
self.on_header.listen(
|
||||
lambda change: log.info(
|
||||
'%s: added %s header blocks, final height %s',
|
||||
self.get_id(), change, self.headers.height
|
||||
)
|
||||
)
|
||||
self._download_height = 0
|
||||
|
||||
self._on_ready_controller = StreamController()
|
||||
self.on_ready = self._on_ready_controller.stream
|
||||
|
||||
self._tx_cache = pylru.lrucache(100000)
|
||||
self._update_tasks = TaskGroup()
|
||||
self._utxo_reservation_lock = asyncio.Lock()
|
||||
self._header_processing_lock = asyncio.Lock()
|
||||
self._address_update_locks: Dict[str, asyncio.Lock] = {}
|
||||
|
||||
self.coin_selection_strategy = None
|
||||
self._known_addresses_out_of_sync = set()
|
||||
|
||||
@classmethod
|
||||
def get_id(cls):
|
||||
return '{}_{}'.format(cls.symbol.lower(), cls.network_name.lower())
|
||||
|
||||
@classmethod
|
||||
def hash160_to_address(cls, h160):
|
||||
raw_address = cls.pubkey_address_prefix + h160
|
||||
return Base58.encode(bytearray(raw_address + double_sha256(raw_address)[0:4]))
|
||||
|
||||
@staticmethod
|
||||
def address_to_hash160(address):
|
||||
return Base58.decode(address)[1:21]
|
||||
|
||||
@classmethod
|
||||
def is_valid_address(cls, address):
|
||||
decoded = Base58.decode_check(address)
|
||||
return decoded[0] == cls.pubkey_address_prefix[0]
|
||||
|
||||
@classmethod
|
||||
def public_key_to_address(cls, public_key):
|
||||
return cls.hash160_to_address(hash160(public_key))
|
||||
|
||||
@staticmethod
|
||||
def private_key_to_wif(private_key):
|
||||
return b'\x1c' + private_key + b'\x01'
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
return os.path.join(self.config['data_path'], self.get_id())
|
||||
|
||||
def add_account(self, account: baseaccount.BaseAccount):
|
||||
self.accounts.append(account)
|
||||
|
||||
async def _get_account_and_address_info_for_address(self, wallet, address):
|
||||
match = await self.db.get_address(accounts=wallet.accounts, address=address)
|
||||
if match:
|
||||
for account in wallet.accounts:
|
||||
if match['account'] == account.public_key.address:
|
||||
return account, match
|
||||
|
||||
async def get_private_key_for_address(self, wallet, address) -> Optional[PrivateKey]:
|
||||
match = await self._get_account_and_address_info_for_address(wallet, address)
|
||||
if match:
|
||||
account, address_info = match
|
||||
return account.get_private_key(address_info['chain'], address_info['pubkey'].n)
|
||||
return None
|
||||
|
||||
async def get_public_key_for_address(self, wallet, address) -> Optional[PubKey]:
|
||||
match = await self._get_account_and_address_info_for_address(wallet, address)
|
||||
if match:
|
||||
_, address_info = match
|
||||
return address_info['pubkey']
|
||||
return None
|
||||
|
||||
async def get_account_for_address(self, wallet, address):
|
||||
match = await self._get_account_and_address_info_for_address(wallet, address)
|
||||
if match:
|
||||
return match[0]
|
||||
|
||||
async def get_effective_amount_estimators(self, funding_accounts: Iterable[baseaccount.BaseAccount]):
|
||||
estimators = []
|
||||
for account in funding_accounts:
|
||||
utxos = await account.get_utxos()
|
||||
for utxo in utxos:
|
||||
estimators.append(utxo.get_estimator(self))
|
||||
return estimators
|
||||
|
||||
async def get_addresses(self, **constraints):
|
||||
return await self.db.get_addresses(**constraints)
|
||||
|
||||
def get_address_count(self, **constraints):
|
||||
return self.db.get_address_count(**constraints)
|
||||
|
||||
async def get_spendable_utxos(self, amount: int, funding_accounts):
|
||||
async with self._utxo_reservation_lock:
|
||||
txos = await self.get_effective_amount_estimators(funding_accounts)
|
||||
fee = self.transaction_class.output_class.pay_pubkey_hash(COIN, NULL_HASH32).get_fee(self)
|
||||
selector = CoinSelector(amount, fee)
|
||||
spendables = selector.select(txos, self.coin_selection_strategy)
|
||||
if spendables:
|
||||
await self.reserve_outputs(s.txo for s in spendables)
|
||||
return spendables
|
||||
|
||||
def reserve_outputs(self, txos):
|
||||
return self.db.reserve_outputs(txos)
|
||||
|
||||
def release_outputs(self, txos):
|
||||
return self.db.release_outputs(txos)
|
||||
|
||||
def release_tx(self, tx):
|
||||
return self.release_outputs([txi.txo_ref.txo for txi in tx.inputs])
|
||||
|
||||
def get_utxos(self, **constraints):
|
||||
return self.db.get_utxos(**constraints)
|
||||
|
||||
def get_utxo_count(self, **constraints):
|
||||
return self.db.get_utxo_count(**constraints)
|
||||
|
||||
def get_transactions(self, **constraints):
|
||||
return self.db.get_transactions(**constraints)
|
||||
|
||||
def get_transaction_count(self, **constraints):
|
||||
return self.db.get_transaction_count(**constraints)
|
||||
|
||||
async def get_local_status_and_history(self, address, history=None):
|
||||
if not history:
|
||||
address_details = await self.db.get_address(address=address)
|
||||
history = address_details['history'] or ''
|
||||
parts = history.split(':')[:-1]
|
||||
return (
|
||||
hexlify(sha256(history.encode())).decode() if history else None,
|
||||
list(zip(parts[0::2], map(int, parts[1::2])))
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_root_of_merkle_tree(branches, branch_positions, working_branch):
|
||||
for i, branch in enumerate(branches):
|
||||
other_branch = unhexlify(branch)[::-1]
|
||||
other_branch_on_left = bool((branch_positions >> i) & 1)
|
||||
if other_branch_on_left:
|
||||
combined = other_branch + working_branch
|
||||
else:
|
||||
combined = working_branch + other_branch
|
||||
working_branch = double_sha256(combined)
|
||||
return hexlify(working_branch[::-1])
|
||||
|
||||
async def start(self):
|
||||
if not os.path.exists(self.path):
|
||||
os.mkdir(self.path)
|
||||
await asyncio.wait([
|
||||
self.db.open(),
|
||||
self.headers.open()
|
||||
])
|
||||
first_connection = self.network.on_connected.first
|
||||
asyncio.ensure_future(self.network.start())
|
||||
await first_connection
|
||||
async with self._header_processing_lock:
|
||||
await self._update_tasks.add(self.initial_headers_sync())
|
||||
await self._on_ready_controller.stream.first
|
||||
|
||||
async def join_network(self, *_):
|
||||
log.info("Subscribing and updating accounts.")
|
||||
async with self._header_processing_lock:
|
||||
await self.update_headers()
|
||||
await self.subscribe_accounts()
|
||||
await self._update_tasks.done.wait()
|
||||
self._on_ready_controller.add(True)
|
||||
|
||||
async def stop(self):
|
||||
self._update_tasks.cancel()
|
||||
await self._update_tasks.done.wait()
|
||||
await self.network.stop()
|
||||
await self.db.close()
|
||||
await self.headers.close()
|
||||
|
||||
@property
|
||||
def local_height_including_downloaded_height(self):
|
||||
return max(self.headers.height, self._download_height)
|
||||
|
||||
async def initial_headers_sync(self):
|
||||
target = self.network.remote_height + 1
|
||||
current = len(self.headers)
|
||||
get_chunk = partial(self.network.retriable_call, self.network.get_headers, count=4096, b64=True)
|
||||
chunks = [asyncio.create_task(get_chunk(height)) for height in range(current, target, 4096)]
|
||||
total = 0
|
||||
async with self.headers.checkpointed_connector() as buffer:
|
||||
for chunk in chunks:
|
||||
headers = await chunk
|
||||
total += buffer.write(
|
||||
zlib.decompress(base64.b64decode(headers['base64']), wbits=-15, bufsize=600_000)
|
||||
)
|
||||
self._download_height = current + total // self.headers.header_size
|
||||
log.info("Headers sync: %s / %s", self._download_height, target)
|
||||
|
||||
async def update_headers(self, height=None, headers=None, subscription_update=False):
|
||||
rewound = 0
|
||||
while True:
|
||||
|
||||
if height is None or height > len(self.headers):
|
||||
# sometimes header subscription updates are for a header in the future
|
||||
# which can't be connected, so we do a normal header sync instead
|
||||
height = len(self.headers)
|
||||
headers = None
|
||||
subscription_update = False
|
||||
|
||||
if not headers:
|
||||
header_response = await self.network.retriable_call(self.network.get_headers, height, 2001)
|
||||
headers = header_response['hex']
|
||||
|
||||
if not headers:
|
||||
# Nothing to do, network thinks we're already at the latest height.
|
||||
return
|
||||
|
||||
added = await self.headers.connect(height, unhexlify(headers))
|
||||
if added > 0:
|
||||
height += added
|
||||
self._on_header_controller.add(
|
||||
BlockHeightEvent(self.headers.height, added))
|
||||
|
||||
if rewound > 0:
|
||||
# we started rewinding blocks and apparently found
|
||||
# a new chain
|
||||
rewound = 0
|
||||
await self.db.rewind_blockchain(height)
|
||||
|
||||
if subscription_update:
|
||||
# subscription updates are for latest header already
|
||||
# so we don't need to check if there are newer / more
|
||||
# on another loop of update_headers(), just return instead
|
||||
return
|
||||
|
||||
elif added == 0:
|
||||
# we had headers to connect but none got connected, probably a reorganization
|
||||
height -= 1
|
||||
rewound += 1
|
||||
log.warning(
|
||||
"Blockchain Reorganization: attempting rewind to height %s from starting height %s",
|
||||
height, height+rewound
|
||||
)
|
||||
|
||||
else:
|
||||
raise IndexError(f"headers.connect() returned negative number ({added})")
|
||||
|
||||
if height < 0:
|
||||
raise IndexError(
|
||||
"Blockchain reorganization rewound all the way back to genesis hash. "
|
||||
"Something is very wrong. Maybe you are on the wrong blockchain?"
|
||||
)
|
||||
|
||||
if rewound >= 100:
|
||||
raise IndexError(
|
||||
"Blockchain reorganization dropped {} headers. This is highly unusual. "
|
||||
"Will not continue to attempt reorganizing. Please, delete the ledger "
|
||||
"synchronization directory inside your wallet directory (folder: '{}') and "
|
||||
"restart the program to synchronize from scratch."
|
||||
.format(rewound, self.get_id())
|
||||
)
|
||||
|
||||
headers = None # ready to download some more headers
|
||||
|
||||
# if we made it this far and this was a subscription_update
|
||||
# it means something went wrong and now we're doing a more
|
||||
# robust sync, turn off subscription update shortcut
|
||||
subscription_update = False
|
||||
|
||||
async def receive_header(self, response):
|
||||
async with self._header_processing_lock:
|
||||
header = response[0]
|
||||
await self.update_headers(
|
||||
height=header['height'], headers=header['hex'], subscription_update=True
|
||||
)
|
||||
|
||||
async def subscribe_accounts(self):
|
||||
if self.network.is_connected and self.accounts:
|
||||
await asyncio.wait([
|
||||
self.subscribe_account(a) for a in self.accounts
|
||||
])
|
||||
|
||||
async def subscribe_account(self, account: baseaccount.BaseAccount):
|
||||
for address_manager in account.address_managers.values():
|
||||
await self.subscribe_addresses(address_manager, await address_manager.get_addresses())
|
||||
await account.ensure_address_gap()
|
||||
|
||||
async def unsubscribe_account(self, account: baseaccount.BaseAccount):
|
||||
for address in await account.get_addresses():
|
||||
await self.network.unsubscribe_address(address)
|
||||
|
||||
async def announce_addresses(self, address_manager: baseaccount.AddressManager, addresses: List[str]):
|
||||
await self.subscribe_addresses(address_manager, addresses)
|
||||
await self._on_address_controller.add(
|
||||
AddressesGeneratedEvent(address_manager, addresses)
|
||||
)
|
||||
|
||||
async def subscribe_addresses(self, address_manager: baseaccount.AddressManager, addresses: List[str]):
|
||||
if self.network.is_connected and addresses:
|
||||
await asyncio.wait([
|
||||
self.subscribe_address(address_manager, address) for address in addresses
|
||||
])
|
||||
|
||||
async def subscribe_address(self, address_manager: baseaccount.AddressManager, address: str):
|
||||
remote_status = await self.network.subscribe_address(address)
|
||||
self._update_tasks.add(self.update_history(address, remote_status, address_manager))
|
||||
|
||||
def process_status_update(self, update):
|
||||
address, remote_status = update
|
||||
self._update_tasks.add(self.update_history(address, remote_status))
|
||||
|
||||
async def update_history(self, address, remote_status,
|
||||
address_manager: baseaccount.AddressManager = None):
|
||||
|
||||
async with self._address_update_locks.setdefault(address, asyncio.Lock()):
|
||||
self._known_addresses_out_of_sync.discard(address)
|
||||
|
||||
local_status, local_history = await self.get_local_status_and_history(address)
|
||||
|
||||
if local_status == remote_status:
|
||||
return True
|
||||
|
||||
remote_history = await self.network.retriable_call(self.network.get_history, address)
|
||||
remote_history = list(map(itemgetter('tx_hash', 'height'), remote_history))
|
||||
we_need = set(remote_history) - set(local_history)
|
||||
if not we_need:
|
||||
return True
|
||||
|
||||
cache_tasks: List[asyncio.Future[BaseTransaction]] = []
|
||||
synced_history = StringIO()
|
||||
for i, (txid, remote_height) in enumerate(remote_history):
|
||||
if i < len(local_history) and local_history[i] == (txid, remote_height) and not cache_tasks:
|
||||
synced_history.write(f'{txid}:{remote_height}:')
|
||||
else:
|
||||
check_local = (txid, remote_height) not in we_need
|
||||
cache_tasks.append(asyncio.ensure_future(
|
||||
self.cache_transaction(txid, remote_height, check_local=check_local)
|
||||
))
|
||||
|
||||
synced_txs = []
|
||||
for task in cache_tasks:
|
||||
tx = await task
|
||||
|
||||
check_db_for_txos = []
|
||||
for txi in tx.inputs:
|
||||
if txi.txo_ref.txo is not None:
|
||||
continue
|
||||
cache_item = self._tx_cache.get(txi.txo_ref.tx_ref.id)
|
||||
if cache_item is not None:
|
||||
if cache_item.tx is None:
|
||||
await cache_item.has_tx.wait()
|
||||
assert cache_item.tx is not None
|
||||
txi.txo_ref = cache_item.tx.outputs[txi.txo_ref.position].ref
|
||||
else:
|
||||
check_db_for_txos.append(txi.txo_ref.id)
|
||||
|
||||
referenced_txos = {} if not check_db_for_txos else {
|
||||
txo.id: txo for txo in await self.db.get_txos(txoid__in=check_db_for_txos, no_tx=True)
|
||||
}
|
||||
|
||||
for txi in tx.inputs:
|
||||
if txi.txo_ref.txo is not None:
|
||||
continue
|
||||
referenced_txo = referenced_txos.get(txi.txo_ref.id)
|
||||
if referenced_txo is not None:
|
||||
txi.txo_ref = referenced_txo.ref
|
||||
|
||||
synced_history.write(f'{tx.id}:{tx.height}:')
|
||||
synced_txs.append(tx)
|
||||
|
||||
await self.db.save_transaction_io_batch(
|
||||
synced_txs, address, self.address_to_hash160(address), synced_history.getvalue()
|
||||
)
|
||||
await asyncio.wait([
|
||||
self._on_transaction_controller.add(TransactionEvent(address, tx))
|
||||
for tx in synced_txs
|
||||
])
|
||||
|
||||
if address_manager is None:
|
||||
address_manager = await self.get_address_manager_for_address(address)
|
||||
|
||||
if address_manager is not None:
|
||||
await address_manager.ensure_address_gap()
|
||||
|
||||
local_status, local_history = \
|
||||
await self.get_local_status_and_history(address, synced_history.getvalue())
|
||||
if local_status != remote_status:
|
||||
if local_history == remote_history:
|
||||
return True
|
||||
log.warning(
|
||||
"Wallet is out of sync after syncing. Remote: %s with %d items, local: %s with %d items",
|
||||
remote_status, len(remote_history), local_status, len(local_history)
|
||||
)
|
||||
log.warning("local: %s", local_history)
|
||||
log.warning("remote: %s", remote_history)
|
||||
self._known_addresses_out_of_sync.add(address)
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
async def cache_transaction(self, txid, remote_height, check_local=True):
|
||||
cache_item = self._tx_cache.get(txid)
|
||||
if cache_item is None:
|
||||
cache_item = self._tx_cache[txid] = TransactionCacheItem()
|
||||
elif cache_item.tx is not None and \
|
||||
cache_item.tx.height >= remote_height and \
|
||||
(cache_item.tx.is_verified or remote_height < 1):
|
||||
return cache_item.tx # cached tx is already up-to-date
|
||||
|
||||
async with cache_item.lock:
|
||||
|
||||
tx = cache_item.tx
|
||||
|
||||
if tx is None and check_local:
|
||||
# check local db
|
||||
tx = cache_item.tx = await self.db.get_transaction(txid=txid)
|
||||
|
||||
if tx is None:
|
||||
# fetch from network
|
||||
_raw = await self.network.retriable_call(self.network.get_transaction, txid, remote_height)
|
||||
tx = self.transaction_class(unhexlify(_raw))
|
||||
cache_item.tx = tx # make sure it's saved before caching it
|
||||
|
||||
await self.maybe_verify_transaction(tx, remote_height)
|
||||
return tx
|
||||
|
||||
async def maybe_verify_transaction(self, tx, remote_height):
|
||||
tx.height = remote_height
|
||||
if 0 < remote_height < len(self.headers):
|
||||
merkle = await self.network.retriable_call(self.network.get_merkle, tx.id, remote_height)
|
||||
merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash)
|
||||
header = self.headers[remote_height]
|
||||
tx.position = merkle['pos']
|
||||
tx.is_verified = merkle_root == header['merkle_root']
|
||||
|
||||
async def get_address_manager_for_address(self, address) -> Optional[baseaccount.AddressManager]:
|
||||
details = await self.db.get_address(address=address)
|
||||
for account in self.accounts:
|
||||
if account.id == details['account']:
|
||||
return account.address_managers[details['chain']]
|
||||
return None
|
||||
|
||||
def broadcast(self, tx):
|
||||
# broadcast can't be a retriable call yet
|
||||
return self.network.broadcast(hexlify(tx.raw).decode())
|
||||
|
||||
async def wait(self, tx: basetransaction.BaseTransaction, height=-1, timeout=1):
|
||||
addresses = set()
|
||||
for txi in tx.inputs:
|
||||
if txi.txo_ref.txo is not None:
|
||||
addresses.add(
|
||||
self.hash160_to_address(txi.txo_ref.txo.pubkey_hash)
|
||||
)
|
||||
for txo in tx.outputs:
|
||||
if txo.has_address:
|
||||
addresses.add(self.hash160_to_address(txo.pubkey_hash))
|
||||
records = await self.db.get_addresses(address__in=addresses)
|
||||
_, pending = await asyncio.wait([
|
||||
self.on_transaction.where(partial(
|
||||
lambda a, e: a == e.address and e.tx.height >= height and e.tx.id == tx.id,
|
||||
address_record['address']
|
||||
)) for address_record in records
|
||||
], timeout=timeout)
|
||||
if pending:
|
||||
for record in records:
|
||||
found = False
|
||||
_, local_history = await self.get_local_status_and_history(None, history=record['history'])
|
||||
for txid, local_height in local_history:
|
||||
if txid == tx.id and local_height >= height:
|
||||
found = True
|
||||
if not found:
|
||||
print(record['history'], addresses, tx.id)
|
||||
raise asyncio.TimeoutError('Timed out waiting for transaction.')
|
|
@ -1,87 +0,0 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from typing import Type, MutableSequence, MutableMapping, Optional
|
||||
|
||||
from lbry.wallet.client.baseledger import BaseLedger, LedgerRegistry
|
||||
from lbry.wallet.client.wallet import Wallet, WalletStorage
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseWalletManager:
|
||||
|
||||
def __init__(self, wallets: MutableSequence[Wallet] = None,
|
||||
ledgers: MutableMapping[Type[BaseLedger], BaseLedger] = None) -> None:
|
||||
self.wallets = wallets or []
|
||||
self.ledgers = ledgers or {}
|
||||
self.running = False
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict) -> 'BaseWalletManager':
|
||||
manager = cls()
|
||||
for ledger_id, ledger_config in config.get('ledgers', {}).items():
|
||||
manager.get_or_create_ledger(ledger_id, ledger_config)
|
||||
for wallet_path in config.get('wallets', []):
|
||||
wallet_storage = WalletStorage(wallet_path)
|
||||
wallet = Wallet.from_storage(wallet_storage, manager)
|
||||
manager.wallets.append(wallet)
|
||||
return manager
|
||||
|
||||
def get_or_create_ledger(self, ledger_id, ledger_config=None):
|
||||
ledger_class = LedgerRegistry.get_ledger_class(ledger_id)
|
||||
ledger = self.ledgers.get(ledger_class)
|
||||
if ledger is None:
|
||||
ledger = ledger_class(ledger_config or {})
|
||||
self.ledgers[ledger_class] = ledger
|
||||
return ledger
|
||||
|
||||
def import_wallet(self, path):
|
||||
storage = WalletStorage(path)
|
||||
wallet = Wallet.from_storage(storage, self)
|
||||
self.wallets.append(wallet)
|
||||
return wallet
|
||||
|
||||
@property
|
||||
def default_wallet(self):
|
||||
for wallet in self.wallets:
|
||||
return wallet
|
||||
|
||||
@property
|
||||
def default_account(self):
|
||||
for wallet in self.wallets:
|
||||
return wallet.default_account
|
||||
|
||||
@property
|
||||
def accounts(self):
|
||||
for wallet in self.wallets:
|
||||
yield from wallet.accounts
|
||||
|
||||
async def start(self):
|
||||
self.running = True
|
||||
await asyncio.gather(*(
|
||||
l.start() for l in self.ledgers.values()
|
||||
))
|
||||
|
||||
async def stop(self):
|
||||
await asyncio.gather(*(
|
||||
l.stop() for l in self.ledgers.values()
|
||||
))
|
||||
self.running = False
|
||||
|
||||
def get_wallet_or_default(self, wallet_id: Optional[str]) -> Wallet:
|
||||
if wallet_id is None:
|
||||
return self.default_wallet
|
||||
return self.get_wallet_or_error(wallet_id)
|
||||
|
||||
def get_wallet_or_error(self, wallet_id: str) -> Wallet:
|
||||
for wallet in self.wallets:
|
||||
if wallet.id == wallet_id:
|
||||
return wallet
|
||||
raise ValueError(f"Couldn't find wallet: {wallet_id}.")
|
||||
|
||||
@staticmethod
|
||||
def get_balance(wallet):
|
||||
accounts = wallet.accounts
|
||||
if not accounts:
|
||||
return 0
|
||||
return accounts[0].ledger.db.get_balance(wallet=wallet, accounts=accounts)
|
|
@ -1,364 +0,0 @@
|
|||
import logging
|
||||
import asyncio
|
||||
from operator import itemgetter
|
||||
from typing import Dict, Optional, Tuple
|
||||
from time import perf_counter
|
||||
|
||||
import lbry
|
||||
from lbry.wallet.rpc import RPCSession as BaseClientSession, Connector, RPCError, ProtocolError
|
||||
from lbry.wallet.stream import StreamController
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ClientSession(BaseClientSession):
|
||||
def __init__(self, *args, network, server, timeout=30, on_connect_callback=None, **kwargs):
|
||||
self.network = network
|
||||
self.server = server
|
||||
super().__init__(*args, **kwargs)
|
||||
self._on_disconnect_controller = StreamController()
|
||||
self.on_disconnected = self._on_disconnect_controller.stream
|
||||
self.framer.max_size = self.max_errors = 1 << 32
|
||||
self.bw_limit = -1
|
||||
self.timeout = timeout
|
||||
self.max_seconds_idle = timeout * 2
|
||||
self.response_time: Optional[float] = None
|
||||
self.connection_latency: Optional[float] = None
|
||||
self._response_samples = 0
|
||||
self.pending_amount = 0
|
||||
self._on_connect_cb = on_connect_callback or (lambda: None)
|
||||
self.trigger_urgent_reconnect = asyncio.Event()
|
||||
|
||||
@property
|
||||
def available(self):
|
||||
return not self.is_closing() and self.response_time is not None
|
||||
|
||||
@property
|
||||
def server_address_and_port(self) -> Optional[Tuple[str, int]]:
|
||||
if not self.transport:
|
||||
return None
|
||||
return self.transport.get_extra_info('peername')
|
||||
|
||||
async def send_timed_server_version_request(self, args=(), timeout=None):
|
||||
timeout = timeout or self.timeout
|
||||
log.debug("send version request to %s:%i", *self.server)
|
||||
start = perf_counter()
|
||||
result = await asyncio.wait_for(
|
||||
super().send_request('server.version', args), timeout=timeout
|
||||
)
|
||||
current_response_time = perf_counter() - start
|
||||
response_sum = (self.response_time or 0) * self._response_samples + current_response_time
|
||||
self.response_time = response_sum / (self._response_samples + 1)
|
||||
self._response_samples += 1
|
||||
return result
|
||||
|
||||
async def send_request(self, method, args=()):
|
||||
self.pending_amount += 1
|
||||
log.debug("send %s to %s:%i", method, *self.server)
|
||||
try:
|
||||
if method == 'server.version':
|
||||
return await self.send_timed_server_version_request(args, self.timeout)
|
||||
request = asyncio.ensure_future(super().send_request(method, args))
|
||||
while not request.done():
|
||||
done, pending = await asyncio.wait([request], timeout=self.timeout)
|
||||
if pending:
|
||||
log.debug("Time since last packet: %s", perf_counter() - self.last_packet_received)
|
||||
if (perf_counter() - self.last_packet_received) < self.timeout:
|
||||
continue
|
||||
log.info("timeout sending %s to %s:%i", method, *self.server)
|
||||
raise asyncio.TimeoutError
|
||||
if done:
|
||||
return request.result()
|
||||
except (RPCError, ProtocolError) as e:
|
||||
log.warning("Wallet server (%s:%i) returned an error. Code: %s Message: %s",
|
||||
*self.server, *e.args)
|
||||
raise e
|
||||
except ConnectionError:
|
||||
log.warning("connection to %s:%i lost", *self.server)
|
||||
self.synchronous_close()
|
||||
raise
|
||||
except asyncio.CancelledError:
|
||||
log.info("cancelled sending %s to %s:%i", method, *self.server)
|
||||
self.synchronous_close()
|
||||
raise
|
||||
finally:
|
||||
self.pending_amount -= 1
|
||||
|
||||
async def ensure_session(self):
|
||||
# Handles reconnecting and maintaining a session alive
|
||||
# TODO: change to 'ping' on newer protocol (above 1.2)
|
||||
retry_delay = default_delay = 1.0
|
||||
while True:
|
||||
try:
|
||||
if self.is_closing():
|
||||
await self.create_connection(self.timeout)
|
||||
await self.ensure_server_version()
|
||||
self._on_connect_cb()
|
||||
if (perf_counter() - self.last_send) > self.max_seconds_idle or self.response_time is None:
|
||||
await self.ensure_server_version()
|
||||
retry_delay = default_delay
|
||||
except RPCError as e:
|
||||
log.warning("Server error, ignoring for 1h: %s:%d -- %s", *self.server, e.message)
|
||||
retry_delay = 60 * 60
|
||||
except (asyncio.TimeoutError, OSError):
|
||||
await self.close()
|
||||
retry_delay = min(60, retry_delay * 2)
|
||||
log.debug("Wallet server timeout (retry in %s seconds): %s:%d", retry_delay, *self.server)
|
||||
try:
|
||||
await asyncio.wait_for(self.trigger_urgent_reconnect.wait(), timeout=retry_delay)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
finally:
|
||||
self.trigger_urgent_reconnect.clear()
|
||||
|
||||
async def ensure_server_version(self, required=None, timeout=3):
|
||||
required = required or self.network.PROTOCOL_VERSION
|
||||
return await asyncio.wait_for(
|
||||
self.send_request('server.version', [lbry.__version__, required]), timeout=timeout
|
||||
)
|
||||
|
||||
async def create_connection(self, timeout=6):
|
||||
connector = Connector(lambda: self, *self.server)
|
||||
start = perf_counter()
|
||||
await asyncio.wait_for(connector.create_connection(), timeout=timeout)
|
||||
self.connection_latency = perf_counter() - start
|
||||
|
||||
async def handle_request(self, request):
|
||||
controller = self.network.subscription_controllers[request.method]
|
||||
controller.add(request.args)
|
||||
|
||||
def connection_lost(self, exc):
|
||||
log.debug("Connection lost: %s:%d", *self.server)
|
||||
super().connection_lost(exc)
|
||||
self.response_time = None
|
||||
self.connection_latency = None
|
||||
self._response_samples = 0
|
||||
self.pending_amount = 0
|
||||
self._on_disconnect_controller.add(True)
|
||||
|
||||
|
||||
class BaseNetwork:
|
||||
PROTOCOL_VERSION = '1.2'
|
||||
|
||||
def __init__(self, ledger):
|
||||
self.ledger = ledger
|
||||
self.session_pool = SessionPool(network=self, timeout=self.config.get('connect_timeout', 6))
|
||||
self.client: Optional[ClientSession] = None
|
||||
self._switch_task: Optional[asyncio.Task] = None
|
||||
self.running = False
|
||||
self.remote_height: int = 0
|
||||
self._concurrency = asyncio.Semaphore(16)
|
||||
|
||||
self._on_connected_controller = StreamController()
|
||||
self.on_connected = self._on_connected_controller.stream
|
||||
|
||||
self._on_header_controller = StreamController(merge_repeated_events=True)
|
||||
self.on_header = self._on_header_controller.stream
|
||||
|
||||
self._on_status_controller = StreamController(merge_repeated_events=True)
|
||||
self.on_status = self._on_status_controller.stream
|
||||
|
||||
self.subscription_controllers = {
|
||||
'blockchain.headers.subscribe': self._on_header_controller,
|
||||
'blockchain.address.subscribe': self._on_status_controller,
|
||||
}
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
return self.ledger.config
|
||||
|
||||
async def switch_forever(self):
|
||||
while self.running:
|
||||
if self.is_connected:
|
||||
await self.client.on_disconnected.first
|
||||
self.client = None
|
||||
continue
|
||||
self.client = await self.session_pool.wait_for_fastest_session()
|
||||
log.info("Switching to SPV wallet server: %s:%d", *self.client.server)
|
||||
try:
|
||||
self._update_remote_height((await self.subscribe_headers(),))
|
||||
self._on_connected_controller.add(True)
|
||||
log.info("Subscribed to headers: %s:%d", *self.client.server)
|
||||
except (asyncio.TimeoutError, ConnectionError):
|
||||
log.info("Switching to %s:%d timed out, closing and retrying.", *self.client.server)
|
||||
self.client.synchronous_close()
|
||||
self.client = None
|
||||
|
||||
async def start(self):
|
||||
self.running = True
|
||||
self._switch_task = asyncio.ensure_future(self.switch_forever())
|
||||
# this may become unnecessary when there are no more bugs found,
|
||||
# but for now it helps understanding log reports
|
||||
self._switch_task.add_done_callback(lambda _: log.info("Wallet client switching task stopped."))
|
||||
self.session_pool.start(self.config['default_servers'])
|
||||
self.on_header.listen(self._update_remote_height)
|
||||
|
||||
async def stop(self):
|
||||
if self.running:
|
||||
self.running = False
|
||||
self._switch_task.cancel()
|
||||
self.session_pool.stop()
|
||||
|
||||
@property
|
||||
def is_connected(self):
|
||||
return self.client and not self.client.is_closing()
|
||||
|
||||
def rpc(self, list_or_method, args, restricted=True):
|
||||
session = self.client if restricted else self.session_pool.fastest_session
|
||||
if session and not session.is_closing():
|
||||
return session.send_request(list_or_method, args)
|
||||
else:
|
||||
self.session_pool.trigger_nodelay_connect()
|
||||
raise ConnectionError("Attempting to send rpc request when connection is not available.")
|
||||
|
||||
async def retriable_call(self, function, *args, **kwargs):
|
||||
async with self._concurrency:
|
||||
while self.running:
|
||||
if not self.is_connected:
|
||||
log.warning("Wallet server unavailable, waiting for it to come back and retry.")
|
||||
await self.on_connected.first
|
||||
await self.session_pool.wait_for_fastest_session()
|
||||
try:
|
||||
return await function(*args, **kwargs)
|
||||
except asyncio.TimeoutError:
|
||||
log.warning("Wallet server call timed out, retrying.")
|
||||
except ConnectionError:
|
||||
pass
|
||||
raise asyncio.CancelledError() # if we got here, we are shutting down
|
||||
|
||||
def _update_remote_height(self, header_args):
|
||||
self.remote_height = header_args[0]["height"]
|
||||
|
||||
def get_transaction(self, tx_hash, known_height=None):
|
||||
# use any server if its old, otherwise restrict to who gave us the history
|
||||
restricted = known_height in (None, -1, 0) or 0 > known_height > self.remote_height - 10
|
||||
return self.rpc('blockchain.transaction.get', [tx_hash], restricted)
|
||||
|
||||
def get_transaction_height(self, tx_hash, known_height=None):
|
||||
restricted = not known_height or 0 > known_height > self.remote_height - 10
|
||||
return self.rpc('blockchain.transaction.get_height', [tx_hash], restricted)
|
||||
|
||||
def get_merkle(self, tx_hash, height):
|
||||
restricted = 0 > height > self.remote_height - 10
|
||||
return self.rpc('blockchain.transaction.get_merkle', [tx_hash, height], restricted)
|
||||
|
||||
def get_headers(self, height, count=10000, b64=False):
|
||||
restricted = height >= self.remote_height - 100
|
||||
return self.rpc('blockchain.block.headers', [height, count, 0, b64], restricted)
|
||||
|
||||
# --- Subscribes, history and broadcasts are always aimed towards the master client directly
|
||||
def get_history(self, address):
|
||||
return self.rpc('blockchain.address.get_history', [address], True)
|
||||
|
||||
def broadcast(self, raw_transaction):
|
||||
return self.rpc('blockchain.transaction.broadcast', [raw_transaction], True)
|
||||
|
||||
def subscribe_headers(self):
|
||||
return self.rpc('blockchain.headers.subscribe', [True], True)
|
||||
|
||||
async def subscribe_address(self, address):
|
||||
try:
|
||||
return await self.rpc('blockchain.address.subscribe', [address], True)
|
||||
except asyncio.TimeoutError:
|
||||
# abort and cancel, we can't lose a subscription, it will happen again on reconnect
|
||||
if self.client:
|
||||
self.client.abort()
|
||||
raise asyncio.CancelledError()
|
||||
|
||||
def unsubscribe_address(self, address):
|
||||
return self.rpc('blockchain.address.unsubscribe', [address], True)
|
||||
|
||||
def get_server_features(self):
|
||||
return self.rpc('server.features', (), restricted=True)
|
||||
|
||||
|
||||
class SessionPool:
|
||||
|
||||
def __init__(self, network: BaseNetwork, timeout: float):
|
||||
self.network = network
|
||||
self.sessions: Dict[ClientSession, Optional[asyncio.Task]] = dict()
|
||||
self.timeout = timeout
|
||||
self.new_connection_event = asyncio.Event()
|
||||
|
||||
@property
|
||||
def online(self):
|
||||
return any(not session.is_closing() for session in self.sessions)
|
||||
|
||||
@property
|
||||
def available_sessions(self):
|
||||
return (session for session in self.sessions if session.available)
|
||||
|
||||
@property
|
||||
def fastest_session(self):
|
||||
if not self.online:
|
||||
return None
|
||||
return min(
|
||||
[((session.response_time + session.connection_latency) * (session.pending_amount + 1), session)
|
||||
for session in self.available_sessions] or [(0, None)],
|
||||
key=itemgetter(0)
|
||||
)[1]
|
||||
|
||||
def _get_session_connect_callback(self, session: ClientSession):
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
def callback():
|
||||
duplicate_connections = [
|
||||
s for s in self.sessions
|
||||
if s is not session and s.server_address_and_port == session.server_address_and_port
|
||||
]
|
||||
already_connected = None if not duplicate_connections else duplicate_connections[0]
|
||||
if already_connected:
|
||||
self.sessions.pop(session).cancel()
|
||||
session.synchronous_close()
|
||||
log.debug("wallet server %s resolves to the same server as %s, rechecking in an hour",
|
||||
session.server[0], already_connected.server[0])
|
||||
loop.call_later(3600, self._connect_session, session.server)
|
||||
return
|
||||
self.new_connection_event.set()
|
||||
log.info("connected to %s:%i", *session.server)
|
||||
|
||||
return callback
|
||||
|
||||
def _connect_session(self, server: Tuple[str, int]):
|
||||
session = None
|
||||
for s in self.sessions:
|
||||
if s.server == server:
|
||||
session = s
|
||||
break
|
||||
if not session:
|
||||
session = ClientSession(
|
||||
network=self.network, server=server
|
||||
)
|
||||
session._on_connect_cb = self._get_session_connect_callback(session)
|
||||
task = self.sessions.get(session, None)
|
||||
if not task or task.done():
|
||||
task = asyncio.create_task(session.ensure_session())
|
||||
task.add_done_callback(lambda _: self.ensure_connections())
|
||||
self.sessions[session] = task
|
||||
|
||||
def start(self, default_servers):
|
||||
for server in default_servers:
|
||||
self._connect_session(server)
|
||||
|
||||
def stop(self):
|
||||
for session, task in self.sessions.items():
|
||||
task.cancel()
|
||||
session.synchronous_close()
|
||||
self.sessions.clear()
|
||||
|
||||
def ensure_connections(self):
|
||||
for session in self.sessions:
|
||||
self._connect_session(session.server)
|
||||
|
||||
def trigger_nodelay_connect(self):
|
||||
# used when other parts of the system sees we might have internet back
|
||||
# bypasses the retry interval
|
||||
for session in self.sessions:
|
||||
session.trigger_urgent_reconnect.set()
|
||||
|
||||
async def wait_for_fastest_session(self):
|
||||
while not self.fastest_session:
|
||||
self.trigger_nodelay_connect()
|
||||
self.new_connection_event.clear()
|
||||
await self.new_connection_event.wait()
|
||||
return self.fastest_session
|
|
@ -1,450 +0,0 @@
|
|||
from itertools import chain
|
||||
from binascii import hexlify
|
||||
from collections import namedtuple
|
||||
from typing import List
|
||||
|
||||
from lbry.wallet.client.bcd_data_stream import BCDataStream
|
||||
from lbry.wallet.client.util import subclass_tuple
|
||||
|
||||
# bitcoin opcodes
|
||||
OP_0 = 0x00
|
||||
OP_1 = 0x51
|
||||
OP_16 = 0x60
|
||||
OP_VERIFY = 0x69
|
||||
OP_DUP = 0x76
|
||||
OP_HASH160 = 0xa9
|
||||
OP_EQUALVERIFY = 0x88
|
||||
OP_CHECKSIG = 0xac
|
||||
OP_CHECKMULTISIG = 0xae
|
||||
OP_EQUAL = 0x87
|
||||
OP_PUSHDATA1 = 0x4c
|
||||
OP_PUSHDATA2 = 0x4d
|
||||
OP_PUSHDATA4 = 0x4e
|
||||
OP_RETURN = 0x6a
|
||||
OP_2DROP = 0x6d
|
||||
OP_DROP = 0x75
|
||||
|
||||
|
||||
# template matching opcodes (not real opcodes)
|
||||
# base class for PUSH_DATA related opcodes
|
||||
# pylint: disable=invalid-name
|
||||
PUSH_DATA_OP = namedtuple('PUSH_DATA_OP', 'name')
|
||||
# opcode for variable length strings
|
||||
# pylint: disable=invalid-name
|
||||
PUSH_SINGLE = subclass_tuple('PUSH_SINGLE', PUSH_DATA_OP)
|
||||
# opcode for variable size integers
|
||||
# pylint: disable=invalid-name
|
||||
PUSH_INTEGER = subclass_tuple('PUSH_INTEGER', PUSH_DATA_OP)
|
||||
# opcode for variable number of variable length strings
|
||||
# pylint: disable=invalid-name
|
||||
PUSH_MANY = subclass_tuple('PUSH_MANY', PUSH_DATA_OP)
|
||||
# opcode with embedded subscript parsing
|
||||
# pylint: disable=invalid-name
|
||||
PUSH_SUBSCRIPT = namedtuple('PUSH_SUBSCRIPT', 'name template')
|
||||
|
||||
|
||||
def is_push_data_opcode(opcode):
|
||||
return isinstance(opcode, (PUSH_DATA_OP, PUSH_SUBSCRIPT))
|
||||
|
||||
|
||||
def is_push_data_token(token):
|
||||
return 1 <= token <= OP_PUSHDATA4
|
||||
|
||||
|
||||
def push_data(data):
|
||||
size = len(data)
|
||||
if size < OP_PUSHDATA1:
|
||||
yield BCDataStream.uint8.pack(size)
|
||||
elif size <= 0xFF:
|
||||
yield BCDataStream.uint8.pack(OP_PUSHDATA1)
|
||||
yield BCDataStream.uint8.pack(size)
|
||||
elif size <= 0xFFFF:
|
||||
yield BCDataStream.uint8.pack(OP_PUSHDATA2)
|
||||
yield BCDataStream.uint16.pack(size)
|
||||
else:
|
||||
yield BCDataStream.uint8.pack(OP_PUSHDATA4)
|
||||
yield BCDataStream.uint32.pack(size)
|
||||
yield bytes(data)
|
||||
|
||||
|
||||
def read_data(token, stream):
|
||||
if token < OP_PUSHDATA1:
|
||||
return stream.read(token)
|
||||
if token == OP_PUSHDATA1:
|
||||
return stream.read(stream.read_uint8())
|
||||
if token == OP_PUSHDATA2:
|
||||
return stream.read(stream.read_uint16())
|
||||
return stream.read(stream.read_uint32())
|
||||
|
||||
|
||||
# opcode for OP_1 - OP_16
|
||||
# pylint: disable=invalid-name
|
||||
SMALL_INTEGER = namedtuple('SMALL_INTEGER', 'name')
|
||||
|
||||
|
||||
def is_small_integer(token):
|
||||
return OP_1 <= token <= OP_16
|
||||
|
||||
|
||||
def push_small_integer(num):
|
||||
assert 1 <= num <= 16
|
||||
yield BCDataStream.uint8.pack(OP_1 + (num - 1))
|
||||
|
||||
|
||||
def read_small_integer(token):
|
||||
return (token - OP_1) + 1
|
||||
|
||||
|
||||
class Token(namedtuple('Token', 'value')):
|
||||
__slots__ = ()
|
||||
|
||||
def __repr__(self):
|
||||
name = None
|
||||
for var_name, var_value in globals().items():
|
||||
if var_name.startswith('OP_') and var_value == self.value:
|
||||
name = var_name
|
||||
break
|
||||
return name or self.value
|
||||
|
||||
|
||||
class DataToken(Token):
|
||||
__slots__ = ()
|
||||
|
||||
def __repr__(self):
|
||||
return f'"{hexlify(self.value)}"'
|
||||
|
||||
|
||||
class SmallIntegerToken(Token):
|
||||
__slots__ = ()
|
||||
|
||||
def __repr__(self):
|
||||
return f'SmallIntegerToken({self.value})'
|
||||
|
||||
|
||||
def token_producer(source):
|
||||
token = source.read_uint8()
|
||||
while token is not None:
|
||||
if is_push_data_token(token):
|
||||
yield DataToken(read_data(token, source))
|
||||
elif is_small_integer(token):
|
||||
yield SmallIntegerToken(read_small_integer(token))
|
||||
else:
|
||||
yield Token(token)
|
||||
token = source.read_uint8()
|
||||
|
||||
|
||||
def tokenize(source):
|
||||
return list(token_producer(source))
|
||||
|
||||
|
||||
class ScriptError(Exception):
|
||||
""" General script handling error. """
|
||||
|
||||
|
||||
class ParseError(ScriptError):
|
||||
""" Script parsing error. """
|
||||
|
||||
|
||||
class Parser:
|
||||
|
||||
def __init__(self, opcodes, tokens):
|
||||
self.opcodes = opcodes
|
||||
self.tokens = tokens
|
||||
self.values = {}
|
||||
self.token_index = 0
|
||||
self.opcode_index = 0
|
||||
|
||||
def parse(self):
|
||||
while self.token_index < len(self.tokens) and self.opcode_index < len(self.opcodes):
|
||||
token = self.tokens[self.token_index]
|
||||
opcode = self.opcodes[self.opcode_index]
|
||||
if token.value == 0 and isinstance(opcode, PUSH_SINGLE):
|
||||
token = DataToken(b'')
|
||||
if isinstance(token, DataToken):
|
||||
if isinstance(opcode, (PUSH_SINGLE, PUSH_INTEGER, PUSH_SUBSCRIPT)):
|
||||
self.push_single(opcode, token.value)
|
||||
elif isinstance(opcode, PUSH_MANY):
|
||||
self.consume_many_non_greedy()
|
||||
else:
|
||||
raise ParseError(f"DataToken found but opcode was '{opcode}'.")
|
||||
elif isinstance(token, SmallIntegerToken):
|
||||
if isinstance(opcode, SMALL_INTEGER):
|
||||
self.values[opcode.name] = token.value
|
||||
else:
|
||||
raise ParseError(f"SmallIntegerToken found but opcode was '{opcode}'.")
|
||||
elif token.value == opcode:
|
||||
pass
|
||||
else:
|
||||
raise ParseError(f"Token is '{token.value}' and opcode is '{opcode}'.")
|
||||
self.token_index += 1
|
||||
self.opcode_index += 1
|
||||
|
||||
if self.token_index < len(self.tokens):
|
||||
raise ParseError("Parse completed without all tokens being consumed.")
|
||||
|
||||
if self.opcode_index < len(self.opcodes):
|
||||
raise ParseError("Parse completed without all opcodes being consumed.")
|
||||
|
||||
return self
|
||||
|
||||
def consume_many_non_greedy(self):
|
||||
""" Allows PUSH_MANY to consume data without being greedy
|
||||
in cases when one or more PUSH_SINGLEs follow a PUSH_MANY. This will
|
||||
prioritize giving all PUSH_SINGLEs some data and only after that
|
||||
subsume the rest into PUSH_MANY.
|
||||
"""
|
||||
|
||||
token_values = []
|
||||
while self.token_index < len(self.tokens):
|
||||
token = self.tokens[self.token_index]
|
||||
if not isinstance(token, DataToken):
|
||||
self.token_index -= 1
|
||||
break
|
||||
token_values.append(token.value)
|
||||
self.token_index += 1
|
||||
|
||||
push_opcodes = []
|
||||
push_many_count = 0
|
||||
while self.opcode_index < len(self.opcodes):
|
||||
opcode = self.opcodes[self.opcode_index]
|
||||
if not is_push_data_opcode(opcode):
|
||||
self.opcode_index -= 1
|
||||
break
|
||||
if isinstance(opcode, PUSH_MANY):
|
||||
push_many_count += 1
|
||||
push_opcodes.append(opcode)
|
||||
self.opcode_index += 1
|
||||
|
||||
if push_many_count > 1:
|
||||
raise ParseError(
|
||||
"Cannot have more than one consecutive PUSH_MANY, as there is no way to tell which"
|
||||
" token value should go into which PUSH_MANY."
|
||||
)
|
||||
|
||||
if len(push_opcodes) > len(token_values):
|
||||
raise ParseError(
|
||||
"Not enough token values to match all of the PUSH_MANY and PUSH_SINGLE opcodes."
|
||||
)
|
||||
|
||||
many_opcode = push_opcodes.pop(0)
|
||||
|
||||
# consume data into PUSH_SINGLE opcodes, working backwards
|
||||
for opcode in reversed(push_opcodes):
|
||||
self.push_single(opcode, token_values.pop())
|
||||
|
||||
# finally PUSH_MANY gets everything that's left
|
||||
self.values[many_opcode.name] = token_values
|
||||
|
||||
def push_single(self, opcode, value):
|
||||
if isinstance(opcode, PUSH_SINGLE):
|
||||
self.values[opcode.name] = value
|
||||
elif isinstance(opcode, PUSH_INTEGER):
|
||||
self.values[opcode.name] = int.from_bytes(value, 'little')
|
||||
elif isinstance(opcode, PUSH_SUBSCRIPT):
|
||||
self.values[opcode.name] = Script.from_source_with_template(value, opcode.template)
|
||||
else:
|
||||
raise ParseError(f"Not a push single or subscript: {opcode}")
|
||||
|
||||
|
||||
class Template:
|
||||
|
||||
__slots__ = 'name', 'opcodes'
|
||||
|
||||
def __init__(self, name, opcodes):
|
||||
self.name = name
|
||||
self.opcodes = opcodes
|
||||
|
||||
def parse(self, tokens):
|
||||
return Parser(self.opcodes, tokens).parse().values if self.opcodes else {}
|
||||
|
||||
def generate(self, values):
|
||||
source = BCDataStream()
|
||||
for opcode in self.opcodes:
|
||||
if isinstance(opcode, PUSH_SINGLE):
|
||||
data = values[opcode.name]
|
||||
source.write_many(push_data(data))
|
||||
elif isinstance(opcode, PUSH_INTEGER):
|
||||
data = values[opcode.name]
|
||||
source.write_many(push_data(
|
||||
data.to_bytes((data.bit_length() + 7) // 8, byteorder='little')
|
||||
))
|
||||
elif isinstance(opcode, PUSH_SUBSCRIPT):
|
||||
data = values[opcode.name]
|
||||
source.write_many(push_data(data.source))
|
||||
elif isinstance(opcode, PUSH_MANY):
|
||||
for data in values[opcode.name]:
|
||||
source.write_many(push_data(data))
|
||||
elif isinstance(opcode, SMALL_INTEGER):
|
||||
data = values[opcode.name]
|
||||
source.write_many(push_small_integer(data))
|
||||
else:
|
||||
source.write_uint8(opcode)
|
||||
return source.get_bytes()
|
||||
|
||||
|
||||
class Script:
|
||||
|
||||
__slots__ = 'source', '_template', '_values', '_template_hint'
|
||||
|
||||
templates: List[Template] = []
|
||||
|
||||
NO_SCRIPT = Template('no_script', None) # special case
|
||||
|
||||
def __init__(self, source=None, template=None, values=None, template_hint=None):
|
||||
self.source = source
|
||||
self._template = template
|
||||
self._values = values
|
||||
self._template_hint = template_hint
|
||||
if source is None and template and values:
|
||||
self.generate()
|
||||
|
||||
@property
|
||||
def template(self):
|
||||
if self._template is None:
|
||||
self.parse(self._template_hint)
|
||||
return self._template
|
||||
|
||||
@property
|
||||
def values(self):
|
||||
if self._values is None:
|
||||
self.parse(self._template_hint)
|
||||
return self._values
|
||||
|
||||
@property
|
||||
def tokens(self):
|
||||
return tokenize(BCDataStream(self.source))
|
||||
|
||||
@classmethod
|
||||
def from_source_with_template(cls, source, template):
|
||||
return cls(source, template_hint=template)
|
||||
|
||||
def parse(self, template_hint=None):
|
||||
tokens = self.tokens
|
||||
if not tokens and not template_hint:
|
||||
template_hint = self.NO_SCRIPT
|
||||
for template in chain((template_hint,), self.templates):
|
||||
if not template:
|
||||
continue
|
||||
try:
|
||||
self._values = template.parse(tokens)
|
||||
self._template = template
|
||||
return
|
||||
except ParseError:
|
||||
continue
|
||||
raise ValueError(f'No matching templates for source: {hexlify(self.source)}')
|
||||
|
||||
def generate(self):
|
||||
self.source = self.template.generate(self._values)
|
||||
|
||||
|
||||
class BaseInputScript(Script):
|
||||
""" Input / redeem script templates (aka scriptSig) """
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
REDEEM_PUBKEY = Template('pubkey', (
|
||||
PUSH_SINGLE('signature'),
|
||||
))
|
||||
REDEEM_PUBKEY_HASH = Template('pubkey_hash', (
|
||||
PUSH_SINGLE('signature'), PUSH_SINGLE('pubkey')
|
||||
))
|
||||
REDEEM_SCRIPT = Template('script', (
|
||||
SMALL_INTEGER('signatures_count'), PUSH_MANY('pubkeys'), SMALL_INTEGER('pubkeys_count'),
|
||||
OP_CHECKMULTISIG
|
||||
))
|
||||
REDEEM_SCRIPT_HASH = Template('script_hash', (
|
||||
OP_0, PUSH_MANY('signatures'), PUSH_SUBSCRIPT('script', REDEEM_SCRIPT)
|
||||
))
|
||||
|
||||
templates = [
|
||||
REDEEM_PUBKEY,
|
||||
REDEEM_PUBKEY_HASH,
|
||||
REDEEM_SCRIPT_HASH,
|
||||
REDEEM_SCRIPT
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def redeem_pubkey_hash(cls, signature, pubkey):
|
||||
return cls(template=cls.REDEEM_PUBKEY_HASH, values={
|
||||
'signature': signature,
|
||||
'pubkey': pubkey
|
||||
})
|
||||
|
||||
@classmethod
|
||||
def redeem_script_hash(cls, signatures, pubkeys):
|
||||
return cls(template=cls.REDEEM_SCRIPT_HASH, values={
|
||||
'signatures': signatures,
|
||||
'script': cls.redeem_script(signatures, pubkeys)
|
||||
})
|
||||
|
||||
@classmethod
|
||||
def redeem_script(cls, signatures, pubkeys):
|
||||
return cls(template=cls.REDEEM_SCRIPT, values={
|
||||
'signatures_count': len(signatures),
|
||||
'pubkeys': pubkeys,
|
||||
'pubkeys_count': len(pubkeys)
|
||||
})
|
||||
|
||||
|
||||
class BaseOutputScript(Script):
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
# output / payment script templates (aka scriptPubKey)
|
||||
PAY_PUBKEY_FULL = Template('pay_pubkey_full', (
|
||||
PUSH_SINGLE('pubkey'), OP_CHECKSIG
|
||||
))
|
||||
PAY_PUBKEY_HASH = Template('pay_pubkey_hash', (
|
||||
OP_DUP, OP_HASH160, PUSH_SINGLE('pubkey_hash'), OP_EQUALVERIFY, OP_CHECKSIG
|
||||
))
|
||||
PAY_SCRIPT_HASH = Template('pay_script_hash', (
|
||||
OP_HASH160, PUSH_SINGLE('script_hash'), OP_EQUAL
|
||||
))
|
||||
PAY_SEGWIT = Template('pay_script_hash+segwit', (
|
||||
OP_0, PUSH_SINGLE('script_hash')
|
||||
))
|
||||
RETURN_DATA = Template('return_data', (
|
||||
OP_RETURN, PUSH_SINGLE('data')
|
||||
))
|
||||
|
||||
templates = [
|
||||
PAY_PUBKEY_FULL,
|
||||
PAY_PUBKEY_HASH,
|
||||
PAY_SCRIPT_HASH,
|
||||
PAY_SEGWIT,
|
||||
RETURN_DATA
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def pay_pubkey_hash(cls, pubkey_hash):
|
||||
return cls(template=cls.PAY_PUBKEY_HASH, values={
|
||||
'pubkey_hash': pubkey_hash
|
||||
})
|
||||
|
||||
@classmethod
|
||||
def pay_script_hash(cls, script_hash):
|
||||
return cls(template=cls.PAY_SCRIPT_HASH, values={
|
||||
'script_hash': script_hash
|
||||
})
|
||||
|
||||
@classmethod
|
||||
def return_data(cls, data):
|
||||
return cls(template=cls.RETURN_DATA, values={
|
||||
'data': data
|
||||
})
|
||||
|
||||
@property
|
||||
def is_pay_pubkey(self):
|
||||
return self.template.name.endswith('pay_pubkey_full')
|
||||
|
||||
@property
|
||||
def is_pay_pubkey_hash(self):
|
||||
return self.template.name.endswith('pay_pubkey_hash')
|
||||
|
||||
@property
|
||||
def is_pay_script_hash(self):
|
||||
return self.template.name.endswith('pay_script_hash')
|
||||
|
||||
@property
|
||||
def is_return_data(self):
|
||||
return self.template.name.endswith('return_data')
|
|
@ -1,580 +0,0 @@
|
|||
import logging
|
||||
import typing
|
||||
from typing import List, Iterable, Optional, Tuple
|
||||
from binascii import hexlify
|
||||
|
||||
from lbry.crypto.hash import sha256
|
||||
from lbry.wallet.client.basescript import BaseInputScript, BaseOutputScript
|
||||
from lbry.wallet.client.baseaccount import BaseAccount
|
||||
from lbry.wallet.client.constants import COIN, NULL_HASH32
|
||||
from lbry.wallet.client.bcd_data_stream import BCDataStream
|
||||
from lbry.wallet.client.hash import TXRef, TXRefImmutable
|
||||
from lbry.wallet.client.util import ReadOnlyList
|
||||
from lbry.error import InsufficientFundsError
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from lbry.wallet.client import baseledger, wallet as basewallet
|
||||
|
||||
log = logging.getLogger()
|
||||
|
||||
|
||||
class TXRefMutable(TXRef):
|
||||
|
||||
__slots__ = ('tx',)
|
||||
|
||||
def __init__(self, tx: 'BaseTransaction') -> None:
|
||||
super().__init__()
|
||||
self.tx = tx
|
||||
|
||||
@property
|
||||
def id(self):
|
||||
if self._id is None:
|
||||
self._id = hexlify(self.hash[::-1]).decode()
|
||||
return self._id
|
||||
|
||||
@property
|
||||
def hash(self):
|
||||
if self._hash is None:
|
||||
self._hash = sha256(sha256(self.tx.raw_sans_segwit))
|
||||
return self._hash
|
||||
|
||||
@property
|
||||
def height(self):
|
||||
return self.tx.height
|
||||
|
||||
def reset(self):
|
||||
self._id = None
|
||||
self._hash = None
|
||||
|
||||
|
||||
class TXORef:
|
||||
|
||||
__slots__ = 'tx_ref', 'position'
|
||||
|
||||
def __init__(self, tx_ref: TXRef, position: int) -> None:
|
||||
self.tx_ref = tx_ref
|
||||
self.position = position
|
||||
|
||||
@property
|
||||
def id(self):
|
||||
return f'{self.tx_ref.id}:{self.position}'
|
||||
|
||||
@property
|
||||
def hash(self):
|
||||
return self.tx_ref.hash + BCDataStream.uint32.pack(self.position)
|
||||
|
||||
@property
|
||||
def is_null(self):
|
||||
return self.tx_ref.is_null
|
||||
|
||||
@property
|
||||
def txo(self) -> Optional['BaseOutput']:
|
||||
return None
|
||||
|
||||
|
||||
class TXORefResolvable(TXORef):
|
||||
|
||||
__slots__ = ('_txo',)
|
||||
|
||||
def __init__(self, txo: 'BaseOutput') -> None:
|
||||
assert txo.tx_ref is not None
|
||||
assert txo.position is not None
|
||||
super().__init__(txo.tx_ref, txo.position)
|
||||
self._txo = txo
|
||||
|
||||
@property
|
||||
def txo(self):
|
||||
return self._txo
|
||||
|
||||
|
||||
class InputOutput:
|
||||
|
||||
__slots__ = 'tx_ref', 'position'
|
||||
|
||||
def __init__(self, tx_ref: TXRef = None, position: int = None) -> None:
|
||||
self.tx_ref = tx_ref
|
||||
self.position = position
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
""" Size of this input / output in bytes. """
|
||||
stream = BCDataStream()
|
||||
self.serialize_to(stream)
|
||||
return len(stream.get_bytes())
|
||||
|
||||
def get_fee(self, ledger):
|
||||
return self.size * ledger.fee_per_byte
|
||||
|
||||
def serialize_to(self, stream, alternate_script=None):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class BaseInput(InputOutput):
|
||||
|
||||
script_class = BaseInputScript
|
||||
|
||||
NULL_SIGNATURE = b'\x00'*72
|
||||
NULL_PUBLIC_KEY = b'\x00'*33
|
||||
|
||||
__slots__ = 'txo_ref', 'sequence', 'coinbase', 'script'
|
||||
|
||||
def __init__(self, txo_ref: TXORef, script: BaseInputScript, sequence: int = 0xFFFFFFFF,
|
||||
tx_ref: TXRef = None, position: int = None) -> None:
|
||||
super().__init__(tx_ref, position)
|
||||
self.txo_ref = txo_ref
|
||||
self.sequence = sequence
|
||||
self.coinbase = script if txo_ref.is_null else None
|
||||
self.script = script if not txo_ref.is_null else None
|
||||
|
||||
@property
|
||||
def is_coinbase(self):
|
||||
return self.coinbase is not None
|
||||
|
||||
@classmethod
|
||||
def spend(cls, txo: 'BaseOutput') -> 'BaseInput':
|
||||
""" Create an input to spend the output."""
|
||||
assert txo.script.is_pay_pubkey_hash, 'Attempting to spend unsupported output.'
|
||||
script = cls.script_class.redeem_pubkey_hash(cls.NULL_SIGNATURE, cls.NULL_PUBLIC_KEY)
|
||||
return cls(txo.ref, script)
|
||||
|
||||
@property
|
||||
def amount(self) -> int:
|
||||
""" Amount this input adds to the transaction. """
|
||||
if self.txo_ref.txo is None:
|
||||
raise ValueError('Cannot resolve output to get amount.')
|
||||
return self.txo_ref.txo.amount
|
||||
|
||||
@property
|
||||
def is_my_account(self) -> Optional[bool]:
|
||||
""" True if the output this input spends is yours. """
|
||||
if self.txo_ref.txo is None:
|
||||
return False
|
||||
return self.txo_ref.txo.is_my_account
|
||||
|
||||
@classmethod
|
||||
def deserialize_from(cls, stream):
|
||||
tx_ref = TXRefImmutable.from_hash(stream.read(32), -1)
|
||||
position = stream.read_uint32()
|
||||
script = stream.read_string()
|
||||
sequence = stream.read_uint32()
|
||||
return cls(
|
||||
TXORef(tx_ref, position),
|
||||
cls.script_class(script) if not tx_ref.is_null else script,
|
||||
sequence
|
||||
)
|
||||
|
||||
def serialize_to(self, stream, alternate_script=None):
|
||||
stream.write(self.txo_ref.tx_ref.hash)
|
||||
stream.write_uint32(self.txo_ref.position)
|
||||
if alternate_script is not None:
|
||||
stream.write_string(alternate_script)
|
||||
else:
|
||||
if self.is_coinbase:
|
||||
stream.write_string(self.coinbase)
|
||||
else:
|
||||
stream.write_string(self.script.source)
|
||||
stream.write_uint32(self.sequence)
|
||||
|
||||
|
||||
class BaseOutputEffectiveAmountEstimator:
|
||||
|
||||
__slots__ = 'txo', 'txi', 'fee', 'effective_amount'
|
||||
|
||||
def __init__(self, ledger: 'baseledger.BaseLedger', txo: 'BaseOutput') -> None:
|
||||
self.txo = txo
|
||||
self.txi = ledger.transaction_class.input_class.spend(txo)
|
||||
self.fee: int = self.txi.get_fee(ledger)
|
||||
self.effective_amount: int = txo.amount - self.fee
|
||||
|
||||
def __lt__(self, other):
|
||||
return self.effective_amount < other.effective_amount
|
||||
|
||||
|
||||
class BaseOutput(InputOutput):
|
||||
|
||||
script_class = BaseOutputScript
|
||||
estimator_class = BaseOutputEffectiveAmountEstimator
|
||||
|
||||
__slots__ = 'amount', 'script', 'is_change', 'is_my_account'
|
||||
|
||||
def __init__(self, amount: int, script: BaseOutputScript,
|
||||
tx_ref: TXRef = None, position: int = None,
|
||||
is_change: Optional[bool] = None, is_my_account: Optional[bool] = None
|
||||
) -> None:
|
||||
super().__init__(tx_ref, position)
|
||||
self.amount = amount
|
||||
self.script = script
|
||||
self.is_change = is_change
|
||||
self.is_my_account = is_my_account
|
||||
|
||||
def update_annotations(self, annotated):
|
||||
if annotated is None:
|
||||
self.is_change = False
|
||||
self.is_my_account = False
|
||||
else:
|
||||
self.is_change = annotated.is_change
|
||||
self.is_my_account = annotated.is_my_account
|
||||
|
||||
@property
|
||||
def ref(self):
|
||||
return TXORefResolvable(self)
|
||||
|
||||
@property
|
||||
def id(self):
|
||||
return self.ref.id
|
||||
|
||||
@property
|
||||
def pubkey_hash(self):
|
||||
return self.script.values['pubkey_hash']
|
||||
|
||||
@property
|
||||
def has_address(self):
|
||||
return 'pubkey_hash' in self.script.values
|
||||
|
||||
def get_address(self, ledger):
|
||||
return ledger.hash160_to_address(self.pubkey_hash)
|
||||
|
||||
def get_estimator(self, ledger):
|
||||
return self.estimator_class(ledger, self)
|
||||
|
||||
@classmethod
|
||||
def pay_pubkey_hash(cls, amount, pubkey_hash):
|
||||
return cls(amount, cls.script_class.pay_pubkey_hash(pubkey_hash))
|
||||
|
||||
@classmethod
|
||||
def deserialize_from(cls, stream):
|
||||
return cls(
|
||||
amount=stream.read_uint64(),
|
||||
script=cls.script_class(stream.read_string())
|
||||
)
|
||||
|
||||
def serialize_to(self, stream, alternate_script=None):
|
||||
stream.write_uint64(self.amount)
|
||||
stream.write_string(self.script.source)
|
||||
|
||||
|
||||
class BaseTransaction:
|
||||
|
||||
input_class = BaseInput
|
||||
output_class = BaseOutput
|
||||
|
||||
def __init__(self, raw=None, version: int = 1, locktime: int = 0, is_verified: bool = False,
|
||||
height: int = -2, position: int = -1) -> None:
|
||||
self._raw = raw
|
||||
self._raw_sans_segwit = None
|
||||
self.is_segwit_flag = 0
|
||||
self.witnesses: List[bytes] = []
|
||||
self.ref = TXRefMutable(self)
|
||||
self.version = version
|
||||
self.locktime = locktime
|
||||
self._inputs: List[BaseInput] = []
|
||||
self._outputs: List[BaseOutput] = []
|
||||
self.is_verified = is_verified
|
||||
# Height Progression
|
||||
# -2: not broadcast
|
||||
# -1: in mempool but has unconfirmed inputs
|
||||
# 0: in mempool and all inputs confirmed
|
||||
# +num: confirmed in a specific block (height)
|
||||
self.height = height
|
||||
self.position = position
|
||||
if raw is not None:
|
||||
self._deserialize()
|
||||
|
||||
@property
|
||||
def is_broadcast(self):
|
||||
return self.height > -2
|
||||
|
||||
@property
|
||||
def is_mempool(self):
|
||||
return self.height in (-1, 0)
|
||||
|
||||
@property
|
||||
def is_confirmed(self):
|
||||
return self.height > 0
|
||||
|
||||
@property
|
||||
def id(self):
|
||||
return self.ref.id
|
||||
|
||||
@property
|
||||
def hash(self):
|
||||
return self.ref.hash
|
||||
|
||||
@property
|
||||
def raw(self):
|
||||
if self._raw is None:
|
||||
self._raw = self._serialize()
|
||||
return self._raw
|
||||
|
||||
@property
|
||||
def raw_sans_segwit(self):
|
||||
if self.is_segwit_flag:
|
||||
if self._raw_sans_segwit is None:
|
||||
self._raw_sans_segwit = self._serialize(sans_segwit=True)
|
||||
return self._raw_sans_segwit
|
||||
return self.raw
|
||||
|
||||
def _reset(self):
|
||||
self._raw = None
|
||||
self._raw_sans_segwit = None
|
||||
self.ref.reset()
|
||||
|
||||
@property
|
||||
def inputs(self) -> ReadOnlyList[BaseInput]:
|
||||
return ReadOnlyList(self._inputs)
|
||||
|
||||
@property
|
||||
def outputs(self) -> ReadOnlyList[BaseOutput]:
|
||||
return ReadOnlyList(self._outputs)
|
||||
|
||||
def _add(self, existing_ios: List, new_ios: Iterable[InputOutput], reset=False) -> 'BaseTransaction':
|
||||
for txio in new_ios:
|
||||
txio.tx_ref = self.ref
|
||||
txio.position = len(existing_ios)
|
||||
existing_ios.append(txio)
|
||||
if reset:
|
||||
self._reset()
|
||||
return self
|
||||
|
||||
def add_inputs(self, inputs: Iterable[BaseInput]) -> 'BaseTransaction':
|
||||
return self._add(self._inputs, inputs, True)
|
||||
|
||||
def add_outputs(self, outputs: Iterable[BaseOutput]) -> 'BaseTransaction':
|
||||
return self._add(self._outputs, outputs, True)
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
""" Size in bytes of the entire transaction. """
|
||||
return len(self.raw)
|
||||
|
||||
@property
|
||||
def base_size(self) -> int:
|
||||
""" Size of transaction without inputs or outputs in bytes. """
|
||||
return (
|
||||
self.size
|
||||
- sum(txi.size for txi in self._inputs)
|
||||
- sum(txo.size for txo in self._outputs)
|
||||
)
|
||||
|
||||
@property
|
||||
def input_sum(self):
|
||||
return sum(i.amount for i in self.inputs if i.txo_ref.txo is not None)
|
||||
|
||||
@property
|
||||
def output_sum(self):
|
||||
return sum(o.amount for o in self.outputs)
|
||||
|
||||
@property
|
||||
def net_account_balance(self) -> int:
|
||||
balance = 0
|
||||
for txi in self.inputs:
|
||||
if txi.txo_ref.txo is None:
|
||||
continue
|
||||
if txi.is_my_account is None:
|
||||
raise ValueError(
|
||||
"Cannot access net_account_balance if inputs/outputs do not "
|
||||
"have is_my_account set properly."
|
||||
)
|
||||
if txi.is_my_account:
|
||||
balance -= txi.amount
|
||||
for txo in self.outputs:
|
||||
if txo.is_my_account is None:
|
||||
raise ValueError(
|
||||
"Cannot access net_account_balance if inputs/outputs do not "
|
||||
"have is_my_account set properly."
|
||||
)
|
||||
if txo.is_my_account:
|
||||
balance += txo.amount
|
||||
return balance
|
||||
|
||||
@property
|
||||
def fee(self) -> int:
|
||||
return self.input_sum - self.output_sum
|
||||
|
||||
def get_base_fee(self, ledger) -> int:
|
||||
""" Fee for base tx excluding inputs and outputs. """
|
||||
return self.base_size * ledger.fee_per_byte
|
||||
|
||||
def get_effective_input_sum(self, ledger) -> int:
|
||||
""" Sum of input values *minus* the cost involved to spend them. """
|
||||
return sum(txi.amount - txi.get_fee(ledger) for txi in self._inputs)
|
||||
|
||||
def get_total_output_sum(self, ledger) -> int:
|
||||
""" Sum of output values *plus* the cost involved to spend them. """
|
||||
return sum(txo.amount + txo.get_fee(ledger) for txo in self._outputs)
|
||||
|
||||
def _serialize(self, with_inputs: bool = True, sans_segwit: bool = False) -> bytes:
|
||||
stream = BCDataStream()
|
||||
stream.write_uint32(self.version)
|
||||
if with_inputs:
|
||||
stream.write_compact_size(len(self._inputs))
|
||||
for txin in self._inputs:
|
||||
txin.serialize_to(stream)
|
||||
stream.write_compact_size(len(self._outputs))
|
||||
for txout in self._outputs:
|
||||
txout.serialize_to(stream)
|
||||
stream.write_uint32(self.locktime)
|
||||
return stream.get_bytes()
|
||||
|
||||
def _serialize_for_signature(self, signing_input: int) -> bytes:
|
||||
stream = BCDataStream()
|
||||
stream.write_uint32(self.version)
|
||||
stream.write_compact_size(len(self._inputs))
|
||||
for i, txin in enumerate(self._inputs):
|
||||
if signing_input == i:
|
||||
assert txin.txo_ref.txo is not None
|
||||
txin.serialize_to(stream, txin.txo_ref.txo.script.source)
|
||||
else:
|
||||
txin.serialize_to(stream, b'')
|
||||
stream.write_compact_size(len(self._outputs))
|
||||
for txout in self._outputs:
|
||||
txout.serialize_to(stream)
|
||||
stream.write_uint32(self.locktime)
|
||||
stream.write_uint32(self.signature_hash_type(1)) # signature hash type: SIGHASH_ALL
|
||||
return stream.get_bytes()
|
||||
|
||||
def _deserialize(self):
|
||||
if self._raw is not None:
|
||||
stream = BCDataStream(self._raw)
|
||||
self.version = stream.read_uint32()
|
||||
input_count = stream.read_compact_size()
|
||||
if input_count == 0:
|
||||
self.is_segwit_flag = stream.read_uint8()
|
||||
input_count = stream.read_compact_size()
|
||||
self._add(self._inputs, [
|
||||
self.input_class.deserialize_from(stream) for _ in range(input_count)
|
||||
])
|
||||
output_count = stream.read_compact_size()
|
||||
self._add(self._outputs, [
|
||||
self.output_class.deserialize_from(stream) for _ in range(output_count)
|
||||
])
|
||||
if self.is_segwit_flag:
|
||||
# drain witness portion of transaction
|
||||
# too many witnesses for no crime
|
||||
self.witnesses = []
|
||||
for _ in range(input_count):
|
||||
for _ in range(stream.read_compact_size()):
|
||||
self.witnesses.append(stream.read(stream.read_compact_size()))
|
||||
self.locktime = stream.read_uint32()
|
||||
|
||||
@classmethod
|
||||
def ensure_all_have_same_ledger_and_wallet(
|
||||
cls, funding_accounts: Iterable[BaseAccount],
|
||||
change_account: BaseAccount = None) -> Tuple['baseledger.BaseLedger', 'basewallet.Wallet']:
|
||||
ledger = wallet = None
|
||||
for account in funding_accounts:
|
||||
if ledger is None:
|
||||
ledger = account.ledger
|
||||
wallet = account.wallet
|
||||
if ledger != account.ledger:
|
||||
raise ValueError(
|
||||
'All funding accounts used to create a transaction must be on the same ledger.'
|
||||
)
|
||||
if wallet != account.wallet:
|
||||
raise ValueError(
|
||||
'All funding accounts used to create a transaction must be from the same wallet.'
|
||||
)
|
||||
if change_account is not None:
|
||||
if change_account.ledger != ledger:
|
||||
raise ValueError('Change account must use same ledger as funding accounts.')
|
||||
if change_account.wallet != wallet:
|
||||
raise ValueError('Change account must use same wallet as funding accounts.')
|
||||
if ledger is None:
|
||||
raise ValueError('No ledger found.')
|
||||
if wallet is None:
|
||||
raise ValueError('No wallet found.')
|
||||
return ledger, wallet
|
||||
|
||||
@classmethod
|
||||
async def create(cls, inputs: Iterable[BaseInput], outputs: Iterable[BaseOutput],
|
||||
funding_accounts: Iterable[BaseAccount], change_account: BaseAccount,
|
||||
sign: bool = True):
|
||||
""" Find optimal set of inputs when only outputs are provided; add change
|
||||
outputs if only inputs are provided or if inputs are greater than outputs. """
|
||||
|
||||
tx = cls() \
|
||||
.add_inputs(inputs) \
|
||||
.add_outputs(outputs)
|
||||
|
||||
ledger, _ = cls.ensure_all_have_same_ledger_and_wallet(funding_accounts, change_account)
|
||||
|
||||
# value of the outputs plus associated fees
|
||||
cost = (
|
||||
tx.get_base_fee(ledger) +
|
||||
tx.get_total_output_sum(ledger)
|
||||
)
|
||||
# value of the inputs less the cost to spend those inputs
|
||||
payment = tx.get_effective_input_sum(ledger)
|
||||
|
||||
try:
|
||||
|
||||
for _ in range(5):
|
||||
|
||||
if payment < cost:
|
||||
deficit = cost - payment
|
||||
spendables = await ledger.get_spendable_utxos(deficit, funding_accounts)
|
||||
if not spendables:
|
||||
raise InsufficientFundsError()
|
||||
payment += sum(s.effective_amount for s in spendables)
|
||||
tx.add_inputs(s.txi for s in spendables)
|
||||
|
||||
cost_of_change = (
|
||||
tx.get_base_fee(ledger) +
|
||||
cls.output_class.pay_pubkey_hash(COIN, NULL_HASH32).get_fee(ledger)
|
||||
)
|
||||
if payment > cost:
|
||||
change = payment - cost
|
||||
if change > cost_of_change:
|
||||
change_address = await change_account.change.get_or_create_usable_address()
|
||||
change_hash160 = change_account.ledger.address_to_hash160(change_address)
|
||||
change_amount = change - cost_of_change
|
||||
change_output = cls.output_class.pay_pubkey_hash(change_amount, change_hash160)
|
||||
change_output.is_change = True
|
||||
tx.add_outputs([cls.output_class.pay_pubkey_hash(change_amount, change_hash160)])
|
||||
|
||||
if tx._outputs:
|
||||
break
|
||||
# this condition and the outer range(5) loop cover an edge case
|
||||
# whereby a single input is just enough to cover the fee and
|
||||
# has some change left over, but the change left over is less
|
||||
# than the cost_of_change: thus the input is completely
|
||||
# consumed and no output is added, which is an invalid tx.
|
||||
# to be able to spend this input we must increase the cost
|
||||
# of the TX and run through the balance algorithm a second time
|
||||
# adding an extra input and change output, making tx valid.
|
||||
# we do this 5 times in case the other UTXOs added are also
|
||||
# less than the fee, after 5 attempts we give up and go home
|
||||
cost += cost_of_change + 1
|
||||
|
||||
if sign:
|
||||
await tx.sign(funding_accounts)
|
||||
|
||||
except Exception as e:
|
||||
log.exception('Failed to create transaction:')
|
||||
await ledger.release_tx(tx)
|
||||
raise e
|
||||
|
||||
return tx
|
||||
|
||||
@staticmethod
|
||||
def signature_hash_type(hash_type):
|
||||
return hash_type
|
||||
|
||||
async def sign(self, funding_accounts: Iterable[BaseAccount]):
|
||||
ledger, wallet = self.ensure_all_have_same_ledger_and_wallet(funding_accounts)
|
||||
for i, txi in enumerate(self._inputs):
|
||||
assert txi.script is not None
|
||||
assert txi.txo_ref.txo is not None
|
||||
txo_script = txi.txo_ref.txo.script
|
||||
if txo_script.is_pay_pubkey_hash:
|
||||
address = ledger.hash160_to_address(txo_script.values['pubkey_hash'])
|
||||
private_key = await ledger.get_private_key_for_address(wallet, address)
|
||||
assert private_key is not None, 'Cannot find private key for signing output.'
|
||||
tx = self._serialize_for_signature(i)
|
||||
txi.script.values['signature'] = \
|
||||
private_key.sign(tx) + bytes((self.signature_hash_type(1),))
|
||||
txi.script.values['pubkey'] = private_key.public_key.pubkey_bytes
|
||||
txi.script.generate()
|
||||
else:
|
||||
raise NotImplementedError("Don't know how to spend this output.")
|
||||
self._reset()
|
|
@ -1,6 +0,0 @@
|
|||
NULL_HASH32 = b'\x00'*32
|
||||
|
||||
CENT = 1000000
|
||||
COIN = 100*CENT
|
||||
|
||||
TIMEOUT = 30.0
|
Loading…
Add table
Reference in a new issue