working functional test_contact_rpc + more string bans

This commit is contained in:
Victor Shyba 2018-07-20 16:45:58 -03:00 committed by Jack Robison
parent 1ee682f06f
commit e1314a9d1e
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
8 changed files with 61 additions and 62 deletions

View file

@ -59,7 +59,7 @@ class _Contact(object):
def log_id(self, short=True):
if not self.id:
return "not initialized"
id_hex = self.id.encode('hex')
id_hex = hexlify(self.id)
return id_hex if not short else id_hex[:8]
@property
@ -162,7 +162,7 @@ class _Contact(object):
raise AttributeError("unknown command: %s" % name)
def _sendRPC(*args, **kwargs):
return self._networkProtocol.sendRPC(self, name, args)
return self._networkProtocol.sendRPC(self, name.encode(), args)
return _sendRPC

View file

@ -58,8 +58,6 @@ class Bencode(Encoding):
"""
if isinstance(data, (int, long)):
return b'i%de' % data
elif isinstance(data, str):
return b'%d:%s' % (len(data), data.encode())
elif isinstance(data, bytes):
return b'%d:%s' % (len(data), data)
elif isinstance(data, (list, tuple)):

View file

@ -140,8 +140,7 @@ class KBucket(object):
if not.
@rtype: bool
"""
if isinstance(key, str):
key = long(hexlify(key.encode()), 16)
assert type(key) in [long, bytes], "{} is {}".format(key, type(key)) # fixme: _maybe_ remove this after porting
if isinstance(key, bytes):
key = long(hexlify(key), 16)
return self.rangeMin <= key < self.rangeMax

View file

@ -48,6 +48,5 @@ class ErrorMessage(ResponseMessage):
def __init__(self, rpcID, nodeID, exceptionType, errorMessage):
ResponseMessage.__init__(self, rpcID, nodeID, errorMessage)
if isinstance(exceptionType, type):
self.exceptionType = '%s.%s' % (exceptionType.__module__, exceptionType.__name__)
else:
self.exceptionType = exceptionType
exceptionType = ('%s.%s' % (exceptionType.__module__, exceptionType.__name__)).encode()
self.exceptionType = exceptionType

View file

@ -10,6 +10,8 @@ import binascii
import hashlib
import struct
import logging
from functools import reduce
from twisted.internet import defer, error, task
from lbrynet.core.utils import generate_id, DeferredDict
@ -493,7 +495,7 @@ class Node(MockKademliaHelper):
@rtype: str
"""
return 'pong'
return b'pong'
@rpcmethod
def store(self, rpc_contact, blob_hash, token, port, originalPublisherID, age):
@ -530,13 +532,13 @@ class Node(MockKademliaHelper):
if 0 <= port <= 65536:
compact_port = struct.pack('>H', port)
else:
raise TypeError('Invalid port')
raise TypeError('Invalid port: {}'.format(port))
compact_address = compact_ip + compact_port + rpc_contact.id
now = int(self.clock.seconds())
originallyPublished = now - age
self._dataStore.addPeerToBlob(rpc_contact, blob_hash, compact_address, now, originallyPublished,
originalPublisherID)
return 'OK'
return b'OK'
@rpcmethod
def findNode(self, rpc_contact, key):
@ -578,11 +580,11 @@ class Node(MockKademliaHelper):
raise ValueError("invalid blob hash length: %i" % len(key))
response = {
'token': self.make_token(rpc_contact.compact_ip()),
b'token': self.make_token(rpc_contact.compact_ip()),
}
if self._protocol._protocolVersion:
response['protocolVersion'] = self._protocol._protocolVersion
response[b'protocolVersion'] = self._protocol._protocolVersion
# get peers we have stored for this blob
has_other_peers = self._dataStore.hasPeersForBlob(key)
@ -592,17 +594,15 @@ class Node(MockKademliaHelper):
# if we don't have k storing peers to return and we have this hash locally, include our contact information
if len(peers) < constants.k and key in self._dataStore.completed_blobs:
compact_ip = str(
reduce(lambda buff, x: buff + bytearray([int(x)]), self.externalIP.split('.'), bytearray())
)
compact_port = str(struct.pack('>H', self.peerPort))
compact_ip = reduce(lambda buff, x: buff + bytearray([int(x)]), self.externalIP.split('.'), bytearray())
compact_port = struct.pack('>H', self.peerPort)
compact_address = compact_ip + compact_port + self.node_id
peers.append(compact_address)
if peers:
response[key] = peers
else:
response['contacts'] = self.findNode(rpc_contact, key)
response[b'contacts'] = self.findNode(rpc_contact, key)
return response
def _generateID(self):

View file

@ -1,6 +1,7 @@
import logging
import socket
import errno
from binascii import hexlify
from collections import deque
from twisted.internet import protocol, defer
@ -108,12 +109,12 @@ class KademliaProtocol(protocol.DatagramProtocol):
self.started_listening_time = 0
def _migrate_incoming_rpc_args(self, contact, method, *args):
if method == 'store' and contact.protocolVersion == 0:
if method == b'store' and contact.protocolVersion == 0:
if isinstance(args[1], dict):
blob_hash = args[0]
token = args[1].pop('token', None)
port = args[1].pop('port', -1)
originalPublisherID = args[1].pop('lbryid', None)
token = args[1].pop(b'token', None)
port = args[1].pop(b'port', -1)
originalPublisherID = args[1].pop(b'lbryid', None)
age = 0
return (blob_hash, token, port, originalPublisherID, age), {}
return args, {}
@ -124,16 +125,16 @@ class KademliaProtocol(protocol.DatagramProtocol):
protocol version keyword argument to calls to contacts who will accept it
"""
if contact.protocolVersion == 0:
if method == 'store':
if method == b'store':
blob_hash, token, port, originalPublisherID, age = args
args = (blob_hash, {'token': token, 'port': port, 'lbryid': originalPublisherID}, originalPublisherID,
args = (blob_hash, {b'token': token, b'port': port, b'lbryid': originalPublisherID}, originalPublisherID,
False)
return args
return args
if args and isinstance(args[-1], dict):
args[-1]['protocolVersion'] = self._protocolVersion
args[-1][b'protocolVersion'] = self._protocolVersion
return args
return args + ({'protocolVersion': self._protocolVersion},)
return args + ({b'protocolVersion': self._protocolVersion},)
def sendRPC(self, contact, method, args):
"""
@ -162,7 +163,7 @@ class KademliaProtocol(protocol.DatagramProtocol):
if args:
log.debug("%s:%i SEND CALL %s(%s) TO %s:%i", self._node.externalIP, self._node.port, method,
args[0].encode('hex'), contact.address, contact.port)
hexlify(args[0]), contact.address, contact.port)
else:
log.debug("%s:%i SEND CALL %s TO %s:%i", self._node.externalIP, self._node.port, method,
contact.address, contact.port)
@ -179,11 +180,11 @@ class KademliaProtocol(protocol.DatagramProtocol):
def _update_contact(result): # refresh the contact in the routing table
contact.update_last_replied()
if method == 'findValue':
if 'protocolVersion' not in result:
if method == b'findValue':
if b'protocolVersion' not in result:
contact.update_protocol_version(0)
else:
contact.update_protocol_version(result.pop('protocolVersion'))
contact.update_protocol_version(result.pop(b'protocolVersion'))
d = self._node.addContact(contact)
d.addCallback(lambda _: result)
return d
@ -214,8 +215,7 @@ class KademliaProtocol(protocol.DatagramProtocol):
@note: This is automatically called by Twisted when the protocol
receives a UDP datagram
"""
if datagram[0] == '\x00' and datagram[25] == '\x00':
if datagram[0] == b'\x00' and datagram[25] == b'\x00':
totalPackets = (ord(datagram[1]) << 8) | ord(datagram[2])
msgID = datagram[5:25]
seqNumber = (ord(datagram[3]) << 8) | ord(datagram[4])
@ -307,7 +307,7 @@ class KademliaProtocol(protocol.DatagramProtocol):
# the node id of the node we sent a message to (these messages are treated as an error)
if remoteContact.id and remoteContact.id != message.nodeID: # sent_to_id will be None for bootstrap
log.debug("mismatch: (%s) %s:%i (%s vs %s)", method, remoteContact.address, remoteContact.port,
remoteContact.log_id(False), message.nodeID.encode('hex'))
remoteContact.log_id(False), hexlify(message.nodeID))
df.errback(TimeoutError(remoteContact.id))
return
elif not remoteContact.id:
@ -396,6 +396,7 @@ class KademliaProtocol(protocol.DatagramProtocol):
def _sendError(self, contact, rpcID, exceptionType, exceptionMessage):
""" Send an RPC error message to the specified contact
"""
exceptionType, exceptionMessage = exceptionType.encode(), exceptionMessage.encode()
msg = msgtypes.ErrorMessage(rpcID, self._node.node_id, exceptionType, exceptionMessage)
msgPrimitive = self._translator.toPrimitive(msg)
encodedMsg = self._encoder.encode(msgPrimitive)
@ -416,7 +417,7 @@ class KademliaProtocol(protocol.DatagramProtocol):
df.addErrback(handleError)
# Execute the RPC
func = getattr(self._node, method, None)
func = getattr(self._node, method.decode(), None)
if callable(func) and hasattr(func, "rpcmethod"):
# Call the exposed Node method and return the result to the deferred callback chain
# if args:
@ -425,14 +426,14 @@ class KademliaProtocol(protocol.DatagramProtocol):
# else:
log.debug("%s:%i RECV CALL %s %s:%i", self._node.externalIP, self._node.port, method,
senderContact.address, senderContact.port)
if args and isinstance(args[-1], dict) and 'protocolVersion' in args[-1]: # args don't need reformatting
senderContact.update_protocol_version(int(args[-1].pop('protocolVersion')))
if args and isinstance(args[-1], dict) and b'protocolVersion' in args[-1]: # args don't need reformatting
senderContact.update_protocol_version(int(args[-1].pop(b'protocolVersion')))
a, kw = tuple(args[:-1]), args[-1]
else:
senderContact.update_protocol_version(0)
a, kw = self._migrate_incoming_rpc_args(senderContact, method, *args)
try:
if method != 'ping':
if method != b'ping':
result = func(senderContact, *a)
else:
result = func()

View file

@ -1,3 +1,5 @@
from binascii import unhexlify
import time
from twisted.trial import unittest
import logging
@ -19,12 +21,12 @@ class KademliaProtocolTest(unittest.TestCase):
def setUp(self):
self._reactor = Clock()
self.node = Node(node_id='1' * 48, udpPort=self.udpPort, externalIP="127.0.0.1", listenUDP=listenUDP,
self.node = Node(node_id=b'1' * 48, udpPort=self.udpPort, externalIP="127.0.0.1", listenUDP=listenUDP,
resolve=resolve, clock=self._reactor, callLater=self._reactor.callLater)
self.remote_node = Node(node_id='2' * 48, udpPort=self.udpPort, externalIP="127.0.0.2", listenUDP=listenUDP,
self.remote_node = Node(node_id=b'2' * 48, udpPort=self.udpPort, externalIP="127.0.0.2", listenUDP=listenUDP,
resolve=resolve, clock=self._reactor, callLater=self._reactor.callLater)
self.remote_contact = self.node.contact_manager.make_contact('2' * 48, '127.0.0.2', 9182, self.node._protocol)
self.us_from_them = self.remote_node.contact_manager.make_contact('1' * 48, '127.0.0.1', 9182,
self.remote_contact = self.node.contact_manager.make_contact(b'2' * 48, '127.0.0.2', 9182, self.node._protocol)
self.us_from_them = self.remote_node.contact_manager.make_contact(b'1' * 48, '127.0.0.1', 9182,
self.remote_node._protocol)
self.node.start_listening()
self.remote_node.start_listening()
@ -105,7 +107,7 @@ class KademliaProtocolTest(unittest.TestCase):
self.error = 'An RPC error occurred: %s' % f.getErrorMessage()
def handleResult(result):
expectedResult = 'pong'
expectedResult = b'pong'
if result != expectedResult:
self.error = 'Result from RPC is incorrect; expected "%s", got "%s"' \
% (expectedResult, result)
@ -142,7 +144,7 @@ class KademliaProtocolTest(unittest.TestCase):
self.error = 'An RPC error occurred: %s' % f.getErrorMessage()
def handleResult(result):
expectedResult = 'pong'
expectedResult = b'pong'
if result != expectedResult:
self.error = 'Result from RPC is incorrect; expected "%s", got "%s"' % \
(expectedResult, result)
@ -163,12 +165,12 @@ class KademliaProtocolTest(unittest.TestCase):
@defer.inlineCallbacks
def testDetectProtocolVersion(self):
original_findvalue = self.remote_node.findValue
fake_blob = str("AB" * 48).decode('hex')
fake_blob = unhexlify("AB" * 48)
@rpcmethod
def findValue(contact, key):
result = original_findvalue(contact, key)
result.pop('protocolVersion')
result.pop(b'protocolVersion')
return result
self.remote_node.findValue = findValue
@ -205,35 +207,35 @@ class KademliaProtocolTest(unittest.TestCase):
@rpcmethod
def findValue(contact, key):
result = original_findvalue(contact, key)
if 'protocolVersion' in result:
result.pop('protocolVersion')
if b'protocolVersion' in result:
result.pop(b'protocolVersion')
return result
@rpcmethod
def store(contact, key, value, originalPublisherID=None, self_store=False, **kwargs):
self.assertTrue(len(key) == 48)
self.assertSetEqual(set(value.keys()), {'token', 'lbryid', 'port'})
self.assertSetEqual(set(value.keys()), {b'token', b'lbryid', b'port'})
self.assertFalse(self_store)
self.assertDictEqual(kwargs, {})
return original_store( # pylint: disable=too-many-function-args
contact, key, value['token'], value['port'], originalPublisherID, 0
contact, key, value[b'token'], value[b'port'], originalPublisherID, 0
)
self.remote_node.findValue = findValue
self.remote_node.store = store
fake_blob = str("AB" * 48).decode('hex')
fake_blob = unhexlify("AB" * 48)
d = self.remote_contact.findValue(fake_blob)
self._reactor.advance(3)
find_value_response = yield d
self.assertEqual(self.remote_contact.protocolVersion, 0)
self.assertTrue('protocolVersion' not in find_value_response)
token = find_value_response['token']
self.assertTrue(b'protocolVersion' not in find_value_response)
token = find_value_response[b'token']
d = self.remote_contact.store(fake_blob, token, 3333, self.node.node_id, 0)
self._reactor.advance(3)
response = yield d
self.assertEqual(response, "OK")
self.assertEqual(response, b'OK')
self.assertEqual(self.remote_contact.protocolVersion, 0)
self.assertTrue(self.remote_node._dataStore.hasPeersForBlob(fake_blob))
self.assertEqual(len(self.remote_node._dataStore.getStoringContacts()), 1)
@ -245,24 +247,24 @@ class KademliaProtocolTest(unittest.TestCase):
self.remote_node._protocol._migrate_outgoing_rpc_args = _dont_migrate
us_from_them = self.remote_node.contact_manager.make_contact('1' * 48, '127.0.0.1', self.udpPort,
us_from_them = self.remote_node.contact_manager.make_contact(b'1' * 48, '127.0.0.1', self.udpPort,
self.remote_node._protocol)
fake_blob = str("AB" * 48).decode('hex')
fake_blob = unhexlify("AB" * 48)
d = us_from_them.findValue(fake_blob)
self._reactor.advance(3)
find_value_response = yield d
self.assertEqual(self.remote_contact.protocolVersion, 0)
self.assertTrue('protocolVersion' not in find_value_response)
token = find_value_response['token']
self.assertTrue(b'protocolVersion' not in find_value_response)
token = find_value_response[b'token']
us_from_them.update_protocol_version(0)
d = self.remote_node._protocol.sendRPC(
us_from_them, "store", (fake_blob, {'lbryid': self.remote_node.node_id, 'token': token, 'port': 3333})
us_from_them, b"store", (fake_blob, {b'lbryid': self.remote_node.node_id, b'token': token, b'port': 3333})
)
self._reactor.advance(3)
response = yield d
self.assertEqual(response, "OK")
self.assertEqual(response, b'OK')
self.assertEqual(self.remote_contact.protocolVersion, 0)
self.assertTrue(self.node._dataStore.hasPeersForBlob(fake_blob))
self.assertEqual(len(self.node._dataStore.getStoringContacts()), 1)

View file

@ -32,7 +32,7 @@ class TreeRoutingTableTest(unittest.TestCase):
""" Test to see if distance method returns correct result"""
# testList holds a couple 3-tuple (variable1, variable2, result)
basicTestList = [(bytes([170] * 48), bytes([85] * 48), long(hexlify(bytes([255] * 48)), 16))]
basicTestList = [(bytes(b'\xaa' * 48), bytes(b'\x55' * 48), long(hexlify(bytes(b'\xff' * 48)), 16))]
for test in basicTestList:
result = Distance(test[0])(test[1])
@ -139,7 +139,7 @@ class TreeRoutingTableTest(unittest.TestCase):
Test that a bucket is not split if it is full, but the new contact is not closer than the kth closest contact
"""
self.routingTable._parentNodeID = bytes(48 * [255])
self.routingTable._parentNodeID = bytes(48 * b'\xff')
node_ids = [
b"100000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",