482 lines
14 KiB
Python
482 lines
14 KiB
Python
import base64
|
|
import codecs
|
|
import datetime
|
|
import random
|
|
import socket
|
|
import time
|
|
import string
|
|
import sys
|
|
import json
|
|
import typing
|
|
import asyncio
|
|
import ssl
|
|
import logging
|
|
import ipaddress
|
|
import contextlib
|
|
import functools
|
|
import collections
|
|
import hashlib
|
|
import pkg_resources
|
|
|
|
import certifi
|
|
import aiohttp
|
|
from prometheus_client import Counter
|
|
from prometheus_client.registry import REGISTRY
|
|
from lbry.schema.claim import Claim
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
# defining these time functions here allows for easier overriding in testing
|
|
def now():
|
|
return datetime.datetime.now()
|
|
|
|
|
|
def utcnow():
|
|
return datetime.datetime.utcnow()
|
|
|
|
|
|
def isonow():
|
|
"""Return utc now in isoformat with timezone"""
|
|
return utcnow().isoformat() + 'Z'
|
|
|
|
|
|
def today():
|
|
return datetime.datetime.today()
|
|
|
|
|
|
def timedelta(**kwargs):
|
|
return datetime.timedelta(**kwargs)
|
|
|
|
|
|
def datetime_obj(*args, **kwargs):
|
|
return datetime.datetime(*args, **kwargs)
|
|
|
|
|
|
def get_lbry_hash_obj():
|
|
return hashlib.sha384()
|
|
|
|
|
|
def generate_id(num=None):
|
|
h = get_lbry_hash_obj()
|
|
if num is not None:
|
|
h.update(str(num).encode())
|
|
else:
|
|
h.update(str(random.getrandbits(512)).encode())
|
|
return h.digest()
|
|
|
|
|
|
def version_is_greater_than(version_a, version_b):
|
|
"""Returns True if version a is more recent than version b"""
|
|
return pkg_resources.parse_version(version_a) > pkg_resources.parse_version(version_b)
|
|
|
|
|
|
def rot13(some_str):
|
|
return codecs.encode(some_str, 'rot_13')
|
|
|
|
|
|
def deobfuscate(obfustacated):
|
|
return base64.b64decode(rot13(obfustacated)).decode()
|
|
|
|
|
|
def obfuscate(plain):
|
|
return rot13(base64.b64encode(plain).decode())
|
|
|
|
|
|
def check_connection(server="lbry.com", port=80, timeout=5) -> bool:
|
|
"""Attempts to open a socket to server:port and returns True if successful."""
|
|
log.debug('Checking connection to %s:%s', server, port)
|
|
try:
|
|
server = socket.gethostbyname(server)
|
|
socket.create_connection((server, port), timeout).close()
|
|
return True
|
|
except (socket.gaierror, socket.herror):
|
|
log.debug("Failed to connect to %s:%s. Unable to resolve domain. Trying to bypass DNS",
|
|
server, port)
|
|
try:
|
|
server = "8.8.8.8"
|
|
port = 53
|
|
socket.create_connection((server, port), timeout).close()
|
|
return True
|
|
except OSError:
|
|
return False
|
|
except OSError:
|
|
return False
|
|
|
|
|
|
async def async_check_connection(server="lbry.com", port=80, timeout=1) -> bool:
|
|
return await asyncio.get_event_loop().run_in_executor(None, check_connection, server, port, timeout)
|
|
|
|
|
|
def random_string(length=10, chars=string.ascii_lowercase):
|
|
return ''.join([random.choice(chars) for _ in range(length)])
|
|
|
|
|
|
def short_hash(hash_str):
|
|
return hash_str[:6]
|
|
|
|
|
|
def get_sd_hash(stream_info):
|
|
if not stream_info:
|
|
return None
|
|
if isinstance(stream_info, Claim):
|
|
return stream_info.stream.source.sd_hash
|
|
result = stream_info.get('claim', {}).\
|
|
get('value', {}).\
|
|
get('stream', {}).\
|
|
get('source', {}).\
|
|
get('source')
|
|
if not result:
|
|
log.warning("Unable to get sd_hash")
|
|
return result
|
|
|
|
|
|
def json_dumps_pretty(obj, **kwargs):
|
|
return json.dumps(obj, sort_keys=True, indent=2, separators=(',', ': '), **kwargs)
|
|
|
|
|
|
def cancel_task(task: typing.Optional[asyncio.Task]):
|
|
if task and not task.done():
|
|
task.cancel()
|
|
|
|
|
|
def cancel_tasks(tasks: typing.List[typing.Optional[asyncio.Task]]):
|
|
for task in tasks:
|
|
cancel_task(task)
|
|
|
|
|
|
def drain_tasks(tasks: typing.List[typing.Optional[asyncio.Task]]):
|
|
while tasks:
|
|
cancel_task(tasks.pop())
|
|
|
|
|
|
def async_timed_cache(duration: int):
|
|
def wrapper(func):
|
|
cache: typing.Dict[typing.Tuple,
|
|
typing.Tuple[typing.Any, float]] = {}
|
|
|
|
@functools.wraps(func)
|
|
async def _inner(*args, **kwargs) -> typing.Any:
|
|
loop = asyncio.get_running_loop()
|
|
time_now = loop.time()
|
|
key = tuple([args, tuple([tuple([k, kwargs[k]]) for k in kwargs])])
|
|
if key in cache and (time_now - cache[key][1] < duration):
|
|
return cache[key][0]
|
|
to_cache = await func(*args, **kwargs)
|
|
cache[key] = to_cache, time_now
|
|
return to_cache
|
|
return _inner
|
|
return wrapper
|
|
|
|
|
|
def cache_concurrent(async_fn):
|
|
"""
|
|
When the decorated function has concurrent calls made to it with the same arguments, only run it once
|
|
"""
|
|
cache: typing.Dict = {}
|
|
|
|
@functools.wraps(async_fn)
|
|
async def wrapper(*args, **kwargs):
|
|
key = tuple([args, tuple([tuple([k, kwargs[k]]) for k in kwargs])])
|
|
cache[key] = cache.get(key) or asyncio.create_task(async_fn(*args, **kwargs))
|
|
try:
|
|
return await cache[key]
|
|
finally:
|
|
cache.pop(key, None)
|
|
|
|
return wrapper
|
|
|
|
|
|
@async_timed_cache(300)
|
|
async def resolve_host(url: str, port: int, proto: str) -> str:
|
|
if proto not in ['udp', 'tcp']:
|
|
raise Exception("invalid protocol")
|
|
if url.lower() == 'localhost':
|
|
return '127.0.0.1'
|
|
try:
|
|
if ipaddress.ip_address(url):
|
|
return url
|
|
except ValueError:
|
|
pass
|
|
loop = asyncio.get_running_loop()
|
|
return (await loop.getaddrinfo(
|
|
url, port,
|
|
proto=socket.IPPROTO_TCP if proto == 'tcp' else socket.IPPROTO_UDP,
|
|
type=socket.SOCK_STREAM if proto == 'tcp' else socket.SOCK_DGRAM,
|
|
family=socket.AF_INET
|
|
))[0][4][0]
|
|
|
|
|
|
class LRUCacheWithMetrics:
|
|
__slots__ = [
|
|
'capacity',
|
|
'cache',
|
|
'_track_metrics',
|
|
'hits',
|
|
'misses'
|
|
]
|
|
|
|
def __init__(self, capacity: int, metric_name: typing.Optional[str] = None, namespace: str = "daemon_cache"):
|
|
self.capacity = capacity
|
|
self.cache = collections.OrderedDict()
|
|
if metric_name is None:
|
|
self._track_metrics = False
|
|
self.hits = self.misses = None
|
|
else:
|
|
self._track_metrics = True
|
|
try:
|
|
self.hits = Counter(
|
|
f"{metric_name}_cache_hit_count", "Number of cache hits", namespace=namespace
|
|
)
|
|
self.misses = Counter(
|
|
f"{metric_name}_cache_miss_count", "Number of cache misses", namespace=namespace
|
|
)
|
|
except ValueError as err:
|
|
log.debug("failed to set up prometheus %s_cache_miss_count metric: %s", metric_name, err)
|
|
self._track_metrics = False
|
|
self.hits = self.misses = None
|
|
|
|
def get(self, key, default=None):
|
|
try:
|
|
value = self.cache.pop(key)
|
|
if self._track_metrics:
|
|
self.hits.inc()
|
|
except KeyError:
|
|
if self._track_metrics:
|
|
self.misses.inc()
|
|
return default
|
|
self.cache[key] = value
|
|
return value
|
|
|
|
def set(self, key, value):
|
|
try:
|
|
self.cache.pop(key)
|
|
except KeyError:
|
|
if len(self.cache) >= self.capacity:
|
|
self.cache.popitem(last=False)
|
|
self.cache[key] = value
|
|
|
|
def clear(self):
|
|
self.cache.clear()
|
|
|
|
def pop(self, key):
|
|
return self.cache.pop(key)
|
|
|
|
def __setitem__(self, key, value):
|
|
return self.set(key, value)
|
|
|
|
def __getitem__(self, item):
|
|
return self.get(item)
|
|
|
|
def __contains__(self, item) -> bool:
|
|
return item in self.cache
|
|
|
|
def __len__(self):
|
|
return len(self.cache)
|
|
|
|
def __delitem__(self, key):
|
|
self.cache.pop(key)
|
|
|
|
def __del__(self):
|
|
self.clear()
|
|
if self._track_metrics: # needed for tests
|
|
try:
|
|
REGISTRY.unregister(self.hits)
|
|
REGISTRY.unregister(self.misses)
|
|
except AttributeError:
|
|
pass
|
|
|
|
|
|
class LRUCache:
|
|
__slots__ = [
|
|
'capacity',
|
|
'cache'
|
|
]
|
|
|
|
def __init__(self, capacity: int):
|
|
self.capacity = capacity
|
|
self.cache = collections.OrderedDict()
|
|
|
|
def get(self, key, default=None):
|
|
try:
|
|
value = self.cache.pop(key)
|
|
except KeyError:
|
|
return default
|
|
self.cache[key] = value
|
|
return value
|
|
|
|
def set(self, key, value):
|
|
try:
|
|
self.cache.pop(key)
|
|
except KeyError:
|
|
if len(self.cache) >= self.capacity:
|
|
self.cache.popitem(last=False)
|
|
self.cache[key] = value
|
|
|
|
def clear(self):
|
|
self.cache.clear()
|
|
|
|
def pop(self, key):
|
|
return self.cache.pop(key)
|
|
|
|
def __setitem__(self, key, value):
|
|
return self.set(key, value)
|
|
|
|
def __getitem__(self, item):
|
|
return self.get(item)
|
|
|
|
def __contains__(self, item) -> bool:
|
|
return item in self.cache
|
|
|
|
def __len__(self):
|
|
return len(self.cache)
|
|
|
|
def __delitem__(self, key):
|
|
self.cache.pop(key)
|
|
|
|
def __del__(self):
|
|
self.clear()
|
|
|
|
|
|
def lru_cache_concurrent(cache_size: typing.Optional[int] = None,
|
|
override_lru_cache: typing.Optional[LRUCacheWithMetrics] = None):
|
|
if not cache_size and override_lru_cache is None:
|
|
raise ValueError("invalid cache size")
|
|
concurrent_cache = {}
|
|
lru_cache = override_lru_cache if override_lru_cache is not None else LRUCacheWithMetrics(cache_size)
|
|
|
|
def wrapper(async_fn):
|
|
|
|
@functools.wraps(async_fn)
|
|
async def _inner(*args, **kwargs):
|
|
key = tuple([args, tuple([tuple([k, kwargs[k]]) for k in kwargs])])
|
|
if key in lru_cache:
|
|
return lru_cache.get(key)
|
|
|
|
concurrent_cache[key] = concurrent_cache.get(key) or asyncio.create_task(async_fn(*args, **kwargs))
|
|
|
|
try:
|
|
result = await concurrent_cache[key]
|
|
lru_cache.set(key, result)
|
|
return result
|
|
finally:
|
|
concurrent_cache.pop(key, None)
|
|
return _inner
|
|
return wrapper
|
|
|
|
|
|
def get_ssl_context() -> ssl.SSLContext:
|
|
return ssl.create_default_context(
|
|
purpose=ssl.Purpose.CLIENT_AUTH, capath=certifi.where()
|
|
)
|
|
|
|
|
|
@contextlib.asynccontextmanager
|
|
async def aiohttp_request(method, url, **kwargs) -> typing.AsyncContextManager[aiohttp.ClientResponse]:
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.request(method, url, **kwargs) as response:
|
|
yield response
|
|
|
|
|
|
# the ipaddress module does not show these subnets as reserved
|
|
CARRIER_GRADE_NAT_SUBNET = ipaddress.ip_network('100.64.0.0/10')
|
|
IPV4_TO_6_RELAY_SUBNET = ipaddress.ip_network('192.88.99.0/24')
|
|
|
|
|
|
def is_valid_public_ipv4(address, allow_localhost: bool = False):
|
|
try:
|
|
parsed_ip = ipaddress.ip_address(address)
|
|
if parsed_ip.is_loopback and allow_localhost:
|
|
return True
|
|
if any((parsed_ip.version != 4, parsed_ip.is_unspecified, parsed_ip.is_link_local, parsed_ip.is_loopback,
|
|
parsed_ip.is_multicast, parsed_ip.is_reserved, parsed_ip.is_private, parsed_ip.is_reserved)):
|
|
return False
|
|
else:
|
|
return not any((CARRIER_GRADE_NAT_SUBNET.supernet_of(ipaddress.ip_network(f"{address}/32")),
|
|
IPV4_TO_6_RELAY_SUBNET.supernet_of(ipaddress.ip_network(f"{address}/32"))))
|
|
except (ipaddress.AddressValueError, ValueError):
|
|
return False
|
|
|
|
|
|
async def fallback_get_external_ip(): # used if spv servers can't be used for ip detection
|
|
try:
|
|
async with aiohttp_request("get", "https://api.lbry.com/ip") as resp:
|
|
response = await resp.json()
|
|
if response['success']:
|
|
return response['data']['ip'], None
|
|
except Exception:
|
|
return None, None
|
|
|
|
|
|
async def _get_external_ip(default_servers) -> typing.Tuple[typing.Optional[str], typing.Optional[str]]:
|
|
# used if upnp is disabled or non-functioning
|
|
from lbry.wallet.server.udp import SPVStatusClientProtocol # pylint: disable=C0415
|
|
|
|
hostname_to_ip = {}
|
|
ip_to_hostnames = collections.defaultdict(list)
|
|
|
|
async def resolve_spv(server, port):
|
|
try:
|
|
server_addr = await resolve_host(server, port, 'udp')
|
|
hostname_to_ip[server] = (server_addr, port)
|
|
ip_to_hostnames[(server_addr, port)].append(server)
|
|
except Exception:
|
|
log.exception("error looking up dns for spv servers")
|
|
|
|
# accumulate the dns results
|
|
await asyncio.gather(*(resolve_spv(server, port) for (server, port) in default_servers))
|
|
|
|
loop = asyncio.get_event_loop()
|
|
pong_responses = asyncio.Queue()
|
|
connection = SPVStatusClientProtocol(pong_responses)
|
|
try:
|
|
await loop.create_datagram_endpoint(lambda: connection, ('0.0.0.0', 0))
|
|
# could raise OSError if it cant bind
|
|
randomized_servers = list(ip_to_hostnames.keys())
|
|
random.shuffle(randomized_servers)
|
|
for server in randomized_servers:
|
|
connection.ping(server)
|
|
try:
|
|
_, pong = await asyncio.wait_for(pong_responses.get(), 1)
|
|
if is_valid_public_ipv4(pong.ip_address):
|
|
return pong.ip_address, ip_to_hostnames[server][0]
|
|
except asyncio.TimeoutError:
|
|
pass
|
|
return None, None
|
|
finally:
|
|
connection.close()
|
|
|
|
|
|
async def get_external_ip(default_servers) -> typing.Tuple[typing.Optional[str], typing.Optional[str]]:
|
|
ip_from_spv_servers = await _get_external_ip(default_servers)
|
|
if not ip_from_spv_servers[1]:
|
|
return await fallback_get_external_ip()
|
|
return ip_from_spv_servers
|
|
|
|
|
|
def is_running_from_bundle():
|
|
# see https://pyinstaller.readthedocs.io/en/stable/runtime-information.html
|
|
return getattr(sys, 'frozen', False) and hasattr(sys, '_MEIPASS')
|
|
|
|
|
|
class LockWithMetrics(asyncio.Lock):
|
|
def __init__(self, acquire_metric, held_time_metric, loop=None):
|
|
super().__init__(loop=loop)
|
|
self._acquire_metric = acquire_metric
|
|
self._lock_held_time_metric = held_time_metric
|
|
self._lock_acquired_time = None
|
|
|
|
async def acquire(self):
|
|
start = time.perf_counter()
|
|
try:
|
|
return await super().acquire()
|
|
finally:
|
|
self._lock_acquired_time = time.perf_counter()
|
|
self._acquire_metric.observe(self._lock_acquired_time - start)
|
|
|
|
def release(self):
|
|
try:
|
|
return super().release()
|
|
finally:
|
|
self._lock_held_time_metric.observe(time.perf_counter() - self._lock_acquired_time)
|