Pull request: 2574 external tests vol.3

Merge in DNS/adguard-home from 2574-external-tests-3 to master

Updates #2574.

Squashed commit of the following:

commit 29d429c65dee2621ca503710a7ba9522f14f55f9
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Thu Feb 4 20:06:57 2021 +0300

    all: finally fix spacing

commit 9e3a3be63b74852a7802e3f1832648444b58e4d0
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Thu Feb 4 19:59:09 2021 +0300

    aghtest: polish spacing

commit 8a984159fe813b95b989803f5b8b78d01a41bd39
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Thu Feb 4 18:44:47 2021 +0300

    all: fix linux tests, imp code quality

commit 0c1b42bacba1b23fa847e1fa032579c525b3eaa1
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Thu Feb 4 17:33:12 2021 +0300

    all: mv testutil to aghtest package, imp tests
This commit is contained in:
Eugene Burkov 2021-02-04 20:35:13 +03:00
parent 8aec08727c
commit c9d2436d77
24 changed files with 737 additions and 496 deletions

View File

@ -5,14 +5,9 @@ import (
"fmt" "fmt"
"testing" "testing"
"github.com/AdguardTeam/AdGuardHome/internal/testutil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestMain(m *testing.M) {
testutil.DiscardLogOutput(m)
}
func TestError_Error(t *testing.T) { func TestError_Error(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string

View File

@ -1,5 +1,5 @@
// Package testutil contains utilities for testing. // Package aghtest contains utilities for testing.
package testutil package aghtest
import ( import (
"io" "io"

View File

@ -0,0 +1,63 @@
package aghtest
import (
"context"
"crypto/sha256"
"net"
"sync"
)
// TestResolver is a Resolver for tests.
type TestResolver struct {
counter int
counterLock sync.Mutex
}
// HostToIPs generates IPv4 and IPv6 from host.
//
// TODO(e.burkov): Replace with LookupIP after upgrading go to v1.15.
func (r *TestResolver) HostToIPs(host string) (ipv4, ipv6 net.IP) {
hash := sha256.Sum256([]byte(host))
return net.IP(hash[:4]), net.IP(hash[4:20])
}
// LookupIPAddr implements Resolver interface for *testResolver. It returns the
// slice of net.IPAddr with IPv4 and IPv6 instances.
func (r *TestResolver) LookupIPAddr(_ context.Context, host string) (ips []net.IPAddr, err error) {
ipv4, ipv6 := r.HostToIPs(host)
addrs := []net.IPAddr{{
IP: ipv4,
}, {
IP: ipv6,
}}
r.counterLock.Lock()
defer r.counterLock.Unlock()
r.counter++
return addrs, nil
}
// LookupHost implements Resolver interface for *testResolver. It returns the
// slice of IPv4 and IPv6 instances converted to strings.
func (r *TestResolver) LookupHost(host string) (addrs []string, err error) {
ipv4, ipv6 := r.HostToIPs(host)
r.counterLock.Lock()
defer r.counterLock.Unlock()
r.counter++
return []string{
ipv4.String(),
ipv6.String(),
}, nil
}
// Counter returns the number of requests handled.
func (r *TestResolver) Counter() int {
r.counterLock.Lock()
defer r.counterLock.Unlock()
return r.counter
}

View File

@ -0,0 +1,175 @@
package aghtest
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"net"
"strings"
"sync"
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
"github.com/miekg/dns"
)
// TestUpstream is a mock of real upstream.
type TestUpstream struct {
// Addr is the address for Address method.
Addr string
// CName is a map of hostname to canonical name.
CName map[string]string
// IPv4 is a map of hostname to IPv4.
IPv4 map[string][]net.IP
// IPv6 is a map of hostname to IPv6.
IPv6 map[string][]net.IP
// Reverse is a map of address to domain name.
Reverse map[string][]string
}
// Exchange implements upstream.Upstream interface for *TestUpstream.
func (u *TestUpstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
resp = &dns.Msg{}
resp.SetReply(m)
if len(m.Question) == 0 {
return nil, fmt.Errorf("question should not be empty")
}
name := m.Question[0].Name
if cname, ok := u.CName[name]; ok {
resp.Answer = append(resp.Answer, &dns.CNAME{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeCNAME,
},
Target: cname,
})
}
var hasRec bool
var rrType uint16
var ips []net.IP
switch m.Question[0].Qtype {
case dns.TypeA:
rrType = dns.TypeA
if ipv4addr, ok := u.IPv4[name]; ok {
hasRec = true
ips = ipv4addr
}
case dns.TypeAAAA:
rrType = dns.TypeAAAA
if ipv6addr, ok := u.IPv6[name]; ok {
hasRec = true
ips = ipv6addr
}
case dns.TypePTR:
names, ok := u.Reverse[name]
if !ok {
break
}
for _, n := range names {
resp.Answer = append(resp.Answer, &dns.PTR{
Hdr: dns.RR_Header{
Name: name,
Rrtype: rrType,
},
Ptr: n,
})
}
}
for _, ip := range ips {
resp.Answer = append(resp.Answer, &dns.A{
Hdr: dns.RR_Header{
Name: name,
Rrtype: rrType,
},
A: ip,
})
}
if len(resp.Answer) == 0 {
if hasRec {
// Set no error RCode if there are some records for
// given Qname but we didn't apply them.
resp.SetRcode(m, dns.RcodeSuccess)
return resp, nil
}
// Set NXDomain RCode otherwise.
resp.SetRcode(m, dns.RcodeNameError)
}
return resp, nil
}
// Address implements upstream.Upstream interface for *TestUpstream.
func (u *TestUpstream) Address() string {
return u.Addr
}
// TestBlockUpstream implements upstream.Upstream interface for replacing real
// upstream in tests.
type TestBlockUpstream struct {
Hostname string
Block bool
requestsCount int
lock sync.RWMutex
}
// Exchange returns a message unique for TestBlockUpstream's Hostname-Block
// pair.
func (u *TestBlockUpstream) Exchange(r *dns.Msg) (*dns.Msg, error) {
u.lock.Lock()
defer u.lock.Unlock()
u.requestsCount++
hash := sha256.Sum256([]byte(u.Hostname))
hashToReturn := hex.EncodeToString(hash[:])
if !u.Block {
hashToReturn = hex.EncodeToString(hash[:])[:2] + strings.Repeat("ab", 28)
}
m := &dns.Msg{}
m.Answer = []dns.RR{
&dns.TXT{
Hdr: dns.RR_Header{
Name: r.Question[0].Name,
},
Txt: []string{
hashToReturn,
},
},
}
return m, nil
}
// Address always returns an empty string.
func (u *TestBlockUpstream) Address() string {
return ""
}
// RequestsCount returns the number of handled requests. It's safe for
// concurrent use.
func (u *TestBlockUpstream) RequestsCount() int {
u.lock.Lock()
defer u.lock.Unlock()
return u.requestsCount
}
// TestErrUpstream implements upstream.Upstream interface for replacing real
// upstream in tests.
type TestErrUpstream struct{}
// Exchange always returns nil Msg and non-nil error.
func (u *TestErrUpstream) Exchange(*dns.Msg) (*dns.Msg, error) {
return nil, agherr.Error("bad")
}
// Address always returns an empty string.
func (u *TestErrUpstream) Address() string {
return ""
}

View File

@ -9,12 +9,12 @@ import (
"testing" "testing"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/testutil" "github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
testutil.DiscardLogOutput(m) aghtest.DiscardLogOutput(m)
} }
func testNotify(flags uint32) { func testNotify(flags uint32) {

View File

@ -17,14 +17,14 @@ import (
"testing" "testing"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/testutil" "github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/hugelgupf/socketpair" "github.com/hugelgupf/socketpair"
"github.com/insomniacslk/dhcp/dhcpv4" "github.com/insomniacslk/dhcp/dhcpv4"
"github.com/insomniacslk/dhcp/dhcpv4/server4" "github.com/insomniacslk/dhcp/dhcpv4/server4"
) )
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
testutil.DiscardLogOutput(m) aghtest.DiscardLogOutput(m)
} }
type handler struct { type handler struct {

View File

@ -43,6 +43,12 @@ type RequestFilteringSettings struct {
ServicesRules []ServiceEntry ServicesRules []ServiceEntry
} }
// Resolver is the interface for net.Resolver to simplify testing.
type Resolver interface {
// TODO(e.burkov): Replace with LookupIP after upgrading go to v1.15.
LookupIPAddr(ctx context.Context, host string) (ips []net.IPAddr, err error)
}
// Config allows you to configure DNS filtering with New() or just change variables directly. // Config allows you to configure DNS filtering with New() or just change variables directly.
type Config struct { type Config struct {
ParentalEnabled bool `yaml:"parental_enabled"` ParentalEnabled bool `yaml:"parental_enabled"`
@ -69,6 +75,9 @@ type Config struct {
// Register an HTTP handler // Register an HTTP handler
HTTPRegister func(string, string, func(http.ResponseWriter, *http.Request)) `yaml:"-"` HTTPRegister func(string, string, func(http.ResponseWriter, *http.Request)) `yaml:"-"`
// CustomResolver is the resolver used by DNSFilter.
CustomResolver Resolver
} }
// LookupStats store stats collected during safebrowsing or parental checks // LookupStats store stats collected during safebrowsing or parental checks
@ -92,12 +101,6 @@ type filtersInitializerParams struct {
blockFilters []Filter blockFilters []Filter
} }
// Resolver is the interface for net.Resolver to simplify testing.
type Resolver interface {
// TODO(e.burkov): Replace with LookupIP after upgrading go to v1.15.
LookupIPAddr(ctx context.Context, host string) (ips []net.IPAddr, err error)
}
// DNSFilter matches hostnames and DNS requests against filtering rules. // DNSFilter matches hostnames and DNS requests against filtering rules.
type DNSFilter struct { type DNSFilter struct {
rulesStorage *filterlist.RuleStorage rulesStorage *filterlist.RuleStorage
@ -796,6 +799,7 @@ func InitModule() {
// New creates properly initialized DNS Filter that is ready to be used. // New creates properly initialized DNS Filter that is ready to be used.
func New(c *Config, blockFilters []Filter) *DNSFilter { func New(c *Config, blockFilters []Filter) *DNSFilter {
var resolver Resolver = net.DefaultResolver
if c != nil { if c != nil {
cacheConf := cache.Config{ cacheConf := cache.Config{
EnableLRU: true, EnableLRU: true,
@ -815,10 +819,14 @@ func New(c *Config, blockFilters []Filter) *DNSFilter {
cacheConf.MaxSize = c.ParentalCacheSize cacheConf.MaxSize = c.ParentalCacheSize
gctx.parentalCache = cache.New(cacheConf) gctx.parentalCache = cache.New(cacheConf)
} }
if c.CustomResolver != nil {
resolver = c.CustomResolver
}
} }
d := &DNSFilter{ d := &DNSFilter{
resolver: net.DefaultResolver, resolver: resolver,
} }
err := d.initSecurityServices() err := d.initSecurityServices()

View File

@ -3,12 +3,11 @@ package dnsfilter
import ( import (
"bytes" "bytes"
"context" "context"
"crypto/sha256"
"fmt" "fmt"
"net" "net"
"testing" "testing"
"github.com/AdguardTeam/AdGuardHome/internal/testutil" "github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/golibs/cache" "github.com/AdguardTeam/golibs/cache"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/urlfilter/rules" "github.com/AdguardTeam/urlfilter/rules"
@ -17,7 +16,7 @@ import (
) )
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
testutil.DiscardLogOutput(m) aghtest.DiscardLogOutput(m)
} }
var setts RequestFilteringSettings var setts RequestFilteringSettings
@ -37,7 +36,9 @@ func purgeCaches() {
} }
func newForTest(c *Config, filters []Filter) *DNSFilter { func newForTest(c *Config, filters []Filter) *DNSFilter {
setts = RequestFilteringSettings{} setts = RequestFilteringSettings{
FilteringEnabled: true,
}
setts.FilteringEnabled = true setts.FilteringEnabled = true
if c != nil { if c != nil {
c.SafeBrowsingCacheSize = 10000 c.SafeBrowsingCacheSize = 10000
@ -149,16 +150,16 @@ func TestEtcHostsMatching(t *testing.T) {
func TestSafeBrowsing(t *testing.T) { func TestSafeBrowsing(t *testing.T) {
logOutput := &bytes.Buffer{} logOutput := &bytes.Buffer{}
testutil.ReplaceLogWriter(t, logOutput) aghtest.ReplaceLogWriter(t, logOutput)
testutil.ReplaceLogLevel(t, log.DEBUG) aghtest.ReplaceLogLevel(t, log.DEBUG)
d := newForTest(&Config{SafeBrowsingEnabled: true}, nil) d := newForTest(&Config{SafeBrowsingEnabled: true}, nil)
t.Cleanup(d.Close) t.Cleanup(d.Close)
matching := "wmconvirus.narod.ru" matching := "wmconvirus.narod.ru"
d.safeBrowsingUpstream = &testSbUpstream{ d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{
hostname: matching, Hostname: matching,
block: true, Block: true,
} })
d.checkMatch(t, matching) d.checkMatch(t, matching)
assert.Contains(t, logOutput.String(), "SafeBrowsing lookup for "+matching) assert.Contains(t, logOutput.String(), "SafeBrowsing lookup for "+matching)
@ -178,10 +179,10 @@ func TestParallelSB(t *testing.T) {
d := newForTest(&Config{SafeBrowsingEnabled: true}, nil) d := newForTest(&Config{SafeBrowsingEnabled: true}, nil)
t.Cleanup(d.Close) t.Cleanup(d.Close)
matching := "wmconvirus.narod.ru" matching := "wmconvirus.narod.ru"
d.safeBrowsingUpstream = &testSbUpstream{ d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{
hostname: matching, Hostname: matching,
block: true, Block: true,
} })
t.Run("group", func(t *testing.T) { t.Run("group", func(t *testing.T) {
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
@ -228,26 +229,12 @@ func TestCheckHostSafeSearchYandex(t *testing.T) {
} }
} }
// testResolver is a Resolver for tests.
type testResolver struct{}
// LookupIP implements Resolver interface for *testResolver.
func (r *testResolver) LookupIPAddr(_ context.Context, host string) (ips []net.IPAddr, err error) {
hash := sha256.Sum256([]byte(host))
addrs := []net.IPAddr{{
IP: net.IP(hash[:4]),
Zone: "somezone",
}, {
IP: net.IP(hash[4:20]),
Zone: "somezone",
}}
return addrs, nil
}
func TestCheckHostSafeSearchGoogle(t *testing.T) { func TestCheckHostSafeSearchGoogle(t *testing.T) {
d := newForTest(&Config{SafeSearchEnabled: true}, nil) d := newForTest(&Config{
SafeSearchEnabled: true,
CustomResolver: &aghtest.TestResolver{},
}, nil)
t.Cleanup(d.Close) t.Cleanup(d.Close)
d.resolver = &testResolver{}
// Check host for each domain. // Check host for each domain.
for _, host := range []string{ for _, host := range []string{
@ -299,12 +286,12 @@ func TestSafeSearchCacheYandex(t *testing.T) {
} }
func TestSafeSearchCacheGoogle(t *testing.T) { func TestSafeSearchCacheGoogle(t *testing.T) {
d := newForTest(nil, nil) resolver := &aghtest.TestResolver{}
d := newForTest(&Config{
CustomResolver: resolver,
}, nil)
t.Cleanup(d.Close) t.Cleanup(d.Close)
resolver := &testResolver{}
d.resolver = resolver
domain := "www.google.ru" domain := "www.google.ru"
res, err := d.CheckHost(domain, dns.TypeA, &setts) res, err := d.CheckHost(domain, dns.TypeA, &setts)
assert.Nil(t, err) assert.Nil(t, err)
@ -350,16 +337,16 @@ func TestSafeSearchCacheGoogle(t *testing.T) {
func TestParentalControl(t *testing.T) { func TestParentalControl(t *testing.T) {
logOutput := &bytes.Buffer{} logOutput := &bytes.Buffer{}
testutil.ReplaceLogWriter(t, logOutput) aghtest.ReplaceLogWriter(t, logOutput)
testutil.ReplaceLogLevel(t, log.DEBUG) aghtest.ReplaceLogLevel(t, log.DEBUG)
d := newForTest(&Config{ParentalEnabled: true}, nil) d := newForTest(&Config{ParentalEnabled: true}, nil)
t.Cleanup(d.Close) t.Cleanup(d.Close)
matching := "pornhub.com" matching := "pornhub.com"
d.parentalUpstream = &testSbUpstream{ d.SetParentalUpstream(&aghtest.TestBlockUpstream{
hostname: matching, Hostname: matching,
block: true, Block: true,
} })
d.checkMatch(t, matching) d.checkMatch(t, matching)
assert.Contains(t, logOutput.String(), "Parental lookup for "+matching) assert.Contains(t, logOutput.String(), "Parental lookup for "+matching)
@ -733,14 +720,14 @@ func TestClientSettings(t *testing.T) {
}}, }},
) )
t.Cleanup(d.Close) t.Cleanup(d.Close)
d.parentalUpstream = &testSbUpstream{ d.SetParentalUpstream(&aghtest.TestBlockUpstream{
hostname: "pornhub.com", Hostname: "pornhub.com",
block: true, Block: true,
} })
d.safeBrowsingUpstream = &testSbUpstream{ d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{
hostname: "wmconvirus.narod.ru", Hostname: "wmconvirus.narod.ru",
block: true, Block: true,
} })
type testCase struct { type testCase struct {
name string name string
@ -801,10 +788,10 @@ func BenchmarkSafeBrowsing(b *testing.B) {
d := newForTest(&Config{SafeBrowsingEnabled: true}, nil) d := newForTest(&Config{SafeBrowsingEnabled: true}, nil)
b.Cleanup(d.Close) b.Cleanup(d.Close)
blocked := "wmconvirus.narod.ru" blocked := "wmconvirus.narod.ru"
d.safeBrowsingUpstream = &testSbUpstream{ d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{
hostname: blocked, Hostname: blocked,
block: true, Block: true,
} })
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
res, err := d.CheckHost(blocked, dns.TypeA, &setts) res, err := d.CheckHost(blocked, dns.TypeA, &setts)
assert.Nilf(b, err, "Error while matching host %s: %s", blocked, err) assert.Nilf(b, err, "Error while matching host %s: %s", blocked, err)
@ -816,10 +803,10 @@ func BenchmarkSafeBrowsingParallel(b *testing.B) {
d := newForTest(&Config{SafeBrowsingEnabled: true}, nil) d := newForTest(&Config{SafeBrowsingEnabled: true}, nil)
b.Cleanup(d.Close) b.Cleanup(d.Close)
blocked := "wmconvirus.narod.ru" blocked := "wmconvirus.narod.ru"
d.safeBrowsingUpstream = &testSbUpstream{ d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{
hostname: blocked, Hostname: blocked,
block: true, Block: true,
} })
b.RunParallel(func(pb *testing.PB) { b.RunParallel(func(pb *testing.PB) {
for pb.Next() { for pb.Next() {
res, err := d.CheckHost(blocked, dns.TypeA, &setts) res, err := d.CheckHost(blocked, dns.TypeA, &setts)

View File

@ -30,6 +30,20 @@ const (
pcTXTSuffix = `pc.dns.adguard.com.` pcTXTSuffix = `pc.dns.adguard.com.`
) )
// SetParentalUpstream sets the parental upstream for *DNSFilter.
//
// TODO(e.burkov): Remove this in v1 API to forbid the direct access.
func (d *DNSFilter) SetParentalUpstream(u upstream.Upstream) {
d.parentalUpstream = u
}
// SetSafeBrowsingUpstream sets the safe browsing upstream for *DNSFilter.
//
// TODO(e.burkov): Remove this in v1 API to forbid the direct access.
func (d *DNSFilter) SetSafeBrowsingUpstream(u upstream.Upstream) {
d.safeBrowsingUpstream = u
}
func (d *DNSFilter) initSecurityServices() error { func (d *DNSFilter) initSecurityServices() error {
var err error var err error
d.safeBrowsingServer = defaultSafebrowsingServer d.safeBrowsingServer = defaultSafebrowsingServer
@ -44,15 +58,17 @@ func (d *DNSFilter) initSecurityServices() error {
}, },
} }
d.parentalUpstream, err = upstream.AddressToUpstream(d.parentalServer, opts) parUps, err := upstream.AddressToUpstream(d.parentalServer, opts)
if err != nil { if err != nil {
return fmt.Errorf("converting parental server: %w", err) return fmt.Errorf("converting parental server: %w", err)
} }
d.SetParentalUpstream(parUps)
d.safeBrowsingUpstream, err = upstream.AddressToUpstream(d.safeBrowsingServer, opts) sbUps, err := upstream.AddressToUpstream(d.safeBrowsingServer, opts)
if err != nil { if err != nil {
return fmt.Errorf("converting safe browsing server: %w", err) return fmt.Errorf("converting safe browsing server: %w", err)
} }
d.SetSafeBrowsingUpstream(sbUps)
return nil return nil
} }
@ -227,7 +243,7 @@ func (c *sbCtx) processTXT(resp *dns.Msg) (bool, [][]byte) {
func (c *sbCtx) storeCache(hashes [][]byte) { func (c *sbCtx) storeCache(hashes [][]byte) {
sort.Slice(hashes, func(a, b int) bool { sort.Slice(hashes, func(a, b int) bool {
return bytes.Compare(hashes[a], hashes[b]) < 0 return bytes.Compare(hashes[a], hashes[b]) == -1
}) })
var curData []byte var curData []byte

View File

@ -2,14 +2,11 @@ package dnsfilter
import ( import (
"crypto/sha256" "crypto/sha256"
"encoding/hex"
"strings" "strings"
"sync"
"testing" "testing"
"github.com/AdguardTeam/AdGuardHome/internal/agherr" "github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/golibs/cache" "github.com/AdguardTeam/golibs/cache"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -108,27 +105,14 @@ func TestSafeBrowsingCache(t *testing.T) {
assert.Empty(t, c.getCached()) assert.Empty(t, c.getCached())
} }
// testErrUpstream implements upstream.Upstream interface for replacing real
// upstream in tests.
type testErrUpstream struct{}
// Exchange always returns nil Msg and non-nil error.
func (teu *testErrUpstream) Exchange(*dns.Msg) (*dns.Msg, error) {
return nil, agherr.Error("bad")
}
func (teu *testErrUpstream) Address() string {
return ""
}
func TestSBPC_checkErrorUpstream(t *testing.T) { func TestSBPC_checkErrorUpstream(t *testing.T) {
d := newForTest(&Config{SafeBrowsingEnabled: true}, nil) d := newForTest(&Config{SafeBrowsingEnabled: true}, nil)
t.Cleanup(d.Close) t.Cleanup(d.Close)
ups := &testErrUpstream{} ups := &aghtest.TestErrUpstream{}
d.safeBrowsingUpstream = ups d.SetSafeBrowsingUpstream(ups)
d.parentalUpstream = ups d.SetParentalUpstream(ups)
_, err := d.checkSafeBrowsing("smthng.com") _, err := d.checkSafeBrowsing("smthng.com")
assert.NotNil(t, err) assert.NotNil(t, err)
@ -137,122 +121,86 @@ func TestSBPC_checkErrorUpstream(t *testing.T) {
assert.NotNil(t, err) assert.NotNil(t, err)
} }
// testSbUpstream implements upstream.Upstream interface for replacing real func TestSBPC(t *testing.T) {
// upstream in tests.
type testSbUpstream struct {
hostname string
block bool
requestsCount int
counterLock sync.RWMutex
}
// Exchange returns a message depending on the upstream settings (hostname, block)
func (u *testSbUpstream) Exchange(r *dns.Msg) (*dns.Msg, error) {
u.counterLock.Lock()
u.requestsCount++
u.counterLock.Unlock()
hash := sha256.Sum256([]byte(u.hostname))
prefix := hash[0:2]
hashToReturn := hex.EncodeToString(prefix) + strings.Repeat("ab", 28)
if u.block {
hashToReturn = hex.EncodeToString(hash[:])
}
m := &dns.Msg{}
m.Answer = []dns.RR{
&dns.TXT{
Hdr: dns.RR_Header{
Name: r.Question[0].Name,
},
Txt: []string{
hashToReturn,
},
},
}
return m, nil
}
func (u *testSbUpstream) Address() string {
return ""
}
func TestSBPC_sbValidResponse(t *testing.T) {
d := newForTest(&Config{SafeBrowsingEnabled: true}, nil) d := newForTest(&Config{SafeBrowsingEnabled: true}, nil)
t.Cleanup(d.Close) t.Cleanup(d.Close)
ups := &testSbUpstream{} const hostname = "example.org"
d.safeBrowsingUpstream = ups
d.parentalUpstream = ups
// Prepare the upstream testCases := []struct {
ups.hostname = "example.org" name string
ups.block = false block bool
ups.requestsCount = 0 testFunc func(string) (Result, error)
testCache cache.Cache
}{{
name: "sb_no_block",
block: false,
testFunc: d.checkSafeBrowsing,
testCache: gctx.safebrowsingCache,
}, {
name: "sb_block",
block: true,
testFunc: d.checkSafeBrowsing,
testCache: gctx.safebrowsingCache,
}, {
name: "pc_no_block",
block: false,
testFunc: d.checkParental,
testCache: gctx.parentalCache,
}, {
name: "pc_block",
block: true,
testFunc: d.checkParental,
testCache: gctx.parentalCache,
}}
// First - check that the request is not blocked for _, tc := range testCases {
res, err := d.checkSafeBrowsing("example.org") t.Run(tc.name, func(t *testing.T) {
assert.Nil(t, err) // Prepare the upstream.
assert.False(t, res.IsFiltered) ups := &aghtest.TestBlockUpstream{
Hostname: hostname,
Block: tc.block,
}
d.SetSafeBrowsingUpstream(ups)
d.SetParentalUpstream(ups)
// Check the cache state, check that the response is now cached // Firstly, check the request blocking.
assert.Equal(t, 1, gctx.safebrowsingCache.Stats().Count) hits := 0
assert.Equal(t, 0, gctx.safebrowsingCache.Stats().Hit) res, err := tc.testFunc(hostname)
assert.Nil(t, err)
if tc.block {
assert.True(t, res.IsFiltered)
assert.Len(t, res.Rules, 1)
hits++
} else {
assert.False(t, res.IsFiltered)
}
// There was one request to an upstream // Check the cache state, check the response is now cached.
assert.Equal(t, 1, ups.requestsCount) assert.Equal(t, 1, tc.testCache.Stats().Count)
assert.Equal(t, hits, tc.testCache.Stats().Hit)
// Now make the same request to check that the cache was used // There was one request to an upstream.
res, err = d.checkSafeBrowsing("example.org") assert.Equal(t, 1, ups.RequestsCount())
assert.Nil(t, err)
assert.False(t, res.IsFiltered)
// Check the cache state, it should've been used // Now make the same request to check the cache was used.
assert.Equal(t, 1, gctx.safebrowsingCache.Stats().Count) res, err = tc.testFunc(hostname)
assert.Equal(t, 1, gctx.safebrowsingCache.Stats().Hit) assert.Nil(t, err)
if tc.block {
assert.True(t, res.IsFiltered)
assert.Len(t, res.Rules, 1)
} else {
assert.False(t, res.IsFiltered)
}
// Check that there were no additional requests // Check the cache state, it should've been used.
assert.Equal(t, 1, ups.requestsCount) assert.Equal(t, 1, tc.testCache.Stats().Count)
} assert.Equal(t, hits+1, tc.testCache.Stats().Hit)
func TestSBPC_pcBlockedResponse(t *testing.T) { // Check that there were no additional requests.
d := newForTest(&Config{SafeBrowsingEnabled: true}, nil) assert.Equal(t, 1, ups.RequestsCount())
t.Cleanup(d.Close)
purgeCaches()
ups := &testSbUpstream{} })
d.safeBrowsingUpstream = ups }
d.parentalUpstream = ups
// Prepare the upstream
// Make sure that the upstream will return a response that matches the queried domain
ups.hostname = "example.com"
ups.block = true
ups.requestsCount = 0
// Make a lookup
res, err := d.checkParental("example.com")
assert.Nil(t, err)
assert.True(t, res.IsFiltered)
assert.Len(t, res.Rules, 1)
// Check the cache state, check that the response is now cached
assert.Equal(t, 1, gctx.parentalCache.Stats().Count)
assert.Equal(t, 1, gctx.parentalCache.Stats().Hit)
// There was one request to an upstream
assert.Equal(t, 1, ups.requestsCount)
// Make a second lookup for the same domain
res, err = d.checkParental("example.com")
assert.Nil(t, err)
assert.True(t, res.IsFiltered)
assert.Len(t, res.Rules, 1)
// Check the cache state, it should've been used
assert.Equal(t, 1, gctx.parentalCache.Stats().Count)
assert.Equal(t, 2, gctx.parentalCache.Stats().Hit)
// Check that there were no additional requests
assert.Equal(t, 1, ups.requestsCount)
} }

View File

@ -296,12 +296,13 @@ func (s *Server) prepareUpstreamSettings() error {
// prepareIntlProxy - initializes DNS proxy that we use for internal DNS queries // prepareIntlProxy - initializes DNS proxy that we use for internal DNS queries
func (s *Server) prepareIntlProxy() { func (s *Server) prepareIntlProxy() {
intlProxyConfig := proxy.Config{ s.internalProxy = &proxy.Proxy{
CacheEnabled: true, Config: proxy.Config{
CacheSizeBytes: 4096, CacheEnabled: true,
UpstreamConfig: s.conf.UpstreamConfig, CacheSizeBytes: 4096,
UpstreamConfig: s.conf.UpstreamConfig,
},
} }
s.internalProxy = &proxy.Proxy{Config: intlProxyConfig}
} }
// prepareTLS - prepares TLS configuration for the DNS proxy // prepareTLS - prepares TLS configuration for the DNS proxy

View File

@ -85,10 +85,11 @@ type DNSCreateParams struct {
// NewServer creates a new instance of the dnsforward.Server // NewServer creates a new instance of the dnsforward.Server
// Note: this function must be called only once // Note: this function must be called only once
func NewServer(p DNSCreateParams) *Server { func NewServer(p DNSCreateParams) *Server {
s := &Server{} s := &Server{
s.dnsFilter = p.DNSFilter dnsFilter: p.DNSFilter,
s.stats = p.Stats stats: p.Stats,
s.queryLog = p.QueryLog queryLog: p.QueryLog,
}
if p.DHCPServer != nil { if p.DHCPServer != nil {
s.dhcpServer = p.DHCPServer s.dhcpServer = p.DHCPServer
@ -103,6 +104,16 @@ func NewServer(p DNSCreateParams) *Server {
return s return s
} }
// NewCustomServer creates a new instance of *Server with custom internal proxy.
func NewCustomServer(internalProxy *proxy.Proxy) *Server {
s := &Server{}
if internalProxy != nil {
s.internalProxy = internalProxy
}
return s
}
// Close - close object // Close - close object
func (s *Server) Close() { func (s *Server) Close() {
s.Lock() s.Lock()

View File

@ -1,11 +1,9 @@
package dnsforward package dnsforward
import ( import (
"context"
"crypto/ecdsa" "crypto/ecdsa"
"crypto/rand" "crypto/rand"
"crypto/rsa" "crypto/rsa"
"crypto/sha256"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"crypto/x509/pkix" "crypto/x509/pkix"
@ -20,7 +18,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/testutil" "github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/util" "github.com/AdguardTeam/AdGuardHome/internal/util"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd" "github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
@ -32,7 +30,7 @@ import (
) )
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
testutil.DiscardLogOutput(m) aghtest.DiscardLogOutput(m)
} }
const ( const (
@ -53,10 +51,13 @@ func startDeferStop(t *testing.T, s *Server) {
} }
func TestServer(t *testing.T) { func TestServer(t *testing.T) {
s := createTestServer(t) s := createTestServer(t, &dnsfilter.Config{}, ServerConfig{
UDPListenAddr: &net.UDPAddr{},
TCPListenAddr: &net.TCPAddr{},
})
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
&testUpstream{ &aghtest.TestUpstream{
ipv4: map[string][]net.IP{ IPv4: map[string][]net.IP{
"google-public-dns-a.google.com.": {{8, 8, 8, 8}}, "google-public-dns-a.google.com.": {{8, 8, 8, 8}},
}, },
}, },
@ -88,11 +89,13 @@ func TestServer(t *testing.T) {
} }
func TestServerWithProtectionDisabled(t *testing.T) { func TestServerWithProtectionDisabled(t *testing.T) {
s := createTestServer(t) s := createTestServer(t, &dnsfilter.Config{}, ServerConfig{
s.conf.ProtectionEnabled = false UDPListenAddr: &net.UDPAddr{},
TCPListenAddr: &net.TCPAddr{},
})
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
&testUpstream{ &aghtest.TestUpstream{
ipv4: map[string][]net.IP{ IPv4: map[string][]net.IP{
"google-public-dns-a.google.com.": {{8, 8, 8, 8}}, "google-public-dns-a.google.com.": {{8, 8, 8, 8}},
}, },
}, },
@ -113,7 +116,11 @@ func createTestTLS(t *testing.T, tlsConf TLSConfig) (s *Server, certPem []byte)
var keyPem []byte var keyPem []byte
_, certPem, keyPem = createServerTLSConfig(t) _, certPem, keyPem = createServerTLSConfig(t)
s = createTestServer(t)
s = createTestServer(t, &dnsfilter.Config{}, ServerConfig{
UDPListenAddr: &net.UDPAddr{},
TCPListenAddr: &net.TCPAddr{},
})
tlsConf.CertificateChainData, tlsConf.PrivateKeyData = certPem, keyPem tlsConf.CertificateChainData, tlsConf.PrivateKeyData = certPem, keyPem
s.conf.TLSConfig = tlsConf s.conf.TLSConfig = tlsConf
@ -126,11 +133,11 @@ func createTestTLS(t *testing.T, tlsConf TLSConfig) (s *Server, certPem []byte)
func TestDoTServer(t *testing.T) { func TestDoTServer(t *testing.T) {
s, certPem := createTestTLS(t, TLSConfig{ s, certPem := createTestTLS(t, TLSConfig{
TLSListenAddr: &net.TCPAddr{Port: 0}, TLSListenAddr: &net.TCPAddr{},
}) })
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
&testUpstream{ &aghtest.TestUpstream{
ipv4: map[string][]net.IP{ IPv4: map[string][]net.IP{
"google-public-dns-a.google.com.": {{8, 8, 8, 8}}, "google-public-dns-a.google.com.": {{8, 8, 8, 8}},
}, },
}, },
@ -156,11 +163,11 @@ func TestDoTServer(t *testing.T) {
func TestDoQServer(t *testing.T) { func TestDoQServer(t *testing.T) {
s, _ := createTestTLS(t, TLSConfig{ s, _ := createTestTLS(t, TLSConfig{
QUICListenAddr: &net.UDPAddr{Port: 0}, QUICListenAddr: &net.UDPAddr{},
}) })
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
&testUpstream{ &aghtest.TestUpstream{
ipv4: map[string][]net.IP{ IPv4: map[string][]net.IP{
"google-public-dns-a.google.com.": {{8, 8, 8, 8}}, "google-public-dns-a.google.com.": {{8, 8, 8, 8}},
}, },
}, },
@ -184,10 +191,27 @@ func TestDoQServer(t *testing.T) {
func TestServerRace(t *testing.T) { func TestServerRace(t *testing.T) {
t.Skip("TODO(e.burkov): inspect the golibs/cache package for locks") t.Skip("TODO(e.burkov): inspect the golibs/cache package for locks")
s := createTestServer(t) filterConf := &dnsfilter.Config{
SafeBrowsingEnabled: true,
SafeBrowsingCacheSize: 1000,
SafeSearchEnabled: true,
SafeSearchCacheSize: 1000,
ParentalCacheSize: 1000,
CacheTime: 30,
}
forwardConf := ServerConfig{
UDPListenAddr: &net.UDPAddr{},
TCPListenAddr: &net.TCPAddr{},
FilteringConfig: FilteringConfig{
ProtectionEnabled: true,
UpstreamDNS: []string{"8.8.8.8:53", "8.8.4.4:53"},
},
ConfigModified: func() {},
}
s := createTestServer(t, filterConf, forwardConf)
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
&testUpstream{ &aghtest.TestUpstream{
ipv4: map[string][]net.IP{ IPv4: map[string][]net.IP{
"google-public-dns-a.google.com.": {{8, 8, 8, 8}}, "google-public-dns-a.google.com.": {{8, 8, 8, 8}},
}, },
}, },
@ -202,68 +226,74 @@ func TestServerRace(t *testing.T) {
sendTestMessagesAsync(t, conn) sendTestMessagesAsync(t, conn)
} }
// testResolver is a Resolver for tests.
//
//lint:ignore U1000 TODO(e.burkov): move into aghtest package.
type testResolver struct{}
// LookupIPAddr implements Resolver interface for *testResolver.
//
//lint:ignore U1000 TODO(e.burkov): move into aghtest package.
func (r *testResolver) LookupIPAddr(_ context.Context, host string) (ips []net.IPAddr, err error) {
hash := sha256.Sum256([]byte(host))
addrs := []net.IPAddr{{
IP: net.IP(hash[:4]),
Zone: "somezone",
}, {
IP: net.IP(hash[4:20]),
Zone: "somezone",
}}
return addrs, nil
}
// LookupHost implements Resolver interface for *testResolver.
//
//lint:ignore U1000 TODO(e.burkov): move into aghtest package.
func (r *testResolver) LookupHost(host string) (addrs []string, err error) {
hash := sha256.Sum256([]byte(host))
addrs = []string{
net.IP(hash[:4]).String(),
net.IP(hash[4:20]).String(),
}
return addrs, nil
}
func TestSafeSearch(t *testing.T) { func TestSafeSearch(t *testing.T) {
t.Skip("TODO(e.burkov): substitute the dnsfilter by one with custom resolver from aghtest") resolver := &aghtest.TestResolver{}
filterConf := &dnsfilter.Config{
testUpstreamIP := net.IP{213, 180, 193, 56} SafeSearchEnabled: true,
testCases := []string{ SafeSearchCacheSize: 1000,
"yandex.com.", CacheTime: 30,
"yandex.by.", CustomResolver: resolver,
"yandex.kz.",
"yandex.ru.",
"www.google.com.",
"www.google.com.af.",
"www.google.be.",
"www.google.by.",
} }
forwardConf := ServerConfig{
UDPListenAddr: &net.UDPAddr{},
TCPListenAddr: &net.TCPAddr{},
FilteringConfig: FilteringConfig{
ProtectionEnabled: true,
},
}
s := createTestServer(t, filterConf, forwardConf)
startDeferStop(t, s)
addr := s.dnsProxy.Addr(proxy.ProtoUDP).String()
client := dns.Client{Net: proxy.ProtoUDP}
yandexIP := net.IP{213, 180, 193, 56}
googleIP, _ := resolver.HostToIPs("forcesafesearch.google.com")
testCases := []struct {
host string
want net.IP
}{{
host: "yandex.com.",
want: yandexIP,
}, {
host: "yandex.by.",
want: yandexIP,
}, {
host: "yandex.kz.",
want: yandexIP,
}, {
host: "yandex.ru.",
want: yandexIP,
}, {
host: "www.google.com.",
want: googleIP,
}, {
host: "www.google.com.af.",
want: googleIP,
}, {
host: "www.google.be.",
want: googleIP,
}, {
host: "www.google.by.",
want: googleIP,
}}
for _, tc := range testCases { for _, tc := range testCases {
t.Run("safe_search_"+tc, func(t *testing.T) { t.Run(tc.host, func(t *testing.T) {
s := createTestServer(t) req := createTestMessage(tc.host)
startDeferStop(t, s) reply, _, err := client.Exchange(req, addr)
assert.Nilf(t, err, "couldn't talk to server %s: %s", addr, err)
addr := s.dnsProxy.Addr(proxy.ProtoUDP) assertResponse(t, reply, tc.want)
client := dns.Client{Net: proxy.ProtoUDP}
exchangeAndAssertResponse(t, &client, addr, tc, testUpstreamIP)
}) })
} }
} }
func TestInvalidRequest(t *testing.T) { func TestInvalidRequest(t *testing.T) {
s := createTestServer(t) s := createTestServer(t, &dnsfilter.Config{}, ServerConfig{
UDPListenAddr: &net.UDPAddr{},
TCPListenAddr: &net.TCPAddr{},
})
startDeferStop(t, s) startDeferStop(t, s)
addr := s.dnsProxy.Addr(proxy.ProtoUDP).String() addr := s.dnsProxy.Addr(proxy.ProtoUDP).String()
@ -284,7 +314,14 @@ func TestInvalidRequest(t *testing.T) {
} }
func TestBlockedRequest(t *testing.T) { func TestBlockedRequest(t *testing.T) {
s := createTestServer(t) forwardConf := ServerConfig{
UDPListenAddr: &net.UDPAddr{},
TCPListenAddr: &net.TCPAddr{},
FilteringConfig: FilteringConfig{
ProtectionEnabled: true,
},
}
s := createTestServer(t, &dnsfilter.Config{}, forwardConf)
startDeferStop(t, s) startDeferStop(t, s)
addr := s.dnsProxy.Addr(proxy.ProtoUDP) addr := s.dnsProxy.Addr(proxy.ProtoUDP)
@ -299,12 +336,19 @@ func TestBlockedRequest(t *testing.T) {
} }
func TestServerCustomClientUpstream(t *testing.T) { func TestServerCustomClientUpstream(t *testing.T) {
s := createTestServer(t) forwardConf := ServerConfig{
UDPListenAddr: &net.UDPAddr{},
TCPListenAddr: &net.TCPAddr{},
FilteringConfig: FilteringConfig{
ProtectionEnabled: true,
},
}
s := createTestServer(t, &dnsfilter.Config{}, forwardConf)
s.conf.GetCustomUpstreamByClient = func(_ string) *proxy.UpstreamConfig { s.conf.GetCustomUpstreamByClient = func(_ string) *proxy.UpstreamConfig {
return &proxy.UpstreamConfig{ return &proxy.UpstreamConfig{
Upstreams: []upstream.Upstream{ Upstreams: []upstream.Upstream{
&testUpstream{ &aghtest.TestUpstream{
ipv4: map[string][]net.IP{ IPv4: map[string][]net.IP{
"host.": {{192, 168, 0, 1}}, "host.": {{192, 168, 0, 1}},
}, },
}, },
@ -327,82 +371,6 @@ func TestServerCustomClientUpstream(t *testing.T) {
assert.Equal(t, net.IP{192, 168, 0, 1}, reply.Answer[0].(*dns.A).A) assert.Equal(t, net.IP{192, 168, 0, 1}, reply.Answer[0].(*dns.A).A)
} }
// testUpstream is a mock of real upstream. specify fields with necessary values
// to simulate real upstream behaviour.
//
// TODO(e.burkov): move into aghtest package.
type testUpstream struct {
// cn is a map of hostname to canonical name.
cn map[string]string
// ipv4 is a map of hostname to IPv4.
ipv4 map[string][]net.IP
// ipv6 is a map of hostname to IPv6.
ipv6 map[string][]net.IP
}
// Exchange implements upstream.Upstream interface for *testUpstream.
func (u *testUpstream) Exchange(m *dns.Msg) (*dns.Msg, error) {
resp := &dns.Msg{}
resp.SetReply(m)
hasRec := false
name := m.Question[0].Name
if cname, ok := u.cn[name]; ok {
resp.Answer = append(resp.Answer, &dns.CNAME{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeCNAME,
},
Target: cname,
})
}
var rrtype uint16
var a []net.IP
switch m.Question[0].Qtype {
case dns.TypeA:
rrtype = dns.TypeA
if ipv4addr, ok := u.ipv4[name]; ok {
hasRec = true
a = ipv4addr
}
case dns.TypeAAAA:
rrtype = dns.TypeAAAA
if ipv6addr, ok := u.ipv6[name]; ok {
hasRec = true
a = ipv6addr
}
}
for _, ip := range a {
resp.Answer = append(resp.Answer, &dns.A{
Hdr: dns.RR_Header{
Name: name,
Rrtype: rrtype,
},
A: ip,
})
}
if len(resp.Answer) == 0 {
if hasRec {
// Set no error RCode if there are some records for
// given Qname but we didn't apply them.
resp.SetRcode(m, dns.RcodeSuccess)
return resp, nil
}
// Set NXDomain RCode otherwise.
resp.SetRcode(m, dns.RcodeNameError)
}
return resp, nil
}
// Address implements upstream.Upstream interface for *testUpstream.
func (u *testUpstream) Address() string {
return "test"
}
func (s *Server) startWithUpstream(u upstream.Upstream) error { func (s *Server) startWithUpstream(u upstream.Upstream) error {
s.Lock() s.Lock()
defer s.Unlock() defer s.Unlock()
@ -410,30 +378,35 @@ func (s *Server) startWithUpstream(u upstream.Upstream) error {
if err != nil { if err != nil {
return err return err
} }
s.dnsProxy.UpstreamConfig = &proxy.UpstreamConfig{ s.dnsProxy.UpstreamConfig = &proxy.UpstreamConfig{
Upstreams: []upstream.Upstream{u}, Upstreams: []upstream.Upstream{u},
} }
return s.dnsProxy.Start() return s.dnsProxy.Start()
} }
// testCNAMEs is a simple map of names and CNAMEs necessary for the testUpstream work // testCNAMEs is a map of names and CNAMEs necessary for the TestUpstream work.
var testCNAMEs = map[string]string{ var testCNAMEs = map[string]string{
"badhost.": "null.example.org.", "badhost.": "null.example.org.",
"whitelist.example.org.": "null.example.org.", "whitelist.example.org.": "null.example.org.",
} }
// testIPv4 is a simple map of names and IPv4s necessary for the testUpstream work // testIPv4 is a map of names and IPv4s necessary for the TestUpstream work.
var testIPv4 = map[string][]net.IP{ var testIPv4 = map[string][]net.IP{
"null.example.org.": {{1, 2, 3, 4}}, "null.example.org.": {{1, 2, 3, 4}},
"example.org.": {{127, 0, 0, 255}}, "example.org.": {{127, 0, 0, 255}},
} }
func TestBlockCNAMEProtectionEnabled(t *testing.T) { func TestBlockCNAMEProtectionEnabled(t *testing.T) {
s := createTestServer(t) s := createTestServer(t, &dnsfilter.Config{}, ServerConfig{
testUpstm := &testUpstream{ UDPListenAddr: &net.UDPAddr{},
cn: testCNAMEs, TCPListenAddr: &net.TCPAddr{},
ipv4: testIPv4, })
ipv6: nil, testUpstm := &aghtest.TestUpstream{
CName: testCNAMEs,
IPv4: testIPv4,
IPv6: nil,
} }
s.conf.ProtectionEnabled = false s.conf.ProtectionEnabled = false
err := s.startWithUpstream(testUpstm) err := s.startWithUpstream(testUpstm)
@ -449,11 +422,18 @@ func TestBlockCNAMEProtectionEnabled(t *testing.T) {
} }
func TestBlockCNAME(t *testing.T) { func TestBlockCNAME(t *testing.T) {
s := createTestServer(t) forwardConf := ServerConfig{
UDPListenAddr: &net.UDPAddr{},
TCPListenAddr: &net.TCPAddr{},
FilteringConfig: FilteringConfig{
ProtectionEnabled: true,
},
}
s := createTestServer(t, &dnsfilter.Config{}, forwardConf)
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
&testUpstream{ &aghtest.TestUpstream{
cn: testCNAMEs, CName: testCNAMEs,
ipv4: testIPv4, IPv4: testIPv4,
}, },
} }
startDeferStop(t, s) startDeferStop(t, s)
@ -496,14 +476,21 @@ func TestBlockCNAME(t *testing.T) {
} }
func TestClientRulesForCNAMEMatching(t *testing.T) { func TestClientRulesForCNAMEMatching(t *testing.T) {
s := createTestServer(t) forwardConf := ServerConfig{
s.conf.FilterHandler = func(_ net.IP, _ string, settings *dnsfilter.RequestFilteringSettings) { UDPListenAddr: &net.UDPAddr{},
settings.FilteringEnabled = false TCPListenAddr: &net.TCPAddr{},
FilteringConfig: FilteringConfig{
ProtectionEnabled: true,
FilterHandler: func(_ net.IP, _ string, settings *dnsfilter.RequestFilteringSettings) {
settings.FilteringEnabled = false
},
},
} }
s := createTestServer(t, &dnsfilter.Config{}, forwardConf)
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
&testUpstream{ &aghtest.TestUpstream{
cn: testCNAMEs, CName: testCNAMEs,
ipv4: testIPv4, IPv4: testIPv4,
}, },
} }
startDeferStop(t, s) startDeferStop(t, s)
@ -531,8 +518,15 @@ func TestClientRulesForCNAMEMatching(t *testing.T) {
} }
func TestNullBlockedRequest(t *testing.T) { func TestNullBlockedRequest(t *testing.T) {
s := createTestServer(t) forwardConf := ServerConfig{
s.conf.FilteringConfig.BlockingMode = "null_ip" UDPListenAddr: &net.UDPAddr{},
TCPListenAddr: &net.TCPAddr{},
FilteringConfig: FilteringConfig{
ProtectionEnabled: true,
BlockingMode: "null_ip",
},
}
s := createTestServer(t, &dnsfilter.Config{}, forwardConf)
startDeferStop(t, s) startDeferStop(t, s)
addr := s.dnsProxy.Addr(proxy.ProtoUDP) addr := s.dnsProxy.Addr(proxy.ProtoUDP)
@ -568,8 +562,8 @@ func TestBlockedCustomIP(t *testing.T) {
DNSFilter: dnsfilter.New(&dnsfilter.Config{}, filters), DNSFilter: dnsfilter.New(&dnsfilter.Config{}, filters),
}) })
conf := ServerConfig{ conf := ServerConfig{
UDPListenAddr: &net.UDPAddr{Port: 0}, UDPListenAddr: &net.UDPAddr{},
TCPListenAddr: &net.TCPAddr{Port: 0}, TCPListenAddr: &net.TCPAddr{},
FilteringConfig: FilteringConfig{ FilteringConfig: FilteringConfig{
ProtectionEnabled: true, ProtectionEnabled: true,
BlockingMode: "custom_ip", BlockingMode: "custom_ip",
@ -606,7 +600,14 @@ func TestBlockedCustomIP(t *testing.T) {
} }
func TestBlockedByHosts(t *testing.T) { func TestBlockedByHosts(t *testing.T) {
s := createTestServer(t) forwardConf := ServerConfig{
UDPListenAddr: &net.UDPAddr{},
TCPListenAddr: &net.TCPAddr{},
FilteringConfig: FilteringConfig{
ProtectionEnabled: true,
},
}
s := createTestServer(t, &dnsfilter.Config{}, forwardConf)
startDeferStop(t, s) startDeferStop(t, s)
addr := s.dnsProxy.Addr(proxy.ProtoUDP) addr := s.dnsProxy.Addr(proxy.ProtoUDP)
@ -623,24 +624,32 @@ func TestBlockedByHosts(t *testing.T) {
} }
func TestBlockedBySafeBrowsing(t *testing.T) { func TestBlockedBySafeBrowsing(t *testing.T) {
t.Skip("TODO(e.burkov): substitute the dnsfilter by one with custom safeBrowsingUpstream") const hostname = "wmconvirus.narod.ru"
resolver := &testResolver{}
ips, _ := resolver.LookupIPAddr(context.Background(), safeBrowsingBlockHost)
addrs, _ := resolver.LookupHost(safeBrowsingBlockHost)
s := createTestServer(t) sbUps := &aghtest.TestBlockUpstream{
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ Hostname: hostname,
&testUpstream{ Block: true,
ipv4: map[string][]net.IP{ }
"wmconvirus.narod.ru.": {ips[0].IP}, ans, _ := (&aghtest.TestResolver{}).HostToIPs(hostname)
},
filterConf := &dnsfilter.Config{
SafeBrowsingEnabled: true,
}
forwardConf := ServerConfig{
UDPListenAddr: &net.UDPAddr{},
TCPListenAddr: &net.TCPAddr{},
FilteringConfig: FilteringConfig{
SafeBrowsingBlockHost: ans.String(),
ProtectionEnabled: true,
}, },
} }
s := createTestServer(t, filterConf, forwardConf)
s.dnsFilter.SetSafeBrowsingUpstream(sbUps)
startDeferStop(t, s) startDeferStop(t, s)
addr := s.dnsProxy.Addr(proxy.ProtoUDP) addr := s.dnsProxy.Addr(proxy.ProtoUDP)
// SafeBrowsing blocking. // SafeBrowsing blocking.
req := createTestMessage("wmconvirus.narod.ru.") req := createTestMessage(hostname + ".")
reply, err := dns.Exchange(req, addr.String()) reply, err := dns.Exchange(req, addr.String())
assert.Nilf(t, err, "couldn't talk to server %s: %s", addr, err) assert.Nilf(t, err, "couldn't talk to server %s: %s", addr, err)
@ -648,14 +657,7 @@ func TestBlockedBySafeBrowsing(t *testing.T) {
a, ok := reply.Answer[0].(*dns.A) a, ok := reply.Answer[0].(*dns.A)
if assert.Truef(t, ok, "dns server %s returned wrong answer type instead of A: %v", addr, reply.Answer[0]) { if assert.Truef(t, ok, "dns server %s returned wrong answer type instead of A: %v", addr, reply.Answer[0]) {
found := false assert.Equal(t, ans, a.A, "dns server %s returned wrong answer: %v", addr, a.A)
for _, blockAddr := range addrs {
if blockAddr == a.A.String() {
found = true
break
}
}
assert.Truef(t, found, "dns server %s returned wrong answer: %v", addr, a.A)
} }
} }
@ -679,19 +681,19 @@ func TestRewrite(t *testing.T) {
s := NewServer(DNSCreateParams{DNSFilter: f}) s := NewServer(DNSCreateParams{DNSFilter: f})
err := s.Prepare(&ServerConfig{ err := s.Prepare(&ServerConfig{
UDPListenAddr: &net.UDPAddr{Port: 0}, UDPListenAddr: &net.UDPAddr{},
TCPListenAddr: &net.TCPAddr{Port: 0}, TCPListenAddr: &net.TCPAddr{},
FilteringConfig: FilteringConfig{ FilteringConfig: FilteringConfig{
ProtectionEnabled: true, ProtectionEnabled: true,
UpstreamDNS: []string{"8.8.8.8:53"}, UpstreamDNS: []string{"8.8.8.8:53"},
}, },
}) })
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
&testUpstream{ &aghtest.TestUpstream{
cn: map[string]string{ CName: map[string]string{
"example.org": "somename", "example.org": "somename",
}, },
ipv4: map[string][]net.IP{ IPv4: map[string][]net.IP{
"example.org.": {{4, 3, 2, 1}}, "example.org.": {{4, 3, 2, 1}},
}, },
}, },
@ -724,13 +726,14 @@ func TestRewrite(t *testing.T) {
req = createTestMessageWithType("my.alias.example.org.", dns.TypeA) req = createTestMessageWithType("my.alias.example.org.", dns.TypeA)
reply, err = dns.Exchange(req, addr.String()) reply, err = dns.Exchange(req, addr.String())
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "my.alias.example.org.", reply.Question[0].Name) // the original question is restored // The original question is restored.
assert.Equal(t, "my.alias.example.org.", reply.Question[0].Name)
assert.Len(t, reply.Answer, 2) assert.Len(t, reply.Answer, 2)
assert.Equal(t, "example.org.", reply.Answer[0].(*dns.CNAME).Target) assert.Equal(t, "example.org.", reply.Answer[0].(*dns.CNAME).Target)
assert.Equal(t, dns.TypeA, reply.Answer[1].Header().Rrtype) assert.Equal(t, dns.TypeA, reply.Answer[1].Header().Rrtype)
} }
func createTestServer(t *testing.T) *Server { func createTestServer(t *testing.T, filterConf *dnsfilter.Config, forwardConf ServerConfig) *Server {
rules := `||nxdomain.example.org rules := `||nxdomain.example.org
||null.example.org^ ||null.example.org^
127.0.0.1 host.example.org 127.0.0.1 host.example.org
@ -739,30 +742,13 @@ func createTestServer(t *testing.T) *Server {
filters := []dnsfilter.Filter{{ filters := []dnsfilter.Filter{{
ID: 0, Data: []byte(rules), ID: 0, Data: []byte(rules),
}} }}
c := dnsfilter.Config{
SafeBrowsingEnabled: true,
SafeBrowsingCacheSize: 1000,
SafeSearchEnabled: true,
SafeSearchCacheSize: 1000,
ParentalCacheSize: 1000,
CacheTime: 30,
}
f := dnsfilter.New(&c, filters) f := dnsfilter.New(filterConf, filters)
s := NewServer(DNSCreateParams{DNSFilter: f}) s := NewServer(DNSCreateParams{DNSFilter: f})
s.conf = ServerConfig{ s.conf = forwardConf
UDPListenAddr: &net.UDPAddr{Port: 0}, assert.Nil(t, s.Prepare(nil))
TCPListenAddr: &net.TCPAddr{Port: 0},
FilteringConfig: FilteringConfig{
ProtectionEnabled: true,
UpstreamDNS: []string{"8.8.8.8:53", "8.8.4.4:53"},
},
ConfigModified: func() {},
}
err := s.Prepare(nil)
assert.Nil(t, err)
return s return s
} }
@ -849,15 +835,6 @@ func sendTestMessages(t *testing.T, conn *dns.Conn) {
} }
} }
func exchangeAndAssertResponse(t *testing.T, client *dns.Client, addr net.Addr, host string, ip net.IP) {
t.Helper()
req := createTestMessage(host)
reply, _, err := client.Exchange(req, addr.String())
assert.Nilf(t, err, "couldn't talk to server %s: %s", addr, err)
assertResponse(t, reply, ip)
}
func createGoogleATestMessage() *dns.Msg { func createGoogleATestMessage() *dns.Msg {
return createTestMessage("google-public-dns-a.google.com.") return createTestMessage("google-public-dns-a.google.com.")
} }
@ -879,6 +856,7 @@ func createTestMessage(host string) *dns.Msg {
func createTestMessageWithType(host string, qtype uint16) *dns.Msg { func createTestMessageWithType(host string, qtype uint16) *dns.Msg {
req := createTestMessage(host) req := createTestMessage(host)
req.Question[0].Qtype = qtype req.Question[0].Qtype = qtype
return req return req
} }
@ -889,7 +867,10 @@ func assertGoogleAResponse(t *testing.T, reply *dns.Msg) {
func assertResponse(t *testing.T, reply *dns.Msg, ip net.IP) { func assertResponse(t *testing.T, reply *dns.Msg, ip net.IP) {
t.Helper() t.Helper()
assert.Lenf(t, reply.Answer, 1, "dns server returned reply with wrong number of answers - %d", len(reply.Answer)) if !assert.Lenf(t, reply.Answer, 1, "dns server returned reply with wrong number of answers - %d", len(reply.Answer)) {
return
}
a, ok := reply.Answer[0].(*dns.A) a, ok := reply.Answer[0].(*dns.A)
if assert.Truef(t, ok, "dns server returned wrong answer type instead of A: %v", reply.Answer[0]) { if assert.Truef(t, ok, "dns server returned wrong answer type instead of A: %v", reply.Answer[0]) {
assert.Truef(t, a.A.Equal(ip), "dns server returned wrong answer instead of %s: %s", ip, a.A) assert.Truef(t, a.A.Equal(ip), "dns server returned wrong answer instead of %s: %s", ip, a.A)
@ -900,8 +881,10 @@ func publicKey(priv interface{}) interface{} {
switch k := priv.(type) { switch k := priv.(type) {
case *rsa.PrivateKey: case *rsa.PrivateKey:
return &k.PublicKey return &k.PublicKey
case *ecdsa.PrivateKey: case *ecdsa.PrivateKey:
return &k.PublicKey return &k.PublicKey
default: default:
return nil return nil
} }
@ -1082,6 +1065,7 @@ func (d *testDHCP) Leases(flags int) []dhcpd.Lease {
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}, HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
Hostname: "localhost", Hostname: "localhost",
} }
return []dhcpd.Lease{l} return []dhcpd.Lease{l}
} }
func (d *testDHCP) SetOnLeaseChanged(onLeaseChanged dhcpd.OnLeaseChangedT) {} func (d *testDHCP) SetOnLeaseChanged(onLeaseChanged dhcpd.OnLeaseChangedT) {}
@ -1094,8 +1078,8 @@ func TestPTRResponseFromDHCPLeases(t *testing.T) {
DHCPServer: dhcp, DHCPServer: dhcp,
}) })
s.conf.UDPListenAddr = &net.UDPAddr{Port: 0} s.conf.UDPListenAddr = &net.UDPAddr{}
s.conf.TCPListenAddr = &net.TCPAddr{Port: 0} s.conf.TCPListenAddr = &net.TCPAddr{}
s.conf.UpstreamDNS = []string{"127.0.0.1:53"} s.conf.UpstreamDNS = []string{"127.0.0.1:53"}
s.conf.FilteringConfig.ProtectionEnabled = true s.conf.FilteringConfig.ProtectionEnabled = true
err := s.Prepare(nil) err := s.Prepare(nil)
@ -1143,8 +1127,8 @@ func TestPTRResponseFromHosts(t *testing.T) {
t.Cleanup(c.AutoHosts.Close) t.Cleanup(c.AutoHosts.Close)
s := NewServer(DNSCreateParams{DNSFilter: dnsfilter.New(&c, nil)}) s := NewServer(DNSCreateParams{DNSFilter: dnsfilter.New(&c, nil)})
s.conf.UDPListenAddr = &net.UDPAddr{Port: 0} s.conf.UDPListenAddr = &net.UDPAddr{}
s.conf.TCPListenAddr = &net.TCPAddr{Port: 0} s.conf.TCPListenAddr = &net.TCPAddr{}
s.conf.UpstreamDNS = []string{"127.0.0.1:53"} s.conf.UpstreamDNS = []string{"127.0.0.1:53"}
s.conf.FilteringConfig.ProtectionEnabled = true s.conf.FilteringConfig.ProtectionEnabled = true
assert.Nil(t, s.Prepare(nil)) assert.Nil(t, s.Prepare(nil))

View File

@ -2,16 +2,35 @@ package dnsforward
import ( import (
"io/ioutil" "io/ioutil"
"net"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings" "strings"
"testing" "testing"
"github.com/AdguardTeam/AdGuardHome/internal/dnsfilter"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestDNSForwardHTTTP_handleGetConfig(t *testing.T) { func TestDNSForwardHTTTP_handleGetConfig(t *testing.T) {
s := createTestServer(t) filterConf := &dnsfilter.Config{
SafeBrowsingEnabled: true,
SafeBrowsingCacheSize: 1000,
SafeSearchEnabled: true,
SafeSearchCacheSize: 1000,
ParentalCacheSize: 1000,
CacheTime: 30,
}
forwardConf := ServerConfig{
UDPListenAddr: &net.UDPAddr{},
TCPListenAddr: &net.TCPAddr{},
FilteringConfig: FilteringConfig{
ProtectionEnabled: true,
UpstreamDNS: []string{"8.8.8.8:53", "8.8.4.4:53"},
},
ConfigModified: func() {},
}
s := createTestServer(t, filterConf, forwardConf)
err := s.Start() err := s.Start()
assert.Nil(t, err) assert.Nil(t, err)
defer assert.Nil(t, s.Stop()) defer assert.Nil(t, s.Stop())
@ -35,6 +54,7 @@ func TestDNSForwardHTTTP_handleGetConfig(t *testing.T) {
conf: func() ServerConfig { conf: func() ServerConfig {
conf := defaultConf conf := defaultConf
conf.FastestAddr = true conf.FastestAddr = true
return conf return conf
}, },
want: "{\"upstream_dns\":[\"8.8.8.8:53\",\"8.8.4.4:53\"],\"upstream_dns_file\":\"\",\"bootstrap_dns\":[\"9.9.9.10\",\"149.112.112.10\",\"2620:fe::10\",\"2620:fe::fe:10\"],\"protection_enabled\":true,\"ratelimit\":0,\"blocking_mode\":\"\",\"blocking_ipv4\":\"\",\"blocking_ipv6\":\"\",\"edns_cs_enabled\":false,\"dnssec_enabled\":false,\"disable_ipv6\":false,\"upstream_mode\":\"fastest_addr\",\"cache_size\":0,\"cache_ttl_min\":0,\"cache_ttl_max\":0}\n", want: "{\"upstream_dns\":[\"8.8.8.8:53\",\"8.8.4.4:53\"],\"upstream_dns_file\":\"\",\"bootstrap_dns\":[\"9.9.9.10\",\"149.112.112.10\",\"2620:fe::10\",\"2620:fe::fe:10\"],\"protection_enabled\":true,\"ratelimit\":0,\"blocking_mode\":\"\",\"blocking_ipv4\":\"\",\"blocking_ipv6\":\"\",\"edns_cs_enabled\":false,\"dnssec_enabled\":false,\"disable_ipv6\":false,\"upstream_mode\":\"fastest_addr\",\"cache_size\":0,\"cache_ttl_min\":0,\"cache_ttl_max\":0}\n",
@ -43,6 +63,7 @@ func TestDNSForwardHTTTP_handleGetConfig(t *testing.T) {
conf: func() ServerConfig { conf: func() ServerConfig {
conf := defaultConf conf := defaultConf
conf.AllServers = true conf.AllServers = true
return conf return conf
}, },
want: "{\"upstream_dns\":[\"8.8.8.8:53\",\"8.8.4.4:53\"],\"upstream_dns_file\":\"\",\"bootstrap_dns\":[\"9.9.9.10\",\"149.112.112.10\",\"2620:fe::10\",\"2620:fe::fe:10\"],\"protection_enabled\":true,\"ratelimit\":0,\"blocking_mode\":\"\",\"blocking_ipv4\":\"\",\"blocking_ipv6\":\"\",\"edns_cs_enabled\":false,\"dnssec_enabled\":false,\"disable_ipv6\":false,\"upstream_mode\":\"parallel\",\"cache_size\":0,\"cache_ttl_min\":0,\"cache_ttl_max\":0}\n", want: "{\"upstream_dns\":[\"8.8.8.8:53\",\"8.8.4.4:53\"],\"upstream_dns_file\":\"\",\"bootstrap_dns\":[\"9.9.9.10\",\"149.112.112.10\",\"2620:fe::10\",\"2620:fe::fe:10\"],\"protection_enabled\":true,\"ratelimit\":0,\"blocking_mode\":\"\",\"blocking_ipv4\":\"\",\"blocking_ipv6\":\"\",\"edns_cs_enabled\":false,\"dnssec_enabled\":false,\"disable_ipv6\":false,\"upstream_mode\":\"parallel\",\"cache_size\":0,\"cache_ttl_min\":0,\"cache_ttl_max\":0}\n",
@ -61,7 +82,24 @@ func TestDNSForwardHTTTP_handleGetConfig(t *testing.T) {
} }
func TestDNSForwardHTTTP_handleSetConfig(t *testing.T) { func TestDNSForwardHTTTP_handleSetConfig(t *testing.T) {
s := createTestServer(t) filterConf := &dnsfilter.Config{
SafeBrowsingEnabled: true,
SafeBrowsingCacheSize: 1000,
SafeSearchEnabled: true,
SafeSearchCacheSize: 1000,
ParentalCacheSize: 1000,
CacheTime: 30,
}
forwardConf := ServerConfig{
UDPListenAddr: &net.UDPAddr{},
TCPListenAddr: &net.TCPAddr{},
FilteringConfig: FilteringConfig{
ProtectionEnabled: true,
UpstreamDNS: []string{"8.8.8.8:53", "8.8.4.4:53"},
},
ConfigModified: func() {},
}
s := createTestServer(t, filterConf, forwardConf)
defaultConf := s.conf defaultConf := s.conf

View File

@ -9,12 +9,12 @@ import (
"testing" "testing"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/testutil" "github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
testutil.DiscardLogOutput(m) aghtest.DiscardLogOutput(m)
} }
func prepareTestDir() string { func prepareTestDir() string {

View File

@ -27,18 +27,18 @@ type RDNS struct {
// InitRDNS - create module context // InitRDNS - create module context
func InitRDNS(dnsServer *dnsforward.Server, clients *clientsContainer) *RDNS { func InitRDNS(dnsServer *dnsforward.Server, clients *clientsContainer) *RDNS {
r := RDNS{} r := &RDNS{
r.dnsServer = dnsServer dnsServer: dnsServer,
r.clients = clients clients: clients,
ipAddrs: cache.New(cache.Config{
EnableLRU: true,
MaxCount: 10000,
}),
ipChannel: make(chan net.IP, 256),
}
cconf := cache.Config{}
cconf.EnableLRU = true
cconf.MaxCount = 10000
r.ipAddrs = cache.New(cconf)
r.ipChannel = make(chan net.IP, 256)
go r.workerLoop() go r.workerLoop()
return &r return r
} }
// Begin - add IP address to rDNS queue // Begin - add IP address to rDNS queue
@ -75,23 +75,23 @@ func (r *RDNS) Begin(ip net.IP) {
func (r *RDNS) resolve(ip net.IP) string { func (r *RDNS) resolve(ip net.IP) string {
log.Tracef("Resolving host for %s", ip) log.Tracef("Resolving host for %s", ip)
req := dns.Msg{} name, err := dns.ReverseAddr(ip.String())
req.Id = dns.Id()
req.RecursionDesired = true
req.Question = []dns.Question{
{
Qtype: dns.TypePTR,
Qclass: dns.ClassINET,
},
}
var err error
req.Question[0].Name, err = dns.ReverseAddr(ip.String())
if err != nil { if err != nil {
log.Debug("Error while calling dns.ReverseAddr(%s): %s", ip, err) log.Debug("Error while calling dns.ReverseAddr(%s): %s", ip, err)
return "" return ""
} }
resp, err := r.dnsServer.Exchange(&req) resp, err := r.dnsServer.Exchange(&dns.Msg{
MsgHdr: dns.MsgHdr{
Id: dns.Id(),
RecursionDesired: true,
},
Question: []dns.Question{{
Name: name,
Qtype: dns.TypePTR,
Qclass: dns.ClassINET,
}},
})
if err != nil { if err != nil {
log.Debug("Error while making an rDNS lookup for %s: %s", ip, err) log.Debug("Error while making an rDNS lookup for %s: %s", ip, err)
return "" return ""

View File

@ -4,16 +4,26 @@ import (
"net" "net"
"testing" "testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward" "github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestResolveRDNS(t *testing.T) { func TestResolveRDNS(t *testing.T) {
dns := &dnsforward.Server{} ups := &aghtest.TestUpstream{
conf := &dnsforward.ServerConfig{} Reverse: map[string][]string{
conf.UpstreamDNS = []string{"8.8.8.8"} "1.1.1.1.in-addr.arpa.": {"one.one.one.one"},
err := dns.Prepare(conf) },
assert.Nil(t, err) }
dns := dnsforward.NewCustomServer(&proxy.Proxy{
Config: proxy.Config{
UpstreamConfig: &proxy.UpstreamConfig{
Upstreams: []upstream.Upstream{ups},
},
},
})
clients := &clientsContainer{} clients := &clientsContainer{}
rdns := InitRDNS(dns, clients) rdns := InitRDNS(dns, clients)

View File

@ -13,9 +13,14 @@ func prepareTestDNSServer() error {
Context.dnsServer = dnsforward.NewServer(dnsforward.DNSCreateParams{}) Context.dnsServer = dnsforward.NewServer(dnsforward.DNSCreateParams{})
conf := &dnsforward.ServerConfig{} conf := &dnsforward.ServerConfig{}
conf.UpstreamDNS = []string{"8.8.8.8"} conf.UpstreamDNS = []string{"8.8.8.8"}
return Context.dnsServer.Prepare(conf) return Context.dnsServer.Prepare(conf)
} }
// TODO(e.burkov): It's kind of complicated to get rid of network access in this
// test. The thing is that *Whois creates new *net.Dialer each time it requests
// the server, so it becomes hard to simulate handling of request from test even
// with substituted upstream. However, it must be done.
func TestWhois(t *testing.T) { func TestWhois(t *testing.T) {
assert.Nil(t, prepareTestDNSServer()) assert.Nil(t, prepareTestDNSServer())

View File

@ -8,8 +8,8 @@ import (
"testing" "testing"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/dnsfilter" "github.com/AdguardTeam/AdGuardHome/internal/dnsfilter"
"github.com/AdguardTeam/AdGuardHome/internal/testutil"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/urlfilter/rules" "github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns" "github.com/miekg/dns"
@ -19,8 +19,8 @@ import (
func TestDecodeLogEntry(t *testing.T) { func TestDecodeLogEntry(t *testing.T) {
logOutput := &bytes.Buffer{} logOutput := &bytes.Buffer{}
testutil.ReplaceLogWriter(t, logOutput) aghtest.ReplaceLogWriter(t, logOutput)
testutil.ReplaceLogLevel(t, log.DEBUG) aghtest.ReplaceLogLevel(t, log.DEBUG)
t.Run("success", func(t *testing.T) { t.Run("success", func(t *testing.T) {
const ansStr = `Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==` const ansStr = `Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==`

View File

@ -10,14 +10,14 @@ import (
"github.com/AdguardTeam/dnsproxy/proxyutil" "github.com/AdguardTeam/dnsproxy/proxyutil"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/dnsfilter" "github.com/AdguardTeam/AdGuardHome/internal/dnsfilter"
"github.com/AdguardTeam/AdGuardHome/internal/testutil"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
testutil.DiscardLogOutput(m) aghtest.DiscardLogOutput(m)
} }
func prepareTestDir() string { func prepareTestDir() string {

View File

@ -7,12 +7,12 @@ import (
"sync/atomic" "sync/atomic"
"testing" "testing"
"github.com/AdguardTeam/AdGuardHome/internal/testutil" "github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
testutil.DiscardLogOutput(m) aghtest.DiscardLogOutput(m)
} }
func UIntArrayEquals(a, b []uint64) bool { func UIntArrayEquals(a, b []uint64) bool {

View File

@ -3,9 +3,9 @@ package sysutil
import ( import (
"testing" "testing"
"github.com/AdguardTeam/AdGuardHome/internal/testutil" "github.com/AdguardTeam/AdGuardHome/internal/aghtest"
) )
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
testutil.DiscardLogOutput(m) aghtest.DiscardLogOutput(m)
} }

View File

@ -11,7 +11,7 @@ import (
"strconv" "strconv"
"testing" "testing"
"github.com/AdguardTeam/AdGuardHome/internal/testutil" "github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/version" "github.com/AdguardTeam/AdGuardHome/internal/version"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -19,7 +19,7 @@ import (
// TODO(a.garipov): Rewrite these tests. // TODO(a.garipov): Rewrite these tests.
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
testutil.DiscardLogOutput(m) aghtest.DiscardLogOutput(m)
} }
func startHTTPServer(data string) (l net.Listener, portStr string) { func startHTTPServer(data string) (l net.Listener, portStr string) {

View File

@ -8,13 +8,13 @@ import (
"testing" "testing"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/testutil" "github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
testutil.DiscardLogOutput(m) aghtest.DiscardLogOutput(m)
} }
func prepareTestDir() string { func prepareTestDir() string {