Support up SQL on add column operations (#34)

Add a new field `Up` to **add column** migrations:

```json
{
  "name": "03_add_column_to_products",
  "operations": [
    {
      "add_column": {
        "table": "products",
        "up": "UPPER(name)",
        "column": {
          "name": "description",
          "type": "varchar(255)",
          "nullable": true
        }
      }
    }
  ]
}
```

The SQL specified by the `up` field will be run whenever an row is
inserted into the underlying table when the session's `search_path` is
not set to the latest version of the schema.

The `up` SQL snippet can refer to existing columns in the table by name
(as in the the above example, where the `description` field is set to
`UPPER(name)`).
This commit is contained in:
Andrew Farries 2023-07-20 06:37:03 +01:00 committed by GitHub
parent 37b75384a9
commit b4efd8ad50
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 171 additions and 7 deletions

View File

@ -4,6 +4,7 @@
{
"add_column": {
"table": "products",
"up": "UPPER(name)",
"column": {
"name": "description",
"type": "varchar(255)",

View File

@ -13,7 +13,7 @@ type Operation interface {
// Start will apply the required changes to enable supporting the new schema
// version in the database (through a view)
// update the given views to expose the new schema version
Start(ctx context.Context, conn *sql.DB, s *schema.Schema) error
Start(ctx context.Context, conn *sql.DB, schemaName, stateSchema string, s *schema.Schema) error
// Complete will update the database schema to match the current version
// after calling Start.

View File

@ -12,19 +12,26 @@ import (
)
type OpAddColumn struct {
Table string `json:"table"`
Column Column `json:"column"`
Table string `json:"table"`
Up *string `json:"up"`
Column Column `json:"column"`
}
var _ Operation = (*OpAddColumn)(nil)
func (o *OpAddColumn) Start(ctx context.Context, conn *sql.DB, s *schema.Schema) error {
func (o *OpAddColumn) Start(ctx context.Context, conn *sql.DB, schemaName, stateSchema string, s *schema.Schema) error {
table := s.GetTable(o.Table)
if err := addColumn(ctx, conn, *o, table); err != nil {
return fmt.Errorf("failed to start add column operation: %w", err)
}
if o.Up != nil {
if err := createTrigger(ctx, conn, o, schemaName, stateSchema, s); err != nil {
return fmt.Errorf("failed to create trigger: %w", err)
}
}
table.AddColumn(o.Column.Name, schema.Column{
Name: TemporaryName(o.Column.Name),
})
@ -82,3 +89,74 @@ func addColumn(ctx context.Context, conn *sql.DB, o OpAddColumn, t *schema.Table
))
return err
}
func createTrigger(ctx context.Context, conn *sql.DB, o *OpAddColumn, schemaName, stateSchema string, s *schema.Schema) error {
triggerFnName := func(o *OpAddColumn) string {
return "_pgroll_add_column_" + o.Table + "_" + o.Column.Name
}
triggerName := triggerFnName
// Generate the SQL declarations for the trigger function
// This results in declarations like:
// col1 table.col1%TYPE := NEW.col1;
// Without these declarations, users would have to reference
// `col1` as `NEW.col1` in their `up` SQL.
sqlDeclarations := func(s *schema.Schema) string {
table := s.GetTable(o.Table)
decls := ""
for _, c := range table.Columns {
decls += fmt.Sprintf("%[1]s %[2]s.%[1]s%%TYPE := NEW.%[1]s;\n",
pq.QuoteIdentifier(c.Name),
pq.QuoteIdentifier(table.Name))
}
return decls
}
//nolint:gosec // I don't think we can avoid SQL injection warnings here when running arbitrary SQL
triggerFn := fmt.Sprintf(`CREATE OR REPLACE FUNCTION %[1]s()
RETURNS TRIGGER
LANGUAGE PLPGSQL
AS $$
DECLARE
%[4]s
latest_schema text;
search_path text;
BEGIN
SELECT %[5]s || '_' || latest_version INTO latest_schema FROM %[6]s.latest_version(%[5]s);
SELECT current_setting INTO search_path FROM current_setting('search_path');
IF search_path <> latest_schema THEN
NEW.%[2]s = %[3]s;
END IF;
RETURN NEW;
END; $$`,
pq.QuoteIdentifier(triggerFnName(o)),
pq.QuoteIdentifier(TemporaryName(o.Column.Name)),
*o.Up,
sqlDeclarations(s),
pq.QuoteLiteral(schemaName),
pq.QuoteIdentifier(stateSchema))
_, err := conn.ExecContext(ctx, triggerFn)
if err != nil {
return err
}
trigger := fmt.Sprintf(`CREATE OR REPLACE TRIGGER %[1]s
BEFORE UPDATE OR INSERT
ON %[2]s
FOR EACH ROW
EXECUTE PROCEDURE %[3]s();`,
pq.QuoteIdentifier(triggerName(o)),
pq.QuoteIdentifier(o.Table),
pq.QuoteIdentifier(triggerFnName(o)))
_, err = conn.ExecContext(ctx, trigger)
if err != nil {
return err
}
return nil
}

View File

@ -106,3 +106,72 @@ func TestAddColumn(t *testing.T) {
},
}})
}
func TestAddColumnWithUpSql(t *testing.T) {
t.Parallel()
ptr := func(s string) *string { return &s }
ExecuteTests(t, TestCases{{
name: "add column",
migrations: []migrations.Migration{
{
Name: "01_add_table",
Operations: migrations.Operations{
&migrations.OpCreateTable{
Name: "products",
Columns: []migrations.Column{
{
Name: "id",
Type: "serial",
PrimaryKey: true,
},
{
Name: "name",
Type: "varchar(255)",
Unique: true,
},
},
},
},
},
{
Name: "02_add_column",
Operations: migrations.Operations{
&migrations.OpAddColumn{
Table: "products",
Up: ptr("UPPER(name)"),
Column: migrations.Column{
Name: "description",
Type: "varchar(255)",
Nullable: true,
},
},
},
},
},
afterStart: func(t *testing.T, db *sql.DB) {
// inserting via both the old and the new views works
MustInsert(t, db, "public", "01_add_table", "products", map[string]string{
"name": "apple",
})
MustInsert(t, db, "public", "02_add_column", "products", map[string]string{
"name": "banana",
"description": "a yellow banana",
})
res := MustSelect(t, db, "public", "02_add_column", "products")
assert.Equal(t, []map[string]any{
// the description column has been populated for the product inserted into the old view.
{"id": 1, "name": "apple", "description": "APPLE"},
// the description column for the product inserted into the new view is as inserted.
{"id": 2, "name": "banana", "description": "a yellow banana"},
}, res)
},
afterRollback: func(t *testing.T, db *sql.DB) {
// TODO check that the trigger created by the start operation has been removed
},
afterComplete: func(t *testing.T, db *sql.DB) {
},
}})
}

View File

@ -8,6 +8,7 @@ import (
"testing"
"time"
"github.com/lib/pq"
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/modules/postgres"
"github.com/testcontainers/testcontainers-go/wait"
@ -282,6 +283,8 @@ func MustInsert(t *testing.T, db *sql.DB, schema, version, table string, record
t.Helper()
versionSchema := roll.VersionedSchemaName(schema, version)
mustSetSearchPath(t, db, versionSchema)
cols := maps.Keys(record)
slices.Sort(cols)
@ -352,3 +355,12 @@ func MustSelect(t *testing.T, db *sql.DB, schema, version, table string) []map[s
return res
}
func mustSetSearchPath(t *testing.T, db *sql.DB, schema string) {
t.Helper()
_, err := db.Exec(fmt.Sprintf("SET search_path = %s", pq.QuoteIdentifier(schema)))
if err != nil {
t.Fatal(err)
}
}

View File

@ -26,7 +26,7 @@ type Column struct {
Default *string `json:"default"`
}
func (o *OpCreateTable) Start(ctx context.Context, conn *sql.DB, s *schema.Schema) error {
func (o *OpCreateTable) Start(ctx context.Context, conn *sql.DB, schemaName, stateSchema string, s *schema.Schema) error {
tempName := TemporaryName(o.Name)
_, err := conn.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s (%s)",
pq.QuoteIdentifier(tempName),

View File

@ -17,7 +17,7 @@ type OpRenameTable struct {
To string `json:"to"`
}
func (o *OpRenameTable) Start(ctx context.Context, conn *sql.DB, s *schema.Schema) error {
func (o *OpRenameTable) Start(ctx context.Context, conn *sql.DB, schemaName, stateSchema string, s *schema.Schema) error {
return s.RenameTable(o.From, o.To)
}

View File

@ -39,7 +39,7 @@ func (m *Roll) Start(ctx context.Context, migration *migrations.Migration) error
// execute operations
for _, op := range migration.Operations {
err := op.Start(ctx, m.pgConn, newSchema)
err := op.Start(ctx, m.pgConn, m.schema, m.state.Schema(), newSchema)
if err != nil {
return fmt.Errorf("unable to execute start operation: %w", err)
}

View File

@ -163,6 +163,10 @@ func (s *State) Close() error {
return s.pgConn.Close()
}
func (s *State) Schema() string {
return s.schema
}
// IsActiveMigrationPeriod returns true if there is an active migration
func (s *State) IsActiveMigrationPeriod(ctx context.Context, schema string) (bool, error) {
var isActive bool