changes based on comments

This commit is contained in:
Jeffrey Picard 2021-11-15 08:52:32 -05:00
parent 395e1db489
commit 355eab682c
4 changed files with 28 additions and 22 deletions

View file

@ -5,6 +5,7 @@ import (
"context" "context"
"log" "log"
"math" "math"
"net"
"os" "os"
"strings" "strings"
"sync/atomic" "sync/atomic"
@ -27,6 +28,7 @@ var (
"127.0.0.1": true, "127.0.0.1": true,
"0.0.0.0": true, "0.0.0.0": true,
"localhost": true, "localhost": true,
"<nil>": true,
} }
) )
@ -70,6 +72,7 @@ func (s *Server) getNumSubs() int64 {
// getAndSetExternalIp takes the address of a peer running a UDP server and // getAndSetExternalIp takes the address of a peer running a UDP server and
// pings it, so we can determine our own external IP address. // pings it, so we can determine our own external IP address.
func (s *Server) getAndSetExternalIp(msg *pb.ServerMessage) error { func (s *Server) getAndSetExternalIp(msg *pb.ServerMessage) error {
log.Println(msg)
myIp, err := UDPPing(msg.Address, msg.Port) myIp, err := UDPPing(msg.Address, msg.Port)
if err != nil { if err != nil {
return err return err
@ -139,7 +142,7 @@ retry:
// If the peer is us, skip // If the peer is us, skip
log.Println(ipPort) log.Println(ipPort)
if ipPort[1] == port && if ipPort[1] == port &&
(localHosts[ipPort[0]] || ipPort[0] == s.ExternalIP) { (localHosts[ipPort[0]] || ipPort[0] == s.ExternalIP.String()) {
log.Println("Self peer, skipping ...") log.Println("Self peer, skipping ...")
continue continue
} }
@ -177,7 +180,7 @@ func (s *Server) subscribeToPeer(peer *FederatedServer) error {
defer conn.Close() defer conn.Close()
msg := &pb.ServerMessage{ msg := &pb.ServerMessage{
Address: s.ExternalIP, Address: s.ExternalIP.String(),
Port: s.Args.Port, Port: s.Args.Port,
} }
@ -217,7 +220,7 @@ func (s *Server) helloPeer(server *FederatedServer) (*pb.HelloMessage, error) {
msg := &pb.HelloMessage{ msg := &pb.HelloMessage{
Port: s.Args.Port, Port: s.Args.Port,
Host: s.ExternalIP, Host: s.ExternalIP.String(),
Servers: []*pb.ServerMessage{}, Servers: []*pb.ServerMessage{},
} }
@ -327,7 +330,10 @@ func (s *Server) notifyPeerSubs(newServer *FederatedServer) {
func (s *Server) addPeer(msg *pb.ServerMessage, ping bool, subscribe bool) error { func (s *Server) addPeer(msg *pb.ServerMessage, ping bool, subscribe bool) error {
// First thing we get our external ip if we don't have it, otherwise we // First thing we get our external ip if we don't have it, otherwise we
// could end up subscribed to our self, which is silly. // could end up subscribed to our self, which is silly.
if s.ExternalIP == "" { nilIP := net.IP{}
//localIP0 := net.IPv4(0,0,0,0)
localIP1 := net.IPv4(127,0,0,1)
if s.ExternalIP.Equal(nilIP) || s.ExternalIP.Equal(localIP1) {
err := s.getAndSetExternalIp(msg) err := s.getAndSetExternalIp(msg)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
@ -336,7 +342,7 @@ func (s *Server) addPeer(msg *pb.ServerMessage, ping bool, subscribe bool) error
} }
if s.Args.Port == msg.Port && if s.Args.Port == msg.Port &&
(localHosts[msg.Address] || msg.Address == s.ExternalIP) { (localHosts[msg.Address] || msg.Address == s.ExternalIP.String()) {
log.Printf("%s:%s addPeer: Self peer, skipping...\n", s.ExternalIP, s.Args.Port) log.Printf("%s:%s addPeer: Self peer, skipping...\n", s.ExternalIP, s.Args.Port)
return nil return nil
} }
@ -406,7 +412,7 @@ func (s *Server) makeHelloMessage() *pb.HelloMessage {
return &pb.HelloMessage{ return &pb.HelloMessage{
Port: s.Args.Port, Port: s.Args.Port,
Host: s.ExternalIP, Host: s.ExternalIP.String(),
Servers: servers, Servers: servers,
} }
} }

View file

@ -5,6 +5,7 @@ import (
"context" "context"
"fmt" "fmt"
"log" "log"
"net"
"os" "os"
"strings" "strings"
"testing" "testing"
@ -88,7 +89,7 @@ func TestAddPeer(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T){ t.Run(tt.name, func(t *testing.T){
server := MakeHubServer(ctx, args) server := MakeHubServer(ctx, args)
server.ExternalIP = "0.0.0.0" server.ExternalIP = net.IPv4(0,0,0,0)
metrics.PeersKnown.Set(0) metrics.PeersKnown.Set(0)
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
@ -147,7 +148,7 @@ func TestPeerWriter(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T){ t.Run(tt.name, func(t *testing.T){
server := MakeHubServer(ctx, args) server := MakeHubServer(ctx, args)
server.ExternalIP = "0.0.0.0" server.ExternalIP = net.IPv4(0,0,0,0)
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
var msg *pb.ServerMessage var msg *pb.ServerMessage
@ -466,12 +467,12 @@ func TestUDPServer(t *testing.T) {
server.GrpcServer.GracefulStop() server.GrpcServer.GracefulStop()
server2.GrpcServer.GracefulStop() server2.GrpcServer.GracefulStop()
got1 := server.ExternalIP got1 := server.ExternalIP.String()
if got1 != tt.want { if got1 != tt.want {
t.Errorf("server.ExternalIP = %s, want %s\n", got1, tt.want) t.Errorf("server.ExternalIP = %s, want %s\n", got1, tt.want)
t.Errorf("server.Args.Port = %s\n", server.Args.Port) t.Errorf("server.Args.Port = %s\n", server.Args.Port)
} }
got2 := server2.ExternalIP got2 := server2.ExternalIP.String()
if got2 != tt.want { if got2 != tt.want {
t.Errorf("server2.ExternalIP = %s, want %s\n", got2, tt.want) t.Errorf("server2.ExternalIP = %s, want %s\n", got2, tt.want)
t.Errorf("server2.Args.Port = %s\n", server2.Args.Port) t.Errorf("server2.Args.Port = %s\n", server2.Args.Port)

View file

@ -41,7 +41,7 @@ type Server struct {
PeerSubs map[string]*FederatedServer PeerSubs map[string]*FederatedServer
PeerSubsMut sync.RWMutex PeerSubsMut sync.RWMutex
NumPeerSubs *int64 NumPeerSubs *int64
ExternalIP string ExternalIP net.IP
pb.UnimplementedHubServer pb.UnimplementedHubServer
} }
@ -202,7 +202,7 @@ func MakeHubServer(ctx context.Context, args *Args) *Server {
PeerSubs: make(map[string]*FederatedServer), PeerSubs: make(map[string]*FederatedServer),
PeerSubsMut: sync.RWMutex{}, PeerSubsMut: sync.RWMutex{},
NumPeerSubs: numSubs, NumPeerSubs: numSubs,
ExternalIP: "", ExternalIP: net.IPv4(127,0,0,1),
} }
// Start up our background services // Start up our background services

View file

@ -2,7 +2,6 @@ package server
import ( import (
"encoding/binary" "encoding/binary"
"fmt"
"net" "net"
"strconv" "strconv"
"strings" "strings"
@ -137,8 +136,8 @@ func EncodeAddress(addr string) []byte {
} }
// DecodeAddress gets the string ipv4 address from an SPVPong struct. // DecodeAddress gets the string ipv4 address from an SPVPong struct.
func (pong *SPVPong) DecodeAddress() string { func (pong *SPVPong) DecodeAddress() net.IP {
return fmt.Sprintf("%d.%d.%d.%d", return net.IPv4(
pong.srcAddrRaw[0], pong.srcAddrRaw[0],
pong.srcAddrRaw[1], pong.srcAddrRaw[1],
pong.srcAddrRaw[2], pong.srcAddrRaw[2],
@ -148,40 +147,40 @@ func (pong *SPVPong) DecodeAddress() string {
// UDPPing sends a ping over udp to another hub and returns the ip address of // UDPPing sends a ping over udp to another hub and returns the ip address of
// this hub. // this hub.
func UDPPing(ip, port string) (string, error) { func UDPPing(ip, port string) (net.IP, error) {
address := ip + ":" + port address := ip + ":" + port
addr, err := net.ResolveUDPAddr("udp", address) addr, err := net.ResolveUDPAddr("udp", address)
if err != nil { if err != nil {
return "", err return net.IP{}, err
} }
conn, err := net.DialUDP("udp", nil, addr) conn, err := net.DialUDP("udp", nil, addr)
if err != nil { if err != nil {
return "", err return net.IP{}, err
} }
defer conn.Close() defer conn.Close()
_, err = conn.Write(encodeSPVPing()) _, err = conn.Write(encodeSPVPing())
if err != nil { if err != nil {
return "", err return net.IP{}, err
} }
buffer := make([]byte, maxBufferSize) buffer := make([]byte, maxBufferSize)
deadline := time.Now().Add(time.Second) deadline := time.Now().Add(time.Second)
err = conn.SetReadDeadline(deadline) err = conn.SetReadDeadline(deadline)
if err != nil { if err != nil {
return "", err return net.IP{}, err
} }
n, _, err := conn.ReadFromUDP(buffer) n, _, err := conn.ReadFromUDP(buffer)
if err != nil { if err != nil {
return "", err return net.IP{}, err
} }
pong := decodeSPVPong(buffer[:n]) pong := decodeSPVPong(buffer[:n])
if pong == nil { if pong == nil {
return "", errors.Base("Pong decoding failed") return net.IP{}, errors.Base("Pong decoding failed")
} }
myAddr := pong.DecodeAddress() myAddr := pong.DecodeAddress()