Merge pull request #2922 from lbryio/estimate_on_demand

estimate block timestamps on client only when necessary (header hasn't been downloaded yet)
This commit is contained in:
Lex Berezhny 2020-04-27 12:17:54 -04:00 committed by GitHub
commit 94c45cf2a0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 80 additions and 59 deletions

View file

@ -5,7 +5,6 @@ import asyncio
import logging import logging
import zlib import zlib
from datetime import date from datetime import date
from concurrent.futures.thread import ThreadPoolExecutor
from io import BytesIO from io import BytesIO
from typing import Optional, Iterator, Tuple, Callable from typing import Optional, Iterator, Tuple, Callable
@ -42,23 +41,22 @@ class Headers:
validate_difficulty: bool = True validate_difficulty: bool = True
def __init__(self, path) -> None: def __init__(self, path) -> None:
if path == ':memory:': self.io = None
self.io = BytesIO()
self.path = path self.path = path
self._size: Optional[int] = None self._size: Optional[int] = None
self.chunk_getter: Optional[Callable] = None self.chunk_getter: Optional[Callable] = None
self.executor = ThreadPoolExecutor(1)
self.known_missing_checkpointed_chunks = set() self.known_missing_checkpointed_chunks = set()
self.check_chunk_lock = asyncio.Lock() self.check_chunk_lock = asyncio.Lock()
async def open(self): async def open(self):
if not self.executor: self.io = BytesIO()
self.executor = ThreadPoolExecutor(1)
if self.path != ':memory:': if self.path != ':memory:':
if not os.path.exists(self.path): def _readit():
self.io = open(self.path, 'w+b') if os.path.exists(self.path):
else: with open(self.path, 'r+b') as header_file:
self.io = open(self.path, 'r+b') self.io.seek(0)
self.io.write(header_file.read())
await asyncio.get_event_loop().run_in_executor(None, _readit)
bytes_size = self.io.seek(0, os.SEEK_END) bytes_size = self.io.seek(0, os.SEEK_END)
self._size = bytes_size // self.header_size self._size = bytes_size // self.header_size
max_checkpointed_height = max(self.checkpoints.keys() or [-1]) + 1000 max_checkpointed_height = max(self.checkpoints.keys() or [-1]) + 1000
@ -72,10 +70,14 @@ class Headers:
await self.get_all_missing_headers() await self.get_all_missing_headers()
async def close(self): async def close(self):
if self.executor: if self.io is not None:
self.executor.shutdown() def _close():
self.executor = None flags = 'r+b' if os.path.exists(self.path) else 'w+b'
self.io.close() with open(self.path, flags) as header_file:
header_file.write(self.io.getbuffer())
await asyncio.get_event_loop().run_in_executor(None, _close)
self.io.close()
self.io = None
@staticmethod @staticmethod
def serialize(header): def serialize(header):
@ -135,28 +137,30 @@ class Headers:
except struct.error: except struct.error:
raise IndexError(f"failed to get {height}, at {len(self)}") raise IndexError(f"failed to get {height}, at {len(self)}")
def estimated_timestamp(self, height): def estimated_timestamp(self, height, try_real_headers=True):
if height <= 0: if height <= 0:
return return
if try_real_headers and self.has_header(height):
offset = height * self.header_size
return struct.unpack('<I', self.io.getbuffer()[offset + 100: offset + 104])[0]
return int(self.first_block_timestamp + (height * self.timestamp_average_offset)) return int(self.first_block_timestamp + (height * self.timestamp_average_offset))
def estimated_julian_day(self, height): def estimated_julian_day(self, height):
return date_to_julian_day(date.fromtimestamp(self.estimated_timestamp(height))) return date_to_julian_day(date.fromtimestamp(self.estimated_timestamp(height, False)))
async def get_raw_header(self, height) -> bytes: async def get_raw_header(self, height) -> bytes:
if self.chunk_getter: if self.chunk_getter:
await self.ensure_chunk_at(height) await self.ensure_chunk_at(height)
if not 0 <= height <= self.height: if not 0 <= height <= self.height:
raise IndexError(f"{height} is out of bounds, current height: {self.height}") raise IndexError(f"{height} is out of bounds, current height: {self.height}")
return await asyncio.get_running_loop().run_in_executor(self.executor, self._read, height) return self._read(height)
def _read(self, height, count=1): def _read(self, height, count=1):
self.io.seek(height * self.header_size, os.SEEK_SET) offset = height * self.header_size
return self.io.read(self.header_size * count) return bytes(self.io.getbuffer()[offset: offset + self.header_size * count])
def chunk_hash(self, start, count): def chunk_hash(self, start, count):
self.io.seek(start * self.header_size, os.SEEK_SET) return self.hash_header(self._read(start, count)).decode()
return self.hash_header(self.io.read(count * self.header_size)).decode()
async def ensure_checkpointed_size(self): async def ensure_checkpointed_size(self):
max_checkpointed_height = max(self.checkpoints.keys() or [-1]) max_checkpointed_height = max(self.checkpoints.keys() or [-1])
@ -165,7 +169,7 @@ class Headers:
async def ensure_chunk_at(self, height): async def ensure_chunk_at(self, height):
async with self.check_chunk_lock: async with self.check_chunk_lock:
if await self.has_header(height): if self.has_header(height):
log.debug("has header %s", height) log.debug("has header %s", height)
return return
return await self.fetch_chunk(height) return await self.fetch_chunk(height)
@ -179,7 +183,7 @@ class Headers:
) )
chunk_hash = self.hash_header(chunk).decode() chunk_hash = self.hash_header(chunk).decode()
if self.checkpoints.get(start) == chunk_hash: if self.checkpoints.get(start) == chunk_hash:
await asyncio.get_running_loop().run_in_executor(self.executor, self._write, start, chunk) self._write(start, chunk)
if start in self.known_missing_checkpointed_chunks: if start in self.known_missing_checkpointed_chunks:
self.known_missing_checkpointed_chunks.remove(start) self.known_missing_checkpointed_chunks.remove(start)
return return
@ -189,27 +193,23 @@ class Headers:
f"Checkpoint mismatch at height {start}. Expected {self.checkpoints[start]}, but got {chunk_hash} instead." f"Checkpoint mismatch at height {start}. Expected {self.checkpoints[start]}, but got {chunk_hash} instead."
) )
async def has_header(self, height): def has_header(self, height):
normalized_height = (height // 1000) * 1000 normalized_height = (height // 1000) * 1000
if normalized_height in self.checkpoints: if normalized_height in self.checkpoints:
return normalized_height not in self.known_missing_checkpointed_chunks return normalized_height not in self.known_missing_checkpointed_chunks
def _has_header(height): empty = '56944c5d3f98413ef45cf54545538103cc9f298e0575820ad3591376e2e0f65d'
empty = '56944c5d3f98413ef45cf54545538103cc9f298e0575820ad3591376e2e0f65d' all_zeroes = '789d737d4f448e554b318c94063bbfa63e9ccda6e208f5648ca76ee68896557b'
all_zeroes = '789d737d4f448e554b318c94063bbfa63e9ccda6e208f5648ca76ee68896557b' return self.chunk_hash(height, 1) not in (empty, all_zeroes)
return self.chunk_hash(height, 1) not in (empty, all_zeroes)
return await asyncio.get_running_loop().run_in_executor(self.executor, _has_header, height)
async def get_all_missing_headers(self): async def get_all_missing_headers(self):
# Heavy operation done in one optimized shot # Heavy operation done in one optimized shot
def _io_checkall(): for chunk_height, expected_hash in reversed(list(self.checkpoints.items())):
for chunk_height, expected_hash in reversed(list(self.checkpoints.items())): if chunk_height in self.known_missing_checkpointed_chunks:
if chunk_height in self.known_missing_checkpointed_chunks: continue
continue if self.chunk_hash(chunk_height, 1000) != expected_hash:
if self.chunk_hash(chunk_height, 1000) != expected_hash: self.known_missing_checkpointed_chunks.add(chunk_height)
self.known_missing_checkpointed_chunks.add(chunk_height) return self.known_missing_checkpointed_chunks
return self.known_missing_checkpointed_chunks
return await asyncio.get_running_loop().run_in_executor(self.executor, _io_checkall)
@property @property
def height(self) -> int: def height(self) -> int:
@ -241,7 +241,7 @@ class Headers:
bail = True bail = True
chunk = chunk[:(height-e.height)*self.header_size] chunk = chunk[:(height-e.height)*self.header_size]
if chunk: if chunk:
added += await asyncio.get_running_loop().run_in_executor(self.executor, self._write, height, chunk) added += self._write(height, chunk)
if bail: if bail:
break break
return added return added
@ -306,9 +306,7 @@ class Headers:
previous_header_hash = fail = None previous_header_hash = fail = None
batch_size = 36 batch_size = 36
for height in range(start_height, self.height, batch_size): for height in range(start_height, self.height, batch_size):
headers = await asyncio.get_running_loop().run_in_executor( headers = self._read(height, batch_size)
self.executor, self._read, height, batch_size
)
if len(headers) % self.header_size != 0: if len(headers) % self.header_size != 0:
headers = headers[:(len(headers) // self.header_size) * self.header_size] headers = headers[:(len(headers) // self.header_size) * self.header_size]
for header_hash, header in self._iterate_headers(height, headers): for header_hash, header in self._iterate_headers(height, headers):
@ -324,12 +322,11 @@ class Headers:
assert start_height > 0 and height == start_height assert start_height > 0 and height == start_height
if fail: if fail:
log.warning("Header file corrupted at height %s, truncating it.", height - 1) log.warning("Header file corrupted at height %s, truncating it.", height - 1)
def __truncate(at_height): self.io.seek(max(0, (height - 1)) * self.header_size, os.SEEK_SET)
self.io.seek(max(0, (at_height - 1)) * self.header_size, os.SEEK_SET) self.io.truncate()
self.io.truncate() self.io.flush()
self.io.flush() self._size = self.io.seek(0, os.SEEK_END) // self.header_size
self._size = self.io.seek(0, os.SEEK_END) // self.header_size return
return await asyncio.get_running_loop().run_in_executor(self.executor, __truncate, height)
previous_header_hash = header_hash previous_header_hash = header_hash
@classmethod @classmethod

View file

@ -211,6 +211,7 @@ class TestQueries(AsyncioTestCase):
'db': Database(':memory:'), 'db': Database(':memory:'),
'headers': Headers(':memory:') 'headers': Headers(':memory:')
}) })
await self.ledger.headers.open()
self.wallet = Wallet() self.wallet = Wallet()
await self.ledger.db.open() await self.ledger.db.open()

View file

@ -21,8 +21,8 @@ class TestHeaders(AsyncioTestCase):
async def test_deserialize(self): async def test_deserialize(self):
self.maxDiff = None self.maxDiff = None
h = Headers(':memory:') h = Headers(':memory:')
h.io.write(HEADERS)
await h.open() await h.open()
await h.connect(0, HEADERS)
self.assertEqual(await h.get(0), { self.assertEqual(await h.get(0), {
'bits': 520159231, 'bits': 520159231,
'block_height': 0, 'block_height': 0,
@ -52,8 +52,11 @@ class TestHeaders(AsyncioTestCase):
self.assertEqual(headers.height, 19) self.assertEqual(headers.height, 19)
async def test_connect_from_middle(self): async def test_connect_from_middle(self):
h = Headers(':memory:') headers_temporary_file = tempfile.mktemp()
h.io.write(HEADERS[:block_bytes(10)]) self.addCleanup(os.remove, headers_temporary_file)
with open(headers_temporary_file, 'w+b') as headers_file:
headers_file.write(HEADERS[:block_bytes(10)])
h = Headers(headers_temporary_file)
await h.open() await h.open()
self.assertEqual(h.height, 9) self.assertEqual(h.height, 9)
await h.connect(len(h), HEADERS[block_bytes(10):block_bytes(20)]) await h.connect(len(h), HEADERS[block_bytes(10):block_bytes(20)])
@ -115,6 +118,7 @@ class TestHeaders(AsyncioTestCase):
async def test_bounds(self): async def test_bounds(self):
headers = Headers(':memory:') headers = Headers(':memory:')
await headers.open()
await headers.connect(0, HEADERS) await headers.connect(0, HEADERS)
self.assertEqual(19, headers.height) self.assertEqual(19, headers.height)
with self.assertRaises(IndexError): with self.assertRaises(IndexError):
@ -126,6 +130,7 @@ class TestHeaders(AsyncioTestCase):
async def test_repair(self): async def test_repair(self):
headers = Headers(':memory:') headers = Headers(':memory:')
await headers.open()
await headers.connect(0, HEADERS[:block_bytes(11)]) await headers.connect(0, HEADERS[:block_bytes(11)])
self.assertEqual(10, headers.height) self.assertEqual(10, headers.height)
await headers.repair() await headers.repair()
@ -147,24 +152,39 @@ class TestHeaders(AsyncioTestCase):
await headers.repair(start_height=10) await headers.repair(start_height=10)
self.assertEqual(19, headers.height) self.assertEqual(19, headers.height)
def test_do_not_estimate_unconfirmed(self): async def test_do_not_estimate_unconfirmed(self):
headers = Headers(':memory:') headers = Headers(':memory:')
await headers.open()
self.assertIsNone(headers.estimated_timestamp(-1)) self.assertIsNone(headers.estimated_timestamp(-1))
self.assertIsNone(headers.estimated_timestamp(0)) self.assertIsNone(headers.estimated_timestamp(0))
self.assertIsNotNone(headers.estimated_timestamp(1)) self.assertIsNotNone(headers.estimated_timestamp(1))
async def test_misalignment_triggers_repair_on_open(self): async def test_dont_estimate_whats_there(self):
headers = Headers(':memory:') headers = Headers(':memory:')
headers.io.seek(0) await headers.open()
headers.io.write(HEADERS) estimated = headers.estimated_timestamp(10)
await headers.connect(0, HEADERS)
real_time = (await headers.get(10))['timestamp']
after_downloading_header_estimated = headers.estimated_timestamp(10)
self.assertNotEqual(estimated, after_downloading_header_estimated)
self.assertEqual(after_downloading_header_estimated, real_time)
async def test_misalignment_triggers_repair_on_open(self):
headers_temporary_file = tempfile.mktemp()
self.addCleanup(os.remove, headers_temporary_file)
with open(headers_temporary_file, 'w+b') as headers_file:
headers_file.write(HEADERS)
headers = Headers(headers_temporary_file)
with self.assertLogs(level='WARN') as cm: with self.assertLogs(level='WARN') as cm:
await headers.open() await headers.open()
await headers.close()
self.assertEqual(cm.output, []) self.assertEqual(cm.output, [])
headers.io.seek(0) with open(headers_temporary_file, 'w+b') as headers_file:
headers.io.truncate() headers_file.seek(0)
headers.io.write(HEADERS[:block_bytes(10)]) headers_file.truncate()
headers.io.write(b'ops') headers_file.write(HEADERS[:block_bytes(10)])
headers.io.write(HEADERS[block_bytes(10):]) headers_file.write(b'ops')
headers_file.write(HEADERS[block_bytes(10):])
await headers.open() await headers.open()
self.assertEqual( self.assertEqual(
cm.output, [ cm.output, [
@ -192,6 +212,7 @@ class TestHeaders(AsyncioTestCase):
reader_task = asyncio.create_task(reader()) reader_task = asyncio.create_task(reader())
await writer() await writer()
await reader_task await reader_task
await headers.close()
HEADERS = unhexlify( HEADERS = unhexlify(

View file

@ -48,6 +48,8 @@ class LedgerTestCase(AsyncioTestCase):
'db': Database(':memory:'), 'db': Database(':memory:'),
'headers': Headers(':memory:') 'headers': Headers(':memory:')
}) })
self.ledger.headers.checkpoints = {}
await self.ledger.headers.open()
self.account = Account.generate(self.ledger, Wallet(), "lbryum") self.account = Account.generate(self.ledger, Wallet(), "lbryum")
await self.ledger.db.open() await self.ledger.db.open()