Ingest checksum issue (#378)

- Refactor `Downloader.Get`.
This commit is contained in:
Neil O'Toole 2024-01-28 23:55:25 -07:00 committed by GitHub
parent 9f59bc4c76
commit 9a0b9b7a9c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 684 additions and 419 deletions

View File

@ -7,7 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
Breaking changes are annotated with ☢️, and alpha/beta features with 🐥. Breaking changes are annotated with ☢️, and alpha/beta features with 🐥.
## [v0.47.0] - 2024-01-28 ## [v0.47.0] - UPCOMING
This is a significant release, focused on improving i/o, responsiveness, This is a significant release, focused on improving i/o, responsiveness,
and performance. The headline features are [caching](https://sq.io/docs/source#cache) and performance. The headline features are [caching](https://sq.io/docs/source#cache)

View File

@ -5,6 +5,7 @@ import (
"bytes" "bytes"
"context" "context"
crand "crypto/rand" crand "crypto/rand"
"errors"
"fmt" "fmt"
"io" "io"
mrand "math/rand" mrand "math/rand"
@ -382,6 +383,87 @@ func (w *notifyOnceWriter) Write(p []byte) (n int, err error) {
return w.w.Write(p) return w.w.Write(p)
} }
// NotifyOnEOFReader returns an [io.Reader] that invokes fn
// when r.Read returns [io.EOF]. The error that fn returns is
// what's returned to the r caller: fn can transform the error
// or return it unchanged. If r or fn is nil, r is returned.
//
// If r is an [io.ReadCloser], the returned reader will also
// implement [io.ReadCloser].
//
// See also: [NotifyOnErrorReader], which is a generalization of
// [NotifyOnEOFReader].
func NotifyOnEOFReader(r io.Reader, fn func(error) error) io.Reader {
if r == nil || fn == nil {
return r
}
if rc, ok := r.(io.ReadCloser); ok {
return &notifyOnEOFReadCloser{notifyOnEOFReader{r: rc, fn: fn}}
}
return &notifyOnEOFReader{r: r}
}
type notifyOnEOFReader struct {
r io.Reader
fn func(error) error
}
// Read implements io.Reader.
func (r *notifyOnEOFReader) Read(p []byte) (n int, err error) {
n, err = r.r.Read(p)
if err != nil && errors.Is(err, io.EOF) {
err = r.fn(err)
}
return n, err
}
var _ io.ReadCloser = (*notifyOnEOFReadCloser)(nil)
type notifyOnEOFReadCloser struct {
notifyOnEOFReader
}
// Close implements io.Closer.
func (r *notifyOnEOFReadCloser) Close() error {
if c, ok := r.r.(io.Closer); ok {
return c.Close()
}
return nil
}
// NotifyOnErrorReader returns an [io.Reader] that invokes fn
// when r.Read returns an error. The error that fn returns is
// what's returned to the r caller: fn can transform the error
// or return it unchanged. If r or fn is nil, r is returned.
//
// See also: [NotifyOnEOFReader], which is a specialization of
// [NotifyOnErrorReader].
func NotifyOnErrorReader(r io.Reader, fn func(error) error) io.Reader {
if r == nil || fn == nil {
return r
}
return &notifyOnErrorReader{r: r}
}
type notifyOnErrorReader struct {
r io.Reader
fn func(error) error
}
// Read implements io.Reader.
func (r *notifyOnErrorReader) Read(p []byte) (n int, err error) {
n, err = r.r.Read(p)
if err != nil {
err = r.fn(err)
}
return n, err
}
// WriteCloser returns w as an io.WriteCloser. If w implements // WriteCloser returns w as an io.WriteCloser. If w implements
// io.WriteCloser, w is returned. Otherwise, w is wrapped in a // io.WriteCloser, w is returned. Otherwise, w is wrapped in a
// no-op decorator that implements io.WriteCloser. // no-op decorator that implements io.WriteCloser.
@ -458,10 +540,35 @@ func DirExists(dir string) bool {
return fi.IsDir() return fi.IsDir()
} }
// Drain drains r. // DrainClose drains rc, returning the number of bytes read, and any error.
func Drain(r io.Reader) error { // The reader is always closed, even if the drain operation returned an error.
_, err := io.Copy(io.Discard, r) // If both the drain and the close operations return non-nil errors, the drain
return err // error is returned.
func DrainClose(rc io.ReadCloser) (n int, err error) {
var n64 int64
n64, err = io.Copy(io.Discard, rc)
n = int(n64)
closeErr := rc.Close()
if err == nil {
err = closeErr
}
return n, errz.Err(err)
}
// Filesize returns the size of the file at fp. An error is returned
// if fp doesn't exist or is a directory.
func Filesize(fp string) (int64, error) {
fi, err := os.Stat(fp)
if err != nil {
return 0, errz.Err(err)
}
if fi.IsDir() {
return 0, errz.Errorf("not a file: %s", fp)
}
return fi.Size(), nil
} }
// FileInfoEqual returns true if a and b are equal. // FileInfoEqual returns true if a and b are equal.

View File

@ -3,7 +3,9 @@ package files
import ( import (
"bytes" "bytes"
"context" "context"
"errors"
"fmt" "fmt"
"io"
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
@ -73,15 +75,44 @@ func (fs *Files) CacheDirFor(src *source.Source) (dir string, err error) {
} }
// WriteIngestChecksum is invoked (after successful ingestion) to write the // WriteIngestChecksum is invoked (after successful ingestion) to write the
// checksum for the ingest cache db. // checksum of the source document file vs the ingest DB. Thus, if the source
// document changes, the checksum will no longer match, and the ingest DB
// will be considered invalid.
func (fs *Files) WriteIngestChecksum(ctx context.Context, src, backingSrc *source.Source) (err error) { func (fs *Files) WriteIngestChecksum(ctx context.Context, src, backingSrc *source.Source) (err error) {
fs.mu.Lock()
defer fs.mu.Unlock()
log := lg.FromContext(ctx) log := lg.FromContext(ctx)
ingestFilePath, err := fs.filepath(src) ingestFilePath, err := fs.filepath(src)
if err != nil { if err != nil {
return err return err
} }
// Write the checksums file. if location.TypeOf(src.Location) == location.TypeHTTP {
// If the source is remote, check if there was a download,
// and if so, make sure it's completed.
stream, ok := fs.streams[src.Handle]
if ok {
select {
case <-stream.Filled():
case <-stream.Done():
case <-ctx.Done():
return ctx.Err()
}
if err = stream.Err(); err != nil && !errors.Is(err, io.EOF) {
return err
}
}
// If we got this far, either there's no stream, or the stream is done,
// which means that the download cache has been updated, and contains
// the fresh cached body file that we'll use to calculate the checksum.
// So, we'll go ahead and do the checksum stuff below.
}
// Now, we need to write a checksum file that contains the computed checksum
// value from ingestFilePath.
var sum checksum.Checksum var sum checksum.Checksum
if sum, err = checksum.ForFile(ingestFilePath); err != nil { if sum, err = checksum.ForFile(ingestFilePath); err != nil {
log.Warn("Failed to compute checksum for source file; caching not in effect", log.Warn("Failed to compute checksum for source file; caching not in effect",
@ -110,9 +141,9 @@ func (fs *Files) CachedBackingSourceFor(ctx context.Context, src *source.Source)
defer fs.mu.Unlock() defer fs.mu.Unlock()
switch location.TypeOf(src.Location) { switch location.TypeOf(src.Location) {
case location.TypeLocalFile: case location.TypeFile:
return fs.cachedBackingSourceForFile(ctx, src) return fs.cachedBackingSourceForFile(ctx, src)
case location.TypeRemoteFile: case location.TypeHTTP:
return fs.cachedBackingSourceForRemoteFile(ctx, src) return fs.cachedBackingSourceForRemoteFile(ctx, src)
default: default:
return nil, false, errz.Errorf("caching not applicable for source: %s", src.Handle) return nil, false, errz.Errorf("caching not applicable for source: %s", src.Handle)

View File

@ -117,7 +117,7 @@ func (fs *Files) detectType(ctx context.Context, handle, loc string) (typ driver
start := time.Now() start := time.Now()
var newRdrFn NewReaderFunc var newRdrFn NewReaderFunc
if location.TypeOf(loc) == location.TypeLocalFile { if location.TypeOf(loc) == location.TypeFile {
newRdrFn = func(ctx context.Context) (io.ReadCloser, error) { newRdrFn = func(ctx context.Context) (io.ReadCloser, error) {
return errz.Return(os.Open(loc)) return errz.Return(os.Open(loc))
} }

View File

@ -2,19 +2,15 @@ package files
import ( import (
"context" "context"
"io"
"net/http" "net/http"
"path/filepath" "path/filepath"
"time" "time"
"github.com/neilotoole/streamcache" "github.com/neilotoole/streamcache"
"github.com/neilotoole/sq/libsq/core/errz"
"github.com/neilotoole/sq/libsq/core/ioz" "github.com/neilotoole/sq/libsq/core/ioz"
"github.com/neilotoole/sq/libsq/core/ioz/checksum" "github.com/neilotoole/sq/libsq/core/ioz/checksum"
"github.com/neilotoole/sq/libsq/core/ioz/httpz" "github.com/neilotoole/sq/libsq/core/ioz/httpz"
"github.com/neilotoole/sq/libsq/core/lg"
"github.com/neilotoole/sq/libsq/core/lg/lgm"
"github.com/neilotoole/sq/libsq/core/options" "github.com/neilotoole/sq/libsq/core/options"
"github.com/neilotoole/sq/libsq/files/internal/downloader" "github.com/neilotoole/sq/libsq/files/internal/downloader"
"github.com/neilotoole/sq/libsq/source" "github.com/neilotoole/sq/libsq/source"
@ -108,78 +104,19 @@ func (fs *Files) maybeStartDownload(ctx context.Context, src *source.Source, che
} }
} }
// Having got this far, we need to talk to the downloader. dlFile, dlStream, err = dldr.Get(ctx)
var ( switch {
dlErrCh = make(chan error, 1) case err != nil:
dlStreamCh = make(chan *streamcache.Stream, 1)
dlFileCh = make(chan string, 1)
)
// Our handler simply pushes the callback values into the channels, which
// this main goroutine will select on at the bottom of the func. The call
// to downloader.Get will be executed in a newly spawned goroutine below.
h := downloader.Handler{
Cached: func(dlFile string) { dlFileCh <- dlFile },
Uncached: func(dlStream *streamcache.Stream) { dlStreamCh <- dlStream },
Error: func(dlErr error) { dlErrCh <- dlErr },
}
go func() {
// Spawn a goroutine to execute the download process.
// The handler will be called before Get returns.
cacheFile := dldr.Get(ctx, h)
if cacheFile == "" {
// Either the download failed, or cache update failed.
return
}
// The download succeeded, and the cache was successfully updated.
// We know that cacheFile exists now. If a stream was created (and
// thus added to Files.streams), we can swap it out and instead
// populate Files.downloadedFiles with the cacheFile. Thus, going
// forward, any clients of Files will get the cacheFile instead of
// the stream.
// We need to lock here because we're accessing Files.streams.
// So, this goroutine will block until the lock is available.
// That shouldn't be an issue: the up-stack Files function that
// acquired the lock will eventually return, releasing the lock,
// at which point the swap will happen. No big deal.
fs.mu.Lock()
defer fs.mu.Unlock()
if stream, ok := fs.streams[src.Handle]; ok && stream != nil {
// The stream exists, and it's safe to close the stream's source,
// (i.e. the http response body), because the body has already
// been completely drained by the downloader: otherwise, we
// wouldn't have a non-empty value for cacheFile.
if c, ok := stream.Source().(io.Closer); ok {
lg.WarnIfCloseError(lg.FromContext(ctx), lgm.CloseHTTPResponseBody, c)
}
}
// Now perform the swap: populate Files.downloadedFiles with cacheFile,
// and remove the stream from Files.streams.
fs.downloadedFiles[src.Handle] = cacheFile
delete(fs.streams, src.Handle)
}() // end of goroutine
// Here we wait on the handler channels.
select {
case <-ctx.Done():
return "", nil, errz.Err(ctx.Err())
case err = <-dlErrCh:
return "", nil, err return "", nil, err
case dlStream = <-dlStreamCh: case dlFile != "":
// New download stream. Add it to Files.streams, // The file is already on disk, so we can just return it.
fs.downloadedFiles[src.Handle] = dlFile
return dlFile, nil, nil
default:
// A new download stream was created. Add it to Files.streams,
// and return the stream. // and return the stream.
fs.streams[src.Handle] = dlStream fs.streams[src.Handle] = dlStream
return "", dlStream, nil return "", dlStream, nil
case dlFile = <-dlFileCh:
// The file is already on disk, so we added it to Files.downloadedFiles,
// and return its filepath.
fs.downloadedFiles[src.Handle] = dlFile
return dlFile, nil, nil
} }
} }
@ -192,7 +129,7 @@ func (fs *Files) downloadPaths(src *source.Source) (dlDir, dlFile string, err er
return "", dlFile, err return "", dlFile, err
} }
// Note: we're depending on internal knowledge of the downloader impl here, // Note: we depend on internal knowledge of the downloader impl here,
// which is not great. It might be better to implement a function // which is not great. It might be better to implement a function
// in pkg downloader. // in pkg downloader.
dlDir = filepath.Join(cacheDir, "download", checksum.Sum([]byte(src.Location))) dlDir = filepath.Join(cacheDir, "download", checksum.Sum([]byte(src.Location)))

View File

@ -11,6 +11,8 @@ import (
"os" "os"
"sync" "sync"
"github.com/neilotoole/sq/libsq/core/ioz"
"github.com/neilotoole/streamcache" "github.com/neilotoole/streamcache"
"github.com/neilotoole/sq/libsq/core/cleanup" "github.com/neilotoole/sq/libsq/core/cleanup"
@ -105,7 +107,7 @@ func New(ctx context.Context, optReg *options.Registry, cfgLock lockfile.LockFun
// An error is returned if src is not a document/file source. // An error is returned if src is not a document/file source.
func (fs *Files) Filesize(ctx context.Context, src *source.Source) (size int64, err error) { func (fs *Files) Filesize(ctx context.Context, src *source.Source) (size int64, err error) {
switch location.TypeOf(src.Location) { switch location.TypeOf(src.Location) {
case location.TypeLocalFile: case location.TypeFile:
var fi os.FileInfo var fi os.FileInfo
if fi, err = os.Stat(src.Location); err != nil { if fi, err = os.Stat(src.Location); err != nil {
return 0, errz.Err(err) return 0, errz.Err(err)
@ -126,7 +128,7 @@ func (fs *Files) Filesize(ctx context.Context, src *source.Source) (size int64,
} }
return int64(total), nil return int64(total), nil
case location.TypeRemoteFile: case location.TypeHTTP:
fs.mu.Lock() fs.mu.Lock()
// First check if the file is already downloaded // First check if the file is already downloaded
@ -134,12 +136,8 @@ func (fs *Files) Filesize(ctx context.Context, src *source.Source) (size int64,
dlFile, ok := fs.downloadedFiles[src.Handle] dlFile, ok := fs.downloadedFiles[src.Handle]
if ok { if ok {
// The file is already downloaded. // The file is already downloaded.
fs.mu.Unlock() defer fs.mu.Unlock()
var fi os.FileInfo return ioz.Filesize(dlFile)
if fi, err = os.Stat(dlFile); err != nil {
return 0, errz.Err(err)
}
return fi.Size(), nil
} }
// It's not in the list of downloaded files, so // It's not in the list of downloaded files, so
@ -147,7 +145,9 @@ func (fs *Files) Filesize(ctx context.Context, src *source.Source) (size int64,
dlStream, ok := fs.streams[src.Handle] dlStream, ok := fs.streams[src.Handle]
if ok { if ok {
fs.mu.Unlock() fs.mu.Unlock()
var total int var total int
// Block until the download completes.
if total, err = dlStream.Total(ctx); err != nil { if total, err = dlStream.Total(ctx); err != nil {
return 0, err return 0, err
} }
@ -155,17 +155,17 @@ func (fs *Files) Filesize(ctx context.Context, src *source.Source) (size int64,
} }
// Finally, we turn to the downloader. // Finally, we turn to the downloader.
defer fs.mu.Unlock()
var dl *downloader.Downloader var dl *downloader.Downloader
dl, err = fs.downloaderFor(ctx, src) if dl, err = fs.downloaderFor(ctx, src); err != nil {
fs.mu.Unlock()
if err != nil {
return 0, err return 0, err
} }
// dl.Filesize will fail if the file has not been downloaded yet, which if dlFile, err = dl.CacheFile(ctx); err != nil {
// means that the source has not been ingested; but Files.Filesize should return 0, err
// not have been invoked before ingestion. }
return dl.Filesize(ctx)
return ioz.Filesize(dlFile)
case location.TypeSQL: case location.TypeSQL:
// Should be impossible. // Should be impossible.
@ -202,9 +202,9 @@ func (fs *Files) AddStdin(ctx context.Context, f *os.File) error {
// file exists. // file exists.
func (fs *Files) filepath(src *source.Source) (string, error) { func (fs *Files) filepath(src *source.Source) (string, error) {
switch location.TypeOf(src.Location) { switch location.TypeOf(src.Location) {
case location.TypeLocalFile: case location.TypeFile:
return src.Location, nil return src.Location, nil
case location.TypeRemoteFile: case location.TypeHTTP:
_, dlFile, err := fs.downloadPaths(src) _, dlFile, err := fs.downloadPaths(src)
if err != nil { if err != nil {
return "", err return "", err
@ -249,7 +249,7 @@ func (fs *Files) newReader(ctx context.Context, src *source.Source, finalRdr boo
return nil, errz.Errorf("unknown source location type: %s", loc) return nil, errz.Errorf("unknown source location type: %s", loc)
case location.TypeSQL: case location.TypeSQL:
return nil, errz.Errorf("invalid to read SQL source: %s", loc) return nil, errz.Errorf("invalid to read SQL source: %s", loc)
case location.TypeLocalFile: case location.TypeFile:
return errz.Return(os.Open(loc)) return errz.Return(os.Open(loc))
case location.TypeStdin: case location.TypeStdin:
stdinStream, ok := fs.streams[source.StdinHandle] stdinStream, ok := fs.streams[source.StdinHandle]
@ -313,13 +313,13 @@ func (fs *Files) Ping(ctx context.Context, src *source.Source) error {
case location.TypeStdin: case location.TypeStdin:
// Stdin is always available. // Stdin is always available.
return nil return nil
case location.TypeLocalFile: case location.TypeFile:
if _, err := os.Stat(src.Location); err != nil { if _, err := os.Stat(src.Location); err != nil {
return errz.Wrapf(err, "ping: failed to stat file source %s: %s", src.Handle, src.Location) return errz.Wrapf(err, "ping: failed to stat file source %s: %s", src.Handle, src.Location)
} }
return nil return nil
case location.TypeRemoteFile: case location.TypeHTTP:
req, err := http.NewRequestWithContext(ctx, http.MethodHead, src.Location, nil) req, err := http.NewRequestWithContext(ctx, http.MethodHead, src.Location, nil)
if err != nil { if err != nil {
return errz.Wrapf(err, "ping: %s", src.Handle) return errz.Wrapf(err, "ping: %s", src.Handle)

View File

@ -12,7 +12,6 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"github.com/neilotoole/sq/libsq/core/ioz"
"github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg"
"github.com/neilotoole/sq/libsq/core/lg/lgt" "github.com/neilotoole/sq/libsq/core/lg/lgt"
"github.com/neilotoole/sq/libsq/core/stringz" "github.com/neilotoole/sq/libsq/core/stringz"
@ -102,11 +101,18 @@ func TestFiles_DriverType(t *testing.T) {
t.Run(tu.Name(location.Redact(tc.loc)), func(t *testing.T) { t.Run(tu.Name(location.Redact(tc.loc)), func(t *testing.T) {
ctx := lg.NewContext(context.Background(), lgt.New(t)) ctx := lg.NewContext(context.Background(), lgt.New(t))
fs, err := files.New(ctx, nil, testh.TempLockFunc(t), tu.TempDir(t, true), tu.CacheDir(t, true)) fs, err := files.New(
ctx,
nil,
testh.TempLockFunc(t),
tu.TempDir(t, false),
tu.CacheDir(t, false),
)
require.NoError(t, err) require.NoError(t, err)
defer func() { assert.NoError(t, fs.Close()) }()
fs.AddDriverDetectors(testh.DriverDetectors()...) fs.AddDriverDetectors(testh.DriverDetectors()...)
gotType, gotErr := fs.DetectType(context.Background(), "@test_"+stringz.Uniq8(), tc.loc) gotType, gotErr := fs.DetectType(ctx, "@test_"+stringz.Uniq8(), tc.loc)
if tc.wantErr { if tc.wantErr {
require.Error(t, gotErr) require.Error(t, gotErr)
return return
@ -267,7 +273,8 @@ func TestFiles_Filesize(t *testing.T) {
// Files.Filesize will block until the stream is fully read. // Files.Filesize will block until the stream is fully read.
r, err := fs.NewReader(th.Context, stdinSrc, false) r, err := fs.NewReader(th.Context, stdinSrc, false)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, ioz.Drain(r)) _, err = io.Copy(io.Discard, r)
require.NoError(t, err)
gotSize2, err := fs.Filesize(th.Context, stdinSrc) gotSize2, err := fs.Filesize(th.Context, stdinSrc)
require.NoError(t, err) require.NoError(t, err)

View File

@ -4,12 +4,13 @@ import (
"bufio" "bufio"
"bytes" "bytes"
"context" "context"
"errors"
"io" "io"
"net/http" "net/http"
"net/http/httputil" "net/http/httputil"
"os" "os"
"path/filepath" "path/filepath"
"time" "sync"
"github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/errz"
"github.com/neilotoole/sq/libsq/core/ioz" "github.com/neilotoole/sq/libsq/core/ioz"
@ -17,7 +18,6 @@ import (
"github.com/neilotoole/sq/libsq/core/ioz/contextio" "github.com/neilotoole/sq/libsq/core/ioz/contextio"
"github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg"
"github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/lg/lga"
"github.com/neilotoole/sq/libsq/core/lg/lgm"
) )
const ( const (
@ -229,8 +229,7 @@ func (c *cache) clear(ctx context.Context) error {
recreateErr := ioz.RequireDir(c.dir) recreateErr := ioz.RequireDir(c.dir)
err := errz.Append(deleteErr, recreateErr) err := errz.Append(deleteErr, recreateErr)
if err != nil { if err != nil {
lg.FromContext(ctx).Error(msgDeleteCache, lg.FromContext(ctx).Error(msgDeleteCache, lga.Dir, c.dir, lga.Err, err)
lga.Dir, c.dir, lga.Err, err)
return err return err
} }
@ -238,85 +237,235 @@ func (c *cache) clear(ctx context.Context) error {
return nil return nil
} }
// write updates the cache. If headerOnly is true, only the header cache file // writeHeader updates the main cache header file from resp. The response body
// is updated, and the function returns. Otherwise, the header and body // is not written to the cache, nor is resp.Body closed.
// cache files are updated, and a checksum file (computed from the body file) func (c *cache) writeHeader(ctx context.Context, resp *http.Response) (err error) {
// is also written to disk. header, err := httputil.DumpResponse(resp, false)
// if err != nil {
// If an error occurs while attempting to update the cache, any existing return errz.Err(err)
// cache artifacts are left untouched. It's a sort of atomic-write-lite. }
// To achieve this, cache files are first written to a staging dir, and that
// staging dir is only swapped with the main cache dir if there are no errors.
//
// The response body is always closed.
func (c *cache) write(ctx context.Context, resp *http.Response, headerOnly bool) (err error) {
log := lg.FromContext(ctx)
start := time.Now()
var stagingDir string
defer func() {
lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, resp.Body)
if stagingDir != "" && ioz.DirExists(stagingDir) {
lg.WarnIfError(log, "Remove cache staging dir", os.RemoveAll(stagingDir))
}
}()
mainDir := filepath.Join(c.dir, "main") mainDir := filepath.Join(c.dir, "main")
if err = ioz.RequireDir(mainDir); err != nil { if err = ioz.RequireDir(mainDir); err != nil {
return err return err
} }
stagingDir = filepath.Join(c.dir, "staging") fp := filepath.Join(mainDir, "header")
if err = ioz.RequireDir(mainDir); err != nil { if _, err = ioz.WriteToFile(ctx, fp, bytes.NewReader(header)); err != nil {
return err return err
} }
headerBytes, err := httputil.DumpResponse(resp, false) lg.FromContext(ctx).Info("Updated download cache (header only)", lga.Dir, c.dir, lga.Resp, resp)
return nil
}
// newResponseCacher returns a new responseCacher for resp.
// On return, resp.Body will be nil.
func (c *cache) newResponseCacher(ctx context.Context, resp *http.Response) (*responseCacher, error) {
defer func() { resp.Body = nil }()
stagingDir := filepath.Join(c.dir, "staging")
if err := ioz.RequireDir(stagingDir); err != nil {
_ = resp.Body.Close()
return nil, err
}
header, err := httputil.DumpResponse(resp, false)
if err != nil { if err != nil {
return errz.Err(err) _ = resp.Body.Close()
return nil, errz.Err(err)
} }
fpHeaderStaging := filepath.Join(stagingDir, "header") if _, err = ioz.WriteToFile(ctx, filepath.Join(stagingDir, "header"), bytes.NewReader(header)); err != nil {
if _, err = ioz.WriteToFile(ctx, fpHeaderStaging, bytes.NewReader(headerBytes)); err != nil { _ = resp.Body.Close()
return err return nil, err
} }
fpHeader, fpBody, _ := c.paths(resp.Request) var f *os.File
if headerOnly { if f, err = os.Create(filepath.Join(stagingDir, "body")); err != nil {
// It's only the header that we're changing, so we don't need to _ = resp.Body.Close()
// swap the entire staging dir, just the header file. return nil, err
if err = os.Rename(fpHeaderStaging, fpHeader); err != nil { }
return errz.Wrap(err, "failed to move staging cache header file")
}
log.Info("Updated download cache (header only)", lga.Dir, c.dir, lga.Resp, resp) r := &responseCacher{
stagingDir: stagingDir,
mainDir: filepath.Join(c.dir, "main"),
body: resp.Body,
f: f,
}
return r, nil
}
var _ io.ReadCloser = (*responseCacher)(nil)
// responseCacher is an io.ReadCloser that wraps an [http.Response.Body],
// appending bytes read via Read to a staging cache file. When Read receives
// [io.EOF] from the wrapped response body, the staging cache is promoted and
// replaces the main cache. If an error occurs during Read, the staging cache is
// discarded, and the main cache is left untouched. If an error occurs during
// cache promotion (which happens on receipt of io.EOF from resp.Body), the
// promotion error, not [io.EOF], is returned by Read. Thus, a consumer of
// responseCacher will not receive [io.EOF] unless the cache is successfully
// promoted.
type responseCacher struct {
body io.ReadCloser
closeErr *error
f *os.File
mainDir string
stagingDir string
mu sync.Mutex
}
// Read implements [io.Reader]. It reads into p from the wrapped response body,
// appends the received bytes to the staging cache, and returns the number of
// bytes read, and any error. When Read encounters [io.EOF] from the response
// body, it promotes the staging cache to main, and on success returns [io.EOF].
// If an error occurs during cache promotion, Read returns that promotion error
// instead of [io.EOF].
func (r *responseCacher) Read(p []byte) (n int, err error) {
r.mu.Lock()
defer r.mu.Unlock()
// Use r.body as a sentinel to indicate that the cache
// has been closed.
if r.body == nil {
return 0, errz.New("response cache already closed")
}
n, err = r.body.Read(p)
switch {
case err == nil:
err = r.cacheAppend(p, n)
return n, err
case !errors.Is(err, io.EOF):
// It's some other kind of error.
// Clean up and return.
_ = r.body.Close()
r.body = nil
_ = r.f.Close()
_ = os.Remove(r.f.Name())
r.f = nil
_ = os.RemoveAll(r.stagingDir)
r.stagingDir = ""
return n, err
default:
// It's EOF time! Let's promote the cache.
}
var err2 error
if err2 = r.cacheAppend(p, n); err2 != nil {
return n, err2
}
if err2 = r.cachePromote(); err2 != nil {
return n, err2
}
return n, err
}
// cacheAppend appends n bytes from p to the staging cache. If an error occurs,
// the staging cache is discarded, and the error is returned.
func (r *responseCacher) cacheAppend(p []byte, n int) error {
_, err := r.f.Write(p[:n])
if err == nil {
return nil return nil
} }
_ = r.body.Close()
r.body = nil
_ = r.f.Close()
_ = os.Remove(r.f.Name())
r.f = nil
_ = os.RemoveAll(r.stagingDir)
r.stagingDir = ""
return errz.Wrap(err, "failed to append http response body bytes to staging cache")
}
fpBodyStaging := filepath.Join(stagingDir, "body") // cachePromote is invoked by Read when it receives io.EOF from the wrapped
var written int64 // response body. It promotes the staging cache to main, and on success returns
if written, err = ioz.WriteToFile(ctx, fpBodyStaging, resp.Body); err != nil { // nil. If an error occurs during promotion, the staging cache is discarded, and
log.Warn("Cache write: failed to write cache body file", lga.Err, err, lga.Path, fpBodyStaging) // the promotion error is returned.
return err func (r *responseCacher) cachePromote() error {
defer func() {
if r.f != nil {
_ = r.f.Close()
r.f = nil
}
if r.body != nil {
_ = r.body.Close()
r.body = nil
}
if r.stagingDir != "" {
_ = os.RemoveAll(r.stagingDir)
r.stagingDir = ""
}
}()
err := r.f.Close()
fpBody := r.f.Name()
r.f = nil
if err != nil {
return errz.Wrap(err, "failed to close cache body file")
} }
sum, err := checksum.ForFile(fpBodyStaging) err = r.body.Close()
r.body = nil
if err != nil { if err != nil {
return errz.Wrap(err, "failed to close http response body")
}
var sum checksum.Checksum
if sum, err = checksum.ForFile(fpBody); err != nil {
return errz.Wrap(err, "failed to compute checksum for cache body file") return errz.Wrap(err, "failed to compute checksum for cache body file")
} }
if err = checksum.WriteFile(filepath.Join(stagingDir, "checksums.txt"), sum, "body"); err != nil { if err = checksum.WriteFile(filepath.Join(r.stagingDir, "checksums.txt"), sum, "body"); err != nil {
return errz.Wrap(err, "failed to write checksum file for cache body") return errz.Wrap(err, "failed to write checksum file for cache body")
} }
// We've got good data in the staging dir. Now we do the switcheroo. // We've got good data in the staging dir. Now we do the switcheroo.
if err = ioz.RenameDir(stagingDir, mainDir); err != nil { if err = ioz.RenameDir(r.stagingDir, r.mainDir); err != nil {
return errz.Wrap(err, "failed to write download cache") return errz.Wrap(err, "failed to write download cache")
} }
r.stagingDir = ""
stagingDir = ""
log.Info("Updated download cache (full)",
lga.Written, written, lga.File, fpBody, lga.Elapsed, time.Since(start).Round(time.Millisecond))
return nil return nil
} }
// Close implements [io.Closer]. Note that cache promotion happens when Read
// receives [io.EOF] from the wrapped response body, so the main action should
// be over by the time Close is invoked. Note that Close is idempotent, and
// returns the same error on subsequent invocations.
func (r *responseCacher) Close() error {
r.mu.Lock()
defer r.mu.Unlock()
if r.closeErr != nil {
// Already closed
return *r.closeErr
}
// There's some duplication of logic with using both r.closeErr
// and r.body == nil as sentinels. This could be cleaned up.
var err error
if r.f != nil {
err = errz.Append(err, r.f.Close())
r.f = nil
}
if r.body != nil {
err = errz.Append(err, r.body.Close())
r.body = nil
}
if r.stagingDir != "" {
err = errz.Append(err, os.RemoveAll(r.stagingDir))
r.stagingDir = ""
}
r.closeErr = &err
return *r.closeErr
}

View File

@ -104,9 +104,14 @@ func (s State) String() string {
const XFromCache = "X-From-Stream" const XFromCache = "X-From-Stream"
// Downloader encapsulates downloading a file from a URL, using a local // Downloader encapsulates downloading a file from a URL, using a local
// disk cache if possible. Downloader.Get makes uses of the Handler callback // disk cache if possible. Downloader.Get returns either a filepath to the
// mechanism to facilitate early consumption of a download stream while the // already-downloaded file, or a stream of the download in progress, or an
// download is still in flight. // error. If a stream is returned, the Downloader cache is updated when the
// stream is fully consumed (this can be observed by the closing of the
// channel returned by [streamcache.Stream.Filled]).
//
// To be extra clear about that last point: the caller must consume the
// stream returned by Downloader.Get, or the cache will not be written.
type Downloader struct { type Downloader struct {
// c is the HTTP client used to make requests. // c is the HTTP client used to make requests.
c *http.Client c *http.Client
@ -115,10 +120,6 @@ type Downloader struct {
// It will be created in dlDir. // It will be created in dlDir.
cache *cache cache *cache
// dlStream is the streamcache.Stream that is passed Handler.Uncached for an
// active download. This field is reset to nil on each call to Downloader.Get.
dlStream *streamcache.Stream
// name is a user-friendly name, such as a source handle like @data. // name is a user-friendly name, such as a source handle like @data.
name string name string
@ -166,27 +167,28 @@ func New(name string, c *http.Client, dlURL, dlDir string) (*Downloader, error)
return dl, nil return dl, nil
} }
// Get attempts to get the remote file, invoking Handler as appropriate. Exactly // Get attempts to get the remote file, returning either the filepath of the
// one of the Handler methods will be invoked, one time. // already-cached file in dlFile, or a stream of a newly-started download in
// dlStream, or an error. Exactly one of the return values will be non-nil.
// //
// - If Handler.Uncached is invoked, a download stream has begun. Get will // - If dlFile is non-empty, it is the filepath on disk of the cached download,
// then block until the download is completed. The download resp.Body is // and dlStream and err are nil. However, depending on OptContinueOnError,
// written to cache, and on success, the filepath to the newly updated // dlFile may be the path to a stale download. If the cache is stale and a
// cache file is returned. // transport error occurs during refresh, and OptContinueOnError is true,
// If an error occurs during cache write, the error is logged, and Get // the previous cached download is returned. If OptContinueOnError is false,
// returns the filepath of the previously cached download, if permitted // the transport error is returned, and dlFile is empty. The caller can also
// by policy. If not permitted or not existing, empty string is returned. // check the cache state via [Downloader.State].
// - If Handler.Cached is invoked, Get returns immediately afterwards with // - If dlStream is non-nil, it is a stream of the download in progress, and
// the filepath of the cached download (the same value provided to // dlFile is empty. The cache is updated when the stream has been completely
// Handler.Cached). // consumed. If the stream is not consumed, the cache is not updated. If an
// - If Handler.Error is invoked, there was an unrecoverable problem (e.g. a // error occurs reading from the stream, the cache is also not updated: this
// transport error, and there's no previous cache) and the download is // means that the cache may still contain the previous (stale) download.
// unavailable. That error should be propagated up the stack. Get will // - If err is non-nil, there was an unrecoverable problem (e.g. a transport
// return empty string. // error, and there's no previous cache) and the download is unavailable.
// //
// Get consults the context for options. In particular, it makes // Get consults the context for options. In particular, it makes use of OptCache
// use of OptCache and OptContinueOnError. // and OptContinueOnError.
func (dl *Downloader) Get(ctx context.Context, h Handler) (cacheFile string) { func (dl *Downloader) Get(ctx context.Context) (dlFile string, dlStream *streamcache.Stream, err error) {
dl.mu.Lock() dl.mu.Lock()
defer dl.mu.Unlock() defer dl.mu.Unlock()
@ -198,19 +200,16 @@ func (dl *Downloader) Get(ctx context.Context, h Handler) (cacheFile string) {
req := dl.mustRequest(ctx) req := dl.mustRequest(ctx)
lg.FromContext(ctx).Debug("Get download", lga.URL, dl.url) lg.FromContext(ctx).Debug("Get download", lga.URL, dl.url)
cacheFile = dl.get(req, h) return dl.get(req)
return cacheFile
} }
// get contains the main logic for getting the download. // get contains the main logic for getting the download.
// It invokes Handler as appropriate, and on success returns the func (dl *Downloader) get(req *http.Request) (dlFile string, //nolint:gocognit,funlen,cyclop
// filepath of the valid cached download file. dlStream *streamcache.Stream, err error,
func (dl *Downloader) get(req *http.Request, h Handler) (cacheFile string) { //nolint:gocognit,funlen,cyclop ) {
ctx := req.Context() ctx := req.Context()
log := lg.FromContext(ctx) log := lg.FromContext(ctx)
dl.dlStream = nil
var fpBody string var fpBody string
if dl.cache != nil { if dl.cache != nil {
_, fpBody, _ = dl.cache.paths(req) _, fpBody, _ = dl.cache.paths(req)
@ -219,12 +218,10 @@ func (dl *Downloader) get(req *http.Request, h Handler) (cacheFile string) { //n
state := dl.state(req) state := dl.state(req)
if state == Fresh && fpBody != "" { if state == Fresh && fpBody != "" {
// The cached response is fresh, so we can return it. // The cached response is fresh, so we can return it.
h.Cached(fpBody) return fpBody, nil, nil
return fpBody
} }
cacheable := dl.isCacheable(req) cacheable := dl.isCacheable(req)
var err error
var cachedResp *http.Response var cachedResp *http.Response
if cacheable { if cacheable {
cachedResp, err = dl.cache.get(req.Context(), req) //nolint:bodyclose cachedResp, err = dl.cache.get(req.Context(), req) //nolint:bodyclose
@ -241,8 +238,7 @@ func (dl *Downloader) get(req *http.Request, h Handler) (cacheFile string) { //n
freshness := getFreshness(cachedResp.Header, req.Header) freshness := getFreshness(cachedResp.Header, req.Header)
if freshness == Fresh && fpBody != "" { if freshness == Fresh && fpBody != "" {
lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, cachedResp.Body) lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, cachedResp.Body)
h.Cached(fpBody) return fpBody, nil, nil
return fpBody
} }
if freshness == Stale { if freshness == Stale {
@ -279,12 +275,11 @@ func (dl *Downloader) get(req *http.Request, h Handler) (cacheFile string) { //n
case fpBody != "" && (err != nil || resp.StatusCode >= 500) && case fpBody != "" && (err != nil || resp.StatusCode >= 500) &&
req.Method == http.MethodGet && canStaleOnError(cachedResp.Header, req.Header): req.Method == http.MethodGet && canStaleOnError(cachedResp.Header, req.Header):
// In case of transport failure and stale-if-error activated, returns cached content // In case of transport failure canStaleOnError returns true,
// when available. // return the stale cached download.
lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, cachedResp.Body) lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, cachedResp.Body)
log.Warn("Returning cached response due to transport failure", lga.Err, err) log.Warn("Returning cached response due to transport failure", lga.Err, err)
h.Cached(fpBody) return fpBody, nil, nil
return fpBody
default: default:
if err != nil && resp != nil && resp.StatusCode != http.StatusOK { if err != nil && resp != nil && resp.StatusCode != http.StatusOK {
@ -293,46 +288,41 @@ func (dl *Downloader) get(req *http.Request, h Handler) (cacheFile string) { //n
if fp := dl.cacheFileOnError(req, err); fp != "" { if fp := dl.cacheFileOnError(req, err); fp != "" {
lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, resp.Body) lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, resp.Body)
h.Cached(fp) return fp, nil, nil
return fp
} }
} }
if err != nil { if err != nil {
lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, resp.Body) lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, resp.Body)
if fp := dl.cacheFileOnError(req, err); fp != "" { if fp := dl.cacheFileOnError(req, err); fp != "" {
h.Cached(fp) return fp, nil, nil
return fp
} }
h.Error(err)
return "" return "", nil, err
} }
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, resp.Body) lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, resp.Body)
err = errz.Errorf("download: unexpected HTTP status: %s", httpz.StatusText(resp.StatusCode)) err = errz.Errorf("download: unexpected HTTP status: %s", httpz.StatusText(resp.StatusCode))
if fp := dl.cacheFileOnError(req, err); fp != "" { if fp := dl.cacheFileOnError(req, err); fp != "" {
h.Cached(fp) return fp, nil, nil
return fp
} }
h.Error(err) return "", nil, err
return ""
} }
} }
} else { } else {
reqCacheControl := parseCacheControl(req.Header) reqCacheControl := parseCacheControl(req.Header)
if _, ok := reqCacheControl["only-if-cached"]; ok { if _, ok := reqCacheControl["only-if-cached"]; ok {
resp = newGatewayTimeoutResponse(req) //nolint:bodyclose resp = newGatewayTimeoutResponse(req)
} else { } else {
resp, err = dl.do(req) //nolint:bodyclose resp, err = dl.do(req) //nolint:bodyclose
if err != nil { if err != nil {
if fp := dl.cacheFileOnError(req, err); fp != "" { if fp := dl.cacheFileOnError(req, err); fp != "" {
h.Cached(fp) return fp, nil, nil
return fp
} }
h.Error(err)
return "" return "", nil, err
} }
} }
} }
@ -348,59 +338,55 @@ func (dl *Downloader) get(req *http.Request, h Handler) (cacheFile string) { //n
} }
if resp == cachedResp { if resp == cachedResp {
err = dl.cache.write(ctx, resp, true) err = dl.cache.writeHeader(ctx, resp)
lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, resp.Body) lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, resp.Body)
if err != nil { if err != nil {
log.Error("Failed to update cache header", lga.Dir, dl.cache.dir, lga.Err, err) log.Error("Failed to update cache header", lga.Dir, dl.cache.dir, lga.Err, err)
if fp := dl.cacheFileOnError(req, err); fp != "" { if fp := dl.cacheFileOnError(req, err); fp != "" {
h.Cached(fp) return fp, nil, nil
return fp
} }
h.Error(err)
return "" return "", nil, err
} }
if fpBody != "" { if fpBody != "" {
h.Cached(fpBody) return fpBody, nil, nil
return fpBody
} }
} else if cachedResp != nil && cachedResp.Body != nil { } else if cachedResp != nil && cachedResp.Body != nil {
lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, cachedResp.Body) lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, cachedResp.Body)
} }
dl.dlStream = streamcache.New(resp.Body) // OK, this is where the funky stuff happens.
resp.Body = dl.dlStream.NewReader(ctx) //
h.Uncached(dl.dlStream) // First, note that the cache is two-stage: there's a staging cache, and a
// main cache. The staging cache is used to write the response body to disk,
if err = dl.cache.write(req.Context(), resp, false); err != nil { // and when the response body is fully consumed, the staging cache is
// We don't explicitly call Handler.Error: it would be "illegal" to do so // promoted to main, in a sort of atomic-swap-lite. This is done to avoid
// anyway, because the Handler docs state that at most one Handler callback // partially-written cache files in the main cache, and other such nastiness.
// func is ever invoked. //
// // The responseCacher type is an io.ReadCloser that wraps the response body.
// The cache write could fail for one of two reasons: // As its Read method is called, it writes the body bytes to a staging cache
// // file (and also returns them to the caller). When rCacher encounters io.EOF,
// - The download didn't complete successfully: that is, there was an error // it promotes the staging cache to main before returning io.EOF to the
// reading from resp.Body. In this case, that same error will be propagated // caller. If promotion fails, the promotion error (not io.EOF) is returned
// to the Handler via the streamcache.Stream that was provided to Handler.Uncached. // to the caller. Thus, it is guaranteed that any caller of rCacher's Read
// - The download completed, but there was a problem writing out the cache // method will only receive io.EOF if the cache has been promoted to main.
// files (header, body, checksum). This is likely a very rare occurrence. var rCacher *responseCacher
// In that case, any previous cache files are left untouched by cache.write, if rCacher, err = dl.cache.newResponseCacher(ctx, resp); err != nil {
// and all we do is log the error. If the cache is inconsistent, it will return "", nil, err
// repair itself on next invocation, so it's not a big deal.
log.Warn("Failed to write download cache", lga.Dir, dl.cache.dir, lga.Err, err)
lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, resp.Body)
return ""
} }
lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, resp.Body)
return fpBody // And now we wrap rCacher in a streamcache: any streamcache readers will
// only receive io.EOF when/if the staging cache has been promoted to main.
dlStream = streamcache.New(rCacher)
return "", dlStream, nil
} }
// It's not cacheable, so we can just wrap resp.Body in a streamcache // The response is not cacheable, so we can just wrap resp.Body in a
// and return it. // streamcache and return it.
dl.dlStream = streamcache.New(resp.Body) dlStream = streamcache.New(resp.Body)
resp.Body = nil // Unnecessary, but just to be explicit. resp.Body = nil // Unnecessary, but just to be explicit.
h.Uncached(dl.dlStream) return "", dlStream, nil
return ""
} }
// do executes the request. // do executes the request.
@ -494,42 +480,12 @@ func (dl *Downloader) state(req *http.Request) State {
return getFreshness(cachedResp.Header, req.Header) return getFreshness(cachedResp.Header, req.Header)
} }
// Filesize returns the size of the downloaded file. This should // CacheFile returns the path to the cached file, if it exists. If there's
// only be invoked after the download has completed or is cached, // a download in progress ([Downloader.Get] returned a stream), then CacheFile
// as it may block until the download completes. // may return the filepath to the previously cached file. The caller should
func (dl *Downloader) Filesize(ctx context.Context) (int64, error) { // wait on any previously returned download stream to complete to ensure
dl.mu.Lock() // that the returned filepath is that of the current download. The caller
defer dl.mu.Unlock() // can also check the cache state via [Downloader.State].
if dl.dlStream != nil {
// There's an active download, so we can get the filesize
// when the download completes.
size, err := dl.dlStream.Total(ctx)
return int64(size), err
}
if dl.cache == nil {
return 0, errz.New("download file size not available")
}
req := dl.mustRequest(ctx)
if !dl.cache.exists(req) {
// It's not in the cache.
return 0, errz.New("download file size not available")
}
// It's in the cache.
_, fp, _ := dl.cache.paths(req)
fi, err := os.Stat(fp)
if err != nil {
return 0, errz.Wrapf(err, "unable to stat cached download file: %s", fp)
}
return fi.Size(), nil
}
// CacheFile returns the path to the cached file, if it exists and has
// been fully downloaded.
func (dl *Downloader) CacheFile(ctx context.Context) (fp string, err error) { func (dl *Downloader) CacheFile(ctx context.Context) (fp string, err error) {
dl.mu.Lock() dl.mu.Lock()
defer dl.mu.Unlock() defer dl.mu.Unlock()

View File

@ -12,23 +12,101 @@ import (
"testing" "testing"
"time" "time"
"github.com/neilotoole/streamcache"
"github.com/neilotoole/sq/libsq/core/ioz"
"github.com/neilotoole/sq/libsq/core/lg/lga"
"github.com/neilotoole/sq/libsq/core/stringz"
"github.com/neilotoole/sq/testh/sakila"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/neilotoole/sq/libsq/core/ioz"
"github.com/neilotoole/sq/libsq/core/ioz/httpz" "github.com/neilotoole/sq/libsq/core/ioz/httpz"
"github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg"
"github.com/neilotoole/sq/libsq/core/lg/lga"
"github.com/neilotoole/sq/libsq/core/lg/lgt" "github.com/neilotoole/sq/libsq/core/lg/lgt"
"github.com/neilotoole/sq/libsq/core/options"
"github.com/neilotoole/sq/libsq/core/stringz"
"github.com/neilotoole/sq/libsq/files"
"github.com/neilotoole/sq/libsq/files/internal/downloader" "github.com/neilotoole/sq/libsq/files/internal/downloader"
"github.com/neilotoole/sq/testh/sakila"
"github.com/neilotoole/sq/testh/tu" "github.com/neilotoole/sq/testh/tu"
) )
func TestDownload_redirect(t *testing.T) { func TestDownloader(t *testing.T) {
const dlURL = sakila.ActorCSVURL
log := lgt.New(t)
ctx := lg.NewContext(context.Background(), log)
cacheDir := t.TempDir()
dl, gotErr := downloader.New(t.Name(), httpz.NewDefaultClient(), dlURL, cacheDir)
require.NoError(t, gotErr)
require.NoError(t, dl.Clear(ctx))
require.Equal(t, downloader.Uncached, dl.State(ctx))
sum, ok := dl.Checksum(ctx)
require.False(t, ok)
require.Empty(t, sum)
// Here's our first download, it's not cached, so we should
// get a stream.
gotFile, gotStream, gotErr := dl.Get(ctx)
require.NoError(t, gotErr)
require.Empty(t, gotFile)
require.NotNil(t, gotStream)
require.Equal(t, downloader.Uncached, dl.State(ctx))
// The stream should not be filled yet, because we
// haven't read from it.
tu.RequireNoTake(t, gotStream.Filled())
// And there should be no cache file, because the cache file
// isn't written until the stream is drained.
gotFile, gotErr = dl.CacheFile(ctx)
require.Error(t, gotErr)
require.Empty(t, gotFile)
r := gotStream.NewReader(ctx)
gotStream.Seal()
// Now we drain the stream, and the cache should magically fill.
var gotN int
gotN, gotErr = ioz.DrainClose(r)
require.NoError(t, gotErr)
require.Equal(t, sakila.ActorCSVSize, gotN)
tu.RequireTake(t, gotStream.Filled())
tu.RequireTake(t, gotStream.Done())
require.Equal(t, sakila.ActorCSVSize, gotStream.Size())
require.Equal(t, downloader.Fresh, dl.State(ctx))
// Now we should be able to access the cache.
sum, ok = dl.Checksum(ctx)
require.True(t, ok)
require.NotEmpty(t, sum)
gotFile, gotErr = dl.CacheFile(ctx)
require.NoError(t, gotErr)
require.NotEmpty(t, gotFile)
gotSize, gotErr := ioz.Filesize(gotFile)
require.NoError(t, gotErr)
require.Equal(t, sakila.ActorCSVSize, int(gotSize))
// Let's download again, and verify that the cache is used.
gotFile, gotStream, gotErr = dl.Get(ctx)
require.Nil(t, gotErr)
require.Nil(t, gotStream)
require.NotEmpty(t, gotFile)
gotFileBytes, gotErr := os.ReadFile(gotFile)
require.NoError(t, gotErr)
require.Equal(t, sakila.ActorCSVSize, len(gotFileBytes))
require.Equal(t, downloader.Fresh, dl.State(ctx))
sum, ok = dl.Checksum(ctx)
require.True(t, ok)
require.NotEmpty(t, sum)
require.NoError(t, dl.Clear(ctx))
require.Equal(t, downloader.Uncached, dl.State(ctx))
sum, ok = dl.Checksum(ctx)
require.False(t, ok)
require.Empty(t, sum)
}
func TestDownloader_redirect(t *testing.T) {
const hello = `Hello World!` const hello = `Hello World!`
serveBody := hello serveBody := hello
lastModified := time.Now().UTC() lastModified := time.Now().UTC()
@ -81,100 +159,28 @@ func TestDownload_redirect(t *testing.T) {
dl, err := downloader.New(t.Name(), httpz.NewDefaultClient(), loc, cacheDir) dl, err := downloader.New(t.Name(), httpz.NewDefaultClient(), loc, cacheDir)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, dl.Clear(ctx)) require.NoError(t, dl.Clear(ctx))
h := downloader.NewSinkHandler(log.With("origin", "handler"))
gotGetFile := dl.Get(ctx, h.Handler) gotFile, gotStream, gotErr := dl.Get(ctx)
require.Empty(t, h.Errors) require.NoError(t, gotErr)
require.NotEmpty(t, gotGetFile) require.NotNil(t, gotStream)
gotBody := tu.ReadToString(t, h.Streams[0].NewReader(ctx)) require.Empty(t, gotFile)
gotBody := tu.ReadToString(t, gotStream.NewReader(ctx))
require.Equal(t, hello, gotBody) require.Equal(t, hello, gotBody)
h.Reset() gotFile, gotStream, gotErr = dl.Get(ctx)
gotGetFile = dl.Get(ctx, h.Handler) require.NoError(t, gotErr)
require.NotEmpty(t, gotGetFile) require.Nil(t, gotStream)
require.Empty(t, h.Errors) require.NotEmpty(t, gotFile)
require.Empty(t, h.Streams) t.Logf("got fp: %s", gotFile)
gotDownloadedFile := h.Downloaded[0] gotBody = tu.ReadFileToString(t, gotFile)
t.Logf("got fp: %s", gotDownloadedFile)
gotBody = tu.ReadFileToString(t, gotDownloadedFile)
t.Logf("got body: \n\n%s\n\n", gotBody) t.Logf("got body: \n\n%s\n\n", gotBody)
require.Equal(t, serveBody, gotBody) require.Equal(t, serveBody, gotBody)
h.Reset()
gotGetFile = dl.Get(ctx, h.Handler)
require.NotEmpty(t, gotGetFile)
require.Empty(t, h.Errors)
require.Empty(t, h.Streams)
gotDownloadedFile = h.Downloaded[0]
t.Logf("got fp: %s", gotDownloadedFile)
gotBody = tu.ReadFileToString(t, gotDownloadedFile)
t.Logf("got body: \n\n%s\n\n", gotBody)
require.Equal(t, serveBody, gotBody)
}
func TestDownload_New(t *testing.T) {
log := lgt.New(t)
ctx := lg.NewContext(context.Background(), log)
const dlURL = sakila.ActorCSVURL
cacheDir := t.TempDir()
dl, err := downloader.New(t.Name(), httpz.NewDefaultClient(), dlURL, cacheDir)
require.NoError(t, err)
require.NoError(t, dl.Clear(ctx))
require.Equal(t, downloader.Uncached, dl.State(ctx))
sum, ok := dl.Checksum(ctx)
require.False(t, ok)
require.Empty(t, sum)
h := downloader.NewSinkHandler(log.With("origin", "handler"))
gotGetFile := dl.Get(ctx, h.Handler)
require.NotEmpty(t, gotGetFile)
require.Empty(t, h.Errors)
require.Empty(t, h.Downloaded)
require.Equal(t, 1, len(h.Streams))
require.Equal(t, int64(sakila.ActorCSVSize), int64(h.Streams[0].Size()))
require.Equal(t, downloader.Fresh, dl.State(ctx))
sum, ok = dl.Checksum(ctx)
require.True(t, ok)
require.NotEmpty(t, sum)
h.Reset()
gotGetFile = dl.Get(ctx, h.Handler)
require.NotEmpty(t, gotGetFile)
require.Empty(t, h.Errors)
require.Empty(t, h.Streams)
require.NotEmpty(t, h.Downloaded)
require.Equal(t, gotGetFile, h.Downloaded[0])
gotFileBytes, err := os.ReadFile(h.Downloaded[0])
require.NoError(t, err)
require.Equal(t, sakila.ActorCSVSize, len(gotFileBytes))
require.Equal(t, downloader.Fresh, dl.State(ctx))
sum, ok = dl.Checksum(ctx)
require.True(t, ok)
require.NotEmpty(t, sum)
require.NoError(t, dl.Clear(ctx))
require.Equal(t, downloader.Uncached, dl.State(ctx))
sum, ok = dl.Checksum(ctx)
require.False(t, ok)
require.Empty(t, sum)
h.Reset()
gotGetFile = dl.Get(ctx, h.Handler)
require.Empty(t, h.Errors)
require.NotEmpty(t, gotGetFile)
require.True(t, ioz.FileAccessible(gotGetFile))
h.Reset()
} }
func TestCachePreservedOnFailedRefresh(t *testing.T) { func TestCachePreservedOnFailedRefresh(t *testing.T) {
o := options.Options{files.OptHTTPResponseTimeout.Key(): "10m"} ctx := lg.NewContext(context.Background(), lgt.New(t))
ctx := options.NewContext(context.Background(), o)
var ( var (
log = lgt.New(t)
srvr *httptest.Server srvr *httptest.Server
srvrShouldBodyError bool srvrShouldBodyError bool
srvrShouldNoCache bool srvrShouldNoCache bool
@ -216,32 +222,34 @@ func TestCachePreservedOnFailedRefresh(t *testing.T) {
})) }))
t.Cleanup(srvr.Close) t.Cleanup(srvr.Close)
ctx = lg.NewContext(ctx, log.With("origin", "downloader"))
cacheDir := filepath.Join(t.TempDir(), stringz.UniqSuffix("dlcache")) cacheDir := filepath.Join(t.TempDir(), stringz.UniqSuffix("dlcache"))
dl, err := downloader.New(t.Name(), httpz.NewDefaultClient(), srvr.URL, cacheDir) dl, err := downloader.New(t.Name(), httpz.NewDefaultClient(), srvr.URL, cacheDir)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, dl.Clear(ctx)) require.NoError(t, dl.Clear(ctx))
h := downloader.NewSinkHandler(log.With("origin", "handler"))
gotGetFile := dl.Get(ctx, h.Handler) var gotFile string
require.Empty(t, h.Errors) var gotStream *streamcache.Stream
require.NotEmpty(t, h.Streams) gotFile, gotStream, err = dl.Get(ctx)
require.NotEmpty(t, gotGetFile, "cache file should have been filled") require.NoError(t, err)
require.True(t, ioz.FileAccessible(gotGetFile)) require.Empty(t, gotFile)
require.NotNil(t, gotStream)
stream := h.Streams[0] tu.RequireNoTake(t, gotStream.Filled())
start := time.Now() r := gotStream.NewReader(ctx)
gotStream.Seal()
t.Logf("Waiting for download to complete") t.Logf("Waiting for download to complete")
<-stream.Filled() start := time.Now()
var gotN int
gotN, err = ioz.DrainClose(r)
require.NoError(t, err)
t.Logf("Download completed after %s", time.Since(start)) t.Logf("Download completed after %s", time.Since(start))
require.True(t, errors.Is(stream.Err(), io.EOF)) tu.RequireTake(t, gotStream.Filled())
gotSize, err := dl.Filesize(ctx) tu.RequireTake(t, gotStream.Done())
require.Equal(t, len(sentBody), gotN)
require.True(t, errors.Is(gotStream.Err(), io.EOF))
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, len(sentBody), int(gotSize)) require.Equal(t, len(sentBody), gotStream.Size())
require.Equal(t, len(sentBody), stream.Size())
gotFilesize, err := dl.Filesize(ctx)
require.NoError(t, err)
require.Equal(t, len(sentBody), int(gotFilesize))
fpBody, err := dl.CacheFile(ctx) fpBody, err := dl.CacheFile(ctx)
require.NoError(t, err) require.NoError(t, err)
@ -256,23 +264,31 @@ func TestCachePreservedOnFailedRefresh(t *testing.T) {
fiChecksums1 := tu.MustStat(t, fpChecksums) fiChecksums1 := tu.MustStat(t, fpChecksums)
srvrShouldBodyError = true srvrShouldBodyError = true
h.Reset()
// Sleep to allow file modification timestamps to tick // Sleep to allow file modification timestamps to tick
time.Sleep(time.Millisecond * 10) time.Sleep(time.Millisecond * 10)
gotGetFile = dl.Get(ctx, h.Handler) gotFile, gotStream, err = dl.Get(ctx)
require.Empty(t, h.Errors) require.NoError(t, err)
require.Empty(t, gotGetFile, require.Empty(t, gotFile,
"gotCacheFile should be empty, because the server returned an error during cache write") "gotFile should be empty, because the server returned an error during cache write")
require.NotEmpty(t, h.Streams, require.NotNil(t, gotStream,
"h.Streams should not be empty, because the download was in fact initiated") "gotStream should not be empty, because the download was in fact initiated")
stream = h.Streams[0]
<-stream.Filled() r = gotStream.NewReader(ctx)
err = stream.Err() gotStream.Seal()
gotN, err = ioz.DrainClose(r)
require.Equal(t, len(sentBody), gotN)
tu.RequireTake(t, gotStream.Filled())
tu.RequireTake(t, gotStream.Done())
streamErr := gotStream.Err()
require.Error(t, streamErr)
require.True(t, errors.Is(err, streamErr))
t.Logf("got stream err: %v", err) t.Logf("got stream err: %v", err)
require.Error(t, err)
require.True(t, errors.Is(err, io.ErrUnexpectedEOF)) require.True(t, errors.Is(err, io.ErrUnexpectedEOF))
require.Equal(t, len(sentBody), stream.Size()) require.Equal(t, len(sentBody), gotStream.Size())
// Verify that the server hasn't updated the cache, // Verify that the server hasn't updated the cache,
// by checking that the file modification timestamps // by checking that the file modification timestamps
@ -284,6 +300,4 @@ func TestCachePreservedOnFailedRefresh(t *testing.T) {
require.True(t, ioz.FileInfoEqual(fiBody1, fiBody2)) require.True(t, ioz.FileInfoEqual(fiBody1, fiBody2))
require.True(t, ioz.FileInfoEqual(fiHeader1, fiHeader2)) require.True(t, ioz.FileInfoEqual(fiHeader1, fiHeader2))
require.True(t, ioz.FileInfoEqual(fiChecksums1, fiChecksums2)) require.True(t, ioz.FileInfoEqual(fiChecksums1, fiChecksums2))
h.Reset()
} }

View File

@ -1,6 +1,9 @@
// Package location contains functionality related to source location. // Package location contains functionality related to source location.
package location package location
// NOTE: This package contains code from several eras. There's a bunch of
// overlap and duplication. It should be consolidated.
import ( import (
"net/url" "net/url"
"path" "path"
@ -327,11 +330,11 @@ func isFpath(loc string) (fpath string, ok bool) {
type Type string type Type string
const ( const (
TypeStdin = "stdin" TypeStdin = "stdin"
TypeLocalFile = "local_file" TypeFile = "local_file"
TypeSQL = "sql" TypeSQL = "sql"
TypeRemoteFile = "remote_file" TypeHTTP = "http_file"
TypeUnknown = "unknown" TypeUnknown = "unknown"
) )
// TypeOf returns the type of loc, or locTypeUnknown if it // TypeOf returns the type of loc, or locTypeUnknown if it
@ -345,14 +348,14 @@ func TypeOf(loc string) Type {
return TypeSQL return TypeSQL
case strings.HasPrefix(loc, "http://"), case strings.HasPrefix(loc, "http://"),
strings.HasPrefix(loc, "https://"): strings.HasPrefix(loc, "https://"):
return TypeRemoteFile return TypeHTTP
default: default:
} }
if _, err := filepath.Abs(loc); err != nil { if _, err := filepath.Abs(loc); err != nil {
return TypeUnknown return TypeUnknown
} }
return TypeLocalFile return TypeFile
} }
// isHTTP tests if s is a well-structured HTTP or HTTPS url, and // isHTTP tests if s is a well-structured HTTP or HTTPS url, and

View File

@ -216,6 +216,26 @@ var (
_ AssertCompareFunc = require.Greater _ AssertCompareFunc = require.Greater
) )
// RequireNoTake fails if a value is taken from c.
func RequireNoTake[C any](tb testing.TB, c <-chan C, msgAndArgs ...any) {
tb.Helper()
select {
case <-c:
require.Fail(tb, "unexpected take from channel", msgAndArgs...)
default:
}
}
// RequireTake fails if a value is not taken from c.
func RequireTake[C any](tb testing.TB, c <-chan C, msgAndArgs ...any) {
tb.Helper()
select {
case <-c:
default:
require.Fail(tb, "unexpected failure to take from channel", msgAndArgs...)
}
}
// DirCopy copies the contents of sourceDir to a temp dir. // DirCopy copies the contents of sourceDir to a temp dir.
// If keep is false, temp dir will be cleaned up on test exit. // If keep is false, temp dir will be cleaned up on test exit.
func DirCopy(tb testing.TB, sourceDir string, keep bool) (tmpDir string) { func DirCopy(tb testing.TB, sourceDir string, keep bool) (tmpDir string) {
@ -356,6 +376,47 @@ func MustStat(tb testing.TB, fp string) os.FileInfo {
return fi return fi
} }
// MustDrain drains r, failing t on error. If arg cloze is true,
// r is closed if it's an io.Closer, even if the drain fails.
// FIXME: delete this func.
func MustDrain(tb testing.TB, r io.Reader, cloze bool) {
tb.Helper()
_, cpErr := io.Copy(io.Discard, r)
if !cloze {
require.NoError(tb, cpErr)
return
}
var closeErr error
if rc, ok := r.(io.Closer); ok {
closeErr = rc.Close()
}
require.NoError(tb, cpErr)
require.NoError(tb, closeErr)
}
// MustDrainN is like MustDrain, but also reports the number of bytes
// drained. If arg cloze is true, r is closed if it's an io.Closer,
// even if the drain fails.
func MustDrainN(tb testing.TB, r io.Reader, cloze bool) int {
tb.Helper()
n, cpErr := io.Copy(io.Discard, r)
if !cloze {
require.NoError(tb, cpErr)
return int(n)
}
var closeErr error
if rc, ok := r.(io.Closer); ok {
closeErr = rc.Close()
}
require.NoError(tb, cpErr)
require.NoError(tb, closeErr)
return int(n)
}
// TempDir is the standard means for obtaining a temp dir for tests. // TempDir is the standard means for obtaining a temp dir for tests.
// If arg clean is true, the temp dir is created via t.TempDir, and // If arg clean is true, the temp dir is created via t.TempDir, and
// thus is deleted on test cleanup. // thus is deleted on test cleanup.