+ reserving outpoints should no longer have race conditions

+ converted all comment type annotations to python 3 style syntax annotations
+ pylint & mypy
This commit is contained in:
Lex Berezhny 2018-07-28 20:52:54 -04:00
parent 746aef474a
commit b0bd0b1fc0
26 changed files with 482 additions and 447 deletions

View file

@ -5,13 +5,25 @@ language: python
python:
- "3.7"
install:
- pip install tox-travis coverage
- pushd .. && git clone https://github.com/lbryio/electrumx.git --branch lbryumx && popd
- pushd .. && git clone https://github.com/lbryio/orchstr8.git && popd
jobs:
include:
script: tox
- stage: code quality
name: "pylint & mypy"
install:
- pip install pylint mypy
- pip install -e .
script:
- pylint --rcfile=setup.cfg torba
- mypy torba
after_success:
- coverage combine tests/
- bash <(curl -s https://codecov.io/bash)
- stage: test
name: "Unit Tests"
install:
- pip install tox-travis coverage
- pushd .. && git clone https://github.com/lbryio/electrumx.git --branch lbryumx && popd
- pushd .. && git clone https://github.com/lbryio/orchstr8.git && popd
script: tox
after_success:
- coverage combine tests/
- bash <(curl -s https://codecov.io/bash)

View file

@ -5,3 +5,26 @@ branch = True
source =
torba
.tox/*/lib/python*/site-packages/torba
[mypy-twisted.*,cryptography.*,ecdsa.*,pbkdf2]
ignore_missing_imports = True
[pylint]
max-args=10
max-line-length=110
good-names=T,t,n,i,j,k,x,y,s,f,d,h,c,e,op,db,tx,io,cachedproperty,log,id
valid-metaclass-classmethod-first-arg=mcs
disable=
fixme,
no-else-return,
cyclic-import,
missing-docstring,
duplicate-code,
expression-not-assigned,
inconsistent-return-statements,
too-few-public-methods,
too-many-locals,
too-many-arguments,
too-many-public-methods,
too-many-instance-attributes,
protected-access

View file

@ -32,8 +32,7 @@ setup(
'twisted',
'ecdsa',
'pbkdf2',
'cryptography',
'typing'
'cryptography'
),
extras_require={
'test': (

View file

@ -30,13 +30,13 @@ class BaseSelectionTestCase(unittest.TestCase):
class TestCoinSelectionTests(BaseSelectionTestCase):
def test_empty_coins(self):
self.assertIsNone(CoinSelector([], 0, 0).select())
self.assertEqual(CoinSelector([], 0, 0).select(), [])
def test_skip_binary_search_if_total_not_enough(self):
fee = utxo(CENT).get_estimator(self.ledger).fee
big_pool = self.estimates(utxo(CENT+fee) for _ in range(100))
selector = CoinSelector(big_pool, 101 * CENT, 0)
self.assertIsNone(selector.select())
self.assertEqual(selector.select(), [])
self.assertEqual(selector.tries, 0) # Never tried.
# check happy path
selector = CoinSelector(big_pool, 100 * CENT, 0)
@ -108,7 +108,7 @@ class TestOfficialBitcoinCoinSelectionTests(BaseSelectionTestCase):
self.assertEqual([3 * CENT, 2 * CENT], search(utxo_pool, 5 * CENT, 0.5 * CENT))
# Select 11 Cent, not possible
self.assertIsNone(search(utxo_pool, 11 * CENT, 0.5 * CENT))
self.assertEqual(search(utxo_pool, 11 * CENT, 0.5 * CENT), [])
# Select 10 Cent
utxo_pool += self.estimates(utxo(5 * CENT))
@ -126,12 +126,12 @@ class TestOfficialBitcoinCoinSelectionTests(BaseSelectionTestCase):
)
# Select 0.25 Cent, not possible
self.assertIsNone(search(utxo_pool, 0.25 * CENT, 0.5 * CENT))
self.assertEqual(search(utxo_pool, 0.25 * CENT, 0.5 * CENT), [])
# Iteration exhaustion test
utxo_pool, target = self.make_hard_case(17)
selector = CoinSelector(utxo_pool, target, 0)
self.assertIsNone(selector.branch_and_bound())
self.assertEqual(selector.branch_and_bound(), [])
self.assertEqual(selector.tries, MAXIMUM_TRIES) # Should exhaust
utxo_pool, target = self.make_hard_case(14)
self.assertIsNotNone(search(utxo_pool, target, 0)) # Should not exhaust
@ -152,4 +152,4 @@ class TestOfficialBitcoinCoinSelectionTests(BaseSelectionTestCase):
# Select 1 Cent with pool of only greater than 5 Cent
utxo_pool = self.estimates(utxo(i * CENT) for i in range(5, 21))
for _ in range(100):
self.assertIsNone(search(utxo_pool, 1 * CENT, 2 * CENT))
self.assertEqual(search(utxo_pool, 1 * CENT, 2 * CENT), [])

View file

@ -1,4 +1,3 @@
import six
from binascii import hexlify
from twisted.trial import unittest
from twisted.internet import defer
@ -7,9 +6,6 @@ from torba.coin.bitcoinsegwit import MainNetLedger
from .test_transaction import get_transaction, get_output
if six.PY3:
buffer = memoryview
class MockNetwork:
@ -50,9 +46,7 @@ class MainNetTestLedger(MainNetLedger):
network_name = 'unittest'
def __init__(self):
super(MainNetLedger, self).__init__({
'db': MainNetLedger.database_class(':memory:')
})
super().__init__({'db': MainNetLedger.database_class(':memory:')})
class LedgerTestCase(unittest.TestCase):

View file

@ -3,14 +3,14 @@ from twisted.trial import unittest
from torba.coin.bitcoinsegwit import MainNetLedger as BTCLedger
from torba.coin.bitcoincash import MainNetLedger as BCHLedger
from torba.manager import WalletManager
from torba.basemanager import BaseWalletManager
from torba.wallet import Wallet, WalletStorage
class TestWalletCreation(unittest.TestCase):
def setUp(self):
self.manager = WalletManager()
self.manager = BaseWalletManager()
config = {'data_path': '/tmp/wallet'}
self.btc_ledger = self.manager.get_or_create_ledger(BTCLedger.get_id(), config)
self.bch_ledger = self.manager.get_or_create_ledger(BCHLedger.get_id(), config)
@ -63,7 +63,7 @@ class TestWalletCreation(unittest.TestCase):
self.assertDictEqual(wallet_dict, wallet.to_dict())
def test_read_write(self):
manager = WalletManager()
manager = BaseWalletManager()
config = {'data_path': '/tmp/wallet'}
ledger = manager.get_or_create_ledger(BTCLedger.get_id(), config)

View file

@ -1,2 +1,2 @@
__path__ = __import__('pkgutil').extend_path(__path__, __name__)
__path__: str = __import__('pkgutil').extend_path(__path__, __name__)
__version__ = '0.0.4'

View file

@ -1,12 +1,16 @@
from typing import Dict
import typing
from typing import Sequence
from twisted.internet import defer
from torba.mnemonic import Mnemonic
from torba.bip32 import PrivateKey, PubKey, from_extended_key_string
from torba.hash import double_sha256, aes_encrypt, aes_decrypt
if typing.TYPE_CHECKING:
from torba import baseledger
class KeyManager(object):
class KeyManager:
__slots__ = 'account', 'public_key', 'chain_number'
@ -19,27 +23,27 @@ class KeyManager(object):
def db(self):
return self.account.ledger.db
def _query_addresses(self, limit=None, max_used_times=None, order_by=None):
def _query_addresses(self, limit: int = None, max_used_times: int = None, order_by=None):
return self.db.get_addresses(
self.account, self.chain_number, limit, max_used_times, order_by
)
def get_max_gap(self): # type: () -> defer.Deferred
def get_max_gap(self) -> defer.Deferred:
raise NotImplementedError
def ensure_address_gap(self): # type: () -> defer.Deferred
def ensure_address_gap(self) -> defer.Deferred:
raise NotImplementedError
def get_address_records(self, limit=None, only_usable=False): # type: (int, bool) -> defer.Deferred
def get_address_records(self, limit: int = None, only_usable: bool = False) -> defer.Deferred:
raise NotImplementedError
@defer.inlineCallbacks
def get_addresses(self, limit=None, only_usable=False): # type: (int, bool) -> defer.Deferred
def get_addresses(self, limit: int = None, only_usable: bool = False) -> defer.Deferred:
records = yield self.get_address_records(limit=limit, only_usable=only_usable)
defer.returnValue([r['address'] for r in records])
@defer.inlineCallbacks
def get_or_create_usable_address(self): # type: () -> defer.Deferred
def get_or_create_usable_address(self) -> defer.Deferred:
addresses = yield self.get_addresses(limit=1, only_usable=True)
if addresses:
defer.returnValue(addresses[0])
@ -52,14 +56,14 @@ class KeyChain(KeyManager):
__slots__ = 'gap', 'maximum_uses_per_address'
def __init__(self, account, root_public_key, chain_number, gap, maximum_uses_per_address):
# type: ('BaseAccount', PubKey, int, int, int) -> None
super(KeyChain, self).__init__(account, root_public_key.child(chain_number), chain_number)
def __init__(self, account: 'BaseAccount', root_public_key: PubKey,
chain_number: int, gap: int, maximum_uses_per_address: int) -> None:
super().__init__(account, root_public_key.child(chain_number), chain_number)
self.gap = gap
self.maximum_uses_per_address = maximum_uses_per_address
@defer.inlineCallbacks
def generate_keys(self, start, end):
def generate_keys(self, start: int, end: int) -> defer.Deferred:
new_keys = []
for index in range(start, end+1):
new_keys.append((index, self.public_key.child(index)))
@ -69,7 +73,7 @@ class KeyChain(KeyManager):
defer.returnValue([key[1].address for key in new_keys])
@defer.inlineCallbacks
def get_max_gap(self):
def get_max_gap(self) -> defer.Deferred:
addresses = yield self._query_addresses(order_by="position ASC")
max_gap = 0
current_gap = 0
@ -82,7 +86,7 @@ class KeyChain(KeyManager):
defer.returnValue(max_gap)
@defer.inlineCallbacks
def ensure_address_gap(self):
def ensure_address_gap(self) -> defer.Deferred:
addresses = yield self._query_addresses(self.gap, None, "position DESC")
existing_gap = 0
@ -100,7 +104,7 @@ class KeyChain(KeyManager):
new_keys = yield self.generate_keys(start, end-1)
defer.returnValue(new_keys)
def get_address_records(self, limit=None, only_usable=False):
def get_address_records(self, limit: int = None, only_usable: bool = False):
return self._query_addresses(
limit, self.maximum_uses_per_address if only_usable else None,
"used_times ASC, position ASC"
@ -112,15 +116,11 @@ class SingleKey(KeyManager):
__slots__ = ()
def __init__(self, account, root_public_key, chain_number):
# type: ('BaseAccount', PubKey) -> None
super(SingleKey, self).__init__(account, root_public_key, chain_number)
def get_max_gap(self):
def get_max_gap(self) -> defer.Deferred:
return defer.succeed(0)
@defer.inlineCallbacks
def ensure_address_gap(self):
def ensure_address_gap(self) -> defer.Deferred:
exists = yield self.get_address_records()
if not exists:
yield self.db.add_keys(
@ -129,20 +129,20 @@ class SingleKey(KeyManager):
defer.returnValue([self.public_key.address])
defer.returnValue([])
def get_address_records(self, **kwargs):
def get_address_records(self, limit: int = None, only_usable: bool = False) -> defer.Deferred:
return self._query_addresses()
class BaseAccount(object):
class BaseAccount:
mnemonic_class = Mnemonic
private_key_class = PrivateKey
public_key_class = PubKey
def __init__(self, ledger, name, seed, encrypted, is_hd, private_key,
public_key, receiving_gap=20, change_gap=6,
receiving_maximum_uses_per_address=2, change_maximum_uses_per_address=2):
# type: (torba.baseledger.BaseLedger, str, str, bool, bool, PrivateKey, PubKey, int, int, int, int) -> None
def __init__(self, ledger: 'baseledger.BaseLedger', name: str, seed: str, encrypted: bool, is_hd: bool,
private_key: PrivateKey, public_key: PubKey, receiving_gap: int = 20, change_gap: int = 6,
receiving_maximum_uses_per_address: int = 2, change_maximum_uses_per_address: int = 2
) -> None:
self.ledger = ledger
self.name = name
self.seed = seed
@ -150,25 +150,26 @@ class BaseAccount(object):
self.private_key = private_key
self.public_key = public_key
if is_hd:
receiving, change = self.keychains = (
KeyChain(self, public_key, 0, receiving_gap, receiving_maximum_uses_per_address),
KeyChain(self, public_key, 1, change_gap, change_maximum_uses_per_address)
self.receiving: KeyManager = KeyChain(
self, public_key, 0, receiving_gap, receiving_maximum_uses_per_address
)
self.change: KeyManager = KeyChain(
self, public_key, 1, change_gap, change_maximum_uses_per_address
)
self.keychains: Sequence[KeyManager] = (self.receiving, self.change)
else:
self.keychains = SingleKey(self, public_key, 0),
receiving = change = self.keychains[0]
self.receiving = receiving # type: KeyManager
self.change = change # type: KeyManager
self.change = self.receiving = SingleKey(self, public_key, 0)
self.keychains = (self.receiving,)
ledger.add_account(self)
@classmethod
def generate(cls, ledger, password, **kwargs): # type: (torba.baseledger.BaseLedger, str) -> BaseAccount
def generate(cls, ledger: 'baseledger.BaseLedger', password: str, **kwargs):
seed = cls.mnemonic_class().make_seed()
return cls.from_seed(ledger, seed, password, **kwargs)
@classmethod
def from_seed(cls, ledger, seed, password, is_hd=True, **kwargs):
# type: (torba.baseledger.BaseLedger, str, str) -> BaseAccount
def from_seed(cls, ledger: 'baseledger.BaseLedger', seed: str, password: str,
is_hd: bool = True, **kwargs):
private_key = cls.get_private_key_from_seed(ledger, seed, password)
return cls(
ledger=ledger, name='Account #{}'.format(private_key.public_key.address),
@ -179,14 +180,13 @@ class BaseAccount(object):
)
@classmethod
def get_private_key_from_seed(cls, ledger, seed, password):
# type: (torba.baseledger.BaseLedger, str, str) -> PrivateKey
def get_private_key_from_seed(cls, ledger: 'baseledger.BaseLedger', seed: str, password: str):
return cls.private_key_class.from_seed(
ledger, cls.mnemonic_class.mnemonic_to_seed(seed, password)
)
@classmethod
def from_dict(cls, ledger, d): # type: (torba.baseledger.BaseLedger, Dict) -> BaseAccount
def from_dict(cls, ledger: 'baseledger.BaseLedger', d: dict):
if not d['encrypted'] and d['private_key']:
private_key = from_extended_key_string(ledger, d['private_key'])
public_key = private_key.public_key
@ -264,21 +264,20 @@ class BaseAccount(object):
defer.returnValue(addresses)
@defer.inlineCallbacks
def get_addresses(self, limit=None, max_used_times=None): # type: (int, int) -> defer.Deferred
def get_addresses(self, limit: int = None, max_used_times: int = None) -> defer.Deferred:
records = yield self.get_address_records(limit, max_used_times)
defer.returnValue([r['address'] for r in records])
def get_address_records(self, limit=None, max_used_times=None): # type: (int, int) -> defer.Deferred
def get_address_records(self, limit: int = None, max_used_times: int = None) -> defer.Deferred:
return self.ledger.db.get_addresses(self, None, limit, max_used_times)
def get_private_key(self, chain, index):
def get_private_key(self, chain: int, index: int) -> PrivateKey:
assert not self.encrypted, "Cannot get private key on encrypted wallet account."
if isinstance(self.receiving, SingleKey):
return self.private_key
else:
return self.private_key.child(chain).child(index)
return self.private_key.child(chain).child(index)
def get_balance(self, confirmations=6, **constraints):
def get_balance(self, confirmations: int = 6, **constraints):
if confirmations > 0:
height = self.ledger.headers.height - (confirmations-1)
constraints.update({'height__lte': height, 'height__gt': 0})

View file

@ -1,5 +1,5 @@
import logging
from typing import List, Union
from typing import Tuple, List, Sequence
from operator import itemgetter
import sqlite3
@ -11,13 +11,13 @@ from torba.hash import TXRefImmutable
log = logging.getLogger(__name__)
class SQLiteMixin(object):
class SQLiteMixin:
CREATE_TABLES_QUERY = None
CREATE_TABLES_QUERY: Sequence[str] = ()
def __init__(self, path):
self._db_path = path
self.db = None
self.db: adbapi.ConnectionPool = None
def start(self):
log.info("connecting to database: %s", self._db_path)
@ -32,8 +32,8 @@ class SQLiteMixin(object):
self.db.close()
return defer.succeed(True)
def _insert_sql(self, table, data):
# type: (str, dict) -> tuple[str, List]
@staticmethod
def _insert_sql(table: str, data: dict) -> Tuple[str, List]:
columns, values = [], []
for column, value in data.items():
columns.append(column)
@ -43,8 +43,8 @@ class SQLiteMixin(object):
)
return sql, values
def _update_sql(self, table, data, where, constraints):
# type: (str, dict) -> tuple[str, List]
@staticmethod
def _update_sql(table: str, data: dict, where: str, constraints: list) -> Tuple[str, list]:
columns, values = [], []
for column, value in data.items():
columns.append("{} = ?".format(column))
@ -146,7 +146,8 @@ class BaseDatabase(SQLiteMixin):
CREATE_TXI_TABLE
)
def txo_to_row(self, tx, address, txo):
@staticmethod
def txo_to_row(tx, address, txo):
return {
'txid': tx.id,
'txoid': txo.id,
@ -156,7 +157,7 @@ class BaseDatabase(SQLiteMixin):
'script': sqlite3.Binary(txo.script.source)
}
def save_transaction_io(self, save_tx, tx, height, is_verified, address, hash, history):
def save_transaction_io(self, save_tx, tx, height, is_verified, address, txhash, history):
def _steps(t):
if save_tx == 'insert':
@ -168,9 +169,8 @@ class BaseDatabase(SQLiteMixin):
}))
elif save_tx == 'update':
self.execute(t, *self._update_sql("tx", {
'height': height, 'is_verified': is_verified
}, 'txid = ?', (tx.id,)
))
'height': height, 'is_verified': is_verified
}, 'txid = ?', (tx.id,)))
existing_txos = list(map(itemgetter(0), self.execute(
t, "SELECT position FROM txo WHERE txid = ?", (tx.id,)
@ -179,7 +179,7 @@ class BaseDatabase(SQLiteMixin):
for txo in tx.outputs:
if txo.position in existing_txos:
continue
if txo.script.is_pay_pubkey_hash and txo.script.values['pubkey_hash'] == hash:
if txo.script.is_pay_pubkey_hash and txo.script.values['pubkey_hash'] == txhash:
self.execute(t, *self._insert_sql("txo", self.txo_to_row(tx, address, txo)))
elif txo.script.is_pay_script_hash:
# TODO: implement script hash payments
@ -202,15 +202,16 @@ class BaseDatabase(SQLiteMixin):
return self.db.runInteraction(_steps)
def reserve_spent_outputs(self, txoids, is_reserved=True):
def reserve_outputs(self, txos, is_reserved=True):
txoids = [txo.id for txo in txos]
return self.run_operation(
"UPDATE txo SET is_reserved = ? WHERE txoid IN ({})".format(
', '.join(['?']*len(txoids))
), [is_reserved]+txoids
)
def release_reserved_outputs(self, txoids):
return self.reserve_spent_outputs(txoids, is_reserved=False)
def release_outputs(self, txos):
return self.reserve_outputs(txos, is_reserved=False)
@defer.inlineCallbacks
def get_transaction(self, txid):
@ -226,7 +227,7 @@ class BaseDatabase(SQLiteMixin):
extra_sql = ""
if constraints:
extras = []
for key in constraints.keys():
for key in constraints:
col, op = key, '='
if key.endswith('__not'):
col, op = key[:-len('__not')], '!='
@ -257,7 +258,7 @@ class BaseDatabase(SQLiteMixin):
extra_sql = ""
if constraints:
extra_sql = ' AND ' + ' AND '.join(
'{} = :{}'.format(c, c) for c in constraints.keys()
'{} = :{}'.format(c, c) for c in constraints
)
values = {'account': account.public_key.address}
values.update(constraints)

View file

@ -1,13 +1,16 @@
import os
import struct
import logging
import typing
from binascii import unhexlify
from twisted.internet import threads, defer
from torba.stream import StreamController, execute_serially
from torba.stream import StreamController
from torba.util import int_to_hex, rev_hex, hash_encode
from torba.hash import double_sha256, pow_hash
if typing.TYPE_CHECKING:
from torba import baseledger
log = logging.getLogger(__name__)
@ -17,7 +20,7 @@ class BaseHeaders:
header_size = 80
verify_bits_to_target = True
def __init__(self, ledger): # type: (baseledger.BaseLedger) -> BaseHeaders
def __init__(self, ledger: 'baseledger.BaseLedger') -> None:
self.ledger = ledger
self._size = None
self._on_change_controller = StreamController()
@ -62,7 +65,6 @@ class BaseHeaders:
header = self.sync_read_header(height)
return self._deserialize(height, header)
@execute_serially
@defer.inlineCallbacks
def connect(self, start, headers):
yield threads.deferToThread(self._sync_connect, start, headers)
@ -84,8 +86,9 @@ class BaseHeaders:
_old_size = self._size
self._size = self.sync_read_length()
change = self._size - _old_size
log.info('{}: added {} header blocks, final height {}'.format(
self.ledger.get_id(), change, self.height)
log.info(
'%s: added %s header blocks, final height %s',
self.ledger.get_id(), change, self.height
)
self._on_change_controller.add(change)
@ -101,7 +104,7 @@ class BaseHeaders:
assert previous_hash == header['prev_block_hash'], \
"prev hash mismatch: {} vs {}".format(previous_hash, header['prev_block_hash'])
bits, target = self._calculate_next_work_required(height, previous_header, header)
bits, _ = self._calculate_next_work_required(height, previous_header, header)
assert bits == header['bits'], \
"bits mismatch: {} vs {} (hash: {})".format(
bits, header['bits'], self._hash_header(header))
@ -154,37 +157,37 @@ class BaseHeaders:
if self.verify_bits_to_target:
bits = last['bits']
bitsN = (bits >> 24) & 0xff
assert 0x03 <= bitsN <= 0x1d, \
"First part of bits should be in [0x03, 0x1d], but it was {}".format(hex(bitsN))
bitsBase = bits & 0xffffff
assert 0x8000 <= bitsBase <= 0x7fffff, \
"Second part of bits should be in [0x8000, 0x7fffff] but it was {}".format(bitsBase)
bits_n = (bits >> 24) & 0xff
assert 0x03 <= bits_n <= 0x1d, \
"First part of bits should be in [0x03, 0x1d], but it was {}".format(hex(bits_n))
bits_base = bits & 0xffffff
assert 0x8000 <= bits_base <= 0x7fffff, \
"Second part of bits should be in [0x8000, 0x7fffff] but it was {}".format(bits_base)
# new target
retargetTimespan = self.ledger.target_timespan
nActualTimespan = last['timestamp'] - first['timestamp']
retarget_timespan = self.ledger.target_timespan
n_actual_timespan = last['timestamp'] - first['timestamp']
nModulatedTimespan = retargetTimespan + (nActualTimespan - retargetTimespan) // 8
n_modulated_timespan = retarget_timespan + (n_actual_timespan - retarget_timespan) // 8
nMinTimespan = retargetTimespan - (retargetTimespan // 8)
nMaxTimespan = retargetTimespan + (retargetTimespan // 2)
n_min_timespan = retarget_timespan - (retarget_timespan // 8)
n_max_timespan = retarget_timespan + (retarget_timespan // 2)
# Limit adjustment step
if nModulatedTimespan < nMinTimespan:
nModulatedTimespan = nMinTimespan
elif nModulatedTimespan > nMaxTimespan:
nModulatedTimespan = nMaxTimespan
if n_modulated_timespan < n_min_timespan:
n_modulated_timespan = n_min_timespan
elif n_modulated_timespan > n_max_timespan:
n_modulated_timespan = n_max_timespan
# Retarget
bnPowLimit = _ArithUint256(self.ledger.max_target)
bnNew = _ArithUint256.SetCompact(last['bits'])
bnNew *= nModulatedTimespan
bnNew //= nModulatedTimespan
if bnNew > bnPowLimit:
bnNew = bnPowLimit
bn_pow_limit = _ArithUint256(self.ledger.max_target)
bn_new = _ArithUint256.set_compact(last['bits'])
bn_new *= n_modulated_timespan
bn_new //= n_modulated_timespan
if bn_new > bn_pow_limit:
bn_new = bn_pow_limit
return bnNew.GetCompact(), bnNew._value
return bn_new.get_compact(), bn_new._value
class _ArithUint256:
@ -197,49 +200,48 @@ class _ArithUint256:
return hex(self._value)
@staticmethod
def fromCompact(nCompact):
def from_compact(n_compact):
"""Convert a compact representation into its value"""
nSize = nCompact >> 24
n_size = n_compact >> 24
# the lower 23 bits
nWord = nCompact & 0x007fffff
if nSize <= 3:
return nWord >> 8 * (3 - nSize)
n_word = n_compact & 0x007fffff
if n_size <= 3:
return n_word >> 8 * (3 - n_size)
else:
return nWord << 8 * (nSize - 3)
return n_word << 8 * (n_size - 3)
@classmethod
def SetCompact(cls, nCompact):
return cls(cls.fromCompact(nCompact))
def set_compact(cls, n_compact):
return cls(cls.from_compact(n_compact))
def bits(self):
"""Returns the position of the highest bit set plus one."""
bn = bin(self._value)[2:]
for i, d in enumerate(bn):
bits = bin(self._value)[2:]
for i, d in enumerate(bits):
if d:
return (len(bn) - i) + 1
return (len(bits) - i) + 1
return 0
def GetLow64(self):
def get_low64(self):
return self._value & 0xffffffffffffffff
def GetCompact(self):
def get_compact(self):
"""Convert a value into its compact representation"""
nSize = (self.bits() + 7) // 8
nCompact = 0
if nSize <= 3:
nCompact = self.GetLow64() << 8 * (3 - nSize)
n_size = (self.bits() + 7) // 8
if n_size <= 3:
n_compact = self.get_low64() << 8 * (3 - n_size)
else:
bn = _ArithUint256(self._value >> 8 * (nSize - 3))
nCompact = bn.GetLow64()
n = _ArithUint256(self._value >> 8 * (n_size - 3))
n_compact = n.get_low64()
# The 0x00800000 bit denotes the sign.
# Thus, if it is already set, divide the mantissa by 256 and increase the exponent.
if nCompact & 0x00800000:
nCompact >>= 8
nSize += 1
assert (nCompact & ~0x007fffff) == 0
assert nSize < 256
nCompact |= nSize << 24
return nCompact
if n_compact & 0x00800000:
n_compact >>= 8
n_size += 1
assert (n_compact & ~0x007fffff) == 0
assert n_size < 256
n_compact |= n_size << 24
return n_compact
def __mul__(self, x):
# Take the mod because we are limited to an unsigned 256 bit number

View file

@ -1,6 +1,4 @@
import os
import six
import hashlib
import logging
from binascii import hexlify, unhexlify
from typing import Dict, Type, Iterable
@ -14,17 +12,22 @@ from torba import basedatabase
from torba import baseheader
from torba import basenetwork
from torba import basetransaction
from torba.stream import StreamController, execute_serially
from torba.hash import hash160, double_sha256, Base58
from torba.coinselection import CoinSelector
from torba.constants import COIN, NULL_HASH32
from torba.stream import StreamController
from torba.hash import hash160, double_sha256, sha256, Base58
log = logging.getLogger(__name__)
LedgerType = Type['BaseLedger']
class LedgerRegistry(type):
ledgers = {} # type: Dict[str, Type[BaseLedger]]
ledgers: Dict[str, LedgerType] = {}
def __new__(mcs, name, bases, attrs):
cls = super(LedgerRegistry, mcs).__new__(mcs, name, bases, attrs) # type: Type[BaseLedger]
cls: LedgerType = super().__new__(mcs, name, bases, attrs)
if not (name == 'BaseLedger' and not bases):
ledger_id = cls.get_id()
assert ledger_id not in mcs.ledgers,\
@ -33,7 +36,7 @@ class LedgerRegistry(type):
return cls
@classmethod
def get_ledger_class(mcs, ledger_id): # type: (str) -> Type[BaseLedger]
def get_ledger_class(mcs, ledger_id: str) -> LedgerType:
return mcs.ledgers[ledger_id]
@ -41,11 +44,11 @@ class TransactionEvent(namedtuple('TransactionEvent', ('address', 'tx', 'height'
pass
class BaseLedger(six.with_metaclass(LedgerRegistry)):
class BaseLedger(metaclass=LedgerRegistry):
name = None
symbol = None
network_name = None
name: str
symbol: str
network_name: str
account_class = baseaccount.BaseAccount
database_class = basedatabase.BaseDatabase
@ -54,10 +57,10 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)):
transaction_class = basetransaction.BaseTransaction
secret_prefix = None
pubkey_address_prefix = None
script_address_prefix = None
extended_public_key_prefix = None
extended_private_key_prefix = None
pubkey_address_prefix: bytes
script_address_prefix: bytes
extended_public_key_prefix: bytes
extended_private_key_prefix: bytes
default_fee_per_byte = 10
@ -71,13 +74,14 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)):
self.network.on_status.listen(self.process_status)
self.accounts = []
self.headers = self.config.get('headers') or self.headers_class(self)
self.fee_per_byte = self.config.get('fee_per_byte', self.default_fee_per_byte)
self.fee_per_byte: int = self.config.get('fee_per_byte', self.default_fee_per_byte)
self._on_transaction_controller = StreamController()
self.on_transaction = self._on_transaction_controller.stream
self.on_transaction.listen(
lambda e: log.info('({}) on_transaction: address={}, height={}, is_verified={}, tx.id={}'.format(
self.get_id(), e.address, e.height, e.is_verified, e.tx.id)
lambda e: log.info(
'(%s) on_transaction: address=%s, height=%s, is_verified=%s, tx.id=%s',
self.get_id(), e.address, e.height, e.is_verified, e.tx.id
)
)
@ -85,6 +89,8 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)):
self.on_header = self._on_header_controller.stream
self._transaction_processing_locks = {}
self._utxo_reservation_lock = defer.DeferredLock()
self._header_processing_lock = defer.DeferredLock()
@classmethod
def get_id(cls):
@ -97,9 +103,7 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)):
@staticmethod
def address_to_hash160(address):
bytes = Base58.decode(address)
prefix, pubkey_bytes, addr_checksum = bytes[0], bytes[1:21], bytes[21:]
return pubkey_bytes
return Base58.decode(address)[1:21]
@classmethod
def public_key_to_address(cls, public_key):
@ -113,7 +117,7 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)):
def path(self):
return os.path.join(self.config['data_path'], self.get_id())
def get_input_output_fee(self, io):
def get_input_output_fee(self, io: basetransaction.InputOutput) -> int:
""" Fee based on size of the input / output. """
return self.fee_per_byte * io.size
@ -122,14 +126,14 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)):
return self.fee_per_byte * tx.base_size
@defer.inlineCallbacks
def add_account(self, account): # type: (baseaccount.BaseAccount) -> None
def add_account(self, account: baseaccount.BaseAccount) -> defer.Deferred:
self.accounts.append(account)
if self.network.is_connected:
yield self.update_account(account)
@defer.inlineCallbacks
def get_transaction(self, txhash):
raw, height, is_verified = yield self.db.get_transaction(txhash)
raw, _, _ = yield self.db.get_transaction(txhash)
if raw is not None:
defer.returnValue(self.transaction_class(raw))
@ -142,8 +146,7 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)):
defer.returnValue(account.get_private_key(match['chain'], match['position']))
@defer.inlineCallbacks
def get_effective_amount_estimators(self, funding_accounts):
# type: (Iterable[baseaccount.BaseAccount]) -> defer.Deferred
def get_effective_amount_estimators(self, funding_accounts: Iterable[baseaccount.BaseAccount]):
estimators = []
for account in funding_accounts:
utxos = yield account.get_unspent_outputs()
@ -151,12 +154,39 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)):
estimators.append(utxo.get_estimator(self))
defer.returnValue(estimators)
@defer.inlineCallbacks
def get_spendable_utxos(self, amount: int, funding_accounts):
yield self._utxo_reservation_lock.acquire()
try:
txos = yield self.get_effective_amount_estimators(funding_accounts)
selector = CoinSelector(
txos, amount,
self.get_input_output_fee(
self.transaction_class.output_class.pay_pubkey_hash(COIN, NULL_HASH32)
)
)
spendables = selector.select()
if spendables:
yield self.reserve_outputs(s.txo for s in spendables)
except Exception:
log.exception('Failed to get spendable utxos:')
raise
finally:
self._utxo_reservation_lock.release()
defer.returnValue(spendables)
def reserve_outputs(self, txos):
return self.db.reserve_outputs(txos)
def release_outputs(self, txos):
return self.db.release_outputs(txos)
@defer.inlineCallbacks
def get_local_status(self, address):
address_details = yield self.db.get_address(address)
history = address_details['history'] or ''
hash = hashlib.sha256(history.encode()).digest()
defer.returnValue(hexlify(hash))
h = sha256(history.encode())
defer.returnValue(hexlify(h))
@defer.inlineCallbacks
def get_local_history(self, address):
@ -203,7 +233,6 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)):
yield self.network.stop()
yield self.db.stop()
@execute_serially
@defer.inlineCallbacks
def update_headers(self):
while True:
@ -216,18 +245,19 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)):
@defer.inlineCallbacks
def process_header(self, response):
header = response[0]
if self.update_headers.is_running:
return
if header['height'] == len(self.headers):
# New header from network directly connects after the last local header.
yield self.headers.connect(len(self.headers), unhexlify(header['hex']))
self._on_header_controller.add(self.headers.height)
elif header['height'] > len(self.headers):
# New header is several heights ahead of local, do download instead.
yield self.update_headers()
yield self._header_processing_lock.acquire()
try:
header = response[0]
if header['height'] == len(self.headers):
# New header from network directly connects after the last local header.
yield self.headers.connect(len(self.headers), unhexlify(header['hex']))
self._on_header_controller.add(self.headers.height)
elif header['height'] > len(self.headers):
# New header is several heights ahead of local, do download instead.
yield self.update_headers()
finally:
self._header_processing_lock.release()
@execute_serially
def update_accounts(self):
return defer.DeferredList([
self.update_account(a) for a in self.accounts
@ -274,7 +304,7 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)):
try:
# see if we have a local copy of transaction, otherwise fetch it from server
raw, local_height, is_verified = yield self.db.get_transaction(hex_id)
raw, _, is_verified = yield self.db.get_transaction(hex_id)
save_tx = None
if raw is None:
_raw = yield self.network.get_transaction(hex_id)
@ -294,15 +324,16 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)):
''.join('{}:{}:'.format(tx_id, tx_height) for tx_id, tx_height in synced_history)
)
log.debug("{}: sync'ed tx {} for address: {}, height: {}, verified: {}".format(
log.debug(
"%s: sync'ed tx %s for address: %s, height: %s, verified: %s",
self.get_id(), hex_id, address, remote_height, is_verified
))
)
self._on_transaction_controller.add(TransactionEvent(address, tx, remote_height, is_verified))
except Exception as e:
except Exception:
log.exception('Failed to synchronize transaction:')
raise e
raise
finally:
lock.release()

View file

@ -1,4 +1,4 @@
from typing import List, Dict, Type
from typing import Type, MutableSequence, MutableMapping
from twisted.internet import defer
from torba.baseledger import BaseLedger, LedgerRegistry
@ -6,16 +6,16 @@ from torba.wallet import Wallet, WalletStorage
from torba.constants import COIN
class WalletManager(object):
class BaseWalletManager:
def __init__(self, wallets=None, ledgers=None):
# type: (List[Wallet], Dict[Type[BaseLedger],BaseLedger]) -> None
def __init__(self, wallets: MutableSequence[Wallet] = None,
ledgers: MutableMapping[Type[BaseLedger], BaseLedger] = None) -> None:
self.wallets = wallets or []
self.ledgers = ledgers or {}
self.running = False
@classmethod
def from_config(cls, config): # type: (Dict) -> WalletManager
def from_config(cls, config: dict) -> 'BaseWalletManager':
manager = cls()
for ledger_id, ledger_config in config.get('ledgers', {}).items():
manager.get_or_create_ledger(ledger_id, ledger_config)

View file

@ -21,6 +21,7 @@ class StratumClientProtocol(LineOnlyReceiver):
self.request_id = 0
self.lookup_table = {}
self.session = {}
self.network = None
self.on_disconnected_controller = StreamController()
self.on_disconnected = self.on_disconnected_controller.stream
@ -52,7 +53,7 @@ class StratumClientProtocol(LineOnlyReceiver):
socket.SOL_TCP, socket.TCP_KEEPCNT, 5
# Failed keepalive probles before declaring other end dead
)
except Exception as err:
except Exception as err: # pylint: disable=broad-except
# Supported only by the socket transport,
# but there's really no better place in code to trigger this.
log.warning("Error setting up socket: %s", err)
@ -61,7 +62,7 @@ class StratumClientProtocol(LineOnlyReceiver):
self.on_disconnected_controller.add(True)
def lineReceived(self, line):
log.debug('received: {}'.format(line))
log.debug('received: %s', line)
try:
message = json.loads(line)
@ -82,7 +83,7 @@ class StratumClientProtocol(LineOnlyReceiver):
controller = self.network.subscription_controllers[message['method']]
controller.add(message.get('params'))
else:
log.warning("Cannot handle message '%s'" % line)
log.warning("Cannot handle message '%s'", line)
def rpc(self, method, *args):
message_id = self._get_id()
@ -91,7 +92,7 @@ class StratumClientProtocol(LineOnlyReceiver):
'method': method,
'params': args
})
log.debug('sent: {}'.format(message))
log.debug('sent: %s', message)
self.sendLine(message.encode('latin-1'))
d = self.lookup_table[message_id] = defer.Deferred()
return d
@ -138,20 +139,21 @@ class BaseNetwork:
@defer.inlineCallbacks
def start(self):
for server in cycle(self.config['default_servers']):
endpoint = clientFromString(reactor, 'tcp:{}:{}'.format(*server))
log.debug("Attempting connection to SPV wallet server: {}:{}".format(*server))
connection_string = 'tcp:{}:{}'.format(*server)
endpoint = clientFromString(reactor, connection_string)
log.debug("Attempting connection to SPV wallet server: %s", connection_string)
self.service = ClientService(endpoint, StratumClientFactory(self))
self.service.startService()
try:
self.client = yield self.service.whenConnected(failAfterFailures=2)
yield self.ensure_server_version()
log.info("Successfully connected to SPV wallet server: {}:{}".format(*server))
log.info("Successfully connected to SPV wallet server: %s", connection_string)
self._on_connected_controller.add(True)
yield self.client.on_disconnected.first
except CancelledError:
return
except Exception:
log.exception("Connecting to {}:{} raised an exception:".format(*server))
except Exception: # pylint: disable=broad-except
log.exception("Connecting to %s raised an exception:", connection_string)
finally:
self.client = None
if not self.running:

View file

@ -1,6 +1,7 @@
from itertools import chain
from binascii import hexlify
from collections import namedtuple
from typing import List
from torba.bcd_data_stream import BCDataStream
from torba.util import subclass_tuple
@ -25,17 +26,21 @@ OP_DROP = 0x75
# template matching opcodes (not real opcodes)
# base class for PUSH_DATA related opcodes
# pylint: disable=invalid-name
PUSH_DATA_OP = namedtuple('PUSH_DATA_OP', 'name')
# opcode for variable length strings
# pylint: disable=invalid-name
PUSH_SINGLE = subclass_tuple('PUSH_SINGLE', PUSH_DATA_OP)
# opcode for variable number of variable length strings
# pylint: disable=invalid-name
PUSH_MANY = subclass_tuple('PUSH_MANY', PUSH_DATA_OP)
# opcode with embedded subscript parsing
# pylint: disable=invalid-name
PUSH_SUBSCRIPT = namedtuple('PUSH_SUBSCRIPT', 'name template')
def is_push_data_opcode(opcode):
return isinstance(opcode, PUSH_DATA_OP) or isinstance(opcode, PUSH_SUBSCRIPT)
return isinstance(opcode, (PUSH_DATA_OP, PUSH_SUBSCRIPT))
def is_push_data_token(token):
@ -61,15 +66,15 @@ def push_data(data):
def read_data(token, stream):
if token < OP_PUSHDATA1:
return stream.read(token)
elif token == OP_PUSHDATA1:
if token == OP_PUSHDATA1:
return stream.read(stream.read_uint8())
elif token == OP_PUSHDATA2:
if token == OP_PUSHDATA2:
return stream.read(stream.read_uint16())
else:
return stream.read(stream.read_uint32())
return stream.read(stream.read_uint32())
# opcode for OP_1 - OP_16
# pylint: disable=invalid-name
SMALL_INTEGER = namedtuple('SMALL_INTEGER', 'name')
@ -233,7 +238,7 @@ class Parser:
raise ParseError("Not a push single or subscript: {}".format(opcode))
class Template(object):
class Template:
__slots__ = 'name', 'opcodes'
@ -264,11 +269,11 @@ class Template(object):
return source.get_bytes()
class Script(object):
class Script:
__slots__ = 'source', 'template', 'values'
templates = []
templates: List[Template] = []
def __init__(self, source=None, template=None, values=None, template_hint=None):
self.source = source

View file

@ -1,29 +1,30 @@
import six
import logging
from typing import List, Iterable
import typing
from typing import List, Iterable, Optional
from binascii import hexlify
from twisted.internet import defer
import torba.baseaccount
import torba.baseledger
from torba.basescript import BaseInputScript, BaseOutputScript
from torba.coinselection import CoinSelector
from torba.baseaccount import BaseAccount
from torba.constants import COIN, NULL_HASH32
from torba.bcd_data_stream import BCDataStream
from torba.hash import sha256, TXRef, TXRefImmutable, TXORef
from torba.hash import sha256, TXRef, TXRefImmutable
from torba.util import ReadOnlyList
if typing.TYPE_CHECKING:
from torba import baseledger
log = logging.getLogger()
class TXRefMutable(TXRef):
__slots__ = 'tx',
__slots__ = ('tx',)
def __init__(self, tx):
super(TXRefMutable, self).__init__()
def __init__(self, tx: 'BaseTransaction') -> None:
super().__init__()
self.tx = tx
@property
@ -43,12 +44,35 @@ class TXRefMutable(TXRef):
self._hash = None
class TXORef:
__slots__ = 'tx_ref', 'position'
def __init__(self, tx_ref: TXRef, position: int) -> None:
self.tx_ref = tx_ref
self.position = position
@property
def id(self):
return '{}:{}'.format(self.tx_ref.id, self.position)
@property
def is_null(self):
return self.tx_ref.is_null
@property
def txo(self) -> Optional['BaseOutput']:
return None
class TXORefResolvable(TXORef):
__slots__ = '_txo',
__slots__ = ('_txo',)
def __init__(self, txo):
super(TXORefResolvable, self).__init__(txo.tx_ref, txo.position)
def __init__(self, txo: 'BaseOutput') -> None:
assert txo.tx_ref is not None
assert txo.position is not None
super().__init__(txo.tx_ref, txo.position)
self._txo = txo
@property
@ -56,23 +80,23 @@ class TXORefResolvable(TXORef):
return self._txo
class InputOutput(object):
class InputOutput:
__slots__ = 'tx_ref', 'position'
def __init__(self, tx_ref=None, position=None):
self.tx_ref = tx_ref # type: TXRef
self.position = position # type: int
def __init__(self, tx_ref: TXRef = None, position: int = None) -> None:
self.tx_ref = tx_ref
self.position = position
@property
def size(self):
def size(self) -> int:
""" Size of this input / output in bytes. """
stream = BCDataStream()
self.serialize_to(stream)
return len(stream.get_bytes())
def serialize_to(self, stream):
raise NotImplemented
def serialize_to(self, stream, alternate_script=None):
raise NotImplementedError
class BaseInput(InputOutput):
@ -84,27 +108,27 @@ class BaseInput(InputOutput):
__slots__ = 'txo_ref', 'sequence', 'coinbase', 'script'
def __init__(self, txo_ref, script, sequence=0xFFFFFFFF, tx_ref=None, position=None):
# type: (TXORef, BaseInputScript, int, TXRef, int) -> None
super(BaseInput, self).__init__(tx_ref, position)
def __init__(self, txo_ref: TXORef, script: BaseInputScript, sequence: int = 0xFFFFFFFF,
tx_ref: TXRef = None, position: int = None) -> None:
super().__init__(tx_ref, position)
self.txo_ref = txo_ref
self.sequence = sequence
self.coinbase = script if txo_ref.is_null else None
self.script = script if not txo_ref.is_null else None # type: BaseInputScript
self.script = script if not txo_ref.is_null else None
@property
def is_coinbase(self):
return self.coinbase is not None
@classmethod
def spend(cls, txo): # type: (BaseOutput) -> BaseInput
def spend(cls, txo: 'BaseOutput') -> 'BaseInput':
""" Create an input to spend the output."""
assert txo.script.is_pay_pubkey_hash, 'Attempting to spend unsupported output.'
script = cls.script_class.redeem_pubkey_hash(cls.NULL_SIGNATURE, cls.NULL_PUBLIC_KEY)
return cls(txo.ref, script)
@property
def amount(self):
def amount(self) -> int:
""" Amount this input adds to the transaction. """
if self.txo_ref.txo is None:
raise ValueError('Cannot resolve output to get amount.')
@ -135,15 +159,15 @@ class BaseInput(InputOutput):
stream.write_uint32(self.sequence)
class BaseOutputEffectiveAmountEstimator(object):
class BaseOutputEffectiveAmountEstimator:
__slots__ = 'txo', 'txi', 'fee', 'effective_amount'
def __init__(self, ledger, txo): # type: (torba.baseledger.BaseLedger, BaseOutput) -> None
def __init__(self, ledger: 'baseledger.BaseLedger', txo: 'BaseOutput') -> None:
self.txo = txo
self.txi = ledger.transaction_class.input_class.spend(txo)
self.fee = ledger.get_input_output_fee(self.txi)
self.effective_amount = txo.amount - self.fee
self.fee: int = ledger.get_input_output_fee(self.txi)
self.effective_amount: int = txo.amount - self.fee
def __lt__(self, other):
return self.effective_amount < other.effective_amount
@ -156,9 +180,9 @@ class BaseOutput(InputOutput):
__slots__ = 'amount', 'script'
def __init__(self, amount, script, tx_ref=None, position=None):
# type: (int, BaseOutputScript, TXRef, int) -> None
super(BaseOutput, self).__init__(tx_ref, position)
def __init__(self, amount: int, script: BaseOutputScript,
tx_ref: TXRef = None, position: int = None) -> None:
super().__init__(tx_ref, position)
self.amount = amount
self.script = script
@ -184,7 +208,7 @@ class BaseOutput(InputOutput):
script=cls.script_class(stream.read_string())
)
def serialize_to(self, stream):
def serialize_to(self, stream, alternate_script=None):
stream.write_uint64(self.amount)
stream.write_string(self.script.source)
@ -194,7 +218,7 @@ class BaseTransaction:
input_class = BaseInput
output_class = BaseOutput
def __init__(self, raw=None, version=1, locktime=0):
def __init__(self, raw=None, version=1, locktime=0) -> None:
self._raw = raw
self.ref = TXRefMutable(self)
self.version = version # type: int
@ -230,8 +254,7 @@ class BaseTransaction:
def outputs(self): # type: () -> ReadOnlyList[BaseOutput]
return ReadOnlyList(self._outputs)
def _add(self, new_ios, existing_ios):
# type: (List[InputOutput], List[InputOutput]) -> BaseTransaction
def _add(self, new_ios: Iterable[InputOutput], existing_ios: List) -> 'BaseTransaction':
for txio in new_ios:
txio.tx_ref = self.ref
txio.position = len(existing_ios)
@ -239,28 +262,28 @@ class BaseTransaction:
self._reset()
return self
def add_inputs(self, inputs): # type: (List[BaseInput]) -> BaseTransaction
def add_inputs(self, inputs: Iterable[BaseInput]) -> 'BaseTransaction':
return self._add(inputs, self._inputs)
def add_outputs(self, outputs): # type: (List[BaseOutput]) -> BaseTransaction
def add_outputs(self, outputs: Iterable[BaseOutput]) -> 'BaseTransaction':
return self._add(outputs, self._outputs)
@property
def fee(self): # type: () -> int
def fee(self) -> int:
""" Fee that will actually be paid."""
return self.input_sum - self.output_sum
@property
def size(self): # type: () -> int
def size(self) -> int:
""" Size in bytes of the entire transaction. """
return len(self.raw)
@property
def base_size(self): # type: () -> int
def base_size(self) -> int:
""" Size in bytes of transaction meta data and all outputs; without inputs. """
return len(self._serialize(with_inputs=False))
def _serialize(self, with_inputs=True): # type: (bool) -> bytes
def _serialize(self, with_inputs: bool = True) -> bytes:
stream = BCDataStream()
stream.write_uint32(self.version)
if with_inputs:
@ -273,12 +296,13 @@ class BaseTransaction:
stream.write_uint32(self.locktime)
return stream.get_bytes()
def _serialize_for_signature(self, signing_input): # type: (int) -> bytes
def _serialize_for_signature(self, signing_input: int) -> bytes:
stream = BCDataStream()
stream.write_uint32(self.version)
stream.write_compact_size(len(self._inputs))
for i, txin in enumerate(self._inputs):
if signing_input == i:
assert txin.txo_ref.txo is not None
txin.serialize_to(stream, txin.txo_ref.txo.script.source)
else:
txin.serialize_to(stream, b'')
@ -304,8 +328,9 @@ class BaseTransaction:
self.locktime = stream.read_uint32()
@classmethod
def ensure_all_have_same_ledger(cls, funding_accounts, change_account=None):
# type: (Iterable[torba.baseaccount.BaseAccount], torba.baseaccount.BaseAccount) -> torba.baseledger.BaseLedger
def ensure_all_have_same_ledger(
cls, funding_accounts: Iterable[BaseAccount], change_account: BaseAccount = None)\
-> 'baseledger.BaseLedger':
ledger = None
for account in funding_accounts:
if ledger is None:
@ -316,33 +341,24 @@ class BaseTransaction:
)
if change_account is not None and change_account.ledger != ledger:
raise ValueError('Change account must use same ledger as funding accounts.')
if ledger is None:
raise ValueError('No ledger found.')
return ledger
@classmethod
@defer.inlineCallbacks
def pay(cls, outputs, funding_accounts, change_account, reserve_outputs=True):
# type: (List[BaseOutput], List[torba.baseaccount.BaseAccount], torba.baseaccount.BaseAccount) -> defer.Deferred
def pay(cls, outputs: Iterable[BaseOutput], funding_accounts: Iterable[BaseAccount],
change_account: BaseAccount):
""" Efficiently spend utxos from funding_accounts to cover the new outputs. """
tx = cls().add_outputs(outputs)
ledger = cls.ensure_all_have_same_ledger(funding_accounts, change_account)
amount = tx.output_sum + ledger.get_transaction_base_fee(tx)
txos = yield ledger.get_effective_amount_estimators(funding_accounts)
selector = CoinSelector(
txos, amount,
ledger.get_input_output_fee(
cls.output_class.pay_pubkey_hash(COIN, NULL_HASH32)
)
)
spendables = yield ledger.get_spendable_utxos(amount, funding_accounts)
spendables = selector.select()
if not spendables:
raise ValueError('Not enough funds to cover this transaction.')
reserved_outputs = [s.txo.id for s in spendables]
if reserve_outputs:
yield ledger.db.reserve_spent_outputs(reserved_outputs)
try:
spent_sum = sum(s.effective_amount for s in spendables)
if spent_sum > amount:
@ -351,30 +367,25 @@ class BaseTransaction:
change_amount = spent_sum - amount
tx.add_outputs([cls.output_class.pay_pubkey_hash(change_amount, change_hash160)])
tx.add_inputs([s.txi for s in spendables])
tx.add_inputs(s.txi for s in spendables)
yield tx.sign(funding_accounts)
except Exception:
if reserve_outputs:
yield ledger.db.release_reserved_outputs(reserved_outputs)
raise
except Exception as e:
log.exception('Failed to synchronize transaction:')
yield ledger.release_outputs(s.txo for s in spendables)
raise e
defer.returnValue(tx)
@classmethod
@defer.inlineCallbacks
def liquidate(cls, assets, funding_accounts, change_account, reserve_outputs=True):
def liquidate(cls, assets, funding_accounts, change_account):
""" Spend assets (utxos) supplementing with funding_accounts if fee is higher than asset value. """
tx = cls().add_inputs([
cls.input_class.spend(utxo) for utxo in assets
])
ledger = cls.ensure_all_have_same_ledger(funding_accounts, change_account)
reserved_outputs = [utxo.id for utxo in assets]
if reserve_outputs:
yield ledger.db.reserve_spent_outputs(reserved_outputs)
yield ledger.reserve_outputs(assets)
try:
cost_of_change = (
ledger.get_transaction_base_fee(tx) +
@ -386,41 +397,35 @@ class BaseTransaction:
change_hash160 = change_account.ledger.address_to_hash160(change_address)
change_amount = liquidated_total - cost_of_change
tx.add_outputs([cls.output_class.pay_pubkey_hash(change_amount, change_hash160)])
yield tx.sign(funding_accounts)
except Exception:
if reserve_outputs:
yield ledger.db.release_reserved_outputs(reserved_outputs)
yield ledger.release_outputs(assets)
raise
defer.returnValue(tx)
def signature_hash_type(self, hash_type):
@staticmethod
def signature_hash_type(hash_type):
return hash_type
@defer.inlineCallbacks
def sign(self, funding_accounts): # type: (Iterable[torba.baseaccount.BaseAccount]) -> BaseTransaction
def sign(self, funding_accounts: Iterable[BaseAccount]) -> defer.Deferred:
ledger = self.ensure_all_have_same_ledger(funding_accounts)
for i, txi in enumerate(self._inputs):
assert txi.script is not None
assert txi.txo_ref.txo is not None
txo_script = txi.txo_ref.txo.script
if txo_script.is_pay_pubkey_hash:
address = ledger.hash160_to_address(txo_script.values['pubkey_hash'])
private_key = yield ledger.get_private_key_for_address(address)
tx = self._serialize_for_signature(i)
txi.script.values['signature'] = \
private_key.sign(tx) + six.int2byte(self.signature_hash_type(1))
private_key.sign(tx) + bytes((self.signature_hash_type(1),))
txi.script.values['pubkey'] = private_key.public_key.pubkey_bytes
txi.script.generate()
else:
raise NotImplementedError("Don't know how to spend this output.")
self._reset()
def sort(self):
# See https://github.com/kristovatlas/rfc/blob/master/bips/bip-li01.mediawiki
self._inputs.sort(key=lambda i: (i['prevout_hash'], i['prevout_n']))
self._outputs.sort(key=lambda o: (o[2], pay_script(o[0], o[1])))
@property
def input_sum(self):
return sum(i.amount for i in self.inputs)

View file

@ -35,9 +35,9 @@ class BCDataStream:
return size
if size == 253:
return self.read_uint16()
elif size == 254:
if size == 254:
return self.read_uint32()
elif size == 255:
if size == 255:
return self.read_uint64()
def write_compact_size(self, size):
@ -70,7 +70,7 @@ class BCDataStream:
def _read_struct(self, fmt):
value = self.read(fmt.size)
if len(value) > 0:
if value:
return fmt.unpack(value)[0]
def read_int8(self):

View file

@ -10,7 +10,6 @@
import struct
import hashlib
from six import int2byte, byte2int, indexbytes
import ecdsa
import ecdsa.ellipticcurve as EC
@ -24,7 +23,7 @@ class DerivationError(Exception):
""" Raised when an invalid derivation occurs. """
class _KeyBase(object):
class _KeyBase:
""" A BIP32 Key, public or private. """
CURVE = ecdsa.SECP256k1
@ -63,17 +62,23 @@ class _KeyBase(object):
if len(raw_serkey) != 33:
raise ValueError('raw_serkey must have length 33')
return (ver_bytes + int2byte(self.depth)
return (ver_bytes + bytes((self.depth,))
+ self.parent_fingerprint() + struct.pack('>I', self.n)
+ self.chain_code + raw_serkey)
def identifier(self):
raise NotImplementedError
def extended_key(self):
raise NotImplementedError
def fingerprint(self):
""" Return the key's fingerprint as 4 bytes. """
return self.identifier()[:4]
def parent_fingerprint(self):
""" Return the parent key's fingerprint as 4 bytes. """
return self.parent.fingerprint() if self.parent else int2byte(0)*4
return self.parent.fingerprint() if self.parent else bytes((0,)*4)
def extended_key_string(self):
""" Return an extended key as a base58 string. """
@ -84,7 +89,7 @@ class PubKey(_KeyBase):
""" A BIP32 public key. """
def __init__(self, ledger, pubkey, chain_code, n, depth, parent=None):
super(PubKey, self).__init__(ledger, chain_code, n, depth, parent)
super().__init__(ledger, chain_code, n, depth, parent)
if isinstance(pubkey, ecdsa.VerifyingKey):
self.verifying_key = pubkey
else:
@ -97,16 +102,16 @@ class PubKey(_KeyBase):
raise TypeError('pubkey must be raw bytes')
if len(pubkey) != 33:
raise ValueError('pubkey must be 33 bytes')
if indexbytes(pubkey, 0) not in (2, 3):
if pubkey[0] not in (2, 3):
raise ValueError('invalid pubkey prefix byte')
curve = cls.CURVE.curve
is_odd = indexbytes(pubkey, 0) == 3
is_odd = pubkey[0] == 3
x = bytes_to_int(pubkey[1:])
# p is the finite field order
a, b, p = curve.a(), curve.b(), curve.p()
y2 = pow(x, 3, p) + b
a, b, p = curve.a(), curve.b(), curve.p() # pylint: disable=invalid-name
y2 = pow(x, 3, p) + b # pylint: disable=invalid-name
assert a == 0 # Otherwise y2 += a * pow(x, 2, p)
y = NT.square_root_mod_prime(y2 % p, p)
if bool(y & 1) != is_odd:
@ -119,7 +124,7 @@ class PubKey(_KeyBase):
def pubkey_bytes(self):
""" Return the compressed public key as 33 bytes. """
point = self.verifying_key.pubkey.point
prefix = int2byte(2 + (point.y() & 1))
prefix = bytes((2 + (point.y() & 1),))
padded_bytes = _exponent_to_bytes(point.x())
return prefix + padded_bytes
@ -137,10 +142,10 @@ class PubKey(_KeyBase):
raise ValueError('invalid BIP32 public key child number')
msg = self.pubkey_bytes + struct.pack('>I', n)
L, R = self._hmac_sha512(msg)
L, R = self._hmac_sha512(msg) # pylint: disable=invalid-name
curve = self.CURVE
L = bytes_to_int(L)
L = bytes_to_int(L) # pylint: disable=invalid-name
if L >= curve.order:
raise DerivationError
@ -172,7 +177,7 @@ class LowSValueSigningKey(ecdsa.SigningKey):
def sign_number(self, number, entropy=None, k=None):
order = self.privkey.order
r, s = ecdsa.SigningKey.sign_number(self, number, entropy, k)
r, s = ecdsa.SigningKey.sign_number(self, number, entropy, k) # pylint: disable=invalid-name
if s > order / 2:
s = order - s
return r, s
@ -184,7 +189,7 @@ class PrivateKey(_KeyBase):
HARDENED = 1 << 31
def __init__(self, ledger, privkey, chain_code, n, depth, parent=None):
super(PrivateKey, self).__init__(ledger, chain_code, n, depth, parent)
super().__init__(ledger, chain_code, n, depth, parent)
if isinstance(privkey, ecdsa.SigningKey):
self.signing_key = privkey
else:
@ -254,10 +259,10 @@ class PrivateKey(_KeyBase):
serkey = self.public_key.pubkey_bytes
msg = serkey + struct.pack('>I', n)
L, R = self._hmac_sha512(msg)
L, R = self._hmac_sha512(msg) # pylint: disable=invalid-name
curve = self.CURVE
L = bytes_to_int(L)
L = bytes_to_int(L) # pylint: disable=invalid-name
exponent = (L + bytes_to_int(self.private_key_bytes)) % curve.order
if exponent == 0 or L >= curve.order:
raise DerivationError
@ -286,7 +291,7 @@ class PrivateKey(_KeyBase):
def _exponent_to_bytes(exponent):
"""Convert an exponent to 32 big-endian bytes"""
return (int2byte(0)*32 + int_to_bytes(exponent))[-32:]
return (bytes((0,)*32) + int_to_bytes(exponent))[-32:]
def _from_extended_key(ledger, ekey):
@ -296,8 +301,8 @@ def _from_extended_key(ledger, ekey):
if len(ekey) != 78:
raise ValueError('extended key must have length 78')
depth = indexbytes(ekey, 4)
fingerprint = ekey[5:9] # Not used
depth = ekey[4]
# fingerprint = ekey[5:9]
n, = struct.unpack('>I', ekey[9:13])
chain_code = ekey[13:45]
@ -305,7 +310,7 @@ def _from_extended_key(ledger, ekey):
pubkey = ekey[45:]
key = PubKey(ledger, pubkey, chain_code, n, depth)
elif ekey[:4] == ledger.extended_private_key_prefix:
if indexbytes(ekey, 45) != 0:
if ekey[45] != 0:
raise ValueError('invalid extended private key prefix byte')
privkey = ekey[46:]
key = PrivateKey(ledger, privkey, chain_code, n, depth)

View file

@ -1 +1 @@
__path__ = __import__('pkgutil').extend_path(__path__, __name__)
__path__: str = __import__('pkgutil').extend_path(__path__, __name__)

View file

@ -6,7 +6,6 @@ __node_url__ = (
)
__electrumx__ = 'electrumx.lib.coins.BitcoinCashRegtest'
from six import int2byte
from binascii import unhexlify
from torba.baseledger import BaseLedger
from torba.baseheader import BaseHeaders
@ -26,8 +25,8 @@ class MainNetLedger(BaseLedger):
transaction_class = Transaction
pubkey_address_prefix = int2byte(0x00)
script_address_prefix = int2byte(0x05)
pubkey_address_prefix = bytes((0,))
script_address_prefix = bytes((5,))
extended_public_key_prefix = unhexlify('0488b21e')
extended_private_key_prefix = unhexlify('0488ade4')
@ -42,8 +41,8 @@ class RegTestLedger(MainNetLedger):
headers_class = UnverifiedHeaders
network_name = 'regtest'
pubkey_address_prefix = int2byte(111)
script_address_prefix = int2byte(196)
pubkey_address_prefix = bytes((111,))
script_address_prefix = bytes((196,))
extended_public_key_prefix = unhexlify('043587cf')
extended_private_key_prefix = unhexlify('04358394')

View file

@ -6,7 +6,6 @@ __node_url__ = (
)
__electrumx__ = 'electrumx.lib.coins.BitcoinSegwitRegtest'
from six import int2byte
from binascii import unhexlify
from torba.baseledger import BaseLedger
from torba.baseheader import BaseHeaders
@ -17,8 +16,8 @@ class MainNetLedger(BaseLedger):
symbol = 'BTC'
network_name = 'mainnet'
pubkey_address_prefix = int2byte(0x00)
script_address_prefix = int2byte(0x05)
pubkey_address_prefix = bytes((0,))
script_address_prefix = bytes((5,))
extended_public_key_prefix = unhexlify('0488b21e')
extended_private_key_prefix = unhexlify('0488ade4')
@ -33,8 +32,8 @@ class RegTestLedger(MainNetLedger):
headers_class = UnverifiedHeaders
network_name = 'regtest'
pubkey_address_prefix = int2byte(111)
script_address_prefix = int2byte(196)
pubkey_address_prefix = bytes((111,))
script_address_prefix = bytes((196,))
extended_public_key_prefix = unhexlify('043587cf')
extended_private_key_prefix = unhexlify('04358394')

View file

@ -1,16 +1,15 @@
import six
from random import Random
from typing import List
import torba
from torba import basetransaction
MAXIMUM_TRIES = 100000
class CoinSelector:
def __init__(self, txos, target, cost_of_change, seed=None):
# type: (List[torba.basetransaction.BaseOutputAmountEstimator], int, int, str) -> None
def __init__(self, txos: List[basetransaction.BaseOutputEffectiveAmountEstimator],
target: int, cost_of_change: int, seed: str = None) -> None:
self.txos = txos
self.target = target
self.cost_of_change = cost_of_change
@ -18,17 +17,17 @@ class CoinSelector:
self.tries = 0
self.available = sum(c.effective_amount for c in self.txos)
self.random = Random(seed)
if six.PY3 and seed is not None:
if seed is not None:
self.random.seed(seed, version=1)
def select(self): # type: () -> List[torba.basetransaction.BaseOutputAmountEstimator]
def select(self) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]:
if not self.txos:
return
return []
if self.target > self.available:
return
return []
return self.branch_and_bound() or self.single_random_draw()
def branch_and_bound(self): # type: () -> List[torba.basetransaction.BaseOutputAmountEstimator]
def branch_and_bound(self) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]:
# see bitcoin implementation for more info:
# https://github.com/bitcoin/bitcoin/blob/master/src/wallet/coinselection.cpp
@ -36,9 +35,9 @@ class CoinSelector:
current_value = 0
current_available_value = self.available
current_selection = []
current_selection: List[bool] = []
best_waste = self.cost_of_change
best_selection = []
best_selection: List[bool] = []
while self.tries < MAXIMUM_TRIES:
self.tries += 1
@ -70,7 +69,7 @@ class CoinSelector:
utxo = self.txos[len(current_selection)]
current_available_value -= utxo.effective_amount
previous_utxo = self.txos[len(current_selection) - 1] if current_selection else None
if current_selection and not current_selection[-1] and \
if current_selection and not current_selection[-1] and previous_utxo and \
utxo.effective_amount == previous_utxo.effective_amount and \
utxo.fee == previous_utxo.fee:
current_selection.append(False)
@ -84,7 +83,9 @@ class CoinSelector:
self.txos[i] for i, include in enumerate(best_selection) if include
]
def single_random_draw(self): # type: () -> List[torba.basetransaction.BaseOutputAmountEstimator]
return []
def single_random_draw(self) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]:
self.random.shuffle(self.txos, self.random.random)
selection = []
amount = 0
@ -93,3 +94,4 @@ class CoinSelector:
amount += coin.effective_amount
if amount >= self.target+self.cost_of_change:
return selection
return []

View file

@ -9,7 +9,6 @@
""" Cryptography hash functions and related classes. """
import os
import six
import base64
import hashlib
import hmac
@ -22,13 +21,8 @@ from cryptography.hazmat.backends import default_backend
from torba.util import bytes_to_int, int_to_bytes
from torba.constants import NULL_HASH32
_sha256 = hashlib.sha256
_sha512 = hashlib.sha512
_new_hash = hashlib.new
_new_hmac = hmac.new
class TXRef(object):
class TXRef:
__slots__ = '_id', '_hash'
@ -68,50 +62,29 @@ class TXRefImmutable(TXRef):
return ref
class TXORef(object):
__slots__ = 'tx_ref', 'position'
def __init__(self, tx_ref, position): # type: (TXRef, int) -> None
self.tx_ref = tx_ref
self.position = position
@property
def id(self):
return '{}:{}'.format(self.tx_ref.id, self.position)
@property
def is_null(self):
return self.tx_ref.is_null
@property
def txo(self):
return None
def sha256(x):
""" Simple wrapper of hashlib sha256. """
return _sha256(x).digest()
return hashlib.sha256(x).digest()
def sha512(x):
""" Simple wrapper of hashlib sha512. """
return _sha512(x).digest()
return hashlib.sha512(x).digest()
def ripemd160(x):
""" Simple wrapper of hashlib ripemd160. """
h = _new_hash('ripemd160')
h = hashlib.new('ripemd160')
h.update(x)
return h.digest()
def pow_hash(x):
r = sha512(double_sha256(x))
r1 = ripemd160(r[:len(r) // 2])
r2 = ripemd160(r[len(r) // 2:])
r3 = double_sha256(r1 + r2)
return r3
h = sha512(double_sha256(x))
return double_sha256(
ripemd160(h[:len(h) // 2]) +
ripemd160(h[len(h) // 2:])
)
def double_sha256(x):
@ -121,7 +94,7 @@ def double_sha256(x):
def hmac_sha512(key, msg):
""" Use SHA-512 to provide an HMAC. """
return _new_hmac(key, msg, _sha512).digest()
return hmac.new(key, msg, hashlib.sha512).digest()
def hash160(x):
@ -165,7 +138,7 @@ class Base58Error(Exception):
""" Exception used for Base58 errors. """
class Base58(object):
class Base58:
""" Class providing base 58 functionality. """
chars = u'123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz'
@ -207,7 +180,7 @@ class Base58(object):
break
count += 1
if count:
result = six.int2byte(0) * count + result
result = bytes((0,)) * count + result
return result

View file

@ -56,7 +56,7 @@ CJK_INTERVALS = [
def is_cjk(c):
n = ord(c)
for start, end, name in CJK_INTERVALS:
for start, end, _ in CJK_INTERVALS:
if start <= n <= end:
return True
return False
@ -93,7 +93,7 @@ def load_words(filename):
return words
file_names = {
FILE_NAMES = {
'en': 'english.txt',
'es': 'spanish.txt',
'ja': 'japanese.txt',
@ -102,20 +102,22 @@ file_names = {
}
class Mnemonic(object):
class Mnemonic:
# Seed derivation no longer follows BIP39
# Mnemonic phrase uses a hash based checksum, instead of a words-dependent checksum
def __init__(self, lang='en'):
filename = file_names.get(lang, 'english.txt')
filename = FILE_NAMES.get(lang, 'english.txt')
self.words = load_words(filename)
@classmethod
def mnemonic_to_seed(self, mnemonic, passphrase=u''):
PBKDF2_ROUNDS = 2048
@staticmethod
def mnemonic_to_seed(mnemonic, passphrase=u''):
pbkdf2_rounds = 2048
mnemonic = normalize_text(mnemonic)
passphrase = normalize_text(passphrase)
return pbkdf2.PBKDF2(mnemonic, passphrase, iterations=PBKDF2_ROUNDS, macmodule=hmac, digestmodule=hashlib.sha512).read(64)
return pbkdf2.PBKDF2(
mnemonic, passphrase, iterations=pbkdf2_rounds, macmodule=hmac, digestmodule=hashlib.sha512
).read(64)
def mnemonic_encode(self, i):
n = len(self.words)
@ -131,8 +133,8 @@ class Mnemonic(object):
words = seed.split()
i = 0
while words:
w = words.pop()
k = self.words.index(w)
word = words.pop()
k = self.words.index(word)
i = i*n + k
return i

View file

@ -1,27 +1,7 @@
import six
from twisted.internet.defer import Deferred, DeferredLock, maybeDeferred, inlineCallbacks
import asyncio
from twisted.internet.defer import Deferred
from twisted.python.failure import Failure
if six.PY3:
import asyncio
def execute_serially(f):
_lock = DeferredLock()
@inlineCallbacks
def allow_only_one_at_a_time(*args, **kwargs):
yield _lock.acquire()
allow_only_one_at_a_time.is_running = True
try:
yield maybeDeferred(f, *args, **kwargs)
finally:
allow_only_one_at_a_time.is_running = False
_lock.release()
allow_only_one_at_a_time.is_running = False
return allow_only_one_at_a_time
class BroadcastSubscription:
@ -76,10 +56,10 @@ class StreamController:
@property
def _iterate_subscriptions(self):
next = self._first_subscription
while next is not None:
subscription = next
next = next._next
next_sub = self._first_subscription
while next_sub is not None:
subscription = next_sub
next_sub = next_sub._next
yield subscription
def add(self, event):
@ -96,15 +76,15 @@ class StreamController:
def _cancel(self, subscription):
previous = subscription._previous
next = subscription._next
next_sub = subscription._next
if previous is None:
self._first_subscription = next
self._first_subscription = next_sub
else:
previous._next = next
if next is None:
previous._next = next_sub
if next_sub is None:
self._last_subscription = previous
else:
next._previous = previous
next_sub._previous = previous
subscription._next = subscription._previous = subscription
def _listen(self, on_data, on_error, on_done):

View file

@ -1,20 +1,19 @@
from binascii import unhexlify, hexlify
from collections import Sequence
from typing import TypeVar, Generic
from typing import TypeVar, Sequence
T = TypeVar('T')
class ReadOnlyList(Sequence, Generic[T]):
class ReadOnlyList(Sequence[T]):
def __init__(self, lst):
self.lst = lst
def __getitem__(self, key): # type: (int) -> T
def __getitem__(self, key):
return self.lst[key]
def __len__(self):
def __len__(self) -> int:
return len(self.lst)
@ -22,13 +21,13 @@ def subclass_tuple(name, base):
return type(name, (base,), {'__slots__': ()})
class cachedproperty(object):
class cachedproperty:
def __init__(self, f):
self.f = f
def __get__(self, obj, type):
obj = obj or type
def __get__(self, obj, objtype):
obj = obj or objtype
value = self.f(obj)
setattr(obj, self.f.__name__, value)
return value
@ -42,8 +41,8 @@ def bytes_to_int(be_bytes):
def int_to_bytes(value):
""" Converts an integer to a big-endian sequence of bytes. """
length = (value.bit_length() + 7) // 8
h = '%x' % value
return unhexlify(('0' * (len(h) % 2) + h).zfill(length * 2))
s = '%x' % value
return unhexlify(('0' * (len(s) % 2) + s).zfill(length * 2))
def rev_hex(s):
@ -56,8 +55,8 @@ def int_to_hex(i, length=1):
return rev_hex(s)
def hex_to_int(s):
return int(b'0x' + hexlify(s[::-1]), 16)
def hex_to_int(x):
return int(b'0x' + hexlify(x[::-1]), 16)
def hash_encode(x):

View file

@ -1,10 +1,13 @@
import stat
import json
import os
from typing import List
import typing
from typing import Sequence, MutableSequence
import torba.baseaccount
import torba.baseledger
if typing.TYPE_CHECKING:
from torba import baseaccount
from torba import baseledger
from torba import basemanager
class Wallet:
@ -14,24 +17,24 @@ class Wallet:
by physical files on the filesystem.
"""
def __init__(self, name='Wallet', accounts=None, storage=None):
# type: (str, List[torba.baseaccount.BaseAccount], WalletStorage) -> None
def __init__(self, name: str = 'Wallet', accounts: MutableSequence['baseaccount.BaseAccount'] = None,
storage: 'WalletStorage' = None) -> None:
self.name = name
self.accounts = accounts or [] # type: List[torba.baseaccount.BaseAccount]
self.accounts = accounts or []
self.storage = storage or WalletStorage()
def generate_account(self, ledger):
# type: (torba.baseledger.BaseLedger) -> torba.baseaccount.BaseAccount
def generate_account(self, ledger: 'baseledger.BaseLedger') -> 'baseaccount.BaseAccount':
account = ledger.account_class.generate(ledger, u'torba')
self.accounts.append(account)
return account
@classmethod
def from_storage(cls, storage, manager): # type: (WalletStorage, 'WalletManager') -> Wallet
def from_storage(cls, storage: 'WalletStorage', manager: 'basemanager.BaseWalletManager') -> 'Wallet':
json_dict = storage.read()
accounts = []
for account_dict in json_dict.get('accounts', []):
account_dicts: Sequence[dict] = json_dict.get('accounts', [])
for account_dict in account_dicts:
ledger = manager.get_or_create_ledger(account_dict['ledger'])
account = ledger.account_class.from_dict(ledger, account_dict)
accounts.append(account)
@ -110,7 +113,7 @@ class WalletStorage:
mode = stat.S_IREAD | stat.S_IWRITE
try:
os.rename(temp_path, self.path)
except:
except Exception: # pylint: disable=broad-except
os.remove(self.path)
os.rename(temp_path, self.path)
os.chmod(self.path, mode)