mirror of
https://github.com/xataio/pgroll.git
synced 2024-07-14 17:10:33 +03:00
Add a way to set postgres role when executing migrations (#226)
In same cases we want to set a specific role when executing migrations, so the ownerhsip of the created/updated objects is different from the pgroll user (storing pgroll state). This change allows to set a role that will be set in the connection executing migrations.
This commit is contained in:
parent
025a38f057
commit
73c2016929
@ -21,3 +21,7 @@ func StateSchema() string {
|
||||
func LockTimeout() int {
|
||||
return viper.GetInt("LOCK_TIMEOUT")
|
||||
}
|
||||
|
||||
func Role() string {
|
||||
return viper.GetString("ROLE")
|
||||
}
|
||||
|
@ -23,11 +23,13 @@ func init() {
|
||||
rootCmd.PersistentFlags().String("schema", "public", "Postgres schema to use for the migration")
|
||||
rootCmd.PersistentFlags().String("pgroll-schema", "pgroll", "Postgres schema to use for pgroll internal state")
|
||||
rootCmd.PersistentFlags().Int("lock-timeout", 500, "Postgres lock timeout in milliseconds for pgroll DDL operations")
|
||||
rootCmd.PersistentFlags().String("role", "", "Optional postgres role to set when executing migrations")
|
||||
|
||||
viper.BindPFlag("PG_URL", rootCmd.PersistentFlags().Lookup("postgres-url"))
|
||||
viper.BindPFlag("SCHEMA", rootCmd.PersistentFlags().Lookup("schema"))
|
||||
viper.BindPFlag("STATE_SCHEMA", rootCmd.PersistentFlags().Lookup("pgroll-schema"))
|
||||
viper.BindPFlag("LOCK_TIMEOUT", rootCmd.PersistentFlags().Lookup("lock-timeout"))
|
||||
viper.BindPFlag("ROLE", rootCmd.PersistentFlags().Lookup("role"))
|
||||
}
|
||||
|
||||
var rootCmd = &cobra.Command{
|
||||
@ -41,13 +43,17 @@ func NewRoll(ctx context.Context) (*roll.Roll, error) {
|
||||
schema := flags.Schema()
|
||||
stateSchema := flags.StateSchema()
|
||||
lockTimeout := flags.LockTimeout()
|
||||
role := flags.Role()
|
||||
|
||||
state, err := state.New(ctx, pgURL, stateSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return roll.New(ctx, pgURL, schema, lockTimeout, state)
|
||||
return roll.New(ctx, pgURL, schema, state,
|
||||
roll.WithLockTimeoutMs(lockTimeout),
|
||||
roll.WithRole(role),
|
||||
)
|
||||
}
|
||||
|
||||
// Execute executes the root command.
|
||||
|
@ -537,12 +537,14 @@ The `pgroll` CLI has the following top-level flags:
|
||||
* `--schema`: The Postgres schema in which migrations will be run (default `"public"`).
|
||||
* `--pgroll-schema`: The Postgres schema in which `pgroll` will store its internal state (default: `"pgroll"`).
|
||||
* `--lock-timeout`: The Postgres `lock_timeout` value to use for all `pgroll` DDL operations, specified in milliseconds (default `500`).
|
||||
* `--role``: The Postgres role to use for all `pgroll` DDL operations (default: `""`, which doesn't set any role).
|
||||
|
||||
Each of these flags can also be set via an environment variable:
|
||||
* `PGROLL_PG_URL`
|
||||
* `PGROLL_SCHEMA`
|
||||
* `PGROLL_STATE_SCHEMA`
|
||||
* `PGROLL_LOCK_TIMEOUT`
|
||||
* `PGROLL_ROLE`
|
||||
|
||||
The CLI flag takes precedence if a flag is set via both an environment variable and a CLI flag.
|
||||
|
||||
|
@ -253,7 +253,7 @@ func TestSchemaOptionIsRespected(t *testing.T) {
|
||||
func TestLockTimeoutIsEnforced(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testutils.WithMigratorInSchemaWithLockTimeoutAndConnectionToContainer(t, "public", 100, func(mig *roll.Roll, db *sql.DB) {
|
||||
testutils.WithMigratorInSchemaAndConnectionToContainerWithOptions(t, "public", []roll.Option{roll.WithLockTimeoutMs(100)}, func(mig *roll.Roll, db *sql.DB) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Start a create table migration
|
||||
@ -458,6 +458,38 @@ func TestStatusMethodReturnsCorrectStatus(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestRoleIsRespected(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testutils.WithMigratorInSchemaAndConnectionToContainerWithOptions(t, "public", []roll.Option{roll.WithRole("pgroll")}, func(mig *roll.Roll, db *sql.DB) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Start a create table migration
|
||||
err := mig.Start(ctx, &migrations.Migration{
|
||||
Name: "01_create_table",
|
||||
Operations: migrations.Operations{createTableOp("table1")},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Complete the create table migration
|
||||
err = mig.Complete(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Ensure that the table exists in the correct schema and owned by the correct role
|
||||
var exists bool
|
||||
err = db.QueryRow(`
|
||||
SELECT EXISTS(
|
||||
SELECT 1
|
||||
FROM pg_catalog.pg_tables
|
||||
WHERE tablename = $1
|
||||
AND schemaname = $2
|
||||
AND tableowner = $3
|
||||
)`, "table1", "public", "pgroll").Scan(&exists)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
})
|
||||
}
|
||||
|
||||
func createTableOp(tableName string) *migrations.OpCreateTable {
|
||||
return &migrations.OpCreateTable{
|
||||
Name: tableName,
|
||||
|
27
pkg/roll/options.go
Normal file
27
pkg/roll/options.go
Normal file
@ -0,0 +1,27 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package roll
|
||||
|
||||
type options struct {
|
||||
// lock timeout in milliseconds for pgroll DDL operations
|
||||
lockTimeoutMs int
|
||||
|
||||
// optional role to set before executing migrations
|
||||
role string
|
||||
}
|
||||
|
||||
type Option func(*options)
|
||||
|
||||
// WithLockTimeoutMs sets the lock timeout in milliseconds for pgroll DDL operations
|
||||
func WithLockTimeoutMs(lockTimeoutMs int) Option {
|
||||
return func(o *options) {
|
||||
o.lockTimeoutMs = lockTimeoutMs
|
||||
}
|
||||
}
|
||||
|
||||
// WithRole sets the role to set before executing migrations
|
||||
func WithRole(role string) Option {
|
||||
return func(o *options) {
|
||||
o.role = role
|
||||
}
|
||||
}
|
@ -26,7 +26,12 @@ type Roll struct {
|
||||
pgVersion PGVersion
|
||||
}
|
||||
|
||||
func New(ctx context.Context, pgURL, schema string, lockTimeoutMs int, state *state.State) (*Roll, error) {
|
||||
func New(ctx context.Context, pgURL, schema string, state *state.State, opts ...Option) (*Roll, error) {
|
||||
options := &options{}
|
||||
for _, o := range opts {
|
||||
o(options)
|
||||
}
|
||||
|
||||
dsn, err := pq.ParseURL(pgURL)
|
||||
if err != nil {
|
||||
dsn = pgURL
|
||||
@ -48,10 +53,19 @@ func New(ctx context.Context, pgURL, schema string, lockTimeoutMs int, state *st
|
||||
return nil, fmt.Errorf("unable to set pgroll.internal to true: %w", err)
|
||||
}
|
||||
|
||||
_, err = conn.ExecContext(ctx, fmt.Sprintf("SET lock_timeout to '%dms'", lockTimeoutMs))
|
||||
if options.lockTimeoutMs > 0 {
|
||||
_, err = conn.ExecContext(ctx, fmt.Sprintf("SET lock_timeout to '%dms'", options.lockTimeoutMs))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to set lock_timeout: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if options.role != "" {
|
||||
_, err = conn.ExecContext(ctx, fmt.Sprintf("SET ROLE %s", options.role))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to set role to '%s': %w", options.role, err)
|
||||
}
|
||||
}
|
||||
|
||||
var pgMajorVersion PGVersion
|
||||
err = conn.QueryRowContext(ctx, "SELECT split_part(split_part(version(), ' ', 2), '.', 1)").Scan(&pgMajorVersion)
|
||||
|
@ -55,6 +55,17 @@ func SharedTestMain(m *testing.M) {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
db, err := sql.Open("postgres", tConnStr)
|
||||
if err != nil {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// create handy role for tests
|
||||
_, err = db.ExecContext(ctx, "CREATE ROLE pgroll")
|
||||
if err != nil {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
exitCode := m.Run()
|
||||
|
||||
if err := ctr.Terminate(ctx); err != nil {
|
||||
@ -113,7 +124,7 @@ func WithStateAndConnectionToContainer(t *testing.T, fn func(*state.State, *sql.
|
||||
fn(st, db)
|
||||
}
|
||||
|
||||
func WithMigratorInSchemaWithLockTimeoutAndConnectionToContainer(t *testing.T, schema string, lockTimeoutMs int, fn func(mig *roll.Roll, db *sql.DB)) {
|
||||
func WithMigratorInSchemaAndConnectionToContainerWithOptions(t *testing.T, schema string, opts []roll.Option, fn func(mig *roll.Roll, db *sql.DB)) {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
@ -143,50 +154,60 @@ func WithMigratorInSchemaWithLockTimeoutAndConnectionToContainer(t *testing.T, s
|
||||
u.Path = "/" + dbName
|
||||
connStr := u.String()
|
||||
|
||||
st, err := state.New(ctx, connStr, "pgroll")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = st.Init(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
mig, err := roll.New(ctx, connStr, schema, lockTimeoutMs, st)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
if err := mig.Close(); err != nil {
|
||||
t.Fatalf("Failed to close migrator connection: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
db, err := sql.Open("postgres", connStr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = db.ExecContext(ctx, fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", schema))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
if err := db.Close(); err != nil {
|
||||
t.Fatalf("Failed to close database connection: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
st, err := state.New(ctx, connStr, "pgroll")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = st.Init(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
mig, err := roll.New(ctx, connStr, schema, st, opts...)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
if err := mig.Close(); err != nil {
|
||||
t.Fatalf("Failed to close migrator connection: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
_, err = db.ExecContext(ctx, fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", schema))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = db.ExecContext(ctx, fmt.Sprintf("GRANT ALL PRIVILEGES ON SCHEMA %s TO pgroll", schema))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = db.ExecContext(ctx, fmt.Sprintf("GRANT ALL PRIVILEGES ON DATABASE %s TO pgroll", dbName))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
fn(mig, db)
|
||||
}
|
||||
|
||||
func WithMigratorInSchemaAndConnectionToContainer(t *testing.T, schema string, fn func(mig *roll.Roll, db *sql.DB)) {
|
||||
WithMigratorInSchemaWithLockTimeoutAndConnectionToContainer(t, schema, 500, fn)
|
||||
WithMigratorInSchemaAndConnectionToContainerWithOptions(t, schema, []roll.Option{roll.WithLockTimeoutMs(500)}, fn)
|
||||
}
|
||||
|
||||
func WithMigratorAndConnectionToContainer(t *testing.T, fn func(mig *roll.Roll, db *sql.DB)) {
|
||||
WithMigratorInSchemaWithLockTimeoutAndConnectionToContainer(t, "public", 500, fn)
|
||||
WithMigratorInSchemaAndConnectionToContainerWithOptions(t, "public", []roll.Option{roll.WithLockTimeoutMs(500)}, fn)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user