diff --git a/README.md b/README.md index c91c693..32e62b6 100644 --- a/README.md +++ b/README.md @@ -285,6 +285,11 @@ The core functionality of reproxy can be extended with external plugins. Each pl - `HeadersIn` - incoming headers. Those will be sent to the proxied url - `HeadersOut` - outgoing headers. Will be sent back to the client +By default headers set by a plugin will be mixed with the original headers. In case if plugin need to control all the headers, for example drop some of them, `OverrideHeaders*` field can be set by a plugin indicating to the core reporxy process the need to overwrite all the headers instead of mixing them in. + +- `OverrideHeadersIn` - indicates plugin responsible for all incoming headers. +- `OverrideHeadersOut` - indicates plugin responsible for all outgoing headers + To simplify the development process all the building blocks provided. It includes `lib.Plugin` handling registration, listening and dispatching calls as well as `lib.Request` and `lib.Response` defining input and output. Plugin's authors should implement concrete handlers satisfying `func(req lib.Request, res *lib.HandlerResponse) (err error)` signature. Each plugin may contain multiple handlers like this. diff --git a/app/plugin/conductor.go b/app/plugin/conductor.go index 657d20d..e259075 100644 --- a/app/plugin/conductor.go +++ b/app/plugin/conductor.go @@ -87,6 +87,20 @@ func (c *Conductor) Run(ctx context.Context) error { // Failed plugin calls ignored. Status code from any plugin may stop the chain of calls if not 200. This is needed // to allow plugins like auth which has to terminate request in some cases. func (c *Conductor) Middleware(next http.Handler) http.Handler { + + setHeaders := func(src, alt http.Header, overrideHeaders bool) { + if overrideHeaders { + for k := range src { + src.Del(k) + } + } + for k, vv := range alt { + for _, v := range vv { + src.Add(k, v) + } + } + } + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { c.lock.RLock() @@ -101,16 +115,10 @@ func (c *Conductor) Middleware(next http.Handler) http.Handler { http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } - for k, vv := range reply.HeadersIn { - for _, v := range vv { - r.Header.Add(k, v) - } - } - for k, vv := range reply.HeadersOut { - for _, v := range vv { - w.Header().Add(k, v) - } - } + + setHeaders(r.Header, reply.HeadersIn, reply.OverrideHeadersIn) + setHeaders(w.Header(), reply.HeadersOut, reply.OverrideHeadersOut) + if reply.StatusCode >= 400 { c.lock.RUnlock() http.Error(w, http.StatusText(reply.StatusCode), reply.StatusCode) diff --git a/app/plugin/conductor_test.go b/app/plugin/conductor_test.go index 99a1191..8cf4e0b 100644 --- a/app/plugin/conductor_test.go +++ b/app/plugin/conductor_test.go @@ -223,6 +223,7 @@ func TestConductor_Middleware(t *testing.T) { reply.(*lib.Response).StatusCode = 200 reply.(*lib.Response).HeadersOut = map[string][]string{} reply.(*lib.Response).HeadersOut.Set("k11", "v11") + reply.(*lib.Response).OverrideHeadersOut = true } if serviceMethod == "Test1.Mw3" { t.Fatal("shouldn't be called") @@ -285,9 +286,10 @@ func TestConductor_Middleware(t *testing.T) { })) h.ServeHTTP(w, rr) assert.Equal(t, 200, w.Result().StatusCode) - assert.Equal(t, "v1", w.Result().Header.Get("k1")) + assert.Equal(t, "", w.Result().Header.Get("k1")) assert.Equal(t, "v2", w.Result().Header.Get("k2")) assert.Equal(t, "v21", rr.Header.Get("k21")) + assert.Equal(t, "v11", w.Result().Header.Get("k11")) t.Logf("req: %+v", rr) t.Logf("resp: %+v", w.Result()) } diff --git a/lib/rpc.go b/lib/rpc.go index 9fd2b8f..54cc8d8 100644 --- a/lib/rpc.go +++ b/lib/rpc.go @@ -25,7 +25,9 @@ type Request struct { // Response from plugin's handler call type Response struct { - StatusCode int - HeadersIn http.Header - HeadersOut http.Header + StatusCode int + HeadersIn http.Header + HeadersOut http.Header + OverrideHeadersIn bool // indicates plugin removing all the original incoming headers + OverrideHeadersOut bool // indicates plugin removing all the original outgoing headers }