add mock udp transport layer

This commit is contained in:
Jack Robison 2018-02-20 13:46:17 -05:00
parent 87c69742cd
commit 2e30ce9ae5
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
2 changed files with 132 additions and 5 deletions

View file

@ -1,12 +1,15 @@
import struct
import io import io
from Crypto.PublicKey import RSA from Crypto.PublicKey import RSA
from twisted.internet import defer from twisted.internet import defer, threads, error
from lbrynet.core import PTCWallet from lbrynet.core import PTCWallet
from lbrynet.core import BlobAvailability from lbrynet.core import BlobAvailability
from lbrynet.core.utils import generate_id
from lbrynet.daemon import ExchangeRateManager as ERM from lbrynet.daemon import ExchangeRateManager as ERM
from lbrynet import conf from lbrynet import conf
from util import debug_kademlia_packet
KB = 2**10 KB = 2**10
@ -18,15 +21,18 @@ class FakeLBRYFile(object):
self.stream_hash = stream_hash self.stream_hash = stream_hash
self.file_name = 'fake_lbry_file' self.file_name = 'fake_lbry_file'
class Node(object): class Node(object):
def __init__(self, *args, **kwargs): def __init__(self, hash_announcer, peer_finder=None, peer_manager=None, **kwargs):
pass self.hash_announcer = hash_announcer
self.peer_finder = peer_finder
self.peer_manager = peer_manager
def joinNetwork(self, *args): def joinNetwork(self, *args):
pass return defer.succeed(True)
def stop(self): def stop(self):
pass return defer.succeed(None)
class FakeNetwork(object): class FakeNetwork(object):
@ -164,6 +170,9 @@ class Announcer(object):
def stop(self): def stop(self):
pass pass
def get_next_announce_time(self):
return 0
class GenFile(io.RawIOBase): class GenFile(io.RawIOBase):
def __init__(self, size, pattern): def __init__(self, size, pattern):
@ -313,4 +322,87 @@ def mock_conf_settings(obj, settings={}):
obj.addCleanup(_reset_settings) obj.addCleanup(_reset_settings)
MOCK_DHT_NODES = [
"000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
"FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF",
"DEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEF",
]
MOCK_DHT_SEED_DNS = { # these map to mock nodes 0, 1, and 2
"lbrynet1.lbry.io": "10.42.42.1",
"lbrynet2.lbry.io": "10.42.42.2",
"lbrynet3.lbry.io": "10.42.42.3",
}
def resolve(name, timeout=(1, 3, 11, 45)):
if name not in MOCK_DHT_SEED_DNS:
return defer.fail(error.DNSLookupError(name))
return defer.succeed(MOCK_DHT_SEED_DNS[name])
class MockUDPTransport(object):
def __init__(self, address, port, max_packet_size, protocol):
self.address = address
self.port = port
self.max_packet_size = max_packet_size
self._node = protocol._node
def write(self, data, address):
dest = MockNetwork.protocols[address][0]
debug_kademlia_packet(data, (self.address, self.port), address, self._node)
dest.datagramReceived(data, (self.address, self.port))
class MockUDPPort(object):
def __init__(self, protocol):
self.protocol = protocol
def startListening(self, reason=None):
return self.protocol.startProtocol()
def stopListening(self, reason=None):
return self.protocol.stopProtocol()
class MockNetwork(object):
protocols = {} # (interface, port): (protocol, max_packet_size)
@classmethod
def add_peer(cls, port, protocol, interface, maxPacketSize):
interface = protocol._node.externalIP
protocol.transport = MockUDPTransport(interface, port, maxPacketSize, protocol)
cls.protocols[(interface, port)] = (protocol, maxPacketSize)
def listenUDP(port, protocol, interface='', maxPacketSize=8192):
MockNetwork.add_peer(port, protocol, interface, maxPacketSize)
return MockUDPPort(protocol)
def address_generator(address=(10, 42, 42, 1)):
def increment(addr):
value = struct.unpack("I", "".join([chr(x) for x in list(addr)[::-1]]))[0] + 1
new_addr = []
for i in range(4):
new_addr.append(value % 256)
value >>= 8
return tuple(new_addr[::-1])
while True:
yield "{}.{}.{}.{}".format(*address)
address = increment(address)
def mock_node_generator(count=None, mock_node_ids=MOCK_DHT_NODES):
if mock_node_ids is None:
mock_node_ids = MOCK_DHT_NODES
for num, node_ip in enumerate(address_generator()):
if count and num >= count:
break
if num >= len(mock_node_ids):
node_id = generate_id().encode('hex')
else:
node_id = mock_node_ids[num]
yield (node_id, node_ip)

View file

@ -5,24 +5,36 @@ import os
import tempfile import tempfile
import shutil import shutil
import mock import mock
import logging
from lbrynet.dht.encoding import Bencode
from lbrynet.dht.error import DecodeError
from lbrynet.dht.msgformat import DefaultFormat
from lbrynet.dht.msgtypes import ResponseMessage, RequestMessage, ErrorMessage
_encode = Bencode()
_datagram_formatter = DefaultFormat()
DEFAULT_TIMESTAMP = datetime.datetime(2016, 1, 1) DEFAULT_TIMESTAMP = datetime.datetime(2016, 1, 1)
DEFAULT_ISO_TIME = time.mktime(DEFAULT_TIMESTAMP.timetuple()) DEFAULT_ISO_TIME = time.mktime(DEFAULT_TIMESTAMP.timetuple())
log = logging.getLogger("lbrynet.tests.util")
def mk_db_and_blob_dir(): def mk_db_and_blob_dir():
db_dir = tempfile.mkdtemp() db_dir = tempfile.mkdtemp()
blob_dir = tempfile.mkdtemp() blob_dir = tempfile.mkdtemp()
return db_dir, blob_dir return db_dir, blob_dir
def rm_db_and_blob_dir(db_dir, blob_dir): def rm_db_and_blob_dir(db_dir, blob_dir):
shutil.rmtree(db_dir, ignore_errors=True) shutil.rmtree(db_dir, ignore_errors=True)
shutil.rmtree(blob_dir, ignore_errors=True) shutil.rmtree(blob_dir, ignore_errors=True)
def random_lbry_hash(): def random_lbry_hash():
return binascii.b2a_hex(os.urandom(48)) return binascii.b2a_hex(os.urandom(48))
def resetTime(test_case, timestamp=DEFAULT_TIMESTAMP): def resetTime(test_case, timestamp=DEFAULT_TIMESTAMP):
iso_time = time.mktime(timestamp.timetuple()) iso_time = time.mktime(timestamp.timetuple())
patcher = mock.patch('time.time') patcher = mock.patch('time.time')
@ -37,5 +49,28 @@ def resetTime(test_case, timestamp=DEFAULT_TIMESTAMP):
patcher.start().return_value = timestamp patcher.start().return_value = timestamp
test_case.addCleanup(patcher.stop) test_case.addCleanup(patcher.stop)
def is_android(): def is_android():
return 'ANDROID_ARGUMENT' in os.environ # detect Android using the Kivy way return 'ANDROID_ARGUMENT' in os.environ # detect Android using the Kivy way
def debug_kademlia_packet(data, source, destination, node):
if log.level != logging.DEBUG:
return
try:
packet = _datagram_formatter.fromPrimitive(_encode.decode(data))
if isinstance(packet, RequestMessage):
log.debug("request %s --> %s %s (node time %s)", source[0], destination[0], packet.request,
node.clock.seconds())
elif isinstance(packet, ResponseMessage):
if isinstance(packet.response, (str, unicode)):
log.debug("response %s <-- %s %s (node time %s)", destination[0], source[0], packet.response,
node.clock.seconds())
else:
log.debug("response %s <-- %s %i contacts (node time %s)", destination[0], source[0],
len(packet.response), node.clock.seconds())
elif isinstance(packet, ErrorMessage):
log.error("error %s <-- %s %s (node time %s)", destination[0], source[0], packet.exceptionType,
node.clock.seconds())
except DecodeError:
log.exception("decode error %s --> %s (node time %s)", source[0], destination[0], node.clock.seconds())