lbry-sdk/torba/baseheader.py

197 lines
7 KiB
Python
Raw Normal View History

2018-06-11 15:33:32 +02:00
import os
import logging
from io import BytesIO
from typing import Optional, Iterator, Tuple
from binascii import hexlify
2018-06-11 15:33:32 +02:00
from twisted.internet import threads, defer
from torba.stream import StreamController
from torba.util import ArithUint256
from torba.hash import double_sha256
2018-06-11 15:33:32 +02:00
log = logging.getLogger(__name__)
2018-06-11 15:33:32 +02:00
class InvalidHeader(Exception):
def __init__(self, height, message):
super().__init__(message)
self.message = message
self.height = height
2018-06-11 15:33:32 +02:00
class BaseHeaders:
header_size: int
chunk_size: int
2018-06-11 15:33:32 +02:00
max_target: int
2018-08-16 06:56:46 +02:00
genesis_hash: Optional[bytes]
target_timespan: int
validate_difficulty: bool = True
def __init__(self, path) -> None:
if path == ':memory:':
self.io = BytesIO()
self.path = path
2018-08-16 06:56:46 +02:00
self._size: Optional[int] = None
2018-06-11 15:33:32 +02:00
self._on_change_controller = StreamController()
self.on_changed = self._on_change_controller.stream
self._header_connect_lock = defer.DeferredLock()
2018-06-11 15:33:32 +02:00
def open(self):
if self.path != ':memory:':
self.io = open(self.path, 'a+b')
return defer.succeed(True)
2018-06-11 15:33:32 +02:00
def close(self):
self.io.close()
return defer.succeed(True)
2018-06-11 15:33:32 +02:00
@staticmethod
def serialize(header: dict) -> bytes:
raise NotImplementedError
2018-07-01 23:20:17 +02:00
@staticmethod
def deserialize(height, header):
raise NotImplementedError
2018-06-11 15:33:32 +02:00
def get_next_chunk_target(self, chunk: int) -> ArithUint256:
return ArithUint256(self.max_target)
2018-06-11 15:33:32 +02:00
2018-08-16 06:56:46 +02:00
@staticmethod
def get_next_block_target(chunk_target: ArithUint256, previous: Optional[dict],
current: Optional[dict]) -> ArithUint256:
return chunk_target
2018-06-11 15:33:32 +02:00
def __len__(self) -> int:
2018-06-11 15:33:32 +02:00
if self._size is None:
self._size = self.io.seek(0, os.SEEK_END) // self.header_size
2018-06-11 15:33:32 +02:00
return self._size
def __bool__(self):
return True
def __getitem__(self, height) -> dict:
2018-06-11 15:33:32 +02:00
assert not isinstance(height, slice), \
"Slicing of header chain has not been implemented yet."
return self.deserialize(height, self.get_raw_header(height))
2018-06-11 15:33:32 +02:00
def get_raw_header(self, height) -> bytes:
self.io.seek(height * self.header_size, os.SEEK_SET)
return self.io.read(self.header_size)
2018-06-11 15:33:32 +02:00
@property
def height(self) -> int:
return len(self)-1
2018-06-11 15:33:32 +02:00
def hash(self, height=None) -> bytes:
return self.hash_header(
self.get_raw_header(height or self.height)
)
2018-06-11 15:33:32 +02:00
@staticmethod
def hash_header(header: bytes) -> bytes:
2018-06-11 15:33:32 +02:00
if header is None:
return b'0' * 64
return hexlify(double_sha256(header)[::-1])
2018-06-11 15:33:32 +02:00
@defer.inlineCallbacks
def connect(self, start: int, headers: bytes):
added = 0
bail = False
yield self._header_connect_lock.acquire()
try:
for height, chunk in self._iterate_chunks(start, headers):
try:
# validate_chunk() is CPU bound on large chunks
yield threads.deferToThread(self.validate_chunk, height, chunk)
except InvalidHeader as e:
bail = True
chunk = chunk[:(height-e.height)*self.header_size]
written = 0
if chunk:
self.io.seek(height * self.header_size, os.SEEK_SET)
written = self.io.write(chunk) // self.header_size
self.io.truncate()
# .seek()/.write()/.truncate() might also .flush() when needed
# the goal here is mainly to ensure we're definitely flush()'ing
yield threads.deferToThread(self.io.flush)
self._size = None
self._on_change_controller.add(written)
added += written
if bail:
break
finally:
self._header_connect_lock.release()
defer.returnValue(added)
def validate_chunk(self, height, chunk):
previous_hash, previous_header, previous_previous_header = None, None, None
if height > 0:
previous_header = self[height-1]
previous_hash = self.hash(height-1)
if height > 1:
previous_previous_header = self[height-2]
chunk_target = self.get_next_chunk_target(height // 2016 - 1)
for current_hash, current_header in self._iterate_headers(height, chunk):
block_target = self.get_next_block_target(chunk_target, previous_previous_header, previous_header)
self.validate_header(height, current_hash, current_header, previous_hash, block_target)
previous_previous_header = previous_header
previous_header = current_header
previous_hash = current_hash
def validate_header(self, height: int, current_hash: bytes,
header: dict, previous_hash: bytes, target: ArithUint256):
if previous_hash is None:
if self.genesis_hash is not None and self.genesis_hash != current_hash:
raise InvalidHeader(
height, "genesis header doesn't match: {} vs expected {}".format(
current_hash.decode(), self.genesis_hash.decode())
)
return
if header['prev_block_hash'] != previous_hash:
raise InvalidHeader(
height, "previous hash mismatch: {} vs expected {}".format(
header['prev_block_hash'].decode(), previous_hash.decode())
)
if self.validate_difficulty:
if header['bits'] != target.compact:
raise InvalidHeader(
height, "bits mismatch: {} vs expected {}".format(
header['bits'], target.compact)
)
proof_of_work = self.get_proof_of_work(current_hash)
if proof_of_work > target:
raise InvalidHeader(
height, "insufficient proof of work: {} vs target {}".format(
proof_of_work.value, target.value)
)
2018-06-11 15:33:32 +02:00
@staticmethod
def get_proof_of_work(header_hash: bytes) -> ArithUint256:
return ArithUint256(int(b'0x' + header_hash, 16))
def _iterate_chunks(self, height: int, headers: bytes) -> Iterator[Tuple[int, bytes]]:
assert len(headers) % self.header_size == 0
start = 0
end = (self.chunk_size - height % self.chunk_size) * self.header_size
while start < end:
yield height + (start // self.header_size), headers[start:end]
start = end
end = min(len(headers), end + self.chunk_size * self.header_size)
def _iterate_headers(self, height: int, headers: bytes) -> Iterator[Tuple[bytes, dict]]:
assert len(headers) % self.header_size == 0
for idx in range(len(headers) // self.header_size):
start, end = idx * self.header_size, (idx + 1) * self.header_size
header = headers[start:end]
yield self.hash_header(header), self.deserialize(height+idx, header)