mirror of
https://github.com/neilotoole/sq.git
synced 2024-12-18 21:52:28 +03:00
196 lines
5.5 KiB
Go
196 lines
5.5 KiB
Go
package libsq
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"sync"
|
|
|
|
"github.com/neilotoole/lg"
|
|
"go.uber.org/atomic"
|
|
|
|
"github.com/neilotoole/sq/libsq/driver"
|
|
"github.com/neilotoole/sq/libsq/errz"
|
|
"github.com/neilotoole/sq/libsq/sqlz"
|
|
)
|
|
|
|
// DefaultRecordChSize is the default size of a record channel.
|
|
const DefaultRecordChSize = 100
|
|
|
|
// DBWriter implements RecordWriter, writing
|
|
// records to a database table.
|
|
type DBWriter struct {
|
|
log lg.Log
|
|
wg *sync.WaitGroup
|
|
cancelFn context.CancelFunc
|
|
destDB driver.Database
|
|
destTbl string
|
|
recordCh chan sqlz.Record
|
|
written *atomic.Int64
|
|
errCh chan error
|
|
errs []error
|
|
|
|
// preWriteHook, when non-nil, is invoked by the Open method before any
|
|
// records are written. This is useful when the recMeta or tx are
|
|
// needed to perform actions before insertion, such as creating
|
|
// the dest table on the fly.
|
|
preWriteHook func(ctx context.Context, recMeta sqlz.RecordMeta, tx sqlz.DB) error
|
|
}
|
|
|
|
// NewDBWriter returns a new writer than implements RecordWriter.
|
|
// The writer writes records from recordCh to destTbl
|
|
// in destDB. The recChSize param controls the size of recordCh
|
|
// returned by the writer's Open method.
|
|
func NewDBWriter(log lg.Log, destDB driver.Database, destTbl string, recChSize int) *DBWriter {
|
|
return &DBWriter{
|
|
log: log,
|
|
destDB: destDB,
|
|
destTbl: destTbl,
|
|
recordCh: make(chan sqlz.Record, recChSize),
|
|
errCh: make(chan error, 3),
|
|
written: atomic.NewInt64(0),
|
|
wg: &sync.WaitGroup{},
|
|
}
|
|
|
|
// Note: errCh has size 3 because that's the maximum number of
|
|
// errs that could be sent. Frequently only one err is sent,
|
|
// but sometimes there are additional errs, e.g. when
|
|
// ctx is done, we send ctx.Err, followed by any rollback err.
|
|
}
|
|
|
|
// Open implements RecordWriter.
|
|
func (w *DBWriter) Open(ctx context.Context, cancelFn context.CancelFunc, recMeta sqlz.RecordMeta) (chan<- sqlz.Record, <-chan error, error) {
|
|
w.cancelFn = cancelFn
|
|
|
|
// REVISIT: tx could potentially be passed to NewDBWriter?
|
|
tx, err := w.destDB.DB().BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return nil, nil, errz.Wrapf(err, "failed to open tx for %s.%s", w.destDB.Source().Handle, w.destTbl)
|
|
}
|
|
|
|
if w.preWriteHook != nil {
|
|
err = w.preWriteHook(ctx, recMeta, tx)
|
|
if err != nil {
|
|
w.rollback(tx, err)
|
|
return nil, nil, err
|
|
}
|
|
}
|
|
|
|
// insertStmt, _, _, mungeFn, err := w.destDB.SQLDriver().PrepareInsertStmt(ctx, tx, w.destTbl, recMeta.Names())
|
|
stmtExecer, err := w.destDB.SQLDriver().PrepareInsertStmt(ctx, tx, w.destTbl, recMeta.Names())
|
|
if err != nil {
|
|
w.rollback(tx, err)
|
|
return nil, nil, err
|
|
}
|
|
|
|
w.wg.Add(1)
|
|
go func() {
|
|
defer func() {
|
|
// When the inserter goroutine finishes:
|
|
// - we close the errCh (and indicator that the writer is done)
|
|
// - and mark the wg as done, which the Wait method depends upon.
|
|
close(w.errCh)
|
|
w.wg.Done()
|
|
}()
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
// ctx is done (e.g. cancelled), so we're going to rollback.
|
|
w.rollback(tx, ctx.Err())
|
|
return
|
|
|
|
case rec := <-w.recordCh:
|
|
if rec == nil {
|
|
// No more results on recordCh, it has been closed.
|
|
// It's time to commit the tx.
|
|
// Note that Commit automatically closes any stmts
|
|
// that were prepared by tx.
|
|
commitErr := errz.Err(tx.Commit())
|
|
if commitErr != nil {
|
|
w.log.Error(commitErr)
|
|
w.addErrs(commitErr)
|
|
} else {
|
|
w.log.Debugf("Tx commit success for %s.%s", w.destDB.Source().Handle, w.destTbl)
|
|
}
|
|
return
|
|
}
|
|
|
|
// rec is not nil, therefore we write it out
|
|
err = w.doInsert(ctx, stmtExecer, rec)
|
|
if err != nil {
|
|
w.rollback(tx, err)
|
|
return
|
|
}
|
|
|
|
// Otherwise, we successfully wrote rec to tx.
|
|
// Therefore continue to wait/select for the next
|
|
// element on recordCh (or for recordCh to close)
|
|
// or for ctx.Done indicating timeout or cancel etc.
|
|
}
|
|
}
|
|
|
|
}()
|
|
|
|
return w.recordCh, w.errCh, nil
|
|
}
|
|
|
|
// Wait implements RecordWriter.
|
|
func (w *DBWriter) Wait() (written int64, err error) {
|
|
w.wg.Wait()
|
|
if w.cancelFn != nil {
|
|
w.cancelFn()
|
|
}
|
|
return w.written.Load(), errz.Combine(w.errs...)
|
|
}
|
|
|
|
// addErrs handles any non-nil err in errs by appending it to w.errs
|
|
// and sending it on w.errCh.
|
|
func (w *DBWriter) addErrs(errs ...error) {
|
|
for _, err := range errs {
|
|
if err != nil {
|
|
w.errs = append(w.errs, err)
|
|
w.errCh <- err
|
|
}
|
|
}
|
|
}
|
|
|
|
// rollback rolls back tx. Note that rollback or commit of the tx
|
|
// will close all of the tx's prepared statements, so we don't
|
|
// need to close those manually.
|
|
func (w *DBWriter) rollback(tx *sql.Tx, causeErrs ...error) {
|
|
// Guaranteed to be at least one causeErr
|
|
w.log.Errorf("failed to insert to %s.%s: tx rollback due to: %s",
|
|
w.destDB.Source().Handle, w.destTbl, causeErrs[0])
|
|
|
|
rollbackErr := errz.Err(tx.Rollback())
|
|
w.log.WarnIfError(rollbackErr)
|
|
|
|
w.addErrs(causeErrs...)
|
|
w.addErrs(rollbackErr)
|
|
}
|
|
|
|
func (w *DBWriter) doInsert(ctx context.Context, stmtExecer *driver.StmtExecer, rec sqlz.Record) error {
|
|
err := stmtExecer.Munge(rec)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
affected, err := stmtExecer.Exec(ctx, rec...)
|
|
if err != nil {
|
|
// NOTE: in the scenario where we're inserting into
|
|
// a SQLite db, and there's multiple writers (inserters) to
|
|
// the same db, a "database is locked" error from SQLite is
|
|
// possible. See https://github.com/mattn/go-sqlite3/issues/274
|
|
// Perhaps there's a sensible way to handle such an error that
|
|
// could be tackled here.
|
|
return errz.Err(err)
|
|
}
|
|
|
|
if affected != 1 {
|
|
w.log.Warnf("expected 1 affected row for insert, but got %d", affected)
|
|
}
|
|
|
|
w.written.Add(affected)
|
|
return nil
|
|
}
|