diff --git a/dht/rpc.go b/dht/rpc.go index fb8df4c..ae8af09 100644 --- a/dht/rpc.go +++ b/dht/rpc.go @@ -6,6 +6,8 @@ import ( "strings" "time" + "github.com/lbryio/lbry.go/util" + "github.com/davecgh/go-spew/spew" "github.com/lyoshenka/bencode" log "github.com/sirupsen/logrus" @@ -15,52 +17,49 @@ import ( func handlePacket(dht *DHT, pkt packet) { //log.Debugf("[%s] Received message from %s:%s (%d bytes) %s", dht.node.id.HexShort(), pkt.raddr.IP.String(), strconv.Itoa(pkt.raddr.Port), len(pkt.data), hex.EncodeToString(pkt.data)) - var data map[string]interface{} - err := bencode.DecodeBytes(pkt.data, &data) - if err != nil { - log.Errorf("[%s] error decoding data: %s: (%d bytes) %s", dht.node.id.HexShort(), err.Error(), len(pkt.data), hex.EncodeToString(pkt.data)) + if !util.InSlice(string(pkt.data[0:5]), []string{"d1:0i", "di0ei"}) { + log.Errorf("[%s] data is not a well-formatted dict: (%d bytes) %s", dht.node.id.HexShort(), len(pkt.data), hex.EncodeToString(pkt.data)) return } - msgType, ok := data[headerTypeField] - if !ok { - log.Errorf("[%s] decoded data has no message type: %s", dht.node.id.HexShort(), spew.Sdump(data)) - return - } + // TODO: test this stuff more thoroughly - switch msgType.(int64) { - case requestType: + // the following is a bit of a hack, but it lets us avoid decoding every message twice + // it depends on the data being a dict with 0 as the first key (so it starts with "d1:0i") and the message type as the first value + + switch pkt.data[5] { + case '0' + requestType: request := Request{} - err = bencode.DecodeBytes(pkt.data, &request) + err := bencode.DecodeBytes(pkt.data, &request) if err != nil { - log.Errorln(err) + log.Errorf("[%s] error decoding request: %s: (%d bytes) %s", dht.node.id.HexShort(), err.Error(), len(pkt.data), hex.EncodeToString(pkt.data)) return } log.Debugf("[%s] query %s: received request from %s: %s(%s)", dht.node.id.HexShort(), request.ID.HexShort(), request.NodeID.HexShort(), request.Method, argsToString(request.Args)) handleRequest(dht, pkt.raddr, request) - case responseType: + case '0' + responseType: response := Response{} - err = bencode.DecodeBytes(pkt.data, &response) + err := bencode.DecodeBytes(pkt.data, &response) if err != nil { - log.Errorln(err) + log.Errorf("[%s] error decoding response: %s: (%d bytes) %s", dht.node.id.HexShort(), err.Error(), len(pkt.data), hex.EncodeToString(pkt.data)) return } log.Debugf("[%s] query %s: received response from %s: %s", dht.node.id.HexShort(), response.ID.HexShort(), response.NodeID.HexShort(), response.ArgsDebug()) handleResponse(dht, pkt.raddr, response) - case errorType: + case '0' + errorType: e := Error{} - err = bencode.DecodeBytes(pkt.data, &e) + err := bencode.DecodeBytes(pkt.data, &e) if err != nil { - log.Errorln(err) + log.Errorf("[%s] error decoding error: %s: (%d bytes) %s", dht.node.id.HexShort(), err.Error(), len(pkt.data), hex.EncodeToString(pkt.data)) return } log.Debugf("[%s] query %s: received error from %s: %s", dht.node.id.HexShort(), e.ID.HexShort(), e.NodeID.HexShort(), e.ExceptionType) handleError(dht, pkt.raddr, e) default: - log.Errorf("[%s] invalid message type: %s", dht.node.id.HexShort(), msgType) + log.Errorf("[%s] invalid message type: %s", dht.node.id.HexShort(), pkt.data[5]) return } }