import os import logging from io import BytesIO from typing import Optional, Iterator, Tuple from binascii import hexlify from twisted.internet import threads, defer from torba.util import ArithUint256 from torba.hash import double_sha256 log = logging.getLogger(__name__) class InvalidHeader(Exception): def __init__(self, height, message): super().__init__(message) self.message = message self.height = height class BaseHeaders: header_size: int chunk_size: int max_target: int genesis_hash: Optional[bytes] target_timespan: int validate_difficulty: bool = True def __init__(self, path) -> None: if path == ':memory:': self.io = BytesIO() self.path = path self._size: Optional[int] = None self._header_connect_lock = defer.DeferredLock() def open(self): if self.path != ':memory:': self.io = open(self.path, 'a+b') return defer.succeed(True) def close(self): self.io.close() return defer.succeed(True) @staticmethod def serialize(header: dict) -> bytes: raise NotImplementedError @staticmethod def deserialize(height, header): raise NotImplementedError def get_next_chunk_target(self, chunk: int) -> ArithUint256: return ArithUint256(self.max_target) @staticmethod def get_next_block_target(chunk_target: ArithUint256, previous: Optional[dict], current: Optional[dict]) -> ArithUint256: return chunk_target def __len__(self) -> int: if self._size is None: self._size = self.io.seek(0, os.SEEK_END) // self.header_size return self._size def __bool__(self): return True def __getitem__(self, height) -> dict: assert not isinstance(height, slice), \ "Slicing of header chain has not been implemented yet." return self.deserialize(height, self.get_raw_header(height)) def get_raw_header(self, height) -> bytes: self.io.seek(height * self.header_size, os.SEEK_SET) return self.io.read(self.header_size) @property def height(self) -> int: return len(self)-1 def hash(self, height=None) -> bytes: return self.hash_header( self.get_raw_header(height if height is not None else self.height) ) @staticmethod def hash_header(header: bytes) -> bytes: if header is None: return b'0' * 64 return hexlify(double_sha256(header)[::-1]) @defer.inlineCallbacks def connect(self, start: int, headers: bytes): added = 0 bail = False yield self._header_connect_lock.acquire() try: for height, chunk in self._iterate_chunks(start, headers): try: # validate_chunk() is CPU bound and reads previous chunks from file system yield threads.deferToThread(self.validate_chunk, height, chunk) except InvalidHeader as e: bail = True chunk = chunk[:(height-e.height+1)*self.header_size] written = 0 if chunk: self.io.seek(height * self.header_size, os.SEEK_SET) written = self.io.write(chunk) // self.header_size self.io.truncate() # .seek()/.write()/.truncate() might also .flush() when needed # the goal here is mainly to ensure we're definitely flush()'ing yield threads.deferToThread(self.io.flush) self._size = None added += written if bail: break finally: self._header_connect_lock.release() defer.returnValue(added) def validate_chunk(self, height, chunk): previous_hash, previous_header, previous_previous_header = None, None, None if height > 0: previous_header = self[height-1] previous_hash = self.hash(height-1) if height > 1: previous_previous_header = self[height-2] chunk_target = self.get_next_chunk_target(height // 2016 - 1) for current_hash, current_header in self._iterate_headers(height, chunk): block_target = self.get_next_block_target(chunk_target, previous_previous_header, previous_header) self.validate_header(height, current_hash, current_header, previous_hash, block_target) previous_previous_header = previous_header previous_header = current_header previous_hash = current_hash def validate_header(self, height: int, current_hash: bytes, header: dict, previous_hash: bytes, target: ArithUint256): if previous_hash is None: if self.genesis_hash is not None and self.genesis_hash != current_hash: raise InvalidHeader( height, "genesis header doesn't match: {} vs expected {}".format( current_hash.decode(), self.genesis_hash.decode()) ) return if header['prev_block_hash'] != previous_hash: raise InvalidHeader( height, "previous hash mismatch: {} vs expected {}".format( header['prev_block_hash'].decode(), previous_hash.decode()) ) if self.validate_difficulty: if header['bits'] != target.compact: raise InvalidHeader( height, "bits mismatch: {} vs expected {}".format( header['bits'], target.compact) ) proof_of_work = self.get_proof_of_work(current_hash) if proof_of_work > target: raise InvalidHeader( height, "insufficient proof of work: {} vs target {}".format( proof_of_work.value, target.value) ) @staticmethod def get_proof_of_work(header_hash: bytes) -> ArithUint256: return ArithUint256(int(b'0x' + header_hash, 16)) def _iterate_chunks(self, height: int, headers: bytes) -> Iterator[Tuple[int, bytes]]: assert len(headers) % self.header_size == 0 start = 0 end = (self.chunk_size - height % self.chunk_size) * self.header_size while start < end: yield height + (start // self.header_size), headers[start:end] start = end end = min(len(headers), end + self.chunk_size * self.header_size) def _iterate_headers(self, height: int, headers: bytes) -> Iterator[Tuple[bytes, dict]]: assert len(headers) % self.header_size == 0 for idx in range(len(headers) // self.header_size): start, end = idx * self.header_size, (idx + 1) * self.header_size header = headers[start:end] yield self.hash_header(header), self.deserialize(height+idx, header)