Merge branch 'udp' into develop
This commit is contained in:
commit
a17474bb05
47 changed files with 1815 additions and 556 deletions
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
|
@ -0,0 +1,2 @@
|
|||
/config.json
|
||||
/chihaya
|
21
Godeps/Godeps.json
generated
21
Godeps/Godeps.json
generated
|
@ -1,10 +1,10 @@
|
|||
{
|
||||
"ImportPath": "github.com/chihaya/chihaya",
|
||||
"GoVersion": "go1.4.1",
|
||||
"GoVersion": "go1.4.2",
|
||||
"Deps": [
|
||||
{
|
||||
"ImportPath": "github.com/chihaya/bencode",
|
||||
"Rev": "e60878f635e1a61315c413492e133dd39769b1d1"
|
||||
"Rev": "3c485a8d166ff6a79baba90c2c2da01c8348e930"
|
||||
},
|
||||
{
|
||||
"ImportPath": "github.com/golang/glog",
|
||||
|
@ -12,7 +12,11 @@
|
|||
},
|
||||
{
|
||||
"ImportPath": "github.com/julienschmidt/httprouter",
|
||||
"Rev": "00ce1c6a267162792c367acc43b1681a884e1872"
|
||||
"Rev": "8c199fb6259ffc1af525cc3ad52ee60ba8359669"
|
||||
},
|
||||
{
|
||||
"ImportPath": "github.com/pushrax/bufferpool",
|
||||
"Rev": "7d6e1653dee10a165d1f357f3a57bc8031e9621b"
|
||||
},
|
||||
{
|
||||
"ImportPath": "github.com/pushrax/faststats",
|
||||
|
@ -23,16 +27,13 @@
|
|||
"Rev": "86044f1c998d49053e13293029414ddb63f3a422"
|
||||
},
|
||||
{
|
||||
"ImportPath": "github.com/stretchr/graceful",
|
||||
"Rev": "8e780ba3fe3d3e7ab15fc52e3d60a996587181dc"
|
||||
},
|
||||
{
|
||||
"ImportPath": "github.com/stretchr/pat/stop",
|
||||
"Rev": "f7fe051f2b9bcaca162b38de4f93c9a8457160b9"
|
||||
"ImportPath": "github.com/tylerb/graceful",
|
||||
"Comment": "v1-7-g0c01122",
|
||||
"Rev": "0c011221e91b35f488b8818b00ca279929e9ed7d"
|
||||
},
|
||||
{
|
||||
"ImportPath": "golang.org/x/net/netutil",
|
||||
"Rev": "c84eff7014eba178f68bd4c05b86780efe0fbf35"
|
||||
"Rev": "d175081df37eff8cda13f478bc11a0a65b39958b"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
|
6
Godeps/_workspace/src/github.com/chihaya/bencode/encoder.go
generated
vendored
6
Godeps/_workspace/src/github.com/chihaya/bencode/encoder.go
generated
vendored
|
@ -62,6 +62,12 @@ func marshal(w io.Writer, data interface{}) error {
|
|||
case uint:
|
||||
marshalUint(w, uint64(v))
|
||||
|
||||
case int16:
|
||||
marshalInt(w, int64(v))
|
||||
|
||||
case uint16:
|
||||
marshalUint(w, uint64(v))
|
||||
|
||||
case int64:
|
||||
marshalInt(w, v)
|
||||
|
||||
|
|
2
Godeps/_workspace/src/github.com/chihaya/bencode/encoder_test.go
generated
vendored
2
Godeps/_workspace/src/github.com/chihaya/bencode/encoder_test.go
generated
vendored
|
@ -19,6 +19,8 @@ var marshalTests = []struct {
|
|||
{uint(43), "i43e"},
|
||||
{int64(44), "i44e"},
|
||||
{uint64(45), "i45e"},
|
||||
{int16(44), "i44e"},
|
||||
{uint16(45), "i45e"},
|
||||
|
||||
{"example", "7:example"},
|
||||
{[]byte("example"), "7:example"},
|
||||
|
|
1
Godeps/_workspace/src/github.com/julienschmidt/httprouter/.travis.yml
generated
vendored
1
Godeps/_workspace/src/github.com/julienschmidt/httprouter/.travis.yml
generated
vendored
|
@ -1,3 +1,4 @@
|
|||
sudo: false
|
||||
language: go
|
||||
go:
|
||||
- 1.1
|
||||
|
|
52
Godeps/_workspace/src/github.com/julienschmidt/httprouter/README.md
generated
vendored
52
Godeps/_workspace/src/github.com/julienschmidt/httprouter/README.md
generated
vendored
|
@ -1,29 +1,17 @@
|
|||
# HttpRouter [![Build Status](https://travis-ci.org/julienschmidt/httprouter.png?branch=master)](https://travis-ci.org/julienschmidt/httprouter) [![GoDoc](http://godoc.org/github.com/julienschmidt/httprouter?status.png)](http://godoc.org/github.com/julienschmidt/httprouter)
|
||||
# HttpRouter [![Build Status](https://travis-ci.org/julienschmidt/httprouter.png?branch=master)](https://travis-ci.org/julienschmidt/httprouter) [![Coverage](http://gocover.io/_badge/github.com/julienschmidt/httprouter?0)](http://gocover.io/github.com/julienschmidt/httprouter) [![GoDoc](http://godoc.org/github.com/julienschmidt/httprouter?status.png)](http://godoc.org/github.com/julienschmidt/httprouter)
|
||||
|
||||
HttpRouter is a lightweight high performance HTTP request router
|
||||
(also called *multiplexer* or just *mux* for short) for [Go](http://golang.org/).
|
||||
|
||||
In contrast to the default mux of Go's net/http package, this router supports
|
||||
In contrast to the [default mux](http://golang.org/pkg/net/http/#ServeMux) of Go's net/http package, this router supports
|
||||
variables in the routing pattern and matches against the request method.
|
||||
It also scales better.
|
||||
|
||||
The router is optimized for best performance and a small memory footprint.
|
||||
The router is optimized for high performance and a small memory footprint.
|
||||
It scales well even with very long paths and a large number of routes.
|
||||
A compressing dynamic trie (radix tree) structure is used for efficient matching.
|
||||
|
||||
## Features
|
||||
**Zero Garbage:** The matching and dispatching process generates zero bytes of
|
||||
garbage. In fact, the only heap allocations that are made, is by building the
|
||||
slice of the key-value pairs for path parameters. If the request path contains
|
||||
no parameters, not a single heap allocation is necessary.
|
||||
|
||||
**Best Performance:** [Benchmarks speak for themselves](https://github.com/julienschmidt/go-http-routing-benchmark).
|
||||
See below for technical details of the implementation.
|
||||
|
||||
**Parameters in your routing pattern:** Stop parsing the requested URL path,
|
||||
just give the path segment a name and the router delivers the dynamic value to
|
||||
you. Because of the design of the router, path parameters are very cheap.
|
||||
|
||||
**Only explicit matches:** With other routers, like [http.ServeMux](http://golang.org/pkg/net/http/#ServeMux),
|
||||
a requested URL path could match multiple patterns. Therefore they have some
|
||||
awkward pattern priority rules, like *longest match* or *first registered,
|
||||
|
@ -34,7 +22,7 @@ great for SEO and improves the user experience.
|
|||
**Stop caring about trailing slashes:** Choose the URL style you like, the
|
||||
router automatically redirects the client if a trailing slash is missing or if
|
||||
there is one extra. Of course it only does so, if the new path has a handler.
|
||||
If you don't like it, you can turn off this behavior.
|
||||
If you don't like it, you can [turn off this behavior](http://godoc.org/github.com/julienschmidt/httprouter#Router.RedirectTrailingSlash).
|
||||
|
||||
**Path auto-correction:** Besides detecting the missing or additional trailing
|
||||
slash at no extra cost, the router can also fix wrong cases and remove
|
||||
|
@ -43,11 +31,23 @@ Is [CAPTAIN CAPS LOCK](http://www.urbandictionary.com/define.php?term=Captain+Ca
|
|||
HttpRouter can help him by making a case-insensitive look-up and redirecting him
|
||||
to the correct URL.
|
||||
|
||||
**No more server crashes:** You can set a PanicHandler to deal with panics
|
||||
**Parameters in your routing pattern:** Stop parsing the requested URL path,
|
||||
just give the path segment a name and the router delivers the dynamic value to
|
||||
you. Because of the design of the router, path parameters are very cheap.
|
||||
|
||||
**Zero Garbage:** The matching and dispatching process generates zero bytes of
|
||||
garbage. In fact, the only heap allocations that are made, is by building the
|
||||
slice of the key-value pairs for path parameters. If the request path contains
|
||||
no parameters, not a single heap allocation is necessary.
|
||||
|
||||
**Best Performance:** [Benchmarks speak for themselves](https://github.com/julienschmidt/go-http-routing-benchmark).
|
||||
See below for technical details of the implementation.
|
||||
|
||||
**No more server crashes:** You can set a [Panic handler](http://godoc.org/github.com/julienschmidt/httprouter#Router.PanicHandler) to deal with panics
|
||||
occurring during handling a HTTP request. The router then recovers and lets the
|
||||
PanicHandler log what happened and deliver a nice error page.
|
||||
|
||||
Of course you can also set a **custom NotFound handler** and **serve static files**.
|
||||
Of course you can also set **custom [NotFound](http://godoc.org/github.com/julienschmidt/httprouter#Router.NotFound) and [MethodNotAllowed](http://godoc.org/github.com/julienschmidt/httprouter#Router.MethodNotAllowed) handlers** and [**serve static files**](http://godoc.org/github.com/julienschmidt/httprouter#Router.ServeFiles).
|
||||
|
||||
## Usage
|
||||
This is just a quick introduction, view the [GoDoc](http://godoc.org/github.com/julienschmidt/httprouter) for details.
|
||||
|
@ -189,7 +189,7 @@ for example the [Gorilla handlers](http://www.gorillatoolkit.org/pkg/handlers).
|
|||
Or you could [just write your own](http://justinas.org/writing-http-middleware-in-go/),
|
||||
it's very easy!
|
||||
|
||||
Alternatively, you could try [a framework building upon HttpRouter](#web-frameworks--co-based-on-httprouter).
|
||||
Alternatively, you could try [a web framework based on HttpRouter](#web-frameworks-based-on-httprouter).
|
||||
|
||||
### Multi-domain / Sub-domains
|
||||
Here is a quick example: Does your server serve multiple domains / hosts?
|
||||
|
@ -256,7 +256,10 @@ func BasicAuth(h httprouter.Handle, user, pass []byte) httprouter.Handle {
|
|||
payload, err := base64.StdEncoding.DecodeString(auth[len(basicAuthPrefix):])
|
||||
if err == nil {
|
||||
pair := bytes.SplitN(payload, []byte(":"), 2)
|
||||
if len(pair) == 2 && bytes.Equal(pair[0], user) && bytes.Equal(pair[1], pass) {
|
||||
if len(pair) == 2 &&
|
||||
bytes.Equal(pair[0], user) &&
|
||||
bytes.Equal(pair[1], pass) {
|
||||
|
||||
// Delegate request to the given handle
|
||||
h(w, r, ps)
|
||||
return
|
||||
|
@ -305,9 +308,16 @@ router.NotFound = http.FileServer(http.Dir("public")).ServeHTTP
|
|||
|
||||
But this approach sidesteps the strict core rules of this router to avoid routing problems. A cleaner approach is to use a distinct sub-path for serving files, like `/static/*filepath` or `/files/*filepath`.
|
||||
|
||||
## Web Frameworks & Co based on HttpRouter
|
||||
## Web Frameworks based on HttpRouter
|
||||
If the HttpRouter is a bit too minimalistic for you, you might try one of the following more high-level 3rd-party web frameworks building upon the HttpRouter package:
|
||||
* [Ace](https://github.com/plimble/ace): Blazing fast Go Web Framework
|
||||
* [api2go](https://github.com/univedo/api2go): A JSON API Implementation for Go
|
||||
* [Gin](https://github.com/gin-gonic/gin): Features a martini-like API with much better performance
|
||||
* [Goat](https://github.com/bahlo/goat): A minimalistic REST API server in Go
|
||||
* [Hikaru](https://github.com/najeira/hikaru): Supports standalone and Google AppEngine
|
||||
* [Hitch](https://github.com/nbio/hitch): Hitch ties httprouter, [httpcontext](https://github.com/nbio/httpcontext), and middleware up in a bow
|
||||
* [kami](https://github.com/guregu/kami): A tiny web framework using x/net/context
|
||||
* [Medeina](https://github.com/imdario/medeina): Inspired by Ruby's Roda and Cuba
|
||||
* [Neko](https://github.com/rocwong/neko): A lightweight web application framework for Golang
|
||||
* [Roxanna](https://github.com/iamthemuffinman/Roxanna): An amalgamation of httprouter, better logging, and hot reload
|
||||
* [siesta](https://github.com/VividCortex/siesta): Composable HTTP handlers with contexts
|
||||
|
|
32
Godeps/_workspace/src/github.com/julienschmidt/httprouter/router.go
generated
vendored
32
Godeps/_workspace/src/github.com/julienschmidt/httprouter/router.go
generated
vendored
|
@ -142,6 +142,11 @@ type Router struct {
|
|||
// found. If it is not set, http.NotFound is used.
|
||||
NotFound http.HandlerFunc
|
||||
|
||||
// Configurable http.HandlerFunc which is called when a request
|
||||
// cannot be routed and HandleMethodNotAllowed is true.
|
||||
// If it is not set, http.Error with http.StatusMethodNotAllowed is used.
|
||||
MethodNotAllowed http.HandlerFunc
|
||||
|
||||
// Function to handle panics recovered from http handlers.
|
||||
// It should be used to generate a error page and return the http error code
|
||||
// 500 (Internal Server Error).
|
||||
|
@ -173,6 +178,11 @@ func (r *Router) HEAD(path string, handle Handle) {
|
|||
r.Handle("HEAD", path, handle)
|
||||
}
|
||||
|
||||
// OPTIONS is a shortcut for router.Handle("OPTIONS", path, handle)
|
||||
func (r *Router) OPTIONS(path string, handle Handle) {
|
||||
r.Handle("OPTIONS", path, handle)
|
||||
}
|
||||
|
||||
// POST is a shortcut for router.Handle("POST", path, handle)
|
||||
func (r *Router) POST(path string, handle Handle) {
|
||||
r.Handle("POST", path, handle)
|
||||
|
@ -203,7 +213,7 @@ func (r *Router) DELETE(path string, handle Handle) {
|
|||
// communication with a proxy).
|
||||
func (r *Router) Handle(method, path string, handle Handle) {
|
||||
if path[0] != '/' {
|
||||
panic("path must begin with '/'")
|
||||
panic("path must begin with '/' in path '" + path + "'")
|
||||
}
|
||||
|
||||
if r.trees == nil {
|
||||
|
@ -232,11 +242,7 @@ func (r *Router) Handler(method, path string, handler http.Handler) {
|
|||
// HandlerFunc is an adapter which allows the usage of an http.HandlerFunc as a
|
||||
// request handle.
|
||||
func (r *Router) HandlerFunc(method, path string, handler http.HandlerFunc) {
|
||||
r.Handle(method, path,
|
||||
func(w http.ResponseWriter, req *http.Request, _ Params) {
|
||||
handler(w, req)
|
||||
},
|
||||
)
|
||||
r.Handler(method, path, handler)
|
||||
}
|
||||
|
||||
// ServeFiles serves files from the given file system root.
|
||||
|
@ -251,7 +257,7 @@ func (r *Router) HandlerFunc(method, path string, handler http.HandlerFunc) {
|
|||
// router.ServeFiles("/src/*filepath", http.Dir("/var/www"))
|
||||
func (r *Router) ServeFiles(path string, root http.FileSystem) {
|
||||
if len(path) < 10 || path[len(path)-10:] != "/*filepath" {
|
||||
panic("path must end with /*filepath")
|
||||
panic("path must end with /*filepath in path '" + path + "'")
|
||||
}
|
||||
|
||||
fileServer := http.FileServer(root)
|
||||
|
@ -335,10 +341,14 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
|||
|
||||
handle, _, _ := r.trees[method].getValue(req.URL.Path)
|
||||
if handle != nil {
|
||||
http.Error(w,
|
||||
http.StatusText(http.StatusMethodNotAllowed),
|
||||
http.StatusMethodNotAllowed,
|
||||
)
|
||||
if r.MethodNotAllowed != nil {
|
||||
r.MethodNotAllowed(w, req)
|
||||
} else {
|
||||
http.Error(w,
|
||||
http.StatusText(http.StatusMethodNotAllowed),
|
||||
http.StatusMethodNotAllowed,
|
||||
)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
|
27
Godeps/_workspace/src/github.com/julienschmidt/httprouter/router_test.go
generated
vendored
27
Godeps/_workspace/src/github.com/julienschmidt/httprouter/router_test.go
generated
vendored
|
@ -76,7 +76,7 @@ func (h handlerStruct) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
func TestRouterAPI(t *testing.T) {
|
||||
var get, head, post, put, patch, delete, handler, handlerFunc bool
|
||||
var get, head, options, post, put, patch, delete, handler, handlerFunc bool
|
||||
|
||||
httpHandler := handlerStruct{&handler}
|
||||
|
||||
|
@ -87,6 +87,9 @@ func TestRouterAPI(t *testing.T) {
|
|||
router.HEAD("/GET", func(w http.ResponseWriter, r *http.Request, _ Params) {
|
||||
head = true
|
||||
})
|
||||
router.OPTIONS("/GET", func(w http.ResponseWriter, r *http.Request, _ Params) {
|
||||
options = true
|
||||
})
|
||||
router.POST("/POST", func(w http.ResponseWriter, r *http.Request, _ Params) {
|
||||
post = true
|
||||
})
|
||||
|
@ -118,6 +121,12 @@ func TestRouterAPI(t *testing.T) {
|
|||
t.Error("routing HEAD failed")
|
||||
}
|
||||
|
||||
r, _ = http.NewRequest("OPTIONS", "/GET", nil)
|
||||
router.ServeHTTP(w, r)
|
||||
if !options {
|
||||
t.Error("routing OPTIONS failed")
|
||||
}
|
||||
|
||||
r, _ = http.NewRequest("POST", "/POST", nil)
|
||||
router.ServeHTTP(w, r)
|
||||
if !post {
|
||||
|
@ -176,7 +185,21 @@ func TestRouterNotAllowed(t *testing.T) {
|
|||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, r)
|
||||
if !(w.Code == http.StatusMethodNotAllowed) {
|
||||
t.Errorf("NotAllowed handling route %s failed: Code=%d, Header=%v", w.Code, w.Header())
|
||||
t.Errorf("NotAllowed handling failed: Code=%d, Header=%v", w.Code, w.Header())
|
||||
}
|
||||
|
||||
w = httptest.NewRecorder()
|
||||
responseText := "custom method"
|
||||
router.MethodNotAllowed = func(w http.ResponseWriter, req *http.Request) {
|
||||
w.WriteHeader(http.StatusTeapot)
|
||||
w.Write([]byte(responseText))
|
||||
}
|
||||
router.ServeHTTP(w, r)
|
||||
if got := w.Body.String(); !(got == responseText) {
|
||||
t.Errorf("unexpected response got %q want %q", got, responseText)
|
||||
}
|
||||
if w.Code != http.StatusTeapot {
|
||||
t.Errorf("unexpected response code %d want %d", w.Code, http.StatusTeapot)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
207
Godeps/_workspace/src/github.com/julienschmidt/httprouter/tree.go
generated
vendored
207
Godeps/_workspace/src/github.com/julienschmidt/httprouter/tree.go
generated
vendored
|
@ -43,41 +43,48 @@ type node struct {
|
|||
wildChild bool
|
||||
nType nodeType
|
||||
maxParams uint8
|
||||
indices []byte
|
||||
indices string
|
||||
children []*node
|
||||
handle Handle
|
||||
priority uint32
|
||||
}
|
||||
|
||||
// increments priority of the given child and reorders if necessary
|
||||
func (n *node) incrementChildPrio(i int) int {
|
||||
n.children[i].priority++
|
||||
prio := n.children[i].priority
|
||||
func (n *node) incrementChildPrio(pos int) int {
|
||||
n.children[pos].priority++
|
||||
prio := n.children[pos].priority
|
||||
|
||||
// adjust position (move to front)
|
||||
for j := i - 1; j >= 0 && n.children[j].priority < prio; j-- {
|
||||
newPos := pos
|
||||
for newPos > 0 && n.children[newPos-1].priority < prio {
|
||||
// swap node positions
|
||||
tmpN := n.children[j]
|
||||
n.children[j] = n.children[i]
|
||||
n.children[i] = tmpN
|
||||
tmpI := n.indices[j]
|
||||
n.indices[j] = n.indices[i]
|
||||
n.indices[i] = tmpI
|
||||
tmpN := n.children[newPos-1]
|
||||
n.children[newPos-1] = n.children[newPos]
|
||||
n.children[newPos] = tmpN
|
||||
|
||||
i--
|
||||
newPos--
|
||||
}
|
||||
return i
|
||||
|
||||
// build new index char string
|
||||
if newPos != pos {
|
||||
n.indices = n.indices[:newPos] + // unchanged prefix, might be empty
|
||||
n.indices[pos:pos+1] + // the index char we move
|
||||
n.indices[newPos:pos] + n.indices[pos+1:] // rest without char at 'pos'
|
||||
}
|
||||
|
||||
return newPos
|
||||
}
|
||||
|
||||
// addRoute adds a node with the given handle to the path.
|
||||
// Not concurrency-safe!
|
||||
func (n *node) addRoute(path string, handle Handle) {
|
||||
fullPath := path
|
||||
n.priority++
|
||||
numParams := countParams(path)
|
||||
|
||||
// non-empty tree
|
||||
if len(n.path) > 0 || len(n.children) > 0 {
|
||||
WALK:
|
||||
walk:
|
||||
for {
|
||||
// Update maxParams of the current node
|
||||
if numParams > n.maxParams {
|
||||
|
@ -85,10 +92,12 @@ func (n *node) addRoute(path string, handle Handle) {
|
|||
}
|
||||
|
||||
// Find the longest common prefix.
|
||||
// This also implies that the commom prefix contains no ':' or '*'
|
||||
// since the existing key can't contain this chars.
|
||||
// This also implies that the common prefix contains no ':' or '*'
|
||||
// since the existing key can't contain those chars.
|
||||
i := 0
|
||||
for max := min(len(path), len(n.path)); i < max && path[i] == n.path[i]; i++ {
|
||||
max := min(len(path), len(n.path))
|
||||
for i < max && path[i] == n.path[i] {
|
||||
i++
|
||||
}
|
||||
|
||||
// Split edge
|
||||
|
@ -110,7 +119,8 @@ func (n *node) addRoute(path string, handle Handle) {
|
|||
}
|
||||
|
||||
n.children = []*node{&child}
|
||||
n.indices = []byte{n.path[i]}
|
||||
// []byte for proper unicode char conversion, see #65
|
||||
n.indices = string([]byte{n.path[i]})
|
||||
n.path = path[:i]
|
||||
n.handle = nil
|
||||
n.wildChild = false
|
||||
|
@ -134,11 +144,13 @@ func (n *node) addRoute(path string, handle Handle) {
|
|||
if len(path) >= len(n.path) && n.path == path[:len(n.path)] {
|
||||
// check for longer wildcard, e.g. :name and :names
|
||||
if len(n.path) >= len(path) || path[len(n.path)] == '/' {
|
||||
continue WALK
|
||||
continue walk
|
||||
}
|
||||
}
|
||||
|
||||
panic("conflict with wildcard route")
|
||||
panic("path segment '" + path +
|
||||
"' conflicts with existing wildcard '" + n.path +
|
||||
"' in path '" + fullPath + "'")
|
||||
}
|
||||
|
||||
c := path[0]
|
||||
|
@ -147,21 +159,22 @@ func (n *node) addRoute(path string, handle Handle) {
|
|||
if n.nType == param && c == '/' && len(n.children) == 1 {
|
||||
n = n.children[0]
|
||||
n.priority++
|
||||
continue WALK
|
||||
continue walk
|
||||
}
|
||||
|
||||
// Check if a child with the next path byte exists
|
||||
for i, index := range n.indices {
|
||||
if c == index {
|
||||
for i := 0; i < len(n.indices); i++ {
|
||||
if c == n.indices[i] {
|
||||
i = n.incrementChildPrio(i)
|
||||
n = n.children[i]
|
||||
continue WALK
|
||||
continue walk
|
||||
}
|
||||
}
|
||||
|
||||
// Otherwise insert it
|
||||
if c != ':' && c != '*' {
|
||||
n.indices = append(n.indices, c)
|
||||
// []byte for proper unicode char conversion, see #65
|
||||
n.indices += string([]byte{c})
|
||||
child := &node{
|
||||
maxParams: numParams,
|
||||
}
|
||||
|
@ -169,24 +182,24 @@ func (n *node) addRoute(path string, handle Handle) {
|
|||
n.incrementChildPrio(len(n.indices) - 1)
|
||||
n = child
|
||||
}
|
||||
n.insertChild(numParams, path, handle)
|
||||
n.insertChild(numParams, path, fullPath, handle)
|
||||
return
|
||||
|
||||
} else if i == len(path) { // Make node a (in-path) leaf
|
||||
if n.handle != nil {
|
||||
panic("a Handle is already registered for this path")
|
||||
panic("a handle is already registered for path ''" + fullPath + "'")
|
||||
}
|
||||
n.handle = handle
|
||||
}
|
||||
return
|
||||
}
|
||||
} else { // Empty tree
|
||||
n.insertChild(numParams, path, handle)
|
||||
n.insertChild(numParams, path, fullPath, handle)
|
||||
}
|
||||
}
|
||||
|
||||
func (n *node) insertChild(numParams uint8, path string, handle Handle) {
|
||||
var offset int
|
||||
func (n *node) insertChild(numParams uint8, path, fullPath string, handle Handle) {
|
||||
var offset int // already handled bytes of the path
|
||||
|
||||
// find prefix until first wildcard (beginning with ':'' or '*'')
|
||||
for i, max := 0, len(path); numParams > 0; i++ {
|
||||
|
@ -195,20 +208,29 @@ func (n *node) insertChild(numParams uint8, path string, handle Handle) {
|
|||
continue
|
||||
}
|
||||
|
||||
// Check if this Node existing children which would be
|
||||
// unreachable if we insert the wildcard here
|
||||
if len(n.children) > 0 {
|
||||
panic("wildcard route conflicts with existing children")
|
||||
}
|
||||
|
||||
// find wildcard end (either '/' or path end)
|
||||
end := i + 1
|
||||
for end < max && path[end] != '/' {
|
||||
end++
|
||||
switch path[end] {
|
||||
// the wildcard name must not contain ':' and '*'
|
||||
case ':', '*':
|
||||
panic("only one wildcard per path segment is allowed, has: '" +
|
||||
path[i:] + "' in path '" + fullPath + "'")
|
||||
default:
|
||||
end++
|
||||
}
|
||||
}
|
||||
|
||||
// check if this Node existing children which would be
|
||||
// unreachable if we insert the wildcard here
|
||||
if len(n.children) > 0 {
|
||||
panic("wildcard route '" + path[i:end] +
|
||||
"' conflicts with existing children in path '" + fullPath + "'")
|
||||
}
|
||||
|
||||
// check if the wildcard has a name
|
||||
if end-i < 2 {
|
||||
panic("wildcards must be named with a non-empty name")
|
||||
panic("wildcards must be named with a non-empty name in path '" + fullPath + "'")
|
||||
}
|
||||
|
||||
if c == ':' { // param
|
||||
|
@ -244,17 +266,17 @@ func (n *node) insertChild(numParams uint8, path string, handle Handle) {
|
|||
|
||||
} else { // catchAll
|
||||
if end != max || numParams > 1 {
|
||||
panic("catch-all routes are only allowed at the end of the path")
|
||||
panic("catch-all routes are only allowed at the end of the path in path '" + fullPath + "'")
|
||||
}
|
||||
|
||||
if len(n.path) > 0 && n.path[len(n.path)-1] == '/' {
|
||||
panic("catch-all conflicts with existing handle for the path segment root")
|
||||
panic("catch-all conflicts with existing handle for the path segment root in path '" + fullPath + "'")
|
||||
}
|
||||
|
||||
// currently fixed width 1 for '/'
|
||||
i--
|
||||
if path[i] != '/' {
|
||||
panic("no / before catch-all")
|
||||
panic("no / before catch-all in path '" + fullPath + "'")
|
||||
}
|
||||
|
||||
n.path = path[offset:i]
|
||||
|
@ -266,7 +288,7 @@ func (n *node) insertChild(numParams uint8, path string, handle Handle) {
|
|||
maxParams: 1,
|
||||
}
|
||||
n.children = []*node{child}
|
||||
n.indices = []byte{path[i]}
|
||||
n.indices = string(path[i])
|
||||
n = child
|
||||
n.priority++
|
||||
|
||||
|
@ -305,8 +327,8 @@ walk: // Outer loop for walking the tree
|
|||
// to walk down the tree
|
||||
if !n.wildChild {
|
||||
c := path[0]
|
||||
for i, index := range n.indices {
|
||||
if c == index {
|
||||
for i := 0; i < len(n.indices); i++ {
|
||||
if c == n.indices[i] {
|
||||
n = n.children[i]
|
||||
continue walk
|
||||
}
|
||||
|
@ -379,7 +401,7 @@ walk: // Outer loop for walking the tree
|
|||
return
|
||||
|
||||
default:
|
||||
panic("Invalid node type")
|
||||
panic("invalid node type")
|
||||
}
|
||||
}
|
||||
} else if path == n.path {
|
||||
|
@ -391,10 +413,10 @@ walk: // Outer loop for walking the tree
|
|||
|
||||
// No handle found. Check if a handle for this path + a
|
||||
// trailing slash exists for trailing slash recommendation
|
||||
for i, index := range n.indices {
|
||||
if index == '/' {
|
||||
for i := 0; i < len(n.indices); i++ {
|
||||
if n.indices[i] == '/' {
|
||||
n = n.children[i]
|
||||
tsr = (n.path == "/" && n.handle != nil) ||
|
||||
tsr = (len(n.path) == 1 && n.handle != nil) ||
|
||||
(n.nType == catchAll && n.children[0].handle != nil)
|
||||
return
|
||||
}
|
||||
|
@ -414,7 +436,7 @@ walk: // Outer loop for walking the tree
|
|||
|
||||
// Makes a case-insensitive lookup of the given path and tries to find a handler.
|
||||
// It can optionally also fix trailing slashes.
|
||||
// It returns the case-corrected path and a bool indicating wether the lookup
|
||||
// It returns the case-corrected path and a bool indicating whether the lookup
|
||||
// was successful.
|
||||
func (n *node) findCaseInsensitivePath(path string, fixTrailingSlash bool) (ciPath []byte, found bool) {
|
||||
ciPath = make([]byte, 0, len(path)+1) // preallocate enough memory
|
||||
|
@ -433,7 +455,7 @@ func (n *node) findCaseInsensitivePath(path string, fixTrailingSlash bool) (ciPa
|
|||
for i, index := range n.indices {
|
||||
// must use recursive approach since both index and
|
||||
// ToLower(index) could exist. We must check both.
|
||||
if r == unicode.ToLower(rune(index)) {
|
||||
if r == unicode.ToLower(index) {
|
||||
out, found := n.children[i].findCaseInsensitivePath(path, fixTrailingSlash)
|
||||
if found {
|
||||
return append(ciPath, out...), true
|
||||
|
@ -445,53 +467,52 @@ func (n *node) findCaseInsensitivePath(path string, fixTrailingSlash bool) (ciPa
|
|||
// without a trailing slash if a leaf exists for that path
|
||||
found = (fixTrailingSlash && path == "/" && n.handle != nil)
|
||||
return
|
||||
}
|
||||
|
||||
} else {
|
||||
n = n.children[0]
|
||||
n = n.children[0]
|
||||
switch n.nType {
|
||||
case param:
|
||||
// find param end (either '/' or path end)
|
||||
k := 0
|
||||
for k < len(path) && path[k] != '/' {
|
||||
k++
|
||||
}
|
||||
|
||||
switch n.nType {
|
||||
case param:
|
||||
// find param end (either '/' or path end)
|
||||
k := 0
|
||||
for k < len(path) && path[k] != '/' {
|
||||
k++
|
||||
}
|
||||
// add param value to case insensitive path
|
||||
ciPath = append(ciPath, path[:k]...)
|
||||
|
||||
// add param value to case insensitive path
|
||||
ciPath = append(ciPath, path[:k]...)
|
||||
|
||||
// we need to go deeper!
|
||||
if k < len(path) {
|
||||
if len(n.children) > 0 {
|
||||
path = path[k:]
|
||||
n = n.children[0]
|
||||
continue
|
||||
} else { // ... but we can't
|
||||
if fixTrailingSlash && len(path) == k+1 {
|
||||
return ciPath, true
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if n.handle != nil {
|
||||
return ciPath, true
|
||||
} else if fixTrailingSlash && len(n.children) == 1 {
|
||||
// No handle found. Check if a handle for this path + a
|
||||
// trailing slash exists
|
||||
// we need to go deeper!
|
||||
if k < len(path) {
|
||||
if len(n.children) > 0 {
|
||||
path = path[k:]
|
||||
n = n.children[0]
|
||||
if n.path == "/" && n.handle != nil {
|
||||
return append(ciPath, '/'), true
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// ... but we can't
|
||||
if fixTrailingSlash && len(path) == k+1 {
|
||||
return ciPath, true
|
||||
}
|
||||
return
|
||||
|
||||
case catchAll:
|
||||
return append(ciPath, path...), true
|
||||
|
||||
default:
|
||||
panic("Invalid node type")
|
||||
}
|
||||
|
||||
if n.handle != nil {
|
||||
return ciPath, true
|
||||
} else if fixTrailingSlash && len(n.children) == 1 {
|
||||
// No handle found. Check if a handle for this path + a
|
||||
// trailing slash exists
|
||||
n = n.children[0]
|
||||
if n.path == "/" && n.handle != nil {
|
||||
return append(ciPath, '/'), true
|
||||
}
|
||||
}
|
||||
return
|
||||
|
||||
case catchAll:
|
||||
return append(ciPath, path...), true
|
||||
|
||||
default:
|
||||
panic("invalid node type")
|
||||
}
|
||||
} else {
|
||||
// We should have reached the node containing the handle.
|
||||
|
@ -503,10 +524,10 @@ func (n *node) findCaseInsensitivePath(path string, fixTrailingSlash bool) (ciPa
|
|||
// No handle found.
|
||||
// Try to fix the path by adding a trailing slash
|
||||
if fixTrailingSlash {
|
||||
for i, index := range n.indices {
|
||||
if index == '/' {
|
||||
for i := 0; i < len(n.indices); i++ {
|
||||
if n.indices[i] == '/' {
|
||||
n = n.children[i]
|
||||
if (n.path == "/" && n.handle != nil) ||
|
||||
if (len(n.path) == 1 && n.handle != nil) ||
|
||||
(n.nType == catchAll && n.children[0].handle != nil) {
|
||||
return append(ciPath, '/'), true
|
||||
}
|
||||
|
|
35
Godeps/_workspace/src/github.com/julienschmidt/httprouter/tree_test.go
generated
vendored
35
Godeps/_workspace/src/github.com/julienschmidt/httprouter/tree_test.go
generated
vendored
|
@ -125,6 +125,8 @@ func TestTreeAddAndGet(t *testing.T) {
|
|||
"/doc/",
|
||||
"/doc/go_faq.html",
|
||||
"/doc/go1.html",
|
||||
"/α",
|
||||
"/β",
|
||||
}
|
||||
for _, route := range routes {
|
||||
tree.addRoute(route, fakeHandler(route))
|
||||
|
@ -142,6 +144,8 @@ func TestTreeAddAndGet(t *testing.T) {
|
|||
{"/cona", true, "", nil}, // key mismatch
|
||||
{"/no", true, "", nil}, // no matching child
|
||||
{"/ab", false, "/ab", nil},
|
||||
{"/α", false, "/α", nil},
|
||||
{"/β", false, "/β", nil},
|
||||
})
|
||||
|
||||
checkPriorities(t, tree)
|
||||
|
@ -339,6 +343,27 @@ func TestTreeCatchAllConflictRoot(t *testing.T) {
|
|||
testRoutes(t, routes)
|
||||
}
|
||||
|
||||
func TestTreeDoubleWildcard(t *testing.T) {
|
||||
const panicMsg = "only one wildcard per path segment is allowed"
|
||||
|
||||
routes := [...]string{
|
||||
"/:foo:bar",
|
||||
"/:foo:bar/",
|
||||
"/:foo*bar",
|
||||
}
|
||||
|
||||
for _, route := range routes {
|
||||
tree := &node{}
|
||||
recv := catchPanic(func() {
|
||||
tree.addRoute(route, nil)
|
||||
})
|
||||
|
||||
if rs, ok := recv.(string); !ok || !strings.HasPrefix(rs, panicMsg) {
|
||||
t.Fatalf(`"Expected panic "%s" for route '%s', got "%v"`, panicMsg, route, recv)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*func TestTreeDuplicateWildcard(t *testing.T) {
|
||||
tree := &node{}
|
||||
|
||||
|
@ -559,6 +584,8 @@ func TestTreeFindCaseInsensitivePath(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestTreeInvalidNodeType(t *testing.T) {
|
||||
const panicMsg = "invalid node type"
|
||||
|
||||
tree := &node{}
|
||||
tree.addRoute("/", fakeHandler("/"))
|
||||
tree.addRoute("/:page", fakeHandler("/:page"))
|
||||
|
@ -570,15 +597,15 @@ func TestTreeInvalidNodeType(t *testing.T) {
|
|||
recv := catchPanic(func() {
|
||||
tree.getValue("/test")
|
||||
})
|
||||
if rs, ok := recv.(string); !ok || rs != "Invalid node type" {
|
||||
t.Fatalf(`Expected panic "Invalid node type", got "%v"`, recv)
|
||||
if rs, ok := recv.(string); !ok || rs != panicMsg {
|
||||
t.Fatalf("Expected panic '"+panicMsg+"', got '%v'", recv)
|
||||
}
|
||||
|
||||
// case-insensitive lookup
|
||||
recv = catchPanic(func() {
|
||||
tree.findCaseInsensitivePath("/test", true)
|
||||
})
|
||||
if rs, ok := recv.(string); !ok || rs != "Invalid node type" {
|
||||
t.Fatalf(`Expected panic "Invalid node type", got "%v"`, recv)
|
||||
if rs, ok := recv.(string); !ok || rs != panicMsg {
|
||||
t.Fatalf("Expected panic '"+panicMsg+"', got '%v'", recv)
|
||||
}
|
||||
}
|
||||
|
|
1
Godeps/_workspace/src/github.com/pushrax/bufferpool/.travis.yml
generated
vendored
Normal file
1
Godeps/_workspace/src/github.com/pushrax/bufferpool/.travis.yml
generated
vendored
Normal file
|
@ -0,0 +1 @@
|
|||
language: go
|
4
Godeps/_workspace/src/github.com/pushrax/bufferpool/AUTHORS
generated
vendored
Normal file
4
Godeps/_workspace/src/github.com/pushrax/bufferpool/AUTHORS
generated
vendored
Normal file
|
@ -0,0 +1,4 @@
|
|||
# This is the official list of Bufferpool authors for copyright purposes.
|
||||
|
||||
Jimmy Zelinskie <jimmyzelinskie@gmail.com>
|
||||
Justin Li <jli@j-li.net>
|
24
Godeps/_workspace/src/github.com/pushrax/bufferpool/LICENSE
generated
vendored
Normal file
24
Godeps/_workspace/src/github.com/pushrax/bufferpool/LICENSE
generated
vendored
Normal file
|
@ -0,0 +1,24 @@
|
|||
Bufferpool is released under a BSD 2-Clause license, reproduced below.
|
||||
|
||||
Copyright (c) 2013, The Bufferpool Authors
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without modification,
|
||||
are permitted provided that the following conditions are met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright notice,
|
||||
this list of conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
|
||||
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
|
||||
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
9
Godeps/_workspace/src/github.com/pushrax/bufferpool/README.md
generated
vendored
Normal file
9
Godeps/_workspace/src/github.com/pushrax/bufferpool/README.md
generated
vendored
Normal file
|
@ -0,0 +1,9 @@
|
|||
# bufferpool [![Build Status](https://secure.travis-ci.org/pushrax/bufferpool.png)](http://travis-ci.org/pushrax/bufferpool)
|
||||
|
||||
The bufferpool package implements a thread-safe pool of reusable, equally sized `byte.Buffer`s.
|
||||
If you're allocating `byte.Buffer`s very frequently, you can use this to speed up your
|
||||
program and take strain off the garbage collector.
|
||||
|
||||
## docs
|
||||
|
||||
[GoDoc](http://godoc.org/github.com/pushrax/bufferpool)
|
69
Godeps/_workspace/src/github.com/pushrax/bufferpool/bufferpool.go
generated
vendored
Normal file
69
Godeps/_workspace/src/github.com/pushrax/bufferpool/bufferpool.go
generated
vendored
Normal file
|
@ -0,0 +1,69 @@
|
|||
// Copyright 2013 The Bufferpool Authors. All rights reserved.
|
||||
// Use of this source code is governed by the BSD 2-Clause license,
|
||||
// which can be found in the LICENSE file.
|
||||
|
||||
// Package bufferpool implements a capacity-limited pool of reusable,
|
||||
// equally-sized buffers.
|
||||
package bufferpool
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
)
|
||||
|
||||
// A BufferPool is a capacity-limited pool of equally sized buffers.
|
||||
type BufferPool struct {
|
||||
bufferSize int
|
||||
pool chan []byte
|
||||
}
|
||||
|
||||
// New returns a newly allocated BufferPool with the given maximum pool size
|
||||
// and buffer size.
|
||||
func New(poolSize, bufferSize int) *BufferPool {
|
||||
return &BufferPool{
|
||||
bufferSize,
|
||||
make(chan []byte, poolSize),
|
||||
}
|
||||
}
|
||||
|
||||
// Take is used to obtain a new zeroed buffer. This will allocate a new buffer
|
||||
// if the pool was empty.
|
||||
func (pool *BufferPool) Take() *bytes.Buffer {
|
||||
return bytes.NewBuffer(pool.TakeSlice()[:0])
|
||||
}
|
||||
|
||||
// TakeSlice is used to obtain a new slice. This will allocate a new slice
|
||||
// if the pool was empty.
|
||||
func (pool *BufferPool) TakeSlice() (slice []byte) {
|
||||
select {
|
||||
case slice = <-pool.pool:
|
||||
default:
|
||||
slice = make([]byte, pool.bufferSize)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Give is used to attempt to return a buffer to the pool. It may not
|
||||
// be added to the pool if it was already full.
|
||||
func (pool *BufferPool) Give(buf *bytes.Buffer) error {
|
||||
buf.Reset()
|
||||
slice := buf.Bytes()
|
||||
|
||||
if cap(slice) < pool.bufferSize {
|
||||
return errors.New("Gave an incorrectly sized buffer to the pool.")
|
||||
}
|
||||
|
||||
return pool.GiveSlice(slice[:buf.Len()])
|
||||
}
|
||||
|
||||
// GiveSlice is used to attempt to return a slice to the pool. It may not
|
||||
// be added to the pool if it was already full.
|
||||
func (pool *BufferPool) GiveSlice(slice []byte) error {
|
||||
select {
|
||||
case pool.pool <- slice:
|
||||
// Everything went smoothly!
|
||||
default:
|
||||
return errors.New("Gave a buffer to a full pool.")
|
||||
}
|
||||
return nil
|
||||
}
|
67
Godeps/_workspace/src/github.com/pushrax/bufferpool/bufferpool_test.go
generated
vendored
Normal file
67
Godeps/_workspace/src/github.com/pushrax/bufferpool/bufferpool_test.go
generated
vendored
Normal file
|
@ -0,0 +1,67 @@
|
|||
// Copyright 2013 The Bufferpool Authors. All rights reserved.
|
||||
// Use of this source code is governed by the BSD 2-Clause license,
|
||||
// which can be found in the LICENSE file.
|
||||
|
||||
package bufferpool_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/pushrax/bufferpool"
|
||||
)
|
||||
|
||||
func ExampleNew() {
|
||||
bp := bufferpool.New(10, 255)
|
||||
|
||||
dogBuffer := bp.Take()
|
||||
dogBuffer.WriteString("Dog!")
|
||||
bp.Give(dogBuffer)
|
||||
|
||||
catBuffer := bp.Take() // dogBuffer is reused and reset.
|
||||
catBuffer.WriteString("Cat!")
|
||||
|
||||
fmt.Println(catBuffer)
|
||||
// Output:
|
||||
// Cat!
|
||||
}
|
||||
|
||||
func TestTakeFromEmpty(t *testing.T) {
|
||||
bp := bufferpool.New(1, 1)
|
||||
poolBuf := bp.Take()
|
||||
if !bytes.Equal(poolBuf.Bytes(), []byte("")) {
|
||||
t.Fatalf("Buffer from empty bufferpool was allocated incorrectly.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTakeFromFilled(t *testing.T) {
|
||||
bp := bufferpool.New(1, 1)
|
||||
|
||||
origBuf := bytes.NewBuffer([]byte("X"))
|
||||
bp.Give(origBuf)
|
||||
|
||||
reusedBuf := bp.Take()
|
||||
if !bytes.Equal(reusedBuf.Bytes(), []byte("")) {
|
||||
t.Fatalf("Buffer from filled bufferpool was recycled incorrectly.")
|
||||
}
|
||||
|
||||
// Compare addresses of the first element in the underlying slice.
|
||||
if &origBuf.Bytes()[:1][0] != &reusedBuf.Bytes()[:1][0] {
|
||||
t.Fatalf("Recycled buffer points at different address.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSliceSemantics(t *testing.T) {
|
||||
bp := bufferpool.New(1, 8)
|
||||
|
||||
buf := bp.Take()
|
||||
buf.WriteString("abc")
|
||||
bp.Give(buf)
|
||||
|
||||
buf2 := bp.TakeSlice()
|
||||
|
||||
if !bytes.Equal(buf2[:3], []byte("abc")) {
|
||||
t.Fatalf("Buffer from filled bufferpool was recycled incorrectly.")
|
||||
}
|
||||
}
|
1
Godeps/_workspace/src/github.com/stretchr/graceful/wercker.yml
generated
vendored
1
Godeps/_workspace/src/github.com/stretchr/graceful/wercker.yml
generated
vendored
|
@ -1 +0,0 @@
|
|||
box: wercker/golang
|
46
Godeps/_workspace/src/github.com/stretchr/pat/stop/doc.go
generated
vendored
46
Godeps/_workspace/src/github.com/stretchr/pat/stop/doc.go
generated
vendored
|
@ -1,46 +0,0 @@
|
|||
// Package stop represents a pattern for types that need to do some work
|
||||
// when stopping. The StopChan method returns a <-chan stop.Signal which
|
||||
// is closed when the operation has completed.
|
||||
//
|
||||
// Stopper types when implementing the stop channel pattern should use stop.Make
|
||||
// to create and store a stop channel, and close the channel once stopping has completed:
|
||||
// func New() Type {
|
||||
// t := new(Type)
|
||||
// t.stopChan = stop.Make()
|
||||
// return t
|
||||
// }
|
||||
// func (t Type) Stop() {
|
||||
// go func(){
|
||||
// // TODO: tear stuff down
|
||||
// close(t.stopChan)
|
||||
// }()
|
||||
// }
|
||||
// func (t Type) StopChan() <-chan stop.Signal {
|
||||
// return t.stopChan
|
||||
// }
|
||||
//
|
||||
// Stopper types can be stopped in the following ways:
|
||||
// // stop and forget
|
||||
// t.Stop(1 * time.Second)
|
||||
//
|
||||
// // stop and wait
|
||||
// t.Stop(1 * time.Second)
|
||||
// <-t.StopChan()
|
||||
//
|
||||
// // stop, do more work, then wait
|
||||
// t.Stop(1 * time.Second);
|
||||
// // do more work
|
||||
// <-t.StopChan()
|
||||
//
|
||||
// // stop and timeout after 1 second
|
||||
// t.Stop(1 * time.Second)
|
||||
// select {
|
||||
// case <-t.StopChan():
|
||||
// case <-time.After(1 * time.Second):
|
||||
// }
|
||||
//
|
||||
// // stop.All is the same as calling Stop() then StopChan() so
|
||||
// // all above patterns also work on many Stopper types,
|
||||
// // for example; stop and wait for many things:
|
||||
// <-stop.All(1 * time.Second, t1, t2, t3)
|
||||
package stop
|
57
Godeps/_workspace/src/github.com/stretchr/pat/stop/stop.go
generated
vendored
57
Godeps/_workspace/src/github.com/stretchr/pat/stop/stop.go
generated
vendored
|
@ -1,57 +0,0 @@
|
|||
package stop
|
||||
|
||||
import "time"
|
||||
|
||||
// Signal is the type that gets sent down the stop channel.
|
||||
type Signal struct{}
|
||||
|
||||
// NoWait represents a time.Duration with zero value.
|
||||
// Logically meaning no grace wait period when stopping.
|
||||
var NoWait time.Duration
|
||||
|
||||
// Stopper represents types that implement
|
||||
// the stop channel pattern.
|
||||
type Stopper interface {
|
||||
// Stop instructs the type to halt operations and close
|
||||
// the stop channel when it is finished.
|
||||
Stop(wait time.Duration)
|
||||
// StopChan gets the stop channel which will block until
|
||||
// stopping has completed, at which point it is closed.
|
||||
// Callers should never close the stop channel.
|
||||
// The StopChan should exist from the point at which operations
|
||||
// begun, not the point at which Stop was called.
|
||||
StopChan() <-chan Signal
|
||||
}
|
||||
|
||||
// Stopped returns a channel that signals immediately. Useful for
|
||||
// cases when no tear-down work is required and stopping is
|
||||
// immediate.
|
||||
func Stopped() <-chan Signal {
|
||||
c := Make()
|
||||
close(c)
|
||||
return c
|
||||
}
|
||||
|
||||
// Make makes a new channel used to indicate when
|
||||
// stopping has finished. Sends to channel will not block.
|
||||
func Make() chan Signal {
|
||||
return make(chan Signal, 0)
|
||||
}
|
||||
|
||||
// All stops all Stopper types and returns another channel
|
||||
// which will close once all things have finished stopping.
|
||||
func All(wait time.Duration, stoppers ...Stopper) <-chan Signal {
|
||||
all := Make()
|
||||
go func() {
|
||||
var allChans []<-chan Signal
|
||||
for _, stopper := range stoppers {
|
||||
go stopper.Stop(wait)
|
||||
allChans = append(allChans, stopper.StopChan())
|
||||
}
|
||||
for _, ch := range allChans {
|
||||
<-ch
|
||||
}
|
||||
close(all)
|
||||
}()
|
||||
return all
|
||||
}
|
76
Godeps/_workspace/src/github.com/stretchr/pat/stop/stop_test.go
generated
vendored
76
Godeps/_workspace/src/github.com/stretchr/pat/stop/stop_test.go
generated
vendored
|
@ -1,76 +0,0 @@
|
|||
package stop_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/pat/stop"
|
||||
)
|
||||
|
||||
type testStopper struct {
|
||||
stopChan chan stop.Signal
|
||||
}
|
||||
|
||||
func NewTestStopper() *testStopper {
|
||||
s := new(testStopper)
|
||||
s.stopChan = stop.Make()
|
||||
return s
|
||||
}
|
||||
|
||||
func (t *testStopper) Stop(wait time.Duration) {
|
||||
go func() {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
close(t.stopChan)
|
||||
}()
|
||||
}
|
||||
func (t *testStopper) StopChan() <-chan stop.Signal {
|
||||
return t.stopChan
|
||||
}
|
||||
|
||||
type noopStopper struct{}
|
||||
|
||||
func (t *noopStopper) Stop() {
|
||||
}
|
||||
func (t *noopStopper) StopChan() <-chan stop.Signal {
|
||||
return stop.Stopped()
|
||||
}
|
||||
|
||||
func TestStop(t *testing.T) {
|
||||
|
||||
s := NewTestStopper()
|
||||
s.Stop(1 * time.Second)
|
||||
stopChan := s.StopChan()
|
||||
select {
|
||||
case <-stopChan:
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Error("Stop signal was never sent (timed out)")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestAll(t *testing.T) {
|
||||
|
||||
s1 := NewTestStopper()
|
||||
s2 := NewTestStopper()
|
||||
s3 := NewTestStopper()
|
||||
|
||||
select {
|
||||
case <-stop.All(1*time.Second, s1, s2, s3):
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Error("All signal was never sent (timed out)")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestNoop(t *testing.T) {
|
||||
|
||||
s := new(noopStopper)
|
||||
s.Stop()
|
||||
stopChan := s.StopChan()
|
||||
select {
|
||||
case <-stopChan:
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Error("Stop signal was never sent (timed out)")
|
||||
}
|
||||
|
||||
}
|
|
@ -1,6 +1,6 @@
|
|||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2014 Stretchr, Inc.
|
||||
Copyright (c) 2014 Tyler Bunnell
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
|
@ -18,4 +18,4 @@ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
SOFTWARE.
|
|
@ -1,17 +1,30 @@
|
|||
graceful [![GoDoc](https://godoc.org/github.com/stretchr/graceful?status.png)](http://godoc.org/github.com/stretchr/graceful) [![wercker status](https://app.wercker.com/status/2729ba763abf87695a17547e0f7af4a4/s "wercker status")](https://app.wercker.com/project/bykey/2729ba763abf87695a17547e0f7af4a4)
|
||||
graceful [![GoDoc](https://godoc.org/github.com/tylerb/graceful?status.png)](http://godoc.org/github.com/tylerb/graceful) [![Build Status](https://drone.io/github.com/tylerb/graceful/status.png)](https://drone.io/github.com/tylerb/graceful/latest) [![Coverage Status](https://coveralls.io/repos/tylerb/graceful/badge.svg?branch=dronedebug)](https://coveralls.io/r/tylerb/graceful?branch=dronedebug) [![Gitter](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/tylerb/graceful?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge)
|
||||
========
|
||||
|
||||
Graceful is a Go 1.3+ package enabling graceful shutdown of http.Handler servers.
|
||||
|
||||
## Installation
|
||||
|
||||
To install, simply execute:
|
||||
|
||||
```
|
||||
go get gopkg.in/tylerb/graceful.v1
|
||||
```
|
||||
|
||||
I am using [gopkg.in](http://http://labix.org/gopkg.in) to control releases.
|
||||
|
||||
## Usage
|
||||
|
||||
Usage of Graceful is simple. Create your http.Handler and pass it to the `Run` function:
|
||||
Using Graceful is easy. Simply create your http.Handler and pass it to the `Run` function:
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"github.com/stretchr/graceful"
|
||||
"gopkg.in/tylerb/graceful.v1"
|
||||
"net/http"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
@ -31,9 +44,10 @@ package main
|
|||
|
||||
import (
|
||||
"github.com/codegangsta/negroni"
|
||||
"github.com/stretchr/graceful"
|
||||
"gopkg.in/tylerb/graceful.v1"
|
||||
"net/http"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
@ -111,4 +125,13 @@ same time and all will be signalled when stopping is complete.
|
|||
|
||||
## Contributing
|
||||
|
||||
Before sending a pull request, please open a new issue describing the feature/issue you wish to address so it can be discussed. The subsequent pull request should close that issue.
|
||||
If you would like to contribute, please:
|
||||
|
||||
1. Create a GitHub issue regarding the contribution. Features and bugs should be discussed beforehand.
|
||||
2. Fork the repository.
|
||||
3. Create a pull request with your solution. This pull request should reference and close the issues (Fix #2).
|
||||
|
||||
All pull requests should:
|
||||
|
||||
1. Pass [gometalinter -t .](https://github.com/alecthomas/gometalinter) with no warnings.
|
||||
2. Be `go fmt` formatted.
|
|
@ -11,7 +11,6 @@ import (
|
|||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/pat/stop"
|
||||
"golang.org/x/net/netutil"
|
||||
)
|
||||
|
||||
|
@ -41,30 +40,31 @@ type Server struct {
|
|||
// must not be set directly.
|
||||
ConnState func(net.Conn, http.ConnState)
|
||||
|
||||
// ShutdownInitiated is an optional callback function that is called
|
||||
// ShutdownInitiated is an optional callback function that is called
|
||||
// when shutdown is initiated. It can be used to notify the client
|
||||
// side of long lived connections (e.g. websockets) to reconnect.
|
||||
ShutdownInitiated func()
|
||||
|
||||
// NoSignalHandling prevents graceful from automatically shutting down
|
||||
// on SIGINT and SIGTERM. If set to true, you must shut down the server
|
||||
// manually with Stop().
|
||||
NoSignalHandling bool
|
||||
|
||||
// interrupt signals the listener to stop serving connections,
|
||||
// and the server to shut down.
|
||||
interrupt chan os.Signal
|
||||
|
||||
// stopChan is the channel on which callers may block while waiting for
|
||||
// the server to stop.
|
||||
stopChan chan stop.Signal
|
||||
stopChan chan struct{}
|
||||
|
||||
// stopChanOnce is used to create the stop channel on demand, once, per
|
||||
// instance.
|
||||
stopChanOnce sync.Once
|
||||
// stopLock is used to protect access to the stopChan.
|
||||
stopLock sync.RWMutex
|
||||
|
||||
// connections holds all connections managed by graceful
|
||||
connections map[net.Conn]struct{}
|
||||
}
|
||||
|
||||
// ensure Server conforms to stop.Stopper
|
||||
var _ stop.Stopper = (*Server)(nil)
|
||||
|
||||
// Run serves the http.Handler with graceful shutdown enabled.
|
||||
//
|
||||
// timeout is the duration to wait until killing active requests and stopping the server.
|
||||
|
@ -173,7 +173,6 @@ func (srv *Server) Serve(listener net.Listener) error {
|
|||
case http.StateClosed, http.StateHijacked:
|
||||
remove <- conn
|
||||
}
|
||||
|
||||
if srv.ConnState != nil {
|
||||
srv.ConnState(conn, state)
|
||||
}
|
||||
|
@ -182,7 +181,53 @@ func (srv *Server) Serve(listener net.Listener) error {
|
|||
// Manage open connections
|
||||
shutdown := make(chan chan struct{})
|
||||
kill := make(chan struct{})
|
||||
go func() {
|
||||
go srv.manageConnections(add, remove, shutdown, kill)
|
||||
|
||||
interrupt := srv.interruptChan()
|
||||
|
||||
// Set up the interrupt handler
|
||||
if !srv.NoSignalHandling {
|
||||
signal.Notify(interrupt, syscall.SIGINT, syscall.SIGTERM)
|
||||
}
|
||||
|
||||
go srv.handleInterrupt(interrupt, listener)
|
||||
|
||||
// Serve with graceful listener.
|
||||
// Execution blocks here until listener.Close() is called, above.
|
||||
err := srv.Server.Serve(listener)
|
||||
|
||||
srv.shutdown(shutdown, kill)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Stop instructs the type to halt operations and close
|
||||
// the stop channel when it is finished.
|
||||
//
|
||||
// timeout is grace period for which to wait before shutting
|
||||
// down the server. The timeout value passed here will override the
|
||||
// timeout given when constructing the server, as this is an explicit
|
||||
// command to stop the server.
|
||||
func (srv *Server) Stop(timeout time.Duration) {
|
||||
srv.Timeout = timeout
|
||||
interrupt := srv.interruptChan()
|
||||
interrupt <- syscall.SIGINT
|
||||
}
|
||||
|
||||
// StopChan gets the stop channel which will block until
|
||||
// stopping has completed, at which point it is closed.
|
||||
// Callers should never close the stop channel.
|
||||
func (srv *Server) StopChan() <-chan struct{} {
|
||||
srv.stopLock.Lock()
|
||||
if srv.stopChan == nil {
|
||||
srv.stopChan = make(chan struct{})
|
||||
}
|
||||
srv.stopLock.Unlock()
|
||||
return srv.stopChan
|
||||
}
|
||||
|
||||
func (srv *Server) manageConnections(add, remove chan net.Conn, shutdown chan chan struct{}, kill chan struct{}) {
|
||||
{
|
||||
var done chan struct{}
|
||||
srv.connections = map[net.Conn]struct{}{}
|
||||
for {
|
||||
|
@ -202,36 +247,39 @@ func (srv *Server) Serve(listener net.Listener) error {
|
|||
}
|
||||
case <-kill:
|
||||
for k := range srv.connections {
|
||||
k.Close()
|
||||
_ = k.Close() // nothing to do here if it errors
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func (srv *Server) interruptChan() chan os.Signal {
|
||||
srv.stopLock.Lock()
|
||||
if srv.interrupt == nil {
|
||||
srv.interrupt = make(chan os.Signal, 1)
|
||||
}
|
||||
srv.stopLock.Unlock()
|
||||
|
||||
// Set up the interrupt catch
|
||||
signal.Notify(srv.interrupt, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-srv.interrupt
|
||||
srv.SetKeepAlivesEnabled(false)
|
||||
listener.Close()
|
||||
return srv.interrupt
|
||||
}
|
||||
|
||||
if srv.ShutdownInitiated != nil {
|
||||
srv.ShutdownInitiated()
|
||||
}
|
||||
func (srv *Server) handleInterrupt(interrupt chan os.Signal, listener net.Listener) {
|
||||
<-interrupt
|
||||
|
||||
signal.Stop(srv.interrupt)
|
||||
close(srv.interrupt)
|
||||
}()
|
||||
srv.SetKeepAlivesEnabled(false)
|
||||
_ = listener.Close() // we are shutting down anyway. ignore error.
|
||||
|
||||
// Serve with graceful listener.
|
||||
// Execution blocks here until listener.Close() is called, above.
|
||||
err := srv.Server.Serve(listener)
|
||||
if srv.ShutdownInitiated != nil {
|
||||
srv.ShutdownInitiated()
|
||||
}
|
||||
|
||||
signal.Stop(interrupt)
|
||||
close(interrupt)
|
||||
}
|
||||
|
||||
func (srv *Server) shutdown(shutdown chan chan struct{}, kill chan struct{}) {
|
||||
// Request done notification
|
||||
done := make(chan struct{})
|
||||
shutdown <- done
|
||||
|
@ -246,32 +294,9 @@ func (srv *Server) Serve(listener net.Listener) error {
|
|||
<-done
|
||||
}
|
||||
// Close the stopChan to wake up any blocked goroutines.
|
||||
srv.stopLock.Lock()
|
||||
if srv.stopChan != nil {
|
||||
close(srv.stopChan)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Stop instructs the type to halt operations and close
|
||||
// the stop channel when it is finished.
|
||||
//
|
||||
// timeout is grace period for which to wait before shutting
|
||||
// down the server. The timeout value passed here will override the
|
||||
// timeout given when constructing the server, as this is an explicit
|
||||
// command to stop the server.
|
||||
func (srv *Server) Stop(timeout time.Duration) {
|
||||
srv.Timeout = timeout
|
||||
srv.interrupt <- syscall.SIGINT
|
||||
}
|
||||
|
||||
// StopChan gets the stop channel which will block until
|
||||
// stopping has completed, at which point it is closed.
|
||||
// Callers should never close the stop channel.
|
||||
func (srv *Server) StopChan() <-chan stop.Signal {
|
||||
srv.stopChanOnce.Do(func() {
|
||||
if srv.stopChan == nil {
|
||||
srv.stopChan = stop.Make()
|
||||
}
|
||||
})
|
||||
return srv.stopChan
|
||||
srv.stopLock.Unlock()
|
||||
}
|
|
@ -1,6 +1,7 @@
|
|||
package graceful
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
|
@ -13,34 +14,52 @@ import (
|
|||
"time"
|
||||
)
|
||||
|
||||
var killTime = 50 * time.Millisecond
|
||||
var (
|
||||
killTime = 500 * time.Millisecond
|
||||
timeoutTime = 1000 * time.Millisecond
|
||||
waitTime = 100 * time.Millisecond
|
||||
)
|
||||
|
||||
func runQuery(t *testing.T, expected int, shouldErr bool, wg *sync.WaitGroup) {
|
||||
func runQuery(t *testing.T, expected int, shouldErr bool, wg *sync.WaitGroup, once *sync.Once) {
|
||||
wg.Add(1)
|
||||
defer wg.Done()
|
||||
client := http.Client{}
|
||||
r, err := client.Get("http://localhost:3000")
|
||||
if shouldErr && err == nil {
|
||||
t.Fatal("Expected an error but none was encountered.")
|
||||
once.Do(func() {
|
||||
t.Fatal("Expected an error but none was encountered.")
|
||||
})
|
||||
} else if shouldErr && err != nil {
|
||||
if err.(*url.Error).Err == io.EOF {
|
||||
if checkErr(t, err, once) {
|
||||
return
|
||||
}
|
||||
errno := err.(*url.Error).Err.(*net.OpError).Err.(syscall.Errno)
|
||||
if errno == syscall.ECONNREFUSED {
|
||||
return
|
||||
} else if err != nil {
|
||||
t.Fatal("Error on Get:", err)
|
||||
}
|
||||
}
|
||||
|
||||
if r != nil && r.StatusCode != expected {
|
||||
t.Fatalf("Incorrect status code on response. Expected %d. Got %d", expected, r.StatusCode)
|
||||
once.Do(func() {
|
||||
t.Fatalf("Incorrect status code on response. Expected %d. Got %d", expected, r.StatusCode)
|
||||
})
|
||||
} else if r == nil {
|
||||
t.Fatal("No response when a response was expected.")
|
||||
once.Do(func() {
|
||||
t.Fatal("No response when a response was expected.")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func checkErr(t *testing.T, err error, once *sync.Once) bool {
|
||||
if err.(*url.Error).Err == io.EOF {
|
||||
return true
|
||||
}
|
||||
errno := err.(*url.Error).Err.(*net.OpError).Err.(syscall.Errno)
|
||||
if errno == syscall.ECONNREFUSED {
|
||||
return true
|
||||
} else if err != nil {
|
||||
once.Do(func() {
|
||||
t.Fatal("Error on Get:", err)
|
||||
})
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func createListener(sleep time.Duration) (*http.Server, net.Listener, error) {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/", func(rw http.ResponseWriter, r *http.Request) {
|
||||
|
@ -50,6 +69,9 @@ func createListener(sleep time.Duration) (*http.Server, net.Listener, error) {
|
|||
|
||||
server := &http.Server{Addr: ":3000", Handler: mux}
|
||||
l, err := net.Listen("tcp", ":3000")
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
}
|
||||
return server, l, err
|
||||
}
|
||||
|
||||
|
@ -64,16 +86,17 @@ func runServer(timeout, sleep time.Duration, c chan os.Signal) error {
|
|||
}
|
||||
|
||||
func launchTestQueries(t *testing.T, wg *sync.WaitGroup, c chan os.Signal) {
|
||||
var once sync.Once
|
||||
for i := 0; i < 8; i++ {
|
||||
go runQuery(t, http.StatusOK, false, wg)
|
||||
go runQuery(t, http.StatusOK, false, wg, &once)
|
||||
}
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
time.Sleep(waitTime)
|
||||
c <- os.Interrupt
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
time.Sleep(waitTime)
|
||||
|
||||
for i := 0; i < 8; i++ {
|
||||
go runQuery(t, 0, true, wg)
|
||||
go runQuery(t, 0, true, wg, &once)
|
||||
}
|
||||
|
||||
wg.Done()
|
||||
|
@ -106,16 +129,17 @@ func TestGracefulRunTimesOut(t *testing.T) {
|
|||
wg.Done()
|
||||
}()
|
||||
|
||||
var once sync.Once
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
for i := 0; i < 8; i++ {
|
||||
go runQuery(t, 0, true, &wg)
|
||||
go runQuery(t, 0, true, &wg, &once)
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
time.Sleep(waitTime)
|
||||
c <- os.Interrupt
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
time.Sleep(waitTime)
|
||||
for i := 0; i < 8; i++ {
|
||||
go runQuery(t, 0, true, &wg)
|
||||
go runQuery(t, 0, true, &wg, &once)
|
||||
}
|
||||
wg.Done()
|
||||
}()
|
||||
|
@ -160,14 +184,23 @@ func TestGracefulRunNoRequests(t *testing.T) {
|
|||
func TestGracefulForwardsConnState(t *testing.T) {
|
||||
c := make(chan os.Signal, 1)
|
||||
states := make(map[http.ConnState]int)
|
||||
var stateLock sync.Mutex
|
||||
|
||||
connState := func(conn net.Conn, state http.ConnState) {
|
||||
stateLock.Lock()
|
||||
states[state]++
|
||||
stateLock.Unlock()
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
|
||||
expected := map[http.ConnState]int{
|
||||
http.StateNew: 8,
|
||||
http.StateActive: 8,
|
||||
http.StateClosed: 8,
|
||||
}
|
||||
|
||||
go func() {
|
||||
server, l, _ := createListener(killTime / 2)
|
||||
srv := &Server{
|
||||
|
@ -185,15 +218,11 @@ func TestGracefulForwardsConnState(t *testing.T) {
|
|||
go launchTestQueries(t, &wg, c)
|
||||
wg.Wait()
|
||||
|
||||
expected := map[http.ConnState]int{
|
||||
http.StateNew: 8,
|
||||
http.StateActive: 8,
|
||||
http.StateClosed: 8,
|
||||
}
|
||||
|
||||
stateLock.Lock()
|
||||
if !reflect.DeepEqual(states, expected) {
|
||||
t.Errorf("Incorrect connection state tracking.\n actual: %v\nexpected: %v\n", states, expected)
|
||||
}
|
||||
stateLock.Unlock()
|
||||
}
|
||||
|
||||
func TestGracefulExplicitStop(t *testing.T) {
|
||||
|
@ -206,14 +235,14 @@ func TestGracefulExplicitStop(t *testing.T) {
|
|||
|
||||
go func() {
|
||||
go srv.Serve(l)
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
time.Sleep(waitTime)
|
||||
srv.Stop(killTime)
|
||||
}()
|
||||
|
||||
// block on the stopChan until the server has shut down
|
||||
select {
|
||||
case <-srv.StopChan():
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
case <-time.After(timeoutTime):
|
||||
t.Fatal("Timed out while waiting for explicit stop to complete")
|
||||
}
|
||||
}
|
||||
|
@ -228,7 +257,7 @@ func TestGracefulExplicitStopOverride(t *testing.T) {
|
|||
|
||||
go func() {
|
||||
go srv.Serve(l)
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
time.Sleep(waitTime)
|
||||
srv.Stop(killTime / 2)
|
||||
}()
|
||||
|
||||
|
@ -253,7 +282,7 @@ func TestShutdownInitiatedCallback(t *testing.T) {
|
|||
|
||||
go func() {
|
||||
go srv.Serve(l)
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
time.Sleep(waitTime)
|
||||
srv.Stop(killTime)
|
||||
}()
|
||||
|
||||
|
@ -302,12 +331,9 @@ func TestNotifyClosed(t *testing.T) {
|
|||
wg.Done()
|
||||
}()
|
||||
|
||||
var once sync.Once
|
||||
for i := 0; i < 8; i++ {
|
||||
runQuery(t, http.StatusOK, false, &wg)
|
||||
}
|
||||
|
||||
if len(srv.connections) > 0 {
|
||||
t.Fatal("hijacked connections should not be managed")
|
||||
runQuery(t, http.StatusOK, false, &wg, &once)
|
||||
}
|
||||
|
||||
srv.Stop(0)
|
||||
|
@ -315,8 +341,39 @@ func TestNotifyClosed(t *testing.T) {
|
|||
// block on the stopChan until the server has shut down
|
||||
select {
|
||||
case <-srv.StopChan():
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
case <-time.After(timeoutTime):
|
||||
t.Fatal("Timed out while waiting for explicit stop to complete")
|
||||
}
|
||||
|
||||
if len(srv.connections) > 0 {
|
||||
t.Fatal("hijacked connections should not be managed")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestStopDeadlock(t *testing.T) {
|
||||
c := make(chan struct{})
|
||||
|
||||
server, l, err := createListener(1 * time.Millisecond)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
srv := &Server{Server: server, NoSignalHandling: true}
|
||||
|
||||
go func() {
|
||||
time.Sleep(waitTime)
|
||||
srv.Serve(l)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
srv.Stop(0)
|
||||
close(c)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-c:
|
||||
case <-time.After(timeoutTime):
|
||||
t.Fatal("Timed out while waiting for explicit stop to complete")
|
||||
}
|
||||
}
|
|
@ -5,7 +5,7 @@ import (
|
|||
"sync"
|
||||
|
||||
"github.com/codegangsta/negroni"
|
||||
"github.com/stretchr/graceful"
|
||||
"github.com/tylerb/graceful"
|
||||
)
|
||||
|
||||
func main() {
|
52
chihaya.go
52
chihaya.go
|
@ -7,8 +7,11 @@ package chihaya
|
|||
import (
|
||||
"flag"
|
||||
"os"
|
||||
"os/signal"
|
||||
"runtime"
|
||||
"runtime/pprof"
|
||||
"sync"
|
||||
"syscall"
|
||||
|
||||
"github.com/golang/glog"
|
||||
|
||||
|
@ -16,6 +19,7 @@ import (
|
|||
"github.com/chihaya/chihaya/http"
|
||||
"github.com/chihaya/chihaya/stats"
|
||||
"github.com/chihaya/chihaya/tracker"
|
||||
"github.com/chihaya/chihaya/udp"
|
||||
|
||||
// See the README for how to import custom drivers.
|
||||
_ "github.com/chihaya/chihaya/backend/noop"
|
||||
|
@ -77,6 +81,50 @@ func Boot() {
|
|||
glog.Fatal("New: ", err)
|
||||
}
|
||||
|
||||
http.Serve(cfg, tkr)
|
||||
glog.Info("Gracefully shut down")
|
||||
var wg sync.WaitGroup
|
||||
var servers []tracker.Server
|
||||
|
||||
if cfg.HTTPListenAddr != "" {
|
||||
wg.Add(1)
|
||||
srv := http.NewServer(cfg, tkr)
|
||||
servers = append(servers, srv)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
srv.Serve(cfg.HTTPListenAddr)
|
||||
}()
|
||||
}
|
||||
|
||||
if cfg.UDPListenAddr != "" {
|
||||
wg.Add(1)
|
||||
srv := udp.NewServer(cfg, tkr)
|
||||
servers = append(servers, srv)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
srv.Serve(cfg.UDPListenAddr)
|
||||
}()
|
||||
}
|
||||
|
||||
shutdown := make(chan os.Signal)
|
||||
signal.Notify(shutdown, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
go func() {
|
||||
wg.Wait()
|
||||
signal.Stop(shutdown)
|
||||
close(shutdown)
|
||||
}()
|
||||
|
||||
<-shutdown
|
||||
glog.Info("Shutting down...")
|
||||
|
||||
for _, srv := range servers {
|
||||
srv.Stop()
|
||||
}
|
||||
|
||||
<-shutdown
|
||||
|
||||
if err := tkr.Close(); err != nil {
|
||||
glog.Errorf("Failed to shut down tracker cleanly: %s", err.Error())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -90,17 +90,24 @@ type TrackerConfig struct {
|
|||
|
||||
// HTTPConfig is the configuration for HTTP functionality.
|
||||
type HTTPConfig struct {
|
||||
ListenAddr string `json:"http_listen_addr"`
|
||||
RequestTimeout Duration `json:"http_request_timeout"`
|
||||
HTTPReadTimeout Duration `json:"http_read_timeout"`
|
||||
HTTPWriteTimeout Duration `json:"http_write_timeout"`
|
||||
HTTPListenLimit int `json:"http_listen_limit"`
|
||||
HTTPListenAddr string `json:"http_listen_addr"`
|
||||
HTTPRequestTimeout Duration `json:"http_request_timeout"`
|
||||
HTTPReadTimeout Duration `json:"http_read_timeout"`
|
||||
HTTPWriteTimeout Duration `json:"http_write_timeout"`
|
||||
HTTPListenLimit int `json:"http_listen_limit"`
|
||||
}
|
||||
|
||||
// UDPConfig is the configuration for HTTP functionality.
|
||||
type UDPConfig struct {
|
||||
UDPListenAddr string `json:"udp_listen_addr"`
|
||||
UDPReadBufferSize int `json:"udp_read_buffer_size"`
|
||||
}
|
||||
|
||||
// Config is the global configuration for an instance of Chihaya.
|
||||
type Config struct {
|
||||
TrackerConfig
|
||||
HTTPConfig
|
||||
UDPConfig
|
||||
DriverConfig
|
||||
StatsConfig
|
||||
}
|
||||
|
@ -129,10 +136,14 @@ var DefaultConfig = Config{
|
|||
},
|
||||
|
||||
HTTPConfig: HTTPConfig{
|
||||
ListenAddr: ":6881",
|
||||
RequestTimeout: Duration{10 * time.Second},
|
||||
HTTPReadTimeout: Duration{10 * time.Second},
|
||||
HTTPWriteTimeout: Duration{10 * time.Second},
|
||||
HTTPListenAddr: ":6881",
|
||||
HTTPRequestTimeout: Duration{10 * time.Second},
|
||||
HTTPReadTimeout: Duration{10 * time.Second},
|
||||
HTTPWriteTimeout: Duration{10 * time.Second},
|
||||
},
|
||||
|
||||
UDPConfig: UDPConfig{
|
||||
UDPListenAddr: ":6882",
|
||||
},
|
||||
|
||||
DriverConfig: DriverConfig{
|
||||
|
|
|
@ -13,10 +13,11 @@
|
|||
"respect_af": false,
|
||||
"client_whitelist_enabled": false,
|
||||
"client_whitelist": ["OP1011"],
|
||||
"udp_listen_addr": ":6881",
|
||||
"http_listen_addr": ":6881",
|
||||
"http_request_timeout": "10s",
|
||||
"http_read_timeout": "10s",
|
||||
"http_write_timeout": "10s",
|
||||
"http_request_timeout": "4s",
|
||||
"http_read_timeout": "4s",
|
||||
"http_write_timeout": "4s",
|
||||
"http_listen_limit": 0,
|
||||
"driver": "noop",
|
||||
"stats_buffer_size": 0,
|
||||
|
|
|
@ -31,10 +31,10 @@ func TestPublicAnnounce(t *testing.T) {
|
|||
peer3 := makePeerParams("peer3", false)
|
||||
|
||||
peer1["event"] = "started"
|
||||
expected := makeResponse(1, 0)
|
||||
expected := makeResponse(1, 0, peer1)
|
||||
checkAnnounce(peer1, expected, srv, t)
|
||||
|
||||
expected = makeResponse(2, 0)
|
||||
expected = makeResponse(2, 0, peer2)
|
||||
checkAnnounce(peer2, expected, srv, t)
|
||||
|
||||
expected = makeResponse(2, 1, peer1, peer2)
|
||||
|
@ -147,7 +147,7 @@ func TestPrivateAnnounce(t *testing.T) {
|
|||
peer2 := makePeerParams("-TR2820-peer2", false)
|
||||
peer3 := makePeerParams("-TR2820-peer3", true)
|
||||
|
||||
expected := makeResponse(0, 1)
|
||||
expected := makeResponse(0, 1, peer1)
|
||||
srv.URL = baseURL + "/users/vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv1"
|
||||
checkAnnounce(peer1, expected, srv, t)
|
||||
|
||||
|
@ -189,7 +189,7 @@ func TestPreferredSubnet(t *testing.T) {
|
|||
peerD1 := makePeerParams("peerD1", false, "fc02::1")
|
||||
peerD2 := makePeerParams("peerD2", false, "fc02::2")
|
||||
|
||||
expected := makeResponse(0, 1)
|
||||
expected := makeResponse(0, 1, peerA1)
|
||||
checkAnnounce(peerA1, expected, srv, t)
|
||||
|
||||
expected = makeResponse(0, 2, peerA1)
|
||||
|
@ -255,7 +255,7 @@ func TestCompactAnnounce(t *testing.T) {
|
|||
peer3["compact"] = "1"
|
||||
|
||||
expected := makeResponse(0, 1)
|
||||
expected["peers"] = ""
|
||||
expected["peers"] = compact
|
||||
checkAnnounce(peer1, expected, srv, t)
|
||||
|
||||
expected = makeResponse(0, 2)
|
||||
|
|
59
http/http.go
59
http/http.go
|
@ -12,7 +12,7 @@ import (
|
|||
|
||||
"github.com/golang/glog"
|
||||
"github.com/julienschmidt/httprouter"
|
||||
"github.com/stretchr/graceful"
|
||||
"github.com/tylerb/graceful"
|
||||
|
||||
"github.com/chihaya/chihaya/config"
|
||||
"github.com/chihaya/chihaya/stats"
|
||||
|
@ -24,8 +24,10 @@ type ResponseHandler func(http.ResponseWriter, *http.Request, httprouter.Params)
|
|||
|
||||
// Server represents an HTTP serving torrent tracker.
|
||||
type Server struct {
|
||||
config *config.Config
|
||||
tracker *tracker.Tracker
|
||||
config *config.Config
|
||||
tracker *tracker.Tracker
|
||||
grace *graceful.Server
|
||||
stopping bool
|
||||
}
|
||||
|
||||
// makeHandler wraps our ResponseHandlers while timing requests, collecting,
|
||||
|
@ -118,40 +120,53 @@ func (s *Server) connState(conn net.Conn, state http.ConnState) {
|
|||
}
|
||||
}
|
||||
|
||||
// Serve creates a new Server and proceeds to block while handling requests
|
||||
// until a graceful shutdown.
|
||||
func Serve(cfg *config.Config, tkr *tracker.Tracker) {
|
||||
srv := &Server{
|
||||
config: cfg,
|
||||
tracker: tkr,
|
||||
}
|
||||
// Serve runs an HTTP server, blocking until the server has shut down.
|
||||
func (s *Server) Serve(addr string) {
|
||||
glog.V(0).Info("Starting HTTP on ", addr)
|
||||
|
||||
glog.V(0).Info("Starting on ", cfg.ListenAddr)
|
||||
if cfg.HTTPListenLimit != 0 {
|
||||
glog.V(0).Info("Limiting connections to ", cfg.HTTPListenLimit)
|
||||
if s.config.HTTPListenLimit != 0 {
|
||||
glog.V(0).Info("Limiting connections to ", s.config.HTTPListenLimit)
|
||||
}
|
||||
|
||||
grace := &graceful.Server{
|
||||
Timeout: cfg.RequestTimeout.Duration,
|
||||
ConnState: srv.connState,
|
||||
ListenLimit: cfg.HTTPListenLimit,
|
||||
Timeout: s.config.HTTPRequestTimeout.Duration,
|
||||
ConnState: s.connState,
|
||||
ListenLimit: s.config.HTTPListenLimit,
|
||||
|
||||
NoSignalHandling: true,
|
||||
Server: &http.Server{
|
||||
Addr: cfg.ListenAddr,
|
||||
Handler: newRouter(srv),
|
||||
ReadTimeout: cfg.HTTPReadTimeout.Duration,
|
||||
WriteTimeout: cfg.HTTPWriteTimeout.Duration,
|
||||
Addr: addr,
|
||||
Handler: newRouter(s),
|
||||
ReadTimeout: s.config.HTTPReadTimeout.Duration,
|
||||
WriteTimeout: s.config.HTTPWriteTimeout.Duration,
|
||||
},
|
||||
}
|
||||
|
||||
s.grace = grace
|
||||
grace.SetKeepAlivesEnabled(false)
|
||||
grace.ShutdownInitiated = func() { s.stopping = true }
|
||||
|
||||
if err := grace.ListenAndServe(); err != nil {
|
||||
if opErr, ok := err.(*net.OpError); !ok || (ok && opErr.Op != "accept") {
|
||||
glog.Errorf("Failed to gracefully run HTTP server: %s", err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if err := srv.tracker.Close(); err != nil {
|
||||
glog.Errorf("Failed to shutdown tracker cleanly: %s", err.Error())
|
||||
glog.Info("HTTP server shut down cleanly")
|
||||
}
|
||||
|
||||
// Stop cleanly shuts down the server.
|
||||
func (s *Server) Stop() {
|
||||
if !s.stopping {
|
||||
s.grace.Stop(s.grace.Timeout)
|
||||
}
|
||||
}
|
||||
|
||||
// NewServer returns a new HTTP server for a given configuration and tracker.
|
||||
func NewServer(cfg *config.Config, tkr *tracker.Tracker) *Server {
|
||||
return &Server{
|
||||
config: cfg,
|
||||
tracker: tkr,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -76,7 +76,7 @@ func (s *Server) stats(w http.ResponseWriter, r *http.Request, p httprouter.Para
|
|||
func handleTorrentError(err error, w *Writer) (int, error) {
|
||||
if err == nil {
|
||||
return http.StatusOK, nil
|
||||
} else if _, ok := err.(models.ClientError); ok {
|
||||
} else if models.IsPublicError(err) {
|
||||
w.WriteError(err)
|
||||
stats.RecordEvent(stats.ClientError)
|
||||
return http.StatusOK, nil
|
||||
|
@ -86,10 +86,8 @@ func handleTorrentError(err error, w *Writer) (int, error) {
|
|||
}
|
||||
|
||||
func (s *Server) serveAnnounce(w http.ResponseWriter, r *http.Request, p httprouter.Params) (int, error) {
|
||||
stats.RecordEvent(stats.Announce)
|
||||
|
||||
writer := &Writer{w}
|
||||
ann, err := NewAnnounce(s.config, r, p)
|
||||
ann, err := s.newAnnounce(r, p)
|
||||
if err != nil {
|
||||
return handleTorrentError(err, writer)
|
||||
}
|
||||
|
@ -98,10 +96,8 @@ func (s *Server) serveAnnounce(w http.ResponseWriter, r *http.Request, p httprou
|
|||
}
|
||||
|
||||
func (s *Server) serveScrape(w http.ResponseWriter, r *http.Request, p httprouter.Params) (int, error) {
|
||||
stats.RecordEvent(stats.Scrape)
|
||||
|
||||
writer := &Writer{w}
|
||||
scrape, err := NewScrape(s.config, r, p)
|
||||
scrape, err := s.newScrape(r, p)
|
||||
if err != nil {
|
||||
return handleTorrentError(err, writer)
|
||||
}
|
||||
|
|
102
http/tracker.go
102
http/tracker.go
|
@ -17,8 +17,8 @@ import (
|
|||
"github.com/chihaya/chihaya/tracker/models"
|
||||
)
|
||||
|
||||
// NewAnnounce parses an HTTP request and generates a models.Announce.
|
||||
func NewAnnounce(cfg *config.Config, r *http.Request, p httprouter.Params) (*models.Announce, error) {
|
||||
// newAnnounce parses an HTTP request and generates a models.Announce.
|
||||
func (s *Server) newAnnounce(r *http.Request, p httprouter.Params) (*models.Announce, error) {
|
||||
q, err := query.New(r.URL.RawQuery)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -26,7 +26,7 @@ func NewAnnounce(cfg *config.Config, r *http.Request, p httprouter.Params) (*mod
|
|||
|
||||
compact := q.Params["compact"] != "0"
|
||||
event, _ := q.Params["event"]
|
||||
numWant := requestedPeerCount(q, cfg.NumWantFallback)
|
||||
numWant := requestedPeerCount(q, s.config.NumWantFallback)
|
||||
|
||||
infohash, exists := q.Params["info_hash"]
|
||||
if !exists {
|
||||
|
@ -38,17 +38,34 @@ func NewAnnounce(cfg *config.Config, r *http.Request, p httprouter.Params) (*mod
|
|||
return nil, models.ErrMalformedRequest
|
||||
}
|
||||
|
||||
ipv4, ipv6, err := requestedIP(q, r, &cfg.NetConfig)
|
||||
if err != nil {
|
||||
return nil, models.ErrMalformedRequest
|
||||
}
|
||||
|
||||
port, err := q.Uint64("port")
|
||||
if err != nil {
|
||||
return nil, models.ErrMalformedRequest
|
||||
}
|
||||
|
||||
left, err := q.Uint64("left")
|
||||
|
||||
ipv4, ipv6, err := requestedEndpoint(q, r, &s.config.NetConfig)
|
||||
if err != nil {
|
||||
return nil, models.ErrMalformedRequest
|
||||
}
|
||||
|
||||
var ipv4Endpoint, ipv6Endpoint models.Endpoint
|
||||
if ipv4 != nil {
|
||||
ipv4Endpoint = *ipv4
|
||||
// If the port we couldn't get the port before, fallback to the port param.
|
||||
if ipv4Endpoint.Port == uint16(0) {
|
||||
ipv4Endpoint.Port = uint16(port)
|
||||
}
|
||||
}
|
||||
if ipv6 != nil {
|
||||
ipv6Endpoint = *ipv6
|
||||
// If the port we couldn't get the port before, fallback to the port param.
|
||||
if ipv6Endpoint.Port == uint16(0) {
|
||||
ipv6Endpoint.Port = uint16(port)
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, models.ErrMalformedRequest
|
||||
}
|
||||
|
@ -64,24 +81,23 @@ func NewAnnounce(cfg *config.Config, r *http.Request, p httprouter.Params) (*mod
|
|||
}
|
||||
|
||||
return &models.Announce{
|
||||
Config: cfg,
|
||||
Config: s.config,
|
||||
Compact: compact,
|
||||
Downloaded: downloaded,
|
||||
Event: event,
|
||||
IPv4: ipv4,
|
||||
IPv6: ipv6,
|
||||
IPv4: ipv4Endpoint,
|
||||
IPv6: ipv6Endpoint,
|
||||
Infohash: infohash,
|
||||
Left: left,
|
||||
NumWant: numWant,
|
||||
Passkey: p.ByName("passkey"),
|
||||
PeerID: peerID,
|
||||
Port: port,
|
||||
Uploaded: uploaded,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewScrape parses an HTTP request and generates a models.Scrape.
|
||||
func NewScrape(cfg *config.Config, r *http.Request, p httprouter.Params) (*models.Scrape, error) {
|
||||
// newScrape parses an HTTP request and generates a models.Scrape.
|
||||
func (s *Server) newScrape(r *http.Request, p httprouter.Params) (*models.Scrape, error) {
|
||||
q, err := query.New(r.URL.RawQuery)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -96,7 +112,7 @@ func NewScrape(cfg *config.Config, r *http.Request, p httprouter.Params) (*model
|
|||
}
|
||||
|
||||
return &models.Scrape{
|
||||
Config: cfg,
|
||||
Config: s.config,
|
||||
|
||||
Passkey: p.ByName("passkey"),
|
||||
Infohashes: q.Infohashes,
|
||||
|
@ -116,26 +132,26 @@ func requestedPeerCount(q *query.Query, fallback int) int {
|
|||
return fallback
|
||||
}
|
||||
|
||||
// requestedIP returns the IP addresses for a request. If there are multiple
|
||||
// IP addresses in the request, one IPv4 and one IPv6 will be returned.
|
||||
func requestedIP(q *query.Query, r *http.Request, cfg *config.NetConfig) (v4, v6 net.IP, err error) {
|
||||
// requestedEndpoint returns the IP address and port pairs for a request. If
|
||||
// there are multiple in the request, one IPv4 and one IPv6 will be returned.
|
||||
func requestedEndpoint(q *query.Query, r *http.Request, cfg *config.NetConfig) (v4, v6 *models.Endpoint, err error) {
|
||||
var done bool
|
||||
|
||||
if cfg.AllowIPSpoofing {
|
||||
if str, ok := q.Params["ip"]; ok {
|
||||
if v4, v6, done = getIPs(str, v4, v6, cfg); done {
|
||||
if v4, v6, done = getEndpoints(str, v4, v6, cfg); done {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if str, ok := q.Params["ipv4"]; ok {
|
||||
if v4, v6, done = getIPs(str, v4, v6, cfg); done {
|
||||
if v4, v6, done = getEndpoints(str, v4, v6, cfg); done {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if str, ok := q.Params["ipv6"]; ok {
|
||||
if v4, v6, done = getIPs(str, v4, v6, cfg); done {
|
||||
if v4, v6, done = getEndpoints(str, v4, v6, cfg); done {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -143,47 +159,49 @@ func requestedIP(q *query.Query, r *http.Request, cfg *config.NetConfig) (v4, v6
|
|||
|
||||
if cfg.RealIPHeader != "" {
|
||||
if xRealIPs, ok := r.Header[cfg.RealIPHeader]; ok {
|
||||
if v4, v6, done = getIPs(string(xRealIPs[0]), v4, v6, cfg); done {
|
||||
if v4, v6, done = getEndpoints(string(xRealIPs[0]), v4, v6, cfg); done {
|
||||
return
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if r.RemoteAddr == "" {
|
||||
if v4 == nil {
|
||||
v4 = net.ParseIP("127.0.0.1")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var host string
|
||||
host, _, err = net.SplitHostPort(r.RemoteAddr)
|
||||
|
||||
if err == nil && host != "" {
|
||||
if v4, v6, done = getIPs(host, v4, v6, cfg); done {
|
||||
if r.RemoteAddr == "" && v4 == nil {
|
||||
if v4, v6, done = getEndpoints("127.0.0.1", v4, v6, cfg); done {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if v4, v6, done = getEndpoints(r.RemoteAddr, v4, v6, cfg); done {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if v4 == nil && v6 == nil {
|
||||
err = errors.New("failed to parse IP address")
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func getIPs(ipstr string, ipv4, ipv6 net.IP, cfg *config.NetConfig) (net.IP, net.IP, bool) {
|
||||
var done bool
|
||||
func getEndpoints(ipstr string, ipv4, ipv6 *models.Endpoint, cfg *config.NetConfig) (*models.Endpoint, *models.Endpoint, bool) {
|
||||
host, port, err := net.SplitHostPort(ipstr)
|
||||
if err != nil {
|
||||
host = ipstr
|
||||
}
|
||||
|
||||
if ip := net.ParseIP(ipstr); ip != nil {
|
||||
newIPv4 := ip.To4()
|
||||
// We can ignore this error, because ports that are 0 are assumed to be the
|
||||
// port parameter provided in the "port" param of the announce request.
|
||||
parsedPort, _ := strconv.ParseUint(port, 10, 16)
|
||||
|
||||
if ipv4 == nil && newIPv4 != nil {
|
||||
ipv4 = newIPv4
|
||||
} else if ipv6 == nil && newIPv4 == nil {
|
||||
ipv6 = ip
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
ipTo4 := ip.To4()
|
||||
if ipv4 == nil && ipTo4 != nil {
|
||||
ipv4 = &models.Endpoint{ipTo4, uint16(parsedPort)}
|
||||
} else if ipv6 == nil && ipTo4 == nil {
|
||||
ipv6 = &models.Endpoint{ip, uint16(parsedPort)}
|
||||
}
|
||||
}
|
||||
|
||||
var done bool
|
||||
if cfg.DualStackedPeers {
|
||||
done = ipv4 != nil && ipv6 != nil
|
||||
} else {
|
||||
|
|
|
@ -271,15 +271,21 @@ func (s *Stats) handlePeerEvent(ps *PeerStats, event int) {
|
|||
|
||||
// RecordEvent broadcasts an event to the default stats queue.
|
||||
func RecordEvent(event int) {
|
||||
DefaultStats.RecordEvent(event)
|
||||
if DefaultStats != nil {
|
||||
DefaultStats.RecordEvent(event)
|
||||
}
|
||||
}
|
||||
|
||||
// RecordPeerEvent broadcasts a peer event to the default stats queue.
|
||||
func RecordPeerEvent(event int, ipv6 bool) {
|
||||
DefaultStats.RecordPeerEvent(event, ipv6)
|
||||
if DefaultStats != nil {
|
||||
DefaultStats.RecordPeerEvent(event, ipv6)
|
||||
}
|
||||
}
|
||||
|
||||
// RecordTiming broadcasts a timing event to the default stats queue.
|
||||
func RecordTiming(event int, duration time.Duration) {
|
||||
DefaultStats.RecordTiming(event, duration)
|
||||
if DefaultStats != nil {
|
||||
DefaultStats.RecordTiming(event, duration)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -70,6 +70,7 @@ func (tkr *Tracker) HandleAnnounce(ann *models.Announce, w Writer) (err error) {
|
|||
stats.RecordEvent(stats.DeletedTorrent)
|
||||
}
|
||||
|
||||
stats.RecordEvent(stats.Announce)
|
||||
return w.WriteAnnounce(newAnnounceResponse(ann))
|
||||
}
|
||||
|
||||
|
@ -263,6 +264,7 @@ func newAnnounceResponse(ann *models.Announce) *models.AnnounceResponse {
|
|||
leechCount := ann.Torrent.Leechers.Len()
|
||||
|
||||
res := &models.AnnounceResponse{
|
||||
Announce: ann,
|
||||
Complete: seedCount,
|
||||
Incomplete: leechCount,
|
||||
Interval: ann.Config.Announce.Duration,
|
||||
|
@ -272,6 +274,10 @@ func newAnnounceResponse(ann *models.Announce) *models.AnnounceResponse {
|
|||
|
||||
if ann.NumWant > 0 && ann.Event != "stopped" && ann.Event != "paused" {
|
||||
res.IPv4Peers, res.IPv6Peers = getPeers(ann)
|
||||
|
||||
if len(res.IPv4Peers)+len(res.IPv6Peers) == 0 {
|
||||
models.AppendPeer(&res.IPv4Peers, &res.IPv6Peers, ann, ann.Peer)
|
||||
}
|
||||
}
|
||||
|
||||
return res
|
||||
|
|
|
@ -8,7 +8,6 @@ package models
|
|||
|
||||
import (
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
@ -40,9 +39,19 @@ var (
|
|||
|
||||
type ClientError string
|
||||
type NotFoundError ClientError
|
||||
type ProtocolError ClientError
|
||||
|
||||
func (e ClientError) Error() string { return string(e) }
|
||||
func (e NotFoundError) Error() string { return string(e) }
|
||||
func (e ProtocolError) Error() string { return string(e) }
|
||||
|
||||
// IsPublicError determines whether an error should be propogated to the client.
|
||||
func IsPublicError(err error) bool {
|
||||
_, cl := err.(ClientError)
|
||||
_, nf := err.(NotFoundError)
|
||||
_, pc := err.(ProtocolError)
|
||||
return cl || nf || pc
|
||||
}
|
||||
|
||||
// PeerList represents a list of peers: either seeders or leechers.
|
||||
type PeerList []Peer
|
||||
|
@ -51,8 +60,8 @@ type PeerList []Peer
|
|||
type PeerKey string
|
||||
|
||||
// NewPeerKey creates a properly formatted PeerKey.
|
||||
func NewPeerKey(peerID string, ip net.IP, port string) PeerKey {
|
||||
return PeerKey(peerID + "//" + ip.String() + ":" + port)
|
||||
func NewPeerKey(peerID string, ip net.IP) PeerKey {
|
||||
return PeerKey(peerID + "//" + ip.String())
|
||||
}
|
||||
|
||||
// IP parses and returns the IP address for a given PeerKey.
|
||||
|
@ -69,26 +78,23 @@ func (pk PeerKey) PeerID() string {
|
|||
return strings.Split(string(pk), "//")[0]
|
||||
}
|
||||
|
||||
// Port returns the port section of the PeerKey.
|
||||
func (pk PeerKey) Port() string {
|
||||
return strings.Split(string(pk), "//")[2]
|
||||
// Endpoint is an IP and port pair.
|
||||
type Endpoint struct {
|
||||
// Always has length net.IPv4len if IPv4, and net.IPv6len if IPv6
|
||||
IP net.IP `json:"ip"`
|
||||
Port uint16 `json:"port"`
|
||||
}
|
||||
|
||||
// Peer is a participant in a swarm.
|
||||
type Peer struct {
|
||||
ID string `json:"id"`
|
||||
UserID uint64 `json:"user_id"`
|
||||
TorrentID uint64 `json:"torrent_id"`
|
||||
|
||||
// Always has length net.IPv4len if IPv4, and net.IPv6len if IPv6
|
||||
IP net.IP `json:"ip,omitempty"`
|
||||
|
||||
Port uint64 `json:"port"`
|
||||
|
||||
ID string `json:"id"`
|
||||
UserID uint64 `json:"user_id"`
|
||||
TorrentID uint64 `json:"torrent_id"`
|
||||
Uploaded uint64 `json:"uploaded"`
|
||||
Downloaded uint64 `json:"downloaded"`
|
||||
Left uint64 `json:"left"`
|
||||
LastAnnounce int64 `json:"last_announce"`
|
||||
Endpoint
|
||||
}
|
||||
|
||||
// HasIPv4 determines if a peer's IP address can be represented as an IPv4
|
||||
|
@ -105,7 +111,7 @@ func (p *Peer) HasIPv6() bool {
|
|||
|
||||
// Key returns a PeerKey for the given peer.
|
||||
func (p *Peer) Key() PeerKey {
|
||||
return NewPeerKey(p.ID, p.IP, strconv.FormatUint(p.Port, 10))
|
||||
return NewPeerKey(p.ID, p.IP)
|
||||
}
|
||||
|
||||
// Torrent is a swarm for a given torrent file.
|
||||
|
@ -140,18 +146,17 @@ type User struct {
|
|||
type Announce struct {
|
||||
Config *config.Config `json:"config"`
|
||||
|
||||
Compact bool `json:"compact"`
|
||||
Downloaded uint64 `json:"downloaded"`
|
||||
Event string `json:"event"`
|
||||
IPv4 net.IP `json:"ipv4"`
|
||||
IPv6 net.IP `json:"ipv6"`
|
||||
Infohash string `json:"infohash"`
|
||||
Left uint64 `json:"left"`
|
||||
NumWant int `json:"numwant"`
|
||||
Passkey string `json:"passkey"`
|
||||
PeerID string `json:"peer_id"`
|
||||
Port uint64 `json:"port"`
|
||||
Uploaded uint64 `json:"uploaded"`
|
||||
Compact bool `json:"compact"`
|
||||
Downloaded uint64 `json:"downloaded"`
|
||||
Event string `json:"event"`
|
||||
IPv4 Endpoint `json:"ipv4"`
|
||||
IPv6 Endpoint `json:"ipv6"`
|
||||
Infohash string `json:"infohash"`
|
||||
Left uint64 `json:"left"`
|
||||
NumWant int `json:"numwant"`
|
||||
Passkey string `json:"passkey"`
|
||||
PeerID string `json:"peer_id"`
|
||||
Uploaded uint64 `json:"uploaded"`
|
||||
|
||||
Torrent *Torrent `json:"-"`
|
||||
User *User `json:"-"`
|
||||
|
@ -177,12 +182,14 @@ func (a *Announce) ClientID() (clientID string) {
|
|||
return
|
||||
}
|
||||
|
||||
// HasIPv4 determines whether or not an announce has an IPv4 endpoint.
|
||||
func (a *Announce) HasIPv4() bool {
|
||||
return a.IPv4 != nil
|
||||
return a.IPv4.IP != nil
|
||||
}
|
||||
|
||||
// HasIPv6 determines whether or not an announce has an IPv6 endpoint.
|
||||
func (a *Announce) HasIPv6() bool {
|
||||
return a.IPv6 != nil
|
||||
return a.IPv6.IP != nil
|
||||
}
|
||||
|
||||
// BuildPeer creates the Peer representation of an Announce. When provided nil
|
||||
|
@ -192,7 +199,6 @@ func (a *Announce) HasIPv6() bool {
|
|||
func (a *Announce) BuildPeer(u *User, t *Torrent) {
|
||||
a.Peer = &Peer{
|
||||
ID: a.PeerID,
|
||||
Port: a.Port,
|
||||
Uploaded: a.Uploaded,
|
||||
Downloaded: a.Downloaded,
|
||||
Left: a.Left,
|
||||
|
@ -211,15 +217,15 @@ func (a *Announce) BuildPeer(u *User, t *Torrent) {
|
|||
|
||||
if a.HasIPv4() && a.HasIPv6() {
|
||||
a.PeerV4 = a.Peer
|
||||
a.PeerV4.IP = a.IPv4
|
||||
a.PeerV4.Endpoint = a.IPv4
|
||||
a.PeerV6 = &*a.Peer
|
||||
a.PeerV6.IP = a.IPv6
|
||||
a.PeerV6.Endpoint = a.IPv6
|
||||
} else if a.HasIPv4() {
|
||||
a.PeerV4 = a.Peer
|
||||
a.PeerV4.IP = a.IPv4
|
||||
a.PeerV4.Endpoint = a.IPv4
|
||||
} else if a.HasIPv6() {
|
||||
a.PeerV6 = a.Peer
|
||||
a.PeerV6.IP = a.IPv6
|
||||
a.PeerV6.Endpoint = a.IPv6
|
||||
} else {
|
||||
panic("models: announce must have an IP")
|
||||
}
|
||||
|
@ -250,6 +256,7 @@ type AnnounceDelta struct {
|
|||
|
||||
// AnnounceResponse contains the information needed to fulfill an announce.
|
||||
type AnnounceResponse struct {
|
||||
Announce *Announce
|
||||
Complete, Incomplete int
|
||||
Interval, MinInterval time.Duration
|
||||
IPv4Peers, IPv6Peers PeerList
|
||||
|
|
|
@ -159,7 +159,7 @@ func (pm *PeerMap) AppendPeers(ipv4s, ipv6s PeerList, ann *Announce, wanted int)
|
|||
} else if peersEquivalent(&peer, ann.Peer) {
|
||||
continue
|
||||
} else {
|
||||
appendPeer(&ipv4s, &ipv6s, ann, &peer, &count)
|
||||
count += AppendPeer(&ipv4s, &ipv6s, ann, &peer)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -174,7 +174,7 @@ func (pm *PeerMap) AppendPeers(ipv4s, ipv6s PeerList, ann *Announce, wanted int)
|
|||
} else if peersEquivalent(&peer, ann.Peer) {
|
||||
continue
|
||||
} else {
|
||||
appendPeer(&ipv4s, &ipv6s, ann, &peer, &count)
|
||||
count += AppendPeer(&ipv4s, &ipv6s, ann, &peer)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -183,18 +183,20 @@ func (pm *PeerMap) AppendPeers(ipv4s, ipv6s PeerList, ann *Announce, wanted int)
|
|||
return ipv4s, ipv6s
|
||||
}
|
||||
|
||||
// appendPeer adds a peer to its corresponding peerlist.
|
||||
func appendPeer(ipv4s, ipv6s *PeerList, ann *Announce, peer *Peer, count *int) {
|
||||
// AppendPeer adds a peer to its corresponding peerlist.
|
||||
func AppendPeer(ipv4s, ipv6s *PeerList, ann *Announce, peer *Peer) int {
|
||||
if ann.HasIPv6() && peer.HasIPv6() {
|
||||
*ipv6s = append(*ipv6s, *peer)
|
||||
*count++
|
||||
return 1
|
||||
} else if ann.Config.RespectAF && ann.HasIPv4() && peer.HasIPv4() {
|
||||
*ipv4s = append(*ipv4s, *peer)
|
||||
*count++
|
||||
return 1
|
||||
} else if !ann.Config.RespectAF && peer.HasIPv4() {
|
||||
*ipv4s = append(*ipv4s, *peer)
|
||||
*count++
|
||||
return 1
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
// peersEquivalent checks if two peers represent the same entity.
|
||||
|
|
|
@ -4,7 +4,10 @@
|
|||
|
||||
package tracker
|
||||
|
||||
import "github.com/chihaya/chihaya/tracker/models"
|
||||
import (
|
||||
"github.com/chihaya/chihaya/stats"
|
||||
"github.com/chihaya/chihaya/tracker/models"
|
||||
)
|
||||
|
||||
// HandleScrape encapsulates all the logic of handling a BitTorrent client's
|
||||
// scrape without being coupled to any transport protocol.
|
||||
|
@ -24,6 +27,7 @@ func (tkr *Tracker) HandleScrape(scrape *models.Scrape, w Writer) (err error) {
|
|||
torrents = append(torrents, torrent)
|
||||
}
|
||||
|
||||
stats.RecordEvent(stats.Scrape)
|
||||
return w.WriteScrape(&models.ScrapeResponse{
|
||||
Files: torrents,
|
||||
})
|
||||
|
|
|
@ -24,6 +24,15 @@ type Tracker struct {
|
|||
*Storage
|
||||
}
|
||||
|
||||
// Server represents a server for a given BitTorrent tracker protocol.
|
||||
type Server interface {
|
||||
// Serve runs the server and blocks until the server has shut down.
|
||||
Serve(addr string)
|
||||
|
||||
// Stop cleanly shuts down the server in a non-blocking manner.
|
||||
Stop()
|
||||
}
|
||||
|
||||
// New creates a new Tracker, and opens any necessary connections.
|
||||
// Maintenance routines are automatically spawned in the background.
|
||||
func New(cfg *config.Config) (*Tracker, error) {
|
||||
|
@ -64,7 +73,8 @@ func (tkr *Tracker) LoadApprovedClients(clients []string) {
|
|||
}
|
||||
|
||||
// Writer serializes a tracker's responses, and is implemented for each
|
||||
// response transport used by the tracker.
|
||||
// response transport used by the tracker. Only one of these may be called
|
||||
// per request, and only once.
|
||||
//
|
||||
// Note, data passed into any of these functions will not contain sensitive
|
||||
// information, so it may be passed back the client freely.
|
||||
|
|
87
udp/announce_test.go
Normal file
87
udp/announce_test.go
Normal file
|
@ -0,0 +1,87 @@
|
|||
// Copyright 2015 The Chihaya Authors. All rights reserved.
|
||||
// Use of this source code is governed by the BSD 2-Clause license,
|
||||
// which can be found in the LICENSE file.
|
||||
|
||||
package udp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/chihaya/chihaya/config"
|
||||
)
|
||||
|
||||
func requestAnnounce(sock *net.UDPConn, connID []byte, hash string) ([]byte, error) {
|
||||
txID := makeTransactionID()
|
||||
peerID := []byte("-UT2210-b4a2h9a9f5c4")
|
||||
|
||||
var request []byte
|
||||
request = append(request, connID...)
|
||||
request = append(request, announceAction...)
|
||||
request = append(request, txID...)
|
||||
request = append(request, []byte(hash)...)
|
||||
request = append(request, peerID...)
|
||||
request = append(request, make([]byte, 8)...) // Downloaded
|
||||
request = append(request, make([]byte, 8)...) // Left
|
||||
request = append(request, make([]byte, 8)...) // Uploaded
|
||||
request = append(request, make([]byte, 4)...) // Event
|
||||
request = append(request, make([]byte, 4)...) // IP
|
||||
request = append(request, make([]byte, 4)...) // Key
|
||||
request = append(request, make([]byte, 4)...) // NumWant
|
||||
request = append(request, make([]byte, 2)...) // Port
|
||||
|
||||
return doRequest(sock, request, txID)
|
||||
}
|
||||
|
||||
func TestAnnounce(t *testing.T) {
|
||||
srv, done, err := setupTracker(&config.DefaultConfig)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, sock, err := setupSocket()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
connID, err := requestConnectionID(sock)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
announce, err := requestAnnounce(sock, connID, "aaaaaaaaaaaaaaaaaaaa")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Parse the response.
|
||||
var action, txID, interval, leechers, seeders uint32
|
||||
buf := bytes.NewReader(announce)
|
||||
binary.Read(buf, binary.BigEndian, &action)
|
||||
binary.Read(buf, binary.BigEndian, &txID)
|
||||
binary.Read(buf, binary.BigEndian, &interval)
|
||||
binary.Read(buf, binary.BigEndian, &leechers)
|
||||
binary.Read(buf, binary.BigEndian, &seeders)
|
||||
|
||||
if action != uint32(announceActionID) {
|
||||
t.Fatal("expected announce action")
|
||||
}
|
||||
|
||||
if interval != uint32(config.DefaultConfig.Announce.Seconds()) {
|
||||
t.Fatal("incorrect interval")
|
||||
}
|
||||
|
||||
if leechers != uint32(0) {
|
||||
t.Fatal("incorrect leecher count")
|
||||
}
|
||||
|
||||
// We're the only seeder.
|
||||
if seeders != uint32(1) {
|
||||
t.Fatal("incorrect seeder count")
|
||||
}
|
||||
|
||||
srv.Stop()
|
||||
<-done
|
||||
}
|
90
udp/connection.go
Normal file
90
udp/connection.go
Normal file
|
@ -0,0 +1,90 @@
|
|||
// Copyright 2015 The Chihaya Authors. All rights reserved.
|
||||
// Use of this source code is governed by the BSD 2-Clause license,
|
||||
// which can be found in the LICENSE file.
|
||||
|
||||
package udp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"net"
|
||||
)
|
||||
|
||||
// ConnectionIDGenerator represents the logic to generate 64-bit UDP
|
||||
// connection IDs from peer IP addresses.
|
||||
type ConnectionIDGenerator struct {
|
||||
iv, iv2 []byte
|
||||
block cipher.Block
|
||||
}
|
||||
|
||||
// NewConnectionIDGenerator creates a ConnectionIDGenerator and generates its
|
||||
// AES key and first initialization vector.
|
||||
func NewConnectionIDGenerator() (gen *ConnectionIDGenerator, err error) {
|
||||
gen = &ConnectionIDGenerator{}
|
||||
key := make([]byte, 16)
|
||||
|
||||
_, err = rand.Read(key)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
gen.block, err = aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = gen.NewIV()
|
||||
return
|
||||
}
|
||||
|
||||
// Generate returns the 64-bit connection ID for an IP
|
||||
func (g *ConnectionIDGenerator) Generate(ip net.IP) []byte {
|
||||
return g.generate(ip, g.iv)
|
||||
}
|
||||
|
||||
func (g *ConnectionIDGenerator) generate(ip net.IP, iv []byte) []byte {
|
||||
for len(ip) < 8 {
|
||||
ip = append(ip, ip...) // Not enough bits in output.
|
||||
}
|
||||
|
||||
ct := make([]byte, 16)
|
||||
stream := cipher.NewCFBDecrypter(g.block, iv)
|
||||
stream.XORKeyStream(ct, ip)
|
||||
|
||||
for i := len(ip) - 1; i >= 8; i-- {
|
||||
ct[i-8] ^= ct[i]
|
||||
}
|
||||
|
||||
return ct[:8]
|
||||
}
|
||||
|
||||
// Matches checks if the given connection ID matches an IP with the current or
|
||||
// previous initialization vectors.
|
||||
func (g *ConnectionIDGenerator) Matches(id []byte, ip net.IP) bool {
|
||||
if expected := g.generate(ip, g.iv); bytes.Equal(id, expected) {
|
||||
return true
|
||||
}
|
||||
|
||||
if iv2 := g.iv2; iv2 != nil {
|
||||
if expected := g.generate(ip, iv2); bytes.Equal(id, expected) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// NewIV generates a new initialization vector and rotates the current one.
|
||||
func (g *ConnectionIDGenerator) NewIV() error {
|
||||
newiv := make([]byte, 16)
|
||||
if _, err := rand.Read(newiv); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
g.iv2 = g.iv
|
||||
g.iv = newiv
|
||||
|
||||
return nil
|
||||
}
|
70
udp/connection_test.go
Normal file
70
udp/connection_test.go
Normal file
|
@ -0,0 +1,70 @@
|
|||
// Copyright 2015 The Chihaya Authors. All rights reserved.
|
||||
// Use of this source code is governed by the BSD 2-Clause license,
|
||||
// which can be found in the LICENSE file.
|
||||
|
||||
package udp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestInitReturnsNoError(t *testing.T) {
|
||||
if _, err := NewConnectionIDGenerator(); err != nil {
|
||||
t.Error("Init returned", err)
|
||||
}
|
||||
}
|
||||
|
||||
func testGenerateConnectionID(t *testing.T, ip net.IP) {
|
||||
gen, _ := NewConnectionIDGenerator()
|
||||
|
||||
id1 := gen.Generate(ip)
|
||||
id2 := gen.Generate(ip)
|
||||
|
||||
if !bytes.Equal(id1, id2) {
|
||||
t.Errorf("Connection ID mismatch: %x != %x", id1, id2)
|
||||
}
|
||||
|
||||
if len(id1) != 8 {
|
||||
t.Errorf("Connection ID had length: %d != 8", len(id1))
|
||||
}
|
||||
|
||||
if bytes.Count(id1, []byte{0}) == 8 {
|
||||
t.Errorf("Connection ID was 0")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateConnectionIDIPv4(t *testing.T) {
|
||||
testGenerateConnectionID(t, net.ParseIP("192.168.1.123").To4())
|
||||
}
|
||||
|
||||
func TestGenerateConnectionIDIPv6(t *testing.T) {
|
||||
testGenerateConnectionID(t, net.ParseIP("1:2:3:4::5:6"))
|
||||
}
|
||||
|
||||
func TestMatchesWorksWithPreviousIV(t *testing.T) {
|
||||
gen, _ := NewConnectionIDGenerator()
|
||||
ip := net.ParseIP("192.168.1.123").To4()
|
||||
|
||||
id1 := gen.Generate(ip)
|
||||
if !gen.Matches(id1, ip) {
|
||||
t.Errorf("Connection ID mismatch for current IV")
|
||||
}
|
||||
|
||||
gen.NewIV()
|
||||
if !gen.Matches(id1, ip) {
|
||||
t.Errorf("Connection ID mismatch for previous IV")
|
||||
}
|
||||
|
||||
id2 := gen.Generate(ip)
|
||||
gen.NewIV()
|
||||
|
||||
if gen.Matches(id1, ip) {
|
||||
t.Errorf("Connection ID matched for discarded IV")
|
||||
}
|
||||
|
||||
if !gen.Matches(id2, ip) {
|
||||
t.Errorf("Connection ID mismatch for previous IV")
|
||||
}
|
||||
}
|
252
udp/protocol.go
Normal file
252
udp/protocol.go
Normal file
|
@ -0,0 +1,252 @@
|
|||
// Copyright 2015 The Chihaya Authors. All rights reserved.
|
||||
// Use of this source code is governed by the BSD 2-Clause license,
|
||||
// which can be found in the LICENSE file.
|
||||
|
||||
package udp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"net"
|
||||
|
||||
"github.com/chihaya/chihaya/stats"
|
||||
"github.com/chihaya/chihaya/tracker/models"
|
||||
)
|
||||
|
||||
const (
|
||||
connectActionID uint32 = iota
|
||||
announceActionID
|
||||
scrapeActionID
|
||||
errorActionID
|
||||
announceDualStackActionID
|
||||
)
|
||||
|
||||
var (
|
||||
// initialConnectionID is the magic initial connection ID specified by BEP 15.
|
||||
initialConnectionID = []byte{0, 0, 0x04, 0x17, 0x27, 0x10, 0x19, 0x80}
|
||||
|
||||
// emptyIPs are the value of an IP field that has been left blank.
|
||||
emptyIPv4 = []byte{0, 0, 0, 0}
|
||||
emptyIPv6 = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
|
||||
|
||||
// Option-Types described in BEP41 and BEP45.
|
||||
optionEndOfOptions = byte(0x0)
|
||||
optionNOP = byte(0x1)
|
||||
optionURLData = byte(0x2)
|
||||
optionIPv6 = byte(0x3)
|
||||
|
||||
// eventIDs map IDs to event names.
|
||||
eventIDs = []string{
|
||||
"",
|
||||
"completed",
|
||||
"started",
|
||||
"stopped",
|
||||
}
|
||||
|
||||
errMalformedPacket = models.ProtocolError("malformed packet")
|
||||
errMalformedIP = models.ProtocolError("malformed IP address")
|
||||
errMalformedEvent = models.ProtocolError("malformed event ID")
|
||||
errBadConnectionID = models.ProtocolError("bad connection ID")
|
||||
)
|
||||
|
||||
// handleTorrentError writes err to w if err is a models.ClientError.
|
||||
func handleTorrentError(err error, w *Writer) {
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if models.IsPublicError(err) {
|
||||
w.WriteError(err)
|
||||
stats.RecordEvent(stats.ClientError)
|
||||
}
|
||||
}
|
||||
|
||||
// handlePacket decodes and processes one UDP request, returning the response.
|
||||
func (s *Server) handlePacket(packet []byte, addr *net.UDPAddr) (response []byte, actionName string) {
|
||||
if len(packet) < 16 {
|
||||
return // Malformed, no client packets are less than 16 bytes.
|
||||
}
|
||||
|
||||
connID := packet[0:8]
|
||||
action := binary.BigEndian.Uint32(packet[8:12])
|
||||
transactionID := packet[12:16]
|
||||
|
||||
writer := &Writer{
|
||||
buf: new(bytes.Buffer),
|
||||
|
||||
connectionID: connID,
|
||||
transactionID: transactionID,
|
||||
}
|
||||
defer func() { response = writer.buf.Bytes() }()
|
||||
|
||||
if action != 0 && !s.connIDGen.Matches(connID, addr.IP) {
|
||||
writer.WriteError(errBadConnectionID)
|
||||
return
|
||||
}
|
||||
|
||||
switch action {
|
||||
case connectActionID:
|
||||
actionName = "connect"
|
||||
if !bytes.Equal(connID, initialConnectionID) {
|
||||
return // Malformed packet.
|
||||
}
|
||||
|
||||
writer.writeHeader(0)
|
||||
writer.buf.Write(s.connIDGen.Generate(addr.IP))
|
||||
|
||||
case announceActionID:
|
||||
actionName = "announce"
|
||||
ann, err := s.newAnnounce(packet, addr.IP)
|
||||
|
||||
if err == nil {
|
||||
err = s.tracker.HandleAnnounce(ann, writer)
|
||||
}
|
||||
|
||||
handleTorrentError(err, writer)
|
||||
|
||||
case scrapeActionID:
|
||||
actionName = "scrape"
|
||||
scrape, err := s.newScrape(packet)
|
||||
|
||||
if err == nil {
|
||||
err = s.tracker.HandleScrape(scrape, writer)
|
||||
}
|
||||
|
||||
handleTorrentError(err, writer)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// newAnnounce decodes one announce packet, returning a models.Announce.
|
||||
func (s *Server) newAnnounce(packet []byte, ip net.IP) (*models.Announce, error) {
|
||||
if len(packet) < 98 {
|
||||
return nil, errMalformedPacket
|
||||
}
|
||||
|
||||
infohash := packet[16:36]
|
||||
peerID := packet[36:56]
|
||||
|
||||
downloaded := binary.BigEndian.Uint64(packet[56:64])
|
||||
left := binary.BigEndian.Uint64(packet[64:72])
|
||||
uploaded := binary.BigEndian.Uint64(packet[72:80])
|
||||
|
||||
eventID := packet[83]
|
||||
if eventID > 3 {
|
||||
return nil, errMalformedEvent
|
||||
}
|
||||
|
||||
ipv4bytes := packet[84:88]
|
||||
if s.config.AllowIPSpoofing && !bytes.Equal(ipv4bytes, emptyIPv4) {
|
||||
ip = net.ParseIP(string(ipv4bytes))
|
||||
}
|
||||
|
||||
if ip == nil {
|
||||
return nil, errMalformedIP
|
||||
} else if ipv4 := ip.To4(); ipv4 != nil {
|
||||
ip = ipv4
|
||||
}
|
||||
|
||||
numWant := binary.BigEndian.Uint32(packet[92:96])
|
||||
port := binary.BigEndian.Uint16(packet[96:98])
|
||||
|
||||
announce := &models.Announce{
|
||||
Config: s.config,
|
||||
Downloaded: downloaded,
|
||||
Event: eventIDs[eventID],
|
||||
IPv4: models.Endpoint{
|
||||
IP: ip,
|
||||
Port: port,
|
||||
},
|
||||
Infohash: string(infohash),
|
||||
Left: left,
|
||||
NumWant: int(numWant),
|
||||
PeerID: string(peerID),
|
||||
Uploaded: uploaded,
|
||||
}
|
||||
|
||||
if err := s.handleOptionalParameters(packet, announce); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return announce, nil
|
||||
}
|
||||
|
||||
// handleOptionalParameters parses the optional parameters as described in BEP41
|
||||
// and updates an announce with the values parsed.
|
||||
func (s *Server) handleOptionalParameters(packet []byte, announce *models.Announce) error {
|
||||
if len(packet) > 98 {
|
||||
optionStartIndex := 98
|
||||
for optionStartIndex < len(packet)-1 {
|
||||
option := packet[optionStartIndex]
|
||||
switch option {
|
||||
case optionEndOfOptions:
|
||||
return nil
|
||||
|
||||
case optionNOP:
|
||||
optionStartIndex++
|
||||
|
||||
case optionURLData:
|
||||
if optionStartIndex+1 > len(packet)-1 {
|
||||
return errMalformedPacket
|
||||
}
|
||||
|
||||
length := int(packet[optionStartIndex+1])
|
||||
if optionStartIndex+1+length > len(packet)-1 {
|
||||
return errMalformedPacket
|
||||
}
|
||||
|
||||
// TODO: Actually parse the URL Data as described in BEP41.
|
||||
|
||||
optionStartIndex += 1 + length
|
||||
|
||||
case optionIPv6:
|
||||
if optionStartIndex+19 > len(packet)-1 {
|
||||
return errMalformedPacket
|
||||
}
|
||||
|
||||
ipv6bytes := packet[optionStartIndex+1 : optionStartIndex+17]
|
||||
if s.config.AllowIPSpoofing && !bytes.Equal(ipv6bytes, emptyIPv6) {
|
||||
announce.IPv6.IP = net.ParseIP(string(ipv6bytes)).To16()
|
||||
announce.IPv6.Port = binary.BigEndian.Uint16(packet[optionStartIndex+17 : optionStartIndex+19])
|
||||
if announce.IPv6.IP == nil {
|
||||
return errMalformedIP
|
||||
}
|
||||
}
|
||||
|
||||
optionStartIndex += 19
|
||||
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// There was no optional parameters to parse.
|
||||
return nil
|
||||
}
|
||||
|
||||
// newScrape decodes one announce packet, returning a models.Scrape.
|
||||
func (s *Server) newScrape(packet []byte) (*models.Scrape, error) {
|
||||
if len(packet) < 36 {
|
||||
return nil, errMalformedPacket
|
||||
}
|
||||
|
||||
var infohashes []string
|
||||
packet = packet[16:]
|
||||
|
||||
if len(packet)%20 != 0 {
|
||||
return nil, errMalformedPacket
|
||||
}
|
||||
|
||||
for len(packet) >= 20 {
|
||||
infohash := packet[:20]
|
||||
infohashes = append(infohashes, string(infohash))
|
||||
packet = packet[20:]
|
||||
}
|
||||
|
||||
return &models.Scrape{
|
||||
Config: s.config,
|
||||
Infohashes: infohashes,
|
||||
}, nil
|
||||
}
|
77
udp/scrape_test.go
Normal file
77
udp/scrape_test.go
Normal file
|
@ -0,0 +1,77 @@
|
|||
// Copyright 2015 The Chihaya Authors. All rights reserved.
|
||||
// Use of this source code is governed by the BSD 2-Clause license,
|
||||
// which can be found in the LICENSE file.
|
||||
|
||||
package udp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/chihaya/chihaya/config"
|
||||
)
|
||||
|
||||
func doRequest(sock *net.UDPConn, request, txID []byte) ([]byte, error) {
|
||||
response := make([]byte, 1024)
|
||||
|
||||
n, err := sendRequest(sock, request, response)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !bytes.Equal(response[4:8], txID) {
|
||||
return nil, fmt.Errorf("transaction ID mismatch")
|
||||
}
|
||||
|
||||
return response[:n], nil
|
||||
}
|
||||
|
||||
func requestScrape(sock *net.UDPConn, connID []byte, hashes []string) ([]byte, error) {
|
||||
txID := makeTransactionID()
|
||||
|
||||
var request []byte
|
||||
request = append(request, connID...)
|
||||
request = append(request, scrapeAction...)
|
||||
request = append(request, txID...)
|
||||
|
||||
for _, hash := range hashes {
|
||||
request = append(request, []byte(hash)...)
|
||||
}
|
||||
|
||||
return doRequest(sock, request, txID)
|
||||
}
|
||||
|
||||
func TestScrapeEmpty(t *testing.T) {
|
||||
srv, done, err := setupTracker(&config.DefaultConfig)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, sock, err := setupSocket()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
connID, err := requestConnectionID(sock)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
scrape, err := requestScrape(sock, connID, []string{"aaaaaaaaaaaaaaaaaaaa"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(scrape[:4], errorAction) {
|
||||
t.Error("expected error response")
|
||||
}
|
||||
|
||||
if string(scrape[8:]) != "torrent does not exist\000" {
|
||||
t.Error("expected torrent to not exist")
|
||||
}
|
||||
|
||||
srv.Stop()
|
||||
<-done
|
||||
}
|
128
udp/udp.go
Normal file
128
udp/udp.go
Normal file
|
@ -0,0 +1,128 @@
|
|||
// Copyright 2015 The Chihaya Authors. All rights reserved.
|
||||
// Use of this source code is governed by the BSD 2-Clause license,
|
||||
// which can be found in the LICENSE file.
|
||||
|
||||
// Package udp implements a UDP BitTorrent tracker per BEP 15.
|
||||
// IPv6 is currently unsupported as there is no widely-implemented standard.
|
||||
package udp
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/golang/glog"
|
||||
"github.com/pushrax/bufferpool"
|
||||
|
||||
"github.com/chihaya/chihaya/config"
|
||||
"github.com/chihaya/chihaya/tracker"
|
||||
)
|
||||
|
||||
// Server represents a UDP torrent tracker.
|
||||
type Server struct {
|
||||
config *config.Config
|
||||
tracker *tracker.Tracker
|
||||
done bool
|
||||
booting chan struct{}
|
||||
sock *net.UDPConn
|
||||
|
||||
connIDGen *ConnectionIDGenerator
|
||||
}
|
||||
|
||||
func (s *Server) serve(listenAddr string) error {
|
||||
if s.sock != nil {
|
||||
return errors.New("server already booted")
|
||||
}
|
||||
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", listenAddr)
|
||||
if err != nil {
|
||||
close(s.booting)
|
||||
return err
|
||||
}
|
||||
|
||||
sock, err := net.ListenUDP("udp", udpAddr)
|
||||
defer sock.Close()
|
||||
if err != nil {
|
||||
close(s.booting)
|
||||
return err
|
||||
}
|
||||
|
||||
if s.config.UDPReadBufferSize > 0 {
|
||||
sock.SetReadBuffer(s.config.UDPReadBufferSize)
|
||||
}
|
||||
|
||||
pool := bufferpool.New(1000, 2048)
|
||||
s.sock = sock
|
||||
close(s.booting)
|
||||
|
||||
for !s.done {
|
||||
buffer := pool.TakeSlice()
|
||||
sock.SetReadDeadline(time.Now().Add(time.Second))
|
||||
n, addr, err := sock.ReadFromUDP(buffer)
|
||||
|
||||
if err != nil {
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Temporary() {
|
||||
pool.GiveSlice(buffer)
|
||||
continue
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
|
||||
go func() {
|
||||
response, action := s.handlePacket(buffer[:n], addr)
|
||||
pool.GiveSlice(buffer)
|
||||
|
||||
if len(response) > 0 {
|
||||
sock.WriteToUDP(response, addr)
|
||||
}
|
||||
|
||||
if glog.V(2) {
|
||||
duration := time.Since(start)
|
||||
glog.Infof("[UDP - %9s] %s %s", duration, action, addr)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Serve runs a UDP server, blocking until the server has shut down.
|
||||
func (s *Server) Serve(addr string) {
|
||||
glog.V(0).Info("Starting UDP on ", addr)
|
||||
|
||||
go func() {
|
||||
// Generate a new IV every hour.
|
||||
for range time.Tick(time.Hour) {
|
||||
s.connIDGen.NewIV()
|
||||
}
|
||||
}()
|
||||
|
||||
if err := s.serve(addr); err != nil {
|
||||
glog.Errorf("Failed to run UDP server: %s", err.Error())
|
||||
} else {
|
||||
glog.Info("UDP server shut down cleanly")
|
||||
}
|
||||
}
|
||||
|
||||
// Stop cleanly shuts down the server.
|
||||
func (s *Server) Stop() {
|
||||
s.done = true
|
||||
s.sock.SetReadDeadline(time.Now())
|
||||
}
|
||||
|
||||
// NewServer returns a new UDP server for a given configuration and tracker.
|
||||
func NewServer(cfg *config.Config, tkr *tracker.Tracker) *Server {
|
||||
gen, err := NewConnectionIDGenerator()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return &Server{
|
||||
config: cfg,
|
||||
tracker: tkr,
|
||||
connIDGen: gen,
|
||||
booting: make(chan struct{}),
|
||||
}
|
||||
}
|
132
udp/udp_test.go
Normal file
132
udp/udp_test.go
Normal file
|
@ -0,0 +1,132 @@
|
|||
// Copyright 2015 The Chihaya Authors. All rights reserved.
|
||||
// Use of this source code is governed by the BSD 2-Clause license,
|
||||
// which can be found in the LICENSE file.
|
||||
|
||||
package udp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/chihaya/chihaya/config"
|
||||
"github.com/chihaya/chihaya/tracker"
|
||||
|
||||
_ "github.com/chihaya/chihaya/backend/noop"
|
||||
)
|
||||
|
||||
var (
|
||||
testPort = "34137"
|
||||
connectAction = []byte{0, 0, 0, byte(connectActionID)}
|
||||
announceAction = []byte{0, 0, 0, byte(announceActionID)}
|
||||
scrapeAction = []byte{0, 0, 0, byte(scrapeActionID)}
|
||||
errorAction = []byte{0, 0, 0, byte(errorActionID)}
|
||||
)
|
||||
|
||||
func setupTracker(cfg *config.Config) (*Server, chan struct{}, error) {
|
||||
tkr, err := tracker.New(cfg)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
srv := NewServer(cfg, tkr)
|
||||
done := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
if err := srv.serve(":" + testPort); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
close(done)
|
||||
}()
|
||||
|
||||
<-srv.booting
|
||||
return srv, done, nil
|
||||
}
|
||||
|
||||
func setupSocket() (*net.UDPAddr, *net.UDPConn, error) {
|
||||
srvAddr, err := net.ResolveUDPAddr("udp", "localhost:"+testPort)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
sock, err := net.DialUDP("udp", nil, srvAddr)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return srvAddr, sock, err
|
||||
}
|
||||
|
||||
func makeTransactionID() []byte {
|
||||
out := make([]byte, 4)
|
||||
rand.Read(out)
|
||||
return out
|
||||
}
|
||||
|
||||
func sendRequest(sock *net.UDPConn, request, response []byte) (int, error) {
|
||||
if _, err := sock.Write(request); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
sock.SetReadDeadline(time.Now().Add(time.Second))
|
||||
n, err := sock.Read(response)
|
||||
|
||||
if err != nil {
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
return 0, fmt.Errorf("no response from tracker: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
func requestConnectionID(sock *net.UDPConn) ([]byte, error) {
|
||||
txID := makeTransactionID()
|
||||
request := []byte{}
|
||||
|
||||
request = append(request, initialConnectionID...)
|
||||
request = append(request, connectAction...)
|
||||
request = append(request, txID...)
|
||||
|
||||
response := make([]byte, 1024)
|
||||
n, err := sendRequest(sock, request, response)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if n != 16 {
|
||||
return nil, fmt.Errorf("packet length mismatch: %d != 16", n)
|
||||
}
|
||||
|
||||
if !bytes.Equal(response[4:8], txID) {
|
||||
return nil, fmt.Errorf("transaction ID mismatch")
|
||||
}
|
||||
|
||||
if !bytes.Equal(response[0:4], connectAction) {
|
||||
return nil, fmt.Errorf("action mismatch")
|
||||
}
|
||||
|
||||
return response[8:16], nil
|
||||
}
|
||||
|
||||
func TestRequestConnectionID(t *testing.T) {
|
||||
srv, done, err := setupTracker(&config.DefaultConfig)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, sock, err := setupSocket()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if _, err = requestConnectionID(sock); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
srv.Stop()
|
||||
<-done
|
||||
}
|
97
udp/writer.go
Normal file
97
udp/writer.go
Normal file
|
@ -0,0 +1,97 @@
|
|||
// Copyright 2015 The Chihaya Authors. All rights reserved.
|
||||
// Use of this source code is governed by the BSD 2-Clause license,
|
||||
// which can be found in the LICENSE file.
|
||||
|
||||
package udp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"time"
|
||||
|
||||
"github.com/chihaya/chihaya/tracker/models"
|
||||
)
|
||||
|
||||
// Writer implements the tracker.Writer interface for the UDP protocol.
|
||||
type Writer struct {
|
||||
buf *bytes.Buffer
|
||||
|
||||
connectionID []byte
|
||||
transactionID []byte
|
||||
}
|
||||
|
||||
// WriteError writes the failure reason as a null-terminated string.
|
||||
func (w *Writer) WriteError(err error) error {
|
||||
w.writeHeader(errorActionID)
|
||||
w.buf.WriteString(err.Error())
|
||||
w.buf.WriteRune('\000')
|
||||
return nil
|
||||
}
|
||||
|
||||
// WriteAnnounce encodes an announce response by selecting the proper announce
|
||||
// format based on the BitTorrent spec.
|
||||
func (w *Writer) WriteAnnounce(resp *models.AnnounceResponse) (err error) {
|
||||
if resp.Announce.HasIPv6() {
|
||||
err = w.WriteAnnounceIPv6(resp)
|
||||
} else {
|
||||
err = w.WriteAnnounceIPv4(resp)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// WriteAnnounceIPv6 encodes an announce response according to BEP45.
|
||||
func (w *Writer) WriteAnnounceIPv6(resp *models.AnnounceResponse) error {
|
||||
w.writeHeader(announceDualStackActionID)
|
||||
binary.Write(w.buf, binary.BigEndian, uint32(resp.Interval/time.Second))
|
||||
binary.Write(w.buf, binary.BigEndian, uint32(resp.Incomplete))
|
||||
binary.Write(w.buf, binary.BigEndian, uint32(resp.Complete))
|
||||
binary.Write(w.buf, binary.BigEndian, uint32(len(resp.IPv4Peers)))
|
||||
binary.Write(w.buf, binary.BigEndian, uint32(len(resp.IPv6Peers)))
|
||||
|
||||
for _, peer := range resp.IPv4Peers {
|
||||
w.buf.Write(peer.IP)
|
||||
binary.Write(w.buf, binary.BigEndian, peer.Port)
|
||||
}
|
||||
|
||||
for _, peer := range resp.IPv6Peers {
|
||||
w.buf.Write(peer.IP)
|
||||
binary.Write(w.buf, binary.BigEndian, peer.Port)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// WriteAnnounceIPv4 encodes an announce response according to BEP15.
|
||||
func (w *Writer) WriteAnnounceIPv4(resp *models.AnnounceResponse) error {
|
||||
w.writeHeader(announceActionID)
|
||||
binary.Write(w.buf, binary.BigEndian, uint32(resp.Interval/time.Second))
|
||||
binary.Write(w.buf, binary.BigEndian, uint32(resp.Incomplete))
|
||||
binary.Write(w.buf, binary.BigEndian, uint32(resp.Complete))
|
||||
|
||||
for _, peer := range resp.IPv4Peers {
|
||||
w.buf.Write(peer.IP)
|
||||
binary.Write(w.buf, binary.BigEndian, peer.Port)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// WriteScrape encodes a scrape response according to BEP15.
|
||||
func (w *Writer) WriteScrape(resp *models.ScrapeResponse) error {
|
||||
w.writeHeader(scrapeActionID)
|
||||
|
||||
for _, torrent := range resp.Files {
|
||||
binary.Write(w.buf, binary.BigEndian, uint32(torrent.Seeders.Len()))
|
||||
binary.Write(w.buf, binary.BigEndian, uint32(torrent.Snatches))
|
||||
binary.Write(w.buf, binary.BigEndian, uint32(torrent.Leechers.Len()))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeHeader writes the action and transaction ID to the response.
|
||||
func (w *Writer) writeHeader(action uint32) {
|
||||
binary.Write(w.buf, binary.BigEndian, action)
|
||||
w.buf.Write(w.transactionID)
|
||||
}
|
Loading…
Reference in a new issue