Implement adding check constraints to existing columns (#83)

Add support for an operation to add a `CHECK` constraint to an existing
column. The new operation looks like this:

```json
{
  "name": "22_add_check_constraint",
  "operations": [
    {
      "set_check_constraint": {
        "table": "posts",
        "column": "title",
        "check": "length(title) > 3",
        "up": "(SELECT CASE WHEN length(title) <= 3 THEN LPAD(title, 4, '-') ELSE title END)",
        "down": "title"
      }
    }
  ]
}
```

This migrations adds a `CHECK (length(title) > 3)` constraint to the
`title` column on the `posts` table. Pre-existing values in the old
schema are rewritten to meet the constraint using the `up` SQL.

The implementation is similar to the **set not null**, **change column
type** and **set foreign key** operations.

* On `Start`:
* The column is duplicated and a `NOT VALID` `CHECK` constraint is added
to the new column.
* Values from the old column are backfilled into the new column using
`up` SQL.
* Triggers are created to copy values from old -> new with `up` SQL and
from new->old using `down` SQL.
* On `Complete`
  * The `CHECK` constraint is validated 
* The old column is dropped and the new column renamed to the name of
the old column.
* Postgres ensures that the `CHECK` constraint is also updated to apply
to the new column.
  * Triggers and trigger functions are removed.
* On `Rollback`
  * The new column is removed
  * Triggers and trigger functions are removed.

As with other operations involving `up` and `down` SQL, it is the user's
responsibility to ensure that values from the old schema that don't meet
the new `CHECK` constraint are correctly rewritten to meet the
constraint with `up` SQL. If the `up` SQL fails to produce a value that
meets the constraint, the migration will fail either at start (for
existing values in the old schema) or at runtime (for values written to
the old schema during the migration period).
This commit is contained in:
Andrew Farries 2023-09-11 06:17:28 +01:00 committed by GitHub
parent c76ea9ce48
commit a105edd05e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 493 additions and 13 deletions

View File

@ -0,0 +1,14 @@
{
"name": "22_add_check_constraint",
"operations": [
{
"set_check_constraint": {
"table": "posts",
"column": "title",
"check": "length(title) > 3",
"up": "(SELECT CASE WHEN length(title) <= 3 THEN LPAD(title, 4, '-') ELSE title END)",
"down": "title"
}
}
]
}

View File

@ -11,19 +11,20 @@ import (
type OpName string
const (
OpNameCreateTable OpName = "create_table"
OpNameRenameTable OpName = "rename_table"
OpNameDropTable OpName = "drop_table"
OpNameAddColumn OpName = "add_column"
OpNameDropColumn OpName = "drop_column"
OpNameCreateIndex OpName = "create_index"
OpNameDropIndex OpName = "drop_index"
OpNameRenameColumn OpName = "rename_column"
OpNameSetUnique OpName = "set_unique"
OpNameSetNotNull OpName = "set_not_null"
OpNameSetForeignKey OpName = "set_foreign_key"
OpNameChangeType OpName = "change_type"
OpRawSQLName OpName = "sql"
OpNameCreateTable OpName = "create_table"
OpNameRenameTable OpName = "rename_table"
OpNameDropTable OpName = "drop_table"
OpNameAddColumn OpName = "add_column"
OpNameDropColumn OpName = "drop_column"
OpNameCreateIndex OpName = "create_index"
OpNameDropIndex OpName = "drop_index"
OpNameRenameColumn OpName = "rename_column"
OpNameSetUnique OpName = "set_unique"
OpNameSetNotNull OpName = "set_not_null"
OpNameSetForeignKey OpName = "set_foreign_key"
OpNameSetCheckConstraint OpName = "set_check_constraint"
OpNameChangeType OpName = "change_type"
OpRawSQLName OpName = "sql"
)
func TemporaryName(name string) string {
@ -112,6 +113,9 @@ func (v *Operations) UnmarshalJSON(data []byte) error {
case OpNameSetForeignKey:
item = &OpSetForeignKey{}
case OpNameSetCheckConstraint:
item = &OpSetCheckConstraint{}
case OpNameChangeType:
item = &OpChangeType{}
@ -195,6 +199,9 @@ func OperationName(op Operation) OpName {
case *OpSetForeignKey:
return OpNameSetForeignKey
case *OpSetCheckConstraint:
return OpNameSetCheckConstraint
case *OpChangeType:
return OpNameChangeType

View File

@ -0,0 +1,206 @@
package migrations
import (
"context"
"database/sql"
"fmt"
"strings"
"github.com/lib/pq"
"github.com/xataio/pg-roll/pkg/schema"
)
type OpSetCheckConstraint struct {
Table string `json:"table"`
Column string `json:"column"`
Check string `json:"check"`
Up string `json:"up"`
Down string `json:"down"`
}
var _ Operation = (*OpSetCheckConstraint)(nil)
func (o *OpSetCheckConstraint) Start(ctx context.Context, conn *sql.DB, stateSchema string, s *schema.Schema) error {
table := s.GetTable(o.Table)
column := table.GetColumn(o.Column)
// Create a copy of the column on the underlying table.
if err := duplicateColumn(ctx, conn, table, *column); err != nil {
return fmt.Errorf("failed to duplicate column: %w", err)
}
// Add the check constraint to the new column as NOT VALID.
if err := o.addCheckConstraint(ctx, conn); err != nil {
return fmt.Errorf("failed to add check constraint: %w", err)
}
// Add a trigger to copy values from the old column to the new, rewriting values using the `up` SQL.
err := createTrigger(ctx, conn, triggerConfig{
Name: TriggerName(o.Table, o.Column),
Direction: TriggerDirectionUp,
Columns: table.Columns,
SchemaName: s.Name,
TableName: o.Table,
PhysicalColumn: TemporaryName(o.Column),
StateSchema: stateSchema,
SQL: o.Up,
})
if err != nil {
return fmt.Errorf("failed to create up trigger: %w", err)
}
// Backfill the new column with values from the old column.
if err := backFill(ctx, conn, o.Table, TemporaryName(o.Column)); err != nil {
return fmt.Errorf("failed to backfill column: %w", err)
}
// Add the new column to the internal schema representation. This is done
// here, before creation of the down trigger, so that the trigger can declare
// a variable for the new column.
table.AddColumn(o.Column, schema.Column{
Name: TemporaryName(o.Column),
})
// Add a trigger to copy values from the new column to the old, rewriting values using the `down` SQL.
err = createTrigger(ctx, conn, triggerConfig{
Name: TriggerName(o.Table, TemporaryName(o.Column)),
Direction: TriggerDirectionDown,
Columns: table.Columns,
SchemaName: s.Name,
TableName: o.Table,
PhysicalColumn: o.Column,
StateSchema: stateSchema,
SQL: o.Down,
})
if err != nil {
return fmt.Errorf("failed to create down trigger: %w", err)
}
return nil
}
func (o *OpSetCheckConstraint) Complete(ctx context.Context, conn *sql.DB) error {
tempName := TemporaryName(o.Column)
// Validate the check constraint
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s VALIDATE CONSTRAINT %s",
pq.QuoteIdentifier(o.Table),
pq.QuoteIdentifier(CheckConstraintName(o.Table, tempName))))
if err != nil {
return err
}
// Remove the up function and trigger
_, err = conn.ExecContext(ctx, fmt.Sprintf("DROP FUNCTION IF EXISTS %s CASCADE",
pq.QuoteIdentifier(TriggerFunctionName(o.Table, o.Column))))
if err != nil {
return err
}
// Remove the down function and trigger
_, err = conn.ExecContext(ctx, fmt.Sprintf("DROP FUNCTION IF EXISTS %s CASCADE",
pq.QuoteIdentifier(TriggerFunctionName(o.Table, TemporaryName(o.Column)))))
if err != nil {
return err
}
// Drop the old column
_, err = conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s DROP COLUMN IF EXISTS %s",
pq.QuoteIdentifier(o.Table),
pq.QuoteIdentifier(o.Column)))
if err != nil {
return err
}
// Rename the new column to the old column name
_, err = conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s RENAME COLUMN %s TO %s",
pq.QuoteIdentifier(o.Table),
pq.QuoteIdentifier(TemporaryName(o.Column)),
pq.QuoteIdentifier(o.Column)))
if err != nil {
return err
}
// Rename the check constraint to use the final (non-temporary) column name.
_, err = conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s RENAME CONSTRAINT %s TO %s",
pq.QuoteIdentifier(o.Table),
pq.QuoteIdentifier(CheckConstraintName(o.Table, tempName)),
pq.QuoteIdentifier(CheckConstraintName(o.Table, o.Column)),
))
return err
}
func (o *OpSetCheckConstraint) Rollback(ctx context.Context, conn *sql.DB) error {
// Drop the new column, taking the constraint on the column with it
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s DROP COLUMN IF EXISTS %s",
pq.QuoteIdentifier(o.Table),
pq.QuoteIdentifier(TemporaryName(o.Column)),
))
if err != nil {
return err
}
// Remove the up function and trigger
_, err = conn.ExecContext(ctx, fmt.Sprintf("DROP FUNCTION IF EXISTS %s CASCADE",
pq.QuoteIdentifier(TriggerFunctionName(o.Table, o.Column)),
))
if err != nil {
return err
}
// Remove the down function and trigger
_, err = conn.ExecContext(ctx, fmt.Sprintf("DROP FUNCTION IF EXISTS %s CASCADE",
pq.QuoteIdentifier(TriggerFunctionName(o.Table, TemporaryName(o.Column))),
))
return err
}
func (o *OpSetCheckConstraint) Validate(ctx context.Context, s *schema.Schema) error {
table := s.GetTable(o.Table)
if table == nil {
return TableDoesNotExistError{Name: o.Table}
}
column := table.GetColumn(o.Column)
if column == nil {
return ColumnDoesNotExistError{Table: o.Table, Name: o.Column}
}
if o.Check == "" {
return FieldRequiredError{Name: "check"}
}
if o.Up == "" {
return FieldRequiredError{Name: "up"}
}
if o.Down == "" {
return FieldRequiredError{Name: "down"}
}
return nil
}
func CheckConstraintName(tableName, columnName string) string {
return fmt.Sprintf("_pgroll_check_%s_%s", tableName, columnName)
}
func (o *OpSetCheckConstraint) addCheckConstraint(ctx context.Context, conn *sql.DB) error {
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s CHECK (%s) NOT VALID",
pq.QuoteIdentifier(o.Table),
pq.QuoteIdentifier(CheckConstraintName(o.Table, TemporaryName(o.Column))),
rewriteCheckExpression(o.Check, o.Column, TemporaryName(o.Column)),
))
return err
}
// In order for the `check` expression to be easy to write, migration authors specify
// the check expression as though it were being applied to the old column,
// On migration start, however, the check is actually applied to the new (temporary)
// column.
// This function naively rewrites the check expression to apply to the new column.
func rewriteCheckExpression(check string, oldColumn, newColumn string) string {
return strings.ReplaceAll(check, oldColumn, newColumn)
}

View File

@ -0,0 +1,253 @@
package migrations_test
import (
"database/sql"
"testing"
"github.com/stretchr/testify/assert"
"github.com/xataio/pg-roll/pkg/migrations"
)
func TestSetCheckConstraint(t *testing.T) {
t.Parallel()
ExecuteTests(t, TestCases{{
name: "add check constraint",
migrations: []migrations.Migration{
{
Name: "01_add_table",
Operations: migrations.Operations{
&migrations.OpCreateTable{
Name: "posts",
Columns: []migrations.Column{
{
Name: "id",
Type: "serial",
PrimaryKey: true,
},
{
Name: "title",
Type: "text",
},
},
},
},
},
{
Name: "02_add_check_constraint",
Operations: migrations.Operations{
&migrations.OpSetCheckConstraint{
Table: "posts",
Column: "title",
Check: "length(title) > 3",
Up: "(SELECT CASE WHEN length(title) <= 3 THEN LPAD(title, 4, '-') ELSE title END)",
Down: "title",
},
},
},
},
afterStart: func(t *testing.T, db *sql.DB) {
// The new (temporary) `title` column should exist on the underlying table.
ColumnMustExist(t, db, "public", "posts", migrations.TemporaryName("title"))
// Inserting a row that meets the check constraint into the old view works.
MustInsert(t, db, "public", "01_add_table", "posts", map[string]string{
"title": "post by alice",
})
// Inserting a row that does not meet the check constraint into the old view also works.
MustInsert(t, db, "public", "01_add_table", "posts", map[string]string{
"title": "b",
})
// Both rows have been backfilled into the new view; the short title has
// been rewritten using `up` SQL to meet the length constraint.
rows := MustSelect(t, db, "public", "02_add_check_constraint", "posts")
assert.Equal(t, []map[string]any{
{"id": 1, "title": "post by alice"},
{"id": 2, "title": "---b"},
}, rows)
// Inserting a row that meets the check constraint into the new view works.
MustInsert(t, db, "public", "02_add_check_constraint", "posts", map[string]string{
"title": "post by carl",
})
// Inserting a row that does not meet the check constraint into the new view fails.
MustNotInsert(t, db, "public", "02_add_check_constraint", "posts", map[string]string{
"title": "d",
})
// The row that was inserted into the new view has been backfilled into the old view.
rows = MustSelect(t, db, "public", "01_add_table", "posts")
assert.Equal(t, []map[string]any{
{"id": 1, "title": "post by alice"},
{"id": 2, "title": "b"},
{"id": 3, "title": "post by carl"},
}, rows)
},
afterRollback: func(t *testing.T, db *sql.DB) {
// The new (temporary) `title` column should not exist on the underlying table.
ColumnMustNotExist(t, db, "public", "posts", migrations.TemporaryName("title"))
// The up function no longer exists.
FunctionMustNotExist(t, db, "public", migrations.TriggerFunctionName("posts", "title"))
// The down function no longer exists.
FunctionMustNotExist(t, db, "public", migrations.TriggerFunctionName("posts", migrations.TemporaryName("title")))
// The up trigger no longer exists.
TriggerMustNotExist(t, db, "public", "posts", migrations.TriggerName("posts", "title"))
// The down trigger no longer exists.
TriggerMustNotExist(t, db, "public", "posts", migrations.TriggerName("posts", migrations.TemporaryName("title")))
},
afterComplete: func(t *testing.T, db *sql.DB) {
// Inserting a row that meets the check constraint into the new view works.
MustInsert(t, db, "public", "02_add_check_constraint", "posts", map[string]string{
"title": "post by dana",
})
// Inserting a row that does not meet the check constraint into the new view fails.
MustNotInsert(t, db, "public", "02_add_check_constraint", "posts", map[string]string{
"title": "e",
})
// The data in the new `posts` view is as expected.
rows := MustSelect(t, db, "public", "02_add_check_constraint", "posts")
assert.Equal(t, []map[string]any{
{"id": 1, "title": "post by alice"},
{"id": 2, "title": "---b"},
{"id": 3, "title": "post by carl"},
{"id": 5, "title": "post by dana"},
}, rows)
// The up function no longer exists.
FunctionMustNotExist(t, db, "public", migrations.TriggerFunctionName("posts", "title"))
// The down function no longer exists.
FunctionMustNotExist(t, db, "public", migrations.TriggerFunctionName("posts", migrations.TemporaryName("title")))
// The up trigger no longer exists.
TriggerMustNotExist(t, db, "public", "posts", migrations.TriggerName("posts", "title"))
// The down trigger no longer exists.
TriggerMustNotExist(t, db, "public", "posts", migrations.TriggerName("posts", migrations.TemporaryName("title")))
},
}})
}
func TestSetCheckConstraintValidation(t *testing.T) {
t.Parallel()
createTableMigration := migrations.Migration{
Name: "01_add_table",
Operations: migrations.Operations{
&migrations.OpCreateTable{
Name: "posts",
Columns: []migrations.Column{
{
Name: "id",
Type: "serial",
PrimaryKey: true,
},
{
Name: "title",
Type: "text",
},
},
},
},
}
ExecuteTests(t, TestCases{
{
name: "table must exist",
migrations: []migrations.Migration{
createTableMigration,
{
Name: "02_add_check_constraint",
Operations: migrations.Operations{
&migrations.OpSetCheckConstraint{
Table: "doesntexist",
Column: "title",
Check: "length(title) > 3",
Up: "(SELECT CASE WHEN length(title) <= 3 THEN LPAD(title, 4, '-') ELSE title END)",
Down: "title",
},
},
},
},
wantStartErr: migrations.TableDoesNotExistError{Name: "doesntexist"},
},
{
name: "column must exist",
migrations: []migrations.Migration{
createTableMigration,
{
Name: "02_add_check_constraint",
Operations: migrations.Operations{
&migrations.OpSetCheckConstraint{
Table: "posts",
Column: "doesntexist",
Check: "length(title) > 3",
Up: "(SELECT CASE WHEN length(title) <= 3 THEN LPAD(title, 4, '-') ELSE title END)",
Down: "title",
},
},
},
},
wantStartErr: migrations.ColumnDoesNotExistError{Table: "posts", Name: "doesntexist"},
},
{
name: "check SQL is mandatory",
migrations: []migrations.Migration{
createTableMigration,
{
Name: "02_add_check_constraint",
Operations: migrations.Operations{
&migrations.OpSetCheckConstraint{
Table: "posts",
Column: "title",
Up: "(SELECT CASE WHEN length(title) <= 3 THEN LPAD(title, 4, '-') ELSE title END)",
Down: "title",
},
},
},
},
wantStartErr: migrations.FieldRequiredError{Name: "check"},
},
{
name: "up SQL is mandatory",
migrations: []migrations.Migration{
createTableMigration,
{
Name: "02_add_check_constraint",
Operations: migrations.Operations{
&migrations.OpSetCheckConstraint{
Table: "posts",
Column: "title",
Check: "length(title) > 3",
Down: "title",
},
},
},
},
wantStartErr: migrations.FieldRequiredError{Name: "up"},
},
{
name: "down SQL is mandatory",
migrations: []migrations.Migration{
createTableMigration,
{
Name: "02_add_check_constraint",
Operations: migrations.Operations{
&migrations.OpSetCheckConstraint{
Table: "posts",
Column: "title",
Check: "length(title) > 3",
Up: "(SELECT CASE WHEN length(title) <= 3 THEN LPAD(title, 4, '-') ELSE title END)",
},
},
},
},
wantStartErr: migrations.FieldRequiredError{Name: "down"},
},
})
}