diff --git a/connmgr/connmanager.go b/connmgr/connmanager.go index 5f056b33..26c8e59c 100644 --- a/connmgr/connmanager.go +++ b/connmgr/connmanager.go @@ -143,6 +143,15 @@ type Config struct { Dial func(net.Addr) (net.Conn, error) } +// registerPending is used to register a pending connection attempt. By +// registering pending connection attempts we allow callers to cancel pending +// connection attempts before their successful or in the case they're not +// longer wanted. +type registerPending struct { + c *ConnReq + done chan struct{} +} + // handleConnected is used to queue a successful connection. type handleConnected struct { c *ConnReq @@ -217,12 +226,17 @@ func (cm *ConnManager) handleFailedConn(c *ConnReq) { // are processed and mapped by their assigned ids. func (cm *ConnManager) connHandler() { conns := make(map[uint64]*ConnReq, cm.cfg.TargetOutbound) + pendingConns := make(map[uint64]*ConnReq) out: for { select { case req := <-cm.requests: switch msg := req.(type) { + case registerPending: + pendingConns[msg.c.id] = msg.c + close(msg.done) + case handleConnected: connReq := msg.c connReq.updateState(ConnEstablished) @@ -232,12 +246,26 @@ out: connReq.retryCount = 0 cm.failedAttempts = 0 + delete(pendingConns, connReq.id) + if cm.cfg.OnConnection != nil { go cm.cfg.OnConnection(connReq, msg.conn) } case handleDisconnected: - if connReq, ok := conns[msg.id]; ok { + connReq, ok := conns[msg.id] + if !ok { + connReq, ok = pendingConns[msg.id] + if ok && !msg.retry { + connReq.updateState(ConnFailed) + + log.Debugf("Cancelling: %v", connReq) + delete(pendingConns, msg.id) + return + } + } + + if connReq != nil { connReq.updateState(ConnDisconnected) if connReq.conn != nil { connReq.conn.Close() @@ -304,8 +332,18 @@ func (cm *ConnManager) Connect(c *ConnReq) { } if atomic.LoadUint64(&c.id) == 0 { atomic.StoreUint64(&c.id, atomic.AddUint64(&cm.connReqCount, 1)) + + // Submit a request of a pending connection attempt to the + // connection manager. By registering the id before the + // connection is even established, we'll be able to later + // cancel the connection via the Remove method. + done := make(chan struct{}) + cm.requests <- registerPending{c, done} + <-done } + log.Debugf("Attempting to connect to %v", c) + conn, err := cm.cfg.Dial(c.Addr) if err != nil { cm.requests <- handleFailed{c, err} @@ -324,8 +362,11 @@ func (cm *ConnManager) Disconnect(id uint64) { cm.requests <- handleDisconnected{id, true} } -// Remove removes the connection corresponding to the given connection -// id from known connections. +// Remove removes the connection corresponding to the given connection id from +// known connections. +// +// NOTE: This method can also be used to cancel a lingering connection attempt +// that hasn't yet succeeded. func (cm *ConnManager) Remove(id uint64) { if atomic.LoadInt32(&cm.stop) != 0 { return diff --git a/connmgr/connmanager_test.go b/connmgr/connmanager_test.go index 03b6dd2e..99928931 100644 --- a/connmgr/connmanager_test.go +++ b/connmgr/connmanager_test.go @@ -6,8 +6,10 @@ package connmgr import ( "errors" + "fmt" "io" "net" + "runtime" "sync/atomic" "testing" "time" @@ -421,6 +423,52 @@ func TestStopFailed(t *testing.T) { cmgr.Wait() } +// TestRemovePendingConnection tests that it's possible to cancel a pending +// connection, removing its internal state from the ConnMgr. +func TestRemovePendingConnection(t *testing.T) { + // Create a ConnMgr instance with an instance of a dialer that'll never + // succeed. + wait := make(chan struct{}) + indefiniteDialer := func(addr net.Addr) (net.Conn, error) { + <-wait + return nil, fmt.Errorf("error") + } + cmgr, err := New(&Config{ + Dial: indefiniteDialer, + }) + if err != nil { + t.Fatalf("New error: %v", err) + } + cmgr.Start() + + // Establish a connection request to a random IP we've chosen. + cr := &ConnReq{ + Addr: &net.TCPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 18555, + }, + Permanent: true, + } + go cmgr.Connect(cr) + + runtime.Gosched() + + // The request launched above will actually never be able to establish + // a connection. So we'll cancel it _before_ it's able to be completed. + cmgr.Remove(cr.ID()) + + runtime.Gosched() + + // Now examine the status of the connection request, it should read a + // status of failed. + if cr.State() != ConnFailed { + t.Fatalf("request wasn't cancelled, status is: %v", cr.State()) + } + + close(wait) + cmgr.Stop() +} + // mockListener implements the net.Listener interface and is used to test // code that deals with net.Listeners without having to actually make any real // connections.