diff --git a/cmd.go b/cmd.go index 2ad885b..29d804a 100644 --- a/cmd.go +++ b/cmd.go @@ -182,7 +182,11 @@ func walletMain() error { // Start account manager and open accounts. AcctMgr.Start() - server, err = newRPCServer(cfg.SvrListeners) + server, err = newRPCServer( + cfg.SvrListeners, + cfg.RPCMaxClients, + cfg.RPCMaxWebsockets, + ) if err != nil { log.Errorf("Unable to create HTTP server: %v", err) return err diff --git a/config.go b/config.go index acf09fe..66e9a69 100644 --- a/config.go +++ b/config.go @@ -30,14 +30,16 @@ import ( ) const ( - defaultCAFilename = "btcd.cert" - defaultConfigFilename = "btcwallet.conf" - defaultBtcNet = btcwire.TestNet3 - defaultLogLevel = "info" - defaultLogDirname = "logs" - defaultLogFilename = "btcwallet.log" - defaultKeypoolSize = 100 - defaultDisallowFree = false + defaultCAFilename = "btcd.cert" + defaultConfigFilename = "btcwallet.conf" + defaultBtcNet = btcwire.TestNet3 + defaultLogLevel = "info" + defaultLogDirname = "logs" + defaultLogFilename = "btcwallet.log" + defaultKeypoolSize = 100 + defaultDisallowFree = false + defaultRPCMaxClients = 10 + defaultRPCMaxWebsockets = 25 ) var ( @@ -52,28 +54,30 @@ var ( ) type config struct { - ShowVersion bool `short:"V" long:"version" description:"Display version information and exit"` - CAFile string `long:"cafile" description:"File containing root certificates to authenticate a TLS connections with btcd"` - RPCConnect string `short:"c" long:"rpcconnect" description:"Hostname/IP and port of btcd RPC server to connect to (default localhost:18334, mainnet: localhost:8334, simnet: localhost:18556)"` - DebugLevel string `short:"d" long:"debuglevel" description:"Logging level {trace, debug, info, warn, error, critical}"` - ConfigFile string `short:"C" long:"configfile" description:"Path to configuration file"` - SvrListeners []string `long:"rpclisten" description:"Listen for RPC/websocket connections on this interface/port (default port: 18332, mainnet: 8332, simnet: 18554)"` - DataDir string `short:"D" long:"datadir" description:"Directory to store wallets and transactions"` - LogDir string `long:"logdir" description:"Directory to log output."` - Username string `short:"u" long:"username" description:"Username for client and btcd authorization"` - Password string `short:"P" long:"password" default-mask:"-" description:"Password for client and btcd authorization"` - BtcdUsername string `long:"btcdusername" description:"Alternative username for btcd authorization"` - BtcdPassword string `long:"btcdpassword" default-mask:"-" description:"Alternative password for btcd authorization"` - RPCCert string `long:"rpccert" description:"File containing the certificate file"` - RPCKey string `long:"rpckey" description:"File containing the certificate key"` - MainNet bool `long:"mainnet" description:"Use the main Bitcoin network (default testnet3)"` - SimNet bool `long:"simnet" description:"Use the simulation test network (default testnet3)"` - KeypoolSize uint `short:"k" long:"keypoolsize" description:"Maximum number of addresses in keypool"` - DisallowFree bool `long:"disallowfree" description:"Force transactions to always include a fee"` - Proxy string `long:"proxy" description:"Connect via SOCKS5 proxy (eg. 127.0.0.1:9050)"` - ProxyUser string `long:"proxyuser" description:"Username for proxy server"` - ProxyPass string `long:"proxypass" default-mask:"-" description:"Password for proxy server"` - Profile string `long:"profile" description:"Enable HTTP profiling on given port -- NOTE port must be between 1024 and 65536"` + ShowVersion bool `short:"V" long:"version" description:"Display version information and exit"` + CAFile string `long:"cafile" description:"File containing root certificates to authenticate a TLS connections with btcd"` + RPCConnect string `short:"c" long:"rpcconnect" description:"Hostname/IP and port of btcd RPC server to connect to (default localhost:18334, mainnet: localhost:8334, simnet: localhost:18556)"` + DebugLevel string `short:"d" long:"debuglevel" description:"Logging level {trace, debug, info, warn, error, critical}"` + ConfigFile string `short:"C" long:"configfile" description:"Path to configuration file"` + SvrListeners []string `long:"rpclisten" description:"Listen for RPC/websocket connections on this interface/port (default port: 18332, mainnet: 8332, simnet: 18554)"` + DataDir string `short:"D" long:"datadir" description:"Directory to store wallets and transactions"` + LogDir string `long:"logdir" description:"Directory to log output."` + Username string `short:"u" long:"username" description:"Username for client and btcd authorization"` + Password string `short:"P" long:"password" default-mask:"-" description:"Password for client and btcd authorization"` + BtcdUsername string `long:"btcdusername" description:"Alternative username for btcd authorization"` + BtcdPassword string `long:"btcdpassword" default-mask:"-" description:"Alternative password for btcd authorization"` + RPCCert string `long:"rpccert" description:"File containing the certificate file"` + RPCKey string `long:"rpckey" description:"File containing the certificate key"` + RPCMaxClients int64 `long:"rpcmaxclients" description:"Max number of RPC clients for standard connections"` + RPCMaxWebsockets int64 `long:"rpcmaxwebsockets" description:"Max number of RPC websocket connections"` + MainNet bool `long:"mainnet" description:"Use the main Bitcoin network (default testnet3)"` + SimNet bool `long:"simnet" description:"Use the simulation test network (default testnet3)"` + KeypoolSize uint `short:"k" long:"keypoolsize" description:"Maximum number of addresses in keypool"` + DisallowFree bool `long:"disallowfree" description:"Force transactions to always include a fee"` + Proxy string `long:"proxy" description:"Connect via SOCKS5 proxy (eg. 127.0.0.1:9050)"` + ProxyUser string `long:"proxyuser" description:"Username for proxy server"` + ProxyPass string `long:"proxypass" default-mask:"-" description:"Password for proxy server"` + Profile string `long:"profile" description:"Enable HTTP profiling on given port -- NOTE port must be between 1024 and 65536"` } // cleanAndExpandPath expands environement variables and leading ~ in the @@ -233,14 +237,16 @@ func normalizeAddress(addr, defaultPort string) string { func loadConfig() (*config, []string, error) { // Default config. cfg := config{ - DebugLevel: defaultLogLevel, - ConfigFile: defaultConfigFile, - DataDir: defaultDataDir, - LogDir: defaultLogDir, - RPCKey: defaultRPCKeyFile, - RPCCert: defaultRPCCertFile, - KeypoolSize: defaultKeypoolSize, - DisallowFree: defaultDisallowFree, + DebugLevel: defaultLogLevel, + ConfigFile: defaultConfigFile, + DataDir: defaultDataDir, + LogDir: defaultLogDir, + RPCKey: defaultRPCKeyFile, + RPCCert: defaultRPCCertFile, + KeypoolSize: defaultKeypoolSize, + DisallowFree: defaultDisallowFree, + RPCMaxClients: defaultRPCMaxClients, + RPCMaxWebsockets: defaultRPCMaxWebsockets, } // A config file in the current directory takes precedence. diff --git a/rpcserver.go b/rpcserver.go index d1bcb02..e8b1e39 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -35,6 +35,7 @@ import ( "path/filepath" "runtime" "sync" + "sync/atomic" "time" "github.com/conformal/btcec" @@ -206,10 +207,12 @@ func genCertPair(certFile, keyFile string) error { // rpcServer holds the items the RPC server may need to access (auth, // config, shutdown, etc.) type rpcServer struct { - wg sync.WaitGroup - listeners []net.Listener - authsha [sha256.Size]byte - wsClients map[*websocketClient]struct{} + wg sync.WaitGroup + maxClients int64 // Maximum number of concurrent active RPC HTTP clients + maxWebsockets int64 // Maximum number of concurrent active RPC WS clients + listeners []net.Listener + authsha [sha256.Size]byte + wsClients map[*websocketClient]struct{} upgrader websocket.Upgrader @@ -224,12 +227,14 @@ type rpcServer struct { // newRPCServer creates a new server for serving RPC client connections, both // HTTP POST and websocket. -func newRPCServer(listenAddrs []string) (*rpcServer, error) { +func newRPCServer(listenAddrs []string, maxClients, maxWebsockets int64) (*rpcServer, error) { login := cfg.Username + ":" + cfg.Password auth := "Basic " + base64.StdEncoding.EncodeToString([]byte(login)) s := rpcServer{ - authsha: sha256.Sum256([]byte(auth)), - wsClients: map[*websocketClient]struct{}{}, + authsha: sha256.Sum256([]byte(auth)), + maxClients: maxClients, + maxWebsockets: maxWebsockets, + wsClients: map[*websocketClient]struct{}{}, upgrader: websocket.Upgrader{ // Allow all origins. CheckOrigin: func(r *http.Request) bool { return true }, @@ -303,6 +308,7 @@ func (s *rpcServer) Start() { serveMux := http.NewServeMux() const rpcAuthTimeoutSeconds = 10 + httpServer := &http.Server{ Handler: serveMux, @@ -310,45 +316,50 @@ func (s *rpcServer) Start() { // handshake within the allowed timeframe. ReadTimeout: time.Second * rpcAuthTimeoutSeconds, } - serveMux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Connection", "close") - w.Header().Set("Content-Type", "application/json") - r.Close = true - // TODO: Limit number of active connections. + serveMux.Handle("/", + throttledFn(s.maxClients, func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Connection", "close") + w.Header().Set("Content-Type", "application/json") + r.Close = true - if err := s.checkAuthHeader(r); err != nil { - log.Warnf("Unauthorized client connection attempt") - http.Error(w, "401 Unauthorized.", http.StatusUnauthorized) - return - } - s.PostClientRPC(w, r) - }) - serveMux.HandleFunc("/ws", func(w http.ResponseWriter, r *http.Request) { - authenticated := false - switch s.checkAuthHeader(r) { - case nil: - authenticated = true - case ErrNoAuth: - // nothing - default: - // If auth was supplied but incorrect, rather than simply - // being missing, immediately terminate the connection. - log.Warnf("Disconnecting improperly authorized " + - "websocket client") - http.Error(w, "401 Unauthorized.", http.StatusUnauthorized) - return - } + if err := s.checkAuthHeader(r); err != nil { + log.Warnf("Unauthorized client connection attempt") + http.Error(w, "401 Unauthorized.", http.StatusUnauthorized) + return + } + s.PostClientRPC(w, r) + }), + ) + + serveMux.Handle("/ws", + throttledFn(s.maxWebsockets, func(w http.ResponseWriter, r *http.Request) { + authenticated := false + switch s.checkAuthHeader(r) { + case nil: + authenticated = true + case ErrNoAuth: + // nothing + default: + // If auth was supplied but incorrect, rather than simply + // being missing, immediately terminate the connection. + log.Warnf("Disconnecting improperly authorized " + + "websocket client") + http.Error(w, "401 Unauthorized.", http.StatusUnauthorized) + return + } + + conn, err := s.upgrader.Upgrade(w, r, nil) + if err != nil { + log.Warnf("Cannot websocket upgrade client %s: %v", + r.RemoteAddr, err) + return + } + wsc := newWebsocketClient(conn, authenticated, r.RemoteAddr) + s.WebsocketClientRPC(wsc) + }), + ) - conn, err := s.upgrader.Upgrade(w, r, nil) - if err != nil { - log.Warnf("Cannot websocket upgrade client %s: %v", - r.RemoteAddr, err) - return - } - wsc := newWebsocketClient(conn, authenticated, r.RemoteAddr) - s.WebsocketClientRPC(wsc) - }) for _, listener := range s.listeners { s.wg.Add(1) go func(listener net.Listener) { @@ -428,6 +439,31 @@ func (s *rpcServer) checkAuthHeader(r *http.Request) error { return nil } +// throttledFn wraps an http.HandlerFunc with throttling of concurrent active +// clients by responding with an HTTP 429 when the threshold is crossed. +func throttledFn(threshold int64, f http.HandlerFunc) http.Handler { + return throttled(threshold, f) +} + +// throttled wraps an http.Handler with throttling of concurrent active +// clients by responding with an HTTP 429 when the threshold is crossed. +func throttled(threshold int64, h http.Handler) http.Handler { + var active int64 + + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + current := atomic.AddInt64(&active, 1) + defer atomic.AddInt64(&active, -1) + + if current-1 >= threshold { + log.Warnf("Reached threshold of %d concurrent active clients", threshold) + http.Error(w, "429 Too Many Requests", 429) + return + } + + h.ServeHTTP(w, r) + }) +} + func (s *rpcServer) WebsocketClientRead(wsc *websocketClient) { for { _, request, err := wsc.conn.ReadMessage() @@ -746,6 +782,8 @@ func (s *rpcServer) WebsocketClientRPC(wsc *websocketClient) { // Send initial unsolicited notifications. // TODO: these should be requested by the client first. s.NotifyConnectionStatus(wsc) + + <-wsc.quit } // maxRequestSize specifies the maximum number of bytes in the request body diff --git a/rpcserver_test.go b/rpcserver_test.go new file mode 100644 index 0000000..5a86fee --- /dev/null +++ b/rpcserver_test.go @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2013, 2014 Conformal Systems LLC + * + * Permission to use, copy, modify, and distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +package main + +import ( + "net/http" + "net/http/httptest" + "reflect" + "testing" + "time" +) + +func TestThrottle(t *testing.T) { + const threshold = 1 + + srv := httptest.NewServer(throttledFn(threshold, + func(w http.ResponseWriter, r *http.Request) { + time.Sleep(20 * time.Millisecond) + }), + ) + + codes := make(chan int, 2) + for i := 0; i < cap(codes); i++ { + go func() { + res, err := http.Get(srv.URL) + if err != nil { + t.Fatal(err) + } + codes <- res.StatusCode + }() + } + + got := make(map[int]int, cap(codes)) + for i := 0; i < cap(codes); i++ { + got[<-codes]++ + } + + want := map[int]int{200: 1, 429: 1} + if !reflect.DeepEqual(want, got) { + t.Fatalf("status codes: want: %v, got: %v", want, got) + } +}