mirror of
https://github.com/neilotoole/sq.git
synced 2024-12-22 07:31:39 +03:00
239 lines
6.2 KiB
Go
239 lines
6.2 KiB
Go
package diff
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"fmt"
|
|
"slices"
|
|
"time"
|
|
|
|
"github.com/samber/lo"
|
|
"golang.org/x/sync/errgroup"
|
|
|
|
udiff "github.com/neilotoole/sq/cli/diff/internal/go-udiff"
|
|
"github.com/neilotoole/sq/cli/diff/internal/go-udiff/myers"
|
|
"github.com/neilotoole/sq/cli/output"
|
|
"github.com/neilotoole/sq/cli/run"
|
|
"github.com/neilotoole/sq/libsq"
|
|
"github.com/neilotoole/sq/libsq/core/errz"
|
|
"github.com/neilotoole/sq/libsq/core/lg"
|
|
"github.com/neilotoole/sq/libsq/core/lg/lga"
|
|
"github.com/neilotoole/sq/libsq/core/options"
|
|
"github.com/neilotoole/sq/libsq/driver"
|
|
)
|
|
|
|
// buildTableDataDiff compares the row data in td1 and td2, returning
|
|
// a tableDataDiff instance if there's a difference between the
|
|
// equivalent rows.
|
|
//
|
|
// NOTE: This is a naive implementation. It simply prints the entire
|
|
// results from each table to an output.RecordWriter, and diffs the
|
|
// resulting text. Thus, if the target tables contained a million rows
|
|
// but no diffs, each row will still have been rendered to text.
|
|
// A better mechanism might be to perform the diff against the
|
|
// raw record.Record values, and only generate the diff text if there
|
|
// are differences, and even then, to only selectively generate the
|
|
// needed text.
|
|
func buildTableDataDiff(ctx context.Context, ru *run.Run, cfg *Config,
|
|
td1, td2 *tableData,
|
|
) (*tableDataDiff, error) {
|
|
qc := run.NewQueryContext(ru, nil)
|
|
query1 := td1.src.Handle + "." + td1.tblName
|
|
query2 := td2.src.Handle + "." + td2.tblName
|
|
log := lg.FromContext(ctx).With("a", query1).With("b", query2)
|
|
|
|
pr := ru.Writers.Printing.Clone()
|
|
pr.EnableColor(false)
|
|
|
|
buf1, buf2 := &bytes.Buffer{}, &bytes.Buffer{}
|
|
w1, w2 := cfg.RecordWriterFn(buf1, pr), cfg.RecordWriterFn(buf2, pr)
|
|
recw1, recw2 := output.NewRecordWriterAdapter(ctx, w1), output.NewRecordWriterAdapter(ctx, w2)
|
|
|
|
g, gCtx := errgroup.WithContext(ctx)
|
|
g.Go(func() error {
|
|
if err := libsq.ExecuteSLQ(gCtx, qc, query1, recw1); err != nil {
|
|
if errz.IsErrNotExist(err) {
|
|
// It's totally ok if a table is not found.
|
|
log.Debug("Diff: table not found", lga.Src, td1.src, lga.Table, td1.tblName)
|
|
return nil
|
|
}
|
|
}
|
|
|
|
_, err := recw1.Wait()
|
|
return err
|
|
})
|
|
g.Go(func() error {
|
|
if err := libsq.ExecuteSLQ(gCtx, qc, query2, recw2); err != nil {
|
|
if errz.IsErrNotExist(err) {
|
|
log.Debug("Diff: table not found", lga.Src, td2.src, lga.Table, td2.tblName)
|
|
return nil
|
|
}
|
|
}
|
|
|
|
_, err := recw2.Wait()
|
|
return err
|
|
})
|
|
if err := g.Wait(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var (
|
|
body1, body2 = buf1.String(), buf2.String()
|
|
err error
|
|
)
|
|
|
|
edits := myers.ComputeEdits(body1, body2)
|
|
unified, err := udiff.ToUnified(
|
|
query1,
|
|
query2,
|
|
body1,
|
|
edits,
|
|
cfg.Lines,
|
|
)
|
|
if err != nil {
|
|
return nil, errz.Err(err)
|
|
}
|
|
|
|
return &tableDataDiff{
|
|
td1: td1,
|
|
td2: td2,
|
|
header: fmt.Sprintf("sq diff --data %s %s", query1, query2),
|
|
diff: unified,
|
|
}, nil
|
|
}
|
|
|
|
// execSourceDataDiff executes a diff all tables found in either source.
|
|
func execSourceDataDiff(ctx context.Context, ru *run.Run, cfg *Config, sd1, sd2 *sourceData) error {
|
|
o := options.FromContext(ctx)
|
|
|
|
allTblNames := append(sd1.srcMeta.TableNames(), sd2.srcMeta.TableNames()...)
|
|
allTblNames = lo.Uniq(allTblNames)
|
|
slices.Sort(allTblNames)
|
|
|
|
diffs := make([]*tableDataDiff, len(allTblNames))
|
|
|
|
// mIndex is a map of table name to its index in allTblNames.
|
|
mIndex := make(map[string]int, len(allTblNames))
|
|
for i := range allTblNames {
|
|
mIndex[allTblNames[i]] = i
|
|
}
|
|
|
|
g, gCtx := errgroup.WithContext(ctx)
|
|
g.SetLimit(driver.OptTuningErrgroupLimit.Get(o))
|
|
diffCh := make(chan *tableDataDiff, driver.OptTuningRecChanSize.Get(o))
|
|
|
|
printErrCh := make(chan error, 1)
|
|
printIndex := 0
|
|
|
|
go func() {
|
|
// This is the read/print goroutine. Further down, in an errgroup, we spin
|
|
// up multiple goroutines to build a tableDataDiff for each table. Those
|
|
// errgroup goroutines write their results to diffCh. But they can be sent
|
|
// on diffCh in any order.
|
|
//
|
|
// This goroutine reads *tblDataDiff from diffCh, and looks up the index
|
|
// of that table in mIndex. Then, it sets diffs[i] with the tableDataDiff.
|
|
//
|
|
// Note that we want to print the diffs in the order specified by
|
|
// allTblNames. A printIndex counter is maintained. After the goroutine
|
|
// inserts a value into diffs, we then check if the next (as defined by
|
|
// printIndex) element of diffs is non-nil. If so, the goroutine prints
|
|
// that diff, and advances the printIndex counter.
|
|
|
|
defer func() { close(printErrCh) }()
|
|
|
|
var tblDataDiff *tableDataDiff
|
|
for {
|
|
select {
|
|
case <-gCtx.Done():
|
|
return
|
|
case tblDataDiff = <-diffCh:
|
|
if tblDataDiff == nil {
|
|
// Channel is closed, means we're done.
|
|
return
|
|
}
|
|
}
|
|
|
|
diffIndex, ok := mIndex[tblDataDiff.td1.tblName]
|
|
if !ok {
|
|
// Shouldn't happen
|
|
err := errz.Errorf("Index not found for table: %s", tblDataDiff.td1.tblName)
|
|
printErrCh <- err
|
|
return
|
|
}
|
|
|
|
// Put tblDataDiff into diffs.
|
|
diffs[diffIndex] = tblDataDiff
|
|
|
|
// Now check if the next diff is available to print.
|
|
for {
|
|
select {
|
|
case <-gCtx.Done():
|
|
return
|
|
default:
|
|
}
|
|
|
|
if printIndex >= len(diffs) {
|
|
return
|
|
}
|
|
|
|
if diffs[printIndex] == nil {
|
|
break
|
|
}
|
|
|
|
tblDataDiff = diffs[printIndex]
|
|
if err := Print(ru.Out, ru.Writers.Printing, tblDataDiff.header, tblDataDiff.diff); err != nil {
|
|
printErrCh <- err
|
|
return
|
|
}
|
|
|
|
printIndex++
|
|
}
|
|
}
|
|
}()
|
|
|
|
for _, tblName := range allTblNames {
|
|
tblName := tblName
|
|
|
|
// Add a little delay to allow the goroutine to get a head start
|
|
// on its successors. Benchmark: does this actually help?
|
|
time.Sleep(time.Microsecond * 10)
|
|
|
|
g.Go(func() error {
|
|
select {
|
|
case <-gCtx.Done():
|
|
return errz.Err(gCtx.Err())
|
|
default:
|
|
}
|
|
|
|
td1 := &tableData{
|
|
tblName: tblName,
|
|
src: sd1.src,
|
|
srcMeta: sd1.srcMeta,
|
|
}
|
|
td2 := &tableData{
|
|
tblName: tblName,
|
|
src: sd2.src,
|
|
srcMeta: sd2.srcMeta,
|
|
}
|
|
|
|
tblDataDiff, err := buildTableDataDiff(gCtx, ru, cfg, td1, td2)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
diffCh <- tblDataDiff
|
|
return nil
|
|
})
|
|
}
|
|
|
|
if err := g.Wait(); err != nil {
|
|
return err
|
|
}
|
|
close(diffCh)
|
|
|
|
// This has the effect of waiting until the print goroutine completes.
|
|
err := <-printErrCh
|
|
return err
|
|
}
|