From 19211d44173288c721b0f5759cdab45b41bdfab1 Mon Sep 17 00:00:00 2001 From: Victor Shyba Date: Tue, 17 Jul 2018 23:00:34 -0300 Subject: [PATCH] make sure bencoding works for bytes, not strings --- lbrynet/dht/encoding.py | 36 +++++++++++++++++---------------- lbrynet/dht/node.py | 12 +++++------ tests/unit/dht/test_encoding.py | 16 +++++++-------- 3 files changed, 33 insertions(+), 31 deletions(-) diff --git a/lbrynet/dht/encoding.py b/lbrynet/dht/encoding.py index 43dc64fd4..b6fe37009 100644 --- a/lbrynet/dht/encoding.py +++ b/lbrynet/dht/encoding.py @@ -54,23 +54,24 @@ class Bencode(Encoding): @rtype: str """ if isinstance(data, (int, long)): - return 'i%de' % data + return b'i%de' % data elif isinstance(data, str): - return '%d:%s' % (len(data), data) + return b'%d:%s' % (len(data), data.encode()) + elif isinstance(data, bytes): + return b'%d:%s' % (len(data), data) elif isinstance(data, (list, tuple)): - encodedListItems = '' + encodedListItems = b'' for item in data: encodedListItems += self.encode(item) - return 'l%se' % encodedListItems + return b'l%se' % encodedListItems elif isinstance(data, dict): - encodedDictItems = '' + encodedDictItems = b'' keys = data.keys() for key in sorted(keys): encodedDictItems += self.encode(key) # TODO: keys should always be bytestrings encodedDictItems += self.encode(data[key]) - return 'd%se' % encodedDictItems + return b'd%se' % encodedDictItems else: - print(data) raise TypeError("Cannot bencode '%s' object" % type(data)) def decode(self, data): @@ -85,6 +86,7 @@ class Bencode(Encoding): @return: The decoded data, as a native Python type @rtype: int, list, dict or str """ + assert type(data) == bytes if len(data) == 0: raise DecodeError('Cannot decode empty string') try: @@ -98,34 +100,34 @@ class Bencode(Encoding): Do not call this; use C{decode()} instead """ - if data[startIndex] == 'i': - endPos = data[startIndex:].find('e') + startIndex + if data[startIndex] == ord('i'): + endPos = data[startIndex:].find(b'e') + startIndex return int(data[startIndex + 1:endPos]), endPos + 1 - elif data[startIndex] == 'l': + elif data[startIndex] == ord('l'): startIndex += 1 decodedList = [] - while data[startIndex] != 'e': + while data[startIndex] != ord('e'): listData, startIndex = Bencode._decodeRecursive(data, startIndex) decodedList.append(listData) return decodedList, startIndex + 1 - elif data[startIndex] == 'd': + elif data[startIndex] == ord('d'): startIndex += 1 decodedDict = {} - while data[startIndex] != 'e': + while data[startIndex] != ord('e'): key, startIndex = Bencode._decodeRecursive(data, startIndex) value, startIndex = Bencode._decodeRecursive(data, startIndex) decodedDict[key] = value return decodedDict, startIndex - elif data[startIndex] == 'f': + elif data[startIndex] == ord('f'): # This (float data type) is a non-standard extension to the original Bencode algorithm - endPos = data[startIndex:].find('e') + startIndex + endPos = data[startIndex:].find(ord('e')) + startIndex return float(data[startIndex + 1:endPos]), endPos + 1 - elif data[startIndex] == 'n': + elif data[startIndex] == ord('n'): # This (None/NULL data type) is a non-standard extension # to the original Bencode algorithm return None, startIndex + 1 else: - splitPos = data[startIndex:].find(':') + startIndex + splitPos = data[startIndex:].find(ord(':')) + startIndex try: length = int(data[startIndex:splitPos]) except ValueError: diff --git a/lbrynet/dht/node.py b/lbrynet/dht/node.py index f4937852b..3b37b803f 100644 --- a/lbrynet/dht/node.py +++ b/lbrynet/dht/node.py @@ -155,10 +155,10 @@ class Node(MockKademliaHelper): self.peer_finder = peer_finder or DHTPeerFinder(self, self.peer_manager) self._join_deferred = None - def __del__(self): - log.warning("unclean shutdown of the dht node") - if hasattr(self, "_listeningPort") and self._listeningPort is not None: - self._listeningPort.stopListening() + #def __del__(self): + # log.warning("unclean shutdown of the dht node") + # if hasattr(self, "_listeningPort") and self._listeningPort is not None: + # self._listeningPort.stopListening() @defer.inlineCallbacks def stop(self): @@ -203,7 +203,7 @@ class Node(MockKademliaHelper): if not known_node_resolution: known_node_resolution = yield _resolve_seeds() # we are one of the seed nodes, don't add ourselves - if (self.externalIP, self.port) in known_node_resolution.itervalues(): + if (self.externalIP, self.port) in known_node_resolution.values(): del known_node_resolution[(self.externalIP, self.port)] known_node_addresses.remove((self.externalIP, self.port)) @@ -216,7 +216,7 @@ class Node(MockKademliaHelper): def _initialize_routing(): bootstrap_contacts = [] contact_addresses = {(c.address, c.port): c for c in self.contacts} - for (host, port), ip_address in known_node_resolution.iteritems(): + for (host, port), ip_address in known_node_resolution.items(): if (host, port) not in contact_addresses: # Create temporary contact information for the list of addresses of known nodes # The contact node id will be set with the responding node id when we initialize it to None diff --git a/tests/unit/dht/test_encoding.py b/tests/unit/dht/test_encoding.py index 042a664f3..2f694d849 100644 --- a/tests/unit/dht/test_encoding.py +++ b/tests/unit/dht/test_encoding.py @@ -13,17 +13,17 @@ class BencodeTest(unittest.TestCase): def setUp(self): self.encoding = lbrynet.dht.encoding.Bencode() # Thanks goes to wikipedia for the initial test cases ;-) - self.cases = ((42, 'i42e'), - ('spam', '4:spam'), - (['spam', 42], 'l4:spami42ee'), - ({'foo': 42, 'bar': 'spam'}, 'd3:bar4:spam3:fooi42ee'), + self.cases = ((42, b'i42e'), + (b'spam', b'4:spam'), + ([b'spam', 42], b'l4:spami42ee'), + ({b'foo': 42, b'bar': b'spam'}, b'd3:bar4:spam3:fooi42ee'), # ...and now the "real life" tests - ([['abc', '127.0.0.1', 1919], ['def', '127.0.0.1', 1921]], - 'll3:abc9:127.0.0.1i1919eel3:def9:127.0.0.1i1921eee')) + ([[b'abc', b'127.0.0.1', 1919], [b'def', b'127.0.0.1', 1921]], + b'll3:abc9:127.0.0.1i1919eel3:def9:127.0.0.1i1921eee')) # The following test cases are "bad"; i.e. sending rubbish into the decoder to test # what exceptions get thrown - self.badDecoderCases = ('abcdefghijklmnopqrstuvwxyz', - '') + self.badDecoderCases = (b'abcdefghijklmnopqrstuvwxyz', + b'') def testEncoder(self): """ Tests the bencode encoder """