hub/scribe/common.py

367 lines
9.4 KiB
Python
Raw Normal View History

2022-03-08 17:01:19 +01:00
import hashlib
import hmac
import ipaddress
import logging
import typing
import collections
from asyncio import get_event_loop, Event
from prometheus_client import Counter
log = logging.getLogger(__name__)
_sha256 = hashlib.sha256
_sha512 = hashlib.sha512
_new_hash = hashlib.new
_new_hmac = hmac.new
HASHX_LEN = 11
CLAIM_HASH_LEN = 20
HISTOGRAM_BUCKETS = (
.005, .01, .025, .05, .075, .1, .25, .5, .75, 1.0, 2.5, 5.0, 7.5, 10.0, 15.0, 20.0, 30.0, 60.0, float('inf')
)
2022-03-08 17:01:19 +01:00
# class cachedproperty:
# def __init__(self, f):
# self.f = f
#
# def __get__(self, obj, type):
# obj = obj or type
# value = self.f(obj)
# setattr(obj, self.f.__name__, value)
# return value
def formatted_time(t, sep=' '):
"""Return a number of seconds as a string in days, hours, mins and
maybe secs."""
t = int(t)
fmts = (('{:d}d', 86400), ('{:02d}h', 3600), ('{:02d}m', 60))
parts = []
for fmt, n in fmts:
val = t // n
if parts or val:
parts.append(fmt.format(val))
t %= n
if len(parts) < 3:
parts.append(f'{t:02d}s')
return sep.join(parts)
def protocol_tuple(s):
"""Converts a protocol version number, such as "1.0" to a tuple (1, 0).
If the version number is bad, (0, ) indicating version 0 is returned."""
try:
return tuple(int(part) for part in s.split('.'))
except Exception:
return (0, )
def version_string(ptuple):
"""Convert a version tuple such as (1, 2) to "1.2".
There is always at least one dot, so (1, ) becomes "1.0"."""
while len(ptuple) < 2:
ptuple += (0, )
return '.'.join(str(p) for p in ptuple)
def protocol_version(client_req, min_tuple, max_tuple):
"""Given a client's protocol version string, return a pair of
protocol tuples:
(negotiated version, client min request)
If the request is unsupported, the negotiated protocol tuple is
None.
"""
if client_req is None:
client_min = client_max = min_tuple
else:
if isinstance(client_req, list) and len(client_req) == 2:
client_min, client_max = client_req
else:
client_min = client_max = client_req
client_min = protocol_tuple(client_min)
client_max = protocol_tuple(client_max)
result = min(client_max, max_tuple)
if result < max(client_min, min_tuple) or result == (0, ):
result = None
return result, client_min
class LRUCacheWithMetrics:
__slots__ = [
'capacity',
'cache',
'_track_metrics',
'hits',
'misses'
]
def __init__(self, capacity: int, metric_name: typing.Optional[str] = None, namespace: str = "daemon_cache"):
self.capacity = capacity
self.cache = collections.OrderedDict()
if metric_name is None:
self._track_metrics = False
self.hits = self.misses = None
else:
self._track_metrics = True
try:
self.hits = Counter(
f"{metric_name}_cache_hit_count", "Number of cache hits", namespace=namespace
)
self.misses = Counter(
f"{metric_name}_cache_miss_count", "Number of cache misses", namespace=namespace
)
except ValueError as err:
log.debug("failed to set up prometheus %s_cache_miss_count metric: %s", metric_name, err)
self._track_metrics = False
self.hits = self.misses = None
def get(self, key, default=None):
try:
value = self.cache.pop(key)
if self._track_metrics:
self.hits.inc()
except KeyError:
if self._track_metrics:
self.misses.inc()
return default
self.cache[key] = value
return value
def set(self, key, value):
try:
self.cache.pop(key)
except KeyError:
if len(self.cache) >= self.capacity:
self.cache.popitem(last=False)
self.cache[key] = value
def clear(self):
self.cache.clear()
def pop(self, key):
return self.cache.pop(key)
def __setitem__(self, key, value):
return self.set(key, value)
def __getitem__(self, item):
return self.get(item)
def __contains__(self, item) -> bool:
return item in self.cache
def __len__(self):
return len(self.cache)
def __delitem__(self, key):
self.cache.pop(key)
def __del__(self):
self.clear()
class LRUCache:
__slots__ = [
'capacity',
'cache'
]
def __init__(self, capacity: int):
self.capacity = capacity
self.cache = collections.OrderedDict()
def get(self, key, default=None):
try:
value = self.cache.pop(key)
except KeyError:
return default
self.cache[key] = value
return value
def set(self, key, value):
try:
self.cache.pop(key)
except KeyError:
if len(self.cache) >= self.capacity:
self.cache.popitem(last=False)
self.cache[key] = value
def items(self):
return self.cache.items()
def clear(self):
self.cache.clear()
def pop(self, key, default=None):
return self.cache.pop(key, default)
def __setitem__(self, key, value):
return self.set(key, value)
def __getitem__(self, item):
return self.get(item)
def __contains__(self, item) -> bool:
return item in self.cache
def __len__(self):
return len(self.cache)
def __delitem__(self, key):
self.cache.pop(key)
def __del__(self):
self.clear()
# the ipaddress module does not show these subnets as reserved
CARRIER_GRADE_NAT_SUBNET = ipaddress.ip_network('100.64.0.0/10')
IPV4_TO_6_RELAY_SUBNET = ipaddress.ip_network('192.88.99.0/24')
def is_valid_public_ipv4(address, allow_localhost: bool = False, allow_lan: bool = False):
try:
parsed_ip = ipaddress.ip_address(address)
if parsed_ip.is_loopback and allow_localhost:
return True
if allow_lan and parsed_ip.is_private:
return True
if any((parsed_ip.version != 4, parsed_ip.is_unspecified, parsed_ip.is_link_local, parsed_ip.is_loopback,
parsed_ip.is_multicast, parsed_ip.is_reserved, parsed_ip.is_private)):
return False
else:
return not any((CARRIER_GRADE_NAT_SUBNET.supernet_of(ipaddress.ip_network(f"{address}/32")),
IPV4_TO_6_RELAY_SUBNET.supernet_of(ipaddress.ip_network(f"{address}/32"))))
except (ipaddress.AddressValueError, ValueError):
return False
def sha256(x):
"""Simple wrapper of hashlib sha256."""
return _sha256(x).digest()
def ripemd160(x):
"""Simple wrapper of hashlib ripemd160."""
h = _new_hash('ripemd160')
h.update(x)
return h.digest()
def double_sha256(x):
"""SHA-256 of SHA-256, as used extensively in bitcoin."""
return sha256(sha256(x))
def hmac_sha512(key, msg):
"""Use SHA-512 to provide an HMAC."""
return _new_hmac(key, msg, _sha512).digest()
def hash160(x):
"""RIPEMD-160 of SHA-256. Used to make bitcoin addresses from pubkeys."""
return ripemd160(sha256(x))
def hash_to_hex_str(x: bytes) -> str:
"""Convert a big-endian binary hash to displayed hex string.
Display form of a binary hash is reversed and converted to hex.
"""
return x[::-1].hex()
def hex_str_to_hash(x: str) -> bytes:
"""Convert a displayed hex string to a binary hash."""
return bytes.fromhex(x)[::-1]
INVALID_REQUEST = -32600
INVALID_ARGS = -32602
class CodeMessageError(Exception):
@property
def code(self):
return self.args[0]
@property
def message(self):
return self.args[1]
def __eq__(self, other):
return (isinstance(other, self.__class__) and
self.code == other.code and self.message == other.message)
def __hash__(self):
# overridden to make the exception hashable
# see https://bugs.python.org/issue28603
return hash((self.code, self.message))
@classmethod
def invalid_args(cls, message):
return cls(INVALID_ARGS, message)
@classmethod
def invalid_request(cls, message):
return cls(INVALID_REQUEST, message)
@classmethod
def empty_batch(cls):
return cls.invalid_request('batch is empty')
class RPCError(CodeMessageError):
pass
class DaemonError(Exception):
"""Raised when the daemon returns an error in its results."""
class WarmingUpError(Exception):
"""Internal - when the daemon is warming up."""
class WorkQueueFullError(Exception):
"""Internal - when the daemon's work queue is full."""
class TaskGroup:
def __init__(self, loop=None):
self._loop = loop or get_event_loop()
self._tasks = set()
self.done = Event()
self.started = Event()
def __len__(self):
return len(self._tasks)
def add(self, coro):
task = self._loop.create_task(coro)
self._tasks.add(task)
self.started.set()
self.done.clear()
task.add_done_callback(self._remove)
return task
def _remove(self, task):
self._tasks.remove(task)
if len(self._tasks) < 1:
self.done.set()
self.started.clear()
def cancel(self):
for task in self._tasks:
task.cancel()
self.done.set()
self.started.clear()