From 713c665588abea30828fb95fd4c743fe9e385fa0 Mon Sep 17 00:00:00 2001 From: Lex Berezhny Date: Fri, 1 May 2020 09:34:34 -0400 Subject: [PATCH] unittests --- tests/unit/blob/test_blob_file.py | 2 +- .../unit/blob_exchange/test_transfer_blob.py | 10 +- tests/unit/comments/test_comment_signing.py | 2 +- tests/unit/crypto/test_bip32.py | 37 ++-- tests/unit/database/test_SQLiteStorage.py | 4 +- tests/unit/lbrynet_daemon/test_Daemon.py | 2 +- tests/unit/schema/test_schema_signing.py | 41 ++--- tests/unit/stream/test_reflector.py | 4 +- tests/unit/stream/test_stream_descriptor.py | 4 +- tests/unit/stream/test_stream_manager.py | 54 ++---- tests/unit/test_event_controller.py | 41 +++-- tests/unit/wallet/test_account.py | 162 +++++------------- tests/unit/wallet/test_coinselection.py | 33 ++-- tests/unit/wallet/test_wallet.py | 116 ++++++++++--- 14 files changed, 224 insertions(+), 288 deletions(-) diff --git a/tests/unit/blob/test_blob_file.py b/tests/unit/blob/test_blob_file.py index ba7923002..77f572d28 100644 --- a/tests/unit/blob/test_blob_file.py +++ b/tests/unit/blob/test_blob_file.py @@ -18,7 +18,7 @@ class TestBlob(AsyncioTestCase): self.tmp_dir = tempfile.mkdtemp() self.addCleanup(lambda: shutil.rmtree(self.tmp_dir)) self.loop = asyncio.get_running_loop() - self.config = Config() + self.config = Config.with_same_dir(self.tmp_dir) self.storage = SQLiteStorage(self.config, ":memory:", self.loop) self.blob_manager = BlobManager(self.loop, self.tmp_dir, self.storage, self.config) await self.storage.open() diff --git a/tests/unit/blob_exchange/test_transfer_blob.py b/tests/unit/blob_exchange/test_transfer_blob.py index f7c011e3b..9ee46501e 100644 --- a/tests/unit/blob_exchange/test_transfer_blob.py +++ b/tests/unit/blob_exchange/test_transfer_blob.py @@ -33,14 +33,12 @@ class BlobExchangeTestBase(AsyncioTestCase): self.server_dir = tempfile.mkdtemp() self.addCleanup(shutil.rmtree, self.client_dir) self.addCleanup(shutil.rmtree, self.server_dir) - self.server_config = Config(data_dir=self.server_dir, download_dir=self.server_dir, wallet=self.server_dir, - reflector_servers=[]) + self.server_config = Config.with_same_dir(self.server_dir).set(reflector_servers=[]) self.server_storage = SQLiteStorage(self.server_config, os.path.join(self.server_dir, "lbrynet.sqlite")) self.server_blob_manager = BlobManager(self.loop, self.server_dir, self.server_storage, self.server_config) self.server = BlobServer(self.loop, self.server_blob_manager, 'bQEaw42GXsgCAGio1nxFncJSyRmnztSCjP') - self.client_config = Config(data_dir=self.client_dir, download_dir=self.client_dir, wallet=self.client_dir, - reflector_servers=[]) + self.client_config = Config.with_same_dir(self.client_dir).set(reflector_servers=[]) self.client_storage = SQLiteStorage(self.client_config, os.path.join(self.client_dir, "lbrynet.sqlite")) self.client_blob_manager = BlobManager(self.loop, self.client_dir, self.client_storage, self.client_config) self.client_peer_manager = PeerManager(self.loop) @@ -98,7 +96,7 @@ class TestBlobExchange(BlobExchangeTestBase): second_client_dir = tempfile.mkdtemp() self.addCleanup(shutil.rmtree, second_client_dir) - second_client_conf = Config() + second_client_conf = Config.with_same_dir(second_client_dir) second_client_storage = SQLiteStorage(second_client_conf, os.path.join(second_client_dir, "lbrynet.sqlite")) second_client_blob_manager = BlobManager( self.loop, second_client_dir, second_client_storage, second_client_conf @@ -188,7 +186,7 @@ class TestBlobExchange(BlobExchangeTestBase): second_client_dir = tempfile.mkdtemp() self.addCleanup(shutil.rmtree, second_client_dir) - second_client_conf = Config() + second_client_conf = Config.with_same_dir(second_client_dir) second_client_storage = SQLiteStorage(second_client_conf, os.path.join(second_client_dir, "lbrynet.sqlite")) second_client_blob_manager = BlobManager( diff --git a/tests/unit/comments/test_comment_signing.py b/tests/unit/comments/test_comment_signing.py index 9cdfd3d69..39db92e8d 100644 --- a/tests/unit/comments/test_comment_signing.py +++ b/tests/unit/comments/test_comment_signing.py @@ -3,7 +3,7 @@ import hashlib from lbry.extras.daemon.comment_client import sign_comment from lbry.extras.daemon.comment_client import is_comment_signed_by_channel -from tests.unit.wallet.test_schema_signing import get_stream, get_channel +from unit.schema.test_schema_signing import get_stream, get_channel class TestSigningComments(AsyncioTestCase): diff --git a/tests/unit/crypto/test_bip32.py b/tests/unit/crypto/test_bip32.py index 97d33addc..fadbe96f1 100644 --- a/tests/unit/crypto/test_bip32.py +++ b/tests/unit/crypto/test_bip32.py @@ -1,14 +1,15 @@ +from unittest import TestCase from binascii import unhexlify, hexlify -from lbry.testcase import AsyncioTestCase -from lbry.wallet.bip32 import PubKey, PrivateKey, from_extended_key_string -from lbry.wallet import Ledger, Headers -from lbry.db import Database +from lbry.blockchain.ledger import Ledger +from lbry.crypto.bip32 import PubKey, PrivateKey, from_extended_key_string -from tests.unit.wallet.key_fixtures import expected_ids, expected_privkeys, expected_hardened_privkeys +from tests.unit.wallet.key_fixtures import ( + expected_ids, expected_privkeys, expected_hardened_privkeys +) -class BIP32Tests(AsyncioTestCase): +class BIP32Tests(TestCase): def test_pubkey_validation(self): with self.assertRaisesRegex(TypeError, 'chain code must be raw bytes'): @@ -41,16 +42,13 @@ class BIP32Tests(AsyncioTestCase): self.assertIsInstance(new_key, PubKey) self.assertEqual(hexlify(new_key.identifier()), expected_ids[i]) - async def test_private_key_validation(self): + def test_private_key_validation(self): with self.assertRaisesRegex(TypeError, 'private key must be raw bytes'): PrivateKey(None, None, b'abcd'*8, 0, 255) with self.assertRaisesRegex(ValueError, 'private key must be 32 bytes'): PrivateKey(None, b'abcd', b'abcd'*8, 0, 255) private_key = PrivateKey( - Ledger({ - 'db': Database('sqlite:///:memory:'), - 'headers': Headers(':memory:'), - }), + Ledger(), unhexlify('2423f3dc6087d9683f73a684935abc0ccd8bc26370588f56653128c6a6f0bf7c'), b'abcd'*8, 0, 1 ) @@ -66,12 +64,9 @@ class BIP32Tests(AsyncioTestCase): private_key.child(-1) self.assertIsInstance(private_key.child(PrivateKey.HARDENED), PrivateKey) - async def test_private_key_derivation(self): + def test_private_key_derivation(self): private_key = PrivateKey( - Ledger({ - 'db': Database('sqlite:///:memory:'), - 'headers': Headers(':memory:'), - }), + Ledger(), unhexlify('2423f3dc6087d9683f73a684935abc0ccd8bc26370588f56653128c6a6f0bf7c'), b'abcd'*8, 0, 1 ) @@ -84,21 +79,17 @@ class BIP32Tests(AsyncioTestCase): self.assertIsInstance(new_privkey, PrivateKey) self.assertEqual(hexlify(new_privkey.private_key_bytes), expected_hardened_privkeys[i - 1 - PrivateKey.HARDENED]) - async def test_from_extended_keys(self): - ledger = Ledger({ - 'db': Database('sqlite:///:memory:'), - 'headers': Headers(':memory:'), - }) + def test_from_extended_keys(self): self.assertIsInstance( from_extended_key_string( - ledger, + Ledger(), 'xprv9s21ZrQH143K2dyhK7SevfRG72bYDRNv25yKPWWm6dqApNxm1Zb1m5gGcBWYfbsPjTr2v5joit8Af2Zp5P' '6yz3jMbycrLrRMpeAJxR8qDg8', ), PrivateKey ) self.assertIsInstance( from_extended_key_string( - ledger, + Ledger(), 'xpub661MyMwAqRbcF84AR8yfHoMzf4S2ct6mPJtvBtvNeyN9hBHuZ6uGJszkTSn5fQUCdz3XU17eBzFeAUwV6f' 'iW44g14WF52fYC5J483wqQ5ZP', ), PubKey diff --git a/tests/unit/database/test_SQLiteStorage.py b/tests/unit/database/test_SQLiteStorage.py index 07063aa7f..f0b70d32c 100644 --- a/tests/unit/database/test_SQLiteStorage.py +++ b/tests/unit/database/test_SQLiteStorage.py @@ -70,9 +70,9 @@ fake_claim_info = { class StorageTest(AsyncioTestCase): async def asyncSetUp(self): - self.conf = Config() - self.storage = SQLiteStorage(self.conf, ':memory:') self.blob_dir = tempfile.mkdtemp() + self.conf = Config.with_same_dir(self.blob_dir) + self.storage = SQLiteStorage(self.conf, ':memory:') self.addCleanup(shutil.rmtree, self.blob_dir) self.blob_manager = BlobManager(asyncio.get_event_loop(), self.blob_dir, self.storage, self.conf) await self.storage.open() diff --git a/tests/unit/lbrynet_daemon/test_Daemon.py b/tests/unit/lbrynet_daemon/test_Daemon.py index 6cbd57b04..896dfb824 100644 --- a/tests/unit/lbrynet_daemon/test_Daemon.py +++ b/tests/unit/lbrynet_daemon/test_Daemon.py @@ -10,7 +10,7 @@ from lbry.extras.daemon.components import HASH_ANNOUNCER_COMPONENT from lbry.extras.daemon.components import UPNP_COMPONENT, BLOB_COMPONENT from lbry.extras.daemon.components import PEER_PROTOCOL_SERVER_COMPONENT, EXCHANGE_RATE_MANAGER_COMPONENT from lbry.extras.daemon.daemon import Daemon as LBRYDaemon -from lbry.wallet import WalletManager, Wallet +from lbry.wallet.manager import WalletManager, Wallet from tests import test_utils # from tests.mocks import mock_conf_settings, FakeNetwork, FakeFileManager diff --git a/tests/unit/schema/test_schema_signing.py b/tests/unit/schema/test_schema_signing.py index 4d4dccf7f..80b3cb949 100644 --- a/tests/unit/schema/test_schema_signing.py +++ b/tests/unit/schema/test_schema_signing.py @@ -1,41 +1,25 @@ +from unittest import TestCase from binascii import unhexlify -from lbry.testcase import AsyncioTestCase -from lbry.wallet.constants import CENT, NULL_HASH32 - -from lbry.wallet import Ledger, Headers, Transaction, Input, Output -from lbry.db import Database +from lbry.testcase import get_transaction +from lbry.blockchain.ledger import Ledger +from lbry.blockchain.transaction import Transaction, Output +from lbry.constants import CENT from lbry.schema.claim import Claim -def get_output(amount=CENT, pubkey_hash=NULL_HASH32): - return Transaction() \ - .add_outputs([Output.pay_pubkey_hash(amount, pubkey_hash)]) \ - .outputs[0] - - -def get_input(): - return Input.spend(get_output()) - - -def get_tx(): - return Transaction().add_inputs([get_input()]) - - def get_channel(claim_name='@foo'): channel_txo = Output.pay_claim_name_pubkey_hash(CENT, claim_name, Claim(), b'abc') channel_txo.generate_channel_private_key() - get_tx().add_outputs([channel_txo]) - return channel_txo + return get_transaction(channel_txo).outputs[0] def get_stream(claim_name='foo'): stream_txo = Output.pay_claim_name_pubkey_hash(CENT, claim_name, Claim(), b'abc') - get_tx().add_outputs([stream_txo]) - return stream_txo + return get_transaction(stream_txo).outputs[0] -class TestSigningAndValidatingClaim(AsyncioTestCase): +class TestSigningAndValidatingClaim(TestCase): def test_successful_create_sign_and_validate(self): channel = get_channel() @@ -65,7 +49,7 @@ class TestSigningAndValidatingClaim(AsyncioTestCase): self.assertFalse(channel.is_channel_private_key(get_channel().private_key)) -class TestValidatingOldSignatures(AsyncioTestCase): +class TestValidatingOldSignatures(TestCase): def test_signed_claim_made_by_ytsync(self): stream_tx = Transaction(unhexlify( @@ -109,9 +93,4 @@ class TestValidatingOldSignatures(AsyncioTestCase): )) channel = channel_tx.outputs[0] - ledger = Ledger({ - 'db': Database(':memory:'), - 'headers': Headers(':memory:') - }) - - self.assertTrue(stream.is_signed_by(channel, ledger)) + self.assertTrue(stream.is_signed_by(channel, Ledger())) diff --git a/tests/unit/stream/test_reflector.py b/tests/unit/stream/test_reflector.py index 8c228f92c..ec3270c44 100644 --- a/tests/unit/stream/test_reflector.py +++ b/tests/unit/stream/test_reflector.py @@ -18,7 +18,7 @@ class TestStreamAssembler(AsyncioTestCase): tmp_dir = tempfile.mkdtemp() self.addCleanup(lambda: shutil.rmtree(tmp_dir)) - self.conf = Config() + self.conf = Config.with_same_dir(tmp_dir) self.storage = SQLiteStorage(self.conf, os.path.join(tmp_dir, "lbrynet.sqlite")) await self.storage.open() self.blob_manager = BlobManager(self.loop, tmp_dir, self.storage, self.conf) @@ -26,7 +26,7 @@ class TestStreamAssembler(AsyncioTestCase): server_tmp_dir = tempfile.mkdtemp() self.addCleanup(lambda: shutil.rmtree(server_tmp_dir)) - self.server_conf = Config() + self.server_conf = Config.with_same_dir(server_tmp_dir) self.server_storage = SQLiteStorage(self.server_conf, os.path.join(server_tmp_dir, "lbrynet.sqlite")) await self.server_storage.open() self.server_blob_manager = BlobManager(self.loop, server_tmp_dir, self.server_storage, self.server_conf) diff --git a/tests/unit/stream/test_stream_descriptor.py b/tests/unit/stream/test_stream_descriptor.py index b46012711..73bcbeb72 100644 --- a/tests/unit/stream/test_stream_descriptor.py +++ b/tests/unit/stream/test_stream_descriptor.py @@ -20,7 +20,7 @@ class TestStreamDescriptor(AsyncioTestCase): self.cleartext = os.urandom(20000000) self.tmp_dir = tempfile.mkdtemp() self.addCleanup(lambda: shutil.rmtree(self.tmp_dir)) - self.conf = Config() + self.conf = Config.with_same_dir(self.tmp_dir) self.storage = SQLiteStorage(self.conf, ":memory:") await self.storage.open() self.blob_manager = BlobManager(self.loop, self.tmp_dir, self.storage, self.conf) @@ -93,7 +93,7 @@ class TestRecoverOldStreamDescriptors(AsyncioTestCase): loop = asyncio.get_event_loop() tmp_dir = tempfile.mkdtemp() self.addCleanup(lambda: shutil.rmtree(tmp_dir)) - self.conf = Config() + self.conf = Config.with_same_dir(tmp_dir) storage = SQLiteStorage(self.conf, ":memory:") await storage.open() blob_manager = BlobManager(loop, tmp_dir, storage, self.conf) diff --git a/tests/unit/stream/test_stream_manager.py b/tests/unit/stream/test_stream_manager.py index aeda3ec90..958ca7739 100644 --- a/tests/unit/stream/test_stream_manager.py +++ b/tests/unit/stream/test_stream_manager.py @@ -10,11 +10,10 @@ from lbry.testcase import get_fake_exchange_rate_manager from lbry.utils import generate_id from lbry.error import InsufficientFundsError from lbry.error import KeyFeeAboveMaxAllowedError, ResolveError, DownloadSDTimeoutError, DownloadDataTimeoutError -from lbry.wallet import WalletManager, Wallet, Ledger, Transaction, Input, Output -from lbry.wallet.constants import CENT, NULL_HASH32 -from lbry.wallet.network import ClientSession -from lbry.db import Database -from lbry.conf import Config +from lbry.blockchain.ledger import Ledger +from lbry.blockchain.transaction import Transaction, Input, Output +from lbry.constants import CENT, NULL_HASH32 +from lbry.service.full_node import FullNode from lbry.extras.daemon.analytics import AnalyticsManager from lbry.stream.stream_manager import StreamManager from lbry.stream.descriptor import StreamDescriptor @@ -64,7 +63,7 @@ def get_claim_transaction(claim_name, claim=b''): ) -async def get_mock_wallet(sd_hash, storage, balance=10.0, fee=None): +async def get_mock_wallet(sd_hash, storage, conf, balance=10.0, fee=None): claim = Claim() if fee: if fee['currency'] == 'LBC': @@ -84,45 +83,26 @@ async def get_mock_wallet(sd_hash, storage, balance=10.0, fee=None): }) - class FakeHeaders: - def estimated_timestamp(self, height): - return 1984 - - def __init__(self, height): - self.height = height - - def __getitem__(self, item): - return {'timestamp': 1984} - - wallet = Wallet() - ledger = Ledger({ - 'db': Database('sqlite:///:memory:'), - 'headers': FakeHeaders(514082) - }) - await ledger.db.open() - wallet.generate_account(ledger) - manager = WalletManager() - manager.config = Config() - manager.wallets.append(wallet) - manager.ledgers[Ledger] = ledger - manager.ledger.network.client = ClientSession( - network=manager.ledger.network, server=('fakespv.lbry.com', 50001) - ) + service = FullNode(Ledger(conf), 'sqlite:///:memory:') + await service.db.open() + await service.wallet_manager.open() async def mock_resolve(*args, **kwargs): result = {txo.meta['permanent_url']: txo} claims = [ - StreamManager._convert_to_old_resolve_output(manager, result)[txo.meta['permanent_url']] + StreamManager._convert_to_old_resolve_output( + service.wallet_manager, result + )[txo.meta['permanent_url']] ] await storage.save_claims(claims) return result - manager.ledger.resolve = mock_resolve + service.resolve = mock_resolve async def get_balance(*_): return balance - manager.get_balance = get_balance + service.get_balance = get_balance - return manager, txo.meta['permanent_url'] + return service, txo.meta['permanent_url'] class TestStreamManager(BlobExchangeTestBase): @@ -138,8 +118,10 @@ class TestStreamManager(BlobExchangeTestBase): self.loop, self.server_blob_manager.blob_dir, file_path, old_sort=old_sort ) self.sd_hash = descriptor.sd_hash - self.mock_wallet, self.uri = await get_mock_wallet(self.sd_hash, self.client_storage, balance, fee) - self.stream_manager = StreamManager(self.loop, self.client_config, self.client_blob_manager, self.mock_wallet, + self.service, self.uri = await get_mock_wallet( + self.sd_hash, self.client_storage, self.client_config, balance, fee + ) + self.stream_manager = StreamManager(self.loop, self.client_config, self.client_blob_manager, self.service, self.client_storage, get_mock_node(self.server_from_client), AnalyticsManager(self.client_config, binascii.hexlify(generate_id()).decode(), diff --git a/tests/unit/test_event_controller.py b/tests/unit/test_event_controller.py index 4cd11e64a..213899414 100644 --- a/tests/unit/test_event_controller.py +++ b/tests/unit/test_event_controller.py @@ -1,25 +1,42 @@ -from lbry.wallet.stream import StreamController -from lbry.wallet.tasks import TaskGroup from lbry.testcase import AsyncioTestCase +from lbry.event import EventController +from lbry.tasks import TaskGroup class StreamControllerTestCase(AsyncioTestCase): - def test_non_unique_events(self): + + async def test_non_unique_events(self): events = [] - controller = StreamController() - controller.stream.listen(on_data=events.append) - controller.add("yo") - controller.add("yo") + controller = EventController() + controller.stream.listen(events.append) + await controller.add("yo") + await controller.add("yo") self.assertListEqual(events, ["yo", "yo"]) - def test_unique_events(self): + async def test_unique_events(self): events = [] - controller = StreamController(merge_repeated_events=True) - controller.stream.listen(on_data=events.append) - controller.add("yo") - controller.add("yo") + controller = EventController(merge_repeated_events=True) + controller.stream.listen(events.append) + await controller.add("yo") + await controller.add("yo") self.assertListEqual(events, ["yo"]) + async def test_sync_listener_errors(self): + def bad_listener(e): + raise ValueError('bad') + controller = EventController() + controller.stream.listen(bad_listener) + with self.assertRaises(ValueError): + await controller.add("yo") + + async def test_async_listener_errors(self): + async def bad_listener(e): + raise ValueError('bad') + controller = EventController() + controller.stream.listen(bad_listener) + with self.assertRaises(ValueError): + await controller.add("yo") + class TaskGroupTestCase(AsyncioTestCase): diff --git a/tests/unit/wallet/test_account.py b/tests/unit/wallet/test_account.py index a494cfe33..91dbf4564 100644 --- a/tests/unit/wallet/test_account.py +++ b/tests/unit/wallet/test_account.py @@ -1,31 +1,34 @@ -from binascii import hexlify from lbry.testcase import AsyncioTestCase -from lbry.wallet import Wallet, Ledger, Headers, Account, SingleKey, HierarchicalDeterministic -from lbry.db import Database + +from lbry.blockchain.ledger import Ledger +from lbry.wallet.account import Account, SingleKey, HierarchicalDeterministic +from lbry.db import Database, PubkeyAddress -class TestAccount(AsyncioTestCase): +class AccountTestCase(AsyncioTestCase): async def asyncSetUp(self): - self.ledger = Ledger({ - 'db': Database('sqlite:///:memory:'), - 'headers': Headers(':memory:') - }) - await self.ledger.db.open() + self.ledger = Ledger() + self.db = Database(self.ledger, 'sqlite:///:memory:') + await self.db.open() + self.addCleanup(self.db.close) - async def asyncTearDown(self): - await self.ledger.db.close() + async def update_addressed_used(self, address, used): + await self.db.execute( + PubkeyAddress.update() + .where(PubkeyAddress.c.address == address) + .values(used_times=used) + ) + + +class TestAccount(AccountTestCase): async def test_generate_account(self): - account = Account.generate(self.ledger, Wallet(), 'lbryum') - self.assertEqual(account.ledger, self.ledger) + account = Account.generate(self.ledger, self.db, 'lbryum') self.assertIsNotNone(account.seed) self.assertEqual(account.public_key.ledger, self.ledger) self.assertEqual(account.private_key.public_key, account.public_key) - self.assertEqual(account.public_key.ledger, self.ledger) - self.assertEqual(account.private_key.public_key, account.public_key) - addresses = await account.receiving.get_addresses() self.assertEqual(len(addresses), 0) addresses = await account.change.get_addresses() @@ -39,14 +42,14 @@ class TestAccount(AsyncioTestCase): self.assertEqual(len(addresses), 6) async def test_generate_keys_over_batch_threshold_saves_it_properly(self): - account = Account.generate(self.ledger, Wallet(), 'lbryum') + account = Account.generate(self.ledger, self.db, 'lbryum') async with account.receiving.address_generator_lock: await account.receiving._generate_keys(0, 200) records = await account.receiving.get_address_records() self.assertEqual(len(records), 201) async def test_ensure_address_gap(self): - account = Account.generate(self.ledger, Wallet(), 'lbryum') + account = Account.generate(self.ledger, self.db, 'lbryum') self.assertIsInstance(account.receiving, HierarchicalDeterministic) @@ -75,17 +78,17 @@ class TestAccount(AsyncioTestCase): # case #2: only one new addressed needed records = await account.receiving.get_address_records() - await self.ledger.db.set_address_history(records[0]['address'], 'a:1:') + await self.update_addressed_used(records[0]['address'], 1) new_keys = await account.receiving.ensure_address_gap() self.assertEqual(len(new_keys), 1) # case #3: 20 addresses needed - await self.ledger.db.set_address_history(new_keys[0], 'a:1:') + await self.update_addressed_used(new_keys[0], 1) new_keys = await account.receiving.ensure_address_gap() self.assertEqual(len(new_keys), 20) async def test_get_or_create_usable_address(self): - account = Account.generate(self.ledger, Wallet(), 'lbryum') + account = Account.generate(self.ledger, self.db, 'lbryum') keys = await account.receiving.get_addresses() self.assertEqual(len(keys), 0) @@ -98,7 +101,7 @@ class TestAccount(AsyncioTestCase): async def test_generate_account_from_seed(self): account = Account.from_dict( - self.ledger, Wallet(), { + self.ledger, self.db, { "seed": "carbon smart garage balance margin twelve chest sword toas" "t envelope bottom stomach absent" @@ -117,19 +120,6 @@ class TestAccount(AsyncioTestCase): address = await account.receiving.ensure_address_gap() self.assertEqual(address[0], 'bCqJrLHdoiRqEZ1whFZ3WHNb33bP34SuGx') - private_key = await self.ledger.get_private_key_for_address( - account.wallet, 'bCqJrLHdoiRqEZ1whFZ3WHNb33bP34SuGx' - ) - self.assertEqual( - private_key.extended_key_string(), - 'xprv9vwXVierUTT4hmoe3dtTeBfbNv1ph2mm8RWXARU6HsZjBaAoFaS2FRQu4fptR' - 'AyJWhJW42dmsEaC1nKnVKKTMhq3TVEHsNj1ca3ciZMKktT' - ) - private_key = await self.ledger.get_private_key_for_address( - account.wallet, 'BcQjRlhDOIrQez1WHfz3whnB33Bp34sUgX' - ) - self.assertIsNone(private_key) - async def test_load_and_save_account(self): account_data = { 'name': 'Main Account', @@ -152,7 +142,7 @@ class TestAccount(AsyncioTestCase): } } - account = Account.from_dict(self.ledger, Wallet(), account_data) + account = Account.from_dict(self.ledger, self.db, account_data) await account.ensure_address_gap() @@ -160,27 +150,8 @@ class TestAccount(AsyncioTestCase): self.assertEqual(len(addresses), 17) addresses = await account.change.get_addresses() self.assertEqual(len(addresses), 10) - - account_data['ledger'] = 'lbc_mainnet' self.assertDictEqual(account_data, account.to_dict()) - async def test_save_max_gap(self): - account = Account.generate( - self.ledger, Wallet(), 'lbryum', { - 'name': 'deterministic-chain', - 'receiving': {'gap': 3, 'maximum_uses_per_address': 2}, - 'change': {'gap': 4, 'maximum_uses_per_address': 2} - } - ) - self.assertEqual(account.receiving.gap, 3) - self.assertEqual(account.change.gap, 4) - await account.save_max_gap() - self.assertEqual(account.receiving.gap, 20) - self.assertEqual(account.change.gap, 6) - # doesn't fail for single-address account - account2 = Account.generate(self.ledger, Wallet(), 'lbryum', {'name': 'single-address'}) - await account2.save_max_gap() - def test_merge_diff(self): account_data = { 'name': 'My Account', @@ -201,7 +172,7 @@ class TestAccount(AsyncioTestCase): 'change': {'gap': 5, 'maximum_uses_per_address': 2} } } - account = Account.from_dict(self.ledger, Wallet(), account_data) + account = Account.from_dict(self.ledger, self.db, account_data) self.assertEqual(account.name, 'My Account') self.assertEqual(account.modified_on, 123.456) @@ -230,18 +201,13 @@ class TestAccount(AsyncioTestCase): self.assertEqual(account.receiving.maximum_uses_per_address, 9) -class TestSingleKeyAccount(AsyncioTestCase): +class TestSingleKeyAccount(AccountTestCase): async def asyncSetUp(self): - self.ledger = Ledger({ - 'db': Database('sqlite:///:memory:'), - 'headers': Headers(':memory:') - }) - await self.ledger.db.open() - self.account = Account.generate(self.ledger, Wallet(), "torba", {'name': 'single-address'}) - - async def asyncTearDown(self): - await self.ledger.db.close() + await super().asyncSetUp() + self.account = Account.generate( + self.ledger, self.db, "torba", {'name': 'single-address'} + ) async def test_generate_account(self): account = self.account @@ -286,7 +252,6 @@ class TestSingleKeyAccount(AsyncioTestCase): 'chain': 0, 'account': account.public_key.address, 'address': account.public_key.address, - 'history': None, 'used_times': 0 }]) self.assertEqual( @@ -300,7 +265,7 @@ class TestSingleKeyAccount(AsyncioTestCase): # case #2: after use, still no new address needed records = await account.receiving.get_address_records() - await self.ledger.db.set_address_history(records[0]['address'], 'a:1:') + await self.update_addressed_used(records[0]['address'], 1) empty = await account.receiving.ensure_address_gap() self.assertEqual(len(empty), 0) @@ -313,7 +278,7 @@ class TestSingleKeyAccount(AsyncioTestCase): address1 = await account.receiving.get_or_create_usable_address() self.assertIsNotNone(address1) - await self.ledger.db.set_address_history(address1, 'a:1:b:2:c:3:') + await self.update_addressed_used(address1, 3) records = await account.receiving.get_address_records() self.assertEqual(records[0]['used_times'], 3) @@ -323,47 +288,6 @@ class TestSingleKeyAccount(AsyncioTestCase): keys = await account.receiving.get_addresses() self.assertEqual(len(keys), 1) - async def test_generate_account_from_seed(self): - account = Account.from_dict( - self.ledger, Wallet(), { - "seed": - "carbon smart garage balance margin twelve chest sword toas" - "t envelope bottom stomach absent", - 'address_generator': {'name': 'single-address'} - } - ) - self.assertEqual( - account.private_key.extended_key_string(), - 'xprv9s21ZrQH143K42ovpZygnjfHdAqSd9jo7zceDfPRogM7bkkoNVv7' - 'DRNLEoB8HoirMgH969NrgL8jNzLEegqFzPRWM37GXd4uE8uuRkx4LAe', - ) - self.assertEqual( - account.public_key.extended_key_string(), - 'xpub661MyMwAqRbcGWtPvbWh9sc2BCfw2cTeVDYF23o3N1t6UZ5wv3EM' - 'mDgp66FxHuDtWdft3B5eL5xQtyzAtkdmhhC95gjRjLzSTdkho95asu9', - ) - address = await account.receiving.ensure_address_gap() - self.assertEqual(address[0], account.public_key.address) - - private_key = await self.ledger.get_private_key_for_address( - account.wallet, address[0] - ) - self.assertEqual( - private_key.extended_key_string(), - 'xprv9s21ZrQH143K42ovpZygnjfHdAqSd9jo7zceDfPRogM7bkkoNVv7' - 'DRNLEoB8HoirMgH969NrgL8jNzLEegqFzPRWM37GXd4uE8uuRkx4LAe', - ) - - invalid_key = await self.ledger.get_private_key_for_address( - account.wallet, 'BcQjRlhDOIrQez1WHfz3whnB33Bp34sUgX' - ) - self.assertIsNone(invalid_key) - - self.assertEqual( - hexlify(private_key.wif()), - b'1cef6c80310b1bcbcfa3176ea809ac840f48cda634c475d402e6bd68d5bb3827d601' - ) - async def test_load_and_save_account(self): account_data = { 'name': 'My Account', @@ -380,7 +304,7 @@ class TestSingleKeyAccount(AsyncioTestCase): 'certificates': {} } - account = Account.from_dict(self.ledger, Wallet(), account_data) + account = Account.from_dict(self.ledger, self.db, account_data) await account.ensure_address_gap() @@ -390,11 +314,11 @@ class TestSingleKeyAccount(AsyncioTestCase): self.assertEqual(len(addresses), 1) self.maxDiff = None - account_data['ledger'] = 'lbc_mainnet' self.assertDictEqual(account_data, account.to_dict()) -class AccountEncryptionTests(AsyncioTestCase): +class AccountEncryptionTests(AccountTestCase): + password = "password" init_vector = b'0000000000000000' unencrypted_account = { @@ -428,14 +352,8 @@ class AccountEncryptionTests(AsyncioTestCase): 'address_generator': {'name': 'single-address'} } - async def asyncSetUp(self): - self.ledger = Ledger({ - 'db': Database(':memory:'), - 'headers': Headers(':memory:') - }) - def test_encrypt_wallet(self): - account = Account.from_dict(self.ledger, Wallet(), self.unencrypted_account) + account = Account.from_dict(self.ledger, self.db, self.unencrypted_account) account.init_vectors = { 'seed': self.init_vector, 'private_key': self.init_vector @@ -465,7 +383,7 @@ class AccountEncryptionTests(AsyncioTestCase): self.assertFalse(account.encrypted) def test_decrypt_wallet(self): - account = Account.from_dict(self.ledger, Wallet(), self.encrypted_account) + account = Account.from_dict(self.ledger, self.db, self.encrypted_account) self.assertTrue(account.encrypted) account.decrypt(self.password) @@ -486,7 +404,7 @@ class AccountEncryptionTests(AsyncioTestCase): account_data = self.unencrypted_account.copy() del account_data['seed'] del account_data['private_key'] - account = Account.from_dict(self.ledger, Wallet(), account_data) + account = Account.from_dict(self.ledger, self.db, account_data) encrypted = account.to_dict('password') self.assertFalse(encrypted['seed']) self.assertFalse(encrypted['private_key']) diff --git a/tests/unit/wallet/test_coinselection.py b/tests/unit/wallet/test_coinselection.py index 2faceda89..40cbf1684 100644 --- a/tests/unit/wallet/test_coinselection.py +++ b/tests/unit/wallet/test_coinselection.py @@ -1,16 +1,10 @@ +from unittest import TestCase from types import GeneratorType -from lbry.testcase import AsyncioTestCase - -from lbry.wallet import Ledger, Headers -from lbry.db import Database -from lbry.wallet.coinselection import CoinSelector, MAXIMUM_TRIES +from lbry.blockchain.ledger import RegTestLedger +from lbry.wallet.coinselection import CoinSelector, OutputEffectiveAmountEstimator, MAXIMUM_TRIES from lbry.constants import CENT - -from tests.unit.wallet.test_transaction import get_output as utxo - - -NULL_HASH = b'\x00'*32 +from lbry.testcase import get_output as utxo def search(*args, **kwargs): @@ -18,21 +12,14 @@ def search(*args, **kwargs): return [o.txo.amount for o in selection] if selection else selection -class BaseSelectionTestCase(AsyncioTestCase): +class BaseSelectionTestCase(TestCase): - async def asyncSetUp(self): - self.ledger = Ledger({ - 'db': Database('sqlite:///:memory:'), - 'headers': Headers(':memory:'), - }) - await self.ledger.db.open() - - async def asyncTearDown(self): - await self.ledger.db.close() + def setUp(self): + self.ledger = RegTestLedger() def estimates(self, *args): txos = args[0] if isinstance(args[0], (GeneratorType, list)) else args - return [txo.get_estimator(self.ledger) for txo in txos] + return [OutputEffectiveAmountEstimator(self.ledger, txo) for txo in txos] class TestCoinSelectionTests(BaseSelectionTestCase): @@ -41,7 +28,7 @@ class TestCoinSelectionTests(BaseSelectionTestCase): self.assertListEqual(CoinSelector(0, 0).select([]), []) def test_skip_binary_search_if_total_not_enough(self): - fee = utxo(CENT).get_estimator(self.ledger).fee + fee = OutputEffectiveAmountEstimator(self.ledger, utxo(CENT)).fee big_pool = self.estimates(utxo(CENT+fee) for _ in range(100)) selector = CoinSelector(101 * CENT, 0) self.assertListEqual(selector.select(big_pool), []) @@ -52,7 +39,7 @@ class TestCoinSelectionTests(BaseSelectionTestCase): self.assertEqual(selector.tries, 201) def test_exact_match(self): - fee = utxo(CENT).get_estimator(self.ledger).fee + fee = OutputEffectiveAmountEstimator(self.ledger, utxo(CENT)).fee utxo_pool = self.estimates( utxo(CENT + fee), utxo(CENT), diff --git a/tests/unit/wallet/test_wallet.py b/tests/unit/wallet/test_wallet.py index dc6e4d3e8..f2274d162 100644 --- a/tests/unit/wallet/test_wallet.py +++ b/tests/unit/wallet/test_wallet.py @@ -1,30 +1,95 @@ import tempfile from binascii import hexlify - from unittest import TestCase, mock + from lbry.testcase import AsyncioTestCase -from lbry.wallet import ( - Ledger, RegTestLedger, WalletManager, Account, - Wallet, WalletStorage, TimestampedPreferences +from lbry.db import Database +from lbry.blockchain.ledger import Ledger +from lbry.wallet.manager import WalletManager +from lbry.wallet.wallet import ( + Account, Wallet, WalletStorage, TimestampedPreferences ) -class TestWalletCreation(AsyncioTestCase): +class WalletTestCase(AsyncioTestCase): async def asyncSetUp(self): - self.manager = WalletManager() - config = {'data_path': '/tmp/wallet'} - self.main_ledger = self.manager.get_or_create_ledger(Ledger.get_id(), config) - self.test_ledger = self.manager.get_or_create_ledger(RegTestLedger.get_id(), config) + self.ledger = Ledger() + self.db = Database(self.ledger, 'sqlite:///:memory:') + await self.db.open() + self.addCleanup(self.db.close) + + +class WalletAccountTest(WalletTestCase): + + async def test_private_key_for_hierarchical_account(self): + wallet = Wallet(self.ledger, self.db) + account = wallet.add_account({ + "seed": + "carbon smart garage balance margin twelve chest sword toas" + "t envelope bottom stomach absent" + }) + await account.receiving.ensure_address_gap() + private_key = await wallet.get_private_key_for_address( + 'bCqJrLHdoiRqEZ1whFZ3WHNb33bP34SuGx' + ) + self.assertEqual( + private_key.extended_key_string(), + 'xprv9vwXVierUTT4hmoe3dtTeBfbNv1ph2mm8RWXARU6HsZjBaAoFaS2FRQu4fptR' + 'AyJWhJW42dmsEaC1nKnVKKTMhq3TVEHsNj1ca3ciZMKktT' + ) + self.assertIsNone( + await wallet.get_private_key_for_address('BcQjRlhDOIrQez1WHfz3whnB33Bp34sUgX') + ) + + async def test_private_key_for_single_address_account(self): + wallet = Wallet(self.ledger, self.db) + account = wallet.add_account({ + "seed": + "carbon smart garage balance margin twelve chest sword toas" + "t envelope bottom stomach absent", + 'address_generator': {'name': 'single-address'} + }) + address = await account.receiving.ensure_address_gap() + private_key = await wallet.get_private_key_for_address(address[0]) + self.assertEqual( + private_key.extended_key_string(), + 'xprv9s21ZrQH143K42ovpZygnjfHdAqSd9jo7zceDfPRogM7bkkoNVv7' + 'DRNLEoB8HoirMgH969NrgL8jNzLEegqFzPRWM37GXd4uE8uuRkx4LAe', + ) + self.assertIsNone( + await wallet.get_private_key_for_address('BcQjRlhDOIrQez1WHfz3whnB33Bp34sUgX') + ) + + async def test_save_max_gap(self): + wallet = Wallet(self.ledger, self.db) + account = wallet.generate_account( + 'lbryum', { + 'name': 'deterministic-chain', + 'receiving': {'gap': 3, 'maximum_uses_per_address': 2}, + 'change': {'gap': 4, 'maximum_uses_per_address': 2} + } + ) + self.assertEqual(account.receiving.gap, 3) + self.assertEqual(account.change.gap, 4) + await wallet.save_max_gap() + self.assertEqual(account.receiving.gap, 20) + self.assertEqual(account.change.gap, 6) + # doesn't fail for single-address account + wallet.generate_account('lbryum', {'name': 'single-address'}) + await wallet.save_max_gap() + + +class TestWalletCreation(WalletTestCase): def test_create_wallet_and_accounts(self): - wallet = Wallet() + wallet = Wallet(self.ledger, self.db) self.assertEqual(wallet.name, 'Wallet') self.assertListEqual(wallet.accounts, []) - account1 = wallet.generate_account(self.main_ledger) - wallet.generate_account(self.main_ledger) - wallet.generate_account(self.test_ledger) + account1 = wallet.generate_account() + wallet.generate_account() + wallet.generate_account() self.assertEqual(wallet.default_account, account1) self.assertEqual(len(wallet.accounts), 3) @@ -32,12 +97,12 @@ class TestWalletCreation(AsyncioTestCase): wallet_dict = { 'version': 1, 'name': 'Main Wallet', + 'ledger': 'lbc_mainnet', 'preferences': {}, 'accounts': [ { 'certificates': {}, 'name': 'An Account', - 'ledger': 'lbc_mainnet', 'modified_on': 123.456, 'seed': "carbon smart garage balance margin twelve chest sword toast envelope bottom stomac" @@ -59,10 +124,11 @@ class TestWalletCreation(AsyncioTestCase): } storage = WalletStorage(default=wallet_dict) - wallet = Wallet.from_storage(storage, self.manager) + wallet = Wallet.from_storage(self.ledger, self.db, storage) self.assertEqual(wallet.name, 'Main Wallet') self.assertEqual( - hexlify(wallet.hash), b'a75913d2e7339c1a9ac0c89d621a4e10fd3a40dc3560dc01f4cf4ada0a0b05b8' + hexlify(wallet.hash), + b'3b23aae8cd9b360f4296130b8f7afc5b2437560cdef7237bed245288ce8a5f79' ) self.assertEqual(len(wallet.accounts), 1) account = wallet.default_account @@ -75,9 +141,7 @@ class TestWalletCreation(AsyncioTestCase): self.assertEqual(decrypted['accounts'][0]['name'], 'An Account') def test_read_write(self): - manager = WalletManager() - config = {'data_path': '/tmp/wallet'} - ledger = manager.get_or_create_ledger(Ledger.get_id(), config) + manager = WalletManager(self.ledger, self.db) with tempfile.NamedTemporaryFile(suffix='.json') as wallet_file: wallet_file.write(b'{"version": 1}') @@ -85,29 +149,29 @@ class TestWalletCreation(AsyncioTestCase): # create and write wallet to a file wallet = manager.import_wallet(wallet_file.name) - account = wallet.generate_account(ledger) + account = wallet.generate_account() wallet.save() # read wallet from file wallet_storage = WalletStorage(wallet_file.name) - wallet = Wallet.from_storage(wallet_storage, manager) + wallet = Wallet.from_storage(self.ledger, self.db, wallet_storage) self.assertEqual(account.public_key.address, wallet.default_account.public_key.address) def test_merge(self): - wallet1 = Wallet() + wallet1 = Wallet(self.ledger, self.db) wallet1.preferences['one'] = 1 wallet1.preferences['conflict'] = 1 - wallet1.generate_account(self.main_ledger) - wallet2 = Wallet() + wallet1.generate_account() + wallet2 = Wallet(self.ledger, self.db) wallet2.preferences['two'] = 2 wallet2.preferences['conflict'] = 2 # will be more recent - wallet2.generate_account(self.main_ledger) + wallet2.generate_account() self.assertEqual(len(wallet1.accounts), 1) self.assertEqual(wallet1.preferences, {'one': 1, 'conflict': 1}) - added = wallet1.merge(self.manager, 'password', wallet2.pack('password')) + added = wallet1.merge('password', wallet2.pack('password')) self.assertEqual(added[0].id, wallet2.default_account.id) self.assertEqual(len(wallet1.accounts), 2) self.assertEqual(wallet1.accounts[1].id, wallet2.default_account.id)