add repair ability to baseheaders

This commit is contained in:
Victor Shyba 2019-07-09 01:07:07 -03:00 committed by Lex Berezhny
parent 14f743c493
commit 9990889cce
2 changed files with 36 additions and 1 deletions

View file

@ -94,6 +94,27 @@ class BasicHeadersTests(BitcoinHeadersTestCase):
await headers.connect(len(headers), remainder) await headers.connect(len(headers), remainder)
self.assertEqual(headers.height, 32259) self.assertEqual(headers.height, 32259)
async def test_repair(self):
headers = MainHeaders(':memory:')
await headers.connect(0, self.get_bytes(block_bytes(3001)))
self.assertEqual(headers.height, 3000)
headers.repair()
self.assertEqual(headers.height, 3000)
# corrupt the middle of it
headers.io.seek(block_bytes(1500))
headers.io.write(b"wtf")
headers.repair()
self.assertEqual(headers.height, 1499)
self.assertEqual(len(headers), 1500)
# corrupt by appending
headers.io.seek(block_bytes(len(headers)))
headers.io.write(b"appending")
headers._size = None
headers.repair()
self.assertEqual(headers.height, 1499)
await headers.connect(len(headers), self.get_bytes(block_bytes(3001 - 1500), after=block_bytes(1500)))
self.assertEqual(headers.height, 3000)
async def test_concurrency(self): async def test_concurrency(self):
BLOCKS = 30 BLOCKS = 30
headers_temporary_file = tempfile.mktemp() headers_temporary_file = tempfile.mktemp()

View file

@ -164,12 +164,26 @@ class BaseHeaders:
proof_of_work.value, target.value) proof_of_work.value, target.value)
) )
def repair(self):
for height in range(self.height):
chunk = self.get_raw_header(height)
try:
# validate_chunk() is CPU bound and reads previous chunks from file system
self.validate_chunk(height, chunk)
except InvalidHeader as e:
log.warning("Header file corrupted at height %s, truncating it.", e.height)
self.io.seek((e.height) * self.header_size, os.SEEK_SET)
self.io.truncate()
self.io.flush()
self._size = None
return
@staticmethod @staticmethod
def get_proof_of_work(header_hash: bytes) -> ArithUint256: def get_proof_of_work(header_hash: bytes) -> ArithUint256:
return ArithUint256(int(b'0x' + header_hash, 16)) return ArithUint256(int(b'0x' + header_hash, 16))
def _iterate_chunks(self, height: int, headers: bytes) -> Iterator[Tuple[int, bytes]]: def _iterate_chunks(self, height: int, headers: bytes) -> Iterator[Tuple[int, bytes]]:
assert len(headers) % self.header_size == 0 assert len(headers) % self.header_size == 0, f"{len(headers)} {len(headers)%self.header_size}"
start = 0 start = 0
end = (self.chunk_size - height % self.chunk_size) * self.header_size end = (self.chunk_size - height % self.chunk_size) * self.header_size
while start < end: while start < end: