handle both urls and ip addresses for fixed nodes and reflector servers

This commit is contained in:
Jack Robison 2019-01-31 13:46:19 -05:00
parent f9fd62c214
commit 2b035009ef
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
3 changed files with 23 additions and 21 deletions

View file

@ -1,9 +1,9 @@
import logging import logging
import asyncio import asyncio
import typing import typing
import socket
import binascii import binascii
import contextlib import contextlib
from lbrynet.utils import resolve_host
from lbrynet.dht import constants from lbrynet.dht import constants
from lbrynet.dht.error import RemoteException from lbrynet.dht.error import RemoteException
from lbrynet.dht.protocol.async_generator_junction import AsyncGeneratorJunction from lbrynet.dht.protocol.async_generator_junction import AsyncGeneratorJunction
@ -116,26 +116,22 @@ class Node:
log.warning("Already bound to port %s", self.listening_port) log.warning("Already bound to port %s", self.listening_port)
async def join_network(self, interface: typing.Optional[str] = '', async def join_network(self, interface: typing.Optional[str] = '',
known_node_urls: typing.Optional[typing.List[typing.Tuple[str, int]]] = None, known_node_urls: typing.Optional[typing.List[typing.Tuple[str, int]]] = None):
known_node_addresses: typing.Optional[typing.List[typing.Tuple[str, int]]] = None):
if not self.listening_port: if not self.listening_port:
await self.start_listening(interface) await self.start_listening(interface)
self.protocol.ping_queue.start() self.protocol.ping_queue.start()
self._refresh_task = self.loop.create_task(self.refresh_node()) self._refresh_task = self.loop.create_task(self.refresh_node())
# resolve the known node urls # resolve the known node urls
known_node_addresses = known_node_addresses or [] known_node_addresses = []
url_to_addr = {} url_to_addr = {}
if known_node_urls: if known_node_urls:
for host, port in known_node_urls: for host, port in known_node_urls:
info = await self.loop.getaddrinfo( address = await resolve_host(host)
host, 'https', if (address, port) not in known_node_addresses:
proto=socket.IPPROTO_TCP, known_node_addresses.append((address, port))
) url_to_addr[address] = host
if (info[0][4][0], port) not in known_node_addresses:
known_node_addresses.append((info[0][4][0], port))
url_to_addr[info[0][4][0]] = host
if known_node_addresses: if known_node_addresses:
while not self.protocol.routing_table.get_peers(): while not self.protocol.routing_table.get_peers():

View file

@ -1,7 +1,7 @@
import asyncio import asyncio
import typing import typing
import socket
import logging import logging
from lbrynet.utils import resolve_host
from lbrynet.stream.assembler import StreamAssembler from lbrynet.stream.assembler import StreamAssembler
from lbrynet.stream.descriptor import StreamDescriptor from lbrynet.stream.descriptor import StreamDescriptor
from lbrynet.blob_exchange.downloader import BlobDownloader from lbrynet.blob_exchange.downloader import BlobDownloader
@ -20,14 +20,6 @@ def drain_into(a: list, b: list):
b.append(a.pop()) b.append(a.pop())
async def resolve_host(loop: asyncio.BaseEventLoop, url: str):
info = await loop.getaddrinfo(
url, 'https',
proto=socket.IPPROTO_TCP,
)
return info[0][4][0]
class StreamDownloader(StreamAssembler): class StreamDownloader(StreamAssembler):
def __init__(self, loop: asyncio.BaseEventLoop, config: 'Config', blob_manager: 'BlobFileManager', sd_hash: str, def __init__(self, loop: asyncio.BaseEventLoop, config: 'Config', blob_manager: 'BlobFileManager', sd_hash: str,
output_dir: typing.Optional[str] = None, output_file_name: typing.Optional[str] = None): output_dir: typing.Optional[str] = None, output_file_name: typing.Optional[str] = None):
@ -78,7 +70,7 @@ class StreamDownloader(StreamAssembler):
def add_fixed_peers(self): def add_fixed_peers(self):
async def _add_fixed_peers(): async def _add_fixed_peers():
self.peer_queue.put_nowait([ self.peer_queue.put_nowait([
KademliaPeer(self.loop, address=(await resolve_host(self.loop, url)), tcp_port=port + 1) KademliaPeer(self.loop, address=(await resolve_host(url)), tcp_port=port + 1)
for url, port in self.config.reflector_servers for url, port in self.config.reflector_servers
]) ])
if self.config.reflector_servers: if self.config.reflector_servers:

View file

@ -8,6 +8,7 @@ import json
import typing import typing
import asyncio import asyncio
import logging import logging
import ipaddress
import pkg_resources import pkg_resources
from lbrynet.schema.claim import ClaimDict from lbrynet.schema.claim import ClaimDict
from lbrynet.cryptoutils import get_lbry_hash_obj from lbrynet.cryptoutils import get_lbry_hash_obj
@ -136,3 +137,16 @@ def cancel_tasks(tasks: typing.List[typing.Optional[asyncio.Task]]):
def drain_tasks(tasks: typing.List[typing.Optional[asyncio.Task]]): def drain_tasks(tasks: typing.List[typing.Optional[asyncio.Task]]):
while tasks: while tasks:
cancel_task(tasks.pop()) cancel_task(tasks.pop())
async def resolve_host(url: str) -> str:
try:
if ipaddress.ip_address(url):
return url
except ValueError:
pass
loop = asyncio.get_running_loop()
return (await loop.getaddrinfo(
url, 'https',
proto=socket.IPPROTO_TCP,
))[0][4][0]