Allow retrying a connection on startup (#695)

* Allow retrying a connection on startup
* Remove unused vars
* Add an extra comment
* Restructure retry logic a bit
* Update retry logic
* Fix comment
* Update comment
* Change type for RetryCount and RetryDelay to uint
* Extra test cases
* Tweak
This commit is contained in:
Dan Sosedoff 2023-11-04 11:12:48 -05:00 committed by GitHub
parent f810c0227b
commit fe5039d17a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 122 additions and 32 deletions

View File

@ -6,7 +6,6 @@ import (
"os" "os"
"os/exec" "os/exec"
"os/signal" "os/signal"
"regexp"
"strings" "strings"
"syscall" "syscall"
"time" "time"
@ -35,9 +34,6 @@ SECURITY WARNING: You are running Pgweb in read-only mode.
This mode is designed for environments where users could potentially delete or change data. This mode is designed for environments where users could potentially delete or change data.
For proper read-only access please follow PostgreSQL role management documentation. For proper read-only access please follow PostgreSQL role management documentation.
--------------------------------------------------------------------------------` --------------------------------------------------------------------------------`
regexErrConnectionRefused = regexp.MustCompile(`(connection|actively) refused`)
regexErrAuthFailed = regexp.MustCompile(`authentication failed`)
) )
func init() { func init() {
@ -77,35 +73,20 @@ func initClient() {
} }
if command.Opts.Debug { if command.Opts.Debug {
fmt.Println("Server connection string:", cl.ConnectionString) fmt.Println("Opening database connection using string:", cl.ConnectionString)
} }
retryCount := command.Opts.RetryCount
retryDelay := time.Second * time.Duration(command.Opts.RetryDelay)
fmt.Println("Connecting to server...") fmt.Println("Connecting to server...")
if err := cl.Test(); err != nil { abort, err := testClient(cl, int(retryCount), retryDelay)
msg := err.Error() if err != nil {
if abort {
// Check if we're trying to connect to the default database. exitWithMessage(err.Error())
if command.Opts.DbName == "" && command.Opts.URL == "" { } else {
// If database does not exist, allow user to connect from the UI. return
if strings.Contains(msg, "database") && strings.Contains(msg, "does not exist") {
fmt.Println("Error:", msg)
return
}
// Do not bail if local server is not running.
if regexErrConnectionRefused.MatchString(msg) {
fmt.Println("Error:", msg)
return
}
// Do not bail if local auth is invalid
if regexErrAuthFailed.MatchString(msg) {
fmt.Println("Error:", msg)
return
}
} }
exitWithMessage(msg)
} }
if !command.Opts.Sessions { if !command.Opts.Sessions {
@ -280,6 +261,39 @@ func openPage() {
} }
} }
// testClient attempts to establish a database connection until it succeeds or
// give up after certain number of retries. Retries only available when database
// name or a connection string is provided.
func testClient(cl *client.Client, retryCount int, retryDelay time.Duration) (abort bool, err error) {
usingDefaultDB := command.Opts.DbName == "" && command.Opts.URL == ""
for {
err = cl.Test()
if err == nil {
return false, nil
}
// Continue normal start up if can't connect locally without database details.
if usingDefaultDB {
if errors.Is(err, client.ErrConnectionRefused) ||
errors.Is(err, client.ErrAuthFailed) ||
errors.Is(err, client.ErrDatabaseNotExist) {
return false, err
}
}
// Only retry if can't establish connection to the server.
if errors.Is(err, client.ErrConnectionRefused) && retryCount > 0 {
fmt.Printf("Connection error: %v, retrying in %v (%d remaining)\n", err, retryDelay, retryCount)
retryCount--
<-time.After(retryDelay)
continue
}
return true, err
}
}
func Run() { func Run() {
initOptions() initOptions()
initClient() initClient()

View File

@ -7,6 +7,7 @@ import (
"log" "log"
neturl "net/url" neturl "net/url"
"reflect" "reflect"
"regexp"
"strings" "strings"
"time" "time"
@ -21,6 +22,18 @@ import (
"github.com/sosedoff/pgweb/pkg/statements" "github.com/sosedoff/pgweb/pkg/statements"
) )
var (
regexErrAuthFailed = regexp.MustCompile(`(authentication failed|role "(.*)" does not exist)`)
regexErrConnectionRefused = regexp.MustCompile(`(connection|actively) refused`)
regexErrDatabaseNotExist = regexp.MustCompile(`database "(.*)" does not exist`)
)
var (
ErrAuthFailed = errors.New("authentication failed")
ErrConnectionRefused = errors.New("connection refused")
ErrDatabaseNotExist = errors.New("database does not exist")
)
type Client struct { type Client struct {
db *sqlx.DB db *sqlx.DB
tunnel *Tunnel tunnel *Tunnel
@ -179,7 +192,28 @@ func (client *Client) setServerVersion() {
} }
func (client *Client) Test() error { func (client *Client) Test() error {
return client.db.Ping() // NOTE: This is a different timeout defined in CLI OpenTimeout
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
err := client.db.PingContext(ctx)
if err == nil {
return nil
}
errMsg := err.Error()
if regexErrConnectionRefused.MatchString(errMsg) {
return ErrConnectionRefused
}
if regexErrAuthFailed.MatchString(errMsg) {
return ErrAuthFailed
}
if regexErrDatabaseNotExist.MatchString(errMsg) {
return ErrDatabaseNotExist
}
return err
} }
func (client *Client) TestWithTimeout(timeout time.Duration) (result error) { func (client *Client) TestWithTimeout(timeout time.Duration) (result error) {

View File

@ -14,6 +14,7 @@ import (
"github.com/sosedoff/pgweb/pkg/command" "github.com/sosedoff/pgweb/pkg/command"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
var ( var (
@ -199,7 +200,46 @@ func testClientIdleTime(t *testing.T) {
} }
func testTest(t *testing.T) { func testTest(t *testing.T) {
assert.NoError(t, testClient.Test()) examples := []struct {
name string
input string
err error
}{
{
name: "success",
input: fmt.Sprintf("postgres://%s@%s:%s/%s?sslmode=disable", serverUser, serverHost, serverPort, serverDatabase),
err: nil,
},
{
name: "connection refused",
input: "postgresql://localhost:5433/dbname",
err: ErrConnectionRefused,
},
{
name: "invalid user",
input: fmt.Sprintf("postgres://%s:%s@%s:%s/%s?sslmode=disable", "foo", serverPassword, serverHost, serverPort, serverDatabase),
err: ErrAuthFailed,
},
{
name: "invalid password",
input: fmt.Sprintf("postgres://%s:%s@%s:%s/%s?sslmode=disable", serverUser, "foo", serverHost, serverPort, serverDatabase),
err: ErrAuthFailed,
},
{
name: "invalid database",
input: fmt.Sprintf("postgres://%s@%s:%s/%s?sslmode=disable", serverUser, serverHost, serverPort, "foo"),
err: ErrDatabaseNotExist,
},
}
for _, ex := range examples {
t.Run(ex.name, func(t *testing.T) {
conn, err := NewFromUrl(ex.input, nil)
require.NoError(t, err)
require.Equal(t, ex.err, conn.Test())
})
}
} }
func testInfo(t *testing.T) { func testInfo(t *testing.T) {

View File

@ -36,7 +36,9 @@ type Options struct {
SSLRootCert string `long:"ssl-rootcert" description:"SSL certificate authority file"` SSLRootCert string `long:"ssl-rootcert" description:"SSL certificate authority file"`
SSLCert string `long:"ssl-cert" description:"SSL client certificate file"` SSLCert string `long:"ssl-cert" description:"SSL client certificate file"`
SSLKey string `long:"ssl-key" description:"SSL client certificate key file"` SSLKey string `long:"ssl-key" description:"SSL client certificate key file"`
OpenTimeout int `long:"open-timeout" description:" Maximum wait for connection, in seconds" default:"30"` OpenTimeout int `long:"open-timeout" description:"Maximum wait time for connection, in seconds" default:"30"`
RetryDelay uint `long:"open-retry-delay" description:"Number of seconds to wait before retrying the connection" default:"3"`
RetryCount uint `long:"open-retry" description:"Number of times to retry establishing connection" default:"0"`
HTTPHost string `long:"bind" description:"HTTP server host" default:"localhost"` HTTPHost string `long:"bind" description:"HTTP server host" default:"localhost"`
HTTPPort uint `long:"listen" description:"HTTP server listen port" default:"8081"` HTTPPort uint `long:"listen" description:"HTTP server listen port" default:"8081"`
AuthUser string `long:"auth-user" description:"HTTP basic auth user"` AuthUser string `long:"auth-user" description:"HTTP basic auth user"`