sq/drivers/postgres/metadata.go
Neil O'Toole 791d41ad1d
Removed dependencies on neilotoole/errgroup in favor of sync/errgroup (#176)
* Removed dependencies on neilotoole/errgroup in favor of sync/errgroup

* Removed deadcode
2023-04-02 15:17:15 -06:00

683 lines
19 KiB
Go

package postgres
import (
"context"
"database/sql"
"errors"
"fmt"
"reflect"
"strings"
"github.com/neilotoole/sq/libsq/core/lg/lga"
"github.com/neilotoole/sq/libsq/core/lg/lgm"
"github.com/neilotoole/sq/libsq/core/lg"
"golang.org/x/exp/slog"
"github.com/jackc/pgconn"
"github.com/neilotoole/sq/libsq/core/errz"
"github.com/neilotoole/sq/libsq/core/kind"
"github.com/neilotoole/sq/libsq/core/sqlz"
"github.com/neilotoole/sq/libsq/driver"
"github.com/neilotoole/sq/libsq/source"
"golang.org/x/sync/errgroup"
)
// kindFromDBTypeName determines the kind.Kind from the database
// type name. For example, "VARCHAR" -> kind.Text.
// See https://www.postgresql.org/docs/9.5/datatype.html
func kindFromDBTypeName(log *slog.Logger, colName, dbTypeName string) kind.Kind {
var knd kind.Kind
dbTypeName = strings.ToUpper(dbTypeName)
switch dbTypeName {
default:
log.Warn(
"Unknown Postgres column type: using alt type",
lga.DBType, dbTypeName,
lga.Col, colName,
lga.Alt, kind.Unknown,
)
knd = kind.Unknown
case "":
knd = kind.Unknown
case "INT", "INTEGER", "INT2", "INT4", "INT8", "SMALLINT", "BIGINT":
knd = kind.Int
case "CHAR", "CHARACTER", "VARCHAR", "TEXT", "BPCHAR", "CHARACTER VARYING": //nolint:goconst
knd = kind.Text
case "BYTEA":
knd = kind.Bytes
case "BOOL", "BOOLEAN":
knd = kind.Bool
case "TIMESTAMP", "TIMESTAMPTZ", "TIMESTAMP WITHOUT TIME ZONE":
knd = kind.Datetime
case "TIME", "TIMETZ", "TIME WITHOUT TIME ZONE": //nolint:goconst
knd = kind.Time
case "DATE":
knd = kind.Date
case "INTERVAL": // interval meaning time duration
knd = kind.Text
case "FLOAT", "FLOAT4", "FLOAT8", "DOUBLE", "DOUBLE PRECISION":
knd = kind.Float
case "UUID":
knd = kind.Text
case "DECIMAL", "NUMERIC", "MONEY":
knd = kind.Decimal
case "JSON", "JSONB":
knd = kind.Text
case "BIT", "VARBIT":
knd = kind.Text
case "XML":
knd = kind.Text
case "BOX", "CIRCLE", "LINE", "LSEG", "PATH", "POINT", "POLYGON":
knd = kind.Text
case "CIDR", "INET", "MACADDR":
knd = kind.Text
case "USER-DEFINED":
// REVISIT: How to handle USER-DEFINED type?
knd = kind.Text
case "TSVECTOR":
// REVISIT: how to handle TSVECTOR type?
knd = kind.Text
case "ARRAY":
// REVISIT: how to handle ARRAY type?
knd = kind.Text
}
return knd
}
// setScanType ensures that ctd's scan type field is set appropriately.
func setScanType(log *slog.Logger, ctd *sqlz.ColumnTypeData, knd kind.Kind) {
if knd == kind.Decimal {
// Force the use of string for decimal, as the driver will
// sometimes prefer float.
ctd.ScanType = sqlz.RTypeNullString
return
}
// Need to switch to the nullable scan types because the
// backing driver doesn't report nullable info accurately.
ctd.ScanType = toNullableScanType(log, ctd.Name, ctd.DatabaseTypeName, knd, ctd.ScanType)
}
// toNullableScanType returns the nullable equivalent of the scan type
// reported by the postgres driver's ColumnType.ScanType. This is necessary
// because the pgx driver does not support the stdlib sql
// driver.RowsColumnTypeNullable interface.
func toNullableScanType(log *slog.Logger, colName, dbTypeName string, knd kind.Kind,
pgScanType reflect.Type,
) reflect.Type {
var nullableScanType reflect.Type
switch pgScanType {
default:
// If we don't recognize the scan type (likely it's any),
// we explicitly switch through the db type names that we know.
// At this time, we will use NullString for all unrecognized
// scan types, but nonetheless we switch through the known db type
// names so that we see the log warning for truly unknown types.
switch dbTypeName {
default:
nullableScanType = sqlz.RTypeNullString
log.Warn("Unknown Postgres scan type",
lga.Col, colName,
lga.ScanType, pgScanType,
lga.DBType, dbTypeName,
lga.Kind, knd,
lga.DefaultTo, nullableScanType,
)
case "":
// NOTE: the pgx driver currently reports an empty dbTypeName for certain
// cols such as XML or MONEY.
nullableScanType = sqlz.RTypeNullString
case "TIME":
nullableScanType = sqlz.RTypeNullString
case "BIT", "VARBIT":
nullableScanType = sqlz.RTypeNullString
case "BPCHAR":
nullableScanType = sqlz.RTypeNullString
case "BOX", "CIRCLE", "LINE", "LSEG", "PATH", "POINT", "POLYGON":
nullableScanType = sqlz.RTypeNullString
case "CIDR", "INET", "MACADDR":
nullableScanType = sqlz.RTypeNullString
case "INTERVAL":
nullableScanType = sqlz.RTypeNullString
case "JSON", "JSONB":
nullableScanType = sqlz.RTypeNullString
case "XML":
nullableScanType = sqlz.RTypeNullString
case "UUID":
nullableScanType = sqlz.RTypeNullString
}
case sqlz.RTypeInt64, sqlz.RTypeInt, sqlz.RTypeInt8, sqlz.RTypeInt16, sqlz.RTypeInt32, sqlz.RTypeNullInt64:
nullableScanType = sqlz.RTypeNullInt64
case sqlz.RTypeFloat32, sqlz.RTypeFloat64, sqlz.RTypeNullFloat64:
nullableScanType = sqlz.RTypeNullFloat64
case sqlz.RTypeString, sqlz.RTypeNullString:
nullableScanType = sqlz.RTypeNullString
case sqlz.RTypeBool, sqlz.RTypeNullBool:
nullableScanType = sqlz.RTypeNullBool
case sqlz.RTypeTime, sqlz.RTypeNullTime:
nullableScanType = sqlz.RTypeNullTime
case sqlz.RTypeBytes:
nullableScanType = sqlz.RTypeBytes
}
return nullableScanType
}
func getSourceMetadata(ctx context.Context, src *source.Source, db sqlz.DB) (*source.Metadata, error) {
log := lg.FromContext(ctx)
md := &source.Metadata{
Handle: src.Handle,
Location: src.Location,
SourceType: src.Type,
DBDriverType: src.Type,
}
var schema sql.NullString
const summaryQuery = `SELECT current_catalog, current_schema(), pg_database_size(current_catalog),
current_setting('server_version'), version(), "current_user"()`
err := db.QueryRowContext(ctx, summaryQuery).
Scan(&md.Name, &schema, &md.Size, &md.DBVersion, &md.DBProduct, &md.User)
if err != nil {
return nil, errz.Err(err)
}
if !schema.Valid {
return nil, errz.New("NULL value for current_schema(): check privileges and search_path")
}
md.Schema = schema.String
md.FQName = md.Name + "." + schema.String
md.DBVars, err = getPgSettings(ctx, db)
if err != nil {
return nil, err
}
tblNames, err := getAllTableNames(ctx, db)
if err != nil {
return nil, err
}
g, gCtx := errgroup.WithContext(ctx)
g.SetLimit(driver.Tuning.ErrgroupNumG)
tblMetas := make([]*source.TableMetadata, len(tblNames))
for i := range tblNames {
select {
case <-gCtx.Done():
return nil, errz.Err(gCtx.Err())
default:
}
i := i
g.Go(func() error {
tblMeta, mdErr := getTableMetadata(gCtx, db, tblNames[i])
if mdErr != nil {
if hasErrCode(mdErr, errCodeRelationNotExist) {
// If the table is dropped while we're collecting metadata,
// for example, we log a warning and suppress the error.
log.Warn("metadata collection: table not found (continuing regardless)",
lga.Table, tblNames[i],
lga.Err, mdErr,
)
return nil
}
return mdErr
}
tblMetas[i] = tblMeta
return nil
})
}
err = g.Wait()
if err != nil {
return nil, errz.Err(err)
}
// If a table wasn't found (possibly dropped while querying), then
// its entry could be nil. We copy the non-nil elements to the
// final slice.
md.Tables = make([]*source.TableMetadata, 0, len(tblMetas))
for i := range tblMetas {
if tblMetas[i] != nil {
md.Tables = append(md.Tables, tblMetas[i])
}
}
return md, nil
}
// hasErrCode returns true if err (or its cause error)
// is of type *pgconn.PgError and err.Number equals code.
func hasErrCode(err error, code string) bool {
if err == nil {
return false
}
var pgErr *pgconn.PgError
if errors.As(err, &pgErr) {
return pgErr.Code == code
}
return false
}
const errCodeRelationNotExist = "42P01"
func getPgSettings(ctx context.Context, db sqlz.DB) ([]source.DBVar, error) {
log := lg.FromContext(ctx)
rows, err := db.QueryContext(ctx, "SELECT name, setting FROM pg_settings ORDER BY name")
if err != nil {
return nil, errz.Err(err)
}
defer lg.WarnIfCloseError(log, lgm.CloseDBRows, rows)
var dbVars []source.DBVar
for rows.Next() {
v := source.DBVar{}
err = rows.Scan(&v.Name, &v.Value)
if err != nil {
return nil, errz.Err(err)
}
dbVars = append(dbVars, v)
}
err = closeRows(rows)
if err != nil {
return nil, err
}
return dbVars, nil
}
// getAllTable names returns all table (or view) names in the current
// catalog & schema.
func getAllTableNames(ctx context.Context, db sqlz.DB) ([]string, error) {
log := lg.FromContext(ctx)
const tblNamesQuery = `SELECT table_name FROM information_schema.tables
WHERE table_catalog = current_catalog AND table_schema = current_schema()
ORDER BY table_name`
rows, err := db.QueryContext(ctx, tblNamesQuery)
if err != nil {
return nil, errz.Err(err)
}
defer lg.WarnIfCloseError(log, lgm.CloseDBRows, rows)
var tblNames []string
for rows.Next() {
var s string
err = rows.Scan(&s)
if err != nil {
return nil, errz.Err(err)
}
tblNames = append(tblNames, s)
}
err = closeRows(rows)
if err != nil {
return nil, err
}
return tblNames, nil
}
func getTableMetadata(ctx context.Context, db sqlz.DB, tblName string) (*source.TableMetadata, error) {
log := lg.FromContext(ctx)
const tblsQueryTpl = `SELECT table_catalog, table_schema, table_name, table_type, is_insertable_into,
(SELECT COUNT(*) FROM "%s") AS table_row_count,
pg_total_relation_size('%q') AS table_size,
(SELECT '%q'::regclass::oid AS table_oid),
obj_description('%q'::REGCLASS, 'pg_class') AS table_comment
FROM information_schema.tables
WHERE table_catalog = current_database()
AND table_schema = current_schema()
AND table_name = $1`
tablesQuery := fmt.Sprintf(tblsQueryTpl, tblName, tblName, tblName, tblName)
pgTbl := &pgTable{}
err := db.QueryRowContext(ctx, tablesQuery, tblName).
Scan(&pgTbl.tableCatalog, &pgTbl.tableSchema, &pgTbl.tableName, &pgTbl.tableType, &pgTbl.isInsertable,
&pgTbl.rowCount, &pgTbl.size, &pgTbl.oid, &pgTbl.comment)
if err != nil {
return nil, errz.Err(err)
}
tblMeta := tblMetaFromPgTable(pgTbl)
if tblMeta.Name != tblName {
// Shouldn't happen, but we'll error if it does
return nil, errz.Errorf("table {%s} not found in %s.%s", tblName, pgTbl.tableCatalog, pgTbl.tableSchema)
}
pgCols, err := getPgColumns(ctx, db, tblName)
if err != nil {
return nil, err
}
for _, pgCol := range pgCols {
colMeta := colMetaFromPgColumn(log, pgCol)
tblMeta.Columns = append(tblMeta.Columns, colMeta)
}
// We need to fetch the constraints to set the PK etc.
pgConstraints, err := getPgConstraints(ctx, db, tblName)
if err != nil {
return nil, err
}
setTblMetaConstraints(log, tblMeta, pgConstraints)
return tblMeta, nil
}
// pgTable holds query results for table metadata.
type pgTable struct {
tableCatalog string
tableSchema string
tableName string
tableType string
isInsertable sqlz.NullBool // Use driver.NullBool because "YES", "NO" values
rowCount int64
size sql.NullInt64
oid string
comment sql.NullString
}
func tblMetaFromPgTable(pgt *pgTable) *source.TableMetadata {
md := &source.TableMetadata{
Name: pgt.tableName,
FQName: fmt.Sprintf("%s.%s.%s", pgt.tableCatalog, pgt.tableSchema, pgt.tableName),
DBTableType: pgt.tableType,
RowCount: pgt.rowCount,
Comment: pgt.comment.String,
Columns: nil, // Note: columns are set independently later
}
if pgt.size.Valid && pgt.size.Int64 > 0 {
md.Size = &pgt.size.Int64
}
switch md.DBTableType {
case "BASE TABLE":
md.TableType = sqlz.TableTypeTable
case "VIEW":
md.TableType = sqlz.TableTypeView
default:
}
return md
}
// pgColumn holds query results for column metadata.
// See https://www.postgresql.org/docs/8.0/infoschema-columns.html
type pgColumn struct {
tableCatalog string
tableSchema string
tableName string
columnName string
ordinalPosition int64
columnDefault sql.NullString
isNullable sqlz.NullBool
dataType string
characterMaximumLength sql.NullInt64
characterOctetLength sql.NullInt64
numericPrecision sql.NullInt64
numericPrecisionRadix sql.NullInt64
numericScale sql.NullInt64
datetimePrecision sql.NullInt64
domainCatalog sql.NullString
domainSchema sql.NullString
domainName sql.NullString
udtCatalog string
udtSchema string
udtName string
isIdentity sqlz.NullBool
isGenerated sql.NullString
isUpdatable sqlz.NullBool
// comment holds any column comment. Note that this field is
// not part of the standard postgres infoschema, but is
// separately fetched.
comment sql.NullString
}
// getPgColumns queries the column metadata for tblName.
func getPgColumns(ctx context.Context, db sqlz.DB, tblName string) ([]*pgColumn, error) {
log := lg.FromContext(ctx)
// colsQuery gets column information from information_schema.columns.
//
// It also has a subquery to get column comments. See:
// - https://stackoverflow.com/a/22547588
// - https://dba.stackexchange.com/a/160668
const colsQuery = `SELECT table_catalog,
table_schema,
table_name,
column_name,
ordinal_position,
column_default,
is_nullable,
data_type,
character_maximum_length,
character_octet_length,
numeric_precision,
numeric_precision_radix,
numeric_scale,
datetime_precision,
domain_catalog,
domain_schema,
domain_name,
udt_catalog,
udt_schema,
udt_name,
is_identity,
is_generated,
is_updatable,
(
SELECT
pg_catalog.col_description(c.oid, cols.ordinal_position::INT)
FROM
pg_catalog.pg_class c
WHERE
c.oid = (SELECT ('"' || cols.table_name || '"')::regclass::oid)
AND c.relname = cols.table_name
) AS column_comment
FROM information_schema.columns cols
WHERE cols.table_catalog = current_catalog AND cols.table_schema = current_schema() AND cols.table_name = $1
ORDER BY cols.table_catalog, cols.table_schema, cols.table_name, cols.ordinal_position`
rows, err := db.QueryContext(ctx, colsQuery, tblName)
if err != nil {
return nil, errz.Err(err)
}
defer lg.WarnIfCloseError(log, lgm.CloseDBRows, rows)
var cols []*pgColumn
for rows.Next() {
col := &pgColumn{}
err = scanPgColumn(rows, col)
if err != nil {
return nil, err
}
cols = append(cols, col)
}
err = closeRows(rows)
if err != nil {
return nil, err
}
return cols, nil
}
func scanPgColumn(rows *sql.Rows, c *pgColumn) error {
err := rows.Scan(&c.tableCatalog, &c.tableSchema, &c.tableName, &c.columnName, &c.ordinalPosition,
&c.columnDefault, &c.isNullable, &c.dataType, &c.characterMaximumLength, &c.characterOctetLength,
&c.numericPrecision, &c.numericPrecisionRadix, &c.numericScale,
&c.datetimePrecision, &c.domainCatalog, &c.domainSchema, &c.domainName,
&c.udtCatalog, &c.udtSchema, &c.udtName,
&c.isIdentity, &c.isGenerated, &c.isUpdatable, &c.comment)
return errz.Err(err)
}
func colMetaFromPgColumn(log *slog.Logger, pgCol *pgColumn) *source.ColMetadata {
colMeta := &source.ColMetadata{
Name: pgCol.columnName,
Position: pgCol.ordinalPosition,
PrimaryKey: false, // Note that PrimaryKey is set separately from pgConstraint.
BaseType: pgCol.udtName,
ColumnType: pgCol.dataType,
Kind: kindFromDBTypeName(log, pgCol.columnName, pgCol.udtName),
Nullable: pgCol.isNullable.Bool,
DefaultValue: pgCol.columnDefault.String,
Comment: pgCol.comment.String,
}
return colMeta
}
// getPgConstraints returns a slice of pgConstraint. If tblName is
// empty, constraints for all tables in the current catalog & schema
// are returned. If tblName is specified, constraints just for that
// table are returned.
func getPgConstraints(ctx context.Context, db sqlz.DB, tblName string) ([]*pgConstraint, error) {
log := lg.FromContext(ctx)
var args []any
query := `SELECT kcu.table_catalog,kcu.table_schema,kcu.table_name,kcu.column_name,
kcu.ordinal_position,tc.constraint_name,tc.constraint_type,
(
SELECT pg_catalog.pg_get_constraintdef(pgc.oid, TRUE)
FROM pg_catalog.pg_constraint pgc
WHERE pgc.conrelid = (SELECT ('"' || kcu.table_name || '"')::regclass::oid)
AND pgc.conname = tc.constraint_name
limit 1
) AS constraint_def,
(
SELECT pgc.confrelid::regclass
FROM pg_catalog.pg_constraint pgc
WHERE pgc.conrelid = (SELECT ('"' || kcu.table_name || '"')::regclass::oid)
AND pgc.conname = tc.constraint_name
AND pgc.confrelid > 0
LIMIT 1
) AS constraint_fkey_table_name
FROM information_schema.key_column_usage AS kcu
LEFT JOIN information_schema.table_constraints AS tc
ON tc.constraint_name = kcu.constraint_name
WHERE kcu.table_catalog = current_catalog AND kcu.table_schema = current_schema()
`
if tblName != "" {
query += ` AND kcu.table_name = $1 `
args = append(args, tblName)
}
query += ` ORDER BY kcu.table_catalog, kcu.table_schema, kcu.table_name, tc.constraint_type DESC, kcu.ordinal_position`
rows, err := db.QueryContext(ctx, query, args...)
if err != nil {
return nil, errz.Err(err)
}
defer lg.WarnIfCloseError(log, lgm.CloseDBRows, rows)
var constraints []*pgConstraint
for rows.Next() {
pgc := &pgConstraint{}
err = rows.Scan(&pgc.tableCatalog, &pgc.tableSchema, &pgc.tableName, &pgc.columnName, &pgc.ordinalPosition,
&pgc.constraintName, &pgc.constraintType, &pgc.constraintDef, &pgc.constraintFKeyTableName)
if err != nil {
return nil, errz.Err(err)
}
constraints = append(constraints, pgc)
}
err = closeRows(rows)
if err != nil {
return nil, err
}
return constraints, nil
}
// pgConstraint holds query results for constraint metadata.
// This type is column-focused: that is, an instance is produced
// for each constraint/column pair. Thus, if a table has a
// composite primary key (col_a, col_b), two pgConstraint instances
// are produced.
type pgConstraint struct {
tableCatalog string
tableSchema string
tableName string
columnName string
ordinalPosition int64
constraintName sql.NullString
constraintType sql.NullString
constraintDef sql.NullString
// constraintFKeyTableName holds the name of the table to which
// a foreign-key constraint points to. This is null if this
// constraint is not a foreign key.
constraintFKeyTableName sql.NullString
}
// setTblMetaConstraints updates tblMeta with constraints found
// in pgConstraints.
func setTblMetaConstraints(log *slog.Logger, tblMeta *source.TableMetadata, pgConstraints []*pgConstraint) {
for _, pgc := range pgConstraints {
fqTblName := pgc.tableCatalog + "." + pgc.tableSchema + "." + pgc.tableName
if fqTblName != tblMeta.FQName {
continue
}
if pgc.constraintType.String == constraintTypePK {
colMeta := tblMeta.Column(pgc.columnName)
if colMeta == nil {
// Shouldn't happen
log.Warn("No column found matching constraint",
lga.Target, tblMeta.Name+"."+pgc.columnName,
"constraint", pgc.constraintName,
)
continue
}
colMeta.PrimaryKey = true
}
}
}
const (
constraintTypePK = "PRIMARY KEY"
constraintTypeFK = "FOREIGN KEY"
)
// closeRows invokes rows.Err and rows.Close, returning
// an error if either of those methods returned an error.
func closeRows(rows *sql.Rows) error {
if rows == nil {
return nil
}
err1 := rows.Err()
err2 := rows.Close()
if err1 != nil {
return errz.Err(err1)
}
return errz.Err(err2)
}