diff --git a/go.mod b/go.mod index 5e62e8ef..375bed97 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/AdguardTeam/AdGuardHome go 1.20 require ( - github.com/AdguardTeam/dnsproxy v0.57.3 + github.com/AdguardTeam/dnsproxy v0.58.0 github.com/AdguardTeam/golibs v0.17.2 github.com/AdguardTeam/urlfilter v0.17.3 github.com/NYTimes/gziphandler v1.1.1 diff --git a/go.sum b/go.sum index aa43d0dd..dcb90678 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -github.com/AdguardTeam/dnsproxy v0.57.3 h1:0v7D+LQrOL2k2fvkG3Ft3Cn3ayUsvAdlOlJR+gLxSGA= -github.com/AdguardTeam/dnsproxy v0.57.3/go.mod h1:ZvkbM71HwpilgkCnTubDiR4Ba6x5Qvnhy2iasMWaTDM= +github.com/AdguardTeam/dnsproxy v0.58.0 h1:1zPmDYWIc60D5Mn2idt3TcH+CQzKBvkWzJ5/u49wraw= +github.com/AdguardTeam/dnsproxy v0.58.0/go.mod h1:ZvkbM71HwpilgkCnTubDiR4Ba6x5Qvnhy2iasMWaTDM= github.com/AdguardTeam/golibs v0.17.2 h1:vg6wHMjUKscnyPGRvxS5kAt7Uw4YxcJiITZliZ476W8= github.com/AdguardTeam/golibs v0.17.2/go.mod h1:DKhCIXHcUYtBhU8ibTLKh1paUL96n5zhQBlx763sj+U= github.com/AdguardTeam/urlfilter v0.17.3 h1:fg/ObbnO0Cv6aw0tW6N/ETDMhhNvmcUUOZ7HlmKC3rw= diff --git a/internal/aghnet/hostscontainer.go b/internal/aghnet/hostscontainer.go index 0d9a4bcc..e27b115f 100644 --- a/internal/aghnet/hostscontainer.go +++ b/internal/aghnet/hostscontainer.go @@ -1,14 +1,17 @@ package aghnet import ( + "context" "fmt" "io" "io/fs" "net/netip" "path" + "strings" "sync/atomic" "github.com/AdguardTeam/AdGuardHome/internal/aghos" + "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/hostsfile" "github.com/AdguardTeam/golibs/log" @@ -141,13 +144,9 @@ func NewHostsContainer( func (hc *HostsContainer) Close() (err error) { log.Debug("%s: closing", hostsContainerPrefix) - err = hc.watcher.Close() - if err != nil { - err = fmt.Errorf("closing fs watcher: %w", err) - - // Go on and close the container either way. - } + err = errors.Annotate(hc.watcher.Close(), "closing fs watcher: %w") + // Go on and close the container either way. close(hc.done) return err @@ -319,3 +318,39 @@ func (hc *HostsContainer) refresh() (err error) { return nil } + +// type check +var _ upstream.Resolver = (*HostsContainer)(nil) + +// LookupNetIP implements the [upstream.Resolver] interface for *HostsContainer. +func (hc *HostsContainer) LookupNetIP( + ctx context.Context, + network string, + hostname string, +) (addrs []netip.Addr, err error) { + // TODO(e.burkov): Think of extracting this logic to a golibs function if + // needed anywhere else. + var isDesiredProto func(ip netip.Addr) (ok bool) + switch network { + case "ip4": + isDesiredProto = (netip.Addr).Is4 + case "ip6": + isDesiredProto = (netip.Addr).Is6 + case "ip": + isDesiredProto = func(ip netip.Addr) (ok bool) { return true } + default: + return nil, fmt.Errorf("unsupported network: %q", network) + } + + idx := hc.current.Load() + recs := idx.names[strings.ToLower(hostname)] + + addrs = make([]netip.Addr, 0, len(recs)) + for _, rec := range recs { + if isDesiredProto(rec.Addr) { + addrs = append(addrs, rec.Addr) + } + } + + return slices.Clip(addrs), nil +} diff --git a/internal/aghnet/net.go b/internal/aghnet/net.go index a0489dd8..ac9fb7bd 100644 --- a/internal/aghnet/net.go +++ b/internal/aghnet/net.go @@ -10,9 +10,11 @@ import ( "net" "net/netip" "net/url" + "strings" "syscall" "github.com/AdguardTeam/AdGuardHome/internal/aghos" + "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" ) @@ -307,6 +309,50 @@ func ParseAddrPort(s string, defaultPort uint16) (ipp netip.AddrPort, err error) return ipp, nil } +// ParseSubnet parses s either as a CIDR prefix itself, or as an IP address, +// returning the corresponding single-IP CIDR prefix. +// +// TODO(e.burkov): Taken from dnsproxy, move to golibs. +func ParseSubnet(s string) (p netip.Prefix, err error) { + if strings.Contains(s, "/") { + p, err = netip.ParsePrefix(s) + if err != nil { + return netip.Prefix{}, err + } + } else { + var ip netip.Addr + ip, err = netip.ParseAddr(s) + if err != nil { + return netip.Prefix{}, err + } + + p = netip.PrefixFrom(ip, ip.BitLen()) + } + + return p, nil +} + +// ParseBootstraps returns the slice of upstream resolvers parsed from addrs. +// It additionally returns the closers for each resolver, that should be closed +// after use. +func ParseBootstraps( + addrs []string, + opts *upstream.Options, +) (boots []*upstream.UpstreamResolver, err error) { + boots = make([]*upstream.UpstreamResolver, 0, len(boots)) + for i, b := range addrs { + var r *upstream.UpstreamResolver + r, err = upstream.NewUpstreamResolver(b, opts) + if err != nil { + return nil, fmt.Errorf("bootstrap at index %d: %w", i, err) + } + + boots = append(boots, r) + } + + return boots, nil +} + // BroadcastFromPref calculates the broadcast IP address for p. func BroadcastFromPref(p netip.Prefix) (bc netip.Addr) { bc = p.Addr().Unmap() diff --git a/internal/aghtest/interface.go b/internal/aghtest/interface.go index 10789d8e..3d88f0c1 100644 --- a/internal/aghtest/interface.go +++ b/internal/aghtest/interface.go @@ -11,6 +11,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/next/agh" "github.com/AdguardTeam/AdGuardHome/internal/rdns" "github.com/AdguardTeam/AdGuardHome/internal/whois" + "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" "github.com/miekg/dns" ) @@ -116,6 +117,26 @@ func (p *AddressUpdater) UpdateAddress(ip netip.Addr, host string, info *whois.I p.OnUpdateAddress(ip, host, info) } +// Package dnsforward + +// ClientsContainer is a fake [dnsforward.ClientsContainer] implementation for +// tests. +type ClientsContainer struct { + OnUpstreamConfigByID func( + id string, + boot upstream.Resolver, + ) (conf *proxy.UpstreamConfig, err error) +} + +// UpstreamConfigByID implements the [dnsforward.ClientsContainer] interface +// for *ClientsContainer. +func (c *ClientsContainer) UpstreamConfigByID( + id string, + boot upstream.Resolver, +) (conf *proxy.UpstreamConfig, err error) { + return c.OnUpstreamConfigByID(id, boot) +} + // Package filtering // Resolver is a fake [filtering.Resolver] implementation for tests. diff --git a/internal/aghtest/interface_test.go b/internal/aghtest/interface_test.go index a17c5e67..c1a376ba 100644 --- a/internal/aghtest/interface_test.go +++ b/internal/aghtest/interface_test.go @@ -2,6 +2,7 @@ package aghtest_test import ( "github.com/AdguardTeam/AdGuardHome/internal/aghtest" + "github.com/AdguardTeam/AdGuardHome/internal/dnsforward" "github.com/AdguardTeam/AdGuardHome/internal/filtering" ) @@ -9,3 +10,6 @@ import ( // type check var _ filtering.Resolver = (*aghtest.Resolver)(nil) + +// type check +var _ dnsforward.ClientsContainer = (*aghtest.ClientsContainer)(nil) diff --git a/internal/dnsforward/config.go b/internal/dnsforward/config.go index 32a0d52a..0947bf8f 100644 --- a/internal/dnsforward/config.go +++ b/internal/dnsforward/config.go @@ -27,6 +27,16 @@ import ( "golang.org/x/exp/slices" ) +// ClientsContainer provides information about preconfigured DNS clients. +type ClientsContainer interface { + // UpstreamConfigByID returns the custom upstream configuration for the + // client having id, using boot to initialize the one if necessary. It + // returns nil if there is no custom upstream configuration for the client. + // The id is expected to be either a string representation of an IP address + // or the ClientID. + UpstreamConfigByID(id string, boot upstream.Resolver) (conf *proxy.UpstreamConfig, err error) +} + // Config represents the DNS filtering configuration of AdGuard Home. The zero // Config is empty and ready for use. type Config struct { @@ -35,10 +45,9 @@ type Config struct { // FilterHandler is an optional additional filtering callback. FilterHandler func(cliAddr netip.Addr, clientID string, settings *filtering.Settings) `yaml:"-"` - // GetCustomUpstreamByClient is a callback that returns upstreams - // configuration based on the client IP address or ClientID. It returns - // nil if there are no custom upstreams for the client. - GetCustomUpstreamByClient func(id string) (conf *proxy.UpstreamConfig, err error) `yaml:"-"` + // ClientsContainer stores the information about special handling of some + // DNS clients. + ClientsContainer ClientsContainer `yaml:"-"` // Anti-DNS amplification @@ -323,8 +332,8 @@ func (s *Server) createProxyConfig() (conf proxy.Config, err error) { ) for i, s := range srvConf.BogusNXDomain { - var subnet *net.IPNet - subnet, err = netutil.ParseSubnet(s) + var subnet netip.Prefix + subnet, err = aghnet.ParseSubnet(s) if err != nil { log.Error("subnet at index %d: %s", i, err) @@ -423,10 +432,7 @@ func collectListenAddr( // collectDNSAddrs returns configured set of listening addresses. It also // returns a set of ports of each unspecified listening address. -func (conf *ServerConfig) collectDNSAddrs() ( - addrs map[netip.AddrPort]unit, - unspecPorts map[uint16]unit, -) { +func (conf *ServerConfig) collectDNSAddrs() (addrs mapAddrPortSet, unspecPorts map[uint16]unit) { // TODO(e.burkov): Perhaps, we shouldn't allocate as much memory, since the // TCP and UDP listening addresses are currently the same. addrs = make(map[netip.AddrPort]unit, len(conf.TCPListenAddrs)+len(conf.UDPListenAddrs)) @@ -446,20 +452,64 @@ func (conf *ServerConfig) collectDNSAddrs() ( // defaultPlainDNSPort is the default port for plain DNS. const defaultPlainDNSPort uint16 = 53 -// addrPortMatcher is a function that matches an IP address with port. -type addrPortMatcher func(addr netip.AddrPort) (ok bool) +// addrPortSet is a set of [netip.AddrPort] values. +type addrPortSet interface { + // Has returns true if addrPort is in the set. + Has(addrPort netip.AddrPort) (ok bool) +} + +// type check +var _ addrPortSet = emptyAddrPortSet{} + +// emptyAddrPortSet is the [addrPortSet] containing no values. +type emptyAddrPortSet struct{} + +// Has implements the [addrPortSet] interface for [emptyAddrPortSet]. +func (emptyAddrPortSet) Has(_ netip.AddrPort) (ok bool) { return false } + +// mapAddrPortSet is the [addrPortSet] containing values of [netip.AddrPort] as +// keys of a map. +type mapAddrPortSet map[netip.AddrPort]unit + +// type check +var _ addrPortSet = mapAddrPortSet{} + +// Has implements the [addrPortSet] interface for [mapAddrPortSet]. +func (m mapAddrPortSet) Has(addrPort netip.AddrPort) (ok bool) { + _, ok = m[addrPort] + + return ok +} + +// combinedAddrPortSet is the [addrPortSet] defined by some IP addresses along +// with ports, any combination of which is considered being in the set. +type combinedAddrPortSet struct { + // TODO(e.burkov): Use sorted slices in combination with binary search. + ports map[uint16]unit + addrs []netip.Addr +} + +// type check +var _ addrPortSet = (*combinedAddrPortSet)(nil) + +// Has implements the [addrPortSet] interface for [*combinedAddrPortSet]. +func (m *combinedAddrPortSet) Has(addrPort netip.AddrPort) (ok bool) { + _, ok = m.ports[addrPort.Port()] + + return ok && slices.Contains(m.addrs, addrPort.Addr()) +} // filterOut filters out all the upstreams that match um. It returns all the // closing errors joined. -func (m addrPortMatcher) filterOut(upsConf *proxy.UpstreamConfig) (err error) { +func filterOutAddrs(upsConf *proxy.UpstreamConfig, set addrPortSet) (err error) { var errs []error delFunc := func(u upstream.Upstream) (ok bool) { // TODO(e.burkov): We should probably consider the protocol of u to // only filter out the listening addresses of the same protocol. addr, parseErr := aghnet.ParseAddrPort(u.Address(), defaultPlainDNSPort) - if parseErr != nil || !m(addr) { + if parseErr != nil || !set.Has(addr) { // Don't filter out the upstream if it either cannot be parsed, or - // does not match um. + // does not match m. return false } @@ -479,26 +529,20 @@ func (m addrPortMatcher) filterOut(upsConf *proxy.UpstreamConfig) (err error) { return errors.Join(errs...) } -// ourAddrsMatcher returns a matcher that matches all the configured listening +// ourAddrsSet returns an addrPortSet that contains all the configured listening // addresses. -func (conf *ServerConfig) ourAddrsMatcher() (m addrPortMatcher, err error) { +func (conf *ServerConfig) ourAddrsSet() (m addrPortSet, err error) { addrs, unspecPorts := conf.collectDNSAddrs() - if len(addrs) == 0 { + switch { + case len(addrs) == 0: log.Debug("dnsforward: no listen addresses") - // Match no addresses. - return func(_ netip.AddrPort) (ok bool) { return false }, nil - } - - if len(unspecPorts) == 0 { + return emptyAddrPortSet{}, nil + case len(unspecPorts) == 0: log.Debug("dnsforward: filtering out addresses %s", addrs) - m = func(a netip.AddrPort) (ok bool) { - _, ok = addrs[a] - - return ok - } - } else { + return addrs, nil + default: var ifaceAddrs []netip.Addr ifaceAddrs, err = aghnet.CollectAllIfacesAddrs() if err != nil { @@ -508,16 +552,11 @@ func (conf *ServerConfig) ourAddrsMatcher() (m addrPortMatcher, err error) { log.Debug("dnsforward: filtering out addresses %s on ports %d", ifaceAddrs, unspecPorts) - m = func(a netip.AddrPort) (ok bool) { - if _, ok = unspecPorts[a.Port()]; ok { - return slices.Contains(ifaceAddrs, a.Addr()) - } - - return false - } + return &combinedAddrPortSet{ + ports: unspecPorts, + addrs: ifaceAddrs, + }, nil } - - return m, nil } // prepareTLS - prepares TLS configuration for the DNS proxy @@ -574,7 +613,7 @@ func (s *Server) prepareTLS(proxyConfig *proxy.Config) (err error) { // isWildcard returns true if host is a wildcard hostname. func isWildcard(host string) (ok bool) { - return len(host) >= 2 && host[0] == '*' && host[1] == '.' + return strings.HasPrefix(host, "*.") } // matchesDomainWildcard returns true if host matches the domain wildcard diff --git a/internal/dnsforward/dnsforward.go b/internal/dnsforward/dnsforward.go index d4c61205..1fa7979e 100644 --- a/internal/dnsforward/dnsforward.go +++ b/internal/dnsforward/dnsforward.go @@ -3,6 +3,7 @@ package dnsforward import ( "fmt" + "io" "net" "net/http" "net/netip" @@ -135,8 +136,21 @@ type Server struct { // PTR resolving. sysResolvers SystemResolvers - // recDetector is a cache for recursive requests. It is used to detect - // and prevent recursive requests only for private upstreams. + // etcHosts contains the data from the system's hosts files. + etcHosts upstream.Resolver + + // bootstrap is the resolver for upstreams' hostnames. + bootstrap upstream.Resolver + + // bootResolvers are the resolvers that should be used for + // bootstrapping along with [etcHosts]. + // + // TODO(e.burkov): Use [proxy.UpstreamConfig] when it will implement the + // [upstream.Resolver] interface. + bootResolvers []*upstream.UpstreamResolver + + // recDetector is a cache for recursive requests. It is used to detect and + // prevent recursive requests only for private upstreams. // // See https://github.com/adguardTeam/adGuardHome/issues/3185#issuecomment-851048135. recDetector *recursionDetector @@ -153,8 +167,8 @@ type Server struct { // during the BeforeRequestHandler stage. clientIDCache cache.Cache - // DNS proxy instance for internal usage - // We don't Start() it and so no listen port is required. + // internalProxy resolves internal requests from the application itself. It + // isn't started and so no listen ports are required. internalProxy *proxy.Proxy // isRunning is true if the DNS server is running. @@ -185,6 +199,7 @@ type DNSCreateParams struct { DHCPServer DHCP PrivateNets netutil.SubnetSet Anonymizer *aghnet.IPMut + EtcHosts *aghnet.HostsContainer LocalDomain string } @@ -224,6 +239,7 @@ func NewServer(p DNSCreateParams) (s *Server, err error) { privateNets: p.PrivateNets, // TODO(e.burkov): Use some case-insensitive string comparison. localDomainSuffix: strings.ToLower(localDomainSuffix), + etcHosts: p.EtcHosts, recDetector: newRecursionDetector(recursionTTL, cachedRecurrentReqNum), clientIDCache: cache.New(cache.Config{ EnableLRU: true, @@ -421,7 +437,7 @@ func hostFromPTR(resp *dns.Msg) (host string, ttl time.Duration, err error) { return "", 0, ErrRDNSNoData } -// Start starts the DNS server. +// Start starts the DNS server. It must only be called after [Server.Prepare]. func (s *Server) Start() error { s.serverLock.Lock() defer s.serverLock.Unlock() @@ -429,12 +445,14 @@ func (s *Server) Start() error { return s.startLocked() } -// startLocked starts the DNS server without locking. For internal use only. +// startLocked starts the DNS server without locking. s.serverLock is expected +// to be locked. func (s *Server) startLocked() error { err := s.dnsProxy.Start() if err == nil { s.isRunning = true } + return err } @@ -443,21 +461,20 @@ func (s *Server) startLocked() error { // faster than ordinary upstreams. const defaultLocalTimeout = 1 * time.Second -// setupLocalResolvers initializes the resolvers for local addresses. For -// internal use only. -func (s *Server) setupLocalResolvers() (err error) { - matcher, err := s.conf.ourAddrsMatcher() +// setupLocalResolvers initializes the resolvers for local addresses. It +// assumes s.serverLock is locked or the Server not running. +func (s *Server) setupLocalResolvers(boot upstream.Resolver) (err error) { + set, err := s.conf.ourAddrsSet() if err != nil { // Don't wrap the error because it's informative enough as is. return err } - bootstraps := s.conf.BootstrapDNS resolvers := s.conf.LocalPTRResolvers filterConfig := false if len(resolvers) == 0 { - sysResolvers := slices.DeleteFunc(s.sysResolvers.Addrs(), matcher) + sysResolvers := slices.DeleteFunc(slices.Clone(s.sysResolvers.Addrs()), set.Has) resolvers = make([]string, 0, len(sysResolvers)) for _, r := range sysResolvers { resolvers = append(resolvers, r.String()) @@ -470,7 +487,7 @@ func (s *Server) setupLocalResolvers() (err error) { log.Debug("dnsforward: upstreams to resolve ptr for local addresses: %v", resolvers) uc, err := s.prepareUpstreamConfig(resolvers, nil, &upstream.Options{ - Bootstrap: bootstraps, + Bootstrap: boot, Timeout: defaultLocalTimeout, // TODO(e.burkov): Should we verify server's certificates? PreferIPv6: s.conf.BootstrapPreferIPv6, @@ -480,7 +497,7 @@ func (s *Server) setupLocalResolvers() (err error) { } if filterConfig { - if err = matcher.filterOut(uc); err != nil { + if err = filterOutAddrs(uc, set); err != nil { return fmt.Errorf("filtering private upstreams: %w", err) } } @@ -491,6 +508,7 @@ func (s *Server) setupLocalResolvers() (err error) { }, } + // TODO(e.burkov): Should we also consider the DNS64 usage? if s.conf.UsePrivateRDNS && // Only set the upstream config if there are any upstreams. It's safe // to put nil into [proxy.Config.PrivateRDNSUpstreamConfig]. @@ -517,31 +535,19 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) { s.initDefaultSettings() - err = s.prepareIpsetListSettings() - if err != nil { - // Don't wrap the error, because it's informative enough as is. - return fmt.Errorf("preparing ipset settings: %w", err) - } - - err = s.prepareUpstreamSettings() + boot, err := s.prepareInternalDNS() if err != nil { // Don't wrap the error, because it's informative enough as is. return err } - var proxyConfig proxy.Config - proxyConfig, err = s.createProxyConfig() + proxyConfig, err := s.createProxyConfig() if err != nil { return fmt.Errorf("preparing proxy: %w", err) } s.setupDNS64() - err = s.prepareInternalProxy() - if err != nil { - return fmt.Errorf("preparing internal proxy: %w", err) - } - s.access, err = newAccessCtx( s.conf.AllowedClients, s.conf.DisallowedClients, @@ -556,7 +562,7 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) { // TODO(e.burkov): Remove once the local resolvers logic moved to dnsproxy. s.dnsProxy = &proxy.Proxy{Config: proxyConfig} - err = s.setupLocalResolvers() + err = s.setupLocalResolvers(boot) if err != nil { return fmt.Errorf("setting up resolvers: %w", err) } @@ -575,6 +581,38 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) { return nil } +// prepareInternalDNS initializes the internal state of s before initializing +// the primary DNS proxy instance. It assumes s.serverLock is locked or the +// Server not running. +func (s *Server) prepareInternalDNS() (boot upstream.Resolver, err error) { + err = s.prepareIpsetListSettings() + if err != nil { + return nil, fmt.Errorf("preparing ipset settings: %w", err) + } + + s.bootstrap, s.bootResolvers, err = s.createBootstrap(s.conf.BootstrapDNS, &upstream.Options{ + Timeout: DefaultTimeout, + HTTPVersions: UpstreamHTTPVersions(s.conf.UseHTTP3Upstreams), + }) + if err != nil { + // Don't wrap the error, because it's informative enough as is. + return nil, err + } + + err = s.prepareUpstreamSettings(boot) + if err != nil { + // Don't wrap the error, because it's informative enough as is. + return s.bootstrap, err + } + + err = s.prepareInternalProxy() + if err != nil { + return s.bootstrap, fmt.Errorf("preparing internal proxy: %w", err) + } + + return s.bootstrap, nil +} + // setupFallbackDNS initializes the fallback DNS servers. func (s *Server) setupFallbackDNS() (err error) { fallbacks := s.conf.FallbackDNS @@ -598,7 +636,8 @@ func (s *Server) setupFallbackDNS() (err error) { return nil } -// setupAddrProc initializes the address processor. For internal use only. +// setupAddrProc initializes the address processor. It assumes s.serverLock is +// locked or the Server not running. func (s *Server) setupAddrProc() { // TODO(a.garipov): This is a crutch for tests; remove. if s.conf.AddrProcConf == nil { @@ -687,7 +726,8 @@ func (s *Server) Stop() error { return s.stopLocked() } -// stopLocked stops the DNS server without locking. For internal use only. +// stopLocked stops the DNS server without locking. s.serverLock is expected to +// be locked. func (s *Server) stopLocked() (err error) { // TODO(e.burkov, a.garipov): Return critical errors, not just log them. // This will require filtering all the non-critical errors in @@ -700,18 +740,11 @@ func (s *Server) stopLocked() (err error) { } } - if upsConf := s.internalProxy.UpstreamConfig; upsConf != nil { - err = upsConf.Close() - if err != nil { - log.Error("dnsforward: closing internal resolvers: %s", err) - } - } + logCloserErr(s.internalProxy.UpstreamConfig, "dnsforward: closing internal resolvers: %s") + logCloserErr(s.localResolvers.UpstreamConfig, "dnsforward: closing local resolvers: %s") - if upsConf := s.localResolvers.UpstreamConfig; upsConf != nil { - err = upsConf.Close() - if err != nil { - log.Error("dnsforward: closing local resolvers: %s", err) - } + for _, b := range s.bootResolvers { + logCloserErr(b, "dnsforward: closing bootstrap %s: %s", b.Address()) } s.isRunning = false @@ -719,6 +752,18 @@ func (s *Server) stopLocked() (err error) { return nil } +// logCloserErr logs the error returned by c, if any. +func logCloserErr(c io.Closer, format string, args ...any) { + if c == nil { + return + } + + err := c.Close() + if err != nil { + log.Error(format, append(args, err)...) + } +} + // IsRunning returns true if the DNS server is running. func (s *Server) IsRunning() bool { s.serverLock.RLock() diff --git a/internal/dnsforward/dnsforward_test.go b/internal/dnsforward/dnsforward_test.go index f315a661..2b52aba0 100644 --- a/internal/dnsforward/dnsforward_test.go +++ b/internal/dnsforward/dnsforward_test.go @@ -656,17 +656,20 @@ func TestServerCustomClientUpstream(t *testing.T) { s := createTestServer(t, &filtering.Config{ BlockingMode: filtering.BlockingModeDefault, }, forwardConf, nil) - s.conf.GetCustomUpstreamByClient = func(_ string) (conf *proxy.UpstreamConfig, err error) { - ups := aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) { - return aghalg.Coalesce( - aghtest.MatchedResponse(req, dns.TypeA, "host", "192.168.0.1"), - new(dns.Msg).SetRcode(req, dns.RcodeNameError), - ), nil - }) - return &proxy.UpstreamConfig{ - Upstreams: []upstream.Upstream{ups}, - }, nil + ups := aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) { + return aghalg.Coalesce( + aghtest.MatchedResponse(req, dns.TypeA, "host", "192.168.0.1"), + new(dns.Msg).SetRcode(req, dns.RcodeNameError), + ), nil + }) + s.conf.ClientsContainer = &aghtest.ClientsContainer{ + OnUpstreamConfigByID: func( + _ string, + _ upstream.Resolver, + ) (conf *proxy.UpstreamConfig, err error) { + return &proxy.UpstreamConfig{Upstreams: []upstream.Upstream{ups}}, nil + }, } startDeferStop(t, s) diff --git a/internal/dnsforward/http.go b/internal/dnsforward/http.go index 816e7aff..53874578 100644 --- a/internal/dnsforward/http.go +++ b/internal/dnsforward/http.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "io" + "net" "net/http" "net/netip" "strings" @@ -197,13 +198,13 @@ func (s *Server) getDNSConfig() (c *jsonDNSConfig) { // defaultLocalPTRUpstreams returns the list of default local PTR resolvers // filtered of AdGuard Home's own DNS server addresses. It may appear empty. func (s *Server) defaultLocalPTRUpstreams() (ups []string, err error) { - matcher, err := s.conf.ourAddrsMatcher() + matcher, err := s.conf.ourAddrsSet() if err != nil { // Don't wrap the error because it's informative enough as is. return nil, err } - sysResolvers := slices.DeleteFunc(s.sysResolvers.Addrs(), matcher) + sysResolvers := slices.DeleteFunc(s.sysResolvers.Addrs(), matcher.Has) ups = make([]string, 0, len(sysResolvers)) for _, r := range sysResolvers { ups = append(ups, r.String()) @@ -575,7 +576,7 @@ func newUpstreamConfig(upstreams []string) (conf *proxy.UpstreamConfig, err erro conf, err = proxy.ParseUpstreamsConfig( upstreams, &upstream.Options{ - Bootstrap: []string{}, + Bootstrap: net.DefaultResolver, Timeout: DefaultTimeout, }, ) @@ -890,22 +891,11 @@ func (s *Server) checkUpstreamAddr( } }() - opts = &upstream.Options{ + u, err := upstream.AddressToUpstream(addr, &upstream.Options{ Bootstrap: opts.Bootstrap, Timeout: opts.Timeout, PreferIPv6: opts.PreferIPv6, - } - - // dnsFilter can be nil during application update. - if s.dnsFilter != nil { - recs := s.dnsFilter.EtcHostsRecords(extractUpstreamHost(addr)) - for _, rec := range recs { - opts.ServerIPAddrs = append(opts.ServerIPAddrs, rec.Addr.AsSlice()) - } - sortNetIPAddrs(opts.ServerIPAddrs, opts.PreferIPv6) - } - - u, err := upstream.AddressToUpstream(addr, opts) + }) if err != nil { return fmt.Errorf("creating upstream for %q: %w", addr, err) } @@ -915,6 +905,13 @@ func (s *Server) checkUpstreamAddr( return check(u) } +// closeBoots closes all the provided bootstrap servers and logs errors if any. +func closeBoots(boots []*upstream.UpstreamResolver) { + for _, c := range boots { + logCloserErr(c, "dnsforward: closing bootstrap %s: %s", c.Address()) + } +} + // handleTestUpstreamDNS handles requests to the POST /control/test_upstream_dns // endpoint. func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) { @@ -929,15 +926,21 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) { req.Upstreams = stringutil.FilterOut(req.Upstreams, IsCommentOrEmpty) req.FallbackDNS = stringutil.FilterOut(req.FallbackDNS, IsCommentOrEmpty) req.PrivateUpstreams = stringutil.FilterOut(req.PrivateUpstreams, IsCommentOrEmpty) + req.BootstrapDNS = stringutil.FilterOut(req.BootstrapDNS, IsCommentOrEmpty) opts := &upstream.Options{ - Bootstrap: req.BootstrapDNS, Timeout: s.conf.UpstreamTimeout, PreferIPv6: s.conf.BootstrapPreferIPv6, } - if len(opts.Bootstrap) == 0 { - opts.Bootstrap = defaultBootstrap + + var boots []*upstream.UpstreamResolver + opts.Bootstrap, boots, err = s.createBootstrap(req.BootstrapDNS, opts) + if err != nil { + aghhttp.Error(r, w, http.StatusBadRequest, "Failed to parse bootstrap servers: %s", err) + + return } + defer closeBoots(boots) wg := &sync.WaitGroup{} m := &sync.Map{} diff --git a/internal/dnsforward/http_test.go b/internal/dnsforward/http_test.go index 2beed7f6..99b03786 100644 --- a/internal/dnsforward/http_test.go +++ b/internal/dnsforward/http_test.go @@ -223,7 +223,7 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) { `upstream servers: validating upstream "!!!": not an ip:port`, }, { name: "bootstraps_bad", - wantSet: `validating dns config: checking bootstrap a: invalid address: bootstrap a:53: ` + + wantSet: `validating dns config: checking bootstrap a: invalid address: not a bootstrap: ` + `ParseAddr("a"): unable to parse IP`, }, { name: "cache_bad_ttl", @@ -534,6 +534,7 @@ func TestServer_HandleTestUpstreamDNS(t *testing.T) { EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, }, }, nil) + srv.etcHosts = hc startDeferStop(t, srv) testCases := []struct { diff --git a/internal/dnsforward/process.go b/internal/dnsforward/process.go index 0b572d8b..11e7459d 100644 --- a/internal/dnsforward/process.go +++ b/internal/dnsforward/process.go @@ -831,14 +831,13 @@ func (s *Server) dhcpHostFromRequest(q *dns.Question) (reqHost string) { // setCustomUpstream sets custom upstream settings in pctx, if necessary. func (s *Server) setCustomUpstream(pctx *proxy.DNSContext, clientID string) { - customUpsByClient := s.conf.GetCustomUpstreamByClient - if pctx.Addr == nil || customUpsByClient == nil { + if pctx.Addr == nil || s.conf.ClientsContainer == nil { return } // Use the ClientID first, since it has a higher priority. id := stringutil.Coalesce(clientID, ipStringFromAddr(pctx.Addr)) - upsConf, err := customUpsByClient(id) + upsConf, err := s.conf.ClientsContainer.UpstreamConfigByID(id, s.bootstrap) if err != nil { log.Error("dnsforward: getting custom upstreams for client %s: %s", id, err) @@ -847,9 +846,9 @@ func (s *Server) setCustomUpstream(pctx *proxy.DNSContext, clientID string) { if upsConf != nil { log.Debug("dnsforward: using custom upstreams for client %s", id) - } - pctx.CustomUpstreamConfig = upsConf + pctx.CustomUpstreamConfig = upsConf + } } // Apply filtering logic after we have received response from upstream servers diff --git a/internal/dnsforward/upstreams.go b/internal/dnsforward/upstreams.go index f30dce69..e71a9672 100644 --- a/internal/dnsforward/upstreams.go +++ b/internal/dnsforward/upstreams.go @@ -1,21 +1,15 @@ package dnsforward import ( - "bytes" "fmt" - "net" - "net/url" "os" - "strings" "time" + "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/log" - "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/stringutil" - "golang.org/x/exp/maps" - "golang.org/x/exp/slices" ) // loadUpstreams parses upstream DNS servers from the configured file or from @@ -39,7 +33,7 @@ func (s *Server) loadUpstreams() (upstreams []string, err error) { } // prepareUpstreamSettings sets upstream DNS server settings. -func (s *Server) prepareUpstreamSettings() (err error) { +func (s *Server) prepareUpstreamSettings(boot upstream.Resolver) (err error) { // Load upstreams either from the file, or from the settings var upstreams []string upstreams, err = s.loadUpstreams() @@ -48,7 +42,7 @@ func (s *Server) prepareUpstreamSettings() (err error) { } s.conf.UpstreamConfig, err = s.prepareUpstreamConfig(upstreams, defaultDNS, &upstream.Options{ - Bootstrap: s.conf.BootstrapDNS, + Bootstrap: boot, Timeout: s.conf.UpstreamTimeout, HTTPVersions: UpstreamHTTPVersions(s.conf.UseHTTP3Upstreams), PreferIPv6: s.conf.BootstrapPreferIPv6, @@ -92,178 +86,9 @@ func (s *Server) prepareUpstreamConfig( uc.Upstreams = defaultUpstreamConfig.Upstreams } - // dnsFilter can be nil during application update. - if s.dnsFilter != nil { - err = s.replaceUpstreamsWithHosts(uc, opts) - if err != nil { - return nil, fmt.Errorf("resolving upstreams with hosts: %w", err) - } - } - return uc, nil } -// replaceUpstreamsWithHosts replaces unique upstreams with their resolved -// versions based on the system hosts file. -// -// TODO(e.burkov): This should be performed inside dnsproxy, which should -// actually consider /etc/hosts. See TODO on [aghnet.HostsContainer]. -func (s *Server) replaceUpstreamsWithHosts( - upsConf *proxy.UpstreamConfig, - opts *upstream.Options, -) (err error) { - resolved := map[string]*upstream.Options{} - - err = s.resolveUpstreamsWithHosts(resolved, upsConf.Upstreams, opts) - if err != nil { - return fmt.Errorf("resolving upstreams: %w", err) - } - - hosts := maps.Keys(upsConf.DomainReservedUpstreams) - // TODO(e.burkov): Think of extracting sorted range into an util function. - slices.Sort(hosts) - for _, host := range hosts { - err = s.resolveUpstreamsWithHosts(resolved, upsConf.DomainReservedUpstreams[host], opts) - if err != nil { - return fmt.Errorf("resolving upstreams reserved for %s: %w", host, err) - } - } - - hosts = maps.Keys(upsConf.SpecifiedDomainUpstreams) - slices.Sort(hosts) - for _, host := range hosts { - err = s.resolveUpstreamsWithHosts(resolved, upsConf.SpecifiedDomainUpstreams[host], opts) - if err != nil { - return fmt.Errorf("resolving upstreams specific for %s: %w", host, err) - } - } - - return nil -} - -// resolveUpstreamsWithHosts resolves the IP addresses of each of the upstreams -// and replaces those both in upstreams and resolved. Upstreams that failed to -// resolve are placed to resolved as-is. This function only returns error of -// upstreams closing. -func (s *Server) resolveUpstreamsWithHosts( - resolved map[string]*upstream.Options, - upstreams []upstream.Upstream, - opts *upstream.Options, -) (err error) { - for i := range upstreams { - u := upstreams[i] - addr := u.Address() - host := extractUpstreamHost(addr) - - withIPs, ok := resolved[host] - if !ok { - recs := s.dnsFilter.EtcHostsRecords(host) - if len(recs) == 0 { - resolved[host] = nil - - return nil - } - - withIPs = opts.Clone() - withIPs.ServerIPAddrs = make([]net.IP, 0, len(recs)) - for _, rec := range recs { - withIPs.ServerIPAddrs = append(withIPs.ServerIPAddrs, rec.Addr.AsSlice()) - } - - sortNetIPAddrs(withIPs.ServerIPAddrs, opts.PreferIPv6) - resolved[host] = withIPs - } else if withIPs == nil { - continue - } - - if err = u.Close(); err != nil { - return fmt.Errorf("closing upstream %s: %w", addr, err) - } - - upstreams[i], err = upstream.AddressToUpstream(addr, withIPs) - if err != nil { - return fmt.Errorf("replacing upstream %s with resolved %s: %w", addr, host, err) - } - - log.Debug("dnsforward: using %s for %s", withIPs.ServerIPAddrs, upstreams[i].Address()) - } - - return nil -} - -// extractUpstreamHost returns the hostname of addr without port with an -// assumption that any address passed here has already been successfully parsed -// by [upstream.AddressToUpstream]. This function essentially mirrors the logic -// of [upstream.AddressToUpstream], see TODO on [replaceUpstreamsWithHosts]. -func extractUpstreamHost(addr string) (host string) { - var err error - if strings.Contains(addr, "://") { - var u *url.URL - u, err = url.Parse(addr) - if err != nil { - log.Debug("dnsforward: parsing upstream %s: %s", addr, err) - - return addr - } - - return u.Hostname() - } - - // Probably, plain UDP upstream defined by address or address:port. - host, err = netutil.SplitHost(addr) - if err != nil { - return addr - } - - return host -} - -// sortNetIPAddrs sorts addrs in accordance with the protocol preferences. -// Invalid addresses are sorted near the end. -// -// TODO(e.burkov): This function taken from dnsproxy, which also already -// contains a few similar functions. Think of moving to golibs. -func sortNetIPAddrs(addrs []net.IP, preferIPv6 bool) { - l := len(addrs) - if l <= 1 { - return - } - - slices.SortStableFunc(addrs, func(addrA, addrB net.IP) (res int) { - switch len(addrA) { - case net.IPv4len, net.IPv6len: - switch len(addrB) { - case net.IPv4len, net.IPv6len: - // Go on. - default: - return -1 - } - default: - return 1 - } - - // Treat IPv6-mapped IPv4 addresses as IPv6 addresses. - aIs4, bIs4 := addrA.To4() != nil, addrB.To4() != nil - if aIs4 == bIs4 { - return bytes.Compare(addrA, addrB) - } - - if aIs4 { - if preferIPv6 { - return 1 - } - - return -1 - } - - if preferIPv6 { - return -1 - } - - return 1 - }) -} - // UpstreamHTTPVersions returns the HTTP versions for upstream configuration // depending on configuration. func UpstreamHTTPVersions(http3 bool) (v []upstream.HTTPVersion) { @@ -295,3 +120,41 @@ func setProxyUpstreamMode( conf.UpstreamMode = proxy.UModeLoadBalance } } + +// createBootstrap returns a bootstrap resolver based on the configuration of s. +// boots are the upstream resolvers that should be closed after use. r is the +// actual bootstrap resolver, which may include the system hosts. +// +// TODO(e.burkov): This function currently returns a resolver and a slice of +// the upstream resolvers, which are essentially the same. boots are returned +// for being able to close them afterwards, but it introduces an implicit +// contract that r could only be used before that. Anyway, this code should +// improve when the [proxy.UpstreamConfig] will become an [upstream.Resolver] +// and be used here. +func (s *Server) createBootstrap( + addrs []string, + opts *upstream.Options, +) (r upstream.Resolver, boots []*upstream.UpstreamResolver, err error) { + if len(addrs) == 0 { + addrs = defaultBootstrap + } + + boots, err = aghnet.ParseBootstraps(addrs, opts) + if err != nil { + // Don't wrap the error, since it's informative enough as is. + return nil, nil, err + } + + var parallel upstream.ParallelResolver + for _, b := range boots { + parallel = append(parallel, b) + } + + if s.etcHosts != nil { + r = upstream.ConsequentResolver{s.etcHosts, parallel} + } else { + r = parallel + } + + return r, boots, nil +} diff --git a/internal/filtering/filtering.go b/internal/filtering/filtering.go index 6fae9c60..703f6c71 100644 --- a/internal/filtering/filtering.go +++ b/internal/filtering/filtering.go @@ -98,6 +98,8 @@ type Config struct { // EtcHosts is a container of IP-hostname pairs taken from the operating // system configuration files (e.g. /etc/hosts). + // + // TODO(e.burkov): Move it to dnsforward entirely. EtcHosts *aghnet.HostsContainer `yaml:"-"` // Called when the configuration is changed by HTTP request diff --git a/internal/home/clients.go b/internal/home/clients.go index 68e0b7d1..d9f5dcd4 100644 --- a/internal/home/clients.go +++ b/internal/home/clients.go @@ -126,7 +126,13 @@ func (clients *clientsContainer) Init( return nil } - if clients.etcHosts != nil { + // The clients.etcHosts may be nil even if config.Clients.Sources.HostsFile + // is true, because of the deprecated option --no-etc-hosts. + // + // TODO(e.burkov): The option should probably be returned, since hosts file + // currently used not only for clients' information enrichment, but also in + // the filtering module and upstream addresses resolution. + if config.Clients.Sources.HostsFile && clients.etcHosts != nil { go clients.handleHostsUpdates() } @@ -419,11 +425,14 @@ func (clients *clientsContainer) shouldCountClient(ids []string) (y bool) { return true } -// findUpstreams returns upstreams configured for the client, identified either -// by its IP address or its ClientID. upsConf is nil if the client isn't found -// or if the client has no custom upstreams. -func (clients *clientsContainer) findUpstreams( +// type check +var _ dnsforward.ClientsContainer = (*clientsContainer)(nil) + +// UpstreamConfigByID implements the [dnsforward.ClientsContainer] interface for +// *clientsContainer. +func (clients *clientsContainer) UpstreamConfigByID( id string, + bootstrap upstream.Resolver, ) (upsConf *proxy.UpstreamConfig, err error) { clients.lock.Lock() defer clients.lock.Unlock() @@ -431,6 +440,8 @@ func (clients *clientsContainer) findUpstreams( c, ok := clients.findLocked(id) if !ok { return nil, nil + } else if c.upstreamConfig != nil { + return c.upstreamConfig, nil } upstreams := stringutil.FilterOut(c.Upstreams, dnsforward.IsCommentOrEmpty) @@ -438,21 +449,18 @@ func (clients *clientsContainer) findUpstreams( return nil, nil } - if c.upstreamConfig != nil { - return c.upstreamConfig, nil - } - var conf *proxy.UpstreamConfig conf, err = proxy.ParseUpstreamsConfig( upstreams, &upstream.Options{ - Bootstrap: config.DNS.BootstrapDNS, + Bootstrap: bootstrap, Timeout: config.DNS.UpstreamTimeout.Duration, HTTPVersions: dnsforward.UpstreamHTTPVersions(config.DNS.UseHTTP3Upstreams), PreferIPv6: config.DNS.BootstrapPreferIPv6, }, ) if err != nil { + // Don't wrap the error since it's informative enough as is. return nil, err } @@ -672,10 +680,6 @@ func (clients *clientsContainer) Del(name string) (ok bool) { return false } - if err := c.closeUpstreams(); err != nil { - log.Error("client container: removing client %s: %s", name, err) - } - clients.del(c) return true @@ -683,10 +687,14 @@ func (clients *clientsContainer) Del(name string) (ok bool) { // del removes c from the indexes. clients.lock is expected to be locked. func (clients *clientsContainer) del(c *Client) { - // update Name index + if err := c.closeUpstreams(); err != nil { + log.Error("client container: removing client %s: %s", c.Name, err) + } + + // Update the name index. delete(clients.list, c.Name) - // update ID index + // Update the ID index. for _, id := range c.IDs { delete(clients.idIndex, id) } diff --git a/internal/home/clients_internal_test.go b/internal/home/clients_internal_test.go index b8ef598f..2efd2d4a 100644 --- a/internal/home/clients_internal_test.go +++ b/internal/home/clients_internal_test.go @@ -355,11 +355,11 @@ func TestClientsCustomUpstream(t *testing.T) { require.NoError(t, err) assert.True(t, ok) - config, err := clients.findUpstreams("1.2.3.4") + config, err := clients.UpstreamConfigByID("1.2.3.4", net.DefaultResolver) assert.Nil(t, config) assert.NoError(t, err) - config, err = clients.findUpstreams("1.1.1.1") + config, err = clients.UpstreamConfigByID("1.1.1.1", net.DefaultResolver) require.NotNil(t, config) assert.NoError(t, err) assert.Len(t, config.Upstreams, 1) diff --git a/internal/home/dns.go b/internal/home/dns.go index d26dd0d8..fa822eeb 100644 --- a/internal/home/dns.go +++ b/internal/home/dns.go @@ -138,8 +138,9 @@ func initDNSServer( QueryLog: qlog, PrivateNets: privateNets, Anonymizer: anonymizer, - LocalDomain: config.DHCP.LocalDomainName, DHCPServer: dhcpSrv, + EtcHosts: Context.etcHosts, + LocalDomain: config.DHCP.LocalDomainName, }) if err != nil { closeDNSServer() @@ -288,7 +289,7 @@ func newServerConfig( newConf.TLSAllowUnencryptedDoH = tlsConf.AllowUnencryptedDoH newConf.FilterHandler = applyAdditionalFiltering - newConf.GetCustomUpstreamByClient = Context.clients.findUpstreams + newConf.ClientsContainer = &Context.clients newConf.LocalPTRResolvers = dnsConf.LocalPTRResolvers newConf.UpstreamTimeout = dnsConf.UpstreamTimeout.Duration diff --git a/internal/home/home.go b/internal/home/home.go index 7e57652e..ab2d83a2 100644 --- a/internal/home/home.go +++ b/internal/home/home.go @@ -6,7 +6,6 @@ import ( "crypto/x509" "fmt" "io/fs" - "net" "net/http" "net/netip" "net/url" @@ -160,7 +159,7 @@ func setupContext(opts options) (err error) { os.Exit(0) } - if !opts.noEtcHosts && config.Clients.Sources.HostsFile { + if !opts.noEtcHosts { err = setupHostsContainer() if err != nil { // Don't wrap the error, because it's informative enough as is. @@ -239,13 +238,13 @@ func setupHostsContainer() (err error) { ) if err != nil { closeErr := hostsWatcher.Close() - if errors.Is(err, aghnet.ErrNoHostsPaths) && closeErr == nil { + if errors.Is(err, aghnet.ErrNoHostsPaths) { log.Info("warning: initing hosts container: %s", err) - return nil + return closeErr } - return errors.WithDeferred(fmt.Errorf("initing hosts container: %w", err), closeErr) + return errors.Join(fmt.Errorf("initializing hosts container: %w", err), closeErr) } return nil @@ -294,19 +293,13 @@ func initContextClients() (err error) { arpDB = arpdb.New() } - err = Context.clients.Init( + return Context.clients.Init( config.Clients.Persistent, Context.dhcpServer, Context.etcHosts, arpDB, config.Filtering, ) - if err != nil { - // Don't wrap the error, because it's informative enough as is. - return err - } - - return nil } // setupBindOpts overrides bind host/port from the opts. @@ -376,11 +369,15 @@ func setupDNSFilteringConf(conf *filtering.Config) (err error) { upsOpts := &upstream.Options{ Timeout: dnsTimeout, - ServerIPAddrs: []net.IP{ - {94, 140, 14, 15}, - {94, 140, 15, 16}, - net.ParseIP("2a10:50c0::bad1:ff"), - net.ParseIP("2a10:50c0::bad2:ff"), + Bootstrap: upstream.StaticResolver{ + // 94.140.14.15. + netip.AddrFrom4([4]byte{94, 140, 14, 15}), + // 94.140.14.16. + netip.AddrFrom4([4]byte{94, 140, 14, 16}), + // 2a10:50c0::bad1:ff. + netip.AddrFrom16([16]byte{42, 16, 80, 192, 12: 186, 209, 0, 255}), + // 2a10:50c0::bad2:ff. + netip.AddrFrom16([16]byte{42, 16, 80, 192, 12: 186, 210, 0, 255}), }, } diff --git a/internal/next/dnssvc/dnssvc.go b/internal/next/dnssvc/dnssvc.go index 07ce4096..e0056dcb 100644 --- a/internal/next/dnssvc/dnssvc.go +++ b/internal/next/dnssvc/dnssvc.go @@ -12,11 +12,14 @@ import ( "sync/atomic" "time" + "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/next/agh" + // TODO(a.garipov): Add a “dnsproxy proxy” package to shield us from changes // and replacement of module dnsproxy. "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" + "github.com/AdguardTeam/golibs/errors" ) // Service is the AdGuard Home DNS service. A nil *Service is a valid @@ -27,6 +30,7 @@ import ( type Service struct { proxy *proxy.Proxy bootstraps []string + bootstrapResolvers []*upstream.UpstreamResolver upstreams []string dns64Prefixes []netip.Prefix upsTimeout time.Duration @@ -52,7 +56,7 @@ func New(c *Config) (svc *Service, err error) { useDNS64: c.UseDNS64, } - upstreams, err := addressesToUpstreams( + upstreams, resolvers, err := addressesToUpstreams( c.UpstreamServers, c.BootstrapServers, c.UpstreamTimeout, @@ -62,6 +66,7 @@ func New(c *Config) (svc *Service, err error) { return nil, fmt.Errorf("converting upstreams: %w", err) } + svc.bootstrapResolvers = resolvers svc.proxy = &proxy.Proxy{ Config: proxy.Config{ UDPListenAddr: udpAddrs(c.Addresses), @@ -90,20 +95,37 @@ func addressesToUpstreams( bootstraps []string, timeout time.Duration, preferIPv6 bool, -) (upstreams []upstream.Upstream, err error) { +) (upstreams []upstream.Upstream, boots []*upstream.UpstreamResolver, err error) { + opts := &upstream.Options{ + Timeout: timeout, + PreferIPv6: preferIPv6, + } + + boots, err = aghnet.ParseBootstraps(bootstraps, opts) + if err != nil { + // Don't wrap the error, since it's informative enough as is. + return nil, nil, err + } + + // TODO(e.burkov): Add system hosts resolver here. + var bootstrap upstream.ParallelResolver + for _, r := range boots { + bootstrap = append(bootstrap, r) + } + upstreams = make([]upstream.Upstream, len(upsStrs)) for i, upsStr := range upsStrs { upstreams[i], err = upstream.AddressToUpstream(upsStr, &upstream.Options{ - Bootstrap: bootstraps, + Bootstrap: bootstrap, Timeout: timeout, PreferIPv6: preferIPv6, }) if err != nil { - return nil, fmt.Errorf("upstream at index %d: %w", i, err) + return nil, boots, fmt.Errorf("upstream at index %d: %w", i, err) } } - return upstreams, nil + return upstreams, boots, nil } // tcpAddrs converts []netip.AddrPort into []*net.TCPAddr. @@ -162,7 +184,15 @@ func (svc *Service) Shutdown(ctx context.Context) (err error) { return nil } - return svc.proxy.Stop() + errs := []error{ + svc.proxy.Stop(), + } + + for _, b := range svc.bootstrapResolvers { + errs = append(errs, errors.Annotate(b.Close(), "closing bootstrap %s: %w", b.Address())) + } + + return errors.Join(errs...) } // Config returns the current configuration of the web service. Config must not