From e8f63bc29550705268b533032ccc2ea24f8c86ba Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Thu, 3 Nov 2016 14:23:13 -0700 Subject: [PATCH] connmgr: switch to using net.Addr interface throughout for addresses This commit modifies the `ConnManager` to use the `net.Add` interface through the package instead of a plain string to represent and manipulate addresses. This change makes the package much more general as users of the package can possibly utilize custom implementations of the `net.Addr` interface to establish connections. More precisely, the `ConnReq` struct has been modified to use a net.Addr instance explicitly, and the `DialFunc` type has also been modified to take a `net.Addr` directly. This latter change gives functions that adhere to the `DialFunc` type more flexibility as to exactly how the connection is established. Additionally, the `connmgr.Config.GetNewAddress` configuration option now directly returns a `net.Addr. This change allows the `connmgr` to be decoupled from all DNS queries which allows callers to preferentially select more secure methods like performing DNS lookups over a Tor proxy. --- config.go | 8 ++-- connmgr/connmanager.go | 14 ++++-- connmgr/connmanager_test.go | 92 +++++++++++++++++++++++++++---------- server.go | 59 +++++++++++++++++++++--- 4 files changed, 134 insertions(+), 39 deletions(-) diff --git a/config.go b/config.go index e51f2df9..fe4fba81 100644 --- a/config.go +++ b/config.go @@ -1001,11 +1001,11 @@ func createDefaultConfigFile(destinationPath string) error { // example, .onion addresses will be dialed using the onion specific proxy if // one was specified, but will otherwise use the normal dial function (which // could itself use a proxy or not). -func btcdDial(network, address string) (net.Conn, error) { - if strings.Contains(address, ".onion:") { - return cfg.oniondial(network, address) +func btcdDial(addr net.Addr) (net.Conn, error) { + if strings.Contains(addr.String(), ".onion:") { + return cfg.oniondial(addr.Network(), addr.String()) } - return cfg.dial(network, address) + return cfg.dial(addr.Network(), addr.String()) } // btcdLookup returns the correct DNS lookup function to use depending on the diff --git a/connmgr/connmanager.go b/connmgr/connmanager.go index 33133734..5f056b33 100644 --- a/connmgr/connmanager.go +++ b/connmgr/connmanager.go @@ -57,7 +57,7 @@ type ConnReq struct { // The following variables must only be used atomically. id uint64 - Addr string + Addr net.Addr Permanent bool conn net.Conn @@ -88,7 +88,7 @@ func (c *ConnReq) State() ConnState { // String returns a human-readable string for the connection request. func (c *ConnReq) String() string { - if c.Addr == "" { + if c.Addr.String() == "" { return fmt.Sprintf("reqid %d", atomic.LoadUint64(&c.id)) } return fmt.Sprintf("%s (reqid %d)", c.Addr, atomic.LoadUint64(&c.id)) @@ -137,10 +137,10 @@ type Config struct { // GetNewAddress is a way to get an address to make a network connection // to. If nil, no new connections will be made automatically. - GetNewAddress func() (string, error) + GetNewAddress func() (net.Addr, error) // Dial connects to the address on the named network. It cannot be nil. - Dial func(string, string) (net.Conn, error) + Dial func(net.Addr) (net.Conn, error) } // handleConnected is used to queue a successful connection. @@ -281,14 +281,18 @@ func (cm *ConnManager) NewConnReq() { if cm.cfg.GetNewAddress == nil { return } + c := &ConnReq{} atomic.StoreUint64(&c.id, atomic.AddUint64(&cm.connReqCount, 1)) + addr, err := cm.cfg.GetNewAddress() if err != nil { cm.requests <- handleFailed{c, err} return } + c.Addr = addr + cm.Connect(c) } @@ -302,7 +306,7 @@ func (cm *ConnManager) Connect(c *ConnReq) { atomic.StoreUint64(&c.id, atomic.AddUint64(&cm.connReqCount, 1)) } log.Debugf("Attempting to connect to %v", c) - conn, err := cm.cfg.Dial("tcp", c.Addr) + conn, err := cm.cfg.Dial(c.Addr) if err != nil { cm.requests <- handleFailed{c, err} } else { diff --git a/connmgr/connmanager_test.go b/connmgr/connmanager_test.go index c3408371..c947b2b7 100644 --- a/connmgr/connmanager_test.go +++ b/connmgr/connmanager_test.go @@ -9,7 +9,6 @@ import ( "errors" "io" "net" - "strconv" "sync/atomic" "testing" "time" @@ -40,7 +39,7 @@ type mockConn struct { lnet, laddr string // remote network, address for the connection. - rnet, raddr string + rAddr net.Addr } // LocalAddr returns the local address for the connection. @@ -50,7 +49,7 @@ func (c mockConn) LocalAddr() net.Addr { // RemoteAddr returns the remote address for the connection. func (c mockConn) RemoteAddr() net.Addr { - return &mockAddr{c.rnet, c.raddr} + return &mockAddr{c.rAddr.Network(), c.rAddr.String()} } // Close handles closing the connection. @@ -64,9 +63,9 @@ func (c mockConn) SetWriteDeadline(t time.Time) error { return nil } // mockDialer mocks the net.Dial interface by returning a mock connection to // the given address. -func mockDialer(network, address string) (net.Conn, error) { +func mockDialer(addr net.Addr) (net.Conn, error) { r, w := io.Pipe() - c := &mockConn{raddr: address} + c := &mockConn{rAddr: addr} c.Reader = r c.Writer = w return c, nil @@ -102,8 +101,13 @@ func TestStartStop(t *testing.T) { disconnected := make(chan *ConnReq) cmgr, err := New(&Config{ TargetOutbound: 1, - GetNewAddress: func() (string, error) { return "127.0.0.1:18555", nil }, - Dial: mockDialer, + GetNewAddress: func() (net.Addr, error) { + return &net.TCPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 18555, + }, nil + }, + Dial: mockDialer, OnConnection: func(c *ConnReq, conn net.Conn) { connected <- c }, @@ -120,7 +124,13 @@ func TestStartStop(t *testing.T) { // already stopped cmgr.Stop() // ignored - cr := &ConnReq{Addr: "127.0.0.1:18555", Permanent: true} + cr := &ConnReq{ + Addr: &net.TCPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 18555, + }, + Permanent: true, + } cmgr.Connect(cr) if cr.ID() != 0 { t.Fatalf("start/stop: got id: %v, want: 0", cr.ID()) @@ -151,7 +161,13 @@ func TestConnectMode(t *testing.T) { if err != nil { t.Fatalf("New error: %v", err) } - cr := &ConnReq{Addr: "127.0.0.1:18555", Permanent: true} + cr := &ConnReq{ + Addr: &net.TCPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 18555, + }, + Permanent: true, + } cmgr.Start() cmgr.Connect(cr) gotConnReq := <-connected @@ -184,7 +200,12 @@ func TestTargetOutbound(t *testing.T) { cmgr, err := New(&Config{ TargetOutbound: targetOutbound, Dial: mockDialer, - GetNewAddress: func() (string, error) { return "127.0.0.1:18555", nil }, + GetNewAddress: func() (net.Addr, error) { + return &net.TCPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 18555, + }, nil + }, OnConnection: func(c *ConnReq, conn net.Conn) { connected <- c }, @@ -228,7 +249,13 @@ func TestRetryPermanent(t *testing.T) { t.Fatalf("New error: %v", err) } - cr := &ConnReq{Addr: "127.0.0.1:18555", Permanent: true} + cr := &ConnReq{ + Addr: &net.TCPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 18555, + }, + Permanent: true, + } go cmgr.Connect(cr) cmgr.Start() gotConnReq := <-connected @@ -292,10 +319,10 @@ func TestMaxRetryDuration(t *testing.T) { time.AfterFunc(5*time.Millisecond, func() { close(networkUp) }) - timedDialer := func(network, address string) (net.Conn, error) { + timedDialer := func(addr net.Addr) (net.Conn, error) { select { case <-networkUp: - return mockDialer(network, address) + return mockDialer(addr) default: return nil, errors.New("network down") } @@ -314,7 +341,13 @@ func TestMaxRetryDuration(t *testing.T) { t.Fatalf("New error: %v", err) } - cr := &ConnReq{Addr: "127.0.0.1:18555", Permanent: true} + cr := &ConnReq{ + Addr: &net.TCPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 18555, + }, + Permanent: true, + } go cmgr.Connect(cr) cmgr.Start() // retry in 1ms @@ -331,7 +364,7 @@ func TestMaxRetryDuration(t *testing.T) { // failure gracefully. func TestNetworkFailure(t *testing.T) { var dials uint32 - errDialer := func(network, address string) (net.Conn, error) { + errDialer := func(net net.Addr) (net.Conn, error) { atomic.AddUint32(&dials, 1) return nil, errors.New("network down") } @@ -339,7 +372,12 @@ func TestNetworkFailure(t *testing.T) { TargetOutbound: 5, RetryDuration: 5 * time.Millisecond, Dial: errDialer, - GetNewAddress: func() (string, error) { return "127.0.0.1:18555", nil }, + GetNewAddress: func() (net.Addr, error) { + return &net.TCPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 18555, + }, nil + }, OnConnection: func(c *ConnReq, conn net.Conn) { t.Fatalf("network failure: got unexpected connection - %v", c.Addr) }, @@ -365,7 +403,7 @@ func TestNetworkFailure(t *testing.T) { // the failure. func TestStopFailed(t *testing.T) { done := make(chan struct{}, 1) - waitDialer := func(network, address string) (net.Conn, error) { + waitDialer := func(addr net.Addr) (net.Conn, error) { done <- struct{}{} time.Sleep(time.Millisecond) return nil, errors.New("network down") @@ -384,7 +422,13 @@ func TestStopFailed(t *testing.T) { atomic.StoreInt32(&cmgr.stop, 0) cmgr.Stop() }() - cr := &ConnReq{Addr: "127.0.0.1:18555", Permanent: true} + cr := &ConnReq{ + Addr: &net.TCPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 18555, + }, + Permanent: true, + } go cmgr.Connect(cr) cmgr.Wait() } @@ -428,12 +472,14 @@ func (m *mockListener) Addr() net.Addr { // address. It will cause the Accept function to return a mock connection // configured with the provided remote address and the local address for the // mock listener. -func (m *mockListener) Connect(remoteAddr string) { +func (m *mockListener) Connect(ip string, port int) { m.provideConn <- &mockConn{ laddr: m.localAddr, lnet: "tcp", - raddr: remoteAddr, - rnet: "tcp", + rAddr: &net.TCPAddr{ + IP: net.ParseIP(ip), + Port: port, + }, } } @@ -471,8 +517,8 @@ func TestListeners(t *testing.T) { go func() { for i, listener := range listeners { l := listener.(*mockListener) - l.Connect("127.0.0.1:" + strconv.Itoa(10000+i*2)) - l.Connect("127.0.0.1:" + strconv.Itoa(10000+i*2+1)) + l.Connect("127.0.0.1", 10000+i*2) + l.Connect("127.0.0.1", 10000+i*2+1) } }() diff --git a/server.go b/server.go index 8d210db8..e18b373c 100644 --- a/server.go +++ b/server.go @@ -1456,9 +1456,15 @@ func (s *server) handleQuery(state *peerState, querymsg interface{}) { } } + netAddr, err := addrStringToNetAddr(msg.addr) + if err != nil { + msg.reply <- err + return + } + // TODO(oga) if too many, nuke a non-perm peer. go s.connManager.Connect(&connmgr.ConnReq{ - Addr: msg.addr, + Addr: netAddr, Permanent: msg.permanent, }) msg.reply <- nil @@ -1603,7 +1609,7 @@ func (s *server) inboundPeerConnected(conn net.Conn) { // manager of the attempt. func (s *server) outboundPeerConnected(c *connmgr.ConnReq, conn net.Conn) { sp := newServerPeer(s, c.Permanent) - p, err := peer.NewOutboundPeer(newPeerConfig(sp), c.Addr) + p, err := peer.NewOutboundPeer(newPeerConfig(sp), c.Addr.String()) if err != nil { srvrLog.Debugf("Cannot create outbound peer %s: %v", c.Addr, err) s.connManager.Disconnect(c.ID()) @@ -2401,9 +2407,9 @@ func newServer(listenAddrs []string, db database.DB, chainParams *chaincfg.Param // specified peers and actively avoid advertising and connecting to // discovered peers in order to prevent it from becoming a public test // network. - var newAddressFunc func() (string, error) + var newAddressFunc func() (net.Addr, error) if !cfg.SimNet && len(cfg.ConnectPeers) == 0 { - newAddressFunc = func() (string, error) { + newAddressFunc = func() (net.Addr, error) { for tries := 0; tries < 100; tries++ { addr := s.addrManager.GetAddress("any") if addr == nil { @@ -2432,9 +2438,12 @@ func newServer(listenAddrs []string, db database.DB, chainParams *chaincfg.Param activeNetParams.DefaultPort && tries < 50 { continue } - return addrmgr.NetAddressKey(addr.NetAddress()), nil + + addrString := addrmgr.NetAddressKey(addr.NetAddress()) + return addrStringToNetAddr(addrString) } - return "", errors.New("no valid connect address") + + return nil, errors.New("no valid connect address") } } @@ -2463,7 +2472,15 @@ func newServer(listenAddrs []string, db database.DB, chainParams *chaincfg.Param permanentPeers = cfg.AddPeers } for _, addr := range permanentPeers { - go s.connManager.Connect(&connmgr.ConnReq{Addr: addr, Permanent: true}) + tcpAddr, err := addrStringToNetAddr(addr) + if err != nil { + return nil, err + } + + go s.connManager.Connect(&connmgr.ConnReq{ + Addr: tcpAddr, + Permanent: true, + }) } if !cfg.DisableRPC { @@ -2483,6 +2500,34 @@ func newServer(listenAddrs []string, db database.DB, chainParams *chaincfg.Param return &s, nil } +// addrStringToNetAddr takes an address in the form of 'host:port' and returns +// a net.Addr which maps to the original address with any host names resolved +// to IP addresses. +func addrStringToNetAddr(addr string) (net.Addr, error) { + host, strPort, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + + // Attempt to look up an IP address associated with the parsed host. + // The btcdLookup function will transparently handle performing the + // lookup over Tor if necessary. + ips, err := btcdLookup(host) + if err != nil { + return nil, err + } + + port, err := strconv.Atoi(strPort) + if err != nil { + return nil, err + } + + return &net.TCPAddr{ + IP: ips[0], + Port: port, + }, nil +} + // dynamicTickDuration is a convenience function used to dynamically choose a // tick duration based on remaining time. It is primarily used during // server shutdown to make shutdown warnings more frequent as the shutdown time