mirror of
https://github.com/xataio/pgroll.git
synced 2024-07-14 17:10:33 +03:00
Add raw SQL operation (#43)
This change adds a new `sql` operation, that allows to define an `up` SQL statement to perform a migration on the schema. An optional `down` field can be provided, this will be used when trying to do a rollback after (for instance, in case of migration failure). A new trigger is installed to capture DDL events coming from direct user manipulations (not done by pg-roll), so they are stored as a migration, getting to know the resulting schema in all cases.
This commit is contained in:
parent
2702343334
commit
16b1d75ee0
@ -6,6 +6,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/xataio/pg-roll/pkg/migrations"
|
||||
"github.com/xataio/pg-roll/pkg/roll"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
@ -44,7 +45,8 @@ func startCmd() *cobra.Command {
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Printf("Migration successful!, new version of the schema available under postgres '%s' schema\n", version)
|
||||
viewName := roll.VersionedSchemaName(Schema, version)
|
||||
fmt.Printf("Migration successful! New version of the schema available under postgres '%s' schema\n", viewName)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
@ -42,8 +42,8 @@ var statusCmd = &cobra.Command{
|
||||
},
|
||||
}
|
||||
|
||||
func statusForSchema(ctx context.Context, state *state.State, schema string) (*statusLine, error) {
|
||||
latestVersion, err := state.LatestVersion(ctx, schema)
|
||||
func statusForSchema(ctx context.Context, st *state.State, schema string) (*statusLine, error) {
|
||||
latestVersion, err := st.LatestVersion(ctx, schema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -51,7 +51,7 @@ func statusForSchema(ctx context.Context, state *state.State, schema string) (*s
|
||||
latestVersion = new(string)
|
||||
}
|
||||
|
||||
isActive, err := state.IsActiveMigrationPeriod(ctx, schema)
|
||||
isActive, err := st.IsActiveMigrationPeriod(ctx, schema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
11
examples/05_sql.json
Normal file
11
examples/05_sql.json
Normal file
@ -0,0 +1,11 @@
|
||||
{
|
||||
"name": "05_sql",
|
||||
"operations": [
|
||||
{
|
||||
"sql": {
|
||||
"up": "CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)",
|
||||
"down": "DROP TABLE users"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
16
examples/06_add_column_to_sql_table.json
Normal file
16
examples/06_add_column_to_sql_table.json
Normal file
@ -0,0 +1,16 @@
|
||||
{
|
||||
"name": "06_add_column_to_sql_table",
|
||||
"operations": [
|
||||
{
|
||||
"add_column": {
|
||||
"table": "users",
|
||||
"up": "UPPER(name)",
|
||||
"column": {
|
||||
"name": "description",
|
||||
"type": "varchar(255)",
|
||||
"nullable": true
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
@ -2,6 +2,20 @@ package migrations
|
||||
|
||||
import "fmt"
|
||||
|
||||
type InvalidMigrationError struct {
|
||||
Reason string
|
||||
}
|
||||
|
||||
func (e InvalidMigrationError) Error() string {
|
||||
return e.Reason
|
||||
}
|
||||
|
||||
type EmptyMigrationError struct{}
|
||||
|
||||
func (e EmptyMigrationError) Error() string {
|
||||
return "migration is empty"
|
||||
}
|
||||
|
||||
type TableAlreadyExistsError struct {
|
||||
Name string
|
||||
}
|
||||
|
@ -3,6 +3,7 @@ package migrations
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
"github.com/xataio/pg-roll/pkg/schema"
|
||||
@ -27,6 +28,17 @@ type Operation interface {
|
||||
Validate(ctx context.Context, s *schema.Schema) error
|
||||
}
|
||||
|
||||
// IsolatedOperation is an operation that cannot be executed with other operations
|
||||
// in the same migration
|
||||
type IsolatedOperation interface {
|
||||
IsIsolated()
|
||||
}
|
||||
|
||||
// RequiresSchemaRefreshOperation is an operation that requires the resulting schema to be refreshed
|
||||
type RequiresSchemaRefreshOperation interface {
|
||||
RequiresSchemaRefresh()
|
||||
}
|
||||
|
||||
type (
|
||||
Operations []Operation
|
||||
Migration struct {
|
||||
@ -39,6 +51,14 @@ type (
|
||||
// Validate will check that the migration can be applied to the given schema
|
||||
// returns a descriptive error if the migration is invalid
|
||||
func (m *Migration) Validate(ctx context.Context, s *schema.Schema) error {
|
||||
for _, op := range m.Operations {
|
||||
if _, ok := op.(IsolatedOperation); ok {
|
||||
if len(m.Operations) > 1 {
|
||||
return InvalidMigrationError{Reason: fmt.Sprintf("operation %q cannot be executed with other operations", OperationName(op))}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, op := range m.Operations {
|
||||
err := op.Validate(ctx, s)
|
||||
if err != nil {
|
||||
|
38
pkg/migrations/migrations_test.go
Normal file
38
pkg/migrations/migrations_test.go
Normal file
@ -0,0 +1,38 @@
|
||||
package migrations
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/xataio/pg-roll/pkg/schema"
|
||||
)
|
||||
|
||||
func TestMigrationsIsolated(t *testing.T) {
|
||||
migration := Migration{
|
||||
Name: "sql",
|
||||
Operations: Operations{
|
||||
&OpRawSQL{
|
||||
Up: `foo`,
|
||||
},
|
||||
&OpRenameColumn{},
|
||||
},
|
||||
}
|
||||
|
||||
err := migration.Validate(context.TODO(), schema.New())
|
||||
var wantErr InvalidMigrationError
|
||||
assert.ErrorAs(t, err, &wantErr)
|
||||
}
|
||||
|
||||
func TestMigrationsIsolatedValid(t *testing.T) {
|
||||
migration := Migration{
|
||||
Name: "sql",
|
||||
Operations: Operations{
|
||||
&OpRawSQL{
|
||||
Up: `foo`,
|
||||
},
|
||||
},
|
||||
}
|
||||
err := migration.Validate(context.TODO(), schema.New())
|
||||
assert.NoError(t, err)
|
||||
}
|
@ -21,6 +21,7 @@ const (
|
||||
OpNameRenameColumn OpName = "rename_column"
|
||||
OpNameSetUnique OpName = "set_unique"
|
||||
OpNameSetNotNull OpName = "set_not_null"
|
||||
OpRawSQLName OpName = "sql"
|
||||
)
|
||||
|
||||
func TemporaryName(name string) string {
|
||||
@ -106,6 +107,9 @@ func (v *Operations) UnmarshalJSON(data []byte) error {
|
||||
case OpNameSetNotNull:
|
||||
item = &OpSetNotNull{}
|
||||
|
||||
case OpRawSQLName:
|
||||
item = &OpRawSQL{}
|
||||
|
||||
default:
|
||||
return fmt.Errorf("unknown migration type: %v", opName)
|
||||
}
|
||||
@ -136,44 +140,8 @@ func (v Operations) MarshalJSON() ([]byte, error) {
|
||||
buf.WriteByte(',')
|
||||
}
|
||||
|
||||
var opName OpName
|
||||
switch op.(type) {
|
||||
case *OpCreateTable:
|
||||
opName = OpNameCreateTable
|
||||
|
||||
case *OpRenameTable:
|
||||
opName = OpNameRenameTable
|
||||
|
||||
case *OpDropTable:
|
||||
opName = OpNameDropTable
|
||||
|
||||
case *OpAddColumn:
|
||||
opName = OpNameAddColumn
|
||||
|
||||
case *OpDropColumn:
|
||||
opName = OpNameDropColumn
|
||||
|
||||
case *OpRenameColumn:
|
||||
opName = OpNameRenameColumn
|
||||
|
||||
case *OpCreateIndex:
|
||||
opName = OpNameCreateIndex
|
||||
|
||||
case *OpDropIndex:
|
||||
opName = OpNameDropIndex
|
||||
|
||||
case *OpSetUnique:
|
||||
opName = OpNameSetUnique
|
||||
|
||||
case *OpSetNotNull:
|
||||
opName = OpNameSetNotNull
|
||||
|
||||
default:
|
||||
panic(fmt.Errorf("unknown operation for %T", op))
|
||||
}
|
||||
|
||||
buf.WriteString(`{"`)
|
||||
buf.WriteString(string(opName))
|
||||
buf.WriteString(string(OperationName(op)))
|
||||
buf.WriteString(`":`)
|
||||
if err := enc.Encode(op); err != nil {
|
||||
return nil, fmt.Errorf("unable to encode op [%v]: %w", i, err)
|
||||
@ -183,3 +151,43 @@ func (v Operations) MarshalJSON() ([]byte, error) {
|
||||
buf.WriteByte(']')
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func OperationName(op Operation) OpName {
|
||||
switch op.(type) {
|
||||
case *OpCreateTable:
|
||||
return OpNameCreateTable
|
||||
|
||||
case *OpRenameTable:
|
||||
return OpNameRenameTable
|
||||
|
||||
case *OpDropTable:
|
||||
return OpNameDropTable
|
||||
|
||||
case *OpAddColumn:
|
||||
return OpNameAddColumn
|
||||
|
||||
case *OpDropColumn:
|
||||
return OpNameDropColumn
|
||||
|
||||
case *OpRenameColumn:
|
||||
return OpNameRenameColumn
|
||||
|
||||
case *OpCreateIndex:
|
||||
return OpNameCreateIndex
|
||||
|
||||
case *OpDropIndex:
|
||||
return OpNameDropIndex
|
||||
|
||||
case *OpSetUnique:
|
||||
return OpNameSetUnique
|
||||
|
||||
case *OpSetNotNull:
|
||||
return OpNameSetNotNull
|
||||
|
||||
case *OpRawSQL:
|
||||
return OpRawSQLName
|
||||
|
||||
}
|
||||
|
||||
panic(fmt.Errorf("unknown operation for %T", op))
|
||||
}
|
||||
|
49
pkg/migrations/op_raw_sql.go
Normal file
49
pkg/migrations/op_raw_sql.go
Normal file
@ -0,0 +1,49 @@
|
||||
package migrations
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/xataio/pg-roll/pkg/schema"
|
||||
)
|
||||
|
||||
var _ Operation = (*OpRawSQL)(nil)
|
||||
|
||||
type OpRawSQL struct {
|
||||
Up string `json:"up"`
|
||||
Down string `json:"down,omitempty"`
|
||||
}
|
||||
|
||||
func (o *OpRawSQL) Start(ctx context.Context, conn *sql.DB, schemaName, stateSchema string, s *schema.Schema) error {
|
||||
_, err := conn.ExecContext(ctx, o.Up)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *OpRawSQL) Complete(ctx context.Context, conn *sql.DB) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *OpRawSQL) Rollback(ctx context.Context, conn *sql.DB) error {
|
||||
if o.Down != "" {
|
||||
_, err := conn.ExecContext(ctx, o.Down)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *OpRawSQL) Validate(ctx context.Context, s *schema.Schema) error {
|
||||
if o.Up == "" {
|
||||
return EmptyMigrationError{}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// this operation is isolated, cannot be executed with other operations
|
||||
func (o *OpRawSQL) IsIsolated() {}
|
||||
|
||||
// this operation requires the resulting schema to be refreshed
|
||||
func (o *OpRawSQL) RequiresSchemaRefresh() {}
|
109
pkg/migrations/op_raw_sql_test.go
Normal file
109
pkg/migrations/op_raw_sql_test.go
Normal file
@ -0,0 +1,109 @@
|
||||
package migrations_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
"github.com/xataio/pg-roll/pkg/migrations"
|
||||
)
|
||||
|
||||
func TestRawSQL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ExecuteTests(t, TestCases{
|
||||
{
|
||||
name: "raw SQL",
|
||||
migrations: []migrations.Migration{
|
||||
{
|
||||
Name: "01_create_table",
|
||||
Operations: migrations.Operations{
|
||||
&migrations.OpRawSQL{
|
||||
Up: `
|
||||
CREATE TABLE test_table (
|
||||
id serial,
|
||||
name text
|
||||
)
|
||||
`,
|
||||
Down: `
|
||||
DROP TABLE test_table
|
||||
`,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
afterStart: func(t *testing.T, db *sql.DB) {
|
||||
// table can be accessed after start
|
||||
ViewMustExist(t, db, "public", "01_create_table", "test_table")
|
||||
|
||||
// inserts work
|
||||
MustInsert(t, db, "public", "01_create_table", "test_table", map[string]string{
|
||||
"name": "foo",
|
||||
})
|
||||
},
|
||||
afterRollback: func(t *testing.T, db *sql.DB) {
|
||||
// table is dropped after rollback
|
||||
TableMustNotExist(t, db, "public", "test_table")
|
||||
},
|
||||
afterComplete: func(t *testing.T, db *sql.DB) {
|
||||
// inserts still work after complete
|
||||
MustInsert(t, db, "public", "01_create_table", "test_table", map[string]string{
|
||||
"name": "foo",
|
||||
})
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "migration on top of raw SQL",
|
||||
migrations: []migrations.Migration{
|
||||
{
|
||||
Name: "01_create_table",
|
||||
Operations: migrations.Operations{
|
||||
&migrations.OpRawSQL{
|
||||
Up: `
|
||||
CREATE TABLE test_table (
|
||||
id serial,
|
||||
name text
|
||||
)
|
||||
`,
|
||||
Down: `
|
||||
DROP TABLE test_table
|
||||
`,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "02_rename_table",
|
||||
Operations: migrations.Operations{
|
||||
&migrations.OpRenameTable{
|
||||
From: "test_table",
|
||||
To: "test_table_renamed",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
afterStart: func(t *testing.T, db *sql.DB) {
|
||||
// table can be accessed after start
|
||||
ViewMustExist(t, db, "public", "01_create_table", "test_table")
|
||||
|
||||
// table is renamed in new version
|
||||
ViewMustExist(t, db, "public", "02_rename_table", "test_table_renamed")
|
||||
|
||||
// inserts work
|
||||
MustInsert(t, db, "public", "01_create_table", "test_table", map[string]string{
|
||||
"name": "foo",
|
||||
})
|
||||
MustInsert(t, db, "public", "02_rename_table", "test_table_renamed", map[string]string{
|
||||
"name": "foo",
|
||||
})
|
||||
},
|
||||
afterComplete: func(t *testing.T, db *sql.DB) {
|
||||
// table can still be accessed after complete
|
||||
ViewMustExist(t, db, "public", "02_rename_table", "test_table_renamed")
|
||||
|
||||
// inserts work
|
||||
MustInsert(t, db, "public", "02_rename_table", "test_table_renamed", map[string]string{
|
||||
"name": "foo",
|
||||
})
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
@ -42,6 +42,14 @@ func (m *Roll) Start(ctx context.Context, migration *migrations.Migration) error
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to execute start operation: %w", err)
|
||||
}
|
||||
|
||||
if _, ok := op.(migrations.RequiresSchemaRefreshOperation); ok {
|
||||
// refresh schema
|
||||
newSchema, err = m.state.ReadSchema(ctx, m.schema)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to refresh schema: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// create schema for the new version
|
||||
|
@ -3,6 +3,7 @@ package roll
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/lib/pq"
|
||||
|
||||
@ -31,6 +32,11 @@ func New(ctx context.Context, pgURL, schema string, state *state.State) (*Roll,
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = conn.ExecContext(ctx, "SET LOCAL pgroll.internal to 'TRUE'")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to set pgroll.internal to true: %w", err)
|
||||
}
|
||||
|
||||
return &Roll{
|
||||
pgConn: conn,
|
||||
schema: schema,
|
||||
|
@ -1,6 +1,11 @@
|
||||
package schema
|
||||
|
||||
import "fmt"
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// XXX we create a view of the schema with the minimum required for us to
|
||||
// know how to execute migrations and build views for the new schema version.
|
||||
@ -57,6 +62,11 @@ type Index struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// Replace replaces the contents of the schema with the contents of the given one
|
||||
func (s *Schema) Replace(other *Schema) {
|
||||
s.Tables = other.Tables
|
||||
}
|
||||
|
||||
func (s *Schema) GetTable(name string) *Table {
|
||||
if s.Tables == nil {
|
||||
return nil
|
||||
@ -121,3 +131,20 @@ func (t *Table) RenameColumn(from, to string) {
|
||||
t.Columns[to] = t.Columns[from]
|
||||
delete(t.Columns, from)
|
||||
}
|
||||
|
||||
// Make the Schema struct implement the driver.Valuer interface. This method
|
||||
// simply returns the JSON-encoded representation of the struct.
|
||||
func (s Schema) Value() (driver.Value, error) {
|
||||
return json.Marshal(s)
|
||||
}
|
||||
|
||||
// Make the Schema struct implement the sql.Scanner interface. This method
|
||||
// simply decodes a JSON-encoded value into the struct fields.
|
||||
func (s *Schema) Scan(value interface{}) error {
|
||||
b, ok := value.([]byte)
|
||||
if !ok {
|
||||
return errors.New("type assertion to []byte failed")
|
||||
}
|
||||
|
||||
return json.Unmarshal(b, &s)
|
||||
}
|
||||
|
@ -140,6 +140,52 @@ BEGIN
|
||||
RETURN tables;
|
||||
END;
|
||||
$$;
|
||||
|
||||
CREATE OR REPLACE FUNCTION %[1]s.raw_migration() RETURNS event_trigger
|
||||
LANGUAGE plpgsql AS $$
|
||||
DECLARE
|
||||
schemaname TEXT;
|
||||
BEGIN
|
||||
-- Ignore migrations done by pg-roll
|
||||
IF (current_setting('pgroll.internal', 'TRUE') <> 'TRUE') THEN
|
||||
RETURN;
|
||||
END IF;
|
||||
|
||||
-- Guess the schema from ddl commands, ignore migrations that touch several schemas
|
||||
IF (SELECT COUNT(DISTINCT schema_name) FROM pg_event_trigger_ddl_commands() WHERE schema_name IS NOT NULL) > 1 THEN
|
||||
RAISE NOTICE 'pg-roll: ignoring migration that touches several schemas';
|
||||
RETURN;
|
||||
END IF;
|
||||
|
||||
SELECT schema_name INTO schemaname FROM pg_event_trigger_ddl_commands() WHERE schema_name IS NOT NULL;
|
||||
|
||||
IF schemaname IS NULL THEN
|
||||
RAISE NOTICE 'pg-roll: ignoring migration with null schema';
|
||||
RETURN;
|
||||
END IF;
|
||||
|
||||
-- Ignore migrations done during a migration period
|
||||
IF %[1]s.is_active_migration_period(schemaname) THEN
|
||||
RAISE NOTICE 'pg-roll: ignoring migration during active migration period';
|
||||
RETURN;
|
||||
END IF;
|
||||
|
||||
-- Someone did a schema change without pg-roll, include it in the history
|
||||
INSERT INTO %[1]s.migrations (schema, name, migration, resulting_schema, done, parent)
|
||||
VALUES (
|
||||
schemaname,
|
||||
format('sql_%%s', substr(md5(random()::text), 0, 15)),
|
||||
json_build_object('sql', json_build_object('up', current_query())),
|
||||
%[1]s.read_schema(schemaname),
|
||||
true,
|
||||
%[1]s.latest_version(schemaname)
|
||||
);
|
||||
END;
|
||||
$$;
|
||||
|
||||
DROP EVENT TRIGGER IF EXISTS pg_roll_handle_ddl;
|
||||
CREATE EVENT TRIGGER pg_roll_handle_ddl ON ddl_command_end
|
||||
EXECUTE FUNCTION %[1]s.raw_migration() ;
|
||||
`
|
||||
|
||||
type State struct {
|
||||
@ -153,6 +199,11 @@ func New(ctx context.Context, pgURL, stateSchema string) (*State, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = conn.ExecContext(ctx, "SET LOCAL pgroll.internal to 'TRUE'")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to set pgroll.internal to true: %w", err)
|
||||
}
|
||||
|
||||
return &State{
|
||||
pgConn: conn,
|
||||
schema: stateSchema,
|
||||
@ -183,7 +234,7 @@ func (s *State) IsActiveMigrationPeriod(ctx context.Context, schema string) (boo
|
||||
return false, err
|
||||
}
|
||||
|
||||
return isActive, err
|
||||
return isActive, nil
|
||||
}
|
||||
|
||||
// GetActiveMigration returns the name & raw content of the active migration (if any), errors out otherwise
|
||||
@ -232,6 +283,17 @@ func (s *State) PreviousVersion(ctx context.Context, schema string) (*string, er
|
||||
return parent, nil
|
||||
}
|
||||
|
||||
// ReadSchema reads & returns the current schema from postgres
|
||||
func ReadSchema(ctx context.Context, conn *sql.DB, stateSchema, schemaname string) (*schema.Schema, error) {
|
||||
var res schema.Schema
|
||||
err := conn.QueryRowContext(ctx, fmt.Sprintf("SELECT %[1]s.read_schema($1)", pq.QuoteIdentifier(stateSchema)), schemaname).Scan(&res)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &res, nil
|
||||
}
|
||||
|
||||
// Start creates a new migration, storing its name and raw content
|
||||
// this will effectively activate a new migration period, so `IsActiveMigrationPeriod` will return true
|
||||
// until the migration is completed
|
||||
|
Loading…
Reference in New Issue
Block a user