+ reserving outpoints should no longer have race conditions

+ converted all comment type annotations to python 3 style syntax annotations
+ pylint & mypy
This commit is contained in:
Lex Berezhny 2018-07-28 20:52:54 -04:00
parent 746aef474a
commit b0bd0b1fc0
26 changed files with 482 additions and 447 deletions

View file

@ -5,13 +5,25 @@ language: python
python: python:
- "3.7" - "3.7"
install: jobs:
include:
- stage: code quality
name: "pylint & mypy"
install:
- pip install pylint mypy
- pip install -e .
script:
- pylint --rcfile=setup.cfg torba
- mypy torba
- stage: test
name: "Unit Tests"
install:
- pip install tox-travis coverage - pip install tox-travis coverage
- pushd .. && git clone https://github.com/lbryio/electrumx.git --branch lbryumx && popd - pushd .. && git clone https://github.com/lbryio/electrumx.git --branch lbryumx && popd
- pushd .. && git clone https://github.com/lbryio/orchstr8.git && popd - pushd .. && git clone https://github.com/lbryio/orchstr8.git && popd
script: tox
script: tox after_success:
after_success:
- coverage combine tests/ - coverage combine tests/
- bash <(curl -s https://codecov.io/bash) - bash <(curl -s https://codecov.io/bash)

View file

@ -5,3 +5,26 @@ branch = True
source = source =
torba torba
.tox/*/lib/python*/site-packages/torba .tox/*/lib/python*/site-packages/torba
[mypy-twisted.*,cryptography.*,ecdsa.*,pbkdf2]
ignore_missing_imports = True
[pylint]
max-args=10
max-line-length=110
good-names=T,t,n,i,j,k,x,y,s,f,d,h,c,e,op,db,tx,io,cachedproperty,log,id
valid-metaclass-classmethod-first-arg=mcs
disable=
fixme,
no-else-return,
cyclic-import,
missing-docstring,
duplicate-code,
expression-not-assigned,
inconsistent-return-statements,
too-few-public-methods,
too-many-locals,
too-many-arguments,
too-many-public-methods,
too-many-instance-attributes,
protected-access

View file

@ -32,8 +32,7 @@ setup(
'twisted', 'twisted',
'ecdsa', 'ecdsa',
'pbkdf2', 'pbkdf2',
'cryptography', 'cryptography'
'typing'
), ),
extras_require={ extras_require={
'test': ( 'test': (

View file

@ -30,13 +30,13 @@ class BaseSelectionTestCase(unittest.TestCase):
class TestCoinSelectionTests(BaseSelectionTestCase): class TestCoinSelectionTests(BaseSelectionTestCase):
def test_empty_coins(self): def test_empty_coins(self):
self.assertIsNone(CoinSelector([], 0, 0).select()) self.assertEqual(CoinSelector([], 0, 0).select(), [])
def test_skip_binary_search_if_total_not_enough(self): def test_skip_binary_search_if_total_not_enough(self):
fee = utxo(CENT).get_estimator(self.ledger).fee fee = utxo(CENT).get_estimator(self.ledger).fee
big_pool = self.estimates(utxo(CENT+fee) for _ in range(100)) big_pool = self.estimates(utxo(CENT+fee) for _ in range(100))
selector = CoinSelector(big_pool, 101 * CENT, 0) selector = CoinSelector(big_pool, 101 * CENT, 0)
self.assertIsNone(selector.select()) self.assertEqual(selector.select(), [])
self.assertEqual(selector.tries, 0) # Never tried. self.assertEqual(selector.tries, 0) # Never tried.
# check happy path # check happy path
selector = CoinSelector(big_pool, 100 * CENT, 0) selector = CoinSelector(big_pool, 100 * CENT, 0)
@ -108,7 +108,7 @@ class TestOfficialBitcoinCoinSelectionTests(BaseSelectionTestCase):
self.assertEqual([3 * CENT, 2 * CENT], search(utxo_pool, 5 * CENT, 0.5 * CENT)) self.assertEqual([3 * CENT, 2 * CENT], search(utxo_pool, 5 * CENT, 0.5 * CENT))
# Select 11 Cent, not possible # Select 11 Cent, not possible
self.assertIsNone(search(utxo_pool, 11 * CENT, 0.5 * CENT)) self.assertEqual(search(utxo_pool, 11 * CENT, 0.5 * CENT), [])
# Select 10 Cent # Select 10 Cent
utxo_pool += self.estimates(utxo(5 * CENT)) utxo_pool += self.estimates(utxo(5 * CENT))
@ -126,12 +126,12 @@ class TestOfficialBitcoinCoinSelectionTests(BaseSelectionTestCase):
) )
# Select 0.25 Cent, not possible # Select 0.25 Cent, not possible
self.assertIsNone(search(utxo_pool, 0.25 * CENT, 0.5 * CENT)) self.assertEqual(search(utxo_pool, 0.25 * CENT, 0.5 * CENT), [])
# Iteration exhaustion test # Iteration exhaustion test
utxo_pool, target = self.make_hard_case(17) utxo_pool, target = self.make_hard_case(17)
selector = CoinSelector(utxo_pool, target, 0) selector = CoinSelector(utxo_pool, target, 0)
self.assertIsNone(selector.branch_and_bound()) self.assertEqual(selector.branch_and_bound(), [])
self.assertEqual(selector.tries, MAXIMUM_TRIES) # Should exhaust self.assertEqual(selector.tries, MAXIMUM_TRIES) # Should exhaust
utxo_pool, target = self.make_hard_case(14) utxo_pool, target = self.make_hard_case(14)
self.assertIsNotNone(search(utxo_pool, target, 0)) # Should not exhaust self.assertIsNotNone(search(utxo_pool, target, 0)) # Should not exhaust
@ -152,4 +152,4 @@ class TestOfficialBitcoinCoinSelectionTests(BaseSelectionTestCase):
# Select 1 Cent with pool of only greater than 5 Cent # Select 1 Cent with pool of only greater than 5 Cent
utxo_pool = self.estimates(utxo(i * CENT) for i in range(5, 21)) utxo_pool = self.estimates(utxo(i * CENT) for i in range(5, 21))
for _ in range(100): for _ in range(100):
self.assertIsNone(search(utxo_pool, 1 * CENT, 2 * CENT)) self.assertEqual(search(utxo_pool, 1 * CENT, 2 * CENT), [])

View file

@ -1,4 +1,3 @@
import six
from binascii import hexlify from binascii import hexlify
from twisted.trial import unittest from twisted.trial import unittest
from twisted.internet import defer from twisted.internet import defer
@ -7,9 +6,6 @@ from torba.coin.bitcoinsegwit import MainNetLedger
from .test_transaction import get_transaction, get_output from .test_transaction import get_transaction, get_output
if six.PY3:
buffer = memoryview
class MockNetwork: class MockNetwork:
@ -50,9 +46,7 @@ class MainNetTestLedger(MainNetLedger):
network_name = 'unittest' network_name = 'unittest'
def __init__(self): def __init__(self):
super(MainNetLedger, self).__init__({ super().__init__({'db': MainNetLedger.database_class(':memory:')})
'db': MainNetLedger.database_class(':memory:')
})
class LedgerTestCase(unittest.TestCase): class LedgerTestCase(unittest.TestCase):

View file

@ -3,14 +3,14 @@ from twisted.trial import unittest
from torba.coin.bitcoinsegwit import MainNetLedger as BTCLedger from torba.coin.bitcoinsegwit import MainNetLedger as BTCLedger
from torba.coin.bitcoincash import MainNetLedger as BCHLedger from torba.coin.bitcoincash import MainNetLedger as BCHLedger
from torba.manager import WalletManager from torba.basemanager import BaseWalletManager
from torba.wallet import Wallet, WalletStorage from torba.wallet import Wallet, WalletStorage
class TestWalletCreation(unittest.TestCase): class TestWalletCreation(unittest.TestCase):
def setUp(self): def setUp(self):
self.manager = WalletManager() self.manager = BaseWalletManager()
config = {'data_path': '/tmp/wallet'} config = {'data_path': '/tmp/wallet'}
self.btc_ledger = self.manager.get_or_create_ledger(BTCLedger.get_id(), config) self.btc_ledger = self.manager.get_or_create_ledger(BTCLedger.get_id(), config)
self.bch_ledger = self.manager.get_or_create_ledger(BCHLedger.get_id(), config) self.bch_ledger = self.manager.get_or_create_ledger(BCHLedger.get_id(), config)
@ -63,7 +63,7 @@ class TestWalletCreation(unittest.TestCase):
self.assertDictEqual(wallet_dict, wallet.to_dict()) self.assertDictEqual(wallet_dict, wallet.to_dict())
def test_read_write(self): def test_read_write(self):
manager = WalletManager() manager = BaseWalletManager()
config = {'data_path': '/tmp/wallet'} config = {'data_path': '/tmp/wallet'}
ledger = manager.get_or_create_ledger(BTCLedger.get_id(), config) ledger = manager.get_or_create_ledger(BTCLedger.get_id(), config)

View file

@ -1,2 +1,2 @@
__path__ = __import__('pkgutil').extend_path(__path__, __name__) __path__: str = __import__('pkgutil').extend_path(__path__, __name__)
__version__ = '0.0.4' __version__ = '0.0.4'

View file

@ -1,12 +1,16 @@
from typing import Dict import typing
from typing import Sequence
from twisted.internet import defer from twisted.internet import defer
from torba.mnemonic import Mnemonic from torba.mnemonic import Mnemonic
from torba.bip32 import PrivateKey, PubKey, from_extended_key_string from torba.bip32 import PrivateKey, PubKey, from_extended_key_string
from torba.hash import double_sha256, aes_encrypt, aes_decrypt from torba.hash import double_sha256, aes_encrypt, aes_decrypt
if typing.TYPE_CHECKING:
from torba import baseledger
class KeyManager(object):
class KeyManager:
__slots__ = 'account', 'public_key', 'chain_number' __slots__ = 'account', 'public_key', 'chain_number'
@ -19,27 +23,27 @@ class KeyManager(object):
def db(self): def db(self):
return self.account.ledger.db return self.account.ledger.db
def _query_addresses(self, limit=None, max_used_times=None, order_by=None): def _query_addresses(self, limit: int = None, max_used_times: int = None, order_by=None):
return self.db.get_addresses( return self.db.get_addresses(
self.account, self.chain_number, limit, max_used_times, order_by self.account, self.chain_number, limit, max_used_times, order_by
) )
def get_max_gap(self): # type: () -> defer.Deferred def get_max_gap(self) -> defer.Deferred:
raise NotImplementedError raise NotImplementedError
def ensure_address_gap(self): # type: () -> defer.Deferred def ensure_address_gap(self) -> defer.Deferred:
raise NotImplementedError raise NotImplementedError
def get_address_records(self, limit=None, only_usable=False): # type: (int, bool) -> defer.Deferred def get_address_records(self, limit: int = None, only_usable: bool = False) -> defer.Deferred:
raise NotImplementedError raise NotImplementedError
@defer.inlineCallbacks @defer.inlineCallbacks
def get_addresses(self, limit=None, only_usable=False): # type: (int, bool) -> defer.Deferred def get_addresses(self, limit: int = None, only_usable: bool = False) -> defer.Deferred:
records = yield self.get_address_records(limit=limit, only_usable=only_usable) records = yield self.get_address_records(limit=limit, only_usable=only_usable)
defer.returnValue([r['address'] for r in records]) defer.returnValue([r['address'] for r in records])
@defer.inlineCallbacks @defer.inlineCallbacks
def get_or_create_usable_address(self): # type: () -> defer.Deferred def get_or_create_usable_address(self) -> defer.Deferred:
addresses = yield self.get_addresses(limit=1, only_usable=True) addresses = yield self.get_addresses(limit=1, only_usable=True)
if addresses: if addresses:
defer.returnValue(addresses[0]) defer.returnValue(addresses[0])
@ -52,14 +56,14 @@ class KeyChain(KeyManager):
__slots__ = 'gap', 'maximum_uses_per_address' __slots__ = 'gap', 'maximum_uses_per_address'
def __init__(self, account, root_public_key, chain_number, gap, maximum_uses_per_address): def __init__(self, account: 'BaseAccount', root_public_key: PubKey,
# type: ('BaseAccount', PubKey, int, int, int) -> None chain_number: int, gap: int, maximum_uses_per_address: int) -> None:
super(KeyChain, self).__init__(account, root_public_key.child(chain_number), chain_number) super().__init__(account, root_public_key.child(chain_number), chain_number)
self.gap = gap self.gap = gap
self.maximum_uses_per_address = maximum_uses_per_address self.maximum_uses_per_address = maximum_uses_per_address
@defer.inlineCallbacks @defer.inlineCallbacks
def generate_keys(self, start, end): def generate_keys(self, start: int, end: int) -> defer.Deferred:
new_keys = [] new_keys = []
for index in range(start, end+1): for index in range(start, end+1):
new_keys.append((index, self.public_key.child(index))) new_keys.append((index, self.public_key.child(index)))
@ -69,7 +73,7 @@ class KeyChain(KeyManager):
defer.returnValue([key[1].address for key in new_keys]) defer.returnValue([key[1].address for key in new_keys])
@defer.inlineCallbacks @defer.inlineCallbacks
def get_max_gap(self): def get_max_gap(self) -> defer.Deferred:
addresses = yield self._query_addresses(order_by="position ASC") addresses = yield self._query_addresses(order_by="position ASC")
max_gap = 0 max_gap = 0
current_gap = 0 current_gap = 0
@ -82,7 +86,7 @@ class KeyChain(KeyManager):
defer.returnValue(max_gap) defer.returnValue(max_gap)
@defer.inlineCallbacks @defer.inlineCallbacks
def ensure_address_gap(self): def ensure_address_gap(self) -> defer.Deferred:
addresses = yield self._query_addresses(self.gap, None, "position DESC") addresses = yield self._query_addresses(self.gap, None, "position DESC")
existing_gap = 0 existing_gap = 0
@ -100,7 +104,7 @@ class KeyChain(KeyManager):
new_keys = yield self.generate_keys(start, end-1) new_keys = yield self.generate_keys(start, end-1)
defer.returnValue(new_keys) defer.returnValue(new_keys)
def get_address_records(self, limit=None, only_usable=False): def get_address_records(self, limit: int = None, only_usable: bool = False):
return self._query_addresses( return self._query_addresses(
limit, self.maximum_uses_per_address if only_usable else None, limit, self.maximum_uses_per_address if only_usable else None,
"used_times ASC, position ASC" "used_times ASC, position ASC"
@ -112,15 +116,11 @@ class SingleKey(KeyManager):
__slots__ = () __slots__ = ()
def __init__(self, account, root_public_key, chain_number): def get_max_gap(self) -> defer.Deferred:
# type: ('BaseAccount', PubKey) -> None
super(SingleKey, self).__init__(account, root_public_key, chain_number)
def get_max_gap(self):
return defer.succeed(0) return defer.succeed(0)
@defer.inlineCallbacks @defer.inlineCallbacks
def ensure_address_gap(self): def ensure_address_gap(self) -> defer.Deferred:
exists = yield self.get_address_records() exists = yield self.get_address_records()
if not exists: if not exists:
yield self.db.add_keys( yield self.db.add_keys(
@ -129,20 +129,20 @@ class SingleKey(KeyManager):
defer.returnValue([self.public_key.address]) defer.returnValue([self.public_key.address])
defer.returnValue([]) defer.returnValue([])
def get_address_records(self, **kwargs): def get_address_records(self, limit: int = None, only_usable: bool = False) -> defer.Deferred:
return self._query_addresses() return self._query_addresses()
class BaseAccount(object): class BaseAccount:
mnemonic_class = Mnemonic mnemonic_class = Mnemonic
private_key_class = PrivateKey private_key_class = PrivateKey
public_key_class = PubKey public_key_class = PubKey
def __init__(self, ledger, name, seed, encrypted, is_hd, private_key, def __init__(self, ledger: 'baseledger.BaseLedger', name: str, seed: str, encrypted: bool, is_hd: bool,
public_key, receiving_gap=20, change_gap=6, private_key: PrivateKey, public_key: PubKey, receiving_gap: int = 20, change_gap: int = 6,
receiving_maximum_uses_per_address=2, change_maximum_uses_per_address=2): receiving_maximum_uses_per_address: int = 2, change_maximum_uses_per_address: int = 2
# type: (torba.baseledger.BaseLedger, str, str, bool, bool, PrivateKey, PubKey, int, int, int, int) -> None ) -> None:
self.ledger = ledger self.ledger = ledger
self.name = name self.name = name
self.seed = seed self.seed = seed
@ -150,25 +150,26 @@ class BaseAccount(object):
self.private_key = private_key self.private_key = private_key
self.public_key = public_key self.public_key = public_key
if is_hd: if is_hd:
receiving, change = self.keychains = ( self.receiving: KeyManager = KeyChain(
KeyChain(self, public_key, 0, receiving_gap, receiving_maximum_uses_per_address), self, public_key, 0, receiving_gap, receiving_maximum_uses_per_address
KeyChain(self, public_key, 1, change_gap, change_maximum_uses_per_address)
) )
self.change: KeyManager = KeyChain(
self, public_key, 1, change_gap, change_maximum_uses_per_address
)
self.keychains: Sequence[KeyManager] = (self.receiving, self.change)
else: else:
self.keychains = SingleKey(self, public_key, 0), self.change = self.receiving = SingleKey(self, public_key, 0)
receiving = change = self.keychains[0] self.keychains = (self.receiving,)
self.receiving = receiving # type: KeyManager
self.change = change # type: KeyManager
ledger.add_account(self) ledger.add_account(self)
@classmethod @classmethod
def generate(cls, ledger, password, **kwargs): # type: (torba.baseledger.BaseLedger, str) -> BaseAccount def generate(cls, ledger: 'baseledger.BaseLedger', password: str, **kwargs):
seed = cls.mnemonic_class().make_seed() seed = cls.mnemonic_class().make_seed()
return cls.from_seed(ledger, seed, password, **kwargs) return cls.from_seed(ledger, seed, password, **kwargs)
@classmethod @classmethod
def from_seed(cls, ledger, seed, password, is_hd=True, **kwargs): def from_seed(cls, ledger: 'baseledger.BaseLedger', seed: str, password: str,
# type: (torba.baseledger.BaseLedger, str, str) -> BaseAccount is_hd: bool = True, **kwargs):
private_key = cls.get_private_key_from_seed(ledger, seed, password) private_key = cls.get_private_key_from_seed(ledger, seed, password)
return cls( return cls(
ledger=ledger, name='Account #{}'.format(private_key.public_key.address), ledger=ledger, name='Account #{}'.format(private_key.public_key.address),
@ -179,14 +180,13 @@ class BaseAccount(object):
) )
@classmethod @classmethod
def get_private_key_from_seed(cls, ledger, seed, password): def get_private_key_from_seed(cls, ledger: 'baseledger.BaseLedger', seed: str, password: str):
# type: (torba.baseledger.BaseLedger, str, str) -> PrivateKey
return cls.private_key_class.from_seed( return cls.private_key_class.from_seed(
ledger, cls.mnemonic_class.mnemonic_to_seed(seed, password) ledger, cls.mnemonic_class.mnemonic_to_seed(seed, password)
) )
@classmethod @classmethod
def from_dict(cls, ledger, d): # type: (torba.baseledger.BaseLedger, Dict) -> BaseAccount def from_dict(cls, ledger: 'baseledger.BaseLedger', d: dict):
if not d['encrypted'] and d['private_key']: if not d['encrypted'] and d['private_key']:
private_key = from_extended_key_string(ledger, d['private_key']) private_key = from_extended_key_string(ledger, d['private_key'])
public_key = private_key.public_key public_key = private_key.public_key
@ -264,21 +264,20 @@ class BaseAccount(object):
defer.returnValue(addresses) defer.returnValue(addresses)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_addresses(self, limit=None, max_used_times=None): # type: (int, int) -> defer.Deferred def get_addresses(self, limit: int = None, max_used_times: int = None) -> defer.Deferred:
records = yield self.get_address_records(limit, max_used_times) records = yield self.get_address_records(limit, max_used_times)
defer.returnValue([r['address'] for r in records]) defer.returnValue([r['address'] for r in records])
def get_address_records(self, limit=None, max_used_times=None): # type: (int, int) -> defer.Deferred def get_address_records(self, limit: int = None, max_used_times: int = None) -> defer.Deferred:
return self.ledger.db.get_addresses(self, None, limit, max_used_times) return self.ledger.db.get_addresses(self, None, limit, max_used_times)
def get_private_key(self, chain, index): def get_private_key(self, chain: int, index: int) -> PrivateKey:
assert not self.encrypted, "Cannot get private key on encrypted wallet account." assert not self.encrypted, "Cannot get private key on encrypted wallet account."
if isinstance(self.receiving, SingleKey): if isinstance(self.receiving, SingleKey):
return self.private_key return self.private_key
else:
return self.private_key.child(chain).child(index) return self.private_key.child(chain).child(index)
def get_balance(self, confirmations=6, **constraints): def get_balance(self, confirmations: int = 6, **constraints):
if confirmations > 0: if confirmations > 0:
height = self.ledger.headers.height - (confirmations-1) height = self.ledger.headers.height - (confirmations-1)
constraints.update({'height__lte': height, 'height__gt': 0}) constraints.update({'height__lte': height, 'height__gt': 0})

View file

@ -1,5 +1,5 @@
import logging import logging
from typing import List, Union from typing import Tuple, List, Sequence
from operator import itemgetter from operator import itemgetter
import sqlite3 import sqlite3
@ -11,13 +11,13 @@ from torba.hash import TXRefImmutable
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class SQLiteMixin(object): class SQLiteMixin:
CREATE_TABLES_QUERY = None CREATE_TABLES_QUERY: Sequence[str] = ()
def __init__(self, path): def __init__(self, path):
self._db_path = path self._db_path = path
self.db = None self.db: adbapi.ConnectionPool = None
def start(self): def start(self):
log.info("connecting to database: %s", self._db_path) log.info("connecting to database: %s", self._db_path)
@ -32,8 +32,8 @@ class SQLiteMixin(object):
self.db.close() self.db.close()
return defer.succeed(True) return defer.succeed(True)
def _insert_sql(self, table, data): @staticmethod
# type: (str, dict) -> tuple[str, List] def _insert_sql(table: str, data: dict) -> Tuple[str, List]:
columns, values = [], [] columns, values = [], []
for column, value in data.items(): for column, value in data.items():
columns.append(column) columns.append(column)
@ -43,8 +43,8 @@ class SQLiteMixin(object):
) )
return sql, values return sql, values
def _update_sql(self, table, data, where, constraints): @staticmethod
# type: (str, dict) -> tuple[str, List] def _update_sql(table: str, data: dict, where: str, constraints: list) -> Tuple[str, list]:
columns, values = [], [] columns, values = [], []
for column, value in data.items(): for column, value in data.items():
columns.append("{} = ?".format(column)) columns.append("{} = ?".format(column))
@ -146,7 +146,8 @@ class BaseDatabase(SQLiteMixin):
CREATE_TXI_TABLE CREATE_TXI_TABLE
) )
def txo_to_row(self, tx, address, txo): @staticmethod
def txo_to_row(tx, address, txo):
return { return {
'txid': tx.id, 'txid': tx.id,
'txoid': txo.id, 'txoid': txo.id,
@ -156,7 +157,7 @@ class BaseDatabase(SQLiteMixin):
'script': sqlite3.Binary(txo.script.source) 'script': sqlite3.Binary(txo.script.source)
} }
def save_transaction_io(self, save_tx, tx, height, is_verified, address, hash, history): def save_transaction_io(self, save_tx, tx, height, is_verified, address, txhash, history):
def _steps(t): def _steps(t):
if save_tx == 'insert': if save_tx == 'insert':
@ -169,8 +170,7 @@ class BaseDatabase(SQLiteMixin):
elif save_tx == 'update': elif save_tx == 'update':
self.execute(t, *self._update_sql("tx", { self.execute(t, *self._update_sql("tx", {
'height': height, 'is_verified': is_verified 'height': height, 'is_verified': is_verified
}, 'txid = ?', (tx.id,) }, 'txid = ?', (tx.id,)))
))
existing_txos = list(map(itemgetter(0), self.execute( existing_txos = list(map(itemgetter(0), self.execute(
t, "SELECT position FROM txo WHERE txid = ?", (tx.id,) t, "SELECT position FROM txo WHERE txid = ?", (tx.id,)
@ -179,7 +179,7 @@ class BaseDatabase(SQLiteMixin):
for txo in tx.outputs: for txo in tx.outputs:
if txo.position in existing_txos: if txo.position in existing_txos:
continue continue
if txo.script.is_pay_pubkey_hash and txo.script.values['pubkey_hash'] == hash: if txo.script.is_pay_pubkey_hash and txo.script.values['pubkey_hash'] == txhash:
self.execute(t, *self._insert_sql("txo", self.txo_to_row(tx, address, txo))) self.execute(t, *self._insert_sql("txo", self.txo_to_row(tx, address, txo)))
elif txo.script.is_pay_script_hash: elif txo.script.is_pay_script_hash:
# TODO: implement script hash payments # TODO: implement script hash payments
@ -202,15 +202,16 @@ class BaseDatabase(SQLiteMixin):
return self.db.runInteraction(_steps) return self.db.runInteraction(_steps)
def reserve_spent_outputs(self, txoids, is_reserved=True): def reserve_outputs(self, txos, is_reserved=True):
txoids = [txo.id for txo in txos]
return self.run_operation( return self.run_operation(
"UPDATE txo SET is_reserved = ? WHERE txoid IN ({})".format( "UPDATE txo SET is_reserved = ? WHERE txoid IN ({})".format(
', '.join(['?']*len(txoids)) ', '.join(['?']*len(txoids))
), [is_reserved]+txoids ), [is_reserved]+txoids
) )
def release_reserved_outputs(self, txoids): def release_outputs(self, txos):
return self.reserve_spent_outputs(txoids, is_reserved=False) return self.reserve_outputs(txos, is_reserved=False)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_transaction(self, txid): def get_transaction(self, txid):
@ -226,7 +227,7 @@ class BaseDatabase(SQLiteMixin):
extra_sql = "" extra_sql = ""
if constraints: if constraints:
extras = [] extras = []
for key in constraints.keys(): for key in constraints:
col, op = key, '=' col, op = key, '='
if key.endswith('__not'): if key.endswith('__not'):
col, op = key[:-len('__not')], '!=' col, op = key[:-len('__not')], '!='
@ -257,7 +258,7 @@ class BaseDatabase(SQLiteMixin):
extra_sql = "" extra_sql = ""
if constraints: if constraints:
extra_sql = ' AND ' + ' AND '.join( extra_sql = ' AND ' + ' AND '.join(
'{} = :{}'.format(c, c) for c in constraints.keys() '{} = :{}'.format(c, c) for c in constraints
) )
values = {'account': account.public_key.address} values = {'account': account.public_key.address}
values.update(constraints) values.update(constraints)

View file

@ -1,13 +1,16 @@
import os import os
import struct import struct
import logging import logging
import typing
from binascii import unhexlify from binascii import unhexlify
from twisted.internet import threads, defer from twisted.internet import threads, defer
from torba.stream import StreamController, execute_serially from torba.stream import StreamController
from torba.util import int_to_hex, rev_hex, hash_encode from torba.util import int_to_hex, rev_hex, hash_encode
from torba.hash import double_sha256, pow_hash from torba.hash import double_sha256, pow_hash
if typing.TYPE_CHECKING:
from torba import baseledger
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -17,7 +20,7 @@ class BaseHeaders:
header_size = 80 header_size = 80
verify_bits_to_target = True verify_bits_to_target = True
def __init__(self, ledger): # type: (baseledger.BaseLedger) -> BaseHeaders def __init__(self, ledger: 'baseledger.BaseLedger') -> None:
self.ledger = ledger self.ledger = ledger
self._size = None self._size = None
self._on_change_controller = StreamController() self._on_change_controller = StreamController()
@ -62,7 +65,6 @@ class BaseHeaders:
header = self.sync_read_header(height) header = self.sync_read_header(height)
return self._deserialize(height, header) return self._deserialize(height, header)
@execute_serially
@defer.inlineCallbacks @defer.inlineCallbacks
def connect(self, start, headers): def connect(self, start, headers):
yield threads.deferToThread(self._sync_connect, start, headers) yield threads.deferToThread(self._sync_connect, start, headers)
@ -84,8 +86,9 @@ class BaseHeaders:
_old_size = self._size _old_size = self._size
self._size = self.sync_read_length() self._size = self.sync_read_length()
change = self._size - _old_size change = self._size - _old_size
log.info('{}: added {} header blocks, final height {}'.format( log.info(
self.ledger.get_id(), change, self.height) '%s: added %s header blocks, final height %s',
self.ledger.get_id(), change, self.height
) )
self._on_change_controller.add(change) self._on_change_controller.add(change)
@ -101,7 +104,7 @@ class BaseHeaders:
assert previous_hash == header['prev_block_hash'], \ assert previous_hash == header['prev_block_hash'], \
"prev hash mismatch: {} vs {}".format(previous_hash, header['prev_block_hash']) "prev hash mismatch: {} vs {}".format(previous_hash, header['prev_block_hash'])
bits, target = self._calculate_next_work_required(height, previous_header, header) bits, _ = self._calculate_next_work_required(height, previous_header, header)
assert bits == header['bits'], \ assert bits == header['bits'], \
"bits mismatch: {} vs {} (hash: {})".format( "bits mismatch: {} vs {} (hash: {})".format(
bits, header['bits'], self._hash_header(header)) bits, header['bits'], self._hash_header(header))
@ -154,37 +157,37 @@ class BaseHeaders:
if self.verify_bits_to_target: if self.verify_bits_to_target:
bits = last['bits'] bits = last['bits']
bitsN = (bits >> 24) & 0xff bits_n = (bits >> 24) & 0xff
assert 0x03 <= bitsN <= 0x1d, \ assert 0x03 <= bits_n <= 0x1d, \
"First part of bits should be in [0x03, 0x1d], but it was {}".format(hex(bitsN)) "First part of bits should be in [0x03, 0x1d], but it was {}".format(hex(bits_n))
bitsBase = bits & 0xffffff bits_base = bits & 0xffffff
assert 0x8000 <= bitsBase <= 0x7fffff, \ assert 0x8000 <= bits_base <= 0x7fffff, \
"Second part of bits should be in [0x8000, 0x7fffff] but it was {}".format(bitsBase) "Second part of bits should be in [0x8000, 0x7fffff] but it was {}".format(bits_base)
# new target # new target
retargetTimespan = self.ledger.target_timespan retarget_timespan = self.ledger.target_timespan
nActualTimespan = last['timestamp'] - first['timestamp'] n_actual_timespan = last['timestamp'] - first['timestamp']
nModulatedTimespan = retargetTimespan + (nActualTimespan - retargetTimespan) // 8 n_modulated_timespan = retarget_timespan + (n_actual_timespan - retarget_timespan) // 8
nMinTimespan = retargetTimespan - (retargetTimespan // 8) n_min_timespan = retarget_timespan - (retarget_timespan // 8)
nMaxTimespan = retargetTimespan + (retargetTimespan // 2) n_max_timespan = retarget_timespan + (retarget_timespan // 2)
# Limit adjustment step # Limit adjustment step
if nModulatedTimespan < nMinTimespan: if n_modulated_timespan < n_min_timespan:
nModulatedTimespan = nMinTimespan n_modulated_timespan = n_min_timespan
elif nModulatedTimespan > nMaxTimespan: elif n_modulated_timespan > n_max_timespan:
nModulatedTimespan = nMaxTimespan n_modulated_timespan = n_max_timespan
# Retarget # Retarget
bnPowLimit = _ArithUint256(self.ledger.max_target) bn_pow_limit = _ArithUint256(self.ledger.max_target)
bnNew = _ArithUint256.SetCompact(last['bits']) bn_new = _ArithUint256.set_compact(last['bits'])
bnNew *= nModulatedTimespan bn_new *= n_modulated_timespan
bnNew //= nModulatedTimespan bn_new //= n_modulated_timespan
if bnNew > bnPowLimit: if bn_new > bn_pow_limit:
bnNew = bnPowLimit bn_new = bn_pow_limit
return bnNew.GetCompact(), bnNew._value return bn_new.get_compact(), bn_new._value
class _ArithUint256: class _ArithUint256:
@ -197,49 +200,48 @@ class _ArithUint256:
return hex(self._value) return hex(self._value)
@staticmethod @staticmethod
def fromCompact(nCompact): def from_compact(n_compact):
"""Convert a compact representation into its value""" """Convert a compact representation into its value"""
nSize = nCompact >> 24 n_size = n_compact >> 24
# the lower 23 bits # the lower 23 bits
nWord = nCompact & 0x007fffff n_word = n_compact & 0x007fffff
if nSize <= 3: if n_size <= 3:
return nWord >> 8 * (3 - nSize) return n_word >> 8 * (3 - n_size)
else: else:
return nWord << 8 * (nSize - 3) return n_word << 8 * (n_size - 3)
@classmethod @classmethod
def SetCompact(cls, nCompact): def set_compact(cls, n_compact):
return cls(cls.fromCompact(nCompact)) return cls(cls.from_compact(n_compact))
def bits(self): def bits(self):
"""Returns the position of the highest bit set plus one.""" """Returns the position of the highest bit set plus one."""
bn = bin(self._value)[2:] bits = bin(self._value)[2:]
for i, d in enumerate(bn): for i, d in enumerate(bits):
if d: if d:
return (len(bn) - i) + 1 return (len(bits) - i) + 1
return 0 return 0
def GetLow64(self): def get_low64(self):
return self._value & 0xffffffffffffffff return self._value & 0xffffffffffffffff
def GetCompact(self): def get_compact(self):
"""Convert a value into its compact representation""" """Convert a value into its compact representation"""
nSize = (self.bits() + 7) // 8 n_size = (self.bits() + 7) // 8
nCompact = 0 if n_size <= 3:
if nSize <= 3: n_compact = self.get_low64() << 8 * (3 - n_size)
nCompact = self.GetLow64() << 8 * (3 - nSize)
else: else:
bn = _ArithUint256(self._value >> 8 * (nSize - 3)) n = _ArithUint256(self._value >> 8 * (n_size - 3))
nCompact = bn.GetLow64() n_compact = n.get_low64()
# The 0x00800000 bit denotes the sign. # The 0x00800000 bit denotes the sign.
# Thus, if it is already set, divide the mantissa by 256 and increase the exponent. # Thus, if it is already set, divide the mantissa by 256 and increase the exponent.
if nCompact & 0x00800000: if n_compact & 0x00800000:
nCompact >>= 8 n_compact >>= 8
nSize += 1 n_size += 1
assert (nCompact & ~0x007fffff) == 0 assert (n_compact & ~0x007fffff) == 0
assert nSize < 256 assert n_size < 256
nCompact |= nSize << 24 n_compact |= n_size << 24
return nCompact return n_compact
def __mul__(self, x): def __mul__(self, x):
# Take the mod because we are limited to an unsigned 256 bit number # Take the mod because we are limited to an unsigned 256 bit number

View file

@ -1,6 +1,4 @@
import os import os
import six
import hashlib
import logging import logging
from binascii import hexlify, unhexlify from binascii import hexlify, unhexlify
from typing import Dict, Type, Iterable from typing import Dict, Type, Iterable
@ -14,17 +12,22 @@ from torba import basedatabase
from torba import baseheader from torba import baseheader
from torba import basenetwork from torba import basenetwork
from torba import basetransaction from torba import basetransaction
from torba.stream import StreamController, execute_serially from torba.coinselection import CoinSelector
from torba.hash import hash160, double_sha256, Base58 from torba.constants import COIN, NULL_HASH32
from torba.stream import StreamController
from torba.hash import hash160, double_sha256, sha256, Base58
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
LedgerType = Type['BaseLedger']
class LedgerRegistry(type): class LedgerRegistry(type):
ledgers = {} # type: Dict[str, Type[BaseLedger]]
ledgers: Dict[str, LedgerType] = {}
def __new__(mcs, name, bases, attrs): def __new__(mcs, name, bases, attrs):
cls = super(LedgerRegistry, mcs).__new__(mcs, name, bases, attrs) # type: Type[BaseLedger] cls: LedgerType = super().__new__(mcs, name, bases, attrs)
if not (name == 'BaseLedger' and not bases): if not (name == 'BaseLedger' and not bases):
ledger_id = cls.get_id() ledger_id = cls.get_id()
assert ledger_id not in mcs.ledgers,\ assert ledger_id not in mcs.ledgers,\
@ -33,7 +36,7 @@ class LedgerRegistry(type):
return cls return cls
@classmethod @classmethod
def get_ledger_class(mcs, ledger_id): # type: (str) -> Type[BaseLedger] def get_ledger_class(mcs, ledger_id: str) -> LedgerType:
return mcs.ledgers[ledger_id] return mcs.ledgers[ledger_id]
@ -41,11 +44,11 @@ class TransactionEvent(namedtuple('TransactionEvent', ('address', 'tx', 'height'
pass pass
class BaseLedger(six.with_metaclass(LedgerRegistry)): class BaseLedger(metaclass=LedgerRegistry):
name = None name: str
symbol = None symbol: str
network_name = None network_name: str
account_class = baseaccount.BaseAccount account_class = baseaccount.BaseAccount
database_class = basedatabase.BaseDatabase database_class = basedatabase.BaseDatabase
@ -54,10 +57,10 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)):
transaction_class = basetransaction.BaseTransaction transaction_class = basetransaction.BaseTransaction
secret_prefix = None secret_prefix = None
pubkey_address_prefix = None pubkey_address_prefix: bytes
script_address_prefix = None script_address_prefix: bytes
extended_public_key_prefix = None extended_public_key_prefix: bytes
extended_private_key_prefix = None extended_private_key_prefix: bytes
default_fee_per_byte = 10 default_fee_per_byte = 10
@ -71,13 +74,14 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)):
self.network.on_status.listen(self.process_status) self.network.on_status.listen(self.process_status)
self.accounts = [] self.accounts = []
self.headers = self.config.get('headers') or self.headers_class(self) self.headers = self.config.get('headers') or self.headers_class(self)
self.fee_per_byte = self.config.get('fee_per_byte', self.default_fee_per_byte) self.fee_per_byte: int = self.config.get('fee_per_byte', self.default_fee_per_byte)
self._on_transaction_controller = StreamController() self._on_transaction_controller = StreamController()
self.on_transaction = self._on_transaction_controller.stream self.on_transaction = self._on_transaction_controller.stream
self.on_transaction.listen( self.on_transaction.listen(
lambda e: log.info('({}) on_transaction: address={}, height={}, is_verified={}, tx.id={}'.format( lambda e: log.info(
self.get_id(), e.address, e.height, e.is_verified, e.tx.id) '(%s) on_transaction: address=%s, height=%s, is_verified=%s, tx.id=%s',
self.get_id(), e.address, e.height, e.is_verified, e.tx.id
) )
) )
@ -85,6 +89,8 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)):
self.on_header = self._on_header_controller.stream self.on_header = self._on_header_controller.stream
self._transaction_processing_locks = {} self._transaction_processing_locks = {}
self._utxo_reservation_lock = defer.DeferredLock()
self._header_processing_lock = defer.DeferredLock()
@classmethod @classmethod
def get_id(cls): def get_id(cls):
@ -97,9 +103,7 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)):
@staticmethod @staticmethod
def address_to_hash160(address): def address_to_hash160(address):
bytes = Base58.decode(address) return Base58.decode(address)[1:21]
prefix, pubkey_bytes, addr_checksum = bytes[0], bytes[1:21], bytes[21:]
return pubkey_bytes
@classmethod @classmethod
def public_key_to_address(cls, public_key): def public_key_to_address(cls, public_key):
@ -113,7 +117,7 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)):
def path(self): def path(self):
return os.path.join(self.config['data_path'], self.get_id()) return os.path.join(self.config['data_path'], self.get_id())
def get_input_output_fee(self, io): def get_input_output_fee(self, io: basetransaction.InputOutput) -> int:
""" Fee based on size of the input / output. """ """ Fee based on size of the input / output. """
return self.fee_per_byte * io.size return self.fee_per_byte * io.size
@ -122,14 +126,14 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)):
return self.fee_per_byte * tx.base_size return self.fee_per_byte * tx.base_size
@defer.inlineCallbacks @defer.inlineCallbacks
def add_account(self, account): # type: (baseaccount.BaseAccount) -> None def add_account(self, account: baseaccount.BaseAccount) -> defer.Deferred:
self.accounts.append(account) self.accounts.append(account)
if self.network.is_connected: if self.network.is_connected:
yield self.update_account(account) yield self.update_account(account)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_transaction(self, txhash): def get_transaction(self, txhash):
raw, height, is_verified = yield self.db.get_transaction(txhash) raw, _, _ = yield self.db.get_transaction(txhash)
if raw is not None: if raw is not None:
defer.returnValue(self.transaction_class(raw)) defer.returnValue(self.transaction_class(raw))
@ -142,8 +146,7 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)):
defer.returnValue(account.get_private_key(match['chain'], match['position'])) defer.returnValue(account.get_private_key(match['chain'], match['position']))
@defer.inlineCallbacks @defer.inlineCallbacks
def get_effective_amount_estimators(self, funding_accounts): def get_effective_amount_estimators(self, funding_accounts: Iterable[baseaccount.BaseAccount]):
# type: (Iterable[baseaccount.BaseAccount]) -> defer.Deferred
estimators = [] estimators = []
for account in funding_accounts: for account in funding_accounts:
utxos = yield account.get_unspent_outputs() utxos = yield account.get_unspent_outputs()
@ -151,12 +154,39 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)):
estimators.append(utxo.get_estimator(self)) estimators.append(utxo.get_estimator(self))
defer.returnValue(estimators) defer.returnValue(estimators)
@defer.inlineCallbacks
def get_spendable_utxos(self, amount: int, funding_accounts):
yield self._utxo_reservation_lock.acquire()
try:
txos = yield self.get_effective_amount_estimators(funding_accounts)
selector = CoinSelector(
txos, amount,
self.get_input_output_fee(
self.transaction_class.output_class.pay_pubkey_hash(COIN, NULL_HASH32)
)
)
spendables = selector.select()
if spendables:
yield self.reserve_outputs(s.txo for s in spendables)
except Exception:
log.exception('Failed to get spendable utxos:')
raise
finally:
self._utxo_reservation_lock.release()
defer.returnValue(spendables)
def reserve_outputs(self, txos):
return self.db.reserve_outputs(txos)
def release_outputs(self, txos):
return self.db.release_outputs(txos)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_local_status(self, address): def get_local_status(self, address):
address_details = yield self.db.get_address(address) address_details = yield self.db.get_address(address)
history = address_details['history'] or '' history = address_details['history'] or ''
hash = hashlib.sha256(history.encode()).digest() h = sha256(history.encode())
defer.returnValue(hexlify(hash)) defer.returnValue(hexlify(h))
@defer.inlineCallbacks @defer.inlineCallbacks
def get_local_history(self, address): def get_local_history(self, address):
@ -203,7 +233,6 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)):
yield self.network.stop() yield self.network.stop()
yield self.db.stop() yield self.db.stop()
@execute_serially
@defer.inlineCallbacks @defer.inlineCallbacks
def update_headers(self): def update_headers(self):
while True: while True:
@ -216,9 +245,9 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)):
@defer.inlineCallbacks @defer.inlineCallbacks
def process_header(self, response): def process_header(self, response):
yield self._header_processing_lock.acquire()
try:
header = response[0] header = response[0]
if self.update_headers.is_running:
return
if header['height'] == len(self.headers): if header['height'] == len(self.headers):
# New header from network directly connects after the last local header. # New header from network directly connects after the last local header.
yield self.headers.connect(len(self.headers), unhexlify(header['hex'])) yield self.headers.connect(len(self.headers), unhexlify(header['hex']))
@ -226,8 +255,9 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)):
elif header['height'] > len(self.headers): elif header['height'] > len(self.headers):
# New header is several heights ahead of local, do download instead. # New header is several heights ahead of local, do download instead.
yield self.update_headers() yield self.update_headers()
finally:
self._header_processing_lock.release()
@execute_serially
def update_accounts(self): def update_accounts(self):
return defer.DeferredList([ return defer.DeferredList([
self.update_account(a) for a in self.accounts self.update_account(a) for a in self.accounts
@ -274,7 +304,7 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)):
try: try:
# see if we have a local copy of transaction, otherwise fetch it from server # see if we have a local copy of transaction, otherwise fetch it from server
raw, local_height, is_verified = yield self.db.get_transaction(hex_id) raw, _, is_verified = yield self.db.get_transaction(hex_id)
save_tx = None save_tx = None
if raw is None: if raw is None:
_raw = yield self.network.get_transaction(hex_id) _raw = yield self.network.get_transaction(hex_id)
@ -294,15 +324,16 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)):
''.join('{}:{}:'.format(tx_id, tx_height) for tx_id, tx_height in synced_history) ''.join('{}:{}:'.format(tx_id, tx_height) for tx_id, tx_height in synced_history)
) )
log.debug("{}: sync'ed tx {} for address: {}, height: {}, verified: {}".format( log.debug(
"%s: sync'ed tx %s for address: %s, height: %s, verified: %s",
self.get_id(), hex_id, address, remote_height, is_verified self.get_id(), hex_id, address, remote_height, is_verified
)) )
self._on_transaction_controller.add(TransactionEvent(address, tx, remote_height, is_verified)) self._on_transaction_controller.add(TransactionEvent(address, tx, remote_height, is_verified))
except Exception as e: except Exception:
log.exception('Failed to synchronize transaction:') log.exception('Failed to synchronize transaction:')
raise e raise
finally: finally:
lock.release() lock.release()

View file

@ -1,4 +1,4 @@
from typing import List, Dict, Type from typing import Type, MutableSequence, MutableMapping
from twisted.internet import defer from twisted.internet import defer
from torba.baseledger import BaseLedger, LedgerRegistry from torba.baseledger import BaseLedger, LedgerRegistry
@ -6,16 +6,16 @@ from torba.wallet import Wallet, WalletStorage
from torba.constants import COIN from torba.constants import COIN
class WalletManager(object): class BaseWalletManager:
def __init__(self, wallets=None, ledgers=None): def __init__(self, wallets: MutableSequence[Wallet] = None,
# type: (List[Wallet], Dict[Type[BaseLedger],BaseLedger]) -> None ledgers: MutableMapping[Type[BaseLedger], BaseLedger] = None) -> None:
self.wallets = wallets or [] self.wallets = wallets or []
self.ledgers = ledgers or {} self.ledgers = ledgers or {}
self.running = False self.running = False
@classmethod @classmethod
def from_config(cls, config): # type: (Dict) -> WalletManager def from_config(cls, config: dict) -> 'BaseWalletManager':
manager = cls() manager = cls()
for ledger_id, ledger_config in config.get('ledgers', {}).items(): for ledger_id, ledger_config in config.get('ledgers', {}).items():
manager.get_or_create_ledger(ledger_id, ledger_config) manager.get_or_create_ledger(ledger_id, ledger_config)

View file

@ -21,6 +21,7 @@ class StratumClientProtocol(LineOnlyReceiver):
self.request_id = 0 self.request_id = 0
self.lookup_table = {} self.lookup_table = {}
self.session = {} self.session = {}
self.network = None
self.on_disconnected_controller = StreamController() self.on_disconnected_controller = StreamController()
self.on_disconnected = self.on_disconnected_controller.stream self.on_disconnected = self.on_disconnected_controller.stream
@ -52,7 +53,7 @@ class StratumClientProtocol(LineOnlyReceiver):
socket.SOL_TCP, socket.TCP_KEEPCNT, 5 socket.SOL_TCP, socket.TCP_KEEPCNT, 5
# Failed keepalive probles before declaring other end dead # Failed keepalive probles before declaring other end dead
) )
except Exception as err: except Exception as err: # pylint: disable=broad-except
# Supported only by the socket transport, # Supported only by the socket transport,
# but there's really no better place in code to trigger this. # but there's really no better place in code to trigger this.
log.warning("Error setting up socket: %s", err) log.warning("Error setting up socket: %s", err)
@ -61,7 +62,7 @@ class StratumClientProtocol(LineOnlyReceiver):
self.on_disconnected_controller.add(True) self.on_disconnected_controller.add(True)
def lineReceived(self, line): def lineReceived(self, line):
log.debug('received: {}'.format(line)) log.debug('received: %s', line)
try: try:
message = json.loads(line) message = json.loads(line)
@ -82,7 +83,7 @@ class StratumClientProtocol(LineOnlyReceiver):
controller = self.network.subscription_controllers[message['method']] controller = self.network.subscription_controllers[message['method']]
controller.add(message.get('params')) controller.add(message.get('params'))
else: else:
log.warning("Cannot handle message '%s'" % line) log.warning("Cannot handle message '%s'", line)
def rpc(self, method, *args): def rpc(self, method, *args):
message_id = self._get_id() message_id = self._get_id()
@ -91,7 +92,7 @@ class StratumClientProtocol(LineOnlyReceiver):
'method': method, 'method': method,
'params': args 'params': args
}) })
log.debug('sent: {}'.format(message)) log.debug('sent: %s', message)
self.sendLine(message.encode('latin-1')) self.sendLine(message.encode('latin-1'))
d = self.lookup_table[message_id] = defer.Deferred() d = self.lookup_table[message_id] = defer.Deferred()
return d return d
@ -138,20 +139,21 @@ class BaseNetwork:
@defer.inlineCallbacks @defer.inlineCallbacks
def start(self): def start(self):
for server in cycle(self.config['default_servers']): for server in cycle(self.config['default_servers']):
endpoint = clientFromString(reactor, 'tcp:{}:{}'.format(*server)) connection_string = 'tcp:{}:{}'.format(*server)
log.debug("Attempting connection to SPV wallet server: {}:{}".format(*server)) endpoint = clientFromString(reactor, connection_string)
log.debug("Attempting connection to SPV wallet server: %s", connection_string)
self.service = ClientService(endpoint, StratumClientFactory(self)) self.service = ClientService(endpoint, StratumClientFactory(self))
self.service.startService() self.service.startService()
try: try:
self.client = yield self.service.whenConnected(failAfterFailures=2) self.client = yield self.service.whenConnected(failAfterFailures=2)
yield self.ensure_server_version() yield self.ensure_server_version()
log.info("Successfully connected to SPV wallet server: {}:{}".format(*server)) log.info("Successfully connected to SPV wallet server: %s", connection_string)
self._on_connected_controller.add(True) self._on_connected_controller.add(True)
yield self.client.on_disconnected.first yield self.client.on_disconnected.first
except CancelledError: except CancelledError:
return return
except Exception: except Exception: # pylint: disable=broad-except
log.exception("Connecting to {}:{} raised an exception:".format(*server)) log.exception("Connecting to %s raised an exception:", connection_string)
finally: finally:
self.client = None self.client = None
if not self.running: if not self.running:

View file

@ -1,6 +1,7 @@
from itertools import chain from itertools import chain
from binascii import hexlify from binascii import hexlify
from collections import namedtuple from collections import namedtuple
from typing import List
from torba.bcd_data_stream import BCDataStream from torba.bcd_data_stream import BCDataStream
from torba.util import subclass_tuple from torba.util import subclass_tuple
@ -25,17 +26,21 @@ OP_DROP = 0x75
# template matching opcodes (not real opcodes) # template matching opcodes (not real opcodes)
# base class for PUSH_DATA related opcodes # base class for PUSH_DATA related opcodes
# pylint: disable=invalid-name
PUSH_DATA_OP = namedtuple('PUSH_DATA_OP', 'name') PUSH_DATA_OP = namedtuple('PUSH_DATA_OP', 'name')
# opcode for variable length strings # opcode for variable length strings
# pylint: disable=invalid-name
PUSH_SINGLE = subclass_tuple('PUSH_SINGLE', PUSH_DATA_OP) PUSH_SINGLE = subclass_tuple('PUSH_SINGLE', PUSH_DATA_OP)
# opcode for variable number of variable length strings # opcode for variable number of variable length strings
# pylint: disable=invalid-name
PUSH_MANY = subclass_tuple('PUSH_MANY', PUSH_DATA_OP) PUSH_MANY = subclass_tuple('PUSH_MANY', PUSH_DATA_OP)
# opcode with embedded subscript parsing # opcode with embedded subscript parsing
# pylint: disable=invalid-name
PUSH_SUBSCRIPT = namedtuple('PUSH_SUBSCRIPT', 'name template') PUSH_SUBSCRIPT = namedtuple('PUSH_SUBSCRIPT', 'name template')
def is_push_data_opcode(opcode): def is_push_data_opcode(opcode):
return isinstance(opcode, PUSH_DATA_OP) or isinstance(opcode, PUSH_SUBSCRIPT) return isinstance(opcode, (PUSH_DATA_OP, PUSH_SUBSCRIPT))
def is_push_data_token(token): def is_push_data_token(token):
@ -61,15 +66,15 @@ def push_data(data):
def read_data(token, stream): def read_data(token, stream):
if token < OP_PUSHDATA1: if token < OP_PUSHDATA1:
return stream.read(token) return stream.read(token)
elif token == OP_PUSHDATA1: if token == OP_PUSHDATA1:
return stream.read(stream.read_uint8()) return stream.read(stream.read_uint8())
elif token == OP_PUSHDATA2: if token == OP_PUSHDATA2:
return stream.read(stream.read_uint16()) return stream.read(stream.read_uint16())
else:
return stream.read(stream.read_uint32()) return stream.read(stream.read_uint32())
# opcode for OP_1 - OP_16 # opcode for OP_1 - OP_16
# pylint: disable=invalid-name
SMALL_INTEGER = namedtuple('SMALL_INTEGER', 'name') SMALL_INTEGER = namedtuple('SMALL_INTEGER', 'name')
@ -233,7 +238,7 @@ class Parser:
raise ParseError("Not a push single or subscript: {}".format(opcode)) raise ParseError("Not a push single or subscript: {}".format(opcode))
class Template(object): class Template:
__slots__ = 'name', 'opcodes' __slots__ = 'name', 'opcodes'
@ -264,11 +269,11 @@ class Template(object):
return source.get_bytes() return source.get_bytes()
class Script(object): class Script:
__slots__ = 'source', 'template', 'values' __slots__ = 'source', 'template', 'values'
templates = [] templates: List[Template] = []
def __init__(self, source=None, template=None, values=None, template_hint=None): def __init__(self, source=None, template=None, values=None, template_hint=None):
self.source = source self.source = source

View file

@ -1,29 +1,30 @@
import six
import logging import logging
from typing import List, Iterable import typing
from typing import List, Iterable, Optional
from binascii import hexlify from binascii import hexlify
from twisted.internet import defer from twisted.internet import defer
import torba.baseaccount
import torba.baseledger
from torba.basescript import BaseInputScript, BaseOutputScript from torba.basescript import BaseInputScript, BaseOutputScript
from torba.coinselection import CoinSelector from torba.baseaccount import BaseAccount
from torba.constants import COIN, NULL_HASH32 from torba.constants import COIN, NULL_HASH32
from torba.bcd_data_stream import BCDataStream from torba.bcd_data_stream import BCDataStream
from torba.hash import sha256, TXRef, TXRefImmutable, TXORef from torba.hash import sha256, TXRef, TXRefImmutable
from torba.util import ReadOnlyList from torba.util import ReadOnlyList
if typing.TYPE_CHECKING:
from torba import baseledger
log = logging.getLogger() log = logging.getLogger()
class TXRefMutable(TXRef): class TXRefMutable(TXRef):
__slots__ = 'tx', __slots__ = ('tx',)
def __init__(self, tx): def __init__(self, tx: 'BaseTransaction') -> None:
super(TXRefMutable, self).__init__() super().__init__()
self.tx = tx self.tx = tx
@property @property
@ -43,12 +44,35 @@ class TXRefMutable(TXRef):
self._hash = None self._hash = None
class TXORef:
__slots__ = 'tx_ref', 'position'
def __init__(self, tx_ref: TXRef, position: int) -> None:
self.tx_ref = tx_ref
self.position = position
@property
def id(self):
return '{}:{}'.format(self.tx_ref.id, self.position)
@property
def is_null(self):
return self.tx_ref.is_null
@property
def txo(self) -> Optional['BaseOutput']:
return None
class TXORefResolvable(TXORef): class TXORefResolvable(TXORef):
__slots__ = '_txo', __slots__ = ('_txo',)
def __init__(self, txo): def __init__(self, txo: 'BaseOutput') -> None:
super(TXORefResolvable, self).__init__(txo.tx_ref, txo.position) assert txo.tx_ref is not None
assert txo.position is not None
super().__init__(txo.tx_ref, txo.position)
self._txo = txo self._txo = txo
@property @property
@ -56,23 +80,23 @@ class TXORefResolvable(TXORef):
return self._txo return self._txo
class InputOutput(object): class InputOutput:
__slots__ = 'tx_ref', 'position' __slots__ = 'tx_ref', 'position'
def __init__(self, tx_ref=None, position=None): def __init__(self, tx_ref: TXRef = None, position: int = None) -> None:
self.tx_ref = tx_ref # type: TXRef self.tx_ref = tx_ref
self.position = position # type: int self.position = position
@property @property
def size(self): def size(self) -> int:
""" Size of this input / output in bytes. """ """ Size of this input / output in bytes. """
stream = BCDataStream() stream = BCDataStream()
self.serialize_to(stream) self.serialize_to(stream)
return len(stream.get_bytes()) return len(stream.get_bytes())
def serialize_to(self, stream): def serialize_to(self, stream, alternate_script=None):
raise NotImplemented raise NotImplementedError
class BaseInput(InputOutput): class BaseInput(InputOutput):
@ -84,27 +108,27 @@ class BaseInput(InputOutput):
__slots__ = 'txo_ref', 'sequence', 'coinbase', 'script' __slots__ = 'txo_ref', 'sequence', 'coinbase', 'script'
def __init__(self, txo_ref, script, sequence=0xFFFFFFFF, tx_ref=None, position=None): def __init__(self, txo_ref: TXORef, script: BaseInputScript, sequence: int = 0xFFFFFFFF,
# type: (TXORef, BaseInputScript, int, TXRef, int) -> None tx_ref: TXRef = None, position: int = None) -> None:
super(BaseInput, self).__init__(tx_ref, position) super().__init__(tx_ref, position)
self.txo_ref = txo_ref self.txo_ref = txo_ref
self.sequence = sequence self.sequence = sequence
self.coinbase = script if txo_ref.is_null else None self.coinbase = script if txo_ref.is_null else None
self.script = script if not txo_ref.is_null else None # type: BaseInputScript self.script = script if not txo_ref.is_null else None
@property @property
def is_coinbase(self): def is_coinbase(self):
return self.coinbase is not None return self.coinbase is not None
@classmethod @classmethod
def spend(cls, txo): # type: (BaseOutput) -> BaseInput def spend(cls, txo: 'BaseOutput') -> 'BaseInput':
""" Create an input to spend the output.""" """ Create an input to spend the output."""
assert txo.script.is_pay_pubkey_hash, 'Attempting to spend unsupported output.' assert txo.script.is_pay_pubkey_hash, 'Attempting to spend unsupported output.'
script = cls.script_class.redeem_pubkey_hash(cls.NULL_SIGNATURE, cls.NULL_PUBLIC_KEY) script = cls.script_class.redeem_pubkey_hash(cls.NULL_SIGNATURE, cls.NULL_PUBLIC_KEY)
return cls(txo.ref, script) return cls(txo.ref, script)
@property @property
def amount(self): def amount(self) -> int:
""" Amount this input adds to the transaction. """ """ Amount this input adds to the transaction. """
if self.txo_ref.txo is None: if self.txo_ref.txo is None:
raise ValueError('Cannot resolve output to get amount.') raise ValueError('Cannot resolve output to get amount.')
@ -135,15 +159,15 @@ class BaseInput(InputOutput):
stream.write_uint32(self.sequence) stream.write_uint32(self.sequence)
class BaseOutputEffectiveAmountEstimator(object): class BaseOutputEffectiveAmountEstimator:
__slots__ = 'txo', 'txi', 'fee', 'effective_amount' __slots__ = 'txo', 'txi', 'fee', 'effective_amount'
def __init__(self, ledger, txo): # type: (torba.baseledger.BaseLedger, BaseOutput) -> None def __init__(self, ledger: 'baseledger.BaseLedger', txo: 'BaseOutput') -> None:
self.txo = txo self.txo = txo
self.txi = ledger.transaction_class.input_class.spend(txo) self.txi = ledger.transaction_class.input_class.spend(txo)
self.fee = ledger.get_input_output_fee(self.txi) self.fee: int = ledger.get_input_output_fee(self.txi)
self.effective_amount = txo.amount - self.fee self.effective_amount: int = txo.amount - self.fee
def __lt__(self, other): def __lt__(self, other):
return self.effective_amount < other.effective_amount return self.effective_amount < other.effective_amount
@ -156,9 +180,9 @@ class BaseOutput(InputOutput):
__slots__ = 'amount', 'script' __slots__ = 'amount', 'script'
def __init__(self, amount, script, tx_ref=None, position=None): def __init__(self, amount: int, script: BaseOutputScript,
# type: (int, BaseOutputScript, TXRef, int) -> None tx_ref: TXRef = None, position: int = None) -> None:
super(BaseOutput, self).__init__(tx_ref, position) super().__init__(tx_ref, position)
self.amount = amount self.amount = amount
self.script = script self.script = script
@ -184,7 +208,7 @@ class BaseOutput(InputOutput):
script=cls.script_class(stream.read_string()) script=cls.script_class(stream.read_string())
) )
def serialize_to(self, stream): def serialize_to(self, stream, alternate_script=None):
stream.write_uint64(self.amount) stream.write_uint64(self.amount)
stream.write_string(self.script.source) stream.write_string(self.script.source)
@ -194,7 +218,7 @@ class BaseTransaction:
input_class = BaseInput input_class = BaseInput
output_class = BaseOutput output_class = BaseOutput
def __init__(self, raw=None, version=1, locktime=0): def __init__(self, raw=None, version=1, locktime=0) -> None:
self._raw = raw self._raw = raw
self.ref = TXRefMutable(self) self.ref = TXRefMutable(self)
self.version = version # type: int self.version = version # type: int
@ -230,8 +254,7 @@ class BaseTransaction:
def outputs(self): # type: () -> ReadOnlyList[BaseOutput] def outputs(self): # type: () -> ReadOnlyList[BaseOutput]
return ReadOnlyList(self._outputs) return ReadOnlyList(self._outputs)
def _add(self, new_ios, existing_ios): def _add(self, new_ios: Iterable[InputOutput], existing_ios: List) -> 'BaseTransaction':
# type: (List[InputOutput], List[InputOutput]) -> BaseTransaction
for txio in new_ios: for txio in new_ios:
txio.tx_ref = self.ref txio.tx_ref = self.ref
txio.position = len(existing_ios) txio.position = len(existing_ios)
@ -239,28 +262,28 @@ class BaseTransaction:
self._reset() self._reset()
return self return self
def add_inputs(self, inputs): # type: (List[BaseInput]) -> BaseTransaction def add_inputs(self, inputs: Iterable[BaseInput]) -> 'BaseTransaction':
return self._add(inputs, self._inputs) return self._add(inputs, self._inputs)
def add_outputs(self, outputs): # type: (List[BaseOutput]) -> BaseTransaction def add_outputs(self, outputs: Iterable[BaseOutput]) -> 'BaseTransaction':
return self._add(outputs, self._outputs) return self._add(outputs, self._outputs)
@property @property
def fee(self): # type: () -> int def fee(self) -> int:
""" Fee that will actually be paid.""" """ Fee that will actually be paid."""
return self.input_sum - self.output_sum return self.input_sum - self.output_sum
@property @property
def size(self): # type: () -> int def size(self) -> int:
""" Size in bytes of the entire transaction. """ """ Size in bytes of the entire transaction. """
return len(self.raw) return len(self.raw)
@property @property
def base_size(self): # type: () -> int def base_size(self) -> int:
""" Size in bytes of transaction meta data and all outputs; without inputs. """ """ Size in bytes of transaction meta data and all outputs; without inputs. """
return len(self._serialize(with_inputs=False)) return len(self._serialize(with_inputs=False))
def _serialize(self, with_inputs=True): # type: (bool) -> bytes def _serialize(self, with_inputs: bool = True) -> bytes:
stream = BCDataStream() stream = BCDataStream()
stream.write_uint32(self.version) stream.write_uint32(self.version)
if with_inputs: if with_inputs:
@ -273,12 +296,13 @@ class BaseTransaction:
stream.write_uint32(self.locktime) stream.write_uint32(self.locktime)
return stream.get_bytes() return stream.get_bytes()
def _serialize_for_signature(self, signing_input): # type: (int) -> bytes def _serialize_for_signature(self, signing_input: int) -> bytes:
stream = BCDataStream() stream = BCDataStream()
stream.write_uint32(self.version) stream.write_uint32(self.version)
stream.write_compact_size(len(self._inputs)) stream.write_compact_size(len(self._inputs))
for i, txin in enumerate(self._inputs): for i, txin in enumerate(self._inputs):
if signing_input == i: if signing_input == i:
assert txin.txo_ref.txo is not None
txin.serialize_to(stream, txin.txo_ref.txo.script.source) txin.serialize_to(stream, txin.txo_ref.txo.script.source)
else: else:
txin.serialize_to(stream, b'') txin.serialize_to(stream, b'')
@ -304,8 +328,9 @@ class BaseTransaction:
self.locktime = stream.read_uint32() self.locktime = stream.read_uint32()
@classmethod @classmethod
def ensure_all_have_same_ledger(cls, funding_accounts, change_account=None): def ensure_all_have_same_ledger(
# type: (Iterable[torba.baseaccount.BaseAccount], torba.baseaccount.BaseAccount) -> torba.baseledger.BaseLedger cls, funding_accounts: Iterable[BaseAccount], change_account: BaseAccount = None)\
-> 'baseledger.BaseLedger':
ledger = None ledger = None
for account in funding_accounts: for account in funding_accounts:
if ledger is None: if ledger is None:
@ -316,33 +341,24 @@ class BaseTransaction:
) )
if change_account is not None and change_account.ledger != ledger: if change_account is not None and change_account.ledger != ledger:
raise ValueError('Change account must use same ledger as funding accounts.') raise ValueError('Change account must use same ledger as funding accounts.')
if ledger is None:
raise ValueError('No ledger found.')
return ledger return ledger
@classmethod @classmethod
@defer.inlineCallbacks @defer.inlineCallbacks
def pay(cls, outputs, funding_accounts, change_account, reserve_outputs=True): def pay(cls, outputs: Iterable[BaseOutput], funding_accounts: Iterable[BaseAccount],
# type: (List[BaseOutput], List[torba.baseaccount.BaseAccount], torba.baseaccount.BaseAccount) -> defer.Deferred change_account: BaseAccount):
""" Efficiently spend utxos from funding_accounts to cover the new outputs. """ """ Efficiently spend utxos from funding_accounts to cover the new outputs. """
tx = cls().add_outputs(outputs) tx = cls().add_outputs(outputs)
ledger = cls.ensure_all_have_same_ledger(funding_accounts, change_account) ledger = cls.ensure_all_have_same_ledger(funding_accounts, change_account)
amount = tx.output_sum + ledger.get_transaction_base_fee(tx) amount = tx.output_sum + ledger.get_transaction_base_fee(tx)
txos = yield ledger.get_effective_amount_estimators(funding_accounts) spendables = yield ledger.get_spendable_utxos(amount, funding_accounts)
selector = CoinSelector(
txos, amount,
ledger.get_input_output_fee(
cls.output_class.pay_pubkey_hash(COIN, NULL_HASH32)
)
)
spendables = selector.select()
if not spendables: if not spendables:
raise ValueError('Not enough funds to cover this transaction.') raise ValueError('Not enough funds to cover this transaction.')
reserved_outputs = [s.txo.id for s in spendables]
if reserve_outputs:
yield ledger.db.reserve_spent_outputs(reserved_outputs)
try: try:
spent_sum = sum(s.effective_amount for s in spendables) spent_sum = sum(s.effective_amount for s in spendables)
if spent_sum > amount: if spent_sum > amount:
@ -351,30 +367,25 @@ class BaseTransaction:
change_amount = spent_sum - amount change_amount = spent_sum - amount
tx.add_outputs([cls.output_class.pay_pubkey_hash(change_amount, change_hash160)]) tx.add_outputs([cls.output_class.pay_pubkey_hash(change_amount, change_hash160)])
tx.add_inputs([s.txi for s in spendables]) tx.add_inputs(s.txi for s in spendables)
yield tx.sign(funding_accounts) yield tx.sign(funding_accounts)
except Exception: except Exception as e:
if reserve_outputs: log.exception('Failed to synchronize transaction:')
yield ledger.db.release_reserved_outputs(reserved_outputs) yield ledger.release_outputs(s.txo for s in spendables)
raise raise e
defer.returnValue(tx) defer.returnValue(tx)
@classmethod @classmethod
@defer.inlineCallbacks @defer.inlineCallbacks
def liquidate(cls, assets, funding_accounts, change_account, reserve_outputs=True): def liquidate(cls, assets, funding_accounts, change_account):
""" Spend assets (utxos) supplementing with funding_accounts if fee is higher than asset value. """ """ Spend assets (utxos) supplementing with funding_accounts if fee is higher than asset value. """
tx = cls().add_inputs([ tx = cls().add_inputs([
cls.input_class.spend(utxo) for utxo in assets cls.input_class.spend(utxo) for utxo in assets
]) ])
ledger = cls.ensure_all_have_same_ledger(funding_accounts, change_account) ledger = cls.ensure_all_have_same_ledger(funding_accounts, change_account)
yield ledger.reserve_outputs(assets)
reserved_outputs = [utxo.id for utxo in assets]
if reserve_outputs:
yield ledger.db.reserve_spent_outputs(reserved_outputs)
try: try:
cost_of_change = ( cost_of_change = (
ledger.get_transaction_base_fee(tx) + ledger.get_transaction_base_fee(tx) +
@ -386,41 +397,35 @@ class BaseTransaction:
change_hash160 = change_account.ledger.address_to_hash160(change_address) change_hash160 = change_account.ledger.address_to_hash160(change_address)
change_amount = liquidated_total - cost_of_change change_amount = liquidated_total - cost_of_change
tx.add_outputs([cls.output_class.pay_pubkey_hash(change_amount, change_hash160)]) tx.add_outputs([cls.output_class.pay_pubkey_hash(change_amount, change_hash160)])
yield tx.sign(funding_accounts) yield tx.sign(funding_accounts)
except Exception: except Exception:
if reserve_outputs: yield ledger.release_outputs(assets)
yield ledger.db.release_reserved_outputs(reserved_outputs)
raise raise
defer.returnValue(tx) defer.returnValue(tx)
def signature_hash_type(self, hash_type): @staticmethod
def signature_hash_type(hash_type):
return hash_type return hash_type
@defer.inlineCallbacks @defer.inlineCallbacks
def sign(self, funding_accounts): # type: (Iterable[torba.baseaccount.BaseAccount]) -> BaseTransaction def sign(self, funding_accounts: Iterable[BaseAccount]) -> defer.Deferred:
ledger = self.ensure_all_have_same_ledger(funding_accounts) ledger = self.ensure_all_have_same_ledger(funding_accounts)
for i, txi in enumerate(self._inputs): for i, txi in enumerate(self._inputs):
assert txi.script is not None
assert txi.txo_ref.txo is not None
txo_script = txi.txo_ref.txo.script txo_script = txi.txo_ref.txo.script
if txo_script.is_pay_pubkey_hash: if txo_script.is_pay_pubkey_hash:
address = ledger.hash160_to_address(txo_script.values['pubkey_hash']) address = ledger.hash160_to_address(txo_script.values['pubkey_hash'])
private_key = yield ledger.get_private_key_for_address(address) private_key = yield ledger.get_private_key_for_address(address)
tx = self._serialize_for_signature(i) tx = self._serialize_for_signature(i)
txi.script.values['signature'] = \ txi.script.values['signature'] = \
private_key.sign(tx) + six.int2byte(self.signature_hash_type(1)) private_key.sign(tx) + bytes((self.signature_hash_type(1),))
txi.script.values['pubkey'] = private_key.public_key.pubkey_bytes txi.script.values['pubkey'] = private_key.public_key.pubkey_bytes
txi.script.generate() txi.script.generate()
else: else:
raise NotImplementedError("Don't know how to spend this output.") raise NotImplementedError("Don't know how to spend this output.")
self._reset() self._reset()
def sort(self):
# See https://github.com/kristovatlas/rfc/blob/master/bips/bip-li01.mediawiki
self._inputs.sort(key=lambda i: (i['prevout_hash'], i['prevout_n']))
self._outputs.sort(key=lambda o: (o[2], pay_script(o[0], o[1])))
@property @property
def input_sum(self): def input_sum(self):
return sum(i.amount for i in self.inputs) return sum(i.amount for i in self.inputs)

View file

@ -35,9 +35,9 @@ class BCDataStream:
return size return size
if size == 253: if size == 253:
return self.read_uint16() return self.read_uint16()
elif size == 254: if size == 254:
return self.read_uint32() return self.read_uint32()
elif size == 255: if size == 255:
return self.read_uint64() return self.read_uint64()
def write_compact_size(self, size): def write_compact_size(self, size):
@ -70,7 +70,7 @@ class BCDataStream:
def _read_struct(self, fmt): def _read_struct(self, fmt):
value = self.read(fmt.size) value = self.read(fmt.size)
if len(value) > 0: if value:
return fmt.unpack(value)[0] return fmt.unpack(value)[0]
def read_int8(self): def read_int8(self):

View file

@ -10,7 +10,6 @@
import struct import struct
import hashlib import hashlib
from six import int2byte, byte2int, indexbytes
import ecdsa import ecdsa
import ecdsa.ellipticcurve as EC import ecdsa.ellipticcurve as EC
@ -24,7 +23,7 @@ class DerivationError(Exception):
""" Raised when an invalid derivation occurs. """ """ Raised when an invalid derivation occurs. """
class _KeyBase(object): class _KeyBase:
""" A BIP32 Key, public or private. """ """ A BIP32 Key, public or private. """
CURVE = ecdsa.SECP256k1 CURVE = ecdsa.SECP256k1
@ -63,17 +62,23 @@ class _KeyBase(object):
if len(raw_serkey) != 33: if len(raw_serkey) != 33:
raise ValueError('raw_serkey must have length 33') raise ValueError('raw_serkey must have length 33')
return (ver_bytes + int2byte(self.depth) return (ver_bytes + bytes((self.depth,))
+ self.parent_fingerprint() + struct.pack('>I', self.n) + self.parent_fingerprint() + struct.pack('>I', self.n)
+ self.chain_code + raw_serkey) + self.chain_code + raw_serkey)
def identifier(self):
raise NotImplementedError
def extended_key(self):
raise NotImplementedError
def fingerprint(self): def fingerprint(self):
""" Return the key's fingerprint as 4 bytes. """ """ Return the key's fingerprint as 4 bytes. """
return self.identifier()[:4] return self.identifier()[:4]
def parent_fingerprint(self): def parent_fingerprint(self):
""" Return the parent key's fingerprint as 4 bytes. """ """ Return the parent key's fingerprint as 4 bytes. """
return self.parent.fingerprint() if self.parent else int2byte(0)*4 return self.parent.fingerprint() if self.parent else bytes((0,)*4)
def extended_key_string(self): def extended_key_string(self):
""" Return an extended key as a base58 string. """ """ Return an extended key as a base58 string. """
@ -84,7 +89,7 @@ class PubKey(_KeyBase):
""" A BIP32 public key. """ """ A BIP32 public key. """
def __init__(self, ledger, pubkey, chain_code, n, depth, parent=None): def __init__(self, ledger, pubkey, chain_code, n, depth, parent=None):
super(PubKey, self).__init__(ledger, chain_code, n, depth, parent) super().__init__(ledger, chain_code, n, depth, parent)
if isinstance(pubkey, ecdsa.VerifyingKey): if isinstance(pubkey, ecdsa.VerifyingKey):
self.verifying_key = pubkey self.verifying_key = pubkey
else: else:
@ -97,16 +102,16 @@ class PubKey(_KeyBase):
raise TypeError('pubkey must be raw bytes') raise TypeError('pubkey must be raw bytes')
if len(pubkey) != 33: if len(pubkey) != 33:
raise ValueError('pubkey must be 33 bytes') raise ValueError('pubkey must be 33 bytes')
if indexbytes(pubkey, 0) not in (2, 3): if pubkey[0] not in (2, 3):
raise ValueError('invalid pubkey prefix byte') raise ValueError('invalid pubkey prefix byte')
curve = cls.CURVE.curve curve = cls.CURVE.curve
is_odd = indexbytes(pubkey, 0) == 3 is_odd = pubkey[0] == 3
x = bytes_to_int(pubkey[1:]) x = bytes_to_int(pubkey[1:])
# p is the finite field order # p is the finite field order
a, b, p = curve.a(), curve.b(), curve.p() a, b, p = curve.a(), curve.b(), curve.p() # pylint: disable=invalid-name
y2 = pow(x, 3, p) + b y2 = pow(x, 3, p) + b # pylint: disable=invalid-name
assert a == 0 # Otherwise y2 += a * pow(x, 2, p) assert a == 0 # Otherwise y2 += a * pow(x, 2, p)
y = NT.square_root_mod_prime(y2 % p, p) y = NT.square_root_mod_prime(y2 % p, p)
if bool(y & 1) != is_odd: if bool(y & 1) != is_odd:
@ -119,7 +124,7 @@ class PubKey(_KeyBase):
def pubkey_bytes(self): def pubkey_bytes(self):
""" Return the compressed public key as 33 bytes. """ """ Return the compressed public key as 33 bytes. """
point = self.verifying_key.pubkey.point point = self.verifying_key.pubkey.point
prefix = int2byte(2 + (point.y() & 1)) prefix = bytes((2 + (point.y() & 1),))
padded_bytes = _exponent_to_bytes(point.x()) padded_bytes = _exponent_to_bytes(point.x())
return prefix + padded_bytes return prefix + padded_bytes
@ -137,10 +142,10 @@ class PubKey(_KeyBase):
raise ValueError('invalid BIP32 public key child number') raise ValueError('invalid BIP32 public key child number')
msg = self.pubkey_bytes + struct.pack('>I', n) msg = self.pubkey_bytes + struct.pack('>I', n)
L, R = self._hmac_sha512(msg) L, R = self._hmac_sha512(msg) # pylint: disable=invalid-name
curve = self.CURVE curve = self.CURVE
L = bytes_to_int(L) L = bytes_to_int(L) # pylint: disable=invalid-name
if L >= curve.order: if L >= curve.order:
raise DerivationError raise DerivationError
@ -172,7 +177,7 @@ class LowSValueSigningKey(ecdsa.SigningKey):
def sign_number(self, number, entropy=None, k=None): def sign_number(self, number, entropy=None, k=None):
order = self.privkey.order order = self.privkey.order
r, s = ecdsa.SigningKey.sign_number(self, number, entropy, k) r, s = ecdsa.SigningKey.sign_number(self, number, entropy, k) # pylint: disable=invalid-name
if s > order / 2: if s > order / 2:
s = order - s s = order - s
return r, s return r, s
@ -184,7 +189,7 @@ class PrivateKey(_KeyBase):
HARDENED = 1 << 31 HARDENED = 1 << 31
def __init__(self, ledger, privkey, chain_code, n, depth, parent=None): def __init__(self, ledger, privkey, chain_code, n, depth, parent=None):
super(PrivateKey, self).__init__(ledger, chain_code, n, depth, parent) super().__init__(ledger, chain_code, n, depth, parent)
if isinstance(privkey, ecdsa.SigningKey): if isinstance(privkey, ecdsa.SigningKey):
self.signing_key = privkey self.signing_key = privkey
else: else:
@ -254,10 +259,10 @@ class PrivateKey(_KeyBase):
serkey = self.public_key.pubkey_bytes serkey = self.public_key.pubkey_bytes
msg = serkey + struct.pack('>I', n) msg = serkey + struct.pack('>I', n)
L, R = self._hmac_sha512(msg) L, R = self._hmac_sha512(msg) # pylint: disable=invalid-name
curve = self.CURVE curve = self.CURVE
L = bytes_to_int(L) L = bytes_to_int(L) # pylint: disable=invalid-name
exponent = (L + bytes_to_int(self.private_key_bytes)) % curve.order exponent = (L + bytes_to_int(self.private_key_bytes)) % curve.order
if exponent == 0 or L >= curve.order: if exponent == 0 or L >= curve.order:
raise DerivationError raise DerivationError
@ -286,7 +291,7 @@ class PrivateKey(_KeyBase):
def _exponent_to_bytes(exponent): def _exponent_to_bytes(exponent):
"""Convert an exponent to 32 big-endian bytes""" """Convert an exponent to 32 big-endian bytes"""
return (int2byte(0)*32 + int_to_bytes(exponent))[-32:] return (bytes((0,)*32) + int_to_bytes(exponent))[-32:]
def _from_extended_key(ledger, ekey): def _from_extended_key(ledger, ekey):
@ -296,8 +301,8 @@ def _from_extended_key(ledger, ekey):
if len(ekey) != 78: if len(ekey) != 78:
raise ValueError('extended key must have length 78') raise ValueError('extended key must have length 78')
depth = indexbytes(ekey, 4) depth = ekey[4]
fingerprint = ekey[5:9] # Not used # fingerprint = ekey[5:9]
n, = struct.unpack('>I', ekey[9:13]) n, = struct.unpack('>I', ekey[9:13])
chain_code = ekey[13:45] chain_code = ekey[13:45]
@ -305,7 +310,7 @@ def _from_extended_key(ledger, ekey):
pubkey = ekey[45:] pubkey = ekey[45:]
key = PubKey(ledger, pubkey, chain_code, n, depth) key = PubKey(ledger, pubkey, chain_code, n, depth)
elif ekey[:4] == ledger.extended_private_key_prefix: elif ekey[:4] == ledger.extended_private_key_prefix:
if indexbytes(ekey, 45) != 0: if ekey[45] != 0:
raise ValueError('invalid extended private key prefix byte') raise ValueError('invalid extended private key prefix byte')
privkey = ekey[46:] privkey = ekey[46:]
key = PrivateKey(ledger, privkey, chain_code, n, depth) key = PrivateKey(ledger, privkey, chain_code, n, depth)

View file

@ -1 +1 @@
__path__ = __import__('pkgutil').extend_path(__path__, __name__) __path__: str = __import__('pkgutil').extend_path(__path__, __name__)

View file

@ -6,7 +6,6 @@ __node_url__ = (
) )
__electrumx__ = 'electrumx.lib.coins.BitcoinCashRegtest' __electrumx__ = 'electrumx.lib.coins.BitcoinCashRegtest'
from six import int2byte
from binascii import unhexlify from binascii import unhexlify
from torba.baseledger import BaseLedger from torba.baseledger import BaseLedger
from torba.baseheader import BaseHeaders from torba.baseheader import BaseHeaders
@ -26,8 +25,8 @@ class MainNetLedger(BaseLedger):
transaction_class = Transaction transaction_class = Transaction
pubkey_address_prefix = int2byte(0x00) pubkey_address_prefix = bytes((0,))
script_address_prefix = int2byte(0x05) script_address_prefix = bytes((5,))
extended_public_key_prefix = unhexlify('0488b21e') extended_public_key_prefix = unhexlify('0488b21e')
extended_private_key_prefix = unhexlify('0488ade4') extended_private_key_prefix = unhexlify('0488ade4')
@ -42,8 +41,8 @@ class RegTestLedger(MainNetLedger):
headers_class = UnverifiedHeaders headers_class = UnverifiedHeaders
network_name = 'regtest' network_name = 'regtest'
pubkey_address_prefix = int2byte(111) pubkey_address_prefix = bytes((111,))
script_address_prefix = int2byte(196) script_address_prefix = bytes((196,))
extended_public_key_prefix = unhexlify('043587cf') extended_public_key_prefix = unhexlify('043587cf')
extended_private_key_prefix = unhexlify('04358394') extended_private_key_prefix = unhexlify('04358394')

View file

@ -6,7 +6,6 @@ __node_url__ = (
) )
__electrumx__ = 'electrumx.lib.coins.BitcoinSegwitRegtest' __electrumx__ = 'electrumx.lib.coins.BitcoinSegwitRegtest'
from six import int2byte
from binascii import unhexlify from binascii import unhexlify
from torba.baseledger import BaseLedger from torba.baseledger import BaseLedger
from torba.baseheader import BaseHeaders from torba.baseheader import BaseHeaders
@ -17,8 +16,8 @@ class MainNetLedger(BaseLedger):
symbol = 'BTC' symbol = 'BTC'
network_name = 'mainnet' network_name = 'mainnet'
pubkey_address_prefix = int2byte(0x00) pubkey_address_prefix = bytes((0,))
script_address_prefix = int2byte(0x05) script_address_prefix = bytes((5,))
extended_public_key_prefix = unhexlify('0488b21e') extended_public_key_prefix = unhexlify('0488b21e')
extended_private_key_prefix = unhexlify('0488ade4') extended_private_key_prefix = unhexlify('0488ade4')
@ -33,8 +32,8 @@ class RegTestLedger(MainNetLedger):
headers_class = UnverifiedHeaders headers_class = UnverifiedHeaders
network_name = 'regtest' network_name = 'regtest'
pubkey_address_prefix = int2byte(111) pubkey_address_prefix = bytes((111,))
script_address_prefix = int2byte(196) script_address_prefix = bytes((196,))
extended_public_key_prefix = unhexlify('043587cf') extended_public_key_prefix = unhexlify('043587cf')
extended_private_key_prefix = unhexlify('04358394') extended_private_key_prefix = unhexlify('04358394')

View file

@ -1,16 +1,15 @@
import six
from random import Random from random import Random
from typing import List from typing import List
import torba from torba import basetransaction
MAXIMUM_TRIES = 100000 MAXIMUM_TRIES = 100000
class CoinSelector: class CoinSelector:
def __init__(self, txos, target, cost_of_change, seed=None): def __init__(self, txos: List[basetransaction.BaseOutputEffectiveAmountEstimator],
# type: (List[torba.basetransaction.BaseOutputAmountEstimator], int, int, str) -> None target: int, cost_of_change: int, seed: str = None) -> None:
self.txos = txos self.txos = txos
self.target = target self.target = target
self.cost_of_change = cost_of_change self.cost_of_change = cost_of_change
@ -18,17 +17,17 @@ class CoinSelector:
self.tries = 0 self.tries = 0
self.available = sum(c.effective_amount for c in self.txos) self.available = sum(c.effective_amount for c in self.txos)
self.random = Random(seed) self.random = Random(seed)
if six.PY3 and seed is not None: if seed is not None:
self.random.seed(seed, version=1) self.random.seed(seed, version=1)
def select(self): # type: () -> List[torba.basetransaction.BaseOutputAmountEstimator] def select(self) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]:
if not self.txos: if not self.txos:
return return []
if self.target > self.available: if self.target > self.available:
return return []
return self.branch_and_bound() or self.single_random_draw() return self.branch_and_bound() or self.single_random_draw()
def branch_and_bound(self): # type: () -> List[torba.basetransaction.BaseOutputAmountEstimator] def branch_and_bound(self) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]:
# see bitcoin implementation for more info: # see bitcoin implementation for more info:
# https://github.com/bitcoin/bitcoin/blob/master/src/wallet/coinselection.cpp # https://github.com/bitcoin/bitcoin/blob/master/src/wallet/coinselection.cpp
@ -36,9 +35,9 @@ class CoinSelector:
current_value = 0 current_value = 0
current_available_value = self.available current_available_value = self.available
current_selection = [] current_selection: List[bool] = []
best_waste = self.cost_of_change best_waste = self.cost_of_change
best_selection = [] best_selection: List[bool] = []
while self.tries < MAXIMUM_TRIES: while self.tries < MAXIMUM_TRIES:
self.tries += 1 self.tries += 1
@ -70,7 +69,7 @@ class CoinSelector:
utxo = self.txos[len(current_selection)] utxo = self.txos[len(current_selection)]
current_available_value -= utxo.effective_amount current_available_value -= utxo.effective_amount
previous_utxo = self.txos[len(current_selection) - 1] if current_selection else None previous_utxo = self.txos[len(current_selection) - 1] if current_selection else None
if current_selection and not current_selection[-1] and \ if current_selection and not current_selection[-1] and previous_utxo and \
utxo.effective_amount == previous_utxo.effective_amount and \ utxo.effective_amount == previous_utxo.effective_amount and \
utxo.fee == previous_utxo.fee: utxo.fee == previous_utxo.fee:
current_selection.append(False) current_selection.append(False)
@ -84,7 +83,9 @@ class CoinSelector:
self.txos[i] for i, include in enumerate(best_selection) if include self.txos[i] for i, include in enumerate(best_selection) if include
] ]
def single_random_draw(self): # type: () -> List[torba.basetransaction.BaseOutputAmountEstimator] return []
def single_random_draw(self) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]:
self.random.shuffle(self.txos, self.random.random) self.random.shuffle(self.txos, self.random.random)
selection = [] selection = []
amount = 0 amount = 0
@ -93,3 +94,4 @@ class CoinSelector:
amount += coin.effective_amount amount += coin.effective_amount
if amount >= self.target+self.cost_of_change: if amount >= self.target+self.cost_of_change:
return selection return selection
return []

View file

@ -9,7 +9,6 @@
""" Cryptography hash functions and related classes. """ """ Cryptography hash functions and related classes. """
import os import os
import six
import base64 import base64
import hashlib import hashlib
import hmac import hmac
@ -22,13 +21,8 @@ from cryptography.hazmat.backends import default_backend
from torba.util import bytes_to_int, int_to_bytes from torba.util import bytes_to_int, int_to_bytes
from torba.constants import NULL_HASH32 from torba.constants import NULL_HASH32
_sha256 = hashlib.sha256
_sha512 = hashlib.sha512
_new_hash = hashlib.new
_new_hmac = hmac.new
class TXRef:
class TXRef(object):
__slots__ = '_id', '_hash' __slots__ = '_id', '_hash'
@ -68,50 +62,29 @@ class TXRefImmutable(TXRef):
return ref return ref
class TXORef(object):
__slots__ = 'tx_ref', 'position'
def __init__(self, tx_ref, position): # type: (TXRef, int) -> None
self.tx_ref = tx_ref
self.position = position
@property
def id(self):
return '{}:{}'.format(self.tx_ref.id, self.position)
@property
def is_null(self):
return self.tx_ref.is_null
@property
def txo(self):
return None
def sha256(x): def sha256(x):
""" Simple wrapper of hashlib sha256. """ """ Simple wrapper of hashlib sha256. """
return _sha256(x).digest() return hashlib.sha256(x).digest()
def sha512(x): def sha512(x):
""" Simple wrapper of hashlib sha512. """ """ Simple wrapper of hashlib sha512. """
return _sha512(x).digest() return hashlib.sha512(x).digest()
def ripemd160(x): def ripemd160(x):
""" Simple wrapper of hashlib ripemd160. """ """ Simple wrapper of hashlib ripemd160. """
h = _new_hash('ripemd160') h = hashlib.new('ripemd160')
h.update(x) h.update(x)
return h.digest() return h.digest()
def pow_hash(x): def pow_hash(x):
r = sha512(double_sha256(x)) h = sha512(double_sha256(x))
r1 = ripemd160(r[:len(r) // 2]) return double_sha256(
r2 = ripemd160(r[len(r) // 2:]) ripemd160(h[:len(h) // 2]) +
r3 = double_sha256(r1 + r2) ripemd160(h[len(h) // 2:])
return r3 )
def double_sha256(x): def double_sha256(x):
@ -121,7 +94,7 @@ def double_sha256(x):
def hmac_sha512(key, msg): def hmac_sha512(key, msg):
""" Use SHA-512 to provide an HMAC. """ """ Use SHA-512 to provide an HMAC. """
return _new_hmac(key, msg, _sha512).digest() return hmac.new(key, msg, hashlib.sha512).digest()
def hash160(x): def hash160(x):
@ -165,7 +138,7 @@ class Base58Error(Exception):
""" Exception used for Base58 errors. """ """ Exception used for Base58 errors. """
class Base58(object): class Base58:
""" Class providing base 58 functionality. """ """ Class providing base 58 functionality. """
chars = u'123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz' chars = u'123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz'
@ -207,7 +180,7 @@ class Base58(object):
break break
count += 1 count += 1
if count: if count:
result = six.int2byte(0) * count + result result = bytes((0,)) * count + result
return result return result

View file

@ -56,7 +56,7 @@ CJK_INTERVALS = [
def is_cjk(c): def is_cjk(c):
n = ord(c) n = ord(c)
for start, end, name in CJK_INTERVALS: for start, end, _ in CJK_INTERVALS:
if start <= n <= end: if start <= n <= end:
return True return True
return False return False
@ -93,7 +93,7 @@ def load_words(filename):
return words return words
file_names = { FILE_NAMES = {
'en': 'english.txt', 'en': 'english.txt',
'es': 'spanish.txt', 'es': 'spanish.txt',
'ja': 'japanese.txt', 'ja': 'japanese.txt',
@ -102,20 +102,22 @@ file_names = {
} }
class Mnemonic(object): class Mnemonic:
# Seed derivation no longer follows BIP39 # Seed derivation no longer follows BIP39
# Mnemonic phrase uses a hash based checksum, instead of a words-dependent checksum # Mnemonic phrase uses a hash based checksum, instead of a words-dependent checksum
def __init__(self, lang='en'): def __init__(self, lang='en'):
filename = file_names.get(lang, 'english.txt') filename = FILE_NAMES.get(lang, 'english.txt')
self.words = load_words(filename) self.words = load_words(filename)
@classmethod @staticmethod
def mnemonic_to_seed(self, mnemonic, passphrase=u''): def mnemonic_to_seed(mnemonic, passphrase=u''):
PBKDF2_ROUNDS = 2048 pbkdf2_rounds = 2048
mnemonic = normalize_text(mnemonic) mnemonic = normalize_text(mnemonic)
passphrase = normalize_text(passphrase) passphrase = normalize_text(passphrase)
return pbkdf2.PBKDF2(mnemonic, passphrase, iterations=PBKDF2_ROUNDS, macmodule=hmac, digestmodule=hashlib.sha512).read(64) return pbkdf2.PBKDF2(
mnemonic, passphrase, iterations=pbkdf2_rounds, macmodule=hmac, digestmodule=hashlib.sha512
).read(64)
def mnemonic_encode(self, i): def mnemonic_encode(self, i):
n = len(self.words) n = len(self.words)
@ -131,8 +133,8 @@ class Mnemonic(object):
words = seed.split() words = seed.split()
i = 0 i = 0
while words: while words:
w = words.pop() word = words.pop()
k = self.words.index(w) k = self.words.index(word)
i = i*n + k i = i*n + k
return i return i

View file

@ -1,27 +1,7 @@
import six import asyncio
from twisted.internet.defer import Deferred, DeferredLock, maybeDeferred, inlineCallbacks from twisted.internet.defer import Deferred
from twisted.python.failure import Failure from twisted.python.failure import Failure
if six.PY3:
import asyncio
def execute_serially(f):
_lock = DeferredLock()
@inlineCallbacks
def allow_only_one_at_a_time(*args, **kwargs):
yield _lock.acquire()
allow_only_one_at_a_time.is_running = True
try:
yield maybeDeferred(f, *args, **kwargs)
finally:
allow_only_one_at_a_time.is_running = False
_lock.release()
allow_only_one_at_a_time.is_running = False
return allow_only_one_at_a_time
class BroadcastSubscription: class BroadcastSubscription:
@ -76,10 +56,10 @@ class StreamController:
@property @property
def _iterate_subscriptions(self): def _iterate_subscriptions(self):
next = self._first_subscription next_sub = self._first_subscription
while next is not None: while next_sub is not None:
subscription = next subscription = next_sub
next = next._next next_sub = next_sub._next
yield subscription yield subscription
def add(self, event): def add(self, event):
@ -96,15 +76,15 @@ class StreamController:
def _cancel(self, subscription): def _cancel(self, subscription):
previous = subscription._previous previous = subscription._previous
next = subscription._next next_sub = subscription._next
if previous is None: if previous is None:
self._first_subscription = next self._first_subscription = next_sub
else: else:
previous._next = next previous._next = next_sub
if next is None: if next_sub is None:
self._last_subscription = previous self._last_subscription = previous
else: else:
next._previous = previous next_sub._previous = previous
subscription._next = subscription._previous = subscription subscription._next = subscription._previous = subscription
def _listen(self, on_data, on_error, on_done): def _listen(self, on_data, on_error, on_done):

View file

@ -1,20 +1,19 @@
from binascii import unhexlify, hexlify from binascii import unhexlify, hexlify
from collections import Sequence from typing import TypeVar, Sequence
from typing import TypeVar, Generic
T = TypeVar('T') T = TypeVar('T')
class ReadOnlyList(Sequence, Generic[T]): class ReadOnlyList(Sequence[T]):
def __init__(self, lst): def __init__(self, lst):
self.lst = lst self.lst = lst
def __getitem__(self, key): # type: (int) -> T def __getitem__(self, key):
return self.lst[key] return self.lst[key]
def __len__(self): def __len__(self) -> int:
return len(self.lst) return len(self.lst)
@ -22,13 +21,13 @@ def subclass_tuple(name, base):
return type(name, (base,), {'__slots__': ()}) return type(name, (base,), {'__slots__': ()})
class cachedproperty(object): class cachedproperty:
def __init__(self, f): def __init__(self, f):
self.f = f self.f = f
def __get__(self, obj, type): def __get__(self, obj, objtype):
obj = obj or type obj = obj or objtype
value = self.f(obj) value = self.f(obj)
setattr(obj, self.f.__name__, value) setattr(obj, self.f.__name__, value)
return value return value
@ -42,8 +41,8 @@ def bytes_to_int(be_bytes):
def int_to_bytes(value): def int_to_bytes(value):
""" Converts an integer to a big-endian sequence of bytes. """ """ Converts an integer to a big-endian sequence of bytes. """
length = (value.bit_length() + 7) // 8 length = (value.bit_length() + 7) // 8
h = '%x' % value s = '%x' % value
return unhexlify(('0' * (len(h) % 2) + h).zfill(length * 2)) return unhexlify(('0' * (len(s) % 2) + s).zfill(length * 2))
def rev_hex(s): def rev_hex(s):
@ -56,8 +55,8 @@ def int_to_hex(i, length=1):
return rev_hex(s) return rev_hex(s)
def hex_to_int(s): def hex_to_int(x):
return int(b'0x' + hexlify(s[::-1]), 16) return int(b'0x' + hexlify(x[::-1]), 16)
def hash_encode(x): def hash_encode(x):

View file

@ -1,10 +1,13 @@
import stat import stat
import json import json
import os import os
from typing import List import typing
from typing import Sequence, MutableSequence
import torba.baseaccount if typing.TYPE_CHECKING:
import torba.baseledger from torba import baseaccount
from torba import baseledger
from torba import basemanager
class Wallet: class Wallet:
@ -14,24 +17,24 @@ class Wallet:
by physical files on the filesystem. by physical files on the filesystem.
""" """
def __init__(self, name='Wallet', accounts=None, storage=None): def __init__(self, name: str = 'Wallet', accounts: MutableSequence['baseaccount.BaseAccount'] = None,
# type: (str, List[torba.baseaccount.BaseAccount], WalletStorage) -> None storage: 'WalletStorage' = None) -> None:
self.name = name self.name = name
self.accounts = accounts or [] # type: List[torba.baseaccount.BaseAccount] self.accounts = accounts or []
self.storage = storage or WalletStorage() self.storage = storage or WalletStorage()
def generate_account(self, ledger): def generate_account(self, ledger: 'baseledger.BaseLedger') -> 'baseaccount.BaseAccount':
# type: (torba.baseledger.BaseLedger) -> torba.baseaccount.BaseAccount
account = ledger.account_class.generate(ledger, u'torba') account = ledger.account_class.generate(ledger, u'torba')
self.accounts.append(account) self.accounts.append(account)
return account return account
@classmethod @classmethod
def from_storage(cls, storage, manager): # type: (WalletStorage, 'WalletManager') -> Wallet def from_storage(cls, storage: 'WalletStorage', manager: 'basemanager.BaseWalletManager') -> 'Wallet':
json_dict = storage.read() json_dict = storage.read()
accounts = [] accounts = []
for account_dict in json_dict.get('accounts', []): account_dicts: Sequence[dict] = json_dict.get('accounts', [])
for account_dict in account_dicts:
ledger = manager.get_or_create_ledger(account_dict['ledger']) ledger = manager.get_or_create_ledger(account_dict['ledger'])
account = ledger.account_class.from_dict(ledger, account_dict) account = ledger.account_class.from_dict(ledger, account_dict)
accounts.append(account) accounts.append(account)
@ -110,7 +113,7 @@ class WalletStorage:
mode = stat.S_IREAD | stat.S_IWRITE mode = stat.S_IREAD | stat.S_IWRITE
try: try:
os.rename(temp_path, self.path) os.rename(temp_path, self.path)
except: except Exception: # pylint: disable=broad-except
os.remove(self.path) os.remove(self.path)
os.rename(temp_path, self.path) os.rename(temp_path, self.path)
os.chmod(self.path, mode) os.chmod(self.path, mode)