wire: Change NewNetAddress to accept a *net.TCPConn.

Rather than accepting a net.Addr interface and returning an error when
it's not specifically a *net.TCPConn, just accept a *net.TCPConn
directly so the compiler will assert it.  Also, remove the error return
since it can no longer occur.
This commit is contained in:
Dave Collins 2016-11-03 20:16:37 -05:00
parent a041b4349b
commit 2c6f864b55
No known key found for this signature in database
GPG key ID: B8904D9D9C93D1F2
5 changed files with 9 additions and 53 deletions

View file

@ -41,16 +41,10 @@ func TestMessage(t *testing.T) {
// MsgVersion. // MsgVersion.
addrYou := &net.TCPAddr{IP: net.ParseIP("192.168.0.1"), Port: 8333} addrYou := &net.TCPAddr{IP: net.ParseIP("192.168.0.1"), Port: 8333}
you, err := NewNetAddress(addrYou, SFNodeNetwork) you := NewNetAddress(addrYou, SFNodeNetwork)
if err != nil {
t.Errorf("NewNetAddress: %v", err)
}
you.Timestamp = time.Time{} // Version message has zero value timestamp. you.Timestamp = time.Time{} // Version message has zero value timestamp.
addrMe := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8333} addrMe := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8333}
me, err := NewNetAddress(addrMe, SFNodeNetwork) me := NewNetAddress(addrMe, SFNodeNetwork)
if err != nil {
t.Errorf("NewNetAddress: %v", err)
}
me.Timestamp = time.Time{} // Version message has zero value timestamp. me.Timestamp = time.Time{} // Version message has zero value timestamp.
msgVersion := NewMsgVersion(me, you, 123123, 0) msgVersion := NewMsgVersion(me, you, 123123, 0)

View file

@ -39,11 +39,8 @@ func TestAddr(t *testing.T) {
// Ensure NetAddresses are added properly. // Ensure NetAddresses are added properly.
tcpAddr := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8333} tcpAddr := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8333}
na, err := NewNetAddress(tcpAddr, SFNodeNetwork) na := NewNetAddress(tcpAddr, SFNodeNetwork)
if err != nil { err := msg.AddAddress(na)
t.Errorf("NewNetAddress: %v", err)
}
err = msg.AddAddress(na)
if err != nil { if err != nil {
t.Errorf("AddAddress: %v", err) t.Errorf("AddAddress: %v", err)
} }

View file

@ -23,15 +23,9 @@ func TestVersion(t *testing.T) {
// Create version message data. // Create version message data.
lastBlock := int32(234234) lastBlock := int32(234234)
tcpAddrMe := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8333} tcpAddrMe := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8333}
me, err := NewNetAddress(tcpAddrMe, SFNodeNetwork) me := NewNetAddress(tcpAddrMe, SFNodeNetwork)
if err != nil {
t.Errorf("NewNetAddress: %v", err)
}
tcpAddrYou := &net.TCPAddr{IP: net.ParseIP("192.168.0.1"), Port: 8333} tcpAddrYou := &net.TCPAddr{IP: net.ParseIP("192.168.0.1"), Port: 8333}
you, err := NewNetAddress(tcpAddrYou, SFNodeNetwork) you := NewNetAddress(tcpAddrYou, SFNodeNetwork)
if err != nil {
t.Errorf("NewNetAddress: %v", err)
}
nonce, err := RandomUint64() nonce, err := RandomUint64()
if err != nil { if err != nil {
t.Errorf("RandomUint64: error generating nonce: %v", err) t.Errorf("RandomUint64: error generating nonce: %v", err)

View file

@ -6,16 +6,11 @@ package wire
import ( import (
"encoding/binary" "encoding/binary"
"errors"
"io" "io"
"net" "net"
"time" "time"
) )
// ErrInvalidNetAddr describes an error that indicates the caller didn't specify
// a TCP address as required.
var ErrInvalidNetAddr = errors.New("provided net.Addr is not a net.TCPAddr")
// maxNetAddressPayload returns the max payload size for a bitcoin NetAddress // maxNetAddressPayload returns the max payload size for a bitcoin NetAddress
// based on the protocol version. // based on the protocol version.
func maxNetAddressPayload(pver uint32) uint32 { func maxNetAddressPayload(pver uint32) uint32 {
@ -85,17 +80,8 @@ func NewNetAddressIPPort(ip net.IP, port uint16, services ServiceFlag) *NetAddre
// NewNetAddress returns a new NetAddress using the provided TCP address and // NewNetAddress returns a new NetAddress using the provided TCP address and
// supported services with defaults for the remaining fields. // supported services with defaults for the remaining fields.
// func NewNetAddress(addr *net.TCPAddr, services ServiceFlag) *NetAddress {
// Note that addr must be a net.TCPAddr. An ErrInvalidNetAddr is returned return NewNetAddressIPPort(addr.IP, uint16(addr.Port), services)
// if it is not.
func NewNetAddress(addr net.Addr, services ServiceFlag) (*NetAddress, error) {
tcpAddr, ok := addr.(*net.TCPAddr)
if !ok {
return nil, ErrInvalidNetAddr
}
na := NewNetAddressIPPort(tcpAddr.IP, uint16(tcpAddr.Port), services)
return na, nil
} }
// readNetAddress reads an encoded NetAddress from r depending on the protocol // readNetAddress reads an encoded NetAddress from r depending on the protocol

View file

@ -21,14 +21,7 @@ func TestNetAddress(t *testing.T) {
port := 8333 port := 8333
// Test NewNetAddress. // Test NewNetAddress.
tcpAddr := &net.TCPAddr{ na := NewNetAddress(&net.TCPAddr{IP: ip, Port: port}, 0)
IP: ip,
Port: port,
}
na, err := NewNetAddress(tcpAddr, 0)
if err != nil {
t.Errorf("NewNetAddress: %v", err)
}
// Ensure we get the same ip, port, and services back out. // Ensure we get the same ip, port, and services back out.
if !na.IP.Equal(ip) { if !na.IP.Equal(ip) {
@ -76,14 +69,6 @@ func TestNetAddress(t *testing.T) {
"protocol version %d - got %v, want %v", pver, "protocol version %d - got %v, want %v", pver,
maxPayload, wantPayload) maxPayload, wantPayload)
} }
// Check for expected failure on wrong address type.
udpAddr := &net.UDPAddr{}
_, err = NewNetAddress(udpAddr, 0)
if err != ErrInvalidNetAddr {
t.Errorf("NewNetAddress: expected error not received - "+
"got %v, want %v", err, ErrInvalidNetAddr)
}
} }
// TestNetAddressWire tests the NetAddress wire encode and decode for various // TestNetAddressWire tests the NetAddress wire encode and decode for various