remove blacklist, peerwire, and custom bencode lib. successfully receive a request and send a response
This commit is contained in:
parent
fa8a4a59bc
commit
c0290497de
9 changed files with 263 additions and 1381 deletions
257
dht/bencode.go
257
dht/bencode.go
|
@ -1,257 +0,0 @@
|
|||
package dht
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
// find returns the index of first target in data starting from `start`.
|
||||
// It returns -1 if target not found.
|
||||
func find(data []byte, start int, target rune) (index int) {
|
||||
index = bytes.IndexRune(data[start:], target)
|
||||
if index != -1 {
|
||||
return index + start
|
||||
}
|
||||
return index
|
||||
}
|
||||
|
||||
// DecodeString decodes a string in the data. It returns a tuple
|
||||
// (decoded result, the end position, error).
|
||||
func DecodeString(data []byte, start int) (result interface{}, index int, err error) {
|
||||
|
||||
if start >= len(data) || data[start] < '0' || data[start] > '9' {
|
||||
err = errors.New("invalid string bencode")
|
||||
return
|
||||
}
|
||||
|
||||
i := find(data, start, ':')
|
||||
if i == -1 {
|
||||
err = errors.New("':' not found when decode string")
|
||||
return
|
||||
}
|
||||
|
||||
length, err := strconv.Atoi(string(data[start:i]))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if length < 0 {
|
||||
err = errors.New("invalid length of string")
|
||||
return
|
||||
}
|
||||
|
||||
index = i + 1 + length
|
||||
|
||||
if index > len(data) || index < i+1 {
|
||||
err = errors.New("out of range")
|
||||
return
|
||||
}
|
||||
|
||||
result = string(data[i+1 : index])
|
||||
return
|
||||
}
|
||||
|
||||
// DecodeInt decodes int value in the data.
|
||||
func DecodeInt(data []byte, start int) (result interface{}, index int, err error) {
|
||||
|
||||
if start >= len(data) || data[start] != 'i' {
|
||||
err = errors.New("invalid int bencode")
|
||||
return
|
||||
}
|
||||
|
||||
index = find(data, start+1, 'e')
|
||||
|
||||
if index == -1 {
|
||||
err = errors.New("':' not found when decode int")
|
||||
return
|
||||
}
|
||||
|
||||
result, err = strconv.Atoi(string(data[start+1 : index]))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
index++
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// decodeItem decodes an item of dict or list.
|
||||
func decodeItem(data []byte, i int) (result interface{}, index int, err error) {
|
||||
var decodeFunc = []func([]byte, int) (interface{}, int, error){
|
||||
DecodeString, DecodeInt, DecodeList, DecodeDict,
|
||||
}
|
||||
|
||||
for _, f := range decodeFunc {
|
||||
result, index, err = f(data, i)
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
err = errors.New("invalid bencode when decode item")
|
||||
return
|
||||
}
|
||||
|
||||
// DecodeList decodes a list value.
|
||||
func DecodeList(data []byte, start int) (result interface{}, index int, err error) {
|
||||
|
||||
if start >= len(data) || data[start] != 'l' {
|
||||
err = errors.New("invalid list bencode")
|
||||
return
|
||||
}
|
||||
|
||||
var item interface{}
|
||||
r := make([]interface{}, 0, 8)
|
||||
|
||||
index = start + 1
|
||||
for index < len(data) {
|
||||
char, _ := utf8.DecodeRune(data[index:])
|
||||
if char == 'e' {
|
||||
break
|
||||
}
|
||||
|
||||
item, index, err = decodeItem(data, index)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
r = append(r, item)
|
||||
}
|
||||
|
||||
if index == len(data) {
|
||||
err = errors.New("'e' not found when decode list")
|
||||
return
|
||||
}
|
||||
index++
|
||||
|
||||
result = r
|
||||
return
|
||||
}
|
||||
|
||||
// DecodeDict decodes a map value.
|
||||
func DecodeDict(data []byte, start int) (result interface{}, index int, err error) {
|
||||
|
||||
if start >= len(data) || data[start] != 'd' {
|
||||
err = errors.New("invalid dict bencode")
|
||||
return
|
||||
}
|
||||
|
||||
var item, key interface{}
|
||||
r := make(map[string]interface{})
|
||||
|
||||
index = start + 1
|
||||
for index < len(data) {
|
||||
char, _ := utf8.DecodeRune(data[index:])
|
||||
if char == 'e' {
|
||||
break
|
||||
}
|
||||
|
||||
if !unicode.IsDigit(char) {
|
||||
err = errors.New("invalid dict bencode")
|
||||
return
|
||||
}
|
||||
|
||||
key, index, err = DecodeString(data, index)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if index >= len(data) {
|
||||
err = errors.New("out of range")
|
||||
return
|
||||
}
|
||||
|
||||
item, index, err = decodeItem(data, index)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
r[key.(string)] = item
|
||||
}
|
||||
|
||||
if index == len(data) {
|
||||
err = errors.New("'e' not found when decode dict")
|
||||
return
|
||||
}
|
||||
index++
|
||||
|
||||
result = r
|
||||
return
|
||||
}
|
||||
|
||||
// Decode decodes a bencoded string to string, int, list or map.
|
||||
func Decode(data []byte) (result interface{}, err error) {
|
||||
result, _, err = decodeItem(data, 0)
|
||||
return
|
||||
}
|
||||
|
||||
// EncodeString encodes a string value.
|
||||
func EncodeString(data string) string {
|
||||
return strings.Join([]string{strconv.Itoa(len(data)), data}, ":")
|
||||
}
|
||||
|
||||
// EncodeInt encodes a int value.
|
||||
func EncodeInt(data int) string {
|
||||
return strings.Join([]string{"i", strconv.Itoa(data), "e"}, "")
|
||||
}
|
||||
|
||||
// EncodeItem encodes an item of dict or list.
|
||||
func encodeItem(data interface{}) (item string) {
|
||||
switch v := data.(type) {
|
||||
case string:
|
||||
item = EncodeString(v)
|
||||
case int:
|
||||
item = EncodeInt(v)
|
||||
case []interface{}:
|
||||
item = EncodeList(v)
|
||||
case map[string]interface{}:
|
||||
item = EncodeDict(v)
|
||||
default:
|
||||
panic("invalid type when encode item")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// EncodeList encodes a list value.
|
||||
func EncodeList(data []interface{}) string {
|
||||
result := make([]string, len(data))
|
||||
|
||||
for i, item := range data {
|
||||
result[i] = encodeItem(item)
|
||||
}
|
||||
|
||||
return strings.Join([]string{"l", strings.Join(result, ""), "e"}, "")
|
||||
}
|
||||
|
||||
// EncodeDict encodes a dict value.
|
||||
func EncodeDict(data map[string]interface{}) string {
|
||||
result, i := make([]string, len(data)), 0
|
||||
|
||||
for key, val := range data {
|
||||
result[i] = strings.Join(
|
||||
[]string{EncodeString(key), encodeItem(val)},
|
||||
"")
|
||||
i++
|
||||
}
|
||||
|
||||
return strings.Join([]string{"d", strings.Join(result, ""), "e"}, "")
|
||||
}
|
||||
|
||||
// Encode encodes a string, int, dict or list value to a bencoded string.
|
||||
func Encode(data interface{}) string {
|
||||
switch v := data.(type) {
|
||||
case string:
|
||||
return EncodeString(v)
|
||||
case int:
|
||||
return EncodeInt(v)
|
||||
case []interface{}:
|
||||
return EncodeList(v)
|
||||
case map[string]interface{}:
|
||||
return EncodeDict(v)
|
||||
default:
|
||||
panic("invalid type when encode")
|
||||
}
|
||||
}
|
|
@ -1,159 +0,0 @@
|
|||
package dht
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDecodeString(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
out string
|
||||
}{
|
||||
{"0:", ""},
|
||||
{"1:a", "a"},
|
||||
{"5:hello", "hello"},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
if out, err := Decode([]byte(c.in)); err != nil || out != c.out {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeInt(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
out int
|
||||
}{
|
||||
{"i123e:", 123},
|
||||
{"i0e", 0},
|
||||
{"i-1e", -1},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
if out, err := Decode([]byte(c.in)); err != nil || out != c.out {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeList(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
out []interface{}
|
||||
}{
|
||||
{"li123ei-1ee", []interface{}{123, -1}},
|
||||
{"l5:helloe", []interface{}{"hello"}},
|
||||
{"ld5:hello5:worldee", []interface{}{map[string]interface{}{"hello": "world"}}},
|
||||
{"lli1ei2eee", []interface{}{[]interface{}{1, 2}}},
|
||||
}
|
||||
|
||||
for i, c := range cases {
|
||||
v, err := Decode([]byte(c.in))
|
||||
if err != nil {
|
||||
t.Fail()
|
||||
}
|
||||
|
||||
out := v.([]interface{})
|
||||
|
||||
switch i {
|
||||
case 0, 1:
|
||||
for j, item := range out {
|
||||
if item != c.out[j] {
|
||||
t.Fail()
|
||||
}
|
||||
}
|
||||
case 2:
|
||||
if len(out) != 1 {
|
||||
t.Fail()
|
||||
}
|
||||
|
||||
o := out[0].(map[string]interface{})
|
||||
cout := c.out[0].(map[string]interface{})
|
||||
|
||||
for k, v := range o {
|
||||
if cv, ok := cout[k]; !ok || v != cv {
|
||||
t.Fail()
|
||||
}
|
||||
}
|
||||
case 3:
|
||||
if len(out) != 1 {
|
||||
t.Fail()
|
||||
}
|
||||
|
||||
o := out[0].([]interface{})
|
||||
cout := c.out[0].([]interface{})
|
||||
|
||||
for j, item := range o {
|
||||
if item != cout[j] {
|
||||
t.Fail()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeDict(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
out map[string]interface{}
|
||||
}{
|
||||
{"d5:helloi100ee", map[string]interface{}{"hello": 100}},
|
||||
{"d3:foo3:bare", map[string]interface{}{"foo": "bar"}},
|
||||
{"d1:ad3:foo3:baree", map[string]interface{}{"a": map[string]interface{}{"foo": "bar"}}},
|
||||
{"d4:listli1eee", map[string]interface{}{"list": []interface{}{1}}},
|
||||
}
|
||||
|
||||
for i, c := range cases {
|
||||
v, err := Decode([]byte(c.in))
|
||||
if err != nil {
|
||||
t.Fail()
|
||||
}
|
||||
|
||||
out := v.(map[string]interface{})
|
||||
|
||||
switch i {
|
||||
case 0, 1:
|
||||
for k, v := range out {
|
||||
if cv, ok := c.out[k]; !ok || v != cv {
|
||||
t.Fail()
|
||||
}
|
||||
}
|
||||
case 2:
|
||||
if len(out) != 1 {
|
||||
t.Fail()
|
||||
}
|
||||
|
||||
v, ok := out["a"]
|
||||
if !ok {
|
||||
t.Fail()
|
||||
}
|
||||
|
||||
cout := c.out["a"].(map[string]interface{})
|
||||
|
||||
for k, v := range v.(map[string]interface{}) {
|
||||
if cv, ok := cout[k]; !ok || v != cv {
|
||||
t.Fail()
|
||||
}
|
||||
}
|
||||
case 3:
|
||||
if len(out) != 1 {
|
||||
t.Fail()
|
||||
}
|
||||
|
||||
v, ok := out["list"]
|
||||
if !ok {
|
||||
t.Fail()
|
||||
}
|
||||
|
||||
cout := c.out["list"].([]interface{})
|
||||
|
||||
for j, v := range v.([]interface{}) {
|
||||
if v != cout[j] {
|
||||
t.Fail()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,92 +0,0 @@
|
|||
package dht
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// blockedItem represents a blocked node.
|
||||
type blockedItem struct {
|
||||
ip string
|
||||
port int
|
||||
createTime time.Time
|
||||
}
|
||||
|
||||
// blackList manages the blocked nodes including which sends bad information
|
||||
// and can't ping out.
|
||||
type blackList struct {
|
||||
list *syncedMap
|
||||
maxSize int
|
||||
expiredAfter time.Duration
|
||||
}
|
||||
|
||||
// newBlackList returns a blackList pointer.
|
||||
func newBlackList(size int) *blackList {
|
||||
return &blackList{
|
||||
list: newSyncedMap(),
|
||||
maxSize: size,
|
||||
expiredAfter: time.Hour * 1,
|
||||
}
|
||||
}
|
||||
|
||||
// genKey returns a key. If port is less than 0, the key wil be ip. Ohterwise
|
||||
// it will be `ip:port` format.
|
||||
func (bl *blackList) genKey(ip string, port int) string {
|
||||
key := ip
|
||||
if port >= 0 {
|
||||
key = genAddress(ip, port)
|
||||
}
|
||||
return key
|
||||
}
|
||||
|
||||
// insert adds a blocked item to the blacklist.
|
||||
func (bl *blackList) insert(ip string, port int) {
|
||||
if bl.list.Len() >= bl.maxSize {
|
||||
return
|
||||
}
|
||||
|
||||
bl.list.Set(bl.genKey(ip, port), &blockedItem{
|
||||
ip: ip,
|
||||
port: port,
|
||||
createTime: time.Now(),
|
||||
})
|
||||
}
|
||||
|
||||
// delete removes blocked item form the blackList.
|
||||
func (bl *blackList) delete(ip string, port int) {
|
||||
bl.list.Delete(bl.genKey(ip, port))
|
||||
}
|
||||
|
||||
// validate checks whether ip-port pair is in the block nodes list.
|
||||
func (bl *blackList) in(ip string, port int) bool {
|
||||
if _, ok := bl.list.Get(ip); ok {
|
||||
return true
|
||||
}
|
||||
|
||||
key := bl.genKey(ip, port)
|
||||
|
||||
v, ok := bl.list.Get(key)
|
||||
if ok {
|
||||
if time.Now().Sub(v.(*blockedItem).createTime) < bl.expiredAfter {
|
||||
return true
|
||||
}
|
||||
bl.list.Delete(key)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// clear cleans the expired items every 10 minutes.
|
||||
func (bl *blackList) clear() {
|
||||
for _ = range time.Tick(time.Minute * 10) {
|
||||
keys := make([]interface{}, 0, 100)
|
||||
|
||||
for item := range bl.list.Iter() {
|
||||
if time.Now().Sub(
|
||||
item.val.(*blockedItem).createTime) > bl.expiredAfter {
|
||||
|
||||
keys = append(keys, item.key)
|
||||
}
|
||||
}
|
||||
|
||||
bl.list.DeleteMulti(keys)
|
||||
}
|
||||
}
|
|
@ -1,57 +0,0 @@
|
|||
package dht
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var blacklist = newBlackList(256)
|
||||
|
||||
func TestGenKey(t *testing.T) {
|
||||
cases := []struct {
|
||||
in struct {
|
||||
ip string
|
||||
port int
|
||||
}
|
||||
out string
|
||||
}{
|
||||
{struct {
|
||||
ip string
|
||||
port int
|
||||
}{"0.0.0.0", -1}, "0.0.0.0"},
|
||||
{struct {
|
||||
ip string
|
||||
port int
|
||||
}{"1.1.1.1", 8080}, "1.1.1.1:8080"},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
if blacklist.genKey(c.in.ip, c.in.port) != c.out {
|
||||
t.Fail()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBlackList(t *testing.T) {
|
||||
address := []struct {
|
||||
ip string
|
||||
port int
|
||||
}{
|
||||
{"0.0.0.0", -1},
|
||||
{"1.1.1.1", 8080},
|
||||
{"2.2.2.2", 8081},
|
||||
}
|
||||
|
||||
for _, addr := range address {
|
||||
blacklist.insert(addr.ip, addr.port)
|
||||
if !blacklist.in(addr.ip, addr.port) {
|
||||
t.Fail()
|
||||
}
|
||||
|
||||
blacklist.delete(addr.ip, addr.port)
|
||||
if blacklist.in(addr.ip, addr.port) {
|
||||
fmt.Println(addr.ip)
|
||||
t.Fail()
|
||||
}
|
||||
}
|
||||
}
|
61
dht/dht.go
61
dht/dht.go
|
@ -15,9 +15,6 @@ import (
|
|||
type Config struct {
|
||||
// in mainline dht, k = 8
|
||||
K int
|
||||
// for crawling mode, we put all nodes in one bucket, so KBucketSize may
|
||||
// not be K
|
||||
KBucketSize int
|
||||
// candidates are udp, udp4, udp6
|
||||
Network string
|
||||
// format is `ip:port`
|
||||
|
@ -40,10 +37,6 @@ type Config struct {
|
|||
OnGetPeers func(string, string, int)
|
||||
// callback when got announce_peer request
|
||||
OnAnnouncePeer func(string, string, int)
|
||||
// blcoked ips
|
||||
BlockedIPs []string
|
||||
// blacklist size
|
||||
BlackListMaxSize int
|
||||
// the times it tries when send fails
|
||||
Try int
|
||||
// the size of packet need to be dealt with
|
||||
|
@ -57,14 +50,13 @@ type Config struct {
|
|||
// NewStandardConfig returns a Config pointer with default values.
|
||||
func NewStandardConfig() *Config {
|
||||
return &Config{
|
||||
K: 8,
|
||||
KBucketSize: 8,
|
||||
Network: "udp4",
|
||||
Address: ":6881",
|
||||
K: 8,
|
||||
Network: "udp4",
|
||||
Address: ":4444",
|
||||
PrimeNodes: []string{
|
||||
"router.bittorrent.com:6881",
|
||||
"router.utorrent.com:6881",
|
||||
"dht.transmissionbt.com:6881",
|
||||
"lbrynet1.lbry.io:4444",
|
||||
"lbrynet2.lbry.io:4444",
|
||||
"lbrynet3.lbry.io:4444",
|
||||
},
|
||||
NodeExpriedAfter: time.Duration(time.Minute * 15),
|
||||
KBucketExpiredAfter: time.Duration(time.Minute * 15),
|
||||
|
@ -72,8 +64,6 @@ func NewStandardConfig() *Config {
|
|||
TokenExpiredAfter: time.Duration(time.Minute * 10),
|
||||
MaxTransactionCursor: math.MaxUint32,
|
||||
MaxNodes: 5000,
|
||||
BlockedIPs: make([]string, 0),
|
||||
BlackListMaxSize: 65536,
|
||||
Try: 2,
|
||||
PacketJobLimit: 1024,
|
||||
PacketWorkerLimit: 256,
|
||||
|
@ -90,7 +80,6 @@ type DHT struct {
|
|||
transactionManager *transactionManager
|
||||
peersManager *peersManager
|
||||
tokenManager *tokenManager
|
||||
blackList *blackList
|
||||
Ready bool
|
||||
packets chan packet
|
||||
workerTokens chan struct{}
|
||||
|
@ -111,26 +100,10 @@ func New(config *Config) *DHT {
|
|||
d := &DHT{
|
||||
Config: config,
|
||||
node: node,
|
||||
blackList: newBlackList(config.BlackListMaxSize),
|
||||
packets: make(chan packet, config.PacketJobLimit),
|
||||
workerTokens: make(chan struct{}, config.PacketWorkerLimit),
|
||||
}
|
||||
|
||||
for _, ip := range config.BlockedIPs {
|
||||
d.blackList.insert(ip, -1)
|
||||
}
|
||||
|
||||
go func() {
|
||||
for _, ip := range getLocalIPs() {
|
||||
d.blackList.insert(ip, -1)
|
||||
}
|
||||
|
||||
ip, err := getRemoteIP()
|
||||
if err != nil {
|
||||
d.blackList.insert(ip, -1)
|
||||
}
|
||||
}()
|
||||
|
||||
return d
|
||||
}
|
||||
|
||||
|
@ -143,15 +116,13 @@ func (dht *DHT) init() {
|
|||
}
|
||||
|
||||
dht.conn = listener.(*net.UDPConn)
|
||||
dht.routingTable = newRoutingTable(dht.KBucketSize, dht)
|
||||
dht.routingTable = newRoutingTable(dht.K, dht)
|
||||
dht.peersManager = newPeersManager(dht)
|
||||
dht.tokenManager = newTokenManager(dht.TokenExpiredAfter, dht)
|
||||
dht.transactionManager = newTransactionManager(
|
||||
dht.MaxTransactionCursor, dht)
|
||||
dht.transactionManager = newTransactionManager(dht.MaxTransactionCursor, dht)
|
||||
|
||||
go dht.transactionManager.run()
|
||||
go dht.tokenManager.clear()
|
||||
go dht.blackList.clear()
|
||||
}
|
||||
|
||||
// join makes current node join the dht network.
|
||||
|
@ -162,9 +133,8 @@ func (dht *DHT) join() {
|
|||
continue
|
||||
}
|
||||
|
||||
// NOTE: Temporary node has NOT node id.
|
||||
dht.transactionManager.findNode(
|
||||
&node{addr: raddr},
|
||||
&node{id: dht.node.id, addr: raddr},
|
||||
dht.node.id.RawString(),
|
||||
)
|
||||
}
|
||||
|
@ -179,7 +149,6 @@ func (dht *DHT) listen() {
|
|||
if err != nil {
|
||||
continue
|
||||
}
|
||||
log.Infof("Received %s", buff)
|
||||
|
||||
dht.packets <- packet{buff[:n], raddr}
|
||||
}
|
||||
|
@ -214,12 +183,12 @@ func (dht *DHT) GetPeers(infoHash string) ([]*Peer, error) {
|
|||
ch := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
neighbors := dht.routingTable.GetNeighbors(
|
||||
newBitmapFromString(infoHash), dht.K)
|
||||
|
||||
for _, no := range neighbors {
|
||||
dht.transactionManager.getPeers(no, infoHash)
|
||||
}
|
||||
//neighbors := dht.routingTable.GetNeighbors(
|
||||
// newBitmapFromString(infoHash), dht.K)
|
||||
//
|
||||
//for _, no := range neighbors {
|
||||
// dht.transactionManager.getPeers(no, infoHash)
|
||||
//}
|
||||
|
||||
i := 0
|
||||
for range time.Tick(time.Second * 1) {
|
||||
|
|
591
dht/krpc.go
591
dht/krpc.go
|
@ -1,20 +1,23 @@
|
|||
package dht
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cast"
|
||||
"github.com/zeebo/bencode"
|
||||
"net"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
pingType = "ping"
|
||||
findNodeType = "find_node"
|
||||
getPeersType = "get_peers"
|
||||
announcePeerType = "announce_peer"
|
||||
pingMethod = "ping"
|
||||
storeMethod = "store"
|
||||
findNodeMethod = "findNode"
|
||||
findValueMethod = "findValue"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -24,6 +27,78 @@ const (
|
|||
unknownError
|
||||
)
|
||||
|
||||
const (
|
||||
requestType = 0
|
||||
responseType = 1
|
||||
errorType = 2
|
||||
)
|
||||
|
||||
const (
|
||||
// these are strings because bencode requires bytestring keys
|
||||
headerTypeField = "0"
|
||||
headerMessageIDField = "1"
|
||||
headerNodeIDField = "2"
|
||||
headerPayloadField = "3"
|
||||
headerArgsField = "4"
|
||||
)
|
||||
|
||||
type Message interface {
|
||||
GetID() string
|
||||
Encode() ([]byte, error)
|
||||
}
|
||||
|
||||
type Request struct {
|
||||
ID string
|
||||
NodeID string
|
||||
Method string
|
||||
Args []string
|
||||
}
|
||||
|
||||
func (r Request) GetID() string { return r.ID }
|
||||
func (r Request) Encode() ([]byte, error) {
|
||||
return bencode.EncodeBytes(map[string]interface{}{
|
||||
headerTypeField: requestType,
|
||||
headerMessageIDField: r.ID,
|
||||
headerNodeIDField: r.NodeID,
|
||||
headerPayloadField: r.Method,
|
||||
headerArgsField: r.Args,
|
||||
})
|
||||
}
|
||||
|
||||
type Response struct {
|
||||
ID string
|
||||
NodeID string
|
||||
Response string
|
||||
}
|
||||
|
||||
func (r Response) GetID() string { return r.ID }
|
||||
func (r Response) Encode() ([]byte, error) {
|
||||
return bencode.EncodeBytes(map[string]interface{}{
|
||||
headerTypeField: responseType,
|
||||
headerMessageIDField: r.ID,
|
||||
headerNodeIDField: r.NodeID,
|
||||
headerPayloadField: r.Response,
|
||||
})
|
||||
}
|
||||
|
||||
type Error struct {
|
||||
ID string
|
||||
NodeID string
|
||||
Response []string
|
||||
ExceptionType string
|
||||
}
|
||||
|
||||
func (e Error) GetID() string { return e.ID }
|
||||
func (e Error) Encode() ([]byte, error) {
|
||||
return bencode.EncodeBytes(map[string]interface{}{
|
||||
headerTypeField: errorType,
|
||||
headerMessageIDField: e.ID,
|
||||
headerNodeIDField: e.NodeID,
|
||||
headerPayloadField: e.ExceptionType,
|
||||
headerArgsField: e.Response,
|
||||
})
|
||||
}
|
||||
|
||||
// packet represents the information receive from udp.
|
||||
type packet struct {
|
||||
data []byte
|
||||
|
@ -108,39 +183,25 @@ func makeQuery(t, q string, a map[string]interface{}) map[string]interface{} {
|
|||
}
|
||||
}
|
||||
|
||||
// makeResponse returns a response-formed data.
|
||||
func makeResponse(t string, r map[string]interface{}) map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"t": t,
|
||||
"y": "r",
|
||||
"r": r,
|
||||
}
|
||||
}
|
||||
|
||||
// makeError returns a err-formed data.
|
||||
func makeError(t string, errCode int, errMsg string) map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"t": t,
|
||||
"y": "e",
|
||||
"e": []interface{}{errCode, errMsg},
|
||||
}
|
||||
}
|
||||
|
||||
// send sends data to the udp.
|
||||
func send(dht *DHT, addr *net.UDPAddr, data map[string]interface{}) error {
|
||||
func send(dht *DHT, addr *net.UDPAddr, data Message) error {
|
||||
log.Infof("Sending %s", spew.Sdump(data))
|
||||
encoded, err := data.Encode()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Infof("Encoded: %s", string(encoded))
|
||||
|
||||
dht.conn.SetWriteDeadline(time.Now().Add(time.Second * 15))
|
||||
|
||||
_, err := dht.conn.WriteToUDP([]byte(Encode(data)), addr)
|
||||
if err != nil {
|
||||
dht.blackList.insert(addr.IP.String(), -1)
|
||||
}
|
||||
_, err = dht.conn.WriteToUDP(encoded, addr)
|
||||
return err
|
||||
}
|
||||
|
||||
// query represents the query data included queried node and query-formed data.
|
||||
type query struct {
|
||||
node *node
|
||||
data map[string]interface{}
|
||||
node *node
|
||||
request Request
|
||||
}
|
||||
|
||||
// transaction implements transaction.
|
||||
|
@ -199,7 +260,7 @@ func (tm *transactionManager) genIndexKey(queryType, address string) string {
|
|||
|
||||
// genIndexKeyByTrans generates an indexed key by a transaction.
|
||||
func (tm *transactionManager) genIndexKeyByTrans(trans *transaction) string {
|
||||
return tm.genIndexKey(trans.data["q"].(string), trans.node.addr.String())
|
||||
return tm.genIndexKey(trans.request.Method, trans.node.addr.String())
|
||||
}
|
||||
|
||||
// insert adds a transaction to transactionManager.
|
||||
|
@ -233,8 +294,7 @@ func (tm *transactionManager) len() int {
|
|||
|
||||
// transaction returns a transaction. keyType should be one of 0, 1 which
|
||||
// represents transId and index each.
|
||||
func (tm *transactionManager) transaction(
|
||||
key string, keyType int) *transaction {
|
||||
func (tm *transactionManager) transaction(key string, keyType int) *transaction {
|
||||
|
||||
sm := tm.transactions
|
||||
if keyType == 1 {
|
||||
|
@ -261,14 +321,11 @@ func (tm *transactionManager) getByIndex(index string) *transaction {
|
|||
|
||||
// transaction gets the proper transaction with whose id is transId and
|
||||
// address is addr.
|
||||
func (tm *transactionManager) filterOne(
|
||||
transID string, addr *net.UDPAddr) *transaction {
|
||||
|
||||
func (tm *transactionManager) filterOne(transID string, addr *net.UDPAddr) *transaction {
|
||||
trans := tm.getByTransID(transID)
|
||||
if trans == nil || trans.node.addr.String() != addr.String() {
|
||||
return nil
|
||||
}
|
||||
|
||||
return trans
|
||||
}
|
||||
|
||||
|
@ -276,15 +333,14 @@ func (tm *transactionManager) filterOne(
|
|||
// When timeout, it will retry `try - 1` times, which means it will query
|
||||
// `try` times totally.
|
||||
func (tm *transactionManager) query(q *query, try int) {
|
||||
transID := q.data["t"].(string)
|
||||
trans := tm.newTransaction(transID, q)
|
||||
trans := tm.newTransaction(q.request.ID, q)
|
||||
|
||||
tm.insert(trans)
|
||||
defer tm.delete(trans.id)
|
||||
|
||||
success := false
|
||||
for i := 0; i < try; i++ {
|
||||
if err := send(tm.dht, q.node.addr, q.data); err != nil {
|
||||
if err := send(tm.dht, q.node.addr, q.request); err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
|
@ -297,7 +353,6 @@ func (tm *transactionManager) query(q *query, try int) {
|
|||
}
|
||||
|
||||
if !success && q.node.id != nil {
|
||||
tm.dht.blackList.insert(q.node.addr.IP.String(), q.node.addr.Port)
|
||||
tm.dht.routingTable.RemoveByAddr(q.node.addr.String())
|
||||
}
|
||||
}
|
||||
|
@ -315,166 +370,130 @@ func (tm *transactionManager) run() {
|
|||
}
|
||||
|
||||
// sendQuery send query-formed data to the chan.
|
||||
func (tm *transactionManager) sendQuery(no *node, queryType string, a map[string]interface{}) {
|
||||
|
||||
func (tm *transactionManager) sendQuery(no *node, request Request) {
|
||||
// If the target is self, then stop.
|
||||
if no.id != nil && no.id.RawString() == tm.dht.node.id.RawString() ||
|
||||
tm.getByIndex(tm.genIndexKey(queryType, no.addr.String())) != nil ||
|
||||
tm.dht.blackList.in(no.addr.IP.String(), no.addr.Port) {
|
||||
tm.getByIndex(tm.genIndexKey(request.Method, no.addr.String())) != nil {
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
data := makeQuery(tm.genTransID(), queryType, a)
|
||||
tm.queryChan <- &query{
|
||||
node: no,
|
||||
data: data,
|
||||
}
|
||||
request.ID = tm.genTransID()
|
||||
request.NodeID = tm.dht.id(no.id.RawString())
|
||||
tm.queryChan <- &query{node: no, request: request}
|
||||
}
|
||||
|
||||
// ping sends ping query to the chan.
|
||||
func (tm *transactionManager) ping(no *node) {
|
||||
tm.sendQuery(no, pingType, map[string]interface{}{
|
||||
"id": tm.dht.id(no.id.RawString()),
|
||||
})
|
||||
tm.sendQuery(no, Request{Method: pingMethod})
|
||||
}
|
||||
|
||||
// findNode sends find_node query to the chan.
|
||||
func (tm *transactionManager) findNode(no *node, target string) {
|
||||
tm.sendQuery(no, findNodeType, map[string]interface{}{
|
||||
"id": tm.dht.id(target),
|
||||
"target": target,
|
||||
})
|
||||
tm.sendQuery(no, Request{Method: findNodeMethod, Args: []string{target}})
|
||||
}
|
||||
|
||||
// getPeers sends get_peers query to the chan.
|
||||
func (tm *transactionManager) getPeers(no *node, infoHash string) {
|
||||
tm.sendQuery(no, getPeersType, map[string]interface{}{
|
||||
"id": tm.dht.id(infoHash),
|
||||
"info_hash": infoHash,
|
||||
})
|
||||
}
|
||||
|
||||
// announcePeer sends announce_peer query to the chan.
|
||||
func (tm *transactionManager) announcePeer(no *node, infoHash string, impliedPort, port int, token string) {
|
||||
tm.sendQuery(no, announcePeerType, map[string]interface{}{
|
||||
"id": tm.dht.id(no.id.RawString()),
|
||||
"info_hash": infoHash,
|
||||
"implied_port": impliedPort,
|
||||
"port": port,
|
||||
"token": token,
|
||||
})
|
||||
}
|
||||
|
||||
// parseKey parses the key in dict data. `t` is type of the keyed value.
|
||||
// It's one of "int", "string", "map", "list".
|
||||
func parseKey(data map[string]interface{}, key string, t string) error {
|
||||
val, ok := data[key]
|
||||
if !ok {
|
||||
return errors.New("lack of key")
|
||||
}
|
||||
|
||||
switch t {
|
||||
case "string":
|
||||
_, ok = val.(string)
|
||||
case "int":
|
||||
_, ok = val.(int)
|
||||
case "map":
|
||||
_, ok = val.(map[string]interface{})
|
||||
case "list":
|
||||
_, ok = val.([]interface{})
|
||||
default:
|
||||
panic("invalid type")
|
||||
}
|
||||
|
||||
if !ok {
|
||||
return errors.New("invalid key type")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseKeys parses keys. It just wraps parseKey.
|
||||
func parseKeys(data map[string]interface{}, pairs [][]string) error {
|
||||
for _, args := range pairs {
|
||||
key, t := args[0], args[1]
|
||||
if err := parseKey(data, key, t); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseMessage parses the basic data received from udp.
|
||||
// It returns a map value.
|
||||
func parseMessage(data interface{}) (map[string]interface{}, error) {
|
||||
response, ok := data.(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, errors.New("response is not dict")
|
||||
}
|
||||
|
||||
if err := parseKeys(
|
||||
response, [][]string{{"t", "string"}, {"y", "string"}}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// handleRequest handles the requests received from udp.
|
||||
func handleRequest(dht *DHT, addr *net.UDPAddr, response map[string]interface{}) (success bool) {
|
||||
|
||||
t := response["t"].(string)
|
||||
|
||||
if err := parseKeys(
|
||||
response, [][]string{{"q", "string"}, {"a", "map"}}); err != nil {
|
||||
|
||||
send(dht, addr, makeError(t, protocolError, err.Error()))
|
||||
// handle handles packets received from udp.
|
||||
func handle(dht *DHT, pkt packet) {
|
||||
log.Infof("Received message from %s: %s", pkt.raddr.IP.String(), string(pkt.data))
|
||||
if len(dht.workerTokens) == dht.PacketWorkerLimit {
|
||||
return
|
||||
}
|
||||
|
||||
q := response["q"].(string)
|
||||
a := response["a"].(map[string]interface{})
|
||||
dht.workerTokens <- struct{}{}
|
||||
|
||||
if err := parseKey(a, "id", "string"); err != nil {
|
||||
send(dht, addr, makeError(t, protocolError, err.Error()))
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
defer func() {
|
||||
<-dht.workerTokens
|
||||
}()
|
||||
|
||||
id := a["id"].(string)
|
||||
|
||||
if id == dht.node.id.RawString() {
|
||||
return
|
||||
}
|
||||
|
||||
if len(id) != nodeIDLength {
|
||||
send(dht, addr, makeError(t, protocolError, "invalid id"))
|
||||
return
|
||||
}
|
||||
|
||||
if no, ok := dht.routingTable.GetNodeByAddress(addr.String()); ok &&
|
||||
no.id.RawString() != id {
|
||||
|
||||
dht.blackList.insert(addr.IP.String(), addr.Port)
|
||||
dht.routingTable.RemoveByAddr(addr.String())
|
||||
|
||||
send(dht, addr, makeError(t, protocolError, "invalid id"))
|
||||
return
|
||||
}
|
||||
|
||||
switch q {
|
||||
case pingType:
|
||||
send(dht, addr, makeResponse(t, map[string]interface{}{
|
||||
"id": dht.id(id),
|
||||
}))
|
||||
case findNodeType:
|
||||
if err := parseKey(a, "target", "string"); err != nil {
|
||||
send(dht, addr, makeError(t, protocolError, err.Error()))
|
||||
var data map[string]interface{}
|
||||
err := bencode.DecodeBytes(pkt.data, &data)
|
||||
if err != nil {
|
||||
log.Errorf("Error decoding data: %s\n%s", err, pkt.data)
|
||||
return
|
||||
}
|
||||
|
||||
target := a["target"].(string)
|
||||
msgType, ok := data[headerTypeField]
|
||||
if !ok {
|
||||
log.Errorf("Decoded data has no message type: %s", data)
|
||||
return
|
||||
}
|
||||
|
||||
switch msgType.(int64) {
|
||||
case requestType:
|
||||
request := Request{
|
||||
ID: data[headerMessageIDField].(string),
|
||||
NodeID: data[headerNodeIDField].(string),
|
||||
Method: data[headerPayloadField].(string),
|
||||
Args: getArgs(data[headerArgsField]),
|
||||
}
|
||||
spew.Dump(request)
|
||||
handleRequest(dht, pkt.raddr, request)
|
||||
|
||||
case responseType:
|
||||
response := Response{
|
||||
ID: data[headerMessageIDField].(string),
|
||||
NodeID: data[headerNodeIDField].(string),
|
||||
Response: data[headerPayloadField].(string),
|
||||
}
|
||||
handleResponse(dht, pkt.raddr, response)
|
||||
|
||||
case errorType:
|
||||
e := Error{
|
||||
ID: data[headerMessageIDField].(string),
|
||||
NodeID: data[headerNodeIDField].(string),
|
||||
ExceptionType: data[headerPayloadField].(string),
|
||||
Response: getArgs(data[headerArgsField]),
|
||||
}
|
||||
handleError(dht, pkt.raddr, e)
|
||||
|
||||
default:
|
||||
log.Errorf("Invalid message type: %s", msgType)
|
||||
return
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func getArgs(argsInt interface{}) (args []string) {
|
||||
if reflect.TypeOf(argsInt).Kind() == reflect.Slice {
|
||||
v := reflect.ValueOf(argsInt)
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
args = append(args, cast.ToString(v.Index(i).Interface()))
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// handleRequest handles the requests received from udp.
|
||||
func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) (success bool) {
|
||||
if request.NodeID == dht.node.id.RawString() {
|
||||
return
|
||||
}
|
||||
|
||||
if len(request.NodeID) != nodeIDLength {
|
||||
send(dht, addr, Error{ID: request.ID, NodeID: dht.node.id.RawString(), Response: []string{"Invalid ID"}})
|
||||
return
|
||||
}
|
||||
|
||||
if no, ok := dht.routingTable.GetNodeByAddress(addr.String()); ok && no.id.RawString() != request.NodeID {
|
||||
dht.routingTable.RemoveByAddr(addr.String())
|
||||
send(dht, addr, Error{ID: request.ID, NodeID: dht.node.id.RawString(), Response: []string{"Invalid ID"}})
|
||||
return
|
||||
}
|
||||
|
||||
switch request.Method {
|
||||
case pingMethod:
|
||||
send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id.RawString(), Response: "pong"})
|
||||
case findNodeMethod:
|
||||
if len(request.Args) < 1 {
|
||||
send(dht, addr, Error{ID: request.ID, NodeID: dht.node.id.RawString(), Response: []string{"No target"}})
|
||||
return
|
||||
}
|
||||
|
||||
target := request.Args[0]
|
||||
if len(target) != nodeIDLength {
|
||||
send(dht, addr, makeError(t, protocolError, "invalid target"))
|
||||
send(dht, addr, Error{ID: request.ID, NodeID: dht.node.id.RawString(), Response: []string{"Invalid target"}})
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -485,94 +504,17 @@ func handleRequest(dht *DHT, addr *net.UDPAddr, response map[string]interface{})
|
|||
if no != nil {
|
||||
nodes = no.CompactNodeInfo()
|
||||
} else {
|
||||
nodes = strings.Join(
|
||||
dht.routingTable.GetNeighborCompactInfos(targetID, dht.K),
|
||||
"",
|
||||
)
|
||||
nodes = strings.Join(dht.routingTable.GetNeighborCompactInfos(targetID, dht.K), "")
|
||||
}
|
||||
|
||||
send(dht, addr, makeResponse(t, map[string]interface{}{
|
||||
"id": dht.id(target),
|
||||
"nodes": nodes,
|
||||
}))
|
||||
case getPeersType:
|
||||
if err := parseKey(a, "info_hash", "string"); err != nil {
|
||||
send(dht, addr, makeError(t, protocolError, err.Error()))
|
||||
return
|
||||
}
|
||||
send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id.RawString(), Response: nodes})
|
||||
|
||||
infoHash := a["info_hash"].(string)
|
||||
|
||||
if len(infoHash) != nodeIDLength {
|
||||
send(dht, addr, makeError(t, protocolError, "invalid info_hash"))
|
||||
return
|
||||
}
|
||||
|
||||
if peers := dht.peersManager.GetPeers(
|
||||
infoHash, dht.K); len(peers) > 0 {
|
||||
|
||||
values := make([]interface{}, len(peers))
|
||||
for i, p := range peers {
|
||||
values[i] = p.CompactIPPortInfo()
|
||||
}
|
||||
|
||||
send(dht, addr, makeResponse(t, map[string]interface{}{
|
||||
"id": dht.id(infoHash),
|
||||
"values": values,
|
||||
"token": dht.tokenManager.token(addr),
|
||||
}))
|
||||
} else {
|
||||
send(dht, addr, makeResponse(t, map[string]interface{}{
|
||||
"id": dht.id(infoHash),
|
||||
"token": dht.tokenManager.token(addr),
|
||||
"nodes": strings.Join(dht.routingTable.GetNeighborCompactInfos(
|
||||
newBitmapFromString(infoHash), dht.K), ""),
|
||||
}))
|
||||
}
|
||||
|
||||
if dht.OnGetPeers != nil {
|
||||
dht.OnGetPeers(infoHash, addr.IP.String(), addr.Port)
|
||||
}
|
||||
case announcePeerType:
|
||||
if err := parseKeys(a, [][]string{
|
||||
{"info_hash", "string"},
|
||||
{"port", "int"},
|
||||
{"token", "string"}}); err != nil {
|
||||
|
||||
send(dht, addr, makeError(t, protocolError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
infoHash := a["info_hash"].(string)
|
||||
port := a["port"].(int)
|
||||
token := a["token"].(string)
|
||||
|
||||
if !dht.tokenManager.check(addr, token) {
|
||||
// send(dht, addr, makeError(t, protocolError, "invalid token"))
|
||||
return
|
||||
}
|
||||
|
||||
if impliedPort, ok := a["implied_port"]; ok &&
|
||||
impliedPort.(int) != 0 {
|
||||
|
||||
port = addr.Port
|
||||
}
|
||||
|
||||
dht.peersManager.Insert(infoHash, newPeer(addr.IP, port, token))
|
||||
|
||||
send(dht, addr, makeResponse(t, map[string]interface{}{
|
||||
"id": dht.id(id),
|
||||
}))
|
||||
|
||||
if dht.OnAnnouncePeer != nil {
|
||||
dht.OnAnnouncePeer(infoHash, addr.IP.String(), port)
|
||||
}
|
||||
default:
|
||||
// send(dht, addr, makeError(t, protocolError, "invalid q"))
|
||||
return
|
||||
}
|
||||
|
||||
no, _ := newNode(id, addr.Network(), addr.String())
|
||||
no, _ := newNode(request.NodeID, addr.Network(), addr.String())
|
||||
dht.routingTable.Insert(no)
|
||||
return true
|
||||
}
|
||||
|
@ -580,13 +522,7 @@ func handleRequest(dht *DHT, addr *net.UDPAddr, response map[string]interface{})
|
|||
// findOn puts nodes in the response to the routingTable, then if target is in
|
||||
// the nodes or all nodes are in the routingTable, it stops. Otherwise it
|
||||
// continues to findNode or getPeers.
|
||||
func findOn(dht *DHT, r map[string]interface{}, target *bitmap, queryType string) error {
|
||||
|
||||
if err := parseKey(r, "nodes", "string"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nodes := r["nodes"].(string)
|
||||
func findOn(dht *DHT, nodes string, target *bitmap, queryType string) error {
|
||||
if len(nodes)%compactNodeInfoLength != 0 {
|
||||
return fmt.Errorf("the length of nodes should can be divided by %d", compactNodeInfoLength)
|
||||
}
|
||||
|
@ -612,10 +548,8 @@ func findOn(dht *DHT, r map[string]interface{}, target *bitmap, queryType string
|
|||
targetID := target.RawString()
|
||||
for _, no := range dht.routingTable.GetNeighbors(target, dht.K) {
|
||||
switch queryType {
|
||||
case findNodeType:
|
||||
case findNodeMethod:
|
||||
dht.transactionManager.findNode(no, targetID)
|
||||
case getPeersType:
|
||||
dht.transactionManager.getPeers(no, targetID)
|
||||
default:
|
||||
panic("invalid find type")
|
||||
}
|
||||
|
@ -624,76 +558,32 @@ func findOn(dht *DHT, r map[string]interface{}, target *bitmap, queryType string
|
|||
}
|
||||
|
||||
// handleResponse handles responses received from udp.
|
||||
func handleResponse(dht *DHT, addr *net.UDPAddr, response map[string]interface{}) (success bool) {
|
||||
|
||||
t := response["t"].(string)
|
||||
|
||||
trans := dht.transactionManager.filterOne(t, addr)
|
||||
func handleResponse(dht *DHT, addr *net.UDPAddr, response Response) (success bool) {
|
||||
trans := dht.transactionManager.filterOne(response.ID, addr)
|
||||
if trans == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// inform transManager to delete the transaction.
|
||||
if err := parseKey(response, "r", "map"); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
q := trans.data["q"].(string)
|
||||
a := trans.data["a"].(map[string]interface{})
|
||||
r := response["r"].(map[string]interface{})
|
||||
|
||||
if err := parseKey(r, "id", "string"); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
id := r["id"].(string)
|
||||
|
||||
// If response's node id is not the same with the node id in the
|
||||
// transaction, raise error.
|
||||
if trans.node.id != nil && trans.node.id.RawString() != r["id"].(string) {
|
||||
dht.blackList.insert(addr.IP.String(), addr.Port)
|
||||
// TODO: is this necessary??? why??
|
||||
if trans.node.id != nil && trans.node.id.RawString() != response.NodeID {
|
||||
dht.routingTable.RemoveByAddr(addr.String())
|
||||
return
|
||||
}
|
||||
|
||||
node, err := newNode(id, addr.Network(), addr.String())
|
||||
node, err := newNode(response.NodeID, addr.Network(), addr.String())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
switch q {
|
||||
case pingType:
|
||||
case findNodeType:
|
||||
if trans.data["q"].(string) != findNodeType {
|
||||
switch trans.request.Method {
|
||||
case pingMethod:
|
||||
case findNodeMethod:
|
||||
target := trans.request.Args[0]
|
||||
if findOn(dht, response.Response, newBitmapFromString(target), findNodeMethod) != nil {
|
||||
return
|
||||
}
|
||||
|
||||
target := trans.data["a"].(map[string]interface{})["target"].(string)
|
||||
if findOn(dht, r, newBitmapFromString(target), findNodeType) != nil {
|
||||
return
|
||||
}
|
||||
case getPeersType:
|
||||
if err := parseKey(r, "token", "string"); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
token := r["token"].(string)
|
||||
infoHash := a["info_hash"].(string)
|
||||
|
||||
if err := parseKey(r, "values", "list"); err == nil {
|
||||
values := r["values"].([]interface{})
|
||||
for _, v := range values {
|
||||
p, err := newPeerFromCompactIPPortInfo(v.(string), token)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
dht.peersManager.Insert(infoHash, p)
|
||||
}
|
||||
} else if findOn(
|
||||
dht, r, newBitmapFromString(infoHash), getPeersType) != nil {
|
||||
return
|
||||
}
|
||||
case announcePeerType:
|
||||
default:
|
||||
return
|
||||
}
|
||||
|
@ -701,71 +591,16 @@ func handleResponse(dht *DHT, addr *net.UDPAddr, response map[string]interface{}
|
|||
// inform transManager to delete transaction.
|
||||
trans.response <- struct{}{}
|
||||
|
||||
dht.blackList.delete(addr.IP.String(), addr.Port)
|
||||
dht.routingTable.Insert(node)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// handleError handles errors received from udp.
|
||||
func handleError(dht *DHT, addr *net.UDPAddr, response map[string]interface{}) (success bool) {
|
||||
|
||||
if err := parseKey(response, "e", "list"); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if e := response["e"].([]interface{}); len(e) != 2 {
|
||||
return
|
||||
}
|
||||
|
||||
if trans := dht.transactionManager.filterOne(
|
||||
response["t"].(string), addr); trans != nil {
|
||||
|
||||
func handleError(dht *DHT, addr *net.UDPAddr, e Error) (success bool) {
|
||||
if trans := dht.transactionManager.filterOne(e.ID, addr); trans != nil {
|
||||
trans.response <- struct{}{}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
var handlers = map[string]func(*DHT, *net.UDPAddr, map[string]interface{}) bool{
|
||||
"q": handleRequest,
|
||||
"r": handleResponse,
|
||||
"e": handleError,
|
||||
}
|
||||
|
||||
// handle handles packets received from udp.
|
||||
func handle(dht *DHT, pkt packet) {
|
||||
log.Infof("Packet from %s: %s", pkt.raddr.IP.String(), pkt.data)
|
||||
if len(dht.workerTokens) == dht.PacketWorkerLimit {
|
||||
return
|
||||
}
|
||||
|
||||
dht.workerTokens <- struct{}{}
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
<-dht.workerTokens
|
||||
}()
|
||||
|
||||
if dht.blackList.in(pkt.raddr.IP.String(), pkt.raddr.Port) {
|
||||
log.Infof("%s blacklisted, ignoring packet", pkt.raddr.IP.String())
|
||||
return
|
||||
}
|
||||
|
||||
data, err := Decode(pkt.data)
|
||||
if err != nil {
|
||||
log.Errorf("Error decoding data: %s\n%s", err, pkt.data)
|
||||
return
|
||||
}
|
||||
|
||||
response, err := parseMessage(data)
|
||||
if err != nil {
|
||||
log.Errorf("Error parsing message: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
if f, ok := handlers[response["y"].(string)]; ok {
|
||||
f(dht, pkt.raddr, response)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
|
386
dht/peerwire.go
386
dht/peerwire.go
|
@ -1,386 +0,0 @@
|
|||
package dht
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha1"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
// REQUEST represents request message type
|
||||
REQUEST = iota
|
||||
// DATA represents data message type
|
||||
DATA
|
||||
// REJECT represents reject message type
|
||||
REJECT
|
||||
)
|
||||
|
||||
const (
|
||||
// BLOCK is 2 ^ 14
|
||||
BLOCK = 16384
|
||||
// MaxMetadataSize represents the max medata it can accept
|
||||
MaxMetadataSize = BLOCK * 1000
|
||||
// EXTENDED represents it is a extended message
|
||||
EXTENDED = 20
|
||||
// HANDSHAKE represents handshake bit
|
||||
HANDSHAKE = 0
|
||||
)
|
||||
|
||||
var handshakePrefix = []byte{
|
||||
19, 66, 105, 116, 84, 111, 114, 114, 101, 110, 116, 32, 112, 114,
|
||||
111, 116, 111, 99, 111, 108, 0, 0, 0, 0, 0, 16, 0, 1,
|
||||
}
|
||||
var handshakePrefixLength = len(handshakePrefix)
|
||||
|
||||
// read reads size-length bytes from conn to data.
|
||||
func read(conn *net.TCPConn, size int, data *bytes.Buffer) error {
|
||||
conn.SetReadDeadline(time.Now().Add(time.Second * 15))
|
||||
|
||||
n, err := io.CopyN(data, conn, int64(size))
|
||||
if err != nil || n != int64(size) {
|
||||
return errors.New("read error")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// readMessage gets a message from the tcp connection.
|
||||
func readMessage(conn *net.TCPConn, data *bytes.Buffer) (
|
||||
length int, err error) {
|
||||
|
||||
if err = read(conn, 4, data); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
length = int(bytes2int(data.Next(4)))
|
||||
if length == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
if err = read(conn, length, data); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// sendMessage sends data to the connection.
|
||||
func sendMessage(conn *net.TCPConn, data []byte) error {
|
||||
length := int32(len(data))
|
||||
|
||||
buffer := bytes.NewBuffer(nil)
|
||||
binary.Write(buffer, binary.BigEndian, length)
|
||||
|
||||
conn.SetWriteDeadline(time.Now().Add(time.Second * 10))
|
||||
_, err := conn.Write(append(buffer.Bytes(), data...))
|
||||
return err
|
||||
}
|
||||
|
||||
// sendHandshake sends handshake message to conn.
|
||||
func sendHandshake(conn *net.TCPConn, infoHash, peerID []byte) error {
|
||||
data := make([]byte, handshakePrefixLength+nodeIDLength+len(peerID))
|
||||
copy(data[:handshakePrefixLength], handshakePrefix)
|
||||
copy(data[handshakePrefixLength:handshakePrefixLength+nodeIDLength], infoHash)
|
||||
copy(data[handshakePrefixLength+nodeIDLength:], peerID)
|
||||
|
||||
conn.SetWriteDeadline(time.Now().Add(time.Second * 10))
|
||||
_, err := conn.Write(data)
|
||||
return err
|
||||
}
|
||||
|
||||
// onHandshake handles the handshake response.
|
||||
func onHandshake(data []byte) (err error) {
|
||||
if !(bytes.Equal(handshakePrefix[:handshakePrefixLength], data[:handshakePrefixLength]) && data[25]&0x10 != 0) {
|
||||
err = errors.New("invalid handshake response")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// sendExtHandshake requests for the ut_metadata and metadata_size.
|
||||
func sendExtHandshake(conn *net.TCPConn) error {
|
||||
data := append(
|
||||
[]byte{EXTENDED, HANDSHAKE},
|
||||
Encode(map[string]interface{}{
|
||||
"m": map[string]interface{}{"ut_metadata": 1},
|
||||
})...,
|
||||
)
|
||||
|
||||
return sendMessage(conn, data)
|
||||
}
|
||||
|
||||
// getUTMetaSize returns the ut_metadata and metadata_size.
|
||||
func getUTMetaSize(data []byte) (
|
||||
utMetadata int, metadataSize int, err error) {
|
||||
|
||||
v, err := Decode(data)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
dict, ok := v.(map[string]interface{})
|
||||
if !ok {
|
||||
err = errors.New("invalid dict")
|
||||
return
|
||||
}
|
||||
|
||||
if err = parseKeys(
|
||||
dict, [][]string{{"metadata_size", "int"}, {"m", "map"}}); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
m := dict["m"].(map[string]interface{})
|
||||
if err = parseKey(m, "ut_metadata", "int"); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
utMetadata = m["ut_metadata"].(int)
|
||||
metadataSize = dict["metadata_size"].(int)
|
||||
|
||||
if metadataSize > MaxMetadataSize {
|
||||
err = errors.New("metadata_size too long")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Request represents the request context.
|
||||
type Request struct {
|
||||
InfoHash []byte
|
||||
IP string
|
||||
Port int
|
||||
}
|
||||
|
||||
// Response contains the request context and the metadata info.
|
||||
type Response struct {
|
||||
Request
|
||||
MetadataInfo []byte
|
||||
}
|
||||
|
||||
// Wire represents the wire protocol.
|
||||
type Wire struct {
|
||||
blackList *blackList
|
||||
queue *syncedMap
|
||||
requests chan Request
|
||||
responses chan Response
|
||||
workerTokens chan struct{}
|
||||
}
|
||||
|
||||
// NewWire returns a Wire pointer.
|
||||
// - blackListSize: the blacklist size
|
||||
// - requestQueueSize: the max requests it can buffers
|
||||
// - workerQueueSize: the max goroutine downloading workers
|
||||
func NewWire(blackListSize, requestQueueSize, workerQueueSize int) *Wire {
|
||||
return &Wire{
|
||||
blackList: newBlackList(blackListSize),
|
||||
queue: newSyncedMap(),
|
||||
requests: make(chan Request, requestQueueSize),
|
||||
responses: make(chan Response, 1024),
|
||||
workerTokens: make(chan struct{}, workerQueueSize),
|
||||
}
|
||||
}
|
||||
|
||||
// Request pushes the request to the queue.
|
||||
func (wire *Wire) Request(infoHash []byte, ip string, port int) {
|
||||
wire.requests <- Request{InfoHash: infoHash, IP: ip, Port: port}
|
||||
}
|
||||
|
||||
// Response returns a chan of Response.
|
||||
func (wire *Wire) Response() <-chan Response {
|
||||
return wire.responses
|
||||
}
|
||||
|
||||
// isDone returns whether the wire get all pieces of the metadata info.
|
||||
func (wire *Wire) isDone(pieces [][]byte) bool {
|
||||
for _, piece := range pieces {
|
||||
if len(piece) == 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (wire *Wire) requestPieces(
|
||||
conn *net.TCPConn, utMetadata int, metadataSize int, piecesNum int) {
|
||||
|
||||
buffer := make([]byte, 1024)
|
||||
for i := 0; i < piecesNum; i++ {
|
||||
buffer[0] = EXTENDED
|
||||
buffer[1] = byte(utMetadata)
|
||||
|
||||
msg := Encode(map[string]interface{}{
|
||||
"msg_type": REQUEST,
|
||||
"piece": i,
|
||||
})
|
||||
|
||||
length := len(msg) + 2
|
||||
copy(buffer[2:length], msg)
|
||||
|
||||
sendMessage(conn, buffer[:length])
|
||||
}
|
||||
buffer = nil
|
||||
}
|
||||
|
||||
// fetchMetadata fetchs medata info accroding to infohash from dht.
|
||||
func (wire *Wire) fetchMetadata(r Request) {
|
||||
var (
|
||||
length int
|
||||
msgType byte
|
||||
piecesNum int
|
||||
pieces [][]byte
|
||||
utMetadata int
|
||||
metadataSize int
|
||||
)
|
||||
|
||||
defer func() {
|
||||
pieces = nil
|
||||
recover()
|
||||
}()
|
||||
|
||||
infoHash := r.InfoHash
|
||||
address := genAddress(r.IP, r.Port)
|
||||
|
||||
dial, err := net.DialTimeout("tcp", address, time.Second*15)
|
||||
if err != nil {
|
||||
wire.blackList.insert(r.IP, r.Port)
|
||||
return
|
||||
}
|
||||
conn := dial.(*net.TCPConn)
|
||||
conn.SetLinger(0)
|
||||
defer conn.Close()
|
||||
|
||||
data := bytes.NewBuffer(nil)
|
||||
data.Grow(BLOCK)
|
||||
|
||||
if sendHandshake(conn, infoHash, []byte(randomString(nodeIDLength))) != nil ||
|
||||
read(conn, 68, data) != nil ||
|
||||
onHandshake(data.Next(68)) != nil ||
|
||||
sendExtHandshake(conn) != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
length, err = readMessage(conn, data)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if length == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
msgType, err = data.ReadByte()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
switch msgType {
|
||||
case EXTENDED:
|
||||
extendedID, err := data.ReadByte()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
payload, err := ioutil.ReadAll(data)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if extendedID == 0 {
|
||||
if pieces != nil {
|
||||
return
|
||||
}
|
||||
|
||||
utMetadata, metadataSize, err = getUTMetaSize(payload)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
piecesNum = metadataSize / BLOCK
|
||||
if metadataSize%BLOCK != 0 {
|
||||
piecesNum++
|
||||
}
|
||||
|
||||
pieces = make([][]byte, piecesNum)
|
||||
go wire.requestPieces(conn, utMetadata, metadataSize, piecesNum)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if pieces == nil {
|
||||
return
|
||||
}
|
||||
|
||||
d, index, err := DecodeDict(payload, 0)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
dict := d.(map[string]interface{})
|
||||
|
||||
if err = parseKeys(dict, [][]string{
|
||||
{"msg_type", "int"},
|
||||
{"piece", "int"}}); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if dict["msg_type"].(int) != DATA {
|
||||
continue
|
||||
}
|
||||
|
||||
piece := dict["piece"].(int)
|
||||
pieceLen := length - 2 - index
|
||||
|
||||
if (piece != piecesNum-1 && pieceLen != BLOCK) ||
|
||||
(piece == piecesNum-1 && pieceLen != metadataSize%BLOCK) {
|
||||
return
|
||||
}
|
||||
|
||||
pieces[piece] = payload[index:]
|
||||
|
||||
if wire.isDone(pieces) {
|
||||
metadataInfo := bytes.Join(pieces, nil)
|
||||
|
||||
info := sha1.Sum(metadataInfo)
|
||||
if !bytes.Equal(infoHash, info[:]) {
|
||||
return
|
||||
}
|
||||
|
||||
wire.responses <- Response{
|
||||
Request: r,
|
||||
MetadataInfo: metadataInfo,
|
||||
}
|
||||
return
|
||||
}
|
||||
default:
|
||||
data.Reset()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Run starts the peer wire protocol.
|
||||
func (wire *Wire) Run() {
|
||||
go wire.blackList.clear()
|
||||
|
||||
for r := range wire.requests {
|
||||
wire.workerTokens <- struct{}{}
|
||||
|
||||
go func(r Request) {
|
||||
defer func() {
|
||||
<-wire.workerTokens
|
||||
}()
|
||||
|
||||
key := strings.Join([]string{
|
||||
string(r.InfoHash), genAddress(r.IP, r.Port),
|
||||
}, ":")
|
||||
|
||||
if len(r.InfoHash) != nodeIDLength || wire.blackList.in(r.IP, r.Port) ||
|
||||
wire.queue.Has(key) {
|
||||
return
|
||||
}
|
||||
|
||||
wire.fetchMetadata(r)
|
||||
}(r)
|
||||
}
|
||||
}
|
|
@ -11,7 +11,7 @@ import (
|
|||
|
||||
// maxPrefixLength is the length of DHT node.
|
||||
const maxPrefixLength = 160
|
||||
const nodeIDLength = 20
|
||||
const nodeIDLength = 48
|
||||
const compactNodeInfoLength = nodeIDLength + 6
|
||||
|
||||
// node represents a DHT node.
|
||||
|
@ -359,11 +359,6 @@ func (rt *routingTable) Insert(nd *node) bool {
|
|||
rt.Lock()
|
||||
defer rt.Unlock()
|
||||
|
||||
if rt.dht.blackList.in(nd.addr.IP.String(), nd.addr.Port) ||
|
||||
rt.cachedNodes.Len() >= rt.dht.MaxNodes {
|
||||
return false
|
||||
}
|
||||
|
||||
var (
|
||||
next *routingTableNode
|
||||
bucket *kbucket
|
||||
|
|
34
main.go
Normal file
34
main.go
Normal file
|
@ -0,0 +1,34 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/lbryio/lbry.go/dht"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func main() {
|
||||
config := dht.NewStandardConfig()
|
||||
config.Address = ":49449" // dont pollute real port
|
||||
config.PrimeNodes = []string{
|
||||
"127.0.0.1:10001",
|
||||
}
|
||||
|
||||
d := dht.New(config)
|
||||
log.Info("Starting...")
|
||||
go d.Run()
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
for {
|
||||
peers, err := d.GetPeers("012b66fc7052d9a0c8cb563b8ede7662003ba65f425c2661b5c6919d445deeb31469be8b842d6faeea3f2b3ebcaec845")
|
||||
if err != nil {
|
||||
time.Sleep(time.Second * 1)
|
||||
continue
|
||||
}
|
||||
|
||||
fmt.Println("Found peers:", peers)
|
||||
break
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue