Add a way to set postgres role when executing migrations (#226)

In same cases we want to set a specific role when executing migrations,
so the ownerhsip of the created/updated objects is different from the
pgroll user (storing pgroll state). This change allows to set a role
that will be set in the connection executing migrations.
This commit is contained in:
Carlos Pérez-Aradros Herce 2024-01-15 11:36:01 +01:00 committed by GitHub
parent 025a38f057
commit 73c2016929
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 141 additions and 35 deletions

View File

@ -21,3 +21,7 @@ func StateSchema() string {
func LockTimeout() int {
return viper.GetInt("LOCK_TIMEOUT")
}
func Role() string {
return viper.GetString("ROLE")
}

View File

@ -23,11 +23,13 @@ func init() {
rootCmd.PersistentFlags().String("schema", "public", "Postgres schema to use for the migration")
rootCmd.PersistentFlags().String("pgroll-schema", "pgroll", "Postgres schema to use for pgroll internal state")
rootCmd.PersistentFlags().Int("lock-timeout", 500, "Postgres lock timeout in milliseconds for pgroll DDL operations")
rootCmd.PersistentFlags().String("role", "", "Optional postgres role to set when executing migrations")
viper.BindPFlag("PG_URL", rootCmd.PersistentFlags().Lookup("postgres-url"))
viper.BindPFlag("SCHEMA", rootCmd.PersistentFlags().Lookup("schema"))
viper.BindPFlag("STATE_SCHEMA", rootCmd.PersistentFlags().Lookup("pgroll-schema"))
viper.BindPFlag("LOCK_TIMEOUT", rootCmd.PersistentFlags().Lookup("lock-timeout"))
viper.BindPFlag("ROLE", rootCmd.PersistentFlags().Lookup("role"))
}
var rootCmd = &cobra.Command{
@ -41,13 +43,17 @@ func NewRoll(ctx context.Context) (*roll.Roll, error) {
schema := flags.Schema()
stateSchema := flags.StateSchema()
lockTimeout := flags.LockTimeout()
role := flags.Role()
state, err := state.New(ctx, pgURL, stateSchema)
if err != nil {
return nil, err
}
return roll.New(ctx, pgURL, schema, lockTimeout, state)
return roll.New(ctx, pgURL, schema, state,
roll.WithLockTimeoutMs(lockTimeout),
roll.WithRole(role),
)
}
// Execute executes the root command.

View File

@ -537,12 +537,14 @@ The `pgroll` CLI has the following top-level flags:
* `--schema`: The Postgres schema in which migrations will be run (default `"public"`).
* `--pgroll-schema`: The Postgres schema in which `pgroll` will store its internal state (default: `"pgroll"`).
* `--lock-timeout`: The Postgres `lock_timeout` value to use for all `pgroll` DDL operations, specified in milliseconds (default `500`).
* `--role``: The Postgres role to use for all `pgroll` DDL operations (default: `""`, which doesn't set any role).
Each of these flags can also be set via an environment variable:
* `PGROLL_PG_URL`
* `PGROLL_SCHEMA`
* `PGROLL_STATE_SCHEMA`
* `PGROLL_LOCK_TIMEOUT`
* `PGROLL_ROLE`
The CLI flag takes precedence if a flag is set via both an environment variable and a CLI flag.

View File

@ -253,7 +253,7 @@ func TestSchemaOptionIsRespected(t *testing.T) {
func TestLockTimeoutIsEnforced(t *testing.T) {
t.Parallel()
testutils.WithMigratorInSchemaWithLockTimeoutAndConnectionToContainer(t, "public", 100, func(mig *roll.Roll, db *sql.DB) {
testutils.WithMigratorInSchemaAndConnectionToContainerWithOptions(t, "public", []roll.Option{roll.WithLockTimeoutMs(100)}, func(mig *roll.Roll, db *sql.DB) {
ctx := context.Background()
// Start a create table migration
@ -458,6 +458,38 @@ func TestStatusMethodReturnsCorrectStatus(t *testing.T) {
})
}
func TestRoleIsRespected(t *testing.T) {
t.Parallel()
testutils.WithMigratorInSchemaAndConnectionToContainerWithOptions(t, "public", []roll.Option{roll.WithRole("pgroll")}, func(mig *roll.Roll, db *sql.DB) {
ctx := context.Background()
// Start a create table migration
err := mig.Start(ctx, &migrations.Migration{
Name: "01_create_table",
Operations: migrations.Operations{createTableOp("table1")},
})
assert.NoError(t, err)
// Complete the create table migration
err = mig.Complete(ctx)
assert.NoError(t, err)
// Ensure that the table exists in the correct schema and owned by the correct role
var exists bool
err = db.QueryRow(`
SELECT EXISTS(
SELECT 1
FROM pg_catalog.pg_tables
WHERE tablename = $1
AND schemaname = $2
AND tableowner = $3
)`, "table1", "public", "pgroll").Scan(&exists)
assert.NoError(t, err)
assert.True(t, exists)
})
}
func createTableOp(tableName string) *migrations.OpCreateTable {
return &migrations.OpCreateTable{
Name: tableName,

27
pkg/roll/options.go Normal file
View File

@ -0,0 +1,27 @@
// SPDX-License-Identifier: Apache-2.0
package roll
type options struct {
// lock timeout in milliseconds for pgroll DDL operations
lockTimeoutMs int
// optional role to set before executing migrations
role string
}
type Option func(*options)
// WithLockTimeoutMs sets the lock timeout in milliseconds for pgroll DDL operations
func WithLockTimeoutMs(lockTimeoutMs int) Option {
return func(o *options) {
o.lockTimeoutMs = lockTimeoutMs
}
}
// WithRole sets the role to set before executing migrations
func WithRole(role string) Option {
return func(o *options) {
o.role = role
}
}

View File

@ -26,7 +26,12 @@ type Roll struct {
pgVersion PGVersion
}
func New(ctx context.Context, pgURL, schema string, lockTimeoutMs int, state *state.State) (*Roll, error) {
func New(ctx context.Context, pgURL, schema string, state *state.State, opts ...Option) (*Roll, error) {
options := &options{}
for _, o := range opts {
o(options)
}
dsn, err := pq.ParseURL(pgURL)
if err != nil {
dsn = pgURL
@ -48,9 +53,18 @@ func New(ctx context.Context, pgURL, schema string, lockTimeoutMs int, state *st
return nil, fmt.Errorf("unable to set pgroll.internal to true: %w", err)
}
_, err = conn.ExecContext(ctx, fmt.Sprintf("SET lock_timeout to '%dms'", lockTimeoutMs))
if err != nil {
return nil, fmt.Errorf("unable to set lock_timeout: %w", err)
if options.lockTimeoutMs > 0 {
_, err = conn.ExecContext(ctx, fmt.Sprintf("SET lock_timeout to '%dms'", options.lockTimeoutMs))
if err != nil {
return nil, fmt.Errorf("unable to set lock_timeout: %w", err)
}
}
if options.role != "" {
_, err = conn.ExecContext(ctx, fmt.Sprintf("SET ROLE %s", options.role))
if err != nil {
return nil, fmt.Errorf("unable to set role to '%s': %w", options.role, err)
}
}
var pgMajorVersion PGVersion

View File

@ -55,6 +55,17 @@ func SharedTestMain(m *testing.M) {
os.Exit(1)
}
db, err := sql.Open("postgres", tConnStr)
if err != nil {
os.Exit(1)
}
// create handy role for tests
_, err = db.ExecContext(ctx, "CREATE ROLE pgroll")
if err != nil {
os.Exit(1)
}
exitCode := m.Run()
if err := ctr.Terminate(ctx); err != nil {
@ -113,7 +124,7 @@ func WithStateAndConnectionToContainer(t *testing.T, fn func(*state.State, *sql.
fn(st, db)
}
func WithMigratorInSchemaWithLockTimeoutAndConnectionToContainer(t *testing.T, schema string, lockTimeoutMs int, fn func(mig *roll.Roll, db *sql.DB)) {
func WithMigratorInSchemaAndConnectionToContainerWithOptions(t *testing.T, schema string, opts []roll.Option, fn func(mig *roll.Roll, db *sql.DB)) {
t.Helper()
ctx := context.Background()
@ -143,50 +154,60 @@ func WithMigratorInSchemaWithLockTimeoutAndConnectionToContainer(t *testing.T, s
u.Path = "/" + dbName
connStr := u.String()
st, err := state.New(ctx, connStr, "pgroll")
if err != nil {
t.Fatal(err)
}
err = st.Init(ctx)
if err != nil {
t.Fatal(err)
}
mig, err := roll.New(ctx, connStr, schema, lockTimeoutMs, st)
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() {
if err := mig.Close(); err != nil {
t.Fatalf("Failed to close migrator connection: %v", err)
}
})
db, err := sql.Open("postgres", connStr)
if err != nil {
t.Fatal(err)
}
_, err = db.ExecContext(ctx, fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", schema))
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() {
if err := db.Close(); err != nil {
t.Fatalf("Failed to close database connection: %v", err)
}
})
st, err := state.New(ctx, connStr, "pgroll")
if err != nil {
t.Fatal(err)
}
err = st.Init(ctx)
if err != nil {
t.Fatal(err)
}
mig, err := roll.New(ctx, connStr, schema, st, opts...)
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() {
if err := mig.Close(); err != nil {
t.Fatalf("Failed to close migrator connection: %v", err)
}
})
_, err = db.ExecContext(ctx, fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", schema))
if err != nil {
t.Fatal(err)
}
_, err = db.ExecContext(ctx, fmt.Sprintf("GRANT ALL PRIVILEGES ON SCHEMA %s TO pgroll", schema))
if err != nil {
t.Fatal(err)
}
_, err = db.ExecContext(ctx, fmt.Sprintf("GRANT ALL PRIVILEGES ON DATABASE %s TO pgroll", dbName))
if err != nil {
t.Fatal(err)
}
fn(mig, db)
}
func WithMigratorInSchemaAndConnectionToContainer(t *testing.T, schema string, fn func(mig *roll.Roll, db *sql.DB)) {
WithMigratorInSchemaWithLockTimeoutAndConnectionToContainer(t, schema, 500, fn)
WithMigratorInSchemaAndConnectionToContainerWithOptions(t, schema, []roll.Option{roll.WithLockTimeoutMs(500)}, fn)
}
func WithMigratorAndConnectionToContainer(t *testing.T, fn func(mig *roll.Roll, db *sql.DB)) {
WithMigratorInSchemaWithLockTimeoutAndConnectionToContainer(t, "public", 500, fn)
WithMigratorInSchemaAndConnectionToContainerWithOptions(t, "public", []roll.Option{roll.WithLockTimeoutMs(500)}, fn)
}