add token manager, add token to request/response, sucessfully perform a STORE request on the python daemon
This commit is contained in:
parent
5a37e49765
commit
f5f47aa079
14 changed files with 529 additions and 294 deletions
|
@ -3,27 +3,26 @@ package dht
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"strconv"
|
|
||||||
|
|
||||||
"github.com/lbryio/errors.go"
|
"github.com/lbryio/errors.go"
|
||||||
"github.com/lyoshenka/bencode"
|
"github.com/lyoshenka/bencode"
|
||||||
)
|
)
|
||||||
|
|
||||||
type bitmap [nodeIDLength]byte
|
type Bitmap [nodeIDLength]byte
|
||||||
|
|
||||||
func (b bitmap) RawString() string {
|
func (b Bitmap) RawString() string {
|
||||||
return string(b[:])
|
return string(b[:])
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b bitmap) Hex() string {
|
func (b Bitmap) Hex() string {
|
||||||
return hex.EncodeToString(b[:])
|
return hex.EncodeToString(b[:])
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b bitmap) HexShort() string {
|
func (b Bitmap) HexShort() string {
|
||||||
return hex.EncodeToString(b[:4])
|
return hex.EncodeToString(b[:4])
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b bitmap) Equals(other bitmap) bool {
|
func (b Bitmap) Equals(other Bitmap) bool {
|
||||||
for k := range b {
|
for k := range b {
|
||||||
if b[k] != other[k] {
|
if b[k] != other[k] {
|
||||||
return false
|
return false
|
||||||
|
@ -32,17 +31,17 @@ func (b bitmap) Equals(other bitmap) bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b bitmap) Less(other interface{}) bool {
|
func (b Bitmap) Less(other interface{}) bool {
|
||||||
for k := range b {
|
for k := range b {
|
||||||
if b[k] != other.(bitmap)[k] {
|
if b[k] != other.(Bitmap)[k] {
|
||||||
return b[k] < other.(bitmap)[k]
|
return b[k] < other.(Bitmap)[k]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b bitmap) Xor(other bitmap) bitmap {
|
func (b Bitmap) Xor(other Bitmap) Bitmap {
|
||||||
var ret bitmap
|
var ret Bitmap
|
||||||
for k := range b {
|
for k := range b {
|
||||||
ret[k] = b[k] ^ other[k]
|
ret[k] = b[k] ^ other[k]
|
||||||
}
|
}
|
||||||
|
@ -50,7 +49,7 @@ func (b bitmap) Xor(other bitmap) bitmap {
|
||||||
}
|
}
|
||||||
|
|
||||||
// PrefixLen returns the number of leading 0 bits
|
// PrefixLen returns the number of leading 0 bits
|
||||||
func (b bitmap) PrefixLen() int {
|
func (b Bitmap) PrefixLen() int {
|
||||||
for i := range b {
|
for i := range b {
|
||||||
for j := 0; j < 8; j++ {
|
for j := 0; j < 8; j++ {
|
||||||
if (b[i]>>uint8(7-j))&0x1 != 0 {
|
if (b[i]>>uint8(7-j))&0x1 != 0 {
|
||||||
|
@ -61,12 +60,12 @@ func (b bitmap) PrefixLen() int {
|
||||||
return numBuckets
|
return numBuckets
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b bitmap) MarshalBencode() ([]byte, error) {
|
func (b Bitmap) MarshalBencode() ([]byte, error) {
|
||||||
str := string(b[:])
|
str := string(b[:])
|
||||||
return bencode.EncodeBytes(str)
|
return bencode.EncodeBytes(str)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *bitmap) UnmarshalBencode(encoded []byte) error {
|
func (b *Bitmap) UnmarshalBencode(encoded []byte) error {
|
||||||
var str string
|
var str string
|
||||||
err := bencode.DecodeBytes(encoded, &str)
|
err := bencode.DecodeBytes(encoded, &str)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -79,30 +78,55 @@ func (b *bitmap) UnmarshalBencode(encoded []byte) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newBitmapFromBytes(data []byte) bitmap {
|
func BitmapFromBytes(data []byte) (Bitmap, error) {
|
||||||
if len(data) != nodeIDLength {
|
var bmp Bitmap
|
||||||
panic("invalid bitmap of length " + strconv.Itoa(len(data)))
|
|
||||||
|
if len(data) != len(bmp) {
|
||||||
|
return bmp, errors.Err("invalid bitmap of length %d", len(data))
|
||||||
}
|
}
|
||||||
|
|
||||||
var bmp bitmap
|
|
||||||
copy(bmp[:], data)
|
copy(bmp[:], data)
|
||||||
return bmp
|
return bmp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newBitmapFromString(data string) bitmap {
|
func BitmapFromBytesP(data []byte) Bitmap {
|
||||||
return newBitmapFromBytes([]byte(data))
|
bmp, err := BitmapFromBytes(data)
|
||||||
}
|
|
||||||
|
|
||||||
func newBitmapFromHex(hexStr string) bitmap {
|
|
||||||
decoded, err := hex.DecodeString(hexStr)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
return newBitmapFromBytes(decoded)
|
return bmp
|
||||||
}
|
}
|
||||||
|
|
||||||
func newRandomBitmap() bitmap {
|
func BitmapFromString(data string) (Bitmap, error) {
|
||||||
var id bitmap
|
return BitmapFromBytes([]byte(data))
|
||||||
|
}
|
||||||
|
|
||||||
|
func BitmapFromStringP(data string) Bitmap {
|
||||||
|
bmp, err := BitmapFromString(data)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return bmp
|
||||||
|
}
|
||||||
|
|
||||||
|
func BitmapFromHex(hexStr string) (Bitmap, error) {
|
||||||
|
decoded, err := hex.DecodeString(hexStr)
|
||||||
|
if err != nil {
|
||||||
|
return Bitmap{}, errors.Err(err)
|
||||||
|
}
|
||||||
|
return BitmapFromBytes(decoded)
|
||||||
|
}
|
||||||
|
|
||||||
|
func BitmapFromHexP(hexStr string) Bitmap {
|
||||||
|
bmp, err := BitmapFromHex(hexStr)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return bmp
|
||||||
|
}
|
||||||
|
|
||||||
|
func RandomBitmapP() Bitmap {
|
||||||
|
var id Bitmap
|
||||||
_, err := rand.Read(id[:])
|
_, err := rand.Read(id[:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
|
|
|
@ -7,19 +7,19 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestBitmap(t *testing.T) {
|
func TestBitmap(t *testing.T) {
|
||||||
a := bitmap{
|
a := Bitmap{
|
||||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
|
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
|
||||||
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
|
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
|
||||||
24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
|
24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
|
||||||
36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
|
36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
|
||||||
}
|
}
|
||||||
b := bitmap{
|
b := Bitmap{
|
||||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
|
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
|
||||||
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
|
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
|
||||||
24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
|
24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
|
||||||
36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 47, 46,
|
36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 47, 46,
|
||||||
}
|
}
|
||||||
c := bitmap{
|
c := Bitmap{
|
||||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
@ -46,13 +46,13 @@ func TestBitmap(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
id := "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
|
id := "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
|
||||||
if newBitmapFromHex(id).Hex() != id {
|
if BitmapFromHexP(id).Hex() != id {
|
||||||
t.Error(newBitmapFromHex(id).Hex())
|
t.Error(BitmapFromHexP(id).Hex())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBitmapMarshal(t *testing.T) {
|
func TestBitmapMarshal(t *testing.T) {
|
||||||
b := newBitmapFromString("123456789012345678901234567890123456789012345678")
|
b := BitmapFromStringP("123456789012345678901234567890123456789012345678")
|
||||||
encoded, err := bencode.EncodeBytes(b)
|
encoded, err := bencode.EncodeBytes(b)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
|
@ -66,11 +66,11 @@ func TestBitmapMarshal(t *testing.T) {
|
||||||
func TestBitmapMarshalEmbedded(t *testing.T) {
|
func TestBitmapMarshalEmbedded(t *testing.T) {
|
||||||
e := struct {
|
e := struct {
|
||||||
A string
|
A string
|
||||||
B bitmap
|
B Bitmap
|
||||||
C int
|
C int
|
||||||
}{
|
}{
|
||||||
A: "1",
|
A: "1",
|
||||||
B: newBitmapFromString("222222222222222222222222222222222222222222222222"),
|
B: BitmapFromStringP("222222222222222222222222222222222222222222222222"),
|
||||||
C: 3,
|
C: 3,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -86,7 +86,7 @@ func TestBitmapMarshalEmbedded(t *testing.T) {
|
||||||
|
|
||||||
func TestBitmapMarshalEmbedded2(t *testing.T) {
|
func TestBitmapMarshalEmbedded2(t *testing.T) {
|
||||||
encoded, err := bencode.EncodeBytes([]interface{}{
|
encoded, err := bencode.EncodeBytes([]interface{}{
|
||||||
newBitmapFromString("333333333333333333333333333333333333333333333333"),
|
BitmapFromStringP("333333333333333333333333333333333333333333333333"),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
|
@ -113,7 +113,7 @@ func TestBitmap_PrefixLen(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, test := range tt {
|
for _, test := range tt {
|
||||||
len := newBitmapFromHex(test.str).PrefixLen()
|
len := BitmapFromHexP(test.str).PrefixLen()
|
||||||
if len != test.len {
|
if len != test.len {
|
||||||
t.Errorf("got prefix len %d; expected %d for %s", len, test.len, test.str)
|
t.Errorf("got prefix len %d; expected %d for %s", len, test.len, test.str)
|
||||||
}
|
}
|
||||||
|
|
File diff suppressed because one or more lines are too long
88
dht/dht.go
88
dht/dht.go
|
@ -3,6 +3,7 @@ package dht
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net"
|
"net"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
@ -28,6 +29,7 @@ const bucketSize = 8 // this is the constant k in the spec
|
||||||
|
|
||||||
const udpRetry = 3
|
const udpRetry = 3
|
||||||
const udpTimeout = 10 * time.Second
|
const udpTimeout = 10 * time.Second
|
||||||
|
const udpMaxMessageLength = 1024 // I think our longest message is ~676 bytes, so I rounded up
|
||||||
|
|
||||||
const tExpire = 86400 * time.Second // the time after which a key/value pair expires; this is a time-to-live (TTL) from the original publication date
|
const tExpire = 86400 * time.Second // the time after which a key/value pair expires; this is a time-to-live (TTL) from the original publication date
|
||||||
const tRefresh = 3600 * time.Second // the time after which an otherwise unaccessed bucket must be refreshed
|
const tRefresh = 3600 * time.Second // the time after which an otherwise unaccessed bucket must be refreshed
|
||||||
|
@ -37,6 +39,8 @@ const tRepublish = 86400 * time.Second // the time after which the original publ
|
||||||
const numBuckets = nodeIDLength * 8
|
const numBuckets = nodeIDLength * 8
|
||||||
const compactNodeInfoLength = nodeIDLength + 6
|
const compactNodeInfoLength = nodeIDLength + 6
|
||||||
|
|
||||||
|
const tokenSecretRotationInterval = 5 * time.Minute // how often the token-generating secret is rotated
|
||||||
|
|
||||||
// packet represents the information receive from udp.
|
// packet represents the information receive from udp.
|
||||||
type packet struct {
|
type packet struct {
|
||||||
data []byte
|
data []byte
|
||||||
|
@ -92,6 +96,8 @@ type DHT struct {
|
||||||
store *peerStore
|
store *peerStore
|
||||||
// transaction manager
|
// transaction manager
|
||||||
tm *transactionManager
|
tm *transactionManager
|
||||||
|
// token manager
|
||||||
|
tokens *tokenManager
|
||||||
// stopper to shut down DHT
|
// stopper to shut down DHT
|
||||||
stop *stopOnce.Stopper
|
stop *stopOnce.Stopper
|
||||||
// wait group for all the things that need to be stopped when DHT shuts down
|
// wait group for all the things that need to be stopped when DHT shuts down
|
||||||
|
@ -106,11 +112,11 @@ func New(config *Config) (*DHT, error) {
|
||||||
config = NewStandardConfig()
|
config = NewStandardConfig()
|
||||||
}
|
}
|
||||||
|
|
||||||
var id bitmap
|
var id Bitmap
|
||||||
if config.NodeID == "" {
|
if config.NodeID == "" {
|
||||||
id = newRandomBitmap()
|
id = RandomBitmapP()
|
||||||
} else {
|
} else {
|
||||||
id = newBitmapFromHex(config.NodeID)
|
id = BitmapFromHexP(config.NodeID)
|
||||||
}
|
}
|
||||||
|
|
||||||
ip, port, err := net.SplitHostPort(config.Address)
|
ip, port, err := net.SplitHostPort(config.Address)
|
||||||
|
@ -141,8 +147,10 @@ func New(config *Config) (*DHT, error) {
|
||||||
stop: stopOnce.New(),
|
stop: stopOnce.New(),
|
||||||
stopWG: &sync.WaitGroup{},
|
stopWG: &sync.WaitGroup{},
|
||||||
joined: make(chan struct{}),
|
joined: make(chan struct{}),
|
||||||
|
tokens: &tokenManager{},
|
||||||
}
|
}
|
||||||
d.tm = newTransactionManager(d)
|
d.tm = newTransactionManager(d)
|
||||||
|
d.tokens.Start(tokenSecretRotationInterval)
|
||||||
return d, nil
|
return d, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -173,7 +181,7 @@ func (dht *DHT) listen() {
|
||||||
dht.stopWG.Add(1)
|
dht.stopWG.Add(1)
|
||||||
defer dht.stopWG.Done()
|
defer dht.stopWG.Done()
|
||||||
|
|
||||||
buf := make([]byte, 16384)
|
buf := make([]byte, udpMaxMessageLength)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
|
@ -212,7 +220,7 @@ func (dht *DHT) join() {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
tmpNode := Node{id: newRandomBitmap(), ip: raddr.IP, port: raddr.Port}
|
tmpNode := Node{id: RandomBitmapP(), ip: raddr.IP, port: raddr.Port}
|
||||||
res := dht.tm.Send(tmpNode, Request{Method: pingMethod})
|
res := dht.tm.Send(tmpNode, Request{Method: pingMethod})
|
||||||
if res == nil {
|
if res == nil {
|
||||||
log.Errorf("[%s] join: no response from seed node %s", dht.node.id.HexShort(), addr)
|
log.Errorf("[%s] join: no response from seed node %s", dht.node.id.HexShort(), addr)
|
||||||
|
@ -271,12 +279,13 @@ func (dht *DHT) Shutdown() {
|
||||||
log.Debugf("[%s] DHT shutting down", dht.node.id.HexShort())
|
log.Debugf("[%s] DHT shutting down", dht.node.id.HexShort())
|
||||||
dht.stop.Stop()
|
dht.stop.Stop()
|
||||||
dht.stopWG.Wait()
|
dht.stopWG.Wait()
|
||||||
|
dht.tokens.Stop()
|
||||||
dht.conn.Close()
|
dht.conn.Close()
|
||||||
log.Debugf("[%s] DHT stopped", dht.node.id.HexShort())
|
log.Debugf("[%s] DHT stopped", dht.node.id.HexShort())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get returns the list of nodes that have the blob for the given hash
|
// Get returns the list of nodes that have the blob for the given hash
|
||||||
func (dht *DHT) Get(hash bitmap) ([]Node, error) {
|
func (dht *DHT) Get(hash Bitmap) ([]Node, error) {
|
||||||
nf := newNodeFinder(dht, hash, true)
|
nf := newNodeFinder(dht, hash, true)
|
||||||
res, err := nf.Find()
|
res, err := nf.Find()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -290,20 +299,48 @@ func (dht *DHT) Get(hash bitmap) ([]Node, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Announce announces to the DHT that this node has the blob for the given hash
|
// Announce announces to the DHT that this node has the blob for the given hash
|
||||||
func (dht *DHT) Announce(hash bitmap) error {
|
func (dht *DHT) Announce(hash Bitmap) error {
|
||||||
nf := newNodeFinder(dht, hash, false)
|
nf := newNodeFinder(dht, hash, false)
|
||||||
res, err := nf.Find()
|
res, err := nf.Find()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: if this node is closer than farthest peer, store locally and pop farthest peer
|
||||||
|
|
||||||
for _, node := range res.Nodes {
|
for _, node := range res.Nodes {
|
||||||
|
go dht.storeOnNode(hash, node)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dht *DHT) storeOnNode(hash Bitmap, node Node) {
|
||||||
|
dht.stopWG.Add(1)
|
||||||
|
defer dht.stopWG.Done()
|
||||||
|
|
||||||
|
resCh := dht.tm.SendAsync(context.Background(), node, Request{
|
||||||
|
Method: findValueMethod,
|
||||||
|
Arg: &hash,
|
||||||
|
})
|
||||||
|
var res *Response
|
||||||
|
|
||||||
|
select {
|
||||||
|
case res = <-resCh:
|
||||||
|
case <-dht.stop.Chan():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if res == nil {
|
||||||
|
return // request timed out
|
||||||
|
}
|
||||||
|
|
||||||
dht.tm.SendAsync(context.Background(), node, Request{
|
dht.tm.SendAsync(context.Background(), node, Request{
|
||||||
Method: storeMethod,
|
Method: storeMethod,
|
||||||
StoreArgs: &storeArgs{
|
StoreArgs: &storeArgs{
|
||||||
BlobHash: hash,
|
BlobHash: hash,
|
||||||
Value: storeArgsValue{
|
Value: storeArgsValue{
|
||||||
Token: "",
|
Token: res.Token,
|
||||||
LbryID: dht.node.id,
|
LbryID: dht.node.id,
|
||||||
Port: dht.node.port,
|
Port: dht.node.port,
|
||||||
},
|
},
|
||||||
|
@ -311,11 +348,8 @@ func (dht *DHT) Announce(hash bitmap) error {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (dht *DHT) PrintState() {
|
func (dht *DHT) PrintState() {
|
||||||
log.Printf("DHT state at %s", time.Now().Format(time.RFC822Z))
|
log.Printf("DHT node %s at %s", dht.node.String(), time.Now().Format(time.RFC822Z))
|
||||||
log.Printf("Outstanding transactions: %d", dht.tm.Count())
|
log.Printf("Outstanding transactions: %d", dht.tm.Count())
|
||||||
log.Printf("Stored hashes: %d", dht.store.CountStoredHashes())
|
log.Printf("Stored hashes: %d", dht.store.CountStoredHashes())
|
||||||
log.Printf("Buckets:")
|
log.Printf("Buckets:")
|
||||||
|
@ -326,6 +360,34 @@ func (dht *DHT) PrintState() {
|
||||||
|
|
||||||
func printNodeList(list []Node) {
|
func printNodeList(list []Node) {
|
||||||
for i, n := range list {
|
for i, n := range list {
|
||||||
log.Printf("%d) %s %s:%d", i, n.id.HexShort(), n.ip.String(), n.port)
|
log.Printf("%d) %s", i, n.String())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func MakeTestDHT(numNodes int) []*DHT {
|
||||||
|
if numNodes < 1 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ip := "127.0.0.1"
|
||||||
|
firstPort := 21000
|
||||||
|
dhts := make([]*DHT, numNodes)
|
||||||
|
|
||||||
|
for i := 0; i < numNodes; i++ {
|
||||||
|
seeds := []string{}
|
||||||
|
if i > 0 {
|
||||||
|
seeds = []string{ip + ":" + strconv.Itoa(firstPort)}
|
||||||
|
}
|
||||||
|
|
||||||
|
dht, err := New(&Config{Address: ip + ":" + strconv.Itoa(firstPort+i), NodeID: RandomBitmapP().Hex(), SeedNodes: seeds})
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
go dht.Start()
|
||||||
|
dht.WaitUntilJoined()
|
||||||
|
dhts[i] = dht
|
||||||
|
}
|
||||||
|
|
||||||
|
return dhts
|
||||||
|
}
|
||||||
|
|
|
@ -3,7 +3,6 @@ package dht
|
||||||
import (
|
import (
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
"strconv"
|
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
@ -12,14 +11,14 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNodeFinder_FindNodes(t *testing.T) {
|
func TestNodeFinder_FindNodes(t *testing.T) {
|
||||||
dhts := makeDHT(t, 3)
|
dhts := MakeTestDHT(3)
|
||||||
defer func() {
|
defer func() {
|
||||||
for i := range dhts {
|
for i := range dhts {
|
||||||
dhts[i].Shutdown()
|
dhts[i].Shutdown()
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
nf := newNodeFinder(dhts[2], newRandomBitmap(), false)
|
nf := newNodeFinder(dhts[2], RandomBitmapP(), false)
|
||||||
res, err := nf.Find()
|
res, err := nf.Find()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
@ -55,15 +54,15 @@ func TestNodeFinder_FindNodes(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNodeFinder_FindValue(t *testing.T) {
|
func TestNodeFinder_FindValue(t *testing.T) {
|
||||||
dhts := makeDHT(t, 3)
|
dhts := MakeTestDHT(3)
|
||||||
defer func() {
|
defer func() {
|
||||||
for i := range dhts {
|
for i := range dhts {
|
||||||
dhts[i].Shutdown()
|
dhts[i].Shutdown()
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
blobHashToFind := newRandomBitmap()
|
blobHashToFind := RandomBitmapP()
|
||||||
nodeToFind := Node{id: newRandomBitmap(), ip: net.IPv4(1, 2, 3, 4), port: 5678}
|
nodeToFind := Node{id: RandomBitmapP(), ip: net.IPv4(1, 2, 3, 4), port: 5678}
|
||||||
dhts[0].store.Upsert(blobHashToFind, nodeToFind)
|
dhts[0].store.Upsert(blobHashToFind, nodeToFind)
|
||||||
|
|
||||||
nf := newNodeFinder(dhts[2], blobHashToFind, true)
|
nf := newNodeFinder(dhts[2], blobHashToFind, true)
|
||||||
|
@ -90,7 +89,7 @@ func TestDHT_LargeDHT(t *testing.T) {
|
||||||
rand.Seed(time.Now().UnixNano())
|
rand.Seed(time.Now().UnixNano())
|
||||||
log.Println("if this takes longer than 20 seconds, its stuck. idk why it gets stuck sometimes, but its a bug.")
|
log.Println("if this takes longer than 20 seconds, its stuck. idk why it gets stuck sometimes, but its a bug.")
|
||||||
nodes := 100
|
nodes := 100
|
||||||
dhts := makeDHT(t, nodes)
|
dhts := MakeTestDHT(nodes)
|
||||||
defer func() {
|
defer func() {
|
||||||
for _, d := range dhts {
|
for _, d := range dhts {
|
||||||
go d.Shutdown()
|
go d.Shutdown()
|
||||||
|
@ -100,9 +99,9 @@ func TestDHT_LargeDHT(t *testing.T) {
|
||||||
|
|
||||||
wg := &sync.WaitGroup{}
|
wg := &sync.WaitGroup{}
|
||||||
numIDs := nodes / 2
|
numIDs := nodes / 2
|
||||||
ids := make([]bitmap, numIDs)
|
ids := make([]Bitmap, numIDs)
|
||||||
for i := 0; i < numIDs; i++ {
|
for i := 0; i < numIDs; i++ {
|
||||||
ids[i] = newRandomBitmap()
|
ids[i] = RandomBitmapP()
|
||||||
}
|
}
|
||||||
for i := 0; i < numIDs; i++ {
|
for i := 0; i < numIDs; i++ {
|
||||||
go func(i int) {
|
go func(i int) {
|
||||||
|
@ -116,31 +115,3 @@ func TestDHT_LargeDHT(t *testing.T) {
|
||||||
|
|
||||||
dhts[1].PrintState()
|
dhts[1].PrintState()
|
||||||
}
|
}
|
||||||
|
|
||||||
func makeDHT(t *testing.T, numNodes int) []*DHT {
|
|
||||||
if numNodes < 1 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
ip := "127.0.0.1"
|
|
||||||
firstPort := 21000
|
|
||||||
dhts := make([]*DHT, numNodes)
|
|
||||||
|
|
||||||
for i := 0; i < numNodes; i++ {
|
|
||||||
seeds := []string{}
|
|
||||||
if i > 0 {
|
|
||||||
seeds = []string{ip + ":" + strconv.Itoa(firstPort)}
|
|
||||||
}
|
|
||||||
|
|
||||||
dht, err := New(&Config{Address: ip + ":" + strconv.Itoa(firstPort+i), NodeID: newRandomBitmap().Hex(), SeedNodes: seeds})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
go dht.Start()
|
|
||||||
dht.WaitUntilJoined()
|
|
||||||
dhts[i] = dht
|
|
||||||
}
|
|
||||||
|
|
||||||
return dhts
|
|
||||||
}
|
|
||||||
|
|
|
@ -38,6 +38,8 @@ const (
|
||||||
headerNodeIDField = "2" // node id is 48 bytes long
|
headerNodeIDField = "2" // node id is 48 bytes long
|
||||||
headerPayloadField = "3"
|
headerPayloadField = "3"
|
||||||
headerArgsField = "4"
|
headerArgsField = "4"
|
||||||
|
contactsField = "contacts"
|
||||||
|
tokenField = "token"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Message interface {
|
type Message interface {
|
||||||
|
@ -76,9 +78,9 @@ func newMessageID() messageID {
|
||||||
|
|
||||||
type Request struct {
|
type Request struct {
|
||||||
ID messageID
|
ID messageID
|
||||||
NodeID bitmap
|
NodeID Bitmap
|
||||||
Method string
|
Method string
|
||||||
Arg *bitmap
|
Arg *Bitmap
|
||||||
StoreArgs *storeArgs
|
StoreArgs *storeArgs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -87,7 +89,7 @@ func (r Request) MarshalBencode() ([]byte, error) {
|
||||||
if r.StoreArgs != nil {
|
if r.StoreArgs != nil {
|
||||||
args = r.StoreArgs
|
args = r.StoreArgs
|
||||||
} else if r.Arg != nil {
|
} else if r.Arg != nil {
|
||||||
args = []bitmap{*r.Arg}
|
args = []Bitmap{*r.Arg}
|
||||||
}
|
}
|
||||||
return bencode.EncodeBytes(map[string]interface{}{
|
return bencode.EncodeBytes(map[string]interface{}{
|
||||||
headerTypeField: requestType,
|
headerTypeField: requestType,
|
||||||
|
@ -101,7 +103,7 @@ func (r Request) MarshalBencode() ([]byte, error) {
|
||||||
func (r *Request) UnmarshalBencode(b []byte) error {
|
func (r *Request) UnmarshalBencode(b []byte) error {
|
||||||
var raw struct {
|
var raw struct {
|
||||||
ID messageID `bencode:"1"`
|
ID messageID `bencode:"1"`
|
||||||
NodeID bitmap `bencode:"2"`
|
NodeID Bitmap `bencode:"2"`
|
||||||
Method string `bencode:"3"`
|
Method string `bencode:"3"`
|
||||||
Args bencode.RawMessage `bencode:"4"`
|
Args bencode.RawMessage `bencode:"4"`
|
||||||
}
|
}
|
||||||
|
@ -121,7 +123,7 @@ func (r *Request) UnmarshalBencode(b []byte) error {
|
||||||
return errors.Prefix("request unmarshal", err)
|
return errors.Prefix("request unmarshal", err)
|
||||||
}
|
}
|
||||||
} else if len(raw.Args) > 2 { // 2 because an empty list is `le`
|
} else if len(raw.Args) > 2 { // 2 because an empty list is `le`
|
||||||
tmp := []bitmap{}
|
tmp := []Bitmap{}
|
||||||
err = bencode.DecodeBytes(raw.Args, &tmp)
|
err = bencode.DecodeBytes(raw.Args, &tmp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Prefix("request unmarshal", err)
|
return errors.Prefix("request unmarshal", err)
|
||||||
|
@ -143,14 +145,14 @@ func (r Request) ArgsDebug() string {
|
||||||
|
|
||||||
type storeArgsValue struct {
|
type storeArgsValue struct {
|
||||||
Token string `bencode:"token"`
|
Token string `bencode:"token"`
|
||||||
LbryID bitmap `bencode:"lbryid"`
|
LbryID Bitmap `bencode:"lbryid"`
|
||||||
Port int `bencode:"port"`
|
Port int `bencode:"port"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type storeArgs struct {
|
type storeArgs struct {
|
||||||
BlobHash bitmap
|
BlobHash Bitmap
|
||||||
Value storeArgsValue
|
Value storeArgsValue
|
||||||
NodeID bitmap
|
NodeID Bitmap
|
||||||
SelfStore bool // this is an int on the wire
|
SelfStore bool // this is an int on the wire
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -217,10 +219,11 @@ func (s *storeArgs) UnmarshalBencode(b []byte) error {
|
||||||
|
|
||||||
type Response struct {
|
type Response struct {
|
||||||
ID messageID
|
ID messageID
|
||||||
NodeID bitmap
|
NodeID Bitmap
|
||||||
Data string
|
Data string
|
||||||
FindNodeData []Node
|
FindNodeData []Node
|
||||||
FindValueKey string
|
FindValueKey string
|
||||||
|
Token string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r Response) ArgsDebug() string {
|
func (r Response) ArgsDebug() string {
|
||||||
|
@ -238,6 +241,11 @@ func (r Response) ArgsDebug() string {
|
||||||
str += c.Addr().String() + ":" + c.id.HexShort() + ","
|
str += c.Addr().String() + ":" + c.id.HexShort() + ","
|
||||||
}
|
}
|
||||||
str = strings.TrimRight(str, ",") + "|"
|
str = strings.TrimRight(str, ",") + "|"
|
||||||
|
|
||||||
|
if r.Token != "" {
|
||||||
|
str += " token: " + hex.EncodeToString([]byte(r.Token))[:8]
|
||||||
|
}
|
||||||
|
|
||||||
return str
|
return str
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -247,9 +255,16 @@ func (r Response) MarshalBencode() ([]byte, error) {
|
||||||
headerMessageIDField: r.ID,
|
headerMessageIDField: r.ID,
|
||||||
headerNodeIDField: r.NodeID,
|
headerNodeIDField: r.NodeID,
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.Data != "" {
|
if r.Data != "" {
|
||||||
|
// ping or store
|
||||||
data[headerPayloadField] = r.Data
|
data[headerPayloadField] = r.Data
|
||||||
} else if r.FindValueKey != "" {
|
} else if r.FindValueKey != "" {
|
||||||
|
// findValue success
|
||||||
|
if r.Token == "" {
|
||||||
|
return nil, errors.Err("response to findValue must have a token")
|
||||||
|
}
|
||||||
|
|
||||||
var contacts [][]byte
|
var contacts [][]byte
|
||||||
for _, n := range r.FindNodeData {
|
for _, n := range r.FindNodeData {
|
||||||
compact, err := n.MarshalCompact()
|
compact, err := n.MarshalCompact()
|
||||||
|
@ -258,9 +273,19 @@ func (r Response) MarshalBencode() ([]byte, error) {
|
||||||
}
|
}
|
||||||
contacts = append(contacts, compact)
|
contacts = append(contacts, compact)
|
||||||
}
|
}
|
||||||
data[headerPayloadField] = map[string][][]byte{r.FindValueKey: contacts}
|
data[headerPayloadField] = map[string]interface{}{
|
||||||
|
r.FindValueKey: contacts,
|
||||||
|
tokenField: r.Token,
|
||||||
|
}
|
||||||
|
} else if r.Token != "" {
|
||||||
|
// findValue failure falling back to findNode
|
||||||
|
data[headerPayloadField] = map[string]interface{}{
|
||||||
|
contactsField: r.FindNodeData,
|
||||||
|
tokenField: r.Token,
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
data[headerPayloadField] = map[string][]Node{"contacts": r.FindNodeData}
|
// straight up findNode
|
||||||
|
data[headerPayloadField] = r.FindNodeData
|
||||||
}
|
}
|
||||||
|
|
||||||
return bencode.EncodeBytes(data)
|
return bencode.EncodeBytes(data)
|
||||||
|
@ -269,7 +294,7 @@ func (r Response) MarshalBencode() ([]byte, error) {
|
||||||
func (r *Response) UnmarshalBencode(b []byte) error {
|
func (r *Response) UnmarshalBencode(b []byte) error {
|
||||||
var raw struct {
|
var raw struct {
|
||||||
ID messageID `bencode:"1"`
|
ID messageID `bencode:"1"`
|
||||||
NodeID bitmap `bencode:"2"`
|
NodeID Bitmap `bencode:"2"`
|
||||||
Data bencode.RawMessage `bencode:"3"`
|
Data bencode.RawMessage `bencode:"3"`
|
||||||
}
|
}
|
||||||
err := bencode.DecodeBytes(b, &raw)
|
err := bencode.DecodeBytes(b, &raw)
|
||||||
|
@ -280,15 +305,34 @@ func (r *Response) UnmarshalBencode(b []byte) error {
|
||||||
r.ID = raw.ID
|
r.ID = raw.ID
|
||||||
r.NodeID = raw.NodeID
|
r.NodeID = raw.NodeID
|
||||||
|
|
||||||
|
// maybe data is a string (response to ping or store)?
|
||||||
err = bencode.DecodeBytes(raw.Data, &r.Data)
|
err = bencode.DecodeBytes(raw.Data, &r.Data)
|
||||||
if err != nil {
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// maybe data is a list of nodes (response to findNode)?
|
||||||
|
err = bencode.DecodeBytes(raw.Data, &r.FindNodeData)
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// it must be a response to findValue
|
||||||
var rawData map[string]bencode.RawMessage
|
var rawData map[string]bencode.RawMessage
|
||||||
err = bencode.DecodeBytes(raw.Data, &rawData)
|
err = bencode.DecodeBytes(raw.Data, &rawData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if contacts, ok := rawData["contacts"]; ok {
|
if token, ok := rawData[tokenField]; ok {
|
||||||
|
err = bencode.DecodeBytes(token, &r.Token)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
delete(rawData, tokenField) // it doesnt mess up findValue key finding below
|
||||||
|
}
|
||||||
|
|
||||||
|
if contacts, ok := rawData[contactsField]; ok {
|
||||||
err = bencode.DecodeBytes(contacts, &r.FindNodeData)
|
err = bencode.DecodeBytes(contacts, &r.FindNodeData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -312,14 +356,13 @@ func (r *Response) UnmarshalBencode(b []byte) error {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type Error struct {
|
type Error struct {
|
||||||
ID messageID
|
ID messageID
|
||||||
NodeID bitmap
|
NodeID Bitmap
|
||||||
ExceptionType string
|
ExceptionType string
|
||||||
Response []string
|
Response []string
|
||||||
}
|
}
|
||||||
|
@ -337,7 +380,7 @@ func (e Error) MarshalBencode() ([]byte, error) {
|
||||||
func (e *Error) UnmarshalBencode(b []byte) error {
|
func (e *Error) UnmarshalBencode(b []byte) error {
|
||||||
var raw struct {
|
var raw struct {
|
||||||
ID messageID `bencode:"1"`
|
ID messageID `bencode:"1"`
|
||||||
NodeID bitmap `bencode:"2"`
|
NodeID Bitmap `bencode:"2"`
|
||||||
ExceptionType string `bencode:"3"`
|
ExceptionType string `bencode:"3"`
|
||||||
Args interface{} `bencode:"4"`
|
Args interface{} `bencode:"4"`
|
||||||
}
|
}
|
||||||
|
|
File diff suppressed because one or more lines are too long
|
@ -13,7 +13,7 @@ import (
|
||||||
|
|
||||||
type nodeFinder struct {
|
type nodeFinder struct {
|
||||||
findValue bool // true if we're using findValue
|
findValue bool // true if we're using findValue
|
||||||
target bitmap
|
target Bitmap
|
||||||
dht *DHT
|
dht *DHT
|
||||||
|
|
||||||
done *stopOnce.Stopper
|
done *stopOnce.Stopper
|
||||||
|
@ -26,7 +26,10 @@ type nodeFinder struct {
|
||||||
|
|
||||||
shortlistMutex *sync.Mutex
|
shortlistMutex *sync.Mutex
|
||||||
shortlist []Node
|
shortlist []Node
|
||||||
shortlistAdded map[bitmap]bool
|
shortlistAdded map[Bitmap]bool
|
||||||
|
|
||||||
|
outstandingRequestsMutex *sync.RWMutex
|
||||||
|
outstandingRequests uint
|
||||||
}
|
}
|
||||||
|
|
||||||
type findNodeResponse struct {
|
type findNodeResponse struct {
|
||||||
|
@ -34,7 +37,7 @@ type findNodeResponse struct {
|
||||||
Nodes []Node
|
Nodes []Node
|
||||||
}
|
}
|
||||||
|
|
||||||
func newNodeFinder(dht *DHT, target bitmap, findValue bool) *nodeFinder {
|
func newNodeFinder(dht *DHT, target Bitmap, findValue bool) *nodeFinder {
|
||||||
return &nodeFinder{
|
return &nodeFinder{
|
||||||
dht: dht,
|
dht: dht,
|
||||||
target: target,
|
target: target,
|
||||||
|
@ -42,13 +45,18 @@ func newNodeFinder(dht *DHT, target bitmap, findValue bool) *nodeFinder {
|
||||||
findValueMutex: &sync.Mutex{},
|
findValueMutex: &sync.Mutex{},
|
||||||
activeNodesMutex: &sync.Mutex{},
|
activeNodesMutex: &sync.Mutex{},
|
||||||
shortlistMutex: &sync.Mutex{},
|
shortlistMutex: &sync.Mutex{},
|
||||||
shortlistAdded: make(map[bitmap]bool),
|
shortlistAdded: make(map[Bitmap]bool),
|
||||||
done: stopOnce.New(),
|
done: stopOnce.New(),
|
||||||
|
outstandingRequestsMutex: &sync.RWMutex{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (nf *nodeFinder) Find() (findNodeResponse, error) {
|
func (nf *nodeFinder) Find() (findNodeResponse, error) {
|
||||||
log.Debugf("[%s] starting an iterative Find() for %s (findValue is %t)", nf.dht.node.id.HexShort(), nf.target.HexShort(), nf.findValue)
|
if nf.findValue {
|
||||||
|
log.Debugf("[%s] starting an iterative Find for the value %s", nf.dht.node.id.HexShort(), nf.target.HexShort())
|
||||||
|
} else {
|
||||||
|
log.Debugf("[%s] starting an iterative Find for nodes near %s", nf.dht.node.id.HexShort(), nf.target.HexShort())
|
||||||
|
}
|
||||||
nf.appendNewToShortlist(nf.dht.rt.GetClosest(nf.target, alpha))
|
nf.appendNewToShortlist(nf.dht.rt.GetClosest(nf.target, alpha))
|
||||||
if len(nf.shortlist) == 0 {
|
if len(nf.shortlist) == 0 {
|
||||||
return findNodeResponse{}, errors.Err("no nodes in routing table")
|
return findNodeResponse{}, errors.Err("no nodes in routing table")
|
||||||
|
@ -67,7 +75,7 @@ func (nf *nodeFinder) Find() (findNodeResponse, error) {
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
|
||||||
// TODO: what to do if we have less than K active nodes, shortlist is empty, but we
|
// TODO: what to do if we have less than K active nodes, shortlist is empty, but we
|
||||||
// TODO: have other nodes in our routing table whom we have not contacted. prolly contact them?
|
// TODO: have other nodes in our routing table whom we have not contacted. prolly contact them
|
||||||
|
|
||||||
result := findNodeResponse{}
|
result := findNodeResponse{}
|
||||||
if nf.findValue && len(nf.findValueResult) > 0 {
|
if nf.findValue && len(nf.findValueResult) > 0 {
|
||||||
|
@ -91,8 +99,8 @@ func (nf *nodeFinder) iterationWorker(num int) {
|
||||||
maybeNode := nf.popFromShortlist()
|
maybeNode := nf.popFromShortlist()
|
||||||
if maybeNode == nil {
|
if maybeNode == nil {
|
||||||
// TODO: block if there are pending requests out from other workers. there may be more shortlist values coming
|
// TODO: block if there are pending requests out from other workers. there may be more shortlist values coming
|
||||||
log.Debugf("[%s] no more nodes in shortlist", nf.dht.node.id.HexShort())
|
log.Debugf("[%s] worker %d: no nodes in shortlist, waiting...", nf.dht.node.id.HexShort(), num)
|
||||||
time.Sleep(10 * time.Millisecond)
|
time.Sleep(100 * time.Millisecond)
|
||||||
} else {
|
} else {
|
||||||
node := *maybeNode
|
node := *maybeNode
|
||||||
|
|
||||||
|
@ -107,7 +115,9 @@ func (nf *nodeFinder) iterationWorker(num int) {
|
||||||
req.Method = findNodeMethod
|
req.Method = findNodeMethod
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("[%s] contacting %s", nf.dht.node.id.HexShort(), node.id.HexShort())
|
log.Debugf("[%s] worker %d: contacting %s", nf.dht.node.id.HexShort(), num, node.id.HexShort())
|
||||||
|
|
||||||
|
nf.incrementOutstanding()
|
||||||
|
|
||||||
var res *Response
|
var res *Response
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
@ -122,6 +132,7 @@ func (nf *nodeFinder) iterationWorker(num int) {
|
||||||
|
|
||||||
if res == nil {
|
if res == nil {
|
||||||
// nothing to do, response timed out
|
// nothing to do, response timed out
|
||||||
|
log.Debugf("[%s] worker %d: timed out waiting for %s", nf.dht.node.id.HexShort(), num, node.id.HexShort())
|
||||||
} else if nf.findValue && res.FindValueKey != "" {
|
} else if nf.findValue && res.FindValueKey != "" {
|
||||||
log.Debugf("[%s] worker %d: got value", nf.dht.node.id.HexShort(), num)
|
log.Debugf("[%s] worker %d: got value", nf.dht.node.id.HexShort(), num)
|
||||||
nf.findValueMutex.Lock()
|
nf.findValueMutex.Lock()
|
||||||
|
@ -130,10 +141,12 @@ func (nf *nodeFinder) iterationWorker(num int) {
|
||||||
nf.done.Stop()
|
nf.done.Stop()
|
||||||
return
|
return
|
||||||
} else {
|
} else {
|
||||||
log.Debugf("[%s] worker %d: got more contacts", nf.dht.node.id.HexShort(), num)
|
log.Debugf("[%s] worker %d: got contacts", nf.dht.node.id.HexShort(), num)
|
||||||
nf.insertIntoActiveList(node)
|
nf.insertIntoActiveList(node)
|
||||||
nf.appendNewToShortlist(res.FindNodeData)
|
nf.appendNewToShortlist(res.FindNodeData)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
nf.decrementOutstanding() // this is all the way down here because we need to add to shortlist first
|
||||||
}
|
}
|
||||||
|
|
||||||
if nf.isSearchFinished() {
|
if nf.isSearchFinished() {
|
||||||
|
@ -199,6 +212,7 @@ func (nf *nodeFinder) isSearchFinished() bool {
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !nf.areRequestsOutstanding() {
|
||||||
nf.shortlistMutex.Lock()
|
nf.shortlistMutex.Lock()
|
||||||
defer nf.shortlistMutex.Unlock()
|
defer nf.shortlistMutex.Unlock()
|
||||||
|
|
||||||
|
@ -213,6 +227,25 @@ func (nf *nodeFinder) isSearchFinished() bool {
|
||||||
// we have at least K active nodes, and we don't have any closer nodes yet to contact
|
// we have at least K active nodes, and we don't have any closer nodes yet to contact
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (nf *nodeFinder) incrementOutstanding() {
|
||||||
|
nf.outstandingRequestsMutex.Lock()
|
||||||
|
defer nf.outstandingRequestsMutex.Unlock()
|
||||||
|
nf.outstandingRequests++
|
||||||
|
}
|
||||||
|
func (nf *nodeFinder) decrementOutstanding() {
|
||||||
|
nf.outstandingRequestsMutex.Lock()
|
||||||
|
defer nf.outstandingRequestsMutex.Unlock()
|
||||||
|
if nf.outstandingRequests > 0 {
|
||||||
|
nf.outstandingRequests--
|
||||||
|
}
|
||||||
|
}
|
||||||
|
func (nf *nodeFinder) areRequestsOutstanding() bool {
|
||||||
|
nf.outstandingRequestsMutex.RLock()
|
||||||
|
defer nf.outstandingRequestsMutex.RUnlock()
|
||||||
|
return nf.outstandingRequests > 0
|
||||||
|
}
|
||||||
|
|
|
@ -15,9 +15,14 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type Node struct {
|
type Node struct {
|
||||||
id bitmap
|
id Bitmap
|
||||||
ip net.IP
|
ip net.IP
|
||||||
port int
|
port int
|
||||||
|
token string // this is set when the node is returned from a FindNode call
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n Node) String() string {
|
||||||
|
return n.id.HexShort() + "@" + n.Addr().String()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n Node) Addr() *net.UDPAddr {
|
func (n Node) Addr() *net.UDPAddr {
|
||||||
|
@ -51,7 +56,7 @@ func (n *Node) UnmarshalCompact(b []byte) error {
|
||||||
}
|
}
|
||||||
n.ip = net.IPv4(b[0], b[1], b[2], b[3]).To4()
|
n.ip = net.IPv4(b[0], b[1], b[2], b[3]).To4()
|
||||||
n.port = int(uint16(b[5]) | uint16(b[4])<<8)
|
n.port = int(uint16(b[5]) | uint16(b[4])<<8)
|
||||||
n.id = newBitmapFromBytes(b[6:])
|
n.id = BitmapFromBytesP(b[6:])
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -95,7 +100,7 @@ func (n *Node) UnmarshalBencode(b []byte) error {
|
||||||
|
|
||||||
type sortedNode struct {
|
type sortedNode struct {
|
||||||
node Node
|
node Node
|
||||||
xorDistanceToTarget bitmap
|
xorDistanceToTarget Bitmap
|
||||||
}
|
}
|
||||||
|
|
||||||
type byXorDistance []sortedNode
|
type byXorDistance []sortedNode
|
||||||
|
@ -128,6 +133,18 @@ func (rt *routingTable) BucketInfo() string {
|
||||||
|
|
||||||
bucketInfo := []string{}
|
bucketInfo := []string{}
|
||||||
for i, b := range rt.buckets {
|
for i, b := range rt.buckets {
|
||||||
|
contents := bucketContents(b)
|
||||||
|
if contents != "" {
|
||||||
|
bucketInfo = append(bucketInfo, fmt.Sprintf("Bucket %d: %s", i, contents))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(bucketInfo) == 0 {
|
||||||
|
return "buckets are empty"
|
||||||
|
}
|
||||||
|
return strings.Join(bucketInfo, "\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
func bucketContents(b *list.List) string {
|
||||||
count := 0
|
count := 0
|
||||||
ids := ""
|
ids := ""
|
||||||
for curr := b.Front(); curr != nil; curr = curr.Next() {
|
for curr := b.Front(); curr != nil; curr = curr.Next() {
|
||||||
|
@ -139,14 +156,11 @@ func (rt *routingTable) BucketInfo() string {
|
||||||
}
|
}
|
||||||
|
|
||||||
if count > 0 {
|
if count > 0 {
|
||||||
bucketInfo = append(bucketInfo, fmt.Sprintf("Bucket %d: (%d) %s", i, count, ids))
|
return fmt.Sprintf("(%d) %s", count, ids)
|
||||||
|
} else {
|
||||||
|
return ""
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(bucketInfo) == 0 {
|
|
||||||
return "buckets are empty"
|
|
||||||
}
|
|
||||||
return strings.Join(bucketInfo, "\n")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rt *routingTable) Update(node Node) {
|
func (rt *routingTable) Update(node Node) {
|
||||||
rt.lock.Lock()
|
rt.lock.Lock()
|
||||||
|
@ -165,7 +179,7 @@ func (rt *routingTable) Update(node Node) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rt *routingTable) RemoveByID(id bitmap) {
|
func (rt *routingTable) RemoveByID(id Bitmap) {
|
||||||
rt.lock.Lock()
|
rt.lock.Lock()
|
||||||
defer rt.lock.Unlock()
|
defer rt.lock.Unlock()
|
||||||
bucketNum := bucketFor(rt.node.id, id)
|
bucketNum := bucketFor(rt.node.id, id)
|
||||||
|
@ -176,7 +190,7 @@ func (rt *routingTable) RemoveByID(id bitmap) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rt *routingTable) GetClosest(target bitmap, limit int) []Node {
|
func (rt *routingTable) GetClosest(target Bitmap, limit int) []Node {
|
||||||
rt.lock.RLock()
|
rt.lock.RLock()
|
||||||
defer rt.lock.RUnlock()
|
defer rt.lock.RUnlock()
|
||||||
|
|
||||||
|
@ -216,7 +230,7 @@ func (rt *routingTable) GetClosest(target bitmap, limit int) []Node {
|
||||||
return nodes
|
return nodes
|
||||||
}
|
}
|
||||||
|
|
||||||
func findInList(bucket *list.List, value bitmap) *list.Element {
|
func findInList(bucket *list.List, value Bitmap) *list.Element {
|
||||||
for curr := bucket.Front(); curr != nil; curr = curr.Next() {
|
for curr := bucket.Front(); curr != nil; curr = curr.Next() {
|
||||||
if curr.Value.(Node).id.Equals(value) {
|
if curr.Value.(Node).id.Equals(value) {
|
||||||
return curr
|
return curr
|
||||||
|
@ -225,7 +239,7 @@ func findInList(bucket *list.List, value bitmap) *list.Element {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func appendNodes(nodes []sortedNode, start *list.Element, target bitmap) []sortedNode {
|
func appendNodes(nodes []sortedNode, start *list.Element, target Bitmap) []sortedNode {
|
||||||
for curr := start; curr != nil; curr = curr.Next() {
|
for curr := start; curr != nil; curr = curr.Next() {
|
||||||
node := curr.Value.(Node)
|
node := curr.Value.(Node)
|
||||||
nodes = append(nodes, sortedNode{node, node.id.Xor(target)})
|
nodes = append(nodes, sortedNode{node, node.id.Xor(target)})
|
||||||
|
@ -233,14 +247,14 @@ func appendNodes(nodes []sortedNode, start *list.Element, target bitmap) []sorte
|
||||||
return nodes
|
return nodes
|
||||||
}
|
}
|
||||||
|
|
||||||
func bucketFor(id bitmap, target bitmap) int {
|
func bucketFor(id Bitmap, target Bitmap) int {
|
||||||
if id.Equals(target) {
|
if id.Equals(target) {
|
||||||
panic("nodes do not have a bucket for themselves")
|
panic("nodes do not have a bucket for themselves")
|
||||||
}
|
}
|
||||||
return numBuckets - 1 - target.Xor(id).PrefixLen()
|
return numBuckets - 1 - target.Xor(id).PrefixLen()
|
||||||
}
|
}
|
||||||
|
|
||||||
func sortNodesInPlace(nodes []Node, target bitmap) {
|
func sortNodesInPlace(nodes []Node, target Bitmap) {
|
||||||
toSort := make([]sortedNode, len(nodes))
|
toSort := make([]sortedNode, len(nodes))
|
||||||
|
|
||||||
for i, n := range nodes {
|
for i, n := range nodes {
|
||||||
|
|
|
@ -7,21 +7,21 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestRoutingTable_bucketFor(t *testing.T) {
|
func TestRoutingTable_bucketFor(t *testing.T) {
|
||||||
target := newBitmapFromHex("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
|
target := BitmapFromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
|
||||||
var tests = []struct {
|
var tests = []struct {
|
||||||
id bitmap
|
id Bitmap
|
||||||
target bitmap
|
target Bitmap
|
||||||
expected int
|
expected int
|
||||||
}{
|
}{
|
||||||
{newBitmapFromHex("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001"), target, 0},
|
{BitmapFromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001"), target, 0},
|
||||||
{newBitmapFromHex("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000002"), target, 1},
|
{BitmapFromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000002"), target, 1},
|
||||||
{newBitmapFromHex("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000003"), target, 1},
|
{BitmapFromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000003"), target, 1},
|
||||||
{newBitmapFromHex("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000004"), target, 2},
|
{BitmapFromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000004"), target, 2},
|
||||||
{newBitmapFromHex("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000005"), target, 2},
|
{BitmapFromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000005"), target, 2},
|
||||||
{newBitmapFromHex("00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000f"), target, 3},
|
{BitmapFromHexP("00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000f"), target, 3},
|
||||||
{newBitmapFromHex("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000010"), target, 4},
|
{BitmapFromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000010"), target, 4},
|
||||||
{newBitmapFromHex("F00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), target, 383},
|
{BitmapFromHexP("F00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), target, 383},
|
||||||
{newBitmapFromHex("F0000000000000000000000000000000F0000000000000000000000000F0000000000000000000000000000000000000"), target, 383},
|
{BitmapFromHexP("F0000000000000000000000000000000F0000000000000000000000000F0000000000000000000000000000000000000"), target, 383},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
|
@ -33,14 +33,14 @@ func TestRoutingTable_bucketFor(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRoutingTable(t *testing.T) {
|
func TestRoutingTable(t *testing.T) {
|
||||||
n1 := newBitmapFromHex("FFFFFFFF0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
|
n1 := BitmapFromHexP("FFFFFFFF0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
|
||||||
n2 := newBitmapFromHex("FFFFFFF00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
|
n2 := BitmapFromHexP("FFFFFFF00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
|
||||||
n3 := newBitmapFromHex("111111110000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
|
n3 := BitmapFromHexP("111111110000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
|
||||||
rt := newRoutingTable(&Node{n1, net.ParseIP("127.0.0.1"), 8000})
|
rt := newRoutingTable(&Node{n1, net.ParseIP("127.0.0.1"), 8000, ""})
|
||||||
rt.Update(Node{n2, net.ParseIP("127.0.0.1"), 8001})
|
rt.Update(Node{n2, net.ParseIP("127.0.0.1"), 8001, ""})
|
||||||
rt.Update(Node{n3, net.ParseIP("127.0.0.1"), 8002})
|
rt.Update(Node{n3, net.ParseIP("127.0.0.1"), 8002, ""})
|
||||||
|
|
||||||
contacts := rt.GetClosest(newBitmapFromHex("222222220000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), 1)
|
contacts := rt.GetClosest(BitmapFromHexP("222222220000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), 1)
|
||||||
if len(contacts) != 1 {
|
if len(contacts) != 1 {
|
||||||
t.Fail()
|
t.Fail()
|
||||||
return
|
return
|
||||||
|
@ -64,7 +64,7 @@ func TestRoutingTable(t *testing.T) {
|
||||||
|
|
||||||
func TestCompactEncoding(t *testing.T) {
|
func TestCompactEncoding(t *testing.T) {
|
||||||
n := Node{
|
n := Node{
|
||||||
id: newBitmapFromHex("1c8aff71b99462464d9eeac639595ab99664be3482cb91a29d87467515c7d9158fe72aa1f1582dab07d8f8b5db277f41"),
|
id: BitmapFromHexP("1c8aff71b99462464d9eeac639595ab99664be3482cb91a29d87467515c7d9158fe72aa1f1582dab07d8f8b5db277f41"),
|
||||||
ip: net.ParseIP("1.2.3.4"),
|
ip: net.ParseIP("1.2.3.4"),
|
||||||
port: int(55<<8 + 66),
|
port: int(55<<8 + 66),
|
||||||
}
|
}
|
||||||
|
|
43
dht/rpc.go
43
dht/rpc.go
|
@ -15,7 +15,7 @@ import (
|
||||||
|
|
||||||
// handlePacket handles packets received from udp.
|
// handlePacket handles packets received from udp.
|
||||||
func handlePacket(dht *DHT, pkt packet) {
|
func handlePacket(dht *DHT, pkt packet) {
|
||||||
//log.Debugf("[%s] Received message from %s:%s (%d bytes) %s", dht.node.id.HexShort(), pkt.raddr.IP.String(), strconv.Itoa(pkt.raddr.Port), len(pkt.data), hex.EncodeToString(pkt.data))
|
//log.Debugf("[%s] Received message from %s (%d bytes) %s", dht.node.id.HexShort(), pkt.raddr.String(), len(pkt.data), hex.EncodeToString(pkt.data))
|
||||||
|
|
||||||
if !util.InSlice(string(pkt.data[0:5]), []string{"d1:0i", "di0ei"}) {
|
if !util.InSlice(string(pkt.data[0:5]), []string{"d1:0i", "di0ei"}) {
|
||||||
log.Errorf("[%s] data is not a well-formatted dict: (%d bytes) %s", dht.node.id.HexShort(), len(pkt.data), hex.EncodeToString(pkt.data))
|
log.Errorf("[%s] data is not a well-formatted dict: (%d bytes) %s", dht.node.id.HexShort(), len(pkt.data), hex.EncodeToString(pkt.data))
|
||||||
|
@ -32,7 +32,7 @@ func handlePacket(dht *DHT, pkt packet) {
|
||||||
request := Request{}
|
request := Request{}
|
||||||
err := bencode.DecodeBytes(pkt.data, &request)
|
err := bencode.DecodeBytes(pkt.data, &request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("[%s] error decoding request: %s: (%d bytes) %s", dht.node.id.HexShort(), err.Error(), len(pkt.data), hex.EncodeToString(pkt.data))
|
log.Errorf("[%s] error decoding request from %s: %s: (%d bytes) %s", dht.node.id.HexShort(), pkt.raddr.String(), err.Error(), len(pkt.data), hex.EncodeToString(pkt.data))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Debugf("[%s] query %s: received request from %s: %s(%s)", dht.node.id.HexShort(), request.ID.HexShort(), request.NodeID.HexShort(), request.Method, request.ArgsDebug())
|
log.Debugf("[%s] query %s: received request from %s: %s(%s)", dht.node.id.HexShort(), request.ID.HexShort(), request.NodeID.HexShort(), request.Method, request.ArgsDebug())
|
||||||
|
@ -42,7 +42,7 @@ func handlePacket(dht *DHT, pkt packet) {
|
||||||
response := Response{}
|
response := Response{}
|
||||||
err := bencode.DecodeBytes(pkt.data, &response)
|
err := bencode.DecodeBytes(pkt.data, &response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("[%s] error decoding response: %s: (%d bytes) %s", dht.node.id.HexShort(), err.Error(), len(pkt.data), hex.EncodeToString(pkt.data))
|
log.Errorf("[%s] error decoding response from %s: %s: (%d bytes) %s", dht.node.id.HexShort(), pkt.raddr.String(), err.Error(), len(pkt.data), hex.EncodeToString(pkt.data))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Debugf("[%s] query %s: received response from %s: %s", dht.node.id.HexShort(), response.ID.HexShort(), response.NodeID.HexShort(), response.ArgsDebug())
|
log.Debugf("[%s] query %s: received response from %s: %s", dht.node.id.HexShort(), response.ID.HexShort(), response.NodeID.HexShort(), response.ArgsDebug())
|
||||||
|
@ -52,7 +52,7 @@ func handlePacket(dht *DHT, pkt packet) {
|
||||||
e := Error{}
|
e := Error{}
|
||||||
err := bencode.DecodeBytes(pkt.data, &e)
|
err := bencode.DecodeBytes(pkt.data, &e)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("[%s] error decoding error: %s: (%d bytes) %s", dht.node.id.HexShort(), err.Error(), len(pkt.data), hex.EncodeToString(pkt.data))
|
log.Errorf("[%s] error decoding error from %s: %s: (%d bytes) %s", dht.node.id.HexShort(), pkt.raddr.String(), err.Error(), len(pkt.data), hex.EncodeToString(pkt.data))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Debugf("[%s] query %s: received error from %s: %s", dht.node.id.HexShort(), e.ID.HexShort(), e.NodeID.HexShort(), e.ExceptionType)
|
log.Debugf("[%s] query %s: received error from %s: %s", dht.node.id.HexShort(), e.ID.HexShort(), e.NodeID.HexShort(), e.ExceptionType)
|
||||||
|
@ -72,19 +72,28 @@ func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
switch request.Method {
|
switch request.Method {
|
||||||
|
default:
|
||||||
|
// send(dht, addr, makeError(t, protocolError, "invalid q"))
|
||||||
|
log.Errorln("invalid request method")
|
||||||
|
return
|
||||||
case pingMethod:
|
case pingMethod:
|
||||||
send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id, Data: pingSuccessResponse})
|
send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id, Data: pingSuccessResponse})
|
||||||
case storeMethod:
|
case storeMethod:
|
||||||
// TODO: we should be sending the IP in the request, not just using the sender's IP
|
// TODO: we should be sending the IP in the request, not just using the sender's IP
|
||||||
// TODO: should we be using StoreArgs.NodeID or StoreArgs.Value.LbryID ???
|
// TODO: should we be using StoreArgs.NodeID or StoreArgs.Value.LbryID ???
|
||||||
|
if dht.tokens.Verify(request.StoreArgs.Value.Token, request.NodeID, addr) {
|
||||||
dht.store.Upsert(request.StoreArgs.BlobHash, Node{id: request.StoreArgs.NodeID, ip: addr.IP, port: request.StoreArgs.Value.Port})
|
dht.store.Upsert(request.StoreArgs.BlobHash, Node{id: request.StoreArgs.NodeID, ip: addr.IP, port: request.StoreArgs.Value.Port})
|
||||||
send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id, Data: storeSuccessResponse})
|
send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id, Data: storeSuccessResponse})
|
||||||
|
} else {
|
||||||
|
send(dht, addr, Error{ID: request.ID, NodeID: dht.node.id, ExceptionType: "invalid-token"})
|
||||||
|
}
|
||||||
case findNodeMethod:
|
case findNodeMethod:
|
||||||
if request.Arg == nil {
|
if request.Arg == nil {
|
||||||
log.Errorln("request is missing arg")
|
log.Errorln("request is missing arg")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
doFindNodes(dht, addr, request)
|
send(dht, addr, getFindResponse(dht, request))
|
||||||
|
|
||||||
case findValueMethod:
|
case findValueMethod:
|
||||||
if request.Arg == nil {
|
if request.Arg == nil {
|
||||||
log.Errorln("request is missing arg")
|
log.Errorln("request is missing arg")
|
||||||
|
@ -97,32 +106,30 @@ func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) {
|
||||||
NodeID: dht.node.id,
|
NodeID: dht.node.id,
|
||||||
FindValueKey: request.Arg.RawString(),
|
FindValueKey: request.Arg.RawString(),
|
||||||
FindNodeData: nodes,
|
FindNodeData: nodes,
|
||||||
|
Token: dht.tokens.Get(request.NodeID, addr),
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
doFindNodes(dht, addr, request)
|
res := getFindResponse(dht, request)
|
||||||
|
res.Token = dht.tokens.Get(request.NodeID, addr)
|
||||||
|
send(dht, addr, res)
|
||||||
}
|
}
|
||||||
|
|
||||||
default:
|
|
||||||
// send(dht, addr, makeError(t, protocolError, "invalid q"))
|
|
||||||
log.Errorln("invalid request method")
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
node := Node{id: request.NodeID, ip: addr.IP, port: addr.Port}
|
node := Node{id: request.NodeID, ip: addr.IP, port: addr.Port}
|
||||||
dht.rt.Update(node)
|
dht.rt.Update(node)
|
||||||
}
|
}
|
||||||
|
|
||||||
func doFindNodes(dht *DHT, addr *net.UDPAddr, request Request) {
|
func getFindResponse(dht *DHT, request Request) Response {
|
||||||
closestNodes := dht.rt.GetClosest(*request.Arg, bucketSize)
|
closestNodes := dht.rt.GetClosest(*request.Arg, bucketSize)
|
||||||
if len(closestNodes) > 0 {
|
response := Response{
|
||||||
response := Response{ID: request.ID, NodeID: dht.node.id, FindNodeData: make([]Node, len(closestNodes))}
|
ID: request.ID,
|
||||||
|
NodeID: dht.node.id,
|
||||||
|
FindNodeData: make([]Node, len(closestNodes)),
|
||||||
|
}
|
||||||
for i, n := range closestNodes {
|
for i, n := range closestNodes {
|
||||||
response.FindNodeData[i] = n
|
response.FindNodeData[i] = n
|
||||||
}
|
}
|
||||||
send(dht, addr, response)
|
return response
|
||||||
} else {
|
|
||||||
log.Warn("no nodes in routing table")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleResponse handles responses received from udp.
|
// handleResponse handles responses received from udp.
|
||||||
|
|
|
@ -88,8 +88,8 @@ func (t *testUDPConn) Close() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPing(t *testing.T) {
|
func TestPing(t *testing.T) {
|
||||||
dhtNodeID := newRandomBitmap()
|
dhtNodeID := RandomBitmapP()
|
||||||
testNodeID := newRandomBitmap()
|
testNodeID := RandomBitmapP()
|
||||||
|
|
||||||
conn := newTestUDPConn("127.0.0.1:21217")
|
conn := newTestUDPConn("127.0.0.1:21217")
|
||||||
|
|
||||||
|
@ -183,8 +183,8 @@ func TestPing(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestStore(t *testing.T) {
|
func TestStore(t *testing.T) {
|
||||||
dhtNodeID := newRandomBitmap()
|
dhtNodeID := RandomBitmapP()
|
||||||
testNodeID := newRandomBitmap()
|
testNodeID := RandomBitmapP()
|
||||||
|
|
||||||
conn := newTestUDPConn("127.0.0.1:21217")
|
conn := newTestUDPConn("127.0.0.1:21217")
|
||||||
|
|
||||||
|
@ -199,7 +199,7 @@ func TestStore(t *testing.T) {
|
||||||
defer dht.Shutdown()
|
defer dht.Shutdown()
|
||||||
|
|
||||||
messageID := newMessageID()
|
messageID := newMessageID()
|
||||||
blobHashToStore := newRandomBitmap()
|
blobHashToStore := RandomBitmapP()
|
||||||
|
|
||||||
storeRequest := Request{
|
storeRequest := Request{
|
||||||
ID: messageID,
|
ID: messageID,
|
||||||
|
@ -208,7 +208,7 @@ func TestStore(t *testing.T) {
|
||||||
StoreArgs: &storeArgs{
|
StoreArgs: &storeArgs{
|
||||||
BlobHash: blobHashToStore,
|
BlobHash: blobHashToStore,
|
||||||
Value: storeArgsValue{
|
Value: storeArgsValue{
|
||||||
Token: "arst",
|
Token: dht.tokens.Get(testNodeID, conn.addr),
|
||||||
LbryID: testNodeID,
|
LbryID: testNodeID,
|
||||||
Port: 9999,
|
Port: 9999,
|
||||||
},
|
},
|
||||||
|
@ -280,8 +280,8 @@ func TestStore(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFindNode(t *testing.T) {
|
func TestFindNode(t *testing.T) {
|
||||||
dhtNodeID := newRandomBitmap()
|
dhtNodeID := RandomBitmapP()
|
||||||
testNodeID := newRandomBitmap()
|
testNodeID := RandomBitmapP()
|
||||||
|
|
||||||
conn := newTestUDPConn("127.0.0.1:21217")
|
conn := newTestUDPConn("127.0.0.1:21217")
|
||||||
|
|
||||||
|
@ -297,13 +297,13 @@ func TestFindNode(t *testing.T) {
|
||||||
nodesToInsert := 3
|
nodesToInsert := 3
|
||||||
var nodes []Node
|
var nodes []Node
|
||||||
for i := 0; i < nodesToInsert; i++ {
|
for i := 0; i < nodesToInsert; i++ {
|
||||||
n := Node{id: newRandomBitmap(), ip: net.ParseIP("127.0.0.1"), port: 10000 + i}
|
n := Node{id: RandomBitmapP(), ip: net.ParseIP("127.0.0.1"), port: 10000 + i}
|
||||||
nodes = append(nodes, n)
|
nodes = append(nodes, n)
|
||||||
dht.rt.Update(n)
|
dht.rt.Update(n)
|
||||||
}
|
}
|
||||||
|
|
||||||
messageID := newMessageID()
|
messageID := newMessageID()
|
||||||
blobHashToFind := newRandomBitmap()
|
blobHashToFind := RandomBitmapP()
|
||||||
|
|
||||||
request := Request{
|
request := Request{
|
||||||
ID: messageID,
|
ID: messageID,
|
||||||
|
@ -338,27 +338,17 @@ func TestFindNode(t *testing.T) {
|
||||||
t.Fatal("missing payload field")
|
t.Fatal("missing payload field")
|
||||||
}
|
}
|
||||||
|
|
||||||
payload, ok := response[headerPayloadField].(map[string]interface{})
|
contacts, ok := response[headerPayloadField].([]interface{})
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("payload is not a dictionary")
|
t.Fatal("payload is not a list")
|
||||||
}
|
|
||||||
|
|
||||||
contactsList, ok := payload["contacts"]
|
|
||||||
if !ok {
|
|
||||||
t.Fatal("payload is missing 'contacts' key")
|
|
||||||
}
|
|
||||||
|
|
||||||
contacts, ok := contactsList.([]interface{})
|
|
||||||
if !ok {
|
|
||||||
t.Fatal("'contacts' is not a list")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
verifyContacts(t, contacts, nodes)
|
verifyContacts(t, contacts, nodes)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFindValueExisting(t *testing.T) {
|
func TestFindValueExisting(t *testing.T) {
|
||||||
dhtNodeID := newRandomBitmap()
|
dhtNodeID := RandomBitmapP()
|
||||||
testNodeID := newRandomBitmap()
|
testNodeID := RandomBitmapP()
|
||||||
|
|
||||||
conn := newTestUDPConn("127.0.0.1:21217")
|
conn := newTestUDPConn("127.0.0.1:21217")
|
||||||
|
|
||||||
|
@ -375,7 +365,7 @@ func TestFindValueExisting(t *testing.T) {
|
||||||
nodesToInsert := 3
|
nodesToInsert := 3
|
||||||
var nodes []Node
|
var nodes []Node
|
||||||
for i := 0; i < nodesToInsert; i++ {
|
for i := 0; i < nodesToInsert; i++ {
|
||||||
n := Node{id: newRandomBitmap(), ip: net.ParseIP("127.0.0.1"), port: 10000 + i}
|
n := Node{id: RandomBitmapP(), ip: net.ParseIP("127.0.0.1"), port: 10000 + i}
|
||||||
nodes = append(nodes, n)
|
nodes = append(nodes, n)
|
||||||
dht.rt.Update(n)
|
dht.rt.Update(n)
|
||||||
}
|
}
|
||||||
|
@ -383,9 +373,9 @@ func TestFindValueExisting(t *testing.T) {
|
||||||
//data, _ := hex.DecodeString("64313a30693065313a3132303a7de8e57d34e316abbb5a8a8da50dcd1ad4c80e0f313a3234383a7ce1b831dec8689e44f80f547d2dea171f6a625e1a4ff6c6165e645f953103dabeb068a622203f859c6c64658fd3aa3b313a33393a66696e6456616c7565313a346c34383aa47624b8e7ee1e54df0c45e2eb858feb0b705bd2a78d8b739be31ba188f4bd6f56b371c51fecc5280d5fd26ba4168e966565")
|
//data, _ := hex.DecodeString("64313a30693065313a3132303a7de8e57d34e316abbb5a8a8da50dcd1ad4c80e0f313a3234383a7ce1b831dec8689e44f80f547d2dea171f6a625e1a4ff6c6165e645f953103dabeb068a622203f859c6c64658fd3aa3b313a33393a66696e6456616c7565313a346c34383aa47624b8e7ee1e54df0c45e2eb858feb0b705bd2a78d8b739be31ba188f4bd6f56b371c51fecc5280d5fd26ba4168e966565")
|
||||||
|
|
||||||
messageID := newMessageID()
|
messageID := newMessageID()
|
||||||
valueToFind := newRandomBitmap()
|
valueToFind := RandomBitmapP()
|
||||||
|
|
||||||
nodeToFind := Node{id: newRandomBitmap(), ip: net.ParseIP("1.2.3.4"), port: 1286}
|
nodeToFind := Node{id: RandomBitmapP(), ip: net.ParseIP("1.2.3.4"), port: 1286}
|
||||||
dht.store.Upsert(valueToFind, nodeToFind)
|
dht.store.Upsert(valueToFind, nodeToFind)
|
||||||
dht.store.Upsert(valueToFind, nodeToFind)
|
dht.store.Upsert(valueToFind, nodeToFind)
|
||||||
dht.store.Upsert(valueToFind, nodeToFind)
|
dht.store.Upsert(valueToFind, nodeToFind)
|
||||||
|
@ -442,8 +432,8 @@ func TestFindValueExisting(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFindValueFallbackToFindNode(t *testing.T) {
|
func TestFindValueFallbackToFindNode(t *testing.T) {
|
||||||
dhtNodeID := newRandomBitmap()
|
dhtNodeID := RandomBitmapP()
|
||||||
testNodeID := newRandomBitmap()
|
testNodeID := RandomBitmapP()
|
||||||
|
|
||||||
conn := newTestUDPConn("127.0.0.1:21217")
|
conn := newTestUDPConn("127.0.0.1:21217")
|
||||||
|
|
||||||
|
@ -460,13 +450,13 @@ func TestFindValueFallbackToFindNode(t *testing.T) {
|
||||||
nodesToInsert := 3
|
nodesToInsert := 3
|
||||||
var nodes []Node
|
var nodes []Node
|
||||||
for i := 0; i < nodesToInsert; i++ {
|
for i := 0; i < nodesToInsert; i++ {
|
||||||
n := Node{id: newRandomBitmap(), ip: net.ParseIP("127.0.0.1"), port: 10000 + i}
|
n := Node{id: RandomBitmapP(), ip: net.ParseIP("127.0.0.1"), port: 10000 + i}
|
||||||
nodes = append(nodes, n)
|
nodes = append(nodes, n)
|
||||||
dht.rt.Update(n)
|
dht.rt.Update(n)
|
||||||
}
|
}
|
||||||
|
|
||||||
messageID := newMessageID()
|
messageID := newMessageID()
|
||||||
valueToFind := newRandomBitmap()
|
valueToFind := RandomBitmapP()
|
||||||
|
|
||||||
request := Request{
|
request := Request{
|
||||||
ID: messageID,
|
ID: messageID,
|
||||||
|
@ -506,7 +496,7 @@ func TestFindValueFallbackToFindNode(t *testing.T) {
|
||||||
t.Fatal("payload is not a dictionary")
|
t.Fatal("payload is not a dictionary")
|
||||||
}
|
}
|
||||||
|
|
||||||
contactsList, ok := payload["contacts"]
|
contactsList, ok := payload[contactsField]
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("payload is missing 'contacts' key")
|
t.Fatal("payload is missing 'contacts' key")
|
||||||
}
|
}
|
||||||
|
|
14
dht/store.go
14
dht/store.go
|
@ -11,30 +11,30 @@ type peer struct {
|
||||||
|
|
||||||
type peerStore struct {
|
type peerStore struct {
|
||||||
// map of blob hashes to (map of node IDs to bools)
|
// map of blob hashes to (map of node IDs to bools)
|
||||||
hashes map[bitmap]map[bitmap]bool
|
hashes map[Bitmap]map[Bitmap]bool
|
||||||
// map of node IDs to peers
|
// map of node IDs to peers
|
||||||
nodeInfo map[bitmap]peer
|
nodeInfo map[Bitmap]peer
|
||||||
lock sync.RWMutex
|
lock sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func newPeerStore() *peerStore {
|
func newPeerStore() *peerStore {
|
||||||
return &peerStore{
|
return &peerStore{
|
||||||
hashes: make(map[bitmap]map[bitmap]bool),
|
hashes: make(map[Bitmap]map[Bitmap]bool),
|
||||||
nodeInfo: make(map[bitmap]peer),
|
nodeInfo: make(map[Bitmap]peer),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *peerStore) Upsert(blobHash bitmap, node Node) {
|
func (s *peerStore) Upsert(blobHash Bitmap, node Node) {
|
||||||
s.lock.Lock()
|
s.lock.Lock()
|
||||||
defer s.lock.Unlock()
|
defer s.lock.Unlock()
|
||||||
if _, ok := s.hashes[blobHash]; !ok {
|
if _, ok := s.hashes[blobHash]; !ok {
|
||||||
s.hashes[blobHash] = make(map[bitmap]bool)
|
s.hashes[blobHash] = make(map[Bitmap]bool)
|
||||||
}
|
}
|
||||||
s.hashes[blobHash][node.id] = true
|
s.hashes[blobHash][node.id] = true
|
||||||
s.nodeInfo[node.id] = peer{node: node}
|
s.nodeInfo[node.id] = peer{node: node}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *peerStore) Get(blobHash bitmap) []Node {
|
func (s *peerStore) Get(blobHash Bitmap) []Node {
|
||||||
s.lock.RLock()
|
s.lock.RLock()
|
||||||
defer s.lock.RUnlock()
|
defer s.lock.RUnlock()
|
||||||
var nodes []Node
|
var nodes []Node
|
||||||
|
|
80
dht/token_manager.go
Normal file
80
dht/token_manager.go
Normal file
|
@ -0,0 +1,80 @@
|
||||||
|
package dht
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/sha256"
|
||||||
|
"net"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/lbryio/lbry.go/stopOnce"
|
||||||
|
)
|
||||||
|
|
||||||
|
type tokenManager struct {
|
||||||
|
secret []byte
|
||||||
|
prevSecret []byte
|
||||||
|
lock *sync.RWMutex
|
||||||
|
wg *sync.WaitGroup
|
||||||
|
done *stopOnce.Stopper
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tm *tokenManager) Start(interval time.Duration) {
|
||||||
|
tm.secret = make([]byte, 64)
|
||||||
|
tm.prevSecret = make([]byte, 64)
|
||||||
|
tm.lock = &sync.RWMutex{}
|
||||||
|
tm.wg = &sync.WaitGroup{}
|
||||||
|
tm.done = stopOnce.New()
|
||||||
|
|
||||||
|
tm.rotateSecret()
|
||||||
|
|
||||||
|
tm.wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer tm.wg.Done()
|
||||||
|
tick := time.NewTicker(interval)
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-tick.C:
|
||||||
|
tm.rotateSecret()
|
||||||
|
case <-tm.done.Chan():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tm *tokenManager) Stop() {
|
||||||
|
tm.done.Stop()
|
||||||
|
tm.wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tm *tokenManager) Get(nodeID Bitmap, addr *net.UDPAddr) string {
|
||||||
|
return genToken(tm.secret, nodeID, addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tm *tokenManager) Verify(token string, nodeID Bitmap, addr *net.UDPAddr) bool {
|
||||||
|
return token == genToken(tm.secret, nodeID, addr) || token == genToken(tm.prevSecret, nodeID, addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func genToken(secret []byte, nodeID Bitmap, addr *net.UDPAddr) string {
|
||||||
|
buf := bytes.Buffer{}
|
||||||
|
buf.Write(nodeID[:])
|
||||||
|
buf.Write(addr.IP)
|
||||||
|
buf.WriteString(strconv.Itoa(addr.Port))
|
||||||
|
buf.Write(secret)
|
||||||
|
t := sha256.Sum256(buf.Bytes())
|
||||||
|
return string(t[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tm *tokenManager) rotateSecret() {
|
||||||
|
tm.lock.Lock()
|
||||||
|
defer tm.lock.Unlock()
|
||||||
|
|
||||||
|
copy(tm.prevSecret, tm.secret)
|
||||||
|
|
||||||
|
_, err := rand.Read(tm.secret)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in a new issue