diff --git a/CHANGELOG.md b/CHANGELOG.md index 5f851a6b..480449f3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,7 @@ NOTE: Add new changes BELOW THIS COMMENT. ### Added +- IPv6 support in Safe Search for some services. - The ability to make bootstrap DNS lookups prefer IPv6 addresses to IPv4 ones using the new `dns.bootstrap_prefer_ipv6` configuration file property ([#4262]). diff --git a/internal/dnsforward/dnsforward_test.go b/internal/dnsforward/dnsforward_test.go index bb4b28c6..ce8b1cf2 100644 --- a/internal/dnsforward/dnsforward_test.go +++ b/internal/dnsforward/dnsforward_test.go @@ -453,8 +453,9 @@ func TestSafeSearch(t *testing.T) { SafeSearchCacheSize: 1000, CacheTime: 30, } - safeSearch, err := safesearch.NewDefaultSafeSearch( + safeSearch, err := safesearch.NewDefault( safeSearchConf, + "", filterConf.SafeSearchCacheSize, time.Minute*time.Duration(filterConf.CacheTime), ) diff --git a/internal/filtering/safesearch.go b/internal/filtering/safesearch.go index a57394a3..003b9ee1 100644 --- a/internal/filtering/safesearch.go +++ b/internal/filtering/safesearch.go @@ -1,17 +1,17 @@ package filtering -import ( - "github.com/AdguardTeam/urlfilter/rules" - "github.com/miekg/dns" -) +import "github.com/miekg/dns" // SafeSearch interface describes a service for search engines hosts rewrites. type SafeSearch interface { - // SearchHost returns a replacement address for the search engine host. - SearchHost(host string, qtype uint16) (res *rules.DNSRewrite) - - // CheckHost checks host with safe search engine. + // CheckHost checks host with safe search filter. CheckHost must be safe + // for concurrent use. qtype must be either [dns.TypeA] or [dns.TypeAAAA]. CheckHost(host string, qtype uint16) (res Result, err error) + + // Update updates the configuration of the safe search filter. Update must + // be safe for concurrent use. An implementation of Update may ignore some + // fields, but it must document which. + Update(conf SafeSearchConfig) (err error) } // SafeSearchConfig is a struct with safe search related settings. @@ -37,10 +37,12 @@ type SafeSearchConfig struct { // [hostChecker.check]. func (d *DNSFilter) checkSafeSearch( host string, - _ uint16, + qtype uint16, setts *Settings, ) (res Result, err error) { - if !setts.ProtectionEnabled || !setts.SafeSearchEnabled { + if !setts.ProtectionEnabled || + !setts.SafeSearchEnabled || + (qtype != dns.TypeA && qtype != dns.TypeAAAA) { return Result{}, nil } @@ -50,8 +52,8 @@ func (d *DNSFilter) checkSafeSearch( clientSafeSearch := setts.ClientSafeSearch if clientSafeSearch != nil { - return clientSafeSearch.CheckHost(host, dns.TypeA) + return clientSafeSearch.CheckHost(host, qtype) } - return d.safeSearch.CheckHost(host, dns.TypeA) + return d.safeSearch.CheckHost(host, qtype) } diff --git a/internal/filtering/safesearch/rules/bing.txt b/internal/filtering/safesearch/rules/bing.txt index 8c61b63c..4e1a7bc8 100644 --- a/internal/filtering/safesearch/rules/bing.txt +++ b/internal/filtering/safesearch/rules/bing.txt @@ -1 +1 @@ -|www.bing.com^$dnsrewrite=NOERROR;CNAME;strict.bing.com \ No newline at end of file +|www.bing.com^$dnsrewrite=NOERROR;CNAME;strict.bing.com diff --git a/internal/filtering/safesearch/rules/duckduckgo.txt b/internal/filtering/safesearch/rules/duckduckgo.txt index 084f1be0..268d1f3d 100644 --- a/internal/filtering/safesearch/rules/duckduckgo.txt +++ b/internal/filtering/safesearch/rules/duckduckgo.txt @@ -1,3 +1,3 @@ |duckduckgo.com^$dnsrewrite=NOERROR;CNAME;safe.duckduckgo.com |start.duckduckgo.com^$dnsrewrite=NOERROR;CNAME;safe.duckduckgo.com -|www.duckduckgo.com^$dnsrewrite=NOERROR;CNAME;safe.duckduckgo.com \ No newline at end of file +|www.duckduckgo.com^$dnsrewrite=NOERROR;CNAME;safe.duckduckgo.com diff --git a/internal/filtering/safesearch/rules/google.txt b/internal/filtering/safesearch/rules/google.txt index 62f13067..60df6617 100644 --- a/internal/filtering/safesearch/rules/google.txt +++ b/internal/filtering/safesearch/rules/google.txt @@ -188,4 +188,4 @@ |www.google.tt^$dnsrewrite=NOERROR;CNAME;forcesafesearch.google.com |www.google.vg^$dnsrewrite=NOERROR;CNAME;forcesafesearch.google.com |www.google.vu^$dnsrewrite=NOERROR;CNAME;forcesafesearch.google.com -|www.google.ws^$dnsrewrite=NOERROR;CNAME;forcesafesearch.google.com \ No newline at end of file +|www.google.ws^$dnsrewrite=NOERROR;CNAME;forcesafesearch.google.com diff --git a/internal/filtering/safesearch/rules/pixabay.txt b/internal/filtering/safesearch/rules/pixabay.txt index 0ab07746..7fe39482 100644 --- a/internal/filtering/safesearch/rules/pixabay.txt +++ b/internal/filtering/safesearch/rules/pixabay.txt @@ -1 +1 @@ -|pixabay.com^$dnsrewrite=NOERROR;CNAME;safesearch.pixabay.com \ No newline at end of file +|pixabay.com^$dnsrewrite=NOERROR;CNAME;safesearch.pixabay.com diff --git a/internal/filtering/safesearch/rules/yandex.txt b/internal/filtering/safesearch/rules/yandex.txt index b6f4afb7..c54a5148 100644 --- a/internal/filtering/safesearch/rules/yandex.txt +++ b/internal/filtering/safesearch/rules/yandex.txt @@ -49,4 +49,4 @@ |yandex.ru^$dnsrewrite=NOERROR;A;213.180.193.56 |yandex.tj^$dnsrewrite=NOERROR;A;213.180.193.56 |yandex.tm^$dnsrewrite=NOERROR;A;213.180.193.56 -|yandex.uz^$dnsrewrite=NOERROR;A;213.180.193.56 \ No newline at end of file +|yandex.uz^$dnsrewrite=NOERROR;A;213.180.193.56 diff --git a/internal/filtering/safesearch/rules/youtube.txt b/internal/filtering/safesearch/rules/youtube.txt index 70e3ae46..8a3fe247 100644 --- a/internal/filtering/safesearch/rules/youtube.txt +++ b/internal/filtering/safesearch/rules/youtube.txt @@ -2,4 +2,4 @@ |m.youtube.com^$dnsrewrite=NOERROR;CNAME;restrictmoderate.youtube.com |youtubei.googleapis.com^$dnsrewrite=NOERROR;CNAME;restrictmoderate.youtube.com |youtube.googleapis.com^$dnsrewrite=NOERROR;CNAME;restrictmoderate.youtube.com -|www.youtube-nocookie.com^$dnsrewrite=NOERROR;CNAME;restrictmoderate.youtube.com \ No newline at end of file +|www.youtube-nocookie.com^$dnsrewrite=NOERROR;CNAME;restrictmoderate.youtube.com diff --git a/internal/filtering/safesearch/safesearch.go b/internal/filtering/safesearch/safesearch.go index e944e217..880f406c 100644 --- a/internal/filtering/safesearch/safesearch.go +++ b/internal/filtering/safesearch/safesearch.go @@ -9,6 +9,7 @@ import ( "fmt" "net" "strings" + "sync" "time" "github.com/AdguardTeam/AdGuardHome/internal/filtering" @@ -53,44 +54,85 @@ func isServiceProtected(s filtering.SafeSearchConfig, service Service) (ok bool) } } -// DefaultSafeSearch is the default safesearch struct. -type DefaultSafeSearch struct { - engine *urlfilter.DNSEngine - safeSearchCache cache.Cache - resolver filtering.Resolver - cacheTime time.Duration +// Default is the default safe search filter that uses filtering rules with the +// dnsrewrite modifier. +type Default struct { + // mu protects engine. + mu *sync.RWMutex + + // engine is the filtering engine that contains the DNS rewrite rules. + // engine may be nil, which means that this safe search filter is disabled. + engine *urlfilter.DNSEngine + + cache cache.Cache + resolver filtering.Resolver + logPrefix string + cacheTTL time.Duration } -// NewDefaultSafeSearch returns new safesearch struct. CacheTime is an element -// TTL (in minutes). -func NewDefaultSafeSearch( +// NewDefault returns an initialized default safe search filter. name is used +// for logging. +func NewDefault( conf filtering.SafeSearchConfig, + name string, cacheSize uint, - cacheTime time.Duration, -) (ss *DefaultSafeSearch, err error) { - engine, err := newEngine(filtering.SafeSearchListID, conf) - if err != nil { - return nil, err - } - + cacheTTL time.Duration, +) (ss *Default, err error) { var resolver filtering.Resolver = net.DefaultResolver if conf.CustomResolver != nil { resolver = conf.CustomResolver } - return &DefaultSafeSearch{ - engine: engine, - safeSearchCache: cache.New(cache.Config{ + ss = &Default{ + mu: &sync.RWMutex{}, + + cache: cache.New(cache.Config{ EnableLRU: true, MaxSize: cacheSize, }), - cacheTime: cacheTime, - resolver: resolver, - }, nil + resolver: resolver, + // Use %s, because the client safe-search names already contain double + // quotes. + logPrefix: fmt.Sprintf("safesearch %s: ", name), + cacheTTL: cacheTTL, + } + + err = ss.resetEngine(filtering.SafeSearchListID, conf) + if err != nil { + // Don't wrap the error, because it's informative enough as is. + return nil, err + } + + return ss, nil } -// newEngine creates new engine for provided safe search configuration. -func newEngine(listID int, conf filtering.SafeSearchConfig) (engine *urlfilter.DNSEngine, err error) { +// log is a helper for logging that includes the name of the safe search +// filter. level must be one of [log.DEBUG], [log.INFO], and [log.ERROR]. +func (ss *Default) log(level log.Level, msg string, args ...any) { + switch level { + case log.DEBUG: + log.Debug(ss.logPrefix+msg, args...) + case log.INFO: + log.Info(ss.logPrefix+msg, args...) + case log.ERROR: + log.Error(ss.logPrefix+msg, args...) + default: + panic(fmt.Errorf("safesearch: unsupported logging level %d", level)) + } +} + +// resetEngine creates new engine for provided safe search configuration and +// sets it in ss. +func (ss *Default) resetEngine( + listID int, + conf filtering.SafeSearchConfig, +) (err error) { + if !conf.Enabled { + ss.log(log.INFO, "disabled") + + return nil + } + var sb strings.Builder for service, serviceRules := range safeSearchRules { if isServiceProtected(conf, service) { @@ -106,20 +148,73 @@ func newEngine(listID int, conf filtering.SafeSearchConfig) (engine *urlfilter.D rs, err := filterlist.NewRuleStorage([]filterlist.RuleList{strList}) if err != nil { - return nil, fmt.Errorf("creating rule storage: %w", err) + return fmt.Errorf("creating rule storage: %w", err) } - engine = urlfilter.NewDNSEngine(rs) - log.Info("safesearch: filter %d: reset %d rules", listID, engine.RulesCount) + ss.engine = urlfilter.NewDNSEngine(rs) - return engine, nil + ss.log(log.INFO, "reset %d rules", ss.engine.RulesCount) + + return nil } // type check -var _ filtering.SafeSearch = (*DefaultSafeSearch)(nil) +var _ filtering.SafeSearch = (*Default)(nil) + +// CheckHost implements the [filtering.SafeSearch] interface for +// *DefaultSafeSearch. +func (ss *Default) CheckHost( + host string, + qtype rules.RRType, +) (res filtering.Result, err error) { + start := time.Now() + defer func() { + ss.log(log.DEBUG, "lookup for %q finished in %s", host, time.Since(start)) + }() + + if qtype != dns.TypeA && qtype != dns.TypeAAAA { + return filtering.Result{}, fmt.Errorf("unsupported question type %s", dns.Type(qtype)) + } + + // Check cache. Return cached result if it was found + cachedValue, isFound := ss.getCachedResult(host, qtype) + if isFound { + ss.log(log.DEBUG, "found in cache: %q", host) + + return cachedValue, nil + } + + rewrite := ss.searchHost(host, qtype) + if rewrite == nil { + return filtering.Result{}, nil + } + + fltRes, err := ss.newResult(rewrite, qtype) + if err != nil { + ss.log(log.DEBUG, "looking up addresses for %q: %s", host, err) + + return filtering.Result{}, err + } + + if fltRes != nil { + res = *fltRes + ss.setCacheResult(host, qtype, res) + + return res, nil + } + + return filtering.Result{}, fmt.Errorf("no ipv4 addresses for %q", host) +} + +// searchHost looks up DNS rewrites in the internal DNS filtering engine. +func (ss *Default) searchHost(host string, qtype rules.RRType) (res *rules.DNSRewrite) { + ss.mu.RLock() + defer ss.mu.RUnlock() + + if ss.engine == nil { + return nil + } -// SearchHost implements the [filtering.SafeSearch] interface for *DefaultSafeSearch. -func (ss *DefaultSafeSearch) SearchHost(host string, qtype uint16) (res *rules.DNSRewrite) { r, _ := ss.engine.MatchRequest(&urlfilter.DNSRequest{ Hostname: strings.ToLower(host), DNSType: qtype, @@ -133,51 +228,11 @@ func (ss *DefaultSafeSearch) SearchHost(host string, qtype uint16) (res *rules.D return nil } -// CheckHost implements the [filtering.SafeSearch] interface for -// *DefaultSafeSearch. -func (ss *DefaultSafeSearch) CheckHost( - host string, - qtype uint16, -) (res filtering.Result, err error) { - if log.GetLevel() >= log.DEBUG { - timer := log.StartTimer() - defer timer.LogElapsed("safesearch: lookup for %s", host) - } - - // Check cache. Return cached result if it was found - cachedValue, isFound := ss.getCachedResult(host) - if isFound { - log.Debug("safesearch: found in cache: %s", host) - - return cachedValue, nil - } - - rewrite := ss.SearchHost(host, qtype) - if rewrite == nil { - return filtering.Result{}, nil - } - - dRes, err := ss.newResult(rewrite, qtype) - if err != nil { - log.Debug("safesearch: failed to lookup addresses for %s: %s", host, err) - - return filtering.Result{}, err - } - - if dRes != nil { - res = *dRes - ss.setCacheResult(host, res) - - return res, nil - } - - return filtering.Result{}, fmt.Errorf("no ipv4 addresses in safe search response for %s", host) -} - -// newResult creates Result object from rewrite rule. -func (ss *DefaultSafeSearch) newResult( +// newResult creates Result object from rewrite rule. qtype must be either +// [dns.TypeA] or [dns.TypeAAAA]. +func (ss *Default) newResult( rewrite *rules.DNSRewrite, - qtype uint16, + qtype rules.RRType, ) (res *filtering.Result, err error) { res = &filtering.Result{ Rules: []*filtering.ResultRule{{ @@ -187,7 +242,7 @@ func (ss *DefaultSafeSearch) newResult( IsFiltered: true, } - if rewrite.RRType == qtype && (qtype == dns.TypeA || qtype == dns.TypeAAAA) { + if rewrite.RRType == qtype { ip, ok := rewrite.Value.(net.IP) if !ok || ip == nil { return nil, nil @@ -198,17 +253,25 @@ func (ss *DefaultSafeSearch) newResult( return res, nil } - if rewrite.NewCNAME == "" { + host := rewrite.NewCNAME + if host == "" { return nil, nil } - ips, err := ss.resolver.LookupIP(context.Background(), "ip", rewrite.NewCNAME) + ss.log(log.DEBUG, "resolving %q", host) + + ips, err := ss.resolver.LookupIP(context.Background(), qtypeToProto(qtype), host) if err != nil { return nil, err } + ss.log(log.DEBUG, "resolved %s", ips) + for _, ip := range ips { - if ip = ip.To4(); ip == nil { + // TODO(a.garipov): Remove this filtering once the resolver we use + // actually learns about network. + ip = fitToProto(ip, qtype) + if ip == nil { continue } @@ -220,38 +283,71 @@ func (ss *DefaultSafeSearch) newResult( return nil, nil } -// setCacheResult stores data in cache for host. -func (ss *DefaultSafeSearch) setCacheResult(host string, res filtering.Result) { - expire := uint32(time.Now().Add(ss.cacheTime).Unix()) +// qtypeToProto returns "ip4" for [dns.TypeA] and "ip6" for [dns.TypeAAAA]. +// It panics for other types. +func qtypeToProto(qtype rules.RRType) (proto string) { + switch qtype { + case dns.TypeA: + return "ip4" + case dns.TypeAAAA: + return "ip6" + default: + panic(fmt.Errorf("safesearch: unsupported question type %s", dns.Type(qtype))) + } +} + +// fitToProto returns a non-nil IP address if ip is the correct protocol version +// for qtype. qtype is expected to be either [dns.TypeA] or [dns.TypeAAAA]. +func fitToProto(ip net.IP, qtype rules.RRType) (res net.IP) { + ip4 := ip.To4() + if qtype == dns.TypeA { + return ip4 + } + + if ip4 == nil { + return ip + } + + return nil +} + +// setCacheResult stores data in cache for host. qtype is expected to be either +// [dns.TypeA] or [dns.TypeAAAA]. +func (ss *Default) setCacheResult(host string, qtype rules.RRType, res filtering.Result) { + expire := uint32(time.Now().Add(ss.cacheTTL).Unix()) exp := make([]byte, 4) binary.BigEndian.PutUint32(exp, expire) buf := bytes.NewBuffer(exp) err := gob.NewEncoder(buf).Encode(res) if err != nil { - log.Error("safesearch: cache encoding: %s", err) + ss.log(log.ERROR, "cache encoding: %s", err) return } val := buf.Bytes() - _ = ss.safeSearchCache.Set([]byte(host), val) + _ = ss.cache.Set([]byte(dns.Type(qtype).String()+" "+host), val) - log.Debug("safesearch: stored in cache: %s (%d bytes)", host, len(val)) + ss.log(log.DEBUG, "stored in cache: %q, %d bytes", host, len(val)) } -// getCachedResult returns stored data from cache for host. -func (ss *DefaultSafeSearch) getCachedResult(host string) (res filtering.Result, ok bool) { +// getCachedResult returns stored data from cache for host. qtype is expected +// to be either [dns.TypeA] or [dns.TypeAAAA]. +func (ss *Default) getCachedResult( + host string, + qtype rules.RRType, +) (res filtering.Result, ok bool) { res = filtering.Result{} - data := ss.safeSearchCache.Get([]byte(host)) + data := ss.cache.Get([]byte(dns.Type(qtype).String() + " " + host)) if data == nil { return res, false } exp := binary.BigEndian.Uint32(data[:4]) if exp <= uint32(time.Now().Unix()) { - ss.safeSearchCache.Del([]byte(host)) + ss.cache.Del([]byte(host)) return res, false } @@ -260,10 +356,27 @@ func (ss *DefaultSafeSearch) getCachedResult(host string) (res filtering.Result, err := gob.NewDecoder(buf).Decode(&res) if err != nil { - log.Debug("safesearch: cache decoding: %s", err) + ss.log(log.ERROR, "cache decoding: %s", err) return filtering.Result{}, false } return res, true } + +// Update implements the [filtering.SafeSearch] interface for *Default. Update +// ignores the CustomResolver and Enabled fields. +func (ss *Default) Update(conf filtering.SafeSearchConfig) (err error) { + ss.mu.Lock() + defer ss.mu.Unlock() + + err = ss.resetEngine(filtering.SafeSearchListID, conf) + if err != nil { + // Don't wrap the error, because it's informative enough as is. + return err + } + + ss.cache.Clear() + + return nil +} diff --git a/internal/filtering/safesearch/safesearch_internal_test.go b/internal/filtering/safesearch/safesearch_internal_test.go new file mode 100644 index 00000000..c87a9ad5 --- /dev/null +++ b/internal/filtering/safesearch/safesearch_internal_test.go @@ -0,0 +1,137 @@ +package safesearch + +import ( + "context" + "net" + "testing" + "time" + + "github.com/AdguardTeam/AdGuardHome/internal/aghtest" + "github.com/AdguardTeam/AdGuardHome/internal/filtering" + "github.com/AdguardTeam/urlfilter/rules" + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TODO(a.garipov): Move as much of this as possible into proper external tests. + +const ( + // TODO(a.garipov): Add IPv6 tests. + testQType = dns.TypeA + testCacheSize = 5000 + testCacheTTL = 30 * time.Minute +) + +var defaultSafeSearchConf = filtering.SafeSearchConfig{ + Enabled: true, + Bing: true, + DuckDuckGo: true, + Google: true, + Pixabay: true, + Yandex: true, + YouTube: true, +} + +var yandexIP = net.IPv4(213, 180, 193, 56) + +func newForTest(t testing.TB, ssConf filtering.SafeSearchConfig) (ss *Default) { + ss, err := NewDefault(ssConf, "", testCacheSize, testCacheTTL) + require.NoError(t, err) + + return ss +} + +func TestSafeSearch(t *testing.T) { + ss := newForTest(t, defaultSafeSearchConf) + val := ss.searchHost("www.google.com", testQType) + + assert.Equal(t, &rules.DNSRewrite{NewCNAME: "forcesafesearch.google.com"}, val) +} + +func TestSafeSearchCacheYandex(t *testing.T) { + const domain = "yandex.ru" + + ss := newForTest(t, filtering.SafeSearchConfig{Enabled: false}) + + // Check host with disabled safesearch. + res, err := ss.CheckHost(domain, testQType) + require.NoError(t, err) + + assert.False(t, res.IsFiltered) + assert.Empty(t, res.Rules) + + ss = newForTest(t, defaultSafeSearchConf) + res, err = ss.CheckHost(domain, testQType) + require.NoError(t, err) + + // For yandex we already know valid IP. + require.Len(t, res.Rules, 1) + + assert.Equal(t, res.Rules[0].IP, yandexIP) + + // Check cache. + cachedValue, isFound := ss.getCachedResult(domain, testQType) + require.True(t, isFound) + require.Len(t, cachedValue.Rules, 1) + + assert.Equal(t, cachedValue.Rules[0].IP, yandexIP) +} + +func TestSafeSearchCacheGoogle(t *testing.T) { + const domain = "www.google.ru" + + ss := newForTest(t, filtering.SafeSearchConfig{Enabled: false}) + + res, err := ss.CheckHost(domain, testQType) + require.NoError(t, err) + + assert.False(t, res.IsFiltered) + assert.Empty(t, res.Rules) + + resolver := &aghtest.TestResolver{} + ss = newForTest(t, defaultSafeSearchConf) + ss.resolver = resolver + + // Lookup for safesearch domain. + rewrite := ss.searchHost(domain, testQType) + + ips, err := resolver.LookupIP(context.Background(), "ip", rewrite.NewCNAME) + require.NoError(t, err) + + var foundIP net.IP + for _, ip := range ips { + if ip.To4() != nil { + foundIP = ip + + break + } + } + + res, err = ss.CheckHost(domain, testQType) + require.NoError(t, err) + require.Len(t, res.Rules, 1) + + assert.True(t, res.Rules[0].IP.Equal(foundIP)) + + // Check cache. + cachedValue, isFound := ss.getCachedResult(domain, testQType) + require.True(t, isFound) + require.Len(t, cachedValue.Rules, 1) + + assert.True(t, cachedValue.Rules[0].IP.Equal(foundIP)) +} + +const googleHost = "www.google.com" + +var dnsRewriteSink *rules.DNSRewrite + +func BenchmarkSafeSearch(b *testing.B) { + ss := newForTest(b, defaultSafeSearchConf) + + for n := 0; n < b.N; n++ { + dnsRewriteSink = ss.searchHost(googleHost, testQType) + } + + assert.Equal(b, "forcesafesearch.google.com", dnsRewriteSink.NewCNAME) +} diff --git a/internal/filtering/safesearch/safesearch_test.go b/internal/filtering/safesearch/safesearch_test.go index 97d18f95..8e1aea2e 100644 --- a/internal/filtering/safesearch/safesearch_test.go +++ b/internal/filtering/safesearch/safesearch_test.go @@ -1,26 +1,37 @@ -package safesearch +package safesearch_test import ( - "context" "net" "testing" "time" "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/filtering" - "github.com/AdguardTeam/urlfilter/rules" + "github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch" + "github.com/AdguardTeam/golibs/testutil" "github.com/miekg/dns" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +func TestMain(m *testing.M) { + testutil.DiscardLogOutput(m) +} + +// Common test constants. const ( - safeSearchCacheSize = 5000 - cacheTime = 30 * time.Minute + // TODO(a.garipov): Add IPv6 tests. + testQType = dns.TypeA + testCacheSize = 5000 + testCacheTTL = 30 * time.Minute ) -var defaultSafeSearchConf = filtering.SafeSearchConfig{ - Enabled: true, +// testConf is the default safe search configuration for tests. +var testConf = filtering.SafeSearchConfig{ + CustomResolver: nil, + + Enabled: true, + Bing: true, DuckDuckGo: true, Google: true, @@ -29,25 +40,15 @@ var defaultSafeSearchConf = filtering.SafeSearchConfig{ YouTube: true, } +// yandexIP is the expected IP address of Yandex safe search results. Keep in +// sync with the rules data. var yandexIP = net.IPv4(213, 180, 193, 56) -func newForTest(t testing.TB, ssConf filtering.SafeSearchConfig) (ss *DefaultSafeSearch) { - ss, err := NewDefaultSafeSearch(ssConf, safeSearchCacheSize, cacheTime) +func TestDefault_CheckHost_yandex(t *testing.T) { + conf := testConf + ss, err := safesearch.NewDefault(conf, "", testCacheSize, testCacheTTL) require.NoError(t, err) - return ss -} - -func TestSafeSearch(t *testing.T) { - ss := newForTest(t, defaultSafeSearchConf) - val := ss.SearchHost("www.google.com", dns.TypeA) - - assert.Equal(t, &rules.DNSRewrite{NewCNAME: "forcesafesearch.google.com"}, val) -} - -func TestCheckHostSafeSearchYandex(t *testing.T) { - ss := newForTest(t, defaultSafeSearchConf) - // Check host for each domain. for _, host := range []string{ "yandex.ru", @@ -57,7 +58,8 @@ func TestCheckHostSafeSearchYandex(t *testing.T) { "yandex.kz", "www.yandex.com", } { - res, err := ss.CheckHost(host, dns.TypeA) + var res filtering.Result + res, err = ss.CheckHost(host, testQType) require.NoError(t, err) assert.True(t, res.IsFiltered) @@ -69,12 +71,14 @@ func TestCheckHostSafeSearchYandex(t *testing.T) { } } -func TestCheckHostSafeSearchGoogle(t *testing.T) { +func TestDefault_CheckHost_google(t *testing.T) { resolver := &aghtest.TestResolver{} ip, _ := resolver.HostToIPs("forcesafesearch.google.com") - ss := newForTest(t, defaultSafeSearchConf) - ss.resolver = resolver + conf := testConf + conf.CustomResolver = resolver + ss, err := safesearch.NewDefault(conf, "", testCacheSize, testCacheTTL) + require.NoError(t, err) // Check host for each domain. for _, host := range []string{ @@ -87,7 +91,8 @@ func TestCheckHostSafeSearchGoogle(t *testing.T) { "www.google.je", } { t.Run(host, func(t *testing.T) { - res, err := ss.CheckHost(host, dns.TypeA) + var res filtering.Result + res, err = ss.CheckHost(host, testQType) require.NoError(t, err) assert.True(t, res.IsFiltered) @@ -100,103 +105,35 @@ func TestCheckHostSafeSearchGoogle(t *testing.T) { } } -func TestSafeSearchCacheYandex(t *testing.T) { - const domain = "yandex.ru" - - ss := newForTest(t, filtering.SafeSearchConfig{Enabled: false}) - - // Check host with disabled safesearch. - res, err := ss.CheckHost(domain, dns.TypeA) +func TestDefault_Update(t *testing.T) { + conf := testConf + ss, err := safesearch.NewDefault(conf, "", testCacheSize, testCacheTTL) require.NoError(t, err) - assert.False(t, res.IsFiltered) - assert.Empty(t, res.Rules) - - ss = newForTest(t, defaultSafeSearchConf) - res, err = ss.CheckHost(domain, dns.TypeA) + res, err := ss.CheckHost("www.yandex.com", testQType) require.NoError(t, err) - // For yandex we already know valid IP. - require.Len(t, res.Rules, 1) + assert.True(t, res.IsFiltered) - assert.Equal(t, res.Rules[0].IP, yandexIP) - - // Check cache. - cachedValue, isFound := ss.getCachedResult(domain) - require.True(t, isFound) - require.Len(t, cachedValue.Rules, 1) - - assert.Equal(t, cachedValue.Rules[0].IP, yandexIP) -} - -func TestSafeSearchCacheGoogle(t *testing.T) { - const domain = "www.google.ru" - - ss := newForTest(t, filtering.SafeSearchConfig{Enabled: false}) - - res, err := ss.CheckHost(domain, dns.TypeA) - require.NoError(t, err) - - assert.False(t, res.IsFiltered) - assert.Empty(t, res.Rules) - - resolver := &aghtest.TestResolver{} - ss = newForTest(t, defaultSafeSearchConf) - ss.resolver = resolver - - // Lookup for safesearch domain. - rewrite := ss.SearchHost(domain, dns.TypeA) - - ips, err := resolver.LookupIP(context.Background(), "ip", rewrite.NewCNAME) - require.NoError(t, err) - - var foundIP net.IP - for _, ip := range ips { - if ip.To4() != nil { - foundIP = ip - - break - } - } - - res, err = ss.CheckHost(domain, dns.TypeA) - require.NoError(t, err) - require.Len(t, res.Rules, 1) - - assert.True(t, res.Rules[0].IP.Equal(foundIP)) - - // Check cache. - cachedValue, isFound := ss.getCachedResult(domain) - require.True(t, isFound) - require.Len(t, cachedValue.Rules, 1) - - assert.True(t, cachedValue.Rules[0].IP.Equal(foundIP)) -} - -const googleHost = "www.google.com" - -var dnsRewriteSink *rules.DNSRewrite - -func BenchmarkSafeSearch(b *testing.B) { - ss := newForTest(b, defaultSafeSearchConf) - - for n := 0; n < b.N; n++ { - dnsRewriteSink = ss.SearchHost(googleHost, dns.TypeA) - } - - assert.Equal(b, "forcesafesearch.google.com", dnsRewriteSink.NewCNAME) -} - -var dnsRewriteParallelSink *rules.DNSRewrite - -func BenchmarkSafeSearch_parallel(b *testing.B) { - ss := newForTest(b, defaultSafeSearchConf) - - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - dnsRewriteParallelSink = ss.SearchHost(googleHost, dns.TypeA) - } + err = ss.Update(filtering.SafeSearchConfig{ + Enabled: true, + Google: false, }) + require.NoError(t, err) - assert.Equal(b, "forcesafesearch.google.com", dnsRewriteParallelSink.NewCNAME) + res, err = ss.CheckHost("www.yandex.com", testQType) + require.NoError(t, err) + + assert.False(t, res.IsFiltered) + + err = ss.Update(filtering.SafeSearchConfig{ + Enabled: false, + Google: true, + }) + require.NoError(t, err) + + res, err = ss.CheckHost("www.yandex.com", testQType) + require.NoError(t, err) + + assert.False(t, res.IsFiltered) } diff --git a/internal/filtering/safesearchhttp.go b/internal/filtering/safesearchhttp.go index db293231..6048cfea 100644 --- a/internal/filtering/safesearchhttp.go +++ b/internal/filtering/safesearchhttp.go @@ -50,11 +50,19 @@ func (d *DNSFilter) handleSafeSearchSettings(w http.ResponseWriter, r *http.Requ return } + conf := *req + err = d.safeSearch.Update(conf) + if err != nil { + aghhttp.Error(r, w, http.StatusBadRequest, "updating: %s", err) + + return + } + func() { d.confLock.Lock() defer d.confLock.Unlock() - d.Config.SafeSearchConf = *req + d.Config.SafeSearchConf = conf }() d.Config.ConfigModified() diff --git a/internal/home/client.go b/internal/home/client.go index 31e20743..c3946ffb 100644 --- a/internal/home/client.go +++ b/internal/home/client.go @@ -3,8 +3,10 @@ package home import ( "encoding" "fmt" + "time" "github.com/AdguardTeam/AdGuardHome/internal/filtering" + "github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch" "github.com/AdguardTeam/dnsproxy/proxy" ) @@ -45,6 +47,23 @@ func (c *Client) closeUpstreams() (err error) { return nil } +// setSafeSearch initializes and sets the safe search filter for this client. +func (c *Client) setSafeSearch( + conf filtering.SafeSearchConfig, + cacheSize uint, + cacheTTL time.Duration, +) (err error) { + ss, err := safesearch.NewDefault(conf, fmt.Sprintf("client %q", c.Name), cacheSize, cacheTTL) + if err != nil { + // Don't wrap the error, because it's informative enough as is. + return err + } + + c.SafeSearch = ss + + return nil +} + // clientSource represents the source from which the information about the // client has been obtained. type clientSource uint diff --git a/internal/home/clients.go b/internal/home/clients.go index 9453b951..58be4bde 100644 --- a/internal/home/clients.go +++ b/internal/home/clients.go @@ -13,7 +13,6 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/dhcpd" "github.com/AdguardTeam/AdGuardHome/internal/dnsforward" "github.com/AdguardTeam/AdGuardHome/internal/filtering" - "github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch" "github.com/AdguardTeam/AdGuardHome/internal/querylog" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" @@ -55,6 +54,14 @@ type clientsContainer struct { // more detail. lock sync.Mutex + // safeSearchCacheSize is the size of the safe search cache to use for + // persistent clients. + safeSearchCacheSize uint + + // safeSearchCacheTTL is the TTL of the safe search cache to use for + // persistent clients. + safeSearchCacheTTL time.Duration + // testing is a flag that disables some features for internal tests. // // TODO(a.garipov): Awful. Remove. @@ -74,6 +81,7 @@ func (clients *clientsContainer) Init( if clients.list != nil { log.Fatal("clients.list != nil") } + clients.list = make(map[string]*Client) clients.idIndex = make(map[string]*Client) clients.ipToRC = map[netip.Addr]*RuntimeClient{} @@ -85,6 +93,9 @@ func (clients *clientsContainer) Init( clients.arpdb = arpdb clients.addFromConfig(objects, filteringConf) + clients.safeSearchCacheSize = filteringConf.SafeSearchCacheSize + clients.safeSearchCacheTTL = time.Minute * time.Duration(filteringConf.CacheTime) + if clients.testing { return } @@ -171,18 +182,16 @@ func (clients *clientsContainer) addFromConfig(objects []*clientObject, filterin if o.SafeSearchConf.Enabled { o.SafeSearchConf.CustomResolver = safeSearchResolver{} - ss, err := safesearch.NewDefaultSafeSearch( + err := cli.setSafeSearch( o.SafeSearchConf, filteringConf.SafeSearchCacheSize, time.Minute*time.Duration(filteringConf.CacheTime), ) if err != nil { - log.Error("clients: init client safesearch %s: %s", cli.Name, err) + log.Error("clients: init client safesearch %q: %s", cli.Name, err) continue } - - cli.SafeSearch = ss } for _, s := range o.BlockedServices { diff --git a/internal/home/clients_test.go b/internal/home/clients_test.go index 1c08348e..410ef6d4 100644 --- a/internal/home/clients_test.go +++ b/internal/home/clients_test.go @@ -9,17 +9,27 @@ import ( "time" "github.com/AdguardTeam/AdGuardHome/internal/dhcpd" + "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func TestClients(t *testing.T) { - clients := clientsContainer{} - clients.testing = true +// newClientsContainer is a helper that creates a new clients container for +// tests. +func newClientsContainer() (c *clientsContainer) { + c = &clientsContainer{ + testing: true, + } - clients.Init(nil, nil, nil, nil, nil) + c.Init(nil, nil, nil, nil, &filtering.Config{}) + + return c +} + +func TestClients(t *testing.T) { + clients := newClientsContainer() t.Run("add_success", func(t *testing.T) { var ( @@ -198,10 +208,7 @@ func TestClients(t *testing.T) { } func TestClientsWHOIS(t *testing.T) { - clients := clientsContainer{ - testing: true, - } - clients.Init(nil, nil, nil, nil, nil) + clients := newClientsContainer() whois := &RuntimeClientWHOISInfo{ Country: "AU", Orgname: "Example Org", @@ -247,10 +254,7 @@ func TestClientsWHOIS(t *testing.T) { } func TestClientsAddExisting(t *testing.T) { - clients := clientsContainer{ - testing: true, - } - clients.Init(nil, nil, nil, nil, nil) + clients := newClientsContainer() t.Run("simple", func(t *testing.T) { ip := netip.MustParseAddr("1.1.1.1") @@ -325,10 +329,7 @@ func TestClientsAddExisting(t *testing.T) { } func TestClientsCustomUpstream(t *testing.T) { - clients := clientsContainer{ - testing: true, - } - clients.Init(nil, nil, nil, nil, nil) + clients := newClientsContainer() // Add client with upstreams. ok, err := clients.Add(&Client{ diff --git a/internal/home/clientshttp.go b/internal/home/clientshttp.go index c666d821..9a948d1e 100644 --- a/internal/home/clientshttp.go +++ b/internal/home/clientshttp.go @@ -49,8 +49,8 @@ type clientJSON struct { type runtimeClientJSON struct { WHOISInfo *RuntimeClientWHOISInfo `json:"whois_info"` - Name string `json:"name"` IP netip.Addr `json:"ip"` + Name string `json:"name"` Source clientSource `json:"source"` } @@ -90,14 +90,16 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http } // jsonToClient converts JSON object to Client object. -func jsonToClient(cj clientJSON) (c *Client) { +func (clients *clientsContainer) jsonToClient(cj clientJSON) (c *Client, err error) { var safeSearchConf filtering.SafeSearchConfig if cj.SafeSearchConf != nil { safeSearchConf = *cj.SafeSearchConf } else { // TODO(d.kolyshev): Remove after cleaning the deprecated // [clientJSON.SafeSearchEnabled] field. - safeSearchConf = filtering.SafeSearchConfig{Enabled: cj.SafeSearchEnabled} + safeSearchConf = filtering.SafeSearchConfig{ + Enabled: cj.SafeSearchEnabled, + } // Set default service flags for enabled safesearch. if safeSearchConf.Enabled { @@ -110,20 +112,35 @@ func jsonToClient(cj clientJSON) (c *Client) { } } - return &Client{ - Name: cj.Name, - IDs: cj.IDs, - Tags: cj.Tags, + c = &Client{ + safeSearchConf: safeSearchConf, + + Name: cj.Name, + + IDs: cj.IDs, + Tags: cj.Tags, + BlockedServices: cj.BlockedServices, + Upstreams: cj.Upstreams, + UseOwnSettings: !cj.UseGlobalSettings, FilteringEnabled: cj.FilteringEnabled, ParentalEnabled: cj.ParentalEnabled, SafeBrowsingEnabled: cj.SafeBrowsingEnabled, - safeSearchConf: safeSearchConf, UseOwnBlockedServices: !cj.UseGlobalBlockedServices, - BlockedServices: cj.BlockedServices, - - Upstreams: cj.Upstreams, } + + if safeSearchConf.Enabled { + err = c.setSafeSearch( + safeSearchConf, + clients.safeSearchCacheSize, + clients.safeSearchCacheTTL, + ) + if err != nil { + return nil, fmt.Errorf("creating safesearch for client %q: %w", c.Name, err) + } + } + + return c, nil } // clientToJSON converts Client object to JSON. @@ -161,7 +178,13 @@ func (clients *clientsContainer) handleAddClient(w http.ResponseWriter, r *http. return } - c := jsonToClient(cj) + c, err := clients.jsonToClient(cj) + if err != nil { + aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) + + return + } + ok, err := clients.Add(c) if err != nil { aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) @@ -224,7 +247,13 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht return } - c := jsonToClient(dj.Data) + c, err := clients.jsonToClient(dj.Data) + if err != nil { + aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) + + return + } + err = clients.Update(dj.Name, c) if err != nil { aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) diff --git a/internal/home/dns.go b/internal/home/dns.go index 47d6f177..b14fa440 100644 --- a/internal/home/dns.go +++ b/internal/home/dns.go @@ -545,6 +545,8 @@ var _ filtering.Resolver = safeSearchResolver{} // LookupIP implements [filtering.Resolver] interface for safeSearchResolver. // It returns the slice of net.IP with IPv4 and IPv6 instances. +// +// TODO(a.garipov): Support network. func (r safeSearchResolver) LookupIP(_ context.Context, _, host string) (ips []net.IP, err error) { addrs, err := Context.dnsServer.Resolve(host) if err != nil { diff --git a/internal/home/home.go b/internal/home/home.go index fe78344f..2e4751c5 100644 --- a/internal/home/home.go +++ b/internal/home/home.go @@ -297,8 +297,9 @@ func setupConfig(opts options) (err error) { config.DNS.DnsfilterConf.HTTPClient = Context.client config.DNS.DnsfilterConf.SafeSearchConf.CustomResolver = safeSearchResolver{} - config.DNS.DnsfilterConf.SafeSearch, err = safesearch.NewDefaultSafeSearch( + config.DNS.DnsfilterConf.SafeSearch, err = safesearch.NewDefault( config.DNS.DnsfilterConf.SafeSearchConf, + "default", config.DNS.DnsfilterConf.SafeSearchCacheSize, time.Minute*time.Duration(config.DNS.DnsfilterConf.CacheTime), ) @@ -869,8 +870,10 @@ func detectFirstRun() bool { // Connect to a remote server resolving hostname using our own DNS server. // // TODO(e.burkov): This messy logic should be decomposed and clarified. +// +// TODO(a.garipov): Support network. func customDialContext(ctx context.Context, network, addr string) (conn net.Conn, err error) { - log.Tracef("network:%v addr:%v", network, addr) + log.Debug("home: customdial: dialing addr %q for network %s", addr, network) host, port, err := net.SplitHostPort(addr) if err != nil {