package libsq import ( "context" "database/sql" "fmt" "github.com/samber/lo" "golang.org/x/sync/errgroup" "github.com/neilotoole/sq/libsq/ast" "github.com/neilotoole/sq/libsq/ast/render" "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/lg/lgm" "github.com/neilotoole/sq/libsq/core/options" "github.com/neilotoole/sq/libsq/core/record" "github.com/neilotoole/sq/libsq/core/sqlmodel" "github.com/neilotoole/sq/libsq/core/sqlz" "github.com/neilotoole/sq/libsq/core/tablefq" "github.com/neilotoole/sq/libsq/driver" "github.com/neilotoole/sq/libsq/source" ) // pipeline is used to execute a SLQ query, // and write the resulting records to a RecordWriter. type pipeline struct { // query is the SLQ query query string // qc is the context in which the query is executed. qc *QueryContext // rc is the Context for rendering SQL. // This field is set during pipeline.prepare. It can't be set before // then because the target DB to use is calculated during pipeline.prepare, // based on the input query and other context. rc *render.Context // tasks contains tasks that must be completed before targetSQL // is executed against targetPool. Typically tasks is used to // set up the joindb before it is queried. tasks []tasker // targetSQL is the ultimate SQL query to be executed against targetPool. targetSQL string // targetPool is the destination for the ultimate SQL query to // be executed against. targetPool driver.Pool } // newPipeline parses query, returning a pipeline prepared for // execution via pipeline.execute. func newPipeline(ctx context.Context, qc *QueryContext, query string) (*pipeline, error) { log := lg.FromContext(ctx) a, err := ast.Parse(log, query) if err != nil { return nil, err } qModel, err := buildQueryModel(qc, a) if err != nil { return nil, err } p := &pipeline{ qc: qc, query: query, } if err = p.prepare(ctx, qModel); err != nil { return nil, err } return p, nil } // execute executes the pipeline, writing results to recw. func (p *pipeline) execute(ctx context.Context, recw RecordWriter) error { log := lg.FromContext(ctx) log.Debug( "Execute SQL query", lga.Src, p.targetPool.Source(), lga.SQL, p.targetSQL, ) errw := p.targetPool.SQLDriver().ErrWrapFunc() // TODO: The tasks might like to be executed in parallel. However, // what happens if a task does something that is session/connection-dependent? // When the query executes later (below), it could be on a different // connection. Maybe the tasks need a means of declaring that they // must be executed on the same connection as the main query? if err := p.executeTasks(ctx); err != nil { return errw(err) } var conn sqlz.DB if len(p.qc.PreExecStmts) > 0 || len(p.qc.PostExecStmts) > 0 { // If there's pre/post exec work to do, we need to // obtain a connection from the pool. We are responsible // for closing these resources. db, err := p.targetPool.DB(ctx) if err != nil { return errw(err) } defer lg.WarnIfCloseError(log, lgm.CloseDB, db) if conn, err = db.Conn(ctx); err != nil { return errw(err) } defer lg.WarnIfCloseError(log, lgm.CloseConn, conn.(*sql.Conn)) for _, stmt := range p.qc.PreExecStmts { if _, err = conn.ExecContext(ctx, stmt); err != nil { return errw(err) } } } if err := QuerySQL(ctx, p.targetPool, conn, recw, p.targetSQL); err != nil { return err } if conn != nil && len(p.qc.PostExecStmts) > 0 { for _, stmt := range p.qc.PostExecStmts { if _, err := conn.ExecContext(ctx, stmt); err != nil { return errw(err) } } } return nil } // executeTasks executes any tasks in pipeline.tasks. // These tasks may exist if preparatory work must be performed // before pipeline.targetSQL can be executed. func (p *pipeline) executeTasks(ctx context.Context) error { switch len(p.tasks) { case 0: return nil case 1: return p.tasks[0].executeTask(ctx) default: } g, gCtx := errgroup.WithContext(ctx) g.SetLimit(driver.OptTuningErrgroupLimit.Get(options.FromContext(ctx))) for _, task := range p.tasks { task := task g.Go(func() error { select { case <-gCtx.Done(): return gCtx.Err() default: } return task.executeTask(gCtx) }) } return g.Wait() } // prepareNoTable is invoked when the queryModel doesn't have a table. // That is to say, the query doesn't have a "FROM table" clause. It is // this function's responsibility to figure out what source to use, and // to set the relevant pipeline fields. func (p *pipeline) prepareNoTable(ctx context.Context, qm *queryModel) error { log := lg.FromContext(ctx) log.Debug("No table in query; will look for source to use...") var ( src *source.Source err error handle = ast.NewInspector(qm.AST).FindFirstHandle() ) if handle == "" { if src = p.qc.Collection.Active(); src == nil { log.Debug("No active source, will use scratchdb.") p.targetPool, err = p.qc.ScratchPoolOpener.OpenScratch(ctx, "scratch") if err != nil { return err } p.rc = &render.Context{ Renderer: p.targetPool.SQLDriver().Renderer(), Args: p.qc.Args, Dialect: p.targetPool.SQLDriver().Dialect(), } return nil } log.Debug("Using active source.", lga.Src, src) } else if src, err = p.qc.Collection.Get(handle); err != nil { return err } // At this point, src is non-nil. if p.targetPool, err = p.qc.PoolOpener.Open(ctx, src); err != nil { return err } p.rc = &render.Context{ Renderer: p.targetPool.SQLDriver().Renderer(), Args: p.qc.Args, Dialect: p.targetPool.SQLDriver().Dialect(), } return nil } // prepareFromTable builds the "FROM table" fragment. // // When this function returns, pipeline.rc will be set. func (p *pipeline) prepareFromTable(ctx context.Context, tblSel *ast.TblSelectorNode) (fromClause string, fromPool driver.Pool, err error, ) { handle := tblSel.Handle() if handle == "" { handle = p.qc.Collection.ActiveHandle() if handle == "" { return "", nil, errz.New("query does not specify source, and no active source") } } src, err := p.qc.Collection.Get(handle) if err != nil { return "", nil, err } fromPool, err = p.qc.PoolOpener.Open(ctx, src) if err != nil { return "", nil, err } rndr := fromPool.SQLDriver().Renderer() p.rc = &render.Context{ Renderer: rndr, Args: p.qc.Args, Dialect: fromPool.SQLDriver().Dialect(), } fromClause, err = rndr.FromTable(p.rc, tblSel) if err != nil { return "", nil, err } return fromClause, fromPool, nil } // joinClause models the SQL "JOIN" construct. type joinClause struct { leftTbl *ast.TblSelectorNode joins []*ast.JoinNode } // tables returns a new slice containing all referenced tables. func (jc *joinClause) tables() []*ast.TblSelectorNode { tbls := make([]*ast.TblSelectorNode, len(jc.joins)+1) tbls[0] = jc.leftTbl for i := range jc.joins { tbls[i+1] = jc.joins[i].Table() } return tbls } // handles returns the set of (non-empty) handles from the tables, // without any duplicates. func (jc *joinClause) handles() []string { handles := make([]string, len(jc.joins)+1) handles[0] = jc.leftTbl.Handle() for i := 0; i < len(jc.joins); i++ { handles[i+1] = jc.joins[i].Table().Handle() } handles = lo.Uniq(handles) handles = lo.Without(handles, "") return handles } // isSingleSource returns true if the joins refer to the same handle. func (jc *joinClause) isSingleSource() bool { leftHandle := jc.leftTbl.Handle() for _, join := range jc.joins { joinHandle := join.Table().Handle() if joinHandle == "" { continue } if joinHandle != leftHandle { return false } } return true } // prepareFromJoin builds the "JOIN" clause. // // When this function returns, pipeline.rc will be set. func (p *pipeline) prepareFromJoin(ctx context.Context, jc *joinClause) (fromClause string, fromConn driver.Pool, err error, ) { if jc.isSingleSource() { return p.joinSingleSource(ctx, jc) } return p.joinCrossSource(ctx, jc) } // joinSingleSource sets up a join against a single source. // // On return, pipeline.rc will be set. func (p *pipeline) joinSingleSource(ctx context.Context, jc *joinClause) (fromClause string, fromPool driver.Pool, err error, ) { src, err := p.qc.Collection.Get(jc.leftTbl.Handle()) if err != nil { return "", nil, err } fromPool, err = p.qc.PoolOpener.Open(ctx, src) if err != nil { return "", nil, err } rndr := fromPool.SQLDriver().Renderer() p.rc = &render.Context{ Renderer: rndr, Args: p.qc.Args, Dialect: fromPool.SQLDriver().Dialect(), } fromClause, err = rndr.Join(p.rc, jc.leftTbl, jc.joins) if err != nil { return "", nil, err } return fromClause, fromPool, nil } // joinCrossSource returns a FROM clause that forms part of // the SQL SELECT statement against fromDB. // // On return, pipeline.rc will be set. func (p *pipeline) joinCrossSource(ctx context.Context, jc *joinClause) (fromClause string, fromDB driver.Pool, err error, ) { handles := jc.handles() srcs := make([]*source.Source, 0, len(handles)) for _, handle := range handles { var src *source.Source if src, err = p.qc.Collection.Get(handle); err != nil { return "", nil, err } srcs = append(srcs, src) } // Open the join db joinPool, err := p.qc.JoinPoolOpener.OpenJoin(ctx, srcs...) if err != nil { return "", nil, err } rndr := joinPool.SQLDriver().Renderer() p.rc = &render.Context{ Renderer: rndr, Args: p.qc.Args, Dialect: joinPool.SQLDriver().Dialect(), } leftHandle := jc.leftTbl.Handle() // TODO: verify not empty tbls := jc.tables() for _, tbl := range tbls { tbl := tbl handle := tbl.Handle() if handle == "" { handle = leftHandle } var src *source.Source if src, err = p.qc.Collection.Get(handle); err != nil { return "", nil, err } var db driver.Pool if db, err = p.qc.PoolOpener.Open(ctx, src); err != nil { return "", nil, err } task := &joinCopyTask{ fromPool: db, fromTbl: tbl.Table(), toPool: joinPool, toTbl: tbl.TblAliasOrName(), } tbl.SyncTblNameAlias() p.tasks = append(p.tasks, task) } fromClause, err = rndr.Join(p.rc, jc.leftTbl, jc.joins) if err != nil { return "", nil, err } return fromClause, joinPool, nil } // tasker is the interface for executing a DB task. type tasker interface { // executeTask executes a task against the DB. executeTask(ctx context.Context) error } // joinCopyTask is a specification of a table data copy task to be performed // for a cross-source join. That is, the data in fromDB.fromTblName will // be copied to a table in toPool. If colNames is // empty, all cols in fromTbl are to be copied. type joinCopyTask struct { fromPool driver.Pool fromTbl tablefq.T toPool driver.Pool toTbl tablefq.T } func (jt *joinCopyTask) executeTask(ctx context.Context) error { return execCopyTable(ctx, jt.fromPool, jt.fromTbl, jt.toPool, jt.toTbl) } // execCopyTable performs the work of copying fromDB.fromTbl to destPool.destTbl. func execCopyTable(ctx context.Context, fromDB driver.Pool, fromTbl tablefq.T, destPool driver.Pool, destTbl tablefq.T, ) error { log := lg.FromContext(ctx) createTblHook := func(ctx context.Context, originRecMeta record.Meta, destPool driver.Pool, tx sqlz.DB, ) error { destColNames := originRecMeta.Names() destColKinds := originRecMeta.Kinds() destTblDef := sqlmodel.NewTableDef(destTbl.Table, destColNames, destColKinds) err := destPool.SQLDriver().CreateTable(ctx, tx, destTblDef) if err != nil { return errz.Wrapf(err, "failed to create dest table %s.%s", destPool.Source().Handle, destTbl) } return nil } inserter := NewDBWriter( destPool, destTbl.Table, driver.OptTuningRecChanSize.Get(destPool.Source().Options), createTblHook, ) query := "SELECT * FROM " + fromTbl.Render(fromDB.SQLDriver().Dialect().Enquote) err := QuerySQL(ctx, fromDB, nil, inserter, query) if err != nil { return errz.Wrapf(err, "insert %s.%s failed", destPool.Source().Handle, destTbl) } affected, err := inserter.Wait() // Wait for the writer to finish processing if err != nil { return errz.Wrapf(err, "insert %s.%s failed", destPool.Source().Handle, destTbl) } log.Debug("Copied rows to dest", lga.Count, affected, lga.From, fmt.Sprintf("%s.%s", fromDB.Source().Handle, fromTbl), lga.To, fmt.Sprintf("%s.%s", destPool.Source().Handle, destTbl)) return nil }