diff --git a/config.go b/config.go index 87d2df6f..d0a31e66 100644 --- a/config.go +++ b/config.go @@ -20,6 +20,7 @@ import ( "strings" "time" + "github.com/btcsuite/btcd/connmgr" "github.com/btcsuite/btcd/database" _ "github.com/btcsuite/btcd/database/ffldb" "github.com/btcsuite/btcd/mempool" @@ -860,7 +861,7 @@ func loadConfig() (*config, []string, error) { cfg.dial = proxy.Dial if !cfg.NoOnion { cfg.lookup = func(host string) ([]net.IP, error) { - return torLookupIP(host, cfg.Proxy) + return connmgr.TorLookupIP(host, cfg.Proxy) } } } @@ -900,7 +901,7 @@ func loadConfig() (*config, []string, error) { return proxy.Dial(a, b) } cfg.onionlookup = func(host string) ([]net.IP, error) { - return torLookupIP(host, cfg.OnionProxy) + return connmgr.TorLookupIP(host, cfg.OnionProxy) } } else { cfg.oniondial = cfg.dial diff --git a/connmgr/README.md b/connmgr/README.md new file mode 100644 index 00000000..3cb184fb --- /dev/null +++ b/connmgr/README.md @@ -0,0 +1,39 @@ +connmgr +======= + +[![Build Status](http://img.shields.io/travis/btcsuite/btcd.svg)] +(https://travis-ci.org/btcsuite/btcd) [![ISC License] +(http://img.shields.io/badge/license-ISC-blue.svg)](http://copyfree.org) +[![GoDoc](https://img.shields.io/badge/godoc-reference-blue.svg)] +(http://godoc.org/github.com/btcsuite/btcd/connmgr) + +Package connmgr implements a generic Bitcoin network connection manager. + +## Overview + +Connection Manager handles all the general connection concerns such as +maintaining a set number of outbound connections, sourcing peers, banning, +limiting max connections, tor lookup, etc. + +The package provides a generic connection manager which is able to accept +connection requests from a source or a set of given addresses, dial them and +notify the caller on connections. The main intended use is to initialize a pool +of active connections and maintain them to remain connected to the P2P network. + +In addition the connection manager provides the following utilities: + +- Notifications on connections or disconnections +- Handle failures and retry new addresses from the source +- Connect only to specified addresses +- Permanent connections with increasing backoff retry timers +- Disconnect or Remove an established connection + +## Installation and Updating + +```bash +$ go get -u github.com/btcsuite/btcd/connmgr +``` + +## License + +Package connmgr is licensed under the [copyfree](http://copyfree.org) ISC License. diff --git a/connmgr/connmanager.go b/connmgr/connmanager.go new file mode 100644 index 00000000..c7d661b6 --- /dev/null +++ b/connmgr/connmanager.go @@ -0,0 +1,371 @@ +// Copyright (c) 2016 The btcsuite developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +package connmgr + +import ( + "errors" + "fmt" + "net" + "sync" + "sync/atomic" + "time" +) + +// maxFailedAttempts is the maximum number of successive failed connection +// attempts after which network failure is assumed and new connections will +// be delayed by the configured retry duration. +const maxFailedAttempts = 25 + +var ( + //ErrDialNil is used to indicate that Dial cannot be nil in the configuration. + ErrDialNil = errors.New("Config: Dial cannot be nil") + + // maxRetryDuration is the max duration of time retrying of a persistent + // connection is allowed to grow to. This is necessary since the retry + // logic uses a backoff mechanism which increases the interval base times + // the number of retries that have been done. + maxRetryDuration = time.Minute * 5 + + // defaultRetryDuration is the default duration of time for retrying + // persistent connections. + defaultRetryDuration = time.Second * 5 + + // defaultMaxOutbound is the default number of maximum outbound connections + // to maintain. + defaultMaxOutbound = uint32(8) +) + +// DialFunc defines a function that dials a connection. +type DialFunc func(string, string) (net.Conn, error) + +// AddressFunc defines a function that returns a network address to connect to. +type AddressFunc func() (string, error) + +// ConnState represents the state of the requested connection. +type ConnState uint8 + +// ConnState can be either pending, established, disconnected or failed. When +// a new connection is requested, it is attempted and categorized as +// established or failed depending on the connection result. An established +// connection which was disconnected is categorized as disconnected. +const ( + ConnPending ConnState = iota + ConnEstablished + ConnDisconnected + ConnFailed +) + +// OnConnectionFunc is the signature of the callback function which is used to +// subscribe to new connections. +type OnConnectionFunc func(*ConnReq, net.Conn) + +// OnDisconnectionFunc is the signature of the callback function which is used to +// notify disconnections. +type OnDisconnectionFunc func(*ConnReq) + +// ConnReq is the connection request to a network address. If permanent, the +// connection will be retried on disconnection. +type ConnReq struct { + // The following variables must only be used atomically. + id uint64 + + Addr string + Permanent bool + + conn net.Conn + state ConnState + stateMtx sync.RWMutex + retryCount uint32 +} + +// updateState updates the state of the connection request. +func (c *ConnReq) updateState(state ConnState) { + c.stateMtx.Lock() + c.state = state + c.stateMtx.Unlock() +} + +// ID returns a unique identifier for the connection request. +func (c *ConnReq) ID() uint64 { + return atomic.LoadUint64(&c.id) +} + +// State is the connection state of the requested connection. +func (c *ConnReq) State() ConnState { + c.stateMtx.RLock() + state := c.state + c.stateMtx.RUnlock() + return state +} + +// String returns a human-readable string for the connection request. +func (c *ConnReq) String() string { + if c.Addr == "" { + return fmt.Sprintf("reqid %d", atomic.LoadUint64(&c.id)) + } + return fmt.Sprintf("%s (reqid %d)", c.Addr, atomic.LoadUint64(&c.id)) +} + +// Config holds the configuration options related to the connection manager. +type Config struct { + // MaxOutbound is the maximum number of outbound network connections to + // maintain. Defaults to 8. + MaxOutbound uint32 + + // RetryDuration is the duration to wait before retrying connection + // requests. Defaults to 5s. + RetryDuration time.Duration + + // OnConnection is a callback that is fired when a new connection is + // established. + OnConnection OnConnectionFunc + + // OnDisconnection is a callback that is fired when a connection is + // disconnected. + OnDisconnection OnDisconnectionFunc + + // GetNewAddress is a way to get an address to make a network connection + // to. If nil, no new connections will be made automatically. + GetNewAddress AddressFunc + + // Dial connects to the address on the named network. It cannot be nil. + Dial DialFunc +} + +// handleConnected is used to queue a successful connection. +type handleConnected struct { + c *ConnReq + conn net.Conn +} + +// handleDisconnected is used to remove a connection. +type handleDisconnected struct { + id uint64 + retry bool +} + +// handleFailed is used to remove a pending connection. +type handleFailed struct { + c *ConnReq + err error +} + +// ConnManager provides a manager to handle network connections. +type ConnManager struct { + // The following variables must only be used atomically. + connReqCount uint64 + start int32 + stop int32 + + cfg Config + wg sync.WaitGroup + failedAttempts uint64 + requests chan interface{} + quit chan struct{} +} + +// handleFailedConn handles a connection failed due to a disconnect or any +// other failure. If permanent, it retries the connection after the configured +// retry duration. Otherwise, if required, it makes a new connection request. +// After maxFailedConnectionAttempts new connections will be retried after the +// configured retry duration. +func (cm *ConnManager) handleFailedConn(c *ConnReq, retry bool) { + if atomic.LoadInt32(&cm.stop) != 0 { + return + } + if retry && c.Permanent { + c.retryCount++ + d := time.Duration(c.retryCount) * cm.cfg.RetryDuration + if d > maxRetryDuration { + d = maxRetryDuration + } + log.Debugf("Retrying connection to %v in %v", c, d) + time.AfterFunc(d, func() { + cm.Connect(c) + }) + } else if cm.cfg.GetNewAddress != nil { + cm.failedAttempts++ + if cm.failedAttempts >= maxFailedAttempts { + log.Debugf("Max failed connection attempts reached: [%d] "+ + "-- retrying connection in: %v", maxFailedAttempts, + cm.cfg.RetryDuration) + time.AfterFunc(cm.cfg.RetryDuration, func() { + cm.NewConnReq() + }) + } else { + go cm.NewConnReq() + } + } +} + +// connHandler handles all connection related requests. It must be run as a +// goroutine. +// +// The connection handler makes sure that we maintain a pool of active outbound +// connections so that we remain connected to the network. Connection requests +// are processed and mapped by their assigned ids. +func (cm *ConnManager) connHandler() { + conns := make(map[uint64]*ConnReq, cm.cfg.MaxOutbound) +out: + for { + select { + case req := <-cm.requests: + switch msg := req.(type) { + + case handleConnected: + connReq := msg.c + connReq.updateState(ConnEstablished) + connReq.conn = msg.conn + conns[connReq.id] = connReq + log.Debugf("Connected to %v", connReq) + connReq.retryCount = 0 + cm.failedAttempts = 0 + + if cm.cfg.OnConnection != nil { + go cm.cfg.OnConnection(connReq, msg.conn) + } + + case handleDisconnected: + if connReq, ok := conns[msg.id]; ok { + connReq.updateState(ConnDisconnected) + if connReq.conn != nil { + connReq.conn.Close() + } + log.Debugf("Disconnected from %v", connReq) + delete(conns, msg.id) + + if cm.cfg.OnDisconnection != nil { + go cm.cfg.OnDisconnection(connReq) + } + + cm.handleFailedConn(connReq, msg.retry) + } else { + log.Errorf("Unknown connection: %d", msg.id) + } + + case handleFailed: + connReq := msg.c + connReq.updateState(ConnFailed) + log.Debugf("Failed to connect to %v: %v", connReq, msg.err) + cm.handleFailedConn(connReq, true) + } + + case <-cm.quit: + break out + } + } + + cm.wg.Done() + log.Trace("Connection handler done") +} + +// NewConnReq creates a new connection request and connects to the +// corresponding address. +func (cm *ConnManager) NewConnReq() { + if atomic.LoadInt32(&cm.stop) != 0 { + return + } + if cm.cfg.GetNewAddress == nil { + return + } + c := &ConnReq{} + atomic.StoreUint64(&c.id, atomic.AddUint64(&cm.connReqCount, 1)) + addr, err := cm.cfg.GetNewAddress() + if err != nil { + cm.requests <- handleFailed{c, err} + return + } + c.Addr = addr + cm.Connect(c) +} + +// Connect assigns an id and dials a connection to the address of the +// connection request. +func (cm *ConnManager) Connect(c *ConnReq) { + if atomic.LoadInt32(&cm.stop) != 0 { + return + } + if atomic.LoadUint64(&c.id) == 0 { + atomic.StoreUint64(&c.id, atomic.AddUint64(&cm.connReqCount, 1)) + } + log.Debugf("Attempting to connect to %v", c) + conn, err := cm.cfg.Dial("tcp", c.Addr) + if err != nil { + cm.requests <- handleFailed{c, err} + } else { + cm.requests <- handleConnected{c, conn} + } +} + +// Disconnect disconnects the connection corresponding to the given connection +// id. If permanent, the connection will be retried with an increasing backoff +// duration. +func (cm *ConnManager) Disconnect(id uint64) { + if atomic.LoadInt32(&cm.stop) != 0 { + return + } + cm.requests <- handleDisconnected{id, true} +} + +// Remove removes the connection corresponding to the given connection +// id from known connections. +func (cm *ConnManager) Remove(id uint64) { + if atomic.LoadInt32(&cm.stop) != 0 { + return + } + cm.requests <- handleDisconnected{id, false} +} + +// Start launches the connection manager and begins connecting to the network. +func (cm *ConnManager) Start() { + // Already started? + if atomic.AddInt32(&cm.start, 1) != 1 { + return + } + + log.Trace("Connection manager started") + cm.wg.Add(1) + go cm.connHandler() + + for i := atomic.LoadUint64(&cm.connReqCount); i < uint64(cm.cfg.MaxOutbound); i++ { + go cm.NewConnReq() + } +} + +// Wait blocks until the connection manager halts gracefully. +func (cm *ConnManager) Wait() { + cm.wg.Wait() +} + +// Stop gracefully shuts down the connection manager. +func (cm *ConnManager) Stop() { + if atomic.AddInt32(&cm.stop, 1) != 1 { + log.Warnf("Connection manager already stopped") + return + } + close(cm.quit) + log.Trace("Connection manager stopped") +} + +// New returns a new connection manager. +// Use Start to start connecting to the network. +func New(cfg *Config) (*ConnManager, error) { + if cfg.Dial == nil { + return nil, ErrDialNil + } + // Default to sane values + if cfg.RetryDuration <= 0 { + cfg.RetryDuration = defaultRetryDuration + } + if cfg.MaxOutbound == 0 { + cfg.MaxOutbound = defaultMaxOutbound + } + cm := ConnManager{ + cfg: *cfg, // Copy so caller can't mutate + requests: make(chan interface{}), + quit: make(chan struct{}), + } + return &cm, nil +} diff --git a/connmgr/connmanager_test.go b/connmgr/connmanager_test.go new file mode 100644 index 00000000..adac4b88 --- /dev/null +++ b/connmgr/connmanager_test.go @@ -0,0 +1,389 @@ +// Copyright (c) 2016 The btcsuite developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +package connmgr + +import ( + "bytes" + "errors" + "io" + "net" + "sync/atomic" + "testing" + "time" + + "github.com/btcsuite/btclog" +) + +func init() { + // Override the max retry duration when running tests. + maxRetryDuration = 2 * time.Millisecond +} + +// mockAddr mocks a network address +type mockAddr struct { + net, address string +} + +func (m mockAddr) Network() string { return m.net } +func (m mockAddr) String() string { return m.address } + +// mockConn mocks a network connection by implementing the net.Conn interface. +type mockConn struct { + io.Reader + io.Writer + io.Closer + + // local network, address for the connection. + lnet, laddr string + + // remote network, address for the connection. + rnet, raddr string +} + +// LocalAddr returns the local address for the connection. +func (c mockConn) LocalAddr() net.Addr { + return &mockAddr{c.lnet, c.laddr} +} + +// RemoteAddr returns the remote address for the connection. +func (c mockConn) RemoteAddr() net.Addr { + return &mockAddr{c.rnet, c.raddr} +} + +// Close handles closing the connection. +func (c mockConn) Close() error { + return nil +} + +func (c mockConn) SetDeadline(t time.Time) error { return nil } +func (c mockConn) SetReadDeadline(t time.Time) error { return nil } +func (c mockConn) SetWriteDeadline(t time.Time) error { return nil } + +// mockDialer mocks the net.Dial interface by returning a mock connection to +// the given address. +func mockDialer(network, address string) (net.Conn, error) { + r, w := io.Pipe() + c := &mockConn{raddr: address} + c.Reader = r + c.Writer = w + return c, nil +} + +// TestNewConfig tests that new ConnManager config is validated as expected. +func TestNewConfig(t *testing.T) { + _, err := New(&Config{}) + if err == nil { + t.Fatalf("New expected error: 'Dial can't be nil', got nil") + } + _, err = New(&Config{ + Dial: mockDialer, + }) + if err != nil { + t.Fatalf("New unexpected error: %v", err) + } +} + +// TestUseLogger tests that a logger can be passed to UseLogger +func TestUseLogger(t *testing.T) { + l, err := btclog.NewLoggerFromWriter(bytes.NewBuffer(nil), btclog.InfoLvl) + if err != nil { + t.Fatal(err) + } + UseLogger(l) +} + +// TestStartStop tests that the connection manager starts and stops as +// expected. +func TestStartStop(t *testing.T) { + connected := make(chan *ConnReq) + disconnected := make(chan *ConnReq) + cmgr, err := New(&Config{ + MaxOutbound: 1, + GetNewAddress: func() (string, error) { return "127.0.0.1:18555", nil }, + Dial: mockDialer, + OnConnection: func(c *ConnReq, conn net.Conn) { + connected <- c + }, + OnDisconnection: func(c *ConnReq) { + disconnected <- c + }, + }) + if err != nil { + t.Fatalf("New error: %v", err) + } + cmgr.Start() + gotConnReq := <-connected + cmgr.Stop() + // already stopped + cmgr.Stop() + // ignored + cr := &ConnReq{Addr: "127.0.0.1:18555", Permanent: true} + cmgr.Connect(cr) + if cr.ID() != 0 { + t.Fatalf("start/stop: got id: %v, want: 0", cr.ID()) + } + cmgr.Disconnect(gotConnReq.ID()) + cmgr.Remove(gotConnReq.ID()) + select { + case <-disconnected: + t.Fatalf("start/stop: unexpected disconnection") + case <-time.Tick(10 * time.Millisecond): + break + } +} + +// TestConnectMode tests that the connection manager works in the connect mode. +// +// In connect mode, automatic connections are disabled, so we test that +// requests using Connect are handled and that no other connections are made. +func TestConnectMode(t *testing.T) { + connected := make(chan *ConnReq) + cmgr, err := New(&Config{ + MaxOutbound: 2, + Dial: mockDialer, + OnConnection: func(c *ConnReq, conn net.Conn) { + connected <- c + }, + }) + if err != nil { + t.Fatalf("New error: %v", err) + } + cr := &ConnReq{Addr: "127.0.0.1:18555", Permanent: true} + cmgr.Start() + cmgr.Connect(cr) + gotConnReq := <-connected + wantID := cr.ID() + gotID := gotConnReq.ID() + if gotID != wantID { + t.Fatalf("connect mode: %v - want ID %v, got ID %v", cr.Addr, wantID, gotID) + } + gotState := cr.State() + wantState := ConnEstablished + if gotState != wantState { + t.Fatalf("connect mode: %v - want state %v, got state %v", cr.Addr, wantState, gotState) + } + select { + case c := <-connected: + t.Fatalf("connect mode: got unexpected connection - %v", c.Addr) + case <-time.After(time.Millisecond): + break + } + cmgr.Stop() +} + +// TestMaxOutbound tests the maximum number of outbound connections. +// +// We wait until all connections are established, then test they there are the +// only connections made. +func TestMaxOutbound(t *testing.T) { + maxOutbound := uint32(10) + connected := make(chan *ConnReq) + cmgr, err := New(&Config{ + MaxOutbound: maxOutbound, + Dial: mockDialer, + GetNewAddress: func() (string, error) { return "127.0.0.1:18555", nil }, + OnConnection: func(c *ConnReq, conn net.Conn) { + connected <- c + }, + }) + if err != nil { + t.Fatalf("New error: %v", err) + } + cmgr.Start() + for i := uint32(0); i < maxOutbound; i++ { + <-connected + } + + select { + case c := <-connected: + t.Fatalf("max outbound: got unexpected connection - %v", c.Addr) + case <-time.After(time.Millisecond): + break + } + cmgr.Stop() +} + +// TestRetryPermanent tests that permanent connection requests are retried. +// +// We make a permanent connection request using Connect, disconnect it using +// Disconnect and we wait for it to be connected back. +func TestRetryPermanent(t *testing.T) { + connected := make(chan *ConnReq) + disconnected := make(chan *ConnReq) + cmgr, err := New(&Config{ + RetryDuration: time.Millisecond, + MaxOutbound: 1, + Dial: mockDialer, + OnConnection: func(c *ConnReq, conn net.Conn) { + connected <- c + }, + OnDisconnection: func(c *ConnReq) { + disconnected <- c + }, + }) + if err != nil { + t.Fatalf("New error: %v", err) + } + + cr := &ConnReq{Addr: "127.0.0.1:18555", Permanent: true} + go cmgr.Connect(cr) + cmgr.Start() + gotConnReq := <-connected + wantID := cr.ID() + gotID := gotConnReq.ID() + if gotID != wantID { + t.Fatalf("retry: %v - want ID %v, got ID %v", cr.Addr, wantID, gotID) + } + gotState := cr.State() + wantState := ConnEstablished + if gotState != wantState { + t.Fatalf("retry: %v - want state %v, got state %v", cr.Addr, wantState, gotState) + } + + cmgr.Disconnect(cr.ID()) + gotConnReq = <-disconnected + wantID = cr.ID() + gotID = gotConnReq.ID() + if gotID != wantID { + t.Fatalf("retry: %v - want ID %v, got ID %v", cr.Addr, wantID, gotID) + } + gotState = cr.State() + wantState = ConnDisconnected + if gotState != wantState { + t.Fatalf("retry: %v - want state %v, got state %v", cr.Addr, wantState, gotState) + } + + gotConnReq = <-connected + wantID = cr.ID() + gotID = gotConnReq.ID() + if gotID != wantID { + t.Fatalf("retry: %v - want ID %v, got ID %v", cr.Addr, wantID, gotID) + } + gotState = cr.State() + wantState = ConnEstablished + if gotState != wantState { + t.Fatalf("retry: %v - want state %v, got state %v", cr.Addr, wantState, gotState) + } + + cmgr.Remove(cr.ID()) + gotConnReq = <-disconnected + wantID = cr.ID() + gotID = gotConnReq.ID() + if gotID != wantID { + t.Fatalf("retry: %v - want ID %v, got ID %v", cr.Addr, wantID, gotID) + } + gotState = cr.State() + wantState = ConnDisconnected + if gotState != wantState { + t.Fatalf("retry: %v - want state %v, got state %v", cr.Addr, wantState, gotState) + } + cmgr.Stop() +} + +// TestMaxRetryDuration tests the maximum retry duration. +// +// We have a timed dialer which initially returns err but after RetryDuration +// hits maxRetryDuration returns a mock conn. +func TestMaxRetryDuration(t *testing.T) { + networkUp := make(chan struct{}) + time.AfterFunc(5*time.Millisecond, func() { + close(networkUp) + }) + timedDialer := func(network, address string) (net.Conn, error) { + select { + case <-networkUp: + return mockDialer(network, address) + default: + return nil, errors.New("network down") + } + } + + connected := make(chan *ConnReq) + cmgr, err := New(&Config{ + RetryDuration: time.Millisecond, + MaxOutbound: 1, + Dial: timedDialer, + OnConnection: func(c *ConnReq, conn net.Conn) { + connected <- c + }, + }) + if err != nil { + t.Fatalf("New error: %v", err) + } + + cr := &ConnReq{Addr: "127.0.0.1:18555", Permanent: true} + go cmgr.Connect(cr) + cmgr.Start() + // retry in 1ms + // retry in 2ms - max retry duration reached + // retry in 2ms - timedDialer returns mockDial + select { + case <-connected: + case <-time.Tick(100 * time.Millisecond): + t.Fatalf("max retry duration: connection timeout") + } +} + +// TestNetworkFailure tests that the connection manager handles a network +// failure gracefully. +func TestNetworkFailure(t *testing.T) { + var dials uint32 + errDialer := func(network, address string) (net.Conn, error) { + atomic.AddUint32(&dials, 1) + return nil, errors.New("network down") + } + cmgr, err := New(&Config{ + MaxOutbound: 5, + RetryDuration: 5 * time.Millisecond, + Dial: errDialer, + GetNewAddress: func() (string, error) { return "127.0.0.1:18555", nil }, + OnConnection: func(c *ConnReq, conn net.Conn) { + t.Fatalf("network failure: got unexpected connection - %v", c.Addr) + }, + }) + if err != nil { + t.Fatalf("New error: %v", err) + } + cmgr.Start() + time.AfterFunc(10*time.Millisecond, cmgr.Stop) + cmgr.Wait() + wantMaxDials := uint32(75) + if atomic.LoadUint32(&dials) > wantMaxDials { + t.Fatalf("network failure: unexpected number of dials - got %v, want < %v", + atomic.LoadUint32(&dials), wantMaxDials) + } +} + +// TestStopFailed tests that failed connections are ignored after connmgr is +// stopped. +// +// We have a dailer which sets the stop flag on the conn manager and returns an +// err so that the handler assumes that the conn manager is stopped and ignores +// the failure. +func TestStopFailed(t *testing.T) { + done := make(chan struct{}, 1) + waitDialer := func(network, address string) (net.Conn, error) { + done <- struct{}{} + time.Sleep(time.Millisecond) + return nil, errors.New("network down") + } + cmgr, err := New(&Config{ + Dial: waitDialer, + }) + if err != nil { + t.Fatalf("New error: %v", err) + } + cmgr.Start() + go func() { + <-done + atomic.StoreInt32(&cmgr.stop, 1) + time.Sleep(2 * time.Millisecond) + atomic.StoreInt32(&cmgr.stop, 0) + cmgr.Stop() + }() + cr := &ConnReq{Addr: "127.0.0.1:18555", Permanent: true} + go cmgr.Connect(cr) + cmgr.Wait() +} diff --git a/connmgr/doc.go b/connmgr/doc.go new file mode 100644 index 00000000..acb90c31 --- /dev/null +++ b/connmgr/doc.go @@ -0,0 +1,14 @@ +// Copyright (c) 2016 The btcsuite developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +/* +Package connmgr implements a generic Bitcoin network connection manager. + +Connection Manager Overview + +Connection Manager handles all the general connection concerns such as +maintaining a set number of outbound connections, sourcing peers, banning, +limiting max connections, tor lookup, etc. +*/ +package connmgr diff --git a/dynamicbanscore.go b/connmgr/dynamicbanscore.go similarity index 87% rename from dynamicbanscore.go rename to connmgr/dynamicbanscore.go index 786c7b40..10623944 100644 --- a/dynamicbanscore.go +++ b/connmgr/dynamicbanscore.go @@ -2,7 +2,7 @@ // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. -package main +package connmgr import ( "fmt" @@ -48,7 +48,7 @@ func decayFactor(t int64) float64 { return math.Exp(-1.0 * float64(t) * lambda) } -// dynamicBanScore provides dynamic ban scores consisting of a persistent and a +// DynamicBanScore provides dynamic ban scores consisting of a persistent and a // decaying component. The persistent score could be utilized to create simple // additive banning policies similar to those found in other bitcoin node // implementations. @@ -56,11 +56,11 @@ func decayFactor(t int64) float64 { // The decaying score enables the creation of evasive logic which handles // misbehaving peers (especially application layer DoS attacks) gracefully // by disconnecting and banning peers attempting various kinds of flooding. -// dynamicBanScore allows these two approaches to be used in tandem. +// DynamicBanScore allows these two approaches to be used in tandem. // -// Zero value: Values of type dynamicBanScore are immediately ready for use upon +// Zero value: Values of type DynamicBanScore are immediately ready for use upon // declaration. -type dynamicBanScore struct { +type DynamicBanScore struct { lastUnix int64 transient float64 persistent uint32 @@ -68,7 +68,7 @@ type dynamicBanScore struct { } // String returns the ban score as a human-readable string. -func (s *dynamicBanScore) String() string { +func (s *DynamicBanScore) String() string { s.Lock() r := fmt.Sprintf("persistent %v + transient %v at %v = %v as of now", s.persistent, s.transient, s.lastUnix, s.Int()) @@ -80,7 +80,7 @@ func (s *dynamicBanScore) String() string { // scores. // // This function is safe for concurrent access. -func (s *dynamicBanScore) Int() uint32 { +func (s *DynamicBanScore) Int() uint32 { s.Lock() r := s.int(time.Now()) s.Unlock() @@ -91,7 +91,7 @@ func (s *dynamicBanScore) Int() uint32 { // passed as parameters. The resulting score is returned. // // This function is safe for concurrent access. -func (s *dynamicBanScore) Increase(persistent, transient uint32) uint32 { +func (s *DynamicBanScore) Increase(persistent, transient uint32) uint32 { s.Lock() r := s.increase(persistent, transient, time.Now()) s.Unlock() @@ -101,7 +101,7 @@ func (s *dynamicBanScore) Increase(persistent, transient uint32) uint32 { // Reset set both persistent and decaying scores to zero. // // This function is safe for concurrent access. -func (s *dynamicBanScore) Reset() { +func (s *DynamicBanScore) Reset() { s.Lock() s.persistent = 0 s.transient = 0 @@ -114,7 +114,7 @@ func (s *dynamicBanScore) Reset() { // // This function is not safe for concurrent access. It is intended to be used // internally and during testing. -func (s *dynamicBanScore) int(t time.Time) uint32 { +func (s *DynamicBanScore) int(t time.Time) uint32 { dt := t.Unix() - s.lastUnix if s.transient < 1 || dt < 0 || Lifetime < dt { return s.persistent @@ -128,7 +128,7 @@ func (s *dynamicBanScore) int(t time.Time) uint32 { // resulting score is returned. // // This function is not safe for concurrent access. -func (s *dynamicBanScore) increase(persistent, transient uint32, t time.Time) uint32 { +func (s *DynamicBanScore) increase(persistent, transient uint32, t time.Time) uint32 { s.persistent += persistent tu := t.Unix() dt := tu - s.lastUnix diff --git a/dynamicbanscore_test.go b/connmgr/dynamicbanscore_test.go similarity index 86% rename from dynamicbanscore_test.go rename to connmgr/dynamicbanscore_test.go index 070bc6ae..6dcd64d4 100644 --- a/dynamicbanscore_test.go +++ b/connmgr/dynamicbanscore_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. -package main +package connmgr import ( "math" @@ -11,9 +11,9 @@ import ( ) // TestDynamicBanScoreDecay tests the exponential decay implemented in -// dynamicBanScore. +// DynamicBanScore. func TestDynamicBanScoreDecay(t *testing.T) { - var bs dynamicBanScore + var bs DynamicBanScore base := time.Now() r := bs.increase(100, 50, base) @@ -32,10 +32,10 @@ func TestDynamicBanScoreDecay(t *testing.T) { } } -// TestDynamicBanScoreLifetime tests that dynamicBanScore properly yields zero +// TestDynamicBanScoreLifetime tests that DynamicBanScore properly yields zero // once the maximum age is reached. func TestDynamicBanScoreLifetime(t *testing.T) { - var bs dynamicBanScore + var bs DynamicBanScore base := time.Now() r := bs.increase(0, math.MaxUint32, base) @@ -49,10 +49,10 @@ func TestDynamicBanScoreLifetime(t *testing.T) { } } -// TestDynamicBanScore tests exported functions of dynamicBanScore. Exponential +// TestDynamicBanScore tests exported functions of DynamicBanScore. Exponential // decay or other time based behavior is tested by other functions. func TestDynamicBanScoreReset(t *testing.T) { - var bs dynamicBanScore + var bs DynamicBanScore if bs.Int() != 0 { t.Errorf("Initial state is not zero.") } diff --git a/connmgr/log.go b/connmgr/log.go new file mode 100644 index 00000000..1afa7ee6 --- /dev/null +++ b/connmgr/log.go @@ -0,0 +1,30 @@ +// Copyright (c) 2016 The btcsuite developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +package connmgr + +import "github.com/btcsuite/btclog" + +// log is a logger that is initialized with no output filters. This +// means the package will not perform any logging by default until the caller +// requests it. +var log btclog.Logger + +// The default amount of logging is none. +func init() { + DisableLog() +} + +// DisableLog disables all library log output. Logging output is disabled +// by default until either UseLogger or SetLogWriter are called. +func DisableLog() { + log = btclog.Disabled +} + +// UseLogger uses a specified Logger to output package logging info. +// This should be used in preference to SetLogWriter if the caller is also +// using btclog. +func UseLogger(logger btclog.Logger) { + log = logger +} diff --git a/connmgr/seed.go b/connmgr/seed.go new file mode 100644 index 00000000..e4715372 --- /dev/null +++ b/connmgr/seed.go @@ -0,0 +1,66 @@ +// Copyright (c) 2016 The btcsuite developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +package connmgr + +import ( + mrand "math/rand" + "net" + "strconv" + "time" + + "github.com/btcsuite/btcd/chaincfg" + "github.com/btcsuite/btcd/wire" +) + +const ( + // These constants are used by the DNS seed code to pick a random last + // seen time. + secondsIn3Days int32 = 24 * 60 * 60 * 3 + secondsIn4Days int32 = 24 * 60 * 60 * 4 +) + +// OnSeed is the signature of the callback function which is invoked when DNS +// seeding is succesfull. +type OnSeed func(addrs []*wire.NetAddress) + +// LookupFunc is the signature of the DNS lookup function. +type LookupFunc func(string) ([]net.IP, error) + +// SeedFromDNS uses DNS seeding to populate the address manager with peers. +func SeedFromDNS(chainParams *chaincfg.Params, lookupFn LookupFunc, seedFn OnSeed) { + for _, seeder := range chainParams.DNSSeeds { + go func(seeder string) { + randSource := mrand.New(mrand.NewSource(time.Now().UnixNano())) + + seedpeers, err := lookupFn(seeder) + if err != nil { + log.Infof("DNS discovery failed on seed %s: %v", seeder, err) + return + } + numPeers := len(seedpeers) + + log.Infof("%d addresses found from DNS seed %s", numPeers, seeder) + + if numPeers == 0 { + return + } + addresses := make([]*wire.NetAddress, len(seedpeers)) + // if this errors then we have *real* problems + intPort, _ := strconv.Atoi(chainParams.DefaultPort) + for i, peer := range seedpeers { + addresses[i] = new(wire.NetAddress) + addresses[i].SetAddress(peer, uint16(intPort)) + // bitcoind seeds with addresses from + // a time randomly selected between 3 + // and 7 days ago. + addresses[i].Timestamp = time.Now().Add(-1 * + time.Second * time.Duration(secondsIn3Days+ + randSource.Int31n(secondsIn4Days))) + } + + seedFn(addresses) + }(seeder) + } +} diff --git a/discovery.go b/connmgr/tor.go similarity index 86% rename from discovery.go rename to connmgr/tor.go index 44c58f78..ddd2c4af 100644 --- a/discovery.go +++ b/connmgr/tor.go @@ -1,8 +1,8 @@ -// Copyright (c) 2013-2014 The btcsuite developers +// Copyright (c) 2013-2016 The btcsuite developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. -package main +package connmgr import ( "encoding/binary" @@ -48,10 +48,10 @@ var ( } ) -// torLookupIP uses Tor to resolve DNS via the SOCKS extension they provide for +// TorLookupIP uses Tor to resolve DNS via the SOCKS extension they provide for // resolution over the Tor network. Tor itself doesn't support ipv6 so this // doesn't either. -func torLookupIP(host, proxy string) ([]net.IP, error) { +func TorLookupIP(host, proxy string) ([]net.IP, error) { conn, err := net.Dial("tcp", proxy) if err != nil { return nil, err @@ -130,15 +130,3 @@ func torLookupIP(host, proxy string) ([]net.IP, error) { return addr, nil } - -// dnsDiscover looks up the list of peers resolved by DNS for all hosts in -// seeders. If proxy is not "" then it is used as a tor proxy for the -// resolution. -func dnsDiscover(seeder string) ([]net.IP, error) { - peers, err := btcdLookup(seeder) - if err != nil { - return nil, err - } - - return peers, nil -} diff --git a/docs/README.md b/docs/README.md index f01e3004..486384e2 100644 --- a/docs/README.md +++ b/docs/README.md @@ -259,3 +259,5 @@ information. * [chainhash](https://github.com/btcsuite/btcd/tree/master/chaincfg/chainhash) - Provides a generic hash type and associated functions that allows the specific hash algorithm to be abstracted. + * [connmgr](https://github.com/btcsuite/btcd/tree/master/connmgr) - + Package connmgr implements a generic Bitcoin network connection manager. diff --git a/log.go b/log.go index 5125c1e1..35253d8c 100644 --- a/log.go +++ b/log.go @@ -11,6 +11,7 @@ import ( "github.com/btcsuite/btcd/addrmgr" "github.com/btcsuite/btcd/blockchain" "github.com/btcsuite/btcd/blockchain/indexers" + "github.com/btcsuite/btcd/connmgr" "github.com/btcsuite/btcd/database" "github.com/btcsuite/btcd/mempool" "github.com/btcsuite/btcd/peer" @@ -33,6 +34,7 @@ var ( backendLog = seelog.Disabled adxrLog = btclog.Disabled amgrLog = btclog.Disabled + cmgrLog = btclog.Disabled bcdbLog = btclog.Disabled bmgrLog = btclog.Disabled btcdLog = btclog.Disabled @@ -51,6 +53,7 @@ var ( var subsystemLoggers = map[string]btclog.Logger{ "ADXR": adxrLog, "AMGR": amgrLog, + "CMGR": cmgrLog, "BCDB": bcdbLog, "BMGR": bmgrLog, "BTCD": btcdLog, @@ -97,6 +100,10 @@ func useLogger(subsystemID string, logger btclog.Logger) { amgrLog = logger addrmgr.UseLogger(logger) + case "CMGR": + cmgrLog = logger + connmgr.UseLogger(logger) + case "BCDB": bcdbLog = logger database.UseLogger(logger) diff --git a/server.go b/server.go index 53db4bc2..bd35e510 100644 --- a/server.go +++ b/server.go @@ -11,7 +11,6 @@ import ( "errors" "fmt" "math" - mrand "math/rand" "net" "runtime" "strconv" @@ -25,6 +24,7 @@ import ( "github.com/btcsuite/btcd/blockchain/indexers" "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/connmgr" "github.com/btcsuite/btcd/database" "github.com/btcsuite/btcd/mempool" "github.com/btcsuite/btcd/mining" @@ -35,13 +35,6 @@ import ( "github.com/btcsuite/btcutil/bloom" ) -const ( - // These constants are used by the DNS seed code to pick a random last - // seen time. - secondsIn3Days int32 = 24 * 60 * 60 * 3 - secondsIn4Days int32 = 24 * 60 * 60 * 4 -) - const ( // defaultServices describes the default services that are supported by // the server. @@ -54,12 +47,6 @@ const ( // retries when connecting to persistent peers. It is adjusted by the // number of retries such that there is a retry backoff. connectionRetryInterval = time.Second * 5 - - // maxConnectionRetryInterval is the max amount of time retrying of a - // persistent peer is allowed to grow to. This is necessary since the - // retry logic uses a backoff mechanism which increases the interval - // base done the number of retries that have been done. - maxConnectionRetryInterval = time.Minute * 5 ) var ( @@ -109,7 +96,6 @@ type updatePeerHeightsMsg struct { // peerState maintains state of inbound, persistent, outbound peers as well // as banned peers and outbound groups. type peerState struct { - pendingPeers map[string]*serverPeer inboundPeers map[int32]*serverPeer outboundPeers map[int32]*serverPeer persistentPeers map[int32]*serverPeer @@ -124,22 +110,6 @@ func (ps *peerState) Count() int { len(ps.persistentPeers) } -// OutboundCount returns the count of known outbound peers. -func (ps *peerState) OutboundCount() int { - return len(ps.outboundPeers) + len(ps.persistentPeers) -} - -// NeedMoreOutbound returns true if more outbound peers are required. -func (ps *peerState) NeedMoreOutbound() bool { - return ps.OutboundCount() < ps.maxOutboundPeers && - ps.Count() < cfg.MaxPeers -} - -// NeedMoreTries returns true if more outbound peer attempts can be tried. -func (ps *peerState) NeedMoreTries() bool { - return len(ps.pendingPeers) < 2*(ps.maxOutboundPeers-ps.OutboundCount()) -} - // forAllOutboundPeers is a helper function that runs closure on all outbound // peers known to peerState. func (ps *peerState) forAllOutboundPeers(closure func(sp *serverPeer)) { @@ -151,14 +121,6 @@ func (ps *peerState) forAllOutboundPeers(closure func(sp *serverPeer)) { } } -// forPendingPeers is a helper function that runs closure on all pending peers -// known to peerState. -func (ps *peerState) forPendingPeers(closure func(sp *serverPeer)) { - for _, e := range ps.pendingPeers { - closure(e) - } -} - // forAllPeers is a helper function that runs closure on all peers known to // peerState. func (ps *peerState) forAllPeers(closure func(sp *serverPeer)) { @@ -182,17 +144,16 @@ type server struct { listeners []net.Listener chainParams *chaincfg.Params addrManager *addrmgr.AddrManager + connManager *connmgr.ConnManager sigCache *txscript.SigCache rpcServer *rpcServer blockManager *blockManager txMemPool *mempool.TxPool cpuMiner *CPUMiner modifyRebroadcastInv chan interface{} - pendingPeers chan *serverPeer newPeers chan *serverPeer donePeers chan *serverPeer banPeers chan *serverPeer - retryPeers chan *serverPeer wakeup chan struct{} query chan interface{} relayInv chan relayMsg @@ -218,6 +179,7 @@ type server struct { type serverPeer struct { *peer.Peer + connReq *connmgr.ConnReq server *server persistent bool continueHash *chainhash.Hash @@ -228,7 +190,7 @@ type serverPeer struct { requestedBlocks map[chainhash.Hash]struct{} filter *bloom.Filter knownAddresses map[string]struct{} - banScore dynamicBanScore + banScore connmgr.DynamicBanScore quit chan struct{} // The following chans are used to sync blockmanager and server. txProcessed chan struct{} @@ -1248,16 +1210,6 @@ func (s *server) handleAddPeerMsg(state *peerState, sp *serverPeer) bool { // TODO: Check for max peers from a single IP. - // Limit max outbound peers. - if _, ok := state.pendingPeers[sp.Addr()]; ok { - if state.OutboundCount() >= state.maxOutboundPeers { - srvrLog.Infof("Max outbound peers reached [%d] - disconnecting "+ - "peer %s", state.maxOutboundPeers, sp) - sp.Disconnect() - return false - } - } - // Limit max number of total peers. if state.Count() >= cfg.MaxPeers { srvrLog.Infof("Max peers reached [%d] - disconnecting peer %s", @@ -1279,8 +1231,6 @@ func (s *server) handleAddPeerMsg(state *peerState, sp *serverPeer) bool { } else { state.outboundPeers[sp.ID()] = sp } - // Remove from pending peers. - delete(state.pendingPeers, sp.Addr()) } return true @@ -1289,12 +1239,6 @@ func (s *server) handleAddPeerMsg(state *peerState, sp *serverPeer) bool { // handleDonePeerMsg deals with peers that have signalled they are done. It is // invoked from the peerHandler goroutine. func (s *server) handleDonePeerMsg(state *peerState, sp *serverPeer) { - if _, ok := state.pendingPeers[sp.Addr()]; ok { - delete(state.pendingPeers, sp.Addr()) - srvrLog.Debugf("Removed pending peer %s", sp) - return - } - var list map[int32]*serverPeer if sp.persistent { list = state.persistentPeers @@ -1304,23 +1248,21 @@ func (s *server) handleDonePeerMsg(state *peerState, sp *serverPeer) { list = state.outboundPeers } if _, ok := list[sp.ID()]; ok { - // Issue an asynchronous reconnect if the peer was a - // persistent outbound connection. - if !sp.Inbound() && sp.persistent && atomic.LoadInt32(&s.shutdown) == 0 { - // Retry peer - sp2 := s.newOutboundPeer(sp.Addr(), sp.persistent) - if sp2 != nil { - go s.retryConn(sp2, false) - } - } if !sp.Inbound() && sp.VersionKnown() { state.outboundGroups[addrmgr.GroupKey(sp.NA())]-- } + if sp.persistent && sp.connReq != nil { + s.connManager.Disconnect(sp.connReq.ID()) + } delete(list, sp.ID()) srvrLog.Debugf("Removed peer %s", sp) return } + if sp.connReq != nil { + s.connManager.Remove(sp.connReq.ID()) + } + // Update the address' last seen time if the peer has acknowledged // our version and has sent us its version as well. if sp.VerAckReceived() && sp.VersionKnown() && sp.NA() != nil { @@ -1428,6 +1370,11 @@ type getPeersMsg struct { reply chan []*serverPeer } +type getOutboundGroup struct { + key string + reply chan int +} + type getAddedNodesMsg struct { reply chan []*serverPeer } @@ -1485,13 +1432,11 @@ func (s *server) handleQuery(state *peerState, querymsg interface{}) { } // TODO(oga) if too many, nuke a non-perm peer. - sp := s.newOutboundPeer(msg.addr, msg.permanent) - if sp != nil { - go s.peerConnHandler(sp) - msg.reply <- nil - } else { - msg.reply <- errors.New("failed to add peer") - } + go s.connManager.Connect(&connmgr.ConnReq{ + Addr: msg.addr, + Permanent: msg.permanent, + }) + msg.reply <- nil case removeNodeMsg: found := disconnectPeer(state.persistentPeers, msg.cmp, func(sp *serverPeer) { // Keep group counts ok since we remove from @@ -1504,6 +1449,13 @@ func (s *server) handleQuery(state *peerState, querymsg interface{}) { } else { msg.reply <- errors.New("peer not found") } + case getOutboundGroup: + count, ok := state.outboundGroups[msg.key] + if ok { + msg.reply <- count + } else { + msg.reply <- 0 + } // Request a list of the persistent (added) peers. case getAddedNodesMsg: // Respond with a slice of the relavent peers. @@ -1630,53 +1582,6 @@ func (s *server) listenHandler(listener net.Listener) { srvrLog.Tracef("Listener handler done for %s", listener.Addr()) } -// seedFromDNS uses DNS seeding to populate the address manager with peers. -func (s *server) seedFromDNS() { - // Nothing to do if DNS seeding is disabled. - if cfg.DisableDNSSeed { - return - } - - for _, seeder := range activeNetParams.DNSSeeds { - go func(seeder string) { - randSource := mrand.New(mrand.NewSource(time.Now().UnixNano())) - - seedpeers, err := dnsDiscover(seeder) - if err != nil { - discLog.Infof("DNS discovery failed on seed %s: %v", seeder, err) - return - } - numPeers := len(seedpeers) - - discLog.Infof("%d addresses found from DNS seed %s", numPeers, seeder) - - if numPeers == 0 { - return - } - addresses := make([]*wire.NetAddress, len(seedpeers)) - // if this errors then we have *real* problems - intPort, _ := strconv.Atoi(activeNetParams.DefaultPort) - for i, peer := range seedpeers { - addresses[i] = new(wire.NetAddress) - addresses[i].SetAddress(peer, uint16(intPort)) - // bitcoind seeds with addresses from - // a time randomly selected between 3 - // and 7 days ago. - addresses[i].Timestamp = time.Now().Add(-1 * - time.Second * time.Duration(secondsIn3Days+ - randSource.Int31n(secondsIn4Days))) - } - - // Bitcoind uses a lookup of the dns seeder here. This - // is rather strange since the values looked up by the - // DNS seed lookups will vary quite a lot. - // to replicate this behaviour we put all addresses as - // having come from the first one. - s.addrManager.AddAddresses(addresses, addresses[0]) - }(seeder) - } -} - // newOutboundPeer initializes a new outbound peer and setups the message // listeners. func (s *server) newOutboundPeer(addr string, persistent bool) *serverPeer { @@ -1691,15 +1596,6 @@ func (s *server) newOutboundPeer(addr string, persistent bool) *serverPeer { return sp } -// peerConnHandler handles peer connections. It must be run in a goroutine. -func (s *server) peerConnHandler(sp *serverPeer) { - err := s.establishConn(sp) - if err != nil { - srvrLog.Debugf("Failed to connect to %s: %v", sp.Addr(), err) - sp.Disconnect() - } -} - // peerDoneHandler handles peer disconnects by notifiying the server that it's // done. func (s *server) peerDoneHandler(sp *serverPeer) { @@ -1713,51 +1609,6 @@ func (s *server) peerDoneHandler(sp *serverPeer) { close(sp.quit) } -// establishConn establishes a connection to the peer. -func (s *server) establishConn(sp *serverPeer) error { - srvrLog.Debugf("Attempting to connect to %s", sp.Addr()) - conn, err := btcdDial("tcp", sp.Addr()) - if err != nil { - return err - } - sp.AssociateConnection(conn) - s.addrManager.Attempt(sp.NA()) - return nil -} - -// retryConn retries connection to the peer after the given duration. It must -// be run as a goroutine. -func (s *server) retryConn(sp *serverPeer, initialAttempt bool) { - retryDuration := connectionRetryInterval - for { - if initialAttempt { - retryDuration = 0 - initialAttempt = false - } else { - srvrLog.Debugf("Retrying connection to %s in %s", sp.Addr(), - retryDuration) - } - select { - case <-time.After(retryDuration): - err := s.establishConn(sp) - if err != nil { - retryDuration += connectionRetryInterval - if retryDuration > maxConnectionRetryInterval { - retryDuration = maxConnectionRetryInterval - } - continue - } - return - - case <-sp.quit: - return - - case <-s.quit: - return - } - } -} - // peerHandler is used to handle peer operations such as adding and removing // peers to and from the server, banning peers, and broadcasting messages to // peers. It must be run in a goroutine. @@ -1773,7 +1624,6 @@ func (s *server) peerHandler() { srvrLog.Tracef("Starting peer handler") state := &peerState{ - pendingPeers: make(map[string]*serverPeer), inboundPeers: make(map[int32]*serverPeer), persistentPeers: make(map[int32]*serverPeer), outboundPeers: make(map[int32]*serverPeer), @@ -1784,20 +1634,19 @@ func (s *server) peerHandler() { if cfg.MaxPeers < state.maxOutboundPeers { state.maxOutboundPeers = cfg.MaxPeers } - // Add peers discovered through DNS to the address manager. - s.seedFromDNS() - // Start up persistent peers. - permanentPeers := cfg.ConnectPeers - if len(permanentPeers) == 0 { - permanentPeers = cfg.AddPeers - } - for _, addr := range permanentPeers { - sp := s.newOutboundPeer(addr, true) - if sp != nil { - go s.retryConn(sp, true) - } + if !cfg.DisableDNSSeed { + // Add peers discovered through DNS to the address manager. + connmgr.SeedFromDNS(activeNetParams.Params, btcdLookup, func(addrs []*wire.NetAddress) { + // Bitcoind uses a lookup of the dns seeder here. This + // is rather strange since the values looked up by the + // DNS seed lookups will vary quite a lot. + // to replicate this behaviour we put all addresses as + // having come from the first one. + s.addrManager.AddAddresses(addrs, addrs[0]) + }) } + go s.connManager.Start() // if nothing else happens, wake us up soon. time.AfterFunc(10*time.Second, func() { s.wakeup <- struct{}{} }) @@ -1845,86 +1694,9 @@ out: }) break out } - - // Don't try to connect to more peers when running on the - // simulation test network. The simulation network is only - // intended to connect to specified peers and actively avoid - // advertising and connecting to discovered peers. - if cfg.SimNet { - continue - } - - // Only try connect to more peers if we actually need more. - if !state.NeedMoreOutbound() || len(cfg.ConnectPeers) > 0 || - atomic.LoadInt32(&s.shutdown) != 0 { - state.forPendingPeers(func(sp *serverPeer) { - srvrLog.Tracef("Shutdown peer %s", sp) - sp.Disconnect() - }) - continue - } - tries := 0 - for state.NeedMoreOutbound() && - state.NeedMoreTries() && - atomic.LoadInt32(&s.shutdown) == 0 { - addr := s.addrManager.GetAddress("any") - if addr == nil { - break - } - key := addrmgr.GroupKey(addr.NetAddress()) - // Address will not be invalid, local or unroutable - // because addrmanager rejects those on addition. - // Just check that we don't already have an address - // in the same group so that we are not connecting - // to the same network segment at the expense of - // others. - if state.outboundGroups[key] != 0 { - break - } - - tries++ - // After 100 bad tries exit the loop and we'll try again - // later. - if tries > 100 { - break - } - - // Check that we don't have a pending connection to this addr. - addrStr := addrmgr.NetAddressKey(addr.NetAddress()) - if _, ok := state.pendingPeers[addrStr]; ok { - continue - } - - // XXX if we have limited that address skip - - // only allow recent nodes (10mins) after we failed 30 - // times - if tries < 30 && time.Now().Sub(addr.LastAttempt()) < 10*time.Minute { - continue - } - - // allow nondefault ports after 50 failed tries. - if fmt.Sprintf("%d", addr.NetAddress().Port) != - activeNetParams.DefaultPort && tries < 50 { - continue - } - - tries = 0 - sp := s.newOutboundPeer(addrStr, false) - if sp != nil { - go s.peerConnHandler(sp) - state.pendingPeers[sp.Addr()] = sp - } - } - - // We need more peers, wake up in ten seconds and try again. - if state.NeedMoreOutbound() { - time.AfterFunc(10*time.Second, func() { - s.wakeup <- struct{}{} - }) - } } + s.connManager.Stop() s.blockManager.Stop() s.addrManager.Stop() @@ -1982,6 +1754,14 @@ func (s *server) ConnectedCount() int32 { return <-replyChan } +// OutboundGroupCount returns the number of peers connected to the given +// outbound group key. +func (s *server) OutboundGroupCount(key string) int { + replyChan := make(chan int) + s.query <- getOutboundGroup{key: key, reply: replyChan} + return <-replyChan +} + // AddedNodeInfo returns an array of btcjson.GetAddedNodeInfoResult structures // describing the persistent (added) nodes. func (s *server) AddedNodeInfo() []*serverPeer { @@ -2519,7 +2299,6 @@ func newServer(listenAddrs []string, db database.DB, chainParams *chaincfg.Param newPeers: make(chan *serverPeer, cfg.MaxPeers), donePeers: make(chan *serverPeer, cfg.MaxPeers), banPeers: make(chan *serverPeer, cfg.MaxPeers), - retryPeers: make(chan *serverPeer, cfg.MaxPeers), wakeup: make(chan struct{}), query: make(chan interface{}), relayInv: make(chan relayMsg, cfg.MaxPeers), @@ -2602,6 +2381,78 @@ func newServer(listenAddrs []string, db database.DB, chainParams *chaincfg.Param } s.cpuMiner = newCPUMiner(&policy, &s) + // Only setup a function to return new addresses to connect to when + // not running in connect-only mode. The simulation network is always + // in connect-only mode since it is only intended to connect to + // specified peers and actively avoid advertising and connecting to + // discovered peers in order to prevent it from becoming a public test + // network. + var newAddressFunc connmgr.AddressFunc + if !cfg.SimNet && len(cfg.ConnectPeers) == 0 { + newAddressFunc = func() (string, error) { + for tries := 0; tries < 100; tries++ { + addr := s.addrManager.GetAddress("any") + if addr == nil { + break + } + + // Address will not be invalid, local or unroutable + // because addrmanager rejects those on addition. + // Just check that we don't already have an address + // in the same group so that we are not connecting + // to the same network segment at the expense of + // others. + key := addrmgr.GroupKey(addr.NetAddress()) + if s.OutboundGroupCount(key) != 0 { + continue + } + + // only allow recent nodes (10mins) after we failed 30 + // times + if tries < 30 && time.Now().Sub(addr.LastAttempt()) < 10*time.Minute { + continue + } + + // allow nondefault ports after 50 failed tries. + if fmt.Sprintf("%d", addr.NetAddress().Port) != + activeNetParams.DefaultPort && tries < 50 { + continue + } + return addrmgr.NetAddressKey(addr.NetAddress()), nil + } + return "", errors.New("no valid connect address") + } + } + + // Create a connection manager. + cmgr, err := connmgr.New(&connmgr.Config{ + RetryDuration: connectionRetryInterval, + MaxOutbound: defaultMaxOutbound, + Dial: btcdDial, + OnConnection: func(c *connmgr.ConnReq, conn net.Conn) { + sp := s.newOutboundPeer(c.Addr, c.Permanent) + if sp != nil { + sp.AssociateConnection(conn) + sp.connReq = c + s.addrManager.Attempt(sp.NA()) + } + }, + GetNewAddress: newAddressFunc, + }) + if err != nil { + return nil, err + } + s.connManager = cmgr + + // Start up persistent peers. + permanentPeers := cfg.ConnectPeers + if len(permanentPeers) == 0 { + permanentPeers = cfg.AddPeers + } + for _, addr := range permanentPeers { + go s.connManager.Connect(&connmgr.ConnReq{Addr: addr, Permanent: true}) + } + if !cfg.DisableRPC { s.rpcServer, err = newRPCServer(cfg.RPCListeners, &policy, &s) if err != nil {