reproxy/app/plugin/conductor_test.go
2023-02-27 14:51:25 -06:00

424 lines
14 KiB
Go

package plugin
import (
"bytes"
"context"
"encoding/json"
"errors"
"math/rand"
"net/http"
"net/http/httptest"
"regexp"
"strconv"
"testing"
"time"
log "github.com/go-pkgz/lgr"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/umputun/reproxy/app/discovery"
"github.com/umputun/reproxy/lib"
)
func TestConductor_registrationHandler(t *testing.T) {
rpcClient := &RPCClientMock{
CallFunc: func(serviceMethod string, args interface{}, reply interface{}) error {
return nil
},
}
dialer := &RPCDialerMock{
DialFunc: func(network string, address string) (RPCClient, error) {
return rpcClient, nil
},
}
c := Conductor{RPCDialer: dialer}
ts := httptest.NewServer(c.registrationHandler())
defer ts.Close()
client := http.Client{Timeout: time.Second}
{ // register plugin with two methods
plugin := lib.Plugin{Name: "Test1", Address: "127.0.0.1:0001", Methods: []string{"Mw1", "Mw2"}}
data, err := json.Marshal(plugin)
require.NoError(t, err)
req, err := http.NewRequest("POST", ts.URL, bytes.NewReader(data))
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, 2, len(c.plugins), "two plugins registered")
assert.Equal(t, "Test1.Mw1", c.plugins[0].Method)
assert.Equal(t, "127.0.0.1:0001", c.plugins[0].Address)
assert.Equal(t, true, c.plugins[0].Alive)
assert.Equal(t, "127.0.0.1:0001", c.plugins[1].Address)
assert.Equal(t, "Test1.Mw2", c.plugins[1].Method)
assert.Equal(t, true, c.plugins[1].Alive)
assert.Equal(t, 0, len(rpcClient.CallCalls()))
assert.Equal(t, 1, len(dialer.DialCalls()))
}
{ // same registration
plugin := lib.Plugin{Name: "Test1", Address: "127.0.0.1:0001", Methods: []string{"Mw1", "Mw2"}}
data, err := json.Marshal(plugin)
require.NoError(t, err)
req, err := http.NewRequest("POST", ts.URL, bytes.NewReader(data))
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, 2, len(c.plugins), "two plugins registered")
assert.Equal(t, 0, len(rpcClient.CallCalls()))
assert.Equal(t, 1, len(dialer.DialCalls()))
}
{ // address changed
plugin := lib.Plugin{Name: "Test1", Address: "127.0.0.2:8002", Methods: []string{"Mw1", "Mw2"}}
data, err := json.Marshal(plugin)
require.NoError(t, err)
req, err := http.NewRequest("POST", ts.URL, bytes.NewReader(data))
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, 2, len(c.plugins), "two plugins registered")
assert.Equal(t, "Test1.Mw1", c.plugins[0].Method)
assert.Equal(t, "127.0.0.2:8002", c.plugins[0].Address)
assert.Equal(t, true, c.plugins[0].Alive)
assert.Equal(t, "127.0.0.2:8002", c.plugins[1].Address)
assert.Equal(t, "Test1.Mw2", c.plugins[1].Method)
assert.Equal(t, true, c.plugins[1].Alive)
assert.Equal(t, 0, len(rpcClient.CallCalls()))
assert.Equal(t, 2, len(dialer.DialCalls()))
}
{ // address changed
plugin := lib.Plugin{Name: "Test2", Address: "127.0.0.3:8003", Methods: []string{"Mw11", "Mw12", "Mw13"}}
data, err := json.Marshal(plugin)
require.NoError(t, err)
req, err := http.NewRequest("POST", ts.URL, bytes.NewReader(data))
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, 2+3, len(c.plugins), "3 more plugins registered")
assert.Equal(t, "Test2.Mw11", c.plugins[2].Method)
assert.Equal(t, "127.0.0.3:8003", c.plugins[2].Address)
assert.Equal(t, true, c.plugins[2].Alive)
assert.Equal(t, 0, len(rpcClient.CallCalls()))
assert.Equal(t, 3, len(dialer.DialCalls()))
}
{ // bad registration
req, err := http.NewRequest("POST", ts.URL, bytes.NewBufferString("bas json body"))
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
}
{ // unsupported registration method
plugin := lib.Plugin{Name: "Test2", Address: "127.0.0.3:8003", Methods: []string{"Mw11", "Mw12", "Mw13"}}
data, err := json.Marshal(plugin)
require.NoError(t, err)
req, err := http.NewRequest("PUT", ts.URL, bytes.NewReader(data))
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
}
{ // unregister
plugin := lib.Plugin{Name: "Test1", Address: "127.0.0.2:8002", Methods: []string{"Mw1", "Mw2"}}
data, err := json.Marshal(plugin)
require.NoError(t, err)
req, err := http.NewRequest("DELETE", ts.URL, bytes.NewReader(data))
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, 3, len(c.plugins), "3 plugins left, 2 removed")
assert.Equal(t, "Test2.Mw11", c.plugins[0].Method)
assert.Equal(t, "127.0.0.3:8003", c.plugins[0].Address)
assert.Equal(t, true, c.plugins[0].Alive)
assert.Equal(t, 0, len(rpcClient.CallCalls()))
assert.Equal(t, 3, len(dialer.DialCalls()))
}
{ // bad unregister
req, err := http.NewRequest("DELETE", ts.URL, bytes.NewBufferString("bad json body"))
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
assert.Equal(t, 3, len(c.plugins), "still 3 plugins left, 2 removed")
}
}
func TestConductor_registrationHandlerInternalError(t *testing.T) {
dialer := &RPCDialerMock{
DialFunc: func(network string, address string) (RPCClient, error) {
return nil, errors.New("failed")
},
}
c := Conductor{RPCDialer: dialer}
ts := httptest.NewServer(c.registrationHandler())
defer ts.Close()
client := http.Client{Timeout: time.Second}
plugin := lib.Plugin{Name: "Test1", Address: "127.0.0.1:0001"}
data, err := json.Marshal(plugin)
require.NoError(t, err)
req, err := http.NewRequest("POST", ts.URL, bytes.NewReader(data))
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
}
func TestConductor_Middleware(t *testing.T) {
rpcClient := &RPCClientMock{
CallFunc: func(serviceMethod string, args interface{}, reply interface{}) error {
if serviceMethod == "Test1.Mw1" {
req := args.(lib.Request)
assert.Equal(t, "route123", req.Route)
assert.Equal(t, "src123", req.Match.Src)
assert.Equal(t, "dst123", req.Match.Dst)
assert.Equal(t, "docker", req.Match.ProviderID)
assert.Equal(t, "server123", req.Match.Server)
assert.Equal(t, "proxy", req.Match.MatchType)
assert.Equal(t, "/webroot", req.Match.AssetsWebRoot)
assert.Equal(t, "loc", req.Match.AssetsLocation)
log.Printf("rr: %+v", req)
reply.(*lib.Response).StatusCode = 200
reply.(*lib.Response).HeadersOut = map[string][]string{}
reply.(*lib.Response).HeadersOut.Set("k1", "v1")
reply.(*lib.Response).HeadersIn = map[string][]string{}
reply.(*lib.Response).HeadersIn.Set("k21", "v21")
}
if serviceMethod == "Test1.Mw2" {
req := args.(lib.Request)
assert.Equal(t, "route123", req.Route)
assert.Equal(t, "src123", req.Match.Src)
assert.Equal(t, "dst123", req.Match.Dst)
assert.Equal(t, "docker", req.Match.ProviderID)
assert.Equal(t, "server123", req.Match.Server)
log.Printf("rr: %+v", req)
reply.(*lib.Response).StatusCode = 200
reply.(*lib.Response).HeadersOut = map[string][]string{}
reply.(*lib.Response).HeadersOut.Set("k11", "v11")
reply.(*lib.Response).OverrideHeadersOut = true
}
if serviceMethod == "Test1.Mw3" {
t.Fatal("shouldn't be called")
}
return nil
},
}
dialer := &RPCDialerMock{
DialFunc: func(network string, address string) (RPCClient, error) {
return rpcClient, nil
},
}
c := Conductor{RPCDialer: dialer, Address: "127.0.0.1:50100"}
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
go func() {
c.Run(ctx)
}()
time.Sleep(time.Millisecond * 50)
client := http.Client{Timeout: time.Second}
// register plugin with 3 methods
plugin := lib.Plugin{Name: "Test1", Address: "127.0.0.1:8001", Methods: []string{"Mw1", "Mw2", "Mw3"}}
data, err := json.Marshal(plugin)
require.NoError(t, err)
req, err := http.NewRequest("POST", "http://127.0.0.1:50100", bytes.NewReader(data))
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, 3, len(c.plugins), "3 plugins registered")
c.plugins[2].Alive = false // set 3rd to dead
rr, err := http.NewRequest("GET", "http://127.0.0.1", http.NoBody)
require.NoError(t, err)
m := discovery.MatchedRoute{
Destination: "route123",
Mapper: discovery.URLMapper{
Server: "server123",
ProviderID: discovery.PIDocker,
MatchType: discovery.MTProxy,
SrcMatch: *regexp.MustCompile("src123"),
Dst: "dst123",
AssetsWebRoot: "/webroot",
AssetsLocation: "loc",
},
}
rr = rr.WithContext(context.WithValue(rr.Context(), CtxMatch, m))
w := httptest.NewRecorder()
h := c.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("k2", "v2")
w.Write([]byte("something"))
assert.Equal(t, "v21", r.Header.Get("k21"))
}))
h.ServeHTTP(w, rr)
assert.Equal(t, 200, w.Result().StatusCode)
assert.Equal(t, "", w.Result().Header.Get("k1"))
assert.Equal(t, "v2", w.Result().Header.Get("k2"))
assert.Equal(t, "v21", rr.Header.Get("k21"))
assert.Equal(t, "v11", w.Result().Header.Get("k11"))
t.Logf("req: %+v", rr)
t.Logf("resp: %+v", w.Result())
}
func TestConductor_MiddlewarePluginBadStatus(t *testing.T) {
rpcClient := &RPCClientMock{
CallFunc: func(serviceMethod string, args interface{}, reply interface{}) error {
if serviceMethod == "Test1.Mw1" {
req := args.(lib.Request)
assert.Equal(t, "route123", req.Route)
assert.Equal(t, "src123", req.Match.Src)
assert.Equal(t, "dst123", req.Match.Dst)
assert.Equal(t, "docker", req.Match.ProviderID)
assert.Equal(t, "server123", req.Match.Server)
log.Printf("rr: %+v", req)
reply.(*lib.Response).StatusCode = 404
}
return nil
},
}
dialer := &RPCDialerMock{
DialFunc: func(network string, address string) (RPCClient, error) {
return rpcClient, nil
},
}
port := rand.Intn(30000)
c := Conductor{RPCDialer: dialer, Address: "127.0.0.1:" + strconv.Itoa(30000+port)}
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
go func() {
c.Run(ctx)
}()
time.Sleep(time.Millisecond * 150)
client := http.Client{Timeout: time.Second}
// register plugin with one methods
plugin := lib.Plugin{Name: "Test1", Address: "127.0.0.1:8001", Methods: []string{"Mw1"}}
data, err := json.Marshal(plugin)
require.NoError(t, err)
req, err := http.NewRequest("POST", "http://127.0.0.1:"+strconv.Itoa(30000+port), bytes.NewReader(data))
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, 1, len(c.plugins), "one plugin registered")
rr, err := http.NewRequest("GET", "http://127.0.0.1", http.NoBody)
require.NoError(t, err)
m := discovery.MatchedRoute{
Destination: "route123",
Mapper: discovery.URLMapper{
Server: "server123",
ProviderID: discovery.PIDocker,
MatchType: discovery.MTProxy,
SrcMatch: *regexp.MustCompile("src123"),
Dst: "dst123",
AssetsWebRoot: "/webroot",
AssetsLocation: "loc",
},
}
rr = rr.WithContext(context.WithValue(rr.Context(), CtxMatch, m))
w := httptest.NewRecorder()
h := c.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Failed() // handler not called on plugin middleware error
}))
h.ServeHTTP(w, rr)
assert.Equal(t, 404, w.Result().StatusCode)
assert.Equal(t, "", rr.Header.Get("k1")) // header not set by plugin on error
t.Logf("req: %+v", rr)
t.Logf("resp: %+v", w.Result())
}
func TestConductor_MiddlewarePluginFailed(t *testing.T) {
rpcClient := &RPCClientMock{
CallFunc: func(serviceMethod string, args interface{}, reply interface{}) error {
if serviceMethod == "Test1.Mw1" {
return errors.New("something failed")
}
return nil
},
}
dialer := &RPCDialerMock{
DialFunc: func(network string, address string) (RPCClient, error) {
return rpcClient, nil
},
}
c := Conductor{RPCDialer: dialer, Address: "127.0.0.1:50100"}
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
go func() {
c.Run(ctx)
}()
time.Sleep(time.Millisecond * 250)
client := http.Client{Timeout: time.Second}
// register plugin with one methods
plugin := lib.Plugin{Name: "Test1", Address: "127.0.0.1:8001", Methods: []string{"Mw1"}}
data, err := json.Marshal(plugin)
require.NoError(t, err)
req, err := http.NewRequest("POST", "http://127.0.0.1:50100", bytes.NewReader(data))
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, 1, len(c.plugins), "one plugin registered")
rr, err := http.NewRequest("GET", "http://127.0.0.1", http.NoBody)
require.NoError(t, err)
w := httptest.NewRecorder()
h := c.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Failed() // handler not called on plugin middleware error
}))
h.ServeHTTP(w, rr)
assert.Equal(t, 500, w.Result().StatusCode)
assert.Equal(t, "", rr.Header.Get("k1")) // header not set by plugin on error
t.Logf("req: %+v", rr)
t.Logf("resp: %+v", w.Result())
}