more porting, plus some functional tests working

This commit is contained in:
Victor Shyba 2018-07-24 18:51:41 -03:00 committed by Jack Robison
parent 7335d012ef
commit 78c8c8e64d
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
9 changed files with 38 additions and 30 deletions

View file

@ -182,7 +182,6 @@ class ContactManager:
return contact
def make_contact(self, id, ipAddress, udpPort, networkProtocol, firstComm=0):
ipAddress = str(ipAddress)
contact = self.get_contact(id, ipAddress, udpPort)
if contact:
return contact

View file

@ -38,7 +38,7 @@ class DictDataStore(UserDict):
if not unexpired_peers:
del self[key]
else:
self[key] = unexpired_peers
self[key] = list(unexpired_peers)
def hasPeersForBlob(self, key):
return True if key in self and len(tuple(self.filter_bad_and_expired_peers(key))) else False

View file

@ -60,6 +60,8 @@ class Bencode(Encoding):
return b'i%de' % data
elif isinstance(data, bytes):
return b'%d:%s' % (len(data), data)
elif isinstance(data, str):
return b'%d:' % (len(data)) + data.encode()
elif isinstance(data, (list, tuple)):
encodedListItems = b''
for item in data:

View file

@ -38,7 +38,7 @@ class _IterativeFind:
# Shortlist of contact objects (the k closest known contacts to the key from the routing table)
self.shortlist = shortlist
# The search key
self.key = str(key)
self.key = key
# The rpc method name (findValue or findNode)
self.rpc = rpc
# List of active queries; len() indicates number of active probes
@ -74,22 +74,22 @@ class _IterativeFind:
for contact_tup in contact_triples:
if not isinstance(contact_tup, (list, tuple)) or len(contact_tup) != 3:
raise ValueError("invalid contact triple")
contact_tup[1] = contact_tup[1].decode() # ips are strings
return contact_triples
def sortByDistance(self, contact_list):
"""Sort the list of contacts in order by distance from key"""
contact_list.sort(key=lambda c: self.distance(c.id))
@defer.inlineCallbacks
def extendShortlist(self, contact, result):
# The "raw response" tuple contains the response message and the originating address info
originAddress = (contact.address, contact.port)
if self.finished_deferred.called:
defer.returnValue(contact.id)
return contact.id
if self.node.contact_manager.is_ignored(originAddress):
raise ValueError("contact is ignored")
if contact.id == self.node.node_id:
defer.returnValue(contact.id)
return contact.id
if contact not in self.active_contacts:
self.active_contacts.append(contact)
@ -134,14 +134,14 @@ class _IterativeFind:
self.sortByDistance(self.active_contacts)
self.finished_deferred.callback(self.active_contacts[:min(constants.k, len(self.active_contacts))])
defer.returnValue(contact.id)
return contact.id
@defer.inlineCallbacks
def probeContact(self, contact):
fn = getattr(contact, self.rpc)
try:
response = yield fn(self.key)
result = yield self.extendShortlist(contact, response)
result = self.extendShortlist(contact, response)
defer.returnValue(result)
except (TimeoutError, defer.CancelledError, ValueError, IndexError):
defer.returnValue(contact.id)

View file

@ -162,6 +162,13 @@ class Node(MockKademliaHelper):
# if hasattr(self, "_listeningPort") and self._listeningPort is not None:
# self._listeningPort.stopListening()
def __str__(self):
return '<%s.%s object; ID: %s, IP address: %s, UDP port: %d>' % (
self.__module__, self.__class__.__name__, binascii.hexlify(self.node_id), self.externalIP, self.port)
def __hash__(self):
return self.node_id.__hash__()
@defer.inlineCallbacks
def stop(self):
# stop LoopingCalls:
@ -315,10 +322,10 @@ class Node(MockKademliaHelper):
token = contact.token
if not token:
find_value_response = yield contact.findValue(blob_hash)
token = find_value_response['token']
token = find_value_response[b'token']
contact.update_token(token)
res = yield contact.store(blob_hash, token, self.peerPort, self.node_id, 0)
if res != "OK":
if res != b"OK":
raise ValueError(res)
defer.returnValue(True)
log.debug("Stored %s to %s (%s)", binascii.hexlify(blob_hash), contact.log_id(), contact.address)
@ -326,7 +333,7 @@ class Node(MockKademliaHelper):
log.debug("Timeout while storing blob_hash %s at %s",
binascii.hexlify(blob_hash), contact.log_id())
except ValueError as err:
log.error("Unexpected response: %s" % err.message)
log.error("Unexpected response: %s" % err)
except Exception as err:
log.error("Unexpected error while storing blob_hash %s at %s: %s",
binascii.hexlify(blob_hash), contact, err)
@ -339,9 +346,7 @@ class Node(MockKademliaHelper):
if not self.externalIP:
raise Exception("Cannot determine external IP: %s" % self.externalIP)
stored_to = yield DeferredDict({contact: self.storeToContact(blob_hash, contact) for contact in contacts})
contacted_node_ids = map(
lambda contact: contact.id.encode('hex'), filter(lambda contact: stored_to[contact], stored_to.keys())
)
contacted_node_ids = [binascii.hexlify(contact.id) for contact in stored_to.keys() if stored_to[contact]]
log.debug("Stored %s to %i of %i attempted peers", binascii.hexlify(blob_hash),
len(contacted_node_ids), len(contacts))
defer.returnValue(contacted_node_ids)
@ -403,7 +408,7 @@ class Node(MockKademliaHelper):
@rtype: twisted.internet.defer.Deferred
"""
if len(key) != constants.key_bits / 8:
if len(key) != constants.key_bits // 8:
raise ValueError("invalid key length!")
# Execute the search
@ -554,7 +559,7 @@ class Node(MockKademliaHelper):
node is returning all of the contacts that it knows of.
@rtype: list
"""
if len(key) != constants.key_bits / 8:
if len(key) != constants.key_bits // 8:
raise ValueError("invalid contact id length: %i" % len(key))
contacts = self._routingTable.findCloseNodes(key, sender_node_id=rpc_contact.id)
@ -576,7 +581,7 @@ class Node(MockKademliaHelper):
@rtype: dict or list
"""
if len(key) != constants.key_bits / 8:
if len(key) != constants.key_bits // 8:
raise ValueError("invalid blob hash length: %i" % len(key))
response = {
@ -645,7 +650,7 @@ class Node(MockKademliaHelper):
@rtype: twisted.internet.defer.Deferred
"""
if len(key) != constants.key_bits / 8:
if len(key) != constants.key_bits // 8:
raise ValueError("invalid key length: %i" % len(key))
if startupShortlist is None:

View file

@ -220,17 +220,17 @@ class KademliaProtocol(protocol.DatagramProtocol):
@note: This is automatically called by Twisted when the protocol
receives a UDP datagram
"""
if datagram[0] == b'\x00' and datagram[25] == b'\x00':
totalPackets = (ord(datagram[1]) << 8) | ord(datagram[2])
if chr(datagram[0]) == '\x00' and chr(datagram[25]) == '\x00':
totalPackets = (datagram[1] << 8) | datagram[2]
msgID = datagram[5:25]
seqNumber = (ord(datagram[3]) << 8) | ord(datagram[4])
seqNumber = (datagram[3] << 8) | datagram[4]
if msgID not in self._partialMessages:
self._partialMessages[msgID] = {}
self._partialMessages[msgID][seqNumber] = datagram[26:]
if len(self._partialMessages[msgID]) == totalPackets:
keys = self._partialMessages[msgID].keys()
keys.sort()
data = ''
data = b''
for key in keys:
data += self._partialMessages[msgID][key]
datagram = data
@ -350,7 +350,7 @@ class KademliaProtocol(protocol.DatagramProtocol):
# 1st byte is transmission type id, bytes 2 & 3 are the
# total number of packets in this transmission, bytes 4 &
# 5 are the sequence number for this specific packet
totalPackets = len(data) / self.msgSizeLimit
totalPackets = len(data) // self.msgSizeLimit
if len(data) % self.msgSizeLimit > 0:
totalPackets += 1
encTotalPackets = chr(totalPackets >> 8) + chr(totalPackets & 0xff)
@ -375,7 +375,7 @@ class KademliaProtocol(protocol.DatagramProtocol):
if self.transport:
try:
self.transport.write(txData, address)
except socket.error as err:
except OSError as err:
if err.errno == errno.EWOULDBLOCK:
# i'm scared this may swallow important errors, but i get a million of these
# on Linux and it doesnt seem to affect anything -grin

View file

@ -6,6 +6,8 @@
# may be created by processing this file with epydoc: http://epydoc.sf.net
import random
from binascii import unhexlify
from twisted.internet import defer
from . import constants
from . import kbucket
@ -267,8 +269,8 @@ class TreeRoutingTable:
randomID = randomID[:-1]
if len(randomID) % 2 != 0:
randomID = '0' + randomID
randomID = randomID.decode('hex')
randomID = (constants.key_bits / 8 - len(randomID)) * '\x00' + randomID
randomID = unhexlify(randomID)
randomID = ((constants.key_bits // 8) - len(randomID)) * b'\x00' + randomID
return randomID
def _splitBucket(self, oldBucketIndex):

View file

@ -173,5 +173,5 @@ class TestKademliaBase(unittest.TestCase):
yield self.run_reactor(2, ping_dl)
node_addresses = {node.externalIP for node in self.nodes}.union({seed.externalIP for seed in self._seeds})
self.assertSetEqual(node_addresses, contacted)
expected = {node: "pong" for node in contacted}
expected = {node: b"pong" for node in contacted}
self.assertDictEqual(ping_replies, expected)

View file

@ -25,9 +25,9 @@ class TestPeerExpiration(TestKademliaBase):
offline_addresses = self.get_routable_addresses().difference(self.get_online_addresses())
self.assertSetEqual(offline_addresses, removed_addresses)
get_nodes_with_stale_contacts = lambda: filter(lambda node: any(contact.address in offline_addresses
for contact in node.contacts),
self.nodes + self._seeds)
get_nodes_with_stale_contacts = lambda: list(filter(lambda node: any(contact.address in offline_addresses
for contact in node.contacts),
self.nodes + self._seeds))
self.assertRaises(AssertionError, self.verify_all_nodes_are_routable)
self.assertTrue(len(get_nodes_with_stale_contacts()) > 1)