From c76ea9ce4840dc9297ea966079a2bb67f5f44c19 Mon Sep 17 00:00:00 2001 From: Andrew Farries Date: Mon, 11 Sep 2023 06:12:59 +0100 Subject: [PATCH] Add support for adding a foreign key constraint to an existing column (#82) Add support for adding a foreign key constraint to an existing column. Such a migration looks like: ```json { "name": "21_add_foreign_key_constraint", "operations": [ { "set_foreign_key": { "table": "posts", "column": "user_id", "references": { "table": "users", "column": "id" }, "up": "(SELECT CASE WHEN EXISTS (SELECT 1 FROM users WHERE users.id = user_id) THEN user_id ELSE NULL END)", "down": "user_id" } } ] } ``` This migration adds a foreign key constraint to the `user_id` column in the `posts` table, referencing the `id` column in the `users` table. The implementation is similar to the **set not null** and **change column type** operations: * On `Start`: * Create a new column, duplicating the one to which the FK constraint should be added. * The new column has the foreign key constraint added as `NOT VALID` to avoid taking a long lived `SHARE ROW EXCLUSIVE` lock (see [here](https://medium.com/paypal-tech/postgresql-at-scale-database-schema-changes-without-downtime-20d3749ed680#00dc)). * Backfill the new column with values from the existing column, rewriting values using the `up` SQL. * Create a trigger to populate the new column when values are written to the old column, converting values with `up`. * Create a trigger to populate the old column when values are written to the new column, converting values with `down`. * On `Complete` * Validate the foreign key constraint. * Remove triggers * Drop the old column * Rename the new column to the old column name. * Rename the foreign key constraint to be consistent with the new name of the column. * On `Rollback` * Remove the new column and both triggers. Removing the new column also removes the foreign key constraint on it. The `up` SQL in this operation is critical. The old column does not have a foreign key constraint imposed on it after `Start` as that would violate the guarantee that `pg-roll` does not make changes to the existing schema. The `up` SQL therefore needs to take into account that not all rows inserted into the old schema will have a valid foreign key. In the example `json` above, the `up` SQL ensures that values for which there is no corresponding user in the `users` table result in `NULL` values in the new column. Failure to do this would result in the old schema failing to insert rows without a valid `user_id`. An alternative would be to implement data quarantining for these values, as discussed last week @exekias . --- examples/20_create_posts_table.json | 26 ++ examples/21_add_foreign_key_constraint.json | 17 ++ pkg/migrations/op_common.go | 31 +- pkg/migrations/op_set_fk.go | 216 ++++++++++++++ pkg/migrations/op_set_fk_test.go | 312 ++++++++++++++++++++ 5 files changed, 590 insertions(+), 12 deletions(-) create mode 100644 examples/20_create_posts_table.json create mode 100644 examples/21_add_foreign_key_constraint.json create mode 100644 pkg/migrations/op_set_fk.go create mode 100644 pkg/migrations/op_set_fk_test.go diff --git a/examples/20_create_posts_table.json b/examples/20_create_posts_table.json new file mode 100644 index 0000000..df0e28f --- /dev/null +++ b/examples/20_create_posts_table.json @@ -0,0 +1,26 @@ +{ + "name": "20_create_posts_table", + "operations": [ + { + "create_table": { + "name": "posts", + "columns": [ + { + "name": "id", + "type": "serial", + "pk": true + }, + { + "name": "title", + "type": "varchar(255)" + }, + { + "name": "user_id", + "type": "integer", + "nullable": true + } + ] + } + } + ] +} diff --git a/examples/21_add_foreign_key_constraint.json b/examples/21_add_foreign_key_constraint.json new file mode 100644 index 0000000..19ca1a5 --- /dev/null +++ b/examples/21_add_foreign_key_constraint.json @@ -0,0 +1,17 @@ +{ + "name": "21_add_foreign_key_constraint", + "operations": [ + { + "set_foreign_key": { + "table": "posts", + "column": "user_id", + "references": { + "table": "users", + "column": "id" + }, + "up": "(SELECT CASE WHEN EXISTS (SELECT 1 FROM users WHERE users.id = user_id) THEN user_id ELSE NULL END)", + "down": "user_id" + } + } + ] +} diff --git a/pkg/migrations/op_common.go b/pkg/migrations/op_common.go index 36cf1cf..3e16ab4 100644 --- a/pkg/migrations/op_common.go +++ b/pkg/migrations/op_common.go @@ -11,18 +11,19 @@ import ( type OpName string const ( - OpNameCreateTable OpName = "create_table" - OpNameRenameTable OpName = "rename_table" - OpNameDropTable OpName = "drop_table" - OpNameAddColumn OpName = "add_column" - OpNameDropColumn OpName = "drop_column" - OpNameCreateIndex OpName = "create_index" - OpNameDropIndex OpName = "drop_index" - OpNameRenameColumn OpName = "rename_column" - OpNameSetUnique OpName = "set_unique" - OpNameSetNotNull OpName = "set_not_null" - OpNameChangeType OpName = "change_type" - OpRawSQLName OpName = "sql" + OpNameCreateTable OpName = "create_table" + OpNameRenameTable OpName = "rename_table" + OpNameDropTable OpName = "drop_table" + OpNameAddColumn OpName = "add_column" + OpNameDropColumn OpName = "drop_column" + OpNameCreateIndex OpName = "create_index" + OpNameDropIndex OpName = "drop_index" + OpNameRenameColumn OpName = "rename_column" + OpNameSetUnique OpName = "set_unique" + OpNameSetNotNull OpName = "set_not_null" + OpNameSetForeignKey OpName = "set_foreign_key" + OpNameChangeType OpName = "change_type" + OpRawSQLName OpName = "sql" ) func TemporaryName(name string) string { @@ -108,6 +109,9 @@ func (v *Operations) UnmarshalJSON(data []byte) error { case OpNameSetNotNull: item = &OpSetNotNull{} + case OpNameSetForeignKey: + item = &OpSetForeignKey{} + case OpNameChangeType: item = &OpChangeType{} @@ -188,6 +192,9 @@ func OperationName(op Operation) OpName { case *OpSetNotNull: return OpNameSetNotNull + case *OpSetForeignKey: + return OpNameSetForeignKey + case *OpChangeType: return OpNameChangeType diff --git a/pkg/migrations/op_set_fk.go b/pkg/migrations/op_set_fk.go new file mode 100644 index 0000000..e727965 --- /dev/null +++ b/pkg/migrations/op_set_fk.go @@ -0,0 +1,216 @@ +package migrations + +import ( + "context" + "database/sql" + "fmt" + + "github.com/lib/pq" + "github.com/xataio/pg-roll/pkg/schema" +) + +type OpSetForeignKey struct { + Table string `json:"table"` + Column string `json:"column"` + References ColumnReference `json:"references"` + Up string `json:"up"` + Down string `json:"down"` +} + +var _ Operation = (*OpSetForeignKey)(nil) + +func (o *OpSetForeignKey) Start(ctx context.Context, conn *sql.DB, stateSchema string, s *schema.Schema) error { + table := s.GetTable(o.Table) + column := table.GetColumn(o.Column) + + // Create a copy of the column on the underlying table. + if err := duplicateColumn(ctx, conn, table, *column); err != nil { + return fmt.Errorf("failed to duplicate column: %w", err) + } + + // Create a NOT VALID foreign key constraint on the new column. + if err := o.addForeignKeyConstraint(ctx, conn); err != nil { + return fmt.Errorf("failed to add foreign key constraint: %w", err) + } + + // Add a trigger to copy values from the old column to the new, rewriting values using the `up` SQL. + err := createTrigger(ctx, conn, triggerConfig{ + Name: TriggerName(o.Table, o.Column), + Direction: TriggerDirectionUp, + Columns: table.Columns, + SchemaName: s.Name, + TableName: o.Table, + PhysicalColumn: TemporaryName(o.Column), + StateSchema: stateSchema, + SQL: o.Up, + }) + if err != nil { + return fmt.Errorf("failed to create up trigger: %w", err) + } + + // Backfill the new column with values from the old column. + if err := backFill(ctx, conn, o.Table, TemporaryName(o.Column)); err != nil { + return fmt.Errorf("failed to backfill column: %w", err) + } + + // Add the new column to the internal schema representation. This is done + // here, before creation of the down trigger, so that the trigger can declare + // a variable for the new column. + table.AddColumn(o.Column, schema.Column{ + Name: TemporaryName(o.Column), + }) + + // Add a trigger to copy values from the new column to the old, rewriting values using the `down` SQL. + err = createTrigger(ctx, conn, triggerConfig{ + Name: TriggerName(o.Table, TemporaryName(o.Column)), + Direction: TriggerDirectionDown, + Columns: table.Columns, + SchemaName: s.Name, + TableName: o.Table, + PhysicalColumn: o.Column, + StateSchema: stateSchema, + SQL: o.Down, + }) + if err != nil { + return fmt.Errorf("failed to create down trigger: %w", err) + } + + return nil +} + +func (o *OpSetForeignKey) Complete(ctx context.Context, conn *sql.DB) error { + tempName := TemporaryName(o.Column) + tableRef := o.References.Table + columnRef := o.References.Column + + // Validate the foreign key constraint + _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s VALIDATE CONSTRAINT %s", + pq.QuoteIdentifier(o.Table), + pq.QuoteIdentifier(ForeignKeyConstraintName(tempName, tableRef, columnRef)))) + if err != nil { + return err + } + + // Remove the up function and trigger + _, err = conn.ExecContext(ctx, fmt.Sprintf("DROP FUNCTION IF EXISTS %s CASCADE", + pq.QuoteIdentifier(TriggerFunctionName(o.Table, o.Column)))) + if err != nil { + return err + } + + // Remove the down function and trigger + _, err = conn.ExecContext(ctx, fmt.Sprintf("DROP FUNCTION IF EXISTS %s CASCADE", + pq.QuoteIdentifier(TriggerFunctionName(o.Table, TemporaryName(o.Column))))) + if err != nil { + return err + } + + // Drop the old column + _, err = conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s DROP COLUMN IF EXISTS %s", + pq.QuoteIdentifier(o.Table), + pq.QuoteIdentifier(o.Column))) + if err != nil { + return err + } + + // Rename the new column to the old column name + _, err = conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s RENAME COLUMN %s TO %s", + pq.QuoteIdentifier(o.Table), + pq.QuoteIdentifier(TemporaryName(o.Column)), + pq.QuoteIdentifier(o.Column))) + if err != nil { + return err + } + + // Rename the foreign key constraint to use the final (non-temporary) column name. + _, err = conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s RENAME CONSTRAINT %s TO %s", + pq.QuoteIdentifier(o.Table), + pq.QuoteIdentifier(ForeignKeyConstraintName(tempName, tableRef, columnRef)), + pq.QuoteIdentifier(ForeignKeyConstraintName(o.Column, tableRef, columnRef)), + )) + + return err +} + +func (o *OpSetForeignKey) Rollback(ctx context.Context, conn *sql.DB) error { + // Drop the new column, taking the constraint on the column with it + _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s DROP COLUMN IF EXISTS %s", + pq.QuoteIdentifier(o.Table), + pq.QuoteIdentifier(TemporaryName(o.Column)), + )) + if err != nil { + return err + } + + // Remove the up function and trigger + _, err = conn.ExecContext(ctx, fmt.Sprintf("DROP FUNCTION IF EXISTS %s CASCADE", + pq.QuoteIdentifier(TriggerFunctionName(o.Table, o.Column)), + )) + if err != nil { + return err + } + + // Remove the down function and trigger + _, err = conn.ExecContext(ctx, fmt.Sprintf("DROP FUNCTION IF EXISTS %s CASCADE", + pq.QuoteIdentifier(TriggerFunctionName(o.Table, TemporaryName(o.Column))), + )) + + return err +} + +func (o *OpSetForeignKey) Validate(ctx context.Context, s *schema.Schema) error { + table := s.GetTable(o.Table) + if table == nil { + return TableDoesNotExistError{Name: o.Table} + } + + column := table.GetColumn(o.Column) + if column == nil { + return ColumnDoesNotExistError{Table: o.Table, Name: o.Column} + } + + refTable := s.GetTable(o.References.Table) + if refTable == nil { + return ColumnReferenceError{ + Table: o.Table, + Column: o.Column, + Err: TableDoesNotExistError{Name: o.References.Table}, + } + } + + refColumn := refTable.GetColumn(o.References.Column) + if refColumn == nil { + return ColumnReferenceError{ + Table: o.Table, + Column: o.Column, + Err: ColumnDoesNotExistError{ + Table: o.References.Table, + Name: o.References.Column, + }, + } + } + + if o.Up == "" { + return FieldRequiredError{Name: "up"} + } + + if o.Down == "" { + return FieldRequiredError{Name: "down"} + } + + return nil +} + +func (o *OpSetForeignKey) addForeignKeyConstraint(ctx context.Context, conn *sql.DB) error { + tempColumnName := TemporaryName(o.Column) + + _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s) NOT VALID", + pq.QuoteIdentifier(o.Table), + pq.QuoteIdentifier(ForeignKeyConstraintName(tempColumnName, o.References.Table, o.References.Column)), + pq.QuoteIdentifier(tempColumnName), + pq.QuoteIdentifier(o.References.Table), + pq.QuoteIdentifier(o.References.Column), + )) + + return err +} diff --git a/pkg/migrations/op_set_fk_test.go b/pkg/migrations/op_set_fk_test.go new file mode 100644 index 0000000..f58219a --- /dev/null +++ b/pkg/migrations/op_set_fk_test.go @@ -0,0 +1,312 @@ +package migrations_test + +import ( + "database/sql" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/xataio/pg-roll/pkg/migrations" +) + +func TestSetForeignKey(t *testing.T) { + t.Parallel() + + ExecuteTests(t, TestCases{{ + name: "add foreign key constraint", + migrations: []migrations.Migration{ + { + Name: "01_add_tables", + Operations: migrations.Operations{ + &migrations.OpCreateTable{ + Name: "users", + Columns: []migrations.Column{ + { + Name: "id", + Type: "serial", + PrimaryKey: true, + }, + { + Name: "name", + Type: "text", + }, + }, + }, + &migrations.OpCreateTable{ + Name: "posts", + Columns: []migrations.Column{ + { + Name: "id", + Type: "serial", + PrimaryKey: true, + }, + { + Name: "title", + Type: "text", + }, + { + Name: "user_id", + Type: "integer", + }, + }, + }, + }, + }, + { + Name: "02_add_fk_constraint", + Operations: migrations.Operations{ + &migrations.OpSetForeignKey{ + Table: "posts", + Column: "user_id", + References: migrations.ColumnReference{ + Table: "users", + Column: "id", + }, + Up: "(SELECT CASE WHEN EXISTS (SELECT 1 FROM users WHERE users.id = user_id) THEN user_id ELSE NULL END)", + Down: "user_id", + }, + }, + }, + }, + afterStart: func(t *testing.T, db *sql.DB) { + // The new (temporary) `user_id` column should exist on the underlying table. + ColumnMustExist(t, db, "public", "posts", migrations.TemporaryName("user_id")) + + // Inserting some data into the `users` table works. + MustInsert(t, db, "public", "02_add_fk_constraint", "users", map[string]string{ + "name": "alice", + }) + MustInsert(t, db, "public", "02_add_fk_constraint", "users", map[string]string{ + "name": "bob", + }) + + // Inserting data into the new `posts` view with a valid user reference works. + MustInsert(t, db, "public", "02_add_fk_constraint", "posts", map[string]string{ + "title": "post by alice", + "user_id": "1", + }) + + // Inserting data into the new `posts` view with an invalid user reference fails. + MustNotInsert(t, db, "public", "02_add_fk_constraint", "posts", map[string]string{ + "title": "post by unknown user", + "user_id": "3", + }) + + // The post that was inserted successfully has been backfilled into the old view. + rows := MustSelect(t, db, "public", "01_add_tables", "posts") + assert.Equal(t, []map[string]any{ + {"id": 1, "title": "post by alice", "user_id": 1}, + }, rows) + + // Inserting data into the old `posts` view with a valid user reference works. + MustInsert(t, db, "public", "01_add_tables", "posts", map[string]string{ + "title": "post by bob", + "user_id": "2", + }) + + // Inserting data into the old `posts` view with an invalid user reference also works. + MustInsert(t, db, "public", "01_add_tables", "posts", map[string]string{ + "title": "post by unknown user", + "user_id": "3", + }) + + // The post that was inserted successfully has been backfilled into the new view. + // The post by an unknown user has been backfilled with a NULL user_id. + rows = MustSelect(t, db, "public", "02_add_fk_constraint", "posts") + assert.Equal(t, []map[string]any{ + {"id": 1, "title": "post by alice", "user_id": 1}, + {"id": 3, "title": "post by bob", "user_id": 2}, + {"id": 4, "title": "post by unknown user", "user_id": nil}, + }, rows) + }, + afterRollback: func(t *testing.T, db *sql.DB) { + // The new (temporary) `user_id` column should not exist on the underlying table. + ColumnMustNotExist(t, db, "public", "posts", migrations.TemporaryName("user_id")) + + // The up function no longer exists. + FunctionMustNotExist(t, db, "public", migrations.TriggerFunctionName("posts", "user_id")) + // The down function no longer exists. + FunctionMustNotExist(t, db, "public", migrations.TriggerFunctionName("posts", migrations.TemporaryName("user_id"))) + + // The up trigger no longer exists. + TriggerMustNotExist(t, db, "public", "posts", migrations.TriggerName("posts", "user_id")) + // The down trigger no longer exists. + TriggerMustNotExist(t, db, "public", "posts", migrations.TriggerName("posts", migrations.TemporaryName("user_id"))) + }, + afterComplete: func(t *testing.T, db *sql.DB) { + // The new (temporary) `user_id` column should not exist on the underlying table. + ColumnMustNotExist(t, db, "public", "posts", migrations.TemporaryName("user_id")) + + // Inserting data into the new `posts` view with a valid user reference works. + MustInsert(t, db, "public", "02_add_fk_constraint", "posts", map[string]string{ + "title": "another post by alice", + "user_id": "1", + }) + + // Inserting data into the new `posts` view with an invalid user reference fails. + MustNotInsert(t, db, "public", "02_add_fk_constraint", "posts", map[string]string{ + "title": "post by unknown user", + "user_id": "3", + }) + + // The data in the new `posts` view is as expected. + rows := MustSelect(t, db, "public", "02_add_fk_constraint", "posts") + assert.Equal(t, []map[string]any{ + {"id": 1, "title": "post by alice", "user_id": 1}, + {"id": 3, "title": "post by bob", "user_id": 2}, + {"id": 4, "title": "post by unknown user", "user_id": nil}, + {"id": 5, "title": "another post by alice", "user_id": 1}, + }, rows) + + // The up function no longer exists. + FunctionMustNotExist(t, db, "public", migrations.TriggerFunctionName("posts", "user_id")) + // The down function no longer exists. + FunctionMustNotExist(t, db, "public", migrations.TriggerFunctionName("posts", migrations.TemporaryName("user_id"))) + + // The up trigger no longer exists. + TriggerMustNotExist(t, db, "public", "posts", migrations.TriggerName("posts", "user_id")) + // The down trigger no longer exists. + TriggerMustNotExist(t, db, "public", "posts", migrations.TriggerName("posts", migrations.TemporaryName("user_id"))) + }, + }}) +} + +func TestSetForeignKeyValidation(t *testing.T) { + t.Parallel() + + createTablesMigration := migrations.Migration{ + Name: "01_add_tables", + Operations: migrations.Operations{ + &migrations.OpCreateTable{ + Name: "users", + Columns: []migrations.Column{ + { + Name: "id", + Type: "serial", + PrimaryKey: true, + }, + { + Name: "name", + Type: "text", + }, + }, + }, + &migrations.OpCreateTable{ + Name: "posts", + Columns: []migrations.Column{ + { + Name: "id", + Type: "serial", + PrimaryKey: true, + }, + { + Name: "title", + Type: "text", + }, + { + Name: "user_id", + Type: "integer", + }, + }, + }, + }, + } + + ExecuteTests(t, TestCases{ + { + name: "table must exist", + migrations: []migrations.Migration{ + createTablesMigration, + { + Name: "02_add_fk_constraint", + Operations: migrations.Operations{ + &migrations.OpSetForeignKey{ + Table: "doesntexist", + Column: "user_id", + References: migrations.ColumnReference{ + Table: "users", + Column: "id", + }, + Up: "(SELECT CASE WHEN EXISTS (SELECT 1 FROM users WHERE users.id = user_id) THEN user_id ELSE NULL END)", + Down: "user_id", + }, + }, + }, + }, + wantStartErr: migrations.TableDoesNotExistError{Name: "doesntexist"}, + }, + { + name: "column must exist", + migrations: []migrations.Migration{ + createTablesMigration, + { + Name: "02_add_fk_constraint", + Operations: migrations.Operations{ + &migrations.OpSetForeignKey{ + Table: "posts", + Column: "doesntexist", + References: migrations.ColumnReference{ + Table: "users", + Column: "id", + }, + Up: "(SELECT CASE WHEN EXISTS (SELECT 1 FROM users WHERE users.id = user_id) THEN user_id ELSE NULL END)", + Down: "user_id", + }, + }, + }, + }, + wantStartErr: migrations.ColumnDoesNotExistError{Table: "posts", Name: "doesntexist"}, + }, + { + name: "referenced table must exist", + migrations: []migrations.Migration{ + createTablesMigration, + { + Name: "02_add_fk_constraint", + Operations: migrations.Operations{ + &migrations.OpSetForeignKey{ + Table: "posts", + Column: "user_id", + References: migrations.ColumnReference{ + Table: "doesntexist", + Column: "id", + }, + Up: "(SELECT CASE WHEN EXISTS (SELECT 1 FROM users WHERE users.id = user_id) THEN user_id ELSE NULL END)", + Down: "user_id", + }, + }, + }, + }, + wantStartErr: migrations.ColumnReferenceError{ + Table: "posts", + Column: "user_id", + Err: migrations.TableDoesNotExistError{Name: "doesntexist"}, + }, + }, + { + name: "referenced column must exist", + migrations: []migrations.Migration{ + createTablesMigration, + { + Name: "02_add_fk_constraint", + Operations: migrations.Operations{ + &migrations.OpSetForeignKey{ + Table: "posts", + Column: "user_id", + References: migrations.ColumnReference{ + Table: "users", + Column: "doesntexist", + }, + Up: "(SELECT CASE WHEN EXISTS (SELECT 1 FROM users WHERE users.id = user_id) THEN user_id ELSE NULL END)", + Down: "user_id", + }, + }, + }, + }, + wantStartErr: migrations.ColumnReferenceError{ + Table: "posts", + Column: "user_id", + Err: migrations.ColumnDoesNotExistError{Table: "users", Name: "doesntexist"}, + }, + }, + }) +}