import os import hashlib from binascii import hexlify, unhexlify from typing import List, Dict, Type from operator import itemgetter from twisted.internet import threads, defer, task, reactor from torba.account import Account, AccountsView from torba.basecoin import BaseCoin from torba.basetransaction import BaseTransaction from torba.basenetwork import BaseNetwork from torba.stream import StreamController, execute_serially from torba.util import hex_to_int, int_to_hex, rev_hex, hash_encode from torba.hash import double_sha256, pow_hash class Address: def __init__(self, pubkey_hash): self.pubkey_hash = pubkey_hash self.transactions = [] # type: List[BaseTransaction] def __iter__(self): return iter(self.transactions) def __len__(self): return len(self.transactions) def add_transaction(self, transaction): self.transactions.append(transaction) def get_unspent_utxos(self): inputs, outputs, utxos = [], [], [] for tx in self: for txi in tx.inputs: inputs.append((txi.output_txid, txi.output_index)) for txo in tx.outputs: if txo.script.is_pay_pubkey_hash and txo.script.values['pubkey_hash'] == self.pubkey_hash: outputs.append((txo, txo.transaction.hash, txo.index)) for output in set(outputs): if output[1:] not in inputs: yield output[0] class BaseLedger: # coin_class is automatically set by BaseCoin metaclass # when it creates the Coin classes, there is a 1..1 relationship # between a coin and a ledger (at the class level) but a 1..* relationship # at instance level. Only one Ledger instance should exist per coin class, # but many coin instances can exist linking back to the single Ledger instance. coin_class = None # type: Type[BaseCoin] network_class = None # type: Type[BaseNetwork] verify_bits_to_target = True def __init__(self, accounts, config=None, network=None, db=None): self.accounts = accounts # type: AccountsView self.config = config or {} self.db = db self.addresses = {} # type: Dict[str, Address] self.transactions = {} # type: Dict[str, BaseTransaction] self.headers = Headers(self) self._on_transaction_controller = StreamController() self.on_transaction = self._on_transaction_controller.stream self.network = network or self.network_class(self.config) self.network.on_header.listen(self.process_header) self.network.on_status.listen(self.process_status) @property def transaction_class(self): return self.coin_class.transaction_class @classmethod def from_json(cls, json_dict): return cls(json_dict) @defer.inlineCallbacks def load(self): txs = yield self.db.get_transactions() for tx_hash, raw, height in txs: self.transactions[tx_hash] = self.transaction_class(raw, height) txios = yield self.db.get_transaction_inputs_and_outputs() for tx_hash, address_hash, input_output, amount, height in txios: tx = self.transactions[tx_hash] address = self.addresses.get(address_hash) if address is None: address = self.addresses[address_hash] = Address(self.coin_class.address_to_hash160(address_hash)) tx.add_txio(address, input_output, amount) address.add_transaction(tx) def is_address_old(self, address, age_limit=2): age = -1 for tx in self.get_transactions(address, []): if tx.height == 0: tx_age = 0 else: tx_age = self.headers.height - tx.height + 1 if tx_age > age: age = tx_age return age > age_limit def add_transaction(self, address, transaction): # type: (str, BaseTransaction) -> None if address not in self.addresses: self.addresses[address] = Address(self.coin_class.address_to_hash160(address)) self.addresses[address].add_transaction(transaction) self.transactions.setdefault(hexlify(transaction.id), transaction) self._on_transaction_controller.add(transaction) def has_address(self, address): return address in self.addresses def get_transaction(self, tx_hash, *args): return self.transactions.get(tx_hash, *args) def get_transactions(self, address, *args): return self.addresses.get(address, *args) def get_status(self, address): hashes = [ '{}:{}:'.format(hexlify(tx.hash), tx.height).encode() for tx in self.get_transactions(address, []) if tx.height is not None ] if hashes: return hexlify(hashlib.sha256(b''.join(hashes)).digest()) def has_transaction(self, tx_hash): return tx_hash in self.transactions def get_least_used_address(self, addresses, max_transactions=100): transaction_counts = [] for address in addresses: transactions = self.get_transactions(address, []) tx_count = len(transactions) if tx_count == 0: return address elif tx_count >= max_transactions: continue else: transaction_counts.append((address, tx_count)) if transaction_counts: transaction_counts.sort(key=itemgetter(1)) return transaction_counts[0] def get_unspent_outputs(self, address): if address in self.addresses: return list(self.addresses[address].get_unspent_utxos()) return [] @defer.inlineCallbacks def start(self): first_connection = self.network.on_connected.first self.network.start() yield first_connection self.headers.touch() yield self.update_headers() yield self.network.subscribe_headers() yield self.update_accounts() def stop(self): return self.network.stop() @execute_serially @defer.inlineCallbacks def update_headers(self): while True: height_sought = len(self.headers) headers = yield self.network.get_headers(height_sought) print("received {} headers starting at {} height".format(headers['count'], height_sought)) #log.info("received {} headers starting at {} height".format(headers['count'], height_sought)) if headers['count'] <= 0: break yield self.headers.connect(height_sought, unhexlify(headers['hex'])) @defer.inlineCallbacks def process_header(self, response): header = response[0] if self.update_headers.is_running: return if header['height'] == len(self.headers): # New header from network directly connects after the last local header. yield self.headers.connect(len(self.headers), unhexlify(header['hex'])) elif header['height'] > len(self.headers): # New header is several heights ahead of local, do download instead. yield self.update_headers() @execute_serially def update_accounts(self): return defer.DeferredList([ self.update_account(a) for a in self.accounts ]) @defer.inlineCallbacks def update_account(self, account): # type: (Account) -> defer.Defferred # Before subscribing, download history for any addresses that don't have any, # this avoids situation where we're getting status updates to addresses we know # need to update anyways. Continue to get history and create more addresses until # all missing addresses are created and history for them is fully restored. account.ensure_enough_addresses() addresses = list(account.addresses_without_history()) while addresses: yield defer.DeferredList([ self.update_history(a) for a in addresses ]) addresses = account.ensure_enough_addresses() # By this point all of the addresses should be restored and we # can now subscribe all of them to receive updates. yield defer.DeferredList([ self.subscribe_history(address) for address in account.addresses ]) @defer.inlineCallbacks def update_history(self, address): history = yield self.network.get_history(address) for hash in map(itemgetter('tx_hash'), history): transaction = self.get_transaction(hash) if not transaction: raw = yield self.network.get_transaction(hash) transaction = self.transaction_class(unhexlify(raw)) self.add_transaction(address, transaction) @defer.inlineCallbacks def subscribe_history(self, address): status = yield self.network.subscribe_address(address) if status != self.get_status(address): yield self.update_history(address) def process_status(self, response): address, status = response if status != self.get_status(address): task.deferLater(reactor, 0, self.update_history, address) def broadcast(self, tx): return self.network.broadcast(hexlify(tx.raw)) class Headers: def __init__(self, ledger): self.ledger = ledger self._size = None self._on_change_controller = StreamController() self.on_changed = self._on_change_controller.stream @property def path(self): wallet_path = self.ledger.config.get('wallet_path', '') filename = '{}_headers'.format(self.ledger.coin_class.get_id()) return os.path.join(wallet_path, filename) def touch(self): if not os.path.exists(self.path): with open(self.path, 'wb'): pass @property def height(self): return len(self) - 1 def sync_read_length(self): return os.path.getsize(self.path) // self.ledger.header_size def sync_read_header(self, height): if 0 <= height < len(self): with open(self.path, 'rb') as f: f.seek(height * self.ledger.header_size) return f.read(self.ledger.header_size) def __len__(self): if self._size is None: self._size = self.sync_read_length() return self._size def __getitem__(self, height): assert not isinstance(height, slice),\ "Slicing of header chain has not been implemented yet." header = self.sync_read_header(height) return self._deserialize(height, header) @execute_serially @defer.inlineCallbacks def connect(self, start, headers): yield threads.deferToThread(self._sync_connect, start, headers) def _sync_connect(self, start, headers): previous_header = None for header in self._iterate_headers(start, headers): height = header['block_height'] if previous_header is None and height > 0: previous_header = self[height-1] self._verify_header(height, header, previous_header) previous_header = header with open(self.path, 'r+b') as f: f.seek(start * self.ledger.header_size) f.write(headers) f.truncate() _old_size = self._size self._size = self.sync_read_length() change = self._size - _old_size #log.info('saved {} header blocks'.format(change)) self._on_change_controller.add(change) def _iterate_headers(self, height, headers): assert len(headers) % self.ledger.header_size == 0 for idx in range(len(headers) // self.ledger.header_size): start, end = idx * self.ledger.header_size, (idx + 1) * self.ledger.header_size header = headers[start:end] yield self._deserialize(height+idx, header) def _verify_header(self, height, header, previous_header): previous_hash = self._hash_header(previous_header) assert previous_hash == header['prev_block_hash'], \ "prev hash mismatch: {} vs {}".format(previous_hash, header['prev_block_hash']) bits, target = self._calculate_lbry_next_work_required(height, previous_header, header) assert bits == header['bits'], \ "bits mismatch: {} vs {} (hash: {})".format( bits, header['bits'], self._hash_header(header)) _pow_hash = self._pow_hash_header(header) assert int(b'0x' + _pow_hash, 16) <= target, \ "insufficient proof of work: {} vs target {}".format( int(b'0x' + _pow_hash, 16), target) @staticmethod def _serialize(header): return b''.join([ int_to_hex(header['version'], 4), rev_hex(header['prev_block_hash']), rev_hex(header['merkle_root']), rev_hex(header['claim_trie_root']), int_to_hex(int(header['timestamp']), 4), int_to_hex(int(header['bits']), 4), int_to_hex(int(header['nonce']), 4) ]) @staticmethod def _deserialize(height, header): return { 'version': hex_to_int(header[0:4]), 'prev_block_hash': hash_encode(header[4:36]), 'merkle_root': hash_encode(header[36:68]), 'claim_trie_root': hash_encode(header[68:100]), 'timestamp': hex_to_int(header[100:104]), 'bits': hex_to_int(header[104:108]), 'nonce': hex_to_int(header[108:112]), 'block_height': height } def _hash_header(self, header): if header is None: return b'0' * 64 return hash_encode(double_sha256(unhexlify(self._serialize(header)))) def _pow_hash_header(self, header): if header is None: return b'0' * 64 return hash_encode(pow_hash(unhexlify(self._serialize(header)))) def _calculate_lbry_next_work_required(self, height, first, last): """ See: lbrycrd/src/lbry.cpp """ if height == 0: return self.ledger.genesis_bits, self.ledger.max_target if self.ledger.verify_bits_to_target: bits = last['bits'] bitsN = (bits >> 24) & 0xff assert 0x03 <= bitsN <= 0x1f, \ "First part of bits should be in [0x03, 0x1d], but it was {}".format(hex(bitsN)) bitsBase = bits & 0xffffff assert 0x8000 <= bitsBase <= 0x7fffff, \ "Second part of bits should be in [0x8000, 0x7fffff] but it was {}".format(bitsBase) # new target retargetTimespan = self.ledger.target_timespan nActualTimespan = last['timestamp'] - first['timestamp'] nModulatedTimespan = retargetTimespan + (nActualTimespan - retargetTimespan) // 8 nMinTimespan = retargetTimespan - (retargetTimespan // 8) nMaxTimespan = retargetTimespan + (retargetTimespan // 2) # Limit adjustment step if nModulatedTimespan < nMinTimespan: nModulatedTimespan = nMinTimespan elif nModulatedTimespan > nMaxTimespan: nModulatedTimespan = nMaxTimespan # Retarget bnPowLimit = _ArithUint256(self.ledger.max_target) bnNew = _ArithUint256.SetCompact(last['bits']) bnNew *= nModulatedTimespan bnNew //= nModulatedTimespan if bnNew > bnPowLimit: bnNew = bnPowLimit return bnNew.GetCompact(), bnNew._value class _ArithUint256: """ See: lbrycrd/src/arith_uint256.cpp """ def __init__(self, value): self._value = value def __str__(self): return hex(self._value) @staticmethod def fromCompact(nCompact): """Convert a compact representation into its value""" nSize = nCompact >> 24 # the lower 23 bits nWord = nCompact & 0x007fffff if nSize <= 3: return nWord >> 8 * (3 - nSize) else: return nWord << 8 * (nSize - 3) @classmethod def SetCompact(cls, nCompact): return cls(cls.fromCompact(nCompact)) def bits(self): """Returns the position of the highest bit set plus one.""" bn = bin(self._value)[2:] for i, d in enumerate(bn): if d: return (len(bn) - i) + 1 return 0 def GetLow64(self): return self._value & 0xffffffffffffffff def GetCompact(self): """Convert a value into its compact representation""" nSize = (self.bits() + 7) // 8 nCompact = 0 if nSize <= 3: nCompact = self.GetLow64() << 8 * (3 - nSize) else: bn = _ArithUint256(self._value >> 8 * (nSize - 3)) nCompact = bn.GetLow64() # The 0x00800000 bit denotes the sign. # Thus, if it is already set, divide the mantissa by 256 and increase the exponent. if nCompact & 0x00800000: nCompact >>= 8 nSize += 1 assert (nCompact & ~0x007fffff) == 0 assert nSize < 256 nCompact |= nSize << 24 return nCompact def __mul__(self, x): # Take the mod because we are limited to an unsigned 256 bit number return _ArithUint256((self._value * x) % 2 ** 256) def __ifloordiv__(self, x): self._value = (self._value // x) return self def __gt__(self, x): return self._value > x._value