From f0049ef7ac13c190e888bc75c64908c57b40d105 Mon Sep 17 00:00:00 2001 From: Umputun Date: Mon, 7 Jun 2021 18:57:42 -0500 Subject: [PATCH] add support of spa-like assets handling --- app/discovery/discovery.go | 9 ++- app/discovery/discovery_test.go | 17 +++-- app/discovery/provider/docker.go | 7 +- app/discovery/provider/file.go | 4 +- app/discovery/provider/file_test.go | 15 ++++- app/discovery/provider/static.go | 8 ++- app/discovery/provider/static_test.go | 18 ++--- app/discovery/provider/testdata/config.yml | 3 +- app/proxy/proxy.go | 18 +++-- app/proxy/proxy_test.go | 23 +++++++ app/proxy/testdata/index.html | 1 + go.mod | 2 +- go.sum | 2 + vendor/github.com/go-pkgz/rest/.golangci.yml | 2 +- vendor/github.com/go-pkgz/rest/file_server.go | 67 ++++++++++++++++++- vendor/github.com/go-pkgz/rest/sizelimit.go | 2 + vendor/modules.txt | 2 +- 17 files changed, 168 insertions(+), 32 deletions(-) create mode 100644 app/proxy/testdata/index.html diff --git a/app/discovery/discovery.go b/app/discovery/discovery.go index 13e8805..f17da2a 100644 --- a/app/discovery/discovery.go +++ b/app/discovery/discovery.go @@ -39,6 +39,7 @@ type URLMapper struct { AssetsLocation string AssetsWebRoot string + AssetsSPA bool dead bool } @@ -176,7 +177,12 @@ func (s *Service) Match(srv, src string) (res Matches) { case MTStatic: if src == m.AssetsWebRoot || strings.HasPrefix(src, m.AssetsWebRoot+"/") { res.MatchType = MTStatic - res.Routes = append(res.Routes, MatchedRoute{Destination: m.AssetsWebRoot + ":" + m.AssetsLocation, Alive: true}) + destSfx := ":norm" + if m.AssetsSPA { + destSfx = ":spa" + } + res.Routes = append(res.Routes, MatchedRoute{ + Destination: m.AssetsWebRoot + ":" + m.AssetsLocation + destSfx, Alive: true}) return res } } @@ -370,6 +376,7 @@ func (s *Service) extendMapper(m URLMapper) URLMapper { MatchType: m.MatchType, AssetsWebRoot: m.AssetsWebRoot, AssetsLocation: m.AssetsLocation, + AssetsSPA: m.AssetsSPA, } rx, err := regexp.Compile("^" + strings.TrimSuffix(src, "/") + "/(.*)") diff --git a/app/discovery/discovery_test.go b/app/discovery/discovery_test.go index 7b0fbf1..cb300fc 100644 --- a/app/discovery/discovery_test.go +++ b/app/discovery/discovery_test.go @@ -110,6 +110,8 @@ func TestService_Match(t *testing.T) { {SrcMatch: *regexp.MustCompile("/www/"), Dst: "/var/web", ProviderID: PIDocker, MatchType: MTStatic, AssetsWebRoot: "/www", AssetsLocation: "/var/web"}, {SrcMatch: *regexp.MustCompile("/path/"), Dst: "/var/web/path", ProviderID: PIDocker, MatchType: MTStatic}, + {SrcMatch: *regexp.MustCompile("/www2/"), Dst: "/var/web2", ProviderID: PIDocker, MatchType: MTStatic, + AssetsWebRoot: "/www2", AssetsLocation: "/var/web2", AssetsSPA: true}, }, nil }, } @@ -120,7 +122,7 @@ func TestService_Match(t *testing.T) { err := svc.Run(ctx) require.Error(t, err) assert.Equal(t, context.DeadlineExceeded, err) - assert.Equal(t, 10, len(svc.Mappers())) + assert.Equal(t, 11, len(svc.Mappers())) tbl := []struct { server, src string @@ -145,12 +147,13 @@ func TestService_Match(t *testing.T) { {Destination: "http://127.0.0.5:8080/blah2/num123456/abc/3", Alive: false}, }}}, - {"m1.example.com", "/web/index.html", Matches{MTStatic, []MatchedRoute{{Destination: "/web:/var/web/", Alive: true}}}}, - {"m1.example.com", "/web/", Matches{MTStatic, []MatchedRoute{{Destination: "/web:/var/web/", Alive: true}}}}, - {"m1.example.com", "/www/something", Matches{MTStatic, []MatchedRoute{{Destination: "/www:/var/web/", Alive: true}}}}, - {"m1.example.com", "/www/", Matches{MTStatic, []MatchedRoute{{Destination: "/www:/var/web/", Alive: true}}}}, - {"m1.example.com", "/www", Matches{MTStatic, []MatchedRoute{{Destination: "/www:/var/web/", Alive: true}}}}, - {"xyx.example.com", "/path/something", Matches{MTStatic, []MatchedRoute{{Destination: "/path:/var/web/path/", Alive: true}}}}, + {"m1.example.com", "/web/index.html", Matches{MTStatic, []MatchedRoute{{Destination: "/web:/var/web/:norm", Alive: true}}}}, + {"m1.example.com", "/web/", Matches{MTStatic, []MatchedRoute{{Destination: "/web:/var/web/:norm", Alive: true}}}}, + {"m1.example.com", "/www/something", Matches{MTStatic, []MatchedRoute{{Destination: "/www:/var/web/:norm", Alive: true}}}}, + {"m1.example.com", "/www/", Matches{MTStatic, []MatchedRoute{{Destination: "/www:/var/web/:norm", Alive: true}}}}, + {"m1.example.com", "/www", Matches{MTStatic, []MatchedRoute{{Destination: "/www:/var/web/:norm", Alive: true}}}}, + {"xyx.example.com", "/path/something", Matches{MTStatic, []MatchedRoute{{Destination: "/path:/var/web/path/:norm", Alive: true}}}}, + {"m1.example.com", "/www2", Matches{MTStatic, []MatchedRoute{{Destination: "/www2:/var/web2/:spa", Alive: true}}}}, } for i, tt := range tbl { diff --git a/app/discovery/provider/docker.go b/app/discovery/provider/docker.go index 4fa40a4..537d0f8 100644 --- a/app/discovery/provider/docker.go +++ b/app/discovery/provider/docker.go @@ -102,7 +102,7 @@ func (d *Docker) parseContainerInfo(c containerInfo) (res []discovery.URLMapper) // defaults destURL, pingURL, server := fmt.Sprintf("http://%s:%d/$1", c.IP, port), fmt.Sprintf("http://%s:%d/ping", c.IP, port), "*" - assetsWebRoot, assetsLocation := "", "" + assetsWebRoot, assetsLocation, assetsSPA := "", "", false if d.AutoAPI && n == 0 { enabled = true @@ -149,6 +149,10 @@ func (d *Docker) parseContainerInfo(c containerInfo) (res []discovery.URLMapper) } } + if _, ok := d.labelN(c.Labels, n, "spa"); ok { + assetsSPA = true + } + // should not set anything, handled on matchedPort level. just use to enable implicitly if _, ok := d.labelN(c.Labels, n, "port"); ok { enabled = true @@ -179,6 +183,7 @@ func (d *Docker) parseContainerInfo(c containerInfo) (res []discovery.URLMapper) mp.MatchType = discovery.MTStatic mp.AssetsWebRoot = assetsWebRoot mp.AssetsLocation = assetsLocation + mp.AssetsSPA = assetsSPA } res = append(res, mp) } diff --git a/app/discovery/provider/file.go b/app/discovery/provider/file.go index 35706e6..43799ce 100644 --- a/app/discovery/provider/file.go +++ b/app/discovery/provider/file.go @@ -74,6 +74,7 @@ func (d *File) List() (res []discovery.URLMapper, err error) { Dest string `yaml:"dest"` Ping string `yaml:"ping"` AssetsEnabled bool `yaml:"assets"` + AssetsSPA bool `yaml:"spa"` } fh, err := os.Open(d.FileName) if err != nil { @@ -103,8 +104,9 @@ func (d *File) List() (res []discovery.URLMapper, err error) { ProviderID: discovery.PIFile, MatchType: discovery.MTProxy, } - if f.AssetsEnabled { + if f.AssetsEnabled || f.AssetsSPA { mapper.MatchType = discovery.MTStatic + mapper.AssetsSPA = f.AssetsSPA } res = append(res, mapper) } diff --git a/app/discovery/provider/file_test.go b/app/discovery/provider/file_test.go index 85d8209..c55d370 100644 --- a/app/discovery/provider/file_test.go +++ b/app/discovery/provider/file_test.go @@ -10,6 +10,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/umputun/reproxy/app/discovery" ) func TestFile_Events(t *testing.T) { @@ -105,26 +107,37 @@ func TestFile_List(t *testing.T) { res, err := f.List() require.NoError(t, err) t.Logf("%+v", res) - assert.Equal(t, 4, len(res)) + assert.Equal(t, 5, len(res)) assert.Equal(t, "^/api/svc2/(.*)", res[0].SrcMatch.String()) assert.Equal(t, "http://127.0.0.2:8080/blah2/$1/abc", res[0].Dst) assert.Equal(t, "", res[0].PingURL) assert.Equal(t, "srv.example.com", res[0].Server) + assert.Equal(t, discovery.MTProxy, res[0].MatchType) assert.Equal(t, "^/api/svc1/(.*)", res[1].SrcMatch.String()) assert.Equal(t, "http://127.0.0.1:8080/blah1/$1", res[1].Dst) assert.Equal(t, "", res[1].PingURL) assert.Equal(t, "*", res[1].Server) + assert.Equal(t, discovery.MTProxy, res[1].MatchType) assert.Equal(t, "/api/svc3/xyz", res[2].SrcMatch.String()) assert.Equal(t, "http://127.0.0.3:8080/blah3/xyz", res[2].Dst) assert.Equal(t, "http://127.0.0.3:8080/ping", res[2].PingURL) assert.Equal(t, "*", res[2].Server) + assert.Equal(t, discovery.MTProxy, res[2].MatchType) assert.Equal(t, "/web/", res[3].SrcMatch.String()) assert.Equal(t, "/var/web", res[3].Dst) assert.Equal(t, "", res[3].PingURL) assert.Equal(t, "*", res[3].Server) + assert.Equal(t, discovery.MTStatic, res[3].MatchType) + assert.Equal(t, false, res[3].AssetsSPA) + assert.Equal(t, "/web2/", res[4].SrcMatch.String()) + assert.Equal(t, "/var/web2", res[4].Dst) + assert.Equal(t, "", res[4].PingURL) + assert.Equal(t, "*", res[4].Server) + assert.Equal(t, discovery.MTStatic, res[4].MatchType) + assert.Equal(t, true, res[4].AssetsSPA) } diff --git a/app/discovery/provider/static.go b/app/discovery/provider/static.go index 9adc11d..01f5a83 100644 --- a/app/discovery/provider/static.go +++ b/app/discovery/provider/static.go @@ -41,11 +41,16 @@ func (s *Static) List() (res []discovery.URLMapper, err error) { } dst := strings.TrimSpace(elems[2]) - assets := false + assets, spa := false, false if strings.HasPrefix(dst, "assets:") { dst = strings.TrimPrefix(dst, "assets:") assets = true } + if strings.HasPrefix(dst, "spa:") { + dst = strings.TrimPrefix(dst, "spa:") + assets = true + spa = true + } res := discovery.URLMapper{ Server: strings.TrimSpace(elems[0]), @@ -57,6 +62,7 @@ func (s *Static) List() (res []discovery.URLMapper, err error) { } if assets { res.MatchType = discovery.MTStatic + res.AssetsSPA = spa } return res, nil diff --git a/app/discovery/provider/static_test.go b/app/discovery/provider/static_test.go index d61240e..d42fd3b 100644 --- a/app/discovery/provider/static_test.go +++ b/app/discovery/provider/static_test.go @@ -15,16 +15,17 @@ func TestStatic_List(t *testing.T) { tbl := []struct { rule string server, src, dst, ping string - static bool + static, spa bool err bool }{ - {"example.com,123,456, ping ", "example.com", "123", "456", "ping", false, false}, - {"*,123,456,", "*", "123", "456", "", false, false}, - {"123,456", "", "", "", "", false, true}, - {"123", "", "", "", "", false, true}, - {"example.com , 123, 456 ,ping", "example.com", "123", "456", "ping", false, false}, - {"example.com,123, assets:456, ping ", "example.com", "123", "456", "ping", true, false}, - {"example.com,123, assets:456 ", "example.com", "123", "456", "", true, false}, + {"example.com,123,456, ping ", "example.com", "123", "456", "ping", false, false, false}, + {"*,123,456,", "*", "123", "456", "", false, false, false}, + {"123,456", "", "", "", "", false, false, true}, + {"123", "", "", "", "", false, false, true}, + {"example.com , 123, 456 ,ping", "example.com", "123", "456", "ping", false, false, false}, + {"example.com,123, assets:456, ping ", "example.com", "123", "456", "ping", true, false, false}, + {"example.com,123, assets:456 ", "example.com", "123", "456", "", true, false, false}, + {"example.com,123, spa:456 ", "example.com", "123", "456", "", true, true, false}, } for i, tt := range tbl { @@ -43,6 +44,7 @@ func TestStatic_List(t *testing.T) { assert.Equal(t, tt.ping, res[0].PingURL) if tt.static { assert.Equal(t, discovery.MTStatic, res[0].MatchType) + assert.Equal(t, tt.spa, res[0].AssetsSPA) } else { assert.Equal(t, discovery.MTProxy, res[0].MatchType) } diff --git a/app/discovery/provider/testdata/config.yml b/app/discovery/provider/testdata/config.yml index 5f3200c..d0662bb 100644 --- a/app/discovery/provider/testdata/config.yml +++ b/app/discovery/provider/testdata/config.yml @@ -1,6 +1,7 @@ default: - {route: "^/api/svc1/(.*)", dest: "http://127.0.0.1:8080/blah1/$1"} - {route: "/api/svc3/xyz", dest: "http://127.0.0.3:8080/blah3/xyz", "ping": "http://127.0.0.3:8080/ping"} - - {route: "/web/", dest: "/var/web", "static": yes} + - {route: "/web/", dest: "/var/web", "assets": yes} + - {route: "/web2/", dest: "/var/web2", "spa": yes} srv.example.com: - {route: "^/api/svc2/(.*)", dest: "http://127.0.0.2:8080/blah2/$1/abc"} diff --git a/app/proxy/proxy.go b/app/proxy/proxy.go index 7b25499..433e606 100644 --- a/app/proxy/proxy.go +++ b/app/proxy/proxy.go @@ -28,6 +28,7 @@ type Http struct { // nolint golint Address string AssetsLocation string AssetsWebRoot string + AssetsSPA bool MaxBodySize int64 GzEnabled bool ProxyHeaders []string @@ -242,14 +243,14 @@ func (h *Http) proxyHandler() http.HandlerFunc { } case discovery.MTStatic: - // static match result has webroot:location, i.e. /www:/var/somedir/ + // static match result has webroot:location:[spa:normal], i.e. /www:/var/somedir/:normal ae := strings.Split(match.Destination, ":") - if len(ae) != 2 { // shouldn't happen + if len(ae) != 3 { // shouldn't happen log.Printf("[WARN] unexpected static assets destination: %s", match.Destination) h.Reporter.Report(w, http.StatusInternalServerError) return } - fs, err := R.FileServer(ae[0], ae[1]) + fs, err := h.fileServer(ae[0], ae[1], ae[2] == "spa") if err != nil { log.Printf("[WARN] file server error, %v", err) h.Reporter.Report(w, http.StatusInternalServerError) @@ -288,7 +289,7 @@ func (h *Http) matchHandler(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server := r.URL.Hostname() if server == "" { - server = strings.Split(r.Host, ":")[0] + server = strings.Split(r.Host, ":")[0] // drop port } matches := h.Match(server, r.URL.Path) // get all matches for the server:path pair match, ok := getMatch(matches, h.LBSelector) @@ -317,7 +318,7 @@ func (h *Http) assetsHandler() http.HandlerFunc { return func(writer http.ResponseWriter, request *http.Request) {} } log.Printf("[DEBUG] shared assets server enabled for %s %s", h.AssetsWebRoot, h.AssetsLocation) - fs, err := R.FileServer(h.AssetsWebRoot, h.AssetsLocation) + fs, err := h.fileServer(h.AssetsWebRoot, h.AssetsLocation, h.AssetsSPA) if err != nil { log.Printf("[WARN] can't initialize assets server, %v", err) return func(writer http.ResponseWriter, request *http.Request) {} @@ -325,6 +326,13 @@ func (h *Http) assetsHandler() http.HandlerFunc { return h.CacheControl.Middleware(fs).ServeHTTP } +func (h *Http) fileServer(assetsWebRoot, assetsLocation string, spa bool) (http.Handler, error) { + if spa { + return R.FileServerSPA(assetsWebRoot, assetsLocation, nil) + } + return R.FileServer(assetsWebRoot, assetsLocation, nil) +} + func (h *Http) isAssetRequest(r *http.Request) bool { if h.AssetsLocation == "" || h.AssetsWebRoot == "" { return false diff --git a/app/proxy/proxy_test.go b/app/proxy/proxy_test.go index 7380a64..a1b31d0 100644 --- a/app/proxy/proxy_test.go +++ b/app/proxy/proxy_test.go @@ -204,6 +204,7 @@ func TestHttp_DoWithAssetRules(t *testing.T) { "localhost,^/api/(.*)," + ds.URL + "/123/$1,", "127.0.0.1,^/api/(.*)," + ds.URL + "/567/$1,", "*,/web,assets:testdata,", + "*,/web2,spa:testdata,", }, }}, time.Millisecond*10) @@ -220,6 +221,20 @@ func TestHttp_DoWithAssetRules(t *testing.T) { time.Sleep(10 * time.Millisecond) client := http.Client{} + { + resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/web2/nop.html") + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, http.StatusOK, resp.StatusCode) + t.Logf("%+v", resp.Header) + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, "index html", string(body)) + assert.Equal(t, "", resp.Header.Get("App-Method")) + assert.Equal(t, "", resp.Header.Get("h1")) + assert.Equal(t, "public, max-age=43200", resp.Header.Get("Cache-Control")) + } { req, err := http.NewRequest("GET", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", nil) @@ -251,6 +266,14 @@ func TestHttp_DoWithAssetRules(t *testing.T) { assert.Equal(t, "", resp.Header.Get("h1")) assert.Equal(t, "public, max-age=43200", resp.Header.Get("Cache-Control")) } + + { + resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/web/nop.html") + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, http.StatusNotFound, resp.StatusCode) + } + } func TestHttp_DoWithRedirects(t *testing.T) { diff --git a/app/proxy/testdata/index.html b/app/proxy/testdata/index.html new file mode 100644 index 0000000..e978caf --- /dev/null +++ b/app/proxy/testdata/index.html @@ -0,0 +1 @@ +index html \ No newline at end of file diff --git a/go.mod b/go.mod index 86117dd..6c2cea7 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.16 require ( github.com/go-pkgz/lgr v0.10.4 github.com/go-pkgz/repeater v1.1.3 - github.com/go-pkgz/rest v1.9.2 + github.com/go-pkgz/rest v1.10.0 github.com/gorilla/handlers v1.5.1 github.com/prometheus/client_golang v1.10.0 github.com/stretchr/testify v1.7.0 diff --git a/go.sum b/go.sum index 86a0d47..29fef4a 100644 --- a/go.sum +++ b/go.sum @@ -73,6 +73,8 @@ github.com/go-pkgz/rest v1.9.2 h1:RyBBRXBYY6eBgTW3UGYOyT4VQPDiBBFh/tesELWsryQ= github.com/go-pkgz/rest v1.9.2/go.mod h1:wZ/dGipZUaF9to0vIQl7PwDHgWQDB0jsrFg1xnAKLDw= github.com/go-pkgz/rest v1.9.3-0.20210514184429-77a1bddb51db h1:PoIO+kDPc0A6m5xlRao4No1P9Ew4hdyZ4UFnX9fbanc= github.com/go-pkgz/rest v1.9.3-0.20210514184429-77a1bddb51db/go.mod h1:wZ/dGipZUaF9to0vIQl7PwDHgWQDB0jsrFg1xnAKLDw= +github.com/go-pkgz/rest v1.10.0 h1:4tkm8IrI+Gke2uyq/NYRwIZ1nnta17Q2LB1bUHNh7OQ= +github.com/go-pkgz/rest v1.10.0/go.mod h1:wZ/dGipZUaF9to0vIQl7PwDHgWQDB0jsrFg1xnAKLDw= github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/gogo/googleapis v1.1.0/go.mod h1:gf4bu3Q80BeJ6H1S1vYPm8/ELATdvryBaNFGgqEef3s= diff --git a/vendor/github.com/go-pkgz/rest/.golangci.yml b/vendor/github.com/go-pkgz/rest/.golangci.yml index 831557a..4cd6d42 100644 --- a/vendor/github.com/go-pkgz/rest/.golangci.yml +++ b/vendor/github.com/go-pkgz/rest/.golangci.yml @@ -44,7 +44,7 @@ linters: - varcheck - stylecheck - gochecknoinits - - scopelint + - exportloopref - gocritic - nakedret - gosimple diff --git a/vendor/github.com/go-pkgz/rest/file_server.go b/vendor/github.com/go-pkgz/rest/file_server.go index 8c981a0..fc6a15d 100644 --- a/vendor/github.com/go-pkgz/rest/file_server.go +++ b/vendor/github.com/go-pkgz/rest/file_server.go @@ -2,6 +2,8 @@ package rest import ( "fmt" + "io" + "io/ioutil" "net/http" "os" "path/filepath" @@ -12,7 +14,8 @@ import ( // prevents directory listing. // - public defines base path of the url, i.e. for http://example.com/static/* it should be /static // - local for the local path to the root of the served directory -func FileServer(public, local string) (http.Handler, error) { +// - notFound is the reader for the custom 404 html, can be nil for default +func FileServer(public, local string, notFound io.Reader) (http.Handler, error) { root, err := filepath.Abs(local) if err != nil { @@ -22,15 +25,38 @@ func FileServer(public, local string) (http.Handler, error) { return nil, fmt.Errorf("local path %s doesn't exist: %w", root, err) } - return http.StripPrefix(public, http.FileServer(noDirListingFS{http.Dir(root)})), nil + fs := http.StripPrefix(public, http.FileServer(noDirListingFS{http.Dir(root), false})) + return custom404Handler(fs, notFound) } -type noDirListingFS struct{ fs http.FileSystem } +// FileServerSPA returns FileServer as above, but instead of no-found returns /local/index.html +func FileServerSPA(public, local string, notFound io.Reader) (http.Handler, error) { + + root, err := filepath.Abs(local) + if err != nil { + return nil, fmt.Errorf("can't get absolute path for %s: %w", local, err) + } + if _, err = os.Stat(root); os.IsNotExist(err) { + return nil, fmt.Errorf("local path %s doesn't exist: %w", root, err) + } + + fs := http.StripPrefix(public, http.FileServer(noDirListingFS{http.Dir(root), true})) + return custom404Handler(fs, notFound) +} + +type noDirListingFS struct { + fs http.FileSystem + spa bool +} // Open file on FS, for directory enforce index.html and fail on a missing index func (fs noDirListingFS) Open(name string) (http.File, error) { + f, err := fs.fs.Open(name) if err != nil { + if fs.spa { + return fs.fs.Open("/index.html") + } return nil, err } @@ -47,3 +73,38 @@ func (fs noDirListingFS) Open(name string) (http.File, error) { } return f, nil } + +// respWriter404 intercept Write to provide custom 404 response +type respWriter404 struct { + http.ResponseWriter + status int + msg []byte +} + +func (w *respWriter404) WriteHeader(status int) { + w.status = status + w.ResponseWriter.WriteHeader(status) +} + +func (w *respWriter404) Write(p []byte) (n int, err error) { + if w.status != http.StatusNotFound || w.msg == nil { + return w.ResponseWriter.Write(p) + } + _, err = w.ResponseWriter.Write(w.msg) + return len(p), err +} + +func custom404Handler(next http.Handler, notFound io.Reader) (http.Handler, error) { + if notFound == nil { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { next.ServeHTTP(w, r) }), nil + } + + body, err := ioutil.ReadAll(notFound) + if err != nil { + return nil, err + } + + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + next.ServeHTTP(&respWriter404{ResponseWriter: w, msg: body}, r) + }), nil +} diff --git a/vendor/github.com/go-pkgz/rest/sizelimit.go b/vendor/github.com/go-pkgz/rest/sizelimit.go index 1b6be67..6770e88 100644 --- a/vendor/github.com/go-pkgz/rest/sizelimit.go +++ b/vendor/github.com/go-pkgz/rest/sizelimit.go @@ -26,6 +26,8 @@ func SizeLimit(size int64) func(http.Handler) http.Handler { w.WriteHeader(http.StatusServiceUnavailable) return } + _ = r.Body.Close() // the original body already consumed + if int64(len(content)) > size { w.WriteHeader(http.StatusRequestEntityTooLarge) return diff --git a/vendor/modules.txt b/vendor/modules.txt index f45f7ac..62baaf8 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -13,7 +13,7 @@ github.com/go-pkgz/lgr ## explicit github.com/go-pkgz/repeater github.com/go-pkgz/repeater/strategy -# github.com/go-pkgz/rest v1.9.2 +# github.com/go-pkgz/rest v1.10.0 ## explicit github.com/go-pkgz/rest github.com/go-pkgz/rest/logger