From 5c1aef2f24f3317b1406e75906d7ef58a1639a71 Mon Sep 17 00:00:00 2001 From: Andrew Farries Date: Wed, 8 May 2024 15:54:27 +0100 Subject: [PATCH] Retry on `lock_timeout` errors (#353) Retry statements and transactions that fail due to `lock_timeout` errors. DDL operations and backfills are run in a session in which `SET lock_timout TO xms'` has been set (`x` defaults to `500` but can be specified with the `--lock-timeout` parameter). This ensures that a long running query can't cause other queries to queue up behind a DDL operation as it waits to acquire its lock. The current behaviour if a DDL operation or backfill batch times out when requesting a lock is to fail, forcing the user to retry the migration operation (start, rollback, or complete). This PR retries individual statements (like the DDL operations run by migration operations) and transactions (used by backfills) if they fail due to a `lock_timeout` error. The retry uses an exponential backoff with jitter. Fixes #171 --- go.mod | 1 + go.sum | 2 + pkg/db/db.go | 82 +++++++++++++++ pkg/db/db_test.go | 121 ++++++++++++++++++++++ pkg/migrations/backfill.go | 34 +++--- pkg/migrations/comment.go | 6 +- pkg/migrations/duplicate.go | 6 +- pkg/migrations/migrations.go | 8 +- pkg/migrations/op_add_column.go | 14 +-- pkg/migrations/op_alter_column.go | 10 +- pkg/migrations/op_change_type.go | 8 +- pkg/migrations/op_create_index.go | 8 +- pkg/migrations/op_create_table.go | 8 +- pkg/migrations/op_drop_column.go | 8 +- pkg/migrations/op_drop_constraint.go | 8 +- pkg/migrations/op_drop_index.go | 8 +- pkg/migrations/op_drop_not_null.go | 8 +- pkg/migrations/op_drop_table.go | 8 +- pkg/migrations/op_raw_sql.go | 8 +- pkg/migrations/op_rename_constraint.go | 8 +- pkg/migrations/op_rename_table.go | 8 +- pkg/migrations/op_set_check.go | 10 +- pkg/migrations/op_set_comment.go | 8 +- pkg/migrations/op_set_default.go | 8 +- pkg/migrations/op_set_fk.go | 10 +- pkg/migrations/op_set_notnull.go | 8 +- pkg/migrations/op_set_replica_identity.go | 8 +- pkg/migrations/op_set_unique.go | 10 +- pkg/migrations/rename.go | 4 +- pkg/migrations/trigger.go | 4 +- pkg/roll/execute_test.go | 74 ++++++------- pkg/roll/roll.go | 7 +- pkg/testutils/util.go | 8 ++ 33 files changed, 365 insertions(+), 166 deletions(-) create mode 100644 pkg/db/db.go create mode 100644 pkg/db/db_test.go diff --git a/go.mod b/go.mod index b7d6578..e1787f8 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/xataio/pgroll go 1.21 require ( + github.com/cloudflare/backoff v0.0.0-20161212185259-647f3cdfc87a github.com/google/go-cmp v0.6.0 github.com/lib/pq v1.10.9 github.com/oapi-codegen/nullable v1.1.0 diff --git a/go.sum b/go.sum index 0811811..d402871 100644 --- a/go.sum +++ b/go.sum @@ -73,6 +73,8 @@ github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWR github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= +github.com/cloudflare/backoff v0.0.0-20161212185259-647f3cdfc87a h1:8d1CEOF1xldesKds5tRG3tExBsMOgWYownMHNCsev54= +github.com/cloudflare/backoff v0.0.0-20161212185259-647f3cdfc87a/go.mod h1:rzgs2ZOiguV6/NpiDgADjRLPNyZlApIWxKpkT+X8SdY= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= diff --git a/pkg/db/db.go b/pkg/db/db.go new file mode 100644 index 0000000..dc5be90 --- /dev/null +++ b/pkg/db/db.go @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: Apache-2.0 + +package db + +import ( + "context" + "database/sql" + "errors" + "time" + + "github.com/cloudflare/backoff" + "github.com/lib/pq" +) + +const ( + lockNotAvailableErrorCode pq.ErrorCode = "55P03" + maxBackoffDuration = 1 * time.Minute + backoffInterval = 1 * time.Second +) + +type DB interface { + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) + WithRetryableTransaction(ctx context.Context, f func(context.Context, *sql.Tx) error) error + Close() error +} + +// RDB wraps a *sql.DB and retries queries using an exponential backoff (with +// jitter) on lock_timeout errors. +type RDB struct { + DB *sql.DB +} + +// ExecContext wraps sql.DB.ExecContext, retrying queries on lock_timeout errors. +func (db *RDB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + b := backoff.New(maxBackoffDuration, backoffInterval) + + for { + res, err := db.DB.ExecContext(ctx, query, args...) + if err == nil { + return res, nil + } + + pqErr := &pq.Error{} + if errors.As(err, &pqErr) && pqErr.Code == lockNotAvailableErrorCode { + <-time.After(b.Duration()) + } else { + return nil, err + } + } +} + +// WithRetryableTransaction runs `f` in a transaction, retrying on lock_timeout errors. +func (db *RDB) WithRetryableTransaction(ctx context.Context, f func(context.Context, *sql.Tx) error) error { + b := backoff.New(maxBackoffDuration, backoffInterval) + + for { + tx, err := db.DB.BeginTx(ctx, nil) + if err != nil { + return err + } + + err = f(ctx, tx) + if err == nil { + return tx.Commit() + } + + if errRollback := tx.Rollback(); errRollback != nil { + return errRollback + } + + pqErr := &pq.Error{} + if errors.As(err, &pqErr) && pqErr.Code == lockNotAvailableErrorCode { + <-time.After(b.Duration()) + } else { + return err + } + } +} + +func (db *RDB) Close() error { + return db.DB.Close() +} diff --git a/pkg/db/db_test.go b/pkg/db/db_test.go new file mode 100644 index 0000000..bfc37b2 --- /dev/null +++ b/pkg/db/db_test.go @@ -0,0 +1,121 @@ +// SPDX-License-Identifier: Apache-2.0 + +package db_test + +import ( + "context" + "database/sql" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/xataio/pgroll/pkg/db" + "github.com/xataio/pgroll/pkg/testutils" +) + +func TestMain(m *testing.M) { + testutils.SharedTestMain(m) +} + +func TestExecContext(t *testing.T) { + t.Parallel() + + testutils.WithConnectionToContainer(t, func(conn *sql.DB, connStr string) { + ctx := context.Background() + // create a table on which an exclusive lock is held for 2 seconds + setupTableLock(t, connStr, 2*time.Second) + + // set the lock timeout to 100ms + ensureLockTimeout(t, conn, 100) + + // execute a query that should retry until the lock is released + rdb := &db.RDB{DB: conn} + _, err := rdb.ExecContext(ctx, "INSERT INTO test(id) VALUES (1)") + require.NoError(t, err) + }) +} + +func TestWithRetryableTransaction(t *testing.T) { + t.Parallel() + + testutils.WithConnectionToContainer(t, func(conn *sql.DB, connStr string) { + ctx := context.Background() + + // create a table on which an exclusive lock is held for 2 seconds + setupTableLock(t, connStr, 2*time.Second) + + // set the lock timeout to 100ms + ensureLockTimeout(t, conn, 100) + + // run a transaction that should retry until the lock is released + rdb := &db.RDB{DB: conn} + err := rdb.WithRetryableTransaction(ctx, func(ctx context.Context, tx *sql.Tx) error { + return tx.QueryRowContext(ctx, "SELECT 1 FROM test").Err() + }) + require.NoError(t, err) + }) +} + +// setupTableLock: +// * connects to the database +// * creates a table in the database +// * starts a transaction that temporarily locks the table +func setupTableLock(t *testing.T, connStr string, d time.Duration) { + t.Helper() + ctx := context.Background() + + // connect to the database + conn2, err := sql.Open("postgres", connStr) + require.NoError(t, err) + + // create a table in the database + _, err = conn2.ExecContext(ctx, "CREATE TABLE test (id INT PRIMARY KEY)") + require.NoError(t, err) + + // start a transaction that takes a temporary lock on the table + errCh := make(chan error) + go func() { + // begin a transaction + tx, err := conn2.Begin() + if err != nil { + errCh <- err + return + } + + // lock the table + _, err = tx.ExecContext(ctx, "LOCK TABLE test IN ACCESS EXCLUSIVE MODE") + if err != nil { + errCh <- err + return + } + + // signal that the lock is obtained + errCh <- nil + + // temporarily hold the lock + time.Sleep(d) + + // commit the transaction + tx.Commit() + }() + + // wait for the lock to be obtained + err = <-errCh + require.NoError(t, err) +} + +func ensureLockTimeout(t *testing.T, conn *sql.DB, ms int) { + t.Helper() + + // Set the lock timeout + query := fmt.Sprintf("SET lock_timeout = '%dms'", ms) + _, err := conn.ExecContext(context.Background(), query) + require.NoError(t, err) + + // Ensure the lock timeout is set + var lockTimeout string + err = conn.QueryRowContext(context.Background(), "SHOW lock_timeout").Scan(&lockTimeout) + require.NoError(t, err) + require.Equal(t, fmt.Sprintf("%dms", ms), lockTimeout) +} diff --git a/pkg/migrations/backfill.go b/pkg/migrations/backfill.go index 35ef729..56fcae9 100644 --- a/pkg/migrations/backfill.go +++ b/pkg/migrations/backfill.go @@ -9,6 +9,7 @@ import ( "fmt" "github.com/lib/pq" + "github.com/xataio/pgroll/pkg/db" "github.com/xataio/pgroll/pkg/schema" ) @@ -18,7 +19,7 @@ import ( // 2. Get the first batch of rows from the table, ordered by the primary key. // 3. Update each row in the batch, setting the value of the primary key column to itself. // 4. Repeat steps 2 and 3 until no more rows are returned. -func Backfill(ctx context.Context, conn *sql.DB, table *schema.Table, cbs ...CallbackFn) error { +func Backfill(ctx context.Context, conn db.DB, table *schema.Table, cbs ...CallbackFn) error { // get the backfill column identityColumn := getIdentityColumn(table) if identityColumn == nil { @@ -85,27 +86,20 @@ type batcher struct { batchSize int } -// updateBatch updates the next batch of rows in the table. -func (b *batcher) updateBatch(ctx context.Context, conn *sql.DB) error { - // Start the transaction for this batch - tx, err := conn.BeginTx(ctx, nil) - if err != nil { - return err - } - defer tx.Rollback() +func (b *batcher) updateBatch(ctx context.Context, conn db.DB) error { + return conn.WithRetryableTransaction(ctx, func(ctx context.Context, tx *sql.Tx) error { + // Build the query to update the next batch of rows + query := b.buildQuery() - // Build the query to update the next batch of rows - query := b.buildQuery() + // Execute the query to update the next batch of rows and update the last PK + // value for the next batch + err := tx.QueryRowContext(ctx, query).Scan(&b.lastValue) + if err != nil { + return err + } - // Execute the query to update the next batch of rows and update the last PK - // value for the next batch - err = tx.QueryRowContext(ctx, query).Scan(&b.lastValue) - if err != nil { - return err - } - - // Commit the transaction for this batch - return tx.Commit() + return nil + }) } // buildQuery builds the query used to update the next batch of rows. diff --git a/pkg/migrations/comment.go b/pkg/migrations/comment.go index b3f232d..f1a5028 100644 --- a/pkg/migrations/comment.go +++ b/pkg/migrations/comment.go @@ -4,13 +4,13 @@ package migrations import ( "context" - "database/sql" "fmt" "github.com/lib/pq" + "github.com/xataio/pgroll/pkg/db" ) -func addCommentToColumn(ctx context.Context, conn *sql.DB, tableName, columnName string, comment *string) error { +func addCommentToColumn(ctx context.Context, conn db.DB, tableName, columnName string, comment *string) error { _, err := conn.ExecContext(ctx, fmt.Sprintf(`COMMENT ON COLUMN %s.%s IS %s`, pq.QuoteIdentifier(tableName), pq.QuoteIdentifier(columnName), @@ -19,7 +19,7 @@ func addCommentToColumn(ctx context.Context, conn *sql.DB, tableName, columnName return err } -func addCommentToTable(ctx context.Context, conn *sql.DB, tableName string, comment *string) error { +func addCommentToTable(ctx context.Context, conn db.DB, tableName string, comment *string) error { _, err := conn.ExecContext(ctx, fmt.Sprintf(`COMMENT ON TABLE %s IS %s`, pq.QuoteIdentifier(tableName), commentToSQL(comment))) diff --git a/pkg/migrations/duplicate.go b/pkg/migrations/duplicate.go index 7bcd58d..19e386f 100644 --- a/pkg/migrations/duplicate.go +++ b/pkg/migrations/duplicate.go @@ -4,17 +4,17 @@ package migrations import ( "context" - "database/sql" "fmt" "slices" "strings" "github.com/lib/pq" + "github.com/xataio/pgroll/pkg/db" "github.com/xataio/pgroll/pkg/schema" ) type Duplicator struct { - conn *sql.DB + conn db.DB table *schema.Table column *schema.Column asName string @@ -24,7 +24,7 @@ type Duplicator struct { } // NewColumnDuplicator creates a new Duplicator for a column. -func NewColumnDuplicator(conn *sql.DB, table *schema.Table, column *schema.Column) *Duplicator { +func NewColumnDuplicator(conn db.DB, table *schema.Table, column *schema.Column) *Duplicator { return &Duplicator{ conn: conn, table: table, diff --git a/pkg/migrations/migrations.go b/pkg/migrations/migrations.go index 76bc055..3e26c46 100644 --- a/pkg/migrations/migrations.go +++ b/pkg/migrations/migrations.go @@ -4,10 +4,10 @@ package migrations import ( "context" - "database/sql" "fmt" _ "github.com/lib/pq" + "github.com/xataio/pgroll/pkg/db" "github.com/xataio/pgroll/pkg/schema" ) @@ -18,16 +18,16 @@ type Operation interface { // version in the database (through a view) // update the given views to expose the new schema version // Returns the table that requires backfilling, if any. - Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) + Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) // Complete will update the database schema to match the current version // after calling Start. // This method should be called once the previous version is no longer used - Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error + Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error // Rollback will revert the changes made by Start. It is not possible to // rollback a completed migration. - Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error + Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error // Validate returns a descriptive error if the operation cannot be applied to the given schema Validate(ctx context.Context, s *schema.Schema) error diff --git a/pkg/migrations/op_add_column.go b/pkg/migrations/op_add_column.go index aa251e2..88d1f67 100644 --- a/pkg/migrations/op_add_column.go +++ b/pkg/migrations/op_add_column.go @@ -4,18 +4,18 @@ package migrations import ( "context" - "database/sql" "errors" "fmt" "strings" "github.com/lib/pq" + "github.com/xataio/pgroll/pkg/db" "github.com/xataio/pgroll/pkg/schema" ) var _ Operation = (*OpAddColumn)(nil) -func (o *OpAddColumn) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { +func (o *OpAddColumn) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { table := s.GetTable(o.Table) if err := addColumn(ctx, conn, *o, table, tr); err != nil { @@ -65,7 +65,7 @@ func (o *OpAddColumn) Start(ctx context.Context, conn *sql.DB, stateSchema strin return tableToBackfill, nil } -func (o *OpAddColumn) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpAddColumn) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { tempName := TemporaryName(o.Column.Name) _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s RENAME COLUMN %s TO %s", @@ -118,7 +118,7 @@ func (o *OpAddColumn) Complete(ctx context.Context, conn *sql.DB, tr SQLTransfor return err } -func (o *OpAddColumn) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error { +func (o *OpAddColumn) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error { tempName := TemporaryName(o.Column.Name) _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s DROP COLUMN IF EXISTS %s", @@ -182,7 +182,7 @@ func (o *OpAddColumn) Validate(ctx context.Context, s *schema.Schema) error { return nil } -func addColumn(ctx context.Context, conn *sql.DB, o OpAddColumn, t *schema.Table, tr SQLTransformer) error { +func addColumn(ctx context.Context, conn db.DB, o OpAddColumn, t *schema.Table, tr SQLTransformer) error { // don't add non-nullable columns with no default directly // they are handled by: // - adding the column as nullable @@ -216,7 +216,7 @@ func addColumn(ctx context.Context, conn *sql.DB, o OpAddColumn, t *schema.Table return err } -func addNotNullConstraint(ctx context.Context, conn *sql.DB, table, column, physicalColumn string) error { +func addNotNullConstraint(ctx context.Context, conn db.DB, table, column, physicalColumn string) error { _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s CHECK (%s IS NOT NULL) NOT VALID", pq.QuoteIdentifier(table), pq.QuoteIdentifier(NotNullConstraintName(column)), @@ -225,7 +225,7 @@ func addNotNullConstraint(ctx context.Context, conn *sql.DB, table, column, phys return err } -func (o *OpAddColumn) addCheckConstraint(ctx context.Context, conn *sql.DB) error { +func (o *OpAddColumn) addCheckConstraint(ctx context.Context, conn db.DB) error { _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s CHECK (%s) NOT VALID", pq.QuoteIdentifier(o.Table), pq.QuoteIdentifier(o.Column.Check.Name), diff --git a/pkg/migrations/op_alter_column.go b/pkg/migrations/op_alter_column.go index 8653fad..95ae8c3 100644 --- a/pkg/migrations/op_alter_column.go +++ b/pkg/migrations/op_alter_column.go @@ -4,16 +4,16 @@ package migrations import ( "context" - "database/sql" "fmt" "github.com/lib/pq" + "github.com/xataio/pgroll/pkg/db" "github.com/xataio/pgroll/pkg/schema" ) var _ Operation = (*OpAlterColumn)(nil) -func (o *OpAlterColumn) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { +func (o *OpAlterColumn) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { table := s.GetTable(o.Table) column := table.GetColumn(o.Column) ops := o.subOperations() @@ -84,7 +84,7 @@ func (o *OpAlterColumn) Start(ctx context.Context, conn *sql.DB, stateSchema str return table, nil } -func (o *OpAlterColumn) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpAlterColumn) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { ops := o.subOperations() // Perform any operation specific completion steps @@ -139,7 +139,7 @@ func (o *OpAlterColumn) Complete(ctx context.Context, conn *sql.DB, tr SQLTransf return nil } -func (o *OpAlterColumn) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error { +func (o *OpAlterColumn) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error { ops := o.subOperations() // Perform any operation specific rollback steps @@ -315,7 +315,7 @@ func (o *OpAlterColumn) subOperations() []Operation { } // duplicatorForOperations returns a Duplicator for the given operations -func duplicatorForOperations(ops []Operation, conn *sql.DB, table *schema.Table, column *schema.Column) *Duplicator { +func duplicatorForOperations(ops []Operation, conn db.DB, table *schema.Table, column *schema.Column) *Duplicator { d := NewColumnDuplicator(conn, table, column) for _, op := range ops { diff --git a/pkg/migrations/op_change_type.go b/pkg/migrations/op_change_type.go index ea2c903..d9bf0a3 100644 --- a/pkg/migrations/op_change_type.go +++ b/pkg/migrations/op_change_type.go @@ -4,8 +4,8 @@ package migrations import ( "context" - "database/sql" + "github.com/xataio/pgroll/pkg/db" "github.com/xataio/pgroll/pkg/schema" ) @@ -19,17 +19,17 @@ type OpChangeType struct { var _ Operation = (*OpChangeType)(nil) -func (o *OpChangeType) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { +func (o *OpChangeType) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { table := s.GetTable(o.Table) return table, nil } -func (o *OpChangeType) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpChangeType) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { return nil } -func (o *OpChangeType) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error { +func (o *OpChangeType) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error { return nil } diff --git a/pkg/migrations/op_create_index.go b/pkg/migrations/op_create_index.go index d09b726..d2aff86 100644 --- a/pkg/migrations/op_create_index.go +++ b/pkg/migrations/op_create_index.go @@ -4,17 +4,17 @@ package migrations import ( "context" - "database/sql" "fmt" "strings" "github.com/lib/pq" + "github.com/xataio/pgroll/pkg/db" "github.com/xataio/pgroll/pkg/schema" ) var _ Operation = (*OpCreateIndex)(nil) -func (o *OpCreateIndex) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { +func (o *OpCreateIndex) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { // create index concurrently _, err := conn.ExecContext(ctx, fmt.Sprintf("CREATE INDEX CONCURRENTLY IF NOT EXISTS %s ON %s (%s)", pq.QuoteIdentifier(o.Name), @@ -23,12 +23,12 @@ func (o *OpCreateIndex) Start(ctx context.Context, conn *sql.DB, stateSchema str return nil, err } -func (o *OpCreateIndex) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpCreateIndex) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { // No-op return nil } -func (o *OpCreateIndex) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error { +func (o *OpCreateIndex) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error { // drop the index concurrently _, err := conn.ExecContext(ctx, fmt.Sprintf("DROP INDEX CONCURRENTLY IF EXISTS %s", o.Name)) diff --git a/pkg/migrations/op_create_table.go b/pkg/migrations/op_create_table.go index adc4a37..a829c72 100644 --- a/pkg/migrations/op_create_table.go +++ b/pkg/migrations/op_create_table.go @@ -4,17 +4,17 @@ package migrations import ( "context" - "database/sql" "fmt" "strings" "github.com/lib/pq" + "github.com/xataio/pgroll/pkg/db" "github.com/xataio/pgroll/pkg/schema" ) var _ Operation = (*OpCreateTable)(nil) -func (o *OpCreateTable) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { +func (o *OpCreateTable) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { // Generate SQL for the columns in the table columnsSQL, err := columnsToSQL(o.Columns, tr) if err != nil { @@ -61,7 +61,7 @@ func (o *OpCreateTable) Start(ctx context.Context, conn *sql.DB, stateSchema str return nil, nil } -func (o *OpCreateTable) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpCreateTable) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { tempName := TemporaryName(o.Name) _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s RENAME TO %s", pq.QuoteIdentifier(tempName), @@ -69,7 +69,7 @@ func (o *OpCreateTable) Complete(ctx context.Context, conn *sql.DB, tr SQLTransf return err } -func (o *OpCreateTable) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error { +func (o *OpCreateTable) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error { tempName := TemporaryName(o.Name) _, err := conn.ExecContext(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s", diff --git a/pkg/migrations/op_drop_column.go b/pkg/migrations/op_drop_column.go index fdb1d2d..cd9494a 100644 --- a/pkg/migrations/op_drop_column.go +++ b/pkg/migrations/op_drop_column.go @@ -4,16 +4,16 @@ package migrations import ( "context" - "database/sql" "fmt" "github.com/lib/pq" + "github.com/xataio/pgroll/pkg/db" "github.com/xataio/pgroll/pkg/schema" ) var _ Operation = (*OpDropColumn)(nil) -func (o *OpDropColumn) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { +func (o *OpDropColumn) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { if o.Down != "" { err := createTrigger(ctx, conn, tr, triggerConfig{ Name: TriggerName(o.Table, o.Column), @@ -34,7 +34,7 @@ func (o *OpDropColumn) Start(ctx context.Context, conn *sql.DB, stateSchema stri return nil, nil } -func (o *OpDropColumn) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpDropColumn) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s DROP COLUMN %s", pq.QuoteIdentifier(o.Table), pq.QuoteIdentifier(o.Column))) @@ -48,7 +48,7 @@ func (o *OpDropColumn) Complete(ctx context.Context, conn *sql.DB, tr SQLTransfo return err } -func (o *OpDropColumn) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error { +func (o *OpDropColumn) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error { _, err := conn.ExecContext(ctx, fmt.Sprintf("DROP FUNCTION IF EXISTS %s CASCADE", pq.QuoteIdentifier(TriggerFunctionName(o.Table, o.Column)))) diff --git a/pkg/migrations/op_drop_constraint.go b/pkg/migrations/op_drop_constraint.go index 907d1a9..64784cc 100644 --- a/pkg/migrations/op_drop_constraint.go +++ b/pkg/migrations/op_drop_constraint.go @@ -4,16 +4,16 @@ package migrations import ( "context" - "database/sql" "fmt" "github.com/lib/pq" + "github.com/xataio/pgroll/pkg/db" "github.com/xataio/pgroll/pkg/schema" ) var _ Operation = (*OpDropConstraint)(nil) -func (o *OpDropConstraint) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { +func (o *OpDropConstraint) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { table := s.GetTable(o.Table) column := table.GetColumn(o.Column) @@ -62,7 +62,7 @@ func (o *OpDropConstraint) Start(ctx context.Context, conn *sql.DB, stateSchema return table, nil } -func (o *OpDropConstraint) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpDropConstraint) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { // Remove the up function and trigger _, err := conn.ExecContext(ctx, fmt.Sprintf("DROP FUNCTION IF EXISTS %s CASCADE", pq.QuoteIdentifier(TriggerFunctionName(o.Table, o.Column)))) @@ -95,7 +95,7 @@ func (o *OpDropConstraint) Complete(ctx context.Context, conn *sql.DB, tr SQLTra return err } -func (o *OpDropConstraint) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error { +func (o *OpDropConstraint) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error { // Drop the new column _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s DROP COLUMN IF EXISTS %s", pq.QuoteIdentifier(o.Table), diff --git a/pkg/migrations/op_drop_index.go b/pkg/migrations/op_drop_index.go index 884d6d6..dab42ed 100644 --- a/pkg/migrations/op_drop_index.go +++ b/pkg/migrations/op_drop_index.go @@ -4,27 +4,27 @@ package migrations import ( "context" - "database/sql" "fmt" + "github.com/xataio/pgroll/pkg/db" "github.com/xataio/pgroll/pkg/schema" ) var _ Operation = (*OpDropIndex)(nil) -func (o *OpDropIndex) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { +func (o *OpDropIndex) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { // no-op return nil, nil } -func (o *OpDropIndex) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpDropIndex) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { // drop the index concurrently _, err := conn.ExecContext(ctx, fmt.Sprintf("DROP INDEX CONCURRENTLY IF EXISTS %s", o.Name)) return err } -func (o *OpDropIndex) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error { +func (o *OpDropIndex) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error { // no-op return nil } diff --git a/pkg/migrations/op_drop_not_null.go b/pkg/migrations/op_drop_not_null.go index 3d3b2b3..acf7530 100644 --- a/pkg/migrations/op_drop_not_null.go +++ b/pkg/migrations/op_drop_not_null.go @@ -4,8 +4,8 @@ package migrations import ( "context" - "database/sql" + "github.com/xataio/pgroll/pkg/db" "github.com/xataio/pgroll/pkg/schema" ) @@ -18,17 +18,17 @@ type OpDropNotNull struct { var _ Operation = (*OpDropNotNull)(nil) -func (o *OpDropNotNull) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { +func (o *OpDropNotNull) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { table := s.GetTable(o.Table) return table, nil } -func (o *OpDropNotNull) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpDropNotNull) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { return nil } -func (o *OpDropNotNull) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error { +func (o *OpDropNotNull) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error { return nil } diff --git a/pkg/migrations/op_drop_table.go b/pkg/migrations/op_drop_table.go index b25e4e2..ce745e9 100644 --- a/pkg/migrations/op_drop_table.go +++ b/pkg/migrations/op_drop_table.go @@ -4,27 +4,27 @@ package migrations import ( "context" - "database/sql" "fmt" "github.com/lib/pq" + "github.com/xataio/pgroll/pkg/db" "github.com/xataio/pgroll/pkg/schema" ) var _ Operation = (*OpDropTable)(nil) -func (o *OpDropTable) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { +func (o *OpDropTable) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { s.RemoveTable(o.Name) return nil, nil } -func (o *OpDropTable) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpDropTable) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { _, err := conn.ExecContext(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s", pq.QuoteIdentifier(o.Name))) return err } -func (o *OpDropTable) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error { +func (o *OpDropTable) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error { return nil } diff --git a/pkg/migrations/op_raw_sql.go b/pkg/migrations/op_raw_sql.go index 7f3f9e2..910a439 100644 --- a/pkg/migrations/op_raw_sql.go +++ b/pkg/migrations/op_raw_sql.go @@ -4,14 +4,14 @@ package migrations import ( "context" - "database/sql" + "github.com/xataio/pgroll/pkg/db" "github.com/xataio/pgroll/pkg/schema" ) var _ Operation = (*OpRawSQL)(nil) -func (o *OpRawSQL) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { +func (o *OpRawSQL) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { if o.OnComplete { return nil, nil } @@ -25,7 +25,7 @@ func (o *OpRawSQL) Start(ctx context.Context, conn *sql.DB, stateSchema string, return nil, err } -func (o *OpRawSQL) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpRawSQL) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { if !o.OnComplete { return nil } @@ -39,7 +39,7 @@ func (o *OpRawSQL) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer return err } -func (o *OpRawSQL) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error { +func (o *OpRawSQL) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error { if o.Down == "" { return nil } diff --git a/pkg/migrations/op_rename_constraint.go b/pkg/migrations/op_rename_constraint.go index 875693a..2ff3347 100644 --- a/pkg/migrations/op_rename_constraint.go +++ b/pkg/migrations/op_rename_constraint.go @@ -4,21 +4,21 @@ package migrations import ( "context" - "database/sql" "fmt" "github.com/lib/pq" + "github.com/xataio/pgroll/pkg/db" "github.com/xataio/pgroll/pkg/schema" ) var _ Operation = (*OpRenameConstraint)(nil) -func (o *OpRenameConstraint) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { +func (o *OpRenameConstraint) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { // no-op return nil, nil } -func (o *OpRenameConstraint) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpRenameConstraint) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { // rename the constraint in the underlying table _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s RENAME CONSTRAINT %s TO %s", pq.QuoteIdentifier(o.Table), @@ -27,7 +27,7 @@ func (o *OpRenameConstraint) Complete(ctx context.Context, conn *sql.DB, tr SQLT return err } -func (o *OpRenameConstraint) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error { +func (o *OpRenameConstraint) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error { // no-op return nil } diff --git a/pkg/migrations/op_rename_table.go b/pkg/migrations/op_rename_table.go index eb85e1e..98b42b6 100644 --- a/pkg/migrations/op_rename_table.go +++ b/pkg/migrations/op_rename_table.go @@ -4,27 +4,27 @@ package migrations import ( "context" - "database/sql" "fmt" "github.com/lib/pq" + "github.com/xataio/pgroll/pkg/db" "github.com/xataio/pgroll/pkg/schema" ) var _ Operation = (*OpRenameTable)(nil) -func (o *OpRenameTable) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { +func (o *OpRenameTable) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { return nil, s.RenameTable(o.From, o.To) } -func (o *OpRenameTable) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpRenameTable) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s RENAME TO %s", pq.QuoteIdentifier(o.From), pq.QuoteIdentifier(o.To))) return err } -func (o *OpRenameTable) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error { +func (o *OpRenameTable) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error { return nil } diff --git a/pkg/migrations/op_set_check.go b/pkg/migrations/op_set_check.go index 36910fd..5eaa7de 100644 --- a/pkg/migrations/op_set_check.go +++ b/pkg/migrations/op_set_check.go @@ -4,11 +4,11 @@ package migrations import ( "context" - "database/sql" "fmt" "strings" "github.com/lib/pq" + "github.com/xataio/pgroll/pkg/db" "github.com/xataio/pgroll/pkg/schema" ) @@ -22,7 +22,7 @@ type OpSetCheckConstraint struct { var _ Operation = (*OpSetCheckConstraint)(nil) -func (o *OpSetCheckConstraint) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { +func (o *OpSetCheckConstraint) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { table := s.GetTable(o.Table) // Add the check constraint to the new column as NOT VALID. @@ -33,7 +33,7 @@ func (o *OpSetCheckConstraint) Start(ctx context.Context, conn *sql.DB, stateSch return table, nil } -func (o *OpSetCheckConstraint) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpSetCheckConstraint) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { // Validate the check constraint _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s VALIDATE CONSTRAINT %s", pq.QuoteIdentifier(o.Table), @@ -45,7 +45,7 @@ func (o *OpSetCheckConstraint) Complete(ctx context.Context, conn *sql.DB, tr SQ return nil } -func (o *OpSetCheckConstraint) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error { +func (o *OpSetCheckConstraint) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error { return nil } @@ -69,7 +69,7 @@ func (o *OpSetCheckConstraint) Validate(ctx context.Context, s *schema.Schema) e return nil } -func (o *OpSetCheckConstraint) addCheckConstraint(ctx context.Context, conn *sql.DB) error { +func (o *OpSetCheckConstraint) addCheckConstraint(ctx context.Context, conn db.DB) error { _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s CHECK (%s) NOT VALID", pq.QuoteIdentifier(o.Table), pq.QuoteIdentifier(o.Check.Name), diff --git a/pkg/migrations/op_set_comment.go b/pkg/migrations/op_set_comment.go index fc8d13a..ec3804b 100644 --- a/pkg/migrations/op_set_comment.go +++ b/pkg/migrations/op_set_comment.go @@ -4,8 +4,8 @@ package migrations import ( "context" - "database/sql" + "github.com/xataio/pgroll/pkg/db" "github.com/xataio/pgroll/pkg/schema" ) @@ -19,17 +19,17 @@ type OpSetComment struct { var _ Operation = (*OpSetComment)(nil) -func (o *OpSetComment) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { +func (o *OpSetComment) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { tbl := s.GetTable(o.Table) return tbl, addCommentToColumn(ctx, conn, o.Table, TemporaryName(o.Column), o.Comment) } -func (o *OpSetComment) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpSetComment) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { return nil } -func (o *OpSetComment) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error { +func (o *OpSetComment) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error { return nil } diff --git a/pkg/migrations/op_set_default.go b/pkg/migrations/op_set_default.go index 6da2a04..4fbbcda 100644 --- a/pkg/migrations/op_set_default.go +++ b/pkg/migrations/op_set_default.go @@ -4,10 +4,10 @@ package migrations import ( "context" - "database/sql" "fmt" "github.com/lib/pq" + "github.com/xataio/pgroll/pkg/db" "github.com/xataio/pgroll/pkg/schema" ) @@ -21,7 +21,7 @@ type OpSetDefault struct { var _ Operation = (*OpSetDefault)(nil) -func (o *OpSetDefault) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { +func (o *OpSetDefault) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { tbl := s.GetTable(o.Table) _, err := conn.ExecContext(ctx, fmt.Sprintf(`ALTER TABLE %s ALTER COLUMN %s SET DEFAULT %s`, @@ -35,11 +35,11 @@ func (o *OpSetDefault) Start(ctx context.Context, conn *sql.DB, stateSchema stri return tbl, nil } -func (o *OpSetDefault) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpSetDefault) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { return nil } -func (o *OpSetDefault) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error { +func (o *OpSetDefault) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error { return nil } diff --git a/pkg/migrations/op_set_fk.go b/pkg/migrations/op_set_fk.go index 6cbe711..90541e3 100644 --- a/pkg/migrations/op_set_fk.go +++ b/pkg/migrations/op_set_fk.go @@ -4,11 +4,11 @@ package migrations import ( "context" - "database/sql" "fmt" "strings" "github.com/lib/pq" + "github.com/xataio/pgroll/pkg/db" "github.com/xataio/pgroll/pkg/schema" ) @@ -22,7 +22,7 @@ type OpSetForeignKey struct { var _ Operation = (*OpSetForeignKey)(nil) -func (o *OpSetForeignKey) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { +func (o *OpSetForeignKey) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { table := s.GetTable(o.Table) // Create a NOT VALID foreign key constraint on the new column. @@ -33,7 +33,7 @@ func (o *OpSetForeignKey) Start(ctx context.Context, conn *sql.DB, stateSchema s return table, nil } -func (o *OpSetForeignKey) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpSetForeignKey) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { // Validate the foreign key constraint _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s VALIDATE CONSTRAINT %s", pq.QuoteIdentifier(o.Table), @@ -45,7 +45,7 @@ func (o *OpSetForeignKey) Complete(ctx context.Context, conn *sql.DB, tr SQLTran return nil } -func (o *OpSetForeignKey) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error { +func (o *OpSetForeignKey) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error { return nil } @@ -69,7 +69,7 @@ func (o *OpSetForeignKey) Validate(ctx context.Context, s *schema.Schema) error return nil } -func (o *OpSetForeignKey) addForeignKeyConstraint(ctx context.Context, conn *sql.DB) error { +func (o *OpSetForeignKey) addForeignKeyConstraint(ctx context.Context, conn db.DB) error { tempColumnName := TemporaryName(o.Column) onDelete := "NO ACTION" diff --git a/pkg/migrations/op_set_notnull.go b/pkg/migrations/op_set_notnull.go index 066b56a..c6d6395 100644 --- a/pkg/migrations/op_set_notnull.go +++ b/pkg/migrations/op_set_notnull.go @@ -4,10 +4,10 @@ package migrations import ( "context" - "database/sql" "fmt" "github.com/lib/pq" + "github.com/xataio/pgroll/pkg/db" "github.com/xataio/pgroll/pkg/schema" ) @@ -20,7 +20,7 @@ type OpSetNotNull struct { var _ Operation = (*OpSetNotNull)(nil) -func (o *OpSetNotNull) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { +func (o *OpSetNotNull) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { table := s.GetTable(o.Table) // Add an unchecked NOT NULL constraint to the new column. @@ -31,7 +31,7 @@ func (o *OpSetNotNull) Start(ctx context.Context, conn *sql.DB, stateSchema stri return table, nil } -func (o *OpSetNotNull) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpSetNotNull) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { // Validate the NOT NULL constraint on the old column. // The constraint must be valid because: // * Existing NULL values in the old column were rewritten using the `up` SQL during backfill. @@ -62,7 +62,7 @@ func (o *OpSetNotNull) Complete(ctx context.Context, conn *sql.DB, tr SQLTransfo return nil } -func (o *OpSetNotNull) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error { +func (o *OpSetNotNull) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error { return nil } diff --git a/pkg/migrations/op_set_replica_identity.go b/pkg/migrations/op_set_replica_identity.go index 75c196b..282a554 100644 --- a/pkg/migrations/op_set_replica_identity.go +++ b/pkg/migrations/op_set_replica_identity.go @@ -4,18 +4,18 @@ package migrations import ( "context" - "database/sql" "fmt" "slices" "strings" "github.com/lib/pq" + "github.com/xataio/pgroll/pkg/db" "github.com/xataio/pgroll/pkg/schema" ) var _ Operation = (*OpSetReplicaIdentity)(nil) -func (o *OpSetReplicaIdentity) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { +func (o *OpSetReplicaIdentity) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { // build the correct form of the `SET REPLICA IDENTITY` statement based on the`identity type identitySQL := strings.ToUpper(o.Identity.Type) if identitySQL == "INDEX" { @@ -29,12 +29,12 @@ func (o *OpSetReplicaIdentity) Start(ctx context.Context, conn *sql.DB, stateSch return nil, err } -func (o *OpSetReplicaIdentity) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpSetReplicaIdentity) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { // No-op return nil } -func (o *OpSetReplicaIdentity) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error { +func (o *OpSetReplicaIdentity) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error { // No-op return nil } diff --git a/pkg/migrations/op_set_unique.go b/pkg/migrations/op_set_unique.go index a19614e..5e89e16 100644 --- a/pkg/migrations/op_set_unique.go +++ b/pkg/migrations/op_set_unique.go @@ -4,10 +4,10 @@ package migrations import ( "context" - "database/sql" "fmt" "github.com/lib/pq" + "github.com/xataio/pgroll/pkg/db" "github.com/xataio/pgroll/pkg/schema" ) @@ -21,7 +21,7 @@ type OpSetUnique struct { var _ Operation = (*OpSetUnique)(nil) -func (o *OpSetUnique) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { +func (o *OpSetUnique) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { table := s.GetTable(o.Table) // Add a unique index to the new column @@ -32,7 +32,7 @@ func (o *OpSetUnique) Start(ctx context.Context, conn *sql.DB, stateSchema strin return table, nil } -func (o *OpSetUnique) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpSetUnique) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { // Create a unique constraint using the unique index _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s ADD CONSTRAINT %s UNIQUE USING INDEX %s", pq.QuoteIdentifier(o.Table), @@ -45,7 +45,7 @@ func (o *OpSetUnique) Complete(ctx context.Context, conn *sql.DB, tr SQLTransfor return err } -func (o *OpSetUnique) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error { +func (o *OpSetUnique) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error { return nil } @@ -66,7 +66,7 @@ func (o *OpSetUnique) Validate(ctx context.Context, s *schema.Schema) error { return nil } -func (o *OpSetUnique) addUniqueIndex(ctx context.Context, conn *sql.DB) error { +func (o *OpSetUnique) addUniqueIndex(ctx context.Context, conn db.DB) error { // create unique index concurrently _, err := conn.ExecContext(ctx, fmt.Sprintf("CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS %s ON %s (%s)", pq.QuoteIdentifier(o.Name), diff --git a/pkg/migrations/rename.go b/pkg/migrations/rename.go index d75d6ca..b8b888b 100644 --- a/pkg/migrations/rename.go +++ b/pkg/migrations/rename.go @@ -4,11 +4,11 @@ package migrations import ( "context" - "database/sql" "fmt" "slices" "github.com/lib/pq" + "github.com/xataio/pgroll/pkg/db" "github.com/xataio/pgroll/pkg/schema" ) @@ -16,7 +16,7 @@ import ( // * renames a duplicated column to its original name // * renames any foreign keys on the duplicated column to their original name. // * Validates and renames any temporary `CHECK` constraints on the duplicated column. -func RenameDuplicatedColumn(ctx context.Context, conn *sql.DB, table *schema.Table, column *schema.Column) error { +func RenameDuplicatedColumn(ctx context.Context, conn db.DB, table *schema.Table, column *schema.Column) error { const ( cRenameColumnSQL = `ALTER TABLE IF EXISTS %s RENAME COLUMN %s TO %s` cRenameConstraintSQL = `ALTER TABLE IF EXISTS %s RENAME CONSTRAINT %s TO %s` diff --git a/pkg/migrations/trigger.go b/pkg/migrations/trigger.go index eb932d3..e065f71 100644 --- a/pkg/migrations/trigger.go +++ b/pkg/migrations/trigger.go @@ -5,10 +5,10 @@ package migrations import ( "bytes" "context" - "database/sql" "text/template" "github.com/lib/pq" + "github.com/xataio/pgroll/pkg/db" "github.com/xataio/pgroll/pkg/migrations/templates" "github.com/xataio/pgroll/pkg/schema" ) @@ -33,7 +33,7 @@ type triggerConfig struct { SQL string } -func createTrigger(ctx context.Context, conn *sql.DB, tr SQLTransformer, cfg triggerConfig) error { +func createTrigger(ctx context.Context, conn db.DB, tr SQLTransformer, cfg triggerConfig) error { sql, err := tr.TransformSQL(cfg.SQL) if err != nil { return err diff --git a/pkg/roll/execute_test.go b/pkg/roll/execute_test.go index d43b7f2..db419e9 100644 --- a/pkg/roll/execute_test.go +++ b/pkg/roll/execute_test.go @@ -5,11 +5,10 @@ package roll_test import ( "context" "database/sql" - "errors" "fmt" "testing" + "time" - "github.com/lib/pq" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/xataio/pgroll/pkg/migrations" @@ -313,57 +312,48 @@ func TestSchemaOptionIsRespected(t *testing.T) { }) } -func TestLockTimeoutIsEnforced(t *testing.T) { +func TestMigrationDDLIsRetriedOnLockTimeouts(t *testing.T) { t.Parallel() - testutils.WithMigratorInSchemaAndConnectionToContainerWithOptions(t, "public", []roll.Option{roll.WithLockTimeoutMs(100)}, func(mig *roll.Roll, db *sql.DB) { + testutils.WithMigratorInSchemaAndConnectionToContainerWithOptions(t, "public", []roll.Option{roll.WithLockTimeoutMs(50)}, func(mig *roll.Roll, db *sql.DB) { ctx := context.Background() - // Start a create table migration - err := mig.Start(ctx, &migrations.Migration{ - Name: "01_create_table", - Operations: migrations.Operations{createTableOp("table1")}, - }) - if err != nil { - t.Fatalf("Failed to start migration: %v", err) - } + // Create a table + _, err := db.ExecContext(ctx, "CREATE TABLE table1 (id integer, name text)") + require.NoError(t, err) - // Complete the create table migration - if err := mig.Complete(ctx); err != nil { - t.Fatalf("Failed to complete migration: %v", err) - } + // Start a goroutine which takes an ACCESS_EXCLUSIVE lock on the table for + // two seconds + errCh := make(chan error) + go func() { + tx, err := db.Begin() + if err != nil { + errCh <- err + } - // Start a transaction and take an ACCESS_EXCLUSIVE lock on the table - // Don't commit the transaction so that the lock is held indefinitely - tx, err := db.Begin() - if err != nil { - t.Fatalf("Failed to start transaction: %v", err) - } - t.Cleanup(func() { + if _, err := tx.ExecContext(ctx, "LOCK TABLE table1 IN ACCESS EXCLUSIVE MODE"); err != nil { + errCh <- err + } + errCh <- nil + + // Sleep for two seconds to hold the lock + time.Sleep(2 * time.Second) + + // Commit the transaction tx.Commit() - }) - if _, err := tx.ExecContext(ctx, "LOCK TABLE table1 IN ACCESS EXCLUSIVE MODE"); err != nil { - t.Fatalf("Failed to take ACCESS_EXCLUSIVE lock on table: %v", err) - } + }() - // Attempt to run a second migration on the table while the lock is held - // The migration should fail due to a lock timeout error + // Wait for lock to be taken + err = <-errCh + require.NoError(t, err) + + // Attempt to start a second migration on the table while the lock is held. + // The migration should eventually succeed after the lock is released err = mig.Start(ctx, &migrations.Migration{ - Name: "02_create_table", + Name: "01_add_column", Operations: migrations.Operations{addColumnOp("table1")}, }) - if err == nil { - t.Fatalf("Expected migration to fail due to lock timeout") - } - if err != nil { - pqErr := &pq.Error{} - if ok := errors.As(err, &pqErr); !ok { - t.Fatalf("Migration failed with unexpected error: %v", err) - } - if pqErr.Code != "55P03" { // Lock not available error code - t.Fatalf("Migration failed with unexpected error: %v", err) - } - } + require.NoError(t, err) }) } diff --git a/pkg/roll/roll.go b/pkg/roll/roll.go index 46df978..d2e37e7 100644 --- a/pkg/roll/roll.go +++ b/pkg/roll/roll.go @@ -9,6 +9,7 @@ import ( "github.com/lib/pq" + "github.com/xataio/pgroll/pkg/db" "github.com/xataio/pgroll/pkg/migrations" "github.com/xataio/pgroll/pkg/state" ) @@ -18,7 +19,7 @@ type PGVersion int const PGVersion15 PGVersion = 15 type Roll struct { - pgConn *sql.DB // TODO abstract sql connection + pgConn db.DB // schema we are acting on schema string @@ -57,7 +58,7 @@ func New(ctx context.Context, pgURL, schema string, state *state.State, opts ... } return &Roll{ - pgConn: conn, + pgConn: &db.RDB{DB: conn}, schema: schema, state: state, pgVersion: PGVersion(pgMajorVersion), @@ -114,7 +115,7 @@ func (m *Roll) PGVersion() PGVersion { return m.pgVersion } -func (m *Roll) PgConn() *sql.DB { +func (m *Roll) PgConn() db.DB { return m.pgConn } diff --git a/pkg/testutils/util.go b/pkg/testutils/util.go index c815340..7cb1b54 100644 --- a/pkg/testutils/util.go +++ b/pkg/testutils/util.go @@ -103,6 +103,14 @@ func WithStateInSchemaAndConnectionToContainer(t *testing.T, schema string, fn f fn(st, db) } +func WithConnectionToContainer(t *testing.T, fn func(*sql.DB, string)) { + t.Helper() + + db, connStr, _ := setupTestDatabase(t) + + fn(db, connStr) +} + func WithStateAndConnectionToContainer(t *testing.T, fn func(*state.State, *sql.DB)) { WithStateInSchemaAndConnectionToContainer(t, "pgroll", fn) }