Add connect backend tests (#546)

This commit is contained in:
Dan Sosedoff 2022-01-08 14:45:21 -06:00 committed by GitHub
parent 0794c642e4
commit e2f5e06c07
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 166 additions and 7 deletions

View File

@ -1,6 +1,7 @@
package api package api
import ( import (
"context"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"net/http" "net/http"
@ -94,8 +95,11 @@ func ConnectWithBackend(c *gin.Context) {
PassHeaders: command.Opts.ConnectHeaders, PassHeaders: command.Opts.ConnectHeaders,
} }
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
// Fetch connection credentials // Fetch connection credentials
cred, err := backend.FetchCredential(c.Param("resource"), c) cred, err := backend.FetchCredential(ctx, c.Param("resource"), c)
if err != nil { if err != nil {
badRequest(c, err) badRequest(c, err)
return return

View File

@ -2,6 +2,7 @@ package api
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"log" "log"
@ -31,7 +32,7 @@ type BackendCredential struct {
} }
// FetchCredential sends an authentication request to a third-party service // FetchCredential sends an authentication request to a third-party service
func (be Backend) FetchCredential(resource string, c *gin.Context) (*BackendCredential, error) { func (be Backend) FetchCredential(ctx context.Context, resource string, c *gin.Context) (*BackendCredential, error) {
request := BackendRequest{ request := BackendRequest{
Resource: resource, Resource: resource,
Token: be.Token, Token: be.Token,
@ -45,21 +46,27 @@ func (be Backend) FetchCredential(resource string, c *gin.Context) (*BackendCred
body, err := json.Marshal(request) body, err := json.Marshal(request)
if err != nil { if err != nil {
log.Println("[BACKEND] backend request serialization error:", err)
return nil, err return nil, err
} }
resp, err := http.Post(be.Endpoint, "application/json", bytes.NewReader(body)) req, err := http.NewRequestWithContext(ctx, http.MethodPost, be.Endpoint, bytes.NewReader(body))
if err != nil {
return nil, err
}
req.Header.Set("content-type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {
// Any connection-related issues will show up in the server log // Any connection-related issues will show up in the server log
log.Println("Unable to fetch backend credential:", err) log.Println("[BACKEND] unable to fetch credential:", err)
// We dont want to expose the url of the backend here, so reply with generic error
return nil, errBackendConnectError return nil, errBackendConnectError
} }
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode != 200 { if resp.StatusCode != 200 {
return nil, fmt.Errorf("Got HTTP error %v from backend", resp.StatusCode) return nil, fmt.Errorf("received HTTP status code %v", resp.StatusCode)
} }
cred := &BackendCredential{} cred := &BackendCredential{}

148
pkg/api/backend_test.go Normal file
View File

@ -0,0 +1,148 @@
package api
import (
"context"
"errors"
"net/http"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
)
func TestBackendFetchCredential(t *testing.T) {
examples := []struct {
name string
backend Backend
resourceName string
cred *BackendCredential
reqCtx *gin.Context
ctx func() (context.Context, context.CancelFunc)
err error
}{
{
name: "Bad auth token",
backend: Backend{Endpoint: "http://localhost:5555/unauthorized"},
err: errors.New("received HTTP status code 401"),
},
{
name: "Backend timeout",
backend: Backend{Endpoint: "http://localhost:5555/timeout"},
ctx: func() (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), time.Millisecond*100)
},
err: errors.New("Unable to connect to the auth backend"),
},
{
name: "Empty response",
backend: Backend{Endpoint: "http://localhost:5555/empty-response"},
err: errors.New("Connection string is required"),
},
{
name: "Missing header",
backend: Backend{Endpoint: "http://localhost:5555/pass-header"},
err: errors.New("received HTTP status code 400"),
},
{
name: "Require header",
backend: Backend{
Endpoint: "http://localhost:5555/pass-header",
PassHeaders: "x-foo",
},
reqCtx: &gin.Context{
Request: &http.Request{
Header: http.Header{
"X-Foo": []string{"bar"},
},
},
},
cred: &BackendCredential{DatabaseURL: "postgres://hostname/bar"},
},
{
name: "Success",
backend: Backend{Endpoint: "http://localhost:5555/success"},
cred: &BackendCredential{DatabaseURL: "postgres://hostname/dbname"},
},
}
srvCtx, srvCancel := context.WithTimeout(context.Background(), time.Minute)
defer srvCancel()
go startTestBackend(srvCtx, "localhost:5555")
for _, ex := range examples {
t.Run(ex.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
if ex.ctx != nil {
ctx, cancel = ex.ctx()
}
defer cancel()
reqCtx := ex.reqCtx
if reqCtx == nil {
reqCtx = &gin.Context{
Request: &http.Request{},
}
}
cred, err := ex.backend.FetchCredential(ctx, ex.resourceName, reqCtx)
assert.Equal(t, ex.err, err)
assert.Equal(t, ex.cred, cred)
})
}
}
func startTestBackend(ctx context.Context, listenAddr string) {
router := gin.New()
router.Use(func(c *gin.Context) {
if c.GetHeader("content-type") != "application/json" {
c.AbortWithStatus(http.StatusBadRequest)
}
})
router.POST("/unauthorized", func(c *gin.Context) {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
})
router.POST("/timeout", func(c *gin.Context) {
time.Sleep(time.Second)
c.JSON(http.StatusOK, gin.H{})
})
router.POST("/empty-response", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{})
})
router.POST("/pass-header", func(c *gin.Context) {
req := BackendRequest{}
if err := c.BindJSON(&req); err != nil {
panic(err)
}
header := req.Headers["x-foo"]
if header == "" {
c.AbortWithStatus(http.StatusBadRequest)
return
}
c.JSON(http.StatusOK, gin.H{
"database_url": "postgres://hostname/" + header,
})
})
router.POST("/success", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"database_url": "postgres://hostname/dbname",
})
})
server := &http.Server{Addr: listenAddr, Handler: router}
go server.ListenAndServe()
select {
case <-ctx.Done():
server.Shutdown(context.Background())
}
}