mirror of
https://github.com/xataio/pgroll.git
synced 2024-09-11 05:45:48 +03:00
Make state initialization concurrency safe (#285)
Make `pgroll` state initialization concurrency safe by using Postgres advisory locking to ensure at most one connection can initialize at at time. See docs on Postgres advisory locking: * https://www.postgresql.org/docs/current/functions-admin.html#FUNCTIONS-ADVISORY-LOCKS * https://www.postgresql.org/docs/current/explicit-locking.html#ADVISORY-LOCKS Closes https://github.com/xataio/pgroll/issues/283
This commit is contained in:
parent
161fde60ca
commit
b4e3044adf
@ -354,11 +354,28 @@ func New(ctx context.Context, pgURL, stateSchema string) (*State, error) {
|
||||
}
|
||||
|
||||
func (s *State) Init(ctx context.Context) error {
|
||||
// ensure pgroll internal tables exist
|
||||
// TODO: eventually use migrations for this instead of hardcoding
|
||||
_, err := s.pgConn.ExecContext(ctx, fmt.Sprintf(sqlInit, pq.QuoteIdentifier(s.schema)))
|
||||
tx, err := s.pgConn.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
return err
|
||||
// Try to obtain an advisory lock.
|
||||
// The key is an arbitrary number, used to distinguish the lock from other locks.
|
||||
// The lock is automatically released when the transaction is committed or rolled back.
|
||||
const key int64 = 0x2c03057fb9525b
|
||||
_, err = tx.ExecContext(ctx, "SELECT pg_advisory_xact_lock($1)", key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Perform pgroll state initialization
|
||||
_, err = tx.ExecContext(ctx, fmt.Sprintf(sqlInit, pq.QuoteIdentifier(s.schema)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (s *State) Close() error {
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
@ -132,6 +133,29 @@ func TestPgRollInitializationInANonDefaultSchema(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestConcurrentInitialization(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testutils.WithUninitializedState(t, func(state *state.State) {
|
||||
ctx := context.Background()
|
||||
numGoroutines := 10
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(numGoroutines)
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
if err := state.Init(ctx); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
})
|
||||
}
|
||||
|
||||
func TestReadSchema(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
@ -89,49 +89,13 @@ func WithStateInSchemaAndConnectionToContainer(t *testing.T, schema string, fn f
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
tDB, err := sql.Open("postgres", tConnStr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
if err := tDB.Close(); err != nil {
|
||||
t.Fatalf("Failed to close database connection: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
dbName := randomDBName()
|
||||
|
||||
_, err = tDB.ExecContext(ctx, fmt.Sprintf("CREATE DATABASE %s", pq.QuoteIdentifier(dbName)))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
u, err := url.Parse(tConnStr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
u.Path = "/" + dbName
|
||||
connStr := u.String()
|
||||
db, connStr, _ := setupTestDatabase(t)
|
||||
|
||||
st, err := state.New(ctx, connStr, schema)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
db, err := sql.Open("postgres", connStr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
if err := db.Close(); err != nil {
|
||||
t.Fatalf("Failed to close database connection: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
// init the state
|
||||
if err := st.Init(ctx); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -143,46 +107,25 @@ func WithStateAndConnectionToContainer(t *testing.T, fn func(*state.State, *sql.
|
||||
WithStateInSchemaAndConnectionToContainer(t, "pgroll", fn)
|
||||
}
|
||||
|
||||
func WithUninitializedState(t *testing.T, fn func(*state.State)) {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
_, connStr, _ := setupTestDatabase(t)
|
||||
|
||||
st, err := state.New(ctx, connStr, "pgroll")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
fn(st)
|
||||
}
|
||||
|
||||
func WithMigratorInSchemaAndConnectionToContainerWithOptions(t *testing.T, schema string, opts []roll.Option, fn func(mig *roll.Roll, db *sql.DB)) {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
tDB, err := sql.Open("postgres", tConnStr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
if err := tDB.Close(); err != nil {
|
||||
t.Fatalf("Failed to close database connection: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
dbName := randomDBName()
|
||||
|
||||
_, err = tDB.ExecContext(ctx, fmt.Sprintf("CREATE DATABASE %s", pq.QuoteIdentifier(dbName)))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
u, err := url.Parse(tConnStr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
u.Path = "/" + dbName
|
||||
connStr := u.String()
|
||||
|
||||
db, err := sql.Open("postgres", connStr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
if err := db.Close(); err != nil {
|
||||
t.Fatalf("Failed to close database connection: %v", err)
|
||||
}
|
||||
})
|
||||
db, connStr, dbName := setupTestDatabase(t)
|
||||
|
||||
st, err := state.New(ctx, connStr, "pgroll")
|
||||
if err != nil {
|
||||
@ -230,3 +173,51 @@ func WithMigratorInSchemaAndConnectionToContainer(t *testing.T, schema string, f
|
||||
func WithMigratorAndConnectionToContainer(t *testing.T, fn func(mig *roll.Roll, db *sql.DB)) {
|
||||
WithMigratorInSchemaAndConnectionToContainerWithOptions(t, "public", []roll.Option{roll.WithLockTimeoutMs(500)}, fn)
|
||||
}
|
||||
|
||||
// setupTestDatabase creates a new database in the test container and returns:
|
||||
// - a connection to the new database
|
||||
// - the connection string to the new database
|
||||
// - the name of the new database
|
||||
func setupTestDatabase(t *testing.T) (*sql.DB, string, string) {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
tDB, err := sql.Open("postgres", tConnStr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
if err := tDB.Close(); err != nil {
|
||||
t.Fatalf("Failed to close database connection: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
dbName := randomDBName()
|
||||
|
||||
_, err = tDB.ExecContext(ctx, fmt.Sprintf("CREATE DATABASE %s", pq.QuoteIdentifier(dbName)))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
u, err := url.Parse(tConnStr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
u.Path = "/" + dbName
|
||||
connStr := u.String()
|
||||
|
||||
db, err := sql.Open("postgres", connStr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
if err := db.Close(); err != nil {
|
||||
t.Fatalf("Failed to close database connection: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
return db, connStr, dbName
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user