package postgres import ( "context" "database/sql" "fmt" "log/slog" "reflect" "strconv" "strings" "golang.org/x/sync/errgroup" "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/kind" "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/lg/lgm" "github.com/neilotoole/sq/libsq/core/options" "github.com/neilotoole/sq/libsq/core/progress" "github.com/neilotoole/sq/libsq/core/record" "github.com/neilotoole/sq/libsq/core/sqlz" "github.com/neilotoole/sq/libsq/core/stringz" "github.com/neilotoole/sq/libsq/driver" "github.com/neilotoole/sq/libsq/source" "github.com/neilotoole/sq/libsq/source/metadata" ) // kindFromDBTypeName determines the kind.Kind from the database // type name. For example, "VARCHAR" -> kind.Text. // See https://www.postgresql.org/docs/9.5/datatype.html func kindFromDBTypeName(log *slog.Logger, colName, dbTypeName string) kind.Kind { var knd kind.Kind dbTypeName = strings.ToUpper(dbTypeName) switch dbTypeName { default: log.Warn( "Unknown Postgres column type: using alt type", lga.DBType, dbTypeName, lga.Col, colName, lga.Alt, kind.Unknown, ) knd = kind.Unknown case "": knd = kind.Unknown case "INT", "INTEGER", "INT2", "INT4", "INT8", "SMALLINT", "BIGINT": knd = kind.Int case "CHAR", "CHARACTER", "VARCHAR", "TEXT", "BPCHAR", "CHARACTER VARYING": //nolint:goconst knd = kind.Text case "BYTEA": knd = kind.Bytes case "BOOL", "BOOLEAN": knd = kind.Bool case "TIMESTAMP", "TIMESTAMPTZ", "TIMESTAMP WITHOUT TIME ZONE": knd = kind.Datetime case "TIME", "TIMETZ", "TIME WITHOUT TIME ZONE": //nolint:goconst knd = kind.Time case "DATE": knd = kind.Date case "INTERVAL": // interval meaning time duration knd = kind.Text case "FLOAT", "FLOAT4", "FLOAT8", "DOUBLE", "DOUBLE PRECISION": knd = kind.Float case "UUID": knd = kind.Text case "DECIMAL", "NUMERIC", "MONEY": knd = kind.Decimal case "JSON", "JSONB": knd = kind.Text case "BIT", "VARBIT": knd = kind.Text case "XML": knd = kind.Text case "BOX", "CIRCLE", "LINE", "LSEG", "PATH", "POINT", "POLYGON": knd = kind.Text case "CIDR", "INET", "MACADDR": knd = kind.Text case "USER-DEFINED": // REVISIT: How to handle USER-DEFINED type? knd = kind.Text case "TSVECTOR": // REVISIT: how to handle TSVECTOR type? knd = kind.Text case "ARRAY": // REVISIT: how to handle ARRAY type? knd = kind.Text } return knd } // setScanType ensures that ctd's scan type field is set appropriately. func setScanType(log *slog.Logger, ctd *record.ColumnTypeData, knd kind.Kind) { if knd == kind.Decimal { ctd.ScanType = sqlz.RTypeNullDecimal return } // Need to switch to the nullable scan types because the // backing driver doesn't report nullable info accurately. ctd.ScanType = toNullableScanType(log, ctd.Name, ctd.DatabaseTypeName, knd, ctd.ScanType) } // toNullableScanType returns the nullable equivalent of the scan type // reported by the postgres driver's ColumnType.ScanType. This is necessary // because the pgx driver does not support the stdlib sql // driver.RowsColumnTypeNullable interface. func toNullableScanType(log *slog.Logger, colName, dbTypeName string, knd kind.Kind, pgScanType reflect.Type, ) reflect.Type { var nullableScanType reflect.Type switch pgScanType { default: // If we don't recognize the scan type (likely it's any), // we explicitly switch through the db type names that we know. // At this time, we will use NullString for all unrecognized // scan types, but nonetheless we switch through the known db type // names so that we see the log warning for truly unknown types. switch dbTypeName { default: nullableScanType = sqlz.RTypeNullString log.Warn("Unknown Postgres scan type", lga.Col, colName, lga.ScanType, pgScanType, lga.DBType, dbTypeName, lga.Kind, knd, lga.DefaultTo, nullableScanType, ) case "": // NOTE: the pgx driver currently reports an empty dbTypeName for certain // cols such as XML or MONEY. nullableScanType = sqlz.RTypeNullString case "TIME": nullableScanType = sqlz.RTypeNullString case "BIT", "VARBIT": nullableScanType = sqlz.RTypeNullString case "BPCHAR": nullableScanType = sqlz.RTypeNullString case "BOX", "CIRCLE", "LINE", "LSEG", "PATH", "POINT", "POLYGON": nullableScanType = sqlz.RTypeNullString case "CIDR", "INET", "MACADDR": nullableScanType = sqlz.RTypeNullString case "INTERVAL": nullableScanType = sqlz.RTypeNullString case "JSON", "JSONB": nullableScanType = sqlz.RTypeNullString case "XML": nullableScanType = sqlz.RTypeNullString case "UUID": nullableScanType = sqlz.RTypeNullString case "NUMERIC": nullableScanType = sqlz.RTypeNullDecimal } case sqlz.RTypeInt64, sqlz.RTypeInt, sqlz.RTypeInt8, sqlz.RTypeInt16, sqlz.RTypeInt32, sqlz.RTypeNullInt64: nullableScanType = sqlz.RTypeNullInt64 case sqlz.RTypeFloat32, sqlz.RTypeFloat64, sqlz.RTypeNullFloat64: nullableScanType = sqlz.RTypeNullFloat64 case sqlz.RTypeString, sqlz.RTypeNullString: nullableScanType = sqlz.RTypeNullString case sqlz.RTypeBool, sqlz.RTypeNullBool: nullableScanType = sqlz.RTypeNullBool case sqlz.RTypeTime, sqlz.RTypeNullTime: nullableScanType = sqlz.RTypeNullTime case sqlz.RTypeBytes: nullableScanType = sqlz.RTypeBytes case sqlz.RTypeDecimal: nullableScanType = sqlz.RTypeNullDecimal } return nullableScanType } func getSourceMetadata(ctx context.Context, src *source.Source, db sqlz.DB, noSchema bool) (*metadata.Source, error) { log := lg.FromContext(ctx) ctx = options.NewContext(ctx, src.Options) md := &metadata.Source{ Handle: src.Handle, Location: src.Location, Driver: src.Type, DBDriver: src.Type, } var schema sql.NullString const summaryQuery = `SELECT current_catalog, current_schema(), pg_database_size(current_catalog), current_setting('server_version'), version(), "current_user"()` err := db.QueryRowContext(ctx, summaryQuery). Scan(&md.Name, &schema, &md.Size, &md.DBVersion, &md.DBProduct, &md.User) if err != nil { return nil, errw(err) } progress.Incr(ctx, 1) progress.DebugSleep(ctx) if !schema.Valid { return nil, errz.New("NULL value for current_schema(): check privileges and search_path") } md.Catalog = md.Name md.Schema = schema.String md.FQName = md.Name + "." + schema.String md.DBProperties, err = getPgSettings(ctx, db) if err != nil { return nil, err } if noSchema { return md, nil } tblNames, err := getAllTableNames(ctx, db) if err != nil { return nil, err } g, gCtx := errgroup.WithContext(ctx) g.SetLimit(driver.OptTuningErrgroupLimit.Get(src.Options)) tblMetas := make([]*metadata.Table, len(tblNames)) for i := range tblNames { i := i g.Go(func() error { select { case <-gCtx.Done(): return gCtx.Err() default: } var tblMeta *metadata.Table var mdErr error mdErr = doRetry(gCtx, func() error { tblMeta, mdErr = getTableMetadata(gCtx, db, tblNames[i]) return mdErr }) if mdErr != nil { switch { case isErrRelationNotExist(err): // For example, if the table is dropped while we're collecting // metadata, we log a warning and suppress the error. log.Warn("metadata collection: table not found (continuing regardless)", lga.Table, tblNames[i], lga.Err, mdErr, ) default: return err } } tblMetas[i] = tblMeta return nil }) } err = g.Wait() if err != nil { return nil, errw(err) } // If a table wasn't found (possibly dropped while querying), then // its entry could be nil. We copy the non-nil elements to the // final slice. md.Tables = make([]*metadata.Table, 0, len(tblMetas)) for i := range tblMetas { if tblMetas[i] != nil { md.Tables = append(md.Tables, tblMetas[i]) } } for _, tbl := range md.Tables { if tbl.TableType == sqlz.TableTypeTable { md.TableCount++ } else if tbl.TableType == sqlz.TableTypeView { md.ViewCount++ } } return md, nil } func getPgSettings(ctx context.Context, db sqlz.DB) (map[string]any, error) { rows, err := db.QueryContext(ctx, "SELECT name, setting, vartype FROM pg_settings ORDER BY name") if err != nil { return nil, errw(err) } defer lg.WarnIfCloseError(lg.FromContext(ctx), lgm.CloseDBRows, rows) m := map[string]any{} for rows.Next() { var ( name string setting string typ string val any ) if err = rows.Scan(&name, &setting, &typ); err != nil { return nil, errw(err) } progress.Incr(ctx, 1) progress.DebugSleep(ctx) // Narrow the setting value bool, int, etc. val = setting switch typ { case "integer": var i int if i, err = strconv.Atoi(setting); err == nil { val = i } case "bool": var b bool if b, err = stringz.ParseBool(setting); err == nil { val = b } case "real": var f float64 if f, err = strconv.ParseFloat(setting, 64); err == nil { val = f } case "enum", "string": default: // Leave as string } m[name] = val } if err = closeRows(rows); err != nil { return nil, err } return m, nil } // getAllTable names returns all table (or view) names in the current // catalog & schema. func getAllTableNames(ctx context.Context, db sqlz.DB) ([]string, error) { log := lg.FromContext(ctx) const tblNamesQuery = `SELECT table_name FROM information_schema.tables WHERE table_catalog = current_catalog AND table_schema = current_schema() ORDER BY table_name` rows, err := db.QueryContext(ctx, tblNamesQuery) if err != nil { return nil, errw(err) } defer lg.WarnIfCloseError(log, lgm.CloseDBRows, rows) var tblNames []string for rows.Next() { var s string err = rows.Scan(&s) if err != nil { return nil, errw(err) } tblNames = append(tblNames, s) progress.Incr(ctx, 1) progress.DebugSleep(ctx) } err = closeRows(rows) if err != nil { return nil, err } return tblNames, nil } func getTableMetadata(ctx context.Context, db sqlz.DB, tblName string) (*metadata.Table, error) { log := lg.FromContext(ctx) const tblsQueryTpl = `SELECT table_catalog, table_schema, table_name, table_type, is_insertable_into, (SELECT COUNT(*) FROM "%s") AS table_row_count, pg_total_relation_size('%q') AS table_size, (SELECT '%q'::regclass::oid AS table_oid), obj_description('%q'::REGCLASS, 'pg_class') AS table_comment FROM information_schema.tables WHERE table_catalog = current_database() AND table_schema = current_schema() AND table_name = $1` tablesQuery := fmt.Sprintf(tblsQueryTpl, tblName, tblName, tblName, tblName) pgTbl := &pgTable{} err := db.QueryRowContext(ctx, tablesQuery, tblName). Scan(&pgTbl.tableCatalog, &pgTbl.tableSchema, &pgTbl.tableName, &pgTbl.tableType, &pgTbl.isInsertable, &pgTbl.rowCount, &pgTbl.size, &pgTbl.oid, &pgTbl.comment) if err != nil { return nil, errw(err) } progress.Incr(ctx, 1) progress.DebugSleep(ctx) tblMeta := tblMetaFromPgTable(pgTbl) if tblMeta.Name != tblName { // Shouldn't happen, but we'll error if it does return nil, errz.Errorf("table {%s} not found in %s.%s", tblName, pgTbl.tableCatalog, pgTbl.tableSchema) } pgCols, err := getPgColumns(ctx, db, tblName) if err != nil { return nil, err } for _, pgCol := range pgCols { colMeta := colMetaFromPgColumn(log, pgCol) tblMeta.Columns = append(tblMeta.Columns, colMeta) } // We need to fetch the constraints to set the PK etc. pgConstraints, err := getPgConstraints(ctx, db, tblName) if err != nil { return nil, err } setTblMetaConstraints(log, tblMeta, pgConstraints) return tblMeta, nil } // pgTable holds query results for table metadata. type pgTable struct { tableCatalog string tableSchema string tableName string tableType string oid string comment sql.NullString size sql.NullInt64 rowCount int64 isInsertable sqlz.NullBool // Use driver.NullBool because "YES", "NO" values } func tblMetaFromPgTable(pgt *pgTable) *metadata.Table { md := &metadata.Table{ Name: pgt.tableName, FQName: fmt.Sprintf("%s.%s.%s", pgt.tableCatalog, pgt.tableSchema, pgt.tableName), DBTableType: pgt.tableType, RowCount: pgt.rowCount, Comment: pgt.comment.String, Columns: nil, // Note: columns are set independently later } if pgt.size.Valid && pgt.size.Int64 > 0 { md.Size = &pgt.size.Int64 } switch md.DBTableType { case "BASE TABLE": md.TableType = sqlz.TableTypeTable case "VIEW": md.TableType = sqlz.TableTypeView default: } return md } // pgColumn holds query results for column metadata. // See https://www.postgresql.org/docs/8.0/infoschema-columns.html type pgColumn struct { tableCatalog string tableSchema string tableName string columnName string dataType string udtCatalog string udtSchema string udtName string columnDefault sql.NullString domainCatalog sql.NullString domainSchema sql.NullString domainName sql.NullString isGenerated sql.NullString // comment holds any column comment. Note that this field is // not part of the standard postgres infoschema, but is // separately fetched. comment sql.NullString characterMaximumLength sql.NullInt64 characterOctetLength sql.NullInt64 numericPrecision sql.NullInt64 numericPrecisionRadix sql.NullInt64 numericScale sql.NullInt64 datetimePrecision sql.NullInt64 ordinalPosition int64 isNullable sqlz.NullBool isIdentity sqlz.NullBool isUpdatable sqlz.NullBool } // getPgColumns queries the column metadata for tblName. func getPgColumns(ctx context.Context, db sqlz.DB, tblName string) ([]*pgColumn, error) { log := lg.FromContext(ctx) // colsQuery gets column information from information_schema.columns. // // It also has a subquery to get column comments. See: // - https://stackoverflow.com/a/22547588 // - https://dba.stackexchange.com/a/160668 const colsQuery = `SELECT table_catalog, table_schema, table_name, column_name, ordinal_position, column_default, is_nullable, data_type, character_maximum_length, character_octet_length, numeric_precision, numeric_precision_radix, numeric_scale, datetime_precision, domain_catalog, domain_schema, domain_name, udt_catalog, udt_schema, udt_name, is_identity, is_generated, is_updatable, ( SELECT pg_catalog.col_description(c.oid, cols.ordinal_position::INT) FROM pg_catalog.pg_class c WHERE c.oid = (SELECT ('"' || cols.table_name || '"')::regclass::oid) AND c.relname = cols.table_name ) AS column_comment FROM information_schema.columns cols WHERE cols.table_catalog = current_catalog AND cols.table_schema = current_schema() AND cols.table_name = $1 ORDER BY cols.table_catalog, cols.table_schema, cols.table_name, cols.ordinal_position` rows, err := db.QueryContext(ctx, colsQuery, tblName) if err != nil { return nil, errw(err) } defer lg.WarnIfCloseError(log, lgm.CloseDBRows, rows) var cols []*pgColumn for rows.Next() { col := &pgColumn{} err = scanPgColumn(rows, col) if err != nil { return nil, err } progress.Incr(ctx, 1) progress.DebugSleep(ctx) cols = append(cols, col) } err = closeRows(rows) if err != nil { return nil, err } return cols, nil } func scanPgColumn(rows *sql.Rows, c *pgColumn) error { err := rows.Scan(&c.tableCatalog, &c.tableSchema, &c.tableName, &c.columnName, &c.ordinalPosition, &c.columnDefault, &c.isNullable, &c.dataType, &c.characterMaximumLength, &c.characterOctetLength, &c.numericPrecision, &c.numericPrecisionRadix, &c.numericScale, &c.datetimePrecision, &c.domainCatalog, &c.domainSchema, &c.domainName, &c.udtCatalog, &c.udtSchema, &c.udtName, &c.isIdentity, &c.isGenerated, &c.isUpdatable, &c.comment) return errw(err) } func colMetaFromPgColumn(log *slog.Logger, pgCol *pgColumn) *metadata.Column { colMeta := &metadata.Column{ Name: pgCol.columnName, Position: pgCol.ordinalPosition, PrimaryKey: false, // Note that PrimaryKey is set separately from pgConstraint. BaseType: pgCol.udtName, ColumnType: pgCol.dataType, Kind: kindFromDBTypeName(log, pgCol.columnName, pgCol.udtName), Nullable: pgCol.isNullable.Bool, DefaultValue: pgCol.columnDefault.String, Comment: pgCol.comment.String, } return colMeta } // getPgConstraints returns a slice of pgConstraint. If tblName is // empty, constraints for all tables in the current catalog & schema // are returned. If tblName is specified, constraints just for that // table are returned. func getPgConstraints(ctx context.Context, db sqlz.DB, tblName string) ([]*pgConstraint, error) { log := lg.FromContext(ctx) var args []any query := `SELECT kcu.table_catalog,kcu.table_schema,kcu.table_name,kcu.column_name, kcu.ordinal_position,tc.constraint_name,tc.constraint_type, ( SELECT pg_catalog.pg_get_constraintdef(pgc.oid, TRUE) FROM pg_catalog.pg_constraint pgc WHERE pgc.conrelid = (SELECT ('"' || kcu.table_name || '"')::regclass::oid) AND pgc.conname = tc.constraint_name limit 1 ) AS constraint_def, ( SELECT pgc.confrelid::regclass FROM pg_catalog.pg_constraint pgc WHERE pgc.conrelid = (SELECT ('"' || kcu.table_name || '"')::regclass::oid) AND pgc.conname = tc.constraint_name AND pgc.confrelid > 0 LIMIT 1 ) AS constraint_fkey_table_name FROM information_schema.key_column_usage AS kcu LEFT JOIN information_schema.table_constraints AS tc ON tc.constraint_name = kcu.constraint_name WHERE kcu.table_catalog = current_catalog AND kcu.table_schema = current_schema() ` if tblName != "" { query += ` AND kcu.table_name = $1 ` args = append(args, tblName) } query += ` ORDER BY kcu.table_catalog, kcu.table_schema, kcu.table_name, tc.constraint_type DESC, kcu.ordinal_position` rows, err := db.QueryContext(ctx, query, args...) if err != nil { return nil, errw(err) } defer lg.WarnIfCloseError(log, lgm.CloseDBRows, rows) var constraints []*pgConstraint for rows.Next() { pgc := &pgConstraint{} err = rows.Scan(&pgc.tableCatalog, &pgc.tableSchema, &pgc.tableName, &pgc.columnName, &pgc.ordinalPosition, &pgc.constraintName, &pgc.constraintType, &pgc.constraintDef, &pgc.constraintFKeyTableName) if err != nil { return nil, errw(err) } progress.Incr(ctx, 1) progress.DebugSleep(ctx) constraints = append(constraints, pgc) } err = closeRows(rows) if err != nil { return nil, err } return constraints, nil } // pgConstraint holds query results for constraint metadata. // This type is column-focused: that is, an instance is produced // for each constraint/column pair. Thus, if a table has a // composite primary key (col_a, col_b), two pgConstraint instances // are produced. type pgConstraint struct { tableCatalog string tableSchema string tableName string columnName string constraintName sql.NullString constraintType sql.NullString constraintDef sql.NullString // constraintFKeyTableName holds the name of the table to which // a foreign-key constraint points to. This is null if this // constraint is not a foreign key. constraintFKeyTableName sql.NullString ordinalPosition int64 } // setTblMetaConstraints updates tblMeta with constraints found // in pgConstraints. func setTblMetaConstraints(log *slog.Logger, tblMeta *metadata.Table, pgConstraints []*pgConstraint) { for _, pgc := range pgConstraints { fqTblName := pgc.tableCatalog + "." + pgc.tableSchema + "." + pgc.tableName if fqTblName != tblMeta.FQName { continue } if pgc.constraintType.String == constraintTypePK { colMeta := tblMeta.Column(pgc.columnName) if colMeta == nil { // Shouldn't happen log.Warn("No column found matching constraint", lga.Target, tblMeta.Name+"."+pgc.columnName, "constraint", pgc.constraintName, ) continue } colMeta.PrimaryKey = true } } } const ( constraintTypePK = "PRIMARY KEY" constraintTypeFK = "FOREIGN KEY" ) // closeRows invokes rows.Err and rows.Close, returning // an error if either of those methods returned an error. func closeRows(rows *sql.Rows) error { if rows == nil { return nil } err1 := rows.Err() err2 := rows.Close() if err1 != nil { return errw(err1) } return errw(err2) }