Implement Start for adding columns with NOT NULL and no DEFAULT (#37)

Implement `Start` for **add column** operations that add a `NOT NULL`
column without a `DEFAULT`.

To add such a column without forcing a exclusive lock while a full table
scan is performed, these steps need to be followed:

On `Start`:
1. Add the new column
2. Add a `CHECK IS NOT NULL` constraint to the new column, but with `NOT
VALID`, to avoid the scan.
3. Backfill the new column with the provided `up` SQL.

On `Complete`
1. Validate the constraint (with `ALTER TABLE VALIDATE CONSTRAINT`).
2. Add the `NOT NULL` attribute to the column. The presence of a valid
`NOT NULL` constraint on the column means that adding `NOT NULL` to the
column does not perform a full table scan.

See [this
post](https://medium.com/paypal-tech/postgresql-at-scale-database-schema-changes-without-downtime-20d3749ed680#00dc)
for a summary of these steps.
This commit is contained in:
Andrew Farries 2023-07-21 07:47:42 +01:00 committed by GitHub
parent 7e209da2ea
commit 9a08b6cc77
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 114 additions and 7 deletions

View File

@ -22,6 +22,17 @@
"default": "100"
}
}
},
{
"add_column": {
"table": "products",
"up": "name || '-category'",
"column": {
"name": "category",
"type": "varchar(255)",
"nullable": false
}
}
}
]
}

View File

@ -26,6 +26,12 @@ func (o *OpAddColumn) Start(ctx context.Context, conn *sql.DB, schemaName, state
return fmt.Errorf("failed to start add column operation: %w", err)
}
if !o.Column.Nullable && o.Column.Default == nil {
if err := addNotNullConstraint(ctx, conn, o); err != nil {
return fmt.Errorf("failed to add check constraint: %w", err)
}
}
if o.Up != nil {
if err := createTrigger(ctx, conn, o, schemaName, stateSchema, s); err != nil {
return fmt.Errorf("failed to create trigger: %w", err)
@ -78,10 +84,6 @@ func (o *OpAddColumn) Validate(ctx context.Context, s *schema.Schema) error {
return ColumnAlreadyExistsError{Name: o.Column.Name, Table: o.Table}
}
if !o.Column.Nullable && o.Column.Default == nil {
return errors.New("adding non-nullable columns without a default is not supported")
}
if o.Column.PrimaryKey {
return errors.New("adding primary key columns is not supported")
}
@ -90,6 +92,17 @@ func (o *OpAddColumn) Validate(ctx context.Context, s *schema.Schema) error {
}
func addColumn(ctx context.Context, conn *sql.DB, o OpAddColumn, t *schema.Table) error {
// don't add non-nullable columns with no default directly
// they are handled by:
// - adding the column as nullable
// - adding a NOT VALID check constraint on the column
// - validating the constraint and converting the column to not null
// on migration completion
// This is to avoid unnecessary exclusive table locks.
if !o.Column.Nullable && o.Column.Default == nil {
o.Column.Nullable = true
}
o.Column.Name = TemporaryName(o.Column.Name)
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s",
@ -99,6 +112,15 @@ func addColumn(ctx context.Context, conn *sql.DB, o OpAddColumn, t *schema.Table
return err
}
func addNotNullConstraint(ctx context.Context, conn *sql.DB, o *OpAddColumn) error {
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s CHECK (%s IS NOT NULL) NOT VALID",
pq.QuoteIdentifier(o.Table),
pq.QuoteIdentifier(NotNullConstraintName(o.Column.Name)),
pq.QuoteIdentifier(TemporaryName(o.Column.Name)),
))
return err
}
func createTrigger(ctx context.Context, conn *sql.DB, o *OpAddColumn, schemaName, stateSchema string, s *schema.Schema) error {
// Generate the SQL declarations for the trigger function
// This results in declarations like:
@ -177,6 +199,10 @@ func backFill(ctx context.Context, conn *sql.DB, o *OpAddColumn) error {
return err
}
func NotNullConstraintName(columnName string) string {
return "_pgroll_add_column_check_" + columnName
}
func TriggerFunctionName(tableName, columnName string) string {
return "_pgroll_add_column_" + tableName + "_" + columnName
}

View File

@ -187,3 +187,63 @@ func TestAddColumnWithUpSql(t *testing.T) {
},
}})
}
func TestAddNotNullColumnWithNoDefault(t *testing.T) {
t.Parallel()
ptr := func(s string) *string { return &s }
ExecuteTests(t, TestCases{{
name: "add column",
migrations: []migrations.Migration{
{
Name: "01_add_table",
Operations: migrations.Operations{
&migrations.OpCreateTable{
Name: "products",
Columns: []migrations.Column{
{
Name: "id",
Type: "serial",
PrimaryKey: true,
},
{
Name: "name",
Type: "varchar(255)",
Unique: true,
},
},
},
},
},
{
Name: "02_add_column",
Operations: migrations.Operations{
&migrations.OpAddColumn{
Table: "products",
Up: ptr("UPPER(name)"),
Column: migrations.Column{
Name: "description",
Type: "varchar(255)",
Nullable: false,
},
},
},
},
},
afterStart: func(t *testing.T, db *sql.DB) {
// Inserting a null description through the old view works (due to `up` sql populating the column).
MustInsert(t, db, "public", "01_add_table", "products", map[string]string{
"name": "apple",
})
// Inserting a null description through the new view fails.
MustNotInsert(t, db, "public", "02_add_column", "products", map[string]string{
"name": "banana",
})
},
afterRollback: func(t *testing.T, db *sql.DB) {
},
afterComplete: func(t *testing.T, db *sql.DB) {
},
}})
}

View File

@ -346,6 +346,18 @@ func columnExists(t *testing.T, db *sql.DB, schema, table, column string) bool {
}
func MustInsert(t *testing.T, db *sql.DB, schema, version, table string, record map[string]string) {
if err := insert(t, db, schema, version, table, record); err != nil {
t.Fatal(err)
}
}
func MustNotInsert(t *testing.T, db *sql.DB, schema, version, table string, record map[string]string) {
if err := insert(t, db, schema, version, table, record); err == nil {
t.Fatal("Expected INSERT to fail")
}
}
func insert(t *testing.T, db *sql.DB, schema, version, table string, record map[string]string) error {
t.Helper()
versionSchema := roll.VersionedSchemaName(schema, version)
@ -374,9 +386,7 @@ func MustInsert(t *testing.T, db *sql.DB, schema, version, table string, record
stmt := fmt.Sprintf("INSERT INTO %s.%s %s", versionSchema, table, recordStr)
_, err := db.Exec(stmt)
if err != nil {
t.Fatal(err)
}
return err
}
func MustSelect(t *testing.T, db *sql.DB, schema, version, table string) []map[string]any {