better locking, stop corrupting headers, fix some tests

This commit is contained in:
Victor Shyba 2020-03-18 14:17:01 -03:00
parent 241e946d91
commit b04a516063
4 changed files with 49 additions and 33 deletions

View file

@ -5,6 +5,7 @@ import asyncio
import hashlib import hashlib
import logging import logging
import zlib import zlib
from concurrent.futures.thread import ThreadPoolExecutor
from io import BytesIO from io import BytesIO
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
@ -47,6 +48,7 @@ class Headers:
self.path = path self.path = path
self._size: Optional[int] = None self._size: Optional[int] = None
self.chunk_getter: Optional[Callable] = None self.chunk_getter: Optional[Callable] = None
self.executor = ThreadPoolExecutor(1)
async def open(self): async def open(self):
if self.path != ':memory:': if self.path != ':memory:':
@ -54,8 +56,10 @@ class Headers:
self.io = open(self.path, 'w+b') self.io = open(self.path, 'w+b')
else: else:
self.io = open(self.path, 'r+b') self.io = open(self.path, 'r+b')
self._size = self.io.seek(0, os.SEEK_END) // self.header_size
async def close(self): async def close(self):
self.executor.shutdown()
self.io.close() self.io.close()
@staticmethod @staticmethod
@ -103,27 +107,34 @@ class Headers:
return new_target return new_target
def __len__(self) -> int: def __len__(self) -> int:
if self._size is None:
self._size = self.io.seek(0, os.SEEK_END) // self.header_size
return self._size return self._size
def __bool__(self): def __bool__(self):
return True return True
async def get(self, height) -> dict: async def get(self, height) -> dict:
if height < 0:
raise IndexError(f"Height cannot be negative!!")
if isinstance(height, slice): if isinstance(height, slice):
raise NotImplementedError("Slicing of header chain has not been implemented yet.") raise NotImplementedError("Slicing of header chain has not been implemented yet.")
return self.deserialize(height, await self.get_raw_header(height)) try:
return self.deserialize(height, await self.get_raw_header(height))
except struct.error:
raise IndexError(f"failed to get {height}, at {len(self)}")
def estimated_timestamp(self, height): def estimated_timestamp(self, height):
return self.first_block_timestamp + (height * self.timestamp_average_offset) return self.first_block_timestamp + (height * self.timestamp_average_offset)
async def get_raw_header(self, height) -> bytes: async def get_raw_header(self, height) -> bytes:
await self.ensure_chunk_at(height) if self.chunk_getter:
self.io.seek(height * self.header_size, os.SEEK_SET) await self.ensure_chunk_at(height)
return self.io.read(self.header_size) return await asyncio.get_running_loop().run_in_executor(self.executor, self._read, height)
async def chunk_hash(self, start, count): def _read(self, height, count=1):
self.io.seek(height * self.header_size, os.SEEK_SET)
return self.io.read(self.header_size * count)
def chunk_hash(self, start, count):
self.io.seek(start * self.header_size, os.SEEK_SET) self.io.seek(start * self.header_size, os.SEEK_SET)
return self.hash_header(self.io.read(count * self.header_size)).decode() return self.hash_header(self.io.read(count * self.header_size)).decode()
@ -141,16 +152,20 @@ class Headers:
zlib.decompress(base64.b64decode(headers['base64']), wbits=-15, bufsize=600_000) zlib.decompress(base64.b64decode(headers['base64']), wbits=-15, bufsize=600_000)
) )
chunk_hash = self.hash_header(chunk).decode() chunk_hash = self.hash_header(chunk).decode()
if HASHES[start] == chunk_hash: if HASHES.get(start) == chunk_hash:
return self._write(start, chunk) return await asyncio.get_running_loop().run_in_executor(self.executor, self._write, start, chunk)
elif start not in HASHES:
return # todo: fixme
raise Exception( raise Exception(
f"Checkpoint mismatch at height {start}. Expected {HASHES[start]}, but got {chunk_hash} instead." f"Checkpoint mismatch at height {start}. Expected {HASHES[start]}, but got {chunk_hash} instead."
) )
async def has_header(self, height): async def has_header(self, height):
empty = '56944c5d3f98413ef45cf54545538103cc9f298e0575820ad3591376e2e0f65d' def _has_header(height):
all_zeroes = '789d737d4f448e554b318c94063bbfa63e9ccda6e208f5648ca76ee68896557b' empty = '56944c5d3f98413ef45cf54545538103cc9f298e0575820ad3591376e2e0f65d'
return await self.chunk_hash(height, 1) not in (empty, all_zeroes) all_zeroes = '789d737d4f448e554b318c94063bbfa63e9ccda6e208f5648ca76ee68896557b'
return self.chunk_hash(height, 1) not in (empty, all_zeroes)
return await asyncio.get_running_loop().run_in_executor(self.executor, _has_header, height)
@property @property
def height(self) -> int: def height(self) -> int:
@ -216,11 +231,11 @@ class Headers:
def _write(self, height, verified_chunk): def _write(self, height, verified_chunk):
self.io.seek(height * self.header_size, os.SEEK_SET) self.io.seek(height * self.header_size, os.SEEK_SET)
written = self.io.write(verified_chunk) // self.header_size written = self.io.write(verified_chunk) // self.header_size
self.io.truncate() # self.io.truncate()
# .seek()/.write()/.truncate() might also .flush() when needed # .seek()/.write()/.truncate() might also .flush() when needed
# the goal here is mainly to ensure we're definitely flush()'ing # the goal here is mainly to ensure we're definitely flush()'ing
self.io.flush() self.io.flush()
self._size = self.io.tell() // self.header_size self._size = self.io.seek(0, os.SEEK_END) // self.header_size
return written return written
async def validate_chunk(self, height, chunk): async def validate_chunk(self, height, chunk):
@ -272,8 +287,9 @@ class Headers:
previous_header_hash = fail = None previous_header_hash = fail = None
batch_size = 36 batch_size = 36
for start_height in range(0, self.height, batch_size): for start_height in range(0, self.height, batch_size):
self.io.seek(self.header_size * start_height) headers = await asyncio.get_running_loop().run_in_executor(
headers = self.io.read(self.header_size*batch_size) self.executor, self._read, start_height, batch_size
)
if len(headers) % self.header_size != 0: if len(headers) % self.header_size != 0:
headers = headers[:(len(headers) // self.header_size) * self.header_size] headers = headers[:(len(headers) // self.header_size) * self.header_size]
for header_hash, header in self._iterate_headers(start_height, headers): for header_hash, header in self._iterate_headers(start_height, headers):
@ -286,11 +302,12 @@ class Headers:
fail = True fail = True
if fail: if fail:
log.warning("Header file corrupted at height %s, truncating it.", height - 1) log.warning("Header file corrupted at height %s, truncating it.", height - 1)
self.io.seek(max(0, (height - 1)) * self.header_size, os.SEEK_SET) def __truncate(at_height):
self.io.truncate() self.io.seek(max(0, (at_height - 1)) * self.header_size, os.SEEK_SET)
self.io.flush() self.io.truncate()
self._size = None self.io.flush()
return self._size = self.io.seek(0, os.SEEK_END) // self.header_size
return await asyncio.get_running_loop().run_in_executor(self.executor, __truncate, height)
previous_header_hash = header_hash previous_header_hash = header_hash
@classmethod @classmethod

View file

@ -316,9 +316,6 @@ class Ledger(metaclass=LedgerRegistry):
first_connection = self.network.on_connected.first first_connection = self.network.on_connected.first
asyncio.ensure_future(self.network.start()) asyncio.ensure_future(self.network.start())
await first_connection await first_connection
async with self._header_processing_lock:
await self._update_tasks.add(self.initial_headers_sync())
await self._on_ready_controller.stream.first
await asyncio.gather(*(a.maybe_migrate_certificates() for a in self.accounts)) await asyncio.gather(*(a.maybe_migrate_certificates() for a in self.accounts))
await asyncio.gather(*(a.save_max_gap() for a in self.accounts)) await asyncio.gather(*(a.save_max_gap() for a in self.accounts))
if len(self.accounts) > 10: if len(self.accounts) > 10:
@ -329,8 +326,10 @@ class Ledger(metaclass=LedgerRegistry):
async def join_network(self, *_): async def join_network(self, *_):
log.info("Subscribing and updating accounts.") log.info("Subscribing and updating accounts.")
#async with self._header_processing_lock: self._update_tasks.add(self.initial_headers_sync())
# await self.update_headers() async with self._header_processing_lock:
await self.headers.ensure_tip()
await self.update_headers()
await self.subscribe_accounts() await self.subscribe_accounts()
await self._update_tasks.done.wait() await self._update_tasks.done.wait()
self._on_ready_controller.add(True) self._on_ready_controller.add(True)
@ -348,16 +347,12 @@ class Ledger(metaclass=LedgerRegistry):
async def initial_headers_sync(self): async def initial_headers_sync(self):
target = self.network.remote_height + 1 target = self.network.remote_height + 1
current = len(self.headers)
get_chunk = partial(self.network.retriable_call, self.network.get_headers, count=1000, b64=True) get_chunk = partial(self.network.retriable_call, self.network.get_headers, count=1000, b64=True)
self.headers.chunk_getter = get_chunk self.headers.chunk_getter = get_chunk
await self.headers.ensure_tip()
async def doit(): async def doit():
for height in range(current, target, 1000): for height in reversed(range(0, target, 1000)):
await self.headers.ensure_chunk_at(height) await self.headers.ensure_chunk_at(height)
self._download_height = height
log.info("Headers sync: %s / %s", self._download_height, target)
asyncio.ensure_future(doit()) asyncio.ensure_future(doit())
return return
@ -598,7 +593,7 @@ class Ledger(metaclass=LedgerRegistry):
async def maybe_verify_transaction(self, tx, remote_height): async def maybe_verify_transaction(self, tx, remote_height):
tx.height = remote_height tx.height = remote_height
if 0 < remote_height < self.network.remote_height: if 0 < remote_height < len(self.headers):
merkle = await self.network.retriable_call(self.network.get_merkle, tx.id, remote_height) merkle = await self.network.retriable_call(self.network.get_merkle, tx.id, remote_height)
merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash) merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash)
header = await self.headers.get(remote_height) header = await self.headers.get(remote_height)

View file

@ -42,6 +42,7 @@ class TestHeaders(AsyncioTestCase):
async def test_connect_from_genesis(self): async def test_connect_from_genesis(self):
headers = Headers(':memory:') headers = Headers(':memory:')
await headers.open()
self.assertEqual(headers.height, -1) self.assertEqual(headers.height, -1)
await headers.connect(0, HEADERS) await headers.connect(0, HEADERS)
self.assertEqual(headers.height, 19) self.assertEqual(headers.height, 19)
@ -49,6 +50,7 @@ class TestHeaders(AsyncioTestCase):
async def test_connect_from_middle(self): async def test_connect_from_middle(self):
h = Headers(':memory:') h = Headers(':memory:')
h.io.write(HEADERS[:block_bytes(10)]) h.io.write(HEADERS[:block_bytes(10)])
await h.open()
self.assertEqual(h.height, 9) self.assertEqual(h.height, 9)
await h.connect(len(h), HEADERS[block_bytes(10):block_bytes(20)]) await h.connect(len(h), HEADERS[block_bytes(10):block_bytes(20)])
self.assertEqual(h.height, 19) self.assertEqual(h.height, 19)
@ -140,6 +142,7 @@ class TestHeaders(AsyncioTestCase):
async def test_checkpointed_writer(self): async def test_checkpointed_writer(self):
headers = Headers(':memory:') headers = Headers(':memory:')
await headers.open()
getblocks = lambda start, end: HEADERS[block_bytes(start):block_bytes(end)] getblocks = lambda start, end: HEADERS[block_bytes(start):block_bytes(end)]
headers.checkpoint = 10, hexlify(sha256(getblocks(10, 11))) headers.checkpoint = 10, hexlify(sha256(getblocks(10, 11)))
async with headers.checkpointed_connector() as buff: async with headers.checkpointed_connector() as buff:
@ -149,6 +152,7 @@ class TestHeaders(AsyncioTestCase):
buff.write(getblocks(10, 19)) buff.write(getblocks(10, 19))
self.assertEqual(len(headers), 19) self.assertEqual(len(headers), 19)
headers = Headers(':memory:') headers = Headers(':memory:')
await headers.open()
async with headers.checkpointed_connector() as buff: async with headers.checkpointed_connector() as buff:
buff.write(getblocks(0, 19)) buff.write(getblocks(0, 19))
self.assertEqual(len(headers), 19) self.assertEqual(len(headers), 19)

View file

@ -67,7 +67,7 @@ class LedgerTestCase(AsyncioTestCase):
serialized = self.make_header(**kwargs) serialized = self.make_header(**kwargs)
self.ledger.headers.io.seek(0, os.SEEK_END) self.ledger.headers.io.seek(0, os.SEEK_END)
self.ledger.headers.io.write(serialized) self.ledger.headers.io.write(serialized)
self.ledger.headers._size = None self.ledger.headers._size = self.ledger.headers.io.seek(0, os.SEEK_END) // self.ledger.headers.header_size
class TestSynchronization(LedgerTestCase): class TestSynchronization(LedgerTestCase):