more porting, plus some functional tests working

This commit is contained in:
Victor Shyba 2018-07-24 18:51:41 -03:00 committed by Jack Robison
parent 7335d012ef
commit 78c8c8e64d
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
9 changed files with 38 additions and 30 deletions

View file

@ -182,7 +182,6 @@ class ContactManager:
return contact return contact
def make_contact(self, id, ipAddress, udpPort, networkProtocol, firstComm=0): def make_contact(self, id, ipAddress, udpPort, networkProtocol, firstComm=0):
ipAddress = str(ipAddress)
contact = self.get_contact(id, ipAddress, udpPort) contact = self.get_contact(id, ipAddress, udpPort)
if contact: if contact:
return contact return contact

View file

@ -38,7 +38,7 @@ class DictDataStore(UserDict):
if not unexpired_peers: if not unexpired_peers:
del self[key] del self[key]
else: else:
self[key] = unexpired_peers self[key] = list(unexpired_peers)
def hasPeersForBlob(self, key): def hasPeersForBlob(self, key):
return True if key in self and len(tuple(self.filter_bad_and_expired_peers(key))) else False return True if key in self and len(tuple(self.filter_bad_and_expired_peers(key))) else False

View file

@ -60,6 +60,8 @@ class Bencode(Encoding):
return b'i%de' % data return b'i%de' % data
elif isinstance(data, bytes): elif isinstance(data, bytes):
return b'%d:%s' % (len(data), data) return b'%d:%s' % (len(data), data)
elif isinstance(data, str):
return b'%d:' % (len(data)) + data.encode()
elif isinstance(data, (list, tuple)): elif isinstance(data, (list, tuple)):
encodedListItems = b'' encodedListItems = b''
for item in data: for item in data:

View file

@ -38,7 +38,7 @@ class _IterativeFind:
# Shortlist of contact objects (the k closest known contacts to the key from the routing table) # Shortlist of contact objects (the k closest known contacts to the key from the routing table)
self.shortlist = shortlist self.shortlist = shortlist
# The search key # The search key
self.key = str(key) self.key = key
# The rpc method name (findValue or findNode) # The rpc method name (findValue or findNode)
self.rpc = rpc self.rpc = rpc
# List of active queries; len() indicates number of active probes # List of active queries; len() indicates number of active probes
@ -74,22 +74,22 @@ class _IterativeFind:
for contact_tup in contact_triples: for contact_tup in contact_triples:
if not isinstance(contact_tup, (list, tuple)) or len(contact_tup) != 3: if not isinstance(contact_tup, (list, tuple)) or len(contact_tup) != 3:
raise ValueError("invalid contact triple") raise ValueError("invalid contact triple")
contact_tup[1] = contact_tup[1].decode() # ips are strings
return contact_triples return contact_triples
def sortByDistance(self, contact_list): def sortByDistance(self, contact_list):
"""Sort the list of contacts in order by distance from key""" """Sort the list of contacts in order by distance from key"""
contact_list.sort(key=lambda c: self.distance(c.id)) contact_list.sort(key=lambda c: self.distance(c.id))
@defer.inlineCallbacks
def extendShortlist(self, contact, result): def extendShortlist(self, contact, result):
# The "raw response" tuple contains the response message and the originating address info # The "raw response" tuple contains the response message and the originating address info
originAddress = (contact.address, contact.port) originAddress = (contact.address, contact.port)
if self.finished_deferred.called: if self.finished_deferred.called:
defer.returnValue(contact.id) return contact.id
if self.node.contact_manager.is_ignored(originAddress): if self.node.contact_manager.is_ignored(originAddress):
raise ValueError("contact is ignored") raise ValueError("contact is ignored")
if contact.id == self.node.node_id: if contact.id == self.node.node_id:
defer.returnValue(contact.id) return contact.id
if contact not in self.active_contacts: if contact not in self.active_contacts:
self.active_contacts.append(contact) self.active_contacts.append(contact)
@ -134,14 +134,14 @@ class _IterativeFind:
self.sortByDistance(self.active_contacts) self.sortByDistance(self.active_contacts)
self.finished_deferred.callback(self.active_contacts[:min(constants.k, len(self.active_contacts))]) self.finished_deferred.callback(self.active_contacts[:min(constants.k, len(self.active_contacts))])
defer.returnValue(contact.id) return contact.id
@defer.inlineCallbacks @defer.inlineCallbacks
def probeContact(self, contact): def probeContact(self, contact):
fn = getattr(contact, self.rpc) fn = getattr(contact, self.rpc)
try: try:
response = yield fn(self.key) response = yield fn(self.key)
result = yield self.extendShortlist(contact, response) result = self.extendShortlist(contact, response)
defer.returnValue(result) defer.returnValue(result)
except (TimeoutError, defer.CancelledError, ValueError, IndexError): except (TimeoutError, defer.CancelledError, ValueError, IndexError):
defer.returnValue(contact.id) defer.returnValue(contact.id)

View file

@ -162,6 +162,13 @@ class Node(MockKademliaHelper):
# if hasattr(self, "_listeningPort") and self._listeningPort is not None: # if hasattr(self, "_listeningPort") and self._listeningPort is not None:
# self._listeningPort.stopListening() # self._listeningPort.stopListening()
def __str__(self):
return '<%s.%s object; ID: %s, IP address: %s, UDP port: %d>' % (
self.__module__, self.__class__.__name__, binascii.hexlify(self.node_id), self.externalIP, self.port)
def __hash__(self):
return self.node_id.__hash__()
@defer.inlineCallbacks @defer.inlineCallbacks
def stop(self): def stop(self):
# stop LoopingCalls: # stop LoopingCalls:
@ -315,10 +322,10 @@ class Node(MockKademliaHelper):
token = contact.token token = contact.token
if not token: if not token:
find_value_response = yield contact.findValue(blob_hash) find_value_response = yield contact.findValue(blob_hash)
token = find_value_response['token'] token = find_value_response[b'token']
contact.update_token(token) contact.update_token(token)
res = yield contact.store(blob_hash, token, self.peerPort, self.node_id, 0) res = yield contact.store(blob_hash, token, self.peerPort, self.node_id, 0)
if res != "OK": if res != b"OK":
raise ValueError(res) raise ValueError(res)
defer.returnValue(True) defer.returnValue(True)
log.debug("Stored %s to %s (%s)", binascii.hexlify(blob_hash), contact.log_id(), contact.address) log.debug("Stored %s to %s (%s)", binascii.hexlify(blob_hash), contact.log_id(), contact.address)
@ -326,7 +333,7 @@ class Node(MockKademliaHelper):
log.debug("Timeout while storing blob_hash %s at %s", log.debug("Timeout while storing blob_hash %s at %s",
binascii.hexlify(blob_hash), contact.log_id()) binascii.hexlify(blob_hash), contact.log_id())
except ValueError as err: except ValueError as err:
log.error("Unexpected response: %s" % err.message) log.error("Unexpected response: %s" % err)
except Exception as err: except Exception as err:
log.error("Unexpected error while storing blob_hash %s at %s: %s", log.error("Unexpected error while storing blob_hash %s at %s: %s",
binascii.hexlify(blob_hash), contact, err) binascii.hexlify(blob_hash), contact, err)
@ -339,9 +346,7 @@ class Node(MockKademliaHelper):
if not self.externalIP: if not self.externalIP:
raise Exception("Cannot determine external IP: %s" % self.externalIP) raise Exception("Cannot determine external IP: %s" % self.externalIP)
stored_to = yield DeferredDict({contact: self.storeToContact(blob_hash, contact) for contact in contacts}) stored_to = yield DeferredDict({contact: self.storeToContact(blob_hash, contact) for contact in contacts})
contacted_node_ids = map( contacted_node_ids = [binascii.hexlify(contact.id) for contact in stored_to.keys() if stored_to[contact]]
lambda contact: contact.id.encode('hex'), filter(lambda contact: stored_to[contact], stored_to.keys())
)
log.debug("Stored %s to %i of %i attempted peers", binascii.hexlify(blob_hash), log.debug("Stored %s to %i of %i attempted peers", binascii.hexlify(blob_hash),
len(contacted_node_ids), len(contacts)) len(contacted_node_ids), len(contacts))
defer.returnValue(contacted_node_ids) defer.returnValue(contacted_node_ids)
@ -403,7 +408,7 @@ class Node(MockKademliaHelper):
@rtype: twisted.internet.defer.Deferred @rtype: twisted.internet.defer.Deferred
""" """
if len(key) != constants.key_bits / 8: if len(key) != constants.key_bits // 8:
raise ValueError("invalid key length!") raise ValueError("invalid key length!")
# Execute the search # Execute the search
@ -554,7 +559,7 @@ class Node(MockKademliaHelper):
node is returning all of the contacts that it knows of. node is returning all of the contacts that it knows of.
@rtype: list @rtype: list
""" """
if len(key) != constants.key_bits / 8: if len(key) != constants.key_bits // 8:
raise ValueError("invalid contact id length: %i" % len(key)) raise ValueError("invalid contact id length: %i" % len(key))
contacts = self._routingTable.findCloseNodes(key, sender_node_id=rpc_contact.id) contacts = self._routingTable.findCloseNodes(key, sender_node_id=rpc_contact.id)
@ -576,7 +581,7 @@ class Node(MockKademliaHelper):
@rtype: dict or list @rtype: dict or list
""" """
if len(key) != constants.key_bits / 8: if len(key) != constants.key_bits // 8:
raise ValueError("invalid blob hash length: %i" % len(key)) raise ValueError("invalid blob hash length: %i" % len(key))
response = { response = {
@ -645,7 +650,7 @@ class Node(MockKademliaHelper):
@rtype: twisted.internet.defer.Deferred @rtype: twisted.internet.defer.Deferred
""" """
if len(key) != constants.key_bits / 8: if len(key) != constants.key_bits // 8:
raise ValueError("invalid key length: %i" % len(key)) raise ValueError("invalid key length: %i" % len(key))
if startupShortlist is None: if startupShortlist is None:

View file

@ -220,17 +220,17 @@ class KademliaProtocol(protocol.DatagramProtocol):
@note: This is automatically called by Twisted when the protocol @note: This is automatically called by Twisted when the protocol
receives a UDP datagram receives a UDP datagram
""" """
if datagram[0] == b'\x00' and datagram[25] == b'\x00': if chr(datagram[0]) == '\x00' and chr(datagram[25]) == '\x00':
totalPackets = (ord(datagram[1]) << 8) | ord(datagram[2]) totalPackets = (datagram[1] << 8) | datagram[2]
msgID = datagram[5:25] msgID = datagram[5:25]
seqNumber = (ord(datagram[3]) << 8) | ord(datagram[4]) seqNumber = (datagram[3] << 8) | datagram[4]
if msgID not in self._partialMessages: if msgID not in self._partialMessages:
self._partialMessages[msgID] = {} self._partialMessages[msgID] = {}
self._partialMessages[msgID][seqNumber] = datagram[26:] self._partialMessages[msgID][seqNumber] = datagram[26:]
if len(self._partialMessages[msgID]) == totalPackets: if len(self._partialMessages[msgID]) == totalPackets:
keys = self._partialMessages[msgID].keys() keys = self._partialMessages[msgID].keys()
keys.sort() keys.sort()
data = '' data = b''
for key in keys: for key in keys:
data += self._partialMessages[msgID][key] data += self._partialMessages[msgID][key]
datagram = data datagram = data
@ -350,7 +350,7 @@ class KademliaProtocol(protocol.DatagramProtocol):
# 1st byte is transmission type id, bytes 2 & 3 are the # 1st byte is transmission type id, bytes 2 & 3 are the
# total number of packets in this transmission, bytes 4 & # total number of packets in this transmission, bytes 4 &
# 5 are the sequence number for this specific packet # 5 are the sequence number for this specific packet
totalPackets = len(data) / self.msgSizeLimit totalPackets = len(data) // self.msgSizeLimit
if len(data) % self.msgSizeLimit > 0: if len(data) % self.msgSizeLimit > 0:
totalPackets += 1 totalPackets += 1
encTotalPackets = chr(totalPackets >> 8) + chr(totalPackets & 0xff) encTotalPackets = chr(totalPackets >> 8) + chr(totalPackets & 0xff)
@ -375,7 +375,7 @@ class KademliaProtocol(protocol.DatagramProtocol):
if self.transport: if self.transport:
try: try:
self.transport.write(txData, address) self.transport.write(txData, address)
except socket.error as err: except OSError as err:
if err.errno == errno.EWOULDBLOCK: if err.errno == errno.EWOULDBLOCK:
# i'm scared this may swallow important errors, but i get a million of these # i'm scared this may swallow important errors, but i get a million of these
# on Linux and it doesnt seem to affect anything -grin # on Linux and it doesnt seem to affect anything -grin

View file

@ -6,6 +6,8 @@
# may be created by processing this file with epydoc: http://epydoc.sf.net # may be created by processing this file with epydoc: http://epydoc.sf.net
import random import random
from binascii import unhexlify
from twisted.internet import defer from twisted.internet import defer
from . import constants from . import constants
from . import kbucket from . import kbucket
@ -267,8 +269,8 @@ class TreeRoutingTable:
randomID = randomID[:-1] randomID = randomID[:-1]
if len(randomID) % 2 != 0: if len(randomID) % 2 != 0:
randomID = '0' + randomID randomID = '0' + randomID
randomID = randomID.decode('hex') randomID = unhexlify(randomID)
randomID = (constants.key_bits / 8 - len(randomID)) * '\x00' + randomID randomID = ((constants.key_bits // 8) - len(randomID)) * b'\x00' + randomID
return randomID return randomID
def _splitBucket(self, oldBucketIndex): def _splitBucket(self, oldBucketIndex):

View file

@ -173,5 +173,5 @@ class TestKademliaBase(unittest.TestCase):
yield self.run_reactor(2, ping_dl) yield self.run_reactor(2, ping_dl)
node_addresses = {node.externalIP for node in self.nodes}.union({seed.externalIP for seed in self._seeds}) node_addresses = {node.externalIP for node in self.nodes}.union({seed.externalIP for seed in self._seeds})
self.assertSetEqual(node_addresses, contacted) self.assertSetEqual(node_addresses, contacted)
expected = {node: "pong" for node in contacted} expected = {node: b"pong" for node in contacted}
self.assertDictEqual(ping_replies, expected) self.assertDictEqual(ping_replies, expected)

View file

@ -25,9 +25,9 @@ class TestPeerExpiration(TestKademliaBase):
offline_addresses = self.get_routable_addresses().difference(self.get_online_addresses()) offline_addresses = self.get_routable_addresses().difference(self.get_online_addresses())
self.assertSetEqual(offline_addresses, removed_addresses) self.assertSetEqual(offline_addresses, removed_addresses)
get_nodes_with_stale_contacts = lambda: filter(lambda node: any(contact.address in offline_addresses get_nodes_with_stale_contacts = lambda: list(filter(lambda node: any(contact.address in offline_addresses
for contact in node.contacts), for contact in node.contacts),
self.nodes + self._seeds) self.nodes + self._seeds))
self.assertRaises(AssertionError, self.verify_all_nodes_are_routable) self.assertRaises(AssertionError, self.verify_all_nodes_are_routable)
self.assertTrue(len(get_nodes_with_stale_contacts()) > 1) self.assertTrue(len(get_nodes_with_stale_contacts()) > 1)