sq/cli/output/adapter.go
2023-04-09 19:29:13 -06:00

182 lines
4.5 KiB
Go

package output
import (
"context"
"sync"
"time"
"go.uber.org/atomic"
"github.com/neilotoole/sq/libsq/core/errz"
"github.com/neilotoole/sq/libsq/core/sqlz"
)
// RecordWriterAdapter implements libsq.RecordWriter and
// wraps an output.RecordWriter instance, providing a
// bridge between the asynchronous libsq.RecordWriter and
// synchronous output.RecordWriter interfaces.
//
// Note that a writer implementation such as the JSON or
// CSV writer could directly implement libsq.RecordWriter.
// But that interface is non-trivial to implement, hence
// this bridge type.
//
// The FlushAfterN and FlushAfterDuration fields control
// flushing of the writer.
type RecordWriterAdapter struct {
rw RecordWriter
wg *sync.WaitGroup
recCh chan sqlz.Record
errCh chan error
errs []error
written *atomic.Int64
cancelFn context.CancelFunc
// FlushAfterN indicates that the writer's Flush method
// should be invoked after N invocations of WriteRecords.
// A value of 0 will flush every time a record is written.
// Set to -1 to disable.
FlushAfterN int64
// FlushAfterDuration controls whether the writer's Flush method
// is invoked periodically. A duration <= 0 disables periodic flushing.
FlushAfterDuration time.Duration
}
// adapterRecChSize is the size of the record chan (effectively
// the buffer) used by RecordWriterAdapter.
// FIXME: adapterRecChSize should be user-configurable.
const adapterRecChSize = 1000
// NewRecordWriterAdapter returns a new RecordWriterAdapter.
func NewRecordWriterAdapter(rw RecordWriter) *RecordWriterAdapter {
recCh := make(chan sqlz.Record, adapterRecChSize)
return &RecordWriterAdapter{rw: rw, recCh: recCh, wg: &sync.WaitGroup{}, written: atomic.NewInt64(0)}
}
// Open implements libsq.RecordWriter.
func (w *RecordWriterAdapter) Open(ctx context.Context, cancelFn context.CancelFunc,
recMeta sqlz.RecordMeta,
) (chan<- sqlz.Record, <-chan error, error) {
w.cancelFn = cancelFn
err := w.rw.Open(recMeta)
if err != nil {
return nil, nil, err
}
// errCh has size 2 because that's the maximum number of
// errs that could be sent. Typically only one err is sent,
// but in the case of ctx.Done, we send ctx.Err, followed
// by any error returned by r.rw.Close.
w.errCh = make(chan error, 2)
w.wg.Add(1)
go func() {
defer func() {
w.wg.Done()
close(w.errCh)
}()
var lastFlushN, recN int64
var flushTimer *time.Timer
var flushCh <-chan time.Time
if w.FlushAfterDuration > 0 {
flushTimer = time.NewTimer(w.FlushAfterDuration)
flushCh = flushTimer.C
defer flushTimer.Stop()
}
for {
select {
case <-ctx.Done():
w.addErrs(ctx.Err(), w.rw.Close())
return
case <-flushCh:
// The flushTimer has expired, time to flush.
err = w.rw.Flush()
if err != nil {
w.addErrs(err)
return
}
lastFlushN = recN
flushTimer.Reset(w.FlushAfterDuration)
continue
case rec := <-w.recCh:
if rec == nil { // no more results on recCh, it has been closed
err = w.rw.Close()
if err != nil {
w.addErrs()
}
return
}
// rec is not nil, therefore we write it out.
// We could accumulate a bunch of recs into a slice here,
// but we'll worry about that if benchmarking shows it'll matter.
writeErr := w.rw.WriteRecords([]sqlz.Record{rec})
if writeErr != nil {
w.addErrs(writeErr)
return
}
recN = w.written.Inc()
// Check if we should flush
if w.FlushAfterN >= 0 && (recN-lastFlushN >= w.FlushAfterN) {
err = w.rw.Flush()
if err != nil {
w.addErrs(err)
return
}
lastFlushN = recN
if flushTimer != nil {
// Reset the timer, but we need to stop and drain it first.
// See the timer.Reset docs.
if !flushTimer.Stop() {
<-flushTimer.C
}
flushTimer.Reset(w.FlushAfterDuration)
}
}
// If we got this far, we successfully wrote rec to rw.
// Therefore continue to wait/select for the next
// element on recCh (or for recCh to close)
// or for ctx.Done indicating timeout or cancel etc.
continue
}
}
}()
return w.recCh, w.errCh, nil
}
// Wait implements libsq.RecordWriter.
func (w *RecordWriterAdapter) Wait() (written int64, err error) {
w.wg.Wait()
if w.cancelFn != nil {
w.cancelFn()
}
return w.written.Load(), errz.Combine(w.errs...)
}
// addErrs handles any non-nil err in errs by appending it to w.errs
// and sending it on w.errCh.
func (w *RecordWriterAdapter) addErrs(errs ...error) {
for _, err := range errs {
if err != nil {
w.errs = append(w.errs, err)
w.errCh <- err
}
}
}