mirror of
https://github.com/neilotoole/sq.git
synced 2024-12-24 16:51:34 +03:00
484 lines
12 KiB
Go
484 lines
12 KiB
Go
package libsq
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
|
|
"github.com/samber/lo"
|
|
"golang.org/x/sync/errgroup"
|
|
|
|
"github.com/neilotoole/sq/libsq/ast"
|
|
"github.com/neilotoole/sq/libsq/ast/render"
|
|
"github.com/neilotoole/sq/libsq/core/errz"
|
|
"github.com/neilotoole/sq/libsq/core/lg"
|
|
"github.com/neilotoole/sq/libsq/core/lg/lga"
|
|
"github.com/neilotoole/sq/libsq/core/lg/lgm"
|
|
"github.com/neilotoole/sq/libsq/core/options"
|
|
"github.com/neilotoole/sq/libsq/core/record"
|
|
"github.com/neilotoole/sq/libsq/core/sqlmodel"
|
|
"github.com/neilotoole/sq/libsq/core/sqlz"
|
|
"github.com/neilotoole/sq/libsq/core/tablefq"
|
|
"github.com/neilotoole/sq/libsq/driver"
|
|
"github.com/neilotoole/sq/libsq/source"
|
|
)
|
|
|
|
// pipeline is used to execute a SLQ query,
|
|
// and write the resulting records to a RecordWriter.
|
|
type pipeline struct {
|
|
// query is the SLQ query
|
|
query string
|
|
|
|
// qc is the context in which the query is executed.
|
|
qc *QueryContext
|
|
|
|
// rc is the Context for rendering SQL.
|
|
// This field is set during pipeline.prepare. It can't be set before
|
|
// then because the target DB to use is calculated during pipeline.prepare,
|
|
// based on the input query and other context.
|
|
rc *render.Context
|
|
|
|
// tasks contains tasks that must be completed before targetSQL
|
|
// is executed against targetPool. Typically tasks is used to
|
|
// set up the joindb before it is queried.
|
|
tasks []tasker
|
|
|
|
// targetSQL is the ultimate SQL query to be executed against targetPool.
|
|
targetSQL string
|
|
|
|
// targetPool is the destination for the ultimate SQL query to
|
|
// be executed against.
|
|
targetPool driver.Pool
|
|
}
|
|
|
|
// newPipeline parses query, returning a pipeline prepared for
|
|
// execution via pipeline.execute.
|
|
func newPipeline(ctx context.Context, qc *QueryContext, query string) (*pipeline, error) {
|
|
log := lg.FromContext(ctx)
|
|
|
|
a, err := ast.Parse(log, query)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
qModel, err := buildQueryModel(qc, a)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
p := &pipeline{
|
|
qc: qc,
|
|
query: query,
|
|
}
|
|
|
|
if err = p.prepare(ctx, qModel); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return p, nil
|
|
}
|
|
|
|
// execute executes the pipeline, writing results to recw.
|
|
func (p *pipeline) execute(ctx context.Context, recw RecordWriter) error {
|
|
log := lg.FromContext(ctx)
|
|
log.Debug(
|
|
"Execute SQL query",
|
|
lga.Src, p.targetPool.Source(),
|
|
lga.SQL, p.targetSQL,
|
|
)
|
|
|
|
errw := p.targetPool.SQLDriver().ErrWrapFunc()
|
|
|
|
// TODO: The tasks might like to be executed in parallel. However,
|
|
// what happens if a task does something that is session/connection-dependent?
|
|
// When the query executes later (below), it could be on a different
|
|
// connection. Maybe the tasks need a means of declaring that they
|
|
// must be executed on the same connection as the main query?
|
|
if err := p.executeTasks(ctx); err != nil {
|
|
return errw(err)
|
|
}
|
|
|
|
var conn sqlz.DB
|
|
if len(p.qc.PreExecStmts) > 0 || len(p.qc.PostExecStmts) > 0 {
|
|
// If there's pre/post exec work to do, we need to
|
|
// obtain a connection from the pool. We are responsible
|
|
// for closing these resources.
|
|
db, err := p.targetPool.DB(ctx)
|
|
if err != nil {
|
|
return errw(err)
|
|
}
|
|
defer lg.WarnIfCloseError(log, lgm.CloseDB, db)
|
|
|
|
if conn, err = db.Conn(ctx); err != nil {
|
|
return errw(err)
|
|
}
|
|
defer lg.WarnIfCloseError(log, lgm.CloseConn, conn.(*sql.Conn))
|
|
|
|
for _, stmt := range p.qc.PreExecStmts {
|
|
if _, err = conn.ExecContext(ctx, stmt); err != nil {
|
|
return errw(err)
|
|
}
|
|
}
|
|
}
|
|
|
|
if err := QuerySQL(ctx, p.targetPool, conn, recw, p.targetSQL); err != nil {
|
|
return err
|
|
}
|
|
|
|
if conn != nil && len(p.qc.PostExecStmts) > 0 {
|
|
for _, stmt := range p.qc.PostExecStmts {
|
|
if _, err := conn.ExecContext(ctx, stmt); err != nil {
|
|
return errw(err)
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// executeTasks executes any tasks in pipeline.tasks.
|
|
// These tasks may exist if preparatory work must be performed
|
|
// before pipeline.targetSQL can be executed.
|
|
func (p *pipeline) executeTasks(ctx context.Context) error {
|
|
switch len(p.tasks) {
|
|
case 0:
|
|
return nil
|
|
case 1:
|
|
return p.tasks[0].executeTask(ctx)
|
|
default:
|
|
}
|
|
|
|
g, gCtx := errgroup.WithContext(ctx)
|
|
g.SetLimit(driver.OptTuningErrgroupLimit.Get(options.FromContext(ctx)))
|
|
|
|
for _, task := range p.tasks {
|
|
task := task
|
|
|
|
g.Go(func() error {
|
|
select {
|
|
case <-gCtx.Done():
|
|
return gCtx.Err()
|
|
default:
|
|
}
|
|
return task.executeTask(gCtx)
|
|
})
|
|
}
|
|
|
|
return g.Wait()
|
|
}
|
|
|
|
// prepareNoTable is invoked when the queryModel doesn't have a table.
|
|
// That is to say, the query doesn't have a "FROM table" clause. It is
|
|
// this function's responsibility to figure out what source to use, and
|
|
// to set the relevant pipeline fields.
|
|
func (p *pipeline) prepareNoTable(ctx context.Context, qm *queryModel) error {
|
|
log := lg.FromContext(ctx)
|
|
log.Debug("No table in query; will look for source to use...")
|
|
|
|
var (
|
|
src *source.Source
|
|
err error
|
|
handle = ast.NewInspector(qm.AST).FindFirstHandle()
|
|
)
|
|
|
|
if handle == "" {
|
|
if src = p.qc.Collection.Active(); src == nil {
|
|
log.Debug("No active source, will use scratchdb.")
|
|
p.targetPool, err = p.qc.ScratchPoolOpener.OpenScratch(ctx, "scratch")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
p.rc = &render.Context{
|
|
Renderer: p.targetPool.SQLDriver().Renderer(),
|
|
Args: p.qc.Args,
|
|
Dialect: p.targetPool.SQLDriver().Dialect(),
|
|
}
|
|
return nil
|
|
}
|
|
|
|
log.Debug("Using active source.", lga.Src, src)
|
|
} else if src, err = p.qc.Collection.Get(handle); err != nil {
|
|
return err
|
|
}
|
|
|
|
// At this point, src is non-nil.
|
|
if p.targetPool, err = p.qc.PoolOpener.Open(ctx, src); err != nil {
|
|
return err
|
|
}
|
|
|
|
p.rc = &render.Context{
|
|
Renderer: p.targetPool.SQLDriver().Renderer(),
|
|
Args: p.qc.Args,
|
|
Dialect: p.targetPool.SQLDriver().Dialect(),
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// prepareFromTable builds the "FROM table" fragment.
|
|
//
|
|
// When this function returns, pipeline.rc will be set.
|
|
func (p *pipeline) prepareFromTable(ctx context.Context, tblSel *ast.TblSelectorNode) (fromClause string,
|
|
fromPool driver.Pool, err error,
|
|
) {
|
|
handle := tblSel.Handle()
|
|
if handle == "" {
|
|
handle = p.qc.Collection.ActiveHandle()
|
|
if handle == "" {
|
|
return "", nil, errz.New("query does not specify source, and no active source")
|
|
}
|
|
}
|
|
|
|
src, err := p.qc.Collection.Get(handle)
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
|
|
fromPool, err = p.qc.PoolOpener.Open(ctx, src)
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
|
|
rndr := fromPool.SQLDriver().Renderer()
|
|
p.rc = &render.Context{
|
|
Renderer: rndr,
|
|
Args: p.qc.Args,
|
|
Dialect: fromPool.SQLDriver().Dialect(),
|
|
}
|
|
|
|
fromClause, err = rndr.FromTable(p.rc, tblSel)
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
|
|
return fromClause, fromPool, nil
|
|
}
|
|
|
|
// joinClause models the SQL "JOIN" construct.
|
|
type joinClause struct {
|
|
leftTbl *ast.TblSelectorNode
|
|
joins []*ast.JoinNode
|
|
}
|
|
|
|
// tables returns a new slice containing all referenced tables.
|
|
func (jc *joinClause) tables() []*ast.TblSelectorNode {
|
|
tbls := make([]*ast.TblSelectorNode, len(jc.joins)+1)
|
|
tbls[0] = jc.leftTbl
|
|
for i := range jc.joins {
|
|
tbls[i+1] = jc.joins[i].Table()
|
|
}
|
|
|
|
return tbls
|
|
}
|
|
|
|
// handles returns the set of (non-empty) handles from the tables,
|
|
// without any duplicates.
|
|
func (jc *joinClause) handles() []string {
|
|
handles := make([]string, len(jc.joins)+1)
|
|
handles[0] = jc.leftTbl.Handle()
|
|
for i := 0; i < len(jc.joins); i++ {
|
|
handles[i+1] = jc.joins[i].Table().Handle()
|
|
}
|
|
|
|
handles = lo.Uniq(handles)
|
|
handles = lo.Without(handles, "")
|
|
return handles
|
|
}
|
|
|
|
// isSingleSource returns true if the joins refer to the same handle.
|
|
func (jc *joinClause) isSingleSource() bool {
|
|
leftHandle := jc.leftTbl.Handle()
|
|
|
|
for _, join := range jc.joins {
|
|
joinHandle := join.Table().Handle()
|
|
if joinHandle == "" {
|
|
continue
|
|
}
|
|
|
|
if joinHandle != leftHandle {
|
|
return false
|
|
}
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
// prepareFromJoin builds the "JOIN" clause.
|
|
//
|
|
// When this function returns, pipeline.rc will be set.
|
|
func (p *pipeline) prepareFromJoin(ctx context.Context, jc *joinClause) (fromClause string,
|
|
fromConn driver.Pool, err error,
|
|
) {
|
|
if jc.isSingleSource() {
|
|
return p.joinSingleSource(ctx, jc)
|
|
}
|
|
|
|
return p.joinCrossSource(ctx, jc)
|
|
}
|
|
|
|
// joinSingleSource sets up a join against a single source.
|
|
//
|
|
// On return, pipeline.rc will be set.
|
|
func (p *pipeline) joinSingleSource(ctx context.Context, jc *joinClause) (fromClause string,
|
|
fromPool driver.Pool, err error,
|
|
) {
|
|
src, err := p.qc.Collection.Get(jc.leftTbl.Handle())
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
|
|
fromPool, err = p.qc.PoolOpener.Open(ctx, src)
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
|
|
rndr := fromPool.SQLDriver().Renderer()
|
|
p.rc = &render.Context{
|
|
Renderer: rndr,
|
|
Args: p.qc.Args,
|
|
Dialect: fromPool.SQLDriver().Dialect(),
|
|
}
|
|
|
|
fromClause, err = rndr.Join(p.rc, jc.leftTbl, jc.joins)
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
|
|
return fromClause, fromPool, nil
|
|
}
|
|
|
|
// joinCrossSource returns a FROM clause that forms part of
|
|
// the SQL SELECT statement against fromDB.
|
|
//
|
|
// On return, pipeline.rc will be set.
|
|
func (p *pipeline) joinCrossSource(ctx context.Context, jc *joinClause) (fromClause string,
|
|
fromDB driver.Pool, err error,
|
|
) {
|
|
handles := jc.handles()
|
|
srcs := make([]*source.Source, 0, len(handles))
|
|
for _, handle := range handles {
|
|
var src *source.Source
|
|
if src, err = p.qc.Collection.Get(handle); err != nil {
|
|
return "", nil, err
|
|
}
|
|
srcs = append(srcs, src)
|
|
}
|
|
|
|
// Open the join db
|
|
joinPool, err := p.qc.JoinPoolOpener.OpenJoin(ctx, srcs...)
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
|
|
rndr := joinPool.SQLDriver().Renderer()
|
|
p.rc = &render.Context{
|
|
Renderer: rndr,
|
|
Args: p.qc.Args,
|
|
Dialect: joinPool.SQLDriver().Dialect(),
|
|
}
|
|
|
|
leftHandle := jc.leftTbl.Handle()
|
|
// TODO: verify not empty
|
|
|
|
tbls := jc.tables()
|
|
for _, tbl := range tbls {
|
|
tbl := tbl
|
|
handle := tbl.Handle()
|
|
if handle == "" {
|
|
handle = leftHandle
|
|
}
|
|
var src *source.Source
|
|
if src, err = p.qc.Collection.Get(handle); err != nil {
|
|
return "", nil, err
|
|
}
|
|
var db driver.Pool
|
|
if db, err = p.qc.PoolOpener.Open(ctx, src); err != nil {
|
|
return "", nil, err
|
|
}
|
|
|
|
task := &joinCopyTask{
|
|
fromPool: db,
|
|
fromTbl: tbl.Table(),
|
|
toPool: joinPool,
|
|
toTbl: tbl.TblAliasOrName(),
|
|
}
|
|
|
|
tbl.SyncTblNameAlias()
|
|
|
|
p.tasks = append(p.tasks, task)
|
|
}
|
|
|
|
fromClause, err = rndr.Join(p.rc, jc.leftTbl, jc.joins)
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
|
|
return fromClause, joinPool, nil
|
|
}
|
|
|
|
// tasker is the interface for executing a DB task.
|
|
type tasker interface {
|
|
// executeTask executes a task against the DB.
|
|
executeTask(ctx context.Context) error
|
|
}
|
|
|
|
// joinCopyTask is a specification of a table data copy task to be performed
|
|
// for a cross-source join. That is, the data in fromDB.fromTblName will
|
|
// be copied to a table in toPool. If colNames is
|
|
// empty, all cols in fromTbl are to be copied.
|
|
type joinCopyTask struct {
|
|
fromPool driver.Pool
|
|
fromTbl tablefq.T
|
|
toPool driver.Pool
|
|
toTbl tablefq.T
|
|
}
|
|
|
|
func (jt *joinCopyTask) executeTask(ctx context.Context) error {
|
|
return execCopyTable(ctx, jt.fromPool, jt.fromTbl, jt.toPool, jt.toTbl)
|
|
}
|
|
|
|
// execCopyTable performs the work of copying fromDB.fromTbl to destPool.destTbl.
|
|
func execCopyTable(ctx context.Context, fromDB driver.Pool, fromTbl tablefq.T,
|
|
destPool driver.Pool, destTbl tablefq.T,
|
|
) error {
|
|
log := lg.FromContext(ctx)
|
|
|
|
createTblHook := func(ctx context.Context, originRecMeta record.Meta, destPool driver.Pool,
|
|
tx sqlz.DB,
|
|
) error {
|
|
destColNames := originRecMeta.Names()
|
|
destColKinds := originRecMeta.Kinds()
|
|
destTblDef := sqlmodel.NewTableDef(destTbl.Table, destColNames, destColKinds)
|
|
|
|
err := destPool.SQLDriver().CreateTable(ctx, tx, destTblDef)
|
|
if err != nil {
|
|
return errz.Wrapf(err, "failed to create dest table %s.%s", destPool.Source().Handle, destTbl)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
inserter := NewDBWriter(
|
|
destPool,
|
|
destTbl.Table,
|
|
driver.OptTuningRecChanSize.Get(destPool.Source().Options),
|
|
createTblHook,
|
|
)
|
|
|
|
query := "SELECT * FROM " + fromTbl.Render(fromDB.SQLDriver().Dialect().Enquote)
|
|
err := QuerySQL(ctx, fromDB, nil, inserter, query)
|
|
if err != nil {
|
|
return errz.Wrapf(err, "insert %s.%s failed", destPool.Source().Handle, destTbl)
|
|
}
|
|
|
|
affected, err := inserter.Wait() // Wait for the writer to finish processing
|
|
if err != nil {
|
|
return errz.Wrapf(err, "insert %s.%s failed", destPool.Source().Handle, destTbl)
|
|
}
|
|
log.Debug("Copied rows to dest", lga.Count, affected,
|
|
lga.From, fmt.Sprintf("%s.%s", fromDB.Source().Handle, fromTbl),
|
|
lga.To, fmt.Sprintf("%s.%s", destPool.Source().Handle, destTbl))
|
|
return nil
|
|
}
|