package diffdoc import ( "context" "io" "golang.org/x/sync/errgroup" "github.com/neilotoole/sq/libsq/core/ioz/contextio" "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/lg/lgm" ) // Differ encapsulates a [Doc] and a function that populates the [Doc]. // Create one via [NewDiffer], and then pass it to [Execute]. type Differ struct { doc Doc fn func(ctx context.Context, cancelFn func(error)) } // NewDiffer returns a new [Differ] that can be passed to [Execute]. Arg doc is // the [Doc] to be populated, and fn populates the [Doc]. The cancelFn arg to fn // must only be invoked in the event of an error; it must not be invoked on the // happy path. func NewDiffer(doc Doc, fn func(ctx context.Context, cancelFn func(error))) *Differ { return &Differ{doc: doc, fn: fn} } // execute returns a function that, when invoked, populates the doc by executing // the function passed to NewDiffer. If that function returns an error, // cancelFn is invoked with that error, and the error is returned. func (d *Differ) execute(ctx context.Context, cancelFn func(error)) func() error { if cancelFn == nil { ctx, cancelFn = context.WithCancelCause(ctx) } return func() error { d.fn(ctx, cancelFn) err := d.doc.Err() if err != nil { cancelFn(err) } return err } } // Execute executes differs concurrently, writing output sequentially to w. // // Arg concurrency specifies the maximum number of concurrent Differ executions. // Zero indicates sequential execution; a negative values indicates unbounded // concurrency. // // The first error encountered is returned; hasDiff returns true if differences // were found, and false if no differences. func Execute(ctx context.Context, w io.Writer, concurrency int, differs []*Differ) (hasDiffs bool, err error) { defer func() { for _, differ := range differs { if differs == nil || differ.doc == nil { continue } if closeErr := differ.doc.Close(); closeErr != nil { lg.FromContext(ctx).Warn(lgm.CloseDiffDoc, lga.Doc, differ.doc.String(), lga.Err, closeErr) } } }() var cancelFn context.CancelCauseFunc ctx, cancelFn = context.WithCancelCause(ctx) defer func() { cancelFn(err) }() g := &errgroup.Group{} g.SetLimit(concurrency) for i := range differs { if differs[i] == nil { continue } g.Go(differs[i].execute(ctx, cancelFn)) } // We don't call g.Wait() here because we're using errgroup solely to limit // the number of concurrent goroutines. We don't actually want to wait for all // the goroutines to finish; we want to stream the output (via io.Copy below) // as soon as it's available. rdrs := make([]io.Reader, 0, len(differs)) for i := range differs { if differs[i] == nil { continue } rdrs = append(rdrs, differs[i].doc) } var n int64 n, err = io.Copy(w, contextio.NewReader(ctx, io.MultiReader(rdrs...))) return n > 0, err }