mirror of
https://github.com/neilotoole/sq.git
synced 2024-12-18 21:52:28 +03:00
a3cd01f36a
* Diff refactor
861 lines
24 KiB
Go
861 lines
24 KiB
Go
// Package sqlserver implements the sq driver for SQL Server.
|
|
package sqlserver
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"log/slog"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/microsoft/go-mssqldb/msdsn"
|
|
|
|
"github.com/neilotoole/sq/libsq/ast"
|
|
"github.com/neilotoole/sq/libsq/ast/render"
|
|
"github.com/neilotoole/sq/libsq/core/errz"
|
|
"github.com/neilotoole/sq/libsq/core/jointype"
|
|
"github.com/neilotoole/sq/libsq/core/kind"
|
|
"github.com/neilotoole/sq/libsq/core/langz"
|
|
"github.com/neilotoole/sq/libsq/core/lg"
|
|
"github.com/neilotoole/sq/libsq/core/lg/lga"
|
|
"github.com/neilotoole/sq/libsq/core/lg/lgm"
|
|
"github.com/neilotoole/sq/libsq/core/record"
|
|
"github.com/neilotoole/sq/libsq/core/schema"
|
|
"github.com/neilotoole/sq/libsq/core/sqlz"
|
|
"github.com/neilotoole/sq/libsq/core/stringz"
|
|
"github.com/neilotoole/sq/libsq/core/tablefq"
|
|
"github.com/neilotoole/sq/libsq/driver"
|
|
"github.com/neilotoole/sq/libsq/driver/dialect"
|
|
"github.com/neilotoole/sq/libsq/source"
|
|
"github.com/neilotoole/sq/libsq/source/drivertype"
|
|
"github.com/neilotoole/sq/libsq/source/location"
|
|
"github.com/neilotoole/sq/libsq/source/metadata"
|
|
)
|
|
|
|
const (
|
|
// dbDrvr is the backing SQL Server driver impl name.
|
|
dbDrvr = "sqlserver"
|
|
)
|
|
|
|
var _ driver.Provider = (*Provider)(nil)
|
|
|
|
// Provider is the SQL Server implementation of driver.Provider.
|
|
type Provider struct {
|
|
Log *slog.Logger
|
|
}
|
|
|
|
// DriverFor implements driver.Provider.
|
|
func (p *Provider) DriverFor(typ drivertype.Type) (driver.Driver, error) {
|
|
if typ != drivertype.MSSQL {
|
|
return nil, errz.Errorf("unsupported driver type {%s}}", typ)
|
|
}
|
|
|
|
return &driveri{log: p.Log}, nil
|
|
}
|
|
|
|
var _ driver.SQLDriver = (*driveri)(nil)
|
|
|
|
// driveri is the SQL Server implementation of driver.Driver.
|
|
type driveri struct {
|
|
log *slog.Logger
|
|
}
|
|
|
|
// ConnParams implements driver.SQLDriver.
|
|
func (d *driveri) ConnParams() map[string][]string {
|
|
// https://github.com/microsoft/go-mssqldb#connection-parameters-and-dsn.
|
|
return map[string][]string{
|
|
"ApplicationIntent": {"ReadOnly"},
|
|
"ServerSPN": nil,
|
|
"TrustServerCertificate": {"false", "true"},
|
|
"Workstation ID": nil,
|
|
"app name": {"sq"},
|
|
"certificate": nil,
|
|
"connection timeout": {"0"},
|
|
"database": nil,
|
|
"dial timeout": {"0"},
|
|
"encrypt": {"disable", "false", "true"},
|
|
"failoverpartner": nil,
|
|
"failoverport": {"1433"},
|
|
"hostNameInCertificate": nil,
|
|
"keepAlive": {"0", "30"},
|
|
"log": {"0", "1", "2", "4", "8", "16", "32", "64", "128", "255"},
|
|
"packet size": {"512", "4096", "16383", "32767"},
|
|
"protocol": nil,
|
|
"tlsmin": {"1.0", "1.1", "1.2", "1.3"},
|
|
"user id": nil,
|
|
}
|
|
}
|
|
|
|
// ErrWrapFunc implements driver.SQLDriver.
|
|
func (d *driveri) ErrWrapFunc() func(error) error {
|
|
return errw
|
|
}
|
|
|
|
// DriverMetadata implements driver.SQLDriver.
|
|
func (d *driveri) DriverMetadata() driver.Metadata {
|
|
return driver.Metadata{
|
|
Type: drivertype.MSSQL,
|
|
Description: "Microsoft SQL Server / Azure SQL Edge",
|
|
Doc: "https://github.com/microsoft/go-mssqldb",
|
|
IsSQL: true,
|
|
DefaultPort: 1433,
|
|
}
|
|
}
|
|
|
|
// Dialect implements driver.SQLDriver.
|
|
func (d *driveri) Dialect() dialect.Dialect {
|
|
return dialect.Dialect{
|
|
Type: drivertype.MSSQL,
|
|
Placeholders: placeholders,
|
|
Enquote: stringz.DoubleQuote,
|
|
MaxBatchValues: 1000,
|
|
Ops: dialect.DefaultOps(),
|
|
Joins: jointype.All(),
|
|
Catalog: true,
|
|
}
|
|
}
|
|
|
|
func placeholders(numCols, numRows int) string {
|
|
rows := make([]string, numRows)
|
|
|
|
n := 1
|
|
var sb strings.Builder
|
|
for i := 0; i < numRows; i++ {
|
|
sb.Reset()
|
|
sb.WriteRune('(')
|
|
for j := 1; j <= numCols; j++ {
|
|
sb.WriteString("@p")
|
|
sb.WriteString(strconv.Itoa(n))
|
|
n++
|
|
if j < numCols {
|
|
sb.WriteString(driver.Comma)
|
|
}
|
|
}
|
|
sb.WriteRune(')')
|
|
rows[i] = sb.String()
|
|
}
|
|
|
|
return strings.Join(rows, driver.Comma)
|
|
}
|
|
|
|
// Renderer implements driver.SQLDriver.
|
|
func (d *driveri) Renderer() *render.Renderer {
|
|
r := render.NewDefaultRenderer()
|
|
|
|
// Custom functions for SQLServer-specific stuff.
|
|
r.Range = renderRange
|
|
r.PreRender = append(r.PreRender, preRender)
|
|
|
|
r.FunctionNames[ast.FuncNameSchema] = "SCHEMA_NAME"
|
|
r.FunctionNames[ast.FuncNameCatalog] = "DB_NAME"
|
|
r.FunctionOverrides[ast.FuncNameRowNum] = renderFuncRowNum
|
|
|
|
return r
|
|
}
|
|
|
|
// Open implements driver.Driver.
|
|
func (d *driveri) Open(ctx context.Context, src *source.Source) (driver.Grip, error) {
|
|
lg.FromContext(ctx).Debug(lgm.OpenSrc, lga.Src, src)
|
|
|
|
db, err := d.doOpen(ctx, src)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if err = driver.OpeningPing(ctx, src, db); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &grip{log: d.log, db: db, src: src, drvr: d}, nil
|
|
}
|
|
|
|
func (d *driveri) doOpen(ctx context.Context, src *source.Source) (*sql.DB, error) {
|
|
log := lg.FromContext(ctx)
|
|
loc := src.Location
|
|
cfg, err := msdsn.Parse(loc)
|
|
if err != nil {
|
|
return nil, errw(err)
|
|
}
|
|
if src.Catalog != "" {
|
|
cfg.Database = src.Catalog
|
|
loc = cfg.URL().String()
|
|
|
|
log.Debug("Using catalog as database in connection string",
|
|
lga.Src, src,
|
|
lga.Catalog, src.Catalog,
|
|
lga.Conn, location.Redact(loc),
|
|
)
|
|
}
|
|
|
|
cfg.DialTimeout = driver.OptConnOpenTimeout.Get(src.Options)
|
|
loc = cfg.URL().String()
|
|
|
|
db, err := sql.Open(dbDrvr, loc)
|
|
if err != nil {
|
|
return nil, errw(err)
|
|
}
|
|
|
|
driver.ConfigureDB(ctx, db, src.Options)
|
|
return db, nil
|
|
}
|
|
|
|
// ValidateSource implements driver.Driver.
|
|
func (d *driveri) ValidateSource(src *source.Source) (*source.Source, error) {
|
|
if src.Type != drivertype.MSSQL {
|
|
return nil, errz.Errorf("expected driver type %q but got %q", drivertype.MSSQL, src.Type)
|
|
}
|
|
return src, nil
|
|
}
|
|
|
|
// Ping implements driver.Driver.
|
|
func (d *driveri) Ping(ctx context.Context, src *source.Source) error {
|
|
db, err := d.doOpen(ctx, src)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer lg.WarnIfCloseError(d.log, lgm.CloseDB, db)
|
|
|
|
err = db.PingContext(ctx)
|
|
return errz.Wrapf(errw(err), "ping %s", src.Handle)
|
|
}
|
|
|
|
// DBProperties implements driver.SQLDriver.
|
|
func (d *driveri) DBProperties(ctx context.Context, db sqlz.DB) (map[string]any, error) {
|
|
return getDBProperties(ctx, db)
|
|
}
|
|
|
|
// Truncate implements driver.Driver. Due to a quirk of SQL Server, the
|
|
// operation is implemented in two statements. First "DELETE FROM tbl" to
|
|
// delete all rows. Then, if reset is true, the table sequence counter
|
|
// is reset via RESEED.
|
|
//
|
|
//nolint:lll
|
|
func (d *driveri) Truncate(ctx context.Context, src *source.Source, tbl string, reset bool,
|
|
) (affected int64, err error) {
|
|
// https://docs.microsoft.com/en-us/sql/t-sql/statements/truncate-table-transact-sql?view=sql-server-ver15
|
|
|
|
// When there are foreign key constraints on mssql tables,
|
|
// it's not possible to TRUNCATE the table. An alternative is
|
|
// to delete all rows and reseed the identity column.
|
|
//
|
|
// DELETE FROM "table1"; DBCC CHECKIDENT ('table1', RESEED, 1);
|
|
//
|
|
// See: https://stackoverflow.com/questions/253849/cannot-truncate-table-because-it-is-being-referenced-by-a-foreign-key-constraint
|
|
|
|
db, err := d.doOpen(ctx, src)
|
|
if err != nil {
|
|
return 0, errw(err)
|
|
}
|
|
defer lg.WarnIfFuncError(d.log, lgm.CloseDB, db.Close)
|
|
|
|
affected, err = sqlz.ExecAffected(ctx, db, fmt.Sprintf("DELETE FROM %q", tbl))
|
|
if err != nil {
|
|
return affected, errz.Wrapf(errw(err), "truncate: failed to delete from %q", tbl)
|
|
}
|
|
|
|
if reset {
|
|
_, err = db.ExecContext(ctx, fmt.Sprintf("DBCC CHECKIDENT ('%s', RESEED, 1)", tbl))
|
|
if err != nil {
|
|
if hasErrCode(err, errNoIdentityColumn) {
|
|
// The table has no identity column, so we can't reseed.
|
|
lg.FromContext(ctx).Warn("truncate: table has no identity column, so cannot reseed",
|
|
lga.Src, src, lga.Table, tbl, lga.Err, errw(err))
|
|
return affected, nil
|
|
}
|
|
return affected, errz.Wrapf(errw(err), "truncate: deleted %d rows from %q but RESEED failed", affected, tbl)
|
|
}
|
|
}
|
|
|
|
return affected, nil
|
|
}
|
|
|
|
// TableColumnTypes implements driver.SQLDriver.
|
|
func (d *driveri) TableColumnTypes(ctx context.Context, db sqlz.DB, tblName string,
|
|
colNames []string,
|
|
) ([]*sql.ColumnType, error) {
|
|
// SQLServer has this unusual incantation for its LIMIT equivalent:
|
|
//
|
|
// SELECT username, email, address_id FROM person
|
|
// ORDER BY (SELECT 0) OFFSET 0 ROWS FETCH NEXT 1 ROWS ONLY;
|
|
const queryTpl = "SELECT %s FROM %s ORDER BY (SELECT 0) OFFSET 0 ROWS FETCH NEXT 1 ROWS ONLY"
|
|
|
|
enquote := d.Dialect().Enquote
|
|
tblNameQuoted := enquote(tblName)
|
|
|
|
colsClause := "*"
|
|
if len(colNames) > 0 {
|
|
colNamesQuoted := langz.Apply(colNames, enquote)
|
|
colsClause = strings.Join(colNamesQuoted, driver.Comma)
|
|
}
|
|
|
|
query := fmt.Sprintf(queryTpl, colsClause, tblNameQuoted)
|
|
rows, err := db.QueryContext(ctx, query)
|
|
if err != nil {
|
|
return nil, errw(err)
|
|
}
|
|
|
|
colTypes, err := rows.ColumnTypes()
|
|
if err != nil {
|
|
lg.WarnIfFuncError(d.log, lgm.CloseDBRows, rows.Close)
|
|
return nil, errw(err)
|
|
}
|
|
|
|
err = rows.Err()
|
|
if err != nil {
|
|
lg.WarnIfFuncError(d.log, lgm.CloseDBRows, rows.Close)
|
|
return nil, errw(err)
|
|
}
|
|
|
|
err = rows.Close()
|
|
if err != nil {
|
|
return nil, errw(err)
|
|
}
|
|
|
|
return colTypes, nil
|
|
}
|
|
|
|
// RecordMeta implements driver.SQLDriver.
|
|
func (d *driveri) RecordMeta(ctx context.Context, colTypes []*sql.ColumnType) (
|
|
record.Meta, driver.NewRecordFunc, error,
|
|
) {
|
|
sColTypeData := make([]*record.ColumnTypeData, len(colTypes))
|
|
ogColNames := make([]string, len(colTypes))
|
|
for i, colType := range colTypes {
|
|
knd := kindFromDBTypeName(d.log, colType.Name(), colType.DatabaseTypeName())
|
|
colTypeData := record.NewColumnTypeData(colType, knd)
|
|
setScanType(colTypeData, knd)
|
|
sColTypeData[i] = colTypeData
|
|
ogColNames[i] = colTypeData.Name
|
|
}
|
|
|
|
mungedColNames, err := driver.MungeResultColNames(ctx, ogColNames)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
recMeta := make(record.Meta, len(colTypes))
|
|
for i := range sColTypeData {
|
|
recMeta[i] = record.NewFieldMeta(sColTypeData[i], mungedColNames[i])
|
|
}
|
|
|
|
mungeFn := func(vals []any) (record.Record, error) {
|
|
// sqlserver doesn't need to do any special munging, so we
|
|
// just use the default munging.
|
|
rec, skipped := driver.NewRecordFromScanRow(recMeta, vals, nil)
|
|
if len(skipped) > 0 {
|
|
return nil, errz.Errorf("expected zero skipped cols but have %d", skipped)
|
|
}
|
|
return rec, nil
|
|
}
|
|
|
|
return recMeta, mungeFn, nil
|
|
}
|
|
|
|
// TableExists implements driver.SQLDriver.
|
|
func (d *driveri) TableExists(ctx context.Context, db sqlz.DB, tbl string) (bool, error) {
|
|
const query = `SELECT COUNT(*) FROM information_schema.tables
|
|
WHERE table_schema = schema_name() AND table_name = @p1`
|
|
|
|
var count int64
|
|
if err := db.QueryRowContext(ctx, query, tbl).Scan(&count); err != nil {
|
|
return false, errw(err)
|
|
}
|
|
|
|
return count == 1, nil
|
|
}
|
|
|
|
// CurrentSchema implements driver.SQLDriver.
|
|
func (d *driveri) CurrentSchema(ctx context.Context, db sqlz.DB) (string, error) {
|
|
var name string
|
|
if err := db.QueryRowContext(ctx, `SELECT SCHEMA_NAME()`).Scan(&name); err != nil {
|
|
return "", errw(err)
|
|
}
|
|
|
|
return name, nil
|
|
}
|
|
|
|
// ListSchemas implements driver.SQLDriver.
|
|
func (d *driveri) ListSchemas(ctx context.Context, db sqlz.DB) ([]string, error) {
|
|
log := lg.FromContext(ctx)
|
|
|
|
const q = `SELECT name FROM sys.schemas ORDER BY name`
|
|
var schemas []string
|
|
rows, err := db.QueryContext(ctx, q)
|
|
if err != nil {
|
|
return nil, errz.Err(err)
|
|
}
|
|
defer lg.WarnIfCloseError(log, lgm.CloseDBRows, rows)
|
|
|
|
for rows.Next() {
|
|
var schma string
|
|
if err = rows.Scan(&schma); err != nil {
|
|
return nil, errz.Err(err)
|
|
}
|
|
schemas = append(schemas, schma)
|
|
}
|
|
|
|
if err = rows.Err(); err != nil {
|
|
return nil, errz.Err(err)
|
|
}
|
|
|
|
return schemas, nil
|
|
}
|
|
|
|
// SchemaExists implements driver.SQLDriver.
|
|
func (d *driveri) SchemaExists(ctx context.Context, db sqlz.DB, schma string) (bool, error) {
|
|
if schma == "" {
|
|
return false, nil
|
|
}
|
|
|
|
const q = `SELECT COUNT(SCHEMA_NAME) FROM INFORMATION_SCHEMA.SCHEMATA
|
|
WHERE SCHEMA_NAME = @p1 AND CATALOG_NAME = DB_NAME()`
|
|
|
|
var count int
|
|
return count > 0, errw(db.QueryRowContext(ctx, q, schma).Scan(&count))
|
|
}
|
|
|
|
// ListSchemaMetadata implements driver.SQLDriver.
|
|
func (d *driveri) ListSchemaMetadata(ctx context.Context, db sqlz.DB) ([]*metadata.Schema, error) {
|
|
log := lg.FromContext(ctx)
|
|
|
|
const q = `SELECT schema_name, catalog_name, schema_owner FROM information_schema.schemata
|
|
WHERE catalog_name = DB_NAME()
|
|
ORDER BY schema_name`
|
|
var schemas []*metadata.Schema
|
|
rows, err := db.QueryContext(ctx, q)
|
|
if err != nil {
|
|
return nil, errw(err)
|
|
}
|
|
|
|
defer lg.WarnIfCloseError(log, lgm.CloseDBRows, rows)
|
|
|
|
var name string
|
|
var catalog, owner sql.NullString
|
|
|
|
for rows.Next() {
|
|
if err = rows.Scan(&name, &catalog, &owner); err != nil {
|
|
return nil, errw(err)
|
|
}
|
|
s := &metadata.Schema{
|
|
Name: name,
|
|
Catalog: catalog.String,
|
|
Owner: owner.String,
|
|
}
|
|
|
|
schemas = append(schemas, s)
|
|
}
|
|
|
|
if err = rows.Err(); err != nil {
|
|
return nil, errw(err)
|
|
}
|
|
|
|
return schemas, nil
|
|
}
|
|
|
|
// CurrentCatalog implements driver.SQLDriver.
|
|
func (d *driveri) CurrentCatalog(ctx context.Context, db sqlz.DB) (string, error) {
|
|
var name string
|
|
if err := db.QueryRowContext(ctx, `SELECT DB_NAME()`).Scan(&name); err != nil {
|
|
return "", errw(err)
|
|
}
|
|
|
|
return name, nil
|
|
}
|
|
|
|
// CatalogExists implements driver.SQLDriver.
|
|
func (d *driveri) CatalogExists(ctx context.Context, db sqlz.DB, catalog string) (bool, error) {
|
|
if catalog == "" {
|
|
return false, nil
|
|
}
|
|
|
|
const q = `SELECT COUNT(name) FROM sys.databases WHERE name = @p1`
|
|
|
|
var count int
|
|
return count > 0, errw(db.QueryRowContext(ctx, q, catalog).Scan(&count))
|
|
}
|
|
|
|
// ListCatalogs implements driver.SQLDriver.
|
|
func (d *driveri) ListCatalogs(ctx context.Context, db sqlz.DB) ([]string, error) {
|
|
catalogs := make([]string, 1, 3)
|
|
if err := db.QueryRowContext(ctx, `SELECT DB_NAME()`).Scan(&catalogs[0]); err != nil {
|
|
return nil, errw(err)
|
|
}
|
|
|
|
const q = `SELECT name FROM sys.databases WHERE name != DB_NAME() ORDER BY name`
|
|
|
|
rows, err := db.QueryContext(ctx, q)
|
|
if err != nil {
|
|
return nil, errw(err)
|
|
}
|
|
|
|
defer lg.WarnIfCloseError(lg.FromContext(ctx), lgm.CloseDBRows, rows)
|
|
|
|
for rows.Next() {
|
|
var catalog string
|
|
if err = rows.Scan(&catalog); err != nil {
|
|
return nil, errw(err)
|
|
}
|
|
catalogs = append(catalogs, catalog)
|
|
}
|
|
|
|
if err = rows.Err(); err != nil {
|
|
return nil, errw(err)
|
|
}
|
|
|
|
return catalogs, nil
|
|
}
|
|
|
|
// CreateSchema implements driver.SQLDriver.
|
|
func (d *driveri) CreateSchema(ctx context.Context, db sqlz.DB, schemaName string) error {
|
|
stmt := `CREATE SCHEMA ` + stringz.DoubleQuote(schemaName)
|
|
if _, err := db.ExecContext(ctx, stmt); err != nil {
|
|
return errz.Wrapf(err, "failed to create schema {%s}", schemaName)
|
|
}
|
|
|
|
lg.FromContext(ctx).Debug("Created schema", lga.Schema, schemaName)
|
|
return nil
|
|
}
|
|
|
|
// DropSchema implements driver.SQLDriver.
|
|
func (d *driveri) DropSchema(ctx context.Context, db sqlz.DB, schemaName string) error {
|
|
dropObjectsStmt := genDropSchemaObjectsStmt(schemaName)
|
|
|
|
if _, err := db.ExecContext(ctx, dropObjectsStmt); err != nil {
|
|
return errz.Wrapf(err, "failed to drop objects in schema {%s}", schemaName)
|
|
}
|
|
|
|
dropSchemaStmt := `DROP SCHEMA [` + schemaName + `]`
|
|
if _, err := db.ExecContext(ctx, dropSchemaStmt); err != nil {
|
|
return errz.Wrapf(err, "failed to drop schema {%s}", schemaName)
|
|
}
|
|
|
|
lg.FromContext(ctx).Debug("Dropped schema", lga.Schema, schemaName)
|
|
return nil
|
|
}
|
|
|
|
// ListTableNames implements driver.SQLDriver.
|
|
func (d *driveri) ListTableNames(ctx context.Context, db sqlz.DB, schma string, tables, views bool) ([]string, error) {
|
|
var tblClause string
|
|
|
|
switch {
|
|
case tables && views:
|
|
tblClause = " AND (TABLE_TYPE = 'BASE TABLE' OR TABLE_TYPE = 'VIEW')"
|
|
case tables:
|
|
tblClause = " AND TABLE_TYPE = 'BASE TABLE'"
|
|
case views:
|
|
tblClause = " AND TABLE_TYPE = 'VIEW'"
|
|
default:
|
|
return []string{}, nil
|
|
}
|
|
|
|
var args []any
|
|
q := "SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = "
|
|
if schma == "" {
|
|
q += "SCHEMA_NAME()"
|
|
} else {
|
|
q += "@p1"
|
|
args = append(args, schma)
|
|
}
|
|
q += tblClause + " ORDER BY TABLE_NAME"
|
|
|
|
rows, err := db.QueryContext(ctx, q, args...)
|
|
if err != nil {
|
|
return nil, errw(err)
|
|
}
|
|
|
|
names, err := sqlz.RowsScanColumn[string](ctx, rows)
|
|
if err != nil {
|
|
return nil, errw(err)
|
|
}
|
|
|
|
return names, nil
|
|
}
|
|
|
|
// CreateTable implements driver.SQLDriver.
|
|
func (d *driveri) CreateTable(ctx context.Context, db sqlz.DB, tblDef *schema.Table) error {
|
|
stmt := buildCreateTableStmt(tblDef)
|
|
|
|
_, err := db.ExecContext(ctx, stmt)
|
|
return errw(err)
|
|
}
|
|
|
|
// AlterTableAddColumn implements driver.SQLDriver.
|
|
func (d *driveri) AlterTableAddColumn(ctx context.Context, db sqlz.DB, tbl, col string, knd kind.Kind) error {
|
|
q := fmt.Sprintf("ALTER TABLE %q ADD %q ", tbl, col) + dbTypeNameFromKind(knd)
|
|
|
|
_, err := db.ExecContext(ctx, q)
|
|
return errz.Wrapf(errw(err), "alter table: failed to add column %q to table %q", col, tbl)
|
|
}
|
|
|
|
// AlterTableRename implements driver.SQLDriver.
|
|
func (d *driveri) AlterTableRename(ctx context.Context, db sqlz.DB, tbl, newName string) error {
|
|
schma, err := d.CurrentSchema(ctx, db)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
q := fmt.Sprintf(`exec sp_rename '[%s].[%s]', '%s'`, schma, tbl, newName)
|
|
_, err = db.ExecContext(ctx, q)
|
|
return errz.Wrapf(errw(err), "alter table: failed to rename table %q to %q", tbl, newName)
|
|
}
|
|
|
|
// AlterTableRenameColumn implements driver.SQLDriver.
|
|
func (d *driveri) AlterTableRenameColumn(ctx context.Context, db sqlz.DB, tbl, col, newName string) error {
|
|
schma, err := d.CurrentSchema(ctx, db)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
q := fmt.Sprintf(`exec sp_rename '[%s].[%s].[%s]', '%s'`, schma, tbl, col, newName)
|
|
_, err = db.ExecContext(ctx, q)
|
|
return errz.Wrapf(errw(err), "alter table: failed to rename column {%s.%s.%s} to {%s}", schma, tbl, col, newName)
|
|
}
|
|
|
|
// CopyTable implements driver.SQLDriver.
|
|
func (d *driveri) CopyTable(ctx context.Context, db sqlz.DB,
|
|
fromTable, toTable tablefq.T, copyData bool,
|
|
) (int64, error) {
|
|
var stmt string
|
|
|
|
if copyData {
|
|
stmt = fmt.Sprintf("SELECT * INTO %s FROM %s", tblfmt(toTable), tblfmt(fromTable))
|
|
} else {
|
|
stmt = fmt.Sprintf("SELECT TOP(0) * INTO %s FROM %s", tblfmt(toTable), tblfmt(fromTable))
|
|
}
|
|
|
|
affected, err := sqlz.ExecAffected(ctx, db, stmt)
|
|
if err != nil {
|
|
return 0, errw(err)
|
|
}
|
|
|
|
return affected, nil
|
|
}
|
|
|
|
// DropTable implements driver.SQLDriver.
|
|
func (d *driveri) DropTable(ctx context.Context, db sqlz.DB, tbl tablefq.T, ifExists bool) error {
|
|
var stmt string
|
|
|
|
// We don't want the catalog for this part.
|
|
tbl.Catalog = ""
|
|
tblID := tblfmt(tbl)
|
|
|
|
if ifExists {
|
|
stmt = fmt.Sprintf("IF OBJECT_ID('%s', 'U') IS NOT NULL DROP TABLE %s", tblID, tblID)
|
|
} else {
|
|
stmt = fmt.Sprintf("DROP TABLE %s", tblID)
|
|
}
|
|
|
|
_, err := db.ExecContext(ctx, stmt)
|
|
return errw(err)
|
|
}
|
|
|
|
// PrepareInsertStmt implements driver.SQLDriver.
|
|
func (d *driveri) PrepareInsertStmt(ctx context.Context, db sqlz.DB, destTbl string, destColNames []string,
|
|
numRows int,
|
|
) (*driver.StmtExecer, error) {
|
|
destColsMeta, err := d.getTableColsMeta(ctx, db, destTbl, destColNames)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
stmt, err := driver.PrepareInsertStmt(ctx, d, db, destTbl, destColsMeta.Names(), numRows)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
execer := driver.NewStmtExecer(stmt, driver.DefaultInsertMungeFunc(destTbl, destColsMeta),
|
|
newStmtExecFunc(stmt, db, destTbl), destColsMeta)
|
|
return execer, nil
|
|
}
|
|
|
|
// PrepareUpdateStmt implements driver.SQLDriver.
|
|
func (d *driveri) PrepareUpdateStmt(ctx context.Context, db sqlz.DB, destTbl string, destColNames []string,
|
|
where string,
|
|
) (*driver.StmtExecer, error) {
|
|
destColsMeta, err := d.getTableColsMeta(ctx, db, destTbl, destColNames)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
query, err := buildUpdateStmt(destTbl, destColNames, where)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
stmt, err := db.PrepareContext(ctx, query)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
execer := driver.NewStmtExecer(stmt, driver.DefaultInsertMungeFunc(destTbl, destColsMeta),
|
|
newStmtExecFunc(stmt, db, destTbl), destColsMeta)
|
|
return execer, nil
|
|
}
|
|
|
|
func (d *driveri) getTableColsMeta(ctx context.Context, db sqlz.DB, tblName string, colNames []string) (
|
|
record.Meta, error,
|
|
) {
|
|
// SQLServer has this unusual incantation for its LIMIT equivalent:
|
|
//
|
|
// SELECT username, email, address_id FROM person
|
|
// ORDER BY (SELECT 0) OFFSET 0 ROWS FETCH NEXT 1 ROWS ONLY;
|
|
const queryTpl = "SELECT %s FROM %s ORDER BY (SELECT 0) OFFSET 0 ROWS FETCH NEXT 1 ROWS ONLY"
|
|
|
|
enquote := d.Dialect().Enquote
|
|
tblNameQuoted := enquote(tblName)
|
|
colNamesQuoted := langz.Apply(colNames, enquote)
|
|
colsJoined := strings.Join(colNamesQuoted, driver.Comma)
|
|
|
|
query := fmt.Sprintf(queryTpl, colsJoined, tblNameQuoted)
|
|
rows, err := db.QueryContext(ctx, query)
|
|
if err != nil {
|
|
return nil, errw(err)
|
|
}
|
|
|
|
colTypes, err := rows.ColumnTypes()
|
|
if err != nil {
|
|
lg.WarnIfFuncError(d.log, lgm.CloseDBRows, rows.Close)
|
|
return nil, errw(err)
|
|
}
|
|
|
|
if rows.Err() != nil {
|
|
return nil, errw(rows.Err())
|
|
}
|
|
|
|
destCols, _, err := d.RecordMeta(ctx, colTypes)
|
|
if err != nil {
|
|
lg.WarnIfFuncError(d.log, lgm.CloseDBRows, rows.Close)
|
|
return nil, errw(err)
|
|
}
|
|
|
|
err = rows.Close()
|
|
if err != nil {
|
|
return nil, errw(err)
|
|
}
|
|
|
|
return destCols, nil
|
|
}
|
|
|
|
// newStmtExecFunc returns a StmtExecFunc that has logic to deal with
|
|
// the "identity insert" error. If the error is encountered, setIdentityInsert
|
|
// is called and stmt is executed again.
|
|
func newStmtExecFunc(stmt *sql.Stmt, db sqlz.DB, tbl string) driver.StmtExecFunc {
|
|
return func(ctx context.Context, args ...any) (int64, error) {
|
|
res, err := stmt.ExecContext(ctx, args...)
|
|
if err == nil {
|
|
var affected int64
|
|
affected, err = res.RowsAffected()
|
|
return affected, errw(err)
|
|
}
|
|
|
|
if !hasErrCode(err, errCodeIdentityInsert) {
|
|
return 0, errw(err)
|
|
}
|
|
|
|
idErr := setIdentityInsert(ctx, db, tbl, true)
|
|
if idErr != nil {
|
|
return 0, errz.Append(errw(err), idErr)
|
|
}
|
|
|
|
res, err = stmt.ExecContext(ctx, args...)
|
|
if err != nil {
|
|
return 0, errw(err)
|
|
}
|
|
|
|
affected, err := res.RowsAffected()
|
|
return affected, errw(err)
|
|
}
|
|
}
|
|
|
|
// setIdentityInsert enables (or disables) "identity insert" for tbl on db.
|
|
// SQLServer is fussy about inserting values to the identity col. This
|
|
// error can be returned from the driver:
|
|
//
|
|
// mssql: Cannot insert explicit value for identity column in table 'payment' when IDENTITY_INSERT is set to OFF
|
|
//
|
|
// The solution is "SET IDENTITY_INSERT tbl ON".
|
|
//
|
|
// See: https://docs.microsoft.com/en-us/sql/t-sql/statements/set-identity-insert-transact-sql?view=sql-server-ver15
|
|
func setIdentityInsert(ctx context.Context, db sqlz.DB, tbl string, on bool) error {
|
|
mode := "ON"
|
|
if !on {
|
|
mode = "OFF"
|
|
}
|
|
|
|
query := fmt.Sprintf("SET IDENTITY_INSERT %s %s", tblfmt(tbl), mode)
|
|
_, err := db.ExecContext(ctx, query)
|
|
return errz.Wrapf(errw(err), "failed to SET IDENTITY INSERT %s %s", tblfmt(tbl), mode)
|
|
}
|
|
|
|
// tblfmt formats a table name for use in a query. The arg can be a string,
|
|
// or a tablefq.T.
|
|
func tblfmt[T string | tablefq.T](tbl T) string {
|
|
tfq := tablefq.From(tbl)
|
|
return tfq.Render(stringz.DoubleQuote)
|
|
}
|
|
|
|
// genDropSchemaObjectsStmt generates a SQL statement that drops all
|
|
// objects in the named schema. It is used by driveri.DropSchema.
|
|
// This statement is necessary because SQLServer
|
|
// doesn't support "DROP SCHEMA [NAME] CASCADE".
|
|
// Note that script may not be comprehensive; there could be other
|
|
// objects that we haven't considered. But it works on all that
|
|
// that's been tested so far.
|
|
//
|
|
// See: https://stackoverflow.com/a/8150428
|
|
//
|
|
//nolint:lll
|
|
func genDropSchemaObjectsStmt(schemaName string) string {
|
|
const tpl = `
|
|
declare @SchemaName nvarchar(100) = '%s'
|
|
declare @SchemaID int = schema_id(@SchemaName)
|
|
|
|
declare @n char(1)
|
|
set @n = char(10)
|
|
declare @stmt nvarchar(max)
|
|
|
|
-- procedures
|
|
select @stmt = isnull( @stmt + @n, '' ) +
|
|
'drop procedure [' + @SchemaName + '].[' + name + ']'
|
|
from sys.procedures where schema_id = @SchemaID
|
|
|
|
|
|
-- check constraints
|
|
select @stmt = isnull( @stmt + @n, '' ) +
|
|
'alter table [' + @SchemaName + '].[' + object_name( parent_object_id ) + '] drop constraint [' + name + ']'
|
|
from sys.check_constraints where schema_id = @SchemaID
|
|
|
|
-- functions
|
|
select @stmt = isnull( @stmt + @n, '' ) +
|
|
'drop function [' + @SchemaName + '].[' + name + ']'
|
|
from sys.objects
|
|
where schema_id = @SchemaID and type in ( 'FN', 'IF', 'TF' )
|
|
--
|
|
-- views
|
|
select @stmt = isnull( @stmt + @n, '' ) +
|
|
'drop view [' + @SchemaName + '].[' + name + ']'
|
|
from sys.views where schema_id = @SchemaID
|
|
--
|
|
-- foreign keys
|
|
select @stmt = isnull( @stmt + @n, '' ) +
|
|
'alter table [' + @SchemaName + '].[' + object_name( parent_object_id ) + '] drop constraint [' + name + ']'
|
|
from sys.foreign_keys where schema_id = @SchemaID
|
|
|
|
-- tables
|
|
select @stmt = isnull( @stmt + @n, '' ) +
|
|
'drop table [' + @SchemaName + '].[' + name + ']'
|
|
from sys.tables where schema_id = @SchemaID
|
|
|
|
-- user defined types
|
|
select @stmt = isnull( @stmt + @n, '' ) +
|
|
'drop type [' + @SchemaName + '].[' + name + ']'
|
|
from sys.types
|
|
where schema_id = @SchemaID and is_user_defined = 1
|
|
|
|
exec sp_executesql @stmt
|
|
`
|
|
|
|
return fmt.Sprintf(tpl, schemaName)
|
|
}
|