diff --git a/server/federation.go b/server/federation.go index c85ff3f..25a0dca 100644 --- a/server/federation.go +++ b/server/federation.go @@ -73,10 +73,11 @@ func (s *Server) getNumSubs() int64 { // pings it, so we can determine our own external IP address. func (s *Server) getAndSetExternalIp(msg *pb.ServerMessage) error { log.Println(msg) - myIp, _, err := UDPPing(msg.Address, msg.Port) + pong, err := UDPPing(msg.Address, msg.Port) if err != nil { return err } + myIp := pong.DecodeAddress() log.Println("my ip: ", myIp) s.ExternalIP = myIp diff --git a/server/udp.go b/server/udp.go index 734025f..55cc14a 100644 --- a/server/udp.go +++ b/server/udp.go @@ -153,48 +153,64 @@ func (pong *SPVPong) DecodeCountry() string { return pb.Location_Country_name[int32(pong.country)] } +func (pong *SPVPong) DecodeProtocolVersion() int { + return int(pong.protocolVersion) +} + +func (pong *SPVPong) DecodeHeight() int { + return int(pong.height) +} + +func (pong *SPVPong) DecodeTip() []byte { + return pong.tip +} + +func (pong *SPVPong) DecodeFlags() byte { + return pong.flags +} + // UDPPing sends a ping over udp to another hub and returns the ip address of // this hub. -func UDPPing(ip, port string) (net.IP, string, error) { +func UDPPing(ip, port string) (*SPVPong, error) { address := ip + ":" + port addr, err := net.ResolveUDPAddr("udp", address) if err != nil { - return net.IP{}, "", err + return nil, err } conn, err := net.DialUDP("udp", nil, addr) if err != nil { - return net.IP{}, "", err + return nil, err } defer conn.Close() _, err = conn.Write(encodeSPVPing()) if err != nil { - return net.IP{}, "", err + return nil, err } buffer := make([]byte, maxBufferSize) deadline := time.Now().Add(time.Second) err = conn.SetReadDeadline(deadline) if err != nil { - return net.IP{}, "", err + return nil, err } n, _, err := conn.ReadFromUDP(buffer) if err != nil { - return net.IP{}, "", err + return nil, err } pong := decodeSPVPong(buffer[:n]) if pong == nil { - return net.IP{}, "", errors.Base("Pong decoding failed") + return nil, errors.Base("Pong decoding failed") } - myAddr := pong.DecodeAddress() - country := pong.DecodeCountry() + // myAddr := pong.DecodeAddress() + // country := pong.DecodeCountry() - return myAddr, country, nil + return pong, nil } // UDPServer is a goroutine that starts an udp server that implements the hubs diff --git a/server/udp_test.go b/server/udp_test.go index 5f1071f..ceb3db3 100644 --- a/server/udp_test.go +++ b/server/udp_test.go @@ -1,59 +1,83 @@ package server import ( - "log" - "os/exec" - "strings" - "testing" + "log" + "os/exec" + "strings" + "testing" ) // TestAddPeer tests the ability to add peers func TestUDPPing(t *testing.T) { - args := makeDefaultArgs() - args.StartUDP = false + args := makeDefaultArgs() + args.StartUDP = false - tests := []struct { - name string - wantIP string - wantCountry string - } { - { - name: "Get the right ip from production server.", - wantIP: "SETME", - wantCountry: "US", - }, - } + tests := []struct { + name string + wantIP string + wantCountry string + wantProtocolVersion int + wantHeightMin int + wantFlags byte + } { + { + name: "Correctly parse information from production server.", + wantIP: "SETME", + wantCountry: "US", + wantProtocolVersion: 1, + wantHeightMin: 1060000, + wantFlags: 1, + }, + } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T){ + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T){ - toAddr := "spv16.lbry.com" - toPort := "50001" + toAddr := "spv16.lbry.com" + toPort := "50001" - ip, country, err := UDPPing(toAddr, toPort) - gotCountry := country - if err != nil { - log.Println(err) - } + pong, err := UDPPing(toAddr, toPort) + gotCountry := pong.DecodeCountry() + if err != nil { + log.Println(err) + } - res, err := exec.Command("dig", "@resolver4.opendns.com", "myip.opendns.com", "+short").Output() + res, err := exec.Command("dig", "@resolver4.opendns.com", "myip.opendns.com", "+short").Output() - if err != nil { - log.Println(err) - } + if err != nil { + log.Println(err) + } - digIP := strings.TrimSpace(string(res)) - udpIP := ip.String() - tt.wantIP = digIP + digIP := strings.TrimSpace(string(res)) + udpIP := pong.DecodeAddress().String() + tt.wantIP = digIP - gotIP := udpIP - if gotIP != tt.wantIP { - t.Errorf("got: '%s', want: '%s'\n", gotIP, tt.wantIP) - } - if gotCountry != tt.wantCountry { - t.Errorf("got: '%s', want: '%s'\n", gotCountry, tt.wantCountry) - } - }) - } + log.Println("Height:", pong.DecodeHeight()) + log.Printf("Flags: %x\n", pong.DecodeFlags()) + log.Println("ProtocolVersion:", pong.DecodeProtocolVersion()) + log.Printf("Tip: %x\n", pong.DecodeTip()) + + gotHeight := pong.DecodeHeight() + gotProtocolVersion := pong.DecodeProtocolVersion() + gotFlags := pong.DecodeFlags() + gotIP := udpIP + + if gotIP != tt.wantIP { + t.Errorf("ip: got: '%s', want: '%s'\n", gotIP, tt.wantIP) + } + if gotCountry != tt.wantCountry { + t.Errorf("country: got: '%s', want: '%s'\n", gotCountry, tt.wantCountry) + } + if gotHeight < tt.wantHeightMin { + t.Errorf("height: got: %d, want >=: %d\n", gotHeight, tt.wantHeightMin) + } + if gotProtocolVersion != tt.wantProtocolVersion { + t.Errorf("protocolVersion: got: %d, want: %d\n", gotProtocolVersion, tt.wantProtocolVersion) + } + if gotFlags != tt.wantFlags { + t.Errorf("flags: got: %d, want: %d\n", gotFlags, tt.wantFlags) + } + }) + } }