tracker/frontend/udp/frontend.go

364 lines
9.1 KiB
Go
Raw Normal View History

2016-08-04 20:08:26 +02:00
// Package udp implements a BitTorrent tracker via the UDP protocol as
// described in BEP 15.
package udp
import (
"bytes"
2016-08-17 04:32:15 +02:00
"context"
"encoding/binary"
2022-01-16 05:28:52 +01:00
"errors"
"fmt"
"math/rand"
"net"
2016-08-05 07:47:04 +02:00
"sync"
"time"
2016-08-17 03:42:08 +02:00
"github.com/chihaya/chihaya/bittorrent"
"github.com/chihaya/chihaya/frontend"
"github.com/chihaya/chihaya/frontend/udp/bytepool"
2017-06-20 14:58:44 +02:00
"github.com/chihaya/chihaya/pkg/log"
"github.com/chihaya/chihaya/pkg/stop"
2017-09-29 00:50:20 +02:00
"github.com/chihaya/chihaya/pkg/timecache"
)
var allowedGeneratedPrivateKeyRunes = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890")
2016-08-04 20:48:32 +02:00
// Config represents all of the configurable options for a UDP BitTorrent
// Tracker.
type Config struct {
2017-05-12 13:12:35 +02:00
Addr string `yaml:"addr"`
PrivateKey string `yaml:"private_key"`
MaxClockSkew time.Duration `yaml:"max_clock_skew"`
EnableRequestTiming bool `yaml:"enable_request_timing"`
2017-10-08 23:35:50 +02:00
ParseOptions `yaml:",inline"`
}
2017-05-07 00:48:44 +02:00
// LogFields renders the current config as a set of Logrus fields.
func (cfg Config) LogFields() log.Fields {
return log.Fields{
2017-05-12 13:12:35 +02:00
"addr": cfg.Addr,
"privateKey": cfg.PrivateKey,
"maxClockSkew": cfg.MaxClockSkew,
"enableRequestTiming": cfg.EnableRequestTiming,
2017-10-08 23:35:50 +02:00
"allowIPSpoofing": cfg.AllowIPSpoofing,
"maxNumWant": cfg.MaxNumWant,
"defaultNumWant": cfg.DefaultNumWant,
"maxScrapeInfoHashes": cfg.MaxScrapeInfoHashes,
2017-05-07 00:48:44 +02:00
}
}
// Validate sanity checks values set in a config and returns a new config with
// default values replacing anything that is invalid.
//
// This function warns to the logger when a value is changed.
func (cfg Config) Validate() Config {
validcfg := cfg
// Generate a private key if one isn't provided by the user.
if cfg.PrivateKey == "" {
rand.Seed(time.Now().UnixNano())
pkeyRunes := make([]rune, 64)
for i := range pkeyRunes {
pkeyRunes[i] = allowedGeneratedPrivateKeyRunes[rand.Intn(len(allowedGeneratedPrivateKeyRunes))]
}
validcfg.PrivateKey = string(pkeyRunes)
log.Warn("UDP private key was not provided, using generated key", log.Fields{"key": validcfg.PrivateKey})
}
if cfg.MaxNumWant <= 0 {
validcfg.MaxNumWant = defaultMaxNumWant
log.Warn("falling back to default configuration", log.Fields{
"name": "udp.MaxNumWant",
"provided": cfg.MaxNumWant,
"default": validcfg.MaxNumWant,
})
}
if cfg.DefaultNumWant <= 0 {
validcfg.DefaultNumWant = defaultDefaultNumWant
log.Warn("falling back to default configuration", log.Fields{
"name": "udp.DefaultNumWant",
"provided": cfg.DefaultNumWant,
"default": validcfg.DefaultNumWant,
})
}
if cfg.MaxScrapeInfoHashes <= 0 {
validcfg.MaxScrapeInfoHashes = defaultMaxScrapeInfoHashes
log.Warn("falling back to default configuration", log.Fields{
"name": "udp.MaxScrapeInfoHashes",
"provided": cfg.MaxScrapeInfoHashes,
"default": validcfg.MaxScrapeInfoHashes,
})
}
return validcfg
}
2016-08-07 04:41:33 +02:00
// Frontend holds the state of a UDP BitTorrent Frontend.
type Frontend struct {
2016-08-05 07:47:04 +02:00
socket *net.UDPConn
closing chan struct{}
wg sync.WaitGroup
genPool *sync.Pool
2016-08-10 02:08:15 +02:00
logic frontend.TrackerLogic
Config
}
// NewFrontend creates a new instance of an UDP Frontend that asynchronously
// serves requests.
func NewFrontend(logic frontend.TrackerLogic, provided Config) (*Frontend, error) {
cfg := provided.Validate()
f := &Frontend{
closing: make(chan struct{}),
2016-08-10 02:08:15 +02:00
logic: logic,
Config: cfg,
genPool: &sync.Pool{
New: func() interface{} {
return NewConnectionIDGenerator(cfg.PrivateKey)
},
},
}
2022-01-16 05:28:52 +01:00
if err := f.listen(); err != nil {
2018-12-26 16:15:05 +01:00
return nil, err
}
go func() {
2018-12-26 16:15:05 +01:00
if err := f.serve(); err != nil {
2017-06-20 14:58:44 +02:00
log.Fatal("failed while serving udp", log.Err(err))
}
}()
return f, nil
}
2016-08-07 04:41:33 +02:00
// Stop provides a thread-safe way to shutdown a currently running Frontend.
func (t *Frontend) Stop() stop.Result {
select {
case <-t.closing:
return stop.AlreadyStopped
default:
}
c := make(stop.Channel)
go func() {
close(t.closing)
2022-01-15 19:31:14 +01:00
_ = t.socket.SetReadDeadline(time.Now())
t.wg.Wait()
c.Done(t.socket.Close())
}()
return c.Result()
}
2018-12-27 13:17:43 +01:00
// listen resolves the address and binds the server socket.
2018-12-26 16:15:05 +01:00
func (t *Frontend) listen() error {
2016-08-04 06:18:58 +02:00
udpAddr, err := net.ResolveUDPAddr("udp", t.Addr)
if err != nil {
return err
}
2016-08-05 07:47:04 +02:00
t.socket, err = net.ListenUDP("udp", udpAddr)
2018-12-26 16:15:05 +01:00
return err
}
2018-12-27 13:17:43 +01:00
// serve blocks while listening and serving UDP BitTorrent requests
2018-12-26 16:15:05 +01:00
// until Stop() is called or an error is returned.
func (t *Frontend) serve() error {
2016-09-05 18:30:03 +02:00
pool := bytepool.New(2048)
t.wg.Add(1)
defer t.wg.Done()
for {
// Check to see if we need to shutdown.
select {
2016-08-04 06:18:58 +02:00
case <-t.closing:
2018-12-27 13:17:43 +01:00
log.Debug("udp serve() received shutdown signal")
return nil
default:
}
// Read a UDP packet into a reusable buffer.
buffer := pool.Get()
n, addr, err := t.socket.ReadFromUDP(*buffer)
if err != nil {
pool.Put(buffer)
2022-01-16 05:28:52 +01:00
var netErr net.Error
if errors.As(err, &netErr); netErr.Temporary() {
// A temporary failure is not fatal; just pretend it never happened.
continue
}
return err
}
// We got nothin'
if n == 0 {
pool.Put(buffer)
continue
}
2016-08-04 06:18:58 +02:00
t.wg.Add(1)
2016-08-05 07:47:04 +02:00
go func() {
2016-08-04 06:18:58 +02:00
defer t.wg.Done()
defer pool.Put(buffer)
if ip := addr.IP.To4(); ip != nil {
addr.IP = ip
}
2016-08-04 20:48:32 +02:00
// Handle the request.
2017-05-12 13:12:35 +02:00
var start time.Time
if t.EnableRequestTiming {
start = time.Now()
}
2017-02-01 02:58:08 +01:00
action, af, err := t.handleRequest(
// Make sure the IP is copied, not referenced.
Request{(*buffer)[:n], append([]byte{}, addr.IP...)},
2016-08-05 07:47:04 +02:00
ResponseWriter{t.socket, addr},
)
2017-05-12 13:12:35 +02:00
if t.EnableRequestTiming {
recordResponseDuration(action, af, err, time.Since(start))
} else {
recordResponseDuration(action, af, err, time.Duration(0))
}
2016-08-04 20:48:32 +02:00
}()
}
}
2016-08-04 20:08:26 +02:00
// Request represents a UDP payload received by a Tracker.
type Request struct {
Packet []byte
IP net.IP
}
2016-08-04 20:08:26 +02:00
// ResponseWriter implements the ability to respond to a Request via the
// io.Writer interface.
type ResponseWriter struct {
2016-08-05 07:47:04 +02:00
socket *net.UDPConn
addr *net.UDPAddr
}
2016-08-04 20:08:26 +02:00
// Write implements the io.Writer interface for a ResponseWriter.
2016-08-05 07:47:04 +02:00
func (w ResponseWriter) Write(b []byte) (int, error) {
2022-01-15 19:31:14 +01:00
_, _ = w.socket.WriteToUDP(b, w.addr)
return len(b), nil
}
2016-08-04 20:08:26 +02:00
// handleRequest parses and responds to a UDP Request.
2017-02-01 02:58:08 +01:00
func (t *Frontend) handleRequest(r Request, w ResponseWriter) (actionName string, af *bittorrent.AddressFamily, err error) {
2016-08-05 07:47:04 +02:00
if len(r.Packet) < 16 {
// Malformed, no client packets are less than 16 bytes.
// We explicitly return nothing in case this is a DoS attempt.
err = errMalformedPacket
return
}
// Parse the headers of the UDP packet.
2016-08-05 07:47:04 +02:00
connID := r.Packet[0:8]
actionID := binary.BigEndian.Uint32(r.Packet[8:12])
txID := r.Packet[12:16]
// get a connection ID generator/validator from the pool.
gen := t.genPool.Get().(*ConnectionIDGenerator)
defer t.genPool.Put(gen)
// If this isn't requesting a new connection ID and the connection ID is
// invalid, then fail.
if actionID != connectActionID && !gen.Validate(connID, r.IP, timecache.Now(), t.MaxClockSkew) {
err = errBadConnectionID
WriteError(w, txID, err)
return
}
// Handle the requested action.
switch actionID {
case connectActionID:
actionName = "connect"
if !bytes.Equal(connID, initialConnectionID) {
err = errMalformedPacket
return
}
af = new(bittorrent.AddressFamily)
if r.IP.To4() != nil {
*af = bittorrent.IPv4
} else if len(r.IP) == net.IPv6len { // implies r.IP.To4() == nil
*af = bittorrent.IPv6
} else {
// Should never happen - we got the IP straight from the UDP packet.
panic(fmt.Sprintf("udp: invalid IP: neither v4 nor v6, IP: %#v", r.IP))
}
WriteConnectionID(w, txID, gen.Generate(r.IP, timecache.Now()))
case announceActionID, announceV6ActionID:
actionName = "announce"
var req *bittorrent.AnnounceRequest
2017-10-08 23:35:50 +02:00
req, err = ParseAnnounce(r, actionID == announceV6ActionID, t.ParseOptions)
if err != nil {
WriteError(w, txID, err)
return
}
af = new(bittorrent.AddressFamily)
2017-02-01 02:58:08 +01:00
*af = req.IP.AddressFamily
var ctx context.Context
var resp *bittorrent.AnnounceResponse
ctx, resp, err = t.logic.HandleAnnounce(context.Background(), req)
if err != nil {
WriteError(w, txID, err)
return
}
WriteAnnounce(w, txID, resp, actionID == announceV6ActionID, req.IP.AddressFamily == bittorrent.IPv6)
go t.logic.AfterAnnounce(ctx, req, resp)
case scrapeActionID:
actionName = "scrape"
var req *bittorrent.ScrapeRequest
2017-10-08 23:35:50 +02:00
req, err = ParseScrape(r, t.ParseOptions)
if err != nil {
WriteError(w, txID, err)
return
}
2016-11-28 20:55:04 +01:00
if r.IP.To4() != nil {
req.AddressFamily = bittorrent.IPv4
2016-11-28 20:55:04 +01:00
} else if len(r.IP) == net.IPv6len { // implies r.IP.To4() == nil
req.AddressFamily = bittorrent.IPv6
2016-11-28 20:55:04 +01:00
} else {
// Should never happen - we got the IP straight from the UDP packet.
panic(fmt.Sprintf("udp: invalid IP: neither v4 nor v6, IP: %#v", r.IP))
2016-11-28 20:55:04 +01:00
}
af = new(bittorrent.AddressFamily)
2017-02-01 02:58:08 +01:00
*af = req.AddressFamily
2016-11-28 20:55:04 +01:00
var ctx context.Context
var resp *bittorrent.ScrapeResponse
ctx, resp, err = t.logic.HandleScrape(context.Background(), req)
if err != nil {
WriteError(w, txID, err)
return
}
WriteScrape(w, txID, resp)
go t.logic.AfterScrape(ctx, req, resp)
default:
err = errUnknownAction
WriteError(w, txID, err)
}
return
}