add protocol version to the dht and migrate old arg format for store

This commit is contained in:
Jack Robison 2018-05-31 10:50:11 -04:00
parent 7d21cc5822
commit 9a63db4ec6
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
7 changed files with 226 additions and 82 deletions

View file

@ -55,3 +55,5 @@ udpDatagramMaxSize = 8192 # 8 KB
key_bits = 384
rpc_id_length = 20
protocolVersion = 1

View file

@ -34,6 +34,7 @@ class _Contact(object):
self.getTime = self._contactManager._get_time
self.lastReplied = None
self.lastRequested = None
self.protocolVersion = constants.protocolVersion
@property
def lastInteracted(self):
@ -120,6 +121,9 @@ class _Contact(object):
failures.append(self.getTime())
self._contactManager._rpc_failures[(self.address, self.port)] = failures
def update_protocol_version(self, version):
self.protocolVersion = version
def __str__(self):
return '<%s.%s object; IP address: %s, UDP port: %d>' % (
self.__module__, self.__class__.__name__, self.address, self.port)
@ -143,7 +147,7 @@ class _Contact(object):
raise AttributeError("unknown command: %s" % name)
def _sendRPC(*args, **kwargs):
return self._networkProtocol.sendRPC(self, name, args, **kwargs)
return self._networkProtocol.sendRPC(self, name, args)
return _sendRPC

View file

@ -53,10 +53,10 @@ class _IterativeFind(object):
def is_find_value_request(self):
return self.rpc == "findValue"
def is_closer(self, responseMsg):
def is_closer(self, contact):
if not self.closest_node:
return True
return self.distance.is_closer(responseMsg.nodeID, self.closest_node.id)
return self.distance.is_closer(contact.id, self.closest_node.id)
def getContactTriples(self, result):
if self.is_find_value_request:
@ -73,16 +73,15 @@ class _IterativeFind(object):
contact_list.sort(key=lambda c: self.distance(c.id))
@defer.inlineCallbacks
def extendShortlist(self, contact, responseTuple):
def extendShortlist(self, contact, result):
# The "raw response" tuple contains the response message and the originating address info
responseMsg = responseTuple[0]
originAddress = responseTuple[1] # tuple: (ip address, udp port)
originAddress = (contact.address, contact.port)
if self.finished_deferred.called:
defer.returnValue(responseMsg.nodeID)
defer.returnValue(contact.id)
if self.node.contact_manager.is_ignored(originAddress):
raise ValueError("contact is ignored")
if responseMsg.nodeID == self.node.node_id:
defer.returnValue(responseMsg.nodeID)
if contact.id == self.node.node_id:
defer.returnValue(contact.id)
yield self._lock.acquire()
@ -92,7 +91,6 @@ class _IterativeFind(object):
self.shortlist.append(contact)
# Now grow extend the (unverified) shortlist with the returned contacts
result = responseMsg.response
# TODO: some validation on the result (for guarding against attacks)
# If we are looking for a value, first see if this result is the value
# we are looking for before treating it as a list of contact triples
@ -107,7 +105,7 @@ class _IterativeFind(object):
# - mark it as the closest "empty" node, if it is
# TODO: store to this peer after finding the value as per the kademlia spec
if 'closestNodeNoValue' in self.find_value_result:
if self.is_closer(responseMsg):
if self.is_closer(contact):
self.find_value_result['closestNodeNoValue'] = contact
else:
self.find_value_result['closestNodeNoValue'] = contact
@ -130,14 +128,14 @@ class _IterativeFind(object):
self.sortByDistance(self.active_contacts)
self.finished_deferred.callback(self.active_contacts[:min(constants.k, len(self.active_contacts))])
defer.returnValue(responseMsg.nodeID)
defer.returnValue(contact.id)
@defer.inlineCallbacks
def probeContact(self, contact):
fn = getattr(contact, self.rpc)
try:
response_tuple = yield fn(self.key, rawResponse=True)
result = yield self.extendShortlist(contact, response_tuple)
response = yield fn(self.key)
result = yield self.extendShortlist(contact, response)
defer.returnValue(result)
except (TimeoutError, defer.CancelledError, ValueError, IndexError):
defer.returnValue(contact.id)

View file

@ -315,7 +315,7 @@ class Node(MockKademliaHelper):
self_contact = self.contact_manager.make_contact(self.node_id, self.externalIP,
self.port, self._protocol)
token = self.make_token(self_contact.compact_ip())
yield self.store(self_contact, blob_hash, token, self.peerPort)
yield self.store(self_contact, blob_hash, token, self.peerPort, self.node_id, 0)
elif self.externalIP is not None:
pass
else:
@ -327,15 +327,15 @@ class Node(MockKademliaHelper):
def announce_to_contact(contact):
known_nodes[contact.id] = contact
try:
responseMsg, originAddress = yield contact.findValue(blob_hash, rawResponse=True)
res = yield contact.store(blob_hash, responseMsg.response['token'], self.peerPort)
response = yield contact.findValue(blob_hash)
res = yield contact.store(blob_hash, response['token'], self.peerPort, self.node_id, 0)
if res != "OK":
raise ValueError(res)
contacted.append(contact)
log.debug("Stored %s to %s (%s)", blob_hash.encode('hex'), contact.id.encode('hex'), originAddress[0])
log.debug("Stored %s to %s (%s)", binascii.hexlify(blob_hash), contact.log_id(), contact.address)
except protocol.TimeoutError:
log.debug("Timeout while storing blob_hash %s at %s",
blob_hash.encode('hex')[:16], contact.log_id())
binascii.hexlify(blob_hash), contact.log_id())
except ValueError as err:
log.error("Unexpected response: %s" % err.message)
except Exception as err:
@ -348,7 +348,7 @@ class Node(MockKademliaHelper):
yield defer.DeferredList(dl)
log.debug("Stored %s to %i of %i attempted peers", blob_hash.encode('hex')[:16],
log.debug("Stored %s to %i of %i attempted peers", binascii.hexlify(blob_hash),
len(contacted), len(contacts))
contacted_node_ids = [c.id.encode('hex') for c in contacted]
@ -506,7 +506,7 @@ class Node(MockKademliaHelper):
return 'pong'
@rpcmethod
def store(self, rpc_contact, blob_hash, token, port, originalPublisherID=None, age=0):
def store(self, rpc_contact, blob_hash, token, port, originalPublisherID, age):
""" Store the received data in this node's local datastore
@param blob_hash: The hash of the data
@ -589,6 +589,9 @@ class Node(MockKademliaHelper):
'token': self.make_token(rpc_contact.compact_ip()),
}
if self._protocol._protocolVersion:
response['protocolVersion'] = self._protocol._protocolVersion
if self._dataStore.hasPeersForBlob(key):
response[key] = self._dataStore.getPeersForBlob(key)
else:

View file

@ -102,8 +102,37 @@ class KademliaProtocol(protocol.DatagramProtocol):
self._partialMessagesProgress = {}
self._listening = defer.Deferred(None)
self._ping_queue = PingQueue(self._node)
self._protocolVersion = constants.protocolVersion
def sendRPC(self, contact, method, args, rawResponse=False):
def _migrate_incoming_rpc_args(self, contact, method, *args):
if method == 'store' and contact.protocolVersion == 0:
if isinstance(args[1], dict):
blob_hash = args[0]
token = args[1].pop('token', None)
port = args[1].pop('port', -1)
originalPublisherID = args[1].pop('lbryid', None)
age = 0
return (blob_hash, token, port, originalPublisherID, age), {}
return args, {}
def _migrate_outgoing_rpc_args(self, contact, method, *args):
"""
This will reformat protocol version 0 arguments for the store function and will add the
protocol version keyword argument to calls to contacts who will accept it
"""
if contact.protocolVersion == 0:
if method == 'store':
blob_hash, token, port, originalPublisherID, age = args
args = (blob_hash, {'token': token, 'port': port, 'lbryid': originalPublisherID}, originalPublisherID,
False)
return args
return args
if args and isinstance(args[-1], dict):
args[-1]['protocolVersion'] = self._protocolVersion
return args
return args + ({'protocolVersion': self._protocolVersion},)
def sendRPC(self, contact, method, args):
"""
Sends an RPC to the specified contact
@ -114,14 +143,6 @@ class KademliaProtocol(protocol.DatagramProtocol):
@param args: A list of (non-keyword) arguments to pass to the remote
method, in the correct order
@type args: tuple
@param rawResponse: If this is set to C{True}, the caller of this RPC
will receive a tuple containing the actual response
message object and the originating address tuple as
a result; in other words, it will not be
interpreted by this class. Unless something special
needs to be done with the metadata associated with
the message, this should remain C{False}.
@type rawResponse: bool
@return: This immediately returns a deferred object, which will return
the result of the RPC call, or raise the relevant exception
@ -131,7 +152,8 @@ class KademliaProtocol(protocol.DatagramProtocol):
C{ErrorMessage}).
@rtype: twisted.internet.defer.Deferred
"""
msg = msgtypes.RequestMessage(self._node.node_id, method, args)
msg = msgtypes.RequestMessage(self._node.node_id, method, self._migrate_outgoing_rpc_args(contact, method,
*args))
msgPrimitive = self._translator.toPrimitive(msg)
encodedMsg = self._encoder.encode(msgPrimitive)
@ -143,8 +165,6 @@ class KademliaProtocol(protocol.DatagramProtocol):
contact.address, contact.port)
df = defer.Deferred()
if rawResponse:
df._rpcRawResponse = True
def _remove_contact(failure): # remove the contact from the routing table and track the failure
try:
@ -156,6 +176,11 @@ class KademliaProtocol(protocol.DatagramProtocol):
def _update_contact(result): # refresh the contact in the routing table
contact.update_last_replied()
if method == 'findValue':
if 'protocolVersion' not in result:
contact.update_protocol_version(0)
else:
contact.update_protocol_version(result.pop('protocolVersion'))
d = self._node.addContact(contact)
d.addCallback(lambda _: result)
return d
@ -284,14 +309,8 @@ class KademliaProtocol(protocol.DatagramProtocol):
elif not remoteContact.id:
remoteContact.set_id(message.nodeID)
if hasattr(df, '_rpcRawResponse'):
# The RPC requested that the raw response message
# and originating address be returned; do not
# interpret it
df.callback((message, address))
else:
# We got a result from the RPC
df.callback(message.response)
# We got a result from the RPC
df.callback(message.response)
else:
# If the original message isn't found, it must have timed out
# TODO: we should probably do something with this...
@ -395,20 +414,25 @@ class KademliaProtocol(protocol.DatagramProtocol):
func = getattr(self._node, method, None)
if callable(func) and hasattr(func, "rpcmethod"):
# Call the exposed Node method and return the result to the deferred callback chain
if args:
log.debug("%s:%i RECV CALL %s(%s) %s:%i", self._node.externalIP, self._node.port, method,
args[0].encode('hex'), senderContact.address, senderContact.port)
else:
log.debug("%s:%i RECV CALL %s %s:%i", self._node.externalIP, self._node.port, method,
# if args:
# log.debug("%s:%i RECV CALL %s(%s) %s:%i", self._node.externalIP, self._node.port, method,
# args[0].encode('hex'), senderContact.address, senderContact.port)
# else:
log.debug("%s:%i RECV CALL %s %s:%i", self._node.externalIP, self._node.port, method,
senderContact.address, senderContact.port)
if args and isinstance(args[-1], dict) and 'protocolVersion' in args[-1]: # args don't need reformatting
senderContact.update_protocol_version(int(args[-1].pop('protocolVersion')))
a, kw = tuple(args[:-1]), args[-1]
else:
senderContact.update_protocol_version(0)
a, kw = self._migrate_incoming_rpc_args(senderContact, method, *args)
try:
if method != 'ping':
result = func(senderContact, *args)
result = func(senderContact, *a)
else:
result = func()
except Exception, e:
log.exception("error handling request for %s:%i %s", senderContact.address,
senderContact.port, method)
log.exception("error handling request for %s:%i %s", senderContact.address, senderContact.port, method)
df.errback(e)
else:
df.callback(result)

View file

@ -1,12 +1,10 @@
import time
import unittest
from twisted.trial import unittest
import logging
from twisted.internet.task import Clock
from twisted.internet import defer
import lbrynet.dht.protocol
import lbrynet.dht.contact
import lbrynet.dht.constants
import lbrynet.dht.msgtypes
from lbrynet.dht.error import TimeoutError
from lbrynet.dht.node import Node, rpcmethod
from mock_transport import listenUDP, resolve
@ -23,8 +21,18 @@ class KademliaProtocolTest(unittest.TestCase):
self._reactor = Clock()
self.node = Node(node_id='1' * 48, udpPort=self.udpPort, externalIP="127.0.0.1", listenUDP=listenUDP,
resolve=resolve, clock=self._reactor, callLater=self._reactor.callLater)
self.remote_node = Node(node_id='2' * 48, udpPort=self.udpPort, externalIP="127.0.0.2", listenUDP=listenUDP,
resolve=resolve, clock=self._reactor, callLater=self._reactor.callLater)
self.remote_contact = self.node.contact_manager.make_contact('2' * 48, '127.0.0.2', 9182, self.node._protocol)
self.us_from_them = self.remote_node.contact_manager.make_contact('1' * 48, '127.0.0.1', 9182,
self.remote_node._protocol)
self.node.start_listening()
self.remote_node.start_listening()
@defer.inlineCallbacks
def tearDown(self):
yield self.node.stop()
yield self.remote_node.stop()
del self._reactor
@defer.inlineCallbacks
@ -37,15 +45,12 @@ class KademliaProtocolTest(unittest.TestCase):
result = yield d
self.assertTrue(result)
@defer.inlineCallbacks
def testRPCTimeout(self):
""" Tests if a RPC message sent to a dead remote node times out correctly """
dead_node = Node(node_id='2' * 48, udpPort=self.udpPort, externalIP="127.0.0.2", listenUDP=listenUDP,
resolve=resolve, clock=self._reactor, callLater=self._reactor.callLater)
dead_node.start_listening()
dead_node.stop()
yield self.remote_node.stop()
self._reactor.pump([1 for _ in range(10)])
dead_contact = self.node.contact_manager.make_contact('2' * 48, '127.0.0.2', 9182, self.node._protocol)
self.node.addContact(dead_contact)
self.node.addContact(self.remote_contact)
@rpcmethod
def fake_ping(*args, **kwargs):
@ -60,12 +65,12 @@ class KademliaProtocolTest(unittest.TestCase):
self.node.ping = fake_ping
# Make sure the contact was added
self.failIf(dead_contact not in self.node.contacts,
self.failIf(self.remote_contact not in self.node.contacts,
'Contact not added to fake node (error in test code)')
self.node.start_listening()
# Run the PING RPC (which should raise a timeout error)
df = self.node._protocol.sendRPC(dead_contact, 'ping', {})
df = self.remote_contact.ping()
def check_timeout(err):
self.assertEqual(err.type, TimeoutError)
@ -79,7 +84,7 @@ class KademliaProtocolTest(unittest.TestCase):
# See if the contact was removed due to the timeout
def check_removed_contact():
self.failIf(dead_contact in self.node.contacts,
self.failIf(self.remote_contact in self.node.contacts,
'Contact was not removed after RPC timeout; check exception types.')
df.addCallback(lambda _: reset_values())
@ -88,14 +93,11 @@ class KademliaProtocolTest(unittest.TestCase):
df.addCallback(lambda _: check_removed_contact())
self._reactor.pump([1 for _ in range(20)])
@defer.inlineCallbacks
def testRPCRequest(self):
""" Tests if a valid RPC request is executed and responded to correctly """
remote_node = Node(node_id='2' * 48, udpPort=self.udpPort, externalIP="127.0.0.2", listenUDP=listenUDP,
resolve=resolve, clock=self._reactor, callLater=self._reactor.callLater)
remote_node.start_listening()
remoteContact = remote_node.contact_manager.make_contact('2' * 48, '127.0.0.2', 9182, self.node._protocol)
self.node.addContact(remoteContact)
yield self.node.addContact(self.remote_contact)
self.error = None
@ -108,15 +110,13 @@ class KademliaProtocolTest(unittest.TestCase):
self.error = 'Result from RPC is incorrect; expected "%s", got "%s"' \
% (expectedResult, result)
# Publish the "local" node on the network
self.node.start_listening()
# Simulate the RPC
df = remoteContact.ping()
df = self.remote_contact.ping()
df.addCallback(handleResult)
df.addErrback(handleError)
for _ in range(10):
self._reactor.advance(1)
self._reactor.advance(2)
yield df
self.failIf(self.error, self.error)
# The list of sent RPC messages should be empty at this stage
@ -129,18 +129,13 @@ class KademliaProtocolTest(unittest.TestCase):
Verifies that a RPC request for an existing but unpublished
method is denied, and that the associated (remote) exception gets
raised locally """
remote_node = Node(node_id='2' * 48, udpPort=self.udpPort, externalIP="127.0.0.2", listenUDP=listenUDP,
resolve=resolve, clock=self._reactor, callLater=self._reactor.callLater)
remote_contact = remote_node.contact_manager.make_contact('2' * 48, '127.0.0.2', 9182, self.node._protocol)
self.assertRaises(AttributeError, getattr, remote_contact, "not_a_rpc_function")
self.assertRaises(AttributeError, getattr, self.remote_contact, "not_a_rpc_function")
def testRPCRequestArgs(self):
""" Tests if an RPC requiring arguments is executed correctly """
remote_node = Node(node_id='2' * 48, udpPort=self.udpPort, externalIP="127.0.0.2", listenUDP=listenUDP,
resolve=resolve, clock=self._reactor, callLater=self._reactor.callLater)
remote_node.start_listening()
remote_contact = remote_node.contact_manager.make_contact('2' * 48, '127.0.0.2', 9182, self.node._protocol)
self.node.addContact(remote_contact)
self.node.addContact(self.remote_contact)
self.error = None
def handleError(f):
@ -155,7 +150,7 @@ class KademliaProtocolTest(unittest.TestCase):
# Publish the "local" node on the network
self.node.start_listening()
# Simulate the RPC
df = remote_contact.ping()
df = self.remote_contact.ping()
df.addCallback(handleResult)
df.addErrback(handleError)
self._reactor.pump([1 for _ in range(10)])
@ -164,3 +159,121 @@ class KademliaProtocolTest(unittest.TestCase):
self.failUnlessEqual(len(self.node._protocol._sentMessages), 0,
'The protocol is still waiting for a RPC result, '
'but the transaction is already done!')
@defer.inlineCallbacks
def testDetectProtocolVersion(self):
original_findvalue = self.remote_node.findValue
fake_blob = str("AB" * 48).decode('hex')
@rpcmethod
def findValue(contact, key):
result = original_findvalue(contact, key)
result.pop('protocolVersion')
return result
self.assertEquals(self.remote_contact.protocolVersion, 1)
self.remote_node.findValue = findValue
d = self.remote_contact.findValue(fake_blob)
self._reactor.advance(3)
find_value_response = yield d
self.assertEquals(self.remote_contact.protocolVersion, 0)
self.assertTrue('protocolVersion' not in find_value_response)
self.remote_node.findValue = original_findvalue
d = self.remote_contact.findValue(fake_blob)
self._reactor.advance(3)
find_value_response = yield d
self.assertEquals(self.remote_contact.protocolVersion, 1)
self.assertTrue('protocolVersion' not in find_value_response)
self.remote_node.findValue = findValue
d = self.remote_contact.findValue(fake_blob)
self._reactor.advance(3)
find_value_response = yield d
self.assertEquals(self.remote_contact.protocolVersion, 0)
self.assertTrue('protocolVersion' not in find_value_response)
@defer.inlineCallbacks
def testStoreToPre_0_20_0_Node(self):
self.remote_node._protocol._protocolVersion = 0
def _dont_migrate(contact, method, *args):
return args, {}
self.remote_node._protocol._migrate_incoming_rpc_args = _dont_migrate
original_findvalue = self.remote_node.findValue
original_store = self.remote_node.store
@rpcmethod
def findValue(contact, key):
result = original_findvalue(contact, key)
if 'protocolVersion' in result:
result.pop('protocolVersion')
return result
@rpcmethod
def store(contact, key, value, originalPublisherID=None, self_store=False, **kwargs):
self.assertTrue(len(key) == 48)
self.assertSetEqual(set(value.keys()), {'token', 'lbryid', 'port'})
self.assertFalse(self_store)
self.assertDictEqual(kwargs, {})
return original_store( # pylint: disable=too-many-function-args
contact, key, value['token'], value['port'], originalPublisherID, 0
)
self.assertEquals(self.remote_contact.protocolVersion, 1)
self.remote_node.findValue = findValue
self.remote_node.store = store
fake_blob = str("AB" * 48).decode('hex')
d = self.remote_contact.findValue(fake_blob)
self._reactor.advance(3)
find_value_response = yield d
self.assertEquals(self.remote_contact.protocolVersion, 0)
self.assertTrue('protocolVersion' not in find_value_response)
token = find_value_response['token']
d = self.remote_contact.store(fake_blob, token, 3333, self.node.node_id, 0)
self._reactor.advance(3)
response = yield d
self.assertEquals(response, "OK")
self.assertEquals(self.remote_contact.protocolVersion, 0)
self.assertTrue(self.remote_node._dataStore.hasPeersForBlob(fake_blob))
self.assertEquals(len(self.remote_node._dataStore.getStoringContacts()), 1)
@defer.inlineCallbacks
def testStoreFromPre_0_20_0_Node(self):
self.remote_node._protocol._protocolVersion = 0
def _dont_migrate(contact, method, *args):
return args
self.remote_node._protocol._migrate_outgoing_rpc_args = _dont_migrate
us_from_them = self.remote_node.contact_manager.make_contact('1' * 48, '127.0.0.1', self.udpPort,
self.remote_node._protocol)
fake_blob = str("AB" * 48).decode('hex')
d = us_from_them.findValue(fake_blob)
self._reactor.advance(3)
find_value_response = yield d
self.assertEquals(self.remote_contact.protocolVersion, 0)
self.assertTrue('protocolVersion' not in find_value_response)
token = find_value_response['token']
us_from_them.update_protocol_version(0)
d = self.remote_node._protocol.sendRPC(
us_from_them, "store", (fake_blob, {'lbryid': self.remote_node.node_id, 'token': token, 'port': 3333})
)
self._reactor.advance(3)
response = yield d
self.assertEquals(response, "OK")
self.assertEquals(self.remote_contact.protocolVersion, 0)
self.assertTrue(self.node._dataStore.hasPeersForBlob(fake_blob))
self.assertEquals(len(self.node._dataStore.getStoringContacts()), 1)
self.assertIs(self.node._dataStore.getStoringContacts()[0], self.remote_contact)

View file

@ -62,7 +62,7 @@ class NodeDataTest(unittest.TestCase):
def testStore(self):
""" Tests if the node can store (and privately retrieve) some data """
for key, port in self.cases:
yield self.node.store(self.contact, key, self.token, port, self.contact.id)
yield self.node.store(self.contact, key, self.token, port, self.contact.id, 0)
for key, value in self.cases:
expected_result = self.contact.compact_ip() + str(struct.pack('>H', value)) + \
self.contact.id