diff --git a/pkg/state/state.go b/pkg/state/state.go index 47a46b3..ec95ccc 100644 --- a/pkg/state/state.go +++ b/pkg/state/state.go @@ -73,18 +73,18 @@ STABLE; CREATE OR REPLACE FUNCTION %[1]s.previous_version(schemaname NAME) RETURNS text AS $$ WITH RECURSIVE find_ancestor AS ( - SELECT schema, name, parent, migration_type FROM pgroll.migrations + SELECT schema, name, parent, migration_type FROM %[1]s.migrations WHERE name = (SELECT %[1]s.latest_version(schemaname)) AND schema = schemaname UNION ALL - SELECT m.schema, m.name, m.parent, m.migration_type FROM pgroll.migrations m + SELECT m.schema, m.name, m.parent, m.migration_type FROM %[1]s.migrations m INNER JOIN find_ancestor fa ON fa.parent = m.name AND fa.schema = m.schema WHERE m.migration_type = 'inferred' ) SELECT a.parent FROM find_ancestor AS a - JOIN pgroll.migrations AS b ON a.parent = b.name AND a.schema = b.schema + JOIN %[1]s.migrations AS b ON a.parent = b.name AND a.schema = b.schema WHERE b.migration_type = 'pgroll'; $$ LANGUAGE SQL diff --git a/pkg/state/state_test.go b/pkg/state/state_test.go index 574835c..3f5a44c 100644 --- a/pkg/state/state_test.go +++ b/pkg/state/state_test.go @@ -114,6 +114,24 @@ func TestInferredMigration(t *testing.T) { }) } +func TestPgRollInitializationInANonDefaultSchema(t *testing.T) { + t.Parallel() + + testutils.WithStateInSchemaAndConnectionToContainer(t, "pgroll_foo", func(state *state.State, _ *sql.DB) { + ctx := context.Background() + + // Ensure that pgroll state has been correctly initialized in the + // non-default schema `pgroll_foo` by performing a basic operation on the + // state + migrationActive, err := state.IsActiveMigrationPeriod(ctx, "public") + if err != nil { + t.Fatal(err) + } + + assert.False(t, migrationActive) + }) +} + func TestReadSchema(t *testing.T) { t.Parallel() diff --git a/pkg/testutils/util.go b/pkg/testutils/util.go index 4198575..45fe7d0 100644 --- a/pkg/testutils/util.go +++ b/pkg/testutils/util.go @@ -85,7 +85,7 @@ func TestSchema() string { return "public" } -func WithStateAndConnectionToContainer(t *testing.T, fn func(*state.State, *sql.DB)) { +func WithStateInSchemaAndConnectionToContainer(t *testing.T, schema string, fn func(*state.State, *sql.DB)) { t.Helper() ctx := context.Background() @@ -115,7 +115,7 @@ func WithStateAndConnectionToContainer(t *testing.T, fn func(*state.State, *sql. u.Path = "/" + dbName connStr := u.String() - st, err := state.New(ctx, connStr, "pgroll") + st, err := state.New(ctx, connStr, schema) if err != nil { t.Fatal(err) } @@ -139,6 +139,10 @@ func WithStateAndConnectionToContainer(t *testing.T, fn func(*state.State, *sql. fn(st, db) } +func WithStateAndConnectionToContainer(t *testing.T, fn func(*state.State, *sql.DB)) { + WithStateInSchemaAndConnectionToContainer(t, "pgroll", fn) +} + func WithMigratorInSchemaAndConnectionToContainerWithOptions(t *testing.T, schema string, opts []roll.Option, fn func(mig *roll.Roll, db *sql.DB)) { t.Helper() ctx := context.Background()