Retry on lock_timeout errors (#353)

Retry statements and transactions that fail due to `lock_timeout`
errors.

DDL operations and backfills are run in a session in which `SET
lock_timout TO xms'` has been set (`x` defaults to `500` but can be
specified with the `--lock-timeout` parameter). This ensures that a long
running query can't cause other queries to queue up behind a DDL
operation as it waits to acquire its lock.

The current behaviour if a DDL operation or backfill batch times out
when requesting a lock is to fail, forcing the user to retry the
migration operation (start, rollback, or complete).

This PR retries individual statements (like the DDL operations run by
migration operations) and transactions (used by backfills) if they fail
due to a `lock_timeout` error. The retry uses an exponential backoff
with jitter.

Fixes #171
This commit is contained in:
Andrew Farries 2024-05-08 15:54:27 +01:00 committed by GitHub
parent 4f0a715613
commit 5c1aef2f24
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
33 changed files with 365 additions and 166 deletions

1
go.mod
View File

@ -3,6 +3,7 @@ module github.com/xataio/pgroll
go 1.21 go 1.21
require ( require (
github.com/cloudflare/backoff v0.0.0-20161212185259-647f3cdfc87a
github.com/google/go-cmp v0.6.0 github.com/google/go-cmp v0.6.0
github.com/lib/pq v1.10.9 github.com/lib/pq v1.10.9
github.com/oapi-codegen/nullable v1.1.0 github.com/oapi-codegen/nullable v1.1.0

2
go.sum
View File

@ -73,6 +73,8 @@ github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWR
github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI=
github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU=
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
github.com/cloudflare/backoff v0.0.0-20161212185259-647f3cdfc87a h1:8d1CEOF1xldesKds5tRG3tExBsMOgWYownMHNCsev54=
github.com/cloudflare/backoff v0.0.0-20161212185259-647f3cdfc87a/go.mod h1:rzgs2ZOiguV6/NpiDgADjRLPNyZlApIWxKpkT+X8SdY=
github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc=
github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk=
github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk=

82
pkg/db/db.go Normal file
View File

@ -0,0 +1,82 @@
// SPDX-License-Identifier: Apache-2.0
package db
import (
"context"
"database/sql"
"errors"
"time"
"github.com/cloudflare/backoff"
"github.com/lib/pq"
)
const (
lockNotAvailableErrorCode pq.ErrorCode = "55P03"
maxBackoffDuration = 1 * time.Minute
backoffInterval = 1 * time.Second
)
type DB interface {
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
WithRetryableTransaction(ctx context.Context, f func(context.Context, *sql.Tx) error) error
Close() error
}
// RDB wraps a *sql.DB and retries queries using an exponential backoff (with
// jitter) on lock_timeout errors.
type RDB struct {
DB *sql.DB
}
// ExecContext wraps sql.DB.ExecContext, retrying queries on lock_timeout errors.
func (db *RDB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
b := backoff.New(maxBackoffDuration, backoffInterval)
for {
res, err := db.DB.ExecContext(ctx, query, args...)
if err == nil {
return res, nil
}
pqErr := &pq.Error{}
if errors.As(err, &pqErr) && pqErr.Code == lockNotAvailableErrorCode {
<-time.After(b.Duration())
} else {
return nil, err
}
}
}
// WithRetryableTransaction runs `f` in a transaction, retrying on lock_timeout errors.
func (db *RDB) WithRetryableTransaction(ctx context.Context, f func(context.Context, *sql.Tx) error) error {
b := backoff.New(maxBackoffDuration, backoffInterval)
for {
tx, err := db.DB.BeginTx(ctx, nil)
if err != nil {
return err
}
err = f(ctx, tx)
if err == nil {
return tx.Commit()
}
if errRollback := tx.Rollback(); errRollback != nil {
return errRollback
}
pqErr := &pq.Error{}
if errors.As(err, &pqErr) && pqErr.Code == lockNotAvailableErrorCode {
<-time.After(b.Duration())
} else {
return err
}
}
}
func (db *RDB) Close() error {
return db.DB.Close()
}

121
pkg/db/db_test.go Normal file
View File

@ -0,0 +1,121 @@
// SPDX-License-Identifier: Apache-2.0
package db_test
import (
"context"
"database/sql"
"fmt"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/xataio/pgroll/pkg/db"
"github.com/xataio/pgroll/pkg/testutils"
)
func TestMain(m *testing.M) {
testutils.SharedTestMain(m)
}
func TestExecContext(t *testing.T) {
t.Parallel()
testutils.WithConnectionToContainer(t, func(conn *sql.DB, connStr string) {
ctx := context.Background()
// create a table on which an exclusive lock is held for 2 seconds
setupTableLock(t, connStr, 2*time.Second)
// set the lock timeout to 100ms
ensureLockTimeout(t, conn, 100)
// execute a query that should retry until the lock is released
rdb := &db.RDB{DB: conn}
_, err := rdb.ExecContext(ctx, "INSERT INTO test(id) VALUES (1)")
require.NoError(t, err)
})
}
func TestWithRetryableTransaction(t *testing.T) {
t.Parallel()
testutils.WithConnectionToContainer(t, func(conn *sql.DB, connStr string) {
ctx := context.Background()
// create a table on which an exclusive lock is held for 2 seconds
setupTableLock(t, connStr, 2*time.Second)
// set the lock timeout to 100ms
ensureLockTimeout(t, conn, 100)
// run a transaction that should retry until the lock is released
rdb := &db.RDB{DB: conn}
err := rdb.WithRetryableTransaction(ctx, func(ctx context.Context, tx *sql.Tx) error {
return tx.QueryRowContext(ctx, "SELECT 1 FROM test").Err()
})
require.NoError(t, err)
})
}
// setupTableLock:
// * connects to the database
// * creates a table in the database
// * starts a transaction that temporarily locks the table
func setupTableLock(t *testing.T, connStr string, d time.Duration) {
t.Helper()
ctx := context.Background()
// connect to the database
conn2, err := sql.Open("postgres", connStr)
require.NoError(t, err)
// create a table in the database
_, err = conn2.ExecContext(ctx, "CREATE TABLE test (id INT PRIMARY KEY)")
require.NoError(t, err)
// start a transaction that takes a temporary lock on the table
errCh := make(chan error)
go func() {
// begin a transaction
tx, err := conn2.Begin()
if err != nil {
errCh <- err
return
}
// lock the table
_, err = tx.ExecContext(ctx, "LOCK TABLE test IN ACCESS EXCLUSIVE MODE")
if err != nil {
errCh <- err
return
}
// signal that the lock is obtained
errCh <- nil
// temporarily hold the lock
time.Sleep(d)
// commit the transaction
tx.Commit()
}()
// wait for the lock to be obtained
err = <-errCh
require.NoError(t, err)
}
func ensureLockTimeout(t *testing.T, conn *sql.DB, ms int) {
t.Helper()
// Set the lock timeout
query := fmt.Sprintf("SET lock_timeout = '%dms'", ms)
_, err := conn.ExecContext(context.Background(), query)
require.NoError(t, err)
// Ensure the lock timeout is set
var lockTimeout string
err = conn.QueryRowContext(context.Background(), "SHOW lock_timeout").Scan(&lockTimeout)
require.NoError(t, err)
require.Equal(t, fmt.Sprintf("%dms", ms), lockTimeout)
}

View File

@ -9,6 +9,7 @@ import (
"fmt" "fmt"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/xataio/pgroll/pkg/db"
"github.com/xataio/pgroll/pkg/schema" "github.com/xataio/pgroll/pkg/schema"
) )
@ -18,7 +19,7 @@ import (
// 2. Get the first batch of rows from the table, ordered by the primary key. // 2. Get the first batch of rows from the table, ordered by the primary key.
// 3. Update each row in the batch, setting the value of the primary key column to itself. // 3. Update each row in the batch, setting the value of the primary key column to itself.
// 4. Repeat steps 2 and 3 until no more rows are returned. // 4. Repeat steps 2 and 3 until no more rows are returned.
func Backfill(ctx context.Context, conn *sql.DB, table *schema.Table, cbs ...CallbackFn) error { func Backfill(ctx context.Context, conn db.DB, table *schema.Table, cbs ...CallbackFn) error {
// get the backfill column // get the backfill column
identityColumn := getIdentityColumn(table) identityColumn := getIdentityColumn(table)
if identityColumn == nil { if identityColumn == nil {
@ -85,27 +86,20 @@ type batcher struct {
batchSize int batchSize int
} }
// updateBatch updates the next batch of rows in the table. func (b *batcher) updateBatch(ctx context.Context, conn db.DB) error {
func (b *batcher) updateBatch(ctx context.Context, conn *sql.DB) error { return conn.WithRetryableTransaction(ctx, func(ctx context.Context, tx *sql.Tx) error {
// Start the transaction for this batch // Build the query to update the next batch of rows
tx, err := conn.BeginTx(ctx, nil) query := b.buildQuery()
if err != nil {
return err
}
defer tx.Rollback()
// Build the query to update the next batch of rows // Execute the query to update the next batch of rows and update the last PK
query := b.buildQuery() // value for the next batch
err := tx.QueryRowContext(ctx, query).Scan(&b.lastValue)
if err != nil {
return err
}
// Execute the query to update the next batch of rows and update the last PK return nil
// value for the next batch })
err = tx.QueryRowContext(ctx, query).Scan(&b.lastValue)
if err != nil {
return err
}
// Commit the transaction for this batch
return tx.Commit()
} }
// buildQuery builds the query used to update the next batch of rows. // buildQuery builds the query used to update the next batch of rows.

View File

@ -4,13 +4,13 @@ package migrations
import ( import (
"context" "context"
"database/sql"
"fmt" "fmt"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/xataio/pgroll/pkg/db"
) )
func addCommentToColumn(ctx context.Context, conn *sql.DB, tableName, columnName string, comment *string) error { func addCommentToColumn(ctx context.Context, conn db.DB, tableName, columnName string, comment *string) error {
_, err := conn.ExecContext(ctx, fmt.Sprintf(`COMMENT ON COLUMN %s.%s IS %s`, _, err := conn.ExecContext(ctx, fmt.Sprintf(`COMMENT ON COLUMN %s.%s IS %s`,
pq.QuoteIdentifier(tableName), pq.QuoteIdentifier(tableName),
pq.QuoteIdentifier(columnName), pq.QuoteIdentifier(columnName),
@ -19,7 +19,7 @@ func addCommentToColumn(ctx context.Context, conn *sql.DB, tableName, columnName
return err return err
} }
func addCommentToTable(ctx context.Context, conn *sql.DB, tableName string, comment *string) error { func addCommentToTable(ctx context.Context, conn db.DB, tableName string, comment *string) error {
_, err := conn.ExecContext(ctx, fmt.Sprintf(`COMMENT ON TABLE %s IS %s`, _, err := conn.ExecContext(ctx, fmt.Sprintf(`COMMENT ON TABLE %s IS %s`,
pq.QuoteIdentifier(tableName), pq.QuoteIdentifier(tableName),
commentToSQL(comment))) commentToSQL(comment)))

View File

@ -4,17 +4,17 @@ package migrations
import ( import (
"context" "context"
"database/sql"
"fmt" "fmt"
"slices" "slices"
"strings" "strings"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/xataio/pgroll/pkg/db"
"github.com/xataio/pgroll/pkg/schema" "github.com/xataio/pgroll/pkg/schema"
) )
type Duplicator struct { type Duplicator struct {
conn *sql.DB conn db.DB
table *schema.Table table *schema.Table
column *schema.Column column *schema.Column
asName string asName string
@ -24,7 +24,7 @@ type Duplicator struct {
} }
// NewColumnDuplicator creates a new Duplicator for a column. // NewColumnDuplicator creates a new Duplicator for a column.
func NewColumnDuplicator(conn *sql.DB, table *schema.Table, column *schema.Column) *Duplicator { func NewColumnDuplicator(conn db.DB, table *schema.Table, column *schema.Column) *Duplicator {
return &Duplicator{ return &Duplicator{
conn: conn, conn: conn,
table: table, table: table,

View File

@ -4,10 +4,10 @@ package migrations
import ( import (
"context" "context"
"database/sql"
"fmt" "fmt"
_ "github.com/lib/pq" _ "github.com/lib/pq"
"github.com/xataio/pgroll/pkg/db"
"github.com/xataio/pgroll/pkg/schema" "github.com/xataio/pgroll/pkg/schema"
) )
@ -18,16 +18,16 @@ type Operation interface {
// version in the database (through a view) // version in the database (through a view)
// update the given views to expose the new schema version // update the given views to expose the new schema version
// Returns the table that requires backfilling, if any. // Returns the table that requires backfilling, if any.
Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error)
// Complete will update the database schema to match the current version // Complete will update the database schema to match the current version
// after calling Start. // after calling Start.
// This method should be called once the previous version is no longer used // This method should be called once the previous version is no longer used
Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error
// Rollback will revert the changes made by Start. It is not possible to // Rollback will revert the changes made by Start. It is not possible to
// rollback a completed migration. // rollback a completed migration.
Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error
// Validate returns a descriptive error if the operation cannot be applied to the given schema // Validate returns a descriptive error if the operation cannot be applied to the given schema
Validate(ctx context.Context, s *schema.Schema) error Validate(ctx context.Context, s *schema.Schema) error

View File

@ -4,18 +4,18 @@ package migrations
import ( import (
"context" "context"
"database/sql"
"errors" "errors"
"fmt" "fmt"
"strings" "strings"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/xataio/pgroll/pkg/db"
"github.com/xataio/pgroll/pkg/schema" "github.com/xataio/pgroll/pkg/schema"
) )
var _ Operation = (*OpAddColumn)(nil) var _ Operation = (*OpAddColumn)(nil)
func (o *OpAddColumn) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { func (o *OpAddColumn) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
table := s.GetTable(o.Table) table := s.GetTable(o.Table)
if err := addColumn(ctx, conn, *o, table, tr); err != nil { if err := addColumn(ctx, conn, *o, table, tr); err != nil {
@ -65,7 +65,7 @@ func (o *OpAddColumn) Start(ctx context.Context, conn *sql.DB, stateSchema strin
return tableToBackfill, nil return tableToBackfill, nil
} }
func (o *OpAddColumn) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error { func (o *OpAddColumn) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error {
tempName := TemporaryName(o.Column.Name) tempName := TemporaryName(o.Column.Name)
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s RENAME COLUMN %s TO %s", _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s RENAME COLUMN %s TO %s",
@ -118,7 +118,7 @@ func (o *OpAddColumn) Complete(ctx context.Context, conn *sql.DB, tr SQLTransfor
return err return err
} }
func (o *OpAddColumn) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error { func (o *OpAddColumn) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error {
tempName := TemporaryName(o.Column.Name) tempName := TemporaryName(o.Column.Name)
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s DROP COLUMN IF EXISTS %s", _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s DROP COLUMN IF EXISTS %s",
@ -182,7 +182,7 @@ func (o *OpAddColumn) Validate(ctx context.Context, s *schema.Schema) error {
return nil return nil
} }
func addColumn(ctx context.Context, conn *sql.DB, o OpAddColumn, t *schema.Table, tr SQLTransformer) error { func addColumn(ctx context.Context, conn db.DB, o OpAddColumn, t *schema.Table, tr SQLTransformer) error {
// don't add non-nullable columns with no default directly // don't add non-nullable columns with no default directly
// they are handled by: // they are handled by:
// - adding the column as nullable // - adding the column as nullable
@ -216,7 +216,7 @@ func addColumn(ctx context.Context, conn *sql.DB, o OpAddColumn, t *schema.Table
return err return err
} }
func addNotNullConstraint(ctx context.Context, conn *sql.DB, table, column, physicalColumn string) error { func addNotNullConstraint(ctx context.Context, conn db.DB, table, column, physicalColumn string) error {
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s CHECK (%s IS NOT NULL) NOT VALID", _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s CHECK (%s IS NOT NULL) NOT VALID",
pq.QuoteIdentifier(table), pq.QuoteIdentifier(table),
pq.QuoteIdentifier(NotNullConstraintName(column)), pq.QuoteIdentifier(NotNullConstraintName(column)),
@ -225,7 +225,7 @@ func addNotNullConstraint(ctx context.Context, conn *sql.DB, table, column, phys
return err return err
} }
func (o *OpAddColumn) addCheckConstraint(ctx context.Context, conn *sql.DB) error { func (o *OpAddColumn) addCheckConstraint(ctx context.Context, conn db.DB) error {
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s CHECK (%s) NOT VALID", _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s CHECK (%s) NOT VALID",
pq.QuoteIdentifier(o.Table), pq.QuoteIdentifier(o.Table),
pq.QuoteIdentifier(o.Column.Check.Name), pq.QuoteIdentifier(o.Column.Check.Name),

View File

@ -4,16 +4,16 @@ package migrations
import ( import (
"context" "context"
"database/sql"
"fmt" "fmt"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/xataio/pgroll/pkg/db"
"github.com/xataio/pgroll/pkg/schema" "github.com/xataio/pgroll/pkg/schema"
) )
var _ Operation = (*OpAlterColumn)(nil) var _ Operation = (*OpAlterColumn)(nil)
func (o *OpAlterColumn) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { func (o *OpAlterColumn) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
table := s.GetTable(o.Table) table := s.GetTable(o.Table)
column := table.GetColumn(o.Column) column := table.GetColumn(o.Column)
ops := o.subOperations() ops := o.subOperations()
@ -84,7 +84,7 @@ func (o *OpAlterColumn) Start(ctx context.Context, conn *sql.DB, stateSchema str
return table, nil return table, nil
} }
func (o *OpAlterColumn) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error { func (o *OpAlterColumn) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error {
ops := o.subOperations() ops := o.subOperations()
// Perform any operation specific completion steps // Perform any operation specific completion steps
@ -139,7 +139,7 @@ func (o *OpAlterColumn) Complete(ctx context.Context, conn *sql.DB, tr SQLTransf
return nil return nil
} }
func (o *OpAlterColumn) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error { func (o *OpAlterColumn) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error {
ops := o.subOperations() ops := o.subOperations()
// Perform any operation specific rollback steps // Perform any operation specific rollback steps
@ -315,7 +315,7 @@ func (o *OpAlterColumn) subOperations() []Operation {
} }
// duplicatorForOperations returns a Duplicator for the given operations // duplicatorForOperations returns a Duplicator for the given operations
func duplicatorForOperations(ops []Operation, conn *sql.DB, table *schema.Table, column *schema.Column) *Duplicator { func duplicatorForOperations(ops []Operation, conn db.DB, table *schema.Table, column *schema.Column) *Duplicator {
d := NewColumnDuplicator(conn, table, column) d := NewColumnDuplicator(conn, table, column)
for _, op := range ops { for _, op := range ops {

View File

@ -4,8 +4,8 @@ package migrations
import ( import (
"context" "context"
"database/sql"
"github.com/xataio/pgroll/pkg/db"
"github.com/xataio/pgroll/pkg/schema" "github.com/xataio/pgroll/pkg/schema"
) )
@ -19,17 +19,17 @@ type OpChangeType struct {
var _ Operation = (*OpChangeType)(nil) var _ Operation = (*OpChangeType)(nil)
func (o *OpChangeType) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { func (o *OpChangeType) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
table := s.GetTable(o.Table) table := s.GetTable(o.Table)
return table, nil return table, nil
} }
func (o *OpChangeType) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error { func (o *OpChangeType) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error {
return nil return nil
} }
func (o *OpChangeType) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error { func (o *OpChangeType) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error {
return nil return nil
} }

View File

@ -4,17 +4,17 @@ package migrations
import ( import (
"context" "context"
"database/sql"
"fmt" "fmt"
"strings" "strings"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/xataio/pgroll/pkg/db"
"github.com/xataio/pgroll/pkg/schema" "github.com/xataio/pgroll/pkg/schema"
) )
var _ Operation = (*OpCreateIndex)(nil) var _ Operation = (*OpCreateIndex)(nil)
func (o *OpCreateIndex) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { func (o *OpCreateIndex) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
// create index concurrently // create index concurrently
_, err := conn.ExecContext(ctx, fmt.Sprintf("CREATE INDEX CONCURRENTLY IF NOT EXISTS %s ON %s (%s)", _, err := conn.ExecContext(ctx, fmt.Sprintf("CREATE INDEX CONCURRENTLY IF NOT EXISTS %s ON %s (%s)",
pq.QuoteIdentifier(o.Name), pq.QuoteIdentifier(o.Name),
@ -23,12 +23,12 @@ func (o *OpCreateIndex) Start(ctx context.Context, conn *sql.DB, stateSchema str
return nil, err return nil, err
} }
func (o *OpCreateIndex) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error { func (o *OpCreateIndex) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error {
// No-op // No-op
return nil return nil
} }
func (o *OpCreateIndex) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error { func (o *OpCreateIndex) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error {
// drop the index concurrently // drop the index concurrently
_, err := conn.ExecContext(ctx, fmt.Sprintf("DROP INDEX CONCURRENTLY IF EXISTS %s", o.Name)) _, err := conn.ExecContext(ctx, fmt.Sprintf("DROP INDEX CONCURRENTLY IF EXISTS %s", o.Name))

View File

@ -4,17 +4,17 @@ package migrations
import ( import (
"context" "context"
"database/sql"
"fmt" "fmt"
"strings" "strings"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/xataio/pgroll/pkg/db"
"github.com/xataio/pgroll/pkg/schema" "github.com/xataio/pgroll/pkg/schema"
) )
var _ Operation = (*OpCreateTable)(nil) var _ Operation = (*OpCreateTable)(nil)
func (o *OpCreateTable) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { func (o *OpCreateTable) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
// Generate SQL for the columns in the table // Generate SQL for the columns in the table
columnsSQL, err := columnsToSQL(o.Columns, tr) columnsSQL, err := columnsToSQL(o.Columns, tr)
if err != nil { if err != nil {
@ -61,7 +61,7 @@ func (o *OpCreateTable) Start(ctx context.Context, conn *sql.DB, stateSchema str
return nil, nil return nil, nil
} }
func (o *OpCreateTable) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error { func (o *OpCreateTable) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error {
tempName := TemporaryName(o.Name) tempName := TemporaryName(o.Name)
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s RENAME TO %s", _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s RENAME TO %s",
pq.QuoteIdentifier(tempName), pq.QuoteIdentifier(tempName),
@ -69,7 +69,7 @@ func (o *OpCreateTable) Complete(ctx context.Context, conn *sql.DB, tr SQLTransf
return err return err
} }
func (o *OpCreateTable) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error { func (o *OpCreateTable) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error {
tempName := TemporaryName(o.Name) tempName := TemporaryName(o.Name)
_, err := conn.ExecContext(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s", _, err := conn.ExecContext(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s",

View File

@ -4,16 +4,16 @@ package migrations
import ( import (
"context" "context"
"database/sql"
"fmt" "fmt"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/xataio/pgroll/pkg/db"
"github.com/xataio/pgroll/pkg/schema" "github.com/xataio/pgroll/pkg/schema"
) )
var _ Operation = (*OpDropColumn)(nil) var _ Operation = (*OpDropColumn)(nil)
func (o *OpDropColumn) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { func (o *OpDropColumn) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
if o.Down != "" { if o.Down != "" {
err := createTrigger(ctx, conn, tr, triggerConfig{ err := createTrigger(ctx, conn, tr, triggerConfig{
Name: TriggerName(o.Table, o.Column), Name: TriggerName(o.Table, o.Column),
@ -34,7 +34,7 @@ func (o *OpDropColumn) Start(ctx context.Context, conn *sql.DB, stateSchema stri
return nil, nil return nil, nil
} }
func (o *OpDropColumn) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error { func (o *OpDropColumn) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error {
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s DROP COLUMN %s", _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s DROP COLUMN %s",
pq.QuoteIdentifier(o.Table), pq.QuoteIdentifier(o.Table),
pq.QuoteIdentifier(o.Column))) pq.QuoteIdentifier(o.Column)))
@ -48,7 +48,7 @@ func (o *OpDropColumn) Complete(ctx context.Context, conn *sql.DB, tr SQLTransfo
return err return err
} }
func (o *OpDropColumn) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error { func (o *OpDropColumn) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error {
_, err := conn.ExecContext(ctx, fmt.Sprintf("DROP FUNCTION IF EXISTS %s CASCADE", _, err := conn.ExecContext(ctx, fmt.Sprintf("DROP FUNCTION IF EXISTS %s CASCADE",
pq.QuoteIdentifier(TriggerFunctionName(o.Table, o.Column)))) pq.QuoteIdentifier(TriggerFunctionName(o.Table, o.Column))))

View File

@ -4,16 +4,16 @@ package migrations
import ( import (
"context" "context"
"database/sql"
"fmt" "fmt"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/xataio/pgroll/pkg/db"
"github.com/xataio/pgroll/pkg/schema" "github.com/xataio/pgroll/pkg/schema"
) )
var _ Operation = (*OpDropConstraint)(nil) var _ Operation = (*OpDropConstraint)(nil)
func (o *OpDropConstraint) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { func (o *OpDropConstraint) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
table := s.GetTable(o.Table) table := s.GetTable(o.Table)
column := table.GetColumn(o.Column) column := table.GetColumn(o.Column)
@ -62,7 +62,7 @@ func (o *OpDropConstraint) Start(ctx context.Context, conn *sql.DB, stateSchema
return table, nil return table, nil
} }
func (o *OpDropConstraint) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error { func (o *OpDropConstraint) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error {
// Remove the up function and trigger // Remove the up function and trigger
_, err := conn.ExecContext(ctx, fmt.Sprintf("DROP FUNCTION IF EXISTS %s CASCADE", _, err := conn.ExecContext(ctx, fmt.Sprintf("DROP FUNCTION IF EXISTS %s CASCADE",
pq.QuoteIdentifier(TriggerFunctionName(o.Table, o.Column)))) pq.QuoteIdentifier(TriggerFunctionName(o.Table, o.Column))))
@ -95,7 +95,7 @@ func (o *OpDropConstraint) Complete(ctx context.Context, conn *sql.DB, tr SQLTra
return err return err
} }
func (o *OpDropConstraint) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error { func (o *OpDropConstraint) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error {
// Drop the new column // Drop the new column
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s DROP COLUMN IF EXISTS %s", _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s DROP COLUMN IF EXISTS %s",
pq.QuoteIdentifier(o.Table), pq.QuoteIdentifier(o.Table),

View File

@ -4,27 +4,27 @@ package migrations
import ( import (
"context" "context"
"database/sql"
"fmt" "fmt"
"github.com/xataio/pgroll/pkg/db"
"github.com/xataio/pgroll/pkg/schema" "github.com/xataio/pgroll/pkg/schema"
) )
var _ Operation = (*OpDropIndex)(nil) var _ Operation = (*OpDropIndex)(nil)
func (o *OpDropIndex) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { func (o *OpDropIndex) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
// no-op // no-op
return nil, nil return nil, nil
} }
func (o *OpDropIndex) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error { func (o *OpDropIndex) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error {
// drop the index concurrently // drop the index concurrently
_, err := conn.ExecContext(ctx, fmt.Sprintf("DROP INDEX CONCURRENTLY IF EXISTS %s", o.Name)) _, err := conn.ExecContext(ctx, fmt.Sprintf("DROP INDEX CONCURRENTLY IF EXISTS %s", o.Name))
return err return err
} }
func (o *OpDropIndex) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error { func (o *OpDropIndex) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error {
// no-op // no-op
return nil return nil
} }

View File

@ -4,8 +4,8 @@ package migrations
import ( import (
"context" "context"
"database/sql"
"github.com/xataio/pgroll/pkg/db"
"github.com/xataio/pgroll/pkg/schema" "github.com/xataio/pgroll/pkg/schema"
) )
@ -18,17 +18,17 @@ type OpDropNotNull struct {
var _ Operation = (*OpDropNotNull)(nil) var _ Operation = (*OpDropNotNull)(nil)
func (o *OpDropNotNull) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { func (o *OpDropNotNull) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
table := s.GetTable(o.Table) table := s.GetTable(o.Table)
return table, nil return table, nil
} }
func (o *OpDropNotNull) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error { func (o *OpDropNotNull) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error {
return nil return nil
} }
func (o *OpDropNotNull) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error { func (o *OpDropNotNull) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error {
return nil return nil
} }

View File

@ -4,27 +4,27 @@ package migrations
import ( import (
"context" "context"
"database/sql"
"fmt" "fmt"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/xataio/pgroll/pkg/db"
"github.com/xataio/pgroll/pkg/schema" "github.com/xataio/pgroll/pkg/schema"
) )
var _ Operation = (*OpDropTable)(nil) var _ Operation = (*OpDropTable)(nil)
func (o *OpDropTable) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { func (o *OpDropTable) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
s.RemoveTable(o.Name) s.RemoveTable(o.Name)
return nil, nil return nil, nil
} }
func (o *OpDropTable) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error { func (o *OpDropTable) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error {
_, err := conn.ExecContext(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s", pq.QuoteIdentifier(o.Name))) _, err := conn.ExecContext(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s", pq.QuoteIdentifier(o.Name)))
return err return err
} }
func (o *OpDropTable) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error { func (o *OpDropTable) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error {
return nil return nil
} }

View File

@ -4,14 +4,14 @@ package migrations
import ( import (
"context" "context"
"database/sql"
"github.com/xataio/pgroll/pkg/db"
"github.com/xataio/pgroll/pkg/schema" "github.com/xataio/pgroll/pkg/schema"
) )
var _ Operation = (*OpRawSQL)(nil) var _ Operation = (*OpRawSQL)(nil)
func (o *OpRawSQL) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { func (o *OpRawSQL) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
if o.OnComplete { if o.OnComplete {
return nil, nil return nil, nil
} }
@ -25,7 +25,7 @@ func (o *OpRawSQL) Start(ctx context.Context, conn *sql.DB, stateSchema string,
return nil, err return nil, err
} }
func (o *OpRawSQL) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error { func (o *OpRawSQL) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error {
if !o.OnComplete { if !o.OnComplete {
return nil return nil
} }
@ -39,7 +39,7 @@ func (o *OpRawSQL) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer
return err return err
} }
func (o *OpRawSQL) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error { func (o *OpRawSQL) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error {
if o.Down == "" { if o.Down == "" {
return nil return nil
} }

View File

@ -4,21 +4,21 @@ package migrations
import ( import (
"context" "context"
"database/sql"
"fmt" "fmt"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/xataio/pgroll/pkg/db"
"github.com/xataio/pgroll/pkg/schema" "github.com/xataio/pgroll/pkg/schema"
) )
var _ Operation = (*OpRenameConstraint)(nil) var _ Operation = (*OpRenameConstraint)(nil)
func (o *OpRenameConstraint) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { func (o *OpRenameConstraint) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
// no-op // no-op
return nil, nil return nil, nil
} }
func (o *OpRenameConstraint) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error { func (o *OpRenameConstraint) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error {
// rename the constraint in the underlying table // rename the constraint in the underlying table
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s RENAME CONSTRAINT %s TO %s", _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s RENAME CONSTRAINT %s TO %s",
pq.QuoteIdentifier(o.Table), pq.QuoteIdentifier(o.Table),
@ -27,7 +27,7 @@ func (o *OpRenameConstraint) Complete(ctx context.Context, conn *sql.DB, tr SQLT
return err return err
} }
func (o *OpRenameConstraint) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error { func (o *OpRenameConstraint) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error {
// no-op // no-op
return nil return nil
} }

View File

@ -4,27 +4,27 @@ package migrations
import ( import (
"context" "context"
"database/sql"
"fmt" "fmt"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/xataio/pgroll/pkg/db"
"github.com/xataio/pgroll/pkg/schema" "github.com/xataio/pgroll/pkg/schema"
) )
var _ Operation = (*OpRenameTable)(nil) var _ Operation = (*OpRenameTable)(nil)
func (o *OpRenameTable) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { func (o *OpRenameTable) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
return nil, s.RenameTable(o.From, o.To) return nil, s.RenameTable(o.From, o.To)
} }
func (o *OpRenameTable) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error { func (o *OpRenameTable) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error {
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s RENAME TO %s", _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s RENAME TO %s",
pq.QuoteIdentifier(o.From), pq.QuoteIdentifier(o.From),
pq.QuoteIdentifier(o.To))) pq.QuoteIdentifier(o.To)))
return err return err
} }
func (o *OpRenameTable) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error { func (o *OpRenameTable) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error {
return nil return nil
} }

View File

@ -4,11 +4,11 @@ package migrations
import ( import (
"context" "context"
"database/sql"
"fmt" "fmt"
"strings" "strings"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/xataio/pgroll/pkg/db"
"github.com/xataio/pgroll/pkg/schema" "github.com/xataio/pgroll/pkg/schema"
) )
@ -22,7 +22,7 @@ type OpSetCheckConstraint struct {
var _ Operation = (*OpSetCheckConstraint)(nil) var _ Operation = (*OpSetCheckConstraint)(nil)
func (o *OpSetCheckConstraint) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { func (o *OpSetCheckConstraint) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
table := s.GetTable(o.Table) table := s.GetTable(o.Table)
// Add the check constraint to the new column as NOT VALID. // Add the check constraint to the new column as NOT VALID.
@ -33,7 +33,7 @@ func (o *OpSetCheckConstraint) Start(ctx context.Context, conn *sql.DB, stateSch
return table, nil return table, nil
} }
func (o *OpSetCheckConstraint) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error { func (o *OpSetCheckConstraint) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error {
// Validate the check constraint // Validate the check constraint
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s VALIDATE CONSTRAINT %s", _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s VALIDATE CONSTRAINT %s",
pq.QuoteIdentifier(o.Table), pq.QuoteIdentifier(o.Table),
@ -45,7 +45,7 @@ func (o *OpSetCheckConstraint) Complete(ctx context.Context, conn *sql.DB, tr SQ
return nil return nil
} }
func (o *OpSetCheckConstraint) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error { func (o *OpSetCheckConstraint) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error {
return nil return nil
} }
@ -69,7 +69,7 @@ func (o *OpSetCheckConstraint) Validate(ctx context.Context, s *schema.Schema) e
return nil return nil
} }
func (o *OpSetCheckConstraint) addCheckConstraint(ctx context.Context, conn *sql.DB) error { func (o *OpSetCheckConstraint) addCheckConstraint(ctx context.Context, conn db.DB) error {
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s CHECK (%s) NOT VALID", _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s CHECK (%s) NOT VALID",
pq.QuoteIdentifier(o.Table), pq.QuoteIdentifier(o.Table),
pq.QuoteIdentifier(o.Check.Name), pq.QuoteIdentifier(o.Check.Name),

View File

@ -4,8 +4,8 @@ package migrations
import ( import (
"context" "context"
"database/sql"
"github.com/xataio/pgroll/pkg/db"
"github.com/xataio/pgroll/pkg/schema" "github.com/xataio/pgroll/pkg/schema"
) )
@ -19,17 +19,17 @@ type OpSetComment struct {
var _ Operation = (*OpSetComment)(nil) var _ Operation = (*OpSetComment)(nil)
func (o *OpSetComment) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { func (o *OpSetComment) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
tbl := s.GetTable(o.Table) tbl := s.GetTable(o.Table)
return tbl, addCommentToColumn(ctx, conn, o.Table, TemporaryName(o.Column), o.Comment) return tbl, addCommentToColumn(ctx, conn, o.Table, TemporaryName(o.Column), o.Comment)
} }
func (o *OpSetComment) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error { func (o *OpSetComment) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error {
return nil return nil
} }
func (o *OpSetComment) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error { func (o *OpSetComment) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error {
return nil return nil
} }

View File

@ -4,10 +4,10 @@ package migrations
import ( import (
"context" "context"
"database/sql"
"fmt" "fmt"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/xataio/pgroll/pkg/db"
"github.com/xataio/pgroll/pkg/schema" "github.com/xataio/pgroll/pkg/schema"
) )
@ -21,7 +21,7 @@ type OpSetDefault struct {
var _ Operation = (*OpSetDefault)(nil) var _ Operation = (*OpSetDefault)(nil)
func (o *OpSetDefault) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { func (o *OpSetDefault) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
tbl := s.GetTable(o.Table) tbl := s.GetTable(o.Table)
_, err := conn.ExecContext(ctx, fmt.Sprintf(`ALTER TABLE %s ALTER COLUMN %s SET DEFAULT %s`, _, err := conn.ExecContext(ctx, fmt.Sprintf(`ALTER TABLE %s ALTER COLUMN %s SET DEFAULT %s`,
@ -35,11 +35,11 @@ func (o *OpSetDefault) Start(ctx context.Context, conn *sql.DB, stateSchema stri
return tbl, nil return tbl, nil
} }
func (o *OpSetDefault) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error { func (o *OpSetDefault) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error {
return nil return nil
} }
func (o *OpSetDefault) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error { func (o *OpSetDefault) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error {
return nil return nil
} }

View File

@ -4,11 +4,11 @@ package migrations
import ( import (
"context" "context"
"database/sql"
"fmt" "fmt"
"strings" "strings"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/xataio/pgroll/pkg/db"
"github.com/xataio/pgroll/pkg/schema" "github.com/xataio/pgroll/pkg/schema"
) )
@ -22,7 +22,7 @@ type OpSetForeignKey struct {
var _ Operation = (*OpSetForeignKey)(nil) var _ Operation = (*OpSetForeignKey)(nil)
func (o *OpSetForeignKey) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { func (o *OpSetForeignKey) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
table := s.GetTable(o.Table) table := s.GetTable(o.Table)
// Create a NOT VALID foreign key constraint on the new column. // Create a NOT VALID foreign key constraint on the new column.
@ -33,7 +33,7 @@ func (o *OpSetForeignKey) Start(ctx context.Context, conn *sql.DB, stateSchema s
return table, nil return table, nil
} }
func (o *OpSetForeignKey) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error { func (o *OpSetForeignKey) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error {
// Validate the foreign key constraint // Validate the foreign key constraint
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s VALIDATE CONSTRAINT %s", _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s VALIDATE CONSTRAINT %s",
pq.QuoteIdentifier(o.Table), pq.QuoteIdentifier(o.Table),
@ -45,7 +45,7 @@ func (o *OpSetForeignKey) Complete(ctx context.Context, conn *sql.DB, tr SQLTran
return nil return nil
} }
func (o *OpSetForeignKey) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error { func (o *OpSetForeignKey) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error {
return nil return nil
} }
@ -69,7 +69,7 @@ func (o *OpSetForeignKey) Validate(ctx context.Context, s *schema.Schema) error
return nil return nil
} }
func (o *OpSetForeignKey) addForeignKeyConstraint(ctx context.Context, conn *sql.DB) error { func (o *OpSetForeignKey) addForeignKeyConstraint(ctx context.Context, conn db.DB) error {
tempColumnName := TemporaryName(o.Column) tempColumnName := TemporaryName(o.Column)
onDelete := "NO ACTION" onDelete := "NO ACTION"

View File

@ -4,10 +4,10 @@ package migrations
import ( import (
"context" "context"
"database/sql"
"fmt" "fmt"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/xataio/pgroll/pkg/db"
"github.com/xataio/pgroll/pkg/schema" "github.com/xataio/pgroll/pkg/schema"
) )
@ -20,7 +20,7 @@ type OpSetNotNull struct {
var _ Operation = (*OpSetNotNull)(nil) var _ Operation = (*OpSetNotNull)(nil)
func (o *OpSetNotNull) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { func (o *OpSetNotNull) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
table := s.GetTable(o.Table) table := s.GetTable(o.Table)
// Add an unchecked NOT NULL constraint to the new column. // Add an unchecked NOT NULL constraint to the new column.
@ -31,7 +31,7 @@ func (o *OpSetNotNull) Start(ctx context.Context, conn *sql.DB, stateSchema stri
return table, nil return table, nil
} }
func (o *OpSetNotNull) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error { func (o *OpSetNotNull) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error {
// Validate the NOT NULL constraint on the old column. // Validate the NOT NULL constraint on the old column.
// The constraint must be valid because: // The constraint must be valid because:
// * Existing NULL values in the old column were rewritten using the `up` SQL during backfill. // * Existing NULL values in the old column were rewritten using the `up` SQL during backfill.
@ -62,7 +62,7 @@ func (o *OpSetNotNull) Complete(ctx context.Context, conn *sql.DB, tr SQLTransfo
return nil return nil
} }
func (o *OpSetNotNull) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error { func (o *OpSetNotNull) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error {
return nil return nil
} }

View File

@ -4,18 +4,18 @@ package migrations
import ( import (
"context" "context"
"database/sql"
"fmt" "fmt"
"slices" "slices"
"strings" "strings"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/xataio/pgroll/pkg/db"
"github.com/xataio/pgroll/pkg/schema" "github.com/xataio/pgroll/pkg/schema"
) )
var _ Operation = (*OpSetReplicaIdentity)(nil) var _ Operation = (*OpSetReplicaIdentity)(nil)
func (o *OpSetReplicaIdentity) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { func (o *OpSetReplicaIdentity) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
// build the correct form of the `SET REPLICA IDENTITY` statement based on the`identity type // build the correct form of the `SET REPLICA IDENTITY` statement based on the`identity type
identitySQL := strings.ToUpper(o.Identity.Type) identitySQL := strings.ToUpper(o.Identity.Type)
if identitySQL == "INDEX" { if identitySQL == "INDEX" {
@ -29,12 +29,12 @@ func (o *OpSetReplicaIdentity) Start(ctx context.Context, conn *sql.DB, stateSch
return nil, err return nil, err
} }
func (o *OpSetReplicaIdentity) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error { func (o *OpSetReplicaIdentity) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error {
// No-op // No-op
return nil return nil
} }
func (o *OpSetReplicaIdentity) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error { func (o *OpSetReplicaIdentity) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error {
// No-op // No-op
return nil return nil
} }

View File

@ -4,10 +4,10 @@ package migrations
import ( import (
"context" "context"
"database/sql"
"fmt" "fmt"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/xataio/pgroll/pkg/db"
"github.com/xataio/pgroll/pkg/schema" "github.com/xataio/pgroll/pkg/schema"
) )
@ -21,7 +21,7 @@ type OpSetUnique struct {
var _ Operation = (*OpSetUnique)(nil) var _ Operation = (*OpSetUnique)(nil)
func (o *OpSetUnique) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { func (o *OpSetUnique) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
table := s.GetTable(o.Table) table := s.GetTable(o.Table)
// Add a unique index to the new column // Add a unique index to the new column
@ -32,7 +32,7 @@ func (o *OpSetUnique) Start(ctx context.Context, conn *sql.DB, stateSchema strin
return table, nil return table, nil
} }
func (o *OpSetUnique) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error { func (o *OpSetUnique) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error {
// Create a unique constraint using the unique index // Create a unique constraint using the unique index
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s ADD CONSTRAINT %s UNIQUE USING INDEX %s", _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s ADD CONSTRAINT %s UNIQUE USING INDEX %s",
pq.QuoteIdentifier(o.Table), pq.QuoteIdentifier(o.Table),
@ -45,7 +45,7 @@ func (o *OpSetUnique) Complete(ctx context.Context, conn *sql.DB, tr SQLTransfor
return err return err
} }
func (o *OpSetUnique) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error { func (o *OpSetUnique) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error {
return nil return nil
} }
@ -66,7 +66,7 @@ func (o *OpSetUnique) Validate(ctx context.Context, s *schema.Schema) error {
return nil return nil
} }
func (o *OpSetUnique) addUniqueIndex(ctx context.Context, conn *sql.DB) error { func (o *OpSetUnique) addUniqueIndex(ctx context.Context, conn db.DB) error {
// create unique index concurrently // create unique index concurrently
_, err := conn.ExecContext(ctx, fmt.Sprintf("CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS %s ON %s (%s)", _, err := conn.ExecContext(ctx, fmt.Sprintf("CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS %s ON %s (%s)",
pq.QuoteIdentifier(o.Name), pq.QuoteIdentifier(o.Name),

View File

@ -4,11 +4,11 @@ package migrations
import ( import (
"context" "context"
"database/sql"
"fmt" "fmt"
"slices" "slices"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/xataio/pgroll/pkg/db"
"github.com/xataio/pgroll/pkg/schema" "github.com/xataio/pgroll/pkg/schema"
) )
@ -16,7 +16,7 @@ import (
// * renames a duplicated column to its original name // * renames a duplicated column to its original name
// * renames any foreign keys on the duplicated column to their original name. // * renames any foreign keys on the duplicated column to their original name.
// * Validates and renames any temporary `CHECK` constraints on the duplicated column. // * Validates and renames any temporary `CHECK` constraints on the duplicated column.
func RenameDuplicatedColumn(ctx context.Context, conn *sql.DB, table *schema.Table, column *schema.Column) error { func RenameDuplicatedColumn(ctx context.Context, conn db.DB, table *schema.Table, column *schema.Column) error {
const ( const (
cRenameColumnSQL = `ALTER TABLE IF EXISTS %s RENAME COLUMN %s TO %s` cRenameColumnSQL = `ALTER TABLE IF EXISTS %s RENAME COLUMN %s TO %s`
cRenameConstraintSQL = `ALTER TABLE IF EXISTS %s RENAME CONSTRAINT %s TO %s` cRenameConstraintSQL = `ALTER TABLE IF EXISTS %s RENAME CONSTRAINT %s TO %s`

View File

@ -5,10 +5,10 @@ package migrations
import ( import (
"bytes" "bytes"
"context" "context"
"database/sql"
"text/template" "text/template"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/xataio/pgroll/pkg/db"
"github.com/xataio/pgroll/pkg/migrations/templates" "github.com/xataio/pgroll/pkg/migrations/templates"
"github.com/xataio/pgroll/pkg/schema" "github.com/xataio/pgroll/pkg/schema"
) )
@ -33,7 +33,7 @@ type triggerConfig struct {
SQL string SQL string
} }
func createTrigger(ctx context.Context, conn *sql.DB, tr SQLTransformer, cfg triggerConfig) error { func createTrigger(ctx context.Context, conn db.DB, tr SQLTransformer, cfg triggerConfig) error {
sql, err := tr.TransformSQL(cfg.SQL) sql, err := tr.TransformSQL(cfg.SQL)
if err != nil { if err != nil {
return err return err

View File

@ -5,11 +5,10 @@ package roll_test
import ( import (
"context" "context"
"database/sql" "database/sql"
"errors"
"fmt" "fmt"
"testing" "testing"
"time"
"github.com/lib/pq"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/xataio/pgroll/pkg/migrations" "github.com/xataio/pgroll/pkg/migrations"
@ -313,57 +312,48 @@ func TestSchemaOptionIsRespected(t *testing.T) {
}) })
} }
func TestLockTimeoutIsEnforced(t *testing.T) { func TestMigrationDDLIsRetriedOnLockTimeouts(t *testing.T) {
t.Parallel() t.Parallel()
testutils.WithMigratorInSchemaAndConnectionToContainerWithOptions(t, "public", []roll.Option{roll.WithLockTimeoutMs(100)}, func(mig *roll.Roll, db *sql.DB) { testutils.WithMigratorInSchemaAndConnectionToContainerWithOptions(t, "public", []roll.Option{roll.WithLockTimeoutMs(50)}, func(mig *roll.Roll, db *sql.DB) {
ctx := context.Background() ctx := context.Background()
// Start a create table migration // Create a table
err := mig.Start(ctx, &migrations.Migration{ _, err := db.ExecContext(ctx, "CREATE TABLE table1 (id integer, name text)")
Name: "01_create_table", require.NoError(t, err)
Operations: migrations.Operations{createTableOp("table1")},
})
if err != nil {
t.Fatalf("Failed to start migration: %v", err)
}
// Complete the create table migration // Start a goroutine which takes an ACCESS_EXCLUSIVE lock on the table for
if err := mig.Complete(ctx); err != nil { // two seconds
t.Fatalf("Failed to complete migration: %v", err) errCh := make(chan error)
} go func() {
tx, err := db.Begin()
if err != nil {
errCh <- err
}
// Start a transaction and take an ACCESS_EXCLUSIVE lock on the table if _, err := tx.ExecContext(ctx, "LOCK TABLE table1 IN ACCESS EXCLUSIVE MODE"); err != nil {
// Don't commit the transaction so that the lock is held indefinitely errCh <- err
tx, err := db.Begin() }
if err != nil { errCh <- nil
t.Fatalf("Failed to start transaction: %v", err)
} // Sleep for two seconds to hold the lock
t.Cleanup(func() { time.Sleep(2 * time.Second)
// Commit the transaction
tx.Commit() tx.Commit()
}) }()
if _, err := tx.ExecContext(ctx, "LOCK TABLE table1 IN ACCESS EXCLUSIVE MODE"); err != nil {
t.Fatalf("Failed to take ACCESS_EXCLUSIVE lock on table: %v", err)
}
// Attempt to run a second migration on the table while the lock is held // Wait for lock to be taken
// The migration should fail due to a lock timeout error err = <-errCh
require.NoError(t, err)
// Attempt to start a second migration on the table while the lock is held.
// The migration should eventually succeed after the lock is released
err = mig.Start(ctx, &migrations.Migration{ err = mig.Start(ctx, &migrations.Migration{
Name: "02_create_table", Name: "01_add_column",
Operations: migrations.Operations{addColumnOp("table1")}, Operations: migrations.Operations{addColumnOp("table1")},
}) })
if err == nil { require.NoError(t, err)
t.Fatalf("Expected migration to fail due to lock timeout")
}
if err != nil {
pqErr := &pq.Error{}
if ok := errors.As(err, &pqErr); !ok {
t.Fatalf("Migration failed with unexpected error: %v", err)
}
if pqErr.Code != "55P03" { // Lock not available error code
t.Fatalf("Migration failed with unexpected error: %v", err)
}
}
}) })
} }

View File

@ -9,6 +9,7 @@ import (
"github.com/lib/pq" "github.com/lib/pq"
"github.com/xataio/pgroll/pkg/db"
"github.com/xataio/pgroll/pkg/migrations" "github.com/xataio/pgroll/pkg/migrations"
"github.com/xataio/pgroll/pkg/state" "github.com/xataio/pgroll/pkg/state"
) )
@ -18,7 +19,7 @@ type PGVersion int
const PGVersion15 PGVersion = 15 const PGVersion15 PGVersion = 15
type Roll struct { type Roll struct {
pgConn *sql.DB // TODO abstract sql connection pgConn db.DB
// schema we are acting on // schema we are acting on
schema string schema string
@ -57,7 +58,7 @@ func New(ctx context.Context, pgURL, schema string, state *state.State, opts ...
} }
return &Roll{ return &Roll{
pgConn: conn, pgConn: &db.RDB{DB: conn},
schema: schema, schema: schema,
state: state, state: state,
pgVersion: PGVersion(pgMajorVersion), pgVersion: PGVersion(pgMajorVersion),
@ -114,7 +115,7 @@ func (m *Roll) PGVersion() PGVersion {
return m.pgVersion return m.pgVersion
} }
func (m *Roll) PgConn() *sql.DB { func (m *Roll) PgConn() db.DB {
return m.pgConn return m.pgConn
} }

View File

@ -103,6 +103,14 @@ func WithStateInSchemaAndConnectionToContainer(t *testing.T, schema string, fn f
fn(st, db) fn(st, db)
} }
func WithConnectionToContainer(t *testing.T, fn func(*sql.DB, string)) {
t.Helper()
db, connStr, _ := setupTestDatabase(t)
fn(db, connStr)
}
func WithStateAndConnectionToContainer(t *testing.T, fn func(*state.State, *sql.DB)) { func WithStateAndConnectionToContainer(t *testing.T, fn func(*state.State, *sql.DB)) {
WithStateInSchemaAndConnectionToContainer(t, "pgroll", fn) WithStateInSchemaAndConnectionToContainer(t, "pgroll", fn)
} }