forked from LBRYCommunity/lbry-sdk
merged torba base classes with lbry sub-classes
This commit is contained in:
parent
c8d72b59c0
commit
c9e410a6f4
23 changed files with 0 additions and 3475 deletions
|
@ -1,485 +0,0 @@
|
||||||
import os
|
|
||||||
import json
|
|
||||||
import time
|
|
||||||
import asyncio
|
|
||||||
import random
|
|
||||||
import typing
|
|
||||||
from typing import Dict, Tuple, Type, Optional, Any, List
|
|
||||||
|
|
||||||
from lbry.crypto.hash import sha256
|
|
||||||
from lbry.crypto.crypt import aes_encrypt, aes_decrypt
|
|
||||||
from lbry.wallet.client.bip32 import PrivateKey, PubKey, from_extended_key_string
|
|
||||||
from lbry.wallet.client.mnemonic import Mnemonic
|
|
||||||
from lbry.wallet.client.constants import COIN
|
|
||||||
from lbry.error import InvalidPasswordError
|
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
|
||||||
from lbry.wallet.client import baseledger, wallet as basewallet
|
|
||||||
|
|
||||||
|
|
||||||
class AddressManager:
|
|
||||||
|
|
||||||
name: str
|
|
||||||
|
|
||||||
__slots__ = 'account', 'public_key', 'chain_number', 'address_generator_lock'
|
|
||||||
|
|
||||||
def __init__(self, account, public_key, chain_number):
|
|
||||||
self.account = account
|
|
||||||
self.public_key = public_key
|
|
||||||
self.chain_number = chain_number
|
|
||||||
self.address_generator_lock = asyncio.Lock()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_dict(cls, account: 'BaseAccount', d: dict) \
|
|
||||||
-> Tuple['AddressManager', 'AddressManager']:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def to_dict(cls, receiving: 'AddressManager', change: 'AddressManager') -> Dict:
|
|
||||||
d: Dict[str, Any] = {'name': cls.name}
|
|
||||||
receiving_dict = receiving.to_dict_instance()
|
|
||||||
if receiving_dict:
|
|
||||||
d['receiving'] = receiving_dict
|
|
||||||
change_dict = change.to_dict_instance()
|
|
||||||
if change_dict:
|
|
||||||
d['change'] = change_dict
|
|
||||||
return d
|
|
||||||
|
|
||||||
def merge(self, d: dict):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def to_dict_instance(self) -> Optional[dict]:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def _query_addresses(self, **constraints):
|
|
||||||
return self.account.ledger.db.get_addresses(
|
|
||||||
accounts=[self.account],
|
|
||||||
chain=self.chain_number,
|
|
||||||
**constraints
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_private_key(self, index: int) -> PrivateKey:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def get_public_key(self, index: int) -> PubKey:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
async def get_max_gap(self):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
async def ensure_address_gap(self):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def get_address_records(self, only_usable: bool = False, **constraints):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
async def get_addresses(self, only_usable: bool = False, **constraints) -> List[str]:
|
|
||||||
records = await self.get_address_records(only_usable=only_usable, **constraints)
|
|
||||||
return [r['address'] for r in records]
|
|
||||||
|
|
||||||
async def get_or_create_usable_address(self) -> str:
|
|
||||||
addresses = await self.get_addresses(only_usable=True, limit=10)
|
|
||||||
if addresses:
|
|
||||||
return random.choice(addresses)
|
|
||||||
addresses = await self.ensure_address_gap()
|
|
||||||
return addresses[0]
|
|
||||||
|
|
||||||
|
|
||||||
class HierarchicalDeterministic(AddressManager):
|
|
||||||
""" Implements simple version of Bitcoin Hierarchical Deterministic key management. """
|
|
||||||
|
|
||||||
name: str = "deterministic-chain"
|
|
||||||
|
|
||||||
__slots__ = 'gap', 'maximum_uses_per_address'
|
|
||||||
|
|
||||||
def __init__(self, account: 'BaseAccount', chain: int, gap: int, maximum_uses_per_address: int) -> None:
|
|
||||||
super().__init__(account, account.public_key.child(chain), chain)
|
|
||||||
self.gap = gap
|
|
||||||
self.maximum_uses_per_address = maximum_uses_per_address
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_dict(cls, account: 'BaseAccount', d: dict) -> Tuple[AddressManager, AddressManager]:
|
|
||||||
return (
|
|
||||||
cls(account, 0, **d.get('receiving', {'gap': 20, 'maximum_uses_per_address': 1})),
|
|
||||||
cls(account, 1, **d.get('change', {'gap': 6, 'maximum_uses_per_address': 1}))
|
|
||||||
)
|
|
||||||
|
|
||||||
def merge(self, d: dict):
|
|
||||||
self.gap = d.get('gap', self.gap)
|
|
||||||
self.maximum_uses_per_address = d.get('maximum_uses_per_address', self.maximum_uses_per_address)
|
|
||||||
|
|
||||||
def to_dict_instance(self):
|
|
||||||
return {'gap': self.gap, 'maximum_uses_per_address': self.maximum_uses_per_address}
|
|
||||||
|
|
||||||
def get_private_key(self, index: int) -> PrivateKey:
|
|
||||||
return self.account.private_key.child(self.chain_number).child(index)
|
|
||||||
|
|
||||||
def get_public_key(self, index: int) -> PubKey:
|
|
||||||
return self.account.public_key.child(self.chain_number).child(index)
|
|
||||||
|
|
||||||
async def get_max_gap(self) -> int:
|
|
||||||
addresses = await self._query_addresses(order_by="n asc")
|
|
||||||
max_gap = 0
|
|
||||||
current_gap = 0
|
|
||||||
for address in addresses:
|
|
||||||
if address['used_times'] == 0:
|
|
||||||
current_gap += 1
|
|
||||||
else:
|
|
||||||
max_gap = max(max_gap, current_gap)
|
|
||||||
current_gap = 0
|
|
||||||
return max_gap
|
|
||||||
|
|
||||||
async def ensure_address_gap(self) -> List[str]:
|
|
||||||
async with self.address_generator_lock:
|
|
||||||
addresses = await self._query_addresses(limit=self.gap, order_by="n desc")
|
|
||||||
|
|
||||||
existing_gap = 0
|
|
||||||
for address in addresses:
|
|
||||||
if address['used_times'] == 0:
|
|
||||||
existing_gap += 1
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
|
|
||||||
if existing_gap == self.gap:
|
|
||||||
return []
|
|
||||||
|
|
||||||
start = addresses[0]['pubkey'].n+1 if addresses else 0
|
|
||||||
end = start + (self.gap - existing_gap)
|
|
||||||
new_keys = await self._generate_keys(start, end-1)
|
|
||||||
await self.account.ledger.announce_addresses(self, new_keys)
|
|
||||||
return new_keys
|
|
||||||
|
|
||||||
async def _generate_keys(self, start: int, end: int) -> List[str]:
|
|
||||||
if not self.address_generator_lock.locked():
|
|
||||||
raise RuntimeError('Should not be called outside of address_generator_lock.')
|
|
||||||
keys = [self.public_key.child(index) for index in range(start, end+1)]
|
|
||||||
await self.account.ledger.db.add_keys(self.account, self.chain_number, keys)
|
|
||||||
return [key.address for key in keys]
|
|
||||||
|
|
||||||
def get_address_records(self, only_usable: bool = False, **constraints):
|
|
||||||
if only_usable:
|
|
||||||
constraints['used_times__lt'] = self.maximum_uses_per_address
|
|
||||||
if 'order_by' not in constraints:
|
|
||||||
constraints['order_by'] = "used_times asc, n asc"
|
|
||||||
return self._query_addresses(**constraints)
|
|
||||||
|
|
||||||
|
|
||||||
class SingleKey(AddressManager):
|
|
||||||
""" Single Key address manager always returns the same address for all operations. """
|
|
||||||
|
|
||||||
name: str = "single-address"
|
|
||||||
|
|
||||||
__slots__ = ()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_dict(cls, account: 'BaseAccount', d: dict)\
|
|
||||||
-> Tuple[AddressManager, AddressManager]:
|
|
||||||
same_address_manager = cls(account, account.public_key, 0)
|
|
||||||
return same_address_manager, same_address_manager
|
|
||||||
|
|
||||||
def to_dict_instance(self):
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_private_key(self, index: int) -> PrivateKey:
|
|
||||||
return self.account.private_key
|
|
||||||
|
|
||||||
def get_public_key(self, index: int) -> PubKey:
|
|
||||||
return self.account.public_key
|
|
||||||
|
|
||||||
async def get_max_gap(self) -> int:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
async def ensure_address_gap(self) -> List[str]:
|
|
||||||
async with self.address_generator_lock:
|
|
||||||
exists = await self.get_address_records()
|
|
||||||
if not exists:
|
|
||||||
await self.account.ledger.db.add_keys(self.account, self.chain_number, [self.public_key])
|
|
||||||
new_keys = [self.public_key.address]
|
|
||||||
await self.account.ledger.announce_addresses(self, new_keys)
|
|
||||||
return new_keys
|
|
||||||
return []
|
|
||||||
|
|
||||||
def get_address_records(self, only_usable: bool = False, **constraints):
|
|
||||||
return self._query_addresses(**constraints)
|
|
||||||
|
|
||||||
|
|
||||||
class BaseAccount:
|
|
||||||
|
|
||||||
mnemonic_class = Mnemonic
|
|
||||||
private_key_class = PrivateKey
|
|
||||||
public_key_class = PubKey
|
|
||||||
address_generators: Dict[str, Type[AddressManager]] = {
|
|
||||||
SingleKey.name: SingleKey,
|
|
||||||
HierarchicalDeterministic.name: HierarchicalDeterministic,
|
|
||||||
}
|
|
||||||
|
|
||||||
def __init__(self, ledger: 'baseledger.BaseLedger', wallet: 'basewallet.Wallet', name: str,
|
|
||||||
seed: str, private_key_string: str, encrypted: bool,
|
|
||||||
private_key: Optional[PrivateKey], public_key: PubKey,
|
|
||||||
address_generator: dict, modified_on: float) -> None:
|
|
||||||
self.ledger = ledger
|
|
||||||
self.wallet = wallet
|
|
||||||
self.id = public_key.address
|
|
||||||
self.name = name
|
|
||||||
self.seed = seed
|
|
||||||
self.modified_on = modified_on
|
|
||||||
self.private_key_string = private_key_string
|
|
||||||
self.init_vectors: Dict[str, bytes] = {}
|
|
||||||
self.encrypted = encrypted
|
|
||||||
self.private_key = private_key
|
|
||||||
self.public_key = public_key
|
|
||||||
generator_name = address_generator.get('name', HierarchicalDeterministic.name)
|
|
||||||
self.address_generator = self.address_generators[generator_name]
|
|
||||||
self.receiving, self.change = self.address_generator.from_dict(self, address_generator)
|
|
||||||
self.address_managers = {am.chain_number: am for am in {self.receiving, self.change}}
|
|
||||||
ledger.add_account(self)
|
|
||||||
wallet.add_account(self)
|
|
||||||
|
|
||||||
def get_init_vector(self, key) -> Optional[bytes]:
|
|
||||||
init_vector = self.init_vectors.get(key, None)
|
|
||||||
if init_vector is None:
|
|
||||||
init_vector = self.init_vectors[key] = os.urandom(16)
|
|
||||||
return init_vector
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def generate(cls, ledger: 'baseledger.BaseLedger', wallet: 'basewallet.Wallet',
|
|
||||||
name: str = None, address_generator: dict = None):
|
|
||||||
return cls.from_dict(ledger, wallet, {
|
|
||||||
'name': name,
|
|
||||||
'seed': cls.mnemonic_class().make_seed(),
|
|
||||||
'address_generator': address_generator or {}
|
|
||||||
})
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_private_key_from_seed(cls, ledger: 'baseledger.BaseLedger', seed: str, password: str):
|
|
||||||
return cls.private_key_class.from_seed(
|
|
||||||
ledger, cls.mnemonic_class.mnemonic_to_seed(seed, password)
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def keys_from_dict(cls, ledger: 'baseledger.BaseLedger', d: dict) \
|
|
||||||
-> Tuple[str, Optional[PrivateKey], PubKey]:
|
|
||||||
seed = d.get('seed', '')
|
|
||||||
private_key_string = d.get('private_key', '')
|
|
||||||
private_key = None
|
|
||||||
public_key = None
|
|
||||||
encrypted = d.get('encrypted', False)
|
|
||||||
if not encrypted:
|
|
||||||
if seed:
|
|
||||||
private_key = cls.get_private_key_from_seed(ledger, seed, '')
|
|
||||||
public_key = private_key.public_key
|
|
||||||
elif private_key_string:
|
|
||||||
private_key = from_extended_key_string(ledger, private_key_string)
|
|
||||||
public_key = private_key.public_key
|
|
||||||
if public_key is None:
|
|
||||||
public_key = from_extended_key_string(ledger, d['public_key'])
|
|
||||||
return seed, private_key, public_key
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_dict(cls, ledger: 'baseledger.BaseLedger', wallet: 'basewallet.Wallet', d: dict):
|
|
||||||
seed, private_key, public_key = cls.keys_from_dict(ledger, d)
|
|
||||||
name = d.get('name')
|
|
||||||
if not name:
|
|
||||||
name = f'Account #{public_key.address}'
|
|
||||||
return cls(
|
|
||||||
ledger=ledger,
|
|
||||||
wallet=wallet,
|
|
||||||
name=name,
|
|
||||||
seed=seed,
|
|
||||||
private_key_string=d.get('private_key', ''),
|
|
||||||
encrypted=d.get('encrypted', False),
|
|
||||||
private_key=private_key,
|
|
||||||
public_key=public_key,
|
|
||||||
address_generator=d.get('address_generator', {}),
|
|
||||||
modified_on=d.get('modified_on', time.time())
|
|
||||||
)
|
|
||||||
|
|
||||||
def to_dict(self, encrypt_password: str = None):
|
|
||||||
private_key_string, seed = self.private_key_string, self.seed
|
|
||||||
if not self.encrypted and self.private_key:
|
|
||||||
private_key_string = self.private_key.extended_key_string()
|
|
||||||
if not self.encrypted and encrypt_password:
|
|
||||||
if private_key_string:
|
|
||||||
private_key_string = aes_encrypt(
|
|
||||||
encrypt_password, private_key_string, self.get_init_vector('private_key')
|
|
||||||
)
|
|
||||||
if seed:
|
|
||||||
seed = aes_encrypt(encrypt_password, self.seed, self.get_init_vector('seed'))
|
|
||||||
return {
|
|
||||||
'ledger': self.ledger.get_id(),
|
|
||||||
'name': self.name,
|
|
||||||
'seed': seed,
|
|
||||||
'encrypted': bool(self.encrypted or encrypt_password),
|
|
||||||
'private_key': private_key_string,
|
|
||||||
'public_key': self.public_key.extended_key_string(),
|
|
||||||
'address_generator': self.address_generator.to_dict(self.receiving, self.change),
|
|
||||||
'modified_on': self.modified_on
|
|
||||||
}
|
|
||||||
|
|
||||||
def merge(self, d: dict):
|
|
||||||
if d.get('modified_on', 0) > self.modified_on:
|
|
||||||
self.name = d['name']
|
|
||||||
self.modified_on = d.get('modified_on', time.time())
|
|
||||||
assert self.address_generator.name == d['address_generator']['name']
|
|
||||||
for chain_name in ('change', 'receiving'):
|
|
||||||
if chain_name in d['address_generator']:
|
|
||||||
chain_object = getattr(self, chain_name)
|
|
||||||
chain_object.merge(d['address_generator'][chain_name])
|
|
||||||
|
|
||||||
@property
|
|
||||||
def hash(self) -> bytes:
|
|
||||||
assert not self.encrypted, "Cannot hash an encrypted account."
|
|
||||||
return sha256(json.dumps(self.to_dict()).encode())
|
|
||||||
|
|
||||||
async def get_details(self, show_seed=False, **kwargs):
|
|
||||||
satoshis = await self.get_balance(**kwargs)
|
|
||||||
details = {
|
|
||||||
'id': self.id,
|
|
||||||
'name': self.name,
|
|
||||||
'ledger': self.ledger.get_id(),
|
|
||||||
'coins': round(satoshis/COIN, 2),
|
|
||||||
'satoshis': satoshis,
|
|
||||||
'encrypted': self.encrypted,
|
|
||||||
'public_key': self.public_key.extended_key_string(),
|
|
||||||
'address_generator': self.address_generator.to_dict(self.receiving, self.change)
|
|
||||||
}
|
|
||||||
if show_seed:
|
|
||||||
details['seed'] = self.seed
|
|
||||||
return details
|
|
||||||
|
|
||||||
def decrypt(self, password: str) -> bool:
|
|
||||||
assert self.encrypted, "Key is not encrypted."
|
|
||||||
try:
|
|
||||||
seed = self._decrypt_seed(password)
|
|
||||||
except (ValueError, InvalidPasswordError):
|
|
||||||
return False
|
|
||||||
try:
|
|
||||||
private_key = self._decrypt_private_key_string(password)
|
|
||||||
except (TypeError, ValueError, InvalidPasswordError):
|
|
||||||
return False
|
|
||||||
self.seed = seed
|
|
||||||
self.private_key = private_key
|
|
||||||
self.private_key_string = ""
|
|
||||||
self.encrypted = False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _decrypt_private_key_string(self, password: str) -> Optional[PrivateKey]:
|
|
||||||
if not self.private_key_string:
|
|
||||||
return None
|
|
||||||
private_key_string, self.init_vectors['private_key'] = aes_decrypt(password, self.private_key_string)
|
|
||||||
if not private_key_string:
|
|
||||||
return None
|
|
||||||
return from_extended_key_string(
|
|
||||||
self.ledger, private_key_string
|
|
||||||
)
|
|
||||||
|
|
||||||
def _decrypt_seed(self, password: str) -> str:
|
|
||||||
if not self.seed:
|
|
||||||
return ""
|
|
||||||
seed, self.init_vectors['seed'] = aes_decrypt(password, self.seed)
|
|
||||||
if not seed:
|
|
||||||
return ""
|
|
||||||
try:
|
|
||||||
Mnemonic().mnemonic_decode(seed)
|
|
||||||
except IndexError:
|
|
||||||
# failed to decode the seed, this either means it decrypted and is invalid
|
|
||||||
# or that we hit an edge case where an incorrect password gave valid padding
|
|
||||||
raise ValueError("Failed to decode seed.")
|
|
||||||
return seed
|
|
||||||
|
|
||||||
def encrypt(self, password: str) -> bool:
|
|
||||||
assert not self.encrypted, "Key is already encrypted."
|
|
||||||
if self.seed:
|
|
||||||
self.seed = aes_encrypt(password, self.seed, self.get_init_vector('seed'))
|
|
||||||
if isinstance(self.private_key, PrivateKey):
|
|
||||||
self.private_key_string = aes_encrypt(
|
|
||||||
password, self.private_key.extended_key_string(), self.get_init_vector('private_key')
|
|
||||||
)
|
|
||||||
self.private_key = None
|
|
||||||
self.encrypted = True
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def ensure_address_gap(self):
|
|
||||||
addresses = []
|
|
||||||
for address_manager in self.address_managers.values():
|
|
||||||
new_addresses = await address_manager.ensure_address_gap()
|
|
||||||
addresses.extend(new_addresses)
|
|
||||||
return addresses
|
|
||||||
|
|
||||||
async def get_addresses(self, **constraints) -> List[str]:
|
|
||||||
rows = await self.ledger.db.select_addresses('address', accounts=[self], **constraints)
|
|
||||||
return [r[0] for r in rows]
|
|
||||||
|
|
||||||
def get_address_records(self, **constraints):
|
|
||||||
return self.ledger.db.get_addresses(accounts=[self], **constraints)
|
|
||||||
|
|
||||||
def get_address_count(self, **constraints):
|
|
||||||
return self.ledger.db.get_address_count(accounts=[self], **constraints)
|
|
||||||
|
|
||||||
def get_private_key(self, chain: int, index: int) -> PrivateKey:
|
|
||||||
assert not self.encrypted, "Cannot get private key on encrypted wallet account."
|
|
||||||
return self.address_managers[chain].get_private_key(index)
|
|
||||||
|
|
||||||
def get_public_key(self, chain: int, index: int) -> PubKey:
|
|
||||||
return self.address_managers[chain].get_public_key(index)
|
|
||||||
|
|
||||||
def get_balance(self, confirmations: int = 0, **constraints):
|
|
||||||
if confirmations > 0:
|
|
||||||
height = self.ledger.headers.height - (confirmations-1)
|
|
||||||
constraints.update({'height__lte': height, 'height__gt': 0})
|
|
||||||
return self.ledger.db.get_balance(accounts=[self], **constraints)
|
|
||||||
|
|
||||||
async def get_max_gap(self):
|
|
||||||
change_gap = await self.change.get_max_gap()
|
|
||||||
receiving_gap = await self.receiving.get_max_gap()
|
|
||||||
return {
|
|
||||||
'max_change_gap': change_gap,
|
|
||||||
'max_receiving_gap': receiving_gap,
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_utxos(self, **constraints):
|
|
||||||
return self.ledger.get_utxos(wallet=self.wallet, accounts=[self], **constraints)
|
|
||||||
|
|
||||||
def get_utxo_count(self, **constraints):
|
|
||||||
return self.ledger.get_utxo_count(wallet=self.wallet, accounts=[self], **constraints)
|
|
||||||
|
|
||||||
def get_transactions(self, **constraints):
|
|
||||||
return self.ledger.get_transactions(wallet=self.wallet, accounts=[self], **constraints)
|
|
||||||
|
|
||||||
def get_transaction_count(self, **constraints):
|
|
||||||
return self.ledger.get_transaction_count(wallet=self.wallet, accounts=[self], **constraints)
|
|
||||||
|
|
||||||
async def fund(self, to_account, amount=None, everything=False,
|
|
||||||
outputs=1, broadcast=False, **constraints):
|
|
||||||
assert self.ledger == to_account.ledger, 'Can only transfer between accounts of the same ledger.'
|
|
||||||
tx_class = self.ledger.transaction_class
|
|
||||||
if everything:
|
|
||||||
utxos = await self.get_utxos(**constraints)
|
|
||||||
await self.ledger.reserve_outputs(utxos)
|
|
||||||
tx = await tx_class.create(
|
|
||||||
inputs=[tx_class.input_class.spend(txo) for txo in utxos],
|
|
||||||
outputs=[],
|
|
||||||
funding_accounts=[self],
|
|
||||||
change_account=to_account
|
|
||||||
)
|
|
||||||
elif amount > 0:
|
|
||||||
to_address = await to_account.change.get_or_create_usable_address()
|
|
||||||
to_hash160 = to_account.ledger.address_to_hash160(to_address)
|
|
||||||
tx = await tx_class.create(
|
|
||||||
inputs=[],
|
|
||||||
outputs=[
|
|
||||||
tx_class.output_class.pay_pubkey_hash(amount//outputs, to_hash160)
|
|
||||||
for _ in range(outputs)
|
|
||||||
],
|
|
||||||
funding_accounts=[self],
|
|
||||||
change_account=self
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError('An amount is required.')
|
|
||||||
|
|
||||||
if broadcast:
|
|
||||||
await self.ledger.broadcast(tx)
|
|
||||||
else:
|
|
||||||
await self.ledger.release_tx(tx)
|
|
||||||
|
|
||||||
return tx
|
|
|
@ -1,652 +0,0 @@
|
||||||
import logging
|
|
||||||
import asyncio
|
|
||||||
from binascii import hexlify
|
|
||||||
from concurrent.futures.thread import ThreadPoolExecutor
|
|
||||||
|
|
||||||
from typing import Tuple, List, Union, Callable, Any, Awaitable, Iterable, Dict, Optional
|
|
||||||
|
|
||||||
import sqlite3
|
|
||||||
|
|
||||||
from lbry.wallet.client.basetransaction import BaseTransaction, TXRefImmutable
|
|
||||||
from lbry.wallet.client.bip32 import PubKey
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
|
||||||
sqlite3.enable_callback_tracebacks(True)
|
|
||||||
|
|
||||||
|
|
||||||
class AIOSQLite:
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
# has to be single threaded as there is no mapping of thread:connection
|
|
||||||
self.executor = ThreadPoolExecutor(max_workers=1)
|
|
||||||
self.connection: sqlite3.Connection = None
|
|
||||||
self._closing = False
|
|
||||||
self.query_count = 0
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def connect(cls, path: Union[bytes, str], *args, **kwargs):
|
|
||||||
sqlite3.enable_callback_tracebacks(True)
|
|
||||||
def _connect():
|
|
||||||
return sqlite3.connect(path, *args, **kwargs)
|
|
||||||
db = cls()
|
|
||||||
db.connection = await asyncio.get_event_loop().run_in_executor(db.executor, _connect)
|
|
||||||
return db
|
|
||||||
|
|
||||||
async def close(self):
|
|
||||||
if self._closing:
|
|
||||||
return
|
|
||||||
self._closing = True
|
|
||||||
await asyncio.get_event_loop().run_in_executor(self.executor, self.connection.close)
|
|
||||||
self.executor.shutdown(wait=True)
|
|
||||||
self.connection = None
|
|
||||||
|
|
||||||
def executemany(self, sql: str, params: Iterable):
|
|
||||||
params = params if params is not None else []
|
|
||||||
# this fetchall is needed to prevent SQLITE_MISUSE
|
|
||||||
return self.run(lambda conn: conn.executemany(sql, params).fetchall())
|
|
||||||
|
|
||||||
def executescript(self, script: str) -> Awaitable:
|
|
||||||
return self.run(lambda conn: conn.executescript(script))
|
|
||||||
|
|
||||||
def execute_fetchall(self, sql: str, parameters: Iterable = None) -> Awaitable[Iterable[sqlite3.Row]]:
|
|
||||||
parameters = parameters if parameters is not None else []
|
|
||||||
return self.run(lambda conn: conn.execute(sql, parameters).fetchall())
|
|
||||||
|
|
||||||
def execute_fetchone(self, sql: str, parameters: Iterable = None) -> Awaitable[Iterable[sqlite3.Row]]:
|
|
||||||
parameters = parameters if parameters is not None else []
|
|
||||||
return self.run(lambda conn: conn.execute(sql, parameters).fetchone())
|
|
||||||
|
|
||||||
def execute(self, sql: str, parameters: Iterable = None) -> Awaitable[sqlite3.Cursor]:
|
|
||||||
parameters = parameters if parameters is not None else []
|
|
||||||
return self.run(lambda conn: conn.execute(sql, parameters))
|
|
||||||
|
|
||||||
def run(self, fun, *args, **kwargs) -> Awaitable:
|
|
||||||
return asyncio.get_event_loop().run_in_executor(
|
|
||||||
self.executor, lambda: self.__run_transaction(fun, *args, **kwargs)
|
|
||||||
)
|
|
||||||
|
|
||||||
def __run_transaction(self, fun: Callable[[sqlite3.Connection, Any, Any], Any], *args, **kwargs):
|
|
||||||
self.connection.execute('begin')
|
|
||||||
try:
|
|
||||||
self.query_count += 1
|
|
||||||
result = fun(self.connection, *args, **kwargs) # type: ignore
|
|
||||||
self.connection.commit()
|
|
||||||
return result
|
|
||||||
except (Exception, OSError) as e:
|
|
||||||
log.exception('Error running transaction:', exc_info=e)
|
|
||||||
self.connection.rollback()
|
|
||||||
log.warning("rolled back")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def run_with_foreign_keys_disabled(self, fun, *args, **kwargs) -> Awaitable:
|
|
||||||
return asyncio.get_event_loop().run_in_executor(
|
|
||||||
self.executor, self.__run_transaction_with_foreign_keys_disabled, fun, args, kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
def __run_transaction_with_foreign_keys_disabled(self,
|
|
||||||
fun: Callable[[sqlite3.Connection, Any, Any], Any],
|
|
||||||
args, kwargs):
|
|
||||||
foreign_keys_enabled, = self.connection.execute("pragma foreign_keys").fetchone()
|
|
||||||
if not foreign_keys_enabled:
|
|
||||||
raise sqlite3.IntegrityError("foreign keys are disabled, use `AIOSQLite.run` instead")
|
|
||||||
try:
|
|
||||||
self.connection.execute('pragma foreign_keys=off').fetchone()
|
|
||||||
return self.__run_transaction(fun, *args, **kwargs)
|
|
||||||
finally:
|
|
||||||
self.connection.execute('pragma foreign_keys=on').fetchone()
|
|
||||||
|
|
||||||
|
|
||||||
def constraints_to_sql(constraints, joiner=' AND ', prepend_key=''):
|
|
||||||
sql, values = [], {}
|
|
||||||
for key, constraint in constraints.items():
|
|
||||||
tag = '0'
|
|
||||||
if '#' in key:
|
|
||||||
key, tag = key[:key.index('#')], key[key.index('#')+1:]
|
|
||||||
col, op, key = key, '=', key.replace('.', '_')
|
|
||||||
if not key:
|
|
||||||
sql.append(constraint)
|
|
||||||
continue
|
|
||||||
if key.startswith('$'):
|
|
||||||
values[key] = constraint
|
|
||||||
continue
|
|
||||||
if key.endswith('__not'):
|
|
||||||
col, op = col[:-len('__not')], '!='
|
|
||||||
elif key.endswith('__is_null'):
|
|
||||||
col = col[:-len('__is_null')]
|
|
||||||
sql.append(f'{col} IS NULL')
|
|
||||||
continue
|
|
||||||
if key.endswith('__is_not_null'):
|
|
||||||
col = col[:-len('__is_not_null')]
|
|
||||||
sql.append(f'{col} IS NOT NULL')
|
|
||||||
continue
|
|
||||||
if key.endswith('__lt'):
|
|
||||||
col, op = col[:-len('__lt')], '<'
|
|
||||||
elif key.endswith('__lte'):
|
|
||||||
col, op = col[:-len('__lte')], '<='
|
|
||||||
elif key.endswith('__gt'):
|
|
||||||
col, op = col[:-len('__gt')], '>'
|
|
||||||
elif key.endswith('__gte'):
|
|
||||||
col, op = col[:-len('__gte')], '>='
|
|
||||||
elif key.endswith('__like'):
|
|
||||||
col, op = col[:-len('__like')], 'LIKE'
|
|
||||||
elif key.endswith('__not_like'):
|
|
||||||
col, op = col[:-len('__not_like')], 'NOT LIKE'
|
|
||||||
elif key.endswith('__in') or key.endswith('__not_in'):
|
|
||||||
if key.endswith('__in'):
|
|
||||||
col, op = col[:-len('__in')], 'IN'
|
|
||||||
else:
|
|
||||||
col, op = col[:-len('__not_in')], 'NOT IN'
|
|
||||||
if constraint:
|
|
||||||
if isinstance(constraint, (list, set, tuple)):
|
|
||||||
keys = []
|
|
||||||
for i, val in enumerate(constraint):
|
|
||||||
keys.append(f':{key}{tag}_{i}')
|
|
||||||
values[f'{key}{tag}_{i}'] = val
|
|
||||||
sql.append(f'{col} {op} ({", ".join(keys)})')
|
|
||||||
elif isinstance(constraint, str):
|
|
||||||
sql.append(f'{col} {op} ({constraint})')
|
|
||||||
else:
|
|
||||||
raise ValueError(f"{col} requires a list, set or string as constraint value.")
|
|
||||||
continue
|
|
||||||
elif key.endswith('__any') or key.endswith('__or'):
|
|
||||||
where, subvalues = constraints_to_sql(constraint, ' OR ', key+tag+'_')
|
|
||||||
sql.append(f'({where})')
|
|
||||||
values.update(subvalues)
|
|
||||||
continue
|
|
||||||
if key.endswith('__and'):
|
|
||||||
where, subvalues = constraints_to_sql(constraint, ' AND ', key+tag+'_')
|
|
||||||
sql.append(f'({where})')
|
|
||||||
values.update(subvalues)
|
|
||||||
continue
|
|
||||||
sql.append(f'{col} {op} :{prepend_key}{key}{tag}')
|
|
||||||
values[prepend_key+key+tag] = constraint
|
|
||||||
return joiner.join(sql) if sql else '', values
|
|
||||||
|
|
||||||
|
|
||||||
def query(select, **constraints) -> Tuple[str, Dict[str, Any]]:
|
|
||||||
sql = [select]
|
|
||||||
limit = constraints.pop('limit', None)
|
|
||||||
offset = constraints.pop('offset', None)
|
|
||||||
order_by = constraints.pop('order_by', None)
|
|
||||||
|
|
||||||
accounts = constraints.pop('accounts', [])
|
|
||||||
if accounts:
|
|
||||||
constraints['account__in'] = [a.public_key.address for a in accounts]
|
|
||||||
|
|
||||||
where, values = constraints_to_sql(constraints)
|
|
||||||
if where:
|
|
||||||
sql.append('WHERE')
|
|
||||||
sql.append(where)
|
|
||||||
|
|
||||||
if order_by:
|
|
||||||
sql.append('ORDER BY')
|
|
||||||
if isinstance(order_by, str):
|
|
||||||
sql.append(order_by)
|
|
||||||
elif isinstance(order_by, list):
|
|
||||||
sql.append(', '.join(order_by))
|
|
||||||
else:
|
|
||||||
raise ValueError("order_by must be string or list")
|
|
||||||
|
|
||||||
if limit is not None:
|
|
||||||
sql.append(f'LIMIT {limit}')
|
|
||||||
|
|
||||||
if offset is not None:
|
|
||||||
sql.append(f'OFFSET {offset}')
|
|
||||||
|
|
||||||
return ' '.join(sql), values
|
|
||||||
|
|
||||||
|
|
||||||
def interpolate(sql, values):
|
|
||||||
for k in sorted(values.keys(), reverse=True):
|
|
||||||
value = values[k]
|
|
||||||
if isinstance(value, bytes):
|
|
||||||
value = f"X'{hexlify(value).decode()}'"
|
|
||||||
elif isinstance(value, str):
|
|
||||||
value = f"'{value}'"
|
|
||||||
else:
|
|
||||||
value = str(value)
|
|
||||||
sql = sql.replace(f":{k}", value)
|
|
||||||
return sql
|
|
||||||
|
|
||||||
|
|
||||||
def rows_to_dict(rows, fields):
|
|
||||||
if rows:
|
|
||||||
return [dict(zip(fields, r)) for r in rows]
|
|
||||||
else:
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
class SQLiteMixin:
|
|
||||||
|
|
||||||
SCHEMA_VERSION: Optional[str] = None
|
|
||||||
CREATE_TABLES_QUERY: str
|
|
||||||
MAX_QUERY_VARIABLES = 900
|
|
||||||
|
|
||||||
CREATE_VERSION_TABLE = """
|
|
||||||
create table if not exists version (
|
|
||||||
version text
|
|
||||||
);
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, path):
|
|
||||||
self._db_path = path
|
|
||||||
self.db: AIOSQLite = None
|
|
||||||
self.ledger = None
|
|
||||||
|
|
||||||
async def open(self):
|
|
||||||
log.info("connecting to database: %s", self._db_path)
|
|
||||||
self.db = await AIOSQLite.connect(self._db_path, isolation_level=None)
|
|
||||||
if self.SCHEMA_VERSION:
|
|
||||||
tables = [t[0] for t in await self.db.execute_fetchall(
|
|
||||||
"SELECT name FROM sqlite_master WHERE type='table';"
|
|
||||||
)]
|
|
||||||
if tables:
|
|
||||||
if 'version' in tables:
|
|
||||||
version = await self.db.execute_fetchone("SELECT version FROM version LIMIT 1;")
|
|
||||||
if version == (self.SCHEMA_VERSION,):
|
|
||||||
return
|
|
||||||
await self.db.executescript('\n'.join(
|
|
||||||
f"DROP TABLE {table};" for table in tables
|
|
||||||
))
|
|
||||||
await self.db.execute(self.CREATE_VERSION_TABLE)
|
|
||||||
await self.db.execute("INSERT INTO version VALUES (?)", (self.SCHEMA_VERSION,))
|
|
||||||
await self.db.executescript(self.CREATE_TABLES_QUERY)
|
|
||||||
|
|
||||||
async def close(self):
|
|
||||||
await self.db.close()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _insert_sql(table: str, data: dict, ignore_duplicate: bool = False,
|
|
||||||
replace: bool = False) -> Tuple[str, List]:
|
|
||||||
columns, values = [], []
|
|
||||||
for column, value in data.items():
|
|
||||||
columns.append(column)
|
|
||||||
values.append(value)
|
|
||||||
policy = ""
|
|
||||||
if ignore_duplicate:
|
|
||||||
policy = " OR IGNORE"
|
|
||||||
if replace:
|
|
||||||
policy = " OR REPLACE"
|
|
||||||
sql = "INSERT{} INTO {} ({}) VALUES ({})".format(
|
|
||||||
policy, table, ', '.join(columns), ', '.join(['?'] * len(values))
|
|
||||||
)
|
|
||||||
return sql, values
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _update_sql(table: str, data: dict, where: str,
|
|
||||||
constraints: Union[list, tuple]) -> Tuple[str, list]:
|
|
||||||
columns, values = [], []
|
|
||||||
for column, value in data.items():
|
|
||||||
columns.append(f"{column} = ?")
|
|
||||||
values.append(value)
|
|
||||||
values.extend(constraints)
|
|
||||||
sql = "UPDATE {} SET {} WHERE {}".format(
|
|
||||||
table, ', '.join(columns), where
|
|
||||||
)
|
|
||||||
return sql, values
|
|
||||||
|
|
||||||
|
|
||||||
class BaseDatabase(SQLiteMixin):
|
|
||||||
|
|
||||||
SCHEMA_VERSION = "1.1"
|
|
||||||
|
|
||||||
PRAGMAS = """
|
|
||||||
pragma journal_mode=WAL;
|
|
||||||
"""
|
|
||||||
|
|
||||||
CREATE_ACCOUNT_TABLE = """
|
|
||||||
create table if not exists account_address (
|
|
||||||
account text not null,
|
|
||||||
address text not null,
|
|
||||||
chain integer not null,
|
|
||||||
pubkey blob not null,
|
|
||||||
chain_code blob not null,
|
|
||||||
n integer not null,
|
|
||||||
depth integer not null,
|
|
||||||
primary key (account, address)
|
|
||||||
);
|
|
||||||
create index if not exists address_account_idx on account_address (address, account);
|
|
||||||
"""
|
|
||||||
|
|
||||||
CREATE_PUBKEY_ADDRESS_TABLE = """
|
|
||||||
create table if not exists pubkey_address (
|
|
||||||
address text primary key,
|
|
||||||
history text,
|
|
||||||
used_times integer not null default 0
|
|
||||||
);
|
|
||||||
"""
|
|
||||||
|
|
||||||
CREATE_TX_TABLE = """
|
|
||||||
create table if not exists tx (
|
|
||||||
txid text primary key,
|
|
||||||
raw blob not null,
|
|
||||||
height integer not null,
|
|
||||||
position integer not null,
|
|
||||||
is_verified boolean not null default 0
|
|
||||||
);
|
|
||||||
"""
|
|
||||||
|
|
||||||
CREATE_TXO_TABLE = """
|
|
||||||
create table if not exists txo (
|
|
||||||
txid text references tx,
|
|
||||||
txoid text primary key,
|
|
||||||
address text references pubkey_address,
|
|
||||||
position integer not null,
|
|
||||||
amount integer not null,
|
|
||||||
script blob not null,
|
|
||||||
is_reserved boolean not null default 0
|
|
||||||
);
|
|
||||||
create index if not exists txo_address_idx on txo (address);
|
|
||||||
"""
|
|
||||||
|
|
||||||
CREATE_TXI_TABLE = """
|
|
||||||
create table if not exists txi (
|
|
||||||
txid text references tx,
|
|
||||||
txoid text references txo,
|
|
||||||
address text references pubkey_address
|
|
||||||
);
|
|
||||||
create index if not exists txi_address_idx on txi (address);
|
|
||||||
create index if not exists txi_txoid_idx on txi (txoid);
|
|
||||||
"""
|
|
||||||
|
|
||||||
CREATE_TABLES_QUERY = (
|
|
||||||
PRAGMAS +
|
|
||||||
CREATE_ACCOUNT_TABLE +
|
|
||||||
CREATE_PUBKEY_ADDRESS_TABLE +
|
|
||||||
CREATE_TX_TABLE +
|
|
||||||
CREATE_TXO_TABLE +
|
|
||||||
CREATE_TXI_TABLE
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def txo_to_row(tx, address, txo):
|
|
||||||
return {
|
|
||||||
'txid': tx.id,
|
|
||||||
'txoid': txo.id,
|
|
||||||
'address': address,
|
|
||||||
'position': txo.position,
|
|
||||||
'amount': txo.amount,
|
|
||||||
'script': sqlite3.Binary(txo.script.source)
|
|
||||||
}
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def tx_to_row(tx):
|
|
||||||
return {
|
|
||||||
'txid': tx.id,
|
|
||||||
'raw': sqlite3.Binary(tx.raw),
|
|
||||||
'height': tx.height,
|
|
||||||
'position': tx.position,
|
|
||||||
'is_verified': tx.is_verified
|
|
||||||
}
|
|
||||||
|
|
||||||
async def insert_transaction(self, tx):
|
|
||||||
await self.db.execute_fetchall(*self._insert_sql('tx', self.tx_to_row(tx)))
|
|
||||||
|
|
||||||
async def update_transaction(self, tx):
|
|
||||||
await self.db.execute_fetchall(*self._update_sql("tx", {
|
|
||||||
'height': tx.height, 'position': tx.position, 'is_verified': tx.is_verified
|
|
||||||
}, 'txid = ?', (tx.id,)))
|
|
||||||
|
|
||||||
def _transaction_io(self, conn: sqlite3.Connection, tx: BaseTransaction, address, txhash, history):
|
|
||||||
conn.execute(*self._insert_sql('tx', self.tx_to_row(tx), replace=True))
|
|
||||||
|
|
||||||
for txo in tx.outputs:
|
|
||||||
if txo.script.is_pay_pubkey_hash and txo.pubkey_hash == txhash:
|
|
||||||
conn.execute(*self._insert_sql(
|
|
||||||
"txo", self.txo_to_row(tx, address, txo), ignore_duplicate=True
|
|
||||||
)).fetchall()
|
|
||||||
elif txo.script.is_pay_script_hash:
|
|
||||||
# TODO: implement script hash payments
|
|
||||||
log.warning('Database.save_transaction_io: pay script hash is not implemented!')
|
|
||||||
|
|
||||||
for txi in tx.inputs:
|
|
||||||
if txi.txo_ref.txo is not None:
|
|
||||||
txo = txi.txo_ref.txo
|
|
||||||
if txo.has_address and txo.get_address(self.ledger) == address:
|
|
||||||
conn.execute(*self._insert_sql("txi", {
|
|
||||||
'txid': tx.id,
|
|
||||||
'txoid': txo.id,
|
|
||||||
'address': address,
|
|
||||||
}, ignore_duplicate=True)).fetchall()
|
|
||||||
|
|
||||||
conn.execute(
|
|
||||||
"UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?",
|
|
||||||
(history, history.count(':') // 2, address)
|
|
||||||
)
|
|
||||||
|
|
||||||
def save_transaction_io(self, tx: BaseTransaction, address, txhash, history):
|
|
||||||
return self.db.run(self._transaction_io, tx, address, txhash, history)
|
|
||||||
|
|
||||||
def save_transaction_io_batch(self, txs: Iterable[BaseTransaction], address, txhash, history):
|
|
||||||
def __many(conn):
|
|
||||||
for tx in txs:
|
|
||||||
self._transaction_io(conn, tx, address, txhash, history)
|
|
||||||
return self.db.run(__many)
|
|
||||||
|
|
||||||
async def reserve_outputs(self, txos, is_reserved=True):
|
|
||||||
txoids = ((is_reserved, txo.id) for txo in txos)
|
|
||||||
await self.db.executemany("UPDATE txo SET is_reserved = ? WHERE txoid = ?", txoids)
|
|
||||||
|
|
||||||
async def release_outputs(self, txos):
|
|
||||||
await self.reserve_outputs(txos, is_reserved=False)
|
|
||||||
|
|
||||||
async def rewind_blockchain(self, above_height): # pylint: disable=no-self-use
|
|
||||||
# TODO:
|
|
||||||
# 1. delete transactions above_height
|
|
||||||
# 2. update address histories removing deleted TXs
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def select_transactions(self, cols, accounts=None, **constraints):
|
|
||||||
if not {'txid', 'txid__in'}.intersection(constraints):
|
|
||||||
assert accounts, "'accounts' argument required when no 'txid' constraint is present"
|
|
||||||
constraints.update({
|
|
||||||
f'$account{i}': a.public_key.address for i, a in enumerate(accounts)
|
|
||||||
})
|
|
||||||
account_values = ', '.join([f':$account{i}' for i in range(len(accounts))])
|
|
||||||
where = f" WHERE account_address.account IN ({account_values})"
|
|
||||||
constraints['txid__in'] = f"""
|
|
||||||
SELECT txo.txid FROM txo JOIN account_address USING (address) {where}
|
|
||||||
UNION
|
|
||||||
SELECT txi.txid FROM txi JOIN account_address USING (address) {where}
|
|
||||||
"""
|
|
||||||
return await self.db.execute_fetchall(
|
|
||||||
*query(f"SELECT {cols} FROM tx", **constraints)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def get_transactions(self, wallet=None, **constraints):
|
|
||||||
tx_rows = await self.select_transactions(
|
|
||||||
'txid, raw, height, position, is_verified',
|
|
||||||
order_by=constraints.pop('order_by', ["height=0 DESC", "height DESC", "position DESC"]),
|
|
||||||
**constraints
|
|
||||||
)
|
|
||||||
|
|
||||||
if not tx_rows:
|
|
||||||
return []
|
|
||||||
|
|
||||||
txids, txs, txi_txoids = [], [], []
|
|
||||||
for row in tx_rows:
|
|
||||||
txids.append(row[0])
|
|
||||||
txs.append(self.ledger.transaction_class(
|
|
||||||
raw=row[1], height=row[2], position=row[3], is_verified=bool(row[4])
|
|
||||||
))
|
|
||||||
for txi in txs[-1].inputs:
|
|
||||||
txi_txoids.append(txi.txo_ref.id)
|
|
||||||
|
|
||||||
step = self.MAX_QUERY_VARIABLES
|
|
||||||
annotated_txos = {}
|
|
||||||
for offset in range(0, len(txids), step):
|
|
||||||
annotated_txos.update({
|
|
||||||
txo.id: txo for txo in
|
|
||||||
(await self.get_txos(
|
|
||||||
wallet=wallet,
|
|
||||||
txid__in=txids[offset:offset+step],
|
|
||||||
))
|
|
||||||
})
|
|
||||||
|
|
||||||
referenced_txos = {}
|
|
||||||
for offset in range(0, len(txi_txoids), step):
|
|
||||||
referenced_txos.update({
|
|
||||||
txo.id: txo for txo in
|
|
||||||
(await self.get_txos(
|
|
||||||
wallet=wallet,
|
|
||||||
txoid__in=txi_txoids[offset:offset+step],
|
|
||||||
))
|
|
||||||
})
|
|
||||||
|
|
||||||
for tx in txs:
|
|
||||||
for txi in tx.inputs:
|
|
||||||
txo = referenced_txos.get(txi.txo_ref.id)
|
|
||||||
if txo:
|
|
||||||
txi.txo_ref = txo.ref
|
|
||||||
for txo in tx.outputs:
|
|
||||||
_txo = annotated_txos.get(txo.id)
|
|
||||||
if _txo:
|
|
||||||
txo.update_annotations(_txo)
|
|
||||||
else:
|
|
||||||
txo.update_annotations(None)
|
|
||||||
|
|
||||||
return txs
|
|
||||||
|
|
||||||
async def get_transaction_count(self, **constraints):
|
|
||||||
constraints.pop('wallet', None)
|
|
||||||
constraints.pop('offset', None)
|
|
||||||
constraints.pop('limit', None)
|
|
||||||
constraints.pop('order_by', None)
|
|
||||||
count = await self.select_transactions('count(*)', **constraints)
|
|
||||||
return count[0][0]
|
|
||||||
|
|
||||||
async def get_transaction(self, **constraints):
|
|
||||||
txs = await self.get_transactions(limit=1, **constraints)
|
|
||||||
if txs:
|
|
||||||
return txs[0]
|
|
||||||
|
|
||||||
async def select_txos(self, cols, **constraints):
|
|
||||||
sql = f"SELECT {cols} FROM txo JOIN tx USING (txid)"
|
|
||||||
if 'accounts' in constraints:
|
|
||||||
sql += " JOIN account_address USING (address)"
|
|
||||||
return await self.db.execute_fetchall(*query(sql, **constraints))
|
|
||||||
|
|
||||||
async def get_txos(self, wallet=None, no_tx=False, **constraints):
|
|
||||||
my_accounts = {a.public_key.address for a in wallet.accounts} if wallet else set()
|
|
||||||
if 'order_by' not in constraints:
|
|
||||||
constraints['order_by'] = [
|
|
||||||
"tx.height=0 DESC", "tx.height DESC", "tx.position DESC", "txo.position"
|
|
||||||
]
|
|
||||||
rows = await self.select_txos(
|
|
||||||
"""
|
|
||||||
tx.txid, raw, tx.height, tx.position, tx.is_verified, txo.position, amount, script, (
|
|
||||||
select group_concat(account||"|"||chain) from account_address
|
|
||||||
where account_address.address=txo.address
|
|
||||||
)
|
|
||||||
""",
|
|
||||||
**constraints
|
|
||||||
)
|
|
||||||
txos = []
|
|
||||||
txs = {}
|
|
||||||
output_class = self.ledger.transaction_class.output_class
|
|
||||||
for row in rows:
|
|
||||||
if no_tx:
|
|
||||||
txo = output_class(
|
|
||||||
amount=row[6],
|
|
||||||
script=output_class.script_class(row[7]),
|
|
||||||
tx_ref=TXRefImmutable.from_id(row[0], row[2]),
|
|
||||||
position=row[5]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if row[0] not in txs:
|
|
||||||
txs[row[0]] = self.ledger.transaction_class(
|
|
||||||
row[1], height=row[2], position=row[3], is_verified=row[4]
|
|
||||||
)
|
|
||||||
txo = txs[row[0]].outputs[row[5]]
|
|
||||||
row_accounts = dict(a.split('|') for a in row[8].split(','))
|
|
||||||
account_match = set(row_accounts) & my_accounts
|
|
||||||
if account_match:
|
|
||||||
txo.is_my_account = True
|
|
||||||
txo.is_change = row_accounts[account_match.pop()] == '1'
|
|
||||||
else:
|
|
||||||
txo.is_change = txo.is_my_account = False
|
|
||||||
txos.append(txo)
|
|
||||||
return txos
|
|
||||||
|
|
||||||
async def get_txo_count(self, **constraints):
|
|
||||||
constraints.pop('wallet', None)
|
|
||||||
constraints.pop('offset', None)
|
|
||||||
constraints.pop('limit', None)
|
|
||||||
constraints.pop('order_by', None)
|
|
||||||
count = await self.select_txos('count(*)', **constraints)
|
|
||||||
return count[0][0]
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def constrain_utxo(constraints):
|
|
||||||
constraints['is_reserved'] = False
|
|
||||||
constraints['txoid__not_in'] = "SELECT txoid FROM txi"
|
|
||||||
|
|
||||||
def get_utxos(self, **constraints):
|
|
||||||
self.constrain_utxo(constraints)
|
|
||||||
return self.get_txos(**constraints)
|
|
||||||
|
|
||||||
def get_utxo_count(self, **constraints):
|
|
||||||
self.constrain_utxo(constraints)
|
|
||||||
return self.get_txo_count(**constraints)
|
|
||||||
|
|
||||||
async def get_balance(self, wallet=None, accounts=None, **constraints):
|
|
||||||
assert wallet or accounts, \
|
|
||||||
"'wallet' or 'accounts' constraints required to calculate balance"
|
|
||||||
constraints['accounts'] = accounts or wallet.accounts
|
|
||||||
self.constrain_utxo(constraints)
|
|
||||||
balance = await self.select_txos('SUM(amount)', **constraints)
|
|
||||||
return balance[0][0] or 0
|
|
||||||
|
|
||||||
async def select_addresses(self, cols, **constraints):
|
|
||||||
return await self.db.execute_fetchall(*query(
|
|
||||||
f"SELECT {cols} FROM pubkey_address JOIN account_address USING (address)",
|
|
||||||
**constraints
|
|
||||||
))
|
|
||||||
|
|
||||||
async def get_addresses(self, cols=None, **constraints):
|
|
||||||
cols = cols or (
|
|
||||||
'address', 'account', 'chain', 'history', 'used_times',
|
|
||||||
'pubkey', 'chain_code', 'n', 'depth'
|
|
||||||
)
|
|
||||||
addresses = rows_to_dict(await self.select_addresses(', '.join(cols), **constraints), cols)
|
|
||||||
if 'pubkey' in cols:
|
|
||||||
for address in addresses:
|
|
||||||
address['pubkey'] = PubKey(
|
|
||||||
self.ledger, address.pop('pubkey'), address.pop('chain_code'),
|
|
||||||
address.pop('n'), address.pop('depth')
|
|
||||||
)
|
|
||||||
return addresses
|
|
||||||
|
|
||||||
async def get_address_count(self, cols=None, **constraints):
|
|
||||||
count = await self.select_addresses('count(*)', **constraints)
|
|
||||||
return count[0][0]
|
|
||||||
|
|
||||||
async def get_address(self, **constraints):
|
|
||||||
addresses = await self.get_addresses(limit=1, **constraints)
|
|
||||||
if addresses:
|
|
||||||
return addresses[0]
|
|
||||||
|
|
||||||
async def add_keys(self, account, chain, pubkeys):
|
|
||||||
await self.db.executemany(
|
|
||||||
"insert or ignore into account_address "
|
|
||||||
"(account, address, chain, pubkey, chain_code, n, depth) values "
|
|
||||||
"(?, ?, ?, ?, ?, ?, ?)", ((
|
|
||||||
account.id, k.address, chain,
|
|
||||||
sqlite3.Binary(k.pubkey_bytes),
|
|
||||||
sqlite3.Binary(k.chain_code),
|
|
||||||
k.n, k.depth
|
|
||||||
) for k in pubkeys)
|
|
||||||
)
|
|
||||||
await self.db.executemany(
|
|
||||||
"insert or ignore into pubkey_address (address) values (?)",
|
|
||||||
((pubkey.address,) for pubkey in pubkeys)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _set_address_history(self, address, history):
|
|
||||||
await self.db.execute_fetchall(
|
|
||||||
"UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?",
|
|
||||||
(history, history.count(':')//2, address)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def set_address_history(self, address, history):
|
|
||||||
await self._set_address_history(address, history)
|
|
|
@ -1,246 +0,0 @@
|
||||||
import asyncio
|
|
||||||
import hashlib
|
|
||||||
import os
|
|
||||||
import logging
|
|
||||||
from contextlib import asynccontextmanager
|
|
||||||
from io import BytesIO
|
|
||||||
from typing import Optional, Iterator, Tuple
|
|
||||||
from binascii import hexlify
|
|
||||||
|
|
||||||
from lbry.wallet.client.util import ArithUint256
|
|
||||||
from lbry.crypto.hash import double_sha256
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class InvalidHeader(Exception):
|
|
||||||
|
|
||||||
def __init__(self, height, message):
|
|
||||||
super().__init__(message)
|
|
||||||
self.message = message
|
|
||||||
self.height = height
|
|
||||||
|
|
||||||
|
|
||||||
class BaseHeaders:
|
|
||||||
|
|
||||||
header_size: int
|
|
||||||
chunk_size: int
|
|
||||||
|
|
||||||
max_target: int
|
|
||||||
genesis_hash: Optional[bytes]
|
|
||||||
target_timespan: int
|
|
||||||
|
|
||||||
validate_difficulty: bool = True
|
|
||||||
checkpoint = None
|
|
||||||
|
|
||||||
def __init__(self, path) -> None:
|
|
||||||
if path == ':memory:':
|
|
||||||
self.io = BytesIO()
|
|
||||||
self.path = path
|
|
||||||
self._size: Optional[int] = None
|
|
||||||
|
|
||||||
async def open(self):
|
|
||||||
if self.path != ':memory:':
|
|
||||||
if not os.path.exists(self.path):
|
|
||||||
self.io = open(self.path, 'w+b')
|
|
||||||
else:
|
|
||||||
self.io = open(self.path, 'r+b')
|
|
||||||
|
|
||||||
async def close(self):
|
|
||||||
self.io.close()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def serialize(header: dict) -> bytes:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def deserialize(height, header):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def get_next_chunk_target(self, chunk: int) -> ArithUint256:
|
|
||||||
return ArithUint256(self.max_target)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_next_block_target(chunk_target: ArithUint256, previous: Optional[dict],
|
|
||||||
current: Optional[dict]) -> ArithUint256:
|
|
||||||
return chunk_target
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
if self._size is None:
|
|
||||||
self._size = self.io.seek(0, os.SEEK_END) // self.header_size
|
|
||||||
return self._size
|
|
||||||
|
|
||||||
def __bool__(self):
|
|
||||||
return True
|
|
||||||
|
|
||||||
def __getitem__(self, height) -> dict:
|
|
||||||
if isinstance(height, slice):
|
|
||||||
raise NotImplementedError("Slicing of header chain has not been implemented yet.")
|
|
||||||
if not 0 <= height <= self.height:
|
|
||||||
raise IndexError(f"{height} is out of bounds, current height: {self.height}")
|
|
||||||
return self.deserialize(height, self.get_raw_header(height))
|
|
||||||
|
|
||||||
def get_raw_header(self, height) -> bytes:
|
|
||||||
self.io.seek(height * self.header_size, os.SEEK_SET)
|
|
||||||
return self.io.read(self.header_size)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def height(self) -> int:
|
|
||||||
return len(self)-1
|
|
||||||
|
|
||||||
@property
|
|
||||||
def bytes_size(self):
|
|
||||||
return len(self) * self.header_size
|
|
||||||
|
|
||||||
def hash(self, height=None) -> bytes:
|
|
||||||
return self.hash_header(
|
|
||||||
self.get_raw_header(height if height is not None else self.height)
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def hash_header(header: bytes) -> bytes:
|
|
||||||
if header is None:
|
|
||||||
return b'0' * 64
|
|
||||||
return hexlify(double_sha256(header)[::-1])
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def checkpointed_connector(self):
|
|
||||||
buf = BytesIO()
|
|
||||||
try:
|
|
||||||
yield buf
|
|
||||||
finally:
|
|
||||||
await asyncio.sleep(0)
|
|
||||||
final_height = len(self) + buf.tell() // self.header_size
|
|
||||||
verifiable_bytes = (self.checkpoint[0] - len(self)) * self.header_size if self.checkpoint else 0
|
|
||||||
if verifiable_bytes > 0 and final_height >= self.checkpoint[0]:
|
|
||||||
buf.seek(0)
|
|
||||||
self.io.seek(0)
|
|
||||||
h = hashlib.sha256()
|
|
||||||
h.update(self.io.read())
|
|
||||||
h.update(buf.read(verifiable_bytes))
|
|
||||||
if h.hexdigest().encode() == self.checkpoint[1]:
|
|
||||||
buf.seek(0)
|
|
||||||
self._write(len(self), buf.read(verifiable_bytes))
|
|
||||||
remaining = buf.read()
|
|
||||||
buf.seek(0)
|
|
||||||
buf.write(remaining)
|
|
||||||
buf.truncate()
|
|
||||||
else:
|
|
||||||
log.warning("Checkpoint mismatch, connecting headers through slow method.")
|
|
||||||
if buf.tell() > 0:
|
|
||||||
await self.connect(len(self), buf.getvalue())
|
|
||||||
|
|
||||||
async def connect(self, start: int, headers: bytes) -> int:
|
|
||||||
added = 0
|
|
||||||
bail = False
|
|
||||||
for height, chunk in self._iterate_chunks(start, headers):
|
|
||||||
try:
|
|
||||||
# validate_chunk() is CPU bound and reads previous chunks from file system
|
|
||||||
self.validate_chunk(height, chunk)
|
|
||||||
except InvalidHeader as e:
|
|
||||||
bail = True
|
|
||||||
chunk = chunk[:(height-e.height)*self.header_size]
|
|
||||||
added += self._write(height, chunk) if chunk else 0
|
|
||||||
if bail:
|
|
||||||
break
|
|
||||||
return added
|
|
||||||
|
|
||||||
def _write(self, height, verified_chunk):
|
|
||||||
self.io.seek(height * self.header_size, os.SEEK_SET)
|
|
||||||
written = self.io.write(verified_chunk) // self.header_size
|
|
||||||
self.io.truncate()
|
|
||||||
# .seek()/.write()/.truncate() might also .flush() when needed
|
|
||||||
# the goal here is mainly to ensure we're definitely flush()'ing
|
|
||||||
self.io.flush()
|
|
||||||
self._size = self.io.tell() // self.header_size
|
|
||||||
return written
|
|
||||||
|
|
||||||
def validate_chunk(self, height, chunk):
|
|
||||||
previous_hash, previous_header, previous_previous_header = None, None, None
|
|
||||||
if height > 0:
|
|
||||||
previous_header = self[height-1]
|
|
||||||
previous_hash = self.hash(height-1)
|
|
||||||
if height > 1:
|
|
||||||
previous_previous_header = self[height-2]
|
|
||||||
chunk_target = self.get_next_chunk_target(height // 2016 - 1)
|
|
||||||
for current_hash, current_header in self._iterate_headers(height, chunk):
|
|
||||||
block_target = self.get_next_block_target(chunk_target, previous_previous_header, previous_header)
|
|
||||||
self.validate_header(height, current_hash, current_header, previous_hash, block_target)
|
|
||||||
previous_previous_header = previous_header
|
|
||||||
previous_header = current_header
|
|
||||||
previous_hash = current_hash
|
|
||||||
|
|
||||||
def validate_header(self, height: int, current_hash: bytes,
|
|
||||||
header: dict, previous_hash: bytes, target: ArithUint256):
|
|
||||||
|
|
||||||
if previous_hash is None:
|
|
||||||
if self.genesis_hash is not None and self.genesis_hash != current_hash:
|
|
||||||
raise InvalidHeader(
|
|
||||||
height, f"genesis header doesn't match: {current_hash.decode()} "
|
|
||||||
f"vs expected {self.genesis_hash.decode()}")
|
|
||||||
return
|
|
||||||
|
|
||||||
if header['prev_block_hash'] != previous_hash:
|
|
||||||
raise InvalidHeader(
|
|
||||||
height, "previous hash mismatch: {} vs expected {}".format(
|
|
||||||
header['prev_block_hash'].decode(), previous_hash.decode())
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.validate_difficulty:
|
|
||||||
|
|
||||||
if header['bits'] != target.compact:
|
|
||||||
raise InvalidHeader(
|
|
||||||
height, "bits mismatch: {} vs expected {}".format(
|
|
||||||
header['bits'], target.compact)
|
|
||||||
)
|
|
||||||
|
|
||||||
proof_of_work = self.get_proof_of_work(current_hash)
|
|
||||||
if proof_of_work > target:
|
|
||||||
raise InvalidHeader(
|
|
||||||
height, f"insufficient proof of work: {proof_of_work.value} vs target {target.value}"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def repair(self):
|
|
||||||
previous_header_hash = fail = None
|
|
||||||
batch_size = 36
|
|
||||||
for start_height in range(0, self.height, batch_size):
|
|
||||||
self.io.seek(self.header_size * start_height)
|
|
||||||
headers = self.io.read(self.header_size*batch_size)
|
|
||||||
if len(headers) % self.header_size != 0:
|
|
||||||
headers = headers[:(len(headers) // self.header_size) * self.header_size]
|
|
||||||
for header_hash, header in self._iterate_headers(start_height, headers):
|
|
||||||
height = header['block_height']
|
|
||||||
if height:
|
|
||||||
if header['prev_block_hash'] != previous_header_hash:
|
|
||||||
fail = True
|
|
||||||
else:
|
|
||||||
if header_hash != self.genesis_hash:
|
|
||||||
fail = True
|
|
||||||
if fail:
|
|
||||||
log.warning("Header file corrupted at height %s, truncating it.", height - 1)
|
|
||||||
self.io.seek(max(0, (height - 1)) * self.header_size, os.SEEK_SET)
|
|
||||||
self.io.truncate()
|
|
||||||
self.io.flush()
|
|
||||||
self._size = None
|
|
||||||
return
|
|
||||||
previous_header_hash = header_hash
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_proof_of_work(header_hash: bytes) -> ArithUint256:
|
|
||||||
return ArithUint256(int(b'0x' + header_hash, 16))
|
|
||||||
|
|
||||||
def _iterate_chunks(self, height: int, headers: bytes) -> Iterator[Tuple[int, bytes]]:
|
|
||||||
assert len(headers) % self.header_size == 0, f"{len(headers)} {len(headers)%self.header_size}"
|
|
||||||
start = 0
|
|
||||||
end = (self.chunk_size - height % self.chunk_size) * self.header_size
|
|
||||||
while start < end:
|
|
||||||
yield height + (start // self.header_size), headers[start:end]
|
|
||||||
start = end
|
|
||||||
end = min(len(headers), end + self.chunk_size * self.header_size)
|
|
||||||
|
|
||||||
def _iterate_headers(self, height: int, headers: bytes) -> Iterator[Tuple[bytes, dict]]:
|
|
||||||
assert len(headers) % self.header_size == 0, len(headers)
|
|
||||||
for idx in range(len(headers) // self.header_size):
|
|
||||||
start, end = idx * self.header_size, (idx + 1) * self.header_size
|
|
||||||
header = headers[start:end]
|
|
||||||
yield self.hash_header(header), self.deserialize(height+idx, header)
|
|
|
@ -1,605 +0,0 @@
|
||||||
import base64
|
|
||||||
import os
|
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
import zlib
|
|
||||||
from functools import partial
|
|
||||||
from binascii import hexlify, unhexlify
|
|
||||||
from io import StringIO
|
|
||||||
|
|
||||||
from typing import Dict, Type, Iterable, List, Optional
|
|
||||||
from operator import itemgetter
|
|
||||||
from collections import namedtuple
|
|
||||||
|
|
||||||
import pylru
|
|
||||||
from lbry.wallet.client.basetransaction import BaseTransaction
|
|
||||||
from lbry.wallet.tasks import TaskGroup
|
|
||||||
from lbry.wallet.client import baseaccount, basenetwork, basetransaction
|
|
||||||
from lbry.wallet.client.basedatabase import BaseDatabase
|
|
||||||
from lbry.wallet.client.baseheader import BaseHeaders
|
|
||||||
from lbry.wallet.client.coinselection import CoinSelector
|
|
||||||
from lbry.wallet.client.constants import COIN, NULL_HASH32
|
|
||||||
from lbry.wallet.stream import StreamController
|
|
||||||
from lbry.crypto.hash import hash160, double_sha256, sha256
|
|
||||||
from lbry.crypto.base58 import Base58
|
|
||||||
from lbry.wallet.client.bip32 import PubKey, PrivateKey
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
LedgerType = Type['BaseLedger']
|
|
||||||
|
|
||||||
|
|
||||||
class LedgerRegistry(type):
|
|
||||||
|
|
||||||
ledgers: Dict[str, LedgerType] = {}
|
|
||||||
|
|
||||||
def __new__(mcs, name, bases, attrs):
|
|
||||||
cls: LedgerType = super().__new__(mcs, name, bases, attrs)
|
|
||||||
if not (name == 'BaseLedger' and not bases):
|
|
||||||
ledger_id = cls.get_id()
|
|
||||||
assert ledger_id not in mcs.ledgers,\
|
|
||||||
f'Ledger with id "{ledger_id}" already registered.'
|
|
||||||
mcs.ledgers[ledger_id] = cls
|
|
||||||
return cls
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_ledger_class(mcs, ledger_id: str) -> LedgerType:
|
|
||||||
return mcs.ledgers[ledger_id]
|
|
||||||
|
|
||||||
|
|
||||||
class TransactionEvent(namedtuple('TransactionEvent', ('address', 'tx'))):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class AddressesGeneratedEvent(namedtuple('AddressesGeneratedEvent', ('address_manager', 'addresses'))):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class BlockHeightEvent(namedtuple('BlockHeightEvent', ('height', 'change'))):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class TransactionCacheItem:
|
|
||||||
__slots__ = '_tx', 'lock', 'has_tx'
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
tx: Optional[basetransaction.BaseTransaction] = None,
|
|
||||||
lock: Optional[asyncio.Lock] = None):
|
|
||||||
self.has_tx = asyncio.Event()
|
|
||||||
self.lock = lock or asyncio.Lock()
|
|
||||||
self._tx = self.tx = tx
|
|
||||||
|
|
||||||
@property
|
|
||||||
def tx(self) -> Optional[basetransaction.BaseTransaction]:
|
|
||||||
return self._tx
|
|
||||||
|
|
||||||
@tx.setter
|
|
||||||
def tx(self, tx: basetransaction.BaseTransaction):
|
|
||||||
self._tx = tx
|
|
||||||
if tx is not None:
|
|
||||||
self.has_tx.set()
|
|
||||||
|
|
||||||
|
|
||||||
class BaseLedger(metaclass=LedgerRegistry):
|
|
||||||
|
|
||||||
name: str
|
|
||||||
symbol: str
|
|
||||||
network_name: str
|
|
||||||
|
|
||||||
database_class = BaseDatabase
|
|
||||||
account_class = baseaccount.BaseAccount
|
|
||||||
network_class = basenetwork.BaseNetwork
|
|
||||||
transaction_class = basetransaction.BaseTransaction
|
|
||||||
|
|
||||||
headers_class: Type[BaseHeaders]
|
|
||||||
|
|
||||||
pubkey_address_prefix: bytes
|
|
||||||
script_address_prefix: bytes
|
|
||||||
extended_public_key_prefix: bytes
|
|
||||||
extended_private_key_prefix: bytes
|
|
||||||
|
|
||||||
default_fee_per_byte = 10
|
|
||||||
|
|
||||||
def __init__(self, config=None):
|
|
||||||
self.config = config or {}
|
|
||||||
self.db: BaseDatabase = self.config.get('db') or self.database_class(
|
|
||||||
os.path.join(self.path, "blockchain.db")
|
|
||||||
)
|
|
||||||
self.db.ledger = self
|
|
||||||
self.headers: BaseHeaders = self.config.get('headers') or self.headers_class(
|
|
||||||
os.path.join(self.path, "headers")
|
|
||||||
)
|
|
||||||
self.network = self.config.get('network') or self.network_class(self)
|
|
||||||
self.network.on_header.listen(self.receive_header)
|
|
||||||
self.network.on_status.listen(self.process_status_update)
|
|
||||||
self.network.on_connected.listen(self.join_network)
|
|
||||||
|
|
||||||
self.accounts = []
|
|
||||||
self.fee_per_byte: int = self.config.get('fee_per_byte', self.default_fee_per_byte)
|
|
||||||
|
|
||||||
self._on_transaction_controller = StreamController()
|
|
||||||
self.on_transaction = self._on_transaction_controller.stream
|
|
||||||
self.on_transaction.listen(
|
|
||||||
lambda e: log.info(
|
|
||||||
'(%s) on_transaction: address=%s, height=%s, is_verified=%s, tx.id=%s',
|
|
||||||
self.get_id(), e.address, e.tx.height, e.tx.is_verified, e.tx.id
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
self._on_address_controller = StreamController()
|
|
||||||
self.on_address = self._on_address_controller.stream
|
|
||||||
self.on_address.listen(
|
|
||||||
lambda e: log.info('(%s) on_address: %s', self.get_id(), e.addresses)
|
|
||||||
)
|
|
||||||
|
|
||||||
self._on_header_controller = StreamController()
|
|
||||||
self.on_header = self._on_header_controller.stream
|
|
||||||
self.on_header.listen(
|
|
||||||
lambda change: log.info(
|
|
||||||
'%s: added %s header blocks, final height %s',
|
|
||||||
self.get_id(), change, self.headers.height
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self._download_height = 0
|
|
||||||
|
|
||||||
self._on_ready_controller = StreamController()
|
|
||||||
self.on_ready = self._on_ready_controller.stream
|
|
||||||
|
|
||||||
self._tx_cache = pylru.lrucache(100000)
|
|
||||||
self._update_tasks = TaskGroup()
|
|
||||||
self._utxo_reservation_lock = asyncio.Lock()
|
|
||||||
self._header_processing_lock = asyncio.Lock()
|
|
||||||
self._address_update_locks: Dict[str, asyncio.Lock] = {}
|
|
||||||
|
|
||||||
self.coin_selection_strategy = None
|
|
||||||
self._known_addresses_out_of_sync = set()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_id(cls):
|
|
||||||
return '{}_{}'.format(cls.symbol.lower(), cls.network_name.lower())
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def hash160_to_address(cls, h160):
|
|
||||||
raw_address = cls.pubkey_address_prefix + h160
|
|
||||||
return Base58.encode(bytearray(raw_address + double_sha256(raw_address)[0:4]))
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def address_to_hash160(address):
|
|
||||||
return Base58.decode(address)[1:21]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def is_valid_address(cls, address):
|
|
||||||
decoded = Base58.decode_check(address)
|
|
||||||
return decoded[0] == cls.pubkey_address_prefix[0]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def public_key_to_address(cls, public_key):
|
|
||||||
return cls.hash160_to_address(hash160(public_key))
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def private_key_to_wif(private_key):
|
|
||||||
return b'\x1c' + private_key + b'\x01'
|
|
||||||
|
|
||||||
@property
|
|
||||||
def path(self):
|
|
||||||
return os.path.join(self.config['data_path'], self.get_id())
|
|
||||||
|
|
||||||
def add_account(self, account: baseaccount.BaseAccount):
|
|
||||||
self.accounts.append(account)
|
|
||||||
|
|
||||||
async def _get_account_and_address_info_for_address(self, wallet, address):
|
|
||||||
match = await self.db.get_address(accounts=wallet.accounts, address=address)
|
|
||||||
if match:
|
|
||||||
for account in wallet.accounts:
|
|
||||||
if match['account'] == account.public_key.address:
|
|
||||||
return account, match
|
|
||||||
|
|
||||||
async def get_private_key_for_address(self, wallet, address) -> Optional[PrivateKey]:
|
|
||||||
match = await self._get_account_and_address_info_for_address(wallet, address)
|
|
||||||
if match:
|
|
||||||
account, address_info = match
|
|
||||||
return account.get_private_key(address_info['chain'], address_info['pubkey'].n)
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def get_public_key_for_address(self, wallet, address) -> Optional[PubKey]:
|
|
||||||
match = await self._get_account_and_address_info_for_address(wallet, address)
|
|
||||||
if match:
|
|
||||||
_, address_info = match
|
|
||||||
return address_info['pubkey']
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def get_account_for_address(self, wallet, address):
|
|
||||||
match = await self._get_account_and_address_info_for_address(wallet, address)
|
|
||||||
if match:
|
|
||||||
return match[0]
|
|
||||||
|
|
||||||
async def get_effective_amount_estimators(self, funding_accounts: Iterable[baseaccount.BaseAccount]):
|
|
||||||
estimators = []
|
|
||||||
for account in funding_accounts:
|
|
||||||
utxos = await account.get_utxos()
|
|
||||||
for utxo in utxos:
|
|
||||||
estimators.append(utxo.get_estimator(self))
|
|
||||||
return estimators
|
|
||||||
|
|
||||||
async def get_addresses(self, **constraints):
|
|
||||||
return await self.db.get_addresses(**constraints)
|
|
||||||
|
|
||||||
def get_address_count(self, **constraints):
|
|
||||||
return self.db.get_address_count(**constraints)
|
|
||||||
|
|
||||||
async def get_spendable_utxos(self, amount: int, funding_accounts):
|
|
||||||
async with self._utxo_reservation_lock:
|
|
||||||
txos = await self.get_effective_amount_estimators(funding_accounts)
|
|
||||||
fee = self.transaction_class.output_class.pay_pubkey_hash(COIN, NULL_HASH32).get_fee(self)
|
|
||||||
selector = CoinSelector(amount, fee)
|
|
||||||
spendables = selector.select(txos, self.coin_selection_strategy)
|
|
||||||
if spendables:
|
|
||||||
await self.reserve_outputs(s.txo for s in spendables)
|
|
||||||
return spendables
|
|
||||||
|
|
||||||
def reserve_outputs(self, txos):
|
|
||||||
return self.db.reserve_outputs(txos)
|
|
||||||
|
|
||||||
def release_outputs(self, txos):
|
|
||||||
return self.db.release_outputs(txos)
|
|
||||||
|
|
||||||
def release_tx(self, tx):
|
|
||||||
return self.release_outputs([txi.txo_ref.txo for txi in tx.inputs])
|
|
||||||
|
|
||||||
def get_utxos(self, **constraints):
|
|
||||||
return self.db.get_utxos(**constraints)
|
|
||||||
|
|
||||||
def get_utxo_count(self, **constraints):
|
|
||||||
return self.db.get_utxo_count(**constraints)
|
|
||||||
|
|
||||||
def get_transactions(self, **constraints):
|
|
||||||
return self.db.get_transactions(**constraints)
|
|
||||||
|
|
||||||
def get_transaction_count(self, **constraints):
|
|
||||||
return self.db.get_transaction_count(**constraints)
|
|
||||||
|
|
||||||
async def get_local_status_and_history(self, address, history=None):
|
|
||||||
if not history:
|
|
||||||
address_details = await self.db.get_address(address=address)
|
|
||||||
history = address_details['history'] or ''
|
|
||||||
parts = history.split(':')[:-1]
|
|
||||||
return (
|
|
||||||
hexlify(sha256(history.encode())).decode() if history else None,
|
|
||||||
list(zip(parts[0::2], map(int, parts[1::2])))
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_root_of_merkle_tree(branches, branch_positions, working_branch):
|
|
||||||
for i, branch in enumerate(branches):
|
|
||||||
other_branch = unhexlify(branch)[::-1]
|
|
||||||
other_branch_on_left = bool((branch_positions >> i) & 1)
|
|
||||||
if other_branch_on_left:
|
|
||||||
combined = other_branch + working_branch
|
|
||||||
else:
|
|
||||||
combined = working_branch + other_branch
|
|
||||||
working_branch = double_sha256(combined)
|
|
||||||
return hexlify(working_branch[::-1])
|
|
||||||
|
|
||||||
async def start(self):
|
|
||||||
if not os.path.exists(self.path):
|
|
||||||
os.mkdir(self.path)
|
|
||||||
await asyncio.wait([
|
|
||||||
self.db.open(),
|
|
||||||
self.headers.open()
|
|
||||||
])
|
|
||||||
first_connection = self.network.on_connected.first
|
|
||||||
asyncio.ensure_future(self.network.start())
|
|
||||||
await first_connection
|
|
||||||
async with self._header_processing_lock:
|
|
||||||
await self._update_tasks.add(self.initial_headers_sync())
|
|
||||||
await self._on_ready_controller.stream.first
|
|
||||||
|
|
||||||
async def join_network(self, *_):
|
|
||||||
log.info("Subscribing and updating accounts.")
|
|
||||||
async with self._header_processing_lock:
|
|
||||||
await self.update_headers()
|
|
||||||
await self.subscribe_accounts()
|
|
||||||
await self._update_tasks.done.wait()
|
|
||||||
self._on_ready_controller.add(True)
|
|
||||||
|
|
||||||
async def stop(self):
|
|
||||||
self._update_tasks.cancel()
|
|
||||||
await self._update_tasks.done.wait()
|
|
||||||
await self.network.stop()
|
|
||||||
await self.db.close()
|
|
||||||
await self.headers.close()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def local_height_including_downloaded_height(self):
|
|
||||||
return max(self.headers.height, self._download_height)
|
|
||||||
|
|
||||||
async def initial_headers_sync(self):
|
|
||||||
target = self.network.remote_height + 1
|
|
||||||
current = len(self.headers)
|
|
||||||
get_chunk = partial(self.network.retriable_call, self.network.get_headers, count=4096, b64=True)
|
|
||||||
chunks = [asyncio.create_task(get_chunk(height)) for height in range(current, target, 4096)]
|
|
||||||
total = 0
|
|
||||||
async with self.headers.checkpointed_connector() as buffer:
|
|
||||||
for chunk in chunks:
|
|
||||||
headers = await chunk
|
|
||||||
total += buffer.write(
|
|
||||||
zlib.decompress(base64.b64decode(headers['base64']), wbits=-15, bufsize=600_000)
|
|
||||||
)
|
|
||||||
self._download_height = current + total // self.headers.header_size
|
|
||||||
log.info("Headers sync: %s / %s", self._download_height, target)
|
|
||||||
|
|
||||||
async def update_headers(self, height=None, headers=None, subscription_update=False):
|
|
||||||
rewound = 0
|
|
||||||
while True:
|
|
||||||
|
|
||||||
if height is None or height > len(self.headers):
|
|
||||||
# sometimes header subscription updates are for a header in the future
|
|
||||||
# which can't be connected, so we do a normal header sync instead
|
|
||||||
height = len(self.headers)
|
|
||||||
headers = None
|
|
||||||
subscription_update = False
|
|
||||||
|
|
||||||
if not headers:
|
|
||||||
header_response = await self.network.retriable_call(self.network.get_headers, height, 2001)
|
|
||||||
headers = header_response['hex']
|
|
||||||
|
|
||||||
if not headers:
|
|
||||||
# Nothing to do, network thinks we're already at the latest height.
|
|
||||||
return
|
|
||||||
|
|
||||||
added = await self.headers.connect(height, unhexlify(headers))
|
|
||||||
if added > 0:
|
|
||||||
height += added
|
|
||||||
self._on_header_controller.add(
|
|
||||||
BlockHeightEvent(self.headers.height, added))
|
|
||||||
|
|
||||||
if rewound > 0:
|
|
||||||
# we started rewinding blocks and apparently found
|
|
||||||
# a new chain
|
|
||||||
rewound = 0
|
|
||||||
await self.db.rewind_blockchain(height)
|
|
||||||
|
|
||||||
if subscription_update:
|
|
||||||
# subscription updates are for latest header already
|
|
||||||
# so we don't need to check if there are newer / more
|
|
||||||
# on another loop of update_headers(), just return instead
|
|
||||||
return
|
|
||||||
|
|
||||||
elif added == 0:
|
|
||||||
# we had headers to connect but none got connected, probably a reorganization
|
|
||||||
height -= 1
|
|
||||||
rewound += 1
|
|
||||||
log.warning(
|
|
||||||
"Blockchain Reorganization: attempting rewind to height %s from starting height %s",
|
|
||||||
height, height+rewound
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise IndexError(f"headers.connect() returned negative number ({added})")
|
|
||||||
|
|
||||||
if height < 0:
|
|
||||||
raise IndexError(
|
|
||||||
"Blockchain reorganization rewound all the way back to genesis hash. "
|
|
||||||
"Something is very wrong. Maybe you are on the wrong blockchain?"
|
|
||||||
)
|
|
||||||
|
|
||||||
if rewound >= 100:
|
|
||||||
raise IndexError(
|
|
||||||
"Blockchain reorganization dropped {} headers. This is highly unusual. "
|
|
||||||
"Will not continue to attempt reorganizing. Please, delete the ledger "
|
|
||||||
"synchronization directory inside your wallet directory (folder: '{}') and "
|
|
||||||
"restart the program to synchronize from scratch."
|
|
||||||
.format(rewound, self.get_id())
|
|
||||||
)
|
|
||||||
|
|
||||||
headers = None # ready to download some more headers
|
|
||||||
|
|
||||||
# if we made it this far and this was a subscription_update
|
|
||||||
# it means something went wrong and now we're doing a more
|
|
||||||
# robust sync, turn off subscription update shortcut
|
|
||||||
subscription_update = False
|
|
||||||
|
|
||||||
async def receive_header(self, response):
|
|
||||||
async with self._header_processing_lock:
|
|
||||||
header = response[0]
|
|
||||||
await self.update_headers(
|
|
||||||
height=header['height'], headers=header['hex'], subscription_update=True
|
|
||||||
)
|
|
||||||
|
|
||||||
async def subscribe_accounts(self):
|
|
||||||
if self.network.is_connected and self.accounts:
|
|
||||||
await asyncio.wait([
|
|
||||||
self.subscribe_account(a) for a in self.accounts
|
|
||||||
])
|
|
||||||
|
|
||||||
async def subscribe_account(self, account: baseaccount.BaseAccount):
|
|
||||||
for address_manager in account.address_managers.values():
|
|
||||||
await self.subscribe_addresses(address_manager, await address_manager.get_addresses())
|
|
||||||
await account.ensure_address_gap()
|
|
||||||
|
|
||||||
async def unsubscribe_account(self, account: baseaccount.BaseAccount):
|
|
||||||
for address in await account.get_addresses():
|
|
||||||
await self.network.unsubscribe_address(address)
|
|
||||||
|
|
||||||
async def announce_addresses(self, address_manager: baseaccount.AddressManager, addresses: List[str]):
|
|
||||||
await self.subscribe_addresses(address_manager, addresses)
|
|
||||||
await self._on_address_controller.add(
|
|
||||||
AddressesGeneratedEvent(address_manager, addresses)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def subscribe_addresses(self, address_manager: baseaccount.AddressManager, addresses: List[str]):
|
|
||||||
if self.network.is_connected and addresses:
|
|
||||||
await asyncio.wait([
|
|
||||||
self.subscribe_address(address_manager, address) for address in addresses
|
|
||||||
])
|
|
||||||
|
|
||||||
async def subscribe_address(self, address_manager: baseaccount.AddressManager, address: str):
|
|
||||||
remote_status = await self.network.subscribe_address(address)
|
|
||||||
self._update_tasks.add(self.update_history(address, remote_status, address_manager))
|
|
||||||
|
|
||||||
def process_status_update(self, update):
|
|
||||||
address, remote_status = update
|
|
||||||
self._update_tasks.add(self.update_history(address, remote_status))
|
|
||||||
|
|
||||||
async def update_history(self, address, remote_status,
|
|
||||||
address_manager: baseaccount.AddressManager = None):
|
|
||||||
|
|
||||||
async with self._address_update_locks.setdefault(address, asyncio.Lock()):
|
|
||||||
self._known_addresses_out_of_sync.discard(address)
|
|
||||||
|
|
||||||
local_status, local_history = await self.get_local_status_and_history(address)
|
|
||||||
|
|
||||||
if local_status == remote_status:
|
|
||||||
return True
|
|
||||||
|
|
||||||
remote_history = await self.network.retriable_call(self.network.get_history, address)
|
|
||||||
remote_history = list(map(itemgetter('tx_hash', 'height'), remote_history))
|
|
||||||
we_need = set(remote_history) - set(local_history)
|
|
||||||
if not we_need:
|
|
||||||
return True
|
|
||||||
|
|
||||||
cache_tasks: List[asyncio.Future[BaseTransaction]] = []
|
|
||||||
synced_history = StringIO()
|
|
||||||
for i, (txid, remote_height) in enumerate(remote_history):
|
|
||||||
if i < len(local_history) and local_history[i] == (txid, remote_height) and not cache_tasks:
|
|
||||||
synced_history.write(f'{txid}:{remote_height}:')
|
|
||||||
else:
|
|
||||||
check_local = (txid, remote_height) not in we_need
|
|
||||||
cache_tasks.append(asyncio.ensure_future(
|
|
||||||
self.cache_transaction(txid, remote_height, check_local=check_local)
|
|
||||||
))
|
|
||||||
|
|
||||||
synced_txs = []
|
|
||||||
for task in cache_tasks:
|
|
||||||
tx = await task
|
|
||||||
|
|
||||||
check_db_for_txos = []
|
|
||||||
for txi in tx.inputs:
|
|
||||||
if txi.txo_ref.txo is not None:
|
|
||||||
continue
|
|
||||||
cache_item = self._tx_cache.get(txi.txo_ref.tx_ref.id)
|
|
||||||
if cache_item is not None:
|
|
||||||
if cache_item.tx is None:
|
|
||||||
await cache_item.has_tx.wait()
|
|
||||||
assert cache_item.tx is not None
|
|
||||||
txi.txo_ref = cache_item.tx.outputs[txi.txo_ref.position].ref
|
|
||||||
else:
|
|
||||||
check_db_for_txos.append(txi.txo_ref.id)
|
|
||||||
|
|
||||||
referenced_txos = {} if not check_db_for_txos else {
|
|
||||||
txo.id: txo for txo in await self.db.get_txos(txoid__in=check_db_for_txos, no_tx=True)
|
|
||||||
}
|
|
||||||
|
|
||||||
for txi in tx.inputs:
|
|
||||||
if txi.txo_ref.txo is not None:
|
|
||||||
continue
|
|
||||||
referenced_txo = referenced_txos.get(txi.txo_ref.id)
|
|
||||||
if referenced_txo is not None:
|
|
||||||
txi.txo_ref = referenced_txo.ref
|
|
||||||
|
|
||||||
synced_history.write(f'{tx.id}:{tx.height}:')
|
|
||||||
synced_txs.append(tx)
|
|
||||||
|
|
||||||
await self.db.save_transaction_io_batch(
|
|
||||||
synced_txs, address, self.address_to_hash160(address), synced_history.getvalue()
|
|
||||||
)
|
|
||||||
await asyncio.wait([
|
|
||||||
self._on_transaction_controller.add(TransactionEvent(address, tx))
|
|
||||||
for tx in synced_txs
|
|
||||||
])
|
|
||||||
|
|
||||||
if address_manager is None:
|
|
||||||
address_manager = await self.get_address_manager_for_address(address)
|
|
||||||
|
|
||||||
if address_manager is not None:
|
|
||||||
await address_manager.ensure_address_gap()
|
|
||||||
|
|
||||||
local_status, local_history = \
|
|
||||||
await self.get_local_status_and_history(address, synced_history.getvalue())
|
|
||||||
if local_status != remote_status:
|
|
||||||
if local_history == remote_history:
|
|
||||||
return True
|
|
||||||
log.warning(
|
|
||||||
"Wallet is out of sync after syncing. Remote: %s with %d items, local: %s with %d items",
|
|
||||||
remote_status, len(remote_history), local_status, len(local_history)
|
|
||||||
)
|
|
||||||
log.warning("local: %s", local_history)
|
|
||||||
log.warning("remote: %s", remote_history)
|
|
||||||
self._known_addresses_out_of_sync.add(address)
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def cache_transaction(self, txid, remote_height, check_local=True):
|
|
||||||
cache_item = self._tx_cache.get(txid)
|
|
||||||
if cache_item is None:
|
|
||||||
cache_item = self._tx_cache[txid] = TransactionCacheItem()
|
|
||||||
elif cache_item.tx is not None and \
|
|
||||||
cache_item.tx.height >= remote_height and \
|
|
||||||
(cache_item.tx.is_verified or remote_height < 1):
|
|
||||||
return cache_item.tx # cached tx is already up-to-date
|
|
||||||
|
|
||||||
async with cache_item.lock:
|
|
||||||
|
|
||||||
tx = cache_item.tx
|
|
||||||
|
|
||||||
if tx is None and check_local:
|
|
||||||
# check local db
|
|
||||||
tx = cache_item.tx = await self.db.get_transaction(txid=txid)
|
|
||||||
|
|
||||||
if tx is None:
|
|
||||||
# fetch from network
|
|
||||||
_raw = await self.network.retriable_call(self.network.get_transaction, txid, remote_height)
|
|
||||||
tx = self.transaction_class(unhexlify(_raw))
|
|
||||||
cache_item.tx = tx # make sure it's saved before caching it
|
|
||||||
|
|
||||||
await self.maybe_verify_transaction(tx, remote_height)
|
|
||||||
return tx
|
|
||||||
|
|
||||||
async def maybe_verify_transaction(self, tx, remote_height):
|
|
||||||
tx.height = remote_height
|
|
||||||
if 0 < remote_height < len(self.headers):
|
|
||||||
merkle = await self.network.retriable_call(self.network.get_merkle, tx.id, remote_height)
|
|
||||||
merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash)
|
|
||||||
header = self.headers[remote_height]
|
|
||||||
tx.position = merkle['pos']
|
|
||||||
tx.is_verified = merkle_root == header['merkle_root']
|
|
||||||
|
|
||||||
async def get_address_manager_for_address(self, address) -> Optional[baseaccount.AddressManager]:
|
|
||||||
details = await self.db.get_address(address=address)
|
|
||||||
for account in self.accounts:
|
|
||||||
if account.id == details['account']:
|
|
||||||
return account.address_managers[details['chain']]
|
|
||||||
return None
|
|
||||||
|
|
||||||
def broadcast(self, tx):
|
|
||||||
# broadcast can't be a retriable call yet
|
|
||||||
return self.network.broadcast(hexlify(tx.raw).decode())
|
|
||||||
|
|
||||||
async def wait(self, tx: basetransaction.BaseTransaction, height=-1, timeout=1):
|
|
||||||
addresses = set()
|
|
||||||
for txi in tx.inputs:
|
|
||||||
if txi.txo_ref.txo is not None:
|
|
||||||
addresses.add(
|
|
||||||
self.hash160_to_address(txi.txo_ref.txo.pubkey_hash)
|
|
||||||
)
|
|
||||||
for txo in tx.outputs:
|
|
||||||
if txo.has_address:
|
|
||||||
addresses.add(self.hash160_to_address(txo.pubkey_hash))
|
|
||||||
records = await self.db.get_addresses(address__in=addresses)
|
|
||||||
_, pending = await asyncio.wait([
|
|
||||||
self.on_transaction.where(partial(
|
|
||||||
lambda a, e: a == e.address and e.tx.height >= height and e.tx.id == tx.id,
|
|
||||||
address_record['address']
|
|
||||||
)) for address_record in records
|
|
||||||
], timeout=timeout)
|
|
||||||
if pending:
|
|
||||||
for record in records:
|
|
||||||
found = False
|
|
||||||
_, local_history = await self.get_local_status_and_history(None, history=record['history'])
|
|
||||||
for txid, local_height in local_history:
|
|
||||||
if txid == tx.id and local_height >= height:
|
|
||||||
found = True
|
|
||||||
if not found:
|
|
||||||
print(record['history'], addresses, tx.id)
|
|
||||||
raise asyncio.TimeoutError('Timed out waiting for transaction.')
|
|
|
@ -1,87 +0,0 @@
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
from typing import Type, MutableSequence, MutableMapping, Optional
|
|
||||||
|
|
||||||
from lbry.wallet.client.baseledger import BaseLedger, LedgerRegistry
|
|
||||||
from lbry.wallet.client.wallet import Wallet, WalletStorage
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class BaseWalletManager:
|
|
||||||
|
|
||||||
def __init__(self, wallets: MutableSequence[Wallet] = None,
|
|
||||||
ledgers: MutableMapping[Type[BaseLedger], BaseLedger] = None) -> None:
|
|
||||||
self.wallets = wallets or []
|
|
||||||
self.ledgers = ledgers or {}
|
|
||||||
self.running = False
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, config: dict) -> 'BaseWalletManager':
|
|
||||||
manager = cls()
|
|
||||||
for ledger_id, ledger_config in config.get('ledgers', {}).items():
|
|
||||||
manager.get_or_create_ledger(ledger_id, ledger_config)
|
|
||||||
for wallet_path in config.get('wallets', []):
|
|
||||||
wallet_storage = WalletStorage(wallet_path)
|
|
||||||
wallet = Wallet.from_storage(wallet_storage, manager)
|
|
||||||
manager.wallets.append(wallet)
|
|
||||||
return manager
|
|
||||||
|
|
||||||
def get_or_create_ledger(self, ledger_id, ledger_config=None):
|
|
||||||
ledger_class = LedgerRegistry.get_ledger_class(ledger_id)
|
|
||||||
ledger = self.ledgers.get(ledger_class)
|
|
||||||
if ledger is None:
|
|
||||||
ledger = ledger_class(ledger_config or {})
|
|
||||||
self.ledgers[ledger_class] = ledger
|
|
||||||
return ledger
|
|
||||||
|
|
||||||
def import_wallet(self, path):
|
|
||||||
storage = WalletStorage(path)
|
|
||||||
wallet = Wallet.from_storage(storage, self)
|
|
||||||
self.wallets.append(wallet)
|
|
||||||
return wallet
|
|
||||||
|
|
||||||
@property
|
|
||||||
def default_wallet(self):
|
|
||||||
for wallet in self.wallets:
|
|
||||||
return wallet
|
|
||||||
|
|
||||||
@property
|
|
||||||
def default_account(self):
|
|
||||||
for wallet in self.wallets:
|
|
||||||
return wallet.default_account
|
|
||||||
|
|
||||||
@property
|
|
||||||
def accounts(self):
|
|
||||||
for wallet in self.wallets:
|
|
||||||
yield from wallet.accounts
|
|
||||||
|
|
||||||
async def start(self):
|
|
||||||
self.running = True
|
|
||||||
await asyncio.gather(*(
|
|
||||||
l.start() for l in self.ledgers.values()
|
|
||||||
))
|
|
||||||
|
|
||||||
async def stop(self):
|
|
||||||
await asyncio.gather(*(
|
|
||||||
l.stop() for l in self.ledgers.values()
|
|
||||||
))
|
|
||||||
self.running = False
|
|
||||||
|
|
||||||
def get_wallet_or_default(self, wallet_id: Optional[str]) -> Wallet:
|
|
||||||
if wallet_id is None:
|
|
||||||
return self.default_wallet
|
|
||||||
return self.get_wallet_or_error(wallet_id)
|
|
||||||
|
|
||||||
def get_wallet_or_error(self, wallet_id: str) -> Wallet:
|
|
||||||
for wallet in self.wallets:
|
|
||||||
if wallet.id == wallet_id:
|
|
||||||
return wallet
|
|
||||||
raise ValueError(f"Couldn't find wallet: {wallet_id}.")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_balance(wallet):
|
|
||||||
accounts = wallet.accounts
|
|
||||||
if not accounts:
|
|
||||||
return 0
|
|
||||||
return accounts[0].ledger.db.get_balance(wallet=wallet, accounts=accounts)
|
|
|
@ -1,364 +0,0 @@
|
||||||
import logging
|
|
||||||
import asyncio
|
|
||||||
from operator import itemgetter
|
|
||||||
from typing import Dict, Optional, Tuple
|
|
||||||
from time import perf_counter
|
|
||||||
|
|
||||||
import lbry
|
|
||||||
from lbry.wallet.rpc import RPCSession as BaseClientSession, Connector, RPCError, ProtocolError
|
|
||||||
from lbry.wallet.stream import StreamController
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class ClientSession(BaseClientSession):
|
|
||||||
def __init__(self, *args, network, server, timeout=30, on_connect_callback=None, **kwargs):
|
|
||||||
self.network = network
|
|
||||||
self.server = server
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self._on_disconnect_controller = StreamController()
|
|
||||||
self.on_disconnected = self._on_disconnect_controller.stream
|
|
||||||
self.framer.max_size = self.max_errors = 1 << 32
|
|
||||||
self.bw_limit = -1
|
|
||||||
self.timeout = timeout
|
|
||||||
self.max_seconds_idle = timeout * 2
|
|
||||||
self.response_time: Optional[float] = None
|
|
||||||
self.connection_latency: Optional[float] = None
|
|
||||||
self._response_samples = 0
|
|
||||||
self.pending_amount = 0
|
|
||||||
self._on_connect_cb = on_connect_callback or (lambda: None)
|
|
||||||
self.trigger_urgent_reconnect = asyncio.Event()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def available(self):
|
|
||||||
return not self.is_closing() and self.response_time is not None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def server_address_and_port(self) -> Optional[Tuple[str, int]]:
|
|
||||||
if not self.transport:
|
|
||||||
return None
|
|
||||||
return self.transport.get_extra_info('peername')
|
|
||||||
|
|
||||||
async def send_timed_server_version_request(self, args=(), timeout=None):
|
|
||||||
timeout = timeout or self.timeout
|
|
||||||
log.debug("send version request to %s:%i", *self.server)
|
|
||||||
start = perf_counter()
|
|
||||||
result = await asyncio.wait_for(
|
|
||||||
super().send_request('server.version', args), timeout=timeout
|
|
||||||
)
|
|
||||||
current_response_time = perf_counter() - start
|
|
||||||
response_sum = (self.response_time or 0) * self._response_samples + current_response_time
|
|
||||||
self.response_time = response_sum / (self._response_samples + 1)
|
|
||||||
self._response_samples += 1
|
|
||||||
return result
|
|
||||||
|
|
||||||
async def send_request(self, method, args=()):
|
|
||||||
self.pending_amount += 1
|
|
||||||
log.debug("send %s to %s:%i", method, *self.server)
|
|
||||||
try:
|
|
||||||
if method == 'server.version':
|
|
||||||
return await self.send_timed_server_version_request(args, self.timeout)
|
|
||||||
request = asyncio.ensure_future(super().send_request(method, args))
|
|
||||||
while not request.done():
|
|
||||||
done, pending = await asyncio.wait([request], timeout=self.timeout)
|
|
||||||
if pending:
|
|
||||||
log.debug("Time since last packet: %s", perf_counter() - self.last_packet_received)
|
|
||||||
if (perf_counter() - self.last_packet_received) < self.timeout:
|
|
||||||
continue
|
|
||||||
log.info("timeout sending %s to %s:%i", method, *self.server)
|
|
||||||
raise asyncio.TimeoutError
|
|
||||||
if done:
|
|
||||||
return request.result()
|
|
||||||
except (RPCError, ProtocolError) as e:
|
|
||||||
log.warning("Wallet server (%s:%i) returned an error. Code: %s Message: %s",
|
|
||||||
*self.server, *e.args)
|
|
||||||
raise e
|
|
||||||
except ConnectionError:
|
|
||||||
log.warning("connection to %s:%i lost", *self.server)
|
|
||||||
self.synchronous_close()
|
|
||||||
raise
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
log.info("cancelled sending %s to %s:%i", method, *self.server)
|
|
||||||
self.synchronous_close()
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
self.pending_amount -= 1
|
|
||||||
|
|
||||||
async def ensure_session(self):
|
|
||||||
# Handles reconnecting and maintaining a session alive
|
|
||||||
# TODO: change to 'ping' on newer protocol (above 1.2)
|
|
||||||
retry_delay = default_delay = 1.0
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
if self.is_closing():
|
|
||||||
await self.create_connection(self.timeout)
|
|
||||||
await self.ensure_server_version()
|
|
||||||
self._on_connect_cb()
|
|
||||||
if (perf_counter() - self.last_send) > self.max_seconds_idle or self.response_time is None:
|
|
||||||
await self.ensure_server_version()
|
|
||||||
retry_delay = default_delay
|
|
||||||
except RPCError as e:
|
|
||||||
log.warning("Server error, ignoring for 1h: %s:%d -- %s", *self.server, e.message)
|
|
||||||
retry_delay = 60 * 60
|
|
||||||
except (asyncio.TimeoutError, OSError):
|
|
||||||
await self.close()
|
|
||||||
retry_delay = min(60, retry_delay * 2)
|
|
||||||
log.debug("Wallet server timeout (retry in %s seconds): %s:%d", retry_delay, *self.server)
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(self.trigger_urgent_reconnect.wait(), timeout=retry_delay)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
pass
|
|
||||||
finally:
|
|
||||||
self.trigger_urgent_reconnect.clear()
|
|
||||||
|
|
||||||
async def ensure_server_version(self, required=None, timeout=3):
|
|
||||||
required = required or self.network.PROTOCOL_VERSION
|
|
||||||
return await asyncio.wait_for(
|
|
||||||
self.send_request('server.version', [lbry.__version__, required]), timeout=timeout
|
|
||||||
)
|
|
||||||
|
|
||||||
async def create_connection(self, timeout=6):
|
|
||||||
connector = Connector(lambda: self, *self.server)
|
|
||||||
start = perf_counter()
|
|
||||||
await asyncio.wait_for(connector.create_connection(), timeout=timeout)
|
|
||||||
self.connection_latency = perf_counter() - start
|
|
||||||
|
|
||||||
async def handle_request(self, request):
|
|
||||||
controller = self.network.subscription_controllers[request.method]
|
|
||||||
controller.add(request.args)
|
|
||||||
|
|
||||||
def connection_lost(self, exc):
|
|
||||||
log.debug("Connection lost: %s:%d", *self.server)
|
|
||||||
super().connection_lost(exc)
|
|
||||||
self.response_time = None
|
|
||||||
self.connection_latency = None
|
|
||||||
self._response_samples = 0
|
|
||||||
self.pending_amount = 0
|
|
||||||
self._on_disconnect_controller.add(True)
|
|
||||||
|
|
||||||
|
|
||||||
class BaseNetwork:
|
|
||||||
PROTOCOL_VERSION = '1.2'
|
|
||||||
|
|
||||||
def __init__(self, ledger):
|
|
||||||
self.ledger = ledger
|
|
||||||
self.session_pool = SessionPool(network=self, timeout=self.config.get('connect_timeout', 6))
|
|
||||||
self.client: Optional[ClientSession] = None
|
|
||||||
self._switch_task: Optional[asyncio.Task] = None
|
|
||||||
self.running = False
|
|
||||||
self.remote_height: int = 0
|
|
||||||
self._concurrency = asyncio.Semaphore(16)
|
|
||||||
|
|
||||||
self._on_connected_controller = StreamController()
|
|
||||||
self.on_connected = self._on_connected_controller.stream
|
|
||||||
|
|
||||||
self._on_header_controller = StreamController(merge_repeated_events=True)
|
|
||||||
self.on_header = self._on_header_controller.stream
|
|
||||||
|
|
||||||
self._on_status_controller = StreamController(merge_repeated_events=True)
|
|
||||||
self.on_status = self._on_status_controller.stream
|
|
||||||
|
|
||||||
self.subscription_controllers = {
|
|
||||||
'blockchain.headers.subscribe': self._on_header_controller,
|
|
||||||
'blockchain.address.subscribe': self._on_status_controller,
|
|
||||||
}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def config(self):
|
|
||||||
return self.ledger.config
|
|
||||||
|
|
||||||
async def switch_forever(self):
|
|
||||||
while self.running:
|
|
||||||
if self.is_connected:
|
|
||||||
await self.client.on_disconnected.first
|
|
||||||
self.client = None
|
|
||||||
continue
|
|
||||||
self.client = await self.session_pool.wait_for_fastest_session()
|
|
||||||
log.info("Switching to SPV wallet server: %s:%d", *self.client.server)
|
|
||||||
try:
|
|
||||||
self._update_remote_height((await self.subscribe_headers(),))
|
|
||||||
self._on_connected_controller.add(True)
|
|
||||||
log.info("Subscribed to headers: %s:%d", *self.client.server)
|
|
||||||
except (asyncio.TimeoutError, ConnectionError):
|
|
||||||
log.info("Switching to %s:%d timed out, closing and retrying.", *self.client.server)
|
|
||||||
self.client.synchronous_close()
|
|
||||||
self.client = None
|
|
||||||
|
|
||||||
async def start(self):
|
|
||||||
self.running = True
|
|
||||||
self._switch_task = asyncio.ensure_future(self.switch_forever())
|
|
||||||
# this may become unnecessary when there are no more bugs found,
|
|
||||||
# but for now it helps understanding log reports
|
|
||||||
self._switch_task.add_done_callback(lambda _: log.info("Wallet client switching task stopped."))
|
|
||||||
self.session_pool.start(self.config['default_servers'])
|
|
||||||
self.on_header.listen(self._update_remote_height)
|
|
||||||
|
|
||||||
async def stop(self):
|
|
||||||
if self.running:
|
|
||||||
self.running = False
|
|
||||||
self._switch_task.cancel()
|
|
||||||
self.session_pool.stop()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_connected(self):
|
|
||||||
return self.client and not self.client.is_closing()
|
|
||||||
|
|
||||||
def rpc(self, list_or_method, args, restricted=True):
|
|
||||||
session = self.client if restricted else self.session_pool.fastest_session
|
|
||||||
if session and not session.is_closing():
|
|
||||||
return session.send_request(list_or_method, args)
|
|
||||||
else:
|
|
||||||
self.session_pool.trigger_nodelay_connect()
|
|
||||||
raise ConnectionError("Attempting to send rpc request when connection is not available.")
|
|
||||||
|
|
||||||
async def retriable_call(self, function, *args, **kwargs):
|
|
||||||
async with self._concurrency:
|
|
||||||
while self.running:
|
|
||||||
if not self.is_connected:
|
|
||||||
log.warning("Wallet server unavailable, waiting for it to come back and retry.")
|
|
||||||
await self.on_connected.first
|
|
||||||
await self.session_pool.wait_for_fastest_session()
|
|
||||||
try:
|
|
||||||
return await function(*args, **kwargs)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
log.warning("Wallet server call timed out, retrying.")
|
|
||||||
except ConnectionError:
|
|
||||||
pass
|
|
||||||
raise asyncio.CancelledError() # if we got here, we are shutting down
|
|
||||||
|
|
||||||
def _update_remote_height(self, header_args):
|
|
||||||
self.remote_height = header_args[0]["height"]
|
|
||||||
|
|
||||||
def get_transaction(self, tx_hash, known_height=None):
|
|
||||||
# use any server if its old, otherwise restrict to who gave us the history
|
|
||||||
restricted = known_height in (None, -1, 0) or 0 > known_height > self.remote_height - 10
|
|
||||||
return self.rpc('blockchain.transaction.get', [tx_hash], restricted)
|
|
||||||
|
|
||||||
def get_transaction_height(self, tx_hash, known_height=None):
|
|
||||||
restricted = not known_height or 0 > known_height > self.remote_height - 10
|
|
||||||
return self.rpc('blockchain.transaction.get_height', [tx_hash], restricted)
|
|
||||||
|
|
||||||
def get_merkle(self, tx_hash, height):
|
|
||||||
restricted = 0 > height > self.remote_height - 10
|
|
||||||
return self.rpc('blockchain.transaction.get_merkle', [tx_hash, height], restricted)
|
|
||||||
|
|
||||||
def get_headers(self, height, count=10000, b64=False):
|
|
||||||
restricted = height >= self.remote_height - 100
|
|
||||||
return self.rpc('blockchain.block.headers', [height, count, 0, b64], restricted)
|
|
||||||
|
|
||||||
# --- Subscribes, history and broadcasts are always aimed towards the master client directly
|
|
||||||
def get_history(self, address):
|
|
||||||
return self.rpc('blockchain.address.get_history', [address], True)
|
|
||||||
|
|
||||||
def broadcast(self, raw_transaction):
|
|
||||||
return self.rpc('blockchain.transaction.broadcast', [raw_transaction], True)
|
|
||||||
|
|
||||||
def subscribe_headers(self):
|
|
||||||
return self.rpc('blockchain.headers.subscribe', [True], True)
|
|
||||||
|
|
||||||
async def subscribe_address(self, address):
|
|
||||||
try:
|
|
||||||
return await self.rpc('blockchain.address.subscribe', [address], True)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
# abort and cancel, we can't lose a subscription, it will happen again on reconnect
|
|
||||||
if self.client:
|
|
||||||
self.client.abort()
|
|
||||||
raise asyncio.CancelledError()
|
|
||||||
|
|
||||||
def unsubscribe_address(self, address):
|
|
||||||
return self.rpc('blockchain.address.unsubscribe', [address], True)
|
|
||||||
|
|
||||||
def get_server_features(self):
|
|
||||||
return self.rpc('server.features', (), restricted=True)
|
|
||||||
|
|
||||||
|
|
||||||
class SessionPool:
|
|
||||||
|
|
||||||
def __init__(self, network: BaseNetwork, timeout: float):
|
|
||||||
self.network = network
|
|
||||||
self.sessions: Dict[ClientSession, Optional[asyncio.Task]] = dict()
|
|
||||||
self.timeout = timeout
|
|
||||||
self.new_connection_event = asyncio.Event()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def online(self):
|
|
||||||
return any(not session.is_closing() for session in self.sessions)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def available_sessions(self):
|
|
||||||
return (session for session in self.sessions if session.available)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def fastest_session(self):
|
|
||||||
if not self.online:
|
|
||||||
return None
|
|
||||||
return min(
|
|
||||||
[((session.response_time + session.connection_latency) * (session.pending_amount + 1), session)
|
|
||||||
for session in self.available_sessions] or [(0, None)],
|
|
||||||
key=itemgetter(0)
|
|
||||||
)[1]
|
|
||||||
|
|
||||||
def _get_session_connect_callback(self, session: ClientSession):
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
|
|
||||||
def callback():
|
|
||||||
duplicate_connections = [
|
|
||||||
s for s in self.sessions
|
|
||||||
if s is not session and s.server_address_and_port == session.server_address_and_port
|
|
||||||
]
|
|
||||||
already_connected = None if not duplicate_connections else duplicate_connections[0]
|
|
||||||
if already_connected:
|
|
||||||
self.sessions.pop(session).cancel()
|
|
||||||
session.synchronous_close()
|
|
||||||
log.debug("wallet server %s resolves to the same server as %s, rechecking in an hour",
|
|
||||||
session.server[0], already_connected.server[0])
|
|
||||||
loop.call_later(3600, self._connect_session, session.server)
|
|
||||||
return
|
|
||||||
self.new_connection_event.set()
|
|
||||||
log.info("connected to %s:%i", *session.server)
|
|
||||||
|
|
||||||
return callback
|
|
||||||
|
|
||||||
def _connect_session(self, server: Tuple[str, int]):
|
|
||||||
session = None
|
|
||||||
for s in self.sessions:
|
|
||||||
if s.server == server:
|
|
||||||
session = s
|
|
||||||
break
|
|
||||||
if not session:
|
|
||||||
session = ClientSession(
|
|
||||||
network=self.network, server=server
|
|
||||||
)
|
|
||||||
session._on_connect_cb = self._get_session_connect_callback(session)
|
|
||||||
task = self.sessions.get(session, None)
|
|
||||||
if not task or task.done():
|
|
||||||
task = asyncio.create_task(session.ensure_session())
|
|
||||||
task.add_done_callback(lambda _: self.ensure_connections())
|
|
||||||
self.sessions[session] = task
|
|
||||||
|
|
||||||
def start(self, default_servers):
|
|
||||||
for server in default_servers:
|
|
||||||
self._connect_session(server)
|
|
||||||
|
|
||||||
def stop(self):
|
|
||||||
for session, task in self.sessions.items():
|
|
||||||
task.cancel()
|
|
||||||
session.synchronous_close()
|
|
||||||
self.sessions.clear()
|
|
||||||
|
|
||||||
def ensure_connections(self):
|
|
||||||
for session in self.sessions:
|
|
||||||
self._connect_session(session.server)
|
|
||||||
|
|
||||||
def trigger_nodelay_connect(self):
|
|
||||||
# used when other parts of the system sees we might have internet back
|
|
||||||
# bypasses the retry interval
|
|
||||||
for session in self.sessions:
|
|
||||||
session.trigger_urgent_reconnect.set()
|
|
||||||
|
|
||||||
async def wait_for_fastest_session(self):
|
|
||||||
while not self.fastest_session:
|
|
||||||
self.trigger_nodelay_connect()
|
|
||||||
self.new_connection_event.clear()
|
|
||||||
await self.new_connection_event.wait()
|
|
||||||
return self.fastest_session
|
|
|
@ -1,450 +0,0 @@
|
||||||
from itertools import chain
|
|
||||||
from binascii import hexlify
|
|
||||||
from collections import namedtuple
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from lbry.wallet.client.bcd_data_stream import BCDataStream
|
|
||||||
from lbry.wallet.client.util import subclass_tuple
|
|
||||||
|
|
||||||
# bitcoin opcodes
|
|
||||||
OP_0 = 0x00
|
|
||||||
OP_1 = 0x51
|
|
||||||
OP_16 = 0x60
|
|
||||||
OP_VERIFY = 0x69
|
|
||||||
OP_DUP = 0x76
|
|
||||||
OP_HASH160 = 0xa9
|
|
||||||
OP_EQUALVERIFY = 0x88
|
|
||||||
OP_CHECKSIG = 0xac
|
|
||||||
OP_CHECKMULTISIG = 0xae
|
|
||||||
OP_EQUAL = 0x87
|
|
||||||
OP_PUSHDATA1 = 0x4c
|
|
||||||
OP_PUSHDATA2 = 0x4d
|
|
||||||
OP_PUSHDATA4 = 0x4e
|
|
||||||
OP_RETURN = 0x6a
|
|
||||||
OP_2DROP = 0x6d
|
|
||||||
OP_DROP = 0x75
|
|
||||||
|
|
||||||
|
|
||||||
# template matching opcodes (not real opcodes)
|
|
||||||
# base class for PUSH_DATA related opcodes
|
|
||||||
# pylint: disable=invalid-name
|
|
||||||
PUSH_DATA_OP = namedtuple('PUSH_DATA_OP', 'name')
|
|
||||||
# opcode for variable length strings
|
|
||||||
# pylint: disable=invalid-name
|
|
||||||
PUSH_SINGLE = subclass_tuple('PUSH_SINGLE', PUSH_DATA_OP)
|
|
||||||
# opcode for variable size integers
|
|
||||||
# pylint: disable=invalid-name
|
|
||||||
PUSH_INTEGER = subclass_tuple('PUSH_INTEGER', PUSH_DATA_OP)
|
|
||||||
# opcode for variable number of variable length strings
|
|
||||||
# pylint: disable=invalid-name
|
|
||||||
PUSH_MANY = subclass_tuple('PUSH_MANY', PUSH_DATA_OP)
|
|
||||||
# opcode with embedded subscript parsing
|
|
||||||
# pylint: disable=invalid-name
|
|
||||||
PUSH_SUBSCRIPT = namedtuple('PUSH_SUBSCRIPT', 'name template')
|
|
||||||
|
|
||||||
|
|
||||||
def is_push_data_opcode(opcode):
|
|
||||||
return isinstance(opcode, (PUSH_DATA_OP, PUSH_SUBSCRIPT))
|
|
||||||
|
|
||||||
|
|
||||||
def is_push_data_token(token):
|
|
||||||
return 1 <= token <= OP_PUSHDATA4
|
|
||||||
|
|
||||||
|
|
||||||
def push_data(data):
|
|
||||||
size = len(data)
|
|
||||||
if size < OP_PUSHDATA1:
|
|
||||||
yield BCDataStream.uint8.pack(size)
|
|
||||||
elif size <= 0xFF:
|
|
||||||
yield BCDataStream.uint8.pack(OP_PUSHDATA1)
|
|
||||||
yield BCDataStream.uint8.pack(size)
|
|
||||||
elif size <= 0xFFFF:
|
|
||||||
yield BCDataStream.uint8.pack(OP_PUSHDATA2)
|
|
||||||
yield BCDataStream.uint16.pack(size)
|
|
||||||
else:
|
|
||||||
yield BCDataStream.uint8.pack(OP_PUSHDATA4)
|
|
||||||
yield BCDataStream.uint32.pack(size)
|
|
||||||
yield bytes(data)
|
|
||||||
|
|
||||||
|
|
||||||
def read_data(token, stream):
|
|
||||||
if token < OP_PUSHDATA1:
|
|
||||||
return stream.read(token)
|
|
||||||
if token == OP_PUSHDATA1:
|
|
||||||
return stream.read(stream.read_uint8())
|
|
||||||
if token == OP_PUSHDATA2:
|
|
||||||
return stream.read(stream.read_uint16())
|
|
||||||
return stream.read(stream.read_uint32())
|
|
||||||
|
|
||||||
|
|
||||||
# opcode for OP_1 - OP_16
|
|
||||||
# pylint: disable=invalid-name
|
|
||||||
SMALL_INTEGER = namedtuple('SMALL_INTEGER', 'name')
|
|
||||||
|
|
||||||
|
|
||||||
def is_small_integer(token):
|
|
||||||
return OP_1 <= token <= OP_16
|
|
||||||
|
|
||||||
|
|
||||||
def push_small_integer(num):
|
|
||||||
assert 1 <= num <= 16
|
|
||||||
yield BCDataStream.uint8.pack(OP_1 + (num - 1))
|
|
||||||
|
|
||||||
|
|
||||||
def read_small_integer(token):
|
|
||||||
return (token - OP_1) + 1
|
|
||||||
|
|
||||||
|
|
||||||
class Token(namedtuple('Token', 'value')):
|
|
||||||
__slots__ = ()
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
name = None
|
|
||||||
for var_name, var_value in globals().items():
|
|
||||||
if var_name.startswith('OP_') and var_value == self.value:
|
|
||||||
name = var_name
|
|
||||||
break
|
|
||||||
return name or self.value
|
|
||||||
|
|
||||||
|
|
||||||
class DataToken(Token):
|
|
||||||
__slots__ = ()
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f'"{hexlify(self.value)}"'
|
|
||||||
|
|
||||||
|
|
||||||
class SmallIntegerToken(Token):
|
|
||||||
__slots__ = ()
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f'SmallIntegerToken({self.value})'
|
|
||||||
|
|
||||||
|
|
||||||
def token_producer(source):
|
|
||||||
token = source.read_uint8()
|
|
||||||
while token is not None:
|
|
||||||
if is_push_data_token(token):
|
|
||||||
yield DataToken(read_data(token, source))
|
|
||||||
elif is_small_integer(token):
|
|
||||||
yield SmallIntegerToken(read_small_integer(token))
|
|
||||||
else:
|
|
||||||
yield Token(token)
|
|
||||||
token = source.read_uint8()
|
|
||||||
|
|
||||||
|
|
||||||
def tokenize(source):
|
|
||||||
return list(token_producer(source))
|
|
||||||
|
|
||||||
|
|
||||||
class ScriptError(Exception):
|
|
||||||
""" General script handling error. """
|
|
||||||
|
|
||||||
|
|
||||||
class ParseError(ScriptError):
|
|
||||||
""" Script parsing error. """
|
|
||||||
|
|
||||||
|
|
||||||
class Parser:
|
|
||||||
|
|
||||||
def __init__(self, opcodes, tokens):
|
|
||||||
self.opcodes = opcodes
|
|
||||||
self.tokens = tokens
|
|
||||||
self.values = {}
|
|
||||||
self.token_index = 0
|
|
||||||
self.opcode_index = 0
|
|
||||||
|
|
||||||
def parse(self):
|
|
||||||
while self.token_index < len(self.tokens) and self.opcode_index < len(self.opcodes):
|
|
||||||
token = self.tokens[self.token_index]
|
|
||||||
opcode = self.opcodes[self.opcode_index]
|
|
||||||
if token.value == 0 and isinstance(opcode, PUSH_SINGLE):
|
|
||||||
token = DataToken(b'')
|
|
||||||
if isinstance(token, DataToken):
|
|
||||||
if isinstance(opcode, (PUSH_SINGLE, PUSH_INTEGER, PUSH_SUBSCRIPT)):
|
|
||||||
self.push_single(opcode, token.value)
|
|
||||||
elif isinstance(opcode, PUSH_MANY):
|
|
||||||
self.consume_many_non_greedy()
|
|
||||||
else:
|
|
||||||
raise ParseError(f"DataToken found but opcode was '{opcode}'.")
|
|
||||||
elif isinstance(token, SmallIntegerToken):
|
|
||||||
if isinstance(opcode, SMALL_INTEGER):
|
|
||||||
self.values[opcode.name] = token.value
|
|
||||||
else:
|
|
||||||
raise ParseError(f"SmallIntegerToken found but opcode was '{opcode}'.")
|
|
||||||
elif token.value == opcode:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
raise ParseError(f"Token is '{token.value}' and opcode is '{opcode}'.")
|
|
||||||
self.token_index += 1
|
|
||||||
self.opcode_index += 1
|
|
||||||
|
|
||||||
if self.token_index < len(self.tokens):
|
|
||||||
raise ParseError("Parse completed without all tokens being consumed.")
|
|
||||||
|
|
||||||
if self.opcode_index < len(self.opcodes):
|
|
||||||
raise ParseError("Parse completed without all opcodes being consumed.")
|
|
||||||
|
|
||||||
return self
|
|
||||||
|
|
||||||
def consume_many_non_greedy(self):
|
|
||||||
""" Allows PUSH_MANY to consume data without being greedy
|
|
||||||
in cases when one or more PUSH_SINGLEs follow a PUSH_MANY. This will
|
|
||||||
prioritize giving all PUSH_SINGLEs some data and only after that
|
|
||||||
subsume the rest into PUSH_MANY.
|
|
||||||
"""
|
|
||||||
|
|
||||||
token_values = []
|
|
||||||
while self.token_index < len(self.tokens):
|
|
||||||
token = self.tokens[self.token_index]
|
|
||||||
if not isinstance(token, DataToken):
|
|
||||||
self.token_index -= 1
|
|
||||||
break
|
|
||||||
token_values.append(token.value)
|
|
||||||
self.token_index += 1
|
|
||||||
|
|
||||||
push_opcodes = []
|
|
||||||
push_many_count = 0
|
|
||||||
while self.opcode_index < len(self.opcodes):
|
|
||||||
opcode = self.opcodes[self.opcode_index]
|
|
||||||
if not is_push_data_opcode(opcode):
|
|
||||||
self.opcode_index -= 1
|
|
||||||
break
|
|
||||||
if isinstance(opcode, PUSH_MANY):
|
|
||||||
push_many_count += 1
|
|
||||||
push_opcodes.append(opcode)
|
|
||||||
self.opcode_index += 1
|
|
||||||
|
|
||||||
if push_many_count > 1:
|
|
||||||
raise ParseError(
|
|
||||||
"Cannot have more than one consecutive PUSH_MANY, as there is no way to tell which"
|
|
||||||
" token value should go into which PUSH_MANY."
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(push_opcodes) > len(token_values):
|
|
||||||
raise ParseError(
|
|
||||||
"Not enough token values to match all of the PUSH_MANY and PUSH_SINGLE opcodes."
|
|
||||||
)
|
|
||||||
|
|
||||||
many_opcode = push_opcodes.pop(0)
|
|
||||||
|
|
||||||
# consume data into PUSH_SINGLE opcodes, working backwards
|
|
||||||
for opcode in reversed(push_opcodes):
|
|
||||||
self.push_single(opcode, token_values.pop())
|
|
||||||
|
|
||||||
# finally PUSH_MANY gets everything that's left
|
|
||||||
self.values[many_opcode.name] = token_values
|
|
||||||
|
|
||||||
def push_single(self, opcode, value):
|
|
||||||
if isinstance(opcode, PUSH_SINGLE):
|
|
||||||
self.values[opcode.name] = value
|
|
||||||
elif isinstance(opcode, PUSH_INTEGER):
|
|
||||||
self.values[opcode.name] = int.from_bytes(value, 'little')
|
|
||||||
elif isinstance(opcode, PUSH_SUBSCRIPT):
|
|
||||||
self.values[opcode.name] = Script.from_source_with_template(value, opcode.template)
|
|
||||||
else:
|
|
||||||
raise ParseError(f"Not a push single or subscript: {opcode}")
|
|
||||||
|
|
||||||
|
|
||||||
class Template:
|
|
||||||
|
|
||||||
__slots__ = 'name', 'opcodes'
|
|
||||||
|
|
||||||
def __init__(self, name, opcodes):
|
|
||||||
self.name = name
|
|
||||||
self.opcodes = opcodes
|
|
||||||
|
|
||||||
def parse(self, tokens):
|
|
||||||
return Parser(self.opcodes, tokens).parse().values if self.opcodes else {}
|
|
||||||
|
|
||||||
def generate(self, values):
|
|
||||||
source = BCDataStream()
|
|
||||||
for opcode in self.opcodes:
|
|
||||||
if isinstance(opcode, PUSH_SINGLE):
|
|
||||||
data = values[opcode.name]
|
|
||||||
source.write_many(push_data(data))
|
|
||||||
elif isinstance(opcode, PUSH_INTEGER):
|
|
||||||
data = values[opcode.name]
|
|
||||||
source.write_many(push_data(
|
|
||||||
data.to_bytes((data.bit_length() + 7) // 8, byteorder='little')
|
|
||||||
))
|
|
||||||
elif isinstance(opcode, PUSH_SUBSCRIPT):
|
|
||||||
data = values[opcode.name]
|
|
||||||
source.write_many(push_data(data.source))
|
|
||||||
elif isinstance(opcode, PUSH_MANY):
|
|
||||||
for data in values[opcode.name]:
|
|
||||||
source.write_many(push_data(data))
|
|
||||||
elif isinstance(opcode, SMALL_INTEGER):
|
|
||||||
data = values[opcode.name]
|
|
||||||
source.write_many(push_small_integer(data))
|
|
||||||
else:
|
|
||||||
source.write_uint8(opcode)
|
|
||||||
return source.get_bytes()
|
|
||||||
|
|
||||||
|
|
||||||
class Script:
|
|
||||||
|
|
||||||
__slots__ = 'source', '_template', '_values', '_template_hint'
|
|
||||||
|
|
||||||
templates: List[Template] = []
|
|
||||||
|
|
||||||
NO_SCRIPT = Template('no_script', None) # special case
|
|
||||||
|
|
||||||
def __init__(self, source=None, template=None, values=None, template_hint=None):
|
|
||||||
self.source = source
|
|
||||||
self._template = template
|
|
||||||
self._values = values
|
|
||||||
self._template_hint = template_hint
|
|
||||||
if source is None and template and values:
|
|
||||||
self.generate()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def template(self):
|
|
||||||
if self._template is None:
|
|
||||||
self.parse(self._template_hint)
|
|
||||||
return self._template
|
|
||||||
|
|
||||||
@property
|
|
||||||
def values(self):
|
|
||||||
if self._values is None:
|
|
||||||
self.parse(self._template_hint)
|
|
||||||
return self._values
|
|
||||||
|
|
||||||
@property
|
|
||||||
def tokens(self):
|
|
||||||
return tokenize(BCDataStream(self.source))
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_source_with_template(cls, source, template):
|
|
||||||
return cls(source, template_hint=template)
|
|
||||||
|
|
||||||
def parse(self, template_hint=None):
|
|
||||||
tokens = self.tokens
|
|
||||||
if not tokens and not template_hint:
|
|
||||||
template_hint = self.NO_SCRIPT
|
|
||||||
for template in chain((template_hint,), self.templates):
|
|
||||||
if not template:
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
self._values = template.parse(tokens)
|
|
||||||
self._template = template
|
|
||||||
return
|
|
||||||
except ParseError:
|
|
||||||
continue
|
|
||||||
raise ValueError(f'No matching templates for source: {hexlify(self.source)}')
|
|
||||||
|
|
||||||
def generate(self):
|
|
||||||
self.source = self.template.generate(self._values)
|
|
||||||
|
|
||||||
|
|
||||||
class BaseInputScript(Script):
|
|
||||||
""" Input / redeem script templates (aka scriptSig) """
|
|
||||||
|
|
||||||
__slots__ = ()
|
|
||||||
|
|
||||||
REDEEM_PUBKEY = Template('pubkey', (
|
|
||||||
PUSH_SINGLE('signature'),
|
|
||||||
))
|
|
||||||
REDEEM_PUBKEY_HASH = Template('pubkey_hash', (
|
|
||||||
PUSH_SINGLE('signature'), PUSH_SINGLE('pubkey')
|
|
||||||
))
|
|
||||||
REDEEM_SCRIPT = Template('script', (
|
|
||||||
SMALL_INTEGER('signatures_count'), PUSH_MANY('pubkeys'), SMALL_INTEGER('pubkeys_count'),
|
|
||||||
OP_CHECKMULTISIG
|
|
||||||
))
|
|
||||||
REDEEM_SCRIPT_HASH = Template('script_hash', (
|
|
||||||
OP_0, PUSH_MANY('signatures'), PUSH_SUBSCRIPT('script', REDEEM_SCRIPT)
|
|
||||||
))
|
|
||||||
|
|
||||||
templates = [
|
|
||||||
REDEEM_PUBKEY,
|
|
||||||
REDEEM_PUBKEY_HASH,
|
|
||||||
REDEEM_SCRIPT_HASH,
|
|
||||||
REDEEM_SCRIPT
|
|
||||||
]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def redeem_pubkey_hash(cls, signature, pubkey):
|
|
||||||
return cls(template=cls.REDEEM_PUBKEY_HASH, values={
|
|
||||||
'signature': signature,
|
|
||||||
'pubkey': pubkey
|
|
||||||
})
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def redeem_script_hash(cls, signatures, pubkeys):
|
|
||||||
return cls(template=cls.REDEEM_SCRIPT_HASH, values={
|
|
||||||
'signatures': signatures,
|
|
||||||
'script': cls.redeem_script(signatures, pubkeys)
|
|
||||||
})
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def redeem_script(cls, signatures, pubkeys):
|
|
||||||
return cls(template=cls.REDEEM_SCRIPT, values={
|
|
||||||
'signatures_count': len(signatures),
|
|
||||||
'pubkeys': pubkeys,
|
|
||||||
'pubkeys_count': len(pubkeys)
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
class BaseOutputScript(Script):
|
|
||||||
|
|
||||||
__slots__ = ()
|
|
||||||
|
|
||||||
# output / payment script templates (aka scriptPubKey)
|
|
||||||
PAY_PUBKEY_FULL = Template('pay_pubkey_full', (
|
|
||||||
PUSH_SINGLE('pubkey'), OP_CHECKSIG
|
|
||||||
))
|
|
||||||
PAY_PUBKEY_HASH = Template('pay_pubkey_hash', (
|
|
||||||
OP_DUP, OP_HASH160, PUSH_SINGLE('pubkey_hash'), OP_EQUALVERIFY, OP_CHECKSIG
|
|
||||||
))
|
|
||||||
PAY_SCRIPT_HASH = Template('pay_script_hash', (
|
|
||||||
OP_HASH160, PUSH_SINGLE('script_hash'), OP_EQUAL
|
|
||||||
))
|
|
||||||
PAY_SEGWIT = Template('pay_script_hash+segwit', (
|
|
||||||
OP_0, PUSH_SINGLE('script_hash')
|
|
||||||
))
|
|
||||||
RETURN_DATA = Template('return_data', (
|
|
||||||
OP_RETURN, PUSH_SINGLE('data')
|
|
||||||
))
|
|
||||||
|
|
||||||
templates = [
|
|
||||||
PAY_PUBKEY_FULL,
|
|
||||||
PAY_PUBKEY_HASH,
|
|
||||||
PAY_SCRIPT_HASH,
|
|
||||||
PAY_SEGWIT,
|
|
||||||
RETURN_DATA
|
|
||||||
]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def pay_pubkey_hash(cls, pubkey_hash):
|
|
||||||
return cls(template=cls.PAY_PUBKEY_HASH, values={
|
|
||||||
'pubkey_hash': pubkey_hash
|
|
||||||
})
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def pay_script_hash(cls, script_hash):
|
|
||||||
return cls(template=cls.PAY_SCRIPT_HASH, values={
|
|
||||||
'script_hash': script_hash
|
|
||||||
})
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def return_data(cls, data):
|
|
||||||
return cls(template=cls.RETURN_DATA, values={
|
|
||||||
'data': data
|
|
||||||
})
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_pay_pubkey(self):
|
|
||||||
return self.template.name.endswith('pay_pubkey_full')
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_pay_pubkey_hash(self):
|
|
||||||
return self.template.name.endswith('pay_pubkey_hash')
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_pay_script_hash(self):
|
|
||||||
return self.template.name.endswith('pay_script_hash')
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_return_data(self):
|
|
||||||
return self.template.name.endswith('return_data')
|
|
|
@ -1,580 +0,0 @@
|
||||||
import logging
|
|
||||||
import typing
|
|
||||||
from typing import List, Iterable, Optional, Tuple
|
|
||||||
from binascii import hexlify
|
|
||||||
|
|
||||||
from lbry.crypto.hash import sha256
|
|
||||||
from lbry.wallet.client.basescript import BaseInputScript, BaseOutputScript
|
|
||||||
from lbry.wallet.client.baseaccount import BaseAccount
|
|
||||||
from lbry.wallet.client.constants import COIN, NULL_HASH32
|
|
||||||
from lbry.wallet.client.bcd_data_stream import BCDataStream
|
|
||||||
from lbry.wallet.client.hash import TXRef, TXRefImmutable
|
|
||||||
from lbry.wallet.client.util import ReadOnlyList
|
|
||||||
from lbry.error import InsufficientFundsError
|
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
|
||||||
from lbry.wallet.client import baseledger, wallet as basewallet
|
|
||||||
|
|
||||||
log = logging.getLogger()
|
|
||||||
|
|
||||||
|
|
||||||
class TXRefMutable(TXRef):
|
|
||||||
|
|
||||||
__slots__ = ('tx',)
|
|
||||||
|
|
||||||
def __init__(self, tx: 'BaseTransaction') -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.tx = tx
|
|
||||||
|
|
||||||
@property
|
|
||||||
def id(self):
|
|
||||||
if self._id is None:
|
|
||||||
self._id = hexlify(self.hash[::-1]).decode()
|
|
||||||
return self._id
|
|
||||||
|
|
||||||
@property
|
|
||||||
def hash(self):
|
|
||||||
if self._hash is None:
|
|
||||||
self._hash = sha256(sha256(self.tx.raw_sans_segwit))
|
|
||||||
return self._hash
|
|
||||||
|
|
||||||
@property
|
|
||||||
def height(self):
|
|
||||||
return self.tx.height
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
self._id = None
|
|
||||||
self._hash = None
|
|
||||||
|
|
||||||
|
|
||||||
class TXORef:
|
|
||||||
|
|
||||||
__slots__ = 'tx_ref', 'position'
|
|
||||||
|
|
||||||
def __init__(self, tx_ref: TXRef, position: int) -> None:
|
|
||||||
self.tx_ref = tx_ref
|
|
||||||
self.position = position
|
|
||||||
|
|
||||||
@property
|
|
||||||
def id(self):
|
|
||||||
return f'{self.tx_ref.id}:{self.position}'
|
|
||||||
|
|
||||||
@property
|
|
||||||
def hash(self):
|
|
||||||
return self.tx_ref.hash + BCDataStream.uint32.pack(self.position)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_null(self):
|
|
||||||
return self.tx_ref.is_null
|
|
||||||
|
|
||||||
@property
|
|
||||||
def txo(self) -> Optional['BaseOutput']:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class TXORefResolvable(TXORef):
|
|
||||||
|
|
||||||
__slots__ = ('_txo',)
|
|
||||||
|
|
||||||
def __init__(self, txo: 'BaseOutput') -> None:
|
|
||||||
assert txo.tx_ref is not None
|
|
||||||
assert txo.position is not None
|
|
||||||
super().__init__(txo.tx_ref, txo.position)
|
|
||||||
self._txo = txo
|
|
||||||
|
|
||||||
@property
|
|
||||||
def txo(self):
|
|
||||||
return self._txo
|
|
||||||
|
|
||||||
|
|
||||||
class InputOutput:
|
|
||||||
|
|
||||||
__slots__ = 'tx_ref', 'position'
|
|
||||||
|
|
||||||
def __init__(self, tx_ref: TXRef = None, position: int = None) -> None:
|
|
||||||
self.tx_ref = tx_ref
|
|
||||||
self.position = position
|
|
||||||
|
|
||||||
@property
|
|
||||||
def size(self) -> int:
|
|
||||||
""" Size of this input / output in bytes. """
|
|
||||||
stream = BCDataStream()
|
|
||||||
self.serialize_to(stream)
|
|
||||||
return len(stream.get_bytes())
|
|
||||||
|
|
||||||
def get_fee(self, ledger):
|
|
||||||
return self.size * ledger.fee_per_byte
|
|
||||||
|
|
||||||
def serialize_to(self, stream, alternate_script=None):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
class BaseInput(InputOutput):
|
|
||||||
|
|
||||||
script_class = BaseInputScript
|
|
||||||
|
|
||||||
NULL_SIGNATURE = b'\x00'*72
|
|
||||||
NULL_PUBLIC_KEY = b'\x00'*33
|
|
||||||
|
|
||||||
__slots__ = 'txo_ref', 'sequence', 'coinbase', 'script'
|
|
||||||
|
|
||||||
def __init__(self, txo_ref: TXORef, script: BaseInputScript, sequence: int = 0xFFFFFFFF,
|
|
||||||
tx_ref: TXRef = None, position: int = None) -> None:
|
|
||||||
super().__init__(tx_ref, position)
|
|
||||||
self.txo_ref = txo_ref
|
|
||||||
self.sequence = sequence
|
|
||||||
self.coinbase = script if txo_ref.is_null else None
|
|
||||||
self.script = script if not txo_ref.is_null else None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_coinbase(self):
|
|
||||||
return self.coinbase is not None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def spend(cls, txo: 'BaseOutput') -> 'BaseInput':
|
|
||||||
""" Create an input to spend the output."""
|
|
||||||
assert txo.script.is_pay_pubkey_hash, 'Attempting to spend unsupported output.'
|
|
||||||
script = cls.script_class.redeem_pubkey_hash(cls.NULL_SIGNATURE, cls.NULL_PUBLIC_KEY)
|
|
||||||
return cls(txo.ref, script)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def amount(self) -> int:
|
|
||||||
""" Amount this input adds to the transaction. """
|
|
||||||
if self.txo_ref.txo is None:
|
|
||||||
raise ValueError('Cannot resolve output to get amount.')
|
|
||||||
return self.txo_ref.txo.amount
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_my_account(self) -> Optional[bool]:
|
|
||||||
""" True if the output this input spends is yours. """
|
|
||||||
if self.txo_ref.txo is None:
|
|
||||||
return False
|
|
||||||
return self.txo_ref.txo.is_my_account
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def deserialize_from(cls, stream):
|
|
||||||
tx_ref = TXRefImmutable.from_hash(stream.read(32), -1)
|
|
||||||
position = stream.read_uint32()
|
|
||||||
script = stream.read_string()
|
|
||||||
sequence = stream.read_uint32()
|
|
||||||
return cls(
|
|
||||||
TXORef(tx_ref, position),
|
|
||||||
cls.script_class(script) if not tx_ref.is_null else script,
|
|
||||||
sequence
|
|
||||||
)
|
|
||||||
|
|
||||||
def serialize_to(self, stream, alternate_script=None):
|
|
||||||
stream.write(self.txo_ref.tx_ref.hash)
|
|
||||||
stream.write_uint32(self.txo_ref.position)
|
|
||||||
if alternate_script is not None:
|
|
||||||
stream.write_string(alternate_script)
|
|
||||||
else:
|
|
||||||
if self.is_coinbase:
|
|
||||||
stream.write_string(self.coinbase)
|
|
||||||
else:
|
|
||||||
stream.write_string(self.script.source)
|
|
||||||
stream.write_uint32(self.sequence)
|
|
||||||
|
|
||||||
|
|
||||||
class BaseOutputEffectiveAmountEstimator:
|
|
||||||
|
|
||||||
__slots__ = 'txo', 'txi', 'fee', 'effective_amount'
|
|
||||||
|
|
||||||
def __init__(self, ledger: 'baseledger.BaseLedger', txo: 'BaseOutput') -> None:
|
|
||||||
self.txo = txo
|
|
||||||
self.txi = ledger.transaction_class.input_class.spend(txo)
|
|
||||||
self.fee: int = self.txi.get_fee(ledger)
|
|
||||||
self.effective_amount: int = txo.amount - self.fee
|
|
||||||
|
|
||||||
def __lt__(self, other):
|
|
||||||
return self.effective_amount < other.effective_amount
|
|
||||||
|
|
||||||
|
|
||||||
class BaseOutput(InputOutput):
|
|
||||||
|
|
||||||
script_class = BaseOutputScript
|
|
||||||
estimator_class = BaseOutputEffectiveAmountEstimator
|
|
||||||
|
|
||||||
__slots__ = 'amount', 'script', 'is_change', 'is_my_account'
|
|
||||||
|
|
||||||
def __init__(self, amount: int, script: BaseOutputScript,
|
|
||||||
tx_ref: TXRef = None, position: int = None,
|
|
||||||
is_change: Optional[bool] = None, is_my_account: Optional[bool] = None
|
|
||||||
) -> None:
|
|
||||||
super().__init__(tx_ref, position)
|
|
||||||
self.amount = amount
|
|
||||||
self.script = script
|
|
||||||
self.is_change = is_change
|
|
||||||
self.is_my_account = is_my_account
|
|
||||||
|
|
||||||
def update_annotations(self, annotated):
|
|
||||||
if annotated is None:
|
|
||||||
self.is_change = False
|
|
||||||
self.is_my_account = False
|
|
||||||
else:
|
|
||||||
self.is_change = annotated.is_change
|
|
||||||
self.is_my_account = annotated.is_my_account
|
|
||||||
|
|
||||||
@property
|
|
||||||
def ref(self):
|
|
||||||
return TXORefResolvable(self)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def id(self):
|
|
||||||
return self.ref.id
|
|
||||||
|
|
||||||
@property
|
|
||||||
def pubkey_hash(self):
|
|
||||||
return self.script.values['pubkey_hash']
|
|
||||||
|
|
||||||
@property
|
|
||||||
def has_address(self):
|
|
||||||
return 'pubkey_hash' in self.script.values
|
|
||||||
|
|
||||||
def get_address(self, ledger):
|
|
||||||
return ledger.hash160_to_address(self.pubkey_hash)
|
|
||||||
|
|
||||||
def get_estimator(self, ledger):
|
|
||||||
return self.estimator_class(ledger, self)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def pay_pubkey_hash(cls, amount, pubkey_hash):
|
|
||||||
return cls(amount, cls.script_class.pay_pubkey_hash(pubkey_hash))
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def deserialize_from(cls, stream):
|
|
||||||
return cls(
|
|
||||||
amount=stream.read_uint64(),
|
|
||||||
script=cls.script_class(stream.read_string())
|
|
||||||
)
|
|
||||||
|
|
||||||
def serialize_to(self, stream, alternate_script=None):
|
|
||||||
stream.write_uint64(self.amount)
|
|
||||||
stream.write_string(self.script.source)
|
|
||||||
|
|
||||||
|
|
||||||
class BaseTransaction:
|
|
||||||
|
|
||||||
input_class = BaseInput
|
|
||||||
output_class = BaseOutput
|
|
||||||
|
|
||||||
def __init__(self, raw=None, version: int = 1, locktime: int = 0, is_verified: bool = False,
|
|
||||||
height: int = -2, position: int = -1) -> None:
|
|
||||||
self._raw = raw
|
|
||||||
self._raw_sans_segwit = None
|
|
||||||
self.is_segwit_flag = 0
|
|
||||||
self.witnesses: List[bytes] = []
|
|
||||||
self.ref = TXRefMutable(self)
|
|
||||||
self.version = version
|
|
||||||
self.locktime = locktime
|
|
||||||
self._inputs: List[BaseInput] = []
|
|
||||||
self._outputs: List[BaseOutput] = []
|
|
||||||
self.is_verified = is_verified
|
|
||||||
# Height Progression
|
|
||||||
# -2: not broadcast
|
|
||||||
# -1: in mempool but has unconfirmed inputs
|
|
||||||
# 0: in mempool and all inputs confirmed
|
|
||||||
# +num: confirmed in a specific block (height)
|
|
||||||
self.height = height
|
|
||||||
self.position = position
|
|
||||||
if raw is not None:
|
|
||||||
self._deserialize()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_broadcast(self):
|
|
||||||
return self.height > -2
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_mempool(self):
|
|
||||||
return self.height in (-1, 0)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_confirmed(self):
|
|
||||||
return self.height > 0
|
|
||||||
|
|
||||||
@property
|
|
||||||
def id(self):
|
|
||||||
return self.ref.id
|
|
||||||
|
|
||||||
@property
|
|
||||||
def hash(self):
|
|
||||||
return self.ref.hash
|
|
||||||
|
|
||||||
@property
|
|
||||||
def raw(self):
|
|
||||||
if self._raw is None:
|
|
||||||
self._raw = self._serialize()
|
|
||||||
return self._raw
|
|
||||||
|
|
||||||
@property
|
|
||||||
def raw_sans_segwit(self):
|
|
||||||
if self.is_segwit_flag:
|
|
||||||
if self._raw_sans_segwit is None:
|
|
||||||
self._raw_sans_segwit = self._serialize(sans_segwit=True)
|
|
||||||
return self._raw_sans_segwit
|
|
||||||
return self.raw
|
|
||||||
|
|
||||||
def _reset(self):
|
|
||||||
self._raw = None
|
|
||||||
self._raw_sans_segwit = None
|
|
||||||
self.ref.reset()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def inputs(self) -> ReadOnlyList[BaseInput]:
|
|
||||||
return ReadOnlyList(self._inputs)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def outputs(self) -> ReadOnlyList[BaseOutput]:
|
|
||||||
return ReadOnlyList(self._outputs)
|
|
||||||
|
|
||||||
def _add(self, existing_ios: List, new_ios: Iterable[InputOutput], reset=False) -> 'BaseTransaction':
|
|
||||||
for txio in new_ios:
|
|
||||||
txio.tx_ref = self.ref
|
|
||||||
txio.position = len(existing_ios)
|
|
||||||
existing_ios.append(txio)
|
|
||||||
if reset:
|
|
||||||
self._reset()
|
|
||||||
return self
|
|
||||||
|
|
||||||
def add_inputs(self, inputs: Iterable[BaseInput]) -> 'BaseTransaction':
|
|
||||||
return self._add(self._inputs, inputs, True)
|
|
||||||
|
|
||||||
def add_outputs(self, outputs: Iterable[BaseOutput]) -> 'BaseTransaction':
|
|
||||||
return self._add(self._outputs, outputs, True)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def size(self) -> int:
|
|
||||||
""" Size in bytes of the entire transaction. """
|
|
||||||
return len(self.raw)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def base_size(self) -> int:
|
|
||||||
""" Size of transaction without inputs or outputs in bytes. """
|
|
||||||
return (
|
|
||||||
self.size
|
|
||||||
- sum(txi.size for txi in self._inputs)
|
|
||||||
- sum(txo.size for txo in self._outputs)
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def input_sum(self):
|
|
||||||
return sum(i.amount for i in self.inputs if i.txo_ref.txo is not None)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def output_sum(self):
|
|
||||||
return sum(o.amount for o in self.outputs)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def net_account_balance(self) -> int:
|
|
||||||
balance = 0
|
|
||||||
for txi in self.inputs:
|
|
||||||
if txi.txo_ref.txo is None:
|
|
||||||
continue
|
|
||||||
if txi.is_my_account is None:
|
|
||||||
raise ValueError(
|
|
||||||
"Cannot access net_account_balance if inputs/outputs do not "
|
|
||||||
"have is_my_account set properly."
|
|
||||||
)
|
|
||||||
if txi.is_my_account:
|
|
||||||
balance -= txi.amount
|
|
||||||
for txo in self.outputs:
|
|
||||||
if txo.is_my_account is None:
|
|
||||||
raise ValueError(
|
|
||||||
"Cannot access net_account_balance if inputs/outputs do not "
|
|
||||||
"have is_my_account set properly."
|
|
||||||
)
|
|
||||||
if txo.is_my_account:
|
|
||||||
balance += txo.amount
|
|
||||||
return balance
|
|
||||||
|
|
||||||
@property
|
|
||||||
def fee(self) -> int:
|
|
||||||
return self.input_sum - self.output_sum
|
|
||||||
|
|
||||||
def get_base_fee(self, ledger) -> int:
|
|
||||||
""" Fee for base tx excluding inputs and outputs. """
|
|
||||||
return self.base_size * ledger.fee_per_byte
|
|
||||||
|
|
||||||
def get_effective_input_sum(self, ledger) -> int:
|
|
||||||
""" Sum of input values *minus* the cost involved to spend them. """
|
|
||||||
return sum(txi.amount - txi.get_fee(ledger) for txi in self._inputs)
|
|
||||||
|
|
||||||
def get_total_output_sum(self, ledger) -> int:
|
|
||||||
""" Sum of output values *plus* the cost involved to spend them. """
|
|
||||||
return sum(txo.amount + txo.get_fee(ledger) for txo in self._outputs)
|
|
||||||
|
|
||||||
def _serialize(self, with_inputs: bool = True, sans_segwit: bool = False) -> bytes:
|
|
||||||
stream = BCDataStream()
|
|
||||||
stream.write_uint32(self.version)
|
|
||||||
if with_inputs:
|
|
||||||
stream.write_compact_size(len(self._inputs))
|
|
||||||
for txin in self._inputs:
|
|
||||||
txin.serialize_to(stream)
|
|
||||||
stream.write_compact_size(len(self._outputs))
|
|
||||||
for txout in self._outputs:
|
|
||||||
txout.serialize_to(stream)
|
|
||||||
stream.write_uint32(self.locktime)
|
|
||||||
return stream.get_bytes()
|
|
||||||
|
|
||||||
def _serialize_for_signature(self, signing_input: int) -> bytes:
|
|
||||||
stream = BCDataStream()
|
|
||||||
stream.write_uint32(self.version)
|
|
||||||
stream.write_compact_size(len(self._inputs))
|
|
||||||
for i, txin in enumerate(self._inputs):
|
|
||||||
if signing_input == i:
|
|
||||||
assert txin.txo_ref.txo is not None
|
|
||||||
txin.serialize_to(stream, txin.txo_ref.txo.script.source)
|
|
||||||
else:
|
|
||||||
txin.serialize_to(stream, b'')
|
|
||||||
stream.write_compact_size(len(self._outputs))
|
|
||||||
for txout in self._outputs:
|
|
||||||
txout.serialize_to(stream)
|
|
||||||
stream.write_uint32(self.locktime)
|
|
||||||
stream.write_uint32(self.signature_hash_type(1)) # signature hash type: SIGHASH_ALL
|
|
||||||
return stream.get_bytes()
|
|
||||||
|
|
||||||
def _deserialize(self):
|
|
||||||
if self._raw is not None:
|
|
||||||
stream = BCDataStream(self._raw)
|
|
||||||
self.version = stream.read_uint32()
|
|
||||||
input_count = stream.read_compact_size()
|
|
||||||
if input_count == 0:
|
|
||||||
self.is_segwit_flag = stream.read_uint8()
|
|
||||||
input_count = stream.read_compact_size()
|
|
||||||
self._add(self._inputs, [
|
|
||||||
self.input_class.deserialize_from(stream) for _ in range(input_count)
|
|
||||||
])
|
|
||||||
output_count = stream.read_compact_size()
|
|
||||||
self._add(self._outputs, [
|
|
||||||
self.output_class.deserialize_from(stream) for _ in range(output_count)
|
|
||||||
])
|
|
||||||
if self.is_segwit_flag:
|
|
||||||
# drain witness portion of transaction
|
|
||||||
# too many witnesses for no crime
|
|
||||||
self.witnesses = []
|
|
||||||
for _ in range(input_count):
|
|
||||||
for _ in range(stream.read_compact_size()):
|
|
||||||
self.witnesses.append(stream.read(stream.read_compact_size()))
|
|
||||||
self.locktime = stream.read_uint32()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def ensure_all_have_same_ledger_and_wallet(
|
|
||||||
cls, funding_accounts: Iterable[BaseAccount],
|
|
||||||
change_account: BaseAccount = None) -> Tuple['baseledger.BaseLedger', 'basewallet.Wallet']:
|
|
||||||
ledger = wallet = None
|
|
||||||
for account in funding_accounts:
|
|
||||||
if ledger is None:
|
|
||||||
ledger = account.ledger
|
|
||||||
wallet = account.wallet
|
|
||||||
if ledger != account.ledger:
|
|
||||||
raise ValueError(
|
|
||||||
'All funding accounts used to create a transaction must be on the same ledger.'
|
|
||||||
)
|
|
||||||
if wallet != account.wallet:
|
|
||||||
raise ValueError(
|
|
||||||
'All funding accounts used to create a transaction must be from the same wallet.'
|
|
||||||
)
|
|
||||||
if change_account is not None:
|
|
||||||
if change_account.ledger != ledger:
|
|
||||||
raise ValueError('Change account must use same ledger as funding accounts.')
|
|
||||||
if change_account.wallet != wallet:
|
|
||||||
raise ValueError('Change account must use same wallet as funding accounts.')
|
|
||||||
if ledger is None:
|
|
||||||
raise ValueError('No ledger found.')
|
|
||||||
if wallet is None:
|
|
||||||
raise ValueError('No wallet found.')
|
|
||||||
return ledger, wallet
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def create(cls, inputs: Iterable[BaseInput], outputs: Iterable[BaseOutput],
|
|
||||||
funding_accounts: Iterable[BaseAccount], change_account: BaseAccount,
|
|
||||||
sign: bool = True):
|
|
||||||
""" Find optimal set of inputs when only outputs are provided; add change
|
|
||||||
outputs if only inputs are provided or if inputs are greater than outputs. """
|
|
||||||
|
|
||||||
tx = cls() \
|
|
||||||
.add_inputs(inputs) \
|
|
||||||
.add_outputs(outputs)
|
|
||||||
|
|
||||||
ledger, _ = cls.ensure_all_have_same_ledger_and_wallet(funding_accounts, change_account)
|
|
||||||
|
|
||||||
# value of the outputs plus associated fees
|
|
||||||
cost = (
|
|
||||||
tx.get_base_fee(ledger) +
|
|
||||||
tx.get_total_output_sum(ledger)
|
|
||||||
)
|
|
||||||
# value of the inputs less the cost to spend those inputs
|
|
||||||
payment = tx.get_effective_input_sum(ledger)
|
|
||||||
|
|
||||||
try:
|
|
||||||
|
|
||||||
for _ in range(5):
|
|
||||||
|
|
||||||
if payment < cost:
|
|
||||||
deficit = cost - payment
|
|
||||||
spendables = await ledger.get_spendable_utxos(deficit, funding_accounts)
|
|
||||||
if not spendables:
|
|
||||||
raise InsufficientFundsError()
|
|
||||||
payment += sum(s.effective_amount for s in spendables)
|
|
||||||
tx.add_inputs(s.txi for s in spendables)
|
|
||||||
|
|
||||||
cost_of_change = (
|
|
||||||
tx.get_base_fee(ledger) +
|
|
||||||
cls.output_class.pay_pubkey_hash(COIN, NULL_HASH32).get_fee(ledger)
|
|
||||||
)
|
|
||||||
if payment > cost:
|
|
||||||
change = payment - cost
|
|
||||||
if change > cost_of_change:
|
|
||||||
change_address = await change_account.change.get_or_create_usable_address()
|
|
||||||
change_hash160 = change_account.ledger.address_to_hash160(change_address)
|
|
||||||
change_amount = change - cost_of_change
|
|
||||||
change_output = cls.output_class.pay_pubkey_hash(change_amount, change_hash160)
|
|
||||||
change_output.is_change = True
|
|
||||||
tx.add_outputs([cls.output_class.pay_pubkey_hash(change_amount, change_hash160)])
|
|
||||||
|
|
||||||
if tx._outputs:
|
|
||||||
break
|
|
||||||
# this condition and the outer range(5) loop cover an edge case
|
|
||||||
# whereby a single input is just enough to cover the fee and
|
|
||||||
# has some change left over, but the change left over is less
|
|
||||||
# than the cost_of_change: thus the input is completely
|
|
||||||
# consumed and no output is added, which is an invalid tx.
|
|
||||||
# to be able to spend this input we must increase the cost
|
|
||||||
# of the TX and run through the balance algorithm a second time
|
|
||||||
# adding an extra input and change output, making tx valid.
|
|
||||||
# we do this 5 times in case the other UTXOs added are also
|
|
||||||
# less than the fee, after 5 attempts we give up and go home
|
|
||||||
cost += cost_of_change + 1
|
|
||||||
|
|
||||||
if sign:
|
|
||||||
await tx.sign(funding_accounts)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
log.exception('Failed to create transaction:')
|
|
||||||
await ledger.release_tx(tx)
|
|
||||||
raise e
|
|
||||||
|
|
||||||
return tx
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def signature_hash_type(hash_type):
|
|
||||||
return hash_type
|
|
||||||
|
|
||||||
async def sign(self, funding_accounts: Iterable[BaseAccount]):
|
|
||||||
ledger, wallet = self.ensure_all_have_same_ledger_and_wallet(funding_accounts)
|
|
||||||
for i, txi in enumerate(self._inputs):
|
|
||||||
assert txi.script is not None
|
|
||||||
assert txi.txo_ref.txo is not None
|
|
||||||
txo_script = txi.txo_ref.txo.script
|
|
||||||
if txo_script.is_pay_pubkey_hash:
|
|
||||||
address = ledger.hash160_to_address(txo_script.values['pubkey_hash'])
|
|
||||||
private_key = await ledger.get_private_key_for_address(wallet, address)
|
|
||||||
assert private_key is not None, 'Cannot find private key for signing output.'
|
|
||||||
tx = self._serialize_for_signature(i)
|
|
||||||
txi.script.values['signature'] = \
|
|
||||||
private_key.sign(tx) + bytes((self.signature_hash_type(1),))
|
|
||||||
txi.script.values['pubkey'] = private_key.public_key.pubkey_bytes
|
|
||||||
txi.script.generate()
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("Don't know how to spend this output.")
|
|
||||||
self._reset()
|
|
|
@ -1,6 +0,0 @@
|
||||||
NULL_HASH32 = b'\x00'*32
|
|
||||||
|
|
||||||
CENT = 1000000
|
|
||||||
COIN = 100*CENT
|
|
||||||
|
|
||||||
TIMEOUT = 30.0
|
|
Loading…
Reference in a new issue