sq/drivers/csv/insert.go
2023-04-09 19:29:13 -06:00

137 lines
3.3 KiB
Go

package csv
import (
"context"
"encoding/csv"
"errors"
"io"
"github.com/neilotoole/sq/libsq/core/lg"
"github.com/neilotoole/sq/libsq/core/lg/lga"
"github.com/neilotoole/sq/libsq/driver"
"github.com/neilotoole/sq/libsq/core/kind"
"github.com/neilotoole/sq/libsq/core/sqlmodel"
"github.com/neilotoole/sq/libsq"
"github.com/neilotoole/sq/libsq/core/errz"
"github.com/neilotoole/sq/libsq/core/sqlz"
)
// execInsert inserts the CSV records in readAheadRecs (followed by records
// from the csv.Reader) via recw. The caller should wait on recw to complete.
func execInsert(ctx context.Context, recw libsq.RecordWriter, recMeta sqlz.RecordMeta,
mungers []kind.MungeFunc, readAheadRecs [][]string, r *csv.Reader,
) error {
ctx, cancelFn := context.WithCancel(ctx)
// We don't do "defer cancelFn" here. The cancelFn is passed
// to recw.
recordCh, errCh, err := recw.Open(ctx, cancelFn, recMeta)
if err != nil {
return err
}
defer close(recordCh)
// Before we continue reading from CSV, we first write out
// any CSV records we read earlier.
for i := range readAheadRecs {
var rec []any
if rec, err = mungeCSV2InsertRecord(ctx, mungers, readAheadRecs[i]); err != nil {
return err
}
select {
case err = <-errCh:
cancelFn()
return err
case <-ctx.Done():
cancelFn()
return ctx.Err()
case recordCh <- rec:
}
}
var csvRecord []string
for {
csvRecord, err = r.Read()
if errors.Is(err, io.EOF) {
// We're done reading
return nil
}
if err != nil {
cancelFn()
return errz.Wrap(err, "read from CSV data source")
}
rec, err := mungeCSV2InsertRecord(ctx, mungers, csvRecord)
if err != nil {
return err
}
select {
case err = <-errCh:
cancelFn()
return err
case <-ctx.Done():
cancelFn()
return ctx.Err()
case recordCh <- rec:
}
}
}
// mungeCSV2InsertRecord returns a new []any containing
// the values of the csvRec []string.
func mungeCSV2InsertRecord(ctx context.Context, mungers []kind.MungeFunc, csvRec []string) ([]any, error) {
var err error
a := make([]any, len(csvRec))
for i := range csvRec {
if i >= len(mungers) {
lg.FromContext(ctx).Error("no munger for field", lga.Index, i, lga.Val, csvRec[i])
// Maybe should panic here, or return an error?
// But, in future we may be able to handle ragged-edge records,
// so maybe logging the error is best.
continue
}
if mungers[i] != nil {
a[i], err = mungers[i](csvRec[i])
if err != nil {
return nil, err
}
} else {
a[i] = csvRec[i]
}
}
return a, nil
}
func createTblDef(tblName string, colNames []string, kinds []kind.Kind) *sqlmodel.TableDef {
tbl := &sqlmodel.TableDef{Name: tblName}
cols := make([]*sqlmodel.ColDef, len(colNames))
for i := range colNames {
cols[i] = &sqlmodel.ColDef{Table: tbl, Name: colNames[i], Kind: kinds[i]}
}
tbl.Cols = cols
return tbl
}
// getRecMeta returns RecordMeta to use with RecordWriter.Open.
func getRecMeta(ctx context.Context, scratchDB driver.Database, tblDef *sqlmodel.TableDef) (sqlz.RecordMeta, error) {
colTypes, err := scratchDB.SQLDriver().TableColumnTypes(ctx, scratchDB.DB(), tblDef.Name, tblDef.ColNames())
if err != nil {
return nil, err
}
destMeta, _, err := scratchDB.SQLDriver().RecordMeta(colTypes)
if err != nil {
return nil, err
}
return destMeta, nil
}