Pull request 2286: AGDNS-2374-slog-safesearch

Squashed commit of the following:

commit 1909dfed99
Merge: 3856fda5f 2c64ab5a5
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Wed Oct 9 16:21:38 2024 +0300

    Merge branch 'master' into AGDNS-2374-slog-safesearch

commit 3856fda5f3
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Tue Oct 8 20:04:34 2024 +0300

    home: imp code

commit de774009aa
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Mon Oct 7 16:41:58 2024 +0300

    all: imp code

commit 038bae59d5
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Thu Oct 3 20:24:48 2024 +0300

    all: imp code

commit 792975e248
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Thu Oct 3 15:46:40 2024 +0300

    all: slog safesearch
This commit is contained in:
Stanislav Chzhen 2024-10-09 16:31:03 +03:00
parent 2c64ab5a51
commit 6363f8a2e7
12 changed files with 237 additions and 122 deletions

View File

@ -7,10 +7,8 @@ import (
"net/netip" "net/netip"
"slices" "slices"
"strings" "strings"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
"github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
@ -323,20 +321,3 @@ func (c *Persistent) CloseUpstreams() (err error) {
return nil return nil
} }
// SetSafeSearch initializes and sets the safe search filter for this client.
func (c *Persistent) SetSafeSearch(
conf filtering.SafeSearchConfig,
cacheSize uint,
cacheTTL time.Duration,
) (err error) {
ss, err := safesearch.NewDefault(conf, fmt.Sprintf("client %q", c.Name), cacheSize, cacheTTL)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
}
c.SafeSearch = ss
return nil
}

View File

@ -513,12 +513,14 @@ func TestSafeSearch(t *testing.T) {
SafeSearchCacheSize: 1000, SafeSearchCacheSize: 1000,
CacheTime: 30, CacheTime: 30,
} }
safeSearch, err := safesearch.NewDefault(
safeSearchConf, ctx := testutil.ContextWithTimeout(t, testTimeout)
"", safeSearch, err := safesearch.NewDefault(ctx, &safesearch.DefaultConfig{
filterConf.SafeSearchCacheSize, Logger: slogutil.NewDiscardLogger(),
time.Minute*time.Duration(filterConf.CacheTime), ServicesConfig: safeSearchConf,
) CacheSize: filterConf.SafeSearchCacheSize,
CacheTTL: time.Minute * time.Duration(filterConf.CacheTime),
})
require.NoError(t, err) require.NoError(t, err)
filterConf.SafeSearch = safeSearch filterConf.SafeSearch = safeSearch

View File

@ -1,15 +1,17 @@
package filtering package filtering
import "context"
// SafeSearch interface describes a service for search engines hosts rewrites. // SafeSearch interface describes a service for search engines hosts rewrites.
type SafeSearch interface { type SafeSearch interface {
// CheckHost checks host with safe search filter. CheckHost must be safe // CheckHost checks host with safe search filter. CheckHost must be safe
// for concurrent use. qtype must be either [dns.TypeA] or [dns.TypeAAAA]. // for concurrent use. qtype must be either [dns.TypeA] or [dns.TypeAAAA].
CheckHost(host string, qtype uint16) (res Result, err error) CheckHost(ctx context.Context, host string, qtype uint16) (res Result, err error)
// Update updates the configuration of the safe search filter. Update must // Update updates the configuration of the safe search filter. Update must
// be safe for concurrent use. An implementation of Update may ignore some // be safe for concurrent use. An implementation of Update may ignore some
// fields, but it must document which. // fields, but it must document which.
Update(conf SafeSearchConfig) (err error) Update(ctx context.Context, conf SafeSearchConfig) (err error)
} }
// SafeSearchConfig is a struct with safe search related settings. // SafeSearchConfig is a struct with safe search related settings.
@ -40,10 +42,13 @@ func (d *DNSFilter) checkSafeSearch(
return Result{}, nil return Result{}, nil
} }
// TODO(s.chzhen): Pass context.
ctx := context.TODO()
clientSafeSearch := setts.ClientSafeSearch clientSafeSearch := setts.ClientSafeSearch
if clientSafeSearch != nil { if clientSafeSearch != nil {
return clientSafeSearch.CheckHost(host, qtype) return clientSafeSearch.CheckHost(ctx, host, qtype)
} }
return d.safeSearch.CheckHost(host, qtype) return d.safeSearch.CheckHost(ctx, host, qtype)
} }

View File

@ -3,9 +3,11 @@ package safesearch
import ( import (
"bytes" "bytes"
"context"
"encoding/binary" "encoding/binary"
"encoding/gob" "encoding/gob"
"fmt" "fmt"
"log/slog"
"net/netip" "net/netip"
"strings" "strings"
"sync" "sync"
@ -14,13 +16,20 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist" "github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist"
"github.com/AdguardTeam/golibs/cache" "github.com/AdguardTeam/golibs/cache"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/urlfilter" "github.com/AdguardTeam/urlfilter"
"github.com/AdguardTeam/urlfilter/filterlist" "github.com/AdguardTeam/urlfilter/filterlist"
"github.com/AdguardTeam/urlfilter/rules" "github.com/AdguardTeam/urlfilter/rules"
"github.com/c2h5oh/datasize"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
// Attribute keys and values for logging.
const (
LogPrefix = "safesearch"
LogKeyClient = "client"
)
// Service is a enum with service names used as search providers. // Service is a enum with service names used as search providers.
type Service string type Service string
@ -57,9 +66,32 @@ func isServiceProtected(s filtering.SafeSearchConfig, service Service) (ok bool)
} }
} }
// DefaultConfig is the configuration structure for [Default].
type DefaultConfig struct {
// Logger is used for logging the operation of the safe search filter.
Logger *slog.Logger
// ClientName is the name of the persistent client associated with the safe
// search filter, if there is one.
ClientName string
// CacheSize is the size of the filter results cache.
CacheSize uint
// CacheTTL is the Time to Live duration for cached items.
CacheTTL time.Duration
// ServicesConfig contains safe search settings for services. It must not
// be nil.
ServicesConfig filtering.SafeSearchConfig
}
// Default is the default safe search filter that uses filtering rules with the // Default is the default safe search filter that uses filtering rules with the
// dnsrewrite modifier. // dnsrewrite modifier.
type Default struct { type Default struct {
// logger is used for logging the operation of the safe search filter.
logger *slog.Logger
// mu protects engine. // mu protects engine.
mu *sync.RWMutex mu *sync.RWMutex
@ -67,33 +99,28 @@ type Default struct {
// engine may be nil, which means that this safe search filter is disabled. // engine may be nil, which means that this safe search filter is disabled.
engine *urlfilter.DNSEngine engine *urlfilter.DNSEngine
cache cache.Cache // cache stores safe search filtering results.
logPrefix string cache cache.Cache
cacheTTL time.Duration
// cacheTTL is the Time to Live duration for cached items.
cacheTTL time.Duration
} }
// NewDefault returns an initialized default safe search filter. name is used // NewDefault returns an initialized default safe search filter. ctx is used
// for logging. // to log the initial refresh.
func NewDefault( func NewDefault(ctx context.Context, conf *DefaultConfig) (ss *Default, err error) {
conf filtering.SafeSearchConfig,
name string,
cacheSize uint,
cacheTTL time.Duration,
) (ss *Default, err error) {
ss = &Default{ ss = &Default{
mu: &sync.RWMutex{}, logger: conf.Logger,
mu: &sync.RWMutex{},
cache: cache.New(cache.Config{ cache: cache.New(cache.Config{
EnableLRU: true, EnableLRU: true,
MaxSize: cacheSize, MaxSize: conf.CacheSize,
}), }),
// Use %s, because the client safe-search names already contain double cacheTTL: conf.CacheTTL,
// quotes.
logPrefix: fmt.Sprintf("safesearch %s: ", name),
cacheTTL: cacheTTL,
} }
err = ss.resetEngine(rulelist.URLFilterIDSafeSearch, conf) // TODO(s.chzhen): Move to [Default.InitialRefresh].
err = ss.resetEngine(ctx, rulelist.URLFilterIDSafeSearch, conf.ServicesConfig)
if err != nil { if err != nil {
// Don't wrap the error, because it's informative enough as is. // Don't wrap the error, because it's informative enough as is.
return nil, err return nil, err
@ -102,29 +129,15 @@ func NewDefault(
return ss, nil return ss, nil
} }
// log is a helper for logging that includes the name of the safe search
// filter. level must be one of [log.DEBUG], [log.INFO], and [log.ERROR].
func (ss *Default) log(level log.Level, msg string, args ...any) {
switch level {
case log.DEBUG:
log.Debug(ss.logPrefix+msg, args...)
case log.INFO:
log.Info(ss.logPrefix+msg, args...)
case log.ERROR:
log.Error(ss.logPrefix+msg, args...)
default:
panic(fmt.Errorf("safesearch: unsupported logging level %d", level))
}
}
// resetEngine creates new engine for provided safe search configuration and // resetEngine creates new engine for provided safe search configuration and
// sets it in ss. // sets it in ss.
func (ss *Default) resetEngine( func (ss *Default) resetEngine(
ctx context.Context,
listID int, listID int,
conf filtering.SafeSearchConfig, conf filtering.SafeSearchConfig,
) (err error) { ) (err error) {
if !conf.Enabled { if !conf.Enabled {
ss.log(log.INFO, "disabled") ss.logger.DebugContext(ctx, "disabled")
return nil return nil
} }
@ -149,7 +162,7 @@ func (ss *Default) resetEngine(
ss.engine = urlfilter.NewDNSEngine(rs) ss.engine = urlfilter.NewDNSEngine(rs)
ss.log(log.INFO, "reset %d rules", ss.engine.RulesCount) ss.logger.InfoContext(ctx, "reset rules", "count", ss.engine.RulesCount)
return nil return nil
} }
@ -158,10 +171,14 @@ func (ss *Default) resetEngine(
var _ filtering.SafeSearch = (*Default)(nil) var _ filtering.SafeSearch = (*Default)(nil)
// CheckHost implements the [filtering.SafeSearch] interface for *Default. // CheckHost implements the [filtering.SafeSearch] interface for *Default.
func (ss *Default) CheckHost(host string, qtype rules.RRType) (res filtering.Result, err error) { func (ss *Default) CheckHost(
ctx context.Context,
host string,
qtype rules.RRType,
) (res filtering.Result, err error) {
start := time.Now() start := time.Now()
defer func() { defer func() {
ss.log(log.DEBUG, "lookup for %q finished in %s", host, time.Since(start)) ss.logger.DebugContext(ctx, "lookup finished", "host", host, "elapsed", time.Since(start))
}() }()
switch qtype { switch qtype {
@ -172,9 +189,9 @@ func (ss *Default) CheckHost(host string, qtype rules.RRType) (res filtering.Res
} }
// Check cache. Return cached result if it was found // Check cache. Return cached result if it was found
cachedValue, isFound := ss.getCachedResult(host, qtype) cachedValue, isFound := ss.getCachedResult(ctx, host, qtype)
if isFound { if isFound {
ss.log(log.DEBUG, "found in cache: %q", host) ss.logger.DebugContext(ctx, "found in cache", "host", host)
return cachedValue, nil return cachedValue, nil
} }
@ -186,7 +203,7 @@ func (ss *Default) CheckHost(host string, qtype rules.RRType) (res filtering.Res
fltRes, err := ss.newResult(rewrite, qtype) fltRes, err := ss.newResult(rewrite, qtype)
if err != nil { if err != nil {
ss.log(log.DEBUG, "looking up addresses for %q: %s", host, err) ss.logger.ErrorContext(ctx, "looking up addresses", "host", host, slogutil.KeyError, err)
return filtering.Result{}, err return filtering.Result{}, err
} }
@ -195,7 +212,7 @@ func (ss *Default) CheckHost(host string, qtype rules.RRType) (res filtering.Res
// TODO(a.garipov): Consider switch back to resolving CNAME records IPs and // TODO(a.garipov): Consider switch back to resolving CNAME records IPs and
// saving results to cache. // saving results to cache.
ss.setCacheResult(host, qtype, res) ss.setCacheResult(ctx, host, qtype, res)
return res, nil return res, nil
} }
@ -255,7 +272,12 @@ func (ss *Default) newResult(
// setCacheResult stores data in cache for host. qtype is expected to be either // setCacheResult stores data in cache for host. qtype is expected to be either
// [dns.TypeA] or [dns.TypeAAAA]. // [dns.TypeA] or [dns.TypeAAAA].
func (ss *Default) setCacheResult(host string, qtype rules.RRType, res filtering.Result) { func (ss *Default) setCacheResult(
ctx context.Context,
host string,
qtype rules.RRType,
res filtering.Result,
) {
expire := uint32(time.Now().Add(ss.cacheTTL).Unix()) expire := uint32(time.Now().Add(ss.cacheTTL).Unix())
exp := make([]byte, 4) exp := make([]byte, 4)
binary.BigEndian.PutUint32(exp, expire) binary.BigEndian.PutUint32(exp, expire)
@ -263,7 +285,7 @@ func (ss *Default) setCacheResult(host string, qtype rules.RRType, res filtering
err := gob.NewEncoder(buf).Encode(res) err := gob.NewEncoder(buf).Encode(res)
if err != nil { if err != nil {
ss.log(log.ERROR, "cache encoding: %s", err) ss.logger.ErrorContext(ctx, "cache encoding", slogutil.KeyError, err)
return return
} }
@ -271,12 +293,18 @@ func (ss *Default) setCacheResult(host string, qtype rules.RRType, res filtering
val := buf.Bytes() val := buf.Bytes()
_ = ss.cache.Set([]byte(dns.Type(qtype).String()+" "+host), val) _ = ss.cache.Set([]byte(dns.Type(qtype).String()+" "+host), val)
ss.log(log.DEBUG, "stored in cache: %q, %d bytes", host, len(val)) ss.logger.DebugContext(
ctx,
"stored in cache",
"host", host,
"entry_size", datasize.ByteSize(len(val)),
)
} }
// getCachedResult returns stored data from cache for host. qtype is expected // getCachedResult returns stored data from cache for host. qtype is expected
// to be either [dns.TypeA] or [dns.TypeAAAA]. // to be either [dns.TypeA] or [dns.TypeAAAA].
func (ss *Default) getCachedResult( func (ss *Default) getCachedResult(
ctx context.Context,
host string, host string,
qtype rules.RRType, qtype rules.RRType,
) (res filtering.Result, ok bool) { ) (res filtering.Result, ok bool) {
@ -298,7 +326,7 @@ func (ss *Default) getCachedResult(
err := gob.NewDecoder(buf).Decode(&res) err := gob.NewDecoder(buf).Decode(&res)
if err != nil { if err != nil {
ss.log(log.ERROR, "cache decoding: %s", err) ss.logger.ErrorContext(ctx, "cache decoding", slogutil.KeyError, err)
return filtering.Result{}, false return filtering.Result{}, false
} }
@ -308,11 +336,11 @@ func (ss *Default) getCachedResult(
// Update implements the [filtering.SafeSearch] interface for *Default. Update // Update implements the [filtering.SafeSearch] interface for *Default. Update
// ignores the CustomResolver and Enabled fields. // ignores the CustomResolver and Enabled fields.
func (ss *Default) Update(conf filtering.SafeSearchConfig) (err error) { func (ss *Default) Update(ctx context.Context, conf filtering.SafeSearchConfig) (err error) {
ss.mu.Lock() ss.mu.Lock()
defer ss.mu.Unlock() defer ss.mu.Unlock()
err = ss.resetEngine(rulelist.URLFilterIDSafeSearch, conf) err = ss.resetEngine(ctx, rulelist.URLFilterIDSafeSearch, conf)
if err != nil { if err != nil {
// Don't wrap the error, because it's informative enough as is. // Don't wrap the error, because it's informative enough as is.
return err return err

View File

@ -6,6 +6,8 @@ import (
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/urlfilter/rules" "github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -21,6 +23,9 @@ const (
testCacheTTL = 30 * time.Minute testCacheTTL = 30 * time.Minute
) )
// testTimeout is the common timeout for tests and contexts.
const testTimeout = 1 * time.Second
var defaultSafeSearchConf = filtering.SafeSearchConfig{ var defaultSafeSearchConf = filtering.SafeSearchConfig{
Enabled: true, Enabled: true,
Bing: true, Bing: true,
@ -35,7 +40,12 @@ var defaultSafeSearchConf = filtering.SafeSearchConfig{
var yandexIP = netip.AddrFrom4([4]byte{213, 180, 193, 56}) var yandexIP = netip.AddrFrom4([4]byte{213, 180, 193, 56})
func newForTest(t testing.TB, ssConf filtering.SafeSearchConfig) (ss *Default) { func newForTest(t testing.TB, ssConf filtering.SafeSearchConfig) (ss *Default) {
ss, err := NewDefault(ssConf, "", testCacheSize, testCacheTTL) ss, err := NewDefault(testutil.ContextWithTimeout(t, testTimeout), &DefaultConfig{
Logger: slogutil.NewDiscardLogger(),
ServicesConfig: ssConf,
CacheSize: testCacheSize,
CacheTTL: testCacheTTL,
})
require.NoError(t, err) require.NoError(t, err)
return ss return ss
@ -52,16 +62,17 @@ func TestSafeSearchCacheYandex(t *testing.T) {
const domain = "yandex.ru" const domain = "yandex.ru"
ss := newForTest(t, filtering.SafeSearchConfig{Enabled: false}) ss := newForTest(t, filtering.SafeSearchConfig{Enabled: false})
ctx := testutil.ContextWithTimeout(t, testTimeout)
// Check host with disabled safesearch. // Check host with disabled safesearch.
res, err := ss.CheckHost(domain, testQType) res, err := ss.CheckHost(ctx, domain, testQType)
require.NoError(t, err) require.NoError(t, err)
assert.False(t, res.IsFiltered) assert.False(t, res.IsFiltered)
assert.Empty(t, res.Rules) assert.Empty(t, res.Rules)
ss = newForTest(t, defaultSafeSearchConf) ss = newForTest(t, defaultSafeSearchConf)
res, err = ss.CheckHost(domain, testQType) res, err = ss.CheckHost(ctx, domain, testQType)
require.NoError(t, err) require.NoError(t, err)
// For yandex we already know valid IP. // For yandex we already know valid IP.
@ -70,7 +81,7 @@ func TestSafeSearchCacheYandex(t *testing.T) {
assert.Equal(t, res.Rules[0].IP, yandexIP) assert.Equal(t, res.Rules[0].IP, yandexIP)
// Check cache. // Check cache.
cachedValue, isFound := ss.getCachedResult(domain, testQType) cachedValue, isFound := ss.getCachedResult(ctx, domain, testQType)
require.True(t, isFound) require.True(t, isFound)
require.Len(t, cachedValue.Rules, 1) require.Len(t, cachedValue.Rules, 1)

View File

@ -10,15 +10,15 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist" "github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch" "github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestMain(m *testing.M) { // testTimeout is the common timeout for tests and contexts.
testutil.DiscardLogOutput(m) const testTimeout = 1 * time.Second
}
// Common test constants. // Common test constants.
const ( const (
@ -47,7 +47,13 @@ var yandexIP = netip.AddrFrom4([4]byte{213, 180, 193, 56})
func TestDefault_CheckHost_yandex(t *testing.T) { func TestDefault_CheckHost_yandex(t *testing.T) {
conf := testConf conf := testConf
ss, err := safesearch.NewDefault(conf, "", testCacheSize, testCacheTTL) ctx := testutil.ContextWithTimeout(t, testTimeout)
ss, err := safesearch.NewDefault(ctx, &safesearch.DefaultConfig{
Logger: slogutil.NewDiscardLogger(),
ServicesConfig: conf,
CacheSize: testCacheSize,
CacheTTL: testCacheTTL,
})
require.NoError(t, err) require.NoError(t, err)
hosts := []string{ hosts := []string{
@ -82,7 +88,7 @@ func TestDefault_CheckHost_yandex(t *testing.T) {
for _, host := range hosts { for _, host := range hosts {
// Check host for each domain. // Check host for each domain.
var res filtering.Result var res filtering.Result
res, err = ss.CheckHost(host, tc.qt) res, err = ss.CheckHost(ctx, host, tc.qt)
require.NoError(t, err) require.NoError(t, err)
assert.True(t, res.IsFiltered) assert.True(t, res.IsFiltered)
@ -103,7 +109,13 @@ func TestDefault_CheckHost_yandex(t *testing.T) {
} }
func TestDefault_CheckHost_google(t *testing.T) { func TestDefault_CheckHost_google(t *testing.T) {
ss, err := safesearch.NewDefault(testConf, "", testCacheSize, testCacheTTL) ctx := testutil.ContextWithTimeout(t, testTimeout)
ss, err := safesearch.NewDefault(ctx, &safesearch.DefaultConfig{
Logger: slogutil.NewDiscardLogger(),
ServicesConfig: testConf,
CacheSize: testCacheSize,
CacheTTL: testCacheTTL,
})
require.NoError(t, err) require.NoError(t, err)
// Check host for each domain. // Check host for each domain.
@ -118,7 +130,7 @@ func TestDefault_CheckHost_google(t *testing.T) {
} { } {
t.Run(host, func(t *testing.T) { t.Run(host, func(t *testing.T) {
var res filtering.Result var res filtering.Result
res, err = ss.CheckHost(host, testQType) res, err = ss.CheckHost(ctx, host, testQType)
require.NoError(t, err) require.NoError(t, err)
assert.True(t, res.IsFiltered) assert.True(t, res.IsFiltered)
@ -149,13 +161,19 @@ func (r *testResolver) LookupIP(
} }
func TestDefault_CheckHost_duckduckgoAAAA(t *testing.T) { func TestDefault_CheckHost_duckduckgoAAAA(t *testing.T) {
ss, err := safesearch.NewDefault(testConf, "", testCacheSize, testCacheTTL) ctx := testutil.ContextWithTimeout(t, testTimeout)
ss, err := safesearch.NewDefault(ctx, &safesearch.DefaultConfig{
Logger: slogutil.NewDiscardLogger(),
ServicesConfig: testConf,
CacheSize: testCacheSize,
CacheTTL: testCacheTTL,
})
require.NoError(t, err) require.NoError(t, err)
// The DuckDuckGo safe-search addresses are resolved through CNAMEs, but // The DuckDuckGo safe-search addresses are resolved through CNAMEs, but
// DuckDuckGo doesn't have a safe-search IPv6 address. The result should be // DuckDuckGo doesn't have a safe-search IPv6 address. The result should be
// the same as the one for Yandex IPv6. That is, a NODATA response. // the same as the one for Yandex IPv6. That is, a NODATA response.
res, err := ss.CheckHost("www.duckduckgo.com", dns.TypeAAAA) res, err := ss.CheckHost(ctx, "www.duckduckgo.com", dns.TypeAAAA)
require.NoError(t, err) require.NoError(t, err)
assert.True(t, res.IsFiltered) assert.True(t, res.IsFiltered)
@ -166,32 +184,38 @@ func TestDefault_CheckHost_duckduckgoAAAA(t *testing.T) {
func TestDefault_Update(t *testing.T) { func TestDefault_Update(t *testing.T) {
conf := testConf conf := testConf
ss, err := safesearch.NewDefault(conf, "", testCacheSize, testCacheTTL) ctx := testutil.ContextWithTimeout(t, testTimeout)
ss, err := safesearch.NewDefault(ctx, &safesearch.DefaultConfig{
Logger: slogutil.NewDiscardLogger(),
ServicesConfig: conf,
CacheSize: testCacheSize,
CacheTTL: testCacheTTL,
})
require.NoError(t, err) require.NoError(t, err)
res, err := ss.CheckHost("www.yandex.com", testQType) res, err := ss.CheckHost(ctx, "www.yandex.com", testQType)
require.NoError(t, err) require.NoError(t, err)
assert.True(t, res.IsFiltered) assert.True(t, res.IsFiltered)
err = ss.Update(filtering.SafeSearchConfig{ err = ss.Update(ctx, filtering.SafeSearchConfig{
Enabled: true, Enabled: true,
Google: false, Google: false,
}) })
require.NoError(t, err) require.NoError(t, err)
res, err = ss.CheckHost("www.yandex.com", testQType) res, err = ss.CheckHost(ctx, "www.yandex.com", testQType)
require.NoError(t, err) require.NoError(t, err)
assert.False(t, res.IsFiltered) assert.False(t, res.IsFiltered)
err = ss.Update(filtering.SafeSearchConfig{ err = ss.Update(ctx, filtering.SafeSearchConfig{
Enabled: false, Enabled: false,
Google: true, Google: true,
}) })
require.NoError(t, err) require.NoError(t, err)
res, err = ss.CheckHost("www.yandex.com", testQType) res, err = ss.CheckHost(ctx, "www.yandex.com", testQType)
require.NoError(t, err) require.NoError(t, err)
assert.False(t, res.IsFiltered) assert.False(t, res.IsFiltered)

View File

@ -51,7 +51,7 @@ func (d *DNSFilter) handleSafeSearchSettings(w http.ResponseWriter, r *http.Requ
} }
conf := *req conf := *req
err = d.safeSearch.Update(conf) err = d.safeSearch.Update(r.Context(), conf)
if err != nil { if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "updating: %s", err) aghhttp.Error(r, w, http.StatusBadRequest, "updating: %s", err)

View File

@ -3,6 +3,7 @@ package home
import ( import (
"context" "context"
"fmt" "fmt"
"log/slog"
"net/netip" "net/netip"
"slices" "slices"
"sync" "sync"
@ -13,17 +14,23 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/client" "github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward" "github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
"github.com/AdguardTeam/AdGuardHome/internal/querylog" "github.com/AdguardTeam/AdGuardHome/internal/querylog"
"github.com/AdguardTeam/AdGuardHome/internal/schedule" "github.com/AdguardTeam/AdGuardHome/internal/schedule"
"github.com/AdguardTeam/AdGuardHome/internal/whois" "github.com/AdguardTeam/AdGuardHome/internal/whois"
"github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/stringutil" "github.com/AdguardTeam/golibs/stringutil"
) )
// clientsContainer is the storage of all runtime and persistent clients. // clientsContainer is the storage of all runtime and persistent clients.
type clientsContainer struct { type clientsContainer struct {
// baseLogger is used to create loggers with custom prefixes for safe search
// filter. It must not be nil.
baseLogger *slog.Logger
// storage stores information about persistent clients. // storage stores information about persistent clients.
storage *client.Storage storage *client.Storage
@ -61,6 +68,8 @@ type BlockedClientChecker interface {
// dhcpServer: optional // dhcpServer: optional
// Note: this function must be called only once // Note: this function must be called only once
func (clients *clientsContainer) Init( func (clients *clientsContainer) Init(
ctx context.Context,
baseLogger *slog.Logger,
objects []*clientObject, objects []*clientObject,
dhcpServer client.DHCP, dhcpServer client.DHCP,
etcHosts *aghnet.HostsContainer, etcHosts *aghnet.HostsContainer,
@ -72,13 +81,14 @@ func (clients *clientsContainer) Init(
return errors.Error("clients container already initialized") return errors.Error("clients container already initialized")
} }
clients.baseLogger = baseLogger
clients.safeSearchCacheSize = filteringConf.SafeSearchCacheSize clients.safeSearchCacheSize = filteringConf.SafeSearchCacheSize
clients.safeSearchCacheTTL = time.Minute * time.Duration(filteringConf.CacheTime) clients.safeSearchCacheTTL = time.Minute * time.Duration(filteringConf.CacheTime)
confClients := make([]*client.Persistent, 0, len(objects)) confClients := make([]*client.Persistent, 0, len(objects))
for i, o := range objects { for i, o := range objects {
var p *client.Persistent var p *client.Persistent
p, err = o.toPersistent(clients.safeSearchCacheSize, clients.safeSearchCacheTTL) p, err = o.toPersistent(ctx, baseLogger, clients.safeSearchCacheSize, clients.safeSearchCacheTTL)
if err != nil { if err != nil {
return fmt.Errorf("init persistent client at index %d: %w", i, err) return fmt.Errorf("init persistent client at index %d: %w", i, err)
} }
@ -168,6 +178,8 @@ type clientObject struct {
// toPersistent returns an initialized persistent client if there are no errors. // toPersistent returns an initialized persistent client if there are no errors.
func (o *clientObject) toPersistent( func (o *clientObject) toPersistent(
ctx context.Context,
baseLogger *slog.Logger,
safeSearchCacheSize uint, safeSearchCacheSize uint,
safeSearchCacheTTL time.Duration, safeSearchCacheTTL time.Duration,
) (cli *client.Persistent, err error) { ) (cli *client.Persistent, err error) {
@ -203,14 +215,23 @@ func (o *clientObject) toPersistent(
} }
if o.SafeSearchConf.Enabled { if o.SafeSearchConf.Enabled {
err = cli.SetSafeSearch( logger := baseLogger.With(
o.SafeSearchConf, slogutil.KeyPrefix, safesearch.LogPrefix,
safeSearchCacheSize, safesearch.LogKeyClient, cli.Name,
safeSearchCacheTTL,
) )
var ss *safesearch.Default
ss, err = safesearch.NewDefault(ctx, &safesearch.DefaultConfig{
Logger: logger,
ServicesConfig: o.SafeSearchConf,
ClientName: cli.Name,
CacheSize: safeSearchCacheSize,
CacheTTL: safeSearchCacheTTL,
})
if err != nil { if err != nil {
return nil, fmt.Errorf("init safesearch %q: %w", cli.Name, err) return nil, fmt.Errorf("init safesearch %q: %w", cli.Name, err)
} }
cli.SafeSearch = ss
} }
if o.BlockedServices == nil { if o.BlockedServices == nil {

View File

@ -7,6 +7,8 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/client" "github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -20,7 +22,18 @@ func newClientsContainer(t *testing.T) (c *clientsContainer) {
testing: true, testing: true,
} }
require.NoError(t, c.Init(nil, client.EmptyDHCP{}, nil, nil, &filtering.Config{})) ctx := testutil.ContextWithTimeout(t, testTimeout)
err := c.Init(
ctx,
slogutil.NewDiscardLogger(),
nil,
client.EmptyDHCP{},
nil,
nil,
&filtering.Config{},
)
require.NoError(t, err)
return c return c
} }

View File

@ -1,6 +1,7 @@
package home package home
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
@ -10,8 +11,10 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/client" "github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
"github.com/AdguardTeam/AdGuardHome/internal/schedule" "github.com/AdguardTeam/AdGuardHome/internal/schedule"
"github.com/AdguardTeam/AdGuardHome/internal/whois" "github.com/AdguardTeam/AdGuardHome/internal/whois"
"github.com/AdguardTeam/golibs/logutil/slogutil"
) )
// clientJSON is a common structure used by several handlers to deal with // clientJSON is a common structure used by several handlers to deal with
@ -181,6 +184,7 @@ func initPrev(cj clientJSON, prev *client.Persistent) (c *client.Persistent, err
// jsonToClient converts JSON object to persistent client object if there are no // jsonToClient converts JSON object to persistent client object if there are no
// errors. // errors.
func (clients *clientsContainer) jsonToClient( func (clients *clientsContainer) jsonToClient(
ctx context.Context,
cj clientJSON, cj clientJSON,
prev *client.Persistent, prev *client.Persistent,
) (c *client.Persistent, err error) { ) (c *client.Persistent, err error) {
@ -207,14 +211,23 @@ func (clients *clientsContainer) jsonToClient(
c.UseOwnBlockedServices = !cj.UseGlobalBlockedServices c.UseOwnBlockedServices = !cj.UseGlobalBlockedServices
if c.SafeSearchConf.Enabled { if c.SafeSearchConf.Enabled {
err = c.SetSafeSearch( logger := clients.baseLogger.With(
c.SafeSearchConf, slogutil.KeyPrefix, safesearch.LogPrefix,
clients.safeSearchCacheSize, safesearch.LogKeyClient, c.Name,
clients.safeSearchCacheTTL,
) )
var ss *safesearch.Default
ss, err = safesearch.NewDefault(ctx, &safesearch.DefaultConfig{
Logger: logger,
ServicesConfig: c.SafeSearchConf,
ClientName: c.Name,
CacheSize: clients.safeSearchCacheSize,
CacheTTL: clients.safeSearchCacheTTL,
})
if err != nil { if err != nil {
return nil, fmt.Errorf("creating safesearch for client %q: %w", c.Name, err) return nil, fmt.Errorf("creating safesearch for client %q: %w", c.Name, err)
} }
c.SafeSearch = ss
} }
return c, nil return c, nil
@ -321,7 +334,7 @@ func (clients *clientsContainer) handleAddClient(w http.ResponseWriter, r *http.
return return
} }
c, err := clients.jsonToClient(cj, nil) c, err := clients.jsonToClient(r.Context(), cj, nil)
if err != nil { if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
@ -391,7 +404,7 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht
return return
} }
c, err := clients.jsonToClient(dj.Data, nil) c, err := clients.jsonToClient(r.Context(), dj.Data, nil)
if err != nil { if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)

View File

@ -11,14 +11,19 @@ import (
"net/url" "net/url"
"slices" "slices"
"testing" "testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/client" "github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/schedule" "github.com/AdguardTeam/AdGuardHome/internal/schedule"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
// testTimeout is the common timeout for tests and contexts.
const testTimeout = 1 * time.Second
const ( const (
testClientIP1 = "1.1.1.1" testClientIP1 = "1.1.1.1"
testClientIP2 = "2.2.2.2" testClientIP2 = "2.2.2.2"
@ -103,9 +108,10 @@ func assertPersistentClients(tb testing.TB, clients *clientsContainer, want []*c
require.NoError(tb, err) require.NoError(tb, err)
var got []*client.Persistent var got []*client.Persistent
ctx := testutil.ContextWithTimeout(tb, testTimeout)
for _, cj := range clientList.Clients { for _, cj := range clientList.Clients {
var c *client.Persistent var c *client.Persistent
c, err = clients.jsonToClient(*cj, nil) c, err = clients.jsonToClient(ctx, *cj, nil)
require.NoError(tb, err) require.NoError(tb, err)
got = append(got, c) got = append(got, c)
@ -125,10 +131,11 @@ func assertPersistentClientsData(
tb.Helper() tb.Helper()
var got []*client.Persistent var got []*client.Persistent
ctx := testutil.ContextWithTimeout(tb, testTimeout)
for _, cm := range data { for _, cm := range data {
for _, cj := range cm { for _, cj := range cm {
var c *client.Persistent var c *client.Persistent
c, err := clients.jsonToClient(*cj, nil) c, err := clients.jsonToClient(ctx, *cj, nil)
require.NoError(tb, err) require.NoError(tb, err)
got = append(got, c) got = append(got, c)

View File

@ -278,8 +278,8 @@ func setupOpts(opts options) (err error) {
} }
// initContextClients initializes Context clients and related fields. // initContextClients initializes Context clients and related fields.
func initContextClients(logger *slog.Logger) (err error) { func initContextClients(ctx context.Context, logger *slog.Logger) (err error) {
err = setupDNSFilteringConf(config.Filtering) err = setupDNSFilteringConf(ctx, logger, config.Filtering)
if err != nil { if err != nil {
// Don't wrap the error, because it's informative enough as is. // Don't wrap the error, because it's informative enough as is.
return err return err
@ -306,6 +306,8 @@ func initContextClients(logger *slog.Logger) (err error) {
} }
return Context.clients.Init( return Context.clients.Init(
ctx,
logger,
config.Clients.Persistent, config.Clients.Persistent,
Context.dhcpServer, Context.dhcpServer,
Context.etcHosts, Context.etcHosts,
@ -355,7 +357,11 @@ func setupBindOpts(opts options) (err error) {
} }
// setupDNSFilteringConf sets up DNS filtering configuration settings. // setupDNSFilteringConf sets up DNS filtering configuration settings.
func setupDNSFilteringConf(conf *filtering.Config) (err error) { func setupDNSFilteringConf(
ctx context.Context,
baseLogger *slog.Logger,
conf *filtering.Config,
) (err error) {
const ( const (
dnsTimeout = 3 * time.Second dnsTimeout = 3 * time.Second
@ -446,12 +452,13 @@ func setupDNSFilteringConf(conf *filtering.Config) (err error) {
conf.ParentalBlockHost = host conf.ParentalBlockHost = host
} }
conf.SafeSearch, err = safesearch.NewDefault( logger := baseLogger.With(slogutil.KeyPrefix, safesearch.LogPrefix)
conf.SafeSearchConf, conf.SafeSearch, err = safesearch.NewDefault(ctx, &safesearch.DefaultConfig{
"default", Logger: logger,
conf.SafeSearchCacheSize, ServicesConfig: conf.SafeSearchConf,
cacheTime, CacheSize: conf.SafeSearchCacheSize,
) CacheTTL: cacheTime,
})
if err != nil { if err != nil {
return fmt.Errorf("initializing safesearch: %w", err) return fmt.Errorf("initializing safesearch: %w", err)
} }
@ -584,7 +591,10 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
// data first, but also to avoid relying on automatic Go init() function. // data first, but also to avoid relying on automatic Go init() function.
filtering.InitModule() filtering.InitModule()
err = initContextClients(slogLogger) // TODO(s.chzhen): Use it for the entire initialization process.
ctx := context.Background()
err = initContextClients(ctx, slogLogger)
fatalOnError(err) fatalOnError(err)
err = setupOpts(opts) err = setupOpts(opts)