diff --git a/cli/cmd_slq.go b/cli/cmd_slq.go index 5f8c6ecc..4d334f17 100644 --- a/cli/cmd_slq.go +++ b/cli/cmd_slq.go @@ -129,7 +129,15 @@ func execSLQInsert(ctx context.Context, rc *RunContext, destSrc *source.Source, driver.Tuning.RecordChSize, libsq.DBWriterCreateTableIfNotExistsHook(destTbl), ) - execErr := libsq.ExecuteSLQ(ctx, rc.Log, rc.databases, rc.databases, srcs, slq, inserter) + + qc := &libsq.QueryContext{ + Sources: srcs, + DBOpener: rc.databases, + JoinDBOpener: rc.databases, + Args: nil, + } + + execErr := libsq.ExecuteSLQ(ctx, rc.Log, qc, slq, inserter) affected, waitErr := inserter.Wait() // Wait for the writer to finish processing if execErr != nil { return errz.Wrapf(execErr, "insert %s.%s failed", destSrc.Handle, destTbl) @@ -150,8 +158,15 @@ func execSLQPrint(ctx context.Context, rc *RunContext) error { return err } + qc := &libsq.QueryContext{ + Sources: rc.Config.Sources, + DBOpener: rc.databases, + JoinDBOpener: rc.databases, + Args: nil, + } + recw := output.NewRecordWriterAdapter(rc.writers.recordw) - execErr := libsq.ExecuteSLQ(ctx, rc.Log, rc.databases, rc.databases, rc.Config.Sources, slq, recw) + execErr := libsq.ExecuteSLQ(ctx, rc.Log, qc, slq, recw) _, waitErr := recw.Wait() if execErr != nil { return execErr diff --git a/cli/completion.go b/cli/completion.go index c678436e..c0f6ae97 100644 --- a/cli/completion.go +++ b/cli/completion.go @@ -16,7 +16,7 @@ type completionFunc func(cmd *cobra.Command, args []string, toComplete string) ( var ( _ completionFunc = completeDriverType _ completionFunc = completeSLQ - _ completionFunc = new(handleTableCompleter).complete + _ completionFunc = (*handleTableCompleter)(nil).complete ) // completeHandle is a completionFunc that suggests handles. @@ -110,8 +110,6 @@ type handleTableCompleter struct { max int } -var _ completionFunc = (*handleTableCompleter)(nil).complete - // complete is the completionFunc for handleTableCompleter. func (c *handleTableCompleter) complete(cmd *cobra.Command, args []string, toComplete string, diff --git a/libsq/engine.go b/libsq/engine.go index 05e6c12a..8673cf0a 100644 --- a/libsq/engine.go +++ b/libsq/engine.go @@ -12,15 +12,13 @@ import ( "github.com/neilotoole/sq/libsq/core/sqlmodel" "github.com/neilotoole/sq/libsq/core/sqlz" "github.com/neilotoole/sq/libsq/driver" - "github.com/neilotoole/sq/libsq/source" ) // engine executes a queryModel and writes to a RecordWriter. type engine struct { - log lg.Log - srcs *source.Set - dbOpener driver.DatabaseOpener - joinDBOpener driver.JoinDatabaseOpener + log lg.Log + + qc *QueryContext // tasks contains tasks that must be completed before targetSQL // is executed against targetDB. Typically tasks is used to @@ -36,6 +34,29 @@ type engine struct { targetDB driver.Database } +func newEngine(ctx context.Context, log lg.Log, qc *QueryContext, query string) (*engine, error) { + a, err := ast.Parse(log, query) + if err != nil { + return nil, err + } + + qModel, err := buildQueryModel(log, a) + if err != nil { + return nil, err + } + + ng := &engine{ + log: log, + qc: qc, + } + + if err = ng.prepare(ctx, qModel); err != nil { + return nil, err + } + + return ng, nil +} + // prepare prepares the engine to execute queryModel. // When this method returns, targetDB and targetSQL will be set, // as will any tasks (which may be empty). The tasks must be executed @@ -156,12 +177,12 @@ func (ng *engine) executeTasks(ctx context.Context) error { func (ng *engine) buildTableFromClause(ctx context.Context, tblSel *ast.TblSelectorNode) (fromClause string, fromConn driver.Database, err error, ) { - src, err := ng.srcs.Get(tblSel.Handle()) + src, err := ng.qc.Sources.Get(tblSel.Handle()) if err != nil { return "", nil, err } - fromConn, err = ng.dbOpener.Open(ctx, src) + fromConn, err = ng.qc.DBOpener.Open(ctx, src) if err != nil { return "", nil, err } @@ -196,12 +217,12 @@ func (ng *engine) buildJoinFromClause(ctx context.Context, fnJoin *ast.JoinNode) func (ng *engine) singleSourceJoin(ctx context.Context, fnJoin *ast.JoinNode) (fromClause string, fromDB driver.Database, err error, ) { - src, err := ng.srcs.Get(fnJoin.LeftTbl().Handle()) + src, err := ng.qc.Sources.Get(fnJoin.LeftTbl().Handle()) if err != nil { return "", nil, err } - fromDB, err = ng.dbOpener.Open(ctx, src) + fromDB, err = ng.qc.DBOpener.Open(ctx, src) if err != nil { return "", nil, err } @@ -226,23 +247,23 @@ func (ng *engine) crossSourceJoin(ctx context.Context, fnJoin *ast.JoinNode) (fr fnJoin.LeftTbl().TblName()) } - leftSrc, err := ng.srcs.Get(fnJoin.LeftTbl().Handle()) + leftSrc, err := ng.qc.Sources.Get(fnJoin.LeftTbl().Handle()) if err != nil { return "", nil, err } - rightSrc, err := ng.srcs.Get(fnJoin.RightTbl().Handle()) + rightSrc, err := ng.qc.Sources.Get(fnJoin.RightTbl().Handle()) if err != nil { return "", nil, err } // Open the join db - joinDB, err := ng.joinDBOpener.OpenJoin(ctx, leftSrc, rightSrc) + joinDB, err := ng.qc.JoinDBOpener.OpenJoin(ctx, leftSrc, rightSrc) if err != nil { return "", nil, err } - leftDB, err := ng.dbOpener.Open(ctx, leftSrc) + leftDB, err := ng.qc.DBOpener.Open(ctx, leftSrc) if err != nil { return "", nil, err } @@ -253,7 +274,7 @@ func (ng *engine) crossSourceJoin(ctx context.Context, fnJoin *ast.JoinNode) (fr toTblName: leftTblName, } - rightDB, err := ng.dbOpener.Open(ctx, rightSrc) + rightDB, err := ng.qc.DBOpener.Open(ctx, rightSrc) if err != nil { return "", nil, err } @@ -276,20 +297,6 @@ func (ng *engine) crossSourceJoin(ctx context.Context, fnJoin *ast.JoinNode) (fr return fromClause, joinDB, nil } -// SLQ2SQL simulates execution of a SLQ query, but instead of executing -// the resulting SQL query, that ultimate SQL is returned. Effectively it is -// equivalent to libsq.ExecuteSLQ, but without the execution. -func SLQ2SQL(ctx context.Context, log lg.Log, dbOpener driver.DatabaseOpener, - joinDBOpener driver.JoinDatabaseOpener, srcs *source.Set, query string, -) (targetSQL string, err error) { - var ng *engine - ng, err = newEngine(ctx, log, dbOpener, joinDBOpener, srcs, query) - if err != nil { - return "", err - } - return ng.targetSQL, nil -} - // tasker is the interface for executing a DB task. type tasker interface { // executeTask executes a task against the DB. diff --git a/libsq/libsq.go b/libsq/libsq.go index 017fce31..4e237fea 100644 --- a/libsq/libsq.go +++ b/libsq/libsq.go @@ -13,13 +13,28 @@ import ( "context" "github.com/neilotoole/lg" - "github.com/neilotoole/sq/libsq/ast" "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/sqlz" "github.com/neilotoole/sq/libsq/driver" "github.com/neilotoole/sq/libsq/source" ) +// QueryContext encapsulates the context a SLQ query is executed within. +type QueryContext struct { + // Sources is the set of sources. + Sources *source.Set + + // DBOpener is used to open databases. + DBOpener driver.DatabaseOpener + + // JoinDBOpener is used to open the joindb (if needed). + JoinDBOpener driver.JoinDatabaseOpener + + // Args defines variables that are substituted into the query. + // May be nil or empty. + Args map[string]string +} + // RecordWriter is the interface for writing records to a // destination. The Open method returns a channel to // which the records are sent. The Wait method allows @@ -71,11 +86,10 @@ type RecordWriter interface { } // ExecuteSLQ executes the slq query, writing the results to recw. -// The caller is responsible for closing dbases. -func ExecuteSLQ(ctx context.Context, log lg.Log, dbOpener driver.DatabaseOpener, joinDBOpener driver.JoinDatabaseOpener, - srcs *source.Set, query string, recw RecordWriter, +// The caller is responsible for closing qc. +func ExecuteSLQ(ctx context.Context, log lg.Log, qc *QueryContext, query string, recw RecordWriter, ) error { - ng, err := newEngine(ctx, log, dbOpener, joinDBOpener, srcs, query) + ng, err := newEngine(ctx, log, qc, query) if err != nil { return err } @@ -83,31 +97,17 @@ func ExecuteSLQ(ctx context.Context, log lg.Log, dbOpener driver.DatabaseOpener, return ng.execute(ctx, recw) } -func newEngine(ctx context.Context, log lg.Log, dbOpener driver.DatabaseOpener, joinDBOpener driver.JoinDatabaseOpener, - srcs *source.Set, query string, -) (*engine, error) { - a, err := ast.Parse(log, query) +// SLQ2SQL simulates execution of a SLQ query, but instead of executing +// the resulting SQL query, that ultimate SQL is returned. Effectively it is +// equivalent to libsq.ExecuteSLQ, but without the execution. +func SLQ2SQL(ctx context.Context, log lg.Log, qc *QueryContext, query string, +) (targetSQL string, err error) { + var ng *engine + ng, err = newEngine(ctx, log, qc, query) if err != nil { - return nil, err + return "", err } - - qModel, err := buildQueryModel(log, a) - if err != nil { - return nil, err - } - - ng := &engine{ - log: log, - srcs: srcs, - dbOpener: dbOpener, - joinDBOpener: joinDBOpener, - } - - if err = ng.prepare(ctx, qModel); err != nil { - return nil, err - } - - return ng, nil + return ng.targetSQL, nil } // QuerySQL executes the SQL query against dbase, writing diff --git a/drivers/query_args_test.go b/libsq/query_args_test.go similarity index 97% rename from drivers/query_args_test.go rename to libsq/query_args_test.go index ec1b2365..803083c3 100644 --- a/drivers/query_args_test.go +++ b/libsq/query_args_test.go @@ -1,4 +1,4 @@ -package drivers_test +package libsq_test import ( "testing" diff --git a/drivers/query_cols_test.go b/libsq/query_cols_test.go similarity index 99% rename from drivers/query_cols_test.go rename to libsq/query_cols_test.go index 227b71f8..5b2c54b9 100644 --- a/drivers/query_cols_test.go +++ b/libsq/query_cols_test.go @@ -1,4 +1,4 @@ -package drivers_test +package libsq_test import ( "testing" diff --git a/drivers/query_count_test.go b/libsq/query_count_test.go similarity index 99% rename from drivers/query_count_test.go rename to libsq/query_count_test.go index 708256ac..f935041c 100644 --- a/drivers/query_count_test.go +++ b/libsq/query_count_test.go @@ -1,4 +1,4 @@ -package drivers_test +package libsq_test import ( "testing" diff --git a/drivers/query_datetime_test.go b/libsq/query_datetime_test.go similarity index 98% rename from drivers/query_datetime_test.go rename to libsq/query_datetime_test.go index fb39aa79..13532501 100644 --- a/drivers/query_datetime_test.go +++ b/libsq/query_datetime_test.go @@ -1,4 +1,4 @@ -package drivers_test +package libsq_test import ( "testing" diff --git a/drivers/query_filter_test.go b/libsq/query_filter_test.go similarity index 96% rename from drivers/query_filter_test.go rename to libsq/query_filter_test.go index 7a9365f1..0878bc6f 100644 --- a/drivers/query_filter_test.go +++ b/libsq/query_filter_test.go @@ -1,4 +1,4 @@ -package drivers_test +package libsq_test import ( "testing" diff --git a/drivers/query_groupby_test.go b/libsq/query_groupby_test.go similarity index 98% rename from drivers/query_groupby_test.go rename to libsq/query_groupby_test.go index f43f3185..b009e3e3 100644 --- a/drivers/query_groupby_test.go +++ b/libsq/query_groupby_test.go @@ -1,4 +1,4 @@ -package drivers_test +package libsq_test import ( "testing" diff --git a/drivers/query_join_test.go b/libsq/query_join_test.go similarity index 98% rename from drivers/query_join_test.go rename to libsq/query_join_test.go index 6582517d..5bb8881a 100644 --- a/drivers/query_join_test.go +++ b/libsq/query_join_test.go @@ -1,4 +1,4 @@ -package drivers_test +package libsq_test import ( "testing" diff --git a/drivers/query_orderby_test.go b/libsq/query_orderby_test.go similarity index 99% rename from drivers/query_orderby_test.go rename to libsq/query_orderby_test.go index ad3bb68b..3832c075 100644 --- a/drivers/query_orderby_test.go +++ b/libsq/query_orderby_test.go @@ -1,4 +1,4 @@ -package drivers_test +package libsq_test import ( "testing" diff --git a/drivers/query_test.go b/libsq/query_test.go similarity index 50% rename from drivers/query_test.go rename to libsq/query_test.go index c9159684..a610d82a 100644 --- a/drivers/query_test.go +++ b/libsq/query_test.go @@ -1,16 +1,15 @@ -package drivers_test +package libsq_test import ( "strings" "testing" - "golang.org/x/exp/slices" + "github.com/neilotoole/sq/libsq" - "github.com/neilotoole/sq/drivers/mysql" + "golang.org/x/exp/slices" "github.com/neilotoole/sq/libsq/source" - "github.com/neilotoole/sq/libsq" "github.com/stretchr/testify/require" _ "github.com/mattn/go-sqlite3" @@ -67,7 +66,6 @@ func execQueryTestCase(t *testing.T, tc queryTestCase) { } srcs := testh.New(t).NewSourceSet(sakila.SQLLatest()...) - // srcs := testh.New(t).NewSourceSet(sakila.SL3) // FIXME: remove when done debugging for _, src := range srcs.Items() { src := src @@ -91,7 +89,14 @@ func execQueryTestCase(t *testing.T, tc queryTestCase) { th := testh.New(t) dbases := th.Databases() - gotSQL, gotErr := libsq.SLQ2SQL(th.Context, th.Log, dbases, dbases, srcs, in) + qc := &libsq.QueryContext{ + Sources: srcs, + DBOpener: dbases, + JoinDBOpener: dbases, + Args: tc.args, + } + + gotSQL, gotErr := libsq.SLQ2SQL(th.Context, th.Log, qc, in) if tc.wantErr { require.Error(t, gotErr) return @@ -111,68 +116,3 @@ func execQueryTestCase(t *testing.T, tc queryTestCase) { }) } } - -//nolint:exhaustive,lll -func TestSLQ2SQL(t *testing.T) { - testCases := []queryTestCase{ - { - name: "select/cols", - in: `@sakila | .actor | .first_name, .last_name`, - wantSQL: `SELECT "first_name", "last_name" FROM "actor"`, - override: map[source.Type]string{mysql.Type: "SELECT `first_name`, `last_name` FROM `actor`"}, - wantRecs: sakila.TblActorCount, - }, - { - name: "select/cols-whitespace-single-col", - in: `@sakila | .actor | ."first name"`, - wantSQL: `SELECT "first name" FROM "actor"`, - override: map[source.Type]string{mysql.Type: "SELECT `first name` FROM `actor`"}, - wantRecs: sakila.TblActorCount, - skipExec: true, - }, - { - name: "select/cols-whitespace-multiple-cols", - in: `@sakila | .actor | .actor_id, ."first name", ."last name"`, - wantSQL: `SELECT "actor_id", "first name", "last name" FROM "actor"`, - override: map[source.Type]string{mysql.Type: "SELECT `actor_id`, `first name`, `last name` FROM `actor`"}, - wantRecs: sakila.TblActorCount, - skipExec: true, - }, - { - name: "count/whitespace-col", - in: `@sakila | .actor | count(."first name")`, - wantSQL: `SELECT count("first name") FROM "actor"`, - override: map[source.Type]string{mysql.Type: "SELECT count(`first name`) FROM `actor`"}, - skipExec: true, - }, - { - name: "select/table-whitespace", - in: `@sakila | ."film actor"`, - wantSQL: `SELECT * FROM "film actor"`, - override: map[source.Type]string{mysql.Type: "SELECT * FROM `film actor`"}, - skipExec: true, - }, - { - name: "select/cols-aliases", - in: `@sakila | .actor | .first_name:given_name, .last_name:family_name`, - wantSQL: `SELECT "first_name" AS "given_name", "last_name" AS "family_name" FROM "actor"`, - override: map[source.Type]string{mysql.Type: "SELECT `first_name` AS `given_name`, `last_name` AS `family_name` FROM `actor`"}, - wantRecs: sakila.TblActorCount, - }, - - { - name: "select/handle-table/cols", - in: `@sakila.actor | .first_name, .last_name`, - wantSQL: `SELECT "first_name", "last_name" FROM "actor"`, - override: map[source.Type]string{mysql.Type: "SELECT `first_name`, `last_name` FROM `actor`"}, - wantRecs: sakila.TblActorCount, - }, - } - - for _, tc := range testCases { - tc := tc - t.Run(tc.name, func(t *testing.T) { - execQueryTestCase(t, tc) - }) - } -} diff --git a/drivers/query_unique_test.go b/libsq/query_unique_test.go similarity index 98% rename from drivers/query_unique_test.go rename to libsq/query_unique_test.go index 28db726f..1d8b92dd 100644 --- a/drivers/query_unique_test.go +++ b/libsq/query_unique_test.go @@ -1,4 +1,4 @@ -package drivers_test +package libsq_test import ( "testing" diff --git a/testh/testh.go b/testh/testh.go index f59f0619..73782b27 100644 --- a/testh/testh.go +++ b/testh/testh.go @@ -458,12 +458,16 @@ func (h *Helper) QuerySLQ(query string) (*RecordSink, error) { _ = h.Source(handle) } - srcs := h.srcs - dbases := h.Databases() + qc := &libsq.QueryContext{ + Sources: h.srcs, + DBOpener: h.databases, + JoinDBOpener: h.databases, + } + sink := &RecordSink{} recw := output.NewRecordWriterAdapter(sink) - err = libsq.ExecuteSLQ(h.Context, h.Log, dbases, dbases, srcs, query, recw) + err = libsq.ExecuteSLQ(h.Context, h.Log, qc, query, recw) if err != nil { return nil, err }