390 lines
10 KiB
Go
390 lines
10 KiB
Go
|
// Copyright (c) 2016 The btcsuite developers
|
||
|
// Use of this source code is governed by an ISC
|
||
|
// license that can be found in the LICENSE file.
|
||
|
|
||
|
package connmgr
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"errors"
|
||
|
"io"
|
||
|
"net"
|
||
|
"sync/atomic"
|
||
|
"testing"
|
||
|
"time"
|
||
|
|
||
|
"github.com/btcsuite/btclog"
|
||
|
)
|
||
|
|
||
|
func init() {
|
||
|
// Override the max retry duration when running tests.
|
||
|
maxRetryDuration = 2 * time.Millisecond
|
||
|
}
|
||
|
|
||
|
// mockAddr mocks a network address
|
||
|
type mockAddr struct {
|
||
|
net, address string
|
||
|
}
|
||
|
|
||
|
func (m mockAddr) Network() string { return m.net }
|
||
|
func (m mockAddr) String() string { return m.address }
|
||
|
|
||
|
// mockConn mocks a network connection by implementing the net.Conn interface.
|
||
|
type mockConn struct {
|
||
|
io.Reader
|
||
|
io.Writer
|
||
|
io.Closer
|
||
|
|
||
|
// local network, address for the connection.
|
||
|
lnet, laddr string
|
||
|
|
||
|
// remote network, address for the connection.
|
||
|
rnet, raddr string
|
||
|
}
|
||
|
|
||
|
// LocalAddr returns the local address for the connection.
|
||
|
func (c mockConn) LocalAddr() net.Addr {
|
||
|
return &mockAddr{c.lnet, c.laddr}
|
||
|
}
|
||
|
|
||
|
// RemoteAddr returns the remote address for the connection.
|
||
|
func (c mockConn) RemoteAddr() net.Addr {
|
||
|
return &mockAddr{c.rnet, c.raddr}
|
||
|
}
|
||
|
|
||
|
// Close handles closing the connection.
|
||
|
func (c mockConn) Close() error {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (c mockConn) SetDeadline(t time.Time) error { return nil }
|
||
|
func (c mockConn) SetReadDeadline(t time.Time) error { return nil }
|
||
|
func (c mockConn) SetWriteDeadline(t time.Time) error { return nil }
|
||
|
|
||
|
// mockDialer mocks the net.Dial interface by returning a mock connection to
|
||
|
// the given address.
|
||
|
func mockDialer(network, address string) (net.Conn, error) {
|
||
|
r, w := io.Pipe()
|
||
|
c := &mockConn{raddr: address}
|
||
|
c.Reader = r
|
||
|
c.Writer = w
|
||
|
return c, nil
|
||
|
}
|
||
|
|
||
|
// TestNewConfig tests that new ConnManager config is validated as expected.
|
||
|
func TestNewConfig(t *testing.T) {
|
||
|
_, err := New(&Config{})
|
||
|
if err == nil {
|
||
|
t.Fatalf("New expected error: 'Dial can't be nil', got nil")
|
||
|
}
|
||
|
_, err = New(&Config{
|
||
|
Dial: mockDialer,
|
||
|
})
|
||
|
if err != nil {
|
||
|
t.Fatalf("New unexpected error: %v", err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// TestUseLogger tests that a logger can be passed to UseLogger
|
||
|
func TestUseLogger(t *testing.T) {
|
||
|
l, err := btclog.NewLoggerFromWriter(bytes.NewBuffer(nil), btclog.InfoLvl)
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
UseLogger(l)
|
||
|
}
|
||
|
|
||
|
// TestStartStop tests that the connection manager starts and stops as
|
||
|
// expected.
|
||
|
func TestStartStop(t *testing.T) {
|
||
|
connected := make(chan *ConnReq)
|
||
|
disconnected := make(chan *ConnReq)
|
||
|
cmgr, err := New(&Config{
|
||
|
MaxOutbound: 1,
|
||
|
GetNewAddress: func() (string, error) { return "127.0.0.1:18555", nil },
|
||
|
Dial: mockDialer,
|
||
|
OnConnection: func(c *ConnReq, conn net.Conn) {
|
||
|
connected <- c
|
||
|
},
|
||
|
OnDisconnection: func(c *ConnReq) {
|
||
|
disconnected <- c
|
||
|
},
|
||
|
})
|
||
|
if err != nil {
|
||
|
t.Fatalf("New error: %v", err)
|
||
|
}
|
||
|
cmgr.Start()
|
||
|
gotConnReq := <-connected
|
||
|
cmgr.Stop()
|
||
|
// already stopped
|
||
|
cmgr.Stop()
|
||
|
// ignored
|
||
|
cr := &ConnReq{Addr: "127.0.0.1:18555", Permanent: true}
|
||
|
cmgr.Connect(cr)
|
||
|
if cr.ID() != 0 {
|
||
|
t.Fatalf("start/stop: got id: %v, want: 0", cr.ID())
|
||
|
}
|
||
|
cmgr.Disconnect(gotConnReq.ID())
|
||
|
cmgr.Remove(gotConnReq.ID())
|
||
|
select {
|
||
|
case <-disconnected:
|
||
|
t.Fatalf("start/stop: unexpected disconnection")
|
||
|
case <-time.Tick(10 * time.Millisecond):
|
||
|
break
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// TestConnectMode tests that the connection manager works in the connect mode.
|
||
|
//
|
||
|
// In connect mode, automatic connections are disabled, so we test that
|
||
|
// requests using Connect are handled and that no other connections are made.
|
||
|
func TestConnectMode(t *testing.T) {
|
||
|
connected := make(chan *ConnReq)
|
||
|
cmgr, err := New(&Config{
|
||
|
MaxOutbound: 2,
|
||
|
Dial: mockDialer,
|
||
|
OnConnection: func(c *ConnReq, conn net.Conn) {
|
||
|
connected <- c
|
||
|
},
|
||
|
})
|
||
|
if err != nil {
|
||
|
t.Fatalf("New error: %v", err)
|
||
|
}
|
||
|
cr := &ConnReq{Addr: "127.0.0.1:18555", Permanent: true}
|
||
|
cmgr.Start()
|
||
|
cmgr.Connect(cr)
|
||
|
gotConnReq := <-connected
|
||
|
wantID := cr.ID()
|
||
|
gotID := gotConnReq.ID()
|
||
|
if gotID != wantID {
|
||
|
t.Fatalf("connect mode: %v - want ID %v, got ID %v", cr.Addr, wantID, gotID)
|
||
|
}
|
||
|
gotState := cr.State()
|
||
|
wantState := ConnEstablished
|
||
|
if gotState != wantState {
|
||
|
t.Fatalf("connect mode: %v - want state %v, got state %v", cr.Addr, wantState, gotState)
|
||
|
}
|
||
|
select {
|
||
|
case c := <-connected:
|
||
|
t.Fatalf("connect mode: got unexpected connection - %v", c.Addr)
|
||
|
case <-time.After(time.Millisecond):
|
||
|
break
|
||
|
}
|
||
|
cmgr.Stop()
|
||
|
}
|
||
|
|
||
|
// TestMaxOutbound tests the maximum number of outbound connections.
|
||
|
//
|
||
|
// We wait until all connections are established, then test they there are the
|
||
|
// only connections made.
|
||
|
func TestMaxOutbound(t *testing.T) {
|
||
|
maxOutbound := uint32(10)
|
||
|
connected := make(chan *ConnReq)
|
||
|
cmgr, err := New(&Config{
|
||
|
MaxOutbound: maxOutbound,
|
||
|
Dial: mockDialer,
|
||
|
GetNewAddress: func() (string, error) { return "127.0.0.1:18555", nil },
|
||
|
OnConnection: func(c *ConnReq, conn net.Conn) {
|
||
|
connected <- c
|
||
|
},
|
||
|
})
|
||
|
if err != nil {
|
||
|
t.Fatalf("New error: %v", err)
|
||
|
}
|
||
|
cmgr.Start()
|
||
|
for i := uint32(0); i < maxOutbound; i++ {
|
||
|
<-connected
|
||
|
}
|
||
|
|
||
|
select {
|
||
|
case c := <-connected:
|
||
|
t.Fatalf("max outbound: got unexpected connection - %v", c.Addr)
|
||
|
case <-time.After(time.Millisecond):
|
||
|
break
|
||
|
}
|
||
|
cmgr.Stop()
|
||
|
}
|
||
|
|
||
|
// TestRetryPermanent tests that permanent connection requests are retried.
|
||
|
//
|
||
|
// We make a permanent connection request using Connect, disconnect it using
|
||
|
// Disconnect and we wait for it to be connected back.
|
||
|
func TestRetryPermanent(t *testing.T) {
|
||
|
connected := make(chan *ConnReq)
|
||
|
disconnected := make(chan *ConnReq)
|
||
|
cmgr, err := New(&Config{
|
||
|
RetryDuration: time.Millisecond,
|
||
|
MaxOutbound: 1,
|
||
|
Dial: mockDialer,
|
||
|
OnConnection: func(c *ConnReq, conn net.Conn) {
|
||
|
connected <- c
|
||
|
},
|
||
|
OnDisconnection: func(c *ConnReq) {
|
||
|
disconnected <- c
|
||
|
},
|
||
|
})
|
||
|
if err != nil {
|
||
|
t.Fatalf("New error: %v", err)
|
||
|
}
|
||
|
|
||
|
cr := &ConnReq{Addr: "127.0.0.1:18555", Permanent: true}
|
||
|
go cmgr.Connect(cr)
|
||
|
cmgr.Start()
|
||
|
gotConnReq := <-connected
|
||
|
wantID := cr.ID()
|
||
|
gotID := gotConnReq.ID()
|
||
|
if gotID != wantID {
|
||
|
t.Fatalf("retry: %v - want ID %v, got ID %v", cr.Addr, wantID, gotID)
|
||
|
}
|
||
|
gotState := cr.State()
|
||
|
wantState := ConnEstablished
|
||
|
if gotState != wantState {
|
||
|
t.Fatalf("retry: %v - want state %v, got state %v", cr.Addr, wantState, gotState)
|
||
|
}
|
||
|
|
||
|
cmgr.Disconnect(cr.ID())
|
||
|
gotConnReq = <-disconnected
|
||
|
wantID = cr.ID()
|
||
|
gotID = gotConnReq.ID()
|
||
|
if gotID != wantID {
|
||
|
t.Fatalf("retry: %v - want ID %v, got ID %v", cr.Addr, wantID, gotID)
|
||
|
}
|
||
|
gotState = cr.State()
|
||
|
wantState = ConnDisconnected
|
||
|
if gotState != wantState {
|
||
|
t.Fatalf("retry: %v - want state %v, got state %v", cr.Addr, wantState, gotState)
|
||
|
}
|
||
|
|
||
|
gotConnReq = <-connected
|
||
|
wantID = cr.ID()
|
||
|
gotID = gotConnReq.ID()
|
||
|
if gotID != wantID {
|
||
|
t.Fatalf("retry: %v - want ID %v, got ID %v", cr.Addr, wantID, gotID)
|
||
|
}
|
||
|
gotState = cr.State()
|
||
|
wantState = ConnEstablished
|
||
|
if gotState != wantState {
|
||
|
t.Fatalf("retry: %v - want state %v, got state %v", cr.Addr, wantState, gotState)
|
||
|
}
|
||
|
|
||
|
cmgr.Remove(cr.ID())
|
||
|
gotConnReq = <-disconnected
|
||
|
wantID = cr.ID()
|
||
|
gotID = gotConnReq.ID()
|
||
|
if gotID != wantID {
|
||
|
t.Fatalf("retry: %v - want ID %v, got ID %v", cr.Addr, wantID, gotID)
|
||
|
}
|
||
|
gotState = cr.State()
|
||
|
wantState = ConnDisconnected
|
||
|
if gotState != wantState {
|
||
|
t.Fatalf("retry: %v - want state %v, got state %v", cr.Addr, wantState, gotState)
|
||
|
}
|
||
|
cmgr.Stop()
|
||
|
}
|
||
|
|
||
|
// TestMaxRetryDuration tests the maximum retry duration.
|
||
|
//
|
||
|
// We have a timed dialer which initially returns err but after RetryDuration
|
||
|
// hits maxRetryDuration returns a mock conn.
|
||
|
func TestMaxRetryDuration(t *testing.T) {
|
||
|
networkUp := make(chan struct{})
|
||
|
time.AfterFunc(5*time.Millisecond, func() {
|
||
|
close(networkUp)
|
||
|
})
|
||
|
timedDialer := func(network, address string) (net.Conn, error) {
|
||
|
select {
|
||
|
case <-networkUp:
|
||
|
return mockDialer(network, address)
|
||
|
default:
|
||
|
return nil, errors.New("network down")
|
||
|
}
|
||
|
}
|
||
|
|
||
|
connected := make(chan *ConnReq)
|
||
|
cmgr, err := New(&Config{
|
||
|
RetryDuration: time.Millisecond,
|
||
|
MaxOutbound: 1,
|
||
|
Dial: timedDialer,
|
||
|
OnConnection: func(c *ConnReq, conn net.Conn) {
|
||
|
connected <- c
|
||
|
},
|
||
|
})
|
||
|
if err != nil {
|
||
|
t.Fatalf("New error: %v", err)
|
||
|
}
|
||
|
|
||
|
cr := &ConnReq{Addr: "127.0.0.1:18555", Permanent: true}
|
||
|
go cmgr.Connect(cr)
|
||
|
cmgr.Start()
|
||
|
// retry in 1ms
|
||
|
// retry in 2ms - max retry duration reached
|
||
|
// retry in 2ms - timedDialer returns mockDial
|
||
|
select {
|
||
|
case <-connected:
|
||
|
case <-time.Tick(100 * time.Millisecond):
|
||
|
t.Fatalf("max retry duration: connection timeout")
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// TestNetworkFailure tests that the connection manager handles a network
|
||
|
// failure gracefully.
|
||
|
func TestNetworkFailure(t *testing.T) {
|
||
|
var dials uint32
|
||
|
errDialer := func(network, address string) (net.Conn, error) {
|
||
|
atomic.AddUint32(&dials, 1)
|
||
|
return nil, errors.New("network down")
|
||
|
}
|
||
|
cmgr, err := New(&Config{
|
||
|
MaxOutbound: 5,
|
||
|
RetryDuration: 5 * time.Millisecond,
|
||
|
Dial: errDialer,
|
||
|
GetNewAddress: func() (string, error) { return "127.0.0.1:18555", nil },
|
||
|
OnConnection: func(c *ConnReq, conn net.Conn) {
|
||
|
t.Fatalf("network failure: got unexpected connection - %v", c.Addr)
|
||
|
},
|
||
|
})
|
||
|
if err != nil {
|
||
|
t.Fatalf("New error: %v", err)
|
||
|
}
|
||
|
cmgr.Start()
|
||
|
time.AfterFunc(10*time.Millisecond, cmgr.Stop)
|
||
|
cmgr.Wait()
|
||
|
wantMaxDials := uint32(75)
|
||
|
if atomic.LoadUint32(&dials) > wantMaxDials {
|
||
|
t.Fatalf("network failure: unexpected number of dials - got %v, want < %v",
|
||
|
atomic.LoadUint32(&dials), wantMaxDials)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// TestStopFailed tests that failed connections are ignored after connmgr is
|
||
|
// stopped.
|
||
|
//
|
||
|
// We have a dailer which sets the stop flag on the conn manager and returns an
|
||
|
// err so that the handler assumes that the conn manager is stopped and ignores
|
||
|
// the failure.
|
||
|
func TestStopFailed(t *testing.T) {
|
||
|
done := make(chan struct{}, 1)
|
||
|
waitDialer := func(network, address string) (net.Conn, error) {
|
||
|
done <- struct{}{}
|
||
|
time.Sleep(time.Millisecond)
|
||
|
return nil, errors.New("network down")
|
||
|
}
|
||
|
cmgr, err := New(&Config{
|
||
|
Dial: waitDialer,
|
||
|
})
|
||
|
if err != nil {
|
||
|
t.Fatalf("New error: %v", err)
|
||
|
}
|
||
|
cmgr.Start()
|
||
|
go func() {
|
||
|
<-done
|
||
|
atomic.StoreInt32(&cmgr.stop, 1)
|
||
|
time.Sleep(2 * time.Millisecond)
|
||
|
atomic.StoreInt32(&cmgr.stop, 0)
|
||
|
cmgr.Stop()
|
||
|
}()
|
||
|
cr := &ConnReq{Addr: "127.0.0.1:18555", Permanent: true}
|
||
|
go cmgr.Connect(cr)
|
||
|
cmgr.Wait()
|
||
|
}
|