tracker/frontends/udp/frontend.go

269 lines
6.4 KiB
Go
Raw Normal View History

// Copyright 2016 Jimmy Zelinskie
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
2016-08-04 20:08:26 +02:00
// Package udp implements a BitTorrent tracker via the UDP protocol as
// described in BEP 15.
package udp
import (
"bytes"
"encoding/binary"
2016-08-05 07:47:04 +02:00
"log"
"net"
2016-08-05 07:47:04 +02:00
"sync"
"time"
2016-08-05 07:47:04 +02:00
"github.com/prometheus/client_golang/prometheus"
"golang.org/x/net/context"
2016-08-07 04:41:33 +02:00
"github.com/jzelinskie/trakr/backend"
"github.com/jzelinskie/trakr/bittorrent"
2016-08-07 04:41:33 +02:00
"github.com/jzelinskie/trakr/frontends/udp/bytepool"
)
2016-08-05 09:35:17 +02:00
func init() {
prometheus.MustRegister(promResponseDurationMilliseconds)
recordResponseDuration("action", nil, time.Second)
}
var promResponseDurationMilliseconds = prometheus.NewHistogramVec(
prometheus.HistogramOpts{
Name: "trakr_udp_response_duration_milliseconds",
Help: "The duration of time it takes to receive and write a response to an API request",
Buckets: prometheus.ExponentialBuckets(9.375, 2, 10),
},
[]string{"action", "error"},
)
2016-08-04 20:48:32 +02:00
// recordResponseDuration records the duration of time to respond to a UDP
// Request in milliseconds .
2016-08-05 07:47:04 +02:00
func recordResponseDuration(action string, err error, duration time.Duration) {
var errString string
if err != nil {
errString = err.Error()
}
2016-08-04 20:48:32 +02:00
promResponseDurationMilliseconds.
2016-08-05 07:47:04 +02:00
WithLabelValues(action, errString).
2016-08-04 20:48:32 +02:00
Observe(float64(duration.Nanoseconds()) / float64(time.Millisecond))
}
// Config represents all of the configurable options for a UDP BitTorrent
// Tracker.
type Config struct {
Addr string
PrivateKey string
2016-08-05 07:47:04 +02:00
MaxClockSkew time.Duration
AllowIPSpoofing bool
}
2016-08-07 04:41:33 +02:00
// Frontend holds the state of a UDP BitTorrent Frontend.
type Frontend struct {
2016-08-05 07:47:04 +02:00
socket *net.UDPConn
closing chan struct{}
wg sync.WaitGroup
2016-08-07 04:41:33 +02:00
backend.TrackerFuncs
Config
}
2016-08-07 04:41:33 +02:00
// NewFrontend allocates a new instance of a Frontend.
func NewFrontend(funcs backend.TrackerFuncs, cfg Config) *Frontend {
return &Frontend{
2016-08-04 06:18:58 +02:00
closing: make(chan struct{}),
TrackerFuncs: funcs,
Config: cfg,
}
}
2016-08-07 04:41:33 +02:00
// Stop provides a thread-safe way to shutdown a currently running Frontend.
func (t *Frontend) Stop() {
2016-08-04 06:18:58 +02:00
close(t.closing)
2016-08-05 07:47:04 +02:00
t.socket.SetReadDeadline(time.Now())
2016-08-04 06:18:58 +02:00
t.wg.Wait()
}
2016-08-04 20:08:26 +02:00
// ListenAndServe listens on the UDP network address t.Addr and blocks serving
// BitTorrent requests until t.Stop() is called or an error is returned.
2016-08-07 04:41:33 +02:00
func (t *Frontend) ListenAndServe() error {
2016-08-04 06:18:58 +02:00
udpAddr, err := net.ResolveUDPAddr("udp", t.Addr)
if err != nil {
return err
}
2016-08-05 07:47:04 +02:00
t.socket, err = net.ListenUDP("udp", udpAddr)
if err != nil {
return err
}
2016-08-05 07:47:04 +02:00
defer t.socket.Close()
pool := bytepool.New(256, 2048)
for {
// Check to see if we need to shutdown.
select {
2016-08-04 06:18:58 +02:00
case <-t.closing:
t.wg.Wait()
return nil
default:
}
// Read a UDP packet into a reusable buffer.
buffer := pool.Get()
2016-08-05 07:47:04 +02:00
t.socket.SetReadDeadline(time.Now().Add(time.Second))
n, addr, err := t.socket.ReadFromUDP(buffer)
if err != nil {
pool.Put(buffer)
if netErr, ok := err.(net.Error); ok && netErr.Temporary() {
// A temporary failure is not fatal; just pretend it never happened.
continue
}
return err
}
// We got nothin'
if n == 0 {
pool.Put(buffer)
continue
}
2016-08-04 20:48:32 +02:00
log.Println("Got UDP Request")
2016-08-04 06:18:58 +02:00
t.wg.Add(1)
2016-08-05 07:47:04 +02:00
go func() {
2016-08-04 06:18:58 +02:00
defer t.wg.Done()
defer pool.Put(buffer)
2016-08-04 20:48:32 +02:00
// Handle the request.
start := time.Now()
2016-08-05 07:47:04 +02:00
response, action, err := t.handleRequest(
Request{buffer[:n], addr.IP},
ResponseWriter{t.socket, addr},
)
2016-08-04 20:48:32 +02:00
log.Printf("Handled UDP Request: %s, %s, %s\n", response, action, err)
2016-08-05 07:47:04 +02:00
recordResponseDuration(action, err, time.Since(start))
2016-08-04 20:48:32 +02:00
}()
}
}
2016-08-04 20:08:26 +02:00
// Request represents a UDP payload received by a Tracker.
type Request struct {
Packet []byte
IP net.IP
}
2016-08-04 20:08:26 +02:00
// ResponseWriter implements the ability to respond to a Request via the
// io.Writer interface.
type ResponseWriter struct {
2016-08-05 07:47:04 +02:00
socket *net.UDPConn
addr *net.UDPAddr
}
2016-08-04 20:08:26 +02:00
// Write implements the io.Writer interface for a ResponseWriter.
2016-08-05 07:47:04 +02:00
func (w ResponseWriter) Write(b []byte) (int, error) {
w.socket.WriteToUDP(b, w.addr)
return len(b), nil
}
2016-08-04 20:08:26 +02:00
// handleRequest parses and responds to a UDP Request.
2016-08-07 04:41:33 +02:00
func (t *Frontend) handleRequest(r Request, w ResponseWriter) (response []byte, actionName string, err error) {
2016-08-05 07:47:04 +02:00
if len(r.Packet) < 16 {
// Malformed, no client packets are less than 16 bytes.
// We explicitly return nothing in case this is a DoS attempt.
err = errMalformedPacket
return
}
// Parse the headers of the UDP packet.
2016-08-05 07:47:04 +02:00
connID := r.Packet[0:8]
actionID := binary.BigEndian.Uint32(r.Packet[8:12])
txID := r.Packet[12:16]
// If this isn't requesting a new connection ID and the connection ID is
// invalid, then fail.
2016-08-05 07:47:04 +02:00
if actionID != connectActionID && !ValidConnectionID(connID, r.IP, time.Now(), t.MaxClockSkew, t.PrivateKey) {
err = errBadConnectionID
WriteError(w, txID, err)
return
}
// Handle the requested action.
switch actionID {
case connectActionID:
actionName = "connect"
if !bytes.Equal(connID, initialConnectionID) {
err = errMalformedPacket
return
}
2016-08-04 06:18:58 +02:00
WriteConnectionID(w, txID, NewConnectionID(r.IP, time.Now(), t.PrivateKey))
return
case announceActionID:
actionName = "announce"
var req *bittorrent.AnnounceRequest
2016-08-04 06:18:58 +02:00
req, err = ParseAnnounce(r, t.AllowIPSpoofing)
if err != nil {
WriteError(w, txID, err)
return
}
var resp *bittorrent.AnnounceResponse
2016-08-05 07:47:04 +02:00
resp, err = t.HandleAnnounce(context.TODO(), req)
if err != nil {
WriteError(w, txID, err)
return
}
WriteAnnounce(w, txID, resp)
2016-08-04 06:18:58 +02:00
if t.AfterAnnounce != nil {
2016-08-04 20:48:32 +02:00
go t.AfterAnnounce(req, resp)
}
return
case scrapeActionID:
actionName = "scrape"
var req *bittorrent.ScrapeRequest
req, err = ParseScrape(r)
if err != nil {
WriteError(w, txID, err)
return
}
var resp *bittorrent.ScrapeResponse
2016-08-05 07:47:04 +02:00
resp, err = t.HandleScrape(context.TODO(), req)
if err != nil {
WriteError(w, txID, err)
return
}
WriteScrape(w, txID, resp)
2016-08-04 06:18:58 +02:00
if t.AfterScrape != nil {
2016-08-04 20:48:32 +02:00
go t.AfterScrape(req, resp)
}
return
default:
err = errUnknownAction
WriteError(w, txID, err)
return
}
}