wire: Change NewNetAddress to accept a *net.TCPConn.
Rather than accepting a net.Addr interface and returning an error when it's not specifically a *net.TCPConn, just accept a *net.TCPConn directly so the compiler will assert it. Also, remove the error return since it can no longer occur.
This commit is contained in:
parent
a041b4349b
commit
2c6f864b55
5 changed files with 9 additions and 53 deletions
|
@ -41,16 +41,10 @@ func TestMessage(t *testing.T) {
|
||||||
|
|
||||||
// MsgVersion.
|
// MsgVersion.
|
||||||
addrYou := &net.TCPAddr{IP: net.ParseIP("192.168.0.1"), Port: 8333}
|
addrYou := &net.TCPAddr{IP: net.ParseIP("192.168.0.1"), Port: 8333}
|
||||||
you, err := NewNetAddress(addrYou, SFNodeNetwork)
|
you := NewNetAddress(addrYou, SFNodeNetwork)
|
||||||
if err != nil {
|
|
||||||
t.Errorf("NewNetAddress: %v", err)
|
|
||||||
}
|
|
||||||
you.Timestamp = time.Time{} // Version message has zero value timestamp.
|
you.Timestamp = time.Time{} // Version message has zero value timestamp.
|
||||||
addrMe := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8333}
|
addrMe := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8333}
|
||||||
me, err := NewNetAddress(addrMe, SFNodeNetwork)
|
me := NewNetAddress(addrMe, SFNodeNetwork)
|
||||||
if err != nil {
|
|
||||||
t.Errorf("NewNetAddress: %v", err)
|
|
||||||
}
|
|
||||||
me.Timestamp = time.Time{} // Version message has zero value timestamp.
|
me.Timestamp = time.Time{} // Version message has zero value timestamp.
|
||||||
msgVersion := NewMsgVersion(me, you, 123123, 0)
|
msgVersion := NewMsgVersion(me, you, 123123, 0)
|
||||||
|
|
||||||
|
|
|
@ -39,11 +39,8 @@ func TestAddr(t *testing.T) {
|
||||||
|
|
||||||
// Ensure NetAddresses are added properly.
|
// Ensure NetAddresses are added properly.
|
||||||
tcpAddr := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8333}
|
tcpAddr := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8333}
|
||||||
na, err := NewNetAddress(tcpAddr, SFNodeNetwork)
|
na := NewNetAddress(tcpAddr, SFNodeNetwork)
|
||||||
if err != nil {
|
err := msg.AddAddress(na)
|
||||||
t.Errorf("NewNetAddress: %v", err)
|
|
||||||
}
|
|
||||||
err = msg.AddAddress(na)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("AddAddress: %v", err)
|
t.Errorf("AddAddress: %v", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,15 +23,9 @@ func TestVersion(t *testing.T) {
|
||||||
// Create version message data.
|
// Create version message data.
|
||||||
lastBlock := int32(234234)
|
lastBlock := int32(234234)
|
||||||
tcpAddrMe := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8333}
|
tcpAddrMe := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8333}
|
||||||
me, err := NewNetAddress(tcpAddrMe, SFNodeNetwork)
|
me := NewNetAddress(tcpAddrMe, SFNodeNetwork)
|
||||||
if err != nil {
|
|
||||||
t.Errorf("NewNetAddress: %v", err)
|
|
||||||
}
|
|
||||||
tcpAddrYou := &net.TCPAddr{IP: net.ParseIP("192.168.0.1"), Port: 8333}
|
tcpAddrYou := &net.TCPAddr{IP: net.ParseIP("192.168.0.1"), Port: 8333}
|
||||||
you, err := NewNetAddress(tcpAddrYou, SFNodeNetwork)
|
you := NewNetAddress(tcpAddrYou, SFNodeNetwork)
|
||||||
if err != nil {
|
|
||||||
t.Errorf("NewNetAddress: %v", err)
|
|
||||||
}
|
|
||||||
nonce, err := RandomUint64()
|
nonce, err := RandomUint64()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("RandomUint64: error generating nonce: %v", err)
|
t.Errorf("RandomUint64: error generating nonce: %v", err)
|
||||||
|
|
|
@ -6,16 +6,11 @@ package wire
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ErrInvalidNetAddr describes an error that indicates the caller didn't specify
|
|
||||||
// a TCP address as required.
|
|
||||||
var ErrInvalidNetAddr = errors.New("provided net.Addr is not a net.TCPAddr")
|
|
||||||
|
|
||||||
// maxNetAddressPayload returns the max payload size for a bitcoin NetAddress
|
// maxNetAddressPayload returns the max payload size for a bitcoin NetAddress
|
||||||
// based on the protocol version.
|
// based on the protocol version.
|
||||||
func maxNetAddressPayload(pver uint32) uint32 {
|
func maxNetAddressPayload(pver uint32) uint32 {
|
||||||
|
@ -85,17 +80,8 @@ func NewNetAddressIPPort(ip net.IP, port uint16, services ServiceFlag) *NetAddre
|
||||||
|
|
||||||
// NewNetAddress returns a new NetAddress using the provided TCP address and
|
// NewNetAddress returns a new NetAddress using the provided TCP address and
|
||||||
// supported services with defaults for the remaining fields.
|
// supported services with defaults for the remaining fields.
|
||||||
//
|
func NewNetAddress(addr *net.TCPAddr, services ServiceFlag) *NetAddress {
|
||||||
// Note that addr must be a net.TCPAddr. An ErrInvalidNetAddr is returned
|
return NewNetAddressIPPort(addr.IP, uint16(addr.Port), services)
|
||||||
// if it is not.
|
|
||||||
func NewNetAddress(addr net.Addr, services ServiceFlag) (*NetAddress, error) {
|
|
||||||
tcpAddr, ok := addr.(*net.TCPAddr)
|
|
||||||
if !ok {
|
|
||||||
return nil, ErrInvalidNetAddr
|
|
||||||
}
|
|
||||||
|
|
||||||
na := NewNetAddressIPPort(tcpAddr.IP, uint16(tcpAddr.Port), services)
|
|
||||||
return na, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// readNetAddress reads an encoded NetAddress from r depending on the protocol
|
// readNetAddress reads an encoded NetAddress from r depending on the protocol
|
||||||
|
|
|
@ -21,14 +21,7 @@ func TestNetAddress(t *testing.T) {
|
||||||
port := 8333
|
port := 8333
|
||||||
|
|
||||||
// Test NewNetAddress.
|
// Test NewNetAddress.
|
||||||
tcpAddr := &net.TCPAddr{
|
na := NewNetAddress(&net.TCPAddr{IP: ip, Port: port}, 0)
|
||||||
IP: ip,
|
|
||||||
Port: port,
|
|
||||||
}
|
|
||||||
na, err := NewNetAddress(tcpAddr, 0)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("NewNetAddress: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure we get the same ip, port, and services back out.
|
// Ensure we get the same ip, port, and services back out.
|
||||||
if !na.IP.Equal(ip) {
|
if !na.IP.Equal(ip) {
|
||||||
|
@ -76,14 +69,6 @@ func TestNetAddress(t *testing.T) {
|
||||||
"protocol version %d - got %v, want %v", pver,
|
"protocol version %d - got %v, want %v", pver,
|
||||||
maxPayload, wantPayload)
|
maxPayload, wantPayload)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for expected failure on wrong address type.
|
|
||||||
udpAddr := &net.UDPAddr{}
|
|
||||||
_, err = NewNetAddress(udpAddr, 0)
|
|
||||||
if err != ErrInvalidNetAddr {
|
|
||||||
t.Errorf("NewNetAddress: expected error not received - "+
|
|
||||||
"got %v, want %v", err, ErrInvalidNetAddr)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestNetAddressWire tests the NetAddress wire encode and decode for various
|
// TestNetAddressWire tests the NetAddress wire encode and decode for various
|
||||||
|
|
Loading…
Reference in a new issue