diff --git a/cmd/getstream.go b/cmd/getstream.go index ca3de6d..6a8c053 100644 --- a/cmd/getstream.go +++ b/cmd/getstream.go @@ -1,10 +1,9 @@ package cmd import ( - "io/ioutil" "os" - "github.com/lbryio/lbry.go/stream" + "github.com/lbryio/reflector.go/store" "github.com/lbryio/lbry.go/extras/errors" "github.com/lbryio/reflector.go/peer" @@ -33,39 +32,16 @@ func getStreamCmd(cmd *cobra.Command, args []string) { log.Fatal("error connecting client to server: ", err) } - s, err := c.GetStream(sdHash) - if err != nil { - log.Error(errors.FullTrace(err)) - return - } - - var sd stream.SDBlob - err = sd.FromBlob(s[0]) - if err != nil { - log.Error(errors.FullTrace(err)) - return - } - - log.Printf("Downloading %d blobs for %s", len(sd.BlobInfos)-1, sd.SuggestedFileName) - - data, err := s.Data() - if err != nil { - log.Error(errors.FullTrace(err)) - return - } + cache := store.NewFileBlobStore("/tmp/lbry_downloaded_blobs") wd, err := os.Getwd() if err != nil { - log.Error(errors.FullTrace(err)) - return + log.Fatal(err) } - filename := wd + "/" + sd.SuggestedFileName - err = ioutil.WriteFile(filename, data, 0644) + err = c.WriteStream(sdHash, wd, cache) if err != nil { log.Error(errors.FullTrace(err)) return } - - log.Printf("Wrote %d bytes to %s\n", len(data), filename) } diff --git a/peer/client.go b/peer/client.go index a8e2591..8558f1a 100644 --- a/peer/client.go +++ b/peer/client.go @@ -6,8 +6,11 @@ import ( "encoding/json" "io" "net" + "os" "time" + "github.com/lbryio/reflector.go/store" + "github.com/lbryio/lbry.go/stream" "github.com/lbryio/lbry.go/extras/errors" @@ -47,15 +50,65 @@ func (c *Client) Close() error { return c.conn.Close() } +// WriteStream downloads and writes a stream to file +func (c *Client) WriteStream(sdHash, dir string, blobStore store.BlobStore) error { + if !c.connected { + return errors.Err("not connected") + } + + var sd stream.SDBlob + + sdb, err := c.getBlobWithCache(sdHash, blobStore) + if err != nil { + return err + } + + err = sd.FromBlob(sdb) + if err != nil { + return err + } + + info, err := os.Stat(dir) + if err != nil { + return errors.Prefix("cannot stat "+dir, err) + } else if !info.IsDir() { + return errors.Err(dir + " must be a directory") + } + + f, err := os.Create(dir + "/" + sd.SuggestedFileName) + if err != nil { + return err + } + + for i := 0; i < len(sd.BlobInfos)-1; i++ { + b, err := c.getBlobWithCache(hex.EncodeToString(sd.BlobInfos[i].BlobHash), blobStore) + if err != nil { + return err + } + + data, err := b.Plaintext(sd.Key, sd.BlobInfos[i].IV) + if err != nil { + return err + } + + _, err = f.Write(data) + if err != nil { + return err + } + } + + return nil +} + // GetStream gets a stream -func (c *Client) GetStream(sdHash string) (stream.Stream, error) { +func (c *Client) GetStream(sdHash string, blobCache store.BlobStore) (stream.Stream, error) { if !c.connected { return nil, errors.Err("not connected") } var sd stream.SDBlob - b, err := c.GetBlob(sdHash) + b, err := c.getBlobWithCache(sdHash, blobCache) if err != nil { return nil, err } @@ -69,7 +122,7 @@ func (c *Client) GetStream(sdHash string) (stream.Stream, error) { s[0] = b for i := 0; i < len(sd.BlobInfos)-1; i++ { - s[i+1], err = c.GetBlob(hex.EncodeToString(sd.BlobInfos[i].BlobHash)) + s[i+1], err = c.getBlobWithCache(hex.EncodeToString(sd.BlobInfos[i].BlobHash), blobCache) if err != nil { return nil, err } @@ -78,14 +131,34 @@ func (c *Client) GetStream(sdHash string) (stream.Stream, error) { return s, nil } +func (c *Client) getBlobWithCache(hash string, blobCache store.BlobStore) (stream.Blob, error) { + if blobCache == nil { + return c.GetBlob(hash) + } + + blob, err := blobCache.Get(hash) + if err == nil || !errors.Is(err, store.ErrBlobNotFound) { + return blob, err + } + + blob, err = c.GetBlob(hash) + if err != nil { + return nil, err + } + + err = blobCache.Put(hash, blob) + + return blob, err +} + // GetBlob gets a blob -func (c *Client) GetBlob(blobHash string) (stream.Blob, error) { +func (c *Client) GetBlob(hash string) (stream.Blob, error) { if !c.connected { return nil, errors.Err("not connected") } sendRequest, err := json.Marshal(blobRequest{ - RequestedBlob: blobHash, + RequestedBlob: hash, }) if err != nil { return nil, err @@ -103,16 +176,16 @@ func (c *Client) GetBlob(blobHash string) (stream.Blob, error) { } if resp.IncomingBlob.Error != "" { - return nil, errors.Prefix(blobHash[:8], resp.IncomingBlob.Error) + return nil, errors.Prefix(hash[:8], resp.IncomingBlob.Error) } - if resp.IncomingBlob.BlobHash != blobHash { - return nil, errors.Prefix(blobHash[:8], "Blob hash in response does not match requested hash") + if resp.IncomingBlob.BlobHash != hash { + return nil, errors.Prefix(hash[:8], "Blob hash in response does not match requested hash") } if resp.IncomingBlob.Length <= 0 { - return nil, errors.Prefix(blobHash[:8], "Length reported as <= 0") + return nil, errors.Prefix(hash[:8], "Length reported as <= 0") } - log.Println("Receiving blob " + blobHash[:8]) + log.Println("Receiving blob " + hash[:8]) blob, err := c.readRawBlob(resp.IncomingBlob.Length) if err != nil {