mirror of
https://github.com/neilotoole/sq.git
synced 2024-11-23 19:33:22 +03:00
* JSON large data file
This commit is contained in:
parent
f445f67b44
commit
5f531ffb9d
@ -145,8 +145,8 @@ type objectValueSet map[*entity]map[string]any
|
||||
|
||||
// processor process JSON objects.
|
||||
type processor struct {
|
||||
root *entity
|
||||
importSchema *ingestSchema
|
||||
root *entity
|
||||
curSchema *ingestSchema
|
||||
|
||||
// schemaDirtyEntities tracks entities whose structure have been modified.
|
||||
schemaDirtyEntities map[*entity]struct{}
|
||||
@ -162,9 +162,13 @@ type processor struct {
|
||||
|
||||
func newProcessor(flatten bool) *processor {
|
||||
return &processor{
|
||||
flatten: flatten,
|
||||
importSchema: &ingestSchema{},
|
||||
root: &entity{name: source.MonotableName, detectors: map[string]*kind.Detector{}},
|
||||
flatten: flatten,
|
||||
curSchema: nil,
|
||||
root: &entity{
|
||||
name: source.MonotableName,
|
||||
detectors: map[string]*kind.Detector{},
|
||||
kinds: map[string]kind.Kind{},
|
||||
},
|
||||
schemaDirtyEntities: map[*entity]struct{}{},
|
||||
}
|
||||
}
|
||||
@ -263,7 +267,7 @@ func (p *processor) buildSchemaFlat() (*ingestSchema, error) {
|
||||
// processObject processes the parsed JSON object m. If the structure
|
||||
// of the ingestSchema changes due to this object, dirtySchema returns true.
|
||||
func (p *processor) processObject(m map[string]any, chunk []byte) (dirtySchema bool, err error) {
|
||||
p.curObjVals = objectValueSet{}
|
||||
p.curObjVals = make(objectValueSet)
|
||||
err = p.doAddObject(p.root, m)
|
||||
dirtySchema = len(p.schemaDirtyEntities) > 0
|
||||
if err != nil {
|
||||
@ -296,6 +300,8 @@ func (p *processor) updateColNames(chunk []byte) error {
|
||||
}
|
||||
|
||||
func (p *processor) doAddObject(ent *entity, m map[string]any) error {
|
||||
var err error
|
||||
|
||||
for fieldName, val := range m {
|
||||
switch val := val.(type) {
|
||||
case map[string]any:
|
||||
@ -315,6 +321,7 @@ func (p *processor) doAddObject(ent *entity, m map[string]any) error {
|
||||
name: fieldName,
|
||||
parent: ent,
|
||||
detectors: map[string]*kind.Detector{},
|
||||
kinds: map[string]kind.Kind{},
|
||||
}
|
||||
ent.children = append(ent.children, child)
|
||||
} else if child.isArray {
|
||||
@ -324,8 +331,7 @@ func (p *processor) doAddObject(ent *entity, m map[string]any) error {
|
||||
ent.String())
|
||||
}
|
||||
|
||||
err := p.doAddObject(child, val)
|
||||
if err != nil {
|
||||
if err = p.doAddObject(child, val); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -358,14 +364,64 @@ func (p *processor) doAddObject(ent *entity, m map[string]any) error {
|
||||
colName := p.calcColName(ent, fieldName)
|
||||
entVals[colName] = val
|
||||
|
||||
val = maybeFloatToInt(val)
|
||||
detector.Sample(val)
|
||||
colDef := p.getColDef(ent, colName)
|
||||
|
||||
if colDef == nil && val != nil {
|
||||
val = maybeFloatToInt(val)
|
||||
// We don't need to keep sampling after we've detected the kind.
|
||||
detector.Sample(val)
|
||||
} else
|
||||
// REVISIT: We don't need to hold onto the samples after we've detected
|
||||
// the kind, it's just holding onto memory. We should probably nil out
|
||||
// the detector.
|
||||
|
||||
// The column is already defined. Check if the value is allowed.
|
||||
if !p.fieldValAllowed(detector, colDef, val) {
|
||||
p.markSchemaDirty(ent)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *processor) fieldValAllowed(detector *kind.Detector, col *schema.Column, val any) bool {
|
||||
if val == nil || col == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
if col.Kind == kind.Null || col.Kind == kind.Unknown || col.Kind == kind.Text {
|
||||
return true
|
||||
}
|
||||
|
||||
detector.Sample(val)
|
||||
k, _, err := detector.Detect()
|
||||
if err != nil || k != col.Kind {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// getColDef returns the schema.Column, or nil if not existing.
|
||||
func (p *processor) getColDef(ent *entity, colName string) *schema.Column {
|
||||
if p == nil || p.curSchema == nil || p.curSchema.entityTbls == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
tblDef, ok := p.curSchema.entityTbls[ent]
|
||||
if !ok || tblDef == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
colDef, err := tblDef.FindCol(colName)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return colDef
|
||||
}
|
||||
|
||||
// buildInsertionsFlat builds a set of DB insertions from the
|
||||
// processor's unwrittenObjVals. After a non-error return, unwrittenObjVals
|
||||
// is empty.
|
||||
@ -417,6 +473,10 @@ type entity struct {
|
||||
// field etc, but not for an object or array field.
|
||||
detectors map[string]*kind.Detector
|
||||
|
||||
// kinds is the sibling of detectors, holding a kind.Kind for each field,
|
||||
// once the detector has detected the kind.
|
||||
kinds map[string]kind.Kind
|
||||
|
||||
name string
|
||||
children []*entity
|
||||
|
||||
@ -487,31 +547,117 @@ type ingestSchema struct {
|
||||
tblDefs []*schema.Table
|
||||
}
|
||||
|
||||
func (s *ingestSchema) getTable(name string) *schema.Table {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, tbl := range s.tblDefs {
|
||||
if tbl.Name == name {
|
||||
return tbl
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// execSchemaDelta executes the schema delta between curSchema and newSchema.
|
||||
// That is, if curSchema is nil, then newSchema is created in the DB; if
|
||||
// newSchema has additional tables or columns, then those are created in the DB.
|
||||
//
|
||||
// TODO: execSchemaDelta is only partially implemented; it doesn't create
|
||||
// the new tables/columns.
|
||||
func execSchemaDelta(ctx context.Context, drvr driver.SQLDriver, db sqlz.DB,
|
||||
curSchema, newSchema *ingestSchema,
|
||||
) error {
|
||||
log := lg.FromContext(ctx)
|
||||
var err error
|
||||
|
||||
if curSchema == nil {
|
||||
for _, tblDef := range newSchema.tblDefs {
|
||||
err = drvr.CreateTable(ctx, db, tblDef)
|
||||
for _, tbl := range newSchema.tblDefs {
|
||||
err = drvr.CreateTable(ctx, db, tbl)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug("Created table", lga.Table, tblDef.Name)
|
||||
log.Debug("Created table", lga.Table, tbl.Name)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO: implement execSchemaDelta fully.
|
||||
return errz.New("schema delta not yet implemented")
|
||||
var alterTbls []*schema.Table
|
||||
var createTbls []*schema.Table
|
||||
|
||||
for _, newTbl := range newSchema.tblDefs {
|
||||
oldTbl := curSchema.getTable(newTbl.Name)
|
||||
if oldTbl == nil {
|
||||
createTbls = append(createTbls, newTbl)
|
||||
} else if !oldTbl.Equal(newTbl) {
|
||||
alterTbls = append(alterTbls, newTbl)
|
||||
}
|
||||
}
|
||||
|
||||
for _, wantTbl := range alterTbls {
|
||||
oldTbl := curSchema.getTable(wantTbl.Name)
|
||||
if err = execMaybeAlterTable(ctx, drvr, db, oldTbl, wantTbl); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
for _, wantTbl := range createTbls {
|
||||
err = drvr.CreateTable(ctx, db, wantTbl)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug("Created table", lga.Table, wantTbl.Name)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func execMaybeAlterTable(ctx context.Context, drvr driver.SQLDriver, db sqlz.DB,
|
||||
oldTbl, newTbl *schema.Table,
|
||||
) error {
|
||||
log := lg.FromContext(ctx)
|
||||
if newTbl == nil {
|
||||
return nil
|
||||
}
|
||||
if oldTbl == nil {
|
||||
return drvr.CreateTable(ctx, db, newTbl)
|
||||
}
|
||||
|
||||
if oldTbl.Equal(newTbl) {
|
||||
return nil
|
||||
}
|
||||
|
||||
tblName := newTbl.Name
|
||||
|
||||
var createCols []*schema.Column
|
||||
var wantAlterColNames []string
|
||||
var wantAlterColKinds []kind.Kind
|
||||
|
||||
for _, newCol := range newTbl.Cols {
|
||||
oldCol, err := oldTbl.FindCol(newCol.Name)
|
||||
if err != nil {
|
||||
createCols = append(createCols, newCol)
|
||||
} else if newCol.Kind != oldCol.Kind {
|
||||
wantAlterColNames = append(wantAlterColNames, newCol.Name)
|
||||
wantAlterColKinds = append(wantAlterColKinds, newCol.Kind)
|
||||
}
|
||||
}
|
||||
|
||||
if len(wantAlterColNames) > 0 {
|
||||
err := drvr.AlterTableColumnKinds(ctx, db, tblName, wantAlterColNames, wantAlterColKinds)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
for _, col := range createCols {
|
||||
if err := drvr.AlterTableAddColumn(ctx, db, tblName, col.Name, col.Kind); err != nil {
|
||||
return err
|
||||
}
|
||||
log.Debug("Added column", lga.Table, newTbl.Name, lga.Col, col.Name)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// columnOrderFlat parses the json chunk and returns a slice
|
||||
|
@ -6,12 +6,14 @@ import (
|
||||
stdj "encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/neilotoole/sq/libsq/core/errz"
|
||||
"github.com/neilotoole/sq/libsq/core/lg"
|
||||
"github.com/neilotoole/sq/libsq/core/lg/lga"
|
||||
"github.com/neilotoole/sq/libsq/core/lg/lgm"
|
||||
"github.com/neilotoole/sq/libsq/core/progress"
|
||||
"github.com/neilotoole/sq/libsq/core/stringz"
|
||||
"github.com/neilotoole/sq/libsq/files"
|
||||
"github.com/neilotoole/sq/libsq/source/drivertype"
|
||||
@ -105,7 +107,7 @@ func DetectJSON(sampleSize int) files.TypeDetectFunc {
|
||||
}
|
||||
defer lg.WarnIfCloseError(log, lgm.CloseFileReader, r2)
|
||||
|
||||
sc := newObjectInArrayScanner(r2)
|
||||
sc := newObjectInArrayScanner(log, r2)
|
||||
var validObjCount int
|
||||
var obj map[string]any
|
||||
|
||||
@ -140,6 +142,9 @@ func DetectJSON(sampleSize int) files.TypeDetectFunc {
|
||||
}
|
||||
|
||||
func ingestJSON(ctx context.Context, job *ingestJob) error {
|
||||
bar := progress.FromContext(ctx).NewUnitCounter("Ingest JSON", "object")
|
||||
defer bar.Stop()
|
||||
|
||||
log := lg.FromContext(ctx)
|
||||
defer lg.WarnIfCloseError(log, "Close JSON ingest job", job)
|
||||
|
||||
@ -163,13 +168,12 @@ func ingestJSON(ctx context.Context, job *ingestJob) error {
|
||||
defer lg.WarnIfCloseError(log, lgm.CloseDB, conn)
|
||||
|
||||
proc := newProcessor(job.flatten)
|
||||
scan := newObjectInArrayScanner(r)
|
||||
scan := newObjectInArrayScanner(log, r)
|
||||
|
||||
var (
|
||||
obj map[string]any
|
||||
chunk []byte
|
||||
schemaModified bool
|
||||
curSchema *ingestSchema
|
||||
insertions []*insertion
|
||||
hasMore bool
|
||||
)
|
||||
@ -182,11 +186,14 @@ func ingestJSON(ctx context.Context, job *ingestJob) error {
|
||||
|
||||
// obj is returned nil by scan.next when end of input.
|
||||
hasMore = obj != nil
|
||||
if hasMore {
|
||||
bar.Incr(1)
|
||||
}
|
||||
|
||||
if schemaModified {
|
||||
if !hasMore || scan.objCount >= job.sampleSize {
|
||||
log.Debug("Time to (re)build the schema", lga.Line, scan.objCount)
|
||||
if curSchema == nil {
|
||||
if proc.curSchema == nil {
|
||||
log.Debug("First time building the schema")
|
||||
}
|
||||
|
||||
@ -195,7 +202,7 @@ func ingestJSON(ctx context.Context, job *ingestJob) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = execSchemaDelta(ctx, drvr, conn, curSchema, newSchema); err != nil {
|
||||
if err = execSchemaDelta(ctx, drvr, conn, proc.curSchema, newSchema); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -203,10 +210,11 @@ func ingestJSON(ctx context.Context, job *ingestJob) error {
|
||||
// so we mark it as clean.
|
||||
proc.markSchemaClean()
|
||||
|
||||
curSchema = newSchema
|
||||
// curSchema = newSchema
|
||||
proc.curSchema = newSchema
|
||||
newSchema = nil //nolint:wastedassign
|
||||
|
||||
if insertions, err = proc.buildInsertionsFlat(curSchema); err != nil {
|
||||
if insertions, err = proc.buildInsertionsFlat(proc.curSchema); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -214,11 +222,11 @@ func ingestJSON(ctx context.Context, job *ingestJob) error {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !hasMore {
|
||||
// We're done
|
||||
break
|
||||
}
|
||||
if !hasMore {
|
||||
// end of input
|
||||
break
|
||||
}
|
||||
|
||||
if schemaModified, err = proc.processObject(obj, chunk); err != nil {
|
||||
@ -227,7 +235,7 @@ func ingestJSON(ctx context.Context, job *ingestJob) error {
|
||||
|
||||
// Initial schema has not been created: we're still in
|
||||
// the sampling phase. So we loop.
|
||||
if curSchema == nil {
|
||||
if proc.curSchema == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
@ -240,7 +248,7 @@ func ingestJSON(ctx context.Context, job *ingestJob) error {
|
||||
|
||||
// The schema exists in the DB, and the current JSON chunk hasn't
|
||||
// dirtied the schema, so it's safe to insert the recent rows.
|
||||
if insertions, err = proc.buildInsertionsFlat(curSchema); err != nil {
|
||||
if insertions, err = proc.buildInsertionsFlat(proc.curSchema); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -260,6 +268,7 @@ func ingestJSON(ctx context.Context, job *ingestJob) error {
|
||||
// JSON objects, returning the decoded object and the chunk of JSON
|
||||
// that it was scanned from. Example input: [{a:1},{a:2},{a:3}].
|
||||
type objectsInArrayScanner struct {
|
||||
log *slog.Logger
|
||||
// buf will get all the data that the JSON decoder reads.
|
||||
// buf's role is to keep track of JSON text that has already been
|
||||
// consumed by dec, so that we can return the raw JSON chunk
|
||||
@ -294,13 +303,13 @@ type objectsInArrayScanner struct {
|
||||
|
||||
// newObjectInArrayScanner returns a new instance that
|
||||
// reads from r.
|
||||
func newObjectInArrayScanner(r io.Reader) *objectsInArrayScanner {
|
||||
func newObjectInArrayScanner(log *slog.Logger, r io.Reader) *objectsInArrayScanner {
|
||||
buf := &buffer{b: []byte{}}
|
||||
// Everything that dec reads from r is written
|
||||
// to buf via the TeeReader.
|
||||
dec := stdj.NewDecoder(io.TeeReader(r, buf))
|
||||
|
||||
return &objectsInArrayScanner{buf: buf, dec: dec}
|
||||
return &objectsInArrayScanner{log: log, buf: buf, dec: dec}
|
||||
}
|
||||
|
||||
// next scans the next object from the reader. The returned chunk holds
|
||||
@ -358,10 +367,19 @@ func (s *objectsInArrayScanner) next() (obj map[string]any, chunk []byte, err er
|
||||
return nil, nil, errz.Err(err)
|
||||
}
|
||||
|
||||
more = s.dec.More()
|
||||
var delimIndex int
|
||||
var delim byte
|
||||
|
||||
if len(s.decBuf) == 0 {
|
||||
// We've landed right on the edge of the chunk, there'll be no delim (or any
|
||||
// other char), so we skip over delim searching.
|
||||
goto BOTTOM
|
||||
}
|
||||
|
||||
more = s.dec.More() // REVISIT: Should we not be testing this value?
|
||||
|
||||
// Peek ahead in the decoder buffer
|
||||
delimIndex, delim := nextDelim(s.decBuf, 0, true)
|
||||
delimIndex, delim = nextDelim(s.decBuf, 0, true)
|
||||
if delimIndex == -1 {
|
||||
return nil, nil, errz.New("invalid JSON: additional input expected")
|
||||
}
|
||||
@ -403,6 +421,7 @@ func (s *objectsInArrayScanner) next() (obj map[string]any, chunk []byte, err er
|
||||
}
|
||||
}
|
||||
|
||||
BOTTOM:
|
||||
// Note that we re-use the vars delimIndex and delim here.
|
||||
// Above us, these vars referred to s.decBuf, not s.buf as here.
|
||||
delimIndex, delim = nextDelim(s.buf.b, s.prevDecPos-s.bufOffset, false)
|
||||
|
@ -12,6 +12,7 @@ import (
|
||||
"github.com/neilotoole/sq/libsq/core/lg"
|
||||
"github.com/neilotoole/sq/libsq/core/lg/lga"
|
||||
"github.com/neilotoole/sq/libsq/core/lg/lgm"
|
||||
"github.com/neilotoole/sq/libsq/core/progress"
|
||||
"github.com/neilotoole/sq/libsq/files"
|
||||
"github.com/neilotoole/sq/libsq/source/drivertype"
|
||||
)
|
||||
@ -89,6 +90,8 @@ func DetectJSONL(sampleSize int) files.TypeDetectFunc {
|
||||
func ingestJSONL(ctx context.Context, job *ingestJob) error { //nolint:gocognit
|
||||
log := lg.FromContext(ctx)
|
||||
defer lg.WarnIfCloseError(log, "Close JSONL ingest job", job)
|
||||
bar := progress.FromContext(ctx).NewUnitCounter("Ingest JSONL", "object")
|
||||
defer bar.Stop()
|
||||
|
||||
r, err := job.newRdrFn(ctx)
|
||||
if err != nil {
|
||||
@ -124,6 +127,10 @@ func ingestJSONL(ctx context.Context, job *ingestJob) error { //nolint:gocognit
|
||||
return err
|
||||
}
|
||||
|
||||
if hasMore {
|
||||
bar.Incr(1)
|
||||
}
|
||||
|
||||
if schemaModified {
|
||||
if !hasMore || scan.validLineCount >= job.sampleSize {
|
||||
log.Debug("Time to (re)build the schema", lga.Line, scan.totalLineCount)
|
||||
|
@ -12,14 +12,26 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/neilotoole/sq/cli/testrun"
|
||||
"github.com/neilotoole/sq/drivers/json"
|
||||
"github.com/neilotoole/sq/libsq/core/kind"
|
||||
"github.com/neilotoole/sq/libsq/core/lg/lgt"
|
||||
"github.com/neilotoole/sq/libsq/driver"
|
||||
"github.com/neilotoole/sq/libsq/source"
|
||||
"github.com/neilotoole/sq/libsq/source/drivertype"
|
||||
"github.com/neilotoole/sq/libsq/source/metadata"
|
||||
"github.com/neilotoole/sq/testh"
|
||||
"github.com/neilotoole/sq/testh/proj"
|
||||
"github.com/neilotoole/sq/testh/sakila"
|
||||
"github.com/neilotoole/sq/testh/testsrc"
|
||||
"github.com/neilotoole/sq/testh/tu"
|
||||
)
|
||||
|
||||
const (
|
||||
citiesLargeObjCount = 146994
|
||||
citiesSmallObjCount = 3
|
||||
)
|
||||
|
||||
func BenchmarkIngestJSONL_Flat(b *testing.B) {
|
||||
// $ go test -count=10 -benchtime=5s -bench BenchmarkIngestJSONL_Flat > old.bench.txt
|
||||
// # Make changes
|
||||
@ -303,9 +315,10 @@ func TestScanObjectsInArray(t *testing.T) {
|
||||
|
||||
t.Run(tu.Name(i, tc.in), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
log := lgt.New(t)
|
||||
|
||||
r := bytes.NewReader([]byte(tc.in))
|
||||
gotObjs, gotChunks, err := json.ScanObjectsInArray(r)
|
||||
gotObjs, gotChunks, err := json.ScanObjectsInArray(log, r)
|
||||
if tc.wantErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
@ -332,6 +345,8 @@ func TestScanObjectsInArray_Files(t *testing.T) {
|
||||
{fname: "testdata/actor.json", wantCount: sakila.TblActorCount},
|
||||
{fname: "testdata/film_actor.json", wantCount: sakila.TblFilmActorCount},
|
||||
{fname: "testdata/payment.json", wantCount: sakila.TblPaymentCount},
|
||||
{fname: "testdata/cities.small.json", wantCount: citiesSmallObjCount},
|
||||
{fname: "testdata/cities.large.json", wantCount: citiesLargeObjCount},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
@ -339,12 +354,13 @@ func TestScanObjectsInArray_Files(t *testing.T) {
|
||||
|
||||
t.Run(tu.Name(tc.fname), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
log := lgt.New(t)
|
||||
|
||||
f, err := os.Open(tc.fname)
|
||||
require.NoError(t, err)
|
||||
defer f.Close()
|
||||
|
||||
gotObjs, gotChunks, err := json.ScanObjectsInArray(f)
|
||||
gotObjs, gotChunks, err := json.ScanObjectsInArray(log, f)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tc.wantCount, len(gotObjs))
|
||||
require.Equal(t, tc.wantCount, len(gotChunks))
|
||||
@ -398,3 +414,71 @@ func TestColumnOrderFlat(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONData_Cities(t *testing.T) {
|
||||
t.Parallel()
|
||||
tu.SkipWindows(t, "Takes too long on Windows CI")
|
||||
tu.SkipShort(t, true)
|
||||
|
||||
const wantCSV = `name,lat,lng,country,admin1,admin2
|
||||
Sant Julià de Lòria,42.46372,1.49129,AD,6,
|
||||
Pas de la Casa,42.54277,1.73361,AD,3,
|
||||
Ordino,42.55623,1.53319,AD,5,`
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
wantRowCount int
|
||||
}{
|
||||
{
|
||||
name: "cities.small.json",
|
||||
wantRowCount: citiesSmallObjCount,
|
||||
},
|
||||
{
|
||||
name: "cities.large.json",
|
||||
wantRowCount: citiesLargeObjCount,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tr := testrun.New(context.Background(), t, nil).Hush()
|
||||
src := &source.Source{
|
||||
Handle: "@cities",
|
||||
Type: drivertype.JSON,
|
||||
Location: proj.Abs(filepath.Join("drivers/json/testdata/", tc.name)),
|
||||
}
|
||||
tr = tr.Add(*src)
|
||||
err := tr.Exec("--csv", ".data | .[0:3]")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, wantCSV, tr.OutString())
|
||||
|
||||
// FIXME: test inspect table row count
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Cities_MixedFieldKind(t *testing.T) {
|
||||
src := &source.Source{
|
||||
Handle: "@cities",
|
||||
Type: drivertype.JSON,
|
||||
Location: proj.Abs("drivers/json/testdata/cities.sample-mixed-10.json"),
|
||||
}
|
||||
|
||||
tr := testrun.New(context.Background(), t, nil).Hush()
|
||||
tr.Add(*src)
|
||||
|
||||
require.NoError(t, tr.Exec("config", "set", driver.OptIngestSampleSize.Key(), "2"))
|
||||
|
||||
err := tr.Reset().Exec("inspect", "-j")
|
||||
require.NoError(t, err)
|
||||
var md *metadata.Source
|
||||
tr.Bind(&md)
|
||||
t.Log(md.Name)
|
||||
colAdmin1 := md.Table("data").Column("admin1")
|
||||
require.NotNil(t, colAdmin1)
|
||||
require.Equal(t, kind.Text.String(), colAdmin1.Kind.String())
|
||||
}
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
@ -75,8 +76,8 @@ func TestDetectColKindsJSONA(t *testing.T) {
|
||||
|
||||
// ScanObjectsInArray is a convenience function
|
||||
// for objectsInArrayScanner.
|
||||
func ScanObjectsInArray(r io.Reader) (objs []map[string]any, chunks [][]byte, err error) {
|
||||
sc := newObjectInArrayScanner(r)
|
||||
func ScanObjectsInArray(log *slog.Logger, r io.Reader) (objs []map[string]any, chunks [][]byte, err error) {
|
||||
sc := newObjectInArrayScanner(log, r)
|
||||
|
||||
for {
|
||||
var obj map[string]any
|
||||
|
26
drivers/json/testdata/README.md
vendored
Normal file
26
drivers/json/testdata/README.md
vendored
Normal file
@ -0,0 +1,26 @@
|
||||
# drivers/json/testdata
|
||||
|
||||
|
||||
## Cities
|
||||
|
||||
The various "cities" JSON files contain an array of data like this:
|
||||
|
||||
```json
|
||||
{
|
||||
"name": "Sant Julià de Lòria",
|
||||
"lat": "42.46372",
|
||||
"lng": "1.49129",
|
||||
"country": "AD",
|
||||
"admin1": "06",
|
||||
"admin2": ""
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
|
||||
- [`cities.large.json`](cities.large.json) is `146,994` cities (~20MB uncompressed).
|
||||
- [`cities.small.json`](cities.small.json) 4 city objects.
|
||||
- [`cities.sample-mixed-10.json`](cities.sample-mixed-10.json) contains 10 objects, but the `admin1` field
|
||||
of the third object is the string `TEXT` instead of a number.
|
||||
|
||||
|
1175954
drivers/json/testdata/cities.large.json
vendored
Normal file
1175954
drivers/json/testdata/cities.large.json
vendored
Normal file
File diff suppressed because it is too large
Load Diff
82
drivers/json/testdata/cities.sample-mixed-10.json
vendored
Normal file
82
drivers/json/testdata/cities.sample-mixed-10.json
vendored
Normal file
@ -0,0 +1,82 @@
|
||||
[
|
||||
{
|
||||
"name": "Sant Julià de Lòria",
|
||||
"lat": "42.46372",
|
||||
"lng": "1.49129",
|
||||
"country": "AD",
|
||||
"admin1": "06",
|
||||
"admin2": ""
|
||||
},
|
||||
{
|
||||
"name": "Pas de la Casa",
|
||||
"lat": "42.54277",
|
||||
"lng": "1.73361",
|
||||
"country": "AD",
|
||||
"admin1": "03",
|
||||
"admin2": ""
|
||||
},
|
||||
{
|
||||
"name": "Ordino",
|
||||
"lat": "42.55623",
|
||||
"lng": "1.53319",
|
||||
"country": "AD",
|
||||
"admin1": "TEXT",
|
||||
"admin2": ""
|
||||
},
|
||||
{
|
||||
"name": "les Escaldes",
|
||||
"lat": "42.50729",
|
||||
"lng": "1.53414",
|
||||
"country": "AD",
|
||||
"admin1": "08",
|
||||
"admin2": ""
|
||||
},
|
||||
{
|
||||
"name": "la Massana",
|
||||
"lat": "42.54499",
|
||||
"lng": "1.51483",
|
||||
"country": "AD",
|
||||
"admin1": "04",
|
||||
"admin2": ""
|
||||
},
|
||||
{
|
||||
"name": "Encamp",
|
||||
"lat": "42.53474",
|
||||
"lng": "1.58014",
|
||||
"country": "AD",
|
||||
"admin1": "03",
|
||||
"admin2": ""
|
||||
},
|
||||
{
|
||||
"name": "Canillo",
|
||||
"lat": "42.5676",
|
||||
"lng": "1.59756",
|
||||
"country": "AD",
|
||||
"admin1": "02",
|
||||
"admin2": ""
|
||||
},
|
||||
{
|
||||
"name": "Arinsal",
|
||||
"lat": "42.57205",
|
||||
"lng": "1.48453",
|
||||
"country": "AD",
|
||||
"admin1": "04",
|
||||
"admin2": ""
|
||||
},
|
||||
{
|
||||
"name": "Andorra la Vella",
|
||||
"lat": "42.50779",
|
||||
"lng": "1.52109",
|
||||
"country": "AD",
|
||||
"admin1": "07",
|
||||
"admin2": ""
|
||||
},
|
||||
{
|
||||
"name": "Umm Al Quwain City",
|
||||
"lat": "25.56473",
|
||||
"lng": "55.55517",
|
||||
"country": "AE",
|
||||
"admin1": "07",
|
||||
"admin2": ""
|
||||
}
|
||||
]
|
26
drivers/json/testdata/cities.small.json
vendored
Normal file
26
drivers/json/testdata/cities.small.json
vendored
Normal file
@ -0,0 +1,26 @@
|
||||
[
|
||||
{
|
||||
"name": "Sant Julià de Lòria",
|
||||
"lat": "42.46372",
|
||||
"lng": "1.49129",
|
||||
"country": "AD",
|
||||
"admin1": "06",
|
||||
"admin2": ""
|
||||
},
|
||||
{
|
||||
"name": "Pas de la Casa",
|
||||
"lat": "42.54277",
|
||||
"lng": "1.73361",
|
||||
"country": "AD",
|
||||
"admin1": "03",
|
||||
"admin2": ""
|
||||
},
|
||||
{
|
||||
"name": "Ordino",
|
||||
"lat": "42.55623",
|
||||
"lng": "1.53319",
|
||||
"country": "AD",
|
||||
"admin1": "05",
|
||||
"admin2": ""
|
||||
}
|
||||
]
|
@ -354,6 +354,11 @@ func (d *driveri) AlterTableRenameColumn(ctx context.Context, db sqlz.DB, tbl, c
|
||||
return errz.Wrapf(errw(err), "alter table: failed to rename column {%s.%s} to {%s}", tbl, col, newName)
|
||||
}
|
||||
|
||||
// AlterTableColumnKinds is not yet implemented for mysql.
|
||||
func (d *driveri) AlterTableColumnKinds(_ context.Context, _ sqlz.DB, _ string, _ []string, _ []kind.Kind) error {
|
||||
return errz.New("not implemented")
|
||||
}
|
||||
|
||||
// PrepareInsertStmt implements driver.SQLDriver.
|
||||
func (d *driveri) PrepareInsertStmt(ctx context.Context, db sqlz.DB, destTbl string, destColNames []string,
|
||||
numRows int,
|
||||
|
@ -457,6 +457,11 @@ func (d *driveri) AlterTableAddColumn(ctx context.Context, db sqlz.DB, tbl, col
|
||||
return nil
|
||||
}
|
||||
|
||||
// AlterTableColumnKinds is not yet implemented for postgres.
|
||||
func (d *driveri) AlterTableColumnKinds(_ context.Context, _ sqlz.DB, _ string, _ []string, _ []kind.Kind) error {
|
||||
return errz.New("not implemented")
|
||||
}
|
||||
|
||||
// PrepareInsertStmt implements driver.SQLDriver.
|
||||
func (d *driveri) PrepareInsertStmt(ctx context.Context, db sqlz.DB, destTbl string, destColNames []string,
|
||||
numRows int,
|
||||
|
149
drivers/sqlite3/alter.go
Normal file
149
drivers/sqlite3/alter.go
Normal file
@ -0,0 +1,149 @@
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/neilotoole/sq/drivers/sqlite3/internal/sqlparser"
|
||||
"github.com/neilotoole/sq/libsq/core/errz"
|
||||
"github.com/neilotoole/sq/libsq/core/kind"
|
||||
"github.com/neilotoole/sq/libsq/core/lg"
|
||||
"github.com/neilotoole/sq/libsq/core/lg/lga"
|
||||
"github.com/neilotoole/sq/libsq/core/sqlz"
|
||||
"github.com/neilotoole/sq/libsq/core/stringz"
|
||||
)
|
||||
|
||||
// AlterTableRename implements driver.SQLDriver.
|
||||
func (d *driveri) AlterTableRename(ctx context.Context, db sqlz.DB, tbl, newName string) error {
|
||||
q := fmt.Sprintf(`ALTER TABLE %q RENAME TO %q`, tbl, newName)
|
||||
_, err := db.ExecContext(ctx, q)
|
||||
return errz.Wrapf(errw(err), "alter table: failed to rename table {%s} to {%s}", tbl, newName)
|
||||
}
|
||||
|
||||
// AlterTableRenameColumn implements driver.SQLDriver.
|
||||
func (d *driveri) AlterTableRenameColumn(ctx context.Context, db sqlz.DB, tbl, col, newName string) error {
|
||||
q := fmt.Sprintf("ALTER TABLE %q RENAME COLUMN %q TO %q", tbl, col, newName)
|
||||
_, err := db.ExecContext(ctx, q)
|
||||
return errz.Wrapf(errw(err), "alter table: failed to rename column {%s.%s} to {%s}", tbl, col, newName)
|
||||
}
|
||||
|
||||
// AlterTableAddColumn implements driver.SQLDriver.
|
||||
func (d *driveri) AlterTableAddColumn(ctx context.Context, db sqlz.DB, tbl, col string, knd kind.Kind) error {
|
||||
q := fmt.Sprintf("ALTER TABLE %q ADD COLUMN %q ", tbl, col) + DBTypeForKind(knd)
|
||||
|
||||
_, err := db.ExecContext(ctx, q)
|
||||
if err != nil {
|
||||
return errz.Wrapf(errw(err), "alter table: failed to add column {%s} to table {%s}", col, tbl)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AlterTableColumnKinds implements driver.SQLDriver. Note that SQLite doesn't
|
||||
// really support altering column types, so this is a hacky implementation.
|
||||
// It's not guaranteed that indices, constraints, etc. will be preserved. See:
|
||||
//
|
||||
// - https://www.sqlite.org/lang_altertable.html
|
||||
// - https://www.sqlite.org/faq.html#q11
|
||||
// - https://www.sqlitetutorial.net/sqlite-alter-table/
|
||||
//
|
||||
// Note that colNames and kinds must be the same length.
|
||||
func (d *driveri) AlterTableColumnKinds(ctx context.Context, db sqlz.DB,
|
||||
tblName string, colNames []string, kinds []kind.Kind,
|
||||
) (err error) {
|
||||
if len(colNames) != len(kinds) {
|
||||
return errz.New("sqlite3: alter table: mismatched count of columns and kinds")
|
||||
}
|
||||
|
||||
// It's recommended to disable foreign keys before this alter procedure.
|
||||
if restorePragmaFK, fkErr := pragmaDisableForeignKeys(ctx, db); fkErr != nil {
|
||||
return fkErr
|
||||
} else if restorePragmaFK != nil {
|
||||
defer restorePragmaFK()
|
||||
}
|
||||
|
||||
q := "SELECT sql FROM sqlite_master WHERE type='table' AND name=?"
|
||||
var ogDDL string
|
||||
if err = db.QueryRowContext(ctx, q, tblName).Scan(&ogDDL); err != nil {
|
||||
return errz.Wrapf(err, "sqlite3: alter table: failed to read original DDL")
|
||||
}
|
||||
|
||||
allColDefs, err := sqlparser.ExtractCreateTableStmtColDefs(ogDDL)
|
||||
if err != nil {
|
||||
return errz.Wrapf(err, "sqlite3: alter table: failed to extract column definitions from DDL")
|
||||
}
|
||||
|
||||
var colDefs []*sqlparser.ColDef
|
||||
for i, colName := range colNames {
|
||||
for _, cd := range allColDefs {
|
||||
if cd.Name == colName {
|
||||
colDefs = append(colDefs, cd)
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(colDefs) != i+1 {
|
||||
return errz.Errorf("sqlite3: alter table: column {%s} not found in table DDL", colName)
|
||||
}
|
||||
}
|
||||
|
||||
nuDDL := ogDDL
|
||||
for i, colDef := range colDefs {
|
||||
wantType := DBTypeForKind(kinds[i])
|
||||
wantColDefText := strings.Replace(colDef.Raw, colDef.RawType, wantType, 1)
|
||||
nuDDL = strings.Replace(nuDDL, colDef.Raw, wantColDefText, 1)
|
||||
}
|
||||
|
||||
nuTblName := "tmp_tbl_alter_" + stringz.Uniq32()
|
||||
nuDDL = strings.Replace(nuDDL, tblName, nuTblName, 1)
|
||||
|
||||
if _, err = db.ExecContext(ctx, nuDDL); err != nil {
|
||||
return errz.Wrapf(err, "sqlite3: alter table: failed to create temporary table")
|
||||
}
|
||||
|
||||
copyStmt := fmt.Sprintf(
|
||||
"INSERT INTO %s SELECT * FROM %s",
|
||||
stringz.DoubleQuote(nuTblName),
|
||||
stringz.DoubleQuote(tblName),
|
||||
)
|
||||
if _, err = db.ExecContext(ctx, copyStmt); err != nil {
|
||||
return errz.Wrapf(err, "sqlite3: alter table: failed to copy data to temporary table")
|
||||
}
|
||||
|
||||
// Drop old table
|
||||
if _, err = db.ExecContext(ctx, "DROP TABLE "+stringz.DoubleQuote(tblName)); err != nil {
|
||||
return errz.Wrapf(err, "sqlite3: alter table: failed to drop original table")
|
||||
}
|
||||
|
||||
// Rename new table to old table name
|
||||
if _, err = db.ExecContext(ctx, fmt.Sprintf(
|
||||
"ALTER TABLE %s RENAME TO %s",
|
||||
stringz.DoubleQuote(nuTblName),
|
||||
stringz.DoubleQuote(tblName),
|
||||
)); err != nil {
|
||||
return errz.Wrapf(err, "sqlite3: alter table: failed to rename temporary table")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// pragmaDisableForeignKeys disables foreign keys, returning a function that
|
||||
// restores the original state of the foreign_keys pragma. If an error occurs,
|
||||
// the returned restore function will be nil.
|
||||
func pragmaDisableForeignKeys(ctx context.Context, db sqlz.DB) (restore func(), err error) {
|
||||
pragmaFkExisting, err := readPragma(ctx, db, "foreign_keys")
|
||||
if err != nil {
|
||||
return nil, errz.Wrapf(err, "sqlite3: alter table: failed to read foreign_keys pragma")
|
||||
}
|
||||
|
||||
if _, err = db.ExecContext(ctx, "PRAGMA foreign_keys=off"); err != nil {
|
||||
return nil, errz.Wrapf(err, "sqlite3: alter table: failed to disable foreign_keys pragma")
|
||||
}
|
||||
|
||||
return func() {
|
||||
_, restoreErr := db.ExecContext(ctx, fmt.Sprintf("PRAGMA foreign_keys=%v", pragmaFkExisting))
|
||||
if restoreErr != nil {
|
||||
lg.FromContext(ctx).Error("sqlite3: alter table: failed to restore foreign_keys pragma", lga.Err, restoreErr)
|
||||
}
|
||||
}, nil
|
||||
}
|
143
drivers/sqlite3/internal/sqlparser/create_table.go
Normal file
143
drivers/sqlite3/internal/sqlparser/create_table.go
Normal file
@ -0,0 +1,143 @@
|
||||
package sqlparser
|
||||
|
||||
import (
|
||||
antlr "github.com/antlr4-go/antlr/v4"
|
||||
|
||||
"github.com/neilotoole/sq/drivers/sqlite3/internal/sqlparser/sqlite"
|
||||
"github.com/neilotoole/sq/libsq/ast/antlrz"
|
||||
"github.com/neilotoole/sq/libsq/core/errz"
|
||||
"github.com/neilotoole/sq/libsq/core/stringz"
|
||||
)
|
||||
|
||||
func parseCreateTableStmt(input string) (*sqlite.Create_table_stmtContext, error) {
|
||||
lex := sqlite.NewSQLiteLexer(antlr.NewInputStream(input))
|
||||
lex.RemoveErrorListeners() // the generated lexer has default listeners we don't want
|
||||
lexErrs := &antlrErrorListener{name: "lexer"}
|
||||
lex.AddErrorListener(lexErrs)
|
||||
|
||||
p := sqlite.NewSQLiteParser(antlr.NewCommonTokenStream(lex, 0))
|
||||
p.RemoveErrorListeners() // the generated parser has default listeners we don't want
|
||||
parseErrs := &antlrErrorListener{name: "parser"}
|
||||
p.AddErrorListener(parseErrs)
|
||||
|
||||
qCtx := p.Create_table_stmt()
|
||||
|
||||
if err := lexErrs.error(); err != nil {
|
||||
return nil, errz.Err(err)
|
||||
}
|
||||
|
||||
if err := parseErrs.error(); err != nil {
|
||||
return nil, errz.Err(err)
|
||||
}
|
||||
|
||||
return qCtx.(*sqlite.Create_table_stmtContext), nil
|
||||
}
|
||||
|
||||
// ExtractTableIdentFromCreateTableStmt extracts table name (and the
|
||||
// table's schema if specified) from a CREATE TABLE statement.
|
||||
// If err is nil, table is guaranteed to be non-empty. If arg unescape is
|
||||
// true, any surrounding quotation chars are trimmed from the returned values.
|
||||
//
|
||||
// CREATE TABLE "sakila"."actor" ( actor_id INTEGER NOT NULL) --> sakila, actor, nil
|
||||
func ExtractTableIdentFromCreateTableStmt(stmt string, unescape bool) (schema, table string, err error) {
|
||||
stmtCtx, err := parseCreateTableStmt(stmt)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
if n, ok := stmtCtx.Schema_name().(*sqlite.Schema_nameContext); ok {
|
||||
if n.Any_name() != nil && !n.Any_name().IsEmpty() && n.Any_name().IDENTIFIER() != nil {
|
||||
schema = n.Any_name().IDENTIFIER().GetText()
|
||||
if unescape {
|
||||
schema = trimIdentQuotes(schema)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if x, ok := stmtCtx.Table_name().(*sqlite.Table_nameContext); ok {
|
||||
if x.Any_name() != nil && !x.Any_name().IsEmpty() && x.Any_name().IDENTIFIER() != nil {
|
||||
table = x.Any_name().IDENTIFIER().GetText()
|
||||
if unescape {
|
||||
table = trimIdentQuotes(table)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if table == "" {
|
||||
return "", "", errz.Errorf("failed to extract table name from CREATE TABLE statement")
|
||||
}
|
||||
|
||||
return schema, table, nil
|
||||
}
|
||||
|
||||
// ExtractCreateTableStmtColDefs extracts the column definitions from a CREATE
|
||||
// TABLE statement.
|
||||
func ExtractCreateTableStmtColDefs(stmt string) ([]*ColDef, error) {
|
||||
stmtCtx, err := parseCreateTableStmt(stmt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var colDefs []*ColDef
|
||||
|
||||
tokx := antlrz.NewTokenExtractor(stmt)
|
||||
for _, child := range stmtCtx.GetChildren() {
|
||||
if defCtx, ok := child.(*sqlite.Column_defContext); ok {
|
||||
if defCtx == nil || defCtx.Column_name() == nil {
|
||||
// Shouldn't happen
|
||||
continue
|
||||
}
|
||||
|
||||
if defCtx.Type_name() == nil || defCtx.Type_name().GetText() == "" {
|
||||
// Shouldn't happen
|
||||
continue
|
||||
}
|
||||
|
||||
colDef := &ColDef{
|
||||
DefCtx: defCtx,
|
||||
Raw: tokx.Extract(defCtx),
|
||||
RawName: tokx.Extract(defCtx.Column_name()),
|
||||
Name: stringz.StripDoubleQuote(defCtx.Column_name().GetText()),
|
||||
RawType: tokx.Extract(defCtx.Type_name()),
|
||||
Type: defCtx.Type_name().GetText(),
|
||||
}
|
||||
|
||||
colDef.InputOffset, _ = tokx.Offset(defCtx)
|
||||
|
||||
colDefs = append(colDefs, colDef)
|
||||
}
|
||||
}
|
||||
|
||||
return colDefs, nil
|
||||
}
|
||||
|
||||
// ColDef represents a column definition in a CREATE TABLE statement.
|
||||
type ColDef struct {
|
||||
// DefCtx is the antlr context for the column definition.
|
||||
DefCtx *sqlite.Column_defContext
|
||||
|
||||
// Raw is the raw text of the entire column definition.
|
||||
Raw string
|
||||
|
||||
// RawName is the raw text of the column name as it appeared in the input.
|
||||
// It may be double-quoted.
|
||||
RawName string
|
||||
|
||||
// Name is the column name, stripped of any double-quotes.
|
||||
Name string
|
||||
|
||||
// RawType is the raw text of the column type as it appeared in the input.
|
||||
RawType string
|
||||
|
||||
// Type is the canonicalized column type.
|
||||
Type string
|
||||
|
||||
// InputOffset is the character start index of the column definition in the
|
||||
// input. The def ends at InputOffset+len(Raw).
|
||||
InputOffset int
|
||||
}
|
||||
|
||||
// String returns the raw text of the column definition.
|
||||
func (cd *ColDef) String() string {
|
||||
return cd.Raw
|
||||
}
|
@ -8,48 +8,8 @@ import (
|
||||
"strings"
|
||||
|
||||
antlr "github.com/antlr4-go/antlr/v4"
|
||||
|
||||
"github.com/neilotoole/sq/drivers/sqlite3/internal/sqlparser/sqlite"
|
||||
"github.com/neilotoole/sq/libsq/core/errz"
|
||||
)
|
||||
|
||||
// ExtractTableIdentFromCreateTableStmt extracts table name (and the
|
||||
// table's schema if specified) from a CREATE TABLE statement.
|
||||
// If err is nil, table is guaranteed to be non-empty. If arg unescape is
|
||||
// true, any surrounding quotation chars are trimmed from the returned values.
|
||||
//
|
||||
// CREATE TABLE "sakila"."actor" ( actor_id INTEGER NOT NULL) --> sakila, actor, nil
|
||||
func ExtractTableIdentFromCreateTableStmt(stmt string, unescape bool) (schema, table string, err error) {
|
||||
stmtCtx, err := parseCreateTableStmt(stmt)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
if n, ok := stmtCtx.Schema_name().(*sqlite.Schema_nameContext); ok {
|
||||
if n.Any_name() != nil && !n.Any_name().IsEmpty() && n.Any_name().IDENTIFIER() != nil {
|
||||
schema = n.Any_name().IDENTIFIER().GetText()
|
||||
if unescape {
|
||||
schema = trimIdentQuotes(schema)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if x, ok := stmtCtx.Table_name().(*sqlite.Table_nameContext); ok {
|
||||
if x.Any_name() != nil && !x.Any_name().IsEmpty() && x.Any_name().IDENTIFIER() != nil {
|
||||
table = x.Any_name().IDENTIFIER().GetText()
|
||||
if unescape {
|
||||
table = trimIdentQuotes(table)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if table == "" {
|
||||
return "", "", errz.Errorf("failed to extract table name from CREATE TABLE statement")
|
||||
}
|
||||
|
||||
return schema, table, nil
|
||||
}
|
||||
|
||||
// trimIdentQuotes trims any of the legal quote characters from s.
|
||||
// These are double quote, single quote, backtick, and square brackets.
|
||||
//
|
||||
@ -78,30 +38,6 @@ func trimIdentQuotes(s string) string {
|
||||
return s
|
||||
}
|
||||
|
||||
func parseCreateTableStmt(input string) (*sqlite.Create_table_stmtContext, error) {
|
||||
lex := sqlite.NewSQLiteLexer(antlr.NewInputStream(input))
|
||||
lex.RemoveErrorListeners() // the generated lexer has default listeners we don't want
|
||||
lexErrs := &antlrErrorListener{name: "lexer"}
|
||||
lex.AddErrorListener(lexErrs)
|
||||
|
||||
p := sqlite.NewSQLiteParser(antlr.NewCommonTokenStream(lex, 0))
|
||||
p.RemoveErrorListeners() // the generated parser has default listeners we don't want
|
||||
parseErrs := &antlrErrorListener{name: "parser"}
|
||||
p.AddErrorListener(parseErrs)
|
||||
|
||||
qCtx := p.Create_table_stmt()
|
||||
|
||||
if err := lexErrs.error(); err != nil {
|
||||
return nil, errz.Err(err)
|
||||
}
|
||||
|
||||
if err := parseErrs.error(); err != nil {
|
||||
return nil, errz.Err(err)
|
||||
}
|
||||
|
||||
return qCtx.(*sqlite.Create_table_stmtContext), nil
|
||||
}
|
||||
|
||||
var _ antlr.ErrorListener = (*antlrErrorListener)(nil)
|
||||
|
||||
// antlrErrorListener implements antlr.ErrorListener.
|
||||
|
@ -67,3 +67,38 @@ func TestExtractTableNameFromCreateTableStmt(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractCreateTableStmtColDefs(t *testing.T) {
|
||||
const input = `CREATE TABLE "og_table" (
|
||||
"name" TEXT NOT NULL,
|
||||
"age" INTEGER( 10 ) NOT NULL,
|
||||
weight INTEGER NOT NULL
|
||||
)`
|
||||
|
||||
colDefs, err := sqlparser.ExtractCreateTableStmtColDefs(input)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, colDefs, 3)
|
||||
require.Equal(t, `"name" TEXT NOT NULL`, colDefs[0].Raw)
|
||||
require.Equal(t, `"name"`, colDefs[0].RawName)
|
||||
require.Equal(t, `name`, colDefs[0].Name)
|
||||
require.Equal(t, "TEXT", colDefs[0].Type)
|
||||
require.Equal(t, "TEXT", colDefs[0].RawType)
|
||||
snippet := input[colDefs[0].InputOffset : colDefs[0].InputOffset+len(colDefs[0].Raw)]
|
||||
require.Equal(t, colDefs[0].Raw, snippet)
|
||||
|
||||
require.Equal(t, `"age" INTEGER( 10 ) NOT NULL`, colDefs[1].Raw)
|
||||
require.Equal(t, `"age"`, colDefs[1].RawName)
|
||||
require.Equal(t, `age`, colDefs[1].Name)
|
||||
require.Equal(t, "INTEGER(10)", colDefs[1].Type)
|
||||
require.Equal(t, "INTEGER( 10 )", colDefs[1].RawType)
|
||||
snippet = input[colDefs[1].InputOffset : colDefs[1].InputOffset+len(colDefs[1].Raw)]
|
||||
require.Equal(t, colDefs[1].Raw, snippet)
|
||||
|
||||
require.Equal(t, `weight INTEGER NOT NULL`, colDefs[2].Raw)
|
||||
require.Equal(t, `weight`, colDefs[2].RawName)
|
||||
require.Equal(t, `weight`, colDefs[2].Name)
|
||||
require.Equal(t, "INTEGER", colDefs[2].Type)
|
||||
require.Equal(t, "INTEGER", colDefs[2].RawType)
|
||||
snippet = input[colDefs[2].InputOffset : colDefs[2].InputOffset+len(colDefs[2].Raw)]
|
||||
require.Equal(t, colDefs[2].Raw, snippet)
|
||||
}
|
||||
|
@ -796,32 +796,6 @@ func (d *driveri) ListCatalogs(_ context.Context, _ sqlz.DB) ([]string, error) {
|
||||
return nil, errz.New("sqlite3: catalog mechanism not supported")
|
||||
}
|
||||
|
||||
// AlterTableRename implements driver.SQLDriver.
|
||||
func (d *driveri) AlterTableRename(ctx context.Context, db sqlz.DB, tbl, newName string) error {
|
||||
q := fmt.Sprintf(`ALTER TABLE %q RENAME TO %q`, tbl, newName)
|
||||
_, err := db.ExecContext(ctx, q)
|
||||
return errz.Wrapf(errw(err), "alter table: failed to rename table {%s} to {%s}", tbl, newName)
|
||||
}
|
||||
|
||||
// AlterTableRenameColumn implements driver.SQLDriver.
|
||||
func (d *driveri) AlterTableRenameColumn(ctx context.Context, db sqlz.DB, tbl, col, newName string) error {
|
||||
q := fmt.Sprintf("ALTER TABLE %q RENAME COLUMN %q TO %q", tbl, col, newName)
|
||||
_, err := db.ExecContext(ctx, q)
|
||||
return errz.Wrapf(errw(err), "alter table: failed to rename column {%s.%s} to {%s}", tbl, col, newName)
|
||||
}
|
||||
|
||||
// AlterTableAddColumn implements driver.SQLDriver.
|
||||
func (d *driveri) AlterTableAddColumn(ctx context.Context, db sqlz.DB, tbl, col string, knd kind.Kind) error {
|
||||
q := fmt.Sprintf("ALTER TABLE %q ADD COLUMN %q ", tbl, col) + DBTypeForKind(knd)
|
||||
|
||||
_, err := db.ExecContext(ctx, q)
|
||||
if err != nil {
|
||||
return errz.Wrapf(errw(err), "alter table: failed to add column {%s} to table {%s}", col, tbl)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TableExists implements driver.SQLDriver.
|
||||
func (d *driveri) TableExists(ctx context.Context, db sqlz.DB, tbl string) (bool, error) {
|
||||
const query = `SELECT COUNT(*) FROM sqlite_master WHERE name = ? AND type='table'`
|
||||
|
@ -10,6 +10,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/neilotoole/sq/drivers/sqlite3"
|
||||
"github.com/neilotoole/sq/libsq/core/kind"
|
||||
"github.com/neilotoole/sq/libsq/core/schema"
|
||||
"github.com/neilotoole/sq/libsq/core/sqlz"
|
||||
"github.com/neilotoole/sq/libsq/core/stringz"
|
||||
@ -327,3 +328,66 @@ func TestSQLQuery_Whitespace(t *testing.T) {
|
||||
require.Equal(t, "last name", sink.RecMeta[2].Name())
|
||||
require.Equal(t, "last name", sink.RecMeta[2].MungedName())
|
||||
}
|
||||
|
||||
func TestDriveri_AlterTableColumnKinds(t *testing.T) {
|
||||
th := testh.New(t)
|
||||
src := &source.Source{
|
||||
Handle: "@test",
|
||||
Type: drivertype.SQLite,
|
||||
Location: "sqlite3://" + tu.TempFile(t, "test.db"),
|
||||
}
|
||||
|
||||
ogTbl := &schema.Table{
|
||||
Name: "og_table",
|
||||
PKColName: "",
|
||||
AutoIncrement: false,
|
||||
Cols: nil,
|
||||
}
|
||||
|
||||
ogColName := &schema.Column{
|
||||
Name: "name",
|
||||
Table: ogTbl,
|
||||
Kind: kind.Text,
|
||||
NotNull: true,
|
||||
}
|
||||
ogColAge := &schema.Column{
|
||||
Name: "age",
|
||||
Table: ogTbl,
|
||||
Kind: kind.Int,
|
||||
NotNull: true,
|
||||
}
|
||||
ogColWeight := &schema.Column{
|
||||
Name: "weight",
|
||||
Table: ogTbl,
|
||||
Kind: kind.Int,
|
||||
NotNull: true,
|
||||
}
|
||||
|
||||
ogTbl.Cols = []*schema.Column{ogColName, ogColAge, ogColWeight}
|
||||
grip := th.Open(src)
|
||||
|
||||
db, err := grip.DB(th.Context)
|
||||
require.NoError(t, err)
|
||||
drvr := grip.SQLDriver()
|
||||
|
||||
err = drvr.CreateTable(th.Context, db, ogTbl)
|
||||
require.NoError(t, err)
|
||||
|
||||
gotTblMeta, err := grip.TableMetadata(th.Context, ogTbl.Name)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 3, len(gotTblMeta.Columns))
|
||||
require.Equal(t, kind.Int, gotTblMeta.Column("age").Kind)
|
||||
require.Equal(t, kind.Int, gotTblMeta.Column("weight").Kind)
|
||||
|
||||
alterColNames := []string{"age", "weight"}
|
||||
alterColKinds := []kind.Kind{kind.Text, kind.Float}
|
||||
|
||||
err = drvr.AlterTableColumnKinds(th.Context, db, ogTbl.Name, alterColNames, alterColKinds)
|
||||
require.NoError(t, err)
|
||||
|
||||
gotTblMeta, err = grip.TableMetadata(th.Context, ogTbl.Name)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 3, len(gotTblMeta.Columns))
|
||||
require.Equal(t, kind.Text.String(), gotTblMeta.Column("age").Kind.String())
|
||||
require.Equal(t, kind.Float.String(), gotTblMeta.Column("weight").Kind.String())
|
||||
}
|
||||
|
@ -625,6 +625,11 @@ func (d *driveri) AlterTableRenameColumn(ctx context.Context, db sqlz.DB, tbl, c
|
||||
return errz.Wrapf(errw(err), "alter table: failed to rename column {%s.%s.%s} to {%s}", schma, tbl, col, newName)
|
||||
}
|
||||
|
||||
// AlterTableColumnKinds is not yet implemented for sqlserver.
|
||||
func (d *driveri) AlterTableColumnKinds(_ context.Context, _ sqlz.DB, _ string, _ []string, _ []kind.Kind) error {
|
||||
return errz.New("not implemented")
|
||||
}
|
||||
|
||||
// CopyTable implements driver.SQLDriver.
|
||||
func (d *driveri) CopyTable(ctx context.Context, db sqlz.DB,
|
||||
fromTable, toTable tablefq.T, copyData bool,
|
||||
|
2
go.mod
2
go.mod
@ -10,6 +10,7 @@ require (
|
||||
github.com/alessio/shellescape v1.4.2
|
||||
github.com/antlr4-go/antlr/v4 v4.13.0
|
||||
github.com/c2h5oh/datasize v0.0.0-20231215233829-aa82cc1e6500
|
||||
github.com/djherbis/buffer v1.2.0
|
||||
github.com/dustin/go-humanize v1.0.1
|
||||
github.com/ecnepsnai/osquery v1.0.1
|
||||
github.com/emirpasic/gods v1.18.1
|
||||
@ -61,7 +62,6 @@ require (
|
||||
github.com/VividCortex/ewma v1.2.0 // indirect
|
||||
github.com/acarl005/stripansi v0.0.0-20180116102854-5a71ef0e047d // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/djherbis/buffer v1.2.0 // indirect
|
||||
github.com/felixge/fgprof v0.9.4 // indirect
|
||||
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
|
||||
github.com/golang-sql/sqlexp v0.1.0 // indirect
|
||||
|
79
libsq/ast/antlrz/antlrz.go
Normal file
79
libsq/ast/antlrz/antlrz.go
Normal file
@ -0,0 +1,79 @@
|
||||
// Package antlrz contains utilities for working with ANTLR4.
|
||||
package antlrz
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
antlr "github.com/antlr4-go/antlr/v4"
|
||||
)
|
||||
|
||||
// TokenExtractor extracts the raw text of a parser rule from the input.
|
||||
type TokenExtractor struct {
|
||||
input string
|
||||
lines []string
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
// NewTokenExtractor returns a new TokenExtractor.
|
||||
func NewTokenExtractor(input string) *TokenExtractor {
|
||||
return &TokenExtractor{input: input}
|
||||
}
|
||||
|
||||
// Offset returns the start and stop (inclusive) offsets of the parser rule in
|
||||
// the input.
|
||||
func (l *TokenExtractor) Offset(prc antlr.ParserRuleContext) (start, stop int) {
|
||||
l.once.Do(func() {
|
||||
l.lines = strings.Split(l.input, "\n")
|
||||
})
|
||||
|
||||
startToken := prc.GetStart()
|
||||
startLine := startToken.GetLine() - 1
|
||||
startCol := startToken.GetColumn()
|
||||
|
||||
stopToken := prc.GetStop()
|
||||
stopLine := stopToken.GetLine() - 1
|
||||
stopCol := stopToken.GetColumn() + len(stopToken.GetText())
|
||||
|
||||
for i := 0; i < startLine; i++ {
|
||||
startCol += len(l.lines[i]) + 1
|
||||
}
|
||||
|
||||
for i := 0; i < stopLine; i++ {
|
||||
stopCol += len(l.lines[i]) + 1
|
||||
}
|
||||
|
||||
return startCol, stopCol
|
||||
}
|
||||
|
||||
// Extract extracts the raw text of the parser rule from the input. It may panic
|
||||
// if the parser rule is not found in the input.
|
||||
func (l *TokenExtractor) Extract(prc antlr.ParserRuleContext) string {
|
||||
l.once.Do(func() {
|
||||
l.lines = strings.Split(l.input, "\n")
|
||||
})
|
||||
|
||||
startToken := prc.GetStart()
|
||||
startLine := startToken.GetLine() - 1
|
||||
startCol := startToken.GetColumn()
|
||||
|
||||
stopToken := prc.GetStop()
|
||||
stopLine := stopToken.GetLine() - 1
|
||||
stopCol := stopToken.GetColumn() + len(stopToken.GetText())
|
||||
|
||||
if startLine == stopLine {
|
||||
return l.lines[startLine][startCol:stopCol]
|
||||
}
|
||||
|
||||
// multi-line
|
||||
var sb strings.Builder
|
||||
sb.WriteString(l.lines[startLine][startCol:])
|
||||
sb.WriteString("\n")
|
||||
for i := startLine + 1; i < stopLine; i++ {
|
||||
sb.WriteString(l.lines[i])
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
sb.WriteString(l.lines[stopLine][:stopCol])
|
||||
|
||||
return sb.String()
|
||||
}
|
@ -2,8 +2,6 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/neilotoole/sq/libsq/core/errz"
|
||||
"github.com/neilotoole/sq/libsq/core/kind"
|
||||
)
|
||||
@ -28,6 +26,42 @@ type Table struct { //nolint:govet // field alignment
|
||||
Cols []*Column `json:"cols"`
|
||||
}
|
||||
|
||||
func (t *Table) Equal(b *Table) bool {
|
||||
if t == nil && b == nil {
|
||||
return true
|
||||
}
|
||||
if t == nil || b == nil {
|
||||
return false
|
||||
}
|
||||
if t == b {
|
||||
return true
|
||||
}
|
||||
|
||||
if t.Name != b.Name {
|
||||
return false
|
||||
}
|
||||
|
||||
if t.PKColName != b.PKColName {
|
||||
return false
|
||||
}
|
||||
|
||||
if t.AutoIncrement != b.AutoIncrement {
|
||||
return false
|
||||
}
|
||||
|
||||
if len(t.Cols) != len(b.Cols) {
|
||||
return false
|
||||
}
|
||||
|
||||
for i, col := range t.Cols {
|
||||
if !col.Equal(b.Cols[i]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// NewTable is a convenience constructor for creating
|
||||
// a simple table definition.
|
||||
func NewTable(tblName string, colNames []string, colKinds []kind.Kind) *Table {
|
||||
@ -63,7 +97,7 @@ func (t *Table) ColKinds() []kind.Kind {
|
||||
}
|
||||
|
||||
func (t *Table) String() string {
|
||||
return t.Name + "(" + strings.Join(t.ColNames(), ",") + ")"
|
||||
return t.Name
|
||||
}
|
||||
|
||||
// ColsByName returns the ColDefs for each named column, or an error if any column
|
||||
@ -114,6 +148,52 @@ type Column struct { //nolint:govet // field alignment
|
||||
ForeignKey *FKConstraint `json:"foreign_key,omitempty"`
|
||||
}
|
||||
|
||||
func (c *Column) Equal(b *Column) bool {
|
||||
if c == nil && b == nil {
|
||||
return true
|
||||
}
|
||||
if c == nil || b == nil {
|
||||
return false
|
||||
}
|
||||
if c == b {
|
||||
return true
|
||||
}
|
||||
|
||||
if c.Name != b.Name {
|
||||
return false
|
||||
}
|
||||
|
||||
if c.Kind != b.Kind {
|
||||
return false
|
||||
}
|
||||
|
||||
if c.NotNull != b.NotNull {
|
||||
return false
|
||||
}
|
||||
|
||||
if c.HasDefault != b.HasDefault {
|
||||
return false
|
||||
}
|
||||
|
||||
if c.Size != b.Size {
|
||||
return false
|
||||
}
|
||||
|
||||
if c.Unique != b.Unique {
|
||||
return false
|
||||
}
|
||||
|
||||
if !c.ForeignKey.Equal(b.ForeignKey) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *Column) String() string {
|
||||
return c.Name
|
||||
}
|
||||
|
||||
// FKConstraint models a foreign key constraint.
|
||||
type FKConstraint struct {
|
||||
// RefTable is the name of the referenced parent table.
|
||||
@ -125,3 +205,33 @@ type FKConstraint struct {
|
||||
// OnUpdate is one of CASCADE or SET_NULL, defaults to CASCADE.
|
||||
OnUpdate string `json:"on_update"`
|
||||
}
|
||||
|
||||
func (fk *FKConstraint) Equal(b *FKConstraint) bool {
|
||||
if fk == nil && b == nil {
|
||||
return true
|
||||
}
|
||||
if fk == nil || b == nil {
|
||||
return false
|
||||
}
|
||||
if fk == b {
|
||||
return true
|
||||
}
|
||||
|
||||
if fk.RefTable != b.RefTable {
|
||||
return false
|
||||
}
|
||||
|
||||
if fk.RefCol != b.RefCol {
|
||||
return false
|
||||
}
|
||||
|
||||
if fk.OnDelete != b.OnDelete {
|
||||
return false
|
||||
}
|
||||
|
||||
if fk.OnUpdate != b.OnUpdate {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
@ -193,6 +193,11 @@ type SQLDriver interface {
|
||||
// AlterTableRenameColumn renames a column.
|
||||
AlterTableRenameColumn(ctx context.Context, db sqlz.DB, tbl, col, newName string) error
|
||||
|
||||
// AlterTableColumnKinds changes the kinds of colNames in tbl. The length of
|
||||
// args colNames and kinds must be equal. The method may create a transaction
|
||||
// internally if db is not already a transaction.
|
||||
AlterTableColumnKinds(ctx context.Context, db sqlz.DB, tbl string, colNames []string, kinds []kind.Kind) error
|
||||
|
||||
// DBProperties returns a map of key-value database properties. The value
|
||||
// is often a scalar such as an int, string, or bool, but can be a nested
|
||||
// map or array.
|
||||
|
Loading…
Reference in New Issue
Block a user