s/Server/Tracker

This commit is contained in:
Jimmy Zelinskie 2016-08-04 00:18:58 -04:00
parent dc25c8cab2
commit 0d054414ab
3 changed files with 67 additions and 66 deletions

View file

@ -161,15 +161,16 @@ type ClientError string
// Error implements the error interface for ClientError.
func (c ClientError) Error() string { return string(c) }
// Server represents an implementation of the BitTorrent tracker protocol.
type Server interface {
// Tracker represents an implementation of the BitTorrent tracker protocol.
type Tracker interface {
ListenAndServe() error
Stop()
}
// ServerFuncs are the collection of protocol-agnostic functions used to handle
// requests in a Server.
type ServerFuncs struct {
// TrackerFuncs is the collection of callback functions provided to a Tracker
// to (1) generate a response from a parsed request, and (2) observe anything
// after the response has been delivered to the client.
type TrackerFuncs struct {
HandleAnnounce AnnounceHandler
HandleScrape ScrapeHandler
AfterAnnounce AnnounceCallback

View file

@ -23,41 +23,41 @@ type Config struct {
RealIPHeader string
}
type Server struct {
type Tracker struct {
grace *graceful.Server
bittorrent.ServerFuncs
bittorrent.TrackerFuncs
Config
}
func NewServer(funcs bittorrent.ServerFuncs, cfg Config) {
func NewTracker(funcs bittorrent.TrackerFuncs, cfg Config) {
return &Server{
ServerFuncs: funcs,
Config: cfg,
TrackerFuncs: funcs,
Config: cfg,
}
}
func (s *Server) Stop() {
s.grace.Stop(s.grace.Timeout)
<-s.grace.StopChan()
func (t *Tracker) Stop() {
t.grace.Stop(t.grace.Timeout)
<-t.grace.StopChan()
}
func (s *Server) handler() {
func (t *Tracker) handler() {
router := httprouter.New()
router.GET("/announce", s.announceRoute)
router.GET("/scrape", s.scrapeRoute)
router.GET("/announce", t.announceRoute)
router.GET("/scrape", t.scrapeRoute)
return server
}
func (s *Server) ListenAndServe() error {
s.grace = &graceful.Server{
func (t *Tracker) ListenAndServe() error {
t.grace = &graceful.Server{
Server: &http.Server{
Addr: s.Addr,
Handler: s.handler(),
ReadTimeout: s.ReadTimeout,
WriteTimeout: s.WriteTimeout,
Addr: t.Addr,
Handler: t.handler(),
ReadTimeout: t.ReadTimeout,
WriteTimeout: t.WriteTimeout,
},
Timeout: s.RequestTimeout,
Timeout: t.RequestTimeout,
NoSignalHandling: true,
ConnState: func(conn net.Conn, state http.ConnState) {
switch state {
@ -78,23 +78,23 @@ func (s *Server) ListenAndServe() error {
}
},
}
s.grace.SetKeepAlivesEnabled(false)
t.grace.SetKeepAlivesEnabled(false)
if err := s.grace.ListenAndServe(); err != nil {
if err := t.grace.ListenAndServe(); err != nil {
if opErr, ok := err.(*net.OpError); !ok || (ok && opErr.Op != "accept") {
panic("http: failed to gracefully run HTTP server: " + err.Error())
}
}
}
func (s *Server) announceRoute(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
req, err := ParseAnnounce(r, s.RealIPHeader, s.AllowIPSpoofing)
func (t *Tracker) announceRoute(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
req, err := ParseAnnounce(r, t.RealIPHeader, t.AllowIPSpoofing)
if err != nil {
WriteError(w, err)
return
}
resp, err := s.HandleAnnounce(req)
resp, err := t.HandleAnnounce(req)
if err != nil {
WriteError(w, err)
return
@ -106,19 +106,19 @@ func (s *Server) announceRoute(w http.ResponseWriter, r *http.Request, _ httprou
return
}
if s.AfterAnnounce != nil {
s.AfterAnnounce(req, resp)
if t.AfterAnnounce != nil {
t.AfterAnnounce(req, resp)
}
}
func (s *Server) scrapeRoute(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
func (t *Tracker) scrapeRoute(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
req, err := ParseScrape(r)
if err != nil {
WriteError(w, err)
return
}
resp, err := s.HandleScrape(req)
resp, err := t.HandleScrape(req)
if err != nil {
WriteError(w, err)
return
@ -130,7 +130,7 @@ func (s *Server) scrapeRoute(w http.ResponseWriter, r *http.Request, _ httproute
return
}
if s.AfterScrape != nil {
s.AfterScrape(req, resp)
if t.AfterScrape != nil {
t.AfterScrape(req, resp)
}
}

View file

@ -38,56 +38,56 @@ type Config struct {
AllowIPSpoofing bool
}
type Server struct {
type Tracker struct {
sock *net.UDPConn
closing chan struct{}
wg sync.WaitGroup
bittorrent.ServerFuncs
bittorrent.TrackerFuncs
Config
}
func NewServer(funcs bittorrent.ServerFuncs, cfg Config) {
return &Server{
closing: make(chan struct{}),
ServerFuncs: funcs,
Config: cfg,
func NewTracker(funcs bittorrent.TrackerFuncs, cfg Config) {
return &Tracker{
closing: make(chan struct{}),
TrackerFuncs: funcs,
Config: cfg,
}
}
func (s *udpServer) Stop() {
close(s.closing)
s.sock.SetReadDeadline(time.Now())
s.wg.Wait()
func (t *Tracker) Stop() {
close(t.closing)
t.sock.SetReadDeadline(time.Now())
t.wg.Wait()
}
func (s *Server) ListenAndServe() error {
udpAddr, err := net.ResolveUDPAddr("udp", s.Addr)
func (t *Tracker) ListenAndServe() error {
udpAddr, err := net.ResolveUDPAddr("udp", t.Addr)
if err != nil {
return err
}
s.sock, err = net.ListenUDP("udp", udpAddr)
t.sock, err = net.ListenUDP("udp", udpAddr)
if err != nil {
return err
}
defer s.sock.Close()
defer t.sock.Close()
pool := bytepool.New(256, 2048)
for {
// Check to see if we need to shutdown.
select {
case <-s.closing:
s.wg.Wait()
case <-t.closing:
t.wg.Wait()
return nil
default:
}
// Read a UDP packet into a reusable buffer.
buffer := pool.Get()
s.sock.SetReadDeadline(time.Now().Add(time.Second))
n, addr, err := s.sock.ReadFromUDP(buffer)
t.sock.SetReadDeadline(time.Now().Add(time.Second))
n, addr, err := t.sock.ReadFromUDP(buffer)
if err != nil {
pool.Put(buffer)
if netErr, ok := err.(net.Error); ok && netErr.Temporary() {
@ -105,13 +105,13 @@ func (s *Server) ListenAndServe() error {
log.Println("Got UDP packet")
start := time.Now()
s.wg.Add(1)
t.wg.Add(1)
go func(start time.Time) {
defer s.wg.Done()
defer t.wg.Done()
defer pool.Put(buffer)
// Handle the response.
response, action, err := s.handlePacket(buffer[:n], addr)
response, action, err := t.handlePacket(buffer[:n], addr)
log.Printf("Handled UDP packet: %s, %s, %s\n", response, action, err)
// Record to Prometheus the time in milliseconds to receive, handle, and
@ -141,7 +141,7 @@ func (w *ResponseWriter) Write(b []byte) (int, error) {
return len(b), nil
}
func (s *Server) handlePacket(r *Request, w *ResponseWriter) (response []byte, actionName string, err error) {
func (t *Tracker) handlePacket(r *Request, w *ResponseWriter) (response []byte, actionName string, err error) {
if len(r.packet) < 16 {
// Malformed, no client packets are less than 16 bytes.
// We explicitly return nothing in case this is a DoS attempt.
@ -156,7 +156,7 @@ func (s *Server) handlePacket(r *Request, w *ResponseWriter) (response []byte, a
// If this isn't requesting a new connection ID and the connection ID is
// invalid, then fail.
if actionID != connectActionID && !ValidConnectionID(connID, r.IP, time.Now(), s.PrivateKey) {
if actionID != connectActionID && !ValidConnectionID(connID, r.IP, time.Now(), t.PrivateKey) {
err = errBadConnectionID
WriteError(w, txID, err)
return
@ -172,21 +172,21 @@ func (s *Server) handlePacket(r *Request, w *ResponseWriter) (response []byte, a
return
}
WriteConnectionID(w, txID, NewConnectionID(r.IP, time.Now(), s.PrivateKey))
WriteConnectionID(w, txID, NewConnectionID(r.IP, time.Now(), t.PrivateKey))
return
case announceActionID:
actionName = "announce"
var req *bittorrent.AnnounceRequest
req, err = ParseAnnounce(r, s.AllowIPSpoofing)
req, err = ParseAnnounce(r, t.AllowIPSpoofing)
if err != nil {
WriteError(w, txID, err)
return
}
var resp *bittorrent.AnnounceResponse
resp, err = s.HandleAnnounce(req)
resp, err = t.HandleAnnounce(req)
if err != nil {
WriteError(w, txID, err)
return
@ -194,8 +194,8 @@ func (s *Server) handlePacket(r *Request, w *ResponseWriter) (response []byte, a
WriteAnnounce(w, txID, resp)
if s.AfterAnnounce != nil {
s.AfterAnnounce(req, resp)
if t.AfterAnnounce != nil {
t.AfterAnnounce(req, resp)
}
return
@ -212,7 +212,7 @@ func (s *Server) handlePacket(r *Request, w *ResponseWriter) (response []byte, a
var resp *bittorrent.ScrapeResponse
ctx := context.TODO()
resp, err = s.HandleScrape(ctx, req)
resp, err = t.HandleScrape(ctx, req)
if err != nil {
WriteError(w, txID, err)
return
@ -220,8 +220,8 @@ func (s *Server) handlePacket(r *Request, w *ResponseWriter) (response []byte, a
WriteScrape(w, txID, resp)
if s.AfterScrape != nil {
s.AfterScrape(req, resp)
if t.AfterScrape != nil {
t.AfterScrape(req, resp)
}
return