From b0bd0b1fc09331d839fb5ab8e2a81f005c032a57 Mon Sep 17 00:00:00 2001 From: Lex Berezhny Date: Sat, 28 Jul 2018 20:52:54 -0400 Subject: [PATCH] + reserving outpoints should no longer have race conditions + converted all comment type annotations to python 3 style syntax annotations + pylint & mypy --- .travis.yml | 28 +++-- setup.cfg | 23 ++++ setup.py | 3 +- tests/unit/test_coinselection.py | 12 +- tests/unit/test_ledger.py | 8 +- tests/unit/test_wallet.py | 6 +- torba/__init__.py | 2 +- torba/baseaccount.py | 91 +++++++------- torba/basedatabase.py | 39 +++--- torba/baseheader.py | 110 ++++++++--------- torba/baseledger.py | 121 ++++++++++++------- torba/{manager.py => basemanager.py} | 10 +- torba/basenetwork.py | 20 ++-- torba/basescript.py | 21 ++-- torba/basetransaction.py | 173 ++++++++++++++------------- torba/bcd_data_stream.py | 6 +- torba/bip32.py | 45 +++---- torba/coin/__init__.py | 2 +- torba/coin/bitcoincash.py | 9 +- torba/coin/bitcoinsegwit.py | 9 +- torba/coinselection.py | 28 +++-- torba/hash.py | 51 ++------ torba/mnemonic.py | 22 ++-- torba/stream.py | 42 ++----- torba/util.py | 23 ++-- torba/wallet.py | 25 ++-- 26 files changed, 482 insertions(+), 447 deletions(-) rename torba/{manager.py => basemanager.py} (89%) diff --git a/.travis.yml b/.travis.yml index 95224a7be..5b4e00b5a 100644 --- a/.travis.yml +++ b/.travis.yml @@ -5,13 +5,25 @@ language: python python: - "3.7" -install: - - pip install tox-travis coverage - - pushd .. && git clone https://github.com/lbryio/electrumx.git --branch lbryumx && popd - - pushd .. && git clone https://github.com/lbryio/orchstr8.git && popd +jobs: + include: -script: tox + - stage: code quality + name: "pylint & mypy" + install: + - pip install pylint mypy + - pip install -e . + script: + - pylint --rcfile=setup.cfg torba + - mypy torba -after_success: - - coverage combine tests/ - - bash <(curl -s https://codecov.io/bash) + - stage: test + name: "Unit Tests" + install: + - pip install tox-travis coverage + - pushd .. && git clone https://github.com/lbryio/electrumx.git --branch lbryumx && popd + - pushd .. && git clone https://github.com/lbryio/orchstr8.git && popd + script: tox + after_success: + - coverage combine tests/ + - bash <(curl -s https://codecov.io/bash) diff --git a/setup.cfg b/setup.cfg index 96b5abc5a..66c7c22cf 100644 --- a/setup.cfg +++ b/setup.cfg @@ -5,3 +5,26 @@ branch = True source = torba .tox/*/lib/python*/site-packages/torba + +[mypy-twisted.*,cryptography.*,ecdsa.*,pbkdf2] +ignore_missing_imports = True + +[pylint] +max-args=10 +max-line-length=110 +good-names=T,t,n,i,j,k,x,y,s,f,d,h,c,e,op,db,tx,io,cachedproperty,log,id +valid-metaclass-classmethod-first-arg=mcs +disable= + fixme, + no-else-return, + cyclic-import, + missing-docstring, + duplicate-code, + expression-not-assigned, + inconsistent-return-statements, + too-few-public-methods, + too-many-locals, + too-many-arguments, + too-many-public-methods, + too-many-instance-attributes, + protected-access diff --git a/setup.py b/setup.py index 49b3d1130..4e1b7b954 100644 --- a/setup.py +++ b/setup.py @@ -32,8 +32,7 @@ setup( 'twisted', 'ecdsa', 'pbkdf2', - 'cryptography', - 'typing' + 'cryptography' ), extras_require={ 'test': ( diff --git a/tests/unit/test_coinselection.py b/tests/unit/test_coinselection.py index b2adbbd92..e607fc3ef 100644 --- a/tests/unit/test_coinselection.py +++ b/tests/unit/test_coinselection.py @@ -30,13 +30,13 @@ class BaseSelectionTestCase(unittest.TestCase): class TestCoinSelectionTests(BaseSelectionTestCase): def test_empty_coins(self): - self.assertIsNone(CoinSelector([], 0, 0).select()) + self.assertEqual(CoinSelector([], 0, 0).select(), []) def test_skip_binary_search_if_total_not_enough(self): fee = utxo(CENT).get_estimator(self.ledger).fee big_pool = self.estimates(utxo(CENT+fee) for _ in range(100)) selector = CoinSelector(big_pool, 101 * CENT, 0) - self.assertIsNone(selector.select()) + self.assertEqual(selector.select(), []) self.assertEqual(selector.tries, 0) # Never tried. # check happy path selector = CoinSelector(big_pool, 100 * CENT, 0) @@ -108,7 +108,7 @@ class TestOfficialBitcoinCoinSelectionTests(BaseSelectionTestCase): self.assertEqual([3 * CENT, 2 * CENT], search(utxo_pool, 5 * CENT, 0.5 * CENT)) # Select 11 Cent, not possible - self.assertIsNone(search(utxo_pool, 11 * CENT, 0.5 * CENT)) + self.assertEqual(search(utxo_pool, 11 * CENT, 0.5 * CENT), []) # Select 10 Cent utxo_pool += self.estimates(utxo(5 * CENT)) @@ -126,12 +126,12 @@ class TestOfficialBitcoinCoinSelectionTests(BaseSelectionTestCase): ) # Select 0.25 Cent, not possible - self.assertIsNone(search(utxo_pool, 0.25 * CENT, 0.5 * CENT)) + self.assertEqual(search(utxo_pool, 0.25 * CENT, 0.5 * CENT), []) # Iteration exhaustion test utxo_pool, target = self.make_hard_case(17) selector = CoinSelector(utxo_pool, target, 0) - self.assertIsNone(selector.branch_and_bound()) + self.assertEqual(selector.branch_and_bound(), []) self.assertEqual(selector.tries, MAXIMUM_TRIES) # Should exhaust utxo_pool, target = self.make_hard_case(14) self.assertIsNotNone(search(utxo_pool, target, 0)) # Should not exhaust @@ -152,4 +152,4 @@ class TestOfficialBitcoinCoinSelectionTests(BaseSelectionTestCase): # Select 1 Cent with pool of only greater than 5 Cent utxo_pool = self.estimates(utxo(i * CENT) for i in range(5, 21)) for _ in range(100): - self.assertIsNone(search(utxo_pool, 1 * CENT, 2 * CENT)) + self.assertEqual(search(utxo_pool, 1 * CENT, 2 * CENT), []) diff --git a/tests/unit/test_ledger.py b/tests/unit/test_ledger.py index afedf6042..6a6aff21b 100644 --- a/tests/unit/test_ledger.py +++ b/tests/unit/test_ledger.py @@ -1,4 +1,3 @@ -import six from binascii import hexlify from twisted.trial import unittest from twisted.internet import defer @@ -7,9 +6,6 @@ from torba.coin.bitcoinsegwit import MainNetLedger from .test_transaction import get_transaction, get_output -if six.PY3: - buffer = memoryview - class MockNetwork: @@ -50,9 +46,7 @@ class MainNetTestLedger(MainNetLedger): network_name = 'unittest' def __init__(self): - super(MainNetLedger, self).__init__({ - 'db': MainNetLedger.database_class(':memory:') - }) + super().__init__({'db': MainNetLedger.database_class(':memory:')}) class LedgerTestCase(unittest.TestCase): diff --git a/tests/unit/test_wallet.py b/tests/unit/test_wallet.py index c52119e4e..b52ee4ebb 100644 --- a/tests/unit/test_wallet.py +++ b/tests/unit/test_wallet.py @@ -3,14 +3,14 @@ from twisted.trial import unittest from torba.coin.bitcoinsegwit import MainNetLedger as BTCLedger from torba.coin.bitcoincash import MainNetLedger as BCHLedger -from torba.manager import WalletManager +from torba.basemanager import BaseWalletManager from torba.wallet import Wallet, WalletStorage class TestWalletCreation(unittest.TestCase): def setUp(self): - self.manager = WalletManager() + self.manager = BaseWalletManager() config = {'data_path': '/tmp/wallet'} self.btc_ledger = self.manager.get_or_create_ledger(BTCLedger.get_id(), config) self.bch_ledger = self.manager.get_or_create_ledger(BCHLedger.get_id(), config) @@ -63,7 +63,7 @@ class TestWalletCreation(unittest.TestCase): self.assertDictEqual(wallet_dict, wallet.to_dict()) def test_read_write(self): - manager = WalletManager() + manager = BaseWalletManager() config = {'data_path': '/tmp/wallet'} ledger = manager.get_or_create_ledger(BTCLedger.get_id(), config) diff --git a/torba/__init__.py b/torba/__init__.py index a3b65bdd9..210e2b896 100644 --- a/torba/__init__.py +++ b/torba/__init__.py @@ -1,2 +1,2 @@ -__path__ = __import__('pkgutil').extend_path(__path__, __name__) +__path__: str = __import__('pkgutil').extend_path(__path__, __name__) __version__ = '0.0.4' diff --git a/torba/baseaccount.py b/torba/baseaccount.py index 0f1c0c3f1..532b364da 100644 --- a/torba/baseaccount.py +++ b/torba/baseaccount.py @@ -1,12 +1,16 @@ -from typing import Dict +import typing +from typing import Sequence from twisted.internet import defer from torba.mnemonic import Mnemonic from torba.bip32 import PrivateKey, PubKey, from_extended_key_string from torba.hash import double_sha256, aes_encrypt, aes_decrypt +if typing.TYPE_CHECKING: + from torba import baseledger -class KeyManager(object): + +class KeyManager: __slots__ = 'account', 'public_key', 'chain_number' @@ -19,27 +23,27 @@ class KeyManager(object): def db(self): return self.account.ledger.db - def _query_addresses(self, limit=None, max_used_times=None, order_by=None): + def _query_addresses(self, limit: int = None, max_used_times: int = None, order_by=None): return self.db.get_addresses( self.account, self.chain_number, limit, max_used_times, order_by ) - def get_max_gap(self): # type: () -> defer.Deferred + def get_max_gap(self) -> defer.Deferred: raise NotImplementedError - def ensure_address_gap(self): # type: () -> defer.Deferred + def ensure_address_gap(self) -> defer.Deferred: raise NotImplementedError - def get_address_records(self, limit=None, only_usable=False): # type: (int, bool) -> defer.Deferred + def get_address_records(self, limit: int = None, only_usable: bool = False) -> defer.Deferred: raise NotImplementedError @defer.inlineCallbacks - def get_addresses(self, limit=None, only_usable=False): # type: (int, bool) -> defer.Deferred + def get_addresses(self, limit: int = None, only_usable: bool = False) -> defer.Deferred: records = yield self.get_address_records(limit=limit, only_usable=only_usable) defer.returnValue([r['address'] for r in records]) @defer.inlineCallbacks - def get_or_create_usable_address(self): # type: () -> defer.Deferred + def get_or_create_usable_address(self) -> defer.Deferred: addresses = yield self.get_addresses(limit=1, only_usable=True) if addresses: defer.returnValue(addresses[0]) @@ -52,14 +56,14 @@ class KeyChain(KeyManager): __slots__ = 'gap', 'maximum_uses_per_address' - def __init__(self, account, root_public_key, chain_number, gap, maximum_uses_per_address): - # type: ('BaseAccount', PubKey, int, int, int) -> None - super(KeyChain, self).__init__(account, root_public_key.child(chain_number), chain_number) + def __init__(self, account: 'BaseAccount', root_public_key: PubKey, + chain_number: int, gap: int, maximum_uses_per_address: int) -> None: + super().__init__(account, root_public_key.child(chain_number), chain_number) self.gap = gap self.maximum_uses_per_address = maximum_uses_per_address @defer.inlineCallbacks - def generate_keys(self, start, end): + def generate_keys(self, start: int, end: int) -> defer.Deferred: new_keys = [] for index in range(start, end+1): new_keys.append((index, self.public_key.child(index))) @@ -69,7 +73,7 @@ class KeyChain(KeyManager): defer.returnValue([key[1].address for key in new_keys]) @defer.inlineCallbacks - def get_max_gap(self): + def get_max_gap(self) -> defer.Deferred: addresses = yield self._query_addresses(order_by="position ASC") max_gap = 0 current_gap = 0 @@ -82,7 +86,7 @@ class KeyChain(KeyManager): defer.returnValue(max_gap) @defer.inlineCallbacks - def ensure_address_gap(self): + def ensure_address_gap(self) -> defer.Deferred: addresses = yield self._query_addresses(self.gap, None, "position DESC") existing_gap = 0 @@ -100,7 +104,7 @@ class KeyChain(KeyManager): new_keys = yield self.generate_keys(start, end-1) defer.returnValue(new_keys) - def get_address_records(self, limit=None, only_usable=False): + def get_address_records(self, limit: int = None, only_usable: bool = False): return self._query_addresses( limit, self.maximum_uses_per_address if only_usable else None, "used_times ASC, position ASC" @@ -112,15 +116,11 @@ class SingleKey(KeyManager): __slots__ = () - def __init__(self, account, root_public_key, chain_number): - # type: ('BaseAccount', PubKey) -> None - super(SingleKey, self).__init__(account, root_public_key, chain_number) - - def get_max_gap(self): + def get_max_gap(self) -> defer.Deferred: return defer.succeed(0) @defer.inlineCallbacks - def ensure_address_gap(self): + def ensure_address_gap(self) -> defer.Deferred: exists = yield self.get_address_records() if not exists: yield self.db.add_keys( @@ -129,20 +129,20 @@ class SingleKey(KeyManager): defer.returnValue([self.public_key.address]) defer.returnValue([]) - def get_address_records(self, **kwargs): + def get_address_records(self, limit: int = None, only_usable: bool = False) -> defer.Deferred: return self._query_addresses() -class BaseAccount(object): +class BaseAccount: mnemonic_class = Mnemonic private_key_class = PrivateKey public_key_class = PubKey - def __init__(self, ledger, name, seed, encrypted, is_hd, private_key, - public_key, receiving_gap=20, change_gap=6, - receiving_maximum_uses_per_address=2, change_maximum_uses_per_address=2): - # type: (torba.baseledger.BaseLedger, str, str, bool, bool, PrivateKey, PubKey, int, int, int, int) -> None + def __init__(self, ledger: 'baseledger.BaseLedger', name: str, seed: str, encrypted: bool, is_hd: bool, + private_key: PrivateKey, public_key: PubKey, receiving_gap: int = 20, change_gap: int = 6, + receiving_maximum_uses_per_address: int = 2, change_maximum_uses_per_address: int = 2 + ) -> None: self.ledger = ledger self.name = name self.seed = seed @@ -150,25 +150,26 @@ class BaseAccount(object): self.private_key = private_key self.public_key = public_key if is_hd: - receiving, change = self.keychains = ( - KeyChain(self, public_key, 0, receiving_gap, receiving_maximum_uses_per_address), - KeyChain(self, public_key, 1, change_gap, change_maximum_uses_per_address) + self.receiving: KeyManager = KeyChain( + self, public_key, 0, receiving_gap, receiving_maximum_uses_per_address ) + self.change: KeyManager = KeyChain( + self, public_key, 1, change_gap, change_maximum_uses_per_address + ) + self.keychains: Sequence[KeyManager] = (self.receiving, self.change) else: - self.keychains = SingleKey(self, public_key, 0), - receiving = change = self.keychains[0] - self.receiving = receiving # type: KeyManager - self.change = change # type: KeyManager + self.change = self.receiving = SingleKey(self, public_key, 0) + self.keychains = (self.receiving,) ledger.add_account(self) @classmethod - def generate(cls, ledger, password, **kwargs): # type: (torba.baseledger.BaseLedger, str) -> BaseAccount + def generate(cls, ledger: 'baseledger.BaseLedger', password: str, **kwargs): seed = cls.mnemonic_class().make_seed() return cls.from_seed(ledger, seed, password, **kwargs) @classmethod - def from_seed(cls, ledger, seed, password, is_hd=True, **kwargs): - # type: (torba.baseledger.BaseLedger, str, str) -> BaseAccount + def from_seed(cls, ledger: 'baseledger.BaseLedger', seed: str, password: str, + is_hd: bool = True, **kwargs): private_key = cls.get_private_key_from_seed(ledger, seed, password) return cls( ledger=ledger, name='Account #{}'.format(private_key.public_key.address), @@ -179,14 +180,13 @@ class BaseAccount(object): ) @classmethod - def get_private_key_from_seed(cls, ledger, seed, password): - # type: (torba.baseledger.BaseLedger, str, str) -> PrivateKey + def get_private_key_from_seed(cls, ledger: 'baseledger.BaseLedger', seed: str, password: str): return cls.private_key_class.from_seed( ledger, cls.mnemonic_class.mnemonic_to_seed(seed, password) ) @classmethod - def from_dict(cls, ledger, d): # type: (torba.baseledger.BaseLedger, Dict) -> BaseAccount + def from_dict(cls, ledger: 'baseledger.BaseLedger', d: dict): if not d['encrypted'] and d['private_key']: private_key = from_extended_key_string(ledger, d['private_key']) public_key = private_key.public_key @@ -264,21 +264,20 @@ class BaseAccount(object): defer.returnValue(addresses) @defer.inlineCallbacks - def get_addresses(self, limit=None, max_used_times=None): # type: (int, int) -> defer.Deferred + def get_addresses(self, limit: int = None, max_used_times: int = None) -> defer.Deferred: records = yield self.get_address_records(limit, max_used_times) defer.returnValue([r['address'] for r in records]) - def get_address_records(self, limit=None, max_used_times=None): # type: (int, int) -> defer.Deferred + def get_address_records(self, limit: int = None, max_used_times: int = None) -> defer.Deferred: return self.ledger.db.get_addresses(self, None, limit, max_used_times) - def get_private_key(self, chain, index): + def get_private_key(self, chain: int, index: int) -> PrivateKey: assert not self.encrypted, "Cannot get private key on encrypted wallet account." if isinstance(self.receiving, SingleKey): return self.private_key - else: - return self.private_key.child(chain).child(index) + return self.private_key.child(chain).child(index) - def get_balance(self, confirmations=6, **constraints): + def get_balance(self, confirmations: int = 6, **constraints): if confirmations > 0: height = self.ledger.headers.height - (confirmations-1) constraints.update({'height__lte': height, 'height__gt': 0}) diff --git a/torba/basedatabase.py b/torba/basedatabase.py index 0c8d0a54f..b086f7700 100644 --- a/torba/basedatabase.py +++ b/torba/basedatabase.py @@ -1,5 +1,5 @@ import logging -from typing import List, Union +from typing import Tuple, List, Sequence from operator import itemgetter import sqlite3 @@ -11,13 +11,13 @@ from torba.hash import TXRefImmutable log = logging.getLogger(__name__) -class SQLiteMixin(object): +class SQLiteMixin: - CREATE_TABLES_QUERY = None + CREATE_TABLES_QUERY: Sequence[str] = () def __init__(self, path): self._db_path = path - self.db = None + self.db: adbapi.ConnectionPool = None def start(self): log.info("connecting to database: %s", self._db_path) @@ -32,8 +32,8 @@ class SQLiteMixin(object): self.db.close() return defer.succeed(True) - def _insert_sql(self, table, data): - # type: (str, dict) -> tuple[str, List] + @staticmethod + def _insert_sql(table: str, data: dict) -> Tuple[str, List]: columns, values = [], [] for column, value in data.items(): columns.append(column) @@ -43,8 +43,8 @@ class SQLiteMixin(object): ) return sql, values - def _update_sql(self, table, data, where, constraints): - # type: (str, dict) -> tuple[str, List] + @staticmethod + def _update_sql(table: str, data: dict, where: str, constraints: list) -> Tuple[str, list]: columns, values = [], [] for column, value in data.items(): columns.append("{} = ?".format(column)) @@ -146,7 +146,8 @@ class BaseDatabase(SQLiteMixin): CREATE_TXI_TABLE ) - def txo_to_row(self, tx, address, txo): + @staticmethod + def txo_to_row(tx, address, txo): return { 'txid': tx.id, 'txoid': txo.id, @@ -156,7 +157,7 @@ class BaseDatabase(SQLiteMixin): 'script': sqlite3.Binary(txo.script.source) } - def save_transaction_io(self, save_tx, tx, height, is_verified, address, hash, history): + def save_transaction_io(self, save_tx, tx, height, is_verified, address, txhash, history): def _steps(t): if save_tx == 'insert': @@ -168,9 +169,8 @@ class BaseDatabase(SQLiteMixin): })) elif save_tx == 'update': self.execute(t, *self._update_sql("tx", { - 'height': height, 'is_verified': is_verified - }, 'txid = ?', (tx.id,) - )) + 'height': height, 'is_verified': is_verified + }, 'txid = ?', (tx.id,))) existing_txos = list(map(itemgetter(0), self.execute( t, "SELECT position FROM txo WHERE txid = ?", (tx.id,) @@ -179,7 +179,7 @@ class BaseDatabase(SQLiteMixin): for txo in tx.outputs: if txo.position in existing_txos: continue - if txo.script.is_pay_pubkey_hash and txo.script.values['pubkey_hash'] == hash: + if txo.script.is_pay_pubkey_hash and txo.script.values['pubkey_hash'] == txhash: self.execute(t, *self._insert_sql("txo", self.txo_to_row(tx, address, txo))) elif txo.script.is_pay_script_hash: # TODO: implement script hash payments @@ -202,15 +202,16 @@ class BaseDatabase(SQLiteMixin): return self.db.runInteraction(_steps) - def reserve_spent_outputs(self, txoids, is_reserved=True): + def reserve_outputs(self, txos, is_reserved=True): + txoids = [txo.id for txo in txos] return self.run_operation( "UPDATE txo SET is_reserved = ? WHERE txoid IN ({})".format( ', '.join(['?']*len(txoids)) ), [is_reserved]+txoids ) - def release_reserved_outputs(self, txoids): - return self.reserve_spent_outputs(txoids, is_reserved=False) + def release_outputs(self, txos): + return self.reserve_outputs(txos, is_reserved=False) @defer.inlineCallbacks def get_transaction(self, txid): @@ -226,7 +227,7 @@ class BaseDatabase(SQLiteMixin): extra_sql = "" if constraints: extras = [] - for key in constraints.keys(): + for key in constraints: col, op = key, '=' if key.endswith('__not'): col, op = key[:-len('__not')], '!=' @@ -257,7 +258,7 @@ class BaseDatabase(SQLiteMixin): extra_sql = "" if constraints: extra_sql = ' AND ' + ' AND '.join( - '{} = :{}'.format(c, c) for c in constraints.keys() + '{} = :{}'.format(c, c) for c in constraints ) values = {'account': account.public_key.address} values.update(constraints) diff --git a/torba/baseheader.py b/torba/baseheader.py index 8815a6437..e28fde892 100644 --- a/torba/baseheader.py +++ b/torba/baseheader.py @@ -1,13 +1,16 @@ import os import struct import logging +import typing from binascii import unhexlify from twisted.internet import threads, defer -from torba.stream import StreamController, execute_serially +from torba.stream import StreamController from torba.util import int_to_hex, rev_hex, hash_encode from torba.hash import double_sha256, pow_hash +if typing.TYPE_CHECKING: + from torba import baseledger log = logging.getLogger(__name__) @@ -17,7 +20,7 @@ class BaseHeaders: header_size = 80 verify_bits_to_target = True - def __init__(self, ledger): # type: (baseledger.BaseLedger) -> BaseHeaders + def __init__(self, ledger: 'baseledger.BaseLedger') -> None: self.ledger = ledger self._size = None self._on_change_controller = StreamController() @@ -62,7 +65,6 @@ class BaseHeaders: header = self.sync_read_header(height) return self._deserialize(height, header) - @execute_serially @defer.inlineCallbacks def connect(self, start, headers): yield threads.deferToThread(self._sync_connect, start, headers) @@ -84,8 +86,9 @@ class BaseHeaders: _old_size = self._size self._size = self.sync_read_length() change = self._size - _old_size - log.info('{}: added {} header blocks, final height {}'.format( - self.ledger.get_id(), change, self.height) + log.info( + '%s: added %s header blocks, final height %s', + self.ledger.get_id(), change, self.height ) self._on_change_controller.add(change) @@ -101,7 +104,7 @@ class BaseHeaders: assert previous_hash == header['prev_block_hash'], \ "prev hash mismatch: {} vs {}".format(previous_hash, header['prev_block_hash']) - bits, target = self._calculate_next_work_required(height, previous_header, header) + bits, _ = self._calculate_next_work_required(height, previous_header, header) assert bits == header['bits'], \ "bits mismatch: {} vs {} (hash: {})".format( bits, header['bits'], self._hash_header(header)) @@ -154,37 +157,37 @@ class BaseHeaders: if self.verify_bits_to_target: bits = last['bits'] - bitsN = (bits >> 24) & 0xff - assert 0x03 <= bitsN <= 0x1d, \ - "First part of bits should be in [0x03, 0x1d], but it was {}".format(hex(bitsN)) - bitsBase = bits & 0xffffff - assert 0x8000 <= bitsBase <= 0x7fffff, \ - "Second part of bits should be in [0x8000, 0x7fffff] but it was {}".format(bitsBase) + bits_n = (bits >> 24) & 0xff + assert 0x03 <= bits_n <= 0x1d, \ + "First part of bits should be in [0x03, 0x1d], but it was {}".format(hex(bits_n)) + bits_base = bits & 0xffffff + assert 0x8000 <= bits_base <= 0x7fffff, \ + "Second part of bits should be in [0x8000, 0x7fffff] but it was {}".format(bits_base) # new target - retargetTimespan = self.ledger.target_timespan - nActualTimespan = last['timestamp'] - first['timestamp'] + retarget_timespan = self.ledger.target_timespan + n_actual_timespan = last['timestamp'] - first['timestamp'] - nModulatedTimespan = retargetTimespan + (nActualTimespan - retargetTimespan) // 8 + n_modulated_timespan = retarget_timespan + (n_actual_timespan - retarget_timespan) // 8 - nMinTimespan = retargetTimespan - (retargetTimespan // 8) - nMaxTimespan = retargetTimespan + (retargetTimespan // 2) + n_min_timespan = retarget_timespan - (retarget_timespan // 8) + n_max_timespan = retarget_timespan + (retarget_timespan // 2) # Limit adjustment step - if nModulatedTimespan < nMinTimespan: - nModulatedTimespan = nMinTimespan - elif nModulatedTimespan > nMaxTimespan: - nModulatedTimespan = nMaxTimespan + if n_modulated_timespan < n_min_timespan: + n_modulated_timespan = n_min_timespan + elif n_modulated_timespan > n_max_timespan: + n_modulated_timespan = n_max_timespan # Retarget - bnPowLimit = _ArithUint256(self.ledger.max_target) - bnNew = _ArithUint256.SetCompact(last['bits']) - bnNew *= nModulatedTimespan - bnNew //= nModulatedTimespan - if bnNew > bnPowLimit: - bnNew = bnPowLimit + bn_pow_limit = _ArithUint256(self.ledger.max_target) + bn_new = _ArithUint256.set_compact(last['bits']) + bn_new *= n_modulated_timespan + bn_new //= n_modulated_timespan + if bn_new > bn_pow_limit: + bn_new = bn_pow_limit - return bnNew.GetCompact(), bnNew._value + return bn_new.get_compact(), bn_new._value class _ArithUint256: @@ -197,49 +200,48 @@ class _ArithUint256: return hex(self._value) @staticmethod - def fromCompact(nCompact): + def from_compact(n_compact): """Convert a compact representation into its value""" - nSize = nCompact >> 24 + n_size = n_compact >> 24 # the lower 23 bits - nWord = nCompact & 0x007fffff - if nSize <= 3: - return nWord >> 8 * (3 - nSize) + n_word = n_compact & 0x007fffff + if n_size <= 3: + return n_word >> 8 * (3 - n_size) else: - return nWord << 8 * (nSize - 3) + return n_word << 8 * (n_size - 3) @classmethod - def SetCompact(cls, nCompact): - return cls(cls.fromCompact(nCompact)) + def set_compact(cls, n_compact): + return cls(cls.from_compact(n_compact)) def bits(self): """Returns the position of the highest bit set plus one.""" - bn = bin(self._value)[2:] - for i, d in enumerate(bn): + bits = bin(self._value)[2:] + for i, d in enumerate(bits): if d: - return (len(bn) - i) + 1 + return (len(bits) - i) + 1 return 0 - def GetLow64(self): + def get_low64(self): return self._value & 0xffffffffffffffff - def GetCompact(self): + def get_compact(self): """Convert a value into its compact representation""" - nSize = (self.bits() + 7) // 8 - nCompact = 0 - if nSize <= 3: - nCompact = self.GetLow64() << 8 * (3 - nSize) + n_size = (self.bits() + 7) // 8 + if n_size <= 3: + n_compact = self.get_low64() << 8 * (3 - n_size) else: - bn = _ArithUint256(self._value >> 8 * (nSize - 3)) - nCompact = bn.GetLow64() + n = _ArithUint256(self._value >> 8 * (n_size - 3)) + n_compact = n.get_low64() # The 0x00800000 bit denotes the sign. # Thus, if it is already set, divide the mantissa by 256 and increase the exponent. - if nCompact & 0x00800000: - nCompact >>= 8 - nSize += 1 - assert (nCompact & ~0x007fffff) == 0 - assert nSize < 256 - nCompact |= nSize << 24 - return nCompact + if n_compact & 0x00800000: + n_compact >>= 8 + n_size += 1 + assert (n_compact & ~0x007fffff) == 0 + assert n_size < 256 + n_compact |= n_size << 24 + return n_compact def __mul__(self, x): # Take the mod because we are limited to an unsigned 256 bit number diff --git a/torba/baseledger.py b/torba/baseledger.py index b1d857716..bf64bfd01 100644 --- a/torba/baseledger.py +++ b/torba/baseledger.py @@ -1,6 +1,4 @@ import os -import six -import hashlib import logging from binascii import hexlify, unhexlify from typing import Dict, Type, Iterable @@ -14,17 +12,22 @@ from torba import basedatabase from torba import baseheader from torba import basenetwork from torba import basetransaction -from torba.stream import StreamController, execute_serially -from torba.hash import hash160, double_sha256, Base58 +from torba.coinselection import CoinSelector +from torba.constants import COIN, NULL_HASH32 +from torba.stream import StreamController +from torba.hash import hash160, double_sha256, sha256, Base58 log = logging.getLogger(__name__) +LedgerType = Type['BaseLedger'] + class LedgerRegistry(type): - ledgers = {} # type: Dict[str, Type[BaseLedger]] + + ledgers: Dict[str, LedgerType] = {} def __new__(mcs, name, bases, attrs): - cls = super(LedgerRegistry, mcs).__new__(mcs, name, bases, attrs) # type: Type[BaseLedger] + cls: LedgerType = super().__new__(mcs, name, bases, attrs) if not (name == 'BaseLedger' and not bases): ledger_id = cls.get_id() assert ledger_id not in mcs.ledgers,\ @@ -33,7 +36,7 @@ class LedgerRegistry(type): return cls @classmethod - def get_ledger_class(mcs, ledger_id): # type: (str) -> Type[BaseLedger] + def get_ledger_class(mcs, ledger_id: str) -> LedgerType: return mcs.ledgers[ledger_id] @@ -41,11 +44,11 @@ class TransactionEvent(namedtuple('TransactionEvent', ('address', 'tx', 'height' pass -class BaseLedger(six.with_metaclass(LedgerRegistry)): +class BaseLedger(metaclass=LedgerRegistry): - name = None - symbol = None - network_name = None + name: str + symbol: str + network_name: str account_class = baseaccount.BaseAccount database_class = basedatabase.BaseDatabase @@ -54,10 +57,10 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)): transaction_class = basetransaction.BaseTransaction secret_prefix = None - pubkey_address_prefix = None - script_address_prefix = None - extended_public_key_prefix = None - extended_private_key_prefix = None + pubkey_address_prefix: bytes + script_address_prefix: bytes + extended_public_key_prefix: bytes + extended_private_key_prefix: bytes default_fee_per_byte = 10 @@ -71,13 +74,14 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)): self.network.on_status.listen(self.process_status) self.accounts = [] self.headers = self.config.get('headers') or self.headers_class(self) - self.fee_per_byte = self.config.get('fee_per_byte', self.default_fee_per_byte) + self.fee_per_byte: int = self.config.get('fee_per_byte', self.default_fee_per_byte) self._on_transaction_controller = StreamController() self.on_transaction = self._on_transaction_controller.stream self.on_transaction.listen( - lambda e: log.info('({}) on_transaction: address={}, height={}, is_verified={}, tx.id={}'.format( - self.get_id(), e.address, e.height, e.is_verified, e.tx.id) + lambda e: log.info( + '(%s) on_transaction: address=%s, height=%s, is_verified=%s, tx.id=%s', + self.get_id(), e.address, e.height, e.is_verified, e.tx.id ) ) @@ -85,6 +89,8 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)): self.on_header = self._on_header_controller.stream self._transaction_processing_locks = {} + self._utxo_reservation_lock = defer.DeferredLock() + self._header_processing_lock = defer.DeferredLock() @classmethod def get_id(cls): @@ -97,9 +103,7 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)): @staticmethod def address_to_hash160(address): - bytes = Base58.decode(address) - prefix, pubkey_bytes, addr_checksum = bytes[0], bytes[1:21], bytes[21:] - return pubkey_bytes + return Base58.decode(address)[1:21] @classmethod def public_key_to_address(cls, public_key): @@ -113,7 +117,7 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)): def path(self): return os.path.join(self.config['data_path'], self.get_id()) - def get_input_output_fee(self, io): + def get_input_output_fee(self, io: basetransaction.InputOutput) -> int: """ Fee based on size of the input / output. """ return self.fee_per_byte * io.size @@ -122,14 +126,14 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)): return self.fee_per_byte * tx.base_size @defer.inlineCallbacks - def add_account(self, account): # type: (baseaccount.BaseAccount) -> None + def add_account(self, account: baseaccount.BaseAccount) -> defer.Deferred: self.accounts.append(account) if self.network.is_connected: yield self.update_account(account) @defer.inlineCallbacks def get_transaction(self, txhash): - raw, height, is_verified = yield self.db.get_transaction(txhash) + raw, _, _ = yield self.db.get_transaction(txhash) if raw is not None: defer.returnValue(self.transaction_class(raw)) @@ -142,8 +146,7 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)): defer.returnValue(account.get_private_key(match['chain'], match['position'])) @defer.inlineCallbacks - def get_effective_amount_estimators(self, funding_accounts): - # type: (Iterable[baseaccount.BaseAccount]) -> defer.Deferred + def get_effective_amount_estimators(self, funding_accounts: Iterable[baseaccount.BaseAccount]): estimators = [] for account in funding_accounts: utxos = yield account.get_unspent_outputs() @@ -151,12 +154,39 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)): estimators.append(utxo.get_estimator(self)) defer.returnValue(estimators) + @defer.inlineCallbacks + def get_spendable_utxos(self, amount: int, funding_accounts): + yield self._utxo_reservation_lock.acquire() + try: + txos = yield self.get_effective_amount_estimators(funding_accounts) + selector = CoinSelector( + txos, amount, + self.get_input_output_fee( + self.transaction_class.output_class.pay_pubkey_hash(COIN, NULL_HASH32) + ) + ) + spendables = selector.select() + if spendables: + yield self.reserve_outputs(s.txo for s in spendables) + except Exception: + log.exception('Failed to get spendable utxos:') + raise + finally: + self._utxo_reservation_lock.release() + defer.returnValue(spendables) + + def reserve_outputs(self, txos): + return self.db.reserve_outputs(txos) + + def release_outputs(self, txos): + return self.db.release_outputs(txos) + @defer.inlineCallbacks def get_local_status(self, address): address_details = yield self.db.get_address(address) history = address_details['history'] or '' - hash = hashlib.sha256(history.encode()).digest() - defer.returnValue(hexlify(hash)) + h = sha256(history.encode()) + defer.returnValue(hexlify(h)) @defer.inlineCallbacks def get_local_history(self, address): @@ -203,7 +233,6 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)): yield self.network.stop() yield self.db.stop() - @execute_serially @defer.inlineCallbacks def update_headers(self): while True: @@ -216,18 +245,19 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)): @defer.inlineCallbacks def process_header(self, response): - header = response[0] - if self.update_headers.is_running: - return - if header['height'] == len(self.headers): - # New header from network directly connects after the last local header. - yield self.headers.connect(len(self.headers), unhexlify(header['hex'])) - self._on_header_controller.add(self.headers.height) - elif header['height'] > len(self.headers): - # New header is several heights ahead of local, do download instead. - yield self.update_headers() + yield self._header_processing_lock.acquire() + try: + header = response[0] + if header['height'] == len(self.headers): + # New header from network directly connects after the last local header. + yield self.headers.connect(len(self.headers), unhexlify(header['hex'])) + self._on_header_controller.add(self.headers.height) + elif header['height'] > len(self.headers): + # New header is several heights ahead of local, do download instead. + yield self.update_headers() + finally: + self._header_processing_lock.release() - @execute_serially def update_accounts(self): return defer.DeferredList([ self.update_account(a) for a in self.accounts @@ -274,7 +304,7 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)): try: # see if we have a local copy of transaction, otherwise fetch it from server - raw, local_height, is_verified = yield self.db.get_transaction(hex_id) + raw, _, is_verified = yield self.db.get_transaction(hex_id) save_tx = None if raw is None: _raw = yield self.network.get_transaction(hex_id) @@ -294,15 +324,16 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)): ''.join('{}:{}:'.format(tx_id, tx_height) for tx_id, tx_height in synced_history) ) - log.debug("{}: sync'ed tx {} for address: {}, height: {}, verified: {}".format( + log.debug( + "%s: sync'ed tx %s for address: %s, height: %s, verified: %s", self.get_id(), hex_id, address, remote_height, is_verified - )) + ) self._on_transaction_controller.add(TransactionEvent(address, tx, remote_height, is_verified)) - except Exception as e: + except Exception: log.exception('Failed to synchronize transaction:') - raise e + raise finally: lock.release() diff --git a/torba/manager.py b/torba/basemanager.py similarity index 89% rename from torba/manager.py rename to torba/basemanager.py index 4f7f6e823..7c18a16ef 100644 --- a/torba/manager.py +++ b/torba/basemanager.py @@ -1,4 +1,4 @@ -from typing import List, Dict, Type +from typing import Type, MutableSequence, MutableMapping from twisted.internet import defer from torba.baseledger import BaseLedger, LedgerRegistry @@ -6,16 +6,16 @@ from torba.wallet import Wallet, WalletStorage from torba.constants import COIN -class WalletManager(object): +class BaseWalletManager: - def __init__(self, wallets=None, ledgers=None): - # type: (List[Wallet], Dict[Type[BaseLedger],BaseLedger]) -> None + def __init__(self, wallets: MutableSequence[Wallet] = None, + ledgers: MutableMapping[Type[BaseLedger], BaseLedger] = None) -> None: self.wallets = wallets or [] self.ledgers = ledgers or {} self.running = False @classmethod - def from_config(cls, config): # type: (Dict) -> WalletManager + def from_config(cls, config: dict) -> 'BaseWalletManager': manager = cls() for ledger_id, ledger_config in config.get('ledgers', {}).items(): manager.get_or_create_ledger(ledger_id, ledger_config) diff --git a/torba/basenetwork.py b/torba/basenetwork.py index cdc74c61b..c1f3926d9 100644 --- a/torba/basenetwork.py +++ b/torba/basenetwork.py @@ -21,6 +21,7 @@ class StratumClientProtocol(LineOnlyReceiver): self.request_id = 0 self.lookup_table = {} self.session = {} + self.network = None self.on_disconnected_controller = StreamController() self.on_disconnected = self.on_disconnected_controller.stream @@ -52,7 +53,7 @@ class StratumClientProtocol(LineOnlyReceiver): socket.SOL_TCP, socket.TCP_KEEPCNT, 5 # Failed keepalive probles before declaring other end dead ) - except Exception as err: + except Exception as err: # pylint: disable=broad-except # Supported only by the socket transport, # but there's really no better place in code to trigger this. log.warning("Error setting up socket: %s", err) @@ -61,7 +62,7 @@ class StratumClientProtocol(LineOnlyReceiver): self.on_disconnected_controller.add(True) def lineReceived(self, line): - log.debug('received: {}'.format(line)) + log.debug('received: %s', line) try: message = json.loads(line) @@ -82,7 +83,7 @@ class StratumClientProtocol(LineOnlyReceiver): controller = self.network.subscription_controllers[message['method']] controller.add(message.get('params')) else: - log.warning("Cannot handle message '%s'" % line) + log.warning("Cannot handle message '%s'", line) def rpc(self, method, *args): message_id = self._get_id() @@ -91,7 +92,7 @@ class StratumClientProtocol(LineOnlyReceiver): 'method': method, 'params': args }) - log.debug('sent: {}'.format(message)) + log.debug('sent: %s', message) self.sendLine(message.encode('latin-1')) d = self.lookup_table[message_id] = defer.Deferred() return d @@ -138,20 +139,21 @@ class BaseNetwork: @defer.inlineCallbacks def start(self): for server in cycle(self.config['default_servers']): - endpoint = clientFromString(reactor, 'tcp:{}:{}'.format(*server)) - log.debug("Attempting connection to SPV wallet server: {}:{}".format(*server)) + connection_string = 'tcp:{}:{}'.format(*server) + endpoint = clientFromString(reactor, connection_string) + log.debug("Attempting connection to SPV wallet server: %s", connection_string) self.service = ClientService(endpoint, StratumClientFactory(self)) self.service.startService() try: self.client = yield self.service.whenConnected(failAfterFailures=2) yield self.ensure_server_version() - log.info("Successfully connected to SPV wallet server: {}:{}".format(*server)) + log.info("Successfully connected to SPV wallet server: %s", connection_string) self._on_connected_controller.add(True) yield self.client.on_disconnected.first except CancelledError: return - except Exception: - log.exception("Connecting to {}:{} raised an exception:".format(*server)) + except Exception: # pylint: disable=broad-except + log.exception("Connecting to %s raised an exception:", connection_string) finally: self.client = None if not self.running: diff --git a/torba/basescript.py b/torba/basescript.py index 4cacf6e47..da3527477 100644 --- a/torba/basescript.py +++ b/torba/basescript.py @@ -1,6 +1,7 @@ from itertools import chain from binascii import hexlify from collections import namedtuple +from typing import List from torba.bcd_data_stream import BCDataStream from torba.util import subclass_tuple @@ -25,17 +26,21 @@ OP_DROP = 0x75 # template matching opcodes (not real opcodes) # base class for PUSH_DATA related opcodes +# pylint: disable=invalid-name PUSH_DATA_OP = namedtuple('PUSH_DATA_OP', 'name') # opcode for variable length strings +# pylint: disable=invalid-name PUSH_SINGLE = subclass_tuple('PUSH_SINGLE', PUSH_DATA_OP) # opcode for variable number of variable length strings +# pylint: disable=invalid-name PUSH_MANY = subclass_tuple('PUSH_MANY', PUSH_DATA_OP) # opcode with embedded subscript parsing +# pylint: disable=invalid-name PUSH_SUBSCRIPT = namedtuple('PUSH_SUBSCRIPT', 'name template') def is_push_data_opcode(opcode): - return isinstance(opcode, PUSH_DATA_OP) or isinstance(opcode, PUSH_SUBSCRIPT) + return isinstance(opcode, (PUSH_DATA_OP, PUSH_SUBSCRIPT)) def is_push_data_token(token): @@ -61,15 +66,15 @@ def push_data(data): def read_data(token, stream): if token < OP_PUSHDATA1: return stream.read(token) - elif token == OP_PUSHDATA1: + if token == OP_PUSHDATA1: return stream.read(stream.read_uint8()) - elif token == OP_PUSHDATA2: + if token == OP_PUSHDATA2: return stream.read(stream.read_uint16()) - else: - return stream.read(stream.read_uint32()) + return stream.read(stream.read_uint32()) # opcode for OP_1 - OP_16 +# pylint: disable=invalid-name SMALL_INTEGER = namedtuple('SMALL_INTEGER', 'name') @@ -233,7 +238,7 @@ class Parser: raise ParseError("Not a push single or subscript: {}".format(opcode)) -class Template(object): +class Template: __slots__ = 'name', 'opcodes' @@ -264,11 +269,11 @@ class Template(object): return source.get_bytes() -class Script(object): +class Script: __slots__ = 'source', 'template', 'values' - templates = [] + templates: List[Template] = [] def __init__(self, source=None, template=None, values=None, template_hint=None): self.source = source diff --git a/torba/basetransaction.py b/torba/basetransaction.py index 75c5cb45c..0f326dfa2 100644 --- a/torba/basetransaction.py +++ b/torba/basetransaction.py @@ -1,29 +1,30 @@ -import six import logging -from typing import List, Iterable +import typing +from typing import List, Iterable, Optional from binascii import hexlify from twisted.internet import defer -import torba.baseaccount -import torba.baseledger from torba.basescript import BaseInputScript, BaseOutputScript -from torba.coinselection import CoinSelector +from torba.baseaccount import BaseAccount from torba.constants import COIN, NULL_HASH32 from torba.bcd_data_stream import BCDataStream -from torba.hash import sha256, TXRef, TXRefImmutable, TXORef +from torba.hash import sha256, TXRef, TXRefImmutable from torba.util import ReadOnlyList +if typing.TYPE_CHECKING: + from torba import baseledger + log = logging.getLogger() class TXRefMutable(TXRef): - __slots__ = 'tx', + __slots__ = ('tx',) - def __init__(self, tx): - super(TXRefMutable, self).__init__() + def __init__(self, tx: 'BaseTransaction') -> None: + super().__init__() self.tx = tx @property @@ -43,12 +44,35 @@ class TXRefMutable(TXRef): self._hash = None +class TXORef: + + __slots__ = 'tx_ref', 'position' + + def __init__(self, tx_ref: TXRef, position: int) -> None: + self.tx_ref = tx_ref + self.position = position + + @property + def id(self): + return '{}:{}'.format(self.tx_ref.id, self.position) + + @property + def is_null(self): + return self.tx_ref.is_null + + @property + def txo(self) -> Optional['BaseOutput']: + return None + + class TXORefResolvable(TXORef): - __slots__ = '_txo', + __slots__ = ('_txo',) - def __init__(self, txo): - super(TXORefResolvable, self).__init__(txo.tx_ref, txo.position) + def __init__(self, txo: 'BaseOutput') -> None: + assert txo.tx_ref is not None + assert txo.position is not None + super().__init__(txo.tx_ref, txo.position) self._txo = txo @property @@ -56,23 +80,23 @@ class TXORefResolvable(TXORef): return self._txo -class InputOutput(object): +class InputOutput: __slots__ = 'tx_ref', 'position' - def __init__(self, tx_ref=None, position=None): - self.tx_ref = tx_ref # type: TXRef - self.position = position # type: int + def __init__(self, tx_ref: TXRef = None, position: int = None) -> None: + self.tx_ref = tx_ref + self.position = position @property - def size(self): + def size(self) -> int: """ Size of this input / output in bytes. """ stream = BCDataStream() self.serialize_to(stream) return len(stream.get_bytes()) - def serialize_to(self, stream): - raise NotImplemented + def serialize_to(self, stream, alternate_script=None): + raise NotImplementedError class BaseInput(InputOutput): @@ -84,27 +108,27 @@ class BaseInput(InputOutput): __slots__ = 'txo_ref', 'sequence', 'coinbase', 'script' - def __init__(self, txo_ref, script, sequence=0xFFFFFFFF, tx_ref=None, position=None): - # type: (TXORef, BaseInputScript, int, TXRef, int) -> None - super(BaseInput, self).__init__(tx_ref, position) + def __init__(self, txo_ref: TXORef, script: BaseInputScript, sequence: int = 0xFFFFFFFF, + tx_ref: TXRef = None, position: int = None) -> None: + super().__init__(tx_ref, position) self.txo_ref = txo_ref self.sequence = sequence self.coinbase = script if txo_ref.is_null else None - self.script = script if not txo_ref.is_null else None # type: BaseInputScript + self.script = script if not txo_ref.is_null else None @property def is_coinbase(self): return self.coinbase is not None @classmethod - def spend(cls, txo): # type: (BaseOutput) -> BaseInput + def spend(cls, txo: 'BaseOutput') -> 'BaseInput': """ Create an input to spend the output.""" assert txo.script.is_pay_pubkey_hash, 'Attempting to spend unsupported output.' script = cls.script_class.redeem_pubkey_hash(cls.NULL_SIGNATURE, cls.NULL_PUBLIC_KEY) return cls(txo.ref, script) @property - def amount(self): + def amount(self) -> int: """ Amount this input adds to the transaction. """ if self.txo_ref.txo is None: raise ValueError('Cannot resolve output to get amount.') @@ -135,15 +159,15 @@ class BaseInput(InputOutput): stream.write_uint32(self.sequence) -class BaseOutputEffectiveAmountEstimator(object): +class BaseOutputEffectiveAmountEstimator: __slots__ = 'txo', 'txi', 'fee', 'effective_amount' - def __init__(self, ledger, txo): # type: (torba.baseledger.BaseLedger, BaseOutput) -> None + def __init__(self, ledger: 'baseledger.BaseLedger', txo: 'BaseOutput') -> None: self.txo = txo self.txi = ledger.transaction_class.input_class.spend(txo) - self.fee = ledger.get_input_output_fee(self.txi) - self.effective_amount = txo.amount - self.fee + self.fee: int = ledger.get_input_output_fee(self.txi) + self.effective_amount: int = txo.amount - self.fee def __lt__(self, other): return self.effective_amount < other.effective_amount @@ -156,9 +180,9 @@ class BaseOutput(InputOutput): __slots__ = 'amount', 'script' - def __init__(self, amount, script, tx_ref=None, position=None): - # type: (int, BaseOutputScript, TXRef, int) -> None - super(BaseOutput, self).__init__(tx_ref, position) + def __init__(self, amount: int, script: BaseOutputScript, + tx_ref: TXRef = None, position: int = None) -> None: + super().__init__(tx_ref, position) self.amount = amount self.script = script @@ -184,7 +208,7 @@ class BaseOutput(InputOutput): script=cls.script_class(stream.read_string()) ) - def serialize_to(self, stream): + def serialize_to(self, stream, alternate_script=None): stream.write_uint64(self.amount) stream.write_string(self.script.source) @@ -194,7 +218,7 @@ class BaseTransaction: input_class = BaseInput output_class = BaseOutput - def __init__(self, raw=None, version=1, locktime=0): + def __init__(self, raw=None, version=1, locktime=0) -> None: self._raw = raw self.ref = TXRefMutable(self) self.version = version # type: int @@ -230,8 +254,7 @@ class BaseTransaction: def outputs(self): # type: () -> ReadOnlyList[BaseOutput] return ReadOnlyList(self._outputs) - def _add(self, new_ios, existing_ios): - # type: (List[InputOutput], List[InputOutput]) -> BaseTransaction + def _add(self, new_ios: Iterable[InputOutput], existing_ios: List) -> 'BaseTransaction': for txio in new_ios: txio.tx_ref = self.ref txio.position = len(existing_ios) @@ -239,28 +262,28 @@ class BaseTransaction: self._reset() return self - def add_inputs(self, inputs): # type: (List[BaseInput]) -> BaseTransaction + def add_inputs(self, inputs: Iterable[BaseInput]) -> 'BaseTransaction': return self._add(inputs, self._inputs) - def add_outputs(self, outputs): # type: (List[BaseOutput]) -> BaseTransaction + def add_outputs(self, outputs: Iterable[BaseOutput]) -> 'BaseTransaction': return self._add(outputs, self._outputs) @property - def fee(self): # type: () -> int + def fee(self) -> int: """ Fee that will actually be paid.""" return self.input_sum - self.output_sum @property - def size(self): # type: () -> int + def size(self) -> int: """ Size in bytes of the entire transaction. """ return len(self.raw) @property - def base_size(self): # type: () -> int + def base_size(self) -> int: """ Size in bytes of transaction meta data and all outputs; without inputs. """ return len(self._serialize(with_inputs=False)) - def _serialize(self, with_inputs=True): # type: (bool) -> bytes + def _serialize(self, with_inputs: bool = True) -> bytes: stream = BCDataStream() stream.write_uint32(self.version) if with_inputs: @@ -273,12 +296,13 @@ class BaseTransaction: stream.write_uint32(self.locktime) return stream.get_bytes() - def _serialize_for_signature(self, signing_input): # type: (int) -> bytes + def _serialize_for_signature(self, signing_input: int) -> bytes: stream = BCDataStream() stream.write_uint32(self.version) stream.write_compact_size(len(self._inputs)) for i, txin in enumerate(self._inputs): if signing_input == i: + assert txin.txo_ref.txo is not None txin.serialize_to(stream, txin.txo_ref.txo.script.source) else: txin.serialize_to(stream, b'') @@ -304,8 +328,9 @@ class BaseTransaction: self.locktime = stream.read_uint32() @classmethod - def ensure_all_have_same_ledger(cls, funding_accounts, change_account=None): - # type: (Iterable[torba.baseaccount.BaseAccount], torba.baseaccount.BaseAccount) -> torba.baseledger.BaseLedger + def ensure_all_have_same_ledger( + cls, funding_accounts: Iterable[BaseAccount], change_account: BaseAccount = None)\ + -> 'baseledger.BaseLedger': ledger = None for account in funding_accounts: if ledger is None: @@ -316,33 +341,24 @@ class BaseTransaction: ) if change_account is not None and change_account.ledger != ledger: raise ValueError('Change account must use same ledger as funding accounts.') + if ledger is None: + raise ValueError('No ledger found.') return ledger @classmethod @defer.inlineCallbacks - def pay(cls, outputs, funding_accounts, change_account, reserve_outputs=True): - # type: (List[BaseOutput], List[torba.baseaccount.BaseAccount], torba.baseaccount.BaseAccount) -> defer.Deferred + def pay(cls, outputs: Iterable[BaseOutput], funding_accounts: Iterable[BaseAccount], + change_account: BaseAccount): """ Efficiently spend utxos from funding_accounts to cover the new outputs. """ tx = cls().add_outputs(outputs) ledger = cls.ensure_all_have_same_ledger(funding_accounts, change_account) amount = tx.output_sum + ledger.get_transaction_base_fee(tx) - txos = yield ledger.get_effective_amount_estimators(funding_accounts) - selector = CoinSelector( - txos, amount, - ledger.get_input_output_fee( - cls.output_class.pay_pubkey_hash(COIN, NULL_HASH32) - ) - ) + spendables = yield ledger.get_spendable_utxos(amount, funding_accounts) - spendables = selector.select() if not spendables: raise ValueError('Not enough funds to cover this transaction.') - reserved_outputs = [s.txo.id for s in spendables] - if reserve_outputs: - yield ledger.db.reserve_spent_outputs(reserved_outputs) - try: spent_sum = sum(s.effective_amount for s in spendables) if spent_sum > amount: @@ -351,30 +367,25 @@ class BaseTransaction: change_amount = spent_sum - amount tx.add_outputs([cls.output_class.pay_pubkey_hash(change_amount, change_hash160)]) - tx.add_inputs([s.txi for s in spendables]) + tx.add_inputs(s.txi for s in spendables) yield tx.sign(funding_accounts) - except Exception: - if reserve_outputs: - yield ledger.db.release_reserved_outputs(reserved_outputs) - raise + except Exception as e: + log.exception('Failed to synchronize transaction:') + yield ledger.release_outputs(s.txo for s in spendables) + raise e defer.returnValue(tx) @classmethod @defer.inlineCallbacks - def liquidate(cls, assets, funding_accounts, change_account, reserve_outputs=True): + def liquidate(cls, assets, funding_accounts, change_account): """ Spend assets (utxos) supplementing with funding_accounts if fee is higher than asset value. """ - tx = cls().add_inputs([ cls.input_class.spend(utxo) for utxo in assets ]) ledger = cls.ensure_all_have_same_ledger(funding_accounts, change_account) - - reserved_outputs = [utxo.id for utxo in assets] - if reserve_outputs: - yield ledger.db.reserve_spent_outputs(reserved_outputs) - + yield ledger.reserve_outputs(assets) try: cost_of_change = ( ledger.get_transaction_base_fee(tx) + @@ -386,41 +397,35 @@ class BaseTransaction: change_hash160 = change_account.ledger.address_to_hash160(change_address) change_amount = liquidated_total - cost_of_change tx.add_outputs([cls.output_class.pay_pubkey_hash(change_amount, change_hash160)]) - yield tx.sign(funding_accounts) - except Exception: - if reserve_outputs: - yield ledger.db.release_reserved_outputs(reserved_outputs) + yield ledger.release_outputs(assets) raise - defer.returnValue(tx) - def signature_hash_type(self, hash_type): + @staticmethod + def signature_hash_type(hash_type): return hash_type @defer.inlineCallbacks - def sign(self, funding_accounts): # type: (Iterable[torba.baseaccount.BaseAccount]) -> BaseTransaction + def sign(self, funding_accounts: Iterable[BaseAccount]) -> defer.Deferred: ledger = self.ensure_all_have_same_ledger(funding_accounts) for i, txi in enumerate(self._inputs): + assert txi.script is not None + assert txi.txo_ref.txo is not None txo_script = txi.txo_ref.txo.script if txo_script.is_pay_pubkey_hash: address = ledger.hash160_to_address(txo_script.values['pubkey_hash']) private_key = yield ledger.get_private_key_for_address(address) tx = self._serialize_for_signature(i) txi.script.values['signature'] = \ - private_key.sign(tx) + six.int2byte(self.signature_hash_type(1)) + private_key.sign(tx) + bytes((self.signature_hash_type(1),)) txi.script.values['pubkey'] = private_key.public_key.pubkey_bytes txi.script.generate() else: raise NotImplementedError("Don't know how to spend this output.") self._reset() - def sort(self): - # See https://github.com/kristovatlas/rfc/blob/master/bips/bip-li01.mediawiki - self._inputs.sort(key=lambda i: (i['prevout_hash'], i['prevout_n'])) - self._outputs.sort(key=lambda o: (o[2], pay_script(o[0], o[1]))) - @property def input_sum(self): return sum(i.amount for i in self.inputs) diff --git a/torba/bcd_data_stream.py b/torba/bcd_data_stream.py index ab06cdecc..1c04be005 100644 --- a/torba/bcd_data_stream.py +++ b/torba/bcd_data_stream.py @@ -35,9 +35,9 @@ class BCDataStream: return size if size == 253: return self.read_uint16() - elif size == 254: + if size == 254: return self.read_uint32() - elif size == 255: + if size == 255: return self.read_uint64() def write_compact_size(self, size): @@ -70,7 +70,7 @@ class BCDataStream: def _read_struct(self, fmt): value = self.read(fmt.size) - if len(value) > 0: + if value: return fmt.unpack(value)[0] def read_int8(self): diff --git a/torba/bip32.py b/torba/bip32.py index 5bb06cd1e..7ef76aea6 100644 --- a/torba/bip32.py +++ b/torba/bip32.py @@ -10,7 +10,6 @@ import struct import hashlib -from six import int2byte, byte2int, indexbytes import ecdsa import ecdsa.ellipticcurve as EC @@ -24,7 +23,7 @@ class DerivationError(Exception): """ Raised when an invalid derivation occurs. """ -class _KeyBase(object): +class _KeyBase: """ A BIP32 Key, public or private. """ CURVE = ecdsa.SECP256k1 @@ -63,17 +62,23 @@ class _KeyBase(object): if len(raw_serkey) != 33: raise ValueError('raw_serkey must have length 33') - return (ver_bytes + int2byte(self.depth) + return (ver_bytes + bytes((self.depth,)) + self.parent_fingerprint() + struct.pack('>I', self.n) + self.chain_code + raw_serkey) + def identifier(self): + raise NotImplementedError + + def extended_key(self): + raise NotImplementedError + def fingerprint(self): """ Return the key's fingerprint as 4 bytes. """ return self.identifier()[:4] def parent_fingerprint(self): """ Return the parent key's fingerprint as 4 bytes. """ - return self.parent.fingerprint() if self.parent else int2byte(0)*4 + return self.parent.fingerprint() if self.parent else bytes((0,)*4) def extended_key_string(self): """ Return an extended key as a base58 string. """ @@ -84,7 +89,7 @@ class PubKey(_KeyBase): """ A BIP32 public key. """ def __init__(self, ledger, pubkey, chain_code, n, depth, parent=None): - super(PubKey, self).__init__(ledger, chain_code, n, depth, parent) + super().__init__(ledger, chain_code, n, depth, parent) if isinstance(pubkey, ecdsa.VerifyingKey): self.verifying_key = pubkey else: @@ -97,16 +102,16 @@ class PubKey(_KeyBase): raise TypeError('pubkey must be raw bytes') if len(pubkey) != 33: raise ValueError('pubkey must be 33 bytes') - if indexbytes(pubkey, 0) not in (2, 3): + if pubkey[0] not in (2, 3): raise ValueError('invalid pubkey prefix byte') curve = cls.CURVE.curve - is_odd = indexbytes(pubkey, 0) == 3 + is_odd = pubkey[0] == 3 x = bytes_to_int(pubkey[1:]) # p is the finite field order - a, b, p = curve.a(), curve.b(), curve.p() - y2 = pow(x, 3, p) + b + a, b, p = curve.a(), curve.b(), curve.p() # pylint: disable=invalid-name + y2 = pow(x, 3, p) + b # pylint: disable=invalid-name assert a == 0 # Otherwise y2 += a * pow(x, 2, p) y = NT.square_root_mod_prime(y2 % p, p) if bool(y & 1) != is_odd: @@ -119,7 +124,7 @@ class PubKey(_KeyBase): def pubkey_bytes(self): """ Return the compressed public key as 33 bytes. """ point = self.verifying_key.pubkey.point - prefix = int2byte(2 + (point.y() & 1)) + prefix = bytes((2 + (point.y() & 1),)) padded_bytes = _exponent_to_bytes(point.x()) return prefix + padded_bytes @@ -137,10 +142,10 @@ class PubKey(_KeyBase): raise ValueError('invalid BIP32 public key child number') msg = self.pubkey_bytes + struct.pack('>I', n) - L, R = self._hmac_sha512(msg) + L, R = self._hmac_sha512(msg) # pylint: disable=invalid-name curve = self.CURVE - L = bytes_to_int(L) + L = bytes_to_int(L) # pylint: disable=invalid-name if L >= curve.order: raise DerivationError @@ -172,7 +177,7 @@ class LowSValueSigningKey(ecdsa.SigningKey): def sign_number(self, number, entropy=None, k=None): order = self.privkey.order - r, s = ecdsa.SigningKey.sign_number(self, number, entropy, k) + r, s = ecdsa.SigningKey.sign_number(self, number, entropy, k) # pylint: disable=invalid-name if s > order / 2: s = order - s return r, s @@ -184,7 +189,7 @@ class PrivateKey(_KeyBase): HARDENED = 1 << 31 def __init__(self, ledger, privkey, chain_code, n, depth, parent=None): - super(PrivateKey, self).__init__(ledger, chain_code, n, depth, parent) + super().__init__(ledger, chain_code, n, depth, parent) if isinstance(privkey, ecdsa.SigningKey): self.signing_key = privkey else: @@ -254,10 +259,10 @@ class PrivateKey(_KeyBase): serkey = self.public_key.pubkey_bytes msg = serkey + struct.pack('>I', n) - L, R = self._hmac_sha512(msg) + L, R = self._hmac_sha512(msg) # pylint: disable=invalid-name curve = self.CURVE - L = bytes_to_int(L) + L = bytes_to_int(L) # pylint: disable=invalid-name exponent = (L + bytes_to_int(self.private_key_bytes)) % curve.order if exponent == 0 or L >= curve.order: raise DerivationError @@ -286,7 +291,7 @@ class PrivateKey(_KeyBase): def _exponent_to_bytes(exponent): """Convert an exponent to 32 big-endian bytes""" - return (int2byte(0)*32 + int_to_bytes(exponent))[-32:] + return (bytes((0,)*32) + int_to_bytes(exponent))[-32:] def _from_extended_key(ledger, ekey): @@ -296,8 +301,8 @@ def _from_extended_key(ledger, ekey): if len(ekey) != 78: raise ValueError('extended key must have length 78') - depth = indexbytes(ekey, 4) - fingerprint = ekey[5:9] # Not used + depth = ekey[4] + # fingerprint = ekey[5:9] n, = struct.unpack('>I', ekey[9:13]) chain_code = ekey[13:45] @@ -305,7 +310,7 @@ def _from_extended_key(ledger, ekey): pubkey = ekey[45:] key = PubKey(ledger, pubkey, chain_code, n, depth) elif ekey[:4] == ledger.extended_private_key_prefix: - if indexbytes(ekey, 45) != 0: + if ekey[45] != 0: raise ValueError('invalid extended private key prefix byte') privkey = ekey[46:] key = PrivateKey(ledger, privkey, chain_code, n, depth) diff --git a/torba/coin/__init__.py b/torba/coin/__init__.py index 69e3be50d..97b69ed6e 100644 --- a/torba/coin/__init__.py +++ b/torba/coin/__init__.py @@ -1 +1 @@ -__path__ = __import__('pkgutil').extend_path(__path__, __name__) +__path__: str = __import__('pkgutil').extend_path(__path__, __name__) diff --git a/torba/coin/bitcoincash.py b/torba/coin/bitcoincash.py index ca9a76a41..6678af616 100644 --- a/torba/coin/bitcoincash.py +++ b/torba/coin/bitcoincash.py @@ -6,7 +6,6 @@ __node_url__ = ( ) __electrumx__ = 'electrumx.lib.coins.BitcoinCashRegtest' -from six import int2byte from binascii import unhexlify from torba.baseledger import BaseLedger from torba.baseheader import BaseHeaders @@ -26,8 +25,8 @@ class MainNetLedger(BaseLedger): transaction_class = Transaction - pubkey_address_prefix = int2byte(0x00) - script_address_prefix = int2byte(0x05) + pubkey_address_prefix = bytes((0,)) + script_address_prefix = bytes((5,)) extended_public_key_prefix = unhexlify('0488b21e') extended_private_key_prefix = unhexlify('0488ade4') @@ -42,8 +41,8 @@ class RegTestLedger(MainNetLedger): headers_class = UnverifiedHeaders network_name = 'regtest' - pubkey_address_prefix = int2byte(111) - script_address_prefix = int2byte(196) + pubkey_address_prefix = bytes((111,)) + script_address_prefix = bytes((196,)) extended_public_key_prefix = unhexlify('043587cf') extended_private_key_prefix = unhexlify('04358394') diff --git a/torba/coin/bitcoinsegwit.py b/torba/coin/bitcoinsegwit.py index 68301a57a..fa293e072 100644 --- a/torba/coin/bitcoinsegwit.py +++ b/torba/coin/bitcoinsegwit.py @@ -6,7 +6,6 @@ __node_url__ = ( ) __electrumx__ = 'electrumx.lib.coins.BitcoinSegwitRegtest' -from six import int2byte from binascii import unhexlify from torba.baseledger import BaseLedger from torba.baseheader import BaseHeaders @@ -17,8 +16,8 @@ class MainNetLedger(BaseLedger): symbol = 'BTC' network_name = 'mainnet' - pubkey_address_prefix = int2byte(0x00) - script_address_prefix = int2byte(0x05) + pubkey_address_prefix = bytes((0,)) + script_address_prefix = bytes((5,)) extended_public_key_prefix = unhexlify('0488b21e') extended_private_key_prefix = unhexlify('0488ade4') @@ -33,8 +32,8 @@ class RegTestLedger(MainNetLedger): headers_class = UnverifiedHeaders network_name = 'regtest' - pubkey_address_prefix = int2byte(111) - script_address_prefix = int2byte(196) + pubkey_address_prefix = bytes((111,)) + script_address_prefix = bytes((196,)) extended_public_key_prefix = unhexlify('043587cf') extended_private_key_prefix = unhexlify('04358394') diff --git a/torba/coinselection.py b/torba/coinselection.py index 0c1adede0..bf3709636 100644 --- a/torba/coinselection.py +++ b/torba/coinselection.py @@ -1,16 +1,15 @@ -import six from random import Random from typing import List -import torba +from torba import basetransaction MAXIMUM_TRIES = 100000 class CoinSelector: - def __init__(self, txos, target, cost_of_change, seed=None): - # type: (List[torba.basetransaction.BaseOutputAmountEstimator], int, int, str) -> None + def __init__(self, txos: List[basetransaction.BaseOutputEffectiveAmountEstimator], + target: int, cost_of_change: int, seed: str = None) -> None: self.txos = txos self.target = target self.cost_of_change = cost_of_change @@ -18,17 +17,17 @@ class CoinSelector: self.tries = 0 self.available = sum(c.effective_amount for c in self.txos) self.random = Random(seed) - if six.PY3 and seed is not None: + if seed is not None: self.random.seed(seed, version=1) - def select(self): # type: () -> List[torba.basetransaction.BaseOutputAmountEstimator] + def select(self) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]: if not self.txos: - return + return [] if self.target > self.available: - return + return [] return self.branch_and_bound() or self.single_random_draw() - def branch_and_bound(self): # type: () -> List[torba.basetransaction.BaseOutputAmountEstimator] + def branch_and_bound(self) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]: # see bitcoin implementation for more info: # https://github.com/bitcoin/bitcoin/blob/master/src/wallet/coinselection.cpp @@ -36,9 +35,9 @@ class CoinSelector: current_value = 0 current_available_value = self.available - current_selection = [] + current_selection: List[bool] = [] best_waste = self.cost_of_change - best_selection = [] + best_selection: List[bool] = [] while self.tries < MAXIMUM_TRIES: self.tries += 1 @@ -70,7 +69,7 @@ class CoinSelector: utxo = self.txos[len(current_selection)] current_available_value -= utxo.effective_amount previous_utxo = self.txos[len(current_selection) - 1] if current_selection else None - if current_selection and not current_selection[-1] and \ + if current_selection and not current_selection[-1] and previous_utxo and \ utxo.effective_amount == previous_utxo.effective_amount and \ utxo.fee == previous_utxo.fee: current_selection.append(False) @@ -84,7 +83,9 @@ class CoinSelector: self.txos[i] for i, include in enumerate(best_selection) if include ] - def single_random_draw(self): # type: () -> List[torba.basetransaction.BaseOutputAmountEstimator] + return [] + + def single_random_draw(self) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]: self.random.shuffle(self.txos, self.random.random) selection = [] amount = 0 @@ -93,3 +94,4 @@ class CoinSelector: amount += coin.effective_amount if amount >= self.target+self.cost_of_change: return selection + return [] diff --git a/torba/hash.py b/torba/hash.py index af1f74897..cf8f2f047 100644 --- a/torba/hash.py +++ b/torba/hash.py @@ -9,7 +9,6 @@ """ Cryptography hash functions and related classes. """ import os -import six import base64 import hashlib import hmac @@ -22,13 +21,8 @@ from cryptography.hazmat.backends import default_backend from torba.util import bytes_to_int, int_to_bytes from torba.constants import NULL_HASH32 -_sha256 = hashlib.sha256 -_sha512 = hashlib.sha512 -_new_hash = hashlib.new -_new_hmac = hmac.new - -class TXRef(object): +class TXRef: __slots__ = '_id', '_hash' @@ -68,50 +62,29 @@ class TXRefImmutable(TXRef): return ref -class TXORef(object): - - __slots__ = 'tx_ref', 'position' - - def __init__(self, tx_ref, position): # type: (TXRef, int) -> None - self.tx_ref = tx_ref - self.position = position - - @property - def id(self): - return '{}:{}'.format(self.tx_ref.id, self.position) - - @property - def is_null(self): - return self.tx_ref.is_null - - @property - def txo(self): - return None - - def sha256(x): """ Simple wrapper of hashlib sha256. """ - return _sha256(x).digest() + return hashlib.sha256(x).digest() def sha512(x): """ Simple wrapper of hashlib sha512. """ - return _sha512(x).digest() + return hashlib.sha512(x).digest() def ripemd160(x): """ Simple wrapper of hashlib ripemd160. """ - h = _new_hash('ripemd160') + h = hashlib.new('ripemd160') h.update(x) return h.digest() def pow_hash(x): - r = sha512(double_sha256(x)) - r1 = ripemd160(r[:len(r) // 2]) - r2 = ripemd160(r[len(r) // 2:]) - r3 = double_sha256(r1 + r2) - return r3 + h = sha512(double_sha256(x)) + return double_sha256( + ripemd160(h[:len(h) // 2]) + + ripemd160(h[len(h) // 2:]) + ) def double_sha256(x): @@ -121,7 +94,7 @@ def double_sha256(x): def hmac_sha512(key, msg): """ Use SHA-512 to provide an HMAC. """ - return _new_hmac(key, msg, _sha512).digest() + return hmac.new(key, msg, hashlib.sha512).digest() def hash160(x): @@ -165,7 +138,7 @@ class Base58Error(Exception): """ Exception used for Base58 errors. """ -class Base58(object): +class Base58: """ Class providing base 58 functionality. """ chars = u'123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz' @@ -207,7 +180,7 @@ class Base58(object): break count += 1 if count: - result = six.int2byte(0) * count + result + result = bytes((0,)) * count + result return result diff --git a/torba/mnemonic.py b/torba/mnemonic.py index df1f5a238..3cabcf6cb 100644 --- a/torba/mnemonic.py +++ b/torba/mnemonic.py @@ -56,7 +56,7 @@ CJK_INTERVALS = [ def is_cjk(c): n = ord(c) - for start, end, name in CJK_INTERVALS: + for start, end, _ in CJK_INTERVALS: if start <= n <= end: return True return False @@ -93,7 +93,7 @@ def load_words(filename): return words -file_names = { +FILE_NAMES = { 'en': 'english.txt', 'es': 'spanish.txt', 'ja': 'japanese.txt', @@ -102,20 +102,22 @@ file_names = { } -class Mnemonic(object): +class Mnemonic: # Seed derivation no longer follows BIP39 # Mnemonic phrase uses a hash based checksum, instead of a words-dependent checksum def __init__(self, lang='en'): - filename = file_names.get(lang, 'english.txt') + filename = FILE_NAMES.get(lang, 'english.txt') self.words = load_words(filename) - @classmethod - def mnemonic_to_seed(self, mnemonic, passphrase=u''): - PBKDF2_ROUNDS = 2048 + @staticmethod + def mnemonic_to_seed(mnemonic, passphrase=u''): + pbkdf2_rounds = 2048 mnemonic = normalize_text(mnemonic) passphrase = normalize_text(passphrase) - return pbkdf2.PBKDF2(mnemonic, passphrase, iterations=PBKDF2_ROUNDS, macmodule=hmac, digestmodule=hashlib.sha512).read(64) + return pbkdf2.PBKDF2( + mnemonic, passphrase, iterations=pbkdf2_rounds, macmodule=hmac, digestmodule=hashlib.sha512 + ).read(64) def mnemonic_encode(self, i): n = len(self.words) @@ -131,8 +133,8 @@ class Mnemonic(object): words = seed.split() i = 0 while words: - w = words.pop() - k = self.words.index(w) + word = words.pop() + k = self.words.index(word) i = i*n + k return i diff --git a/torba/stream.py b/torba/stream.py index 6438b7184..79fbba796 100644 --- a/torba/stream.py +++ b/torba/stream.py @@ -1,27 +1,7 @@ -import six -from twisted.internet.defer import Deferred, DeferredLock, maybeDeferred, inlineCallbacks +import asyncio +from twisted.internet.defer import Deferred from twisted.python.failure import Failure -if six.PY3: - import asyncio - - -def execute_serially(f): - _lock = DeferredLock() - - @inlineCallbacks - def allow_only_one_at_a_time(*args, **kwargs): - yield _lock.acquire() - allow_only_one_at_a_time.is_running = True - try: - yield maybeDeferred(f, *args, **kwargs) - finally: - allow_only_one_at_a_time.is_running = False - _lock.release() - - allow_only_one_at_a_time.is_running = False - return allow_only_one_at_a_time - class BroadcastSubscription: @@ -76,10 +56,10 @@ class StreamController: @property def _iterate_subscriptions(self): - next = self._first_subscription - while next is not None: - subscription = next - next = next._next + next_sub = self._first_subscription + while next_sub is not None: + subscription = next_sub + next_sub = next_sub._next yield subscription def add(self, event): @@ -96,15 +76,15 @@ class StreamController: def _cancel(self, subscription): previous = subscription._previous - next = subscription._next + next_sub = subscription._next if previous is None: - self._first_subscription = next + self._first_subscription = next_sub else: - previous._next = next - if next is None: + previous._next = next_sub + if next_sub is None: self._last_subscription = previous else: - next._previous = previous + next_sub._previous = previous subscription._next = subscription._previous = subscription def _listen(self, on_data, on_error, on_done): diff --git a/torba/util.py b/torba/util.py index 9683a8dfe..00ee6ba4d 100644 --- a/torba/util.py +++ b/torba/util.py @@ -1,20 +1,19 @@ from binascii import unhexlify, hexlify -from collections import Sequence -from typing import TypeVar, Generic +from typing import TypeVar, Sequence T = TypeVar('T') -class ReadOnlyList(Sequence, Generic[T]): +class ReadOnlyList(Sequence[T]): def __init__(self, lst): self.lst = lst - def __getitem__(self, key): # type: (int) -> T + def __getitem__(self, key): return self.lst[key] - def __len__(self): + def __len__(self) -> int: return len(self.lst) @@ -22,13 +21,13 @@ def subclass_tuple(name, base): return type(name, (base,), {'__slots__': ()}) -class cachedproperty(object): +class cachedproperty: def __init__(self, f): self.f = f - def __get__(self, obj, type): - obj = obj or type + def __get__(self, obj, objtype): + obj = obj or objtype value = self.f(obj) setattr(obj, self.f.__name__, value) return value @@ -42,8 +41,8 @@ def bytes_to_int(be_bytes): def int_to_bytes(value): """ Converts an integer to a big-endian sequence of bytes. """ length = (value.bit_length() + 7) // 8 - h = '%x' % value - return unhexlify(('0' * (len(h) % 2) + h).zfill(length * 2)) + s = '%x' % value + return unhexlify(('0' * (len(s) % 2) + s).zfill(length * 2)) def rev_hex(s): @@ -56,8 +55,8 @@ def int_to_hex(i, length=1): return rev_hex(s) -def hex_to_int(s): - return int(b'0x' + hexlify(s[::-1]), 16) +def hex_to_int(x): + return int(b'0x' + hexlify(x[::-1]), 16) def hash_encode(x): diff --git a/torba/wallet.py b/torba/wallet.py index 1fe486396..ca6d48fab 100644 --- a/torba/wallet.py +++ b/torba/wallet.py @@ -1,10 +1,13 @@ import stat import json import os -from typing import List +import typing +from typing import Sequence, MutableSequence -import torba.baseaccount -import torba.baseledger +if typing.TYPE_CHECKING: + from torba import baseaccount + from torba import baseledger + from torba import basemanager class Wallet: @@ -14,24 +17,24 @@ class Wallet: by physical files on the filesystem. """ - def __init__(self, name='Wallet', accounts=None, storage=None): - # type: (str, List[torba.baseaccount.BaseAccount], WalletStorage) -> None + def __init__(self, name: str = 'Wallet', accounts: MutableSequence['baseaccount.BaseAccount'] = None, + storage: 'WalletStorage' = None) -> None: self.name = name - self.accounts = accounts or [] # type: List[torba.baseaccount.BaseAccount] + self.accounts = accounts or [] self.storage = storage or WalletStorage() - def generate_account(self, ledger): - # type: (torba.baseledger.BaseLedger) -> torba.baseaccount.BaseAccount + def generate_account(self, ledger: 'baseledger.BaseLedger') -> 'baseaccount.BaseAccount': account = ledger.account_class.generate(ledger, u'torba') self.accounts.append(account) return account @classmethod - def from_storage(cls, storage, manager): # type: (WalletStorage, 'WalletManager') -> Wallet + def from_storage(cls, storage: 'WalletStorage', manager: 'basemanager.BaseWalletManager') -> 'Wallet': json_dict = storage.read() accounts = [] - for account_dict in json_dict.get('accounts', []): + account_dicts: Sequence[dict] = json_dict.get('accounts', []) + for account_dict in account_dicts: ledger = manager.get_or_create_ledger(account_dict['ledger']) account = ledger.account_class.from_dict(ledger, account_dict) accounts.append(account) @@ -110,7 +113,7 @@ class WalletStorage: mode = stat.S_IREAD | stat.S_IWRITE try: os.rename(temp_path, self.path) - except: + except Exception: # pylint: disable=broad-except os.remove(self.path) os.rename(temp_path, self.path) os.chmod(self.path, mode)