Merge pull request #258 from mrd0ll4r/sanitation-hook

middleware: add sanitization hook
This commit is contained in:
mrd0ll4r 2017-01-21 21:06:19 +01:00 committed by GitHub
commit 03f0c977d3
15 changed files with 243 additions and 71 deletions

View file

@ -94,8 +94,9 @@ type AnnounceResponse struct {
// ScrapeRequest represents the parsed parameters from a scrape request. // ScrapeRequest represents the parsed parameters from a scrape request.
type ScrapeRequest struct { type ScrapeRequest struct {
InfoHashes []InfoHash AddressFamily AddressFamily
Params Params InfoHashes []InfoHash
Params Params
} }
// ScrapeResponse represents the parameters used to create a scrape response. // ScrapeResponse represents the parameters used to create a scrape response.
@ -110,11 +111,26 @@ type Scrape struct {
Incomplete uint32 Incomplete uint32
} }
// AddressFamily is the address family of an IP address.
type AddressFamily uint8
// AddressFamily constants.
const (
IPv4 AddressFamily = iota
IPv6
)
// IP is a net.IP with an AddressFamily.
type IP struct {
net.IP
AddressFamily
}
// Peer represents the connection details of a peer that is returned in an // Peer represents the connection details of a peer that is returned in an
// announce response. // announce response.
type Peer struct { type Peer struct {
ID PeerID ID PeerID
IP net.IP IP IP
Port uint16 Port uint16
} }
@ -122,7 +138,7 @@ type Peer struct {
func (p Peer) Equal(x Peer) bool { return p.EqualEndpoint(x) && p.ID == x.ID } func (p Peer) Equal(x Peer) bool { return p.EqualEndpoint(x) && p.ID == x.ID }
// EqualEndpoint reports whether p and x have the same endpoint. // EqualEndpoint reports whether p and x have the same endpoint.
func (p Peer) EqualEndpoint(x Peer) bool { return p.Port == x.Port && p.IP.Equal(x.IP) } func (p Peer) EqualEndpoint(x Peer) bool { return p.Port == x.Port && p.IP.Equal(x.IP.IP) }
// ClientError represents an error that should be exposed to the client over // ClientError represents an error that should be exposed to the client over
// the BitTorrent protocol implementation. // the BitTorrent protocol implementation.

View file

@ -1,6 +1,8 @@
chihaya: chihaya:
announce_interval: 15m announce_interval: 15m
prometheus_addr: localhost:6880 prometheus_addr: localhost:6880
max_numwant: 50
default_numwant: 25
http: http:
addr: 0.0.0.0:6881 addr: 0.0.0.0:6881
@ -21,7 +23,6 @@ chihaya:
gc_interval: 14m gc_interval: 14m
peer_lifetime: 15m peer_lifetime: 15m
shards: 1 shards: 1
max_numwant: 100
prehooks: prehooks:
- name: jwt - name: jwt

View file

@ -13,8 +13,8 @@ import (
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"github.com/tylerb/graceful" "github.com/tylerb/graceful"
"github.com/chihaya/chihaya/bittorrent"
"github.com/chihaya/chihaya/frontend" "github.com/chihaya/chihaya/frontend"
"github.com/chihaya/chihaya/middleware"
) )
func init() { func init() {
@ -22,6 +22,9 @@ func init() {
recordResponseDuration("action", nil, time.Second) recordResponseDuration("action", nil, time.Second)
} }
// ErrInvalidIP indicates an invalid IP.
var ErrInvalidIP = bittorrent.ClientError("invalid IP")
var promResponseDurationMilliseconds = prometheus.NewHistogramVec( var promResponseDurationMilliseconds = prometheus.NewHistogramVec(
prometheus.HistogramOpts{ prometheus.HistogramOpts{
Name: "chihaya_http_response_duration_milliseconds", Name: "chihaya_http_response_duration_milliseconds",
@ -172,10 +175,18 @@ func (t *Frontend) scrapeRoute(w http.ResponseWriter, r *http.Request, _ httprou
return return
} }
ip := net.ParseIP(host) reqIP := net.ParseIP(host)
ctx := context.WithValue(context.Background(), middleware.ScrapeIsIPv6Key, len(ip) == net.IPv6len) if reqIP.To4() != nil {
req.AddressFamily = bittorrent.IPv4
} else if len(reqIP) == net.IPv6len { // implies reqIP.To4() == nil
req.AddressFamily = bittorrent.IPv6
} else {
log.Errorln("http: invalid IP: neither v4 nor v6, RemoteAddr was", r.RemoteAddr)
WriteError(w, ErrInvalidIP)
return
}
resp, err := t.logic.HandleScrape(ctx, req) resp, err := t.logic.HandleScrape(context.Background(), req)
if err != nil { if err != nil {
WriteError(w, err) WriteError(w, err)
return return

View file

@ -74,16 +74,11 @@ func ParseAnnounce(r *http.Request, realIPHeader string, allowIPSpoofing bool) (
} }
request.Peer.Port = uint16(port) request.Peer.Port = uint16(port)
request.Peer.IP = requestedIP(r, qp, realIPHeader, allowIPSpoofing) request.Peer.IP.IP = requestedIP(r, qp, realIPHeader, allowIPSpoofing)
if request.Peer.IP == nil { if request.Peer.IP.IP == nil {
return nil, bittorrent.ClientError("failed to parse peer IP address") return nil, bittorrent.ClientError("failed to parse peer IP address")
} }
// Sanitize IPv4 addresses to 4 bytes.
if ip := request.Peer.IP.To4(); ip != nil {
request.Peer.IP = ip
}
return request, nil return request, nil
} }

View file

@ -10,12 +10,12 @@ import (
"sync" "sync"
"time" "time"
log "github.com/Sirupsen/logrus"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"github.com/chihaya/chihaya/bittorrent" "github.com/chihaya/chihaya/bittorrent"
"github.com/chihaya/chihaya/frontend" "github.com/chihaya/chihaya/frontend"
"github.com/chihaya/chihaya/frontend/udp/bytepool" "github.com/chihaya/chihaya/frontend/udp/bytepool"
"github.com/chihaya/chihaya/middleware"
) )
func init() { func init() {
@ -23,6 +23,9 @@ func init() {
recordResponseDuration("action", nil, time.Second) recordResponseDuration("action", nil, time.Second)
} }
// ErrInvalidIP indicates an invalid IP.
var ErrInvalidIP = bittorrent.ClientError("invalid IP")
var promResponseDurationMilliseconds = prometheus.NewHistogramVec( var promResponseDurationMilliseconds = prometheus.NewHistogramVec(
prometheus.HistogramOpts{ prometheus.HistogramOpts{
Name: "chihaya_udp_response_duration_milliseconds", Name: "chihaya_udp_response_duration_milliseconds",
@ -228,10 +231,18 @@ func (t *Frontend) handleRequest(r Request, w ResponseWriter) (actionName string
return return
} }
ctx := context.WithValue(context.Background(), middleware.ScrapeIsIPv6Key, len(r.IP) == net.IPv6len) if r.IP.To4() != nil {
req.AddressFamily = bittorrent.IPv4
} else if len(r.IP) == net.IPv6len { // implies r.IP.To4() == nil
req.AddressFamily = bittorrent.IPv6
} else {
log.Errorln("udp: invalid IP: neither v4 nor v6, IP was", r.IP)
WriteError(w, txID, ErrInvalidIP)
return
}
var resp *bittorrent.ScrapeResponse var resp *bittorrent.ScrapeResponse
resp, err = t.logic.HandleScrape(ctx, req) resp, err = t.logic.HandleScrape(context.Background(), req)
if err != nil { if err != nil {
WriteError(w, txID, err) WriteError(w, txID, err)
return return

View file

@ -29,10 +29,6 @@ var (
// initialConnectionID is the magic initial connection ID specified by BEP 15. // initialConnectionID is the magic initial connection ID specified by BEP 15.
initialConnectionID = []byte{0, 0, 0x04, 0x17, 0x27, 0x10, 0x19, 0x80} initialConnectionID = []byte{0, 0, 0x04, 0x17, 0x27, 0x10, 0x19, 0x80}
// emptyIPs are the value of an IP field that has been left blank.
emptyIPv4 = []byte{0, 0, 0, 0}
emptyIPv6 = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
// eventIDs map values described in BEP 15 to Events. // eventIDs map values described in BEP 15 to Events.
eventIDs = []bittorrent.Event{ eventIDs = []bittorrent.Event{
bittorrent.None, bittorrent.None,
@ -105,7 +101,7 @@ func ParseAnnounce(r Request, allowIPSpoofing, v6 bool) (*bittorrent.AnnounceReq
Uploaded: uploaded, Uploaded: uploaded,
Peer: bittorrent.Peer{ Peer: bittorrent.Peer{
ID: bittorrent.PeerIDFromBytes(peerID), ID: bittorrent.PeerIDFromBytes(peerID),
IP: ip, IP: bittorrent.IP{IP: ip},
Port: port, Port: port,
}, },
Params: params, Params: params,

View file

@ -46,7 +46,7 @@ func WriteAnnounce(w io.Writer, txID []byte, resp *bittorrent.AnnounceResponse,
} }
for _, peer := range peers { for _, peer := range peers {
buf.Write(peer.IP) buf.Write(peer.IP.IP)
binary.Write(buf, binary.BigEndian, peer.Port) binary.Write(buf, binary.BigEndian, peer.Port)
} }

View file

@ -70,6 +70,48 @@ func (h *swarmInteractionHook) HandleScrape(ctx context.Context, _ *bittorrent.S
// ErrInvalidIP indicates an invalid IP for an Announce. // ErrInvalidIP indicates an invalid IP for an Announce.
var ErrInvalidIP = errors.New("invalid IP") var ErrInvalidIP = errors.New("invalid IP")
// sanitizationHook enforces semantic assumptions about requests that may have
// not been accounted for in a tracker frontend.
//
// The SanitizationHook performs the following checks:
// - maxNumWant: Checks whether the numWant parameter of an announce is below
// a limit. Sets it to the limit if the value is higher.
// - defaultNumWant: Checks whether the numWant parameter of an announce is
// zero. Sets it to the default if it is.
// - IP sanitization: Checks whether the announcing Peer's IP address is either
// IPv4 or IPv6. Returns ErrInvalidIP if the address is neither IPv4 nor
// IPv6. Sets the Peer.AddressFamily field accordingly. Truncates IPv4
// addresses to have a length of 4 bytes.
type sanitizationHook struct {
maxNumWant uint32
defaultNumWant uint32
}
func (h *sanitizationHook) HandleAnnounce(ctx context.Context, req *bittorrent.AnnounceRequest, resp *bittorrent.AnnounceResponse) (context.Context, error) {
if req.NumWant > h.maxNumWant {
req.NumWant = h.maxNumWant
}
if req.NumWant == 0 {
req.NumWant = h.defaultNumWant
}
if ip := req.Peer.IP.To4(); ip != nil {
req.Peer.IP.IP = ip
req.Peer.IP.AddressFamily = bittorrent.IPv4
} else if len(req.Peer.IP.IP) == net.IPv6len { // implies req.Peer.IP.To4() == nil
req.Peer.IP.AddressFamily = bittorrent.IPv6
} else {
return ctx, ErrInvalidIP
}
return ctx, nil
}
func (h *sanitizationHook) HandleScrape(ctx context.Context, req *bittorrent.ScrapeRequest, resp *bittorrent.ScrapeResponse) (context.Context, error) {
return ctx, nil
}
type skipResponseHook struct{} type skipResponseHook struct{}
// SkipResponseHookKey is a key for the context of an Announce or Scrape to // SkipResponseHookKey is a key for the context of an Announce or Scrape to
@ -97,7 +139,7 @@ func (h *responseHook) HandleAnnounce(ctx context.Context, req *bittorrent.Annou
} }
// Add the Scrape data to the response. // Add the Scrape data to the response.
s := h.store.ScrapeSwarm(req.InfoHash, len(req.IP) == net.IPv6len) s := h.store.ScrapeSwarm(req.InfoHash, req.IP.AddressFamily)
resp.Incomplete = s.Incomplete resp.Incomplete = s.Incomplete
resp.Complete = s.Complete resp.Complete = s.Complete
@ -123,13 +165,13 @@ func (h *responseHook) appendPeers(req *bittorrent.AnnounceRequest, resp *bittor
peers = append(peers, req.Peer) peers = append(peers, req.Peer)
} }
switch len(req.IP) { switch req.IP.AddressFamily {
case net.IPv4len: case bittorrent.IPv4:
resp.IPv4Peers = peers resp.IPv4Peers = peers
case net.IPv6len: case bittorrent.IPv6:
resp.IPv6Peers = peers resp.IPv6Peers = peers
default: default:
panic("peer IP is not IPv4 or IPv6 length") panic("attempted to append peer that is neither IPv4 nor IPv6")
} }
return nil return nil
@ -140,10 +182,8 @@ func (h *responseHook) HandleScrape(ctx context.Context, req *bittorrent.ScrapeR
return ctx, nil return ctx, nil
} }
v6, _ := ctx.Value(ScrapeIsIPv6Key).(bool)
for _, infoHash := range req.InfoHashes { for _, infoHash := range req.InfoHashes {
resp.Files[infoHash] = h.store.ScrapeSwarm(infoHash, v6) resp.Files[infoHash] = h.store.ScrapeSwarm(infoHash, req.AddressFamily)
} }
return ctx, nil return ctx, nil

View file

@ -16,6 +16,8 @@ import (
// Config holds the configuration common across all middleware. // Config holds the configuration common across all middleware.
type Config struct { type Config struct {
AnnounceInterval time.Duration `yaml:"announce_interval"` AnnounceInterval time.Duration `yaml:"announce_interval"`
MaxNumWant uint32 `yaml:"max_numwant"`
DefaultNumWant uint32 `yaml:"default_numwant"`
} }
var _ frontend.TrackerLogic = &Logic{} var _ frontend.TrackerLogic = &Logic{}
@ -26,10 +28,13 @@ func NewLogic(cfg Config, peerStore storage.PeerStore, preHooks, postHooks []Hoo
l := &Logic{ l := &Logic{
announceInterval: cfg.AnnounceInterval, announceInterval: cfg.AnnounceInterval,
peerStore: peerStore, peerStore: peerStore,
preHooks: append(preHooks, &responseHook{store: peerStore}), preHooks: []Hook{&sanitizationHook{maxNumWant: cfg.MaxNumWant, defaultNumWant: cfg.DefaultNumWant}},
postHooks: append(postHooks, &swarmInteractionHook{store: peerStore}), postHooks: append(postHooks, &swarmInteractionHook{store: peerStore}),
} }
l.preHooks = append(l.preHooks, preHooks...)
l.preHooks = append(l.preHooks, &responseHook{store: peerStore})
return l return l
} }

View file

@ -0,0 +1,94 @@
package middleware
import (
"context"
"fmt"
"net"
"testing"
"github.com/stretchr/testify/require"
"github.com/chihaya/chihaya/bittorrent"
)
// nopHook is a Hook to measure the overhead of a no-operation Hook through
// benchmarks.
type nopHook struct{}
func (h *nopHook) HandleAnnounce(ctx context.Context, req *bittorrent.AnnounceRequest, resp *bittorrent.AnnounceResponse) (context.Context, error) {
return ctx, nil
}
func (h *nopHook) HandleScrape(ctx context.Context, req *bittorrent.ScrapeRequest, resp *bittorrent.ScrapeResponse) (context.Context, error) {
return ctx, nil
}
type hookList []Hook
func (hooks hookList) handleAnnounce(ctx context.Context, req *bittorrent.AnnounceRequest) (resp *bittorrent.AnnounceResponse, err error) {
resp = &bittorrent.AnnounceResponse{
Interval: 60,
MinInterval: 60,
Compact: true,
}
for _, h := range []Hook(hooks) {
if ctx, err = h.HandleAnnounce(ctx, req, resp); err != nil {
return nil, err
}
}
return resp, nil
}
func benchHookListV4(b *testing.B, hooks hookList) {
req := &bittorrent.AnnounceRequest{Peer: bittorrent.Peer{IP: bittorrent.IP{IP: net.ParseIP("1.2.3.4"), AddressFamily: bittorrent.IPv4}}}
benchHookList(b, hooks, req)
}
func benchHookListV6(b *testing.B, hooks hookList) {
req := &bittorrent.AnnounceRequest{Peer: bittorrent.Peer{IP: bittorrent.IP{IP: net.ParseIP("fc00::0001"), AddressFamily: bittorrent.IPv6}}}
benchHookList(b, hooks, req)
}
func benchHookList(b *testing.B, hooks hookList, req *bittorrent.AnnounceRequest) {
ctx := context.Background()
b.ResetTimer()
for i := 0; i < b.N; i++ {
resp, err := hooks.handleAnnounce(ctx, req)
require.Nil(b, err)
require.NotNil(b, resp)
}
}
func BenchmarkHookOverhead(b *testing.B) {
b.Run("none-v4", func(b *testing.B) {
benchHookListV4(b, hookList{})
})
b.Run("none-v6", func(b *testing.B) {
benchHookListV6(b, hookList{})
})
var nopHooks hookList
for i := 1; i < 4; i++ {
nopHooks = append(nopHooks, &nopHook{})
b.Run(fmt.Sprintf("%dnop-v4", i), func(b *testing.B) {
benchHookListV4(b, nopHooks)
})
b.Run(fmt.Sprintf("%dnop-v6", i), func(b *testing.B) {
benchHookListV6(b, nopHooks)
})
}
var sanHooks hookList
for i := 1; i < 4; i++ {
sanHooks = append(sanHooks, &sanitizationHook{maxNumWant: 50})
b.Run(fmt.Sprintf("%dsanitation-v4", i), func(b *testing.B) {
benchHookListV4(b, sanHooks)
})
b.Run(fmt.Sprintf("%dsanitation-v6", i), func(b *testing.B) {
benchHookListV6(b, sanHooks)
})
}
}

View file

@ -23,7 +23,6 @@ type Config struct {
GarbageCollectionInterval time.Duration `yaml:"gc_interval"` GarbageCollectionInterval time.Duration `yaml:"gc_interval"`
PeerLifetime time.Duration `yaml:"peer_lifetime"` PeerLifetime time.Duration `yaml:"peer_lifetime"`
ShardCount int `yaml:"shard_count"` ShardCount int `yaml:"shard_count"`
MaxNumWant int `yaml:"max_numwant"`
} }
// New creates a new PeerStore backed by memory. // New creates a new PeerStore backed by memory.
@ -38,9 +37,8 @@ func New(cfg Config) (storage.PeerStore, error) {
} }
ps := &peerStore{ ps := &peerStore{
shards: make([]*peerShard, shardCount*2), shards: make([]*peerShard, shardCount*2),
closed: make(chan struct{}), closed: make(chan struct{}),
maxNumWant: cfg.MaxNumWant,
} }
for i := 0; i < shardCount*2; i++ { for i := 0; i < shardCount*2; i++ {
@ -77,39 +75,48 @@ type swarm struct {
} }
type peerStore struct { type peerStore struct {
shards []*peerShard shards []*peerShard
closed chan struct{} closed chan struct{}
maxNumWant int
} }
var _ storage.PeerStore = &peerStore{} var _ storage.PeerStore = &peerStore{}
func (s *peerStore) shardIndex(infoHash bittorrent.InfoHash, v6 bool) uint32 { func (s *peerStore) shardIndex(infoHash bittorrent.InfoHash, af bittorrent.AddressFamily) uint32 {
// There are twice the amount of shards specified by the user, the first // There are twice the amount of shards specified by the user, the first
// half is dedicated to IPv4 swarms and the second half is dedicated to // half is dedicated to IPv4 swarms and the second half is dedicated to
// IPv6 swarms. // IPv6 swarms.
idx := binary.BigEndian.Uint32(infoHash[:4]) % (uint32(len(s.shards)) / 2) idx := binary.BigEndian.Uint32(infoHash[:4]) % (uint32(len(s.shards)) / 2)
if v6 { if af == bittorrent.IPv6 {
idx += uint32(len(s.shards) / 2) idx += uint32(len(s.shards) / 2)
} }
return idx return idx
} }
func newPeerKey(p bittorrent.Peer) serializedPeer { func newPeerKey(p bittorrent.Peer) serializedPeer {
b := make([]byte, 20+2+len(p.IP)) b := make([]byte, 20+2+len(p.IP.IP))
copy(b[:20], p.ID[:]) copy(b[:20], p.ID[:])
binary.BigEndian.PutUint16(b[20:22], p.Port) binary.BigEndian.PutUint16(b[20:22], p.Port)
copy(b[22:], p.IP) copy(b[22:], p.IP.IP)
return serializedPeer(b) return serializedPeer(b)
} }
func decodePeerKey(pk serializedPeer) bittorrent.Peer { func decodePeerKey(pk serializedPeer) bittorrent.Peer {
return bittorrent.Peer{ peer := bittorrent.Peer{
ID: bittorrent.PeerIDFromString(string(pk[:20])), ID: bittorrent.PeerIDFromString(string(pk[:20])),
Port: binary.BigEndian.Uint16([]byte(pk[20:22])), Port: binary.BigEndian.Uint16([]byte(pk[20:22])),
IP: net.IP(pk[22:]), IP: bittorrent.IP{IP: net.IP(pk[22:])}}
if ip := peer.IP.To4(); ip != nil {
peer.IP.IP = ip
peer.IP.AddressFamily = bittorrent.IPv4
} else if len(peer.IP.IP) == net.IPv6len { // implies toReturn.IP.To4() == nil
peer.IP.AddressFamily = bittorrent.IPv6
} else {
panic("IP is neither v4 nor v6")
} }
return peer
} }
func (s *peerStore) PutSeeder(ih bittorrent.InfoHash, p bittorrent.Peer) error { func (s *peerStore) PutSeeder(ih bittorrent.InfoHash, p bittorrent.Peer) error {
@ -121,7 +128,7 @@ func (s *peerStore) PutSeeder(ih bittorrent.InfoHash, p bittorrent.Peer) error {
pk := newPeerKey(p) pk := newPeerKey(p)
shard := s.shards[s.shardIndex(ih, len(p.IP) == net.IPv6len)] shard := s.shards[s.shardIndex(ih, p.IP.AddressFamily)]
shard.Lock() shard.Lock()
if _, ok := shard.swarms[ih]; !ok { if _, ok := shard.swarms[ih]; !ok {
@ -146,7 +153,7 @@ func (s *peerStore) DeleteSeeder(ih bittorrent.InfoHash, p bittorrent.Peer) erro
pk := newPeerKey(p) pk := newPeerKey(p)
shard := s.shards[s.shardIndex(ih, len(p.IP) == net.IPv6len)] shard := s.shards[s.shardIndex(ih, p.IP.AddressFamily)]
shard.Lock() shard.Lock()
if _, ok := shard.swarms[ih]; !ok { if _, ok := shard.swarms[ih]; !ok {
@ -178,7 +185,7 @@ func (s *peerStore) PutLeecher(ih bittorrent.InfoHash, p bittorrent.Peer) error
pk := newPeerKey(p) pk := newPeerKey(p)
shard := s.shards[s.shardIndex(ih, len(p.IP) == net.IPv6len)] shard := s.shards[s.shardIndex(ih, p.IP.AddressFamily)]
shard.Lock() shard.Lock()
if _, ok := shard.swarms[ih]; !ok { if _, ok := shard.swarms[ih]; !ok {
@ -203,7 +210,7 @@ func (s *peerStore) DeleteLeecher(ih bittorrent.InfoHash, p bittorrent.Peer) err
pk := newPeerKey(p) pk := newPeerKey(p)
shard := s.shards[s.shardIndex(ih, len(p.IP) == net.IPv6len)] shard := s.shards[s.shardIndex(ih, p.IP.AddressFamily)]
shard.Lock() shard.Lock()
if _, ok := shard.swarms[ih]; !ok { if _, ok := shard.swarms[ih]; !ok {
@ -235,7 +242,7 @@ func (s *peerStore) GraduateLeecher(ih bittorrent.InfoHash, p bittorrent.Peer) e
pk := newPeerKey(p) pk := newPeerKey(p)
shard := s.shards[s.shardIndex(ih, len(p.IP) == net.IPv6len)] shard := s.shards[s.shardIndex(ih, p.IP.AddressFamily)]
shard.Lock() shard.Lock()
if _, ok := shard.swarms[ih]; !ok { if _, ok := shard.swarms[ih]; !ok {
@ -260,11 +267,7 @@ func (s *peerStore) AnnouncePeers(ih bittorrent.InfoHash, seeder bool, numWant i
default: default:
} }
if numWant > s.maxNumWant { shard := s.shards[s.shardIndex(ih, announcer.IP.AddressFamily)]
numWant = s.maxNumWant
}
shard := s.shards[s.shardIndex(ih, len(announcer.IP) == net.IPv6len)]
shard.RLock() shard.RLock()
if _, ok := shard.swarms[ih]; !ok { if _, ok := shard.swarms[ih]; !ok {
@ -319,14 +322,14 @@ func (s *peerStore) AnnouncePeers(ih bittorrent.InfoHash, seeder bool, numWant i
return return
} }
func (s *peerStore) ScrapeSwarm(ih bittorrent.InfoHash, v6 bool) (resp bittorrent.Scrape) { func (s *peerStore) ScrapeSwarm(ih bittorrent.InfoHash, addressFamily bittorrent.AddressFamily) (resp bittorrent.Scrape) {
select { select {
case <-s.closed: case <-s.closed:
panic("attempted to interact with stopped memory store") panic("attempted to interact with stopped memory store")
default: default:
} }
shard := s.shards[s.shardIndex(ih, v6)] shard := s.shards[s.shardIndex(ih, addressFamily)]
shard.RLock() shard.RLock()
if _, ok := shard.swarms[ih]; !ok { if _, ok := shard.swarms[ih]; !ok {

View file

@ -9,7 +9,7 @@ import (
) )
func createNew() s.PeerStore { func createNew() s.PeerStore {
ps, err := New(Config{ShardCount: 1024, GarbageCollectionInterval: 10 * time.Minute, MaxNumWant: 50}) ps, err := New(Config{ShardCount: 1024, GarbageCollectionInterval: 10 * time.Minute})
if err != nil { if err != nil {
panic(err) panic(err)
} }

View file

@ -57,13 +57,13 @@ type PeerStore interface {
// ScrapeSwarm returns information required to answer a scrape request // ScrapeSwarm returns information required to answer a scrape request
// about a swarm identified by the given infohash. // about a swarm identified by the given infohash.
// The v6 flag indicates whether or not the IPv6 swarm should be // The AddressFamily indicates whether or not the IPv6 swarm should be
// scraped. // scraped.
// The Complete and Incomplete fields of the Scrape must be filled, // The Complete and Incomplete fields of the Scrape must be filled,
// filling the Snatches field is optional. // filling the Snatches field is optional.
// If the infohash is unknown to the PeerStore, an empty Scrape is // If the infohash is unknown to the PeerStore, an empty Scrape is
// returned. // returned.
ScrapeSwarm(infoHash bittorrent.InfoHash, v6 bool) bittorrent.Scrape ScrapeSwarm(infoHash bittorrent.InfoHash, addressFamily bittorrent.AddressFamily) bittorrent.Scrape
// Stopper is an interface that expects a Stop method to stop the // Stopper is an interface that expects a Stop method to stop the
// PeerStore. // PeerStore.

View file

@ -45,7 +45,7 @@ func generatePeers() (a [1000]bittorrent.Peer) {
port := uint16(r.Uint32()) port := uint16(r.Uint32())
a[i] = bittorrent.Peer{ a[i] = bittorrent.Peer{
ID: bittorrent.PeerID(id), ID: bittorrent.PeerID(id),
IP: net.IP(ip), IP: bittorrent.IP{IP: net.IP(ip), AddressFamily: bittorrent.IPv4},
Port: port, Port: port,
} }
} }

View file

@ -20,20 +20,20 @@ func TestPeerStore(t *testing.T, p PeerStore) {
}{ }{
{ {
bittorrent.InfoHashFromString("00000000000000000001"), bittorrent.InfoHashFromString("00000000000000000001"),
bittorrent.Peer{ID: bittorrent.PeerIDFromString("00000000000000000001"), Port: 1, IP: net.ParseIP("1.1.1.1").To4()}, bittorrent.Peer{ID: bittorrent.PeerIDFromString("00000000000000000001"), Port: 1, IP: bittorrent.IP{IP: net.ParseIP("1.1.1.1").To4(), AddressFamily: bittorrent.IPv4}},
}, },
{ {
bittorrent.InfoHashFromString("00000000000000000002"), bittorrent.InfoHashFromString("00000000000000000002"),
bittorrent.Peer{ID: bittorrent.PeerIDFromString("00000000000000000002"), Port: 2, IP: net.ParseIP("abab::0001")}, bittorrent.Peer{ID: bittorrent.PeerIDFromString("00000000000000000002"), Port: 2, IP: bittorrent.IP{IP: net.ParseIP("abab::0001"), AddressFamily: bittorrent.IPv6}},
}, },
} }
v4Peer := bittorrent.Peer{ID: bittorrent.PeerIDFromString("99999999999999999994"), IP: net.ParseIP("99.99.99.99").To4(), Port: 9994} v4Peer := bittorrent.Peer{ID: bittorrent.PeerIDFromString("99999999999999999994"), IP: bittorrent.IP{IP: net.ParseIP("99.99.99.99").To4(), AddressFamily: bittorrent.IPv4}, Port: 9994}
v6Peer := bittorrent.Peer{ID: bittorrent.PeerIDFromString("99999999999999999996"), IP: net.ParseIP("fc00::0001"), Port: 9996} v6Peer := bittorrent.Peer{ID: bittorrent.PeerIDFromString("99999999999999999996"), IP: bittorrent.IP{IP: net.ParseIP("fc00::0001"), AddressFamily: bittorrent.IPv6}, Port: 9996}
for _, c := range testData { for _, c := range testData {
peer := v4Peer peer := v4Peer
if len(c.peer.IP) == net.IPv6len { if c.peer.IP.AddressFamily == bittorrent.IPv6 {
peer = v6Peer peer = v6Peer
} }
@ -48,7 +48,7 @@ func TestPeerStore(t *testing.T, p PeerStore) {
require.Equal(t, ErrResourceDoesNotExist, err) require.Equal(t, ErrResourceDoesNotExist, err)
// Test empty scrape response for non-existent swarms. // Test empty scrape response for non-existent swarms.
scrape := p.ScrapeSwarm(c.ih, len(c.peer.IP) == net.IPv6len) scrape := p.ScrapeSwarm(c.ih, c.peer.IP.AddressFamily)
require.Equal(t, uint32(0), scrape.Complete) require.Equal(t, uint32(0), scrape.Complete)
require.Equal(t, uint32(0), scrape.Incomplete) require.Equal(t, uint32(0), scrape.Incomplete)
require.Equal(t, uint32(0), scrape.Snatches) require.Equal(t, uint32(0), scrape.Snatches)
@ -76,7 +76,7 @@ func TestPeerStore(t *testing.T, p PeerStore) {
require.Nil(t, err) require.Nil(t, err)
require.True(t, containsPeer(peers, c.peer)) require.True(t, containsPeer(peers, c.peer))
scrape = p.ScrapeSwarm(c.ih, len(c.peer.IP) == net.IPv6len) scrape = p.ScrapeSwarm(c.ih, c.peer.IP.AddressFamily)
require.Equal(t, uint32(2), scrape.Incomplete) require.Equal(t, uint32(2), scrape.Incomplete)
require.Equal(t, uint32(0), scrape.Complete) require.Equal(t, uint32(0), scrape.Complete)
@ -97,7 +97,7 @@ func TestPeerStore(t *testing.T, p PeerStore) {
require.Nil(t, err) require.Nil(t, err)
require.True(t, containsPeer(peers, c.peer)) require.True(t, containsPeer(peers, c.peer))
scrape = p.ScrapeSwarm(c.ih, len(c.peer.IP) == net.IPv6len) scrape = p.ScrapeSwarm(c.ih, c.peer.IP.AddressFamily)
require.Equal(t, uint32(1), scrape.Incomplete) require.Equal(t, uint32(1), scrape.Incomplete)
require.Equal(t, uint32(1), scrape.Complete) require.Equal(t, uint32(1), scrape.Complete)