#415 - bug with ingesting large JSON file (#418)

* JSON large data file
This commit is contained in:
Neil O'Toole 2024-03-11 16:38:34 -06:00 committed by GitHub
parent f445f67b44
commit 5f531ffb9d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 1176988 additions and 133 deletions

View File

@ -145,8 +145,8 @@ type objectValueSet map[*entity]map[string]any
// processor process JSON objects.
type processor struct {
root *entity
importSchema *ingestSchema
root *entity
curSchema *ingestSchema
// schemaDirtyEntities tracks entities whose structure have been modified.
schemaDirtyEntities map[*entity]struct{}
@ -162,9 +162,13 @@ type processor struct {
func newProcessor(flatten bool) *processor {
return &processor{
flatten: flatten,
importSchema: &ingestSchema{},
root: &entity{name: source.MonotableName, detectors: map[string]*kind.Detector{}},
flatten: flatten,
curSchema: nil,
root: &entity{
name: source.MonotableName,
detectors: map[string]*kind.Detector{},
kinds: map[string]kind.Kind{},
},
schemaDirtyEntities: map[*entity]struct{}{},
}
}
@ -263,7 +267,7 @@ func (p *processor) buildSchemaFlat() (*ingestSchema, error) {
// processObject processes the parsed JSON object m. If the structure
// of the ingestSchema changes due to this object, dirtySchema returns true.
func (p *processor) processObject(m map[string]any, chunk []byte) (dirtySchema bool, err error) {
p.curObjVals = objectValueSet{}
p.curObjVals = make(objectValueSet)
err = p.doAddObject(p.root, m)
dirtySchema = len(p.schemaDirtyEntities) > 0
if err != nil {
@ -296,6 +300,8 @@ func (p *processor) updateColNames(chunk []byte) error {
}
func (p *processor) doAddObject(ent *entity, m map[string]any) error {
var err error
for fieldName, val := range m {
switch val := val.(type) {
case map[string]any:
@ -315,6 +321,7 @@ func (p *processor) doAddObject(ent *entity, m map[string]any) error {
name: fieldName,
parent: ent,
detectors: map[string]*kind.Detector{},
kinds: map[string]kind.Kind{},
}
ent.children = append(ent.children, child)
} else if child.isArray {
@ -324,8 +331,7 @@ func (p *processor) doAddObject(ent *entity, m map[string]any) error {
ent.String())
}
err := p.doAddObject(child, val)
if err != nil {
if err = p.doAddObject(child, val); err != nil {
return err
}
@ -358,14 +364,64 @@ func (p *processor) doAddObject(ent *entity, m map[string]any) error {
colName := p.calcColName(ent, fieldName)
entVals[colName] = val
val = maybeFloatToInt(val)
detector.Sample(val)
colDef := p.getColDef(ent, colName)
if colDef == nil && val != nil {
val = maybeFloatToInt(val)
// We don't need to keep sampling after we've detected the kind.
detector.Sample(val)
} else
// REVISIT: We don't need to hold onto the samples after we've detected
// the kind, it's just holding onto memory. We should probably nil out
// the detector.
// The column is already defined. Check if the value is allowed.
if !p.fieldValAllowed(detector, colDef, val) {
p.markSchemaDirty(ent)
}
}
}
return nil
}
func (p *processor) fieldValAllowed(detector *kind.Detector, col *schema.Column, val any) bool {
if val == nil || col == nil {
return true
}
if col.Kind == kind.Null || col.Kind == kind.Unknown || col.Kind == kind.Text {
return true
}
detector.Sample(val)
k, _, err := detector.Detect()
if err != nil || k != col.Kind {
return false
}
return true
}
// getColDef returns the schema.Column, or nil if not existing.
func (p *processor) getColDef(ent *entity, colName string) *schema.Column {
if p == nil || p.curSchema == nil || p.curSchema.entityTbls == nil {
return nil
}
tblDef, ok := p.curSchema.entityTbls[ent]
if !ok || tblDef == nil {
return nil
}
colDef, err := tblDef.FindCol(colName)
if err != nil {
return nil
}
return colDef
}
// buildInsertionsFlat builds a set of DB insertions from the
// processor's unwrittenObjVals. After a non-error return, unwrittenObjVals
// is empty.
@ -417,6 +473,10 @@ type entity struct {
// field etc, but not for an object or array field.
detectors map[string]*kind.Detector
// kinds is the sibling of detectors, holding a kind.Kind for each field,
// once the detector has detected the kind.
kinds map[string]kind.Kind
name string
children []*entity
@ -487,31 +547,117 @@ type ingestSchema struct {
tblDefs []*schema.Table
}
func (s *ingestSchema) getTable(name string) *schema.Table {
if s == nil {
return nil
}
for _, tbl := range s.tblDefs {
if tbl.Name == name {
return tbl
}
}
return nil
}
// execSchemaDelta executes the schema delta between curSchema and newSchema.
// That is, if curSchema is nil, then newSchema is created in the DB; if
// newSchema has additional tables or columns, then those are created in the DB.
//
// TODO: execSchemaDelta is only partially implemented; it doesn't create
// the new tables/columns.
func execSchemaDelta(ctx context.Context, drvr driver.SQLDriver, db sqlz.DB,
curSchema, newSchema *ingestSchema,
) error {
log := lg.FromContext(ctx)
var err error
if curSchema == nil {
for _, tblDef := range newSchema.tblDefs {
err = drvr.CreateTable(ctx, db, tblDef)
for _, tbl := range newSchema.tblDefs {
err = drvr.CreateTable(ctx, db, tbl)
if err != nil {
return err
}
log.Debug("Created table", lga.Table, tblDef.Name)
log.Debug("Created table", lga.Table, tbl.Name)
}
return nil
}
// TODO: implement execSchemaDelta fully.
return errz.New("schema delta not yet implemented")
var alterTbls []*schema.Table
var createTbls []*schema.Table
for _, newTbl := range newSchema.tblDefs {
oldTbl := curSchema.getTable(newTbl.Name)
if oldTbl == nil {
createTbls = append(createTbls, newTbl)
} else if !oldTbl.Equal(newTbl) {
alterTbls = append(alterTbls, newTbl)
}
}
for _, wantTbl := range alterTbls {
oldTbl := curSchema.getTable(wantTbl.Name)
if err = execMaybeAlterTable(ctx, drvr, db, oldTbl, wantTbl); err != nil {
return err
}
}
for _, wantTbl := range createTbls {
err = drvr.CreateTable(ctx, db, wantTbl)
if err != nil {
return err
}
log.Debug("Created table", lga.Table, wantTbl.Name)
}
return nil
}
func execMaybeAlterTable(ctx context.Context, drvr driver.SQLDriver, db sqlz.DB,
oldTbl, newTbl *schema.Table,
) error {
log := lg.FromContext(ctx)
if newTbl == nil {
return nil
}
if oldTbl == nil {
return drvr.CreateTable(ctx, db, newTbl)
}
if oldTbl.Equal(newTbl) {
return nil
}
tblName := newTbl.Name
var createCols []*schema.Column
var wantAlterColNames []string
var wantAlterColKinds []kind.Kind
for _, newCol := range newTbl.Cols {
oldCol, err := oldTbl.FindCol(newCol.Name)
if err != nil {
createCols = append(createCols, newCol)
} else if newCol.Kind != oldCol.Kind {
wantAlterColNames = append(wantAlterColNames, newCol.Name)
wantAlterColKinds = append(wantAlterColKinds, newCol.Kind)
}
}
if len(wantAlterColNames) > 0 {
err := drvr.AlterTableColumnKinds(ctx, db, tblName, wantAlterColNames, wantAlterColKinds)
if err != nil {
return err
}
}
for _, col := range createCols {
if err := drvr.AlterTableAddColumn(ctx, db, tblName, col.Name, col.Kind); err != nil {
return err
}
log.Debug("Added column", lga.Table, newTbl.Name, lga.Col, col.Name)
}
return nil
}
// columnOrderFlat parses the json chunk and returns a slice

View File

@ -6,12 +6,14 @@ import (
stdj "encoding/json"
"fmt"
"io"
"log/slog"
"time"
"github.com/neilotoole/sq/libsq/core/errz"
"github.com/neilotoole/sq/libsq/core/lg"
"github.com/neilotoole/sq/libsq/core/lg/lga"
"github.com/neilotoole/sq/libsq/core/lg/lgm"
"github.com/neilotoole/sq/libsq/core/progress"
"github.com/neilotoole/sq/libsq/core/stringz"
"github.com/neilotoole/sq/libsq/files"
"github.com/neilotoole/sq/libsq/source/drivertype"
@ -105,7 +107,7 @@ func DetectJSON(sampleSize int) files.TypeDetectFunc {
}
defer lg.WarnIfCloseError(log, lgm.CloseFileReader, r2)
sc := newObjectInArrayScanner(r2)
sc := newObjectInArrayScanner(log, r2)
var validObjCount int
var obj map[string]any
@ -140,6 +142,9 @@ func DetectJSON(sampleSize int) files.TypeDetectFunc {
}
func ingestJSON(ctx context.Context, job *ingestJob) error {
bar := progress.FromContext(ctx).NewUnitCounter("Ingest JSON", "object")
defer bar.Stop()
log := lg.FromContext(ctx)
defer lg.WarnIfCloseError(log, "Close JSON ingest job", job)
@ -163,13 +168,12 @@ func ingestJSON(ctx context.Context, job *ingestJob) error {
defer lg.WarnIfCloseError(log, lgm.CloseDB, conn)
proc := newProcessor(job.flatten)
scan := newObjectInArrayScanner(r)
scan := newObjectInArrayScanner(log, r)
var (
obj map[string]any
chunk []byte
schemaModified bool
curSchema *ingestSchema
insertions []*insertion
hasMore bool
)
@ -182,11 +186,14 @@ func ingestJSON(ctx context.Context, job *ingestJob) error {
// obj is returned nil by scan.next when end of input.
hasMore = obj != nil
if hasMore {
bar.Incr(1)
}
if schemaModified {
if !hasMore || scan.objCount >= job.sampleSize {
log.Debug("Time to (re)build the schema", lga.Line, scan.objCount)
if curSchema == nil {
if proc.curSchema == nil {
log.Debug("First time building the schema")
}
@ -195,7 +202,7 @@ func ingestJSON(ctx context.Context, job *ingestJob) error {
return err
}
if err = execSchemaDelta(ctx, drvr, conn, curSchema, newSchema); err != nil {
if err = execSchemaDelta(ctx, drvr, conn, proc.curSchema, newSchema); err != nil {
return err
}
@ -203,10 +210,11 @@ func ingestJSON(ctx context.Context, job *ingestJob) error {
// so we mark it as clean.
proc.markSchemaClean()
curSchema = newSchema
// curSchema = newSchema
proc.curSchema = newSchema
newSchema = nil //nolint:wastedassign
if insertions, err = proc.buildInsertionsFlat(curSchema); err != nil {
if insertions, err = proc.buildInsertionsFlat(proc.curSchema); err != nil {
return err
}
@ -214,11 +222,11 @@ func ingestJSON(ctx context.Context, job *ingestJob) error {
return err
}
}
}
if !hasMore {
// We're done
break
}
if !hasMore {
// end of input
break
}
if schemaModified, err = proc.processObject(obj, chunk); err != nil {
@ -227,7 +235,7 @@ func ingestJSON(ctx context.Context, job *ingestJob) error {
// Initial schema has not been created: we're still in
// the sampling phase. So we loop.
if curSchema == nil {
if proc.curSchema == nil {
continue
}
@ -240,7 +248,7 @@ func ingestJSON(ctx context.Context, job *ingestJob) error {
// The schema exists in the DB, and the current JSON chunk hasn't
// dirtied the schema, so it's safe to insert the recent rows.
if insertions, err = proc.buildInsertionsFlat(curSchema); err != nil {
if insertions, err = proc.buildInsertionsFlat(proc.curSchema); err != nil {
return err
}
@ -260,6 +268,7 @@ func ingestJSON(ctx context.Context, job *ingestJob) error {
// JSON objects, returning the decoded object and the chunk of JSON
// that it was scanned from. Example input: [{a:1},{a:2},{a:3}].
type objectsInArrayScanner struct {
log *slog.Logger
// buf will get all the data that the JSON decoder reads.
// buf's role is to keep track of JSON text that has already been
// consumed by dec, so that we can return the raw JSON chunk
@ -294,13 +303,13 @@ type objectsInArrayScanner struct {
// newObjectInArrayScanner returns a new instance that
// reads from r.
func newObjectInArrayScanner(r io.Reader) *objectsInArrayScanner {
func newObjectInArrayScanner(log *slog.Logger, r io.Reader) *objectsInArrayScanner {
buf := &buffer{b: []byte{}}
// Everything that dec reads from r is written
// to buf via the TeeReader.
dec := stdj.NewDecoder(io.TeeReader(r, buf))
return &objectsInArrayScanner{buf: buf, dec: dec}
return &objectsInArrayScanner{log: log, buf: buf, dec: dec}
}
// next scans the next object from the reader. The returned chunk holds
@ -358,10 +367,19 @@ func (s *objectsInArrayScanner) next() (obj map[string]any, chunk []byte, err er
return nil, nil, errz.Err(err)
}
more = s.dec.More()
var delimIndex int
var delim byte
if len(s.decBuf) == 0 {
// We've landed right on the edge of the chunk, there'll be no delim (or any
// other char), so we skip over delim searching.
goto BOTTOM
}
more = s.dec.More() // REVISIT: Should we not be testing this value?
// Peek ahead in the decoder buffer
delimIndex, delim := nextDelim(s.decBuf, 0, true)
delimIndex, delim = nextDelim(s.decBuf, 0, true)
if delimIndex == -1 {
return nil, nil, errz.New("invalid JSON: additional input expected")
}
@ -403,6 +421,7 @@ func (s *objectsInArrayScanner) next() (obj map[string]any, chunk []byte, err er
}
}
BOTTOM:
// Note that we re-use the vars delimIndex and delim here.
// Above us, these vars referred to s.decBuf, not s.buf as here.
delimIndex, delim = nextDelim(s.buf.b, s.prevDecPos-s.bufOffset, false)

View File

@ -12,6 +12,7 @@ import (
"github.com/neilotoole/sq/libsq/core/lg"
"github.com/neilotoole/sq/libsq/core/lg/lga"
"github.com/neilotoole/sq/libsq/core/lg/lgm"
"github.com/neilotoole/sq/libsq/core/progress"
"github.com/neilotoole/sq/libsq/files"
"github.com/neilotoole/sq/libsq/source/drivertype"
)
@ -89,6 +90,8 @@ func DetectJSONL(sampleSize int) files.TypeDetectFunc {
func ingestJSONL(ctx context.Context, job *ingestJob) error { //nolint:gocognit
log := lg.FromContext(ctx)
defer lg.WarnIfCloseError(log, "Close JSONL ingest job", job)
bar := progress.FromContext(ctx).NewUnitCounter("Ingest JSONL", "object")
defer bar.Stop()
r, err := job.newRdrFn(ctx)
if err != nil {
@ -124,6 +127,10 @@ func ingestJSONL(ctx context.Context, job *ingestJob) error { //nolint:gocognit
return err
}
if hasMore {
bar.Incr(1)
}
if schemaModified {
if !hasMore || scan.validLineCount >= job.sampleSize {
log.Debug("Time to (re)build the schema", lga.Line, scan.totalLineCount)

View File

@ -12,14 +12,26 @@ import (
"github.com/stretchr/testify/require"
"github.com/neilotoole/sq/cli/testrun"
"github.com/neilotoole/sq/drivers/json"
"github.com/neilotoole/sq/libsq/core/kind"
"github.com/neilotoole/sq/libsq/core/lg/lgt"
"github.com/neilotoole/sq/libsq/driver"
"github.com/neilotoole/sq/libsq/source"
"github.com/neilotoole/sq/libsq/source/drivertype"
"github.com/neilotoole/sq/libsq/source/metadata"
"github.com/neilotoole/sq/testh"
"github.com/neilotoole/sq/testh/proj"
"github.com/neilotoole/sq/testh/sakila"
"github.com/neilotoole/sq/testh/testsrc"
"github.com/neilotoole/sq/testh/tu"
)
const (
citiesLargeObjCount = 146994
citiesSmallObjCount = 3
)
func BenchmarkIngestJSONL_Flat(b *testing.B) {
// $ go test -count=10 -benchtime=5s -bench BenchmarkIngestJSONL_Flat > old.bench.txt
// # Make changes
@ -303,9 +315,10 @@ func TestScanObjectsInArray(t *testing.T) {
t.Run(tu.Name(i, tc.in), func(t *testing.T) {
t.Parallel()
log := lgt.New(t)
r := bytes.NewReader([]byte(tc.in))
gotObjs, gotChunks, err := json.ScanObjectsInArray(r)
gotObjs, gotChunks, err := json.ScanObjectsInArray(log, r)
if tc.wantErr {
require.Error(t, err)
return
@ -332,6 +345,8 @@ func TestScanObjectsInArray_Files(t *testing.T) {
{fname: "testdata/actor.json", wantCount: sakila.TblActorCount},
{fname: "testdata/film_actor.json", wantCount: sakila.TblFilmActorCount},
{fname: "testdata/payment.json", wantCount: sakila.TblPaymentCount},
{fname: "testdata/cities.small.json", wantCount: citiesSmallObjCount},
{fname: "testdata/cities.large.json", wantCount: citiesLargeObjCount},
}
for _, tc := range testCases {
@ -339,12 +354,13 @@ func TestScanObjectsInArray_Files(t *testing.T) {
t.Run(tu.Name(tc.fname), func(t *testing.T) {
t.Parallel()
log := lgt.New(t)
f, err := os.Open(tc.fname)
require.NoError(t, err)
defer f.Close()
gotObjs, gotChunks, err := json.ScanObjectsInArray(f)
gotObjs, gotChunks, err := json.ScanObjectsInArray(log, f)
require.NoError(t, err)
require.Equal(t, tc.wantCount, len(gotObjs))
require.Equal(t, tc.wantCount, len(gotChunks))
@ -398,3 +414,71 @@ func TestColumnOrderFlat(t *testing.T) {
})
}
}
func TestJSONData_Cities(t *testing.T) {
t.Parallel()
tu.SkipWindows(t, "Takes too long on Windows CI")
tu.SkipShort(t, true)
const wantCSV = `name,lat,lng,country,admin1,admin2
Sant Julià de Lòria,42.46372,1.49129,AD,6,
Pas de la Casa,42.54277,1.73361,AD,3,
Ordino,42.55623,1.53319,AD,5,`
testCases := []struct {
name string
wantRowCount int
}{
{
name: "cities.small.json",
wantRowCount: citiesSmallObjCount,
},
{
name: "cities.large.json",
wantRowCount: citiesLargeObjCount,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
tr := testrun.New(context.Background(), t, nil).Hush()
src := &source.Source{
Handle: "@cities",
Type: drivertype.JSON,
Location: proj.Abs(filepath.Join("drivers/json/testdata/", tc.name)),
}
tr = tr.Add(*src)
err := tr.Exec("--csv", ".data | .[0:3]")
require.NoError(t, err)
require.Equal(t, wantCSV, tr.OutString())
// FIXME: test inspect table row count
})
}
}
func Test_Cities_MixedFieldKind(t *testing.T) {
src := &source.Source{
Handle: "@cities",
Type: drivertype.JSON,
Location: proj.Abs("drivers/json/testdata/cities.sample-mixed-10.json"),
}
tr := testrun.New(context.Background(), t, nil).Hush()
tr.Add(*src)
require.NoError(t, tr.Exec("config", "set", driver.OptIngestSampleSize.Key(), "2"))
err := tr.Reset().Exec("inspect", "-j")
require.NoError(t, err)
var md *metadata.Source
tr.Bind(&md)
t.Log(md.Name)
colAdmin1 := md.Table("data").Column("admin1")
require.NotNil(t, colAdmin1)
require.Equal(t, kind.Text.String(), colAdmin1.Kind.String())
}

View File

@ -4,6 +4,7 @@ import (
"context"
"fmt"
"io"
"log/slog"
"os"
"testing"
@ -75,8 +76,8 @@ func TestDetectColKindsJSONA(t *testing.T) {
// ScanObjectsInArray is a convenience function
// for objectsInArrayScanner.
func ScanObjectsInArray(r io.Reader) (objs []map[string]any, chunks [][]byte, err error) {
sc := newObjectInArrayScanner(r)
func ScanObjectsInArray(log *slog.Logger, r io.Reader) (objs []map[string]any, chunks [][]byte, err error) {
sc := newObjectInArrayScanner(log, r)
for {
var obj map[string]any

26
drivers/json/testdata/README.md vendored Normal file
View File

@ -0,0 +1,26 @@
# drivers/json/testdata
## Cities
The various "cities" JSON files contain an array of data like this:
```json
{
"name": "Sant Julià de Lòria",
"lat": "42.46372",
"lng": "1.49129",
"country": "AD",
"admin1": "06",
"admin2": ""
}
```
- [`cities.large.json`](cities.large.json) is `146,994` cities (~20MB uncompressed).
- [`cities.small.json`](cities.small.json) 4 city objects.
- [`cities.sample-mixed-10.json`](cities.sample-mixed-10.json) contains 10 objects, but the `admin1` field
of the third object is the string `TEXT` instead of a number.

1175954
drivers/json/testdata/cities.large.json vendored Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,82 @@
[
{
"name": "Sant Julià de Lòria",
"lat": "42.46372",
"lng": "1.49129",
"country": "AD",
"admin1": "06",
"admin2": ""
},
{
"name": "Pas de la Casa",
"lat": "42.54277",
"lng": "1.73361",
"country": "AD",
"admin1": "03",
"admin2": ""
},
{
"name": "Ordino",
"lat": "42.55623",
"lng": "1.53319",
"country": "AD",
"admin1": "TEXT",
"admin2": ""
},
{
"name": "les Escaldes",
"lat": "42.50729",
"lng": "1.53414",
"country": "AD",
"admin1": "08",
"admin2": ""
},
{
"name": "la Massana",
"lat": "42.54499",
"lng": "1.51483",
"country": "AD",
"admin1": "04",
"admin2": ""
},
{
"name": "Encamp",
"lat": "42.53474",
"lng": "1.58014",
"country": "AD",
"admin1": "03",
"admin2": ""
},
{
"name": "Canillo",
"lat": "42.5676",
"lng": "1.59756",
"country": "AD",
"admin1": "02",
"admin2": ""
},
{
"name": "Arinsal",
"lat": "42.57205",
"lng": "1.48453",
"country": "AD",
"admin1": "04",
"admin2": ""
},
{
"name": "Andorra la Vella",
"lat": "42.50779",
"lng": "1.52109",
"country": "AD",
"admin1": "07",
"admin2": ""
},
{
"name": "Umm Al Quwain City",
"lat": "25.56473",
"lng": "55.55517",
"country": "AE",
"admin1": "07",
"admin2": ""
}
]

26
drivers/json/testdata/cities.small.json vendored Normal file
View File

@ -0,0 +1,26 @@
[
{
"name": "Sant Julià de Lòria",
"lat": "42.46372",
"lng": "1.49129",
"country": "AD",
"admin1": "06",
"admin2": ""
},
{
"name": "Pas de la Casa",
"lat": "42.54277",
"lng": "1.73361",
"country": "AD",
"admin1": "03",
"admin2": ""
},
{
"name": "Ordino",
"lat": "42.55623",
"lng": "1.53319",
"country": "AD",
"admin1": "05",
"admin2": ""
}
]

View File

@ -354,6 +354,11 @@ func (d *driveri) AlterTableRenameColumn(ctx context.Context, db sqlz.DB, tbl, c
return errz.Wrapf(errw(err), "alter table: failed to rename column {%s.%s} to {%s}", tbl, col, newName)
}
// AlterTableColumnKinds is not yet implemented for mysql.
func (d *driveri) AlterTableColumnKinds(_ context.Context, _ sqlz.DB, _ string, _ []string, _ []kind.Kind) error {
return errz.New("not implemented")
}
// PrepareInsertStmt implements driver.SQLDriver.
func (d *driveri) PrepareInsertStmt(ctx context.Context, db sqlz.DB, destTbl string, destColNames []string,
numRows int,

View File

@ -457,6 +457,11 @@ func (d *driveri) AlterTableAddColumn(ctx context.Context, db sqlz.DB, tbl, col
return nil
}
// AlterTableColumnKinds is not yet implemented for postgres.
func (d *driveri) AlterTableColumnKinds(_ context.Context, _ sqlz.DB, _ string, _ []string, _ []kind.Kind) error {
return errz.New("not implemented")
}
// PrepareInsertStmt implements driver.SQLDriver.
func (d *driveri) PrepareInsertStmt(ctx context.Context, db sqlz.DB, destTbl string, destColNames []string,
numRows int,

149
drivers/sqlite3/alter.go Normal file
View File

@ -0,0 +1,149 @@
package sqlite3
import (
"context"
"fmt"
"strings"
"github.com/neilotoole/sq/drivers/sqlite3/internal/sqlparser"
"github.com/neilotoole/sq/libsq/core/errz"
"github.com/neilotoole/sq/libsq/core/kind"
"github.com/neilotoole/sq/libsq/core/lg"
"github.com/neilotoole/sq/libsq/core/lg/lga"
"github.com/neilotoole/sq/libsq/core/sqlz"
"github.com/neilotoole/sq/libsq/core/stringz"
)
// AlterTableRename implements driver.SQLDriver.
func (d *driveri) AlterTableRename(ctx context.Context, db sqlz.DB, tbl, newName string) error {
q := fmt.Sprintf(`ALTER TABLE %q RENAME TO %q`, tbl, newName)
_, err := db.ExecContext(ctx, q)
return errz.Wrapf(errw(err), "alter table: failed to rename table {%s} to {%s}", tbl, newName)
}
// AlterTableRenameColumn implements driver.SQLDriver.
func (d *driveri) AlterTableRenameColumn(ctx context.Context, db sqlz.DB, tbl, col, newName string) error {
q := fmt.Sprintf("ALTER TABLE %q RENAME COLUMN %q TO %q", tbl, col, newName)
_, err := db.ExecContext(ctx, q)
return errz.Wrapf(errw(err), "alter table: failed to rename column {%s.%s} to {%s}", tbl, col, newName)
}
// AlterTableAddColumn implements driver.SQLDriver.
func (d *driveri) AlterTableAddColumn(ctx context.Context, db sqlz.DB, tbl, col string, knd kind.Kind) error {
q := fmt.Sprintf("ALTER TABLE %q ADD COLUMN %q ", tbl, col) + DBTypeForKind(knd)
_, err := db.ExecContext(ctx, q)
if err != nil {
return errz.Wrapf(errw(err), "alter table: failed to add column {%s} to table {%s}", col, tbl)
}
return nil
}
// AlterTableColumnKinds implements driver.SQLDriver. Note that SQLite doesn't
// really support altering column types, so this is a hacky implementation.
// It's not guaranteed that indices, constraints, etc. will be preserved. See:
//
// - https://www.sqlite.org/lang_altertable.html
// - https://www.sqlite.org/faq.html#q11
// - https://www.sqlitetutorial.net/sqlite-alter-table/
//
// Note that colNames and kinds must be the same length.
func (d *driveri) AlterTableColumnKinds(ctx context.Context, db sqlz.DB,
tblName string, colNames []string, kinds []kind.Kind,
) (err error) {
if len(colNames) != len(kinds) {
return errz.New("sqlite3: alter table: mismatched count of columns and kinds")
}
// It's recommended to disable foreign keys before this alter procedure.
if restorePragmaFK, fkErr := pragmaDisableForeignKeys(ctx, db); fkErr != nil {
return fkErr
} else if restorePragmaFK != nil {
defer restorePragmaFK()
}
q := "SELECT sql FROM sqlite_master WHERE type='table' AND name=?"
var ogDDL string
if err = db.QueryRowContext(ctx, q, tblName).Scan(&ogDDL); err != nil {
return errz.Wrapf(err, "sqlite3: alter table: failed to read original DDL")
}
allColDefs, err := sqlparser.ExtractCreateTableStmtColDefs(ogDDL)
if err != nil {
return errz.Wrapf(err, "sqlite3: alter table: failed to extract column definitions from DDL")
}
var colDefs []*sqlparser.ColDef
for i, colName := range colNames {
for _, cd := range allColDefs {
if cd.Name == colName {
colDefs = append(colDefs, cd)
break
}
}
if len(colDefs) != i+1 {
return errz.Errorf("sqlite3: alter table: column {%s} not found in table DDL", colName)
}
}
nuDDL := ogDDL
for i, colDef := range colDefs {
wantType := DBTypeForKind(kinds[i])
wantColDefText := strings.Replace(colDef.Raw, colDef.RawType, wantType, 1)
nuDDL = strings.Replace(nuDDL, colDef.Raw, wantColDefText, 1)
}
nuTblName := "tmp_tbl_alter_" + stringz.Uniq32()
nuDDL = strings.Replace(nuDDL, tblName, nuTblName, 1)
if _, err = db.ExecContext(ctx, nuDDL); err != nil {
return errz.Wrapf(err, "sqlite3: alter table: failed to create temporary table")
}
copyStmt := fmt.Sprintf(
"INSERT INTO %s SELECT * FROM %s",
stringz.DoubleQuote(nuTblName),
stringz.DoubleQuote(tblName),
)
if _, err = db.ExecContext(ctx, copyStmt); err != nil {
return errz.Wrapf(err, "sqlite3: alter table: failed to copy data to temporary table")
}
// Drop old table
if _, err = db.ExecContext(ctx, "DROP TABLE "+stringz.DoubleQuote(tblName)); err != nil {
return errz.Wrapf(err, "sqlite3: alter table: failed to drop original table")
}
// Rename new table to old table name
if _, err = db.ExecContext(ctx, fmt.Sprintf(
"ALTER TABLE %s RENAME TO %s",
stringz.DoubleQuote(nuTblName),
stringz.DoubleQuote(tblName),
)); err != nil {
return errz.Wrapf(err, "sqlite3: alter table: failed to rename temporary table")
}
return nil
}
// pragmaDisableForeignKeys disables foreign keys, returning a function that
// restores the original state of the foreign_keys pragma. If an error occurs,
// the returned restore function will be nil.
func pragmaDisableForeignKeys(ctx context.Context, db sqlz.DB) (restore func(), err error) {
pragmaFkExisting, err := readPragma(ctx, db, "foreign_keys")
if err != nil {
return nil, errz.Wrapf(err, "sqlite3: alter table: failed to read foreign_keys pragma")
}
if _, err = db.ExecContext(ctx, "PRAGMA foreign_keys=off"); err != nil {
return nil, errz.Wrapf(err, "sqlite3: alter table: failed to disable foreign_keys pragma")
}
return func() {
_, restoreErr := db.ExecContext(ctx, fmt.Sprintf("PRAGMA foreign_keys=%v", pragmaFkExisting))
if restoreErr != nil {
lg.FromContext(ctx).Error("sqlite3: alter table: failed to restore foreign_keys pragma", lga.Err, restoreErr)
}
}, nil
}

View File

@ -0,0 +1,143 @@
package sqlparser
import (
antlr "github.com/antlr4-go/antlr/v4"
"github.com/neilotoole/sq/drivers/sqlite3/internal/sqlparser/sqlite"
"github.com/neilotoole/sq/libsq/ast/antlrz"
"github.com/neilotoole/sq/libsq/core/errz"
"github.com/neilotoole/sq/libsq/core/stringz"
)
func parseCreateTableStmt(input string) (*sqlite.Create_table_stmtContext, error) {
lex := sqlite.NewSQLiteLexer(antlr.NewInputStream(input))
lex.RemoveErrorListeners() // the generated lexer has default listeners we don't want
lexErrs := &antlrErrorListener{name: "lexer"}
lex.AddErrorListener(lexErrs)
p := sqlite.NewSQLiteParser(antlr.NewCommonTokenStream(lex, 0))
p.RemoveErrorListeners() // the generated parser has default listeners we don't want
parseErrs := &antlrErrorListener{name: "parser"}
p.AddErrorListener(parseErrs)
qCtx := p.Create_table_stmt()
if err := lexErrs.error(); err != nil {
return nil, errz.Err(err)
}
if err := parseErrs.error(); err != nil {
return nil, errz.Err(err)
}
return qCtx.(*sqlite.Create_table_stmtContext), nil
}
// ExtractTableIdentFromCreateTableStmt extracts table name (and the
// table's schema if specified) from a CREATE TABLE statement.
// If err is nil, table is guaranteed to be non-empty. If arg unescape is
// true, any surrounding quotation chars are trimmed from the returned values.
//
// CREATE TABLE "sakila"."actor" ( actor_id INTEGER NOT NULL) --> sakila, actor, nil
func ExtractTableIdentFromCreateTableStmt(stmt string, unescape bool) (schema, table string, err error) {
stmtCtx, err := parseCreateTableStmt(stmt)
if err != nil {
return "", "", err
}
if n, ok := stmtCtx.Schema_name().(*sqlite.Schema_nameContext); ok {
if n.Any_name() != nil && !n.Any_name().IsEmpty() && n.Any_name().IDENTIFIER() != nil {
schema = n.Any_name().IDENTIFIER().GetText()
if unescape {
schema = trimIdentQuotes(schema)
}
}
}
if x, ok := stmtCtx.Table_name().(*sqlite.Table_nameContext); ok {
if x.Any_name() != nil && !x.Any_name().IsEmpty() && x.Any_name().IDENTIFIER() != nil {
table = x.Any_name().IDENTIFIER().GetText()
if unescape {
table = trimIdentQuotes(table)
}
}
}
if table == "" {
return "", "", errz.Errorf("failed to extract table name from CREATE TABLE statement")
}
return schema, table, nil
}
// ExtractCreateTableStmtColDefs extracts the column definitions from a CREATE
// TABLE statement.
func ExtractCreateTableStmtColDefs(stmt string) ([]*ColDef, error) {
stmtCtx, err := parseCreateTableStmt(stmt)
if err != nil {
return nil, err
}
var colDefs []*ColDef
tokx := antlrz.NewTokenExtractor(stmt)
for _, child := range stmtCtx.GetChildren() {
if defCtx, ok := child.(*sqlite.Column_defContext); ok {
if defCtx == nil || defCtx.Column_name() == nil {
// Shouldn't happen
continue
}
if defCtx.Type_name() == nil || defCtx.Type_name().GetText() == "" {
// Shouldn't happen
continue
}
colDef := &ColDef{
DefCtx: defCtx,
Raw: tokx.Extract(defCtx),
RawName: tokx.Extract(defCtx.Column_name()),
Name: stringz.StripDoubleQuote(defCtx.Column_name().GetText()),
RawType: tokx.Extract(defCtx.Type_name()),
Type: defCtx.Type_name().GetText(),
}
colDef.InputOffset, _ = tokx.Offset(defCtx)
colDefs = append(colDefs, colDef)
}
}
return colDefs, nil
}
// ColDef represents a column definition in a CREATE TABLE statement.
type ColDef struct {
// DefCtx is the antlr context for the column definition.
DefCtx *sqlite.Column_defContext
// Raw is the raw text of the entire column definition.
Raw string
// RawName is the raw text of the column name as it appeared in the input.
// It may be double-quoted.
RawName string
// Name is the column name, stripped of any double-quotes.
Name string
// RawType is the raw text of the column type as it appeared in the input.
RawType string
// Type is the canonicalized column type.
Type string
// InputOffset is the character start index of the column definition in the
// input. The def ends at InputOffset+len(Raw).
InputOffset int
}
// String returns the raw text of the column definition.
func (cd *ColDef) String() string {
return cd.Raw
}

View File

@ -8,48 +8,8 @@ import (
"strings"
antlr "github.com/antlr4-go/antlr/v4"
"github.com/neilotoole/sq/drivers/sqlite3/internal/sqlparser/sqlite"
"github.com/neilotoole/sq/libsq/core/errz"
)
// ExtractTableIdentFromCreateTableStmt extracts table name (and the
// table's schema if specified) from a CREATE TABLE statement.
// If err is nil, table is guaranteed to be non-empty. If arg unescape is
// true, any surrounding quotation chars are trimmed from the returned values.
//
// CREATE TABLE "sakila"."actor" ( actor_id INTEGER NOT NULL) --> sakila, actor, nil
func ExtractTableIdentFromCreateTableStmt(stmt string, unescape bool) (schema, table string, err error) {
stmtCtx, err := parseCreateTableStmt(stmt)
if err != nil {
return "", "", err
}
if n, ok := stmtCtx.Schema_name().(*sqlite.Schema_nameContext); ok {
if n.Any_name() != nil && !n.Any_name().IsEmpty() && n.Any_name().IDENTIFIER() != nil {
schema = n.Any_name().IDENTIFIER().GetText()
if unescape {
schema = trimIdentQuotes(schema)
}
}
}
if x, ok := stmtCtx.Table_name().(*sqlite.Table_nameContext); ok {
if x.Any_name() != nil && !x.Any_name().IsEmpty() && x.Any_name().IDENTIFIER() != nil {
table = x.Any_name().IDENTIFIER().GetText()
if unescape {
table = trimIdentQuotes(table)
}
}
}
if table == "" {
return "", "", errz.Errorf("failed to extract table name from CREATE TABLE statement")
}
return schema, table, nil
}
// trimIdentQuotes trims any of the legal quote characters from s.
// These are double quote, single quote, backtick, and square brackets.
//
@ -78,30 +38,6 @@ func trimIdentQuotes(s string) string {
return s
}
func parseCreateTableStmt(input string) (*sqlite.Create_table_stmtContext, error) {
lex := sqlite.NewSQLiteLexer(antlr.NewInputStream(input))
lex.RemoveErrorListeners() // the generated lexer has default listeners we don't want
lexErrs := &antlrErrorListener{name: "lexer"}
lex.AddErrorListener(lexErrs)
p := sqlite.NewSQLiteParser(antlr.NewCommonTokenStream(lex, 0))
p.RemoveErrorListeners() // the generated parser has default listeners we don't want
parseErrs := &antlrErrorListener{name: "parser"}
p.AddErrorListener(parseErrs)
qCtx := p.Create_table_stmt()
if err := lexErrs.error(); err != nil {
return nil, errz.Err(err)
}
if err := parseErrs.error(); err != nil {
return nil, errz.Err(err)
}
return qCtx.(*sqlite.Create_table_stmtContext), nil
}
var _ antlr.ErrorListener = (*antlrErrorListener)(nil)
// antlrErrorListener implements antlr.ErrorListener.

View File

@ -67,3 +67,38 @@ func TestExtractTableNameFromCreateTableStmt(t *testing.T) {
})
}
}
func TestExtractCreateTableStmtColDefs(t *testing.T) {
const input = `CREATE TABLE "og_table" (
"name" TEXT NOT NULL,
"age" INTEGER( 10 ) NOT NULL,
weight INTEGER NOT NULL
)`
colDefs, err := sqlparser.ExtractCreateTableStmtColDefs(input)
require.NoError(t, err)
require.Len(t, colDefs, 3)
require.Equal(t, `"name" TEXT NOT NULL`, colDefs[0].Raw)
require.Equal(t, `"name"`, colDefs[0].RawName)
require.Equal(t, `name`, colDefs[0].Name)
require.Equal(t, "TEXT", colDefs[0].Type)
require.Equal(t, "TEXT", colDefs[0].RawType)
snippet := input[colDefs[0].InputOffset : colDefs[0].InputOffset+len(colDefs[0].Raw)]
require.Equal(t, colDefs[0].Raw, snippet)
require.Equal(t, `"age" INTEGER( 10 ) NOT NULL`, colDefs[1].Raw)
require.Equal(t, `"age"`, colDefs[1].RawName)
require.Equal(t, `age`, colDefs[1].Name)
require.Equal(t, "INTEGER(10)", colDefs[1].Type)
require.Equal(t, "INTEGER( 10 )", colDefs[1].RawType)
snippet = input[colDefs[1].InputOffset : colDefs[1].InputOffset+len(colDefs[1].Raw)]
require.Equal(t, colDefs[1].Raw, snippet)
require.Equal(t, `weight INTEGER NOT NULL`, colDefs[2].Raw)
require.Equal(t, `weight`, colDefs[2].RawName)
require.Equal(t, `weight`, colDefs[2].Name)
require.Equal(t, "INTEGER", colDefs[2].Type)
require.Equal(t, "INTEGER", colDefs[2].RawType)
snippet = input[colDefs[2].InputOffset : colDefs[2].InputOffset+len(colDefs[2].Raw)]
require.Equal(t, colDefs[2].Raw, snippet)
}

View File

@ -796,32 +796,6 @@ func (d *driveri) ListCatalogs(_ context.Context, _ sqlz.DB) ([]string, error) {
return nil, errz.New("sqlite3: catalog mechanism not supported")
}
// AlterTableRename implements driver.SQLDriver.
func (d *driveri) AlterTableRename(ctx context.Context, db sqlz.DB, tbl, newName string) error {
q := fmt.Sprintf(`ALTER TABLE %q RENAME TO %q`, tbl, newName)
_, err := db.ExecContext(ctx, q)
return errz.Wrapf(errw(err), "alter table: failed to rename table {%s} to {%s}", tbl, newName)
}
// AlterTableRenameColumn implements driver.SQLDriver.
func (d *driveri) AlterTableRenameColumn(ctx context.Context, db sqlz.DB, tbl, col, newName string) error {
q := fmt.Sprintf("ALTER TABLE %q RENAME COLUMN %q TO %q", tbl, col, newName)
_, err := db.ExecContext(ctx, q)
return errz.Wrapf(errw(err), "alter table: failed to rename column {%s.%s} to {%s}", tbl, col, newName)
}
// AlterTableAddColumn implements driver.SQLDriver.
func (d *driveri) AlterTableAddColumn(ctx context.Context, db sqlz.DB, tbl, col string, knd kind.Kind) error {
q := fmt.Sprintf("ALTER TABLE %q ADD COLUMN %q ", tbl, col) + DBTypeForKind(knd)
_, err := db.ExecContext(ctx, q)
if err != nil {
return errz.Wrapf(errw(err), "alter table: failed to add column {%s} to table {%s}", col, tbl)
}
return nil
}
// TableExists implements driver.SQLDriver.
func (d *driveri) TableExists(ctx context.Context, db sqlz.DB, tbl string) (bool, error) {
const query = `SELECT COUNT(*) FROM sqlite_master WHERE name = ? AND type='table'`

View File

@ -10,6 +10,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/neilotoole/sq/drivers/sqlite3"
"github.com/neilotoole/sq/libsq/core/kind"
"github.com/neilotoole/sq/libsq/core/schema"
"github.com/neilotoole/sq/libsq/core/sqlz"
"github.com/neilotoole/sq/libsq/core/stringz"
@ -327,3 +328,66 @@ func TestSQLQuery_Whitespace(t *testing.T) {
require.Equal(t, "last name", sink.RecMeta[2].Name())
require.Equal(t, "last name", sink.RecMeta[2].MungedName())
}
func TestDriveri_AlterTableColumnKinds(t *testing.T) {
th := testh.New(t)
src := &source.Source{
Handle: "@test",
Type: drivertype.SQLite,
Location: "sqlite3://" + tu.TempFile(t, "test.db"),
}
ogTbl := &schema.Table{
Name: "og_table",
PKColName: "",
AutoIncrement: false,
Cols: nil,
}
ogColName := &schema.Column{
Name: "name",
Table: ogTbl,
Kind: kind.Text,
NotNull: true,
}
ogColAge := &schema.Column{
Name: "age",
Table: ogTbl,
Kind: kind.Int,
NotNull: true,
}
ogColWeight := &schema.Column{
Name: "weight",
Table: ogTbl,
Kind: kind.Int,
NotNull: true,
}
ogTbl.Cols = []*schema.Column{ogColName, ogColAge, ogColWeight}
grip := th.Open(src)
db, err := grip.DB(th.Context)
require.NoError(t, err)
drvr := grip.SQLDriver()
err = drvr.CreateTable(th.Context, db, ogTbl)
require.NoError(t, err)
gotTblMeta, err := grip.TableMetadata(th.Context, ogTbl.Name)
require.NoError(t, err)
require.Equal(t, 3, len(gotTblMeta.Columns))
require.Equal(t, kind.Int, gotTblMeta.Column("age").Kind)
require.Equal(t, kind.Int, gotTblMeta.Column("weight").Kind)
alterColNames := []string{"age", "weight"}
alterColKinds := []kind.Kind{kind.Text, kind.Float}
err = drvr.AlterTableColumnKinds(th.Context, db, ogTbl.Name, alterColNames, alterColKinds)
require.NoError(t, err)
gotTblMeta, err = grip.TableMetadata(th.Context, ogTbl.Name)
require.NoError(t, err)
require.Equal(t, 3, len(gotTblMeta.Columns))
require.Equal(t, kind.Text.String(), gotTblMeta.Column("age").Kind.String())
require.Equal(t, kind.Float.String(), gotTblMeta.Column("weight").Kind.String())
}

View File

@ -625,6 +625,11 @@ func (d *driveri) AlterTableRenameColumn(ctx context.Context, db sqlz.DB, tbl, c
return errz.Wrapf(errw(err), "alter table: failed to rename column {%s.%s.%s} to {%s}", schma, tbl, col, newName)
}
// AlterTableColumnKinds is not yet implemented for sqlserver.
func (d *driveri) AlterTableColumnKinds(_ context.Context, _ sqlz.DB, _ string, _ []string, _ []kind.Kind) error {
return errz.New("not implemented")
}
// CopyTable implements driver.SQLDriver.
func (d *driveri) CopyTable(ctx context.Context, db sqlz.DB,
fromTable, toTable tablefq.T, copyData bool,

2
go.mod
View File

@ -10,6 +10,7 @@ require (
github.com/alessio/shellescape v1.4.2
github.com/antlr4-go/antlr/v4 v4.13.0
github.com/c2h5oh/datasize v0.0.0-20231215233829-aa82cc1e6500
github.com/djherbis/buffer v1.2.0
github.com/dustin/go-humanize v1.0.1
github.com/ecnepsnai/osquery v1.0.1
github.com/emirpasic/gods v1.18.1
@ -61,7 +62,6 @@ require (
github.com/VividCortex/ewma v1.2.0 // indirect
github.com/acarl005/stripansi v0.0.0-20180116102854-5a71ef0e047d // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/djherbis/buffer v1.2.0 // indirect
github.com/felixge/fgprof v0.9.4 // indirect
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
github.com/golang-sql/sqlexp v0.1.0 // indirect

View File

@ -0,0 +1,79 @@
// Package antlrz contains utilities for working with ANTLR4.
package antlrz
import (
"strings"
"sync"
antlr "github.com/antlr4-go/antlr/v4"
)
// TokenExtractor extracts the raw text of a parser rule from the input.
type TokenExtractor struct {
input string
lines []string
once sync.Once
}
// NewTokenExtractor returns a new TokenExtractor.
func NewTokenExtractor(input string) *TokenExtractor {
return &TokenExtractor{input: input}
}
// Offset returns the start and stop (inclusive) offsets of the parser rule in
// the input.
func (l *TokenExtractor) Offset(prc antlr.ParserRuleContext) (start, stop int) {
l.once.Do(func() {
l.lines = strings.Split(l.input, "\n")
})
startToken := prc.GetStart()
startLine := startToken.GetLine() - 1
startCol := startToken.GetColumn()
stopToken := prc.GetStop()
stopLine := stopToken.GetLine() - 1
stopCol := stopToken.GetColumn() + len(stopToken.GetText())
for i := 0; i < startLine; i++ {
startCol += len(l.lines[i]) + 1
}
for i := 0; i < stopLine; i++ {
stopCol += len(l.lines[i]) + 1
}
return startCol, stopCol
}
// Extract extracts the raw text of the parser rule from the input. It may panic
// if the parser rule is not found in the input.
func (l *TokenExtractor) Extract(prc antlr.ParserRuleContext) string {
l.once.Do(func() {
l.lines = strings.Split(l.input, "\n")
})
startToken := prc.GetStart()
startLine := startToken.GetLine() - 1
startCol := startToken.GetColumn()
stopToken := prc.GetStop()
stopLine := stopToken.GetLine() - 1
stopCol := stopToken.GetColumn() + len(stopToken.GetText())
if startLine == stopLine {
return l.lines[startLine][startCol:stopCol]
}
// multi-line
var sb strings.Builder
sb.WriteString(l.lines[startLine][startCol:])
sb.WriteString("\n")
for i := startLine + 1; i < stopLine; i++ {
sb.WriteString(l.lines[i])
sb.WriteString("\n")
}
sb.WriteString(l.lines[stopLine][:stopCol])
return sb.String()
}

View File

@ -2,8 +2,6 @@
package schema
import (
"strings"
"github.com/neilotoole/sq/libsq/core/errz"
"github.com/neilotoole/sq/libsq/core/kind"
)
@ -28,6 +26,42 @@ type Table struct { //nolint:govet // field alignment
Cols []*Column `json:"cols"`
}
func (t *Table) Equal(b *Table) bool {
if t == nil && b == nil {
return true
}
if t == nil || b == nil {
return false
}
if t == b {
return true
}
if t.Name != b.Name {
return false
}
if t.PKColName != b.PKColName {
return false
}
if t.AutoIncrement != b.AutoIncrement {
return false
}
if len(t.Cols) != len(b.Cols) {
return false
}
for i, col := range t.Cols {
if !col.Equal(b.Cols[i]) {
return false
}
}
return true
}
// NewTable is a convenience constructor for creating
// a simple table definition.
func NewTable(tblName string, colNames []string, colKinds []kind.Kind) *Table {
@ -63,7 +97,7 @@ func (t *Table) ColKinds() []kind.Kind {
}
func (t *Table) String() string {
return t.Name + "(" + strings.Join(t.ColNames(), ",") + ")"
return t.Name
}
// ColsByName returns the ColDefs for each named column, or an error if any column
@ -114,6 +148,52 @@ type Column struct { //nolint:govet // field alignment
ForeignKey *FKConstraint `json:"foreign_key,omitempty"`
}
func (c *Column) Equal(b *Column) bool {
if c == nil && b == nil {
return true
}
if c == nil || b == nil {
return false
}
if c == b {
return true
}
if c.Name != b.Name {
return false
}
if c.Kind != b.Kind {
return false
}
if c.NotNull != b.NotNull {
return false
}
if c.HasDefault != b.HasDefault {
return false
}
if c.Size != b.Size {
return false
}
if c.Unique != b.Unique {
return false
}
if !c.ForeignKey.Equal(b.ForeignKey) {
return false
}
return true
}
func (c *Column) String() string {
return c.Name
}
// FKConstraint models a foreign key constraint.
type FKConstraint struct {
// RefTable is the name of the referenced parent table.
@ -125,3 +205,33 @@ type FKConstraint struct {
// OnUpdate is one of CASCADE or SET_NULL, defaults to CASCADE.
OnUpdate string `json:"on_update"`
}
func (fk *FKConstraint) Equal(b *FKConstraint) bool {
if fk == nil && b == nil {
return true
}
if fk == nil || b == nil {
return false
}
if fk == b {
return true
}
if fk.RefTable != b.RefTable {
return false
}
if fk.RefCol != b.RefCol {
return false
}
if fk.OnDelete != b.OnDelete {
return false
}
if fk.OnUpdate != b.OnUpdate {
return false
}
return true
}

View File

@ -193,6 +193,11 @@ type SQLDriver interface {
// AlterTableRenameColumn renames a column.
AlterTableRenameColumn(ctx context.Context, db sqlz.DB, tbl, col, newName string) error
// AlterTableColumnKinds changes the kinds of colNames in tbl. The length of
// args colNames and kinds must be equal. The method may create a transaction
// internally if db is not already a transaction.
AlterTableColumnKinds(ctx context.Context, db sqlz.DB, tbl string, colNames []string, kinds []kind.Kind) error
// DBProperties returns a map of key-value database properties. The value
// is often a scalar such as an int, string, or bool, but can be a nested
// map or array.