Added a unit-test for custom upstreams

This commit is contained in:
Andrey Meshkov 2020-05-14 12:57:41 +03:00
parent 67a39045fc
commit ae51de9335
3 changed files with 44 additions and 14 deletions

View File

@ -249,6 +249,39 @@ func TestBlockedRequest(t *testing.T) {
} }
} }
func TestServerCustomClientUpstream(t *testing.T) {
s := createTestServer(t)
err := s.Start()
if err != nil {
t.Fatalf("Failed to start server: %s", err)
}
s.conf.GetCustomUpstreamByClient = func(clientAddr string) *proxy.UpstreamConfig {
uc := &proxy.UpstreamConfig{}
u := &testUpstream{}
u.ipv4 = map[string][]net.IP{}
u.ipv4["host."] = []net.IP{net.ParseIP("192.168.0.1")}
uc.Upstreams = append(uc.Upstreams, u)
return uc
}
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
// Send test request
req := dns.Msg{}
req.Id = dns.Id()
req.RecursionDesired = true
req.Question = []dns.Question{
{Name: "host.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
}
reply, err := dns.Exchange(&req, addr.String())
assert.Nil(t, err)
assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
assert.NotNil(t, reply.Answer)
assert.Equal(t, "192.168.0.1", reply.Answer[0].(*dns.A).A.String())
assert.Nil(t, s.Stop())
}
// testUpstream is a mock of real upstream. // testUpstream is a mock of real upstream.
// specify fields with necessary values to simulate real upstream behaviour // specify fields with necessary values to simulate real upstream behaviour
type testUpstream struct { type testUpstream struct {

View File

@ -1,14 +0,0 @@
package dnsforward
import "net"
// GetIPString is a helper function that extracts IP address from net.Addr
func GetIPString(addr net.Addr) string {
switch addr := addr.(type) {
case *net.UDPAddr:
return addr.IP.String()
case *net.TCPAddr:
return addr.IP.String()
}
return ""
}

View File

@ -8,6 +8,17 @@ import (
"github.com/AdguardTeam/golibs/utils" "github.com/AdguardTeam/golibs/utils"
) )
// GetIPString is a helper function that extracts IP address from net.Addr
func GetIPString(addr net.Addr) string {
switch addr := addr.(type) {
case *net.UDPAddr:
return addr.IP.String()
case *net.TCPAddr:
return addr.IP.String()
}
return ""
}
func stringArrayDup(a []string) []string { func stringArrayDup(a []string) []string {
a2 := make([]string, len(a)) a2 := make([]string, len(a))
copy(a2, a) copy(a2, a)