diff --git a/server/store/middleware/infohash/blacklist.go b/server/store/middleware/infohash/blacklist.go index 47c4eef..3595748 100644 --- a/server/store/middleware/infohash/blacklist.go +++ b/server/store/middleware/infohash/blacklist.go @@ -13,17 +13,22 @@ import ( func init() { tracker.RegisterAnnounceMiddleware("infohash_blacklist", blacklistAnnounceInfohash) tracker.RegisterScrapeMiddlewareConstructor("infohash_blacklist", blacklistScrapeInfohash) + mustGetStore = func() store.StringStore { + return store.MustGetStore() + } } // ErrBlockedInfohash is returned by a middleware if any of the infohashes // contained in an announce or scrape are disallowed. var ErrBlockedInfohash = tracker.ClientError("disallowed infohash") +var mustGetStore func() store.StringStore + // blacklistAnnounceInfohash provides a middleware that only allows announces // for infohashes that are not stored in a StringStore. func blacklistAnnounceInfohash(next tracker.AnnounceHandler) tracker.AnnounceHandler { return func(cfg *chihaya.TrackerConfig, req *chihaya.AnnounceRequest, resp *chihaya.AnnounceResponse) (err error) { - blacklisted, err := store.MustGetStore().HasString(PrefixInfohash + string(req.InfoHash)) + blacklisted, err := mustGetStore().HasString(PrefixInfohash + string(req.InfoHash)) if err != nil { return err } else if blacklisted { @@ -63,7 +68,7 @@ func blacklistScrapeInfohash(c chihaya.MiddlewareConfig) (tracker.ScrapeMiddlewa func blacklistFilterScrape(next tracker.ScrapeHandler) tracker.ScrapeHandler { return func(cfg *chihaya.TrackerConfig, req *chihaya.ScrapeRequest, resp *chihaya.ScrapeResponse) (err error) { blacklisted := false - storage := store.MustGetStore() + storage := mustGetStore() infohashes := req.InfoHashes for i, ih := range infohashes { @@ -84,7 +89,7 @@ func blacklistFilterScrape(next tracker.ScrapeHandler) tracker.ScrapeHandler { func blacklistBlockScrape(next tracker.ScrapeHandler) tracker.ScrapeHandler { return func(cfg *chihaya.TrackerConfig, req *chihaya.ScrapeRequest, resp *chihaya.ScrapeResponse) (err error) { blacklisted := false - storage := store.MustGetStore() + storage := mustGetStore() for _, ih := range req.InfoHashes { blacklisted, err = storage.HasString(PrefixInfohash + string(ih)) diff --git a/server/store/middleware/infohash/blacklist_test.go b/server/store/middleware/infohash/blacklist_test.go index 32a334a..cc06906 100644 --- a/server/store/middleware/infohash/blacklist_test.go +++ b/server/store/middleware/infohash/blacklist_test.go @@ -10,38 +10,42 @@ import ( "github.com/stretchr/testify/assert" "github.com/chihaya/chihaya" - "github.com/chihaya/chihaya/server" "github.com/chihaya/chihaya/server/store" "github.com/chihaya/chihaya/tracker" - - _ "github.com/chihaya/chihaya/server/store/memory" ) -var srv server.Server +type storeMock struct { + strings map[string]struct{} +} + +func (ss *storeMock) PutString(s string) error { + ss.strings[s] = struct{}{} + + return nil +} + +func (ss *storeMock) HasString(s string) (bool, error) { + _, ok := ss.strings[s] + + return ok, nil +} + +func (ss *storeMock) RemoveString(s string) error { + delete(ss.strings, s) + + return nil +} + +var mock store.StringStore = &storeMock{ + strings: make(map[string]struct{}), +} func TestASetUp(t *testing.T) { - serverConfig := chihaya.ServerConfig{ - Name: "store", - Config: store.Config{ - Addr: "localhost:6880", - StringStore: store.DriverConfig{ - Name: "memory", - }, - IPStore: store.DriverConfig{ - Name: "memory", - }, - PeerStore: store.DriverConfig{ - Name: "memory", - }, - }, + mustGetStore = func() store.StringStore { + return mock } - var err error - srv, err = server.New(&serverConfig, &tracker.Tracker{}) - assert.Nil(t, err) - srv.Start() - - store.MustGetStore().PutString(PrefixInfohash + "abc") + mustGetStore().PutString(PrefixInfohash + "abc") } func TestBlacklistAnnounceMiddleware(t *testing.T) { diff --git a/server/store/middleware/infohash/whitelist.go b/server/store/middleware/infohash/whitelist.go index 53e425f..85dec0c 100644 --- a/server/store/middleware/infohash/whitelist.go +++ b/server/store/middleware/infohash/whitelist.go @@ -6,7 +6,6 @@ package infohash import ( "github.com/chihaya/chihaya" - "github.com/chihaya/chihaya/server/store" "github.com/chihaya/chihaya/tracker" ) @@ -22,7 +21,7 @@ const PrefixInfohash = "ih-" // for infohashes that are not stored in a StringStore func whitelistAnnounceInfohash(next tracker.AnnounceHandler) tracker.AnnounceHandler { return func(cfg *chihaya.TrackerConfig, req *chihaya.AnnounceRequest, resp *chihaya.AnnounceResponse) (err error) { - whitelisted, err := store.MustGetStore().HasString(PrefixInfohash + string(req.InfoHash)) + whitelisted, err := mustGetStore().HasString(PrefixInfohash + string(req.InfoHash)) if err != nil { return err @@ -62,7 +61,7 @@ func whitelistScrapeInfohash(c chihaya.MiddlewareConfig) (tracker.ScrapeMiddlewa func whitelistFilterScrape(next tracker.ScrapeHandler) tracker.ScrapeHandler { return func(cfg *chihaya.TrackerConfig, req *chihaya.ScrapeRequest, resp *chihaya.ScrapeResponse) (err error) { whitelisted := false - storage := store.MustGetStore() + storage := mustGetStore() infohashes := req.InfoHashes for i, ih := range infohashes { @@ -83,7 +82,7 @@ func whitelistFilterScrape(next tracker.ScrapeHandler) tracker.ScrapeHandler { func whitelistBlockScrape(next tracker.ScrapeHandler) tracker.ScrapeHandler { return func(cfg *chihaya.TrackerConfig, req *chihaya.ScrapeRequest, resp *chihaya.ScrapeResponse) (err error) { whitelisted := false - storage := store.MustGetStore() + storage := mustGetStore() for _, ih := range req.InfoHashes { whitelisted, err = storage.HasString(PrefixInfohash + string(ih)) diff --git a/server/store/middleware/infohash/whitelist_test.go b/server/store/middleware/infohash/whitelist_test.go index 728846c..f958638 100644 --- a/server/store/middleware/infohash/whitelist_test.go +++ b/server/store/middleware/infohash/whitelist_test.go @@ -94,7 +94,3 @@ func TestWhitelistScrapeMiddlewareFilter(t *testing.T) { assert.Nil(t, err) assert.Equal(t, []chihaya.InfoHash{chihaya.InfoHash("abc")}, req.InfoHashes) } - -func TestZTearDown(t *testing.T) { - srv.Stop() -}