mirror of
https://github.com/sosedoff/pgweb.git
synced 2024-12-14 10:23:02 +03:00
Add connect backend tests (#546)
This commit is contained in:
parent
0794c642e4
commit
e2f5e06c07
@ -1,6 +1,7 @@
|
|||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
@ -94,8 +95,11 @@ func ConnectWithBackend(c *gin.Context) {
|
|||||||
PassHeaders: command.Opts.ConnectHeaders,
|
PassHeaders: command.Opts.ConnectHeaders,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
// Fetch connection credentials
|
// Fetch connection credentials
|
||||||
cred, err := backend.FetchCredential(c.Param("resource"), c)
|
cred, err := backend.FetchCredential(ctx, c.Param("resource"), c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
badRequest(c, err)
|
badRequest(c, err)
|
||||||
return
|
return
|
||||||
|
@ -2,6 +2,7 @@ package api
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
@ -31,7 +32,7 @@ type BackendCredential struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// FetchCredential sends an authentication request to a third-party service
|
// FetchCredential sends an authentication request to a third-party service
|
||||||
func (be Backend) FetchCredential(resource string, c *gin.Context) (*BackendCredential, error) {
|
func (be Backend) FetchCredential(ctx context.Context, resource string, c *gin.Context) (*BackendCredential, error) {
|
||||||
request := BackendRequest{
|
request := BackendRequest{
|
||||||
Resource: resource,
|
Resource: resource,
|
||||||
Token: be.Token,
|
Token: be.Token,
|
||||||
@ -45,21 +46,27 @@ func (be Backend) FetchCredential(resource string, c *gin.Context) (*BackendCred
|
|||||||
|
|
||||||
body, err := json.Marshal(request)
|
body, err := json.Marshal(request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Println("[BACKEND] backend request serialization error:", err)
|
||||||
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := http.Post(be.Endpoint, "application/json", bytes.NewReader(body))
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, be.Endpoint, bytes.NewReader(body))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Set("content-type", "application/json")
|
||||||
|
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Any connection-related issues will show up in the server log
|
// Any connection-related issues will show up in the server log
|
||||||
log.Println("Unable to fetch backend credential:", err)
|
log.Println("[BACKEND] unable to fetch credential:", err)
|
||||||
|
|
||||||
// We dont want to expose the url of the backend here, so reply with generic error
|
|
||||||
return nil, errBackendConnectError
|
return nil, errBackendConnectError
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
if resp.StatusCode != 200 {
|
if resp.StatusCode != 200 {
|
||||||
return nil, fmt.Errorf("Got HTTP error %v from backend", resp.StatusCode)
|
return nil, fmt.Errorf("received HTTP status code %v", resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
cred := &BackendCredential{}
|
cred := &BackendCredential{}
|
||||||
|
148
pkg/api/backend_test.go
Normal file
148
pkg/api/backend_test.go
Normal file
@ -0,0 +1,148 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBackendFetchCredential(t *testing.T) {
|
||||||
|
examples := []struct {
|
||||||
|
name string
|
||||||
|
backend Backend
|
||||||
|
resourceName string
|
||||||
|
cred *BackendCredential
|
||||||
|
reqCtx *gin.Context
|
||||||
|
ctx func() (context.Context, context.CancelFunc)
|
||||||
|
err error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Bad auth token",
|
||||||
|
backend: Backend{Endpoint: "http://localhost:5555/unauthorized"},
|
||||||
|
err: errors.New("received HTTP status code 401"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Backend timeout",
|
||||||
|
backend: Backend{Endpoint: "http://localhost:5555/timeout"},
|
||||||
|
ctx: func() (context.Context, context.CancelFunc) {
|
||||||
|
return context.WithTimeout(context.Background(), time.Millisecond*100)
|
||||||
|
},
|
||||||
|
err: errors.New("Unable to connect to the auth backend"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Empty response",
|
||||||
|
backend: Backend{Endpoint: "http://localhost:5555/empty-response"},
|
||||||
|
err: errors.New("Connection string is required"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Missing header",
|
||||||
|
backend: Backend{Endpoint: "http://localhost:5555/pass-header"},
|
||||||
|
err: errors.New("received HTTP status code 400"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Require header",
|
||||||
|
backend: Backend{
|
||||||
|
Endpoint: "http://localhost:5555/pass-header",
|
||||||
|
PassHeaders: "x-foo",
|
||||||
|
},
|
||||||
|
reqCtx: &gin.Context{
|
||||||
|
Request: &http.Request{
|
||||||
|
Header: http.Header{
|
||||||
|
"X-Foo": []string{"bar"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
cred: &BackendCredential{DatabaseURL: "postgres://hostname/bar"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Success",
|
||||||
|
backend: Backend{Endpoint: "http://localhost:5555/success"},
|
||||||
|
cred: &BackendCredential{DatabaseURL: "postgres://hostname/dbname"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
srvCtx, srvCancel := context.WithTimeout(context.Background(), time.Minute)
|
||||||
|
defer srvCancel()
|
||||||
|
|
||||||
|
go startTestBackend(srvCtx, "localhost:5555")
|
||||||
|
|
||||||
|
for _, ex := range examples {
|
||||||
|
t.Run(ex.name, func(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
if ex.ctx != nil {
|
||||||
|
ctx, cancel = ex.ctx()
|
||||||
|
}
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
reqCtx := ex.reqCtx
|
||||||
|
if reqCtx == nil {
|
||||||
|
reqCtx = &gin.Context{
|
||||||
|
Request: &http.Request{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cred, err := ex.backend.FetchCredential(ctx, ex.resourceName, reqCtx)
|
||||||
|
assert.Equal(t, ex.err, err)
|
||||||
|
assert.Equal(t, ex.cred, cred)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func startTestBackend(ctx context.Context, listenAddr string) {
|
||||||
|
router := gin.New()
|
||||||
|
|
||||||
|
router.Use(func(c *gin.Context) {
|
||||||
|
if c.GetHeader("content-type") != "application/json" {
|
||||||
|
c.AbortWithStatus(http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
router.POST("/unauthorized", func(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
|
||||||
|
})
|
||||||
|
|
||||||
|
router.POST("/timeout", func(c *gin.Context) {
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
c.JSON(http.StatusOK, gin.H{})
|
||||||
|
})
|
||||||
|
|
||||||
|
router.POST("/empty-response", func(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusOK, gin.H{})
|
||||||
|
})
|
||||||
|
|
||||||
|
router.POST("/pass-header", func(c *gin.Context) {
|
||||||
|
req := BackendRequest{}
|
||||||
|
if err := c.BindJSON(&req); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
header := req.Headers["x-foo"]
|
||||||
|
if header == "" {
|
||||||
|
c.AbortWithStatus(http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"database_url": "postgres://hostname/" + header,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
router.POST("/success", func(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"database_url": "postgres://hostname/dbname",
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
server := &http.Server{Addr: listenAddr, Handler: router}
|
||||||
|
go server.ListenAndServe()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
server.Shutdown(context.Background())
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user