bitmaps in more places
This commit is contained in:
parent
1acfd13ee9
commit
dd8333db33
10 changed files with 83 additions and 81 deletions
|
@ -61,6 +61,11 @@ func (b bitmap) PrefixLen() int {
|
||||||
return numBuckets
|
return numBuckets
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b bitmap) MarshalBencode() ([]byte, error) {
|
||||||
|
str := string(b[:])
|
||||||
|
return bencode.EncodeBytes(str)
|
||||||
|
}
|
||||||
|
|
||||||
func (b *bitmap) UnmarshalBencode(encoded []byte) error {
|
func (b *bitmap) UnmarshalBencode(encoded []byte) error {
|
||||||
var str string
|
var str string
|
||||||
err := bencode.DecodeBytes(encoded, &str)
|
err := bencode.DecodeBytes(encoded, &str)
|
||||||
|
@ -68,17 +73,12 @@ func (b *bitmap) UnmarshalBencode(encoded []byte) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if len(str) != nodeIDLength {
|
if len(str) != nodeIDLength {
|
||||||
return errors.Err("invalid node ID length")
|
return errors.Err("invalid bitmap length")
|
||||||
}
|
}
|
||||||
copy(b[:], str)
|
copy(b[:], str)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b bitmap) MarshalBencode() ([]byte, error) {
|
|
||||||
str := string(b[:])
|
|
||||||
return bencode.EncodeBytes(str)
|
|
||||||
}
|
|
||||||
|
|
||||||
func newBitmapFromBytes(data []byte) bitmap {
|
func newBitmapFromBytes(data []byte) bitmap {
|
||||||
if len(data) != nodeIDLength {
|
if len(data) != nodeIDLength {
|
||||||
panic("invalid bitmap of length " + strconv.Itoa(len(data)))
|
panic("invalid bitmap of length " + strconv.Itoa(len(data)))
|
||||||
|
|
10
dht/dht.go
10
dht/dht.go
|
@ -193,7 +193,7 @@ func (dht *DHT) join() {
|
||||||
}
|
}
|
||||||
|
|
||||||
tmpNode := Node{id: newRandomBitmap(), ip: raddr.IP, port: raddr.Port}
|
tmpNode := Node{id: newRandomBitmap(), ip: raddr.IP, port: raddr.Port}
|
||||||
res := dht.tm.Send(tmpNode, &Request{Method: pingMethod})
|
res := dht.tm.Send(tmpNode, Request{Method: pingMethod})
|
||||||
if res == nil {
|
if res == nil {
|
||||||
log.Errorf("[%s] join: no response from seed node %s", dht.node.id.HexShort(), addr)
|
log.Errorf("[%s] join: no response from seed node %s", dht.node.id.HexShort(), addr)
|
||||||
}
|
}
|
||||||
|
@ -260,8 +260,8 @@ func (dht *DHT) Get(hash bitmap) ([]Node, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Put announces to the DHT that this node has the blob for the given hash
|
// Announce announces to the DHT that this node has the blob for the given hash
|
||||||
func (dht *DHT) Put(hash bitmap) error {
|
func (dht *DHT) Announce(hash bitmap) error {
|
||||||
nf := newNodeFinder(dht, hash, false)
|
nf := newNodeFinder(dht, hash, false)
|
||||||
res, err := nf.Find()
|
res, err := nf.Find()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -269,10 +269,10 @@ func (dht *DHT) Put(hash bitmap) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, node := range res.Nodes {
|
for _, node := range res.Nodes {
|
||||||
send(dht, node.Addr(), &Request{
|
send(dht, node.Addr(), Request{
|
||||||
Method: storeMethod,
|
Method: storeMethod,
|
||||||
StoreArgs: &storeArgs{
|
StoreArgs: &storeArgs{
|
||||||
BlobHash: hash.RawString(),
|
BlobHash: hash,
|
||||||
Value: storeArgsValue{
|
Value: storeArgsValue{
|
||||||
Token: "",
|
Token: "",
|
||||||
LbryID: dht.node.id,
|
LbryID: dht.node.id,
|
||||||
|
|
|
@ -126,6 +126,17 @@ func (r *Request) UnmarshalBencode(b []byte) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r Request) ArgsDebug() string {
|
||||||
|
argsCopy := make([]string, len(r.Args))
|
||||||
|
copy(argsCopy, r.Args)
|
||||||
|
for k, v := range argsCopy {
|
||||||
|
if len(v) == nodeIDLength {
|
||||||
|
argsCopy[k] = hex.EncodeToString([]byte(v))[:8]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return strings.Join(argsCopy, ", ")
|
||||||
|
}
|
||||||
|
|
||||||
type storeArgsValue struct {
|
type storeArgsValue struct {
|
||||||
Token string `bencode:"token"`
|
Token string `bencode:"token"`
|
||||||
LbryID bitmap `bencode:"lbryid"`
|
LbryID bitmap `bencode:"lbryid"`
|
||||||
|
@ -133,7 +144,7 @@ type storeArgsValue struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type storeArgs struct {
|
type storeArgs struct {
|
||||||
BlobHash string
|
BlobHash bitmap
|
||||||
Value storeArgsValue
|
Value storeArgsValue
|
||||||
NodeID bitmap
|
NodeID bitmap
|
||||||
SelfStore bool // this is an int on the wire
|
SelfStore bool // this is an int on the wire
|
||||||
|
|
|
@ -45,7 +45,7 @@ func TestBencodeDecodeStoreArgs(t *testing.T) {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if hex.EncodeToString([]byte(storeArgs.BlobHash)) != strings.ToLower(blobHash) {
|
if storeArgs.BlobHash.Hex() != strings.ToLower(blobHash) {
|
||||||
t.Error("blob hash mismatch")
|
t.Error("blob hash mismatch")
|
||||||
}
|
}
|
||||||
if storeArgs.Value.LbryID.Hex() != strings.ToLower(lbryID) {
|
if storeArgs.Value.LbryID.Hex() != strings.ToLower(lbryID) {
|
||||||
|
|
|
@ -99,7 +99,7 @@ func (nf *nodeFinder) iterationWorker(num int) {
|
||||||
continue // cannot contact self
|
continue // cannot contact self
|
||||||
}
|
}
|
||||||
|
|
||||||
req := &Request{Args: []string{nf.target.RawString()}}
|
req := Request{Args: []string{nf.target.RawString()}}
|
||||||
if nf.findValue {
|
if nf.findValue {
|
||||||
req.Method = findValueMethod
|
req.Method = findValueMethod
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -109,10 +109,11 @@ func TestNodeFinder_FindValue(t *testing.T) {
|
||||||
|
|
||||||
time.Sleep(1 * time.Second) // give dhts a chance to connect
|
time.Sleep(1 * time.Second) // give dhts a chance to connect
|
||||||
|
|
||||||
|
blobHashToFind := newRandomBitmap()
|
||||||
nodeToFind := Node{id: newRandomBitmap(), ip: net.IPv4(1, 2, 3, 4), port: 5678}
|
nodeToFind := Node{id: newRandomBitmap(), ip: net.IPv4(1, 2, 3, 4), port: 5678}
|
||||||
dht1.store.Upsert(nodeToFind.id.RawString(), nodeToFind)
|
dht1.store.Upsert(blobHashToFind, nodeToFind)
|
||||||
|
|
||||||
nf := newNodeFinder(dht3, nodeToFind.id, true)
|
nf := newNodeFinder(dht3, blobHashToFind, true)
|
||||||
res, err := nf.Find()
|
res, err := nf.Find()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
|
44
dht/rpc.go
44
dht/rpc.go
|
@ -3,7 +3,6 @@ package dht
|
||||||
import (
|
import (
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"net"
|
"net"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/lbryio/lbry.go/util"
|
"github.com/lbryio/lbry.go/util"
|
||||||
|
@ -35,7 +34,7 @@ func handlePacket(dht *DHT, pkt packet) {
|
||||||
log.Errorf("[%s] error decoding request: %s: (%d bytes) %s", dht.node.id.HexShort(), err.Error(), len(pkt.data), hex.EncodeToString(pkt.data))
|
log.Errorf("[%s] error decoding request: %s: (%d bytes) %s", dht.node.id.HexShort(), err.Error(), len(pkt.data), hex.EncodeToString(pkt.data))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Debugf("[%s] query %s: received request from %s: %s(%s)", dht.node.id.HexShort(), request.ID.HexShort(), request.NodeID.HexShort(), request.Method, argsToString(request.Args))
|
log.Debugf("[%s] query %s: received request from %s: %s(%s)", dht.node.id.HexShort(), request.ID.HexShort(), request.NodeID.HexShort(), request.Method, request.ArgsDebug())
|
||||||
handleRequest(dht, pkt.raddr, request)
|
handleRequest(dht, pkt.raddr, request)
|
||||||
|
|
||||||
case '0' + responseType:
|
case '0' + responseType:
|
||||||
|
@ -75,17 +74,13 @@ func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) {
|
||||||
case pingMethod:
|
case pingMethod:
|
||||||
send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id, Data: pingSuccessResponse})
|
send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id, Data: pingSuccessResponse})
|
||||||
case storeMethod:
|
case storeMethod:
|
||||||
if request.StoreArgs.BlobHash == "" {
|
|
||||||
log.Errorln("blobhash is empty")
|
|
||||||
return // nothing to store
|
|
||||||
}
|
|
||||||
// TODO: we should be sending the IP in the request, not just using the sender's IP
|
// TODO: we should be sending the IP in the request, not just using the sender's IP
|
||||||
// TODO: should we be using StoreArgs.NodeID or StoreArgs.Value.LbryID ???
|
// TODO: should we be using StoreArgs.NodeID or StoreArgs.Value.LbryID ???
|
||||||
dht.store.Upsert(request.StoreArgs.BlobHash, Node{id: request.StoreArgs.NodeID, ip: addr.IP, port: request.StoreArgs.Value.Port})
|
dht.store.Upsert(request.StoreArgs.BlobHash, Node{id: request.StoreArgs.NodeID, ip: addr.IP, port: request.StoreArgs.Value.Port})
|
||||||
send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id, Data: storeSuccessResponse})
|
send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id, Data: storeSuccessResponse})
|
||||||
case findNodeMethod:
|
case findNodeMethod:
|
||||||
if len(request.Args) < 1 {
|
if len(request.Args) != 1 {
|
||||||
log.Errorln("nothing to find")
|
log.Errorln("invalid number of args")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if len(request.Args[0]) != nodeIDLength {
|
if len(request.Args[0]) != nodeIDLength {
|
||||||
|
@ -94,20 +89,22 @@ func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) {
|
||||||
}
|
}
|
||||||
doFindNodes(dht, addr, request)
|
doFindNodes(dht, addr, request)
|
||||||
case findValueMethod:
|
case findValueMethod:
|
||||||
if len(request.Args) < 1 {
|
if len(request.Args) != 1 {
|
||||||
log.Errorln("nothing to find")
|
log.Errorln("invalid number of args")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if len(request.Args[0]) != nodeIDLength {
|
if len(request.Args[0]) != nodeIDLength {
|
||||||
log.Errorln("invalid node id")
|
log.Errorln("invalid blob hash")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if nodes := dht.store.Get(request.Args[0]); len(nodes) > 0 {
|
if nodes := dht.store.Get(newBitmapFromString(request.Args[0])); len(nodes) > 0 {
|
||||||
response := Response{ID: request.ID, NodeID: dht.node.id}
|
send(dht, addr, Response{
|
||||||
response.FindValueKey = request.Args[0]
|
ID: request.ID,
|
||||||
response.FindNodeData = nodes
|
NodeID: dht.node.id,
|
||||||
send(dht, addr, response)
|
FindValueKey: request.Args[0],
|
||||||
|
FindNodeData: nodes,
|
||||||
|
})
|
||||||
} else {
|
} else {
|
||||||
doFindNodes(dht, addr, request)
|
doFindNodes(dht, addr, request)
|
||||||
}
|
}
|
||||||
|
@ -140,7 +137,7 @@ func doFindNodes(dht *DHT, addr *net.UDPAddr, request Request) {
|
||||||
func handleResponse(dht *DHT, addr *net.UDPAddr, response Response) {
|
func handleResponse(dht *DHT, addr *net.UDPAddr, response Response) {
|
||||||
tx := dht.tm.Find(response.ID, addr)
|
tx := dht.tm.Find(response.ID, addr)
|
||||||
if tx != nil {
|
if tx != nil {
|
||||||
tx.res <- &response
|
tx.res <- response
|
||||||
}
|
}
|
||||||
|
|
||||||
node := Node{id: response.NodeID, ip: addr.IP, port: addr.Port}
|
node := Node{id: response.NodeID, ip: addr.IP, port: addr.Port}
|
||||||
|
@ -163,7 +160,7 @@ func send(dht *DHT, addr *net.UDPAddr, data Message) error {
|
||||||
|
|
||||||
if req, ok := data.(Request); ok {
|
if req, ok := data.(Request); ok {
|
||||||
log.Debugf("[%s] query %s: sending request to %s (%d bytes) %s(%s)",
|
log.Debugf("[%s] query %s: sending request to %s (%d bytes) %s(%s)",
|
||||||
dht.node.id.HexShort(), req.ID.HexShort(), addr.String(), len(encoded), req.Method, argsToString(req.Args))
|
dht.node.id.HexShort(), req.ID.HexShort(), addr.String(), len(encoded), req.Method, req.ArgsDebug())
|
||||||
} else if res, ok := data.(Response); ok {
|
} else if res, ok := data.(Response); ok {
|
||||||
log.Debugf("[%s] query %s: sending response to %s (%d bytes) %s",
|
log.Debugf("[%s] query %s: sending response to %s (%d bytes) %s",
|
||||||
dht.node.id.HexShort(), res.ID.HexShort(), addr.String(), len(encoded), res.ArgsDebug())
|
dht.node.id.HexShort(), res.ID.HexShort(), addr.String(), len(encoded), res.ArgsDebug())
|
||||||
|
@ -176,14 +173,3 @@ func send(dht *DHT, addr *net.UDPAddr, data Message) error {
|
||||||
_, err = dht.conn.WriteToUDP(encoded, addr)
|
_, err = dht.conn.WriteToUDP(encoded, addr)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func argsToString(args []string) string {
|
|
||||||
argsCopy := make([]string, len(args))
|
|
||||||
copy(argsCopy, args)
|
|
||||||
for k, v := range argsCopy {
|
|
||||||
if len(v) == nodeIDLength {
|
|
||||||
argsCopy[k] = hex.EncodeToString([]byte(v))[:8]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return strings.Join(argsCopy, ", ")
|
|
||||||
}
|
|
||||||
|
|
|
@ -199,7 +199,7 @@ func TestStore(t *testing.T) {
|
||||||
defer dht.Shutdown()
|
defer dht.Shutdown()
|
||||||
|
|
||||||
messageID := newMessageID()
|
messageID := newMessageID()
|
||||||
blobHashToStore := newRandomBitmap().RawString()
|
blobHashToStore := newRandomBitmap()
|
||||||
|
|
||||||
storeRequest := Request{
|
storeRequest := Request{
|
||||||
ID: messageID,
|
ID: messageID,
|
||||||
|
@ -383,7 +383,7 @@ func TestFindValueExisting(t *testing.T) {
|
||||||
//data, _ := hex.DecodeString("64313a30693065313a3132303a7de8e57d34e316abbb5a8a8da50dcd1ad4c80e0f313a3234383a7ce1b831dec8689e44f80f547d2dea171f6a625e1a4ff6c6165e645f953103dabeb068a622203f859c6c64658fd3aa3b313a33393a66696e6456616c7565313a346c34383aa47624b8e7ee1e54df0c45e2eb858feb0b705bd2a78d8b739be31ba188f4bd6f56b371c51fecc5280d5fd26ba4168e966565")
|
//data, _ := hex.DecodeString("64313a30693065313a3132303a7de8e57d34e316abbb5a8a8da50dcd1ad4c80e0f313a3234383a7ce1b831dec8689e44f80f547d2dea171f6a625e1a4ff6c6165e645f953103dabeb068a622203f859c6c64658fd3aa3b313a33393a66696e6456616c7565313a346c34383aa47624b8e7ee1e54df0c45e2eb858feb0b705bd2a78d8b739be31ba188f4bd6f56b371c51fecc5280d5fd26ba4168e966565")
|
||||||
|
|
||||||
messageID := newMessageID()
|
messageID := newMessageID()
|
||||||
valueToFind := newRandomBitmap().RawString()
|
valueToFind := newRandomBitmap()
|
||||||
|
|
||||||
nodeToFind := Node{id: newRandomBitmap(), ip: net.ParseIP("1.2.3.4"), port: 1286}
|
nodeToFind := Node{id: newRandomBitmap(), ip: net.ParseIP("1.2.3.4"), port: 1286}
|
||||||
dht.store.Upsert(valueToFind, nodeToFind)
|
dht.store.Upsert(valueToFind, nodeToFind)
|
||||||
|
@ -394,7 +394,7 @@ func TestFindValueExisting(t *testing.T) {
|
||||||
ID: messageID,
|
ID: messageID,
|
||||||
NodeID: testNodeID,
|
NodeID: testNodeID,
|
||||||
Method: findValueMethod,
|
Method: findValueMethod,
|
||||||
Args: []string{valueToFind},
|
Args: []string{valueToFind.RawString()},
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := bencode.EncodeBytes(request)
|
data, err := bencode.EncodeBytes(request)
|
||||||
|
@ -428,7 +428,7 @@ func TestFindValueExisting(t *testing.T) {
|
||||||
t.Fatal("payload is not a dictionary")
|
t.Fatal("payload is not a dictionary")
|
||||||
}
|
}
|
||||||
|
|
||||||
compactContacts, ok := payload[valueToFind]
|
compactContacts, ok := payload[valueToFind.RawString()]
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("payload is missing key for search value")
|
t.Fatal("payload is missing key for search value")
|
||||||
}
|
}
|
||||||
|
|
18
dht/store.go
18
dht/store.go
|
@ -10,33 +10,35 @@ type peer struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type peerStore struct {
|
type peerStore struct {
|
||||||
nodeIDs map[string]map[bitmap]bool
|
// map of blob hashes to (map of node IDs to bools)
|
||||||
|
nodeIDs map[bitmap]map[bitmap]bool
|
||||||
|
// map of node IDs to peers
|
||||||
nodeInfo map[bitmap]peer
|
nodeInfo map[bitmap]peer
|
||||||
lock sync.RWMutex
|
lock sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func newPeerStore() *peerStore {
|
func newPeerStore() *peerStore {
|
||||||
return &peerStore{
|
return &peerStore{
|
||||||
nodeIDs: make(map[string]map[bitmap]bool),
|
nodeIDs: make(map[bitmap]map[bitmap]bool),
|
||||||
nodeInfo: make(map[bitmap]peer),
|
nodeInfo: make(map[bitmap]peer),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *peerStore) Upsert(key string, node Node) {
|
func (s *peerStore) Upsert(blobHash bitmap, node Node) {
|
||||||
s.lock.Lock()
|
s.lock.Lock()
|
||||||
defer s.lock.Unlock()
|
defer s.lock.Unlock()
|
||||||
if _, ok := s.nodeIDs[key]; !ok {
|
if _, ok := s.nodeIDs[blobHash]; !ok {
|
||||||
s.nodeIDs[key] = make(map[bitmap]bool)
|
s.nodeIDs[blobHash] = make(map[bitmap]bool)
|
||||||
}
|
}
|
||||||
s.nodeIDs[key][node.id] = true
|
s.nodeIDs[blobHash][node.id] = true
|
||||||
s.nodeInfo[node.id] = peer{node: node}
|
s.nodeInfo[node.id] = peer{node: node}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *peerStore) Get(key string) []Node {
|
func (s *peerStore) Get(blobHash bitmap) []Node {
|
||||||
s.lock.RLock()
|
s.lock.RLock()
|
||||||
defer s.lock.RUnlock()
|
defer s.lock.RUnlock()
|
||||||
var nodes []Node
|
var nodes []Node
|
||||||
if ids, ok := s.nodeIDs[key]; ok {
|
if ids, ok := s.nodeIDs[blobHash]; ok {
|
||||||
for id := range ids {
|
for id := range ids {
|
||||||
peer, ok := s.nodeInfo[id]
|
peer, ok := s.nodeInfo[id]
|
||||||
if !ok {
|
if !ok {
|
||||||
|
|
|
@ -9,21 +9,21 @@ import (
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
// query represents the query data included queried node and query-formed data.
|
// transaction represents a single query to the dht. it stores the queried node, the request, and the response channel
|
||||||
type transaction struct {
|
type transaction struct {
|
||||||
node Node
|
node Node
|
||||||
req *Request
|
req Request
|
||||||
res chan *Response
|
res chan Response
|
||||||
}
|
}
|
||||||
|
|
||||||
// transactionManager represents the manager of transactions.
|
// transactionManager keeps track of the outstanding transactions
|
||||||
type transactionManager struct {
|
type transactionManager struct {
|
||||||
|
dht *DHT
|
||||||
lock *sync.RWMutex
|
lock *sync.RWMutex
|
||||||
transactions map[messageID]*transaction
|
transactions map[messageID]*transaction
|
||||||
dht *DHT
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// newTransactionManager returns new transactionManager pointer.
|
// newTransactionManager returns a new transactionManager
|
||||||
func newTransactionManager(dht *DHT) *transactionManager {
|
func newTransactionManager(dht *DHT) *transactionManager {
|
||||||
return &transactionManager{
|
return &transactionManager{
|
||||||
lock: &sync.RWMutex{},
|
lock: &sync.RWMutex{},
|
||||||
|
@ -32,36 +32,36 @@ func newTransactionManager(dht *DHT) *transactionManager {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// insert adds a transaction to transactionManager.
|
// insert adds a transaction to the manager.
|
||||||
func (tm *transactionManager) insert(trans *transaction) {
|
func (tm *transactionManager) insert(tx *transaction) {
|
||||||
tm.lock.Lock()
|
tm.lock.Lock()
|
||||||
defer tm.lock.Unlock()
|
defer tm.lock.Unlock()
|
||||||
tm.transactions[trans.req.ID] = trans
|
tm.transactions[tx.req.ID] = tx
|
||||||
}
|
}
|
||||||
|
|
||||||
// delete removes a transaction from transactionManager.
|
// delete removes a transaction from the manager.
|
||||||
func (tm *transactionManager) delete(id messageID) {
|
func (tm *transactionManager) delete(id messageID) {
|
||||||
tm.lock.Lock()
|
tm.lock.Lock()
|
||||||
defer tm.lock.Unlock()
|
defer tm.lock.Unlock()
|
||||||
delete(tm.transactions, id)
|
delete(tm.transactions, id)
|
||||||
}
|
}
|
||||||
|
|
||||||
// find transaction for id. optionally ensure that addr matches node from transaction
|
// Find finds a transaction for the given id. it optionally ensures that addr matches node from transaction
|
||||||
func (tm *transactionManager) Find(id messageID, addr *net.UDPAddr) *transaction {
|
func (tm *transactionManager) Find(id messageID, addr *net.UDPAddr) *transaction {
|
||||||
tm.lock.RLock()
|
tm.lock.RLock()
|
||||||
defer tm.lock.RUnlock()
|
defer tm.lock.RUnlock()
|
||||||
|
|
||||||
t, ok := tm.transactions[id]
|
t, ok := tm.transactions[id]
|
||||||
if !ok {
|
if !ok || (addr != nil && t.node.Addr().String() != addr.String()) {
|
||||||
return nil
|
|
||||||
} else if addr != nil && t.node.Addr().String() != addr.String() {
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return t
|
return t
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tm *transactionManager) SendAsync(ctx context.Context, node Node, req *Request) <-chan *Response {
|
// SendAsync sends a transaction and returns a channel that will eventually contain the transaction response
|
||||||
|
// The response channel is closed when the transaction is completed or times out.
|
||||||
|
func (tm *transactionManager) SendAsync(ctx context.Context, node Node, req Request) <-chan *Response {
|
||||||
if node.id.Equals(tm.dht.node.id) {
|
if node.id.Equals(tm.dht.node.id) {
|
||||||
log.Error("sending query to self")
|
log.Error("sending query to self")
|
||||||
return nil
|
return nil
|
||||||
|
@ -74,24 +74,24 @@ func (tm *transactionManager) SendAsync(ctx context.Context, node Node, req *Req
|
||||||
|
|
||||||
req.ID = newMessageID()
|
req.ID = newMessageID()
|
||||||
req.NodeID = tm.dht.node.id
|
req.NodeID = tm.dht.node.id
|
||||||
trans := &transaction{
|
tx := &transaction{
|
||||||
node: node,
|
node: node,
|
||||||
req: req,
|
req: req,
|
||||||
res: make(chan *Response),
|
res: make(chan Response),
|
||||||
}
|
}
|
||||||
|
|
||||||
tm.insert(trans)
|
tm.insert(tx)
|
||||||
defer tm.delete(trans.req.ID)
|
defer tm.delete(tx.req.ID)
|
||||||
|
|
||||||
for i := 0; i < udpRetry; i++ {
|
for i := 0; i < udpRetry; i++ {
|
||||||
if err := send(tm.dht, trans.node.Addr(), *trans.req); err != nil {
|
if err := send(tm.dht, node.Addr(), tx.req); err != nil {
|
||||||
log.Errorf("send error: ", err.Error())
|
log.Errorf("send error: ", err.Error())
|
||||||
continue // try again? return?
|
continue // try again? return?
|
||||||
}
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case res := <-trans.res:
|
case res := <-tx.res:
|
||||||
ch <- res
|
ch <- &res
|
||||||
return
|
return
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
|
@ -100,13 +100,15 @@ func (tm *transactionManager) SendAsync(ctx context.Context, node Node, req *Req
|
||||||
}
|
}
|
||||||
|
|
||||||
// if request timed out each time
|
// if request timed out each time
|
||||||
tm.dht.rt.RemoveByID(trans.node.id)
|
tm.dht.rt.RemoveByID(tx.node.id)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return ch
|
return ch
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tm *transactionManager) Send(node Node, req *Request) *Response {
|
// Send sends a transaction and blocks until the response is available. It returns a response, or nil
|
||||||
|
// if the transaction timed out.
|
||||||
|
func (tm *transactionManager) Send(node Node, req Request) *Response {
|
||||||
return <-tm.SendAsync(context.Background(), node, req)
|
return <-tm.SendAsync(context.Background(), node, req)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue