diff --git a/cmd/chihaya/config.go b/cmd/chihaya/config.go index bc50f1e..4fb60e2 100644 --- a/cmd/chihaya/config.go +++ b/cmd/chihaya/config.go @@ -21,52 +21,21 @@ type hookConfig struct { Config interface{} `yaml:"config"` } -// ConfigFile represents a namespaced YAML configation file. -type ConfigFile struct { - MainConfigBlock struct { - middleware.Config `yaml:",inline"` - PrometheusAddr string `yaml:"prometheus_addr"` - HTTPConfig httpfrontend.Config `yaml:"http"` - UDPConfig udpfrontend.Config `yaml:"udp"` - Storage memory.Config `yaml:"storage"` - PreHooks []hookConfig `yaml:"prehooks"` - PostHooks []hookConfig `yaml:"posthooks"` - } `yaml:"chihaya"` -} - -// ParseConfigFile returns a new ConfigFile given the path to a YAML -// configuration file. -// -// It supports relative and absolute paths and environment variables. -func ParseConfigFile(path string) (*ConfigFile, error) { - if path == "" { - return nil, errors.New("no config path specified") - } - - f, err := os.Open(os.ExpandEnv(path)) - if err != nil { - return nil, err - } - defer f.Close() - - contents, err := ioutil.ReadAll(f) - if err != nil { - return nil, err - } - - var cfgFile ConfigFile - err = yaml.Unmarshal(contents, &cfgFile) - if err != nil { - return nil, err - } - - return &cfgFile, nil +// Config represents the configuration used for executing Chihaya. +type Config struct { + middleware.Config `yaml:",inline"` + PrometheusAddr string `yaml:"prometheus_addr"` + HTTPConfig httpfrontend.Config `yaml:"http"` + UDPConfig udpfrontend.Config `yaml:"udp"` + Storage memory.Config `yaml:"storage"` + PreHooks []hookConfig `yaml:"prehooks"` + PostHooks []hookConfig `yaml:"posthooks"` } // CreateHooks creates instances of Hooks for all of the PreHooks and PostHooks -// configured in a ConfigFile. -func (cfg ConfigFile) CreateHooks() (preHooks, postHooks []middleware.Hook, err error) { - for _, hookCfg := range cfg.MainConfigBlock.PreHooks { +// configured in a Config. +func (cfg Config) CreateHooks() (preHooks, postHooks []middleware.Hook, err error) { + for _, hookCfg := range cfg.PreHooks { cfgBytes, err := yaml.Marshal(hookCfg.Config) if err != nil { panic("failed to remarshal valid YAML") @@ -109,10 +78,44 @@ func (cfg ConfigFile) CreateHooks() (preHooks, postHooks []middleware.Hook, err } } - for _, hookCfg := range cfg.MainConfigBlock.PostHooks { + for _, hookCfg := range cfg.PostHooks { switch hookCfg.Name { } } return } + +// ConfigFile represents a namespaced YAML configation file. +type ConfigFile struct { + Chihaya Config `yaml:"chihaya"` +} + +// ParseConfigFile returns a new ConfigFile given the path to a YAML +// configuration file. +// +// It supports relative and absolute paths and environment variables. +func ParseConfigFile(path string) (*ConfigFile, error) { + if path == "" { + return nil, errors.New("no config path specified") + } + + f, err := os.Open(os.ExpandEnv(path)) + if err != nil { + return nil, err + } + defer f.Close() + + contents, err := ioutil.ReadAll(f) + if err != nil { + return nil, err + } + + var cfgFile ConfigFile + err = yaml.Unmarshal(contents, &cfgFile) + if err != nil { + return nil, err + } + + return &cfgFile, nil +} diff --git a/cmd/chihaya/main.go b/cmd/chihaya/main.go index 6033f19..3d41148 100644 --- a/cmd/chihaya/main.go +++ b/cmd/chihaya/main.go @@ -2,29 +2,123 @@ package main import ( "errors" - "net/http" "os" "os/signal" "runtime/pprof" + "strings" "syscall" log "github.com/Sirupsen/logrus" - "github.com/prometheus/client_golang/prometheus" "github.com/spf13/cobra" - httpfrontend "github.com/chihaya/chihaya/frontend/http" - udpfrontend "github.com/chihaya/chihaya/frontend/udp" + "github.com/chihaya/chihaya/frontend/http" + "github.com/chihaya/chihaya/frontend/udp" "github.com/chihaya/chihaya/middleware" + "github.com/chihaya/chihaya/pkg/prometheus" + "github.com/chihaya/chihaya/pkg/stop" "github.com/chihaya/chihaya/storage" "github.com/chihaya/chihaya/storage/memory" ) -func rootCmdRun(cmd *cobra.Command, args []string) error { - debugLog, _ := cmd.Flags().GetBool("debug") - if debugLog { - log.SetLevel(log.DebugLevel) - log.Debugln("debug logging enabled") +// Run represents the state of a running instance of Chihaya. +type Run struct { + configFilePath string + peerStore storage.PeerStore + logic *middleware.Logic + sg *stop.Group +} + +// NewRun runs an instance of Chihaya. +func NewRun(configFilePath string) (*Run, error) { + r := &Run{ + configFilePath: configFilePath, } + + return r, r.Start(nil) +} + +// Start begins an instance of Chihaya. +// It is optional to provide an instance of the peer store to avoid the +// creation of a new one. +func (r *Run) Start(ps storage.PeerStore) error { + configFile, err := ParseConfigFile(r.configFilePath) + if err != nil { + return errors.New("failed to read config: " + err.Error()) + } + + cfg := configFile.Chihaya + preHooks, postHooks, err := cfg.CreateHooks() + if err != nil { + return errors.New("failed to validate hook config: " + err.Error()) + } + + r.sg = stop.NewGroup() + r.sg.Add(prometheus.NewServer(cfg.PrometheusAddr)) + + if ps == nil { + ps, err = memory.New(cfg.Storage) + if err != nil { + return errors.New("failed to create memory storage: " + err.Error()) + } + } + r.peerStore = ps + + r.logic = middleware.NewLogic(cfg.Config, r.peerStore, preHooks, postHooks) + + if cfg.HTTPConfig.Addr != "" { + httpfe, err := http.NewFrontend(r.logic, cfg.HTTPConfig) + if err != nil { + return err + } + r.sg.Add(httpfe) + } + + if cfg.UDPConfig.Addr != "" { + udpfe, err := udp.NewFrontend(r.logic, cfg.UDPConfig) + if err != nil { + return err + } + r.sg.Add(udpfe) + } + + return nil +} + +func combineErrors(prefix string, errs []error) error { + var errStrs []string + for _, err := range errs { + errStrs = append(errStrs, err.Error()) + } + + return errors.New(prefix + ": " + strings.Join(errStrs, "; ")) +} + +// Stop shuts down an instance of Chihaya. +func (r *Run) Stop(keepPeerStore bool) (storage.PeerStore, error) { + log.Debug("stopping frontends and prometheus endpoint") + if errs := r.sg.Stop(); len(errs) != 0 { + return nil, combineErrors("failed while shutting down frontends", errs) + } + + log.Debug("stopping logic") + if errs := r.logic.Stop(); len(errs) != 0 { + return nil, combineErrors("failed while shutting down middleware", errs) + } + + if !keepPeerStore { + log.Debug("stopping peer store") + if err, closed := <-r.peerStore.Stop(); !closed { + return nil, err + } + r.peerStore = nil + } + + return r.peerStore, nil +} + +// RunCmdFunc implements a Cobra command that runs an instance of Chihaya and +// handles reloading and shutdown via process signals. +func RunCmdFunc(cmd *cobra.Command, args []string) error { cpuProfilePath, _ := cmd.Flags().GetString("cpuprofile") if cpuProfilePath != "" { log.Infoln("enabled CPU profiling to", cpuProfilePath) @@ -36,162 +130,43 @@ func rootCmdRun(cmd *cobra.Command, args []string) error { defer pprof.StopCPUProfile() } - configFilePath, _ := cmd.Flags().GetString("config") - configFile, err := ParseConfigFile(configFilePath) + configFilePath, err := cmd.Flags().GetString("config") if err != nil { - return errors.New("failed to read config: " + err.Error()) - } - cfg := configFile.MainConfigBlock - - go func() { - promServer := http.Server{ - Addr: cfg.PrometheusAddr, - Handler: prometheus.Handler(), - } - log.Infoln("started serving prometheus stats on", cfg.PrometheusAddr) - if err := promServer.ListenAndServe(); err != nil { - log.Fatalln("failed to start prometheus server:", err.Error()) - } - }() - - peerStore, err := memory.New(cfg.Storage) - if err != nil { - return errors.New("failed to create memory storage: " + err.Error()) + return err } - preHooks, postHooks, err := configFile.CreateHooks() + r, err := NewRun(configFilePath) if err != nil { - return errors.New("failed to create hooks: " + err.Error()) + return err } - logic := middleware.NewLogic(cfg.Config, peerStore, preHooks, postHooks) - - errChan := make(chan error) - - httpFrontend, udpFrontend := startFrontends(cfg.HTTPConfig, cfg.UDPConfig, logic, errChan) - - shutdown := make(chan struct{}) quit := make(chan os.Signal) - restart := make(chan os.Signal) signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) - signal.Notify(restart, syscall.SIGUSR1) - go func() { - for { - select { - case <-restart: - log.Info("Got signal to restart") + reload := make(chan os.Signal) + signal.Notify(reload, syscall.SIGUSR1) - // Reload config - configFile, err = ParseConfigFile(configFilePath) - if err != nil { - log.Error("failed to read config: " + err.Error()) - } - cfg = configFile.MainConfigBlock - - preHooks, postHooks, err = configFile.CreateHooks() - if err != nil { - log.Error("failed to create hooks: " + err.Error()) - } - - // Stop frontends and logic - stopFrontends(udpFrontend, httpFrontend) - - stopLogic(logic, errChan) - - // Restart - log.Debug("Restarting logic") - logic = middleware.NewLogic(cfg.Config, peerStore, preHooks, postHooks) - - log.Debug("Restarting frontends") - httpFrontend, udpFrontend = startFrontends(cfg.HTTPConfig, cfg.UDPConfig, logic, errChan) - - log.Debug("Successfully restarted") - - case <-quit: - stop(udpFrontend, httpFrontend, logic, errChan, peerStore) - case <-shutdown: - stop(udpFrontend, httpFrontend, logic, errChan, peerStore) + for { + select { + case <-reload: + log.Info("received SIGUSR1") + peerStore, err := r.Stop(true) + if err != nil { + return err } - } - }() - closed := false - var bufErr error - for err = range errChan { - if err != nil { - if !closed { - close(shutdown) - closed = true - } else { - log.Errorln(bufErr) + if err := r.Start(peerStore); err != nil { + return err } - bufErr = err + case <-quit: + log.Info("received SIGINT/SIGTERM") + if _, err := r.Stop(false); err != nil { + return err + } + + return nil } } - - return bufErr -} - -func stopFrontends(udpFrontend *udpfrontend.Frontend, httpFrontend *httpfrontend.Frontend) { - log.Debug("Stopping frontends") - if udpFrontend != nil { - udpFrontend.Stop() - } - - if httpFrontend != nil { - httpFrontend.Stop() - } -} - -func stopLogic(logic *middleware.Logic, errChan chan error) { - log.Debug("Stopping logic") - errs := logic.Stop() - for _, err := range errs { - errChan <- err - } -} - -func stop(udpFrontend *udpfrontend.Frontend, httpFrontend *httpfrontend.Frontend, logic *middleware.Logic, errChan chan error, peerStore storage.PeerStore) { - stopFrontends(udpFrontend, httpFrontend) - - stopLogic(logic, errChan) - - // Stop storage - log.Debug("Stopping storage") - for err := range peerStore.Stop() { - if err != nil { - errChan <- err - } - } - - close(errChan) -} - -func startFrontends(httpConfig httpfrontend.Config, udpConfig udpfrontend.Config, logic *middleware.Logic, errChan chan<- error) (httpFrontend *httpfrontend.Frontend, udpFrontend *udpfrontend.Frontend) { - if httpConfig.Addr != "" { - httpFrontend = httpfrontend.NewFrontend(logic, httpConfig) - - go func() { - log.Infoln("started serving HTTP on", httpConfig.Addr) - if err := httpFrontend.ListenAndServe(); err != nil { - errChan <- err - } - }() - } - - if udpConfig.Addr != "" { - udpFrontend = udpfrontend.NewFrontend(logic, udpConfig) - - go func() { - log.Infoln("started serving UDP on", udpConfig.Addr) - if err := udpFrontend.ListenAndServe(); err != nil { - errChan <- err - } - }() - } - - return } func main() { @@ -199,17 +174,20 @@ func main() { Use: "chihaya", Short: "BitTorrent Tracker", Long: "A customizable, multi-protocol BitTorrent Tracker", - Run: func(cmd *cobra.Command, args []string) { - if err := rootCmdRun(cmd, args); err != nil { - log.Fatal(err) + PersistentPreRun: func(cmd *cobra.Command, args []string) { + debugLog, _ := cmd.Flags().GetBool("debug") + if debugLog { + log.SetLevel(log.DebugLevel) + log.Debugln("debug logging enabled") } }, + RunE: RunCmdFunc, } rootCmd.Flags().String("config", "/etc/chihaya.yaml", "location of configuration file") rootCmd.Flags().String("cpuprofile", "", "location to save a CPU profile") rootCmd.Flags().Bool("debug", false, "enable debug logging") if err := rootCmd.Execute(); err != nil { - log.Fatal(err) + log.Fatal("failed when executing root cobra command: " + err.Error()) } } diff --git a/frontend/http/frontend.go b/frontend/http/frontend.go index 86a4a62..c2dc3e4 100644 --- a/frontend/http/frontend.go +++ b/frontend/http/frontend.go @@ -5,7 +5,6 @@ package http import ( "context" "crypto/tls" - "errors" "net" "net/http" "time" @@ -72,94 +71,95 @@ type Config struct { TLSKeyPath string `yaml:"tls_key_path"` } -// Frontend holds the state of an HTTP BitTorrent Frontend. +// Frontend represents the state of an HTTP BitTorrent Frontend. type Frontend struct { - s *http.Server + srv *http.Server + tlsCfg *tls.Config logic frontend.TrackerLogic Config } -// NewFrontend allocates a new instance of a Frontend. -func NewFrontend(logic frontend.TrackerLogic, cfg Config) *Frontend { - return &Frontend{ +// NewFrontend creates a new instance of an HTTP Frontend that asynchronously +// serves requests. +func NewFrontend(logic frontend.TrackerLogic, cfg Config) (*Frontend, error) { + f := &Frontend{ logic: logic, Config: cfg, } + + // If TLS is enabled, create a key pair. + if cfg.TLSCertPath != "" && cfg.TLSKeyPath != "" { + var err error + f.tlsCfg = &tls.Config{ + Certificates: make([]tls.Certificate, 1), + } + f.tlsCfg.Certificates[0], err = tls.LoadX509KeyPair(cfg.TLSCertPath, cfg.TLSKeyPath) + if err != nil { + return nil, err + } + } + + go func() { + if err := f.listenAndServe(); err != nil { + log.Fatal("failed while serving http: " + err.Error()) + } + }() + + return f, nil } // Stop provides a thread-safe way to shutdown a currently running Frontend. -func (t *Frontend) Stop() { - if err := t.s.Shutdown(context.Background()); err != nil { - log.Warn("Error shutting down HTTP frontend:", err) - } +func (f *Frontend) Stop() <-chan error { + c := make(chan error) + go func() { + if err := f.srv.Shutdown(context.Background()); err != nil { + c <- err + } else { + close(c) + } + }() + + return c } -func (t *Frontend) handler() http.Handler { +func (f *Frontend) handler() http.Handler { router := httprouter.New() - router.GET("/announce", t.announceRoute) - router.GET("/scrape", t.scrapeRoute) + router.GET("/announce", f.announceRoute) + router.GET("/scrape", f.scrapeRoute) return router } -// ListenAndServe listens on the TCP network address t.Addr and blocks serving -// BitTorrent requests until t.Stop() is called or an error is returned. -func (t *Frontend) ListenAndServe() error { - t.s = &http.Server{ - Addr: t.Addr, - Handler: t.handler(), - ReadTimeout: t.ReadTimeout, - WriteTimeout: t.WriteTimeout, - ConnState: func(conn net.Conn, state http.ConnState) { - switch state { - case http.StateNew: - //stats.RecordEvent(stats.AcceptedConnection) - - case http.StateClosed: - //stats.RecordEvent(stats.ClosedConnection) - - case http.StateHijacked: - panic("http: connection impossibly hijacked") - - // Ignore the following cases. - case http.StateActive, http.StateIdle: - - default: - panic("http: connection transitioned to unknown state") - } - }, - } - t.s.SetKeepAlivesEnabled(false) - - // If TLS is enabled, create a key pair and add it to the HTTP server. - if t.Config.TLSCertPath != "" && t.Config.TLSKeyPath != "" { - var err error - tlsCfg := &tls.Config{ - Certificates: make([]tls.Certificate, 1), - } - tlsCfg.Certificates[0], err = tls.LoadX509KeyPair(t.Config.TLSCertPath, t.Config.TLSKeyPath) - if err != nil { - return err - } - t.s.TLSConfig = tlsCfg +// listenAndServe blocks while listening and serving HTTP BitTorrent requests +// until Stop() is called or an error is returned. +func (f *Frontend) listenAndServe() error { + f.srv = &http.Server{ + Addr: f.Addr, + TLSConfig: f.tlsCfg, + Handler: f.handler(), + ReadTimeout: f.ReadTimeout, + WriteTimeout: f.WriteTimeout, } - // Start the HTTP server and gracefully handle any network errors. - if err := t.s.ListenAndServe(); err != nil && err != http.ErrServerClosed { - return errors.New("http: failed to run HTTP server: " + err.Error()) + // Disable KeepAlives. + f.srv.SetKeepAlivesEnabled(false) + + // Start the HTTP server. + if err := f.srv.ListenAndServe(); err != http.ErrServerClosed { + return err } return nil } -// announceRoute parses and responds to an Announce by using t.TrackerLogic. -func (t *Frontend) announceRoute(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { +// announceRoute parses and responds to an Announce. +func (f *Frontend) announceRoute(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { var err error start := time.Now() var af *bittorrent.AddressFamily defer func() { recordResponseDuration("announce", af, err, time.Since(start)) }() - req, err := ParseAnnounce(r, t.RealIPHeader, t.AllowIPSpoofing) + req, err := ParseAnnounce(r, f.RealIPHeader, f.AllowIPSpoofing) if err != nil { WriteError(w, err) return @@ -167,7 +167,7 @@ func (t *Frontend) announceRoute(w http.ResponseWriter, r *http.Request, _ httpr af = new(bittorrent.AddressFamily) *af = req.IP.AddressFamily - resp, err := t.logic.HandleAnnounce(context.Background(), req) + resp, err := f.logic.HandleAnnounce(context.Background(), req) if err != nil { WriteError(w, err) return @@ -179,11 +179,11 @@ func (t *Frontend) announceRoute(w http.ResponseWriter, r *http.Request, _ httpr return } - go t.logic.AfterAnnounce(context.Background(), req, resp) + go f.logic.AfterAnnounce(context.Background(), req, resp) } -// scrapeRoute parses and responds to a Scrape by using t.TrackerLogic. -func (t *Frontend) scrapeRoute(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { +// scrapeRoute parses and responds to a Scrape. +func (f *Frontend) scrapeRoute(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { var err error start := time.Now() var af *bittorrent.AddressFamily @@ -215,7 +215,7 @@ func (t *Frontend) scrapeRoute(w http.ResponseWriter, r *http.Request, _ httprou af = new(bittorrent.AddressFamily) *af = req.AddressFamily - resp, err := t.logic.HandleScrape(context.Background(), req) + resp, err := f.logic.HandleScrape(context.Background(), req) if err != nil { WriteError(w, err) return @@ -227,5 +227,5 @@ func (t *Frontend) scrapeRoute(w http.ResponseWriter, r *http.Request, _ httprou return } - go t.logic.AfterScrape(context.Background(), req, resp) + go f.logic.AfterScrape(context.Background(), req, resp) } diff --git a/frontend/udp/frontend.go b/frontend/udp/frontend.go index 441af64..6678310 100644 --- a/frontend/udp/frontend.go +++ b/frontend/udp/frontend.go @@ -17,6 +17,7 @@ import ( "github.com/chihaya/chihaya/bittorrent" "github.com/chihaya/chihaya/frontend" "github.com/chihaya/chihaya/frontend/udp/bytepool" + "github.com/chihaya/chihaya/pkg/stop" ) var allowedGeneratedPrivateKeyRunes = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890") @@ -82,8 +83,9 @@ type Frontend struct { Config } -// NewFrontend allocates a new instance of a Frontend. -func NewFrontend(logic frontend.TrackerLogic, cfg Config) *Frontend { +// NewFrontend creates a new instance of an UDP Frontend that asynchronously +// serves requests. +func NewFrontend(logic frontend.TrackerLogic, cfg Config) (*Frontend, error) { // Generate a private key if one isn't provided by the user. if cfg.PrivateKey == "" { rand.Seed(time.Now().UnixNano()) @@ -96,40 +98,68 @@ func NewFrontend(logic frontend.TrackerLogic, cfg Config) *Frontend { log.Warn("UDP private key was not provided, using generated key: ", cfg.PrivateKey) } - return &Frontend{ + f := &Frontend{ closing: make(chan struct{}), logic: logic, Config: cfg, } + + go func() { + if err := f.listenAndServe(); err != nil { + log.Fatal("failed while serving udp: " + err.Error()) + } + }() + + return f, nil } // Stop provides a thread-safe way to shutdown a currently running Frontend. -func (t *Frontend) Stop() { - close(t.closing) - t.socket.SetReadDeadline(time.Now()) - t.wg.Wait() +func (t *Frontend) Stop() <-chan error { + select { + case <-t.closing: + return stop.AlreadyStopped + default: + } + + c := make(chan error) + go func() { + close(t.closing) + t.socket.SetReadDeadline(time.Now()) + t.wg.Wait() + if err := t.socket.Close(); err != nil { + c <- err + } else { + close(c) + } + }() + + return c } -// ListenAndServe listens on the UDP network address t.Addr and blocks serving -// BitTorrent requests until t.Stop() is called or an error is returned. -func (t *Frontend) ListenAndServe() error { +// listenAndServe blocks while listening and serving UDP BitTorrent requests +// until Stop() is called or an error is returned. +func (t *Frontend) listenAndServe() error { udpAddr, err := net.ResolveUDPAddr("udp", t.Addr) if err != nil { return err } + log.Debugf("listening on udp socket") t.socket, err = net.ListenUDP("udp", udpAddr) if err != nil { return err } - defer t.socket.Close() pool := bytepool.New(2048) + t.wg.Add(1) + defer t.wg.Done() + for { // Check to see if we need to shutdown. select { case <-t.closing: + log.Debugf("returning from udp listen&serve") return nil default: } diff --git a/pkg/prometheus/server.go b/pkg/prometheus/server.go new file mode 100644 index 0000000..26e0c95 --- /dev/null +++ b/pkg/prometheus/server.go @@ -0,0 +1,50 @@ +// Package prometheus implements a standalone HTTP server for serving a +// Prometheus metrics endpoint. +package prometheus + +import ( + "context" + "net/http" + + log "github.com/Sirupsen/logrus" + "github.com/prometheus/client_golang/prometheus" +) + +// Server represents a standalone HTTP server for serving a Prometheus metrics +// endpoint. +type Server struct { + srv *http.Server +} + +// Stop shuts down the server. +func (s *Server) Stop() <-chan error { + c := make(chan error) + go func() { + if err := s.srv.Shutdown(context.Background()); err != nil { + c <- err + } else { + close(c) + } + }() + + return c +} + +// NewServer creates a new instance of a Prometheus server that asynchronously +// serves requests. +func NewServer(addr string) *Server { + s := &Server{ + srv: &http.Server{ + Addr: addr, + Handler: prometheus.Handler(), + }, + } + + go func() { + if err := s.srv.ListenAndServe(); err != http.ErrServerClosed { + log.Fatal("failed while serving prometheus: " + err.Error()) + } + }() + + return s +} diff --git a/storage/memory/peer_store.go b/storage/memory/peer_store.go index d5b7482..1b300eb 100644 --- a/storage/memory/peer_store.go +++ b/storage/memory/peer_store.go @@ -64,18 +64,20 @@ func New(cfg Config) (storage.PeerStore, error) { } ps := &peerStore{ - shards: make([]*peerShard, shardCount*2), - closed: make(chan struct{}), + shards: make([]*peerShard, shardCount*2), + closing: make(chan struct{}), } for i := 0; i < shardCount*2; i++ { ps.shards[i] = &peerShard{swarms: make(map[bittorrent.InfoHash]swarm)} } + ps.wg.Add(1) go func() { + defer ps.wg.Done() for { select { - case <-ps.closed: + case <-ps.closing: return case <-time.After(cfg.GarbageCollectionInterval): before := time.Now().Add(-cfg.PeerLifetime) @@ -102,19 +104,20 @@ type swarm struct { } type peerStore struct { - shards []*peerShard - closed chan struct{} + shards []*peerShard + closing chan struct{} + wg sync.WaitGroup } var _ storage.PeerStore = &peerStore{} -func (s *peerStore) shardIndex(infoHash bittorrent.InfoHash, af bittorrent.AddressFamily) uint32 { +func (ps *peerStore) shardIndex(infoHash bittorrent.InfoHash, af bittorrent.AddressFamily) uint32 { // There are twice the amount of shards specified by the user, the first // half is dedicated to IPv4 swarms and the second half is dedicated to // IPv6 swarms. - idx := binary.BigEndian.Uint32(infoHash[:4]) % (uint32(len(s.shards)) / 2) + idx := binary.BigEndian.Uint32(infoHash[:4]) % (uint32(len(ps.shards)) / 2) if af == bittorrent.IPv6 { - idx += uint32(len(s.shards) / 2) + idx += uint32(len(ps.shards) / 2) } return idx } @@ -146,16 +149,16 @@ func decodePeerKey(pk serializedPeer) bittorrent.Peer { return peer } -func (s *peerStore) PutSeeder(ih bittorrent.InfoHash, p bittorrent.Peer) error { +func (ps *peerStore) PutSeeder(ih bittorrent.InfoHash, p bittorrent.Peer) error { select { - case <-s.closed: + case <-ps.closing: panic("attempted to interact with stopped memory store") default: } pk := newPeerKey(p) - shard := s.shards[s.shardIndex(ih, p.IP.AddressFamily)] + shard := ps.shards[ps.shardIndex(ih, p.IP.AddressFamily)] shard.Lock() if _, ok := shard.swarms[ih]; !ok { @@ -172,16 +175,16 @@ func (s *peerStore) PutSeeder(ih bittorrent.InfoHash, p bittorrent.Peer) error { return nil } -func (s *peerStore) DeleteSeeder(ih bittorrent.InfoHash, p bittorrent.Peer) error { +func (ps *peerStore) DeleteSeeder(ih bittorrent.InfoHash, p bittorrent.Peer) error { select { - case <-s.closed: + case <-ps.closing: panic("attempted to interact with stopped memory store") default: } pk := newPeerKey(p) - shard := s.shards[s.shardIndex(ih, p.IP.AddressFamily)] + shard := ps.shards[ps.shardIndex(ih, p.IP.AddressFamily)] shard.Lock() if _, ok := shard.swarms[ih]; !ok { @@ -205,16 +208,16 @@ func (s *peerStore) DeleteSeeder(ih bittorrent.InfoHash, p bittorrent.Peer) erro return nil } -func (s *peerStore) PutLeecher(ih bittorrent.InfoHash, p bittorrent.Peer) error { +func (ps *peerStore) PutLeecher(ih bittorrent.InfoHash, p bittorrent.Peer) error { select { - case <-s.closed: + case <-ps.closing: panic("attempted to interact with stopped memory store") default: } pk := newPeerKey(p) - shard := s.shards[s.shardIndex(ih, p.IP.AddressFamily)] + shard := ps.shards[ps.shardIndex(ih, p.IP.AddressFamily)] shard.Lock() if _, ok := shard.swarms[ih]; !ok { @@ -231,16 +234,16 @@ func (s *peerStore) PutLeecher(ih bittorrent.InfoHash, p bittorrent.Peer) error return nil } -func (s *peerStore) DeleteLeecher(ih bittorrent.InfoHash, p bittorrent.Peer) error { +func (ps *peerStore) DeleteLeecher(ih bittorrent.InfoHash, p bittorrent.Peer) error { select { - case <-s.closed: + case <-ps.closing: panic("attempted to interact with stopped memory store") default: } pk := newPeerKey(p) - shard := s.shards[s.shardIndex(ih, p.IP.AddressFamily)] + shard := ps.shards[ps.shardIndex(ih, p.IP.AddressFamily)] shard.Lock() if _, ok := shard.swarms[ih]; !ok { @@ -264,16 +267,16 @@ func (s *peerStore) DeleteLeecher(ih bittorrent.InfoHash, p bittorrent.Peer) err return nil } -func (s *peerStore) GraduateLeecher(ih bittorrent.InfoHash, p bittorrent.Peer) error { +func (ps *peerStore) GraduateLeecher(ih bittorrent.InfoHash, p bittorrent.Peer) error { select { - case <-s.closed: + case <-ps.closing: panic("attempted to interact with stopped memory store") default: } pk := newPeerKey(p) - shard := s.shards[s.shardIndex(ih, p.IP.AddressFamily)] + shard := ps.shards[ps.shardIndex(ih, p.IP.AddressFamily)] shard.Lock() if _, ok := shard.swarms[ih]; !ok { @@ -292,14 +295,14 @@ func (s *peerStore) GraduateLeecher(ih bittorrent.InfoHash, p bittorrent.Peer) e return nil } -func (s *peerStore) AnnouncePeers(ih bittorrent.InfoHash, seeder bool, numWant int, announcer bittorrent.Peer) (peers []bittorrent.Peer, err error) { +func (ps *peerStore) AnnouncePeers(ih bittorrent.InfoHash, seeder bool, numWant int, announcer bittorrent.Peer) (peers []bittorrent.Peer, err error) { select { - case <-s.closed: + case <-ps.closing: panic("attempted to interact with stopped memory store") default: } - shard := s.shards[s.shardIndex(ih, announcer.IP.AddressFamily)] + shard := ps.shards[ps.shardIndex(ih, announcer.IP.AddressFamily)] shard.RLock() if _, ok := shard.swarms[ih]; !ok { @@ -354,15 +357,15 @@ func (s *peerStore) AnnouncePeers(ih bittorrent.InfoHash, seeder bool, numWant i return } -func (s *peerStore) ScrapeSwarm(ih bittorrent.InfoHash, addressFamily bittorrent.AddressFamily) (resp bittorrent.Scrape) { +func (ps *peerStore) ScrapeSwarm(ih bittorrent.InfoHash, addressFamily bittorrent.AddressFamily) (resp bittorrent.Scrape) { select { - case <-s.closed: + case <-ps.closing: panic("attempted to interact with stopped memory store") default: } resp.InfoHash = ih - shard := s.shards[s.shardIndex(ih, addressFamily)] + shard := ps.shards[ps.shardIndex(ih, addressFamily)] shard.RLock() if _, ok := shard.swarms[ih]; !ok { @@ -382,9 +385,9 @@ func (s *peerStore) ScrapeSwarm(ih bittorrent.InfoHash, addressFamily bittorrent // // This function must be able to execute while other methods on this interface // are being executed in parallel. -func (s *peerStore) collectGarbage(cutoff time.Time) error { +func (ps *peerStore) collectGarbage(cutoff time.Time) error { select { - case <-s.closed: + case <-ps.closing: panic("attempted to interact with stopped memory store") default: } @@ -392,7 +395,7 @@ func (s *peerStore) collectGarbage(cutoff time.Time) error { var ihDelta float64 cutoffUnix := cutoff.UnixNano() start := time.Now() - for _, shard := range s.shards { + for _, shard := range ps.shards { shard.RLock() var infohashes []bittorrent.InfoHash for ih := range shard.swarms { @@ -440,16 +443,21 @@ func (s *peerStore) collectGarbage(cutoff time.Time) error { return nil } -func (s *peerStore) Stop() <-chan error { - toReturn := make(chan error) +func (ps *peerStore) Stop() <-chan error { + c := make(chan error) go func() { - shards := make([]*peerShard, len(s.shards)) - for i := 0; i < len(s.shards); i++ { + close(ps.closing) + ps.wg.Wait() + + // Explicitly deallocate our storage. + shards := make([]*peerShard, len(ps.shards)) + for i := 0; i < len(ps.shards); i++ { shards[i] = &peerShard{swarms: make(map[bittorrent.InfoHash]swarm)} } - s.shards = shards - close(s.closed) - close(toReturn) + ps.shards = shards + + close(c) }() - return toReturn + + return c }