diff --git a/go.mod b/go.mod index b7d6578..e1787f8 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/xataio/pgroll go 1.21 require ( + github.com/cloudflare/backoff v0.0.0-20161212185259-647f3cdfc87a github.com/google/go-cmp v0.6.0 github.com/lib/pq v1.10.9 github.com/oapi-codegen/nullable v1.1.0 diff --git a/go.sum b/go.sum index 0811811..d402871 100644 --- a/go.sum +++ b/go.sum @@ -73,6 +73,8 @@ github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWR github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= +github.com/cloudflare/backoff v0.0.0-20161212185259-647f3cdfc87a h1:8d1CEOF1xldesKds5tRG3tExBsMOgWYownMHNCsev54= +github.com/cloudflare/backoff v0.0.0-20161212185259-647f3cdfc87a/go.mod h1:rzgs2ZOiguV6/NpiDgADjRLPNyZlApIWxKpkT+X8SdY= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= diff --git a/pkg/db/db.go b/pkg/db/db.go new file mode 100644 index 0000000..dc5be90 --- /dev/null +++ b/pkg/db/db.go @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: Apache-2.0 + +package db + +import ( + "context" + "database/sql" + "errors" + "time" + + "github.com/cloudflare/backoff" + "github.com/lib/pq" +) + +const ( + lockNotAvailableErrorCode pq.ErrorCode = "55P03" + maxBackoffDuration = 1 * time.Minute + backoffInterval = 1 * time.Second +) + +type DB interface { + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) + WithRetryableTransaction(ctx context.Context, f func(context.Context, *sql.Tx) error) error + Close() error +} + +// RDB wraps a *sql.DB and retries queries using an exponential backoff (with +// jitter) on lock_timeout errors. +type RDB struct { + DB *sql.DB +} + +// ExecContext wraps sql.DB.ExecContext, retrying queries on lock_timeout errors. +func (db *RDB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + b := backoff.New(maxBackoffDuration, backoffInterval) + + for { + res, err := db.DB.ExecContext(ctx, query, args...) + if err == nil { + return res, nil + } + + pqErr := &pq.Error{} + if errors.As(err, &pqErr) && pqErr.Code == lockNotAvailableErrorCode { + <-time.After(b.Duration()) + } else { + return nil, err + } + } +} + +// WithRetryableTransaction runs `f` in a transaction, retrying on lock_timeout errors. +func (db *RDB) WithRetryableTransaction(ctx context.Context, f func(context.Context, *sql.Tx) error) error { + b := backoff.New(maxBackoffDuration, backoffInterval) + + for { + tx, err := db.DB.BeginTx(ctx, nil) + if err != nil { + return err + } + + err = f(ctx, tx) + if err == nil { + return tx.Commit() + } + + if errRollback := tx.Rollback(); errRollback != nil { + return errRollback + } + + pqErr := &pq.Error{} + if errors.As(err, &pqErr) && pqErr.Code == lockNotAvailableErrorCode { + <-time.After(b.Duration()) + } else { + return err + } + } +} + +func (db *RDB) Close() error { + return db.DB.Close() +} diff --git a/pkg/db/db_test.go b/pkg/db/db_test.go new file mode 100644 index 0000000..bfc37b2 --- /dev/null +++ b/pkg/db/db_test.go @@ -0,0 +1,121 @@ +// SPDX-License-Identifier: Apache-2.0 + +package db_test + +import ( + "context" + "database/sql" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/xataio/pgroll/pkg/db" + "github.com/xataio/pgroll/pkg/testutils" +) + +func TestMain(m *testing.M) { + testutils.SharedTestMain(m) +} + +func TestExecContext(t *testing.T) { + t.Parallel() + + testutils.WithConnectionToContainer(t, func(conn *sql.DB, connStr string) { + ctx := context.Background() + // create a table on which an exclusive lock is held for 2 seconds + setupTableLock(t, connStr, 2*time.Second) + + // set the lock timeout to 100ms + ensureLockTimeout(t, conn, 100) + + // execute a query that should retry until the lock is released + rdb := &db.RDB{DB: conn} + _, err := rdb.ExecContext(ctx, "INSERT INTO test(id) VALUES (1)") + require.NoError(t, err) + }) +} + +func TestWithRetryableTransaction(t *testing.T) { + t.Parallel() + + testutils.WithConnectionToContainer(t, func(conn *sql.DB, connStr string) { + ctx := context.Background() + + // create a table on which an exclusive lock is held for 2 seconds + setupTableLock(t, connStr, 2*time.Second) + + // set the lock timeout to 100ms + ensureLockTimeout(t, conn, 100) + + // run a transaction that should retry until the lock is released + rdb := &db.RDB{DB: conn} + err := rdb.WithRetryableTransaction(ctx, func(ctx context.Context, tx *sql.Tx) error { + return tx.QueryRowContext(ctx, "SELECT 1 FROM test").Err() + }) + require.NoError(t, err) + }) +} + +// setupTableLock: +// * connects to the database +// * creates a table in the database +// * starts a transaction that temporarily locks the table +func setupTableLock(t *testing.T, connStr string, d time.Duration) { + t.Helper() + ctx := context.Background() + + // connect to the database + conn2, err := sql.Open("postgres", connStr) + require.NoError(t, err) + + // create a table in the database + _, err = conn2.ExecContext(ctx, "CREATE TABLE test (id INT PRIMARY KEY)") + require.NoError(t, err) + + // start a transaction that takes a temporary lock on the table + errCh := make(chan error) + go func() { + // begin a transaction + tx, err := conn2.Begin() + if err != nil { + errCh <- err + return + } + + // lock the table + _, err = tx.ExecContext(ctx, "LOCK TABLE test IN ACCESS EXCLUSIVE MODE") + if err != nil { + errCh <- err + return + } + + // signal that the lock is obtained + errCh <- nil + + // temporarily hold the lock + time.Sleep(d) + + // commit the transaction + tx.Commit() + }() + + // wait for the lock to be obtained + err = <-errCh + require.NoError(t, err) +} + +func ensureLockTimeout(t *testing.T, conn *sql.DB, ms int) { + t.Helper() + + // Set the lock timeout + query := fmt.Sprintf("SET lock_timeout = '%dms'", ms) + _, err := conn.ExecContext(context.Background(), query) + require.NoError(t, err) + + // Ensure the lock timeout is set + var lockTimeout string + err = conn.QueryRowContext(context.Background(), "SHOW lock_timeout").Scan(&lockTimeout) + require.NoError(t, err) + require.Equal(t, fmt.Sprintf("%dms", ms), lockTimeout) +} diff --git a/pkg/migrations/backfill.go b/pkg/migrations/backfill.go index 35ef729..56fcae9 100644 --- a/pkg/migrations/backfill.go +++ b/pkg/migrations/backfill.go @@ -9,6 +9,7 @@ import ( "fmt" "github.com/lib/pq" + "github.com/xataio/pgroll/pkg/db" "github.com/xataio/pgroll/pkg/schema" ) @@ -18,7 +19,7 @@ import ( // 2. Get the first batch of rows from the table, ordered by the primary key. // 3. Update each row in the batch, setting the value of the primary key column to itself. // 4. Repeat steps 2 and 3 until no more rows are returned. -func Backfill(ctx context.Context, conn *sql.DB, table *schema.Table, cbs ...CallbackFn) error { +func Backfill(ctx context.Context, conn db.DB, table *schema.Table, cbs ...CallbackFn) error { // get the backfill column identityColumn := getIdentityColumn(table) if identityColumn == nil { @@ -85,27 +86,20 @@ type batcher struct { batchSize int } -// updateBatch updates the next batch of rows in the table. -func (b *batcher) updateBatch(ctx context.Context, conn *sql.DB) error { - // Start the transaction for this batch - tx, err := conn.BeginTx(ctx, nil) - if err != nil { - return err - } - defer tx.Rollback() +func (b *batcher) updateBatch(ctx context.Context, conn db.DB) error { + return conn.WithRetryableTransaction(ctx, func(ctx context.Context, tx *sql.Tx) error { + // Build the query to update the next batch of rows + query := b.buildQuery() - // Build the query to update the next batch of rows - query := b.buildQuery() + // Execute the query to update the next batch of rows and update the last PK + // value for the next batch + err := tx.QueryRowContext(ctx, query).Scan(&b.lastValue) + if err != nil { + return err + } - // Execute the query to update the next batch of rows and update the last PK - // value for the next batch - err = tx.QueryRowContext(ctx, query).Scan(&b.lastValue) - if err != nil { - return err - } - - // Commit the transaction for this batch - return tx.Commit() + return nil + }) } // buildQuery builds the query used to update the next batch of rows. diff --git a/pkg/migrations/comment.go b/pkg/migrations/comment.go index b3f232d..f1a5028 100644 --- a/pkg/migrations/comment.go +++ b/pkg/migrations/comment.go @@ -4,13 +4,13 @@ package migrations import ( "context" - "database/sql" "fmt" "github.com/lib/pq" + "github.com/xataio/pgroll/pkg/db" ) -func addCommentToColumn(ctx context.Context, conn *sql.DB, tableName, columnName string, comment *string) error { +func addCommentToColumn(ctx context.Context, conn db.DB, tableName, columnName string, comment *string) error { _, err := conn.ExecContext(ctx, fmt.Sprintf(`COMMENT ON COLUMN %s.%s IS %s`, pq.QuoteIdentifier(tableName), pq.QuoteIdentifier(columnName), @@ -19,7 +19,7 @@ func addCommentToColumn(ctx context.Context, conn *sql.DB, tableName, columnName return err } -func addCommentToTable(ctx context.Context, conn *sql.DB, tableName string, comment *string) error { +func addCommentToTable(ctx context.Context, conn db.DB, tableName string, comment *string) error { _, err := conn.ExecContext(ctx, fmt.Sprintf(`COMMENT ON TABLE %s IS %s`, pq.QuoteIdentifier(tableName), commentToSQL(comment))) diff --git a/pkg/migrations/duplicate.go b/pkg/migrations/duplicate.go index 7bcd58d..19e386f 100644 --- a/pkg/migrations/duplicate.go +++ b/pkg/migrations/duplicate.go @@ -4,17 +4,17 @@ package migrations import ( "context" - "database/sql" "fmt" "slices" "strings" "github.com/lib/pq" + "github.com/xataio/pgroll/pkg/db" "github.com/xataio/pgroll/pkg/schema" ) type Duplicator struct { - conn *sql.DB + conn db.DB table *schema.Table column *schema.Column asName string @@ -24,7 +24,7 @@ type Duplicator struct { } // NewColumnDuplicator creates a new Duplicator for a column. -func NewColumnDuplicator(conn *sql.DB, table *schema.Table, column *schema.Column) *Duplicator { +func NewColumnDuplicator(conn db.DB, table *schema.Table, column *schema.Column) *Duplicator { return &Duplicator{ conn: conn, table: table, diff --git a/pkg/migrations/migrations.go b/pkg/migrations/migrations.go index 76bc055..3e26c46 100644 --- a/pkg/migrations/migrations.go +++ b/pkg/migrations/migrations.go @@ -4,10 +4,10 @@ package migrations import ( "context" - "database/sql" "fmt" _ "github.com/lib/pq" + "github.com/xataio/pgroll/pkg/db" "github.com/xataio/pgroll/pkg/schema" ) @@ -18,16 +18,16 @@ type Operation interface { // version in the database (through a view) // update the given views to expose the new schema version // Returns the table that requires backfilling, if any. - Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) + Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) // Complete will update the database schema to match the current version // after calling Start. // This method should be called once the previous version is no longer used - Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error + Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error // Rollback will revert the changes made by Start. It is not possible to // rollback a completed migration. - Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error + Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error // Validate returns a descriptive error if the operation cannot be applied to the given schema Validate(ctx context.Context, s *schema.Schema) error diff --git a/pkg/migrations/op_add_column.go b/pkg/migrations/op_add_column.go index aa251e2..88d1f67 100644 --- a/pkg/migrations/op_add_column.go +++ b/pkg/migrations/op_add_column.go @@ -4,18 +4,18 @@ package migrations import ( "context" - "database/sql" "errors" "fmt" "strings" "github.com/lib/pq" + "github.com/xataio/pgroll/pkg/db" "github.com/xataio/pgroll/pkg/schema" ) var _ Operation = (*OpAddColumn)(nil) -func (o *OpAddColumn) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { +func (o *OpAddColumn) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { table := s.GetTable(o.Table) if err := addColumn(ctx, conn, *o, table, tr); err != nil { @@ -65,7 +65,7 @@ func (o *OpAddColumn) Start(ctx context.Context, conn *sql.DB, stateSchema strin return tableToBackfill, nil } -func (o *OpAddColumn) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpAddColumn) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { tempName := TemporaryName(o.Column.Name) _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s RENAME COLUMN %s TO %s", @@ -118,7 +118,7 @@ func (o *OpAddColumn) Complete(ctx context.Context, conn *sql.DB, tr SQLTransfor return err } -func (o *OpAddColumn) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error { +func (o *OpAddColumn) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error { tempName := TemporaryName(o.Column.Name) _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s DROP COLUMN IF EXISTS %s", @@ -182,7 +182,7 @@ func (o *OpAddColumn) Validate(ctx context.Context, s *schema.Schema) error { return nil } -func addColumn(ctx context.Context, conn *sql.DB, o OpAddColumn, t *schema.Table, tr SQLTransformer) error { +func addColumn(ctx context.Context, conn db.DB, o OpAddColumn, t *schema.Table, tr SQLTransformer) error { // don't add non-nullable columns with no default directly // they are handled by: // - adding the column as nullable @@ -216,7 +216,7 @@ func addColumn(ctx context.Context, conn *sql.DB, o OpAddColumn, t *schema.Table return err } -func addNotNullConstraint(ctx context.Context, conn *sql.DB, table, column, physicalColumn string) error { +func addNotNullConstraint(ctx context.Context, conn db.DB, table, column, physicalColumn string) error { _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s CHECK (%s IS NOT NULL) NOT VALID", pq.QuoteIdentifier(table), pq.QuoteIdentifier(NotNullConstraintName(column)), @@ -225,7 +225,7 @@ func addNotNullConstraint(ctx context.Context, conn *sql.DB, table, column, phys return err } -func (o *OpAddColumn) addCheckConstraint(ctx context.Context, conn *sql.DB) error { +func (o *OpAddColumn) addCheckConstraint(ctx context.Context, conn db.DB) error { _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s CHECK (%s) NOT VALID", pq.QuoteIdentifier(o.Table), pq.QuoteIdentifier(o.Column.Check.Name), diff --git a/pkg/migrations/op_alter_column.go b/pkg/migrations/op_alter_column.go index 8653fad..95ae8c3 100644 --- a/pkg/migrations/op_alter_column.go +++ b/pkg/migrations/op_alter_column.go @@ -4,16 +4,16 @@ package migrations import ( "context" - "database/sql" "fmt" "github.com/lib/pq" + "github.com/xataio/pgroll/pkg/db" "github.com/xataio/pgroll/pkg/schema" ) var _ Operation = (*OpAlterColumn)(nil) -func (o *OpAlterColumn) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { +func (o *OpAlterColumn) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { table := s.GetTable(o.Table) column := table.GetColumn(o.Column) ops := o.subOperations() @@ -84,7 +84,7 @@ func (o *OpAlterColumn) Start(ctx context.Context, conn *sql.DB, stateSchema str return table, nil } -func (o *OpAlterColumn) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpAlterColumn) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { ops := o.subOperations() // Perform any operation specific completion steps @@ -139,7 +139,7 @@ func (o *OpAlterColumn) Complete(ctx context.Context, conn *sql.DB, tr SQLTransf return nil } -func (o *OpAlterColumn) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error { +func (o *OpAlterColumn) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error { ops := o.subOperations() // Perform any operation specific rollback steps @@ -315,7 +315,7 @@ func (o *OpAlterColumn) subOperations() []Operation { } // duplicatorForOperations returns a Duplicator for the given operations -func duplicatorForOperations(ops []Operation, conn *sql.DB, table *schema.Table, column *schema.Column) *Duplicator { +func duplicatorForOperations(ops []Operation, conn db.DB, table *schema.Table, column *schema.Column) *Duplicator { d := NewColumnDuplicator(conn, table, column) for _, op := range ops { diff --git a/pkg/migrations/op_change_type.go b/pkg/migrations/op_change_type.go index ea2c903..d9bf0a3 100644 --- a/pkg/migrations/op_change_type.go +++ b/pkg/migrations/op_change_type.go @@ -4,8 +4,8 @@ package migrations import ( "context" - "database/sql" + "github.com/xataio/pgroll/pkg/db" "github.com/xataio/pgroll/pkg/schema" ) @@ -19,17 +19,17 @@ type OpChangeType struct { var _ Operation = (*OpChangeType)(nil) -func (o *OpChangeType) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { +func (o *OpChangeType) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { table := s.GetTable(o.Table) return table, nil } -func (o *OpChangeType) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpChangeType) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { return nil } -func (o *OpChangeType) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error { +func (o *OpChangeType) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error { return nil } diff --git a/pkg/migrations/op_create_index.go b/pkg/migrations/op_create_index.go index d09b726..d2aff86 100644 --- a/pkg/migrations/op_create_index.go +++ b/pkg/migrations/op_create_index.go @@ -4,17 +4,17 @@ package migrations import ( "context" - "database/sql" "fmt" "strings" "github.com/lib/pq" + "github.com/xataio/pgroll/pkg/db" "github.com/xataio/pgroll/pkg/schema" ) var _ Operation = (*OpCreateIndex)(nil) -func (o *OpCreateIndex) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { +func (o *OpCreateIndex) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { // create index concurrently _, err := conn.ExecContext(ctx, fmt.Sprintf("CREATE INDEX CONCURRENTLY IF NOT EXISTS %s ON %s (%s)", pq.QuoteIdentifier(o.Name), @@ -23,12 +23,12 @@ func (o *OpCreateIndex) Start(ctx context.Context, conn *sql.DB, stateSchema str return nil, err } -func (o *OpCreateIndex) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpCreateIndex) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { // No-op return nil } -func (o *OpCreateIndex) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error { +func (o *OpCreateIndex) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error { // drop the index concurrently _, err := conn.ExecContext(ctx, fmt.Sprintf("DROP INDEX CONCURRENTLY IF EXISTS %s", o.Name)) diff --git a/pkg/migrations/op_create_table.go b/pkg/migrations/op_create_table.go index adc4a37..a829c72 100644 --- a/pkg/migrations/op_create_table.go +++ b/pkg/migrations/op_create_table.go @@ -4,17 +4,17 @@ package migrations import ( "context" - "database/sql" "fmt" "strings" "github.com/lib/pq" + "github.com/xataio/pgroll/pkg/db" "github.com/xataio/pgroll/pkg/schema" ) var _ Operation = (*OpCreateTable)(nil) -func (o *OpCreateTable) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { +func (o *OpCreateTable) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { // Generate SQL for the columns in the table columnsSQL, err := columnsToSQL(o.Columns, tr) if err != nil { @@ -61,7 +61,7 @@ func (o *OpCreateTable) Start(ctx context.Context, conn *sql.DB, stateSchema str return nil, nil } -func (o *OpCreateTable) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpCreateTable) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { tempName := TemporaryName(o.Name) _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s RENAME TO %s", pq.QuoteIdentifier(tempName), @@ -69,7 +69,7 @@ func (o *OpCreateTable) Complete(ctx context.Context, conn *sql.DB, tr SQLTransf return err } -func (o *OpCreateTable) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error { +func (o *OpCreateTable) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error { tempName := TemporaryName(o.Name) _, err := conn.ExecContext(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s", diff --git a/pkg/migrations/op_drop_column.go b/pkg/migrations/op_drop_column.go index fdb1d2d..cd9494a 100644 --- a/pkg/migrations/op_drop_column.go +++ b/pkg/migrations/op_drop_column.go @@ -4,16 +4,16 @@ package migrations import ( "context" - "database/sql" "fmt" "github.com/lib/pq" + "github.com/xataio/pgroll/pkg/db" "github.com/xataio/pgroll/pkg/schema" ) var _ Operation = (*OpDropColumn)(nil) -func (o *OpDropColumn) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { +func (o *OpDropColumn) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { if o.Down != "" { err := createTrigger(ctx, conn, tr, triggerConfig{ Name: TriggerName(o.Table, o.Column), @@ -34,7 +34,7 @@ func (o *OpDropColumn) Start(ctx context.Context, conn *sql.DB, stateSchema stri return nil, nil } -func (o *OpDropColumn) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpDropColumn) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s DROP COLUMN %s", pq.QuoteIdentifier(o.Table), pq.QuoteIdentifier(o.Column))) @@ -48,7 +48,7 @@ func (o *OpDropColumn) Complete(ctx context.Context, conn *sql.DB, tr SQLTransfo return err } -func (o *OpDropColumn) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error { +func (o *OpDropColumn) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error { _, err := conn.ExecContext(ctx, fmt.Sprintf("DROP FUNCTION IF EXISTS %s CASCADE", pq.QuoteIdentifier(TriggerFunctionName(o.Table, o.Column)))) diff --git a/pkg/migrations/op_drop_constraint.go b/pkg/migrations/op_drop_constraint.go index 907d1a9..64784cc 100644 --- a/pkg/migrations/op_drop_constraint.go +++ b/pkg/migrations/op_drop_constraint.go @@ -4,16 +4,16 @@ package migrations import ( "context" - "database/sql" "fmt" "github.com/lib/pq" + "github.com/xataio/pgroll/pkg/db" "github.com/xataio/pgroll/pkg/schema" ) var _ Operation = (*OpDropConstraint)(nil) -func (o *OpDropConstraint) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { +func (o *OpDropConstraint) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { table := s.GetTable(o.Table) column := table.GetColumn(o.Column) @@ -62,7 +62,7 @@ func (o *OpDropConstraint) Start(ctx context.Context, conn *sql.DB, stateSchema return table, nil } -func (o *OpDropConstraint) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpDropConstraint) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { // Remove the up function and trigger _, err := conn.ExecContext(ctx, fmt.Sprintf("DROP FUNCTION IF EXISTS %s CASCADE", pq.QuoteIdentifier(TriggerFunctionName(o.Table, o.Column)))) @@ -95,7 +95,7 @@ func (o *OpDropConstraint) Complete(ctx context.Context, conn *sql.DB, tr SQLTra return err } -func (o *OpDropConstraint) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error { +func (o *OpDropConstraint) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error { // Drop the new column _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s DROP COLUMN IF EXISTS %s", pq.QuoteIdentifier(o.Table), diff --git a/pkg/migrations/op_drop_index.go b/pkg/migrations/op_drop_index.go index 884d6d6..dab42ed 100644 --- a/pkg/migrations/op_drop_index.go +++ b/pkg/migrations/op_drop_index.go @@ -4,27 +4,27 @@ package migrations import ( "context" - "database/sql" "fmt" + "github.com/xataio/pgroll/pkg/db" "github.com/xataio/pgroll/pkg/schema" ) var _ Operation = (*OpDropIndex)(nil) -func (o *OpDropIndex) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { +func (o *OpDropIndex) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { // no-op return nil, nil } -func (o *OpDropIndex) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpDropIndex) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { // drop the index concurrently _, err := conn.ExecContext(ctx, fmt.Sprintf("DROP INDEX CONCURRENTLY IF EXISTS %s", o.Name)) return err } -func (o *OpDropIndex) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error { +func (o *OpDropIndex) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error { // no-op return nil } diff --git a/pkg/migrations/op_drop_not_null.go b/pkg/migrations/op_drop_not_null.go index 3d3b2b3..acf7530 100644 --- a/pkg/migrations/op_drop_not_null.go +++ b/pkg/migrations/op_drop_not_null.go @@ -4,8 +4,8 @@ package migrations import ( "context" - "database/sql" + "github.com/xataio/pgroll/pkg/db" "github.com/xataio/pgroll/pkg/schema" ) @@ -18,17 +18,17 @@ type OpDropNotNull struct { var _ Operation = (*OpDropNotNull)(nil) -func (o *OpDropNotNull) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { +func (o *OpDropNotNull) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { table := s.GetTable(o.Table) return table, nil } -func (o *OpDropNotNull) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpDropNotNull) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { return nil } -func (o *OpDropNotNull) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error { +func (o *OpDropNotNull) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error { return nil } diff --git a/pkg/migrations/op_drop_table.go b/pkg/migrations/op_drop_table.go index b25e4e2..ce745e9 100644 --- a/pkg/migrations/op_drop_table.go +++ b/pkg/migrations/op_drop_table.go @@ -4,27 +4,27 @@ package migrations import ( "context" - "database/sql" "fmt" "github.com/lib/pq" + "github.com/xataio/pgroll/pkg/db" "github.com/xataio/pgroll/pkg/schema" ) var _ Operation = (*OpDropTable)(nil) -func (o *OpDropTable) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { +func (o *OpDropTable) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { s.RemoveTable(o.Name) return nil, nil } -func (o *OpDropTable) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpDropTable) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { _, err := conn.ExecContext(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s", pq.QuoteIdentifier(o.Name))) return err } -func (o *OpDropTable) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error { +func (o *OpDropTable) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error { return nil } diff --git a/pkg/migrations/op_raw_sql.go b/pkg/migrations/op_raw_sql.go index 7f3f9e2..910a439 100644 --- a/pkg/migrations/op_raw_sql.go +++ b/pkg/migrations/op_raw_sql.go @@ -4,14 +4,14 @@ package migrations import ( "context" - "database/sql" + "github.com/xataio/pgroll/pkg/db" "github.com/xataio/pgroll/pkg/schema" ) var _ Operation = (*OpRawSQL)(nil) -func (o *OpRawSQL) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { +func (o *OpRawSQL) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { if o.OnComplete { return nil, nil } @@ -25,7 +25,7 @@ func (o *OpRawSQL) Start(ctx context.Context, conn *sql.DB, stateSchema string, return nil, err } -func (o *OpRawSQL) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpRawSQL) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { if !o.OnComplete { return nil } @@ -39,7 +39,7 @@ func (o *OpRawSQL) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer return err } -func (o *OpRawSQL) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error { +func (o *OpRawSQL) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error { if o.Down == "" { return nil } diff --git a/pkg/migrations/op_rename_constraint.go b/pkg/migrations/op_rename_constraint.go index 875693a..2ff3347 100644 --- a/pkg/migrations/op_rename_constraint.go +++ b/pkg/migrations/op_rename_constraint.go @@ -4,21 +4,21 @@ package migrations import ( "context" - "database/sql" "fmt" "github.com/lib/pq" + "github.com/xataio/pgroll/pkg/db" "github.com/xataio/pgroll/pkg/schema" ) var _ Operation = (*OpRenameConstraint)(nil) -func (o *OpRenameConstraint) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { +func (o *OpRenameConstraint) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { // no-op return nil, nil } -func (o *OpRenameConstraint) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpRenameConstraint) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { // rename the constraint in the underlying table _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s RENAME CONSTRAINT %s TO %s", pq.QuoteIdentifier(o.Table), @@ -27,7 +27,7 @@ func (o *OpRenameConstraint) Complete(ctx context.Context, conn *sql.DB, tr SQLT return err } -func (o *OpRenameConstraint) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error { +func (o *OpRenameConstraint) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error { // no-op return nil } diff --git a/pkg/migrations/op_rename_table.go b/pkg/migrations/op_rename_table.go index eb85e1e..98b42b6 100644 --- a/pkg/migrations/op_rename_table.go +++ b/pkg/migrations/op_rename_table.go @@ -4,27 +4,27 @@ package migrations import ( "context" - "database/sql" "fmt" "github.com/lib/pq" + "github.com/xataio/pgroll/pkg/db" "github.com/xataio/pgroll/pkg/schema" ) var _ Operation = (*OpRenameTable)(nil) -func (o *OpRenameTable) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { +func (o *OpRenameTable) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { return nil, s.RenameTable(o.From, o.To) } -func (o *OpRenameTable) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpRenameTable) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s RENAME TO %s", pq.QuoteIdentifier(o.From), pq.QuoteIdentifier(o.To))) return err } -func (o *OpRenameTable) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error { +func (o *OpRenameTable) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error { return nil } diff --git a/pkg/migrations/op_set_check.go b/pkg/migrations/op_set_check.go index 36910fd..5eaa7de 100644 --- a/pkg/migrations/op_set_check.go +++ b/pkg/migrations/op_set_check.go @@ -4,11 +4,11 @@ package migrations import ( "context" - "database/sql" "fmt" "strings" "github.com/lib/pq" + "github.com/xataio/pgroll/pkg/db" "github.com/xataio/pgroll/pkg/schema" ) @@ -22,7 +22,7 @@ type OpSetCheckConstraint struct { var _ Operation = (*OpSetCheckConstraint)(nil) -func (o *OpSetCheckConstraint) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { +func (o *OpSetCheckConstraint) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { table := s.GetTable(o.Table) // Add the check constraint to the new column as NOT VALID. @@ -33,7 +33,7 @@ func (o *OpSetCheckConstraint) Start(ctx context.Context, conn *sql.DB, stateSch return table, nil } -func (o *OpSetCheckConstraint) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpSetCheckConstraint) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { // Validate the check constraint _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s VALIDATE CONSTRAINT %s", pq.QuoteIdentifier(o.Table), @@ -45,7 +45,7 @@ func (o *OpSetCheckConstraint) Complete(ctx context.Context, conn *sql.DB, tr SQ return nil } -func (o *OpSetCheckConstraint) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error { +func (o *OpSetCheckConstraint) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error { return nil } @@ -69,7 +69,7 @@ func (o *OpSetCheckConstraint) Validate(ctx context.Context, s *schema.Schema) e return nil } -func (o *OpSetCheckConstraint) addCheckConstraint(ctx context.Context, conn *sql.DB) error { +func (o *OpSetCheckConstraint) addCheckConstraint(ctx context.Context, conn db.DB) error { _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s CHECK (%s) NOT VALID", pq.QuoteIdentifier(o.Table), pq.QuoteIdentifier(o.Check.Name), diff --git a/pkg/migrations/op_set_comment.go b/pkg/migrations/op_set_comment.go index fc8d13a..ec3804b 100644 --- a/pkg/migrations/op_set_comment.go +++ b/pkg/migrations/op_set_comment.go @@ -4,8 +4,8 @@ package migrations import ( "context" - "database/sql" + "github.com/xataio/pgroll/pkg/db" "github.com/xataio/pgroll/pkg/schema" ) @@ -19,17 +19,17 @@ type OpSetComment struct { var _ Operation = (*OpSetComment)(nil) -func (o *OpSetComment) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { +func (o *OpSetComment) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { tbl := s.GetTable(o.Table) return tbl, addCommentToColumn(ctx, conn, o.Table, TemporaryName(o.Column), o.Comment) } -func (o *OpSetComment) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpSetComment) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { return nil } -func (o *OpSetComment) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error { +func (o *OpSetComment) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error { return nil } diff --git a/pkg/migrations/op_set_default.go b/pkg/migrations/op_set_default.go index 6da2a04..4fbbcda 100644 --- a/pkg/migrations/op_set_default.go +++ b/pkg/migrations/op_set_default.go @@ -4,10 +4,10 @@ package migrations import ( "context" - "database/sql" "fmt" "github.com/lib/pq" + "github.com/xataio/pgroll/pkg/db" "github.com/xataio/pgroll/pkg/schema" ) @@ -21,7 +21,7 @@ type OpSetDefault struct { var _ Operation = (*OpSetDefault)(nil) -func (o *OpSetDefault) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { +func (o *OpSetDefault) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { tbl := s.GetTable(o.Table) _, err := conn.ExecContext(ctx, fmt.Sprintf(`ALTER TABLE %s ALTER COLUMN %s SET DEFAULT %s`, @@ -35,11 +35,11 @@ func (o *OpSetDefault) Start(ctx context.Context, conn *sql.DB, stateSchema stri return tbl, nil } -func (o *OpSetDefault) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpSetDefault) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { return nil } -func (o *OpSetDefault) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error { +func (o *OpSetDefault) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error { return nil } diff --git a/pkg/migrations/op_set_fk.go b/pkg/migrations/op_set_fk.go index 6cbe711..90541e3 100644 --- a/pkg/migrations/op_set_fk.go +++ b/pkg/migrations/op_set_fk.go @@ -4,11 +4,11 @@ package migrations import ( "context" - "database/sql" "fmt" "strings" "github.com/lib/pq" + "github.com/xataio/pgroll/pkg/db" "github.com/xataio/pgroll/pkg/schema" ) @@ -22,7 +22,7 @@ type OpSetForeignKey struct { var _ Operation = (*OpSetForeignKey)(nil) -func (o *OpSetForeignKey) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { +func (o *OpSetForeignKey) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { table := s.GetTable(o.Table) // Create a NOT VALID foreign key constraint on the new column. @@ -33,7 +33,7 @@ func (o *OpSetForeignKey) Start(ctx context.Context, conn *sql.DB, stateSchema s return table, nil } -func (o *OpSetForeignKey) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpSetForeignKey) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { // Validate the foreign key constraint _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s VALIDATE CONSTRAINT %s", pq.QuoteIdentifier(o.Table), @@ -45,7 +45,7 @@ func (o *OpSetForeignKey) Complete(ctx context.Context, conn *sql.DB, tr SQLTran return nil } -func (o *OpSetForeignKey) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error { +func (o *OpSetForeignKey) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error { return nil } @@ -69,7 +69,7 @@ func (o *OpSetForeignKey) Validate(ctx context.Context, s *schema.Schema) error return nil } -func (o *OpSetForeignKey) addForeignKeyConstraint(ctx context.Context, conn *sql.DB) error { +func (o *OpSetForeignKey) addForeignKeyConstraint(ctx context.Context, conn db.DB) error { tempColumnName := TemporaryName(o.Column) onDelete := "NO ACTION" diff --git a/pkg/migrations/op_set_notnull.go b/pkg/migrations/op_set_notnull.go index 066b56a..c6d6395 100644 --- a/pkg/migrations/op_set_notnull.go +++ b/pkg/migrations/op_set_notnull.go @@ -4,10 +4,10 @@ package migrations import ( "context" - "database/sql" "fmt" "github.com/lib/pq" + "github.com/xataio/pgroll/pkg/db" "github.com/xataio/pgroll/pkg/schema" ) @@ -20,7 +20,7 @@ type OpSetNotNull struct { var _ Operation = (*OpSetNotNull)(nil) -func (o *OpSetNotNull) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { +func (o *OpSetNotNull) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { table := s.GetTable(o.Table) // Add an unchecked NOT NULL constraint to the new column. @@ -31,7 +31,7 @@ func (o *OpSetNotNull) Start(ctx context.Context, conn *sql.DB, stateSchema stri return table, nil } -func (o *OpSetNotNull) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpSetNotNull) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { // Validate the NOT NULL constraint on the old column. // The constraint must be valid because: // * Existing NULL values in the old column were rewritten using the `up` SQL during backfill. @@ -62,7 +62,7 @@ func (o *OpSetNotNull) Complete(ctx context.Context, conn *sql.DB, tr SQLTransfo return nil } -func (o *OpSetNotNull) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error { +func (o *OpSetNotNull) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error { return nil } diff --git a/pkg/migrations/op_set_replica_identity.go b/pkg/migrations/op_set_replica_identity.go index 75c196b..282a554 100644 --- a/pkg/migrations/op_set_replica_identity.go +++ b/pkg/migrations/op_set_replica_identity.go @@ -4,18 +4,18 @@ package migrations import ( "context" - "database/sql" "fmt" "slices" "strings" "github.com/lib/pq" + "github.com/xataio/pgroll/pkg/db" "github.com/xataio/pgroll/pkg/schema" ) var _ Operation = (*OpSetReplicaIdentity)(nil) -func (o *OpSetReplicaIdentity) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { +func (o *OpSetReplicaIdentity) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { // build the correct form of the `SET REPLICA IDENTITY` statement based on the`identity type identitySQL := strings.ToUpper(o.Identity.Type) if identitySQL == "INDEX" { @@ -29,12 +29,12 @@ func (o *OpSetReplicaIdentity) Start(ctx context.Context, conn *sql.DB, stateSch return nil, err } -func (o *OpSetReplicaIdentity) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpSetReplicaIdentity) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { // No-op return nil } -func (o *OpSetReplicaIdentity) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error { +func (o *OpSetReplicaIdentity) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error { // No-op return nil } diff --git a/pkg/migrations/op_set_unique.go b/pkg/migrations/op_set_unique.go index a19614e..5e89e16 100644 --- a/pkg/migrations/op_set_unique.go +++ b/pkg/migrations/op_set_unique.go @@ -4,10 +4,10 @@ package migrations import ( "context" - "database/sql" "fmt" "github.com/lib/pq" + "github.com/xataio/pgroll/pkg/db" "github.com/xataio/pgroll/pkg/schema" ) @@ -21,7 +21,7 @@ type OpSetUnique struct { var _ Operation = (*OpSetUnique)(nil) -func (o *OpSetUnique) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { +func (o *OpSetUnique) Start(ctx context.Context, conn db.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { table := s.GetTable(o.Table) // Add a unique index to the new column @@ -32,7 +32,7 @@ func (o *OpSetUnique) Start(ctx context.Context, conn *sql.DB, stateSchema strin return table, nil } -func (o *OpSetUnique) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpSetUnique) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { // Create a unique constraint using the unique index _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s ADD CONSTRAINT %s UNIQUE USING INDEX %s", pq.QuoteIdentifier(o.Table), @@ -45,7 +45,7 @@ func (o *OpSetUnique) Complete(ctx context.Context, conn *sql.DB, tr SQLTransfor return err } -func (o *OpSetUnique) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error { +func (o *OpSetUnique) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer) error { return nil } @@ -66,7 +66,7 @@ func (o *OpSetUnique) Validate(ctx context.Context, s *schema.Schema) error { return nil } -func (o *OpSetUnique) addUniqueIndex(ctx context.Context, conn *sql.DB) error { +func (o *OpSetUnique) addUniqueIndex(ctx context.Context, conn db.DB) error { // create unique index concurrently _, err := conn.ExecContext(ctx, fmt.Sprintf("CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS %s ON %s (%s)", pq.QuoteIdentifier(o.Name), diff --git a/pkg/migrations/rename.go b/pkg/migrations/rename.go index d75d6ca..b8b888b 100644 --- a/pkg/migrations/rename.go +++ b/pkg/migrations/rename.go @@ -4,11 +4,11 @@ package migrations import ( "context" - "database/sql" "fmt" "slices" "github.com/lib/pq" + "github.com/xataio/pgroll/pkg/db" "github.com/xataio/pgroll/pkg/schema" ) @@ -16,7 +16,7 @@ import ( // * renames a duplicated column to its original name // * renames any foreign keys on the duplicated column to their original name. // * Validates and renames any temporary `CHECK` constraints on the duplicated column. -func RenameDuplicatedColumn(ctx context.Context, conn *sql.DB, table *schema.Table, column *schema.Column) error { +func RenameDuplicatedColumn(ctx context.Context, conn db.DB, table *schema.Table, column *schema.Column) error { const ( cRenameColumnSQL = `ALTER TABLE IF EXISTS %s RENAME COLUMN %s TO %s` cRenameConstraintSQL = `ALTER TABLE IF EXISTS %s RENAME CONSTRAINT %s TO %s` diff --git a/pkg/migrations/trigger.go b/pkg/migrations/trigger.go index eb932d3..e065f71 100644 --- a/pkg/migrations/trigger.go +++ b/pkg/migrations/trigger.go @@ -5,10 +5,10 @@ package migrations import ( "bytes" "context" - "database/sql" "text/template" "github.com/lib/pq" + "github.com/xataio/pgroll/pkg/db" "github.com/xataio/pgroll/pkg/migrations/templates" "github.com/xataio/pgroll/pkg/schema" ) @@ -33,7 +33,7 @@ type triggerConfig struct { SQL string } -func createTrigger(ctx context.Context, conn *sql.DB, tr SQLTransformer, cfg triggerConfig) error { +func createTrigger(ctx context.Context, conn db.DB, tr SQLTransformer, cfg triggerConfig) error { sql, err := tr.TransformSQL(cfg.SQL) if err != nil { return err diff --git a/pkg/roll/execute_test.go b/pkg/roll/execute_test.go index d43b7f2..db419e9 100644 --- a/pkg/roll/execute_test.go +++ b/pkg/roll/execute_test.go @@ -5,11 +5,10 @@ package roll_test import ( "context" "database/sql" - "errors" "fmt" "testing" + "time" - "github.com/lib/pq" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/xataio/pgroll/pkg/migrations" @@ -313,57 +312,48 @@ func TestSchemaOptionIsRespected(t *testing.T) { }) } -func TestLockTimeoutIsEnforced(t *testing.T) { +func TestMigrationDDLIsRetriedOnLockTimeouts(t *testing.T) { t.Parallel() - testutils.WithMigratorInSchemaAndConnectionToContainerWithOptions(t, "public", []roll.Option{roll.WithLockTimeoutMs(100)}, func(mig *roll.Roll, db *sql.DB) { + testutils.WithMigratorInSchemaAndConnectionToContainerWithOptions(t, "public", []roll.Option{roll.WithLockTimeoutMs(50)}, func(mig *roll.Roll, db *sql.DB) { ctx := context.Background() - // Start a create table migration - err := mig.Start(ctx, &migrations.Migration{ - Name: "01_create_table", - Operations: migrations.Operations{createTableOp("table1")}, - }) - if err != nil { - t.Fatalf("Failed to start migration: %v", err) - } + // Create a table + _, err := db.ExecContext(ctx, "CREATE TABLE table1 (id integer, name text)") + require.NoError(t, err) - // Complete the create table migration - if err := mig.Complete(ctx); err != nil { - t.Fatalf("Failed to complete migration: %v", err) - } + // Start a goroutine which takes an ACCESS_EXCLUSIVE lock on the table for + // two seconds + errCh := make(chan error) + go func() { + tx, err := db.Begin() + if err != nil { + errCh <- err + } - // Start a transaction and take an ACCESS_EXCLUSIVE lock on the table - // Don't commit the transaction so that the lock is held indefinitely - tx, err := db.Begin() - if err != nil { - t.Fatalf("Failed to start transaction: %v", err) - } - t.Cleanup(func() { + if _, err := tx.ExecContext(ctx, "LOCK TABLE table1 IN ACCESS EXCLUSIVE MODE"); err != nil { + errCh <- err + } + errCh <- nil + + // Sleep for two seconds to hold the lock + time.Sleep(2 * time.Second) + + // Commit the transaction tx.Commit() - }) - if _, err := tx.ExecContext(ctx, "LOCK TABLE table1 IN ACCESS EXCLUSIVE MODE"); err != nil { - t.Fatalf("Failed to take ACCESS_EXCLUSIVE lock on table: %v", err) - } + }() - // Attempt to run a second migration on the table while the lock is held - // The migration should fail due to a lock timeout error + // Wait for lock to be taken + err = <-errCh + require.NoError(t, err) + + // Attempt to start a second migration on the table while the lock is held. + // The migration should eventually succeed after the lock is released err = mig.Start(ctx, &migrations.Migration{ - Name: "02_create_table", + Name: "01_add_column", Operations: migrations.Operations{addColumnOp("table1")}, }) - if err == nil { - t.Fatalf("Expected migration to fail due to lock timeout") - } - if err != nil { - pqErr := &pq.Error{} - if ok := errors.As(err, &pqErr); !ok { - t.Fatalf("Migration failed with unexpected error: %v", err) - } - if pqErr.Code != "55P03" { // Lock not available error code - t.Fatalf("Migration failed with unexpected error: %v", err) - } - } + require.NoError(t, err) }) } diff --git a/pkg/roll/roll.go b/pkg/roll/roll.go index 46df978..d2e37e7 100644 --- a/pkg/roll/roll.go +++ b/pkg/roll/roll.go @@ -9,6 +9,7 @@ import ( "github.com/lib/pq" + "github.com/xataio/pgroll/pkg/db" "github.com/xataio/pgroll/pkg/migrations" "github.com/xataio/pgroll/pkg/state" ) @@ -18,7 +19,7 @@ type PGVersion int const PGVersion15 PGVersion = 15 type Roll struct { - pgConn *sql.DB // TODO abstract sql connection + pgConn db.DB // schema we are acting on schema string @@ -57,7 +58,7 @@ func New(ctx context.Context, pgURL, schema string, state *state.State, opts ... } return &Roll{ - pgConn: conn, + pgConn: &db.RDB{DB: conn}, schema: schema, state: state, pgVersion: PGVersion(pgMajorVersion), @@ -114,7 +115,7 @@ func (m *Roll) PGVersion() PGVersion { return m.pgVersion } -func (m *Roll) PgConn() *sql.DB { +func (m *Roll) PgConn() db.DB { return m.pgConn } diff --git a/pkg/testutils/util.go b/pkg/testutils/util.go index c815340..7cb1b54 100644 --- a/pkg/testutils/util.go +++ b/pkg/testutils/util.go @@ -103,6 +103,14 @@ func WithStateInSchemaAndConnectionToContainer(t *testing.T, schema string, fn f fn(st, db) } +func WithConnectionToContainer(t *testing.T, fn func(*sql.DB, string)) { + t.Helper() + + db, connStr, _ := setupTestDatabase(t) + + fn(db, connStr) +} + func WithStateAndConnectionToContainer(t *testing.T, fn func(*state.State, *sql.DB)) { WithStateInSchemaAndConnectionToContainer(t, "pgroll", fn) }