diff --git a/config.go b/config.go index 6bc1fad..f11e008 100644 --- a/config.go +++ b/config.go @@ -69,6 +69,7 @@ type config struct { RPCKey string `long:"rpckey" description:"File containing the certificate key"` RPCMaxClients int64 `long:"rpcmaxclients" description:"Max number of RPC clients for standard connections"` RPCMaxWebsockets int64 `long:"rpcmaxwebsockets" description:"Max number of RPC websocket connections"` + DisableServerTLS bool `long:"noservertls" description:"Disable TLS for the RPC server -- NOTE: This is only allowed if the RPC server is bound to localhost"` MainNet bool `long:"mainnet" description:"Use the main Bitcoin network (default testnet3)"` SimNet bool `long:"simnet" description:"Use the simulation test network (default testnet3)"` KeypoolSize uint `short:"k" long:"keypoolsize" description:"DEPRECATED -- Maximum number of addresses in keypool"` @@ -265,9 +266,11 @@ func loadConfig() (*config, []string, error) { } // Show the version and exit if the version flag was specified. + funcName := "loadConfig" + appName := filepath.Base(os.Args[0]) + appName = strings.TrimSuffix(appName, filepath.Ext(appName)) + usageMessage := fmt.Sprintf("Use %s -h to show usage", appName) if preCfg.ShowVersion { - appName := filepath.Base(os.Args[0]) - appName = strings.TrimSuffix(appName, filepath.Ext(appName)) fmt.Println(appName, "version", version()) os.Exit(0) } @@ -363,6 +366,11 @@ func loadConfig() (*config, []string, error) { // Add default port to connect flag if missing. cfg.RPCConnect = normalizeAddress(cfg.RPCConnect, activeNet.btcdPort) + localhostListeners := map[string]struct{}{ + "localhost": struct{}{}, + "127.0.0.1": struct{}{}, + "::1": struct{}{}, + } // If CAFile is unset, choose either the copy or local btcd cert. if cfg.CAFile == "" { cfg.CAFile = filepath.Join(cfg.DataDir, defaultCAFilename) @@ -406,6 +414,31 @@ func loadConfig() (*config, []string, error) { cfg.SvrListeners = normalizeAddresses(cfg.SvrListeners, activeNet.svrPort) + // Only allow server TLS to be disabled if the RPC is bound to localhost + // addresses. + if cfg.DisableServerTLS { + for _, addr := range cfg.SvrListeners { + host, _, err := net.SplitHostPort(addr) + if err != nil { + str := "%s: RPC listen interface '%s' is " + + "invalid: %v" + err := fmt.Errorf(str, funcName, addr, err) + fmt.Fprintln(os.Stderr, err) + fmt.Fprintln(os.Stderr, usageMessage) + return nil, nil, err + } + if _, ok := localhostListeners[host]; !ok { + str := "%s: the --noservertls option may not be used " + + "when binding RPC to non localhost " + + "addresses: %s" + err := fmt.Errorf(str, funcName, addr) + fmt.Fprintln(os.Stderr, err) + fmt.Fprintln(os.Stderr, usageMessage) + return nil, nil, err + } + } + } + // Expand environment variable and leading ~ for filepaths. cfg.CAFile = cleanAndExpandPath(cfg.CAFile) diff --git a/rpcserver.go b/rpcserver.go index 3c5864a..aee7636 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -318,29 +318,41 @@ func newRPCServer(listenAddrs []string, maxPost, maxWebsockets int64) (*rpcServe quit: make(chan struct{}), } - // Check for existence of cert file and key file - if !fileExists(cfg.RPCKey) && !fileExists(cfg.RPCCert) { - // if both files do not exist, we generate them. - err := genCertPair(cfg.RPCCert, cfg.RPCKey) + // Setup TLS if not disabled. + listenFunc := net.Listen + if !cfg.DisableServerTLS { + // Check for existence of cert file and key file + if !fileExists(cfg.RPCKey) && !fileExists(cfg.RPCCert) { + // if both files do not exist, we generate them. + err := genCertPair(cfg.RPCCert, cfg.RPCKey) + if err != nil { + return nil, err + } + } + keypair, err := tls.LoadX509KeyPair(cfg.RPCCert, cfg.RPCKey) if err != nil { return nil, err } - } - keypair, err := tls.LoadX509KeyPair(cfg.RPCCert, cfg.RPCKey) - if err != nil { - return nil, err - } - tlsConfig := tls.Config{ - Certificates: []tls.Certificate{keypair}, - MinVersion: tls.VersionTLS12, + tlsConfig := tls.Config{ + Certificates: []tls.Certificate{keypair}, + MinVersion: tls.VersionTLS12, + } + + // Change the standard net.Listen function to the tls one. + listenFunc = func(net string, laddr string) (net.Listener, error) { + return tls.Listen(net, laddr, &tlsConfig) + } } ipv4ListenAddrs, ipv6ListenAddrs, err := parseListeners(listenAddrs) + if err != nil { + return nil, err + } listeners := make([]net.Listener, 0, len(ipv6ListenAddrs)+len(ipv4ListenAddrs)) for _, addr := range ipv4ListenAddrs { - listener, err := tls.Listen("tcp4", addr, &tlsConfig) + listener, err := listenFunc("tcp4", addr) if err != nil { log.Warnf("RPCS: Can't listen on %s: %v", addr, err) @@ -350,7 +362,7 @@ func newRPCServer(listenAddrs []string, maxPost, maxWebsockets int64) (*rpcServe } for _, addr := range ipv6ListenAddrs { - listener, err := tls.Listen("tcp6", addr, &tlsConfig) + listener, err := listenFunc("tcp6", addr) if err != nil { log.Warnf("RPCS: Can't listen on %s: %v", addr, err)