diff --git a/examples/08_create_fruits_table.json b/examples/08_create_fruits_table.json new file mode 100644 index 0000000..fb034d1 --- /dev/null +++ b/examples/08_create_fruits_table.json @@ -0,0 +1,26 @@ +{ + "name": "08_create_fruits_table", + "operations": [ + { + "create_table": { + "name": "fruits", + "columns": [ + { + "name": "id", + "type": "serial", + "pk": true + }, + { + "name": "name", + "type": "varchar(255)", + "unique": true + }, + { + "name": "price", + "type": "decimal(10,2)" + } + ] + } + } + ] +} diff --git a/examples/09_drop_column.json b/examples/09_drop_column.json new file mode 100644 index 0000000..4dad911 --- /dev/null +++ b/examples/09_drop_column.json @@ -0,0 +1,12 @@ +{ + "name": "09_drop_column", + "operations": [ + { + "drop_column": { + "table": "fruits", + "column": "price", + "down": "0" + } + } + ] +} diff --git a/pkg/migrations/errors.go b/pkg/migrations/errors.go index 87dc2fa..302d6ef 100644 --- a/pkg/migrations/errors.go +++ b/pkg/migrations/errors.go @@ -26,3 +26,12 @@ type ColumnAlreadyExistsError struct { func (e ColumnAlreadyExistsError) Error() string { return fmt.Sprintf("column %q already exists in table %q", e.Name, e.Table) } + +type ColumnDoesNotExistError struct { + Table string + Name string +} + +func (e ColumnDoesNotExistError) Error() string { + return fmt.Sprintf("column %q does not exist on table %q", e.Name, e.Table) +} diff --git a/pkg/migrations/op_add_column.go b/pkg/migrations/op_add_column.go index 92f0ac7..d2a82fe 100644 --- a/pkg/migrations/op_add_column.go +++ b/pkg/migrations/op_add_column.go @@ -33,7 +33,16 @@ func (o *OpAddColumn) Start(ctx context.Context, conn *sql.DB, schemaName, state } if o.Up != nil { - if err := createTrigger(ctx, conn, o, schemaName, stateSchema, s); err != nil { + err := createTrigger(ctx, conn, s, triggerConfig{ + Direction: TriggerDirectionUp, + SchemaName: schemaName, + StateSchema: stateSchema, + Table: o.Table, + Column: o.Column.Name, + PhysicalColumn: TemporaryName(o.Column.Name), + SQL: *o.Up, + }) + if err != nil { return fmt.Errorf("failed to create trigger: %w", err) } if err := backFill(ctx, conn, o); err != nil { @@ -154,72 +163,6 @@ func addNotNullConstraint(ctx context.Context, conn *sql.DB, o *OpAddColumn) err return err } -func createTrigger(ctx context.Context, conn *sql.DB, o *OpAddColumn, schemaName, stateSchema string, s *schema.Schema) error { - // Generate the SQL declarations for the trigger function - // This results in declarations like: - // col1 table.col1%TYPE := NEW.col1; - // Without these declarations, users would have to reference - // `col1` as `NEW.col1` in their `up` SQL. - sqlDeclarations := func(s *schema.Schema) string { - table := s.GetTable(o.Table) - - decls := "" - for _, c := range table.Columns { - decls += fmt.Sprintf("%[1]s %[2]s.%[1]s%%TYPE := NEW.%[1]s;\n", - pq.QuoteIdentifier(c.Name), - pq.QuoteIdentifier(table.Name)) - } - return decls - } - - //nolint:gosec // unavoidable SQL injection warning when running arbitrary SQL - triggerFn := fmt.Sprintf(`CREATE OR REPLACE FUNCTION %[1]s() - RETURNS TRIGGER - LANGUAGE PLPGSQL - AS $$ - DECLARE - %[4]s - latest_schema text; - search_path text; - BEGIN - SELECT %[5]s || '_' || latest_version INTO latest_schema FROM %[6]s.latest_version(%[5]s); - SELECT current_setting INTO search_path FROM current_setting('search_path'); - - IF search_path <> latest_schema THEN - NEW.%[2]s = %[3]s; - END IF; - - RETURN NEW; - END; $$`, - pq.QuoteIdentifier(TriggerFunctionName(o.Table, o.Column.Name)), - pq.QuoteIdentifier(TemporaryName(o.Column.Name)), - *o.Up, - sqlDeclarations(s), - pq.QuoteLiteral(schemaName), - pq.QuoteIdentifier(stateSchema)) - - _, err := conn.ExecContext(ctx, triggerFn) - if err != nil { - return err - } - - trigger := fmt.Sprintf(`CREATE OR REPLACE TRIGGER %[1]s - BEFORE UPDATE OR INSERT - ON %[2]s - FOR EACH ROW - EXECUTE PROCEDURE %[3]s();`, - pq.QuoteIdentifier(TriggerName(o.Table, o.Column.Name)), - pq.QuoteIdentifier(o.Table), - pq.QuoteIdentifier(TriggerFunctionName(o.Table, o.Column.Name))) - - _, err = conn.ExecContext(ctx, trigger) - if err != nil { - return err - } - - return nil -} - func backFill(ctx context.Context, conn *sql.DB, o *OpAddColumn) error { // touch rows without changing them in order to have the trigger fire // and set the value using the `up` SQL. @@ -235,11 +178,3 @@ func backFill(ctx context.Context, conn *sql.DB, o *OpAddColumn) error { func NotNullConstraintName(columnName string) string { return "_pgroll_add_column_check_" + columnName } - -func TriggerFunctionName(tableName, columnName string) string { - return "_pgroll_add_column_" + tableName + "_" + columnName -} - -func TriggerName(tableName, columnName string) string { - return TriggerFunctionName(tableName, columnName) -} diff --git a/pkg/migrations/op_common.go b/pkg/migrations/op_common.go index ef4c790..0e38dac 100644 --- a/pkg/migrations/op_common.go +++ b/pkg/migrations/op_common.go @@ -15,6 +15,7 @@ const ( OpNameRenameTable OpName = "rename_table" OpNameDropTable OpName = "drop_table" OpNameAddColumn OpName = "add_column" + OpNameDropColumn OpName = "drop_column" ) func TemporaryName(name string) string { @@ -82,6 +83,9 @@ func (v *Operations) UnmarshalJSON(data []byte) error { case OpNameAddColumn: item = &OpAddColumn{} + case OpNameDropColumn: + item = &OpDropColumn{} + default: return fmt.Errorf("unknown migration type: %v", opName) } @@ -126,6 +130,9 @@ func (v Operations) MarshalJSON() ([]byte, error) { case *OpAddColumn: opName = OpNameAddColumn + case *OpDropColumn: + opName = OpNameDropColumn + default: panic(fmt.Errorf("unknown operation for %T", op)) } diff --git a/pkg/migrations/op_drop_column.go b/pkg/migrations/op_drop_column.go new file mode 100644 index 0000000..cffe9f9 --- /dev/null +++ b/pkg/migrations/op_drop_column.go @@ -0,0 +1,70 @@ +package migrations + +import ( + "context" + "database/sql" + "fmt" + + "github.com/lib/pq" + + "pg-roll/pkg/schema" +) + +type OpDropColumn struct { + Table string `json:"table"` + Column string `json:"column"` + Down *string `json:"down,omitempty"` +} + +var _ Operation = (*OpDropColumn)(nil) + +func (o *OpDropColumn) Start(ctx context.Context, conn *sql.DB, schemaName string, stateSchema string, s *schema.Schema) error { + if o.Down != nil { + err := createTrigger(ctx, conn, s, triggerConfig{ + Direction: TriggerDirectionDown, + SchemaName: schemaName, + StateSchema: stateSchema, + Table: o.Table, + Column: o.Column, + PhysicalColumn: o.Column, + SQL: *o.Down, + }) + if err != nil { + return err + } + } + + s.GetTable(o.Table).RemoveColumn(o.Column) + return nil +} + +func (o *OpDropColumn) Complete(ctx context.Context, conn *sql.DB) error { + _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s DROP COLUMN %s", o.Table, o.Column)) + if err != nil { + return err + } + + _, err = conn.ExecContext(ctx, fmt.Sprintf("DROP FUNCTION IF EXISTS %s CASCADE", + pq.QuoteIdentifier(TriggerFunctionName(o.Table, o.Column)))) + + return err +} + +func (o *OpDropColumn) Rollback(ctx context.Context, conn *sql.DB) error { + _, err := conn.ExecContext(ctx, fmt.Sprintf("DROP FUNCTION IF EXISTS %s CASCADE", + pq.QuoteIdentifier(TriggerFunctionName(o.Table, o.Column)))) + + return err +} + +func (o *OpDropColumn) Validate(ctx context.Context, s *schema.Schema) error { + table := s.GetTable(o.Table) + + if table == nil { + return TableDoesNotExistError{Name: o.Table} + } + if table.GetColumn(o.Column) == nil { + return ColumnDoesNotExistError{Table: o.Table, Name: o.Column} + } + return nil +} diff --git a/pkg/migrations/op_drop_column_test.go b/pkg/migrations/op_drop_column_test.go new file mode 100644 index 0000000..7114ad9 --- /dev/null +++ b/pkg/migrations/op_drop_column_test.go @@ -0,0 +1,108 @@ +package migrations_test + +import ( + "database/sql" + "testing" + + "github.com/stretchr/testify/assert" + + "pg-roll/pkg/migrations" + "pg-roll/pkg/roll" +) + +func TestDropColumnWithDownSQL(t *testing.T) { + t.Parallel() + + ptr := func(s string) *string { return &s } + + ExecuteTests(t, TestCases{{ + name: "drop column", + migrations: []migrations.Migration{ + { + Name: "01_add_table", + Operations: migrations.Operations{ + &migrations.OpCreateTable{ + Name: "users", + Columns: []migrations.Column{ + { + Name: "id", + Type: "serial", + PrimaryKey: true, + }, + { + Name: "name", + Type: "varchar(255)", + Nullable: false, + }, + { + Name: "email", + Type: "varchar(255)", + Nullable: false, + }, + }, + }, + }, + }, + { + Name: "02_drop_column", + Operations: migrations.Operations{ + &migrations.OpDropColumn{ + Table: "users", + Column: "name", + Down: ptr("UPPER(email)"), + }, + }, + }, + }, + afterStart: func(t *testing.T, db *sql.DB) { + // The deleted column is not present on the view in the new version schema. + versionSchema := roll.VersionedSchemaName("public", "02_drop_column") + ColumnMustNotExist(t, db, versionSchema, "users", "name") + + // But the column is still present on the underlying table. + ColumnMustExist(t, db, "public", "users", "name") + + // Inserting into the view in the new version schema should succeed. + MustInsert(t, db, "public", "02_drop_column", "users", map[string]string{ + "email": "foo@example.com", + }) + + // The "down" SQL has populated the removed column ("name") + results := MustSelect(t, db, "public", "01_add_table", "users") + assert.Equal(t, []map[string]any{ + {"id": 1, "name": "FOO@EXAMPLE.COM", "email": "foo@example.com"}, + }, results) + }, + afterRollback: func(t *testing.T, db *sql.DB) { + // The trigger function has been dropped. + triggerFnName := migrations.TriggerFunctionName("users", "name") + FunctionMustNotExist(t, db, "public", triggerFnName) + + // The trigger has been dropped. + triggerName := migrations.TriggerName("users", "name") + TriggerMustNotExist(t, db, "public", "users", triggerName) + }, + afterComplete: func(t *testing.T, db *sql.DB) { + // The column has been deleted from the underlying table. + ColumnMustNotExist(t, db, "public", "users", "name") + + // The trigger function has been dropped. + triggerFnName := migrations.TriggerFunctionName("users", "name") + FunctionMustNotExist(t, db, "public", triggerFnName) + + // The trigger has been dropped. + triggerName := migrations.TriggerName("users", "name") + TriggerMustNotExist(t, db, "public", "users", triggerName) + + // Inserting into the view in the new version schema should succeed. + MustInsert(t, db, "public", "02_drop_column", "users", map[string]string{ + "email": "bar@example.com", + }) + results := MustSelect(t, db, "public", "02_drop_column", "users") + assert.Equal(t, []map[string]any{ + {"id": 1, "email": "foo@example.com"}, + {"id": 2, "email": "bar@example.com"}, + }, results) + }, + }}) +} diff --git a/pkg/migrations/triggers.go b/pkg/migrations/triggers.go new file mode 100644 index 0000000..8eef46a --- /dev/null +++ b/pkg/migrations/triggers.go @@ -0,0 +1,109 @@ +package migrations + +import ( + "context" + "database/sql" + "fmt" + + "pg-roll/pkg/schema" + + "github.com/lib/pq" +) + +type TriggerDirection string + +const ( + TriggerDirectionUp TriggerDirection = "up" + TriggerDirectionDown TriggerDirection = "down" +) + +type triggerConfig struct { + Direction TriggerDirection + SchemaName string + StateSchema string + Table string + Column string + PhysicalColumn string + SQL string +} + +func createTrigger(ctx context.Context, conn *sql.DB, s *schema.Schema, cfg triggerConfig) error { + // Generate the SQL declarations for the trigger function + // This results in declarations like: + // col1 table.col1%TYPE := NEW.col1; + // Without these declarations, users would have to reference + // `col1` as `NEW.col1` in their `up` SQL. + sqlDeclarations := func(s *schema.Schema) string { + table := s.GetTable(cfg.Table) + + decls := "" + for _, c := range table.Columns { + decls += fmt.Sprintf("%[1]s %[3]s.%[2]s.%[1]s%%TYPE := NEW.%[1]s;\n", + pq.QuoteIdentifier(c.Name), + pq.QuoteIdentifier(table.Name), + pq.QuoteIdentifier(cfg.SchemaName)) + } + return decls + } + + cmp := "<>" + if cfg.Direction == TriggerDirectionDown { + cmp = "=" + } + + //nolint:gosec // unavoidable SQL injection warning when running arbitrary SQL + triggerFn := fmt.Sprintf(`CREATE OR REPLACE FUNCTION %[1]s() + RETURNS TRIGGER + LANGUAGE PLPGSQL + AS $$ + DECLARE + %[4]s + latest_schema text; + search_path text; + BEGIN + SELECT %[5]s || '_' || latest_version INTO latest_schema FROM %[6]s.latest_version(%[5]s); + SELECT current_setting INTO search_path FROM current_setting('search_path'); + + IF search_path %[7]s latest_schema THEN + NEW.%[2]s = %[3]s; + END IF; + + RETURN NEW; + END; $$`, + pq.QuoteIdentifier(TriggerFunctionName(cfg.Table, cfg.Column)), + pq.QuoteIdentifier(cfg.PhysicalColumn), + cfg.SQL, + sqlDeclarations(s), + pq.QuoteLiteral(cfg.SchemaName), + pq.QuoteIdentifier(cfg.StateSchema), + cmp) + + _, err := conn.ExecContext(ctx, triggerFn) + if err != nil { + return err + } + + trigger := fmt.Sprintf(`CREATE OR REPLACE TRIGGER %[1]s + BEFORE UPDATE OR INSERT + ON %[2]s + FOR EACH ROW + EXECUTE PROCEDURE %[3]s();`, + pq.QuoteIdentifier(TriggerName(cfg.Table, cfg.Column)), + pq.QuoteIdentifier(cfg.Table), + pq.QuoteIdentifier(TriggerFunctionName(cfg.Table, cfg.Column))) + + _, err = conn.ExecContext(ctx, trigger) + if err != nil { + return err + } + + return nil +} + +func TriggerFunctionName(tableName, columnName string) string { + return "_pgroll_trigger_" + tableName + "_" + columnName +} + +func TriggerName(tableName, columnName string) string { + return TriggerFunctionName(tableName, columnName) +} diff --git a/pkg/schema/schema.go b/pkg/schema/schema.go index 4bd25c5..801a29a 100644 --- a/pkg/schema/schema.go +++ b/pkg/schema/schema.go @@ -102,3 +102,7 @@ func (t *Table) AddColumn(name string, c Column) { t.Columns[name] = c } + +func (t *Table) RemoveColumn(column string) { + delete(t.Columns, column) +}