From e2f5e06c07669b81ab376bfc1cfe820b6ee95c23 Mon Sep 17 00:00:00 2001 From: Dan Sosedoff Date: Sat, 8 Jan 2022 14:45:21 -0600 Subject: [PATCH] Add connect backend tests (#546) --- pkg/api/api.go | 6 +- pkg/api/backend.go | 19 ++++-- pkg/api/backend_test.go | 148 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 166 insertions(+), 7 deletions(-) create mode 100644 pkg/api/backend_test.go diff --git a/pkg/api/api.go b/pkg/api/api.go index c136999..c951727 100644 --- a/pkg/api/api.go +++ b/pkg/api/api.go @@ -1,6 +1,7 @@ package api import ( + "context" "encoding/base64" "fmt" "net/http" @@ -94,8 +95,11 @@ func ConnectWithBackend(c *gin.Context) { PassHeaders: command.Opts.ConnectHeaders, } + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + // Fetch connection credentials - cred, err := backend.FetchCredential(c.Param("resource"), c) + cred, err := backend.FetchCredential(ctx, c.Param("resource"), c) if err != nil { badRequest(c, err) return diff --git a/pkg/api/backend.go b/pkg/api/backend.go index 17ad724..2d6ac2f 100644 --- a/pkg/api/backend.go +++ b/pkg/api/backend.go @@ -2,6 +2,7 @@ package api import ( "bytes" + "context" "encoding/json" "fmt" "log" @@ -31,7 +32,7 @@ type BackendCredential struct { } // FetchCredential sends an authentication request to a third-party service -func (be Backend) FetchCredential(resource string, c *gin.Context) (*BackendCredential, error) { +func (be Backend) FetchCredential(ctx context.Context, resource string, c *gin.Context) (*BackendCredential, error) { request := BackendRequest{ Resource: resource, Token: be.Token, @@ -45,21 +46,27 @@ func (be Backend) FetchCredential(resource string, c *gin.Context) (*BackendCred body, err := json.Marshal(request) if err != nil { + log.Println("[BACKEND] backend request serialization error:", err) + return nil, err } - resp, err := http.Post(be.Endpoint, "application/json", bytes.NewReader(body)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, be.Endpoint, bytes.NewReader(body)) + if err != nil { + return nil, err + } + req.Header.Set("content-type", "application/json") + + resp, err := http.DefaultClient.Do(req) if err != nil { // Any connection-related issues will show up in the server log - log.Println("Unable to fetch backend credential:", err) - - // We dont want to expose the url of the backend here, so reply with generic error + log.Println("[BACKEND] unable to fetch credential:", err) return nil, errBackendConnectError } defer resp.Body.Close() if resp.StatusCode != 200 { - return nil, fmt.Errorf("Got HTTP error %v from backend", resp.StatusCode) + return nil, fmt.Errorf("received HTTP status code %v", resp.StatusCode) } cred := &BackendCredential{} diff --git a/pkg/api/backend_test.go b/pkg/api/backend_test.go new file mode 100644 index 0000000..c10be7e --- /dev/null +++ b/pkg/api/backend_test.go @@ -0,0 +1,148 @@ +package api + +import ( + "context" + "errors" + "net/http" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" +) + +func TestBackendFetchCredential(t *testing.T) { + examples := []struct { + name string + backend Backend + resourceName string + cred *BackendCredential + reqCtx *gin.Context + ctx func() (context.Context, context.CancelFunc) + err error + }{ + { + name: "Bad auth token", + backend: Backend{Endpoint: "http://localhost:5555/unauthorized"}, + err: errors.New("received HTTP status code 401"), + }, + { + name: "Backend timeout", + backend: Backend{Endpoint: "http://localhost:5555/timeout"}, + ctx: func() (context.Context, context.CancelFunc) { + return context.WithTimeout(context.Background(), time.Millisecond*100) + }, + err: errors.New("Unable to connect to the auth backend"), + }, + { + name: "Empty response", + backend: Backend{Endpoint: "http://localhost:5555/empty-response"}, + err: errors.New("Connection string is required"), + }, + { + name: "Missing header", + backend: Backend{Endpoint: "http://localhost:5555/pass-header"}, + err: errors.New("received HTTP status code 400"), + }, + { + name: "Require header", + backend: Backend{ + Endpoint: "http://localhost:5555/pass-header", + PassHeaders: "x-foo", + }, + reqCtx: &gin.Context{ + Request: &http.Request{ + Header: http.Header{ + "X-Foo": []string{"bar"}, + }, + }, + }, + cred: &BackendCredential{DatabaseURL: "postgres://hostname/bar"}, + }, + { + name: "Success", + backend: Backend{Endpoint: "http://localhost:5555/success"}, + cred: &BackendCredential{DatabaseURL: "postgres://hostname/dbname"}, + }, + } + + srvCtx, srvCancel := context.WithTimeout(context.Background(), time.Minute) + defer srvCancel() + + go startTestBackend(srvCtx, "localhost:5555") + + for _, ex := range examples { + t.Run(ex.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + if ex.ctx != nil { + ctx, cancel = ex.ctx() + } + defer cancel() + + reqCtx := ex.reqCtx + if reqCtx == nil { + reqCtx = &gin.Context{ + Request: &http.Request{}, + } + } + + cred, err := ex.backend.FetchCredential(ctx, ex.resourceName, reqCtx) + assert.Equal(t, ex.err, err) + assert.Equal(t, ex.cred, cred) + }) + } +} + +func startTestBackend(ctx context.Context, listenAddr string) { + router := gin.New() + + router.Use(func(c *gin.Context) { + if c.GetHeader("content-type") != "application/json" { + c.AbortWithStatus(http.StatusBadRequest) + } + }) + + router.POST("/unauthorized", func(c *gin.Context) { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"}) + }) + + router.POST("/timeout", func(c *gin.Context) { + time.Sleep(time.Second) + c.JSON(http.StatusOK, gin.H{}) + }) + + router.POST("/empty-response", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{}) + }) + + router.POST("/pass-header", func(c *gin.Context) { + req := BackendRequest{} + if err := c.BindJSON(&req); err != nil { + panic(err) + } + + header := req.Headers["x-foo"] + if header == "" { + c.AbortWithStatus(http.StatusBadRequest) + return + } + + c.JSON(http.StatusOK, gin.H{ + "database_url": "postgres://hostname/" + header, + }) + }) + + router.POST("/success", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{ + "database_url": "postgres://hostname/dbname", + }) + }) + + server := &http.Server{Addr: listenAddr, Handler: router} + go server.ListenAndServe() + + select { + case <-ctx.Done(): + server.Shutdown(context.Background()) + } +}