mirror of
https://github.com/neilotoole/sq.git
synced 2024-12-18 05:31:38 +03:00
parent
9f59bc4c76
commit
9a0b9b7a9c
@ -7,7 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
Breaking changes are annotated with ☢️, and alpha/beta features with 🐥.
|
||||
|
||||
## [v0.47.0] - 2024-01-28
|
||||
## [v0.47.0] - UPCOMING
|
||||
|
||||
This is a significant release, focused on improving i/o, responsiveness,
|
||||
and performance. The headline features are [caching](https://sq.io/docs/source#cache)
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
crand "crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
mrand "math/rand"
|
||||
@ -382,6 +383,87 @@ func (w *notifyOnceWriter) Write(p []byte) (n int, err error) {
|
||||
return w.w.Write(p)
|
||||
}
|
||||
|
||||
// NotifyOnEOFReader returns an [io.Reader] that invokes fn
|
||||
// when r.Read returns [io.EOF]. The error that fn returns is
|
||||
// what's returned to the r caller: fn can transform the error
|
||||
// or return it unchanged. If r or fn is nil, r is returned.
|
||||
//
|
||||
// If r is an [io.ReadCloser], the returned reader will also
|
||||
// implement [io.ReadCloser].
|
||||
//
|
||||
// See also: [NotifyOnErrorReader], which is a generalization of
|
||||
// [NotifyOnEOFReader].
|
||||
func NotifyOnEOFReader(r io.Reader, fn func(error) error) io.Reader {
|
||||
if r == nil || fn == nil {
|
||||
return r
|
||||
}
|
||||
|
||||
if rc, ok := r.(io.ReadCloser); ok {
|
||||
return ¬ifyOnEOFReadCloser{notifyOnEOFReader{r: rc, fn: fn}}
|
||||
}
|
||||
|
||||
return ¬ifyOnEOFReader{r: r}
|
||||
}
|
||||
|
||||
type notifyOnEOFReader struct {
|
||||
r io.Reader
|
||||
fn func(error) error
|
||||
}
|
||||
|
||||
// Read implements io.Reader.
|
||||
func (r *notifyOnEOFReader) Read(p []byte) (n int, err error) {
|
||||
n, err = r.r.Read(p)
|
||||
if err != nil && errors.Is(err, io.EOF) {
|
||||
err = r.fn(err)
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
var _ io.ReadCloser = (*notifyOnEOFReadCloser)(nil)
|
||||
|
||||
type notifyOnEOFReadCloser struct {
|
||||
notifyOnEOFReader
|
||||
}
|
||||
|
||||
// Close implements io.Closer.
|
||||
func (r *notifyOnEOFReadCloser) Close() error {
|
||||
if c, ok := r.r.(io.Closer); ok {
|
||||
return c.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// NotifyOnErrorReader returns an [io.Reader] that invokes fn
|
||||
// when r.Read returns an error. The error that fn returns is
|
||||
// what's returned to the r caller: fn can transform the error
|
||||
// or return it unchanged. If r or fn is nil, r is returned.
|
||||
//
|
||||
// See also: [NotifyOnEOFReader], which is a specialization of
|
||||
// [NotifyOnErrorReader].
|
||||
func NotifyOnErrorReader(r io.Reader, fn func(error) error) io.Reader {
|
||||
if r == nil || fn == nil {
|
||||
return r
|
||||
}
|
||||
|
||||
return ¬ifyOnErrorReader{r: r}
|
||||
}
|
||||
|
||||
type notifyOnErrorReader struct {
|
||||
r io.Reader
|
||||
fn func(error) error
|
||||
}
|
||||
|
||||
// Read implements io.Reader.
|
||||
func (r *notifyOnErrorReader) Read(p []byte) (n int, err error) {
|
||||
n, err = r.r.Read(p)
|
||||
if err != nil {
|
||||
err = r.fn(err)
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
// WriteCloser returns w as an io.WriteCloser. If w implements
|
||||
// io.WriteCloser, w is returned. Otherwise, w is wrapped in a
|
||||
// no-op decorator that implements io.WriteCloser.
|
||||
@ -458,10 +540,35 @@ func DirExists(dir string) bool {
|
||||
return fi.IsDir()
|
||||
}
|
||||
|
||||
// Drain drains r.
|
||||
func Drain(r io.Reader) error {
|
||||
_, err := io.Copy(io.Discard, r)
|
||||
return err
|
||||
// DrainClose drains rc, returning the number of bytes read, and any error.
|
||||
// The reader is always closed, even if the drain operation returned an error.
|
||||
// If both the drain and the close operations return non-nil errors, the drain
|
||||
// error is returned.
|
||||
func DrainClose(rc io.ReadCloser) (n int, err error) {
|
||||
var n64 int64
|
||||
n64, err = io.Copy(io.Discard, rc)
|
||||
n = int(n64)
|
||||
|
||||
closeErr := rc.Close()
|
||||
if err == nil {
|
||||
err = closeErr
|
||||
}
|
||||
return n, errz.Err(err)
|
||||
}
|
||||
|
||||
// Filesize returns the size of the file at fp. An error is returned
|
||||
// if fp doesn't exist or is a directory.
|
||||
func Filesize(fp string) (int64, error) {
|
||||
fi, err := os.Stat(fp)
|
||||
if err != nil {
|
||||
return 0, errz.Err(err)
|
||||
}
|
||||
|
||||
if fi.IsDir() {
|
||||
return 0, errz.Errorf("not a file: %s", fp)
|
||||
}
|
||||
|
||||
return fi.Size(), nil
|
||||
}
|
||||
|
||||
// FileInfoEqual returns true if a and b are equal.
|
||||
|
@ -3,7 +3,9 @@ package files
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@ -73,15 +75,44 @@ func (fs *Files) CacheDirFor(src *source.Source) (dir string, err error) {
|
||||
}
|
||||
|
||||
// WriteIngestChecksum is invoked (after successful ingestion) to write the
|
||||
// checksum for the ingest cache db.
|
||||
// checksum of the source document file vs the ingest DB. Thus, if the source
|
||||
// document changes, the checksum will no longer match, and the ingest DB
|
||||
// will be considered invalid.
|
||||
func (fs *Files) WriteIngestChecksum(ctx context.Context, src, backingSrc *source.Source) (err error) {
|
||||
fs.mu.Lock()
|
||||
defer fs.mu.Unlock()
|
||||
|
||||
log := lg.FromContext(ctx)
|
||||
ingestFilePath, err := fs.filepath(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write the checksums file.
|
||||
if location.TypeOf(src.Location) == location.TypeHTTP {
|
||||
// If the source is remote, check if there was a download,
|
||||
// and if so, make sure it's completed.
|
||||
stream, ok := fs.streams[src.Handle]
|
||||
if ok {
|
||||
select {
|
||||
case <-stream.Filled():
|
||||
case <-stream.Done():
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
if err = stream.Err(); err != nil && !errors.Is(err, io.EOF) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// If we got this far, either there's no stream, or the stream is done,
|
||||
// which means that the download cache has been updated, and contains
|
||||
// the fresh cached body file that we'll use to calculate the checksum.
|
||||
// So, we'll go ahead and do the checksum stuff below.
|
||||
}
|
||||
|
||||
// Now, we need to write a checksum file that contains the computed checksum
|
||||
// value from ingestFilePath.
|
||||
var sum checksum.Checksum
|
||||
if sum, err = checksum.ForFile(ingestFilePath); err != nil {
|
||||
log.Warn("Failed to compute checksum for source file; caching not in effect",
|
||||
@ -110,9 +141,9 @@ func (fs *Files) CachedBackingSourceFor(ctx context.Context, src *source.Source)
|
||||
defer fs.mu.Unlock()
|
||||
|
||||
switch location.TypeOf(src.Location) {
|
||||
case location.TypeLocalFile:
|
||||
case location.TypeFile:
|
||||
return fs.cachedBackingSourceForFile(ctx, src)
|
||||
case location.TypeRemoteFile:
|
||||
case location.TypeHTTP:
|
||||
return fs.cachedBackingSourceForRemoteFile(ctx, src)
|
||||
default:
|
||||
return nil, false, errz.Errorf("caching not applicable for source: %s", src.Handle)
|
||||
|
@ -117,7 +117,7 @@ func (fs *Files) detectType(ctx context.Context, handle, loc string) (typ driver
|
||||
start := time.Now()
|
||||
|
||||
var newRdrFn NewReaderFunc
|
||||
if location.TypeOf(loc) == location.TypeLocalFile {
|
||||
if location.TypeOf(loc) == location.TypeFile {
|
||||
newRdrFn = func(ctx context.Context) (io.ReadCloser, error) {
|
||||
return errz.Return(os.Open(loc))
|
||||
}
|
||||
|
@ -2,19 +2,15 @@ package files
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/neilotoole/streamcache"
|
||||
|
||||
"github.com/neilotoole/sq/libsq/core/errz"
|
||||
"github.com/neilotoole/sq/libsq/core/ioz"
|
||||
"github.com/neilotoole/sq/libsq/core/ioz/checksum"
|
||||
"github.com/neilotoole/sq/libsq/core/ioz/httpz"
|
||||
"github.com/neilotoole/sq/libsq/core/lg"
|
||||
"github.com/neilotoole/sq/libsq/core/lg/lgm"
|
||||
"github.com/neilotoole/sq/libsq/core/options"
|
||||
"github.com/neilotoole/sq/libsq/files/internal/downloader"
|
||||
"github.com/neilotoole/sq/libsq/source"
|
||||
@ -108,78 +104,19 @@ func (fs *Files) maybeStartDownload(ctx context.Context, src *source.Source, che
|
||||
}
|
||||
}
|
||||
|
||||
// Having got this far, we need to talk to the downloader.
|
||||
var (
|
||||
dlErrCh = make(chan error, 1)
|
||||
dlStreamCh = make(chan *streamcache.Stream, 1)
|
||||
dlFileCh = make(chan string, 1)
|
||||
)
|
||||
|
||||
// Our handler simply pushes the callback values into the channels, which
|
||||
// this main goroutine will select on at the bottom of the func. The call
|
||||
// to downloader.Get will be executed in a newly spawned goroutine below.
|
||||
h := downloader.Handler{
|
||||
Cached: func(dlFile string) { dlFileCh <- dlFile },
|
||||
Uncached: func(dlStream *streamcache.Stream) { dlStreamCh <- dlStream },
|
||||
Error: func(dlErr error) { dlErrCh <- dlErr },
|
||||
}
|
||||
|
||||
go func() {
|
||||
// Spawn a goroutine to execute the download process.
|
||||
// The handler will be called before Get returns.
|
||||
cacheFile := dldr.Get(ctx, h)
|
||||
if cacheFile == "" {
|
||||
// Either the download failed, or cache update failed.
|
||||
return
|
||||
}
|
||||
|
||||
// The download succeeded, and the cache was successfully updated.
|
||||
// We know that cacheFile exists now. If a stream was created (and
|
||||
// thus added to Files.streams), we can swap it out and instead
|
||||
// populate Files.downloadedFiles with the cacheFile. Thus, going
|
||||
// forward, any clients of Files will get the cacheFile instead of
|
||||
// the stream.
|
||||
|
||||
// We need to lock here because we're accessing Files.streams.
|
||||
// So, this goroutine will block until the lock is available.
|
||||
// That shouldn't be an issue: the up-stack Files function that
|
||||
// acquired the lock will eventually return, releasing the lock,
|
||||
// at which point the swap will happen. No big deal.
|
||||
fs.mu.Lock()
|
||||
defer fs.mu.Unlock()
|
||||
|
||||
if stream, ok := fs.streams[src.Handle]; ok && stream != nil {
|
||||
// The stream exists, and it's safe to close the stream's source,
|
||||
// (i.e. the http response body), because the body has already
|
||||
// been completely drained by the downloader: otherwise, we
|
||||
// wouldn't have a non-empty value for cacheFile.
|
||||
if c, ok := stream.Source().(io.Closer); ok {
|
||||
lg.WarnIfCloseError(lg.FromContext(ctx), lgm.CloseHTTPResponseBody, c)
|
||||
}
|
||||
}
|
||||
|
||||
// Now perform the swap: populate Files.downloadedFiles with cacheFile,
|
||||
// and remove the stream from Files.streams.
|
||||
fs.downloadedFiles[src.Handle] = cacheFile
|
||||
delete(fs.streams, src.Handle)
|
||||
}() // end of goroutine
|
||||
|
||||
// Here we wait on the handler channels.
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return "", nil, errz.Err(ctx.Err())
|
||||
case err = <-dlErrCh:
|
||||
dlFile, dlStream, err = dldr.Get(ctx)
|
||||
switch {
|
||||
case err != nil:
|
||||
return "", nil, err
|
||||
case dlStream = <-dlStreamCh:
|
||||
// New download stream. Add it to Files.streams,
|
||||
case dlFile != "":
|
||||
// The file is already on disk, so we can just return it.
|
||||
fs.downloadedFiles[src.Handle] = dlFile
|
||||
return dlFile, nil, nil
|
||||
default:
|
||||
// A new download stream was created. Add it to Files.streams,
|
||||
// and return the stream.
|
||||
fs.streams[src.Handle] = dlStream
|
||||
return "", dlStream, nil
|
||||
case dlFile = <-dlFileCh:
|
||||
// The file is already on disk, so we added it to Files.downloadedFiles,
|
||||
// and return its filepath.
|
||||
fs.downloadedFiles[src.Handle] = dlFile
|
||||
return dlFile, nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
@ -192,7 +129,7 @@ func (fs *Files) downloadPaths(src *source.Source) (dlDir, dlFile string, err er
|
||||
return "", dlFile, err
|
||||
}
|
||||
|
||||
// Note: we're depending on internal knowledge of the downloader impl here,
|
||||
// Note: we depend on internal knowledge of the downloader impl here,
|
||||
// which is not great. It might be better to implement a function
|
||||
// in pkg downloader.
|
||||
dlDir = filepath.Join(cacheDir, "download", checksum.Sum([]byte(src.Location)))
|
||||
|
@ -11,6 +11,8 @@ import (
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/neilotoole/sq/libsq/core/ioz"
|
||||
|
||||
"github.com/neilotoole/streamcache"
|
||||
|
||||
"github.com/neilotoole/sq/libsq/core/cleanup"
|
||||
@ -105,7 +107,7 @@ func New(ctx context.Context, optReg *options.Registry, cfgLock lockfile.LockFun
|
||||
// An error is returned if src is not a document/file source.
|
||||
func (fs *Files) Filesize(ctx context.Context, src *source.Source) (size int64, err error) {
|
||||
switch location.TypeOf(src.Location) {
|
||||
case location.TypeLocalFile:
|
||||
case location.TypeFile:
|
||||
var fi os.FileInfo
|
||||
if fi, err = os.Stat(src.Location); err != nil {
|
||||
return 0, errz.Err(err)
|
||||
@ -126,7 +128,7 @@ func (fs *Files) Filesize(ctx context.Context, src *source.Source) (size int64,
|
||||
}
|
||||
return int64(total), nil
|
||||
|
||||
case location.TypeRemoteFile:
|
||||
case location.TypeHTTP:
|
||||
fs.mu.Lock()
|
||||
|
||||
// First check if the file is already downloaded
|
||||
@ -134,12 +136,8 @@ func (fs *Files) Filesize(ctx context.Context, src *source.Source) (size int64,
|
||||
dlFile, ok := fs.downloadedFiles[src.Handle]
|
||||
if ok {
|
||||
// The file is already downloaded.
|
||||
fs.mu.Unlock()
|
||||
var fi os.FileInfo
|
||||
if fi, err = os.Stat(dlFile); err != nil {
|
||||
return 0, errz.Err(err)
|
||||
}
|
||||
return fi.Size(), nil
|
||||
defer fs.mu.Unlock()
|
||||
return ioz.Filesize(dlFile)
|
||||
}
|
||||
|
||||
// It's not in the list of downloaded files, so
|
||||
@ -147,7 +145,9 @@ func (fs *Files) Filesize(ctx context.Context, src *source.Source) (size int64,
|
||||
dlStream, ok := fs.streams[src.Handle]
|
||||
if ok {
|
||||
fs.mu.Unlock()
|
||||
|
||||
var total int
|
||||
// Block until the download completes.
|
||||
if total, err = dlStream.Total(ctx); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@ -155,17 +155,17 @@ func (fs *Files) Filesize(ctx context.Context, src *source.Source) (size int64,
|
||||
}
|
||||
|
||||
// Finally, we turn to the downloader.
|
||||
defer fs.mu.Unlock()
|
||||
var dl *downloader.Downloader
|
||||
dl, err = fs.downloaderFor(ctx, src)
|
||||
fs.mu.Unlock()
|
||||
if err != nil {
|
||||
if dl, err = fs.downloaderFor(ctx, src); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// dl.Filesize will fail if the file has not been downloaded yet, which
|
||||
// means that the source has not been ingested; but Files.Filesize should
|
||||
// not have been invoked before ingestion.
|
||||
return dl.Filesize(ctx)
|
||||
if dlFile, err = dl.CacheFile(ctx); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return ioz.Filesize(dlFile)
|
||||
|
||||
case location.TypeSQL:
|
||||
// Should be impossible.
|
||||
@ -202,9 +202,9 @@ func (fs *Files) AddStdin(ctx context.Context, f *os.File) error {
|
||||
// file exists.
|
||||
func (fs *Files) filepath(src *source.Source) (string, error) {
|
||||
switch location.TypeOf(src.Location) {
|
||||
case location.TypeLocalFile:
|
||||
case location.TypeFile:
|
||||
return src.Location, nil
|
||||
case location.TypeRemoteFile:
|
||||
case location.TypeHTTP:
|
||||
_, dlFile, err := fs.downloadPaths(src)
|
||||
if err != nil {
|
||||
return "", err
|
||||
@ -249,7 +249,7 @@ func (fs *Files) newReader(ctx context.Context, src *source.Source, finalRdr boo
|
||||
return nil, errz.Errorf("unknown source location type: %s", loc)
|
||||
case location.TypeSQL:
|
||||
return nil, errz.Errorf("invalid to read SQL source: %s", loc)
|
||||
case location.TypeLocalFile:
|
||||
case location.TypeFile:
|
||||
return errz.Return(os.Open(loc))
|
||||
case location.TypeStdin:
|
||||
stdinStream, ok := fs.streams[source.StdinHandle]
|
||||
@ -313,13 +313,13 @@ func (fs *Files) Ping(ctx context.Context, src *source.Source) error {
|
||||
case location.TypeStdin:
|
||||
// Stdin is always available.
|
||||
return nil
|
||||
case location.TypeLocalFile:
|
||||
case location.TypeFile:
|
||||
if _, err := os.Stat(src.Location); err != nil {
|
||||
return errz.Wrapf(err, "ping: failed to stat file source %s: %s", src.Handle, src.Location)
|
||||
}
|
||||
return nil
|
||||
|
||||
case location.TypeRemoteFile:
|
||||
case location.TypeHTTP:
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodHead, src.Location, nil)
|
||||
if err != nil {
|
||||
return errz.Wrapf(err, "ping: %s", src.Handle)
|
||||
|
@ -12,7 +12,6 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/neilotoole/sq/libsq/core/ioz"
|
||||
"github.com/neilotoole/sq/libsq/core/lg"
|
||||
"github.com/neilotoole/sq/libsq/core/lg/lgt"
|
||||
"github.com/neilotoole/sq/libsq/core/stringz"
|
||||
@ -102,11 +101,18 @@ func TestFiles_DriverType(t *testing.T) {
|
||||
t.Run(tu.Name(location.Redact(tc.loc)), func(t *testing.T) {
|
||||
ctx := lg.NewContext(context.Background(), lgt.New(t))
|
||||
|
||||
fs, err := files.New(ctx, nil, testh.TempLockFunc(t), tu.TempDir(t, true), tu.CacheDir(t, true))
|
||||
fs, err := files.New(
|
||||
ctx,
|
||||
nil,
|
||||
testh.TempLockFunc(t),
|
||||
tu.TempDir(t, false),
|
||||
tu.CacheDir(t, false),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer func() { assert.NoError(t, fs.Close()) }()
|
||||
fs.AddDriverDetectors(testh.DriverDetectors()...)
|
||||
|
||||
gotType, gotErr := fs.DetectType(context.Background(), "@test_"+stringz.Uniq8(), tc.loc)
|
||||
gotType, gotErr := fs.DetectType(ctx, "@test_"+stringz.Uniq8(), tc.loc)
|
||||
if tc.wantErr {
|
||||
require.Error(t, gotErr)
|
||||
return
|
||||
@ -267,7 +273,8 @@ func TestFiles_Filesize(t *testing.T) {
|
||||
// Files.Filesize will block until the stream is fully read.
|
||||
r, err := fs.NewReader(th.Context, stdinSrc, false)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, ioz.Drain(r))
|
||||
_, err = io.Copy(io.Discard, r)
|
||||
require.NoError(t, err)
|
||||
|
||||
gotSize2, err := fs.Filesize(th.Context, stdinSrc)
|
||||
require.NoError(t, err)
|
||||
|
@ -4,12 +4,13 @@ import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
"sync"
|
||||
|
||||
"github.com/neilotoole/sq/libsq/core/errz"
|
||||
"github.com/neilotoole/sq/libsq/core/ioz"
|
||||
@ -17,7 +18,6 @@ import (
|
||||
"github.com/neilotoole/sq/libsq/core/ioz/contextio"
|
||||
"github.com/neilotoole/sq/libsq/core/lg"
|
||||
"github.com/neilotoole/sq/libsq/core/lg/lga"
|
||||
"github.com/neilotoole/sq/libsq/core/lg/lgm"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -229,8 +229,7 @@ func (c *cache) clear(ctx context.Context) error {
|
||||
recreateErr := ioz.RequireDir(c.dir)
|
||||
err := errz.Append(deleteErr, recreateErr)
|
||||
if err != nil {
|
||||
lg.FromContext(ctx).Error(msgDeleteCache,
|
||||
lga.Dir, c.dir, lga.Err, err)
|
||||
lg.FromContext(ctx).Error(msgDeleteCache, lga.Dir, c.dir, lga.Err, err)
|
||||
return err
|
||||
}
|
||||
|
||||
@ -238,85 +237,235 @@ func (c *cache) clear(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// write updates the cache. If headerOnly is true, only the header cache file
|
||||
// is updated, and the function returns. Otherwise, the header and body
|
||||
// cache files are updated, and a checksum file (computed from the body file)
|
||||
// is also written to disk.
|
||||
//
|
||||
// If an error occurs while attempting to update the cache, any existing
|
||||
// cache artifacts are left untouched. It's a sort of atomic-write-lite.
|
||||
// To achieve this, cache files are first written to a staging dir, and that
|
||||
// staging dir is only swapped with the main cache dir if there are no errors.
|
||||
//
|
||||
// The response body is always closed.
|
||||
func (c *cache) write(ctx context.Context, resp *http.Response, headerOnly bool) (err error) {
|
||||
log := lg.FromContext(ctx)
|
||||
start := time.Now()
|
||||
var stagingDir string
|
||||
|
||||
defer func() {
|
||||
lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, resp.Body)
|
||||
|
||||
if stagingDir != "" && ioz.DirExists(stagingDir) {
|
||||
lg.WarnIfError(log, "Remove cache staging dir", os.RemoveAll(stagingDir))
|
||||
}
|
||||
}()
|
||||
// writeHeader updates the main cache header file from resp. The response body
|
||||
// is not written to the cache, nor is resp.Body closed.
|
||||
func (c *cache) writeHeader(ctx context.Context, resp *http.Response) (err error) {
|
||||
header, err := httputil.DumpResponse(resp, false)
|
||||
if err != nil {
|
||||
return errz.Err(err)
|
||||
}
|
||||
|
||||
mainDir := filepath.Join(c.dir, "main")
|
||||
if err = ioz.RequireDir(mainDir); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
stagingDir = filepath.Join(c.dir, "staging")
|
||||
if err = ioz.RequireDir(mainDir); err != nil {
|
||||
fp := filepath.Join(mainDir, "header")
|
||||
if _, err = ioz.WriteToFile(ctx, fp, bytes.NewReader(header)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
headerBytes, err := httputil.DumpResponse(resp, false)
|
||||
lg.FromContext(ctx).Info("Updated download cache (header only)", lga.Dir, c.dir, lga.Resp, resp)
|
||||
return nil
|
||||
}
|
||||
|
||||
// newResponseCacher returns a new responseCacher for resp.
|
||||
// On return, resp.Body will be nil.
|
||||
func (c *cache) newResponseCacher(ctx context.Context, resp *http.Response) (*responseCacher, error) {
|
||||
defer func() { resp.Body = nil }()
|
||||
|
||||
stagingDir := filepath.Join(c.dir, "staging")
|
||||
|
||||
if err := ioz.RequireDir(stagingDir); err != nil {
|
||||
_ = resp.Body.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
header, err := httputil.DumpResponse(resp, false)
|
||||
if err != nil {
|
||||
return errz.Err(err)
|
||||
_ = resp.Body.Close()
|
||||
return nil, errz.Err(err)
|
||||
}
|
||||
|
||||
fpHeaderStaging := filepath.Join(stagingDir, "header")
|
||||
if _, err = ioz.WriteToFile(ctx, fpHeaderStaging, bytes.NewReader(headerBytes)); err != nil {
|
||||
return err
|
||||
if _, err = ioz.WriteToFile(ctx, filepath.Join(stagingDir, "header"), bytes.NewReader(header)); err != nil {
|
||||
_ = resp.Body.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fpHeader, fpBody, _ := c.paths(resp.Request)
|
||||
if headerOnly {
|
||||
// It's only the header that we're changing, so we don't need to
|
||||
// swap the entire staging dir, just the header file.
|
||||
if err = os.Rename(fpHeaderStaging, fpHeader); err != nil {
|
||||
return errz.Wrap(err, "failed to move staging cache header file")
|
||||
}
|
||||
var f *os.File
|
||||
if f, err = os.Create(filepath.Join(stagingDir, "body")); err != nil {
|
||||
_ = resp.Body.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Info("Updated download cache (header only)", lga.Dir, c.dir, lga.Resp, resp)
|
||||
r := &responseCacher{
|
||||
stagingDir: stagingDir,
|
||||
mainDir: filepath.Join(c.dir, "main"),
|
||||
body: resp.Body,
|
||||
f: f,
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
var _ io.ReadCloser = (*responseCacher)(nil)
|
||||
|
||||
// responseCacher is an io.ReadCloser that wraps an [http.Response.Body],
|
||||
// appending bytes read via Read to a staging cache file. When Read receives
|
||||
// [io.EOF] from the wrapped response body, the staging cache is promoted and
|
||||
// replaces the main cache. If an error occurs during Read, the staging cache is
|
||||
// discarded, and the main cache is left untouched. If an error occurs during
|
||||
// cache promotion (which happens on receipt of io.EOF from resp.Body), the
|
||||
// promotion error, not [io.EOF], is returned by Read. Thus, a consumer of
|
||||
// responseCacher will not receive [io.EOF] unless the cache is successfully
|
||||
// promoted.
|
||||
type responseCacher struct {
|
||||
body io.ReadCloser
|
||||
closeErr *error
|
||||
f *os.File
|
||||
mainDir string
|
||||
stagingDir string
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// Read implements [io.Reader]. It reads into p from the wrapped response body,
|
||||
// appends the received bytes to the staging cache, and returns the number of
|
||||
// bytes read, and any error. When Read encounters [io.EOF] from the response
|
||||
// body, it promotes the staging cache to main, and on success returns [io.EOF].
|
||||
// If an error occurs during cache promotion, Read returns that promotion error
|
||||
// instead of [io.EOF].
|
||||
func (r *responseCacher) Read(p []byte) (n int, err error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
// Use r.body as a sentinel to indicate that the cache
|
||||
// has been closed.
|
||||
if r.body == nil {
|
||||
return 0, errz.New("response cache already closed")
|
||||
}
|
||||
|
||||
n, err = r.body.Read(p)
|
||||
|
||||
switch {
|
||||
case err == nil:
|
||||
err = r.cacheAppend(p, n)
|
||||
return n, err
|
||||
case !errors.Is(err, io.EOF):
|
||||
// It's some other kind of error.
|
||||
// Clean up and return.
|
||||
_ = r.body.Close()
|
||||
r.body = nil
|
||||
_ = r.f.Close()
|
||||
_ = os.Remove(r.f.Name())
|
||||
r.f = nil
|
||||
_ = os.RemoveAll(r.stagingDir)
|
||||
r.stagingDir = ""
|
||||
return n, err
|
||||
default:
|
||||
// It's EOF time! Let's promote the cache.
|
||||
}
|
||||
|
||||
var err2 error
|
||||
if err2 = r.cacheAppend(p, n); err2 != nil {
|
||||
return n, err2
|
||||
}
|
||||
|
||||
if err2 = r.cachePromote(); err2 != nil {
|
||||
return n, err2
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
// cacheAppend appends n bytes from p to the staging cache. If an error occurs,
|
||||
// the staging cache is discarded, and the error is returned.
|
||||
func (r *responseCacher) cacheAppend(p []byte, n int) error {
|
||||
_, err := r.f.Write(p[:n])
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
_ = r.body.Close()
|
||||
r.body = nil
|
||||
_ = r.f.Close()
|
||||
_ = os.Remove(r.f.Name())
|
||||
r.f = nil
|
||||
_ = os.RemoveAll(r.stagingDir)
|
||||
r.stagingDir = ""
|
||||
return errz.Wrap(err, "failed to append http response body bytes to staging cache")
|
||||
}
|
||||
|
||||
fpBodyStaging := filepath.Join(stagingDir, "body")
|
||||
var written int64
|
||||
if written, err = ioz.WriteToFile(ctx, fpBodyStaging, resp.Body); err != nil {
|
||||
log.Warn("Cache write: failed to write cache body file", lga.Err, err, lga.Path, fpBodyStaging)
|
||||
return err
|
||||
// cachePromote is invoked by Read when it receives io.EOF from the wrapped
|
||||
// response body. It promotes the staging cache to main, and on success returns
|
||||
// nil. If an error occurs during promotion, the staging cache is discarded, and
|
||||
// the promotion error is returned.
|
||||
func (r *responseCacher) cachePromote() error {
|
||||
defer func() {
|
||||
if r.f != nil {
|
||||
_ = r.f.Close()
|
||||
r.f = nil
|
||||
}
|
||||
if r.body != nil {
|
||||
_ = r.body.Close()
|
||||
r.body = nil
|
||||
}
|
||||
if r.stagingDir != "" {
|
||||
_ = os.RemoveAll(r.stagingDir)
|
||||
r.stagingDir = ""
|
||||
}
|
||||
}()
|
||||
|
||||
err := r.f.Close()
|
||||
fpBody := r.f.Name()
|
||||
r.f = nil
|
||||
if err != nil {
|
||||
return errz.Wrap(err, "failed to close cache body file")
|
||||
}
|
||||
|
||||
sum, err := checksum.ForFile(fpBodyStaging)
|
||||
err = r.body.Close()
|
||||
r.body = nil
|
||||
if err != nil {
|
||||
return errz.Wrap(err, "failed to close http response body")
|
||||
}
|
||||
|
||||
var sum checksum.Checksum
|
||||
if sum, err = checksum.ForFile(fpBody); err != nil {
|
||||
return errz.Wrap(err, "failed to compute checksum for cache body file")
|
||||
}
|
||||
|
||||
if err = checksum.WriteFile(filepath.Join(stagingDir, "checksums.txt"), sum, "body"); err != nil {
|
||||
if err = checksum.WriteFile(filepath.Join(r.stagingDir, "checksums.txt"), sum, "body"); err != nil {
|
||||
return errz.Wrap(err, "failed to write checksum file for cache body")
|
||||
}
|
||||
|
||||
// We've got good data in the staging dir. Now we do the switcheroo.
|
||||
if err = ioz.RenameDir(stagingDir, mainDir); err != nil {
|
||||
if err = ioz.RenameDir(r.stagingDir, r.mainDir); err != nil {
|
||||
return errz.Wrap(err, "failed to write download cache")
|
||||
}
|
||||
r.stagingDir = ""
|
||||
|
||||
stagingDir = ""
|
||||
log.Info("Updated download cache (full)",
|
||||
lga.Written, written, lga.File, fpBody, lga.Elapsed, time.Since(start).Round(time.Millisecond))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close implements [io.Closer]. Note that cache promotion happens when Read
|
||||
// receives [io.EOF] from the wrapped response body, so the main action should
|
||||
// be over by the time Close is invoked. Note that Close is idempotent, and
|
||||
// returns the same error on subsequent invocations.
|
||||
func (r *responseCacher) Close() error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if r.closeErr != nil {
|
||||
// Already closed
|
||||
return *r.closeErr
|
||||
}
|
||||
|
||||
// There's some duplication of logic with using both r.closeErr
|
||||
// and r.body == nil as sentinels. This could be cleaned up.
|
||||
|
||||
var err error
|
||||
if r.f != nil {
|
||||
err = errz.Append(err, r.f.Close())
|
||||
r.f = nil
|
||||
}
|
||||
|
||||
if r.body != nil {
|
||||
err = errz.Append(err, r.body.Close())
|
||||
r.body = nil
|
||||
}
|
||||
|
||||
if r.stagingDir != "" {
|
||||
err = errz.Append(err, os.RemoveAll(r.stagingDir))
|
||||
r.stagingDir = ""
|
||||
}
|
||||
|
||||
r.closeErr = &err
|
||||
return *r.closeErr
|
||||
}
|
||||
|
@ -104,9 +104,14 @@ func (s State) String() string {
|
||||
const XFromCache = "X-From-Stream"
|
||||
|
||||
// Downloader encapsulates downloading a file from a URL, using a local
|
||||
// disk cache if possible. Downloader.Get makes uses of the Handler callback
|
||||
// mechanism to facilitate early consumption of a download stream while the
|
||||
// download is still in flight.
|
||||
// disk cache if possible. Downloader.Get returns either a filepath to the
|
||||
// already-downloaded file, or a stream of the download in progress, or an
|
||||
// error. If a stream is returned, the Downloader cache is updated when the
|
||||
// stream is fully consumed (this can be observed by the closing of the
|
||||
// channel returned by [streamcache.Stream.Filled]).
|
||||
//
|
||||
// To be extra clear about that last point: the caller must consume the
|
||||
// stream returned by Downloader.Get, or the cache will not be written.
|
||||
type Downloader struct {
|
||||
// c is the HTTP client used to make requests.
|
||||
c *http.Client
|
||||
@ -115,10 +120,6 @@ type Downloader struct {
|
||||
// It will be created in dlDir.
|
||||
cache *cache
|
||||
|
||||
// dlStream is the streamcache.Stream that is passed Handler.Uncached for an
|
||||
// active download. This field is reset to nil on each call to Downloader.Get.
|
||||
dlStream *streamcache.Stream
|
||||
|
||||
// name is a user-friendly name, such as a source handle like @data.
|
||||
name string
|
||||
|
||||
@ -166,27 +167,28 @@ func New(name string, c *http.Client, dlURL, dlDir string) (*Downloader, error)
|
||||
return dl, nil
|
||||
}
|
||||
|
||||
// Get attempts to get the remote file, invoking Handler as appropriate. Exactly
|
||||
// one of the Handler methods will be invoked, one time.
|
||||
// Get attempts to get the remote file, returning either the filepath of the
|
||||
// already-cached file in dlFile, or a stream of a newly-started download in
|
||||
// dlStream, or an error. Exactly one of the return values will be non-nil.
|
||||
//
|
||||
// - If Handler.Uncached is invoked, a download stream has begun. Get will
|
||||
// then block until the download is completed. The download resp.Body is
|
||||
// written to cache, and on success, the filepath to the newly updated
|
||||
// cache file is returned.
|
||||
// If an error occurs during cache write, the error is logged, and Get
|
||||
// returns the filepath of the previously cached download, if permitted
|
||||
// by policy. If not permitted or not existing, empty string is returned.
|
||||
// - If Handler.Cached is invoked, Get returns immediately afterwards with
|
||||
// the filepath of the cached download (the same value provided to
|
||||
// Handler.Cached).
|
||||
// - If Handler.Error is invoked, there was an unrecoverable problem (e.g. a
|
||||
// transport error, and there's no previous cache) and the download is
|
||||
// unavailable. That error should be propagated up the stack. Get will
|
||||
// return empty string.
|
||||
// - If dlFile is non-empty, it is the filepath on disk of the cached download,
|
||||
// and dlStream and err are nil. However, depending on OptContinueOnError,
|
||||
// dlFile may be the path to a stale download. If the cache is stale and a
|
||||
// transport error occurs during refresh, and OptContinueOnError is true,
|
||||
// the previous cached download is returned. If OptContinueOnError is false,
|
||||
// the transport error is returned, and dlFile is empty. The caller can also
|
||||
// check the cache state via [Downloader.State].
|
||||
// - If dlStream is non-nil, it is a stream of the download in progress, and
|
||||
// dlFile is empty. The cache is updated when the stream has been completely
|
||||
// consumed. If the stream is not consumed, the cache is not updated. If an
|
||||
// error occurs reading from the stream, the cache is also not updated: this
|
||||
// means that the cache may still contain the previous (stale) download.
|
||||
// - If err is non-nil, there was an unrecoverable problem (e.g. a transport
|
||||
// error, and there's no previous cache) and the download is unavailable.
|
||||
//
|
||||
// Get consults the context for options. In particular, it makes
|
||||
// use of OptCache and OptContinueOnError.
|
||||
func (dl *Downloader) Get(ctx context.Context, h Handler) (cacheFile string) {
|
||||
// Get consults the context for options. In particular, it makes use of OptCache
|
||||
// and OptContinueOnError.
|
||||
func (dl *Downloader) Get(ctx context.Context) (dlFile string, dlStream *streamcache.Stream, err error) {
|
||||
dl.mu.Lock()
|
||||
defer dl.mu.Unlock()
|
||||
|
||||
@ -198,19 +200,16 @@ func (dl *Downloader) Get(ctx context.Context, h Handler) (cacheFile string) {
|
||||
|
||||
req := dl.mustRequest(ctx)
|
||||
lg.FromContext(ctx).Debug("Get download", lga.URL, dl.url)
|
||||
cacheFile = dl.get(req, h)
|
||||
return cacheFile
|
||||
return dl.get(req)
|
||||
}
|
||||
|
||||
// get contains the main logic for getting the download.
|
||||
// It invokes Handler as appropriate, and on success returns the
|
||||
// filepath of the valid cached download ›file.
|
||||
func (dl *Downloader) get(req *http.Request, h Handler) (cacheFile string) { //nolint:gocognit,funlen,cyclop
|
||||
func (dl *Downloader) get(req *http.Request) (dlFile string, //nolint:gocognit,funlen,cyclop
|
||||
dlStream *streamcache.Stream, err error,
|
||||
) {
|
||||
ctx := req.Context()
|
||||
log := lg.FromContext(ctx)
|
||||
|
||||
dl.dlStream = nil
|
||||
|
||||
var fpBody string
|
||||
if dl.cache != nil {
|
||||
_, fpBody, _ = dl.cache.paths(req)
|
||||
@ -219,12 +218,10 @@ func (dl *Downloader) get(req *http.Request, h Handler) (cacheFile string) { //n
|
||||
state := dl.state(req)
|
||||
if state == Fresh && fpBody != "" {
|
||||
// The cached response is fresh, so we can return it.
|
||||
h.Cached(fpBody)
|
||||
return fpBody
|
||||
return fpBody, nil, nil
|
||||
}
|
||||
|
||||
cacheable := dl.isCacheable(req)
|
||||
var err error
|
||||
var cachedResp *http.Response
|
||||
if cacheable {
|
||||
cachedResp, err = dl.cache.get(req.Context(), req) //nolint:bodyclose
|
||||
@ -241,8 +238,7 @@ func (dl *Downloader) get(req *http.Request, h Handler) (cacheFile string) { //n
|
||||
freshness := getFreshness(cachedResp.Header, req.Header)
|
||||
if freshness == Fresh && fpBody != "" {
|
||||
lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, cachedResp.Body)
|
||||
h.Cached(fpBody)
|
||||
return fpBody
|
||||
return fpBody, nil, nil
|
||||
}
|
||||
|
||||
if freshness == Stale {
|
||||
@ -279,12 +275,11 @@ func (dl *Downloader) get(req *http.Request, h Handler) (cacheFile string) { //n
|
||||
|
||||
case fpBody != "" && (err != nil || resp.StatusCode >= 500) &&
|
||||
req.Method == http.MethodGet && canStaleOnError(cachedResp.Header, req.Header):
|
||||
// In case of transport failure and stale-if-error activated, returns cached content
|
||||
// when available.
|
||||
// In case of transport failure canStaleOnError returns true,
|
||||
// return the stale cached download.
|
||||
lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, cachedResp.Body)
|
||||
log.Warn("Returning cached response due to transport failure", lga.Err, err)
|
||||
h.Cached(fpBody)
|
||||
return fpBody
|
||||
return fpBody, nil, nil
|
||||
|
||||
default:
|
||||
if err != nil && resp != nil && resp.StatusCode != http.StatusOK {
|
||||
@ -293,46 +288,41 @@ func (dl *Downloader) get(req *http.Request, h Handler) (cacheFile string) { //n
|
||||
|
||||
if fp := dl.cacheFileOnError(req, err); fp != "" {
|
||||
lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, resp.Body)
|
||||
h.Cached(fp)
|
||||
return fp
|
||||
return fp, nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, resp.Body)
|
||||
if fp := dl.cacheFileOnError(req, err); fp != "" {
|
||||
h.Cached(fp)
|
||||
return fp
|
||||
return fp, nil, nil
|
||||
}
|
||||
h.Error(err)
|
||||
return ""
|
||||
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, resp.Body)
|
||||
err = errz.Errorf("download: unexpected HTTP status: %s", httpz.StatusText(resp.StatusCode))
|
||||
if fp := dl.cacheFileOnError(req, err); fp != "" {
|
||||
h.Cached(fp)
|
||||
return fp
|
||||
return fp, nil, nil
|
||||
}
|
||||
|
||||
h.Error(err)
|
||||
return ""
|
||||
return "", nil, err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
reqCacheControl := parseCacheControl(req.Header)
|
||||
if _, ok := reqCacheControl["only-if-cached"]; ok {
|
||||
resp = newGatewayTimeoutResponse(req) //nolint:bodyclose
|
||||
resp = newGatewayTimeoutResponse(req)
|
||||
} else {
|
||||
resp, err = dl.do(req) //nolint:bodyclose
|
||||
if err != nil {
|
||||
if fp := dl.cacheFileOnError(req, err); fp != "" {
|
||||
h.Cached(fp)
|
||||
return fp
|
||||
return fp, nil, nil
|
||||
}
|
||||
h.Error(err)
|
||||
return ""
|
||||
|
||||
return "", nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -348,59 +338,55 @@ func (dl *Downloader) get(req *http.Request, h Handler) (cacheFile string) { //n
|
||||
}
|
||||
|
||||
if resp == cachedResp {
|
||||
err = dl.cache.write(ctx, resp, true)
|
||||
err = dl.cache.writeHeader(ctx, resp)
|
||||
lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, resp.Body)
|
||||
if err != nil {
|
||||
log.Error("Failed to update cache header", lga.Dir, dl.cache.dir, lga.Err, err)
|
||||
if fp := dl.cacheFileOnError(req, err); fp != "" {
|
||||
h.Cached(fp)
|
||||
return fp
|
||||
return fp, nil, nil
|
||||
}
|
||||
h.Error(err)
|
||||
return ""
|
||||
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
if fpBody != "" {
|
||||
h.Cached(fpBody)
|
||||
return fpBody
|
||||
return fpBody, nil, nil
|
||||
}
|
||||
} else if cachedResp != nil && cachedResp.Body != nil {
|
||||
lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, cachedResp.Body)
|
||||
}
|
||||
|
||||
dl.dlStream = streamcache.New(resp.Body)
|
||||
resp.Body = dl.dlStream.NewReader(ctx)
|
||||
h.Uncached(dl.dlStream)
|
||||
|
||||
if err = dl.cache.write(req.Context(), resp, false); err != nil {
|
||||
// We don't explicitly call Handler.Error: it would be "illegal" to do so
|
||||
// anyway, because the Handler docs state that at most one Handler callback
|
||||
// func is ever invoked.
|
||||
//
|
||||
// The cache write could fail for one of two reasons:
|
||||
//
|
||||
// - The download didn't complete successfully: that is, there was an error
|
||||
// reading from resp.Body. In this case, that same error will be propagated
|
||||
// to the Handler via the streamcache.Stream that was provided to Handler.Uncached.
|
||||
// - The download completed, but there was a problem writing out the cache
|
||||
// files (header, body, checksum). This is likely a very rare occurrence.
|
||||
// In that case, any previous cache files are left untouched by cache.write,
|
||||
// and all we do is log the error. If the cache is inconsistent, it will
|
||||
// repair itself on next invocation, so it's not a big deal.
|
||||
log.Warn("Failed to write download cache", lga.Dir, dl.cache.dir, lga.Err, err)
|
||||
lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, resp.Body)
|
||||
return ""
|
||||
// OK, this is where the funky stuff happens.
|
||||
//
|
||||
// First, note that the cache is two-stage: there's a staging cache, and a
|
||||
// main cache. The staging cache is used to write the response body to disk,
|
||||
// and when the response body is fully consumed, the staging cache is
|
||||
// promoted to main, in a sort of atomic-swap-lite. This is done to avoid
|
||||
// partially-written cache files in the main cache, and other such nastiness.
|
||||
//
|
||||
// The responseCacher type is an io.ReadCloser that wraps the response body.
|
||||
// As its Read method is called, it writes the body bytes to a staging cache
|
||||
// file (and also returns them to the caller). When rCacher encounters io.EOF,
|
||||
// it promotes the staging cache to main before returning io.EOF to the
|
||||
// caller. If promotion fails, the promotion error (not io.EOF) is returned
|
||||
// to the caller. Thus, it is guaranteed that any caller of rCacher's Read
|
||||
// method will only receive io.EOF if the cache has been promoted to main.
|
||||
var rCacher *responseCacher
|
||||
if rCacher, err = dl.cache.newResponseCacher(ctx, resp); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, resp.Body)
|
||||
return fpBody
|
||||
|
||||
// And now we wrap rCacher in a streamcache: any streamcache readers will
|
||||
// only receive io.EOF when/if the staging cache has been promoted to main.
|
||||
dlStream = streamcache.New(rCacher)
|
||||
return "", dlStream, nil
|
||||
}
|
||||
|
||||
// It's not cacheable, so we can just wrap resp.Body in a streamcache
|
||||
// and return it.
|
||||
dl.dlStream = streamcache.New(resp.Body)
|
||||
// The response is not cacheable, so we can just wrap resp.Body in a
|
||||
// streamcache and return it.
|
||||
dlStream = streamcache.New(resp.Body)
|
||||
resp.Body = nil // Unnecessary, but just to be explicit.
|
||||
h.Uncached(dl.dlStream)
|
||||
return ""
|
||||
return "", dlStream, nil
|
||||
}
|
||||
|
||||
// do executes the request.
|
||||
@ -494,42 +480,12 @@ func (dl *Downloader) state(req *http.Request) State {
|
||||
return getFreshness(cachedResp.Header, req.Header)
|
||||
}
|
||||
|
||||
// Filesize returns the size of the downloaded file. This should
|
||||
// only be invoked after the download has completed or is cached,
|
||||
// as it may block until the download completes.
|
||||
func (dl *Downloader) Filesize(ctx context.Context) (int64, error) {
|
||||
dl.mu.Lock()
|
||||
defer dl.mu.Unlock()
|
||||
|
||||
if dl.dlStream != nil {
|
||||
// There's an active download, so we can get the filesize
|
||||
// when the download completes.
|
||||
size, err := dl.dlStream.Total(ctx)
|
||||
return int64(size), err
|
||||
}
|
||||
|
||||
if dl.cache == nil {
|
||||
return 0, errz.New("download file size not available")
|
||||
}
|
||||
|
||||
req := dl.mustRequest(ctx)
|
||||
if !dl.cache.exists(req) {
|
||||
// It's not in the cache.
|
||||
return 0, errz.New("download file size not available")
|
||||
}
|
||||
|
||||
// It's in the cache.
|
||||
_, fp, _ := dl.cache.paths(req)
|
||||
fi, err := os.Stat(fp)
|
||||
if err != nil {
|
||||
return 0, errz.Wrapf(err, "unable to stat cached download file: %s", fp)
|
||||
}
|
||||
|
||||
return fi.Size(), nil
|
||||
}
|
||||
|
||||
// CacheFile returns the path to the cached file, if it exists and has
|
||||
// been fully downloaded.
|
||||
// CacheFile returns the path to the cached file, if it exists. If there's
|
||||
// a download in progress ([Downloader.Get] returned a stream), then CacheFile
|
||||
// may return the filepath to the previously cached file. The caller should
|
||||
// wait on any previously returned download stream to complete to ensure
|
||||
// that the returned filepath is that of the current download. The caller
|
||||
// can also check the cache state via [Downloader.State].
|
||||
func (dl *Downloader) CacheFile(ctx context.Context) (fp string, err error) {
|
||||
dl.mu.Lock()
|
||||
defer dl.mu.Unlock()
|
||||
|
@ -12,23 +12,101 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/neilotoole/streamcache"
|
||||
|
||||
"github.com/neilotoole/sq/libsq/core/ioz"
|
||||
"github.com/neilotoole/sq/libsq/core/lg/lga"
|
||||
"github.com/neilotoole/sq/libsq/core/stringz"
|
||||
"github.com/neilotoole/sq/testh/sakila"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/neilotoole/sq/libsq/core/ioz"
|
||||
"github.com/neilotoole/sq/libsq/core/ioz/httpz"
|
||||
"github.com/neilotoole/sq/libsq/core/lg"
|
||||
"github.com/neilotoole/sq/libsq/core/lg/lga"
|
||||
"github.com/neilotoole/sq/libsq/core/lg/lgt"
|
||||
"github.com/neilotoole/sq/libsq/core/options"
|
||||
"github.com/neilotoole/sq/libsq/core/stringz"
|
||||
"github.com/neilotoole/sq/libsq/files"
|
||||
"github.com/neilotoole/sq/libsq/files/internal/downloader"
|
||||
"github.com/neilotoole/sq/testh/sakila"
|
||||
"github.com/neilotoole/sq/testh/tu"
|
||||
)
|
||||
|
||||
func TestDownload_redirect(t *testing.T) {
|
||||
func TestDownloader(t *testing.T) {
|
||||
const dlURL = sakila.ActorCSVURL
|
||||
log := lgt.New(t)
|
||||
ctx := lg.NewContext(context.Background(), log)
|
||||
|
||||
cacheDir := t.TempDir()
|
||||
|
||||
dl, gotErr := downloader.New(t.Name(), httpz.NewDefaultClient(), dlURL, cacheDir)
|
||||
require.NoError(t, gotErr)
|
||||
require.NoError(t, dl.Clear(ctx))
|
||||
require.Equal(t, downloader.Uncached, dl.State(ctx))
|
||||
sum, ok := dl.Checksum(ctx)
|
||||
require.False(t, ok)
|
||||
require.Empty(t, sum)
|
||||
|
||||
// Here's our first download, it's not cached, so we should
|
||||
// get a stream.
|
||||
gotFile, gotStream, gotErr := dl.Get(ctx)
|
||||
require.NoError(t, gotErr)
|
||||
require.Empty(t, gotFile)
|
||||
require.NotNil(t, gotStream)
|
||||
require.Equal(t, downloader.Uncached, dl.State(ctx))
|
||||
|
||||
// The stream should not be filled yet, because we
|
||||
// haven't read from it.
|
||||
tu.RequireNoTake(t, gotStream.Filled())
|
||||
|
||||
// And there should be no cache file, because the cache file
|
||||
// isn't written until the stream is drained.
|
||||
gotFile, gotErr = dl.CacheFile(ctx)
|
||||
require.Error(t, gotErr)
|
||||
require.Empty(t, gotFile)
|
||||
|
||||
r := gotStream.NewReader(ctx)
|
||||
gotStream.Seal()
|
||||
|
||||
// Now we drain the stream, and the cache should magically fill.
|
||||
var gotN int
|
||||
gotN, gotErr = ioz.DrainClose(r)
|
||||
require.NoError(t, gotErr)
|
||||
require.Equal(t, sakila.ActorCSVSize, gotN)
|
||||
tu.RequireTake(t, gotStream.Filled())
|
||||
tu.RequireTake(t, gotStream.Done())
|
||||
require.Equal(t, sakila.ActorCSVSize, gotStream.Size())
|
||||
require.Equal(t, downloader.Fresh, dl.State(ctx))
|
||||
|
||||
// Now we should be able to access the cache.
|
||||
sum, ok = dl.Checksum(ctx)
|
||||
require.True(t, ok)
|
||||
require.NotEmpty(t, sum)
|
||||
gotFile, gotErr = dl.CacheFile(ctx)
|
||||
require.NoError(t, gotErr)
|
||||
require.NotEmpty(t, gotFile)
|
||||
gotSize, gotErr := ioz.Filesize(gotFile)
|
||||
require.NoError(t, gotErr)
|
||||
require.Equal(t, sakila.ActorCSVSize, int(gotSize))
|
||||
|
||||
// Let's download again, and verify that the cache is used.
|
||||
gotFile, gotStream, gotErr = dl.Get(ctx)
|
||||
require.Nil(t, gotErr)
|
||||
require.Nil(t, gotStream)
|
||||
require.NotEmpty(t, gotFile)
|
||||
|
||||
gotFileBytes, gotErr := os.ReadFile(gotFile)
|
||||
require.NoError(t, gotErr)
|
||||
require.Equal(t, sakila.ActorCSVSize, len(gotFileBytes))
|
||||
require.Equal(t, downloader.Fresh, dl.State(ctx))
|
||||
sum, ok = dl.Checksum(ctx)
|
||||
require.True(t, ok)
|
||||
require.NotEmpty(t, sum)
|
||||
|
||||
require.NoError(t, dl.Clear(ctx))
|
||||
require.Equal(t, downloader.Uncached, dl.State(ctx))
|
||||
sum, ok = dl.Checksum(ctx)
|
||||
require.False(t, ok)
|
||||
require.Empty(t, sum)
|
||||
}
|
||||
|
||||
func TestDownloader_redirect(t *testing.T) {
|
||||
const hello = `Hello World!`
|
||||
serveBody := hello
|
||||
lastModified := time.Now().UTC()
|
||||
@ -81,100 +159,28 @@ func TestDownload_redirect(t *testing.T) {
|
||||
dl, err := downloader.New(t.Name(), httpz.NewDefaultClient(), loc, cacheDir)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, dl.Clear(ctx))
|
||||
h := downloader.NewSinkHandler(log.With("origin", "handler"))
|
||||
|
||||
gotGetFile := dl.Get(ctx, h.Handler)
|
||||
require.Empty(t, h.Errors)
|
||||
require.NotEmpty(t, gotGetFile)
|
||||
gotBody := tu.ReadToString(t, h.Streams[0].NewReader(ctx))
|
||||
gotFile, gotStream, gotErr := dl.Get(ctx)
|
||||
require.NoError(t, gotErr)
|
||||
require.NotNil(t, gotStream)
|
||||
require.Empty(t, gotFile)
|
||||
gotBody := tu.ReadToString(t, gotStream.NewReader(ctx))
|
||||
require.Equal(t, hello, gotBody)
|
||||
|
||||
h.Reset()
|
||||
gotGetFile = dl.Get(ctx, h.Handler)
|
||||
require.NotEmpty(t, gotGetFile)
|
||||
require.Empty(t, h.Errors)
|
||||
require.Empty(t, h.Streams)
|
||||
gotDownloadedFile := h.Downloaded[0]
|
||||
t.Logf("got fp: %s", gotDownloadedFile)
|
||||
gotBody = tu.ReadFileToString(t, gotDownloadedFile)
|
||||
gotFile, gotStream, gotErr = dl.Get(ctx)
|
||||
require.NoError(t, gotErr)
|
||||
require.Nil(t, gotStream)
|
||||
require.NotEmpty(t, gotFile)
|
||||
t.Logf("got fp: %s", gotFile)
|
||||
gotBody = tu.ReadFileToString(t, gotFile)
|
||||
t.Logf("got body: \n\n%s\n\n", gotBody)
|
||||
require.Equal(t, serveBody, gotBody)
|
||||
|
||||
h.Reset()
|
||||
gotGetFile = dl.Get(ctx, h.Handler)
|
||||
require.NotEmpty(t, gotGetFile)
|
||||
require.Empty(t, h.Errors)
|
||||
require.Empty(t, h.Streams)
|
||||
gotDownloadedFile = h.Downloaded[0]
|
||||
t.Logf("got fp: %s", gotDownloadedFile)
|
||||
gotBody = tu.ReadFileToString(t, gotDownloadedFile)
|
||||
t.Logf("got body: \n\n%s\n\n", gotBody)
|
||||
require.Equal(t, serveBody, gotBody)
|
||||
}
|
||||
|
||||
func TestDownload_New(t *testing.T) {
|
||||
log := lgt.New(t)
|
||||
ctx := lg.NewContext(context.Background(), log)
|
||||
const dlURL = sakila.ActorCSVURL
|
||||
|
||||
cacheDir := t.TempDir()
|
||||
|
||||
dl, err := downloader.New(t.Name(), httpz.NewDefaultClient(), dlURL, cacheDir)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, dl.Clear(ctx))
|
||||
require.Equal(t, downloader.Uncached, dl.State(ctx))
|
||||
sum, ok := dl.Checksum(ctx)
|
||||
require.False(t, ok)
|
||||
require.Empty(t, sum)
|
||||
|
||||
h := downloader.NewSinkHandler(log.With("origin", "handler"))
|
||||
gotGetFile := dl.Get(ctx, h.Handler)
|
||||
require.NotEmpty(t, gotGetFile)
|
||||
require.Empty(t, h.Errors)
|
||||
require.Empty(t, h.Downloaded)
|
||||
require.Equal(t, 1, len(h.Streams))
|
||||
require.Equal(t, int64(sakila.ActorCSVSize), int64(h.Streams[0].Size()))
|
||||
require.Equal(t, downloader.Fresh, dl.State(ctx))
|
||||
sum, ok = dl.Checksum(ctx)
|
||||
require.True(t, ok)
|
||||
require.NotEmpty(t, sum)
|
||||
|
||||
h.Reset()
|
||||
gotGetFile = dl.Get(ctx, h.Handler)
|
||||
require.NotEmpty(t, gotGetFile)
|
||||
require.Empty(t, h.Errors)
|
||||
require.Empty(t, h.Streams)
|
||||
require.NotEmpty(t, h.Downloaded)
|
||||
require.Equal(t, gotGetFile, h.Downloaded[0])
|
||||
gotFileBytes, err := os.ReadFile(h.Downloaded[0])
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, sakila.ActorCSVSize, len(gotFileBytes))
|
||||
require.Equal(t, downloader.Fresh, dl.State(ctx))
|
||||
sum, ok = dl.Checksum(ctx)
|
||||
require.True(t, ok)
|
||||
require.NotEmpty(t, sum)
|
||||
|
||||
require.NoError(t, dl.Clear(ctx))
|
||||
require.Equal(t, downloader.Uncached, dl.State(ctx))
|
||||
sum, ok = dl.Checksum(ctx)
|
||||
require.False(t, ok)
|
||||
require.Empty(t, sum)
|
||||
|
||||
h.Reset()
|
||||
gotGetFile = dl.Get(ctx, h.Handler)
|
||||
require.Empty(t, h.Errors)
|
||||
require.NotEmpty(t, gotGetFile)
|
||||
require.True(t, ioz.FileAccessible(gotGetFile))
|
||||
|
||||
h.Reset()
|
||||
}
|
||||
|
||||
func TestCachePreservedOnFailedRefresh(t *testing.T) {
|
||||
o := options.Options{files.OptHTTPResponseTimeout.Key(): "10m"}
|
||||
ctx := options.NewContext(context.Background(), o)
|
||||
ctx := lg.NewContext(context.Background(), lgt.New(t))
|
||||
|
||||
var (
|
||||
log = lgt.New(t)
|
||||
srvr *httptest.Server
|
||||
srvrShouldBodyError bool
|
||||
srvrShouldNoCache bool
|
||||
@ -216,32 +222,34 @@ func TestCachePreservedOnFailedRefresh(t *testing.T) {
|
||||
}))
|
||||
t.Cleanup(srvr.Close)
|
||||
|
||||
ctx = lg.NewContext(ctx, log.With("origin", "downloader"))
|
||||
cacheDir := filepath.Join(t.TempDir(), stringz.UniqSuffix("dlcache"))
|
||||
dl, err := downloader.New(t.Name(), httpz.NewDefaultClient(), srvr.URL, cacheDir)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, dl.Clear(ctx))
|
||||
h := downloader.NewSinkHandler(log.With("origin", "handler"))
|
||||
|
||||
gotGetFile := dl.Get(ctx, h.Handler)
|
||||
require.Empty(t, h.Errors)
|
||||
require.NotEmpty(t, h.Streams)
|
||||
require.NotEmpty(t, gotGetFile, "cache file should have been filled")
|
||||
require.True(t, ioz.FileAccessible(gotGetFile))
|
||||
|
||||
stream := h.Streams[0]
|
||||
start := time.Now()
|
||||
var gotFile string
|
||||
var gotStream *streamcache.Stream
|
||||
gotFile, gotStream, err = dl.Get(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, gotFile)
|
||||
require.NotNil(t, gotStream)
|
||||
tu.RequireNoTake(t, gotStream.Filled())
|
||||
r := gotStream.NewReader(ctx)
|
||||
gotStream.Seal()
|
||||
t.Logf("Waiting for download to complete")
|
||||
<-stream.Filled()
|
||||
start := time.Now()
|
||||
|
||||
var gotN int
|
||||
gotN, err = ioz.DrainClose(r)
|
||||
require.NoError(t, err)
|
||||
t.Logf("Download completed after %s", time.Since(start))
|
||||
require.True(t, errors.Is(stream.Err(), io.EOF))
|
||||
gotSize, err := dl.Filesize(ctx)
|
||||
tu.RequireTake(t, gotStream.Filled())
|
||||
tu.RequireTake(t, gotStream.Done())
|
||||
require.Equal(t, len(sentBody), gotN)
|
||||
|
||||
require.True(t, errors.Is(gotStream.Err(), io.EOF))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, len(sentBody), int(gotSize))
|
||||
require.Equal(t, len(sentBody), stream.Size())
|
||||
gotFilesize, err := dl.Filesize(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, len(sentBody), int(gotFilesize))
|
||||
require.Equal(t, len(sentBody), gotStream.Size())
|
||||
|
||||
fpBody, err := dl.CacheFile(ctx)
|
||||
require.NoError(t, err)
|
||||
@ -256,23 +264,31 @@ func TestCachePreservedOnFailedRefresh(t *testing.T) {
|
||||
fiChecksums1 := tu.MustStat(t, fpChecksums)
|
||||
|
||||
srvrShouldBodyError = true
|
||||
h.Reset()
|
||||
|
||||
// Sleep to allow file modification timestamps to tick
|
||||
time.Sleep(time.Millisecond * 10)
|
||||
gotGetFile = dl.Get(ctx, h.Handler)
|
||||
require.Empty(t, h.Errors)
|
||||
require.Empty(t, gotGetFile,
|
||||
"gotCacheFile should be empty, because the server returned an error during cache write")
|
||||
require.NotEmpty(t, h.Streams,
|
||||
"h.Streams should not be empty, because the download was in fact initiated")
|
||||
stream = h.Streams[0]
|
||||
<-stream.Filled()
|
||||
err = stream.Err()
|
||||
gotFile, gotStream, err = dl.Get(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, gotFile,
|
||||
"gotFile should be empty, because the server returned an error during cache write")
|
||||
require.NotNil(t, gotStream,
|
||||
"gotStream should not be empty, because the download was in fact initiated")
|
||||
|
||||
r = gotStream.NewReader(ctx)
|
||||
gotStream.Seal()
|
||||
|
||||
gotN, err = ioz.DrainClose(r)
|
||||
require.Equal(t, len(sentBody), gotN)
|
||||
tu.RequireTake(t, gotStream.Filled())
|
||||
tu.RequireTake(t, gotStream.Done())
|
||||
|
||||
streamErr := gotStream.Err()
|
||||
require.Error(t, streamErr)
|
||||
require.True(t, errors.Is(err, streamErr))
|
||||
t.Logf("got stream err: %v", err)
|
||||
require.Error(t, err)
|
||||
|
||||
require.True(t, errors.Is(err, io.ErrUnexpectedEOF))
|
||||
require.Equal(t, len(sentBody), stream.Size())
|
||||
require.Equal(t, len(sentBody), gotStream.Size())
|
||||
|
||||
// Verify that the server hasn't updated the cache,
|
||||
// by checking that the file modification timestamps
|
||||
@ -284,6 +300,4 @@ func TestCachePreservedOnFailedRefresh(t *testing.T) {
|
||||
require.True(t, ioz.FileInfoEqual(fiBody1, fiBody2))
|
||||
require.True(t, ioz.FileInfoEqual(fiHeader1, fiHeader2))
|
||||
require.True(t, ioz.FileInfoEqual(fiChecksums1, fiChecksums2))
|
||||
|
||||
h.Reset()
|
||||
}
|
||||
|
@ -1,6 +1,9 @@
|
||||
// Package location contains functionality related to source location.
|
||||
package location
|
||||
|
||||
// NOTE: This package contains code from several eras. There's a bunch of
|
||||
// overlap and duplication. It should be consolidated.
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"path"
|
||||
@ -327,11 +330,11 @@ func isFpath(loc string) (fpath string, ok bool) {
|
||||
type Type string
|
||||
|
||||
const (
|
||||
TypeStdin = "stdin"
|
||||
TypeLocalFile = "local_file"
|
||||
TypeSQL = "sql"
|
||||
TypeRemoteFile = "remote_file"
|
||||
TypeUnknown = "unknown"
|
||||
TypeStdin = "stdin"
|
||||
TypeFile = "local_file"
|
||||
TypeSQL = "sql"
|
||||
TypeHTTP = "http_file"
|
||||
TypeUnknown = "unknown"
|
||||
)
|
||||
|
||||
// TypeOf returns the type of loc, or locTypeUnknown if it
|
||||
@ -345,14 +348,14 @@ func TypeOf(loc string) Type {
|
||||
return TypeSQL
|
||||
case strings.HasPrefix(loc, "http://"),
|
||||
strings.HasPrefix(loc, "https://"):
|
||||
return TypeRemoteFile
|
||||
return TypeHTTP
|
||||
default:
|
||||
}
|
||||
|
||||
if _, err := filepath.Abs(loc); err != nil {
|
||||
return TypeUnknown
|
||||
}
|
||||
return TypeLocalFile
|
||||
return TypeFile
|
||||
}
|
||||
|
||||
// isHTTP tests if s is a well-structured HTTP or HTTPS url, and
|
||||
|
@ -216,6 +216,26 @@ var (
|
||||
_ AssertCompareFunc = require.Greater
|
||||
)
|
||||
|
||||
// RequireNoTake fails if a value is taken from c.
|
||||
func RequireNoTake[C any](tb testing.TB, c <-chan C, msgAndArgs ...any) {
|
||||
tb.Helper()
|
||||
select {
|
||||
case <-c:
|
||||
require.Fail(tb, "unexpected take from channel", msgAndArgs...)
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// RequireTake fails if a value is not taken from c.
|
||||
func RequireTake[C any](tb testing.TB, c <-chan C, msgAndArgs ...any) {
|
||||
tb.Helper()
|
||||
select {
|
||||
case <-c:
|
||||
default:
|
||||
require.Fail(tb, "unexpected failure to take from channel", msgAndArgs...)
|
||||
}
|
||||
}
|
||||
|
||||
// DirCopy copies the contents of sourceDir to a temp dir.
|
||||
// If keep is false, temp dir will be cleaned up on test exit.
|
||||
func DirCopy(tb testing.TB, sourceDir string, keep bool) (tmpDir string) {
|
||||
@ -356,6 +376,47 @@ func MustStat(tb testing.TB, fp string) os.FileInfo {
|
||||
return fi
|
||||
}
|
||||
|
||||
// MustDrain drains r, failing t on error. If arg cloze is true,
|
||||
// r is closed if it's an io.Closer, even if the drain fails.
|
||||
// FIXME: delete this func.
|
||||
func MustDrain(tb testing.TB, r io.Reader, cloze bool) {
|
||||
tb.Helper()
|
||||
_, cpErr := io.Copy(io.Discard, r)
|
||||
if !cloze {
|
||||
require.NoError(tb, cpErr)
|
||||
return
|
||||
}
|
||||
|
||||
var closeErr error
|
||||
if rc, ok := r.(io.Closer); ok {
|
||||
closeErr = rc.Close()
|
||||
}
|
||||
|
||||
require.NoError(tb, cpErr)
|
||||
require.NoError(tb, closeErr)
|
||||
}
|
||||
|
||||
// MustDrainN is like MustDrain, but also reports the number of bytes
|
||||
// drained. If arg cloze is true, r is closed if it's an io.Closer,
|
||||
// even if the drain fails.
|
||||
func MustDrainN(tb testing.TB, r io.Reader, cloze bool) int {
|
||||
tb.Helper()
|
||||
n, cpErr := io.Copy(io.Discard, r)
|
||||
if !cloze {
|
||||
require.NoError(tb, cpErr)
|
||||
return int(n)
|
||||
}
|
||||
|
||||
var closeErr error
|
||||
if rc, ok := r.(io.Closer); ok {
|
||||
closeErr = rc.Close()
|
||||
}
|
||||
|
||||
require.NoError(tb, cpErr)
|
||||
require.NoError(tb, closeErr)
|
||||
return int(n)
|
||||
}
|
||||
|
||||
// TempDir is the standard means for obtaining a temp dir for tests.
|
||||
// If arg clean is true, the temp dir is created via t.TempDir, and
|
||||
// thus is deleted on test cleanup.
|
||||
|
Loading…
Reference in New Issue
Block a user