connmgr: switch to using net.Addr interface throughout for addresses
This commit modifies the `ConnManager` to use the `net.Add` interface through the package instead of a plain string to represent and manipulate addresses. This change makes the package much more general as users of the package can possibly utilize custom implementations of the `net.Addr` interface to establish connections. More precisely, the `ConnReq` struct has been modified to use a net.Addr instance explicitly, and the `DialFunc` type has also been modified to take a `net.Addr` directly. This latter change gives functions that adhere to the `DialFunc` type more flexibility as to exactly how the connection is established. Additionally, the `connmgr.Config.GetNewAddress` configuration option now directly returns a `net.Addr. This change allows the `connmgr` to be decoupled from all DNS queries which allows callers to preferentially select more secure methods like performing DNS lookups over a Tor proxy.
This commit is contained in:
parent
df33d4340e
commit
e8f63bc295
4 changed files with 134 additions and 39 deletions
|
@ -1001,11 +1001,11 @@ func createDefaultConfigFile(destinationPath string) error {
|
|||
// example, .onion addresses will be dialed using the onion specific proxy if
|
||||
// one was specified, but will otherwise use the normal dial function (which
|
||||
// could itself use a proxy or not).
|
||||
func btcdDial(network, address string) (net.Conn, error) {
|
||||
if strings.Contains(address, ".onion:") {
|
||||
return cfg.oniondial(network, address)
|
||||
func btcdDial(addr net.Addr) (net.Conn, error) {
|
||||
if strings.Contains(addr.String(), ".onion:") {
|
||||
return cfg.oniondial(addr.Network(), addr.String())
|
||||
}
|
||||
return cfg.dial(network, address)
|
||||
return cfg.dial(addr.Network(), addr.String())
|
||||
}
|
||||
|
||||
// btcdLookup returns the correct DNS lookup function to use depending on the
|
||||
|
|
|
@ -57,7 +57,7 @@ type ConnReq struct {
|
|||
// The following variables must only be used atomically.
|
||||
id uint64
|
||||
|
||||
Addr string
|
||||
Addr net.Addr
|
||||
Permanent bool
|
||||
|
||||
conn net.Conn
|
||||
|
@ -88,7 +88,7 @@ func (c *ConnReq) State() ConnState {
|
|||
|
||||
// String returns a human-readable string for the connection request.
|
||||
func (c *ConnReq) String() string {
|
||||
if c.Addr == "" {
|
||||
if c.Addr.String() == "" {
|
||||
return fmt.Sprintf("reqid %d", atomic.LoadUint64(&c.id))
|
||||
}
|
||||
return fmt.Sprintf("%s (reqid %d)", c.Addr, atomic.LoadUint64(&c.id))
|
||||
|
@ -137,10 +137,10 @@ type Config struct {
|
|||
|
||||
// GetNewAddress is a way to get an address to make a network connection
|
||||
// to. If nil, no new connections will be made automatically.
|
||||
GetNewAddress func() (string, error)
|
||||
GetNewAddress func() (net.Addr, error)
|
||||
|
||||
// Dial connects to the address on the named network. It cannot be nil.
|
||||
Dial func(string, string) (net.Conn, error)
|
||||
Dial func(net.Addr) (net.Conn, error)
|
||||
}
|
||||
|
||||
// handleConnected is used to queue a successful connection.
|
||||
|
@ -281,14 +281,18 @@ func (cm *ConnManager) NewConnReq() {
|
|||
if cm.cfg.GetNewAddress == nil {
|
||||
return
|
||||
}
|
||||
|
||||
c := &ConnReq{}
|
||||
atomic.StoreUint64(&c.id, atomic.AddUint64(&cm.connReqCount, 1))
|
||||
|
||||
addr, err := cm.cfg.GetNewAddress()
|
||||
if err != nil {
|
||||
cm.requests <- handleFailed{c, err}
|
||||
return
|
||||
}
|
||||
|
||||
c.Addr = addr
|
||||
|
||||
cm.Connect(c)
|
||||
}
|
||||
|
||||
|
@ -302,7 +306,7 @@ func (cm *ConnManager) Connect(c *ConnReq) {
|
|||
atomic.StoreUint64(&c.id, atomic.AddUint64(&cm.connReqCount, 1))
|
||||
}
|
||||
log.Debugf("Attempting to connect to %v", c)
|
||||
conn, err := cm.cfg.Dial("tcp", c.Addr)
|
||||
conn, err := cm.cfg.Dial(c.Addr)
|
||||
if err != nil {
|
||||
cm.requests <- handleFailed{c, err}
|
||||
} else {
|
||||
|
|
|
@ -9,7 +9,6 @@ import (
|
|||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
@ -40,7 +39,7 @@ type mockConn struct {
|
|||
lnet, laddr string
|
||||
|
||||
// remote network, address for the connection.
|
||||
rnet, raddr string
|
||||
rAddr net.Addr
|
||||
}
|
||||
|
||||
// LocalAddr returns the local address for the connection.
|
||||
|
@ -50,7 +49,7 @@ func (c mockConn) LocalAddr() net.Addr {
|
|||
|
||||
// RemoteAddr returns the remote address for the connection.
|
||||
func (c mockConn) RemoteAddr() net.Addr {
|
||||
return &mockAddr{c.rnet, c.raddr}
|
||||
return &mockAddr{c.rAddr.Network(), c.rAddr.String()}
|
||||
}
|
||||
|
||||
// Close handles closing the connection.
|
||||
|
@ -64,9 +63,9 @@ func (c mockConn) SetWriteDeadline(t time.Time) error { return nil }
|
|||
|
||||
// mockDialer mocks the net.Dial interface by returning a mock connection to
|
||||
// the given address.
|
||||
func mockDialer(network, address string) (net.Conn, error) {
|
||||
func mockDialer(addr net.Addr) (net.Conn, error) {
|
||||
r, w := io.Pipe()
|
||||
c := &mockConn{raddr: address}
|
||||
c := &mockConn{rAddr: addr}
|
||||
c.Reader = r
|
||||
c.Writer = w
|
||||
return c, nil
|
||||
|
@ -102,7 +101,12 @@ func TestStartStop(t *testing.T) {
|
|||
disconnected := make(chan *ConnReq)
|
||||
cmgr, err := New(&Config{
|
||||
TargetOutbound: 1,
|
||||
GetNewAddress: func() (string, error) { return "127.0.0.1:18555", nil },
|
||||
GetNewAddress: func() (net.Addr, error) {
|
||||
return &net.TCPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: 18555,
|
||||
}, nil
|
||||
},
|
||||
Dial: mockDialer,
|
||||
OnConnection: func(c *ConnReq, conn net.Conn) {
|
||||
connected <- c
|
||||
|
@ -120,7 +124,13 @@ func TestStartStop(t *testing.T) {
|
|||
// already stopped
|
||||
cmgr.Stop()
|
||||
// ignored
|
||||
cr := &ConnReq{Addr: "127.0.0.1:18555", Permanent: true}
|
||||
cr := &ConnReq{
|
||||
Addr: &net.TCPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: 18555,
|
||||
},
|
||||
Permanent: true,
|
||||
}
|
||||
cmgr.Connect(cr)
|
||||
if cr.ID() != 0 {
|
||||
t.Fatalf("start/stop: got id: %v, want: 0", cr.ID())
|
||||
|
@ -151,7 +161,13 @@ func TestConnectMode(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("New error: %v", err)
|
||||
}
|
||||
cr := &ConnReq{Addr: "127.0.0.1:18555", Permanent: true}
|
||||
cr := &ConnReq{
|
||||
Addr: &net.TCPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: 18555,
|
||||
},
|
||||
Permanent: true,
|
||||
}
|
||||
cmgr.Start()
|
||||
cmgr.Connect(cr)
|
||||
gotConnReq := <-connected
|
||||
|
@ -184,7 +200,12 @@ func TestTargetOutbound(t *testing.T) {
|
|||
cmgr, err := New(&Config{
|
||||
TargetOutbound: targetOutbound,
|
||||
Dial: mockDialer,
|
||||
GetNewAddress: func() (string, error) { return "127.0.0.1:18555", nil },
|
||||
GetNewAddress: func() (net.Addr, error) {
|
||||
return &net.TCPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: 18555,
|
||||
}, nil
|
||||
},
|
||||
OnConnection: func(c *ConnReq, conn net.Conn) {
|
||||
connected <- c
|
||||
},
|
||||
|
@ -228,7 +249,13 @@ func TestRetryPermanent(t *testing.T) {
|
|||
t.Fatalf("New error: %v", err)
|
||||
}
|
||||
|
||||
cr := &ConnReq{Addr: "127.0.0.1:18555", Permanent: true}
|
||||
cr := &ConnReq{
|
||||
Addr: &net.TCPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: 18555,
|
||||
},
|
||||
Permanent: true,
|
||||
}
|
||||
go cmgr.Connect(cr)
|
||||
cmgr.Start()
|
||||
gotConnReq := <-connected
|
||||
|
@ -292,10 +319,10 @@ func TestMaxRetryDuration(t *testing.T) {
|
|||
time.AfterFunc(5*time.Millisecond, func() {
|
||||
close(networkUp)
|
||||
})
|
||||
timedDialer := func(network, address string) (net.Conn, error) {
|
||||
timedDialer := func(addr net.Addr) (net.Conn, error) {
|
||||
select {
|
||||
case <-networkUp:
|
||||
return mockDialer(network, address)
|
||||
return mockDialer(addr)
|
||||
default:
|
||||
return nil, errors.New("network down")
|
||||
}
|
||||
|
@ -314,7 +341,13 @@ func TestMaxRetryDuration(t *testing.T) {
|
|||
t.Fatalf("New error: %v", err)
|
||||
}
|
||||
|
||||
cr := &ConnReq{Addr: "127.0.0.1:18555", Permanent: true}
|
||||
cr := &ConnReq{
|
||||
Addr: &net.TCPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: 18555,
|
||||
},
|
||||
Permanent: true,
|
||||
}
|
||||
go cmgr.Connect(cr)
|
||||
cmgr.Start()
|
||||
// retry in 1ms
|
||||
|
@ -331,7 +364,7 @@ func TestMaxRetryDuration(t *testing.T) {
|
|||
// failure gracefully.
|
||||
func TestNetworkFailure(t *testing.T) {
|
||||
var dials uint32
|
||||
errDialer := func(network, address string) (net.Conn, error) {
|
||||
errDialer := func(net net.Addr) (net.Conn, error) {
|
||||
atomic.AddUint32(&dials, 1)
|
||||
return nil, errors.New("network down")
|
||||
}
|
||||
|
@ -339,7 +372,12 @@ func TestNetworkFailure(t *testing.T) {
|
|||
TargetOutbound: 5,
|
||||
RetryDuration: 5 * time.Millisecond,
|
||||
Dial: errDialer,
|
||||
GetNewAddress: func() (string, error) { return "127.0.0.1:18555", nil },
|
||||
GetNewAddress: func() (net.Addr, error) {
|
||||
return &net.TCPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: 18555,
|
||||
}, nil
|
||||
},
|
||||
OnConnection: func(c *ConnReq, conn net.Conn) {
|
||||
t.Fatalf("network failure: got unexpected connection - %v", c.Addr)
|
||||
},
|
||||
|
@ -365,7 +403,7 @@ func TestNetworkFailure(t *testing.T) {
|
|||
// the failure.
|
||||
func TestStopFailed(t *testing.T) {
|
||||
done := make(chan struct{}, 1)
|
||||
waitDialer := func(network, address string) (net.Conn, error) {
|
||||
waitDialer := func(addr net.Addr) (net.Conn, error) {
|
||||
done <- struct{}{}
|
||||
time.Sleep(time.Millisecond)
|
||||
return nil, errors.New("network down")
|
||||
|
@ -384,7 +422,13 @@ func TestStopFailed(t *testing.T) {
|
|||
atomic.StoreInt32(&cmgr.stop, 0)
|
||||
cmgr.Stop()
|
||||
}()
|
||||
cr := &ConnReq{Addr: "127.0.0.1:18555", Permanent: true}
|
||||
cr := &ConnReq{
|
||||
Addr: &net.TCPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: 18555,
|
||||
},
|
||||
Permanent: true,
|
||||
}
|
||||
go cmgr.Connect(cr)
|
||||
cmgr.Wait()
|
||||
}
|
||||
|
@ -428,12 +472,14 @@ func (m *mockListener) Addr() net.Addr {
|
|||
// address. It will cause the Accept function to return a mock connection
|
||||
// configured with the provided remote address and the local address for the
|
||||
// mock listener.
|
||||
func (m *mockListener) Connect(remoteAddr string) {
|
||||
func (m *mockListener) Connect(ip string, port int) {
|
||||
m.provideConn <- &mockConn{
|
||||
laddr: m.localAddr,
|
||||
lnet: "tcp",
|
||||
raddr: remoteAddr,
|
||||
rnet: "tcp",
|
||||
rAddr: &net.TCPAddr{
|
||||
IP: net.ParseIP(ip),
|
||||
Port: port,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -471,8 +517,8 @@ func TestListeners(t *testing.T) {
|
|||
go func() {
|
||||
for i, listener := range listeners {
|
||||
l := listener.(*mockListener)
|
||||
l.Connect("127.0.0.1:" + strconv.Itoa(10000+i*2))
|
||||
l.Connect("127.0.0.1:" + strconv.Itoa(10000+i*2+1))
|
||||
l.Connect("127.0.0.1", 10000+i*2)
|
||||
l.Connect("127.0.0.1", 10000+i*2+1)
|
||||
}
|
||||
}()
|
||||
|
||||
|
|
59
server.go
59
server.go
|
@ -1456,9 +1456,15 @@ func (s *server) handleQuery(state *peerState, querymsg interface{}) {
|
|||
}
|
||||
}
|
||||
|
||||
netAddr, err := addrStringToNetAddr(msg.addr)
|
||||
if err != nil {
|
||||
msg.reply <- err
|
||||
return
|
||||
}
|
||||
|
||||
// TODO(oga) if too many, nuke a non-perm peer.
|
||||
go s.connManager.Connect(&connmgr.ConnReq{
|
||||
Addr: msg.addr,
|
||||
Addr: netAddr,
|
||||
Permanent: msg.permanent,
|
||||
})
|
||||
msg.reply <- nil
|
||||
|
@ -1603,7 +1609,7 @@ func (s *server) inboundPeerConnected(conn net.Conn) {
|
|||
// manager of the attempt.
|
||||
func (s *server) outboundPeerConnected(c *connmgr.ConnReq, conn net.Conn) {
|
||||
sp := newServerPeer(s, c.Permanent)
|
||||
p, err := peer.NewOutboundPeer(newPeerConfig(sp), c.Addr)
|
||||
p, err := peer.NewOutboundPeer(newPeerConfig(sp), c.Addr.String())
|
||||
if err != nil {
|
||||
srvrLog.Debugf("Cannot create outbound peer %s: %v", c.Addr, err)
|
||||
s.connManager.Disconnect(c.ID())
|
||||
|
@ -2401,9 +2407,9 @@ func newServer(listenAddrs []string, db database.DB, chainParams *chaincfg.Param
|
|||
// specified peers and actively avoid advertising and connecting to
|
||||
// discovered peers in order to prevent it from becoming a public test
|
||||
// network.
|
||||
var newAddressFunc func() (string, error)
|
||||
var newAddressFunc func() (net.Addr, error)
|
||||
if !cfg.SimNet && len(cfg.ConnectPeers) == 0 {
|
||||
newAddressFunc = func() (string, error) {
|
||||
newAddressFunc = func() (net.Addr, error) {
|
||||
for tries := 0; tries < 100; tries++ {
|
||||
addr := s.addrManager.GetAddress("any")
|
||||
if addr == nil {
|
||||
|
@ -2432,9 +2438,12 @@ func newServer(listenAddrs []string, db database.DB, chainParams *chaincfg.Param
|
|||
activeNetParams.DefaultPort && tries < 50 {
|
||||
continue
|
||||
}
|
||||
return addrmgr.NetAddressKey(addr.NetAddress()), nil
|
||||
|
||||
addrString := addrmgr.NetAddressKey(addr.NetAddress())
|
||||
return addrStringToNetAddr(addrString)
|
||||
}
|
||||
return "", errors.New("no valid connect address")
|
||||
|
||||
return nil, errors.New("no valid connect address")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2463,7 +2472,15 @@ func newServer(listenAddrs []string, db database.DB, chainParams *chaincfg.Param
|
|||
permanentPeers = cfg.AddPeers
|
||||
}
|
||||
for _, addr := range permanentPeers {
|
||||
go s.connManager.Connect(&connmgr.ConnReq{Addr: addr, Permanent: true})
|
||||
tcpAddr, err := addrStringToNetAddr(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go s.connManager.Connect(&connmgr.ConnReq{
|
||||
Addr: tcpAddr,
|
||||
Permanent: true,
|
||||
})
|
||||
}
|
||||
|
||||
if !cfg.DisableRPC {
|
||||
|
@ -2483,6 +2500,34 @@ func newServer(listenAddrs []string, db database.DB, chainParams *chaincfg.Param
|
|||
return &s, nil
|
||||
}
|
||||
|
||||
// addrStringToNetAddr takes an address in the form of 'host:port' and returns
|
||||
// a net.Addr which maps to the original address with any host names resolved
|
||||
// to IP addresses.
|
||||
func addrStringToNetAddr(addr string) (net.Addr, error) {
|
||||
host, strPort, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Attempt to look up an IP address associated with the parsed host.
|
||||
// The btcdLookup function will transparently handle performing the
|
||||
// lookup over Tor if necessary.
|
||||
ips, err := btcdLookup(host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
port, err := strconv.Atoi(strPort)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &net.TCPAddr{
|
||||
IP: ips[0],
|
||||
Port: port,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// dynamicTickDuration is a convenience function used to dynamically choose a
|
||||
// tick duration based on remaining time. It is primarily used during
|
||||
// server shutdown to make shutdown warnings more frequent as the shutdown time
|
||||
|
|
Loading…
Add table
Reference in a new issue