package json // import.go contains functionality common to the // various JSON import mechanisms. import ( "bytes" "context" stdj "encoding/json" "io" "sort" "strings" "github.com/neilotoole/lg" "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/kind" "github.com/neilotoole/sq/libsq/core/sqlmodel" "github.com/neilotoole/sq/libsq/core/sqlz" "github.com/neilotoole/sq/libsq/core/stringz" "github.com/neilotoole/sq/libsq/driver" "github.com/neilotoole/sq/libsq/source" ) // importJob describes a single import job, where the JSON // at fromSrc is read via openFn and the resulting records // are written to destDB. type importJob struct { fromSrc *source.Source openFn source.FileOpenFunc destDB driver.Database // sampleSize is the maximum number of values to // sample to determine the kind of an element. sampleSize int // flatten specifies that the fields of nested JSON objects are // imported as fields of the single top-level table, with a // scoped column name. flatten bool } type importFunc func(ctx context.Context, log lg.Log, job importJob) error var ( _ importFunc = importJSON _ importFunc = importJSONA _ importFunc = importJSONL ) // getRecMeta returns RecordMeta to use with RecordWriter.Open. func getRecMeta(ctx context.Context, scratchDB driver.Database, tblDef *sqlmodel.TableDef) (sqlz.RecordMeta, error) { colTypes, err := scratchDB.SQLDriver().TableColumnTypes(ctx, scratchDB.DB(), tblDef.Name, tblDef.ColNames()) if err != nil { return nil, err } destMeta, _, err := scratchDB.SQLDriver().RecordMeta(colTypes) if err != nil { return nil, err } return destMeta, nil } const ( leftBrace = stdj.Delim('{') rightBrace = stdj.Delim('}') leftBracket = stdj.Delim('[') rightBracket = stdj.Delim(']') // colScopeSep is used when generating flat column names. Thus // an entity "name.first" becomes "name_first". colScopeSep = "_" ) // objectValueSet is the set of values for each of the fields of // a top-level JSON object. It is a map of entity to a map // of fieldName:fieldValue. For a nested JSON object, the value set // may refer to several entities, and thus may decompose into // insertions to several tables. type objectValueSet map[*entity]map[string]any // processor process JSON objects. type processor struct { // if flattened is true, the JSON object will be flattened into a single table. flatten bool root *entity schema *importSchema colNamesOrdered []string // schemaDirtyEntities tracks entities whose structure have been modified. schemaDirtyEntities map[*entity]struct{} unwrittenObjVals []objectValueSet curObjVals objectValueSet } func newProcessor(flatten bool) *processor { return &processor{ flatten: flatten, schema: &importSchema{}, root: &entity{name: source.MonotableName, detectors: map[string]*kind.Detector{}}, schemaDirtyEntities: map[*entity]struct{}{}, } } func (p *processor) markSchemaDirty(e *entity) { p.schemaDirtyEntities[e] = struct{}{} } func (p *processor) markSchemaClean() { for k := range p.schemaDirtyEntities { delete(p.schemaDirtyEntities, k) } } // calcColName calculates the appropriate DB column name from // a field. The result is different if p.flatten is true (in which // case the column name may have a prefix derived from the entity's // parent). func (p *processor) calcColName(ent *entity, fieldName string) string { if !p.flatten { return fieldName } // Otherwise we namespace the column name. if ent.parent == nil { return fieldName } colName := ent.name + colScopeSep + fieldName return p.calcColName(ent.parent, colName) } // buildSchemaFlat currently only builds a flat (single table) schema. func (p *processor) buildSchemaFlat() (*importSchema, error) { tblDef := &sqlmodel.TableDef{ Name: source.MonotableName, } var colDefs []*sqlmodel.ColDef schema := &importSchema{ colMungeFns: map[*sqlmodel.ColDef]kind.MungeFunc{}, entityTbls: map[*entity]*sqlmodel.TableDef{}, tblDefs: []*sqlmodel.TableDef{tblDef}, // Single table only because flat } visitFn := func(e *entity) error { schema.entityTbls[e] = tblDef for _, field := range e.fieldNames { if detector, ok := e.detectors[field]; ok { // If it has a detector, it's a regular field k, mungeFn, err := detector.Detect() if err != nil { return errz.Err(err) } if k == kind.Null { k = kind.Text } colDef := &sqlmodel.ColDef{ Name: p.calcColName(e, field), Table: tblDef, Kind: k, } colDefs = append(colDefs, colDef) if mungeFn != nil { schema.colMungeFns[colDef] = mungeFn } continue } } return nil } err := walkEntity(p.root, visitFn) if err != nil { return nil, err } // Add the column names, in the correct order for _, colName := range p.colNamesOrdered { for j := range colDefs { if colDefs[j].Name == colName { tblDef.Cols = append(tblDef.Cols, colDefs[j]) } } } return schema, nil } // processObject processes the parsed JSON object m. If the structure // of the importSchema changes due to this object, dirtySchema returns true. func (p *processor) processObject(m map[string]any, chunk []byte) (dirtySchema bool, err error) { p.curObjVals = objectValueSet{} err = p.doAddObject(p.root, m) dirtySchema = len(p.schemaDirtyEntities) > 0 if err != nil { return dirtySchema, err } p.unwrittenObjVals = append(p.unwrittenObjVals, p.curObjVals) p.curObjVals = nil if dirtySchema { err = p.updateColNames(chunk) } return dirtySchema, err } func (p *processor) updateColNames(chunk []byte) error { colNames, err := columnOrderFlat(chunk) if err != nil { return err } for _, colName := range colNames { if !stringz.InSlice(p.colNamesOrdered, colName) { p.colNamesOrdered = append(p.colNamesOrdered, colName) } } return nil } func (p *processor) doAddObject(ent *entity, m map[string]any) error { for fieldName, val := range m { switch val := val.(type) { case map[string]any: // time to recurse child := ent.getChild(fieldName) if child == nil { p.markSchemaDirty(ent) if !stringz.InSlice(ent.fieldNames, fieldName) { // The field name could already exist (even without // the child existing) if we encountered // the field before but it was nil ent.fieldNames = append(ent.fieldNames, fieldName) } child = &entity{ name: fieldName, parent: ent, detectors: map[string]*kind.Detector{}, } ent.children = append(ent.children, child) } else if child.isArray { // Child already exists // Safety check return errz.Errorf("JSON entity %q previously detected as array, but now detected as object", ent.String()) } err := p.doAddObject(child, val) if err != nil { return err } case []any: if !stringz.InSlice(ent.fieldNames, fieldName) { ent.fieldNames = append(ent.fieldNames, fieldName) } default: // It's a regular value detector, ok := ent.detectors[fieldName] if !ok { p.markSchemaDirty(ent) if stringz.InSlice(ent.fieldNames, fieldName) { return errz.Errorf("JSON field %q was previously detected as a nested field (object or array)") } ent.fieldNames = append(ent.fieldNames, fieldName) detector = kind.NewDetector() ent.detectors[fieldName] = detector } entVals := p.curObjVals[ent] if entVals == nil { entVals = map[string]any{} p.curObjVals[ent] = entVals } colName := p.calcColName(ent, fieldName) entVals[colName] = val val = maybeFloatToInt(val) detector.Sample(val) } } return nil } // buildInsertionsFlat builds a set of DB insertions from the // processor's unwrittenObjVals. After a non-error return, unwrittenObjVals // is empty. func (p *processor) buildInsertionsFlat(schema *importSchema) ([]*insertion, error) { if len(schema.tblDefs) != 1 { return nil, errz.Errorf("expected 1 table for flat JSON processing but got %d", len(schema.tblDefs)) } tblDef := schema.tblDefs[0] var insertions []*insertion // Each of unwrittenObjVals is effectively an INSERT row for _, objValSet := range p.unwrittenObjVals { var colNames []string colVals := map[string]any{} for ent, fieldVals := range objValSet { // For each entity, we get its values and add them to colVals. for colName, val := range fieldVals { if _, ok := colVals[colName]; ok { return nil, errz.Errorf("column %q already exists, but found column with same name in %q", ent) } colVals[colName] = val colNames = append(colNames, colName) } } sort.Strings(colNames) vals := make([]any, len(colNames)) for i, colName := range colNames { vals[i] = colVals[colName] } insertions = append(insertions, newInsertion(tblDef.Name, colNames, vals)) } p.unwrittenObjVals = p.unwrittenObjVals[:0] return insertions, nil } // entity models the structure of a JSON entity, either an object or an array. type entity struct { // isArray is true if the entity is an array, false if an object. isArray bool name string parent *entity children []*entity // fieldName holds the names of each field. This includes simple // fields (such as a number or string) and nested types like // object or array. fieldNames []string // detectors holds a kind detector for each non-entity field // of entity. That is, it holds a detector for each string or number // field etc, but not for an object or array field. detectors map[string]*kind.Detector } func (e *entity) String() string { name := e.name if name == "" { name = source.MonotableName } parent := e.parent for parent != nil { name = parent.String() + "." + name parent = parent.parent } return name } // fqFieldName returns the fully-qualified field name, such // as "data.name.first_name". func (e *entity) fqFieldName(field string) string { //nolint:unused return e.String() + "." + field } // getChild returns the named child, or nil. func (e *entity) getChild(name string) *entity { for _, child := range e.children { if child.name == name { return child } } return nil } func walkEntity(ent *entity, visitFn func(*entity) error) error { err := visitFn(ent) if err != nil { return err } for _, child := range ent.children { err = walkEntity(child, visitFn) if err != nil { return err } } return nil } // importSchema encapsulates the table definitions that // the JSON is imported to. type importSchema struct { tblDefs []*sqlmodel.TableDef colMungeFns map[*sqlmodel.ColDef]kind.MungeFunc // entityTbls is a mapping of entity to the table in which // the entity's fields will be inserted. entityTbls map[*entity]*sqlmodel.TableDef } func execSchemaDelta(ctx context.Context, log lg.Log, drvr driver.SQLDriver, db sqlz.DB, curSchema, newSchema *importSchema, ) error { var err error if curSchema == nil { for _, tblDef := range newSchema.tblDefs { err = drvr.CreateTable(ctx, db, tblDef) if err != nil { return err } log.Debugf("Created table %q", tblDef.Name) } return nil } return errz.New("schema delta not yet implemented") } // columnOrderFlat parses the json chunk and returns a slice // containing column names, in the order they appear in chunk. // Nested fields are flattened, e.g: // // {"a":1, "b": {"c":2, "d":3}} --> ["a", "b_c", "b_d"] func columnOrderFlat(chunk []byte) ([]string, error) { dec := stdj.NewDecoder(bytes.NewReader(chunk)) var ( cols []string stack []string tok stdj.Token err error ) // Get the opening left-brace _, err = requireDelimToken(dec, leftBrace) if err != nil { return nil, err } loop: for { // Expect tok to be a field name, or else the terminating right-brace. tok, err = dec.Token() if err != nil { if err == io.EOF { //nolint:errorlint break } return nil, errz.Err(err) } switch tok := tok.(type) { case string: // tok is a field name stack = append(stack, tok) case stdj.Delim: if tok == rightBrace { if len(stack) == 0 { // This is the terminating right-brace break loop } // Else we've come to the end of an object stack = stack[:len(stack)-1] continue } default: return nil, errz.Errorf("expected string field name but got %T: %s", tok, formatToken(tok)) } // We've consumed the field name above, now let's see what // the next token is tok, err = dec.Token() if err != nil { return nil, errz.Err(err) } switch tok := tok.(type) { default: // This next token was a regular old value. // The field name is already on the stack. We generate // the column name... cols = append(cols, strings.Join(stack, colScopeSep)) // And pop the stack. stack = stack[0 : len(stack)-1] case stdj.Delim: // The next token was a delimiter. if tok == leftBrace { // It's the start of a nested object. // Back to the top of the loop we go, so that // we can descend into the nested object. continue loop } if tok == leftBracket { // It's the start of an array. // Note that we don't descend into arrays. cols = append(cols, strings.Join(stack, colScopeSep)) stack = stack[0 : len(stack)-1] err = decoderFindArrayClose(dec) if err != nil { return nil, err } } } } return cols, nil } // decoderFindArrayClose advances dec until a closing // right-bracket ']' is located at the correct nesting level. // The most-recently returned decoder token should have been // the opening left-bracket '['. func decoderFindArrayClose(dec *stdj.Decoder) error { var depth int var tok stdj.Token var err error for { tok, err = dec.Token() if err != nil { break } if tok == leftBracket { // Nested array depth++ continue } if tok == rightBracket { if depth == 0 { return nil } depth-- } } return errz.Err(err) } // execInsertions performs db INSERT for each of the insertions. func execInsertions(ctx context.Context, log lg.Log, drvr driver.SQLDriver, db sqlz.DB, insertions []*insertion) error { // FIXME: This is an inefficient way of performing insertion. // We should be re-using the prepared statement, and probably // should batch the inserts as well. See driver.BatchInsert. var err error var execer *driver.StmtExecer for _, insert := range insertions { execer, err = drvr.PrepareInsertStmt(ctx, db, insert.tbl, insert.cols, 1) if err != nil { return err } err = execer.Munge(insert.vals) if err != nil { log.WarnIfCloseError(execer) return err } _, err = execer.Exec(ctx, insert.vals...) if err != nil { log.WarnIfCloseError(execer) return err } err = execer.Close() if err != nil { return err } } return nil } type insertion struct { // stmtKey is a concatenation of tbl and cols that can // uniquely identify a db insert statement. stmtKey string tbl string cols []string vals []any } func newInsertion(tbl string, cols []string, vals []any) *insertion { return &insertion{ stmtKey: buildInsertStmtKey(tbl, cols), tbl: tbl, cols: cols, vals: vals, } } // buildInsertStmtKey returns a concatenation of tbl and cols that can // uniquely identify a db insert statement. func buildInsertStmtKey(tbl string, cols []string) string { return tbl + "__" + strings.Join(cols, "_") }