diff --git a/dht/bits/bitmap.go b/dht/bits/bitmap.go index 164df90..88f9f3f 100644 --- a/dht/bits/bitmap.go +++ b/dht/bits/bitmap.go @@ -3,6 +3,7 @@ package bits import ( "crypto/rand" "encoding/hex" + "math/big" "strconv" "strings" @@ -26,6 +27,12 @@ func (b Bitmap) String() string { return string(b[:]) } +func (b Bitmap) Big() *big.Int { + i := new(big.Int) + i.SetString(b.Hex(), 16) + return i +} + // BString returns the bitmap as a string of 0s and 1s func (b Bitmap) BString() string { var s string @@ -343,6 +350,15 @@ func FromShortHexP(hexStr string) Bitmap { return bmp } +func FromBigP(b *big.Int) Bitmap { + return FromShortHexP(b.Text(16)) +} + +// Max returns a bitmap with all bits set to 1 +func MaxP() Bitmap { + return FromHexP(strings.Repeat("1", NumBytes*2)) +} + // Rand generates a cryptographically random bitmap with the confines of the parameters specified. func Rand() Bitmap { var id Bitmap diff --git a/dht/bootstrap.go b/dht/bootstrap.go index 797ac22..833f203 100644 --- a/dht/bootstrap.go +++ b/dht/bootstrap.go @@ -159,7 +159,8 @@ func (b *BootstrapNode) check() { func (b *BootstrapNode) handleRequest(addr *net.UDPAddr, request Request) { switch request.Method { case pingMethod: - if err := b.sendMessage(addr, Response{ID: request.ID, NodeID: b.id, Data: pingSuccessResponse}); err != nil { + err := b.sendMessage(addr, Response{ID: request.ID, NodeID: b.id, Data: pingSuccessResponse}) + if err != nil { log.Error("error sending response message - ", err) } case findNodeMethod: @@ -167,11 +168,13 @@ func (b *BootstrapNode) handleRequest(addr *net.UDPAddr, request Request) { log.Errorln("request is missing arg") return } - if err := b.sendMessage(addr, Response{ + + err := b.sendMessage(addr, Response{ ID: request.ID, NodeID: b.id, Contacts: b.get(bucketSize), - }); err != nil { + }) + if err != nil { log.Error("error sending 'findnodemethod' response message - ", err) } } diff --git a/dht/bootstrap_test.go b/dht/bootstrap_test.go index 1b89466..4997467 100644 --- a/dht/bootstrap_test.go +++ b/dht/bootstrap_test.go @@ -10,15 +10,15 @@ import ( func TestBootstrapPing(t *testing.T) { b := NewBootstrapNode(bits.Rand(), 10, bootstrapDefaultRefreshDuration) - listener, err := net.ListenPacket(network, "127.0.0.1:54320") + listener, err := net.ListenPacket(Network, "127.0.0.1:54320") if err != nil { panic(err) } - if err := b.Connect(listener.(*net.UDPConn)); err != nil { + err = b.Connect(listener.(*net.UDPConn)) + if err != nil { t.Error(err) } - defer b.Shutdown() b.Shutdown() } diff --git a/dht/dht.go b/dht/dht.go index 2cbaae9..aabac59 100644 --- a/dht/dht.go +++ b/dht/dht.go @@ -21,7 +21,7 @@ func init() { } const ( - network = "udp4" + Network = "udp4" // TODO: all these constants should be defaults, and should be used to set values in the standard Config. then the code should use values in the config // TODO: alternatively, have a global Config for constants. at least that way tests can modify the values @@ -90,26 +90,57 @@ type DHT struct { } // New returns a DHT pointer. If config is nil, then config will be set to the default config. -func New(config *Config) (*DHT, error) { +func New(config *Config) *DHT { if config == nil { config = NewStandardConfig() } - contact, err := getContact(config.NodeID, config.Address) - if err != nil { - return nil, err - } - d := &DHT{ conf: config, - contact: contact, - node: NewNode(contact.ID), stop: stopOnce.New(), joined: make(chan struct{}), lock: &sync.RWMutex{}, announced: make(map[bits.Bitmap]bool), } - return d, nil + return d +} + +func (dht *DHT) connect(conn UDPConn) error { + contact, err := getContact(dht.conf.NodeID, dht.conf.Address) + if err != nil { + return err + } + + dht.contact = contact + dht.node = NewNode(contact.ID) + + err = dht.node.Connect(conn) + if err != nil { + return err + } + return nil +} + +// Start starts the dht +func (dht *DHT) Start() error { + listener, err := net.ListenPacket(Network, dht.conf.Address) + if err != nil { + return errors.Err(err) + } + conn := listener.(*net.UDPConn) + + err = dht.connect(conn) + if err != nil { + return err + } + + dht.join() + log.Debugf("[%s] DHT ready on %s (%d nodes found during join)", + dht.node.id.HexShort(), dht.contact.Addr().String(), dht.node.rt.Count()) + + go dht.startReannouncer() + + return nil } // join makes current node join the dht network. @@ -144,27 +175,6 @@ func (dht *DHT) join() { // http://xlattice.sourceforge.net/components/protocol/kademlia/specs.html#join } -// Start starts the dht -func (dht *DHT) Start() error { - listener, err := net.ListenPacket(network, dht.conf.Address) - if err != nil { - return errors.Err(err) - } - conn := listener.(*net.UDPConn) - err = dht.node.Connect(conn) - if err != nil { - return err - } - - dht.join() - log.Debugf("[%s] DHT ready on %s (%d nodes found during join)", - dht.node.id.HexShort(), dht.contact.Addr().String(), dht.node.rt.Count()) - - go dht.startReannouncer() - - return nil -} - // WaitUntilJoined blocks until the node joins the network. func (dht *DHT) WaitUntilJoined() { if dht.joined == nil { @@ -184,7 +194,7 @@ func (dht *DHT) Shutdown() { // Ping pings a given address, creates a temporary contact for sending a message, and returns an error if communication // fails. func (dht *DHT) Ping(addr string) error { - raddr, err := net.ResolveUDPAddr(network, addr) + raddr, err := net.ResolveUDPAddr(Network, addr) if err != nil { return err } @@ -211,8 +221,20 @@ func (dht *DHT) Get(hash bits.Bitmap) ([]Contact, error) { return nil, nil } +// Add adds the hash to the list of hashes this node has +func (dht *DHT) Add(hash bits.Bitmap) error { + // TODO: calling Add several times quickly could cause it to be announced multiple times before dht.announced[hash] is set to true + dht.lock.RLock() + exists := dht.announced[hash] + dht.lock.RUnlock() + if exists { + return nil + } + return dht.announce(hash) +} + // Announce announces to the DHT that this node has the blob for the given hash -func (dht *DHT) Announce(hash bits.Bitmap) error { +func (dht *DHT) announce(hash bits.Bitmap) error { contacts, _, err := FindContacts(dht.node, hash, false, dht.stop.Ch()) if err != nil { return err @@ -254,7 +276,7 @@ func (dht *DHT) startReannouncer() { dht.stop.Add(1) go func(bm bits.Bitmap) { defer dht.stop.Done() - err := dht.Announce(bm) + err := dht.announce(bm) if err != nil { log.Error("error re-announcing bitmap - ", err) } diff --git a/dht/dht_test.go b/dht/dht_test.go index 1de7e31..7ee2ecf 100644 --- a/dht/dht_test.go +++ b/dht/dht_test.go @@ -121,7 +121,8 @@ func TestDHT_LargeDHT(t *testing.T) { wg.Add(1) go func(index int) { defer wg.Done() - if err := dhts[index].Announce(ids[index]); err != nil { + err := dhts[index].announce(ids[index]) + if err != nil { t.Error("error announcing random bitmap - ", err) } }(i) diff --git a/dht/node.go b/dht/node.go index d35de51..3e1b838 100644 --- a/dht/node.go +++ b/dht/node.go @@ -229,7 +229,8 @@ func (n *Node) handleRequest(addr *net.UDPAddr, request Request) { log.Errorln("invalid request method") return case pingMethod: - if err := n.sendMessage(addr, Response{ID: request.ID, NodeID: n.id, Data: pingSuccessResponse}); err != nil { + err := n.sendMessage(addr, Response{ID: request.ID, NodeID: n.id, Data: pingSuccessResponse}) + if err != nil { log.Error("error sending 'pingmethod' response message - ", err) } case storeMethod: @@ -237,11 +238,14 @@ func (n *Node) handleRequest(addr *net.UDPAddr, request Request) { // TODO: should we be using StoreArgs.NodeID or StoreArgs.Value.LbryID ??? if n.tokens.Verify(request.StoreArgs.Value.Token, request.NodeID, addr) { n.Store(request.StoreArgs.BlobHash, Contact{ID: request.StoreArgs.NodeID, IP: addr.IP, Port: request.StoreArgs.Value.Port}) - if err := n.sendMessage(addr, Response{ID: request.ID, NodeID: n.id, Data: storeSuccessResponse}); err != nil { + + err := n.sendMessage(addr, Response{ID: request.ID, NodeID: n.id, Data: storeSuccessResponse}) + if err != nil { log.Error("error sending 'storemethod' response message - ", err) } } else { - if err := n.sendMessage(addr, Error{ID: request.ID, NodeID: n.id, ExceptionType: "invalid-token"}); err != nil { + err := n.sendMessage(addr, Error{ID: request.ID, NodeID: n.id, ExceptionType: "invalid-token"}) + if err != nil { log.Error("error sending 'storemethod'response message for invalid-token - ", err) } } @@ -250,11 +254,12 @@ func (n *Node) handleRequest(addr *net.UDPAddr, request Request) { log.Errorln("request is missing arg") return } - if err := n.sendMessage(addr, Response{ + err := n.sendMessage(addr, Response{ ID: request.ID, NodeID: n.id, Contacts: n.rt.GetClosest(*request.Arg, bucketSize), - }); err != nil { + }) + if err != nil { log.Error("error sending 'findnodemethod' response message - ", err) } @@ -277,7 +282,8 @@ func (n *Node) handleRequest(addr *net.UDPAddr, request Request) { res.Contacts = n.rt.GetClosest(*request.Arg, bucketSize) } - if err := n.sendMessage(addr, res); err != nil { + err := n.sendMessage(addr, res) + if err != nil { log.Error("error sending 'findvaluemethod' response message - ", err) } } @@ -322,7 +328,8 @@ func (n *Node) sendMessage(addr *net.UDPAddr, data Message) error { log.Debugf("[%s] (%d bytes) %s", n.id.HexShort(), len(encoded), spew.Sdump(data)) } - if err := n.conn.SetWriteDeadline(time.Now().Add(5 * time.Second)); err != nil { + err = n.conn.SetWriteDeadline(time.Now().Add(5 * time.Second)) + if err != nil { log.Error("error setting write deadline - ", err) } @@ -391,7 +398,8 @@ func (n *Node) SendAsync(ctx context.Context, contact Contact, req Request) <-ch defer n.txDelete(tx.req.ID) for i := 0; i < udpRetry; i++ { - if err := n.sendMessage(contact.Addr(), tx.req); err != nil { + err := n.sendMessage(contact.Addr(), tx.req) + if err != nil { if !strings.Contains(err.Error(), "use of closed network connection") { // this only happens on localhost. real UDP has no connections log.Error("send error: ", err) } diff --git a/dht/node_test.go b/dht/node_test.go index b9537f7..009f754 100644 --- a/dht/node_test.go +++ b/dht/node_test.go @@ -15,12 +15,9 @@ func TestPing(t *testing.T) { conn := newTestUDPConn("127.0.0.1:21217") - dht, err := New(&Config{Address: "127.0.0.1:21216", NodeID: dhtNodeID.Hex()}) - if err != nil { - t.Fatal(err) - } + dht := New(&Config{Address: "127.0.0.1:21216", NodeID: dhtNodeID.Hex()}) - err = dht.node.Connect(conn) + err := dht.connect(conn) if err != nil { t.Fatal(err) } @@ -112,12 +109,9 @@ func TestStore(t *testing.T) { conn := newTestUDPConn("127.0.0.1:21217") - dht, err := New(&Config{Address: "127.0.0.1:21216", NodeID: dhtNodeID.Hex()}) - if err != nil { - t.Fatal(err) - } + dht := New(&Config{Address: "127.0.0.1:21216", NodeID: dhtNodeID.Hex()}) - err = dht.node.Connect(conn) + err := dht.connect(conn) if err != nil { t.Fatal(err) } @@ -210,12 +204,9 @@ func TestFindNode(t *testing.T) { conn := newTestUDPConn("127.0.0.1:21217") - dht, err := New(&Config{Address: "127.0.0.1:21216", NodeID: dhtNodeID.Hex()}) - if err != nil { - t.Fatal(err) - } + dht := New(&Config{Address: "127.0.0.1:21216", NodeID: dhtNodeID.Hex()}) - err = dht.node.Connect(conn) + err := dht.connect(conn) if err != nil { t.Fatal(err) } @@ -279,12 +270,9 @@ func TestFindValueExisting(t *testing.T) { conn := newTestUDPConn("127.0.0.1:21217") - dht, err := New(&Config{Address: "127.0.0.1:21216", NodeID: dhtNodeID.Hex()}) - if err != nil { - t.Fatal(err) - } + dht := New(&Config{Address: "127.0.0.1:21216", NodeID: dhtNodeID.Hex()}) - err = dht.node.Connect(conn) + err := dht.connect(conn) if err != nil { t.Fatal(err) } @@ -363,12 +351,9 @@ func TestFindValueFallbackToFindNode(t *testing.T) { conn := newTestUDPConn("127.0.0.1:21217") - dht, err := New(&Config{Address: "127.0.0.1:21216", NodeID: dhtNodeID.Hex()}) - if err != nil { - t.Fatal(err) - } + dht := New(&Config{Address: "127.0.0.1:21216", NodeID: dhtNodeID.Hex()}) - err = dht.node.Connect(conn) + err := dht.connect(conn) if err != nil { t.Fatal(err) } diff --git a/dht/testing.go b/dht/testing.go index 95efea7..7775b8f 100644 --- a/dht/testing.go +++ b/dht/testing.go @@ -23,11 +23,13 @@ func TestingCreateDHT(t *testing.T, numNodes int, bootstrap, concurrent bool) (* bootstrapAddress := testingDHTIP + ":" + strconv.Itoa(testingDHTFirstPort) seeds = []string{bootstrapAddress} bootstrapNode = NewBootstrapNode(bits.Rand(), 0, bootstrapDefaultRefreshDuration) - listener, err := net.ListenPacket(network, bootstrapAddress) + listener, err := net.ListenPacket(Network, bootstrapAddress) if err != nil { panic(err) } - if err := bootstrapNode.Connect(listener.(*net.UDPConn)); err != nil { + + err = bootstrapNode.Connect(listener.(*net.UDPConn)) + if err != nil { t.Error("error connecting bootstrap node - ", err) } } @@ -40,13 +42,11 @@ func TestingCreateDHT(t *testing.T, numNodes int, bootstrap, concurrent bool) (* dhts := make([]*DHT, numNodes) for i := 0; i < numNodes; i++ { - dht, err := New(&Config{Address: testingDHTIP + ":" + strconv.Itoa(firstPort+i), NodeID: bits.Rand().Hex(), SeedNodes: seeds}) - if err != nil { - panic(err) - } + dht := New(&Config{Address: testingDHTIP + ":" + strconv.Itoa(firstPort+i), NodeID: bits.Rand().Hex(), SeedNodes: seeds}) go func() { - if err := dht.Start(); err != nil { + err := dht.Start() + if err != nil { t.Error("error starting dht - ", err) } }()