Retry on lock_timeout errors (#353)

Retry statements and transactions that fail due to `lock_timeout`
errors.

DDL operations and backfills are run in a session in which `SET
lock_timout TO xms'` has been set (`x` defaults to `500` but can be
specified with the `--lock-timeout` parameter). This ensures that a long
running query can't cause other queries to queue up behind a DDL
operation as it waits to acquire its lock.

The current behaviour if a DDL operation or backfill batch times out
when requesting a lock is to fail, forcing the user to retry the
migration operation (start, rollback, or complete).

This PR retries individual statements (like the DDL operations run by
migration operations) and transactions (used by backfills) if they fail
due to a `lock_timeout` error. The retry uses an exponential backoff
with jitter.

Fixes #171
This commit is contained in:
Andrew Farries 2024-05-08 15:54:27 +01:00 committed by GitHub
parent 4f0a715613
commit 5c1aef2f24
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
33 changed files with 365 additions and 166 deletions

1
go.mod
View File

@ -3,6 +3,7 @@ module github.com/xataio/pgroll
go 1.21
require (
github.com/cloudflare/backoff v0.0.0-20161212185259-647f3cdfc87a
github.com/google/go-cmp v0.6.0
github.com/lib/pq v1.10.9
github.com/oapi-codegen/nullable v1.1.0

2
go.sum
View File

@ -73,6 +73,8 @@ github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWR
github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI=
github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU=
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
github.com/cloudflare/backoff v0.0.0-20161212185259-647f3cdfc87a h1:8d1CEOF1xldesKds5tRG3tExBsMOgWYownMHNCsev54=
github.com/cloudflare/backoff v0.0.0-20161212185259-647f3cdfc87a/go.mod h1:rzgs2ZOiguV6/NpiDgADjRLPNyZlApIWxKpkT+X8SdY=
github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc=
github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk=
github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk=

82
pkg/db/db.go Normal file
View File

@ -0,0 +1,82 @@
// SPDX-License-Identifier: Apache-2.0
package db
import (
"context"
"database/sql"
"errors"
"time"
"github.com/cloudflare/backoff"
"github.com/lib/pq"
)
const (
lockNotAvailableErrorCode pq.ErrorCode = "55P03"
maxBackoffDuration = 1 * time.Minute
backoffInterval = 1 * time.Second
)
type DB interface {
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
WithRetryableTransaction(ctx context.Context, f func(context.Context, *sql.Tx) error) error
Close() error
}
// RDB wraps a *sql.DB and retries queries using an exponential backoff (with
// jitter) on lock_timeout errors.
type RDB struct {
DB *sql.DB
}
// ExecContext wraps sql.DB.ExecContext, retrying queries on lock_timeout errors.
func (db *RDB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
b := backoff.New(maxBackoffDuration, backoffInterval)
for {
res, err := db.DB.ExecContext(ctx, query, args...)
if err == nil {
return res, nil
}
pqErr := &pq.Error{}
if errors.As(err, &pqErr) && pqErr.Code == lockNotAvailableErrorCode {
<-time.After(b.Duration())
} else {
return nil, err
}
}
}
// WithRetryableTransaction runs `f` in a transaction, retrying on lock_timeout errors.
func (db *RDB) WithRetryableTransaction(ctx context.Context, f func(context.Context, *sql.Tx) error) error {
b := backoff.New(maxBackoffDuration, backoffInterval)
for {
tx, err := db.DB.BeginTx(ctx, nil)
if err != nil {
return err
}
err = f(ctx, tx)
if err == nil {
return tx.Commit()
}
if errRollback := tx.Rollback(); errRollback != nil {
return errRollback
}
pqErr := &pq.Error{}
if errors.As(err, &pqErr) && pqErr.Code == lockNotAvailableErrorCode {
<-time.After(b.Duration())
} else {
return err
}
}
}
func (db *RDB) Close() error {
return db.DB.Close()
}

121
pkg/db/db_test.go Normal file
View File

@ -0,0 +1,121 @@
// SPDX-License-Identifier: Apache-2.0
package db_test
import (
"context"
"database/sql"
"fmt"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/xataio/pgroll/pkg/db"
"github.com/xataio/pgroll/pkg/testutils"
)
func TestMain(m *testing.M) {
testutils.SharedTestMain(m)
}
func TestExecContext(t *testing.T) {
t.Parallel()
testutils.WithConnectionToContainer(t, func(conn *sql.DB, connStr string) {
ctx := context.Background()
// create a table on which an exclusive lock is held for 2 seconds
setupTableLock(t, connStr, 2*time.Second)
// set the lock timeout to 100ms
ensureLockTimeout(t, conn, 100)
// execute a query that should retry until the lock is released
rdb := &db.RDB{DB: conn}
_, err := rdb.ExecContext(ctx, "INSERT INTO test(id) VALUES (1)")
require.NoError(t, err)
})
}
func TestWithRetryableTransaction(t *testing.T) {
t.Parallel()
testutils.WithConnectionToContainer(t, func(conn *sql.DB, connStr string) {
ctx := context.Background()
// create a table on which an exclusive lock is held for 2 seconds
setupTableLock(t, connStr, 2*time.Second)
// set the lock timeout to 100ms
ensureLockTimeout(t, conn, 100)
// run a transaction that should retry until the lock is released
rdb := &db.RDB{DB: conn}
err := rdb.WithRetryableTransaction(ctx, func(ctx context.Context, tx *sql.Tx) error {
return tx.QueryRowContext(ctx, "SELECT 1 FROM test").Err()
})
require.NoError(t, err)
})
}
// setupTableLock:
// * connects to the database
// * creates a table in the database
// * starts a transaction that temporarily locks the table
func setupTableLock(t *testing.T, connStr string, d time.Duration) {
t.Helper()
ctx := context.Background()
// connect to the database
conn2, err := sql.Open("postgres", connStr)
require.NoError(t, err)
// create a table in the database
_, err = conn2.ExecContext(ctx, "CREATE TABLE test (id INT PRIMARY KEY)")
require.NoError(t, err)
// start a transaction that takes a temporary lock on the table
errCh := make(chan error)
go func() {
// begin a transaction
tx, err := conn2.Begin()
if err != nil {
errCh <- err
return
}
// lock the table
_, err = tx.ExecContext(ctx, "LOCK TABLE test IN ACCESS EXCLUSIVE MODE")
if err != nil {
errCh <- err
return
}
// signal that the lock is obtained
errCh <- nil
// temporarily hold the lock
time.Sleep(d)
// commit the transaction
tx.Commit()
}()
// wait for the lock to be obtained
err = <-errCh
require.NoError(t, err)
}
func ensureLockTimeout(t *testing.T, conn *sql.DB, ms int) {
t.Helper()
// Set the lock timeout
query := fmt.Sprintf("SET lock_timeout = '%dms'", ms)
_, err := conn.ExecContext(context.Background(), query)
require.NoError(t, err)
// Ensure the lock timeout is set
var lockTimeout string
err = conn.QueryRowContext(context.Background(), "SHOW lock_timeout").Scan(&lockTimeout)
require.NoError(t, err)
require.Equal(t, fmt.Sprintf("%dms", ms), lockTimeout)
}

View File

@ -9,6 +9,7 @@ import (
"fmt"
"github.com/lib/pq"
"github.com/xataio/pgroll/pkg/db"
"github.com/xataio/pgroll/pkg/schema"
)
@ -18,7 +19,7 @@ import (
// 2. Get the first batch of rows from the table, ordered by the primary key.
// 3. Update each row in the batch, setting the value of the primary key column to itself.
// 4. Repeat steps 2 and 3 until no more rows are returned.
func Backfill(ctx context.Context, conn *sql.DB, table *schema.Table, cbs ...CallbackFn) error {
func Backfill(ctx context.Context, conn db.DB, table *schema.Table, cbs ...CallbackFn) error {
// get the backfill column
identityColumn := getIdentityColumn(table)
if identityColumn == nil {
@ -85,27 +86,20 @@ type batcher struct {
batchSize int
}
// updateBatch updates the next batch of rows in the table.
func (b *batcher) updateBatch(ctx context.Context, conn *sql.DB) error {
// Start the transaction for this batch
tx, err := conn.BeginTx(ctx, nil)
if err != nil {
return err
}
defer tx.Rollback()
func (b *batcher) updateBatch(ctx context.Context, conn db.DB) error {
return conn.WithRetryableTransaction(ctx, func(ctx context.Context, tx *sql.Tx) error {
// Build the query to update the next batch of rows
query := b.buildQuery()
// Build the query to update the next batch of rows
query := b.buildQuery()
// Execute the query to update the next batch of rows and update the last PK
// value for the next batch
err := tx.QueryRowContext(ctx, query).Scan(&b.lastValue)
if err != nil {
return err
}
// Execute the query to update the next batch of rows and update the last PK
// value for the next batch
err = tx.QueryRowContext(ctx, query).Scan(&b.lastValue)
if err != nil {
return err
}
// Commit the transaction for this batch
return tx.Commit()
return nil
})
}
// buildQuery builds the query used to update the next batch of rows.

View File

@ -4,13 +4,13 @@ package migrations
import (
"context"
"database/sql"
"fmt"
"github.com/lib/pq"
"github.com/xataio/pgroll/pkg/db"
)
func addCommentToColumn(ctx context.Context, conn *sql.DB, tableName, columnName string, comment *string) error {
func addCommentToColumn(ctx context.Context, conn db.DB, tableName, columnName string, comment *string) error {
_, err := conn.ExecContext(ctx, fmt.Sprintf(`COMMENT ON COLUMN %s.%s IS %s`,
pq.QuoteIdentifier(tableName),
pq.QuoteIdentifier(columnName),
@ -19,7 +19,7 @@ func addCommentToColumn(ctx context.Context, conn *sql.DB, tableName, columnName
return err
}
func addCommentToTable(ctx context.Context, conn *sql.DB, tableName string, comment *string) error {
func addCommentToTable(ctx context.Context, conn db.DB, tableName string, comment *string) error {
_, err := conn.ExecContext(ctx, fmt.Sprintf(`COMMENT ON TABLE %s IS %s`,
pq.QuoteIdentifier(tableName),
commentToSQL(comment)))

View File

@ -4,17 +4,17 @@ package migrations
import (
"context"
"database/sql"
"fmt"
"slices"
"strings"
"github.com/lib/pq"
"github.com/xataio/pgroll/pkg/db"
"github.com/xataio/pgroll/pkg/schema"
)
type Duplicator struct {
conn *sql.DB
conn db.DB
table *schema.Table
column *schema.Column
asName string
@ -24,7 +24,7 @@ type Duplicator struct {
}
// NewColumnDuplicator creates a new Duplicator for a column.
func NewColumnDuplicator(conn *sql.DB, table *schema.Table, column *schema.Column) *Duplicator {
func NewColumnDuplicator(conn db.DB, table *schema.Table, column *schema.Column) *Duplicator {
return &Duplicator{
conn: conn,
table: table,

View File

@ -4,10 +4,10 @@ package migrations
import (
"context"
"database/sql"
"fmt"
_ "github.com/lib/pq"
"github.com/xataio/pgroll/pkg/db"
"github.com/xataio/pgroll/pkg/schema"
)
@ -18,16 +18,16 @@ type Operation interface {
// version in the database (through a view)
// update the given views to expose the new schema version
// Returns the table that requires backfilling, if any.
Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error)
Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error)
// Complete will update the database schema to match the current version
// after calling Start.
// This method should be called once the previous version is no longer used
Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error
Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error
// Rollback will revert the changes made by Start. It is not possible to
// rollback a completed migration.
Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error
Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error
// Validate returns a descriptive error if the operation cannot be applied to the given schema
Validate(ctx context.Context, s *schema.Schema) error

View File

@ -4,18 +4,18 @@ package migrations
import (
"context"
"database/sql"
"errors"
"fmt"
"strings"
"github.com/lib/pq"
"github.com/xataio/pgroll/pkg/db"
"github.com/xataio/pgroll/pkg/schema"
)
var _ Operation = (*OpAddColumn)(nil)
func (o *OpAddColumn) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
func (o *OpAddColumn) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
table := s.GetTable(o.Table)
if err := addColumn(ctx, conn, *o, table, tr); err != nil {
@ -65,7 +65,7 @@ func (o *OpAddColumn) Start(ctx context.Context, conn *sql.DB, stateSchema strin
return tableToBackfill, nil
}
func (o *OpAddColumn) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error {
func (o *OpAddColumn) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error {
tempName := TemporaryName(o.Column.Name)
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s RENAME COLUMN %s TO %s",
@ -118,7 +118,7 @@ func (o *OpAddColumn) Complete(ctx context.Context, conn *sql.DB, tr SQLTransfor
return err
}
func (o *OpAddColumn) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error {
func (o *OpAddColumn) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error {
tempName := TemporaryName(o.Column.Name)
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s DROP COLUMN IF EXISTS %s",
@ -182,7 +182,7 @@ func (o *OpAddColumn) Validate(ctx context.Context, s *schema.Schema) error {
return nil
}
func addColumn(ctx context.Context, conn *sql.DB, o OpAddColumn, t *schema.Table, tr SQLTransformer) error {
func addColumn(ctx context.Context, conn db.DB, o OpAddColumn, t *schema.Table, tr SQLTransformer) error {
// don't add non-nullable columns with no default directly
// they are handled by:
// - adding the column as nullable
@ -216,7 +216,7 @@ func addColumn(ctx context.Context, conn *sql.DB, o OpAddColumn, t *schema.Table
return err
}
func addNotNullConstraint(ctx context.Context, conn *sql.DB, table, column, physicalColumn string) error {
func addNotNullConstraint(ctx context.Context, conn db.DB, table, column, physicalColumn string) error {
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s CHECK (%s IS NOT NULL) NOT VALID",
pq.QuoteIdentifier(table),
pq.QuoteIdentifier(NotNullConstraintName(column)),
@ -225,7 +225,7 @@ func addNotNullConstraint(ctx context.Context, conn *sql.DB, table, column, phys
return err
}
func (o *OpAddColumn) addCheckConstraint(ctx context.Context, conn *sql.DB) error {
func (o *OpAddColumn) addCheckConstraint(ctx context.Context, conn db.DB) error {
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s CHECK (%s) NOT VALID",
pq.QuoteIdentifier(o.Table),
pq.QuoteIdentifier(o.Column.Check.Name),

View File

@ -4,16 +4,16 @@ package migrations
import (
"context"
"database/sql"
"fmt"
"github.com/lib/pq"
"github.com/xataio/pgroll/pkg/db"
"github.com/xataio/pgroll/pkg/schema"
)
var _ Operation = (*OpAlterColumn)(nil)
func (o *OpAlterColumn) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
func (o *OpAlterColumn) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
table := s.GetTable(o.Table)
column := table.GetColumn(o.Column)
ops := o.subOperations()
@ -84,7 +84,7 @@ func (o *OpAlterColumn) Start(ctx context.Context, conn *sql.DB, stateSchema str
return table, nil
}
func (o *OpAlterColumn) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error {
func (o *OpAlterColumn) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error {
ops := o.subOperations()
// Perform any operation specific completion steps
@ -139,7 +139,7 @@ func (o *OpAlterColumn) Complete(ctx context.Context, conn *sql.DB, tr SQLTransf
return nil
}
func (o *OpAlterColumn) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error {
func (o *OpAlterColumn) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error {
ops := o.subOperations()
// Perform any operation specific rollback steps
@ -315,7 +315,7 @@ func (o *OpAlterColumn) subOperations() []Operation {
}
// duplicatorForOperations returns a Duplicator for the given operations
func duplicatorForOperations(ops []Operation, conn *sql.DB, table *schema.Table, column *schema.Column) *Duplicator {
func duplicatorForOperations(ops []Operation, conn db.DB, table *schema.Table, column *schema.Column) *Duplicator {
d := NewColumnDuplicator(conn, table, column)
for _, op := range ops {

View File

@ -4,8 +4,8 @@ package migrations
import (
"context"
"database/sql"
"github.com/xataio/pgroll/pkg/db"
"github.com/xataio/pgroll/pkg/schema"
)
@ -19,17 +19,17 @@ type OpChangeType struct {
var _ Operation = (*OpChangeType)(nil)
func (o *OpChangeType) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
func (o *OpChangeType) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
table := s.GetTable(o.Table)
return table, nil
}
func (o *OpChangeType) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error {
func (o *OpChangeType) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error {
return nil
}
func (o *OpChangeType) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error {
func (o *OpChangeType) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error {
return nil
}

View File

@ -4,17 +4,17 @@ package migrations
import (
"context"
"database/sql"
"fmt"
"strings"
"github.com/lib/pq"
"github.com/xataio/pgroll/pkg/db"
"github.com/xataio/pgroll/pkg/schema"
)
var _ Operation = (*OpCreateIndex)(nil)
func (o *OpCreateIndex) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
func (o *OpCreateIndex) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
// create index concurrently
_, err := conn.ExecContext(ctx, fmt.Sprintf("CREATE INDEX CONCURRENTLY IF NOT EXISTS %s ON %s (%s)",
pq.QuoteIdentifier(o.Name),
@ -23,12 +23,12 @@ func (o *OpCreateIndex) Start(ctx context.Context, conn *sql.DB, stateSchema str
return nil, err
}
func (o *OpCreateIndex) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error {
func (o *OpCreateIndex) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error {
// No-op
return nil
}
func (o *OpCreateIndex) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error {
func (o *OpCreateIndex) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error {
// drop the index concurrently
_, err := conn.ExecContext(ctx, fmt.Sprintf("DROP INDEX CONCURRENTLY IF EXISTS %s", o.Name))

View File

@ -4,17 +4,17 @@ package migrations
import (
"context"
"database/sql"
"fmt"
"strings"
"github.com/lib/pq"
"github.com/xataio/pgroll/pkg/db"
"github.com/xataio/pgroll/pkg/schema"
)
var _ Operation = (*OpCreateTable)(nil)
func (o *OpCreateTable) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
func (o *OpCreateTable) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
// Generate SQL for the columns in the table
columnsSQL, err := columnsToSQL(o.Columns, tr)
if err != nil {
@ -61,7 +61,7 @@ func (o *OpCreateTable) Start(ctx context.Context, conn *sql.DB, stateSchema str
return nil, nil
}
func (o *OpCreateTable) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error {
func (o *OpCreateTable) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error {
tempName := TemporaryName(o.Name)
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s RENAME TO %s",
pq.QuoteIdentifier(tempName),
@ -69,7 +69,7 @@ func (o *OpCreateTable) Complete(ctx context.Context, conn *sql.DB, tr SQLTransf
return err
}
func (o *OpCreateTable) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error {
func (o *OpCreateTable) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error {
tempName := TemporaryName(o.Name)
_, err := conn.ExecContext(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s",

View File

@ -4,16 +4,16 @@ package migrations
import (
"context"
"database/sql"
"fmt"
"github.com/lib/pq"
"github.com/xataio/pgroll/pkg/db"
"github.com/xataio/pgroll/pkg/schema"
)
var _ Operation = (*OpDropColumn)(nil)
func (o *OpDropColumn) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
func (o *OpDropColumn) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
if o.Down != "" {
err := createTrigger(ctx, conn, tr, triggerConfig{
Name: TriggerName(o.Table, o.Column),
@ -34,7 +34,7 @@ func (o *OpDropColumn) Start(ctx context.Context, conn *sql.DB, stateSchema stri
return nil, nil
}
func (o *OpDropColumn) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error {
func (o *OpDropColumn) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error {
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s DROP COLUMN %s",
pq.QuoteIdentifier(o.Table),
pq.QuoteIdentifier(o.Column)))
@ -48,7 +48,7 @@ func (o *OpDropColumn) Complete(ctx context.Context, conn *sql.DB, tr SQLTransfo
return err
}
func (o *OpDropColumn) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error {
func (o *OpDropColumn) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error {
_, err := conn.ExecContext(ctx, fmt.Sprintf("DROP FUNCTION IF EXISTS %s CASCADE",
pq.QuoteIdentifier(TriggerFunctionName(o.Table, o.Column))))

View File

@ -4,16 +4,16 @@ package migrations
import (
"context"
"database/sql"
"fmt"
"github.com/lib/pq"
"github.com/xataio/pgroll/pkg/db"
"github.com/xataio/pgroll/pkg/schema"
)
var _ Operation = (*OpDropConstraint)(nil)
func (o *OpDropConstraint) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
func (o *OpDropConstraint) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
table := s.GetTable(o.Table)
column := table.GetColumn(o.Column)
@ -62,7 +62,7 @@ func (o *OpDropConstraint) Start(ctx context.Context, conn *sql.DB, stateSchema
return table, nil
}
func (o *OpDropConstraint) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error {
func (o *OpDropConstraint) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error {
// Remove the up function and trigger
_, err := conn.ExecContext(ctx, fmt.Sprintf("DROP FUNCTION IF EXISTS %s CASCADE",
pq.QuoteIdentifier(TriggerFunctionName(o.Table, o.Column))))
@ -95,7 +95,7 @@ func (o *OpDropConstraint) Complete(ctx context.Context, conn *sql.DB, tr SQLTra
return err
}
func (o *OpDropConstraint) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error {
func (o *OpDropConstraint) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error {
// Drop the new column
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s DROP COLUMN IF EXISTS %s",
pq.QuoteIdentifier(o.Table),

View File

@ -4,27 +4,27 @@ package migrations
import (
"context"
"database/sql"
"fmt"
"github.com/xataio/pgroll/pkg/db"
"github.com/xataio/pgroll/pkg/schema"
)
var _ Operation = (*OpDropIndex)(nil)
func (o *OpDropIndex) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
func (o *OpDropIndex) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
// no-op
return nil, nil
}
func (o *OpDropIndex) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error {
func (o *OpDropIndex) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error {
// drop the index concurrently
_, err := conn.ExecContext(ctx, fmt.Sprintf("DROP INDEX CONCURRENTLY IF EXISTS %s", o.Name))
return err
}
func (o *OpDropIndex) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error {
func (o *OpDropIndex) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error {
// no-op
return nil
}

View File

@ -4,8 +4,8 @@ package migrations
import (
"context"
"database/sql"
"github.com/xataio/pgroll/pkg/db"
"github.com/xataio/pgroll/pkg/schema"
)
@ -18,17 +18,17 @@ type OpDropNotNull struct {
var _ Operation = (*OpDropNotNull)(nil)
func (o *OpDropNotNull) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
func (o *OpDropNotNull) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
table := s.GetTable(o.Table)
return table, nil
}
func (o *OpDropNotNull) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error {
func (o *OpDropNotNull) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error {
return nil
}
func (o *OpDropNotNull) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error {
func (o *OpDropNotNull) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error {
return nil
}

View File

@ -4,27 +4,27 @@ package migrations
import (
"context"
"database/sql"
"fmt"
"github.com/lib/pq"
"github.com/xataio/pgroll/pkg/db"
"github.com/xataio/pgroll/pkg/schema"
)
var _ Operation = (*OpDropTable)(nil)
func (o *OpDropTable) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
func (o *OpDropTable) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
s.RemoveTable(o.Name)
return nil, nil
}
func (o *OpDropTable) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error {
func (o *OpDropTable) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error {
_, err := conn.ExecContext(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s", pq.QuoteIdentifier(o.Name)))
return err
}
func (o *OpDropTable) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error {
func (o *OpDropTable) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error {
return nil
}

View File

@ -4,14 +4,14 @@ package migrations
import (
"context"
"database/sql"
"github.com/xataio/pgroll/pkg/db"
"github.com/xataio/pgroll/pkg/schema"
)
var _ Operation = (*OpRawSQL)(nil)
func (o *OpRawSQL) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
func (o *OpRawSQL) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
if o.OnComplete {
return nil, nil
}
@ -25,7 +25,7 @@ func (o *OpRawSQL) Start(ctx context.Context, conn *sql.DB, stateSchema string,
return nil, err
}
func (o *OpRawSQL) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error {
func (o *OpRawSQL) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error {
if !o.OnComplete {
return nil
}
@ -39,7 +39,7 @@ func (o *OpRawSQL) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer
return err
}
func (o *OpRawSQL) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error {
func (o *OpRawSQL) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error {
if o.Down == "" {
return nil
}

View File

@ -4,21 +4,21 @@ package migrations
import (
"context"
"database/sql"
"fmt"
"github.com/lib/pq"
"github.com/xataio/pgroll/pkg/db"
"github.com/xataio/pgroll/pkg/schema"
)
var _ Operation = (*OpRenameConstraint)(nil)
func (o *OpRenameConstraint) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
func (o *OpRenameConstraint) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
// no-op
return nil, nil
}
func (o *OpRenameConstraint) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error {
func (o *OpRenameConstraint) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error {
// rename the constraint in the underlying table
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s RENAME CONSTRAINT %s TO %s",
pq.QuoteIdentifier(o.Table),
@ -27,7 +27,7 @@ func (o *OpRenameConstraint) Complete(ctx context.Context, conn *sql.DB, tr SQLT
return err
}
func (o *OpRenameConstraint) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error {
func (o *OpRenameConstraint) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error {
// no-op
return nil
}

View File

@ -4,27 +4,27 @@ package migrations
import (
"context"
"database/sql"
"fmt"
"github.com/lib/pq"
"github.com/xataio/pgroll/pkg/db"
"github.com/xataio/pgroll/pkg/schema"
)
var _ Operation = (*OpRenameTable)(nil)
func (o *OpRenameTable) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
func (o *OpRenameTable) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
return nil, s.RenameTable(o.From, o.To)
}
func (o *OpRenameTable) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error {
func (o *OpRenameTable) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error {
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s RENAME TO %s",
pq.QuoteIdentifier(o.From),
pq.QuoteIdentifier(o.To)))
return err
}
func (o *OpRenameTable) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error {
func (o *OpRenameTable) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error {
return nil
}

View File

@ -4,11 +4,11 @@ package migrations
import (
"context"
"database/sql"
"fmt"
"strings"
"github.com/lib/pq"
"github.com/xataio/pgroll/pkg/db"
"github.com/xataio/pgroll/pkg/schema"
)
@ -22,7 +22,7 @@ type OpSetCheckConstraint struct {
var _ Operation = (*OpSetCheckConstraint)(nil)
func (o *OpSetCheckConstraint) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
func (o *OpSetCheckConstraint) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
table := s.GetTable(o.Table)
// Add the check constraint to the new column as NOT VALID.
@ -33,7 +33,7 @@ func (o *OpSetCheckConstraint) Start(ctx context.Context, conn *sql.DB, stateSch
return table, nil
}
func (o *OpSetCheckConstraint) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error {
func (o *OpSetCheckConstraint) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error {
// Validate the check constraint
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s VALIDATE CONSTRAINT %s",
pq.QuoteIdentifier(o.Table),
@ -45,7 +45,7 @@ func (o *OpSetCheckConstraint) Complete(ctx context.Context, conn *sql.DB, tr SQ
return nil
}
func (o *OpSetCheckConstraint) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error {
func (o *OpSetCheckConstraint) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error {
return nil
}
@ -69,7 +69,7 @@ func (o *OpSetCheckConstraint) Validate(ctx context.Context, s *schema.Schema) e
return nil
}
func (o *OpSetCheckConstraint) addCheckConstraint(ctx context.Context, conn *sql.DB) error {
func (o *OpSetCheckConstraint) addCheckConstraint(ctx context.Context, conn db.DB) error {
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s CHECK (%s) NOT VALID",
pq.QuoteIdentifier(o.Table),
pq.QuoteIdentifier(o.Check.Name),

View File

@ -4,8 +4,8 @@ package migrations
import (
"context"
"database/sql"
"github.com/xataio/pgroll/pkg/db"
"github.com/xataio/pgroll/pkg/schema"
)
@ -19,17 +19,17 @@ type OpSetComment struct {
var _ Operation = (*OpSetComment)(nil)
func (o *OpSetComment) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
func (o *OpSetComment) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
tbl := s.GetTable(o.Table)
return tbl, addCommentToColumn(ctx, conn, o.Table, TemporaryName(o.Column), o.Comment)
}
func (o *OpSetComment) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error {
func (o *OpSetComment) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error {
return nil
}
func (o *OpSetComment) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error {
func (o *OpSetComment) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error {
return nil
}

View File

@ -4,10 +4,10 @@ package migrations
import (
"context"
"database/sql"
"fmt"
"github.com/lib/pq"
"github.com/xataio/pgroll/pkg/db"
"github.com/xataio/pgroll/pkg/schema"
)
@ -21,7 +21,7 @@ type OpSetDefault struct {
var _ Operation = (*OpSetDefault)(nil)
func (o *OpSetDefault) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
func (o *OpSetDefault) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
tbl := s.GetTable(o.Table)
_, err := conn.ExecContext(ctx, fmt.Sprintf(`ALTER TABLE %s ALTER COLUMN %s SET DEFAULT %s`,
@ -35,11 +35,11 @@ func (o *OpSetDefault) Start(ctx context.Context, conn *sql.DB, stateSchema stri
return tbl, nil
}
func (o *OpSetDefault) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error {
func (o *OpSetDefault) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error {
return nil
}
func (o *OpSetDefault) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error {
func (o *OpSetDefault) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error {
return nil
}

View File

@ -4,11 +4,11 @@ package migrations
import (
"context"
"database/sql"
"fmt"
"strings"
"github.com/lib/pq"
"github.com/xataio/pgroll/pkg/db"
"github.com/xataio/pgroll/pkg/schema"
)
@ -22,7 +22,7 @@ type OpSetForeignKey struct {
var _ Operation = (*OpSetForeignKey)(nil)
func (o *OpSetForeignKey) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
func (o *OpSetForeignKey) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
table := s.GetTable(o.Table)
// Create a NOT VALID foreign key constraint on the new column.
@ -33,7 +33,7 @@ func (o *OpSetForeignKey) Start(ctx context.Context, conn *sql.DB, stateSchema s
return table, nil
}
func (o *OpSetForeignKey) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error {
func (o *OpSetForeignKey) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error {
// Validate the foreign key constraint
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s VALIDATE CONSTRAINT %s",
pq.QuoteIdentifier(o.Table),
@ -45,7 +45,7 @@ func (o *OpSetForeignKey) Complete(ctx context.Context, conn *sql.DB, tr SQLTran
return nil
}
func (o *OpSetForeignKey) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error {
func (o *OpSetForeignKey) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error {
return nil
}
@ -69,7 +69,7 @@ func (o *OpSetForeignKey) Validate(ctx context.Context, s *schema.Schema) error
return nil
}
func (o *OpSetForeignKey) addForeignKeyConstraint(ctx context.Context, conn *sql.DB) error {
func (o *OpSetForeignKey) addForeignKeyConstraint(ctx context.Context, conn db.DB) error {
tempColumnName := TemporaryName(o.Column)
onDelete := "NO ACTION"

View File

@ -4,10 +4,10 @@ package migrations
import (
"context"
"database/sql"
"fmt"
"github.com/lib/pq"
"github.com/xataio/pgroll/pkg/db"
"github.com/xataio/pgroll/pkg/schema"
)
@ -20,7 +20,7 @@ type OpSetNotNull struct {
var _ Operation = (*OpSetNotNull)(nil)
func (o *OpSetNotNull) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
func (o *OpSetNotNull) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
table := s.GetTable(o.Table)
// Add an unchecked NOT NULL constraint to the new column.
@ -31,7 +31,7 @@ func (o *OpSetNotNull) Start(ctx context.Context, conn *sql.DB, stateSchema stri
return table, nil
}
func (o *OpSetNotNull) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error {
func (o *OpSetNotNull) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error {
// Validate the NOT NULL constraint on the old column.
// The constraint must be valid because:
// * Existing NULL values in the old column were rewritten using the `up` SQL during backfill.
@ -62,7 +62,7 @@ func (o *OpSetNotNull) Complete(ctx context.Context, conn *sql.DB, tr SQLTransfo
return nil
}
func (o *OpSetNotNull) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error {
func (o *OpSetNotNull) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error {
return nil
}

View File

@ -4,18 +4,18 @@ package migrations
import (
"context"
"database/sql"
"fmt"
"slices"
"strings"
"github.com/lib/pq"
"github.com/xataio/pgroll/pkg/db"
"github.com/xataio/pgroll/pkg/schema"
)
var _ Operation = (*OpSetReplicaIdentity)(nil)
func (o *OpSetReplicaIdentity) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
func (o *OpSetReplicaIdentity) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
// build the correct form of the `SET REPLICA IDENTITY` statement based on the`identity type
identitySQL := strings.ToUpper(o.Identity.Type)
if identitySQL == "INDEX" {
@ -29,12 +29,12 @@ func (o *OpSetReplicaIdentity) Start(ctx context.Context, conn *sql.DB, stateSch
return nil, err
}
func (o *OpSetReplicaIdentity) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error {
func (o *OpSetReplicaIdentity) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error {
// No-op
return nil
}
func (o *OpSetReplicaIdentity) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error {
func (o *OpSetReplicaIdentity) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error {
// No-op
return nil
}

View File

@ -4,10 +4,10 @@ package migrations
import (
"context"
"database/sql"
"fmt"
"github.com/lib/pq"
"github.com/xataio/pgroll/pkg/db"
"github.com/xataio/pgroll/pkg/schema"
)
@ -21,7 +21,7 @@ type OpSetUnique struct {
var _ Operation = (*OpSetUnique)(nil)
func (o *OpSetUnique) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
func (o *OpSetUnique) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
table := s.GetTable(o.Table)
// Add a unique index to the new column
@ -32,7 +32,7 @@ func (o *OpSetUnique) Start(ctx context.Context, conn *sql.DB, stateSchema strin
return table, nil
}
func (o *OpSetUnique) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error {
func (o *OpSetUnique) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error {
// Create a unique constraint using the unique index
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s ADD CONSTRAINT %s UNIQUE USING INDEX %s",
pq.QuoteIdentifier(o.Table),
@ -45,7 +45,7 @@ func (o *OpSetUnique) Complete(ctx context.Context, conn *sql.DB, tr SQLTransfor
return err
}
func (o *OpSetUnique) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error {
func (o *OpSetUnique) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error {
return nil
}
@ -66,7 +66,7 @@ func (o *OpSetUnique) Validate(ctx context.Context, s *schema.Schema) error {
return nil
}
func (o *OpSetUnique) addUniqueIndex(ctx context.Context, conn *sql.DB) error {
func (o *OpSetUnique) addUniqueIndex(ctx context.Context, conn db.DB) error {
// create unique index concurrently
_, err := conn.ExecContext(ctx, fmt.Sprintf("CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS %s ON %s (%s)",
pq.QuoteIdentifier(o.Name),

View File

@ -4,11 +4,11 @@ package migrations
import (
"context"
"database/sql"
"fmt"
"slices"
"github.com/lib/pq"
"github.com/xataio/pgroll/pkg/db"
"github.com/xataio/pgroll/pkg/schema"
)
@ -16,7 +16,7 @@ import (
// * renames a duplicated column to its original name
// * renames any foreign keys on the duplicated column to their original name.
// * Validates and renames any temporary `CHECK` constraints on the duplicated column.
func RenameDuplicatedColumn(ctx context.Context, conn *sql.DB, table *schema.Table, column *schema.Column) error {
func RenameDuplicatedColumn(ctx context.Context, conn db.DB, table *schema.Table, column *schema.Column) error {
const (
cRenameColumnSQL = `ALTER TABLE IF EXISTS %s RENAME COLUMN %s TO %s`
cRenameConstraintSQL = `ALTER TABLE IF EXISTS %s RENAME CONSTRAINT %s TO %s`

View File

@ -5,10 +5,10 @@ package migrations
import (
"bytes"
"context"
"database/sql"
"text/template"
"github.com/lib/pq"
"github.com/xataio/pgroll/pkg/db"
"github.com/xataio/pgroll/pkg/migrations/templates"
"github.com/xataio/pgroll/pkg/schema"
)
@ -33,7 +33,7 @@ type triggerConfig struct {
SQL string
}
func createTrigger(ctx context.Context, conn *sql.DB, tr SQLTransformer, cfg triggerConfig) error {
func createTrigger(ctx context.Context, conn db.DB, tr SQLTransformer, cfg triggerConfig) error {
sql, err := tr.TransformSQL(cfg.SQL)
if err != nil {
return err

View File

@ -5,11 +5,10 @@ package roll_test
import (
"context"
"database/sql"
"errors"
"fmt"
"testing"
"time"
"github.com/lib/pq"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/xataio/pgroll/pkg/migrations"
@ -313,57 +312,48 @@ func TestSchemaOptionIsRespected(t *testing.T) {
})
}
func TestLockTimeoutIsEnforced(t *testing.T) {
func TestMigrationDDLIsRetriedOnLockTimeouts(t *testing.T) {
t.Parallel()
testutils.WithMigratorInSchemaAndConnectionToContainerWithOptions(t, "public", []roll.Option{roll.WithLockTimeoutMs(100)}, func(mig *roll.Roll, db *sql.DB) {
testutils.WithMigratorInSchemaAndConnectionToContainerWithOptions(t, "public", []roll.Option{roll.WithLockTimeoutMs(50)}, func(mig *roll.Roll, db *sql.DB) {
ctx := context.Background()
// Start a create table migration
err := mig.Start(ctx, &migrations.Migration{
Name: "01_create_table",
Operations: migrations.Operations{createTableOp("table1")},
})
if err != nil {
t.Fatalf("Failed to start migration: %v", err)
}
// Create a table
_, err := db.ExecContext(ctx, "CREATE TABLE table1 (id integer, name text)")
require.NoError(t, err)
// Complete the create table migration
if err := mig.Complete(ctx); err != nil {
t.Fatalf("Failed to complete migration: %v", err)
}
// Start a goroutine which takes an ACCESS_EXCLUSIVE lock on the table for
// two seconds
errCh := make(chan error)
go func() {
tx, err := db.Begin()
if err != nil {
errCh <- err
}
// Start a transaction and take an ACCESS_EXCLUSIVE lock on the table
// Don't commit the transaction so that the lock is held indefinitely
tx, err := db.Begin()
if err != nil {
t.Fatalf("Failed to start transaction: %v", err)
}
t.Cleanup(func() {
if _, err := tx.ExecContext(ctx, "LOCK TABLE table1 IN ACCESS EXCLUSIVE MODE"); err != nil {
errCh <- err
}
errCh <- nil
// Sleep for two seconds to hold the lock
time.Sleep(2 * time.Second)
// Commit the transaction
tx.Commit()
})
if _, err := tx.ExecContext(ctx, "LOCK TABLE table1 IN ACCESS EXCLUSIVE MODE"); err != nil {
t.Fatalf("Failed to take ACCESS_EXCLUSIVE lock on table: %v", err)
}
}()
// Attempt to run a second migration on the table while the lock is held
// The migration should fail due to a lock timeout error
// Wait for lock to be taken
err = <-errCh
require.NoError(t, err)
// Attempt to start a second migration on the table while the lock is held.
// The migration should eventually succeed after the lock is released
err = mig.Start(ctx, &migrations.Migration{
Name: "02_create_table",
Name: "01_add_column",
Operations: migrations.Operations{addColumnOp("table1")},
})
if err == nil {
t.Fatalf("Expected migration to fail due to lock timeout")
}
if err != nil {
pqErr := &pq.Error{}
if ok := errors.As(err, &pqErr); !ok {
t.Fatalf("Migration failed with unexpected error: %v", err)
}
if pqErr.Code != "55P03" { // Lock not available error code
t.Fatalf("Migration failed with unexpected error: %v", err)
}
}
require.NoError(t, err)
})
}

View File

@ -9,6 +9,7 @@ import (
"github.com/lib/pq"
"github.com/xataio/pgroll/pkg/db"
"github.com/xataio/pgroll/pkg/migrations"
"github.com/xataio/pgroll/pkg/state"
)
@ -18,7 +19,7 @@ type PGVersion int
const PGVersion15 PGVersion = 15
type Roll struct {
pgConn *sql.DB // TODO abstract sql connection
pgConn db.DB
// schema we are acting on
schema string
@ -57,7 +58,7 @@ func New(ctx context.Context, pgURL, schema string, state *state.State, opts ...
}
return &Roll{
pgConn: conn,
pgConn: &db.RDB{DB: conn},
schema: schema,
state: state,
pgVersion: PGVersion(pgMajorVersion),
@ -114,7 +115,7 @@ func (m *Roll) PGVersion() PGVersion {
return m.pgVersion
}
func (m *Roll) PgConn() *sql.DB {
func (m *Roll) PgConn() db.DB {
return m.pgConn
}

View File

@ -103,6 +103,14 @@ func WithStateInSchemaAndConnectionToContainer(t *testing.T, schema string, fn f
fn(st, db)
}
func WithConnectionToContainer(t *testing.T, fn func(*sql.DB, string)) {
t.Helper()
db, connStr, _ := setupTestDatabase(t)
fn(db, connStr)
}
func WithStateAndConnectionToContainer(t *testing.T, fn func(*state.State, *sql.DB)) {
WithStateInSchemaAndConnectionToContainer(t, "pgroll", fn)
}