forked from LBRYCommunity/lbry-sdk
better locking, stop corrupting headers, fix some tests
This commit is contained in:
parent
241e946d91
commit
b04a516063
|
@ -5,6 +5,7 @@ import asyncio
|
||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
import zlib
|
import zlib
|
||||||
|
from concurrent.futures.thread import ThreadPoolExecutor
|
||||||
|
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
@ -47,6 +48,7 @@ class Headers:
|
||||||
self.path = path
|
self.path = path
|
||||||
self._size: Optional[int] = None
|
self._size: Optional[int] = None
|
||||||
self.chunk_getter: Optional[Callable] = None
|
self.chunk_getter: Optional[Callable] = None
|
||||||
|
self.executor = ThreadPoolExecutor(1)
|
||||||
|
|
||||||
async def open(self):
|
async def open(self):
|
||||||
if self.path != ':memory:':
|
if self.path != ':memory:':
|
||||||
|
@ -54,8 +56,10 @@ class Headers:
|
||||||
self.io = open(self.path, 'w+b')
|
self.io = open(self.path, 'w+b')
|
||||||
else:
|
else:
|
||||||
self.io = open(self.path, 'r+b')
|
self.io = open(self.path, 'r+b')
|
||||||
|
self._size = self.io.seek(0, os.SEEK_END) // self.header_size
|
||||||
|
|
||||||
async def close(self):
|
async def close(self):
|
||||||
|
self.executor.shutdown()
|
||||||
self.io.close()
|
self.io.close()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -103,27 +107,34 @@ class Headers:
|
||||||
return new_target
|
return new_target
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
if self._size is None:
|
|
||||||
self._size = self.io.seek(0, os.SEEK_END) // self.header_size
|
|
||||||
return self._size
|
return self._size
|
||||||
|
|
||||||
def __bool__(self):
|
def __bool__(self):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def get(self, height) -> dict:
|
async def get(self, height) -> dict:
|
||||||
|
if height < 0:
|
||||||
|
raise IndexError(f"Height cannot be negative!!")
|
||||||
if isinstance(height, slice):
|
if isinstance(height, slice):
|
||||||
raise NotImplementedError("Slicing of header chain has not been implemented yet.")
|
raise NotImplementedError("Slicing of header chain has not been implemented yet.")
|
||||||
return self.deserialize(height, await self.get_raw_header(height))
|
try:
|
||||||
|
return self.deserialize(height, await self.get_raw_header(height))
|
||||||
|
except struct.error:
|
||||||
|
raise IndexError(f"failed to get {height}, at {len(self)}")
|
||||||
|
|
||||||
def estimated_timestamp(self, height):
|
def estimated_timestamp(self, height):
|
||||||
return self.first_block_timestamp + (height * self.timestamp_average_offset)
|
return self.first_block_timestamp + (height * self.timestamp_average_offset)
|
||||||
|
|
||||||
async def get_raw_header(self, height) -> bytes:
|
async def get_raw_header(self, height) -> bytes:
|
||||||
await self.ensure_chunk_at(height)
|
if self.chunk_getter:
|
||||||
self.io.seek(height * self.header_size, os.SEEK_SET)
|
await self.ensure_chunk_at(height)
|
||||||
return self.io.read(self.header_size)
|
return await asyncio.get_running_loop().run_in_executor(self.executor, self._read, height)
|
||||||
|
|
||||||
async def chunk_hash(self, start, count):
|
def _read(self, height, count=1):
|
||||||
|
self.io.seek(height * self.header_size, os.SEEK_SET)
|
||||||
|
return self.io.read(self.header_size * count)
|
||||||
|
|
||||||
|
def chunk_hash(self, start, count):
|
||||||
self.io.seek(start * self.header_size, os.SEEK_SET)
|
self.io.seek(start * self.header_size, os.SEEK_SET)
|
||||||
return self.hash_header(self.io.read(count * self.header_size)).decode()
|
return self.hash_header(self.io.read(count * self.header_size)).decode()
|
||||||
|
|
||||||
|
@ -141,16 +152,20 @@ class Headers:
|
||||||
zlib.decompress(base64.b64decode(headers['base64']), wbits=-15, bufsize=600_000)
|
zlib.decompress(base64.b64decode(headers['base64']), wbits=-15, bufsize=600_000)
|
||||||
)
|
)
|
||||||
chunk_hash = self.hash_header(chunk).decode()
|
chunk_hash = self.hash_header(chunk).decode()
|
||||||
if HASHES[start] == chunk_hash:
|
if HASHES.get(start) == chunk_hash:
|
||||||
return self._write(start, chunk)
|
return await asyncio.get_running_loop().run_in_executor(self.executor, self._write, start, chunk)
|
||||||
|
elif start not in HASHES:
|
||||||
|
return # todo: fixme
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Checkpoint mismatch at height {start}. Expected {HASHES[start]}, but got {chunk_hash} instead."
|
f"Checkpoint mismatch at height {start}. Expected {HASHES[start]}, but got {chunk_hash} instead."
|
||||||
)
|
)
|
||||||
|
|
||||||
async def has_header(self, height):
|
async def has_header(self, height):
|
||||||
empty = '56944c5d3f98413ef45cf54545538103cc9f298e0575820ad3591376e2e0f65d'
|
def _has_header(height):
|
||||||
all_zeroes = '789d737d4f448e554b318c94063bbfa63e9ccda6e208f5648ca76ee68896557b'
|
empty = '56944c5d3f98413ef45cf54545538103cc9f298e0575820ad3591376e2e0f65d'
|
||||||
return await self.chunk_hash(height, 1) not in (empty, all_zeroes)
|
all_zeroes = '789d737d4f448e554b318c94063bbfa63e9ccda6e208f5648ca76ee68896557b'
|
||||||
|
return self.chunk_hash(height, 1) not in (empty, all_zeroes)
|
||||||
|
return await asyncio.get_running_loop().run_in_executor(self.executor, _has_header, height)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def height(self) -> int:
|
def height(self) -> int:
|
||||||
|
@ -216,11 +231,11 @@ class Headers:
|
||||||
def _write(self, height, verified_chunk):
|
def _write(self, height, verified_chunk):
|
||||||
self.io.seek(height * self.header_size, os.SEEK_SET)
|
self.io.seek(height * self.header_size, os.SEEK_SET)
|
||||||
written = self.io.write(verified_chunk) // self.header_size
|
written = self.io.write(verified_chunk) // self.header_size
|
||||||
self.io.truncate()
|
# self.io.truncate()
|
||||||
# .seek()/.write()/.truncate() might also .flush() when needed
|
# .seek()/.write()/.truncate() might also .flush() when needed
|
||||||
# the goal here is mainly to ensure we're definitely flush()'ing
|
# the goal here is mainly to ensure we're definitely flush()'ing
|
||||||
self.io.flush()
|
self.io.flush()
|
||||||
self._size = self.io.tell() // self.header_size
|
self._size = self.io.seek(0, os.SEEK_END) // self.header_size
|
||||||
return written
|
return written
|
||||||
|
|
||||||
async def validate_chunk(self, height, chunk):
|
async def validate_chunk(self, height, chunk):
|
||||||
|
@ -272,8 +287,9 @@ class Headers:
|
||||||
previous_header_hash = fail = None
|
previous_header_hash = fail = None
|
||||||
batch_size = 36
|
batch_size = 36
|
||||||
for start_height in range(0, self.height, batch_size):
|
for start_height in range(0, self.height, batch_size):
|
||||||
self.io.seek(self.header_size * start_height)
|
headers = await asyncio.get_running_loop().run_in_executor(
|
||||||
headers = self.io.read(self.header_size*batch_size)
|
self.executor, self._read, start_height, batch_size
|
||||||
|
)
|
||||||
if len(headers) % self.header_size != 0:
|
if len(headers) % self.header_size != 0:
|
||||||
headers = headers[:(len(headers) // self.header_size) * self.header_size]
|
headers = headers[:(len(headers) // self.header_size) * self.header_size]
|
||||||
for header_hash, header in self._iterate_headers(start_height, headers):
|
for header_hash, header in self._iterate_headers(start_height, headers):
|
||||||
|
@ -286,11 +302,12 @@ class Headers:
|
||||||
fail = True
|
fail = True
|
||||||
if fail:
|
if fail:
|
||||||
log.warning("Header file corrupted at height %s, truncating it.", height - 1)
|
log.warning("Header file corrupted at height %s, truncating it.", height - 1)
|
||||||
self.io.seek(max(0, (height - 1)) * self.header_size, os.SEEK_SET)
|
def __truncate(at_height):
|
||||||
self.io.truncate()
|
self.io.seek(max(0, (at_height - 1)) * self.header_size, os.SEEK_SET)
|
||||||
self.io.flush()
|
self.io.truncate()
|
||||||
self._size = None
|
self.io.flush()
|
||||||
return
|
self._size = self.io.seek(0, os.SEEK_END) // self.header_size
|
||||||
|
return await asyncio.get_running_loop().run_in_executor(self.executor, __truncate, height)
|
||||||
previous_header_hash = header_hash
|
previous_header_hash = header_hash
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
@ -316,9 +316,6 @@ class Ledger(metaclass=LedgerRegistry):
|
||||||
first_connection = self.network.on_connected.first
|
first_connection = self.network.on_connected.first
|
||||||
asyncio.ensure_future(self.network.start())
|
asyncio.ensure_future(self.network.start())
|
||||||
await first_connection
|
await first_connection
|
||||||
async with self._header_processing_lock:
|
|
||||||
await self._update_tasks.add(self.initial_headers_sync())
|
|
||||||
await self._on_ready_controller.stream.first
|
|
||||||
await asyncio.gather(*(a.maybe_migrate_certificates() for a in self.accounts))
|
await asyncio.gather(*(a.maybe_migrate_certificates() for a in self.accounts))
|
||||||
await asyncio.gather(*(a.save_max_gap() for a in self.accounts))
|
await asyncio.gather(*(a.save_max_gap() for a in self.accounts))
|
||||||
if len(self.accounts) > 10:
|
if len(self.accounts) > 10:
|
||||||
|
@ -329,8 +326,10 @@ class Ledger(metaclass=LedgerRegistry):
|
||||||
|
|
||||||
async def join_network(self, *_):
|
async def join_network(self, *_):
|
||||||
log.info("Subscribing and updating accounts.")
|
log.info("Subscribing and updating accounts.")
|
||||||
#async with self._header_processing_lock:
|
self._update_tasks.add(self.initial_headers_sync())
|
||||||
# await self.update_headers()
|
async with self._header_processing_lock:
|
||||||
|
await self.headers.ensure_tip()
|
||||||
|
await self.update_headers()
|
||||||
await self.subscribe_accounts()
|
await self.subscribe_accounts()
|
||||||
await self._update_tasks.done.wait()
|
await self._update_tasks.done.wait()
|
||||||
self._on_ready_controller.add(True)
|
self._on_ready_controller.add(True)
|
||||||
|
@ -348,16 +347,12 @@ class Ledger(metaclass=LedgerRegistry):
|
||||||
|
|
||||||
async def initial_headers_sync(self):
|
async def initial_headers_sync(self):
|
||||||
target = self.network.remote_height + 1
|
target = self.network.remote_height + 1
|
||||||
current = len(self.headers)
|
|
||||||
get_chunk = partial(self.network.retriable_call, self.network.get_headers, count=1000, b64=True)
|
get_chunk = partial(self.network.retriable_call, self.network.get_headers, count=1000, b64=True)
|
||||||
self.headers.chunk_getter = get_chunk
|
self.headers.chunk_getter = get_chunk
|
||||||
await self.headers.ensure_tip()
|
|
||||||
|
|
||||||
async def doit():
|
async def doit():
|
||||||
for height in range(current, target, 1000):
|
for height in reversed(range(0, target, 1000)):
|
||||||
await self.headers.ensure_chunk_at(height)
|
await self.headers.ensure_chunk_at(height)
|
||||||
self._download_height = height
|
|
||||||
log.info("Headers sync: %s / %s", self._download_height, target)
|
|
||||||
asyncio.ensure_future(doit())
|
asyncio.ensure_future(doit())
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -598,7 +593,7 @@ class Ledger(metaclass=LedgerRegistry):
|
||||||
|
|
||||||
async def maybe_verify_transaction(self, tx, remote_height):
|
async def maybe_verify_transaction(self, tx, remote_height):
|
||||||
tx.height = remote_height
|
tx.height = remote_height
|
||||||
if 0 < remote_height < self.network.remote_height:
|
if 0 < remote_height < len(self.headers):
|
||||||
merkle = await self.network.retriable_call(self.network.get_merkle, tx.id, remote_height)
|
merkle = await self.network.retriable_call(self.network.get_merkle, tx.id, remote_height)
|
||||||
merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash)
|
merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash)
|
||||||
header = await self.headers.get(remote_height)
|
header = await self.headers.get(remote_height)
|
||||||
|
|
|
@ -42,6 +42,7 @@ class TestHeaders(AsyncioTestCase):
|
||||||
|
|
||||||
async def test_connect_from_genesis(self):
|
async def test_connect_from_genesis(self):
|
||||||
headers = Headers(':memory:')
|
headers = Headers(':memory:')
|
||||||
|
await headers.open()
|
||||||
self.assertEqual(headers.height, -1)
|
self.assertEqual(headers.height, -1)
|
||||||
await headers.connect(0, HEADERS)
|
await headers.connect(0, HEADERS)
|
||||||
self.assertEqual(headers.height, 19)
|
self.assertEqual(headers.height, 19)
|
||||||
|
@ -49,6 +50,7 @@ class TestHeaders(AsyncioTestCase):
|
||||||
async def test_connect_from_middle(self):
|
async def test_connect_from_middle(self):
|
||||||
h = Headers(':memory:')
|
h = Headers(':memory:')
|
||||||
h.io.write(HEADERS[:block_bytes(10)])
|
h.io.write(HEADERS[:block_bytes(10)])
|
||||||
|
await h.open()
|
||||||
self.assertEqual(h.height, 9)
|
self.assertEqual(h.height, 9)
|
||||||
await h.connect(len(h), HEADERS[block_bytes(10):block_bytes(20)])
|
await h.connect(len(h), HEADERS[block_bytes(10):block_bytes(20)])
|
||||||
self.assertEqual(h.height, 19)
|
self.assertEqual(h.height, 19)
|
||||||
|
@ -140,6 +142,7 @@ class TestHeaders(AsyncioTestCase):
|
||||||
|
|
||||||
async def test_checkpointed_writer(self):
|
async def test_checkpointed_writer(self):
|
||||||
headers = Headers(':memory:')
|
headers = Headers(':memory:')
|
||||||
|
await headers.open()
|
||||||
getblocks = lambda start, end: HEADERS[block_bytes(start):block_bytes(end)]
|
getblocks = lambda start, end: HEADERS[block_bytes(start):block_bytes(end)]
|
||||||
headers.checkpoint = 10, hexlify(sha256(getblocks(10, 11)))
|
headers.checkpoint = 10, hexlify(sha256(getblocks(10, 11)))
|
||||||
async with headers.checkpointed_connector() as buff:
|
async with headers.checkpointed_connector() as buff:
|
||||||
|
@ -149,6 +152,7 @@ class TestHeaders(AsyncioTestCase):
|
||||||
buff.write(getblocks(10, 19))
|
buff.write(getblocks(10, 19))
|
||||||
self.assertEqual(len(headers), 19)
|
self.assertEqual(len(headers), 19)
|
||||||
headers = Headers(':memory:')
|
headers = Headers(':memory:')
|
||||||
|
await headers.open()
|
||||||
async with headers.checkpointed_connector() as buff:
|
async with headers.checkpointed_connector() as buff:
|
||||||
buff.write(getblocks(0, 19))
|
buff.write(getblocks(0, 19))
|
||||||
self.assertEqual(len(headers), 19)
|
self.assertEqual(len(headers), 19)
|
||||||
|
|
|
@ -67,7 +67,7 @@ class LedgerTestCase(AsyncioTestCase):
|
||||||
serialized = self.make_header(**kwargs)
|
serialized = self.make_header(**kwargs)
|
||||||
self.ledger.headers.io.seek(0, os.SEEK_END)
|
self.ledger.headers.io.seek(0, os.SEEK_END)
|
||||||
self.ledger.headers.io.write(serialized)
|
self.ledger.headers.io.write(serialized)
|
||||||
self.ledger.headers._size = None
|
self.ledger.headers._size = self.ledger.headers.io.seek(0, os.SEEK_END) // self.ledger.headers.header_size
|
||||||
|
|
||||||
|
|
||||||
class TestSynchronization(LedgerTestCase):
|
class TestSynchronization(LedgerTestCase):
|
||||||
|
|
Loading…
Reference in a new issue