Make state initialization concurrency safe (#285)

Make `pgroll` state initialization concurrency safe by using Postgres
advisory locking to ensure at most one connection can initialize at at
time.

See docs on Postgres advisory locking:
*
https://www.postgresql.org/docs/current/functions-admin.html#FUNCTIONS-ADVISORY-LOCKS
*
https://www.postgresql.org/docs/current/explicit-locking.html#ADVISORY-LOCKS

Closes https://github.com/xataio/pgroll/issues/283
This commit is contained in:
Andrew Farries 2024-02-26 09:04:54 +00:00 committed by GitHub
parent 161fde60ca
commit b4e3044adf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 109 additions and 77 deletions

View File

@ -354,11 +354,28 @@ func New(ctx context.Context, pgURL, stateSchema string) (*State, error) {
}
func (s *State) Init(ctx context.Context) error {
// ensure pgroll internal tables exist
// TODO: eventually use migrations for this instead of hardcoding
_, err := s.pgConn.ExecContext(ctx, fmt.Sprintf(sqlInit, pq.QuoteIdentifier(s.schema)))
tx, err := s.pgConn.Begin()
if err != nil {
return err
}
defer tx.Rollback()
return err
// Try to obtain an advisory lock.
// The key is an arbitrary number, used to distinguish the lock from other locks.
// The lock is automatically released when the transaction is committed or rolled back.
const key int64 = 0x2c03057fb9525b
_, err = tx.ExecContext(ctx, "SELECT pg_advisory_xact_lock($1)", key)
if err != nil {
return err
}
// Perform pgroll state initialization
_, err = tx.ExecContext(ctx, fmt.Sprintf(sqlInit, pq.QuoteIdentifier(s.schema)))
if err != nil {
return err
}
return tx.Commit()
}
func (s *State) Close() error {

View File

@ -8,6 +8,7 @@ import (
"encoding/json"
"fmt"
"strings"
"sync"
"testing"
"github.com/google/go-cmp/cmp"
@ -132,6 +133,29 @@ func TestPgRollInitializationInANonDefaultSchema(t *testing.T) {
})
}
func TestConcurrentInitialization(t *testing.T) {
t.Parallel()
testutils.WithUninitializedState(t, func(state *state.State) {
ctx := context.Background()
numGoroutines := 10
wg := sync.WaitGroup{}
wg.Add(numGoroutines)
for i := 0; i < numGoroutines; i++ {
go func() {
defer wg.Done()
if err := state.Init(ctx); err != nil {
t.Error(err)
}
}()
}
wg.Wait()
})
}
func TestReadSchema(t *testing.T) {
t.Parallel()

View File

@ -89,49 +89,13 @@ func WithStateInSchemaAndConnectionToContainer(t *testing.T, schema string, fn f
t.Helper()
ctx := context.Background()
tDB, err := sql.Open("postgres", tConnStr)
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() {
if err := tDB.Close(); err != nil {
t.Fatalf("Failed to close database connection: %v", err)
}
})
dbName := randomDBName()
_, err = tDB.ExecContext(ctx, fmt.Sprintf("CREATE DATABASE %s", pq.QuoteIdentifier(dbName)))
if err != nil {
t.Fatal(err)
}
u, err := url.Parse(tConnStr)
if err != nil {
t.Fatal(err)
}
u.Path = "/" + dbName
connStr := u.String()
db, connStr, _ := setupTestDatabase(t)
st, err := state.New(ctx, connStr, schema)
if err != nil {
t.Fatal(err)
}
db, err := sql.Open("postgres", connStr)
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() {
if err := db.Close(); err != nil {
t.Fatalf("Failed to close database connection: %v", err)
}
})
// init the state
if err := st.Init(ctx); err != nil {
t.Fatal(err)
}
@ -143,46 +107,25 @@ func WithStateAndConnectionToContainer(t *testing.T, fn func(*state.State, *sql.
WithStateInSchemaAndConnectionToContainer(t, "pgroll", fn)
}
func WithUninitializedState(t *testing.T, fn func(*state.State)) {
t.Helper()
ctx := context.Background()
_, connStr, _ := setupTestDatabase(t)
st, err := state.New(ctx, connStr, "pgroll")
if err != nil {
t.Fatal(err)
}
fn(st)
}
func WithMigratorInSchemaAndConnectionToContainerWithOptions(t *testing.T, schema string, opts []roll.Option, fn func(mig *roll.Roll, db *sql.DB)) {
t.Helper()
ctx := context.Background()
tDB, err := sql.Open("postgres", tConnStr)
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() {
if err := tDB.Close(); err != nil {
t.Fatalf("Failed to close database connection: %v", err)
}
})
dbName := randomDBName()
_, err = tDB.ExecContext(ctx, fmt.Sprintf("CREATE DATABASE %s", pq.QuoteIdentifier(dbName)))
if err != nil {
t.Fatal(err)
}
u, err := url.Parse(tConnStr)
if err != nil {
t.Fatal(err)
}
u.Path = "/" + dbName
connStr := u.String()
db, err := sql.Open("postgres", connStr)
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() {
if err := db.Close(); err != nil {
t.Fatalf("Failed to close database connection: %v", err)
}
})
db, connStr, dbName := setupTestDatabase(t)
st, err := state.New(ctx, connStr, "pgroll")
if err != nil {
@ -230,3 +173,51 @@ func WithMigratorInSchemaAndConnectionToContainer(t *testing.T, schema string, f
func WithMigratorAndConnectionToContainer(t *testing.T, fn func(mig *roll.Roll, db *sql.DB)) {
WithMigratorInSchemaAndConnectionToContainerWithOptions(t, "public", []roll.Option{roll.WithLockTimeoutMs(500)}, fn)
}
// setupTestDatabase creates a new database in the test container and returns:
// - a connection to the new database
// - the connection string to the new database
// - the name of the new database
func setupTestDatabase(t *testing.T) (*sql.DB, string, string) {
t.Helper()
ctx := context.Background()
tDB, err := sql.Open("postgres", tConnStr)
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() {
if err := tDB.Close(); err != nil {
t.Fatalf("Failed to close database connection: %v", err)
}
})
dbName := randomDBName()
_, err = tDB.ExecContext(ctx, fmt.Sprintf("CREATE DATABASE %s", pq.QuoteIdentifier(dbName)))
if err != nil {
t.Fatal(err)
}
u, err := url.Parse(tConnStr)
if err != nil {
t.Fatal(err)
}
u.Path = "/" + dbName
connStr := u.String()
db, err := sql.Open("postgres", connStr)
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() {
if err := db.Close(); err != nil {
t.Fatalf("Failed to close database connection: %v", err)
}
})
return db, connStr, dbName
}