diff --git a/dht/bitmap.go b/dht/bitmap.go index 54dc267..9966443 100644 --- a/dht/bitmap.go +++ b/dht/bitmap.go @@ -61,6 +61,11 @@ func (b bitmap) PrefixLen() int { return numBuckets } +func (b bitmap) MarshalBencode() ([]byte, error) { + str := string(b[:]) + return bencode.EncodeBytes(str) +} + func (b *bitmap) UnmarshalBencode(encoded []byte) error { var str string err := bencode.DecodeBytes(encoded, &str) @@ -68,17 +73,12 @@ func (b *bitmap) UnmarshalBencode(encoded []byte) error { return err } if len(str) != nodeIDLength { - return errors.Err("invalid node ID length") + return errors.Err("invalid bitmap length") } copy(b[:], str) return nil } -func (b bitmap) MarshalBencode() ([]byte, error) { - str := string(b[:]) - return bencode.EncodeBytes(str) -} - func newBitmapFromBytes(data []byte) bitmap { if len(data) != nodeIDLength { panic("invalid bitmap of length " + strconv.Itoa(len(data))) diff --git a/dht/dht.go b/dht/dht.go index 3d13722..795970f 100644 --- a/dht/dht.go +++ b/dht/dht.go @@ -193,7 +193,7 @@ func (dht *DHT) join() { } tmpNode := Node{id: newRandomBitmap(), ip: raddr.IP, port: raddr.Port} - res := dht.tm.Send(tmpNode, &Request{Method: pingMethod}) + res := dht.tm.Send(tmpNode, Request{Method: pingMethod}) if res == nil { log.Errorf("[%s] join: no response from seed node %s", dht.node.id.HexShort(), addr) } @@ -260,8 +260,8 @@ func (dht *DHT) Get(hash bitmap) ([]Node, error) { return nil, nil } -// Put announces to the DHT that this node has the blob for the given hash -func (dht *DHT) Put(hash bitmap) error { +// Announce announces to the DHT that this node has the blob for the given hash +func (dht *DHT) Announce(hash bitmap) error { nf := newNodeFinder(dht, hash, false) res, err := nf.Find() if err != nil { @@ -269,10 +269,10 @@ func (dht *DHT) Put(hash bitmap) error { } for _, node := range res.Nodes { - send(dht, node.Addr(), &Request{ + send(dht, node.Addr(), Request{ Method: storeMethod, StoreArgs: &storeArgs{ - BlobHash: hash.RawString(), + BlobHash: hash, Value: storeArgsValue{ Token: "", LbryID: dht.node.id, diff --git a/dht/message.go b/dht/message.go index 77ef27a..fbd123f 100644 --- a/dht/message.go +++ b/dht/message.go @@ -126,6 +126,17 @@ func (r *Request) UnmarshalBencode(b []byte) error { return nil } +func (r Request) ArgsDebug() string { + argsCopy := make([]string, len(r.Args)) + copy(argsCopy, r.Args) + for k, v := range argsCopy { + if len(v) == nodeIDLength { + argsCopy[k] = hex.EncodeToString([]byte(v))[:8] + } + } + return strings.Join(argsCopy, ", ") +} + type storeArgsValue struct { Token string `bencode:"token"` LbryID bitmap `bencode:"lbryid"` @@ -133,7 +144,7 @@ type storeArgsValue struct { } type storeArgs struct { - BlobHash string + BlobHash bitmap Value storeArgsValue NodeID bitmap SelfStore bool // this is an int on the wire diff --git a/dht/message_test.go b/dht/message_test.go index 55ebf2a..c256dab 100644 --- a/dht/message_test.go +++ b/dht/message_test.go @@ -45,7 +45,7 @@ func TestBencodeDecodeStoreArgs(t *testing.T) { t.Error(err) } - if hex.EncodeToString([]byte(storeArgs.BlobHash)) != strings.ToLower(blobHash) { + if storeArgs.BlobHash.Hex() != strings.ToLower(blobHash) { t.Error("blob hash mismatch") } if storeArgs.Value.LbryID.Hex() != strings.ToLower(lbryID) { diff --git a/dht/node_finder.go b/dht/node_finder.go index 14cfbc4..73ff523 100644 --- a/dht/node_finder.go +++ b/dht/node_finder.go @@ -99,7 +99,7 @@ func (nf *nodeFinder) iterationWorker(num int) { continue // cannot contact self } - req := &Request{Args: []string{nf.target.RawString()}} + req := Request{Args: []string{nf.target.RawString()}} if nf.findValue { req.Method = findValueMethod } else { diff --git a/dht/node_finder_test.go b/dht/node_finder_test.go index f72fc99..d98af14 100644 --- a/dht/node_finder_test.go +++ b/dht/node_finder_test.go @@ -109,10 +109,11 @@ func TestNodeFinder_FindValue(t *testing.T) { time.Sleep(1 * time.Second) // give dhts a chance to connect + blobHashToFind := newRandomBitmap() nodeToFind := Node{id: newRandomBitmap(), ip: net.IPv4(1, 2, 3, 4), port: 5678} - dht1.store.Upsert(nodeToFind.id.RawString(), nodeToFind) + dht1.store.Upsert(blobHashToFind, nodeToFind) - nf := newNodeFinder(dht3, nodeToFind.id, true) + nf := newNodeFinder(dht3, blobHashToFind, true) res, err := nf.Find() if err != nil { t.Fatal(err) diff --git a/dht/rpc.go b/dht/rpc.go index ae8af09..60ccec9 100644 --- a/dht/rpc.go +++ b/dht/rpc.go @@ -3,7 +3,6 @@ package dht import ( "encoding/hex" "net" - "strings" "time" "github.com/lbryio/lbry.go/util" @@ -35,7 +34,7 @@ func handlePacket(dht *DHT, pkt packet) { log.Errorf("[%s] error decoding request: %s: (%d bytes) %s", dht.node.id.HexShort(), err.Error(), len(pkt.data), hex.EncodeToString(pkt.data)) return } - log.Debugf("[%s] query %s: received request from %s: %s(%s)", dht.node.id.HexShort(), request.ID.HexShort(), request.NodeID.HexShort(), request.Method, argsToString(request.Args)) + log.Debugf("[%s] query %s: received request from %s: %s(%s)", dht.node.id.HexShort(), request.ID.HexShort(), request.NodeID.HexShort(), request.Method, request.ArgsDebug()) handleRequest(dht, pkt.raddr, request) case '0' + responseType: @@ -75,17 +74,13 @@ func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) { case pingMethod: send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id, Data: pingSuccessResponse}) case storeMethod: - if request.StoreArgs.BlobHash == "" { - log.Errorln("blobhash is empty") - return // nothing to store - } // TODO: we should be sending the IP in the request, not just using the sender's IP // TODO: should we be using StoreArgs.NodeID or StoreArgs.Value.LbryID ??? dht.store.Upsert(request.StoreArgs.BlobHash, Node{id: request.StoreArgs.NodeID, ip: addr.IP, port: request.StoreArgs.Value.Port}) send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id, Data: storeSuccessResponse}) case findNodeMethod: - if len(request.Args) < 1 { - log.Errorln("nothing to find") + if len(request.Args) != 1 { + log.Errorln("invalid number of args") return } if len(request.Args[0]) != nodeIDLength { @@ -94,20 +89,22 @@ func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) { } doFindNodes(dht, addr, request) case findValueMethod: - if len(request.Args) < 1 { - log.Errorln("nothing to find") + if len(request.Args) != 1 { + log.Errorln("invalid number of args") return } if len(request.Args[0]) != nodeIDLength { - log.Errorln("invalid node id") + log.Errorln("invalid blob hash") return } - if nodes := dht.store.Get(request.Args[0]); len(nodes) > 0 { - response := Response{ID: request.ID, NodeID: dht.node.id} - response.FindValueKey = request.Args[0] - response.FindNodeData = nodes - send(dht, addr, response) + if nodes := dht.store.Get(newBitmapFromString(request.Args[0])); len(nodes) > 0 { + send(dht, addr, Response{ + ID: request.ID, + NodeID: dht.node.id, + FindValueKey: request.Args[0], + FindNodeData: nodes, + }) } else { doFindNodes(dht, addr, request) } @@ -140,7 +137,7 @@ func doFindNodes(dht *DHT, addr *net.UDPAddr, request Request) { func handleResponse(dht *DHT, addr *net.UDPAddr, response Response) { tx := dht.tm.Find(response.ID, addr) if tx != nil { - tx.res <- &response + tx.res <- response } node := Node{id: response.NodeID, ip: addr.IP, port: addr.Port} @@ -163,7 +160,7 @@ func send(dht *DHT, addr *net.UDPAddr, data Message) error { if req, ok := data.(Request); ok { log.Debugf("[%s] query %s: sending request to %s (%d bytes) %s(%s)", - dht.node.id.HexShort(), req.ID.HexShort(), addr.String(), len(encoded), req.Method, argsToString(req.Args)) + dht.node.id.HexShort(), req.ID.HexShort(), addr.String(), len(encoded), req.Method, req.ArgsDebug()) } else if res, ok := data.(Response); ok { log.Debugf("[%s] query %s: sending response to %s (%d bytes) %s", dht.node.id.HexShort(), res.ID.HexShort(), addr.String(), len(encoded), res.ArgsDebug()) @@ -176,14 +173,3 @@ func send(dht *DHT, addr *net.UDPAddr, data Message) error { _, err = dht.conn.WriteToUDP(encoded, addr) return err } - -func argsToString(args []string) string { - argsCopy := make([]string, len(args)) - copy(argsCopy, args) - for k, v := range argsCopy { - if len(v) == nodeIDLength { - argsCopy[k] = hex.EncodeToString([]byte(v))[:8] - } - } - return strings.Join(argsCopy, ", ") -} diff --git a/dht/rpc_test.go b/dht/rpc_test.go index 54414c3..8129941 100644 --- a/dht/rpc_test.go +++ b/dht/rpc_test.go @@ -199,7 +199,7 @@ func TestStore(t *testing.T) { defer dht.Shutdown() messageID := newMessageID() - blobHashToStore := newRandomBitmap().RawString() + blobHashToStore := newRandomBitmap() storeRequest := Request{ ID: messageID, @@ -383,7 +383,7 @@ func TestFindValueExisting(t *testing.T) { //data, _ := hex.DecodeString("64313a30693065313a3132303a7de8e57d34e316abbb5a8a8da50dcd1ad4c80e0f313a3234383a7ce1b831dec8689e44f80f547d2dea171f6a625e1a4ff6c6165e645f953103dabeb068a622203f859c6c64658fd3aa3b313a33393a66696e6456616c7565313a346c34383aa47624b8e7ee1e54df0c45e2eb858feb0b705bd2a78d8b739be31ba188f4bd6f56b371c51fecc5280d5fd26ba4168e966565") messageID := newMessageID() - valueToFind := newRandomBitmap().RawString() + valueToFind := newRandomBitmap() nodeToFind := Node{id: newRandomBitmap(), ip: net.ParseIP("1.2.3.4"), port: 1286} dht.store.Upsert(valueToFind, nodeToFind) @@ -394,7 +394,7 @@ func TestFindValueExisting(t *testing.T) { ID: messageID, NodeID: testNodeID, Method: findValueMethod, - Args: []string{valueToFind}, + Args: []string{valueToFind.RawString()}, } data, err := bencode.EncodeBytes(request) @@ -428,7 +428,7 @@ func TestFindValueExisting(t *testing.T) { t.Fatal("payload is not a dictionary") } - compactContacts, ok := payload[valueToFind] + compactContacts, ok := payload[valueToFind.RawString()] if !ok { t.Fatal("payload is missing key for search value") } diff --git a/dht/store.go b/dht/store.go index 8a34bab..0654598 100644 --- a/dht/store.go +++ b/dht/store.go @@ -10,33 +10,35 @@ type peer struct { } type peerStore struct { - nodeIDs map[string]map[bitmap]bool + // map of blob hashes to (map of node IDs to bools) + nodeIDs map[bitmap]map[bitmap]bool + // map of node IDs to peers nodeInfo map[bitmap]peer lock sync.RWMutex } func newPeerStore() *peerStore { return &peerStore{ - nodeIDs: make(map[string]map[bitmap]bool), + nodeIDs: make(map[bitmap]map[bitmap]bool), nodeInfo: make(map[bitmap]peer), } } -func (s *peerStore) Upsert(key string, node Node) { +func (s *peerStore) Upsert(blobHash bitmap, node Node) { s.lock.Lock() defer s.lock.Unlock() - if _, ok := s.nodeIDs[key]; !ok { - s.nodeIDs[key] = make(map[bitmap]bool) + if _, ok := s.nodeIDs[blobHash]; !ok { + s.nodeIDs[blobHash] = make(map[bitmap]bool) } - s.nodeIDs[key][node.id] = true + s.nodeIDs[blobHash][node.id] = true s.nodeInfo[node.id] = peer{node: node} } -func (s *peerStore) Get(key string) []Node { +func (s *peerStore) Get(blobHash bitmap) []Node { s.lock.RLock() defer s.lock.RUnlock() var nodes []Node - if ids, ok := s.nodeIDs[key]; ok { + if ids, ok := s.nodeIDs[blobHash]; ok { for id := range ids { peer, ok := s.nodeInfo[id] if !ok { diff --git a/dht/transaction_manager.go b/dht/transaction_manager.go index e8799aa..8a9c806 100644 --- a/dht/transaction_manager.go +++ b/dht/transaction_manager.go @@ -9,21 +9,21 @@ import ( log "github.com/sirupsen/logrus" ) -// query represents the query data included queried node and query-formed data. +// transaction represents a single query to the dht. it stores the queried node, the request, and the response channel type transaction struct { node Node - req *Request - res chan *Response + req Request + res chan Response } -// transactionManager represents the manager of transactions. +// transactionManager keeps track of the outstanding transactions type transactionManager struct { + dht *DHT lock *sync.RWMutex transactions map[messageID]*transaction - dht *DHT } -// newTransactionManager returns new transactionManager pointer. +// newTransactionManager returns a new transactionManager func newTransactionManager(dht *DHT) *transactionManager { return &transactionManager{ lock: &sync.RWMutex{}, @@ -32,36 +32,36 @@ func newTransactionManager(dht *DHT) *transactionManager { } } -// insert adds a transaction to transactionManager. -func (tm *transactionManager) insert(trans *transaction) { +// insert adds a transaction to the manager. +func (tm *transactionManager) insert(tx *transaction) { tm.lock.Lock() defer tm.lock.Unlock() - tm.transactions[trans.req.ID] = trans + tm.transactions[tx.req.ID] = tx } -// delete removes a transaction from transactionManager. +// delete removes a transaction from the manager. func (tm *transactionManager) delete(id messageID) { tm.lock.Lock() defer tm.lock.Unlock() delete(tm.transactions, id) } -// find transaction for id. optionally ensure that addr matches node from transaction +// Find finds a transaction for the given id. it optionally ensures that addr matches node from transaction func (tm *transactionManager) Find(id messageID, addr *net.UDPAddr) *transaction { tm.lock.RLock() defer tm.lock.RUnlock() t, ok := tm.transactions[id] - if !ok { - return nil - } else if addr != nil && t.node.Addr().String() != addr.String() { + if !ok || (addr != nil && t.node.Addr().String() != addr.String()) { return nil } return t } -func (tm *transactionManager) SendAsync(ctx context.Context, node Node, req *Request) <-chan *Response { +// SendAsync sends a transaction and returns a channel that will eventually contain the transaction response +// The response channel is closed when the transaction is completed or times out. +func (tm *transactionManager) SendAsync(ctx context.Context, node Node, req Request) <-chan *Response { if node.id.Equals(tm.dht.node.id) { log.Error("sending query to self") return nil @@ -74,24 +74,24 @@ func (tm *transactionManager) SendAsync(ctx context.Context, node Node, req *Req req.ID = newMessageID() req.NodeID = tm.dht.node.id - trans := &transaction{ + tx := &transaction{ node: node, req: req, - res: make(chan *Response), + res: make(chan Response), } - tm.insert(trans) - defer tm.delete(trans.req.ID) + tm.insert(tx) + defer tm.delete(tx.req.ID) for i := 0; i < udpRetry; i++ { - if err := send(tm.dht, trans.node.Addr(), *trans.req); err != nil { + if err := send(tm.dht, node.Addr(), tx.req); err != nil { log.Errorf("send error: ", err.Error()) continue // try again? return? } select { - case res := <-trans.res: - ch <- res + case res := <-tx.res: + ch <- &res return case <-ctx.Done(): return @@ -100,13 +100,15 @@ func (tm *transactionManager) SendAsync(ctx context.Context, node Node, req *Req } // if request timed out each time - tm.dht.rt.RemoveByID(trans.node.id) + tm.dht.rt.RemoveByID(tx.node.id) }() return ch } -func (tm *transactionManager) Send(node Node, req *Request) *Response { +// Send sends a transaction and blocks until the response is available. It returns a response, or nil +// if the transaction timed out. +func (tm *transactionManager) Send(node Node, req Request) *Response { return <-tm.SendAsync(context.Background(), node, req) }