mirror of
https://github.com/xataio/pgroll.git
synced 2024-10-05 17:47:59 +03:00
Add a way to set postgres role when executing migrations (#226)
In same cases we want to set a specific role when executing migrations, so the ownerhsip of the created/updated objects is different from the pgroll user (storing pgroll state). This change allows to set a role that will be set in the connection executing migrations.
This commit is contained in:
parent
025a38f057
commit
73c2016929
@ -21,3 +21,7 @@ func StateSchema() string {
|
|||||||
func LockTimeout() int {
|
func LockTimeout() int {
|
||||||
return viper.GetInt("LOCK_TIMEOUT")
|
return viper.GetInt("LOCK_TIMEOUT")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Role() string {
|
||||||
|
return viper.GetString("ROLE")
|
||||||
|
}
|
||||||
|
@ -23,11 +23,13 @@ func init() {
|
|||||||
rootCmd.PersistentFlags().String("schema", "public", "Postgres schema to use for the migration")
|
rootCmd.PersistentFlags().String("schema", "public", "Postgres schema to use for the migration")
|
||||||
rootCmd.PersistentFlags().String("pgroll-schema", "pgroll", "Postgres schema to use for pgroll internal state")
|
rootCmd.PersistentFlags().String("pgroll-schema", "pgroll", "Postgres schema to use for pgroll internal state")
|
||||||
rootCmd.PersistentFlags().Int("lock-timeout", 500, "Postgres lock timeout in milliseconds for pgroll DDL operations")
|
rootCmd.PersistentFlags().Int("lock-timeout", 500, "Postgres lock timeout in milliseconds for pgroll DDL operations")
|
||||||
|
rootCmd.PersistentFlags().String("role", "", "Optional postgres role to set when executing migrations")
|
||||||
|
|
||||||
viper.BindPFlag("PG_URL", rootCmd.PersistentFlags().Lookup("postgres-url"))
|
viper.BindPFlag("PG_URL", rootCmd.PersistentFlags().Lookup("postgres-url"))
|
||||||
viper.BindPFlag("SCHEMA", rootCmd.PersistentFlags().Lookup("schema"))
|
viper.BindPFlag("SCHEMA", rootCmd.PersistentFlags().Lookup("schema"))
|
||||||
viper.BindPFlag("STATE_SCHEMA", rootCmd.PersistentFlags().Lookup("pgroll-schema"))
|
viper.BindPFlag("STATE_SCHEMA", rootCmd.PersistentFlags().Lookup("pgroll-schema"))
|
||||||
viper.BindPFlag("LOCK_TIMEOUT", rootCmd.PersistentFlags().Lookup("lock-timeout"))
|
viper.BindPFlag("LOCK_TIMEOUT", rootCmd.PersistentFlags().Lookup("lock-timeout"))
|
||||||
|
viper.BindPFlag("ROLE", rootCmd.PersistentFlags().Lookup("role"))
|
||||||
}
|
}
|
||||||
|
|
||||||
var rootCmd = &cobra.Command{
|
var rootCmd = &cobra.Command{
|
||||||
@ -41,13 +43,17 @@ func NewRoll(ctx context.Context) (*roll.Roll, error) {
|
|||||||
schema := flags.Schema()
|
schema := flags.Schema()
|
||||||
stateSchema := flags.StateSchema()
|
stateSchema := flags.StateSchema()
|
||||||
lockTimeout := flags.LockTimeout()
|
lockTimeout := flags.LockTimeout()
|
||||||
|
role := flags.Role()
|
||||||
|
|
||||||
state, err := state.New(ctx, pgURL, stateSchema)
|
state, err := state.New(ctx, pgURL, stateSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return roll.New(ctx, pgURL, schema, lockTimeout, state)
|
return roll.New(ctx, pgURL, schema, state,
|
||||||
|
roll.WithLockTimeoutMs(lockTimeout),
|
||||||
|
roll.WithRole(role),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute executes the root command.
|
// Execute executes the root command.
|
||||||
|
@ -537,12 +537,14 @@ The `pgroll` CLI has the following top-level flags:
|
|||||||
* `--schema`: The Postgres schema in which migrations will be run (default `"public"`).
|
* `--schema`: The Postgres schema in which migrations will be run (default `"public"`).
|
||||||
* `--pgroll-schema`: The Postgres schema in which `pgroll` will store its internal state (default: `"pgroll"`).
|
* `--pgroll-schema`: The Postgres schema in which `pgroll` will store its internal state (default: `"pgroll"`).
|
||||||
* `--lock-timeout`: The Postgres `lock_timeout` value to use for all `pgroll` DDL operations, specified in milliseconds (default `500`).
|
* `--lock-timeout`: The Postgres `lock_timeout` value to use for all `pgroll` DDL operations, specified in milliseconds (default `500`).
|
||||||
|
* `--role``: The Postgres role to use for all `pgroll` DDL operations (default: `""`, which doesn't set any role).
|
||||||
|
|
||||||
Each of these flags can also be set via an environment variable:
|
Each of these flags can also be set via an environment variable:
|
||||||
* `PGROLL_PG_URL`
|
* `PGROLL_PG_URL`
|
||||||
* `PGROLL_SCHEMA`
|
* `PGROLL_SCHEMA`
|
||||||
* `PGROLL_STATE_SCHEMA`
|
* `PGROLL_STATE_SCHEMA`
|
||||||
* `PGROLL_LOCK_TIMEOUT`
|
* `PGROLL_LOCK_TIMEOUT`
|
||||||
|
* `PGROLL_ROLE`
|
||||||
|
|
||||||
The CLI flag takes precedence if a flag is set via both an environment variable and a CLI flag.
|
The CLI flag takes precedence if a flag is set via both an environment variable and a CLI flag.
|
||||||
|
|
||||||
|
@ -253,7 +253,7 @@ func TestSchemaOptionIsRespected(t *testing.T) {
|
|||||||
func TestLockTimeoutIsEnforced(t *testing.T) {
|
func TestLockTimeoutIsEnforced(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
testutils.WithMigratorInSchemaWithLockTimeoutAndConnectionToContainer(t, "public", 100, func(mig *roll.Roll, db *sql.DB) {
|
testutils.WithMigratorInSchemaAndConnectionToContainerWithOptions(t, "public", []roll.Option{roll.WithLockTimeoutMs(100)}, func(mig *roll.Roll, db *sql.DB) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
// Start a create table migration
|
// Start a create table migration
|
||||||
@ -458,6 +458,38 @@ func TestStatusMethodReturnsCorrectStatus(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRoleIsRespected(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
testutils.WithMigratorInSchemaAndConnectionToContainerWithOptions(t, "public", []roll.Option{roll.WithRole("pgroll")}, func(mig *roll.Roll, db *sql.DB) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Start a create table migration
|
||||||
|
err := mig.Start(ctx, &migrations.Migration{
|
||||||
|
Name: "01_create_table",
|
||||||
|
Operations: migrations.Operations{createTableOp("table1")},
|
||||||
|
})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Complete the create table migration
|
||||||
|
err = mig.Complete(ctx)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Ensure that the table exists in the correct schema and owned by the correct role
|
||||||
|
var exists bool
|
||||||
|
err = db.QueryRow(`
|
||||||
|
SELECT EXISTS(
|
||||||
|
SELECT 1
|
||||||
|
FROM pg_catalog.pg_tables
|
||||||
|
WHERE tablename = $1
|
||||||
|
AND schemaname = $2
|
||||||
|
AND tableowner = $3
|
||||||
|
)`, "table1", "public", "pgroll").Scan(&exists)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.True(t, exists)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func createTableOp(tableName string) *migrations.OpCreateTable {
|
func createTableOp(tableName string) *migrations.OpCreateTable {
|
||||||
return &migrations.OpCreateTable{
|
return &migrations.OpCreateTable{
|
||||||
Name: tableName,
|
Name: tableName,
|
||||||
|
27
pkg/roll/options.go
Normal file
27
pkg/roll/options.go
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package roll
|
||||||
|
|
||||||
|
type options struct {
|
||||||
|
// lock timeout in milliseconds for pgroll DDL operations
|
||||||
|
lockTimeoutMs int
|
||||||
|
|
||||||
|
// optional role to set before executing migrations
|
||||||
|
role string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Option func(*options)
|
||||||
|
|
||||||
|
// WithLockTimeoutMs sets the lock timeout in milliseconds for pgroll DDL operations
|
||||||
|
func WithLockTimeoutMs(lockTimeoutMs int) Option {
|
||||||
|
return func(o *options) {
|
||||||
|
o.lockTimeoutMs = lockTimeoutMs
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithRole sets the role to set before executing migrations
|
||||||
|
func WithRole(role string) Option {
|
||||||
|
return func(o *options) {
|
||||||
|
o.role = role
|
||||||
|
}
|
||||||
|
}
|
@ -26,7 +26,12 @@ type Roll struct {
|
|||||||
pgVersion PGVersion
|
pgVersion PGVersion
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(ctx context.Context, pgURL, schema string, lockTimeoutMs int, state *state.State) (*Roll, error) {
|
func New(ctx context.Context, pgURL, schema string, state *state.State, opts ...Option) (*Roll, error) {
|
||||||
|
options := &options{}
|
||||||
|
for _, o := range opts {
|
||||||
|
o(options)
|
||||||
|
}
|
||||||
|
|
||||||
dsn, err := pq.ParseURL(pgURL)
|
dsn, err := pq.ParseURL(pgURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
dsn = pgURL
|
dsn = pgURL
|
||||||
@ -48,9 +53,18 @@ func New(ctx context.Context, pgURL, schema string, lockTimeoutMs int, state *st
|
|||||||
return nil, fmt.Errorf("unable to set pgroll.internal to true: %w", err)
|
return nil, fmt.Errorf("unable to set pgroll.internal to true: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = conn.ExecContext(ctx, fmt.Sprintf("SET lock_timeout to '%dms'", lockTimeoutMs))
|
if options.lockTimeoutMs > 0 {
|
||||||
if err != nil {
|
_, err = conn.ExecContext(ctx, fmt.Sprintf("SET lock_timeout to '%dms'", options.lockTimeoutMs))
|
||||||
return nil, fmt.Errorf("unable to set lock_timeout: %w", err)
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to set lock_timeout: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if options.role != "" {
|
||||||
|
_, err = conn.ExecContext(ctx, fmt.Sprintf("SET ROLE %s", options.role))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to set role to '%s': %w", options.role, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var pgMajorVersion PGVersion
|
var pgMajorVersion PGVersion
|
||||||
|
@ -55,6 +55,17 @@ func SharedTestMain(m *testing.M) {
|
|||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
db, err := sql.Open("postgres", tConnStr)
|
||||||
|
if err != nil {
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// create handy role for tests
|
||||||
|
_, err = db.ExecContext(ctx, "CREATE ROLE pgroll")
|
||||||
|
if err != nil {
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
exitCode := m.Run()
|
exitCode := m.Run()
|
||||||
|
|
||||||
if err := ctr.Terminate(ctx); err != nil {
|
if err := ctr.Terminate(ctx); err != nil {
|
||||||
@ -113,7 +124,7 @@ func WithStateAndConnectionToContainer(t *testing.T, fn func(*state.State, *sql.
|
|||||||
fn(st, db)
|
fn(st, db)
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithMigratorInSchemaWithLockTimeoutAndConnectionToContainer(t *testing.T, schema string, lockTimeoutMs int, fn func(mig *roll.Roll, db *sql.DB)) {
|
func WithMigratorInSchemaAndConnectionToContainerWithOptions(t *testing.T, schema string, opts []roll.Option, fn func(mig *roll.Roll, db *sql.DB)) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
@ -143,50 +154,60 @@ func WithMigratorInSchemaWithLockTimeoutAndConnectionToContainer(t *testing.T, s
|
|||||||
u.Path = "/" + dbName
|
u.Path = "/" + dbName
|
||||||
connStr := u.String()
|
connStr := u.String()
|
||||||
|
|
||||||
st, err := state.New(ctx, connStr, "pgroll")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = st.Init(ctx)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
mig, err := roll.New(ctx, connStr, schema, lockTimeoutMs, st)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Cleanup(func() {
|
|
||||||
if err := mig.Close(); err != nil {
|
|
||||||
t.Fatalf("Failed to close migrator connection: %v", err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
db, err := sql.Open("postgres", connStr)
|
db, err := sql.Open("postgres", connStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = db.ExecContext(ctx, fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", schema))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
if err := db.Close(); err != nil {
|
if err := db.Close(); err != nil {
|
||||||
t.Fatalf("Failed to close database connection: %v", err)
|
t.Fatalf("Failed to close database connection: %v", err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
st, err := state.New(ctx, connStr, "pgroll")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = st.Init(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
mig, err := roll.New(ctx, connStr, schema, st, opts...)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
if err := mig.Close(); err != nil {
|
||||||
|
t.Fatalf("Failed to close migrator connection: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
_, err = db.ExecContext(ctx, fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", schema))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = db.ExecContext(ctx, fmt.Sprintf("GRANT ALL PRIVILEGES ON SCHEMA %s TO pgroll", schema))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = db.ExecContext(ctx, fmt.Sprintf("GRANT ALL PRIVILEGES ON DATABASE %s TO pgroll", dbName))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
fn(mig, db)
|
fn(mig, db)
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithMigratorInSchemaAndConnectionToContainer(t *testing.T, schema string, fn func(mig *roll.Roll, db *sql.DB)) {
|
func WithMigratorInSchemaAndConnectionToContainer(t *testing.T, schema string, fn func(mig *roll.Roll, db *sql.DB)) {
|
||||||
WithMigratorInSchemaWithLockTimeoutAndConnectionToContainer(t, schema, 500, fn)
|
WithMigratorInSchemaAndConnectionToContainerWithOptions(t, schema, []roll.Option{roll.WithLockTimeoutMs(500)}, fn)
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithMigratorAndConnectionToContainer(t *testing.T, fn func(mig *roll.Roll, db *sql.DB)) {
|
func WithMigratorAndConnectionToContainer(t *testing.T, fn func(mig *roll.Roll, db *sql.DB)) {
|
||||||
WithMigratorInSchemaWithLockTimeoutAndConnectionToContainer(t, "public", 500, fn)
|
WithMigratorInSchemaAndConnectionToContainerWithOptions(t, "public", []roll.Option{roll.WithLockTimeoutMs(500)}, fn)
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user