unittests

This commit is contained in:
Lex Berezhny 2020-05-01 09:34:34 -04:00
parent 9554b66a37
commit 713c665588
14 changed files with 224 additions and 288 deletions

View file

@ -18,7 +18,7 @@ class TestBlob(AsyncioTestCase):
self.tmp_dir = tempfile.mkdtemp() self.tmp_dir = tempfile.mkdtemp()
self.addCleanup(lambda: shutil.rmtree(self.tmp_dir)) self.addCleanup(lambda: shutil.rmtree(self.tmp_dir))
self.loop = asyncio.get_running_loop() self.loop = asyncio.get_running_loop()
self.config = Config() self.config = Config.with_same_dir(self.tmp_dir)
self.storage = SQLiteStorage(self.config, ":memory:", self.loop) self.storage = SQLiteStorage(self.config, ":memory:", self.loop)
self.blob_manager = BlobManager(self.loop, self.tmp_dir, self.storage, self.config) self.blob_manager = BlobManager(self.loop, self.tmp_dir, self.storage, self.config)
await self.storage.open() await self.storage.open()

View file

@ -33,14 +33,12 @@ class BlobExchangeTestBase(AsyncioTestCase):
self.server_dir = tempfile.mkdtemp() self.server_dir = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, self.client_dir) self.addCleanup(shutil.rmtree, self.client_dir)
self.addCleanup(shutil.rmtree, self.server_dir) self.addCleanup(shutil.rmtree, self.server_dir)
self.server_config = Config(data_dir=self.server_dir, download_dir=self.server_dir, wallet=self.server_dir, self.server_config = Config.with_same_dir(self.server_dir).set(reflector_servers=[])
reflector_servers=[])
self.server_storage = SQLiteStorage(self.server_config, os.path.join(self.server_dir, "lbrynet.sqlite")) self.server_storage = SQLiteStorage(self.server_config, os.path.join(self.server_dir, "lbrynet.sqlite"))
self.server_blob_manager = BlobManager(self.loop, self.server_dir, self.server_storage, self.server_config) self.server_blob_manager = BlobManager(self.loop, self.server_dir, self.server_storage, self.server_config)
self.server = BlobServer(self.loop, self.server_blob_manager, 'bQEaw42GXsgCAGio1nxFncJSyRmnztSCjP') self.server = BlobServer(self.loop, self.server_blob_manager, 'bQEaw42GXsgCAGio1nxFncJSyRmnztSCjP')
self.client_config = Config(data_dir=self.client_dir, download_dir=self.client_dir, wallet=self.client_dir, self.client_config = Config.with_same_dir(self.client_dir).set(reflector_servers=[])
reflector_servers=[])
self.client_storage = SQLiteStorage(self.client_config, os.path.join(self.client_dir, "lbrynet.sqlite")) self.client_storage = SQLiteStorage(self.client_config, os.path.join(self.client_dir, "lbrynet.sqlite"))
self.client_blob_manager = BlobManager(self.loop, self.client_dir, self.client_storage, self.client_config) self.client_blob_manager = BlobManager(self.loop, self.client_dir, self.client_storage, self.client_config)
self.client_peer_manager = PeerManager(self.loop) self.client_peer_manager = PeerManager(self.loop)
@ -98,7 +96,7 @@ class TestBlobExchange(BlobExchangeTestBase):
second_client_dir = tempfile.mkdtemp() second_client_dir = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, second_client_dir) self.addCleanup(shutil.rmtree, second_client_dir)
second_client_conf = Config() second_client_conf = Config.with_same_dir(second_client_dir)
second_client_storage = SQLiteStorage(second_client_conf, os.path.join(second_client_dir, "lbrynet.sqlite")) second_client_storage = SQLiteStorage(second_client_conf, os.path.join(second_client_dir, "lbrynet.sqlite"))
second_client_blob_manager = BlobManager( second_client_blob_manager = BlobManager(
self.loop, second_client_dir, second_client_storage, second_client_conf self.loop, second_client_dir, second_client_storage, second_client_conf
@ -188,7 +186,7 @@ class TestBlobExchange(BlobExchangeTestBase):
second_client_dir = tempfile.mkdtemp() second_client_dir = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, second_client_dir) self.addCleanup(shutil.rmtree, second_client_dir)
second_client_conf = Config() second_client_conf = Config.with_same_dir(second_client_dir)
second_client_storage = SQLiteStorage(second_client_conf, os.path.join(second_client_dir, "lbrynet.sqlite")) second_client_storage = SQLiteStorage(second_client_conf, os.path.join(second_client_dir, "lbrynet.sqlite"))
second_client_blob_manager = BlobManager( second_client_blob_manager = BlobManager(

View file

@ -3,7 +3,7 @@ import hashlib
from lbry.extras.daemon.comment_client import sign_comment from lbry.extras.daemon.comment_client import sign_comment
from lbry.extras.daemon.comment_client import is_comment_signed_by_channel from lbry.extras.daemon.comment_client import is_comment_signed_by_channel
from tests.unit.wallet.test_schema_signing import get_stream, get_channel from unit.schema.test_schema_signing import get_stream, get_channel
class TestSigningComments(AsyncioTestCase): class TestSigningComments(AsyncioTestCase):

View file

@ -1,14 +1,15 @@
from unittest import TestCase
from binascii import unhexlify, hexlify from binascii import unhexlify, hexlify
from lbry.testcase import AsyncioTestCase from lbry.blockchain.ledger import Ledger
from lbry.wallet.bip32 import PubKey, PrivateKey, from_extended_key_string from lbry.crypto.bip32 import PubKey, PrivateKey, from_extended_key_string
from lbry.wallet import Ledger, Headers
from lbry.db import Database
from tests.unit.wallet.key_fixtures import expected_ids, expected_privkeys, expected_hardened_privkeys from tests.unit.wallet.key_fixtures import (
expected_ids, expected_privkeys, expected_hardened_privkeys
)
class BIP32Tests(AsyncioTestCase): class BIP32Tests(TestCase):
def test_pubkey_validation(self): def test_pubkey_validation(self):
with self.assertRaisesRegex(TypeError, 'chain code must be raw bytes'): with self.assertRaisesRegex(TypeError, 'chain code must be raw bytes'):
@ -41,16 +42,13 @@ class BIP32Tests(AsyncioTestCase):
self.assertIsInstance(new_key, PubKey) self.assertIsInstance(new_key, PubKey)
self.assertEqual(hexlify(new_key.identifier()), expected_ids[i]) self.assertEqual(hexlify(new_key.identifier()), expected_ids[i])
async def test_private_key_validation(self): def test_private_key_validation(self):
with self.assertRaisesRegex(TypeError, 'private key must be raw bytes'): with self.assertRaisesRegex(TypeError, 'private key must be raw bytes'):
PrivateKey(None, None, b'abcd'*8, 0, 255) PrivateKey(None, None, b'abcd'*8, 0, 255)
with self.assertRaisesRegex(ValueError, 'private key must be 32 bytes'): with self.assertRaisesRegex(ValueError, 'private key must be 32 bytes'):
PrivateKey(None, b'abcd', b'abcd'*8, 0, 255) PrivateKey(None, b'abcd', b'abcd'*8, 0, 255)
private_key = PrivateKey( private_key = PrivateKey(
Ledger({ Ledger(),
'db': Database('sqlite:///:memory:'),
'headers': Headers(':memory:'),
}),
unhexlify('2423f3dc6087d9683f73a684935abc0ccd8bc26370588f56653128c6a6f0bf7c'), unhexlify('2423f3dc6087d9683f73a684935abc0ccd8bc26370588f56653128c6a6f0bf7c'),
b'abcd'*8, 0, 1 b'abcd'*8, 0, 1
) )
@ -66,12 +64,9 @@ class BIP32Tests(AsyncioTestCase):
private_key.child(-1) private_key.child(-1)
self.assertIsInstance(private_key.child(PrivateKey.HARDENED), PrivateKey) self.assertIsInstance(private_key.child(PrivateKey.HARDENED), PrivateKey)
async def test_private_key_derivation(self): def test_private_key_derivation(self):
private_key = PrivateKey( private_key = PrivateKey(
Ledger({ Ledger(),
'db': Database('sqlite:///:memory:'),
'headers': Headers(':memory:'),
}),
unhexlify('2423f3dc6087d9683f73a684935abc0ccd8bc26370588f56653128c6a6f0bf7c'), unhexlify('2423f3dc6087d9683f73a684935abc0ccd8bc26370588f56653128c6a6f0bf7c'),
b'abcd'*8, 0, 1 b'abcd'*8, 0, 1
) )
@ -84,21 +79,17 @@ class BIP32Tests(AsyncioTestCase):
self.assertIsInstance(new_privkey, PrivateKey) self.assertIsInstance(new_privkey, PrivateKey)
self.assertEqual(hexlify(new_privkey.private_key_bytes), expected_hardened_privkeys[i - 1 - PrivateKey.HARDENED]) self.assertEqual(hexlify(new_privkey.private_key_bytes), expected_hardened_privkeys[i - 1 - PrivateKey.HARDENED])
async def test_from_extended_keys(self): def test_from_extended_keys(self):
ledger = Ledger({
'db': Database('sqlite:///:memory:'),
'headers': Headers(':memory:'),
})
self.assertIsInstance( self.assertIsInstance(
from_extended_key_string( from_extended_key_string(
ledger, Ledger(),
'xprv9s21ZrQH143K2dyhK7SevfRG72bYDRNv25yKPWWm6dqApNxm1Zb1m5gGcBWYfbsPjTr2v5joit8Af2Zp5P' 'xprv9s21ZrQH143K2dyhK7SevfRG72bYDRNv25yKPWWm6dqApNxm1Zb1m5gGcBWYfbsPjTr2v5joit8Af2Zp5P'
'6yz3jMbycrLrRMpeAJxR8qDg8', '6yz3jMbycrLrRMpeAJxR8qDg8',
), PrivateKey ), PrivateKey
) )
self.assertIsInstance( self.assertIsInstance(
from_extended_key_string( from_extended_key_string(
ledger, Ledger(),
'xpub661MyMwAqRbcF84AR8yfHoMzf4S2ct6mPJtvBtvNeyN9hBHuZ6uGJszkTSn5fQUCdz3XU17eBzFeAUwV6f' 'xpub661MyMwAqRbcF84AR8yfHoMzf4S2ct6mPJtvBtvNeyN9hBHuZ6uGJszkTSn5fQUCdz3XU17eBzFeAUwV6f'
'iW44g14WF52fYC5J483wqQ5ZP', 'iW44g14WF52fYC5J483wqQ5ZP',
), PubKey ), PubKey

View file

@ -70,9 +70,9 @@ fake_claim_info = {
class StorageTest(AsyncioTestCase): class StorageTest(AsyncioTestCase):
async def asyncSetUp(self): async def asyncSetUp(self):
self.conf = Config()
self.storage = SQLiteStorage(self.conf, ':memory:')
self.blob_dir = tempfile.mkdtemp() self.blob_dir = tempfile.mkdtemp()
self.conf = Config.with_same_dir(self.blob_dir)
self.storage = SQLiteStorage(self.conf, ':memory:')
self.addCleanup(shutil.rmtree, self.blob_dir) self.addCleanup(shutil.rmtree, self.blob_dir)
self.blob_manager = BlobManager(asyncio.get_event_loop(), self.blob_dir, self.storage, self.conf) self.blob_manager = BlobManager(asyncio.get_event_loop(), self.blob_dir, self.storage, self.conf)
await self.storage.open() await self.storage.open()

View file

@ -10,7 +10,7 @@ from lbry.extras.daemon.components import HASH_ANNOUNCER_COMPONENT
from lbry.extras.daemon.components import UPNP_COMPONENT, BLOB_COMPONENT from lbry.extras.daemon.components import UPNP_COMPONENT, BLOB_COMPONENT
from lbry.extras.daemon.components import PEER_PROTOCOL_SERVER_COMPONENT, EXCHANGE_RATE_MANAGER_COMPONENT from lbry.extras.daemon.components import PEER_PROTOCOL_SERVER_COMPONENT, EXCHANGE_RATE_MANAGER_COMPONENT
from lbry.extras.daemon.daemon import Daemon as LBRYDaemon from lbry.extras.daemon.daemon import Daemon as LBRYDaemon
from lbry.wallet import WalletManager, Wallet from lbry.wallet.manager import WalletManager, Wallet
from tests import test_utils from tests import test_utils
# from tests.mocks import mock_conf_settings, FakeNetwork, FakeFileManager # from tests.mocks import mock_conf_settings, FakeNetwork, FakeFileManager

View file

@ -1,41 +1,25 @@
from unittest import TestCase
from binascii import unhexlify from binascii import unhexlify
from lbry.testcase import AsyncioTestCase from lbry.testcase import get_transaction
from lbry.wallet.constants import CENT, NULL_HASH32 from lbry.blockchain.ledger import Ledger
from lbry.blockchain.transaction import Transaction, Output
from lbry.wallet import Ledger, Headers, Transaction, Input, Output from lbry.constants import CENT
from lbry.db import Database
from lbry.schema.claim import Claim from lbry.schema.claim import Claim
def get_output(amount=CENT, pubkey_hash=NULL_HASH32):
return Transaction() \
.add_outputs([Output.pay_pubkey_hash(amount, pubkey_hash)]) \
.outputs[0]
def get_input():
return Input.spend(get_output())
def get_tx():
return Transaction().add_inputs([get_input()])
def get_channel(claim_name='@foo'): def get_channel(claim_name='@foo'):
channel_txo = Output.pay_claim_name_pubkey_hash(CENT, claim_name, Claim(), b'abc') channel_txo = Output.pay_claim_name_pubkey_hash(CENT, claim_name, Claim(), b'abc')
channel_txo.generate_channel_private_key() channel_txo.generate_channel_private_key()
get_tx().add_outputs([channel_txo]) return get_transaction(channel_txo).outputs[0]
return channel_txo
def get_stream(claim_name='foo'): def get_stream(claim_name='foo'):
stream_txo = Output.pay_claim_name_pubkey_hash(CENT, claim_name, Claim(), b'abc') stream_txo = Output.pay_claim_name_pubkey_hash(CENT, claim_name, Claim(), b'abc')
get_tx().add_outputs([stream_txo]) return get_transaction(stream_txo).outputs[0]
return stream_txo
class TestSigningAndValidatingClaim(AsyncioTestCase): class TestSigningAndValidatingClaim(TestCase):
def test_successful_create_sign_and_validate(self): def test_successful_create_sign_and_validate(self):
channel = get_channel() channel = get_channel()
@ -65,7 +49,7 @@ class TestSigningAndValidatingClaim(AsyncioTestCase):
self.assertFalse(channel.is_channel_private_key(get_channel().private_key)) self.assertFalse(channel.is_channel_private_key(get_channel().private_key))
class TestValidatingOldSignatures(AsyncioTestCase): class TestValidatingOldSignatures(TestCase):
def test_signed_claim_made_by_ytsync(self): def test_signed_claim_made_by_ytsync(self):
stream_tx = Transaction(unhexlify( stream_tx = Transaction(unhexlify(
@ -109,9 +93,4 @@ class TestValidatingOldSignatures(AsyncioTestCase):
)) ))
channel = channel_tx.outputs[0] channel = channel_tx.outputs[0]
ledger = Ledger({ self.assertTrue(stream.is_signed_by(channel, Ledger()))
'db': Database(':memory:'),
'headers': Headers(':memory:')
})
self.assertTrue(stream.is_signed_by(channel, ledger))

View file

@ -18,7 +18,7 @@ class TestStreamAssembler(AsyncioTestCase):
tmp_dir = tempfile.mkdtemp() tmp_dir = tempfile.mkdtemp()
self.addCleanup(lambda: shutil.rmtree(tmp_dir)) self.addCleanup(lambda: shutil.rmtree(tmp_dir))
self.conf = Config() self.conf = Config.with_same_dir(tmp_dir)
self.storage = SQLiteStorage(self.conf, os.path.join(tmp_dir, "lbrynet.sqlite")) self.storage = SQLiteStorage(self.conf, os.path.join(tmp_dir, "lbrynet.sqlite"))
await self.storage.open() await self.storage.open()
self.blob_manager = BlobManager(self.loop, tmp_dir, self.storage, self.conf) self.blob_manager = BlobManager(self.loop, tmp_dir, self.storage, self.conf)
@ -26,7 +26,7 @@ class TestStreamAssembler(AsyncioTestCase):
server_tmp_dir = tempfile.mkdtemp() server_tmp_dir = tempfile.mkdtemp()
self.addCleanup(lambda: shutil.rmtree(server_tmp_dir)) self.addCleanup(lambda: shutil.rmtree(server_tmp_dir))
self.server_conf = Config() self.server_conf = Config.with_same_dir(server_tmp_dir)
self.server_storage = SQLiteStorage(self.server_conf, os.path.join(server_tmp_dir, "lbrynet.sqlite")) self.server_storage = SQLiteStorage(self.server_conf, os.path.join(server_tmp_dir, "lbrynet.sqlite"))
await self.server_storage.open() await self.server_storage.open()
self.server_blob_manager = BlobManager(self.loop, server_tmp_dir, self.server_storage, self.server_conf) self.server_blob_manager = BlobManager(self.loop, server_tmp_dir, self.server_storage, self.server_conf)

View file

@ -20,7 +20,7 @@ class TestStreamDescriptor(AsyncioTestCase):
self.cleartext = os.urandom(20000000) self.cleartext = os.urandom(20000000)
self.tmp_dir = tempfile.mkdtemp() self.tmp_dir = tempfile.mkdtemp()
self.addCleanup(lambda: shutil.rmtree(self.tmp_dir)) self.addCleanup(lambda: shutil.rmtree(self.tmp_dir))
self.conf = Config() self.conf = Config.with_same_dir(self.tmp_dir)
self.storage = SQLiteStorage(self.conf, ":memory:") self.storage = SQLiteStorage(self.conf, ":memory:")
await self.storage.open() await self.storage.open()
self.blob_manager = BlobManager(self.loop, self.tmp_dir, self.storage, self.conf) self.blob_manager = BlobManager(self.loop, self.tmp_dir, self.storage, self.conf)
@ -93,7 +93,7 @@ class TestRecoverOldStreamDescriptors(AsyncioTestCase):
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
tmp_dir = tempfile.mkdtemp() tmp_dir = tempfile.mkdtemp()
self.addCleanup(lambda: shutil.rmtree(tmp_dir)) self.addCleanup(lambda: shutil.rmtree(tmp_dir))
self.conf = Config() self.conf = Config.with_same_dir(tmp_dir)
storage = SQLiteStorage(self.conf, ":memory:") storage = SQLiteStorage(self.conf, ":memory:")
await storage.open() await storage.open()
blob_manager = BlobManager(loop, tmp_dir, storage, self.conf) blob_manager = BlobManager(loop, tmp_dir, storage, self.conf)

View file

@ -10,11 +10,10 @@ from lbry.testcase import get_fake_exchange_rate_manager
from lbry.utils import generate_id from lbry.utils import generate_id
from lbry.error import InsufficientFundsError from lbry.error import InsufficientFundsError
from lbry.error import KeyFeeAboveMaxAllowedError, ResolveError, DownloadSDTimeoutError, DownloadDataTimeoutError from lbry.error import KeyFeeAboveMaxAllowedError, ResolveError, DownloadSDTimeoutError, DownloadDataTimeoutError
from lbry.wallet import WalletManager, Wallet, Ledger, Transaction, Input, Output from lbry.blockchain.ledger import Ledger
from lbry.wallet.constants import CENT, NULL_HASH32 from lbry.blockchain.transaction import Transaction, Input, Output
from lbry.wallet.network import ClientSession from lbry.constants import CENT, NULL_HASH32
from lbry.db import Database from lbry.service.full_node import FullNode
from lbry.conf import Config
from lbry.extras.daemon.analytics import AnalyticsManager from lbry.extras.daemon.analytics import AnalyticsManager
from lbry.stream.stream_manager import StreamManager from lbry.stream.stream_manager import StreamManager
from lbry.stream.descriptor import StreamDescriptor from lbry.stream.descriptor import StreamDescriptor
@ -64,7 +63,7 @@ def get_claim_transaction(claim_name, claim=b''):
) )
async def get_mock_wallet(sd_hash, storage, balance=10.0, fee=None): async def get_mock_wallet(sd_hash, storage, conf, balance=10.0, fee=None):
claim = Claim() claim = Claim()
if fee: if fee:
if fee['currency'] == 'LBC': if fee['currency'] == 'LBC':
@ -84,45 +83,26 @@ async def get_mock_wallet(sd_hash, storage, balance=10.0, fee=None):
}) })
class FakeHeaders: service = FullNode(Ledger(conf), 'sqlite:///:memory:')
def estimated_timestamp(self, height): await service.db.open()
return 1984 await service.wallet_manager.open()
def __init__(self, height):
self.height = height
def __getitem__(self, item):
return {'timestamp': 1984}
wallet = Wallet()
ledger = Ledger({
'db': Database('sqlite:///:memory:'),
'headers': FakeHeaders(514082)
})
await ledger.db.open()
wallet.generate_account(ledger)
manager = WalletManager()
manager.config = Config()
manager.wallets.append(wallet)
manager.ledgers[Ledger] = ledger
manager.ledger.network.client = ClientSession(
network=manager.ledger.network, server=('fakespv.lbry.com', 50001)
)
async def mock_resolve(*args, **kwargs): async def mock_resolve(*args, **kwargs):
result = {txo.meta['permanent_url']: txo} result = {txo.meta['permanent_url']: txo}
claims = [ claims = [
StreamManager._convert_to_old_resolve_output(manager, result)[txo.meta['permanent_url']] StreamManager._convert_to_old_resolve_output(
service.wallet_manager, result
)[txo.meta['permanent_url']]
] ]
await storage.save_claims(claims) await storage.save_claims(claims)
return result return result
manager.ledger.resolve = mock_resolve service.resolve = mock_resolve
async def get_balance(*_): async def get_balance(*_):
return balance return balance
manager.get_balance = get_balance service.get_balance = get_balance
return manager, txo.meta['permanent_url'] return service, txo.meta['permanent_url']
class TestStreamManager(BlobExchangeTestBase): class TestStreamManager(BlobExchangeTestBase):
@ -138,8 +118,10 @@ class TestStreamManager(BlobExchangeTestBase):
self.loop, self.server_blob_manager.blob_dir, file_path, old_sort=old_sort self.loop, self.server_blob_manager.blob_dir, file_path, old_sort=old_sort
) )
self.sd_hash = descriptor.sd_hash self.sd_hash = descriptor.sd_hash
self.mock_wallet, self.uri = await get_mock_wallet(self.sd_hash, self.client_storage, balance, fee) self.service, self.uri = await get_mock_wallet(
self.stream_manager = StreamManager(self.loop, self.client_config, self.client_blob_manager, self.mock_wallet, self.sd_hash, self.client_storage, self.client_config, balance, fee
)
self.stream_manager = StreamManager(self.loop, self.client_config, self.client_blob_manager, self.service,
self.client_storage, get_mock_node(self.server_from_client), self.client_storage, get_mock_node(self.server_from_client),
AnalyticsManager(self.client_config, AnalyticsManager(self.client_config,
binascii.hexlify(generate_id()).decode(), binascii.hexlify(generate_id()).decode(),

View file

@ -1,25 +1,42 @@
from lbry.wallet.stream import StreamController
from lbry.wallet.tasks import TaskGroup
from lbry.testcase import AsyncioTestCase from lbry.testcase import AsyncioTestCase
from lbry.event import EventController
from lbry.tasks import TaskGroup
class StreamControllerTestCase(AsyncioTestCase): class StreamControllerTestCase(AsyncioTestCase):
def test_non_unique_events(self):
async def test_non_unique_events(self):
events = [] events = []
controller = StreamController() controller = EventController()
controller.stream.listen(on_data=events.append) controller.stream.listen(events.append)
controller.add("yo") await controller.add("yo")
controller.add("yo") await controller.add("yo")
self.assertListEqual(events, ["yo", "yo"]) self.assertListEqual(events, ["yo", "yo"])
def test_unique_events(self): async def test_unique_events(self):
events = [] events = []
controller = StreamController(merge_repeated_events=True) controller = EventController(merge_repeated_events=True)
controller.stream.listen(on_data=events.append) controller.stream.listen(events.append)
controller.add("yo") await controller.add("yo")
controller.add("yo") await controller.add("yo")
self.assertListEqual(events, ["yo"]) self.assertListEqual(events, ["yo"])
async def test_sync_listener_errors(self):
def bad_listener(e):
raise ValueError('bad')
controller = EventController()
controller.stream.listen(bad_listener)
with self.assertRaises(ValueError):
await controller.add("yo")
async def test_async_listener_errors(self):
async def bad_listener(e):
raise ValueError('bad')
controller = EventController()
controller.stream.listen(bad_listener)
with self.assertRaises(ValueError):
await controller.add("yo")
class TaskGroupTestCase(AsyncioTestCase): class TaskGroupTestCase(AsyncioTestCase):

View file

@ -1,31 +1,34 @@
from binascii import hexlify
from lbry.testcase import AsyncioTestCase from lbry.testcase import AsyncioTestCase
from lbry.wallet import Wallet, Ledger, Headers, Account, SingleKey, HierarchicalDeterministic
from lbry.db import Database from lbry.blockchain.ledger import Ledger
from lbry.wallet.account import Account, SingleKey, HierarchicalDeterministic
from lbry.db import Database, PubkeyAddress
class TestAccount(AsyncioTestCase): class AccountTestCase(AsyncioTestCase):
async def asyncSetUp(self): async def asyncSetUp(self):
self.ledger = Ledger({ self.ledger = Ledger()
'db': Database('sqlite:///:memory:'), self.db = Database(self.ledger, 'sqlite:///:memory:')
'headers': Headers(':memory:') await self.db.open()
}) self.addCleanup(self.db.close)
await self.ledger.db.open()
async def asyncTearDown(self): async def update_addressed_used(self, address, used):
await self.ledger.db.close() await self.db.execute(
PubkeyAddress.update()
.where(PubkeyAddress.c.address == address)
.values(used_times=used)
)
class TestAccount(AccountTestCase):
async def test_generate_account(self): async def test_generate_account(self):
account = Account.generate(self.ledger, Wallet(), 'lbryum') account = Account.generate(self.ledger, self.db, 'lbryum')
self.assertEqual(account.ledger, self.ledger)
self.assertIsNotNone(account.seed) self.assertIsNotNone(account.seed)
self.assertEqual(account.public_key.ledger, self.ledger) self.assertEqual(account.public_key.ledger, self.ledger)
self.assertEqual(account.private_key.public_key, account.public_key) self.assertEqual(account.private_key.public_key, account.public_key)
self.assertEqual(account.public_key.ledger, self.ledger)
self.assertEqual(account.private_key.public_key, account.public_key)
addresses = await account.receiving.get_addresses() addresses = await account.receiving.get_addresses()
self.assertEqual(len(addresses), 0) self.assertEqual(len(addresses), 0)
addresses = await account.change.get_addresses() addresses = await account.change.get_addresses()
@ -39,14 +42,14 @@ class TestAccount(AsyncioTestCase):
self.assertEqual(len(addresses), 6) self.assertEqual(len(addresses), 6)
async def test_generate_keys_over_batch_threshold_saves_it_properly(self): async def test_generate_keys_over_batch_threshold_saves_it_properly(self):
account = Account.generate(self.ledger, Wallet(), 'lbryum') account = Account.generate(self.ledger, self.db, 'lbryum')
async with account.receiving.address_generator_lock: async with account.receiving.address_generator_lock:
await account.receiving._generate_keys(0, 200) await account.receiving._generate_keys(0, 200)
records = await account.receiving.get_address_records() records = await account.receiving.get_address_records()
self.assertEqual(len(records), 201) self.assertEqual(len(records), 201)
async def test_ensure_address_gap(self): async def test_ensure_address_gap(self):
account = Account.generate(self.ledger, Wallet(), 'lbryum') account = Account.generate(self.ledger, self.db, 'lbryum')
self.assertIsInstance(account.receiving, HierarchicalDeterministic) self.assertIsInstance(account.receiving, HierarchicalDeterministic)
@ -75,17 +78,17 @@ class TestAccount(AsyncioTestCase):
# case #2: only one new addressed needed # case #2: only one new addressed needed
records = await account.receiving.get_address_records() records = await account.receiving.get_address_records()
await self.ledger.db.set_address_history(records[0]['address'], 'a:1:') await self.update_addressed_used(records[0]['address'], 1)
new_keys = await account.receiving.ensure_address_gap() new_keys = await account.receiving.ensure_address_gap()
self.assertEqual(len(new_keys), 1) self.assertEqual(len(new_keys), 1)
# case #3: 20 addresses needed # case #3: 20 addresses needed
await self.ledger.db.set_address_history(new_keys[0], 'a:1:') await self.update_addressed_used(new_keys[0], 1)
new_keys = await account.receiving.ensure_address_gap() new_keys = await account.receiving.ensure_address_gap()
self.assertEqual(len(new_keys), 20) self.assertEqual(len(new_keys), 20)
async def test_get_or_create_usable_address(self): async def test_get_or_create_usable_address(self):
account = Account.generate(self.ledger, Wallet(), 'lbryum') account = Account.generate(self.ledger, self.db, 'lbryum')
keys = await account.receiving.get_addresses() keys = await account.receiving.get_addresses()
self.assertEqual(len(keys), 0) self.assertEqual(len(keys), 0)
@ -98,7 +101,7 @@ class TestAccount(AsyncioTestCase):
async def test_generate_account_from_seed(self): async def test_generate_account_from_seed(self):
account = Account.from_dict( account = Account.from_dict(
self.ledger, Wallet(), { self.ledger, self.db, {
"seed": "seed":
"carbon smart garage balance margin twelve chest sword toas" "carbon smart garage balance margin twelve chest sword toas"
"t envelope bottom stomach absent" "t envelope bottom stomach absent"
@ -117,19 +120,6 @@ class TestAccount(AsyncioTestCase):
address = await account.receiving.ensure_address_gap() address = await account.receiving.ensure_address_gap()
self.assertEqual(address[0], 'bCqJrLHdoiRqEZ1whFZ3WHNb33bP34SuGx') self.assertEqual(address[0], 'bCqJrLHdoiRqEZ1whFZ3WHNb33bP34SuGx')
private_key = await self.ledger.get_private_key_for_address(
account.wallet, 'bCqJrLHdoiRqEZ1whFZ3WHNb33bP34SuGx'
)
self.assertEqual(
private_key.extended_key_string(),
'xprv9vwXVierUTT4hmoe3dtTeBfbNv1ph2mm8RWXARU6HsZjBaAoFaS2FRQu4fptR'
'AyJWhJW42dmsEaC1nKnVKKTMhq3TVEHsNj1ca3ciZMKktT'
)
private_key = await self.ledger.get_private_key_for_address(
account.wallet, 'BcQjRlhDOIrQez1WHfz3whnB33Bp34sUgX'
)
self.assertIsNone(private_key)
async def test_load_and_save_account(self): async def test_load_and_save_account(self):
account_data = { account_data = {
'name': 'Main Account', 'name': 'Main Account',
@ -152,7 +142,7 @@ class TestAccount(AsyncioTestCase):
} }
} }
account = Account.from_dict(self.ledger, Wallet(), account_data) account = Account.from_dict(self.ledger, self.db, account_data)
await account.ensure_address_gap() await account.ensure_address_gap()
@ -160,27 +150,8 @@ class TestAccount(AsyncioTestCase):
self.assertEqual(len(addresses), 17) self.assertEqual(len(addresses), 17)
addresses = await account.change.get_addresses() addresses = await account.change.get_addresses()
self.assertEqual(len(addresses), 10) self.assertEqual(len(addresses), 10)
account_data['ledger'] = 'lbc_mainnet'
self.assertDictEqual(account_data, account.to_dict()) self.assertDictEqual(account_data, account.to_dict())
async def test_save_max_gap(self):
account = Account.generate(
self.ledger, Wallet(), 'lbryum', {
'name': 'deterministic-chain',
'receiving': {'gap': 3, 'maximum_uses_per_address': 2},
'change': {'gap': 4, 'maximum_uses_per_address': 2}
}
)
self.assertEqual(account.receiving.gap, 3)
self.assertEqual(account.change.gap, 4)
await account.save_max_gap()
self.assertEqual(account.receiving.gap, 20)
self.assertEqual(account.change.gap, 6)
# doesn't fail for single-address account
account2 = Account.generate(self.ledger, Wallet(), 'lbryum', {'name': 'single-address'})
await account2.save_max_gap()
def test_merge_diff(self): def test_merge_diff(self):
account_data = { account_data = {
'name': 'My Account', 'name': 'My Account',
@ -201,7 +172,7 @@ class TestAccount(AsyncioTestCase):
'change': {'gap': 5, 'maximum_uses_per_address': 2} 'change': {'gap': 5, 'maximum_uses_per_address': 2}
} }
} }
account = Account.from_dict(self.ledger, Wallet(), account_data) account = Account.from_dict(self.ledger, self.db, account_data)
self.assertEqual(account.name, 'My Account') self.assertEqual(account.name, 'My Account')
self.assertEqual(account.modified_on, 123.456) self.assertEqual(account.modified_on, 123.456)
@ -230,18 +201,13 @@ class TestAccount(AsyncioTestCase):
self.assertEqual(account.receiving.maximum_uses_per_address, 9) self.assertEqual(account.receiving.maximum_uses_per_address, 9)
class TestSingleKeyAccount(AsyncioTestCase): class TestSingleKeyAccount(AccountTestCase):
async def asyncSetUp(self): async def asyncSetUp(self):
self.ledger = Ledger({ await super().asyncSetUp()
'db': Database('sqlite:///:memory:'), self.account = Account.generate(
'headers': Headers(':memory:') self.ledger, self.db, "torba", {'name': 'single-address'}
}) )
await self.ledger.db.open()
self.account = Account.generate(self.ledger, Wallet(), "torba", {'name': 'single-address'})
async def asyncTearDown(self):
await self.ledger.db.close()
async def test_generate_account(self): async def test_generate_account(self):
account = self.account account = self.account
@ -286,7 +252,6 @@ class TestSingleKeyAccount(AsyncioTestCase):
'chain': 0, 'chain': 0,
'account': account.public_key.address, 'account': account.public_key.address,
'address': account.public_key.address, 'address': account.public_key.address,
'history': None,
'used_times': 0 'used_times': 0
}]) }])
self.assertEqual( self.assertEqual(
@ -300,7 +265,7 @@ class TestSingleKeyAccount(AsyncioTestCase):
# case #2: after use, still no new address needed # case #2: after use, still no new address needed
records = await account.receiving.get_address_records() records = await account.receiving.get_address_records()
await self.ledger.db.set_address_history(records[0]['address'], 'a:1:') await self.update_addressed_used(records[0]['address'], 1)
empty = await account.receiving.ensure_address_gap() empty = await account.receiving.ensure_address_gap()
self.assertEqual(len(empty), 0) self.assertEqual(len(empty), 0)
@ -313,7 +278,7 @@ class TestSingleKeyAccount(AsyncioTestCase):
address1 = await account.receiving.get_or_create_usable_address() address1 = await account.receiving.get_or_create_usable_address()
self.assertIsNotNone(address1) self.assertIsNotNone(address1)
await self.ledger.db.set_address_history(address1, 'a:1:b:2:c:3:') await self.update_addressed_used(address1, 3)
records = await account.receiving.get_address_records() records = await account.receiving.get_address_records()
self.assertEqual(records[0]['used_times'], 3) self.assertEqual(records[0]['used_times'], 3)
@ -323,47 +288,6 @@ class TestSingleKeyAccount(AsyncioTestCase):
keys = await account.receiving.get_addresses() keys = await account.receiving.get_addresses()
self.assertEqual(len(keys), 1) self.assertEqual(len(keys), 1)
async def test_generate_account_from_seed(self):
account = Account.from_dict(
self.ledger, Wallet(), {
"seed":
"carbon smart garage balance margin twelve chest sword toas"
"t envelope bottom stomach absent",
'address_generator': {'name': 'single-address'}
}
)
self.assertEqual(
account.private_key.extended_key_string(),
'xprv9s21ZrQH143K42ovpZygnjfHdAqSd9jo7zceDfPRogM7bkkoNVv7'
'DRNLEoB8HoirMgH969NrgL8jNzLEegqFzPRWM37GXd4uE8uuRkx4LAe',
)
self.assertEqual(
account.public_key.extended_key_string(),
'xpub661MyMwAqRbcGWtPvbWh9sc2BCfw2cTeVDYF23o3N1t6UZ5wv3EM'
'mDgp66FxHuDtWdft3B5eL5xQtyzAtkdmhhC95gjRjLzSTdkho95asu9',
)
address = await account.receiving.ensure_address_gap()
self.assertEqual(address[0], account.public_key.address)
private_key = await self.ledger.get_private_key_for_address(
account.wallet, address[0]
)
self.assertEqual(
private_key.extended_key_string(),
'xprv9s21ZrQH143K42ovpZygnjfHdAqSd9jo7zceDfPRogM7bkkoNVv7'
'DRNLEoB8HoirMgH969NrgL8jNzLEegqFzPRWM37GXd4uE8uuRkx4LAe',
)
invalid_key = await self.ledger.get_private_key_for_address(
account.wallet, 'BcQjRlhDOIrQez1WHfz3whnB33Bp34sUgX'
)
self.assertIsNone(invalid_key)
self.assertEqual(
hexlify(private_key.wif()),
b'1cef6c80310b1bcbcfa3176ea809ac840f48cda634c475d402e6bd68d5bb3827d601'
)
async def test_load_and_save_account(self): async def test_load_and_save_account(self):
account_data = { account_data = {
'name': 'My Account', 'name': 'My Account',
@ -380,7 +304,7 @@ class TestSingleKeyAccount(AsyncioTestCase):
'certificates': {} 'certificates': {}
} }
account = Account.from_dict(self.ledger, Wallet(), account_data) account = Account.from_dict(self.ledger, self.db, account_data)
await account.ensure_address_gap() await account.ensure_address_gap()
@ -390,11 +314,11 @@ class TestSingleKeyAccount(AsyncioTestCase):
self.assertEqual(len(addresses), 1) self.assertEqual(len(addresses), 1)
self.maxDiff = None self.maxDiff = None
account_data['ledger'] = 'lbc_mainnet'
self.assertDictEqual(account_data, account.to_dict()) self.assertDictEqual(account_data, account.to_dict())
class AccountEncryptionTests(AsyncioTestCase): class AccountEncryptionTests(AccountTestCase):
password = "password" password = "password"
init_vector = b'0000000000000000' init_vector = b'0000000000000000'
unencrypted_account = { unencrypted_account = {
@ -428,14 +352,8 @@ class AccountEncryptionTests(AsyncioTestCase):
'address_generator': {'name': 'single-address'} 'address_generator': {'name': 'single-address'}
} }
async def asyncSetUp(self):
self.ledger = Ledger({
'db': Database(':memory:'),
'headers': Headers(':memory:')
})
def test_encrypt_wallet(self): def test_encrypt_wallet(self):
account = Account.from_dict(self.ledger, Wallet(), self.unencrypted_account) account = Account.from_dict(self.ledger, self.db, self.unencrypted_account)
account.init_vectors = { account.init_vectors = {
'seed': self.init_vector, 'seed': self.init_vector,
'private_key': self.init_vector 'private_key': self.init_vector
@ -465,7 +383,7 @@ class AccountEncryptionTests(AsyncioTestCase):
self.assertFalse(account.encrypted) self.assertFalse(account.encrypted)
def test_decrypt_wallet(self): def test_decrypt_wallet(self):
account = Account.from_dict(self.ledger, Wallet(), self.encrypted_account) account = Account.from_dict(self.ledger, self.db, self.encrypted_account)
self.assertTrue(account.encrypted) self.assertTrue(account.encrypted)
account.decrypt(self.password) account.decrypt(self.password)
@ -486,7 +404,7 @@ class AccountEncryptionTests(AsyncioTestCase):
account_data = self.unencrypted_account.copy() account_data = self.unencrypted_account.copy()
del account_data['seed'] del account_data['seed']
del account_data['private_key'] del account_data['private_key']
account = Account.from_dict(self.ledger, Wallet(), account_data) account = Account.from_dict(self.ledger, self.db, account_data)
encrypted = account.to_dict('password') encrypted = account.to_dict('password')
self.assertFalse(encrypted['seed']) self.assertFalse(encrypted['seed'])
self.assertFalse(encrypted['private_key']) self.assertFalse(encrypted['private_key'])

View file

@ -1,16 +1,10 @@
from unittest import TestCase
from types import GeneratorType from types import GeneratorType
from lbry.testcase import AsyncioTestCase from lbry.blockchain.ledger import RegTestLedger
from lbry.wallet.coinselection import CoinSelector, OutputEffectiveAmountEstimator, MAXIMUM_TRIES
from lbry.wallet import Ledger, Headers
from lbry.db import Database
from lbry.wallet.coinselection import CoinSelector, MAXIMUM_TRIES
from lbry.constants import CENT from lbry.constants import CENT
from lbry.testcase import get_output as utxo
from tests.unit.wallet.test_transaction import get_output as utxo
NULL_HASH = b'\x00'*32
def search(*args, **kwargs): def search(*args, **kwargs):
@ -18,21 +12,14 @@ def search(*args, **kwargs):
return [o.txo.amount for o in selection] if selection else selection return [o.txo.amount for o in selection] if selection else selection
class BaseSelectionTestCase(AsyncioTestCase): class BaseSelectionTestCase(TestCase):
async def asyncSetUp(self): def setUp(self):
self.ledger = Ledger({ self.ledger = RegTestLedger()
'db': Database('sqlite:///:memory:'),
'headers': Headers(':memory:'),
})
await self.ledger.db.open()
async def asyncTearDown(self):
await self.ledger.db.close()
def estimates(self, *args): def estimates(self, *args):
txos = args[0] if isinstance(args[0], (GeneratorType, list)) else args txos = args[0] if isinstance(args[0], (GeneratorType, list)) else args
return [txo.get_estimator(self.ledger) for txo in txos] return [OutputEffectiveAmountEstimator(self.ledger, txo) for txo in txos]
class TestCoinSelectionTests(BaseSelectionTestCase): class TestCoinSelectionTests(BaseSelectionTestCase):
@ -41,7 +28,7 @@ class TestCoinSelectionTests(BaseSelectionTestCase):
self.assertListEqual(CoinSelector(0, 0).select([]), []) self.assertListEqual(CoinSelector(0, 0).select([]), [])
def test_skip_binary_search_if_total_not_enough(self): def test_skip_binary_search_if_total_not_enough(self):
fee = utxo(CENT).get_estimator(self.ledger).fee fee = OutputEffectiveAmountEstimator(self.ledger, utxo(CENT)).fee
big_pool = self.estimates(utxo(CENT+fee) for _ in range(100)) big_pool = self.estimates(utxo(CENT+fee) for _ in range(100))
selector = CoinSelector(101 * CENT, 0) selector = CoinSelector(101 * CENT, 0)
self.assertListEqual(selector.select(big_pool), []) self.assertListEqual(selector.select(big_pool), [])
@ -52,7 +39,7 @@ class TestCoinSelectionTests(BaseSelectionTestCase):
self.assertEqual(selector.tries, 201) self.assertEqual(selector.tries, 201)
def test_exact_match(self): def test_exact_match(self):
fee = utxo(CENT).get_estimator(self.ledger).fee fee = OutputEffectiveAmountEstimator(self.ledger, utxo(CENT)).fee
utxo_pool = self.estimates( utxo_pool = self.estimates(
utxo(CENT + fee), utxo(CENT + fee),
utxo(CENT), utxo(CENT),

View file

@ -1,30 +1,95 @@
import tempfile import tempfile
from binascii import hexlify from binascii import hexlify
from unittest import TestCase, mock from unittest import TestCase, mock
from lbry.testcase import AsyncioTestCase from lbry.testcase import AsyncioTestCase
from lbry.wallet import ( from lbry.db import Database
Ledger, RegTestLedger, WalletManager, Account, from lbry.blockchain.ledger import Ledger
Wallet, WalletStorage, TimestampedPreferences from lbry.wallet.manager import WalletManager
from lbry.wallet.wallet import (
Account, Wallet, WalletStorage, TimestampedPreferences
) )
class TestWalletCreation(AsyncioTestCase): class WalletTestCase(AsyncioTestCase):
async def asyncSetUp(self): async def asyncSetUp(self):
self.manager = WalletManager() self.ledger = Ledger()
config = {'data_path': '/tmp/wallet'} self.db = Database(self.ledger, 'sqlite:///:memory:')
self.main_ledger = self.manager.get_or_create_ledger(Ledger.get_id(), config) await self.db.open()
self.test_ledger = self.manager.get_or_create_ledger(RegTestLedger.get_id(), config) self.addCleanup(self.db.close)
class WalletAccountTest(WalletTestCase):
async def test_private_key_for_hierarchical_account(self):
wallet = Wallet(self.ledger, self.db)
account = wallet.add_account({
"seed":
"carbon smart garage balance margin twelve chest sword toas"
"t envelope bottom stomach absent"
})
await account.receiving.ensure_address_gap()
private_key = await wallet.get_private_key_for_address(
'bCqJrLHdoiRqEZ1whFZ3WHNb33bP34SuGx'
)
self.assertEqual(
private_key.extended_key_string(),
'xprv9vwXVierUTT4hmoe3dtTeBfbNv1ph2mm8RWXARU6HsZjBaAoFaS2FRQu4fptR'
'AyJWhJW42dmsEaC1nKnVKKTMhq3TVEHsNj1ca3ciZMKktT'
)
self.assertIsNone(
await wallet.get_private_key_for_address('BcQjRlhDOIrQez1WHfz3whnB33Bp34sUgX')
)
async def test_private_key_for_single_address_account(self):
wallet = Wallet(self.ledger, self.db)
account = wallet.add_account({
"seed":
"carbon smart garage balance margin twelve chest sword toas"
"t envelope bottom stomach absent",
'address_generator': {'name': 'single-address'}
})
address = await account.receiving.ensure_address_gap()
private_key = await wallet.get_private_key_for_address(address[0])
self.assertEqual(
private_key.extended_key_string(),
'xprv9s21ZrQH143K42ovpZygnjfHdAqSd9jo7zceDfPRogM7bkkoNVv7'
'DRNLEoB8HoirMgH969NrgL8jNzLEegqFzPRWM37GXd4uE8uuRkx4LAe',
)
self.assertIsNone(
await wallet.get_private_key_for_address('BcQjRlhDOIrQez1WHfz3whnB33Bp34sUgX')
)
async def test_save_max_gap(self):
wallet = Wallet(self.ledger, self.db)
account = wallet.generate_account(
'lbryum', {
'name': 'deterministic-chain',
'receiving': {'gap': 3, 'maximum_uses_per_address': 2},
'change': {'gap': 4, 'maximum_uses_per_address': 2}
}
)
self.assertEqual(account.receiving.gap, 3)
self.assertEqual(account.change.gap, 4)
await wallet.save_max_gap()
self.assertEqual(account.receiving.gap, 20)
self.assertEqual(account.change.gap, 6)
# doesn't fail for single-address account
wallet.generate_account('lbryum', {'name': 'single-address'})
await wallet.save_max_gap()
class TestWalletCreation(WalletTestCase):
def test_create_wallet_and_accounts(self): def test_create_wallet_and_accounts(self):
wallet = Wallet() wallet = Wallet(self.ledger, self.db)
self.assertEqual(wallet.name, 'Wallet') self.assertEqual(wallet.name, 'Wallet')
self.assertListEqual(wallet.accounts, []) self.assertListEqual(wallet.accounts, [])
account1 = wallet.generate_account(self.main_ledger) account1 = wallet.generate_account()
wallet.generate_account(self.main_ledger) wallet.generate_account()
wallet.generate_account(self.test_ledger) wallet.generate_account()
self.assertEqual(wallet.default_account, account1) self.assertEqual(wallet.default_account, account1)
self.assertEqual(len(wallet.accounts), 3) self.assertEqual(len(wallet.accounts), 3)
@ -32,12 +97,12 @@ class TestWalletCreation(AsyncioTestCase):
wallet_dict = { wallet_dict = {
'version': 1, 'version': 1,
'name': 'Main Wallet', 'name': 'Main Wallet',
'ledger': 'lbc_mainnet',
'preferences': {}, 'preferences': {},
'accounts': [ 'accounts': [
{ {
'certificates': {}, 'certificates': {},
'name': 'An Account', 'name': 'An Account',
'ledger': 'lbc_mainnet',
'modified_on': 123.456, 'modified_on': 123.456,
'seed': 'seed':
"carbon smart garage balance margin twelve chest sword toast envelope bottom stomac" "carbon smart garage balance margin twelve chest sword toast envelope bottom stomac"
@ -59,10 +124,11 @@ class TestWalletCreation(AsyncioTestCase):
} }
storage = WalletStorage(default=wallet_dict) storage = WalletStorage(default=wallet_dict)
wallet = Wallet.from_storage(storage, self.manager) wallet = Wallet.from_storage(self.ledger, self.db, storage)
self.assertEqual(wallet.name, 'Main Wallet') self.assertEqual(wallet.name, 'Main Wallet')
self.assertEqual( self.assertEqual(
hexlify(wallet.hash), b'a75913d2e7339c1a9ac0c89d621a4e10fd3a40dc3560dc01f4cf4ada0a0b05b8' hexlify(wallet.hash),
b'3b23aae8cd9b360f4296130b8f7afc5b2437560cdef7237bed245288ce8a5f79'
) )
self.assertEqual(len(wallet.accounts), 1) self.assertEqual(len(wallet.accounts), 1)
account = wallet.default_account account = wallet.default_account
@ -75,9 +141,7 @@ class TestWalletCreation(AsyncioTestCase):
self.assertEqual(decrypted['accounts'][0]['name'], 'An Account') self.assertEqual(decrypted['accounts'][0]['name'], 'An Account')
def test_read_write(self): def test_read_write(self):
manager = WalletManager() manager = WalletManager(self.ledger, self.db)
config = {'data_path': '/tmp/wallet'}
ledger = manager.get_or_create_ledger(Ledger.get_id(), config)
with tempfile.NamedTemporaryFile(suffix='.json') as wallet_file: with tempfile.NamedTemporaryFile(suffix='.json') as wallet_file:
wallet_file.write(b'{"version": 1}') wallet_file.write(b'{"version": 1}')
@ -85,29 +149,29 @@ class TestWalletCreation(AsyncioTestCase):
# create and write wallet to a file # create and write wallet to a file
wallet = manager.import_wallet(wallet_file.name) wallet = manager.import_wallet(wallet_file.name)
account = wallet.generate_account(ledger) account = wallet.generate_account()
wallet.save() wallet.save()
# read wallet from file # read wallet from file
wallet_storage = WalletStorage(wallet_file.name) wallet_storage = WalletStorage(wallet_file.name)
wallet = Wallet.from_storage(wallet_storage, manager) wallet = Wallet.from_storage(self.ledger, self.db, wallet_storage)
self.assertEqual(account.public_key.address, wallet.default_account.public_key.address) self.assertEqual(account.public_key.address, wallet.default_account.public_key.address)
def test_merge(self): def test_merge(self):
wallet1 = Wallet() wallet1 = Wallet(self.ledger, self.db)
wallet1.preferences['one'] = 1 wallet1.preferences['one'] = 1
wallet1.preferences['conflict'] = 1 wallet1.preferences['conflict'] = 1
wallet1.generate_account(self.main_ledger) wallet1.generate_account()
wallet2 = Wallet() wallet2 = Wallet(self.ledger, self.db)
wallet2.preferences['two'] = 2 wallet2.preferences['two'] = 2
wallet2.preferences['conflict'] = 2 # will be more recent wallet2.preferences['conflict'] = 2 # will be more recent
wallet2.generate_account(self.main_ledger) wallet2.generate_account()
self.assertEqual(len(wallet1.accounts), 1) self.assertEqual(len(wallet1.accounts), 1)
self.assertEqual(wallet1.preferences, {'one': 1, 'conflict': 1}) self.assertEqual(wallet1.preferences, {'one': 1, 'conflict': 1})
added = wallet1.merge(self.manager, 'password', wallet2.pack('password')) added = wallet1.merge('password', wallet2.pack('password'))
self.assertEqual(added[0].id, wallet2.default_account.id) self.assertEqual(added[0].id, wallet2.default_account.id)
self.assertEqual(len(wallet1.accounts), 2) self.assertEqual(len(wallet1.accounts), 2)
self.assertEqual(wallet1.accounts[1].id, wallet2.default_account.id) self.assertEqual(wallet1.accounts[1].id, wallet2.default_account.id)