mirror of
https://github.com/neilotoole/sq.git
synced 2025-01-07 08:46:26 +03:00
a1a89ee9dd
* sakila: initial test data * sakila: more test data * sakila: yet more test data setup * whitespace cols: now working for sqlite * grammar cleanup * whitespace cols: now working inside count() func for sqlite * whitespace cols: tests mostly passing; begining refactoring * grammar: refactor handle * grammar: more refactoring * grammar: rename selElement to selector * wip * all tests passing * all tests passing * linting * driver: implement CurrentSchema for all driver.SQLDriver impls * driver: tests for AlterTableRename and AlterTableRenameColumn * undo reformat of SQL * undo reformat of SQL * undo reformat of SQL * undo reformat of SQL
440 lines
13 KiB
Go
440 lines
13 KiB
Go
package sqlserver
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/c2h5oh/datasize"
|
|
"github.com/neilotoole/errgroup"
|
|
"github.com/neilotoole/lg"
|
|
|
|
"github.com/neilotoole/sq/libsq/core/errz"
|
|
"github.com/neilotoole/sq/libsq/core/kind"
|
|
"github.com/neilotoole/sq/libsq/core/sqlz"
|
|
"github.com/neilotoole/sq/libsq/driver"
|
|
"github.com/neilotoole/sq/libsq/source"
|
|
)
|
|
|
|
// kindFromDBTypeName determines the kind.Kind from the database
|
|
// type name. For example, "VARCHAR" -> kind.Text.
|
|
func kindFromDBTypeName(log lg.Log, colName, dbTypeName string) kind.Kind {
|
|
var knd kind.Kind
|
|
dbTypeName = strings.ToUpper(dbTypeName)
|
|
|
|
switch dbTypeName {
|
|
default:
|
|
log.Warnf("Unknown SQLServer database type '%s' for column '%s': using %s", dbTypeName, colName, kind.Unknown)
|
|
knd = kind.Unknown
|
|
case "INT", "BIGINT", "SMALLINT", "TINYINT":
|
|
knd = kind.Int
|
|
case "CHAR", "NCHAR", "VARCHAR", "JSON", "NVARCHAR", "NTEXT", "TEXT":
|
|
knd = kind.Text
|
|
case "BIT":
|
|
knd = kind.Bool
|
|
case "BINARY", "VARBINARY", "IMAGE":
|
|
knd = kind.Bytes
|
|
case "DECIMAL", "NUMERIC":
|
|
knd = kind.Decimal
|
|
case "MONEY", "SMALLMONEY":
|
|
knd = kind.Decimal
|
|
case "DATETIME", "DATETIME2", "SMALLDATETIME", "DATETIMEOFFSET":
|
|
knd = kind.Datetime
|
|
case "DATE":
|
|
knd = kind.Date
|
|
case "TIME":
|
|
knd = kind.Time
|
|
case "FLOAT", "REAL":
|
|
knd = kind.Float
|
|
case "XML":
|
|
knd = kind.Text
|
|
case "UNIQUEIDENTIFIER":
|
|
knd = kind.Text
|
|
case "ROWVERSION", "TIMESTAMP":
|
|
knd = kind.Int
|
|
}
|
|
|
|
return knd
|
|
}
|
|
|
|
// setScanType does some manipulation of ct's scan type.
|
|
// Most importantly, if ct is nullable column, setwe colTypeData.ScanType to a
|
|
// nullable type. This is because the driver doesn't
|
|
// report nullable scan types.
|
|
func setScanType(ct *sqlz.ColumnTypeData, knd kind.Kind) {
|
|
if knd == kind.Decimal {
|
|
// The driver wants us to use []byte instead of string for DECIMAL,
|
|
// but we want to use string.
|
|
if ct.Nullable {
|
|
ct.ScanType = sqlz.RTypeNullString
|
|
} else {
|
|
ct.ScanType = sqlz.RTypeString
|
|
}
|
|
return
|
|
}
|
|
|
|
if !ct.Nullable {
|
|
// If the col type is not nullable, there's nothing
|
|
// to do here.
|
|
return
|
|
}
|
|
|
|
switch ct.ScanType {
|
|
default:
|
|
ct.ScanType = sqlz.RTypeNullString
|
|
|
|
case sqlz.RTypeInt64:
|
|
ct.ScanType = sqlz.RTypeNullInt64
|
|
|
|
case sqlz.RTypeBool:
|
|
ct.ScanType = sqlz.RTypeNullBool
|
|
|
|
case sqlz.RTypeFloat64:
|
|
ct.ScanType = sqlz.RTypeNullFloat64
|
|
|
|
case sqlz.RTypeString:
|
|
ct.ScanType = sqlz.RTypeNullString
|
|
|
|
case sqlz.RTypeTime:
|
|
ct.ScanType = sqlz.RTypeNullTime
|
|
|
|
case sqlz.RTypeBytes:
|
|
ct.ScanType = sqlz.RTypeBytes // no change
|
|
}
|
|
}
|
|
|
|
func getSourceMetadata(ctx context.Context, log lg.Log, src *source.Source, db sqlz.DB) (*source.Metadata, error) {
|
|
const query = `SELECT DB_NAME(), SCHEMA_NAME(), SERVERPROPERTY('ProductVersion'), @@VERSION,
|
|
(SELECT SUM(size) * 8192
|
|
FROM sys.master_files WITH(NOWAIT)
|
|
WHERE database_id = DB_ID()
|
|
GROUP BY database_id) AS total_size_bytes`
|
|
|
|
md := &source.Metadata{SourceType: Type, DBDriverType: Type}
|
|
md.Handle = src.Handle
|
|
md.Location = src.Location
|
|
|
|
var catalog, schema string
|
|
err := db.QueryRowContext(ctx, query).
|
|
Scan(&catalog, &schema, &md.DBVersion, &md.DBProduct, &md.Size)
|
|
if err != nil {
|
|
return nil, errz.Err(err)
|
|
}
|
|
|
|
md.Name = catalog
|
|
md.FQName = catalog + "." + schema
|
|
md.Schema = schema
|
|
|
|
tblNames, tblTypes, err := getAllTables(ctx, log, db)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
g, gctx := errgroup.WithContextN(ctx, driver.Tuning.ErrgroupNumG, driver.Tuning.ErrgroupQSize)
|
|
tblMetas := make([]*source.TableMetadata, len(tblNames))
|
|
|
|
for i := range tblNames {
|
|
select {
|
|
case <-gctx.Done():
|
|
return nil, errz.Err(gctx.Err())
|
|
default:
|
|
}
|
|
|
|
i := i
|
|
g.Go(func() error {
|
|
var tblMeta *source.TableMetadata
|
|
tblMeta, err = getTableMetadata(gctx, log, db, catalog, schema, tblNames[i], tblTypes[i])
|
|
if err != nil {
|
|
if hasErrCode(err, errCodeObjectNotExist) {
|
|
// This can happen if the table is dropped while
|
|
// we're collecting metadata. We log a warning and continue.
|
|
log.Warnf("table metadata: table %q appears not to exist (continuing regardless): %v",
|
|
tblNames[i], err)
|
|
|
|
return nil
|
|
}
|
|
return err
|
|
}
|
|
tblMetas[i] = tblMeta
|
|
return nil
|
|
})
|
|
}
|
|
|
|
err = g.Wait()
|
|
if err != nil {
|
|
return nil, errz.Err(err)
|
|
}
|
|
|
|
// If a table wasn't found (possibly dropped while querying), then
|
|
// its entry could be nil. We copy the non-nil elements to the
|
|
// final slice.
|
|
md.Tables = make([]*source.TableMetadata, 0, len(tblMetas))
|
|
for i := range tblMetas {
|
|
if tblMetas[i] != nil {
|
|
md.Tables = append(md.Tables, tblMetas[i])
|
|
}
|
|
}
|
|
return md, nil
|
|
}
|
|
|
|
func getTableMetadata(ctx context.Context, log lg.Log, db sqlz.DB,
|
|
tblCatalog, tblSchema, tblName, tblType string,
|
|
) (*source.TableMetadata, error) {
|
|
const tplTblUsage = `sp_spaceused '%s'`
|
|
|
|
tblMeta := &source.TableMetadata{Name: tblName, DBTableType: tblType}
|
|
tblMeta.FQName = tblCatalog + "." + tblSchema + "." + tblName
|
|
|
|
switch tblMeta.DBTableType {
|
|
case "BASE TABLE":
|
|
tblMeta.TableType = sqlz.TableTypeTable
|
|
case "VIEW":
|
|
tblMeta.TableType = sqlz.TableTypeView
|
|
default:
|
|
}
|
|
|
|
var rowCount, reserved, data, indexSize, unused sql.NullString
|
|
row := db.QueryRowContext(ctx, fmt.Sprintf(tplTblUsage, tblName))
|
|
err := row.Scan(&tblMeta.Name, &rowCount, &reserved, &data, &indexSize, &unused)
|
|
if err != nil {
|
|
return nil, errz.Err(err)
|
|
}
|
|
|
|
if rowCount.Valid {
|
|
tblMeta.RowCount, err = strconv.ParseInt(strings.TrimSpace(rowCount.String), 10, 64)
|
|
if err != nil {
|
|
return nil, errz.Err(err)
|
|
}
|
|
} else {
|
|
// We can't get the "row count" for a VIEW from sp_spaceused,
|
|
// so we need to select it the old-fashioned way.
|
|
err = db.QueryRowContext(ctx, fmt.Sprintf("SELECT COUNT(*) FROM %q", tblName)).Scan(&tblMeta.RowCount)
|
|
if err != nil {
|
|
return nil, errz.Err(err)
|
|
}
|
|
}
|
|
|
|
if reserved.Valid {
|
|
var byteCount datasize.ByteSize
|
|
err = byteCount.UnmarshalText([]byte(reserved.String))
|
|
if err != nil {
|
|
return nil, errz.Err(err)
|
|
}
|
|
size := int64(byteCount.Bytes())
|
|
tblMeta.Size = &size
|
|
}
|
|
|
|
var dbCols []columnMeta
|
|
dbCols, err = getColumnMeta(ctx, log, db, tblCatalog, tblSchema, tblName)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var dbConstraints []constraintMeta
|
|
dbConstraints, err = getConstraints(ctx, log, db, tblCatalog, tblSchema, tblName)
|
|
if err != nil {
|
|
return nil, errz.Err(err)
|
|
}
|
|
|
|
cols := make([]*source.ColMetadata, len(dbCols))
|
|
for i := range dbCols {
|
|
cols[i] = &source.ColMetadata{
|
|
Name: dbCols[i].ColumnName,
|
|
Position: dbCols[i].OrdinalPosition,
|
|
BaseType: dbCols[i].DataType,
|
|
Kind: kindFromDBTypeName(log, dbCols[i].ColumnName, dbCols[i].DataType),
|
|
Nullable: dbCols[i].Nullable.Bool,
|
|
DefaultValue: dbCols[i].ColumnDefault.String,
|
|
}
|
|
|
|
// We want to output something like VARCHAR(255) for ColType
|
|
|
|
// REVISIT: This is all a bit messy and inconsistent with other drivers
|
|
var colLength *int64
|
|
switch {
|
|
case dbCols[i].CharMaxLength.Valid:
|
|
colLength = &dbCols[i].CharMaxLength.Int64
|
|
case dbCols[i].NumericPrecision.Valid:
|
|
colLength = &dbCols[i].NumericPrecision.Int64
|
|
case dbCols[i].DateTimePrecision.Valid:
|
|
colLength = &dbCols[i].DateTimePrecision.Int64
|
|
}
|
|
|
|
if colLength != nil {
|
|
cols[i].ColumnType = fmt.Sprintf("%s(%v)", dbCols[i].DataType, *colLength)
|
|
} else {
|
|
cols[i].ColumnType = dbCols[i].DataType
|
|
}
|
|
|
|
for _, dbConstraint := range dbConstraints {
|
|
if dbCols[i].ColumnName == dbConstraint.ColumnName {
|
|
if dbConstraint.ConstraintType == "PRIMARY KEY" {
|
|
cols[i].PrimaryKey = true
|
|
break
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
tblMeta.Columns = cols
|
|
return tblMeta, nil
|
|
}
|
|
|
|
// getAllTables returns all of the table names, and the table types
|
|
// (i.e. "BASE TABLE" or "VIEW").
|
|
func getAllTables(ctx context.Context, log lg.Log, db sqlz.DB) (tblNames, tblTypes []string, err error) {
|
|
const query = `SELECT TABLE_NAME, TABLE_TYPE FROM INFORMATION_SCHEMA.TABLES
|
|
WHERE TABLE_TYPE='BASE TABLE' OR TABLE_TYPE='VIEW'
|
|
ORDER BY TABLE_NAME ASC, TABLE_TYPE ASC`
|
|
|
|
rows, err := db.QueryContext(ctx, query)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
defer log.WarnIfCloseError(rows)
|
|
|
|
for rows.Next() {
|
|
var tblName, tblType string
|
|
err = rows.Scan(&tblName, &tblType)
|
|
if err != nil {
|
|
return nil, nil, errz.Err(err)
|
|
}
|
|
|
|
tblNames = append(tblNames, tblName)
|
|
tblTypes = append(tblTypes, tblType)
|
|
}
|
|
|
|
if rows.Err() != nil {
|
|
return nil, nil, errz.Err(rows.Err())
|
|
}
|
|
|
|
return tblNames, tblTypes, nil
|
|
}
|
|
|
|
func getColumnMeta(ctx context.Context, log lg.Log, db sqlz.DB, tblCatalog, tblSchema, tblName string) ([]columnMeta,
|
|
error,
|
|
) {
|
|
// TODO: sq doesn't use all of these columns, no need to select them all.
|
|
|
|
const query = `SELECT
|
|
TABLE_CATALOG, TABLE_SCHEMA, TABLE_NAME,
|
|
COLUMN_NAME, ORDINAL_POSITION, COLUMN_DEFAULT, IS_NULLABLE, DATA_TYPE,
|
|
CHARACTER_MAXIMUM_LENGTH, CHARACTER_OCTET_LENGTH,
|
|
NUMERIC_PRECISION, NUMERIC_PRECISION_RADIX, NUMERIC_SCALE,
|
|
DATETIME_PRECISION,
|
|
CHARACTER_SET_CATALOG, CHARACTER_SET_SCHEMA, CHARACTER_SET_NAME,
|
|
COLLATION_CATALOG, COLLATION_SCHEMA, COLLATION_NAME,
|
|
DOMAIN_CATALOG, DOMAIN_SCHEMA, DOMAIN_NAME
|
|
FROM INFORMATION_SCHEMA.COLUMNS
|
|
WHERE TABLE_CATALOG = @p1 AND TABLE_SCHEMA = @p2 AND TABLE_NAME = @p3`
|
|
|
|
rows, err := db.QueryContext(ctx, query, tblCatalog, tblSchema, tblName)
|
|
if err != nil {
|
|
return nil, errz.Err(err)
|
|
}
|
|
|
|
defer func() { log.WarnIfCloseError(rows) }()
|
|
|
|
var cols []columnMeta
|
|
|
|
for rows.Next() {
|
|
c := columnMeta{}
|
|
err = rows.Scan(&c.TableCatalog, &c.TableSchema, &c.TableName, &c.ColumnName, &c.OrdinalPosition,
|
|
&c.ColumnDefault, &c.Nullable, &c.DataType, &c.CharMaxLength, &c.CharOctetLength, &c.NumericPrecision,
|
|
&c.NumericPrecisionRadix, &c.NumericScale, &c.DateTimePrecision, &c.CharSetCatalog, &c.CharSetSchema,
|
|
&c.CharSetName, &c.CollationCatalog, &c.CollationSchema, &c.CollationName, &c.DomainCatalog,
|
|
&c.DomainSchema, &c.DomainName)
|
|
if err != nil {
|
|
return nil, errz.Err(err)
|
|
}
|
|
|
|
cols = append(cols, c)
|
|
}
|
|
|
|
if err = rows.Err(); err != nil {
|
|
return nil, errz.Err(err)
|
|
}
|
|
|
|
return cols, nil
|
|
}
|
|
|
|
func getConstraints(ctx context.Context, log lg.Log, db sqlz.DB,
|
|
tblCatalog, tblSchema, tblName string,
|
|
) ([]constraintMeta, error) {
|
|
const query = `SELECT kcu.TABLE_CATALOG, kcu.TABLE_SCHEMA, kcu.TABLE_NAME, tc.CONSTRAINT_TYPE,
|
|
kcu.COLUMN_NAME, kcu.CONSTRAINT_NAME
|
|
FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS tc
|
|
JOIN INFORMATION_SCHEMA.KEY_COLUMN_USAGE AS kcu
|
|
ON tc.TABLE_NAME = kcu.TABLE_NAME
|
|
AND tc.CONSTRAINT_CATALOG = kcu.CONSTRAINT_CATALOG
|
|
AND tc.CONSTRAINT_SCHEMA = kcu.CONSTRAINT_SCHEMA
|
|
AND tc.CONSTRAINT_NAME = kcu.CONSTRAINT_NAME
|
|
WHERE tc.TABLE_CATALOG = @p1 AND tc.TABLE_SCHEMA = @p2 AND tc.TABLE_NAME = @p3
|
|
ORDER BY kcu.TABLE_NAME, tc.CONSTRAINT_TYPE, kcu.CONSTRAINT_NAME`
|
|
|
|
rows, err := db.QueryContext(ctx, query, tblCatalog, tblSchema, tblName)
|
|
if err != nil {
|
|
return nil, errz.Err(err)
|
|
}
|
|
|
|
defer func() { log.WarnIfCloseError(rows) }()
|
|
|
|
var constraints []constraintMeta
|
|
|
|
for rows.Next() {
|
|
c := constraintMeta{}
|
|
err = rows.Scan(&c.TableCatalog, &c.TableSchema, &c.TableName, &c.ConstraintType, &c.ColumnName,
|
|
&c.ConstraintName)
|
|
if err != nil {
|
|
return nil, errz.Err(err)
|
|
}
|
|
|
|
constraints = append(constraints, c)
|
|
}
|
|
|
|
if err = rows.Err(); err != nil {
|
|
return nil, errz.Err(err)
|
|
}
|
|
|
|
return constraints, nil
|
|
}
|
|
|
|
// constraintMeta models constraint metadata from information schema.
|
|
type constraintMeta struct {
|
|
TableCatalog string `db:"TABLE_CATALOG"`
|
|
TableSchema string `db:"TABLE_SCHEMA"`
|
|
TableName string `db:"TABLE_NAME"`
|
|
ConstraintType string `db:"CONSTRAINT_TYPE"`
|
|
ColumnName string `db:"COLUMN_NAME"`
|
|
ConstraintName string `db:"CONSTRAINT_NAME"`
|
|
}
|
|
|
|
// columnMeta models column metadata from information schema.
|
|
type columnMeta struct {
|
|
TableCatalog string `db:"TABLE_CATALOG"`
|
|
TableSchema string `db:"TABLE_SCHEMA"`
|
|
TableName string `db:"TABLE_NAME"`
|
|
ColumnName string `db:"COLUMN_NAME"`
|
|
OrdinalPosition int64 `db:"ORDINAL_POSITION"`
|
|
ColumnDefault sql.NullString `db:"COLUMN_DEFAULT"`
|
|
Nullable sqlz.NullBool `db:"IS_NULLABLE"`
|
|
DataType string `db:"DATA_TYPE"`
|
|
CharMaxLength sql.NullInt64 `db:"CHARACTER_MAXIMUM_LENGTH"`
|
|
CharOctetLength sql.NullString `db:"CHARACTER_OCTET_LENGTH"`
|
|
NumericPrecision sql.NullInt64 `db:"NUMERIC_PRECISION"`
|
|
NumericPrecisionRadix sql.NullInt64 `db:"NUMERIC_PRECISION_RADIX"`
|
|
NumericScale sql.NullInt64 `db:"NUMERIC_SCALE"`
|
|
DateTimePrecision sql.NullInt64 `db:"DATETIME_PRECISION"`
|
|
CharSetCatalog sql.NullString `db:"CHARACTER_SET_CATALOG"`
|
|
CharSetSchema sql.NullString `db:"CHARACTER_SET_SCHEMA"`
|
|
CharSetName sql.NullString `db:"CHARACTER_SET_NAME"`
|
|
CollationCatalog sql.NullString `db:"COLLATION_CATALOG"`
|
|
CollationSchema sql.NullString `db:"COLLATION_SCHEMA"`
|
|
CollationName sql.NullString `db:"COLLATION_NAME"`
|
|
DomainCatalog sql.NullString `db:"DOMAIN_CATALOG"`
|
|
DomainSchema sql.NullString `db:"DOMAIN_SCHEMA"`
|
|
DomainName sql.NullString `db:"DOMAIN_NAME"`
|
|
}
|