diff --git a/pkg/api/api.go b/pkg/api/api.go index 291097a..d1afaed 100644 --- a/pkg/api/api.go +++ b/pkg/api/api.go @@ -92,7 +92,7 @@ func ConnectWithBackend(c *gin.Context) { backend := Backend{ Endpoint: command.Opts.ConnectBackend, Token: command.Opts.ConnectToken, - PassHeaders: command.Opts.ConnectHeaders, + PassHeaders: strings.Split(command.Opts.ConnectHeaders, ","), } ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) diff --git a/pkg/api/backend.go b/pkg/api/backend.go index 2d6ac2f..906dd06 100644 --- a/pkg/api/backend.go +++ b/pkg/api/backend.go @@ -5,7 +5,6 @@ import ( "context" "encoding/json" "fmt" - "log" "net/http" "strings" @@ -16,7 +15,7 @@ import ( type Backend struct { Endpoint string Token string - PassHeaders string + PassHeaders []string } // BackendRequest represents a payload sent to the third-party source @@ -33,6 +32,8 @@ type BackendCredential struct { // FetchCredential sends an authentication request to a third-party service func (be Backend) FetchCredential(ctx context.Context, resource string, c *gin.Context) (*BackendCredential, error) { + logger.WithField("resource", resource).Debug("fetching database credential") + request := BackendRequest{ Resource: resource, Token: be.Token, @@ -40,14 +41,13 @@ func (be Backend) FetchCredential(ctx context.Context, resource string, c *gin.C } // Pass white-listed client headers to the backend request - for _, name := range strings.Split(be.PassHeaders, ",") { + for _, name := range be.PassHeaders { request.Headers[strings.ToLower(name)] = c.Request.Header.Get(name) } body, err := json.Marshal(request) if err != nil { - log.Println("[BACKEND] backend request serialization error:", err) - + logger.WithField("resource", resource).Error("backend request serialization error:", err) return nil, err } @@ -59,14 +59,20 @@ func (be Backend) FetchCredential(ctx context.Context, resource string, c *gin.C resp, err := http.DefaultClient.Do(req) if err != nil { - // Any connection-related issues will show up in the server log - log.Println("[BACKEND] unable to fetch credential:", err) + logger.WithField("resource", resource).Error("backend credential fetch failed:", err) return nil, errBackendConnectError } defer resp.Body.Close() if resp.StatusCode != 200 { - return nil, fmt.Errorf("received HTTP status code %v", resp.StatusCode) + err = fmt.Errorf("backend credential fetch received HTTP status code %v", resp.StatusCode) + + logger. + WithField("resource", resource). + WithField("status", resp.StatusCode). + Error(err) + + return nil, err } cred := &BackendCredential{} diff --git a/pkg/api/backend_test.go b/pkg/api/backend_test.go index c0110a0..1a71c4b 100644 --- a/pkg/api/backend_test.go +++ b/pkg/api/backend_test.go @@ -24,7 +24,7 @@ func TestBackendFetchCredential(t *testing.T) { { name: "Bad auth token", backend: Backend{Endpoint: "http://localhost:5555/unauthorized"}, - err: errors.New("received HTTP status code 401"), + err: errors.New("backend credential fetch received HTTP status code 401"), }, { name: "Backend timeout", @@ -42,13 +42,13 @@ func TestBackendFetchCredential(t *testing.T) { { name: "Missing header", backend: Backend{Endpoint: "http://localhost:5555/pass-header"}, - err: errors.New("received HTTP status code 400"), + err: errors.New("backend credential fetch received HTTP status code 400"), }, { name: "Require header", backend: Backend{ Endpoint: "http://localhost:5555/pass-header", - PassHeaders: "x-foo", + PassHeaders: []string{"x-foo"}, }, reqCtx: &gin.Context{ Request: &http.Request{ diff --git a/pkg/api/logger.go b/pkg/api/logger.go index 9f87d0b..8673caf 100644 --- a/pkg/api/logger.go +++ b/pkg/api/logger.go @@ -11,6 +11,19 @@ import ( const loggerMessage = "http_request" +var logger *logrus.Logger + +func init() { + if logger == nil { + logger = logrus.New() + } +} + +// TODO: Move this into server struct when it's ready +func SetLogger(l *logrus.Logger) { + logger = l +} + func RequestLogger(logger *logrus.Logger) gin.HandlerFunc { debug := logger.Level > logrus.InfoLevel diff --git a/pkg/cli/cli.go b/pkg/cli/cli.go index 8247da8..8338295 100644 --- a/pkg/cli/cli.go +++ b/pkg/cli/cli.go @@ -199,6 +199,7 @@ func startServer() { router.Use(gin.BasicAuth(auth)) } + api.SetLogger(logger) api.SetupRoutes(router) fmt.Println("Starting server...") diff --git a/pkg/client/client_test.go b/pkg/client/client_test.go index ef346ca..d72e5f4 100644 --- a/pkg/client/client_test.go +++ b/pkg/client/client_test.go @@ -287,7 +287,7 @@ func testEstimatedTableRowsCount(t *testing.T) { assert.Equal(t, nil, err) assert.Equal(t, []string{"reltuples"}, res.Columns) - assert.Equal(t, []Row{Row{count}}, res.Rows) + assert.Equal(t, []Row{{count}}, res.Rows) } func testTableRowsCount(t *testing.T) { @@ -296,7 +296,7 @@ func testTableRowsCount(t *testing.T) { assert.Equal(t, nil, err) assert.Equal(t, []string{"count"}, res.Columns) - assert.Equal(t, []Row{Row{count}}, res.Rows) + assert.Equal(t, []Row{{count}}, res.Rows) } func testTableRowsCountWithLargeTable(t *testing.T) { @@ -307,7 +307,7 @@ func testTableRowsCountWithLargeTable(t *testing.T) { assert.Equal(t, nil, err) assert.Equal(t, []string{"reltuples"}, res.Columns) - assert.Equal(t, []Row{Row{count}}, res.Rows) + assert.Equal(t, []Row{{count}}, res.Rows) } func testTableIndexes(t *testing.T) {