Rewrite raw SQL operations using a SQL transformer (#330)

Use the SQL transformer to transform the `up` and `down` fields of a raw
SQL migration.

Builds on https://github.com/xataio/pgroll/pull/329 which added a
`SQLTransformer` option to rewrite user-supplied SQL.
This commit is contained in:
Andrew Farries 2024-03-27 10:59:07 +00:00 committed by GitHub
parent cc8c2d38ba
commit 0673b1b470
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 103 additions and 11 deletions

View File

@ -32,12 +32,12 @@ func TestMain(m *testing.M) {
testutils.SharedTestMain(m)
}
func ExecuteTests(t *testing.T, tests TestCases) {
func ExecuteTests(t *testing.T, tests TestCases, opts ...roll.Option) {
testSchema := testutils.TestSchema()
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
testutils.WithMigratorInSchemaAndConnectionToContainer(t, testSchema, func(mig *roll.Roll, db *sql.DB) {
testutils.WithMigratorInSchemaAndConnectionToContainerWithOptions(t, testSchema, opts, func(mig *roll.Roll, db *sql.DB) {
ctx := context.Background()
// run all migrations except the last one

View File

@ -12,27 +12,45 @@ import (
var _ Operation = (*OpRawSQL)(nil)
func (o *OpRawSQL) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
if !o.OnComplete {
_, err := conn.ExecContext(ctx, o.Up)
if o.OnComplete {
return nil, nil
}
up, err := tr.TransformSQL(o.Up)
if err != nil {
return nil, err
}
return nil, nil
_, err = conn.ExecContext(ctx, up)
return nil, err
}
func (o *OpRawSQL) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error {
if o.OnComplete {
_, err := conn.ExecContext(ctx, o.Up)
if !o.OnComplete {
return nil
}
up, err := tr.TransformSQL(o.Up)
if err != nil {
return err
}
return nil
_, err = conn.ExecContext(ctx, up)
return err
}
func (o *OpRawSQL) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error {
if o.Down != "" {
_, err := conn.ExecContext(ctx, o.Down)
if o.Down == "" {
return nil
}
down, err := tr.TransformSQL(o.Down)
if err != nil {
return err
}
return nil
_, err = conn.ExecContext(ctx, down)
return err
}
func (o *OpRawSQL) Validate(ctx context.Context, s *schema.Schema) error {

View File

@ -4,9 +4,11 @@ package migrations_test
import (
"database/sql"
"fmt"
"testing"
"github.com/xataio/pgroll/pkg/migrations"
"github.com/xataio/pgroll/pkg/roll"
)
func TestRawSQL(t *testing.T) {
@ -184,3 +186,75 @@ func TestRawSQL(t *testing.T) {
},
})
}
func TestRawSQLTransformation(t *testing.T) {
t.Parallel()
t.Run("for normal raw SQL operations with up and down SQL", func(t *testing.T) {
ExecuteTests(t, TestCases{
{
name: "SQL transformer rewrites up and down SQL",
migrations: []migrations.Migration{
{
Name: "01_create_table",
Operations: migrations.Operations{
&migrations.OpRawSQL{
Up: "CREATE TABLE apples(id int)",
Down: "CREATE TABLE bananas(id int)",
},
},
},
},
afterStart: func(t *testing.T, db *sql.DB, schema string) {
// The transformed `up` SQL was used in place of the original SQL
TableMustExist(t, db, schema, "table_1")
TableMustNotExist(t, db, schema, "apples")
},
afterRollback: func(t *testing.T, db *sql.DB, schema string) {
// The transformed `down` SQL was used in place of the original SQL
TableMustExist(t, db, schema, "table_2")
TableMustNotExist(t, db, schema, "bananas")
},
afterComplete: func(t *testing.T, db *sql.DB, schema string) {
},
},
}, roll.WithSQLTransformer(&simpleSQLTransformer{}))
})
t.Run("for raw SQL operations that run on complete", func(t *testing.T) {
ExecuteTests(t, TestCases{
{
name: "SQL transformer rewrites up SQL when up is run on completion",
migrations: []migrations.Migration{
{
Name: "01_create_table",
Operations: migrations.Operations{
&migrations.OpRawSQL{
Up: "CREATE TABLE apples(id int)",
OnComplete: true,
},
},
},
},
afterStart: func(t *testing.T, db *sql.DB, schema string) {
},
afterRollback: func(t *testing.T, db *sql.DB, schema string) {
},
afterComplete: func(t *testing.T, db *sql.DB, schema string) {
// The transformed `up` SQL was used in place of the original SQL
TableMustExist(t, db, schema, "table_1")
TableMustNotExist(t, db, schema, "apples")
},
},
}, roll.WithSQLTransformer(&simpleSQLTransformer{}))
})
}
type simpleSQLTransformer struct {
counter int
}
func (s *simpleSQLTransformer) TransformSQL(sql string) (string, error) {
s.counter++
return fmt.Sprintf("CREATE TABLE table_%d(id int)", s.counter), nil
}