From 3c098c07039dbd2e1a46c84b4af8ffc5be20d25b Mon Sep 17 00:00:00 2001
From: Leo Balduf <balduf@hm.edu>
Date: Mon, 28 Nov 2016 20:55:04 +0100
Subject: [PATCH 1/2] middleware: add sanitization hook

---
 bittorrent/bittorrent.go          | 19 ++++++-
 example_config.yaml               |  3 +-
 frontend/http/frontend.go         | 19 ++++++-
 frontend/http/parser.go           |  9 +--
 frontend/udp/frontend.go          | 17 +++++-
 frontend/udp/parser.go            |  6 +-
 frontend/udp/writer.go            |  2 +-
 middleware/hooks.go               | 58 +++++++++++++++++--
 middleware/middleware.go          |  7 ++-
 middleware/middleware_test.go     | 94 +++++++++++++++++++++++++++++++
 storage/memory/peer_store.go      | 53 +++++++++--------
 storage/memory/peer_store_test.go |  2 +-
 storage/storage.go                |  4 +-
 storage/storage_bench.go          |  2 +-
 storage/storage_tests.go          | 16 +++---
 15 files changed, 248 insertions(+), 63 deletions(-)
 create mode 100644 middleware/middleware_test.go

diff --git a/bittorrent/bittorrent.go b/bittorrent/bittorrent.go
index 241c08a..c7a7b0b 100644
--- a/bittorrent/bittorrent.go
+++ b/bittorrent/bittorrent.go
@@ -110,11 +110,26 @@ type Scrape struct {
 	Incomplete uint32
 }
 
+// AddressFamily is the address family of an IP address.
+type AddressFamily uint8
+
+// AddressFamily constants.
+const (
+	IPv4 AddressFamily = iota
+	IPv6
+)
+
+// IP is a net.IP with an AddressFamily.
+type IP struct {
+	net.IP
+	AddressFamily
+}
+
 // Peer represents the connection details of a peer that is returned in an
 // announce response.
 type Peer struct {
 	ID   PeerID
-	IP   net.IP
+	IP   IP
 	Port uint16
 }
 
@@ -122,7 +137,7 @@ type Peer struct {
 func (p Peer) Equal(x Peer) bool { return p.EqualEndpoint(x) && p.ID == x.ID }
 
 // EqualEndpoint reports whether p and x have the same endpoint.
-func (p Peer) EqualEndpoint(x Peer) bool { return p.Port == x.Port && p.IP.Equal(x.IP) }
+func (p Peer) EqualEndpoint(x Peer) bool { return p.Port == x.Port && p.IP.Equal(x.IP.IP) }
 
 // ClientError represents an error that should be exposed to the client over
 // the BitTorrent protocol implementation.
diff --git a/example_config.yaml b/example_config.yaml
index 9d95d45..48fd651 100644
--- a/example_config.yaml
+++ b/example_config.yaml
@@ -1,6 +1,8 @@
 chihaya:
   announce_interval: 15m
   prometheus_addr: localhost:6880
+  max_numwant: 50
+  default_numwant: 25
 
   http:
     addr: 0.0.0.0:6881
@@ -21,7 +23,6 @@ chihaya:
     gc_interval: 14m
     peer_lifetime: 15m
     shards: 1
-    max_numwant: 100
 
   prehooks:
   - name: jwt
diff --git a/frontend/http/frontend.go b/frontend/http/frontend.go
index 849d6f7..a2f0bdd 100644
--- a/frontend/http/frontend.go
+++ b/frontend/http/frontend.go
@@ -13,6 +13,7 @@ import (
 	"github.com/prometheus/client_golang/prometheus"
 	"github.com/tylerb/graceful"
 
+	"github.com/chihaya/chihaya/bittorrent"
 	"github.com/chihaya/chihaya/frontend"
 	"github.com/chihaya/chihaya/middleware"
 )
@@ -22,6 +23,9 @@ func init() {
 	recordResponseDuration("action", nil, time.Second)
 }
 
+// ErrInvalidIP indicates an invalid IP.
+var ErrInvalidIP = bittorrent.ClientError("invalid IP")
+
 var promResponseDurationMilliseconds = prometheus.NewHistogramVec(
 	prometheus.HistogramOpts{
 		Name:    "chihaya_http_response_duration_milliseconds",
@@ -172,8 +176,19 @@ func (t *Frontend) scrapeRoute(w http.ResponseWriter, r *http.Request, _ httprou
 		return
 	}
 
-	ip := net.ParseIP(host)
-	ctx := context.WithValue(context.Background(), middleware.ScrapeIsIPv6Key, len(ip) == net.IPv6len)
+	reqIP := net.ParseIP(host)
+	af := bittorrent.IPv4
+	if reqIP.To4() != nil {
+		af = bittorrent.IPv4
+	} else if len(reqIP) == net.IPv6len { // implies reqIP.To4() == nil
+		af = bittorrent.IPv6
+	} else {
+		log.Errorln("http: invalid IP: neither v4 nor v6, RemoteAddr was", r.RemoteAddr)
+		WriteError(w, ErrInvalidIP)
+		return
+	}
+
+	ctx := context.WithValue(context.Background(), middleware.ScrapeIsIPv6Key, af == bittorrent.IPv6)
 
 	resp, err := t.logic.HandleScrape(ctx, req)
 	if err != nil {
diff --git a/frontend/http/parser.go b/frontend/http/parser.go
index 33a499d..39e13be 100644
--- a/frontend/http/parser.go
+++ b/frontend/http/parser.go
@@ -74,16 +74,11 @@ func ParseAnnounce(r *http.Request, realIPHeader string, allowIPSpoofing bool) (
 	}
 	request.Peer.Port = uint16(port)
 
-	request.Peer.IP = requestedIP(r, qp, realIPHeader, allowIPSpoofing)
-	if request.Peer.IP == nil {
+	request.Peer.IP.IP = requestedIP(r, qp, realIPHeader, allowIPSpoofing)
+	if request.Peer.IP.IP == nil {
 		return nil, bittorrent.ClientError("failed to parse peer IP address")
 	}
 
-	// Sanitize IPv4 addresses to 4 bytes.
-	if ip := request.Peer.IP.To4(); ip != nil {
-		request.Peer.IP = ip
-	}
-
 	return request, nil
 }
 
diff --git a/frontend/udp/frontend.go b/frontend/udp/frontend.go
index 3a74341..69975b8 100644
--- a/frontend/udp/frontend.go
+++ b/frontend/udp/frontend.go
@@ -10,6 +10,7 @@ import (
 	"sync"
 	"time"
 
+	log "github.com/Sirupsen/logrus"
 	"github.com/prometheus/client_golang/prometheus"
 
 	"github.com/chihaya/chihaya/bittorrent"
@@ -23,6 +24,9 @@ func init() {
 	recordResponseDuration("action", nil, time.Second)
 }
 
+// ErrInvalidIP indicates an invalid IP.
+var ErrInvalidIP = bittorrent.ClientError("invalid IP")
+
 var promResponseDurationMilliseconds = prometheus.NewHistogramVec(
 	prometheus.HistogramOpts{
 		Name:    "chihaya_udp_response_duration_milliseconds",
@@ -228,7 +232,18 @@ func (t *Frontend) handleRequest(r Request, w ResponseWriter) (actionName string
 			return
 		}
 
-		ctx := context.WithValue(context.Background(), middleware.ScrapeIsIPv6Key, len(r.IP) == net.IPv6len)
+		af := bittorrent.IPv4
+		if r.IP.To4() != nil {
+			af = bittorrent.IPv4
+		} else if len(r.IP) == net.IPv6len { // implies r.IP.To4() == nil
+			af = bittorrent.IPv6
+		} else {
+			log.Errorln("http: invalid IP: neither v4 nor v6, IP was", r.IP)
+			WriteError(w, txID, ErrInvalidIP)
+			return
+		}
+
+		ctx := context.WithValue(context.Background(), middleware.ScrapeIsIPv6Key, af == bittorrent.IPv6)
 
 		var resp *bittorrent.ScrapeResponse
 		resp, err = t.logic.HandleScrape(ctx, req)
diff --git a/frontend/udp/parser.go b/frontend/udp/parser.go
index 4474a01..9bc453c 100644
--- a/frontend/udp/parser.go
+++ b/frontend/udp/parser.go
@@ -29,10 +29,6 @@ var (
 	// initialConnectionID is the magic initial connection ID specified by BEP 15.
 	initialConnectionID = []byte{0, 0, 0x04, 0x17, 0x27, 0x10, 0x19, 0x80}
 
-	// emptyIPs are the value of an IP field that has been left blank.
-	emptyIPv4 = []byte{0, 0, 0, 0}
-	emptyIPv6 = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
-
 	// eventIDs map values described in BEP 15 to Events.
 	eventIDs = []bittorrent.Event{
 		bittorrent.None,
@@ -105,7 +101,7 @@ func ParseAnnounce(r Request, allowIPSpoofing, v6 bool) (*bittorrent.AnnounceReq
 		Uploaded:   uploaded,
 		Peer: bittorrent.Peer{
 			ID:   bittorrent.PeerIDFromBytes(peerID),
-			IP:   ip,
+			IP:   bittorrent.IP{IP: ip},
 			Port: port,
 		},
 		Params: params,
diff --git a/frontend/udp/writer.go b/frontend/udp/writer.go
index f0800ed..7fc882e 100644
--- a/frontend/udp/writer.go
+++ b/frontend/udp/writer.go
@@ -46,7 +46,7 @@ func WriteAnnounce(w io.Writer, txID []byte, resp *bittorrent.AnnounceResponse,
 	}
 
 	for _, peer := range peers {
-		buf.Write(peer.IP)
+		buf.Write(peer.IP.IP)
 		binary.Write(buf, binary.BigEndian, peer.Port)
 	}
 
diff --git a/middleware/hooks.go b/middleware/hooks.go
index 4fb65dd..bba0aa0 100644
--- a/middleware/hooks.go
+++ b/middleware/hooks.go
@@ -70,6 +70,48 @@ func (h *swarmInteractionHook) HandleScrape(ctx context.Context, _ *bittorrent.S
 // ErrInvalidIP indicates an invalid IP for an Announce.
 var ErrInvalidIP = errors.New("invalid IP")
 
+// sanitizationHook enforces semantic assumptions about requests that may have
+// not been accounted for in a tracker frontend.
+//
+// The SanitizationHook performs the following checks:
+// - maxNumWant: Checks whether the numWant parameter of an announce is below
+//     a limit. Sets it to the limit if the value is higher.
+// - defaultNumWant: Checks whether the numWant parameter of an announce is
+//     zero. Sets it to the default if it is.
+// - IP sanitization: Checks whether the announcing Peer's IP address is either
+//     IPv4 or IPv6. Returns ErrInvalidIP if the address is neither IPv4 nor
+//     IPv6. Sets the Peer.AddressFamily field accordingly. Truncates IPv4
+//     addresses to have a length of 4 bytes.
+type sanitizationHook struct {
+	maxNumWant     uint32
+	defaultNumWant uint32
+}
+
+func (h *sanitizationHook) HandleAnnounce(ctx context.Context, req *bittorrent.AnnounceRequest, resp *bittorrent.AnnounceResponse) (context.Context, error) {
+	if req.NumWant > h.maxNumWant {
+		req.NumWant = h.maxNumWant
+	}
+
+	if req.NumWant == 0 {
+		req.NumWant = h.defaultNumWant
+	}
+
+	if ip := req.Peer.IP.To4(); ip != nil {
+		req.Peer.IP.IP = ip
+		req.Peer.IP.AddressFamily = bittorrent.IPv4
+	} else if len(req.Peer.IP.IP) == net.IPv6len { // implies req.Peer.IP.To4() == nil
+		req.Peer.IP.AddressFamily = bittorrent.IPv6
+	} else {
+		return ctx, ErrInvalidIP
+	}
+
+	return ctx, nil
+}
+
+func (h *sanitizationHook) HandleScrape(ctx context.Context, req *bittorrent.ScrapeRequest, resp *bittorrent.ScrapeResponse) (context.Context, error) {
+	return ctx, nil
+}
+
 type skipResponseHook struct{}
 
 // SkipResponseHookKey is a key for the context of an Announce or Scrape to
@@ -97,7 +139,7 @@ func (h *responseHook) HandleAnnounce(ctx context.Context, req *bittorrent.Annou
 	}
 
 	// Add the Scrape data to the response.
-	s := h.store.ScrapeSwarm(req.InfoHash, len(req.IP) == net.IPv6len)
+	s := h.store.ScrapeSwarm(req.InfoHash, req.IP.AddressFamily)
 	resp.Incomplete = s.Incomplete
 	resp.Complete = s.Complete
 
@@ -123,13 +165,13 @@ func (h *responseHook) appendPeers(req *bittorrent.AnnounceRequest, resp *bittor
 		peers = append(peers, req.Peer)
 	}
 
-	switch len(req.IP) {
-	case net.IPv4len:
+	switch req.IP.AddressFamily {
+	case bittorrent.IPv4:
 		resp.IPv4Peers = peers
-	case net.IPv6len:
+	case bittorrent.IPv6:
 		resp.IPv6Peers = peers
 	default:
-		panic("peer IP is not IPv4 or IPv6 length")
+		panic("attempted to append peer that is neither IPv4 nor IPv6")
 	}
 
 	return nil
@@ -143,7 +185,11 @@ func (h *responseHook) HandleScrape(ctx context.Context, req *bittorrent.ScrapeR
 	v6, _ := ctx.Value(ScrapeIsIPv6Key).(bool)
 
 	for _, infoHash := range req.InfoHashes {
-		resp.Files[infoHash] = h.store.ScrapeSwarm(infoHash, v6)
+		if v6 {
+			resp.Files[infoHash] = h.store.ScrapeSwarm(infoHash, bittorrent.IPv6)
+		} else {
+			resp.Files[infoHash] = h.store.ScrapeSwarm(infoHash, bittorrent.IPv4)
+		}
 	}
 
 	return ctx, nil
diff --git a/middleware/middleware.go b/middleware/middleware.go
index 2911799..9c8ba38 100644
--- a/middleware/middleware.go
+++ b/middleware/middleware.go
@@ -16,6 +16,8 @@ import (
 // Config holds the configuration common across all middleware.
 type Config struct {
 	AnnounceInterval time.Duration `yaml:"announce_interval"`
+	MaxNumWant       uint32        `yaml:"max_numwant"`
+	DefaultNumWant   uint32        `yaml:"default_numwant"`
 }
 
 var _ frontend.TrackerLogic = &Logic{}
@@ -26,10 +28,13 @@ func NewLogic(cfg Config, peerStore storage.PeerStore, preHooks, postHooks []Hoo
 	l := &Logic{
 		announceInterval: cfg.AnnounceInterval,
 		peerStore:        peerStore,
-		preHooks:         append(preHooks, &responseHook{store: peerStore}),
+		preHooks:         []Hook{&sanitizationHook{maxNumWant: cfg.MaxNumWant, defaultNumWant: cfg.DefaultNumWant}},
 		postHooks:        append(postHooks, &swarmInteractionHook{store: peerStore}),
 	}
 
+	l.preHooks = append(l.preHooks, preHooks...)
+	l.preHooks = append(l.preHooks, &responseHook{store: peerStore})
+
 	return l
 }
 
diff --git a/middleware/middleware_test.go b/middleware/middleware_test.go
new file mode 100644
index 0000000..7821d61
--- /dev/null
+++ b/middleware/middleware_test.go
@@ -0,0 +1,94 @@
+package middleware
+
+import (
+	"context"
+	"fmt"
+	"net"
+	"testing"
+
+	"github.com/stretchr/testify/require"
+
+	"github.com/chihaya/chihaya/bittorrent"
+)
+
+// nopHook is a Hook to measure the overhead of a no-operation Hook through
+// benchmarks.
+type nopHook struct{}
+
+func (h *nopHook) HandleAnnounce(ctx context.Context, req *bittorrent.AnnounceRequest, resp *bittorrent.AnnounceResponse) (context.Context, error) {
+	return ctx, nil
+}
+
+func (h *nopHook) HandleScrape(ctx context.Context, req *bittorrent.ScrapeRequest, resp *bittorrent.ScrapeResponse) (context.Context, error) {
+	return ctx, nil
+}
+
+type hookList []Hook
+
+func (hooks hookList) handleAnnounce(ctx context.Context, req *bittorrent.AnnounceRequest) (resp *bittorrent.AnnounceResponse, err error) {
+	resp = &bittorrent.AnnounceResponse{
+		Interval:    60,
+		MinInterval: 60,
+		Compact:     true,
+	}
+
+	for _, h := range []Hook(hooks) {
+		if ctx, err = h.HandleAnnounce(ctx, req, resp); err != nil {
+			return nil, err
+		}
+	}
+
+	return resp, nil
+}
+
+func benchHookListV4(b *testing.B, hooks hookList) {
+	req := &bittorrent.AnnounceRequest{Peer: bittorrent.Peer{IP: bittorrent.IP{IP: net.ParseIP("1.2.3.4"), AddressFamily: bittorrent.IPv4}}}
+	benchHookList(b, hooks, req)
+}
+
+func benchHookListV6(b *testing.B, hooks hookList) {
+	req := &bittorrent.AnnounceRequest{Peer: bittorrent.Peer{IP: bittorrent.IP{IP: net.ParseIP("fc00::0001"), AddressFamily: bittorrent.IPv6}}}
+	benchHookList(b, hooks, req)
+}
+
+func benchHookList(b *testing.B, hooks hookList, req *bittorrent.AnnounceRequest) {
+	ctx := context.Background()
+	b.ResetTimer()
+	for i := 0; i < b.N; i++ {
+		resp, err := hooks.handleAnnounce(ctx, req)
+		require.Nil(b, err)
+		require.NotNil(b, resp)
+	}
+}
+
+func BenchmarkHookOverhead(b *testing.B) {
+	b.Run("none-v4", func(b *testing.B) {
+		benchHookListV4(b, hookList{})
+	})
+
+	b.Run("none-v6", func(b *testing.B) {
+		benchHookListV6(b, hookList{})
+	})
+
+	var nopHooks hookList
+	for i := 1; i < 4; i++ {
+		nopHooks = append(nopHooks, &nopHook{})
+		b.Run(fmt.Sprintf("%dnop-v4", i), func(b *testing.B) {
+			benchHookListV4(b, nopHooks)
+		})
+		b.Run(fmt.Sprintf("%dnop-v6", i), func(b *testing.B) {
+			benchHookListV6(b, nopHooks)
+		})
+	}
+
+	var sanHooks hookList
+	for i := 1; i < 4; i++ {
+		sanHooks = append(sanHooks, &sanitizationHook{maxNumWant: 50})
+		b.Run(fmt.Sprintf("%dsanitation-v4", i), func(b *testing.B) {
+			benchHookListV4(b, sanHooks)
+		})
+		b.Run(fmt.Sprintf("%dsanitation-v6", i), func(b *testing.B) {
+			benchHookListV6(b, sanHooks)
+		})
+	}
+}
diff --git a/storage/memory/peer_store.go b/storage/memory/peer_store.go
index 0ea4ddd..0c78a42 100644
--- a/storage/memory/peer_store.go
+++ b/storage/memory/peer_store.go
@@ -23,7 +23,6 @@ type Config struct {
 	GarbageCollectionInterval time.Duration `yaml:"gc_interval"`
 	PeerLifetime              time.Duration `yaml:"peer_lifetime"`
 	ShardCount                int           `yaml:"shard_count"`
-	MaxNumWant                int           `yaml:"max_numwant"`
 }
 
 // New creates a new PeerStore backed by memory.
@@ -38,9 +37,8 @@ func New(cfg Config) (storage.PeerStore, error) {
 	}
 
 	ps := &peerStore{
-		shards:     make([]*peerShard, shardCount*2),
-		closed:     make(chan struct{}),
-		maxNumWant: cfg.MaxNumWant,
+		shards: make([]*peerShard, shardCount*2),
+		closed: make(chan struct{}),
 	}
 
 	for i := 0; i < shardCount*2; i++ {
@@ -77,39 +75,48 @@ type swarm struct {
 }
 
 type peerStore struct {
-	shards     []*peerShard
-	closed     chan struct{}
-	maxNumWant int
+	shards []*peerShard
+	closed chan struct{}
 }
 
 var _ storage.PeerStore = &peerStore{}
 
-func (s *peerStore) shardIndex(infoHash bittorrent.InfoHash, v6 bool) uint32 {
+func (s *peerStore) shardIndex(infoHash bittorrent.InfoHash, af bittorrent.AddressFamily) uint32 {
 	// There are twice the amount of shards specified by the user, the first
 	// half is dedicated to IPv4 swarms and the second half is dedicated to
 	// IPv6 swarms.
 	idx := binary.BigEndian.Uint32(infoHash[:4]) % (uint32(len(s.shards)) / 2)
-	if v6 {
+	if af == bittorrent.IPv6 {
 		idx += uint32(len(s.shards) / 2)
 	}
 	return idx
 }
 
 func newPeerKey(p bittorrent.Peer) serializedPeer {
-	b := make([]byte, 20+2+len(p.IP))
+	b := make([]byte, 20+2+len(p.IP.IP))
 	copy(b[:20], p.ID[:])
 	binary.BigEndian.PutUint16(b[20:22], p.Port)
-	copy(b[22:], p.IP)
+	copy(b[22:], p.IP.IP)
 
 	return serializedPeer(b)
 }
 
 func decodePeerKey(pk serializedPeer) bittorrent.Peer {
-	return bittorrent.Peer{
+	peer := bittorrent.Peer{
 		ID:   bittorrent.PeerIDFromString(string(pk[:20])),
 		Port: binary.BigEndian.Uint16([]byte(pk[20:22])),
-		IP:   net.IP(pk[22:]),
+		IP:   bittorrent.IP{IP: net.IP(pk[22:])}}
+
+	if ip := peer.IP.To4(); ip != nil {
+		peer.IP.IP = ip
+		peer.IP.AddressFamily = bittorrent.IPv4
+	} else if len(peer.IP.IP) == net.IPv6len { // implies toReturn.IP.To4() == nil
+		peer.IP.AddressFamily = bittorrent.IPv6
+	} else {
+		panic("IP is neither v4 nor v6")
 	}
+
+	return peer
 }
 
 func (s *peerStore) PutSeeder(ih bittorrent.InfoHash, p bittorrent.Peer) error {
@@ -121,7 +128,7 @@ func (s *peerStore) PutSeeder(ih bittorrent.InfoHash, p bittorrent.Peer) error {
 
 	pk := newPeerKey(p)
 
-	shard := s.shards[s.shardIndex(ih, len(p.IP) == net.IPv6len)]
+	shard := s.shards[s.shardIndex(ih, p.IP.AddressFamily)]
 	shard.Lock()
 
 	if _, ok := shard.swarms[ih]; !ok {
@@ -146,7 +153,7 @@ func (s *peerStore) DeleteSeeder(ih bittorrent.InfoHash, p bittorrent.Peer) erro
 
 	pk := newPeerKey(p)
 
-	shard := s.shards[s.shardIndex(ih, len(p.IP) == net.IPv6len)]
+	shard := s.shards[s.shardIndex(ih, p.IP.AddressFamily)]
 	shard.Lock()
 
 	if _, ok := shard.swarms[ih]; !ok {
@@ -178,7 +185,7 @@ func (s *peerStore) PutLeecher(ih bittorrent.InfoHash, p bittorrent.Peer) error
 
 	pk := newPeerKey(p)
 
-	shard := s.shards[s.shardIndex(ih, len(p.IP) == net.IPv6len)]
+	shard := s.shards[s.shardIndex(ih, p.IP.AddressFamily)]
 	shard.Lock()
 
 	if _, ok := shard.swarms[ih]; !ok {
@@ -203,7 +210,7 @@ func (s *peerStore) DeleteLeecher(ih bittorrent.InfoHash, p bittorrent.Peer) err
 
 	pk := newPeerKey(p)
 
-	shard := s.shards[s.shardIndex(ih, len(p.IP) == net.IPv6len)]
+	shard := s.shards[s.shardIndex(ih, p.IP.AddressFamily)]
 	shard.Lock()
 
 	if _, ok := shard.swarms[ih]; !ok {
@@ -235,7 +242,7 @@ func (s *peerStore) GraduateLeecher(ih bittorrent.InfoHash, p bittorrent.Peer) e
 
 	pk := newPeerKey(p)
 
-	shard := s.shards[s.shardIndex(ih, len(p.IP) == net.IPv6len)]
+	shard := s.shards[s.shardIndex(ih, p.IP.AddressFamily)]
 	shard.Lock()
 
 	if _, ok := shard.swarms[ih]; !ok {
@@ -260,11 +267,7 @@ func (s *peerStore) AnnouncePeers(ih bittorrent.InfoHash, seeder bool, numWant i
 	default:
 	}
 
-	if numWant > s.maxNumWant {
-		numWant = s.maxNumWant
-	}
-
-	shard := s.shards[s.shardIndex(ih, len(announcer.IP) == net.IPv6len)]
+	shard := s.shards[s.shardIndex(ih, announcer.IP.AddressFamily)]
 	shard.RLock()
 
 	if _, ok := shard.swarms[ih]; !ok {
@@ -319,14 +322,14 @@ func (s *peerStore) AnnouncePeers(ih bittorrent.InfoHash, seeder bool, numWant i
 	return
 }
 
-func (s *peerStore) ScrapeSwarm(ih bittorrent.InfoHash, v6 bool) (resp bittorrent.Scrape) {
+func (s *peerStore) ScrapeSwarm(ih bittorrent.InfoHash, addressFamily bittorrent.AddressFamily) (resp bittorrent.Scrape) {
 	select {
 	case <-s.closed:
 		panic("attempted to interact with stopped memory store")
 	default:
 	}
 
-	shard := s.shards[s.shardIndex(ih, v6)]
+	shard := s.shards[s.shardIndex(ih, addressFamily)]
 	shard.RLock()
 
 	if _, ok := shard.swarms[ih]; !ok {
diff --git a/storage/memory/peer_store_test.go b/storage/memory/peer_store_test.go
index c1cab14..f2478b1 100644
--- a/storage/memory/peer_store_test.go
+++ b/storage/memory/peer_store_test.go
@@ -9,7 +9,7 @@ import (
 )
 
 func createNew() s.PeerStore {
-	ps, err := New(Config{ShardCount: 1024, GarbageCollectionInterval: 10 * time.Minute, MaxNumWant: 50})
+	ps, err := New(Config{ShardCount: 1024, GarbageCollectionInterval: 10 * time.Minute})
 	if err != nil {
 		panic(err)
 	}
diff --git a/storage/storage.go b/storage/storage.go
index 87c8782..78b31cc 100644
--- a/storage/storage.go
+++ b/storage/storage.go
@@ -57,13 +57,13 @@ type PeerStore interface {
 
 	// ScrapeSwarm returns information required to answer a scrape request
 	// about a swarm identified by the given infohash.
-	// The v6 flag indicates whether or not the IPv6 swarm should be
+	// The AddressFamily indicates whether or not the IPv6 swarm should be
 	// scraped.
 	// The Complete and Incomplete fields of the Scrape must be filled,
 	// filling the Snatches field is optional.
 	// If the infohash is unknown to the PeerStore, an empty Scrape is
 	// returned.
-	ScrapeSwarm(infoHash bittorrent.InfoHash, v6 bool) bittorrent.Scrape
+	ScrapeSwarm(infoHash bittorrent.InfoHash, addressFamily bittorrent.AddressFamily) bittorrent.Scrape
 
 	// Stopper is an interface that expects a Stop method to stop the
 	// PeerStore.
diff --git a/storage/storage_bench.go b/storage/storage_bench.go
index babab7c..6dc1e93 100644
--- a/storage/storage_bench.go
+++ b/storage/storage_bench.go
@@ -45,7 +45,7 @@ func generatePeers() (a [1000]bittorrent.Peer) {
 		port := uint16(r.Uint32())
 		a[i] = bittorrent.Peer{
 			ID:   bittorrent.PeerID(id),
-			IP:   net.IP(ip),
+			IP:   bittorrent.IP{IP: net.IP(ip), AddressFamily: bittorrent.IPv4},
 			Port: port,
 		}
 	}
diff --git a/storage/storage_tests.go b/storage/storage_tests.go
index d3fb4d1..a90e499 100644
--- a/storage/storage_tests.go
+++ b/storage/storage_tests.go
@@ -20,20 +20,20 @@ func TestPeerStore(t *testing.T, p PeerStore) {
 	}{
 		{
 			bittorrent.InfoHashFromString("00000000000000000001"),
-			bittorrent.Peer{ID: bittorrent.PeerIDFromString("00000000000000000001"), Port: 1, IP: net.ParseIP("1.1.1.1").To4()},
+			bittorrent.Peer{ID: bittorrent.PeerIDFromString("00000000000000000001"), Port: 1, IP: bittorrent.IP{IP: net.ParseIP("1.1.1.1").To4(), AddressFamily: bittorrent.IPv4}},
 		},
 		{
 			bittorrent.InfoHashFromString("00000000000000000002"),
-			bittorrent.Peer{ID: bittorrent.PeerIDFromString("00000000000000000002"), Port: 2, IP: net.ParseIP("abab::0001")},
+			bittorrent.Peer{ID: bittorrent.PeerIDFromString("00000000000000000002"), Port: 2, IP: bittorrent.IP{IP: net.ParseIP("abab::0001"), AddressFamily: bittorrent.IPv6}},
 		},
 	}
 
-	v4Peer := bittorrent.Peer{ID: bittorrent.PeerIDFromString("99999999999999999994"), IP: net.ParseIP("99.99.99.99").To4(), Port: 9994}
-	v6Peer := bittorrent.Peer{ID: bittorrent.PeerIDFromString("99999999999999999996"), IP: net.ParseIP("fc00::0001"), Port: 9996}
+	v4Peer := bittorrent.Peer{ID: bittorrent.PeerIDFromString("99999999999999999994"), IP: bittorrent.IP{IP: net.ParseIP("99.99.99.99").To4(), AddressFamily: bittorrent.IPv4}, Port: 9994}
+	v6Peer := bittorrent.Peer{ID: bittorrent.PeerIDFromString("99999999999999999996"), IP: bittorrent.IP{IP: net.ParseIP("fc00::0001"), AddressFamily: bittorrent.IPv6}, Port: 9996}
 
 	for _, c := range testData {
 		peer := v4Peer
-		if len(c.peer.IP) == net.IPv6len {
+		if c.peer.IP.AddressFamily == bittorrent.IPv6 {
 			peer = v6Peer
 		}
 
@@ -48,7 +48,7 @@ func TestPeerStore(t *testing.T, p PeerStore) {
 		require.Equal(t, ErrResourceDoesNotExist, err)
 
 		// Test empty scrape response for non-existent swarms.
-		scrape := p.ScrapeSwarm(c.ih, len(c.peer.IP) == net.IPv6len)
+		scrape := p.ScrapeSwarm(c.ih, c.peer.IP.AddressFamily)
 		require.Equal(t, uint32(0), scrape.Complete)
 		require.Equal(t, uint32(0), scrape.Incomplete)
 		require.Equal(t, uint32(0), scrape.Snatches)
@@ -76,7 +76,7 @@ func TestPeerStore(t *testing.T, p PeerStore) {
 		require.Nil(t, err)
 		require.True(t, containsPeer(peers, c.peer))
 
-		scrape = p.ScrapeSwarm(c.ih, len(c.peer.IP) == net.IPv6len)
+		scrape = p.ScrapeSwarm(c.ih, c.peer.IP.AddressFamily)
 		require.Equal(t, uint32(2), scrape.Incomplete)
 		require.Equal(t, uint32(0), scrape.Complete)
 
@@ -97,7 +97,7 @@ func TestPeerStore(t *testing.T, p PeerStore) {
 		require.Nil(t, err)
 		require.True(t, containsPeer(peers, c.peer))
 
-		scrape = p.ScrapeSwarm(c.ih, len(c.peer.IP) == net.IPv6len)
+		scrape = p.ScrapeSwarm(c.ih, c.peer.IP.AddressFamily)
 		require.Equal(t, uint32(1), scrape.Incomplete)
 		require.Equal(t, uint32(1), scrape.Complete)
 

From 3ae384394466990fbe0f78064a65fd68b76fdada Mon Sep 17 00:00:00 2001
From: Leo Balduf <balduf@hm.edu>
Date: Fri, 20 Jan 2017 20:34:39 +0100
Subject: [PATCH 2/2] bittorrent: add AddressField to ScrapeRequest

---
 bittorrent/bittorrent.go  |  5 +++--
 frontend/http/frontend.go | 10 +++-------
 frontend/udp/frontend.go  | 12 ++++--------
 middleware/hooks.go       |  8 +-------
 4 files changed, 11 insertions(+), 24 deletions(-)

diff --git a/bittorrent/bittorrent.go b/bittorrent/bittorrent.go
index c7a7b0b..14971d3 100644
--- a/bittorrent/bittorrent.go
+++ b/bittorrent/bittorrent.go
@@ -94,8 +94,9 @@ type AnnounceResponse struct {
 
 // ScrapeRequest represents the parsed parameters from a scrape request.
 type ScrapeRequest struct {
-	InfoHashes []InfoHash
-	Params     Params
+	AddressFamily AddressFamily
+	InfoHashes    []InfoHash
+	Params        Params
 }
 
 // ScrapeResponse represents the parameters used to create a scrape response.
diff --git a/frontend/http/frontend.go b/frontend/http/frontend.go
index a2f0bdd..f4f0d34 100644
--- a/frontend/http/frontend.go
+++ b/frontend/http/frontend.go
@@ -15,7 +15,6 @@ import (
 
 	"github.com/chihaya/chihaya/bittorrent"
 	"github.com/chihaya/chihaya/frontend"
-	"github.com/chihaya/chihaya/middleware"
 )
 
 func init() {
@@ -177,20 +176,17 @@ func (t *Frontend) scrapeRoute(w http.ResponseWriter, r *http.Request, _ httprou
 	}
 
 	reqIP := net.ParseIP(host)
-	af := bittorrent.IPv4
 	if reqIP.To4() != nil {
-		af = bittorrent.IPv4
+		req.AddressFamily = bittorrent.IPv4
 	} else if len(reqIP) == net.IPv6len { // implies reqIP.To4() == nil
-		af = bittorrent.IPv6
+		req.AddressFamily = bittorrent.IPv6
 	} else {
 		log.Errorln("http: invalid IP: neither v4 nor v6, RemoteAddr was", r.RemoteAddr)
 		WriteError(w, ErrInvalidIP)
 		return
 	}
 
-	ctx := context.WithValue(context.Background(), middleware.ScrapeIsIPv6Key, af == bittorrent.IPv6)
-
-	resp, err := t.logic.HandleScrape(ctx, req)
+	resp, err := t.logic.HandleScrape(context.Background(), req)
 	if err != nil {
 		WriteError(w, err)
 		return
diff --git a/frontend/udp/frontend.go b/frontend/udp/frontend.go
index 69975b8..419e36d 100644
--- a/frontend/udp/frontend.go
+++ b/frontend/udp/frontend.go
@@ -16,7 +16,6 @@ import (
 	"github.com/chihaya/chihaya/bittorrent"
 	"github.com/chihaya/chihaya/frontend"
 	"github.com/chihaya/chihaya/frontend/udp/bytepool"
-	"github.com/chihaya/chihaya/middleware"
 )
 
 func init() {
@@ -232,21 +231,18 @@ func (t *Frontend) handleRequest(r Request, w ResponseWriter) (actionName string
 			return
 		}
 
-		af := bittorrent.IPv4
 		if r.IP.To4() != nil {
-			af = bittorrent.IPv4
+			req.AddressFamily = bittorrent.IPv4
 		} else if len(r.IP) == net.IPv6len { // implies r.IP.To4() == nil
-			af = bittorrent.IPv6
+			req.AddressFamily = bittorrent.IPv6
 		} else {
-			log.Errorln("http: invalid IP: neither v4 nor v6, IP was", r.IP)
+			log.Errorln("udp: invalid IP: neither v4 nor v6, IP was", r.IP)
 			WriteError(w, txID, ErrInvalidIP)
 			return
 		}
 
-		ctx := context.WithValue(context.Background(), middleware.ScrapeIsIPv6Key, af == bittorrent.IPv6)
-
 		var resp *bittorrent.ScrapeResponse
-		resp, err = t.logic.HandleScrape(ctx, req)
+		resp, err = t.logic.HandleScrape(context.Background(), req)
 		if err != nil {
 			WriteError(w, txID, err)
 			return
diff --git a/middleware/hooks.go b/middleware/hooks.go
index bba0aa0..5ac7525 100644
--- a/middleware/hooks.go
+++ b/middleware/hooks.go
@@ -182,14 +182,8 @@ func (h *responseHook) HandleScrape(ctx context.Context, req *bittorrent.ScrapeR
 		return ctx, nil
 	}
 
-	v6, _ := ctx.Value(ScrapeIsIPv6Key).(bool)
-
 	for _, infoHash := range req.InfoHashes {
-		if v6 {
-			resp.Files[infoHash] = h.store.ScrapeSwarm(infoHash, bittorrent.IPv6)
-		} else {
-			resp.Files[infoHash] = h.store.ScrapeSwarm(infoHash, bittorrent.IPv4)
-		}
+		resp.Files[infoHash] = h.store.ScrapeSwarm(infoHash, req.AddressFamily)
 	}
 
 	return ctx, nil