// Package sqlserver implements the sq driver for SQL Server.
package sqlserver

import (
	"context"
	"database/sql"
	"errors"
	"fmt"
	"strconv"
	"strings"

	"github.com/neilotoole/sq/libsq/driver/dialect"

	"github.com/neilotoole/sq/libsq/core/lg/lga"

	"github.com/neilotoole/sq/libsq/core/lg/lgm"

	"github.com/neilotoole/sq/libsq/core/lg"

	"golang.org/x/exp/slog"

	mssql "github.com/microsoft/go-mssqldb"

	"github.com/neilotoole/sq/libsq/ast/render"
	"github.com/neilotoole/sq/libsq/core/errz"
	"github.com/neilotoole/sq/libsq/core/kind"
	"github.com/neilotoole/sq/libsq/core/sqlmodel"
	"github.com/neilotoole/sq/libsq/core/sqlz"
	"github.com/neilotoole/sq/libsq/core/stringz"
	"github.com/neilotoole/sq/libsq/driver"
	"github.com/neilotoole/sq/libsq/source"
)

const (
	// Type is the SQL Server source driver type.
	Type = source.Type("sqlserver")

	// dbDrvr is the backing SQL Server driver impl name.
	dbDrvr = "sqlserver"
)

var _ driver.Provider = (*Provider)(nil)

// Provider is the SQL Server implementation of driver.Provider.
type Provider struct {
	Log *slog.Logger
}

// DriverFor implements driver.Provider.
func (p *Provider) DriverFor(typ source.Type) (driver.Driver, error) {
	if typ != Type {
		return nil, errz.Errorf("unsupported driver type %q", typ)
	}

	return &driveri{log: p.Log}, nil
}

var _ driver.SQLDriver = (*driveri)(nil)

// driveri is the SQL Server implementation of driver.Driver.
type driveri struct {
	log *slog.Logger
}

// DriverMetadata implements driver.SQLDriver.
func (d *driveri) DriverMetadata() driver.Metadata {
	return driver.Metadata{
		Type:        Type,
		Description: "Microsoft SQL Server / Azure SQL Edge",
		Doc:         "https://github.com/microsoft/go-mssqldb",
		IsSQL:       true,
	}
}

// Dialect implements driver.SQLDriver.
func (d *driveri) Dialect() dialect.Dialect {
	return dialect.Dialect{
		Type:           Type,
		Placeholders:   placeholders,
		IdentQuote:     '"',
		Enquote:        stringz.DoubleQuote,
		MaxBatchValues: 1000,
		Ops:            dialect.DefaultOps(),
	}
}

func placeholders(numCols, numRows int) string {
	rows := make([]string, numRows)

	n := 1
	var sb strings.Builder
	for i := 0; i < numRows; i++ {
		sb.Reset()
		sb.WriteRune('(')
		for j := 1; j <= numCols; j++ {
			sb.WriteString("@p")
			sb.WriteString(strconv.Itoa(n))
			n++
			if j < numCols {
				sb.WriteString(driver.Comma)
			}
		}
		sb.WriteRune(')')
		rows[i] = sb.String()
	}

	return strings.Join(rows, driver.Comma)
}

// Renderer implements driver.SQLDriver.
func (d *driveri) Renderer() *render.Renderer {
	r := render.NewDefaultRenderer()

	// Custom functions for SQLServer specific stuff.
	r.Range = renderRange
	r.PreRender = preRender

	return r
}

// Open implements driver.Driver.
func (d *driveri) Open(ctx context.Context, src *source.Source) (driver.Database, error) {
	db, err := sql.Open(dbDrvr, src.Location)
	if err != nil {
		return nil, errz.Err(err)
	}

	err = db.PingContext(ctx)
	if err != nil {
		lg.WarnIfCloseError(d.log, lgm.CloseDB, db)
		return nil, errz.Err(err)
	}

	return &database{log: d.log, db: db, src: src, drvr: d}, nil
}

// ValidateSource implements driver.Driver.
func (d *driveri) ValidateSource(src *source.Source) (*source.Source, error) {
	if src.Type != Type {
		return nil, errz.Errorf("expected source type %q but got %q", Type, src.Type)
	}
	return src, nil
}

// Ping implements driver.Driver.
func (d *driveri) Ping(ctx context.Context, src *source.Source) error {
	db, err := sql.Open(dbDrvr, src.Location)
	if err != nil {
		return errz.Err(err)
	}

	defer lg.WarnIfCloseError(d.log, lgm.CloseDB, db)

	err = db.PingContext(ctx)
	return errz.Err(err)
}

// Truncate implements driver.Driver. Due to a quirk of SQL Server, the
// operation is implemented in two statements. First "DELETE FROM tbl" to
// delete all rows. Then, if reset is true, the table sequence counter
// is reset via RESEED.
//
//nolint:lll
func (d *driveri) Truncate(ctx context.Context, src *source.Source, tbl string, reset bool) (affected int64,
	err error,
) {
	// https://docs.microsoft.com/en-us/sql/t-sql/statements/truncate-table-transact-sql?view=sql-server-ver15

	// When there are foreign key constraints on mssql tables,
	// it's not possible to TRUNCATE the table. An alternative is
	// to delete all rows and reseed the identity column.
	//
	//  DELETE FROM "table1"; DBCC CHECKIDENT ('table1', RESEED, 1);
	//
	// See: https://stackoverflow.com/questions/253849/cannot-truncate-table-because-it-is-being-referenced-by-a-foreign-key-constraint

	db, err := sql.Open(dbDrvr, src.Location)
	if err != nil {
		return 0, errz.Err(err)
	}
	defer lg.WarnIfFuncError(d.log, lgm.CloseDB, db.Close)

	affected, err = sqlz.ExecAffected(ctx, db, fmt.Sprintf("DELETE FROM %q", tbl))
	if err != nil {
		return affected, errz.Wrapf(err, "truncate: failed to delete from %q", tbl)
	}

	if reset {
		_, err = db.ExecContext(ctx, fmt.Sprintf("DBCC CHECKIDENT ('%s', RESEED, 1)", tbl))
		if err != nil {
			return affected, errz.Wrapf(err, "truncate: deleted %d rows from %q but RESEED failed", affected, tbl)
		}
	}

	return affected, nil
}

// TableColumnTypes implements driver.SQLDriver.
func (d *driveri) TableColumnTypes(ctx context.Context, db sqlz.DB, tblName string,
	colNames []string,
) ([]*sql.ColumnType, error) {
	// SQLServer has this unusual incantation for its LIMIT equivalent:
	//
	// SELECT username, email, address_id FROM person
	// ORDER BY (SELECT 0) OFFSET 0 ROWS FETCH NEXT 1 ROWS ONLY;
	const queryTpl = "SELECT %s FROM %s ORDER BY (SELECT 0) OFFSET 0 ROWS FETCH NEXT 1 ROWS ONLY"

	dialect := d.Dialect()
	quote := string(dialect.IdentQuote)
	tblNameQuoted := stringz.Surround(tblName, quote)

	colsClause := "*"
	if len(colNames) > 0 {
		colNamesQuoted := stringz.SurroundSlice(colNames, quote)
		colsClause = strings.Join(colNamesQuoted, driver.Comma)
	}

	query := fmt.Sprintf(queryTpl, colsClause, tblNameQuoted)
	rows, err := db.QueryContext(ctx, query)
	if err != nil {
		return nil, errz.Err(err)
	}

	colTypes, err := rows.ColumnTypes()
	if err != nil {
		lg.WarnIfFuncError(d.log, lgm.CloseDBRows, rows.Close)
		return nil, errz.Err(err)
	}

	err = rows.Err()
	if err != nil {
		lg.WarnIfFuncError(d.log, lgm.CloseDBRows, rows.Close)
		return nil, errz.Err(err)
	}

	err = rows.Close()
	if err != nil {
		return nil, errz.Err(err)
	}

	return colTypes, nil
}

// RecordMeta implements driver.SQLDriver.
func (d *driveri) RecordMeta(colTypes []*sql.ColumnType) (sqlz.RecordMeta, driver.NewRecordFunc, error) {
	recMeta := make([]*sqlz.FieldMeta, len(colTypes))
	for i, colType := range colTypes {
		kind := kindFromDBTypeName(d.log, colType.Name(), colType.DatabaseTypeName())
		colTypeData := sqlz.NewColumnTypeData(colType, kind)
		setScanType(colTypeData, kind)
		recMeta[i] = sqlz.NewFieldMeta(colTypeData)
	}

	mungeFn := func(vals []any) (sqlz.Record, error) {
		// sqlserver doesn't need to do any special munging, so we
		// just use the default munging.
		rec, skipped := driver.NewRecordFromScanRow(recMeta, vals, nil)
		if len(skipped) > 0 {
			return nil, errz.Errorf("expected zero skipped cols but have %d", skipped)
		}
		return rec, nil
	}

	return recMeta, mungeFn, nil
}

// TableExists implements driver.SQLDriver.
func (d *driveri) TableExists(ctx context.Context, db sqlz.DB, tbl string) (bool, error) {
	const query = `SELECT COUNT(*) FROM information_schema.tables
WHERE table_schema = 'dbo' AND table_name = @p1`

	var count int64
	err := db.QueryRowContext(ctx, query, tbl).Scan(&count)
	if err != nil {
		return false, errz.Err(err)
	}

	return count == 1, nil
}

// CurrentSchema implements driver.SQLDriver.
func (d *driveri) CurrentSchema(ctx context.Context, db sqlz.DB) (string, error) {
	var name string
	if err := db.QueryRowContext(ctx, `SELECT SCHEMA_NAME()`).Scan(&name); err != nil {
		return "", errz.Err(err)
	}

	return name, nil
}

// CreateTable implements driver.SQLDriver.
func (d *driveri) CreateTable(ctx context.Context, db sqlz.DB, tblDef *sqlmodel.TableDef) error {
	stmt := buildCreateTableStmt(tblDef)

	_, err := db.ExecContext(ctx, stmt)
	return errz.Err(err)
}

// AlterTableAddColumn implements driver.SQLDriver.
func (d *driveri) AlterTableAddColumn(ctx context.Context, db sqlz.DB, tbl, col string, knd kind.Kind) error {
	q := fmt.Sprintf("ALTER TABLE %q ADD %q ", tbl, col) + dbTypeNameFromKind(knd)

	_, err := db.ExecContext(ctx, q)
	return errz.Wrapf(err, "alter table: failed to add column %q to table %q", col, tbl)
}

// AlterTableRename implements driver.SQLDriver.
func (d *driveri) AlterTableRename(ctx context.Context, db sqlz.DB, tbl, newName string) error {
	schema, err := d.CurrentSchema(ctx, db)
	if err != nil {
		return err
	}

	q := fmt.Sprintf(`exec sp_rename '[%s].[%s]', '%s'`, schema, tbl, newName)
	_, err = db.ExecContext(ctx, q)
	return errz.Wrapf(err, "alter table: failed to rename table %q to %q", tbl, newName)
}

// AlterTableRenameColumn implements driver.SQLDriver.
func (d *driveri) AlterTableRenameColumn(ctx context.Context, db sqlz.DB, tbl, col, newName string) error {
	schema, err := d.CurrentSchema(ctx, db)
	if err != nil {
		return err
	}

	q := fmt.Sprintf(`exec sp_rename '[%s].[%s].[%s]', '%s'`, schema, tbl, col, newName)
	_, err = db.ExecContext(ctx, q)
	return errz.Wrapf(err, "alter table: failed to rename column {%s.%s.%s} to {%s}", schema, tbl, col, newName)
}

// CopyTable implements driver.SQLDriver.
func (d *driveri) CopyTable(ctx context.Context, db sqlz.DB, fromTable, toTable string, copyData bool) (int64, error) {
	var stmt string

	if copyData {
		stmt = fmt.Sprintf("SELECT * INTO %q FROM %q", toTable, fromTable)
	} else {
		stmt = fmt.Sprintf("SELECT TOP(0) * INTO %q FROM %q", toTable, fromTable)
	}

	affected, err := sqlz.ExecAffected(ctx, db, stmt)
	if err != nil {
		return 0, errz.Err(err)
	}

	return affected, nil
}

// DropTable implements driver.SQLDriver.
func (d *driveri) DropTable(ctx context.Context, db sqlz.DB, tbl string, ifExists bool) error {
	var stmt string

	if ifExists {
		stmt = fmt.Sprintf("IF OBJECT_ID('dbo.%s', 'U') IS NOT NULL DROP TABLE dbo.%q", tbl, tbl)
	} else {
		stmt = fmt.Sprintf("DROP TABLE dbo.%q", tbl)
	}

	_, err := db.ExecContext(ctx, stmt)
	return errz.Err(err)
}

// PrepareInsertStmt implements driver.SQLDriver.
func (d *driveri) PrepareInsertStmt(ctx context.Context, db sqlz.DB, destTbl string, destColNames []string,
	numRows int,
) (*driver.StmtExecer, error) {
	destColsMeta, err := d.getTableColsMeta(ctx, db, destTbl, destColNames)
	if err != nil {
		return nil, err
	}

	stmt, err := driver.PrepareInsertStmt(ctx, d, db, destTbl, destColsMeta.Names(), numRows)
	if err != nil {
		return nil, err
	}

	execer := driver.NewStmtExecer(stmt, driver.DefaultInsertMungeFunc(destTbl, destColsMeta),
		newStmtExecFunc(stmt, db, destTbl), destColsMeta)
	return execer, nil
}

// PrepareUpdateStmt implements driver.SQLDriver.
func (d *driveri) PrepareUpdateStmt(ctx context.Context, db sqlz.DB, destTbl string, destColNames []string,
	where string,
) (*driver.StmtExecer, error) {
	destColsMeta, err := d.getTableColsMeta(ctx, db, destTbl, destColNames)
	if err != nil {
		return nil, err
	}

	query, err := buildUpdateStmt(destTbl, destColNames, where)
	if err != nil {
		return nil, err
	}

	stmt, err := db.PrepareContext(ctx, query)
	if err != nil {
		return nil, err
	}

	execer := driver.NewStmtExecer(stmt, driver.DefaultInsertMungeFunc(destTbl, destColsMeta),
		newStmtExecFunc(stmt, db, destTbl), destColsMeta)
	return execer, nil
}

func (d *driveri) getTableColsMeta(ctx context.Context, db sqlz.DB, tblName string, colNames []string) (sqlz.RecordMeta,
	error,
) {
	// SQLServer has this unusual incantation for its LIMIT equivalent:
	//
	// SELECT username, email, address_id FROM person
	// ORDER BY (SELECT 0) OFFSET 0 ROWS FETCH NEXT 1 ROWS ONLY;
	const queryTpl = "SELECT %s FROM %s ORDER BY (SELECT 0) OFFSET 0 ROWS FETCH NEXT 1 ROWS ONLY"

	dialect := d.Dialect()
	quote := string(dialect.IdentQuote)
	tblNameQuoted := stringz.Surround(tblName, quote)
	colNamesQuoted := stringz.SurroundSlice(colNames, quote)
	colsJoined := strings.Join(colNamesQuoted, driver.Comma)

	query := fmt.Sprintf(queryTpl, colsJoined, tblNameQuoted)
	rows, err := db.QueryContext(ctx, query)
	if err != nil {
		return nil, errz.Err(err)
	}

	colTypes, err := rows.ColumnTypes()
	if err != nil {
		lg.WarnIfFuncError(d.log, lgm.CloseDBRows, rows.Close)
		return nil, errz.Err(err)
	}

	if rows.Err() != nil {
		return nil, errz.Err(rows.Err())
	}

	destCols, _, err := d.RecordMeta(colTypes)
	if err != nil {
		lg.WarnIfFuncError(d.log, lgm.CloseDBRows, rows.Close)
		return nil, errz.Err(err)
	}

	err = rows.Close()
	if err != nil {
		return nil, errz.Err(err)
	}

	return destCols, nil
}

// database implements driver.Database.
type database struct {
	log  *slog.Logger
	drvr *driveri
	db   *sql.DB
	src  *source.Source
}

var _ driver.Database = (*database)(nil)

// DB implements driver.Database.
func (d *database) DB() *sql.DB {
	return d.db
}

// SQLDriver implements driver.Database.
func (d *database) SQLDriver() driver.SQLDriver {
	return d.drvr
}

// Source implements driver.Database.
func (d *database) Source() *source.Source {
	return d.src
}

// TableMetadata implements driver.Database.
func (d *database) TableMetadata(ctx context.Context, tblName string) (*source.TableMetadata, error) {
	const query = `SELECT TABLE_CATALOG, TABLE_SCHEMA, TABLE_TYPE
FROM INFORMATION_SCHEMA.TABLES
WHERE TABLE_NAME = @p1`

	var catalog, schema, tblType string
	err := d.db.QueryRowContext(ctx, query, tblName).Scan(&catalog, &schema, &tblType)
	if err != nil {
		return nil, errz.Err(err)
	}

	return getTableMetadata(ctx, d.db, catalog, schema, tblName, tblType)
}

// SourceMetadata implements driver.Database.
func (d *database) SourceMetadata(ctx context.Context) (*source.Metadata, error) {
	return getSourceMetadata(ctx, d.src, d.db)
}

// Close implements driver.Database.
func (d *database) Close() error {
	d.log.Debug(lgm.CloseDB, lga.Src, d.src)

	return errz.Err(d.db.Close())
}

// newStmtExecFunc returns a StmtExecFunc that has logic to deal with
// the "identity insert" error. If the error is encountered, setIdentityInsert
// is called and stmt is executed again.
func newStmtExecFunc(stmt *sql.Stmt, db sqlz.DB, tbl string) driver.StmtExecFunc {
	return func(ctx context.Context, args ...any) (int64, error) {
		res, err := stmt.ExecContext(ctx, args...)
		if err == nil {
			var affected int64
			affected, err = res.RowsAffected()
			return affected, errz.Err(err)
		}

		if !hasErrCode(err, errCodeIdentityInsert) {
			return 0, errz.Err(err)
		}

		idErr := setIdentityInsert(ctx, db, tbl, true)
		if idErr != nil {
			return 0, errz.Combine(err, idErr)
		}

		res, err = stmt.ExecContext(ctx, args...)
		if err != nil {
			return 0, errz.Err(err)
		}

		affected, err := res.RowsAffected()
		return affected, errz.Err(err)
	}
}

// setIdentityInsert enables (or disables) "identity insert" for tbl on db.
// SQLServer is fussy about inserting values to the identity col. This
// error can be returned from the driver:
//
//	mssql: Cannot insert explicit value for identity column in table 'payment' when IDENTITY_INSERT is set to OFF
//
// The solution is "SET IDENTITY_INSERT tbl ON".
//
// See: https://docs.microsoft.com/en-us/sql/t-sql/statements/set-identity-insert-transact-sql?view=sql-server-ver15
func setIdentityInsert(ctx context.Context, db sqlz.DB, tbl string, on bool) error {
	mode := "ON"
	if !on {
		mode = "OFF"
	}

	query := fmt.Sprintf("SET IDENTITY_INSERT %q %s", tbl, mode)
	_, err := db.ExecContext(ctx, query)
	return errz.Wrapf(err, "failed to SET IDENTITY INSERT %s %s", tbl, mode)
}

// mssql error codes
//
//nolint:lll
const (
	// See: https://docs.microsoft.com/en-us/sql/relational-databases/errors-events/database-engine-events-and-errors?view=sql-server-ver15
	errCodeIdentityInsert int32 = 544
	errCodeObjectNotExist int32 = 15009
)

// hasErrCode returns true if err (or its cause err) is
// of type mssql.Error and err.Number equals code.
func hasErrCode(err error, code int32) bool {
	if err == nil {
		return false
	}

	var msErr mssql.Error
	if errors.As(err, &msErr) {
		return msErr.Number == code
	}

	return false
}