diff --git a/http/routes.go b/http/routes.go index 719c92d..3a8f1c7 100644 --- a/http/routes.go +++ b/http/routes.go @@ -76,7 +76,7 @@ func (s *Server) stats(w http.ResponseWriter, r *http.Request, p httprouter.Para func handleTorrentError(err error, w *Writer) (int, error) { if err == nil { return http.StatusOK, nil - } else if _, ok := err.(models.ClientError); ok { + } else if models.IsPublicError(err) { w.WriteError(err) stats.RecordEvent(stats.ClientError) return http.StatusOK, nil diff --git a/tracker/models/models.go b/tracker/models/models.go index 1fc1f2a..7ff5cd4 100644 --- a/tracker/models/models.go +++ b/tracker/models/models.go @@ -45,6 +45,13 @@ func (e ClientError) Error() string { return string(e) } func (e NotFoundError) Error() string { return string(e) } func (e ProtocolError) Error() string { return string(e) } +func IsPublicError(err error) bool { + _, cl := err.(ClientError) + _, nf := err.(NotFoundError) + _, pc := err.(ProtocolError) + return cl || nf || pc +} + type PeerList []Peer type PeerKey string diff --git a/udp/protocol.go b/udp/protocol.go index 15c0ef8..bc5aff0 100644 --- a/udp/protocol.go +++ b/udp/protocol.go @@ -32,7 +32,7 @@ func handleTorrentError(err error, w *Writer) { return } - if _, ok := err.(models.ClientError); ok { + if models.IsPublicError(err) { w.WriteError(err) stats.RecordEvent(stats.ClientError) } diff --git a/udp/scrape_test.go b/udp/scrape_test.go new file mode 100644 index 0000000..79482e0 --- /dev/null +++ b/udp/scrape_test.go @@ -0,0 +1,72 @@ +// Copyright 2015 The Chihaya Authors. All rights reserved. +// Use of this source code is governed by the BSD 2-Clause license, +// which can be found in the LICENSE file. + +package udp + +import ( + "bytes" + "fmt" + "net" + "testing" + + "github.com/chihaya/chihaya/config" +) + +func requestScrape(sock *net.UDPConn, connID []byte, hashes []string) ([]byte, error) { + txID := makeTransactionID() + request := []byte{} + + request = append(request, connID...) + request = append(request, scrapeAction...) + request = append(request, txID...) + + for _, hash := range hashes { + request = append(request, []byte(hash)...) + } + + response := make([]byte, 1024) + n, err := sendRequest(sock, request, response) + if err != nil { + return nil, err + } + + if !bytes.Equal(response[4:8], txID) { + return nil, fmt.Errorf("transaction ID mismatch") + } + + return response[:n], nil +} + +func TestScrapeEmpty(t *testing.T) { + srv, done, err := setupTracker(&config.DefaultConfig) + if err != nil { + t.Fatal(err) + } + + _, sock, err := setupSocket() + if err != nil { + t.Fatal(err) + } + + connID, err := requestConnectionID(sock) + if err != nil { + t.Fatal(err) + } + + scrape, err := requestScrape(sock, connID, []string{"aaaaaaaaaaaaaaaaaaaa"}) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(scrape[:4], errorAction) { + t.Error("expected error response") + } + + if string(scrape[8:]) != "torrent does not exist\000" { + t.Error("expected torrent to not exist") + } + + srv.Stop() + <-done +} diff --git a/udp/udp.go b/udp/udp.go index f5ebe66..8609eaa 100644 --- a/udp/udp.go +++ b/udp/udp.go @@ -71,10 +71,11 @@ func (s *Server) serve(listenAddr string) error { go func() { response, action := s.handlePacket(buffer[:n], addr) + pool.GiveSlice(buffer) + if response != nil { sock.WriteToUDP(response, addr) } - pool.GiveSlice(buffer) if glog.V(2) { duration := time.Since(start) diff --git a/udp/udp_test.go b/udp/udp_test.go new file mode 100644 index 0000000..b9d57e3 --- /dev/null +++ b/udp/udp_test.go @@ -0,0 +1,135 @@ +// Copyright 2015 The Chihaya Authors. All rights reserved. +// Use of this source code is governed by the BSD 2-Clause license, +// which can be found in the LICENSE file. + +package udp + +import ( + "bytes" + "crypto/rand" + "fmt" + "net" + "testing" + "time" + + "github.com/chihaya/chihaya/config" + "github.com/chihaya/chihaya/stats" + "github.com/chihaya/chihaya/tracker" + + _ "github.com/chihaya/chihaya/backend/noop" +) + +var testPort = "34137" +var connectAction = []byte{0, 0, 0, 0} +var announceAction = []byte{0, 0, 0, 1} +var scrapeAction = []byte{0, 0, 0, 2} +var errorAction = []byte{0, 0, 0, 3} + +func init() { + stats.DefaultStats = stats.New(config.StatsConfig{}) +} + +func setupTracker(cfg *config.Config) (*Server, chan struct{}, error) { + tkr, err := tracker.New(cfg) + if err != nil { + return nil, nil, err + } + + srv := NewServer(cfg, tkr) + done := make(chan struct{}) + + go func() { + if err := srv.serve(":" + testPort); err != nil { + panic(err) + } + close(done) + }() + + <-srv.booting + return srv, done, nil +} + +func setupSocket() (*net.UDPAddr, *net.UDPConn, error) { + srvAddr, err := net.ResolveUDPAddr("udp", "localhost:"+testPort) + if err != nil { + return nil, nil, err + } + + sock, err := net.DialUDP("udp", nil, srvAddr) + if err != nil { + return nil, nil, err + } + + return srvAddr, sock, err +} + +func makeTransactionID() []byte { + out := make([]byte, 4) + rand.Read(out) + return out +} + +func sendRequest(sock *net.UDPConn, request, response []byte) (int, error) { + if _, err := sock.Write(request); err != nil { + return 0, err + } + + sock.SetReadDeadline(time.Now().Add(time.Second)) + n, err := sock.Read(response) + + if err != nil { + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + return 0, fmt.Errorf("no response from tracker: %s", err) + } + } + + return n, err +} + +func requestConnectionID(sock *net.UDPConn) ([]byte, error) { + txID := makeTransactionID() + request := []byte{} + + request = append(request, initialConnectionID...) + request = append(request, connectAction...) + request = append(request, txID...) + + response := make([]byte, 1024) + n, err := sendRequest(sock, request, response) + if err != nil { + return nil, err + } + + if n != 16 { + return nil, fmt.Errorf("packet length mismatch: %d != 16", n) + } + + if !bytes.Equal(response[4:8], txID) { + return nil, fmt.Errorf("transaction ID mismatch") + } + + if !bytes.Equal(response[0:4], connectAction) { + return nil, fmt.Errorf("action mismatch") + } + + return response[8:16], nil +} + +func TestRequestConnectionID(t *testing.T) { + srv, done, err := setupTracker(&config.DefaultConfig) + if err != nil { + t.Fatal(err) + } + + _, sock, err := setupSocket() + if err != nil { + t.Fatal(err) + } + + if _, err = requestConnectionID(sock); err != nil { + t.Fatal(err) + } + + srv.Stop() + <-done +}