a lot more stuff

This commit is contained in:
Lex Berezhny 2018-03-27 02:40:44 -04:00 committed by Jack Robison
parent ca8b2dd83e
commit 0fd160e6e6
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
7 changed files with 368 additions and 1592 deletions

View file

@ -13,10 +13,6 @@ def get_key_chain_from_xpub(xpub):
return key, chain return key, chain
def derive_key(parent_key, chain, sequence):
return CKD_pub(parent_key, chain, sequence)[0]
class AddressSequence: class AddressSequence:
def __init__(self, derived_keys, gap, age_checker, pub_key, chain_key): def __init__(self, derived_keys, gap, age_checker, pub_key, chain_key):
@ -31,7 +27,7 @@ class AddressSequence:
] ]
def generate_next_address(self): def generate_next_address(self):
new_key, _ = derive_key(self.pub_key, self.chain_key, len(self.derived_keys)) new_key, _ = CKD_pub(self.pub_key, self.chain_key, len(self.derived_keys))
address = public_key_to_address(new_key) address = public_key_to_address(new_key)
self.derived_keys.append(new_key.encode('hex')) self.derived_keys.append(new_key.encode('hex'))
self.addresses.append(address) self.addresses.append(address)
@ -59,11 +55,11 @@ class Account:
master_key, master_chain = get_key_chain_from_xpub(data['xpub']) master_key, master_chain = get_key_chain_from_xpub(data['xpub'])
self.receiving = AddressSequence( self.receiving = AddressSequence(
data.get('receiving', []), receiving_gap, age_checker, data.get('receiving', []), receiving_gap, age_checker,
*derive_key(master_key, master_chain, 0) *CKD_pub(master_key, master_chain, 0)
) )
self.change = AddressSequence( self.change = AddressSequence(
data.get('change', []), change_gap, age_checker, data.get('change', []), change_gap, age_checker,
*derive_key(master_key, master_chain, 1) *CKD_pub(master_key, master_chain, 1)
) )
self.is_old = age_checker self.is_old = age_checker
@ -74,10 +70,6 @@ class Account:
'xpub': self.xpub 'xpub': self.xpub
} }
def ensure_enough_addresses(self):
return self.receiving.ensure_enough_addresses() + \
self.change.ensure_enough_addresses()
@property @property
def sequences(self): def sequences(self):
return self.receiving, self.change return self.receiving, self.change

View file

@ -1,16 +1,63 @@
import os import os
import logging import logging
import hashlib
from twisted.internet import threads, defer from twisted.internet import threads, defer
from lbryum.util import hex_to_int, int_to_hex, rev_hex from lbryum.util import hex_to_int, int_to_hex, rev_hex
from lbryum.hashing import hash_encode, Hash, PoWHash from lbryum.hashing import hash_encode, Hash, PoWHash
from .stream import StreamController from .stream import StreamController, execute_serially
from .constants import blockchain_params, HEADER_SIZE from .constants import blockchain_params, HEADER_SIZE
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class Transaction:
def __init__(self, tx_hash, raw, height):
self.hash = tx_hash
self.raw = raw
self.height = height
class BlockchainTransactions:
def __init__(self, history):
self.addresses = {}
self.transactions = {}
for address, transactions in history.items():
self.addresses[address] = []
for txid, raw, height in transactions:
tx = Transaction(txid, raw, height)
self.addresses[address].append(tx)
self.transactions[txid] = tx
def has_address(self, address):
return address in self.addresses
def get_transaction(self, tx_hash, *args):
return self.transactions.get(tx_hash, *args)
def get_transactions(self, address, *args):
return self.addresses.get(address, *args)
def get_status(self, address):
hashes = [
'{}:{}:'.format(tx.hash, tx.height)
for tx in self.get_transactions(address, [])
]
if hashes:
return hashlib.sha256(''.join(hashes)).digest().encode('hex')
def has_transaction(self, tx_hash):
return tx_hash in self.transactions
def add_transaction(self, address, transaction):
self.transactions.setdefault(transaction.hash, transaction)
self.addresses.setdefault(address, [])
self.addresses[address].append(transaction)
class BlockchainHeaders: class BlockchainHeaders:
def __init__(self, path, chain='lbrycrd_main'): def __init__(self, path, chain='lbrycrd_main'):
@ -24,39 +71,39 @@ class BlockchainHeaders:
self.on_changed = self._on_change_controller.stream self.on_changed = self._on_change_controller.stream
self._size = None self._size = None
self._write_lock = defer.DeferredLock()
if not os.path.exists(path): if not os.path.exists(path):
with open(path, 'wb'): with open(path, 'wb'):
pass pass
@property
def height(self):
return len(self) - 1
def sync_read_length(self): def sync_read_length(self):
return os.path.getsize(self.path) / HEADER_SIZE return os.path.getsize(self.path) / HEADER_SIZE
def __len__(self):
if self._size is None:
self._size = self.sync_read_length()
return self._size
def sync_read_header(self, height): def sync_read_header(self, height):
if 0 <= height < len(self): if 0 <= height < len(self):
with open(self.path, 'rb') as f: with open(self.path, 'rb') as f:
f.seek(height * HEADER_SIZE) f.seek(height * HEADER_SIZE)
return f.read(HEADER_SIZE) return f.read(HEADER_SIZE)
def __len__(self):
if self._size is None:
self._size = self.sync_read_length()
return self._size
def __getitem__(self, height): def __getitem__(self, height):
assert not isinstance(height, slice),\ assert not isinstance(height, slice),\
"Slicing of header chain has not been implemented yet." "Slicing of header chain has not been implemented yet."
header = self.sync_read_header(height) header = self.sync_read_header(height)
return self._deserialize(height, header) return self._deserialize(height, header)
@execute_serially
@defer.inlineCallbacks @defer.inlineCallbacks
def connect(self, start, headers): def connect(self, start, headers):
yield self._write_lock.acquire()
try:
yield threads.deferToThread(self._sync_connect, start, headers) yield threads.deferToThread(self._sync_connect, start, headers)
finally:
self._write_lock.release()
def _sync_connect(self, start, headers): def _sync_connect(self, start, headers):
previous_header = None previous_header = None

View file

@ -1,22 +1,19 @@
import os import os
import logging import logging
from operator import itemgetter
from twisted.internet import defer from twisted.internet import defer
import lbryschema import lbryschema
from .protocol import Network from .protocol import Network
from .blockchain import BlockchainHeaders from .blockchain import BlockchainHeaders, Transaction
from .wallet import Wallet from .wallet import Wallet
from .stream import execute_serially
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
def chunks(l, n):
for i in range(0, len(l), n):
yield l[i:i+n]
class WalletManager: class WalletManager:
def __init__(self, storage, config): def __init__(self, storage, config):
@ -24,11 +21,10 @@ class WalletManager:
self.config = config self.config = config
lbryschema.BLOCKCHAIN_NAME = config['chain'] lbryschema.BLOCKCHAIN_NAME = config['chain']
self.headers = BlockchainHeaders(self.headers_path, config['chain']) self.headers = BlockchainHeaders(self.headers_path, config['chain'])
self.wallet = Wallet(self.wallet_path) self.wallet = Wallet(self.wallet_path, self.headers)
self.network = Network(config) self.network = Network(config)
self.network.on_header.listen(self.process_header) self.network.on_header.listen(self.process_header)
self.network.on_transaction.listen(self.process_transaction) self.network.on_status.listen(self.process_status)
self._downloading_headers = False
@property @property
def headers_path(self): def headers_path(self):
@ -41,48 +37,117 @@ class WalletManager:
def wallet_path(self): def wallet_path(self):
return os.path.join(self.config['wallet_path'], 'wallets', 'default_wallet') return os.path.join(self.config['wallet_path'], 'wallets', 'default_wallet')
def get_least_used_receiving_address(self, max_transactions=1000):
return self._get_least_used_address(
self.wallet.receiving_addresses,
self.wallet.default_account.receiving,
max_transactions
)
def get_least_used_change_address(self, max_transactions=100):
return self._get_least_used_address(
self.wallet.change_addresses,
self.wallet.default_account.change,
max_transactions
)
def _get_least_used_address(self, addresses, sequence, max_transactions):
transaction_counts = []
for address in addresses:
transactions = self.wallet.history.get_transactions(address, [])
tx_count = len(transactions)
if tx_count == 0:
return address
elif tx_count >= max_transactions:
continue
else:
transaction_counts.append((address, tx_count))
if transaction_counts:
transaction_counts.sort(key=itemgetter(1))
return transaction_counts[0]
address = sequence.generate_next_address()
self.subscribe_history(address)
return address
@defer.inlineCallbacks @defer.inlineCallbacks
def start(self): def start(self):
self.wallet.load()
self.network.start() self.network.start()
yield self.network.on_connected.first yield self.network.on_connected.first
yield self.download_headers() yield self.update_headers()
yield self.network.headers_subscribe() yield self.network.subscribe_headers()
yield self.download_transactions() yield self.update_wallet()
def stop(self): def stop(self):
return self.network.stop() return self.network.stop()
@execute_serially
@defer.inlineCallbacks @defer.inlineCallbacks
def download_headers(self): def update_headers(self):
self._downloading_headers = True
while True: while True:
sought_height = len(self.headers) height_sought = len(self.headers)
headers = yield self.network.block_headers(sought_height) headers = yield self.network.get_headers(height_sought)
log.info("received {} headers starting at {} height".format(headers['count'], sought_height)) log.info("received {} headers starting at {} height".format(headers['count'], height_sought))
if headers['count'] <= 0: if headers['count'] <= 0:
break break
yield self.headers.connect(sought_height, headers['hex'].decode('hex')) yield self.headers.connect(height_sought, headers['hex'].decode('hex'))
self._downloading_headers = False
@defer.inlineCallbacks @defer.inlineCallbacks
def process_header(self, header): def process_header(self, response):
if self._downloading_headers: header = response[0]
if self.update_headers.is_running:
return return
if header['block_height'] == len(self.headers): if header['height'] == len(self.headers):
# New header from network directly connects after the last local header. # New header from network directly connects after the last local header.
yield self.headers.connect(len(self.headers), header['hex'].decode('hex')) yield self.headers.connect(len(self.headers), header['hex'].decode('hex'))
elif header['block_height'] > len(self.headers): elif header['height'] > len(self.headers):
# New header is several heights ahead of local, do download instead. # New header is several heights ahead of local, do download instead.
yield self.download_headers() yield self.update_headers()
@execute_serially
@defer.inlineCallbacks @defer.inlineCallbacks
def download_transactions(self): def update_wallet(self):
for addresses in chunks(self.wallet.addresses, 500):
self.network.rpc([ if not self.wallet.exists:
('blockchain.address.subscribe', [address]) self.wallet.create()
for address in addresses
# Before subscribing, download history for any addresses that don't have any,
# this avoids situation where we're getting status updates to addresses we know
# need to update anyways. Continue to get history and create more addresses until
# all missing addresses are created and history for them is fully restored.
self.wallet.ensure_enough_addresses()
addresses = list(self.wallet.addresses_without_history)
while addresses:
yield defer.gatherResults([
self.update_history(a) for a in addresses
])
addresses = self.wallet.ensure_enough_addresses()
# By this point all of the addresses should be restored and we
# can now subscribe all of them to receive updates.
yield defer.gatherResults([
self.subscribe_history(address)
for address in self.wallet.addresses
]) ])
def process_transaction(self, tx): @defer.inlineCallbacks
pass def update_history(self, address):
history = yield self.network.get_history(address)
for hash in map(itemgetter('tx_hash'), history):
transaction = self.wallet.history.get_transaction(hash)
if not transaction:
raw = yield self.network.get_transaction(hash)
transaction = Transaction(hash, raw, None)
self.wallet.history.add_transaction(address, transaction)
@defer.inlineCallbacks
def subscribe_history(self, address):
status = yield self.network.subscribe_address(address)
if status != self.wallet.history.get_status(address):
self.update_history(address)
def process_status(self, response):
address, status = response
if status != self.wallet.history.get_status(address):
self.update_history(address)

View file

@ -63,57 +63,37 @@ class StratumClientProtocol(LineOnlyReceiver):
self.on_disconnected_controller.add(True) self.on_disconnected_controller.add(True)
def lineReceived(self, line): def lineReceived(self, line):
try: try:
message = json.loads(line) message = json.loads(line)
except (ValueError, TypeError): except (ValueError, TypeError):
raise ProtocolException("Cannot decode message '%s'" % line.strip()) raise ProtocolException("Cannot decode message '{}'".format(line.strip()))
msg_id = message.get('id', 0)
msg_result = message.get('result') if message.get('id'):
msg_error = message.get('error')
msg_method = message.get('method')
msg_params = message.get('params')
if msg_id:
# It's a RPC response
# Perform lookup to the table of waiting requests.
try: try:
meta = self.lookup_table[msg_id] d = self.lookup_table.pop(message['id'])
del self.lookup_table[msg_id] if message.get('error'):
except KeyError: d.errback(RemoteServiceException(*message['error']))
# When deferred object for given message ID isn't found, it's an error
raise ProtocolException(
"Lookup for deferred object for message ID '%s' failed." % msg_id)
# If there's an error, handle it as errback
# If both result and error are null, handle it as a success with blank result
if msg_error != None:
meta['defer'].errback(
RemoteServiceException(msg_error[0], msg_error[1], msg_error[2])
)
else: else:
meta['defer'].callback(msg_result) d.callback(message.get('result'))
elif msg_method: except KeyError:
if msg_method == 'blockchain.headers.subscribe': raise ProtocolException(
self.network._on_header_controller.add(msg_params[0]) "Lookup for deferred object for message ID '{}' failed.".format(message['id']))
elif msg_method == 'blockchain.address.subscribe': elif message.get('method') in self.network.subscription_controllers:
self.network._on_address_controller.add(msg_params) controller = self.network.subscription_controllers[message['method']]
controller.add(message.get('params'))
else: else:
log.warning("Cannot handle message '%s'" % line) log.warning("Cannot handle message '%s'" % line)
def write_request(self, method, params, is_notification=False): def rpc(self, method, *args):
request_id = None if is_notification else self._get_id() message_id = self._get_id()
serialized = json.dumps({'id': request_id, 'method': method, 'params': params}) message = json.dumps({
self.sendLine(serialized) 'id': message_id,
return request_id
def rpc(self, method, params, is_notification=False):
request_id = self.write_request(method, params, is_notification)
if is_notification:
return
d = defer.Deferred()
self.lookup_table[request_id] = {
'method': method, 'method': method,
'params': params, 'params': args
'defer': d, })
} self.sendLine(message)
d = self.lookup_table[message_id] = defer.Deferred()
return d return d
@ -147,8 +127,13 @@ class Network:
self._on_header_controller = StreamController() self._on_header_controller = StreamController()
self.on_header = self._on_header_controller.stream self.on_header = self._on_header_controller.stream
self._on_transaction_controller = StreamController() self._on_status_controller = StreamController()
self.on_transaction = self._on_transaction_controller.stream self.on_status = self._on_status_controller.stream
self.subscription_controllers = {
'blockchain.headers.subscribe': self._on_header_controller,
'blockchain.address.subscribe': self._on_status_controller,
}
@defer.inlineCallbacks @defer.inlineCallbacks
def start(self): def start(self):
@ -182,101 +167,29 @@ class Network:
def is_connected(self): def is_connected(self):
return self.client is not None and self.client.connected return self.client is not None and self.client.connected
def rpc(self, method, params, *args, **kwargs): def rpc(self, list_or_method, *args):
if self.is_connected: if self.is_connected:
return self.client.rpc(method, params, *args, **kwargs) return self.client.rpc(list_or_method, *args)
else: else:
raise TransportException("Attempting to send rpc request when connection is not available.") raise TransportException("Attempting to send rpc request when connection is not available.")
def claimtrie_getvaluesforuris(self, block_hash, *uris): def broadcast(self, raw_transaction):
return self.rpc( return self.rpc('blockchain.transaction.broadcast', raw_transaction)
'blockchain.claimtrie.getvaluesforuris', [block_hash] + list(uris)
)
def claimtrie_getvaluesforuri(self, block_hash, uri): def get_history(self, address):
return self.rpc('blockchain.claimtrie.getvaluesforuri', [block_hash, uri]) return self.rpc('blockchain.address.get_history', address)
def claimtrie_getclaimssignedbynthtoname(self, name, n): def get_transaction(self, tx_hash):
return self.rpc('blockchain.claimtrie.getclaimssignedbynthtoname', [name, n]) return self.rpc('blockchain.transaction.get', tx_hash)
def claimtrie_getclaimssignedbyid(self, certificate_id): def get_merkle(self, tx_hash, height):
return self.rpc('blockchain.claimtrie.getclaimssignedbyid', [certificate_id]) return self.rpc('blockchain.transaction.get_merkle', tx_hash, height)
def claimtrie_getclaimssignedby(self, name): def get_headers(self, height, count=10000):
return self.rpc('blockchain.claimtrie.getclaimssignedby', [name]) return self.rpc('blockchain.block.headers', height, count)
def claimtrie_getnthclaimforname(self, name, n): def subscribe_headers(self):
return self.rpc('blockchain.claimtrie.getnthclaimforname', [name, n]) return self.rpc('blockchain.headers.subscribe')
def claimtrie_getclaimsbyids(self, *claim_ids): def subscribe_address(self, address):
return self.rpc('blockchain.claimtrie.getclaimsbyids', list(claim_ids)) return self.rpc('blockchain.address.subscribe', address)
def claimtrie_getclaimbyid(self, claim_id):
return self.rpc('blockchain.claimtrie.getclaimbyid', [claim_id])
def claimtrie_get(self):
return self.rpc('blockchain.claimtrie.get', [])
def block_get_block(self, block_hash):
return self.rpc('blockchain.block.get_block', [block_hash])
def claimtrie_getclaimsforname(self, name):
return self.rpc('blockchain.claimtrie.getclaimsforname', [name])
def claimtrie_getclaimsintx(self, txid):
return self.rpc('blockchain.claimtrie.getclaimsintx', [txid])
def claimtrie_getvalue(self, name, block_hash=None):
return self.rpc('blockchain.claimtrie.getvalue', [name, block_hash])
def relayfee(self):
return self.rpc('blockchain.relayfee', [])
def estimatefee(self):
return self.rpc('blockchain.estimatefee', [])
def transaction_get(self, txid):
return self.rpc('blockchain.transaction.get', [txid])
def transaction_get_merkle(self, tx_hash, height, cache_only=False):
return self.rpc('blockchain.transaction.get_merkle', [tx_hash, height, cache_only])
def transaction_broadcast(self, raw_transaction):
return self.rpc('blockchain.transaction.broadcast', [raw_transaction])
def block_get_chunk(self, index, cache_only=False):
return self.rpc('blockchain.block.get_chunk', [index, cache_only])
def block_get_header(self, height, cache_only=False):
return self.rpc('blockchain.block.get_header', [height, cache_only])
def block_headers(self, height, count=10000):
return self.rpc('blockchain.block.headers', [height, count])
def utxo_get_address(self, txid, pos):
return self.rpc('blockchain.utxo.get_address', [txid, pos])
def address_listunspent(self, address):
return self.rpc('blockchain.address.listunspent', [address])
def address_get_proof(self, address):
return self.rpc('blockchain.address.get_proof', [address])
def address_get_balance(self, address):
return self.rpc('blockchain.address.get_balance', [address])
def address_get_mempool(self, address):
return self.rpc('blockchain.address.get_mempool', [address])
def address_get_history(self, address):
return self.rpc('blockchain.address.get_history', [address])
def address_subscribe(self, addresses):
if isinstance(addresses, str):
return self.rpc('blockchain.address.subscribe', [addresses])
else:
msgs = map(lambda addr: ('blockchain.address.subscribe', [addr]), addresses)
self.network.send(msgs, self.addr_subscription_response)
def headers_subscribe(self):
return self.rpc('blockchain.headers.subscribe', [], True)

View file

@ -1,7 +1,24 @@
from twisted.internet.defer import Deferred from twisted.internet.defer import Deferred, DeferredLock, maybeDeferred, inlineCallbacks
from twisted.python.failure import Failure from twisted.python.failure import Failure
def execute_serially(f):
_lock = DeferredLock()
@inlineCallbacks
def allow_only_one_at_a_time(*args, **kwargs):
yield _lock.acquire()
allow_only_one_at_a_time.is_running = True
try:
yield maybeDeferred(f, *args, **kwargs)
finally:
allow_only_one_at_a_time.is_running = False
_lock.release()
allow_only_one_at_a_time.is_running = False
return allow_only_one_at_a_time
class BroadcastSubscription: class BroadcastSubscription:
def __init__(self, controller, on_data, on_error, on_done): def __init__(self, controller, on_data, on_error, on_done):

View file

@ -16,7 +16,7 @@ from .lbrycrd import op_push
from .lbrycrd import point_to_ser, MyVerifyingKey, MySigningKey from .lbrycrd import point_to_ser, MyVerifyingKey, MySigningKey
from .lbrycrd import regenerate_key, public_key_from_private_key from .lbrycrd import regenerate_key, public_key_from_private_key
from .lbrycrd import encode_claim_id_hex, claim_id_hash from .lbrycrd import encode_claim_id_hex, claim_id_hash
from .util import profiler, var_int, int_to_hex, parse_sig, rev_hex from .util import var_int, int_to_hex, parse_sig, rev_hex
log = logging.getLogger() log = logging.getLogger()
@ -559,7 +559,6 @@ class Transaction(object):
fee = relay_fee fee = relay_fee
return fee return fee
@profiler
def estimated_size(self): def estimated_size(self):
'''Return an estimated tx size in bytes.''' '''Return an estimated tx size in bytes.'''
return len(self.serialize(-1)) / 2 # ASCII hex string return len(self.serialize(-1)) / 2 # ASCII hex string

File diff suppressed because it is too large Load diff