mirror of
https://github.com/xataio/pgroll.git
synced 2024-09-11 05:45:48 +03:00
Support up
SQL on add column operations (#34)
Add a new field `Up` to **add column** migrations: ```json { "name": "03_add_column_to_products", "operations": [ { "add_column": { "table": "products", "up": "UPPER(name)", "column": { "name": "description", "type": "varchar(255)", "nullable": true } } } ] } ``` The SQL specified by the `up` field will be run whenever an row is inserted into the underlying table when the session's `search_path` is not set to the latest version of the schema. The `up` SQL snippet can refer to existing columns in the table by name (as in the the above example, where the `description` field is set to `UPPER(name)`).
This commit is contained in:
parent
37b75384a9
commit
b4efd8ad50
@ -4,6 +4,7 @@
|
||||
{
|
||||
"add_column": {
|
||||
"table": "products",
|
||||
"up": "UPPER(name)",
|
||||
"column": {
|
||||
"name": "description",
|
||||
"type": "varchar(255)",
|
||||
|
@ -13,7 +13,7 @@ type Operation interface {
|
||||
// Start will apply the required changes to enable supporting the new schema
|
||||
// version in the database (through a view)
|
||||
// update the given views to expose the new schema version
|
||||
Start(ctx context.Context, conn *sql.DB, s *schema.Schema) error
|
||||
Start(ctx context.Context, conn *sql.DB, schemaName, stateSchema string, s *schema.Schema) error
|
||||
|
||||
// Complete will update the database schema to match the current version
|
||||
// after calling Start.
|
||||
|
@ -12,19 +12,26 @@ import (
|
||||
)
|
||||
|
||||
type OpAddColumn struct {
|
||||
Table string `json:"table"`
|
||||
Column Column `json:"column"`
|
||||
Table string `json:"table"`
|
||||
Up *string `json:"up"`
|
||||
Column Column `json:"column"`
|
||||
}
|
||||
|
||||
var _ Operation = (*OpAddColumn)(nil)
|
||||
|
||||
func (o *OpAddColumn) Start(ctx context.Context, conn *sql.DB, s *schema.Schema) error {
|
||||
func (o *OpAddColumn) Start(ctx context.Context, conn *sql.DB, schemaName, stateSchema string, s *schema.Schema) error {
|
||||
table := s.GetTable(o.Table)
|
||||
|
||||
if err := addColumn(ctx, conn, *o, table); err != nil {
|
||||
return fmt.Errorf("failed to start add column operation: %w", err)
|
||||
}
|
||||
|
||||
if o.Up != nil {
|
||||
if err := createTrigger(ctx, conn, o, schemaName, stateSchema, s); err != nil {
|
||||
return fmt.Errorf("failed to create trigger: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
table.AddColumn(o.Column.Name, schema.Column{
|
||||
Name: TemporaryName(o.Column.Name),
|
||||
})
|
||||
@ -82,3 +89,74 @@ func addColumn(ctx context.Context, conn *sql.DB, o OpAddColumn, t *schema.Table
|
||||
))
|
||||
return err
|
||||
}
|
||||
|
||||
func createTrigger(ctx context.Context, conn *sql.DB, o *OpAddColumn, schemaName, stateSchema string, s *schema.Schema) error {
|
||||
triggerFnName := func(o *OpAddColumn) string {
|
||||
return "_pgroll_add_column_" + o.Table + "_" + o.Column.Name
|
||||
}
|
||||
triggerName := triggerFnName
|
||||
|
||||
// Generate the SQL declarations for the trigger function
|
||||
// This results in declarations like:
|
||||
// col1 table.col1%TYPE := NEW.col1;
|
||||
// Without these declarations, users would have to reference
|
||||
// `col1` as `NEW.col1` in their `up` SQL.
|
||||
sqlDeclarations := func(s *schema.Schema) string {
|
||||
table := s.GetTable(o.Table)
|
||||
|
||||
decls := ""
|
||||
for _, c := range table.Columns {
|
||||
decls += fmt.Sprintf("%[1]s %[2]s.%[1]s%%TYPE := NEW.%[1]s;\n",
|
||||
pq.QuoteIdentifier(c.Name),
|
||||
pq.QuoteIdentifier(table.Name))
|
||||
}
|
||||
return decls
|
||||
}
|
||||
|
||||
//nolint:gosec // I don't think we can avoid SQL injection warnings here when running arbitrary SQL
|
||||
triggerFn := fmt.Sprintf(`CREATE OR REPLACE FUNCTION %[1]s()
|
||||
RETURNS TRIGGER
|
||||
LANGUAGE PLPGSQL
|
||||
AS $$
|
||||
DECLARE
|
||||
%[4]s
|
||||
latest_schema text;
|
||||
search_path text;
|
||||
BEGIN
|
||||
SELECT %[5]s || '_' || latest_version INTO latest_schema FROM %[6]s.latest_version(%[5]s);
|
||||
SELECT current_setting INTO search_path FROM current_setting('search_path');
|
||||
|
||||
IF search_path <> latest_schema THEN
|
||||
NEW.%[2]s = %[3]s;
|
||||
END IF;
|
||||
|
||||
RETURN NEW;
|
||||
END; $$`,
|
||||
pq.QuoteIdentifier(triggerFnName(o)),
|
||||
pq.QuoteIdentifier(TemporaryName(o.Column.Name)),
|
||||
*o.Up,
|
||||
sqlDeclarations(s),
|
||||
pq.QuoteLiteral(schemaName),
|
||||
pq.QuoteIdentifier(stateSchema))
|
||||
|
||||
_, err := conn.ExecContext(ctx, triggerFn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
trigger := fmt.Sprintf(`CREATE OR REPLACE TRIGGER %[1]s
|
||||
BEFORE UPDATE OR INSERT
|
||||
ON %[2]s
|
||||
FOR EACH ROW
|
||||
EXECUTE PROCEDURE %[3]s();`,
|
||||
pq.QuoteIdentifier(triggerName(o)),
|
||||
pq.QuoteIdentifier(o.Table),
|
||||
pq.QuoteIdentifier(triggerFnName(o)))
|
||||
|
||||
_, err = conn.ExecContext(ctx, trigger)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -106,3 +106,72 @@ func TestAddColumn(t *testing.T) {
|
||||
},
|
||||
}})
|
||||
}
|
||||
|
||||
func TestAddColumnWithUpSql(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ptr := func(s string) *string { return &s }
|
||||
|
||||
ExecuteTests(t, TestCases{{
|
||||
name: "add column",
|
||||
migrations: []migrations.Migration{
|
||||
{
|
||||
Name: "01_add_table",
|
||||
Operations: migrations.Operations{
|
||||
&migrations.OpCreateTable{
|
||||
Name: "products",
|
||||
Columns: []migrations.Column{
|
||||
{
|
||||
Name: "id",
|
||||
Type: "serial",
|
||||
PrimaryKey: true,
|
||||
},
|
||||
{
|
||||
Name: "name",
|
||||
Type: "varchar(255)",
|
||||
Unique: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "02_add_column",
|
||||
Operations: migrations.Operations{
|
||||
&migrations.OpAddColumn{
|
||||
Table: "products",
|
||||
Up: ptr("UPPER(name)"),
|
||||
Column: migrations.Column{
|
||||
Name: "description",
|
||||
Type: "varchar(255)",
|
||||
Nullable: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
afterStart: func(t *testing.T, db *sql.DB) {
|
||||
// inserting via both the old and the new views works
|
||||
MustInsert(t, db, "public", "01_add_table", "products", map[string]string{
|
||||
"name": "apple",
|
||||
})
|
||||
MustInsert(t, db, "public", "02_add_column", "products", map[string]string{
|
||||
"name": "banana",
|
||||
"description": "a yellow banana",
|
||||
})
|
||||
|
||||
res := MustSelect(t, db, "public", "02_add_column", "products")
|
||||
assert.Equal(t, []map[string]any{
|
||||
// the description column has been populated for the product inserted into the old view.
|
||||
{"id": 1, "name": "apple", "description": "APPLE"},
|
||||
// the description column for the product inserted into the new view is as inserted.
|
||||
{"id": 2, "name": "banana", "description": "a yellow banana"},
|
||||
}, res)
|
||||
},
|
||||
afterRollback: func(t *testing.T, db *sql.DB) {
|
||||
// TODO check that the trigger created by the start operation has been removed
|
||||
},
|
||||
afterComplete: func(t *testing.T, db *sql.DB) {
|
||||
},
|
||||
}})
|
||||
}
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/testcontainers/testcontainers-go"
|
||||
"github.com/testcontainers/testcontainers-go/modules/postgres"
|
||||
"github.com/testcontainers/testcontainers-go/wait"
|
||||
@ -282,6 +283,8 @@ func MustInsert(t *testing.T, db *sql.DB, schema, version, table string, record
|
||||
t.Helper()
|
||||
versionSchema := roll.VersionedSchemaName(schema, version)
|
||||
|
||||
mustSetSearchPath(t, db, versionSchema)
|
||||
|
||||
cols := maps.Keys(record)
|
||||
slices.Sort(cols)
|
||||
|
||||
@ -352,3 +355,12 @@ func MustSelect(t *testing.T, db *sql.DB, schema, version, table string) []map[s
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
func mustSetSearchPath(t *testing.T, db *sql.DB, schema string) {
|
||||
t.Helper()
|
||||
|
||||
_, err := db.Exec(fmt.Sprintf("SET search_path = %s", pq.QuoteIdentifier(schema)))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
@ -26,7 +26,7 @@ type Column struct {
|
||||
Default *string `json:"default"`
|
||||
}
|
||||
|
||||
func (o *OpCreateTable) Start(ctx context.Context, conn *sql.DB, s *schema.Schema) error {
|
||||
func (o *OpCreateTable) Start(ctx context.Context, conn *sql.DB, schemaName, stateSchema string, s *schema.Schema) error {
|
||||
tempName := TemporaryName(o.Name)
|
||||
_, err := conn.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s (%s)",
|
||||
pq.QuoteIdentifier(tempName),
|
||||
|
@ -17,7 +17,7 @@ type OpRenameTable struct {
|
||||
To string `json:"to"`
|
||||
}
|
||||
|
||||
func (o *OpRenameTable) Start(ctx context.Context, conn *sql.DB, s *schema.Schema) error {
|
||||
func (o *OpRenameTable) Start(ctx context.Context, conn *sql.DB, schemaName, stateSchema string, s *schema.Schema) error {
|
||||
return s.RenameTable(o.From, o.To)
|
||||
}
|
||||
|
||||
|
@ -39,7 +39,7 @@ func (m *Roll) Start(ctx context.Context, migration *migrations.Migration) error
|
||||
|
||||
// execute operations
|
||||
for _, op := range migration.Operations {
|
||||
err := op.Start(ctx, m.pgConn, newSchema)
|
||||
err := op.Start(ctx, m.pgConn, m.schema, m.state.Schema(), newSchema)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to execute start operation: %w", err)
|
||||
}
|
||||
|
@ -163,6 +163,10 @@ func (s *State) Close() error {
|
||||
return s.pgConn.Close()
|
||||
}
|
||||
|
||||
func (s *State) Schema() string {
|
||||
return s.schema
|
||||
}
|
||||
|
||||
// IsActiveMigrationPeriod returns true if there is an active migration
|
||||
func (s *State) IsActiveMigrationPeriod(ctx context.Context, schema string) (bool, error) {
|
||||
var isActive bool
|
||||
|
Loading…
Reference in New Issue
Block a user