Implement 'drop column' migrations (#48)

Implement the **drop column**  migration operation.

A migration to drop a column looks like this:

```json
{
  "name": "09_drop_column",
  "operations": [
    {
      "drop_column": {
        "table": "fruits",
        "column": "price",
        "down": "0"
      }
    }
  ]
}
```

The migration takes the name of the table and column that should be
dropped along with (optionally) some `down` SQL to run to populate the
field in the underlying table when insertions are done via the new
schema version while the migration is in progress.

* On `Start`, the relevant view in the new version schema is created
without the dropped column. The column is not deleted from the
underlying table.
* If `down` SQL is specified, a trigger is created on the underlying
table to populate the column to be removed when inserts are made from
the new schema version.
* On `Rollback` any triggers on the underlying table are removed.
* On `Complete` the old version of the schema is removed and the column
is removed from the underlying table. Any triggers are also removed.
This commit is contained in:
Andrew Farries 2023-08-17 07:37:48 +01:00 committed by GitHub
parent 450e6db231
commit f764993640
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 355 additions and 75 deletions

View File

@ -0,0 +1,26 @@
{
"name": "08_create_fruits_table",
"operations": [
{
"create_table": {
"name": "fruits",
"columns": [
{
"name": "id",
"type": "serial",
"pk": true
},
{
"name": "name",
"type": "varchar(255)",
"unique": true
},
{
"name": "price",
"type": "decimal(10,2)"
}
]
}
}
]
}

View File

@ -0,0 +1,12 @@
{
"name": "09_drop_column",
"operations": [
{
"drop_column": {
"table": "fruits",
"column": "price",
"down": "0"
}
}
]
}

View File

@ -26,3 +26,12 @@ type ColumnAlreadyExistsError struct {
func (e ColumnAlreadyExistsError) Error() string {
return fmt.Sprintf("column %q already exists in table %q", e.Name, e.Table)
}
type ColumnDoesNotExistError struct {
Table string
Name string
}
func (e ColumnDoesNotExistError) Error() string {
return fmt.Sprintf("column %q does not exist on table %q", e.Name, e.Table)
}

View File

@ -33,7 +33,16 @@ func (o *OpAddColumn) Start(ctx context.Context, conn *sql.DB, schemaName, state
}
if o.Up != nil {
if err := createTrigger(ctx, conn, o, schemaName, stateSchema, s); err != nil {
err := createTrigger(ctx, conn, s, triggerConfig{
Direction: TriggerDirectionUp,
SchemaName: schemaName,
StateSchema: stateSchema,
Table: o.Table,
Column: o.Column.Name,
PhysicalColumn: TemporaryName(o.Column.Name),
SQL: *o.Up,
})
if err != nil {
return fmt.Errorf("failed to create trigger: %w", err)
}
if err := backFill(ctx, conn, o); err != nil {
@ -154,72 +163,6 @@ func addNotNullConstraint(ctx context.Context, conn *sql.DB, o *OpAddColumn) err
return err
}
func createTrigger(ctx context.Context, conn *sql.DB, o *OpAddColumn, schemaName, stateSchema string, s *schema.Schema) error {
// Generate the SQL declarations for the trigger function
// This results in declarations like:
// col1 table.col1%TYPE := NEW.col1;
// Without these declarations, users would have to reference
// `col1` as `NEW.col1` in their `up` SQL.
sqlDeclarations := func(s *schema.Schema) string {
table := s.GetTable(o.Table)
decls := ""
for _, c := range table.Columns {
decls += fmt.Sprintf("%[1]s %[2]s.%[1]s%%TYPE := NEW.%[1]s;\n",
pq.QuoteIdentifier(c.Name),
pq.QuoteIdentifier(table.Name))
}
return decls
}
//nolint:gosec // unavoidable SQL injection warning when running arbitrary SQL
triggerFn := fmt.Sprintf(`CREATE OR REPLACE FUNCTION %[1]s()
RETURNS TRIGGER
LANGUAGE PLPGSQL
AS $$
DECLARE
%[4]s
latest_schema text;
search_path text;
BEGIN
SELECT %[5]s || '_' || latest_version INTO latest_schema FROM %[6]s.latest_version(%[5]s);
SELECT current_setting INTO search_path FROM current_setting('search_path');
IF search_path <> latest_schema THEN
NEW.%[2]s = %[3]s;
END IF;
RETURN NEW;
END; $$`,
pq.QuoteIdentifier(TriggerFunctionName(o.Table, o.Column.Name)),
pq.QuoteIdentifier(TemporaryName(o.Column.Name)),
*o.Up,
sqlDeclarations(s),
pq.QuoteLiteral(schemaName),
pq.QuoteIdentifier(stateSchema))
_, err := conn.ExecContext(ctx, triggerFn)
if err != nil {
return err
}
trigger := fmt.Sprintf(`CREATE OR REPLACE TRIGGER %[1]s
BEFORE UPDATE OR INSERT
ON %[2]s
FOR EACH ROW
EXECUTE PROCEDURE %[3]s();`,
pq.QuoteIdentifier(TriggerName(o.Table, o.Column.Name)),
pq.QuoteIdentifier(o.Table),
pq.QuoteIdentifier(TriggerFunctionName(o.Table, o.Column.Name)))
_, err = conn.ExecContext(ctx, trigger)
if err != nil {
return err
}
return nil
}
func backFill(ctx context.Context, conn *sql.DB, o *OpAddColumn) error {
// touch rows without changing them in order to have the trigger fire
// and set the value using the `up` SQL.
@ -235,11 +178,3 @@ func backFill(ctx context.Context, conn *sql.DB, o *OpAddColumn) error {
func NotNullConstraintName(columnName string) string {
return "_pgroll_add_column_check_" + columnName
}
func TriggerFunctionName(tableName, columnName string) string {
return "_pgroll_add_column_" + tableName + "_" + columnName
}
func TriggerName(tableName, columnName string) string {
return TriggerFunctionName(tableName, columnName)
}

View File

@ -15,6 +15,7 @@ const (
OpNameRenameTable OpName = "rename_table"
OpNameDropTable OpName = "drop_table"
OpNameAddColumn OpName = "add_column"
OpNameDropColumn OpName = "drop_column"
)
func TemporaryName(name string) string {
@ -82,6 +83,9 @@ func (v *Operations) UnmarshalJSON(data []byte) error {
case OpNameAddColumn:
item = &OpAddColumn{}
case OpNameDropColumn:
item = &OpDropColumn{}
default:
return fmt.Errorf("unknown migration type: %v", opName)
}
@ -126,6 +130,9 @@ func (v Operations) MarshalJSON() ([]byte, error) {
case *OpAddColumn:
opName = OpNameAddColumn
case *OpDropColumn:
opName = OpNameDropColumn
default:
panic(fmt.Errorf("unknown operation for %T", op))
}

View File

@ -0,0 +1,70 @@
package migrations
import (
"context"
"database/sql"
"fmt"
"github.com/lib/pq"
"pg-roll/pkg/schema"
)
type OpDropColumn struct {
Table string `json:"table"`
Column string `json:"column"`
Down *string `json:"down,omitempty"`
}
var _ Operation = (*OpDropColumn)(nil)
func (o *OpDropColumn) Start(ctx context.Context, conn *sql.DB, schemaName string, stateSchema string, s *schema.Schema) error {
if o.Down != nil {
err := createTrigger(ctx, conn, s, triggerConfig{
Direction: TriggerDirectionDown,
SchemaName: schemaName,
StateSchema: stateSchema,
Table: o.Table,
Column: o.Column,
PhysicalColumn: o.Column,
SQL: *o.Down,
})
if err != nil {
return err
}
}
s.GetTable(o.Table).RemoveColumn(o.Column)
return nil
}
func (o *OpDropColumn) Complete(ctx context.Context, conn *sql.DB) error {
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s DROP COLUMN %s", o.Table, o.Column))
if err != nil {
return err
}
_, err = conn.ExecContext(ctx, fmt.Sprintf("DROP FUNCTION IF EXISTS %s CASCADE",
pq.QuoteIdentifier(TriggerFunctionName(o.Table, o.Column))))
return err
}
func (o *OpDropColumn) Rollback(ctx context.Context, conn *sql.DB) error {
_, err := conn.ExecContext(ctx, fmt.Sprintf("DROP FUNCTION IF EXISTS %s CASCADE",
pq.QuoteIdentifier(TriggerFunctionName(o.Table, o.Column))))
return err
}
func (o *OpDropColumn) Validate(ctx context.Context, s *schema.Schema) error {
table := s.GetTable(o.Table)
if table == nil {
return TableDoesNotExistError{Name: o.Table}
}
if table.GetColumn(o.Column) == nil {
return ColumnDoesNotExistError{Table: o.Table, Name: o.Column}
}
return nil
}

View File

@ -0,0 +1,108 @@
package migrations_test
import (
"database/sql"
"testing"
"github.com/stretchr/testify/assert"
"pg-roll/pkg/migrations"
"pg-roll/pkg/roll"
)
func TestDropColumnWithDownSQL(t *testing.T) {
t.Parallel()
ptr := func(s string) *string { return &s }
ExecuteTests(t, TestCases{{
name: "drop column",
migrations: []migrations.Migration{
{
Name: "01_add_table",
Operations: migrations.Operations{
&migrations.OpCreateTable{
Name: "users",
Columns: []migrations.Column{
{
Name: "id",
Type: "serial",
PrimaryKey: true,
},
{
Name: "name",
Type: "varchar(255)",
Nullable: false,
},
{
Name: "email",
Type: "varchar(255)",
Nullable: false,
},
},
},
},
},
{
Name: "02_drop_column",
Operations: migrations.Operations{
&migrations.OpDropColumn{
Table: "users",
Column: "name",
Down: ptr("UPPER(email)"),
},
},
},
},
afterStart: func(t *testing.T, db *sql.DB) {
// The deleted column is not present on the view in the new version schema.
versionSchema := roll.VersionedSchemaName("public", "02_drop_column")
ColumnMustNotExist(t, db, versionSchema, "users", "name")
// But the column is still present on the underlying table.
ColumnMustExist(t, db, "public", "users", "name")
// Inserting into the view in the new version schema should succeed.
MustInsert(t, db, "public", "02_drop_column", "users", map[string]string{
"email": "foo@example.com",
})
// The "down" SQL has populated the removed column ("name")
results := MustSelect(t, db, "public", "01_add_table", "users")
assert.Equal(t, []map[string]any{
{"id": 1, "name": "FOO@EXAMPLE.COM", "email": "foo@example.com"},
}, results)
},
afterRollback: func(t *testing.T, db *sql.DB) {
// The trigger function has been dropped.
triggerFnName := migrations.TriggerFunctionName("users", "name")
FunctionMustNotExist(t, db, "public", triggerFnName)
// The trigger has been dropped.
triggerName := migrations.TriggerName("users", "name")
TriggerMustNotExist(t, db, "public", "users", triggerName)
},
afterComplete: func(t *testing.T, db *sql.DB) {
// The column has been deleted from the underlying table.
ColumnMustNotExist(t, db, "public", "users", "name")
// The trigger function has been dropped.
triggerFnName := migrations.TriggerFunctionName("users", "name")
FunctionMustNotExist(t, db, "public", triggerFnName)
// The trigger has been dropped.
triggerName := migrations.TriggerName("users", "name")
TriggerMustNotExist(t, db, "public", "users", triggerName)
// Inserting into the view in the new version schema should succeed.
MustInsert(t, db, "public", "02_drop_column", "users", map[string]string{
"email": "bar@example.com",
})
results := MustSelect(t, db, "public", "02_drop_column", "users")
assert.Equal(t, []map[string]any{
{"id": 1, "email": "foo@example.com"},
{"id": 2, "email": "bar@example.com"},
}, results)
},
}})
}

109
pkg/migrations/triggers.go Normal file
View File

@ -0,0 +1,109 @@
package migrations
import (
"context"
"database/sql"
"fmt"
"pg-roll/pkg/schema"
"github.com/lib/pq"
)
type TriggerDirection string
const (
TriggerDirectionUp TriggerDirection = "up"
TriggerDirectionDown TriggerDirection = "down"
)
type triggerConfig struct {
Direction TriggerDirection
SchemaName string
StateSchema string
Table string
Column string
PhysicalColumn string
SQL string
}
func createTrigger(ctx context.Context, conn *sql.DB, s *schema.Schema, cfg triggerConfig) error {
// Generate the SQL declarations for the trigger function
// This results in declarations like:
// col1 table.col1%TYPE := NEW.col1;
// Without these declarations, users would have to reference
// `col1` as `NEW.col1` in their `up` SQL.
sqlDeclarations := func(s *schema.Schema) string {
table := s.GetTable(cfg.Table)
decls := ""
for _, c := range table.Columns {
decls += fmt.Sprintf("%[1]s %[3]s.%[2]s.%[1]s%%TYPE := NEW.%[1]s;\n",
pq.QuoteIdentifier(c.Name),
pq.QuoteIdentifier(table.Name),
pq.QuoteIdentifier(cfg.SchemaName))
}
return decls
}
cmp := "<>"
if cfg.Direction == TriggerDirectionDown {
cmp = "="
}
//nolint:gosec // unavoidable SQL injection warning when running arbitrary SQL
triggerFn := fmt.Sprintf(`CREATE OR REPLACE FUNCTION %[1]s()
RETURNS TRIGGER
LANGUAGE PLPGSQL
AS $$
DECLARE
%[4]s
latest_schema text;
search_path text;
BEGIN
SELECT %[5]s || '_' || latest_version INTO latest_schema FROM %[6]s.latest_version(%[5]s);
SELECT current_setting INTO search_path FROM current_setting('search_path');
IF search_path %[7]s latest_schema THEN
NEW.%[2]s = %[3]s;
END IF;
RETURN NEW;
END; $$`,
pq.QuoteIdentifier(TriggerFunctionName(cfg.Table, cfg.Column)),
pq.QuoteIdentifier(cfg.PhysicalColumn),
cfg.SQL,
sqlDeclarations(s),
pq.QuoteLiteral(cfg.SchemaName),
pq.QuoteIdentifier(cfg.StateSchema),
cmp)
_, err := conn.ExecContext(ctx, triggerFn)
if err != nil {
return err
}
trigger := fmt.Sprintf(`CREATE OR REPLACE TRIGGER %[1]s
BEFORE UPDATE OR INSERT
ON %[2]s
FOR EACH ROW
EXECUTE PROCEDURE %[3]s();`,
pq.QuoteIdentifier(TriggerName(cfg.Table, cfg.Column)),
pq.QuoteIdentifier(cfg.Table),
pq.QuoteIdentifier(TriggerFunctionName(cfg.Table, cfg.Column)))
_, err = conn.ExecContext(ctx, trigger)
if err != nil {
return err
}
return nil
}
func TriggerFunctionName(tableName, columnName string) string {
return "_pgroll_trigger_" + tableName + "_" + columnName
}
func TriggerName(tableName, columnName string) string {
return TriggerFunctionName(tableName, columnName)
}

View File

@ -102,3 +102,7 @@ func (t *Table) AddColumn(name string, c Column) {
t.Columns[name] = c
}
func (t *Table) RemoveColumn(column string) {
delete(t.Columns, column)
}