mirror of
https://github.com/xataio/pgroll.git
synced 2024-09-17 16:57:30 +03:00
Rewrite raw SQL operations using a SQL transformer (#330)
Use the SQL transformer to transform the `up` and `down` fields of a raw SQL migration. Builds on https://github.com/xataio/pgroll/pull/329 which added a `SQLTransformer` option to rewrite user-supplied SQL.
This commit is contained in:
parent
cc8c2d38ba
commit
0673b1b470
@ -32,12 +32,12 @@ func TestMain(m *testing.M) {
|
||||
testutils.SharedTestMain(m)
|
||||
}
|
||||
|
||||
func ExecuteTests(t *testing.T, tests TestCases) {
|
||||
func ExecuteTests(t *testing.T, tests TestCases, opts ...roll.Option) {
|
||||
testSchema := testutils.TestSchema()
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
testutils.WithMigratorInSchemaAndConnectionToContainer(t, testSchema, func(mig *roll.Roll, db *sql.DB) {
|
||||
testutils.WithMigratorInSchemaAndConnectionToContainerWithOptions(t, testSchema, opts, func(mig *roll.Roll, db *sql.DB) {
|
||||
ctx := context.Background()
|
||||
|
||||
// run all migrations except the last one
|
||||
|
@ -12,27 +12,45 @@ import (
|
||||
var _ Operation = (*OpRawSQL)(nil)
|
||||
|
||||
func (o *OpRawSQL) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
|
||||
if !o.OnComplete {
|
||||
_, err := conn.ExecContext(ctx, o.Up)
|
||||
if o.OnComplete {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
up, err := tr.TransformSQL(o.Up)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, nil
|
||||
|
||||
_, err = conn.ExecContext(ctx, up)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func (o *OpRawSQL) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error {
|
||||
if o.OnComplete {
|
||||
_, err := conn.ExecContext(ctx, o.Up)
|
||||
if !o.OnComplete {
|
||||
return nil
|
||||
}
|
||||
|
||||
up, err := tr.TransformSQL(o.Up)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
|
||||
_, err = conn.ExecContext(ctx, up)
|
||||
return err
|
||||
}
|
||||
|
||||
func (o *OpRawSQL) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error {
|
||||
if o.Down != "" {
|
||||
_, err := conn.ExecContext(ctx, o.Down)
|
||||
if o.Down == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
down, err := tr.TransformSQL(o.Down)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
|
||||
_, err = conn.ExecContext(ctx, down)
|
||||
return err
|
||||
}
|
||||
|
||||
func (o *OpRawSQL) Validate(ctx context.Context, s *schema.Schema) error {
|
||||
|
@ -4,9 +4,11 @@ package migrations_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/xataio/pgroll/pkg/migrations"
|
||||
"github.com/xataio/pgroll/pkg/roll"
|
||||
)
|
||||
|
||||
func TestRawSQL(t *testing.T) {
|
||||
@ -184,3 +186,75 @@ func TestRawSQL(t *testing.T) {
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func TestRawSQLTransformation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("for normal raw SQL operations with up and down SQL", func(t *testing.T) {
|
||||
ExecuteTests(t, TestCases{
|
||||
{
|
||||
name: "SQL transformer rewrites up and down SQL",
|
||||
migrations: []migrations.Migration{
|
||||
{
|
||||
Name: "01_create_table",
|
||||
Operations: migrations.Operations{
|
||||
&migrations.OpRawSQL{
|
||||
Up: "CREATE TABLE apples(id int)",
|
||||
Down: "CREATE TABLE bananas(id int)",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
afterStart: func(t *testing.T, db *sql.DB, schema string) {
|
||||
// The transformed `up` SQL was used in place of the original SQL
|
||||
TableMustExist(t, db, schema, "table_1")
|
||||
TableMustNotExist(t, db, schema, "apples")
|
||||
},
|
||||
afterRollback: func(t *testing.T, db *sql.DB, schema string) {
|
||||
// The transformed `down` SQL was used in place of the original SQL
|
||||
TableMustExist(t, db, schema, "table_2")
|
||||
TableMustNotExist(t, db, schema, "bananas")
|
||||
},
|
||||
afterComplete: func(t *testing.T, db *sql.DB, schema string) {
|
||||
},
|
||||
},
|
||||
}, roll.WithSQLTransformer(&simpleSQLTransformer{}))
|
||||
})
|
||||
|
||||
t.Run("for raw SQL operations that run on complete", func(t *testing.T) {
|
||||
ExecuteTests(t, TestCases{
|
||||
{
|
||||
name: "SQL transformer rewrites up SQL when up is run on completion",
|
||||
migrations: []migrations.Migration{
|
||||
{
|
||||
Name: "01_create_table",
|
||||
Operations: migrations.Operations{
|
||||
&migrations.OpRawSQL{
|
||||
Up: "CREATE TABLE apples(id int)",
|
||||
OnComplete: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
afterStart: func(t *testing.T, db *sql.DB, schema string) {
|
||||
},
|
||||
afterRollback: func(t *testing.T, db *sql.DB, schema string) {
|
||||
},
|
||||
afterComplete: func(t *testing.T, db *sql.DB, schema string) {
|
||||
// The transformed `up` SQL was used in place of the original SQL
|
||||
TableMustExist(t, db, schema, "table_1")
|
||||
TableMustNotExist(t, db, schema, "apples")
|
||||
},
|
||||
},
|
||||
}, roll.WithSQLTransformer(&simpleSQLTransformer{}))
|
||||
})
|
||||
}
|
||||
|
||||
type simpleSQLTransformer struct {
|
||||
counter int
|
||||
}
|
||||
|
||||
func (s *simpleSQLTransformer) TransformSQL(sql string) (string, error) {
|
||||
s.counter++
|
||||
return fmt.Sprintf("CREATE TABLE table_%d(id int)", s.counter), nil
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user