refactoring and added basic test for reading/writing wallet file

This commit is contained in:
Lex Berezhny 2018-06-17 23:22:15 -04:00
parent 43cd9c4100
commit 833ef98ff5
5 changed files with 42 additions and 21 deletions

View file

@ -1,3 +1,4 @@
import tempfile
from twisted.trial import unittest
from torba.coin.bitcoinsegwit import MainNetLedger as BTCLedger
@ -36,11 +37,11 @@ class TestWalletCreation(unittest.TestCase):
"h absent",
'encrypted': False,
'private_key':
b'xprv9s21ZrQH143K2dyhK7SevfRG72bYDRNv25yKPWWm6dqApNxm1Zb1m5gGcBWYfbsPjTr2v5joit8Af2Zp5P'
b'6yz3jMbycrLrRMpeAJxR8qDg8',
'xprv9s21ZrQH143K2dyhK7SevfRG72bYDRNv25yKPWWm6dqApNxm1Zb1m5gGcBWYfbsPjTr2v5joit8Af2Zp5P'
'6yz3jMbycrLrRMpeAJxR8qDg8',
'public_key':
b'xpub661MyMwAqRbcF84AR8yfHoMzf4S2ct6mPJtvBtvNeyN9hBHuZ6uGJszkTSn5fQUCdz3XU17eBzFeAUwV6f'
b'iW44g14WF52fYC5J483wqQ5ZP',
'xpub661MyMwAqRbcF84AR8yfHoMzf4S2ct6mPJtvBtvNeyN9hBHuZ6uGJszkTSn5fQUCdz3XU17eBzFeAUwV6f'
'iW44g14WF52fYC5J483wqQ5ZP',
'receiving_gap': 10,
'receiving_maximum_use_per_address': 2,
'change_gap': 10,
@ -57,3 +58,24 @@ class TestWalletCreation(unittest.TestCase):
self.assertIsInstance(account, BTCLedger.account_class)
self.maxDiff = None
self.assertDictEqual(wallet_dict, wallet.to_dict())
def test_read_write(self):
manager = WalletManager()
config = {'wallet_path': '/tmp/wallet'}
ledger = manager.get_or_create_ledger(BTCLedger.get_id(), config)
with tempfile.NamedTemporaryFile(suffix='.json') as wallet_file:
wallet_file.write(b'{}')
wallet_file.seek(0)
# create and write wallet to a file
wallet_storage = WalletStorage(wallet_file.name)
wallet = Wallet.from_storage(wallet_storage, manager)
account = wallet.generate_account(ledger)
wallet.save()
# read wallet from file
wallet_storage = WalletStorage(wallet_file.name)
wallet = Wallet.from_storage(wallet_storage, manager)
self.assertEqual(account.public_key.address, wallet.default_account.public_key.address)

View file

@ -133,8 +133,8 @@ class BaseAccount:
'seed': self.seed,
'encrypted': self.encrypted,
'private_key': self.private_key if self.encrypted else
self.private_key.extended_key_string(),
'public_key': self.public_key.extended_key_string(),
self.private_key.extended_key_string().decode(),
'public_key': self.public_key.extended_key_string().decode(),
'receiving_gap': self.receiving.gap,
'change_gap': self.change.gap,
'receiving_maximum_use_per_address': self.receiving.maximum_use_per_address,

View file

@ -2,7 +2,7 @@ import os
import six
import hashlib
from binascii import hexlify, unhexlify
from typing import Dict, Type
from typing import Dict, Type, Iterable, Generator
from operator import itemgetter
from twisted.internet import defer
@ -126,6 +126,16 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)):
def get_unspent_outputs(self, account):
return self.db.get_utxos(account, self.transaction_class.output_class)
@defer.inlineCallbacks
def get_effective_amount_estimators(self, funding_accounts):
# type: (Iterable[baseaccount.BaseAccount]) -> defer.Deferred
estimators = []
for account in funding_accounts:
utxos = yield self.get_unspent_outputs(account)
for utxo in utxos:
estimators.append(utxo.get_estimator(self))
defer.returnValue(estimators)
@defer.inlineCallbacks
def get_local_status(self, address):
address_details = yield self.db.get_address(address)

View file

@ -1,6 +1,6 @@
import six
import logging
from typing import List, Iterable, Generator
from typing import List, Iterable
from binascii import hexlify
from twisted.internet import defer
@ -271,17 +271,6 @@ class BaseTransaction:
])
self.locktime = stream.read_uint32()
@classmethod
@defer.inlineCallbacks
def get_effective_amount_estimators(cls, funding_accounts):
# type: (Iterable[torba.baseaccount.BaseAccount]) -> Generator[BaseOutputEffectiveAmountEstimator]
estimators = []
for account in funding_accounts:
utxos = yield account.ledger.get_unspent_outputs(account)
for utxo in utxos:
estimators.append(utxo.get_estimator(account.ledger))
defer.returnValue(estimators)
@classmethod
def ensure_all_have_same_ledger(cls, funding_accounts, change_account=None):
# type: (Iterable[torba.baseaccount.BaseAccount], torba.baseaccount.BaseAccount) -> torba.baseledger.BaseLedger
@ -306,7 +295,7 @@ class BaseTransaction:
tx = cls().add_outputs(outputs)
ledger = cls.ensure_all_have_same_ledger(funding_accounts, change_account)
amount = tx.output_sum + ledger.get_transaction_base_fee(tx)
txos = yield cls.get_effective_amount_estimators(funding_accounts)
txos = yield ledger.get_effective_amount_estimators(funding_accounts)
selector = CoinSelector(
txos, amount,
ledger.get_input_output_fee(

View file

@ -17,7 +17,7 @@ class Wallet:
def __init__(self, name='Wallet', accounts=None, storage=None):
# type: (str, List[torba.baseaccount.BaseAccount], WalletStorage) -> None
self.name = name
self.accounts = accounts or []
self.accounts = accounts or [] # type: List[torba.baseaccount.BaseAccount]
self.storage = storage or WalletStorage()
def generate_account(self, ledger):