import os import struct import asyncio import hashlib import logging from io import BytesIO from contextlib import asynccontextmanager from typing import Optional, Iterator, Tuple from binascii import hexlify, unhexlify from lbry.crypto.hash import sha512, double_sha256, ripemd160 from lbry.wallet.util import ArithUint256 log = logging.getLogger(__name__) class InvalidHeader(Exception): def __init__(self, height, message): super().__init__(message) self.message = message self.height = height class Headers: header_size = 112 chunk_size = 10**16 max_target = 0x0000ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff genesis_hash = b'9c89283ba0f3227f6c03b70216b9f665f0118d5e0fa729cedf4fb34d6a34f463' target_timespan = 150 checkpoint = (600_000, b'100b33ca3d0b86a48f0d6d6f30458a130ecb89d5affefe4afccb134d5a40f4c2') validate_difficulty: bool = True def __init__(self, path) -> None: if path == ':memory:': self.io = BytesIO() self.path = path self._size: Optional[int] = None async def open(self): if self.path != ':memory:': if not os.path.exists(self.path): self.io = open(self.path, 'w+b') else: self.io = open(self.path, 'r+b') async def close(self): self.io.close() @staticmethod def serialize(header): return b''.join([ struct.pack(' ArithUint256: return ArithUint256(self.max_target) def get_next_block_target(self, max_target: ArithUint256, previous: Optional[dict], current: Optional[dict]) -> ArithUint256: # https://github.com/lbryio/lbrycrd/blob/master/src/lbry.cpp if previous is None and current is None: return max_target if previous is None: previous = current actual_timespan = current['timestamp'] - previous['timestamp'] modulated_timespan = self.target_timespan + int((actual_timespan - self.target_timespan) / 8) minimum_timespan = self.target_timespan - int(self.target_timespan / 8) # 150 - 18 = 132 maximum_timespan = self.target_timespan + int(self.target_timespan / 2) # 150 + 75 = 225 clamped_timespan = max(minimum_timespan, min(modulated_timespan, maximum_timespan)) target = ArithUint256.from_compact(current['bits']) new_target = min(max_target, (target * clamped_timespan) / self.target_timespan) return new_target def __len__(self) -> int: if self._size is None: self._size = self.io.seek(0, os.SEEK_END) // self.header_size return self._size def __bool__(self): return True def __getitem__(self, height) -> dict: if isinstance(height, slice): raise NotImplementedError("Slicing of header chain has not been implemented yet.") if not 0 <= height <= self.height: raise IndexError(f"{height} is out of bounds, current height: {self.height}") return self.deserialize(height, self.get_raw_header(height)) def get_raw_header(self, height) -> bytes: self.io.seek(height * self.header_size, os.SEEK_SET) return self.io.read(self.header_size) @property def height(self) -> int: return len(self)-1 @property def bytes_size(self): return len(self) * self.header_size def hash(self, height=None) -> bytes: return self.hash_header( self.get_raw_header(height if height is not None else self.height) ) @staticmethod def hash_header(header: bytes) -> bytes: if header is None: return b'0' * 64 return hexlify(double_sha256(header)[::-1]) @asynccontextmanager async def checkpointed_connector(self): buf = BytesIO() try: yield buf finally: await asyncio.sleep(0) final_height = len(self) + buf.tell() // self.header_size verifiable_bytes = (self.checkpoint[0] - len(self)) * self.header_size if self.checkpoint else 0 if verifiable_bytes > 0 and final_height >= self.checkpoint[0]: buf.seek(0) self.io.seek(0) h = hashlib.sha256() h.update(self.io.read()) h.update(buf.read(verifiable_bytes)) if h.hexdigest().encode() == self.checkpoint[1]: buf.seek(0) self._write(len(self), buf.read(verifiable_bytes)) remaining = buf.read() buf.seek(0) buf.write(remaining) buf.truncate() else: log.warning("Checkpoint mismatch, connecting headers through slow method.") if buf.tell() > 0: await self.connect(len(self), buf.getvalue()) async def connect(self, start: int, headers: bytes) -> int: added = 0 bail = False for height, chunk in self._iterate_chunks(start, headers): try: # validate_chunk() is CPU bound and reads previous chunks from file system self.validate_chunk(height, chunk) except InvalidHeader as e: bail = True chunk = chunk[:(height-e.height)*self.header_size] added += self._write(height, chunk) if chunk else 0 if bail: break return added def _write(self, height, verified_chunk): self.io.seek(height * self.header_size, os.SEEK_SET) written = self.io.write(verified_chunk) // self.header_size self.io.truncate() # .seek()/.write()/.truncate() might also .flush() when needed # the goal here is mainly to ensure we're definitely flush()'ing self.io.flush() self._size = self.io.tell() // self.header_size return written def validate_chunk(self, height, chunk): previous_hash, previous_header, previous_previous_header = None, None, None if height > 0: previous_header = self[height-1] previous_hash = self.hash(height-1) if height > 1: previous_previous_header = self[height-2] chunk_target = self.get_next_chunk_target(height // 2016 - 1) for current_hash, current_header in self._iterate_headers(height, chunk): block_target = self.get_next_block_target(chunk_target, previous_previous_header, previous_header) self.validate_header(height, current_hash, current_header, previous_hash, block_target) previous_previous_header = previous_header previous_header = current_header previous_hash = current_hash def validate_header(self, height: int, current_hash: bytes, header: dict, previous_hash: bytes, target: ArithUint256): if previous_hash is None: if self.genesis_hash is not None and self.genesis_hash != current_hash: raise InvalidHeader( height, f"genesis header doesn't match: {current_hash.decode()} " f"vs expected {self.genesis_hash.decode()}") return if header['prev_block_hash'] != previous_hash: raise InvalidHeader( height, "previous hash mismatch: {} vs expected {}".format( header['prev_block_hash'].decode(), previous_hash.decode()) ) if self.validate_difficulty: if header['bits'] != target.compact: raise InvalidHeader( height, "bits mismatch: {} vs expected {}".format( header['bits'], target.compact) ) proof_of_work = self.get_proof_of_work(current_hash) if proof_of_work > target: raise InvalidHeader( height, f"insufficient proof of work: {proof_of_work.value} vs target {target.value}" ) async def repair(self): previous_header_hash = fail = None batch_size = 36 for start_height in range(0, self.height, batch_size): self.io.seek(self.header_size * start_height) headers = self.io.read(self.header_size*batch_size) if len(headers) % self.header_size != 0: headers = headers[:(len(headers) // self.header_size) * self.header_size] for header_hash, header in self._iterate_headers(start_height, headers): height = header['block_height'] if height: if header['prev_block_hash'] != previous_header_hash: fail = True else: if header_hash != self.genesis_hash: fail = True if fail: log.warning("Header file corrupted at height %s, truncating it.", height - 1) self.io.seek(max(0, (height - 1)) * self.header_size, os.SEEK_SET) self.io.truncate() self.io.flush() self._size = None return previous_header_hash = header_hash @classmethod def get_proof_of_work(cls, header_hash: bytes): return ArithUint256(int(b'0x' + cls.header_hash_to_pow_hash(header_hash), 16)) def _iterate_chunks(self, height: int, headers: bytes) -> Iterator[Tuple[int, bytes]]: assert len(headers) % self.header_size == 0, f"{len(headers)} {len(headers)%self.header_size}" start = 0 end = (self.chunk_size - height % self.chunk_size) * self.header_size while start < end: yield height + (start // self.header_size), headers[start:end] start = end end = min(len(headers), end + self.chunk_size * self.header_size) def _iterate_headers(self, height: int, headers: bytes) -> Iterator[Tuple[bytes, dict]]: assert len(headers) % self.header_size == 0, len(headers) for idx in range(len(headers) // self.header_size): start, end = idx * self.header_size, (idx + 1) * self.header_size header = headers[start:end] yield self.hash_header(header), self.deserialize(height+idx, header) @property def claim_trie_root(self): return self[self.height]['claim_trie_root'] @staticmethod def header_hash_to_pow_hash(header_hash: bytes): header_hash_bytes = unhexlify(header_hash)[::-1] h = sha512(header_hash_bytes) pow_hash = double_sha256( ripemd160(h[:len(h) // 2]) + ripemd160(h[len(h) // 2:]) ) return hexlify(pow_hash[::-1]) class UnvalidatedHeaders(Headers): validate_difficulty = False max_target = 0x7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff genesis_hash = b'6e3fcf1299d4ec5d79c3a4c91d624a4acf9e2e173d95a1a0504f677669687556'