Add raw SQL operation (#43)

This change adds a new `sql` operation, that allows to define an `up`
SQL statement to perform a migration on the schema.

An optional `down` field can be provided, this will be used when trying
to do a rollback after (for instance, in case of migration failure).

A new trigger is installed to capture DDL events coming from direct user
manipulations (not done by pg-roll), so they are stored as a migration,
getting to know the resulting schema in all cases.
This commit is contained in:
Carlos Pérez-Aradros Herce 2023-08-30 11:50:59 +02:00 committed by GitHub
parent 2702343334
commit 16b1d75ee0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 413 additions and 43 deletions

View File

@ -6,6 +6,7 @@ import (
"strings"
"github.com/xataio/pg-roll/pkg/migrations"
"github.com/xataio/pg-roll/pkg/roll"
"github.com/spf13/cobra"
)
@ -44,7 +45,8 @@ func startCmd() *cobra.Command {
}
}
fmt.Printf("Migration successful!, new version of the schema available under postgres '%s' schema\n", version)
viewName := roll.VersionedSchemaName(Schema, version)
fmt.Printf("Migration successful! New version of the schema available under postgres '%s' schema\n", viewName)
return nil
},
}

View File

@ -42,8 +42,8 @@ var statusCmd = &cobra.Command{
},
}
func statusForSchema(ctx context.Context, state *state.State, schema string) (*statusLine, error) {
latestVersion, err := state.LatestVersion(ctx, schema)
func statusForSchema(ctx context.Context, st *state.State, schema string) (*statusLine, error) {
latestVersion, err := st.LatestVersion(ctx, schema)
if err != nil {
return nil, err
}
@ -51,7 +51,7 @@ func statusForSchema(ctx context.Context, state *state.State, schema string) (*s
latestVersion = new(string)
}
isActive, err := state.IsActiveMigrationPeriod(ctx, schema)
isActive, err := st.IsActiveMigrationPeriod(ctx, schema)
if err != nil {
return nil, err
}

11
examples/05_sql.json Normal file
View File

@ -0,0 +1,11 @@
{
"name": "05_sql",
"operations": [
{
"sql": {
"up": "CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)",
"down": "DROP TABLE users"
}
}
]
}

View File

@ -0,0 +1,16 @@
{
"name": "06_add_column_to_sql_table",
"operations": [
{
"add_column": {
"table": "users",
"up": "UPPER(name)",
"column": {
"name": "description",
"type": "varchar(255)",
"nullable": true
}
}
}
]
}

View File

@ -2,6 +2,20 @@ package migrations
import "fmt"
type InvalidMigrationError struct {
Reason string
}
func (e InvalidMigrationError) Error() string {
return e.Reason
}
type EmptyMigrationError struct{}
func (e EmptyMigrationError) Error() string {
return "migration is empty"
}
type TableAlreadyExistsError struct {
Name string
}

View File

@ -3,6 +3,7 @@ package migrations
import (
"context"
"database/sql"
"fmt"
_ "github.com/lib/pq"
"github.com/xataio/pg-roll/pkg/schema"
@ -27,6 +28,17 @@ type Operation interface {
Validate(ctx context.Context, s *schema.Schema) error
}
// IsolatedOperation is an operation that cannot be executed with other operations
// in the same migration
type IsolatedOperation interface {
IsIsolated()
}
// RequiresSchemaRefreshOperation is an operation that requires the resulting schema to be refreshed
type RequiresSchemaRefreshOperation interface {
RequiresSchemaRefresh()
}
type (
Operations []Operation
Migration struct {
@ -39,6 +51,14 @@ type (
// Validate will check that the migration can be applied to the given schema
// returns a descriptive error if the migration is invalid
func (m *Migration) Validate(ctx context.Context, s *schema.Schema) error {
for _, op := range m.Operations {
if _, ok := op.(IsolatedOperation); ok {
if len(m.Operations) > 1 {
return InvalidMigrationError{Reason: fmt.Sprintf("operation %q cannot be executed with other operations", OperationName(op))}
}
}
}
for _, op := range m.Operations {
err := op.Validate(ctx, s)
if err != nil {

View File

@ -0,0 +1,38 @@
package migrations
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/xataio/pg-roll/pkg/schema"
)
func TestMigrationsIsolated(t *testing.T) {
migration := Migration{
Name: "sql",
Operations: Operations{
&OpRawSQL{
Up: `foo`,
},
&OpRenameColumn{},
},
}
err := migration.Validate(context.TODO(), schema.New())
var wantErr InvalidMigrationError
assert.ErrorAs(t, err, &wantErr)
}
func TestMigrationsIsolatedValid(t *testing.T) {
migration := Migration{
Name: "sql",
Operations: Operations{
&OpRawSQL{
Up: `foo`,
},
},
}
err := migration.Validate(context.TODO(), schema.New())
assert.NoError(t, err)
}

View File

@ -21,6 +21,7 @@ const (
OpNameRenameColumn OpName = "rename_column"
OpNameSetUnique OpName = "set_unique"
OpNameSetNotNull OpName = "set_not_null"
OpRawSQLName OpName = "sql"
)
func TemporaryName(name string) string {
@ -106,6 +107,9 @@ func (v *Operations) UnmarshalJSON(data []byte) error {
case OpNameSetNotNull:
item = &OpSetNotNull{}
case OpRawSQLName:
item = &OpRawSQL{}
default:
return fmt.Errorf("unknown migration type: %v", opName)
}
@ -136,44 +140,8 @@ func (v Operations) MarshalJSON() ([]byte, error) {
buf.WriteByte(',')
}
var opName OpName
switch op.(type) {
case *OpCreateTable:
opName = OpNameCreateTable
case *OpRenameTable:
opName = OpNameRenameTable
case *OpDropTable:
opName = OpNameDropTable
case *OpAddColumn:
opName = OpNameAddColumn
case *OpDropColumn:
opName = OpNameDropColumn
case *OpRenameColumn:
opName = OpNameRenameColumn
case *OpCreateIndex:
opName = OpNameCreateIndex
case *OpDropIndex:
opName = OpNameDropIndex
case *OpSetUnique:
opName = OpNameSetUnique
case *OpSetNotNull:
opName = OpNameSetNotNull
default:
panic(fmt.Errorf("unknown operation for %T", op))
}
buf.WriteString(`{"`)
buf.WriteString(string(opName))
buf.WriteString(string(OperationName(op)))
buf.WriteString(`":`)
if err := enc.Encode(op); err != nil {
return nil, fmt.Errorf("unable to encode op [%v]: %w", i, err)
@ -183,3 +151,43 @@ func (v Operations) MarshalJSON() ([]byte, error) {
buf.WriteByte(']')
return buf.Bytes(), nil
}
func OperationName(op Operation) OpName {
switch op.(type) {
case *OpCreateTable:
return OpNameCreateTable
case *OpRenameTable:
return OpNameRenameTable
case *OpDropTable:
return OpNameDropTable
case *OpAddColumn:
return OpNameAddColumn
case *OpDropColumn:
return OpNameDropColumn
case *OpRenameColumn:
return OpNameRenameColumn
case *OpCreateIndex:
return OpNameCreateIndex
case *OpDropIndex:
return OpNameDropIndex
case *OpSetUnique:
return OpNameSetUnique
case *OpSetNotNull:
return OpNameSetNotNull
case *OpRawSQL:
return OpRawSQLName
}
panic(fmt.Errorf("unknown operation for %T", op))
}

View File

@ -0,0 +1,49 @@
package migrations
import (
"context"
"database/sql"
"github.com/xataio/pg-roll/pkg/schema"
)
var _ Operation = (*OpRawSQL)(nil)
type OpRawSQL struct {
Up string `json:"up"`
Down string `json:"down,omitempty"`
}
func (o *OpRawSQL) Start(ctx context.Context, conn *sql.DB, schemaName, stateSchema string, s *schema.Schema) error {
_, err := conn.ExecContext(ctx, o.Up)
if err != nil {
return err
}
return nil
}
func (o *OpRawSQL) Complete(ctx context.Context, conn *sql.DB) error {
return nil
}
func (o *OpRawSQL) Rollback(ctx context.Context, conn *sql.DB) error {
if o.Down != "" {
_, err := conn.ExecContext(ctx, o.Down)
return err
}
return nil
}
func (o *OpRawSQL) Validate(ctx context.Context, s *schema.Schema) error {
if o.Up == "" {
return EmptyMigrationError{}
}
return nil
}
// this operation is isolated, cannot be executed with other operations
func (o *OpRawSQL) IsIsolated() {}
// this operation requires the resulting schema to be refreshed
func (o *OpRawSQL) RequiresSchemaRefresh() {}

View File

@ -0,0 +1,109 @@
package migrations_test
import (
"database/sql"
"testing"
"github.com/xataio/pg-roll/pkg/migrations"
)
func TestRawSQL(t *testing.T) {
t.Parallel()
ExecuteTests(t, TestCases{
{
name: "raw SQL",
migrations: []migrations.Migration{
{
Name: "01_create_table",
Operations: migrations.Operations{
&migrations.OpRawSQL{
Up: `
CREATE TABLE test_table (
id serial,
name text
)
`,
Down: `
DROP TABLE test_table
`,
},
},
},
},
afterStart: func(t *testing.T, db *sql.DB) {
// table can be accessed after start
ViewMustExist(t, db, "public", "01_create_table", "test_table")
// inserts work
MustInsert(t, db, "public", "01_create_table", "test_table", map[string]string{
"name": "foo",
})
},
afterRollback: func(t *testing.T, db *sql.DB) {
// table is dropped after rollback
TableMustNotExist(t, db, "public", "test_table")
},
afterComplete: func(t *testing.T, db *sql.DB) {
// inserts still work after complete
MustInsert(t, db, "public", "01_create_table", "test_table", map[string]string{
"name": "foo",
})
},
},
{
name: "migration on top of raw SQL",
migrations: []migrations.Migration{
{
Name: "01_create_table",
Operations: migrations.Operations{
&migrations.OpRawSQL{
Up: `
CREATE TABLE test_table (
id serial,
name text
)
`,
Down: `
DROP TABLE test_table
`,
},
},
},
{
Name: "02_rename_table",
Operations: migrations.Operations{
&migrations.OpRenameTable{
From: "test_table",
To: "test_table_renamed",
},
},
},
},
afterStart: func(t *testing.T, db *sql.DB) {
// table can be accessed after start
ViewMustExist(t, db, "public", "01_create_table", "test_table")
// table is renamed in new version
ViewMustExist(t, db, "public", "02_rename_table", "test_table_renamed")
// inserts work
MustInsert(t, db, "public", "01_create_table", "test_table", map[string]string{
"name": "foo",
})
MustInsert(t, db, "public", "02_rename_table", "test_table_renamed", map[string]string{
"name": "foo",
})
},
afterComplete: func(t *testing.T, db *sql.DB) {
// table can still be accessed after complete
ViewMustExist(t, db, "public", "02_rename_table", "test_table_renamed")
// inserts work
MustInsert(t, db, "public", "02_rename_table", "test_table_renamed", map[string]string{
"name": "foo",
})
},
},
})
}

View File

@ -42,6 +42,14 @@ func (m *Roll) Start(ctx context.Context, migration *migrations.Migration) error
if err != nil {
return fmt.Errorf("unable to execute start operation: %w", err)
}
if _, ok := op.(migrations.RequiresSchemaRefreshOperation); ok {
// refresh schema
newSchema, err = m.state.ReadSchema(ctx, m.schema)
if err != nil {
return fmt.Errorf("unable to refresh schema: %w", err)
}
}
}
// create schema for the new version

View File

@ -3,6 +3,7 @@ package roll
import (
"context"
"database/sql"
"fmt"
"github.com/lib/pq"
@ -31,6 +32,11 @@ func New(ctx context.Context, pgURL, schema string, state *state.State) (*Roll,
return nil, err
}
_, err = conn.ExecContext(ctx, "SET LOCAL pgroll.internal to 'TRUE'")
if err != nil {
return nil, fmt.Errorf("unable to set pgroll.internal to true: %w", err)
}
return &Roll{
pgConn: conn,
schema: schema,

View File

@ -1,6 +1,11 @@
package schema
import "fmt"
import (
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
)
// XXX we create a view of the schema with the minimum required for us to
// know how to execute migrations and build views for the new schema version.
@ -57,6 +62,11 @@ type Index struct {
Name string `json:"name"`
}
// Replace replaces the contents of the schema with the contents of the given one
func (s *Schema) Replace(other *Schema) {
s.Tables = other.Tables
}
func (s *Schema) GetTable(name string) *Table {
if s.Tables == nil {
return nil
@ -121,3 +131,20 @@ func (t *Table) RenameColumn(from, to string) {
t.Columns[to] = t.Columns[from]
delete(t.Columns, from)
}
// Make the Schema struct implement the driver.Valuer interface. This method
// simply returns the JSON-encoded representation of the struct.
func (s Schema) Value() (driver.Value, error) {
return json.Marshal(s)
}
// Make the Schema struct implement the sql.Scanner interface. This method
// simply decodes a JSON-encoded value into the struct fields.
func (s *Schema) Scan(value interface{}) error {
b, ok := value.([]byte)
if !ok {
return errors.New("type assertion to []byte failed")
}
return json.Unmarshal(b, &s)
}

View File

@ -140,6 +140,52 @@ BEGIN
RETURN tables;
END;
$$;
CREATE OR REPLACE FUNCTION %[1]s.raw_migration() RETURNS event_trigger
LANGUAGE plpgsql AS $$
DECLARE
schemaname TEXT;
BEGIN
-- Ignore migrations done by pg-roll
IF (current_setting('pgroll.internal', 'TRUE') <> 'TRUE') THEN
RETURN;
END IF;
-- Guess the schema from ddl commands, ignore migrations that touch several schemas
IF (SELECT COUNT(DISTINCT schema_name) FROM pg_event_trigger_ddl_commands() WHERE schema_name IS NOT NULL) > 1 THEN
RAISE NOTICE 'pg-roll: ignoring migration that touches several schemas';
RETURN;
END IF;
SELECT schema_name INTO schemaname FROM pg_event_trigger_ddl_commands() WHERE schema_name IS NOT NULL;
IF schemaname IS NULL THEN
RAISE NOTICE 'pg-roll: ignoring migration with null schema';
RETURN;
END IF;
-- Ignore migrations done during a migration period
IF %[1]s.is_active_migration_period(schemaname) THEN
RAISE NOTICE 'pg-roll: ignoring migration during active migration period';
RETURN;
END IF;
-- Someone did a schema change without pg-roll, include it in the history
INSERT INTO %[1]s.migrations (schema, name, migration, resulting_schema, done, parent)
VALUES (
schemaname,
format('sql_%%s', substr(md5(random()::text), 0, 15)),
json_build_object('sql', json_build_object('up', current_query())),
%[1]s.read_schema(schemaname),
true,
%[1]s.latest_version(schemaname)
);
END;
$$;
DROP EVENT TRIGGER IF EXISTS pg_roll_handle_ddl;
CREATE EVENT TRIGGER pg_roll_handle_ddl ON ddl_command_end
EXECUTE FUNCTION %[1]s.raw_migration() ;
`
type State struct {
@ -153,6 +199,11 @@ func New(ctx context.Context, pgURL, stateSchema string) (*State, error) {
return nil, err
}
_, err = conn.ExecContext(ctx, "SET LOCAL pgroll.internal to 'TRUE'")
if err != nil {
return nil, fmt.Errorf("unable to set pgroll.internal to true: %w", err)
}
return &State{
pgConn: conn,
schema: stateSchema,
@ -183,7 +234,7 @@ func (s *State) IsActiveMigrationPeriod(ctx context.Context, schema string) (boo
return false, err
}
return isActive, err
return isActive, nil
}
// GetActiveMigration returns the name & raw content of the active migration (if any), errors out otherwise
@ -232,6 +283,17 @@ func (s *State) PreviousVersion(ctx context.Context, schema string) (*string, er
return parent, nil
}
// ReadSchema reads & returns the current schema from postgres
func ReadSchema(ctx context.Context, conn *sql.DB, stateSchema, schemaname string) (*schema.Schema, error) {
var res schema.Schema
err := conn.QueryRowContext(ctx, fmt.Sprintf("SELECT %[1]s.read_schema($1)", pq.QuoteIdentifier(stateSchema)), schemaname).Scan(&res)
if err != nil {
return nil, err
}
return &res, nil
}
// Start creates a new migration, storing its name and raw content
// this will effectively activate a new migration period, so `IsActiveMigrationPeriod` will return true
// until the migration is completed