s/Server/Tracker

This commit is contained in:
Jimmy Zelinskie 2016-08-04 00:18:58 -04:00
parent dc25c8cab2
commit 0d054414ab
3 changed files with 67 additions and 66 deletions

View file

@ -161,15 +161,16 @@ type ClientError string
// Error implements the error interface for ClientError. // Error implements the error interface for ClientError.
func (c ClientError) Error() string { return string(c) } func (c ClientError) Error() string { return string(c) }
// Server represents an implementation of the BitTorrent tracker protocol. // Tracker represents an implementation of the BitTorrent tracker protocol.
type Server interface { type Tracker interface {
ListenAndServe() error ListenAndServe() error
Stop() Stop()
} }
// ServerFuncs are the collection of protocol-agnostic functions used to handle // TrackerFuncs is the collection of callback functions provided to a Tracker
// requests in a Server. // to (1) generate a response from a parsed request, and (2) observe anything
type ServerFuncs struct { // after the response has been delivered to the client.
type TrackerFuncs struct {
HandleAnnounce AnnounceHandler HandleAnnounce AnnounceHandler
HandleScrape ScrapeHandler HandleScrape ScrapeHandler
AfterAnnounce AnnounceCallback AfterAnnounce AnnounceCallback

View file

@ -23,41 +23,41 @@ type Config struct {
RealIPHeader string RealIPHeader string
} }
type Server struct { type Tracker struct {
grace *graceful.Server grace *graceful.Server
bittorrent.ServerFuncs bittorrent.TrackerFuncs
Config Config
} }
func NewServer(funcs bittorrent.ServerFuncs, cfg Config) { func NewTracker(funcs bittorrent.TrackerFuncs, cfg Config) {
return &Server{ return &Server{
ServerFuncs: funcs, TrackerFuncs: funcs,
Config: cfg, Config: cfg,
} }
} }
func (s *Server) Stop() { func (t *Tracker) Stop() {
s.grace.Stop(s.grace.Timeout) t.grace.Stop(t.grace.Timeout)
<-s.grace.StopChan() <-t.grace.StopChan()
} }
func (s *Server) handler() { func (t *Tracker) handler() {
router := httprouter.New() router := httprouter.New()
router.GET("/announce", s.announceRoute) router.GET("/announce", t.announceRoute)
router.GET("/scrape", s.scrapeRoute) router.GET("/scrape", t.scrapeRoute)
return server return server
} }
func (s *Server) ListenAndServe() error { func (t *Tracker) ListenAndServe() error {
s.grace = &graceful.Server{ t.grace = &graceful.Server{
Server: &http.Server{ Server: &http.Server{
Addr: s.Addr, Addr: t.Addr,
Handler: s.handler(), Handler: t.handler(),
ReadTimeout: s.ReadTimeout, ReadTimeout: t.ReadTimeout,
WriteTimeout: s.WriteTimeout, WriteTimeout: t.WriteTimeout,
}, },
Timeout: s.RequestTimeout, Timeout: t.RequestTimeout,
NoSignalHandling: true, NoSignalHandling: true,
ConnState: func(conn net.Conn, state http.ConnState) { ConnState: func(conn net.Conn, state http.ConnState) {
switch state { switch state {
@ -78,23 +78,23 @@ func (s *Server) ListenAndServe() error {
} }
}, },
} }
s.grace.SetKeepAlivesEnabled(false) t.grace.SetKeepAlivesEnabled(false)
if err := s.grace.ListenAndServe(); err != nil { if err := t.grace.ListenAndServe(); err != nil {
if opErr, ok := err.(*net.OpError); !ok || (ok && opErr.Op != "accept") { if opErr, ok := err.(*net.OpError); !ok || (ok && opErr.Op != "accept") {
panic("http: failed to gracefully run HTTP server: " + err.Error()) panic("http: failed to gracefully run HTTP server: " + err.Error())
} }
} }
} }
func (s *Server) announceRoute(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { func (t *Tracker) announceRoute(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
req, err := ParseAnnounce(r, s.RealIPHeader, s.AllowIPSpoofing) req, err := ParseAnnounce(r, t.RealIPHeader, t.AllowIPSpoofing)
if err != nil { if err != nil {
WriteError(w, err) WriteError(w, err)
return return
} }
resp, err := s.HandleAnnounce(req) resp, err := t.HandleAnnounce(req)
if err != nil { if err != nil {
WriteError(w, err) WriteError(w, err)
return return
@ -106,19 +106,19 @@ func (s *Server) announceRoute(w http.ResponseWriter, r *http.Request, _ httprou
return return
} }
if s.AfterAnnounce != nil { if t.AfterAnnounce != nil {
s.AfterAnnounce(req, resp) t.AfterAnnounce(req, resp)
} }
} }
func (s *Server) scrapeRoute(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { func (t *Tracker) scrapeRoute(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
req, err := ParseScrape(r) req, err := ParseScrape(r)
if err != nil { if err != nil {
WriteError(w, err) WriteError(w, err)
return return
} }
resp, err := s.HandleScrape(req) resp, err := t.HandleScrape(req)
if err != nil { if err != nil {
WriteError(w, err) WriteError(w, err)
return return
@ -130,7 +130,7 @@ func (s *Server) scrapeRoute(w http.ResponseWriter, r *http.Request, _ httproute
return return
} }
if s.AfterScrape != nil { if t.AfterScrape != nil {
s.AfterScrape(req, resp) t.AfterScrape(req, resp)
} }
} }

View file

@ -38,56 +38,56 @@ type Config struct {
AllowIPSpoofing bool AllowIPSpoofing bool
} }
type Server struct { type Tracker struct {
sock *net.UDPConn sock *net.UDPConn
closing chan struct{} closing chan struct{}
wg sync.WaitGroup wg sync.WaitGroup
bittorrent.ServerFuncs bittorrent.TrackerFuncs
Config Config
} }
func NewServer(funcs bittorrent.ServerFuncs, cfg Config) { func NewTracker(funcs bittorrent.TrackerFuncs, cfg Config) {
return &Server{ return &Tracker{
closing: make(chan struct{}), closing: make(chan struct{}),
ServerFuncs: funcs, TrackerFuncs: funcs,
Config: cfg, Config: cfg,
} }
} }
func (s *udpServer) Stop() { func (t *Tracker) Stop() {
close(s.closing) close(t.closing)
s.sock.SetReadDeadline(time.Now()) t.sock.SetReadDeadline(time.Now())
s.wg.Wait() t.wg.Wait()
} }
func (s *Server) ListenAndServe() error { func (t *Tracker) ListenAndServe() error {
udpAddr, err := net.ResolveUDPAddr("udp", s.Addr) udpAddr, err := net.ResolveUDPAddr("udp", t.Addr)
if err != nil { if err != nil {
return err return err
} }
s.sock, err = net.ListenUDP("udp", udpAddr) t.sock, err = net.ListenUDP("udp", udpAddr)
if err != nil { if err != nil {
return err return err
} }
defer s.sock.Close() defer t.sock.Close()
pool := bytepool.New(256, 2048) pool := bytepool.New(256, 2048)
for { for {
// Check to see if we need to shutdown. // Check to see if we need to shutdown.
select { select {
case <-s.closing: case <-t.closing:
s.wg.Wait() t.wg.Wait()
return nil return nil
default: default:
} }
// Read a UDP packet into a reusable buffer. // Read a UDP packet into a reusable buffer.
buffer := pool.Get() buffer := pool.Get()
s.sock.SetReadDeadline(time.Now().Add(time.Second)) t.sock.SetReadDeadline(time.Now().Add(time.Second))
n, addr, err := s.sock.ReadFromUDP(buffer) n, addr, err := t.sock.ReadFromUDP(buffer)
if err != nil { if err != nil {
pool.Put(buffer) pool.Put(buffer)
if netErr, ok := err.(net.Error); ok && netErr.Temporary() { if netErr, ok := err.(net.Error); ok && netErr.Temporary() {
@ -105,13 +105,13 @@ func (s *Server) ListenAndServe() error {
log.Println("Got UDP packet") log.Println("Got UDP packet")
start := time.Now() start := time.Now()
s.wg.Add(1) t.wg.Add(1)
go func(start time.Time) { go func(start time.Time) {
defer s.wg.Done() defer t.wg.Done()
defer pool.Put(buffer) defer pool.Put(buffer)
// Handle the response. // Handle the response.
response, action, err := s.handlePacket(buffer[:n], addr) response, action, err := t.handlePacket(buffer[:n], addr)
log.Printf("Handled UDP packet: %s, %s, %s\n", response, action, err) log.Printf("Handled UDP packet: %s, %s, %s\n", response, action, err)
// Record to Prometheus the time in milliseconds to receive, handle, and // Record to Prometheus the time in milliseconds to receive, handle, and
@ -141,7 +141,7 @@ func (w *ResponseWriter) Write(b []byte) (int, error) {
return len(b), nil return len(b), nil
} }
func (s *Server) handlePacket(r *Request, w *ResponseWriter) (response []byte, actionName string, err error) { func (t *Tracker) handlePacket(r *Request, w *ResponseWriter) (response []byte, actionName string, err error) {
if len(r.packet) < 16 { if len(r.packet) < 16 {
// Malformed, no client packets are less than 16 bytes. // Malformed, no client packets are less than 16 bytes.
// We explicitly return nothing in case this is a DoS attempt. // We explicitly return nothing in case this is a DoS attempt.
@ -156,7 +156,7 @@ func (s *Server) handlePacket(r *Request, w *ResponseWriter) (response []byte, a
// If this isn't requesting a new connection ID and the connection ID is // If this isn't requesting a new connection ID and the connection ID is
// invalid, then fail. // invalid, then fail.
if actionID != connectActionID && !ValidConnectionID(connID, r.IP, time.Now(), s.PrivateKey) { if actionID != connectActionID && !ValidConnectionID(connID, r.IP, time.Now(), t.PrivateKey) {
err = errBadConnectionID err = errBadConnectionID
WriteError(w, txID, err) WriteError(w, txID, err)
return return
@ -172,21 +172,21 @@ func (s *Server) handlePacket(r *Request, w *ResponseWriter) (response []byte, a
return return
} }
WriteConnectionID(w, txID, NewConnectionID(r.IP, time.Now(), s.PrivateKey)) WriteConnectionID(w, txID, NewConnectionID(r.IP, time.Now(), t.PrivateKey))
return return
case announceActionID: case announceActionID:
actionName = "announce" actionName = "announce"
var req *bittorrent.AnnounceRequest var req *bittorrent.AnnounceRequest
req, err = ParseAnnounce(r, s.AllowIPSpoofing) req, err = ParseAnnounce(r, t.AllowIPSpoofing)
if err != nil { if err != nil {
WriteError(w, txID, err) WriteError(w, txID, err)
return return
} }
var resp *bittorrent.AnnounceResponse var resp *bittorrent.AnnounceResponse
resp, err = s.HandleAnnounce(req) resp, err = t.HandleAnnounce(req)
if err != nil { if err != nil {
WriteError(w, txID, err) WriteError(w, txID, err)
return return
@ -194,8 +194,8 @@ func (s *Server) handlePacket(r *Request, w *ResponseWriter) (response []byte, a
WriteAnnounce(w, txID, resp) WriteAnnounce(w, txID, resp)
if s.AfterAnnounce != nil { if t.AfterAnnounce != nil {
s.AfterAnnounce(req, resp) t.AfterAnnounce(req, resp)
} }
return return
@ -212,7 +212,7 @@ func (s *Server) handlePacket(r *Request, w *ResponseWriter) (response []byte, a
var resp *bittorrent.ScrapeResponse var resp *bittorrent.ScrapeResponse
ctx := context.TODO() ctx := context.TODO()
resp, err = s.HandleScrape(ctx, req) resp, err = t.HandleScrape(ctx, req)
if err != nil { if err != nil {
WriteError(w, txID, err) WriteError(w, txID, err)
return return
@ -220,8 +220,8 @@ func (s *Server) handlePacket(r *Request, w *ResponseWriter) (response []byte, a
WriteScrape(w, txID, resp) WriteScrape(w, txID, resp)
if s.AfterScrape != nil { if t.AfterScrape != nil {
s.AfterScrape(req, resp) t.AfterScrape(req, resp)
} }
return return