From 425118423c44e6296d61301e207aa0ee30d8b548 Mon Sep 17 00:00:00 2001 From: Andrew Farries Date: Wed, 6 Sep 2023 05:42:59 +0100 Subject: [PATCH] Support creating foreign key constraints with the add column operation (#80) Allow the **add column** operation to create foreign key columns. An example of such an operation is: ```json { "name": "17_add_rating_column", "operations": [ { "add_column": { "table": "orders", "column": { "name": "user_id", "type": "integer", "references": { "table": "users", "column": "id", } } } } ] } ``` Most of the work to support the operation is in https://github.com/xataio/pg-roll/pull/79. * The constraint is added on `Start` (named according to the temporary name of the new column). * The entire new column, including the foreign key constraint, is removed on `Rollback`. * The constraint is renamed to use the final name of the new column on `Complete`. Test cases are included for both nullable and non-nullable FKs. --- pkg/migrations/op_add_column.go | 20 ++- pkg/migrations/op_add_column_test.go | 218 +++++++++++++++++++++++++++ 2 files changed, 232 insertions(+), 6 deletions(-) diff --git a/pkg/migrations/op_add_column.go b/pkg/migrations/op_add_column.go index 04010b2..0047539 100644 --- a/pkg/migrations/op_add_column.go +++ b/pkg/migrations/op_add_column.go @@ -96,9 +96,21 @@ func (o *OpAddColumn) Complete(ctx context.Context, conn *sql.DB) error { if err != nil { return err } - } - return nil + + // Rename any foreign key constraint to use the final (non-temporary) column name. + if o.Column.References != nil { + tableRef := o.Column.References.Table + columnRef := o.Column.References.Column + + _, err = conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s RENAME CONSTRAINT %s TO %s", + pq.QuoteIdentifier(o.Table), + pq.QuoteIdentifier(ForeignKeyConstraintName(tempName, tableRef, columnRef)), + pq.QuoteIdentifier(ForeignKeyConstraintName(o.Column.Name, tableRef, columnRef)), + )) + } + + return err } func (o *OpAddColumn) Rollback(ctx context.Context, conn *sql.DB) error { @@ -130,10 +142,6 @@ func (o *OpAddColumn) Validate(ctx context.Context, s *schema.Schema) error { return errors.New("adding primary key columns is not supported") } - if o.Column.References != nil { - return errors.New("adding foreign key columns is not supported") - } - return nil } diff --git a/pkg/migrations/op_add_column_test.go b/pkg/migrations/op_add_column_test.go index 5e777b6..830de81 100644 --- a/pkg/migrations/op_add_column_test.go +++ b/pkg/migrations/op_add_column_test.go @@ -105,6 +105,224 @@ func TestAddColumn(t *testing.T) { }}) } +func TestAddForeignKeyColumn(t *testing.T) { + t.Parallel() + + ExecuteTests(t, TestCases{ + { + name: "add nullable foreign key column", + migrations: []migrations.Migration{ + { + Name: "01_create_table", + Operations: migrations.Operations{ + &migrations.OpCreateTable{ + Name: "users", + Columns: []migrations.Column{ + { + Name: "id", + Type: "serial", + PrimaryKey: true, + }, + { + Name: "name", + Type: "varchar(255)", + Unique: true, + }, + }, + }, + &migrations.OpCreateTable{ + Name: "orders", + Columns: []migrations.Column{ + { + Name: "id", + Type: "serial", + PrimaryKey: true, + }, + { + Name: "quantity", + Type: "integer", + }, + }, + }, + }, + }, + { + Name: "02_add_column", + Operations: migrations.Operations{ + &migrations.OpAddColumn{ + Table: "orders", + Column: migrations.Column{ + Name: "user_id", + Type: "integer", + References: &migrations.ColumnReference{ + Table: "users", + Column: "id", + }, + Nullable: true, + }, + }, + }, + }, + }, + afterStart: func(t *testing.T, db *sql.DB) { + // The foreign key constraint exists on the new table. + tempColumnName := migrations.TemporaryName("user_id") + constraintName := migrations.ForeignKeyConstraintName(tempColumnName, "users", "id") + ConstraintMustExist(t, db, "public", "orders", constraintName) + + // Inserting a row into the referenced table succeeds. + MustInsert(t, db, "public", "01_create_table", "users", map[string]string{ + "name": "alice", + }) + + // Inserting a row into the referencing table succeeds as the referenced row exists. + MustInsert(t, db, "public", "02_add_column", "orders", map[string]string{ + "user_id": "1", + "quantity": "100", + }) + + // Inserting a row into the referencing table fails as the referenced row does not exist. + MustNotInsert(t, db, "public", "02_create_table_with_fk", "orders", map[string]string{ + "user_id": "2", + "quantity": "200", + }) + }, + afterRollback: func(t *testing.T, db *sql.DB) { + // The new column has been dropped, so the foreign key constraint is gone. + }, + afterComplete: func(t *testing.T, db *sql.DB) { + // The foreign key constraint exists on the new table, using the final + // (non-temporary) name of the new column. + constraintName := migrations.ForeignKeyConstraintName("user_id", "users", "id") + ConstraintMustExist(t, db, "public", "orders", constraintName) + + // Inserting a row into the referenced table succeeds. + MustInsert(t, db, "public", "02_add_column", "users", map[string]string{ + "name": "bob", + }) + + // Inserting a row into the referencing table succeeds as the referenced row exists. + MustInsert(t, db, "public", "02_add_column", "orders", map[string]string{ + "user_id": "2", + "quantity": "200", + }) + + // Inserting a row into the referencing table fails as the referenced row does not exist. + MustNotInsert(t, db, "public", "02_add_column", "orders", map[string]string{ + "user_id": "3", + "quantity": "300", + }) + }, + }, + { + name: "add non-nullable foreign key column", + migrations: []migrations.Migration{ + { + Name: "01_create_table", + Operations: migrations.Operations{ + &migrations.OpCreateTable{ + Name: "users", + Columns: []migrations.Column{ + { + Name: "id", + Type: "serial", + PrimaryKey: true, + }, + { + Name: "name", + Type: "varchar(255)", + Unique: true, + }, + }, + }, + &migrations.OpCreateTable{ + Name: "orders", + Columns: []migrations.Column{ + { + Name: "id", + Type: "serial", + PrimaryKey: true, + }, + { + Name: "quantity", + Type: "integer", + }, + }, + }, + }, + }, + { + Name: "02_add_column", + Operations: migrations.Operations{ + &migrations.OpAddColumn{ + Table: "orders", + Column: migrations.Column{ + Name: "user_id", + Type: "integer", + References: &migrations.ColumnReference{ + Table: "users", + Column: "id", + }, + Nullable: false, + }, + Up: ptr("1"), + }, + }, + }, + }, + afterStart: func(t *testing.T, db *sql.DB) { + // The foreign key constraint exists on the new table. + tempColumnName := migrations.TemporaryName("user_id") + constraintName := migrations.ForeignKeyConstraintName(tempColumnName, "users", "id") + ConstraintMustExist(t, db, "public", "orders", constraintName) + + // Inserting a row into the referenced table succeeds. + MustInsert(t, db, "public", "01_create_table", "users", map[string]string{ + "name": "alice", + }) + + // Inserting a row into the referencing table succeeds as the referenced row exists. + MustInsert(t, db, "public", "02_add_column", "orders", map[string]string{ + "user_id": "1", + "quantity": "100", + }) + + // Inserting a row into the referencing table fails as the referenced row does not exist. + MustNotInsert(t, db, "public", "02_create_table_with_fk", "orders", map[string]string{ + "user_id": "2", + "quantity": "200", + }) + }, + afterRollback: func(t *testing.T, db *sql.DB) { + // The new column has been dropped, so the foreign key constraint is gone. + }, + afterComplete: func(t *testing.T, db *sql.DB) { + // The foreign key constraint exists on the new table, using the final + // (non-temporary) name of the new column. + constraintName := migrations.ForeignKeyConstraintName("user_id", "users", "id") + ConstraintMustExist(t, db, "public", "orders", constraintName) + + // Inserting a row into the referenced table succeeds. + MustInsert(t, db, "public", "02_add_column", "users", map[string]string{ + "name": "bob", + }) + + // Inserting a row into the referencing table succeeds as the referenced row exists. + MustInsert(t, db, "public", "02_add_column", "orders", map[string]string{ + "user_id": "2", + "quantity": "200", + }) + + // Inserting a row into the referencing table fails as the referenced row does not exist. + MustNotInsert(t, db, "public", "02_add_column", "orders", map[string]string{ + "user_id": "3", + "quantity": "300", + }) + }, + }, + }) +} + func TestAddColumnWithUpSql(t *testing.T) { t.Parallel()