Allow columns with CHECK constraints on add column operations (#109)

Allow columns with `CHECK` constraints in `add_column` operations:

```json
{
  "name": "26_add_column_with_check_constraint",
  "operations": [
    {
      "add_column": {
        "table": "people",
        "column": {
          "name": "age",
          "type": "integer",
          "default": "18",
          "check": {
            "name": "age_check",
            "constraint": "age >= 18"
          }
        }
      }
    }
  ]
}
```
This commit is contained in:
Andrew Farries 2023-09-20 09:52:22 +01:00 committed by GitHub
parent 200529d5a3
commit 947b239b05
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 136 additions and 1 deletions

View File

@ -0,0 +1,19 @@
{
"name": "26_add_column_with_check_constraint",
"operations": [
{
"add_column": {
"table": "people",
"column": {
"name": "age",
"type": "integer",
"default": "18",
"check": {
"name": "age_check",
"constraint": "age >= 18"
}
}
}
}
]
}

View File

@ -27,6 +27,12 @@ func (o *OpAddColumn) Start(ctx context.Context, conn *sql.DB, stateSchema strin
if !o.Column.Nullable && o.Column.Default == nil {
if err := addNotNullConstraint(ctx, conn, o.Table, o.Column.Name, TemporaryName(o.Column.Name)); err != nil {
return fmt.Errorf("failed to add not null constraint: %w", err)
}
}
if o.Column.Check != nil {
if err := o.addCheckConstraint(ctx, conn); err != nil {
return fmt.Errorf("failed to add check constraint: %w", err)
}
}
@ -98,6 +104,15 @@ func (o *OpAddColumn) Complete(ctx context.Context, conn *sql.DB) error {
}
}
if o.Column.Check != nil {
_, err = conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s VALIDATE CONSTRAINT %s",
pq.QuoteIdentifier(o.Table),
pq.QuoteIdentifier(o.Column.Check.Name)))
if err != nil {
return err
}
}
return err
}
@ -136,6 +151,16 @@ func (o *OpAddColumn) Validate(ctx context.Context, s *schema.Schema) error {
}
}
if o.Column.Check != nil {
if err := o.Column.Check.Validate(); err != nil {
return CheckConstraintError{
Table: o.Table,
Column: o.Column.Name,
Err: err,
}
}
}
if !o.Column.Nullable && o.Column.Default == nil && o.Up == nil {
return FieldRequiredError{Name: "up"}
}
@ -159,8 +184,15 @@ func addColumn(ctx context.Context, conn *sql.DB, o OpAddColumn, t *schema.Table
o.Column.Nullable = true
}
o.Column.Name = TemporaryName(o.Column.Name)
// Don't add a column with a CHECK constraint directly.
// They are handled by:
// - adding the column without the constraint
// - adding a NOT VALID check constraint to the column
// - validating the constraint on migration completion
// This is to avoid unnecessary exclusive table locks.
o.Column.Check = nil
o.Column.Name = TemporaryName(o.Column.Name)
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s",
pq.QuoteIdentifier(t.Name),
ColumnToSQL(o.Column),
@ -177,6 +209,15 @@ func addNotNullConstraint(ctx context.Context, conn *sql.DB, table, column, phys
return err
}
func (o *OpAddColumn) addCheckConstraint(ctx context.Context, conn *sql.DB) error {
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s CHECK (%s) NOT VALID",
pq.QuoteIdentifier(o.Table),
pq.QuoteIdentifier(o.Column.Check.Name),
rewriteCheckExpression(o.Column.Check.Constraint, o.Column.Name, TemporaryName(o.Column.Name)),
))
return err
}
func backFill(ctx context.Context, conn *sql.DB, table, column string) error {
// touch rows without changing them in order to have the trigger fire
// and set the value using the `up` SQL.

View File

@ -560,3 +560,78 @@ func TestAddColumnValidation(t *testing.T) {
},
})
}
func TestAddColumnWithCheckConstraint(t *testing.T) {
t.Parallel()
ExecuteTests(t, TestCases{{
name: "add column",
migrations: []migrations.Migration{
{
Name: "01_add_table",
Operations: migrations.Operations{
&migrations.OpCreateTable{
Name: "users",
Columns: []migrations.Column{
{
Name: "id",
Type: "serial",
PrimaryKey: true,
},
{
Name: "name",
Type: "varchar(255)",
Unique: true,
},
},
},
},
},
{
Name: "02_add_column",
Operations: migrations.Operations{
&migrations.OpAddColumn{
Table: "users",
Column: migrations.Column{
Name: "age",
Type: "integer",
Default: ptr("18"),
Check: &migrations.CheckConstraint{
Name: "age_check",
Constraint: "age >= 18",
},
},
},
},
},
},
afterStart: func(t *testing.T, db *sql.DB) {
// Inserting a row that meets the constraint into the new view succeeds.
MustInsert(t, db, "public", "02_add_column", "users", map[string]string{
"name": "alice",
"age": "30",
})
// Inserting a row that does not meet the constraint into the new view fails.
MustNotInsert(t, db, "public", "02_add_column", "users", map[string]string{
"name": "bob",
"age": "3",
})
},
afterRollback: func(t *testing.T, db *sql.DB) {
},
afterComplete: func(t *testing.T, db *sql.DB) {
// Inserting a row that meets the constraint into the new view succeeds.
MustInsert(t, db, "public", "02_add_column", "users", map[string]string{
"name": "carl",
"age": "30",
})
// Inserting a row that does not meet the constraint into the new view fails.
MustNotInsert(t, db, "public", "02_add_column", "users", map[string]string{
"name": "dana",
"age": "3",
})
},
}})
}