195 lines
4.4 KiB
Go
195 lines
4.4 KiB
Go
package dht
|
|
|
|
import (
|
|
"context"
|
|
"net"
|
|
"net/http"
|
|
"strconv"
|
|
"sync"
|
|
|
|
"github.com/lbryio/lbry.go/v3/dht/bits"
|
|
|
|
"github.com/cockroachdb/errors"
|
|
"github.com/gorilla/mux"
|
|
rpc2 "github.com/gorilla/rpc/v2"
|
|
"github.com/gorilla/rpc/v2/json"
|
|
)
|
|
|
|
type rpcReceiver struct {
|
|
dht *DHT
|
|
}
|
|
|
|
type RpcPingArgs struct {
|
|
Address string
|
|
}
|
|
|
|
func (rpc *rpcReceiver) Ping(r *http.Request, args *RpcPingArgs, result *string) error {
|
|
if args.Address == "" {
|
|
return errors.WithStack(errors.New("no address given"))
|
|
}
|
|
|
|
err := rpc.dht.Ping(args.Address)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
*result = pingSuccessResponse
|
|
return nil
|
|
}
|
|
|
|
type RpcFindArgs struct {
|
|
Key string
|
|
NodeID string
|
|
IP string
|
|
Port int
|
|
}
|
|
|
|
func (rpc *rpcReceiver) FindNode(r *http.Request, args *RpcFindArgs, result *[]Contact) error {
|
|
key, err := bits.FromHex(args.Key)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
toQuery, err := bits.FromHex(args.NodeID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
c := Contact{ID: toQuery, IP: net.ParseIP(args.IP), Port: args.Port}
|
|
req := Request{Method: findNodeMethod, Arg: &key}
|
|
|
|
nodeResponse := rpc.dht.node.Send(c, req)
|
|
if nodeResponse != nil && nodeResponse.Contacts != nil {
|
|
*result = nodeResponse.Contacts
|
|
}
|
|
return nil
|
|
}
|
|
|
|
type RpcFindValueResult struct {
|
|
Contacts []Contact
|
|
Value string
|
|
}
|
|
|
|
func (rpc *rpcReceiver) FindValue(r *http.Request, args *RpcFindArgs, result *RpcFindValueResult) error {
|
|
key, err := bits.FromHex(args.Key)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
toQuery, err := bits.FromHex(args.NodeID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
c := Contact{ID: toQuery, IP: net.ParseIP(args.IP), Port: args.Port}
|
|
req := Request{Arg: &key, Method: findValueMethod}
|
|
|
|
nodeResponse := rpc.dht.node.Send(c, req)
|
|
if nodeResponse != nil && nodeResponse.FindValueKey != "" {
|
|
*result = RpcFindValueResult{Value: nodeResponse.FindValueKey}
|
|
return nil
|
|
}
|
|
if nodeResponse != nil && nodeResponse.Contacts != nil {
|
|
*result = RpcFindValueResult{Contacts: nodeResponse.Contacts}
|
|
return nil
|
|
}
|
|
|
|
return errors.WithStack(errors.New("not sure what happened"))
|
|
}
|
|
|
|
type RpcIterativeFindValueArgs struct {
|
|
Key string
|
|
}
|
|
|
|
type RpcIterativeFindValueResult struct {
|
|
Contacts []Contact
|
|
FoundValue bool
|
|
Values []Contact
|
|
}
|
|
|
|
func (rpc *rpcReceiver) IterativeFindValue(r *http.Request, args *RpcIterativeFindValueArgs, result *RpcIterativeFindValueResult) error {
|
|
key, err := bits.FromHex(args.Key)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
foundContacts, found, err := FindContacts(rpc.dht.node, key, true, nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
result.Contacts = foundContacts
|
|
result.FoundValue = found
|
|
if found {
|
|
for _, contact := range foundContacts {
|
|
if contact.PeerPort > 0 {
|
|
result.Values = append(result.Values, contact)
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
type RpcBucketResponse struct {
|
|
Start string
|
|
End string
|
|
NumContacts int
|
|
Contacts []Contact
|
|
}
|
|
|
|
type RpcRoutingTableResponse struct {
|
|
NodeID string
|
|
NumBuckets int
|
|
Buckets []RpcBucketResponse
|
|
}
|
|
|
|
func (rpc *rpcReceiver) GetRoutingTable(r *http.Request, args *struct{}, result *RpcRoutingTableResponse) error {
|
|
result.NodeID = rpc.dht.node.id.String()
|
|
result.NumBuckets = len(rpc.dht.node.rt.buckets)
|
|
for _, b := range rpc.dht.node.rt.buckets {
|
|
result.Buckets = append(result.Buckets, RpcBucketResponse{
|
|
Start: b.Range.Start.String(),
|
|
End: b.Range.End.String(),
|
|
NumContacts: b.Len(),
|
|
Contacts: b.Contacts(),
|
|
})
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (rpc *rpcReceiver) AddKnownNode(r *http.Request, args *Contact, result *string) error {
|
|
rpc.dht.node.AddKnownNode(*args)
|
|
return nil
|
|
}
|
|
|
|
func (dht *DHT) runRPCServer(port int) {
|
|
addr := "0.0.0.0:" + strconv.Itoa(port)
|
|
|
|
s := rpc2.NewServer()
|
|
s.RegisterCodec(json.NewCodec(), "application/json")
|
|
s.RegisterCodec(json.NewCodec(), "application/json;charset=UTF-8")
|
|
err := s.RegisterService(&rpcReceiver{dht: dht}, "rpc")
|
|
if err != nil {
|
|
log.Error(errors.WithMessage(err, "registering rpc service"))
|
|
return
|
|
}
|
|
|
|
handler := mux.NewRouter()
|
|
handler.Handle("/", s)
|
|
server := &http.Server{Addr: addr, Handler: handler}
|
|
|
|
wg := sync.WaitGroup{}
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
log.Printf("RPC server listening on %s", addr)
|
|
err := server.ListenAndServe()
|
|
if err != nil && err != http.ErrServerClosed {
|
|
log.Error(err)
|
|
}
|
|
}()
|
|
|
|
<-dht.grp.Ch()
|
|
err = server.Shutdown(context.Background())
|
|
if err != nil {
|
|
log.Error(errors.WithMessage(err, "shutting down rpc service"))
|
|
return
|
|
}
|
|
wg.Wait()
|
|
}
|