diff --git a/example_config.yaml b/example_config.yaml index 2fe4a6b..b6035b8 100644 --- a/example_config.yaml +++ b/example_config.yaml @@ -25,9 +25,13 @@ chihaya: # If you do not wish to run this, delete this section. http: # The network interface that will bind to an HTTP server for serving - # BitTorrent traffic. + # BitTorrent traffic. Remove this to disable the non-TLS listener. addr: "0.0.0.0:6969" + # The network interface that will bind to an HTTPS server for serving + # BitTorrent traffic. If set, tls_cert_path and tls_key_path are required. + https_addr: "" + # The path to the required files to listen via HTTPS. tls_cert_path: "" tls_key_path: "" diff --git a/frontend/http/frontend.go b/frontend/http/frontend.go index 6dcd0df..743616c 100644 --- a/frontend/http/frontend.go +++ b/frontend/http/frontend.go @@ -5,6 +5,7 @@ package http import ( "context" "crypto/tls" + "errors" "net" "net/http" "time" @@ -21,6 +22,7 @@ import ( // Frontend. type Config struct { Addr string `yaml:"addr"` + HTTPSAddr string `yaml:"https_addr"` ReadTimeout time.Duration `yaml:"read_timeout"` WriteTimeout time.Duration `yaml:"write_timeout"` IdleTimeout time.Duration `yaml:"idle_timeout"` @@ -36,6 +38,7 @@ type Config struct { func (cfg Config) LogFields() log.Fields { return log.Fields{ "addr": cfg.Addr, + "httpsAddr": cfg.HTTPSAddr, "readTimeout": cfg.ReadTimeout, "writeTimeout": cfg.WriteTimeout, "idleTimeout": cfg.IdleTimeout, @@ -103,6 +106,7 @@ func (cfg Config) Validate() Config { // Frontend represents the state of an HTTP BitTorrent Frontend. type Frontend struct { srv *http.Server + tlsSrv *http.Server tlsCfg *tls.Config logic frontend.TrackerLogic @@ -119,6 +123,10 @@ func NewFrontend(logic frontend.TrackerLogic, provided Config) (*Frontend, error Config: cfg, } + if cfg.Addr == "" && cfg.HTTPSAddr == "" { + return nil, errors.New("must specify addr or https_addr or both") + } + // If TLS is enabled, create a key pair. if cfg.TLSCertPath != "" && cfg.TLSKeyPath != "" { var err error @@ -131,23 +139,54 @@ func NewFrontend(logic frontend.TrackerLogic, provided Config) (*Frontend, error } } - go func() { - if err := f.listenAndServe(); err != nil { - log.Fatal("failed while serving http", log.Err(err)) - } - }() + if cfg.HTTPSAddr != "" && f.tlsCfg == nil { + return nil, errors.New("must specify tls_cert_path and tls_key_path when using https_addr") + } + if cfg.HTTPSAddr == "" && f.tlsCfg != nil { + return nil, errors.New("must specify https_addr when using tls_cert_path and tls_key_path") + } + + if cfg.Addr != "" { + go func() { + if err := f.listenAndServe(); err != nil { + log.Fatal("failed while serving http", log.Err(err)) + } + }() + } + + if cfg.HTTPSAddr != "" { + go func() { + if err := f.listenAndServeTLS(); err != nil { + log.Fatal("failed while serving https", log.Err(err)) + } + }() + } return f, nil } // Stop provides a thread-safe way to shutdown a currently running Frontend. func (f *Frontend) Stop() stop.Result { - c := make(stop.Channel) - go func() { - c.Done(f.srv.Shutdown(context.Background())) - }() + stopGroup := stop.NewGroup() - return c.Result() + if f.srv != nil { + stopGroup.AddFunc(f.makeStopFunc(f.srv)) + } + if f.tlsSrv != nil { + stopGroup.AddFunc(f.makeStopFunc(f.tlsSrv)) + } + + return stopGroup.Stop() +} + +func (f *Frontend) makeStopFunc(stopSrv *http.Server) stop.Func { + return func() stop.Result { + c := make(stop.Channel) + go func() { + c.Done(stopSrv.Shutdown(context.Background())) + }() + return c.Result() + } } func (f *Frontend) handler() http.Handler { @@ -164,12 +203,11 @@ func (f *Frontend) handler() http.Handler { return router } -// listenAndServe blocks while listening and serving HTTP BitTorrent requests -// until Stop() is called or an error is returned. +// listenAndServe blocks while listening and serving non-TLS HTTP BitTorrent +// requests until Stop() is called or an error is returned. func (f *Frontend) listenAndServe() error { f.srv = &http.Server{ Addr: f.Addr, - TLSConfig: f.tlsCfg, Handler: f.handler(), ReadTimeout: f.ReadTimeout, WriteTimeout: f.WriteTimeout, @@ -179,18 +217,30 @@ func (f *Frontend) listenAndServe() error { f.srv.SetKeepAlivesEnabled(f.EnableKeepAlive) // Start the HTTP server. - if f.tlsCfg != nil { - // ... using TLS. - if err := f.srv.ListenAndServeTLS("", ""); err != http.ErrServerClosed { - return err - } - } else { - // ... using plain TCP. - if err := f.srv.ListenAndServe(); err != http.ErrServerClosed { - return err - } + if err := f.srv.ListenAndServe(); err != http.ErrServerClosed { + return err + } + return nil +} + +// listenAndServeTLS blocks while listening and serving TLS HTTP BitTorrent +// requests until Stop() is called or an error is returned. +func (f *Frontend) listenAndServeTLS() error { + f.tlsSrv = &http.Server{ + Addr: f.HTTPSAddr, + TLSConfig: f.tlsCfg, + Handler: f.handler(), + ReadTimeout: f.ReadTimeout, + WriteTimeout: f.WriteTimeout, } + // Disable KeepAlives. + f.tlsSrv.SetKeepAlivesEnabled(f.EnableKeepAlive) + + // Start the HTTP server. + if err := f.tlsSrv.ListenAndServeTLS("", ""); err != http.ErrServerClosed { + return err + } return nil }