diff --git a/cmd/start.go b/cmd/start.go index 24d72de..e14051d 100644 --- a/cmd/start.go +++ b/cmd/start.go @@ -6,6 +6,7 @@ import ( "strings" "github.com/xataio/pg-roll/pkg/migrations" + "github.com/xataio/pg-roll/pkg/roll" "github.com/spf13/cobra" ) @@ -44,7 +45,8 @@ func startCmd() *cobra.Command { } } - fmt.Printf("Migration successful!, new version of the schema available under postgres '%s' schema\n", version) + viewName := roll.VersionedSchemaName(Schema, version) + fmt.Printf("Migration successful! New version of the schema available under postgres '%s' schema\n", viewName) return nil }, } diff --git a/cmd/status.go b/cmd/status.go index 54c0233..7f37e33 100644 --- a/cmd/status.go +++ b/cmd/status.go @@ -42,8 +42,8 @@ var statusCmd = &cobra.Command{ }, } -func statusForSchema(ctx context.Context, state *state.State, schema string) (*statusLine, error) { - latestVersion, err := state.LatestVersion(ctx, schema) +func statusForSchema(ctx context.Context, st *state.State, schema string) (*statusLine, error) { + latestVersion, err := st.LatestVersion(ctx, schema) if err != nil { return nil, err } @@ -51,7 +51,7 @@ func statusForSchema(ctx context.Context, state *state.State, schema string) (*s latestVersion = new(string) } - isActive, err := state.IsActiveMigrationPeriod(ctx, schema) + isActive, err := st.IsActiveMigrationPeriod(ctx, schema) if err != nil { return nil, err } diff --git a/examples/05_sql.json b/examples/05_sql.json new file mode 100644 index 0000000..7eb801f --- /dev/null +++ b/examples/05_sql.json @@ -0,0 +1,11 @@ +{ + "name": "05_sql", + "operations": [ + { + "sql": { + "up": "CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)", + "down": "DROP TABLE users" + } + } + ] +} diff --git a/examples/06_add_column_to_sql_table.json b/examples/06_add_column_to_sql_table.json new file mode 100644 index 0000000..5473bef --- /dev/null +++ b/examples/06_add_column_to_sql_table.json @@ -0,0 +1,16 @@ +{ + "name": "06_add_column_to_sql_table", + "operations": [ + { + "add_column": { + "table": "users", + "up": "UPPER(name)", + "column": { + "name": "description", + "type": "varchar(255)", + "nullable": true + } + } + } + ] +} diff --git a/pkg/migrations/errors.go b/pkg/migrations/errors.go index 2945bfa..2009ce0 100644 --- a/pkg/migrations/errors.go +++ b/pkg/migrations/errors.go @@ -2,6 +2,20 @@ package migrations import "fmt" +type InvalidMigrationError struct { + Reason string +} + +func (e InvalidMigrationError) Error() string { + return e.Reason +} + +type EmptyMigrationError struct{} + +func (e EmptyMigrationError) Error() string { + return "migration is empty" +} + type TableAlreadyExistsError struct { Name string } diff --git a/pkg/migrations/migrations.go b/pkg/migrations/migrations.go index 8ec9c15..2128806 100644 --- a/pkg/migrations/migrations.go +++ b/pkg/migrations/migrations.go @@ -3,6 +3,7 @@ package migrations import ( "context" "database/sql" + "fmt" _ "github.com/lib/pq" "github.com/xataio/pg-roll/pkg/schema" @@ -27,6 +28,17 @@ type Operation interface { Validate(ctx context.Context, s *schema.Schema) error } +// IsolatedOperation is an operation that cannot be executed with other operations +// in the same migration +type IsolatedOperation interface { + IsIsolated() +} + +// RequiresSchemaRefreshOperation is an operation that requires the resulting schema to be refreshed +type RequiresSchemaRefreshOperation interface { + RequiresSchemaRefresh() +} + type ( Operations []Operation Migration struct { @@ -39,6 +51,14 @@ type ( // Validate will check that the migration can be applied to the given schema // returns a descriptive error if the migration is invalid func (m *Migration) Validate(ctx context.Context, s *schema.Schema) error { + for _, op := range m.Operations { + if _, ok := op.(IsolatedOperation); ok { + if len(m.Operations) > 1 { + return InvalidMigrationError{Reason: fmt.Sprintf("operation %q cannot be executed with other operations", OperationName(op))} + } + } + } + for _, op := range m.Operations { err := op.Validate(ctx, s) if err != nil { diff --git a/pkg/migrations/migrations_test.go b/pkg/migrations/migrations_test.go new file mode 100644 index 0000000..bbb6478 --- /dev/null +++ b/pkg/migrations/migrations_test.go @@ -0,0 +1,38 @@ +package migrations + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/xataio/pg-roll/pkg/schema" +) + +func TestMigrationsIsolated(t *testing.T) { + migration := Migration{ + Name: "sql", + Operations: Operations{ + &OpRawSQL{ + Up: `foo`, + }, + &OpRenameColumn{}, + }, + } + + err := migration.Validate(context.TODO(), schema.New()) + var wantErr InvalidMigrationError + assert.ErrorAs(t, err, &wantErr) +} + +func TestMigrationsIsolatedValid(t *testing.T) { + migration := Migration{ + Name: "sql", + Operations: Operations{ + &OpRawSQL{ + Up: `foo`, + }, + }, + } + err := migration.Validate(context.TODO(), schema.New()) + assert.NoError(t, err) +} diff --git a/pkg/migrations/op_common.go b/pkg/migrations/op_common.go index bc3a9e1..dc2c830 100644 --- a/pkg/migrations/op_common.go +++ b/pkg/migrations/op_common.go @@ -21,6 +21,7 @@ const ( OpNameRenameColumn OpName = "rename_column" OpNameSetUnique OpName = "set_unique" OpNameSetNotNull OpName = "set_not_null" + OpRawSQLName OpName = "sql" ) func TemporaryName(name string) string { @@ -106,6 +107,9 @@ func (v *Operations) UnmarshalJSON(data []byte) error { case OpNameSetNotNull: item = &OpSetNotNull{} + case OpRawSQLName: + item = &OpRawSQL{} + default: return fmt.Errorf("unknown migration type: %v", opName) } @@ -136,44 +140,8 @@ func (v Operations) MarshalJSON() ([]byte, error) { buf.WriteByte(',') } - var opName OpName - switch op.(type) { - case *OpCreateTable: - opName = OpNameCreateTable - - case *OpRenameTable: - opName = OpNameRenameTable - - case *OpDropTable: - opName = OpNameDropTable - - case *OpAddColumn: - opName = OpNameAddColumn - - case *OpDropColumn: - opName = OpNameDropColumn - - case *OpRenameColumn: - opName = OpNameRenameColumn - - case *OpCreateIndex: - opName = OpNameCreateIndex - - case *OpDropIndex: - opName = OpNameDropIndex - - case *OpSetUnique: - opName = OpNameSetUnique - - case *OpSetNotNull: - opName = OpNameSetNotNull - - default: - panic(fmt.Errorf("unknown operation for %T", op)) - } - buf.WriteString(`{"`) - buf.WriteString(string(opName)) + buf.WriteString(string(OperationName(op))) buf.WriteString(`":`) if err := enc.Encode(op); err != nil { return nil, fmt.Errorf("unable to encode op [%v]: %w", i, err) @@ -183,3 +151,43 @@ func (v Operations) MarshalJSON() ([]byte, error) { buf.WriteByte(']') return buf.Bytes(), nil } + +func OperationName(op Operation) OpName { + switch op.(type) { + case *OpCreateTable: + return OpNameCreateTable + + case *OpRenameTable: + return OpNameRenameTable + + case *OpDropTable: + return OpNameDropTable + + case *OpAddColumn: + return OpNameAddColumn + + case *OpDropColumn: + return OpNameDropColumn + + case *OpRenameColumn: + return OpNameRenameColumn + + case *OpCreateIndex: + return OpNameCreateIndex + + case *OpDropIndex: + return OpNameDropIndex + + case *OpSetUnique: + return OpNameSetUnique + + case *OpSetNotNull: + return OpNameSetNotNull + + case *OpRawSQL: + return OpRawSQLName + + } + + panic(fmt.Errorf("unknown operation for %T", op)) +} diff --git a/pkg/migrations/op_raw_sql.go b/pkg/migrations/op_raw_sql.go new file mode 100644 index 0000000..abcd684 --- /dev/null +++ b/pkg/migrations/op_raw_sql.go @@ -0,0 +1,49 @@ +package migrations + +import ( + "context" + "database/sql" + + "github.com/xataio/pg-roll/pkg/schema" +) + +var _ Operation = (*OpRawSQL)(nil) + +type OpRawSQL struct { + Up string `json:"up"` + Down string `json:"down,omitempty"` +} + +func (o *OpRawSQL) Start(ctx context.Context, conn *sql.DB, schemaName, stateSchema string, s *schema.Schema) error { + _, err := conn.ExecContext(ctx, o.Up) + if err != nil { + return err + } + return nil +} + +func (o *OpRawSQL) Complete(ctx context.Context, conn *sql.DB) error { + return nil +} + +func (o *OpRawSQL) Rollback(ctx context.Context, conn *sql.DB) error { + if o.Down != "" { + _, err := conn.ExecContext(ctx, o.Down) + return err + } + return nil +} + +func (o *OpRawSQL) Validate(ctx context.Context, s *schema.Schema) error { + if o.Up == "" { + return EmptyMigrationError{} + } + + return nil +} + +// this operation is isolated, cannot be executed with other operations +func (o *OpRawSQL) IsIsolated() {} + +// this operation requires the resulting schema to be refreshed +func (o *OpRawSQL) RequiresSchemaRefresh() {} diff --git a/pkg/migrations/op_raw_sql_test.go b/pkg/migrations/op_raw_sql_test.go new file mode 100644 index 0000000..0985543 --- /dev/null +++ b/pkg/migrations/op_raw_sql_test.go @@ -0,0 +1,109 @@ +package migrations_test + +import ( + "database/sql" + "testing" + + "github.com/xataio/pg-roll/pkg/migrations" +) + +func TestRawSQL(t *testing.T) { + t.Parallel() + + ExecuteTests(t, TestCases{ + { + name: "raw SQL", + migrations: []migrations.Migration{ + { + Name: "01_create_table", + Operations: migrations.Operations{ + &migrations.OpRawSQL{ + Up: ` + CREATE TABLE test_table ( + id serial, + name text + ) + `, + Down: ` + DROP TABLE test_table + `, + }, + }, + }, + }, + afterStart: func(t *testing.T, db *sql.DB) { + // table can be accessed after start + ViewMustExist(t, db, "public", "01_create_table", "test_table") + + // inserts work + MustInsert(t, db, "public", "01_create_table", "test_table", map[string]string{ + "name": "foo", + }) + }, + afterRollback: func(t *testing.T, db *sql.DB) { + // table is dropped after rollback + TableMustNotExist(t, db, "public", "test_table") + }, + afterComplete: func(t *testing.T, db *sql.DB) { + // inserts still work after complete + MustInsert(t, db, "public", "01_create_table", "test_table", map[string]string{ + "name": "foo", + }) + }, + }, + { + name: "migration on top of raw SQL", + migrations: []migrations.Migration{ + { + Name: "01_create_table", + Operations: migrations.Operations{ + &migrations.OpRawSQL{ + Up: ` + CREATE TABLE test_table ( + id serial, + name text + ) + `, + Down: ` + DROP TABLE test_table + `, + }, + }, + }, + { + Name: "02_rename_table", + Operations: migrations.Operations{ + &migrations.OpRenameTable{ + From: "test_table", + To: "test_table_renamed", + }, + }, + }, + }, + afterStart: func(t *testing.T, db *sql.DB) { + // table can be accessed after start + ViewMustExist(t, db, "public", "01_create_table", "test_table") + + // table is renamed in new version + ViewMustExist(t, db, "public", "02_rename_table", "test_table_renamed") + + // inserts work + MustInsert(t, db, "public", "01_create_table", "test_table", map[string]string{ + "name": "foo", + }) + MustInsert(t, db, "public", "02_rename_table", "test_table_renamed", map[string]string{ + "name": "foo", + }) + }, + afterComplete: func(t *testing.T, db *sql.DB) { + // table can still be accessed after complete + ViewMustExist(t, db, "public", "02_rename_table", "test_table_renamed") + + // inserts work + MustInsert(t, db, "public", "02_rename_table", "test_table_renamed", map[string]string{ + "name": "foo", + }) + }, + }, + }) +} diff --git a/pkg/roll/execute.go b/pkg/roll/execute.go index 39aab46..721bc95 100644 --- a/pkg/roll/execute.go +++ b/pkg/roll/execute.go @@ -42,6 +42,14 @@ func (m *Roll) Start(ctx context.Context, migration *migrations.Migration) error if err != nil { return fmt.Errorf("unable to execute start operation: %w", err) } + + if _, ok := op.(migrations.RequiresSchemaRefreshOperation); ok { + // refresh schema + newSchema, err = m.state.ReadSchema(ctx, m.schema) + if err != nil { + return fmt.Errorf("unable to refresh schema: %w", err) + } + } } // create schema for the new version diff --git a/pkg/roll/roll.go b/pkg/roll/roll.go index bb1d7d1..044e2c5 100644 --- a/pkg/roll/roll.go +++ b/pkg/roll/roll.go @@ -3,6 +3,7 @@ package roll import ( "context" "database/sql" + "fmt" "github.com/lib/pq" @@ -31,6 +32,11 @@ func New(ctx context.Context, pgURL, schema string, state *state.State) (*Roll, return nil, err } + _, err = conn.ExecContext(ctx, "SET LOCAL pgroll.internal to 'TRUE'") + if err != nil { + return nil, fmt.Errorf("unable to set pgroll.internal to true: %w", err) + } + return &Roll{ pgConn: conn, schema: schema, diff --git a/pkg/schema/schema.go b/pkg/schema/schema.go index cfab465..e51a586 100644 --- a/pkg/schema/schema.go +++ b/pkg/schema/schema.go @@ -1,6 +1,11 @@ package schema -import "fmt" +import ( + "database/sql/driver" + "encoding/json" + "errors" + "fmt" +) // XXX we create a view of the schema with the minimum required for us to // know how to execute migrations and build views for the new schema version. @@ -57,6 +62,11 @@ type Index struct { Name string `json:"name"` } +// Replace replaces the contents of the schema with the contents of the given one +func (s *Schema) Replace(other *Schema) { + s.Tables = other.Tables +} + func (s *Schema) GetTable(name string) *Table { if s.Tables == nil { return nil @@ -121,3 +131,20 @@ func (t *Table) RenameColumn(from, to string) { t.Columns[to] = t.Columns[from] delete(t.Columns, from) } + +// Make the Schema struct implement the driver.Valuer interface. This method +// simply returns the JSON-encoded representation of the struct. +func (s Schema) Value() (driver.Value, error) { + return json.Marshal(s) +} + +// Make the Schema struct implement the sql.Scanner interface. This method +// simply decodes a JSON-encoded value into the struct fields. +func (s *Schema) Scan(value interface{}) error { + b, ok := value.([]byte) + if !ok { + return errors.New("type assertion to []byte failed") + } + + return json.Unmarshal(b, &s) +} diff --git a/pkg/state/state.go b/pkg/state/state.go index f4f59dd..2887e8e 100644 --- a/pkg/state/state.go +++ b/pkg/state/state.go @@ -140,6 +140,52 @@ BEGIN RETURN tables; END; $$; + +CREATE OR REPLACE FUNCTION %[1]s.raw_migration() RETURNS event_trigger +LANGUAGE plpgsql AS $$ +DECLARE + schemaname TEXT; +BEGIN + -- Ignore migrations done by pg-roll + IF (current_setting('pgroll.internal', 'TRUE') <> 'TRUE') THEN + RETURN; + END IF; + + -- Guess the schema from ddl commands, ignore migrations that touch several schemas + IF (SELECT COUNT(DISTINCT schema_name) FROM pg_event_trigger_ddl_commands() WHERE schema_name IS NOT NULL) > 1 THEN + RAISE NOTICE 'pg-roll: ignoring migration that touches several schemas'; + RETURN; + END IF; + + SELECT schema_name INTO schemaname FROM pg_event_trigger_ddl_commands() WHERE schema_name IS NOT NULL; + + IF schemaname IS NULL THEN + RAISE NOTICE 'pg-roll: ignoring migration with null schema'; + RETURN; + END IF; + + -- Ignore migrations done during a migration period + IF %[1]s.is_active_migration_period(schemaname) THEN + RAISE NOTICE 'pg-roll: ignoring migration during active migration period'; + RETURN; + END IF; + + -- Someone did a schema change without pg-roll, include it in the history + INSERT INTO %[1]s.migrations (schema, name, migration, resulting_schema, done, parent) + VALUES ( + schemaname, + format('sql_%%s', substr(md5(random()::text), 0, 15)), + json_build_object('sql', json_build_object('up', current_query())), + %[1]s.read_schema(schemaname), + true, + %[1]s.latest_version(schemaname) + ); +END; +$$; + +DROP EVENT TRIGGER IF EXISTS pg_roll_handle_ddl; +CREATE EVENT TRIGGER pg_roll_handle_ddl ON ddl_command_end + EXECUTE FUNCTION %[1]s.raw_migration() ; ` type State struct { @@ -153,6 +199,11 @@ func New(ctx context.Context, pgURL, stateSchema string) (*State, error) { return nil, err } + _, err = conn.ExecContext(ctx, "SET LOCAL pgroll.internal to 'TRUE'") + if err != nil { + return nil, fmt.Errorf("unable to set pgroll.internal to true: %w", err) + } + return &State{ pgConn: conn, schema: stateSchema, @@ -183,7 +234,7 @@ func (s *State) IsActiveMigrationPeriod(ctx context.Context, schema string) (boo return false, err } - return isActive, err + return isActive, nil } // GetActiveMigration returns the name & raw content of the active migration (if any), errors out otherwise @@ -232,6 +283,17 @@ func (s *State) PreviousVersion(ctx context.Context, schema string) (*string, er return parent, nil } +// ReadSchema reads & returns the current schema from postgres +func ReadSchema(ctx context.Context, conn *sql.DB, stateSchema, schemaname string) (*schema.Schema, error) { + var res schema.Schema + err := conn.QueryRowContext(ctx, fmt.Sprintf("SELECT %[1]s.read_schema($1)", pq.QuoteIdentifier(stateSchema)), schemaname).Scan(&res) + if err != nil { + return nil, err + } + + return &res, nil +} + // Start creates a new migration, storing its name and raw content // this will effectively activate a new migration period, so `IsActiveMigrationPeriod` will return true // until the migration is completed