// Package xmlud provides user driver XML import functionality. // Note that this implementation is experimental, not well-tested, // inefficient, possibly incomprehensible, and subject to change. package xmlud import ( "context" "encoding/xml" "fmt" "io" "strconv" "strings" "github.com/neilotoole/lg" "github.com/neilotoole/sq/drivers/userdriver" "github.com/neilotoole/sq/libsq/core/cleanup" "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/kind" "github.com/neilotoole/sq/libsq/core/sqlmodel" "github.com/neilotoole/sq/libsq/driver" ) // Genre is the user driver genre that this package supports. const Genre = "xml" // Import implements userdriver.ImportFunc. func Import(ctx context.Context, log lg.Log, def *userdriver.DriverDef, data io.Reader, destDB driver.Database) error { if def.Genre != Genre { return errz.Errorf("xmlud.Import does not support genre %q", def.Genre) } im := &importer{ log: log, def: def, selStack: newSelStack(), rowStack: newRowStack(), tblDefs: map[string]*sqlmodel.TableDef{}, tblSequence: map[string]int64{}, execInsertFns: map[string]func(ctx context.Context, insertVals []any) error{}, execUpdateFns: map[string]func(ctx context.Context, updateVals []any, whereArgs []any) error{}, clnup: cleanup.New(), msgOnce: map[string]struct{}{}, } err := im.execImport(ctx, data, destDB) err2 := im.clnup.Run() if err != nil { return errz.Wrap(err, "xml import") } return errz.Wrap(err2, "xml import: cleanup") } // importer does the work of importing data from XML. type importer struct { log lg.Log def *userdriver.DriverDef data io.Reader destDB driver.Database selStack *selStack rowStack *rowStack tblDefs map[string]*sqlmodel.TableDef // tblSequence is a map of table name to the last // insert ID value for that table. See dbInsert for more. tblSequence map[string]int64 // execInsertFns is a map of a table+cols key to an func for inserting // vals. Effectively it can be considered a cache of prepared insert // statements. See the dbInsert function. execInsertFns map[string]func(ctx context.Context, vals []any) error // execUpdateFns is similar to execInsertFns, but for UPDATE instead // of INSERT. The whereArgs param is the arguments for the // update's WHERE clause. execUpdateFns map[string]func(ctx context.Context, updateVals []any, whereArgs []any) error // clnup holds cleanup funcs that should be run when the importer // finishes. clnup *cleanup.Cleanup // msgOnce is used by method msgOncef. msgOnce map[string]struct{} } func (im *importer) execImport(ctx context.Context, r io.Reader, destDB driver.Database) error { im.data, im.destDB = r, destDB err := im.createTables(ctx) if err != nil { return err } decoder := xml.NewDecoder(im.data) for { t, err := decoder.Token() if t == nil { break } if err != nil { return errz.Err(err) } switch elem := t.(type) { case xml.StartElement: im.selStack.push(elem.Name.Local) if im.isRootSelector() { continue } if im.isRowSelector() { // We found a new row... prevRow := im.rowStack.peek() if prevRow != nil { // Because the new row might require the primary key of the prev row, // we need to save the previous row, to ensure its primary key is // generated. err = im.saveRow(ctx, prevRow) if err != nil { return err } } var curRow *rowState curRow, err = im.buildRow() if err != nil { return err } im.rowStack.push(curRow) err = im.handleElemAttrs(elem, curRow) if err != nil { return err } continue } // It's not a row element, it's a col element curRow := im.rowStack.peek() if curRow == nil { return errz.Errorf("unable to parse XML: no current row on stack for elem %q", elem.Name.Local) } col := curRow.tbl.ColBySelector(im.selStack.selector()) if col == nil { if msg, ok := im.msgOncef("Skip: element %q is not a column of table %q", elem.Name.Local, curRow.tbl.Name); ok { im.log.Debug(msg) } continue } curRow.curCol = col err = im.handleElemAttrs(elem, curRow) if err != nil { return err } case xml.EndElement: if im.isRowSelector() { row := im.rowStack.peek() if row.dirty() { err = im.saveRow(ctx, row) if err != nil { return err } } im.rowStack.pop() } im.selStack.pop() case xml.CharData: data := string(elem) curRow := im.rowStack.peek() if curRow == nil { continue } if curRow.curCol == nil { continue } val, err := im.convertVal(curRow.tbl.Name, curRow.curCol, data) if err != nil { return err } curRow.dirtyColVals[curRow.curCol.Name] = val curRow.curCol = nil } } return nil } func (im *importer) convertVal(tbl string, col *userdriver.ColMapping, data any) (any, error) { const errTpl = `conversion error: %s.%s: expected "%s" but got %T(%v)` const errTplMsg = `conversion error: %s.%s: expected "%s" but got %T(%v): %v` switch col.Kind { //nolint:exhaustive default: return nil, errz.Errorf("unknown data kind %q for col %s", col.Kind, col.Name) case kind.Text, kind.Time: return data, nil case kind.Int: switch data := data.(type) { case int, int32, int64: return data, nil case string: val, err := strconv.ParseInt(data, 0, 64) if err != nil { return nil, errz.Errorf(errTplMsg, tbl, col.Name, col.Kind, data, data, err) } return val, nil default: return nil, errz.Errorf(errTpl, tbl, col.Name, col.Kind, data, data) } case kind.Float: switch data := data.(type) { case float32, float64: return data, nil case string: val, err := strconv.ParseFloat(data, 64) if err != nil { return nil, errz.Errorf(errTplMsg, tbl, col.Name, col.Kind, data, data, err) } return val, nil default: return nil, errz.Errorf(errTpl, tbl, col.Name, col.Kind, data, data) } case kind.Decimal: return data, nil case kind.Bool: switch data := data.(type) { case bool: return data, nil case int, int32, int64: if data == 0 { return false, nil } return true, nil case string: val, err := strconv.ParseBool(data) if err != nil { return nil, errz.Errorf(errTplMsg, tbl, col.Name, col.Kind, data, data, err) } return val, nil default: return nil, errz.Errorf(errTpl, tbl, col.Name, col.Kind, data, data) } case kind.Datetime, kind.Date: return data, nil case kind.Bytes: return data, nil case kind.Null: return data, nil } } func (im *importer) handleElemAttrs(elem xml.StartElement, curRow *rowState) error { if len(elem.Attr) > 0 { baseSel := im.selStack.selector() for _, attr := range elem.Attr { attrSel := baseSel + "/@" + attr.Name.Local attrCol := curRow.tbl.ColBySelector(attrSel) if attrCol == nil { if msg, ok := im.msgOncef("Skip: attr %q is not a column of table %q", attrSel, curRow.tbl.Name); ok { im.log.Debugf(msg) } continue } // We have found the col matching the attribute val, err := im.convertVal(curRow.tbl.Name, attrCol, attr.Value) if err != nil { return err } curRow.dirtyColVals[attrCol.Name] = val } } return nil } // setForeignColsVals sets the values of any column that needs to be // populated from a foreign key. func (im *importer) setForeignColsVals(row *rowState) error { // check if we need to populate any of the row's values with // foreign key data (e.g. from parent table). for _, col := range row.tbl.Cols { if col.Foreign == "" { continue } // yep, we need to add a foreign key parts := strings.Split(col.Foreign, "/") // parts will look like [ "..", "channel_id" ] if len(parts) != 2 || parts[0] != ".." { return errz.Errorf(`%s.%s: "foreign" field should be of form "../col_name" but was %q`, row.tbl.Name, col.Name, col.Foreign) } fkName := parts[1] parentRow := im.rowStack.peekN(1) if parentRow == nil { return errz.Errorf("unable to find parent() table for foreign key for %s.%s", row.tbl.Name, col.Name) } fkVal, ok := parentRow.savedColVals[fkName] if !ok { return errz.Errorf(`%s.%s: unable to find foreign key value in parent table %q`, row.tbl.Name, col.Name, parentRow.tbl.Name) } row.dirtyColVals[col.Name] = fkVal } return nil } func (im *importer) setSequenceColsVals(row *rowState, nextSeqVal int64) { seqColNames := userdriver.NamesFromCols(row.tbl.SequenceCols()) for _, seqColName := range seqColNames { if _, ok := row.savedColVals[seqColName]; ok { // This seq col has already been saved continue } if _, ok := row.dirtyColVals[seqColName]; ok { // Hmmmn... seqColName is already present. This shouldn't happen, // as the point of a sequence col is to auto-generate the col // value. The input data is inconsistent, or at least, it // clashes with the user driver def. // // We could override the value, or trust the input. // // But given that the seqCol is typically the primary key, // trusting the input could cause a constraint violation // if a subsequent row doesn't have a value for the seqCol. // // Probably safer to override the value. row.dirtyColVals[seqColName] = nextSeqVal im.log.Warnf("%s.%s is a auto-generated sequence() column: ignoring the value found in input", row.tbl.Name, seqColName) continue } // Else, this seq col has not yet been saved row.dirtyColVals[seqColName] = nextSeqVal } } func (im *importer) saveRow(ctx context.Context, row *rowState) error { if !row.dirty() { return nil } tblDef, ok := im.tblDefs[row.tbl.Name] if !ok { return errz.Errorf("unable to find definition for table %q", row.tbl.Name) } if row.created() { // Row already exists in the db err := im.dbUpdate(ctx, row) if err != nil { return errz.Wrapf(err, "failed to update table %q", tblDef.Name) } row.markDirtyAsSaved() return nil } // We're going to INSERT the row. // Maintain the table's sequence. Note that we always increment the // seq val even if there are no sequence cols for this table. prevSeqVal := im.tblSequence[tblDef.Name] nextSeqVal := prevSeqVal + 1 im.tblSequence[tblDef.Name] = nextSeqVal im.setSequenceColsVals(row, nextSeqVal) // Set any foreign cols err := im.setForeignColsVals(row) if err != nil { return err } // Verify that all required cols are present for _, requiredCol := range row.tbl.RequiredCols() { if _, ok = row.dirtyColVals[requiredCol.Name]; !ok { return errz.Errorf("no value for required column %s.%s", row.tbl.Name, requiredCol.Name) } } err = im.dbInsert(ctx, row) if err != nil { return errz.Wrapf(err, "failed to insert to table %q", tblDef.Name) } row.markDirtyAsSaved() return nil } // dbInsert inserts row's dirty col values to row's table. func (im *importer) dbInsert(ctx context.Context, row *rowState) error { tblName := row.tbl.Name colNames := make([]string, len(row.dirtyColVals)) vals := make([]any, len(row.dirtyColVals)) i := 0 for k, v := range row.dirtyColVals { colNames[i], vals[i] = k, v i++ } // We cache the prepared insert statements. cacheKey := "##insert_func__" + tblName + "__" + strings.Join(colNames, ",") execInsertFn, ok := im.execInsertFns[cacheKey] if !ok { // Nothing cached, prepare the insert statement and insert munge func stmtExecer, err := im.destDB.SQLDriver().PrepareInsertStmt(ctx, im.destDB.DB(), tblName, colNames, 1) if err != nil { return err } // Make sure we close stmt eventually. im.clnup.AddC(stmtExecer) execInsertFn = func(ctx context.Context, vals []any) error { // Munge vals so that they're as the target DB expects err = stmtExecer.Munge(vals) if err != nil { return err } _, err = stmtExecer.Exec(ctx, vals...) return errz.Err(err) } // Cache the execInsertFn. im.execInsertFns[cacheKey] = execInsertFn } err := execInsertFn(ctx, vals) if err != nil { return err } return nil } // dbUpdate updates row's table with row's dirty values, using row's // primary key cols as the args to the WHERE clause. func (im *importer) dbUpdate(ctx context.Context, row *rowState) error { drvr := im.destDB.SQLDriver() tblName := row.tbl.Name pkColNames := row.tbl.PrimaryKey var whereBuilder strings.Builder var pkVals []any for i, pkColName := range pkColNames { if pkVal, ok := row.savedColVals[pkColName]; ok { pkVals = append(pkVals, pkVal) if i > 0 { whereBuilder.WriteString(" AND ") } whereBuilder.WriteString(drvr.Dialect().Enquote(pkColName)) whereBuilder.WriteString(" = ?") continue } // Else, we're missing a pk val return errz.Errorf("failed to update table %q: primary key value %q not present", tblName, pkColName) } whereClause := whereBuilder.String() colNames := make([]string, len(row.dirtyColVals)) dirtyVals := make([]any, len(row.dirtyColVals)) i := 0 for k, v := range row.dirtyColVals { colNames[i], dirtyVals[i] = k, v i++ } // We cache the prepared statement. cacheKey := "##update_func__" + tblName + "__" + strings.Join(colNames, ",") + whereClause execUpdateFn, ok := im.execUpdateFns[cacheKey] if !ok { // Nothing cached, prepare the update statement and munge func stmtExecer, err := drvr.PrepareUpdateStmt(ctx, im.destDB.DB(), tblName, colNames, whereClause) if err != nil { return err } // Make sure we close stmt eventually. im.clnup.AddC(stmtExecer) execUpdateFn = func(ctx context.Context, updateVals []any, whereArgs []any) error { // Munge vals so that they're as the target DB expects err := stmtExecer.Munge(updateVals) if err != nil { return err } // Append the WHERE clause args updateVals = append(updateVals, whereArgs...) _, err = stmtExecer.Exec(ctx, updateVals...) return errz.Err(err) } // Cache the execInsertFn. im.execUpdateFns[cacheKey] = execUpdateFn } err := execUpdateFn(ctx, dirtyVals, pkVals) if err != nil { return err } return nil } func (im *importer) buildRow() (*rowState, error) { tbl := im.def.TableBySelector(im.selStack.selector()) if tbl == nil { return nil, errz.Errorf("no tbl matching current selector: %s", im.selStack.selector()) } r := &rowState{tbl: tbl} r.dirtyColVals = make(map[string]any) r.savedColVals = make(map[string]any) for i := range r.tbl.Cols { // If the table has a column that has a "text()" selector, then we need to capture the // next CharData token, so we mark that col as the current col. if strings.HasSuffix(r.tbl.Cols[i].Selector, "text()") { r.curCol = r.tbl.Cols[i] break } } return r, nil } func (im *importer) createTables(ctx context.Context) error { for i := range im.def.Tables { tblDef, err := userdriver.ToTableDef(im.def.Tables[i]) if err != nil { return err } im.tblDefs[tblDef.Name] = tblDef err = im.destDB.SQLDriver().CreateTable(ctx, im.destDB.DB(), tblDef) if err != nil { return err } im.log.Debugf("Created table %s.%s", im.destDB.Source().Handle, tblDef.Name) } return nil } // isRootSelector returns true if the current selector matches the root selector. func (im *importer) isRootSelector() bool { return im.selStack.selector() == im.def.Selector } // isRowSelector returns true if entity referred to by the current selector // maps to a table row (as opposed to a column). func (im *importer) isRowSelector() bool { return im.def.TableBySelector(im.selStack.selector()) != nil } // msgOncef is used to prevent repeated logging of a message. The // method returns ok=true and the formatted string if the formatted // string has not been previous seen by msgOncef. func (im *importer) msgOncef(format string, a ...any) (msg string, ok bool) { msg = fmt.Sprintf(format, a...) if _, exists := im.msgOnce[msg]; exists { // msg already seen, return ok=false. return "", false } im.msgOnce[msg] = struct{}{} return msg, true }