sq/drivers/csv/insert.go
Neil O'Toole 2898a92983
Refactor sqlmodel pkg (#364)
* Refactor sqlmodel pkg
2024-01-25 00:42:51 -07:00

141 lines
3.3 KiB
Go

package csv
import (
"context"
"encoding/csv"
"errors"
"io"
"github.com/neilotoole/sq/libsq"
"github.com/neilotoole/sq/libsq/core/errz"
"github.com/neilotoole/sq/libsq/core/kind"
"github.com/neilotoole/sq/libsq/core/lg"
"github.com/neilotoole/sq/libsq/core/lg/lga"
"github.com/neilotoole/sq/libsq/core/record"
"github.com/neilotoole/sq/libsq/core/schema"
"github.com/neilotoole/sq/libsq/driver"
)
// execInsert inserts the CSV records in readAheadRecs (followed by records
// from the csv.Reader) via recw. The caller should wait on recw to complete.
func execInsert(ctx context.Context, recw libsq.RecordWriter, recMeta record.Meta,
mungers []kind.MungeFunc, readAheadRecs [][]string, r *csv.Reader,
) error {
ctx, cancelFn := context.WithCancel(ctx)
// We don't do "defer cancelFn" here. The cancelFn is passed
// to recw.
recordCh, errCh, err := recw.Open(ctx, cancelFn, recMeta)
if err != nil {
return err
}
defer close(recordCh)
// Before we continue reading from CSV, we first write out
// any CSV records we read earlier.
for i := range readAheadRecs {
var rec []any
if rec, err = mungeCSV2InsertRecord(ctx, mungers, readAheadRecs[i]); err != nil {
return err
}
select {
case err = <-errCh:
cancelFn()
return err
case <-ctx.Done():
cancelFn()
return ctx.Err()
case recordCh <- rec:
}
}
var csvRecord []string
for {
csvRecord, err = r.Read()
if errors.Is(err, io.EOF) {
// We're done reading
return nil
}
if err != nil {
cancelFn()
return errz.Wrap(err, "read from CSV data source")
}
var rec []any
if rec, err = mungeCSV2InsertRecord(ctx, mungers, csvRecord); err != nil {
return err
}
select {
case err = <-errCh:
cancelFn()
return err
case <-ctx.Done():
cancelFn()
return ctx.Err()
case recordCh <- rec:
}
}
}
// mungeCSV2InsertRecord returns a new []any containing
// the values of the csvRec []string.
func mungeCSV2InsertRecord(ctx context.Context, mungers []kind.MungeFunc, csvRec []string) ([]any, error) {
var err error
a := make([]any, len(csvRec))
for i := range csvRec {
if i >= len(mungers) {
lg.FromContext(ctx).Error("no munger for field", lga.Index, i, lga.Val, csvRec[i])
// Maybe should panic here, or return an error?
// But, in future we may be able to handle ragged-edge records,
// so maybe logging the error is best.
continue
}
if mungers[i] != nil {
a[i], err = mungers[i](csvRec[i])
if err != nil {
return nil, err
}
} else {
a[i] = csvRec[i]
}
}
return a, nil
}
func createTblDef(tblName string, colNames []string, kinds []kind.Kind) *schema.Table {
tbl := &schema.Table{Name: tblName}
cols := make([]*schema.Column, len(colNames))
for i := range colNames {
cols[i] = &schema.Column{Table: tbl, Name: colNames[i], Kind: kinds[i]}
}
tbl.Cols = cols
return tbl
}
// getIngestRecMeta returns record.Meta to use with RecordWriter.Open.
func getIngestRecMeta(ctx context.Context, destGrip driver.Grip, tblDef *schema.Table) (record.Meta, error) {
db, err := destGrip.DB(ctx)
if err != nil {
return nil, err
}
drvr := destGrip.SQLDriver()
colTypes, err := drvr.TableColumnTypes(ctx, db, tblDef.Name, tblDef.ColNames())
if err != nil {
return nil, err
}
destMeta, _, err := drvr.RecordMeta(ctx, colTypes)
if err != nil {
return nil, err
}
return destMeta, nil
}