diff --git a/app/proxy/proxy.go b/app/proxy/proxy.go index eea5e8d..281d4b9 100644 --- a/app/proxy/proxy.go +++ b/app/proxy/proxy.go @@ -138,15 +138,9 @@ func (h *Http) Run(ctx context.Context) error { gzipHandler(h.GzEnabled), // gzip response ) + // no FQDNs defined, use the list of discovered servers if len(h.SSLConfig.FQDNs) == 0 && h.SSLConfig.SSLMode == SSLAuto { - // discovery async and may happen not right away. Try to get servers for some time - for i := 0; i < 100; i++ { - h.SSLConfig.FQDNs = h.Servers() // fill all discovered if nothing defined - if len(h.SSLConfig.FQDNs) > 0 { - break - } - time.Sleep(50 * time.Millisecond) - } + h.SSLConfig.FQDNs = h.discoveredServers(ctx, 50*time.Millisecond) } switch h.SSLConfig.SSLMode { @@ -423,3 +417,26 @@ func (h *Http) setXRealIP(r *http.Request) { } r.Header.Add("X-Real-IP", ip) } + +// discoveredServers gets the list of servers discovered by providers. +// The underlying discovery is async and may happen not right away. +// We should try to get servers for some time and make sure we have the complete list of servers +// by checking if the number of servers has not changed between two calls. +func (h *Http) discoveredServers(ctx context.Context, interval time.Duration) (servers []string) { + discoveredServers := 0 + + for i := 0; i < 100; i++ { + select { + case <-ctx.Done(): + return nil + default: + } + servers = h.Servers() // fill all discovered if nothing defined + if len(servers) > 0 && len(servers) == discoveredServers { + break + } + discoveredServers = len(servers) + time.Sleep(interval) + } + return servers +} diff --git a/app/proxy/proxy_test.go b/app/proxy/proxy_test.go index 7b18716..5fb6c2d 100644 --- a/app/proxy/proxy_test.go +++ b/app/proxy/proxy_test.go @@ -890,3 +890,27 @@ func TestHttp_matchHandler(t *testing.T) { }) } } + +func TestHttp_discoveredServers(t *testing.T) { + + calls := 0 + m := &MatcherMock{ServersFunc: func() []string { + defer func() { calls++ }() + switch calls { + case 0, 1, 2, 3, 4: + return []string{} + case 5: + return []string{"s1", "s2"} + case 6, 7: + return []string{"s1", "s2", "s3"} + default: + t.Fatalf("shoudn't be called %d times", calls) + return nil + } + }} + + h := Http{Matcher: m} + + res := h.discoveredServers(context.Background(), time.Millisecond) + assert.Equal(t, []string{"s1", "s2", "s3"}, res) +}