Ingest checksum issue (#378)

- Refactor `Downloader.Get`.
This commit is contained in:
Neil O'Toole 2024-01-28 23:55:25 -07:00 committed by GitHub
parent 9f59bc4c76
commit 9a0b9b7a9c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 684 additions and 419 deletions

View File

@ -7,7 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
Breaking changes are annotated with ☢️, and alpha/beta features with 🐥.
## [v0.47.0] - 2024-01-28
## [v0.47.0] - UPCOMING
This is a significant release, focused on improving i/o, responsiveness,
and performance. The headline features are [caching](https://sq.io/docs/source#cache)

View File

@ -5,6 +5,7 @@ import (
"bytes"
"context"
crand "crypto/rand"
"errors"
"fmt"
"io"
mrand "math/rand"
@ -382,6 +383,87 @@ func (w *notifyOnceWriter) Write(p []byte) (n int, err error) {
return w.w.Write(p)
}
// NotifyOnEOFReader returns an [io.Reader] that invokes fn
// when r.Read returns [io.EOF]. The error that fn returns is
// what's returned to the r caller: fn can transform the error
// or return it unchanged. If r or fn is nil, r is returned.
//
// If r is an [io.ReadCloser], the returned reader will also
// implement [io.ReadCloser].
//
// See also: [NotifyOnErrorReader], which is a generalization of
// [NotifyOnEOFReader].
func NotifyOnEOFReader(r io.Reader, fn func(error) error) io.Reader {
if r == nil || fn == nil {
return r
}
if rc, ok := r.(io.ReadCloser); ok {
return &notifyOnEOFReadCloser{notifyOnEOFReader{r: rc, fn: fn}}
}
return &notifyOnEOFReader{r: r}
}
type notifyOnEOFReader struct {
r io.Reader
fn func(error) error
}
// Read implements io.Reader.
func (r *notifyOnEOFReader) Read(p []byte) (n int, err error) {
n, err = r.r.Read(p)
if err != nil && errors.Is(err, io.EOF) {
err = r.fn(err)
}
return n, err
}
var _ io.ReadCloser = (*notifyOnEOFReadCloser)(nil)
type notifyOnEOFReadCloser struct {
notifyOnEOFReader
}
// Close implements io.Closer.
func (r *notifyOnEOFReadCloser) Close() error {
if c, ok := r.r.(io.Closer); ok {
return c.Close()
}
return nil
}
// NotifyOnErrorReader returns an [io.Reader] that invokes fn
// when r.Read returns an error. The error that fn returns is
// what's returned to the r caller: fn can transform the error
// or return it unchanged. If r or fn is nil, r is returned.
//
// See also: [NotifyOnEOFReader], which is a specialization of
// [NotifyOnErrorReader].
func NotifyOnErrorReader(r io.Reader, fn func(error) error) io.Reader {
if r == nil || fn == nil {
return r
}
return &notifyOnErrorReader{r: r}
}
type notifyOnErrorReader struct {
r io.Reader
fn func(error) error
}
// Read implements io.Reader.
func (r *notifyOnErrorReader) Read(p []byte) (n int, err error) {
n, err = r.r.Read(p)
if err != nil {
err = r.fn(err)
}
return n, err
}
// WriteCloser returns w as an io.WriteCloser. If w implements
// io.WriteCloser, w is returned. Otherwise, w is wrapped in a
// no-op decorator that implements io.WriteCloser.
@ -458,10 +540,35 @@ func DirExists(dir string) bool {
return fi.IsDir()
}
// Drain drains r.
func Drain(r io.Reader) error {
_, err := io.Copy(io.Discard, r)
return err
// DrainClose drains rc, returning the number of bytes read, and any error.
// The reader is always closed, even if the drain operation returned an error.
// If both the drain and the close operations return non-nil errors, the drain
// error is returned.
func DrainClose(rc io.ReadCloser) (n int, err error) {
var n64 int64
n64, err = io.Copy(io.Discard, rc)
n = int(n64)
closeErr := rc.Close()
if err == nil {
err = closeErr
}
return n, errz.Err(err)
}
// Filesize returns the size of the file at fp. An error is returned
// if fp doesn't exist or is a directory.
func Filesize(fp string) (int64, error) {
fi, err := os.Stat(fp)
if err != nil {
return 0, errz.Err(err)
}
if fi.IsDir() {
return 0, errz.Errorf("not a file: %s", fp)
}
return fi.Size(), nil
}
// FileInfoEqual returns true if a and b are equal.

View File

@ -3,7 +3,9 @@ package files
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"os"
"path/filepath"
"strings"
@ -73,15 +75,44 @@ func (fs *Files) CacheDirFor(src *source.Source) (dir string, err error) {
}
// WriteIngestChecksum is invoked (after successful ingestion) to write the
// checksum for the ingest cache db.
// checksum of the source document file vs the ingest DB. Thus, if the source
// document changes, the checksum will no longer match, and the ingest DB
// will be considered invalid.
func (fs *Files) WriteIngestChecksum(ctx context.Context, src, backingSrc *source.Source) (err error) {
fs.mu.Lock()
defer fs.mu.Unlock()
log := lg.FromContext(ctx)
ingestFilePath, err := fs.filepath(src)
if err != nil {
return err
}
// Write the checksums file.
if location.TypeOf(src.Location) == location.TypeHTTP {
// If the source is remote, check if there was a download,
// and if so, make sure it's completed.
stream, ok := fs.streams[src.Handle]
if ok {
select {
case <-stream.Filled():
case <-stream.Done():
case <-ctx.Done():
return ctx.Err()
}
if err = stream.Err(); err != nil && !errors.Is(err, io.EOF) {
return err
}
}
// If we got this far, either there's no stream, or the stream is done,
// which means that the download cache has been updated, and contains
// the fresh cached body file that we'll use to calculate the checksum.
// So, we'll go ahead and do the checksum stuff below.
}
// Now, we need to write a checksum file that contains the computed checksum
// value from ingestFilePath.
var sum checksum.Checksum
if sum, err = checksum.ForFile(ingestFilePath); err != nil {
log.Warn("Failed to compute checksum for source file; caching not in effect",
@ -110,9 +141,9 @@ func (fs *Files) CachedBackingSourceFor(ctx context.Context, src *source.Source)
defer fs.mu.Unlock()
switch location.TypeOf(src.Location) {
case location.TypeLocalFile:
case location.TypeFile:
return fs.cachedBackingSourceForFile(ctx, src)
case location.TypeRemoteFile:
case location.TypeHTTP:
return fs.cachedBackingSourceForRemoteFile(ctx, src)
default:
return nil, false, errz.Errorf("caching not applicable for source: %s", src.Handle)

View File

@ -117,7 +117,7 @@ func (fs *Files) detectType(ctx context.Context, handle, loc string) (typ driver
start := time.Now()
var newRdrFn NewReaderFunc
if location.TypeOf(loc) == location.TypeLocalFile {
if location.TypeOf(loc) == location.TypeFile {
newRdrFn = func(ctx context.Context) (io.ReadCloser, error) {
return errz.Return(os.Open(loc))
}

View File

@ -2,19 +2,15 @@ package files
import (
"context"
"io"
"net/http"
"path/filepath"
"time"
"github.com/neilotoole/streamcache"
"github.com/neilotoole/sq/libsq/core/errz"
"github.com/neilotoole/sq/libsq/core/ioz"
"github.com/neilotoole/sq/libsq/core/ioz/checksum"
"github.com/neilotoole/sq/libsq/core/ioz/httpz"
"github.com/neilotoole/sq/libsq/core/lg"
"github.com/neilotoole/sq/libsq/core/lg/lgm"
"github.com/neilotoole/sq/libsq/core/options"
"github.com/neilotoole/sq/libsq/files/internal/downloader"
"github.com/neilotoole/sq/libsq/source"
@ -108,78 +104,19 @@ func (fs *Files) maybeStartDownload(ctx context.Context, src *source.Source, che
}
}
// Having got this far, we need to talk to the downloader.
var (
dlErrCh = make(chan error, 1)
dlStreamCh = make(chan *streamcache.Stream, 1)
dlFileCh = make(chan string, 1)
)
// Our handler simply pushes the callback values into the channels, which
// this main goroutine will select on at the bottom of the func. The call
// to downloader.Get will be executed in a newly spawned goroutine below.
h := downloader.Handler{
Cached: func(dlFile string) { dlFileCh <- dlFile },
Uncached: func(dlStream *streamcache.Stream) { dlStreamCh <- dlStream },
Error: func(dlErr error) { dlErrCh <- dlErr },
}
go func() {
// Spawn a goroutine to execute the download process.
// The handler will be called before Get returns.
cacheFile := dldr.Get(ctx, h)
if cacheFile == "" {
// Either the download failed, or cache update failed.
return
}
// The download succeeded, and the cache was successfully updated.
// We know that cacheFile exists now. If a stream was created (and
// thus added to Files.streams), we can swap it out and instead
// populate Files.downloadedFiles with the cacheFile. Thus, going
// forward, any clients of Files will get the cacheFile instead of
// the stream.
// We need to lock here because we're accessing Files.streams.
// So, this goroutine will block until the lock is available.
// That shouldn't be an issue: the up-stack Files function that
// acquired the lock will eventually return, releasing the lock,
// at which point the swap will happen. No big deal.
fs.mu.Lock()
defer fs.mu.Unlock()
if stream, ok := fs.streams[src.Handle]; ok && stream != nil {
// The stream exists, and it's safe to close the stream's source,
// (i.e. the http response body), because the body has already
// been completely drained by the downloader: otherwise, we
// wouldn't have a non-empty value for cacheFile.
if c, ok := stream.Source().(io.Closer); ok {
lg.WarnIfCloseError(lg.FromContext(ctx), lgm.CloseHTTPResponseBody, c)
}
}
// Now perform the swap: populate Files.downloadedFiles with cacheFile,
// and remove the stream from Files.streams.
fs.downloadedFiles[src.Handle] = cacheFile
delete(fs.streams, src.Handle)
}() // end of goroutine
// Here we wait on the handler channels.
select {
case <-ctx.Done():
return "", nil, errz.Err(ctx.Err())
case err = <-dlErrCh:
dlFile, dlStream, err = dldr.Get(ctx)
switch {
case err != nil:
return "", nil, err
case dlStream = <-dlStreamCh:
// New download stream. Add it to Files.streams,
case dlFile != "":
// The file is already on disk, so we can just return it.
fs.downloadedFiles[src.Handle] = dlFile
return dlFile, nil, nil
default:
// A new download stream was created. Add it to Files.streams,
// and return the stream.
fs.streams[src.Handle] = dlStream
return "", dlStream, nil
case dlFile = <-dlFileCh:
// The file is already on disk, so we added it to Files.downloadedFiles,
// and return its filepath.
fs.downloadedFiles[src.Handle] = dlFile
return dlFile, nil, nil
}
}
@ -192,7 +129,7 @@ func (fs *Files) downloadPaths(src *source.Source) (dlDir, dlFile string, err er
return "", dlFile, err
}
// Note: we're depending on internal knowledge of the downloader impl here,
// Note: we depend on internal knowledge of the downloader impl here,
// which is not great. It might be better to implement a function
// in pkg downloader.
dlDir = filepath.Join(cacheDir, "download", checksum.Sum([]byte(src.Location)))

View File

@ -11,6 +11,8 @@ import (
"os"
"sync"
"github.com/neilotoole/sq/libsq/core/ioz"
"github.com/neilotoole/streamcache"
"github.com/neilotoole/sq/libsq/core/cleanup"
@ -105,7 +107,7 @@ func New(ctx context.Context, optReg *options.Registry, cfgLock lockfile.LockFun
// An error is returned if src is not a document/file source.
func (fs *Files) Filesize(ctx context.Context, src *source.Source) (size int64, err error) {
switch location.TypeOf(src.Location) {
case location.TypeLocalFile:
case location.TypeFile:
var fi os.FileInfo
if fi, err = os.Stat(src.Location); err != nil {
return 0, errz.Err(err)
@ -126,7 +128,7 @@ func (fs *Files) Filesize(ctx context.Context, src *source.Source) (size int64,
}
return int64(total), nil
case location.TypeRemoteFile:
case location.TypeHTTP:
fs.mu.Lock()
// First check if the file is already downloaded
@ -134,12 +136,8 @@ func (fs *Files) Filesize(ctx context.Context, src *source.Source) (size int64,
dlFile, ok := fs.downloadedFiles[src.Handle]
if ok {
// The file is already downloaded.
fs.mu.Unlock()
var fi os.FileInfo
if fi, err = os.Stat(dlFile); err != nil {
return 0, errz.Err(err)
}
return fi.Size(), nil
defer fs.mu.Unlock()
return ioz.Filesize(dlFile)
}
// It's not in the list of downloaded files, so
@ -147,7 +145,9 @@ func (fs *Files) Filesize(ctx context.Context, src *source.Source) (size int64,
dlStream, ok := fs.streams[src.Handle]
if ok {
fs.mu.Unlock()
var total int
// Block until the download completes.
if total, err = dlStream.Total(ctx); err != nil {
return 0, err
}
@ -155,17 +155,17 @@ func (fs *Files) Filesize(ctx context.Context, src *source.Source) (size int64,
}
// Finally, we turn to the downloader.
defer fs.mu.Unlock()
var dl *downloader.Downloader
dl, err = fs.downloaderFor(ctx, src)
fs.mu.Unlock()
if err != nil {
if dl, err = fs.downloaderFor(ctx, src); err != nil {
return 0, err
}
// dl.Filesize will fail if the file has not been downloaded yet, which
// means that the source has not been ingested; but Files.Filesize should
// not have been invoked before ingestion.
return dl.Filesize(ctx)
if dlFile, err = dl.CacheFile(ctx); err != nil {
return 0, err
}
return ioz.Filesize(dlFile)
case location.TypeSQL:
// Should be impossible.
@ -202,9 +202,9 @@ func (fs *Files) AddStdin(ctx context.Context, f *os.File) error {
// file exists.
func (fs *Files) filepath(src *source.Source) (string, error) {
switch location.TypeOf(src.Location) {
case location.TypeLocalFile:
case location.TypeFile:
return src.Location, nil
case location.TypeRemoteFile:
case location.TypeHTTP:
_, dlFile, err := fs.downloadPaths(src)
if err != nil {
return "", err
@ -249,7 +249,7 @@ func (fs *Files) newReader(ctx context.Context, src *source.Source, finalRdr boo
return nil, errz.Errorf("unknown source location type: %s", loc)
case location.TypeSQL:
return nil, errz.Errorf("invalid to read SQL source: %s", loc)
case location.TypeLocalFile:
case location.TypeFile:
return errz.Return(os.Open(loc))
case location.TypeStdin:
stdinStream, ok := fs.streams[source.StdinHandle]
@ -313,13 +313,13 @@ func (fs *Files) Ping(ctx context.Context, src *source.Source) error {
case location.TypeStdin:
// Stdin is always available.
return nil
case location.TypeLocalFile:
case location.TypeFile:
if _, err := os.Stat(src.Location); err != nil {
return errz.Wrapf(err, "ping: failed to stat file source %s: %s", src.Handle, src.Location)
}
return nil
case location.TypeRemoteFile:
case location.TypeHTTP:
req, err := http.NewRequestWithContext(ctx, http.MethodHead, src.Location, nil)
if err != nil {
return errz.Wrapf(err, "ping: %s", src.Handle)

View File

@ -12,7 +12,6 @@ import (
"github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup"
"github.com/neilotoole/sq/libsq/core/ioz"
"github.com/neilotoole/sq/libsq/core/lg"
"github.com/neilotoole/sq/libsq/core/lg/lgt"
"github.com/neilotoole/sq/libsq/core/stringz"
@ -102,11 +101,18 @@ func TestFiles_DriverType(t *testing.T) {
t.Run(tu.Name(location.Redact(tc.loc)), func(t *testing.T) {
ctx := lg.NewContext(context.Background(), lgt.New(t))
fs, err := files.New(ctx, nil, testh.TempLockFunc(t), tu.TempDir(t, true), tu.CacheDir(t, true))
fs, err := files.New(
ctx,
nil,
testh.TempLockFunc(t),
tu.TempDir(t, false),
tu.CacheDir(t, false),
)
require.NoError(t, err)
defer func() { assert.NoError(t, fs.Close()) }()
fs.AddDriverDetectors(testh.DriverDetectors()...)
gotType, gotErr := fs.DetectType(context.Background(), "@test_"+stringz.Uniq8(), tc.loc)
gotType, gotErr := fs.DetectType(ctx, "@test_"+stringz.Uniq8(), tc.loc)
if tc.wantErr {
require.Error(t, gotErr)
return
@ -267,7 +273,8 @@ func TestFiles_Filesize(t *testing.T) {
// Files.Filesize will block until the stream is fully read.
r, err := fs.NewReader(th.Context, stdinSrc, false)
require.NoError(t, err)
require.NoError(t, ioz.Drain(r))
_, err = io.Copy(io.Discard, r)
require.NoError(t, err)
gotSize2, err := fs.Filesize(th.Context, stdinSrc)
require.NoError(t, err)

View File

@ -4,12 +4,13 @@ import (
"bufio"
"bytes"
"context"
"errors"
"io"
"net/http"
"net/http/httputil"
"os"
"path/filepath"
"time"
"sync"
"github.com/neilotoole/sq/libsq/core/errz"
"github.com/neilotoole/sq/libsq/core/ioz"
@ -17,7 +18,6 @@ import (
"github.com/neilotoole/sq/libsq/core/ioz/contextio"
"github.com/neilotoole/sq/libsq/core/lg"
"github.com/neilotoole/sq/libsq/core/lg/lga"
"github.com/neilotoole/sq/libsq/core/lg/lgm"
)
const (
@ -229,8 +229,7 @@ func (c *cache) clear(ctx context.Context) error {
recreateErr := ioz.RequireDir(c.dir)
err := errz.Append(deleteErr, recreateErr)
if err != nil {
lg.FromContext(ctx).Error(msgDeleteCache,
lga.Dir, c.dir, lga.Err, err)
lg.FromContext(ctx).Error(msgDeleteCache, lga.Dir, c.dir, lga.Err, err)
return err
}
@ -238,85 +237,235 @@ func (c *cache) clear(ctx context.Context) error {
return nil
}
// write updates the cache. If headerOnly is true, only the header cache file
// is updated, and the function returns. Otherwise, the header and body
// cache files are updated, and a checksum file (computed from the body file)
// is also written to disk.
//
// If an error occurs while attempting to update the cache, any existing
// cache artifacts are left untouched. It's a sort of atomic-write-lite.
// To achieve this, cache files are first written to a staging dir, and that
// staging dir is only swapped with the main cache dir if there are no errors.
//
// The response body is always closed.
func (c *cache) write(ctx context.Context, resp *http.Response, headerOnly bool) (err error) {
log := lg.FromContext(ctx)
start := time.Now()
var stagingDir string
defer func() {
lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, resp.Body)
if stagingDir != "" && ioz.DirExists(stagingDir) {
lg.WarnIfError(log, "Remove cache staging dir", os.RemoveAll(stagingDir))
// writeHeader updates the main cache header file from resp. The response body
// is not written to the cache, nor is resp.Body closed.
func (c *cache) writeHeader(ctx context.Context, resp *http.Response) (err error) {
header, err := httputil.DumpResponse(resp, false)
if err != nil {
return errz.Err(err)
}
}()
mainDir := filepath.Join(c.dir, "main")
if err = ioz.RequireDir(mainDir); err != nil {
return err
}
stagingDir = filepath.Join(c.dir, "staging")
if err = ioz.RequireDir(mainDir); err != nil {
fp := filepath.Join(mainDir, "header")
if _, err = ioz.WriteToFile(ctx, fp, bytes.NewReader(header)); err != nil {
return err
}
headerBytes, err := httputil.DumpResponse(resp, false)
lg.FromContext(ctx).Info("Updated download cache (header only)", lga.Dir, c.dir, lga.Resp, resp)
return nil
}
// newResponseCacher returns a new responseCacher for resp.
// On return, resp.Body will be nil.
func (c *cache) newResponseCacher(ctx context.Context, resp *http.Response) (*responseCacher, error) {
defer func() { resp.Body = nil }()
stagingDir := filepath.Join(c.dir, "staging")
if err := ioz.RequireDir(stagingDir); err != nil {
_ = resp.Body.Close()
return nil, err
}
header, err := httputil.DumpResponse(resp, false)
if err != nil {
return errz.Err(err)
_ = resp.Body.Close()
return nil, errz.Err(err)
}
fpHeaderStaging := filepath.Join(stagingDir, "header")
if _, err = ioz.WriteToFile(ctx, fpHeaderStaging, bytes.NewReader(headerBytes)); err != nil {
return err
if _, err = ioz.WriteToFile(ctx, filepath.Join(stagingDir, "header"), bytes.NewReader(header)); err != nil {
_ = resp.Body.Close()
return nil, err
}
fpHeader, fpBody, _ := c.paths(resp.Request)
if headerOnly {
// It's only the header that we're changing, so we don't need to
// swap the entire staging dir, just the header file.
if err = os.Rename(fpHeaderStaging, fpHeader); err != nil {
return errz.Wrap(err, "failed to move staging cache header file")
var f *os.File
if f, err = os.Create(filepath.Join(stagingDir, "body")); err != nil {
_ = resp.Body.Close()
return nil, err
}
log.Info("Updated download cache (header only)", lga.Dir, c.dir, lga.Resp, resp)
r := &responseCacher{
stagingDir: stagingDir,
mainDir: filepath.Join(c.dir, "main"),
body: resp.Body,
f: f,
}
return r, nil
}
var _ io.ReadCloser = (*responseCacher)(nil)
// responseCacher is an io.ReadCloser that wraps an [http.Response.Body],
// appending bytes read via Read to a staging cache file. When Read receives
// [io.EOF] from the wrapped response body, the staging cache is promoted and
// replaces the main cache. If an error occurs during Read, the staging cache is
// discarded, and the main cache is left untouched. If an error occurs during
// cache promotion (which happens on receipt of io.EOF from resp.Body), the
// promotion error, not [io.EOF], is returned by Read. Thus, a consumer of
// responseCacher will not receive [io.EOF] unless the cache is successfully
// promoted.
type responseCacher struct {
body io.ReadCloser
closeErr *error
f *os.File
mainDir string
stagingDir string
mu sync.Mutex
}
// Read implements [io.Reader]. It reads into p from the wrapped response body,
// appends the received bytes to the staging cache, and returns the number of
// bytes read, and any error. When Read encounters [io.EOF] from the response
// body, it promotes the staging cache to main, and on success returns [io.EOF].
// If an error occurs during cache promotion, Read returns that promotion error
// instead of [io.EOF].
func (r *responseCacher) Read(p []byte) (n int, err error) {
r.mu.Lock()
defer r.mu.Unlock()
// Use r.body as a sentinel to indicate that the cache
// has been closed.
if r.body == nil {
return 0, errz.New("response cache already closed")
}
n, err = r.body.Read(p)
switch {
case err == nil:
err = r.cacheAppend(p, n)
return n, err
case !errors.Is(err, io.EOF):
// It's some other kind of error.
// Clean up and return.
_ = r.body.Close()
r.body = nil
_ = r.f.Close()
_ = os.Remove(r.f.Name())
r.f = nil
_ = os.RemoveAll(r.stagingDir)
r.stagingDir = ""
return n, err
default:
// It's EOF time! Let's promote the cache.
}
var err2 error
if err2 = r.cacheAppend(p, n); err2 != nil {
return n, err2
}
if err2 = r.cachePromote(); err2 != nil {
return n, err2
}
return n, err
}
// cacheAppend appends n bytes from p to the staging cache. If an error occurs,
// the staging cache is discarded, and the error is returned.
func (r *responseCacher) cacheAppend(p []byte, n int) error {
_, err := r.f.Write(p[:n])
if err == nil {
return nil
}
_ = r.body.Close()
r.body = nil
_ = r.f.Close()
_ = os.Remove(r.f.Name())
r.f = nil
_ = os.RemoveAll(r.stagingDir)
r.stagingDir = ""
return errz.Wrap(err, "failed to append http response body bytes to staging cache")
}
fpBodyStaging := filepath.Join(stagingDir, "body")
var written int64
if written, err = ioz.WriteToFile(ctx, fpBodyStaging, resp.Body); err != nil {
log.Warn("Cache write: failed to write cache body file", lga.Err, err, lga.Path, fpBodyStaging)
return err
// cachePromote is invoked by Read when it receives io.EOF from the wrapped
// response body. It promotes the staging cache to main, and on success returns
// nil. If an error occurs during promotion, the staging cache is discarded, and
// the promotion error is returned.
func (r *responseCacher) cachePromote() error {
defer func() {
if r.f != nil {
_ = r.f.Close()
r.f = nil
}
if r.body != nil {
_ = r.body.Close()
r.body = nil
}
if r.stagingDir != "" {
_ = os.RemoveAll(r.stagingDir)
r.stagingDir = ""
}
}()
err := r.f.Close()
fpBody := r.f.Name()
r.f = nil
if err != nil {
return errz.Wrap(err, "failed to close cache body file")
}
sum, err := checksum.ForFile(fpBodyStaging)
err = r.body.Close()
r.body = nil
if err != nil {
return errz.Wrap(err, "failed to close http response body")
}
var sum checksum.Checksum
if sum, err = checksum.ForFile(fpBody); err != nil {
return errz.Wrap(err, "failed to compute checksum for cache body file")
}
if err = checksum.WriteFile(filepath.Join(stagingDir, "checksums.txt"), sum, "body"); err != nil {
if err = checksum.WriteFile(filepath.Join(r.stagingDir, "checksums.txt"), sum, "body"); err != nil {
return errz.Wrap(err, "failed to write checksum file for cache body")
}
// We've got good data in the staging dir. Now we do the switcheroo.
if err = ioz.RenameDir(stagingDir, mainDir); err != nil {
if err = ioz.RenameDir(r.stagingDir, r.mainDir); err != nil {
return errz.Wrap(err, "failed to write download cache")
}
r.stagingDir = ""
stagingDir = ""
log.Info("Updated download cache (full)",
lga.Written, written, lga.File, fpBody, lga.Elapsed, time.Since(start).Round(time.Millisecond))
return nil
}
// Close implements [io.Closer]. Note that cache promotion happens when Read
// receives [io.EOF] from the wrapped response body, so the main action should
// be over by the time Close is invoked. Note that Close is idempotent, and
// returns the same error on subsequent invocations.
func (r *responseCacher) Close() error {
r.mu.Lock()
defer r.mu.Unlock()
if r.closeErr != nil {
// Already closed
return *r.closeErr
}
// There's some duplication of logic with using both r.closeErr
// and r.body == nil as sentinels. This could be cleaned up.
var err error
if r.f != nil {
err = errz.Append(err, r.f.Close())
r.f = nil
}
if r.body != nil {
err = errz.Append(err, r.body.Close())
r.body = nil
}
if r.stagingDir != "" {
err = errz.Append(err, os.RemoveAll(r.stagingDir))
r.stagingDir = ""
}
r.closeErr = &err
return *r.closeErr
}

View File

@ -104,9 +104,14 @@ func (s State) String() string {
const XFromCache = "X-From-Stream"
// Downloader encapsulates downloading a file from a URL, using a local
// disk cache if possible. Downloader.Get makes uses of the Handler callback
// mechanism to facilitate early consumption of a download stream while the
// download is still in flight.
// disk cache if possible. Downloader.Get returns either a filepath to the
// already-downloaded file, or a stream of the download in progress, or an
// error. If a stream is returned, the Downloader cache is updated when the
// stream is fully consumed (this can be observed by the closing of the
// channel returned by [streamcache.Stream.Filled]).
//
// To be extra clear about that last point: the caller must consume the
// stream returned by Downloader.Get, or the cache will not be written.
type Downloader struct {
// c is the HTTP client used to make requests.
c *http.Client
@ -115,10 +120,6 @@ type Downloader struct {
// It will be created in dlDir.
cache *cache
// dlStream is the streamcache.Stream that is passed Handler.Uncached for an
// active download. This field is reset to nil on each call to Downloader.Get.
dlStream *streamcache.Stream
// name is a user-friendly name, such as a source handle like @data.
name string
@ -166,27 +167,28 @@ func New(name string, c *http.Client, dlURL, dlDir string) (*Downloader, error)
return dl, nil
}
// Get attempts to get the remote file, invoking Handler as appropriate. Exactly
// one of the Handler methods will be invoked, one time.
// Get attempts to get the remote file, returning either the filepath of the
// already-cached file in dlFile, or a stream of a newly-started download in
// dlStream, or an error. Exactly one of the return values will be non-nil.
//
// - If Handler.Uncached is invoked, a download stream has begun. Get will
// then block until the download is completed. The download resp.Body is
// written to cache, and on success, the filepath to the newly updated
// cache file is returned.
// If an error occurs during cache write, the error is logged, and Get
// returns the filepath of the previously cached download, if permitted
// by policy. If not permitted or not existing, empty string is returned.
// - If Handler.Cached is invoked, Get returns immediately afterwards with
// the filepath of the cached download (the same value provided to
// Handler.Cached).
// - If Handler.Error is invoked, there was an unrecoverable problem (e.g. a
// transport error, and there's no previous cache) and the download is
// unavailable. That error should be propagated up the stack. Get will
// return empty string.
// - If dlFile is non-empty, it is the filepath on disk of the cached download,
// and dlStream and err are nil. However, depending on OptContinueOnError,
// dlFile may be the path to a stale download. If the cache is stale and a
// transport error occurs during refresh, and OptContinueOnError is true,
// the previous cached download is returned. If OptContinueOnError is false,
// the transport error is returned, and dlFile is empty. The caller can also
// check the cache state via [Downloader.State].
// - If dlStream is non-nil, it is a stream of the download in progress, and
// dlFile is empty. The cache is updated when the stream has been completely
// consumed. If the stream is not consumed, the cache is not updated. If an
// error occurs reading from the stream, the cache is also not updated: this
// means that the cache may still contain the previous (stale) download.
// - If err is non-nil, there was an unrecoverable problem (e.g. a transport
// error, and there's no previous cache) and the download is unavailable.
//
// Get consults the context for options. In particular, it makes
// use of OptCache and OptContinueOnError.
func (dl *Downloader) Get(ctx context.Context, h Handler) (cacheFile string) {
// Get consults the context for options. In particular, it makes use of OptCache
// and OptContinueOnError.
func (dl *Downloader) Get(ctx context.Context) (dlFile string, dlStream *streamcache.Stream, err error) {
dl.mu.Lock()
defer dl.mu.Unlock()
@ -198,19 +200,16 @@ func (dl *Downloader) Get(ctx context.Context, h Handler) (cacheFile string) {
req := dl.mustRequest(ctx)
lg.FromContext(ctx).Debug("Get download", lga.URL, dl.url)
cacheFile = dl.get(req, h)
return cacheFile
return dl.get(req)
}
// get contains the main logic for getting the download.
// It invokes Handler as appropriate, and on success returns the
// filepath of the valid cached download file.
func (dl *Downloader) get(req *http.Request, h Handler) (cacheFile string) { //nolint:gocognit,funlen,cyclop
func (dl *Downloader) get(req *http.Request) (dlFile string, //nolint:gocognit,funlen,cyclop
dlStream *streamcache.Stream, err error,
) {
ctx := req.Context()
log := lg.FromContext(ctx)
dl.dlStream = nil
var fpBody string
if dl.cache != nil {
_, fpBody, _ = dl.cache.paths(req)
@ -219,12 +218,10 @@ func (dl *Downloader) get(req *http.Request, h Handler) (cacheFile string) { //n
state := dl.state(req)
if state == Fresh && fpBody != "" {
// The cached response is fresh, so we can return it.
h.Cached(fpBody)
return fpBody
return fpBody, nil, nil
}
cacheable := dl.isCacheable(req)
var err error
var cachedResp *http.Response
if cacheable {
cachedResp, err = dl.cache.get(req.Context(), req) //nolint:bodyclose
@ -241,8 +238,7 @@ func (dl *Downloader) get(req *http.Request, h Handler) (cacheFile string) { //n
freshness := getFreshness(cachedResp.Header, req.Header)
if freshness == Fresh && fpBody != "" {
lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, cachedResp.Body)
h.Cached(fpBody)
return fpBody
return fpBody, nil, nil
}
if freshness == Stale {
@ -279,12 +275,11 @@ func (dl *Downloader) get(req *http.Request, h Handler) (cacheFile string) { //n
case fpBody != "" && (err != nil || resp.StatusCode >= 500) &&
req.Method == http.MethodGet && canStaleOnError(cachedResp.Header, req.Header):
// In case of transport failure and stale-if-error activated, returns cached content
// when available.
// In case of transport failure canStaleOnError returns true,
// return the stale cached download.
lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, cachedResp.Body)
log.Warn("Returning cached response due to transport failure", lga.Err, err)
h.Cached(fpBody)
return fpBody
return fpBody, nil, nil
default:
if err != nil && resp != nil && resp.StatusCode != http.StatusOK {
@ -293,46 +288,41 @@ func (dl *Downloader) get(req *http.Request, h Handler) (cacheFile string) { //n
if fp := dl.cacheFileOnError(req, err); fp != "" {
lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, resp.Body)
h.Cached(fp)
return fp
return fp, nil, nil
}
}
if err != nil {
lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, resp.Body)
if fp := dl.cacheFileOnError(req, err); fp != "" {
h.Cached(fp)
return fp
return fp, nil, nil
}
h.Error(err)
return ""
return "", nil, err
}
if resp.StatusCode != http.StatusOK {
lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, resp.Body)
err = errz.Errorf("download: unexpected HTTP status: %s", httpz.StatusText(resp.StatusCode))
if fp := dl.cacheFileOnError(req, err); fp != "" {
h.Cached(fp)
return fp
return fp, nil, nil
}
h.Error(err)
return ""
return "", nil, err
}
}
} else {
reqCacheControl := parseCacheControl(req.Header)
if _, ok := reqCacheControl["only-if-cached"]; ok {
resp = newGatewayTimeoutResponse(req) //nolint:bodyclose
resp = newGatewayTimeoutResponse(req)
} else {
resp, err = dl.do(req) //nolint:bodyclose
if err != nil {
if fp := dl.cacheFileOnError(req, err); fp != "" {
h.Cached(fp)
return fp
return fp, nil, nil
}
h.Error(err)
return ""
return "", nil, err
}
}
}
@ -348,59 +338,55 @@ func (dl *Downloader) get(req *http.Request, h Handler) (cacheFile string) { //n
}
if resp == cachedResp {
err = dl.cache.write(ctx, resp, true)
err = dl.cache.writeHeader(ctx, resp)
lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, resp.Body)
if err != nil {
log.Error("Failed to update cache header", lga.Dir, dl.cache.dir, lga.Err, err)
if fp := dl.cacheFileOnError(req, err); fp != "" {
h.Cached(fp)
return fp
return fp, nil, nil
}
h.Error(err)
return ""
return "", nil, err
}
if fpBody != "" {
h.Cached(fpBody)
return fpBody
return fpBody, nil, nil
}
} else if cachedResp != nil && cachedResp.Body != nil {
lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, cachedResp.Body)
}
dl.dlStream = streamcache.New(resp.Body)
resp.Body = dl.dlStream.NewReader(ctx)
h.Uncached(dl.dlStream)
if err = dl.cache.write(req.Context(), resp, false); err != nil {
// We don't explicitly call Handler.Error: it would be "illegal" to do so
// anyway, because the Handler docs state that at most one Handler callback
// func is ever invoked.
// OK, this is where the funky stuff happens.
//
// The cache write could fail for one of two reasons:
// First, note that the cache is two-stage: there's a staging cache, and a
// main cache. The staging cache is used to write the response body to disk,
// and when the response body is fully consumed, the staging cache is
// promoted to main, in a sort of atomic-swap-lite. This is done to avoid
// partially-written cache files in the main cache, and other such nastiness.
//
// - The download didn't complete successfully: that is, there was an error
// reading from resp.Body. In this case, that same error will be propagated
// to the Handler via the streamcache.Stream that was provided to Handler.Uncached.
// - The download completed, but there was a problem writing out the cache
// files (header, body, checksum). This is likely a very rare occurrence.
// In that case, any previous cache files are left untouched by cache.write,
// and all we do is log the error. If the cache is inconsistent, it will
// repair itself on next invocation, so it's not a big deal.
log.Warn("Failed to write download cache", lga.Dir, dl.cache.dir, lga.Err, err)
lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, resp.Body)
return ""
}
lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, resp.Body)
return fpBody
// The responseCacher type is an io.ReadCloser that wraps the response body.
// As its Read method is called, it writes the body bytes to a staging cache
// file (and also returns them to the caller). When rCacher encounters io.EOF,
// it promotes the staging cache to main before returning io.EOF to the
// caller. If promotion fails, the promotion error (not io.EOF) is returned
// to the caller. Thus, it is guaranteed that any caller of rCacher's Read
// method will only receive io.EOF if the cache has been promoted to main.
var rCacher *responseCacher
if rCacher, err = dl.cache.newResponseCacher(ctx, resp); err != nil {
return "", nil, err
}
// It's not cacheable, so we can just wrap resp.Body in a streamcache
// and return it.
dl.dlStream = streamcache.New(resp.Body)
// And now we wrap rCacher in a streamcache: any streamcache readers will
// only receive io.EOF when/if the staging cache has been promoted to main.
dlStream = streamcache.New(rCacher)
return "", dlStream, nil
}
// The response is not cacheable, so we can just wrap resp.Body in a
// streamcache and return it.
dlStream = streamcache.New(resp.Body)
resp.Body = nil // Unnecessary, but just to be explicit.
h.Uncached(dl.dlStream)
return ""
return "", dlStream, nil
}
// do executes the request.
@ -494,42 +480,12 @@ func (dl *Downloader) state(req *http.Request) State {
return getFreshness(cachedResp.Header, req.Header)
}
// Filesize returns the size of the downloaded file. This should
// only be invoked after the download has completed or is cached,
// as it may block until the download completes.
func (dl *Downloader) Filesize(ctx context.Context) (int64, error) {
dl.mu.Lock()
defer dl.mu.Unlock()
if dl.dlStream != nil {
// There's an active download, so we can get the filesize
// when the download completes.
size, err := dl.dlStream.Total(ctx)
return int64(size), err
}
if dl.cache == nil {
return 0, errz.New("download file size not available")
}
req := dl.mustRequest(ctx)
if !dl.cache.exists(req) {
// It's not in the cache.
return 0, errz.New("download file size not available")
}
// It's in the cache.
_, fp, _ := dl.cache.paths(req)
fi, err := os.Stat(fp)
if err != nil {
return 0, errz.Wrapf(err, "unable to stat cached download file: %s", fp)
}
return fi.Size(), nil
}
// CacheFile returns the path to the cached file, if it exists and has
// been fully downloaded.
// CacheFile returns the path to the cached file, if it exists. If there's
// a download in progress ([Downloader.Get] returned a stream), then CacheFile
// may return the filepath to the previously cached file. The caller should
// wait on any previously returned download stream to complete to ensure
// that the returned filepath is that of the current download. The caller
// can also check the cache state via [Downloader.State].
func (dl *Downloader) CacheFile(ctx context.Context) (fp string, err error) {
dl.mu.Lock()
defer dl.mu.Unlock()

View File

@ -12,23 +12,101 @@ import (
"testing"
"time"
"github.com/neilotoole/streamcache"
"github.com/neilotoole/sq/libsq/core/ioz"
"github.com/neilotoole/sq/libsq/core/lg/lga"
"github.com/neilotoole/sq/libsq/core/stringz"
"github.com/neilotoole/sq/testh/sakila"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/neilotoole/sq/libsq/core/ioz"
"github.com/neilotoole/sq/libsq/core/ioz/httpz"
"github.com/neilotoole/sq/libsq/core/lg"
"github.com/neilotoole/sq/libsq/core/lg/lga"
"github.com/neilotoole/sq/libsq/core/lg/lgt"
"github.com/neilotoole/sq/libsq/core/options"
"github.com/neilotoole/sq/libsq/core/stringz"
"github.com/neilotoole/sq/libsq/files"
"github.com/neilotoole/sq/libsq/files/internal/downloader"
"github.com/neilotoole/sq/testh/sakila"
"github.com/neilotoole/sq/testh/tu"
)
func TestDownload_redirect(t *testing.T) {
func TestDownloader(t *testing.T) {
const dlURL = sakila.ActorCSVURL
log := lgt.New(t)
ctx := lg.NewContext(context.Background(), log)
cacheDir := t.TempDir()
dl, gotErr := downloader.New(t.Name(), httpz.NewDefaultClient(), dlURL, cacheDir)
require.NoError(t, gotErr)
require.NoError(t, dl.Clear(ctx))
require.Equal(t, downloader.Uncached, dl.State(ctx))
sum, ok := dl.Checksum(ctx)
require.False(t, ok)
require.Empty(t, sum)
// Here's our first download, it's not cached, so we should
// get a stream.
gotFile, gotStream, gotErr := dl.Get(ctx)
require.NoError(t, gotErr)
require.Empty(t, gotFile)
require.NotNil(t, gotStream)
require.Equal(t, downloader.Uncached, dl.State(ctx))
// The stream should not be filled yet, because we
// haven't read from it.
tu.RequireNoTake(t, gotStream.Filled())
// And there should be no cache file, because the cache file
// isn't written until the stream is drained.
gotFile, gotErr = dl.CacheFile(ctx)
require.Error(t, gotErr)
require.Empty(t, gotFile)
r := gotStream.NewReader(ctx)
gotStream.Seal()
// Now we drain the stream, and the cache should magically fill.
var gotN int
gotN, gotErr = ioz.DrainClose(r)
require.NoError(t, gotErr)
require.Equal(t, sakila.ActorCSVSize, gotN)
tu.RequireTake(t, gotStream.Filled())
tu.RequireTake(t, gotStream.Done())
require.Equal(t, sakila.ActorCSVSize, gotStream.Size())
require.Equal(t, downloader.Fresh, dl.State(ctx))
// Now we should be able to access the cache.
sum, ok = dl.Checksum(ctx)
require.True(t, ok)
require.NotEmpty(t, sum)
gotFile, gotErr = dl.CacheFile(ctx)
require.NoError(t, gotErr)
require.NotEmpty(t, gotFile)
gotSize, gotErr := ioz.Filesize(gotFile)
require.NoError(t, gotErr)
require.Equal(t, sakila.ActorCSVSize, int(gotSize))
// Let's download again, and verify that the cache is used.
gotFile, gotStream, gotErr = dl.Get(ctx)
require.Nil(t, gotErr)
require.Nil(t, gotStream)
require.NotEmpty(t, gotFile)
gotFileBytes, gotErr := os.ReadFile(gotFile)
require.NoError(t, gotErr)
require.Equal(t, sakila.ActorCSVSize, len(gotFileBytes))
require.Equal(t, downloader.Fresh, dl.State(ctx))
sum, ok = dl.Checksum(ctx)
require.True(t, ok)
require.NotEmpty(t, sum)
require.NoError(t, dl.Clear(ctx))
require.Equal(t, downloader.Uncached, dl.State(ctx))
sum, ok = dl.Checksum(ctx)
require.False(t, ok)
require.Empty(t, sum)
}
func TestDownloader_redirect(t *testing.T) {
const hello = `Hello World!`
serveBody := hello
lastModified := time.Now().UTC()
@ -81,100 +159,28 @@ func TestDownload_redirect(t *testing.T) {
dl, err := downloader.New(t.Name(), httpz.NewDefaultClient(), loc, cacheDir)
require.NoError(t, err)
require.NoError(t, dl.Clear(ctx))
h := downloader.NewSinkHandler(log.With("origin", "handler"))
gotGetFile := dl.Get(ctx, h.Handler)
require.Empty(t, h.Errors)
require.NotEmpty(t, gotGetFile)
gotBody := tu.ReadToString(t, h.Streams[0].NewReader(ctx))
gotFile, gotStream, gotErr := dl.Get(ctx)
require.NoError(t, gotErr)
require.NotNil(t, gotStream)
require.Empty(t, gotFile)
gotBody := tu.ReadToString(t, gotStream.NewReader(ctx))
require.Equal(t, hello, gotBody)
h.Reset()
gotGetFile = dl.Get(ctx, h.Handler)
require.NotEmpty(t, gotGetFile)
require.Empty(t, h.Errors)
require.Empty(t, h.Streams)
gotDownloadedFile := h.Downloaded[0]
t.Logf("got fp: %s", gotDownloadedFile)
gotBody = tu.ReadFileToString(t, gotDownloadedFile)
gotFile, gotStream, gotErr = dl.Get(ctx)
require.NoError(t, gotErr)
require.Nil(t, gotStream)
require.NotEmpty(t, gotFile)
t.Logf("got fp: %s", gotFile)
gotBody = tu.ReadFileToString(t, gotFile)
t.Logf("got body: \n\n%s\n\n", gotBody)
require.Equal(t, serveBody, gotBody)
h.Reset()
gotGetFile = dl.Get(ctx, h.Handler)
require.NotEmpty(t, gotGetFile)
require.Empty(t, h.Errors)
require.Empty(t, h.Streams)
gotDownloadedFile = h.Downloaded[0]
t.Logf("got fp: %s", gotDownloadedFile)
gotBody = tu.ReadFileToString(t, gotDownloadedFile)
t.Logf("got body: \n\n%s\n\n", gotBody)
require.Equal(t, serveBody, gotBody)
}
func TestDownload_New(t *testing.T) {
log := lgt.New(t)
ctx := lg.NewContext(context.Background(), log)
const dlURL = sakila.ActorCSVURL
cacheDir := t.TempDir()
dl, err := downloader.New(t.Name(), httpz.NewDefaultClient(), dlURL, cacheDir)
require.NoError(t, err)
require.NoError(t, dl.Clear(ctx))
require.Equal(t, downloader.Uncached, dl.State(ctx))
sum, ok := dl.Checksum(ctx)
require.False(t, ok)
require.Empty(t, sum)
h := downloader.NewSinkHandler(log.With("origin", "handler"))
gotGetFile := dl.Get(ctx, h.Handler)
require.NotEmpty(t, gotGetFile)
require.Empty(t, h.Errors)
require.Empty(t, h.Downloaded)
require.Equal(t, 1, len(h.Streams))
require.Equal(t, int64(sakila.ActorCSVSize), int64(h.Streams[0].Size()))
require.Equal(t, downloader.Fresh, dl.State(ctx))
sum, ok = dl.Checksum(ctx)
require.True(t, ok)
require.NotEmpty(t, sum)
h.Reset()
gotGetFile = dl.Get(ctx, h.Handler)
require.NotEmpty(t, gotGetFile)
require.Empty(t, h.Errors)
require.Empty(t, h.Streams)
require.NotEmpty(t, h.Downloaded)
require.Equal(t, gotGetFile, h.Downloaded[0])
gotFileBytes, err := os.ReadFile(h.Downloaded[0])
require.NoError(t, err)
require.Equal(t, sakila.ActorCSVSize, len(gotFileBytes))
require.Equal(t, downloader.Fresh, dl.State(ctx))
sum, ok = dl.Checksum(ctx)
require.True(t, ok)
require.NotEmpty(t, sum)
require.NoError(t, dl.Clear(ctx))
require.Equal(t, downloader.Uncached, dl.State(ctx))
sum, ok = dl.Checksum(ctx)
require.False(t, ok)
require.Empty(t, sum)
h.Reset()
gotGetFile = dl.Get(ctx, h.Handler)
require.Empty(t, h.Errors)
require.NotEmpty(t, gotGetFile)
require.True(t, ioz.FileAccessible(gotGetFile))
h.Reset()
}
func TestCachePreservedOnFailedRefresh(t *testing.T) {
o := options.Options{files.OptHTTPResponseTimeout.Key(): "10m"}
ctx := options.NewContext(context.Background(), o)
ctx := lg.NewContext(context.Background(), lgt.New(t))
var (
log = lgt.New(t)
srvr *httptest.Server
srvrShouldBodyError bool
srvrShouldNoCache bool
@ -216,32 +222,34 @@ func TestCachePreservedOnFailedRefresh(t *testing.T) {
}))
t.Cleanup(srvr.Close)
ctx = lg.NewContext(ctx, log.With("origin", "downloader"))
cacheDir := filepath.Join(t.TempDir(), stringz.UniqSuffix("dlcache"))
dl, err := downloader.New(t.Name(), httpz.NewDefaultClient(), srvr.URL, cacheDir)
require.NoError(t, err)
require.NoError(t, dl.Clear(ctx))
h := downloader.NewSinkHandler(log.With("origin", "handler"))
gotGetFile := dl.Get(ctx, h.Handler)
require.Empty(t, h.Errors)
require.NotEmpty(t, h.Streams)
require.NotEmpty(t, gotGetFile, "cache file should have been filled")
require.True(t, ioz.FileAccessible(gotGetFile))
stream := h.Streams[0]
start := time.Now()
var gotFile string
var gotStream *streamcache.Stream
gotFile, gotStream, err = dl.Get(ctx)
require.NoError(t, err)
require.Empty(t, gotFile)
require.NotNil(t, gotStream)
tu.RequireNoTake(t, gotStream.Filled())
r := gotStream.NewReader(ctx)
gotStream.Seal()
t.Logf("Waiting for download to complete")
<-stream.Filled()
start := time.Now()
var gotN int
gotN, err = ioz.DrainClose(r)
require.NoError(t, err)
t.Logf("Download completed after %s", time.Since(start))
require.True(t, errors.Is(stream.Err(), io.EOF))
gotSize, err := dl.Filesize(ctx)
tu.RequireTake(t, gotStream.Filled())
tu.RequireTake(t, gotStream.Done())
require.Equal(t, len(sentBody), gotN)
require.True(t, errors.Is(gotStream.Err(), io.EOF))
require.NoError(t, err)
require.Equal(t, len(sentBody), int(gotSize))
require.Equal(t, len(sentBody), stream.Size())
gotFilesize, err := dl.Filesize(ctx)
require.NoError(t, err)
require.Equal(t, len(sentBody), int(gotFilesize))
require.Equal(t, len(sentBody), gotStream.Size())
fpBody, err := dl.CacheFile(ctx)
require.NoError(t, err)
@ -256,23 +264,31 @@ func TestCachePreservedOnFailedRefresh(t *testing.T) {
fiChecksums1 := tu.MustStat(t, fpChecksums)
srvrShouldBodyError = true
h.Reset()
// Sleep to allow file modification timestamps to tick
time.Sleep(time.Millisecond * 10)
gotGetFile = dl.Get(ctx, h.Handler)
require.Empty(t, h.Errors)
require.Empty(t, gotGetFile,
"gotCacheFile should be empty, because the server returned an error during cache write")
require.NotEmpty(t, h.Streams,
"h.Streams should not be empty, because the download was in fact initiated")
stream = h.Streams[0]
<-stream.Filled()
err = stream.Err()
gotFile, gotStream, err = dl.Get(ctx)
require.NoError(t, err)
require.Empty(t, gotFile,
"gotFile should be empty, because the server returned an error during cache write")
require.NotNil(t, gotStream,
"gotStream should not be empty, because the download was in fact initiated")
r = gotStream.NewReader(ctx)
gotStream.Seal()
gotN, err = ioz.DrainClose(r)
require.Equal(t, len(sentBody), gotN)
tu.RequireTake(t, gotStream.Filled())
tu.RequireTake(t, gotStream.Done())
streamErr := gotStream.Err()
require.Error(t, streamErr)
require.True(t, errors.Is(err, streamErr))
t.Logf("got stream err: %v", err)
require.Error(t, err)
require.True(t, errors.Is(err, io.ErrUnexpectedEOF))
require.Equal(t, len(sentBody), stream.Size())
require.Equal(t, len(sentBody), gotStream.Size())
// Verify that the server hasn't updated the cache,
// by checking that the file modification timestamps
@ -284,6 +300,4 @@ func TestCachePreservedOnFailedRefresh(t *testing.T) {
require.True(t, ioz.FileInfoEqual(fiBody1, fiBody2))
require.True(t, ioz.FileInfoEqual(fiHeader1, fiHeader2))
require.True(t, ioz.FileInfoEqual(fiChecksums1, fiChecksums2))
h.Reset()
}

View File

@ -1,6 +1,9 @@
// Package location contains functionality related to source location.
package location
// NOTE: This package contains code from several eras. There's a bunch of
// overlap and duplication. It should be consolidated.
import (
"net/url"
"path"
@ -328,9 +331,9 @@ type Type string
const (
TypeStdin = "stdin"
TypeLocalFile = "local_file"
TypeFile = "local_file"
TypeSQL = "sql"
TypeRemoteFile = "remote_file"
TypeHTTP = "http_file"
TypeUnknown = "unknown"
)
@ -345,14 +348,14 @@ func TypeOf(loc string) Type {
return TypeSQL
case strings.HasPrefix(loc, "http://"),
strings.HasPrefix(loc, "https://"):
return TypeRemoteFile
return TypeHTTP
default:
}
if _, err := filepath.Abs(loc); err != nil {
return TypeUnknown
}
return TypeLocalFile
return TypeFile
}
// isHTTP tests if s is a well-structured HTTP or HTTPS url, and

View File

@ -216,6 +216,26 @@ var (
_ AssertCompareFunc = require.Greater
)
// RequireNoTake fails if a value is taken from c.
func RequireNoTake[C any](tb testing.TB, c <-chan C, msgAndArgs ...any) {
tb.Helper()
select {
case <-c:
require.Fail(tb, "unexpected take from channel", msgAndArgs...)
default:
}
}
// RequireTake fails if a value is not taken from c.
func RequireTake[C any](tb testing.TB, c <-chan C, msgAndArgs ...any) {
tb.Helper()
select {
case <-c:
default:
require.Fail(tb, "unexpected failure to take from channel", msgAndArgs...)
}
}
// DirCopy copies the contents of sourceDir to a temp dir.
// If keep is false, temp dir will be cleaned up on test exit.
func DirCopy(tb testing.TB, sourceDir string, keep bool) (tmpDir string) {
@ -356,6 +376,47 @@ func MustStat(tb testing.TB, fp string) os.FileInfo {
return fi
}
// MustDrain drains r, failing t on error. If arg cloze is true,
// r is closed if it's an io.Closer, even if the drain fails.
// FIXME: delete this func.
func MustDrain(tb testing.TB, r io.Reader, cloze bool) {
tb.Helper()
_, cpErr := io.Copy(io.Discard, r)
if !cloze {
require.NoError(tb, cpErr)
return
}
var closeErr error
if rc, ok := r.(io.Closer); ok {
closeErr = rc.Close()
}
require.NoError(tb, cpErr)
require.NoError(tb, closeErr)
}
// MustDrainN is like MustDrain, but also reports the number of bytes
// drained. If arg cloze is true, r is closed if it's an io.Closer,
// even if the drain fails.
func MustDrainN(tb testing.TB, r io.Reader, cloze bool) int {
tb.Helper()
n, cpErr := io.Copy(io.Discard, r)
if !cloze {
require.NoError(tb, cpErr)
return int(n)
}
var closeErr error
if rc, ok := r.(io.Closer); ok {
closeErr = rc.Close()
}
require.NoError(tb, cpErr)
require.NoError(tb, closeErr)
return int(n)
}
// TempDir is the standard means for obtaining a temp dir for tests.
// If arg clean is true, the temp dir is created via t.TempDir, and
// thus is deleted on test cleanup.