This commit is contained in:
Alex Grintsvayg 2017-08-16 11:52:19 -04:00
parent 09745cbdea
commit 1f26aeeb5c
22 changed files with 3869 additions and 0 deletions

1
.gitignore vendored Normal file
View file

@ -0,0 +1 @@
/.idea

1
dht/.gitignore vendored Normal file
View file

@ -0,0 +1 @@
.DS_Store

21
dht/LICENSE Normal file
View file

@ -0,0 +1,21 @@
The MIT License (MIT)
Copyright (c) 2015 Dean Karn
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

87
dht/README.md Normal file
View file

@ -0,0 +1,87 @@
![](https://raw.githubusercontent.com/shiyanhui/dht/master/doc/screen-shot.png)
See the video on the [Youtube](https://www.youtube.com/watch?v=AIpeQtw22kc).
[中文版README](https://github.com/shiyanhui/dht/blob/master/README_CN.md)
## Introduction
DHT implements the bittorrent DHT protocol in Go. Now it includes:
- [BEP-3 (part)](http://www.bittorrent.org/beps/bep_0003.html)
- [BEP-5](http://www.bittorrent.org/beps/bep_0005.html)
- [BEP-9](http://www.bittorrent.org/beps/bep_0009.html)
- [BEP-10](http://www.bittorrent.org/beps/bep_0010.html)
It contains two modes, the standard mode and the crawling mode. The standard
mode follows the BEPs, and you can use it as a standard dht server. The crawling
mode aims to crawl as more metadata info as possiple. It doesn't follow the
standard BEPs protocol. With the crawling mode, you can build another [BTDigg](http://btdigg.org/).
[bthub.io](http://bthub.io) is a BT search engine based on the crawling mode.
## Installation
go get github.com/shiyanhui/dht
## Example
Below is a simple spider. You can move [here](https://github.com/shiyanhui/dht/blob/master/sample)
to see more samples.
```go
import (
"fmt"
"github.com/shiyanhui/dht"
)
func main() {
downloader := dht.NewWire(65535)
go func() {
// once we got the request result
for resp := range downloader.Response() {
fmt.Println(resp.InfoHash, resp.MetadataInfo)
}
}()
go downloader.Run()
config := dht.NewCrawlConfig()
config.OnAnnouncePeer = func(infoHash, ip string, port int) {
// request to download the metadata info
downloader.Request([]byte(infoHash), ip, port)
}
d := dht.New(config)
d.Run()
}
```
## Download
You can download the demo compiled binary file [here](https://github.com/shiyanhui/dht/files/407021/spider.zip).
## Note
- The default crawl mode configure costs about 300M RAM. Set **MaxNodes**
and **BlackListMaxSize** to fit yourself.
- Now it cant't run in LAN because of NAT.
## TODO
- [ ] NAT Traversal.
- [ ] Implements the full BEP-3.
- [ ] Optimization.
## FAQ
#### Why it is slow compared to other spiders ?
Well, maybe there are several reasons.
- DHT aims to implements the standard BitTorrent DHT protocol, not born for crawling the DHT network.
- NAT Traversal issue. You run the crawler in a local network.
- It will block ip which looks like bad and a good ip may be mis-judged.
## License
MIT, read more [here](https://github.com/shiyanhui/dht/blob/master/LICENSE)

78
dht/README_CN.md Normal file
View file

@ -0,0 +1,78 @@
![](https://raw.githubusercontent.com/shiyanhui/dht/master/doc/screen-shot.png)
在这个视频上你可以看到爬取效果[Youtube](https://www.youtube.com/watch?v=AIpeQtw22kc).
## Introduction
DHT实现了BitTorrent DHT协议主要包括
- [BEP-3 (部分)](http://www.bittorrent.org/beps/bep_0003.html)
- [BEP-5](http://www.bittorrent.org/beps/bep_0005.html)
- [BEP-9](http://www.bittorrent.org/beps/bep_0009.html)
- [BEP-10](http://www.bittorrent.org/beps/bep_0010.html)
它包含两种模式标准模式和爬虫模式。标准模式遵循DHT协议你可以把它当做一个标准
的DHT组件。爬虫模式是为了嗅探到更多torrent文件信息它在某些方面不遵循DHT协议。
基于爬虫模式,你可以打造你自己的[BTDigg](http://btdigg.org/)。
[bthub.io](http://bthub.io)是一个基于这个爬虫而建的BT搜索引擎你可以把他当做
BTDigg的替代品。
## Installation
go get github.com/shiyanhui/dht
## Example
下面是一个简单的爬虫例子,你可以到[这里](https://github.com/shiyanhui/dht/blob/master/sample)看完整的Demo。
```go
import (
"fmt"
"github.com/shiyanhui/dht"
)
func main() {
downloader := dht.NewWire(65536)
go func() {
// once we got the request result
for resp := range downloader.Response() {
fmt.Println(resp.InfoHash, resp.MetadataInfo)
}
}()
go downloader.Run()
config := dht.NewCrawlConfig()
config.OnAnnouncePeer = func(infoHash, ip string, port int) {
// request to download the metadata info
downloader.Request([]byte(infoHash), ip, port)
}
d := dht.New(config)
d.Run()
}
```
## Download
这个是已经编译好的Demo二进制文件你可以到这里[下载](https://github.com/shiyanhui/dht/files/407021/spider.zip)。
## 注意
- 默认的爬虫配置需要300M左右内存你可以根据你的服务器内存大小调整MaxNodes和
BlackListMaxSize
- 目前还不能穿透NAT因此还不能在局域网运行
## TODO
- [ ] NAT穿透在局域网内也能够运行
- [ ] 完整地实现BEP-3这样不但能够下载种子也能够下载资源
- [ ] 优化
## Blog
你可以在[这里](https://github.com/shiyanhui/dht/wiki)看到DHT Spider教程。
## License
[MIT](https://github.com/shiyanhui/dht/blob/master/LICENSE)

263
dht/bencode.go Normal file
View file

@ -0,0 +1,263 @@
package dht
import (
"bytes"
"errors"
"strconv"
"strings"
"unicode"
"unicode/utf8"
)
// find returns the index of first target in data starting from `start`.
// It returns -1 if target not found.
func find(data []byte, start int, target rune) (index int) {
index = bytes.IndexRune(data[start:], target)
if index != -1 {
return index + start
}
return index
}
// DecodeString decodes a string in the data. It returns a tuple
// (decoded result, the end position, error).
func DecodeString(data []byte, start int) (
result interface{}, index int, err error) {
if start >= len(data) || data[start] < '0' || data[start] > '9' {
err = errors.New("invalid string bencode")
return
}
i := find(data, start, ':')
if i == -1 {
err = errors.New("':' not found when decode string")
return
}
length, err := strconv.Atoi(string(data[start:i]))
if err != nil {
return
}
if length < 0 {
err = errors.New("invalid length of string")
return
}
index = i + 1 + length
if index > len(data) || index < i+1 {
err = errors.New("out of range")
return
}
result = string(data[i+1 : index])
return
}
// DecodeInt decodes int value in the data.
func DecodeInt(data []byte, start int) (
result interface{}, index int, err error) {
if start >= len(data) || data[start] != 'i' {
err = errors.New("invalid int bencode")
return
}
index = find(data, start+1, 'e')
if index == -1 {
err = errors.New("':' not found when decode int")
return
}
result, err = strconv.Atoi(string(data[start+1 : index]))
if err != nil {
return
}
index++
return
}
// decodeItem decodes an item of dict or list.
func decodeItem(data []byte, i int) (
result interface{}, index int, err error) {
var decodeFunc = []func([]byte, int) (interface{}, int, error){
DecodeString, DecodeInt, DecodeList, DecodeDict,
}
for _, f := range decodeFunc {
result, index, err = f(data, i)
if err == nil {
return
}
}
err = errors.New("invalid bencode when decode item")
return
}
// DecodeList decodes a list value.
func DecodeList(data []byte, start int) (
result interface{}, index int, err error) {
if start >= len(data) || data[start] != 'l' {
err = errors.New("invalid list bencode")
return
}
var item interface{}
r := make([]interface{}, 0, 8)
index = start + 1
for index < len(data) {
char, _ := utf8.DecodeRune(data[index:])
if char == 'e' {
break
}
item, index, err = decodeItem(data, index)
if err != nil {
return
}
r = append(r, item)
}
if index == len(data) {
err = errors.New("'e' not found when decode list")
return
}
index++
result = r
return
}
// DecodeDict decodes a map value.
func DecodeDict(data []byte, start int) (
result interface{}, index int, err error) {
if start >= len(data) || data[start] != 'd' {
err = errors.New("invalid dict bencode")
return
}
var item, key interface{}
r := make(map[string]interface{})
index = start + 1
for index < len(data) {
char, _ := utf8.DecodeRune(data[index:])
if char == 'e' {
break
}
if !unicode.IsDigit(char) {
err = errors.New("invalid dict bencode")
return
}
key, index, err = DecodeString(data, index)
if err != nil {
return
}
if index >= len(data) {
err = errors.New("out of range")
return
}
item, index, err = decodeItem(data, index)
if err != nil {
return
}
r[key.(string)] = item
}
if index == len(data) {
err = errors.New("'e' not found when decode dict")
return
}
index++
result = r
return
}
// Decode decodes a bencoded string to string, int, list or map.
func Decode(data []byte) (result interface{}, err error) {
result, _, err = decodeItem(data, 0)
return
}
// EncodeString encodes a string value.
func EncodeString(data string) string {
return strings.Join([]string{strconv.Itoa(len(data)), data}, ":")
}
// EncodeInt encodes a int value.
func EncodeInt(data int) string {
return strings.Join([]string{"i", strconv.Itoa(data), "e"}, "")
}
// EncodeItem encodes an item of dict or list.
func encodeItem(data interface{}) (item string) {
switch v := data.(type) {
case string:
item = EncodeString(v)
case int:
item = EncodeInt(v)
case []interface{}:
item = EncodeList(v)
case map[string]interface{}:
item = EncodeDict(v)
default:
panic("invalid type when encode item")
}
return
}
// EncodeList encodes a list value.
func EncodeList(data []interface{}) string {
result := make([]string, len(data))
for i, item := range data {
result[i] = encodeItem(item)
}
return strings.Join([]string{"l", strings.Join(result, ""), "e"}, "")
}
// EncodeDict encodes a dict value.
func EncodeDict(data map[string]interface{}) string {
result, i := make([]string, len(data)), 0
for key, val := range data {
result[i] = strings.Join(
[]string{EncodeString(key), encodeItem(val)},
"")
i++
}
return strings.Join([]string{"d", strings.Join(result, ""), "e"}, "")
}
// Encode encodes a string, int, dict or list value to a bencoded string.
func Encode(data interface{}) string {
switch v := data.(type) {
case string:
return EncodeString(v)
case int:
return EncodeInt(v)
case []interface{}:
return EncodeList(v)
case map[string]interface{}:
return EncodeDict(v)
default:
panic("invalid type when encode")
}
}

159
dht/bencode_test.go Normal file
View file

@ -0,0 +1,159 @@
package dht
import (
"testing"
)
func TestDecodeString(t *testing.T) {
cases := []struct {
in string
out string
}{
{"0:", ""},
{"1:a", "a"},
{"5:hello", "hello"},
}
for _, c := range cases {
if out, err := Decode([]byte(c.in)); err != nil || out != c.out {
t.Error(err)
}
}
}
func TestDecodeInt(t *testing.T) {
cases := []struct {
in string
out int
}{
{"i123e:", 123},
{"i0e", 0},
{"i-1e", -1},
}
for _, c := range cases {
if out, err := Decode([]byte(c.in)); err != nil || out != c.out {
t.Error(err)
}
}
}
func TestDecodeList(t *testing.T) {
cases := []struct {
in string
out []interface{}
}{
{"li123ei-1ee", []interface{}{123, -1}},
{"l5:helloe", []interface{}{"hello"}},
{"ld5:hello5:worldee", []interface{}{map[string]interface{}{"hello": "world"}}},
{"lli1ei2eee", []interface{}{[]interface{}{1, 2}}},
}
for i, c := range cases {
v, err := Decode([]byte(c.in))
if err != nil {
t.Fail()
}
out := v.([]interface{})
switch i {
case 0, 1:
for j, item := range out {
if item != c.out[j] {
t.Fail()
}
}
case 2:
if len(out) != 1 {
t.Fail()
}
o := out[0].(map[string]interface{})
cout := c.out[0].(map[string]interface{})
for k, v := range o {
if cv, ok := cout[k]; !ok || v != cv {
t.Fail()
}
}
case 3:
if len(out) != 1 {
t.Fail()
}
o := out[0].([]interface{})
cout := c.out[0].([]interface{})
for j, item := range o {
if item != cout[j] {
t.Fail()
}
}
}
}
}
func TestDecodeDict(t *testing.T) {
cases := []struct {
in string
out map[string]interface{}
}{
{"d5:helloi100ee", map[string]interface{}{"hello": 100}},
{"d3:foo3:bare", map[string]interface{}{"foo": "bar"}},
{"d1:ad3:foo3:baree", map[string]interface{}{"a": map[string]interface{}{"foo": "bar"}}},
{"d4:listli1eee", map[string]interface{}{"list": []interface{}{1}}},
}
for i, c := range cases {
v, err := Decode([]byte(c.in))
if err != nil {
t.Fail()
}
out := v.(map[string]interface{})
switch i {
case 0, 1:
for k, v := range out {
if cv, ok := c.out[k]; !ok || v != cv {
t.Fail()
}
}
case 2:
if len(out) != 1 {
t.Fail()
}
v, ok := out["a"]
if !ok {
t.Fail()
}
cout := c.out["a"].(map[string]interface{})
for k, v := range v.(map[string]interface{}) {
if cv, ok := cout[k]; !ok || v != cv {
t.Fail()
}
}
case 3:
if len(out) != 1 {
t.Fail()
}
v, ok := out["list"]
if !ok {
t.Fail()
}
cout := c.out["list"].([]interface{})
for j, v := range v.([]interface{}) {
if v != cout[j] {
t.Fail()
}
}
}
}
}

163
dht/bitmap.go Normal file
View file

@ -0,0 +1,163 @@
package dht
import (
"fmt"
"strings"
)
// bitmap represents a bit array.
type bitmap struct {
Size int
data []byte
}
// newBitmap returns a size-length bitmap pointer.
func newBitmap(size int) *bitmap {
div, mod := size/8, size%8
if mod > 0 {
div++
}
return &bitmap{size, make([]byte, div)}
}
// newBitmapFrom returns a new copyed bitmap pointer which
// newBitmap.data = other.data[:size].
func newBitmapFrom(other *bitmap, size int) *bitmap {
bitmap := newBitmap(size)
if size > other.Size {
size = other.Size
}
div := size / 8
for i := 0; i < div; i++ {
bitmap.data[i] = other.data[i]
}
for i := div * 8; i < size; i++ {
if other.Bit(i) == 1 {
bitmap.Set(i)
}
}
return bitmap
}
// newBitmapFromBytes returns a bitmap pointer created from a byte array.
func newBitmapFromBytes(data []byte) *bitmap {
bitmap := newBitmap(len(data) * 8)
copy(bitmap.data, data)
return bitmap
}
// newBitmapFromString returns a bitmap pointer created from a string.
func newBitmapFromString(data string) *bitmap {
return newBitmapFromBytes([]byte(data))
}
// Bit returns the bit at index.
func (bitmap *bitmap) Bit(index int) int {
if index >= bitmap.Size {
panic("index out of range")
}
div, mod := index/8, index%8
return int((uint(bitmap.data[div]) & (1 << uint(7-mod))) >> uint(7-mod))
}
// set sets the bit at index `index`. If bit is true, set 1, otherwise set 0.
func (bitmap *bitmap) set(index int, bit int) {
if index >= bitmap.Size {
panic("index out of range")
}
div, mod := index/8, index%8
shift := byte(1 << uint(7-mod))
bitmap.data[div] &= ^shift
if bit > 0 {
bitmap.data[div] |= shift
}
}
// Set sets the bit at idnex to 1.
func (bitmap *bitmap) Set(index int) {
bitmap.set(index, 1)
}
// Unset sets the bit at idnex to 0.
func (bitmap *bitmap) Unset(index int) {
bitmap.set(index, 0)
}
// Compare compares the prefixLen-prefix of two bitmap.
// - If bitmap.data[:prefixLen] < other.data[:prefixLen], return -1.
// - If bitmap.data[:prefixLen] > other.data[:prefixLen], return 1.
// - Otherwise return 0.
func (bitmap *bitmap) Compare(other *bitmap, prefixLen int) int {
if prefixLen > bitmap.Size || prefixLen > other.Size {
panic("index out of range")
}
div, mod := prefixLen/8, prefixLen%8
for i := 0; i < div; i++ {
if bitmap.data[i] > other.data[i] {
return 1
} else if bitmap.data[i] < other.data[i] {
return -1
}
}
for i := div * 8; i < div*8+mod; i++ {
bit1, bit2 := bitmap.Bit(i), other.Bit(i)
if bit1 > bit2 {
return 1
} else if bit1 < bit2 {
return -1
}
}
return 0
}
// Xor returns the xor value of two bitmap.
func (bitmap *bitmap) Xor(other *bitmap) *bitmap {
if bitmap.Size != other.Size {
panic("size not the same")
}
distance := newBitmap(bitmap.Size)
div, mod := distance.Size/8, distance.Size%8
for i := 0; i < div; i++ {
distance.data[i] = bitmap.data[i] ^ other.data[i]
}
for i := div * 8; i < div*8+mod; i++ {
distance.set(i, bitmap.Bit(i)^other.Bit(i))
}
return distance
}
// String returns the bit sequence string of the bitmap.
func (bitmap *bitmap) String() string {
div, mod := bitmap.Size/8, bitmap.Size%8
buff := make([]string, div+mod)
for i := 0; i < div; i++ {
buff[i] = fmt.Sprintf("%08b", bitmap.data[i])
}
for i := div; i < div+mod; i++ {
buff[i] = fmt.Sprintf("%1b", bitmap.Bit(div*8+(i-div)))
}
return strings.Join(buff, "")
}
// RawString returns the string value of bitmap.data.
func (bitmap *bitmap) RawString() string {
return string(bitmap.data)
}

69
dht/bitmap_test.go Normal file
View file

@ -0,0 +1,69 @@
package dht
import (
"testing"
)
func TestBitmap(t *testing.T) {
a := newBitmap(10)
b := newBitmapFrom(a, 10)
c := newBitmapFromBytes([]byte{48, 49, 50, 51, 52, 53, 54, 55, 56, 57})
d := newBitmapFromString("0123456789")
e := newBitmap(10)
// Bit
for i := 0; i < a.Size; i++ {
if a.Bit(i) != 0 {
t.Fail()
}
}
// Compare
if c.Compare(d, d.Size) != 0 {
t.Fail()
}
// RawString
if c.RawString() != d.RawString() || c.RawString() != "0123456789" {
t.Fail()
}
// Set
b.Set(5)
if b.Bit(5) != 1 {
t.Fail()
}
// Unset
b.Unset(5)
if b.Bit(5) == 1 {
t.Fail()
}
// String
if e.String() != "0000000000" {
t.Fail()
}
e.Set(9)
if e.String() != "0000000001" {
t.Fail()
}
e.Set(2)
if e.String() != "0010000001" {
t.Fail()
}
a.Set(0)
a.Set(5)
a.Set(8)
if a.String() != "1000010010" {
t.Fail()
}
// Xor
b.Set(5)
b.Set(9)
if a.Xor(b).String() != "1000000011" {
t.Fail()
}
}

92
dht/blacklist.go Normal file
View file

@ -0,0 +1,92 @@
package dht
import (
"time"
)
// blockedItem represents a blocked node.
type blockedItem struct {
ip string
port int
createTime time.Time
}
// blackList manages the blocked nodes including which sends bad information
// and can't ping out.
type blackList struct {
list *syncedMap
maxSize int
expiredAfter time.Duration
}
// newBlackList returns a blackList pointer.
func newBlackList(size int) *blackList {
return &blackList{
list: newSyncedMap(),
maxSize: size,
expiredAfter: time.Hour * 1,
}
}
// genKey returns a key. If port is less than 0, the key wil be ip. Ohterwise
// it will be `ip:port` format.
func (bl *blackList) genKey(ip string, port int) string {
key := ip
if port >= 0 {
key = genAddress(ip, port)
}
return key
}
// insert adds a blocked item to the blacklist.
func (bl *blackList) insert(ip string, port int) {
if bl.list.Len() >= bl.maxSize {
return
}
bl.list.Set(bl.genKey(ip, port), &blockedItem{
ip: ip,
port: port,
createTime: time.Now(),
})
}
// delete removes blocked item form the blackList.
func (bl *blackList) delete(ip string, port int) {
bl.list.Delete(bl.genKey(ip, port))
}
// validate checks whether ip-port pair is in the block nodes list.
func (bl *blackList) in(ip string, port int) bool {
if _, ok := bl.list.Get(ip); ok {
return true
}
key := bl.genKey(ip, port)
v, ok := bl.list.Get(key)
if ok {
if time.Now().Sub(v.(*blockedItem).createTime) < bl.expiredAfter {
return true
}
bl.list.Delete(key)
}
return false
}
// clear cleans the expired items every 10 minutes.
func (bl *blackList) clear() {
for _ = range time.Tick(time.Minute * 10) {
keys := make([]interface{}, 0, 100)
for item := range bl.list.Iter() {
if time.Now().Sub(
item.val.(*blockedItem).createTime) > bl.expiredAfter {
keys = append(keys, item.key)
}
}
bl.list.DeleteMulti(keys)
}
}

57
dht/blacklist_test.go Normal file
View file

@ -0,0 +1,57 @@
package dht
import (
"fmt"
"testing"
)
var blacklist = newBlackList(256)
func TestGenKey(t *testing.T) {
cases := []struct {
in struct {
ip string
port int
}
out string
}{
{struct {
ip string
port int
}{"0.0.0.0", -1}, "0.0.0.0"},
{struct {
ip string
port int
}{"1.1.1.1", 8080}, "1.1.1.1:8080"},
}
for _, c := range cases {
if blacklist.genKey(c.in.ip, c.in.port) != c.out {
t.Fail()
}
}
}
func TestBlackList(t *testing.T) {
address := []struct {
ip string
port int
}{
{"0.0.0.0", -1},
{"1.1.1.1", 8080},
{"2.2.2.2", 8081},
}
for _, addr := range address {
blacklist.insert(addr.ip, addr.port)
if !blacklist.in(addr.ip, addr.port) {
t.Fail()
}
blacklist.delete(addr.ip, addr.port)
if blacklist.in(addr.ip, addr.port) {
fmt.Println(addr.ip)
t.Fail()
}
}
}

289
dht/container.go Normal file
View file

@ -0,0 +1,289 @@
package dht
import (
"container/list"
"sync"
)
type mapItem struct {
key interface{}
val interface{}
}
// syncedMap represents a goroutine-safe map.
type syncedMap struct {
*sync.RWMutex
data map[interface{}]interface{}
}
// newSyncedMap returns a syncedMap pointer.
func newSyncedMap() *syncedMap {
return &syncedMap{
RWMutex: &sync.RWMutex{},
data: make(map[interface{}]interface{}),
}
}
// Get returns the value mapped to key.
func (smap *syncedMap) Get(key interface{}) (val interface{}, ok bool) {
smap.RLock()
defer smap.RUnlock()
val, ok = smap.data[key]
return
}
// Has returns whether the syncedMap contains the key.
func (smap *syncedMap) Has(key interface{}) bool {
_, ok := smap.Get(key)
return ok
}
// Set sets pair {key: val}.
func (smap *syncedMap) Set(key interface{}, val interface{}) {
smap.Lock()
defer smap.Unlock()
smap.data[key] = val
}
// Delete deletes the key in the map.
func (smap *syncedMap) Delete(key interface{}) {
smap.Lock()
defer smap.Unlock()
delete(smap.data, key)
}
// DeleteMulti deletes keys in batch.
func (smap *syncedMap) DeleteMulti(keys []interface{}) {
smap.Lock()
defer smap.Unlock()
for _, key := range keys {
delete(smap.data, key)
}
}
// Clear resets the data.
func (smap *syncedMap) Clear() {
smap.Lock()
defer smap.Unlock()
smap.data = make(map[interface{}]interface{})
}
// Iter returns a chan which output all items.
func (smap *syncedMap) Iter() <-chan mapItem {
ch := make(chan mapItem)
go func() {
smap.RLock()
for key, val := range smap.data {
ch <- mapItem{
key: key,
val: val,
}
}
smap.RUnlock()
close(ch)
}()
return ch
}
// Len returns the length of syncedMap.
func (smap *syncedMap) Len() int {
smap.RLock()
defer smap.RUnlock()
return len(smap.data)
}
// syncedList represents a goroutine-safe list.
type syncedList struct {
*sync.RWMutex
queue *list.List
}
// newSyncedList returns a syncedList pointer.
func newSyncedList() *syncedList {
return &syncedList{
RWMutex: &sync.RWMutex{},
queue: list.New(),
}
}
// Front returns the first element of slist.
func (slist *syncedList) Front() *list.Element {
slist.RLock()
defer slist.RUnlock()
return slist.queue.Front()
}
// Back returns the last element of slist.
func (slist *syncedList) Back() *list.Element {
slist.RLock()
defer slist.RUnlock()
return slist.queue.Back()
}
// PushFront pushs an element to the head of slist.
func (slist *syncedList) PushFront(v interface{}) *list.Element {
slist.Lock()
defer slist.Unlock()
return slist.queue.PushFront(v)
}
// PushBack pushs an element to the tail of slist.
func (slist *syncedList) PushBack(v interface{}) *list.Element {
slist.Lock()
defer slist.Unlock()
return slist.queue.PushBack(v)
}
// InsertBefore inserts v before mark.
func (slist *syncedList) InsertBefore(
v interface{}, mark *list.Element) *list.Element {
slist.Lock()
defer slist.Unlock()
return slist.queue.InsertBefore(v, mark)
}
// InsertAfter inserts v after mark.
func (slist *syncedList) InsertAfter(
v interface{}, mark *list.Element) *list.Element {
slist.Lock()
defer slist.Unlock()
return slist.queue.InsertAfter(v, mark)
}
// Remove removes e from the slist.
func (slist *syncedList) Remove(e *list.Element) interface{} {
slist.Lock()
defer slist.Unlock()
return slist.queue.Remove(e)
}
// Clear resets the list queue.
func (slist *syncedList) Clear() {
slist.Lock()
defer slist.Unlock()
slist.queue.Init()
}
// Len returns length of the slist.
func (slist *syncedList) Len() int {
slist.RLock()
defer slist.RUnlock()
return slist.queue.Len()
}
// Iter returns a chan which output all elements.
func (slist *syncedList) Iter() <-chan *list.Element {
ch := make(chan *list.Element)
go func() {
slist.RLock()
for e := slist.queue.Front(); e != nil; e = e.Next() {
ch <- e
}
slist.RUnlock()
close(ch)
}()
return ch
}
// KeyedDeque represents a keyed deque.
type keyedDeque struct {
*sync.RWMutex
*syncedList
index map[interface{}]*list.Element
invertedIndex map[*list.Element]interface{}
}
// newKeyedDeque returns a newKeyedDeque pointer.
func newKeyedDeque() *keyedDeque {
return &keyedDeque{
RWMutex: &sync.RWMutex{},
syncedList: newSyncedList(),
index: make(map[interface{}]*list.Element),
invertedIndex: make(map[*list.Element]interface{}),
}
}
// Push pushs a keyed-value to the end of deque.
func (deque *keyedDeque) Push(key interface{}, val interface{}) {
deque.Lock()
defer deque.Unlock()
if e, ok := deque.index[key]; ok {
deque.syncedList.Remove(e)
}
deque.index[key] = deque.syncedList.PushBack(val)
deque.invertedIndex[deque.index[key]] = key
}
// Get returns the keyed value.
func (deque *keyedDeque) Get(key interface{}) (*list.Element, bool) {
deque.RLock()
defer deque.RUnlock()
v, ok := deque.index[key]
return v, ok
}
// Has returns whether key already exists.
func (deque *keyedDeque) HasKey(key interface{}) bool {
_, ok := deque.Get(key)
return ok
}
// Delete deletes a value named key.
func (deque *keyedDeque) Delete(key interface{}) (v interface{}) {
deque.RLock()
e, ok := deque.index[key]
deque.RUnlock()
deque.Lock()
defer deque.Unlock()
if ok {
v = deque.syncedList.Remove(e)
delete(deque.index, key)
delete(deque.invertedIndex, e)
}
return
}
// Removes overwrites list.List.Remove.
func (deque *keyedDeque) Remove(e *list.Element) (v interface{}) {
deque.RLock()
key, ok := deque.invertedIndex[e]
deque.RUnlock()
if ok {
v = deque.Delete(key)
}
return
}
// Clear resets the deque.
func (deque *keyedDeque) Clear() {
deque.Lock()
defer deque.Unlock()
deque.syncedList.Clear()
deque.index = make(map[interface{}]*list.Element)
deque.invertedIndex = make(map[*list.Element]interface{})
}

196
dht/container_test.go Normal file
View file

@ -0,0 +1,196 @@
package dht
import (
"sync"
"testing"
)
func TestSyncedMap(t *testing.T) {
cases := []mapItem{
{"a", 0},
{"b", 1},
{"c", 2},
}
sm := newSyncedMap()
set := func() {
group := sync.WaitGroup{}
for _, item := range cases {
group.Add(1)
go func(item mapItem) {
sm.Set(item.key, item.val)
group.Done()
}(item)
}
group.Wait()
}
isEmpty := func() {
if sm.Len() != 0 {
t.Fail()
}
}
// Set
set()
if sm.Len() != len(cases) {
t.Fail()
}
Loop:
// Iter
for item := range sm.Iter() {
for _, c := range cases {
if item.key == c.key && item.val == c.val {
continue Loop
}
}
t.Fail()
}
// Get, Delete, Has
for _, item := range cases {
val, ok := sm.Get(item.key)
if !ok || val != item.val {
t.Fail()
}
sm.Delete(item.key)
if sm.Has(item.key) {
t.Fail()
}
}
isEmpty()
// DeleteMulti
set()
sm.DeleteMulti([]interface{}{"a", "b", "c"})
isEmpty()
// Clear
set()
sm.Clear()
isEmpty()
}
func TestSyncedList(t *testing.T) {
sl := newSyncedList()
insert := func() {
for i := 0; i < 10; i++ {
sl.PushBack(i)
}
}
isEmpty := func() {
if sl.Len() != 0 {
t.Fail()
}
}
// PushBack
insert()
// Len
if sl.Len() != 10 {
t.Fail()
}
// Iter
i := 0
for item := range sl.Iter() {
if item.Value.(int) != i {
t.Fail()
}
i++
}
// Front
if sl.Front().Value.(int) != 0 {
t.Fail()
}
// Back
if sl.Back().Value.(int) != 9 {
t.Fail()
}
// Remove
for i := 0; i < 10; i++ {
if sl.Remove(sl.Front()).(int) != i {
t.Fail()
}
}
isEmpty()
// Clear
insert()
sl.Clear()
isEmpty()
}
func TestKeyedDeque(t *testing.T) {
cases := []mapItem{
{"a", 0},
{"b", 1},
{"c", 2},
}
deque := newKeyedDeque()
insert := func() {
for _, item := range cases {
deque.Push(item.key, item.val)
}
}
isEmpty := func() {
if deque.Len() != 0 {
t.Fail()
}
}
// Push
insert()
// Len
if deque.Len() != 3 {
t.Fail()
}
// Iter
i := 0
for e := range deque.Iter() {
if e.Value.(int) != i {
t.Fail()
}
i++
}
// HasKey, Get, Delete
for _, item := range cases {
if !deque.HasKey(item.key) {
t.Fail()
}
e, ok := deque.Get(item.key)
if !ok || e.Value.(int) != item.val {
t.Fail()
}
if deque.Delete(item.key) != item.val {
t.Fail()
}
if deque.HasKey(item.key) {
t.Fail()
}
}
isEmpty()
// Clear
insert()
deque.Clear()
isEmpty()
}

296
dht/dht.go Normal file
View file

@ -0,0 +1,296 @@
// Package dht implements the bittorrent dht protocol. For more information
// see http://www.bittorrent.org/beps/bep_0005.html.
package dht
import (
"encoding/hex"
"errors"
"math"
"net"
"time"
)
const (
// StandardMode follows the standard protocol
StandardMode = iota
// CrawlMode for crawling the dht network.
CrawlMode
)
// Config represents the configure of dht.
type Config struct {
// in mainline dht, k = 8
K int
// for crawling mode, we put all nodes in one bucket, so KBucketSize may
// not be K
KBucketSize int
// candidates are udp, udp4, udp6
Network string
// format is `ip:port`
Address string
// the prime nodes through which we can join in dht network
PrimeNodes []string
// the kbucket expired duration
KBucketExpiredAfter time.Duration
// the node expired duration
NodeExpriedAfter time.Duration
// how long it checks whether the bucket is expired
CheckKBucketPeriod time.Duration
// peer token expired duration
TokenExpiredAfter time.Duration
// the max transaction id
MaxTransactionCursor uint64
// how many nodes routing table can hold
MaxNodes int
// callback when got get_peers request
OnGetPeers func(string, string, int)
// callback when got announce_peer request
OnAnnouncePeer func(string, string, int)
// blcoked ips
BlockedIPs []string
// blacklist size
BlackListMaxSize int
// StandardMode or CrawlMode
Mode int
// the times it tries when send fails
Try int
// the size of packet need to be dealt with
PacketJobLimit int
// the size of packet handler
PacketWorkerLimit int
// the nodes num to be fresh in a kbucket
RefreshNodeNum int
}
// NewStandardConfig returns a Config pointer with default values.
func NewStandardConfig() *Config {
return &Config{
K: 8,
KBucketSize: 8,
Network: "udp4",
Address: ":6881",
PrimeNodes: []string{
"router.bittorrent.com:6881",
"router.utorrent.com:6881",
"dht.transmissionbt.com:6881",
},
NodeExpriedAfter: time.Duration(time.Minute * 15),
KBucketExpiredAfter: time.Duration(time.Minute * 15),
CheckKBucketPeriod: time.Duration(time.Second * 30),
TokenExpiredAfter: time.Duration(time.Minute * 10),
MaxTransactionCursor: math.MaxUint32,
MaxNodes: 5000,
BlockedIPs: make([]string, 0),
BlackListMaxSize: 65536,
Try: 2,
Mode: StandardMode,
PacketJobLimit: 1024,
PacketWorkerLimit: 256,
RefreshNodeNum: 8,
}
}
// NewCrawlConfig returns a config in crawling mode.
func NewCrawlConfig() *Config {
config := NewStandardConfig()
config.NodeExpriedAfter = 0
config.KBucketExpiredAfter = 0
config.CheckKBucketPeriod = time.Second * 5
config.KBucketSize = math.MaxInt32
config.Mode = CrawlMode
config.RefreshNodeNum = 256
return config
}
// DHT represents a DHT node.
type DHT struct {
*Config
node *node
conn *net.UDPConn
routingTable *routingTable
transactionManager *transactionManager
peersManager *peersManager
tokenManager *tokenManager
blackList *blackList
Ready bool
packets chan packet
workerTokens chan struct{}
}
// New returns a DHT pointer. If config is nil, then config will be set to
// the default config.
func New(config *Config) *DHT {
if config == nil {
config = NewStandardConfig()
}
node, err := newNode(randomString(20), config.Network, config.Address)
if err != nil {
panic(err)
}
d := &DHT{
Config: config,
node: node,
blackList: newBlackList(config.BlackListMaxSize),
packets: make(chan packet, config.PacketJobLimit),
workerTokens: make(chan struct{}, config.PacketWorkerLimit),
}
for _, ip := range config.BlockedIPs {
d.blackList.insert(ip, -1)
}
go func() {
for _, ip := range getLocalIPs() {
d.blackList.insert(ip, -1)
}
ip, err := getRemoteIP()
if err != nil {
d.blackList.insert(ip, -1)
}
}()
return d
}
// IsStandardMode returns whether mode is StandardMode.
func (dht *DHT) IsStandardMode() bool {
return dht.Mode == StandardMode
}
// IsCrawlMode returns whether mode is CrawlMode.
func (dht *DHT) IsCrawlMode() bool {
return dht.Mode == CrawlMode
}
// init initializes global varables.
func (dht *DHT) init() {
listener, err := net.ListenPacket(dht.Network, dht.Address)
if err != nil {
panic(err)
}
dht.conn = listener.(*net.UDPConn)
dht.routingTable = newRoutingTable(dht.KBucketSize, dht)
dht.peersManager = newPeersManager(dht)
dht.tokenManager = newTokenManager(dht.TokenExpiredAfter, dht)
dht.transactionManager = newTransactionManager(
dht.MaxTransactionCursor, dht)
go dht.transactionManager.run()
go dht.tokenManager.clear()
go dht.blackList.clear()
}
// join makes current node join the dht network.
func (dht *DHT) join() {
for _, addr := range dht.PrimeNodes {
raddr, err := net.ResolveUDPAddr(dht.Network, addr)
if err != nil {
continue
}
// NOTE: Temporary node has NOT node id.
dht.transactionManager.findNode(
&node{addr: raddr},
dht.node.id.RawString(),
)
}
}
// listen receives message from udp.
func (dht *DHT) listen() {
go func() {
buff := make([]byte, 8192)
for {
n, raddr, err := dht.conn.ReadFromUDP(buff)
if err != nil {
continue
}
dht.packets <- packet{buff[:n], raddr}
}
}()
}
// id returns a id near to target if target is not null, otherwise it returns
// the dht's node id.
func (dht *DHT) id(target string) string {
if dht.IsStandardMode() || target == "" {
return dht.node.id.RawString()
}
return target[:15] + dht.node.id.RawString()[15:]
}
// GetPeers returns peers who have announced having infoHash.
func (dht *DHT) GetPeers(infoHash string) ([]*Peer, error) {
if !dht.Ready {
return nil, errors.New("dht not ready")
}
if len(infoHash) == 40 {
data, err := hex.DecodeString(infoHash)
if err != nil {
return nil, err
}
infoHash = string(data)
}
peers := dht.peersManager.GetPeers(infoHash, dht.K)
if len(peers) != 0 {
return peers, nil
}
ch := make(chan struct{})
go func() {
neighbors := dht.routingTable.GetNeighbors(
newBitmapFromString(infoHash), dht.K)
for _, no := range neighbors {
dht.transactionManager.getPeers(no, infoHash)
}
i := 0
for _ = range time.Tick(time.Second * 1) {
i++
peers = dht.peersManager.GetPeers(infoHash, dht.K)
if len(peers) != 0 || i == 30 {
break
}
}
ch <- struct{}{}
}()
<-ch
return peers, nil
}
// Run starts the dht.
func (dht *DHT) Run() {
dht.init()
dht.listen()
dht.join()
dht.Ready = true
var pkt packet
tick := time.Tick(dht.CheckKBucketPeriod)
for {
select {
case pkt = <-dht.packets:
handle(dht, pkt)
case <-tick:
if dht.routingTable.Len() == 0 {
dht.join()
} else if dht.transactionManager.len() == 0 {
go dht.routingTable.Fresh()
}
}
}
}

BIN
dht/doc/screen-shot.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 695 KiB

782
dht/krpc.go Normal file
View file

@ -0,0 +1,782 @@
package dht
import (
"errors"
"net"
"strings"
"sync"
"time"
)
const (
pingType = "ping"
findNodeType = "find_node"
getPeersType = "get_peers"
announcePeerType = "announce_peer"
)
const (
generalError = 201 + iota
serverError
protocolError
unknownError
)
// packet represents the information receive from udp.
type packet struct {
data []byte
raddr *net.UDPAddr
}
// token represents the token when response getPeers request.
type token struct {
data string
createTime time.Time
}
// tokenManager managers the tokens.
type tokenManager struct {
*syncedMap
expiredAfter time.Duration
dht *DHT
}
// newTokenManager returns a new tokenManager.
func newTokenManager(expiredAfter time.Duration, dht *DHT) *tokenManager {
return &tokenManager{
syncedMap: newSyncedMap(),
expiredAfter: expiredAfter,
dht: dht,
}
}
// token returns a token. If it doesn't exist or is expired, it will add a
// new token.
func (tm *tokenManager) token(addr *net.UDPAddr) string {
v, ok := tm.Get(addr.IP.String())
tk, _ := v.(token)
if !ok || time.Now().Sub(tk.createTime) > tm.expiredAfter {
tk = token{
data: randomString(5),
createTime: time.Now(),
}
tm.Set(addr.IP.String(), tk)
}
return tk.data
}
// clear removes expired tokens.
func (tm *tokenManager) clear() {
for _ = range time.Tick(time.Minute * 3) {
keys := make([]interface{}, 0, 100)
for item := range tm.Iter() {
if time.Now().Sub(item.val.(token).createTime) > tm.expiredAfter {
keys = append(keys, item.key)
}
}
tm.DeleteMulti(keys)
}
}
// check returns whether the token is valid.
func (tm *tokenManager) check(addr *net.UDPAddr, tokenString string) bool {
key := addr.IP.String()
v, ok := tm.Get(key)
tk, _ := v.(token)
if ok {
tm.Delete(key)
}
return ok && tokenString == tk.data
}
// makeQuery returns a query-formed data.
func makeQuery(t, q string, a map[string]interface{}) map[string]interface{} {
return map[string]interface{}{
"t": t,
"y": "q",
"q": q,
"a": a,
}
}
// makeResponse returns a response-formed data.
func makeResponse(t string, r map[string]interface{}) map[string]interface{} {
return map[string]interface{}{
"t": t,
"y": "r",
"r": r,
}
}
// makeError returns a err-formed data.
func makeError(t string, errCode int, errMsg string) map[string]interface{} {
return map[string]interface{}{
"t": t,
"y": "e",
"e": []interface{}{errCode, errMsg},
}
}
// send sends data to the udp.
func send(dht *DHT, addr *net.UDPAddr, data map[string]interface{}) error {
dht.conn.SetWriteDeadline(time.Now().Add(time.Second * 15))
_, err := dht.conn.WriteToUDP([]byte(Encode(data)), addr)
if err != nil {
dht.blackList.insert(addr.IP.String(), -1)
}
return err
}
// query represents the query data included queried node and query-formed data.
type query struct {
node *node
data map[string]interface{}
}
// transaction implements transaction.
type transaction struct {
*query
id string
response chan struct{}
}
// transactionManager represents the manager of transactions.
type transactionManager struct {
*sync.RWMutex
transactions *syncedMap
index *syncedMap
cursor uint64
maxCursor uint64
queryChan chan *query
dht *DHT
}
// newTransactionManager returns new transactionManager pointer.
func newTransactionManager(maxCursor uint64, dht *DHT) *transactionManager {
return &transactionManager{
RWMutex: &sync.RWMutex{},
transactions: newSyncedMap(),
index: newSyncedMap(),
maxCursor: maxCursor,
queryChan: make(chan *query, 1024),
dht: dht,
}
}
// genTransID generates a transaction id and returns it.
func (tm *transactionManager) genTransID() string {
tm.Lock()
defer tm.Unlock()
tm.cursor = (tm.cursor + 1) % tm.maxCursor
return string(int2bytes(tm.cursor))
}
// newTransaction creates a new transaction.
func (tm *transactionManager) newTransaction(id string, q *query) *transaction {
return &transaction{
id: id,
query: q,
response: make(chan struct{}, tm.dht.Try+1),
}
}
// genIndexKey generates an indexed key which consists of queryType and
// address.
func (tm *transactionManager) genIndexKey(queryType, address string) string {
return strings.Join([]string{queryType, address}, ":")
}
// genIndexKeyByTrans generates an indexed key by a transaction.
func (tm *transactionManager) genIndexKeyByTrans(trans *transaction) string {
return tm.genIndexKey(trans.data["q"].(string), trans.node.addr.String())
}
// insert adds a transaction to transactionManager.
func (tm *transactionManager) insert(trans *transaction) {
tm.Lock()
defer tm.Unlock()
tm.transactions.Set(trans.id, trans)
tm.index.Set(tm.genIndexKeyByTrans(trans), trans)
}
// delete removes a transaction from transactionManager.
func (tm *transactionManager) delete(transID string) {
v, ok := tm.transactions.Get(transID)
if !ok {
return
}
tm.Lock()
defer tm.Unlock()
trans := v.(*transaction)
tm.transactions.Delete(trans.id)
tm.index.Delete(tm.genIndexKeyByTrans(trans))
}
// len returns how many transactions are requesting now.
func (tm *transactionManager) len() int {
return tm.transactions.Len()
}
// transaction returns a transaction. keyType should be one of 0, 1 which
// represents transId and index each.
func (tm *transactionManager) transaction(
key string, keyType int) *transaction {
sm := tm.transactions
if keyType == 1 {
sm = tm.index
}
v, ok := sm.Get(key)
if !ok {
return nil
}
return v.(*transaction)
}
// getByTransID returns a transaction by transID.
func (tm *transactionManager) getByTransID(transID string) *transaction {
return tm.transaction(transID, 0)
}
// getByIndex returns a transaction by indexed key.
func (tm *transactionManager) getByIndex(index string) *transaction {
return tm.transaction(index, 1)
}
// transaction gets the proper transaction with whose id is transId and
// address is addr.
func (tm *transactionManager) filterOne(
transID string, addr *net.UDPAddr) *transaction {
trans := tm.getByTransID(transID)
if trans == nil || trans.node.addr.String() != addr.String() {
return nil
}
return trans
}
// query sends the query-formed data to udp and wait for the response.
// When timeout, it will retry `try - 1` times, which means it will query
// `try` times totally.
func (tm *transactionManager) query(q *query, try int) {
transID := q.data["t"].(string)
trans := tm.newTransaction(transID, q)
tm.insert(trans)
defer tm.delete(trans.id)
success := false
for i := 0; i < try; i++ {
if err := send(tm.dht, q.node.addr, q.data); err != nil {
break
}
select {
case <-trans.response:
success = true
break
case <-time.After(time.Second * 15):
}
}
if !success && q.node.id != nil {
tm.dht.blackList.insert(q.node.addr.IP.String(), q.node.addr.Port)
tm.dht.routingTable.RemoveByAddr(q.node.addr.String())
}
}
// run starts to listen and consume the query chan.
func (tm *transactionManager) run() {
var q *query
for {
select {
case q = <-tm.queryChan:
go tm.query(q, tm.dht.Try)
}
}
}
// sendQuery send query-formed data to the chan.
func (tm *transactionManager) sendQuery(
no *node, queryType string, a map[string]interface{}) {
// If the target is self, then stop.
if no.id != nil && no.id.RawString() == tm.dht.node.id.RawString() ||
tm.getByIndex(tm.genIndexKey(queryType, no.addr.String())) != nil ||
tm.dht.blackList.in(no.addr.IP.String(), no.addr.Port) {
return
}
data := makeQuery(tm.genTransID(), queryType, a)
tm.queryChan <- &query{
node: no,
data: data,
}
}
// ping sends ping query to the chan.
func (tm *transactionManager) ping(no *node) {
tm.sendQuery(no, pingType, map[string]interface{}{
"id": tm.dht.id(no.id.RawString()),
})
}
// findNode sends find_node query to the chan.
func (tm *transactionManager) findNode(no *node, target string) {
tm.sendQuery(no, findNodeType, map[string]interface{}{
"id": tm.dht.id(target),
"target": target,
})
}
// getPeers sends get_peers query to the chan.
func (tm *transactionManager) getPeers(no *node, infoHash string) {
tm.sendQuery(no, getPeersType, map[string]interface{}{
"id": tm.dht.id(infoHash),
"info_hash": infoHash,
})
}
// announcePeer sends announce_peer query to the chan.
func (tm *transactionManager) announcePeer(
no *node, infoHash string, impliedPort, port int, token string) {
tm.sendQuery(no, announcePeerType, map[string]interface{}{
"id": tm.dht.id(no.id.RawString()),
"info_hash": infoHash,
"implied_port": impliedPort,
"port": port,
"token": token,
})
}
// parseKey parses the key in dict data. `t` is type of the keyed value.
// It's one of "int", "string", "map", "list".
func parseKey(data map[string]interface{}, key string, t string) error {
val, ok := data[key]
if !ok {
return errors.New("lack of key")
}
switch t {
case "string":
_, ok = val.(string)
case "int":
_, ok = val.(int)
case "map":
_, ok = val.(map[string]interface{})
case "list":
_, ok = val.([]interface{})
default:
panic("invalid type")
}
if !ok {
return errors.New("invalid key type")
}
return nil
}
// parseKeys parses keys. It just wraps parseKey.
func parseKeys(data map[string]interface{}, pairs [][]string) error {
for _, args := range pairs {
key, t := args[0], args[1]
if err := parseKey(data, key, t); err != nil {
return err
}
}
return nil
}
// parseMessage parses the basic data received from udp.
// It returns a map value.
func parseMessage(data interface{}) (map[string]interface{}, error) {
response, ok := data.(map[string]interface{})
if !ok {
return nil, errors.New("response is not dict")
}
if err := parseKeys(
response, [][]string{{"t", "string"}, {"y", "string"}}); err != nil {
return nil, err
}
return response, nil
}
// handleRequest handles the requests received from udp.
func handleRequest(dht *DHT, addr *net.UDPAddr,
response map[string]interface{}) (success bool) {
t := response["t"].(string)
if err := parseKeys(
response, [][]string{{"q", "string"}, {"a", "map"}}); err != nil {
send(dht, addr, makeError(t, protocolError, err.Error()))
return
}
q := response["q"].(string)
a := response["a"].(map[string]interface{})
if err := parseKey(a, "id", "string"); err != nil {
send(dht, addr, makeError(t, protocolError, err.Error()))
return
}
id := a["id"].(string)
if id == dht.node.id.RawString() {
return
}
if len(id) != 20 {
send(dht, addr, makeError(t, protocolError, "invalid id"))
return
}
if no, ok := dht.routingTable.GetNodeByAddress(addr.String()); ok &&
no.id.RawString() != id {
dht.blackList.insert(addr.IP.String(), addr.Port)
dht.routingTable.RemoveByAddr(addr.String())
send(dht, addr, makeError(t, protocolError, "invalid id"))
return
}
switch q {
case pingType:
send(dht, addr, makeResponse(t, map[string]interface{}{
"id": dht.id(id),
}))
case findNodeType:
if dht.IsStandardMode() {
if err := parseKey(a, "target", "string"); err != nil {
send(dht, addr, makeError(t, protocolError, err.Error()))
return
}
target := a["target"].(string)
if len(target) != 20 {
send(dht, addr, makeError(t, protocolError, "invalid target"))
return
}
var nodes string
targetID := newBitmapFromString(target)
no, _ := dht.routingTable.GetNodeKBucktByID(targetID)
if no != nil {
nodes = no.CompactNodeInfo()
} else {
nodes = strings.Join(
dht.routingTable.GetNeighborCompactInfos(targetID, dht.K),
"",
)
}
send(dht, addr, makeResponse(t, map[string]interface{}{
"id": dht.id(target),
"nodes": nodes,
}))
}
case getPeersType:
if err := parseKey(a, "info_hash", "string"); err != nil {
send(dht, addr, makeError(t, protocolError, err.Error()))
return
}
infoHash := a["info_hash"].(string)
if len(infoHash) != 20 {
send(dht, addr, makeError(t, protocolError, "invalid info_hash"))
return
}
if dht.IsCrawlMode() {
send(dht, addr, makeResponse(t, map[string]interface{}{
"id": dht.id(infoHash),
"token": dht.tokenManager.token(addr),
"nodes": "",
}))
} else if peers := dht.peersManager.GetPeers(
infoHash, dht.K); len(peers) > 0 {
values := make([]interface{}, len(peers))
for i, p := range peers {
values[i] = p.CompactIPPortInfo()
}
send(dht, addr, makeResponse(t, map[string]interface{}{
"id": dht.id(infoHash),
"values": values,
"token": dht.tokenManager.token(addr),
}))
} else {
send(dht, addr, makeResponse(t, map[string]interface{}{
"id": dht.id(infoHash),
"token": dht.tokenManager.token(addr),
"nodes": strings.Join(dht.routingTable.GetNeighborCompactInfos(
newBitmapFromString(infoHash), dht.K), ""),
}))
}
if dht.OnGetPeers != nil {
dht.OnGetPeers(infoHash, addr.IP.String(), addr.Port)
}
case announcePeerType:
if err := parseKeys(a, [][]string{
{"info_hash", "string"},
{"port", "int"},
{"token", "string"}}); err != nil {
send(dht, addr, makeError(t, protocolError, err.Error()))
return
}
infoHash := a["info_hash"].(string)
port := a["port"].(int)
token := a["token"].(string)
if !dht.tokenManager.check(addr, token) {
// send(dht, addr, makeError(t, protocolError, "invalid token"))
return
}
if impliedPort, ok := a["implied_port"]; ok &&
impliedPort.(int) != 0 {
port = addr.Port
}
if dht.IsStandardMode() {
dht.peersManager.Insert(infoHash, newPeer(addr.IP, port, token))
send(dht, addr, makeResponse(t, map[string]interface{}{
"id": dht.id(id),
}))
}
if dht.OnAnnouncePeer != nil {
dht.OnAnnouncePeer(infoHash, addr.IP.String(), port)
}
default:
// send(dht, addr, makeError(t, protocolError, "invalid q"))
return
}
no, _ := newNode(id, addr.Network(), addr.String())
dht.routingTable.Insert(no)
return true
}
// findOn puts nodes in the response to the routingTable, then if target is in
// the nodes or all nodes are in the routingTable, it stops. Otherwise it
// continues to findNode or getPeers.
func findOn(dht *DHT, r map[string]interface{}, target *bitmap,
queryType string) error {
if err := parseKey(r, "nodes", "string"); err != nil {
return err
}
nodes := r["nodes"].(string)
if len(nodes)%26 != 0 {
return errors.New("the length of nodes should can be divided by 26")
}
hasNew, found := false, false
for i := 0; i < len(nodes)/26; i++ {
no, _ := newNodeFromCompactInfo(
string(nodes[i*26:(i+1)*26]), dht.Network)
if no.id.RawString() == target.RawString() {
found = true
}
if dht.routingTable.Insert(no) {
hasNew = true
}
}
if found || !hasNew {
return nil
}
targetID := target.RawString()
for _, no := range dht.routingTable.GetNeighbors(target, dht.K) {
switch queryType {
case findNodeType:
dht.transactionManager.findNode(no, targetID)
case getPeersType:
dht.transactionManager.getPeers(no, targetID)
default:
panic("invalid find type")
}
}
return nil
}
// handleResponse handles responses received from udp.
func handleResponse(dht *DHT, addr *net.UDPAddr,
response map[string]interface{}) (success bool) {
t := response["t"].(string)
trans := dht.transactionManager.filterOne(t, addr)
if trans == nil {
return
}
// inform transManager to delete the transaction.
if err := parseKey(response, "r", "map"); err != nil {
return
}
q := trans.data["q"].(string)
a := trans.data["a"].(map[string]interface{})
r := response["r"].(map[string]interface{})
if err := parseKey(r, "id", "string"); err != nil {
return
}
id := r["id"].(string)
// If response's node id is not the same with the node id in the
// transaction, raise error.
if trans.node.id != nil && trans.node.id.RawString() != r["id"].(string) {
dht.blackList.insert(addr.IP.String(), addr.Port)
dht.routingTable.RemoveByAddr(addr.String())
return
}
node, err := newNode(id, addr.Network(), addr.String())
if err != nil {
return
}
switch q {
case pingType:
case findNodeType:
if trans.data["q"].(string) != findNodeType {
return
}
target := trans.data["a"].(map[string]interface{})["target"].(string)
if findOn(dht, r, newBitmapFromString(target), findNodeType) != nil {
return
}
case getPeersType:
if err := parseKey(r, "token", "string"); err != nil {
return
}
token := r["token"].(string)
infoHash := a["info_hash"].(string)
if err := parseKey(r, "values", "list"); err == nil {
values := r["values"].([]interface{})
for _, v := range values {
p, err := newPeerFromCompactIPPortInfo(v.(string), token)
if err != nil {
continue
}
dht.peersManager.Insert(infoHash, p)
}
} else if findOn(
dht, r, newBitmapFromString(infoHash), getPeersType) != nil {
return
}
case announcePeerType:
default:
return
}
// inform transManager to delete transaction.
trans.response <- struct{}{}
dht.blackList.delete(addr.IP.String(), addr.Port)
dht.routingTable.Insert(node)
return true
}
// handleError handles errors received from udp.
func handleError(dht *DHT, addr *net.UDPAddr,
response map[string]interface{}) (success bool) {
if err := parseKey(response, "e", "list"); err != nil {
return
}
if e := response["e"].([]interface{}); len(e) != 2 {
return
}
if trans := dht.transactionManager.filterOne(
response["t"].(string), addr); trans != nil {
trans.response <- struct{}{}
}
return true
}
var handlers = map[string]func(*DHT, *net.UDPAddr, map[string]interface{}) bool{
"q": handleRequest,
"r": handleResponse,
"e": handleError,
}
// handle handles packets received from udp.
func handle(dht *DHT, pkt packet) {
if len(dht.workerTokens) == dht.PacketWorkerLimit {
return
}
dht.workerTokens <- struct{}{}
go func() {
defer func() {
<-dht.workerTokens
}()
if dht.blackList.in(pkt.raddr.IP.String(), pkt.raddr.Port) {
return
}
data, err := Decode(pkt.data)
if err != nil {
return
}
response, err := parseMessage(data)
if err != nil {
return
}
if f, ok := handlers[response["y"].(string)]; ok {
f(dht, pkt.raddr, response)
}
}()
}

385
dht/peerwire.go Normal file
View file

@ -0,0 +1,385 @@
package dht
import (
"bytes"
"crypto/sha1"
"encoding/binary"
"errors"
"io"
"io/ioutil"
"net"
"strings"
"time"
)
const (
// REQUEST represents request message type
REQUEST = iota
// DATA represents data message type
DATA
// REJECT represents reject message type
REJECT
)
const (
// BLOCK is 2 ^ 14
BLOCK = 16384
// MaxMetadataSize represents the max medata it can accept
MaxMetadataSize = BLOCK * 1000
// EXTENDED represents it is a extended message
EXTENDED = 20
// HANDSHAKE represents handshake bit
HANDSHAKE = 0
)
var handshakePrefix = []byte{
19, 66, 105, 116, 84, 111, 114, 114, 101, 110, 116, 32, 112, 114,
111, 116, 111, 99, 111, 108, 0, 0, 0, 0, 0, 16, 0, 1,
}
// read reads size-length bytes from conn to data.
func read(conn *net.TCPConn, size int, data *bytes.Buffer) error {
conn.SetReadDeadline(time.Now().Add(time.Second * 15))
n, err := io.CopyN(data, conn, int64(size))
if err != nil || n != int64(size) {
return errors.New("read error")
}
return nil
}
// readMessage gets a message from the tcp connection.
func readMessage(conn *net.TCPConn, data *bytes.Buffer) (
length int, err error) {
if err = read(conn, 4, data); err != nil {
return
}
length = int(bytes2int(data.Next(4)))
if length == 0 {
return
}
if err = read(conn, length, data); err != nil {
return
}
return
}
// sendMessage sends data to the connection.
func sendMessage(conn *net.TCPConn, data []byte) error {
length := int32(len(data))
buffer := bytes.NewBuffer(nil)
binary.Write(buffer, binary.BigEndian, length)
conn.SetWriteDeadline(time.Now().Add(time.Second * 10))
_, err := conn.Write(append(buffer.Bytes(), data...))
return err
}
// sendHandshake sends handshake message to conn.
func sendHandshake(conn *net.TCPConn, infoHash, peerID []byte) error {
data := make([]byte, 68)
copy(data[:28], handshakePrefix)
copy(data[28:48], infoHash)
copy(data[48:], peerID)
conn.SetWriteDeadline(time.Now().Add(time.Second * 10))
_, err := conn.Write(data)
return err
}
// onHandshake handles the handshake response.
func onHandshake(data []byte) (err error) {
if !(bytes.Equal(handshakePrefix[:20], data[:20]) && data[25]&0x10 != 0) {
err = errors.New("invalid handshake response")
}
return
}
// sendExtHandshake requests for the ut_metadata and metadata_size.
func sendExtHandshake(conn *net.TCPConn) error {
data := append(
[]byte{EXTENDED, HANDSHAKE},
Encode(map[string]interface{}{
"m": map[string]interface{}{"ut_metadata": 1},
})...,
)
return sendMessage(conn, data)
}
// getUTMetaSize returns the ut_metadata and metadata_size.
func getUTMetaSize(data []byte) (
utMetadata int, metadataSize int, err error) {
v, err := Decode(data)
if err != nil {
return
}
dict, ok := v.(map[string]interface{})
if !ok {
err = errors.New("invalid dict")
return
}
if err = parseKeys(
dict, [][]string{{"metadata_size", "int"}, {"m", "map"}}); err != nil {
return
}
m := dict["m"].(map[string]interface{})
if err = parseKey(m, "ut_metadata", "int"); err != nil {
return
}
utMetadata = m["ut_metadata"].(int)
metadataSize = dict["metadata_size"].(int)
if metadataSize > MaxMetadataSize {
err = errors.New("metadata_size too long")
}
return
}
// Request represents the request context.
type Request struct {
InfoHash []byte
IP string
Port int
}
// Response contains the request context and the metadata info.
type Response struct {
Request
MetadataInfo []byte
}
// Wire represents the wire protocol.
type Wire struct {
blackList *blackList
queue *syncedMap
requests chan Request
responses chan Response
workerTokens chan struct{}
}
// NewWire returns a Wire pointer.
// - blackListSize: the blacklist size
// - requestQueueSize: the max requests it can buffers
// - workerQueueSize: the max goroutine downloading workers
func NewWire(blackListSize, requestQueueSize, workerQueueSize int) *Wire {
return &Wire{
blackList: newBlackList(blackListSize),
queue: newSyncedMap(),
requests: make(chan Request, requestQueueSize),
responses: make(chan Response, 1024),
workerTokens: make(chan struct{}, workerQueueSize),
}
}
// Request pushes the request to the queue.
func (wire *Wire) Request(infoHash []byte, ip string, port int) {
wire.requests <- Request{InfoHash: infoHash, IP: ip, Port: port}
}
// Response returns a chan of Response.
func (wire *Wire) Response() <-chan Response {
return wire.responses
}
// isDone returns whether the wire get all pieces of the metadata info.
func (wire *Wire) isDone(pieces [][]byte) bool {
for _, piece := range pieces {
if len(piece) == 0 {
return false
}
}
return true
}
func (wire *Wire) requestPieces(
conn *net.TCPConn, utMetadata int, metadataSize int, piecesNum int) {
buffer := make([]byte, 1024)
for i := 0; i < piecesNum; i++ {
buffer[0] = EXTENDED
buffer[1] = byte(utMetadata)
msg := Encode(map[string]interface{}{
"msg_type": REQUEST,
"piece": i,
})
length := len(msg) + 2
copy(buffer[2:length], msg)
sendMessage(conn, buffer[:length])
}
buffer = nil
}
// fetchMetadata fetchs medata info accroding to infohash from dht.
func (wire *Wire) fetchMetadata(r Request) {
var (
length int
msgType byte
piecesNum int
pieces [][]byte
utMetadata int
metadataSize int
)
defer func() {
pieces = nil
recover()
}()
infoHash := r.InfoHash
address := genAddress(r.IP, r.Port)
dial, err := net.DialTimeout("tcp", address, time.Second*15)
if err != nil {
wire.blackList.insert(r.IP, r.Port)
return
}
conn := dial.(*net.TCPConn)
conn.SetLinger(0)
defer conn.Close()
data := bytes.NewBuffer(nil)
data.Grow(BLOCK)
if sendHandshake(conn, infoHash, []byte(randomString(20))) != nil ||
read(conn, 68, data) != nil ||
onHandshake(data.Next(68)) != nil ||
sendExtHandshake(conn) != nil {
return
}
for {
length, err = readMessage(conn, data)
if err != nil {
return
}
if length == 0 {
continue
}
msgType, err = data.ReadByte()
if err != nil {
return
}
switch msgType {
case EXTENDED:
extendedID, err := data.ReadByte()
if err != nil {
return
}
payload, err := ioutil.ReadAll(data)
if err != nil {
return
}
if extendedID == 0 {
if pieces != nil {
return
}
utMetadata, metadataSize, err = getUTMetaSize(payload)
if err != nil {
return
}
piecesNum = metadataSize / BLOCK
if metadataSize%BLOCK != 0 {
piecesNum++
}
pieces = make([][]byte, piecesNum)
go wire.requestPieces(conn, utMetadata, metadataSize, piecesNum)
continue
}
if pieces == nil {
return
}
d, index, err := DecodeDict(payload, 0)
if err != nil {
return
}
dict := d.(map[string]interface{})
if err = parseKeys(dict, [][]string{
{"msg_type", "int"},
{"piece", "int"}}); err != nil {
return
}
if dict["msg_type"].(int) != DATA {
continue
}
piece := dict["piece"].(int)
pieceLen := length - 2 - index
if (piece != piecesNum-1 && pieceLen != BLOCK) ||
(piece == piecesNum-1 && pieceLen != metadataSize%BLOCK) {
return
}
pieces[piece] = payload[index:]
if wire.isDone(pieces) {
metadataInfo := bytes.Join(pieces, nil)
info := sha1.Sum(metadataInfo)
if !bytes.Equal(infoHash, info[:]) {
return
}
wire.responses <- Response{
Request: r,
MetadataInfo: metadataInfo,
}
return
}
default:
data.Reset()
}
}
}
// Run starts the peer wire protocol.
func (wire *Wire) Run() {
go wire.blackList.clear()
for r := range wire.requests {
wire.workerTokens <- struct{}{}
go func(r Request) {
defer func() {
<-wire.workerTokens
}()
key := strings.Join([]string{
string(r.InfoHash), genAddress(r.IP, r.Port),
}, ":")
if len(r.InfoHash) != 20 || wire.blackList.in(r.IP, r.Port) ||
wire.queue.Has(key) {
return
}
wire.fetchMetadata(r)
}(r)
}
}

596
dht/routingtable.go Normal file
View file

@ -0,0 +1,596 @@
package dht
import (
"container/heap"
"errors"
"net"
"strings"
"sync"
"time"
)
// maxPrefixLength is the length of DHT node.
const maxPrefixLength = 160
// node represents a DHT node.
type node struct {
id *bitmap
addr *net.UDPAddr
lastActiveTime time.Time
}
// newNode returns a node pointer.
func newNode(id, network, address string) (*node, error) {
if len(id) != 20 {
return nil, errors.New("node id should be a 20-length string")
}
addr, err := net.ResolveUDPAddr(network, address)
if err != nil {
return nil, err
}
return &node{newBitmapFromString(id), addr, time.Now()}, nil
}
// newNodeFromCompactInfo parses compactNodeInfo and returns a node pointer.
func newNodeFromCompactInfo(
compactNodeInfo string, network string) (*node, error) {
if len(compactNodeInfo) != 26 {
return nil, errors.New("compactNodeInfo should be a 26-length string")
}
id := compactNodeInfo[:20]
ip, port, _ := decodeCompactIPPortInfo(compactNodeInfo[20:])
return newNode(id, network, genAddress(ip.String(), port))
}
// CompactIPPortInfo returns "Compact IP-address/port info".
// See http://www.bittorrent.org/beps/bep_0005.html.
func (node *node) CompactIPPortInfo() string {
info, _ := encodeCompactIPPortInfo(node.addr.IP, node.addr.Port)
return info
}
// CompactNodeInfo returns "Compact node info".
// See http://www.bittorrent.org/beps/bep_0005.html.
func (node *node) CompactNodeInfo() string {
return strings.Join([]string{
node.id.RawString(), node.CompactIPPortInfo(),
}, "")
}
// Peer represents a peer contact.
type Peer struct {
IP net.IP
Port int
token string
}
// newPeer returns a new peer pointer.
func newPeer(ip net.IP, port int, token string) *Peer {
return &Peer{
IP: ip,
Port: port,
token: token,
}
}
// newPeerFromCompactIPPortInfo create a peer pointer by compact ip/port info.
func newPeerFromCompactIPPortInfo(compactInfo, token string) (*Peer, error) {
ip, port, err := decodeCompactIPPortInfo(compactInfo)
if err != nil {
return nil, err
}
return newPeer(ip, port, token), nil
}
// CompactIPPortInfo returns "Compact node info".
// See http://www.bittorrent.org/beps/bep_0005.html.
func (p *Peer) CompactIPPortInfo() string {
info, _ := encodeCompactIPPortInfo(p.IP, p.Port)
return info
}
// peersManager represents a proxy that manipulates peers.
type peersManager struct {
sync.RWMutex
table *syncedMap
dht *DHT
}
// newPeersManager returns a new peersManager.
func newPeersManager(dht *DHT) *peersManager {
return &peersManager{
table: newSyncedMap(),
dht: dht,
}
}
// Insert adds a peer into peersManager.
func (pm *peersManager) Insert(infoHash string, peer *Peer) {
pm.Lock()
if _, ok := pm.table.Get(infoHash); !ok {
pm.table.Set(infoHash, newKeyedDeque())
}
pm.Unlock()
v, _ := pm.table.Get(infoHash)
queue := v.(*keyedDeque)
queue.Push(peer.CompactIPPortInfo(), peer)
if queue.Len() > pm.dht.K {
queue.Remove(queue.Front())
}
}
// GetPeers returns size-length peers who announces having infoHash.
func (pm *peersManager) GetPeers(infoHash string, size int) []*Peer {
peers := make([]*Peer, 0, size)
v, ok := pm.table.Get(infoHash)
if !ok {
return peers
}
for e := range v.(*keyedDeque).Iter() {
peers = append(peers, e.Value.(*Peer))
}
if len(peers) > size {
peers = peers[len(peers)-size:]
}
return peers
}
// kbucket represents a k-size bucket.
type kbucket struct {
sync.RWMutex
nodes, candidates *keyedDeque
lastChanged time.Time
prefix *bitmap
}
// newKBucket returns a new kbucket pointer.
func newKBucket(prefix *bitmap) *kbucket {
bucket := &kbucket{
nodes: newKeyedDeque(),
candidates: newKeyedDeque(),
lastChanged: time.Now(),
prefix: prefix,
}
return bucket
}
// LastChanged return the last time when it changes.
func (bucket *kbucket) LastChanged() time.Time {
bucket.RLock()
defer bucket.RUnlock()
return bucket.lastChanged
}
// RandomChildID returns a random id that has the same prefix with bucket.
func (bucket *kbucket) RandomChildID() string {
prefixLen := bucket.prefix.Size / 8
return strings.Join([]string{
bucket.prefix.RawString()[:prefixLen],
randomString(20 - prefixLen),
}, "")
}
// UpdateTimestamp update bucket's last changed time..
func (bucket *kbucket) UpdateTimestamp() {
bucket.Lock()
defer bucket.Unlock()
bucket.lastChanged = time.Now()
}
// Insert inserts node to the bucket. It returns whether the node is new in
// the bucket.
func (bucket *kbucket) Insert(no *node) bool {
isNew := !bucket.nodes.HasKey(no.id.RawString())
bucket.nodes.Push(no.id.RawString(), no)
bucket.UpdateTimestamp()
return isNew
}
// Replace removes node, then put bucket.candidates.Back() to the right
// place of bucket.nodes.
func (bucket *kbucket) Replace(no *node) {
bucket.nodes.Delete(no.id.RawString())
bucket.UpdateTimestamp()
if bucket.candidates.Len() == 0 {
return
}
no = bucket.candidates.Remove(bucket.candidates.Back()).(*node)
inserted := false
for e := range bucket.nodes.Iter() {
if e.Value.(*node).lastActiveTime.After(
no.lastActiveTime) && !inserted {
bucket.nodes.InsertBefore(no, e)
inserted = true
}
}
if !inserted {
bucket.nodes.PushBack(no)
}
}
// Fresh pings the expired nodes in the bucket.
func (bucket *kbucket) Fresh(dht *DHT) {
for e := range bucket.nodes.Iter() {
no := e.Value.(*node)
if time.Since(no.lastActiveTime) > dht.NodeExpriedAfter {
dht.transactionManager.ping(no)
}
}
}
// routingTableNode represents routing table tree node.
type routingTableNode struct {
sync.RWMutex
children []*routingTableNode
bucket *kbucket
}
// newRoutingTableNode returns a new routingTableNode pointer.
func newRoutingTableNode(prefix *bitmap) *routingTableNode {
return &routingTableNode{
children: make([]*routingTableNode, 2),
bucket: newKBucket(prefix),
}
}
// Child returns routingTableNode's left or right child.
func (tableNode *routingTableNode) Child(index int) *routingTableNode {
if index >= 2 {
return nil
}
tableNode.RLock()
defer tableNode.RUnlock()
return tableNode.children[index]
}
// SetChild sets routingTableNode's left or right child. When index is 0, it's
// the left child, if 1, it's the right child.
func (tableNode *routingTableNode) SetChild(index int, c *routingTableNode) {
tableNode.Lock()
defer tableNode.Unlock()
tableNode.children[index] = c
}
// KBucket returns the bucket routingTableNode holds.
func (tableNode *routingTableNode) KBucket() *kbucket {
tableNode.RLock()
defer tableNode.RUnlock()
return tableNode.bucket
}
// SetKBucket sets the bucket.
func (tableNode *routingTableNode) SetKBucket(bucket *kbucket) {
tableNode.Lock()
defer tableNode.Unlock()
tableNode.bucket = bucket
}
// Split splits current routingTableNode and sets it's two children.
func (tableNode *routingTableNode) Split() {
prefixLen := tableNode.KBucket().prefix.Size
if prefixLen == maxPrefixLength {
return
}
for i := 0; i < 2; i++ {
tableNode.SetChild(i, newRoutingTableNode(newBitmapFrom(
tableNode.KBucket().prefix, prefixLen+1)))
}
tableNode.Lock()
tableNode.children[1].bucket.prefix.Set(prefixLen)
tableNode.Unlock()
for e := range tableNode.KBucket().nodes.Iter() {
nd := e.Value.(*node)
tableNode.Child(nd.id.Bit(prefixLen)).KBucket().nodes.PushBack(nd)
}
for e := range tableNode.KBucket().candidates.Iter() {
nd := e.Value.(*node)
tableNode.Child(nd.id.Bit(prefixLen)).KBucket().candidates.PushBack(nd)
}
for i := 0; i < 2; i++ {
tableNode.Child(i).KBucket().UpdateTimestamp()
}
}
// routingTable implements the routing table in DHT protocol.
type routingTable struct {
*sync.RWMutex
k int
root *routingTableNode
cachedNodes *syncedMap
cachedKBuckets *keyedDeque
dht *DHT
clearQueue *syncedList
}
// newRoutingTable returns a new routingTable pointer.
func newRoutingTable(k int, dht *DHT) *routingTable {
root := newRoutingTableNode(newBitmap(0))
rt := &routingTable{
RWMutex: &sync.RWMutex{},
k: k,
root: root,
cachedNodes: newSyncedMap(),
cachedKBuckets: newKeyedDeque(),
dht: dht,
clearQueue: newSyncedList(),
}
rt.cachedKBuckets.Push(root.bucket.prefix.String(), root.bucket)
return rt
}
// Insert adds a node to routing table. It returns whether the node is new
// in the routingtable.
func (rt *routingTable) Insert(nd *node) bool {
rt.Lock()
defer rt.Unlock()
if rt.dht.blackList.in(nd.addr.IP.String(), nd.addr.Port) ||
rt.cachedNodes.Len() >= rt.dht.MaxNodes {
return false
}
var (
next *routingTableNode
bucket *kbucket
)
root := rt.root
for prefixLen := 1; prefixLen <= maxPrefixLength; prefixLen++ {
next = root.Child(nd.id.Bit(prefixLen - 1))
if next != nil {
// If next is not the leaf.
root = next
} else if root.KBucket().nodes.Len() < rt.k ||
root.KBucket().nodes.HasKey(nd.id.RawString()) {
bucket = root.KBucket()
isNew := bucket.Insert(nd)
rt.cachedNodes.Set(nd.addr.String(), nd)
rt.cachedKBuckets.Push(bucket.prefix.String(), bucket)
return isNew
} else if root.KBucket().prefix.Compare(nd.id, prefixLen-1) == 0 {
// If node has the same prefix with bucket, split it.
root.Split()
rt.cachedKBuckets.Delete(root.KBucket().prefix.String())
root.SetKBucket(nil)
for i := 0; i < 2; i++ {
bucket = root.Child(i).KBucket()
rt.cachedKBuckets.Push(bucket.prefix.String(), bucket)
}
root = root.Child(nd.id.Bit(prefixLen - 1))
} else {
// Finally, store node as a candidate and fresh the bucket.
root.KBucket().candidates.PushBack(nd)
if root.KBucket().candidates.Len() > rt.k {
root.KBucket().candidates.Remove(
root.KBucket().candidates.Front())
}
go root.KBucket().Fresh(rt.dht)
return false
}
}
return false
}
// GetNeighbors returns the size-length nodes closest to id.
func (rt *routingTable) GetNeighbors(id *bitmap, size int) []*node {
rt.RLock()
nodes := make([]interface{}, 0, rt.cachedNodes.Len())
for item := range rt.cachedNodes.Iter() {
nodes = append(nodes, item.val.(*node))
}
rt.RUnlock()
neighbors := getTopK(nodes, id, size)
result := make([]*node, len(neighbors))
for i, nd := range neighbors {
result[i] = nd.(*node)
}
return result
}
// GetNeighborIds return the size-length compact node info closest to id.
func (rt *routingTable) GetNeighborCompactInfos(id *bitmap, size int) []string {
neighbors := rt.GetNeighbors(id, size)
infos := make([]string, len(neighbors))
for i, no := range neighbors {
infos[i] = no.CompactNodeInfo()
}
return infos
}
// GetNodeKBucktById returns node whose id is `id` and the bucket it
// belongs to.
func (rt *routingTable) GetNodeKBucktByID(id *bitmap) (
nd *node, bucket *kbucket) {
rt.RLock()
defer rt.RUnlock()
var next *routingTableNode
root := rt.root
for prefixLen := 1; prefixLen <= maxPrefixLength; prefixLen++ {
next = root.Child(id.Bit(prefixLen - 1))
if next == nil {
v, ok := root.KBucket().nodes.Get(id.RawString())
if !ok {
return
}
nd, bucket = v.Value.(*node), root.KBucket()
return
}
root = next
}
return
}
// GetNodeByAddress finds node by address.
func (rt *routingTable) GetNodeByAddress(address string) (no *node, ok bool) {
rt.RLock()
defer rt.RUnlock()
v, ok := rt.cachedNodes.Get(address)
if ok {
no = v.(*node)
}
return
}
// Remove deletes the node whose id is `id`.
func (rt *routingTable) Remove(id *bitmap) {
if nd, bucket := rt.GetNodeKBucktByID(id); nd != nil {
bucket.Replace(nd)
rt.cachedNodes.Delete(nd.addr.String())
rt.cachedKBuckets.Push(bucket.prefix.String(), bucket)
}
}
// Remove deletes the node whose address is `ip:port`.
func (rt *routingTable) RemoveByAddr(address string) {
v, ok := rt.cachedNodes.Get(address)
if ok {
rt.Remove(v.(*node).id)
}
}
// Fresh sends findNode to all nodes in the expired nodes.
func (rt *routingTable) Fresh() {
now := time.Now()
for e := range rt.cachedKBuckets.Iter() {
bucket := e.Value.(*kbucket)
if now.Sub(bucket.LastChanged()) < rt.dht.KBucketExpiredAfter ||
bucket.nodes.Len() == 0 {
continue
}
i := 0
for e := range bucket.nodes.Iter() {
if i < rt.dht.RefreshNodeNum {
no := e.Value.(*node)
rt.dht.transactionManager.findNode(no, bucket.RandomChildID())
rt.clearQueue.PushBack(no)
}
i++
}
}
if rt.dht.IsCrawlMode() {
for e := range rt.clearQueue.Iter() {
rt.Remove(e.Value.(*node).id)
}
}
rt.clearQueue.Clear()
}
// Len returns the number of nodes in table.
func (rt *routingTable) Len() int {
rt.RLock()
defer rt.RUnlock()
return rt.cachedNodes.Len()
}
// Implementation of heap with heap.Interface.
type heapItem struct {
distance *bitmap
value interface{}
}
type topKHeap []*heapItem
func (kHeap topKHeap) Len() int {
return len(kHeap)
}
func (kHeap topKHeap) Less(i, j int) bool {
return kHeap[i].distance.Compare(kHeap[j].distance, maxPrefixLength) == 1
}
func (kHeap topKHeap) Swap(i, j int) {
kHeap[i], kHeap[j] = kHeap[j], kHeap[i]
}
func (kHeap *topKHeap) Push(x interface{}) {
*kHeap = append(*kHeap, x.(*heapItem))
}
func (kHeap *topKHeap) Pop() interface{} {
n := len(*kHeap)
x := (*kHeap)[n-1]
*kHeap = (*kHeap)[:n-1]
return x
}
// getTopK solves the top-k problem with heap. It's time complexity is
// O(n*log(k)). When n is large, time complexity will be too high, need to be
// optimized.
func getTopK(queue []interface{}, id *bitmap, k int) []interface{} {
topkHeap := make(topKHeap, 0, k+1)
for _, value := range queue {
node := value.(*node)
item := &heapItem{
id.Xor(node.id),
value,
}
heap.Push(&topkHeap, item)
if topkHeap.Len() > k {
heap.Pop(&topkHeap)
}
}
tops := make([]interface{}, topkHeap.Len())
for i := len(tops) - 1; i >= 0; i-- {
tops[i] = heap.Pop(&topkHeap).(*heapItem).value
}
return tops
}

View file

@ -0,0 +1,23 @@
package main
import (
"fmt"
"github.com/shiyanhui/dht"
"time"
)
func main() {
d := dht.New(nil)
go d.Run()
for {
// ubuntu-14.04.2-desktop-amd64.iso
peers, err := d.GetPeers("546cf15f724d19c4319cc17b179d7e035f89c1f4")
if err != nil {
time.Sleep(time.Second * 1)
continue
}
fmt.Println("Found peers:", peers)
}
}

View file

@ -0,0 +1,77 @@
package main
import (
"encoding/hex"
"encoding/json"
"fmt"
"github.com/shiyanhui/dht"
"net/http"
_ "net/http/pprof"
)
type file struct {
Path []interface{} `json:"path"`
Length int `json:"length"`
}
type bitTorrent struct {
InfoHash string `json:"infohash"`
Name string `json:"name"`
Files []file `json:"files,omitempty"`
Length int `json:"length,omitempty"`
}
func main() {
go func() {
http.ListenAndServe(":6060", nil)
}()
w := dht.NewWire(65536, 1024, 256)
go func() {
for resp := range w.Response() {
metadata, err := dht.Decode(resp.MetadataInfo)
if err != nil {
continue
}
info := metadata.(map[string]interface{})
if _, ok := info["name"]; !ok {
continue
}
bt := bitTorrent{
InfoHash: hex.EncodeToString(resp.InfoHash),
Name: info["name"].(string),
}
if v, ok := info["files"]; ok {
files := v.([]interface{})
bt.Files = make([]file, len(files))
for i, item := range files {
f := item.(map[string]interface{})
bt.Files[i] = file{
Path: f["path"].([]interface{}),
Length: f["length"].(int),
}
}
} else if _, ok := info["length"]; ok {
bt.Length = info["length"].(int)
}
data, err := json.Marshal(bt)
if err == nil {
fmt.Printf("%s\n\n", data)
}
}
}()
go w.Run()
config := dht.NewCrawlConfig()
config.OnAnnouncePeer = func(infoHash, ip string, port int) {
w.Request([]byte(infoHash), ip, port)
}
d := dht.New(config)
d.Run()
}

134
dht/util.go Normal file
View file

@ -0,0 +1,134 @@
package dht
import (
"crypto/rand"
"errors"
"io/ioutil"
"net"
"net/http"
"strconv"
"strings"
"time"
)
// randomString generates a size-length string randomly.
func randomString(size int) string {
buff := make([]byte, size)
rand.Read(buff)
return string(buff)
}
// bytes2int returns the int value it represents.
func bytes2int(data []byte) uint64 {
n, val := len(data), uint64(0)
if n > 8 {
panic("data too long")
}
for i, b := range data {
val += uint64(b) << uint64((n-i-1)*8)
}
return val
}
// int2bytes returns the byte array it represents.
func int2bytes(val uint64) []byte {
data, j := make([]byte, 8), -1
for i := 0; i < 8; i++ {
shift := uint64((7 - i) * 8)
data[i] = byte((val & (0xff << shift)) >> shift)
if j == -1 && data[i] != 0 {
j = i
}
}
if j != -1 {
return data[j:]
}
return data[:1]
}
// decodeCompactIPPortInfo decodes compactIP-address/port info in BitTorrent
// DHT Protocol. It returns the ip and port number.
func decodeCompactIPPortInfo(info string) (ip net.IP, port int, err error) {
if len(info) != 6 {
err = errors.New("compact info should be 6-length long")
return
}
ip = net.IPv4(info[0], info[1], info[2], info[3])
port = int((uint16(info[4]) << 8) | uint16(info[5]))
return
}
// encodeCompactIPPortInfo encodes an ip and a port number to
// compactIP-address/port info.
func encodeCompactIPPortInfo(ip net.IP, port int) (info string, err error) {
if port > 65535 || port < 0 {
err = errors.New(
"port should be no greater than 65535 and no less than 0")
return
}
p := int2bytes(uint64(port))
if len(p) < 2 {
p = append(p, p[0])
p[0] = 0
}
info = string(append(ip, p...))
return
}
// getLocalIPs returns local ips.
func getLocalIPs() (ips []string) {
ips = make([]string, 0, 6)
addrs, err := net.InterfaceAddrs()
if err != nil {
return
}
for _, addr := range addrs {
ip, _, err := net.ParseCIDR(addr.String())
if err != nil {
continue
}
ips = append(ips, ip.String())
}
return
}
// getRemoteIP returns the wlan ip.
func getRemoteIP() (ip string, err error) {
client := &http.Client{
Timeout: time.Second * 30,
}
req, err := http.NewRequest("GET", "http://ifconfig.me", nil)
if err != nil {
return
}
req.Header.Set("User-Agent", "curl")
res, err := client.Do(req)
if err != nil {
return
}
defer res.Body.Close()
data, err := ioutil.ReadAll(res.Body)
if err != nil {
return
}
ip = string(data)
return
}
// genAddress returns a ip:port address.
func genAddress(ip string, port int) string {
return strings.Join([]string{ip, strconv.Itoa(port)}, ":")
}

100
dht/util_test.go Normal file
View file

@ -0,0 +1,100 @@
package dht
import (
"testing"
)
func TestInt2Bytes(t *testing.T) {
cases := []struct {
in uint64
out []byte
}{
{0, []byte{0}},
{1, []byte{1}},
{256, []byte{1, 0}},
{22129, []byte{86, 113}},
}
for _, c := range cases {
r := int2bytes(c.in)
if len(r) != len(c.out) {
t.Fail()
}
for i, v := range r {
if v != c.out[i] {
t.Fail()
}
}
}
}
func TestBytes2Int(t *testing.T) {
cases := []struct {
in []byte
out uint64
}{
{[]byte{0}, 0},
{[]byte{1}, 1},
{[]byte{1, 0}, 256},
{[]byte{86, 113}, 22129},
}
for _, c := range cases {
if bytes2int(c.in) != c.out {
t.Fail()
}
}
}
func TestDecodeCompactIPPortInfo(t *testing.T) {
cases := []struct {
in string
out struct {
ip string
port int
}
}{
{"123456", struct {
ip string
port int
}{"49.50.51.52", 13622}},
{"abcdef", struct {
ip string
port int
}{"97.98.99.100", 25958}},
}
for _, item := range cases {
ip, port, err := decodeCompactIPPortInfo(item.in)
if err != nil || ip.String() != item.out.ip || port != item.out.port {
t.Fail()
}
}
}
func TestEncodeCompactIPPortInfo(t *testing.T) {
cases := []struct {
in struct {
ip []byte
port int
}
out string
}{
{struct {
ip []byte
port int
}{[]byte{49, 50, 51, 52}, 13622}, "123456"},
{struct {
ip []byte
port int
}{[]byte{97, 98, 99, 100}, 25958}, "abcdef"},
}
for _, item := range cases {
info, err := encodeCompactIPPortInfo(item.in.ip, item.in.port)
if err != nil || info != item.out {
t.Fail()
}
}
}