starting to put together the pieces

- prism start command
- more configs for prism when assembling the pieces
- cluster notifies on membership change, determines hash range, announces hashes
This commit is contained in:
Alex Grintsvayg 2018-06-14 22:30:37 -04:00
parent 3e7f7583d6
commit 4535122a06
26 changed files with 565 additions and 256 deletions

2
.gitignore vendored
View file

@ -1,4 +1,4 @@
/vendor /vendor
/blobs /blobs
/config.json /config.json
/prism /prism-bin

View file

@ -1,4 +1,4 @@
BINARY=prism BINARY=prism-bin
DIR = $(shell cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd) DIR = $(shell cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd)
VENDOR_DIR = vendor VENDOR_DIR = vendor

View file

@ -18,8 +18,10 @@ const (
DefaultClusterPort = 17946 DefaultClusterPort = 17946
) )
// Cluster is a management type for Serf which is used to maintain cluster membership of lbry nodes. // Cluster maintains cluster membership and notifies on certain events
type Cluster struct { type Cluster struct {
OnHashRangeChange func(n, total int)
name string name string
port int port int
seedAddr string seedAddr string
@ -66,20 +68,26 @@ func (c *Cluster) Connect() error {
return err return err
} }
} }
c.stop.Add(1) c.stop.Add(1)
go func() { go func() {
defer c.stop.Done()
c.listen() c.listen()
c.stop.Done()
}() }()
log.Debugf("cluster started")
return nil return nil
} }
// Shutdown safely shuts down the cluster. // Shutdown safely shuts down the cluster.
func (c *Cluster) Shutdown() { func (c *Cluster) Shutdown() {
log.Debug("shutting down cluster...")
c.stop.StopAndWait() c.stop.StopAndWait()
if err := c.s.Leave(); err != nil { err := c.s.Leave()
log.Error("error shutting down cluster - ", err) if err != nil {
log.Error(errors.Prefix("shutting down cluster", err))
} }
log.Debugf("cluster stopped")
} }
func (c *Cluster) listen() { func (c *Cluster) listen() {
@ -96,19 +104,16 @@ func (c *Cluster) listen() {
continue continue
} }
//spew.Dump(c.Members()) if c.OnHashRangeChange != nil {
alive := getAliveMembers(c.s.Members()) alive := getAliveMembers(c.s.Members())
log.Printf("%s: my hash range is now %d of %d\n", c.name, getHashRangeStart(c.name, alive), len(alive)) c.OnHashRangeChange(getHashInterval(c.name, alive), len(alive))
// figure out my new hash range based on the start and the number of alive members }
// get hashes in that range that need announcing
// announce them
// if more than one node is announcing each hash, figure out how to deal with last_announced_at so both nodes dont announce the same thing at the same time
} }
} }
} }
} }
func getHashRangeStart(myName string, members []serf.Member) int { func getHashInterval(myName string, members []serf.Member) int {
var names []string var names []string
for _, m := range members { for _, m := range members {
names = append(names, m.Name) names = append(names, m.Name)

View file

@ -1,41 +1,51 @@
package cmd package cmd
import ( import (
"net"
"os"
"os/signal"
"strconv"
"syscall"
"time"
"github.com/lbryio/reflector.go/dht" "github.com/lbryio/reflector.go/dht"
"github.com/lbryio/reflector.go/dht/bits" "github.com/lbryio/reflector.go/dht/bits"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
var dhtPort int
func init() { func init() {
var cmd = &cobra.Command{ var cmd = &cobra.Command{
Use: "dht", Use: "dht [start|bootstrap]",
Short: "Run interactive dht node", Short: "Run dht node",
Run: dhtCmd, ValidArgs: []string{"start", "bootstrap"},
Args: argFuncs(cobra.ExactArgs(1), cobra.OnlyValidArgs),
Run: dhtCmd,
} }
cmd.PersistentFlags().IntVar(&dhtPort, "port", 4567, "Port to start DHT on")
rootCmd.AddCommand(cmd) rootCmd.AddCommand(cmd)
} }
func dhtCmd(cmd *cobra.Command, args []string) { func dhtCmd(cmd *cobra.Command, args []string) {
conf := &dht.Config{ if args[0] == "bootstrap" {
Address: "0.0.0.0:4460", node := dht.NewBootstrapNode(bits.Rand(), 1*time.Millisecond, 1*time.Millisecond)
SeedNodes: []string{
"34.231.152.182:4460", listener, err := net.ListenPacket(dht.Network, "127.0.0.1:"+strconv.Itoa(dhtPort))
}, checkErr(err)
conn := listener.(*net.UDPConn)
err = node.Connect(conn)
checkErr(err)
interruptChan := make(chan os.Signal, 1)
signal.Notify(interruptChan, os.Interrupt, syscall.SIGTERM)
<-interruptChan
log.Printf("shutting down bootstrap node")
node.Shutdown()
} else {
log.Fatal("not implemented")
} }
d, err := dht.New(conf)
checkErr(err)
err = d.Start()
checkErr(err)
defer d.Shutdown()
err = d.Ping("34.231.152.182:4470")
checkErr(err)
err = d.Announce(bits.Rand())
checkErr(err)
d.PrintState()
} }

View file

@ -31,7 +31,9 @@ func peerCmd(cmd *cobra.Command, args []string) {
s3 := store.NewS3BlobStore(globalConfig.AwsID, globalConfig.AwsSecret, globalConfig.BucketRegion, globalConfig.BucketName) s3 := store.NewS3BlobStore(globalConfig.AwsID, globalConfig.AwsSecret, globalConfig.BucketRegion, globalConfig.BucketName)
combo := store.NewDBBackedS3Store(s3, db) combo := store.NewDBBackedS3Store(s3, db)
peerServer := peer.NewServer(combo) peerServer := peer.NewServer(combo)
if err := peerServer.Start("localhost:" + strconv.Itoa(peer.DefaultPort)); err != nil {
err = peerServer.Start("localhost:" + strconv.Itoa(peer.DefaultPort))
if err != nil {
log.Fatal(err) log.Fatal(err)
} }

View file

@ -31,7 +31,8 @@ func reflectorCmd(cmd *cobra.Command, args []string) {
s3 := store.NewS3BlobStore(globalConfig.AwsID, globalConfig.AwsSecret, globalConfig.BucketRegion, globalConfig.BucketName) s3 := store.NewS3BlobStore(globalConfig.AwsID, globalConfig.AwsSecret, globalConfig.BucketRegion, globalConfig.BucketName)
combo := store.NewDBBackedS3Store(s3, db) combo := store.NewDBBackedS3Store(s3, db)
reflectorServer := reflector.NewServer(combo) reflectorServer := reflector.NewServer(combo)
if err := reflectorServer.Start("localhost:" + strconv.Itoa(reflector.DefaultPort)); err != nil { err = reflectorServer.Start("localhost:" + strconv.Itoa(reflector.DefaultPort))
if err != nil {
log.Fatal(err) log.Fatal(err)
} }

View file

@ -58,7 +58,8 @@ func init() {
// Execute adds all child commands to the root command and sets flags appropriately. // Execute adds all child commands to the root command and sets flags appropriately.
// This is called by main.main(). It only needs to happen once to the rootCmd. // This is called by main.main(). It only needs to happen once to the rootCmd.
func Execute() { func Execute() {
if err := rootCmd.Execute(); err != nil { err := rootCmd.Execute()
if err != nil {
log.Errorln(err) log.Errorln(err)
os.Exit(1) os.Exit(1)
} }

View file

@ -3,9 +3,12 @@ package cmd
import ( import (
"os" "os"
"os/signal" "os/signal"
"strconv"
"syscall" "syscall"
"github.com/lbryio/reflector.go/db" "github.com/lbryio/reflector.go/db"
"github.com/lbryio/reflector.go/peer"
"github.com/lbryio/reflector.go/prism"
"github.com/lbryio/reflector.go/reflector" "github.com/lbryio/reflector.go/reflector"
"github.com/lbryio/reflector.go/store" "github.com/lbryio/reflector.go/store"
@ -13,6 +16,15 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
var (
startPeerPort int
startReflectorPort int
startDhtPort int
startDhtSeedPort int
startClusterPort int
startClusterSeedPort int
)
func init() { func init() {
var cmd = &cobra.Command{ var cmd = &cobra.Command{
Use: "start [cluster-address]", Use: "start [cluster-address]",
@ -20,6 +32,13 @@ func init() {
Run: startCmd, Run: startCmd,
Args: cobra.RangeArgs(0, 1), Args: cobra.RangeArgs(0, 1),
} }
cmd.PersistentFlags().IntVar(&startDhtPort, "dht-port", 4570, "Port to start DHT on")
cmd.PersistentFlags().IntVar(&startDhtSeedPort, "dht-seed-port", 4567, "Port to connect to DHT bootstrap node on")
cmd.PersistentFlags().IntVar(&startClusterPort, "cluster-port", 5678, "Port to start DHT on")
cmd.PersistentFlags().IntVar(&startClusterSeedPort, "cluster-seed-port", 0, "Port to start DHT on")
cmd.PersistentFlags().IntVar(&startPeerPort, "peer-port", peer.DefaultPort, "Port to start peer protocol on")
cmd.PersistentFlags().IntVar(&startReflectorPort, "reflector-port", reflector.DefaultPort, "Port to start reflector protocol on")
rootCmd.AddCommand(cmd) rootCmd.AddCommand(cmd)
} }
@ -30,15 +49,27 @@ func startCmd(cmd *cobra.Command, args []string) {
s3 := store.NewS3BlobStore(globalConfig.AwsID, globalConfig.AwsSecret, globalConfig.BucketRegion, globalConfig.BucketName) s3 := store.NewS3BlobStore(globalConfig.AwsID, globalConfig.AwsSecret, globalConfig.BucketRegion, globalConfig.BucketName)
comboStore := store.NewDBBackedS3Store(s3, db) comboStore := store.NewDBBackedS3Store(s3, db)
clusterAddr := "" //clusterAddr := ""
if len(args) > 0 { //if len(args) > 0 {
clusterAddr = args[0] // clusterAddr = args[0]
//}
conf := prism.DefaultConf()
conf.DB = db
conf.Blobs = comboStore
conf.DhtAddress = "127.0.0.1:" + strconv.Itoa(startDhtPort)
conf.DhtSeedNodes = []string{"127.0.0.1:" + strconv.Itoa(startDhtSeedPort)}
conf.ClusterPort = startClusterPort
if startClusterSeedPort > 0 {
conf.ClusterSeedAddr = "127.0.0.1:" + strconv.Itoa(startClusterSeedPort)
} }
p := reflector.NewPrism(comboStore, clusterAddr) p := prism.New(conf)
if err = p.Start(); err != nil { err = p.Start()
if err != nil {
log.Fatal(err) log.Fatal(err)
} }
interruptChan := make(chan os.Signal, 1) interruptChan := make(chan os.Signal, 1)
signal.Notify(interruptChan, os.Interrupt, syscall.SIGTERM) signal.Notify(interruptChan, os.Interrupt, syscall.SIGTERM)
<-interruptChan <-interruptChan

View file

@ -174,7 +174,8 @@ func launchFileUploader(params *uploaderParams, blobStore *store.DBBackedS3Store
if isJSON(blob) { if isJSON(blob) {
log.Printf("worker %d: PUTTING SD BLOB %s", worker, hash) log.Printf("worker %d: PUTTING SD BLOB %s", worker, hash)
if err := blobStore.PutSD(hash, blob); err != nil { err := blobStore.PutSD(hash, blob)
if err != nil {
log.Error("PutSD Error: ", err) log.Error("PutSD Error: ", err)
} }
select { select {
@ -183,7 +184,8 @@ func launchFileUploader(params *uploaderParams, blobStore *store.DBBackedS3Store
} }
} else { } else {
log.Printf("worker %d: putting %s", worker, hash) log.Printf("worker %d: putting %s", worker, hash)
if err := blobStore.Put(hash, blob); err != nil { err := blobStore.Put(hash, blob)
if err != nil {
log.Error("Put Blob Error: ", err) log.Error("Put Blob Error: ", err)
} }
select { select {

103
db/db.go
View file

@ -1,10 +1,12 @@
package db package db
import ( import (
"context"
"database/sql" "database/sql"
"github.com/lbryio/lbry.go/errors" "github.com/lbryio/lbry.go/errors"
"github.com/lbryio/lbry.go/querytools" "github.com/lbryio/lbry.go/querytools"
"github.com/lbryio/reflector.go/dht/bits"
"github.com/lbryio/reflector.go/types" "github.com/lbryio/reflector.go/types"
// blank import for db driver // blank import for db driver
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
@ -217,6 +219,57 @@ func (s *SQL) AddSDBlob(sdHash string, sdBlobLength int, sdBlob types.SdBlob) er
}) })
} }
// GetHashesInRange gets blobs with hashes in a given range, and sends the hashes into a channel
func (s *SQL) GetHashesInRange(ctx context.Context, start, end bits.Bitmap) (ch chan bits.Bitmap, ech chan error) {
ch = make(chan bits.Bitmap)
ech = make(chan error)
// TODO: needs waitgroup?
go func() {
defer close(ch)
defer close(ech)
if s.conn == nil {
ech <- errors.Err("not connected")
return
}
query := "SELECT hash FROM blob_ WHERE hash >= ? AND hash <= ?"
args := []interface{}{start.Hex(), end.Hex()}
logQuery(query, args...)
rows, err := s.conn.Query(query, args...)
defer closeRows(rows)
if err != nil {
ech <- err
return
}
var hash string
for rows.Next() {
err := rows.Scan(&hash)
if err != nil {
ech <- err
return
}
select {
case <-ctx.Done():
break
case ch <- bits.FromHexP(hash):
}
}
err = rows.Err()
if err != nil {
ech <- err
return
}
}()
return
}
// txFunc is a function that can be wrapped in a transaction // txFunc is a function that can be wrapped in a transaction
type txFunc func(tx *sql.Tx) error type txFunc func(tx *sql.Tx) error
@ -255,14 +308,16 @@ func withTx(dbOrTx interface{}, f txFunc) (err error) {
} }
func closeRows(rows *sql.Rows) { func closeRows(rows *sql.Rows) {
if err := rows.Close(); err != nil { if rows != nil {
log.Error("error closing rows: ", err) err := rows.Close()
if err != nil {
log.Error("error closing rows: ", err)
}
} }
} }
/*// func to generate schema. SQL below that. /* SQL schema
func schema() {
_ = `
CREATE TABLE blob_ ( CREATE TABLE blob_ (
hash char(96) NOT NULL, hash char(96) NOT NULL,
stored TINYINT(1) NOT NULL DEFAULT 0, stored TINYINT(1) NOT NULL DEFAULT 0,
@ -270,7 +325,7 @@ CREATE TABLE blob_ (
last_announced_at datetime DEFAULT NULL, last_announced_at datetime DEFAULT NULL,
PRIMARY KEY (hash), PRIMARY KEY (hash),
KEY last_announced_at_idx (last_announced_at) KEY last_announced_at_idx (last_announced_at)
) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; );
CREATE TABLE stream ( CREATE TABLE stream (
hash char(96) NOT NULL, hash char(96) NOT NULL,
@ -278,7 +333,7 @@ CREATE TABLE stream (
PRIMARY KEY (hash), PRIMARY KEY (hash),
KEY sd_hash_idx (sd_hash), KEY sd_hash_idx (sd_hash),
FOREIGN KEY (sd_hash) REFERENCES blob_ (hash) ON DELETE RESTRICT ON UPDATE CASCADE FOREIGN KEY (sd_hash) REFERENCES blob_ (hash) ON DELETE RESTRICT ON UPDATE CASCADE
) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; );
CREATE TABLE stream_blob ( CREATE TABLE stream_blob (
stream_hash char(96) NOT NULL, stream_hash char(96) NOT NULL,
@ -287,38 +342,8 @@ CREATE TABLE stream_blob (
PRIMARY KEY (stream_hash, blob_hash), PRIMARY KEY (stream_hash, blob_hash),
FOREIGN KEY (stream_hash) REFERENCES stream (hash) ON DELETE CASCADE ON UPDATE CASCADE, FOREIGN KEY (stream_hash) REFERENCES stream (hash) ON DELETE CASCADE ON UPDATE CASCADE,
FOREIGN KEY (blob_hash) REFERENCES blob_ (hash) ON DELETE CASCADE ON UPDATE CASCADE FOREIGN KEY (blob_hash) REFERENCES blob_ (hash) ON DELETE CASCADE ON UPDATE CASCADE
) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; );
` could add UNIQUE KEY (stream_hash, num) to stream_blob ...
}*/
/* SQL script to create schema
CREATE TABLE `reflector`.`blob_`
(
`hash` char(96) NOT NULL,
`stored` TINYINT(1) NOT NULL DEFAULT 0,
`length` bigint(20) unsigned DEFAULT NULL,
`last_announced_at` datetime DEFAULT NULL,
PRIMARY KEY (`hash`),
KEY `last_announced_at_idx` (`last_announced_at`)
) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
CREATE TABLE `reflector`.`stream`
(
`hash` char(96) NOT NULL,
`sd_hash` char(96) NOT NULL,
PRIMARY KEY (hash),
KEY `sd_hash_idx` (`sd_hash`),
FOREIGN KEY (`sd_hash`) REFERENCES `blob_` (`hash`) ON DELETE RESTRICT ON UPDATE CASCADE
) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
CREATE TABLE `reflector`.`stream_blob`
(
`stream_hash` char(96) NOT NULL,
`blob_hash` char(96) NOT NULL,
`num` int NOT NULL,
PRIMARY KEY (`stream_hash`, `blob_hash`),
FOREIGN KEY (`stream_hash`) REFERENCES `stream` (`hash`) ON DELETE CASCADE ON UPDATE CASCADE,
FOREIGN KEY (`blob_hash`) REFERENCES `blob_` (`hash`) ON DELETE CASCADE ON UPDATE CASCADE
) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
*/ */

View file

@ -3,6 +3,7 @@ package bits
import ( import (
"crypto/rand" "crypto/rand"
"encoding/hex" "encoding/hex"
"math/big"
"strconv" "strconv"
"strings" "strings"
@ -26,6 +27,12 @@ func (b Bitmap) String() string {
return string(b[:]) return string(b[:])
} }
func (b Bitmap) Big() *big.Int {
i := new(big.Int)
i.SetString(b.Hex(), 16)
return i
}
// BString returns the bitmap as a string of 0s and 1s // BString returns the bitmap as a string of 0s and 1s
func (b Bitmap) BString() string { func (b Bitmap) BString() string {
var s string var s string
@ -343,6 +350,15 @@ func FromShortHexP(hexStr string) Bitmap {
return bmp return bmp
} }
func FromBigP(b *big.Int) Bitmap {
return FromShortHexP(b.Text(16))
}
// Max returns a bitmap with all bits set to 1
func MaxP() Bitmap {
return FromHexP(strings.Repeat("1", NumBytes*2))
}
// Rand generates a cryptographically random bitmap with the confines of the parameters specified. // Rand generates a cryptographically random bitmap with the confines of the parameters specified.
func Rand() Bitmap { func Rand() Bitmap {
var id Bitmap var id Bitmap

View file

@ -159,7 +159,8 @@ func (b *BootstrapNode) check() {
func (b *BootstrapNode) handleRequest(addr *net.UDPAddr, request Request) { func (b *BootstrapNode) handleRequest(addr *net.UDPAddr, request Request) {
switch request.Method { switch request.Method {
case pingMethod: case pingMethod:
if err := b.sendMessage(addr, Response{ID: request.ID, NodeID: b.id, Data: pingSuccessResponse}); err != nil { err := b.sendMessage(addr, Response{ID: request.ID, NodeID: b.id, Data: pingSuccessResponse})
if err != nil {
log.Error("error sending response message - ", err) log.Error("error sending response message - ", err)
} }
case findNodeMethod: case findNodeMethod:
@ -167,11 +168,13 @@ func (b *BootstrapNode) handleRequest(addr *net.UDPAddr, request Request) {
log.Errorln("request is missing arg") log.Errorln("request is missing arg")
return return
} }
if err := b.sendMessage(addr, Response{
err := b.sendMessage(addr, Response{
ID: request.ID, ID: request.ID,
NodeID: b.id, NodeID: b.id,
Contacts: b.get(bucketSize), Contacts: b.get(bucketSize),
}); err != nil { })
if err != nil {
log.Error("error sending 'findnodemethod' response message - ", err) log.Error("error sending 'findnodemethod' response message - ", err)
} }
} }

View file

@ -10,15 +10,15 @@ import (
func TestBootstrapPing(t *testing.T) { func TestBootstrapPing(t *testing.T) {
b := NewBootstrapNode(bits.Rand(), 10, bootstrapDefaultRefreshDuration) b := NewBootstrapNode(bits.Rand(), 10, bootstrapDefaultRefreshDuration)
listener, err := net.ListenPacket(network, "127.0.0.1:54320") listener, err := net.ListenPacket(Network, "127.0.0.1:54320")
if err != nil { if err != nil {
panic(err) panic(err)
} }
if err := b.Connect(listener.(*net.UDPConn)); err != nil { err = b.Connect(listener.(*net.UDPConn))
if err != nil {
t.Error(err) t.Error(err)
} }
defer b.Shutdown()
b.Shutdown() b.Shutdown()
} }

View file

@ -21,7 +21,7 @@ func init() {
} }
const ( const (
network = "udp4" Network = "udp4"
// TODO: all these constants should be defaults, and should be used to set values in the standard Config. then the code should use values in the config // TODO: all these constants should be defaults, and should be used to set values in the standard Config. then the code should use values in the config
// TODO: alternatively, have a global Config for constants. at least that way tests can modify the values // TODO: alternatively, have a global Config for constants. at least that way tests can modify the values
@ -90,26 +90,57 @@ type DHT struct {
} }
// New returns a DHT pointer. If config is nil, then config will be set to the default config. // New returns a DHT pointer. If config is nil, then config will be set to the default config.
func New(config *Config) (*DHT, error) { func New(config *Config) *DHT {
if config == nil { if config == nil {
config = NewStandardConfig() config = NewStandardConfig()
} }
contact, err := getContact(config.NodeID, config.Address)
if err != nil {
return nil, err
}
d := &DHT{ d := &DHT{
conf: config, conf: config,
contact: contact,
node: NewNode(contact.ID),
stop: stopOnce.New(), stop: stopOnce.New(),
joined: make(chan struct{}), joined: make(chan struct{}),
lock: &sync.RWMutex{}, lock: &sync.RWMutex{},
announced: make(map[bits.Bitmap]bool), announced: make(map[bits.Bitmap]bool),
} }
return d, nil return d
}
func (dht *DHT) connect(conn UDPConn) error {
contact, err := getContact(dht.conf.NodeID, dht.conf.Address)
if err != nil {
return err
}
dht.contact = contact
dht.node = NewNode(contact.ID)
err = dht.node.Connect(conn)
if err != nil {
return err
}
return nil
}
// Start starts the dht
func (dht *DHT) Start() error {
listener, err := net.ListenPacket(Network, dht.conf.Address)
if err != nil {
return errors.Err(err)
}
conn := listener.(*net.UDPConn)
err = dht.connect(conn)
if err != nil {
return err
}
dht.join()
log.Debugf("[%s] DHT ready on %s (%d nodes found during join)",
dht.node.id.HexShort(), dht.contact.Addr().String(), dht.node.rt.Count())
go dht.startReannouncer()
return nil
} }
// join makes current node join the dht network. // join makes current node join the dht network.
@ -144,27 +175,6 @@ func (dht *DHT) join() {
// http://xlattice.sourceforge.net/components/protocol/kademlia/specs.html#join // http://xlattice.sourceforge.net/components/protocol/kademlia/specs.html#join
} }
// Start starts the dht
func (dht *DHT) Start() error {
listener, err := net.ListenPacket(network, dht.conf.Address)
if err != nil {
return errors.Err(err)
}
conn := listener.(*net.UDPConn)
err = dht.node.Connect(conn)
if err != nil {
return err
}
dht.join()
log.Debugf("[%s] DHT ready on %s (%d nodes found during join)",
dht.node.id.HexShort(), dht.contact.Addr().String(), dht.node.rt.Count())
go dht.startReannouncer()
return nil
}
// WaitUntilJoined blocks until the node joins the network. // WaitUntilJoined blocks until the node joins the network.
func (dht *DHT) WaitUntilJoined() { func (dht *DHT) WaitUntilJoined() {
if dht.joined == nil { if dht.joined == nil {
@ -184,7 +194,7 @@ func (dht *DHT) Shutdown() {
// Ping pings a given address, creates a temporary contact for sending a message, and returns an error if communication // Ping pings a given address, creates a temporary contact for sending a message, and returns an error if communication
// fails. // fails.
func (dht *DHT) Ping(addr string) error { func (dht *DHT) Ping(addr string) error {
raddr, err := net.ResolveUDPAddr(network, addr) raddr, err := net.ResolveUDPAddr(Network, addr)
if err != nil { if err != nil {
return err return err
} }
@ -211,8 +221,20 @@ func (dht *DHT) Get(hash bits.Bitmap) ([]Contact, error) {
return nil, nil return nil, nil
} }
// Add adds the hash to the list of hashes this node has
func (dht *DHT) Add(hash bits.Bitmap) error {
// TODO: calling Add several times quickly could cause it to be announced multiple times before dht.announced[hash] is set to true
dht.lock.RLock()
exists := dht.announced[hash]
dht.lock.RUnlock()
if exists {
return nil
}
return dht.announce(hash)
}
// Announce announces to the DHT that this node has the blob for the given hash // Announce announces to the DHT that this node has the blob for the given hash
func (dht *DHT) Announce(hash bits.Bitmap) error { func (dht *DHT) announce(hash bits.Bitmap) error {
contacts, _, err := FindContacts(dht.node, hash, false, dht.stop.Ch()) contacts, _, err := FindContacts(dht.node, hash, false, dht.stop.Ch())
if err != nil { if err != nil {
return err return err
@ -254,7 +276,7 @@ func (dht *DHT) startReannouncer() {
dht.stop.Add(1) dht.stop.Add(1)
go func(bm bits.Bitmap) { go func(bm bits.Bitmap) {
defer dht.stop.Done() defer dht.stop.Done()
err := dht.Announce(bm) err := dht.announce(bm)
if err != nil { if err != nil {
log.Error("error re-announcing bitmap - ", err) log.Error("error re-announcing bitmap - ", err)
} }

View file

@ -121,7 +121,8 @@ func TestDHT_LargeDHT(t *testing.T) {
wg.Add(1) wg.Add(1)
go func(index int) { go func(index int) {
defer wg.Done() defer wg.Done()
if err := dhts[index].Announce(ids[index]); err != nil { err := dhts[index].announce(ids[index])
if err != nil {
t.Error("error announcing random bitmap - ", err) t.Error("error announcing random bitmap - ", err)
} }
}(i) }(i)

View file

@ -229,7 +229,8 @@ func (n *Node) handleRequest(addr *net.UDPAddr, request Request) {
log.Errorln("invalid request method") log.Errorln("invalid request method")
return return
case pingMethod: case pingMethod:
if err := n.sendMessage(addr, Response{ID: request.ID, NodeID: n.id, Data: pingSuccessResponse}); err != nil { err := n.sendMessage(addr, Response{ID: request.ID, NodeID: n.id, Data: pingSuccessResponse})
if err != nil {
log.Error("error sending 'pingmethod' response message - ", err) log.Error("error sending 'pingmethod' response message - ", err)
} }
case storeMethod: case storeMethod:
@ -237,11 +238,14 @@ func (n *Node) handleRequest(addr *net.UDPAddr, request Request) {
// TODO: should we be using StoreArgs.NodeID or StoreArgs.Value.LbryID ??? // TODO: should we be using StoreArgs.NodeID or StoreArgs.Value.LbryID ???
if n.tokens.Verify(request.StoreArgs.Value.Token, request.NodeID, addr) { if n.tokens.Verify(request.StoreArgs.Value.Token, request.NodeID, addr) {
n.Store(request.StoreArgs.BlobHash, Contact{ID: request.StoreArgs.NodeID, IP: addr.IP, Port: request.StoreArgs.Value.Port}) n.Store(request.StoreArgs.BlobHash, Contact{ID: request.StoreArgs.NodeID, IP: addr.IP, Port: request.StoreArgs.Value.Port})
if err := n.sendMessage(addr, Response{ID: request.ID, NodeID: n.id, Data: storeSuccessResponse}); err != nil {
err := n.sendMessage(addr, Response{ID: request.ID, NodeID: n.id, Data: storeSuccessResponse})
if err != nil {
log.Error("error sending 'storemethod' response message - ", err) log.Error("error sending 'storemethod' response message - ", err)
} }
} else { } else {
if err := n.sendMessage(addr, Error{ID: request.ID, NodeID: n.id, ExceptionType: "invalid-token"}); err != nil { err := n.sendMessage(addr, Error{ID: request.ID, NodeID: n.id, ExceptionType: "invalid-token"})
if err != nil {
log.Error("error sending 'storemethod'response message for invalid-token - ", err) log.Error("error sending 'storemethod'response message for invalid-token - ", err)
} }
} }
@ -250,11 +254,12 @@ func (n *Node) handleRequest(addr *net.UDPAddr, request Request) {
log.Errorln("request is missing arg") log.Errorln("request is missing arg")
return return
} }
if err := n.sendMessage(addr, Response{ err := n.sendMessage(addr, Response{
ID: request.ID, ID: request.ID,
NodeID: n.id, NodeID: n.id,
Contacts: n.rt.GetClosest(*request.Arg, bucketSize), Contacts: n.rt.GetClosest(*request.Arg, bucketSize),
}); err != nil { })
if err != nil {
log.Error("error sending 'findnodemethod' response message - ", err) log.Error("error sending 'findnodemethod' response message - ", err)
} }
@ -277,7 +282,8 @@ func (n *Node) handleRequest(addr *net.UDPAddr, request Request) {
res.Contacts = n.rt.GetClosest(*request.Arg, bucketSize) res.Contacts = n.rt.GetClosest(*request.Arg, bucketSize)
} }
if err := n.sendMessage(addr, res); err != nil { err := n.sendMessage(addr, res)
if err != nil {
log.Error("error sending 'findvaluemethod' response message - ", err) log.Error("error sending 'findvaluemethod' response message - ", err)
} }
} }
@ -322,7 +328,8 @@ func (n *Node) sendMessage(addr *net.UDPAddr, data Message) error {
log.Debugf("[%s] (%d bytes) %s", n.id.HexShort(), len(encoded), spew.Sdump(data)) log.Debugf("[%s] (%d bytes) %s", n.id.HexShort(), len(encoded), spew.Sdump(data))
} }
if err := n.conn.SetWriteDeadline(time.Now().Add(5 * time.Second)); err != nil { err = n.conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
if err != nil {
log.Error("error setting write deadline - ", err) log.Error("error setting write deadline - ", err)
} }
@ -391,7 +398,8 @@ func (n *Node) SendAsync(ctx context.Context, contact Contact, req Request) <-ch
defer n.txDelete(tx.req.ID) defer n.txDelete(tx.req.ID)
for i := 0; i < udpRetry; i++ { for i := 0; i < udpRetry; i++ {
if err := n.sendMessage(contact.Addr(), tx.req); err != nil { err := n.sendMessage(contact.Addr(), tx.req)
if err != nil {
if !strings.Contains(err.Error(), "use of closed network connection") { // this only happens on localhost. real UDP has no connections if !strings.Contains(err.Error(), "use of closed network connection") { // this only happens on localhost. real UDP has no connections
log.Error("send error: ", err) log.Error("send error: ", err)
} }

View file

@ -15,12 +15,9 @@ func TestPing(t *testing.T) {
conn := newTestUDPConn("127.0.0.1:21217") conn := newTestUDPConn("127.0.0.1:21217")
dht, err := New(&Config{Address: "127.0.0.1:21216", NodeID: dhtNodeID.Hex()}) dht := New(&Config{Address: "127.0.0.1:21216", NodeID: dhtNodeID.Hex()})
if err != nil {
t.Fatal(err)
}
err = dht.node.Connect(conn) err := dht.connect(conn)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -112,12 +109,9 @@ func TestStore(t *testing.T) {
conn := newTestUDPConn("127.0.0.1:21217") conn := newTestUDPConn("127.0.0.1:21217")
dht, err := New(&Config{Address: "127.0.0.1:21216", NodeID: dhtNodeID.Hex()}) dht := New(&Config{Address: "127.0.0.1:21216", NodeID: dhtNodeID.Hex()})
if err != nil {
t.Fatal(err)
}
err = dht.node.Connect(conn) err := dht.connect(conn)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -210,12 +204,9 @@ func TestFindNode(t *testing.T) {
conn := newTestUDPConn("127.0.0.1:21217") conn := newTestUDPConn("127.0.0.1:21217")
dht, err := New(&Config{Address: "127.0.0.1:21216", NodeID: dhtNodeID.Hex()}) dht := New(&Config{Address: "127.0.0.1:21216", NodeID: dhtNodeID.Hex()})
if err != nil {
t.Fatal(err)
}
err = dht.node.Connect(conn) err := dht.connect(conn)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -279,12 +270,9 @@ func TestFindValueExisting(t *testing.T) {
conn := newTestUDPConn("127.0.0.1:21217") conn := newTestUDPConn("127.0.0.1:21217")
dht, err := New(&Config{Address: "127.0.0.1:21216", NodeID: dhtNodeID.Hex()}) dht := New(&Config{Address: "127.0.0.1:21216", NodeID: dhtNodeID.Hex()})
if err != nil {
t.Fatal(err)
}
err = dht.node.Connect(conn) err := dht.connect(conn)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -363,12 +351,9 @@ func TestFindValueFallbackToFindNode(t *testing.T) {
conn := newTestUDPConn("127.0.0.1:21217") conn := newTestUDPConn("127.0.0.1:21217")
dht, err := New(&Config{Address: "127.0.0.1:21216", NodeID: dhtNodeID.Hex()}) dht := New(&Config{Address: "127.0.0.1:21216", NodeID: dhtNodeID.Hex()})
if err != nil {
t.Fatal(err)
}
err = dht.node.Connect(conn) err := dht.connect(conn)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View file

@ -23,11 +23,13 @@ func TestingCreateDHT(t *testing.T, numNodes int, bootstrap, concurrent bool) (*
bootstrapAddress := testingDHTIP + ":" + strconv.Itoa(testingDHTFirstPort) bootstrapAddress := testingDHTIP + ":" + strconv.Itoa(testingDHTFirstPort)
seeds = []string{bootstrapAddress} seeds = []string{bootstrapAddress}
bootstrapNode = NewBootstrapNode(bits.Rand(), 0, bootstrapDefaultRefreshDuration) bootstrapNode = NewBootstrapNode(bits.Rand(), 0, bootstrapDefaultRefreshDuration)
listener, err := net.ListenPacket(network, bootstrapAddress) listener, err := net.ListenPacket(Network, bootstrapAddress)
if err != nil { if err != nil {
panic(err) panic(err)
} }
if err := bootstrapNode.Connect(listener.(*net.UDPConn)); err != nil {
err = bootstrapNode.Connect(listener.(*net.UDPConn))
if err != nil {
t.Error("error connecting bootstrap node - ", err) t.Error("error connecting bootstrap node - ", err)
} }
} }
@ -40,13 +42,11 @@ func TestingCreateDHT(t *testing.T, numNodes int, bootstrap, concurrent bool) (*
dhts := make([]*DHT, numNodes) dhts := make([]*DHT, numNodes)
for i := 0; i < numNodes; i++ { for i := 0; i < numNodes; i++ {
dht, err := New(&Config{Address: testingDHTIP + ":" + strconv.Itoa(firstPort+i), NodeID: bits.Rand().Hex(), SeedNodes: seeds}) dht := New(&Config{Address: testingDHTIP + ":" + strconv.Itoa(firstPort+i), NodeID: bits.Rand().Hex(), SeedNodes: seeds})
if err != nil {
panic(err)
}
go func() { go func() {
if err := dht.Start(); err != nil { err := dht.Start()
if err != nil {
t.Error("error starting dht - ", err) t.Error("error starting dht - ", err)
} }
}() }()

View file

@ -45,11 +45,11 @@ func NewServer(store store.BlobStore) *Server {
func (s *Server) Shutdown() { func (s *Server) Shutdown() {
log.Debug("shutting down peer server...") log.Debug("shutting down peer server...")
s.stop.StopAndWait() s.stop.StopAndWait()
log.Debug("peer server stopped")
} }
// Start starts the server listener to handle connections. // Start starts the server listener to handle connections.
func (s *Server) Start(address string) error { func (s *Server) Start(address string) error {
log.Println("peer listening on " + address) log.Println("peer listening on " + address)
l, err := net.Listen("tcp", address) l, err := net.Listen("tcp", address)
if err != nil { if err != nil {
@ -59,8 +59,8 @@ func (s *Server) Start(address string) error {
go s.listenForShutdown(l) go s.listenForShutdown(l)
s.stop.Add(1) s.stop.Add(1)
go func() { go func() {
defer s.stop.Done()
s.listenAndServe(l) s.listenAndServe(l)
s.stop.Done()
}() }()
return nil return nil
@ -69,7 +69,8 @@ func (s *Server) Start(address string) error {
func (s *Server) listenForShutdown(listener net.Listener) { func (s *Server) listenForShutdown(listener net.Listener) {
<-s.stop.Ch() <-s.stop.Ch()
s.closed = true s.closed = true
if err := listener.Close(); err != nil { err := listener.Close()
if err != nil {
log.Error("error closing listener for peer server - ", err) log.Error("error closing listener for peer server - ", err)
} }
} }
@ -84,13 +85,21 @@ func (s *Server) listenAndServe(listener net.Listener) {
log.Error(err) log.Error(err)
} else { } else {
s.stop.Add(1) s.stop.Add(1)
go s.handleConnection(conn) go func() {
s.handleConnection(conn)
s.stop.Done()
}()
} }
} }
} }
func (s *Server) handleConnection(conn net.Conn) { func (s *Server) handleConnection(conn net.Conn) {
defer s.stop.Done() defer func() {
if err := conn.Close(); err != nil {
log.Error(errors.Prefix("closing peer conn", err))
}
}()
timeoutDuration := 5 * time.Second timeoutDuration := 5 * time.Second
for { for {
@ -98,9 +107,11 @@ func (s *Server) handleConnection(conn net.Conn) {
var response []byte var response []byte
var err error var err error
if err := conn.SetReadDeadline(time.Now().Add(timeoutDuration)); err != nil { err = conn.SetReadDeadline(time.Now().Add(timeoutDuration))
log.Error("error setting read deadline for client connection - ", err) if err != nil {
log.Error(errors.FullTrace(err))
} }
request, err = readNextRequest(conn) request, err = readNextRequest(conn)
if err != nil { if err != nil {
if err != io.EOF { if err != io.EOF {
@ -108,8 +119,10 @@ func (s *Server) handleConnection(conn net.Conn) {
} }
return return
} }
if err := conn.SetReadDeadline(time.Time{}); err != nil {
log.Error("error setting read deadline client connection - ", err) err = conn.SetReadDeadline(time.Time{})
if err != nil {
log.Error(errors.FullTrace(err))
} }
if strings.Contains(string(request), `"requested_blobs"`) { if strings.Contains(string(request), `"requested_blobs"`) {

View file

@ -37,7 +37,8 @@ func getServer(t *testing.T, withBlobs bool) *Server {
st := store.MemoryBlobStore{} st := store.MemoryBlobStore{}
if withBlobs { if withBlobs {
for k, v := range blobs { for k, v := range blobs {
if err := st.Put(k, v); err != nil { err := st.Put(k, v)
if err != nil {
t.Error("error during put operation of memory blobstore - ", err) t.Error("error during put operation of memory blobstore - ", err)
} }
} }

197
prism/prism.go Normal file
View file

@ -0,0 +1,197 @@
package prism
import (
"context"
"math/big"
"strconv"
"sync"
"github.com/lbryio/reflector.go/cluster"
"github.com/lbryio/reflector.go/db"
"github.com/lbryio/reflector.go/dht"
"github.com/lbryio/reflector.go/dht/bits"
"github.com/lbryio/reflector.go/peer"
"github.com/lbryio/reflector.go/reflector"
"github.com/lbryio/reflector.go/store"
"github.com/lbryio/lbry.go/errors"
"github.com/lbryio/lbry.go/stopOnce"
log "github.com/sirupsen/logrus"
)
type Config struct {
PeerPort int
ReflectorPort int
DhtAddress string
DhtSeedNodes []string
ClusterPort int
ClusterSeedAddr string
DB *db.SQL
Blobs store.BlobStore
}
// DefaultConf returns a default config
func DefaultConf() *Config {
return &Config{
ClusterPort: cluster.DefaultClusterPort,
}
}
// Prism is the root instance of the application and houses the DHT, Peer Server, Reflector Server, and Cluster.
type Prism struct {
conf *Config
db *db.SQL
dht *dht.DHT
peer *peer.Server
reflector *reflector.Server
cluster *cluster.Cluster
stop *stopOnce.Stopper
}
// New returns an initialized Prism instance
func New(conf *Config) *Prism {
if conf == nil {
conf = DefaultConf()
}
dhtConf := dht.NewStandardConfig()
dhtConf.Address = conf.DhtAddress
dhtConf.SeedNodes = conf.DhtSeedNodes
d := dht.New(dhtConf)
c := cluster.New(conf.ClusterPort, conf.ClusterSeedAddr)
p := &Prism{
conf: conf,
db: conf.DB,
dht: d,
cluster: c,
peer: peer.NewServer(conf.Blobs),
reflector: reflector.NewServer(conf.Blobs),
stop: stopOnce.New(),
}
c.OnHashRangeChange = func(n, total int) {
p.stop.Add(1)
go func() {
p.AnnounceRange(n, total)
p.stop.Done()
}()
}
return p
}
// Start starts the components of the application.
func (p *Prism) Start() error {
if p.conf.DB == nil {
return errors.Err("db required in conf")
}
if p.conf.Blobs == nil {
return errors.Err("blobs required in conf")
}
err := p.dht.Start()
if err != nil {
return err
}
err = p.cluster.Connect()
if err != nil {
return err
}
// TODO: should not be localhost forever. should prolly be 0.0.0.0, or configurable
err = p.peer.Start("localhost:" + strconv.Itoa(p.conf.PeerPort))
if err != nil {
return err
}
// TODO: should not be localhost forever. should prolly be 0.0.0.0, or configurable
err = p.reflector.Start("localhost:" + strconv.Itoa(p.conf.ReflectorPort))
if err != nil {
return err
}
return nil
}
// Shutdown gracefully shuts down the different prism components before exiting.
func (p *Prism) Shutdown() {
p.stop.StopAndWait()
p.reflector.Shutdown()
p.peer.Shutdown()
p.cluster.Shutdown()
p.dht.Shutdown()
}
// AnnounceRange announces the `n`th interval of hashes, out of a total of `total` intervals
func (p *Prism) AnnounceRange(n, total int) {
// TODO: if more than one node is announcing each hash, figure out how to deal with last_announced_at so both nodes dont announce the same thing at the same time
// num and total are 1-indexed
if n < 1 {
log.Errorf("%s: n must be >= 1", p.dht.ID().HexShort())
return
}
max := bits.MaxP().Big()
interval := new(big.Int).Div(max, big.NewInt(int64(total)))
start := new(big.Int).Mul(interval, big.NewInt(int64(n-1)))
end := new(big.Int).Add(start, interval)
if n == total {
end = end.Add(end, big.NewInt(10000)) // there are rounding issues sometimes, so lets make sure we get the full range
}
if end.Cmp(max) > 0 {
end.Set(max)
}
log.Debugf("%s: hash range is now %s to %s\n", p.dht.ID().HexShort(), bits.FromBigP(start).Hex(), bits.FromBigP(end).Hex())
ctx, cancel := context.WithCancel(context.Background())
hashCh, errCh := p.db.GetHashesInRange(ctx, bits.FromBigP(start), bits.FromBigP(end))
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
select {
case <-p.stop.Ch():
return
case err, more := <-errCh:
if more && err != nil {
log.Error(err)
}
}
}()
wg.Add(1)
go func() {
defer wg.Done()
for {
select {
case <-p.stop.Ch():
cancel()
return
case hash, more := <-hashCh:
if !more {
return
}
p.dht.Add(hash)
}
}
}()
wg.Wait()
}

40
prism/prism_test.go Normal file
View file

@ -0,0 +1,40 @@
package prism
import (
"math/big"
"testing"
"github.com/davecgh/go-spew/spew"
"github.com/lbryio/reflector.go/dht/bits"
)
func TestAnnounceRange(t *testing.T) {
t.Skip("TODO: this needs to actually test the thing")
total := 17
max := bits.MaxP().Big()
interval := bits.MaxP().Big()
spew.Dump(interval)
interval.Div(interval, big.NewInt(int64(total)))
for i := 0; i < total; i++ {
start := big.NewInt(0).Mul(interval, big.NewInt(int64(i)))
end := big.NewInt(0).Add(start, interval)
if i == total-1 {
end = end.Add(end, big.NewInt(10000)) // there are rounding issues sometimes, so lets make sure we get the full range
}
if end.Cmp(max) > 0 {
end.Set(max)
}
spew.Dump(i, start, end, bits.FromBigP(start).Hex(), bits.FromBigP(end).Hex())
}
//startB := bits.FromBigP(start)
//endB := bits.FromBigP(end)
//
//t.Logf("%s to %s\n", startB.Hex(), endB.Hex())
}

View file

@ -19,7 +19,8 @@ func TestMain(m *testing.M) {
log.Panic("could not create temp directory - ", err) log.Panic("could not create temp directory - ", err)
} }
defer func(directory string) { defer func(directory string) {
if err := os.RemoveAll(dir); err != nil { err := os.RemoveAll(dir)
if err != nil {
log.Panic("error removing files and directory - ", err) log.Panic("error removing files and directory - ", err)
} }
}(dir) }(dir)
@ -27,7 +28,8 @@ func TestMain(m *testing.M) {
ms := store.MemoryBlobStore{} ms := store.MemoryBlobStore{}
s := NewServer(&ms) s := NewServer(&ms)
go func() { go func() {
if err := s.Start(address); err != nil { err := s.Start(address)
if err != nil {
log.Panic("error starting up reflector server - ", err) log.Panic("error starting up reflector server - ", err)
} }
}() }()

View file

@ -1,70 +0,0 @@
package reflector
import (
"strconv"
"github.com/lbryio/lbry.go/stopOnce"
"github.com/lbryio/reflector.go/cluster"
"github.com/lbryio/reflector.go/dht"
"github.com/lbryio/reflector.go/peer"
"github.com/lbryio/reflector.go/store"
)
// Prism is the root instance of the application and houses the DHT, Peer Server, Reflector Server, and Cluster.
type Prism struct {
dht *dht.DHT
peer *peer.Server
reflector *Server
cluster *cluster.Cluster
stop *stopOnce.Stopper
}
// NewPrism returns an initialized Prism instance pointer.
func NewPrism(store store.BlobStore, clusterSeedAddr string) *Prism {
d, err := dht.New(nil)
if err != nil {
panic(err)
}
return &Prism{
dht: d,
peer: peer.NewServer(store),
reflector: NewServer(store),
cluster: cluster.New(cluster.DefaultClusterPort, clusterSeedAddr),
stop: stopOnce.New(),
}
}
// Start starts the components of the application.
func (p *Prism) Start() error {
err := p.dht.Start()
if err != nil {
return err
}
err = p.cluster.Connect()
if err != nil {
return err
}
err = p.peer.Start("localhost:" + strconv.Itoa(peer.DefaultPort))
if err != nil {
return err
}
err = p.reflector.Start("localhost:" + strconv.Itoa(DefaultPort))
if err != nil {
return err
}
return nil
}
// Shutdown gracefully shuts down the different prism components before exiting.
func (p *Prism) Shutdown() {
p.stop.StopAndWait()
p.reflector.Shutdown()
p.peer.Shutdown()
p.cluster.Shutdown()
p.dht.Shutdown()
}

View file

@ -34,6 +34,7 @@ func NewServer(store store.BlobStore) *Server {
func (s *Server) Shutdown() { func (s *Server) Shutdown() {
log.Debug("shutting down reflector server...") log.Debug("shutting down reflector server...")
s.stop.StopAndWait() s.stop.StopAndWait()
log.Debug("reflector server stopped")
} }
//Start starts the server listener to handle connections. //Start starts the server listener to handle connections.
@ -49,8 +50,8 @@ func (s *Server) Start(address string) error {
s.stop.Add(1) s.stop.Add(1)
go func() { go func() {
defer s.stop.Done()
s.listenAndServe(l) s.listenAndServe(l)
s.stop.Done()
}() }()
return nil return nil
@ -59,7 +60,8 @@ func (s *Server) Start(address string) error {
func (s *Server) listenForShutdown(listener net.Listener) { func (s *Server) listenForShutdown(listener net.Listener) {
<-s.stop.Ch() <-s.stop.Ch()
s.closed = true s.closed = true
if err := listener.Close(); err != nil { err := listener.Close()
if err != nil {
log.Error("error closing listener for peer server - ", err) log.Error("error closing listener for peer server - ", err)
} }
} }
@ -74,20 +76,30 @@ func (s *Server) listenAndServe(listener net.Listener) {
log.Error(err) log.Error(err)
} else { } else {
s.stop.Add(1) s.stop.Add(1)
go s.handleConn(conn) go func() {
s.handleConn(conn)
s.stop.Done()
}()
} }
} }
} }
func (s *Server) handleConn(conn net.Conn) { func (s *Server) handleConn(conn net.Conn) {
defer s.stop.Done() defer func() {
if err := conn.Close(); err != nil {
log.Error(errors.Prefix("closing peer conn", err))
}
}()
// TODO: connection should time out eventually // TODO: connection should time out eventually
err := s.doHandshake(conn) err := s.doHandshake(conn)
if err != nil { if err != nil {
if err == io.EOF { if err == io.EOF {
return return
} }
if err := s.doError(conn, err); err != nil { err := s.doError(conn, err)
if err != nil {
log.Error("error sending error response to reflector client connection - ", err) log.Error("error sending error response to reflector client connection - ", err)
} }
return return
@ -97,7 +109,8 @@ func (s *Server) handleConn(conn net.Conn) {
err = s.receiveBlob(conn) err = s.receiveBlob(conn)
if err != nil { if err != nil {
if err != io.EOF { if err != io.EOF {
if err := s.doError(conn, err); err != nil { err := s.doError(conn, err)
if err != nil {
log.Error("error sending error response for receiving a blob to reflector client connection - ", err) log.Error("error sending error response for receiving a blob to reflector client connection - ", err)
} }
} }

View file

@ -20,7 +20,8 @@ func TestMemoryBlobStore_Get(t *testing.T) {
s := MemoryBlobStore{} s := MemoryBlobStore{}
hash := "abc" hash := "abc"
blob := []byte("abcdefg") blob := []byte("abcdefg")
if err := s.Put(hash, blob); err != nil { err := s.Put(hash, blob)
if err != nil {
t.Error("error getting memory blob - ", err) t.Error("error getting memory blob - ", err)
} }