sq/drivers/userdriver/xmlud/xmlimport.go

607 lines
16 KiB
Go
Raw Normal View History

2020-08-06 20:58:47 +03:00
// Package xmlud provides user driver XML import functionality.
// Note that this implementation is experimental, not well-tested,
// inefficient, possibly incomprehensible, and subject to change.
package xmlud
import (
"context"
"encoding/xml"
2020-08-06 20:58:47 +03:00
"fmt"
"io"
"strconv"
"strings"
2020-08-06 20:58:47 +03:00
"github.com/neilotoole/sq/libsq/core/lg"
"github.com/neilotoole/sq/libsq/core/lg/lga"
"github.com/neilotoole/sq/libsq/source"
"golang.org/x/exp/slog"
2020-08-06 20:58:47 +03:00
"github.com/neilotoole/sq/drivers/userdriver"
"github.com/neilotoole/sq/libsq/core/cleanup"
"github.com/neilotoole/sq/libsq/core/errz"
"github.com/neilotoole/sq/libsq/core/kind"
"github.com/neilotoole/sq/libsq/core/sqlmodel"
2020-08-06 20:58:47 +03:00
"github.com/neilotoole/sq/libsq/driver"
)
// Genre is the user driver genre that this package supports.
const Genre = "xml"
// Import implements userdriver.ImportFunc.
func Import(ctx context.Context, def *userdriver.DriverDef, data io.Reader, destDB driver.Database) error {
2020-08-06 20:58:47 +03:00
if def.Genre != Genre {
return errz.Errorf("xmlud.Import does not support genre {%s}", def.Genre)
2020-08-06 20:58:47 +03:00
}
im := &importer{
log: lg.From(ctx),
2020-08-06 20:58:47 +03:00
def: def,
selStack: newSelStack(),
rowStack: newRowStack(),
tblDefs: map[string]*sqlmodel.TableDef{},
tblSequence: map[string]int64{},
2022-12-17 02:34:33 +03:00
execInsertFns: map[string]func(ctx context.Context, insertVals []any) error{},
execUpdateFns: map[string]func(ctx context.Context, updateVals, whereArgs []any) error{},
2020-08-06 20:58:47 +03:00
clnup: cleanup.New(),
msgOnce: map[string]struct{}{},
}
err := im.execImport(ctx, data, destDB)
err2 := im.clnup.Run()
if err != nil {
return errz.Wrap(err, "xml import")
}
return errz.Wrap(err2, "xml import: cleanup")
}
// importer does the work of importing data from XML.
type importer struct {
log *slog.Logger
2020-08-06 20:58:47 +03:00
def *userdriver.DriverDef
data io.Reader
destDB driver.Database
selStack *selStack
rowStack *rowStack
tblDefs map[string]*sqlmodel.TableDef
// tblSequence is a map of table name to the last
// insert ID value for that table. See dbInsert for more.
tblSequence map[string]int64
2023-03-15 10:43:48 +03:00
// execInsertFns is a map of a table+cols key to a func for inserting
2020-08-06 20:58:47 +03:00
// vals. Effectively it can be considered a cache of prepared insert
// statements. See the dbInsert function.
2022-12-17 02:34:33 +03:00
execInsertFns map[string]func(ctx context.Context, vals []any) error
2020-08-06 20:58:47 +03:00
// execUpdateFns is similar to execInsertFns, but for UPDATE instead
// of INSERT. The whereArgs param is the arguments for the
// update's WHERE clause.
execUpdateFns map[string]func(ctx context.Context, updateVals, whereArgs []any) error
2020-08-06 20:58:47 +03:00
// clnup holds cleanup funcs that should be run when the importer
// finishes.
clnup *cleanup.Cleanup
// msgOnce is used by method msgOncef.
msgOnce map[string]struct{}
}
func (im *importer) execImport(ctx context.Context, r io.Reader, destDB driver.Database) error { //nolint:gocognit
2020-08-06 20:58:47 +03:00
im.data, im.destDB = r, destDB
err := im.createTables(ctx)
if err != nil {
return err
}
decoder := xml.NewDecoder(im.data)
for {
t, err := decoder.Token()
if t == nil {
break
}
if err != nil {
return errz.Err(err)
}
switch elem := t.(type) {
case xml.StartElement:
im.selStack.push(elem.Name.Local)
if im.isRootSelector() {
continue
}
if im.isRowSelector() {
// We found a new row...
prevRow := im.rowStack.peek()
if prevRow != nil {
// Because the new row might require the primary key of the prev row,
// we need to save the previous row, to ensure its primary key is
// generated.
err = im.saveRow(ctx, prevRow)
if err != nil {
return err
}
}
var curRow *rowState
curRow, err = im.buildRow()
if err != nil {
return err
}
im.rowStack.push(curRow)
err = im.handleElemAttrs(elem, curRow)
if err != nil {
return err
}
continue
}
// It's not a row element, it's a col element
curRow := im.rowStack.peek()
if curRow == nil {
return errz.Errorf("unable to parse XML: no current row on stack for elem {%s}", elem.Name.Local)
2020-08-06 20:58:47 +03:00
}
col := curRow.tbl.ColBySelector(im.selStack.selector())
if col == nil {
if msg, ok := im.msgOncef("Skip: element {%s} is not a column of table {%s}", elem.Name.Local,
curRow.tbl.Name); ok {
2020-08-06 20:58:47 +03:00
im.log.Debug(msg)
}
continue
}
curRow.curCol = col
err = im.handleElemAttrs(elem, curRow)
if err != nil {
return err
}
case xml.EndElement:
if im.isRowSelector() {
row := im.rowStack.peek()
if row.dirty() {
err = im.saveRow(ctx, row)
if err != nil {
return err
}
}
im.rowStack.pop()
}
im.selStack.pop()
case xml.CharData:
data := string(elem)
curRow := im.rowStack.peek()
if curRow == nil {
continue
}
if curRow.curCol == nil {
continue
}
val, err := im.convertVal(curRow.tbl.Name, curRow.curCol, data)
if err != nil {
return err
}
curRow.dirtyColVals[curRow.curCol.Name] = val
curRow.curCol = nil
}
}
return nil
}
2022-12-17 02:34:33 +03:00
func (im *importer) convertVal(tbl string, col *userdriver.ColMapping, data any) (any, error) {
2020-08-06 20:58:47 +03:00
const errTpl = `conversion error: %s.%s: expected "%s" but got %T(%v)`
const errTplMsg = `conversion error: %s.%s: expected "%s" but got %T(%v): %v`
switch col.Kind { //nolint:exhaustive
2020-08-06 20:58:47 +03:00
default:
return nil, errz.Errorf("unknown data kind {%s} for col %s", col.Kind, col.Name)
case kind.Text, kind.Time:
2020-08-06 20:58:47 +03:00
return data, nil
case kind.Int:
2020-08-06 20:58:47 +03:00
switch data := data.(type) {
case int, int32, int64:
return data, nil
case string:
val, err := strconv.ParseInt(data, 0, 64)
if err != nil {
return nil, errz.Errorf(errTplMsg, tbl, col.Name, col.Kind, data, data, err)
}
return val, nil
default:
return nil, errz.Errorf(errTpl, tbl, col.Name, col.Kind, data, data)
}
case kind.Float:
2020-08-06 20:58:47 +03:00
switch data := data.(type) {
case float32, float64:
return data, nil
case string:
val, err := strconv.ParseFloat(data, 64)
if err != nil {
return nil, errz.Errorf(errTplMsg, tbl, col.Name, col.Kind, data, data, err)
}
return val, nil
default:
return nil, errz.Errorf(errTpl, tbl, col.Name, col.Kind, data, data)
}
case kind.Decimal:
2020-08-06 20:58:47 +03:00
return data, nil
case kind.Bool:
2020-08-06 20:58:47 +03:00
switch data := data.(type) {
case bool:
return data, nil
case int, int32, int64:
if data == 0 {
return false, nil
}
return true, nil
case string:
val, err := strconv.ParseBool(data)
if err != nil {
return nil, errz.Errorf(errTplMsg, tbl, col.Name, col.Kind, data, data, err)
}
return val, nil
default:
return nil, errz.Errorf(errTpl, tbl, col.Name, col.Kind, data, data)
}
case kind.Datetime, kind.Date:
2020-08-06 20:58:47 +03:00
return data, nil
case kind.Bytes:
2020-08-06 20:58:47 +03:00
return data, nil
case kind.Null:
2020-08-06 20:58:47 +03:00
return data, nil
}
}
func (im *importer) handleElemAttrs(elem xml.StartElement, curRow *rowState) error {
if len(elem.Attr) > 0 {
baseSel := im.selStack.selector()
for _, attr := range elem.Attr {
attrSel := baseSel + "/@" + attr.Name.Local
attrCol := curRow.tbl.ColBySelector(attrSel)
if attrCol == nil {
if msg, ok := im.msgOncef("Skip: attr {%s} is not a column of table {%s}", attrSel, curRow.tbl.Name); ok {
im.log.Debug(msg)
2020-08-06 20:58:47 +03:00
}
continue
}
// We have found the col matching the attribute
val, err := im.convertVal(curRow.tbl.Name, attrCol, attr.Value)
if err != nil {
return err
}
curRow.dirtyColVals[attrCol.Name] = val
}
}
return nil
}
// setForeignColsVals sets the values of any column that needs to be
// populated from a foreign key.
func (im *importer) setForeignColsVals(row *rowState) error {
// check if we need to populate any of the row's values with
// foreign key data (e.g. from parent table).
for _, col := range row.tbl.Cols {
if col.Foreign == "" {
continue
}
// yep, we need to add a foreign key
parts := strings.Split(col.Foreign, "/")
// parts will look like [ "..", "channel_id" ]
if len(parts) != 2 || parts[0] != ".." {
return errz.Errorf(`%s.%s: "foreign" field should be of form "../col_name" but was {%s}`, row.tbl.Name,
col.Name, col.Foreign)
2020-08-06 20:58:47 +03:00
}
fkName := parts[1]
parentRow := im.rowStack.peekN(1)
if parentRow == nil {
return errz.Errorf("unable to find parent() table for foreign key for %s.%s", row.tbl.Name, col.Name)
}
fkVal, ok := parentRow.savedColVals[fkName]
if !ok {
return errz.Errorf(`%s.%s: unable to find foreign key value in parent table {%s}`, row.tbl.Name, col.Name,
parentRow.tbl.Name)
2020-08-06 20:58:47 +03:00
}
row.dirtyColVals[col.Name] = fkVal
}
return nil
}
func (im *importer) setSequenceColsVals(row *rowState, nextSeqVal int64) {
seqColNames := userdriver.NamesFromCols(row.tbl.SequenceCols())
for _, seqColName := range seqColNames {
if _, ok := row.savedColVals[seqColName]; ok {
// This seq col has already been saved
continue
}
if _, ok := row.dirtyColVals[seqColName]; ok {
// Hmmmn... seqColName is already present. This shouldn't happen,
// as the point of a sequence col is to auto-generate the col
// value. The input data is inconsistent, or at least, it
// clashes with the user driver def.
//
// We could override the value, or trust the input.
//
// But given that the seqCol is typically the primary key,
// trusting the input could cause a constraint violation
// if a subsequent row doesn't have a value for the seqCol.
//
// Probably safer to override the value.
row.dirtyColVals[seqColName] = nextSeqVal
im.log.Warn("%s.%s is a auto-generated sequence() column: ignoring the value found in input",
2020-08-06 20:58:47 +03:00
row.tbl.Name, seqColName)
continue
}
// Else, this seq col has not yet been saved
row.dirtyColVals[seqColName] = nextSeqVal
}
}
func (im *importer) saveRow(ctx context.Context, row *rowState) error {
if !row.dirty() {
return nil
}
tblDef, ok := im.tblDefs[row.tbl.Name]
if !ok {
return errz.Errorf("unable to find definition for table {%s}", row.tbl.Name)
2020-08-06 20:58:47 +03:00
}
if row.created() {
// Row already exists in the db
err := im.dbUpdate(ctx, row)
if err != nil {
return errz.Wrapf(err, "failed to update table {%s}", tblDef.Name)
2020-08-06 20:58:47 +03:00
}
row.markDirtyAsSaved()
return nil
}
// We're going to INSERT the row.
// Maintain the table's sequence. Note that we always increment the
// seq val even if there are no sequence cols for this table.
prevSeqVal := im.tblSequence[tblDef.Name]
nextSeqVal := prevSeqVal + 1
im.tblSequence[tblDef.Name] = nextSeqVal
im.setSequenceColsVals(row, nextSeqVal)
// Collection any foreign cols
2020-08-06 20:58:47 +03:00
err := im.setForeignColsVals(row)
if err != nil {
return err
}
// Verify that all required cols are present
for _, requiredCol := range row.tbl.RequiredCols() {
if _, ok = row.dirtyColVals[requiredCol.Name]; !ok {
return errz.Errorf("no value for required column %s.%s", row.tbl.Name, requiredCol.Name)
}
}
err = im.dbInsert(ctx, row)
if err != nil {
return errz.Wrapf(err, "failed to insert to table {%s}", tblDef.Name)
2020-08-06 20:58:47 +03:00
}
row.markDirtyAsSaved()
return nil
}
// dbInsert inserts row's dirty col values to row's table.
func (im *importer) dbInsert(ctx context.Context, row *rowState) error {
tblName := row.tbl.Name
colNames := make([]string, len(row.dirtyColVals))
2022-12-17 02:34:33 +03:00
vals := make([]any, len(row.dirtyColVals))
2020-08-06 20:58:47 +03:00
i := 0
for k, v := range row.dirtyColVals {
colNames[i], vals[i] = k, v
i++
}
// We cache the prepared insert statements.
cacheKey := "##insert_func__" + tblName + "__" + strings.Join(colNames, ",")
execInsertFn, ok := im.execInsertFns[cacheKey]
if !ok {
// Nothing cached, prepare the insert statement and insert munge func
stmtExecer, err := im.destDB.SQLDriver().PrepareInsertStmt(ctx, im.destDB.DB(), tblName, colNames, 1)
2020-08-06 20:58:47 +03:00
if err != nil {
return err
}
// Make sure we close stmt eventually.
im.clnup.AddC(stmtExecer)
2022-12-17 02:34:33 +03:00
execInsertFn = func(ctx context.Context, vals []any) error {
2020-08-06 20:58:47 +03:00
// Munge vals so that they're as the target DB expects
err = stmtExecer.Munge(vals)
if err != nil {
return err
}
_, err = stmtExecer.Exec(ctx, vals...)
return errz.Err(err)
}
// Cache the execInsertFn.
im.execInsertFns[cacheKey] = execInsertFn
}
err := execInsertFn(ctx, vals)
if err != nil {
return err
}
return nil
}
// dbUpdate updates row's table with row's dirty values, using row's
// primary key cols as the args to the WHERE clause.
func (im *importer) dbUpdate(ctx context.Context, row *rowState) error {
drvr := im.destDB.SQLDriver()
tblName := row.tbl.Name
pkColNames := row.tbl.PrimaryKey
var whereBuilder strings.Builder
2022-12-17 02:34:33 +03:00
var pkVals []any
2020-08-06 20:58:47 +03:00
for i, pkColName := range pkColNames {
if pkVal, ok := row.savedColVals[pkColName]; ok {
pkVals = append(pkVals, pkVal)
if i > 0 {
whereBuilder.WriteString(" AND ")
}
whereBuilder.WriteString(drvr.Dialect().Enquote(pkColName))
whereBuilder.WriteString(" = ?")
continue
}
// Else, we're missing a pk val
return errz.Errorf("failed to update table {%s}: primary key value {%s} not present", tblName, pkColName)
2020-08-06 20:58:47 +03:00
}
whereClause := whereBuilder.String()
colNames := make([]string, len(row.dirtyColVals))
2022-12-17 02:34:33 +03:00
dirtyVals := make([]any, len(row.dirtyColVals))
2020-08-06 20:58:47 +03:00
i := 0
for k, v := range row.dirtyColVals {
colNames[i], dirtyVals[i] = k, v
i++
}
// We cache the prepared statement.
cacheKey := "##update_func__" + tblName + "__" + strings.Join(colNames, ",") + whereClause
execUpdateFn, ok := im.execUpdateFns[cacheKey]
if !ok {
// Nothing cached, prepare the update statement and munge func
stmtExecer, err := drvr.PrepareUpdateStmt(ctx, im.destDB.DB(), tblName, colNames, whereClause)
if err != nil {
return err
}
// Make sure we close stmt eventually.
im.clnup.AddC(stmtExecer)
execUpdateFn = func(ctx context.Context, updateVals, whereArgs []any) error {
2020-08-06 20:58:47 +03:00
// Munge vals so that they're as the target DB expects
err := stmtExecer.Munge(updateVals)
if err != nil {
return err
}
// Append the WHERE clause args
updateVals = append(updateVals, whereArgs...)
_, err = stmtExecer.Exec(ctx, updateVals...)
2020-08-06 20:58:47 +03:00
return errz.Err(err)
}
// Cache the execInsertFn.
im.execUpdateFns[cacheKey] = execUpdateFn
}
err := execUpdateFn(ctx, dirtyVals, pkVals)
if err != nil {
return err
}
return nil
}
func (im *importer) buildRow() (*rowState, error) {
tbl := im.def.TableBySelector(im.selStack.selector())
if tbl == nil {
return nil, errz.Errorf("no tbl matching current selector: %s", im.selStack.selector())
}
r := &rowState{tbl: tbl}
2022-12-17 02:34:33 +03:00
r.dirtyColVals = make(map[string]any)
r.savedColVals = make(map[string]any)
2020-08-06 20:58:47 +03:00
for i := range r.tbl.Cols {
// If the table has a column that has a "text()" selector, then we need to capture the
// next CharData token, so we mark that col as the current col.
if strings.HasSuffix(r.tbl.Cols[i].Selector, "text()") {
r.curCol = r.tbl.Cols[i]
break
}
}
return r, nil
}
func (im *importer) createTables(ctx context.Context) error {
for i := range im.def.Tables {
tblDef, err := userdriver.ToTableDef(im.def.Tables[i])
if err != nil {
return err
}
im.tblDefs[tblDef.Name] = tblDef
err = im.destDB.SQLDriver().CreateTable(ctx, im.destDB.DB(), tblDef)
if err != nil {
return err
}
im.log.Debug("Created table", lga.Target, source.Target(im.destDB.Source(), tblDef.Name))
2020-08-06 20:58:47 +03:00
}
return nil
}
// isRootSelector returns true if the current selector matches the root selector.
func (im *importer) isRootSelector() bool {
return im.selStack.selector() == im.def.Selector
}
// isRowSelector returns true if entity referred to by the current selector
// maps to a table row (as opposed to a column).
func (im *importer) isRowSelector() bool {
return im.def.TableBySelector(im.selStack.selector()) != nil
}
// msgOncef is used to prevent repeated logging of a message. The
// method returns ok=true and the formatted string if the formatted
// string has not been previous seen by msgOncef.
2022-12-17 02:34:33 +03:00
func (im *importer) msgOncef(format string, a ...any) (msg string, ok bool) {
2020-08-06 20:58:47 +03:00
msg = fmt.Sprintf(format, a...)
if _, exists := im.msgOnce[msg]; exists {
// msg already seen, return ok=false.
return "", false
}
im.msgOnce[msg] = struct{}{}
return msg, true
}