diff --git a/pkg/api/api.go b/pkg/api/api.go index e7ecb5d..0cad11d 100644 --- a/pkg/api/api.go +++ b/pkg/api/api.go @@ -14,11 +14,37 @@ import ( "github.com/sosedoff/pgweb/pkg/connection" ) -var DbClient *client.Client -var DbSessions = map[string]*client.Client{} +var ( + DbClient *client.Client + DbSessions = map[string]*client.Client{} +) func DB(c *gin.Context) *client.Client { - return DbSessions[getSessionId(c)] + if command.Opts.Sessions { + return DbSessions[getSessionId(c)] + } else { + return DbClient + } +} + +func setClient(c *gin.Context, newClient *client.Client) error { + currentClient := DB(c) + if currentClient != nil { + currentClient.Close() + } + + if !command.Opts.Sessions { + DbClient = newClient + return nil + } + + sessionId := getSessionId(c) + if sessionId == "" { + return errors.New("Session ID is required") + } + + DbSessions[sessionId] = newClient + return nil } func GetHome(c *gin.Context) { @@ -41,12 +67,6 @@ func Connect(c *gin.Context) { return } - sessionId := getSessionId(c) - if sessionId == "" { - c.JSON(400, Error{"Session ID is required"}) - return - } - opts := command.Options{Url: url} url, err := connection.FormatUrl(opts) @@ -68,14 +88,13 @@ func Connect(c *gin.Context) { } info, err := cl.Info() - if err == nil { - db := DbSessions[sessionId] - if db != nil { - db.Close() + err = setClient(c, cl) + if err != nil { + cl.Close() + c.JSON(400, Error{err.Error()}) + return } - - DbSessions[sessionId] = cl } c.JSON(200, info.Format()[0]) diff --git a/pkg/api/helpers.go b/pkg/api/helpers.go index 4372dfa..821abbb 100644 --- a/pkg/api/helpers.go +++ b/pkg/api/helpers.go @@ -8,6 +8,8 @@ import ( "strconv" "github.com/gin-gonic/gin" + + "github.com/sosedoff/pgweb/pkg/command" "github.com/sosedoff/pgweb/pkg/data" ) @@ -98,6 +100,18 @@ func dbCheckMiddleware() gin.HandlerFunc { return } + // We dont care about sessions unless they're enabled + if !command.Opts.Sessions { + if DbClient == nil { + c.JSON(400, Error{"Not connected"}) + c.Abort() + return + } + + c.Next() + return + } + sessionId := getSessionId(c) if sessionId == "" { c.JSON(400, Error{"Session ID is required"}) @@ -112,7 +126,6 @@ func dbCheckMiddleware() gin.HandlerFunc { return } - c.Set("db", conn) c.Next() } } diff --git a/pkg/api/routes.go b/pkg/api/routes.go index 0fa643d..b97bb5d 100644 --- a/pkg/api/routes.go +++ b/pkg/api/routes.go @@ -21,8 +21,11 @@ func SetupRoutes(router *gin.Engine) { { SetupMiddlewares(api) + if command.Opts.Sessions { + api.GET("/sessions", GetSessions) + } + api.GET("/info", GetInfo) - api.GET("/sessions", GetSessions) api.POST("/connect", Connect) api.GET("/databases", GetDatabases) api.GET("/connection", GetConnectionInfo) diff --git a/pkg/command/options.go b/pkg/command/options.go index 814893f..2a53565 100644 --- a/pkg/command/options.go +++ b/pkg/command/options.go @@ -21,6 +21,7 @@ type Options struct { AuthUser string `long:"auth-user" description:"HTTP basic auth user"` AuthPass string `long:"auth-pass" description:"HTTP basic auth password"` SkipOpen bool `short:"s" long:"skip-open" description:"Skip browser open on start"` + Sessions bool `long:"sessions" description:"Enable multiple database sessions" default:"false"` } var Opts Options