mirror of
https://github.com/sosedoff/pgweb.git
synced 2024-12-15 03:36:33 +03:00
Add tests for getSessionId helper
This commit is contained in:
parent
86f63eecc5
commit
c57b477dc9
@ -22,7 +22,7 @@ var (
|
||||
|
||||
func DB(c *gin.Context) *client.Client {
|
||||
if command.Opts.Sessions {
|
||||
return DbSessions[getSessionId(c)]
|
||||
return DbSessions[getSessionId(c.Request)]
|
||||
} else {
|
||||
return DbClient
|
||||
}
|
||||
@ -39,7 +39,7 @@ func setClient(c *gin.Context, newClient *client.Client) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
sessionId := getSessionId(c)
|
||||
sessionId := getSessionId(c.Request)
|
||||
if sessionId == "" {
|
||||
return errors.New("Session ID is required")
|
||||
}
|
||||
|
@ -3,6 +3,7 @@ package api
|
||||
import (
|
||||
"fmt"
|
||||
"mime"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
@ -70,10 +71,10 @@ func desanitize64(query string) string {
|
||||
return query
|
||||
}
|
||||
|
||||
func getSessionId(c *gin.Context) string {
|
||||
id := c.Request.Header.Get("x-session-id")
|
||||
func getSessionId(req *http.Request) string {
|
||||
id := req.Header.Get("x-session-id")
|
||||
if id == "" {
|
||||
id = c.Request.URL.Query().Get("_session_id")
|
||||
id = req.URL.Query().Get("_session_id")
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
@ -1,8 +1,11 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_desanitize64(t *testing.T) {
|
||||
@ -23,3 +26,13 @@ func Test_cleanQuery(t *testing.T) {
|
||||
assert.Equal(t, "", cleanQuery("--something"))
|
||||
assert.Equal(t, "test", cleanQuery("--test\ntest\n -- test\n"))
|
||||
}
|
||||
|
||||
func Test_getSessionId(t *testing.T) {
|
||||
req := &http.Request{Header: http.Header{}}
|
||||
req.Header.Add("x-session-id", "token")
|
||||
assert.Equal(t, "token", getSessionId(req))
|
||||
|
||||
req = &http.Request{}
|
||||
req.URL, _ = url.Parse("http://foobar/?_session_id=token")
|
||||
assert.Equal(t, "token", getSessionId(req))
|
||||
}
|
||||
|
@ -32,7 +32,7 @@ func dbCheckMiddleware() gin.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
sessionId := getSessionId(c)
|
||||
sessionId := getSessionId(c.Request)
|
||||
if sessionId == "" {
|
||||
c.JSON(400, Error{"Session ID is required"})
|
||||
c.Abort()
|
||||
|
Loading…
Reference in New Issue
Block a user