mirror of
https://github.com/xataio/pgroll.git
synced 2024-07-14 17:10:33 +03:00
Implement 'drop column' migrations (#48)
Implement the **drop column** migration operation. A migration to drop a column looks like this: ```json { "name": "09_drop_column", "operations": [ { "drop_column": { "table": "fruits", "column": "price", "down": "0" } } ] } ``` The migration takes the name of the table and column that should be dropped along with (optionally) some `down` SQL to run to populate the field in the underlying table when insertions are done via the new schema version while the migration is in progress. * On `Start`, the relevant view in the new version schema is created without the dropped column. The column is not deleted from the underlying table. * If `down` SQL is specified, a trigger is created on the underlying table to populate the column to be removed when inserts are made from the new schema version. * On `Rollback` any triggers on the underlying table are removed. * On `Complete` the old version of the schema is removed and the column is removed from the underlying table. Any triggers are also removed.
This commit is contained in:
parent
450e6db231
commit
f764993640
26
examples/08_create_fruits_table.json
Normal file
26
examples/08_create_fruits_table.json
Normal file
@ -0,0 +1,26 @@
|
||||
{
|
||||
"name": "08_create_fruits_table",
|
||||
"operations": [
|
||||
{
|
||||
"create_table": {
|
||||
"name": "fruits",
|
||||
"columns": [
|
||||
{
|
||||
"name": "id",
|
||||
"type": "serial",
|
||||
"pk": true
|
||||
},
|
||||
{
|
||||
"name": "name",
|
||||
"type": "varchar(255)",
|
||||
"unique": true
|
||||
},
|
||||
{
|
||||
"name": "price",
|
||||
"type": "decimal(10,2)"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
12
examples/09_drop_column.json
Normal file
12
examples/09_drop_column.json
Normal file
@ -0,0 +1,12 @@
|
||||
{
|
||||
"name": "09_drop_column",
|
||||
"operations": [
|
||||
{
|
||||
"drop_column": {
|
||||
"table": "fruits",
|
||||
"column": "price",
|
||||
"down": "0"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
@ -26,3 +26,12 @@ type ColumnAlreadyExistsError struct {
|
||||
func (e ColumnAlreadyExistsError) Error() string {
|
||||
return fmt.Sprintf("column %q already exists in table %q", e.Name, e.Table)
|
||||
}
|
||||
|
||||
type ColumnDoesNotExistError struct {
|
||||
Table string
|
||||
Name string
|
||||
}
|
||||
|
||||
func (e ColumnDoesNotExistError) Error() string {
|
||||
return fmt.Sprintf("column %q does not exist on table %q", e.Name, e.Table)
|
||||
}
|
||||
|
@ -33,7 +33,16 @@ func (o *OpAddColumn) Start(ctx context.Context, conn *sql.DB, schemaName, state
|
||||
}
|
||||
|
||||
if o.Up != nil {
|
||||
if err := createTrigger(ctx, conn, o, schemaName, stateSchema, s); err != nil {
|
||||
err := createTrigger(ctx, conn, s, triggerConfig{
|
||||
Direction: TriggerDirectionUp,
|
||||
SchemaName: schemaName,
|
||||
StateSchema: stateSchema,
|
||||
Table: o.Table,
|
||||
Column: o.Column.Name,
|
||||
PhysicalColumn: TemporaryName(o.Column.Name),
|
||||
SQL: *o.Up,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create trigger: %w", err)
|
||||
}
|
||||
if err := backFill(ctx, conn, o); err != nil {
|
||||
@ -154,72 +163,6 @@ func addNotNullConstraint(ctx context.Context, conn *sql.DB, o *OpAddColumn) err
|
||||
return err
|
||||
}
|
||||
|
||||
func createTrigger(ctx context.Context, conn *sql.DB, o *OpAddColumn, schemaName, stateSchema string, s *schema.Schema) error {
|
||||
// Generate the SQL declarations for the trigger function
|
||||
// This results in declarations like:
|
||||
// col1 table.col1%TYPE := NEW.col1;
|
||||
// Without these declarations, users would have to reference
|
||||
// `col1` as `NEW.col1` in their `up` SQL.
|
||||
sqlDeclarations := func(s *schema.Schema) string {
|
||||
table := s.GetTable(o.Table)
|
||||
|
||||
decls := ""
|
||||
for _, c := range table.Columns {
|
||||
decls += fmt.Sprintf("%[1]s %[2]s.%[1]s%%TYPE := NEW.%[1]s;\n",
|
||||
pq.QuoteIdentifier(c.Name),
|
||||
pq.QuoteIdentifier(table.Name))
|
||||
}
|
||||
return decls
|
||||
}
|
||||
|
||||
//nolint:gosec // unavoidable SQL injection warning when running arbitrary SQL
|
||||
triggerFn := fmt.Sprintf(`CREATE OR REPLACE FUNCTION %[1]s()
|
||||
RETURNS TRIGGER
|
||||
LANGUAGE PLPGSQL
|
||||
AS $$
|
||||
DECLARE
|
||||
%[4]s
|
||||
latest_schema text;
|
||||
search_path text;
|
||||
BEGIN
|
||||
SELECT %[5]s || '_' || latest_version INTO latest_schema FROM %[6]s.latest_version(%[5]s);
|
||||
SELECT current_setting INTO search_path FROM current_setting('search_path');
|
||||
|
||||
IF search_path <> latest_schema THEN
|
||||
NEW.%[2]s = %[3]s;
|
||||
END IF;
|
||||
|
||||
RETURN NEW;
|
||||
END; $$`,
|
||||
pq.QuoteIdentifier(TriggerFunctionName(o.Table, o.Column.Name)),
|
||||
pq.QuoteIdentifier(TemporaryName(o.Column.Name)),
|
||||
*o.Up,
|
||||
sqlDeclarations(s),
|
||||
pq.QuoteLiteral(schemaName),
|
||||
pq.QuoteIdentifier(stateSchema))
|
||||
|
||||
_, err := conn.ExecContext(ctx, triggerFn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
trigger := fmt.Sprintf(`CREATE OR REPLACE TRIGGER %[1]s
|
||||
BEFORE UPDATE OR INSERT
|
||||
ON %[2]s
|
||||
FOR EACH ROW
|
||||
EXECUTE PROCEDURE %[3]s();`,
|
||||
pq.QuoteIdentifier(TriggerName(o.Table, o.Column.Name)),
|
||||
pq.QuoteIdentifier(o.Table),
|
||||
pq.QuoteIdentifier(TriggerFunctionName(o.Table, o.Column.Name)))
|
||||
|
||||
_, err = conn.ExecContext(ctx, trigger)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func backFill(ctx context.Context, conn *sql.DB, o *OpAddColumn) error {
|
||||
// touch rows without changing them in order to have the trigger fire
|
||||
// and set the value using the `up` SQL.
|
||||
@ -235,11 +178,3 @@ func backFill(ctx context.Context, conn *sql.DB, o *OpAddColumn) error {
|
||||
func NotNullConstraintName(columnName string) string {
|
||||
return "_pgroll_add_column_check_" + columnName
|
||||
}
|
||||
|
||||
func TriggerFunctionName(tableName, columnName string) string {
|
||||
return "_pgroll_add_column_" + tableName + "_" + columnName
|
||||
}
|
||||
|
||||
func TriggerName(tableName, columnName string) string {
|
||||
return TriggerFunctionName(tableName, columnName)
|
||||
}
|
||||
|
@ -15,6 +15,7 @@ const (
|
||||
OpNameRenameTable OpName = "rename_table"
|
||||
OpNameDropTable OpName = "drop_table"
|
||||
OpNameAddColumn OpName = "add_column"
|
||||
OpNameDropColumn OpName = "drop_column"
|
||||
)
|
||||
|
||||
func TemporaryName(name string) string {
|
||||
@ -82,6 +83,9 @@ func (v *Operations) UnmarshalJSON(data []byte) error {
|
||||
case OpNameAddColumn:
|
||||
item = &OpAddColumn{}
|
||||
|
||||
case OpNameDropColumn:
|
||||
item = &OpDropColumn{}
|
||||
|
||||
default:
|
||||
return fmt.Errorf("unknown migration type: %v", opName)
|
||||
}
|
||||
@ -126,6 +130,9 @@ func (v Operations) MarshalJSON() ([]byte, error) {
|
||||
case *OpAddColumn:
|
||||
opName = OpNameAddColumn
|
||||
|
||||
case *OpDropColumn:
|
||||
opName = OpNameDropColumn
|
||||
|
||||
default:
|
||||
panic(fmt.Errorf("unknown operation for %T", op))
|
||||
}
|
||||
|
70
pkg/migrations/op_drop_column.go
Normal file
70
pkg/migrations/op_drop_column.go
Normal file
@ -0,0 +1,70 @@
|
||||
package migrations
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/lib/pq"
|
||||
|
||||
"pg-roll/pkg/schema"
|
||||
)
|
||||
|
||||
type OpDropColumn struct {
|
||||
Table string `json:"table"`
|
||||
Column string `json:"column"`
|
||||
Down *string `json:"down,omitempty"`
|
||||
}
|
||||
|
||||
var _ Operation = (*OpDropColumn)(nil)
|
||||
|
||||
func (o *OpDropColumn) Start(ctx context.Context, conn *sql.DB, schemaName string, stateSchema string, s *schema.Schema) error {
|
||||
if o.Down != nil {
|
||||
err := createTrigger(ctx, conn, s, triggerConfig{
|
||||
Direction: TriggerDirectionDown,
|
||||
SchemaName: schemaName,
|
||||
StateSchema: stateSchema,
|
||||
Table: o.Table,
|
||||
Column: o.Column,
|
||||
PhysicalColumn: o.Column,
|
||||
SQL: *o.Down,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
s.GetTable(o.Table).RemoveColumn(o.Column)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *OpDropColumn) Complete(ctx context.Context, conn *sql.DB) error {
|
||||
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s DROP COLUMN %s", o.Table, o.Column))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = conn.ExecContext(ctx, fmt.Sprintf("DROP FUNCTION IF EXISTS %s CASCADE",
|
||||
pq.QuoteIdentifier(TriggerFunctionName(o.Table, o.Column))))
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (o *OpDropColumn) Rollback(ctx context.Context, conn *sql.DB) error {
|
||||
_, err := conn.ExecContext(ctx, fmt.Sprintf("DROP FUNCTION IF EXISTS %s CASCADE",
|
||||
pq.QuoteIdentifier(TriggerFunctionName(o.Table, o.Column))))
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (o *OpDropColumn) Validate(ctx context.Context, s *schema.Schema) error {
|
||||
table := s.GetTable(o.Table)
|
||||
|
||||
if table == nil {
|
||||
return TableDoesNotExistError{Name: o.Table}
|
||||
}
|
||||
if table.GetColumn(o.Column) == nil {
|
||||
return ColumnDoesNotExistError{Table: o.Table, Name: o.Column}
|
||||
}
|
||||
return nil
|
||||
}
|
108
pkg/migrations/op_drop_column_test.go
Normal file
108
pkg/migrations/op_drop_column_test.go
Normal file
@ -0,0 +1,108 @@
|
||||
package migrations_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"pg-roll/pkg/migrations"
|
||||
"pg-roll/pkg/roll"
|
||||
)
|
||||
|
||||
func TestDropColumnWithDownSQL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ptr := func(s string) *string { return &s }
|
||||
|
||||
ExecuteTests(t, TestCases{{
|
||||
name: "drop column",
|
||||
migrations: []migrations.Migration{
|
||||
{
|
||||
Name: "01_add_table",
|
||||
Operations: migrations.Operations{
|
||||
&migrations.OpCreateTable{
|
||||
Name: "users",
|
||||
Columns: []migrations.Column{
|
||||
{
|
||||
Name: "id",
|
||||
Type: "serial",
|
||||
PrimaryKey: true,
|
||||
},
|
||||
{
|
||||
Name: "name",
|
||||
Type: "varchar(255)",
|
||||
Nullable: false,
|
||||
},
|
||||
{
|
||||
Name: "email",
|
||||
Type: "varchar(255)",
|
||||
Nullable: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "02_drop_column",
|
||||
Operations: migrations.Operations{
|
||||
&migrations.OpDropColumn{
|
||||
Table: "users",
|
||||
Column: "name",
|
||||
Down: ptr("UPPER(email)"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
afterStart: func(t *testing.T, db *sql.DB) {
|
||||
// The deleted column is not present on the view in the new version schema.
|
||||
versionSchema := roll.VersionedSchemaName("public", "02_drop_column")
|
||||
ColumnMustNotExist(t, db, versionSchema, "users", "name")
|
||||
|
||||
// But the column is still present on the underlying table.
|
||||
ColumnMustExist(t, db, "public", "users", "name")
|
||||
|
||||
// Inserting into the view in the new version schema should succeed.
|
||||
MustInsert(t, db, "public", "02_drop_column", "users", map[string]string{
|
||||
"email": "foo@example.com",
|
||||
})
|
||||
|
||||
// The "down" SQL has populated the removed column ("name")
|
||||
results := MustSelect(t, db, "public", "01_add_table", "users")
|
||||
assert.Equal(t, []map[string]any{
|
||||
{"id": 1, "name": "FOO@EXAMPLE.COM", "email": "foo@example.com"},
|
||||
}, results)
|
||||
},
|
||||
afterRollback: func(t *testing.T, db *sql.DB) {
|
||||
// The trigger function has been dropped.
|
||||
triggerFnName := migrations.TriggerFunctionName("users", "name")
|
||||
FunctionMustNotExist(t, db, "public", triggerFnName)
|
||||
|
||||
// The trigger has been dropped.
|
||||
triggerName := migrations.TriggerName("users", "name")
|
||||
TriggerMustNotExist(t, db, "public", "users", triggerName)
|
||||
},
|
||||
afterComplete: func(t *testing.T, db *sql.DB) {
|
||||
// The column has been deleted from the underlying table.
|
||||
ColumnMustNotExist(t, db, "public", "users", "name")
|
||||
|
||||
// The trigger function has been dropped.
|
||||
triggerFnName := migrations.TriggerFunctionName("users", "name")
|
||||
FunctionMustNotExist(t, db, "public", triggerFnName)
|
||||
|
||||
// The trigger has been dropped.
|
||||
triggerName := migrations.TriggerName("users", "name")
|
||||
TriggerMustNotExist(t, db, "public", "users", triggerName)
|
||||
|
||||
// Inserting into the view in the new version schema should succeed.
|
||||
MustInsert(t, db, "public", "02_drop_column", "users", map[string]string{
|
||||
"email": "bar@example.com",
|
||||
})
|
||||
results := MustSelect(t, db, "public", "02_drop_column", "users")
|
||||
assert.Equal(t, []map[string]any{
|
||||
{"id": 1, "email": "foo@example.com"},
|
||||
{"id": 2, "email": "bar@example.com"},
|
||||
}, results)
|
||||
},
|
||||
}})
|
||||
}
|
109
pkg/migrations/triggers.go
Normal file
109
pkg/migrations/triggers.go
Normal file
@ -0,0 +1,109 @@
|
||||
package migrations
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"pg-roll/pkg/schema"
|
||||
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
type TriggerDirection string
|
||||
|
||||
const (
|
||||
TriggerDirectionUp TriggerDirection = "up"
|
||||
TriggerDirectionDown TriggerDirection = "down"
|
||||
)
|
||||
|
||||
type triggerConfig struct {
|
||||
Direction TriggerDirection
|
||||
SchemaName string
|
||||
StateSchema string
|
||||
Table string
|
||||
Column string
|
||||
PhysicalColumn string
|
||||
SQL string
|
||||
}
|
||||
|
||||
func createTrigger(ctx context.Context, conn *sql.DB, s *schema.Schema, cfg triggerConfig) error {
|
||||
// Generate the SQL declarations for the trigger function
|
||||
// This results in declarations like:
|
||||
// col1 table.col1%TYPE := NEW.col1;
|
||||
// Without these declarations, users would have to reference
|
||||
// `col1` as `NEW.col1` in their `up` SQL.
|
||||
sqlDeclarations := func(s *schema.Schema) string {
|
||||
table := s.GetTable(cfg.Table)
|
||||
|
||||
decls := ""
|
||||
for _, c := range table.Columns {
|
||||
decls += fmt.Sprintf("%[1]s %[3]s.%[2]s.%[1]s%%TYPE := NEW.%[1]s;\n",
|
||||
pq.QuoteIdentifier(c.Name),
|
||||
pq.QuoteIdentifier(table.Name),
|
||||
pq.QuoteIdentifier(cfg.SchemaName))
|
||||
}
|
||||
return decls
|
||||
}
|
||||
|
||||
cmp := "<>"
|
||||
if cfg.Direction == TriggerDirectionDown {
|
||||
cmp = "="
|
||||
}
|
||||
|
||||
//nolint:gosec // unavoidable SQL injection warning when running arbitrary SQL
|
||||
triggerFn := fmt.Sprintf(`CREATE OR REPLACE FUNCTION %[1]s()
|
||||
RETURNS TRIGGER
|
||||
LANGUAGE PLPGSQL
|
||||
AS $$
|
||||
DECLARE
|
||||
%[4]s
|
||||
latest_schema text;
|
||||
search_path text;
|
||||
BEGIN
|
||||
SELECT %[5]s || '_' || latest_version INTO latest_schema FROM %[6]s.latest_version(%[5]s);
|
||||
SELECT current_setting INTO search_path FROM current_setting('search_path');
|
||||
|
||||
IF search_path %[7]s latest_schema THEN
|
||||
NEW.%[2]s = %[3]s;
|
||||
END IF;
|
||||
|
||||
RETURN NEW;
|
||||
END; $$`,
|
||||
pq.QuoteIdentifier(TriggerFunctionName(cfg.Table, cfg.Column)),
|
||||
pq.QuoteIdentifier(cfg.PhysicalColumn),
|
||||
cfg.SQL,
|
||||
sqlDeclarations(s),
|
||||
pq.QuoteLiteral(cfg.SchemaName),
|
||||
pq.QuoteIdentifier(cfg.StateSchema),
|
||||
cmp)
|
||||
|
||||
_, err := conn.ExecContext(ctx, triggerFn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
trigger := fmt.Sprintf(`CREATE OR REPLACE TRIGGER %[1]s
|
||||
BEFORE UPDATE OR INSERT
|
||||
ON %[2]s
|
||||
FOR EACH ROW
|
||||
EXECUTE PROCEDURE %[3]s();`,
|
||||
pq.QuoteIdentifier(TriggerName(cfg.Table, cfg.Column)),
|
||||
pq.QuoteIdentifier(cfg.Table),
|
||||
pq.QuoteIdentifier(TriggerFunctionName(cfg.Table, cfg.Column)))
|
||||
|
||||
_, err = conn.ExecContext(ctx, trigger)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func TriggerFunctionName(tableName, columnName string) string {
|
||||
return "_pgroll_trigger_" + tableName + "_" + columnName
|
||||
}
|
||||
|
||||
func TriggerName(tableName, columnName string) string {
|
||||
return TriggerFunctionName(tableName, columnName)
|
||||
}
|
@ -102,3 +102,7 @@ func (t *Table) AddColumn(name string, c Column) {
|
||||
|
||||
t.Columns[name] = c
|
||||
}
|
||||
|
||||
func (t *Table) RemoveColumn(column string) {
|
||||
delete(t.Columns, column)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user