mirror of
https://github.com/neilotoole/sq.git
synced 2024-12-18 21:52:28 +03:00
99454852f0
- Preliminary work on the (currently hidden) `db` cmds. - Improvements to `--src.schema`
678 lines
20 KiB
Go
678 lines
20 KiB
Go
package mysql
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"log/slog"
|
|
"slices"
|
|
"strings"
|
|
|
|
"github.com/go-sql-driver/mysql"
|
|
"github.com/samber/lo"
|
|
"github.com/xo/dburl"
|
|
|
|
"github.com/neilotoole/sq/libsq/ast"
|
|
"github.com/neilotoole/sq/libsq/ast/render"
|
|
"github.com/neilotoole/sq/libsq/core/errz"
|
|
"github.com/neilotoole/sq/libsq/core/jointype"
|
|
"github.com/neilotoole/sq/libsq/core/kind"
|
|
"github.com/neilotoole/sq/libsq/core/lg"
|
|
"github.com/neilotoole/sq/libsq/core/lg/lga"
|
|
"github.com/neilotoole/sq/libsq/core/lg/lgm"
|
|
"github.com/neilotoole/sq/libsq/core/loz"
|
|
"github.com/neilotoole/sq/libsq/core/options"
|
|
"github.com/neilotoole/sq/libsq/core/record"
|
|
"github.com/neilotoole/sq/libsq/core/retry"
|
|
"github.com/neilotoole/sq/libsq/core/schema"
|
|
"github.com/neilotoole/sq/libsq/core/sqlz"
|
|
"github.com/neilotoole/sq/libsq/core/stringz"
|
|
"github.com/neilotoole/sq/libsq/core/tablefq"
|
|
"github.com/neilotoole/sq/libsq/driver"
|
|
"github.com/neilotoole/sq/libsq/driver/dialect"
|
|
"github.com/neilotoole/sq/libsq/source"
|
|
"github.com/neilotoole/sq/libsq/source/drivertype"
|
|
"github.com/neilotoole/sq/libsq/source/metadata"
|
|
)
|
|
|
|
var _ driver.Provider = (*Provider)(nil)
|
|
|
|
// Provider is the MySQL implementation of driver.Provider.
|
|
type Provider struct {
|
|
Log *slog.Logger
|
|
}
|
|
|
|
// DriverFor implements driver.Provider.
|
|
func (p *Provider) DriverFor(typ drivertype.Type) (driver.Driver, error) {
|
|
if typ != drivertype.MySQL {
|
|
return nil, errz.Errorf("unsupported driver type {%s}", typ)
|
|
}
|
|
|
|
return &driveri{log: p.Log}, nil
|
|
}
|
|
|
|
var _ driver.SQLDriver = (*driveri)(nil)
|
|
|
|
// driveri is the MySQL implementation of driver.Driver.
|
|
type driveri struct {
|
|
log *slog.Logger
|
|
}
|
|
|
|
// ConnParams implements driver.SQLDriver.
|
|
// See: https://github.com/go-sql-driver/mysql#dsn-data-source-name.
|
|
func (d *driveri) ConnParams() map[string][]string {
|
|
return map[string][]string{
|
|
"allowAllFiles": {"false", "true"},
|
|
"allowCleartextPasswords": {"false", "true"},
|
|
"allowFallbackToPlaintext": {"false", "true"},
|
|
"allowNativePasswords": {"false", "true"},
|
|
"allowOldPasswords": {"false", "true"},
|
|
"charset": nil,
|
|
"checkConnLiveness": {"true", "false"},
|
|
"clientFoundRows": {"false", "true"},
|
|
"collation": collations,
|
|
"columnsWithAlias": {"false", "true"},
|
|
"connectionAttributes": nil,
|
|
"interpolateParams": {"false", "true"},
|
|
"loc": {"UTC"},
|
|
"maxAllowedPackage": {"0", "67108864"},
|
|
"multiStatements": {"false", "true"},
|
|
"parseTime": {"false", "true"},
|
|
"readTimeout": {"0"},
|
|
"rejectReadOnly": {"false", "true"},
|
|
"timeout": nil,
|
|
"tls": {"false", "true", "skip-verify", "preferred"},
|
|
"writeTimeout": {"0"},
|
|
}
|
|
}
|
|
|
|
// ErrWrapFunc implements driver.SQLDriver.
|
|
func (d *driveri) ErrWrapFunc() func(error) error {
|
|
return errw
|
|
}
|
|
|
|
// DBProperties implements driver.SQLDriver.
|
|
func (d *driveri) DBProperties(ctx context.Context, db sqlz.DB) (map[string]any, error) {
|
|
return getDBProperties(ctx, db)
|
|
}
|
|
|
|
// DriverMetadata implements driver.Driver.
|
|
func (d *driveri) DriverMetadata() driver.Metadata {
|
|
return driver.Metadata{
|
|
Type: drivertype.MySQL,
|
|
Description: "MySQL",
|
|
Doc: "https://github.com/go-sql-driver/mysql",
|
|
IsSQL: true,
|
|
DefaultPort: 3306,
|
|
}
|
|
}
|
|
|
|
// Dialect implements driver.Driver.
|
|
func (d *driveri) Dialect() dialect.Dialect {
|
|
return dialect.Dialect{
|
|
Type: drivertype.MySQL,
|
|
Placeholders: placeholders,
|
|
Enquote: stringz.BacktickQuote,
|
|
IntBool: true,
|
|
MaxBatchValues: 250,
|
|
Ops: dialect.DefaultOps(),
|
|
Joins: lo.Without(jointype.All(), jointype.FullOuter),
|
|
Catalog: false,
|
|
}
|
|
}
|
|
|
|
func placeholders(numCols, numRows int) string {
|
|
rows := make([]string, numRows)
|
|
for i := 0; i < numRows; i++ {
|
|
rows[i] = "(" + stringz.RepeatJoin("?", numCols, driver.Comma) + ")"
|
|
}
|
|
return strings.Join(rows, driver.Comma)
|
|
}
|
|
|
|
// Renderer implements driver.SQLDriver.
|
|
func (d *driveri) Renderer() *render.Renderer {
|
|
r := render.NewDefaultRenderer()
|
|
r.FunctionNames[ast.FuncNameSchema] = "DATABASE"
|
|
r.FunctionOverrides[ast.FuncNameCatalog] = doRenderFuncCatalog
|
|
r.FunctionOverrides[ast.FuncNameRowNum] = renderFuncRowNum
|
|
return r
|
|
}
|
|
|
|
// RecordMeta implements driver.SQLDriver.
|
|
func (d *driveri) RecordMeta(ctx context.Context, colTypes []*sql.ColumnType) (
|
|
record.Meta, driver.NewRecordFunc, error,
|
|
) {
|
|
recMeta, err := recordMetaFromColumnTypes(ctx, colTypes)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
mungeFn := getNewRecordFunc(recMeta)
|
|
return recMeta, mungeFn, nil
|
|
}
|
|
|
|
// CreateSchema implements driver.SQLDriver.
|
|
func (d *driveri) CreateSchema(ctx context.Context, db sqlz.DB, schemaName string) error {
|
|
stmt := `CREATE SCHEMA ` + stringz.BacktickQuote(schemaName)
|
|
_, err := db.ExecContext(ctx, stmt)
|
|
return errz.Wrapf(err, "failed to create schema {%s}", schemaName)
|
|
}
|
|
|
|
// DropSchema implements driver.SQLDriver.
|
|
func (d *driveri) DropSchema(ctx context.Context, db sqlz.DB, schemaName string) error {
|
|
stmt := `DROP SCHEMA ` + stringz.BacktickQuote(schemaName)
|
|
_, err := db.ExecContext(ctx, stmt)
|
|
return errz.Wrapf(err, "failed to drop schema {%s}", schemaName)
|
|
}
|
|
|
|
// CreateTable implements driver.SQLDriver.
|
|
func (d *driveri) CreateTable(ctx context.Context, db sqlz.DB, tblDef *schema.Table) error {
|
|
createStmt := buildCreateTableStmt(tblDef)
|
|
|
|
_, err := db.ExecContext(ctx, createStmt)
|
|
return errw(err)
|
|
}
|
|
|
|
// AlterTableAddColumn implements driver.SQLDriver.
|
|
func (d *driveri) AlterTableAddColumn(ctx context.Context, db sqlz.DB, tbl, col string, knd kind.Kind) error {
|
|
q := fmt.Sprintf("ALTER TABLE `%s` ADD COLUMN `%s` ", tbl, col) + dbTypeNameFromKind(knd)
|
|
|
|
_, err := db.ExecContext(ctx, q)
|
|
if err != nil {
|
|
return errz.Wrapf(errw(err), "alter table: failed to add column {%s} to table {%s}", col, tbl)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// CurrentSchema implements driver.SQLDriver.
|
|
func (d *driveri) CurrentSchema(ctx context.Context, db sqlz.DB) (string, error) {
|
|
var name string
|
|
if err := db.QueryRowContext(ctx, `SELECT DATABASE()`).Scan(&name); err != nil {
|
|
return "", errw(err)
|
|
}
|
|
|
|
return name, nil
|
|
}
|
|
|
|
// ListSchemas implements driver.SQLDriver.
|
|
func (d *driveri) ListSchemas(ctx context.Context, db sqlz.DB) ([]string, error) {
|
|
log := lg.FromContext(ctx)
|
|
|
|
const q = `SHOW DATABASES`
|
|
var schemas []string
|
|
rows, err := db.QueryContext(ctx, q)
|
|
if err != nil {
|
|
return nil, errz.Err(err)
|
|
}
|
|
|
|
defer lg.WarnIfCloseError(log, lgm.CloseDBRows, rows)
|
|
|
|
for rows.Next() {
|
|
var schma string
|
|
if err = rows.Scan(&schma); err != nil {
|
|
return nil, errz.Err(err)
|
|
}
|
|
schemas = append(schemas, schma)
|
|
}
|
|
|
|
if err = rows.Err(); err != nil {
|
|
return nil, errz.Err(err)
|
|
}
|
|
|
|
slices.Sort(schemas)
|
|
|
|
return schemas, nil
|
|
}
|
|
|
|
// SchemaExists implements driver.SQLDriver.
|
|
func (d *driveri) SchemaExists(ctx context.Context, db sqlz.DB, schma string) (bool, error) {
|
|
if schma == "" {
|
|
return false, nil
|
|
}
|
|
|
|
const q = "SELECT COUNT(SCHEMA_NAME) FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME = ?"
|
|
var count int
|
|
return count > 0, errw(db.QueryRowContext(ctx, q, schma).Scan(&count))
|
|
}
|
|
|
|
// ListSchemaMetadata implements driver.SQLDriver.
|
|
func (d *driveri) ListSchemaMetadata(ctx context.Context, db sqlz.DB) ([]*metadata.Schema, error) {
|
|
log := lg.FromContext(ctx)
|
|
|
|
const q = `SELECT SCHEMA_NAME, CATALOG_NAME, '' FROM INFORMATION_SCHEMA.SCHEMATA
|
|
ORDER BY SCHEMA_NAME`
|
|
var schemas []*metadata.Schema
|
|
rows, err := db.QueryContext(ctx, q)
|
|
if err != nil {
|
|
return nil, errw(err)
|
|
}
|
|
|
|
defer lg.WarnIfCloseError(log, lgm.CloseDBRows, rows)
|
|
|
|
var name string
|
|
var catalog, owner sql.NullString
|
|
|
|
for rows.Next() {
|
|
if err = rows.Scan(&name, &catalog, &owner); err != nil {
|
|
return nil, errw(err)
|
|
}
|
|
s := &metadata.Schema{
|
|
Name: name,
|
|
Catalog: catalog.String,
|
|
Owner: owner.String,
|
|
}
|
|
|
|
schemas = append(schemas, s)
|
|
}
|
|
|
|
if err = rows.Err(); err != nil {
|
|
return nil, errw(err)
|
|
}
|
|
|
|
return schemas, nil
|
|
}
|
|
|
|
// CurrentCatalog implements driver.SQLDriver. Although MySQL doesn't really
|
|
// support catalogs, we return the value found in INFORMATION_SCHEMA.SCHEMATA,
|
|
// i.e. "def".
|
|
func (d *driveri) CurrentCatalog(ctx context.Context, db sqlz.DB) (string, error) {
|
|
var catalog string
|
|
|
|
if err := db.QueryRowContext(ctx, selectCatalog).Scan(&catalog); err != nil {
|
|
return "", errw(err)
|
|
}
|
|
return catalog, nil
|
|
}
|
|
|
|
// ListTableNames implements driver.SQLDriver.
|
|
func (d *driveri) ListTableNames(ctx context.Context, db sqlz.DB, schma string, tables, views bool) ([]string, error) {
|
|
var tblClause string
|
|
switch {
|
|
case tables && views:
|
|
tblClause = " AND (TABLE_TYPE = 'BASE TABLE' OR TABLE_TYPE = 'VIEW')"
|
|
case tables:
|
|
tblClause = " AND TABLE_TYPE = 'BASE TABLE'"
|
|
case views:
|
|
tblClause = " AND TABLE_TYPE = 'VIEW'"
|
|
default:
|
|
return []string{}, nil
|
|
}
|
|
|
|
var args []any
|
|
q := "SELECT TABLE_NAME FROM information_schema.TABLES WHERE TABLE_SCHEMA = "
|
|
if schma == "" {
|
|
q += "DATABASE()"
|
|
} else {
|
|
q += "?"
|
|
args = append(args, schma)
|
|
}
|
|
q += tblClause + " ORDER BY TABLE_NAME"
|
|
|
|
rows, err := db.QueryContext(ctx, q, args...)
|
|
if err != nil {
|
|
return nil, errw(err)
|
|
}
|
|
|
|
names, err := sqlz.RowsScanColumn[string](ctx, rows)
|
|
if err != nil {
|
|
return nil, errw(err)
|
|
}
|
|
|
|
return names, nil
|
|
}
|
|
|
|
// ListCatalogs implements driver.SQLDriver. MySQL does not really support catalogs,
|
|
// but this method simply delegates to CurrentCatalog, which returns the value
|
|
// found in INFORMATION_SCHEMA.SCHEMATA, i.e. "def".
|
|
func (d *driveri) ListCatalogs(ctx context.Context, db sqlz.DB) ([]string, error) {
|
|
catalog, err := d.CurrentCatalog(ctx, db)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return []string{catalog}, nil
|
|
}
|
|
|
|
// CatalogExists implements driver.SQLDriver. It returns true if catalog is "def",
|
|
// and false otherwise, nothing that MySQL doesn't really support catalogs.
|
|
func (d *driveri) CatalogExists(_ context.Context, _ sqlz.DB, catalog string) (bool, error) {
|
|
return catalog == "def", nil
|
|
}
|
|
|
|
// AlterTableRename implements driver.SQLDriver.
|
|
func (d *driveri) AlterTableRename(ctx context.Context, db sqlz.DB, tbl, newName string) error {
|
|
q := fmt.Sprintf("RENAME TABLE `%s` TO `%s`", tbl, newName)
|
|
_, err := db.ExecContext(ctx, q)
|
|
return errz.Wrapf(errw(err), "alter table: failed to rename table {%s} to {%s}", tbl, newName)
|
|
}
|
|
|
|
// AlterTableRenameColumn implements driver.SQLDriver.
|
|
func (d *driveri) AlterTableRenameColumn(ctx context.Context, db sqlz.DB, tbl, col, newName string) error {
|
|
q := fmt.Sprintf("ALTER TABLE `%s` RENAME COLUMN `%s` TO `%s`", tbl, col, newName)
|
|
_, err := db.ExecContext(ctx, q)
|
|
return errz.Wrapf(errw(err), "alter table: failed to rename column {%s.%s} to {%s}", tbl, col, newName)
|
|
}
|
|
|
|
// PrepareInsertStmt implements driver.SQLDriver.
|
|
func (d *driveri) PrepareInsertStmt(ctx context.Context, db sqlz.DB, destTbl string, destColNames []string,
|
|
numRows int,
|
|
) (*driver.StmtExecer, error) {
|
|
destColsMeta, err := d.getTableRecordMeta(ctx, db, destTbl, destColNames)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
stmt, err := driver.PrepareInsertStmt(ctx, d, db, destTbl, destColsMeta.Names(), numRows)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
execer := driver.NewStmtExecer(stmt, newInsertMungeFunc(destTbl, destColsMeta), newStmtExecFunc(stmt), destColsMeta)
|
|
return execer, nil
|
|
}
|
|
|
|
// PrepareUpdateStmt implements driver.SQLDriver.
|
|
func (d *driveri) PrepareUpdateStmt(ctx context.Context, db sqlz.DB, destTbl string, destColNames []string,
|
|
where string,
|
|
) (*driver.StmtExecer, error) {
|
|
destColsMeta, err := d.getTableRecordMeta(ctx, db, destTbl, destColNames)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
query, err := buildUpdateStmt(destTbl, destColNames, where)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
stmt, err := db.PrepareContext(ctx, query)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
execer := driver.NewStmtExecer(stmt, newInsertMungeFunc(destTbl, destColsMeta), newStmtExecFunc(stmt), destColsMeta)
|
|
return execer, nil
|
|
}
|
|
|
|
func newStmtExecFunc(stmt *sql.Stmt) driver.StmtExecFunc {
|
|
return func(ctx context.Context, args ...any) (int64, error) {
|
|
res, err := stmt.ExecContext(ctx, args...)
|
|
if err != nil {
|
|
return 0, errw(err)
|
|
}
|
|
affected, err := res.RowsAffected()
|
|
return affected, errw(err)
|
|
}
|
|
}
|
|
|
|
// CopyTable implements driver.SQLDriver.
|
|
func (d *driveri) CopyTable(ctx context.Context, db sqlz.DB,
|
|
fromTable, toTable tablefq.T, copyData bool,
|
|
) (int64, error) {
|
|
stmt := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s SELECT * FROM %s",
|
|
tblfmt(toTable), tblfmt(fromTable))
|
|
|
|
if !copyData {
|
|
stmt += " WHERE 0"
|
|
}
|
|
|
|
affected, err := sqlz.ExecAffected(ctx, db, stmt)
|
|
if err != nil {
|
|
return 0, errw(err)
|
|
}
|
|
|
|
return affected, nil
|
|
}
|
|
|
|
// TableExists implements driver.SQLDriver.
|
|
func (d *driveri) TableExists(ctx context.Context, db sqlz.DB, tbl string) (bool, error) {
|
|
const query = `SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = ?`
|
|
|
|
var count int64
|
|
err := db.QueryRowContext(ctx, query, tbl).Scan(&count)
|
|
if err != nil {
|
|
return false, errw(err)
|
|
}
|
|
|
|
return count == 1, nil
|
|
}
|
|
|
|
// DropTable implements driver.SQLDriver.
|
|
func (d *driveri) DropTable(ctx context.Context, db sqlz.DB, tbl tablefq.T, ifExists bool) error {
|
|
var stmt string
|
|
|
|
if ifExists {
|
|
stmt = fmt.Sprintf("DROP TABLE IF EXISTS %s RESTRICT", tblfmt(tbl))
|
|
} else {
|
|
stmt = fmt.Sprintf("DROP TABLE %s RESTRICT", tblfmt(tbl))
|
|
}
|
|
|
|
_, err := db.ExecContext(ctx, stmt)
|
|
return errw(err)
|
|
}
|
|
|
|
// TableColumnTypes implements driver.SQLDriver.
|
|
func (d *driveri) TableColumnTypes(ctx context.Context, db sqlz.DB, tblName string,
|
|
colNames []string,
|
|
) ([]*sql.ColumnType, error) {
|
|
const queryTpl = "SELECT %s FROM %s LIMIT 0"
|
|
|
|
enquote := d.Dialect().Enquote
|
|
tblNameQuoted := enquote(tblName)
|
|
|
|
colsClause := "*"
|
|
if len(colNames) > 0 {
|
|
colNamesQuoted := loz.Apply(colNames, enquote)
|
|
colsClause = strings.Join(colNamesQuoted, driver.Comma)
|
|
}
|
|
|
|
query := fmt.Sprintf(queryTpl, colsClause, tblNameQuoted)
|
|
rows, err := db.QueryContext(ctx, query)
|
|
if err != nil {
|
|
return nil, errw(err)
|
|
}
|
|
|
|
colTypes, err := rows.ColumnTypes()
|
|
if err != nil {
|
|
lg.WarnIfFuncError(d.log, lgm.CloseDBRows, rows.Close)
|
|
return nil, errw(err)
|
|
}
|
|
|
|
err = rows.Err()
|
|
if err != nil {
|
|
lg.WarnIfFuncError(d.log, lgm.CloseDBRows, rows.Close)
|
|
return nil, errw(err)
|
|
}
|
|
|
|
err = rows.Close()
|
|
if err != nil {
|
|
return nil, errw(err)
|
|
}
|
|
|
|
return colTypes, nil
|
|
}
|
|
|
|
func (d *driveri) getTableRecordMeta(ctx context.Context, db sqlz.DB, tblName string, colNames []string) (
|
|
record.Meta, error,
|
|
) {
|
|
colTypes, err := d.TableColumnTypes(ctx, db, tblName, colNames)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
destCols, _, err := d.RecordMeta(ctx, colTypes)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return destCols, nil
|
|
}
|
|
|
|
// Open implements driver.Driver.
|
|
func (d *driveri) Open(ctx context.Context, src *source.Source) (driver.Grip, error) {
|
|
lg.FromContext(ctx).Debug(lgm.OpenSrc, lga.Src, src)
|
|
|
|
db, err := d.doOpen(ctx, src)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if err = driver.OpeningPing(ctx, src, db); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &grip{log: d.log, db: db, src: src, drvr: d}, nil
|
|
}
|
|
|
|
func (d *driveri) doOpen(ctx context.Context, src *source.Source) (*sql.DB, error) {
|
|
dsn, err := dsnFromLocation(src, true)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
cfg, err := mysql.ParseDSN(dsn)
|
|
if err != nil {
|
|
return nil, errw(err)
|
|
}
|
|
|
|
cfg.Timeout = driver.OptConnOpenTimeout.Get(src.Options)
|
|
// REVISIT: Perhaps allow setting cfg.ReadTimeout and cfg.WriteTimeout?
|
|
// - https://github.com/go-sql-driver/mysql#writetimeout
|
|
// - https://github.com/go-sql-driver/mysql#readtimeout
|
|
|
|
if src.Schema != "" {
|
|
lg.FromContext(ctx).Debug("Setting default schema for MysQL connection",
|
|
lga.Src, src,
|
|
lga.Schema, src.Schema,
|
|
)
|
|
cfg.DBName = src.Schema
|
|
}
|
|
|
|
connector, err := mysql.NewConnector(cfg)
|
|
if err != nil {
|
|
return nil, errw(err)
|
|
}
|
|
|
|
db := sql.OpenDB(connector)
|
|
driver.ConfigureDB(ctx, db, src.Options)
|
|
return db, nil
|
|
}
|
|
|
|
// ValidateSource implements driver.Driver.
|
|
func (d *driveri) ValidateSource(src *source.Source) (*source.Source, error) {
|
|
if src.Type != drivertype.MySQL {
|
|
return nil, errz.Errorf("expected driver type {%s} but got {%s}", drivertype.MySQL, src.Type)
|
|
}
|
|
return src, nil
|
|
}
|
|
|
|
// Ping implements driver.Driver.
|
|
func (d *driveri) Ping(ctx context.Context, src *source.Source) error {
|
|
db, err := d.doOpen(ctx, src)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer lg.WarnIfCloseError(d.log, lgm.CloseDB, db)
|
|
|
|
return errz.Wrapf(errw(db.PingContext(ctx)), "ping %s", src.Handle)
|
|
}
|
|
|
|
// Truncate implements driver.SQLDriver. Arg reset is
|
|
// always ignored: the identity value is always reset by
|
|
// the TRUNCATE statement.
|
|
func (d *driveri) Truncate(ctx context.Context, src *source.Source, tbl string, _ bool) (affected int64,
|
|
err error,
|
|
) {
|
|
// https://dev.mysql.com/doc/refman/8.0/en/truncate-table.html
|
|
db, err := d.doOpen(ctx, src)
|
|
if err != nil {
|
|
return 0, errw(err)
|
|
}
|
|
defer lg.WarnIfFuncError(d.log, lgm.CloseDB, db.Close)
|
|
|
|
// Not sure about the Tx requirements?
|
|
tx, err := db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelSerializable})
|
|
if err != nil {
|
|
return 0, errw(err)
|
|
}
|
|
|
|
// For whatever reason, the "affected" count from TRUNCATE
|
|
// always returns zero. So, we're going to synthesize it.
|
|
var beforeCount int64
|
|
err = tx.QueryRowContext(ctx, fmt.Sprintf("SELECT COUNT(*) FROM `%s`", tbl)).Scan(&beforeCount)
|
|
if err != nil {
|
|
return 0, errz.Append(err, errw(tx.Rollback()))
|
|
}
|
|
|
|
affected, err = sqlz.ExecAffected(ctx, tx, fmt.Sprintf("TRUNCATE TABLE `%s`", tbl))
|
|
if err != nil {
|
|
return affected, errz.Append(err, errw(tx.Rollback()))
|
|
}
|
|
|
|
if affected != 0 {
|
|
// Note: At the time of writing, this doesn't happen:
|
|
// zero is always returned (which we don't like).
|
|
// If this changes (driver changes?) then we'll revisit.
|
|
d.log.Warn("Unexpectedly got non-zero rows affected from TRUNCATE", lga.Count, affected)
|
|
return affected, errw(tx.Commit())
|
|
}
|
|
|
|
// TRUNCATE succeeded, therefore tbl is empty, therefore
|
|
// the count of truncated rows must be beforeCount?
|
|
return beforeCount, errw(tx.Commit())
|
|
}
|
|
|
|
// dsnFromLocation builds the mysql driver DSN from src.Location.
|
|
// If parseTime is true, the param "parseTime=true" is added. This
|
|
// is because of: https://stackoverflow.com/questions/29341590/how-to-parse-time-from-database/29343013#29343013
|
|
func dsnFromLocation(src *source.Source, parseTime bool) (string, error) {
|
|
if !strings.HasPrefix(src.Location, "mysql://") || len(src.Location) < 10 {
|
|
return "", errz.Errorf("invalid source location %s", src.RedactedLocation())
|
|
}
|
|
|
|
u, err := dburl.Parse(src.Location)
|
|
if err != nil {
|
|
return "", errz.Wrapf(errw(err), "invalid source location %s", src.RedactedLocation())
|
|
}
|
|
|
|
// Convert the location to the desired driver DSN.
|
|
// Location: mysql://sakila:p_ssW0rd@localhost:3306/sqtest?allowOldPasswords=1
|
|
// Driver DSN: sakila:p_ssW0rd@tcp(localhost:3306)/sqtest?allowOldPasswords=1
|
|
driverDSN := u.DSN
|
|
|
|
myCfg, err := mysql.ParseDSN(driverDSN) // verify
|
|
if err != nil {
|
|
return "", errz.Wrapf(errw(err), "invalid source location: %s", driverDSN)
|
|
}
|
|
|
|
myCfg.ParseTime = parseTime
|
|
driverDSN = myCfg.FormatDSN()
|
|
|
|
return driverDSN, nil
|
|
}
|
|
|
|
// doRetry executes fn with retry on isErrTooManyConnections.
|
|
func doRetry(ctx context.Context, fn func() error) error {
|
|
maxRetryInterval := driver.OptMaxRetryInterval.Get(options.FromContext(ctx))
|
|
return retry.Do(ctx, maxRetryInterval, fn, isErrTooManyConnections)
|
|
}
|
|
|
|
// tblfmt formats a table name for use in a query. The arg can be a string,
|
|
// or a tablefq.T.
|
|
func tblfmt[T string | tablefq.T](tbl T) string {
|
|
tfq := tablefq.From(tbl)
|
|
return tfq.Render(stringz.BacktickQuote)
|
|
}
|
|
|
|
const selectCatalog = `SELECT CATALOG_NAME FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME = DATABASE() LIMIT 1`
|
|
|
|
func doRenderFuncCatalog(_ *render.Context, fn *ast.FuncNode) (string, error) {
|
|
if fn.FuncName() != ast.FuncNameCatalog {
|
|
// Shouldn't happen
|
|
return "", errz.Errorf("expected %s function, got %q", ast.FuncNameCatalog, fn.FuncName())
|
|
}
|
|
|
|
const frag = `(` + selectCatalog + `)`
|
|
return frag, nil
|
|
}
|