make sure bencoding works for bytes, not strings

This commit is contained in:
Victor Shyba 2018-07-17 23:00:34 -03:00 committed by Jack Robison
parent c312d1b3a6
commit 19211d4417
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
3 changed files with 33 additions and 31 deletions

View file

@ -54,23 +54,24 @@ class Bencode(Encoding):
@rtype: str @rtype: str
""" """
if isinstance(data, (int, long)): if isinstance(data, (int, long)):
return 'i%de' % data return b'i%de' % data
elif isinstance(data, str): elif isinstance(data, str):
return '%d:%s' % (len(data), data) return b'%d:%s' % (len(data), data.encode())
elif isinstance(data, bytes):
return b'%d:%s' % (len(data), data)
elif isinstance(data, (list, tuple)): elif isinstance(data, (list, tuple)):
encodedListItems = '' encodedListItems = b''
for item in data: for item in data:
encodedListItems += self.encode(item) encodedListItems += self.encode(item)
return 'l%se' % encodedListItems return b'l%se' % encodedListItems
elif isinstance(data, dict): elif isinstance(data, dict):
encodedDictItems = '' encodedDictItems = b''
keys = data.keys() keys = data.keys()
for key in sorted(keys): for key in sorted(keys):
encodedDictItems += self.encode(key) # TODO: keys should always be bytestrings encodedDictItems += self.encode(key) # TODO: keys should always be bytestrings
encodedDictItems += self.encode(data[key]) encodedDictItems += self.encode(data[key])
return 'd%se' % encodedDictItems return b'd%se' % encodedDictItems
else: else:
print(data)
raise TypeError("Cannot bencode '%s' object" % type(data)) raise TypeError("Cannot bencode '%s' object" % type(data))
def decode(self, data): def decode(self, data):
@ -85,6 +86,7 @@ class Bencode(Encoding):
@return: The decoded data, as a native Python type @return: The decoded data, as a native Python type
@rtype: int, list, dict or str @rtype: int, list, dict or str
""" """
assert type(data) == bytes
if len(data) == 0: if len(data) == 0:
raise DecodeError('Cannot decode empty string') raise DecodeError('Cannot decode empty string')
try: try:
@ -98,34 +100,34 @@ class Bencode(Encoding):
Do not call this; use C{decode()} instead Do not call this; use C{decode()} instead
""" """
if data[startIndex] == 'i': if data[startIndex] == ord('i'):
endPos = data[startIndex:].find('e') + startIndex endPos = data[startIndex:].find(b'e') + startIndex
return int(data[startIndex + 1:endPos]), endPos + 1 return int(data[startIndex + 1:endPos]), endPos + 1
elif data[startIndex] == 'l': elif data[startIndex] == ord('l'):
startIndex += 1 startIndex += 1
decodedList = [] decodedList = []
while data[startIndex] != 'e': while data[startIndex] != ord('e'):
listData, startIndex = Bencode._decodeRecursive(data, startIndex) listData, startIndex = Bencode._decodeRecursive(data, startIndex)
decodedList.append(listData) decodedList.append(listData)
return decodedList, startIndex + 1 return decodedList, startIndex + 1
elif data[startIndex] == 'd': elif data[startIndex] == ord('d'):
startIndex += 1 startIndex += 1
decodedDict = {} decodedDict = {}
while data[startIndex] != 'e': while data[startIndex] != ord('e'):
key, startIndex = Bencode._decodeRecursive(data, startIndex) key, startIndex = Bencode._decodeRecursive(data, startIndex)
value, startIndex = Bencode._decodeRecursive(data, startIndex) value, startIndex = Bencode._decodeRecursive(data, startIndex)
decodedDict[key] = value decodedDict[key] = value
return decodedDict, startIndex return decodedDict, startIndex
elif data[startIndex] == 'f': elif data[startIndex] == ord('f'):
# This (float data type) is a non-standard extension to the original Bencode algorithm # This (float data type) is a non-standard extension to the original Bencode algorithm
endPos = data[startIndex:].find('e') + startIndex endPos = data[startIndex:].find(ord('e')) + startIndex
return float(data[startIndex + 1:endPos]), endPos + 1 return float(data[startIndex + 1:endPos]), endPos + 1
elif data[startIndex] == 'n': elif data[startIndex] == ord('n'):
# This (None/NULL data type) is a non-standard extension # This (None/NULL data type) is a non-standard extension
# to the original Bencode algorithm # to the original Bencode algorithm
return None, startIndex + 1 return None, startIndex + 1
else: else:
splitPos = data[startIndex:].find(':') + startIndex splitPos = data[startIndex:].find(ord(':')) + startIndex
try: try:
length = int(data[startIndex:splitPos]) length = int(data[startIndex:splitPos])
except ValueError: except ValueError:

View file

@ -155,10 +155,10 @@ class Node(MockKademliaHelper):
self.peer_finder = peer_finder or DHTPeerFinder(self, self.peer_manager) self.peer_finder = peer_finder or DHTPeerFinder(self, self.peer_manager)
self._join_deferred = None self._join_deferred = None
def __del__(self): #def __del__(self):
log.warning("unclean shutdown of the dht node") # log.warning("unclean shutdown of the dht node")
if hasattr(self, "_listeningPort") and self._listeningPort is not None: # if hasattr(self, "_listeningPort") and self._listeningPort is not None:
self._listeningPort.stopListening() # self._listeningPort.stopListening()
@defer.inlineCallbacks @defer.inlineCallbacks
def stop(self): def stop(self):
@ -203,7 +203,7 @@ class Node(MockKademliaHelper):
if not known_node_resolution: if not known_node_resolution:
known_node_resolution = yield _resolve_seeds() known_node_resolution = yield _resolve_seeds()
# we are one of the seed nodes, don't add ourselves # we are one of the seed nodes, don't add ourselves
if (self.externalIP, self.port) in known_node_resolution.itervalues(): if (self.externalIP, self.port) in known_node_resolution.values():
del known_node_resolution[(self.externalIP, self.port)] del known_node_resolution[(self.externalIP, self.port)]
known_node_addresses.remove((self.externalIP, self.port)) known_node_addresses.remove((self.externalIP, self.port))
@ -216,7 +216,7 @@ class Node(MockKademliaHelper):
def _initialize_routing(): def _initialize_routing():
bootstrap_contacts = [] bootstrap_contacts = []
contact_addresses = {(c.address, c.port): c for c in self.contacts} contact_addresses = {(c.address, c.port): c for c in self.contacts}
for (host, port), ip_address in known_node_resolution.iteritems(): for (host, port), ip_address in known_node_resolution.items():
if (host, port) not in contact_addresses: if (host, port) not in contact_addresses:
# Create temporary contact information for the list of addresses of known nodes # Create temporary contact information for the list of addresses of known nodes
# The contact node id will be set with the responding node id when we initialize it to None # The contact node id will be set with the responding node id when we initialize it to None

View file

@ -13,17 +13,17 @@ class BencodeTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.encoding = lbrynet.dht.encoding.Bencode() self.encoding = lbrynet.dht.encoding.Bencode()
# Thanks goes to wikipedia for the initial test cases ;-) # Thanks goes to wikipedia for the initial test cases ;-)
self.cases = ((42, 'i42e'), self.cases = ((42, b'i42e'),
('spam', '4:spam'), (b'spam', b'4:spam'),
(['spam', 42], 'l4:spami42ee'), ([b'spam', 42], b'l4:spami42ee'),
({'foo': 42, 'bar': 'spam'}, 'd3:bar4:spam3:fooi42ee'), ({b'foo': 42, b'bar': b'spam'}, b'd3:bar4:spam3:fooi42ee'),
# ...and now the "real life" tests # ...and now the "real life" tests
([['abc', '127.0.0.1', 1919], ['def', '127.0.0.1', 1921]], ([[b'abc', b'127.0.0.1', 1919], [b'def', b'127.0.0.1', 1921]],
'll3:abc9:127.0.0.1i1919eel3:def9:127.0.0.1i1921eee')) b'll3:abc9:127.0.0.1i1919eel3:def9:127.0.0.1i1921eee'))
# The following test cases are "bad"; i.e. sending rubbish into the decoder to test # The following test cases are "bad"; i.e. sending rubbish into the decoder to test
# what exceptions get thrown # what exceptions get thrown
self.badDecoderCases = ('abcdefghijklmnopqrstuvwxyz', self.badDecoderCases = (b'abcdefghijklmnopqrstuvwxyz',
'') b'')
def testEncoder(self): def testEncoder(self):
""" Tests the bencode encoder """ """ Tests the bencode encoder """