lbry-sdk/torba/client/baseheader.py
2019-06-19 01:49:47 -04:00

189 lines
6.8 KiB
Python

import os
import asyncio
import logging
from io import BytesIO
from typing import Optional, Iterator, Tuple
from binascii import hexlify
from torba.client.util import ArithUint256
from torba.client.hash import double_sha256
log = logging.getLogger(__name__)
class InvalidHeader(Exception):
def __init__(self, height, message):
super().__init__(message)
self.message = message
self.height = height
class BaseHeaders:
header_size: int
chunk_size: int
max_target: int
genesis_hash: Optional[bytes]
target_timespan: int
validate_difficulty: bool = True
def __init__(self, path) -> None:
if path == ':memory:':
self.io = BytesIO()
self.path = path
self._size: Optional[int] = None
self._header_connect_lock = asyncio.Lock()
async def open(self):
if self.path != ':memory:':
if not os.path.exists(self.path):
self.io = open(self.path, 'w+b')
else:
self.io = open(self.path, 'r+b')
async def close(self):
self.io.close()
@staticmethod
def serialize(header: dict) -> bytes:
raise NotImplementedError
@staticmethod
def deserialize(height, header):
raise NotImplementedError
def get_next_chunk_target(self, chunk: int) -> ArithUint256:
return ArithUint256(self.max_target)
@staticmethod
def get_next_block_target(chunk_target: ArithUint256, previous: Optional[dict],
current: Optional[dict]) -> ArithUint256:
return chunk_target
def __len__(self) -> int:
if self._size is None:
self._size = self.io.seek(0, os.SEEK_END) // self.header_size
return self._size
def __bool__(self):
return True
def __getitem__(self, height) -> dict:
assert not isinstance(height, slice), \
"Slicing of header chain has not been implemented yet."
return self.deserialize(height, self.get_raw_header(height))
def get_raw_header(self, height) -> bytes:
self.io.seek(height * self.header_size, os.SEEK_SET)
return self.io.read(self.header_size)
@property
def height(self) -> int:
return len(self)-1
def hash(self, height=None) -> bytes:
return self.hash_header(
self.get_raw_header(height if height is not None else self.height)
)
@staticmethod
def hash_header(header: bytes) -> bytes:
if header is None:
return b'0' * 64
return hexlify(double_sha256(header)[::-1])
async def connect(self, start: int, headers: bytes) -> int:
added = 0
bail = False
loop = asyncio.get_running_loop()
async with self._header_connect_lock:
for height, chunk in self._iterate_chunks(start, headers):
try:
# validate_chunk() is CPU bound and reads previous chunks from file system
await loop.run_in_executor(None, self.validate_chunk, height, chunk)
except InvalidHeader as e:
bail = True
chunk = chunk[:(height-e.height)*self.header_size]
written = 0
if chunk:
self.io.seek(height * self.header_size, os.SEEK_SET)
written = self.io.write(chunk) // self.header_size
self.io.truncate()
# .seek()/.write()/.truncate() might also .flush() when needed
# the goal here is mainly to ensure we're definitely flush()'ing
await loop.run_in_executor(None, self.io.flush)
self._size = None
added += written
if bail:
break
return added
def validate_chunk(self, height, chunk):
previous_hash, previous_header, previous_previous_header = None, None, None
if height > 0:
previous_header = self[height-1]
previous_hash = self.hash(height-1)
if height > 1:
previous_previous_header = self[height-2]
chunk_target = self.get_next_chunk_target(height // 2016 - 1)
for current_hash, current_header in self._iterate_headers(height, chunk):
block_target = self.get_next_block_target(chunk_target, previous_previous_header, previous_header)
self.validate_header(height, current_hash, current_header, previous_hash, block_target)
previous_previous_header = previous_header
previous_header = current_header
previous_hash = current_hash
def validate_header(self, height: int, current_hash: bytes,
header: dict, previous_hash: bytes, target: ArithUint256):
if previous_hash is None:
if self.genesis_hash is not None and self.genesis_hash != current_hash:
raise InvalidHeader(
height, "genesis header doesn't match: {} vs expected {}".format(
current_hash.decode(), self.genesis_hash.decode())
)
return
if header['prev_block_hash'] != previous_hash:
raise InvalidHeader(
height, "previous hash mismatch: {} vs expected {}".format(
header['prev_block_hash'].decode(), previous_hash.decode())
)
if self.validate_difficulty:
if header['bits'] != target.compact:
raise InvalidHeader(
height, "bits mismatch: {} vs expected {}".format(
header['bits'], target.compact)
)
proof_of_work = self.get_proof_of_work(current_hash)
if proof_of_work > target:
raise InvalidHeader(
height, "insufficient proof of work: {} vs target {}".format(
proof_of_work.value, target.value)
)
@staticmethod
def get_proof_of_work(header_hash: bytes) -> ArithUint256:
return ArithUint256(int(b'0x' + header_hash, 16))
def _iterate_chunks(self, height: int, headers: bytes) -> Iterator[Tuple[int, bytes]]:
assert len(headers) % self.header_size == 0
start = 0
end = (self.chunk_size - height % self.chunk_size) * self.header_size
while start < end:
yield height + (start // self.header_size), headers[start:end]
start = end
end = min(len(headers), end + self.chunk_size * self.header_size)
def _iterate_headers(self, height: int, headers: bytes) -> Iterator[Tuple[bytes, dict]]:
assert len(headers) % self.header_size == 0
for idx in range(len(headers) // self.header_size):
start, end = idx * self.header_size, (idx + 1) * self.header_size
header = headers[start:end]
yield self.hash_header(header), self.deserialize(height+idx, header)