udp: Add basic UDP tests
This commit is contained in:
parent
105edf21f1
commit
7512f50731
6 changed files with 218 additions and 3 deletions
|
@ -76,7 +76,7 @@ func (s *Server) stats(w http.ResponseWriter, r *http.Request, p httprouter.Para
|
|||
func handleTorrentError(err error, w *Writer) (int, error) {
|
||||
if err == nil {
|
||||
return http.StatusOK, nil
|
||||
} else if _, ok := err.(models.ClientError); ok {
|
||||
} else if models.IsPublicError(err) {
|
||||
w.WriteError(err)
|
||||
stats.RecordEvent(stats.ClientError)
|
||||
return http.StatusOK, nil
|
||||
|
|
|
@ -45,6 +45,13 @@ func (e ClientError) Error() string { return string(e) }
|
|||
func (e NotFoundError) Error() string { return string(e) }
|
||||
func (e ProtocolError) Error() string { return string(e) }
|
||||
|
||||
func IsPublicError(err error) bool {
|
||||
_, cl := err.(ClientError)
|
||||
_, nf := err.(NotFoundError)
|
||||
_, pc := err.(ProtocolError)
|
||||
return cl || nf || pc
|
||||
}
|
||||
|
||||
type PeerList []Peer
|
||||
type PeerKey string
|
||||
|
||||
|
|
|
@ -32,7 +32,7 @@ func handleTorrentError(err error, w *Writer) {
|
|||
return
|
||||
}
|
||||
|
||||
if _, ok := err.(models.ClientError); ok {
|
||||
if models.IsPublicError(err) {
|
||||
w.WriteError(err)
|
||||
stats.RecordEvent(stats.ClientError)
|
||||
}
|
||||
|
|
72
udp/scrape_test.go
Normal file
72
udp/scrape_test.go
Normal file
|
@ -0,0 +1,72 @@
|
|||
// Copyright 2015 The Chihaya Authors. All rights reserved.
|
||||
// Use of this source code is governed by the BSD 2-Clause license,
|
||||
// which can be found in the LICENSE file.
|
||||
|
||||
package udp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/chihaya/chihaya/config"
|
||||
)
|
||||
|
||||
func requestScrape(sock *net.UDPConn, connID []byte, hashes []string) ([]byte, error) {
|
||||
txID := makeTransactionID()
|
||||
request := []byte{}
|
||||
|
||||
request = append(request, connID...)
|
||||
request = append(request, scrapeAction...)
|
||||
request = append(request, txID...)
|
||||
|
||||
for _, hash := range hashes {
|
||||
request = append(request, []byte(hash)...)
|
||||
}
|
||||
|
||||
response := make([]byte, 1024)
|
||||
n, err := sendRequest(sock, request, response)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !bytes.Equal(response[4:8], txID) {
|
||||
return nil, fmt.Errorf("transaction ID mismatch")
|
||||
}
|
||||
|
||||
return response[:n], nil
|
||||
}
|
||||
|
||||
func TestScrapeEmpty(t *testing.T) {
|
||||
srv, done, err := setupTracker(&config.DefaultConfig)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, sock, err := setupSocket()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
connID, err := requestConnectionID(sock)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
scrape, err := requestScrape(sock, connID, []string{"aaaaaaaaaaaaaaaaaaaa"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(scrape[:4], errorAction) {
|
||||
t.Error("expected error response")
|
||||
}
|
||||
|
||||
if string(scrape[8:]) != "torrent does not exist\000" {
|
||||
t.Error("expected torrent to not exist")
|
||||
}
|
||||
|
||||
srv.Stop()
|
||||
<-done
|
||||
}
|
|
@ -71,10 +71,11 @@ func (s *Server) serve(listenAddr string) error {
|
|||
|
||||
go func() {
|
||||
response, action := s.handlePacket(buffer[:n], addr)
|
||||
pool.GiveSlice(buffer)
|
||||
|
||||
if response != nil {
|
||||
sock.WriteToUDP(response, addr)
|
||||
}
|
||||
pool.GiveSlice(buffer)
|
||||
|
||||
if glog.V(2) {
|
||||
duration := time.Since(start)
|
||||
|
|
135
udp/udp_test.go
Normal file
135
udp/udp_test.go
Normal file
|
@ -0,0 +1,135 @@
|
|||
// Copyright 2015 The Chihaya Authors. All rights reserved.
|
||||
// Use of this source code is governed by the BSD 2-Clause license,
|
||||
// which can be found in the LICENSE file.
|
||||
|
||||
package udp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/chihaya/chihaya/config"
|
||||
"github.com/chihaya/chihaya/stats"
|
||||
"github.com/chihaya/chihaya/tracker"
|
||||
|
||||
_ "github.com/chihaya/chihaya/backend/noop"
|
||||
)
|
||||
|
||||
var testPort = "34137"
|
||||
var connectAction = []byte{0, 0, 0, 0}
|
||||
var announceAction = []byte{0, 0, 0, 1}
|
||||
var scrapeAction = []byte{0, 0, 0, 2}
|
||||
var errorAction = []byte{0, 0, 0, 3}
|
||||
|
||||
func init() {
|
||||
stats.DefaultStats = stats.New(config.StatsConfig{})
|
||||
}
|
||||
|
||||
func setupTracker(cfg *config.Config) (*Server, chan struct{}, error) {
|
||||
tkr, err := tracker.New(cfg)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
srv := NewServer(cfg, tkr)
|
||||
done := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
if err := srv.serve(":" + testPort); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
close(done)
|
||||
}()
|
||||
|
||||
<-srv.booting
|
||||
return srv, done, nil
|
||||
}
|
||||
|
||||
func setupSocket() (*net.UDPAddr, *net.UDPConn, error) {
|
||||
srvAddr, err := net.ResolveUDPAddr("udp", "localhost:"+testPort)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
sock, err := net.DialUDP("udp", nil, srvAddr)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return srvAddr, sock, err
|
||||
}
|
||||
|
||||
func makeTransactionID() []byte {
|
||||
out := make([]byte, 4)
|
||||
rand.Read(out)
|
||||
return out
|
||||
}
|
||||
|
||||
func sendRequest(sock *net.UDPConn, request, response []byte) (int, error) {
|
||||
if _, err := sock.Write(request); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
sock.SetReadDeadline(time.Now().Add(time.Second))
|
||||
n, err := sock.Read(response)
|
||||
|
||||
if err != nil {
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
return 0, fmt.Errorf("no response from tracker: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
func requestConnectionID(sock *net.UDPConn) ([]byte, error) {
|
||||
txID := makeTransactionID()
|
||||
request := []byte{}
|
||||
|
||||
request = append(request, initialConnectionID...)
|
||||
request = append(request, connectAction...)
|
||||
request = append(request, txID...)
|
||||
|
||||
response := make([]byte, 1024)
|
||||
n, err := sendRequest(sock, request, response)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if n != 16 {
|
||||
return nil, fmt.Errorf("packet length mismatch: %d != 16", n)
|
||||
}
|
||||
|
||||
if !bytes.Equal(response[4:8], txID) {
|
||||
return nil, fmt.Errorf("transaction ID mismatch")
|
||||
}
|
||||
|
||||
if !bytes.Equal(response[0:4], connectAction) {
|
||||
return nil, fmt.Errorf("action mismatch")
|
||||
}
|
||||
|
||||
return response[8:16], nil
|
||||
}
|
||||
|
||||
func TestRequestConnectionID(t *testing.T) {
|
||||
srv, done, err := setupTracker(&config.DefaultConfig)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, sock, err := setupSocket()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if _, err = requestConnectionID(sock); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
srv.Stop()
|
||||
<-done
|
||||
}
|
Loading…
Reference in a new issue