Add support for adding a foreign key constraint to an existing column (#82)

Add support for adding a foreign key constraint to an existing column.
Such a migration looks like:

```json
{
  "name": "21_add_foreign_key_constraint",
  "operations": [
    {
      "set_foreign_key": {
        "table": "posts",
        "column": "user_id",
        "references": {
          "table": "users",
          "column": "id"
        },
        "up": "(SELECT CASE WHEN EXISTS (SELECT 1 FROM users WHERE users.id = user_id) THEN user_id ELSE NULL END)",
        "down": "user_id"
      }
    }
  ]
}
```

This migration adds a foreign key constraint to the `user_id` column in
the `posts` table, referencing the `id` column in the `users` table.

The implementation is similar to the **set not null** and **change
column type** operations:

* On `Start`:
* Create a new column, duplicating the one to which the FK constraint
should be added.
* The new column has the foreign key constraint added as `NOT VALID` to
avoid taking a long lived `SHARE ROW EXCLUSIVE` lock (see
[here](https://medium.com/paypal-tech/postgresql-at-scale-database-schema-changes-without-downtime-20d3749ed680#00dc)).
* Backfill the new column with values from the existing column,
rewriting values using the `up` SQL.
* Create a trigger to populate the new column when values are written to
the old column, converting values with `up`.
* Create a trigger to populate the old column when values are written to
the new column, converting values with `down`.
* On `Complete`
  * Validate the foreign key constraint.
  * Remove triggers
  * Drop the old column
  * Rename the new column to the old column name.
* Rename the foreign key constraint to be consistent with the new name
of the column.
* On `Rollback`
* Remove the new column and both triggers. Removing the new column also
removes the foreign key constraint on it.

The `up` SQL in this operation is critical. The old column does not have
a foreign key constraint imposed on it after `Start` as that would
violate the guarantee that `pg-roll` does not make changes to the
existing schema. The `up` SQL therefore needs to take into account that
not all rows inserted into the old schema will have a valid foreign key.
In the example `json` above, the `up` SQL ensures that values for which
there is no corresponding user in the `users` table result in `NULL`
values in the new column. Failure to do this would result in the old
schema failing to insert rows without a valid `user_id`. An alternative
would be to implement data quarantining for these values, as discussed
last week @exekias .
This commit is contained in:
Andrew Farries 2023-09-11 06:12:59 +01:00 committed by GitHub
parent 425118423c
commit c76ea9ce48
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 590 additions and 12 deletions

View File

@ -0,0 +1,26 @@
{
"name": "20_create_posts_table",
"operations": [
{
"create_table": {
"name": "posts",
"columns": [
{
"name": "id",
"type": "serial",
"pk": true
},
{
"name": "title",
"type": "varchar(255)"
},
{
"name": "user_id",
"type": "integer",
"nullable": true
}
]
}
}
]
}

View File

@ -0,0 +1,17 @@
{
"name": "21_add_foreign_key_constraint",
"operations": [
{
"set_foreign_key": {
"table": "posts",
"column": "user_id",
"references": {
"table": "users",
"column": "id"
},
"up": "(SELECT CASE WHEN EXISTS (SELECT 1 FROM users WHERE users.id = user_id) THEN user_id ELSE NULL END)",
"down": "user_id"
}
}
]
}

View File

@ -11,18 +11,19 @@ import (
type OpName string
const (
OpNameCreateTable OpName = "create_table"
OpNameRenameTable OpName = "rename_table"
OpNameDropTable OpName = "drop_table"
OpNameAddColumn OpName = "add_column"
OpNameDropColumn OpName = "drop_column"
OpNameCreateIndex OpName = "create_index"
OpNameDropIndex OpName = "drop_index"
OpNameRenameColumn OpName = "rename_column"
OpNameSetUnique OpName = "set_unique"
OpNameSetNotNull OpName = "set_not_null"
OpNameChangeType OpName = "change_type"
OpRawSQLName OpName = "sql"
OpNameCreateTable OpName = "create_table"
OpNameRenameTable OpName = "rename_table"
OpNameDropTable OpName = "drop_table"
OpNameAddColumn OpName = "add_column"
OpNameDropColumn OpName = "drop_column"
OpNameCreateIndex OpName = "create_index"
OpNameDropIndex OpName = "drop_index"
OpNameRenameColumn OpName = "rename_column"
OpNameSetUnique OpName = "set_unique"
OpNameSetNotNull OpName = "set_not_null"
OpNameSetForeignKey OpName = "set_foreign_key"
OpNameChangeType OpName = "change_type"
OpRawSQLName OpName = "sql"
)
func TemporaryName(name string) string {
@ -108,6 +109,9 @@ func (v *Operations) UnmarshalJSON(data []byte) error {
case OpNameSetNotNull:
item = &OpSetNotNull{}
case OpNameSetForeignKey:
item = &OpSetForeignKey{}
case OpNameChangeType:
item = &OpChangeType{}
@ -188,6 +192,9 @@ func OperationName(op Operation) OpName {
case *OpSetNotNull:
return OpNameSetNotNull
case *OpSetForeignKey:
return OpNameSetForeignKey
case *OpChangeType:
return OpNameChangeType

216
pkg/migrations/op_set_fk.go Normal file
View File

@ -0,0 +1,216 @@
package migrations
import (
"context"
"database/sql"
"fmt"
"github.com/lib/pq"
"github.com/xataio/pg-roll/pkg/schema"
)
type OpSetForeignKey struct {
Table string `json:"table"`
Column string `json:"column"`
References ColumnReference `json:"references"`
Up string `json:"up"`
Down string `json:"down"`
}
var _ Operation = (*OpSetForeignKey)(nil)
func (o *OpSetForeignKey) Start(ctx context.Context, conn *sql.DB, stateSchema string, s *schema.Schema) error {
table := s.GetTable(o.Table)
column := table.GetColumn(o.Column)
// Create a copy of the column on the underlying table.
if err := duplicateColumn(ctx, conn, table, *column); err != nil {
return fmt.Errorf("failed to duplicate column: %w", err)
}
// Create a NOT VALID foreign key constraint on the new column.
if err := o.addForeignKeyConstraint(ctx, conn); err != nil {
return fmt.Errorf("failed to add foreign key constraint: %w", err)
}
// Add a trigger to copy values from the old column to the new, rewriting values using the `up` SQL.
err := createTrigger(ctx, conn, triggerConfig{
Name: TriggerName(o.Table, o.Column),
Direction: TriggerDirectionUp,
Columns: table.Columns,
SchemaName: s.Name,
TableName: o.Table,
PhysicalColumn: TemporaryName(o.Column),
StateSchema: stateSchema,
SQL: o.Up,
})
if err != nil {
return fmt.Errorf("failed to create up trigger: %w", err)
}
// Backfill the new column with values from the old column.
if err := backFill(ctx, conn, o.Table, TemporaryName(o.Column)); err != nil {
return fmt.Errorf("failed to backfill column: %w", err)
}
// Add the new column to the internal schema representation. This is done
// here, before creation of the down trigger, so that the trigger can declare
// a variable for the new column.
table.AddColumn(o.Column, schema.Column{
Name: TemporaryName(o.Column),
})
// Add a trigger to copy values from the new column to the old, rewriting values using the `down` SQL.
err = createTrigger(ctx, conn, triggerConfig{
Name: TriggerName(o.Table, TemporaryName(o.Column)),
Direction: TriggerDirectionDown,
Columns: table.Columns,
SchemaName: s.Name,
TableName: o.Table,
PhysicalColumn: o.Column,
StateSchema: stateSchema,
SQL: o.Down,
})
if err != nil {
return fmt.Errorf("failed to create down trigger: %w", err)
}
return nil
}
func (o *OpSetForeignKey) Complete(ctx context.Context, conn *sql.DB) error {
tempName := TemporaryName(o.Column)
tableRef := o.References.Table
columnRef := o.References.Column
// Validate the foreign key constraint
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s VALIDATE CONSTRAINT %s",
pq.QuoteIdentifier(o.Table),
pq.QuoteIdentifier(ForeignKeyConstraintName(tempName, tableRef, columnRef))))
if err != nil {
return err
}
// Remove the up function and trigger
_, err = conn.ExecContext(ctx, fmt.Sprintf("DROP FUNCTION IF EXISTS %s CASCADE",
pq.QuoteIdentifier(TriggerFunctionName(o.Table, o.Column))))
if err != nil {
return err
}
// Remove the down function and trigger
_, err = conn.ExecContext(ctx, fmt.Sprintf("DROP FUNCTION IF EXISTS %s CASCADE",
pq.QuoteIdentifier(TriggerFunctionName(o.Table, TemporaryName(o.Column)))))
if err != nil {
return err
}
// Drop the old column
_, err = conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s DROP COLUMN IF EXISTS %s",
pq.QuoteIdentifier(o.Table),
pq.QuoteIdentifier(o.Column)))
if err != nil {
return err
}
// Rename the new column to the old column name
_, err = conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s RENAME COLUMN %s TO %s",
pq.QuoteIdentifier(o.Table),
pq.QuoteIdentifier(TemporaryName(o.Column)),
pq.QuoteIdentifier(o.Column)))
if err != nil {
return err
}
// Rename the foreign key constraint to use the final (non-temporary) column name.
_, err = conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s RENAME CONSTRAINT %s TO %s",
pq.QuoteIdentifier(o.Table),
pq.QuoteIdentifier(ForeignKeyConstraintName(tempName, tableRef, columnRef)),
pq.QuoteIdentifier(ForeignKeyConstraintName(o.Column, tableRef, columnRef)),
))
return err
}
func (o *OpSetForeignKey) Rollback(ctx context.Context, conn *sql.DB) error {
// Drop the new column, taking the constraint on the column with it
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s DROP COLUMN IF EXISTS %s",
pq.QuoteIdentifier(o.Table),
pq.QuoteIdentifier(TemporaryName(o.Column)),
))
if err != nil {
return err
}
// Remove the up function and trigger
_, err = conn.ExecContext(ctx, fmt.Sprintf("DROP FUNCTION IF EXISTS %s CASCADE",
pq.QuoteIdentifier(TriggerFunctionName(o.Table, o.Column)),
))
if err != nil {
return err
}
// Remove the down function and trigger
_, err = conn.ExecContext(ctx, fmt.Sprintf("DROP FUNCTION IF EXISTS %s CASCADE",
pq.QuoteIdentifier(TriggerFunctionName(o.Table, TemporaryName(o.Column))),
))
return err
}
func (o *OpSetForeignKey) Validate(ctx context.Context, s *schema.Schema) error {
table := s.GetTable(o.Table)
if table == nil {
return TableDoesNotExistError{Name: o.Table}
}
column := table.GetColumn(o.Column)
if column == nil {
return ColumnDoesNotExistError{Table: o.Table, Name: o.Column}
}
refTable := s.GetTable(o.References.Table)
if refTable == nil {
return ColumnReferenceError{
Table: o.Table,
Column: o.Column,
Err: TableDoesNotExistError{Name: o.References.Table},
}
}
refColumn := refTable.GetColumn(o.References.Column)
if refColumn == nil {
return ColumnReferenceError{
Table: o.Table,
Column: o.Column,
Err: ColumnDoesNotExistError{
Table: o.References.Table,
Name: o.References.Column,
},
}
}
if o.Up == "" {
return FieldRequiredError{Name: "up"}
}
if o.Down == "" {
return FieldRequiredError{Name: "down"}
}
return nil
}
func (o *OpSetForeignKey) addForeignKeyConstraint(ctx context.Context, conn *sql.DB) error {
tempColumnName := TemporaryName(o.Column)
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s) NOT VALID",
pq.QuoteIdentifier(o.Table),
pq.QuoteIdentifier(ForeignKeyConstraintName(tempColumnName, o.References.Table, o.References.Column)),
pq.QuoteIdentifier(tempColumnName),
pq.QuoteIdentifier(o.References.Table),
pq.QuoteIdentifier(o.References.Column),
))
return err
}

View File

@ -0,0 +1,312 @@
package migrations_test
import (
"database/sql"
"testing"
"github.com/stretchr/testify/assert"
"github.com/xataio/pg-roll/pkg/migrations"
)
func TestSetForeignKey(t *testing.T) {
t.Parallel()
ExecuteTests(t, TestCases{{
name: "add foreign key constraint",
migrations: []migrations.Migration{
{
Name: "01_add_tables",
Operations: migrations.Operations{
&migrations.OpCreateTable{
Name: "users",
Columns: []migrations.Column{
{
Name: "id",
Type: "serial",
PrimaryKey: true,
},
{
Name: "name",
Type: "text",
},
},
},
&migrations.OpCreateTable{
Name: "posts",
Columns: []migrations.Column{
{
Name: "id",
Type: "serial",
PrimaryKey: true,
},
{
Name: "title",
Type: "text",
},
{
Name: "user_id",
Type: "integer",
},
},
},
},
},
{
Name: "02_add_fk_constraint",
Operations: migrations.Operations{
&migrations.OpSetForeignKey{
Table: "posts",
Column: "user_id",
References: migrations.ColumnReference{
Table: "users",
Column: "id",
},
Up: "(SELECT CASE WHEN EXISTS (SELECT 1 FROM users WHERE users.id = user_id) THEN user_id ELSE NULL END)",
Down: "user_id",
},
},
},
},
afterStart: func(t *testing.T, db *sql.DB) {
// The new (temporary) `user_id` column should exist on the underlying table.
ColumnMustExist(t, db, "public", "posts", migrations.TemporaryName("user_id"))
// Inserting some data into the `users` table works.
MustInsert(t, db, "public", "02_add_fk_constraint", "users", map[string]string{
"name": "alice",
})
MustInsert(t, db, "public", "02_add_fk_constraint", "users", map[string]string{
"name": "bob",
})
// Inserting data into the new `posts` view with a valid user reference works.
MustInsert(t, db, "public", "02_add_fk_constraint", "posts", map[string]string{
"title": "post by alice",
"user_id": "1",
})
// Inserting data into the new `posts` view with an invalid user reference fails.
MustNotInsert(t, db, "public", "02_add_fk_constraint", "posts", map[string]string{
"title": "post by unknown user",
"user_id": "3",
})
// The post that was inserted successfully has been backfilled into the old view.
rows := MustSelect(t, db, "public", "01_add_tables", "posts")
assert.Equal(t, []map[string]any{
{"id": 1, "title": "post by alice", "user_id": 1},
}, rows)
// Inserting data into the old `posts` view with a valid user reference works.
MustInsert(t, db, "public", "01_add_tables", "posts", map[string]string{
"title": "post by bob",
"user_id": "2",
})
// Inserting data into the old `posts` view with an invalid user reference also works.
MustInsert(t, db, "public", "01_add_tables", "posts", map[string]string{
"title": "post by unknown user",
"user_id": "3",
})
// The post that was inserted successfully has been backfilled into the new view.
// The post by an unknown user has been backfilled with a NULL user_id.
rows = MustSelect(t, db, "public", "02_add_fk_constraint", "posts")
assert.Equal(t, []map[string]any{
{"id": 1, "title": "post by alice", "user_id": 1},
{"id": 3, "title": "post by bob", "user_id": 2},
{"id": 4, "title": "post by unknown user", "user_id": nil},
}, rows)
},
afterRollback: func(t *testing.T, db *sql.DB) {
// The new (temporary) `user_id` column should not exist on the underlying table.
ColumnMustNotExist(t, db, "public", "posts", migrations.TemporaryName("user_id"))
// The up function no longer exists.
FunctionMustNotExist(t, db, "public", migrations.TriggerFunctionName("posts", "user_id"))
// The down function no longer exists.
FunctionMustNotExist(t, db, "public", migrations.TriggerFunctionName("posts", migrations.TemporaryName("user_id")))
// The up trigger no longer exists.
TriggerMustNotExist(t, db, "public", "posts", migrations.TriggerName("posts", "user_id"))
// The down trigger no longer exists.
TriggerMustNotExist(t, db, "public", "posts", migrations.TriggerName("posts", migrations.TemporaryName("user_id")))
},
afterComplete: func(t *testing.T, db *sql.DB) {
// The new (temporary) `user_id` column should not exist on the underlying table.
ColumnMustNotExist(t, db, "public", "posts", migrations.TemporaryName("user_id"))
// Inserting data into the new `posts` view with a valid user reference works.
MustInsert(t, db, "public", "02_add_fk_constraint", "posts", map[string]string{
"title": "another post by alice",
"user_id": "1",
})
// Inserting data into the new `posts` view with an invalid user reference fails.
MustNotInsert(t, db, "public", "02_add_fk_constraint", "posts", map[string]string{
"title": "post by unknown user",
"user_id": "3",
})
// The data in the new `posts` view is as expected.
rows := MustSelect(t, db, "public", "02_add_fk_constraint", "posts")
assert.Equal(t, []map[string]any{
{"id": 1, "title": "post by alice", "user_id": 1},
{"id": 3, "title": "post by bob", "user_id": 2},
{"id": 4, "title": "post by unknown user", "user_id": nil},
{"id": 5, "title": "another post by alice", "user_id": 1},
}, rows)
// The up function no longer exists.
FunctionMustNotExist(t, db, "public", migrations.TriggerFunctionName("posts", "user_id"))
// The down function no longer exists.
FunctionMustNotExist(t, db, "public", migrations.TriggerFunctionName("posts", migrations.TemporaryName("user_id")))
// The up trigger no longer exists.
TriggerMustNotExist(t, db, "public", "posts", migrations.TriggerName("posts", "user_id"))
// The down trigger no longer exists.
TriggerMustNotExist(t, db, "public", "posts", migrations.TriggerName("posts", migrations.TemporaryName("user_id")))
},
}})
}
func TestSetForeignKeyValidation(t *testing.T) {
t.Parallel()
createTablesMigration := migrations.Migration{
Name: "01_add_tables",
Operations: migrations.Operations{
&migrations.OpCreateTable{
Name: "users",
Columns: []migrations.Column{
{
Name: "id",
Type: "serial",
PrimaryKey: true,
},
{
Name: "name",
Type: "text",
},
},
},
&migrations.OpCreateTable{
Name: "posts",
Columns: []migrations.Column{
{
Name: "id",
Type: "serial",
PrimaryKey: true,
},
{
Name: "title",
Type: "text",
},
{
Name: "user_id",
Type: "integer",
},
},
},
},
}
ExecuteTests(t, TestCases{
{
name: "table must exist",
migrations: []migrations.Migration{
createTablesMigration,
{
Name: "02_add_fk_constraint",
Operations: migrations.Operations{
&migrations.OpSetForeignKey{
Table: "doesntexist",
Column: "user_id",
References: migrations.ColumnReference{
Table: "users",
Column: "id",
},
Up: "(SELECT CASE WHEN EXISTS (SELECT 1 FROM users WHERE users.id = user_id) THEN user_id ELSE NULL END)",
Down: "user_id",
},
},
},
},
wantStartErr: migrations.TableDoesNotExistError{Name: "doesntexist"},
},
{
name: "column must exist",
migrations: []migrations.Migration{
createTablesMigration,
{
Name: "02_add_fk_constraint",
Operations: migrations.Operations{
&migrations.OpSetForeignKey{
Table: "posts",
Column: "doesntexist",
References: migrations.ColumnReference{
Table: "users",
Column: "id",
},
Up: "(SELECT CASE WHEN EXISTS (SELECT 1 FROM users WHERE users.id = user_id) THEN user_id ELSE NULL END)",
Down: "user_id",
},
},
},
},
wantStartErr: migrations.ColumnDoesNotExistError{Table: "posts", Name: "doesntexist"},
},
{
name: "referenced table must exist",
migrations: []migrations.Migration{
createTablesMigration,
{
Name: "02_add_fk_constraint",
Operations: migrations.Operations{
&migrations.OpSetForeignKey{
Table: "posts",
Column: "user_id",
References: migrations.ColumnReference{
Table: "doesntexist",
Column: "id",
},
Up: "(SELECT CASE WHEN EXISTS (SELECT 1 FROM users WHERE users.id = user_id) THEN user_id ELSE NULL END)",
Down: "user_id",
},
},
},
},
wantStartErr: migrations.ColumnReferenceError{
Table: "posts",
Column: "user_id",
Err: migrations.TableDoesNotExistError{Name: "doesntexist"},
},
},
{
name: "referenced column must exist",
migrations: []migrations.Migration{
createTablesMigration,
{
Name: "02_add_fk_constraint",
Operations: migrations.Operations{
&migrations.OpSetForeignKey{
Table: "posts",
Column: "user_id",
References: migrations.ColumnReference{
Table: "users",
Column: "doesntexist",
},
Up: "(SELECT CASE WHEN EXISTS (SELECT 1 FROM users WHERE users.id = user_id) THEN user_id ELSE NULL END)",
Down: "user_id",
},
},
},
},
wantStartErr: migrations.ColumnReferenceError{
Table: "posts",
Column: "user_id",
Err: migrations.ColumnDoesNotExistError{Table: "users", Name: "doesntexist"},
},
},
})
}