make it compile!

This commit is contained in:
Jimmy Zelinskie 2016-08-05 01:47:04 -04:00
parent b5de90345e
commit 5c99738b7f
17 changed files with 361 additions and 185 deletions

View file

@ -20,6 +20,8 @@ package bittorrent
import ( import (
"net" "net"
"time" "time"
"golang.org/x/net/context"
) )
// PeerID represents a peer ID. // PeerID represents a peer ID.
@ -107,7 +109,7 @@ type AnnounceResponse struct {
} }
// AnnounceHandler is a function that generates a response for an Announce. // AnnounceHandler is a function that generates a response for an Announce.
type AnnounceHandler func(*AnnounceRequest) *AnnounceResponse type AnnounceHandler func(context.Context, *AnnounceRequest) (*AnnounceResponse, error)
// AnnounceCallback is a function that does something with the results of an // AnnounceCallback is a function that does something with the results of an
// Announce after it has been completed. // Announce after it has been completed.
@ -132,7 +134,7 @@ type Scrape struct {
} }
// ScrapeHandler is a function that generates a response for a Scrape. // ScrapeHandler is a function that generates a response for a Scrape.
type ScrapeHandler func(*ScrapeRequest) *ScrapeResponse type ScrapeHandler func(context.Context, *ScrapeRequest) (*ScrapeResponse, error)
// ScrapeCallback is a function that does something with the results of a // ScrapeCallback is a function that does something with the results of a
// Scrape after it has been completed. // Scrape after it has been completed.
@ -152,9 +154,9 @@ func (p Peer) Equal(x Peer) bool { return p.EqualEndpoint(x) && p.ID == x.ID }
// EqualEndpoint reports whether p and x have the same endpoint. // EqualEndpoint reports whether p and x have the same endpoint.
func (p Peer) EqualEndpoint(x Peer) bool { return p.Port == x.Port && p.IP.Equal(x.IP) } func (p Peer) EqualEndpoint(x Peer) bool { return p.Port == x.Port && p.IP.Equal(x.IP) }
// Params is used to fetch request optional parameters. // Params is used to fetch request optional parameters from an Announce.
type Params interface { type Params interface {
String(key string) (string, error) String(key string) (string, bool)
} }
// ClientError represents an error that should be exposed to the client over // ClientError represents an error that should be exposed to the client over

View file

@ -32,14 +32,9 @@ func ParseAnnounce(r *http.Request, realIPHeader string, allowIPSpoofing bool) (
return nil, err return nil, err
} }
request := &bittorrent.AnnounceRequest{Params: q} request := &bittorrent.AnnounceRequest{Params: qp}
eventStr, err := qp.String("event") eventStr, _ := qp.String("event")
if err == query.ErrKeyNotFound {
eventStr = ""
} else if err != nil {
return nil, bittorrent.ClientError("failed to parse parameter: event")
}
request.Event, err = bittorrent.NewEvent(eventStr) request.Event, err = bittorrent.NewEvent(eventStr)
if err != nil { if err != nil {
return nil, bittorrent.ClientError("failed to provide valid client event") return nil, bittorrent.ClientError("failed to provide valid client event")
@ -57,14 +52,14 @@ func ParseAnnounce(r *http.Request, realIPHeader string, allowIPSpoofing bool) (
} }
request.InfoHash = infoHashes[0] request.InfoHash = infoHashes[0]
peerID, err := qp.String("peer_id") peerID, ok := qp.String("peer_id")
if err != nil { if !ok {
return nil, bittorrent.ClientError("failed to parse parameter: peer_id") return nil, bittorrent.ClientError("failed to parse parameter: peer_id")
} }
if len(peerID) != 20 { if len(peerID) != 20 {
return nil, bittorrent.ClientError("failed to provide valid peer_id") return nil, bittorrent.ClientError("failed to provide valid peer_id")
} }
request.PeerID = bittorrent.PeerIDFromString(peerID) request.Peer.ID = bittorrent.PeerIDFromString(peerID)
request.Left, err = qp.Uint64("left") request.Left, err = qp.Uint64("left")
if err != nil { if err != nil {
@ -85,24 +80,24 @@ func ParseAnnounce(r *http.Request, realIPHeader string, allowIPSpoofing bool) (
if err != nil { if err != nil {
return nil, bittorrent.ClientError("failed to parse parameter: numwant") return nil, bittorrent.ClientError("failed to parse parameter: numwant")
} }
request.NumWant = int32(numwant) request.NumWant = uint32(numwant)
port, err := qp.Uint64("port") port, err := qp.Uint64("port")
if err != nil { if err != nil {
return nil, bittorrent.ClientError("failed to parse parameter: port") return nil, bittorrent.ClientError("failed to parse parameter: port")
} }
request.Port = uint16(port) request.Peer.Port = uint16(port)
request.IP, err = requestedIP(q, r, realIPHeader, allowIPSpoofing) request.Peer.IP = requestedIP(r, qp, realIPHeader, allowIPSpoofing)
if err != nil { if request.Peer.IP == nil {
return nil, bittorrent.ClientError("failed to parse peer IP address: " + err.Error()) return nil, bittorrent.ClientError("failed to parse peer IP address")
} }
return request, nil return request, nil
} }
// ParseScrape parses an bittorrent.ScrapeRequest from an http.Request. // ParseScrape parses an bittorrent.ScrapeRequest from an http.Request.
func ParseScrape(r *http.Request) (*bittorent.ScrapeRequest, error) { func ParseScrape(r *http.Request) (*bittorrent.ScrapeRequest, error) {
qp, err := NewQueryParams(r.URL.RawQuery) qp, err := NewQueryParams(r.URL.RawQuery)
if err != nil { if err != nil {
return nil, err return nil, err
@ -115,7 +110,7 @@ func ParseScrape(r *http.Request) (*bittorent.ScrapeRequest, error) {
request := &bittorrent.ScrapeRequest{ request := &bittorrent.ScrapeRequest{
InfoHashes: infoHashes, InfoHashes: infoHashes,
Params: q, Params: qp,
} }
return request, nil return request, nil
@ -126,46 +121,31 @@ func ParseScrape(r *http.Request) (*bittorent.ScrapeRequest, error) {
// If allowIPSpoofing is true, IPs provided via params will be used. // If allowIPSpoofing is true, IPs provided via params will be used.
// If realIPHeader is not empty string, the first value of the HTTP Header with // If realIPHeader is not empty string, the first value of the HTTP Header with
// that name will be used. // that name will be used.
func requestedIP(r *http.Request, p bittorent.Params, realIPHeader string, allowIPSpoofing bool) (net.IP, error) { func requestedIP(r *http.Request, p bittorrent.Params, realIPHeader string, allowIPSpoofing bool) net.IP {
if allowIPSpoofing { if allowIPSpoofing {
if ipstr, err := p.String("ip"); err == nil { if ipstr, ok := p.String("ip"); ok {
ip, err := net.ParseIP(str) ip := net.ParseIP(ipstr)
if err != nil { return ip
return nil, err
} }
return ip, nil if ipstr, ok := p.String("ipv4"); ok {
ip := net.ParseIP(ipstr)
return ip
} }
if ipstr, err := p.String("ipv4"); err == nil { if ipstr, ok := p.String("ipv6"); ok {
ip, err := net.ParseIP(str) ip := net.ParseIP(ipstr)
if err != nil { return ip
return nil, err
}
return ip, nil
}
if ipstr, err := p.String("ipv6"); err == nil {
ip, err := net.ParseIP(str)
if err != nil {
return nil, err
}
return ip, nil
} }
} }
if realIPHeader != "" { if realIPHeader != "" {
if ips, ok := r.Header[realIPHeader]; ok && len(ips) > 0 { if ips, ok := r.Header[realIPHeader]; ok && len(ips) > 0 {
ip, err := net.ParseIP(ips[0]) ip := net.ParseIP(ips[0])
if err != nil { return ip
return nil, err
}
return ip, nil
} }
} }
return r.RemoteAddr host, _, _ := net.SplitHostPort(r.RemoteAddr)
return net.ParseIP(host)
} }

View file

@ -40,14 +40,14 @@ type QueryParams struct {
} }
// NewQueryParams parses a raw URL query. // NewQueryParams parses a raw URL query.
func NewQueryParams(query string) (*Query, error) { func NewQueryParams(query string) (*QueryParams, error) {
var ( var (
keyStart, keyEnd int keyStart, keyEnd int
valStart, valEnd int valStart, valEnd int
onKey = true onKey = true
q = &Query{ q = &QueryParams{
query: query, query: query,
infoHashes: nil, infoHashes: nil,
params: make(map[string]string), params: make(map[string]string),
@ -111,18 +111,15 @@ func NewQueryParams(query string) (*Query, error) {
// String returns a string parsed from a query. Every key can be returned as a // String returns a string parsed from a query. Every key can be returned as a
// string because they are encoded in the URL as strings. // string because they are encoded in the URL as strings.
func (q *Query) String(key string) (string, error) { func (qp *QueryParams) String(key string) (string, bool) {
val, exists := q.params[key] value, ok := qp.params[key]
if !exists { return value, ok
return "", ErrKeyNotFound
}
return val, nil
} }
// Uint64 returns a uint parsed from a query. After being called, it is safe to // Uint64 returns a uint parsed from a query. After being called, it is safe to
// cast the uint64 to your desired length. // cast the uint64 to your desired length.
func (q *Query) Uint64(key string) (uint64, error) { func (qp *QueryParams) Uint64(key string) (uint64, error) {
str, exists := q.params[key] str, exists := qp.params[key]
if !exists { if !exists {
return 0, ErrKeyNotFound return 0, ErrKeyNotFound
} }
@ -136,6 +133,6 @@ func (q *Query) Uint64(key string) (uint64, error) {
} }
// InfoHashes returns a list of requested infohashes. // InfoHashes returns a list of requested infohashes.
func (q *Query) InfoHashes() []bittorrent.InfoHash { func (qp *QueryParams) InfoHashes() []bittorrent.InfoHash {
return q.infoHashes return qp.infoHashes
} }

View file

@ -16,6 +16,19 @@
// described in BEP 3 and BEP 23. // described in BEP 3 and BEP 23.
package http package http
import (
"net"
"net/http"
"time"
"github.com/julienschmidt/httprouter"
"github.com/prometheus/client_golang/prometheus"
"github.com/tylerb/graceful"
"golang.org/x/net/context"
"github.com/jzelinskie/trakr/bittorrent"
)
var promResponseDurationMilliseconds = prometheus.NewHistogramVec( var promResponseDurationMilliseconds = prometheus.NewHistogramVec(
prometheus.HistogramOpts{ prometheus.HistogramOpts{
Name: "trakr_http_response_duration_milliseconds", Name: "trakr_http_response_duration_milliseconds",
@ -27,9 +40,14 @@ var promResponseDurationMilliseconds = prometheus.NewHistogramVec(
// recordResponseDuration records the duration of time to respond to a UDP // recordResponseDuration records the duration of time to respond to a UDP
// Request in milliseconds . // Request in milliseconds .
func recordResponseDuration(action, err error, duration time.Duration) { func recordResponseDuration(action string, err error, duration time.Duration) {
var errString string
if err != nil {
errString = err.Error()
}
promResponseDurationMilliseconds. promResponseDurationMilliseconds.
WithLabelValues(action, err.Error()). WithLabelValues(action, errString).
Observe(float64(duration.Nanoseconds()) / float64(time.Millisecond)) Observe(float64(duration.Nanoseconds()) / float64(time.Millisecond))
} }
@ -53,8 +71,8 @@ type Tracker struct {
} }
// NewTracker allocates a new instance of a Tracker. // NewTracker allocates a new instance of a Tracker.
func NewTracker(funcs bittorrent.TrackerFuncs, cfg Config) { func NewTracker(funcs bittorrent.TrackerFuncs, cfg Config) *Tracker {
return &Server{ return &Tracker{
TrackerFuncs: funcs, TrackerFuncs: funcs,
Config: cfg, Config: cfg,
} }
@ -66,11 +84,11 @@ func (t *Tracker) Stop() {
<-t.grace.StopChan() <-t.grace.StopChan()
} }
func (t *Tracker) handler() { func (t *Tracker) handler() http.Handler {
router := httprouter.New() router := httprouter.New()
router.GET("/announce", t.announceRoute) router.GET("/announce", t.announceRoute)
router.GET("/scrape", t.scrapeRoute) router.GET("/scrape", t.scrapeRoute)
return server return router
} }
// ListenAndServe listens on the TCP network address t.Addr and blocks serving // ListenAndServe listens on the TCP network address t.Addr and blocks serving
@ -111,18 +129,15 @@ func (t *Tracker) ListenAndServe() error {
panic("http: failed to gracefully run HTTP server: " + err.Error()) panic("http: failed to gracefully run HTTP server: " + err.Error())
} }
} }
return nil
} }
// announceRoute parses and responds to an Announce by using t.TrackerFuncs. // announceRoute parses and responds to an Announce by using t.TrackerFuncs.
func (t *Tracker) announceRoute(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { func (t *Tracker) announceRoute(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
var err error
start := time.Now() start := time.Now()
defer func() { defer recordResponseDuration("announce", err, time.Since(start))
var errString string
if err != nil {
errString = err.Error()
}
recordResponseDuration("announce", errString, time.Since(start))
}()
req, err := ParseAnnounce(r, t.RealIPHeader, t.AllowIPSpoofing) req, err := ParseAnnounce(r, t.RealIPHeader, t.AllowIPSpoofing)
if err != nil { if err != nil {
@ -130,7 +145,7 @@ func (t *Tracker) announceRoute(w http.ResponseWriter, r *http.Request, _ httpro
return return
} }
resp, err := t.HandleAnnounce(req) resp, err := t.HandleAnnounce(context.TODO(), req)
if err != nil { if err != nil {
WriteError(w, err) WriteError(w, err)
return return
@ -145,19 +160,13 @@ func (t *Tracker) announceRoute(w http.ResponseWriter, r *http.Request, _ httpro
if t.AfterAnnounce != nil { if t.AfterAnnounce != nil {
go t.AfterAnnounce(req, resp) go t.AfterAnnounce(req, resp)
} }
recordResponseDuration("announce")
} }
// scrapeRoute parses and responds to a Scrape by using t.TrackerFuncs. // scrapeRoute parses and responds to a Scrape by using t.TrackerFuncs.
func (t *Tracker) scrapeRoute(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { func (t *Tracker) scrapeRoute(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
var err error
start := time.Now() start := time.Now()
defer func() { defer recordResponseDuration("scrape", err, time.Since(start))
var errString string
if err != nil {
errString = err.Error()
}
recordResponseDuration("scrape", errString, time.Since(start))
}()
req, err := ParseScrape(r) req, err := ParseScrape(r)
if err != nil { if err != nil {
@ -165,7 +174,7 @@ func (t *Tracker) scrapeRoute(w http.ResponseWriter, r *http.Request, _ httprout
return return
} }
resp, err := t.HandleScrape(req) resp, err := t.HandleScrape(context.TODO(), req)
if err != nil { if err != nil {
WriteError(w, err) WriteError(w, err)
return return

View file

@ -18,6 +18,7 @@ import (
"net/http" "net/http"
"github.com/jzelinskie/trakr/bittorrent" "github.com/jzelinskie/trakr/bittorrent"
"github.com/jzelinskie/trakr/bittorrent/http/bencode"
) )
// WriteError communicates an error to a BitTorrent client over HTTP. // WriteError communicates an error to a BitTorrent client over HTTP.

View file

@ -18,8 +18,9 @@ import (
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"github.com/jzelinskie/trakr/bittorrent"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/jzelinskie/trakr/bittorrent"
) )
func TestWriteError(t *testing.T) { func TestWriteError(t *testing.T) {

View file

@ -0,0 +1,35 @@
// Copyright 2016 The Chihaya Authors. All rights reserved.
// Use of this source code is governed by the BSD 2-Clause license,
// which can be found in the LICENSE file.
package bytepool
import "sync"
// BytePool is a cached pool of reusable byte slices.
type BytePool struct {
sync.Pool
}
// New allocates a new BytePool with slices of the provided capacity.
func New(length, capacity int) *BytePool {
var bp BytePool
bp.Pool.New = func() interface{} {
return make([]byte, length, capacity)
}
return &bp
}
// Get returns a byte slice from the pool.
func (bp *BytePool) Get() []byte {
return bp.Pool.Get().([]byte)
}
// Put returns a byte slice to the pool.
func (bp *BytePool) Put(b []byte) {
// Zero out the bytes.
for i := 0; i < cap(b); i++ {
b[i] = 0x0
}
bp.Pool.Put(b)
}

View file

@ -63,35 +63,35 @@ var (
// //
// If allowIPSpoofing is true, IPs provided via params will be used. // If allowIPSpoofing is true, IPs provided via params will be used.
func ParseAnnounce(r Request, allowIPSpoofing bool) (*bittorrent.AnnounceRequest, error) { func ParseAnnounce(r Request, allowIPSpoofing bool) (*bittorrent.AnnounceRequest, error) {
if len(r.packet) < 98 { if len(r.Packet) < 98 {
return nil, errMalformedPacket return nil, errMalformedPacket
} }
infohash := r.packet[16:36] infohash := r.Packet[16:36]
peerID := r.packet[36:56] peerID := r.Packet[36:56]
downloaded := binary.BigEndian.Uint64(r.packet[56:64]) downloaded := binary.BigEndian.Uint64(r.Packet[56:64])
left := binary.BigEndian.Uint64(r.packet[64:72]) left := binary.BigEndian.Uint64(r.Packet[64:72])
uploaded := binary.BigEndian.Uint64(r.packet[72:80]) uploaded := binary.BigEndian.Uint64(r.Packet[72:80])
eventID := int(r.packet[83]) eventID := int(r.Packet[83])
if eventID >= len(eventIDs) { if eventID >= len(eventIDs) {
return nil, errMalformedEvent return nil, errMalformedEvent
} }
ip := r.IP ip := r.IP
ipbytes := r.packet[84:88] ipbytes := r.Packet[84:88]
if allowIPSpoofing { if allowIPSpoofing {
ip = net.IP(ipbytes) ip = net.IP(ipbytes)
} }
if !allowIPSpoofing && r.ip == nil { if !allowIPSpoofing && r.IP == nil {
// We have no IP address to fallback on. // We have no IP address to fallback on.
return nil, errMalformedIP return nil, errMalformedIP
} }
numWant := binary.BigEndian.Uint32(r.packet[92:96]) numWant := binary.BigEndian.Uint32(r.Packet[92:96])
port := binary.BigEndian.Uint16(r.packet[96:98]) port := binary.BigEndian.Uint16(r.Packet[96:98])
params, err := handleOptionalParameters(r.packet) params, err := handleOptionalParameters(r.Packet)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -152,24 +152,24 @@ func handleOptionalParameters(packet []byte) (params bittorrent.Params, err erro
} }
// ParseScrape parses a ScrapeRequest from a UDP request. // ParseScrape parses a ScrapeRequest from a UDP request.
func parseScrape(r Request) (*bittorrent.ScrapeRequest, error) { func ParseScrape(r Request) (*bittorrent.ScrapeRequest, error) {
// If a scrape isn't at least 36 bytes long, it's malformed. // If a scrape isn't at least 36 bytes long, it's malformed.
if len(r.packet) < 36 { if len(r.Packet) < 36 {
return nil, errMalformedPacket return nil, errMalformedPacket
} }
// Skip past the initial headers and check that the bytes left equal the // Skip past the initial headers and check that the bytes left equal the
// length of a valid list of infohashes. // length of a valid list of infohashes.
r.packet = r.packet[16:] r.Packet = r.Packet[16:]
if len(r.packet)%20 != 0 { if len(r.Packet)%20 != 0 {
return nil, errMalformedPacket return nil, errMalformedPacket
} }
// Allocate a list of infohashes and append it to the list until we're out. // Allocate a list of infohashes and append it to the list until we're out.
var infohashes []bittorrent.InfoHash var infohashes []bittorrent.InfoHash
for len(r.packet) >= 20 { for len(r.Packet) >= 20 {
infohashes = append(infohashes, bittorrent.InfoHashFromBytes(r.packet[:20])) infohashes = append(infohashes, bittorrent.InfoHashFromBytes(r.Packet[:20]))
r.packet = r.packet[20:] r.Packet = r.Packet[20:]
} }
return &bittorrent.ScrapeRequest{ return &bittorrent.ScrapeRequest{

View file

@ -19,10 +19,16 @@ package udp
import ( import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"log"
"net" "net"
"sync"
"time" "time"
"github.com/prometheus/client_golang/prometheus"
"golang.org/x/net/context"
"github.com/jzelinskie/trakr/bittorrent" "github.com/jzelinskie/trakr/bittorrent"
"github.com/jzelinskie/trakr/bittorrent/udp/bytepool"
) )
var promResponseDurationMilliseconds = prometheus.NewHistogramVec( var promResponseDurationMilliseconds = prometheus.NewHistogramVec(
@ -36,9 +42,14 @@ var promResponseDurationMilliseconds = prometheus.NewHistogramVec(
// recordResponseDuration records the duration of time to respond to a UDP // recordResponseDuration records the duration of time to respond to a UDP
// Request in milliseconds . // Request in milliseconds .
func recordResponseDuration(action, err error, duration time.Duration) { func recordResponseDuration(action string, err error, duration time.Duration) {
var errString string
if err != nil {
errString = err.Error()
}
promResponseDurationMilliseconds. promResponseDurationMilliseconds.
WithLabelValues(action, err.Error()). WithLabelValues(action, errString).
Observe(float64(duration.Nanoseconds()) / float64(time.Millisecond)) Observe(float64(duration.Nanoseconds()) / float64(time.Millisecond))
} }
@ -47,12 +58,13 @@ func recordResponseDuration(action, err error, duration time.Duration) {
type Config struct { type Config struct {
Addr string Addr string
PrivateKey string PrivateKey string
MaxClockSkew time.Duration
AllowIPSpoofing bool AllowIPSpoofing bool
} }
// Tracker holds the state of a UDP BitTorrent Tracker. // Tracker holds the state of a UDP BitTorrent Tracker.
type Tracker struct { type Tracker struct {
sock *net.UDPConn socket *net.UDPConn
closing chan struct{} closing chan struct{}
wg sync.WaitGroup wg sync.WaitGroup
@ -61,7 +73,7 @@ type Tracker struct {
} }
// NewTracker allocates a new instance of a Tracker. // NewTracker allocates a new instance of a Tracker.
func NewTracker(funcs bittorrent.TrackerFuncs, cfg Config) { func NewTracker(funcs bittorrent.TrackerFuncs, cfg Config) *Tracker {
return &Tracker{ return &Tracker{
closing: make(chan struct{}), closing: make(chan struct{}),
TrackerFuncs: funcs, TrackerFuncs: funcs,
@ -72,7 +84,7 @@ func NewTracker(funcs bittorrent.TrackerFuncs, cfg Config) {
// Stop provides a thread-safe way to shutdown a currently running Tracker. // Stop provides a thread-safe way to shutdown a currently running Tracker.
func (t *Tracker) Stop() { func (t *Tracker) Stop() {
close(t.closing) close(t.closing)
t.sock.SetReadDeadline(time.Now()) t.socket.SetReadDeadline(time.Now())
t.wg.Wait() t.wg.Wait()
} }
@ -84,11 +96,11 @@ func (t *Tracker) ListenAndServe() error {
return err return err
} }
t.sock, err = net.ListenUDP("udp", udpAddr) t.socket, err = net.ListenUDP("udp", udpAddr)
if err != nil { if err != nil {
return err return err
} }
defer t.sock.Close() defer t.socket.Close()
pool := bytepool.New(256, 2048) pool := bytepool.New(256, 2048)
@ -103,8 +115,8 @@ func (t *Tracker) ListenAndServe() error {
// Read a UDP packet into a reusable buffer. // Read a UDP packet into a reusable buffer.
buffer := pool.Get() buffer := pool.Get()
t.sock.SetReadDeadline(time.Now().Add(time.Second)) t.socket.SetReadDeadline(time.Now().Add(time.Second))
n, addr, err := t.sock.ReadFromUDP(buffer) n, addr, err := t.socket.ReadFromUDP(buffer)
if err != nil { if err != nil {
pool.Put(buffer) pool.Put(buffer)
if netErr, ok := err.(net.Error); ok && netErr.Temporary() { if netErr, ok := err.(net.Error); ok && netErr.Temporary() {
@ -122,21 +134,18 @@ func (t *Tracker) ListenAndServe() error {
log.Println("Got UDP Request") log.Println("Got UDP Request")
t.wg.Add(1) t.wg.Add(1)
go func(start time.Time) { go func() {
defer t.wg.Done() defer t.wg.Done()
defer pool.Put(buffer) defer pool.Put(buffer)
// Handle the request. // Handle the request.
start := time.Now() start := time.Now()
response, action, err := t.handleRequest(&Request{buffer[:n], addr.IP}) response, action, err := t.handleRequest(
Request{buffer[:n], addr.IP},
ResponseWriter{t.socket, addr},
)
log.Printf("Handled UDP Request: %s, %s, %s\n", response, action, err) log.Printf("Handled UDP Request: %s, %s, %s\n", response, action, err)
recordResponseDuration(action, err, time.Since(start))
// Record to the duration of time used to respond to the request.
var errString string
if err != nil {
errString = err.Error()
}
recordResponseDuration(action, errString, time.Since(start))
}() }()
} }
} }
@ -150,19 +159,19 @@ type Request struct {
// ResponseWriter implements the ability to respond to a Request via the // ResponseWriter implements the ability to respond to a Request via the
// io.Writer interface. // io.Writer interface.
type ResponseWriter struct { type ResponseWriter struct {
socket net.UDPConn socket *net.UDPConn
addr net.UDPAddr addr *net.UDPAddr
} }
// Write implements the io.Writer interface for a ResponseWriter. // Write implements the io.Writer interface for a ResponseWriter.
func (w *ResponseWriter) Write(b []byte) (int, error) { func (w ResponseWriter) Write(b []byte) (int, error) {
w.socket.WriteToUDP(b, w.addr) w.socket.WriteToUDP(b, w.addr)
return len(b), nil return len(b), nil
} }
// handleRequest parses and responds to a UDP Request. // handleRequest parses and responds to a UDP Request.
func (t *Tracker) handleRequest(r *Request, w *ResponseWriter) (response []byte, actionName string, err error) { func (t *Tracker) handleRequest(r Request, w ResponseWriter) (response []byte, actionName string, err error) {
if len(r.packet) < 16 { if len(r.Packet) < 16 {
// Malformed, no client packets are less than 16 bytes. // Malformed, no client packets are less than 16 bytes.
// We explicitly return nothing in case this is a DoS attempt. // We explicitly return nothing in case this is a DoS attempt.
err = errMalformedPacket err = errMalformedPacket
@ -170,13 +179,13 @@ func (t *Tracker) handleRequest(r *Request, w *ResponseWriter) (response []byte,
} }
// Parse the headers of the UDP packet. // Parse the headers of the UDP packet.
connID := r.packet[0:8] connID := r.Packet[0:8]
actionID := binary.BigEndian.Uint32(r.packet[8:12]) actionID := binary.BigEndian.Uint32(r.Packet[8:12])
txID := r.packet[12:16] txID := r.Packet[12:16]
// If this isn't requesting a new connection ID and the connection ID is // If this isn't requesting a new connection ID and the connection ID is
// invalid, then fail. // invalid, then fail.
if actionID != connectActionID && !ValidConnectionID(connID, r.IP, time.Now(), t.PrivateKey) { if actionID != connectActionID && !ValidConnectionID(connID, r.IP, time.Now(), t.MaxClockSkew, t.PrivateKey) {
err = errBadConnectionID err = errBadConnectionID
WriteError(w, txID, err) WriteError(w, txID, err)
return return
@ -206,7 +215,7 @@ func (t *Tracker) handleRequest(r *Request, w *ResponseWriter) (response []byte,
} }
var resp *bittorrent.AnnounceResponse var resp *bittorrent.AnnounceResponse
resp, err = t.HandleAnnounce(req) resp, err = t.HandleAnnounce(context.TODO(), req)
if err != nil { if err != nil {
WriteError(w, txID, err) WriteError(w, txID, err)
return return
@ -231,8 +240,7 @@ func (t *Tracker) handleRequest(r *Request, w *ResponseWriter) (response []byte,
} }
var resp *bittorrent.ScrapeResponse var resp *bittorrent.ScrapeResponse
ctx := context.TODO() resp, err = t.HandleScrape(context.TODO(), req)
resp, err = t.HandleScrape(ctx, req)
if err != nil { if err != nil {
WriteError(w, txID, err) WriteError(w, txID, err)
return return

View file

@ -18,58 +18,59 @@ import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"io"
"time" "time"
"github.com/jzelinskie/trakr/bittorrent" "github.com/jzelinskie/trakr/bittorrent"
) )
// WriteError writes the failure reason as a null-terminated string. // WriteError writes the failure reason as a null-terminated string.
func WriteError(writer io.Writer, txID []byte, err error) { func WriteError(w io.Writer, txID []byte, err error) {
// If the client wasn't at fault, acknowledge it. // If the client wasn't at fault, acknowledge it.
if _, ok := err.(bittorrent.ClientError); !ok { if _, ok := err.(bittorrent.ClientError); !ok {
err = fmt.Errorf("internal error occurred: %s", err.Error()) err = fmt.Errorf("internal error occurred: %s", err.Error())
} }
var buf bytes.Buffer var buf bytes.Buffer
writeHeader(buf, txID, errorActionID) writeHeader(&buf, txID, errorActionID)
buf.WriteString(err.Error()) buf.WriteString(err.Error())
buf.WriteRune('\000') buf.WriteRune('\000')
writer.Write(buf.Bytes()) w.Write(buf.Bytes())
} }
// WriteAnnounce encodes an announce response according to BEP 15. // WriteAnnounce encodes an announce response according to BEP 15.
func WriteAnnounce(respBuf *bytes.Buffer, txID []byte, resp *bittorrent.AnnounceResponse) { func WriteAnnounce(w io.Writer, txID []byte, resp *bittorrent.AnnounceResponse) {
writeHeader(respBuf, txID, announceActionID) writeHeader(w, txID, announceActionID)
binary.Write(respBuf, binary.BigEndian, uint32(resp.Interval/time.Second)) binary.Write(w, binary.BigEndian, uint32(resp.Interval/time.Second))
binary.Write(respBuf, binary.BigEndian, uint32(resp.Incomplete)) binary.Write(w, binary.BigEndian, uint32(resp.Incomplete))
binary.Write(respBuf, binary.BigEndian, uint32(resp.Complete)) binary.Write(w, binary.BigEndian, uint32(resp.Complete))
for _, peer := range resp.IPv4Peers { for _, peer := range resp.IPv4Peers {
respBuf.Write(peer.IP) w.Write(peer.IP)
binary.Write(respBuf, binary.BigEndian, peer.Port) binary.Write(w, binary.BigEndian, peer.Port)
} }
} }
// WriteScrape encodes a scrape response according to BEP 15. // WriteScrape encodes a scrape response according to BEP 15.
func WriteScrape(respBuf *bytes.Buffer, txID []byte, resp *bittorrent.ScrapeResponse) { func WriteScrape(w io.Writer, txID []byte, resp *bittorrent.ScrapeResponse) {
writeHeader(respBuf, txID, scrapeActionID) writeHeader(w, txID, scrapeActionID)
for _, scrape := range resp.Files { for _, scrape := range resp.Files {
binary.Write(respBuf, binary.BigEndian, scrape.Complete) binary.Write(w, binary.BigEndian, scrape.Complete)
binary.Write(respBuf, binary.BigEndian, scrape.Snatches) binary.Write(w, binary.BigEndian, scrape.Snatches)
binary.Write(respBuf, binary.BigEndian, scrape.Incomplete) binary.Write(w, binary.BigEndian, scrape.Incomplete)
} }
} }
// WriteConnectionID encodes a new connection response according to BEP 15. // WriteConnectionID encodes a new connection response according to BEP 15.
func WriteConnectionID(respBuf *bytes.Buffer, txID, connID []byte) { func WriteConnectionID(w io.Writer, txID, connID []byte) {
writeHeader(respBuf, txID, connectActionID) writeHeader(w, txID, connectActionID)
respBuf.Write(connID) w.Write(connID)
} }
// writeHeader writes the action and transaction ID to the provided response // writeHeader writes the action and transaction ID to the provided response
// buffer. // buffer.
func writeHeader(respBuf *bytes.Buffer, txID []byte, action uint32) { func writeHeader(w io.Writer, txID []byte, action uint32) {
binary.Write(respBuf, binary.BigEndian, action) binary.Write(w, binary.BigEndian, action)
respBuf.Write(txID) w.Write(txID)
} }

View file

View file

@ -0,0 +1,65 @@
package main
import (
"errors"
"log"
"os"
"os/signal"
"runtime/pprof"
"syscall"
"github.com/spf13/cobra"
"github.com/jzelinskie/trakr"
)
func main() {
var configFilePath string
var cpuProfilePath string
var rootCmd = &cobra.Command{
Use: "trakr",
Short: "BitTorrent Tracker",
Long: "A customizible, multi-protocol BitTorrent Tracker",
Run: func(cmd *cobra.Command, args []string) {
if err := func() error {
if cpuProfilePath != "" {
log.Println("enabled CPU profiling to " + cpuProfilePath)
f, err := os.Create(cpuProfilePath)
if err != nil {
return err
}
pprof.StartCPUProfile(f)
defer pprof.StopCPUProfile()
}
mt, err := trakr.MultiTrackerFromFile(configFilePath)
if err != nil {
return errors.New("failed to read config: " + err.Error())
}
go func() {
shutdown := make(chan os.Signal)
signal.Notify(shutdown, syscall.SIGINT, syscall.SIGTERM)
<-shutdown
mt.Stop()
}()
if err := mt.ListenAndServe(); err != nil {
return errors.New("failed to cleanly shutdown: " + err.Error())
}
return nil
}(); err != nil {
log.Fatal(err)
}
},
}
rootCmd.Flags().StringVar(&configFilePath, "config", "/etc/trakr.yaml", "location of configuration file (defaults to /etc/trakr.yaml)")
rootCmd.Flags().StringVarP(&cpuProfilePath, "cpuprofile", "", "", "location to save a CPU profile")
if err := rootCmd.Execute(); err != nil {
log.Fatal(err)
}
}

View file

@ -14,13 +14,19 @@
package trakr package trakr
import "github.com/jzelinskie/trakr/bittorrent" import (
"fmt"
"golang.org/x/net/context"
"github.com/jzelinskie/trakr/bittorrent"
)
// Hook abstracts the concept of anything that needs to interact with a // Hook abstracts the concept of anything that needs to interact with a
// BitTorrent client's request and response to a BitTorrent tracker. // BitTorrent client's request and response to a BitTorrent tracker.
type Hook interface { type Hook interface {
HandleAnnounce(context.Context, bittorrent.AnnounceRequest, bittorrent.AnnounceResponse) error HandleAnnounce(context.Context, *bittorrent.AnnounceRequest, *bittorrent.AnnounceResponse) error
HandleScrape(context.Context, bittorrent.ScrapeRequest, bittorrent.ScrapeResponse) error HandleScrape(context.Context, *bittorrent.ScrapeRequest, *bittorrent.ScrapeResponse) error
} }
// HookConstructor is a function used to create a new instance of a Hook. // HookConstructor is a function used to create a new instance of a Hook.
@ -36,7 +42,7 @@ func RegisterPreHook(name string, con HookConstructor) {
if con == nil { if con == nil {
panic("trakr: could not register nil HookConstructor") panic("trakr: could not register nil HookConstructor")
} }
if _, dup := constructors[name]; dup { if _, dup := preHooks[name]; dup {
panic("trakr: could not register duplicate HookConstructor: " + name) panic("trakr: could not register duplicate HookConstructor: " + name)
} }
preHooks[name] = con preHooks[name] = con
@ -61,7 +67,7 @@ func RegisterPostHook(name string, con HookConstructor) {
if con == nil { if con == nil {
panic("trakr: could not register nil HookConstructor") panic("trakr: could not register nil HookConstructor")
} }
if _, dup := constructors[name]; dup { if _, dup := postHooks[name]; dup {
panic("trakr: could not register duplicate HookConstructor: " + name) panic("trakr: could not register duplicate HookConstructor: " + name)
} }
preHooks[name] = con preHooks[name] = con

View file

@ -4,11 +4,11 @@ import (
"sync" "sync"
) )
// AlreadyStopped is a closed error channel to be used by StopperFuncs when // AlreadyStopped is a closed error channel to be used by Funcs when
// an element was already stopped. // an element was already stopped.
var AlreadyStopped <-chan error var AlreadyStopped <-chan error
// AlreadyStoppedFunc is a StopperFunc that returns AlreadyStopped. // AlreadyStoppedFunc is a Func that returns AlreadyStopped.
var AlreadyStoppedFunc = func() <-chan error { return AlreadyStopped } var AlreadyStoppedFunc = func() <-chan error { return AlreadyStopped }
func init() { func init() {
@ -30,7 +30,7 @@ type Stopper interface {
// StopGroup is a group that can be stopped. // StopGroup is a group that can be stopped.
type StopGroup struct { type StopGroup struct {
stoppables []StopperFunc stoppables []Func
stoppablesLock sync.Mutex stoppablesLock sync.Mutex
} }
@ -40,7 +40,7 @@ type Func func() <-chan error
// NewStopGroup creates a new StopGroup. // NewStopGroup creates a new StopGroup.
func NewStopGroup() *StopGroup { func NewStopGroup() *StopGroup {
return &StopGroup{ return &StopGroup{
stoppables: make([]StopperFunc, 0), stoppables: make([]Func, 0),
} }
} }
@ -53,9 +53,9 @@ func (cg *StopGroup) Add(toAdd Stopper) {
cg.stoppables = append(cg.stoppables, toAdd.Stop) cg.stoppables = append(cg.stoppables, toAdd.Stop)
} }
// AddFunc adds a StopperFunc to the StopGroup. // AddFunc adds a Func to the StopGroup.
// On the next call to Stop(), the StopperFunc will be called. // On the next call to Stop(), the Func will be called.
func (cg *StopGroup) AddFunc(toAddFunc StopperFunc) { func (cg *StopGroup) AddFunc(toAddFunc Func) {
cg.stoppablesLock.Lock() cg.stoppablesLock.Lock()
defer cg.stoppablesLock.Unlock() defer cg.stoppablesLock.Unlock()

View file

@ -10,7 +10,7 @@ import (
// ErrResourceDoesNotExist is the error returned by all delete methods in the // ErrResourceDoesNotExist is the error returned by all delete methods in the
// store if the requested resource does not exist. // store if the requested resource does not exist.
var ErrResourceDoesNotExist = bittorrent.ClientError(errors.New("resource does not exist")) var ErrResourceDoesNotExist = bittorrent.ClientError("resource does not exist")
// PeerStore is an interface that abstracts the interactions of storing and // PeerStore is an interface that abstracts the interactions of storing and
// manipulating Peers such that it can be implemented for various data stores. // manipulating Peers such that it can be implemented for various data stores.
@ -68,7 +68,7 @@ type PeerStore interface {
// PeerStore. // PeerStore.
type PeerStoreConstructor func(interface{}) (PeerStore, error) type PeerStoreConstructor func(interface{}) (PeerStore, error)
var peerStores = make(map[string]PeerStoreConstructors) var peerStores = make(map[string]PeerStoreConstructor)
// RegisterPeerStore makes a PeerStoreConstructor available by the provided // RegisterPeerStore makes a PeerStoreConstructor available by the provided
// name. // name.
@ -80,7 +80,7 @@ func RegisterPeerStore(name string, con PeerStoreConstructor) {
panic("trakr: could not register nil PeerStoreConstructor") panic("trakr: could not register nil PeerStoreConstructor")
} }
if _, dup := peerStore[name]; dup { if _, dup := peerStores[name]; dup {
panic("trakr: could not register duplicate PeerStoreConstructor: " + name) panic("trakr: could not register duplicate PeerStoreConstructor: " + name)
} }
@ -88,7 +88,7 @@ func RegisterPeerStore(name string, con PeerStoreConstructor) {
} }
// NewPeerStore creates an instance of the given PeerStore by name. // NewPeerStore creates an instance of the given PeerStore by name.
func NewPeerStore(name, config interface{}) (PeerStore, error) { func NewPeerStore(name string, config interface{}) (PeerStore, error) {
con, ok := peerStores[name] con, ok := peerStores[name]
if !ok { if !ok {
return nil, fmt.Errorf("trakr: unknown PeerStore %q (forgotten import?)", name) return nil, fmt.Errorf("trakr: unknown PeerStore %q (forgotten import?)", name)

View file

@ -17,22 +17,93 @@
// has been delievered to a BitTorrent client. // has been delievered to a BitTorrent client.
package trakr package trakr
import (
"errors"
"io"
"io/ioutil"
"os"
"time"
"github.com/jzelinskie/trakr/bittorrent/http"
"github.com/jzelinskie/trakr/bittorrent/udp"
"gopkg.in/yaml.v2"
)
// GenericConfig is a block of configuration who's structure is unknown.
type GenericConfig struct {
name string `yaml:"name"`
config interface{} `yaml:"config"`
}
// MultiTracker is a multi-protocol, customizable BitTorrent Tracker. // MultiTracker is a multi-protocol, customizable BitTorrent Tracker.
type MultiTracker struct { type MultiTracker struct {
HTTPConfig http.Config AnnounceInterval time.Duration `yaml:"announce_interval"`
UDPConfig udp.Config GCInterval time.Duration `yaml:"gc_interval"`
AnnounceInterval time.Duration GCExpiration time.Duration `yaml:"gc_expiration"`
GCInterval time.Duration HTTPConfig http.Config `yaml:"http"`
GCExpiration time.Duration UDPConfig udp.Config `yaml:"udp"`
PreHooks []Hook PeerStoreConfig []GenericConfig `yaml:"storage"`
PostHooks []Hook PreHooks []GenericConfig `yaml:"prehooks"`
PostHooks []GenericConfig `yaml:"posthooks"`
peerStore PeerStore
httpTracker http.Tracker httpTracker http.Tracker
udpTracker udp.Tracker udpTracker udp.Tracker
} }
// decodeConfigFile unmarshals an io.Reader into a new MultiTracker.
func decodeConfigFile(r io.Reader) (*MultiTracker, error) {
contents, err := ioutil.ReadAll(r)
if err != nil {
return nil, err
}
cfgFile := struct {
mt MultiTracker `yaml:"trakr"`
}{}
err = yaml.Unmarshal(contents, cfgFile)
if err != nil {
return nil, err
}
return &cfgFile.mt, nil
}
// MultiTrackerFromFile returns a new MultiTracker given the path to a YAML
// configuration file.
//
// It supports relative and absolute paths and environment variables.
func MultiTrackerFromFile(path string) (*MultiTracker, error) {
if path == "" {
return nil, errors.New("no config path specified")
}
f, err := os.Open(os.ExpandEnv(path))
if err != nil {
return nil, err
}
defer f.Close()
cfg, err := decodeConfigFile(f)
if err != nil {
return nil, err
}
return cfg, nil
}
// Stop provides a thread-safe way to shutdown a currently running
// MultiTracker.
func (t *MultiTracker) Stop() {
}
// ListenAndServe listens on the protocols and addresses specified in the // ListenAndServe listens on the protocols and addresses specified in the
// HTTPConfig and UDPConfig then blocks serving BitTorrent requests until // HTTPConfig and UDPConfig then blocks serving BitTorrent requests until
// t.Stop() is called or an error is returned. // t.Stop() is called or an error is returned.
func (t *MultiTracker) ListenAndServe() error { func (t *MultiTracker) ListenAndServe() error {
// Build an TrackerFuncs from the PreHooks and PostHooks.
// Create a PeerStore instance.
// Create a HTTP Tracker instance.
// Create a UDP Tracker instance.
return nil
} }