diff --git a/examples/26_add_column_with_check_constraint.json b/examples/26_add_column_with_check_constraint.json new file mode 100644 index 0000000..a38f9bd --- /dev/null +++ b/examples/26_add_column_with_check_constraint.json @@ -0,0 +1,19 @@ +{ + "name": "26_add_column_with_check_constraint", + "operations": [ + { + "add_column": { + "table": "people", + "column": { + "name": "age", + "type": "integer", + "default": "18", + "check": { + "name": "age_check", + "constraint": "age >= 18" + } + } + } + } + ] +} diff --git a/pkg/migrations/op_add_column.go b/pkg/migrations/op_add_column.go index c4d5721..d8b816f 100644 --- a/pkg/migrations/op_add_column.go +++ b/pkg/migrations/op_add_column.go @@ -27,6 +27,12 @@ func (o *OpAddColumn) Start(ctx context.Context, conn *sql.DB, stateSchema strin if !o.Column.Nullable && o.Column.Default == nil { if err := addNotNullConstraint(ctx, conn, o.Table, o.Column.Name, TemporaryName(o.Column.Name)); err != nil { + return fmt.Errorf("failed to add not null constraint: %w", err) + } + } + + if o.Column.Check != nil { + if err := o.addCheckConstraint(ctx, conn); err != nil { return fmt.Errorf("failed to add check constraint: %w", err) } } @@ -98,6 +104,15 @@ func (o *OpAddColumn) Complete(ctx context.Context, conn *sql.DB) error { } } + if o.Column.Check != nil { + _, err = conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s VALIDATE CONSTRAINT %s", + pq.QuoteIdentifier(o.Table), + pq.QuoteIdentifier(o.Column.Check.Name))) + if err != nil { + return err + } + } + return err } @@ -136,6 +151,16 @@ func (o *OpAddColumn) Validate(ctx context.Context, s *schema.Schema) error { } } + if o.Column.Check != nil { + if err := o.Column.Check.Validate(); err != nil { + return CheckConstraintError{ + Table: o.Table, + Column: o.Column.Name, + Err: err, + } + } + } + if !o.Column.Nullable && o.Column.Default == nil && o.Up == nil { return FieldRequiredError{Name: "up"} } @@ -159,8 +184,15 @@ func addColumn(ctx context.Context, conn *sql.DB, o OpAddColumn, t *schema.Table o.Column.Nullable = true } - o.Column.Name = TemporaryName(o.Column.Name) + // Don't add a column with a CHECK constraint directly. + // They are handled by: + // - adding the column without the constraint + // - adding a NOT VALID check constraint to the column + // - validating the constraint on migration completion + // This is to avoid unnecessary exclusive table locks. + o.Column.Check = nil + o.Column.Name = TemporaryName(o.Column.Name) _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s", pq.QuoteIdentifier(t.Name), ColumnToSQL(o.Column), @@ -177,6 +209,15 @@ func addNotNullConstraint(ctx context.Context, conn *sql.DB, table, column, phys return err } +func (o *OpAddColumn) addCheckConstraint(ctx context.Context, conn *sql.DB) error { + _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s CHECK (%s) NOT VALID", + pq.QuoteIdentifier(o.Table), + pq.QuoteIdentifier(o.Column.Check.Name), + rewriteCheckExpression(o.Column.Check.Constraint, o.Column.Name, TemporaryName(o.Column.Name)), + )) + return err +} + func backFill(ctx context.Context, conn *sql.DB, table, column string) error { // touch rows without changing them in order to have the trigger fire // and set the value using the `up` SQL. diff --git a/pkg/migrations/op_add_column_test.go b/pkg/migrations/op_add_column_test.go index 22a4555..9e6b340 100644 --- a/pkg/migrations/op_add_column_test.go +++ b/pkg/migrations/op_add_column_test.go @@ -560,3 +560,78 @@ func TestAddColumnValidation(t *testing.T) { }, }) } + +func TestAddColumnWithCheckConstraint(t *testing.T) { + t.Parallel() + + ExecuteTests(t, TestCases{{ + name: "add column", + migrations: []migrations.Migration{ + { + Name: "01_add_table", + Operations: migrations.Operations{ + &migrations.OpCreateTable{ + Name: "users", + Columns: []migrations.Column{ + { + Name: "id", + Type: "serial", + PrimaryKey: true, + }, + { + Name: "name", + Type: "varchar(255)", + Unique: true, + }, + }, + }, + }, + }, + { + Name: "02_add_column", + Operations: migrations.Operations{ + &migrations.OpAddColumn{ + Table: "users", + Column: migrations.Column{ + Name: "age", + Type: "integer", + Default: ptr("18"), + Check: &migrations.CheckConstraint{ + Name: "age_check", + Constraint: "age >= 18", + }, + }, + }, + }, + }, + }, + afterStart: func(t *testing.T, db *sql.DB) { + // Inserting a row that meets the constraint into the new view succeeds. + MustInsert(t, db, "public", "02_add_column", "users", map[string]string{ + "name": "alice", + "age": "30", + }) + + // Inserting a row that does not meet the constraint into the new view fails. + MustNotInsert(t, db, "public", "02_add_column", "users", map[string]string{ + "name": "bob", + "age": "3", + }) + }, + afterRollback: func(t *testing.T, db *sql.DB) { + }, + afterComplete: func(t *testing.T, db *sql.DB) { + // Inserting a row that meets the constraint into the new view succeeds. + MustInsert(t, db, "public", "02_add_column", "users", map[string]string{ + "name": "carl", + "age": "30", + }) + + // Inserting a row that does not meet the constraint into the new view fails. + MustNotInsert(t, db, "public", "02_add_column", "users", map[string]string{ + "name": "dana", + "age": "3", + }) + }, + }}) +}