Add migrations state handling (#7)

This migrations introduces state handling by creating a dedicated
`pgroll` schema (name configurable). We will store migrations there, as
well as their state. So we keep some useful information, ie the
migration definition (so we don't need it for the `complete` state).

Schema includes the proper constraints to guarantee that:
* Only a migration is active at a time
* Migration history is linear (all migrations have a unique parent,
except the first one which is NULL)
* We now the current migration at all times

Some helper functions are included:

* `is_active_migration_period()` will return true if there is an active
migration.
* `latest_version()` will return the name of the latest version of the
schema.
This commit is contained in:
Carlos Pérez-Aradros Herce 2023-06-28 11:10:03 +02:00 committed by GitHub
parent f9a530c900
commit a8c4fddd14
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 601 additions and 249 deletions

View File

@ -8,29 +8,41 @@ PostgreSQL zero-downtime migrations made easy.
* Bring a development PostgreSQL up:
```sh
docker compose up
```
```sh
docker compose up
```
* Initialize pg-roll (first time only):
```sh
go run . init
```
* Start a migration:
```sh
go run . start examples/01_create_tables.json
```
```sh
go run . start examples/01_create_tables.json
```
* Inspect the results:
```sh
psql postgres://localhost -U postgres
```
```sh
psql postgres://localhost -U postgres
```
```sql
\d+ public.*
\d+ 01_create_tables.*
```
```sql
\d+ public.*
\d+ 01_create_tables.*
```
* (Optional) Rollback the migration (undo):
```sh
go run . rollback
```
* Complete the migration:
```sh
go run . complete examples/01_create_tables.json
```
```sh
go run . complete
```

View File

@ -13,7 +13,6 @@ import (
var analyzeCmd = &cobra.Command{
Use: "analyze",
Short: "Analyze the SQL schema of the target database",
Long: "Analyse the SQL schema of the target database and output the result as JSON",
Hidden: true,
RunE: func(_ *cobra.Command, _ []string) error {
db, err := sql.Open("postgres", PGURL)

View File

@ -2,10 +2,6 @@ package cmd
import (
"fmt"
"path/filepath"
"strings"
"pg-roll/pkg/migrations"
"github.com/spf13/cobra"
)
@ -13,25 +9,14 @@ import (
var completeCmd = &cobra.Command{
Use: "complete <file>",
Short: "Complete an ongoing migration with the operations present in the given file",
Long: `TODO: Add long description`,
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
fileName := args[0]
m, err := migrations.New(cmd.Context(), PGURL)
m, err := NewRoll(cmd.Context())
if err != nil {
return err
}
defer m.Close()
ops, err := migrations.ReadMigrationFile(args[0])
if err != nil {
return fmt.Errorf("reading migration file: %w", err)
}
version := strings.TrimSuffix(filepath.Base(fileName), filepath.Ext(fileName))
err = m.Complete(cmd.Context(), version, ops)
err = m.Complete(cmd.Context())
if err != nil {
return err
}

27
cmd/init.go Normal file
View File

@ -0,0 +1,27 @@
package cmd
import (
"fmt"
"github.com/spf13/cobra"
)
var initCmd = &cobra.Command{
Use: "init <file>",
Short: "Initializes pg-roll, creating the required pg_roll schema to store state",
RunE: func(cmd *cobra.Command, args []string) error {
m, err := NewRoll(cmd.Context())
if err != nil {
return err
}
defer m.Close()
err = m.Init(cmd.Context())
if err != nil {
return err
}
fmt.Printf("Initialization done! pg-roll is ready to be used\n")
return nil
},
}

View File

@ -2,10 +2,6 @@ package cmd
import (
"fmt"
"path/filepath"
"strings"
"pg-roll/pkg/migrations"
"github.com/spf13/cobra"
)
@ -13,30 +9,19 @@ import (
var rollbackCmd = &cobra.Command{
Use: "rollback <file>",
Short: "Roll back an ongoing migration",
Long: "Roll back an ongoing migration. This will revert the changes made by the migration.",
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
fileName := args[0]
m, err := migrations.New(cmd.Context(), PGURL)
m, err := NewRoll(cmd.Context())
if err != nil {
return err
}
defer m.Close()
ops, err := migrations.ReadMigrationFile(args[0])
if err != nil {
return fmt.Errorf("reading migration file: %w", err)
}
version := strings.TrimSuffix(filepath.Base(fileName), filepath.Ext(fileName))
err = m.Rollback(cmd.Context(), version, ops)
err = m.Rollback(cmd.Context())
if err != nil {
return err
}
fmt.Printf("Migration rolled back. Changes made by %q have been reverted.\n", version)
fmt.Printf("Migration rolled back. Changes made since the last version have been reverted.\n")
return nil
},
}

View File

@ -1,22 +1,45 @@
package cmd
import (
"context"
"pg-roll/pkg/roll"
"pg-roll/pkg/state"
"github.com/spf13/cobra"
)
var PGURL string
var (
// PGURL is the Postgres URL to connect to
PGURL string
// Schema is the schema to use for the migration
Schema string
// StateSchema is the Postgres schema where pg-roll will store its state
StateSchema string
)
func init() {
rootCmd.PersistentFlags().StringVar(&PGURL, "postgres_url", "postgres://postgres:postgres@localhost?sslmode=disable", "Postgres URL")
rootCmd.PersistentFlags().StringVar(&PGURL, "postgres-url", "postgres://postgres:postgres@localhost?sslmode=disable", "Postgres URL")
rootCmd.PersistentFlags().StringVar(&Schema, "schema", "public", "Postgres schema to use for the migration")
rootCmd.PersistentFlags().StringVar(&StateSchema, "pgroll-schema", "pgroll", "Postgres schema in which the migration should be applied")
}
var rootCmd = &cobra.Command{
Use: "pg-roll",
Short: "TODO: Add short description",
Long: `TODO: Add long description`,
SilenceUsage: true,
}
func NewRoll(ctx context.Context) (*roll.Roll, error) {
state, err := state.New(ctx, PGURL, StateSchema)
if err != nil {
return nil, err
}
return roll.New(ctx, PGURL, Schema, state)
}
// Execute executes the root command.
func Execute() error {
// register subcommands
@ -24,6 +47,7 @@ func Execute() error {
rootCmd.AddCommand(completeCmd)
rootCmd.AddCommand(rollbackCmd)
rootCmd.AddCommand(analyzeCmd)
rootCmd.AddCommand(initCmd)
return rootCmd.Execute()
}

View File

@ -13,25 +13,24 @@ import (
var startCmd = &cobra.Command{
Use: "start <file>",
Short: "Start a migration for the operations present in the given file",
Long: `TODO: Add long description`,
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
fileName := args[0]
m, err := migrations.New(cmd.Context(), PGURL)
m, err := NewRoll(cmd.Context())
if err != nil {
return err
}
defer m.Close()
ops, err := migrations.ReadMigrationFile(args[0])
migration, err := migrations.ReadMigrationFile(args[0])
if err != nil {
return fmt.Errorf("reading migration file: %w", err)
}
version := strings.TrimSuffix(filepath.Base(fileName), filepath.Ext(fileName))
err = m.Start(cmd.Context(), version, ops)
err = m.Start(cmd.Context(), migration)
if err != nil {
return err
}

View File

@ -1,44 +1,47 @@
[
{
"create_table": {
"name": "customers",
"columns": [
{
"name": "id",
"type": "integer",
"pk": true
},
{
"name": "name",
"type": "varchar(255)",
"unique": true
},
{
"name": "credit_card",
"type": "text",
"nullable": true
}
]
{
"name": "01_create_tables",
"operations": [
{
"create_table": {
"name": "customers",
"columns": [
{
"name": "id",
"type": "integer",
"pk": true
},
{
"name": "name",
"type": "varchar(255)",
"unique": true
},
{
"name": "credit_card",
"type": "text",
"nullable": true
}
]
}
},
{
"create_table": {
"name": "bills",
"columns": [
{
"name": "id",
"type": "integer",
"pk": true
},
{
"name": "date",
"type": "time with time zone"
},
{
"name": "quantity",
"type": "integer"
}
]
}
}
},
{
"create_table": {
"name": "bills",
"columns": [
{
"name": "id",
"type": "integer",
"pk": true
},
{
"name": "date",
"type": "time with time zone"
},
{
"name": "quantity",
"type": "integer"
}
]
}
}
]
]
}

View File

@ -0,0 +1,26 @@
{
"name": "02_create_another_table",
"operations": [
{
"create_table": {
"name": "products",
"columns": [
{
"name": "id",
"type": "serial",
"pk": true
},
{
"name": "name",
"type": "varchar(255)",
"unique": true
},
{
"name": "price",
"type": "decimal(10,2)"
}
]
}
}
]
}

View File

@ -1,97 +0,0 @@
package migrations
import (
"context"
"fmt"
"strings"
"pg-roll/pkg/schema"
"github.com/lib/pq"
)
// Start will apply the required changes to enable supporting the new schema version
func (m *Migrations) Start(ctx context.Context, version string, ops Operations) error {
newSchema := schema.New()
// execute operations
for _, op := range ops {
err := op.Start(ctx, m.pgConn, newSchema)
if err != nil {
return fmt.Errorf("unable to execute start operation: %w", err)
}
}
// create schema for the new version
_, err := m.pgConn.ExecContext(ctx, fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", pq.QuoteIdentifier(version)))
if err != nil {
return err
}
// create views in the new schema
for name, table := range newSchema.Tables {
err = m.createView(ctx, version, name, table)
if err != nil {
return fmt.Errorf("unable to create view: %w", err)
}
}
return nil
}
// Complete will update the database schema to match the current version
func (m *Migrations) Complete(ctx context.Context, version string, ops Operations) error {
// execute operations
for _, op := range ops {
err := op.Complete(ctx, m.pgConn)
if err != nil {
return fmt.Errorf("unable to execute complete operation: %w", err)
}
}
// TODO: once we have state, drop views for previous versions
return nil
}
func (m *Migrations) Rollback(ctx context.Context, version string, ops Operations) error {
// delete the schema and view for the new version
_, err := m.pgConn.ExecContext(ctx, fmt.Sprintf("DROP SCHEMA IF EXISTS %s CASCADE", pq.QuoteIdentifier(version)))
if err != nil {
return err
}
// reverse the order of the operations so that they are undone in the correct order
for i, j := 0, len(ops)-1; i < j; i, j = i+1, j-1 {
ops[i], ops[j] = ops[j], ops[i]
}
// execute operations
for _, op := range ops {
err := op.Rollback(ctx, m.pgConn)
if err != nil {
return fmt.Errorf("unable to execute rollback operation: %w", err)
}
}
return nil
}
// create view creates a view for the new version of the schema
func (m *Migrations) createView(ctx context.Context, version string, name string, table schema.Table) error {
columns := make([]string, 0, len(table.Columns))
for k, v := range table.Columns {
columns = append(columns, fmt.Sprintf("%s AS %s", pq.QuoteIdentifier(k), pq.QuoteIdentifier(v.Name)))
}
_, err := m.pgConn.ExecContext(ctx,
fmt.Sprintf("CREATE OR REPLACE VIEW %s.%s AS SELECT %s FROM %s",
pq.QuoteIdentifier(version),
pq.QuoteIdentifier(name),
strings.Join(columns, ","),
pq.QuoteIdentifier(table.Name)))
if err != nil {
return err
}
return nil
}

View File

@ -9,10 +9,6 @@ import (
_ "github.com/lib/pq"
)
type Migrations struct {
pgConn *sql.DB // TODO abstract sql connection
}
type Operation interface {
// Start will apply the required changes to enable supporting the new schema
// version in the database (through a view)
@ -29,17 +25,11 @@ type Operation interface {
Rollback(ctx context.Context, conn *sql.DB) error
}
func New(ctx context.Context, pgURL string) (*Migrations, error) {
conn, err := sql.Open("postgres", pgURL)
if err != nil {
return nil, err
type (
Operations []Operation
Migration struct {
Name string `json:"name"`
Operations Operations `json:"operations"`
}
return &Migrations{
pgConn: conn,
}, nil
}
func (m *Migrations) Close() error {
return m.pgConn.Close()
}
)

View File

@ -1,19 +1,18 @@
package migrations
import (
"bytes"
"encoding/json"
"fmt"
"io"
"os"
)
type Operations []Operation
func TemporaryName(name string) string {
return "_pgroll_new_" + name
}
func ReadMigrationFile(file string) ([]Operation, error) {
func ReadMigrationFile(file string) (*Migration, error) {
// read operations from file
jsonFile, err := os.Open(file)
if err != nil {
@ -27,15 +26,16 @@ func ReadMigrationFile(file string) ([]Operation, error) {
return nil, err
}
ops := Operations{}
err = json.Unmarshal(byteValue, &ops)
mig := Migration{}
err = json.Unmarshal(byteValue, &mig)
if err != nil {
return nil, err
}
return ops, nil
return &mig, nil
}
// UnmarshalJSON deserializes the list of operations from a JSON array.
func (v *Operations) UnmarshalJSON(data []byte) error {
var tmp []map[string]json.RawMessage
if err := json.Unmarshal(data, &tmp); err != nil {
@ -77,3 +77,39 @@ func (v *Operations) UnmarshalJSON(data []byte) error {
*v = ops
return nil
}
// MarshalJSON serializes the list of operations into a JSON array.
func (v Operations) MarshalJSON() ([]byte, error) {
if len(v) == 0 {
return []byte(`[]`), nil
}
var buf bytes.Buffer
buf.WriteByte('[')
enc := json.NewEncoder(&buf)
for i, op := range v {
if i != 0 {
buf.WriteByte(',')
}
var opName string
switch op.(type) {
case *OpCreateTable:
opName = "create_table"
default:
panic(fmt.Errorf("unknown operation for %T", op))
}
buf.WriteString(`{"`)
buf.WriteString(opName)
buf.WriteString(`":`)
if err := enc.Encode(op); err != nil {
return nil, fmt.Errorf("unable to encode op [%v]: %w", i, err)
}
buf.WriteByte('}')
}
buf.WriteByte(']')
return buf.Bytes(), nil
}

View File

@ -14,16 +14,16 @@ var _ Operation = (*OpCreateTable)(nil)
type OpCreateTable struct {
Name string `json:"name"`
Columns []column `json:"columns"`
Columns []Column `json:"columns"`
}
type column struct {
Name string `json:"name"`
Type string `json:"type"`
Nullable bool `json:"nullable"`
Unique bool `json:"unique"`
PrimaryKey bool `json:"pk"`
Default sql.NullString `json:"default"`
type Column struct {
Name string `json:"name"`
Type string `json:"type"`
Nullable bool `json:"nullable"`
Unique bool `json:"unique"`
PrimaryKey bool `json:"pk"`
Default *string `json:"default"`
}
func (o *OpCreateTable) Start(ctx context.Context, conn *sql.DB, s *schema.Schema) error {
@ -50,7 +50,7 @@ func (o *OpCreateTable) Start(ctx context.Context, conn *sql.DB, s *schema.Schem
return nil
}
func columnsToSQL(cols []column) string {
func columnsToSQL(cols []Column) string {
var sql string
for i, col := range cols {
if i > 0 {
@ -67,8 +67,8 @@ func columnsToSQL(cols []column) string {
if !col.Nullable {
sql += " NOT NULL"
}
if col.Default.Valid {
sql += fmt.Sprintf(" DEFAULT %s", col.Default.String)
if col.Default != nil {
sql += fmt.Sprintf(" DEFAULT %s", *col.Default)
}
}
return sql

View File

@ -1,4 +1,4 @@
package migrations
package migrations_test
import (
"context"
@ -7,6 +7,10 @@ import (
"testing"
"time"
"pg-roll/pkg/migrations"
"pg-roll/pkg/roll"
"pg-roll/pkg/state"
"github.com/google/go-cmp/cmp"
"github.com/lib/pq"
"github.com/testcontainers/testcontainers-go"
@ -23,11 +27,11 @@ const (
func TestViewForNewVersionIsCreatedAfterMigrationStart(t *testing.T) {
t.Parallel()
withMigratorAndConnectionToContainer(t, func(mig *Migrations, db *sql.DB) {
withMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) {
ctx := context.Background()
version := "1_create_table"
if err := mig.Start(ctx, version, Operations{createTableOp()}); err != nil {
if err := mig.Start(ctx, &migrations.Migration{Name: version, Operations: migrations.Operations{createTableOp()}}); err != nil {
t.Fatalf("Failed to start migration: %v", err)
}
@ -52,11 +56,11 @@ func TestViewForNewVersionIsCreatedAfterMigrationStart(t *testing.T) {
func TestRecordsCanBeInsertedIntoAndReadFromNewViewAfterMigrationStart(t *testing.T) {
t.Parallel()
withMigratorAndConnectionToContainer(t, func(mig *Migrations, db *sql.DB) {
withMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) {
ctx := context.Background()
version := "1_create_table"
if err := mig.Start(ctx, version, Operations{createTableOp()}); err != nil {
if err := mig.Start(ctx, &migrations.Migration{Name: version, Operations: migrations.Operations{createTableOp()}}); err != nil {
t.Fatalf("Failed to start migration: %v", err)
}
@ -117,16 +121,17 @@ func TestRecordsCanBeInsertedIntoAndReadFromNewViewAfterMigrationStart(t *testin
func TestViewSchemaAndTableAreDroppedAfterMigrationRevert(t *testing.T) {
t.Parallel()
withMigratorAndConnectionToContainer(t, func(mig *Migrations, db *sql.DB) {
withMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) {
ctx := context.Background()
ops := Operations{createTableOp()}
version := "1_create_table"
if err := mig.Start(ctx, version, ops); err != nil {
migration := &migrations.Migration{Name: version, Operations: migrations.Operations{createTableOp()}}
if err := mig.Start(ctx, migration); err != nil {
t.Fatalf("Failed to start migration: %v", err)
}
if err := mig.Rollback(ctx, version, ops); err != nil {
if err := mig.Rollback(ctx); err != nil {
t.Fatalf("Failed to revert migration: %v", err)
}
@ -134,7 +139,7 @@ func TestViewSchemaAndTableAreDroppedAfterMigrationRevert(t *testing.T) {
//
// Check that the new table has been dropped
//
tableName := TemporaryName(viewName)
tableName := migrations.TemporaryName(viewName)
err := db.QueryRow(`
SELECT EXISTS(
SELECT 1
@ -186,10 +191,10 @@ func TestViewSchemaAndTableAreDroppedAfterMigrationRevert(t *testing.T) {
})
}
func createTableOp() *OpCreateTable {
return &OpCreateTable{
func createTableOp() *migrations.OpCreateTable {
return &migrations.OpCreateTable{
Name: viewName,
Columns: []column{
Columns: []migrations.Column{
{
Name: "id",
Type: "integer",
@ -204,7 +209,7 @@ func createTableOp() *OpCreateTable {
}
}
func withMigratorAndConnectionToContainer(t *testing.T, fn func(mig *Migrations, db *sql.DB)) {
func withMigratorAndConnectionToContainer(t *testing.T, fn func(mig *roll.Roll, db *sql.DB)) {
t.Helper()
ctx := context.Background()
@ -232,7 +237,15 @@ func withMigratorAndConnectionToContainer(t *testing.T, fn func(mig *Migrations,
t.Fatal(err)
}
mig, err := New(ctx, cStr)
st, err := state.New(ctx, cStr, "pgroll")
if err != nil {
t.Fatal(err)
}
err = st.Init(ctx)
if err != nil {
t.Fatal(err)
}
mig, err := roll.New(ctx, cStr, "public", st)
if err != nil {
t.Fatal(err)
}

139
pkg/roll/execute.go Normal file
View File

@ -0,0 +1,139 @@
package roll
import (
"context"
"fmt"
"strings"
"pg-roll/pkg/migrations"
"pg-roll/pkg/schema"
"github.com/lib/pq"
)
// Start will apply the required changes to enable supporting the new schema version
func (m *Roll) Start(ctx context.Context, migration *migrations.Migration) error {
// check if there is an active migration, create one otherwise
active, err := m.state.IsActiveMigrationPeriod(ctx, m.schema)
if err != nil {
return err
}
if active {
return fmt.Errorf("there is an active migration already")
}
// TODO: retrieve current schema + store it as state?
newSchema := schema.New()
// create a new active migration (guaranteed to be unique by constraints)
err = m.state.Start(ctx, m.schema, migration)
if err != nil {
return fmt.Errorf("unable to start migration: %w", err)
}
// execute operations
for _, op := range migration.Operations {
err := op.Start(ctx, m.pgConn, newSchema)
if err != nil {
return fmt.Errorf("unable to execute start operation: %w", err)
}
}
// create schema for the new version
_, err = m.pgConn.ExecContext(ctx, fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", pq.QuoteIdentifier(migration.Name)))
if err != nil {
return err
}
// create views in the new schema
for name, table := range newSchema.Tables {
err = m.createView(ctx, migration.Name, name, table)
if err != nil {
return fmt.Errorf("unable to create view: %w", err)
}
}
return nil
}
// Complete will update the database schema to match the current version
func (m *Roll) Complete(ctx context.Context) error {
// get current ongoing migration
migration, err := m.state.GetActiveMigration(ctx, m.schema)
if err != nil {
return fmt.Errorf("unable to get active migration: %w", err)
}
// execute operations
for _, op := range migration.Operations {
err := op.Complete(ctx, m.pgConn)
if err != nil {
return fmt.Errorf("unable to execute complete operation: %w", err)
}
}
// TODO: drop views from previous version
// mark as completed
err = m.state.Complete(ctx, m.schema, migration.Name)
if err != nil {
return fmt.Errorf("unable to complete migration: %w", err)
}
return nil
}
func (m *Roll) Rollback(ctx context.Context) error {
// get current ongoing migration
migration, err := m.state.GetActiveMigration(ctx, m.schema)
if err != nil {
return fmt.Errorf("unable to get active migration: %w", err)
}
// delete the schema and view for the new version
_, err = m.pgConn.ExecContext(ctx, fmt.Sprintf("DROP SCHEMA IF EXISTS %s CASCADE", pq.QuoteIdentifier(migration.Name)))
if err != nil {
return err
}
// reverse the order of the operations so that they are undone in the correct order
ops := migration.Operations
for i, j := 0, len(ops)-1; i < j; i, j = i+1, j-1 {
ops[i], ops[j] = ops[j], ops[i]
}
// execute operations
for _, op := range ops {
err := op.Rollback(ctx, m.pgConn)
if err != nil {
return fmt.Errorf("unable to execute rollback operation: %w", err)
}
}
// mark as completed
err = m.state.Rollback(ctx, m.schema, migration.Name)
if err != nil {
return fmt.Errorf("unable to rollback migration: %w", err)
}
return nil
}
// create view creates a view for the new version of the schema
func (m *Roll) createView(ctx context.Context, version string, name string, table schema.Table) error {
columns := make([]string, 0, len(table.Columns))
for k, v := range table.Columns {
columns = append(columns, fmt.Sprintf("%s AS %s", pq.QuoteIdentifier(k), pq.QuoteIdentifier(v.Name)))
}
_, err := m.pgConn.ExecContext(ctx,
fmt.Sprintf("CREATE OR REPLACE VIEW %s.%s AS SELECT %s FROM %s",
pq.QuoteIdentifier(version),
pq.QuoteIdentifier(name),
strings.Join(columns, ","),
pq.QuoteIdentifier(table.Name)))
if err != nil {
return err
}
return nil
}

43
pkg/roll/roll.go Normal file
View File

@ -0,0 +1,43 @@
package roll
import (
"context"
"database/sql"
"pg-roll/pkg/state"
)
type Roll struct {
pgConn *sql.DB // TODO abstract sql connection
// schema we are acting on
schema string
state *state.State
}
func New(ctx context.Context, pgURL, schema string, state *state.State) (*Roll, error) {
conn, err := sql.Open("postgres", pgURL)
if err != nil {
return nil, err
}
return &Roll{
pgConn: conn,
schema: schema,
state: state,
}, nil
}
func (m *Roll) Init(ctx context.Context) error {
return m.state.Init(ctx)
}
func (m *Roll) Close() error {
err := m.state.Close()
if err != nil {
return err
}
return m.pgConn.Close()
}

168
pkg/state/state.go Normal file
View File

@ -0,0 +1,168 @@
package state
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"pg-roll/pkg/migrations"
"github.com/lib/pq"
)
const sqlInit = `
CREATE SCHEMA IF NOT EXISTS %[1]s;
CREATE TABLE IF NOT EXISTS %[1]s.migrations (
schema NAME NOT NULL,
name TEXT NOT NULL,
migration JSONB NOT NULL,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
parent TEXT,
done BOOLEAN NOT NULL DEFAULT false,
failed BOOLEAN NOT NULL DEFAULT false,
PRIMARY KEY (schema, name),
FOREIGN KEY (schema, parent) REFERENCES %[1]s.migrations(schema, name)
);
-- Only one migration can be active at a time
CREATE UNIQUE INDEX IF NOT EXISTS only_one_active ON %[1]s.migrations (schema, name, done) WHERE done = false;
-- Only first migration can exist without parent
CREATE UNIQUE INDEX IF NOT EXISTS only_first_migration_without_parent ON %[1]s.migrations ((1)) WHERE parent IS NULL;
-- History is linear
CREATE UNIQUE INDEX IF NOT EXISTS history_is_linear ON %[1]s.migrations (schema, parent);
-- Helper functions
-- Are we in the middle of a migration?
CREATE OR REPLACE FUNCTION %[1]s.is_active_migration_period(schemaname NAME) RETURNS boolean
AS $$ SELECT EXISTS (SELECT 1 FROM %[1]s.migrations WHERE schema=schemaname AND done=false) $$
LANGUAGE SQL
STABLE;
-- Get the latest version name (this is the one with child migrations)
CREATE OR REPLACE FUNCTION %[1]s.latest_version(schemaname NAME) RETURNS text
AS $$ SELECT p.name FROM %[1]s.migrations p WHERE NOT EXISTS (SELECT 1 FROM %[1]s.migrations c WHERE schema=schemaname AND c.parent=p.name) $$
LANGUAGE SQL
STABLE;
`
type State struct {
pgConn *sql.DB
schema string
}
func New(ctx context.Context, pgURL, stateSchema string) (*State, error) {
conn, err := sql.Open("postgres", pgURL)
if err != nil {
return nil, err
}
return &State{
pgConn: conn,
schema: stateSchema,
}, nil
}
func (s *State) Init(ctx context.Context) error {
// ensure pg-roll internal tables exist
// TODO: eventually use migrations for this instead of hardcoding
_, err := s.pgConn.ExecContext(ctx, fmt.Sprintf(sqlInit, pq.QuoteIdentifier(s.schema)))
return err
}
func (s *State) Close() error {
return s.pgConn.Close()
}
// IsActiveMigrationPeriod returns true if there is an active migration
func (s *State) IsActiveMigrationPeriod(ctx context.Context, schema string) (bool, error) {
var isActive bool
err := s.pgConn.QueryRowContext(ctx, fmt.Sprintf("SELECT %s.is_active_migration_period($1)", pq.QuoteIdentifier(s.schema)), schema).Scan(&isActive)
if err != nil {
return false, err
}
return isActive, err
}
// GetActiveMigration returns the name & raw content of the active migration (if any), errors out otherwise
func (s *State) GetActiveMigration(ctx context.Context, schema string) (*migrations.Migration, error) {
var name, rawMigration string
err := s.pgConn.QueryRowContext(ctx, fmt.Sprintf("SELECT name, migration FROM %s.migrations WHERE schema=$1 AND done=false", pq.QuoteIdentifier(s.schema)), schema).Scan(&name, &rawMigration)
if err != nil {
return nil, err
}
var migration migrations.Migration
err = json.Unmarshal([]byte(rawMigration), &migration)
if err != nil {
return nil, fmt.Errorf("unable to unmarshal migration: %w", err)
}
return &migration, nil
}
// Start creates a new migration, storing it's name and raw content
// this will effectively activate a new migration period, so `IsActiveMigrationPeriod` will return true
// until the migration is completed
func (s *State) Start(ctx context.Context, schema string, migration *migrations.Migration) error {
rawMigration, err := json.Marshal(migration)
if err != nil {
return fmt.Errorf("unable to marshal migration: %w", err)
}
_, err = s.pgConn.ExecContext(ctx,
fmt.Sprintf("INSERT INTO %[1]s.migrations (schema, name, parent, migration) VALUES ($1, $2, %[1]s.latest_version($1), $3)", pq.QuoteIdentifier(s.schema)),
schema, migration.Name, rawMigration)
// TODO handle constraint violations, ie to detect an active migration, or duplicated names
return err
}
// Complete marks a migration as completed
func (s *State) Complete(ctx context.Context, schema, name string) error {
res, err := s.pgConn.ExecContext(ctx, fmt.Sprintf("UPDATE %s.migrations SET done=$1 WHERE schema=$2 AND name=$3 AND done=$4", pq.QuoteIdentifier(s.schema)), true, schema, name, false)
if err != nil {
return err
}
// TODO handle constraint violations, ie trying to complete a migration that is not active
rows, err := res.RowsAffected()
if err != nil {
return err
}
if rows == 0 {
return fmt.Errorf("no migration found with name %s", name)
}
return err
}
// Rollback removes a migration from the state (we consider it rolled back, as if it never started)
func (s *State) Rollback(ctx context.Context, schema, name string) error {
res, err := s.pgConn.ExecContext(ctx, fmt.Sprintf("DELETE FROM %s.migrations WHERE schema=$1 AND name=$2 AND done=$3", pq.QuoteIdentifier(s.schema)), schema, name, false)
if err != nil {
return err
}
// TODO handle constraint violations, ie trying to complete a migration that is not active
rows, err := res.RowsAffected()
if err != nil {
return err
}
if rows == 0 {
return fmt.Errorf("no migration found with name %s", name)
}
return nil
}