Merge branch 'udp' into develop

This commit is contained in:
Justin Li 2015-06-17 19:39:51 -04:00
commit a17474bb05
47 changed files with 1815 additions and 556 deletions

2
.gitignore vendored Normal file
View file

@ -0,0 +1,2 @@
/config.json
/chihaya

21
Godeps/Godeps.json generated
View file

@ -1,10 +1,10 @@
{ {
"ImportPath": "github.com/chihaya/chihaya", "ImportPath": "github.com/chihaya/chihaya",
"GoVersion": "go1.4.1", "GoVersion": "go1.4.2",
"Deps": [ "Deps": [
{ {
"ImportPath": "github.com/chihaya/bencode", "ImportPath": "github.com/chihaya/bencode",
"Rev": "e60878f635e1a61315c413492e133dd39769b1d1" "Rev": "3c485a8d166ff6a79baba90c2c2da01c8348e930"
}, },
{ {
"ImportPath": "github.com/golang/glog", "ImportPath": "github.com/golang/glog",
@ -12,7 +12,11 @@
}, },
{ {
"ImportPath": "github.com/julienschmidt/httprouter", "ImportPath": "github.com/julienschmidt/httprouter",
"Rev": "00ce1c6a267162792c367acc43b1681a884e1872" "Rev": "8c199fb6259ffc1af525cc3ad52ee60ba8359669"
},
{
"ImportPath": "github.com/pushrax/bufferpool",
"Rev": "7d6e1653dee10a165d1f357f3a57bc8031e9621b"
}, },
{ {
"ImportPath": "github.com/pushrax/faststats", "ImportPath": "github.com/pushrax/faststats",
@ -23,16 +27,13 @@
"Rev": "86044f1c998d49053e13293029414ddb63f3a422" "Rev": "86044f1c998d49053e13293029414ddb63f3a422"
}, },
{ {
"ImportPath": "github.com/stretchr/graceful", "ImportPath": "github.com/tylerb/graceful",
"Rev": "8e780ba3fe3d3e7ab15fc52e3d60a996587181dc" "Comment": "v1-7-g0c01122",
}, "Rev": "0c011221e91b35f488b8818b00ca279929e9ed7d"
{
"ImportPath": "github.com/stretchr/pat/stop",
"Rev": "f7fe051f2b9bcaca162b38de4f93c9a8457160b9"
}, },
{ {
"ImportPath": "golang.org/x/net/netutil", "ImportPath": "golang.org/x/net/netutil",
"Rev": "c84eff7014eba178f68bd4c05b86780efe0fbf35" "Rev": "d175081df37eff8cda13f478bc11a0a65b39958b"
} }
] ]
} }

View file

@ -62,6 +62,12 @@ func marshal(w io.Writer, data interface{}) error {
case uint: case uint:
marshalUint(w, uint64(v)) marshalUint(w, uint64(v))
case int16:
marshalInt(w, int64(v))
case uint16:
marshalUint(w, uint64(v))
case int64: case int64:
marshalInt(w, v) marshalInt(w, v)

View file

@ -19,6 +19,8 @@ var marshalTests = []struct {
{uint(43), "i43e"}, {uint(43), "i43e"},
{int64(44), "i44e"}, {int64(44), "i44e"},
{uint64(45), "i45e"}, {uint64(45), "i45e"},
{int16(44), "i44e"},
{uint16(45), "i45e"},
{"example", "7:example"}, {"example", "7:example"},
{[]byte("example"), "7:example"}, {[]byte("example"), "7:example"},

View file

@ -1,3 +1,4 @@
sudo: false
language: go language: go
go: go:
- 1.1 - 1.1

View file

@ -1,29 +1,17 @@
# HttpRouter [![Build Status](https://travis-ci.org/julienschmidt/httprouter.png?branch=master)](https://travis-ci.org/julienschmidt/httprouter) [![GoDoc](http://godoc.org/github.com/julienschmidt/httprouter?status.png)](http://godoc.org/github.com/julienschmidt/httprouter) # HttpRouter [![Build Status](https://travis-ci.org/julienschmidt/httprouter.png?branch=master)](https://travis-ci.org/julienschmidt/httprouter) [![Coverage](http://gocover.io/_badge/github.com/julienschmidt/httprouter?0)](http://gocover.io/github.com/julienschmidt/httprouter) [![GoDoc](http://godoc.org/github.com/julienschmidt/httprouter?status.png)](http://godoc.org/github.com/julienschmidt/httprouter)
HttpRouter is a lightweight high performance HTTP request router HttpRouter is a lightweight high performance HTTP request router
(also called *multiplexer* or just *mux* for short) for [Go](http://golang.org/). (also called *multiplexer* or just *mux* for short) for [Go](http://golang.org/).
In contrast to the default mux of Go's net/http package, this router supports In contrast to the [default mux](http://golang.org/pkg/net/http/#ServeMux) of Go's net/http package, this router supports
variables in the routing pattern and matches against the request method. variables in the routing pattern and matches against the request method.
It also scales better. It also scales better.
The router is optimized for best performance and a small memory footprint. The router is optimized for high performance and a small memory footprint.
It scales well even with very long paths and a large number of routes. It scales well even with very long paths and a large number of routes.
A compressing dynamic trie (radix tree) structure is used for efficient matching. A compressing dynamic trie (radix tree) structure is used for efficient matching.
## Features ## Features
**Zero Garbage:** The matching and dispatching process generates zero bytes of
garbage. In fact, the only heap allocations that are made, is by building the
slice of the key-value pairs for path parameters. If the request path contains
no parameters, not a single heap allocation is necessary.
**Best Performance:** [Benchmarks speak for themselves](https://github.com/julienschmidt/go-http-routing-benchmark).
See below for technical details of the implementation.
**Parameters in your routing pattern:** Stop parsing the requested URL path,
just give the path segment a name and the router delivers the dynamic value to
you. Because of the design of the router, path parameters are very cheap.
**Only explicit matches:** With other routers, like [http.ServeMux](http://golang.org/pkg/net/http/#ServeMux), **Only explicit matches:** With other routers, like [http.ServeMux](http://golang.org/pkg/net/http/#ServeMux),
a requested URL path could match multiple patterns. Therefore they have some a requested URL path could match multiple patterns. Therefore they have some
awkward pattern priority rules, like *longest match* or *first registered, awkward pattern priority rules, like *longest match* or *first registered,
@ -34,7 +22,7 @@ great for SEO and improves the user experience.
**Stop caring about trailing slashes:** Choose the URL style you like, the **Stop caring about trailing slashes:** Choose the URL style you like, the
router automatically redirects the client if a trailing slash is missing or if router automatically redirects the client if a trailing slash is missing or if
there is one extra. Of course it only does so, if the new path has a handler. there is one extra. Of course it only does so, if the new path has a handler.
If you don't like it, you can turn off this behavior. If you don't like it, you can [turn off this behavior](http://godoc.org/github.com/julienschmidt/httprouter#Router.RedirectTrailingSlash).
**Path auto-correction:** Besides detecting the missing or additional trailing **Path auto-correction:** Besides detecting the missing or additional trailing
slash at no extra cost, the router can also fix wrong cases and remove slash at no extra cost, the router can also fix wrong cases and remove
@ -43,11 +31,23 @@ Is [CAPTAIN CAPS LOCK](http://www.urbandictionary.com/define.php?term=Captain+Ca
HttpRouter can help him by making a case-insensitive look-up and redirecting him HttpRouter can help him by making a case-insensitive look-up and redirecting him
to the correct URL. to the correct URL.
**No more server crashes:** You can set a PanicHandler to deal with panics **Parameters in your routing pattern:** Stop parsing the requested URL path,
just give the path segment a name and the router delivers the dynamic value to
you. Because of the design of the router, path parameters are very cheap.
**Zero Garbage:** The matching and dispatching process generates zero bytes of
garbage. In fact, the only heap allocations that are made, is by building the
slice of the key-value pairs for path parameters. If the request path contains
no parameters, not a single heap allocation is necessary.
**Best Performance:** [Benchmarks speak for themselves](https://github.com/julienschmidt/go-http-routing-benchmark).
See below for technical details of the implementation.
**No more server crashes:** You can set a [Panic handler](http://godoc.org/github.com/julienschmidt/httprouter#Router.PanicHandler) to deal with panics
occurring during handling a HTTP request. The router then recovers and lets the occurring during handling a HTTP request. The router then recovers and lets the
PanicHandler log what happened and deliver a nice error page. PanicHandler log what happened and deliver a nice error page.
Of course you can also set a **custom NotFound handler** and **serve static files**. Of course you can also set **custom [NotFound](http://godoc.org/github.com/julienschmidt/httprouter#Router.NotFound) and [MethodNotAllowed](http://godoc.org/github.com/julienschmidt/httprouter#Router.MethodNotAllowed) handlers** and [**serve static files**](http://godoc.org/github.com/julienschmidt/httprouter#Router.ServeFiles).
## Usage ## Usage
This is just a quick introduction, view the [GoDoc](http://godoc.org/github.com/julienschmidt/httprouter) for details. This is just a quick introduction, view the [GoDoc](http://godoc.org/github.com/julienschmidt/httprouter) for details.
@ -189,7 +189,7 @@ for example the [Gorilla handlers](http://www.gorillatoolkit.org/pkg/handlers).
Or you could [just write your own](http://justinas.org/writing-http-middleware-in-go/), Or you could [just write your own](http://justinas.org/writing-http-middleware-in-go/),
it's very easy! it's very easy!
Alternatively, you could try [a framework building upon HttpRouter](#web-frameworks--co-based-on-httprouter). Alternatively, you could try [a web framework based on HttpRouter](#web-frameworks-based-on-httprouter).
### Multi-domain / Sub-domains ### Multi-domain / Sub-domains
Here is a quick example: Does your server serve multiple domains / hosts? Here is a quick example: Does your server serve multiple domains / hosts?
@ -256,7 +256,10 @@ func BasicAuth(h httprouter.Handle, user, pass []byte) httprouter.Handle {
payload, err := base64.StdEncoding.DecodeString(auth[len(basicAuthPrefix):]) payload, err := base64.StdEncoding.DecodeString(auth[len(basicAuthPrefix):])
if err == nil { if err == nil {
pair := bytes.SplitN(payload, []byte(":"), 2) pair := bytes.SplitN(payload, []byte(":"), 2)
if len(pair) == 2 && bytes.Equal(pair[0], user) && bytes.Equal(pair[1], pass) { if len(pair) == 2 &&
bytes.Equal(pair[0], user) &&
bytes.Equal(pair[1], pass) {
// Delegate request to the given handle // Delegate request to the given handle
h(w, r, ps) h(w, r, ps)
return return
@ -305,9 +308,16 @@ router.NotFound = http.FileServer(http.Dir("public")).ServeHTTP
But this approach sidesteps the strict core rules of this router to avoid routing problems. A cleaner approach is to use a distinct sub-path for serving files, like `/static/*filepath` or `/files/*filepath`. But this approach sidesteps the strict core rules of this router to avoid routing problems. A cleaner approach is to use a distinct sub-path for serving files, like `/static/*filepath` or `/files/*filepath`.
## Web Frameworks & Co based on HttpRouter ## Web Frameworks based on HttpRouter
If the HttpRouter is a bit too minimalistic for you, you might try one of the following more high-level 3rd-party web frameworks building upon the HttpRouter package: If the HttpRouter is a bit too minimalistic for you, you might try one of the following more high-level 3rd-party web frameworks building upon the HttpRouter package:
* [Ace](https://github.com/plimble/ace): Blazing fast Go Web Framework
* [api2go](https://github.com/univedo/api2go): A JSON API Implementation for Go
* [Gin](https://github.com/gin-gonic/gin): Features a martini-like API with much better performance * [Gin](https://github.com/gin-gonic/gin): Features a martini-like API with much better performance
* [Goat](https://github.com/bahlo/goat): A minimalistic REST API server in Go
* [Hikaru](https://github.com/najeira/hikaru): Supports standalone and Google AppEngine * [Hikaru](https://github.com/najeira/hikaru): Supports standalone and Google AppEngine
* [Hitch](https://github.com/nbio/hitch): Hitch ties httprouter, [httpcontext](https://github.com/nbio/httpcontext), and middleware up in a bow * [Hitch](https://github.com/nbio/hitch): Hitch ties httprouter, [httpcontext](https://github.com/nbio/httpcontext), and middleware up in a bow
* [kami](https://github.com/guregu/kami): A tiny web framework using x/net/context
* [Medeina](https://github.com/imdario/medeina): Inspired by Ruby's Roda and Cuba
* [Neko](https://github.com/rocwong/neko): A lightweight web application framework for Golang * [Neko](https://github.com/rocwong/neko): A lightweight web application framework for Golang
* [Roxanna](https://github.com/iamthemuffinman/Roxanna): An amalgamation of httprouter, better logging, and hot reload
* [siesta](https://github.com/VividCortex/siesta): Composable HTTP handlers with contexts

View file

@ -142,6 +142,11 @@ type Router struct {
// found. If it is not set, http.NotFound is used. // found. If it is not set, http.NotFound is used.
NotFound http.HandlerFunc NotFound http.HandlerFunc
// Configurable http.HandlerFunc which is called when a request
// cannot be routed and HandleMethodNotAllowed is true.
// If it is not set, http.Error with http.StatusMethodNotAllowed is used.
MethodNotAllowed http.HandlerFunc
// Function to handle panics recovered from http handlers. // Function to handle panics recovered from http handlers.
// It should be used to generate a error page and return the http error code // It should be used to generate a error page and return the http error code
// 500 (Internal Server Error). // 500 (Internal Server Error).
@ -173,6 +178,11 @@ func (r *Router) HEAD(path string, handle Handle) {
r.Handle("HEAD", path, handle) r.Handle("HEAD", path, handle)
} }
// OPTIONS is a shortcut for router.Handle("OPTIONS", path, handle)
func (r *Router) OPTIONS(path string, handle Handle) {
r.Handle("OPTIONS", path, handle)
}
// POST is a shortcut for router.Handle("POST", path, handle) // POST is a shortcut for router.Handle("POST", path, handle)
func (r *Router) POST(path string, handle Handle) { func (r *Router) POST(path string, handle Handle) {
r.Handle("POST", path, handle) r.Handle("POST", path, handle)
@ -203,7 +213,7 @@ func (r *Router) DELETE(path string, handle Handle) {
// communication with a proxy). // communication with a proxy).
func (r *Router) Handle(method, path string, handle Handle) { func (r *Router) Handle(method, path string, handle Handle) {
if path[0] != '/' { if path[0] != '/' {
panic("path must begin with '/'") panic("path must begin with '/' in path '" + path + "'")
} }
if r.trees == nil { if r.trees == nil {
@ -232,11 +242,7 @@ func (r *Router) Handler(method, path string, handler http.Handler) {
// HandlerFunc is an adapter which allows the usage of an http.HandlerFunc as a // HandlerFunc is an adapter which allows the usage of an http.HandlerFunc as a
// request handle. // request handle.
func (r *Router) HandlerFunc(method, path string, handler http.HandlerFunc) { func (r *Router) HandlerFunc(method, path string, handler http.HandlerFunc) {
r.Handle(method, path, r.Handler(method, path, handler)
func(w http.ResponseWriter, req *http.Request, _ Params) {
handler(w, req)
},
)
} }
// ServeFiles serves files from the given file system root. // ServeFiles serves files from the given file system root.
@ -251,7 +257,7 @@ func (r *Router) HandlerFunc(method, path string, handler http.HandlerFunc) {
// router.ServeFiles("/src/*filepath", http.Dir("/var/www")) // router.ServeFiles("/src/*filepath", http.Dir("/var/www"))
func (r *Router) ServeFiles(path string, root http.FileSystem) { func (r *Router) ServeFiles(path string, root http.FileSystem) {
if len(path) < 10 || path[len(path)-10:] != "/*filepath" { if len(path) < 10 || path[len(path)-10:] != "/*filepath" {
panic("path must end with /*filepath") panic("path must end with /*filepath in path '" + path + "'")
} }
fileServer := http.FileServer(root) fileServer := http.FileServer(root)
@ -335,10 +341,14 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
handle, _, _ := r.trees[method].getValue(req.URL.Path) handle, _, _ := r.trees[method].getValue(req.URL.Path)
if handle != nil { if handle != nil {
http.Error(w, if r.MethodNotAllowed != nil {
http.StatusText(http.StatusMethodNotAllowed), r.MethodNotAllowed(w, req)
http.StatusMethodNotAllowed, } else {
) http.Error(w,
http.StatusText(http.StatusMethodNotAllowed),
http.StatusMethodNotAllowed,
)
}
return return
} }
} }

View file

@ -76,7 +76,7 @@ func (h handlerStruct) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
func TestRouterAPI(t *testing.T) { func TestRouterAPI(t *testing.T) {
var get, head, post, put, patch, delete, handler, handlerFunc bool var get, head, options, post, put, patch, delete, handler, handlerFunc bool
httpHandler := handlerStruct{&handler} httpHandler := handlerStruct{&handler}
@ -87,6 +87,9 @@ func TestRouterAPI(t *testing.T) {
router.HEAD("/GET", func(w http.ResponseWriter, r *http.Request, _ Params) { router.HEAD("/GET", func(w http.ResponseWriter, r *http.Request, _ Params) {
head = true head = true
}) })
router.OPTIONS("/GET", func(w http.ResponseWriter, r *http.Request, _ Params) {
options = true
})
router.POST("/POST", func(w http.ResponseWriter, r *http.Request, _ Params) { router.POST("/POST", func(w http.ResponseWriter, r *http.Request, _ Params) {
post = true post = true
}) })
@ -118,6 +121,12 @@ func TestRouterAPI(t *testing.T) {
t.Error("routing HEAD failed") t.Error("routing HEAD failed")
} }
r, _ = http.NewRequest("OPTIONS", "/GET", nil)
router.ServeHTTP(w, r)
if !options {
t.Error("routing OPTIONS failed")
}
r, _ = http.NewRequest("POST", "/POST", nil) r, _ = http.NewRequest("POST", "/POST", nil)
router.ServeHTTP(w, r) router.ServeHTTP(w, r)
if !post { if !post {
@ -176,7 +185,21 @@ func TestRouterNotAllowed(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
router.ServeHTTP(w, r) router.ServeHTTP(w, r)
if !(w.Code == http.StatusMethodNotAllowed) { if !(w.Code == http.StatusMethodNotAllowed) {
t.Errorf("NotAllowed handling route %s failed: Code=%d, Header=%v", w.Code, w.Header()) t.Errorf("NotAllowed handling failed: Code=%d, Header=%v", w.Code, w.Header())
}
w = httptest.NewRecorder()
responseText := "custom method"
router.MethodNotAllowed = func(w http.ResponseWriter, req *http.Request) {
w.WriteHeader(http.StatusTeapot)
w.Write([]byte(responseText))
}
router.ServeHTTP(w, r)
if got := w.Body.String(); !(got == responseText) {
t.Errorf("unexpected response got %q want %q", got, responseText)
}
if w.Code != http.StatusTeapot {
t.Errorf("unexpected response code %d want %d", w.Code, http.StatusTeapot)
} }
} }

View file

@ -43,41 +43,48 @@ type node struct {
wildChild bool wildChild bool
nType nodeType nType nodeType
maxParams uint8 maxParams uint8
indices []byte indices string
children []*node children []*node
handle Handle handle Handle
priority uint32 priority uint32
} }
// increments priority of the given child and reorders if necessary // increments priority of the given child and reorders if necessary
func (n *node) incrementChildPrio(i int) int { func (n *node) incrementChildPrio(pos int) int {
n.children[i].priority++ n.children[pos].priority++
prio := n.children[i].priority prio := n.children[pos].priority
// adjust position (move to front) // adjust position (move to front)
for j := i - 1; j >= 0 && n.children[j].priority < prio; j-- { newPos := pos
for newPos > 0 && n.children[newPos-1].priority < prio {
// swap node positions // swap node positions
tmpN := n.children[j] tmpN := n.children[newPos-1]
n.children[j] = n.children[i] n.children[newPos-1] = n.children[newPos]
n.children[i] = tmpN n.children[newPos] = tmpN
tmpI := n.indices[j]
n.indices[j] = n.indices[i]
n.indices[i] = tmpI
i-- newPos--
} }
return i
// build new index char string
if newPos != pos {
n.indices = n.indices[:newPos] + // unchanged prefix, might be empty
n.indices[pos:pos+1] + // the index char we move
n.indices[newPos:pos] + n.indices[pos+1:] // rest without char at 'pos'
}
return newPos
} }
// addRoute adds a node with the given handle to the path. // addRoute adds a node with the given handle to the path.
// Not concurrency-safe! // Not concurrency-safe!
func (n *node) addRoute(path string, handle Handle) { func (n *node) addRoute(path string, handle Handle) {
fullPath := path
n.priority++ n.priority++
numParams := countParams(path) numParams := countParams(path)
// non-empty tree // non-empty tree
if len(n.path) > 0 || len(n.children) > 0 { if len(n.path) > 0 || len(n.children) > 0 {
WALK: walk:
for { for {
// Update maxParams of the current node // Update maxParams of the current node
if numParams > n.maxParams { if numParams > n.maxParams {
@ -85,10 +92,12 @@ func (n *node) addRoute(path string, handle Handle) {
} }
// Find the longest common prefix. // Find the longest common prefix.
// This also implies that the commom prefix contains no ':' or '*' // This also implies that the common prefix contains no ':' or '*'
// since the existing key can't contain this chars. // since the existing key can't contain those chars.
i := 0 i := 0
for max := min(len(path), len(n.path)); i < max && path[i] == n.path[i]; i++ { max := min(len(path), len(n.path))
for i < max && path[i] == n.path[i] {
i++
} }
// Split edge // Split edge
@ -110,7 +119,8 @@ func (n *node) addRoute(path string, handle Handle) {
} }
n.children = []*node{&child} n.children = []*node{&child}
n.indices = []byte{n.path[i]} // []byte for proper unicode char conversion, see #65
n.indices = string([]byte{n.path[i]})
n.path = path[:i] n.path = path[:i]
n.handle = nil n.handle = nil
n.wildChild = false n.wildChild = false
@ -134,11 +144,13 @@ func (n *node) addRoute(path string, handle Handle) {
if len(path) >= len(n.path) && n.path == path[:len(n.path)] { if len(path) >= len(n.path) && n.path == path[:len(n.path)] {
// check for longer wildcard, e.g. :name and :names // check for longer wildcard, e.g. :name and :names
if len(n.path) >= len(path) || path[len(n.path)] == '/' { if len(n.path) >= len(path) || path[len(n.path)] == '/' {
continue WALK continue walk
} }
} }
panic("conflict with wildcard route") panic("path segment '" + path +
"' conflicts with existing wildcard '" + n.path +
"' in path '" + fullPath + "'")
} }
c := path[0] c := path[0]
@ -147,21 +159,22 @@ func (n *node) addRoute(path string, handle Handle) {
if n.nType == param && c == '/' && len(n.children) == 1 { if n.nType == param && c == '/' && len(n.children) == 1 {
n = n.children[0] n = n.children[0]
n.priority++ n.priority++
continue WALK continue walk
} }
// Check if a child with the next path byte exists // Check if a child with the next path byte exists
for i, index := range n.indices { for i := 0; i < len(n.indices); i++ {
if c == index { if c == n.indices[i] {
i = n.incrementChildPrio(i) i = n.incrementChildPrio(i)
n = n.children[i] n = n.children[i]
continue WALK continue walk
} }
} }
// Otherwise insert it // Otherwise insert it
if c != ':' && c != '*' { if c != ':' && c != '*' {
n.indices = append(n.indices, c) // []byte for proper unicode char conversion, see #65
n.indices += string([]byte{c})
child := &node{ child := &node{
maxParams: numParams, maxParams: numParams,
} }
@ -169,24 +182,24 @@ func (n *node) addRoute(path string, handle Handle) {
n.incrementChildPrio(len(n.indices) - 1) n.incrementChildPrio(len(n.indices) - 1)
n = child n = child
} }
n.insertChild(numParams, path, handle) n.insertChild(numParams, path, fullPath, handle)
return return
} else if i == len(path) { // Make node a (in-path) leaf } else if i == len(path) { // Make node a (in-path) leaf
if n.handle != nil { if n.handle != nil {
panic("a Handle is already registered for this path") panic("a handle is already registered for path ''" + fullPath + "'")
} }
n.handle = handle n.handle = handle
} }
return return
} }
} else { // Empty tree } else { // Empty tree
n.insertChild(numParams, path, handle) n.insertChild(numParams, path, fullPath, handle)
} }
} }
func (n *node) insertChild(numParams uint8, path string, handle Handle) { func (n *node) insertChild(numParams uint8, path, fullPath string, handle Handle) {
var offset int var offset int // already handled bytes of the path
// find prefix until first wildcard (beginning with ':'' or '*'') // find prefix until first wildcard (beginning with ':'' or '*'')
for i, max := 0, len(path); numParams > 0; i++ { for i, max := 0, len(path); numParams > 0; i++ {
@ -195,20 +208,29 @@ func (n *node) insertChild(numParams uint8, path string, handle Handle) {
continue continue
} }
// Check if this Node existing children which would be
// unreachable if we insert the wildcard here
if len(n.children) > 0 {
panic("wildcard route conflicts with existing children")
}
// find wildcard end (either '/' or path end) // find wildcard end (either '/' or path end)
end := i + 1 end := i + 1
for end < max && path[end] != '/' { for end < max && path[end] != '/' {
end++ switch path[end] {
// the wildcard name must not contain ':' and '*'
case ':', '*':
panic("only one wildcard per path segment is allowed, has: '" +
path[i:] + "' in path '" + fullPath + "'")
default:
end++
}
} }
// check if this Node existing children which would be
// unreachable if we insert the wildcard here
if len(n.children) > 0 {
panic("wildcard route '" + path[i:end] +
"' conflicts with existing children in path '" + fullPath + "'")
}
// check if the wildcard has a name
if end-i < 2 { if end-i < 2 {
panic("wildcards must be named with a non-empty name") panic("wildcards must be named with a non-empty name in path '" + fullPath + "'")
} }
if c == ':' { // param if c == ':' { // param
@ -244,17 +266,17 @@ func (n *node) insertChild(numParams uint8, path string, handle Handle) {
} else { // catchAll } else { // catchAll
if end != max || numParams > 1 { if end != max || numParams > 1 {
panic("catch-all routes are only allowed at the end of the path") panic("catch-all routes are only allowed at the end of the path in path '" + fullPath + "'")
} }
if len(n.path) > 0 && n.path[len(n.path)-1] == '/' { if len(n.path) > 0 && n.path[len(n.path)-1] == '/' {
panic("catch-all conflicts with existing handle for the path segment root") panic("catch-all conflicts with existing handle for the path segment root in path '" + fullPath + "'")
} }
// currently fixed width 1 for '/' // currently fixed width 1 for '/'
i-- i--
if path[i] != '/' { if path[i] != '/' {
panic("no / before catch-all") panic("no / before catch-all in path '" + fullPath + "'")
} }
n.path = path[offset:i] n.path = path[offset:i]
@ -266,7 +288,7 @@ func (n *node) insertChild(numParams uint8, path string, handle Handle) {
maxParams: 1, maxParams: 1,
} }
n.children = []*node{child} n.children = []*node{child}
n.indices = []byte{path[i]} n.indices = string(path[i])
n = child n = child
n.priority++ n.priority++
@ -305,8 +327,8 @@ walk: // Outer loop for walking the tree
// to walk down the tree // to walk down the tree
if !n.wildChild { if !n.wildChild {
c := path[0] c := path[0]
for i, index := range n.indices { for i := 0; i < len(n.indices); i++ {
if c == index { if c == n.indices[i] {
n = n.children[i] n = n.children[i]
continue walk continue walk
} }
@ -379,7 +401,7 @@ walk: // Outer loop for walking the tree
return return
default: default:
panic("Invalid node type") panic("invalid node type")
} }
} }
} else if path == n.path { } else if path == n.path {
@ -391,10 +413,10 @@ walk: // Outer loop for walking the tree
// No handle found. Check if a handle for this path + a // No handle found. Check if a handle for this path + a
// trailing slash exists for trailing slash recommendation // trailing slash exists for trailing slash recommendation
for i, index := range n.indices { for i := 0; i < len(n.indices); i++ {
if index == '/' { if n.indices[i] == '/' {
n = n.children[i] n = n.children[i]
tsr = (n.path == "/" && n.handle != nil) || tsr = (len(n.path) == 1 && n.handle != nil) ||
(n.nType == catchAll && n.children[0].handle != nil) (n.nType == catchAll && n.children[0].handle != nil)
return return
} }
@ -414,7 +436,7 @@ walk: // Outer loop for walking the tree
// Makes a case-insensitive lookup of the given path and tries to find a handler. // Makes a case-insensitive lookup of the given path and tries to find a handler.
// It can optionally also fix trailing slashes. // It can optionally also fix trailing slashes.
// It returns the case-corrected path and a bool indicating wether the lookup // It returns the case-corrected path and a bool indicating whether the lookup
// was successful. // was successful.
func (n *node) findCaseInsensitivePath(path string, fixTrailingSlash bool) (ciPath []byte, found bool) { func (n *node) findCaseInsensitivePath(path string, fixTrailingSlash bool) (ciPath []byte, found bool) {
ciPath = make([]byte, 0, len(path)+1) // preallocate enough memory ciPath = make([]byte, 0, len(path)+1) // preallocate enough memory
@ -433,7 +455,7 @@ func (n *node) findCaseInsensitivePath(path string, fixTrailingSlash bool) (ciPa
for i, index := range n.indices { for i, index := range n.indices {
// must use recursive approach since both index and // must use recursive approach since both index and
// ToLower(index) could exist. We must check both. // ToLower(index) could exist. We must check both.
if r == unicode.ToLower(rune(index)) { if r == unicode.ToLower(index) {
out, found := n.children[i].findCaseInsensitivePath(path, fixTrailingSlash) out, found := n.children[i].findCaseInsensitivePath(path, fixTrailingSlash)
if found { if found {
return append(ciPath, out...), true return append(ciPath, out...), true
@ -445,53 +467,52 @@ func (n *node) findCaseInsensitivePath(path string, fixTrailingSlash bool) (ciPa
// without a trailing slash if a leaf exists for that path // without a trailing slash if a leaf exists for that path
found = (fixTrailingSlash && path == "/" && n.handle != nil) found = (fixTrailingSlash && path == "/" && n.handle != nil)
return return
}
} else { n = n.children[0]
n = n.children[0] switch n.nType {
case param:
// find param end (either '/' or path end)
k := 0
for k < len(path) && path[k] != '/' {
k++
}
switch n.nType { // add param value to case insensitive path
case param: ciPath = append(ciPath, path[:k]...)
// find param end (either '/' or path end)
k := 0
for k < len(path) && path[k] != '/' {
k++
}
// add param value to case insensitive path // we need to go deeper!
ciPath = append(ciPath, path[:k]...) if k < len(path) {
if len(n.children) > 0 {
// we need to go deeper! path = path[k:]
if k < len(path) {
if len(n.children) > 0 {
path = path[k:]
n = n.children[0]
continue
} else { // ... but we can't
if fixTrailingSlash && len(path) == k+1 {
return ciPath, true
}
return
}
}
if n.handle != nil {
return ciPath, true
} else if fixTrailingSlash && len(n.children) == 1 {
// No handle found. Check if a handle for this path + a
// trailing slash exists
n = n.children[0] n = n.children[0]
if n.path == "/" && n.handle != nil { continue
return append(ciPath, '/'), true }
}
// ... but we can't
if fixTrailingSlash && len(path) == k+1 {
return ciPath, true
} }
return return
case catchAll:
return append(ciPath, path...), true
default:
panic("Invalid node type")
} }
if n.handle != nil {
return ciPath, true
} else if fixTrailingSlash && len(n.children) == 1 {
// No handle found. Check if a handle for this path + a
// trailing slash exists
n = n.children[0]
if n.path == "/" && n.handle != nil {
return append(ciPath, '/'), true
}
}
return
case catchAll:
return append(ciPath, path...), true
default:
panic("invalid node type")
} }
} else { } else {
// We should have reached the node containing the handle. // We should have reached the node containing the handle.
@ -503,10 +524,10 @@ func (n *node) findCaseInsensitivePath(path string, fixTrailingSlash bool) (ciPa
// No handle found. // No handle found.
// Try to fix the path by adding a trailing slash // Try to fix the path by adding a trailing slash
if fixTrailingSlash { if fixTrailingSlash {
for i, index := range n.indices { for i := 0; i < len(n.indices); i++ {
if index == '/' { if n.indices[i] == '/' {
n = n.children[i] n = n.children[i]
if (n.path == "/" && n.handle != nil) || if (len(n.path) == 1 && n.handle != nil) ||
(n.nType == catchAll && n.children[0].handle != nil) { (n.nType == catchAll && n.children[0].handle != nil) {
return append(ciPath, '/'), true return append(ciPath, '/'), true
} }

View file

@ -125,6 +125,8 @@ func TestTreeAddAndGet(t *testing.T) {
"/doc/", "/doc/",
"/doc/go_faq.html", "/doc/go_faq.html",
"/doc/go1.html", "/doc/go1.html",
"/α",
"/β",
} }
for _, route := range routes { for _, route := range routes {
tree.addRoute(route, fakeHandler(route)) tree.addRoute(route, fakeHandler(route))
@ -142,6 +144,8 @@ func TestTreeAddAndGet(t *testing.T) {
{"/cona", true, "", nil}, // key mismatch {"/cona", true, "", nil}, // key mismatch
{"/no", true, "", nil}, // no matching child {"/no", true, "", nil}, // no matching child
{"/ab", false, "/ab", nil}, {"/ab", false, "/ab", nil},
{"/α", false, "/α", nil},
{"/β", false, "/β", nil},
}) })
checkPriorities(t, tree) checkPriorities(t, tree)
@ -339,6 +343,27 @@ func TestTreeCatchAllConflictRoot(t *testing.T) {
testRoutes(t, routes) testRoutes(t, routes)
} }
func TestTreeDoubleWildcard(t *testing.T) {
const panicMsg = "only one wildcard per path segment is allowed"
routes := [...]string{
"/:foo:bar",
"/:foo:bar/",
"/:foo*bar",
}
for _, route := range routes {
tree := &node{}
recv := catchPanic(func() {
tree.addRoute(route, nil)
})
if rs, ok := recv.(string); !ok || !strings.HasPrefix(rs, panicMsg) {
t.Fatalf(`"Expected panic "%s" for route '%s', got "%v"`, panicMsg, route, recv)
}
}
}
/*func TestTreeDuplicateWildcard(t *testing.T) { /*func TestTreeDuplicateWildcard(t *testing.T) {
tree := &node{} tree := &node{}
@ -559,6 +584,8 @@ func TestTreeFindCaseInsensitivePath(t *testing.T) {
} }
func TestTreeInvalidNodeType(t *testing.T) { func TestTreeInvalidNodeType(t *testing.T) {
const panicMsg = "invalid node type"
tree := &node{} tree := &node{}
tree.addRoute("/", fakeHandler("/")) tree.addRoute("/", fakeHandler("/"))
tree.addRoute("/:page", fakeHandler("/:page")) tree.addRoute("/:page", fakeHandler("/:page"))
@ -570,15 +597,15 @@ func TestTreeInvalidNodeType(t *testing.T) {
recv := catchPanic(func() { recv := catchPanic(func() {
tree.getValue("/test") tree.getValue("/test")
}) })
if rs, ok := recv.(string); !ok || rs != "Invalid node type" { if rs, ok := recv.(string); !ok || rs != panicMsg {
t.Fatalf(`Expected panic "Invalid node type", got "%v"`, recv) t.Fatalf("Expected panic '"+panicMsg+"', got '%v'", recv)
} }
// case-insensitive lookup // case-insensitive lookup
recv = catchPanic(func() { recv = catchPanic(func() {
tree.findCaseInsensitivePath("/test", true) tree.findCaseInsensitivePath("/test", true)
}) })
if rs, ok := recv.(string); !ok || rs != "Invalid node type" { if rs, ok := recv.(string); !ok || rs != panicMsg {
t.Fatalf(`Expected panic "Invalid node type", got "%v"`, recv) t.Fatalf("Expected panic '"+panicMsg+"', got '%v'", recv)
} }
} }

View file

@ -0,0 +1 @@
language: go

View file

@ -0,0 +1,4 @@
# This is the official list of Bufferpool authors for copyright purposes.
Jimmy Zelinskie <jimmyzelinskie@gmail.com>
Justin Li <jli@j-li.net>

View file

@ -0,0 +1,24 @@
Bufferpool is released under a BSD 2-Clause license, reproduced below.
Copyright (c) 2013, The Bufferpool Authors
All rights reserved.
Redistribution and use in source and binary forms, with or without modification,
are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice,
this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View file

@ -0,0 +1,9 @@
# bufferpool [![Build Status](https://secure.travis-ci.org/pushrax/bufferpool.png)](http://travis-ci.org/pushrax/bufferpool)
The bufferpool package implements a thread-safe pool of reusable, equally sized `byte.Buffer`s.
If you're allocating `byte.Buffer`s very frequently, you can use this to speed up your
program and take strain off the garbage collector.
## docs
[GoDoc](http://godoc.org/github.com/pushrax/bufferpool)

View file

@ -0,0 +1,69 @@
// Copyright 2013 The Bufferpool Authors. All rights reserved.
// Use of this source code is governed by the BSD 2-Clause license,
// which can be found in the LICENSE file.
// Package bufferpool implements a capacity-limited pool of reusable,
// equally-sized buffers.
package bufferpool
import (
"bytes"
"errors"
)
// A BufferPool is a capacity-limited pool of equally sized buffers.
type BufferPool struct {
bufferSize int
pool chan []byte
}
// New returns a newly allocated BufferPool with the given maximum pool size
// and buffer size.
func New(poolSize, bufferSize int) *BufferPool {
return &BufferPool{
bufferSize,
make(chan []byte, poolSize),
}
}
// Take is used to obtain a new zeroed buffer. This will allocate a new buffer
// if the pool was empty.
func (pool *BufferPool) Take() *bytes.Buffer {
return bytes.NewBuffer(pool.TakeSlice()[:0])
}
// TakeSlice is used to obtain a new slice. This will allocate a new slice
// if the pool was empty.
func (pool *BufferPool) TakeSlice() (slice []byte) {
select {
case slice = <-pool.pool:
default:
slice = make([]byte, pool.bufferSize)
}
return
}
// Give is used to attempt to return a buffer to the pool. It may not
// be added to the pool if it was already full.
func (pool *BufferPool) Give(buf *bytes.Buffer) error {
buf.Reset()
slice := buf.Bytes()
if cap(slice) < pool.bufferSize {
return errors.New("Gave an incorrectly sized buffer to the pool.")
}
return pool.GiveSlice(slice[:buf.Len()])
}
// GiveSlice is used to attempt to return a slice to the pool. It may not
// be added to the pool if it was already full.
func (pool *BufferPool) GiveSlice(slice []byte) error {
select {
case pool.pool <- slice:
// Everything went smoothly!
default:
return errors.New("Gave a buffer to a full pool.")
}
return nil
}

View file

@ -0,0 +1,67 @@
// Copyright 2013 The Bufferpool Authors. All rights reserved.
// Use of this source code is governed by the BSD 2-Clause license,
// which can be found in the LICENSE file.
package bufferpool_test
import (
"bytes"
"fmt"
"testing"
"github.com/pushrax/bufferpool"
)
func ExampleNew() {
bp := bufferpool.New(10, 255)
dogBuffer := bp.Take()
dogBuffer.WriteString("Dog!")
bp.Give(dogBuffer)
catBuffer := bp.Take() // dogBuffer is reused and reset.
catBuffer.WriteString("Cat!")
fmt.Println(catBuffer)
// Output:
// Cat!
}
func TestTakeFromEmpty(t *testing.T) {
bp := bufferpool.New(1, 1)
poolBuf := bp.Take()
if !bytes.Equal(poolBuf.Bytes(), []byte("")) {
t.Fatalf("Buffer from empty bufferpool was allocated incorrectly.")
}
}
func TestTakeFromFilled(t *testing.T) {
bp := bufferpool.New(1, 1)
origBuf := bytes.NewBuffer([]byte("X"))
bp.Give(origBuf)
reusedBuf := bp.Take()
if !bytes.Equal(reusedBuf.Bytes(), []byte("")) {
t.Fatalf("Buffer from filled bufferpool was recycled incorrectly.")
}
// Compare addresses of the first element in the underlying slice.
if &origBuf.Bytes()[:1][0] != &reusedBuf.Bytes()[:1][0] {
t.Fatalf("Recycled buffer points at different address.")
}
}
func TestSliceSemantics(t *testing.T) {
bp := bufferpool.New(1, 8)
buf := bp.Take()
buf.WriteString("abc")
bp.Give(buf)
buf2 := bp.TakeSlice()
if !bytes.Equal(buf2[:3], []byte("abc")) {
t.Fatalf("Buffer from filled bufferpool was recycled incorrectly.")
}
}

View file

@ -1 +0,0 @@
box: wercker/golang

View file

@ -1,46 +0,0 @@
// Package stop represents a pattern for types that need to do some work
// when stopping. The StopChan method returns a <-chan stop.Signal which
// is closed when the operation has completed.
//
// Stopper types when implementing the stop channel pattern should use stop.Make
// to create and store a stop channel, and close the channel once stopping has completed:
// func New() Type {
// t := new(Type)
// t.stopChan = stop.Make()
// return t
// }
// func (t Type) Stop() {
// go func(){
// // TODO: tear stuff down
// close(t.stopChan)
// }()
// }
// func (t Type) StopChan() <-chan stop.Signal {
// return t.stopChan
// }
//
// Stopper types can be stopped in the following ways:
// // stop and forget
// t.Stop(1 * time.Second)
//
// // stop and wait
// t.Stop(1 * time.Second)
// <-t.StopChan()
//
// // stop, do more work, then wait
// t.Stop(1 * time.Second);
// // do more work
// <-t.StopChan()
//
// // stop and timeout after 1 second
// t.Stop(1 * time.Second)
// select {
// case <-t.StopChan():
// case <-time.After(1 * time.Second):
// }
//
// // stop.All is the same as calling Stop() then StopChan() so
// // all above patterns also work on many Stopper types,
// // for example; stop and wait for many things:
// <-stop.All(1 * time.Second, t1, t2, t3)
package stop

View file

@ -1,57 +0,0 @@
package stop
import "time"
// Signal is the type that gets sent down the stop channel.
type Signal struct{}
// NoWait represents a time.Duration with zero value.
// Logically meaning no grace wait period when stopping.
var NoWait time.Duration
// Stopper represents types that implement
// the stop channel pattern.
type Stopper interface {
// Stop instructs the type to halt operations and close
// the stop channel when it is finished.
Stop(wait time.Duration)
// StopChan gets the stop channel which will block until
// stopping has completed, at which point it is closed.
// Callers should never close the stop channel.
// The StopChan should exist from the point at which operations
// begun, not the point at which Stop was called.
StopChan() <-chan Signal
}
// Stopped returns a channel that signals immediately. Useful for
// cases when no tear-down work is required and stopping is
// immediate.
func Stopped() <-chan Signal {
c := Make()
close(c)
return c
}
// Make makes a new channel used to indicate when
// stopping has finished. Sends to channel will not block.
func Make() chan Signal {
return make(chan Signal, 0)
}
// All stops all Stopper types and returns another channel
// which will close once all things have finished stopping.
func All(wait time.Duration, stoppers ...Stopper) <-chan Signal {
all := Make()
go func() {
var allChans []<-chan Signal
for _, stopper := range stoppers {
go stopper.Stop(wait)
allChans = append(allChans, stopper.StopChan())
}
for _, ch := range allChans {
<-ch
}
close(all)
}()
return all
}

View file

@ -1,76 +0,0 @@
package stop_test
import (
"testing"
"time"
"github.com/stretchr/pat/stop"
)
type testStopper struct {
stopChan chan stop.Signal
}
func NewTestStopper() *testStopper {
s := new(testStopper)
s.stopChan = stop.Make()
return s
}
func (t *testStopper) Stop(wait time.Duration) {
go func() {
time.Sleep(100 * time.Millisecond)
close(t.stopChan)
}()
}
func (t *testStopper) StopChan() <-chan stop.Signal {
return t.stopChan
}
type noopStopper struct{}
func (t *noopStopper) Stop() {
}
func (t *noopStopper) StopChan() <-chan stop.Signal {
return stop.Stopped()
}
func TestStop(t *testing.T) {
s := NewTestStopper()
s.Stop(1 * time.Second)
stopChan := s.StopChan()
select {
case <-stopChan:
case <-time.After(1 * time.Second):
t.Error("Stop signal was never sent (timed out)")
}
}
func TestAll(t *testing.T) {
s1 := NewTestStopper()
s2 := NewTestStopper()
s3 := NewTestStopper()
select {
case <-stop.All(1*time.Second, s1, s2, s3):
case <-time.After(1 * time.Second):
t.Error("All signal was never sent (timed out)")
}
}
func TestNoop(t *testing.T) {
s := new(noopStopper)
s.Stop()
stopChan := s.StopChan()
select {
case <-stopChan:
case <-time.After(1 * time.Second):
t.Error("Stop signal was never sent (timed out)")
}
}

View file

@ -1,6 +1,6 @@
The MIT License (MIT) The MIT License (MIT)
Copyright (c) 2014 Stretchr, Inc. Copyright (c) 2014 Tyler Bunnell
Permission is hereby granted, free of charge, to any person obtaining a copy Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal of this software and associated documentation files (the "Software"), to deal

View file

@ -1,17 +1,30 @@
graceful [![GoDoc](https://godoc.org/github.com/stretchr/graceful?status.png)](http://godoc.org/github.com/stretchr/graceful) [![wercker status](https://app.wercker.com/status/2729ba763abf87695a17547e0f7af4a4/s "wercker status")](https://app.wercker.com/project/bykey/2729ba763abf87695a17547e0f7af4a4) graceful [![GoDoc](https://godoc.org/github.com/tylerb/graceful?status.png)](http://godoc.org/github.com/tylerb/graceful) [![Build Status](https://drone.io/github.com/tylerb/graceful/status.png)](https://drone.io/github.com/tylerb/graceful/latest) [![Coverage Status](https://coveralls.io/repos/tylerb/graceful/badge.svg?branch=dronedebug)](https://coveralls.io/r/tylerb/graceful?branch=dronedebug) [![Gitter](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/tylerb/graceful?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge)
======== ========
Graceful is a Go 1.3+ package enabling graceful shutdown of http.Handler servers. Graceful is a Go 1.3+ package enabling graceful shutdown of http.Handler servers.
## Installation
To install, simply execute:
```
go get gopkg.in/tylerb/graceful.v1
```
I am using [gopkg.in](http://http://labix.org/gopkg.in) to control releases.
## Usage ## Usage
Usage of Graceful is simple. Create your http.Handler and pass it to the `Run` function: Using Graceful is easy. Simply create your http.Handler and pass it to the `Run` function:
```go ```go
package main
import ( import (
"github.com/stretchr/graceful" "gopkg.in/tylerb/graceful.v1"
"net/http" "net/http"
"fmt" "fmt"
"time"
) )
func main() { func main() {
@ -31,9 +44,10 @@ package main
import ( import (
"github.com/codegangsta/negroni" "github.com/codegangsta/negroni"
"github.com/stretchr/graceful" "gopkg.in/tylerb/graceful.v1"
"net/http" "net/http"
"fmt" "fmt"
"time"
) )
func main() { func main() {
@ -111,4 +125,13 @@ same time and all will be signalled when stopping is complete.
## Contributing ## Contributing
Before sending a pull request, please open a new issue describing the feature/issue you wish to address so it can be discussed. The subsequent pull request should close that issue. If you would like to contribute, please:
1. Create a GitHub issue regarding the contribution. Features and bugs should be discussed beforehand.
2. Fork the repository.
3. Create a pull request with your solution. This pull request should reference and close the issues (Fix #2).
All pull requests should:
1. Pass [gometalinter -t .](https://github.com/alecthomas/gometalinter) with no warnings.
2. Be `go fmt` formatted.

View file

@ -11,7 +11,6 @@ import (
"syscall" "syscall"
"time" "time"
"github.com/stretchr/pat/stop"
"golang.org/x/net/netutil" "golang.org/x/net/netutil"
) )
@ -41,30 +40,31 @@ type Server struct {
// must not be set directly. // must not be set directly.
ConnState func(net.Conn, http.ConnState) ConnState func(net.Conn, http.ConnState)
// ShutdownInitiated is an optional callback function that is called // ShutdownInitiated is an optional callback function that is called
// when shutdown is initiated. It can be used to notify the client // when shutdown is initiated. It can be used to notify the client
// side of long lived connections (e.g. websockets) to reconnect. // side of long lived connections (e.g. websockets) to reconnect.
ShutdownInitiated func() ShutdownInitiated func()
// NoSignalHandling prevents graceful from automatically shutting down
// on SIGINT and SIGTERM. If set to true, you must shut down the server
// manually with Stop().
NoSignalHandling bool
// interrupt signals the listener to stop serving connections, // interrupt signals the listener to stop serving connections,
// and the server to shut down. // and the server to shut down.
interrupt chan os.Signal interrupt chan os.Signal
// stopChan is the channel on which callers may block while waiting for // stopChan is the channel on which callers may block while waiting for
// the server to stop. // the server to stop.
stopChan chan stop.Signal stopChan chan struct{}
// stopChanOnce is used to create the stop channel on demand, once, per // stopLock is used to protect access to the stopChan.
// instance. stopLock sync.RWMutex
stopChanOnce sync.Once
// connections holds all connections managed by graceful // connections holds all connections managed by graceful
connections map[net.Conn]struct{} connections map[net.Conn]struct{}
} }
// ensure Server conforms to stop.Stopper
var _ stop.Stopper = (*Server)(nil)
// Run serves the http.Handler with graceful shutdown enabled. // Run serves the http.Handler with graceful shutdown enabled.
// //
// timeout is the duration to wait until killing active requests and stopping the server. // timeout is the duration to wait until killing active requests and stopping the server.
@ -173,7 +173,6 @@ func (srv *Server) Serve(listener net.Listener) error {
case http.StateClosed, http.StateHijacked: case http.StateClosed, http.StateHijacked:
remove <- conn remove <- conn
} }
if srv.ConnState != nil { if srv.ConnState != nil {
srv.ConnState(conn, state) srv.ConnState(conn, state)
} }
@ -182,7 +181,53 @@ func (srv *Server) Serve(listener net.Listener) error {
// Manage open connections // Manage open connections
shutdown := make(chan chan struct{}) shutdown := make(chan chan struct{})
kill := make(chan struct{}) kill := make(chan struct{})
go func() { go srv.manageConnections(add, remove, shutdown, kill)
interrupt := srv.interruptChan()
// Set up the interrupt handler
if !srv.NoSignalHandling {
signal.Notify(interrupt, syscall.SIGINT, syscall.SIGTERM)
}
go srv.handleInterrupt(interrupt, listener)
// Serve with graceful listener.
// Execution blocks here until listener.Close() is called, above.
err := srv.Server.Serve(listener)
srv.shutdown(shutdown, kill)
return err
}
// Stop instructs the type to halt operations and close
// the stop channel when it is finished.
//
// timeout is grace period for which to wait before shutting
// down the server. The timeout value passed here will override the
// timeout given when constructing the server, as this is an explicit
// command to stop the server.
func (srv *Server) Stop(timeout time.Duration) {
srv.Timeout = timeout
interrupt := srv.interruptChan()
interrupt <- syscall.SIGINT
}
// StopChan gets the stop channel which will block until
// stopping has completed, at which point it is closed.
// Callers should never close the stop channel.
func (srv *Server) StopChan() <-chan struct{} {
srv.stopLock.Lock()
if srv.stopChan == nil {
srv.stopChan = make(chan struct{})
}
srv.stopLock.Unlock()
return srv.stopChan
}
func (srv *Server) manageConnections(add, remove chan net.Conn, shutdown chan chan struct{}, kill chan struct{}) {
{
var done chan struct{} var done chan struct{}
srv.connections = map[net.Conn]struct{}{} srv.connections = map[net.Conn]struct{}{}
for { for {
@ -202,36 +247,39 @@ func (srv *Server) Serve(listener net.Listener) error {
} }
case <-kill: case <-kill:
for k := range srv.connections { for k := range srv.connections {
k.Close() _ = k.Close() // nothing to do here if it errors
} }
return return
} }
} }
}() }
}
func (srv *Server) interruptChan() chan os.Signal {
srv.stopLock.Lock()
if srv.interrupt == nil { if srv.interrupt == nil {
srv.interrupt = make(chan os.Signal, 1) srv.interrupt = make(chan os.Signal, 1)
} }
srv.stopLock.Unlock()
// Set up the interrupt catch return srv.interrupt
signal.Notify(srv.interrupt, syscall.SIGINT, syscall.SIGTERM) }
go func() {
<-srv.interrupt
srv.SetKeepAlivesEnabled(false)
listener.Close()
if srv.ShutdownInitiated != nil { func (srv *Server) handleInterrupt(interrupt chan os.Signal, listener net.Listener) {
srv.ShutdownInitiated() <-interrupt
}
signal.Stop(srv.interrupt) srv.SetKeepAlivesEnabled(false)
close(srv.interrupt) _ = listener.Close() // we are shutting down anyway. ignore error.
}()
// Serve with graceful listener. if srv.ShutdownInitiated != nil {
// Execution blocks here until listener.Close() is called, above. srv.ShutdownInitiated()
err := srv.Server.Serve(listener) }
signal.Stop(interrupt)
close(interrupt)
}
func (srv *Server) shutdown(shutdown chan chan struct{}, kill chan struct{}) {
// Request done notification // Request done notification
done := make(chan struct{}) done := make(chan struct{})
shutdown <- done shutdown <- done
@ -246,32 +294,9 @@ func (srv *Server) Serve(listener net.Listener) error {
<-done <-done
} }
// Close the stopChan to wake up any blocked goroutines. // Close the stopChan to wake up any blocked goroutines.
srv.stopLock.Lock()
if srv.stopChan != nil { if srv.stopChan != nil {
close(srv.stopChan) close(srv.stopChan)
} }
return err srv.stopLock.Unlock()
}
// Stop instructs the type to halt operations and close
// the stop channel when it is finished.
//
// timeout is grace period for which to wait before shutting
// down the server. The timeout value passed here will override the
// timeout given when constructing the server, as this is an explicit
// command to stop the server.
func (srv *Server) Stop(timeout time.Duration) {
srv.Timeout = timeout
srv.interrupt <- syscall.SIGINT
}
// StopChan gets the stop channel which will block until
// stopping has completed, at which point it is closed.
// Callers should never close the stop channel.
func (srv *Server) StopChan() <-chan stop.Signal {
srv.stopChanOnce.Do(func() {
if srv.stopChan == nil {
srv.stopChan = stop.Make()
}
})
return srv.stopChan
} }

View file

@ -1,6 +1,7 @@
package graceful package graceful
import ( import (
"fmt"
"io" "io"
"net" "net"
"net/http" "net/http"
@ -13,34 +14,52 @@ import (
"time" "time"
) )
var killTime = 50 * time.Millisecond var (
killTime = 500 * time.Millisecond
timeoutTime = 1000 * time.Millisecond
waitTime = 100 * time.Millisecond
)
func runQuery(t *testing.T, expected int, shouldErr bool, wg *sync.WaitGroup) { func runQuery(t *testing.T, expected int, shouldErr bool, wg *sync.WaitGroup, once *sync.Once) {
wg.Add(1) wg.Add(1)
defer wg.Done() defer wg.Done()
client := http.Client{} client := http.Client{}
r, err := client.Get("http://localhost:3000") r, err := client.Get("http://localhost:3000")
if shouldErr && err == nil { if shouldErr && err == nil {
t.Fatal("Expected an error but none was encountered.") once.Do(func() {
t.Fatal("Expected an error but none was encountered.")
})
} else if shouldErr && err != nil { } else if shouldErr && err != nil {
if err.(*url.Error).Err == io.EOF { if checkErr(t, err, once) {
return return
} }
errno := err.(*url.Error).Err.(*net.OpError).Err.(syscall.Errno)
if errno == syscall.ECONNREFUSED {
return
} else if err != nil {
t.Fatal("Error on Get:", err)
}
} }
if r != nil && r.StatusCode != expected { if r != nil && r.StatusCode != expected {
t.Fatalf("Incorrect status code on response. Expected %d. Got %d", expected, r.StatusCode) once.Do(func() {
t.Fatalf("Incorrect status code on response. Expected %d. Got %d", expected, r.StatusCode)
})
} else if r == nil { } else if r == nil {
t.Fatal("No response when a response was expected.") once.Do(func() {
t.Fatal("No response when a response was expected.")
})
} }
} }
func checkErr(t *testing.T, err error, once *sync.Once) bool {
if err.(*url.Error).Err == io.EOF {
return true
}
errno := err.(*url.Error).Err.(*net.OpError).Err.(syscall.Errno)
if errno == syscall.ECONNREFUSED {
return true
} else if err != nil {
once.Do(func() {
t.Fatal("Error on Get:", err)
})
}
return false
}
func createListener(sleep time.Duration) (*http.Server, net.Listener, error) { func createListener(sleep time.Duration) (*http.Server, net.Listener, error) {
mux := http.NewServeMux() mux := http.NewServeMux()
mux.HandleFunc("/", func(rw http.ResponseWriter, r *http.Request) { mux.HandleFunc("/", func(rw http.ResponseWriter, r *http.Request) {
@ -50,6 +69,9 @@ func createListener(sleep time.Duration) (*http.Server, net.Listener, error) {
server := &http.Server{Addr: ":3000", Handler: mux} server := &http.Server{Addr: ":3000", Handler: mux}
l, err := net.Listen("tcp", ":3000") l, err := net.Listen("tcp", ":3000")
if err != nil {
fmt.Println(err)
}
return server, l, err return server, l, err
} }
@ -64,16 +86,17 @@ func runServer(timeout, sleep time.Duration, c chan os.Signal) error {
} }
func launchTestQueries(t *testing.T, wg *sync.WaitGroup, c chan os.Signal) { func launchTestQueries(t *testing.T, wg *sync.WaitGroup, c chan os.Signal) {
var once sync.Once
for i := 0; i < 8; i++ { for i := 0; i < 8; i++ {
go runQuery(t, http.StatusOK, false, wg) go runQuery(t, http.StatusOK, false, wg, &once)
} }
time.Sleep(10 * time.Millisecond) time.Sleep(waitTime)
c <- os.Interrupt c <- os.Interrupt
time.Sleep(10 * time.Millisecond) time.Sleep(waitTime)
for i := 0; i < 8; i++ { for i := 0; i < 8; i++ {
go runQuery(t, 0, true, wg) go runQuery(t, 0, true, wg, &once)
} }
wg.Done() wg.Done()
@ -106,16 +129,17 @@ func TestGracefulRunTimesOut(t *testing.T) {
wg.Done() wg.Done()
}() }()
var once sync.Once
wg.Add(1) wg.Add(1)
go func() { go func() {
for i := 0; i < 8; i++ { for i := 0; i < 8; i++ {
go runQuery(t, 0, true, &wg) go runQuery(t, 0, true, &wg, &once)
} }
time.Sleep(10 * time.Millisecond) time.Sleep(waitTime)
c <- os.Interrupt c <- os.Interrupt
time.Sleep(10 * time.Millisecond) time.Sleep(waitTime)
for i := 0; i < 8; i++ { for i := 0; i < 8; i++ {
go runQuery(t, 0, true, &wg) go runQuery(t, 0, true, &wg, &once)
} }
wg.Done() wg.Done()
}() }()
@ -160,14 +184,23 @@ func TestGracefulRunNoRequests(t *testing.T) {
func TestGracefulForwardsConnState(t *testing.T) { func TestGracefulForwardsConnState(t *testing.T) {
c := make(chan os.Signal, 1) c := make(chan os.Signal, 1)
states := make(map[http.ConnState]int) states := make(map[http.ConnState]int)
var stateLock sync.Mutex
connState := func(conn net.Conn, state http.ConnState) { connState := func(conn net.Conn, state http.ConnState) {
stateLock.Lock()
states[state]++ states[state]++
stateLock.Unlock()
} }
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(1) wg.Add(1)
expected := map[http.ConnState]int{
http.StateNew: 8,
http.StateActive: 8,
http.StateClosed: 8,
}
go func() { go func() {
server, l, _ := createListener(killTime / 2) server, l, _ := createListener(killTime / 2)
srv := &Server{ srv := &Server{
@ -185,15 +218,11 @@ func TestGracefulForwardsConnState(t *testing.T) {
go launchTestQueries(t, &wg, c) go launchTestQueries(t, &wg, c)
wg.Wait() wg.Wait()
expected := map[http.ConnState]int{ stateLock.Lock()
http.StateNew: 8,
http.StateActive: 8,
http.StateClosed: 8,
}
if !reflect.DeepEqual(states, expected) { if !reflect.DeepEqual(states, expected) {
t.Errorf("Incorrect connection state tracking.\n actual: %v\nexpected: %v\n", states, expected) t.Errorf("Incorrect connection state tracking.\n actual: %v\nexpected: %v\n", states, expected)
} }
stateLock.Unlock()
} }
func TestGracefulExplicitStop(t *testing.T) { func TestGracefulExplicitStop(t *testing.T) {
@ -206,14 +235,14 @@ func TestGracefulExplicitStop(t *testing.T) {
go func() { go func() {
go srv.Serve(l) go srv.Serve(l)
time.Sleep(10 * time.Millisecond) time.Sleep(waitTime)
srv.Stop(killTime) srv.Stop(killTime)
}() }()
// block on the stopChan until the server has shut down // block on the stopChan until the server has shut down
select { select {
case <-srv.StopChan(): case <-srv.StopChan():
case <-time.After(100 * time.Millisecond): case <-time.After(timeoutTime):
t.Fatal("Timed out while waiting for explicit stop to complete") t.Fatal("Timed out while waiting for explicit stop to complete")
} }
} }
@ -228,7 +257,7 @@ func TestGracefulExplicitStopOverride(t *testing.T) {
go func() { go func() {
go srv.Serve(l) go srv.Serve(l)
time.Sleep(10 * time.Millisecond) time.Sleep(waitTime)
srv.Stop(killTime / 2) srv.Stop(killTime / 2)
}() }()
@ -253,7 +282,7 @@ func TestShutdownInitiatedCallback(t *testing.T) {
go func() { go func() {
go srv.Serve(l) go srv.Serve(l)
time.Sleep(10 * time.Millisecond) time.Sleep(waitTime)
srv.Stop(killTime) srv.Stop(killTime)
}() }()
@ -302,12 +331,9 @@ func TestNotifyClosed(t *testing.T) {
wg.Done() wg.Done()
}() }()
var once sync.Once
for i := 0; i < 8; i++ { for i := 0; i < 8; i++ {
runQuery(t, http.StatusOK, false, &wg) runQuery(t, http.StatusOK, false, &wg, &once)
}
if len(srv.connections) > 0 {
t.Fatal("hijacked connections should not be managed")
} }
srv.Stop(0) srv.Stop(0)
@ -315,8 +341,39 @@ func TestNotifyClosed(t *testing.T) {
// block on the stopChan until the server has shut down // block on the stopChan until the server has shut down
select { select {
case <-srv.StopChan(): case <-srv.StopChan():
case <-time.After(100 * time.Millisecond): case <-time.After(timeoutTime):
t.Fatal("Timed out while waiting for explicit stop to complete") t.Fatal("Timed out while waiting for explicit stop to complete")
} }
if len(srv.connections) > 0 {
t.Fatal("hijacked connections should not be managed")
}
}
func TestStopDeadlock(t *testing.T) {
c := make(chan struct{})
server, l, err := createListener(1 * time.Millisecond)
if err != nil {
t.Fatal(err)
}
srv := &Server{Server: server, NoSignalHandling: true}
go func() {
time.Sleep(waitTime)
srv.Serve(l)
}()
go func() {
srv.Stop(0)
close(c)
}()
select {
case <-c:
case <-time.After(timeoutTime):
t.Fatal("Timed out while waiting for explicit stop to complete")
}
} }

View file

@ -5,7 +5,7 @@ import (
"sync" "sync"
"github.com/codegangsta/negroni" "github.com/codegangsta/negroni"
"github.com/stretchr/graceful" "github.com/tylerb/graceful"
) )
func main() { func main() {

View file

@ -7,8 +7,11 @@ package chihaya
import ( import (
"flag" "flag"
"os" "os"
"os/signal"
"runtime" "runtime"
"runtime/pprof" "runtime/pprof"
"sync"
"syscall"
"github.com/golang/glog" "github.com/golang/glog"
@ -16,6 +19,7 @@ import (
"github.com/chihaya/chihaya/http" "github.com/chihaya/chihaya/http"
"github.com/chihaya/chihaya/stats" "github.com/chihaya/chihaya/stats"
"github.com/chihaya/chihaya/tracker" "github.com/chihaya/chihaya/tracker"
"github.com/chihaya/chihaya/udp"
// See the README for how to import custom drivers. // See the README for how to import custom drivers.
_ "github.com/chihaya/chihaya/backend/noop" _ "github.com/chihaya/chihaya/backend/noop"
@ -77,6 +81,50 @@ func Boot() {
glog.Fatal("New: ", err) glog.Fatal("New: ", err)
} }
http.Serve(cfg, tkr) var wg sync.WaitGroup
glog.Info("Gracefully shut down") var servers []tracker.Server
if cfg.HTTPListenAddr != "" {
wg.Add(1)
srv := http.NewServer(cfg, tkr)
servers = append(servers, srv)
go func() {
defer wg.Done()
srv.Serve(cfg.HTTPListenAddr)
}()
}
if cfg.UDPListenAddr != "" {
wg.Add(1)
srv := udp.NewServer(cfg, tkr)
servers = append(servers, srv)
go func() {
defer wg.Done()
srv.Serve(cfg.UDPListenAddr)
}()
}
shutdown := make(chan os.Signal)
signal.Notify(shutdown, syscall.SIGINT, syscall.SIGTERM)
go func() {
wg.Wait()
signal.Stop(shutdown)
close(shutdown)
}()
<-shutdown
glog.Info("Shutting down...")
for _, srv := range servers {
srv.Stop()
}
<-shutdown
if err := tkr.Close(); err != nil {
glog.Errorf("Failed to shut down tracker cleanly: %s", err.Error())
}
} }

View file

@ -90,17 +90,24 @@ type TrackerConfig struct {
// HTTPConfig is the configuration for HTTP functionality. // HTTPConfig is the configuration for HTTP functionality.
type HTTPConfig struct { type HTTPConfig struct {
ListenAddr string `json:"http_listen_addr"` HTTPListenAddr string `json:"http_listen_addr"`
RequestTimeout Duration `json:"http_request_timeout"` HTTPRequestTimeout Duration `json:"http_request_timeout"`
HTTPReadTimeout Duration `json:"http_read_timeout"` HTTPReadTimeout Duration `json:"http_read_timeout"`
HTTPWriteTimeout Duration `json:"http_write_timeout"` HTTPWriteTimeout Duration `json:"http_write_timeout"`
HTTPListenLimit int `json:"http_listen_limit"` HTTPListenLimit int `json:"http_listen_limit"`
}
// UDPConfig is the configuration for HTTP functionality.
type UDPConfig struct {
UDPListenAddr string `json:"udp_listen_addr"`
UDPReadBufferSize int `json:"udp_read_buffer_size"`
} }
// Config is the global configuration for an instance of Chihaya. // Config is the global configuration for an instance of Chihaya.
type Config struct { type Config struct {
TrackerConfig TrackerConfig
HTTPConfig HTTPConfig
UDPConfig
DriverConfig DriverConfig
StatsConfig StatsConfig
} }
@ -129,10 +136,14 @@ var DefaultConfig = Config{
}, },
HTTPConfig: HTTPConfig{ HTTPConfig: HTTPConfig{
ListenAddr: ":6881", HTTPListenAddr: ":6881",
RequestTimeout: Duration{10 * time.Second}, HTTPRequestTimeout: Duration{10 * time.Second},
HTTPReadTimeout: Duration{10 * time.Second}, HTTPReadTimeout: Duration{10 * time.Second},
HTTPWriteTimeout: Duration{10 * time.Second}, HTTPWriteTimeout: Duration{10 * time.Second},
},
UDPConfig: UDPConfig{
UDPListenAddr: ":6882",
}, },
DriverConfig: DriverConfig{ DriverConfig: DriverConfig{

View file

@ -13,10 +13,11 @@
"respect_af": false, "respect_af": false,
"client_whitelist_enabled": false, "client_whitelist_enabled": false,
"client_whitelist": ["OP1011"], "client_whitelist": ["OP1011"],
"udp_listen_addr": ":6881",
"http_listen_addr": ":6881", "http_listen_addr": ":6881",
"http_request_timeout": "10s", "http_request_timeout": "4s",
"http_read_timeout": "10s", "http_read_timeout": "4s",
"http_write_timeout": "10s", "http_write_timeout": "4s",
"http_listen_limit": 0, "http_listen_limit": 0,
"driver": "noop", "driver": "noop",
"stats_buffer_size": 0, "stats_buffer_size": 0,

View file

@ -31,10 +31,10 @@ func TestPublicAnnounce(t *testing.T) {
peer3 := makePeerParams("peer3", false) peer3 := makePeerParams("peer3", false)
peer1["event"] = "started" peer1["event"] = "started"
expected := makeResponse(1, 0) expected := makeResponse(1, 0, peer1)
checkAnnounce(peer1, expected, srv, t) checkAnnounce(peer1, expected, srv, t)
expected = makeResponse(2, 0) expected = makeResponse(2, 0, peer2)
checkAnnounce(peer2, expected, srv, t) checkAnnounce(peer2, expected, srv, t)
expected = makeResponse(2, 1, peer1, peer2) expected = makeResponse(2, 1, peer1, peer2)
@ -147,7 +147,7 @@ func TestPrivateAnnounce(t *testing.T) {
peer2 := makePeerParams("-TR2820-peer2", false) peer2 := makePeerParams("-TR2820-peer2", false)
peer3 := makePeerParams("-TR2820-peer3", true) peer3 := makePeerParams("-TR2820-peer3", true)
expected := makeResponse(0, 1) expected := makeResponse(0, 1, peer1)
srv.URL = baseURL + "/users/vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv1" srv.URL = baseURL + "/users/vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv1"
checkAnnounce(peer1, expected, srv, t) checkAnnounce(peer1, expected, srv, t)
@ -189,7 +189,7 @@ func TestPreferredSubnet(t *testing.T) {
peerD1 := makePeerParams("peerD1", false, "fc02::1") peerD1 := makePeerParams("peerD1", false, "fc02::1")
peerD2 := makePeerParams("peerD2", false, "fc02::2") peerD2 := makePeerParams("peerD2", false, "fc02::2")
expected := makeResponse(0, 1) expected := makeResponse(0, 1, peerA1)
checkAnnounce(peerA1, expected, srv, t) checkAnnounce(peerA1, expected, srv, t)
expected = makeResponse(0, 2, peerA1) expected = makeResponse(0, 2, peerA1)
@ -255,7 +255,7 @@ func TestCompactAnnounce(t *testing.T) {
peer3["compact"] = "1" peer3["compact"] = "1"
expected := makeResponse(0, 1) expected := makeResponse(0, 1)
expected["peers"] = "" expected["peers"] = compact
checkAnnounce(peer1, expected, srv, t) checkAnnounce(peer1, expected, srv, t)
expected = makeResponse(0, 2) expected = makeResponse(0, 2)

View file

@ -12,7 +12,7 @@ import (
"github.com/golang/glog" "github.com/golang/glog"
"github.com/julienschmidt/httprouter" "github.com/julienschmidt/httprouter"
"github.com/stretchr/graceful" "github.com/tylerb/graceful"
"github.com/chihaya/chihaya/config" "github.com/chihaya/chihaya/config"
"github.com/chihaya/chihaya/stats" "github.com/chihaya/chihaya/stats"
@ -24,8 +24,10 @@ type ResponseHandler func(http.ResponseWriter, *http.Request, httprouter.Params)
// Server represents an HTTP serving torrent tracker. // Server represents an HTTP serving torrent tracker.
type Server struct { type Server struct {
config *config.Config config *config.Config
tracker *tracker.Tracker tracker *tracker.Tracker
grace *graceful.Server
stopping bool
} }
// makeHandler wraps our ResponseHandlers while timing requests, collecting, // makeHandler wraps our ResponseHandlers while timing requests, collecting,
@ -118,40 +120,53 @@ func (s *Server) connState(conn net.Conn, state http.ConnState) {
} }
} }
// Serve creates a new Server and proceeds to block while handling requests // Serve runs an HTTP server, blocking until the server has shut down.
// until a graceful shutdown. func (s *Server) Serve(addr string) {
func Serve(cfg *config.Config, tkr *tracker.Tracker) { glog.V(0).Info("Starting HTTP on ", addr)
srv := &Server{
config: cfg,
tracker: tkr,
}
glog.V(0).Info("Starting on ", cfg.ListenAddr) if s.config.HTTPListenLimit != 0 {
if cfg.HTTPListenLimit != 0 { glog.V(0).Info("Limiting connections to ", s.config.HTTPListenLimit)
glog.V(0).Info("Limiting connections to ", cfg.HTTPListenLimit)
} }
grace := &graceful.Server{ grace := &graceful.Server{
Timeout: cfg.RequestTimeout.Duration, Timeout: s.config.HTTPRequestTimeout.Duration,
ConnState: srv.connState, ConnState: s.connState,
ListenLimit: cfg.HTTPListenLimit, ListenLimit: s.config.HTTPListenLimit,
NoSignalHandling: true,
Server: &http.Server{ Server: &http.Server{
Addr: cfg.ListenAddr, Addr: addr,
Handler: newRouter(srv), Handler: newRouter(s),
ReadTimeout: cfg.HTTPReadTimeout.Duration, ReadTimeout: s.config.HTTPReadTimeout.Duration,
WriteTimeout: cfg.HTTPWriteTimeout.Duration, WriteTimeout: s.config.HTTPWriteTimeout.Duration,
}, },
} }
s.grace = grace
grace.SetKeepAlivesEnabled(false) grace.SetKeepAlivesEnabled(false)
grace.ShutdownInitiated = func() { s.stopping = true }
if err := grace.ListenAndServe(); err != nil { if err := grace.ListenAndServe(); err != nil {
if opErr, ok := err.(*net.OpError); !ok || (ok && opErr.Op != "accept") { if opErr, ok := err.(*net.OpError); !ok || (ok && opErr.Op != "accept") {
glog.Errorf("Failed to gracefully run HTTP server: %s", err.Error()) glog.Errorf("Failed to gracefully run HTTP server: %s", err.Error())
return
} }
} }
if err := srv.tracker.Close(); err != nil { glog.Info("HTTP server shut down cleanly")
glog.Errorf("Failed to shutdown tracker cleanly: %s", err.Error()) }
// Stop cleanly shuts down the server.
func (s *Server) Stop() {
if !s.stopping {
s.grace.Stop(s.grace.Timeout)
}
}
// NewServer returns a new HTTP server for a given configuration and tracker.
func NewServer(cfg *config.Config, tkr *tracker.Tracker) *Server {
return &Server{
config: cfg,
tracker: tkr,
} }
} }

View file

@ -76,7 +76,7 @@ func (s *Server) stats(w http.ResponseWriter, r *http.Request, p httprouter.Para
func handleTorrentError(err error, w *Writer) (int, error) { func handleTorrentError(err error, w *Writer) (int, error) {
if err == nil { if err == nil {
return http.StatusOK, nil return http.StatusOK, nil
} else if _, ok := err.(models.ClientError); ok { } else if models.IsPublicError(err) {
w.WriteError(err) w.WriteError(err)
stats.RecordEvent(stats.ClientError) stats.RecordEvent(stats.ClientError)
return http.StatusOK, nil return http.StatusOK, nil
@ -86,10 +86,8 @@ func handleTorrentError(err error, w *Writer) (int, error) {
} }
func (s *Server) serveAnnounce(w http.ResponseWriter, r *http.Request, p httprouter.Params) (int, error) { func (s *Server) serveAnnounce(w http.ResponseWriter, r *http.Request, p httprouter.Params) (int, error) {
stats.RecordEvent(stats.Announce)
writer := &Writer{w} writer := &Writer{w}
ann, err := NewAnnounce(s.config, r, p) ann, err := s.newAnnounce(r, p)
if err != nil { if err != nil {
return handleTorrentError(err, writer) return handleTorrentError(err, writer)
} }
@ -98,10 +96,8 @@ func (s *Server) serveAnnounce(w http.ResponseWriter, r *http.Request, p httprou
} }
func (s *Server) serveScrape(w http.ResponseWriter, r *http.Request, p httprouter.Params) (int, error) { func (s *Server) serveScrape(w http.ResponseWriter, r *http.Request, p httprouter.Params) (int, error) {
stats.RecordEvent(stats.Scrape)
writer := &Writer{w} writer := &Writer{w}
scrape, err := NewScrape(s.config, r, p) scrape, err := s.newScrape(r, p)
if err != nil { if err != nil {
return handleTorrentError(err, writer) return handleTorrentError(err, writer)
} }

View file

@ -17,8 +17,8 @@ import (
"github.com/chihaya/chihaya/tracker/models" "github.com/chihaya/chihaya/tracker/models"
) )
// NewAnnounce parses an HTTP request and generates a models.Announce. // newAnnounce parses an HTTP request and generates a models.Announce.
func NewAnnounce(cfg *config.Config, r *http.Request, p httprouter.Params) (*models.Announce, error) { func (s *Server) newAnnounce(r *http.Request, p httprouter.Params) (*models.Announce, error) {
q, err := query.New(r.URL.RawQuery) q, err := query.New(r.URL.RawQuery)
if err != nil { if err != nil {
return nil, err return nil, err
@ -26,7 +26,7 @@ func NewAnnounce(cfg *config.Config, r *http.Request, p httprouter.Params) (*mod
compact := q.Params["compact"] != "0" compact := q.Params["compact"] != "0"
event, _ := q.Params["event"] event, _ := q.Params["event"]
numWant := requestedPeerCount(q, cfg.NumWantFallback) numWant := requestedPeerCount(q, s.config.NumWantFallback)
infohash, exists := q.Params["info_hash"] infohash, exists := q.Params["info_hash"]
if !exists { if !exists {
@ -38,17 +38,34 @@ func NewAnnounce(cfg *config.Config, r *http.Request, p httprouter.Params) (*mod
return nil, models.ErrMalformedRequest return nil, models.ErrMalformedRequest
} }
ipv4, ipv6, err := requestedIP(q, r, &cfg.NetConfig)
if err != nil {
return nil, models.ErrMalformedRequest
}
port, err := q.Uint64("port") port, err := q.Uint64("port")
if err != nil { if err != nil {
return nil, models.ErrMalformedRequest return nil, models.ErrMalformedRequest
} }
left, err := q.Uint64("left") left, err := q.Uint64("left")
ipv4, ipv6, err := requestedEndpoint(q, r, &s.config.NetConfig)
if err != nil {
return nil, models.ErrMalformedRequest
}
var ipv4Endpoint, ipv6Endpoint models.Endpoint
if ipv4 != nil {
ipv4Endpoint = *ipv4
// If the port we couldn't get the port before, fallback to the port param.
if ipv4Endpoint.Port == uint16(0) {
ipv4Endpoint.Port = uint16(port)
}
}
if ipv6 != nil {
ipv6Endpoint = *ipv6
// If the port we couldn't get the port before, fallback to the port param.
if ipv6Endpoint.Port == uint16(0) {
ipv6Endpoint.Port = uint16(port)
}
}
if err != nil { if err != nil {
return nil, models.ErrMalformedRequest return nil, models.ErrMalformedRequest
} }
@ -64,24 +81,23 @@ func NewAnnounce(cfg *config.Config, r *http.Request, p httprouter.Params) (*mod
} }
return &models.Announce{ return &models.Announce{
Config: cfg, Config: s.config,
Compact: compact, Compact: compact,
Downloaded: downloaded, Downloaded: downloaded,
Event: event, Event: event,
IPv4: ipv4, IPv4: ipv4Endpoint,
IPv6: ipv6, IPv6: ipv6Endpoint,
Infohash: infohash, Infohash: infohash,
Left: left, Left: left,
NumWant: numWant, NumWant: numWant,
Passkey: p.ByName("passkey"), Passkey: p.ByName("passkey"),
PeerID: peerID, PeerID: peerID,
Port: port,
Uploaded: uploaded, Uploaded: uploaded,
}, nil }, nil
} }
// NewScrape parses an HTTP request and generates a models.Scrape. // newScrape parses an HTTP request and generates a models.Scrape.
func NewScrape(cfg *config.Config, r *http.Request, p httprouter.Params) (*models.Scrape, error) { func (s *Server) newScrape(r *http.Request, p httprouter.Params) (*models.Scrape, error) {
q, err := query.New(r.URL.RawQuery) q, err := query.New(r.URL.RawQuery)
if err != nil { if err != nil {
return nil, err return nil, err
@ -96,7 +112,7 @@ func NewScrape(cfg *config.Config, r *http.Request, p httprouter.Params) (*model
} }
return &models.Scrape{ return &models.Scrape{
Config: cfg, Config: s.config,
Passkey: p.ByName("passkey"), Passkey: p.ByName("passkey"),
Infohashes: q.Infohashes, Infohashes: q.Infohashes,
@ -116,26 +132,26 @@ func requestedPeerCount(q *query.Query, fallback int) int {
return fallback return fallback
} }
// requestedIP returns the IP addresses for a request. If there are multiple // requestedEndpoint returns the IP address and port pairs for a request. If
// IP addresses in the request, one IPv4 and one IPv6 will be returned. // there are multiple in the request, one IPv4 and one IPv6 will be returned.
func requestedIP(q *query.Query, r *http.Request, cfg *config.NetConfig) (v4, v6 net.IP, err error) { func requestedEndpoint(q *query.Query, r *http.Request, cfg *config.NetConfig) (v4, v6 *models.Endpoint, err error) {
var done bool var done bool
if cfg.AllowIPSpoofing { if cfg.AllowIPSpoofing {
if str, ok := q.Params["ip"]; ok { if str, ok := q.Params["ip"]; ok {
if v4, v6, done = getIPs(str, v4, v6, cfg); done { if v4, v6, done = getEndpoints(str, v4, v6, cfg); done {
return return
} }
} }
if str, ok := q.Params["ipv4"]; ok { if str, ok := q.Params["ipv4"]; ok {
if v4, v6, done = getIPs(str, v4, v6, cfg); done { if v4, v6, done = getEndpoints(str, v4, v6, cfg); done {
return return
} }
} }
if str, ok := q.Params["ipv6"]; ok { if str, ok := q.Params["ipv6"]; ok {
if v4, v6, done = getIPs(str, v4, v6, cfg); done { if v4, v6, done = getEndpoints(str, v4, v6, cfg); done {
return return
} }
} }
@ -143,47 +159,49 @@ func requestedIP(q *query.Query, r *http.Request, cfg *config.NetConfig) (v4, v6
if cfg.RealIPHeader != "" { if cfg.RealIPHeader != "" {
if xRealIPs, ok := r.Header[cfg.RealIPHeader]; ok { if xRealIPs, ok := r.Header[cfg.RealIPHeader]; ok {
if v4, v6, done = getIPs(string(xRealIPs[0]), v4, v6, cfg); done { if v4, v6, done = getEndpoints(string(xRealIPs[0]), v4, v6, cfg); done {
return return
} }
} }
} else { } else {
if r.RemoteAddr == "" { if r.RemoteAddr == "" && v4 == nil {
if v4 == nil { if v4, v6, done = getEndpoints("127.0.0.1", v4, v6, cfg); done {
v4 = net.ParseIP("127.0.0.1")
}
return
}
var host string
host, _, err = net.SplitHostPort(r.RemoteAddr)
if err == nil && host != "" {
if v4, v6, done = getIPs(host, v4, v6, cfg); done {
return return
} }
} }
if v4, v6, done = getEndpoints(r.RemoteAddr, v4, v6, cfg); done {
return
}
} }
if v4 == nil && v6 == nil { if v4 == nil && v6 == nil {
err = errors.New("failed to parse IP address") err = errors.New("failed to parse IP address")
} }
return return
} }
func getIPs(ipstr string, ipv4, ipv6 net.IP, cfg *config.NetConfig) (net.IP, net.IP, bool) { func getEndpoints(ipstr string, ipv4, ipv6 *models.Endpoint, cfg *config.NetConfig) (*models.Endpoint, *models.Endpoint, bool) {
var done bool host, port, err := net.SplitHostPort(ipstr)
if err != nil {
host = ipstr
}
if ip := net.ParseIP(ipstr); ip != nil { // We can ignore this error, because ports that are 0 are assumed to be the
newIPv4 := ip.To4() // port parameter provided in the "port" param of the announce request.
parsedPort, _ := strconv.ParseUint(port, 10, 16)
if ipv4 == nil && newIPv4 != nil { if ip := net.ParseIP(host); ip != nil {
ipv4 = newIPv4 ipTo4 := ip.To4()
} else if ipv6 == nil && newIPv4 == nil { if ipv4 == nil && ipTo4 != nil {
ipv6 = ip ipv4 = &models.Endpoint{ipTo4, uint16(parsedPort)}
} else if ipv6 == nil && ipTo4 == nil {
ipv6 = &models.Endpoint{ip, uint16(parsedPort)}
} }
} }
var done bool
if cfg.DualStackedPeers { if cfg.DualStackedPeers {
done = ipv4 != nil && ipv6 != nil done = ipv4 != nil && ipv6 != nil
} else { } else {

View file

@ -271,15 +271,21 @@ func (s *Stats) handlePeerEvent(ps *PeerStats, event int) {
// RecordEvent broadcasts an event to the default stats queue. // RecordEvent broadcasts an event to the default stats queue.
func RecordEvent(event int) { func RecordEvent(event int) {
DefaultStats.RecordEvent(event) if DefaultStats != nil {
DefaultStats.RecordEvent(event)
}
} }
// RecordPeerEvent broadcasts a peer event to the default stats queue. // RecordPeerEvent broadcasts a peer event to the default stats queue.
func RecordPeerEvent(event int, ipv6 bool) { func RecordPeerEvent(event int, ipv6 bool) {
DefaultStats.RecordPeerEvent(event, ipv6) if DefaultStats != nil {
DefaultStats.RecordPeerEvent(event, ipv6)
}
} }
// RecordTiming broadcasts a timing event to the default stats queue. // RecordTiming broadcasts a timing event to the default stats queue.
func RecordTiming(event int, duration time.Duration) { func RecordTiming(event int, duration time.Duration) {
DefaultStats.RecordTiming(event, duration) if DefaultStats != nil {
DefaultStats.RecordTiming(event, duration)
}
} }

View file

@ -70,6 +70,7 @@ func (tkr *Tracker) HandleAnnounce(ann *models.Announce, w Writer) (err error) {
stats.RecordEvent(stats.DeletedTorrent) stats.RecordEvent(stats.DeletedTorrent)
} }
stats.RecordEvent(stats.Announce)
return w.WriteAnnounce(newAnnounceResponse(ann)) return w.WriteAnnounce(newAnnounceResponse(ann))
} }
@ -263,6 +264,7 @@ func newAnnounceResponse(ann *models.Announce) *models.AnnounceResponse {
leechCount := ann.Torrent.Leechers.Len() leechCount := ann.Torrent.Leechers.Len()
res := &models.AnnounceResponse{ res := &models.AnnounceResponse{
Announce: ann,
Complete: seedCount, Complete: seedCount,
Incomplete: leechCount, Incomplete: leechCount,
Interval: ann.Config.Announce.Duration, Interval: ann.Config.Announce.Duration,
@ -272,6 +274,10 @@ func newAnnounceResponse(ann *models.Announce) *models.AnnounceResponse {
if ann.NumWant > 0 && ann.Event != "stopped" && ann.Event != "paused" { if ann.NumWant > 0 && ann.Event != "stopped" && ann.Event != "paused" {
res.IPv4Peers, res.IPv6Peers = getPeers(ann) res.IPv4Peers, res.IPv6Peers = getPeers(ann)
if len(res.IPv4Peers)+len(res.IPv6Peers) == 0 {
models.AppendPeer(&res.IPv4Peers, &res.IPv6Peers, ann, ann.Peer)
}
} }
return res return res

View file

@ -8,7 +8,6 @@ package models
import ( import (
"net" "net"
"strconv"
"strings" "strings"
"time" "time"
@ -40,9 +39,19 @@ var (
type ClientError string type ClientError string
type NotFoundError ClientError type NotFoundError ClientError
type ProtocolError ClientError
func (e ClientError) Error() string { return string(e) } func (e ClientError) Error() string { return string(e) }
func (e NotFoundError) Error() string { return string(e) } func (e NotFoundError) Error() string { return string(e) }
func (e ProtocolError) Error() string { return string(e) }
// IsPublicError determines whether an error should be propogated to the client.
func IsPublicError(err error) bool {
_, cl := err.(ClientError)
_, nf := err.(NotFoundError)
_, pc := err.(ProtocolError)
return cl || nf || pc
}
// PeerList represents a list of peers: either seeders or leechers. // PeerList represents a list of peers: either seeders or leechers.
type PeerList []Peer type PeerList []Peer
@ -51,8 +60,8 @@ type PeerList []Peer
type PeerKey string type PeerKey string
// NewPeerKey creates a properly formatted PeerKey. // NewPeerKey creates a properly formatted PeerKey.
func NewPeerKey(peerID string, ip net.IP, port string) PeerKey { func NewPeerKey(peerID string, ip net.IP) PeerKey {
return PeerKey(peerID + "//" + ip.String() + ":" + port) return PeerKey(peerID + "//" + ip.String())
} }
// IP parses and returns the IP address for a given PeerKey. // IP parses and returns the IP address for a given PeerKey.
@ -69,26 +78,23 @@ func (pk PeerKey) PeerID() string {
return strings.Split(string(pk), "//")[0] return strings.Split(string(pk), "//")[0]
} }
// Port returns the port section of the PeerKey. // Endpoint is an IP and port pair.
func (pk PeerKey) Port() string { type Endpoint struct {
return strings.Split(string(pk), "//")[2] // Always has length net.IPv4len if IPv4, and net.IPv6len if IPv6
IP net.IP `json:"ip"`
Port uint16 `json:"port"`
} }
// Peer is a participant in a swarm. // Peer is a participant in a swarm.
type Peer struct { type Peer struct {
ID string `json:"id"` ID string `json:"id"`
UserID uint64 `json:"user_id"` UserID uint64 `json:"user_id"`
TorrentID uint64 `json:"torrent_id"` TorrentID uint64 `json:"torrent_id"`
// Always has length net.IPv4len if IPv4, and net.IPv6len if IPv6
IP net.IP `json:"ip,omitempty"`
Port uint64 `json:"port"`
Uploaded uint64 `json:"uploaded"` Uploaded uint64 `json:"uploaded"`
Downloaded uint64 `json:"downloaded"` Downloaded uint64 `json:"downloaded"`
Left uint64 `json:"left"` Left uint64 `json:"left"`
LastAnnounce int64 `json:"last_announce"` LastAnnounce int64 `json:"last_announce"`
Endpoint
} }
// HasIPv4 determines if a peer's IP address can be represented as an IPv4 // HasIPv4 determines if a peer's IP address can be represented as an IPv4
@ -105,7 +111,7 @@ func (p *Peer) HasIPv6() bool {
// Key returns a PeerKey for the given peer. // Key returns a PeerKey for the given peer.
func (p *Peer) Key() PeerKey { func (p *Peer) Key() PeerKey {
return NewPeerKey(p.ID, p.IP, strconv.FormatUint(p.Port, 10)) return NewPeerKey(p.ID, p.IP)
} }
// Torrent is a swarm for a given torrent file. // Torrent is a swarm for a given torrent file.
@ -140,18 +146,17 @@ type User struct {
type Announce struct { type Announce struct {
Config *config.Config `json:"config"` Config *config.Config `json:"config"`
Compact bool `json:"compact"` Compact bool `json:"compact"`
Downloaded uint64 `json:"downloaded"` Downloaded uint64 `json:"downloaded"`
Event string `json:"event"` Event string `json:"event"`
IPv4 net.IP `json:"ipv4"` IPv4 Endpoint `json:"ipv4"`
IPv6 net.IP `json:"ipv6"` IPv6 Endpoint `json:"ipv6"`
Infohash string `json:"infohash"` Infohash string `json:"infohash"`
Left uint64 `json:"left"` Left uint64 `json:"left"`
NumWant int `json:"numwant"` NumWant int `json:"numwant"`
Passkey string `json:"passkey"` Passkey string `json:"passkey"`
PeerID string `json:"peer_id"` PeerID string `json:"peer_id"`
Port uint64 `json:"port"` Uploaded uint64 `json:"uploaded"`
Uploaded uint64 `json:"uploaded"`
Torrent *Torrent `json:"-"` Torrent *Torrent `json:"-"`
User *User `json:"-"` User *User `json:"-"`
@ -177,12 +182,14 @@ func (a *Announce) ClientID() (clientID string) {
return return
} }
// HasIPv4 determines whether or not an announce has an IPv4 endpoint.
func (a *Announce) HasIPv4() bool { func (a *Announce) HasIPv4() bool {
return a.IPv4 != nil return a.IPv4.IP != nil
} }
// HasIPv6 determines whether or not an announce has an IPv6 endpoint.
func (a *Announce) HasIPv6() bool { func (a *Announce) HasIPv6() bool {
return a.IPv6 != nil return a.IPv6.IP != nil
} }
// BuildPeer creates the Peer representation of an Announce. When provided nil // BuildPeer creates the Peer representation of an Announce. When provided nil
@ -192,7 +199,6 @@ func (a *Announce) HasIPv6() bool {
func (a *Announce) BuildPeer(u *User, t *Torrent) { func (a *Announce) BuildPeer(u *User, t *Torrent) {
a.Peer = &Peer{ a.Peer = &Peer{
ID: a.PeerID, ID: a.PeerID,
Port: a.Port,
Uploaded: a.Uploaded, Uploaded: a.Uploaded,
Downloaded: a.Downloaded, Downloaded: a.Downloaded,
Left: a.Left, Left: a.Left,
@ -211,15 +217,15 @@ func (a *Announce) BuildPeer(u *User, t *Torrent) {
if a.HasIPv4() && a.HasIPv6() { if a.HasIPv4() && a.HasIPv6() {
a.PeerV4 = a.Peer a.PeerV4 = a.Peer
a.PeerV4.IP = a.IPv4 a.PeerV4.Endpoint = a.IPv4
a.PeerV6 = &*a.Peer a.PeerV6 = &*a.Peer
a.PeerV6.IP = a.IPv6 a.PeerV6.Endpoint = a.IPv6
} else if a.HasIPv4() { } else if a.HasIPv4() {
a.PeerV4 = a.Peer a.PeerV4 = a.Peer
a.PeerV4.IP = a.IPv4 a.PeerV4.Endpoint = a.IPv4
} else if a.HasIPv6() { } else if a.HasIPv6() {
a.PeerV6 = a.Peer a.PeerV6 = a.Peer
a.PeerV6.IP = a.IPv6 a.PeerV6.Endpoint = a.IPv6
} else { } else {
panic("models: announce must have an IP") panic("models: announce must have an IP")
} }
@ -250,6 +256,7 @@ type AnnounceDelta struct {
// AnnounceResponse contains the information needed to fulfill an announce. // AnnounceResponse contains the information needed to fulfill an announce.
type AnnounceResponse struct { type AnnounceResponse struct {
Announce *Announce
Complete, Incomplete int Complete, Incomplete int
Interval, MinInterval time.Duration Interval, MinInterval time.Duration
IPv4Peers, IPv6Peers PeerList IPv4Peers, IPv6Peers PeerList

View file

@ -159,7 +159,7 @@ func (pm *PeerMap) AppendPeers(ipv4s, ipv6s PeerList, ann *Announce, wanted int)
} else if peersEquivalent(&peer, ann.Peer) { } else if peersEquivalent(&peer, ann.Peer) {
continue continue
} else { } else {
appendPeer(&ipv4s, &ipv6s, ann, &peer, &count) count += AppendPeer(&ipv4s, &ipv6s, ann, &peer)
} }
} }
@ -174,7 +174,7 @@ func (pm *PeerMap) AppendPeers(ipv4s, ipv6s PeerList, ann *Announce, wanted int)
} else if peersEquivalent(&peer, ann.Peer) { } else if peersEquivalent(&peer, ann.Peer) {
continue continue
} else { } else {
appendPeer(&ipv4s, &ipv6s, ann, &peer, &count) count += AppendPeer(&ipv4s, &ipv6s, ann, &peer)
} }
} }
} }
@ -183,18 +183,20 @@ func (pm *PeerMap) AppendPeers(ipv4s, ipv6s PeerList, ann *Announce, wanted int)
return ipv4s, ipv6s return ipv4s, ipv6s
} }
// appendPeer adds a peer to its corresponding peerlist. // AppendPeer adds a peer to its corresponding peerlist.
func appendPeer(ipv4s, ipv6s *PeerList, ann *Announce, peer *Peer, count *int) { func AppendPeer(ipv4s, ipv6s *PeerList, ann *Announce, peer *Peer) int {
if ann.HasIPv6() && peer.HasIPv6() { if ann.HasIPv6() && peer.HasIPv6() {
*ipv6s = append(*ipv6s, *peer) *ipv6s = append(*ipv6s, *peer)
*count++ return 1
} else if ann.Config.RespectAF && ann.HasIPv4() && peer.HasIPv4() { } else if ann.Config.RespectAF && ann.HasIPv4() && peer.HasIPv4() {
*ipv4s = append(*ipv4s, *peer) *ipv4s = append(*ipv4s, *peer)
*count++ return 1
} else if !ann.Config.RespectAF && peer.HasIPv4() { } else if !ann.Config.RespectAF && peer.HasIPv4() {
*ipv4s = append(*ipv4s, *peer) *ipv4s = append(*ipv4s, *peer)
*count++ return 1
} }
return 0
} }
// peersEquivalent checks if two peers represent the same entity. // peersEquivalent checks if two peers represent the same entity.

View file

@ -4,7 +4,10 @@
package tracker package tracker
import "github.com/chihaya/chihaya/tracker/models" import (
"github.com/chihaya/chihaya/stats"
"github.com/chihaya/chihaya/tracker/models"
)
// HandleScrape encapsulates all the logic of handling a BitTorrent client's // HandleScrape encapsulates all the logic of handling a BitTorrent client's
// scrape without being coupled to any transport protocol. // scrape without being coupled to any transport protocol.
@ -24,6 +27,7 @@ func (tkr *Tracker) HandleScrape(scrape *models.Scrape, w Writer) (err error) {
torrents = append(torrents, torrent) torrents = append(torrents, torrent)
} }
stats.RecordEvent(stats.Scrape)
return w.WriteScrape(&models.ScrapeResponse{ return w.WriteScrape(&models.ScrapeResponse{
Files: torrents, Files: torrents,
}) })

View file

@ -24,6 +24,15 @@ type Tracker struct {
*Storage *Storage
} }
// Server represents a server for a given BitTorrent tracker protocol.
type Server interface {
// Serve runs the server and blocks until the server has shut down.
Serve(addr string)
// Stop cleanly shuts down the server in a non-blocking manner.
Stop()
}
// New creates a new Tracker, and opens any necessary connections. // New creates a new Tracker, and opens any necessary connections.
// Maintenance routines are automatically spawned in the background. // Maintenance routines are automatically spawned in the background.
func New(cfg *config.Config) (*Tracker, error) { func New(cfg *config.Config) (*Tracker, error) {
@ -64,7 +73,8 @@ func (tkr *Tracker) LoadApprovedClients(clients []string) {
} }
// Writer serializes a tracker's responses, and is implemented for each // Writer serializes a tracker's responses, and is implemented for each
// response transport used by the tracker. // response transport used by the tracker. Only one of these may be called
// per request, and only once.
// //
// Note, data passed into any of these functions will not contain sensitive // Note, data passed into any of these functions will not contain sensitive
// information, so it may be passed back the client freely. // information, so it may be passed back the client freely.

87
udp/announce_test.go Normal file
View file

@ -0,0 +1,87 @@
// Copyright 2015 The Chihaya Authors. All rights reserved.
// Use of this source code is governed by the BSD 2-Clause license,
// which can be found in the LICENSE file.
package udp
import (
"bytes"
"encoding/binary"
"net"
"testing"
"github.com/chihaya/chihaya/config"
)
func requestAnnounce(sock *net.UDPConn, connID []byte, hash string) ([]byte, error) {
txID := makeTransactionID()
peerID := []byte("-UT2210-b4a2h9a9f5c4")
var request []byte
request = append(request, connID...)
request = append(request, announceAction...)
request = append(request, txID...)
request = append(request, []byte(hash)...)
request = append(request, peerID...)
request = append(request, make([]byte, 8)...) // Downloaded
request = append(request, make([]byte, 8)...) // Left
request = append(request, make([]byte, 8)...) // Uploaded
request = append(request, make([]byte, 4)...) // Event
request = append(request, make([]byte, 4)...) // IP
request = append(request, make([]byte, 4)...) // Key
request = append(request, make([]byte, 4)...) // NumWant
request = append(request, make([]byte, 2)...) // Port
return doRequest(sock, request, txID)
}
func TestAnnounce(t *testing.T) {
srv, done, err := setupTracker(&config.DefaultConfig)
if err != nil {
t.Fatal(err)
}
_, sock, err := setupSocket()
if err != nil {
t.Fatal(err)
}
connID, err := requestConnectionID(sock)
if err != nil {
t.Fatal(err)
}
announce, err := requestAnnounce(sock, connID, "aaaaaaaaaaaaaaaaaaaa")
if err != nil {
t.Fatal(err)
}
// Parse the response.
var action, txID, interval, leechers, seeders uint32
buf := bytes.NewReader(announce)
binary.Read(buf, binary.BigEndian, &action)
binary.Read(buf, binary.BigEndian, &txID)
binary.Read(buf, binary.BigEndian, &interval)
binary.Read(buf, binary.BigEndian, &leechers)
binary.Read(buf, binary.BigEndian, &seeders)
if action != uint32(announceActionID) {
t.Fatal("expected announce action")
}
if interval != uint32(config.DefaultConfig.Announce.Seconds()) {
t.Fatal("incorrect interval")
}
if leechers != uint32(0) {
t.Fatal("incorrect leecher count")
}
// We're the only seeder.
if seeders != uint32(1) {
t.Fatal("incorrect seeder count")
}
srv.Stop()
<-done
}

90
udp/connection.go Normal file
View file

@ -0,0 +1,90 @@
// Copyright 2015 The Chihaya Authors. All rights reserved.
// Use of this source code is governed by the BSD 2-Clause license,
// which can be found in the LICENSE file.
package udp
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"net"
)
// ConnectionIDGenerator represents the logic to generate 64-bit UDP
// connection IDs from peer IP addresses.
type ConnectionIDGenerator struct {
iv, iv2 []byte
block cipher.Block
}
// NewConnectionIDGenerator creates a ConnectionIDGenerator and generates its
// AES key and first initialization vector.
func NewConnectionIDGenerator() (gen *ConnectionIDGenerator, err error) {
gen = &ConnectionIDGenerator{}
key := make([]byte, 16)
_, err = rand.Read(key)
if err != nil {
return
}
gen.block, err = aes.NewCipher(key)
if err != nil {
return
}
err = gen.NewIV()
return
}
// Generate returns the 64-bit connection ID for an IP
func (g *ConnectionIDGenerator) Generate(ip net.IP) []byte {
return g.generate(ip, g.iv)
}
func (g *ConnectionIDGenerator) generate(ip net.IP, iv []byte) []byte {
for len(ip) < 8 {
ip = append(ip, ip...) // Not enough bits in output.
}
ct := make([]byte, 16)
stream := cipher.NewCFBDecrypter(g.block, iv)
stream.XORKeyStream(ct, ip)
for i := len(ip) - 1; i >= 8; i-- {
ct[i-8] ^= ct[i]
}
return ct[:8]
}
// Matches checks if the given connection ID matches an IP with the current or
// previous initialization vectors.
func (g *ConnectionIDGenerator) Matches(id []byte, ip net.IP) bool {
if expected := g.generate(ip, g.iv); bytes.Equal(id, expected) {
return true
}
if iv2 := g.iv2; iv2 != nil {
if expected := g.generate(ip, iv2); bytes.Equal(id, expected) {
return true
}
}
return false
}
// NewIV generates a new initialization vector and rotates the current one.
func (g *ConnectionIDGenerator) NewIV() error {
newiv := make([]byte, 16)
if _, err := rand.Read(newiv); err != nil {
return err
}
g.iv2 = g.iv
g.iv = newiv
return nil
}

70
udp/connection_test.go Normal file
View file

@ -0,0 +1,70 @@
// Copyright 2015 The Chihaya Authors. All rights reserved.
// Use of this source code is governed by the BSD 2-Clause license,
// which can be found in the LICENSE file.
package udp
import (
"bytes"
"net"
"testing"
)
func TestInitReturnsNoError(t *testing.T) {
if _, err := NewConnectionIDGenerator(); err != nil {
t.Error("Init returned", err)
}
}
func testGenerateConnectionID(t *testing.T, ip net.IP) {
gen, _ := NewConnectionIDGenerator()
id1 := gen.Generate(ip)
id2 := gen.Generate(ip)
if !bytes.Equal(id1, id2) {
t.Errorf("Connection ID mismatch: %x != %x", id1, id2)
}
if len(id1) != 8 {
t.Errorf("Connection ID had length: %d != 8", len(id1))
}
if bytes.Count(id1, []byte{0}) == 8 {
t.Errorf("Connection ID was 0")
}
}
func TestGenerateConnectionIDIPv4(t *testing.T) {
testGenerateConnectionID(t, net.ParseIP("192.168.1.123").To4())
}
func TestGenerateConnectionIDIPv6(t *testing.T) {
testGenerateConnectionID(t, net.ParseIP("1:2:3:4::5:6"))
}
func TestMatchesWorksWithPreviousIV(t *testing.T) {
gen, _ := NewConnectionIDGenerator()
ip := net.ParseIP("192.168.1.123").To4()
id1 := gen.Generate(ip)
if !gen.Matches(id1, ip) {
t.Errorf("Connection ID mismatch for current IV")
}
gen.NewIV()
if !gen.Matches(id1, ip) {
t.Errorf("Connection ID mismatch for previous IV")
}
id2 := gen.Generate(ip)
gen.NewIV()
if gen.Matches(id1, ip) {
t.Errorf("Connection ID matched for discarded IV")
}
if !gen.Matches(id2, ip) {
t.Errorf("Connection ID mismatch for previous IV")
}
}

252
udp/protocol.go Normal file
View file

@ -0,0 +1,252 @@
// Copyright 2015 The Chihaya Authors. All rights reserved.
// Use of this source code is governed by the BSD 2-Clause license,
// which can be found in the LICENSE file.
package udp
import (
"bytes"
"encoding/binary"
"net"
"github.com/chihaya/chihaya/stats"
"github.com/chihaya/chihaya/tracker/models"
)
const (
connectActionID uint32 = iota
announceActionID
scrapeActionID
errorActionID
announceDualStackActionID
)
var (
// initialConnectionID is the magic initial connection ID specified by BEP 15.
initialConnectionID = []byte{0, 0, 0x04, 0x17, 0x27, 0x10, 0x19, 0x80}
// emptyIPs are the value of an IP field that has been left blank.
emptyIPv4 = []byte{0, 0, 0, 0}
emptyIPv6 = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
// Option-Types described in BEP41 and BEP45.
optionEndOfOptions = byte(0x0)
optionNOP = byte(0x1)
optionURLData = byte(0x2)
optionIPv6 = byte(0x3)
// eventIDs map IDs to event names.
eventIDs = []string{
"",
"completed",
"started",
"stopped",
}
errMalformedPacket = models.ProtocolError("malformed packet")
errMalformedIP = models.ProtocolError("malformed IP address")
errMalformedEvent = models.ProtocolError("malformed event ID")
errBadConnectionID = models.ProtocolError("bad connection ID")
)
// handleTorrentError writes err to w if err is a models.ClientError.
func handleTorrentError(err error, w *Writer) {
if err == nil {
return
}
if models.IsPublicError(err) {
w.WriteError(err)
stats.RecordEvent(stats.ClientError)
}
}
// handlePacket decodes and processes one UDP request, returning the response.
func (s *Server) handlePacket(packet []byte, addr *net.UDPAddr) (response []byte, actionName string) {
if len(packet) < 16 {
return // Malformed, no client packets are less than 16 bytes.
}
connID := packet[0:8]
action := binary.BigEndian.Uint32(packet[8:12])
transactionID := packet[12:16]
writer := &Writer{
buf: new(bytes.Buffer),
connectionID: connID,
transactionID: transactionID,
}
defer func() { response = writer.buf.Bytes() }()
if action != 0 && !s.connIDGen.Matches(connID, addr.IP) {
writer.WriteError(errBadConnectionID)
return
}
switch action {
case connectActionID:
actionName = "connect"
if !bytes.Equal(connID, initialConnectionID) {
return // Malformed packet.
}
writer.writeHeader(0)
writer.buf.Write(s.connIDGen.Generate(addr.IP))
case announceActionID:
actionName = "announce"
ann, err := s.newAnnounce(packet, addr.IP)
if err == nil {
err = s.tracker.HandleAnnounce(ann, writer)
}
handleTorrentError(err, writer)
case scrapeActionID:
actionName = "scrape"
scrape, err := s.newScrape(packet)
if err == nil {
err = s.tracker.HandleScrape(scrape, writer)
}
handleTorrentError(err, writer)
}
return
}
// newAnnounce decodes one announce packet, returning a models.Announce.
func (s *Server) newAnnounce(packet []byte, ip net.IP) (*models.Announce, error) {
if len(packet) < 98 {
return nil, errMalformedPacket
}
infohash := packet[16:36]
peerID := packet[36:56]
downloaded := binary.BigEndian.Uint64(packet[56:64])
left := binary.BigEndian.Uint64(packet[64:72])
uploaded := binary.BigEndian.Uint64(packet[72:80])
eventID := packet[83]
if eventID > 3 {
return nil, errMalformedEvent
}
ipv4bytes := packet[84:88]
if s.config.AllowIPSpoofing && !bytes.Equal(ipv4bytes, emptyIPv4) {
ip = net.ParseIP(string(ipv4bytes))
}
if ip == nil {
return nil, errMalformedIP
} else if ipv4 := ip.To4(); ipv4 != nil {
ip = ipv4
}
numWant := binary.BigEndian.Uint32(packet[92:96])
port := binary.BigEndian.Uint16(packet[96:98])
announce := &models.Announce{
Config: s.config,
Downloaded: downloaded,
Event: eventIDs[eventID],
IPv4: models.Endpoint{
IP: ip,
Port: port,
},
Infohash: string(infohash),
Left: left,
NumWant: int(numWant),
PeerID: string(peerID),
Uploaded: uploaded,
}
if err := s.handleOptionalParameters(packet, announce); err != nil {
return nil, err
}
return announce, nil
}
// handleOptionalParameters parses the optional parameters as described in BEP41
// and updates an announce with the values parsed.
func (s *Server) handleOptionalParameters(packet []byte, announce *models.Announce) error {
if len(packet) > 98 {
optionStartIndex := 98
for optionStartIndex < len(packet)-1 {
option := packet[optionStartIndex]
switch option {
case optionEndOfOptions:
return nil
case optionNOP:
optionStartIndex++
case optionURLData:
if optionStartIndex+1 > len(packet)-1 {
return errMalformedPacket
}
length := int(packet[optionStartIndex+1])
if optionStartIndex+1+length > len(packet)-1 {
return errMalformedPacket
}
// TODO: Actually parse the URL Data as described in BEP41.
optionStartIndex += 1 + length
case optionIPv6:
if optionStartIndex+19 > len(packet)-1 {
return errMalformedPacket
}
ipv6bytes := packet[optionStartIndex+1 : optionStartIndex+17]
if s.config.AllowIPSpoofing && !bytes.Equal(ipv6bytes, emptyIPv6) {
announce.IPv6.IP = net.ParseIP(string(ipv6bytes)).To16()
announce.IPv6.Port = binary.BigEndian.Uint16(packet[optionStartIndex+17 : optionStartIndex+19])
if announce.IPv6.IP == nil {
return errMalformedIP
}
}
optionStartIndex += 19
default:
return nil
}
}
}
// There was no optional parameters to parse.
return nil
}
// newScrape decodes one announce packet, returning a models.Scrape.
func (s *Server) newScrape(packet []byte) (*models.Scrape, error) {
if len(packet) < 36 {
return nil, errMalformedPacket
}
var infohashes []string
packet = packet[16:]
if len(packet)%20 != 0 {
return nil, errMalformedPacket
}
for len(packet) >= 20 {
infohash := packet[:20]
infohashes = append(infohashes, string(infohash))
packet = packet[20:]
}
return &models.Scrape{
Config: s.config,
Infohashes: infohashes,
}, nil
}

77
udp/scrape_test.go Normal file
View file

@ -0,0 +1,77 @@
// Copyright 2015 The Chihaya Authors. All rights reserved.
// Use of this source code is governed by the BSD 2-Clause license,
// which can be found in the LICENSE file.
package udp
import (
"bytes"
"fmt"
"net"
"testing"
"github.com/chihaya/chihaya/config"
)
func doRequest(sock *net.UDPConn, request, txID []byte) ([]byte, error) {
response := make([]byte, 1024)
n, err := sendRequest(sock, request, response)
if err != nil {
return nil, err
}
if !bytes.Equal(response[4:8], txID) {
return nil, fmt.Errorf("transaction ID mismatch")
}
return response[:n], nil
}
func requestScrape(sock *net.UDPConn, connID []byte, hashes []string) ([]byte, error) {
txID := makeTransactionID()
var request []byte
request = append(request, connID...)
request = append(request, scrapeAction...)
request = append(request, txID...)
for _, hash := range hashes {
request = append(request, []byte(hash)...)
}
return doRequest(sock, request, txID)
}
func TestScrapeEmpty(t *testing.T) {
srv, done, err := setupTracker(&config.DefaultConfig)
if err != nil {
t.Fatal(err)
}
_, sock, err := setupSocket()
if err != nil {
t.Fatal(err)
}
connID, err := requestConnectionID(sock)
if err != nil {
t.Fatal(err)
}
scrape, err := requestScrape(sock, connID, []string{"aaaaaaaaaaaaaaaaaaaa"})
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(scrape[:4], errorAction) {
t.Error("expected error response")
}
if string(scrape[8:]) != "torrent does not exist\000" {
t.Error("expected torrent to not exist")
}
srv.Stop()
<-done
}

128
udp/udp.go Normal file
View file

@ -0,0 +1,128 @@
// Copyright 2015 The Chihaya Authors. All rights reserved.
// Use of this source code is governed by the BSD 2-Clause license,
// which can be found in the LICENSE file.
// Package udp implements a UDP BitTorrent tracker per BEP 15.
// IPv6 is currently unsupported as there is no widely-implemented standard.
package udp
import (
"errors"
"net"
"time"
"github.com/golang/glog"
"github.com/pushrax/bufferpool"
"github.com/chihaya/chihaya/config"
"github.com/chihaya/chihaya/tracker"
)
// Server represents a UDP torrent tracker.
type Server struct {
config *config.Config
tracker *tracker.Tracker
done bool
booting chan struct{}
sock *net.UDPConn
connIDGen *ConnectionIDGenerator
}
func (s *Server) serve(listenAddr string) error {
if s.sock != nil {
return errors.New("server already booted")
}
udpAddr, err := net.ResolveUDPAddr("udp", listenAddr)
if err != nil {
close(s.booting)
return err
}
sock, err := net.ListenUDP("udp", udpAddr)
defer sock.Close()
if err != nil {
close(s.booting)
return err
}
if s.config.UDPReadBufferSize > 0 {
sock.SetReadBuffer(s.config.UDPReadBufferSize)
}
pool := bufferpool.New(1000, 2048)
s.sock = sock
close(s.booting)
for !s.done {
buffer := pool.TakeSlice()
sock.SetReadDeadline(time.Now().Add(time.Second))
n, addr, err := sock.ReadFromUDP(buffer)
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Temporary() {
pool.GiveSlice(buffer)
continue
}
return err
}
start := time.Now()
go func() {
response, action := s.handlePacket(buffer[:n], addr)
pool.GiveSlice(buffer)
if len(response) > 0 {
sock.WriteToUDP(response, addr)
}
if glog.V(2) {
duration := time.Since(start)
glog.Infof("[UDP - %9s] %s %s", duration, action, addr)
}
}()
}
return nil
}
// Serve runs a UDP server, blocking until the server has shut down.
func (s *Server) Serve(addr string) {
glog.V(0).Info("Starting UDP on ", addr)
go func() {
// Generate a new IV every hour.
for range time.Tick(time.Hour) {
s.connIDGen.NewIV()
}
}()
if err := s.serve(addr); err != nil {
glog.Errorf("Failed to run UDP server: %s", err.Error())
} else {
glog.Info("UDP server shut down cleanly")
}
}
// Stop cleanly shuts down the server.
func (s *Server) Stop() {
s.done = true
s.sock.SetReadDeadline(time.Now())
}
// NewServer returns a new UDP server for a given configuration and tracker.
func NewServer(cfg *config.Config, tkr *tracker.Tracker) *Server {
gen, err := NewConnectionIDGenerator()
if err != nil {
panic(err)
}
return &Server{
config: cfg,
tracker: tkr,
connIDGen: gen,
booting: make(chan struct{}),
}
}

132
udp/udp_test.go Normal file
View file

@ -0,0 +1,132 @@
// Copyright 2015 The Chihaya Authors. All rights reserved.
// Use of this source code is governed by the BSD 2-Clause license,
// which can be found in the LICENSE file.
package udp
import (
"bytes"
"crypto/rand"
"fmt"
"net"
"testing"
"time"
"github.com/chihaya/chihaya/config"
"github.com/chihaya/chihaya/tracker"
_ "github.com/chihaya/chihaya/backend/noop"
)
var (
testPort = "34137"
connectAction = []byte{0, 0, 0, byte(connectActionID)}
announceAction = []byte{0, 0, 0, byte(announceActionID)}
scrapeAction = []byte{0, 0, 0, byte(scrapeActionID)}
errorAction = []byte{0, 0, 0, byte(errorActionID)}
)
func setupTracker(cfg *config.Config) (*Server, chan struct{}, error) {
tkr, err := tracker.New(cfg)
if err != nil {
return nil, nil, err
}
srv := NewServer(cfg, tkr)
done := make(chan struct{})
go func() {
if err := srv.serve(":" + testPort); err != nil {
panic(err)
}
close(done)
}()
<-srv.booting
return srv, done, nil
}
func setupSocket() (*net.UDPAddr, *net.UDPConn, error) {
srvAddr, err := net.ResolveUDPAddr("udp", "localhost:"+testPort)
if err != nil {
return nil, nil, err
}
sock, err := net.DialUDP("udp", nil, srvAddr)
if err != nil {
return nil, nil, err
}
return srvAddr, sock, err
}
func makeTransactionID() []byte {
out := make([]byte, 4)
rand.Read(out)
return out
}
func sendRequest(sock *net.UDPConn, request, response []byte) (int, error) {
if _, err := sock.Write(request); err != nil {
return 0, err
}
sock.SetReadDeadline(time.Now().Add(time.Second))
n, err := sock.Read(response)
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
return 0, fmt.Errorf("no response from tracker: %s", err)
}
}
return n, err
}
func requestConnectionID(sock *net.UDPConn) ([]byte, error) {
txID := makeTransactionID()
request := []byte{}
request = append(request, initialConnectionID...)
request = append(request, connectAction...)
request = append(request, txID...)
response := make([]byte, 1024)
n, err := sendRequest(sock, request, response)
if err != nil {
return nil, err
}
if n != 16 {
return nil, fmt.Errorf("packet length mismatch: %d != 16", n)
}
if !bytes.Equal(response[4:8], txID) {
return nil, fmt.Errorf("transaction ID mismatch")
}
if !bytes.Equal(response[0:4], connectAction) {
return nil, fmt.Errorf("action mismatch")
}
return response[8:16], nil
}
func TestRequestConnectionID(t *testing.T) {
srv, done, err := setupTracker(&config.DefaultConfig)
if err != nil {
t.Fatal(err)
}
_, sock, err := setupSocket()
if err != nil {
t.Fatal(err)
}
if _, err = requestConnectionID(sock); err != nil {
t.Fatal(err)
}
srv.Stop()
<-done
}

97
udp/writer.go Normal file
View file

@ -0,0 +1,97 @@
// Copyright 2015 The Chihaya Authors. All rights reserved.
// Use of this source code is governed by the BSD 2-Clause license,
// which can be found in the LICENSE file.
package udp
import (
"bytes"
"encoding/binary"
"time"
"github.com/chihaya/chihaya/tracker/models"
)
// Writer implements the tracker.Writer interface for the UDP protocol.
type Writer struct {
buf *bytes.Buffer
connectionID []byte
transactionID []byte
}
// WriteError writes the failure reason as a null-terminated string.
func (w *Writer) WriteError(err error) error {
w.writeHeader(errorActionID)
w.buf.WriteString(err.Error())
w.buf.WriteRune('\000')
return nil
}
// WriteAnnounce encodes an announce response by selecting the proper announce
// format based on the BitTorrent spec.
func (w *Writer) WriteAnnounce(resp *models.AnnounceResponse) (err error) {
if resp.Announce.HasIPv6() {
err = w.WriteAnnounceIPv6(resp)
} else {
err = w.WriteAnnounceIPv4(resp)
}
return
}
// WriteAnnounceIPv6 encodes an announce response according to BEP45.
func (w *Writer) WriteAnnounceIPv6(resp *models.AnnounceResponse) error {
w.writeHeader(announceDualStackActionID)
binary.Write(w.buf, binary.BigEndian, uint32(resp.Interval/time.Second))
binary.Write(w.buf, binary.BigEndian, uint32(resp.Incomplete))
binary.Write(w.buf, binary.BigEndian, uint32(resp.Complete))
binary.Write(w.buf, binary.BigEndian, uint32(len(resp.IPv4Peers)))
binary.Write(w.buf, binary.BigEndian, uint32(len(resp.IPv6Peers)))
for _, peer := range resp.IPv4Peers {
w.buf.Write(peer.IP)
binary.Write(w.buf, binary.BigEndian, peer.Port)
}
for _, peer := range resp.IPv6Peers {
w.buf.Write(peer.IP)
binary.Write(w.buf, binary.BigEndian, peer.Port)
}
return nil
}
// WriteAnnounceIPv4 encodes an announce response according to BEP15.
func (w *Writer) WriteAnnounceIPv4(resp *models.AnnounceResponse) error {
w.writeHeader(announceActionID)
binary.Write(w.buf, binary.BigEndian, uint32(resp.Interval/time.Second))
binary.Write(w.buf, binary.BigEndian, uint32(resp.Incomplete))
binary.Write(w.buf, binary.BigEndian, uint32(resp.Complete))
for _, peer := range resp.IPv4Peers {
w.buf.Write(peer.IP)
binary.Write(w.buf, binary.BigEndian, peer.Port)
}
return nil
}
// WriteScrape encodes a scrape response according to BEP15.
func (w *Writer) WriteScrape(resp *models.ScrapeResponse) error {
w.writeHeader(scrapeActionID)
for _, torrent := range resp.Files {
binary.Write(w.buf, binary.BigEndian, uint32(torrent.Seeders.Len()))
binary.Write(w.buf, binary.BigEndian, uint32(torrent.Snatches))
binary.Write(w.buf, binary.BigEndian, uint32(torrent.Leechers.Len()))
}
return nil
}
// writeHeader writes the action and transaction ID to the response.
func (w *Writer) writeHeader(action uint32) {
binary.Write(w.buf, binary.BigEndian, action)
w.buf.Write(w.transactionID)
}