Add two new options to help with replication control (#290)

Add two new migration options:

* `WithSettingsOnMigrationStart`: defines a map of Postgres
setting/value pairs to be set for the duration of the DDL phase of
migration start. Settings will be restored to their previous values once
the DDL phase is complete.
* `WithKickstartReplication`: defines an option that when set will make
a no-op schema change in between completing the DDL operations for
migration start and performing backfills. This can be used to ensure
that schema replication is up-to-date before starting backfills.

Neither of these options are exposed via the CLI; they are intended for
use by `pgroll` integrators, ie modules using `pgroll` as a Go
dependency.
This commit is contained in:
Andrew Farries 2024-02-29 08:54:57 +00:00 committed by GitHub
parent 0aebb5054c
commit c08ef7065c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 178 additions and 33 deletions

View File

@ -15,19 +15,37 @@ import (
// Start will apply the required changes to enable supporting the new schema version
func (m *Roll) Start(ctx context.Context, migration *migrations.Migration, cbs ...migrations.CallbackFn) error {
// check if there is an active migration, create one otherwise
active, err := m.state.IsActiveMigrationPeriod(ctx, m.schema)
tablesToBackfill, err := m.StartDDLOperations(ctx, migration, cbs...)
if err != nil {
return err
}
if m.migrationHooks.BeforeBackfill != nil {
if err := m.migrationHooks.BeforeBackfill(m); err != nil {
return fmt.Errorf("failed to execute BeforeBackfill hook: %w", err)
}
}
// perform backfills for the tables that require it
return m.performBackfills(ctx, tablesToBackfill)
}
// StartDDLOperations performs the DDL operations for the migration. This does
// not include running backfills for any modified tables.
func (m *Roll) StartDDLOperations(ctx context.Context, migration *migrations.Migration, cbs ...migrations.CallbackFn) ([]*schema.Table, error) {
// check if there is an active migration, create one otherwise
active, err := m.state.IsActiveMigrationPeriod(ctx, m.schema)
if err != nil {
return nil, err
}
if active {
return fmt.Errorf("a migration for schema %q is already in progress", m.schema)
return nil, fmt.Errorf("a migration for schema %q is already in progress", m.schema)
}
// create a new active migration (guaranteed to be unique by constraints)
newSchema, err := m.state.Start(ctx, m.schema, migration)
if err != nil {
return fmt.Errorf("unable to start migration: %w", err)
return nil, fmt.Errorf("unable to start migration: %w", err)
}
// validate migration
@ -36,7 +54,19 @@ func (m *Roll) Start(ctx context.Context, migration *migrations.Migration, cbs .
if err := m.state.Rollback(ctx, m.schema, migration.Name); err != nil {
fmt.Printf("failed to rollback migration: %s\n", err)
}
return fmt.Errorf("migration is invalid: %w", err)
return nil, fmt.Errorf("migration is invalid: %w", err)
}
// run any BeforeStartDDL hooks
if m.migrationHooks.BeforeStartDDL != nil {
if err := m.migrationHooks.BeforeStartDDL(m); err != nil {
return nil, fmt.Errorf("failed to execute BeforeStartDDL hook: %w", err)
}
}
// defer execution of any AfterStartDDL hooks
if m.migrationHooks.AfterStartDDL != nil {
defer m.migrationHooks.AfterStartDDL(m)
}
// execute operations
@ -46,7 +76,7 @@ func (m *Roll) Start(ctx context.Context, migration *migrations.Migration, cbs .
if err != nil {
errRollback := m.Rollback(ctx)
return errors.Join(
return nil, errors.Join(
fmt.Errorf("unable to execute start operation: %w", err),
errRollback)
}
@ -57,7 +87,7 @@ func (m *Roll) Start(ctx context.Context, migration *migrations.Migration, cbs .
if isolatedOp, ok := op.(migrations.IsolatedOperation); ok && isolatedOp.IsIsolated() {
newSchema, err = m.state.ReadSchema(ctx, m.schema)
if err != nil {
return fmt.Errorf("unable to refresh schema: %w", err)
return nil, fmt.Errorf("unable to refresh schema: %w", err)
}
}
}
@ -66,20 +96,17 @@ func (m *Roll) Start(ctx context.Context, migration *migrations.Migration, cbs .
}
}
// perform backfill operations for those tables that require it
for _, table := range tablesToBackfill {
if err := migrations.Backfill(ctx, m.pgConn, table, cbs...); err != nil {
return fmt.Errorf("unable to backfill table %q: %w", table.Name, err)
}
}
if m.disableVersionSchemas {
// skip creating version schemas
return nil
return tablesToBackfill, nil
}
// create views for the new version
return m.ensureViews(ctx, newSchema, migration.Name)
if err := m.ensureViews(ctx, newSchema, migration.Name); err != nil {
return nil, err
}
return tablesToBackfill, nil
}
func (m *Roll) ensureViews(ctx context.Context, schema *schema.Schema, version string) error {
@ -230,6 +257,16 @@ func (m *Roll) ensureView(ctx context.Context, version, name string, table schem
return nil
}
func (m *Roll) performBackfills(ctx context.Context, tables []*schema.Table) error {
for _, table := range tables {
if err := migrations.Backfill(ctx, m.pgConn, table); err != nil {
return fmt.Errorf("unable to backfill table %q: %w", table.Name, err)
}
}
return nil
}
func VersionedSchemaName(schema string, version string) string {
return schema + "_" + version
}

View File

@ -82,21 +82,6 @@ func TestDisabledSchemaManagement(t *testing.T) {
})
}
func schemaExists(t *testing.T, db *sql.DB, schema string) bool {
t.Helper()
var exists bool
err := db.QueryRow(`
SELECT EXISTS(
SELECT 1
FROM pg_catalog.pg_namespace
WHERE nspname = $1
)`, schema).Scan(&exists)
if err != nil {
t.Fatal(err)
}
return exists
}
func TestPreviousVersionIsDroppedAfterMigrationCompletion(t *testing.T) {
t.Parallel()
@ -488,6 +473,68 @@ func TestRoleIsRespected(t *testing.T) {
})
}
func TestMigrationHooksAreInvoked(t *testing.T) {
t.Parallel()
options := []roll.Option{roll.WithMigrationHooks(roll.MigrationHooks{
BeforeStartDDL: func(m *roll.Roll) error {
_, err := m.PgConn().ExecContext(context.Background(), "CREATE TABLE IF NOT EXISTS before_start_ddl (id integer)")
return err
},
AfterStartDDL: func(m *roll.Roll) error {
_, err := m.PgConn().ExecContext(context.Background(), "CREATE TABLE IF NOT EXISTS after_start_ddl (id integer)")
return err
},
BeforeBackfill: func(m *roll.Roll) error {
_, err := m.PgConn().ExecContext(context.Background(), "CREATE TABLE IF NOT EXISTS before_backfill (id integer)")
return err
},
})}
testutils.WithMigratorInSchemaAndConnectionToContainerWithOptions(t, "public", options, func(mig *roll.Roll, db *sql.DB) {
ctx := context.Background()
// Start a create table migration
err := mig.Start(ctx, &migrations.Migration{
Name: "01_create_table",
Operations: migrations.Operations{createTableOp("table1")},
})
assert.NoError(t, err)
// Ensure that both the before_start_ddl and after_start_ddl tables were created
assert.True(t, tableExists(t, db, "public", "before_start_ddl"))
assert.True(t, tableExists(t, db, "public", "after_start_ddl"))
// Complete the migration
err = mig.Complete(ctx)
assert.NoError(t, err)
// Insert some data into the table created by the migration
_, err = db.ExecContext(ctx, "INSERT INTO table1 (id, name) VALUES (1, 'alice')")
assert.NoError(t, err)
// Start a migration that requires a backfill
err = mig.Start(ctx, &migrations.Migration{
Name: "02_add_column",
Operations: migrations.Operations{
&migrations.OpAddColumn{
Table: "table1",
Column: migrations.Column{
Name: "description",
Type: "text",
Nullable: ptr(false),
},
Up: ptr("'this is a description'"),
},
},
})
assert.NoError(t, err)
// ensure that the before_backfill table was created
assert.True(t, tableExists(t, db, "public", "before_backfill"))
})
}
func createTableOp(tableName string) *migrations.OpCreateTable {
return &migrations.OpCreateTable{
Name: tableName,
@ -560,6 +607,40 @@ func MustSelect(t *testing.T, db *sql.DB, schema, version, table string) []map[s
return res
}
func schemaExists(t *testing.T, db *sql.DB, schema string) bool {
t.Helper()
var exists bool
err := db.QueryRow(`
SELECT EXISTS(
SELECT 1
FROM pg_catalog.pg_namespace
WHERE nspname = $1
)`, schema).Scan(&exists)
if err != nil {
t.Fatal(err)
}
return exists
}
func tableExists(t *testing.T, db *sql.DB, schema, table string) bool {
t.Helper()
var exists bool
err := db.QueryRow(`
SELECT EXISTS(
SELECT 1
FROM pg_catalog.pg_tables
WHERE schemaname = $1
AND tablename = $2
)`,
schema, table).Scan(&exists)
if err != nil {
t.Fatal(err)
}
return exists
}
func ptr[T any](v T) *T {
return &v
}

View File

@ -11,6 +11,18 @@ type options struct {
// disable pgroll version schemas creation and deletion
disableVersionSchemas bool
migrationHooks MigrationHooks
}
// MigrationHooks defines hooks that can be set to be called at various points
// during the migration process
type MigrationHooks struct {
// BeforeStartDDL is called before the DDL phase of migration start
BeforeStartDDL func(*Roll) error
// AfterStartDDL is called after the DDL phase of migration start is complete
AfterStartDDL func(*Roll) error
// BeforeBackfill is called before the backfill phase of migration start
BeforeBackfill func(*Roll) error
}
type Option func(*options)
@ -36,3 +48,12 @@ func WithDisableViewsManagement() Option {
o.disableVersionSchemas = true
}
}
// WithMigrationHooks sets the migration hooks for the Roll instance
// Migration hooks are called at various points during the migration process
// to allow for custom behavior to be injected
func WithMigrationHooks(hooks MigrationHooks) Option {
return func(o *options) {
o.migrationHooks = hooks
}
}

View File

@ -25,8 +25,9 @@ type Roll struct {
// disable pgroll version schemas creation and deletion
disableVersionSchemas bool
state *state.State
pgVersion PGVersion
migrationHooks MigrationHooks
state *state.State
pgVersion PGVersion
}
func New(ctx context.Context, pgURL, schema string, state *state.State, opts ...Option) (*Roll, error) {
@ -82,6 +83,7 @@ func New(ctx context.Context, pgURL, schema string, state *state.State, opts ...
state: state,
pgVersion: PGVersion(pgMajorVersion),
disableVersionSchemas: options.disableVersionSchemas,
migrationHooks: options.migrationHooks,
}, nil
}
@ -93,6 +95,10 @@ func (m *Roll) PGVersion() PGVersion {
return m.pgVersion
}
func (m *Roll) PgConn() *sql.DB {
return m.pgConn
}
func (m *Roll) Status(ctx context.Context, schema string) (*state.Status, error) {
return m.state.Status(ctx, schema)
}