mirror of
https://github.com/neilotoole/sq.git
synced 2024-12-18 21:52:28 +03:00
1ceb50e795
* initial refactoring for the numRows param * work on driver.NewBatchInsert * work on NewBatchInsert * batch insert seems to work * switched testh.Insert to use BatchInsert * doc cleanup * batch insert for dbwriter and csv * removed unneeded NumRows from driver.StmtExecer * minor tidyup
351 lines
9.6 KiB
Go
351 lines
9.6 KiB
Go
package mysql
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"strings"
|
|
|
|
"github.com/go-sql-driver/mysql"
|
|
"github.com/neilotoole/lg"
|
|
"github.com/xo/dburl"
|
|
|
|
"github.com/neilotoole/sq/libsq/driver"
|
|
"github.com/neilotoole/sq/libsq/errz"
|
|
"github.com/neilotoole/sq/libsq/source"
|
|
"github.com/neilotoole/sq/libsq/sqlbuilder"
|
|
"github.com/neilotoole/sq/libsq/sqlmodel"
|
|
"github.com/neilotoole/sq/libsq/sqlz"
|
|
"github.com/neilotoole/sq/libsq/stringz"
|
|
)
|
|
|
|
const (
|
|
// Type is the MySQL source driver type.
|
|
Type = source.Type("mysql")
|
|
|
|
// dbDrvr is the backing MySQL SQL driver impl name.
|
|
dbDrvr = "mysql"
|
|
)
|
|
|
|
// Provider is the MySQL implementation of driver.Provider.
|
|
type Provider struct {
|
|
Log lg.Log
|
|
}
|
|
|
|
// DriverFor implements driver.Provider.
|
|
func (p *Provider) DriverFor(typ source.Type) (driver.Driver, error) {
|
|
if typ != Type {
|
|
return nil, errz.Errorf("unsupported driver type %q", typ)
|
|
}
|
|
|
|
return &Driver{log: p.Log}, nil
|
|
}
|
|
|
|
// Driver is the MySQL implementation of driver.Driver.
|
|
type Driver struct {
|
|
log lg.Log
|
|
}
|
|
|
|
// DriverMetadata implements driver.Driver.
|
|
func (d *Driver) DriverMetadata() driver.Metadata {
|
|
return driver.Metadata{
|
|
Type: Type,
|
|
Description: "MySQL",
|
|
Doc: "https://github.com/go-sql-driver/mysql",
|
|
IsSQL: true,
|
|
}
|
|
}
|
|
|
|
// Dialect implements driver.Driver.
|
|
func (d *Driver) Dialect() driver.Dialect {
|
|
return driver.Dialect{
|
|
Type: Type,
|
|
Placeholders: placeholders,
|
|
Quote: '`',
|
|
IntBool: true,
|
|
MaxBatchValues: 250,
|
|
}
|
|
}
|
|
|
|
func placeholders(numCols, numRows int) string {
|
|
rows := make([]string, numRows)
|
|
for i := 0; i < numRows; i++ {
|
|
rows[i] = "(" + stringz.RepeatJoin("?", numCols, driver.Comma) + ")"
|
|
}
|
|
return strings.Join(rows, driver.Comma)
|
|
}
|
|
|
|
// SQLBuilder implements driver.SQLDriver.
|
|
func (d *Driver) SQLBuilder() (sqlbuilder.FragmentBuilder, sqlbuilder.QueryBuilder) {
|
|
return newFragmentBuilder(d.log), &sqlbuilder.BaseQueryBuilder{}
|
|
}
|
|
|
|
// RecordMeta implements driver.SQLDriver.
|
|
func (d *Driver) RecordMeta(colTypes []*sql.ColumnType) (sqlz.RecordMeta, driver.NewRecordFunc, error) {
|
|
recMeta := recordMetaFromColumnTypes(d.log, colTypes)
|
|
mungeFn := getNewRecordFunc(recMeta)
|
|
return recMeta, mungeFn, nil
|
|
}
|
|
|
|
// CreateTable implements driver.SQLDriver.
|
|
func (d *Driver) CreateTable(ctx context.Context, db sqlz.DB, tblDef *sqlmodel.TableDef) error {
|
|
createStmt, err := buildCreateTableStmt(tblDef)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
_, err = db.ExecContext(ctx, createStmt)
|
|
return errz.Err(err)
|
|
}
|
|
|
|
// PrepareInsertStmt implements driver.SQLDriver.
|
|
func (d *Driver) PrepareInsertStmt(ctx context.Context, db sqlz.DB, destTbl string, destColNames []string, numRows int) (*driver.StmtExecer, error) {
|
|
destColsMeta, err := d.getTableRecordMeta(ctx, db, destTbl, destColNames)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
stmt, err := driver.PrepareInsertStmt(ctx, d, db, destTbl, destColsMeta.Names(), numRows)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
execer := driver.NewStmtExecer(stmt, newInsertMungeFunc(destTbl, destColsMeta), newStmtExecFunc(stmt), destColsMeta)
|
|
return execer, nil
|
|
}
|
|
|
|
// PrepareUpdateStmt implements driver.SQLDriver.
|
|
func (d *Driver) PrepareUpdateStmt(ctx context.Context, db sqlz.DB, destTbl string, destColNames []string, where string) (*driver.StmtExecer, error) {
|
|
destColsMeta, err := d.getTableRecordMeta(ctx, db, destTbl, destColNames)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
query, err := buildUpdateStmt(destTbl, destColNames, where)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
stmt, err := db.PrepareContext(ctx, query)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
execer := driver.NewStmtExecer(stmt, newInsertMungeFunc(destTbl, destColsMeta), newStmtExecFunc(stmt), destColsMeta)
|
|
return execer, nil
|
|
}
|
|
|
|
func newStmtExecFunc(stmt *sql.Stmt) driver.StmtExecFunc {
|
|
return func(ctx context.Context, args ...interface{}) (int64, error) {
|
|
res, err := stmt.ExecContext(ctx, args...)
|
|
if err != nil {
|
|
return 0, errz.Err(err)
|
|
}
|
|
affected, err := res.RowsAffected()
|
|
return affected, errz.Err(err)
|
|
}
|
|
}
|
|
|
|
// CopyTable implements driver.SQLDriver.
|
|
func (d *Driver) CopyTable(ctx context.Context, db sqlz.DB, fromTable, toTable string, copyData bool) (int64, error) {
|
|
stmt := fmt.Sprintf("CREATE TABLE IF NOT EXISTS `%s` SELECT * FROM `%s`", toTable, fromTable)
|
|
|
|
if !copyData {
|
|
stmt += " WHERE 0"
|
|
}
|
|
|
|
affected, err := sqlz.ExecResult(ctx, db, stmt)
|
|
if err != nil {
|
|
return 0, errz.Err(err)
|
|
}
|
|
|
|
return affected, nil
|
|
}
|
|
|
|
// DropTable implements driver.SQLDriver.
|
|
func (d *Driver) DropTable(ctx context.Context, db sqlz.DB, tbl string, ifExists bool) error {
|
|
var stmt string
|
|
|
|
if ifExists {
|
|
stmt = fmt.Sprintf("DROP TABLE IF EXISTS `%s` RESTRICT", tbl)
|
|
} else {
|
|
stmt = fmt.Sprintf("DROP TABLE `%s` RESTRICT", tbl)
|
|
}
|
|
|
|
_, err := db.ExecContext(ctx, stmt)
|
|
return errz.Err(err)
|
|
}
|
|
|
|
// TableColumnTypes implements driver.SQLDriver.
|
|
func (d *Driver) TableColumnTypes(ctx context.Context, db sqlz.DB, tblName string, colNames []string) ([]*sql.ColumnType, error) {
|
|
const queryTpl = "SELECT %s FROM %s LIMIT 0"
|
|
|
|
dialect := d.Dialect()
|
|
quote := string(dialect.Quote)
|
|
tblNameQuoted := stringz.Surround(tblName, quote)
|
|
|
|
var colsClause = "*"
|
|
if len(colNames) > 0 {
|
|
colNamesQuoted := stringz.SurroundSlice(colNames, quote)
|
|
colsClause = strings.Join(colNamesQuoted, driver.Comma)
|
|
}
|
|
|
|
query := fmt.Sprintf(queryTpl, colsClause, tblNameQuoted)
|
|
rows, err := db.QueryContext(ctx, query)
|
|
if err != nil {
|
|
return nil, errz.Err(err)
|
|
}
|
|
|
|
colTypes, err := rows.ColumnTypes()
|
|
if err != nil {
|
|
d.log.WarnIfFuncError(rows.Close)
|
|
return nil, errz.Err(err)
|
|
}
|
|
|
|
err = rows.Err()
|
|
if err != nil {
|
|
d.log.WarnIfFuncError(rows.Close)
|
|
return nil, errz.Err(err)
|
|
}
|
|
|
|
err = rows.Close()
|
|
if err != nil {
|
|
return nil, errz.Err(err)
|
|
}
|
|
|
|
return colTypes, nil
|
|
}
|
|
|
|
func (d *Driver) getTableRecordMeta(ctx context.Context, db sqlz.DB, tblName string, colNames []string) (sqlz.RecordMeta, error) {
|
|
colTypes, err := d.TableColumnTypes(ctx, db, tblName, colNames)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
destCols, _, err := d.RecordMeta(colTypes)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return destCols, nil
|
|
}
|
|
|
|
// SourceDSN extracts the mysql driver DSN from src.Location.
|
|
func SourceDSN(src *source.Source) (string, error) {
|
|
if !strings.HasPrefix(src.Location, "mysql://") || len(src.Location) < 10 {
|
|
return "", errz.Errorf("invalid source location %s", src.RedactedLocation())
|
|
}
|
|
|
|
u, err := dburl.Parse(src.Location)
|
|
if err != nil {
|
|
return "", errz.Wrapf(err, "invalid source location %s", src.RedactedLocation())
|
|
}
|
|
|
|
// Convert the location to the desired driver DSN.
|
|
// Location: mysql://sakila:p_ssW0rd@localhost:3306/sqtest
|
|
// Driver DSN: sakila:p_ssW0rd@tcp(localhost:3306)/sqtest
|
|
driverDSN := fmt.Sprintf("%s@tcp(%s)%s", u.User.String(), u.Host, u.Path)
|
|
|
|
// REVISIT: extra check for safety, can prob delete later
|
|
_, err = mysql.ParseDSN(driverDSN)
|
|
if err != nil {
|
|
return "", errz.Wrapf(err, "invalid source location: %q", driverDSN)
|
|
}
|
|
|
|
return driverDSN, nil
|
|
}
|
|
|
|
// Open implements driver.Driver.
|
|
func (d *Driver) Open(ctx context.Context, src *source.Source) (driver.Database, error) {
|
|
dsn, err := SourceDSN(src)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
db, err := sql.Open(dbDrvr, dsn)
|
|
if err != nil {
|
|
return nil, errz.Err(err)
|
|
}
|
|
|
|
return &database{log: d.log, db: db, src: src, drvr: d}, nil
|
|
}
|
|
|
|
// ValidateSource implements driver.Driver.
|
|
func (d *Driver) ValidateSource(src *source.Source) (*source.Source, error) {
|
|
if src.Type != Type {
|
|
return nil, errz.Errorf("expected source type %q but got %q", Type, src.Type)
|
|
}
|
|
return src, nil
|
|
}
|
|
|
|
// Ping implements driver.Driver.
|
|
func (d *Driver) Ping(ctx context.Context, src *source.Source) error {
|
|
dbase, err := d.Open(context.TODO(), src)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer d.log.WarnIfCloseError(dbase.DB())
|
|
|
|
return dbase.DB().Ping()
|
|
}
|
|
|
|
// Truncate implements driver.SQLDriver. Arg reset is
|
|
// always ignored: the identity value is always reset by
|
|
// the TRUNCATE statement.
|
|
func (d *Driver) Truncate(ctx context.Context, src *source.Source, tbl string, reset bool) (affected int64, err error) {
|
|
// https://dev.mysql.com/doc/refman/8.0/en/truncate-table.html
|
|
dsn, err := SourceDSN(src)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
db, err := sql.Open(dbDrvr, dsn)
|
|
if err != nil {
|
|
return 0, errz.Err(err)
|
|
}
|
|
defer d.log.WarnIfFuncError(db.Close)
|
|
|
|
// Not sure about the Tx requirements?
|
|
tx, err := db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelSerializable})
|
|
if err != nil {
|
|
return 0, errz.Err(err)
|
|
}
|
|
|
|
// For whatever reason, the "affected" count from TRUNCATE
|
|
// always returns zero. So, we're going to synthesize it.
|
|
var beforeCount int64
|
|
err = tx.QueryRowContext(ctx, fmt.Sprintf("SELECT COUNT(*) FROM `%s`", tbl)).Scan(&beforeCount)
|
|
if err != nil {
|
|
return 0, errz.Append(err, errz.Err(tx.Rollback()))
|
|
}
|
|
|
|
affected, err = sqlz.ExecResult(ctx, tx, fmt.Sprintf("TRUNCATE TABLE `%s`", tbl))
|
|
if err != nil {
|
|
return affected, errz.Append(err, errz.Err(tx.Rollback()))
|
|
}
|
|
|
|
if affected != 0 {
|
|
// Note: At the time of writing, this doesn't happen:
|
|
// zero is always returned (which we don't like).
|
|
// If this changes (driver changes?) then we'll revisit.
|
|
d.log.Warnf("Unexpectedly got non-zero (%d) rows affected from TRUNCATE", affected)
|
|
return affected, errz.Err(tx.Commit())
|
|
}
|
|
|
|
// TRUNCATE succeeded, therefore tbl is empty, therefore
|
|
// the count of truncated rows must be beforeCount?
|
|
return beforeCount, errz.Err(tx.Commit())
|
|
}
|
|
|
|
// hasErrCode returns true if err (or its cause error)
|
|
// is of type *mysql.MySQLError and err.Number equals code.
|
|
func hasErrCode(err error, code uint16) bool {
|
|
err = errz.Cause(err)
|
|
if err2, ok := err.(*mysql.MySQLError); ok {
|
|
return err2.Number == code
|
|
}
|
|
return false
|
|
}
|
|
|
|
const errNumTableNotExist = uint16(1146)
|