From 2e30ce9ae5af5b5600a0be16a5642b235aa9df85 Mon Sep 17 00:00:00 2001 From: Jack Robison Date: Tue, 20 Feb 2018 13:46:17 -0500 Subject: [PATCH] add mock udp transport layer --- lbrynet/tests/mocks.py | 102 +++++++++++++++++++++++++++++++++++++++-- lbrynet/tests/util.py | 35 ++++++++++++++ 2 files changed, 132 insertions(+), 5 deletions(-) diff --git a/lbrynet/tests/mocks.py b/lbrynet/tests/mocks.py index 1d719548b..70539ed83 100644 --- a/lbrynet/tests/mocks.py +++ b/lbrynet/tests/mocks.py @@ -1,12 +1,15 @@ +import struct import io from Crypto.PublicKey import RSA -from twisted.internet import defer +from twisted.internet import defer, threads, error from lbrynet.core import PTCWallet from lbrynet.core import BlobAvailability +from lbrynet.core.utils import generate_id from lbrynet.daemon import ExchangeRateManager as ERM from lbrynet import conf +from util import debug_kademlia_packet KB = 2**10 @@ -18,15 +21,18 @@ class FakeLBRYFile(object): self.stream_hash = stream_hash self.file_name = 'fake_lbry_file' + class Node(object): - def __init__(self, *args, **kwargs): - pass + def __init__(self, hash_announcer, peer_finder=None, peer_manager=None, **kwargs): + self.hash_announcer = hash_announcer + self.peer_finder = peer_finder + self.peer_manager = peer_manager def joinNetwork(self, *args): - pass + return defer.succeed(True) def stop(self): - pass + return defer.succeed(None) class FakeNetwork(object): @@ -164,6 +170,9 @@ class Announcer(object): def stop(self): pass + def get_next_announce_time(self): + return 0 + class GenFile(io.RawIOBase): def __init__(self, size, pattern): @@ -313,4 +322,87 @@ def mock_conf_settings(obj, settings={}): obj.addCleanup(_reset_settings) +MOCK_DHT_NODES = [ + "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", + "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF", + "DEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEF", +] +MOCK_DHT_SEED_DNS = { # these map to mock nodes 0, 1, and 2 + "lbrynet1.lbry.io": "10.42.42.1", + "lbrynet2.lbry.io": "10.42.42.2", + "lbrynet3.lbry.io": "10.42.42.3", +} + + +def resolve(name, timeout=(1, 3, 11, 45)): + if name not in MOCK_DHT_SEED_DNS: + return defer.fail(error.DNSLookupError(name)) + return defer.succeed(MOCK_DHT_SEED_DNS[name]) + + +class MockUDPTransport(object): + def __init__(self, address, port, max_packet_size, protocol): + self.address = address + self.port = port + self.max_packet_size = max_packet_size + self._node = protocol._node + + def write(self, data, address): + dest = MockNetwork.protocols[address][0] + debug_kademlia_packet(data, (self.address, self.port), address, self._node) + dest.datagramReceived(data, (self.address, self.port)) + + +class MockUDPPort(object): + def __init__(self, protocol): + self.protocol = protocol + + def startListening(self, reason=None): + return self.protocol.startProtocol() + + def stopListening(self, reason=None): + return self.protocol.stopProtocol() + + +class MockNetwork(object): + protocols = {} # (interface, port): (protocol, max_packet_size) + + @classmethod + def add_peer(cls, port, protocol, interface, maxPacketSize): + interface = protocol._node.externalIP + protocol.transport = MockUDPTransport(interface, port, maxPacketSize, protocol) + cls.protocols[(interface, port)] = (protocol, maxPacketSize) + + +def listenUDP(port, protocol, interface='', maxPacketSize=8192): + MockNetwork.add_peer(port, protocol, interface, maxPacketSize) + return MockUDPPort(protocol) + + +def address_generator(address=(10, 42, 42, 1)): + def increment(addr): + value = struct.unpack("I", "".join([chr(x) for x in list(addr)[::-1]]))[0] + 1 + new_addr = [] + for i in range(4): + new_addr.append(value % 256) + value >>= 8 + return tuple(new_addr[::-1]) + + while True: + yield "{}.{}.{}.{}".format(*address) + address = increment(address) + + +def mock_node_generator(count=None, mock_node_ids=MOCK_DHT_NODES): + if mock_node_ids is None: + mock_node_ids = MOCK_DHT_NODES + + for num, node_ip in enumerate(address_generator()): + if count and num >= count: + break + if num >= len(mock_node_ids): + node_id = generate_id().encode('hex') + else: + node_id = mock_node_ids[num] + yield (node_id, node_ip) diff --git a/lbrynet/tests/util.py b/lbrynet/tests/util.py index 43cb007ea..cc4c7da78 100644 --- a/lbrynet/tests/util.py +++ b/lbrynet/tests/util.py @@ -5,24 +5,36 @@ import os import tempfile import shutil import mock +import logging +from lbrynet.dht.encoding import Bencode +from lbrynet.dht.error import DecodeError +from lbrynet.dht.msgformat import DefaultFormat +from lbrynet.dht.msgtypes import ResponseMessage, RequestMessage, ErrorMessage +_encode = Bencode() +_datagram_formatter = DefaultFormat() DEFAULT_TIMESTAMP = datetime.datetime(2016, 1, 1) DEFAULT_ISO_TIME = time.mktime(DEFAULT_TIMESTAMP.timetuple()) +log = logging.getLogger("lbrynet.tests.util") + def mk_db_and_blob_dir(): db_dir = tempfile.mkdtemp() blob_dir = tempfile.mkdtemp() return db_dir, blob_dir + def rm_db_and_blob_dir(db_dir, blob_dir): shutil.rmtree(db_dir, ignore_errors=True) shutil.rmtree(blob_dir, ignore_errors=True) + def random_lbry_hash(): return binascii.b2a_hex(os.urandom(48)) + def resetTime(test_case, timestamp=DEFAULT_TIMESTAMP): iso_time = time.mktime(timestamp.timetuple()) patcher = mock.patch('time.time') @@ -37,5 +49,28 @@ def resetTime(test_case, timestamp=DEFAULT_TIMESTAMP): patcher.start().return_value = timestamp test_case.addCleanup(patcher.stop) + def is_android(): return 'ANDROID_ARGUMENT' in os.environ # detect Android using the Kivy way + + +def debug_kademlia_packet(data, source, destination, node): + if log.level != logging.DEBUG: + return + try: + packet = _datagram_formatter.fromPrimitive(_encode.decode(data)) + if isinstance(packet, RequestMessage): + log.debug("request %s --> %s %s (node time %s)", source[0], destination[0], packet.request, + node.clock.seconds()) + elif isinstance(packet, ResponseMessage): + if isinstance(packet.response, (str, unicode)): + log.debug("response %s <-- %s %s (node time %s)", destination[0], source[0], packet.response, + node.clock.seconds()) + else: + log.debug("response %s <-- %s %i contacts (node time %s)", destination[0], source[0], + len(packet.response), node.clock.seconds()) + elif isinstance(packet, ErrorMessage): + log.error("error %s <-- %s %s (node time %s)", destination[0], source[0], packet.exceptionType, + node.clock.seconds()) + except DecodeError: + log.exception("decode error %s --> %s (node time %s)", source[0], destination[0], node.clock.seconds())