# Copyright (c) 2016-2017, Neil Booth # # All rights reserved. # # The MIT License (MIT) # # Permission is hereby granted, free of charge, to any person obtaining # a copy of this software and associated documentation files (the # "Software"), to deal in the Software without restriction, including # without limitation the rights to use, copy, modify, merge, publish, # distribute, sublicense, and/or sell copies of the Software, and to # permit persons to whom the Software is furnished to do so, subject to # the following conditions: # # The above copyright notice and this permission notice shall be # included in all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, # EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE # LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION # WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. # and warranty status of this software. """Miscellaneous utility classes and functions.""" import array import inspect from ipaddress import ip_address import logging import re import sys from collections import Container, Mapping from struct import pack, Struct # Logging utilities class ConnectionLogger(logging.LoggerAdapter): """Prepends a connection identifier to a logging message.""" def process(self, msg, kwargs): conn_id = self.extra.get('conn_id', 'unknown') return f'[{conn_id}] {msg}', kwargs class CompactFormatter(logging.Formatter): """Strips the module from the logger name to leave the class only.""" def format(self, record): record.name = record.name.rpartition('.')[-1] return super().format(record) def make_logger(name, *, handler, level): """Return the root ElectrumX logger.""" logger = logging.getLogger(name) logger.addHandler(handler) logger.setLevel(logging.INFO) logger.propagate = False return logger def class_logger(path, classname): """Return a hierarchical logger for a class.""" return logging.getLogger(path).getChild(classname) # Method decorator. To be used for calculations that will always # deliver the same result. The method cannot take any arguments # and should be accessed as an attribute. class cachedproperty: def __init__(self, f): self.f = f def __get__(self, obj, type): obj = obj or type value = self.f(obj) setattr(obj, self.f.__name__, value) return value def formatted_time(t, sep=' '): """Return a number of seconds as a string in days, hours, mins and maybe secs.""" t = int(t) fmts = (('{:d}d', 86400), ('{:02d}h', 3600), ('{:02d}m', 60)) parts = [] for fmt, n in fmts: val = t // n if parts or val: parts.append(fmt.format(val)) t %= n if len(parts) < 3: parts.append(f'{t:02d}s') return sep.join(parts) def deep_getsizeof(obj): """Find the memory footprint of a Python object. Based on code from code.tutsplus.com: http://goo.gl/fZ0DXK This is a recursive function that drills down a Python object graph like a dictionary holding nested dictionaries with lists of lists and tuples and sets. The sys.getsizeof function does a shallow size of only. It counts each object inside a container as pointer only regardless of how big it really is. """ ids = set() def size(o): if id(o) in ids: return 0 r = sys.getsizeof(o) ids.add(id(o)) if isinstance(o, (str, bytes, bytearray, array.array)): return r if isinstance(o, Mapping): return r + sum(size(k) + size(v) for k, v in o.items()) if isinstance(o, Container): return r + sum(size(x) for x in o) return r return size(obj) def subclasses(base_class, strict=True): """Return a list of subclasses of base_class in its module.""" def select(obj): return (inspect.isclass(obj) and issubclass(obj, base_class) and (not strict or obj != base_class)) pairs = inspect.getmembers(sys.modules[base_class.__module__], select) return [pair[1] for pair in pairs] def chunks(items, size): """Break up items, an iterable, into chunks of length size.""" for i in range(0, len(items), size): yield items[i: i + size] def resolve_limit(limit): if limit is None: return -1 assert isinstance(limit, int) and limit >= 0 return limit def bytes_to_int(be_bytes): """Interprets a big-endian sequence of bytes as an integer""" return int.from_bytes(be_bytes, 'big') def int_to_bytes(value): """Converts an integer to a big-endian sequence of bytes""" return value.to_bytes((value.bit_length() + 7) // 8, 'big') def increment_byte_string(bs): """Return the lexicographically next byte string of the same length. Return None if there is none (when the input is all 0xff bytes).""" for n in range(1, len(bs) + 1): if bs[-n] != 0xff: return bs[:-n] + bytes([bs[-n] + 1]) + bytes(n - 1) return None class LogicalFile: """A logical binary file split across several separate files on disk.""" def __init__(self, prefix, digits, file_size): digit_fmt = f'{{:0{digits:d}d}}' self.filename_fmt = prefix + digit_fmt self.file_size = file_size def read(self, start, size=-1): """Read up to size bytes from the virtual file, starting at offset start, and return them. If size is -1 all bytes are read.""" parts = [] while size != 0: try: with self.open_file(start, False) as f: part = f.read(size) if not part: break except FileNotFoundError: break parts.append(part) start += len(part) if size > 0: size -= len(part) return b''.join(parts) def write(self, start, b): """Write the bytes-like object, b, to the underlying virtual file.""" while b: size = min(len(b), self.file_size - (start % self.file_size)) with self.open_file(start, True) as f: f.write(b if size == len(b) else b[:size]) b = b[size:] start += size def open_file(self, start, create): """Open the virtual file and seek to start. Return a file handle. Raise FileNotFoundError if the file does not exist and create is False. """ file_num, offset = divmod(start, self.file_size) filename = self.filename_fmt.format(file_num) f = open_file(filename, create) f.seek(offset) return f def open_file(filename, create=False): """Open the file name. Return its handle.""" try: return open(filename, 'rb+') except FileNotFoundError: if create: return open(filename, 'wb+') raise def open_truncate(filename): """Open the file name. Return its handle.""" return open(filename, 'wb+') def address_string(address): """Return an address as a correctly formatted string.""" fmt = '{}:{:d}' host, port = address try: host = ip_address(host) except ValueError: pass else: if host.version == 6: fmt = '[{}]:{:d}' return fmt.format(host, port) # See http://stackoverflow.com/questions/2532053/validate-a-hostname-string # Note underscores are valid in domain names, but strictly invalid in host # names. We ignore that distinction. SEGMENT_REGEX = re.compile("(?!-)[A-Z_\\d-]{1,63}(? 255: return False # strip exactly one dot from the right, if present if hostname and hostname[-1] == ".": hostname = hostname[:-1] return all(SEGMENT_REGEX.match(x) for x in hostname.split(".")) def protocol_tuple(s): """Converts a protocol version number, such as "1.0" to a tuple (1, 0). If the version number is bad, (0, ) indicating version 0 is returned.""" try: return tuple(int(part) for part in s.split('.')) except Exception: return (0, ) def version_string(ptuple): """Convert a version tuple such as (1, 2) to "1.2". There is always at least one dot, so (1, ) becomes "1.0".""" while len(ptuple) < 2: ptuple += (0, ) return '.'.join(str(p) for p in ptuple) def protocol_version(client_req, min_tuple, max_tuple): """Given a client's protocol version string, return a pair of protocol tuples: (negotiated version, client min request) If the request is unsupported, the negotiated protocol tuple is None. """ if client_req is None: client_min = client_max = min_tuple else: if isinstance(client_req, list) and len(client_req) == 2: client_min, client_max = client_req else: client_min = client_max = client_req client_min = protocol_tuple(client_min) client_max = protocol_tuple(client_max) result = min(client_max, max_tuple) if result < max(client_min, min_tuple) or result == (0, ): result = None return result, client_min struct_le_i = Struct('H') struct_be_I = Struct('>I') structB = Struct('B') unpack_le_int32_from = struct_le_i.unpack_from unpack_le_int64_from = struct_le_q.unpack_from unpack_le_uint16_from = struct_le_H.unpack_from unpack_le_uint32_from = struct_le_I.unpack_from unpack_le_uint64_from = struct_le_Q.unpack_from unpack_be_uint16_from = struct_be_H.unpack_from unpack_be_uint32_from = struct_be_I.unpack_from pack_le_int32 = struct_le_i.pack pack_le_int64 = struct_le_q.pack pack_le_uint16 = struct_le_H.pack pack_le_uint32 = struct_le_I.pack pack_le_uint64 = struct_le_Q.pack pack_be_uint16 = struct_be_H.pack pack_be_uint32 = struct_be_I.pack pack_byte = structB.pack hex_to_bytes = bytes.fromhex def pack_varint(n): if n < 253: return pack_byte(n) if n < 65536: return pack_byte(253) + pack_le_uint16(n) if n < 4294967296: return pack_byte(254) + pack_le_uint32(n) return pack_byte(255) + pack_le_uint64(n) def pack_varbytes(data): return pack_varint(len(data)) + data