diff --git a/cmd/chihaya/config.go b/cmd/chihaya/config.go index 72664ba..23dcc08 100644 --- a/cmd/chihaya/config.go +++ b/cmd/chihaya/config.go @@ -14,6 +14,7 @@ import ( // Imported to register as middleware drivers. _ "github.com/chihaya/chihaya/middleware/clientapproval" _ "github.com/chihaya/chihaya/middleware/jwt" + _ "github.com/chihaya/chihaya/middleware/torrentapproval" _ "github.com/chihaya/chihaya/middleware/varinterval" // Imported to register as storage drivers. diff --git a/example_config.yaml b/example_config.yaml index b6035b8..01e2519 100644 --- a/example_config.yaml +++ b/example_config.yaml @@ -149,3 +149,12 @@ chihaya: # modify_response_probability: 0.2 # max_increase_delta: 60 # modify_min_interval: true + + # This block defines configuration used for torrent approval, it requires to be given + # hashes for whitelist or for blacklist. Hashes are hexadecimal-encoaded. + #- name: torrent approval + # options: + # whitelist: + # - "a1b2c3d4e5a1b2c3d4e5a1b2c3d4e5a1b2c3d4e5" + # blacklist: + # - "e1d2c3b4a5e1b2c3b4a5e1d2c3b4e5e1d2c3b4a5" diff --git a/middleware/torrentapproval/torrentapproval.go b/middleware/torrentapproval/torrentapproval.go new file mode 100644 index 0000000..e45ae57 --- /dev/null +++ b/middleware/torrentapproval/torrentapproval.go @@ -0,0 +1,109 @@ +// Package torrentapproval implements a Hook that fails an Announce based on a +// whitelist or blacklist of torrent hash. +package torrentapproval + +import ( + "context" + "encoding/hex" + "fmt" + + "gopkg.in/yaml.v2" + + "github.com/chihaya/chihaya/bittorrent" + "github.com/chihaya/chihaya/middleware" +) + +// Name is the name by which this middleware is registered with Chihaya. +const Name = "torrent approval" + +func init() { + middleware.RegisterDriver(Name, driver{}) +} + +var _ middleware.Driver = driver{} + +type driver struct{} + +func (d driver) NewHook(optionBytes []byte) (middleware.Hook, error) { + var cfg Config + err := yaml.Unmarshal(optionBytes, &cfg) + if err != nil { + return nil, fmt.Errorf("invalid options for middleware %s: %s", Name, err) + } + + return NewHook(cfg) +} + +// ErrTorrentUnapproved is the error returned when a torrent hash is invalid. +var ErrTorrentUnapproved = bittorrent.ClientError("unapproved torrent") + +// Config represents all the values required by this middleware to validate +// torrents based on their hash value. +type Config struct { + Whitelist []string `yaml:"whitelist"` + Blacklist []string `yaml:"blacklist"` +} + +type hook struct { + approved map[bittorrent.InfoHash]struct{} + unapproved map[bittorrent.InfoHash]struct{} +} + +// NewHook returns an instance of the torrent approval middleware. +func NewHook(cfg Config) (middleware.Hook, error) { + h := &hook{ + approved: make(map[bittorrent.InfoHash]struct{}), + unapproved: make(map[bittorrent.InfoHash]struct{}), + } + + if len(cfg.Whitelist) > 0 && len(cfg.Blacklist) > 0 { + return nil, fmt.Errorf("using both whitelist and blacklist is invalid") + } + + for _, hashString := range cfg.Whitelist { + hashinfo, err := hex.DecodeString(hashString) + if err != nil { + return nil, fmt.Errorf("whitelist : invalid hash %s", hashString) + } + if len(hashinfo) != 20 { + return nil, fmt.Errorf("whitelist : hash %s is not 20 byes", hashString) + } + h.approved[bittorrent.InfoHashFromBytes(hashinfo)] = struct{}{} + } + + for _, hashString := range cfg.Blacklist { + hashinfo, err := hex.DecodeString(hashString) + if err != nil { + return nil, fmt.Errorf("blacklist : invalid hash %s", hashString) + } + if len(hashinfo) != 20 { + return nil, fmt.Errorf("blacklist : hash %s is not 20 byes", hashString) + } + h.unapproved[bittorrent.InfoHashFromBytes(hashinfo)] = struct{}{} + } + + return h, nil +} + +func (h *hook) HandleAnnounce(ctx context.Context, req *bittorrent.AnnounceRequest, resp *bittorrent.AnnounceResponse) (context.Context, error) { + infohash := req.InfoHash + + if len(h.approved) > 0 { + if _, found := h.approved[infohash]; !found { + return ctx, ErrTorrentUnapproved + } + } + + if len(h.unapproved) > 0 { + if _, found := h.unapproved[infohash]; found { + return ctx, ErrTorrentUnapproved + } + } + + return ctx, nil +} + +func (h *hook) HandleScrape(ctx context.Context, req *bittorrent.ScrapeRequest, resp *bittorrent.ScrapeResponse) (context.Context, error) { + // Scrapes don't require any protection. + return ctx, nil +} diff --git a/middleware/torrentapproval/torrentapproval_test.go b/middleware/torrentapproval/torrentapproval_test.go new file mode 100644 index 0000000..b20966c --- /dev/null +++ b/middleware/torrentapproval/torrentapproval_test.go @@ -0,0 +1,78 @@ +package torrentapproval + +import ( + "context" + "encoding/hex" + "fmt" + "testing" + + "github.com/chihaya/chihaya/bittorrent" + "github.com/stretchr/testify/require" +) + +var cases = []struct { + cfg Config + ih string + approved bool +}{ + // Infohash is whitelisted + { + Config{ + Whitelist: []string{"3532cf2d327fad8448c075b4cb42c8136964a435"}, + }, + "3532cf2d327fad8448c075b4cb42c8136964a435", + true, + }, + // Infohash is not whitelisted + { + Config{ + Whitelist: []string{"3532cf2d327fad8448c075b4cb42c8136964a435"}, + }, + "4532cf2d327fad8448c075b4cb42c8136964a435", + false, + }, + // Infohash is not blacklisted + { + Config{ + Blacklist: []string{"3532cf2d327fad8448c075b4cb42c8136964a435"}, + }, + "4532cf2d327fad8448c075b4cb42c8136964a435", + true, + }, + // Infohash is blacklisted + { + Config{ + Blacklist: []string{"3532cf2d327fad8448c075b4cb42c8136964a435"}, + }, + "3532cf2d327fad8448c075b4cb42c8136964a435", + false, + }, +} + +func TestHandleAnnounce(t *testing.T) { + for _, tt := range cases { + t.Run(fmt.Sprintf("testing hash %s", tt.ih), func(t *testing.T) { + h, err := NewHook(tt.cfg) + require.Nil(t, err) + + ctx := context.Background() + req := &bittorrent.AnnounceRequest{} + resp := &bittorrent.AnnounceResponse{} + + hashbytes, err := hex.DecodeString(tt.ih) + require.Nil(t, err) + + hashinfo := bittorrent.InfoHashFromBytes(hashbytes) + + req.InfoHash = hashinfo + + nctx, err := h.HandleAnnounce(ctx, req, resp) + require.Equal(t, ctx, nctx) + if tt.approved == true { + require.Nil(t, err) + } else { + require.NotNil(t, err) + } + }) + } +}