diff --git a/lbry/wallet/header.py b/lbry/wallet/header.py index fc9018562..6094aa017 100644 --- a/lbry/wallet/header.py +++ b/lbry/wallet/header.py @@ -5,7 +5,6 @@ import asyncio import logging import zlib from datetime import date -from concurrent.futures.thread import ThreadPoolExecutor from io import BytesIO from typing import Optional, Iterator, Tuple, Callable @@ -42,7 +41,7 @@ class Headers: validate_difficulty: bool = True def __init__(self, path) -> None: - self.io = BytesIO() + self.io = None self.path = path self._size: Optional[int] = None self.chunk_getter: Optional[Callable] = None @@ -50,6 +49,7 @@ class Headers: self.check_chunk_lock = asyncio.Lock() async def open(self): + self.io = BytesIO() if self.path != ':memory:': if os.path.exists(self.path): with open(self.path, 'r+b') as header_file: @@ -133,16 +133,16 @@ class Headers: except struct.error: raise IndexError(f"failed to get {height}, at {len(self)}") - def estimated_timestamp(self, height): + def estimated_timestamp(self, height, try_real_headers=True): if height <= 0: return - if self.has_header(height): + if try_real_headers and self.has_header(height): offset = height * self.header_size return struct.unpack(' bytes: if self.chunk_getter: diff --git a/tests/unit/wallet/test_database.py b/tests/unit/wallet/test_database.py index 7ceed23d7..b796a9c1a 100644 --- a/tests/unit/wallet/test_database.py +++ b/tests/unit/wallet/test_database.py @@ -211,6 +211,7 @@ class TestQueries(AsyncioTestCase): 'db': Database(':memory:'), 'headers': Headers(':memory:') }) + await self.ledger.headers.open() self.wallet = Wallet() await self.ledger.db.open() diff --git a/tests/unit/wallet/test_headers.py b/tests/unit/wallet/test_headers.py index 578b85512..da433724d 100644 --- a/tests/unit/wallet/test_headers.py +++ b/tests/unit/wallet/test_headers.py @@ -21,8 +21,8 @@ class TestHeaders(AsyncioTestCase): async def test_deserialize(self): self.maxDiff = None h = Headers(':memory:') - h.io.write(HEADERS) await h.open() + await h.connect(0, HEADERS) self.assertEqual(await h.get(0), { 'bits': 520159231, 'block_height': 0, @@ -52,8 +52,11 @@ class TestHeaders(AsyncioTestCase): self.assertEqual(headers.height, 19) async def test_connect_from_middle(self): - h = Headers(':memory:') - h.io.write(HEADERS[:block_bytes(10)]) + headers_temporary_file = tempfile.mktemp() + self.addCleanup(os.remove, headers_temporary_file) + with open(headers_temporary_file, 'w+b') as headers_file: + headers_file.write(HEADERS[:block_bytes(10)]) + h = Headers(headers_temporary_file) await h.open() self.assertEqual(h.height, 9) await h.connect(len(h), HEADERS[block_bytes(10):block_bytes(20)]) @@ -115,6 +118,7 @@ class TestHeaders(AsyncioTestCase): async def test_bounds(self): headers = Headers(':memory:') + await headers.open() await headers.connect(0, HEADERS) self.assertEqual(19, headers.height) with self.assertRaises(IndexError): @@ -126,6 +130,7 @@ class TestHeaders(AsyncioTestCase): async def test_repair(self): headers = Headers(':memory:') + await headers.open() await headers.connect(0, HEADERS[:block_bytes(11)]) self.assertEqual(10, headers.height) await headers.repair() @@ -147,8 +152,9 @@ class TestHeaders(AsyncioTestCase): await headers.repair(start_height=10) self.assertEqual(19, headers.height) - def test_do_not_estimate_unconfirmed(self): + async def test_do_not_estimate_unconfirmed(self): headers = Headers(':memory:') + await headers.open() self.assertIsNone(headers.estimated_timestamp(-1)) self.assertIsNone(headers.estimated_timestamp(0)) self.assertIsNotNone(headers.estimated_timestamp(1)) @@ -164,17 +170,21 @@ class TestHeaders(AsyncioTestCase): self.assertEqual(after_downloading_header_estimated, real_time) async def test_misalignment_triggers_repair_on_open(self): - headers = Headers(':memory:') - headers.io.seek(0) - headers.io.write(HEADERS) + headers_temporary_file = tempfile.mktemp() + self.addCleanup(os.remove, headers_temporary_file) + with open(headers_temporary_file, 'w+b') as headers_file: + headers_file.write(HEADERS) + headers = Headers(headers_temporary_file) with self.assertLogs(level='WARN') as cm: await headers.open() + await headers.close() self.assertEqual(cm.output, []) - headers.io.seek(0) - headers.io.truncate() - headers.io.write(HEADERS[:block_bytes(10)]) - headers.io.write(b'ops') - headers.io.write(HEADERS[block_bytes(10):]) + with open(headers_temporary_file, 'w+b') as headers_file: + headers_file.seek(0) + headers_file.truncate() + headers_file.write(HEADERS[:block_bytes(10)]) + headers_file.write(b'ops') + headers_file.write(HEADERS[block_bytes(10):]) await headers.open() self.assertEqual( cm.output, [ diff --git a/tests/unit/wallet/test_ledger.py b/tests/unit/wallet/test_ledger.py index 0244de987..bfe5cc71b 100644 --- a/tests/unit/wallet/test_ledger.py +++ b/tests/unit/wallet/test_ledger.py @@ -48,6 +48,8 @@ class LedgerTestCase(AsyncioTestCase): 'db': Database(':memory:'), 'headers': Headers(':memory:') }) + self.ledger.headers.checkpoints = {} + await self.ledger.headers.open() self.account = Account.generate(self.ledger, Wallet(), "lbryum") await self.ledger.db.open()