diff --git a/.gitignore b/.gitignore index 25a6136..94e9a20 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,4 @@ /vendor /blobs -/config.json +/config.json* /prism-bin diff --git a/cluster/cluster.go b/cluster/cluster.go index 6aa833f..8e78f61 100644 --- a/cluster/cluster.go +++ b/cluster/cluster.go @@ -4,6 +4,7 @@ import ( "io/ioutil" baselog "log" "sort" + "time" "github.com/lbryio/lbry.go/crypto" "github.com/lbryio/lbry.go/errors" @@ -15,12 +16,13 @@ import ( const ( // DefaultClusterPort is the default port used when starting up a Cluster. - DefaultClusterPort = 17946 + DefaultClusterPort = 17946 + MembershipChangeBufferWindow = 1 * time.Second ) // Cluster maintains cluster membership and notifies on certain events type Cluster struct { - OnHashRangeChange func(n, total int) + OnMembershipChange func(n, total int) name string port int @@ -91,6 +93,9 @@ func (c *Cluster) Shutdown() { } func (c *Cluster) listen() { + var timerCh <-chan time.Time + timer := time.NewTimer(0) + for { select { case <-c.stop.Ch(): @@ -104,11 +109,17 @@ func (c *Cluster) listen() { continue } - if c.OnHashRangeChange != nil { - alive := getAliveMembers(c.s.Members()) - c.OnHashRangeChange(getHashInterval(c.name, alive), len(alive)) + if timerCh == nil { + timer.Reset(MembershipChangeBufferWindow) + timerCh = timer.C } } + case <-timerCh: + if c.OnMembershipChange != nil { + alive := getAliveMembers(c.s.Members()) + c.OnMembershipChange(getHashInterval(c.name, alive), len(alive)) + } + timerCh = nil } } } diff --git a/cmd/dht.go b/cmd/dht.go index 4dd1259..62be601 100644 --- a/cmd/dht.go +++ b/cmd/dht.go @@ -31,7 +31,7 @@ func init() { func dhtCmd(cmd *cobra.Command, args []string) { if args[0] == "bootstrap" { - node := dht.NewBootstrapNode(bits.Rand(), 1*time.Millisecond, 1*time.Millisecond) + node := dht.NewBootstrapNode(bits.Rand(), 1*time.Millisecond, 1*time.Minute) listener, err := net.ListenPacket(dht.Network, "127.0.0.1:"+strconv.Itoa(dhtPort)) checkErr(err) diff --git a/cmd/start.go b/cmd/start.go index 14a3479..5a65730 100644 --- a/cmd/start.go +++ b/cmd/start.go @@ -49,6 +49,12 @@ func startCmd(cmd *cobra.Command, args []string) { s3 := store.NewS3BlobStore(globalConfig.AwsID, globalConfig.AwsSecret, globalConfig.BucketRegion, globalConfig.BucketName) comboStore := store.NewDBBackedS3Store(s3, db) + // TODO: args we need: + // clusterAddr - to connect to cluster (or start new cluster if empty) + // minNodes - minimum number of nodes before announcing starts. otherwise first node will try to announce all the blobs in the db + // or maybe we should do maxHashesPerNode? + // in either case, this should not kill the cluster, but should only limit announces (and notify when some hashes are being left unannounced) + //clusterAddr := "" //if len(args) > 0 { // clusterAddr = args[0] diff --git a/cmd/upload.go b/cmd/upload.go index fde76e6..49e82ba 100644 --- a/cmd/upload.go +++ b/cmd/upload.go @@ -184,9 +184,9 @@ func launchFileUploader(params *uploaderParams, blobStore *store.DBBackedS3Store } } else { log.Printf("worker %d: putting %s", worker, hash) - err := blobStore.Put(hash, blob) + err = blobStore.Put(hash, blob) if err != nil { - log.Error("Put Blob Error: ", err) + log.Error("put Blob Error: ", err) } select { case params.countChan <- blobInc: diff --git a/db/db.go b/db/db.go index 425c800..f3fc445 100644 --- a/db/db.go +++ b/db/db.go @@ -49,23 +49,23 @@ func (s *SQL) Connect(dsn string) error { } // AddBlob adds a blobs information to the database. -func (s *SQL) AddBlob(hash string, length int, stored bool) error { +func (s *SQL) AddBlob(hash string, length int, isStored bool) error { if s.conn == nil { return errors.Err("not connected") } return withTx(s.conn, func(tx *sql.Tx) error { - return addBlob(tx, hash, length, stored) + return addBlob(tx, hash, length, isStored) }) } -func addBlob(tx *sql.Tx, hash string, length int, stored bool) error { +func addBlob(tx *sql.Tx, hash string, length int, isStored bool) error { if length <= 0 { return errors.Err("length must be positive") } - query := "INSERT INTO blob_ (hash, stored, length) VALUES (?,?,?) ON DUPLICATE KEY UPDATE stored = (stored or VALUES(stored))" - args := []interface{}{hash, stored, length} + query := "INSERT INTO blob_ (hash, is_stored, length) VALUES (?,?,?) ON DUPLICATE KEY UPDATE is_stored = (is_stored or VALUES(is_stored))" + args := []interface{}{hash, isStored, length} logQuery(query, args...) @@ -88,7 +88,7 @@ func (s *SQL) HasBlob(hash string) (bool, error) { return false, errors.Err("not connected") } - query := "SELECT EXISTS(SELECT 1 FROM blob_ WHERE hash = ? AND stored = ?)" + query := "SELECT EXISTS(SELECT 1 FROM blob_ WHERE hash = ? AND is_stored = ?)" args := []interface{}{hash, true} logQuery(query, args...) @@ -120,7 +120,7 @@ func (s *SQL) HasBlobs(hashes []string) (map[string]bool, error) { log.Debugf("getting hashes[%d:%d] of %d", doneIndex, sliceEnd, len(hashes)) batch := hashes[doneIndex:sliceEnd] - query := "SELECT hash FROM blob_ WHERE stored = ? && hash IN (" + querytools.Qs(len(batch)) + ")" + query := "SELECT hash FROM blob_ WHERE is_stored = ? && hash IN (" + querytools.Qs(len(batch)) + ")" args := make([]interface{}, len(batch)+1) args[0] = true for i := range batch { @@ -219,6 +219,23 @@ func (s *SQL) AddSDBlob(sdHash string, sdBlobLength int, sdBlob types.SdBlob) er }) } +// GetHashRange gets the smallest and biggest hashes in the db +func (s *SQL) GetHashRange() (string, string, error) { + var min string + var max string + + if s.conn == nil { + return "", "", errors.Err("not connected") + } + + query := "SELECT MIN(hash), MAX(hash) from blob_" + + logQuery(query) + + err := s.conn.QueryRow(query).Scan(&min, &max) + return min, max, err +} + // GetHashesInRange gets blobs with hashes in a given range, and sends the hashes into a channel func (s *SQL) GetHashesInRange(ctx context.Context, start, end bits.Bitmap) (ch chan bits.Bitmap, ech chan error) { ch = make(chan bits.Bitmap) @@ -320,7 +337,7 @@ func closeRows(rows *sql.Rows) { CREATE TABLE blob_ ( hash char(96) NOT NULL, - stored TINYINT(1) NOT NULL DEFAULT 0, + is_stored TINYINT(1) NOT NULL DEFAULT 0, length bigint(20) unsigned DEFAULT NULL, last_announced_at datetime DEFAULT NULL, PRIMARY KEY (hash), diff --git a/dht/bits/bitmap.go b/dht/bits/bitmap.go index 88f9f3f..4edffcc 100644 --- a/dht/bits/bitmap.go +++ b/dht/bits/bitmap.go @@ -23,14 +23,12 @@ const ( // package as a way to handle the unique identifiers of a DHT node. type Bitmap [NumBytes]byte -func (b Bitmap) String() string { +func (b Bitmap) RawString() string { return string(b[:]) } -func (b Bitmap) Big() *big.Int { - i := new(big.Int) - i.SetString(b.Hex(), 16) - return i +func (b Bitmap) String() string { + return b.Hex() } // BString returns the bitmap as a string of 0s and 1s @@ -61,6 +59,12 @@ func (b Bitmap) HexSimplified() string { return simple } +func (b Bitmap) Big() *big.Int { + i := new(big.Int) + i.SetString(b.Hex(), 16) + return i +} + // Equals returns T/F if every byte in bitmap are equal. func (b Bitmap) Equals(other Bitmap) bool { for k := range b { @@ -356,7 +360,7 @@ func FromBigP(b *big.Int) Bitmap { // Max returns a bitmap with all bits set to 1 func MaxP() Bitmap { - return FromHexP(strings.Repeat("1", NumBytes*2)) + return FromHexP(strings.Repeat("f", NumBytes*2)) } // Rand generates a cryptographically random bitmap with the confines of the parameters specified. diff --git a/dht/bits/range.go b/dht/bits/range.go new file mode 100644 index 0000000..349e148 --- /dev/null +++ b/dht/bits/range.go @@ -0,0 +1,63 @@ +package bits + +import ( + "math/big" + + "github.com/lbryio/errors.go" +) + +// Range has a start and end +type Range struct { + Start Bitmap + End Bitmap +} + +func MaxRange() Range { + return Range{ + Start: Bitmap{}, + End: MaxP(), + } +} + +// IntervalP divides the range into `num` intervals and returns the `n`th one +// intervals are approximately the same size, but may not be exact because of rounding issues +// the first interval always starts at the beginning of the range, and the last interval always ends at the end +func (r Range) IntervalP(n, num int) Range { + if num < 1 || n < 1 || n > num { + panic(errors.Err("invalid interval %d of %d", n, num)) + } + + start := r.intervalStart(n, num) + end := new(big.Int) + if n == num { + end = r.End.Big() + } else { + end = r.intervalStart(n+1, num) + end.Sub(end, big.NewInt(1)) + } + + return Range{FromBigP(start), FromBigP(end)} +} + +func (r Range) intervalStart(n, num int) *big.Int { + // formula: + // size = (end - start) / num + // rem = (end - start) % num + // intervalStart = rangeStart + (size * n-1) + ((rem * n-1) % num) + + size := new(big.Int) + rem := new(big.Int) + size.Sub(r.End.Big(), r.Start.Big()).DivMod(size, big.NewInt(int64(num)), rem) + + size.Mul(size, big.NewInt(int64(n-1))) + rem.Mul(rem, big.NewInt(int64(n-1))).Mod(rem, big.NewInt(int64(num))) + + start := r.Start.Big() + start.Add(start, size).Add(start, rem) + + return start +} + +func (r Range) IntervalSize() *big.Int { + return (&big.Int{}).Sub(r.End.Big(), r.Start.Big()) +} diff --git a/dht/bits/range_test.go b/dht/bits/range_test.go new file mode 100644 index 0000000..79d31eb --- /dev/null +++ b/dht/bits/range_test.go @@ -0,0 +1,48 @@ +package bits + +import ( + "math/big" + "testing" +) + +func TestMaxRange(t *testing.T) { + start := FromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000") + end := FromHexP("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") + r := MaxRange() + + if !r.Start.Equals(start) { + t.Error("max range does not start at the beginning") + } + if !r.End.Equals(end) { + t.Error("max range does not end at the end") + } +} + +func TestRange_IntervalP(t *testing.T) { + max := MaxRange() + + numIntervals := 97 + expectedAvg := (&big.Int{}).Div(max.IntervalSize(), big.NewInt(int64(numIntervals))) + maxDiff := big.NewInt(int64(numIntervals)) + + var lastEnd Bitmap + + for i := 1; i <= numIntervals; i++ { + ival := max.IntervalP(i, numIntervals) + if i == 1 && !ival.Start.Equals(max.Start) { + t.Error("first interval does not start at 0") + } + if i == numIntervals && !ival.End.Equals(max.End) { + t.Error("last interval does not end at max") + } + if i > 1 && !ival.Start.Equals(lastEnd.Add(FromShortHexP("1"))) { + t.Errorf("interval %d of %d: last end was %s, this start is %s", i, numIntervals, lastEnd.Hex(), ival.Start.Hex()) + } + + if ival.IntervalSize().Cmp((&big.Int{}).Add(expectedAvg, maxDiff)) > 0 || ival.IntervalSize().Cmp((&big.Int{}).Sub(expectedAvg, maxDiff)) < 0 { + t.Errorf("interval %d of %d: interval size is outside the normal range", i, numIntervals) + } + + lastEnd = ival.End + } +} diff --git a/dht/bootstrap.go b/dht/bootstrap.go index 833f203..0ca6864 100644 --- a/dht/bootstrap.go +++ b/dht/bootstrap.go @@ -21,9 +21,9 @@ type BootstrapNode struct { initialPingInterval time.Duration checkInterval time.Duration - nlock *sync.RWMutex - nodes []peer - nodeKeys map[bits.Bitmap]int + nlock *sync.RWMutex + nodes map[bits.Bitmap]*peer + nodeIDs []bits.Bitmap // necessary for efficient random ID selection } // NewBootstrapNode returns a BootstrapNode pointer. @@ -34,9 +34,9 @@ func NewBootstrapNode(id bits.Bitmap, initialPingInterval, rePingInterval time.D initialPingInterval: initialPingInterval, checkInterval: rePingInterval, - nlock: &sync.RWMutex{}, - nodes: make([]peer, 0), - nodeKeys: make(map[bits.Bitmap]int), + nlock: &sync.RWMutex{}, + nodes: make(map[bits.Bitmap]*peer), + nodeIDs: make([]bits.Bitmap, 0), } b.requestHandler = b.handleRequest @@ -78,15 +78,15 @@ func (b *BootstrapNode) upsert(c Contact) { b.nlock.Lock() defer b.nlock.Unlock() - if i, exists := b.nodeKeys[c.ID]; exists { - log.Debugf("[%s] bootstrap: touching contact %s", b.id.HexShort(), b.nodes[i].Contact.ID.HexShort()) - b.nodes[i].Touch() + if node, exists := b.nodes[c.ID]; exists { + log.Debugf("[%s] bootstrap: touching contact %s", b.id.HexShort(), node.Contact.ID.HexShort()) + node.Touch() return } log.Debugf("[%s] bootstrap: adding new contact %s", b.id.HexShort(), c.ID.HexShort()) - b.nodeKeys[c.ID] = len(b.nodes) - b.nodes = append(b.nodes, peer{c, time.Now(), 0}) + b.nodes[c.ID] = &peer{c, time.Now(), 0} + b.nodeIDs = append(b.nodeIDs, c.ID) } // remove removes the contact from the list @@ -94,14 +94,19 @@ func (b *BootstrapNode) remove(c Contact) { b.nlock.Lock() defer b.nlock.Unlock() - i, exists := b.nodeKeys[c.ID] + _, exists := b.nodes[c.ID] if !exists { return } log.Debugf("[%s] bootstrap: removing contact %s", b.id.HexShort(), c.ID.HexShort()) - b.nodes = append(b.nodes[:i], b.nodes[i+1:]...) - delete(b.nodeKeys, c.ID) + delete(b.nodes, c.ID) + for i := range b.nodeIDs { + if b.nodeIDs[i].Equals(c.ID) { + b.nodeIDs = append(b.nodeIDs[:i], b.nodeIDs[i+1:]...) + break + } + } } // get returns up to `limit` random contacts from the list @@ -114,8 +119,8 @@ func (b *BootstrapNode) get(limit int) []Contact { } ret := make([]Contact, limit) - for i, k := range randKeys(len(b.nodes))[:limit] { - ret[i] = b.nodes[k].Contact + for i, k := range randKeys(len(b.nodeIDs))[:limit] { + ret[i] = b.nodes[b.nodeIDs[k]].Contact } return ret @@ -123,6 +128,7 @@ func (b *BootstrapNode) get(limit int) []Contact { // ping pings a node. if the node responds, it is added to the list. otherwise, it is removed func (b *BootstrapNode) ping(c Contact) { + log.Debugf("[%s] bootstrap: pinging %s", b.id.HexShort(), c.ID.HexShort()) b.stop.Add(1) defer b.stop.Done() @@ -180,9 +186,19 @@ func (b *BootstrapNode) handleRequest(addr *net.UDPAddr, request Request) { } go func() { - log.Debugf("[%s] bootstrap: queuing %s to ping", b.id.HexShort(), request.NodeID.HexShort()) - <-time.After(b.initialPingInterval) - b.ping(Contact{ID: request.NodeID, IP: addr.IP, Port: addr.Port}) + b.nlock.RLock() + _, exists := b.nodes[request.NodeID] + b.nlock.RUnlock() + if !exists { + log.Debugf("[%s] bootstrap: queuing %s to ping", b.id.HexShort(), request.NodeID.HexShort()) + <-time.After(b.initialPingInterval) + b.nlock.RLock() + _, exists = b.nodes[request.NodeID] + b.nlock.RUnlock() + if !exists { + b.ping(Contact{ID: request.NodeID, IP: addr.IP, Port: addr.Port}) + } + } }() } diff --git a/dht/dht.go b/dht/dht.go index aabac59..0df919d 100644 --- a/dht/dht.go +++ b/dht/dht.go @@ -200,7 +200,7 @@ func (dht *DHT) Ping(addr string) error { } tmpNode := Contact{ID: bits.Rand(), IP: raddr.IP, Port: raddr.Port} - res := dht.node.Send(tmpNode, Request{Method: pingMethod}) + res := dht.node.Send(tmpNode, Request{Method: pingMethod}, SendOptions{skipIDCheck: true}) if res == nil { return errors.Err("no response from node %s", addr) } @@ -222,15 +222,23 @@ func (dht *DHT) Get(hash bits.Bitmap) ([]Contact, error) { } // Add adds the hash to the list of hashes this node has -func (dht *DHT) Add(hash bits.Bitmap) error { +func (dht *DHT) Add(hash bits.Bitmap) { // TODO: calling Add several times quickly could cause it to be announced multiple times before dht.announced[hash] is set to true dht.lock.RLock() exists := dht.announced[hash] dht.lock.RUnlock() if exists { - return nil + return } - return dht.announce(hash) + + dht.stop.Add(1) + go func() { + defer dht.stop.Done() + err := dht.announce(hash) + if err != nil { + log.Error(errors.Prefix("error announcing bitmap", err)) + } + }() } // Announce announces to the DHT that this node has the blob for the given hash @@ -241,7 +249,10 @@ func (dht *DHT) announce(hash bits.Bitmap) error { } // if we found less than K contacts, or current node is closer than farthest contact - if len(contacts) < bucketSize || dht.node.id.Xor(hash).Less(contacts[bucketSize-1].ID.Xor(hash)) { + if len(contacts) < bucketSize { + // append self to contacts, and self-store + contacts = append(contacts, dht.contact) + } else if dht.node.id.Xor(hash).Less(contacts[bucketSize-1].ID.Xor(hash)) { // pop last contact, and self-store instead contacts[bucketSize-1] = dht.contact } @@ -289,7 +300,7 @@ func (dht *DHT) startReannouncer() { func (dht *DHT) storeOnNode(hash bits.Bitmap, c Contact) { // self-store - if dht.contact.Equals(c) { + if dht.contact.ID == c.ID { dht.node.Store(hash, c) return } diff --git a/dht/message_test.go b/dht/message_test.go index 71e99de..c42cdef 100644 --- a/dht/message_test.go +++ b/dht/message_test.go @@ -102,7 +102,7 @@ func TestBencodeFindValueResponse(t *testing.T) { res := Response{ ID: newMessageID(), NodeID: bits.Rand(), - FindValueKey: bits.Rand().String(), + FindValueKey: bits.Rand().RawString(), Token: "arst", Contacts: []Contact{ {ID: bits.Rand(), IP: net.IPv4(1, 2, 3, 4).To4(), Port: 5678}, diff --git a/dht/node.go b/dht/node.go index 3e1b838..3be10d5 100644 --- a/dht/node.go +++ b/dht/node.go @@ -276,7 +276,7 @@ func (n *Node) handleRequest(addr *net.UDPAddr, request Request) { } if contacts := n.store.Get(*request.Arg); len(contacts) > 0 { - res.FindValueKey = request.Arg.String() + res.FindValueKey = request.Arg.RawString() res.Contacts = contacts } else { res.Contacts = n.rt.GetClosest(*request.Arg, bucketSize) @@ -297,7 +297,7 @@ func (n *Node) handleRequest(addr *net.UDPAddr, request Request) { // handleResponse handles responses received from udp. func (n *Node) handleResponse(addr *net.UDPAddr, response Response) { - tx := n.txFind(response.ID, addr) + tx := n.txFind(response.ID, Contact{ID: response.NodeID, IP: addr.IP, Port: addr.Port}) if tx != nil { tx.res <- response } @@ -339,9 +339,10 @@ func (n *Node) sendMessage(addr *net.UDPAddr, data Message) error { // transaction represents a single query to the dht. it stores the queried contact, the request, and the response channel type transaction struct { - contact Contact - req Request - res chan Response + contact Contact + req Request + res chan Response + skipIDCheck bool } // insert adds a transaction to the manager. @@ -358,24 +359,27 @@ func (n *Node) txDelete(id messageID) { delete(n.transactions, id) } -// Find finds a transaction for the given id. it optionally ensures that addr matches contact from transaction -func (n *Node) txFind(id messageID, addr *net.UDPAddr) *transaction { +// Find finds a transaction for the given id and contact +func (n *Node) txFind(id messageID, c Contact) *transaction { n.txLock.RLock() defer n.txLock.RUnlock() - // TODO: also check that the response's nodeid matches the id you thought you sent to? - t, ok := n.transactions[id] - if !ok || (addr != nil && t.contact.Addr().String() != addr.String()) { + if !ok || !t.contact.Equals(c, !t.skipIDCheck) { return nil } return t } +// SendOptions controls the behavior of send calls +type SendOptions struct { + skipIDCheck bool +} + // SendAsync sends a transaction and returns a channel that will eventually contain the transaction response // The response channel is closed when the transaction is completed or times out. -func (n *Node) SendAsync(ctx context.Context, contact Contact, req Request) <-chan *Response { +func (n *Node) SendAsync(ctx context.Context, contact Contact, req Request, options ...SendOptions) <-chan *Response { if contact.ID.Equals(n.id) { log.Error("sending query to self") return nil @@ -394,6 +398,10 @@ func (n *Node) SendAsync(ctx context.Context, contact Contact, req Request) <-ch res: make(chan Response), } + if len(options) > 0 && options[0].skipIDCheck { + tx.skipIDCheck = true + } + n.txInsert(tx) defer n.txDelete(tx.req.ID) @@ -425,14 +433,14 @@ func (n *Node) SendAsync(ctx context.Context, contact Contact, req Request) <-ch // Send sends a transaction and blocks until the response is available. It returns a response, or nil // if the transaction timed out. -func (n *Node) Send(contact Contact, req Request) *Response { - return <-n.SendAsync(context.Background(), contact, req) +func (n *Node) Send(contact Contact, req Request, options ...SendOptions) *Response { + return <-n.SendAsync(context.Background(), contact, req, options...) } // SendCancelable sends the transaction asynchronously and allows the transaction to be canceled -func (n *Node) SendCancelable(contact Contact, req Request) (<-chan *Response, context.CancelFunc) { +func (n *Node) SendCancelable(contact Contact, req Request, options ...SendOptions) (<-chan *Response, context.CancelFunc) { ctx, cancel := context.WithCancel(context.Background()) - return n.SendAsync(ctx, contact, req), cancel + return n.SendAsync(ctx, contact, req, options...), cancel } // CountActiveTransactions returns the number of transactions in the manager diff --git a/dht/node_finder.go b/dht/node_finder.go index 72c1b1a..3780281 100644 --- a/dht/node_finder.go +++ b/dht/node_finder.go @@ -48,9 +48,17 @@ func FindContacts(node *Node, target bits.Bitmap, findValue bool, upstreamStop s stop: stopOnce.New(), outstandingRequestsMutex: &sync.RWMutex{}, } + if upstreamStop != nil { - cf.stop.Link(upstreamStop) + go func() { + select { + case <-upstreamStop: + cf.Stop() + case <-cf.stop.Ch(): + } + }() } + return cf.Find() } diff --git a/dht/node_test.go b/dht/node_test.go index 009f754..26d0cc7 100644 --- a/dht/node_test.go +++ b/dht/node_test.go @@ -28,7 +28,7 @@ func TestPing(t *testing.T) { data, err := bencode.EncodeBytes(map[string]interface{}{ headerTypeField: requestType, headerMessageIDField: messageID, - headerNodeIDField: testNodeID.String(), + headerNodeIDField: testNodeID.RawString(), headerPayloadField: "ping", headerArgsField: []string{}, }) @@ -84,7 +84,7 @@ func TestPing(t *testing.T) { rNodeID, ok := response[headerNodeIDField].(string) if !ok { t.Error("node ID is not a string") - } else if rNodeID != dhtNodeID.String() { + } else if rNodeID != dhtNodeID.RawString() { t.Error("unexpected node ID") } } @@ -171,7 +171,7 @@ func TestStore(t *testing.T) { } } - verifyResponse(t, response, messageID, dhtNodeID.String()) + verifyResponse(t, response, messageID, dhtNodeID.RawString()) _, ok := response[headerPayloadField] if !ok { @@ -249,7 +249,7 @@ func TestFindNode(t *testing.T) { } } - verifyResponse(t, response, messageID, dhtNodeID.String()) + verifyResponse(t, response, messageID, dhtNodeID.RawString()) _, ok := response[headerPayloadField] if !ok { @@ -320,7 +320,7 @@ func TestFindValueExisting(t *testing.T) { } } - verifyResponse(t, response, messageID, dhtNodeID.String()) + verifyResponse(t, response, messageID, dhtNodeID.RawString()) _, ok := response[headerPayloadField] if !ok { @@ -332,7 +332,7 @@ func TestFindValueExisting(t *testing.T) { t.Fatal("payload is not a dictionary") } - compactContacts, ok := payload[valueToFind.String()] + compactContacts, ok := payload[valueToFind.RawString()] if !ok { t.Fatal("payload is missing key for search value") } @@ -396,7 +396,7 @@ func TestFindValueFallbackToFindNode(t *testing.T) { } } - verifyResponse(t, response, messageID, dhtNodeID.String()) + verifyResponse(t, response, messageID, dhtNodeID.RawString()) _, ok := response[headerPayloadField] if !ok { diff --git a/dht/routing_table.go b/dht/routing_table.go index 7ac80d8..410420f 100644 --- a/dht/routing_table.go +++ b/dht/routing_table.go @@ -32,8 +32,8 @@ type Contact struct { } // Equals returns T/F if two contacts are the same. -func (c Contact) Equals(other Contact) bool { - return c.ID == other.ID +func (c Contact) Equals(other Contact, checkID bool) bool { + return c.IP.Equal(other.IP) && c.Port == other.Port && (!checkID || c.ID == other.ID) } // Addr returns the UPD Address of the contact. @@ -150,7 +150,7 @@ func (p *peer) Touch() { // ActiveSince returns whether a peer has responded in the last `d` duration // this is used to check if the peer is "good", meaning that we believe the peer will respond to our requests func (p *peer) ActiveInLast(d time.Duration) bool { - return time.Since(p.LastActivity) > d + return time.Since(p.LastActivity) < d } // IsBad returns whether a peer is "bad", meaning that it has failed to respond to multiple pings in a row @@ -352,20 +352,14 @@ func (rt *routingTable) Count() int { return count } -// Range is a structure that holds a min and max bitmaps. The range is used in bucket sizing. -type Range struct { - start bits.Bitmap - end bits.Bitmap -} - // BucketRanges returns a slice of ranges, where the `start` of each range is the smallest id that can // go in that bucket, and the `end` is the largest id -func (rt *routingTable) BucketRanges() []Range { - ranges := make([]Range, len(rt.buckets)) +func (rt *routingTable) BucketRanges() []bits.Range { + ranges := make([]bits.Range, len(rt.buckets)) for i := range rt.buckets { - ranges[i] = Range{ - rt.id.Suffix(i, false).Set(nodeIDBits-1-i, !rt.id.Get(nodeIDBits-1-i)), - rt.id.Suffix(i, true).Set(nodeIDBits-1-i, !rt.id.Get(nodeIDBits-1-i)), + ranges[i] = bits.Range{ + Start: rt.id.Suffix(i, false).Set(nodeIDBits-1-i, !rt.id.Get(nodeIDBits-1-i)), + End: rt.id.Suffix(i, true).Set(nodeIDBits-1-i, !rt.id.Get(nodeIDBits-1-i)), } } return ranges diff --git a/dht/routing_table_test.go b/dht/routing_table_test.go index a889a5a..7a48467 100644 --- a/dht/routing_table_test.go +++ b/dht/routing_table_test.go @@ -147,14 +147,14 @@ func TestRoutingTable_MoveToBack(t *testing.T) { func TestRoutingTable_BucketRanges(t *testing.T) { id := bits.FromHexP("1c8aff71b99462464d9eeac639595ab99664be3482cb91a29d87467515c7d9158fe72aa1f1582dab07d8f8b5db277f41") ranges := newRoutingTable(id).BucketRanges() - if !ranges[0].start.Equals(ranges[0].end) { + if !ranges[0].Start.Equals(ranges[0].End) { t.Error("first bucket should only fit exactly one id") } for i := 0; i < 1000; i++ { randID := bits.Rand() found := -1 for i, r := range ranges { - if r.start.LessOrEqual(randID) && r.end.GreaterOrEqual(randID) { + if r.Start.LessOrEqual(randID) && r.End.GreaterOrEqual(randID) { if found >= 0 { t.Errorf("%s appears in buckets %d and %d", randID.Hex(), found, i) } else { @@ -176,10 +176,10 @@ func TestRoutingTable_Save(t *testing.T) { for i, r := range ranges { for j := 0; j < bucketSize; j++ { - toAdd := r.start.Add(bits.FromShortHexP(strconv.Itoa(j))) - if toAdd.LessOrEqual(r.end) { + toAdd := r.Start.Add(bits.FromShortHexP(strconv.Itoa(j))) + if toAdd.LessOrEqual(r.End) { rt.Update(Contact{ - ID: r.start.Add(bits.FromShortHexP(strconv.Itoa(j))), + ID: r.Start.Add(bits.FromShortHexP(strconv.Itoa(j))), IP: net.ParseIP("1.2.3." + strconv.Itoa(j)), Port: 1 + i*bucketSize + j, }) diff --git a/dht/testing.go b/dht/testing.go index 7775b8f..cda1071 100644 --- a/dht/testing.go +++ b/dht/testing.go @@ -226,7 +226,7 @@ func verifyContacts(t *testing.T, contacts []interface{}, nodes []Contact) { continue } for _, n := range nodes { - if n.ID.String() == id { + if n.ID.RawString() == id { currNode = n currNodeFound = true foundNodes[id] = true diff --git a/prism/prism.go b/prism/prism.go index 976b63d..e62b916 100644 --- a/prism/prism.go +++ b/prism/prism.go @@ -2,7 +2,6 @@ package prism import ( "context" - "math/big" "strconv" "sync" @@ -79,7 +78,7 @@ func New(conf *Config) *Prism { stop: stopOnce.New(), } - c.OnHashRangeChange = func(n, total int) { + c.OnMembershipChange = func(n, total int) { p.stop.Add(1) go func() { p.AnnounceRange(n, total) @@ -144,22 +143,20 @@ func (p *Prism) AnnounceRange(n, total int) { return } - max := bits.MaxP().Big() - interval := new(big.Int).Div(max, big.NewInt(int64(total))) + //r := bits.MaxRange().IntervalP(n, total) - start := new(big.Int).Mul(interval, big.NewInt(int64(n-1))) - end := new(big.Int).Add(start, interval) - if n == total { - end = end.Add(end, big.NewInt(10000)) // there are rounding issues sometimes, so lets make sure we get the full range - } - if end.Cmp(max) > 0 { - end.Set(max) + // TODO: this is temporary. it lets me test with a small number of hashes. use the full range in production + min, max, err := p.db.GetHashRange() + if err != nil { + log.Errorf("%s: error getting hash range: %s", p.dht.ID().HexShort(), err.Error()) + return } + r := (bits.Range{Start: bits.FromHexP(min), End: bits.FromHexP(max)}).IntervalP(n, total) - log.Debugf("%s: hash range is now %s to %s\n", p.dht.ID().HexShort(), bits.FromBigP(start).Hex(), bits.FromBigP(end).Hex()) + log.Infof("%s: hash range is now %s to %s", p.dht.ID().HexShort(), r.Start, r.End) ctx, cancel := context.WithCancel(context.Background()) - hashCh, errCh := p.db.GetHashesInRange(ctx, bits.FromBigP(start), bits.FromBigP(end)) + hashCh, errCh := p.db.GetHashesInRange(ctx, r.Start, r.End) var wg sync.WaitGroup @@ -188,6 +185,7 @@ func (p *Prism) AnnounceRange(n, total int) { if !more { return } + //log.Infof("%s: announcing %s", p.dht.ID().HexShort(), hash.Hex()) p.dht.Add(hash) } }