reproxy/app/proxy/proxy_test.go

917 lines
28 KiB
Go

package proxy
import (
"bytes"
"context"
"fmt"
"io"
"math/rand"
"net/http"
"net/http/httptest"
"net/url"
"strconv"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/umputun/reproxy/app/discovery"
"github.com/umputun/reproxy/app/discovery/provider"
"github.com/umputun/reproxy/app/mgmt"
)
func TestHttp_Do(t *testing.T) {
port := rand.Intn(10000) + 40000
h := Http{Timeouts: Timeouts{ResponseHeader: 200 * time.Millisecond}, Address: fmt.Sprintf("127.0.0.1:%d", port),
AccessLog: io.Discard, Signature: true, ProxyHeaders: []string{"hh1:vv1", "hh2:vv2"}, StdOutEnabled: true,
Reporter: &ErrorReporter{Nice: true}}
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
ds := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Logf("req: %v", r)
w.Header().Add("h1", "v1")
require.Equal(t, "127.0.0.1", r.Header.Get("X-Real-IP"))
fmt.Fprintf(w, "response %s", r.URL.String())
}))
svc := discovery.NewService([]discovery.Provider{
&provider.Static{Rules: []string{
"localhost,^/api/(.*)," + ds.URL + "/123/$1,",
"127.0.0.1,^/api/(.*)," + ds.URL + "/567/$1,",
},
}}, time.Millisecond*10)
go func() {
_ = svc.Run(context.Background())
}()
time.Sleep(50 * time.Millisecond)
h.Matcher = svc
h.Metrics = mgmt.NewMetrics()
go func() {
_ = h.Run(ctx)
}()
time.Sleep(10 * time.Millisecond)
client := http.Client{}
{
req, err := http.NewRequest("GET", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", http.NoBody)
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
t.Logf("%+v", resp.Header)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, "response /567/something", string(body))
assert.Equal(t, "reproxy", resp.Header.Get("App-Name"))
assert.Equal(t, "v1", resp.Header.Get("h1"))
assert.Equal(t, "vv1", resp.Header.Get("hh1"))
assert.Equal(t, "vv2", resp.Header.Get("hh2"))
}
{
resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/api/something")
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
t.Logf("%+v", resp.Header)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, "response /123/something", string(body))
assert.Equal(t, "reproxy", resp.Header.Get("App-Name"))
assert.Equal(t, "v1", resp.Header.Get("h1"))
}
{
resp, err := client.Get("http://127.0.0.1:" + strconv.Itoa(port) + "/bad/something")
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusBadGateway, resp.StatusCode)
b, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Contains(t, string(b), "Sorry for the inconvenience")
assert.Equal(t, "text/html; charset=utf-8", resp.Header.Get("Content-Type"))
}
{
resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/api/test%20%25%20and%20&,%20and%20other%20characters%20@%28%29%5E%21")
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
t.Logf("%+v", resp.Header)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, "response /123/test%20%25%20and%20&,%20and%20other%20characters%20@%28%29%5E%21", string(body))
assert.Equal(t, "reproxy", resp.Header.Get("App-Name"))
assert.Equal(t, "v1", resp.Header.Get("h1"))
}
}
func TestHttp_DoWithAssets(t *testing.T) {
port := rand.Intn(10000) + 40000
cc := NewCacheControl(time.Hour * 12)
h := Http{Timeouts: Timeouts{ResponseHeader: 200 * time.Millisecond}, Address: fmt.Sprintf("127.0.0.1:%d", port),
AccessLog: io.Discard, AssetsWebRoot: "/static", AssetsLocation: "testdata", CacheControl: cc, Reporter: &ErrorReporter{Nice: false}}
ctx, cancel := context.WithTimeout(context.Background(), 1000*time.Millisecond)
defer cancel()
ds := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Logf("req: %v", r)
w.Header().Add("h1", "v1")
require.Equal(t, "127.0.0.1", r.Header.Get("X-Real-IP"))
fmt.Fprintf(w, "response %s", r.URL.String())
}))
svc := discovery.NewService([]discovery.Provider{
&provider.Static{Rules: []string{
"localhost,^/api/(.*)," + ds.URL + "/123/$1,",
"127.0.0.1,^/api/(.*)," + ds.URL + "/567/$1,",
},
}}, time.Millisecond*10)
go func() {
_ = svc.Run(context.Background())
}()
time.Sleep(50 * time.Millisecond)
h.Matcher = svc
h.Metrics = mgmt.NewMetrics()
go func() {
_ = h.Run(ctx)
}()
time.Sleep(50 * time.Millisecond)
client := http.Client{}
{
req, err := http.NewRequest("GET", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", http.NoBody)
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
t.Logf("%+v", resp.Header)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, "response /567/something", string(body))
assert.Equal(t, "", resp.Header.Get("App-Method"))
assert.Equal(t, "v1", resp.Header.Get("h1"))
}
{
resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/static/1.html")
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
t.Logf("%+v", resp.Header)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, "test html", string(body))
assert.Equal(t, "", resp.Header.Get("App-Method"))
assert.Equal(t, "", resp.Header.Get("h1"))
assert.Equal(t, "public, max-age=43200", resp.Header.Get("Cache-Control"))
}
{
resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/static/bad.html")
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusNotFound, resp.StatusCode)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, "404 page not found\n", string(body))
}
{
resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/svcbad")
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusBadGateway, resp.StatusCode)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Contains(t, string(body), "Server error")
assert.Equal(t, "text/plain; charset=utf-8", resp.Header.Get("Content-Type"))
}
}
func TestHttp_DoWithAssetsCustom404(t *testing.T) {
port := rand.Intn(10000) + 40000
cc := NewCacheControl(time.Hour * 12)
h := Http{Timeouts: Timeouts{ResponseHeader: 200 * time.Millisecond}, Address: fmt.Sprintf("127.0.0.1:%d", port),
AccessLog: io.Discard, AssetsWebRoot: "/static", AssetsLocation: "testdata", Assets404: "404.html",
CacheControl: cc, Reporter: &ErrorReporter{Nice: false}}
ctx, cancel := context.WithTimeout(context.Background(), 1000*time.Millisecond)
defer cancel()
ds := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Logf("req: %v", r)
w.Header().Add("h1", "v1")
require.Equal(t, "127.0.0.1", r.Header.Get("X-Real-IP"))
fmt.Fprintf(w, "response %s", r.URL.String())
}))
svc := discovery.NewService([]discovery.Provider{
&provider.Static{Rules: []string{
"localhost,^/api/(.*)," + ds.URL + "/123/$1,",
"127.0.0.1,^/api/(.*)," + ds.URL + "/567/$1,",
},
}}, time.Millisecond*10)
go func() {
_ = svc.Run(context.Background())
}()
time.Sleep(50 * time.Millisecond)
h.Matcher = svc
h.Metrics = mgmt.NewMetrics()
go func() {
_ = h.Run(ctx)
}()
time.Sleep(50 * time.Millisecond)
client := http.Client{}
{
req, err := http.NewRequest("GET", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", http.NoBody)
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
t.Logf("%+v", resp.Header)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, "response /567/something", string(body))
assert.Equal(t, "", resp.Header.Get("App-Method"))
assert.Equal(t, "v1", resp.Header.Get("h1"))
}
{
resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/static/1.html")
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
t.Logf("%+v", resp.Header)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, "test html", string(body))
assert.Equal(t, "", resp.Header.Get("App-Method"))
assert.Equal(t, "", resp.Header.Get("h1"))
assert.Equal(t, "public, max-age=43200", resp.Header.Get("Cache-Control"))
}
{
resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/static/bad.html")
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusNotFound, resp.StatusCode)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, "not found! blah blah blah\nthere is no spoon", string(body))
t.Logf("%+v", resp.Header)
assert.Equal(t, "text/html; charset=utf-8", resp.Header.Get("Content-Type"))
}
{
resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/static/bad2.html")
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusNotFound, resp.StatusCode)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, "not found! blah blah blah\nthere is no spoon", string(body))
t.Logf("%+v", resp.Header)
assert.Equal(t, "text/html; charset=utf-8", resp.Header.Get("Content-Type"))
}
}
func TestHttp_DoWithSpaAssets(t *testing.T) {
port := rand.Intn(10000) + 40000
cc := NewCacheControl(time.Hour * 12)
h := Http{Timeouts: Timeouts{ResponseHeader: 200 * time.Millisecond}, Address: fmt.Sprintf("127.0.0.1:%d", port),
AccessLog: io.Discard, AssetsWebRoot: "/static", AssetsLocation: "testdata", AssetsSPA: true,
CacheControl: cc, Reporter: &ErrorReporter{Nice: false}}
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
ds := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Logf("req: %v", r)
w.Header().Add("h1", "v1")
require.Equal(t, "127.0.0.1", r.Header.Get("X-Real-IP"))
fmt.Fprintf(w, "response %s", r.URL.String())
}))
svc := discovery.NewService([]discovery.Provider{
&provider.Static{Rules: []string{
"localhost,^/api/(.*)," + ds.URL + "/123/$1,",
"127.0.0.1,^/api/(.*)," + ds.URL + "/567/$1,",
},
}}, time.Millisecond*10)
go func() {
_ = svc.Run(context.Background())
}()
time.Sleep(50 * time.Millisecond)
h.Matcher = svc
h.Metrics = mgmt.NewMetrics()
go func() {
_ = h.Run(ctx)
}()
time.Sleep(10 * time.Millisecond)
client := http.Client{}
{
req, err := http.NewRequest("GET", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", http.NoBody)
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
t.Logf("%+v", resp.Header)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, "response /567/something", string(body))
assert.Equal(t, "", resp.Header.Get("App-Method"))
assert.Equal(t, "v1", resp.Header.Get("h1"))
}
{
resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/static/1.html")
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
t.Logf("%+v", resp.Header)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, "test html", string(body))
assert.Equal(t, "", resp.Header.Get("App-Method"))
assert.Equal(t, "", resp.Header.Get("h1"))
assert.Equal(t, "public, max-age=43200", resp.Header.Get("Cache-Control"))
}
{
resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/static/bad.html")
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
t.Logf("%+v", resp.Header)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, "index html", string(body))
assert.Equal(t, "", resp.Header.Get("App-Method"))
assert.Equal(t, "", resp.Header.Get("h1"))
assert.Equal(t, "public, max-age=43200", resp.Header.Get("Cache-Control"))
}
{
resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/svcbad")
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusBadGateway, resp.StatusCode)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Contains(t, string(body), "Server error")
assert.Equal(t, "text/plain; charset=utf-8", resp.Header.Get("Content-Type"))
}
}
func TestHttp_DoWithAssetRules(t *testing.T) {
port := rand.Intn(10000) + 40000
cc := NewCacheControl(time.Hour * 12)
h := Http{Timeouts: Timeouts{ResponseHeader: 200 * time.Millisecond}, Address: fmt.Sprintf("127.0.0.1:%d", port),
AccessLog: io.Discard, CacheControl: cc, Reporter: &ErrorReporter{}}
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
ds := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Logf("req: %v", r)
w.Header().Add("h1", "v1")
require.Equal(t, "127.0.0.1", r.Header.Get("X-Real-IP"))
fmt.Fprintf(w, "response %s", r.URL.String())
}))
svc := discovery.NewService([]discovery.Provider{
&provider.Static{Rules: []string{
"localhost,^/api/(.*)," + ds.URL + "/123/$1,",
"127.0.0.1,^/api/(.*)," + ds.URL + "/567/$1,",
"*,/web,assets:testdata,",
"*,/web2,spa:testdata,",
},
}}, time.Millisecond*10)
go func() {
_ = svc.Run(context.Background())
}()
time.Sleep(50 * time.Millisecond)
h.Matcher = svc
h.Metrics = mgmt.NewMetrics()
go func() {
_ = h.Run(ctx)
}()
time.Sleep(10 * time.Millisecond)
client := http.Client{}
{
resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/web2/nop.html")
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
t.Logf("%+v", resp.Header)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, "index html", string(body))
assert.Equal(t, "", resp.Header.Get("App-Method"))
assert.Equal(t, "", resp.Header.Get("h1"))
assert.Equal(t, "public, max-age=43200", resp.Header.Get("Cache-Control"))
}
{
req, err := http.NewRequest("GET", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", http.NoBody)
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
t.Logf("%+v", resp.Header)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, "response /567/something", string(body))
assert.Equal(t, "", resp.Header.Get("App-Method"))
assert.Equal(t, "v1", resp.Header.Get("h1"))
}
{
resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/web/1.html")
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
t.Logf("%+v", resp.Header)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, "test html", string(body))
assert.Equal(t, "", resp.Header.Get("App-Method"))
assert.Equal(t, "", resp.Header.Get("h1"))
assert.Equal(t, "public, max-age=43200", resp.Header.Get("Cache-Control"))
}
{
resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/web/nop.html")
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusNotFound, resp.StatusCode)
}
}
func TestHttp_DoWithRedirects(t *testing.T) {
port := rand.Intn(10000) + 40000
cc := NewCacheControl(time.Hour * 12)
h := Http{Timeouts: Timeouts{ResponseHeader: 200 * time.Millisecond}, Address: fmt.Sprintf("127.0.0.1:%d", port),
AccessLog: io.Discard, CacheControl: cc, Reporter: &ErrorReporter{}}
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
svc := discovery.NewService([]discovery.Provider{
&provider.Static{Rules: []string{
"localhost,^/api/(.*),@perm http://example.com/123/$1,",
"127.0.0.1,^/api/(.*),@302 http://example.com/567/$1,",
},
}}, time.Millisecond*10)
go func() {
_ = svc.Run(context.Background())
}()
time.Sleep(50 * time.Millisecond)
h.Matcher = svc
h.Metrics = mgmt.NewMetrics()
go func() {
_ = h.Run(ctx)
}()
time.Sleep(10 * time.Millisecond)
client := &http.Client{
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
{
req, err := http.NewRequest("GET", "http://localhost:"+strconv.Itoa(port)+"/api/something", http.NoBody)
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusMovedPermanently, resp.StatusCode)
t.Logf("%+v", resp.Header)
assert.Equal(t, "http://example.com/123/something", resp.Header.Get("Location"))
}
{
req, err := http.NewRequest("GET", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", http.NoBody)
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusFound, resp.StatusCode)
t.Logf("%+v", resp.Header)
assert.Equal(t, "http://example.com/567/something", resp.Header.Get("Location"))
}
}
func TestHttp_DoLimitedReq(t *testing.T) {
port := rand.Intn(10000) + 40000
h := Http{Timeouts: Timeouts{ResponseHeader: 200 * time.Millisecond}, Address: fmt.Sprintf("127.0.0.1:%d", port),
AccessLog: io.Discard, Signature: true, ProxyHeaders: []string{"hh1:vv1", "hh2:vv2"}, StdOutEnabled: true,
Reporter: &ErrorReporter{Nice: true}, MaxBodySize: 10}
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
ds := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Logf("req: %v", r)
w.Header().Add("h1", "v1")
require.Equal(t, "127.0.0.1", r.Header.Get("X-Real-IP"))
fmt.Fprintf(w, "response %s", r.URL.String())
}))
svc := discovery.NewService([]discovery.Provider{
&provider.Static{Rules: []string{
"localhost,^/api/(.*)," + ds.URL + "/123/$1,",
"127.0.0.1,^/api/(.*)," + ds.URL + "/567/$1,",
},
}}, time.Millisecond*10)
go func() {
_ = svc.Run(context.Background())
}()
time.Sleep(50 * time.Millisecond)
h.Matcher, h.Metrics = svc, mgmt.NewMetrics()
go func() {
_ = h.Run(ctx)
}()
time.Sleep(50 * time.Millisecond)
client := http.Client{}
{
req, err := http.NewRequest("POST", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", bytes.NewBufferString("abcdefg"))
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
t.Logf("%+v", resp.Header)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, "response /567/something", string(body))
assert.Equal(t, "reproxy", resp.Header.Get("App-Name"))
assert.Equal(t, "v1", resp.Header.Get("h1"))
assert.Equal(t, "vv1", resp.Header.Get("hh1"))
assert.Equal(t, "vv2", resp.Header.Get("hh2"))
}
{
req, err := http.NewRequest("POST", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", bytes.NewBufferString("abcdefg1234567"))
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusRequestEntityTooLarge, resp.StatusCode)
}
}
func TestHttp_health(t *testing.T) {
port := rand.Intn(10000) + 40000
h := Http{Timeouts: Timeouts{ResponseHeader: 200 * time.Millisecond}, Address: fmt.Sprintf("127.0.0.1:%d", port),
AccessLog: io.Discard, Signature: true, ProxyHeaders: []string{"hh1:vv1", "hh2:vv2"}, StdOutEnabled: true,
Reporter: &ErrorReporter{Nice: true}}
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
ds := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Logf("req: %v", r)
w.Header().Add("h1", "v1")
require.Equal(t, "127.0.0.1", r.Header.Get("X-Real-IP"))
fmt.Fprintf(w, "response %s", r.URL.String())
}))
svc := discovery.NewService([]discovery.Provider{
&provider.Static{Rules: []string{
"localhost,^/api/(.*)," + ds.URL + "/123/$1,",
"127.0.0.1,^/api/(.*)," + ds.URL + "/567/$1,",
"*,/web,spa:testdata,",
},
}}, time.Millisecond*10)
go func() {
_ = svc.Run(context.Background())
}()
time.Sleep(50 * time.Millisecond)
h.Matcher, h.Metrics = svc, mgmt.NewMetrics()
go func() {
_ = h.Run(ctx)
}()
time.Sleep(10 * time.Millisecond)
client := http.Client{}
{
req, err := http.NewRequest("POST", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", bytes.NewBufferString("abcdefg"))
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
t.Logf("%+v", resp.Header)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, "response /567/something", string(body))
assert.Equal(t, "reproxy", resp.Header.Get("App-Name"))
assert.Equal(t, "v1", resp.Header.Get("h1"))
assert.Equal(t, "vv1", resp.Header.Get("hh1"))
assert.Equal(t, "vv2", resp.Header.Get("hh2"))
}
{
req, err := http.NewRequest("GET", "http://127.0.0.1:"+strconv.Itoa(port)+"/health", http.NoBody)
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, `{"status": "ok", "services": 2}`, string(body))
}
}
func TestHttp_withBasicAuth(t *testing.T) {
port := rand.Intn(10000) + 40000
h := Http{Timeouts: Timeouts{ResponseHeader: 200 * time.Millisecond}, Address: fmt.Sprintf("127.0.0.1:%d", port),
AccessLog: io.Discard, Signature: true, ProxyHeaders: []string{"hh1:vv1", "hh2:vv2"}, StdOutEnabled: true,
Reporter: &ErrorReporter{Nice: true}, BasicAuthEnabled: true, BasicAuthAllowed: []string{
"test:$2y$05$zMxDmK65SjcH2vJQNopVSO/nE8ngVLx65RoETyHpez7yTS/8CLEiW",
"test2:$2y$05$TLQqHh6VT4JxysdKGPOlJeSkkMsv.Ku/G45i7ssIm80XuouCrES12",
}}
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
ds := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Logf("req: %v", r)
w.Header().Add("h1", "v1")
require.Equal(t, "127.0.0.1", r.Header.Get("X-Real-IP"))
fmt.Fprintf(w, "response %s", r.URL.String())
}))
svc := discovery.NewService([]discovery.Provider{
&provider.Static{Rules: []string{
"localhost,^/api/(.*)," + ds.URL + "/123/$1,",
"127.0.0.1,^/api/(.*)," + ds.URL + "/567/$1,",
"*,/web,spa:testdata,",
},
}}, time.Millisecond*10)
go func() {
_ = svc.Run(context.Background())
}()
time.Sleep(50 * time.Millisecond)
h.Matcher, h.Metrics = svc, mgmt.NewMetrics()
go func() {
_ = h.Run(ctx)
}()
time.Sleep(10 * time.Millisecond)
client := http.Client{}
{
req, err := http.NewRequest("POST", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", bytes.NewBufferString("abcdefg"))
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
}
{
req, err := http.NewRequest("POST", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", bytes.NewBufferString("abcdefg"))
req.SetBasicAuth("test", "badpasswd")
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
}
{
req, err := http.NewRequest("POST", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", bytes.NewBufferString("abcdefg"))
req.SetBasicAuth("test", "passwd")
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
}
{
req, err := http.NewRequest("POST", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", bytes.NewBufferString("abcdefg"))
req.SetBasicAuth("test2", "passwd2")
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
}
}
func TestHttp_toHttp(t *testing.T) {
tbl := []struct {
addr string
port int
res string
}{
{"localhost:1234", 80, "localhost:80"},
{"m.example.com:443", 8080, "m.example.com:8080"},
{"192.168.1.1:1443", 8080, "192.168.1.1:8080"},
}
h := Http{}
for i, tt := range tbl {
tt := tt
t.Run(strconv.Itoa(i), func(t *testing.T) {
assert.Equal(t, tt.res, h.toHTTP(tt.addr, tt.port))
})
}
}
func TestHttp_isAssetRequest(t *testing.T) {
tbl := []struct {
req string
assetsLocation string
assetsWebRoot string
res bool
}{
{"/static/123.html", "/tmp", "/static", true},
{"/static/123.html", "/tmp", "/static/", true},
{"/static", "/tmp", "/static", true},
{"/static/", "/tmp", "/static", true},
{"/bad/", "/tmp", "/static", false},
{"/static/", "", "/static", false},
{"/static/", "/tmp", "", false},
}
for i, tt := range tbl {
t.Run(strconv.Itoa(i), func(t *testing.T) {
h := Http{AssetsLocation: tt.assetsLocation, AssetsWebRoot: tt.assetsWebRoot}
r, err := http.NewRequest("GET", tt.req, http.NoBody)
require.NoError(t, err)
assert.Equal(t, tt.res, h.isAssetRequest(r))
})
}
}
func TestHttp_matchHandler(t *testing.T) {
tbl := []struct {
matches discovery.Matches
res string
ok bool
}{
{
discovery.Matches{MatchType: discovery.MTProxy, Routes: []discovery.MatchedRoute{
{Destination: "dest1", Alive: true},
{Destination: "dest2", Alive: true},
{Destination: "dest3", Alive: true},
}},
"dest1", true,
},
{
discovery.Matches{MatchType: discovery.MTProxy, Routes: []discovery.MatchedRoute{
{Destination: "dest1", Alive: false},
{Destination: "dest2", Alive: true},
{Destination: "dest3", Alive: false},
}},
"dest2", true,
},
{
discovery.Matches{MatchType: discovery.MTProxy, Routes: []discovery.MatchedRoute{
{Destination: "dest1", Alive: false},
{Destination: "dest2", Alive: true},
{Destination: "dest3", Alive: true},
}},
"dest2", true,
},
{
discovery.Matches{MatchType: discovery.MTProxy, Routes: []discovery.MatchedRoute{
{Destination: "dest1", Alive: false},
{Destination: "dest2", Alive: false},
{Destination: "dest3", Alive: true},
}},
"dest3", true,
},
{
discovery.Matches{MatchType: discovery.MTProxy, Routes: []discovery.MatchedRoute{
{Destination: "dest1", Alive: false},
{Destination: "dest2", Alive: false},
{Destination: "dest3", Alive: false},
}},
"", false,
},
{
discovery.Matches{MatchType: discovery.MTProxy, Routes: []discovery.MatchedRoute{}}, "", false,
},
}
var count int32
matcherMock := &MatcherMock{
MatchFunc: func(srv string, src string) discovery.Matches {
return tbl[atomic.LoadInt32(&count)].matches
},
}
client := http.Client{}
for i, tt := range tbl {
t.Run(strconv.Itoa(i), func(t *testing.T) {
h := Http{Matcher: matcherMock, LBSelector: func(len int) int { return 0 }}
handler := h.matchHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Logf("req: %+v", r)
t.Logf("dst: %v", r.Context().Value(ctxURL))
v := r.Context().Value(ctxURL)
if v == nil {
require.False(t, tt.ok)
return
}
assert.Equal(t, tt.res, v.(*url.URL).String())
}))
req, err := http.NewRequest("GET", "http://example.com", http.NoBody)
require.NoError(t, err)
wr := httptest.NewRecorder()
handler.ServeHTTP(wr, req)
resp, err := client.Do(req)
require.NoError(t, err)
assert.Equal(t, 200, resp.StatusCode)
atomic.AddInt32(&count, 1)
})
}
}
func TestHttp_discoveredServers(t *testing.T) {
calls := 0
m := &MatcherMock{ServersFunc: func() []string {
defer func() { calls++ }()
switch calls {
case 0, 1, 2, 3, 4:
return []string{}
case 5:
return []string{"s1", "s2"}
case 6, 7:
return []string{"s1", "s2", "s3"}
default:
t.Fatalf("shoudn't be called %d times", calls)
return nil
}
}}
h := Http{Matcher: m}
res := h.discoveredServers(context.Background(), time.Millisecond)
assert.Equal(t, []string{"s1", "s2", "s3"}, res)
}