connmgr: add ability to remove pending connections
This commit adds the ability for callers to remove pending connections via a call to the Remove() method. With this change, upstream users of this package can use the connmgr for more elaborate connectivity needs as they can now cancel pending connections that are no longer needed.
This commit is contained in:
parent
ffe4c2f0ad
commit
548c0f499b
2 changed files with 92 additions and 3 deletions
|
@ -143,6 +143,15 @@ type Config struct {
|
||||||
Dial func(net.Addr) (net.Conn, error)
|
Dial func(net.Addr) (net.Conn, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// registerPending is used to register a pending connection attempt. By
|
||||||
|
// registering pending connection attempts we allow callers to cancel pending
|
||||||
|
// connection attempts before their successful or in the case they're not
|
||||||
|
// longer wanted.
|
||||||
|
type registerPending struct {
|
||||||
|
c *ConnReq
|
||||||
|
done chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
// handleConnected is used to queue a successful connection.
|
// handleConnected is used to queue a successful connection.
|
||||||
type handleConnected struct {
|
type handleConnected struct {
|
||||||
c *ConnReq
|
c *ConnReq
|
||||||
|
@ -217,12 +226,17 @@ func (cm *ConnManager) handleFailedConn(c *ConnReq) {
|
||||||
// are processed and mapped by their assigned ids.
|
// are processed and mapped by their assigned ids.
|
||||||
func (cm *ConnManager) connHandler() {
|
func (cm *ConnManager) connHandler() {
|
||||||
conns := make(map[uint64]*ConnReq, cm.cfg.TargetOutbound)
|
conns := make(map[uint64]*ConnReq, cm.cfg.TargetOutbound)
|
||||||
|
pendingConns := make(map[uint64]*ConnReq)
|
||||||
out:
|
out:
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case req := <-cm.requests:
|
case req := <-cm.requests:
|
||||||
switch msg := req.(type) {
|
switch msg := req.(type) {
|
||||||
|
|
||||||
|
case registerPending:
|
||||||
|
pendingConns[msg.c.id] = msg.c
|
||||||
|
close(msg.done)
|
||||||
|
|
||||||
case handleConnected:
|
case handleConnected:
|
||||||
connReq := msg.c
|
connReq := msg.c
|
||||||
connReq.updateState(ConnEstablished)
|
connReq.updateState(ConnEstablished)
|
||||||
|
@ -232,12 +246,26 @@ out:
|
||||||
connReq.retryCount = 0
|
connReq.retryCount = 0
|
||||||
cm.failedAttempts = 0
|
cm.failedAttempts = 0
|
||||||
|
|
||||||
|
delete(pendingConns, connReq.id)
|
||||||
|
|
||||||
if cm.cfg.OnConnection != nil {
|
if cm.cfg.OnConnection != nil {
|
||||||
go cm.cfg.OnConnection(connReq, msg.conn)
|
go cm.cfg.OnConnection(connReq, msg.conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
case handleDisconnected:
|
case handleDisconnected:
|
||||||
if connReq, ok := conns[msg.id]; ok {
|
connReq, ok := conns[msg.id]
|
||||||
|
if !ok {
|
||||||
|
connReq, ok = pendingConns[msg.id]
|
||||||
|
if ok && !msg.retry {
|
||||||
|
connReq.updateState(ConnFailed)
|
||||||
|
|
||||||
|
log.Debugf("Cancelling: %v", connReq)
|
||||||
|
delete(pendingConns, msg.id)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if connReq != nil {
|
||||||
connReq.updateState(ConnDisconnected)
|
connReq.updateState(ConnDisconnected)
|
||||||
if connReq.conn != nil {
|
if connReq.conn != nil {
|
||||||
connReq.conn.Close()
|
connReq.conn.Close()
|
||||||
|
@ -304,8 +332,18 @@ func (cm *ConnManager) Connect(c *ConnReq) {
|
||||||
}
|
}
|
||||||
if atomic.LoadUint64(&c.id) == 0 {
|
if atomic.LoadUint64(&c.id) == 0 {
|
||||||
atomic.StoreUint64(&c.id, atomic.AddUint64(&cm.connReqCount, 1))
|
atomic.StoreUint64(&c.id, atomic.AddUint64(&cm.connReqCount, 1))
|
||||||
|
|
||||||
|
// Submit a request of a pending connection attempt to the
|
||||||
|
// connection manager. By registering the id before the
|
||||||
|
// connection is even established, we'll be able to later
|
||||||
|
// cancel the connection via the Remove method.
|
||||||
|
done := make(chan struct{})
|
||||||
|
cm.requests <- registerPending{c, done}
|
||||||
|
<-done
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("Attempting to connect to %v", c)
|
log.Debugf("Attempting to connect to %v", c)
|
||||||
|
|
||||||
conn, err := cm.cfg.Dial(c.Addr)
|
conn, err := cm.cfg.Dial(c.Addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cm.requests <- handleFailed{c, err}
|
cm.requests <- handleFailed{c, err}
|
||||||
|
@ -324,8 +362,11 @@ func (cm *ConnManager) Disconnect(id uint64) {
|
||||||
cm.requests <- handleDisconnected{id, true}
|
cm.requests <- handleDisconnected{id, true}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove removes the connection corresponding to the given connection
|
// Remove removes the connection corresponding to the given connection id from
|
||||||
// id from known connections.
|
// known connections.
|
||||||
|
//
|
||||||
|
// NOTE: This method can also be used to cancel a lingering connection attempt
|
||||||
|
// that hasn't yet succeeded.
|
||||||
func (cm *ConnManager) Remove(id uint64) {
|
func (cm *ConnManager) Remove(id uint64) {
|
||||||
if atomic.LoadInt32(&cm.stop) != 0 {
|
if atomic.LoadInt32(&cm.stop) != 0 {
|
||||||
return
|
return
|
||||||
|
|
|
@ -6,8 +6,10 @@ package connmgr
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"runtime"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
@ -421,6 +423,52 @@ func TestStopFailed(t *testing.T) {
|
||||||
cmgr.Wait()
|
cmgr.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestRemovePendingConnection tests that it's possible to cancel a pending
|
||||||
|
// connection, removing its internal state from the ConnMgr.
|
||||||
|
func TestRemovePendingConnection(t *testing.T) {
|
||||||
|
// Create a ConnMgr instance with an instance of a dialer that'll never
|
||||||
|
// succeed.
|
||||||
|
wait := make(chan struct{})
|
||||||
|
indefiniteDialer := func(addr net.Addr) (net.Conn, error) {
|
||||||
|
<-wait
|
||||||
|
return nil, fmt.Errorf("error")
|
||||||
|
}
|
||||||
|
cmgr, err := New(&Config{
|
||||||
|
Dial: indefiniteDialer,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("New error: %v", err)
|
||||||
|
}
|
||||||
|
cmgr.Start()
|
||||||
|
|
||||||
|
// Establish a connection request to a random IP we've chosen.
|
||||||
|
cr := &ConnReq{
|
||||||
|
Addr: &net.TCPAddr{
|
||||||
|
IP: net.ParseIP("127.0.0.1"),
|
||||||
|
Port: 18555,
|
||||||
|
},
|
||||||
|
Permanent: true,
|
||||||
|
}
|
||||||
|
go cmgr.Connect(cr)
|
||||||
|
|
||||||
|
runtime.Gosched()
|
||||||
|
|
||||||
|
// The request launched above will actually never be able to establish
|
||||||
|
// a connection. So we'll cancel it _before_ it's able to be completed.
|
||||||
|
cmgr.Remove(cr.ID())
|
||||||
|
|
||||||
|
runtime.Gosched()
|
||||||
|
|
||||||
|
// Now examine the status of the connection request, it should read a
|
||||||
|
// status of failed.
|
||||||
|
if cr.State() != ConnFailed {
|
||||||
|
t.Fatalf("request wasn't cancelled, status is: %v", cr.State())
|
||||||
|
}
|
||||||
|
|
||||||
|
close(wait)
|
||||||
|
cmgr.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
// mockListener implements the net.Listener interface and is used to test
|
// mockListener implements the net.Listener interface and is used to test
|
||||||
// code that deals with net.Listeners without having to actually make any real
|
// code that deals with net.Listeners without having to actually make any real
|
||||||
// connections.
|
// connections.
|
||||||
|
|
Loading…
Reference in a new issue