diff --git a/pkg/migrations/op_create_table_test.go b/pkg/migrations/op_create_table_test.go index 02e1c28..e977afc 100644 --- a/pkg/migrations/op_create_table_test.go +++ b/pkg/migrations/op_create_table_test.go @@ -67,61 +67,11 @@ func TestRecordsCanBeInsertedIntoAndReadFromNewViewAfterMigrationStart(t *testin t.Fatalf("Failed to start migration: %v", err) } - // - // Insert records via the view - // - sql := fmt.Sprintf(`INSERT INTO %s.%s (id, name) VALUES ($1, $2)`, - pq.QuoteIdentifier(versionSchema), - pq.QuoteIdentifier(viewName)) - - insertStmt, err := db.Prepare(sql) - if err != nil { - t.Fatal(err) - } - defer insertStmt.Close() - - type user struct { - ID int - Name string - } - inserted := []user{{ID: 1, Name: "Alice"}, {ID: 2, Name: "Bob"}} - - for _, v := range inserted { - _, err = insertStmt.Exec(v.ID, v.Name) - if err != nil { - t.Fatal(err) - } - } - - // - // Read the records back via the view - // - sql = fmt.Sprintf(`SELECT id, name FROM %q.%q`, versionSchema, viewName) - rows, err := db.Query(sql) - if err != nil { - t.Fatal(err) - } - defer rows.Close() - - var retrieved []user - for rows.Next() { - var user user - if err := rows.Scan(&user.ID, &user.Name); err != nil { - t.Fatal(err) - } - retrieved = append(retrieved, user) - } - if err := rows.Err(); err != nil { - t.Fatal(err) - } - - if !slices.Equal(inserted, retrieved) { - t.Error(cmp.Diff(inserted, retrieved)) - } + insertAndSelectRows(t, db, versionSchema) }) } -func TestViewSchemaAndTableAreDroppedAfterMigrationRevert(t *testing.T) { +func TestRecordsCanBeInsertedIntoAndReadFromNewViewAfterMigrationComplete(t *testing.T) { t.Parallel() withMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) { @@ -129,6 +79,24 @@ func TestViewSchemaAndTableAreDroppedAfterMigrationRevert(t *testing.T) { version := "1_create_table" versionSchema := roll.VersionedSchemaName(schema, version) + if err := mig.Start(ctx, &migrations.Migration{Name: version, Operations: migrations.Operations{createTableOp()}}); err != nil { + t.Fatalf("Failed to start migration: %v", err) + } + if err := mig.Complete(ctx); err != nil { + t.Fatalf("Failed to complete migration: %v", err) + } + + insertAndSelectRows(t, db, versionSchema) + }) +} + +func TestTableIsDroppedAfterMigrationRollback(t *testing.T) { + t.Parallel() + + withMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) { + ctx := context.Background() + version := "1_create_table" + migration := &migrations.Migration{Name: version, Operations: migrations.Operations{createTableOp()}} if err := mig.Start(ctx, migration); err != nil { @@ -157,44 +125,63 @@ func TestViewSchemaAndTableAreDroppedAfterMigrationRevert(t *testing.T) { if exists { t.Errorf("Expected table %q to not exist", tableName) } - - // - // Check that the view in the new schema has been dropped - // - err = db.QueryRow(` - SELECT EXISTS ( - SELECT 1 - FROM pg_catalog.pg_views - WHERE schemaname = $1 - AND viewname = $2 - ) `, versionSchema, viewName).Scan(&exists) - if err != nil { - t.Fatal(err) - } - - if exists { - t.Errorf("Expected view %q to not exist", viewName) - } - - // - // Check that the new schema has been dropped - // - err = db.QueryRow(` - SELECT EXISTS( - SELECT 1 - FROM pg_catalog.pg_namespace - WHERE nspname = $1 - )`, versionSchema).Scan(&exists) - if err != nil { - t.Fatal(err) - } - - if exists { - t.Errorf("Expected schema %q to not exist", version) - } }) } +func insertAndSelectRows(t *testing.T, db *sql.DB, schemaName string) { + // + // Insert records via the view + // + sql := fmt.Sprintf(`INSERT INTO %s.%s (id, name) VALUES ($1, $2)`, + pq.QuoteIdentifier(schemaName), + pq.QuoteIdentifier(viewName)) + + insertStmt, err := db.Prepare(sql) + if err != nil { + t.Fatal(err) + } + defer insertStmt.Close() + + type user struct { + ID int + Name string + } + inserted := []user{{ID: 1, Name: "Alice"}, {ID: 2, Name: "Bob"}} + + for _, v := range inserted { + _, err = insertStmt.Exec(v.ID, v.Name) + if err != nil { + t.Fatal(err) + } + } + + // + // Read the records back via the view + // + sql = fmt.Sprintf(`SELECT id, name FROM %q.%q`, schemaName, viewName) + rows, err := db.Query(sql) + if err != nil { + t.Fatal(err) + } + defer rows.Close() + + var retrieved []user + for rows.Next() { + var user user + if err := rows.Scan(&user.ID, &user.Name); err != nil { + t.Fatal(err) + } + retrieved = append(retrieved, user) + } + if err := rows.Err(); err != nil { + t.Fatal(err) + } + + if !slices.Equal(inserted, retrieved) { + t.Error(cmp.Diff(inserted, retrieved)) + } +} + func createTableOp() *migrations.OpCreateTable { return &migrations.OpCreateTable{ Name: viewName, diff --git a/pkg/roll/execute_test.go b/pkg/roll/execute_test.go index 683a974..edb9f00 100644 --- a/pkg/roll/execute_test.go +++ b/pkg/roll/execute_test.go @@ -20,6 +20,37 @@ const ( postgresImage = "postgres:15.3" ) +func TestSchemaIsCreatedfterMigrationStart(t *testing.T) { + t.Parallel() + + withMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) { + ctx := context.Background() + version := "1_create_table" + + if err := mig.Start(ctx, &migrations.Migration{Name: version, Operations: migrations.Operations{createTableOp("table1")}}); err != nil { + t.Fatalf("Failed to start migration: %v", err) + } + + // + // Check that the schema exists + // + var exists bool + err := db.QueryRow(` + SELECT EXISTS( + SELECT 1 + FROM pg_catalog.pg_namespace + WHERE nspname = $1 + )`, roll.VersionedSchemaName(schema, version)).Scan(&exists) + if err != nil { + t.Fatal(err) + } + + if !exists { + t.Errorf("Expected schema %q to exist", version) + } + }) +} + func TestPreviousVersionIsDroppedAfterMigrationCompletion(t *testing.T) { t.Parallel() @@ -63,6 +94,40 @@ func TestPreviousVersionIsDroppedAfterMigrationCompletion(t *testing.T) { }) } +func TestSchemaIsDroppedAfterMigrationRollback(t *testing.T) { + t.Parallel() + + withMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) { + ctx := context.Background() + version := "1_create_table" + + if err := mig.Start(ctx, &migrations.Migration{Name: version, Operations: migrations.Operations{createTableOp("table1")}}); err != nil { + t.Fatalf("Failed to start migration: %v", err) + } + if err := mig.Rollback(ctx); err != nil { + t.Fatalf("Failed to rollback migration: %v", err) + } + + // + // Check that the schema has been dropped + // + var exists bool + err := db.QueryRow(` + SELECT EXISTS( + SELECT 1 + FROM pg_catalog.pg_namespace + WHERE nspname = $1 + )`, roll.VersionedSchemaName(schema, version)).Scan(&exists) + if err != nil { + t.Fatal(err) + } + + if exists { + t.Errorf("Expected schema %q to not exist", version) + } + }) +} + func createTableOp(tableName string) *migrations.OpCreateTable { return &migrations.OpCreateTable{ Name: tableName,