terminate stream after consuming all the data

This commit is contained in:
Alex Grintsvayg 2021-04-02 14:16:46 -04:00
parent be64130ae1
commit 6bc878d657
No known key found for this signature in database
GPG key ID: AEB3F089F86A22B5
2 changed files with 68 additions and 10 deletions

View file

@ -143,10 +143,15 @@ func NewEncoderFromSD(src io.Reader, sdBlob *SDBlob) *Encoder {
// TODO: consider making a NewPartialEncoder that also copies blobinfos from sdBlobs and seeks forward in the data // TODO: consider making a NewPartialEncoder that also copies blobinfos from sdBlobs and seeks forward in the data
// this would avoid re-creating blobs that were created in the past // this would avoid re-creating blobs that were created in the past
// Next returns the next blob in the stream // Next reads the next chunk of data, encodes it into a blob, and adds it to the stream
// When the source is fully consumed, Next() makes sure the stream is terminated (i.e. the sd blob
// ends with an empty terminating blob) and returns io.EOF
func (e *Encoder) Next() (Blob, error) { func (e *Encoder) Next() (Blob, error) {
n, err := e.src.Read(e.buf) n, err := e.src.Read(e.buf)
if err != nil { if err != nil {
if errors.Is(err, io.EOF) {
e.ensureTerminated()
}
return nil, err return nil, err
} }
@ -171,17 +176,11 @@ func (e *Encoder) Stream() (Stream, error) {
for { for {
blob, err := e.Next() blob, err := e.Next()
if err != nil { if err != nil {
if !errors.Is(err, io.EOF) { if errors.Is(err, io.EOF) {
return nil, err
}
// if stream is not terminated, terminate it
if e.sd.BlobInfos[len(e.sd.BlobInfos)-1].Length > 0 {
e.sd.addBlob(Blob{}, e.nextIV())
}
break break
} }
return nil, err
}
s = append(s, blob) s = append(s, blob)
} }
@ -221,6 +220,16 @@ func (e *Encoder) SourceSizeHint(size int) *Encoder {
return e return e
} }
func (e *Encoder) isTerminated() bool {
return len(e.sd.BlobInfos) >= 1 && e.sd.BlobInfos[len(e.sd.BlobInfos)-1].Length == 0
}
func (e *Encoder) ensureTerminated() {
if !e.isTerminated() {
e.sd.addBlob(Blob{}, e.nextIV())
}
}
// nextIV returns the next preset IV if there is one // nextIV returns the next preset IV if there is one
func (e *Encoder) nextIV() []byte { func (e *Encoder) nextIV() []byte {
if len(e.ivs) == 0 { if len(e.ivs) == 0 {

View file

@ -6,7 +6,10 @@ import (
"crypto/sha256" "crypto/sha256"
"crypto/sha512" "crypto/sha512"
"encoding/hex" "encoding/hex"
"io"
"testing" "testing"
"github.com/lbryio/lbry.go/v2/extras/errors"
) )
var testdataBlobHashes = []string{ var testdataBlobHashes = []string{
@ -135,6 +138,52 @@ func TestMakeStream(t *testing.T) {
} }
} }
func TestEmptyStream(t *testing.T) {
enc := NewEncoder(bytes.NewBuffer(nil))
_, err := enc.Next()
if !errors.Is(err, io.EOF) {
t.Errorf("expected io.EOF, got %v", err)
}
sd := enc.SDBlob()
if len(sd.BlobInfos) != 1 {
t.Errorf("expected 1 blobinfos in sd blob, got %d", len(sd.BlobInfos))
}
if sd.BlobInfos[0].Length != 0 {
t.Errorf("first and only blob to be the terminator blob")
}
}
func TestTermination(t *testing.T) {
b := make([]byte, 12)
enc := NewEncoder(bytes.NewBuffer(b))
_, err := enc.Next()
if err != nil {
t.Error(err)
}
if enc.isTerminated() {
t.Errorf("stream should not terminate until after EOF")
}
_, err = enc.Next()
if !errors.Is(err, io.EOF) {
t.Errorf("expected io.EOF, got %v", err)
}
if !enc.isTerminated() {
t.Errorf("stream should be terminated after EOF")
}
_, err = enc.Next()
if !errors.Is(err, io.EOF) {
t.Errorf("expected io.EOF on all subsequent reads, got %v", err)
}
sd := enc.SDBlob()
if len(sd.BlobInfos) != 2 {
t.Errorf("expected 2 blobinfos in sd blob, got %d", len(sd.BlobInfos))
}
}
func TestSizeHint(t *testing.T) { func TestSizeHint(t *testing.T) {
b := make([]byte, 12) b := make([]byte, 12)