diff --git a/lbry/dht/node.py b/lbry/dht/node.py index 745789e89..74adbc99f 100644 --- a/lbry/dht/node.py +++ b/lbry/dht/node.py @@ -17,6 +17,16 @@ if typing.TYPE_CHECKING: log = logging.getLogger(__name__) +class NodeState: + def __init__(self, + routing_table_peers: typing.List[typing.Tuple[bytes, str, int, int]], + datastore: typing.List[typing.Tuple[bytes, str, int, int, bytes]]): + # List[Tuple[node_id, address, udp_port, tcp_port]] + self.routing_table_peers = routing_table_peers + # List[Tuple[node_id, address, udp_port, tcp_port, blob_hash]] + self.datastore = datastore + + class Node: def __init__(self, loop: asyncio.AbstractEventLoop, peer_manager: 'PeerManager', node_id: bytes, udp_port: int, internal_udp_port: int, peer_port: int, external_ip: str, rpc_timeout: float = constants.RPC_TIMEOUT, @@ -31,6 +41,7 @@ class Node: self._join_task: asyncio.Task = None self._refresh_task: asyncio.Task = None self._storage = storage + self.started_listening = asyncio.Event() async def refresh_node(self, force_once=False): while True: @@ -103,6 +114,7 @@ class Node: return stored_to def stop(self) -> None: + self.started_listening.clear() if self.joined.is_set(): self.joined.clear() if self._join_task: @@ -118,18 +130,32 @@ class Node: self.listening_port = None log.info("Stopped DHT node") + def get_state(self) -> NodeState: + return NodeState( + routing_table_peers=[(p.node_id, p.address, p.udp_port, p.tcp_port) + for p in self.protocol.routing_table.get_peers()], + datastore=self.protocol.data_store.dump() + ) + + def load_state(self, state: NodeState): + for node_id, address, udp_port, tcp_port, blob_hash in state.datastore: + p = make_kademlia_peer(node_id, address, udp_port, tcp_port) + self.protocol.data_store.add_peer_to_blob(p, blob_hash) + async def start_listening(self, interface: str = '0.0.0.0') -> None: if not self.listening_port: self.listening_port, _ = await self.loop.create_datagram_endpoint( lambda: self.protocol, (interface, self.internal_udp_port) ) + self.started_listening.set() log.info("DHT node listening on UDP %s:%i", interface, self.internal_udp_port) self.protocol.start() else: log.warning("Already bound to port %s", self.listening_port) async def join_network(self, interface: str = '0.0.0.0', - known_node_urls: typing.Optional[typing.List[typing.Tuple[str, int]]] = None): + known_node_urls: typing.Optional[typing.List[typing.Tuple[str, int]]] = None, + persisted_peers: typing.List[typing.Tuple[bytes, str, int, int]] = []): def peers_from_urls(urls: typing.Optional[typing.List[typing.Tuple[bytes, str, int, int]]]): peer_addresses = [] for node_id, address, udp_port, tcp_port in urls: @@ -154,9 +180,7 @@ class Node: else: if self.joined.is_set(): self.joined.clear() - seed_peers = peers_from_urls( - await self._storage.get_persisted_kademlia_peers() - ) if self._storage else [] + seed_peers = peers_from_urls(persisted_peers) if persisted_peers else [] if not seed_peers: try: seed_peers.extend(peers_from_urls([ @@ -173,8 +197,11 @@ class Node: await asyncio.sleep(1, loop=self.loop) - def start(self, interface: str, known_node_urls: typing.Optional[typing.List[typing.Tuple[str, int]]] = None): - self._join_task = self.loop.create_task(self.join_network(interface, known_node_urls)) + def start(self, interface: str, + known_node_urls: typing.Optional[typing.List[typing.Tuple[str, int]]] = None, + persisted_peers: typing.List[typing.Tuple[bytes, str, int, int]] = []): + + self._join_task = self.loop.create_task(self.join_network(interface, known_node_urls, persisted_peers)) def get_iterative_node_finder(self, key: bytes, shortlist: typing.Optional[typing.List['KademliaPeer']] = None, bottom_out_limit: int = constants.BOTTOM_OUT_LIMIT, diff --git a/lbry/dht/protocol/data_store.py b/lbry/dht/protocol/data_store.py index 6a614680f..62673f547 100644 --- a/lbry/dht/protocol/data_store.py +++ b/lbry/dht/protocol/data_store.py @@ -68,3 +68,9 @@ class DictDataStore: for _, stored in self._data_store.items(): peers.update(set(map(lambda tup: tup[0], stored))) return list(peers) + + def dump(self) -> typing.List[typing.Tuple[bytes, str, int, int, bytes]]: + data = [] + for k, peers in self._data_store.items(): + data.extend([(p.node_id, p.address, p.udp_port, p.tcp_port, k) for p in map(lambda t: t[0], peers)]) + return data diff --git a/lbry/dht/protocol/protocol.py b/lbry/dht/protocol/protocol.py index 7b90b5644..4f5f94e00 100644 --- a/lbry/dht/protocol/protocol.py +++ b/lbry/dht/protocol/protocol.py @@ -287,6 +287,7 @@ class KademliaProtocol(DatagramProtocol): self._to_add: typing.Set['KademliaPeer'] = set() self._wakeup_routing_task = asyncio.Event(loop=self.loop) self.maintaing_routing_task: typing.Optional[asyncio.Task] = None + self.event_queue = asyncio.Queue(maxsize=100) @functools.lru_cache(128) def get_rpc_peer(self, peer: 'KademliaPeer') -> RemoteKademliaRPC: @@ -428,6 +429,9 @@ class KademliaProtocol(DatagramProtocol): log.debug("%s:%i RECV CALL %s %s:%i", self.external_ip, self.udp_port, message.method.decode(), sender_contact.address, sender_contact.udp_port) + if not self.event_queue.full(): + self.event_queue.put_nowait((sender_contact.node_id, sender_contact.address, method, args)) + if method == b'ping': result = self.node_rpc.ping() elif method == b'store': diff --git a/scripts/tracker.py b/scripts/tracker.py new file mode 100644 index 000000000..3df2da8d6 --- /dev/null +++ b/scripts/tracker.py @@ -0,0 +1,125 @@ +import asyncio +import logging +import signal +import time +from aioupnp import upnp +import sqlite3 +import pickle +from os import path +from pprint import pprint + + +from lbry.dht import node, peer + +log = logging.getLogger("lbry") +log.addHandler(logging.StreamHandler()) +log.setLevel(logging.INFO) + + +async def main(): + data_dir = "/home/grin/code/lbry/sdk" + state_file = data_dir + '/nodestate' + loop = asyncio.get_event_loop() + + try: + loop.add_signal_handler(signal.SIGINT, shutdown) + loop.add_signal_handler(signal.SIGTERM, shutdown) + except NotImplementedError: + pass # Not implemented on Windows + + peer_manager = peer.PeerManager(loop) + u = await upnp.UPnP.discover() + await u.get_next_mapping(4444, "UDP", "lbry dht tracker", 4444) + my_node_id = "38b060a751ac96384cd9327eb1b1e36a21fdb71114be07434c0cc7bf63f6e1da274edebfe76f65fbd51ad2f14898b95b" + n = node.Node(loop, peer_manager, node_id=bytes.fromhex(my_node_id), external_ip=(await u.get_external_ip()), + udp_port=4444, internal_udp_port=4444, peer_port=4444) + + db = sqlite3.connect(data_dir + "/tracker.sqlite3") + db.execute( + '''CREATE TABLE IF NOT EXISTS log (hash TEXT, node_id TEXT, ip TEXT, port INT, timestamp INT)''' + ) + # curr = db.cursor() + # res = curr.execute("SELECT 1, 2, 3") + # for items in res: + # print(items) + + try: + known_node_urls=[("lbrynet1.lbry.com", 4444), ("lbrynet2.lbry.com", 4444), ("lbrynet3.lbry.com", 4444)] + persisted_peers =[] + if path.exists(state_file): + with open(state_file, 'rb') as f: + state = pickle.load(f) + # pprint(state.routing_table_peers) + # pprint(state.datastore) + print(f'loaded {len(state.routing_table_peers)} rt peers, {len(state.datastore)} in store') + n.load_state(state) + persisted_peers = state.routing_table_peers + + n.start("0.0.0.0", known_node_urls, persisted_peers) + await n.started_listening.wait() + print("joined") + # jack = peer.make_kademlia_peer( + # bytes.fromhex("38b060a751ac96384cd9327eb1b1e36a21fdb71114be07434c0cc7bf63f6e1da274edebfe76f65fbd51ad2f14898b95c"), + # "216.19.244.226", udp_port=4444, + # ) + # print(await n.protocol.get_rpc_peer(jack).ping()) + + await dostuff(n, db) + finally: + print("shutting down") + n.stop() + state = n.get_state() + with open(state_file, 'wb') as f: + # pprint(state.routing_table_peers) + # pprint(state.datastore) + print(f'saved {len(state.routing_table_peers)} rt peers, {len(state.datastore)} in store') + pickle.dump(state, f) + db.close() + await u.delete_port_mapping(4444, "UDP") + + +async def dostuff(n, db): + # gather + # as_completed + # wait + # wait_for + + # make a task to loop over the things in the node. those tasks drain into one combined queue + # t = asyncio.create_task for each node + # keep the t + # handle teardown at the end + # + + while True: + (node_id, ip, method, args) = await n.protocol.event_queue.get() + if method == b'store': + blob_hash, token, port, original_publisher_id, age = args[:5] + print(f"STORE from {bytes.hex(node_id)} ({ip}) for blob {bytes.hex(blob_hash)}") + + try: + cur = db.cursor() + cur.execute('INSERT INTO log (hash, node_id, ip, port, timestamp) VALUES (?,?,?,?,?)', + (bytes.hex(blob_hash), bytes.hex(node_id), ip, port, int(time.time()))) + db.commit() + cur.close() + except sqlite3.Error as err: + print("failed insert", err) + else: + pass + # print(f"{method} from {bytes.hex(node_id)} ({ip})") + + +class ShutdownErr(BaseException): + pass + + +def shutdown(): + print("got interrupt signal...") + raise ShutdownErr() + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except ShutdownErr: + pass