a lot more stuff

This commit is contained in:
Lex Berezhny 2018-03-27 02:40:44 -04:00 committed by Jack Robison
parent ca8b2dd83e
commit 0fd160e6e6
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
7 changed files with 368 additions and 1592 deletions

View file

@ -13,10 +13,6 @@ def get_key_chain_from_xpub(xpub):
return key, chain
def derive_key(parent_key, chain, sequence):
return CKD_pub(parent_key, chain, sequence)[0]
class AddressSequence:
def __init__(self, derived_keys, gap, age_checker, pub_key, chain_key):
@ -31,7 +27,7 @@ class AddressSequence:
]
def generate_next_address(self):
new_key, _ = derive_key(self.pub_key, self.chain_key, len(self.derived_keys))
new_key, _ = CKD_pub(self.pub_key, self.chain_key, len(self.derived_keys))
address = public_key_to_address(new_key)
self.derived_keys.append(new_key.encode('hex'))
self.addresses.append(address)
@ -59,11 +55,11 @@ class Account:
master_key, master_chain = get_key_chain_from_xpub(data['xpub'])
self.receiving = AddressSequence(
data.get('receiving', []), receiving_gap, age_checker,
*derive_key(master_key, master_chain, 0)
*CKD_pub(master_key, master_chain, 0)
)
self.change = AddressSequence(
data.get('change', []), change_gap, age_checker,
*derive_key(master_key, master_chain, 1)
*CKD_pub(master_key, master_chain, 1)
)
self.is_old = age_checker
@ -74,10 +70,6 @@ class Account:
'xpub': self.xpub
}
def ensure_enough_addresses(self):
return self.receiving.ensure_enough_addresses() + \
self.change.ensure_enough_addresses()
@property
def sequences(self):
return self.receiving, self.change

View file

@ -1,16 +1,63 @@
import os
import logging
import hashlib
from twisted.internet import threads, defer
from lbryum.util import hex_to_int, int_to_hex, rev_hex
from lbryum.hashing import hash_encode, Hash, PoWHash
from .stream import StreamController
from .stream import StreamController, execute_serially
from .constants import blockchain_params, HEADER_SIZE
log = logging.getLogger(__name__)
class Transaction:
def __init__(self, tx_hash, raw, height):
self.hash = tx_hash
self.raw = raw
self.height = height
class BlockchainTransactions:
def __init__(self, history):
self.addresses = {}
self.transactions = {}
for address, transactions in history.items():
self.addresses[address] = []
for txid, raw, height in transactions:
tx = Transaction(txid, raw, height)
self.addresses[address].append(tx)
self.transactions[txid] = tx
def has_address(self, address):
return address in self.addresses
def get_transaction(self, tx_hash, *args):
return self.transactions.get(tx_hash, *args)
def get_transactions(self, address, *args):
return self.addresses.get(address, *args)
def get_status(self, address):
hashes = [
'{}:{}:'.format(tx.hash, tx.height)
for tx in self.get_transactions(address, [])
]
if hashes:
return hashlib.sha256(''.join(hashes)).digest().encode('hex')
def has_transaction(self, tx_hash):
return tx_hash in self.transactions
def add_transaction(self, address, transaction):
self.transactions.setdefault(transaction.hash, transaction)
self.addresses.setdefault(address, [])
self.addresses[address].append(transaction)
class BlockchainHeaders:
def __init__(self, path, chain='lbrycrd_main'):
@ -24,39 +71,39 @@ class BlockchainHeaders:
self.on_changed = self._on_change_controller.stream
self._size = None
self._write_lock = defer.DeferredLock()
if not os.path.exists(path):
with open(path, 'wb'):
pass
@property
def height(self):
return len(self) - 1
def sync_read_length(self):
return os.path.getsize(self.path) / HEADER_SIZE
def __len__(self):
if self._size is None:
self._size = self.sync_read_length()
return self._size
def sync_read_header(self, height):
if 0 <= height < len(self):
with open(self.path, 'rb') as f:
f.seek(height * HEADER_SIZE)
return f.read(HEADER_SIZE)
def __len__(self):
if self._size is None:
self._size = self.sync_read_length()
return self._size
def __getitem__(self, height):
assert not isinstance(height, slice),\
"Slicing of header chain has not been implemented yet."
header = self.sync_read_header(height)
return self._deserialize(height, header)
@execute_serially
@defer.inlineCallbacks
def connect(self, start, headers):
yield self._write_lock.acquire()
try:
yield threads.deferToThread(self._sync_connect, start, headers)
finally:
self._write_lock.release()
yield threads.deferToThread(self._sync_connect, start, headers)
def _sync_connect(self, start, headers):
previous_header = None

View file

@ -1,22 +1,19 @@
import os
import logging
from operator import itemgetter
from twisted.internet import defer
import lbryschema
from .protocol import Network
from .blockchain import BlockchainHeaders
from .blockchain import BlockchainHeaders, Transaction
from .wallet import Wallet
from .stream import execute_serially
log = logging.getLogger(__name__)
def chunks(l, n):
for i in range(0, len(l), n):
yield l[i:i+n]
class WalletManager:
def __init__(self, storage, config):
@ -24,11 +21,10 @@ class WalletManager:
self.config = config
lbryschema.BLOCKCHAIN_NAME = config['chain']
self.headers = BlockchainHeaders(self.headers_path, config['chain'])
self.wallet = Wallet(self.wallet_path)
self.wallet = Wallet(self.wallet_path, self.headers)
self.network = Network(config)
self.network.on_header.listen(self.process_header)
self.network.on_transaction.listen(self.process_transaction)
self._downloading_headers = False
self.network.on_status.listen(self.process_status)
@property
def headers_path(self):
@ -41,48 +37,117 @@ class WalletManager:
def wallet_path(self):
return os.path.join(self.config['wallet_path'], 'wallets', 'default_wallet')
def get_least_used_receiving_address(self, max_transactions=1000):
return self._get_least_used_address(
self.wallet.receiving_addresses,
self.wallet.default_account.receiving,
max_transactions
)
def get_least_used_change_address(self, max_transactions=100):
return self._get_least_used_address(
self.wallet.change_addresses,
self.wallet.default_account.change,
max_transactions
)
def _get_least_used_address(self, addresses, sequence, max_transactions):
transaction_counts = []
for address in addresses:
transactions = self.wallet.history.get_transactions(address, [])
tx_count = len(transactions)
if tx_count == 0:
return address
elif tx_count >= max_transactions:
continue
else:
transaction_counts.append((address, tx_count))
if transaction_counts:
transaction_counts.sort(key=itemgetter(1))
return transaction_counts[0]
address = sequence.generate_next_address()
self.subscribe_history(address)
return address
@defer.inlineCallbacks
def start(self):
self.wallet.load()
self.network.start()
yield self.network.on_connected.first
yield self.download_headers()
yield self.network.headers_subscribe()
yield self.download_transactions()
yield self.update_headers()
yield self.network.subscribe_headers()
yield self.update_wallet()
def stop(self):
return self.network.stop()
@execute_serially
@defer.inlineCallbacks
def download_headers(self):
self._downloading_headers = True
def update_headers(self):
while True:
sought_height = len(self.headers)
headers = yield self.network.block_headers(sought_height)
log.info("received {} headers starting at {} height".format(headers['count'], sought_height))
height_sought = len(self.headers)
headers = yield self.network.get_headers(height_sought)
log.info("received {} headers starting at {} height".format(headers['count'], height_sought))
if headers['count'] <= 0:
break
yield self.headers.connect(sought_height, headers['hex'].decode('hex'))
self._downloading_headers = False
yield self.headers.connect(height_sought, headers['hex'].decode('hex'))
@defer.inlineCallbacks
def process_header(self, header):
if self._downloading_headers:
def process_header(self, response):
header = response[0]
if self.update_headers.is_running:
return
if header['block_height'] == len(self.headers):
if header['height'] == len(self.headers):
# New header from network directly connects after the last local header.
yield self.headers.connect(len(self.headers), header['hex'].decode('hex'))
elif header['block_height'] > len(self.headers):
elif header['height'] > len(self.headers):
# New header is several heights ahead of local, do download instead.
yield self.download_headers()
yield self.update_headers()
@execute_serially
@defer.inlineCallbacks
def update_wallet(self):
if not self.wallet.exists:
self.wallet.create()
# Before subscribing, download history for any addresses that don't have any,
# this avoids situation where we're getting status updates to addresses we know
# need to update anyways. Continue to get history and create more addresses until
# all missing addresses are created and history for them is fully restored.
self.wallet.ensure_enough_addresses()
addresses = list(self.wallet.addresses_without_history)
while addresses:
yield defer.gatherResults([
self.update_history(a) for a in addresses
])
addresses = self.wallet.ensure_enough_addresses()
# By this point all of the addresses should be restored and we
# can now subscribe all of them to receive updates.
yield defer.gatherResults([
self.subscribe_history(address)
for address in self.wallet.addresses
])
@defer.inlineCallbacks
def download_transactions(self):
for addresses in chunks(self.wallet.addresses, 500):
self.network.rpc([
('blockchain.address.subscribe', [address])
for address in addresses
])
def update_history(self, address):
history = yield self.network.get_history(address)
for hash in map(itemgetter('tx_hash'), history):
transaction = self.wallet.history.get_transaction(hash)
if not transaction:
raw = yield self.network.get_transaction(hash)
transaction = Transaction(hash, raw, None)
self.wallet.history.add_transaction(address, transaction)
def process_transaction(self, tx):
pass
@defer.inlineCallbacks
def subscribe_history(self, address):
status = yield self.network.subscribe_address(address)
if status != self.wallet.history.get_status(address):
self.update_history(address)
def process_status(self, response):
address, status = response
if status != self.wallet.history.get_status(address):
self.update_history(address)

View file

@ -63,57 +63,37 @@ class StratumClientProtocol(LineOnlyReceiver):
self.on_disconnected_controller.add(True)
def lineReceived(self, line):
try:
message = json.loads(line)
except (ValueError, TypeError):
raise ProtocolException("Cannot decode message '%s'" % line.strip())
msg_id = message.get('id', 0)
msg_result = message.get('result')
msg_error = message.get('error')
msg_method = message.get('method')
msg_params = message.get('params')
if msg_id:
# It's a RPC response
# Perform lookup to the table of waiting requests.
raise ProtocolException("Cannot decode message '{}'".format(line.strip()))
if message.get('id'):
try:
meta = self.lookup_table[msg_id]
del self.lookup_table[msg_id]
d = self.lookup_table.pop(message['id'])
if message.get('error'):
d.errback(RemoteServiceException(*message['error']))
else:
d.callback(message.get('result'))
except KeyError:
# When deferred object for given message ID isn't found, it's an error
raise ProtocolException(
"Lookup for deferred object for message ID '%s' failed." % msg_id)
# If there's an error, handle it as errback
# If both result and error are null, handle it as a success with blank result
if msg_error != None:
meta['defer'].errback(
RemoteServiceException(msg_error[0], msg_error[1], msg_error[2])
)
else:
meta['defer'].callback(msg_result)
elif msg_method:
if msg_method == 'blockchain.headers.subscribe':
self.network._on_header_controller.add(msg_params[0])
elif msg_method == 'blockchain.address.subscribe':
self.network._on_address_controller.add(msg_params)
"Lookup for deferred object for message ID '{}' failed.".format(message['id']))
elif message.get('method') in self.network.subscription_controllers:
controller = self.network.subscription_controllers[message['method']]
controller.add(message.get('params'))
else:
log.warning("Cannot handle message '%s'" % line)
def write_request(self, method, params, is_notification=False):
request_id = None if is_notification else self._get_id()
serialized = json.dumps({'id': request_id, 'method': method, 'params': params})
self.sendLine(serialized)
return request_id
def rpc(self, method, params, is_notification=False):
request_id = self.write_request(method, params, is_notification)
if is_notification:
return
d = defer.Deferred()
self.lookup_table[request_id] = {
def rpc(self, method, *args):
message_id = self._get_id()
message = json.dumps({
'id': message_id,
'method': method,
'params': params,
'defer': d,
}
'params': args
})
self.sendLine(message)
d = self.lookup_table[message_id] = defer.Deferred()
return d
@ -147,8 +127,13 @@ class Network:
self._on_header_controller = StreamController()
self.on_header = self._on_header_controller.stream
self._on_transaction_controller = StreamController()
self.on_transaction = self._on_transaction_controller.stream
self._on_status_controller = StreamController()
self.on_status = self._on_status_controller.stream
self.subscription_controllers = {
'blockchain.headers.subscribe': self._on_header_controller,
'blockchain.address.subscribe': self._on_status_controller,
}
@defer.inlineCallbacks
def start(self):
@ -182,101 +167,29 @@ class Network:
def is_connected(self):
return self.client is not None and self.client.connected
def rpc(self, method, params, *args, **kwargs):
def rpc(self, list_or_method, *args):
if self.is_connected:
return self.client.rpc(method, params, *args, **kwargs)
return self.client.rpc(list_or_method, *args)
else:
raise TransportException("Attempting to send rpc request when connection is not available.")
def claimtrie_getvaluesforuris(self, block_hash, *uris):
return self.rpc(
'blockchain.claimtrie.getvaluesforuris', [block_hash] + list(uris)
)
def broadcast(self, raw_transaction):
return self.rpc('blockchain.transaction.broadcast', raw_transaction)
def claimtrie_getvaluesforuri(self, block_hash, uri):
return self.rpc('blockchain.claimtrie.getvaluesforuri', [block_hash, uri])
def get_history(self, address):
return self.rpc('blockchain.address.get_history', address)
def claimtrie_getclaimssignedbynthtoname(self, name, n):
return self.rpc('blockchain.claimtrie.getclaimssignedbynthtoname', [name, n])
def get_transaction(self, tx_hash):
return self.rpc('blockchain.transaction.get', tx_hash)
def claimtrie_getclaimssignedbyid(self, certificate_id):
return self.rpc('blockchain.claimtrie.getclaimssignedbyid', [certificate_id])
def get_merkle(self, tx_hash, height):
return self.rpc('blockchain.transaction.get_merkle', tx_hash, height)
def claimtrie_getclaimssignedby(self, name):
return self.rpc('blockchain.claimtrie.getclaimssignedby', [name])
def get_headers(self, height, count=10000):
return self.rpc('blockchain.block.headers', height, count)
def claimtrie_getnthclaimforname(self, name, n):
return self.rpc('blockchain.claimtrie.getnthclaimforname', [name, n])
def subscribe_headers(self):
return self.rpc('blockchain.headers.subscribe')
def claimtrie_getclaimsbyids(self, *claim_ids):
return self.rpc('blockchain.claimtrie.getclaimsbyids', list(claim_ids))
def claimtrie_getclaimbyid(self, claim_id):
return self.rpc('blockchain.claimtrie.getclaimbyid', [claim_id])
def claimtrie_get(self):
return self.rpc('blockchain.claimtrie.get', [])
def block_get_block(self, block_hash):
return self.rpc('blockchain.block.get_block', [block_hash])
def claimtrie_getclaimsforname(self, name):
return self.rpc('blockchain.claimtrie.getclaimsforname', [name])
def claimtrie_getclaimsintx(self, txid):
return self.rpc('blockchain.claimtrie.getclaimsintx', [txid])
def claimtrie_getvalue(self, name, block_hash=None):
return self.rpc('blockchain.claimtrie.getvalue', [name, block_hash])
def relayfee(self):
return self.rpc('blockchain.relayfee', [])
def estimatefee(self):
return self.rpc('blockchain.estimatefee', [])
def transaction_get(self, txid):
return self.rpc('blockchain.transaction.get', [txid])
def transaction_get_merkle(self, tx_hash, height, cache_only=False):
return self.rpc('blockchain.transaction.get_merkle', [tx_hash, height, cache_only])
def transaction_broadcast(self, raw_transaction):
return self.rpc('blockchain.transaction.broadcast', [raw_transaction])
def block_get_chunk(self, index, cache_only=False):
return self.rpc('blockchain.block.get_chunk', [index, cache_only])
def block_get_header(self, height, cache_only=False):
return self.rpc('blockchain.block.get_header', [height, cache_only])
def block_headers(self, height, count=10000):
return self.rpc('blockchain.block.headers', [height, count])
def utxo_get_address(self, txid, pos):
return self.rpc('blockchain.utxo.get_address', [txid, pos])
def address_listunspent(self, address):
return self.rpc('blockchain.address.listunspent', [address])
def address_get_proof(self, address):
return self.rpc('blockchain.address.get_proof', [address])
def address_get_balance(self, address):
return self.rpc('blockchain.address.get_balance', [address])
def address_get_mempool(self, address):
return self.rpc('blockchain.address.get_mempool', [address])
def address_get_history(self, address):
return self.rpc('blockchain.address.get_history', [address])
def address_subscribe(self, addresses):
if isinstance(addresses, str):
return self.rpc('blockchain.address.subscribe', [addresses])
else:
msgs = map(lambda addr: ('blockchain.address.subscribe', [addr]), addresses)
self.network.send(msgs, self.addr_subscription_response)
def headers_subscribe(self):
return self.rpc('blockchain.headers.subscribe', [], True)
def subscribe_address(self, address):
return self.rpc('blockchain.address.subscribe', address)

View file

@ -1,7 +1,24 @@
from twisted.internet.defer import Deferred
from twisted.internet.defer import Deferred, DeferredLock, maybeDeferred, inlineCallbacks
from twisted.python.failure import Failure
def execute_serially(f):
_lock = DeferredLock()
@inlineCallbacks
def allow_only_one_at_a_time(*args, **kwargs):
yield _lock.acquire()
allow_only_one_at_a_time.is_running = True
try:
yield maybeDeferred(f, *args, **kwargs)
finally:
allow_only_one_at_a_time.is_running = False
_lock.release()
allow_only_one_at_a_time.is_running = False
return allow_only_one_at_a_time
class BroadcastSubscription:
def __init__(self, controller, on_data, on_error, on_done):

View file

@ -16,7 +16,7 @@ from .lbrycrd import op_push
from .lbrycrd import point_to_ser, MyVerifyingKey, MySigningKey
from .lbrycrd import regenerate_key, public_key_from_private_key
from .lbrycrd import encode_claim_id_hex, claim_id_hash
from .util import profiler, var_int, int_to_hex, parse_sig, rev_hex
from .util import var_int, int_to_hex, parse_sig, rev_hex
log = logging.getLogger()
@ -559,7 +559,6 @@ class Transaction(object):
fee = relay_fee
return fee
@profiler
def estimated_size(self):
'''Return an estimated tx size in bytes.'''
return len(self.serialize(-1)) / 2 # ASCII hex string

File diff suppressed because it is too large Load diff