treat 0 max request size limit as unlimited

This commit is contained in:
Umputun 2021-05-12 21:54:41 -05:00
parent dc39e2d090
commit a93bd40f8a
2 changed files with 76 additions and 1 deletions

View File

@ -106,7 +106,7 @@ func (h *Http) Run(ctx context.Context) error {
h.headersHandler(h.ProxyHeaders),
h.accessLogHandler(h.AccessLog),
h.stdoutLogHandler(h.StdOutEnabled, logger.New(logger.Log(log.Default()), logger.Prefix("[INFO]")).Handler),
R.SizeLimit(h.MaxBodySize),
h.maxReqSizeHandler(h.MaxBodySize),
h.gzipHandler(),
)
@ -335,6 +335,17 @@ func (h *Http) stdoutLogHandler(enable bool, lh func(next http.Handler) http.Han
}
}
func (h *Http) maxReqSizeHandler(maxSize int64) func(next http.Handler) http.Handler {
if maxSize <= 0 {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
next.ServeHTTP(w, r)
})
}
}
return R.SizeLimit(maxSize)
}
func (h *Http) makeHTTPServer(addr string, router http.Handler) *http.Server {
return &http.Server{
Addr: addr,

View File

@ -1,6 +1,7 @@
package proxy
import (
"bytes"
"context"
"fmt"
"io"
@ -248,7 +249,70 @@ func TestHttp_DoWithAssetRules(t *testing.T) {
assert.Equal(t, "", resp.Header.Get("h1"))
assert.Equal(t, "public, max-age=43200", resp.Header.Get("Cache-Control"))
}
}
func TestHttp_DoLimitedReq(t *testing.T) {
port := rand.Intn(10000) + 40000
h := Http{Timeouts: Timeouts{ResponseHeader: 200 * time.Millisecond}, Address: fmt.Sprintf("127.0.0.1:%d", port),
AccessLog: io.Discard, Signature: true, ProxyHeaders: []string{"hh1:vv1", "hh2:vv2"}, StdOutEnabled: true,
Reporter: &ErrorReporter{Nice: true}, MaxBodySize: 10}
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
ds := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Logf("req: %v", r)
w.Header().Add("h1", "v1")
require.Equal(t, "127.0.0.1", r.Header.Get("X-Real-IP"))
fmt.Fprintf(w, "response %s", r.URL.String())
}))
svc := discovery.NewService([]discovery.Provider{
&provider.Static{Rules: []string{
"localhost,^/api/(.*)," + ds.URL + "/123/$1,",
"127.0.0.1,^/api/(.*)," + ds.URL + "/567/$1,",
},
}}, time.Millisecond*10)
go func() {
_ = svc.Run(context.Background())
}()
time.Sleep(50 * time.Millisecond)
h.Matcher, h.Metrics = svc, mgmt.NewMetrics()
go func() {
_ = h.Run(ctx)
}()
time.Sleep(10 * time.Millisecond)
client := http.Client{}
{
req, err := http.NewRequest("POST", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", bytes.NewBufferString("abcdefg"))
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
t.Logf("%+v", resp.Header)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, "response /567/something", string(body))
assert.Equal(t, "reproxy", resp.Header.Get("App-Name"))
assert.Equal(t, "v1", resp.Header.Get("h1"))
assert.Equal(t, "vv1", resp.Header.Get("hh1"))
assert.Equal(t, "vv2", resp.Header.Get("hh2"))
}
{
req, err := http.NewRequest("POST", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", bytes.NewBufferString("abcdefg1234567"))
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusRequestEntityTooLarge, resp.StatusCode)
}
}
func TestHttp_toHttp(t *testing.T) {