connmgr: switch to using net.Addr interface throughout for addresses

This commit modifies the `ConnManager` to use the `net.Add` interface
through the package instead of a plain string to represent and
manipulate addresses. This change makes the package much more general as
users of the package can possibly utilize custom implementations of the
`net.Addr` interface to establish connections.

More precisely, the `ConnReq` struct has been modified to use a net.Addr
instance explicitly, and the `DialFunc` type has also been modified to
take a `net.Addr` directly. This latter change gives functions that
adhere to the `DialFunc` type more flexibility as to exactly how the
connection is established.

Additionally, the `connmgr.Config.GetNewAddress` configuration option
now directly returns a `net.Addr. This change allows the `connmgr` to be
decoupled from all DNS queries which allows callers to preferentially
select more secure methods like performing DNS lookups over a Tor proxy.
This commit is contained in:
Olaoluwa Osuntokun 2016-11-03 14:23:13 -07:00
parent df33d4340e
commit e8f63bc295
No known key found for this signature in database
GPG key ID: 9CC5B105D03521A2
4 changed files with 134 additions and 39 deletions

View file

@ -1001,11 +1001,11 @@ func createDefaultConfigFile(destinationPath string) error {
// example, .onion addresses will be dialed using the onion specific proxy if
// one was specified, but will otherwise use the normal dial function (which
// could itself use a proxy or not).
func btcdDial(network, address string) (net.Conn, error) {
if strings.Contains(address, ".onion:") {
return cfg.oniondial(network, address)
func btcdDial(addr net.Addr) (net.Conn, error) {
if strings.Contains(addr.String(), ".onion:") {
return cfg.oniondial(addr.Network(), addr.String())
}
return cfg.dial(network, address)
return cfg.dial(addr.Network(), addr.String())
}
// btcdLookup returns the correct DNS lookup function to use depending on the

View file

@ -57,7 +57,7 @@ type ConnReq struct {
// The following variables must only be used atomically.
id uint64
Addr string
Addr net.Addr
Permanent bool
conn net.Conn
@ -88,7 +88,7 @@ func (c *ConnReq) State() ConnState {
// String returns a human-readable string for the connection request.
func (c *ConnReq) String() string {
if c.Addr == "" {
if c.Addr.String() == "" {
return fmt.Sprintf("reqid %d", atomic.LoadUint64(&c.id))
}
return fmt.Sprintf("%s (reqid %d)", c.Addr, atomic.LoadUint64(&c.id))
@ -137,10 +137,10 @@ type Config struct {
// GetNewAddress is a way to get an address to make a network connection
// to. If nil, no new connections will be made automatically.
GetNewAddress func() (string, error)
GetNewAddress func() (net.Addr, error)
// Dial connects to the address on the named network. It cannot be nil.
Dial func(string, string) (net.Conn, error)
Dial func(net.Addr) (net.Conn, error)
}
// handleConnected is used to queue a successful connection.
@ -281,14 +281,18 @@ func (cm *ConnManager) NewConnReq() {
if cm.cfg.GetNewAddress == nil {
return
}
c := &ConnReq{}
atomic.StoreUint64(&c.id, atomic.AddUint64(&cm.connReqCount, 1))
addr, err := cm.cfg.GetNewAddress()
if err != nil {
cm.requests <- handleFailed{c, err}
return
}
c.Addr = addr
cm.Connect(c)
}
@ -302,7 +306,7 @@ func (cm *ConnManager) Connect(c *ConnReq) {
atomic.StoreUint64(&c.id, atomic.AddUint64(&cm.connReqCount, 1))
}
log.Debugf("Attempting to connect to %v", c)
conn, err := cm.cfg.Dial("tcp", c.Addr)
conn, err := cm.cfg.Dial(c.Addr)
if err != nil {
cm.requests <- handleFailed{c, err}
} else {

View file

@ -9,7 +9,6 @@ import (
"errors"
"io"
"net"
"strconv"
"sync/atomic"
"testing"
"time"
@ -40,7 +39,7 @@ type mockConn struct {
lnet, laddr string
// remote network, address for the connection.
rnet, raddr string
rAddr net.Addr
}
// LocalAddr returns the local address for the connection.
@ -50,7 +49,7 @@ func (c mockConn) LocalAddr() net.Addr {
// RemoteAddr returns the remote address for the connection.
func (c mockConn) RemoteAddr() net.Addr {
return &mockAddr{c.rnet, c.raddr}
return &mockAddr{c.rAddr.Network(), c.rAddr.String()}
}
// Close handles closing the connection.
@ -64,9 +63,9 @@ func (c mockConn) SetWriteDeadline(t time.Time) error { return nil }
// mockDialer mocks the net.Dial interface by returning a mock connection to
// the given address.
func mockDialer(network, address string) (net.Conn, error) {
func mockDialer(addr net.Addr) (net.Conn, error) {
r, w := io.Pipe()
c := &mockConn{raddr: address}
c := &mockConn{rAddr: addr}
c.Reader = r
c.Writer = w
return c, nil
@ -102,7 +101,12 @@ func TestStartStop(t *testing.T) {
disconnected := make(chan *ConnReq)
cmgr, err := New(&Config{
TargetOutbound: 1,
GetNewAddress: func() (string, error) { return "127.0.0.1:18555", nil },
GetNewAddress: func() (net.Addr, error) {
return &net.TCPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 18555,
}, nil
},
Dial: mockDialer,
OnConnection: func(c *ConnReq, conn net.Conn) {
connected <- c
@ -120,7 +124,13 @@ func TestStartStop(t *testing.T) {
// already stopped
cmgr.Stop()
// ignored
cr := &ConnReq{Addr: "127.0.0.1:18555", Permanent: true}
cr := &ConnReq{
Addr: &net.TCPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 18555,
},
Permanent: true,
}
cmgr.Connect(cr)
if cr.ID() != 0 {
t.Fatalf("start/stop: got id: %v, want: 0", cr.ID())
@ -151,7 +161,13 @@ func TestConnectMode(t *testing.T) {
if err != nil {
t.Fatalf("New error: %v", err)
}
cr := &ConnReq{Addr: "127.0.0.1:18555", Permanent: true}
cr := &ConnReq{
Addr: &net.TCPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 18555,
},
Permanent: true,
}
cmgr.Start()
cmgr.Connect(cr)
gotConnReq := <-connected
@ -184,7 +200,12 @@ func TestTargetOutbound(t *testing.T) {
cmgr, err := New(&Config{
TargetOutbound: targetOutbound,
Dial: mockDialer,
GetNewAddress: func() (string, error) { return "127.0.0.1:18555", nil },
GetNewAddress: func() (net.Addr, error) {
return &net.TCPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 18555,
}, nil
},
OnConnection: func(c *ConnReq, conn net.Conn) {
connected <- c
},
@ -228,7 +249,13 @@ func TestRetryPermanent(t *testing.T) {
t.Fatalf("New error: %v", err)
}
cr := &ConnReq{Addr: "127.0.0.1:18555", Permanent: true}
cr := &ConnReq{
Addr: &net.TCPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 18555,
},
Permanent: true,
}
go cmgr.Connect(cr)
cmgr.Start()
gotConnReq := <-connected
@ -292,10 +319,10 @@ func TestMaxRetryDuration(t *testing.T) {
time.AfterFunc(5*time.Millisecond, func() {
close(networkUp)
})
timedDialer := func(network, address string) (net.Conn, error) {
timedDialer := func(addr net.Addr) (net.Conn, error) {
select {
case <-networkUp:
return mockDialer(network, address)
return mockDialer(addr)
default:
return nil, errors.New("network down")
}
@ -314,7 +341,13 @@ func TestMaxRetryDuration(t *testing.T) {
t.Fatalf("New error: %v", err)
}
cr := &ConnReq{Addr: "127.0.0.1:18555", Permanent: true}
cr := &ConnReq{
Addr: &net.TCPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 18555,
},
Permanent: true,
}
go cmgr.Connect(cr)
cmgr.Start()
// retry in 1ms
@ -331,7 +364,7 @@ func TestMaxRetryDuration(t *testing.T) {
// failure gracefully.
func TestNetworkFailure(t *testing.T) {
var dials uint32
errDialer := func(network, address string) (net.Conn, error) {
errDialer := func(net net.Addr) (net.Conn, error) {
atomic.AddUint32(&dials, 1)
return nil, errors.New("network down")
}
@ -339,7 +372,12 @@ func TestNetworkFailure(t *testing.T) {
TargetOutbound: 5,
RetryDuration: 5 * time.Millisecond,
Dial: errDialer,
GetNewAddress: func() (string, error) { return "127.0.0.1:18555", nil },
GetNewAddress: func() (net.Addr, error) {
return &net.TCPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 18555,
}, nil
},
OnConnection: func(c *ConnReq, conn net.Conn) {
t.Fatalf("network failure: got unexpected connection - %v", c.Addr)
},
@ -365,7 +403,7 @@ func TestNetworkFailure(t *testing.T) {
// the failure.
func TestStopFailed(t *testing.T) {
done := make(chan struct{}, 1)
waitDialer := func(network, address string) (net.Conn, error) {
waitDialer := func(addr net.Addr) (net.Conn, error) {
done <- struct{}{}
time.Sleep(time.Millisecond)
return nil, errors.New("network down")
@ -384,7 +422,13 @@ func TestStopFailed(t *testing.T) {
atomic.StoreInt32(&cmgr.stop, 0)
cmgr.Stop()
}()
cr := &ConnReq{Addr: "127.0.0.1:18555", Permanent: true}
cr := &ConnReq{
Addr: &net.TCPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 18555,
},
Permanent: true,
}
go cmgr.Connect(cr)
cmgr.Wait()
}
@ -428,12 +472,14 @@ func (m *mockListener) Addr() net.Addr {
// address. It will cause the Accept function to return a mock connection
// configured with the provided remote address and the local address for the
// mock listener.
func (m *mockListener) Connect(remoteAddr string) {
func (m *mockListener) Connect(ip string, port int) {
m.provideConn <- &mockConn{
laddr: m.localAddr,
lnet: "tcp",
raddr: remoteAddr,
rnet: "tcp",
rAddr: &net.TCPAddr{
IP: net.ParseIP(ip),
Port: port,
},
}
}
@ -471,8 +517,8 @@ func TestListeners(t *testing.T) {
go func() {
for i, listener := range listeners {
l := listener.(*mockListener)
l.Connect("127.0.0.1:" + strconv.Itoa(10000+i*2))
l.Connect("127.0.0.1:" + strconv.Itoa(10000+i*2+1))
l.Connect("127.0.0.1", 10000+i*2)
l.Connect("127.0.0.1", 10000+i*2+1)
}
}()

View file

@ -1456,9 +1456,15 @@ func (s *server) handleQuery(state *peerState, querymsg interface{}) {
}
}
netAddr, err := addrStringToNetAddr(msg.addr)
if err != nil {
msg.reply <- err
return
}
// TODO(oga) if too many, nuke a non-perm peer.
go s.connManager.Connect(&connmgr.ConnReq{
Addr: msg.addr,
Addr: netAddr,
Permanent: msg.permanent,
})
msg.reply <- nil
@ -1603,7 +1609,7 @@ func (s *server) inboundPeerConnected(conn net.Conn) {
// manager of the attempt.
func (s *server) outboundPeerConnected(c *connmgr.ConnReq, conn net.Conn) {
sp := newServerPeer(s, c.Permanent)
p, err := peer.NewOutboundPeer(newPeerConfig(sp), c.Addr)
p, err := peer.NewOutboundPeer(newPeerConfig(sp), c.Addr.String())
if err != nil {
srvrLog.Debugf("Cannot create outbound peer %s: %v", c.Addr, err)
s.connManager.Disconnect(c.ID())
@ -2401,9 +2407,9 @@ func newServer(listenAddrs []string, db database.DB, chainParams *chaincfg.Param
// specified peers and actively avoid advertising and connecting to
// discovered peers in order to prevent it from becoming a public test
// network.
var newAddressFunc func() (string, error)
var newAddressFunc func() (net.Addr, error)
if !cfg.SimNet && len(cfg.ConnectPeers) == 0 {
newAddressFunc = func() (string, error) {
newAddressFunc = func() (net.Addr, error) {
for tries := 0; tries < 100; tries++ {
addr := s.addrManager.GetAddress("any")
if addr == nil {
@ -2432,9 +2438,12 @@ func newServer(listenAddrs []string, db database.DB, chainParams *chaincfg.Param
activeNetParams.DefaultPort && tries < 50 {
continue
}
return addrmgr.NetAddressKey(addr.NetAddress()), nil
addrString := addrmgr.NetAddressKey(addr.NetAddress())
return addrStringToNetAddr(addrString)
}
return "", errors.New("no valid connect address")
return nil, errors.New("no valid connect address")
}
}
@ -2463,7 +2472,15 @@ func newServer(listenAddrs []string, db database.DB, chainParams *chaincfg.Param
permanentPeers = cfg.AddPeers
}
for _, addr := range permanentPeers {
go s.connManager.Connect(&connmgr.ConnReq{Addr: addr, Permanent: true})
tcpAddr, err := addrStringToNetAddr(addr)
if err != nil {
return nil, err
}
go s.connManager.Connect(&connmgr.ConnReq{
Addr: tcpAddr,
Permanent: true,
})
}
if !cfg.DisableRPC {
@ -2483,6 +2500,34 @@ func newServer(listenAddrs []string, db database.DB, chainParams *chaincfg.Param
return &s, nil
}
// addrStringToNetAddr takes an address in the form of 'host:port' and returns
// a net.Addr which maps to the original address with any host names resolved
// to IP addresses.
func addrStringToNetAddr(addr string) (net.Addr, error) {
host, strPort, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
// Attempt to look up an IP address associated with the parsed host.
// The btcdLookup function will transparently handle performing the
// lookup over Tor if necessary.
ips, err := btcdLookup(host)
if err != nil {
return nil, err
}
port, err := strconv.Atoi(strPort)
if err != nil {
return nil, err
}
return &net.TCPAddr{
IP: ips[0],
Port: port,
}, nil
}
// dynamicTickDuration is a convenience function used to dynamically choose a
// tick duration based on remaining time. It is primarily used during
// server shutdown to make shutdown warnings more frequent as the shutdown time