remove blacklist, peerwire, and custom bencode lib. successfully receive a request and send a response

This commit is contained in:
Alex Grintsvayg 2017-08-16 17:34:53 -04:00
parent fa8a4a59bc
commit c0290497de
9 changed files with 263 additions and 1381 deletions

View file

@ -1,257 +0,0 @@
package dht
import (
"bytes"
"errors"
"strconv"
"strings"
"unicode"
"unicode/utf8"
)
// find returns the index of first target in data starting from `start`.
// It returns -1 if target not found.
func find(data []byte, start int, target rune) (index int) {
index = bytes.IndexRune(data[start:], target)
if index != -1 {
return index + start
}
return index
}
// DecodeString decodes a string in the data. It returns a tuple
// (decoded result, the end position, error).
func DecodeString(data []byte, start int) (result interface{}, index int, err error) {
if start >= len(data) || data[start] < '0' || data[start] > '9' {
err = errors.New("invalid string bencode")
return
}
i := find(data, start, ':')
if i == -1 {
err = errors.New("':' not found when decode string")
return
}
length, err := strconv.Atoi(string(data[start:i]))
if err != nil {
return
}
if length < 0 {
err = errors.New("invalid length of string")
return
}
index = i + 1 + length
if index > len(data) || index < i+1 {
err = errors.New("out of range")
return
}
result = string(data[i+1 : index])
return
}
// DecodeInt decodes int value in the data.
func DecodeInt(data []byte, start int) (result interface{}, index int, err error) {
if start >= len(data) || data[start] != 'i' {
err = errors.New("invalid int bencode")
return
}
index = find(data, start+1, 'e')
if index == -1 {
err = errors.New("':' not found when decode int")
return
}
result, err = strconv.Atoi(string(data[start+1 : index]))
if err != nil {
return
}
index++
return
}
// decodeItem decodes an item of dict or list.
func decodeItem(data []byte, i int) (result interface{}, index int, err error) {
var decodeFunc = []func([]byte, int) (interface{}, int, error){
DecodeString, DecodeInt, DecodeList, DecodeDict,
}
for _, f := range decodeFunc {
result, index, err = f(data, i)
if err == nil {
return
}
}
err = errors.New("invalid bencode when decode item")
return
}
// DecodeList decodes a list value.
func DecodeList(data []byte, start int) (result interface{}, index int, err error) {
if start >= len(data) || data[start] != 'l' {
err = errors.New("invalid list bencode")
return
}
var item interface{}
r := make([]interface{}, 0, 8)
index = start + 1
for index < len(data) {
char, _ := utf8.DecodeRune(data[index:])
if char == 'e' {
break
}
item, index, err = decodeItem(data, index)
if err != nil {
return
}
r = append(r, item)
}
if index == len(data) {
err = errors.New("'e' not found when decode list")
return
}
index++
result = r
return
}
// DecodeDict decodes a map value.
func DecodeDict(data []byte, start int) (result interface{}, index int, err error) {
if start >= len(data) || data[start] != 'd' {
err = errors.New("invalid dict bencode")
return
}
var item, key interface{}
r := make(map[string]interface{})
index = start + 1
for index < len(data) {
char, _ := utf8.DecodeRune(data[index:])
if char == 'e' {
break
}
if !unicode.IsDigit(char) {
err = errors.New("invalid dict bencode")
return
}
key, index, err = DecodeString(data, index)
if err != nil {
return
}
if index >= len(data) {
err = errors.New("out of range")
return
}
item, index, err = decodeItem(data, index)
if err != nil {
return
}
r[key.(string)] = item
}
if index == len(data) {
err = errors.New("'e' not found when decode dict")
return
}
index++
result = r
return
}
// Decode decodes a bencoded string to string, int, list or map.
func Decode(data []byte) (result interface{}, err error) {
result, _, err = decodeItem(data, 0)
return
}
// EncodeString encodes a string value.
func EncodeString(data string) string {
return strings.Join([]string{strconv.Itoa(len(data)), data}, ":")
}
// EncodeInt encodes a int value.
func EncodeInt(data int) string {
return strings.Join([]string{"i", strconv.Itoa(data), "e"}, "")
}
// EncodeItem encodes an item of dict or list.
func encodeItem(data interface{}) (item string) {
switch v := data.(type) {
case string:
item = EncodeString(v)
case int:
item = EncodeInt(v)
case []interface{}:
item = EncodeList(v)
case map[string]interface{}:
item = EncodeDict(v)
default:
panic("invalid type when encode item")
}
return
}
// EncodeList encodes a list value.
func EncodeList(data []interface{}) string {
result := make([]string, len(data))
for i, item := range data {
result[i] = encodeItem(item)
}
return strings.Join([]string{"l", strings.Join(result, ""), "e"}, "")
}
// EncodeDict encodes a dict value.
func EncodeDict(data map[string]interface{}) string {
result, i := make([]string, len(data)), 0
for key, val := range data {
result[i] = strings.Join(
[]string{EncodeString(key), encodeItem(val)},
"")
i++
}
return strings.Join([]string{"d", strings.Join(result, ""), "e"}, "")
}
// Encode encodes a string, int, dict or list value to a bencoded string.
func Encode(data interface{}) string {
switch v := data.(type) {
case string:
return EncodeString(v)
case int:
return EncodeInt(v)
case []interface{}:
return EncodeList(v)
case map[string]interface{}:
return EncodeDict(v)
default:
panic("invalid type when encode")
}
}

View file

@ -1,159 +0,0 @@
package dht
import (
"testing"
)
func TestDecodeString(t *testing.T) {
cases := []struct {
in string
out string
}{
{"0:", ""},
{"1:a", "a"},
{"5:hello", "hello"},
}
for _, c := range cases {
if out, err := Decode([]byte(c.in)); err != nil || out != c.out {
t.Error(err)
}
}
}
func TestDecodeInt(t *testing.T) {
cases := []struct {
in string
out int
}{
{"i123e:", 123},
{"i0e", 0},
{"i-1e", -1},
}
for _, c := range cases {
if out, err := Decode([]byte(c.in)); err != nil || out != c.out {
t.Error(err)
}
}
}
func TestDecodeList(t *testing.T) {
cases := []struct {
in string
out []interface{}
}{
{"li123ei-1ee", []interface{}{123, -1}},
{"l5:helloe", []interface{}{"hello"}},
{"ld5:hello5:worldee", []interface{}{map[string]interface{}{"hello": "world"}}},
{"lli1ei2eee", []interface{}{[]interface{}{1, 2}}},
}
for i, c := range cases {
v, err := Decode([]byte(c.in))
if err != nil {
t.Fail()
}
out := v.([]interface{})
switch i {
case 0, 1:
for j, item := range out {
if item != c.out[j] {
t.Fail()
}
}
case 2:
if len(out) != 1 {
t.Fail()
}
o := out[0].(map[string]interface{})
cout := c.out[0].(map[string]interface{})
for k, v := range o {
if cv, ok := cout[k]; !ok || v != cv {
t.Fail()
}
}
case 3:
if len(out) != 1 {
t.Fail()
}
o := out[0].([]interface{})
cout := c.out[0].([]interface{})
for j, item := range o {
if item != cout[j] {
t.Fail()
}
}
}
}
}
func TestDecodeDict(t *testing.T) {
cases := []struct {
in string
out map[string]interface{}
}{
{"d5:helloi100ee", map[string]interface{}{"hello": 100}},
{"d3:foo3:bare", map[string]interface{}{"foo": "bar"}},
{"d1:ad3:foo3:baree", map[string]interface{}{"a": map[string]interface{}{"foo": "bar"}}},
{"d4:listli1eee", map[string]interface{}{"list": []interface{}{1}}},
}
for i, c := range cases {
v, err := Decode([]byte(c.in))
if err != nil {
t.Fail()
}
out := v.(map[string]interface{})
switch i {
case 0, 1:
for k, v := range out {
if cv, ok := c.out[k]; !ok || v != cv {
t.Fail()
}
}
case 2:
if len(out) != 1 {
t.Fail()
}
v, ok := out["a"]
if !ok {
t.Fail()
}
cout := c.out["a"].(map[string]interface{})
for k, v := range v.(map[string]interface{}) {
if cv, ok := cout[k]; !ok || v != cv {
t.Fail()
}
}
case 3:
if len(out) != 1 {
t.Fail()
}
v, ok := out["list"]
if !ok {
t.Fail()
}
cout := c.out["list"].([]interface{})
for j, v := range v.([]interface{}) {
if v != cout[j] {
t.Fail()
}
}
}
}
}

View file

@ -1,92 +0,0 @@
package dht
import (
"time"
)
// blockedItem represents a blocked node.
type blockedItem struct {
ip string
port int
createTime time.Time
}
// blackList manages the blocked nodes including which sends bad information
// and can't ping out.
type blackList struct {
list *syncedMap
maxSize int
expiredAfter time.Duration
}
// newBlackList returns a blackList pointer.
func newBlackList(size int) *blackList {
return &blackList{
list: newSyncedMap(),
maxSize: size,
expiredAfter: time.Hour * 1,
}
}
// genKey returns a key. If port is less than 0, the key wil be ip. Ohterwise
// it will be `ip:port` format.
func (bl *blackList) genKey(ip string, port int) string {
key := ip
if port >= 0 {
key = genAddress(ip, port)
}
return key
}
// insert adds a blocked item to the blacklist.
func (bl *blackList) insert(ip string, port int) {
if bl.list.Len() >= bl.maxSize {
return
}
bl.list.Set(bl.genKey(ip, port), &blockedItem{
ip: ip,
port: port,
createTime: time.Now(),
})
}
// delete removes blocked item form the blackList.
func (bl *blackList) delete(ip string, port int) {
bl.list.Delete(bl.genKey(ip, port))
}
// validate checks whether ip-port pair is in the block nodes list.
func (bl *blackList) in(ip string, port int) bool {
if _, ok := bl.list.Get(ip); ok {
return true
}
key := bl.genKey(ip, port)
v, ok := bl.list.Get(key)
if ok {
if time.Now().Sub(v.(*blockedItem).createTime) < bl.expiredAfter {
return true
}
bl.list.Delete(key)
}
return false
}
// clear cleans the expired items every 10 minutes.
func (bl *blackList) clear() {
for _ = range time.Tick(time.Minute * 10) {
keys := make([]interface{}, 0, 100)
for item := range bl.list.Iter() {
if time.Now().Sub(
item.val.(*blockedItem).createTime) > bl.expiredAfter {
keys = append(keys, item.key)
}
}
bl.list.DeleteMulti(keys)
}
}

View file

@ -1,57 +0,0 @@
package dht
import (
"fmt"
"testing"
)
var blacklist = newBlackList(256)
func TestGenKey(t *testing.T) {
cases := []struct {
in struct {
ip string
port int
}
out string
}{
{struct {
ip string
port int
}{"0.0.0.0", -1}, "0.0.0.0"},
{struct {
ip string
port int
}{"1.1.1.1", 8080}, "1.1.1.1:8080"},
}
for _, c := range cases {
if blacklist.genKey(c.in.ip, c.in.port) != c.out {
t.Fail()
}
}
}
func TestBlackList(t *testing.T) {
address := []struct {
ip string
port int
}{
{"0.0.0.0", -1},
{"1.1.1.1", 8080},
{"2.2.2.2", 8081},
}
for _, addr := range address {
blacklist.insert(addr.ip, addr.port)
if !blacklist.in(addr.ip, addr.port) {
t.Fail()
}
blacklist.delete(addr.ip, addr.port)
if blacklist.in(addr.ip, addr.port) {
fmt.Println(addr.ip)
t.Fail()
}
}
}

View file

@ -15,9 +15,6 @@ import (
type Config struct {
// in mainline dht, k = 8
K int
// for crawling mode, we put all nodes in one bucket, so KBucketSize may
// not be K
KBucketSize int
// candidates are udp, udp4, udp6
Network string
// format is `ip:port`
@ -40,10 +37,6 @@ type Config struct {
OnGetPeers func(string, string, int)
// callback when got announce_peer request
OnAnnouncePeer func(string, string, int)
// blcoked ips
BlockedIPs []string
// blacklist size
BlackListMaxSize int
// the times it tries when send fails
Try int
// the size of packet need to be dealt with
@ -57,14 +50,13 @@ type Config struct {
// NewStandardConfig returns a Config pointer with default values.
func NewStandardConfig() *Config {
return &Config{
K: 8,
KBucketSize: 8,
Network: "udp4",
Address: ":6881",
K: 8,
Network: "udp4",
Address: ":4444",
PrimeNodes: []string{
"router.bittorrent.com:6881",
"router.utorrent.com:6881",
"dht.transmissionbt.com:6881",
"lbrynet1.lbry.io:4444",
"lbrynet2.lbry.io:4444",
"lbrynet3.lbry.io:4444",
},
NodeExpriedAfter: time.Duration(time.Minute * 15),
KBucketExpiredAfter: time.Duration(time.Minute * 15),
@ -72,8 +64,6 @@ func NewStandardConfig() *Config {
TokenExpiredAfter: time.Duration(time.Minute * 10),
MaxTransactionCursor: math.MaxUint32,
MaxNodes: 5000,
BlockedIPs: make([]string, 0),
BlackListMaxSize: 65536,
Try: 2,
PacketJobLimit: 1024,
PacketWorkerLimit: 256,
@ -90,7 +80,6 @@ type DHT struct {
transactionManager *transactionManager
peersManager *peersManager
tokenManager *tokenManager
blackList *blackList
Ready bool
packets chan packet
workerTokens chan struct{}
@ -111,26 +100,10 @@ func New(config *Config) *DHT {
d := &DHT{
Config: config,
node: node,
blackList: newBlackList(config.BlackListMaxSize),
packets: make(chan packet, config.PacketJobLimit),
workerTokens: make(chan struct{}, config.PacketWorkerLimit),
}
for _, ip := range config.BlockedIPs {
d.blackList.insert(ip, -1)
}
go func() {
for _, ip := range getLocalIPs() {
d.blackList.insert(ip, -1)
}
ip, err := getRemoteIP()
if err != nil {
d.blackList.insert(ip, -1)
}
}()
return d
}
@ -143,15 +116,13 @@ func (dht *DHT) init() {
}
dht.conn = listener.(*net.UDPConn)
dht.routingTable = newRoutingTable(dht.KBucketSize, dht)
dht.routingTable = newRoutingTable(dht.K, dht)
dht.peersManager = newPeersManager(dht)
dht.tokenManager = newTokenManager(dht.TokenExpiredAfter, dht)
dht.transactionManager = newTransactionManager(
dht.MaxTransactionCursor, dht)
dht.transactionManager = newTransactionManager(dht.MaxTransactionCursor, dht)
go dht.transactionManager.run()
go dht.tokenManager.clear()
go dht.blackList.clear()
}
// join makes current node join the dht network.
@ -162,9 +133,8 @@ func (dht *DHT) join() {
continue
}
// NOTE: Temporary node has NOT node id.
dht.transactionManager.findNode(
&node{addr: raddr},
&node{id: dht.node.id, addr: raddr},
dht.node.id.RawString(),
)
}
@ -179,7 +149,6 @@ func (dht *DHT) listen() {
if err != nil {
continue
}
log.Infof("Received %s", buff)
dht.packets <- packet{buff[:n], raddr}
}
@ -214,12 +183,12 @@ func (dht *DHT) GetPeers(infoHash string) ([]*Peer, error) {
ch := make(chan struct{})
go func() {
neighbors := dht.routingTable.GetNeighbors(
newBitmapFromString(infoHash), dht.K)
for _, no := range neighbors {
dht.transactionManager.getPeers(no, infoHash)
}
//neighbors := dht.routingTable.GetNeighbors(
// newBitmapFromString(infoHash), dht.K)
//
//for _, no := range neighbors {
// dht.transactionManager.getPeers(no, infoHash)
//}
i := 0
for range time.Tick(time.Second * 1) {

View file

@ -1,20 +1,23 @@
package dht
import (
"errors"
"fmt"
"github.com/davecgh/go-spew/spew"
log "github.com/sirupsen/logrus"
"github.com/spf13/cast"
"github.com/zeebo/bencode"
"net"
"reflect"
"strings"
"sync"
"time"
)
const (
pingType = "ping"
findNodeType = "find_node"
getPeersType = "get_peers"
announcePeerType = "announce_peer"
pingMethod = "ping"
storeMethod = "store"
findNodeMethod = "findNode"
findValueMethod = "findValue"
)
const (
@ -24,6 +27,78 @@ const (
unknownError
)
const (
requestType = 0
responseType = 1
errorType = 2
)
const (
// these are strings because bencode requires bytestring keys
headerTypeField = "0"
headerMessageIDField = "1"
headerNodeIDField = "2"
headerPayloadField = "3"
headerArgsField = "4"
)
type Message interface {
GetID() string
Encode() ([]byte, error)
}
type Request struct {
ID string
NodeID string
Method string
Args []string
}
func (r Request) GetID() string { return r.ID }
func (r Request) Encode() ([]byte, error) {
return bencode.EncodeBytes(map[string]interface{}{
headerTypeField: requestType,
headerMessageIDField: r.ID,
headerNodeIDField: r.NodeID,
headerPayloadField: r.Method,
headerArgsField: r.Args,
})
}
type Response struct {
ID string
NodeID string
Response string
}
func (r Response) GetID() string { return r.ID }
func (r Response) Encode() ([]byte, error) {
return bencode.EncodeBytes(map[string]interface{}{
headerTypeField: responseType,
headerMessageIDField: r.ID,
headerNodeIDField: r.NodeID,
headerPayloadField: r.Response,
})
}
type Error struct {
ID string
NodeID string
Response []string
ExceptionType string
}
func (e Error) GetID() string { return e.ID }
func (e Error) Encode() ([]byte, error) {
return bencode.EncodeBytes(map[string]interface{}{
headerTypeField: errorType,
headerMessageIDField: e.ID,
headerNodeIDField: e.NodeID,
headerPayloadField: e.ExceptionType,
headerArgsField: e.Response,
})
}
// packet represents the information receive from udp.
type packet struct {
data []byte
@ -108,39 +183,25 @@ func makeQuery(t, q string, a map[string]interface{}) map[string]interface{} {
}
}
// makeResponse returns a response-formed data.
func makeResponse(t string, r map[string]interface{}) map[string]interface{} {
return map[string]interface{}{
"t": t,
"y": "r",
"r": r,
}
}
// makeError returns a err-formed data.
func makeError(t string, errCode int, errMsg string) map[string]interface{} {
return map[string]interface{}{
"t": t,
"y": "e",
"e": []interface{}{errCode, errMsg},
}
}
// send sends data to the udp.
func send(dht *DHT, addr *net.UDPAddr, data map[string]interface{}) error {
func send(dht *DHT, addr *net.UDPAddr, data Message) error {
log.Infof("Sending %s", spew.Sdump(data))
encoded, err := data.Encode()
if err != nil {
return err
}
log.Infof("Encoded: %s", string(encoded))
dht.conn.SetWriteDeadline(time.Now().Add(time.Second * 15))
_, err := dht.conn.WriteToUDP([]byte(Encode(data)), addr)
if err != nil {
dht.blackList.insert(addr.IP.String(), -1)
}
_, err = dht.conn.WriteToUDP(encoded, addr)
return err
}
// query represents the query data included queried node and query-formed data.
type query struct {
node *node
data map[string]interface{}
node *node
request Request
}
// transaction implements transaction.
@ -199,7 +260,7 @@ func (tm *transactionManager) genIndexKey(queryType, address string) string {
// genIndexKeyByTrans generates an indexed key by a transaction.
func (tm *transactionManager) genIndexKeyByTrans(trans *transaction) string {
return tm.genIndexKey(trans.data["q"].(string), trans.node.addr.String())
return tm.genIndexKey(trans.request.Method, trans.node.addr.String())
}
// insert adds a transaction to transactionManager.
@ -233,8 +294,7 @@ func (tm *transactionManager) len() int {
// transaction returns a transaction. keyType should be one of 0, 1 which
// represents transId and index each.
func (tm *transactionManager) transaction(
key string, keyType int) *transaction {
func (tm *transactionManager) transaction(key string, keyType int) *transaction {
sm := tm.transactions
if keyType == 1 {
@ -261,14 +321,11 @@ func (tm *transactionManager) getByIndex(index string) *transaction {
// transaction gets the proper transaction with whose id is transId and
// address is addr.
func (tm *transactionManager) filterOne(
transID string, addr *net.UDPAddr) *transaction {
func (tm *transactionManager) filterOne(transID string, addr *net.UDPAddr) *transaction {
trans := tm.getByTransID(transID)
if trans == nil || trans.node.addr.String() != addr.String() {
return nil
}
return trans
}
@ -276,15 +333,14 @@ func (tm *transactionManager) filterOne(
// When timeout, it will retry `try - 1` times, which means it will query
// `try` times totally.
func (tm *transactionManager) query(q *query, try int) {
transID := q.data["t"].(string)
trans := tm.newTransaction(transID, q)
trans := tm.newTransaction(q.request.ID, q)
tm.insert(trans)
defer tm.delete(trans.id)
success := false
for i := 0; i < try; i++ {
if err := send(tm.dht, q.node.addr, q.data); err != nil {
if err := send(tm.dht, q.node.addr, q.request); err != nil {
break
}
@ -297,7 +353,6 @@ func (tm *transactionManager) query(q *query, try int) {
}
if !success && q.node.id != nil {
tm.dht.blackList.insert(q.node.addr.IP.String(), q.node.addr.Port)
tm.dht.routingTable.RemoveByAddr(q.node.addr.String())
}
}
@ -315,166 +370,130 @@ func (tm *transactionManager) run() {
}
// sendQuery send query-formed data to the chan.
func (tm *transactionManager) sendQuery(no *node, queryType string, a map[string]interface{}) {
func (tm *transactionManager) sendQuery(no *node, request Request) {
// If the target is self, then stop.
if no.id != nil && no.id.RawString() == tm.dht.node.id.RawString() ||
tm.getByIndex(tm.genIndexKey(queryType, no.addr.String())) != nil ||
tm.dht.blackList.in(no.addr.IP.String(), no.addr.Port) {
tm.getByIndex(tm.genIndexKey(request.Method, no.addr.String())) != nil {
return
}
data := makeQuery(tm.genTransID(), queryType, a)
tm.queryChan <- &query{
node: no,
data: data,
}
request.ID = tm.genTransID()
request.NodeID = tm.dht.id(no.id.RawString())
tm.queryChan <- &query{node: no, request: request}
}
// ping sends ping query to the chan.
func (tm *transactionManager) ping(no *node) {
tm.sendQuery(no, pingType, map[string]interface{}{
"id": tm.dht.id(no.id.RawString()),
})
tm.sendQuery(no, Request{Method: pingMethod})
}
// findNode sends find_node query to the chan.
func (tm *transactionManager) findNode(no *node, target string) {
tm.sendQuery(no, findNodeType, map[string]interface{}{
"id": tm.dht.id(target),
"target": target,
})
tm.sendQuery(no, Request{Method: findNodeMethod, Args: []string{target}})
}
// getPeers sends get_peers query to the chan.
func (tm *transactionManager) getPeers(no *node, infoHash string) {
tm.sendQuery(no, getPeersType, map[string]interface{}{
"id": tm.dht.id(infoHash),
"info_hash": infoHash,
})
}
// announcePeer sends announce_peer query to the chan.
func (tm *transactionManager) announcePeer(no *node, infoHash string, impliedPort, port int, token string) {
tm.sendQuery(no, announcePeerType, map[string]interface{}{
"id": tm.dht.id(no.id.RawString()),
"info_hash": infoHash,
"implied_port": impliedPort,
"port": port,
"token": token,
})
}
// parseKey parses the key in dict data. `t` is type of the keyed value.
// It's one of "int", "string", "map", "list".
func parseKey(data map[string]interface{}, key string, t string) error {
val, ok := data[key]
if !ok {
return errors.New("lack of key")
}
switch t {
case "string":
_, ok = val.(string)
case "int":
_, ok = val.(int)
case "map":
_, ok = val.(map[string]interface{})
case "list":
_, ok = val.([]interface{})
default:
panic("invalid type")
}
if !ok {
return errors.New("invalid key type")
}
return nil
}
// parseKeys parses keys. It just wraps parseKey.
func parseKeys(data map[string]interface{}, pairs [][]string) error {
for _, args := range pairs {
key, t := args[0], args[1]
if err := parseKey(data, key, t); err != nil {
return err
}
}
return nil
}
// parseMessage parses the basic data received from udp.
// It returns a map value.
func parseMessage(data interface{}) (map[string]interface{}, error) {
response, ok := data.(map[string]interface{})
if !ok {
return nil, errors.New("response is not dict")
}
if err := parseKeys(
response, [][]string{{"t", "string"}, {"y", "string"}}); err != nil {
return nil, err
}
return response, nil
}
// handleRequest handles the requests received from udp.
func handleRequest(dht *DHT, addr *net.UDPAddr, response map[string]interface{}) (success bool) {
t := response["t"].(string)
if err := parseKeys(
response, [][]string{{"q", "string"}, {"a", "map"}}); err != nil {
send(dht, addr, makeError(t, protocolError, err.Error()))
// handle handles packets received from udp.
func handle(dht *DHT, pkt packet) {
log.Infof("Received message from %s: %s", pkt.raddr.IP.String(), string(pkt.data))
if len(dht.workerTokens) == dht.PacketWorkerLimit {
return
}
q := response["q"].(string)
a := response["a"].(map[string]interface{})
dht.workerTokens <- struct{}{}
if err := parseKey(a, "id", "string"); err != nil {
send(dht, addr, makeError(t, protocolError, err.Error()))
return
}
go func() {
defer func() {
<-dht.workerTokens
}()
id := a["id"].(string)
if id == dht.node.id.RawString() {
return
}
if len(id) != nodeIDLength {
send(dht, addr, makeError(t, protocolError, "invalid id"))
return
}
if no, ok := dht.routingTable.GetNodeByAddress(addr.String()); ok &&
no.id.RawString() != id {
dht.blackList.insert(addr.IP.String(), addr.Port)
dht.routingTable.RemoveByAddr(addr.String())
send(dht, addr, makeError(t, protocolError, "invalid id"))
return
}
switch q {
case pingType:
send(dht, addr, makeResponse(t, map[string]interface{}{
"id": dht.id(id),
}))
case findNodeType:
if err := parseKey(a, "target", "string"); err != nil {
send(dht, addr, makeError(t, protocolError, err.Error()))
var data map[string]interface{}
err := bencode.DecodeBytes(pkt.data, &data)
if err != nil {
log.Errorf("Error decoding data: %s\n%s", err, pkt.data)
return
}
target := a["target"].(string)
msgType, ok := data[headerTypeField]
if !ok {
log.Errorf("Decoded data has no message type: %s", data)
return
}
switch msgType.(int64) {
case requestType:
request := Request{
ID: data[headerMessageIDField].(string),
NodeID: data[headerNodeIDField].(string),
Method: data[headerPayloadField].(string),
Args: getArgs(data[headerArgsField]),
}
spew.Dump(request)
handleRequest(dht, pkt.raddr, request)
case responseType:
response := Response{
ID: data[headerMessageIDField].(string),
NodeID: data[headerNodeIDField].(string),
Response: data[headerPayloadField].(string),
}
handleResponse(dht, pkt.raddr, response)
case errorType:
e := Error{
ID: data[headerMessageIDField].(string),
NodeID: data[headerNodeIDField].(string),
ExceptionType: data[headerPayloadField].(string),
Response: getArgs(data[headerArgsField]),
}
handleError(dht, pkt.raddr, e)
default:
log.Errorf("Invalid message type: %s", msgType)
return
}
}()
}
func getArgs(argsInt interface{}) (args []string) {
if reflect.TypeOf(argsInt).Kind() == reflect.Slice {
v := reflect.ValueOf(argsInt)
for i := 0; i < v.Len(); i++ {
args = append(args, cast.ToString(v.Index(i).Interface()))
}
}
return
}
// handleRequest handles the requests received from udp.
func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) (success bool) {
if request.NodeID == dht.node.id.RawString() {
return
}
if len(request.NodeID) != nodeIDLength {
send(dht, addr, Error{ID: request.ID, NodeID: dht.node.id.RawString(), Response: []string{"Invalid ID"}})
return
}
if no, ok := dht.routingTable.GetNodeByAddress(addr.String()); ok && no.id.RawString() != request.NodeID {
dht.routingTable.RemoveByAddr(addr.String())
send(dht, addr, Error{ID: request.ID, NodeID: dht.node.id.RawString(), Response: []string{"Invalid ID"}})
return
}
switch request.Method {
case pingMethod:
send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id.RawString(), Response: "pong"})
case findNodeMethod:
if len(request.Args) < 1 {
send(dht, addr, Error{ID: request.ID, NodeID: dht.node.id.RawString(), Response: []string{"No target"}})
return
}
target := request.Args[0]
if len(target) != nodeIDLength {
send(dht, addr, makeError(t, protocolError, "invalid target"))
send(dht, addr, Error{ID: request.ID, NodeID: dht.node.id.RawString(), Response: []string{"Invalid target"}})
return
}
@ -485,94 +504,17 @@ func handleRequest(dht *DHT, addr *net.UDPAddr, response map[string]interface{})
if no != nil {
nodes = no.CompactNodeInfo()
} else {
nodes = strings.Join(
dht.routingTable.GetNeighborCompactInfos(targetID, dht.K),
"",
)
nodes = strings.Join(dht.routingTable.GetNeighborCompactInfos(targetID, dht.K), "")
}
send(dht, addr, makeResponse(t, map[string]interface{}{
"id": dht.id(target),
"nodes": nodes,
}))
case getPeersType:
if err := parseKey(a, "info_hash", "string"); err != nil {
send(dht, addr, makeError(t, protocolError, err.Error()))
return
}
send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id.RawString(), Response: nodes})
infoHash := a["info_hash"].(string)
if len(infoHash) != nodeIDLength {
send(dht, addr, makeError(t, protocolError, "invalid info_hash"))
return
}
if peers := dht.peersManager.GetPeers(
infoHash, dht.K); len(peers) > 0 {
values := make([]interface{}, len(peers))
for i, p := range peers {
values[i] = p.CompactIPPortInfo()
}
send(dht, addr, makeResponse(t, map[string]interface{}{
"id": dht.id(infoHash),
"values": values,
"token": dht.tokenManager.token(addr),
}))
} else {
send(dht, addr, makeResponse(t, map[string]interface{}{
"id": dht.id(infoHash),
"token": dht.tokenManager.token(addr),
"nodes": strings.Join(dht.routingTable.GetNeighborCompactInfos(
newBitmapFromString(infoHash), dht.K), ""),
}))
}
if dht.OnGetPeers != nil {
dht.OnGetPeers(infoHash, addr.IP.String(), addr.Port)
}
case announcePeerType:
if err := parseKeys(a, [][]string{
{"info_hash", "string"},
{"port", "int"},
{"token", "string"}}); err != nil {
send(dht, addr, makeError(t, protocolError, err.Error()))
return
}
infoHash := a["info_hash"].(string)
port := a["port"].(int)
token := a["token"].(string)
if !dht.tokenManager.check(addr, token) {
// send(dht, addr, makeError(t, protocolError, "invalid token"))
return
}
if impliedPort, ok := a["implied_port"]; ok &&
impliedPort.(int) != 0 {
port = addr.Port
}
dht.peersManager.Insert(infoHash, newPeer(addr.IP, port, token))
send(dht, addr, makeResponse(t, map[string]interface{}{
"id": dht.id(id),
}))
if dht.OnAnnouncePeer != nil {
dht.OnAnnouncePeer(infoHash, addr.IP.String(), port)
}
default:
// send(dht, addr, makeError(t, protocolError, "invalid q"))
return
}
no, _ := newNode(id, addr.Network(), addr.String())
no, _ := newNode(request.NodeID, addr.Network(), addr.String())
dht.routingTable.Insert(no)
return true
}
@ -580,13 +522,7 @@ func handleRequest(dht *DHT, addr *net.UDPAddr, response map[string]interface{})
// findOn puts nodes in the response to the routingTable, then if target is in
// the nodes or all nodes are in the routingTable, it stops. Otherwise it
// continues to findNode or getPeers.
func findOn(dht *DHT, r map[string]interface{}, target *bitmap, queryType string) error {
if err := parseKey(r, "nodes", "string"); err != nil {
return err
}
nodes := r["nodes"].(string)
func findOn(dht *DHT, nodes string, target *bitmap, queryType string) error {
if len(nodes)%compactNodeInfoLength != 0 {
return fmt.Errorf("the length of nodes should can be divided by %d", compactNodeInfoLength)
}
@ -612,10 +548,8 @@ func findOn(dht *DHT, r map[string]interface{}, target *bitmap, queryType string
targetID := target.RawString()
for _, no := range dht.routingTable.GetNeighbors(target, dht.K) {
switch queryType {
case findNodeType:
case findNodeMethod:
dht.transactionManager.findNode(no, targetID)
case getPeersType:
dht.transactionManager.getPeers(no, targetID)
default:
panic("invalid find type")
}
@ -624,76 +558,32 @@ func findOn(dht *DHT, r map[string]interface{}, target *bitmap, queryType string
}
// handleResponse handles responses received from udp.
func handleResponse(dht *DHT, addr *net.UDPAddr, response map[string]interface{}) (success bool) {
t := response["t"].(string)
trans := dht.transactionManager.filterOne(t, addr)
func handleResponse(dht *DHT, addr *net.UDPAddr, response Response) (success bool) {
trans := dht.transactionManager.filterOne(response.ID, addr)
if trans == nil {
return
}
// inform transManager to delete the transaction.
if err := parseKey(response, "r", "map"); err != nil {
return
}
q := trans.data["q"].(string)
a := trans.data["a"].(map[string]interface{})
r := response["r"].(map[string]interface{})
if err := parseKey(r, "id", "string"); err != nil {
return
}
id := r["id"].(string)
// If response's node id is not the same with the node id in the
// transaction, raise error.
if trans.node.id != nil && trans.node.id.RawString() != r["id"].(string) {
dht.blackList.insert(addr.IP.String(), addr.Port)
// TODO: is this necessary??? why??
if trans.node.id != nil && trans.node.id.RawString() != response.NodeID {
dht.routingTable.RemoveByAddr(addr.String())
return
}
node, err := newNode(id, addr.Network(), addr.String())
node, err := newNode(response.NodeID, addr.Network(), addr.String())
if err != nil {
return
}
switch q {
case pingType:
case findNodeType:
if trans.data["q"].(string) != findNodeType {
switch trans.request.Method {
case pingMethod:
case findNodeMethod:
target := trans.request.Args[0]
if findOn(dht, response.Response, newBitmapFromString(target), findNodeMethod) != nil {
return
}
target := trans.data["a"].(map[string]interface{})["target"].(string)
if findOn(dht, r, newBitmapFromString(target), findNodeType) != nil {
return
}
case getPeersType:
if err := parseKey(r, "token", "string"); err != nil {
return
}
token := r["token"].(string)
infoHash := a["info_hash"].(string)
if err := parseKey(r, "values", "list"); err == nil {
values := r["values"].([]interface{})
for _, v := range values {
p, err := newPeerFromCompactIPPortInfo(v.(string), token)
if err != nil {
continue
}
dht.peersManager.Insert(infoHash, p)
}
} else if findOn(
dht, r, newBitmapFromString(infoHash), getPeersType) != nil {
return
}
case announcePeerType:
default:
return
}
@ -701,71 +591,16 @@ func handleResponse(dht *DHT, addr *net.UDPAddr, response map[string]interface{}
// inform transManager to delete transaction.
trans.response <- struct{}{}
dht.blackList.delete(addr.IP.String(), addr.Port)
dht.routingTable.Insert(node)
return true
}
// handleError handles errors received from udp.
func handleError(dht *DHT, addr *net.UDPAddr, response map[string]interface{}) (success bool) {
if err := parseKey(response, "e", "list"); err != nil {
return
}
if e := response["e"].([]interface{}); len(e) != 2 {
return
}
if trans := dht.transactionManager.filterOne(
response["t"].(string), addr); trans != nil {
func handleError(dht *DHT, addr *net.UDPAddr, e Error) (success bool) {
if trans := dht.transactionManager.filterOne(e.ID, addr); trans != nil {
trans.response <- struct{}{}
}
return true
}
var handlers = map[string]func(*DHT, *net.UDPAddr, map[string]interface{}) bool{
"q": handleRequest,
"r": handleResponse,
"e": handleError,
}
// handle handles packets received from udp.
func handle(dht *DHT, pkt packet) {
log.Infof("Packet from %s: %s", pkt.raddr.IP.String(), pkt.data)
if len(dht.workerTokens) == dht.PacketWorkerLimit {
return
}
dht.workerTokens <- struct{}{}
go func() {
defer func() {
<-dht.workerTokens
}()
if dht.blackList.in(pkt.raddr.IP.String(), pkt.raddr.Port) {
log.Infof("%s blacklisted, ignoring packet", pkt.raddr.IP.String())
return
}
data, err := Decode(pkt.data)
if err != nil {
log.Errorf("Error decoding data: %s\n%s", err, pkt.data)
return
}
response, err := parseMessage(data)
if err != nil {
log.Errorf("Error parsing message: %s", err)
return
}
if f, ok := handlers[response["y"].(string)]; ok {
f(dht, pkt.raddr, response)
}
}()
}

View file

@ -1,386 +0,0 @@
package dht
import (
"bytes"
"crypto/sha1"
"encoding/binary"
"errors"
"io"
"io/ioutil"
"net"
"strings"
"time"
)
const (
// REQUEST represents request message type
REQUEST = iota
// DATA represents data message type
DATA
// REJECT represents reject message type
REJECT
)
const (
// BLOCK is 2 ^ 14
BLOCK = 16384
// MaxMetadataSize represents the max medata it can accept
MaxMetadataSize = BLOCK * 1000
// EXTENDED represents it is a extended message
EXTENDED = 20
// HANDSHAKE represents handshake bit
HANDSHAKE = 0
)
var handshakePrefix = []byte{
19, 66, 105, 116, 84, 111, 114, 114, 101, 110, 116, 32, 112, 114,
111, 116, 111, 99, 111, 108, 0, 0, 0, 0, 0, 16, 0, 1,
}
var handshakePrefixLength = len(handshakePrefix)
// read reads size-length bytes from conn to data.
func read(conn *net.TCPConn, size int, data *bytes.Buffer) error {
conn.SetReadDeadline(time.Now().Add(time.Second * 15))
n, err := io.CopyN(data, conn, int64(size))
if err != nil || n != int64(size) {
return errors.New("read error")
}
return nil
}
// readMessage gets a message from the tcp connection.
func readMessage(conn *net.TCPConn, data *bytes.Buffer) (
length int, err error) {
if err = read(conn, 4, data); err != nil {
return
}
length = int(bytes2int(data.Next(4)))
if length == 0 {
return
}
if err = read(conn, length, data); err != nil {
return
}
return
}
// sendMessage sends data to the connection.
func sendMessage(conn *net.TCPConn, data []byte) error {
length := int32(len(data))
buffer := bytes.NewBuffer(nil)
binary.Write(buffer, binary.BigEndian, length)
conn.SetWriteDeadline(time.Now().Add(time.Second * 10))
_, err := conn.Write(append(buffer.Bytes(), data...))
return err
}
// sendHandshake sends handshake message to conn.
func sendHandshake(conn *net.TCPConn, infoHash, peerID []byte) error {
data := make([]byte, handshakePrefixLength+nodeIDLength+len(peerID))
copy(data[:handshakePrefixLength], handshakePrefix)
copy(data[handshakePrefixLength:handshakePrefixLength+nodeIDLength], infoHash)
copy(data[handshakePrefixLength+nodeIDLength:], peerID)
conn.SetWriteDeadline(time.Now().Add(time.Second * 10))
_, err := conn.Write(data)
return err
}
// onHandshake handles the handshake response.
func onHandshake(data []byte) (err error) {
if !(bytes.Equal(handshakePrefix[:handshakePrefixLength], data[:handshakePrefixLength]) && data[25]&0x10 != 0) {
err = errors.New("invalid handshake response")
}
return
}
// sendExtHandshake requests for the ut_metadata and metadata_size.
func sendExtHandshake(conn *net.TCPConn) error {
data := append(
[]byte{EXTENDED, HANDSHAKE},
Encode(map[string]interface{}{
"m": map[string]interface{}{"ut_metadata": 1},
})...,
)
return sendMessage(conn, data)
}
// getUTMetaSize returns the ut_metadata and metadata_size.
func getUTMetaSize(data []byte) (
utMetadata int, metadataSize int, err error) {
v, err := Decode(data)
if err != nil {
return
}
dict, ok := v.(map[string]interface{})
if !ok {
err = errors.New("invalid dict")
return
}
if err = parseKeys(
dict, [][]string{{"metadata_size", "int"}, {"m", "map"}}); err != nil {
return
}
m := dict["m"].(map[string]interface{})
if err = parseKey(m, "ut_metadata", "int"); err != nil {
return
}
utMetadata = m["ut_metadata"].(int)
metadataSize = dict["metadata_size"].(int)
if metadataSize > MaxMetadataSize {
err = errors.New("metadata_size too long")
}
return
}
// Request represents the request context.
type Request struct {
InfoHash []byte
IP string
Port int
}
// Response contains the request context and the metadata info.
type Response struct {
Request
MetadataInfo []byte
}
// Wire represents the wire protocol.
type Wire struct {
blackList *blackList
queue *syncedMap
requests chan Request
responses chan Response
workerTokens chan struct{}
}
// NewWire returns a Wire pointer.
// - blackListSize: the blacklist size
// - requestQueueSize: the max requests it can buffers
// - workerQueueSize: the max goroutine downloading workers
func NewWire(blackListSize, requestQueueSize, workerQueueSize int) *Wire {
return &Wire{
blackList: newBlackList(blackListSize),
queue: newSyncedMap(),
requests: make(chan Request, requestQueueSize),
responses: make(chan Response, 1024),
workerTokens: make(chan struct{}, workerQueueSize),
}
}
// Request pushes the request to the queue.
func (wire *Wire) Request(infoHash []byte, ip string, port int) {
wire.requests <- Request{InfoHash: infoHash, IP: ip, Port: port}
}
// Response returns a chan of Response.
func (wire *Wire) Response() <-chan Response {
return wire.responses
}
// isDone returns whether the wire get all pieces of the metadata info.
func (wire *Wire) isDone(pieces [][]byte) bool {
for _, piece := range pieces {
if len(piece) == 0 {
return false
}
}
return true
}
func (wire *Wire) requestPieces(
conn *net.TCPConn, utMetadata int, metadataSize int, piecesNum int) {
buffer := make([]byte, 1024)
for i := 0; i < piecesNum; i++ {
buffer[0] = EXTENDED
buffer[1] = byte(utMetadata)
msg := Encode(map[string]interface{}{
"msg_type": REQUEST,
"piece": i,
})
length := len(msg) + 2
copy(buffer[2:length], msg)
sendMessage(conn, buffer[:length])
}
buffer = nil
}
// fetchMetadata fetchs medata info accroding to infohash from dht.
func (wire *Wire) fetchMetadata(r Request) {
var (
length int
msgType byte
piecesNum int
pieces [][]byte
utMetadata int
metadataSize int
)
defer func() {
pieces = nil
recover()
}()
infoHash := r.InfoHash
address := genAddress(r.IP, r.Port)
dial, err := net.DialTimeout("tcp", address, time.Second*15)
if err != nil {
wire.blackList.insert(r.IP, r.Port)
return
}
conn := dial.(*net.TCPConn)
conn.SetLinger(0)
defer conn.Close()
data := bytes.NewBuffer(nil)
data.Grow(BLOCK)
if sendHandshake(conn, infoHash, []byte(randomString(nodeIDLength))) != nil ||
read(conn, 68, data) != nil ||
onHandshake(data.Next(68)) != nil ||
sendExtHandshake(conn) != nil {
return
}
for {
length, err = readMessage(conn, data)
if err != nil {
return
}
if length == 0 {
continue
}
msgType, err = data.ReadByte()
if err != nil {
return
}
switch msgType {
case EXTENDED:
extendedID, err := data.ReadByte()
if err != nil {
return
}
payload, err := ioutil.ReadAll(data)
if err != nil {
return
}
if extendedID == 0 {
if pieces != nil {
return
}
utMetadata, metadataSize, err = getUTMetaSize(payload)
if err != nil {
return
}
piecesNum = metadataSize / BLOCK
if metadataSize%BLOCK != 0 {
piecesNum++
}
pieces = make([][]byte, piecesNum)
go wire.requestPieces(conn, utMetadata, metadataSize, piecesNum)
continue
}
if pieces == nil {
return
}
d, index, err := DecodeDict(payload, 0)
if err != nil {
return
}
dict := d.(map[string]interface{})
if err = parseKeys(dict, [][]string{
{"msg_type", "int"},
{"piece", "int"}}); err != nil {
return
}
if dict["msg_type"].(int) != DATA {
continue
}
piece := dict["piece"].(int)
pieceLen := length - 2 - index
if (piece != piecesNum-1 && pieceLen != BLOCK) ||
(piece == piecesNum-1 && pieceLen != metadataSize%BLOCK) {
return
}
pieces[piece] = payload[index:]
if wire.isDone(pieces) {
metadataInfo := bytes.Join(pieces, nil)
info := sha1.Sum(metadataInfo)
if !bytes.Equal(infoHash, info[:]) {
return
}
wire.responses <- Response{
Request: r,
MetadataInfo: metadataInfo,
}
return
}
default:
data.Reset()
}
}
}
// Run starts the peer wire protocol.
func (wire *Wire) Run() {
go wire.blackList.clear()
for r := range wire.requests {
wire.workerTokens <- struct{}{}
go func(r Request) {
defer func() {
<-wire.workerTokens
}()
key := strings.Join([]string{
string(r.InfoHash), genAddress(r.IP, r.Port),
}, ":")
if len(r.InfoHash) != nodeIDLength || wire.blackList.in(r.IP, r.Port) ||
wire.queue.Has(key) {
return
}
wire.fetchMetadata(r)
}(r)
}
}

View file

@ -11,7 +11,7 @@ import (
// maxPrefixLength is the length of DHT node.
const maxPrefixLength = 160
const nodeIDLength = 20
const nodeIDLength = 48
const compactNodeInfoLength = nodeIDLength + 6
// node represents a DHT node.
@ -359,11 +359,6 @@ func (rt *routingTable) Insert(nd *node) bool {
rt.Lock()
defer rt.Unlock()
if rt.dht.blackList.in(nd.addr.IP.String(), nd.addr.Port) ||
rt.cachedNodes.Len() >= rt.dht.MaxNodes {
return false
}
var (
next *routingTableNode
bucket *kbucket

34
main.go Normal file
View file

@ -0,0 +1,34 @@
package main
import (
"fmt"
"time"
"github.com/lbryio/lbry.go/dht"
log "github.com/sirupsen/logrus"
)
func main() {
config := dht.NewStandardConfig()
config.Address = ":49449" // dont pollute real port
config.PrimeNodes = []string{
"127.0.0.1:10001",
}
d := dht.New(config)
log.Info("Starting...")
go d.Run()
time.Sleep(5 * time.Second)
for {
peers, err := d.GetPeers("012b66fc7052d9a0c8cb563b8ede7662003ba65f425c2661b5c6919d445deeb31469be8b842d6faeea3f2b3ebcaec845")
if err != nil {
time.Sleep(time.Second * 1)
continue
}
fmt.Println("Found peers:", peers)
break
}
}