lbry-sdk/lbry/utils.py

495 lines
15 KiB
Python
Raw Normal View History

import base64
import codecs
import datetime
2015-08-20 17:27:15 +02:00
import random
import socket
2020-05-03 03:23:17 +02:00
import time
2016-12-30 19:35:17 +01:00
import string
import sys
2018-07-21 20:12:29 +02:00
import json
2019-01-22 23:44:17 +01:00
import typing
import asyncio
2019-02-28 18:40:11 +01:00
import ssl
2018-06-07 18:18:07 +02:00
import logging
import ipaddress
2019-02-28 18:40:11 +01:00
import contextlib
import functools
import collections
2020-01-03 06:52:48 +01:00
import hashlib
2020-01-03 05:15:33 +01:00
import pkg_resources
import certifi
import aiohttp
2020-12-23 22:37:31 +01:00
from prometheus_client import Counter
2019-06-21 02:55:47 +02:00
from lbry.schema.claim import Claim
2020-12-23 22:37:31 +01:00
2018-06-07 18:18:07 +02:00
log = logging.getLogger(__name__)
2016-10-05 21:16:20 +02:00
# defining these time functions here allows for easier overriding in testing
2016-09-30 06:06:07 +02:00
def now():
return datetime.datetime.now()
2016-10-05 21:16:20 +02:00
2016-09-30 06:06:07 +02:00
def utcnow():
return datetime.datetime.utcnow()
2016-10-05 21:16:20 +02:00
2016-09-30 06:06:07 +02:00
def isonow():
"""Return utc now in isoformat with timezone"""
return utcnow().isoformat() + 'Z'
2016-09-30 06:06:07 +02:00
def today():
return datetime.datetime.today()
2017-01-02 20:52:24 +01:00
def timedelta(**kwargs):
return datetime.timedelta(**kwargs)
def datetime_obj(*args, **kwargs):
return datetime.datetime(*args, **kwargs)
2020-01-03 05:44:41 +01:00
def get_lbry_hash_obj():
return hashlib.sha384()
2015-08-20 17:27:15 +02:00
def generate_id(num=None):
h = get_lbry_hash_obj()
if num is not None:
2018-06-12 17:54:01 +02:00
h.update(str(num).encode())
2015-08-20 17:27:15 +02:00
else:
2018-06-12 17:54:01 +02:00
h.update(str(random.getrandbits(512)).encode())
2015-08-20 17:27:15 +02:00
return h.digest()
2020-01-03 05:15:33 +01:00
def version_is_greater_than(version_a, version_b):
"""Returns True if version a is more recent than version b"""
2020-01-03 05:15:33 +01:00
return pkg_resources.parse_version(version_a) > pkg_resources.parse_version(version_b)
def rot13(some_str):
return codecs.encode(some_str, 'rot_13')
def deobfuscate(obfustacated):
2018-10-18 21:57:15 +02:00
return base64.b64decode(rot13(obfustacated)).decode()
def obfuscate(plain):
return rot13(base64.b64encode(plain).decode())
def check_connection(server="lbry.com", port=80, timeout=5) -> bool:
"""Attempts to open a socket to server:port and returns True if successful."""
log.debug('Checking connection to %s:%s', server, port)
try:
2017-10-31 17:18:47 +01:00
server = socket.gethostbyname(server)
2019-02-03 22:19:29 +01:00
socket.create_connection((server, port), timeout).close()
return True
2020-01-03 05:15:33 +01:00
except (socket.gaierror, socket.herror):
log.debug("Failed to connect to %s:%s. Unable to resolve domain. Trying to bypass DNS",
server, port)
try:
server = "8.8.8.8"
port = 53
2019-02-03 22:19:29 +01:00
socket.create_connection((server, port), timeout).close()
return True
except OSError:
return False
except OSError:
return False
2016-12-30 19:35:17 +01:00
def random_string(length=10, chars=string.ascii_lowercase):
return ''.join([random.choice(chars) for _ in range(length)])
def short_hash(hash_str):
return hash_str[:6]
2017-03-09 16:39:17 +01:00
def get_sd_hash(stream_info):
if not stream_info:
return None
2019-03-20 06:46:23 +01:00
if isinstance(stream_info, Claim):
2019-04-20 07:12:43 +02:00
return stream_info.stream.source.sd_hash
result = stream_info.get('claim', {}).\
get('value', {}).\
get('stream', {}).\
get('source', {}).\
get('source')
if not result:
2018-07-22 01:08:28 +02:00
log.warning("Unable to get sd_hash")
return result
def json_dumps_pretty(obj, **kwargs):
return json.dumps(obj, sort_keys=True, indent=2, separators=(',', ': '), **kwargs)
2018-02-28 20:59:12 +01:00
2019-01-22 23:44:17 +01:00
def cancel_task(task: typing.Optional[asyncio.Task]):
if task and not task.done():
task.cancel()
def cancel_tasks(tasks: typing.List[typing.Optional[asyncio.Task]]):
for task in tasks:
cancel_task(task)
def drain_tasks(tasks: typing.List[typing.Optional[asyncio.Task]]):
while tasks:
cancel_task(tasks.pop())
def async_timed_cache(duration: int):
2020-01-03 05:15:33 +01:00
def wrapper(func):
cache: typing.Dict[typing.Tuple,
typing.Tuple[typing.Any, float]] = {}
2020-01-03 05:15:33 +01:00
@functools.wraps(func)
async def _inner(*args, **kwargs) -> typing.Any:
loop = asyncio.get_running_loop()
2020-01-03 05:15:33 +01:00
time_now = loop.time()
2021-08-21 04:36:35 +02:00
key = (args, tuple(kwargs.items()))
2020-01-03 05:15:33 +01:00
if key in cache and (time_now - cache[key][1] < duration):
return cache[key][0]
2020-01-03 05:15:33 +01:00
to_cache = await func(*args, **kwargs)
cache[key] = to_cache, time_now
return to_cache
return _inner
return wrapper
2019-03-31 03:05:46 +02:00
def cache_concurrent(async_fn):
"""
2019-04-19 18:06:29 +02:00
When the decorated function has concurrent calls made to it with the same arguments, only run it once
2019-03-31 03:05:46 +02:00
"""
cache: typing.Dict = {}
@functools.wraps(async_fn)
async def wrapper(*args, **kwargs):
2021-08-21 04:36:35 +02:00
key = (args, tuple(kwargs.items()))
2019-04-19 18:06:29 +02:00
cache[key] = cache.get(key) or asyncio.create_task(async_fn(*args, **kwargs))
2019-03-31 03:05:46 +02:00
try:
2019-04-19 18:06:29 +02:00
return await cache[key]
2019-03-31 03:05:46 +02:00
finally:
2019-04-19 18:06:29 +02:00
cache.pop(key, None)
2019-03-31 03:05:46 +02:00
return wrapper
@async_timed_cache(300)
2019-02-05 19:31:57 +01:00
async def resolve_host(url: str, port: int, proto: str) -> str:
if proto not in ['udp', 'tcp']:
raise Exception("invalid protocol")
if url.lower() == 'localhost':
return '127.0.0.1'
try:
if ipaddress.ip_address(url):
return url
except ValueError:
pass
loop = asyncio.get_running_loop()
return (await loop.getaddrinfo(
2019-02-05 19:31:57 +01:00
url, port,
2019-02-05 19:36:25 +01:00
proto=socket.IPPROTO_TCP if proto == 'tcp' else socket.IPPROTO_UDP,
type=socket.SOCK_STREAM if proto == 'tcp' else socket.SOCK_DGRAM,
family=socket.AF_INET
))[0][4][0]
2019-02-28 18:40:11 +01:00
class LRUCacheWithMetrics:
__slots__ = [
'capacity',
2020-12-23 22:37:31 +01:00
'cache',
'_track_metrics',
'hits',
'misses'
]
def __init__(self, capacity: int, metric_name: typing.Optional[str] = None, namespace: str = "daemon_cache"):
self.capacity = capacity
self.cache = collections.OrderedDict()
2020-12-23 22:37:31 +01:00
if metric_name is None:
self._track_metrics = False
self.hits = self.misses = None
else:
self._track_metrics = True
try:
self.hits = Counter(
f"{metric_name}_cache_hit_count", "Number of cache hits", namespace=namespace
2020-12-23 22:37:31 +01:00
)
self.misses = Counter(
f"{metric_name}_cache_miss_count", "Number of cache misses", namespace=namespace
2020-12-23 22:37:31 +01:00
)
except ValueError as err:
2021-01-21 20:51:59 +01:00
log.debug("failed to set up prometheus %s_cache_miss_count metric: %s", metric_name, err)
2020-12-23 22:37:31 +01:00
self._track_metrics = False
self.hits = self.misses = None
def get(self, key, default=None):
try:
value = self.cache.pop(key)
if self._track_metrics:
self.hits.inc()
except KeyError:
if self._track_metrics:
self.misses.inc()
return default
self.cache[key] = value
return value
def set(self, key, value):
try:
self.cache.pop(key)
except KeyError:
if len(self.cache) >= self.capacity:
self.cache.popitem(last=False)
self.cache[key] = value
2020-12-23 22:37:31 +01:00
def clear(self):
self.cache.clear()
def pop(self, key):
return self.cache.pop(key)
def __setitem__(self, key, value):
return self.set(key, value)
def __getitem__(self, item):
return self.get(item)
def __contains__(self, item) -> bool:
return item in self.cache
2020-12-23 22:37:31 +01:00
def __len__(self):
return len(self.cache)
def __delitem__(self, key):
self.cache.pop(key)
def __del__(self):
self.clear()
class LRUCache:
__slots__ = [
'capacity',
'cache'
]
def __init__(self, capacity: int):
self.capacity = capacity
self.cache = collections.OrderedDict()
def get(self, key, default=None):
try:
value = self.cache.pop(key)
except KeyError:
return default
self.cache[key] = value
return value
def set(self, key, value):
try:
self.cache.pop(key)
except KeyError:
if len(self.cache) >= self.capacity:
self.cache.popitem(last=False)
self.cache[key] = value
2021-06-04 17:20:44 +02:00
def items(self):
return self.cache.items()
def clear(self):
self.cache.clear()
def pop(self, key, default=None):
return self.cache.pop(key, default)
def __setitem__(self, key, value):
return self.set(key, value)
def __getitem__(self, item):
return self.get(item)
def __contains__(self, item) -> bool:
return item in self.cache
def __len__(self):
return len(self.cache)
def __delitem__(self, key):
self.cache.pop(key)
def __del__(self):
self.clear()
def lru_cache_concurrent(cache_size: typing.Optional[int] = None,
override_lru_cache: typing.Optional[LRUCacheWithMetrics] = None):
if not cache_size and override_lru_cache is None:
raise ValueError("invalid cache size")
concurrent_cache = {}
lru_cache = override_lru_cache if override_lru_cache is not None else LRUCacheWithMetrics(cache_size)
def wrapper(async_fn):
@functools.wraps(async_fn)
async def _inner(*args, **kwargs):
2021-08-21 04:36:35 +02:00
key = (args, tuple(kwargs.items()))
if key in lru_cache:
return lru_cache.get(key)
concurrent_cache[key] = concurrent_cache.get(key) or asyncio.create_task(async_fn(*args, **kwargs))
try:
result = await concurrent_cache[key]
lru_cache.set(key, result)
return result
finally:
concurrent_cache.pop(key, None)
return _inner
return wrapper
2019-02-28 18:42:23 +01:00
def get_ssl_context() -> ssl.SSLContext:
2019-02-28 18:40:11 +01:00
return ssl.create_default_context(
2019-02-28 18:45:56 +01:00
purpose=ssl.Purpose.CLIENT_AUTH, capath=certifi.where()
2019-02-28 18:40:11 +01:00
)
@contextlib.asynccontextmanager
2019-02-28 18:42:23 +01:00
async def aiohttp_request(method, url, **kwargs) -> typing.AsyncContextManager[aiohttp.ClientResponse]:
2019-02-28 18:40:11 +01:00
async with aiohttp.ClientSession() as session:
async with session.request(method, url, **kwargs) as response:
2019-02-28 18:40:11 +01:00
yield response
2019-03-11 02:55:33 +01:00
# the ipaddress module does not show these subnets as reserved
CARRIER_GRADE_NAT_SUBNET = ipaddress.ip_network('100.64.0.0/10')
IPV4_TO_6_RELAY_SUBNET = ipaddress.ip_network('192.88.99.0/24')
2021-04-28 21:28:00 +02:00
def is_valid_public_ipv4(address, allow_localhost: bool = False, allow_lan: bool = False):
try:
parsed_ip = ipaddress.ip_address(address)
if parsed_ip.is_loopback and allow_localhost:
return True
2021-04-28 21:28:00 +02:00
if allow_lan and parsed_ip.is_private:
return True
if any((parsed_ip.version != 4, parsed_ip.is_unspecified, parsed_ip.is_link_local, parsed_ip.is_loopback,
2021-04-28 21:28:00 +02:00
parsed_ip.is_multicast, parsed_ip.is_reserved, parsed_ip.is_private)):
return False
else:
return not any((CARRIER_GRADE_NAT_SUBNET.supernet_of(ipaddress.ip_network(f"{address}/32")),
IPV4_TO_6_RELAY_SUBNET.supernet_of(ipaddress.ip_network(f"{address}/32"))))
except (ipaddress.AddressValueError, ValueError):
return False
async def fallback_get_external_ip(): # used if spv servers can't be used for ip detection
2019-03-11 02:55:33 +01:00
try:
async with aiohttp_request("get", "https://api.lbry.com/ip") as resp:
2019-03-11 02:55:33 +01:00
response = await resp.json()
if response['success']:
return response['data']['ip'], None
2020-01-03 05:15:33 +01:00
except Exception:
return None, None
async def _get_external_ip(default_servers) -> typing.Tuple[typing.Optional[str], typing.Optional[str]]:
# used if upnp is disabled or non-functioning
from lbry.wallet.server.udp import SPVStatusClientProtocol # pylint: disable=C0415
hostname_to_ip = {}
ip_to_hostnames = collections.defaultdict(list)
async def resolve_spv(server, port):
try:
server_addr = await resolve_host(server, port, 'udp')
hostname_to_ip[server] = (server_addr, port)
ip_to_hostnames[(server_addr, port)].append(server)
except Exception:
log.exception("error looking up dns for spv servers")
# accumulate the dns results
await asyncio.gather(*(resolve_spv(server, port) for (server, port) in default_servers))
loop = asyncio.get_event_loop()
pong_responses = asyncio.Queue()
connection = SPVStatusClientProtocol(pong_responses)
try:
await loop.create_datagram_endpoint(lambda: connection, ('0.0.0.0', 0))
# could raise OSError if it cant bind
randomized_servers = list(ip_to_hostnames.keys())
random.shuffle(randomized_servers)
for server in randomized_servers:
connection.ping(server)
try:
_, pong = await asyncio.wait_for(pong_responses.get(), 1)
if is_valid_public_ipv4(pong.ip_address):
return pong.ip_address, ip_to_hostnames[server][0]
except asyncio.TimeoutError:
pass
return None, None
finally:
connection.close()
async def get_external_ip(default_servers) -> typing.Tuple[typing.Optional[str], typing.Optional[str]]:
ip_from_spv_servers = await _get_external_ip(default_servers)
if not ip_from_spv_servers[1]:
return await fallback_get_external_ip()
return ip_from_spv_servers
def is_running_from_bundle():
# see https://pyinstaller.readthedocs.io/en/stable/runtime-information.html
return getattr(sys, 'frozen', False) and hasattr(sys, '_MEIPASS')
2020-05-03 03:23:17 +02:00
class LockWithMetrics(asyncio.Lock):
def __init__(self, acquire_metric, held_time_metric, loop=None):
super().__init__(loop=loop)
self._acquire_metric = acquire_metric
self._lock_held_time_metric = held_time_metric
self._lock_acquired_time = None
async def acquire(self):
start = time.perf_counter()
try:
return await super().acquire()
finally:
self._lock_acquired_time = time.perf_counter()
self._acquire_metric.observe(self._lock_acquired_time - start)
def release(self):
try:
return super().release()
finally:
self._lock_held_time_metric.observe(time.perf_counter() - self._lock_acquired_time)
def get_colliding_prefix_bits(first_value: bytes, second_value: bytes, size: int):
"""
Calculates the amount of colliding bits between <first_value> and <second_value> over the <size> first bits.
This is given by the amount of bits that are the same until the first different one (via XOR).
:param first_value: first value to compare, bigger than size.
:param second_value: second value to compare, bigger than size.
:param size: prefix size in bits.
:return: amount of prefix colliding bits.
"""
assert size % 8 == 0, "size has to be a multiple of 8"
size_in_bytes = size // 8
assert len(first_value) >= size_in_bytes, "first_value has to be larger than size parameter"
first_value = int.from_bytes(first_value[:size_in_bytes], "big")
assert len(second_value) >= size_in_bytes, "second_value has to be larger than size parameter"
second_value = int.from_bytes(second_value[:size_in_bytes], "big")
return size - (first_value ^ second_value).bit_length()