diff --git a/server/peer/server.go b/server/peer/server.go index 7de420b..72ab04a 100644 --- a/server/peer/server.go +++ b/server/peer/server.go @@ -234,10 +234,10 @@ func (s *Server) handleCompositeRequest(data []byte) ([]byte, error) { response := compositeResponse{ LbrycrdAddress: LbrycrdAddress, + AvailableBlobs: []string{}, } if len(request.RequestedBlobs) > 0 { - var availableBlobs []string for _, blobHash := range request.RequestedBlobs { if reflector.IsProtected(blobHash) { return nil, errors.Err("requested blob is protected") @@ -247,15 +247,16 @@ func (s *Server) handleCompositeRequest(data []byte) ([]byte, error) { return nil, err } if exists { - availableBlobs = append(availableBlobs, blobHash) + response.AvailableBlobs = append(response.AvailableBlobs, blobHash) } } - response.AvailableBlobs = availableBlobs } - response.BlobDataPaymentRate = paymentRateAccepted - if request.BlobDataPaymentRate < 0 { - response.BlobDataPaymentRate = paymentRateTooLow + if request.BlobDataPaymentRate != nil { + response.BlobDataPaymentRate = paymentRateAccepted + if *request.BlobDataPaymentRate < 0 { + response.BlobDataPaymentRate = paymentRateTooLow + } } var blob []byte @@ -270,13 +271,13 @@ func (s *Server) handleCompositeRequest(data []byte) ([]byte, error) { blob, trace, err = s.store.Get(request.RequestedBlob) log.Debug(trace.String()) if errors.Is(err, store.ErrBlobNotFound) { - response.IncomingBlob = incomingBlob{ + response.IncomingBlob = &incomingBlob{ Error: err.Error(), } } else if err != nil { return nil, err } else { - response.IncomingBlob = incomingBlob{ + response.IncomingBlob = &incomingBlob{ BlobHash: request.RequestedBlob, Length: len(blob), } @@ -305,7 +306,15 @@ func (s *Server) logError(e error) { } func readNextMessage(buf *bufio.Reader) ([]byte, error) { - msg := make([]byte, 0) + first_byte, err := buf.ReadByte() + if err != nil { + return nil, err + } + if first_byte != '{' { + // every request starts with '{'. Checking here disconnects earlier, so we don't wait until timeout + return nil, errInvalidData + } + msg := []byte("{") eof := false for { @@ -326,6 +335,8 @@ func readNextMessage(buf *bufio.Reader) ([]byte, error) { if len(msg) > maxRequestSize { return msg, errRequestTooLarge + } else if len(msg) > 0 && msg[0] != '{' { + return msg, errInvalidData } // yes, this is how the peer protocol knows when the request finishes @@ -360,6 +371,7 @@ const ( ) var errRequestTooLarge = errors.Base("request is too large") +var errInvalidData = errors.Base("Invalid data") type availabilityRequest struct { LbrycrdAddress bool `json:"lbrycrd_address"` @@ -396,13 +408,13 @@ type blobResponse struct { type compositeRequest struct { LbrycrdAddress bool `json:"lbrycrd_address"` RequestedBlobs []string `json:"requested_blobs"` - BlobDataPaymentRate float64 `json:"blob_data_payment_rate"` + BlobDataPaymentRate *float64 `json:"blob_data_payment_rate"` RequestedBlob string `json:"requested_blob"` } type compositeResponse struct { - LbrycrdAddress string `json:"lbrycrd_address,omitempty"` - AvailableBlobs []string `json:"available_blobs,omitempty"` - BlobDataPaymentRate string `json:"blob_data_payment_rate,omitempty"` - IncomingBlob incomingBlob `json:"incoming_blob,omitempty"` + LbrycrdAddress string `json:"lbrycrd_address,omitempty"` + AvailableBlobs []string `json:"available_blobs"` + BlobDataPaymentRate string `json:"blob_data_payment_rate,omitempty"` + IncomingBlob *incomingBlob `json:"incoming_blob,omitempty"` } diff --git a/server/peer/server_test.go b/server/peer/server_test.go index 7d21921..911d543 100644 --- a/server/peer/server_test.go +++ b/server/peer/server_test.go @@ -2,7 +2,10 @@ package peer import ( "bytes" + "io" + "net" "testing" + "time" "github.com/lbryio/reflector.go/store" ) @@ -75,3 +78,59 @@ func TestAvailabilityRequest_WithBlobs(t *testing.T) { } } } + +func TestRequestFromConnection(t *testing.T) { + s := getServer(t, true) + err := s.Start("127.0.0.1:50505") + defer s.Shutdown() + if err != nil { + t.Error("error starting server", err) + } + + for _, p := range availabilityRequests { + conn, err := net.Dial("tcp", "127.0.0.1:50505") + if err != nil { + t.Error("error opening connection", err) + } + defer conn.Close() + + response := make([]byte, 8192) + _, err = conn.Write(p.request) + if err != nil { + t.Error("error writing", err) + } + _, err = conn.Read(response) + if err != nil { + t.Error("error reading", err) + } + if !bytes.Equal(response[:len(p.response)], p.response) { + t.Errorf("Response did not match expected response.\nExpected: %s\nGot: %s", string(p.response), string(response)) + } + } +} + +func TestInvalidData(t *testing.T) { + s := getServer(t, true) + err := s.Start("127.0.0.1:50503") + defer s.Shutdown() + if err != nil { + t.Error("error starting server", err) + } + conn, err := net.Dial("tcp", "127.0.0.1:50503") + if err != nil { + t.Error("error opening connection", err) + } + defer conn.Close() + + response := make([]byte, 8192) + _, err = conn.Write([]byte("hello dear server, I would like blobs. Please")) + if err != nil { + t.Error("error writing", err) + } + conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + _, err = conn.Read(response) + if err != io.EOF { + t.Error("error reading", err) + } + println(response) +}