mirror of
https://github.com/sosedoff/pgweb.git
synced 2024-12-13 15:35:28 +03:00
Allow retrying a connection on startup (#695)
* Allow retrying a connection on startup * Remove unused vars * Add an extra comment * Restructure retry logic a bit * Update retry logic * Fix comment * Update comment * Change type for RetryCount and RetryDelay to uint * Extra test cases * Tweak
This commit is contained in:
parent
f810c0227b
commit
fe5039d17a
@ -6,7 +6,6 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"regexp"
|
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
@ -35,9 +34,6 @@ SECURITY WARNING: You are running Pgweb in read-only mode.
|
|||||||
This mode is designed for environments where users could potentially delete or change data.
|
This mode is designed for environments where users could potentially delete or change data.
|
||||||
For proper read-only access please follow PostgreSQL role management documentation.
|
For proper read-only access please follow PostgreSQL role management documentation.
|
||||||
--------------------------------------------------------------------------------`
|
--------------------------------------------------------------------------------`
|
||||||
|
|
||||||
regexErrConnectionRefused = regexp.MustCompile(`(connection|actively) refused`)
|
|
||||||
regexErrAuthFailed = regexp.MustCompile(`authentication failed`)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
@ -77,35 +73,20 @@ func initClient() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if command.Opts.Debug {
|
if command.Opts.Debug {
|
||||||
fmt.Println("Server connection string:", cl.ConnectionString)
|
fmt.Println("Opening database connection using string:", cl.ConnectionString)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
retryCount := command.Opts.RetryCount
|
||||||
|
retryDelay := time.Second * time.Duration(command.Opts.RetryDelay)
|
||||||
|
|
||||||
fmt.Println("Connecting to server...")
|
fmt.Println("Connecting to server...")
|
||||||
if err := cl.Test(); err != nil {
|
abort, err := testClient(cl, int(retryCount), retryDelay)
|
||||||
msg := err.Error()
|
if err != nil {
|
||||||
|
if abort {
|
||||||
// Check if we're trying to connect to the default database.
|
exitWithMessage(err.Error())
|
||||||
if command.Opts.DbName == "" && command.Opts.URL == "" {
|
} else {
|
||||||
// If database does not exist, allow user to connect from the UI.
|
return
|
||||||
if strings.Contains(msg, "database") && strings.Contains(msg, "does not exist") {
|
|
||||||
fmt.Println("Error:", msg)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Do not bail if local server is not running.
|
|
||||||
if regexErrConnectionRefused.MatchString(msg) {
|
|
||||||
fmt.Println("Error:", msg)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Do not bail if local auth is invalid
|
|
||||||
if regexErrAuthFailed.MatchString(msg) {
|
|
||||||
fmt.Println("Error:", msg)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
exitWithMessage(msg)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if !command.Opts.Sessions {
|
if !command.Opts.Sessions {
|
||||||
@ -280,6 +261,39 @@ func openPage() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// testClient attempts to establish a database connection until it succeeds or
|
||||||
|
// give up after certain number of retries. Retries only available when database
|
||||||
|
// name or a connection string is provided.
|
||||||
|
func testClient(cl *client.Client, retryCount int, retryDelay time.Duration) (abort bool, err error) {
|
||||||
|
usingDefaultDB := command.Opts.DbName == "" && command.Opts.URL == ""
|
||||||
|
|
||||||
|
for {
|
||||||
|
err = cl.Test()
|
||||||
|
if err == nil {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Continue normal start up if can't connect locally without database details.
|
||||||
|
if usingDefaultDB {
|
||||||
|
if errors.Is(err, client.ErrConnectionRefused) ||
|
||||||
|
errors.Is(err, client.ErrAuthFailed) ||
|
||||||
|
errors.Is(err, client.ErrDatabaseNotExist) {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only retry if can't establish connection to the server.
|
||||||
|
if errors.Is(err, client.ErrConnectionRefused) && retryCount > 0 {
|
||||||
|
fmt.Printf("Connection error: %v, retrying in %v (%d remaining)\n", err, retryDelay, retryCount)
|
||||||
|
retryCount--
|
||||||
|
<-time.After(retryDelay)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return true, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func Run() {
|
func Run() {
|
||||||
initOptions()
|
initOptions()
|
||||||
initClient()
|
initClient()
|
||||||
|
@ -7,6 +7,7 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
neturl "net/url"
|
neturl "net/url"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -21,6 +22,18 @@ import (
|
|||||||
"github.com/sosedoff/pgweb/pkg/statements"
|
"github.com/sosedoff/pgweb/pkg/statements"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
regexErrAuthFailed = regexp.MustCompile(`(authentication failed|role "(.*)" does not exist)`)
|
||||||
|
regexErrConnectionRefused = regexp.MustCompile(`(connection|actively) refused`)
|
||||||
|
regexErrDatabaseNotExist = regexp.MustCompile(`database "(.*)" does not exist`)
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrAuthFailed = errors.New("authentication failed")
|
||||||
|
ErrConnectionRefused = errors.New("connection refused")
|
||||||
|
ErrDatabaseNotExist = errors.New("database does not exist")
|
||||||
|
)
|
||||||
|
|
||||||
type Client struct {
|
type Client struct {
|
||||||
db *sqlx.DB
|
db *sqlx.DB
|
||||||
tunnel *Tunnel
|
tunnel *Tunnel
|
||||||
@ -179,7 +192,28 @@ func (client *Client) setServerVersion() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (client *Client) Test() error {
|
func (client *Client) Test() error {
|
||||||
return client.db.Ping()
|
// NOTE: This is a different timeout defined in CLI OpenTimeout
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
err := client.db.PingContext(ctx)
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
errMsg := err.Error()
|
||||||
|
|
||||||
|
if regexErrConnectionRefused.MatchString(errMsg) {
|
||||||
|
return ErrConnectionRefused
|
||||||
|
}
|
||||||
|
if regexErrAuthFailed.MatchString(errMsg) {
|
||||||
|
return ErrAuthFailed
|
||||||
|
}
|
||||||
|
if regexErrDatabaseNotExist.MatchString(errMsg) {
|
||||||
|
return ErrDatabaseNotExist
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (client *Client) TestWithTimeout(timeout time.Duration) (result error) {
|
func (client *Client) TestWithTimeout(timeout time.Duration) (result error) {
|
||||||
|
@ -14,6 +14,7 @@ import (
|
|||||||
"github.com/sosedoff/pgweb/pkg/command"
|
"github.com/sosedoff/pgweb/pkg/command"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -199,7 +200,46 @@ func testClientIdleTime(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func testTest(t *testing.T) {
|
func testTest(t *testing.T) {
|
||||||
assert.NoError(t, testClient.Test())
|
examples := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
err error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "success",
|
||||||
|
input: fmt.Sprintf("postgres://%s@%s:%s/%s?sslmode=disable", serverUser, serverHost, serverPort, serverDatabase),
|
||||||
|
err: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "connection refused",
|
||||||
|
input: "postgresql://localhost:5433/dbname",
|
||||||
|
err: ErrConnectionRefused,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid user",
|
||||||
|
input: fmt.Sprintf("postgres://%s:%s@%s:%s/%s?sslmode=disable", "foo", serverPassword, serverHost, serverPort, serverDatabase),
|
||||||
|
err: ErrAuthFailed,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid password",
|
||||||
|
input: fmt.Sprintf("postgres://%s:%s@%s:%s/%s?sslmode=disable", serverUser, "foo", serverHost, serverPort, serverDatabase),
|
||||||
|
err: ErrAuthFailed,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid database",
|
||||||
|
input: fmt.Sprintf("postgres://%s@%s:%s/%s?sslmode=disable", serverUser, serverHost, serverPort, "foo"),
|
||||||
|
err: ErrDatabaseNotExist,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ex := range examples {
|
||||||
|
t.Run(ex.name, func(t *testing.T) {
|
||||||
|
conn, err := NewFromUrl(ex.input, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Equal(t, ex.err, conn.Test())
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func testInfo(t *testing.T) {
|
func testInfo(t *testing.T) {
|
||||||
|
@ -36,7 +36,9 @@ type Options struct {
|
|||||||
SSLRootCert string `long:"ssl-rootcert" description:"SSL certificate authority file"`
|
SSLRootCert string `long:"ssl-rootcert" description:"SSL certificate authority file"`
|
||||||
SSLCert string `long:"ssl-cert" description:"SSL client certificate file"`
|
SSLCert string `long:"ssl-cert" description:"SSL client certificate file"`
|
||||||
SSLKey string `long:"ssl-key" description:"SSL client certificate key file"`
|
SSLKey string `long:"ssl-key" description:"SSL client certificate key file"`
|
||||||
OpenTimeout int `long:"open-timeout" description:" Maximum wait for connection, in seconds" default:"30"`
|
OpenTimeout int `long:"open-timeout" description:"Maximum wait time for connection, in seconds" default:"30"`
|
||||||
|
RetryDelay uint `long:"open-retry-delay" description:"Number of seconds to wait before retrying the connection" default:"3"`
|
||||||
|
RetryCount uint `long:"open-retry" description:"Number of times to retry establishing connection" default:"0"`
|
||||||
HTTPHost string `long:"bind" description:"HTTP server host" default:"localhost"`
|
HTTPHost string `long:"bind" description:"HTTP server host" default:"localhost"`
|
||||||
HTTPPort uint `long:"listen" description:"HTTP server listen port" default:"8081"`
|
HTTPPort uint `long:"listen" description:"HTTP server listen port" default:"8081"`
|
||||||
AuthUser string `long:"auth-user" description:"HTTP basic auth user"`
|
AuthUser string `long:"auth-user" description:"HTTP basic auth user"`
|
||||||
|
Loading…
Reference in New Issue
Block a user