lbry.go/dht/node.go
2021-10-06 16:28:17 -04:00

473 lines
13 KiB
Go

package dht
import (
"encoding/hex"
"net"
"strings"
"sync"
"time"
"github.com/cockroachdb/errors"
"github.com/davecgh/go-spew/spew"
"github.com/lbryio/lbry.go/v3/dht/bits"
"github.com/lbryio/lbry.go/v3/extras/stop"
"github.com/lyoshenka/bencode"
)
// packet represents the information receive from udp.
type packet struct {
data []byte
raddr *net.UDPAddr
}
// UDPConn allows using a mocked connection to test sending/receiving data
// TODO: stop mocking this and use the real thing
type UDPConn interface {
ReadFromUDP([]byte) (int, *net.UDPAddr, error)
WriteToUDP([]byte, *net.UDPAddr) (int, error)
SetReadDeadline(time.Time) error
SetWriteDeadline(time.Time) error
Close() error
}
// RequestHandlerFunc is exported handler for requests.
type RequestHandlerFunc func(addr *net.UDPAddr, request Request)
// Node is a type representation of a node on the network.
type Node struct {
// the node's id
id bits.Bitmap
// UDP connection for sending and receiving data
conn UDPConn
// true if we've closed the connection on purpose
connClosed bool
// token manager
tokens *tokenManager
// map of outstanding transactions + mutex
txLock *sync.RWMutex
transactions map[messageID]*transaction
// routing table
rt *routingTable
// data store
store *contactStore
// overrides for request handlers
requestHandler RequestHandlerFunc
// stop the node neatly and clean up after itself
grp *stop.Group
}
// NewNode returns an initialized Node's pointer.
func NewNode(id bits.Bitmap) *Node {
return &Node{
id: id,
rt: newRoutingTable(id),
store: newStore(),
txLock: &sync.RWMutex{},
transactions: make(map[messageID]*transaction),
grp: stop.New(),
tokens: &tokenManager{},
}
}
// Connect connects to the given connection and starts any background threads necessary
func (n *Node) Connect(conn UDPConn) error {
n.conn = conn
n.tokens.Start(tokenSecretRotationInterval)
go func() {
// stop tokens and close the connection when we're shutting down
<-n.grp.Ch()
n.tokens.Stop()
n.connClosed = true
err := n.conn.Close()
if err != nil {
log.Error("error closing node connection on shutdown - ", err)
}
}()
packets := make(chan packet)
n.grp.Add(1)
go func() {
defer n.grp.Done()
buf := make([]byte, udpMaxMessageLength)
for {
bytesRead, raddr, err := n.conn.ReadFromUDP(buf)
if err != nil {
if n.connClosed {
return
}
log.Errorf("udp read error: %v", err)
continue
} else if raddr == nil {
log.Errorf("udp read with no raddr")
continue
}
data := make([]byte, bytesRead)
copy(data, buf[:bytesRead]) // slices use the same underlying array, so we need a new one for each packet
select { // needs select here because packet consumer can quit and the packets channel gets filled up and blocks
case packets <- packet{data: data, raddr: raddr}:
case <-n.grp.Ch():
return
}
}
}()
n.grp.Add(1)
go func() {
defer n.grp.Done()
var pkt packet
for {
select {
case pkt = <-packets:
n.handlePacket(pkt)
case <-n.grp.Ch():
return
}
}
}()
// TODO: turn this back on when you're sure it works right
n.grp.Add(1)
go func() {
defer n.grp.Done()
n.startRoutingTableGrooming()
}()
return nil
}
// Shutdown shuts down the node
func (n *Node) Shutdown() {
log.Debugf("[%s] node shutting down", n.id.HexShort())
n.grp.StopAndWait()
log.Debugf("[%s] node stopped", n.id.HexShort())
}
// handlePacket handles packets received from udp.
func (n *Node) handlePacket(pkt packet) {
//log.Debugf("[%s] Received message from %s (%d bytes) %s", n.id.HexShort(), pkt.raddr.String(), len(pkt.data), hex.EncodeToString(pkt.data))
firstFive := string(pkt.data[0:5])
if firstFive != "d1:0i" && firstFive != "di0ei" {
log.Errorf("[%s] data is not a well-formatted dict: (%d bytes) %s", n.id.HexShort(), len(pkt.data), hex.EncodeToString(pkt.data))
return
}
// the following is a bit of a hack, but it lets us avoid decoding every message twice
// it depends on the data being a dict with 0 as the first key (so it starts with "d1:0i") and the message type as the first value
// TODO: test this more thoroughly
switch pkt.data[5] {
case '0' + requestType:
request := Request{}
err := bencode.DecodeBytes(pkt.data, &request)
if err != nil {
log.Errorf("[%s] error decoding request from %s: %s: (%d bytes) %s", n.id.HexShort(), pkt.raddr.String(), err.Error(), len(pkt.data), hex.EncodeToString(pkt.data))
return
}
log.Debugf("[%s] query %s: received request from %s: %s(%s)", n.id.HexShort(), request.ID.HexShort(), request.NodeID.HexShort(), request.Method, request.argsDebug())
n.handleRequest(pkt.raddr, request)
case '0' + responseType:
response := Response{}
err := bencode.DecodeBytes(pkt.data, &response)
if err != nil {
log.Errorf("[%s] error decoding response from %s: %s: (%d bytes) %s", n.id.HexShort(), pkt.raddr.String(), err.Error(), len(pkt.data), hex.EncodeToString(pkt.data))
return
}
log.Debugf("[%s] query %s: received response from %s: %s", n.id.HexShort(), response.ID.HexShort(), response.NodeID.HexShort(), response.argsDebug())
n.handleResponse(pkt.raddr, response)
case '0' + errorType:
e := Error{}
err := bencode.DecodeBytes(pkt.data, &e)
if err != nil {
log.Errorf("[%s] error decoding error from %s: %s: (%d bytes) %s", n.id.HexShort(), pkt.raddr.String(), err.Error(), len(pkt.data), hex.EncodeToString(pkt.data))
return
}
log.Debugf("[%s] query %s: received error from %s: %s", n.id.HexShort(), e.ID.HexShort(), e.NodeID.HexShort(), e.ExceptionType)
n.handleError(pkt.raddr, e)
default:
log.Errorf("[%s] invalid message type: %s", n.id.HexShort(), string(pkt.data[5]))
return
}
}
// handleRequest handles the requests received from udp.
func (n *Node) handleRequest(addr *net.UDPAddr, request Request) {
if request.NodeID.Equals(n.id) {
log.Warn("ignoring self-request")
return
}
// if a handler is overridden, call it instead
if n.requestHandler != nil {
n.requestHandler(addr, request)
return
}
switch request.Method {
default:
//n.sendMessage(addr, Error{ID: request.ID, NodeID: n.id, ExceptionType: "invalid-request-method"})
log.Errorln("invalid request method")
return
case pingMethod:
err := n.sendMessage(addr, Response{ID: request.ID, NodeID: n.id, Data: pingSuccessResponse})
if err != nil {
log.Error("error sending 'pingmethod' response message - ", err)
}
case storeMethod:
// TODO: we should be sending the IP in the request, not just using the sender's IP
// TODO: should we be using StoreArgs.NodeID or StoreArgs.Value.LbryID ???
if n.tokens.Verify(request.StoreArgs.Value.Token, request.NodeID, addr) {
n.Store(request.StoreArgs.BlobHash, Contact{ID: request.StoreArgs.NodeID, IP: addr.IP, Port: addr.Port, PeerPort: request.StoreArgs.Value.Port})
err := n.sendMessage(addr, Response{ID: request.ID, NodeID: n.id, Data: storeSuccessResponse})
if err != nil {
log.Error("error sending 'storemethod' response message - ", err)
}
} else {
err := n.sendMessage(addr, Error{ID: request.ID, NodeID: n.id, ExceptionType: "invalid-token"})
if err != nil {
log.Error("error sending 'storemethod'response message for invalid-token - ", err)
}
}
case findNodeMethod:
if request.Arg == nil {
log.Errorln("request is missing arg")
return
}
err := n.sendMessage(addr, Response{
ID: request.ID,
NodeID: n.id,
Contacts: n.rt.GetClosest(*request.Arg, bucketSize),
})
if err != nil {
log.Error("error sending 'findnodemethod' response message - ", err)
}
case findValueMethod:
if request.Arg == nil {
log.Errorln("request is missing arg")
return
}
res := Response{
ID: request.ID,
NodeID: n.id,
Token: n.tokens.Get(request.NodeID, addr),
}
if contacts := n.store.Get(*request.Arg); len(contacts) > 0 {
res.FindValueKey = request.Arg.RawString()
res.Contacts = contacts
} else {
res.Contacts = n.rt.GetClosest(*request.Arg, bucketSize)
}
err := n.sendMessage(addr, res)
if err != nil {
log.Error("error sending 'findvaluemethod' response message - ", err)
}
}
// nodes that send us requests should not be inserted, only refreshed.
// the routing table must only contain "good" nodes, which are nodes that reply to our requests
// if a node is already good (aka in the table), its fine to refresh it
// http://www.bittorrent.org/beps/bep_0005.html#routing-table
n.rt.Fresh(Contact{ID: request.NodeID, IP: addr.IP, Port: addr.Port})
}
// handleResponse handles responses received from udp.
func (n *Node) handleResponse(addr *net.UDPAddr, response Response) {
tx := n.txFind(response.ID, Contact{ID: response.NodeID, IP: addr.IP, Port: addr.Port})
if tx != nil {
select {
case tx.res <- response:
default:
//log.Errorf("[%s] query %s: response received, but tx has no listener or multiple responses to the same tx", n.id.HexShort(), response.ID.HexShort())
}
}
n.rt.Update(Contact{ID: response.NodeID, IP: addr.IP, Port: addr.Port})
}
// handleError handles errors received from udp.
func (n *Node) handleError(addr *net.UDPAddr, e Error) {
spew.Dump(e)
n.rt.Fresh(Contact{ID: e.NodeID, IP: addr.IP, Port: addr.Port})
}
// send sends data to a udp address
func (n *Node) sendMessage(addr *net.UDPAddr, data Message) error {
encoded, err := bencode.EncodeBytes(data)
if err != nil {
return errors.WithStack(err)
}
if req, ok := data.(Request); ok {
log.Debugf("[%s] query %s: sending request to %s (%d bytes) %s(%s)",
n.id.HexShort(), req.ID.HexShort(), addr.String(), len(encoded), req.Method, req.argsDebug())
} else if res, ok := data.(Response); ok {
log.Debugf("[%s] query %s: sending response to %s (%d bytes) %s",
n.id.HexShort(), res.ID.HexShort(), addr.String(), len(encoded), res.argsDebug())
} else {
log.Debugf("[%s] (%d bytes) %s", n.id.HexShort(), len(encoded), spew.Sdump(data))
}
err = n.conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
if err != nil {
if n.connClosed {
return nil
}
log.Error("error setting write deadline - ", err)
}
_, err = n.conn.WriteToUDP(encoded, addr)
return errors.WithStack(err)
}
// transaction represents a single query to the dht. it stores the queried contact, the request, and the response channel
type transaction struct {
contact Contact
req Request
res chan Response
skipIDCheck bool
}
// insert adds a transaction to the manager.
func (n *Node) txInsert(tx *transaction) {
n.txLock.Lock()
defer n.txLock.Unlock()
n.transactions[tx.req.ID] = tx
}
// delete removes a transaction from the manager.
func (n *Node) txDelete(id messageID) {
n.txLock.Lock()
defer n.txLock.Unlock()
delete(n.transactions, id)
}
// Find finds a transaction for the given id and contact
func (n *Node) txFind(id messageID, c Contact) *transaction {
n.txLock.RLock()
defer n.txLock.RUnlock()
t, ok := n.transactions[id]
if !ok || !t.contact.Equals(c, !t.skipIDCheck) {
return nil
}
return t
}
// SendOptions controls the behavior of send calls
type SendOptions struct {
skipIDCheck bool
}
// SendAsync sends a transaction and returns a channel that will eventually contain the transaction response
// The response channel is closed when the transaction is completed or times out.
func (n *Node) SendAsync(contact Contact, req Request, options ...SendOptions) <-chan *Response {
ch := make(chan *Response, 1)
if contact.ID.Equals(n.id) {
log.Error("sending query to self")
close(ch)
return ch
}
go func() {
defer close(ch)
req.ID = newMessageID()
req.NodeID = n.id
tx := &transaction{
contact: contact,
req: req,
res: make(chan Response),
}
if len(options) > 0 && options[0].skipIDCheck {
tx.skipIDCheck = true
}
n.txInsert(tx)
defer n.txDelete(tx.req.ID)
for i := 0; i < udpRetry; i++ {
err := n.sendMessage(contact.Addr(), tx.req)
if err != nil {
if !strings.Contains(err.Error(), "use of closed network connection") { // this only happens on localhost. real UDP has no connections
log.Error("send error: ", err)
}
continue
}
select {
case res := <-tx.res:
ch <- &res
return
case <-n.grp.Ch():
return
case <-time.After(udpTimeout):
}
}
// notify routing table about a failure to respond
n.rt.Fail(tx.contact)
}()
return ch
}
// Send sends a transaction and blocks until the response is available. It returns a response, or nil
// if the transaction timed out.
func (n *Node) Send(contact Contact, req Request, options ...SendOptions) *Response {
return <-n.SendAsync(contact, req, options...)
}
// CountActiveTransactions returns the number of transactions in the manager
func (n *Node) CountActiveTransactions() int {
n.txLock.Lock()
defer n.txLock.Unlock()
return len(n.transactions)
}
func (n *Node) startRoutingTableGrooming() {
refreshTicker := time.NewTicker(tRefresh / 5) // how often to check for buckets that need to be refreshed
for {
select {
case <-refreshTicker.C:
RoutingTableRefresh(n, tRefresh, n.grp.Child())
case <-n.grp.Ch():
return
}
}
}
// Store stores a node contact in the node's contact store.
func (n *Node) Store(hash bits.Bitmap, c Contact) {
n.store.Upsert(hash, c)
}
//AddKnownNode adds a known-good node to the routing table
func (n *Node) AddKnownNode(c Contact) {
n.rt.Update(c)
}