From 73c201692906f2073b1a477f801315b37beee764 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20P=C3=A9rez-Aradros=20Herce?= Date: Mon, 15 Jan 2024 11:36:01 +0100 Subject: [PATCH] Add a way to set postgres role when executing migrations (#226) In same cases we want to set a specific role when executing migrations, so the ownerhsip of the created/updated objects is different from the pgroll user (storing pgroll state). This change allows to set a role that will be set in the connection executing migrations. --- cmd/flags/flags.go | 4 ++ cmd/root.go | 8 +++- docs/README.md | 2 + pkg/roll/execute_test.go | 34 ++++++++++++++++- pkg/roll/options.go | 27 ++++++++++++++ pkg/roll/roll.go | 22 +++++++++-- pkg/testutils/util.go | 79 +++++++++++++++++++++++++--------------- 7 files changed, 141 insertions(+), 35 deletions(-) create mode 100644 pkg/roll/options.go diff --git a/cmd/flags/flags.go b/cmd/flags/flags.go index b063804..6d821b4 100644 --- a/cmd/flags/flags.go +++ b/cmd/flags/flags.go @@ -21,3 +21,7 @@ func StateSchema() string { func LockTimeout() int { return viper.GetInt("LOCK_TIMEOUT") } + +func Role() string { + return viper.GetString("ROLE") +} diff --git a/cmd/root.go b/cmd/root.go index 9f9c27b..5936e72 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -23,11 +23,13 @@ func init() { rootCmd.PersistentFlags().String("schema", "public", "Postgres schema to use for the migration") rootCmd.PersistentFlags().String("pgroll-schema", "pgroll", "Postgres schema to use for pgroll internal state") rootCmd.PersistentFlags().Int("lock-timeout", 500, "Postgres lock timeout in milliseconds for pgroll DDL operations") + rootCmd.PersistentFlags().String("role", "", "Optional postgres role to set when executing migrations") viper.BindPFlag("PG_URL", rootCmd.PersistentFlags().Lookup("postgres-url")) viper.BindPFlag("SCHEMA", rootCmd.PersistentFlags().Lookup("schema")) viper.BindPFlag("STATE_SCHEMA", rootCmd.PersistentFlags().Lookup("pgroll-schema")) viper.BindPFlag("LOCK_TIMEOUT", rootCmd.PersistentFlags().Lookup("lock-timeout")) + viper.BindPFlag("ROLE", rootCmd.PersistentFlags().Lookup("role")) } var rootCmd = &cobra.Command{ @@ -41,13 +43,17 @@ func NewRoll(ctx context.Context) (*roll.Roll, error) { schema := flags.Schema() stateSchema := flags.StateSchema() lockTimeout := flags.LockTimeout() + role := flags.Role() state, err := state.New(ctx, pgURL, stateSchema) if err != nil { return nil, err } - return roll.New(ctx, pgURL, schema, lockTimeout, state) + return roll.New(ctx, pgURL, schema, state, + roll.WithLockTimeoutMs(lockTimeout), + roll.WithRole(role), + ) } // Execute executes the root command. diff --git a/docs/README.md b/docs/README.md index 891af50..9c09737 100644 --- a/docs/README.md +++ b/docs/README.md @@ -537,12 +537,14 @@ The `pgroll` CLI has the following top-level flags: * `--schema`: The Postgres schema in which migrations will be run (default `"public"`). * `--pgroll-schema`: The Postgres schema in which `pgroll` will store its internal state (default: `"pgroll"`). * `--lock-timeout`: The Postgres `lock_timeout` value to use for all `pgroll` DDL operations, specified in milliseconds (default `500`). +* `--role``: The Postgres role to use for all `pgroll` DDL operations (default: `""`, which doesn't set any role). Each of these flags can also be set via an environment variable: * `PGROLL_PG_URL` * `PGROLL_SCHEMA` * `PGROLL_STATE_SCHEMA` * `PGROLL_LOCK_TIMEOUT` +* `PGROLL_ROLE` The CLI flag takes precedence if a flag is set via both an environment variable and a CLI flag. diff --git a/pkg/roll/execute_test.go b/pkg/roll/execute_test.go index 93c740c..198305f 100644 --- a/pkg/roll/execute_test.go +++ b/pkg/roll/execute_test.go @@ -253,7 +253,7 @@ func TestSchemaOptionIsRespected(t *testing.T) { func TestLockTimeoutIsEnforced(t *testing.T) { t.Parallel() - testutils.WithMigratorInSchemaWithLockTimeoutAndConnectionToContainer(t, "public", 100, func(mig *roll.Roll, db *sql.DB) { + testutils.WithMigratorInSchemaAndConnectionToContainerWithOptions(t, "public", []roll.Option{roll.WithLockTimeoutMs(100)}, func(mig *roll.Roll, db *sql.DB) { ctx := context.Background() // Start a create table migration @@ -458,6 +458,38 @@ func TestStatusMethodReturnsCorrectStatus(t *testing.T) { }) } +func TestRoleIsRespected(t *testing.T) { + t.Parallel() + + testutils.WithMigratorInSchemaAndConnectionToContainerWithOptions(t, "public", []roll.Option{roll.WithRole("pgroll")}, func(mig *roll.Roll, db *sql.DB) { + ctx := context.Background() + + // Start a create table migration + err := mig.Start(ctx, &migrations.Migration{ + Name: "01_create_table", + Operations: migrations.Operations{createTableOp("table1")}, + }) + assert.NoError(t, err) + + // Complete the create table migration + err = mig.Complete(ctx) + assert.NoError(t, err) + + // Ensure that the table exists in the correct schema and owned by the correct role + var exists bool + err = db.QueryRow(` + SELECT EXISTS( + SELECT 1 + FROM pg_catalog.pg_tables + WHERE tablename = $1 + AND schemaname = $2 + AND tableowner = $3 + )`, "table1", "public", "pgroll").Scan(&exists) + assert.NoError(t, err) + assert.True(t, exists) + }) +} + func createTableOp(tableName string) *migrations.OpCreateTable { return &migrations.OpCreateTable{ Name: tableName, diff --git a/pkg/roll/options.go b/pkg/roll/options.go new file mode 100644 index 0000000..ddae63b --- /dev/null +++ b/pkg/roll/options.go @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: Apache-2.0 + +package roll + +type options struct { + // lock timeout in milliseconds for pgroll DDL operations + lockTimeoutMs int + + // optional role to set before executing migrations + role string +} + +type Option func(*options) + +// WithLockTimeoutMs sets the lock timeout in milliseconds for pgroll DDL operations +func WithLockTimeoutMs(lockTimeoutMs int) Option { + return func(o *options) { + o.lockTimeoutMs = lockTimeoutMs + } +} + +// WithRole sets the role to set before executing migrations +func WithRole(role string) Option { + return func(o *options) { + o.role = role + } +} diff --git a/pkg/roll/roll.go b/pkg/roll/roll.go index 30a200e..e15d0ff 100644 --- a/pkg/roll/roll.go +++ b/pkg/roll/roll.go @@ -26,7 +26,12 @@ type Roll struct { pgVersion PGVersion } -func New(ctx context.Context, pgURL, schema string, lockTimeoutMs int, state *state.State) (*Roll, error) { +func New(ctx context.Context, pgURL, schema string, state *state.State, opts ...Option) (*Roll, error) { + options := &options{} + for _, o := range opts { + o(options) + } + dsn, err := pq.ParseURL(pgURL) if err != nil { dsn = pgURL @@ -48,9 +53,18 @@ func New(ctx context.Context, pgURL, schema string, lockTimeoutMs int, state *st return nil, fmt.Errorf("unable to set pgroll.internal to true: %w", err) } - _, err = conn.ExecContext(ctx, fmt.Sprintf("SET lock_timeout to '%dms'", lockTimeoutMs)) - if err != nil { - return nil, fmt.Errorf("unable to set lock_timeout: %w", err) + if options.lockTimeoutMs > 0 { + _, err = conn.ExecContext(ctx, fmt.Sprintf("SET lock_timeout to '%dms'", options.lockTimeoutMs)) + if err != nil { + return nil, fmt.Errorf("unable to set lock_timeout: %w", err) + } + } + + if options.role != "" { + _, err = conn.ExecContext(ctx, fmt.Sprintf("SET ROLE %s", options.role)) + if err != nil { + return nil, fmt.Errorf("unable to set role to '%s': %w", options.role, err) + } } var pgMajorVersion PGVersion diff --git a/pkg/testutils/util.go b/pkg/testutils/util.go index 12836fe..ebdbd99 100644 --- a/pkg/testutils/util.go +++ b/pkg/testutils/util.go @@ -55,6 +55,17 @@ func SharedTestMain(m *testing.M) { os.Exit(1) } + db, err := sql.Open("postgres", tConnStr) + if err != nil { + os.Exit(1) + } + + // create handy role for tests + _, err = db.ExecContext(ctx, "CREATE ROLE pgroll") + if err != nil { + os.Exit(1) + } + exitCode := m.Run() if err := ctr.Terminate(ctx); err != nil { @@ -113,7 +124,7 @@ func WithStateAndConnectionToContainer(t *testing.T, fn func(*state.State, *sql. fn(st, db) } -func WithMigratorInSchemaWithLockTimeoutAndConnectionToContainer(t *testing.T, schema string, lockTimeoutMs int, fn func(mig *roll.Roll, db *sql.DB)) { +func WithMigratorInSchemaAndConnectionToContainerWithOptions(t *testing.T, schema string, opts []roll.Option, fn func(mig *roll.Roll, db *sql.DB)) { t.Helper() ctx := context.Background() @@ -143,50 +154,60 @@ func WithMigratorInSchemaWithLockTimeoutAndConnectionToContainer(t *testing.T, s u.Path = "/" + dbName connStr := u.String() - st, err := state.New(ctx, connStr, "pgroll") - if err != nil { - t.Fatal(err) - } - - err = st.Init(ctx) - if err != nil { - t.Fatal(err) - } - - mig, err := roll.New(ctx, connStr, schema, lockTimeoutMs, st) - if err != nil { - t.Fatal(err) - } - - t.Cleanup(func() { - if err := mig.Close(); err != nil { - t.Fatalf("Failed to close migrator connection: %v", err) - } - }) - db, err := sql.Open("postgres", connStr) if err != nil { t.Fatal(err) } - _, err = db.ExecContext(ctx, fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", schema)) - if err != nil { - t.Fatal(err) - } - t.Cleanup(func() { if err := db.Close(); err != nil { t.Fatalf("Failed to close database connection: %v", err) } }) + st, err := state.New(ctx, connStr, "pgroll") + if err != nil { + t.Fatal(err) + } + + err = st.Init(ctx) + if err != nil { + t.Fatal(err) + } + + mig, err := roll.New(ctx, connStr, schema, st, opts...) + if err != nil { + t.Fatal(err) + } + + t.Cleanup(func() { + if err := mig.Close(); err != nil { + t.Fatalf("Failed to close migrator connection: %v", err) + } + }) + + _, err = db.ExecContext(ctx, fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", schema)) + if err != nil { + t.Fatal(err) + } + + _, err = db.ExecContext(ctx, fmt.Sprintf("GRANT ALL PRIVILEGES ON SCHEMA %s TO pgroll", schema)) + if err != nil { + t.Fatal(err) + } + + _, err = db.ExecContext(ctx, fmt.Sprintf("GRANT ALL PRIVILEGES ON DATABASE %s TO pgroll", dbName)) + if err != nil { + t.Fatal(err) + } + fn(mig, db) } func WithMigratorInSchemaAndConnectionToContainer(t *testing.T, schema string, fn func(mig *roll.Roll, db *sql.DB)) { - WithMigratorInSchemaWithLockTimeoutAndConnectionToContainer(t, schema, 500, fn) + WithMigratorInSchemaAndConnectionToContainerWithOptions(t, schema, []roll.Option{roll.WithLockTimeoutMs(500)}, fn) } func WithMigratorAndConnectionToContainer(t *testing.T, fn func(mig *roll.Roll, db *sql.DB)) { - WithMigratorInSchemaWithLockTimeoutAndConnectionToContainer(t, "public", 500, fn) + WithMigratorInSchemaAndConnectionToContainerWithOptions(t, "public", []roll.Option{roll.WithLockTimeoutMs(500)}, fn) }