progress since thursday

This commit is contained in:
Lex Berezhny 2018-06-11 09:33:32 -04:00
parent 78e4135159
commit eb6781481a
18 changed files with 919 additions and 955 deletions

View file

@ -1,43 +0,0 @@
from six import int2byte
from binascii import unhexlify
from torba.baseledger import BaseLedger
from torba.basenetwork import BaseNetwork
from torba.basescript import BaseInputScript, BaseOutputScript
from torba.basetransaction import BaseTransaction, BaseInput, BaseOutput
from torba.basecoin import BaseCoin
class Ledger(BaseLedger):
network_class = BaseNetwork
class Input(BaseInput):
script_class = BaseInputScript
class Output(BaseOutput):
script_class = BaseOutputScript
class Transaction(BaseTransaction):
input_class = Input
output_class = Output
class FTC(BaseCoin):
name = 'Fakecoin'
symbol = 'FTC'
network = 'mainnet'
ledger_class = Ledger
transaction_class = Transaction
pubkey_address_prefix = int2byte(0x00)
script_address_prefix = int2byte(0x05)
extended_public_key_prefix = unhexlify('0488b21e')
extended_private_key_prefix = unhexlify('0488ade4')
default_fee_per_byte = 50
def __init__(self, ledger, fee_per_byte=default_fee_per_byte):
super(FTC, self).__init__(ledger, fee_per_byte)

View file

@ -1,38 +1,48 @@
from binascii import hexlify from binascii import hexlify
from twisted.trial import unittest from twisted.trial import unittest
from twisted.internet import defer
from torba.coin.bitcoinsegwit import BTC from torba.coin.bitcoinsegwit import MainNetLedger
from torba.basemanager import WalletManager
from torba.wallet import Account
class TestAccount(unittest.TestCase): class TestAccount(unittest.TestCase):
def setUp(self): def setUp(self):
ledger = WalletManager().get_or_create_ledger(BTC.get_id()) self.ledger = MainNetLedger(db=':memory:')
self.coin = BTC(ledger) return self.ledger.db.start()
@defer.inlineCallbacks
def test_generate_account(self): def test_generate_account(self):
account = Account.generate(self.coin, u"torba") account = self.ledger.account_class.generate(self.ledger, u"torba")
self.assertEqual(account.coin, self.coin) self.assertEqual(account.ledger, self.ledger)
self.assertIsNotNone(account.seed) self.assertIsNotNone(account.seed)
self.assertEqual(account.public_key.coin, self.coin) self.assertEqual(account.public_key.ledger, self.ledger)
self.assertEqual(account.private_key.public_key, account.public_key) self.assertEqual(account.private_key.public_key, account.public_key)
self.assertEqual(len(account.receiving_keys.child_keys), 0) keys = yield account.receiving.get_keys()
self.assertEqual(len(account.receiving_keys.addresses), 0) addresses = yield account.receiving.get_addresses()
self.assertEqual(len(account.change_keys.child_keys), 0) self.assertEqual(len(keys), 0)
self.assertEqual(len(account.change_keys.addresses), 0) self.assertEqual(len(addresses), 0)
keys = yield account.change.get_keys()
addresses = yield account.change.get_addresses()
self.assertEqual(len(keys), 0)
self.assertEqual(len(addresses), 0)
account.ensure_enough_addresses() yield account.ensure_enough_useable_addresses()
self.assertEqual(len(account.receiving_keys.child_keys), 20)
self.assertEqual(len(account.receiving_keys.addresses), 20)
self.assertEqual(len(account.change_keys.child_keys), 6)
self.assertEqual(len(account.change_keys.addresses), 6)
keys = yield account.receiving.get_keys()
addresses = yield account.receiving.get_addresses()
self.assertEqual(len(keys), 20)
self.assertEqual(len(addresses), 20)
keys = yield account.change.get_keys()
addresses = yield account.change.get_addresses()
self.assertEqual(len(keys), 6)
self.assertEqual(len(addresses), 6)
@defer.inlineCallbacks
def test_generate_account_from_seed(self): def test_generate_account_from_seed(self):
account = Account.from_seed( account = self.ledger.account_class.from_seed(
self.coin, self.ledger,
u"carbon smart garage balance margin twelve chest sword toast envelope bottom stomach ab" u"carbon smart garage balance margin twelve chest sword toast envelope bottom stomach ab"
u"sent", u"sent",
u"torba" u"torba"
@ -47,23 +57,22 @@ class TestAccount(unittest.TestCase):
b'xpub661MyMwAqRbcF84AR8yfHoMzf4S2ct6mPJtvBtvNeyN9hBHuZ6uGJszkTSn5fQUCdz3XU17eBzFeAUwV6f' b'xpub661MyMwAqRbcF84AR8yfHoMzf4S2ct6mPJtvBtvNeyN9hBHuZ6uGJszkTSn5fQUCdz3XU17eBzFeAUwV6f'
b'iW44g14WF52fYC5J483wqQ5ZP' b'iW44g14WF52fYC5J483wqQ5ZP'
) )
self.assertEqual( address = yield account.receiving.ensure_enough_useable_addresses()
account.receiving_keys.generate_next_address(), self.assertEqual(address[0], b'1PGDB1CRy8UxPCrkcakRqroVnHxqzvUZhp')
b'1PmX9T3sCiDysNtWszJa44SkKcpGc2NaXP' private_key = yield self.ledger.get_private_key_for_address(b'1PGDB1CRy8UxPCrkcakRqroVnHxqzvUZhp')
)
private_key = account.get_private_key_for_address(b'1PmX9T3sCiDysNtWszJa44SkKcpGc2NaXP')
self.assertEqual( self.assertEqual(
private_key.extended_key_string(), private_key.extended_key_string(),
b'xprv9xNEfQ296VTRaEUDZ8oKq74xw2U6kpj486vFUB4K1wT9U25GX4UwuzFgJN1YuRrqkQ5TTwCpkYnjNpSoHS' b'xprv9xNEfQ296VTRc5QF7AZZ1WTimGzMs54FepRXVxbyypJXCrUKjxsYSyk5EhHYNxU4ApsaBr8AQ4sYo86BbGh2dZSddGXU1CMGwExvnyckjQn'
b'BaEigNHPkoeYbuPMRo6mRUjxg'
) )
self.assertIsNone(account.get_private_key_for_address(b'BcQjRlhDOIrQez1WHfz3whnB33Bp34sUgX')) invalid_key = yield self.ledger.get_private_key_for_address(b'BcQjRlhDOIrQez1WHfz3whnB33Bp34sUgX')
self.assertIsNone(invalid_key)
self.assertEqual( self.assertEqual(
hexlify(private_key.wif()), hexlify(private_key.wif()),
b'1cc27be89ad47ef932562af80e95085eb0ab2ae3e5c019b1369b8b05ff2e94512f01' b'1c5664e848772b199644ab390b5c27d2f6664d9cdfdb62e1c7ac25151b00858b7a01'
) )
@defer.inlineCallbacks
def test_load_and_save_account(self): def test_load_and_save_account(self):
account_data = { account_data = {
'seed': 'seed':
@ -77,29 +86,22 @@ class TestAccount(unittest.TestCase):
'xpub661MyMwAqRbcF84AR8yfHoMzf4S2ct6mPJtvBtvNeyN9hBHuZ6uGJszkTSn5fQUCdz3XU17eBzFeAUwV6f' 'xpub661MyMwAqRbcF84AR8yfHoMzf4S2ct6mPJtvBtvNeyN9hBHuZ6uGJszkTSn5fQUCdz3XU17eBzFeAUwV6f'
'iW44g14WF52fYC5J483wqQ5ZP', 'iW44g14WF52fYC5J483wqQ5ZP',
'receiving_gap': 10, 'receiving_gap': 10,
'receiving_keys': [
'0222345947a59dca4a3363ffa81ac87dd907d2b2feff57383eaeddbab266ca5f2d',
'03fdc9826d5d00a484188cba8eb7dba5877c0323acb77905b7bcbbab35d94be9f6'
],
'change_gap': 10, 'change_gap': 10,
'change_keys': [
'038836be4147836ed6b4df6a89e0d9f1b1c11cec529b7ff5407de57f2e5b032c83'
]
} }
account = Account.from_dict(self.coin, account_data) account = self.ledger.account_class.from_dict(self.ledger, account_data)
self.assertEqual(len(account.receiving_keys.addresses), 2) yield account.ensure_enough_useable_addresses()
self.assertEqual(
account.receiving_keys.addresses[0], keys = yield account.receiving.get_keys()
b'1PmX9T3sCiDysNtWszJa44SkKcpGc2NaXP' addresses = yield account.receiving.get_addresses()
) self.assertEqual(len(keys), 10)
self.assertEqual(len(account.change_keys.addresses), 1) self.assertEqual(len(addresses), 10)
self.assertEqual( keys = yield account.change.get_keys()
account.change_keys.addresses[0], addresses = yield account.change.get_addresses()
b'1PUbu1D1f3c244JPRSJKBCxRqui5NT6geR' self.assertEqual(len(keys), 10)
) self.assertEqual(len(addresses), 10)
self.maxDiff = None self.maxDiff = None
account_data['coin'] = 'btc_mainnet' account_data['ledger'] = 'btc_mainnet'
self.assertDictEqual(account_data, account.to_dict()) self.assertDictEqual(account_data, account.to_dict())

View file

@ -1,9 +1,9 @@
import unittest import unittest
from torba.coin.bitcoinsegwit import BTC from torba.coin.bitcoinsegwit import MainNetLedger
from torba.coinselection import CoinSelector, MAXIMUM_TRIES from torba.coinselection import CoinSelector, MAXIMUM_TRIES
from torba.constants import CENT from torba.constants import CENT
from torba.basemanager import WalletManager from torba.manager import WalletManager
from .test_transaction import Output, get_output as utxo from .test_transaction import Output, get_output as utxo
@ -19,12 +19,12 @@ def search(*args, **kwargs):
class BaseSelectionTestCase(unittest.TestCase): class BaseSelectionTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
ledger = WalletManager().get_or_create_ledger(BTC.get_id()) self.ledger = MainNetLedger(db=':memory:')
self.coin = BTC(ledger) return self.ledger.db.start()
def estimates(self, *args): def estimates(self, *args):
txos = args if isinstance(args[0], Output) else args[0] txos = args if isinstance(args[0], Output) else args[0]
return [txo.get_estimator(self.coin) for txo in txos] return [txo.get_estimator(self.ledger) for txo in txos]
class TestCoinSelectionTests(BaseSelectionTestCase): class TestCoinSelectionTests(BaseSelectionTestCase):
@ -33,7 +33,7 @@ class TestCoinSelectionTests(BaseSelectionTestCase):
self.assertIsNone(CoinSelector([], 0, 0).select()) self.assertIsNone(CoinSelector([], 0, 0).select())
def test_skip_binary_search_if_total_not_enough(self): def test_skip_binary_search_if_total_not_enough(self):
fee = utxo(CENT).get_estimator(self.coin).fee fee = utxo(CENT).get_estimator(self.ledger).fee
big_pool = self.estimates(utxo(CENT+fee) for _ in range(100)) big_pool = self.estimates(utxo(CENT+fee) for _ in range(100))
selector = CoinSelector(big_pool, 101 * CENT, 0) selector = CoinSelector(big_pool, 101 * CENT, 0)
self.assertIsNone(selector.select()) self.assertIsNone(selector.select())

View file

@ -0,0 +1,15 @@
from twisted.trial import unittest
from twisted.internet import defer
from torba.basedatabase import BaseSQLiteWalletStorage
class TestDatabase(unittest.TestCase):
def setUp(self):
self.db = BaseSQLiteWalletStorage(':memory:')
return self.db.start()
@defer.inlineCallbacks
def test_empty_db(self):
result = yield self.db.

View file

@ -1,10 +1,10 @@
from binascii import hexlify, unhexlify from binascii import hexlify, unhexlify
from twisted.trial import unittest from twisted.trial import unittest
from torba.account import Account from torba.basetransaction import BaseTransaction, BaseInput, BaseOutput
from torba.coin.bitcoinsegwit import BTC, Transaction, Output, Input from torba.coin.bitcoinsegwit import MainNetLedger
from torba.constants import CENT, COIN from torba.constants import CENT, COIN
from torba.basemanager import WalletManager from torba.manager import WalletManager
from torba.wallet import Wallet from torba.wallet import Wallet
@ -14,34 +14,33 @@ FEE_PER_CHAR = 200000
def get_output(amount=CENT, pubkey_hash=NULL_HASH): def get_output(amount=CENT, pubkey_hash=NULL_HASH):
return Transaction() \ return BaseTransaction() \
.add_outputs([Output.pay_pubkey_hash(amount, pubkey_hash)]) \ .add_outputs([BaseTransaction.output_class.pay_pubkey_hash(amount, pubkey_hash)]) \
.outputs[0] .outputs[0]
def get_input(): def get_input():
return Input.spend(get_output()) return BaseInput.spend(get_output())
def get_transaction(txo=None): def get_transaction(txo=None):
return Transaction() \ return BaseTransaction() \
.add_inputs([get_input()]) \ .add_inputs([get_input()]) \
.add_outputs([txo or Output.pay_pubkey_hash(CENT, NULL_HASH)]) .add_outputs([txo or BaseOutput.pay_pubkey_hash(CENT, NULL_HASH)])
def get_wallet_and_coin(): def get_wallet_and_ledger():
ledger = WalletManager().get_or_create_ledger(BTC.get_id()) ledger = WalletManager().get_or_create_ledger(MainNetLedger.get_id())
coin = BTC(ledger) return Wallet('Main', [ledger], [ledger.account_class.generate(ledger, u'torba')]), ledger
return Wallet('Main', [coin], [Account.generate(coin, u'torba')]), coin
class TestSizeAndFeeEstimation(unittest.TestCase): class TestSizeAndFeeEstimation(unittest.TestCase):
def setUp(self): def setUp(self):
self.wallet, self.coin = get_wallet_and_coin() self.wallet, self.ledger = get_wallet_and_ledger()
def io_fee(self, io): def io_fee(self, io):
return self.coin.get_input_output_fee(io) return self.ledger.get_input_output_fee(io)
def test_output_size_and_fee(self): def test_output_size_and_fee(self):
txo = get_output() txo = get_output()
@ -58,7 +57,7 @@ class TestSizeAndFeeEstimation(unittest.TestCase):
base_size = tx.size - 1 - tx.inputs[0].size base_size = tx.size - 1 - tx.inputs[0].size
self.assertEqual(tx.size, 204) self.assertEqual(tx.size, 204)
self.assertEqual(tx.base_size, base_size) self.assertEqual(tx.base_size, base_size)
self.assertEqual(self.coin.get_transaction_base_fee(tx), FEE_PER_BYTE * base_size) self.assertEqual(self.ledger.get_transaction_base_fee(tx), FEE_PER_BYTE * base_size)
class TestTransactionSerialization(unittest.TestCase): class TestTransactionSerialization(unittest.TestCase):
@ -71,20 +70,20 @@ class TestTransactionSerialization(unittest.TestCase):
'000000434104678afdb0fe5548271967f1a67130b7105cd6a828e03909a67962e0ea1f61deb649f6bc3f4c' '000000434104678afdb0fe5548271967f1a67130b7105cd6a828e03909a67962e0ea1f61deb649f6bc3f4c'
'ef38c4f35504e51ec112de5c384df7ba0b8d578a4c702b6bf11d5fac00000000' 'ef38c4f35504e51ec112de5c384df7ba0b8d578a4c702b6bf11d5fac00000000'
) )
tx = Transaction(raw) tx = BaseTransaction(raw)
self.assertEqual(tx.version, 1) self.assertEqual(tx.version, 1)
self.assertEqual(tx.locktime, 0) self.assertEqual(tx.locktime, 0)
self.assertEqual(len(tx.inputs), 1) self.assertEqual(len(tx.inputs), 1)
self.assertEqual(len(tx.outputs), 1) self.assertEqual(len(tx.outputs), 1)
coinbase = tx.inputs[0] ledgerbase = tx.inputs[0]
self.assertEqual(coinbase.output_txid, NULL_HASH) self.assertEqual(ledgerbase.output_txid, NULL_HASH)
self.assertEqual(coinbase.output_index, 0xFFFFFFFF) self.assertEqual(ledgerbase.output_index, 0xFFFFFFFF)
self.assertEqual(coinbase.sequence, 4294967295) self.assertEqual(ledgerbase.sequence, 4294967295)
self.assertTrue(coinbase.is_coinbase) self.assertTrue(ledgerbase.is_ledgerbase)
self.assertEqual(coinbase.script, None) self.assertEqual(ledgerbase.script, None)
self.assertEqual( self.assertEqual(
coinbase.coinbase[8:], ledgerbase.ledgerbase[8:],
b'The Times 03/Jan/2009 Chancellor on brink of second bailout for banks' b'The Times 03/Jan/2009 Chancellor on brink of second bailout for banks'
) )
@ -98,7 +97,7 @@ class TestTransactionSerialization(unittest.TestCase):
tx._reset() tx._reset()
self.assertEqual(tx.raw, raw) self.assertEqual(tx.raw, raw)
def test_coinbase_transaction(self): def test_ledgerbase_transaction(self):
raw = unhexlify( raw = unhexlify(
'01000000010000000000000000000000000000000000000000000000000000000000000000ffffffff4e03' '01000000010000000000000000000000000000000000000000000000000000000000000000ffffffff4e03'
'1f5a070473319e592f4254432e434f4d2f4e59412ffabe6d6dcceb2a9d0444c51cabc4ee97a1a000036ca0' '1f5a070473319e592f4254432e434f4d2f4e59412ffabe6d6dcceb2a9d0444c51cabc4ee97a1a000036ca0'
@ -106,20 +105,20 @@ class TestTransactionSerialization(unittest.TestCase):
'0000000017a914e083685a1097ce1ea9e91987ab9e94eae33d8a13870000000000000000266a24aa21a9ed' '0000000017a914e083685a1097ce1ea9e91987ab9e94eae33d8a13870000000000000000266a24aa21a9ed'
'e6c99265a6b9e1d36c962fda0516b35709c49dc3b8176fa7e5d5f1f6197884b400000000' 'e6c99265a6b9e1d36c962fda0516b35709c49dc3b8176fa7e5d5f1f6197884b400000000'
) )
tx = Transaction(raw) tx = BaseTransaction(raw)
self.assertEqual(tx.version, 1) self.assertEqual(tx.version, 1)
self.assertEqual(tx.locktime, 0) self.assertEqual(tx.locktime, 0)
self.assertEqual(len(tx.inputs), 1) self.assertEqual(len(tx.inputs), 1)
self.assertEqual(len(tx.outputs), 2) self.assertEqual(len(tx.outputs), 2)
coinbase = tx.inputs[0] ledgerbase = tx.inputs[0]
self.assertEqual(coinbase.output_txid, NULL_HASH) self.assertEqual(ledgerbase.output_txid, NULL_HASH)
self.assertEqual(coinbase.output_index, 0xFFFFFFFF) self.assertEqual(ledgerbase.output_index, 0xFFFFFFFF)
self.assertEqual(coinbase.sequence, 4294967295) self.assertEqual(ledgerbase.sequence, 4294967295)
self.assertTrue(coinbase.is_coinbase) self.assertTrue(ledgerbase.is_ledgerbase)
self.assertEqual(coinbase.script, None) self.assertEqual(ledgerbase.script, None)
self.assertEqual( self.assertEqual(
coinbase.coinbase[9:22], ledgerbase.ledgerbase[9:22],
b'/BTC.COM/NYA/' b'/BTC.COM/NYA/'
) )
@ -151,17 +150,17 @@ class TestTransactionSigning(unittest.TestCase):
def test_sign(self): def test_sign(self):
ledger = WalletManager().get_or_create_ledger(BTC.get_id()) ledger = WalletManager().get_or_create_ledger(BTC.get_id())
coin = BTC(ledger) ledger = BTC(ledger)
wallet = Wallet('Main', [coin], [Account.from_seed( wallet = Wallet('Main', [ledger], [Account.from_seed(
coin, u'carbon smart garage balance margin twelve chest sword toast envelope bottom stom' ledger, u'carbon smart garage balance margin twelve chest sword toast envelope bottom stom'
u'ach absent', u'torba' u'ach absent', u'torba'
)]) )])
account = wallet.default_account account = wallet.default_account
address1 = account.receiving_keys.generate_next_address() address1 = account.receiving_keys.generate_next_address()
address2 = account.receiving_keys.generate_next_address() address2 = account.receiving_keys.generate_next_address()
pubkey_hash1 = account.coin.address_to_hash160(address1) pubkey_hash1 = account.ledger.address_to_hash160(address1)
pubkey_hash2 = account.coin.address_to_hash160(address2) pubkey_hash2 = account.ledger.address_to_hash160(address2)
tx = Transaction() \ tx = Transaction() \
.add_inputs([Input.spend(get_output(2*COIN, pubkey_hash1))]) \ .add_inputs([Input.spend(get_output(2*COIN, pubkey_hash1))]) \

View file

@ -1,7 +1,7 @@
from twisted.trial import unittest from twisted.trial import unittest
from torba.coin.bitcoinsegwit import BTC from torba.coin.bitcoinsegwit import BTC
from torba.basemanager import WalletManager from torba.manager import WalletManager
from torba.wallet import Account, Wallet, WalletStorage from torba.wallet import Account, Wallet, WalletStorage
from .ftc import FTC from .ftc import FTC

View file

@ -1,191 +0,0 @@
import itertools
from typing import Dict, Generator
from binascii import hexlify, unhexlify
from twisted.internet import defer
from torba.basecoin import BaseCoin
from torba.mnemonic import Mnemonic
from torba.bip32 import PrivateKey, PubKey, from_extended_key_string
from torba.hash import double_sha256, aes_encrypt, aes_decrypt
class KeyChain:
def __init__(self, parent_key, child_keys, gap):
self.coin = parent_key.coin
self.parent_key = parent_key # type: PubKey
self.child_keys = child_keys
self.minimum_gap = gap
self.addresses = [
self.coin.public_key_to_address(key)
for key in child_keys
]
@defer.inlineCallbacks
def has_gap(self):
if len(self.addresses) < self.minimum_gap:
defer.returnValue(False)
for address in self.addresses[-self.minimum_gap:]:
if (yield self.coin.ledger.is_address_old(address)):
defer.returnValue(False)
defer.returnValue(True)
def generate_next_address(self):
child_key = self.parent_key.child(len(self.child_keys))
self.child_keys.append(child_key.pubkey_bytes)
self.addresses.append(child_key.address)
return child_key.address
@defer.inlineCallbacks
def ensure_enough_addresses(self):
starting_length = len(self.addresses)
while not (yield self.has_gap()):
self.generate_next_address()
defer.returnValue(self.addresses[starting_length:])
class Account:
def __init__(self, coin, seed, encrypted, private_key, public_key,
receiving_keys=None, receiving_gap=20,
change_keys=None, change_gap=6):
self.coin = coin # type: BaseCoin
self.seed = seed # type: str
self.encrypted = encrypted # type: bool
self.private_key = private_key # type: PrivateKey
self.public_key = public_key # type: PubKey
self.keychains = (
KeyChain(public_key.child(0), receiving_keys or [], receiving_gap),
KeyChain(public_key.child(1), change_keys or [], change_gap)
)
self.receiving_keys, self.change_keys = self.keychains
@classmethod
def generate(cls, coin, password): # type: (BaseCoin, unicode) -> Account
seed = Mnemonic().make_seed()
return cls.from_seed(coin, seed, password)
@classmethod
def from_seed(cls, coin, seed, password): # type: (BaseCoin, unicode, unicode) -> Account
private_key = cls.get_private_key_from_seed(coin, seed, password)
return cls(
coin=coin, seed=seed, encrypted=False,
private_key=private_key,
public_key=private_key.public_key
)
@staticmethod
def get_private_key_from_seed(coin, seed, password): # type: (BaseCoin, unicode, unicode) -> PrivateKey
return PrivateKey.from_seed(coin, Mnemonic.mnemonic_to_seed(seed, password))
@classmethod
def from_dict(cls, coin, d): # type: (BaseCoin, Dict) -> Account
if not d['encrypted']:
private_key = from_extended_key_string(coin, d['private_key'])
public_key = private_key.public_key
else:
private_key = d['private_key']
public_key = from_extended_key_string(coin, d['public_key'])
return cls(
coin=coin,
seed=d['seed'],
encrypted=d['encrypted'],
private_key=private_key,
public_key=public_key,
receiving_keys=[unhexlify(k) for k in d['receiving_keys']],
receiving_gap=d['receiving_gap'],
change_keys=[unhexlify(k) for k in d['change_keys']],
change_gap=d['change_gap']
)
def to_dict(self):
return {
'coin': self.coin.get_id(),
'seed': self.seed,
'encrypted': self.encrypted,
'private_key': self.private_key if self.encrypted else
self.private_key.extended_key_string().decode(),
'public_key': self.public_key.extended_key_string().decode(),
'receiving_keys': [hexlify(k).decode() for k in self.receiving_keys.child_keys],
'receiving_gap': self.receiving_keys.minimum_gap,
'change_keys': [hexlify(k).decode() for k in self.change_keys.child_keys],
'change_gap': self.change_keys.minimum_gap
}
def decrypt(self, password):
assert self.encrypted, "Key is not encrypted."
secret = double_sha256(password)
self.seed = aes_decrypt(secret, self.seed)
self.private_key = from_extended_key_string(self.coin, aes_decrypt(secret, self.private_key))
self.encrypted = False
def encrypt(self, password):
assert not self.encrypted, "Key is already encrypted."
secret = double_sha256(password)
self.seed = aes_encrypt(secret, self.seed)
self.private_key = aes_encrypt(secret, self.private_key.extended_key_string())
self.encrypted = True
@property
def addresses(self):
return itertools.chain(self.receiving_keys.addresses, self.change_keys.addresses)
def get_private_key_for_address(self, address):
assert not self.encrypted, "Cannot get private key on encrypted wallet account."
for a, keychain in enumerate(self.keychains):
for b, match in enumerate(keychain.addresses):
if address == match:
return self.private_key.child(a).child(b)
@defer.inlineCallbacks
def ensure_enough_addresses(self):
addresses = []
for keychain in self.keychains:
for address in (yield keychain.ensure_enough_addresses()):
addresses.append(address)
defer.returnValue(addresses)
def get_least_used_receiving_address(self, max_transactions=1000):
return self._get_least_used_address(
self.receiving_keys,
max_transactions
)
def get_least_used_change_address(self, max_transactions=100):
return self._get_least_used_address(
self.change_keys,
max_transactions
)
def _get_least_used_address(self, keychain, max_transactions):
ledger = self.coin.ledger
address = ledger.get_least_used_address(self, keychain, max_transactions)
if address:
return address
address = keychain.generate_next_address()
ledger.subscribe_history(address)
return address
@defer.inlineCallbacks
def get_balance(self):
utxos = yield self.coin.ledger.get_unspent_outputs(self)
defer.returnValue(sum(utxo.amount for utxo in utxos))
class AccountsView:
def __init__(self, accounts):
self._accounts_generator = accounts
def __iter__(self): # type: () -> Generator[Account]
return self._accounts_generator()
def addresses(self):
for account in self:
for address in account.addresses:
yield address
def get_account_for_address(self, address):
for account in self:
if address in account.addresses:
return account

184
torba/baseaccount.py Normal file
View file

@ -0,0 +1,184 @@
from typing import Dict
from binascii import unhexlify
from twisted.internet import defer
from torba.mnemonic import Mnemonic
from torba.bip32 import PrivateKey, PubKey, from_extended_key_string
from torba.hash import double_sha256, aes_encrypt, aes_decrypt
class KeyChain:
def __init__(self, account, parent_key, chain_number, minimum_usable_addresses):
self.account = account
self.db = account.ledger.db
self.main_key = parent_key.child(chain_number) # type: PubKey
self.chain_number = chain_number
self.minimum_usable_addresses = minimum_usable_addresses
def get_keys(self):
return self.db.get_keys(self.account, self.chain_number)
def get_addresses(self):
return self.db.get_addresses(self.account, self.chain_number)
@defer.inlineCallbacks
def ensure_enough_useable_addresses(self):
usable_address_count = yield self.db.get_usable_address_count(
self.account, self.chain_number
)
if usable_address_count >= self.minimum_usable_addresses:
defer.returnValue([])
new_addresses_needed = self.minimum_usable_addresses - usable_address_count
start = yield self.db.get_last_address_index(
self.account, self.chain_number
)
end = start + new_addresses_needed
new_keys = []
for index in range(start+1, end+1):
new_keys.append((index, self.main_key.child(index)))
yield self.db.add_keys(
self.account, self.chain_number, new_keys
)
defer.returnValue([
key[1].address for key in new_keys
])
@defer.inlineCallbacks
def has_gap(self):
if len(self.addresses) < self.minimum_gap:
defer.returnValue(False)
for address in self.addresses[-self.minimum_gap:]:
if (yield self.ledger.is_address_old(address)):
defer.returnValue(False)
defer.returnValue(True)
class BaseAccount:
mnemonic_class = Mnemonic
private_key_class = PrivateKey
public_key_class = PubKey
def __init__(self, ledger, seed, encrypted, private_key,
public_key, receiving_gap=20, change_gap=6):
self.ledger = ledger # type: baseledger.BaseLedger
self.seed = seed # type: str
self.encrypted = encrypted # type: bool
self.private_key = private_key # type: PrivateKey
self.public_key = public_key # type: PubKey
self.receiving, self.change = self.keychains = (
KeyChain(self, public_key, 0, receiving_gap),
KeyChain(self, public_key, 1, change_gap)
)
ledger.account_created(self)
@classmethod
def generate(cls, ledger, password): # type: (baseledger.BaseLedger, str) -> BaseAccount
seed = cls.mnemonic_class().make_seed()
return cls.from_seed(ledger, seed, password)
@classmethod
def from_seed(cls, ledger, seed, password):
# type: (baseledger.BaseLedger, str, str) -> BaseAccount
private_key = cls.get_private_key_from_seed(ledger, seed, password)
return cls(
ledger=ledger, seed=seed, encrypted=False,
private_key=private_key,
public_key=private_key.public_key
)
@classmethod
def get_private_key_from_seed(cls, ledger, seed, password):
# type: (baseledger.BaseLedger, str, str) -> PrivateKey
return cls.private_key_class.from_seed(
ledger, cls.mnemonic_class.mnemonic_to_seed(seed, password)
)
@classmethod
def from_dict(cls, ledger, d): # type: (baseledger.BaseLedger, Dict) -> BaseAccount
if not d['encrypted']:
private_key = from_extended_key_string(ledger, d['private_key'])
public_key = private_key.public_key
else:
private_key = d['private_key']
public_key = from_extended_key_string(ledger, d['public_key'])
return cls(
ledger=ledger,
seed=d['seed'],
encrypted=d['encrypted'],
private_key=private_key,
public_key=public_key,
receiving_gap=d['receiving_gap'],
change_gap=d['change_gap']
)
def to_dict(self):
return {
'ledger': self.ledger.get_id(),
'seed': self.seed,
'encrypted': self.encrypted,
'private_key': self.private_key if self.encrypted else
self.private_key.extended_key_string(),
'public_key': self.public_key.extended_key_string(),
'receiving_gap': self.receiving.minimum_usable_addresses,
'change_gap': self.change.minimum_usable_addresses,
}
def decrypt(self, password):
assert self.encrypted, "Key is not encrypted."
secret = double_sha256(password)
self.seed = aes_decrypt(secret, self.seed)
self.private_key = from_extended_key_string(self.ledger, aes_decrypt(secret, self.private_key))
self.encrypted = False
def encrypt(self, password):
assert not self.encrypted, "Key is already encrypted."
secret = double_sha256(password)
self.seed = aes_encrypt(secret, self.seed)
self.private_key = aes_encrypt(secret, self.private_key.extended_key_string())
self.encrypted = True
@defer.inlineCallbacks
def ensure_enough_useable_addresses(self):
addresses = []
for keychain in self.keychains:
new_addresses = yield keychain.ensure_enough_useable_addresses()
addresses.extend(new_addresses)
defer.returnValue(addresses)
def get_private_key(self, chain, index):
assert not self.encrypted, "Cannot get private key on encrypted wallet account."
return self.private_key.child(chain).child(index)
def get_least_used_receiving_address(self, max_transactions=1000):
return self._get_least_used_address(
self.receiving_keys,
max_transactions
)
def get_least_used_change_address(self, max_transactions=100):
return self._get_least_used_address(
self.change_keys,
max_transactions
)
def _get_least_used_address(self, keychain, max_transactions):
ledger = self.ledger
address = ledger.get_least_used_address(self, keychain, max_transactions)
if address:
return address
address = keychain.generate_next_address()
ledger.subscribe_history(address)
return address
@defer.inlineCallbacks
def get_balance(self):
utxos = yield self.ledger.get_unspent_outputs(self)
defer.returnValue(sum(utxo.amount for utxo in utxos))

View file

@ -1,70 +0,0 @@
import six
from typing import Dict, Type
from torba.hash import hash160, double_sha256, Base58
class CoinRegistry(type):
coins = {} # type: Dict[str, Type[BaseCoin]]
def __new__(mcs, name, bases, attrs):
cls = super(CoinRegistry, mcs).__new__(mcs, name, bases, attrs) # type: Type[BaseCoin]
if not (name == 'BaseCoin' and not bases):
coin_id = cls.get_id()
assert coin_id not in mcs.coins, 'Coin with id "{}" already registered.'.format(coin_id)
mcs.coins[coin_id] = cls
assert cls.ledger_class.coin_class is None, (
"Ledger ({}) which this coin ({}) references is already referenced by another "
"coin ({}). One to one relationship between a coin and a ledger is strictly and "
"automatically enforced. Make sure that coin_class=None in the ledger and that "
"another Coin isn't already referencing this Ledger."
).format(cls.ledger_class.__name__, name, cls.ledger_class.coin_class.__name__)
# create back reference from ledger to the coin
cls.ledger_class.coin_class = cls
return cls
@classmethod
def get_coin_class(mcs, coin_id): # type: (str) -> Type[BaseCoin]
return mcs.coins[coin_id]
class BaseCoin(six.with_metaclass(CoinRegistry)):
name = None
symbol = None
network = None
ledger_class = None # type: Type[BaseLedger]
transaction_class = None # type: Type[BaseTransaction]
secret_prefix = None
pubkey_address_prefix = None
script_address_prefix = None
extended_public_key_prefix = None
extended_private_key_prefix = None
def __init__(self, ledger, fee_per_byte):
self.ledger = ledger
@classmethod
def get_id(cls):
return '{}_{}'.format(cls.symbol.lower(), cls.network.lower())
def to_dict(self):
return {}
def hash160_to_address(self, h160):
raw_address = self.pubkey_address_prefix + h160
return Base58.encode(bytearray(raw_address + double_sha256(raw_address)[0:4]))
@staticmethod
def address_to_hash160(address):
bytes = Base58.decode(address)
prefix, pubkey_bytes, addr_checksum = bytes[0], bytes[1:21], bytes[21:]
return pubkey_bytes
def public_key_to_address(self, public_key):
return self.hash160_to_address(hash160(public_key))
@staticmethod
def private_key_to_wif(private_key):
return b'\x1c' + private_key + b'\x01'

View file

@ -1,5 +1,4 @@
import logging import logging
import os
import sqlite3 import sqlite3
from twisted.internet import defer from twisted.internet import defer
from twisted.enterprise import adbapi from twisted.enterprise import adbapi
@ -7,50 +6,12 @@ from twisted.enterprise import adbapi
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class BaseSQLiteWalletStorage(object): class SQLiteMixin(object):
CREATE_TX_TABLE = """ CREATE_TABLES_QUERY = None
create table if not exists tx (
txid blob primary key,
raw blob not null,
height integer not null,
is_confirmed boolean not null,
is_verified boolean not null
);
create table if not exists address_status (
address blob not null,
status text not null
);
"""
CREATE_TXO_TABLE = """ def __init__(self, path):
create table if not exists txo ( self._db_path = path
txoid integer primary key,
account blob not null,
address blob not null,
txid blob references tx,
pos integer not null,
amount integer not null,
script blob not null
);
"""
CREATE_TXI_TABLE = """
create table if not exists txi (
account blob not null,
txid blob references tx,
txoid integer references txo
);
"""
CREATE_TABLES_QUERY = (
CREATE_TX_TABLE +
CREATE_TXO_TABLE +
CREATE_TXI_TABLE
)
def __init__(self, ledger):
self._db_path = os.path.join(ledger.path, "blockchain.db")
self.db = None self.db = None
def start(self): def start(self):
@ -66,47 +27,147 @@ class BaseSQLiteWalletStorage(object):
self.db.close() self.db.close()
return defer.succeed(True) return defer.succeed(True)
@defer.inlineCallbacks def _debug_sql(self, sql):
def run_and_return_one_or_none(self, query, *args): """ For use during debugging to execute arbitrary SQL queries without waiting on reactor. """
result = yield self.db.runQuery(query, args) conn = self.db.connectionFactory(self.db)
if result: trans = self.db.transactionFactory(self, conn)
defer.returnValue(result[0][0]) return trans.execute(sql).fetchall()
else:
defer.returnValue(None) def _insert_sql(self, table, data):
columns, values = [], []
for column, value in data.items():
columns.append(column)
values.append(value)
sql = "REPLACE INTO %s (%s) VALUES (%s)".format(
table, ', '.join(columns), ', '.join(['?'] * len(values))
)
return sql, values
@defer.inlineCallbacks @defer.inlineCallbacks
def run_and_return_list(self, query, *args): def query_one_value_list(self, query, params):
result = yield self.db.runQuery(query, args) result = yield self.db.runQuery(query, params)
if result: if result:
defer.returnValue([i[0] for i in result]) defer.returnValue([i[0] for i in result])
else: else:
defer.returnValue([]) defer.returnValue([])
def run_and_return_id(self, query, *args): @defer.inlineCallbacks
def do_save(t): def query_one_value(self, query, params=None, default=None):
t.execute(query, args) result = yield self.db.runQuery(query, params)
return t.lastrowid if result:
return self.db.runInteraction(do_save) defer.returnValue(result[0][0])
else:
def add_transaction(self, tx, height, is_confirmed, is_verified): defer.returnValue(default)
return self.run_and_return_id(
"insert into tx values (?, ?, ?, ?, ?)",
sqlite3.Binary(tx.id),
sqlite3.Binary(tx.raw),
height,
is_confirmed,
is_verified
)
@defer.inlineCallbacks @defer.inlineCallbacks
def has_transaction(self, txid): def query_dict_value_list(self, query, fields, params=None):
result = yield self.db.runQuery( result = yield self.db.runQuery(query.format(', '.join(fields)), params)
"select rowid from tx where txid=?", (txid,) if result:
) defer.returnValue([dict(zip(fields, r)) for r in result])
defer.returnValue(bool(result)) else:
defer.returnValue([])
def add_tx_output(self, account, txo): @defer.inlineCallbacks
return self.db.runOperation( def query_dict_value(self, query, fields, params=None, default=None):
result = yield self.query_dict_value_list(query, fields, params)
if result:
defer.returnValue(result[0])
else:
defer.returnValue(default)
def query_count(self, sql, params):
return self.query_one_value(
"SELECT count(*) FROM ({})".format(sql), params
)
def insert_and_return_id(self, table, data):
def do_insert(t):
t.execute(*self._insert_sql(table, data))
return t.lastrowid
return self.db.runInteraction(do_insert)
class BaseDatabase(SQLiteMixin):
CREATE_TX_TABLE = """
create table if not exists tx (
txid blob primary key,
raw blob not null,
height integer not null,
is_confirmed boolean not null,
is_verified boolean not null
);
"""
CREATE_PUBKEY_ADDRESS_TABLE = """
create table if not exists pubkey_address (
address blob primary key,
account blob not null,
chain integer not null,
position integer not null,
pubkey blob not null,
history text,
used_times integer default 0
);
"""
CREATE_TXO_TABLE = """
create table if not exists txo (
txoid integer primary key,
txid blob references tx,
address blob references pubkey_address,
position integer not null,
amount integer not null,
script blob not null
);
"""
CREATE_TXI_TABLE = """
create table if not exists txi (
txid blob references tx,
address blob references pubkey_address,
txoid integer references txo
);
"""
CREATE_TABLES_QUERY = (
CREATE_TX_TABLE +
CREATE_PUBKEY_ADDRESS_TABLE +
CREATE_TXO_TABLE +
CREATE_TXI_TABLE
)
def get_missing_transactions(self, address, txids):
def _steps(t):
missing = []
chunk_size = 100
for i in range(0, len(txids), chunk_size):
chunk = txids[i:i + chunk_size]
t.execute(
"SELECT 1 FROM tx WHERE txid=?",
(sqlite3.Binary(txid) for txid in chunk)
)
if not t.execute("SELECT 1 FROM tx WHERE txid=?", (sqlite3.Binary(tx.id),)).fetchone():
t.execute(*self._insert_sql('tx', {
'txid': sqlite3.Binary(tx.id),
'raw': sqlite3.Binary(tx.raw),
'height': height,
'is_confirmed': is_confirmed,
'is_verified': is_verified
}))
return self.db.runInteraction(_steps)
def add_transaction(self, address, tx, height, is_confirmed, is_verified):
def _steps(t):
if not t.execute("SELECT 1 FROM tx WHERE txid=?", (sqlite3.Binary(tx.id),)).fetchone():
t.execute(*self._insert_sql('tx', {
'txid': sqlite3.Binary(tx.id),
'raw': sqlite3.Binary(tx.raw),
'height': height,
'is_confirmed': is_confirmed,
'is_verified': is_verified
}))
t.execute(*self._insert_sql(
"insert into txo values (?, ?, ?, ?, ?, ?, ?, ?, ?)", ( "insert into txo values (?, ?, ?, ?, ?, ?, ?, ?, ?)", (
sqlite3.Binary(account.public_key.address), sqlite3.Binary(account.public_key.address),
sqlite3.Binary(txo.script.values['pubkey_hash']), sqlite3.Binary(txo.script.values['pubkey_hash']),
@ -118,10 +179,8 @@ class BaseSQLiteWalletStorage(object):
txo.script.is_support_claim, txo.script.is_support_claim,
txo.script.is_update_claim txo.script.is_update_claim
) )
)
def add_tx_input(self, account, txi): ))
def _ops(t):
txoid = t.execute( txoid = t.execute(
"select rowid from txo where txid=? and pos=?", ( "select rowid from txo where txid=? and pos=?", (
sqlite3.Binary(txi.output_txid), txi.output_index sqlite3.Binary(txi.output_txid), txi.output_index
@ -134,7 +193,15 @@ class BaseSQLiteWalletStorage(object):
txoid txoid
) )
) )
return self.db.runInteraction(_ops)
return self.db.runInteraction(_steps)
@defer.inlineCallbacks
def has_transaction(self, txid):
result = yield self.db.runQuery(
"select rowid from tx where txid=?", (txid,)
)
defer.returnValue(bool(result))
@defer.inlineCallbacks @defer.inlineCallbacks
def get_balance_for_account(self, account): def get_balance_for_account(self, account):
@ -147,41 +214,6 @@ class BaseSQLiteWalletStorage(object):
else: else:
defer.returnValue(0) defer.returnValue(0)
def get_used_addresses(self, account):
return self.db.runQuery(
"""
SELECT
txios.address,
sum(txios.used_count) as total
FROM
(SELECT address, count(*) as used_count FROM txo
WHERE account=:account GROUP BY address
UNION
SELECT address, count(*) as used_count FROM txi NATURAL JOIN txo
WHERE account=:account GROUP BY address) AS txios
GROUP BY txios.address
ORDER BY total
""", {'account': sqlite3.Binary(account.public_key.address)}
)
@defer.inlineCallbacks
def get_earliest_block_height_for_address(self, address):
result = yield self.db.runQuery(
"""
SELECT
height
FROM
(SELECT DISTINCT height FROM txi NATURAL JOIN txo NATURAL JOIN tx WHERE address=:address
UNION
SELECT DISTINCT height FROM txo NATURAL JOIN tx WHERE address=:address) AS txios
ORDER BY height LIMIT 1
""", {'address': sqlite3.Binary(address)}
)
if result:
defer.returnValue(result[0][0])
else:
defer.returnValue(None)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_utxos(self, account, output_class): def get_utxos(self, account, output_class):
utxos = yield self.db.runQuery( utxos = yield self.db.runQuery(
@ -203,17 +235,79 @@ class BaseSQLiteWalletStorage(object):
) for values in utxos ) for values in utxos
]) ])
@defer.inlineCallbacks def add_keys(self, account, chain, keys):
def get_address_status(self, address): sql = (
result = yield self.db.runQuery( "insert into pubkey_address "
"select status from address_status where address=?", (address,) "(address, account, chain, position, pubkey) "
"values "
) + ', '.join(['(?, ?, ?, ?, ?)'] * len(keys))
values = []
for position, pubkey in keys:
values.append(sqlite3.Binary(pubkey.address))
values.append(sqlite3.Binary(account.public_key.address))
values.append(chain)
values.append(position)
values.append(sqlite3.Binary(pubkey.pubkey_bytes))
return self.db.runOperation(sql, values)
def get_keys(self, account, chain):
return self.query_one_value_list(
"SELECT pubkey FROM pubkey_address WHERE account = ? AND chain = ?",
(sqlite3.Binary(account.public_key.address), chain)
)
def get_address_details(self, address):
return self.query_dict_value(
"SELECT {} FROM pubkey_address WHERE address = ?",
('account', 'chain', 'position'), (sqlite3.Binary(address),)
)
def get_addresses(self, account, chain):
return self.query_one_value_list(
"SELECT address FROM pubkey_address WHERE account = ? AND chain = ?",
(sqlite3.Binary(account.public_key.address), chain)
)
def get_last_address_index(self, account, chain):
return self.query_one_value(
"""
SELECT position FROM pubkey_address
WHERE account = ? AND chain = ?
ORDER BY position DESC LIMIT 1""",
(sqlite3.Binary(account.public_key.address), chain),
default=0
)
def _usable_address_sql(self, account, chain, exclude_used_times):
return """
SELECT address FROM pubkey_address
WHERE
account = :account AND
chain = :chain AND
used_times <= :exclude_used_times
""", {
'account': sqlite3.Binary(account.public_key.address),
'chain': chain,
'exclude_used_times': exclude_used_times
}
def get_usable_addresses(self, account, chain, exclude_used_times=2):
return self.query_one_value_list(*self._usable_address_sql(
account, chain, exclude_used_times
))
def get_usable_address_count(self, account, chain, exclude_used_times=2):
return self.query_count(*self._usable_address_sql(
account, chain, exclude_used_times
))
def get_address_history(self, address):
return self.query_one_value(
"SELECT history FROM pubkey_address WHERE address = ?", (sqlite3.Binary(address),)
) )
if result:
defer.returnValue(result[0][0])
else:
defer.returnValue(None)
def set_address_status(self, address, status): def set_address_status(self, address, status):
return self.db.runOperation( return self.db.runOperation(
"replace into address_status (address, status) values (?, ?)", (address,status) "replace into address_status (address, status) values (?, ?)", (address,status)
) )

243
torba/baseheader.py Normal file
View file

@ -0,0 +1,243 @@
import os
import struct
from binascii import unhexlify
from twisted.internet import threads, defer
import torba
from torba.stream import StreamController, execute_serially
from torba.util import int_to_hex, rev_hex, hash_encode
from torba.hash import double_sha256, pow_hash
class BaseHeaders:
header_size = 80
verify_bits_to_target = True
def __init__(self, ledger): # type: (baseledger.BaseLedger) -> BaseHeaders
self.ledger = ledger
self._size = None
self._on_change_controller = StreamController()
self.on_changed = self._on_change_controller.stream
@property
def path(self):
return os.path.join(self.ledger.path, 'headers')
def touch(self):
if not os.path.exists(self.path):
with open(self.path, 'wb'):
pass
@property
def height(self):
return len(self) - 1
def sync_read_length(self):
return os.path.getsize(self.path) // self.header_size
def sync_read_header(self, height):
if 0 <= height < len(self):
with open(self.path, 'rb') as f:
f.seek(height * self.header_size)
return f.read(self.header_size)
def __len__(self):
if self._size is None:
self._size = self.sync_read_length()
return self._size
def __getitem__(self, height):
assert not isinstance(height, slice), \
"Slicing of header chain has not been implemented yet."
header = self.sync_read_header(height)
return self._deserialize(height, header)
@execute_serially
@defer.inlineCallbacks
def connect(self, start, headers):
yield threads.deferToThread(self._sync_connect, start, headers)
def _sync_connect(self, start, headers):
previous_header = None
for header in self._iterate_headers(start, headers):
height = header['block_height']
if previous_header is None and height > 0:
previous_header = self[height-1]
self._verify_header(height, header, previous_header)
previous_header = header
with open(self.path, 'r+b') as f:
f.seek(start * self.header_size)
f.write(headers)
f.truncate()
_old_size = self._size
self._size = self.sync_read_length()
change = self._size - _old_size
#log.info('saved {} header blocks'.format(change))
self._on_change_controller.add(change)
def _iterate_headers(self, height, headers):
assert len(headers) % self.header_size == 0
for idx in range(len(headers) // self.header_size):
start, end = idx * self.header_size, (idx + 1) * self.header_size
header = headers[start:end]
yield self._deserialize(height+idx, header)
def _verify_header(self, height, header, previous_header):
previous_hash = self._hash_header(previous_header)
assert previous_hash == header['prev_block_hash'], \
"prev hash mismatch: {} vs {}".format(previous_hash, header['prev_block_hash'])
bits, target = self._calculate_next_work_required(height, previous_header, header)
assert bits == header['bits'], \
"bits mismatch: {} vs {} (hash: {})".format(
bits, header['bits'], self._hash_header(header))
# TODO: FIX ME!!!
#_pow_hash = self._pow_hash_header(header)
#assert int(b'0x' + _pow_hash, 16) <= target, \
# "insufficient proof of work: {} vs target {}".format(
# int(b'0x' + _pow_hash, 16), target)
@staticmethod
def _serialize(header):
return b''.join([
int_to_hex(header['version'], 4),
rev_hex(header['prev_block_hash']),
rev_hex(header['merkle_root']),
int_to_hex(int(header['timestamp']), 4),
int_to_hex(int(header['bits']), 4),
int_to_hex(int(header['nonce']), 4)
])
@staticmethod
def _deserialize(height, header):
version, = struct.unpack('<I', header[:4])
timestamp, bits, nonce = struct.unpack('<III', header[68:80])
return {
'block_height': height,
'version': version,
'prev_block_hash': hash_encode(header[4:36]),
'merkle_root': hash_encode(header[36:68]),
'timestamp': timestamp,
'bits': bits,
'nonce': nonce,
}
def _hash_header(self, header):
if header is None:
return b'0' * 64
return hash_encode(double_sha256(unhexlify(self._serialize(header))))
def _pow_hash_header(self, header):
if header is None:
return b'0' * 64
return hash_encode(pow_hash(unhexlify(self._serialize(header))))
def _calculate_next_work_required(self, height, first, last):
if height == 0:
return self.ledger.genesis_bits, self.ledger.max_target
if self.verify_bits_to_target:
bits = last['bits']
bitsN = (bits >> 24) & 0xff
assert 0x03 <= bitsN <= 0x1d, \
"First part of bits should be in [0x03, 0x1d], but it was {}".format(hex(bitsN))
bitsBase = bits & 0xffffff
assert 0x8000 <= bitsBase <= 0x7fffff, \
"Second part of bits should be in [0x8000, 0x7fffff] but it was {}".format(bitsBase)
# new target
retargetTimespan = self.ledger.target_timespan
nActualTimespan = last['timestamp'] - first['timestamp']
nModulatedTimespan = retargetTimespan + (nActualTimespan - retargetTimespan) // 8
nMinTimespan = retargetTimespan - (retargetTimespan // 8)
nMaxTimespan = retargetTimespan + (retargetTimespan // 2)
# Limit adjustment step
if nModulatedTimespan < nMinTimespan:
nModulatedTimespan = nMinTimespan
elif nModulatedTimespan > nMaxTimespan:
nModulatedTimespan = nMaxTimespan
# Retarget
bnPowLimit = _ArithUint256(self.ledger.max_target)
bnNew = _ArithUint256.SetCompact(last['bits'])
bnNew *= nModulatedTimespan
bnNew //= nModulatedTimespan
if bnNew > bnPowLimit:
bnNew = bnPowLimit
return bnNew.GetCompact(), bnNew._value
class _ArithUint256:
""" See: lbrycrd/src/arith_uint256.cpp """
def __init__(self, value):
self._value = value
def __str__(self):
return hex(self._value)
@staticmethod
def fromCompact(nCompact):
"""Convert a compact representation into its value"""
nSize = nCompact >> 24
# the lower 23 bits
nWord = nCompact & 0x007fffff
if nSize <= 3:
return nWord >> 8 * (3 - nSize)
else:
return nWord << 8 * (nSize - 3)
@classmethod
def SetCompact(cls, nCompact):
return cls(cls.fromCompact(nCompact))
def bits(self):
"""Returns the position of the highest bit set plus one."""
bn = bin(self._value)[2:]
for i, d in enumerate(bn):
if d:
return (len(bn) - i) + 1
return 0
def GetLow64(self):
return self._value & 0xffffffffffffffff
def GetCompact(self):
"""Convert a value into its compact representation"""
nSize = (self.bits() + 7) // 8
nCompact = 0
if nSize <= 3:
nCompact = self.GetLow64() << 8 * (3 - nSize)
else:
bn = _ArithUint256(self._value >> 8 * (nSize - 3))
nCompact = bn.GetLow64()
# The 0x00800000 bit denotes the sign.
# Thus, if it is already set, divide the mantissa by 256 and increase the exponent.
if nCompact & 0x00800000:
nCompact >>= 8
nSize += 1
assert (nCompact & ~0x007fffff) == 0
assert nSize < 256
nCompact |= nSize << 24
return nCompact
def __mul__(self, x):
# Take the mod because we are limited to an unsigned 256 bit number
return _ArithUint256((self._value * x) % 2 ** 256)
def __ifloordiv__(self, x):
self._value = (self._value // x)
return self
def __gt__(self, x):
return self._value > x._value

View file

@ -1,71 +1,100 @@
import os import os
import six
import hashlib import hashlib
import struct
from binascii import hexlify, unhexlify from binascii import hexlify, unhexlify
from typing import List, Dict, Type from typing import Dict, Type
from operator import itemgetter from operator import itemgetter
from twisted.internet import threads, defer, task, reactor from twisted.internet import defer
from torba import basetransaction, basedatabase from torba import baseaccount
from torba.account import Account, AccountsView from torba import basedatabase
from torba.basecoin import BaseCoin from torba import baseheader
from torba.basenetwork import BaseNetwork from torba import basenetwork
from torba import basetransaction
from torba.stream import StreamController, execute_serially from torba.stream import StreamController, execute_serially
from torba.util import int_to_hex, rev_hex, hash_encode from torba.hash import hash160, double_sha256, Base58
from torba.hash import double_sha256, pow_hash
class Address: class LedgerRegistry(type):
ledgers = {} # type: Dict[str, Type[BaseLedger]]
def __init__(self, pubkey_hash): def __new__(mcs, name, bases, attrs):
self.pubkey_hash = pubkey_hash cls = super(LedgerRegistry, mcs).__new__(mcs, name, bases, attrs) # type: Type[BaseLedger]
self.transactions = [] # type: List[BaseTransaction] if not (name == 'BaseLedger' and not bases):
ledger_id = cls.get_id()
assert ledger_id not in mcs.ledgers,\
'Ledger with id "{}" already registered.'.format(ledger_id)
mcs.ledgers[ledger_id] = cls
return cls
def __iter__(self): @classmethod
return iter(self.transactions) def get_ledger_class(mcs, ledger_id): # type: (str) -> Type[BaseLedger]
return mcs.ledgers[ledger_id]
def __len__(self):
return len(self.transactions)
def add_transaction(self, transaction):
if transaction not in self.transactions:
self.transactions.append(transaction)
class BaseLedger(object): class BaseLedger(six.with_metaclass(LedgerRegistry)):
# coin_class is automatically set by BaseCoin metaclass name = None
# when it creates the Coin classes, there is a 1..1 relationship symbol = None
# between a coin and a ledger (at the class level) but a 1..* relationship network_name = None
# at instance level. Only one Ledger instance should exist per coin class,
# but many coin instances can exist linking back to the single Ledger instance. account_class = baseaccount.BaseAccount
coin_class = None # type: Type[BaseCoin] database_class = basedatabase.BaseDatabase
network_class = None # type: Type[BaseNetwork] headers_class = baseheader.BaseHeaders
headers_class = None # type: Type[BaseHeaders] network_class = basenetwork.BaseNetwork
database_class = None # type: Type[basedatabase.BaseSQLiteWalletStorage] transaction_class = basetransaction.BaseTransaction
secret_prefix = None
pubkey_address_prefix = None
script_address_prefix = None
extended_public_key_prefix = None
extended_private_key_prefix = None
default_fee_per_byte = 10 default_fee_per_byte = 10
def __init__(self, accounts, config=None, db=None, network=None, def __init__(self, config=None, db=None, network=None):
fee_per_byte=default_fee_per_byte):
self.accounts = accounts # type: AccountsView
self.config = config or {} self.config = config or {}
self.db = db or self.database_class(self) # type: basedatabase.BaseSQLiteWalletStorage self.db = self.database_class(
db or os.path.join(self.path, "blockchain.db")
) # type: basedatabase.BaseSQLiteWalletStorage
self.network = network or self.network_class(self) self.network = network or self.network_class(self)
self.network.on_header.listen(self.process_header) self.network.on_header.listen(self.process_header)
self.network.on_status.listen(self.process_status) self.network.on_status.listen(self.process_status)
self.accounts = set()
self.headers = self.headers_class(self) self.headers = self.headers_class(self)
self.fee_per_byte = fee_per_byte self.fee_per_byte = self.config.get('fee_per_byte', self.default_fee_per_byte)
self._on_transaction_controller = StreamController() self._on_transaction_controller = StreamController()
self.on_transaction = self._on_transaction_controller.stream self.on_transaction = self._on_transaction_controller.stream
@classmethod
def get_id(cls):
return '{}_{}'.format(cls.symbol.lower(), cls.network_name.lower())
def hash160_to_address(self, h160):
raw_address = self.pubkey_address_prefix + h160
return Base58.encode(bytearray(raw_address + double_sha256(raw_address)[0:4]))
def account_created(self, account):
self.accounts.add(account)
@staticmethod
def address_to_hash160(address):
bytes = Base58.decode(address)
prefix, pubkey_bytes, addr_checksum = bytes[0], bytes[1:21], bytes[21:]
return pubkey_bytes
def public_key_to_address(self, public_key):
return self.hash160_to_address(hash160(public_key))
@staticmethod
def private_key_to_wif(private_key):
return b'\x1c' + private_key + b'\x01'
@property @property
def path(self): def path(self):
return os.path.join( return os.path.join(self.config['path'], self.get_id())
self.config['wallet_path'], self.coin_class.get_id()
)
def get_input_output_fee(self, io): def get_input_output_fee(self, io):
""" Fee based on size of the input / output. """ """ Fee based on size of the input / output. """
@ -75,21 +104,8 @@ class BaseLedger(object):
""" Fee for the transaction header and all outputs; without inputs. """ """ Fee for the transaction header and all outputs; without inputs. """
return self.fee_per_byte * tx.base_size return self.fee_per_byte * tx.base_size
@property def get_keys(self, account, chain):
def transaction_class(self): return self.db.get_keys(account, chain)
return self.coin_class.transaction_class
@classmethod
def from_json(cls, json_dict):
return cls(json_dict)
@defer.inlineCallbacks
def is_address_old(self, address, age_limit=2):
height = yield self.db.get_earliest_block_height_for_address(address)
if height is None:
return False
age = self.headers.height - height + 1
return age > age_limit
@defer.inlineCallbacks @defer.inlineCallbacks
def add_transaction(self, transaction, height): # type: (basetransaction.BaseTransaction, int) -> None def add_transaction(self, transaction, height): # type: (basetransaction.BaseTransaction, int) -> None
@ -108,6 +124,14 @@ class BaseLedger(object):
if used_addresses and used_addresses[0][1] < max_transactions: if used_addresses and used_addresses[0][1] < max_transactions:
defer.returnValue(used_addresses[0][0]) defer.returnValue(used_addresses[0][0])
@defer.inlineCallbacks
def get_private_key_for_address(self, address):
match = yield self.db.get_address_details(address)
if match:
for account in self.accounts:
if bytes(match['account']) == account.public_key.address:
defer.returnValue(account.get_private_key(match['chain'], match['position']))
def get_unspent_outputs(self, account): def get_unspent_outputs(self, account):
return self.db.get_utxos(account, self.transaction_class.output_class) return self.db.get_utxos(account, self.transaction_class.output_class)
@ -177,8 +201,7 @@ class BaseLedger(object):
# need to update anyways. Continue to get history and create more addresses until # need to update anyways. Continue to get history and create more addresses until
# all missing addresses are created and history for them is fully restored. # all missing addresses are created and history for them is fully restored.
yield account.ensure_enough_addresses() yield account.ensure_enough_addresses()
used_addresses = yield self.db.get_used_addresses(account) addresses = yield account.get_unused_addresses(account)
addresses = set(account.addresses) - set(map(itemgetter(0), used_addresses))
while addresses: while addresses:
yield defer.DeferredList([ yield defer.DeferredList([
self.update_history(a) for a in addresses self.update_history(a) for a in addresses
@ -203,7 +226,9 @@ class BaseLedger(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def update_history(self, address, remote_status=None): def update_history(self, address, remote_status=None):
history = yield self.network.get_history(address) history = yield self.network.get_history(address)
hashes = list(map(itemgetter('tx_hash'), history))
for hash, height in map(itemgetter('tx_hash', 'height'), history): for hash, height in map(itemgetter('tx_hash', 'height'), history):
if not (yield self.db.has_transaction(hash)): if not (yield self.db.has_transaction(hash)):
raw = yield self.network.get_transaction(hash) raw = yield self.network.get_transaction(hash)
transaction = self.transaction_class(unhexlify(raw)) transaction = self.transaction_class(unhexlify(raw))
@ -229,236 +254,3 @@ class BaseLedger(object):
def broadcast(self, tx): def broadcast(self, tx):
return self.network.broadcast(hexlify(tx.raw)) return self.network.broadcast(hexlify(tx.raw))
class BaseHeaders:
header_size = 80
verify_bits_to_target = True
def __init__(self, ledger):
self.ledger = ledger
self._size = None
self._on_change_controller = StreamController()
self.on_changed = self._on_change_controller.stream
@property
def path(self):
return os.path.join(self.ledger.path, 'headers')
def touch(self):
if not os.path.exists(self.path):
with open(self.path, 'wb'):
pass
@property
def height(self):
return len(self) - 1
def sync_read_length(self):
return os.path.getsize(self.path) // self.header_size
def sync_read_header(self, height):
if 0 <= height < len(self):
with open(self.path, 'rb') as f:
f.seek(height * self.header_size)
return f.read(self.header_size)
def __len__(self):
if self._size is None:
self._size = self.sync_read_length()
return self._size
def __getitem__(self, height):
assert not isinstance(height, slice),\
"Slicing of header chain has not been implemented yet."
header = self.sync_read_header(height)
return self._deserialize(height, header)
@execute_serially
@defer.inlineCallbacks
def connect(self, start, headers):
yield threads.deferToThread(self._sync_connect, start, headers)
def _sync_connect(self, start, headers):
previous_header = None
for header in self._iterate_headers(start, headers):
height = header['block_height']
if previous_header is None and height > 0:
previous_header = self[height-1]
self._verify_header(height, header, previous_header)
previous_header = header
with open(self.path, 'r+b') as f:
f.seek(start * self.header_size)
f.write(headers)
f.truncate()
_old_size = self._size
self._size = self.sync_read_length()
change = self._size - _old_size
#log.info('saved {} header blocks'.format(change))
self._on_change_controller.add(change)
def _iterate_headers(self, height, headers):
assert len(headers) % self.header_size == 0
for idx in range(len(headers) // self.header_size):
start, end = idx * self.header_size, (idx + 1) * self.header_size
header = headers[start:end]
yield self._deserialize(height+idx, header)
def _verify_header(self, height, header, previous_header):
previous_hash = self._hash_header(previous_header)
assert previous_hash == header['prev_block_hash'], \
"prev hash mismatch: {} vs {}".format(previous_hash, header['prev_block_hash'])
bits, target = self._calculate_next_work_required(height, previous_header, header)
assert bits == header['bits'], \
"bits mismatch: {} vs {} (hash: {})".format(
bits, header['bits'], self._hash_header(header))
# TODO: FIX ME!!!
#_pow_hash = self._pow_hash_header(header)
#assert int(b'0x' + _pow_hash, 16) <= target, \
# "insufficient proof of work: {} vs target {}".format(
# int(b'0x' + _pow_hash, 16), target)
@staticmethod
def _serialize(header):
return b''.join([
int_to_hex(header['version'], 4),
rev_hex(header['prev_block_hash']),
rev_hex(header['merkle_root']),
int_to_hex(int(header['timestamp']), 4),
int_to_hex(int(header['bits']), 4),
int_to_hex(int(header['nonce']), 4)
])
@staticmethod
def _deserialize(height, header):
version, = struct.unpack('<I', header[:4])
timestamp, bits, nonce = struct.unpack('<III', header[68:80])
return {
'block_height': height,
'version': version,
'prev_block_hash': hash_encode(header[4:36]),
'merkle_root': hash_encode(header[36:68]),
'timestamp': timestamp,
'bits': bits,
'nonce': nonce,
}
def _hash_header(self, header):
if header is None:
return b'0' * 64
return hash_encode(double_sha256(unhexlify(self._serialize(header))))
def _pow_hash_header(self, header):
if header is None:
return b'0' * 64
return hash_encode(pow_hash(unhexlify(self._serialize(header))))
def _calculate_next_work_required(self, height, first, last):
if height == 0:
return self.ledger.genesis_bits, self.ledger.max_target
if self.verify_bits_to_target:
bits = last['bits']
bitsN = (bits >> 24) & 0xff
assert 0x03 <= bitsN <= 0x1d, \
"First part of bits should be in [0x03, 0x1d], but it was {}".format(hex(bitsN))
bitsBase = bits & 0xffffff
assert 0x8000 <= bitsBase <= 0x7fffff, \
"Second part of bits should be in [0x8000, 0x7fffff] but it was {}".format(bitsBase)
# new target
retargetTimespan = self.ledger.target_timespan
nActualTimespan = last['timestamp'] - first['timestamp']
nModulatedTimespan = retargetTimespan + (nActualTimespan - retargetTimespan) // 8
nMinTimespan = retargetTimespan - (retargetTimespan // 8)
nMaxTimespan = retargetTimespan + (retargetTimespan // 2)
# Limit adjustment step
if nModulatedTimespan < nMinTimespan:
nModulatedTimespan = nMinTimespan
elif nModulatedTimespan > nMaxTimespan:
nModulatedTimespan = nMaxTimespan
# Retarget
bnPowLimit = _ArithUint256(self.ledger.max_target)
bnNew = _ArithUint256.SetCompact(last['bits'])
bnNew *= nModulatedTimespan
bnNew //= nModulatedTimespan
if bnNew > bnPowLimit:
bnNew = bnPowLimit
return bnNew.GetCompact(), bnNew._value
class _ArithUint256:
""" See: lbrycrd/src/arith_uint256.cpp """
def __init__(self, value):
self._value = value
def __str__(self):
return hex(self._value)
@staticmethod
def fromCompact(nCompact):
"""Convert a compact representation into its value"""
nSize = nCompact >> 24
# the lower 23 bits
nWord = nCompact & 0x007fffff
if nSize <= 3:
return nWord >> 8 * (3 - nSize)
else:
return nWord << 8 * (nSize - 3)
@classmethod
def SetCompact(cls, nCompact):
return cls(cls.fromCompact(nCompact))
def bits(self):
"""Returns the position of the highest bit set plus one."""
bn = bin(self._value)[2:]
for i, d in enumerate(bn):
if d:
return (len(bn) - i) + 1
return 0
def GetLow64(self):
return self._value & 0xffffffffffffffff
def GetCompact(self):
"""Convert a value into its compact representation"""
nSize = (self.bits() + 7) // 8
nCompact = 0
if nSize <= 3:
nCompact = self.GetLow64() << 8 * (3 - nSize)
else:
bn = _ArithUint256(self._value >> 8 * (nSize - 3))
nCompact = bn.GetLow64()
# The 0x00800000 bit denotes the sign.
# Thus, if it is already set, divide the mantissa by 256 and increase the exponent.
if nCompact & 0x00800000:
nCompact >>= 8
nSize += 1
assert (nCompact & ~0x007fffff) == 0
assert nSize < 256
nCompact |= nSize << 24
return nCompact
def __mul__(self, x):
# Take the mod because we are limited to an unsigned 256 bit number
return _ArithUint256((self._value * x) % 2 ** 256)
def __ifloordiv__(self, x):
self._value = (self._value // x)
return self
def __gt__(self, x):
return self._value > x._value

View file

@ -3,13 +3,12 @@ import logging
from typing import List, Iterable, Generator from typing import List, Iterable, Generator
from binascii import hexlify from binascii import hexlify
from torba import baseledger
from torba.basescript import BaseInputScript, BaseOutputScript from torba.basescript import BaseInputScript, BaseOutputScript
from torba.coinselection import CoinSelector from torba.coinselection import CoinSelector
from torba.constants import COIN from torba.constants import COIN
from torba.bcd_data_stream import BCDataStream from torba.bcd_data_stream import BCDataStream
from torba.hash import sha256 from torba.hash import sha256
from torba.account import Account from torba.baseaccount import BaseAccount
from torba.util import ReadOnlyList from torba.util import ReadOnlyList
@ -45,7 +44,7 @@ class InputOutput(object):
class BaseInput(InputOutput): class BaseInput(InputOutput):
script_class = None script_class = BaseInputScript
NULL_SIGNATURE = b'\x00'*72 NULL_SIGNATURE = b'\x00'*72
NULL_PUBLIC_KEY = b'\x00'*33 NULL_PUBLIC_KEY = b'\x00'*33
@ -113,7 +112,7 @@ class BaseOutputEffectiveAmountEstimator(object):
__slots__ = 'coin', 'txi', 'txo', 'fee', 'effective_amount' __slots__ = 'coin', 'txi', 'txo', 'fee', 'effective_amount'
def __init__(self, ledger, txo): # type: (baseledger.BaseLedger, BaseOutput) -> None def __init__(self, ledger, txo): # type: (BaseLedger, BaseOutput) -> None
self.txo = txo self.txo = txo
self.txi = ledger.transaction_class.input_class.spend(txo) self.txi = ledger.transaction_class.input_class.spend(txo)
self.fee = ledger.get_input_output_fee(self.txi) self.fee = ledger.get_input_output_fee(self.txi)
@ -125,7 +124,7 @@ class BaseOutputEffectiveAmountEstimator(object):
class BaseOutput(InputOutput): class BaseOutput(InputOutput):
script_class = None script_class = BaseOutputScript
estimator_class = BaseOutputEffectiveAmountEstimator estimator_class = BaseOutputEffectiveAmountEstimator
def __init__(self, amount, script, txid=None): def __init__(self, amount, script, txid=None):
@ -154,8 +153,8 @@ class BaseOutput(InputOutput):
class BaseTransaction: class BaseTransaction:
input_class = None input_class = BaseInput
output_class = None output_class = BaseOutput
def __init__(self, raw=None, version=1, locktime=0): def __init__(self, raw=None, version=1, locktime=0):
self._raw = raw self._raw = raw
@ -277,23 +276,23 @@ class BaseTransaction:
@classmethod @classmethod
def get_effective_amount_estimators(cls, funding_accounts): def get_effective_amount_estimators(cls, funding_accounts):
# type: (Iterable[Account]) -> Generator[BaseOutputEffectiveAmountEstimator] # type: (Iterable[BaseAccount]) -> Generator[BaseOutputEffectiveAmountEstimator]
for account in funding_accounts: for account in funding_accounts:
for utxo in account.coin.ledger.get_unspent_outputs(account): for utxo in account.coin.ledger.get_unspent_outputs(account):
yield utxo.get_estimator(account.coin) yield utxo.get_estimator(account.coin)
@classmethod @classmethod
def ensure_all_have_same_ledger(cls, funding_accounts, change_account=None): def ensure_all_have_same_ledger(cls, funding_accounts, change_account=None):
# type: (Iterable[Account], Account) -> baseledger.BaseLedger # type: (Iterable[BaseAccount], BaseAccount) -> baseledger.BaseLedger
ledger = None ledger = None
for account in funding_accounts: for account in funding_accounts:
if ledger is None: if ledger is None:
ledger = account.coin.ledger ledger = account.ledger
if ledger != account.coin.ledger: if ledger != account.ledger:
raise ValueError( raise ValueError(
'All funding accounts used to create a transaction must be on the same ledger.' 'All funding accounts used to create a transaction must be on the same ledger.'
) )
if change_account is not None and change_account.coin.ledger != ledger: if change_account is not None and change_account.ledger != ledger:
raise ValueError('Change account must use same ledger as funding accounts.') raise ValueError('Change account must use same ledger as funding accounts.')
return ledger return ledger
@ -331,14 +330,13 @@ class BaseTransaction:
def liquidate(cls, assets, funding_accounts, change_account): def liquidate(cls, assets, funding_accounts, change_account):
""" Spend assets (utxos) supplementing with funding_accounts if fee is higher than asset value. """ """ Spend assets (utxos) supplementing with funding_accounts if fee is higher than asset value. """
def sign(self, funding_accounts): # type: (Iterable[Account]) -> BaseTransaction def sign(self, funding_accounts): # type: (Iterable[BaseAccount]) -> BaseTransaction
ledger = self.ensure_all_have_same_ledger(funding_accounts) ledger = self.ensure_all_have_same_ledger(funding_accounts)
for i, txi in enumerate(self._inputs): for i, txi in enumerate(self._inputs):
txo_script = txi.output.script txo_script = txi.output.script
if txo_script.is_pay_pubkey_hash: if txo_script.is_pay_pubkey_hash:
address = ledger.coin_class.hash160_to_address(txo_script.values['pubkey_hash']) address = ledger.hash160_to_address(txo_script.values['pubkey_hash'])
account = ledger.accounts.get_account_for_address(address) private_key = ledger.get_private_key_for_address(address)
private_key = account.get_private_key_for_address(address)
tx = self._serialize_for_signature(i) tx = self._serialize_for_signature(i)
txi.script.values['signature'] = private_key.sign(tx)+six.int2byte(1) txi.script.values['signature'] = private_key.sign(tx)+six.int2byte(1)
txi.script.values['pubkey'] = private_key.public_key.pubkey_bytes txi.script.values['pubkey'] = private_key.public_key.pubkey_bytes

View file

@ -16,7 +16,6 @@ import ecdsa
import ecdsa.ellipticcurve as EC import ecdsa.ellipticcurve as EC
import ecdsa.numbertheory as NT import ecdsa.numbertheory as NT
from torba.basecoin import BaseCoin
from torba.hash import Base58, hmac_sha512, hash160, double_sha256 from torba.hash import Base58, hmac_sha512, hash160, double_sha256
from torba.util import cachedproperty, bytes_to_int, int_to_bytes from torba.util import cachedproperty, bytes_to_int, int_to_bytes
@ -30,9 +29,7 @@ class _KeyBase(object):
CURVE = ecdsa.SECP256k1 CURVE = ecdsa.SECP256k1
def __init__(self, coin, chain_code, n, depth, parent): def __init__(self, ledger, chain_code, n, depth, parent):
if not isinstance(coin, BaseCoin):
raise TypeError('invalid coin')
if not isinstance(chain_code, (bytes, bytearray)): if not isinstance(chain_code, (bytes, bytearray)):
raise TypeError('chain code must be raw bytes') raise TypeError('chain code must be raw bytes')
if len(chain_code) != 32: if len(chain_code) != 32:
@ -44,7 +41,7 @@ class _KeyBase(object):
if parent is not None: if parent is not None:
if not isinstance(parent, type(self)): if not isinstance(parent, type(self)):
raise TypeError('parent key has bad type') raise TypeError('parent key has bad type')
self.coin = coin self.ledger = ledger
self.chain_code = chain_code self.chain_code = chain_code
self.n = n self.n = n
self.depth = depth self.depth = depth
@ -86,8 +83,8 @@ class _KeyBase(object):
class PubKey(_KeyBase): class PubKey(_KeyBase):
""" A BIP32 public key. """ """ A BIP32 public key. """
def __init__(self, coin, pubkey, chain_code, n, depth, parent=None): def __init__(self, ledger, pubkey, chain_code, n, depth, parent=None):
super(PubKey, self).__init__(coin, chain_code, n, depth, parent) super(PubKey, self).__init__(ledger, chain_code, n, depth, parent)
if isinstance(pubkey, ecdsa.VerifyingKey): if isinstance(pubkey, ecdsa.VerifyingKey):
self.verifying_key = pubkey self.verifying_key = pubkey
else: else:
@ -129,7 +126,7 @@ class PubKey(_KeyBase):
@cachedproperty @cachedproperty
def address(self): def address(self):
""" The public key as a P2PKH address. """ """ The public key as a P2PKH address. """
return self.coin.public_key_to_address(self.pubkey_bytes) return self.ledger.public_key_to_address(self.pubkey_bytes)
def ec_point(self): def ec_point(self):
return self.verifying_key.pubkey.point return self.verifying_key.pubkey.point
@ -153,7 +150,7 @@ class PubKey(_KeyBase):
verkey = ecdsa.VerifyingKey.from_public_point(point, curve=curve) verkey = ecdsa.VerifyingKey.from_public_point(point, curve=curve)
return PubKey(self.coin, verkey, R, n, self.depth + 1, self) return PubKey(self.ledger, verkey, R, n, self.depth + 1, self)
def identifier(self): def identifier(self):
""" Return the key's identifier as 20 bytes. """ """ Return the key's identifier as 20 bytes. """
@ -162,7 +159,7 @@ class PubKey(_KeyBase):
def extended_key(self): def extended_key(self):
""" Return a raw extended public key. """ """ Return a raw extended public key. """
return self._extended_key( return self._extended_key(
self.coin.extended_public_key_prefix, self.ledger.extended_public_key_prefix,
self.pubkey_bytes self.pubkey_bytes
) )
@ -186,8 +183,8 @@ class PrivateKey(_KeyBase):
HARDENED = 1 << 31 HARDENED = 1 << 31
def __init__(self, coin, privkey, chain_code, n, depth, parent=None): def __init__(self, ledger, privkey, chain_code, n, depth, parent=None):
super(PrivateKey, self).__init__(coin, chain_code, n, depth, parent) super(PrivateKey, self).__init__(ledger, chain_code, n, depth, parent)
if isinstance(privkey, ecdsa.SigningKey): if isinstance(privkey, ecdsa.SigningKey):
self.signing_key = privkey self.signing_key = privkey
else: else:
@ -212,11 +209,11 @@ class PrivateKey(_KeyBase):
return exponent return exponent
@classmethod @classmethod
def from_seed(cls, coin, seed): def from_seed(cls, ledger, seed):
# This hard-coded message string seems to be coin-independent... # This hard-coded message string seems to be coin-independent...
hmac = hmac_sha512(b'Bitcoin seed', seed) hmac = hmac_sha512(b'Bitcoin seed', seed)
privkey, chain_code = hmac[:32], hmac[32:] privkey, chain_code = hmac[:32], hmac[32:]
return cls(coin, privkey, chain_code, 0, 0) return cls(ledger, privkey, chain_code, 0, 0)
@cachedproperty @cachedproperty
def private_key_bytes(self): def private_key_bytes(self):
@ -228,7 +225,7 @@ class PrivateKey(_KeyBase):
""" Return the corresponding extended public key. """ """ Return the corresponding extended public key. """
verifying_key = self.signing_key.get_verifying_key() verifying_key = self.signing_key.get_verifying_key()
parent_pubkey = self.parent.public_key if self.parent else None parent_pubkey = self.parent.public_key if self.parent else None
return PubKey(self.coin, verifying_key, self.chain_code, self.n, self.depth, return PubKey(self.ledger, verifying_key, self.chain_code, self.n, self.depth,
parent_pubkey) parent_pubkey)
def ec_point(self): def ec_point(self):
@ -240,7 +237,7 @@ class PrivateKey(_KeyBase):
def wif(self): def wif(self):
""" Return the private key encoded in Wallet Import Format. """ """ Return the private key encoded in Wallet Import Format. """
return self.coin.private_key_to_wif(self.private_key_bytes) return self.ledger.private_key_to_wif(self.private_key_bytes)
def address(self): def address(self):
""" The public key as a P2PKH address. """ """ The public key as a P2PKH address. """
@ -267,7 +264,7 @@ class PrivateKey(_KeyBase):
privkey = _exponent_to_bytes(exponent) privkey = _exponent_to_bytes(exponent)
return PrivateKey(self.coin, privkey, R, n, self.depth + 1, self) return PrivateKey(self.ledger, privkey, R, n, self.depth + 1, self)
def sign(self, data): def sign(self, data):
""" Produce a signature for piece of data by double hashing it and signing the hash. """ """ Produce a signature for piece of data by double hashing it and signing the hash. """
@ -282,7 +279,7 @@ class PrivateKey(_KeyBase):
def extended_key(self): def extended_key(self):
"""Return a raw extended private key.""" """Return a raw extended private key."""
return self._extended_key( return self._extended_key(
self.coin.extended_private_key_prefix, self.ledger.extended_private_key_prefix,
b'\0' + self.private_key_bytes b'\0' + self.private_key_bytes
) )
@ -292,7 +289,7 @@ def _exponent_to_bytes(exponent):
return (int2byte(0)*32 + int_to_bytes(exponent))[-32:] return (int2byte(0)*32 + int_to_bytes(exponent))[-32:]
def _from_extended_key(coin, ekey): def _from_extended_key(ledger, ekey):
"""Return a PubKey or PrivateKey from an extended key raw bytes.""" """Return a PubKey or PrivateKey from an extended key raw bytes."""
if not isinstance(ekey, (bytes, bytearray)): if not isinstance(ekey, (bytes, bytearray)):
raise TypeError('extended key must be raw bytes') raise TypeError('extended key must be raw bytes')
@ -304,21 +301,21 @@ def _from_extended_key(coin, ekey):
n, = struct.unpack('>I', ekey[9:13]) n, = struct.unpack('>I', ekey[9:13])
chain_code = ekey[13:45] chain_code = ekey[13:45]
if ekey[:4] == coin.extended_public_key_prefix: if ekey[:4] == ledger.extended_public_key_prefix:
pubkey = ekey[45:] pubkey = ekey[45:]
key = PubKey(coin, pubkey, chain_code, n, depth) key = PubKey(ledger, pubkey, chain_code, n, depth)
elif ekey[:4] == coin.extended_private_key_prefix: elif ekey[:4] == ledger.extended_private_key_prefix:
if indexbytes(ekey, 45) != 0: if indexbytes(ekey, 45) != 0:
raise ValueError('invalid extended private key prefix byte') raise ValueError('invalid extended private key prefix byte')
privkey = ekey[46:] privkey = ekey[46:]
key = PrivateKey(coin, privkey, chain_code, n, depth) key = PrivateKey(ledger, privkey, chain_code, n, depth)
else: else:
raise ValueError('version bytes unrecognised') raise ValueError('version bytes unrecognised')
return key return key
def from_extended_key_string(coin, ekey_str): def from_extended_key_string(ledger, ekey_str):
"""Given an extended key string, such as """Given an extended key string, such as
xpub6BsnM1W2Y7qLMiuhi7f7dbAwQZ5Cz5gYJCRzTNainXzQXYjFwtuQXHd xpub6BsnM1W2Y7qLMiuhi7f7dbAwQZ5Cz5gYJCRzTNainXzQXYjFwtuQXHd
@ -326,4 +323,4 @@ def from_extended_key_string(coin, ekey_str):
return a PubKey or PrivateKey. return a PubKey or PrivateKey.
""" """
return _from_extended_key(coin, Base58.decode_check(ekey_str)) return _from_extended_key(ledger, Base58.decode_check(ekey_str))

View file

@ -13,7 +13,7 @@ from torba.basescript import BaseInputScript, BaseOutputScript
from torba.basetransaction import BaseTransaction, BaseInput, BaseOutput from torba.basetransaction import BaseTransaction, BaseInput, BaseOutput
from torba.basecoin import BaseCoin from torba.basecoin import BaseCoin
from torba.basedatabase import BaseSQLiteWalletStorage from torba.basedatabase import BaseSQLiteWalletStorage
from torba.basemanager import BaseWalletManager from torba.manager import BaseWalletManager
class WalletManager(BaseWalletManager): class WalletManager(BaseWalletManager):

View file

@ -7,65 +7,14 @@ __node_url__ = (
from six import int2byte from six import int2byte
from binascii import unhexlify from binascii import unhexlify
from torba.baseledger import BaseLedger, BaseHeaders from torba.baseledger import BaseLedger
from torba.basenetwork import BaseNetwork from torba.baseheader import BaseHeaders
from torba.basescript import BaseInputScript, BaseOutputScript
from torba.basetransaction import BaseTransaction, BaseInput, BaseOutput
from torba.basecoin import BaseCoin
from torba.basedatabase import BaseSQLiteWalletStorage
from torba.basemanager import BaseWalletManager
class WalletManager(BaseWalletManager): class MainNetLedger(BaseLedger):
pass
class SQLiteWalletStorage(BaseSQLiteWalletStorage):
pass
class Input(BaseInput):
script_class = BaseInputScript
class Output(BaseOutput):
script_class = BaseOutputScript
class Transaction(BaseTransaction):
input_class = Input
output_class = Output
class BitcoinSegwitLedger(BaseLedger):
network_class = BaseNetwork
headers_class = BaseHeaders
class MainNetLedger(BitcoinSegwitLedger):
pass
class UnverifiedHeaders(BaseHeaders):
verify_bits_to_target = False
class RegTestLedger(BitcoinSegwitLedger):
headers_class = UnverifiedHeaders
max_target = 0x7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff
genesis_hash = '0f9188f13cb7b2c71f2a335e3a4fc328bf5beb436012afca590b1a11466e2206'
genesis_bits = 0x207fffff
target_timespan = 1
verify_bits_to_target = False
class BitcoinSegwit(BaseCoin):
name = 'BitcoinSegwit' name = 'BitcoinSegwit'
symbol = 'BTC' symbol = 'BTC'
network = 'mainnet' network_name = 'mainnet'
ledger_class = MainNetLedger
transaction_class = Transaction
pubkey_address_prefix = int2byte(0x00) pubkey_address_prefix = int2byte(0x00)
script_address_prefix = int2byte(0x05) script_address_prefix = int2byte(0x05)
@ -74,10 +23,16 @@ class BitcoinSegwit(BaseCoin):
default_fee_per_byte = 50 default_fee_per_byte = 50
def __init__(self, ledger, fee_per_byte=default_fee_per_byte):
super(BitcoinSegwit, self).__init__(ledger, fee_per_byte) class UnverifiedHeaders(BaseHeaders):
verify_bits_to_target = False
class BitcoinSegwitRegtest(BitcoinSegwit): class RegTestLedger(MainNetLedger):
network = 'regtest' network_name = 'regtest'
ledger_class = RegTestLedger headers_class = UnverifiedHeaders
max_target = 0x7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff
genesis_hash = '0f9188f13cb7b2c71f2a335e3a4fc328bf5beb436012afca590b1a11466e2206'
genesis_bits = 0x207fffff
target_timespan = 1
verify_bits_to_target = False

View file

@ -2,8 +2,7 @@ import functools
from typing import List, Dict, Type from typing import List, Dict, Type
from twisted.internet import defer from twisted.internet import defer
from torba.account import AccountsView from torba.baseaccount import AccountsView
from torba.basecoin import CoinRegistry
from torba.baseledger import BaseLedger from torba.baseledger import BaseLedger
from torba.basetransaction import BaseTransaction, NULL_HASH from torba.basetransaction import BaseTransaction, NULL_HASH
from torba.coinselection import CoinSelector from torba.coinselection import CoinSelector
@ -11,7 +10,7 @@ from torba.constants import COIN
from torba.wallet import Wallet, WalletStorage from torba.wallet import Wallet, WalletStorage
class BaseWalletManager(object): class WalletManager(object):
def __init__(self, wallets=None, ledgers=None): def __init__(self, wallets=None, ledgers=None):
self.wallets = wallets or [] # type: List[Wallet] self.wallets = wallets or [] # type: List[Wallet]
@ -35,12 +34,12 @@ class BaseWalletManager(object):
ledger_class = coin_class.ledger_class ledger_class = coin_class.ledger_class
ledger = self.ledgers.get(ledger_class) ledger = self.ledgers.get(ledger_class)
if ledger is None: if ledger is None:
ledger = self.create_ledger(ledger_class, self.get_accounts_view(coin_class), ledger_config or {}) ledger = self.create_ledger(ledger_class, ledger_config or {})
self.ledgers[ledger_class] = ledger self.ledgers[ledger_class] = ledger
return ledger return ledger
def create_ledger(self, ledger_class, accounts, config): def create_ledger(self, ledger_class, config):
return ledger_class(accounts, config) return ledger_class(config)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_balance(self): def get_balance(self):

View file

@ -3,16 +3,15 @@ import json
import os import os
from typing import List, Dict from typing import List, Dict
from torba.account import Account from torba.baseaccount import BaseAccount
from torba.basecoin import CoinRegistry, BaseCoin from torba.baseledger import LedgerRegistry, BaseLedger
from torba.baseledger import BaseLedger
def inflate_coin(manager, coin_id, coin_dict): def inflate_ledger(manager, ledger_id, ledger_dict):
# type: ('WalletManager', str, Dict) -> BaseCoin # type: ('WalletManager', str, Dict) -> BaseLedger
coin_class = CoinRegistry.get_coin_class(coin_id) ledger_class = LedgerRegistry.get_ledger_class(ledger_id)
ledger = manager.get_or_create_ledger(coin_id) ledger = manager.get_or_create_ledger(ledger_id)
return coin_class(ledger, **coin_dict) return ledger_class(ledger, **ledger_dict)
class Wallet: class Wallet:
@ -22,23 +21,14 @@ class Wallet:
by physical files on the filesystem. by physical files on the filesystem.
""" """
def __init__(self, name='Wallet', coins=None, accounts=None, storage=None): def __init__(self, name='Wallet', ledgers=None, accounts=None, storage=None):
self.name = name self.name = name
self.coins = coins or [] # type: List[BaseCoin] self.ledgers = ledgers or [] # type: List[BaseLedger]
self.accounts = accounts or [] # type: List[Account] self.accounts = accounts or [] # type: List[BaseAccount]
self.storage = storage or WalletStorage() self.storage = storage or WalletStorage()
def get_or_create_coin(self, ledger, coin_dict=None): # type: (BaseLedger, Dict) -> BaseCoin
for coin in self.coins:
if coin.__class__ is ledger.coin_class:
return coin
coin = ledger.coin_class(ledger, **(coin_dict or {}))
self.coins.append(coin)
return coin
def generate_account(self, ledger): # type: (BaseLedger) -> Account def generate_account(self, ledger): # type: (BaseLedger) -> Account
coin = self.get_or_create_coin(ledger) account = ledger.account_class.generate(ledger, u'torba')
account = Account.generate(coin, u'torba')
self.accounts.append(account) self.accounts.append(account)
return account return account
@ -46,22 +36,22 @@ class Wallet:
def from_storage(cls, storage, manager): # type: (WalletStorage, 'WalletManager') -> Wallet def from_storage(cls, storage, manager): # type: (WalletStorage, 'WalletManager') -> Wallet
json_dict = storage.read() json_dict = storage.read()
coins = {} ledgers = {}
for coin_id, coin_dict in json_dict.get('coins', {}).items(): for ledger_id, ledger_dict in json_dict.get('ledgers', {}).items():
coins[coin_id] = inflate_coin(manager, coin_id, coin_dict) ledgers[ledger_id] = inflate_ledger(manager, ledger_id, ledger_dict)
accounts = [] accounts = []
for account_dict in json_dict.get('accounts', []): for account_dict in json_dict.get('accounts', []):
coin_id = account_dict['coin'] ledger_id = account_dict['ledger']
coin = coins.get(coin_id) ledger = ledgers.get(ledger_id)
if coin is None: if ledger is None:
coin = coins[coin_id] = inflate_coin(manager, coin_id, {}) ledger = ledgers[ledger_id] = inflate_ledger(manager, ledger_id, {})
account = Account.from_dict(coin, account_dict) account = ledger.account_class.from_dict(ledger, account_dict)
accounts.append(account) accounts.append(account)
return cls( return cls(
name=json_dict.get('name', 'Wallet'), name=json_dict.get('name', 'Wallet'),
coins=list(coins.values()), ledgers=list(ledgers.values()),
accounts=accounts, accounts=accounts,
storage=storage storage=storage
) )
@ -69,7 +59,7 @@ class Wallet:
def to_dict(self): def to_dict(self):
return { return {
'name': self.name, 'name': self.name,
'coins': {c.get_id(): c.to_dict() for c in self.coins}, 'ledgers': {c.get_id(): {} for c in self.ledgers},
'accounts': [a.to_dict() for a in self.accounts] 'accounts': [a.to_dict() for a in self.accounts]
} }