sq/drivers/mysql/metadata.go
Neil O'Toole 54a0155ed6
Deal with lots of connections, especially for sq inspect. (#186)
* wip: deal with lots of connections

* Clean up TestDatabase_SourceMetadata_concurrent

* Fixed error message
2023-04-08 12:09:27 -06:00

574 lines
18 KiB
Go

package mysql
import (
"context"
"database/sql"
"fmt"
"reflect"
"strings"
"time"
"golang.org/x/sync/errgroup"
"github.com/neilotoole/sq/libsq/core/lg/lga"
"github.com/neilotoole/sq/libsq/core/lg/lgm"
"github.com/neilotoole/sq/libsq/core/lg"
"golang.org/x/exp/slog"
"github.com/neilotoole/sq/libsq/core/errz"
"github.com/neilotoole/sq/libsq/core/kind"
"github.com/neilotoole/sq/libsq/core/sqlz"
"github.com/neilotoole/sq/libsq/core/stringz"
"github.com/neilotoole/sq/libsq/driver"
"github.com/neilotoole/sq/libsq/source"
)
// kindFromDBTypeName determines the Kind from the database
// type name. For example, "VARCHAR(64)" -> kind.Text.
func kindFromDBTypeName(log *slog.Logger, colName, dbTypeName string) kind.Kind {
var knd kind.Kind
dbTypeName = strings.ToUpper(dbTypeName)
// Given variations such as VARCHAR(255), we first trim the parens
// parts. Thus VARCHAR(255) becomes VARCHAR.
i := strings.IndexRune(dbTypeName, '(')
if i > 0 {
dbTypeName = dbTypeName[0:i]
}
switch dbTypeName {
default:
log.Warn(
"Unknown MySQL column type: using alt type",
lga.DBType, dbTypeName,
lga.Col, colName,
lga.Alt, kind.Unknown,
)
knd = kind.Unknown
case "":
knd = kind.Unknown
case "INTEGER", "INT", "TINYINT", "SMALLINT", "MEDIUMINT", "BIGINT", "YEAR", "BIT",
"UNSIGNED INTEGER", "UNSIGNED INT", "UNSIGNED TINYINT",
"UNSIGNED SMALLINT", "UNSIGNED MEDIUMINT", "UNSIGNED BIGINT":
knd = kind.Int
case "DECIMAL", "NUMERIC":
knd = kind.Decimal
case "CHAR", "VARCHAR", "TEXT", "TINYTEXT", "MEDIUMTEXT", "LONGTEXT":
knd = kind.Text
case "ENUM", "SET":
knd = kind.Text
case "JSON":
knd = kind.Text
case "VARBINARY", "BINARY", "BLOB", "MEDIUMBLOB", "LONGBLOB", "TINYBLOB":
knd = kind.Bytes
case "DATETIME", "TIMESTAMP":
knd = kind.Datetime
case "DATE":
knd = kind.Date
case "TIME": //nolint:goconst
knd = kind.Time
case "FLOAT", "DOUBLE", "DOUBLE PRECISION", "REAL":
knd = kind.Float
case "BOOL", "BOOLEAN":
// In practice these are not returned by the mysql driver.
knd = kind.Bool
}
return knd
}
func recordMetaFromColumnTypes(log *slog.Logger, colTypes []*sql.ColumnType) sqlz.RecordMeta {
recMeta := make(sqlz.RecordMeta, len(colTypes))
for i, colType := range colTypes {
knd := kindFromDBTypeName(log, colType.Name(), colType.DatabaseTypeName())
colTypeData := sqlz.NewColumnTypeData(colType, knd)
recMeta[i] = sqlz.NewFieldMeta(colTypeData)
}
return recMeta
}
// getNewRecordFunc returns a NewRecordFunc that, after interacting
// with the standard driver.NewRecordFromScanRow, munges any skipped fields.
// In particular sql.NullTime is unboxed to *time.Time, and TIME fields
// are munged from RawBytes to string.
func getNewRecordFunc(rowMeta sqlz.RecordMeta) driver.NewRecordFunc {
return func(row []any) (sqlz.Record, error) {
rec, skipped := driver.NewRecordFromScanRow(rowMeta, row, nil)
// We iterate over each element of val, checking for certain
// conditions. A more efficient approach might be to (in
// the outside func) iterate over the column metadata, and
// build a list of val elements to visit.
for _, i := range skipped {
if nullTime, ok := rec[i].(*sql.NullTime); ok {
if nullTime.Valid {
// Make a copy of the value
t := nullTime.Time
rec[i] = &t
continue
}
// Else
rec[i] = nil
continue
}
if rowMeta[i].DatabaseTypeName() == "TIME" && rec[i] != nil {
// MySQL may return TIME as RawBytes... convert to a string.
// https://github.com/go-sql-driver/mysql#timetime-support
if rb, ok := rec[i].(*sql.RawBytes); ok {
if len(*rb) == 0 {
// shouldn't happen
zero := "00:00"
rec[i] = &zero
continue
}
// Else
text := string(*rb)
rec[i] = &text
}
continue
}
// else, we don't know what to do with this col
return nil, errz.Errorf("column %d %s: unknown type db(%T) with kind(%s), val(%v)", i, rowMeta[i].Name(),
rec[i], rowMeta[i].Kind(), rec[i])
}
return rec, nil
}
}
// getTableMetadata gets the metadata for a single table. It is the
// implementation of driver.Database.TableMetadata.
func getTableMetadata(ctx context.Context, db sqlz.DB, tblName string) (*source.TableMetadata, error) {
query := `SELECT TABLE_SCHEMA, TABLE_NAME, TABLE_TYPE, TABLE_COMMENT, (DATA_LENGTH + INDEX_LENGTH) AS table_size,
(SELECT COUNT(*) FROM ` + "`" + tblName + "`" + `) AS row_count
FROM information_schema.TABLES
WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = ?`
var schema string
var tblSize sql.NullInt64
tblMeta := &source.TableMetadata{}
err := db.QueryRowContext(ctx, query, tblName).
Scan(&schema, &tblMeta.Name, &tblMeta.DBTableType, &tblMeta.Comment, &tblSize, &tblMeta.RowCount)
if err != nil {
return nil, errz.Err(err)
}
tblMeta.TableType = canonicalTableType(tblMeta.DBTableType)
tblMeta.FQName = schema + "." + tblMeta.Name
if tblSize.Valid {
// For a view (as opposed to table), tblSize is typically nil
tblMeta.Size = &tblSize.Int64
}
tblMeta.Columns, err = getColumnMetadata(ctx, db, tblMeta.Name)
if err != nil {
return nil, err
}
return tblMeta, nil
}
// getColumnMetadata returns column metadata for tblName.
func getColumnMetadata(ctx context.Context, db sqlz.DB, tblName string) ([]*source.ColMetadata, error) {
log := lg.FromContext(ctx)
const query = `SELECT column_name, data_type, column_type, ordinal_position, column_default,
is_nullable, column_key, column_comment, extra
FROM information_schema.columns cols
WHERE cols.TABLE_SCHEMA = DATABASE() AND cols.TABLE_NAME = ?
ORDER BY cols.ordinal_position ASC`
rows, err := db.QueryContext(ctx, query, tblName)
if err != nil {
return nil, errz.Err(err)
}
defer lg.WarnIfCloseError(log, lgm.CloseDBRows, rows)
var cols []*source.ColMetadata
for rows.Next() {
col := &source.ColMetadata{}
var isNullable, colKey, extra string
defVal := &sql.NullString{}
err = rows.Scan(&col.Name, &col.BaseType, &col.ColumnType, &col.Position, defVal, &isNullable, &colKey,
&col.Comment, &extra)
if err != nil {
return nil, errz.Err(err)
}
if strings.EqualFold("YES", isNullable) {
col.Nullable = true
}
if strings.Contains(colKey, "PRI") {
col.PrimaryKey = true
}
col.DefaultValue = defVal.String
col.Kind = kindFromDBTypeName(log, col.Name, col.BaseType)
cols = append(cols, col)
}
return cols, errz.Err(rows.Err())
}
// getSourceMetadata is the implementation of driver.Database.SourceMetadata.
//
// Multiple queries are required to build the SourceMetadata, and this
// impl makes use of errgroup to make concurrent queries. In the initial
// relatively sequential implementation of this function, the main perf
// roadblock was getting the row count for each table/view. For accuracy
// it is necessary to perform "SELECT COUNT(*) FROM tbl" for each table/view.
// For other databases (such as sqlite) it was performant to UNION ALL
// these SELECTs into one (or a few) queries, e.g.:
//
// SELECT COUNT(*) FROM actor
// UNION ALL
// SELECT COUNT(*) FROM address
// UNION ALL
// [...]
//
// However, this seemed to perform poorly (at least for MySQL 5.6 which
// was the main focus of testing). We do seem to be getting fairly
// reasonable results by spinning off a goroutine (via errgroup) for
// each SELECT COUNT(*) query. That said, the testing/benchmarking was
// far from exhaustive, and this entire thing has a bit of a code smell.
func getSourceMetadata(ctx context.Context, src *source.Source, db sqlz.DB) (*source.Metadata, error) {
md := &source.Metadata{SourceType: Type, DBDriverType: Type, Handle: src.Handle, Location: src.Location}
g, gCtx := errgroup.WithContext(ctx)
g.SetLimit(driver.Tuning.ErrgroupLimit)
g.Go(func() error {
return setSourceSummaryMeta(gCtx, db, md)
})
g.Go(func() error {
var err error
md.DBVars, err = getDBVarsMeta(gCtx, db)
return err
})
g.Go(func() error {
var err error
md.Tables, err = getAllTblMetas(gCtx, db)
return err
})
err := g.Wait()
if err != nil {
return nil, err
}
return md, nil
}
func setSourceSummaryMeta(ctx context.Context, db sqlz.DB, md *source.Metadata) error {
const summaryQuery = `SELECT @@GLOBAL.version, @@GLOBAL.version_comment, @@GLOBAL.version_compile_os,
@@GLOBAL.version_compile_machine, DATABASE(), CURRENT_USER(),
(SELECT SUM( data_length + index_length )
FROM information_schema.TABLES WHERE TABLE_SCHEMA = DATABASE()) AS size`
var version, versionComment, versionOS, versionArch, schema string
err := db.QueryRowContext(ctx, summaryQuery).Scan(&version, &versionComment, &versionOS, &versionArch, &schema,
&md.User, &md.Size)
if err != nil {
return errz.Err(err)
}
md.Name = schema
md.Schema = schema
md.FQName = schema
md.DBVersion = version
md.DBProduct = fmt.Sprintf("%s %s / %s (%s)", versionComment, version, versionOS, versionArch)
return nil
}
// getDBVarsMeta returns the database variables.
func getDBVarsMeta(ctx context.Context, db sqlz.DB) ([]source.DBVar, error) {
log := lg.FromContext(ctx)
var dbVars []source.DBVar
rows, err := db.QueryContext(ctx, "SHOW VARIABLES")
if err != nil {
return nil, errz.Err(err)
}
defer lg.WarnIfCloseError(log, lgm.CloseDBRows, rows)
for rows.Next() {
var dbVar source.DBVar
err = rows.Scan(&dbVar.Name, &dbVar.Value)
if err != nil {
return nil, errz.Err(err)
}
dbVars = append(dbVars, dbVar)
}
err = rows.Err()
if err != nil {
return nil, errz.Err(err)
}
return dbVars, nil
}
// getAllTblMetas returns TableMetadata for each table/view in db.
func getAllTblMetas(ctx context.Context, db sqlz.DB) ([]*source.TableMetadata, error) {
log := lg.FromContext(ctx)
const query = `SELECT t.TABLE_SCHEMA, t.TABLE_NAME, t.TABLE_TYPE, t.TABLE_COMMENT,
(DATA_LENGTH + INDEX_LENGTH) AS table_size,
c.COLUMN_NAME, c.ORDINAL_POSITION, c.COLUMN_KEY, c.DATA_TYPE, c.COLUMN_TYPE,
c.IS_NULLABLE, c.COLUMN_DEFAULT, c.COLUMN_COMMENT, c.EXTRA
FROM information_schema.TABLES t
LEFT JOIN information_schema.COLUMNS c
ON c.TABLE_CATALOG = t.TABLE_CATALOG
AND c.TABLE_SCHEMA = t.TABLE_SCHEMA
AND c.TABLE_NAME = t.TABLE_NAME
WHERE t.TABLE_SCHEMA = DATABASE()
ORDER BY c.TABLE_NAME ASC, c.ORDINAL_POSITION ASC`
//nolint:lll
// Query results look like:
// +------------+----------+----------+-------------+----------+-----------+----------------+----------+---------+--------------------+-----------+-----------------+--------------+---------------------------+
// |TABLE_SCHEMA|TABLE_NAME|TABLE_TYPE|TABLE_COMMENT|table_size|COLUMN_NAME|ORDINAL_POSITION|COLUMN_KEY|DATA_TYPE|COLUMN_TYPE |IS_NULLABLE|COLUMN_DEFAULT |COLUMN_COMMENT|EXTRA |
// +------------+----------+----------+-------------+----------+-----------+----------------+----------+---------+--------------------+-----------+-----------------+--------------+---------------------------+
// |sakila |actor |BASE TABLE| |32768 |actor_id |1 |PRI |smallint |smallint(5) unsigned|NO |NULL | |auto_increment |
// |sakila |actor |BASE TABLE| |32768 |first_name |2 | |varchar |varchar(45) |NO |NULL | | |
// |sakila |actor |BASE TABLE| |32768 |last_name |3 |MUL |varchar |varchar(45) |NO |NULL | | |
// |sakila |actor |BASE TABLE| |32768 |last_update|4 | |timestamp|timestamp |NO |CURRENT_TIMESTAMP| |on update CURRENT_TIMESTAMP|
// |sakila |actor_info|VIEW |VIEW |NULL |actor_id |1 | |smallint |smallint(5) unsigned|NO |0 | | |
var tblMetas []*source.TableMetadata
var schema string
var curTblName, curTblType, curTblComment sql.NullString
var curTblSize sql.NullInt64
var curTblMeta *source.TableMetadata
// g is an errgroup for fetching the
// row count for each table.
g, gCtx := errgroup.WithContext(ctx)
g.SetLimit(driver.Tuning.ErrgroupLimit)
rows, err := db.QueryContext(ctx, query)
if err != nil {
return nil, errz.Err(err)
}
defer lg.WarnIfCloseError(log, lgm.CloseDBRows, rows)
for rows.Next() {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
var colName, colDefault, colNullable, colKey, colBaseType, colColumnType, colComment, colExtra sql.NullString
var colPosition sql.NullInt64
err = rows.Scan(&schema, &curTblName, &curTblType, &curTblComment, &curTblSize, &colName, &colPosition,
&colKey, &colBaseType, &colColumnType, &colNullable, &colDefault, &colComment, &colExtra)
if err != nil {
return nil, errz.Err(err)
}
if !curTblName.Valid || !colName.Valid {
// table may have been dropped during metadata collection
log.Debug("Table not found during metadata collection")
continue
}
if curTblMeta == nil || curTblMeta.Name != curTblName.String {
// On our first time encountering a new table name, we create a new TableMetadata
curTblMeta = &source.TableMetadata{
Name: curTblName.String,
FQName: schema + "." + curTblName.String,
DBTableType: curTblType.String,
TableType: canonicalTableType(curTblType.String),
Comment: curTblComment.String,
}
if curTblSize.Valid {
size := curTblSize.Int64
curTblMeta.Size = &size
}
tblMetas = append(tblMetas, curTblMeta)
rowCountTbl, rowCount, i := curTblName.String, &curTblMeta.RowCount, len(tblMetas)-1
g.Go(func() error {
gErr := db.QueryRowContext(gCtx, "SELECT COUNT(*) FROM `"+rowCountTbl+"`").Scan(rowCount)
if gErr != nil {
if hasErrCode(gErr, errNumTableNotExist) {
// The table was probably dropped while we were collecting
// metadata, but that's ok. We set the element to nil
// and we'll filter it out later.
log.Debug("Failed to get row count for table: ignoring error",
lga.Table, curTblName.String,
lga.Err, gErr)
tblMetas[i] = nil
return nil
}
return errz.Err(gErr)
}
return nil
})
}
col := &source.ColMetadata{
Name: colName.String,
Position: colPosition.Int64,
BaseType: colBaseType.String,
ColumnType: colColumnType.String,
DefaultValue: colDefault.String,
Comment: colComment.String,
}
col.Nullable, err = stringz.ParseBool(colNullable.String)
if err != nil {
return nil, err
}
col.Kind = kindFromDBTypeName(log, col.Name, col.BaseType)
if strings.Contains(colKey.String, "PRI") {
col.PrimaryKey = true
}
curTblMeta.Columns = append(curTblMeta.Columns, col)
}
err = g.Wait()
if err != nil {
return nil, err
}
err = rows.Err()
if err != nil {
return nil, errz.Err(err)
}
// tblMetas may contain nil elements if we failed to get the row
// count for the table (which can happen if the table is dropped
// during the metadata collection process). So we filter out any
// nil elements.
retTblMetas := make([]*source.TableMetadata, 0, len(tblMetas))
for i := range tblMetas {
if tblMetas[i] != nil {
retTblMetas = append(retTblMetas, tblMetas[i])
}
}
return retTblMetas, nil
}
// newInsertMungeFunc is lifted from driver.DefaultInsertMungeFunc.
func newInsertMungeFunc(destTbl string, destMeta sqlz.RecordMeta) driver.InsertMungeFunc {
return func(rec sqlz.Record) error {
if len(rec) != len(destMeta) {
return errz.Errorf("insert record has %d vals but dest table %s has %d cols (%s)",
len(rec), destTbl, len(destMeta), strings.Join(destMeta.Names(), ","))
}
for i := range rec {
nullable, _ := destMeta[i].Nullable()
if rec[i] == nil && !nullable {
mungeSetZeroValue(i, rec, destMeta)
continue
}
if destMeta[i].Kind() == kind.Text {
// text doesn't need our help
continue
}
// The dest col kind is something other than text, let's inspect
// the actual value and check its type.
switch val := rec[i].(type) {
default:
continue
case string:
if val == "" {
if nullable {
rec[i] = nil
} else {
mungeSetZeroValue(i, rec, destMeta)
}
}
// else we let the DB figure it out
case *string:
if *val == "" {
if nullable {
rec[i] = nil
} else {
mungeSetZeroValue(i, rec, destMeta)
}
}
// string is non-empty
if destMeta[i].Kind() == kind.Datetime {
// special handling for datetime
mungeSetDatetimeFromString(*val, i, rec)
}
// else we let the DB figure it out
}
}
return nil
}
}
// datetimeLayouts are layouts attempted with time.Parse to
// try to give mysql a time.Time instead of string.
var datetimeLayouts = []string{time.RFC3339Nano, time.RFC3339}
// mungeSetDatetimeFromString attempts to parse s into time.Time and
// sets rec[i] to that value. If unable to parse, rec is unchanged,
// and it's up to mysql to deal with the text.
func mungeSetDatetimeFromString(s string, i int, rec []any) {
var t time.Time
var err error
for _, layout := range datetimeLayouts {
t, err = time.Parse(layout, s)
if err == nil {
rec[i] = t
return
}
}
}
// mungeSetZeroValue is invoked when rec[i] is nil, but
// destMeta[i] is not nullable.
func mungeSetZeroValue(i int, rec []any, destMeta sqlz.RecordMeta) {
// REVISIT: do we need to do special handling for kind.Datetime
// and kind.Time (e.g. "00:00" for time)?
z := reflect.Zero(destMeta[i].ScanType()).Interface()
rec[i] = z
}
// canonicalTableType returns the canonical name for "BASE TABLE"
// and "VIEW".
func canonicalTableType(dbType string) string {
switch dbType {
default:
return ""
case "BASE TABLE":
return sqlz.TableTypeTable
case "VIEW":
return sqlz.TableTypeView
}
}