mirror of
https://github.com/AdguardTeam/AdGuardHome.git
synced 2024-12-13 14:34:35 +03:00
Pull request 2286: AGDNS-2374-slog-safesearch
Squashed commit of the following: commit1909dfed99
Merge:3856fda5f
2c64ab5a5
Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Wed Oct 9 16:21:38 2024 +0300 Merge branch 'master' into AGDNS-2374-slog-safesearch commit3856fda5f3
Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Tue Oct 8 20:04:34 2024 +0300 home: imp code commitde774009aa
Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Mon Oct 7 16:41:58 2024 +0300 all: imp code commit038bae59d5
Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Thu Oct 3 20:24:48 2024 +0300 all: imp code commit792975e248
Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Thu Oct 3 15:46:40 2024 +0300 all: slog safesearch
This commit is contained in:
parent
2c64ab5a51
commit
6363f8a2e7
@ -7,10 +7,8 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
|
|
||||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
@ -323,20 +321,3 @@ func (c *Persistent) CloseUpstreams() (err error) {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetSafeSearch initializes and sets the safe search filter for this client.
|
|
||||||
func (c *Persistent) SetSafeSearch(
|
|
||||||
conf filtering.SafeSearchConfig,
|
|
||||||
cacheSize uint,
|
|
||||||
cacheTTL time.Duration,
|
|
||||||
) (err error) {
|
|
||||||
ss, err := safesearch.NewDefault(conf, fmt.Sprintf("client %q", c.Name), cacheSize, cacheTTL)
|
|
||||||
if err != nil {
|
|
||||||
// Don't wrap the error, because it's informative enough as is.
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
c.SafeSearch = ss
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
@ -513,12 +513,14 @@ func TestSafeSearch(t *testing.T) {
|
|||||||
SafeSearchCacheSize: 1000,
|
SafeSearchCacheSize: 1000,
|
||||||
CacheTime: 30,
|
CacheTime: 30,
|
||||||
}
|
}
|
||||||
safeSearch, err := safesearch.NewDefault(
|
|
||||||
safeSearchConf,
|
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||||
"",
|
safeSearch, err := safesearch.NewDefault(ctx, &safesearch.DefaultConfig{
|
||||||
filterConf.SafeSearchCacheSize,
|
Logger: slogutil.NewDiscardLogger(),
|
||||||
time.Minute*time.Duration(filterConf.CacheTime),
|
ServicesConfig: safeSearchConf,
|
||||||
)
|
CacheSize: filterConf.SafeSearchCacheSize,
|
||||||
|
CacheTTL: time.Minute * time.Duration(filterConf.CacheTime),
|
||||||
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
filterConf.SafeSearch = safeSearch
|
filterConf.SafeSearch = safeSearch
|
||||||
|
@ -1,15 +1,17 @@
|
|||||||
package filtering
|
package filtering
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
// SafeSearch interface describes a service for search engines hosts rewrites.
|
// SafeSearch interface describes a service for search engines hosts rewrites.
|
||||||
type SafeSearch interface {
|
type SafeSearch interface {
|
||||||
// CheckHost checks host with safe search filter. CheckHost must be safe
|
// CheckHost checks host with safe search filter. CheckHost must be safe
|
||||||
// for concurrent use. qtype must be either [dns.TypeA] or [dns.TypeAAAA].
|
// for concurrent use. qtype must be either [dns.TypeA] or [dns.TypeAAAA].
|
||||||
CheckHost(host string, qtype uint16) (res Result, err error)
|
CheckHost(ctx context.Context, host string, qtype uint16) (res Result, err error)
|
||||||
|
|
||||||
// Update updates the configuration of the safe search filter. Update must
|
// Update updates the configuration of the safe search filter. Update must
|
||||||
// be safe for concurrent use. An implementation of Update may ignore some
|
// be safe for concurrent use. An implementation of Update may ignore some
|
||||||
// fields, but it must document which.
|
// fields, but it must document which.
|
||||||
Update(conf SafeSearchConfig) (err error)
|
Update(ctx context.Context, conf SafeSearchConfig) (err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SafeSearchConfig is a struct with safe search related settings.
|
// SafeSearchConfig is a struct with safe search related settings.
|
||||||
@ -40,10 +42,13 @@ func (d *DNSFilter) checkSafeSearch(
|
|||||||
return Result{}, nil
|
return Result{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(s.chzhen): Pass context.
|
||||||
|
ctx := context.TODO()
|
||||||
|
|
||||||
clientSafeSearch := setts.ClientSafeSearch
|
clientSafeSearch := setts.ClientSafeSearch
|
||||||
if clientSafeSearch != nil {
|
if clientSafeSearch != nil {
|
||||||
return clientSafeSearch.CheckHost(host, qtype)
|
return clientSafeSearch.CheckHost(ctx, host, qtype)
|
||||||
}
|
}
|
||||||
|
|
||||||
return d.safeSearch.CheckHost(host, qtype)
|
return d.safeSearch.CheckHost(ctx, host, qtype)
|
||||||
}
|
}
|
||||||
|
@ -3,9 +3,11 @@ package safesearch
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"encoding/gob"
|
"encoding/gob"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@ -14,13 +16,20 @@ import (
|
|||||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist"
|
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist"
|
||||||
"github.com/AdguardTeam/golibs/cache"
|
"github.com/AdguardTeam/golibs/cache"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||||
"github.com/AdguardTeam/urlfilter"
|
"github.com/AdguardTeam/urlfilter"
|
||||||
"github.com/AdguardTeam/urlfilter/filterlist"
|
"github.com/AdguardTeam/urlfilter/filterlist"
|
||||||
"github.com/AdguardTeam/urlfilter/rules"
|
"github.com/AdguardTeam/urlfilter/rules"
|
||||||
|
"github.com/c2h5oh/datasize"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Attribute keys and values for logging.
|
||||||
|
const (
|
||||||
|
LogPrefix = "safesearch"
|
||||||
|
LogKeyClient = "client"
|
||||||
|
)
|
||||||
|
|
||||||
// Service is a enum with service names used as search providers.
|
// Service is a enum with service names used as search providers.
|
||||||
type Service string
|
type Service string
|
||||||
|
|
||||||
@ -57,9 +66,32 @@ func isServiceProtected(s filtering.SafeSearchConfig, service Service) (ok bool)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DefaultConfig is the configuration structure for [Default].
|
||||||
|
type DefaultConfig struct {
|
||||||
|
// Logger is used for logging the operation of the safe search filter.
|
||||||
|
Logger *slog.Logger
|
||||||
|
|
||||||
|
// ClientName is the name of the persistent client associated with the safe
|
||||||
|
// search filter, if there is one.
|
||||||
|
ClientName string
|
||||||
|
|
||||||
|
// CacheSize is the size of the filter results cache.
|
||||||
|
CacheSize uint
|
||||||
|
|
||||||
|
// CacheTTL is the Time to Live duration for cached items.
|
||||||
|
CacheTTL time.Duration
|
||||||
|
|
||||||
|
// ServicesConfig contains safe search settings for services. It must not
|
||||||
|
// be nil.
|
||||||
|
ServicesConfig filtering.SafeSearchConfig
|
||||||
|
}
|
||||||
|
|
||||||
// Default is the default safe search filter that uses filtering rules with the
|
// Default is the default safe search filter that uses filtering rules with the
|
||||||
// dnsrewrite modifier.
|
// dnsrewrite modifier.
|
||||||
type Default struct {
|
type Default struct {
|
||||||
|
// logger is used for logging the operation of the safe search filter.
|
||||||
|
logger *slog.Logger
|
||||||
|
|
||||||
// mu protects engine.
|
// mu protects engine.
|
||||||
mu *sync.RWMutex
|
mu *sync.RWMutex
|
||||||
|
|
||||||
@ -67,33 +99,28 @@ type Default struct {
|
|||||||
// engine may be nil, which means that this safe search filter is disabled.
|
// engine may be nil, which means that this safe search filter is disabled.
|
||||||
engine *urlfilter.DNSEngine
|
engine *urlfilter.DNSEngine
|
||||||
|
|
||||||
cache cache.Cache
|
// cache stores safe search filtering results.
|
||||||
logPrefix string
|
cache cache.Cache
|
||||||
cacheTTL time.Duration
|
|
||||||
|
// cacheTTL is the Time to Live duration for cached items.
|
||||||
|
cacheTTL time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDefault returns an initialized default safe search filter. name is used
|
// NewDefault returns an initialized default safe search filter. ctx is used
|
||||||
// for logging.
|
// to log the initial refresh.
|
||||||
func NewDefault(
|
func NewDefault(ctx context.Context, conf *DefaultConfig) (ss *Default, err error) {
|
||||||
conf filtering.SafeSearchConfig,
|
|
||||||
name string,
|
|
||||||
cacheSize uint,
|
|
||||||
cacheTTL time.Duration,
|
|
||||||
) (ss *Default, err error) {
|
|
||||||
ss = &Default{
|
ss = &Default{
|
||||||
mu: &sync.RWMutex{},
|
logger: conf.Logger,
|
||||||
|
mu: &sync.RWMutex{},
|
||||||
cache: cache.New(cache.Config{
|
cache: cache.New(cache.Config{
|
||||||
EnableLRU: true,
|
EnableLRU: true,
|
||||||
MaxSize: cacheSize,
|
MaxSize: conf.CacheSize,
|
||||||
}),
|
}),
|
||||||
// Use %s, because the client safe-search names already contain double
|
cacheTTL: conf.CacheTTL,
|
||||||
// quotes.
|
|
||||||
logPrefix: fmt.Sprintf("safesearch %s: ", name),
|
|
||||||
cacheTTL: cacheTTL,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = ss.resetEngine(rulelist.URLFilterIDSafeSearch, conf)
|
// TODO(s.chzhen): Move to [Default.InitialRefresh].
|
||||||
|
err = ss.resetEngine(ctx, rulelist.URLFilterIDSafeSearch, conf.ServicesConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Don't wrap the error, because it's informative enough as is.
|
// Don't wrap the error, because it's informative enough as is.
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -102,29 +129,15 @@ func NewDefault(
|
|||||||
return ss, nil
|
return ss, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// log is a helper for logging that includes the name of the safe search
|
|
||||||
// filter. level must be one of [log.DEBUG], [log.INFO], and [log.ERROR].
|
|
||||||
func (ss *Default) log(level log.Level, msg string, args ...any) {
|
|
||||||
switch level {
|
|
||||||
case log.DEBUG:
|
|
||||||
log.Debug(ss.logPrefix+msg, args...)
|
|
||||||
case log.INFO:
|
|
||||||
log.Info(ss.logPrefix+msg, args...)
|
|
||||||
case log.ERROR:
|
|
||||||
log.Error(ss.logPrefix+msg, args...)
|
|
||||||
default:
|
|
||||||
panic(fmt.Errorf("safesearch: unsupported logging level %d", level))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// resetEngine creates new engine for provided safe search configuration and
|
// resetEngine creates new engine for provided safe search configuration and
|
||||||
// sets it in ss.
|
// sets it in ss.
|
||||||
func (ss *Default) resetEngine(
|
func (ss *Default) resetEngine(
|
||||||
|
ctx context.Context,
|
||||||
listID int,
|
listID int,
|
||||||
conf filtering.SafeSearchConfig,
|
conf filtering.SafeSearchConfig,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
if !conf.Enabled {
|
if !conf.Enabled {
|
||||||
ss.log(log.INFO, "disabled")
|
ss.logger.DebugContext(ctx, "disabled")
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -149,7 +162,7 @@ func (ss *Default) resetEngine(
|
|||||||
|
|
||||||
ss.engine = urlfilter.NewDNSEngine(rs)
|
ss.engine = urlfilter.NewDNSEngine(rs)
|
||||||
|
|
||||||
ss.log(log.INFO, "reset %d rules", ss.engine.RulesCount)
|
ss.logger.InfoContext(ctx, "reset rules", "count", ss.engine.RulesCount)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -158,10 +171,14 @@ func (ss *Default) resetEngine(
|
|||||||
var _ filtering.SafeSearch = (*Default)(nil)
|
var _ filtering.SafeSearch = (*Default)(nil)
|
||||||
|
|
||||||
// CheckHost implements the [filtering.SafeSearch] interface for *Default.
|
// CheckHost implements the [filtering.SafeSearch] interface for *Default.
|
||||||
func (ss *Default) CheckHost(host string, qtype rules.RRType) (res filtering.Result, err error) {
|
func (ss *Default) CheckHost(
|
||||||
|
ctx context.Context,
|
||||||
|
host string,
|
||||||
|
qtype rules.RRType,
|
||||||
|
) (res filtering.Result, err error) {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
defer func() {
|
defer func() {
|
||||||
ss.log(log.DEBUG, "lookup for %q finished in %s", host, time.Since(start))
|
ss.logger.DebugContext(ctx, "lookup finished", "host", host, "elapsed", time.Since(start))
|
||||||
}()
|
}()
|
||||||
|
|
||||||
switch qtype {
|
switch qtype {
|
||||||
@ -172,9 +189,9 @@ func (ss *Default) CheckHost(host string, qtype rules.RRType) (res filtering.Res
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check cache. Return cached result if it was found
|
// Check cache. Return cached result if it was found
|
||||||
cachedValue, isFound := ss.getCachedResult(host, qtype)
|
cachedValue, isFound := ss.getCachedResult(ctx, host, qtype)
|
||||||
if isFound {
|
if isFound {
|
||||||
ss.log(log.DEBUG, "found in cache: %q", host)
|
ss.logger.DebugContext(ctx, "found in cache", "host", host)
|
||||||
|
|
||||||
return cachedValue, nil
|
return cachedValue, nil
|
||||||
}
|
}
|
||||||
@ -186,7 +203,7 @@ func (ss *Default) CheckHost(host string, qtype rules.RRType) (res filtering.Res
|
|||||||
|
|
||||||
fltRes, err := ss.newResult(rewrite, qtype)
|
fltRes, err := ss.newResult(rewrite, qtype)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ss.log(log.DEBUG, "looking up addresses for %q: %s", host, err)
|
ss.logger.ErrorContext(ctx, "looking up addresses", "host", host, slogutil.KeyError, err)
|
||||||
|
|
||||||
return filtering.Result{}, err
|
return filtering.Result{}, err
|
||||||
}
|
}
|
||||||
@ -195,7 +212,7 @@ func (ss *Default) CheckHost(host string, qtype rules.RRType) (res filtering.Res
|
|||||||
|
|
||||||
// TODO(a.garipov): Consider switch back to resolving CNAME records IPs and
|
// TODO(a.garipov): Consider switch back to resolving CNAME records IPs and
|
||||||
// saving results to cache.
|
// saving results to cache.
|
||||||
ss.setCacheResult(host, qtype, res)
|
ss.setCacheResult(ctx, host, qtype, res)
|
||||||
|
|
||||||
return res, nil
|
return res, nil
|
||||||
}
|
}
|
||||||
@ -255,7 +272,12 @@ func (ss *Default) newResult(
|
|||||||
|
|
||||||
// setCacheResult stores data in cache for host. qtype is expected to be either
|
// setCacheResult stores data in cache for host. qtype is expected to be either
|
||||||
// [dns.TypeA] or [dns.TypeAAAA].
|
// [dns.TypeA] or [dns.TypeAAAA].
|
||||||
func (ss *Default) setCacheResult(host string, qtype rules.RRType, res filtering.Result) {
|
func (ss *Default) setCacheResult(
|
||||||
|
ctx context.Context,
|
||||||
|
host string,
|
||||||
|
qtype rules.RRType,
|
||||||
|
res filtering.Result,
|
||||||
|
) {
|
||||||
expire := uint32(time.Now().Add(ss.cacheTTL).Unix())
|
expire := uint32(time.Now().Add(ss.cacheTTL).Unix())
|
||||||
exp := make([]byte, 4)
|
exp := make([]byte, 4)
|
||||||
binary.BigEndian.PutUint32(exp, expire)
|
binary.BigEndian.PutUint32(exp, expire)
|
||||||
@ -263,7 +285,7 @@ func (ss *Default) setCacheResult(host string, qtype rules.RRType, res filtering
|
|||||||
|
|
||||||
err := gob.NewEncoder(buf).Encode(res)
|
err := gob.NewEncoder(buf).Encode(res)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ss.log(log.ERROR, "cache encoding: %s", err)
|
ss.logger.ErrorContext(ctx, "cache encoding", slogutil.KeyError, err)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -271,12 +293,18 @@ func (ss *Default) setCacheResult(host string, qtype rules.RRType, res filtering
|
|||||||
val := buf.Bytes()
|
val := buf.Bytes()
|
||||||
_ = ss.cache.Set([]byte(dns.Type(qtype).String()+" "+host), val)
|
_ = ss.cache.Set([]byte(dns.Type(qtype).String()+" "+host), val)
|
||||||
|
|
||||||
ss.log(log.DEBUG, "stored in cache: %q, %d bytes", host, len(val))
|
ss.logger.DebugContext(
|
||||||
|
ctx,
|
||||||
|
"stored in cache",
|
||||||
|
"host", host,
|
||||||
|
"entry_size", datasize.ByteSize(len(val)),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// getCachedResult returns stored data from cache for host. qtype is expected
|
// getCachedResult returns stored data from cache for host. qtype is expected
|
||||||
// to be either [dns.TypeA] or [dns.TypeAAAA].
|
// to be either [dns.TypeA] or [dns.TypeAAAA].
|
||||||
func (ss *Default) getCachedResult(
|
func (ss *Default) getCachedResult(
|
||||||
|
ctx context.Context,
|
||||||
host string,
|
host string,
|
||||||
qtype rules.RRType,
|
qtype rules.RRType,
|
||||||
) (res filtering.Result, ok bool) {
|
) (res filtering.Result, ok bool) {
|
||||||
@ -298,7 +326,7 @@ func (ss *Default) getCachedResult(
|
|||||||
|
|
||||||
err := gob.NewDecoder(buf).Decode(&res)
|
err := gob.NewDecoder(buf).Decode(&res)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ss.log(log.ERROR, "cache decoding: %s", err)
|
ss.logger.ErrorContext(ctx, "cache decoding", slogutil.KeyError, err)
|
||||||
|
|
||||||
return filtering.Result{}, false
|
return filtering.Result{}, false
|
||||||
}
|
}
|
||||||
@ -308,11 +336,11 @@ func (ss *Default) getCachedResult(
|
|||||||
|
|
||||||
// Update implements the [filtering.SafeSearch] interface for *Default. Update
|
// Update implements the [filtering.SafeSearch] interface for *Default. Update
|
||||||
// ignores the CustomResolver and Enabled fields.
|
// ignores the CustomResolver and Enabled fields.
|
||||||
func (ss *Default) Update(conf filtering.SafeSearchConfig) (err error) {
|
func (ss *Default) Update(ctx context.Context, conf filtering.SafeSearchConfig) (err error) {
|
||||||
ss.mu.Lock()
|
ss.mu.Lock()
|
||||||
defer ss.mu.Unlock()
|
defer ss.mu.Unlock()
|
||||||
|
|
||||||
err = ss.resetEngine(rulelist.URLFilterIDSafeSearch, conf)
|
err = ss.resetEngine(ctx, rulelist.URLFilterIDSafeSearch, conf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Don't wrap the error, because it's informative enough as is.
|
// Don't wrap the error, because it's informative enough as is.
|
||||||
return err
|
return err
|
||||||
|
@ -6,6 +6,8 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||||
|
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||||
|
"github.com/AdguardTeam/golibs/testutil"
|
||||||
"github.com/AdguardTeam/urlfilter/rules"
|
"github.com/AdguardTeam/urlfilter/rules"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@ -21,6 +23,9 @@ const (
|
|||||||
testCacheTTL = 30 * time.Minute
|
testCacheTTL = 30 * time.Minute
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// testTimeout is the common timeout for tests and contexts.
|
||||||
|
const testTimeout = 1 * time.Second
|
||||||
|
|
||||||
var defaultSafeSearchConf = filtering.SafeSearchConfig{
|
var defaultSafeSearchConf = filtering.SafeSearchConfig{
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
Bing: true,
|
Bing: true,
|
||||||
@ -35,7 +40,12 @@ var defaultSafeSearchConf = filtering.SafeSearchConfig{
|
|||||||
var yandexIP = netip.AddrFrom4([4]byte{213, 180, 193, 56})
|
var yandexIP = netip.AddrFrom4([4]byte{213, 180, 193, 56})
|
||||||
|
|
||||||
func newForTest(t testing.TB, ssConf filtering.SafeSearchConfig) (ss *Default) {
|
func newForTest(t testing.TB, ssConf filtering.SafeSearchConfig) (ss *Default) {
|
||||||
ss, err := NewDefault(ssConf, "", testCacheSize, testCacheTTL)
|
ss, err := NewDefault(testutil.ContextWithTimeout(t, testTimeout), &DefaultConfig{
|
||||||
|
Logger: slogutil.NewDiscardLogger(),
|
||||||
|
ServicesConfig: ssConf,
|
||||||
|
CacheSize: testCacheSize,
|
||||||
|
CacheTTL: testCacheTTL,
|
||||||
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
return ss
|
return ss
|
||||||
@ -52,16 +62,17 @@ func TestSafeSearchCacheYandex(t *testing.T) {
|
|||||||
const domain = "yandex.ru"
|
const domain = "yandex.ru"
|
||||||
|
|
||||||
ss := newForTest(t, filtering.SafeSearchConfig{Enabled: false})
|
ss := newForTest(t, filtering.SafeSearchConfig{Enabled: false})
|
||||||
|
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||||
|
|
||||||
// Check host with disabled safesearch.
|
// Check host with disabled safesearch.
|
||||||
res, err := ss.CheckHost(domain, testQType)
|
res, err := ss.CheckHost(ctx, domain, testQType)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.False(t, res.IsFiltered)
|
assert.False(t, res.IsFiltered)
|
||||||
assert.Empty(t, res.Rules)
|
assert.Empty(t, res.Rules)
|
||||||
|
|
||||||
ss = newForTest(t, defaultSafeSearchConf)
|
ss = newForTest(t, defaultSafeSearchConf)
|
||||||
res, err = ss.CheckHost(domain, testQType)
|
res, err = ss.CheckHost(ctx, domain, testQType)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// For yandex we already know valid IP.
|
// For yandex we already know valid IP.
|
||||||
@ -70,7 +81,7 @@ func TestSafeSearchCacheYandex(t *testing.T) {
|
|||||||
assert.Equal(t, res.Rules[0].IP, yandexIP)
|
assert.Equal(t, res.Rules[0].IP, yandexIP)
|
||||||
|
|
||||||
// Check cache.
|
// Check cache.
|
||||||
cachedValue, isFound := ss.getCachedResult(domain, testQType)
|
cachedValue, isFound := ss.getCachedResult(ctx, domain, testQType)
|
||||||
require.True(t, isFound)
|
require.True(t, isFound)
|
||||||
require.Len(t, cachedValue.Rules, 1)
|
require.Len(t, cachedValue.Rules, 1)
|
||||||
|
|
||||||
|
@ -10,15 +10,15 @@ import (
|
|||||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist"
|
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
|
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
|
||||||
|
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||||
"github.com/AdguardTeam/golibs/testutil"
|
"github.com/AdguardTeam/golibs/testutil"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestMain(m *testing.M) {
|
// testTimeout is the common timeout for tests and contexts.
|
||||||
testutil.DiscardLogOutput(m)
|
const testTimeout = 1 * time.Second
|
||||||
}
|
|
||||||
|
|
||||||
// Common test constants.
|
// Common test constants.
|
||||||
const (
|
const (
|
||||||
@ -47,7 +47,13 @@ var yandexIP = netip.AddrFrom4([4]byte{213, 180, 193, 56})
|
|||||||
|
|
||||||
func TestDefault_CheckHost_yandex(t *testing.T) {
|
func TestDefault_CheckHost_yandex(t *testing.T) {
|
||||||
conf := testConf
|
conf := testConf
|
||||||
ss, err := safesearch.NewDefault(conf, "", testCacheSize, testCacheTTL)
|
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||||
|
ss, err := safesearch.NewDefault(ctx, &safesearch.DefaultConfig{
|
||||||
|
Logger: slogutil.NewDiscardLogger(),
|
||||||
|
ServicesConfig: conf,
|
||||||
|
CacheSize: testCacheSize,
|
||||||
|
CacheTTL: testCacheTTL,
|
||||||
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
hosts := []string{
|
hosts := []string{
|
||||||
@ -82,7 +88,7 @@ func TestDefault_CheckHost_yandex(t *testing.T) {
|
|||||||
for _, host := range hosts {
|
for _, host := range hosts {
|
||||||
// Check host for each domain.
|
// Check host for each domain.
|
||||||
var res filtering.Result
|
var res filtering.Result
|
||||||
res, err = ss.CheckHost(host, tc.qt)
|
res, err = ss.CheckHost(ctx, host, tc.qt)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.True(t, res.IsFiltered)
|
assert.True(t, res.IsFiltered)
|
||||||
@ -103,7 +109,13 @@ func TestDefault_CheckHost_yandex(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestDefault_CheckHost_google(t *testing.T) {
|
func TestDefault_CheckHost_google(t *testing.T) {
|
||||||
ss, err := safesearch.NewDefault(testConf, "", testCacheSize, testCacheTTL)
|
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||||
|
ss, err := safesearch.NewDefault(ctx, &safesearch.DefaultConfig{
|
||||||
|
Logger: slogutil.NewDiscardLogger(),
|
||||||
|
ServicesConfig: testConf,
|
||||||
|
CacheSize: testCacheSize,
|
||||||
|
CacheTTL: testCacheTTL,
|
||||||
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Check host for each domain.
|
// Check host for each domain.
|
||||||
@ -118,7 +130,7 @@ func TestDefault_CheckHost_google(t *testing.T) {
|
|||||||
} {
|
} {
|
||||||
t.Run(host, func(t *testing.T) {
|
t.Run(host, func(t *testing.T) {
|
||||||
var res filtering.Result
|
var res filtering.Result
|
||||||
res, err = ss.CheckHost(host, testQType)
|
res, err = ss.CheckHost(ctx, host, testQType)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.True(t, res.IsFiltered)
|
assert.True(t, res.IsFiltered)
|
||||||
@ -149,13 +161,19 @@ func (r *testResolver) LookupIP(
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestDefault_CheckHost_duckduckgoAAAA(t *testing.T) {
|
func TestDefault_CheckHost_duckduckgoAAAA(t *testing.T) {
|
||||||
ss, err := safesearch.NewDefault(testConf, "", testCacheSize, testCacheTTL)
|
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||||
|
ss, err := safesearch.NewDefault(ctx, &safesearch.DefaultConfig{
|
||||||
|
Logger: slogutil.NewDiscardLogger(),
|
||||||
|
ServicesConfig: testConf,
|
||||||
|
CacheSize: testCacheSize,
|
||||||
|
CacheTTL: testCacheTTL,
|
||||||
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// The DuckDuckGo safe-search addresses are resolved through CNAMEs, but
|
// The DuckDuckGo safe-search addresses are resolved through CNAMEs, but
|
||||||
// DuckDuckGo doesn't have a safe-search IPv6 address. The result should be
|
// DuckDuckGo doesn't have a safe-search IPv6 address. The result should be
|
||||||
// the same as the one for Yandex IPv6. That is, a NODATA response.
|
// the same as the one for Yandex IPv6. That is, a NODATA response.
|
||||||
res, err := ss.CheckHost("www.duckduckgo.com", dns.TypeAAAA)
|
res, err := ss.CheckHost(ctx, "www.duckduckgo.com", dns.TypeAAAA)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.True(t, res.IsFiltered)
|
assert.True(t, res.IsFiltered)
|
||||||
@ -166,32 +184,38 @@ func TestDefault_CheckHost_duckduckgoAAAA(t *testing.T) {
|
|||||||
|
|
||||||
func TestDefault_Update(t *testing.T) {
|
func TestDefault_Update(t *testing.T) {
|
||||||
conf := testConf
|
conf := testConf
|
||||||
ss, err := safesearch.NewDefault(conf, "", testCacheSize, testCacheTTL)
|
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||||
|
ss, err := safesearch.NewDefault(ctx, &safesearch.DefaultConfig{
|
||||||
|
Logger: slogutil.NewDiscardLogger(),
|
||||||
|
ServicesConfig: conf,
|
||||||
|
CacheSize: testCacheSize,
|
||||||
|
CacheTTL: testCacheTTL,
|
||||||
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
res, err := ss.CheckHost("www.yandex.com", testQType)
|
res, err := ss.CheckHost(ctx, "www.yandex.com", testQType)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.True(t, res.IsFiltered)
|
assert.True(t, res.IsFiltered)
|
||||||
|
|
||||||
err = ss.Update(filtering.SafeSearchConfig{
|
err = ss.Update(ctx, filtering.SafeSearchConfig{
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
Google: false,
|
Google: false,
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
res, err = ss.CheckHost("www.yandex.com", testQType)
|
res, err = ss.CheckHost(ctx, "www.yandex.com", testQType)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.False(t, res.IsFiltered)
|
assert.False(t, res.IsFiltered)
|
||||||
|
|
||||||
err = ss.Update(filtering.SafeSearchConfig{
|
err = ss.Update(ctx, filtering.SafeSearchConfig{
|
||||||
Enabled: false,
|
Enabled: false,
|
||||||
Google: true,
|
Google: true,
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
res, err = ss.CheckHost("www.yandex.com", testQType)
|
res, err = ss.CheckHost(ctx, "www.yandex.com", testQType)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.False(t, res.IsFiltered)
|
assert.False(t, res.IsFiltered)
|
||||||
|
@ -51,7 +51,7 @@ func (d *DNSFilter) handleSafeSearchSettings(w http.ResponseWriter, r *http.Requ
|
|||||||
}
|
}
|
||||||
|
|
||||||
conf := *req
|
conf := *req
|
||||||
err = d.safeSearch.Update(conf)
|
err = d.safeSearch.Update(r.Context(), conf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
aghhttp.Error(r, w, http.StatusBadRequest, "updating: %s", err)
|
aghhttp.Error(r, w, http.StatusBadRequest, "updating: %s", err)
|
||||||
|
|
||||||
|
@ -3,6 +3,7 @@ package home
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"slices"
|
"slices"
|
||||||
"sync"
|
"sync"
|
||||||
@ -13,17 +14,23 @@ import (
|
|||||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
|
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/schedule"
|
"github.com/AdguardTeam/AdGuardHome/internal/schedule"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
|
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||||
"github.com/AdguardTeam/golibs/stringutil"
|
"github.com/AdguardTeam/golibs/stringutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
// clientsContainer is the storage of all runtime and persistent clients.
|
// clientsContainer is the storage of all runtime and persistent clients.
|
||||||
type clientsContainer struct {
|
type clientsContainer struct {
|
||||||
|
// baseLogger is used to create loggers with custom prefixes for safe search
|
||||||
|
// filter. It must not be nil.
|
||||||
|
baseLogger *slog.Logger
|
||||||
|
|
||||||
// storage stores information about persistent clients.
|
// storage stores information about persistent clients.
|
||||||
storage *client.Storage
|
storage *client.Storage
|
||||||
|
|
||||||
@ -61,6 +68,8 @@ type BlockedClientChecker interface {
|
|||||||
// dhcpServer: optional
|
// dhcpServer: optional
|
||||||
// Note: this function must be called only once
|
// Note: this function must be called only once
|
||||||
func (clients *clientsContainer) Init(
|
func (clients *clientsContainer) Init(
|
||||||
|
ctx context.Context,
|
||||||
|
baseLogger *slog.Logger,
|
||||||
objects []*clientObject,
|
objects []*clientObject,
|
||||||
dhcpServer client.DHCP,
|
dhcpServer client.DHCP,
|
||||||
etcHosts *aghnet.HostsContainer,
|
etcHosts *aghnet.HostsContainer,
|
||||||
@ -72,13 +81,14 @@ func (clients *clientsContainer) Init(
|
|||||||
return errors.Error("clients container already initialized")
|
return errors.Error("clients container already initialized")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
clients.baseLogger = baseLogger
|
||||||
clients.safeSearchCacheSize = filteringConf.SafeSearchCacheSize
|
clients.safeSearchCacheSize = filteringConf.SafeSearchCacheSize
|
||||||
clients.safeSearchCacheTTL = time.Minute * time.Duration(filteringConf.CacheTime)
|
clients.safeSearchCacheTTL = time.Minute * time.Duration(filteringConf.CacheTime)
|
||||||
|
|
||||||
confClients := make([]*client.Persistent, 0, len(objects))
|
confClients := make([]*client.Persistent, 0, len(objects))
|
||||||
for i, o := range objects {
|
for i, o := range objects {
|
||||||
var p *client.Persistent
|
var p *client.Persistent
|
||||||
p, err = o.toPersistent(clients.safeSearchCacheSize, clients.safeSearchCacheTTL)
|
p, err = o.toPersistent(ctx, baseLogger, clients.safeSearchCacheSize, clients.safeSearchCacheTTL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("init persistent client at index %d: %w", i, err)
|
return fmt.Errorf("init persistent client at index %d: %w", i, err)
|
||||||
}
|
}
|
||||||
@ -168,6 +178,8 @@ type clientObject struct {
|
|||||||
|
|
||||||
// toPersistent returns an initialized persistent client if there are no errors.
|
// toPersistent returns an initialized persistent client if there are no errors.
|
||||||
func (o *clientObject) toPersistent(
|
func (o *clientObject) toPersistent(
|
||||||
|
ctx context.Context,
|
||||||
|
baseLogger *slog.Logger,
|
||||||
safeSearchCacheSize uint,
|
safeSearchCacheSize uint,
|
||||||
safeSearchCacheTTL time.Duration,
|
safeSearchCacheTTL time.Duration,
|
||||||
) (cli *client.Persistent, err error) {
|
) (cli *client.Persistent, err error) {
|
||||||
@ -203,14 +215,23 @@ func (o *clientObject) toPersistent(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if o.SafeSearchConf.Enabled {
|
if o.SafeSearchConf.Enabled {
|
||||||
err = cli.SetSafeSearch(
|
logger := baseLogger.With(
|
||||||
o.SafeSearchConf,
|
slogutil.KeyPrefix, safesearch.LogPrefix,
|
||||||
safeSearchCacheSize,
|
safesearch.LogKeyClient, cli.Name,
|
||||||
safeSearchCacheTTL,
|
|
||||||
)
|
)
|
||||||
|
var ss *safesearch.Default
|
||||||
|
ss, err = safesearch.NewDefault(ctx, &safesearch.DefaultConfig{
|
||||||
|
Logger: logger,
|
||||||
|
ServicesConfig: o.SafeSearchConf,
|
||||||
|
ClientName: cli.Name,
|
||||||
|
CacheSize: safeSearchCacheSize,
|
||||||
|
CacheTTL: safeSearchCacheTTL,
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("init safesearch %q: %w", cli.Name, err)
|
return nil, fmt.Errorf("init safesearch %q: %w", cli.Name, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cli.SafeSearch = ss
|
||||||
}
|
}
|
||||||
|
|
||||||
if o.BlockedServices == nil {
|
if o.BlockedServices == nil {
|
||||||
|
@ -7,6 +7,8 @@ import (
|
|||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||||
|
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||||
|
"github.com/AdguardTeam/golibs/testutil"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
@ -20,7 +22,18 @@ func newClientsContainer(t *testing.T) (c *clientsContainer) {
|
|||||||
testing: true,
|
testing: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
require.NoError(t, c.Init(nil, client.EmptyDHCP{}, nil, nil, &filtering.Config{}))
|
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||||
|
err := c.Init(
|
||||||
|
ctx,
|
||||||
|
slogutil.NewDiscardLogger(),
|
||||||
|
nil,
|
||||||
|
client.EmptyDHCP{},
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
&filtering.Config{},
|
||||||
|
)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package home
|
package home
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
@ -10,8 +11,10 @@ import (
|
|||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/schedule"
|
"github.com/AdguardTeam/AdGuardHome/internal/schedule"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||||
|
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
// clientJSON is a common structure used by several handlers to deal with
|
// clientJSON is a common structure used by several handlers to deal with
|
||||||
@ -181,6 +184,7 @@ func initPrev(cj clientJSON, prev *client.Persistent) (c *client.Persistent, err
|
|||||||
// jsonToClient converts JSON object to persistent client object if there are no
|
// jsonToClient converts JSON object to persistent client object if there are no
|
||||||
// errors.
|
// errors.
|
||||||
func (clients *clientsContainer) jsonToClient(
|
func (clients *clientsContainer) jsonToClient(
|
||||||
|
ctx context.Context,
|
||||||
cj clientJSON,
|
cj clientJSON,
|
||||||
prev *client.Persistent,
|
prev *client.Persistent,
|
||||||
) (c *client.Persistent, err error) {
|
) (c *client.Persistent, err error) {
|
||||||
@ -207,14 +211,23 @@ func (clients *clientsContainer) jsonToClient(
|
|||||||
c.UseOwnBlockedServices = !cj.UseGlobalBlockedServices
|
c.UseOwnBlockedServices = !cj.UseGlobalBlockedServices
|
||||||
|
|
||||||
if c.SafeSearchConf.Enabled {
|
if c.SafeSearchConf.Enabled {
|
||||||
err = c.SetSafeSearch(
|
logger := clients.baseLogger.With(
|
||||||
c.SafeSearchConf,
|
slogutil.KeyPrefix, safesearch.LogPrefix,
|
||||||
clients.safeSearchCacheSize,
|
safesearch.LogKeyClient, c.Name,
|
||||||
clients.safeSearchCacheTTL,
|
|
||||||
)
|
)
|
||||||
|
var ss *safesearch.Default
|
||||||
|
ss, err = safesearch.NewDefault(ctx, &safesearch.DefaultConfig{
|
||||||
|
Logger: logger,
|
||||||
|
ServicesConfig: c.SafeSearchConf,
|
||||||
|
ClientName: c.Name,
|
||||||
|
CacheSize: clients.safeSearchCacheSize,
|
||||||
|
CacheTTL: clients.safeSearchCacheTTL,
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("creating safesearch for client %q: %w", c.Name, err)
|
return nil, fmt.Errorf("creating safesearch for client %q: %w", c.Name, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.SafeSearch = ss
|
||||||
}
|
}
|
||||||
|
|
||||||
return c, nil
|
return c, nil
|
||||||
@ -321,7 +334,7 @@ func (clients *clientsContainer) handleAddClient(w http.ResponseWriter, r *http.
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c, err := clients.jsonToClient(cj, nil)
|
c, err := clients.jsonToClient(r.Context(), cj, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||||
|
|
||||||
@ -391,7 +404,7 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c, err := clients.jsonToClient(dj.Data, nil)
|
c, err := clients.jsonToClient(r.Context(), dj.Data, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||||
|
|
||||||
|
@ -11,14 +11,19 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"slices"
|
"slices"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/schedule"
|
"github.com/AdguardTeam/AdGuardHome/internal/schedule"
|
||||||
|
"github.com/AdguardTeam/golibs/testutil"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// testTimeout is the common timeout for tests and contexts.
|
||||||
|
const testTimeout = 1 * time.Second
|
||||||
|
|
||||||
const (
|
const (
|
||||||
testClientIP1 = "1.1.1.1"
|
testClientIP1 = "1.1.1.1"
|
||||||
testClientIP2 = "2.2.2.2"
|
testClientIP2 = "2.2.2.2"
|
||||||
@ -103,9 +108,10 @@ func assertPersistentClients(tb testing.TB, clients *clientsContainer, want []*c
|
|||||||
require.NoError(tb, err)
|
require.NoError(tb, err)
|
||||||
|
|
||||||
var got []*client.Persistent
|
var got []*client.Persistent
|
||||||
|
ctx := testutil.ContextWithTimeout(tb, testTimeout)
|
||||||
for _, cj := range clientList.Clients {
|
for _, cj := range clientList.Clients {
|
||||||
var c *client.Persistent
|
var c *client.Persistent
|
||||||
c, err = clients.jsonToClient(*cj, nil)
|
c, err = clients.jsonToClient(ctx, *cj, nil)
|
||||||
require.NoError(tb, err)
|
require.NoError(tb, err)
|
||||||
|
|
||||||
got = append(got, c)
|
got = append(got, c)
|
||||||
@ -125,10 +131,11 @@ func assertPersistentClientsData(
|
|||||||
tb.Helper()
|
tb.Helper()
|
||||||
|
|
||||||
var got []*client.Persistent
|
var got []*client.Persistent
|
||||||
|
ctx := testutil.ContextWithTimeout(tb, testTimeout)
|
||||||
for _, cm := range data {
|
for _, cm := range data {
|
||||||
for _, cj := range cm {
|
for _, cj := range cm {
|
||||||
var c *client.Persistent
|
var c *client.Persistent
|
||||||
c, err := clients.jsonToClient(*cj, nil)
|
c, err := clients.jsonToClient(ctx, *cj, nil)
|
||||||
require.NoError(tb, err)
|
require.NoError(tb, err)
|
||||||
|
|
||||||
got = append(got, c)
|
got = append(got, c)
|
||||||
|
@ -278,8 +278,8 @@ func setupOpts(opts options) (err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// initContextClients initializes Context clients and related fields.
|
// initContextClients initializes Context clients and related fields.
|
||||||
func initContextClients(logger *slog.Logger) (err error) {
|
func initContextClients(ctx context.Context, logger *slog.Logger) (err error) {
|
||||||
err = setupDNSFilteringConf(config.Filtering)
|
err = setupDNSFilteringConf(ctx, logger, config.Filtering)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Don't wrap the error, because it's informative enough as is.
|
// Don't wrap the error, because it's informative enough as is.
|
||||||
return err
|
return err
|
||||||
@ -306,6 +306,8 @@ func initContextClients(logger *slog.Logger) (err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return Context.clients.Init(
|
return Context.clients.Init(
|
||||||
|
ctx,
|
||||||
|
logger,
|
||||||
config.Clients.Persistent,
|
config.Clients.Persistent,
|
||||||
Context.dhcpServer,
|
Context.dhcpServer,
|
||||||
Context.etcHosts,
|
Context.etcHosts,
|
||||||
@ -355,7 +357,11 @@ func setupBindOpts(opts options) (err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// setupDNSFilteringConf sets up DNS filtering configuration settings.
|
// setupDNSFilteringConf sets up DNS filtering configuration settings.
|
||||||
func setupDNSFilteringConf(conf *filtering.Config) (err error) {
|
func setupDNSFilteringConf(
|
||||||
|
ctx context.Context,
|
||||||
|
baseLogger *slog.Logger,
|
||||||
|
conf *filtering.Config,
|
||||||
|
) (err error) {
|
||||||
const (
|
const (
|
||||||
dnsTimeout = 3 * time.Second
|
dnsTimeout = 3 * time.Second
|
||||||
|
|
||||||
@ -446,12 +452,13 @@ func setupDNSFilteringConf(conf *filtering.Config) (err error) {
|
|||||||
conf.ParentalBlockHost = host
|
conf.ParentalBlockHost = host
|
||||||
}
|
}
|
||||||
|
|
||||||
conf.SafeSearch, err = safesearch.NewDefault(
|
logger := baseLogger.With(slogutil.KeyPrefix, safesearch.LogPrefix)
|
||||||
conf.SafeSearchConf,
|
conf.SafeSearch, err = safesearch.NewDefault(ctx, &safesearch.DefaultConfig{
|
||||||
"default",
|
Logger: logger,
|
||||||
conf.SafeSearchCacheSize,
|
ServicesConfig: conf.SafeSearchConf,
|
||||||
cacheTime,
|
CacheSize: conf.SafeSearchCacheSize,
|
||||||
)
|
CacheTTL: cacheTime,
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("initializing safesearch: %w", err)
|
return fmt.Errorf("initializing safesearch: %w", err)
|
||||||
}
|
}
|
||||||
@ -584,7 +591,10 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
|
|||||||
// data first, but also to avoid relying on automatic Go init() function.
|
// data first, but also to avoid relying on automatic Go init() function.
|
||||||
filtering.InitModule()
|
filtering.InitModule()
|
||||||
|
|
||||||
err = initContextClients(slogLogger)
|
// TODO(s.chzhen): Use it for the entire initialization process.
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
err = initContextClients(ctx, slogLogger)
|
||||||
fatalOnError(err)
|
fatalOnError(err)
|
||||||
|
|
||||||
err = setupOpts(opts)
|
err = setupOpts(opts)
|
||||||
|
Loading…
Reference in New Issue
Block a user