lbry-sdk/torba/baseledger.py

312 lines
12 KiB
Python
Raw Normal View History

2018-05-25 08:03:25 +02:00
import os
2018-06-11 15:33:32 +02:00
import six
2018-05-25 08:03:25 +02:00
import hashlib
import logging
2018-05-25 08:03:25 +02:00
from binascii import hexlify, unhexlify
from typing import Dict, Type, Iterable, Generator
2018-05-25 08:03:25 +02:00
from operator import itemgetter
from collections import namedtuple
2018-05-25 08:03:25 +02:00
2018-06-11 15:33:32 +02:00
from twisted.internet import defer
2018-05-25 08:03:25 +02:00
2018-06-11 15:33:32 +02:00
from torba import baseaccount
from torba import basedatabase
from torba import baseheader
from torba import basenetwork
from torba import basetransaction
2018-05-25 08:03:25 +02:00
from torba.stream import StreamController, execute_serially
2018-06-11 15:33:32 +02:00
from torba.hash import hash160, double_sha256, Base58
2018-05-25 08:03:25 +02:00
log = logging.getLogger(__name__)
2018-05-25 08:03:25 +02:00
2018-06-11 15:33:32 +02:00
class LedgerRegistry(type):
ledgers = {} # type: Dict[str, Type[BaseLedger]]
2018-05-25 08:03:25 +02:00
2018-06-11 15:33:32 +02:00
def __new__(mcs, name, bases, attrs):
cls = super(LedgerRegistry, mcs).__new__(mcs, name, bases, attrs) # type: Type[BaseLedger]
if not (name == 'BaseLedger' and not bases):
ledger_id = cls.get_id()
assert ledger_id not in mcs.ledgers,\
'Ledger with id "{}" already registered.'.format(ledger_id)
mcs.ledgers[ledger_id] = cls
return cls
2018-05-25 08:03:25 +02:00
2018-06-11 15:33:32 +02:00
@classmethod
def get_ledger_class(mcs, ledger_id): # type: (str) -> Type[BaseLedger]
return mcs.ledgers[ledger_id]
2018-05-25 08:03:25 +02:00
class TransactionEvent(namedtuple('TransactionEvent', ('address', 'tx', 'height', 'is_verified'))):
pass
2018-06-11 15:33:32 +02:00
class BaseLedger(six.with_metaclass(LedgerRegistry)):
2018-05-25 08:03:25 +02:00
2018-06-11 15:33:32 +02:00
name = None
symbol = None
network_name = None
2018-05-25 08:03:25 +02:00
2018-06-11 15:33:32 +02:00
account_class = baseaccount.BaseAccount
database_class = basedatabase.BaseDatabase
headers_class = baseheader.BaseHeaders
network_class = basenetwork.BaseNetwork
transaction_class = basetransaction.BaseTransaction
2018-05-25 08:03:25 +02:00
2018-06-11 15:33:32 +02:00
secret_prefix = None
pubkey_address_prefix = None
script_address_prefix = None
extended_public_key_prefix = None
extended_private_key_prefix = None
2018-05-25 08:03:25 +02:00
2018-06-08 05:47:46 +02:00
default_fee_per_byte = 10
2018-05-25 08:03:25 +02:00
2018-06-27 00:31:42 +02:00
def __init__(self, config=None, db=None, network=None, headers_class=None):
2018-05-25 08:03:25 +02:00
self.config = config or {}
2018-06-14 21:17:59 +02:00
self.db = db or self.database_class(
os.path.join(self.path, "blockchain.db")
) # type: basedatabase.BaseDatabase
2018-06-08 05:47:46 +02:00
self.network = network or self.network_class(self)
2018-05-25 08:03:25 +02:00
self.network.on_header.listen(self.process_header)
self.network.on_status.listen(self.process_status)
2018-06-11 15:33:32 +02:00
self.accounts = set()
2018-06-27 00:31:42 +02:00
self.headers = (headers_class or self.headers_class)(self)
2018-06-11 15:33:32 +02:00
self.fee_per_byte = self.config.get('fee_per_byte', self.default_fee_per_byte)
2018-06-08 05:47:46 +02:00
self._on_transaction_controller = StreamController()
self.on_transaction = self._on_transaction_controller.stream
self.on_transaction.listen(
lambda e: log.info('({}) on_transaction: address={}, height={}, is_verified={}, tx.id={}'.format(
self.get_id(), e.address, e.height, e.is_verified, e.tx.hex_id)
)
)
self._on_header_controller = StreamController()
self.on_header = self._on_header_controller.stream
2018-06-08 05:47:46 +02:00
2018-06-25 15:54:35 +02:00
self._transaction_processing_locks = {}
2018-06-11 15:33:32 +02:00
@classmethod
def get_id(cls):
return '{}_{}'.format(cls.symbol.lower(), cls.network_name.lower())
def hash160_to_address(self, h160):
raw_address = self.pubkey_address_prefix + h160
return Base58.encode(bytearray(raw_address + double_sha256(raw_address)[0:4]))
@staticmethod
def address_to_hash160(address):
bytes = Base58.decode(address)
prefix, pubkey_bytes, addr_checksum = bytes[0], bytes[1:21], bytes[21:]
return pubkey_bytes
def public_key_to_address(self, public_key):
return self.hash160_to_address(hash160(public_key))
@staticmethod
def private_key_to_wif(private_key):
return b'\x1c' + private_key + b'\x01'
2018-06-08 05:47:46 +02:00
@property
def path(self):
2018-06-12 16:02:04 +02:00
return os.path.join(self.config['wallet_path'], self.get_id())
2018-06-08 05:47:46 +02:00
def get_input_output_fee(self, io):
""" Fee based on size of the input / output. """
return self.fee_per_byte * io.size
def get_transaction_base_fee(self, tx):
""" Fee for the transaction header and all outputs; without inputs. """
return self.fee_per_byte * tx.base_size
2018-05-25 08:03:25 +02:00
2018-06-14 02:57:57 +02:00
@defer.inlineCallbacks
def add_account(self, account): # type: (baseaccount.BaseAccount) -> None
self.accounts.add(account)
if self.network.is_connected:
yield self.update_account(account)
2018-06-11 15:33:32 +02:00
@defer.inlineCallbacks
def get_private_key_for_address(self, address):
2018-06-12 16:02:04 +02:00
match = yield self.db.get_address(address)
2018-06-11 15:33:32 +02:00
if match:
for account in self.accounts:
if bytes(match['account']) == account.public_key.address:
defer.returnValue(account.get_private_key(match['chain'], match['position']))
2018-06-08 05:47:46 +02:00
def get_unspent_outputs(self, account):
return self.db.get_utxos(account, self.transaction_class.output_class)
@defer.inlineCallbacks
def get_effective_amount_estimators(self, funding_accounts):
# type: (Iterable[baseaccount.BaseAccount]) -> defer.Deferred
estimators = []
for account in funding_accounts:
utxos = yield self.get_unspent_outputs(account)
for utxo in utxos:
estimators.append(utxo.get_estimator(self))
defer.returnValue(estimators)
2018-06-12 16:02:04 +02:00
@defer.inlineCallbacks
def get_local_status(self, address):
address_details = yield self.db.get_address(address)
history = address_details['history'] or ''
hash = hashlib.sha256(history.encode()).digest()
2018-06-12 16:02:04 +02:00
defer.returnValue(hexlify(hash))
@defer.inlineCallbacks
def get_local_history(self, address):
address_details = yield self.db.get_address(address)
history = address_details['history'] or ''
parts = history.split(':')[:-1]
2018-06-12 16:02:04 +02:00
defer.returnValue(list(zip(parts[0::2], map(int, parts[1::2]))))
2018-06-25 15:54:35 +02:00
@staticmethod
def get_root_of_merkle_tree(branches, branch_positions, working_branch):
for i, branch in enumerate(branches):
other_branch = unhexlify(branch)[::-1]
other_branch_on_left = bool((branch_positions >> i) & 1)
if other_branch_on_left:
combined = other_branch + working_branch
else:
combined = working_branch + other_branch
working_branch = double_sha256(combined)
return hexlify(working_branch[::-1])
@defer.inlineCallbacks
def is_valid_transaction(self, tx, height):
height <= len(self.headers) or defer.returnValue(False)
2018-06-25 15:54:35 +02:00
merkle = yield self.network.get_merkle(tx.hex_id.decode(), height)
merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash)
header = self.headers[height]
defer.returnValue(merkle_root == header['merkle_root'])
2018-05-25 08:03:25 +02:00
@defer.inlineCallbacks
def start(self):
2018-06-08 05:47:46 +02:00
if not os.path.exists(self.path):
os.mkdir(self.path)
yield self.db.start()
2018-05-25 08:03:25 +02:00
first_connection = self.network.on_connected.first
self.network.start()
yield first_connection
self.headers.touch()
yield self.update_headers()
yield self.network.subscribe_headers()
yield self.update_accounts()
def stop(self):
return self.network.stop()
@execute_serially
@defer.inlineCallbacks
def update_headers(self):
while True:
height_sought = len(self.headers)
headers = yield self.network.get_headers(height_sought)
if headers['count'] <= 0:
break
yield self.headers.connect(height_sought, unhexlify(headers['hex']))
self._on_header_controller.add(height_sought)
2018-05-25 08:03:25 +02:00
@defer.inlineCallbacks
def process_header(self, response):
header = response[0]
if self.update_headers.is_running:
return
if header['height'] == len(self.headers):
# New header from network directly connects after the last local header.
yield self.headers.connect(len(self.headers), unhexlify(header['hex']))
self._on_header_controller.add(len(self.headers))
2018-05-25 08:03:25 +02:00
elif header['height'] > len(self.headers):
# New header is several heights ahead of local, do download instead.
yield self.update_headers()
@execute_serially
def update_accounts(self):
return defer.DeferredList([
self.update_account(a) for a in self.accounts
])
@defer.inlineCallbacks
2018-06-12 16:02:04 +02:00
def update_account(self, account): # type: (baseaccount.BaseAccount) -> defer.Defferred
2018-05-25 08:03:25 +02:00
# Before subscribing, download history for any addresses that don't have any,
# this avoids situation where we're getting status updates to addresses we know
# need to update anyways. Continue to get history and create more addresses until
# all missing addresses are created and history for them is fully restored.
2018-06-12 16:02:04 +02:00
yield account.ensure_address_gap()
addresses = yield account.get_unused_addresses()
2018-05-25 08:03:25 +02:00
while addresses:
yield defer.DeferredList([
self.update_history(a) for a in addresses
])
2018-06-12 16:02:04 +02:00
addresses = yield account.ensure_address_gap()
2018-05-25 08:03:25 +02:00
# By this point all of the addresses should be restored and we
# can now subscribe all of them to receive updates.
2018-06-12 16:02:04 +02:00
all_addresses = yield account.get_addresses()
yield defer.DeferredList(
list(map(self.subscribe_history, all_addresses))
)
2018-06-08 05:47:46 +02:00
2018-05-25 08:03:25 +02:00
@defer.inlineCallbacks
2018-06-12 16:02:04 +02:00
def update_history(self, address):
remote_history = yield self.network.get_history(address)
2018-06-25 15:54:35 +02:00
local_history = yield self.get_local_history(address)
2018-06-12 16:02:04 +02:00
2018-06-25 15:54:35 +02:00
synced_history = []
for i, (hex_id, remote_height) in enumerate(map(itemgetter('tx_hash', 'height'), remote_history)):
2018-06-25 15:54:35 +02:00
synced_history.append((hex_id, remote_height))
2018-06-25 15:54:35 +02:00
2018-06-27 00:31:42 +02:00
if i < len(local_history) and local_history[i] == (hex_id.decode(), remote_height):
2018-06-12 16:02:04 +02:00
continue
lock = self._transaction_processing_locks.setdefault(hex_id, defer.DeferredLock())
2018-06-25 15:54:35 +02:00
yield lock.acquire()
try:
# see if we have a local copy of transaction, otherwise fetch it from server
raw, local_height, is_verified = yield self.db.get_transaction(unhexlify(hex_id)[::-1])
2018-06-25 15:54:35 +02:00
save_tx = None
if raw is None:
_raw = yield self.network.get_transaction(hex_id)
2018-06-25 15:54:35 +02:00
tx = self.transaction_class(unhexlify(_raw))
save_tx = 'insert'
else:
tx = self.transaction_class(raw)
2018-06-25 15:54:35 +02:00
if remote_height > 0 and not is_verified:
is_verified = yield self.is_valid_transaction(tx, remote_height)
is_verified = 1 if is_verified else 0
2018-06-25 15:54:35 +02:00
if save_tx is None:
save_tx = 'update'
yield self.db.save_transaction_io(
save_tx, tx, remote_height, is_verified, address, self.address_to_hash160(address),
''.join('{}:{}:'.format(tx_id.decode(), tx_height) for tx_id, tx_height in synced_history)
2018-06-25 15:54:35 +02:00
)
self._on_transaction_controller.add(TransactionEvent(address, tx, remote_height, is_verified))
2018-06-25 15:54:35 +02:00
finally:
lock.release()
if not lock.locked:
del self._transaction_processing_locks[hex_id]
2018-05-25 08:03:25 +02:00
@defer.inlineCallbacks
def subscribe_history(self, address):
2018-06-08 05:47:46 +02:00
remote_status = yield self.network.subscribe_address(address)
2018-06-12 16:02:04 +02:00
local_status = yield self.get_local_status(address)
2018-06-08 05:47:46 +02:00
if local_status != remote_status:
2018-06-12 16:02:04 +02:00
yield self.update_history(address)
2018-05-25 08:03:25 +02:00
2018-06-08 05:47:46 +02:00
@defer.inlineCallbacks
2018-05-25 08:03:25 +02:00
def process_status(self, response):
2018-06-08 05:47:46 +02:00
address, remote_status = response
2018-06-12 16:02:04 +02:00
local_status = yield self.get_local_status(address)
2018-06-08 05:47:46 +02:00
if local_status != remote_status:
2018-06-12 16:02:04 +02:00
yield self.update_history(address)
2018-05-25 08:03:25 +02:00
def broadcast(self, tx):
return self.network.broadcast(hexlify(tx.raw))