add dht
This commit is contained in:
parent
09745cbdea
commit
1f26aeeb5c
22 changed files with 3869 additions and 0 deletions
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
|
@ -0,0 +1 @@
|
|||
/.idea
|
1
dht/.gitignore
vendored
Normal file
1
dht/.gitignore
vendored
Normal file
|
@ -0,0 +1 @@
|
|||
.DS_Store
|
21
dht/LICENSE
Normal file
21
dht/LICENSE
Normal file
|
@ -0,0 +1,21 @@
|
|||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2015 Dean Karn
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
87
dht/README.md
Normal file
87
dht/README.md
Normal file
|
@ -0,0 +1,87 @@
|
|||
![](https://raw.githubusercontent.com/shiyanhui/dht/master/doc/screen-shot.png)
|
||||
|
||||
See the video on the [Youtube](https://www.youtube.com/watch?v=AIpeQtw22kc).
|
||||
|
||||
[中文版README](https://github.com/shiyanhui/dht/blob/master/README_CN.md)
|
||||
|
||||
## Introduction
|
||||
|
||||
DHT implements the bittorrent DHT protocol in Go. Now it includes:
|
||||
|
||||
- [BEP-3 (part)](http://www.bittorrent.org/beps/bep_0003.html)
|
||||
- [BEP-5](http://www.bittorrent.org/beps/bep_0005.html)
|
||||
- [BEP-9](http://www.bittorrent.org/beps/bep_0009.html)
|
||||
- [BEP-10](http://www.bittorrent.org/beps/bep_0010.html)
|
||||
|
||||
It contains two modes, the standard mode and the crawling mode. The standard
|
||||
mode follows the BEPs, and you can use it as a standard dht server. The crawling
|
||||
mode aims to crawl as more metadata info as possiple. It doesn't follow the
|
||||
standard BEPs protocol. With the crawling mode, you can build another [BTDigg](http://btdigg.org/).
|
||||
|
||||
[bthub.io](http://bthub.io) is a BT search engine based on the crawling mode.
|
||||
|
||||
## Installation
|
||||
|
||||
go get github.com/shiyanhui/dht
|
||||
|
||||
## Example
|
||||
|
||||
Below is a simple spider. You can move [here](https://github.com/shiyanhui/dht/blob/master/sample)
|
||||
to see more samples.
|
||||
|
||||
```go
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/shiyanhui/dht"
|
||||
)
|
||||
|
||||
func main() {
|
||||
downloader := dht.NewWire(65535)
|
||||
go func() {
|
||||
// once we got the request result
|
||||
for resp := range downloader.Response() {
|
||||
fmt.Println(resp.InfoHash, resp.MetadataInfo)
|
||||
}
|
||||
}()
|
||||
go downloader.Run()
|
||||
|
||||
config := dht.NewCrawlConfig()
|
||||
config.OnAnnouncePeer = func(infoHash, ip string, port int) {
|
||||
// request to download the metadata info
|
||||
downloader.Request([]byte(infoHash), ip, port)
|
||||
}
|
||||
d := dht.New(config)
|
||||
|
||||
d.Run()
|
||||
}
|
||||
```
|
||||
|
||||
## Download
|
||||
|
||||
You can download the demo compiled binary file [here](https://github.com/shiyanhui/dht/files/407021/spider.zip).
|
||||
|
||||
## Note
|
||||
|
||||
- The default crawl mode configure costs about 300M RAM. Set **MaxNodes**
|
||||
and **BlackListMaxSize** to fit yourself.
|
||||
- Now it cant't run in LAN because of NAT.
|
||||
|
||||
## TODO
|
||||
|
||||
- [ ] NAT Traversal.
|
||||
- [ ] Implements the full BEP-3.
|
||||
- [ ] Optimization.
|
||||
|
||||
## FAQ
|
||||
|
||||
#### Why it is slow compared to other spiders ?
|
||||
|
||||
Well, maybe there are several reasons.
|
||||
|
||||
- DHT aims to implements the standard BitTorrent DHT protocol, not born for crawling the DHT network.
|
||||
- NAT Traversal issue. You run the crawler in a local network.
|
||||
- It will block ip which looks like bad and a good ip may be mis-judged.
|
||||
|
||||
## License
|
||||
|
||||
MIT, read more [here](https://github.com/shiyanhui/dht/blob/master/LICENSE)
|
78
dht/README_CN.md
Normal file
78
dht/README_CN.md
Normal file
|
@ -0,0 +1,78 @@
|
|||
![](https://raw.githubusercontent.com/shiyanhui/dht/master/doc/screen-shot.png)
|
||||
|
||||
在这个视频上你可以看到爬取效果[Youtube](https://www.youtube.com/watch?v=AIpeQtw22kc).
|
||||
|
||||
## Introduction
|
||||
|
||||
DHT实现了BitTorrent DHT协议,主要包括:
|
||||
|
||||
- [BEP-3 (部分)](http://www.bittorrent.org/beps/bep_0003.html)
|
||||
- [BEP-5](http://www.bittorrent.org/beps/bep_0005.html)
|
||||
- [BEP-9](http://www.bittorrent.org/beps/bep_0009.html)
|
||||
- [BEP-10](http://www.bittorrent.org/beps/bep_0010.html)
|
||||
|
||||
它包含两种模式,标准模式和爬虫模式。标准模式遵循DHT协议,你可以把它当做一个标准
|
||||
的DHT组件。爬虫模式是为了嗅探到更多torrent文件信息,它在某些方面不遵循DHT协议。
|
||||
基于爬虫模式,你可以打造你自己的[BTDigg](http://btdigg.org/)。
|
||||
|
||||
[bthub.io](http://bthub.io)是一个基于这个爬虫而建的BT搜索引擎,你可以把他当做
|
||||
BTDigg的替代品。
|
||||
|
||||
## Installation
|
||||
|
||||
go get github.com/shiyanhui/dht
|
||||
|
||||
## Example
|
||||
|
||||
下面是一个简单的爬虫例子,你可以到[这里](https://github.com/shiyanhui/dht/blob/master/sample)看完整的Demo。
|
||||
|
||||
```go
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/shiyanhui/dht"
|
||||
)
|
||||
|
||||
func main() {
|
||||
downloader := dht.NewWire(65536)
|
||||
go func() {
|
||||
// once we got the request result
|
||||
for resp := range downloader.Response() {
|
||||
fmt.Println(resp.InfoHash, resp.MetadataInfo)
|
||||
}
|
||||
}()
|
||||
go downloader.Run()
|
||||
|
||||
config := dht.NewCrawlConfig()
|
||||
config.OnAnnouncePeer = func(infoHash, ip string, port int) {
|
||||
// request to download the metadata info
|
||||
downloader.Request([]byte(infoHash), ip, port)
|
||||
}
|
||||
d := dht.New(config)
|
||||
|
||||
d.Run()
|
||||
}
|
||||
```
|
||||
|
||||
## Download
|
||||
|
||||
这个是已经编译好的Demo二进制文件,你可以到这里[下载](https://github.com/shiyanhui/dht/files/407021/spider.zip)。
|
||||
|
||||
## 注意
|
||||
|
||||
- 默认的爬虫配置需要300M左右内存,你可以根据你的服务器内存大小调整MaxNodes和
|
||||
BlackListMaxSize
|
||||
- 目前还不能穿透NAT,因此还不能在局域网运行
|
||||
|
||||
## TODO
|
||||
|
||||
- [ ] NAT穿透,在局域网内也能够运行
|
||||
- [ ] 完整地实现BEP-3,这样不但能够下载种子,也能够下载资源
|
||||
- [ ] 优化
|
||||
|
||||
## Blog
|
||||
|
||||
你可以在[这里](https://github.com/shiyanhui/dht/wiki)看到DHT Spider教程。
|
||||
|
||||
## License
|
||||
|
||||
[MIT](https://github.com/shiyanhui/dht/blob/master/LICENSE)
|
263
dht/bencode.go
Normal file
263
dht/bencode.go
Normal file
|
@ -0,0 +1,263 @@
|
|||
package dht
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
// find returns the index of first target in data starting from `start`.
|
||||
// It returns -1 if target not found.
|
||||
func find(data []byte, start int, target rune) (index int) {
|
||||
index = bytes.IndexRune(data[start:], target)
|
||||
if index != -1 {
|
||||
return index + start
|
||||
}
|
||||
return index
|
||||
}
|
||||
|
||||
// DecodeString decodes a string in the data. It returns a tuple
|
||||
// (decoded result, the end position, error).
|
||||
func DecodeString(data []byte, start int) (
|
||||
result interface{}, index int, err error) {
|
||||
|
||||
if start >= len(data) || data[start] < '0' || data[start] > '9' {
|
||||
err = errors.New("invalid string bencode")
|
||||
return
|
||||
}
|
||||
|
||||
i := find(data, start, ':')
|
||||
if i == -1 {
|
||||
err = errors.New("':' not found when decode string")
|
||||
return
|
||||
}
|
||||
|
||||
length, err := strconv.Atoi(string(data[start:i]))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if length < 0 {
|
||||
err = errors.New("invalid length of string")
|
||||
return
|
||||
}
|
||||
|
||||
index = i + 1 + length
|
||||
|
||||
if index > len(data) || index < i+1 {
|
||||
err = errors.New("out of range")
|
||||
return
|
||||
}
|
||||
|
||||
result = string(data[i+1 : index])
|
||||
return
|
||||
}
|
||||
|
||||
// DecodeInt decodes int value in the data.
|
||||
func DecodeInt(data []byte, start int) (
|
||||
result interface{}, index int, err error) {
|
||||
|
||||
if start >= len(data) || data[start] != 'i' {
|
||||
err = errors.New("invalid int bencode")
|
||||
return
|
||||
}
|
||||
|
||||
index = find(data, start+1, 'e')
|
||||
|
||||
if index == -1 {
|
||||
err = errors.New("':' not found when decode int")
|
||||
return
|
||||
}
|
||||
|
||||
result, err = strconv.Atoi(string(data[start+1 : index]))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
index++
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// decodeItem decodes an item of dict or list.
|
||||
func decodeItem(data []byte, i int) (
|
||||
result interface{}, index int, err error) {
|
||||
|
||||
var decodeFunc = []func([]byte, int) (interface{}, int, error){
|
||||
DecodeString, DecodeInt, DecodeList, DecodeDict,
|
||||
}
|
||||
|
||||
for _, f := range decodeFunc {
|
||||
result, index, err = f(data, i)
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
err = errors.New("invalid bencode when decode item")
|
||||
return
|
||||
}
|
||||
|
||||
// DecodeList decodes a list value.
|
||||
func DecodeList(data []byte, start int) (
|
||||
result interface{}, index int, err error) {
|
||||
|
||||
if start >= len(data) || data[start] != 'l' {
|
||||
err = errors.New("invalid list bencode")
|
||||
return
|
||||
}
|
||||
|
||||
var item interface{}
|
||||
r := make([]interface{}, 0, 8)
|
||||
|
||||
index = start + 1
|
||||
for index < len(data) {
|
||||
char, _ := utf8.DecodeRune(data[index:])
|
||||
if char == 'e' {
|
||||
break
|
||||
}
|
||||
|
||||
item, index, err = decodeItem(data, index)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
r = append(r, item)
|
||||
}
|
||||
|
||||
if index == len(data) {
|
||||
err = errors.New("'e' not found when decode list")
|
||||
return
|
||||
}
|
||||
index++
|
||||
|
||||
result = r
|
||||
return
|
||||
}
|
||||
|
||||
// DecodeDict decodes a map value.
|
||||
func DecodeDict(data []byte, start int) (
|
||||
result interface{}, index int, err error) {
|
||||
|
||||
if start >= len(data) || data[start] != 'd' {
|
||||
err = errors.New("invalid dict bencode")
|
||||
return
|
||||
}
|
||||
|
||||
var item, key interface{}
|
||||
r := make(map[string]interface{})
|
||||
|
||||
index = start + 1
|
||||
for index < len(data) {
|
||||
char, _ := utf8.DecodeRune(data[index:])
|
||||
if char == 'e' {
|
||||
break
|
||||
}
|
||||
|
||||
if !unicode.IsDigit(char) {
|
||||
err = errors.New("invalid dict bencode")
|
||||
return
|
||||
}
|
||||
|
||||
key, index, err = DecodeString(data, index)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if index >= len(data) {
|
||||
err = errors.New("out of range")
|
||||
return
|
||||
}
|
||||
|
||||
item, index, err = decodeItem(data, index)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
r[key.(string)] = item
|
||||
}
|
||||
|
||||
if index == len(data) {
|
||||
err = errors.New("'e' not found when decode dict")
|
||||
return
|
||||
}
|
||||
index++
|
||||
|
||||
result = r
|
||||
return
|
||||
}
|
||||
|
||||
// Decode decodes a bencoded string to string, int, list or map.
|
||||
func Decode(data []byte) (result interface{}, err error) {
|
||||
result, _, err = decodeItem(data, 0)
|
||||
return
|
||||
}
|
||||
|
||||
// EncodeString encodes a string value.
|
||||
func EncodeString(data string) string {
|
||||
return strings.Join([]string{strconv.Itoa(len(data)), data}, ":")
|
||||
}
|
||||
|
||||
// EncodeInt encodes a int value.
|
||||
func EncodeInt(data int) string {
|
||||
return strings.Join([]string{"i", strconv.Itoa(data), "e"}, "")
|
||||
}
|
||||
|
||||
// EncodeItem encodes an item of dict or list.
|
||||
func encodeItem(data interface{}) (item string) {
|
||||
switch v := data.(type) {
|
||||
case string:
|
||||
item = EncodeString(v)
|
||||
case int:
|
||||
item = EncodeInt(v)
|
||||
case []interface{}:
|
||||
item = EncodeList(v)
|
||||
case map[string]interface{}:
|
||||
item = EncodeDict(v)
|
||||
default:
|
||||
panic("invalid type when encode item")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// EncodeList encodes a list value.
|
||||
func EncodeList(data []interface{}) string {
|
||||
result := make([]string, len(data))
|
||||
|
||||
for i, item := range data {
|
||||
result[i] = encodeItem(item)
|
||||
}
|
||||
|
||||
return strings.Join([]string{"l", strings.Join(result, ""), "e"}, "")
|
||||
}
|
||||
|
||||
// EncodeDict encodes a dict value.
|
||||
func EncodeDict(data map[string]interface{}) string {
|
||||
result, i := make([]string, len(data)), 0
|
||||
|
||||
for key, val := range data {
|
||||
result[i] = strings.Join(
|
||||
[]string{EncodeString(key), encodeItem(val)},
|
||||
"")
|
||||
i++
|
||||
}
|
||||
|
||||
return strings.Join([]string{"d", strings.Join(result, ""), "e"}, "")
|
||||
}
|
||||
|
||||
// Encode encodes a string, int, dict or list value to a bencoded string.
|
||||
func Encode(data interface{}) string {
|
||||
switch v := data.(type) {
|
||||
case string:
|
||||
return EncodeString(v)
|
||||
case int:
|
||||
return EncodeInt(v)
|
||||
case []interface{}:
|
||||
return EncodeList(v)
|
||||
case map[string]interface{}:
|
||||
return EncodeDict(v)
|
||||
default:
|
||||
panic("invalid type when encode")
|
||||
}
|
||||
}
|
159
dht/bencode_test.go
Normal file
159
dht/bencode_test.go
Normal file
|
@ -0,0 +1,159 @@
|
|||
package dht
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDecodeString(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
out string
|
||||
}{
|
||||
{"0:", ""},
|
||||
{"1:a", "a"},
|
||||
{"5:hello", "hello"},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
if out, err := Decode([]byte(c.in)); err != nil || out != c.out {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeInt(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
out int
|
||||
}{
|
||||
{"i123e:", 123},
|
||||
{"i0e", 0},
|
||||
{"i-1e", -1},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
if out, err := Decode([]byte(c.in)); err != nil || out != c.out {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeList(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
out []interface{}
|
||||
}{
|
||||
{"li123ei-1ee", []interface{}{123, -1}},
|
||||
{"l5:helloe", []interface{}{"hello"}},
|
||||
{"ld5:hello5:worldee", []interface{}{map[string]interface{}{"hello": "world"}}},
|
||||
{"lli1ei2eee", []interface{}{[]interface{}{1, 2}}},
|
||||
}
|
||||
|
||||
for i, c := range cases {
|
||||
v, err := Decode([]byte(c.in))
|
||||
if err != nil {
|
||||
t.Fail()
|
||||
}
|
||||
|
||||
out := v.([]interface{})
|
||||
|
||||
switch i {
|
||||
case 0, 1:
|
||||
for j, item := range out {
|
||||
if item != c.out[j] {
|
||||
t.Fail()
|
||||
}
|
||||
}
|
||||
case 2:
|
||||
if len(out) != 1 {
|
||||
t.Fail()
|
||||
}
|
||||
|
||||
o := out[0].(map[string]interface{})
|
||||
cout := c.out[0].(map[string]interface{})
|
||||
|
||||
for k, v := range o {
|
||||
if cv, ok := cout[k]; !ok || v != cv {
|
||||
t.Fail()
|
||||
}
|
||||
}
|
||||
case 3:
|
||||
if len(out) != 1 {
|
||||
t.Fail()
|
||||
}
|
||||
|
||||
o := out[0].([]interface{})
|
||||
cout := c.out[0].([]interface{})
|
||||
|
||||
for j, item := range o {
|
||||
if item != cout[j] {
|
||||
t.Fail()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeDict(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
out map[string]interface{}
|
||||
}{
|
||||
{"d5:helloi100ee", map[string]interface{}{"hello": 100}},
|
||||
{"d3:foo3:bare", map[string]interface{}{"foo": "bar"}},
|
||||
{"d1:ad3:foo3:baree", map[string]interface{}{"a": map[string]interface{}{"foo": "bar"}}},
|
||||
{"d4:listli1eee", map[string]interface{}{"list": []interface{}{1}}},
|
||||
}
|
||||
|
||||
for i, c := range cases {
|
||||
v, err := Decode([]byte(c.in))
|
||||
if err != nil {
|
||||
t.Fail()
|
||||
}
|
||||
|
||||
out := v.(map[string]interface{})
|
||||
|
||||
switch i {
|
||||
case 0, 1:
|
||||
for k, v := range out {
|
||||
if cv, ok := c.out[k]; !ok || v != cv {
|
||||
t.Fail()
|
||||
}
|
||||
}
|
||||
case 2:
|
||||
if len(out) != 1 {
|
||||
t.Fail()
|
||||
}
|
||||
|
||||
v, ok := out["a"]
|
||||
if !ok {
|
||||
t.Fail()
|
||||
}
|
||||
|
||||
cout := c.out["a"].(map[string]interface{})
|
||||
|
||||
for k, v := range v.(map[string]interface{}) {
|
||||
if cv, ok := cout[k]; !ok || v != cv {
|
||||
t.Fail()
|
||||
}
|
||||
}
|
||||
case 3:
|
||||
if len(out) != 1 {
|
||||
t.Fail()
|
||||
}
|
||||
|
||||
v, ok := out["list"]
|
||||
if !ok {
|
||||
t.Fail()
|
||||
}
|
||||
|
||||
cout := c.out["list"].([]interface{})
|
||||
|
||||
for j, v := range v.([]interface{}) {
|
||||
if v != cout[j] {
|
||||
t.Fail()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
163
dht/bitmap.go
Normal file
163
dht/bitmap.go
Normal file
|
@ -0,0 +1,163 @@
|
|||
package dht
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// bitmap represents a bit array.
|
||||
type bitmap struct {
|
||||
Size int
|
||||
data []byte
|
||||
}
|
||||
|
||||
// newBitmap returns a size-length bitmap pointer.
|
||||
func newBitmap(size int) *bitmap {
|
||||
div, mod := size/8, size%8
|
||||
if mod > 0 {
|
||||
div++
|
||||
}
|
||||
return &bitmap{size, make([]byte, div)}
|
||||
}
|
||||
|
||||
// newBitmapFrom returns a new copyed bitmap pointer which
|
||||
// newBitmap.data = other.data[:size].
|
||||
func newBitmapFrom(other *bitmap, size int) *bitmap {
|
||||
bitmap := newBitmap(size)
|
||||
|
||||
if size > other.Size {
|
||||
size = other.Size
|
||||
}
|
||||
|
||||
div := size / 8
|
||||
|
||||
for i := 0; i < div; i++ {
|
||||
bitmap.data[i] = other.data[i]
|
||||
}
|
||||
|
||||
for i := div * 8; i < size; i++ {
|
||||
if other.Bit(i) == 1 {
|
||||
bitmap.Set(i)
|
||||
}
|
||||
}
|
||||
|
||||
return bitmap
|
||||
}
|
||||
|
||||
// newBitmapFromBytes returns a bitmap pointer created from a byte array.
|
||||
func newBitmapFromBytes(data []byte) *bitmap {
|
||||
bitmap := newBitmap(len(data) * 8)
|
||||
copy(bitmap.data, data)
|
||||
return bitmap
|
||||
}
|
||||
|
||||
// newBitmapFromString returns a bitmap pointer created from a string.
|
||||
func newBitmapFromString(data string) *bitmap {
|
||||
return newBitmapFromBytes([]byte(data))
|
||||
}
|
||||
|
||||
// Bit returns the bit at index.
|
||||
func (bitmap *bitmap) Bit(index int) int {
|
||||
if index >= bitmap.Size {
|
||||
panic("index out of range")
|
||||
}
|
||||
|
||||
div, mod := index/8, index%8
|
||||
return int((uint(bitmap.data[div]) & (1 << uint(7-mod))) >> uint(7-mod))
|
||||
}
|
||||
|
||||
// set sets the bit at index `index`. If bit is true, set 1, otherwise set 0.
|
||||
func (bitmap *bitmap) set(index int, bit int) {
|
||||
if index >= bitmap.Size {
|
||||
panic("index out of range")
|
||||
}
|
||||
|
||||
div, mod := index/8, index%8
|
||||
shift := byte(1 << uint(7-mod))
|
||||
|
||||
bitmap.data[div] &= ^shift
|
||||
if bit > 0 {
|
||||
bitmap.data[div] |= shift
|
||||
}
|
||||
}
|
||||
|
||||
// Set sets the bit at idnex to 1.
|
||||
func (bitmap *bitmap) Set(index int) {
|
||||
bitmap.set(index, 1)
|
||||
}
|
||||
|
||||
// Unset sets the bit at idnex to 0.
|
||||
func (bitmap *bitmap) Unset(index int) {
|
||||
bitmap.set(index, 0)
|
||||
}
|
||||
|
||||
// Compare compares the prefixLen-prefix of two bitmap.
|
||||
// - If bitmap.data[:prefixLen] < other.data[:prefixLen], return -1.
|
||||
// - If bitmap.data[:prefixLen] > other.data[:prefixLen], return 1.
|
||||
// - Otherwise return 0.
|
||||
func (bitmap *bitmap) Compare(other *bitmap, prefixLen int) int {
|
||||
if prefixLen > bitmap.Size || prefixLen > other.Size {
|
||||
panic("index out of range")
|
||||
}
|
||||
|
||||
div, mod := prefixLen/8, prefixLen%8
|
||||
for i := 0; i < div; i++ {
|
||||
if bitmap.data[i] > other.data[i] {
|
||||
return 1
|
||||
} else if bitmap.data[i] < other.data[i] {
|
||||
return -1
|
||||
}
|
||||
}
|
||||
|
||||
for i := div * 8; i < div*8+mod; i++ {
|
||||
bit1, bit2 := bitmap.Bit(i), other.Bit(i)
|
||||
if bit1 > bit2 {
|
||||
return 1
|
||||
} else if bit1 < bit2 {
|
||||
return -1
|
||||
}
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
// Xor returns the xor value of two bitmap.
|
||||
func (bitmap *bitmap) Xor(other *bitmap) *bitmap {
|
||||
if bitmap.Size != other.Size {
|
||||
panic("size not the same")
|
||||
}
|
||||
|
||||
distance := newBitmap(bitmap.Size)
|
||||
div, mod := distance.Size/8, distance.Size%8
|
||||
|
||||
for i := 0; i < div; i++ {
|
||||
distance.data[i] = bitmap.data[i] ^ other.data[i]
|
||||
}
|
||||
|
||||
for i := div * 8; i < div*8+mod; i++ {
|
||||
distance.set(i, bitmap.Bit(i)^other.Bit(i))
|
||||
}
|
||||
|
||||
return distance
|
||||
}
|
||||
|
||||
// String returns the bit sequence string of the bitmap.
|
||||
func (bitmap *bitmap) String() string {
|
||||
div, mod := bitmap.Size/8, bitmap.Size%8
|
||||
buff := make([]string, div+mod)
|
||||
|
||||
for i := 0; i < div; i++ {
|
||||
buff[i] = fmt.Sprintf("%08b", bitmap.data[i])
|
||||
}
|
||||
|
||||
for i := div; i < div+mod; i++ {
|
||||
buff[i] = fmt.Sprintf("%1b", bitmap.Bit(div*8+(i-div)))
|
||||
}
|
||||
|
||||
return strings.Join(buff, "")
|
||||
}
|
||||
|
||||
// RawString returns the string value of bitmap.data.
|
||||
func (bitmap *bitmap) RawString() string {
|
||||
return string(bitmap.data)
|
||||
}
|
69
dht/bitmap_test.go
Normal file
69
dht/bitmap_test.go
Normal file
|
@ -0,0 +1,69 @@
|
|||
package dht
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBitmap(t *testing.T) {
|
||||
a := newBitmap(10)
|
||||
b := newBitmapFrom(a, 10)
|
||||
c := newBitmapFromBytes([]byte{48, 49, 50, 51, 52, 53, 54, 55, 56, 57})
|
||||
d := newBitmapFromString("0123456789")
|
||||
e := newBitmap(10)
|
||||
|
||||
// Bit
|
||||
for i := 0; i < a.Size; i++ {
|
||||
if a.Bit(i) != 0 {
|
||||
t.Fail()
|
||||
}
|
||||
}
|
||||
|
||||
// Compare
|
||||
if c.Compare(d, d.Size) != 0 {
|
||||
t.Fail()
|
||||
}
|
||||
|
||||
// RawString
|
||||
if c.RawString() != d.RawString() || c.RawString() != "0123456789" {
|
||||
t.Fail()
|
||||
}
|
||||
|
||||
// Set
|
||||
b.Set(5)
|
||||
if b.Bit(5) != 1 {
|
||||
t.Fail()
|
||||
}
|
||||
|
||||
// Unset
|
||||
b.Unset(5)
|
||||
if b.Bit(5) == 1 {
|
||||
t.Fail()
|
||||
}
|
||||
|
||||
// String
|
||||
if e.String() != "0000000000" {
|
||||
t.Fail()
|
||||
}
|
||||
e.Set(9)
|
||||
if e.String() != "0000000001" {
|
||||
t.Fail()
|
||||
}
|
||||
e.Set(2)
|
||||
if e.String() != "0010000001" {
|
||||
t.Fail()
|
||||
}
|
||||
|
||||
a.Set(0)
|
||||
a.Set(5)
|
||||
a.Set(8)
|
||||
if a.String() != "1000010010" {
|
||||
t.Fail()
|
||||
}
|
||||
|
||||
// Xor
|
||||
b.Set(5)
|
||||
b.Set(9)
|
||||
if a.Xor(b).String() != "1000000011" {
|
||||
t.Fail()
|
||||
}
|
||||
}
|
92
dht/blacklist.go
Normal file
92
dht/blacklist.go
Normal file
|
@ -0,0 +1,92 @@
|
|||
package dht
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// blockedItem represents a blocked node.
|
||||
type blockedItem struct {
|
||||
ip string
|
||||
port int
|
||||
createTime time.Time
|
||||
}
|
||||
|
||||
// blackList manages the blocked nodes including which sends bad information
|
||||
// and can't ping out.
|
||||
type blackList struct {
|
||||
list *syncedMap
|
||||
maxSize int
|
||||
expiredAfter time.Duration
|
||||
}
|
||||
|
||||
// newBlackList returns a blackList pointer.
|
||||
func newBlackList(size int) *blackList {
|
||||
return &blackList{
|
||||
list: newSyncedMap(),
|
||||
maxSize: size,
|
||||
expiredAfter: time.Hour * 1,
|
||||
}
|
||||
}
|
||||
|
||||
// genKey returns a key. If port is less than 0, the key wil be ip. Ohterwise
|
||||
// it will be `ip:port` format.
|
||||
func (bl *blackList) genKey(ip string, port int) string {
|
||||
key := ip
|
||||
if port >= 0 {
|
||||
key = genAddress(ip, port)
|
||||
}
|
||||
return key
|
||||
}
|
||||
|
||||
// insert adds a blocked item to the blacklist.
|
||||
func (bl *blackList) insert(ip string, port int) {
|
||||
if bl.list.Len() >= bl.maxSize {
|
||||
return
|
||||
}
|
||||
|
||||
bl.list.Set(bl.genKey(ip, port), &blockedItem{
|
||||
ip: ip,
|
||||
port: port,
|
||||
createTime: time.Now(),
|
||||
})
|
||||
}
|
||||
|
||||
// delete removes blocked item form the blackList.
|
||||
func (bl *blackList) delete(ip string, port int) {
|
||||
bl.list.Delete(bl.genKey(ip, port))
|
||||
}
|
||||
|
||||
// validate checks whether ip-port pair is in the block nodes list.
|
||||
func (bl *blackList) in(ip string, port int) bool {
|
||||
if _, ok := bl.list.Get(ip); ok {
|
||||
return true
|
||||
}
|
||||
|
||||
key := bl.genKey(ip, port)
|
||||
|
||||
v, ok := bl.list.Get(key)
|
||||
if ok {
|
||||
if time.Now().Sub(v.(*blockedItem).createTime) < bl.expiredAfter {
|
||||
return true
|
||||
}
|
||||
bl.list.Delete(key)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// clear cleans the expired items every 10 minutes.
|
||||
func (bl *blackList) clear() {
|
||||
for _ = range time.Tick(time.Minute * 10) {
|
||||
keys := make([]interface{}, 0, 100)
|
||||
|
||||
for item := range bl.list.Iter() {
|
||||
if time.Now().Sub(
|
||||
item.val.(*blockedItem).createTime) > bl.expiredAfter {
|
||||
|
||||
keys = append(keys, item.key)
|
||||
}
|
||||
}
|
||||
|
||||
bl.list.DeleteMulti(keys)
|
||||
}
|
||||
}
|
57
dht/blacklist_test.go
Normal file
57
dht/blacklist_test.go
Normal file
|
@ -0,0 +1,57 @@
|
|||
package dht
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var blacklist = newBlackList(256)
|
||||
|
||||
func TestGenKey(t *testing.T) {
|
||||
cases := []struct {
|
||||
in struct {
|
||||
ip string
|
||||
port int
|
||||
}
|
||||
out string
|
||||
}{
|
||||
{struct {
|
||||
ip string
|
||||
port int
|
||||
}{"0.0.0.0", -1}, "0.0.0.0"},
|
||||
{struct {
|
||||
ip string
|
||||
port int
|
||||
}{"1.1.1.1", 8080}, "1.1.1.1:8080"},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
if blacklist.genKey(c.in.ip, c.in.port) != c.out {
|
||||
t.Fail()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBlackList(t *testing.T) {
|
||||
address := []struct {
|
||||
ip string
|
||||
port int
|
||||
}{
|
||||
{"0.0.0.0", -1},
|
||||
{"1.1.1.1", 8080},
|
||||
{"2.2.2.2", 8081},
|
||||
}
|
||||
|
||||
for _, addr := range address {
|
||||
blacklist.insert(addr.ip, addr.port)
|
||||
if !blacklist.in(addr.ip, addr.port) {
|
||||
t.Fail()
|
||||
}
|
||||
|
||||
blacklist.delete(addr.ip, addr.port)
|
||||
if blacklist.in(addr.ip, addr.port) {
|
||||
fmt.Println(addr.ip)
|
||||
t.Fail()
|
||||
}
|
||||
}
|
||||
}
|
289
dht/container.go
Normal file
289
dht/container.go
Normal file
|
@ -0,0 +1,289 @@
|
|||
package dht
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type mapItem struct {
|
||||
key interface{}
|
||||
val interface{}
|
||||
}
|
||||
|
||||
// syncedMap represents a goroutine-safe map.
|
||||
type syncedMap struct {
|
||||
*sync.RWMutex
|
||||
data map[interface{}]interface{}
|
||||
}
|
||||
|
||||
// newSyncedMap returns a syncedMap pointer.
|
||||
func newSyncedMap() *syncedMap {
|
||||
return &syncedMap{
|
||||
RWMutex: &sync.RWMutex{},
|
||||
data: make(map[interface{}]interface{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Get returns the value mapped to key.
|
||||
func (smap *syncedMap) Get(key interface{}) (val interface{}, ok bool) {
|
||||
smap.RLock()
|
||||
defer smap.RUnlock()
|
||||
|
||||
val, ok = smap.data[key]
|
||||
return
|
||||
}
|
||||
|
||||
// Has returns whether the syncedMap contains the key.
|
||||
func (smap *syncedMap) Has(key interface{}) bool {
|
||||
_, ok := smap.Get(key)
|
||||
return ok
|
||||
}
|
||||
|
||||
// Set sets pair {key: val}.
|
||||
func (smap *syncedMap) Set(key interface{}, val interface{}) {
|
||||
smap.Lock()
|
||||
defer smap.Unlock()
|
||||
|
||||
smap.data[key] = val
|
||||
}
|
||||
|
||||
// Delete deletes the key in the map.
|
||||
func (smap *syncedMap) Delete(key interface{}) {
|
||||
smap.Lock()
|
||||
defer smap.Unlock()
|
||||
|
||||
delete(smap.data, key)
|
||||
}
|
||||
|
||||
// DeleteMulti deletes keys in batch.
|
||||
func (smap *syncedMap) DeleteMulti(keys []interface{}) {
|
||||
smap.Lock()
|
||||
defer smap.Unlock()
|
||||
|
||||
for _, key := range keys {
|
||||
delete(smap.data, key)
|
||||
}
|
||||
}
|
||||
|
||||
// Clear resets the data.
|
||||
func (smap *syncedMap) Clear() {
|
||||
smap.Lock()
|
||||
defer smap.Unlock()
|
||||
|
||||
smap.data = make(map[interface{}]interface{})
|
||||
}
|
||||
|
||||
// Iter returns a chan which output all items.
|
||||
func (smap *syncedMap) Iter() <-chan mapItem {
|
||||
ch := make(chan mapItem)
|
||||
go func() {
|
||||
smap.RLock()
|
||||
for key, val := range smap.data {
|
||||
ch <- mapItem{
|
||||
key: key,
|
||||
val: val,
|
||||
}
|
||||
}
|
||||
smap.RUnlock()
|
||||
close(ch)
|
||||
}()
|
||||
return ch
|
||||
}
|
||||
|
||||
// Len returns the length of syncedMap.
|
||||
func (smap *syncedMap) Len() int {
|
||||
smap.RLock()
|
||||
defer smap.RUnlock()
|
||||
|
||||
return len(smap.data)
|
||||
}
|
||||
|
||||
// syncedList represents a goroutine-safe list.
|
||||
type syncedList struct {
|
||||
*sync.RWMutex
|
||||
queue *list.List
|
||||
}
|
||||
|
||||
// newSyncedList returns a syncedList pointer.
|
||||
func newSyncedList() *syncedList {
|
||||
return &syncedList{
|
||||
RWMutex: &sync.RWMutex{},
|
||||
queue: list.New(),
|
||||
}
|
||||
}
|
||||
|
||||
// Front returns the first element of slist.
|
||||
func (slist *syncedList) Front() *list.Element {
|
||||
slist.RLock()
|
||||
defer slist.RUnlock()
|
||||
|
||||
return slist.queue.Front()
|
||||
}
|
||||
|
||||
// Back returns the last element of slist.
|
||||
func (slist *syncedList) Back() *list.Element {
|
||||
slist.RLock()
|
||||
defer slist.RUnlock()
|
||||
|
||||
return slist.queue.Back()
|
||||
}
|
||||
|
||||
// PushFront pushs an element to the head of slist.
|
||||
func (slist *syncedList) PushFront(v interface{}) *list.Element {
|
||||
slist.Lock()
|
||||
defer slist.Unlock()
|
||||
|
||||
return slist.queue.PushFront(v)
|
||||
}
|
||||
|
||||
// PushBack pushs an element to the tail of slist.
|
||||
func (slist *syncedList) PushBack(v interface{}) *list.Element {
|
||||
slist.Lock()
|
||||
defer slist.Unlock()
|
||||
|
||||
return slist.queue.PushBack(v)
|
||||
}
|
||||
|
||||
// InsertBefore inserts v before mark.
|
||||
func (slist *syncedList) InsertBefore(
|
||||
v interface{}, mark *list.Element) *list.Element {
|
||||
|
||||
slist.Lock()
|
||||
defer slist.Unlock()
|
||||
|
||||
return slist.queue.InsertBefore(v, mark)
|
||||
}
|
||||
|
||||
// InsertAfter inserts v after mark.
|
||||
func (slist *syncedList) InsertAfter(
|
||||
v interface{}, mark *list.Element) *list.Element {
|
||||
|
||||
slist.Lock()
|
||||
defer slist.Unlock()
|
||||
|
||||
return slist.queue.InsertAfter(v, mark)
|
||||
}
|
||||
|
||||
// Remove removes e from the slist.
|
||||
func (slist *syncedList) Remove(e *list.Element) interface{} {
|
||||
slist.Lock()
|
||||
defer slist.Unlock()
|
||||
|
||||
return slist.queue.Remove(e)
|
||||
}
|
||||
|
||||
// Clear resets the list queue.
|
||||
func (slist *syncedList) Clear() {
|
||||
slist.Lock()
|
||||
defer slist.Unlock()
|
||||
|
||||
slist.queue.Init()
|
||||
}
|
||||
|
||||
// Len returns length of the slist.
|
||||
func (slist *syncedList) Len() int {
|
||||
slist.RLock()
|
||||
defer slist.RUnlock()
|
||||
|
||||
return slist.queue.Len()
|
||||
}
|
||||
|
||||
// Iter returns a chan which output all elements.
|
||||
func (slist *syncedList) Iter() <-chan *list.Element {
|
||||
ch := make(chan *list.Element)
|
||||
go func() {
|
||||
slist.RLock()
|
||||
for e := slist.queue.Front(); e != nil; e = e.Next() {
|
||||
ch <- e
|
||||
}
|
||||
slist.RUnlock()
|
||||
close(ch)
|
||||
}()
|
||||
return ch
|
||||
}
|
||||
|
||||
// KeyedDeque represents a keyed deque.
|
||||
type keyedDeque struct {
|
||||
*sync.RWMutex
|
||||
*syncedList
|
||||
index map[interface{}]*list.Element
|
||||
invertedIndex map[*list.Element]interface{}
|
||||
}
|
||||
|
||||
// newKeyedDeque returns a newKeyedDeque pointer.
|
||||
func newKeyedDeque() *keyedDeque {
|
||||
return &keyedDeque{
|
||||
RWMutex: &sync.RWMutex{},
|
||||
syncedList: newSyncedList(),
|
||||
index: make(map[interface{}]*list.Element),
|
||||
invertedIndex: make(map[*list.Element]interface{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Push pushs a keyed-value to the end of deque.
|
||||
func (deque *keyedDeque) Push(key interface{}, val interface{}) {
|
||||
deque.Lock()
|
||||
defer deque.Unlock()
|
||||
|
||||
if e, ok := deque.index[key]; ok {
|
||||
deque.syncedList.Remove(e)
|
||||
}
|
||||
deque.index[key] = deque.syncedList.PushBack(val)
|
||||
deque.invertedIndex[deque.index[key]] = key
|
||||
}
|
||||
|
||||
// Get returns the keyed value.
|
||||
func (deque *keyedDeque) Get(key interface{}) (*list.Element, bool) {
|
||||
deque.RLock()
|
||||
defer deque.RUnlock()
|
||||
|
||||
v, ok := deque.index[key]
|
||||
return v, ok
|
||||
}
|
||||
|
||||
// Has returns whether key already exists.
|
||||
func (deque *keyedDeque) HasKey(key interface{}) bool {
|
||||
_, ok := deque.Get(key)
|
||||
return ok
|
||||
}
|
||||
|
||||
// Delete deletes a value named key.
|
||||
func (deque *keyedDeque) Delete(key interface{}) (v interface{}) {
|
||||
deque.RLock()
|
||||
e, ok := deque.index[key]
|
||||
deque.RUnlock()
|
||||
|
||||
deque.Lock()
|
||||
defer deque.Unlock()
|
||||
|
||||
if ok {
|
||||
v = deque.syncedList.Remove(e)
|
||||
delete(deque.index, key)
|
||||
delete(deque.invertedIndex, e)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Removes overwrites list.List.Remove.
|
||||
func (deque *keyedDeque) Remove(e *list.Element) (v interface{}) {
|
||||
deque.RLock()
|
||||
key, ok := deque.invertedIndex[e]
|
||||
deque.RUnlock()
|
||||
|
||||
if ok {
|
||||
v = deque.Delete(key)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Clear resets the deque.
|
||||
func (deque *keyedDeque) Clear() {
|
||||
deque.Lock()
|
||||
defer deque.Unlock()
|
||||
|
||||
deque.syncedList.Clear()
|
||||
deque.index = make(map[interface{}]*list.Element)
|
||||
deque.invertedIndex = make(map[*list.Element]interface{})
|
||||
}
|
196
dht/container_test.go
Normal file
196
dht/container_test.go
Normal file
|
@ -0,0 +1,196 @@
|
|||
package dht
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSyncedMap(t *testing.T) {
|
||||
cases := []mapItem{
|
||||
{"a", 0},
|
||||
{"b", 1},
|
||||
{"c", 2},
|
||||
}
|
||||
|
||||
sm := newSyncedMap()
|
||||
|
||||
set := func() {
|
||||
group := sync.WaitGroup{}
|
||||
for _, item := range cases {
|
||||
group.Add(1)
|
||||
go func(item mapItem) {
|
||||
sm.Set(item.key, item.val)
|
||||
group.Done()
|
||||
}(item)
|
||||
}
|
||||
group.Wait()
|
||||
}
|
||||
|
||||
isEmpty := func() {
|
||||
if sm.Len() != 0 {
|
||||
t.Fail()
|
||||
}
|
||||
}
|
||||
|
||||
// Set
|
||||
set()
|
||||
if sm.Len() != len(cases) {
|
||||
t.Fail()
|
||||
}
|
||||
|
||||
Loop:
|
||||
// Iter
|
||||
for item := range sm.Iter() {
|
||||
for _, c := range cases {
|
||||
if item.key == c.key && item.val == c.val {
|
||||
continue Loop
|
||||
}
|
||||
}
|
||||
t.Fail()
|
||||
}
|
||||
|
||||
// Get, Delete, Has
|
||||
for _, item := range cases {
|
||||
val, ok := sm.Get(item.key)
|
||||
if !ok || val != item.val {
|
||||
t.Fail()
|
||||
}
|
||||
|
||||
sm.Delete(item.key)
|
||||
if sm.Has(item.key) {
|
||||
t.Fail()
|
||||
}
|
||||
}
|
||||
isEmpty()
|
||||
|
||||
// DeleteMulti
|
||||
set()
|
||||
sm.DeleteMulti([]interface{}{"a", "b", "c"})
|
||||
isEmpty()
|
||||
|
||||
// Clear
|
||||
set()
|
||||
sm.Clear()
|
||||
isEmpty()
|
||||
}
|
||||
|
||||
func TestSyncedList(t *testing.T) {
|
||||
sl := newSyncedList()
|
||||
|
||||
insert := func() {
|
||||
for i := 0; i < 10; i++ {
|
||||
sl.PushBack(i)
|
||||
}
|
||||
}
|
||||
|
||||
isEmpty := func() {
|
||||
if sl.Len() != 0 {
|
||||
t.Fail()
|
||||
}
|
||||
}
|
||||
|
||||
// PushBack
|
||||
insert()
|
||||
|
||||
// Len
|
||||
if sl.Len() != 10 {
|
||||
t.Fail()
|
||||
}
|
||||
|
||||
// Iter
|
||||
i := 0
|
||||
for item := range sl.Iter() {
|
||||
if item.Value.(int) != i {
|
||||
t.Fail()
|
||||
}
|
||||
i++
|
||||
}
|
||||
|
||||
// Front
|
||||
if sl.Front().Value.(int) != 0 {
|
||||
t.Fail()
|
||||
}
|
||||
|
||||
// Back
|
||||
if sl.Back().Value.(int) != 9 {
|
||||
t.Fail()
|
||||
}
|
||||
|
||||
// Remove
|
||||
for i := 0; i < 10; i++ {
|
||||
if sl.Remove(sl.Front()).(int) != i {
|
||||
t.Fail()
|
||||
}
|
||||
}
|
||||
isEmpty()
|
||||
|
||||
// Clear
|
||||
insert()
|
||||
sl.Clear()
|
||||
isEmpty()
|
||||
}
|
||||
|
||||
func TestKeyedDeque(t *testing.T) {
|
||||
cases := []mapItem{
|
||||
{"a", 0},
|
||||
{"b", 1},
|
||||
{"c", 2},
|
||||
}
|
||||
|
||||
deque := newKeyedDeque()
|
||||
|
||||
insert := func() {
|
||||
for _, item := range cases {
|
||||
deque.Push(item.key, item.val)
|
||||
}
|
||||
}
|
||||
|
||||
isEmpty := func() {
|
||||
if deque.Len() != 0 {
|
||||
t.Fail()
|
||||
}
|
||||
}
|
||||
|
||||
// Push
|
||||
insert()
|
||||
|
||||
// Len
|
||||
if deque.Len() != 3 {
|
||||
t.Fail()
|
||||
}
|
||||
|
||||
// Iter
|
||||
i := 0
|
||||
for e := range deque.Iter() {
|
||||
if e.Value.(int) != i {
|
||||
t.Fail()
|
||||
}
|
||||
i++
|
||||
}
|
||||
|
||||
// HasKey, Get, Delete
|
||||
for _, item := range cases {
|
||||
if !deque.HasKey(item.key) {
|
||||
t.Fail()
|
||||
}
|
||||
|
||||
e, ok := deque.Get(item.key)
|
||||
if !ok || e.Value.(int) != item.val {
|
||||
t.Fail()
|
||||
}
|
||||
|
||||
if deque.Delete(item.key) != item.val {
|
||||
t.Fail()
|
||||
}
|
||||
|
||||
if deque.HasKey(item.key) {
|
||||
t.Fail()
|
||||
}
|
||||
}
|
||||
isEmpty()
|
||||
|
||||
// Clear
|
||||
insert()
|
||||
deque.Clear()
|
||||
isEmpty()
|
||||
}
|
296
dht/dht.go
Normal file
296
dht/dht.go
Normal file
|
@ -0,0 +1,296 @@
|
|||
// Package dht implements the bittorrent dht protocol. For more information
|
||||
// see http://www.bittorrent.org/beps/bep_0005.html.
|
||||
package dht
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"math"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
// StandardMode follows the standard protocol
|
||||
StandardMode = iota
|
||||
// CrawlMode for crawling the dht network.
|
||||
CrawlMode
|
||||
)
|
||||
|
||||
// Config represents the configure of dht.
|
||||
type Config struct {
|
||||
// in mainline dht, k = 8
|
||||
K int
|
||||
// for crawling mode, we put all nodes in one bucket, so KBucketSize may
|
||||
// not be K
|
||||
KBucketSize int
|
||||
// candidates are udp, udp4, udp6
|
||||
Network string
|
||||
// format is `ip:port`
|
||||
Address string
|
||||
// the prime nodes through which we can join in dht network
|
||||
PrimeNodes []string
|
||||
// the kbucket expired duration
|
||||
KBucketExpiredAfter time.Duration
|
||||
// the node expired duration
|
||||
NodeExpriedAfter time.Duration
|
||||
// how long it checks whether the bucket is expired
|
||||
CheckKBucketPeriod time.Duration
|
||||
// peer token expired duration
|
||||
TokenExpiredAfter time.Duration
|
||||
// the max transaction id
|
||||
MaxTransactionCursor uint64
|
||||
// how many nodes routing table can hold
|
||||
MaxNodes int
|
||||
// callback when got get_peers request
|
||||
OnGetPeers func(string, string, int)
|
||||
// callback when got announce_peer request
|
||||
OnAnnouncePeer func(string, string, int)
|
||||
// blcoked ips
|
||||
BlockedIPs []string
|
||||
// blacklist size
|
||||
BlackListMaxSize int
|
||||
// StandardMode or CrawlMode
|
||||
Mode int
|
||||
// the times it tries when send fails
|
||||
Try int
|
||||
// the size of packet need to be dealt with
|
||||
PacketJobLimit int
|
||||
// the size of packet handler
|
||||
PacketWorkerLimit int
|
||||
// the nodes num to be fresh in a kbucket
|
||||
RefreshNodeNum int
|
||||
}
|
||||
|
||||
// NewStandardConfig returns a Config pointer with default values.
|
||||
func NewStandardConfig() *Config {
|
||||
return &Config{
|
||||
K: 8,
|
||||
KBucketSize: 8,
|
||||
Network: "udp4",
|
||||
Address: ":6881",
|
||||
PrimeNodes: []string{
|
||||
"router.bittorrent.com:6881",
|
||||
"router.utorrent.com:6881",
|
||||
"dht.transmissionbt.com:6881",
|
||||
},
|
||||
NodeExpriedAfter: time.Duration(time.Minute * 15),
|
||||
KBucketExpiredAfter: time.Duration(time.Minute * 15),
|
||||
CheckKBucketPeriod: time.Duration(time.Second * 30),
|
||||
TokenExpiredAfter: time.Duration(time.Minute * 10),
|
||||
MaxTransactionCursor: math.MaxUint32,
|
||||
MaxNodes: 5000,
|
||||
BlockedIPs: make([]string, 0),
|
||||
BlackListMaxSize: 65536,
|
||||
Try: 2,
|
||||
Mode: StandardMode,
|
||||
PacketJobLimit: 1024,
|
||||
PacketWorkerLimit: 256,
|
||||
RefreshNodeNum: 8,
|
||||
}
|
||||
}
|
||||
|
||||
// NewCrawlConfig returns a config in crawling mode.
|
||||
func NewCrawlConfig() *Config {
|
||||
config := NewStandardConfig()
|
||||
config.NodeExpriedAfter = 0
|
||||
config.KBucketExpiredAfter = 0
|
||||
config.CheckKBucketPeriod = time.Second * 5
|
||||
config.KBucketSize = math.MaxInt32
|
||||
config.Mode = CrawlMode
|
||||
config.RefreshNodeNum = 256
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
// DHT represents a DHT node.
|
||||
type DHT struct {
|
||||
*Config
|
||||
node *node
|
||||
conn *net.UDPConn
|
||||
routingTable *routingTable
|
||||
transactionManager *transactionManager
|
||||
peersManager *peersManager
|
||||
tokenManager *tokenManager
|
||||
blackList *blackList
|
||||
Ready bool
|
||||
packets chan packet
|
||||
workerTokens chan struct{}
|
||||
}
|
||||
|
||||
// New returns a DHT pointer. If config is nil, then config will be set to
|
||||
// the default config.
|
||||
func New(config *Config) *DHT {
|
||||
if config == nil {
|
||||
config = NewStandardConfig()
|
||||
}
|
||||
|
||||
node, err := newNode(randomString(20), config.Network, config.Address)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
d := &DHT{
|
||||
Config: config,
|
||||
node: node,
|
||||
blackList: newBlackList(config.BlackListMaxSize),
|
||||
packets: make(chan packet, config.PacketJobLimit),
|
||||
workerTokens: make(chan struct{}, config.PacketWorkerLimit),
|
||||
}
|
||||
|
||||
for _, ip := range config.BlockedIPs {
|
||||
d.blackList.insert(ip, -1)
|
||||
}
|
||||
|
||||
go func() {
|
||||
for _, ip := range getLocalIPs() {
|
||||
d.blackList.insert(ip, -1)
|
||||
}
|
||||
|
||||
ip, err := getRemoteIP()
|
||||
if err != nil {
|
||||
d.blackList.insert(ip, -1)
|
||||
}
|
||||
}()
|
||||
|
||||
return d
|
||||
}
|
||||
|
||||
// IsStandardMode returns whether mode is StandardMode.
|
||||
func (dht *DHT) IsStandardMode() bool {
|
||||
return dht.Mode == StandardMode
|
||||
}
|
||||
|
||||
// IsCrawlMode returns whether mode is CrawlMode.
|
||||
func (dht *DHT) IsCrawlMode() bool {
|
||||
return dht.Mode == CrawlMode
|
||||
}
|
||||
|
||||
// init initializes global varables.
|
||||
func (dht *DHT) init() {
|
||||
listener, err := net.ListenPacket(dht.Network, dht.Address)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
dht.conn = listener.(*net.UDPConn)
|
||||
dht.routingTable = newRoutingTable(dht.KBucketSize, dht)
|
||||
dht.peersManager = newPeersManager(dht)
|
||||
dht.tokenManager = newTokenManager(dht.TokenExpiredAfter, dht)
|
||||
dht.transactionManager = newTransactionManager(
|
||||
dht.MaxTransactionCursor, dht)
|
||||
|
||||
go dht.transactionManager.run()
|
||||
go dht.tokenManager.clear()
|
||||
go dht.blackList.clear()
|
||||
}
|
||||
|
||||
// join makes current node join the dht network.
|
||||
func (dht *DHT) join() {
|
||||
for _, addr := range dht.PrimeNodes {
|
||||
raddr, err := net.ResolveUDPAddr(dht.Network, addr)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// NOTE: Temporary node has NOT node id.
|
||||
dht.transactionManager.findNode(
|
||||
&node{addr: raddr},
|
||||
dht.node.id.RawString(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// listen receives message from udp.
|
||||
func (dht *DHT) listen() {
|
||||
go func() {
|
||||
buff := make([]byte, 8192)
|
||||
for {
|
||||
n, raddr, err := dht.conn.ReadFromUDP(buff)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
dht.packets <- packet{buff[:n], raddr}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// id returns a id near to target if target is not null, otherwise it returns
|
||||
// the dht's node id.
|
||||
func (dht *DHT) id(target string) string {
|
||||
if dht.IsStandardMode() || target == "" {
|
||||
return dht.node.id.RawString()
|
||||
}
|
||||
return target[:15] + dht.node.id.RawString()[15:]
|
||||
}
|
||||
|
||||
// GetPeers returns peers who have announced having infoHash.
|
||||
func (dht *DHT) GetPeers(infoHash string) ([]*Peer, error) {
|
||||
if !dht.Ready {
|
||||
return nil, errors.New("dht not ready")
|
||||
}
|
||||
|
||||
if len(infoHash) == 40 {
|
||||
data, err := hex.DecodeString(infoHash)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
infoHash = string(data)
|
||||
}
|
||||
|
||||
peers := dht.peersManager.GetPeers(infoHash, dht.K)
|
||||
if len(peers) != 0 {
|
||||
return peers, nil
|
||||
}
|
||||
|
||||
ch := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
neighbors := dht.routingTable.GetNeighbors(
|
||||
newBitmapFromString(infoHash), dht.K)
|
||||
|
||||
for _, no := range neighbors {
|
||||
dht.transactionManager.getPeers(no, infoHash)
|
||||
}
|
||||
|
||||
i := 0
|
||||
for _ = range time.Tick(time.Second * 1) {
|
||||
i++
|
||||
peers = dht.peersManager.GetPeers(infoHash, dht.K)
|
||||
if len(peers) != 0 || i == 30 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
ch <- struct{}{}
|
||||
}()
|
||||
|
||||
<-ch
|
||||
return peers, nil
|
||||
}
|
||||
|
||||
// Run starts the dht.
|
||||
func (dht *DHT) Run() {
|
||||
dht.init()
|
||||
dht.listen()
|
||||
dht.join()
|
||||
|
||||
dht.Ready = true
|
||||
|
||||
var pkt packet
|
||||
tick := time.Tick(dht.CheckKBucketPeriod)
|
||||
|
||||
for {
|
||||
select {
|
||||
case pkt = <-dht.packets:
|
||||
handle(dht, pkt)
|
||||
case <-tick:
|
||||
if dht.routingTable.Len() == 0 {
|
||||
dht.join()
|
||||
} else if dht.transactionManager.len() == 0 {
|
||||
go dht.routingTable.Fresh()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
BIN
dht/doc/screen-shot.png
Normal file
BIN
dht/doc/screen-shot.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 695 KiB |
782
dht/krpc.go
Normal file
782
dht/krpc.go
Normal file
|
@ -0,0 +1,782 @@
|
|||
package dht
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
pingType = "ping"
|
||||
findNodeType = "find_node"
|
||||
getPeersType = "get_peers"
|
||||
announcePeerType = "announce_peer"
|
||||
)
|
||||
|
||||
const (
|
||||
generalError = 201 + iota
|
||||
serverError
|
||||
protocolError
|
||||
unknownError
|
||||
)
|
||||
|
||||
// packet represents the information receive from udp.
|
||||
type packet struct {
|
||||
data []byte
|
||||
raddr *net.UDPAddr
|
||||
}
|
||||
|
||||
// token represents the token when response getPeers request.
|
||||
type token struct {
|
||||
data string
|
||||
createTime time.Time
|
||||
}
|
||||
|
||||
// tokenManager managers the tokens.
|
||||
type tokenManager struct {
|
||||
*syncedMap
|
||||
expiredAfter time.Duration
|
||||
dht *DHT
|
||||
}
|
||||
|
||||
// newTokenManager returns a new tokenManager.
|
||||
func newTokenManager(expiredAfter time.Duration, dht *DHT) *tokenManager {
|
||||
return &tokenManager{
|
||||
syncedMap: newSyncedMap(),
|
||||
expiredAfter: expiredAfter,
|
||||
dht: dht,
|
||||
}
|
||||
}
|
||||
|
||||
// token returns a token. If it doesn't exist or is expired, it will add a
|
||||
// new token.
|
||||
func (tm *tokenManager) token(addr *net.UDPAddr) string {
|
||||
v, ok := tm.Get(addr.IP.String())
|
||||
tk, _ := v.(token)
|
||||
|
||||
if !ok || time.Now().Sub(tk.createTime) > tm.expiredAfter {
|
||||
tk = token{
|
||||
data: randomString(5),
|
||||
createTime: time.Now(),
|
||||
}
|
||||
|
||||
tm.Set(addr.IP.String(), tk)
|
||||
}
|
||||
|
||||
return tk.data
|
||||
}
|
||||
|
||||
// clear removes expired tokens.
|
||||
func (tm *tokenManager) clear() {
|
||||
for _ = range time.Tick(time.Minute * 3) {
|
||||
keys := make([]interface{}, 0, 100)
|
||||
|
||||
for item := range tm.Iter() {
|
||||
if time.Now().Sub(item.val.(token).createTime) > tm.expiredAfter {
|
||||
keys = append(keys, item.key)
|
||||
}
|
||||
}
|
||||
|
||||
tm.DeleteMulti(keys)
|
||||
}
|
||||
}
|
||||
|
||||
// check returns whether the token is valid.
|
||||
func (tm *tokenManager) check(addr *net.UDPAddr, tokenString string) bool {
|
||||
key := addr.IP.String()
|
||||
v, ok := tm.Get(key)
|
||||
tk, _ := v.(token)
|
||||
|
||||
if ok {
|
||||
tm.Delete(key)
|
||||
}
|
||||
|
||||
return ok && tokenString == tk.data
|
||||
}
|
||||
|
||||
// makeQuery returns a query-formed data.
|
||||
func makeQuery(t, q string, a map[string]interface{}) map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"t": t,
|
||||
"y": "q",
|
||||
"q": q,
|
||||
"a": a,
|
||||
}
|
||||
}
|
||||
|
||||
// makeResponse returns a response-formed data.
|
||||
func makeResponse(t string, r map[string]interface{}) map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"t": t,
|
||||
"y": "r",
|
||||
"r": r,
|
||||
}
|
||||
}
|
||||
|
||||
// makeError returns a err-formed data.
|
||||
func makeError(t string, errCode int, errMsg string) map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"t": t,
|
||||
"y": "e",
|
||||
"e": []interface{}{errCode, errMsg},
|
||||
}
|
||||
}
|
||||
|
||||
// send sends data to the udp.
|
||||
func send(dht *DHT, addr *net.UDPAddr, data map[string]interface{}) error {
|
||||
dht.conn.SetWriteDeadline(time.Now().Add(time.Second * 15))
|
||||
|
||||
_, err := dht.conn.WriteToUDP([]byte(Encode(data)), addr)
|
||||
if err != nil {
|
||||
dht.blackList.insert(addr.IP.String(), -1)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// query represents the query data included queried node and query-formed data.
|
||||
type query struct {
|
||||
node *node
|
||||
data map[string]interface{}
|
||||
}
|
||||
|
||||
// transaction implements transaction.
|
||||
type transaction struct {
|
||||
*query
|
||||
id string
|
||||
response chan struct{}
|
||||
}
|
||||
|
||||
// transactionManager represents the manager of transactions.
|
||||
type transactionManager struct {
|
||||
*sync.RWMutex
|
||||
transactions *syncedMap
|
||||
index *syncedMap
|
||||
cursor uint64
|
||||
maxCursor uint64
|
||||
queryChan chan *query
|
||||
dht *DHT
|
||||
}
|
||||
|
||||
// newTransactionManager returns new transactionManager pointer.
|
||||
func newTransactionManager(maxCursor uint64, dht *DHT) *transactionManager {
|
||||
return &transactionManager{
|
||||
RWMutex: &sync.RWMutex{},
|
||||
transactions: newSyncedMap(),
|
||||
index: newSyncedMap(),
|
||||
maxCursor: maxCursor,
|
||||
queryChan: make(chan *query, 1024),
|
||||
dht: dht,
|
||||
}
|
||||
}
|
||||
|
||||
// genTransID generates a transaction id and returns it.
|
||||
func (tm *transactionManager) genTransID() string {
|
||||
tm.Lock()
|
||||
defer tm.Unlock()
|
||||
|
||||
tm.cursor = (tm.cursor + 1) % tm.maxCursor
|
||||
return string(int2bytes(tm.cursor))
|
||||
}
|
||||
|
||||
// newTransaction creates a new transaction.
|
||||
func (tm *transactionManager) newTransaction(id string, q *query) *transaction {
|
||||
return &transaction{
|
||||
id: id,
|
||||
query: q,
|
||||
response: make(chan struct{}, tm.dht.Try+1),
|
||||
}
|
||||
}
|
||||
|
||||
// genIndexKey generates an indexed key which consists of queryType and
|
||||
// address.
|
||||
func (tm *transactionManager) genIndexKey(queryType, address string) string {
|
||||
return strings.Join([]string{queryType, address}, ":")
|
||||
}
|
||||
|
||||
// genIndexKeyByTrans generates an indexed key by a transaction.
|
||||
func (tm *transactionManager) genIndexKeyByTrans(trans *transaction) string {
|
||||
return tm.genIndexKey(trans.data["q"].(string), trans.node.addr.String())
|
||||
}
|
||||
|
||||
// insert adds a transaction to transactionManager.
|
||||
func (tm *transactionManager) insert(trans *transaction) {
|
||||
tm.Lock()
|
||||
defer tm.Unlock()
|
||||
|
||||
tm.transactions.Set(trans.id, trans)
|
||||
tm.index.Set(tm.genIndexKeyByTrans(trans), trans)
|
||||
}
|
||||
|
||||
// delete removes a transaction from transactionManager.
|
||||
func (tm *transactionManager) delete(transID string) {
|
||||
v, ok := tm.transactions.Get(transID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
tm.Lock()
|
||||
defer tm.Unlock()
|
||||
|
||||
trans := v.(*transaction)
|
||||
tm.transactions.Delete(trans.id)
|
||||
tm.index.Delete(tm.genIndexKeyByTrans(trans))
|
||||
}
|
||||
|
||||
// len returns how many transactions are requesting now.
|
||||
func (tm *transactionManager) len() int {
|
||||
return tm.transactions.Len()
|
||||
}
|
||||
|
||||
// transaction returns a transaction. keyType should be one of 0, 1 which
|
||||
// represents transId and index each.
|
||||
func (tm *transactionManager) transaction(
|
||||
key string, keyType int) *transaction {
|
||||
|
||||
sm := tm.transactions
|
||||
if keyType == 1 {
|
||||
sm = tm.index
|
||||
}
|
||||
|
||||
v, ok := sm.Get(key)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
return v.(*transaction)
|
||||
}
|
||||
|
||||
// getByTransID returns a transaction by transID.
|
||||
func (tm *transactionManager) getByTransID(transID string) *transaction {
|
||||
return tm.transaction(transID, 0)
|
||||
}
|
||||
|
||||
// getByIndex returns a transaction by indexed key.
|
||||
func (tm *transactionManager) getByIndex(index string) *transaction {
|
||||
return tm.transaction(index, 1)
|
||||
}
|
||||
|
||||
// transaction gets the proper transaction with whose id is transId and
|
||||
// address is addr.
|
||||
func (tm *transactionManager) filterOne(
|
||||
transID string, addr *net.UDPAddr) *transaction {
|
||||
|
||||
trans := tm.getByTransID(transID)
|
||||
if trans == nil || trans.node.addr.String() != addr.String() {
|
||||
return nil
|
||||
}
|
||||
|
||||
return trans
|
||||
}
|
||||
|
||||
// query sends the query-formed data to udp and wait for the response.
|
||||
// When timeout, it will retry `try - 1` times, which means it will query
|
||||
// `try` times totally.
|
||||
func (tm *transactionManager) query(q *query, try int) {
|
||||
transID := q.data["t"].(string)
|
||||
trans := tm.newTransaction(transID, q)
|
||||
|
||||
tm.insert(trans)
|
||||
defer tm.delete(trans.id)
|
||||
|
||||
success := false
|
||||
for i := 0; i < try; i++ {
|
||||
if err := send(tm.dht, q.node.addr, q.data); err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
select {
|
||||
case <-trans.response:
|
||||
success = true
|
||||
break
|
||||
case <-time.After(time.Second * 15):
|
||||
}
|
||||
}
|
||||
|
||||
if !success && q.node.id != nil {
|
||||
tm.dht.blackList.insert(q.node.addr.IP.String(), q.node.addr.Port)
|
||||
tm.dht.routingTable.RemoveByAddr(q.node.addr.String())
|
||||
}
|
||||
}
|
||||
|
||||
// run starts to listen and consume the query chan.
|
||||
func (tm *transactionManager) run() {
|
||||
var q *query
|
||||
|
||||
for {
|
||||
select {
|
||||
case q = <-tm.queryChan:
|
||||
go tm.query(q, tm.dht.Try)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sendQuery send query-formed data to the chan.
|
||||
func (tm *transactionManager) sendQuery(
|
||||
no *node, queryType string, a map[string]interface{}) {
|
||||
|
||||
// If the target is self, then stop.
|
||||
if no.id != nil && no.id.RawString() == tm.dht.node.id.RawString() ||
|
||||
tm.getByIndex(tm.genIndexKey(queryType, no.addr.String())) != nil ||
|
||||
tm.dht.blackList.in(no.addr.IP.String(), no.addr.Port) {
|
||||
return
|
||||
}
|
||||
|
||||
data := makeQuery(tm.genTransID(), queryType, a)
|
||||
tm.queryChan <- &query{
|
||||
node: no,
|
||||
data: data,
|
||||
}
|
||||
}
|
||||
|
||||
// ping sends ping query to the chan.
|
||||
func (tm *transactionManager) ping(no *node) {
|
||||
tm.sendQuery(no, pingType, map[string]interface{}{
|
||||
"id": tm.dht.id(no.id.RawString()),
|
||||
})
|
||||
}
|
||||
|
||||
// findNode sends find_node query to the chan.
|
||||
func (tm *transactionManager) findNode(no *node, target string) {
|
||||
tm.sendQuery(no, findNodeType, map[string]interface{}{
|
||||
"id": tm.dht.id(target),
|
||||
"target": target,
|
||||
})
|
||||
}
|
||||
|
||||
// getPeers sends get_peers query to the chan.
|
||||
func (tm *transactionManager) getPeers(no *node, infoHash string) {
|
||||
tm.sendQuery(no, getPeersType, map[string]interface{}{
|
||||
"id": tm.dht.id(infoHash),
|
||||
"info_hash": infoHash,
|
||||
})
|
||||
}
|
||||
|
||||
// announcePeer sends announce_peer query to the chan.
|
||||
func (tm *transactionManager) announcePeer(
|
||||
no *node, infoHash string, impliedPort, port int, token string) {
|
||||
|
||||
tm.sendQuery(no, announcePeerType, map[string]interface{}{
|
||||
"id": tm.dht.id(no.id.RawString()),
|
||||
"info_hash": infoHash,
|
||||
"implied_port": impliedPort,
|
||||
"port": port,
|
||||
"token": token,
|
||||
})
|
||||
}
|
||||
|
||||
// parseKey parses the key in dict data. `t` is type of the keyed value.
|
||||
// It's one of "int", "string", "map", "list".
|
||||
func parseKey(data map[string]interface{}, key string, t string) error {
|
||||
val, ok := data[key]
|
||||
if !ok {
|
||||
return errors.New("lack of key")
|
||||
}
|
||||
|
||||
switch t {
|
||||
case "string":
|
||||
_, ok = val.(string)
|
||||
case "int":
|
||||
_, ok = val.(int)
|
||||
case "map":
|
||||
_, ok = val.(map[string]interface{})
|
||||
case "list":
|
||||
_, ok = val.([]interface{})
|
||||
default:
|
||||
panic("invalid type")
|
||||
}
|
||||
|
||||
if !ok {
|
||||
return errors.New("invalid key type")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseKeys parses keys. It just wraps parseKey.
|
||||
func parseKeys(data map[string]interface{}, pairs [][]string) error {
|
||||
for _, args := range pairs {
|
||||
key, t := args[0], args[1]
|
||||
if err := parseKey(data, key, t); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseMessage parses the basic data received from udp.
|
||||
// It returns a map value.
|
||||
func parseMessage(data interface{}) (map[string]interface{}, error) {
|
||||
response, ok := data.(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, errors.New("response is not dict")
|
||||
}
|
||||
|
||||
if err := parseKeys(
|
||||
response, [][]string{{"t", "string"}, {"y", "string"}}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// handleRequest handles the requests received from udp.
|
||||
func handleRequest(dht *DHT, addr *net.UDPAddr,
|
||||
response map[string]interface{}) (success bool) {
|
||||
|
||||
t := response["t"].(string)
|
||||
|
||||
if err := parseKeys(
|
||||
response, [][]string{{"q", "string"}, {"a", "map"}}); err != nil {
|
||||
|
||||
send(dht, addr, makeError(t, protocolError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
q := response["q"].(string)
|
||||
a := response["a"].(map[string]interface{})
|
||||
|
||||
if err := parseKey(a, "id", "string"); err != nil {
|
||||
send(dht, addr, makeError(t, protocolError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
id := a["id"].(string)
|
||||
|
||||
if id == dht.node.id.RawString() {
|
||||
return
|
||||
}
|
||||
|
||||
if len(id) != 20 {
|
||||
send(dht, addr, makeError(t, protocolError, "invalid id"))
|
||||
return
|
||||
}
|
||||
|
||||
if no, ok := dht.routingTable.GetNodeByAddress(addr.String()); ok &&
|
||||
no.id.RawString() != id {
|
||||
|
||||
dht.blackList.insert(addr.IP.String(), addr.Port)
|
||||
dht.routingTable.RemoveByAddr(addr.String())
|
||||
|
||||
send(dht, addr, makeError(t, protocolError, "invalid id"))
|
||||
return
|
||||
}
|
||||
|
||||
switch q {
|
||||
case pingType:
|
||||
send(dht, addr, makeResponse(t, map[string]interface{}{
|
||||
"id": dht.id(id),
|
||||
}))
|
||||
case findNodeType:
|
||||
if dht.IsStandardMode() {
|
||||
if err := parseKey(a, "target", "string"); err != nil {
|
||||
send(dht, addr, makeError(t, protocolError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
target := a["target"].(string)
|
||||
if len(target) != 20 {
|
||||
send(dht, addr, makeError(t, protocolError, "invalid target"))
|
||||
return
|
||||
}
|
||||
|
||||
var nodes string
|
||||
targetID := newBitmapFromString(target)
|
||||
|
||||
no, _ := dht.routingTable.GetNodeKBucktByID(targetID)
|
||||
if no != nil {
|
||||
nodes = no.CompactNodeInfo()
|
||||
} else {
|
||||
nodes = strings.Join(
|
||||
dht.routingTable.GetNeighborCompactInfos(targetID, dht.K),
|
||||
"",
|
||||
)
|
||||
}
|
||||
|
||||
send(dht, addr, makeResponse(t, map[string]interface{}{
|
||||
"id": dht.id(target),
|
||||
"nodes": nodes,
|
||||
}))
|
||||
}
|
||||
case getPeersType:
|
||||
if err := parseKey(a, "info_hash", "string"); err != nil {
|
||||
send(dht, addr, makeError(t, protocolError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
infoHash := a["info_hash"].(string)
|
||||
|
||||
if len(infoHash) != 20 {
|
||||
send(dht, addr, makeError(t, protocolError, "invalid info_hash"))
|
||||
return
|
||||
}
|
||||
|
||||
if dht.IsCrawlMode() {
|
||||
send(dht, addr, makeResponse(t, map[string]interface{}{
|
||||
"id": dht.id(infoHash),
|
||||
"token": dht.tokenManager.token(addr),
|
||||
"nodes": "",
|
||||
}))
|
||||
} else if peers := dht.peersManager.GetPeers(
|
||||
infoHash, dht.K); len(peers) > 0 {
|
||||
|
||||
values := make([]interface{}, len(peers))
|
||||
for i, p := range peers {
|
||||
values[i] = p.CompactIPPortInfo()
|
||||
}
|
||||
|
||||
send(dht, addr, makeResponse(t, map[string]interface{}{
|
||||
"id": dht.id(infoHash),
|
||||
"values": values,
|
||||
"token": dht.tokenManager.token(addr),
|
||||
}))
|
||||
} else {
|
||||
send(dht, addr, makeResponse(t, map[string]interface{}{
|
||||
"id": dht.id(infoHash),
|
||||
"token": dht.tokenManager.token(addr),
|
||||
"nodes": strings.Join(dht.routingTable.GetNeighborCompactInfos(
|
||||
newBitmapFromString(infoHash), dht.K), ""),
|
||||
}))
|
||||
}
|
||||
|
||||
if dht.OnGetPeers != nil {
|
||||
dht.OnGetPeers(infoHash, addr.IP.String(), addr.Port)
|
||||
}
|
||||
case announcePeerType:
|
||||
if err := parseKeys(a, [][]string{
|
||||
{"info_hash", "string"},
|
||||
{"port", "int"},
|
||||
{"token", "string"}}); err != nil {
|
||||
|
||||
send(dht, addr, makeError(t, protocolError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
infoHash := a["info_hash"].(string)
|
||||
port := a["port"].(int)
|
||||
token := a["token"].(string)
|
||||
|
||||
if !dht.tokenManager.check(addr, token) {
|
||||
// send(dht, addr, makeError(t, protocolError, "invalid token"))
|
||||
return
|
||||
}
|
||||
|
||||
if impliedPort, ok := a["implied_port"]; ok &&
|
||||
impliedPort.(int) != 0 {
|
||||
|
||||
port = addr.Port
|
||||
}
|
||||
|
||||
if dht.IsStandardMode() {
|
||||
dht.peersManager.Insert(infoHash, newPeer(addr.IP, port, token))
|
||||
|
||||
send(dht, addr, makeResponse(t, map[string]interface{}{
|
||||
"id": dht.id(id),
|
||||
}))
|
||||
}
|
||||
|
||||
if dht.OnAnnouncePeer != nil {
|
||||
dht.OnAnnouncePeer(infoHash, addr.IP.String(), port)
|
||||
}
|
||||
default:
|
||||
// send(dht, addr, makeError(t, protocolError, "invalid q"))
|
||||
return
|
||||
}
|
||||
|
||||
no, _ := newNode(id, addr.Network(), addr.String())
|
||||
dht.routingTable.Insert(no)
|
||||
return true
|
||||
}
|
||||
|
||||
// findOn puts nodes in the response to the routingTable, then if target is in
|
||||
// the nodes or all nodes are in the routingTable, it stops. Otherwise it
|
||||
// continues to findNode or getPeers.
|
||||
func findOn(dht *DHT, r map[string]interface{}, target *bitmap,
|
||||
queryType string) error {
|
||||
|
||||
if err := parseKey(r, "nodes", "string"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nodes := r["nodes"].(string)
|
||||
if len(nodes)%26 != 0 {
|
||||
return errors.New("the length of nodes should can be divided by 26")
|
||||
}
|
||||
|
||||
hasNew, found := false, false
|
||||
for i := 0; i < len(nodes)/26; i++ {
|
||||
no, _ := newNodeFromCompactInfo(
|
||||
string(nodes[i*26:(i+1)*26]), dht.Network)
|
||||
|
||||
if no.id.RawString() == target.RawString() {
|
||||
found = true
|
||||
}
|
||||
|
||||
if dht.routingTable.Insert(no) {
|
||||
hasNew = true
|
||||
}
|
||||
}
|
||||
|
||||
if found || !hasNew {
|
||||
return nil
|
||||
}
|
||||
|
||||
targetID := target.RawString()
|
||||
for _, no := range dht.routingTable.GetNeighbors(target, dht.K) {
|
||||
switch queryType {
|
||||
case findNodeType:
|
||||
dht.transactionManager.findNode(no, targetID)
|
||||
case getPeersType:
|
||||
dht.transactionManager.getPeers(no, targetID)
|
||||
default:
|
||||
panic("invalid find type")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleResponse handles responses received from udp.
|
||||
func handleResponse(dht *DHT, addr *net.UDPAddr,
|
||||
response map[string]interface{}) (success bool) {
|
||||
|
||||
t := response["t"].(string)
|
||||
|
||||
trans := dht.transactionManager.filterOne(t, addr)
|
||||
if trans == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// inform transManager to delete the transaction.
|
||||
if err := parseKey(response, "r", "map"); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
q := trans.data["q"].(string)
|
||||
a := trans.data["a"].(map[string]interface{})
|
||||
r := response["r"].(map[string]interface{})
|
||||
|
||||
if err := parseKey(r, "id", "string"); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
id := r["id"].(string)
|
||||
|
||||
// If response's node id is not the same with the node id in the
|
||||
// transaction, raise error.
|
||||
if trans.node.id != nil && trans.node.id.RawString() != r["id"].(string) {
|
||||
dht.blackList.insert(addr.IP.String(), addr.Port)
|
||||
dht.routingTable.RemoveByAddr(addr.String())
|
||||
return
|
||||
}
|
||||
|
||||
node, err := newNode(id, addr.Network(), addr.String())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
switch q {
|
||||
case pingType:
|
||||
case findNodeType:
|
||||
if trans.data["q"].(string) != findNodeType {
|
||||
return
|
||||
}
|
||||
|
||||
target := trans.data["a"].(map[string]interface{})["target"].(string)
|
||||
if findOn(dht, r, newBitmapFromString(target), findNodeType) != nil {
|
||||
return
|
||||
}
|
||||
case getPeersType:
|
||||
if err := parseKey(r, "token", "string"); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
token := r["token"].(string)
|
||||
infoHash := a["info_hash"].(string)
|
||||
|
||||
if err := parseKey(r, "values", "list"); err == nil {
|
||||
values := r["values"].([]interface{})
|
||||
for _, v := range values {
|
||||
p, err := newPeerFromCompactIPPortInfo(v.(string), token)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
dht.peersManager.Insert(infoHash, p)
|
||||
}
|
||||
} else if findOn(
|
||||
dht, r, newBitmapFromString(infoHash), getPeersType) != nil {
|
||||
return
|
||||
}
|
||||
case announcePeerType:
|
||||
default:
|
||||
return
|
||||
}
|
||||
|
||||
// inform transManager to delete transaction.
|
||||
trans.response <- struct{}{}
|
||||
|
||||
dht.blackList.delete(addr.IP.String(), addr.Port)
|
||||
dht.routingTable.Insert(node)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// handleError handles errors received from udp.
|
||||
func handleError(dht *DHT, addr *net.UDPAddr,
|
||||
response map[string]interface{}) (success bool) {
|
||||
|
||||
if err := parseKey(response, "e", "list"); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if e := response["e"].([]interface{}); len(e) != 2 {
|
||||
return
|
||||
}
|
||||
|
||||
if trans := dht.transactionManager.filterOne(
|
||||
response["t"].(string), addr); trans != nil {
|
||||
|
||||
trans.response <- struct{}{}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
var handlers = map[string]func(*DHT, *net.UDPAddr, map[string]interface{}) bool{
|
||||
"q": handleRequest,
|
||||
"r": handleResponse,
|
||||
"e": handleError,
|
||||
}
|
||||
|
||||
// handle handles packets received from udp.
|
||||
func handle(dht *DHT, pkt packet) {
|
||||
if len(dht.workerTokens) == dht.PacketWorkerLimit {
|
||||
return
|
||||
}
|
||||
|
||||
dht.workerTokens <- struct{}{}
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
<-dht.workerTokens
|
||||
}()
|
||||
|
||||
if dht.blackList.in(pkt.raddr.IP.String(), pkt.raddr.Port) {
|
||||
return
|
||||
}
|
||||
|
||||
data, err := Decode(pkt.data)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
response, err := parseMessage(data)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if f, ok := handlers[response["y"].(string)]; ok {
|
||||
f(dht, pkt.raddr, response)
|
||||
}
|
||||
}()
|
||||
}
|
385
dht/peerwire.go
Normal file
385
dht/peerwire.go
Normal file
|
@ -0,0 +1,385 @@
|
|||
package dht
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha1"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
// REQUEST represents request message type
|
||||
REQUEST = iota
|
||||
// DATA represents data message type
|
||||
DATA
|
||||
// REJECT represents reject message type
|
||||
REJECT
|
||||
)
|
||||
|
||||
const (
|
||||
// BLOCK is 2 ^ 14
|
||||
BLOCK = 16384
|
||||
// MaxMetadataSize represents the max medata it can accept
|
||||
MaxMetadataSize = BLOCK * 1000
|
||||
// EXTENDED represents it is a extended message
|
||||
EXTENDED = 20
|
||||
// HANDSHAKE represents handshake bit
|
||||
HANDSHAKE = 0
|
||||
)
|
||||
|
||||
var handshakePrefix = []byte{
|
||||
19, 66, 105, 116, 84, 111, 114, 114, 101, 110, 116, 32, 112, 114,
|
||||
111, 116, 111, 99, 111, 108, 0, 0, 0, 0, 0, 16, 0, 1,
|
||||
}
|
||||
|
||||
// read reads size-length bytes from conn to data.
|
||||
func read(conn *net.TCPConn, size int, data *bytes.Buffer) error {
|
||||
conn.SetReadDeadline(time.Now().Add(time.Second * 15))
|
||||
|
||||
n, err := io.CopyN(data, conn, int64(size))
|
||||
if err != nil || n != int64(size) {
|
||||
return errors.New("read error")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// readMessage gets a message from the tcp connection.
|
||||
func readMessage(conn *net.TCPConn, data *bytes.Buffer) (
|
||||
length int, err error) {
|
||||
|
||||
if err = read(conn, 4, data); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
length = int(bytes2int(data.Next(4)))
|
||||
if length == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
if err = read(conn, length, data); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// sendMessage sends data to the connection.
|
||||
func sendMessage(conn *net.TCPConn, data []byte) error {
|
||||
length := int32(len(data))
|
||||
|
||||
buffer := bytes.NewBuffer(nil)
|
||||
binary.Write(buffer, binary.BigEndian, length)
|
||||
|
||||
conn.SetWriteDeadline(time.Now().Add(time.Second * 10))
|
||||
_, err := conn.Write(append(buffer.Bytes(), data...))
|
||||
return err
|
||||
}
|
||||
|
||||
// sendHandshake sends handshake message to conn.
|
||||
func sendHandshake(conn *net.TCPConn, infoHash, peerID []byte) error {
|
||||
data := make([]byte, 68)
|
||||
copy(data[:28], handshakePrefix)
|
||||
copy(data[28:48], infoHash)
|
||||
copy(data[48:], peerID)
|
||||
|
||||
conn.SetWriteDeadline(time.Now().Add(time.Second * 10))
|
||||
_, err := conn.Write(data)
|
||||
return err
|
||||
}
|
||||
|
||||
// onHandshake handles the handshake response.
|
||||
func onHandshake(data []byte) (err error) {
|
||||
if !(bytes.Equal(handshakePrefix[:20], data[:20]) && data[25]&0x10 != 0) {
|
||||
err = errors.New("invalid handshake response")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// sendExtHandshake requests for the ut_metadata and metadata_size.
|
||||
func sendExtHandshake(conn *net.TCPConn) error {
|
||||
data := append(
|
||||
[]byte{EXTENDED, HANDSHAKE},
|
||||
Encode(map[string]interface{}{
|
||||
"m": map[string]interface{}{"ut_metadata": 1},
|
||||
})...,
|
||||
)
|
||||
|
||||
return sendMessage(conn, data)
|
||||
}
|
||||
|
||||
// getUTMetaSize returns the ut_metadata and metadata_size.
|
||||
func getUTMetaSize(data []byte) (
|
||||
utMetadata int, metadataSize int, err error) {
|
||||
|
||||
v, err := Decode(data)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
dict, ok := v.(map[string]interface{})
|
||||
if !ok {
|
||||
err = errors.New("invalid dict")
|
||||
return
|
||||
}
|
||||
|
||||
if err = parseKeys(
|
||||
dict, [][]string{{"metadata_size", "int"}, {"m", "map"}}); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
m := dict["m"].(map[string]interface{})
|
||||
if err = parseKey(m, "ut_metadata", "int"); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
utMetadata = m["ut_metadata"].(int)
|
||||
metadataSize = dict["metadata_size"].(int)
|
||||
|
||||
if metadataSize > MaxMetadataSize {
|
||||
err = errors.New("metadata_size too long")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Request represents the request context.
|
||||
type Request struct {
|
||||
InfoHash []byte
|
||||
IP string
|
||||
Port int
|
||||
}
|
||||
|
||||
// Response contains the request context and the metadata info.
|
||||
type Response struct {
|
||||
Request
|
||||
MetadataInfo []byte
|
||||
}
|
||||
|
||||
// Wire represents the wire protocol.
|
||||
type Wire struct {
|
||||
blackList *blackList
|
||||
queue *syncedMap
|
||||
requests chan Request
|
||||
responses chan Response
|
||||
workerTokens chan struct{}
|
||||
}
|
||||
|
||||
// NewWire returns a Wire pointer.
|
||||
// - blackListSize: the blacklist size
|
||||
// - requestQueueSize: the max requests it can buffers
|
||||
// - workerQueueSize: the max goroutine downloading workers
|
||||
func NewWire(blackListSize, requestQueueSize, workerQueueSize int) *Wire {
|
||||
return &Wire{
|
||||
blackList: newBlackList(blackListSize),
|
||||
queue: newSyncedMap(),
|
||||
requests: make(chan Request, requestQueueSize),
|
||||
responses: make(chan Response, 1024),
|
||||
workerTokens: make(chan struct{}, workerQueueSize),
|
||||
}
|
||||
}
|
||||
|
||||
// Request pushes the request to the queue.
|
||||
func (wire *Wire) Request(infoHash []byte, ip string, port int) {
|
||||
wire.requests <- Request{InfoHash: infoHash, IP: ip, Port: port}
|
||||
}
|
||||
|
||||
// Response returns a chan of Response.
|
||||
func (wire *Wire) Response() <-chan Response {
|
||||
return wire.responses
|
||||
}
|
||||
|
||||
// isDone returns whether the wire get all pieces of the metadata info.
|
||||
func (wire *Wire) isDone(pieces [][]byte) bool {
|
||||
for _, piece := range pieces {
|
||||
if len(piece) == 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (wire *Wire) requestPieces(
|
||||
conn *net.TCPConn, utMetadata int, metadataSize int, piecesNum int) {
|
||||
|
||||
buffer := make([]byte, 1024)
|
||||
for i := 0; i < piecesNum; i++ {
|
||||
buffer[0] = EXTENDED
|
||||
buffer[1] = byte(utMetadata)
|
||||
|
||||
msg := Encode(map[string]interface{}{
|
||||
"msg_type": REQUEST,
|
||||
"piece": i,
|
||||
})
|
||||
|
||||
length := len(msg) + 2
|
||||
copy(buffer[2:length], msg)
|
||||
|
||||
sendMessage(conn, buffer[:length])
|
||||
}
|
||||
buffer = nil
|
||||
}
|
||||
|
||||
// fetchMetadata fetchs medata info accroding to infohash from dht.
|
||||
func (wire *Wire) fetchMetadata(r Request) {
|
||||
var (
|
||||
length int
|
||||
msgType byte
|
||||
piecesNum int
|
||||
pieces [][]byte
|
||||
utMetadata int
|
||||
metadataSize int
|
||||
)
|
||||
|
||||
defer func() {
|
||||
pieces = nil
|
||||
recover()
|
||||
}()
|
||||
|
||||
infoHash := r.InfoHash
|
||||
address := genAddress(r.IP, r.Port)
|
||||
|
||||
dial, err := net.DialTimeout("tcp", address, time.Second*15)
|
||||
if err != nil {
|
||||
wire.blackList.insert(r.IP, r.Port)
|
||||
return
|
||||
}
|
||||
conn := dial.(*net.TCPConn)
|
||||
conn.SetLinger(0)
|
||||
defer conn.Close()
|
||||
|
||||
data := bytes.NewBuffer(nil)
|
||||
data.Grow(BLOCK)
|
||||
|
||||
if sendHandshake(conn, infoHash, []byte(randomString(20))) != nil ||
|
||||
read(conn, 68, data) != nil ||
|
||||
onHandshake(data.Next(68)) != nil ||
|
||||
sendExtHandshake(conn) != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
length, err = readMessage(conn, data)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if length == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
msgType, err = data.ReadByte()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
switch msgType {
|
||||
case EXTENDED:
|
||||
extendedID, err := data.ReadByte()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
payload, err := ioutil.ReadAll(data)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if extendedID == 0 {
|
||||
if pieces != nil {
|
||||
return
|
||||
}
|
||||
|
||||
utMetadata, metadataSize, err = getUTMetaSize(payload)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
piecesNum = metadataSize / BLOCK
|
||||
if metadataSize%BLOCK != 0 {
|
||||
piecesNum++
|
||||
}
|
||||
|
||||
pieces = make([][]byte, piecesNum)
|
||||
go wire.requestPieces(conn, utMetadata, metadataSize, piecesNum)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if pieces == nil {
|
||||
return
|
||||
}
|
||||
|
||||
d, index, err := DecodeDict(payload, 0)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
dict := d.(map[string]interface{})
|
||||
|
||||
if err = parseKeys(dict, [][]string{
|
||||
{"msg_type", "int"},
|
||||
{"piece", "int"}}); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if dict["msg_type"].(int) != DATA {
|
||||
continue
|
||||
}
|
||||
|
||||
piece := dict["piece"].(int)
|
||||
pieceLen := length - 2 - index
|
||||
|
||||
if (piece != piecesNum-1 && pieceLen != BLOCK) ||
|
||||
(piece == piecesNum-1 && pieceLen != metadataSize%BLOCK) {
|
||||
return
|
||||
}
|
||||
|
||||
pieces[piece] = payload[index:]
|
||||
|
||||
if wire.isDone(pieces) {
|
||||
metadataInfo := bytes.Join(pieces, nil)
|
||||
|
||||
info := sha1.Sum(metadataInfo)
|
||||
if !bytes.Equal(infoHash, info[:]) {
|
||||
return
|
||||
}
|
||||
|
||||
wire.responses <- Response{
|
||||
Request: r,
|
||||
MetadataInfo: metadataInfo,
|
||||
}
|
||||
return
|
||||
}
|
||||
default:
|
||||
data.Reset()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Run starts the peer wire protocol.
|
||||
func (wire *Wire) Run() {
|
||||
go wire.blackList.clear()
|
||||
|
||||
for r := range wire.requests {
|
||||
wire.workerTokens <- struct{}{}
|
||||
|
||||
go func(r Request) {
|
||||
defer func() {
|
||||
<-wire.workerTokens
|
||||
}()
|
||||
|
||||
key := strings.Join([]string{
|
||||
string(r.InfoHash), genAddress(r.IP, r.Port),
|
||||
}, ":")
|
||||
|
||||
if len(r.InfoHash) != 20 || wire.blackList.in(r.IP, r.Port) ||
|
||||
wire.queue.Has(key) {
|
||||
return
|
||||
}
|
||||
|
||||
wire.fetchMetadata(r)
|
||||
}(r)
|
||||
}
|
||||
}
|
596
dht/routingtable.go
Normal file
596
dht/routingtable.go
Normal file
|
@ -0,0 +1,596 @@
|
|||
package dht
|
||||
|
||||
import (
|
||||
"container/heap"
|
||||
"errors"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// maxPrefixLength is the length of DHT node.
|
||||
const maxPrefixLength = 160
|
||||
|
||||
// node represents a DHT node.
|
||||
type node struct {
|
||||
id *bitmap
|
||||
addr *net.UDPAddr
|
||||
lastActiveTime time.Time
|
||||
}
|
||||
|
||||
// newNode returns a node pointer.
|
||||
func newNode(id, network, address string) (*node, error) {
|
||||
if len(id) != 20 {
|
||||
return nil, errors.New("node id should be a 20-length string")
|
||||
}
|
||||
|
||||
addr, err := net.ResolveUDPAddr(network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &node{newBitmapFromString(id), addr, time.Now()}, nil
|
||||
}
|
||||
|
||||
// newNodeFromCompactInfo parses compactNodeInfo and returns a node pointer.
|
||||
func newNodeFromCompactInfo(
|
||||
compactNodeInfo string, network string) (*node, error) {
|
||||
|
||||
if len(compactNodeInfo) != 26 {
|
||||
return nil, errors.New("compactNodeInfo should be a 26-length string")
|
||||
}
|
||||
|
||||
id := compactNodeInfo[:20]
|
||||
ip, port, _ := decodeCompactIPPortInfo(compactNodeInfo[20:])
|
||||
|
||||
return newNode(id, network, genAddress(ip.String(), port))
|
||||
}
|
||||
|
||||
// CompactIPPortInfo returns "Compact IP-address/port info".
|
||||
// See http://www.bittorrent.org/beps/bep_0005.html.
|
||||
func (node *node) CompactIPPortInfo() string {
|
||||
info, _ := encodeCompactIPPortInfo(node.addr.IP, node.addr.Port)
|
||||
return info
|
||||
}
|
||||
|
||||
// CompactNodeInfo returns "Compact node info".
|
||||
// See http://www.bittorrent.org/beps/bep_0005.html.
|
||||
func (node *node) CompactNodeInfo() string {
|
||||
return strings.Join([]string{
|
||||
node.id.RawString(), node.CompactIPPortInfo(),
|
||||
}, "")
|
||||
}
|
||||
|
||||
// Peer represents a peer contact.
|
||||
type Peer struct {
|
||||
IP net.IP
|
||||
Port int
|
||||
token string
|
||||
}
|
||||
|
||||
// newPeer returns a new peer pointer.
|
||||
func newPeer(ip net.IP, port int, token string) *Peer {
|
||||
return &Peer{
|
||||
IP: ip,
|
||||
Port: port,
|
||||
token: token,
|
||||
}
|
||||
}
|
||||
|
||||
// newPeerFromCompactIPPortInfo create a peer pointer by compact ip/port info.
|
||||
func newPeerFromCompactIPPortInfo(compactInfo, token string) (*Peer, error) {
|
||||
ip, port, err := decodeCompactIPPortInfo(compactInfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return newPeer(ip, port, token), nil
|
||||
}
|
||||
|
||||
// CompactIPPortInfo returns "Compact node info".
|
||||
// See http://www.bittorrent.org/beps/bep_0005.html.
|
||||
func (p *Peer) CompactIPPortInfo() string {
|
||||
info, _ := encodeCompactIPPortInfo(p.IP, p.Port)
|
||||
return info
|
||||
}
|
||||
|
||||
// peersManager represents a proxy that manipulates peers.
|
||||
type peersManager struct {
|
||||
sync.RWMutex
|
||||
table *syncedMap
|
||||
dht *DHT
|
||||
}
|
||||
|
||||
// newPeersManager returns a new peersManager.
|
||||
func newPeersManager(dht *DHT) *peersManager {
|
||||
return &peersManager{
|
||||
table: newSyncedMap(),
|
||||
dht: dht,
|
||||
}
|
||||
}
|
||||
|
||||
// Insert adds a peer into peersManager.
|
||||
func (pm *peersManager) Insert(infoHash string, peer *Peer) {
|
||||
pm.Lock()
|
||||
if _, ok := pm.table.Get(infoHash); !ok {
|
||||
pm.table.Set(infoHash, newKeyedDeque())
|
||||
}
|
||||
pm.Unlock()
|
||||
|
||||
v, _ := pm.table.Get(infoHash)
|
||||
queue := v.(*keyedDeque)
|
||||
|
||||
queue.Push(peer.CompactIPPortInfo(), peer)
|
||||
if queue.Len() > pm.dht.K {
|
||||
queue.Remove(queue.Front())
|
||||
}
|
||||
}
|
||||
|
||||
// GetPeers returns size-length peers who announces having infoHash.
|
||||
func (pm *peersManager) GetPeers(infoHash string, size int) []*Peer {
|
||||
peers := make([]*Peer, 0, size)
|
||||
|
||||
v, ok := pm.table.Get(infoHash)
|
||||
if !ok {
|
||||
return peers
|
||||
}
|
||||
|
||||
for e := range v.(*keyedDeque).Iter() {
|
||||
peers = append(peers, e.Value.(*Peer))
|
||||
}
|
||||
|
||||
if len(peers) > size {
|
||||
peers = peers[len(peers)-size:]
|
||||
}
|
||||
return peers
|
||||
}
|
||||
|
||||
// kbucket represents a k-size bucket.
|
||||
type kbucket struct {
|
||||
sync.RWMutex
|
||||
nodes, candidates *keyedDeque
|
||||
lastChanged time.Time
|
||||
prefix *bitmap
|
||||
}
|
||||
|
||||
// newKBucket returns a new kbucket pointer.
|
||||
func newKBucket(prefix *bitmap) *kbucket {
|
||||
bucket := &kbucket{
|
||||
nodes: newKeyedDeque(),
|
||||
candidates: newKeyedDeque(),
|
||||
lastChanged: time.Now(),
|
||||
prefix: prefix,
|
||||
}
|
||||
return bucket
|
||||
}
|
||||
|
||||
// LastChanged return the last time when it changes.
|
||||
func (bucket *kbucket) LastChanged() time.Time {
|
||||
bucket.RLock()
|
||||
defer bucket.RUnlock()
|
||||
|
||||
return bucket.lastChanged
|
||||
}
|
||||
|
||||
// RandomChildID returns a random id that has the same prefix with bucket.
|
||||
func (bucket *kbucket) RandomChildID() string {
|
||||
prefixLen := bucket.prefix.Size / 8
|
||||
|
||||
return strings.Join([]string{
|
||||
bucket.prefix.RawString()[:prefixLen],
|
||||
randomString(20 - prefixLen),
|
||||
}, "")
|
||||
}
|
||||
|
||||
// UpdateTimestamp update bucket's last changed time..
|
||||
func (bucket *kbucket) UpdateTimestamp() {
|
||||
bucket.Lock()
|
||||
defer bucket.Unlock()
|
||||
|
||||
bucket.lastChanged = time.Now()
|
||||
}
|
||||
|
||||
// Insert inserts node to the bucket. It returns whether the node is new in
|
||||
// the bucket.
|
||||
func (bucket *kbucket) Insert(no *node) bool {
|
||||
isNew := !bucket.nodes.HasKey(no.id.RawString())
|
||||
|
||||
bucket.nodes.Push(no.id.RawString(), no)
|
||||
bucket.UpdateTimestamp()
|
||||
|
||||
return isNew
|
||||
}
|
||||
|
||||
// Replace removes node, then put bucket.candidates.Back() to the right
|
||||
// place of bucket.nodes.
|
||||
func (bucket *kbucket) Replace(no *node) {
|
||||
bucket.nodes.Delete(no.id.RawString())
|
||||
bucket.UpdateTimestamp()
|
||||
|
||||
if bucket.candidates.Len() == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
no = bucket.candidates.Remove(bucket.candidates.Back()).(*node)
|
||||
|
||||
inserted := false
|
||||
for e := range bucket.nodes.Iter() {
|
||||
if e.Value.(*node).lastActiveTime.After(
|
||||
no.lastActiveTime) && !inserted {
|
||||
|
||||
bucket.nodes.InsertBefore(no, e)
|
||||
inserted = true
|
||||
}
|
||||
}
|
||||
|
||||
if !inserted {
|
||||
bucket.nodes.PushBack(no)
|
||||
}
|
||||
}
|
||||
|
||||
// Fresh pings the expired nodes in the bucket.
|
||||
func (bucket *kbucket) Fresh(dht *DHT) {
|
||||
for e := range bucket.nodes.Iter() {
|
||||
no := e.Value.(*node)
|
||||
if time.Since(no.lastActiveTime) > dht.NodeExpriedAfter {
|
||||
dht.transactionManager.ping(no)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// routingTableNode represents routing table tree node.
|
||||
type routingTableNode struct {
|
||||
sync.RWMutex
|
||||
children []*routingTableNode
|
||||
bucket *kbucket
|
||||
}
|
||||
|
||||
// newRoutingTableNode returns a new routingTableNode pointer.
|
||||
func newRoutingTableNode(prefix *bitmap) *routingTableNode {
|
||||
return &routingTableNode{
|
||||
children: make([]*routingTableNode, 2),
|
||||
bucket: newKBucket(prefix),
|
||||
}
|
||||
}
|
||||
|
||||
// Child returns routingTableNode's left or right child.
|
||||
func (tableNode *routingTableNode) Child(index int) *routingTableNode {
|
||||
if index >= 2 {
|
||||
return nil
|
||||
}
|
||||
|
||||
tableNode.RLock()
|
||||
defer tableNode.RUnlock()
|
||||
|
||||
return tableNode.children[index]
|
||||
}
|
||||
|
||||
// SetChild sets routingTableNode's left or right child. When index is 0, it's
|
||||
// the left child, if 1, it's the right child.
|
||||
func (tableNode *routingTableNode) SetChild(index int, c *routingTableNode) {
|
||||
tableNode.Lock()
|
||||
defer tableNode.Unlock()
|
||||
|
||||
tableNode.children[index] = c
|
||||
}
|
||||
|
||||
// KBucket returns the bucket routingTableNode holds.
|
||||
func (tableNode *routingTableNode) KBucket() *kbucket {
|
||||
tableNode.RLock()
|
||||
defer tableNode.RUnlock()
|
||||
|
||||
return tableNode.bucket
|
||||
}
|
||||
|
||||
// SetKBucket sets the bucket.
|
||||
func (tableNode *routingTableNode) SetKBucket(bucket *kbucket) {
|
||||
tableNode.Lock()
|
||||
defer tableNode.Unlock()
|
||||
|
||||
tableNode.bucket = bucket
|
||||
}
|
||||
|
||||
// Split splits current routingTableNode and sets it's two children.
|
||||
func (tableNode *routingTableNode) Split() {
|
||||
prefixLen := tableNode.KBucket().prefix.Size
|
||||
|
||||
if prefixLen == maxPrefixLength {
|
||||
return
|
||||
}
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
tableNode.SetChild(i, newRoutingTableNode(newBitmapFrom(
|
||||
tableNode.KBucket().prefix, prefixLen+1)))
|
||||
}
|
||||
|
||||
tableNode.Lock()
|
||||
tableNode.children[1].bucket.prefix.Set(prefixLen)
|
||||
tableNode.Unlock()
|
||||
|
||||
for e := range tableNode.KBucket().nodes.Iter() {
|
||||
nd := e.Value.(*node)
|
||||
tableNode.Child(nd.id.Bit(prefixLen)).KBucket().nodes.PushBack(nd)
|
||||
}
|
||||
|
||||
for e := range tableNode.KBucket().candidates.Iter() {
|
||||
nd := e.Value.(*node)
|
||||
tableNode.Child(nd.id.Bit(prefixLen)).KBucket().candidates.PushBack(nd)
|
||||
}
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
tableNode.Child(i).KBucket().UpdateTimestamp()
|
||||
}
|
||||
}
|
||||
|
||||
// routingTable implements the routing table in DHT protocol.
|
||||
type routingTable struct {
|
||||
*sync.RWMutex
|
||||
k int
|
||||
root *routingTableNode
|
||||
cachedNodes *syncedMap
|
||||
cachedKBuckets *keyedDeque
|
||||
dht *DHT
|
||||
clearQueue *syncedList
|
||||
}
|
||||
|
||||
// newRoutingTable returns a new routingTable pointer.
|
||||
func newRoutingTable(k int, dht *DHT) *routingTable {
|
||||
root := newRoutingTableNode(newBitmap(0))
|
||||
|
||||
rt := &routingTable{
|
||||
RWMutex: &sync.RWMutex{},
|
||||
k: k,
|
||||
root: root,
|
||||
cachedNodes: newSyncedMap(),
|
||||
cachedKBuckets: newKeyedDeque(),
|
||||
dht: dht,
|
||||
clearQueue: newSyncedList(),
|
||||
}
|
||||
|
||||
rt.cachedKBuckets.Push(root.bucket.prefix.String(), root.bucket)
|
||||
return rt
|
||||
}
|
||||
|
||||
// Insert adds a node to routing table. It returns whether the node is new
|
||||
// in the routingtable.
|
||||
func (rt *routingTable) Insert(nd *node) bool {
|
||||
rt.Lock()
|
||||
defer rt.Unlock()
|
||||
|
||||
if rt.dht.blackList.in(nd.addr.IP.String(), nd.addr.Port) ||
|
||||
rt.cachedNodes.Len() >= rt.dht.MaxNodes {
|
||||
return false
|
||||
}
|
||||
|
||||
var (
|
||||
next *routingTableNode
|
||||
bucket *kbucket
|
||||
)
|
||||
root := rt.root
|
||||
|
||||
for prefixLen := 1; prefixLen <= maxPrefixLength; prefixLen++ {
|
||||
next = root.Child(nd.id.Bit(prefixLen - 1))
|
||||
|
||||
if next != nil {
|
||||
// If next is not the leaf.
|
||||
root = next
|
||||
} else if root.KBucket().nodes.Len() < rt.k ||
|
||||
root.KBucket().nodes.HasKey(nd.id.RawString()) {
|
||||
|
||||
bucket = root.KBucket()
|
||||
isNew := bucket.Insert(nd)
|
||||
|
||||
rt.cachedNodes.Set(nd.addr.String(), nd)
|
||||
rt.cachedKBuckets.Push(bucket.prefix.String(), bucket)
|
||||
|
||||
return isNew
|
||||
} else if root.KBucket().prefix.Compare(nd.id, prefixLen-1) == 0 {
|
||||
// If node has the same prefix with bucket, split it.
|
||||
|
||||
root.Split()
|
||||
|
||||
rt.cachedKBuckets.Delete(root.KBucket().prefix.String())
|
||||
root.SetKBucket(nil)
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
bucket = root.Child(i).KBucket()
|
||||
rt.cachedKBuckets.Push(bucket.prefix.String(), bucket)
|
||||
}
|
||||
|
||||
root = root.Child(nd.id.Bit(prefixLen - 1))
|
||||
} else {
|
||||
// Finally, store node as a candidate and fresh the bucket.
|
||||
root.KBucket().candidates.PushBack(nd)
|
||||
if root.KBucket().candidates.Len() > rt.k {
|
||||
root.KBucket().candidates.Remove(
|
||||
root.KBucket().candidates.Front())
|
||||
}
|
||||
|
||||
go root.KBucket().Fresh(rt.dht)
|
||||
return false
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetNeighbors returns the size-length nodes closest to id.
|
||||
func (rt *routingTable) GetNeighbors(id *bitmap, size int) []*node {
|
||||
rt.RLock()
|
||||
nodes := make([]interface{}, 0, rt.cachedNodes.Len())
|
||||
for item := range rt.cachedNodes.Iter() {
|
||||
nodes = append(nodes, item.val.(*node))
|
||||
}
|
||||
rt.RUnlock()
|
||||
|
||||
neighbors := getTopK(nodes, id, size)
|
||||
result := make([]*node, len(neighbors))
|
||||
|
||||
for i, nd := range neighbors {
|
||||
result[i] = nd.(*node)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// GetNeighborIds return the size-length compact node info closest to id.
|
||||
func (rt *routingTable) GetNeighborCompactInfos(id *bitmap, size int) []string {
|
||||
neighbors := rt.GetNeighbors(id, size)
|
||||
infos := make([]string, len(neighbors))
|
||||
|
||||
for i, no := range neighbors {
|
||||
infos[i] = no.CompactNodeInfo()
|
||||
}
|
||||
|
||||
return infos
|
||||
}
|
||||
|
||||
// GetNodeKBucktById returns node whose id is `id` and the bucket it
|
||||
// belongs to.
|
||||
func (rt *routingTable) GetNodeKBucktByID(id *bitmap) (
|
||||
nd *node, bucket *kbucket) {
|
||||
|
||||
rt.RLock()
|
||||
defer rt.RUnlock()
|
||||
|
||||
var next *routingTableNode
|
||||
root := rt.root
|
||||
|
||||
for prefixLen := 1; prefixLen <= maxPrefixLength; prefixLen++ {
|
||||
next = root.Child(id.Bit(prefixLen - 1))
|
||||
if next == nil {
|
||||
v, ok := root.KBucket().nodes.Get(id.RawString())
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
nd, bucket = v.Value.(*node), root.KBucket()
|
||||
return
|
||||
}
|
||||
root = next
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// GetNodeByAddress finds node by address.
|
||||
func (rt *routingTable) GetNodeByAddress(address string) (no *node, ok bool) {
|
||||
rt.RLock()
|
||||
defer rt.RUnlock()
|
||||
|
||||
v, ok := rt.cachedNodes.Get(address)
|
||||
if ok {
|
||||
no = v.(*node)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Remove deletes the node whose id is `id`.
|
||||
func (rt *routingTable) Remove(id *bitmap) {
|
||||
if nd, bucket := rt.GetNodeKBucktByID(id); nd != nil {
|
||||
bucket.Replace(nd)
|
||||
rt.cachedNodes.Delete(nd.addr.String())
|
||||
rt.cachedKBuckets.Push(bucket.prefix.String(), bucket)
|
||||
}
|
||||
}
|
||||
|
||||
// Remove deletes the node whose address is `ip:port`.
|
||||
func (rt *routingTable) RemoveByAddr(address string) {
|
||||
v, ok := rt.cachedNodes.Get(address)
|
||||
if ok {
|
||||
rt.Remove(v.(*node).id)
|
||||
}
|
||||
}
|
||||
|
||||
// Fresh sends findNode to all nodes in the expired nodes.
|
||||
func (rt *routingTable) Fresh() {
|
||||
now := time.Now()
|
||||
|
||||
for e := range rt.cachedKBuckets.Iter() {
|
||||
bucket := e.Value.(*kbucket)
|
||||
if now.Sub(bucket.LastChanged()) < rt.dht.KBucketExpiredAfter ||
|
||||
bucket.nodes.Len() == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
i := 0
|
||||
for e := range bucket.nodes.Iter() {
|
||||
if i < rt.dht.RefreshNodeNum {
|
||||
no := e.Value.(*node)
|
||||
rt.dht.transactionManager.findNode(no, bucket.RandomChildID())
|
||||
rt.clearQueue.PushBack(no)
|
||||
}
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
if rt.dht.IsCrawlMode() {
|
||||
for e := range rt.clearQueue.Iter() {
|
||||
rt.Remove(e.Value.(*node).id)
|
||||
}
|
||||
}
|
||||
|
||||
rt.clearQueue.Clear()
|
||||
}
|
||||
|
||||
// Len returns the number of nodes in table.
|
||||
func (rt *routingTable) Len() int {
|
||||
rt.RLock()
|
||||
defer rt.RUnlock()
|
||||
|
||||
return rt.cachedNodes.Len()
|
||||
}
|
||||
|
||||
// Implementation of heap with heap.Interface.
|
||||
type heapItem struct {
|
||||
distance *bitmap
|
||||
value interface{}
|
||||
}
|
||||
|
||||
type topKHeap []*heapItem
|
||||
|
||||
func (kHeap topKHeap) Len() int {
|
||||
return len(kHeap)
|
||||
}
|
||||
|
||||
func (kHeap topKHeap) Less(i, j int) bool {
|
||||
return kHeap[i].distance.Compare(kHeap[j].distance, maxPrefixLength) == 1
|
||||
}
|
||||
|
||||
func (kHeap topKHeap) Swap(i, j int) {
|
||||
kHeap[i], kHeap[j] = kHeap[j], kHeap[i]
|
||||
}
|
||||
|
||||
func (kHeap *topKHeap) Push(x interface{}) {
|
||||
*kHeap = append(*kHeap, x.(*heapItem))
|
||||
}
|
||||
|
||||
func (kHeap *topKHeap) Pop() interface{} {
|
||||
n := len(*kHeap)
|
||||
x := (*kHeap)[n-1]
|
||||
*kHeap = (*kHeap)[:n-1]
|
||||
return x
|
||||
}
|
||||
|
||||
// getTopK solves the top-k problem with heap. It's time complexity is
|
||||
// O(n*log(k)). When n is large, time complexity will be too high, need to be
|
||||
// optimized.
|
||||
func getTopK(queue []interface{}, id *bitmap, k int) []interface{} {
|
||||
topkHeap := make(topKHeap, 0, k+1)
|
||||
|
||||
for _, value := range queue {
|
||||
node := value.(*node)
|
||||
item := &heapItem{
|
||||
id.Xor(node.id),
|
||||
value,
|
||||
}
|
||||
heap.Push(&topkHeap, item)
|
||||
if topkHeap.Len() > k {
|
||||
heap.Pop(&topkHeap)
|
||||
}
|
||||
}
|
||||
|
||||
tops := make([]interface{}, topkHeap.Len())
|
||||
for i := len(tops) - 1; i >= 0; i-- {
|
||||
tops[i] = heap.Pop(&topkHeap).(*heapItem).value
|
||||
}
|
||||
|
||||
return tops
|
||||
}
|
23
dht/sample/getpeers/getpeers.go
Normal file
23
dht/sample/getpeers/getpeers.go
Normal file
|
@ -0,0 +1,23 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/shiyanhui/dht"
|
||||
"time"
|
||||
)
|
||||
|
||||
func main() {
|
||||
d := dht.New(nil)
|
||||
go d.Run()
|
||||
|
||||
for {
|
||||
// ubuntu-14.04.2-desktop-amd64.iso
|
||||
peers, err := d.GetPeers("546cf15f724d19c4319cc17b179d7e035f89c1f4")
|
||||
if err != nil {
|
||||
time.Sleep(time.Second * 1)
|
||||
continue
|
||||
}
|
||||
|
||||
fmt.Println("Found peers:", peers)
|
||||
}
|
||||
}
|
77
dht/sample/spider/spider.go
Normal file
77
dht/sample/spider/spider.go
Normal file
|
@ -0,0 +1,77 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/shiyanhui/dht"
|
||||
"net/http"
|
||||
_ "net/http/pprof"
|
||||
)
|
||||
|
||||
type file struct {
|
||||
Path []interface{} `json:"path"`
|
||||
Length int `json:"length"`
|
||||
}
|
||||
|
||||
type bitTorrent struct {
|
||||
InfoHash string `json:"infohash"`
|
||||
Name string `json:"name"`
|
||||
Files []file `json:"files,omitempty"`
|
||||
Length int `json:"length,omitempty"`
|
||||
}
|
||||
|
||||
func main() {
|
||||
go func() {
|
||||
http.ListenAndServe(":6060", nil)
|
||||
}()
|
||||
|
||||
w := dht.NewWire(65536, 1024, 256)
|
||||
go func() {
|
||||
for resp := range w.Response() {
|
||||
metadata, err := dht.Decode(resp.MetadataInfo)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
info := metadata.(map[string]interface{})
|
||||
|
||||
if _, ok := info["name"]; !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
bt := bitTorrent{
|
||||
InfoHash: hex.EncodeToString(resp.InfoHash),
|
||||
Name: info["name"].(string),
|
||||
}
|
||||
|
||||
if v, ok := info["files"]; ok {
|
||||
files := v.([]interface{})
|
||||
bt.Files = make([]file, len(files))
|
||||
|
||||
for i, item := range files {
|
||||
f := item.(map[string]interface{})
|
||||
bt.Files[i] = file{
|
||||
Path: f["path"].([]interface{}),
|
||||
Length: f["length"].(int),
|
||||
}
|
||||
}
|
||||
} else if _, ok := info["length"]; ok {
|
||||
bt.Length = info["length"].(int)
|
||||
}
|
||||
|
||||
data, err := json.Marshal(bt)
|
||||
if err == nil {
|
||||
fmt.Printf("%s\n\n", data)
|
||||
}
|
||||
}
|
||||
}()
|
||||
go w.Run()
|
||||
|
||||
config := dht.NewCrawlConfig()
|
||||
config.OnAnnouncePeer = func(infoHash, ip string, port int) {
|
||||
w.Request([]byte(infoHash), ip, port)
|
||||
}
|
||||
d := dht.New(config)
|
||||
|
||||
d.Run()
|
||||
}
|
134
dht/util.go
Normal file
134
dht/util.go
Normal file
|
@ -0,0 +1,134 @@
|
|||
package dht
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// randomString generates a size-length string randomly.
|
||||
func randomString(size int) string {
|
||||
buff := make([]byte, size)
|
||||
rand.Read(buff)
|
||||
return string(buff)
|
||||
}
|
||||
|
||||
// bytes2int returns the int value it represents.
|
||||
func bytes2int(data []byte) uint64 {
|
||||
n, val := len(data), uint64(0)
|
||||
if n > 8 {
|
||||
panic("data too long")
|
||||
}
|
||||
|
||||
for i, b := range data {
|
||||
val += uint64(b) << uint64((n-i-1)*8)
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
||||
// int2bytes returns the byte array it represents.
|
||||
func int2bytes(val uint64) []byte {
|
||||
data, j := make([]byte, 8), -1
|
||||
for i := 0; i < 8; i++ {
|
||||
shift := uint64((7 - i) * 8)
|
||||
data[i] = byte((val & (0xff << shift)) >> shift)
|
||||
|
||||
if j == -1 && data[i] != 0 {
|
||||
j = i
|
||||
}
|
||||
}
|
||||
|
||||
if j != -1 {
|
||||
return data[j:]
|
||||
}
|
||||
return data[:1]
|
||||
}
|
||||
|
||||
// decodeCompactIPPortInfo decodes compactIP-address/port info in BitTorrent
|
||||
// DHT Protocol. It returns the ip and port number.
|
||||
func decodeCompactIPPortInfo(info string) (ip net.IP, port int, err error) {
|
||||
if len(info) != 6 {
|
||||
err = errors.New("compact info should be 6-length long")
|
||||
return
|
||||
}
|
||||
|
||||
ip = net.IPv4(info[0], info[1], info[2], info[3])
|
||||
port = int((uint16(info[4]) << 8) | uint16(info[5]))
|
||||
return
|
||||
}
|
||||
|
||||
// encodeCompactIPPortInfo encodes an ip and a port number to
|
||||
// compactIP-address/port info.
|
||||
func encodeCompactIPPortInfo(ip net.IP, port int) (info string, err error) {
|
||||
if port > 65535 || port < 0 {
|
||||
err = errors.New(
|
||||
"port should be no greater than 65535 and no less than 0")
|
||||
return
|
||||
}
|
||||
|
||||
p := int2bytes(uint64(port))
|
||||
if len(p) < 2 {
|
||||
p = append(p, p[0])
|
||||
p[0] = 0
|
||||
}
|
||||
|
||||
info = string(append(ip, p...))
|
||||
return
|
||||
}
|
||||
|
||||
// getLocalIPs returns local ips.
|
||||
func getLocalIPs() (ips []string) {
|
||||
ips = make([]string, 0, 6)
|
||||
|
||||
addrs, err := net.InterfaceAddrs()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, addr := range addrs {
|
||||
ip, _, err := net.ParseCIDR(addr.String())
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
ips = append(ips, ip.String())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// getRemoteIP returns the wlan ip.
|
||||
func getRemoteIP() (ip string, err error) {
|
||||
client := &http.Client{
|
||||
Timeout: time.Second * 30,
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("GET", "http://ifconfig.me", nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
req.Header.Set("User-Agent", "curl")
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
defer res.Body.Close()
|
||||
|
||||
data, err := ioutil.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
ip = string(data)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// genAddress returns a ip:port address.
|
||||
func genAddress(ip string, port int) string {
|
||||
return strings.Join([]string{ip, strconv.Itoa(port)}, ":")
|
||||
}
|
100
dht/util_test.go
Normal file
100
dht/util_test.go
Normal file
|
@ -0,0 +1,100 @@
|
|||
package dht
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestInt2Bytes(t *testing.T) {
|
||||
cases := []struct {
|
||||
in uint64
|
||||
out []byte
|
||||
}{
|
||||
{0, []byte{0}},
|
||||
{1, []byte{1}},
|
||||
{256, []byte{1, 0}},
|
||||
{22129, []byte{86, 113}},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
r := int2bytes(c.in)
|
||||
if len(r) != len(c.out) {
|
||||
t.Fail()
|
||||
}
|
||||
|
||||
for i, v := range r {
|
||||
if v != c.out[i] {
|
||||
t.Fail()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBytes2Int(t *testing.T) {
|
||||
cases := []struct {
|
||||
in []byte
|
||||
out uint64
|
||||
}{
|
||||
{[]byte{0}, 0},
|
||||
{[]byte{1}, 1},
|
||||
{[]byte{1, 0}, 256},
|
||||
{[]byte{86, 113}, 22129},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
if bytes2int(c.in) != c.out {
|
||||
t.Fail()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeCompactIPPortInfo(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
out struct {
|
||||
ip string
|
||||
port int
|
||||
}
|
||||
}{
|
||||
{"123456", struct {
|
||||
ip string
|
||||
port int
|
||||
}{"49.50.51.52", 13622}},
|
||||
{"abcdef", struct {
|
||||
ip string
|
||||
port int
|
||||
}{"97.98.99.100", 25958}},
|
||||
}
|
||||
|
||||
for _, item := range cases {
|
||||
ip, port, err := decodeCompactIPPortInfo(item.in)
|
||||
if err != nil || ip.String() != item.out.ip || port != item.out.port {
|
||||
t.Fail()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeCompactIPPortInfo(t *testing.T) {
|
||||
cases := []struct {
|
||||
in struct {
|
||||
ip []byte
|
||||
port int
|
||||
}
|
||||
out string
|
||||
}{
|
||||
{struct {
|
||||
ip []byte
|
||||
port int
|
||||
}{[]byte{49, 50, 51, 52}, 13622}, "123456"},
|
||||
{struct {
|
||||
ip []byte
|
||||
port int
|
||||
}{[]byte{97, 98, 99, 100}, 25958}, "abcdef"},
|
||||
}
|
||||
|
||||
for _, item := range cases {
|
||||
info, err := encodeCompactIPPortInfo(item.in.ip, item.in.port)
|
||||
if err != nil || info != item.out {
|
||||
t.Fail()
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue