From fe5039d17a452dfabed77ef9dd060b16114bca18 Mon Sep 17 00:00:00 2001 From: Dan Sosedoff Date: Sat, 4 Nov 2023 11:12:48 -0500 Subject: [PATCH] Allow retrying a connection on startup (#695) * Allow retrying a connection on startup * Remove unused vars * Add an extra comment * Restructure retry logic a bit * Update retry logic * Fix comment * Update comment * Change type for RetryCount and RetryDelay to uint * Extra test cases * Tweak --- pkg/cli/cli.go | 72 +++++++++++++++++++++++---------------- pkg/client/client.go | 36 +++++++++++++++++++- pkg/client/client_test.go | 42 ++++++++++++++++++++++- pkg/command/options.go | 4 ++- 4 files changed, 122 insertions(+), 32 deletions(-) diff --git a/pkg/cli/cli.go b/pkg/cli/cli.go index 0073f79..8667381 100644 --- a/pkg/cli/cli.go +++ b/pkg/cli/cli.go @@ -6,7 +6,6 @@ import ( "os" "os/exec" "os/signal" - "regexp" "strings" "syscall" "time" @@ -35,9 +34,6 @@ SECURITY WARNING: You are running Pgweb in read-only mode. This mode is designed for environments where users could potentially delete or change data. For proper read-only access please follow PostgreSQL role management documentation. --------------------------------------------------------------------------------` - - regexErrConnectionRefused = regexp.MustCompile(`(connection|actively) refused`) - regexErrAuthFailed = regexp.MustCompile(`authentication failed`) ) func init() { @@ -77,35 +73,20 @@ func initClient() { } if command.Opts.Debug { - fmt.Println("Server connection string:", cl.ConnectionString) + fmt.Println("Opening database connection using string:", cl.ConnectionString) } + retryCount := command.Opts.RetryCount + retryDelay := time.Second * time.Duration(command.Opts.RetryDelay) + fmt.Println("Connecting to server...") - if err := cl.Test(); err != nil { - msg := err.Error() - - // Check if we're trying to connect to the default database. - if command.Opts.DbName == "" && command.Opts.URL == "" { - // If database does not exist, allow user to connect from the UI. - if strings.Contains(msg, "database") && strings.Contains(msg, "does not exist") { - fmt.Println("Error:", msg) - return - } - - // Do not bail if local server is not running. - if regexErrConnectionRefused.MatchString(msg) { - fmt.Println("Error:", msg) - return - } - - // Do not bail if local auth is invalid - if regexErrAuthFailed.MatchString(msg) { - fmt.Println("Error:", msg) - return - } + abort, err := testClient(cl, int(retryCount), retryDelay) + if err != nil { + if abort { + exitWithMessage(err.Error()) + } else { + return } - - exitWithMessage(msg) } if !command.Opts.Sessions { @@ -280,6 +261,39 @@ func openPage() { } } +// testClient attempts to establish a database connection until it succeeds or +// give up after certain number of retries. Retries only available when database +// name or a connection string is provided. +func testClient(cl *client.Client, retryCount int, retryDelay time.Duration) (abort bool, err error) { + usingDefaultDB := command.Opts.DbName == "" && command.Opts.URL == "" + + for { + err = cl.Test() + if err == nil { + return false, nil + } + + // Continue normal start up if can't connect locally without database details. + if usingDefaultDB { + if errors.Is(err, client.ErrConnectionRefused) || + errors.Is(err, client.ErrAuthFailed) || + errors.Is(err, client.ErrDatabaseNotExist) { + return false, err + } + } + + // Only retry if can't establish connection to the server. + if errors.Is(err, client.ErrConnectionRefused) && retryCount > 0 { + fmt.Printf("Connection error: %v, retrying in %v (%d remaining)\n", err, retryDelay, retryCount) + retryCount-- + <-time.After(retryDelay) + continue + } + + return true, err + } +} + func Run() { initOptions() initClient() diff --git a/pkg/client/client.go b/pkg/client/client.go index 0d3988b..00ff7d1 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -7,6 +7,7 @@ import ( "log" neturl "net/url" "reflect" + "regexp" "strings" "time" @@ -21,6 +22,18 @@ import ( "github.com/sosedoff/pgweb/pkg/statements" ) +var ( + regexErrAuthFailed = regexp.MustCompile(`(authentication failed|role "(.*)" does not exist)`) + regexErrConnectionRefused = regexp.MustCompile(`(connection|actively) refused`) + regexErrDatabaseNotExist = regexp.MustCompile(`database "(.*)" does not exist`) +) + +var ( + ErrAuthFailed = errors.New("authentication failed") + ErrConnectionRefused = errors.New("connection refused") + ErrDatabaseNotExist = errors.New("database does not exist") +) + type Client struct { db *sqlx.DB tunnel *Tunnel @@ -179,7 +192,28 @@ func (client *Client) setServerVersion() { } func (client *Client) Test() error { - return client.db.Ping() + // NOTE: This is a different timeout defined in CLI OpenTimeout + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + err := client.db.PingContext(ctx) + if err == nil { + return nil + } + + errMsg := err.Error() + + if regexErrConnectionRefused.MatchString(errMsg) { + return ErrConnectionRefused + } + if regexErrAuthFailed.MatchString(errMsg) { + return ErrAuthFailed + } + if regexErrDatabaseNotExist.MatchString(errMsg) { + return ErrDatabaseNotExist + } + + return err } func (client *Client) TestWithTimeout(timeout time.Duration) (result error) { diff --git a/pkg/client/client_test.go b/pkg/client/client_test.go index a139f77..f2488b9 100644 --- a/pkg/client/client_test.go +++ b/pkg/client/client_test.go @@ -14,6 +14,7 @@ import ( "github.com/sosedoff/pgweb/pkg/command" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) var ( @@ -199,7 +200,46 @@ func testClientIdleTime(t *testing.T) { } func testTest(t *testing.T) { - assert.NoError(t, testClient.Test()) + examples := []struct { + name string + input string + err error + }{ + { + name: "success", + input: fmt.Sprintf("postgres://%s@%s:%s/%s?sslmode=disable", serverUser, serverHost, serverPort, serverDatabase), + err: nil, + }, + { + name: "connection refused", + input: "postgresql://localhost:5433/dbname", + err: ErrConnectionRefused, + }, + { + name: "invalid user", + input: fmt.Sprintf("postgres://%s:%s@%s:%s/%s?sslmode=disable", "foo", serverPassword, serverHost, serverPort, serverDatabase), + err: ErrAuthFailed, + }, + { + name: "invalid password", + input: fmt.Sprintf("postgres://%s:%s@%s:%s/%s?sslmode=disable", serverUser, "foo", serverHost, serverPort, serverDatabase), + err: ErrAuthFailed, + }, + { + name: "invalid database", + input: fmt.Sprintf("postgres://%s@%s:%s/%s?sslmode=disable", serverUser, serverHost, serverPort, "foo"), + err: ErrDatabaseNotExist, + }, + } + + for _, ex := range examples { + t.Run(ex.name, func(t *testing.T) { + conn, err := NewFromUrl(ex.input, nil) + require.NoError(t, err) + + require.Equal(t, ex.err, conn.Test()) + }) + } } func testInfo(t *testing.T) { diff --git a/pkg/command/options.go b/pkg/command/options.go index 6edb43b..0e9fce2 100644 --- a/pkg/command/options.go +++ b/pkg/command/options.go @@ -36,7 +36,9 @@ type Options struct { SSLRootCert string `long:"ssl-rootcert" description:"SSL certificate authority file"` SSLCert string `long:"ssl-cert" description:"SSL client certificate file"` SSLKey string `long:"ssl-key" description:"SSL client certificate key file"` - OpenTimeout int `long:"open-timeout" description:" Maximum wait for connection, in seconds" default:"30"` + OpenTimeout int `long:"open-timeout" description:"Maximum wait time for connection, in seconds" default:"30"` + RetryDelay uint `long:"open-retry-delay" description:"Number of seconds to wait before retrying the connection" default:"3"` + RetryCount uint `long:"open-retry" description:"Number of times to retry establishing connection" default:"0"` HTTPHost string `long:"bind" description:"HTTP server host" default:"localhost"` HTTPPort uint `long:"listen" description:"HTTP server listen port" default:"8081"` AuthUser string `long:"auth-user" description:"HTTP basic auth user"`