mirror of
https://github.com/AdguardTeam/AdGuardHome.git
synced 2024-12-14 18:51:34 +03:00
Pull request 2193: AGDNS-1982 Upd proxy
Closes #6854.Updates #6875. Squashed commit of the following: commit b98adbc0cc6eeaffb262d57775c487e03b1d5ba5 Author: Eugene Burkov <E.Burkov@AdGuard.COM> Date: Wed Apr 10 19:21:44 2024 +0300 dnsforward: upd proxy, imp code, docs commit 4de1eb2bca1047426e02ba680c212f46782e5616 Author: Eugene Burkov <E.Burkov@AdGuard.COM> Date: Wed Apr 10 16:09:58 2024 +0300 WIP commit afa9d61e8dc129f907dc681cd2f831cb5c3b054a Author: Eugene Burkov <E.Burkov@AdGuard.COM> Date: Tue Apr 9 19:24:09 2024 +0300 all: log changes commit c8340676a448687a39acd26bc8ce5f94473e441f Author: Eugene Burkov <E.Burkov@AdGuard.COM> Date: Tue Apr 9 19:06:10 2024 +0300 dnsforward: move code commit 08bb7d43d2a3f689ef2ef2409935dc3946752e94 Author: Eugene Burkov <E.Burkov@AdGuard.COM> Date: Tue Apr 9 18:09:46 2024 +0300 dnsforward: imp code commit b27547ec806dd9bce502d3c6a7c28f33693ed575 Merge: b7efca7886f36ebc06
Author: Eugene Burkov <E.Burkov@AdGuard.COM> Date: Tue Apr 9 17:33:19 2024 +0300 Merge branch 'master' into AGDNS-1982-upd-proxy commit b7efca788b66aa672598b088040d4534ce2e55b0 Author: Eugene Burkov <E.Burkov@AdGuard.COM> Date: Tue Apr 9 17:27:14 2024 +0300 all: upd proxy finally commit 3e16fa87befe4c0ef3a3e7a638d7add28627f9b6 Author: Eugene Burkov <E.Burkov@AdGuard.COM> Date: Fri Apr 5 18:20:13 2024 +0300 dnsforward: upd proxy commit f3cdfc86334a182effcd0de22fac5e678fa53ea7 Author: Eugene Burkov <E.Burkov@AdGuard.COM> Date: Thu Apr 4 20:37:32 2024 +0300 all: upd proxy, golibs commit a79298d6d0504521893ee11fdc3a23c098aea911 Merge: 9feeba5c7fd25dcacb
Author: Eugene Burkov <E.Burkov@AdGuard.COM> Date: Thu Apr 4 20:34:01 2024 +0300 Merge branch 'master' into AGDNS-1982-upd-proxy commit 9feeba5c7f24ff1d308a216608d985cb2a7b7588 Author: Eugene Burkov <E.Burkov@AdGuard.COM> Date: Thu Apr 4 20:25:57 2024 +0300 all: imp code, docs commit 6c68d463db64293eb9c5e29ff91879fd68920a77 Merge: d8108e651ee619b2db
Author: Eugene Burkov <E.Burkov@AdGuard.COM> Date: Thu Apr 4 18:46:11 2024 +0300 Merge branch 'master' into AGDNS-1982-upd-proxy commit d8108e65164df8d67aa4e95154a8768a06255b78 Author: Eugene Burkov <E.Burkov@AdGuard.COM> Date: Wed Apr 3 19:25:27 2024 +0300 all: imp code commit 20461565801c9fcd06a652c6066b524b06c80433 Author: Eugene Burkov <E.Burkov@AdGuard.COM> Date: Wed Apr 3 17:10:33 2024 +0300 all: remove private rdns logic
This commit is contained in:
parent
6f36ebc06c
commit
ff7c715c5f
10
CHANGELOG.md
10
CHANGELOG.md
@ -27,7 +27,17 @@ NOTE: Add new changes BELOW THIS COMMENT.
|
||||
|
||||
- Support for comments in the ipset file ([#5345]).
|
||||
|
||||
### Fixed
|
||||
|
||||
- Subdomains of `in-addr.arpa` and `ip6.arpa` containing zero-length prefix
|
||||
incorrectly considered invalid when specified for private RDNS upstream
|
||||
servers ([#6854]).
|
||||
- Unspecified IP addresses aren't checked when using "Fastest IP address" mode
|
||||
([#6875]).
|
||||
|
||||
[#5345]: https://github.com/AdguardTeam/AdGuardHome/issues/5345
|
||||
[#6854]: https://github.com/AdguardTeam/AdGuardHome/issues/6854
|
||||
[#6875]: https://github.com/AdguardTeam/AdGuardHome/issues/6875
|
||||
|
||||
<!--
|
||||
NOTE: Add new changes ABOVE THIS COMMENT.
|
||||
|
10
go.mod
10
go.mod
@ -3,8 +3,8 @@ module github.com/AdguardTeam/AdGuardHome
|
||||
go 1.22.2
|
||||
|
||||
require (
|
||||
github.com/AdguardTeam/dnsproxy v0.67.0
|
||||
github.com/AdguardTeam/golibs v0.21.0
|
||||
github.com/AdguardTeam/dnsproxy v0.69.1
|
||||
github.com/AdguardTeam/golibs v0.23.0
|
||||
github.com/AdguardTeam/urlfilter v0.18.0
|
||||
github.com/NYTimes/gziphandler v1.1.1
|
||||
github.com/ameshkov/dnscrypt/v2 v2.2.7
|
||||
@ -28,12 +28,12 @@ require (
|
||||
// own code for that. Perhaps, use gopacket.
|
||||
github.com/mdlayher/raw v0.1.0
|
||||
github.com/miekg/dns v1.1.58
|
||||
github.com/quic-go/quic-go v0.41.0
|
||||
github.com/stretchr/testify v1.8.4
|
||||
github.com/quic-go/quic-go v0.42.0
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/ti-mo/netfilter v0.5.1
|
||||
go.etcd.io/bbolt v1.3.9
|
||||
golang.org/x/crypto v0.21.0
|
||||
golang.org/x/exp v0.0.0-20240222234643-814bf88cf225
|
||||
golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8
|
||||
golang.org/x/net v0.23.0
|
||||
golang.org/x/sys v0.18.0
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1
|
||||
|
26
go.sum
26
go.sum
@ -1,7 +1,7 @@
|
||||
github.com/AdguardTeam/dnsproxy v0.67.0 h1:7oKfcA8sm9d1N4qvhsNmQWBX4+fs3sX4cAnERmBXEbw=
|
||||
github.com/AdguardTeam/dnsproxy v0.67.0/go.mod h1:XLfD6IpSplUZZ+f5vhWSJW1mp4wm+KkHWiMo9w7U1Ls=
|
||||
github.com/AdguardTeam/golibs v0.21.0 h1:0swWyNaHTmT7aMwffKd9d54g4wBd8Oaj0fl+5l/PRdE=
|
||||
github.com/AdguardTeam/golibs v0.21.0/go.mod h1:/votX6WK1PdcZ3T2kBOPjPCGmfhlKixhI6ljYrFRPvI=
|
||||
github.com/AdguardTeam/dnsproxy v0.69.1 h1:KiLkKUSrvHeUO/YEf4Bbo/5zyFRIvQstjL7W9G/24pk=
|
||||
github.com/AdguardTeam/dnsproxy v0.69.1/go.mod h1:atO3WeeuyepyhjSt6hC+MF7/IN7TZHfG3/ZwhImHzYs=
|
||||
github.com/AdguardTeam/golibs v0.23.0 h1:PHz/QhJhLmoaOokkqrPFUgu9Hw4iVAqLtBP0O3g1D3Q=
|
||||
github.com/AdguardTeam/golibs v0.23.0/go.mod h1:/xZCf6gZZzz7k1qaoJmI+hhxN98kHFr7LJ22j1nLH0c=
|
||||
github.com/AdguardTeam/urlfilter v0.18.0 h1:ZZzwODC/ADpjJSODxySrrUnt/fvOCfGFaCW6j+wsGfQ=
|
||||
github.com/AdguardTeam/urlfilter v0.18.0/go.mod h1:IXxBwedLiZA2viyHkaFxY/8mjub0li2PXRg8a3d9Z1s=
|
||||
github.com/NYTimes/gziphandler v1.1.1 h1:ZUDjpQae29j0ryrS0u/B8HZfJBtBQHjqw2rQ2cqUQ3I=
|
||||
@ -101,19 +101,19 @@ github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
|
||||
github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo=
|
||||
github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A=
|
||||
github.com/quic-go/quic-go v0.41.0 h1:aD8MmHfgqTURWNJy48IYFg2OnxwHT3JL7ahGs73lb4k=
|
||||
github.com/quic-go/quic-go v0.41.0/go.mod h1:qCkNjqczPEvgsOnxZ0eCD14lv+B2LHlFAB++CNOh9hA=
|
||||
github.com/quic-go/quic-go v0.42.0 h1:uSfdap0eveIl8KXnipv9K7nlwZ5IqLlYOpJ58u5utpM=
|
||||
github.com/quic-go/quic-go v0.42.0/go.mod h1:132kz4kL3F9vxhW3CtQJLDVwcFe5wdWeJXXijhsO57M=
|
||||
github.com/shirou/gopsutil/v3 v3.23.7 h1:C+fHO8hfIppoJ1WdsVm1RoI0RwXoNdfTK7yWXV0wVj4=
|
||||
github.com/shirou/gopsutil/v3 v3.23.7/go.mod h1:c4gnmoRC0hQuaLqvxnx1//VXQ0Ms/X9UnJF8pddY5z4=
|
||||
github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFtM=
|
||||
github.com/shoenig/go-m1cpu v0.1.6/go.mod h1:1JJMcUBvfNwpq05QDQVAnx3gUHr9IYF7GNg9SUEw2VQ=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
|
||||
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/ti-mo/netfilter v0.2.0/go.mod h1:8GbBGsY/8fxtyIdfwy29JiluNcPK4K7wIT+x42ipqUU=
|
||||
github.com/ti-mo/netfilter v0.5.1 h1:cqamEd1c1zmpfpqvInLOro0Znq/RAfw2QL5wL2rAR/8=
|
||||
github.com/ti-mo/netfilter v0.5.1/go.mod h1:h9UPQ3ZrTZGBitay+LETMxZvNgWGK/efTUcqES2YiLw=
|
||||
@ -133,8 +133,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA=
|
||||
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
|
||||
golang.org/x/exp v0.0.0-20240222234643-814bf88cf225 h1:LfspQV/FYTatPTr/3HzIcmiUFH7PGP+OQ6mgDYo3yuQ=
|
||||
golang.org/x/exp v0.0.0-20240222234643-814bf88cf225/go.mod h1:CxmFvTBINI24O/j8iY7H1xHzx2i4OsyguNBmN/uPtqc=
|
||||
golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 h1:aAcj0Da7eBAtrTp03QXWvm88pSyOt+UgdZw2BFZ+lEw=
|
||||
golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8/go.mod h1:CQ1k9gNrJ50XIzaKCRR2hssIjF07kZFEiieALBM/ARQ=
|
||||
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
|
||||
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
|
||||
golang.org/x/mod v0.16.0 h1:QX4fJ0Rr5cPQCF7O9lh9Se4pmwfwskqZfq5moyldzic=
|
||||
@ -168,6 +168,8 @@ golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
|
||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
|
||||
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
|
||||
golang.org/x/tools v0.19.0 h1:tfGCXNR1OsFG+sVdLAitlpjAvD/I6dHDKnYrpEZUHkw=
|
||||
|
113
internal/dnsforward/beforerequest.go
Normal file
113
internal/dnsforward/beforerequest.go
Normal file
@ -0,0 +1,113 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// type check
|
||||
var _ proxy.BeforeRequestHandler = (*Server)(nil)
|
||||
|
||||
// HandleBefore is the handler that is called before any other processing,
|
||||
// including logs. It performs access checks and puts the client ID, if there
|
||||
// is one, into the server's cache.
|
||||
//
|
||||
// TODO(e.burkov): Write tests.
|
||||
func (s *Server) HandleBefore(
|
||||
_ *proxy.Proxy,
|
||||
pctx *proxy.DNSContext,
|
||||
) (err error) {
|
||||
clientID, err := s.clientIDFromDNSContext(pctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting clientid: %w", err)
|
||||
}
|
||||
|
||||
blocked, _ := s.IsBlockedClient(pctx.Addr.Addr(), clientID)
|
||||
if blocked {
|
||||
return s.preBlockedResponse(pctx)
|
||||
}
|
||||
|
||||
if len(pctx.Req.Question) == 1 {
|
||||
q := pctx.Req.Question[0]
|
||||
qt := q.Qtype
|
||||
host := aghnet.NormalizeDomain(q.Name)
|
||||
if s.access.isBlockedHost(host, qt) {
|
||||
log.Debug("access: request %s %s is in access blocklist", dns.Type(qt), host)
|
||||
|
||||
return s.preBlockedResponse(pctx)
|
||||
}
|
||||
}
|
||||
|
||||
if clientID != "" {
|
||||
key := [8]byte{}
|
||||
binary.BigEndian.PutUint64(key[:], pctx.RequestID)
|
||||
s.clientIDCache.Set(key[:], []byte(clientID))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// clientIDFromDNSContext extracts the client's ID from the server name of the
|
||||
// client's DoT or DoQ request or the path of the client's DoH. If the protocol
|
||||
// is not one of these, clientID is an empty string and err is nil.
|
||||
func (s *Server) clientIDFromDNSContext(pctx *proxy.DNSContext) (clientID string, err error) {
|
||||
proto := pctx.Proto
|
||||
if proto == proxy.ProtoHTTPS {
|
||||
clientID, err = clientIDFromDNSContextHTTPS(pctx)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("checking url: %w", err)
|
||||
} else if clientID != "" {
|
||||
return clientID, nil
|
||||
}
|
||||
|
||||
// Go on and check the domain name as well.
|
||||
} else if proto != proxy.ProtoTLS && proto != proxy.ProtoQUIC {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
hostSrvName := s.conf.ServerName
|
||||
if hostSrvName == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
cliSrvName, err := clientServerName(pctx, proto)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
clientID, err = clientIDFromClientServerName(
|
||||
hostSrvName,
|
||||
cliSrvName,
|
||||
s.conf.StrictSNICheck,
|
||||
)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("clientid check: %w", err)
|
||||
}
|
||||
|
||||
return clientID, nil
|
||||
}
|
||||
|
||||
// errAccessBlocked is a sentinel error returned when a request is blocked by
|
||||
// access settings.
|
||||
var errAccessBlocked errors.Error = "blocked by access settings"
|
||||
|
||||
// preBlockedResponse returns a protocol-appropriate response for a request that
|
||||
// was blocked by access settings.
|
||||
func (s *Server) preBlockedResponse(pctx *proxy.DNSContext) (err error) {
|
||||
if pctx.Proto == proxy.ProtoUDP || pctx.Proto == proxy.ProtoDNSCrypt {
|
||||
// Return nil so that dnsproxy drops the connection and thus
|
||||
// prevent DNS amplification attacks.
|
||||
return errAccessBlocked
|
||||
}
|
||||
|
||||
return &proxy.BeforeRequestError{
|
||||
Err: errAccessBlocked,
|
||||
Response: s.makeResponseREFUSED(pctx.Req),
|
||||
}
|
||||
}
|
@ -110,46 +110,6 @@ type quicConnection interface {
|
||||
ConnectionState() (cs quic.ConnectionState)
|
||||
}
|
||||
|
||||
// clientIDFromDNSContext extracts the client's ID from the server name of the
|
||||
// client's DoT or DoQ request or the path of the client's DoH. If the protocol
|
||||
// is not one of these, clientID is an empty string and err is nil.
|
||||
func (s *Server) clientIDFromDNSContext(pctx *proxy.DNSContext) (clientID string, err error) {
|
||||
proto := pctx.Proto
|
||||
if proto == proxy.ProtoHTTPS {
|
||||
clientID, err = clientIDFromDNSContextHTTPS(pctx)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("checking url: %w", err)
|
||||
} else if clientID != "" {
|
||||
return clientID, nil
|
||||
}
|
||||
|
||||
// Go on and check the domain name as well.
|
||||
} else if proto != proxy.ProtoTLS && proto != proxy.ProtoQUIC {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
hostSrvName := s.conf.ServerName
|
||||
if hostSrvName == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
cliSrvName, err := clientServerName(pctx, proto)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
clientID, err = clientIDFromClientServerName(
|
||||
hostSrvName,
|
||||
cliSrvName,
|
||||
s.conf.StrictSNICheck,
|
||||
)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("clientid check: %w", err)
|
||||
}
|
||||
|
||||
return clientID, nil
|
||||
}
|
||||
|
||||
// clientServerName returns the TLS server name based on the protocol. For
|
||||
// DNS-over-HTTPS requests, it will return the hostname part of the Host header
|
||||
// if there is one.
|
||||
|
@ -235,9 +235,18 @@ type DNSCryptConfig struct {
|
||||
// ServerConfig represents server configuration.
|
||||
// The zero ServerConfig is empty and ready for use.
|
||||
type ServerConfig struct {
|
||||
UDPListenAddrs []*net.UDPAddr // UDP listen address
|
||||
TCPListenAddrs []*net.TCPAddr // TCP listen address
|
||||
UpstreamConfig *proxy.UpstreamConfig // Upstream DNS servers config
|
||||
// UDPListenAddrs is the list of addresses to listen for DNS-over-UDP.
|
||||
UDPListenAddrs []*net.UDPAddr
|
||||
|
||||
// TCPListenAddrs is the list of addresses to listen for DNS-over-TCP.
|
||||
TCPListenAddrs []*net.TCPAddr
|
||||
|
||||
// UpstreamConfig is the general configuration of upstream DNS servers.
|
||||
UpstreamConfig *proxy.UpstreamConfig
|
||||
|
||||
// PrivateRDNSUpstreamConfig is the configuration of upstream DNS servers
|
||||
// for private reverse DNS.
|
||||
PrivateRDNSUpstreamConfig *proxy.UpstreamConfig
|
||||
|
||||
// AddrProcConf defines the configuration for the client IP processor.
|
||||
// If nil, [client.EmptyAddrProc] is used.
|
||||
@ -306,24 +315,28 @@ func (s *Server) newProxyConfig() (conf *proxy.Config, err error) {
|
||||
trustedPrefixes := netutil.UnembedPrefixes(srvConf.TrustedProxies)
|
||||
|
||||
conf = &proxy.Config{
|
||||
HTTP3: srvConf.ServeHTTP3,
|
||||
Ratelimit: int(srvConf.Ratelimit),
|
||||
RatelimitSubnetLenIPv4: srvConf.RatelimitSubnetLenIPv4,
|
||||
RatelimitSubnetLenIPv6: srvConf.RatelimitSubnetLenIPv6,
|
||||
RatelimitWhitelist: srvConf.RatelimitWhitelist,
|
||||
RefuseAny: srvConf.RefuseAny,
|
||||
TrustedProxies: netutil.SliceSubnetSet(trustedPrefixes),
|
||||
CacheMinTTL: srvConf.CacheMinTTL,
|
||||
CacheMaxTTL: srvConf.CacheMaxTTL,
|
||||
CacheOptimistic: srvConf.CacheOptimistic,
|
||||
UpstreamConfig: srvConf.UpstreamConfig,
|
||||
BeforeRequestHandler: s.beforeRequestHandler,
|
||||
RequestHandler: s.handleDNSRequest,
|
||||
HTTPSServerName: aghhttp.UserAgent(),
|
||||
EnableEDNSClientSubnet: srvConf.EDNSClientSubnet.Enabled,
|
||||
MaxGoroutines: srvConf.MaxGoroutines,
|
||||
UseDNS64: srvConf.UseDNS64,
|
||||
DNS64Prefs: srvConf.DNS64Prefixes,
|
||||
HTTP3: srvConf.ServeHTTP3,
|
||||
Ratelimit: int(srvConf.Ratelimit),
|
||||
RatelimitSubnetLenIPv4: srvConf.RatelimitSubnetLenIPv4,
|
||||
RatelimitSubnetLenIPv6: srvConf.RatelimitSubnetLenIPv6,
|
||||
RatelimitWhitelist: srvConf.RatelimitWhitelist,
|
||||
RefuseAny: srvConf.RefuseAny,
|
||||
TrustedProxies: netutil.SliceSubnetSet(trustedPrefixes),
|
||||
CacheMinTTL: srvConf.CacheMinTTL,
|
||||
CacheMaxTTL: srvConf.CacheMaxTTL,
|
||||
CacheOptimistic: srvConf.CacheOptimistic,
|
||||
UpstreamConfig: srvConf.UpstreamConfig,
|
||||
PrivateRDNSUpstreamConfig: srvConf.PrivateRDNSUpstreamConfig,
|
||||
BeforeRequestHandler: s,
|
||||
RequestHandler: s.handleDNSRequest,
|
||||
HTTPSServerName: aghhttp.UserAgent(),
|
||||
EnableEDNSClientSubnet: srvConf.EDNSClientSubnet.Enabled,
|
||||
MaxGoroutines: srvConf.MaxGoroutines,
|
||||
UseDNS64: srvConf.UseDNS64,
|
||||
DNS64Prefs: srvConf.DNS64Prefixes,
|
||||
UsePrivateRDNS: srvConf.UsePrivateRDNS,
|
||||
PrivateSubnets: s.privateNets,
|
||||
MessageConstructor: s,
|
||||
}
|
||||
|
||||
if srvConf.EDNSClientSubnet.UseCustom {
|
||||
@ -459,6 +472,26 @@ func (s *Server) prepareIpsetListSettings() (err error) {
|
||||
return s.ipset.init(ipsets)
|
||||
}
|
||||
|
||||
// loadUpstreams parses upstream DNS servers from the configured file or from
|
||||
// the configuration itself.
|
||||
func (conf *ServerConfig) loadUpstreams() (upstreams []string, err error) {
|
||||
if conf.UpstreamDNSFileName == "" {
|
||||
return stringutil.FilterOut(conf.UpstreamDNS, IsCommentOrEmpty), nil
|
||||
}
|
||||
|
||||
var data []byte
|
||||
data, err = os.ReadFile(conf.UpstreamDNSFileName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading upstream from file: %w", err)
|
||||
}
|
||||
|
||||
upstreams = stringutil.SplitTrimmed(string(data), "\n")
|
||||
|
||||
log.Debug("dnsforward: got %d upstreams in %q", len(upstreams), conf.UpstreamDNSFileName)
|
||||
|
||||
return stringutil.FilterOut(upstreams, IsCommentOrEmpty), nil
|
||||
}
|
||||
|
||||
// collectListenAddr adds addrPort to addrs. It also adds its port to
|
||||
// unspecPorts if its address is unspecified.
|
||||
func collectListenAddr(
|
||||
@ -530,8 +563,8 @@ func (m *combinedAddrPortSet) Has(addrPort netip.AddrPort) (ok bool) {
|
||||
return m.ports.Has(addrPort.Port()) && m.addrs.Has(addrPort.Addr())
|
||||
}
|
||||
|
||||
// filterOut filters out all the upstreams that match um. It returns all the
|
||||
// closing errors joined.
|
||||
// filterOutAddrs filters out all the upstreams that match um. It returns all
|
||||
// the closing errors joined.
|
||||
func filterOutAddrs(upsConf *proxy.UpstreamConfig, set addrPortSet) (err error) {
|
||||
var errs []error
|
||||
delFunc := func(u upstream.Upstream) (ok bool) {
|
||||
|
@ -3,7 +3,6 @@ package dnsforward
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
@ -11,6 +10,7 @@ import (
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@ -64,6 +64,8 @@ func newRR(t *testing.T, name string, qtype uint16, ttl uint32, val any) (rr dns
|
||||
}
|
||||
|
||||
func TestServer_HandleDNSRequest_dns64(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const (
|
||||
ipv4Domain = "ipv4.only."
|
||||
ipv6Domain = "ipv6.only."
|
||||
@ -252,32 +254,33 @@ func TestServer_HandleDNSRequest_dns64(t *testing.T) {
|
||||
localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) {
|
||||
require.Len(pt, m.Question, 1)
|
||||
require.Equal(pt, m.Question[0].Name, ptr64Domain)
|
||||
resp := (&dns.Msg{
|
||||
Answer: []dns.RR{localRR},
|
||||
}).SetReply(m)
|
||||
|
||||
resp := (&dns.Msg{}).SetReply(m)
|
||||
resp.Answer = []dns.RR{localRR}
|
||||
|
||||
require.NoError(t, w.WriteMsg(resp))
|
||||
})
|
||||
localUpsAddr := aghtest.StartLocalhostUpstream(t, localUpsHdlr).String()
|
||||
|
||||
client := &dns.Client{
|
||||
Net: "tcp",
|
||||
Timeout: 1 * time.Second,
|
||||
Net: string(proxy.ProtoTCP),
|
||||
Timeout: testTimeout,
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
upsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
|
||||
q := req.Question[0]
|
||||
require.Contains(pt, tc.upsAns, q.Qtype)
|
||||
|
||||
require.Contains(pt, tc.upsAns, q.Qtype)
|
||||
answer := tc.upsAns[q.Qtype]
|
||||
|
||||
resp := (&dns.Msg{
|
||||
Answer: answer[sectionAnswer],
|
||||
Ns: answer[sectionAuthority],
|
||||
Extra: answer[sectionAdditional],
|
||||
}).SetReply(req)
|
||||
resp := (&dns.Msg{}).SetReply(req)
|
||||
resp.Answer = answer[sectionAnswer]
|
||||
resp.Ns = answer[sectionAuthority]
|
||||
resp.Extra = answer[sectionAdditional]
|
||||
|
||||
require.NoError(pt, w.WriteMsg(resp))
|
||||
})
|
||||
@ -307,10 +310,54 @@ func TestServer_HandleDNSRequest_dns64(t *testing.T) {
|
||||
|
||||
req := (&dns.Msg{}).SetQuestion(tc.qname, tc.qtype)
|
||||
|
||||
resp, _, excErr := client.Exchange(req, s.dnsProxy.Addr(proxy.ProtoTCP).String())
|
||||
resp, _, excErr := client.Exchange(req, s.proxy().Addr(proxy.ProtoTCP).String())
|
||||
require.NoError(t, excErr)
|
||||
|
||||
require.Equal(t, tc.wantAns, resp.Answer)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_dns64WithDisabledRDNS(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Shouldn't go to upstream at all.
|
||||
panicHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) {
|
||||
panic("not implemented")
|
||||
})
|
||||
upsAddr := aghtest.StartLocalhostUpstream(t, panicHdlr).String()
|
||||
localUpsAddr := aghtest.StartLocalhostUpstream(t, panicHdlr).String()
|
||||
|
||||
s := createTestServer(t, &filtering.Config{
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
}, ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
UseDNS64: true,
|
||||
Config: Config{
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
UpstreamDNS: []string{upsAddr},
|
||||
},
|
||||
UsePrivateRDNS: false,
|
||||
LocalPTRResolvers: []string{localUpsAddr},
|
||||
ServePlainDNS: true,
|
||||
})
|
||||
startDeferStop(t, s)
|
||||
|
||||
mappedIPv6 := net.ParseIP("64:ff9b::102:304")
|
||||
arpa, err := netutil.IPToReversedAddr(mappedIPv6)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := (&dns.Msg{}).SetQuestion(dns.Fqdn(arpa), dns.TypePTR)
|
||||
|
||||
cli := &dns.Client{
|
||||
Net: string(proxy.ProtoTCP),
|
||||
Timeout: testTimeout,
|
||||
}
|
||||
|
||||
resp, _, err := cli.Exchange(req, s.proxy().Addr(proxy.ProtoTCP).String())
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, dns.RcodeNameError, resp.Rcode)
|
||||
}
|
||||
|
@ -135,12 +135,6 @@ type Server struct {
|
||||
// WHOIS, etc.
|
||||
addrProc client.AddressProcessor
|
||||
|
||||
// localResolvers is a DNS proxy instance used to resolve PTR records for
|
||||
// addresses considered private as per the [privateNets].
|
||||
//
|
||||
// TODO(e.burkov): Remove once the local resolvers logic moved to dnsproxy.
|
||||
localResolvers *proxy.Proxy
|
||||
|
||||
// sysResolvers used to fetch system resolvers to use by default for private
|
||||
// PTR resolving.
|
||||
sysResolvers SystemResolvers
|
||||
@ -158,12 +152,6 @@ type Server struct {
|
||||
// [upstream.Resolver] interface.
|
||||
bootResolvers []*upstream.UpstreamResolver
|
||||
|
||||
// recDetector is a cache for recursive requests. It is used to detect and
|
||||
// prevent recursive requests only for private upstreams.
|
||||
//
|
||||
// See https://github.com/adguardTeam/adGuardHome/issues/3185#issuecomment-851048135.
|
||||
recDetector *recursionDetector
|
||||
|
||||
// dns64Pref is the NAT64 prefix used for DNS64 response mapping. The major
|
||||
// part of DNS64 happens inside the [proxy] package, but there still are
|
||||
// some places where response mapping is needed (e.g. DHCP).
|
||||
@ -212,14 +200,6 @@ type DNSCreateParams struct {
|
||||
LocalDomain string
|
||||
}
|
||||
|
||||
const (
|
||||
// recursionTTL is the time recursive request is cached for.
|
||||
recursionTTL = 1 * time.Second
|
||||
// cachedRecurrentReqNum is the maximum number of cached recurrent
|
||||
// requests.
|
||||
cachedRecurrentReqNum = 1000
|
||||
)
|
||||
|
||||
// NewServer creates a new instance of the dnsforward.Server
|
||||
// Note: this function must be called only once
|
||||
//
|
||||
@ -256,7 +236,6 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
|
||||
// TODO(e.burkov): Use some case-insensitive string comparison.
|
||||
localDomainSuffix: strings.ToLower(localDomainSuffix),
|
||||
etcHosts: etcHosts,
|
||||
recDetector: newRecursionDetector(recursionTTL, cachedRecurrentReqNum),
|
||||
clientIDCache: cache.New(cache.Config{
|
||||
EnableLRU: true,
|
||||
MaxCount: defaultClientIDCacheCount,
|
||||
@ -366,6 +345,7 @@ func (s *Server) Exchange(ip netip.Addr) (host string, ttl time.Duration, err er
|
||||
s.serverLock.RLock()
|
||||
defer s.serverLock.RUnlock()
|
||||
|
||||
// TODO(e.burkov): Migrate to [netip.Addr] already.
|
||||
arpa, err := netutil.IPToReversedAddr(ip.AsSlice())
|
||||
if err != nil {
|
||||
return "", 0, fmt.Errorf("reversing ip: %w", err)
|
||||
@ -386,25 +366,23 @@ func (s *Server) Exchange(ip netip.Addr) (host string, ttl time.Duration, err er
|
||||
}
|
||||
|
||||
dctx := &proxy.DNSContext{
|
||||
Proto: "udp",
|
||||
Req: req,
|
||||
Proto: proxy.ProtoUDP,
|
||||
Req: req,
|
||||
IsPrivateClient: true,
|
||||
}
|
||||
|
||||
var resolver *proxy.Proxy
|
||||
var errMsg string
|
||||
if s.privateNets.Contains(ip) {
|
||||
if !s.conf.UsePrivateRDNS {
|
||||
return "", 0, nil
|
||||
}
|
||||
|
||||
resolver = s.localResolvers
|
||||
errMsg = "resolving a private address: %w"
|
||||
s.recDetector.add(*req)
|
||||
dctx.RequestedPrivateRDNS = netip.PrefixFrom(ip, ip.BitLen())
|
||||
} else {
|
||||
resolver = s.internalProxy
|
||||
errMsg = "resolving an address: %w"
|
||||
}
|
||||
if err = resolver.Resolve(dctx); err != nil {
|
||||
if err = s.internalProxy.Resolve(dctx); err != nil {
|
||||
return "", 0, fmt.Errorf(errMsg, err)
|
||||
}
|
||||
|
||||
@ -473,103 +451,6 @@ func (s *Server) startLocked() error {
|
||||
return err
|
||||
}
|
||||
|
||||
// prepareLocalResolvers initializes the local upstreams configuration using
|
||||
// boot as bootstrap. It assumes that s.serverLock is locked or s not running.
|
||||
func (s *Server) prepareLocalResolvers(
|
||||
boot upstream.Resolver,
|
||||
) (uc *proxy.UpstreamConfig, err error) {
|
||||
set, err := s.conf.ourAddrsSet()
|
||||
if err != nil {
|
||||
// Don't wrap the error because it's informative enough as is.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resolvers := s.conf.LocalPTRResolvers
|
||||
confNeedsFiltering := len(resolvers) > 0
|
||||
if confNeedsFiltering {
|
||||
resolvers = stringutil.FilterOut(resolvers, IsCommentOrEmpty)
|
||||
} else {
|
||||
sysResolvers := slices.DeleteFunc(slices.Clone(s.sysResolvers.Addrs()), set.Has)
|
||||
resolvers = make([]string, 0, len(sysResolvers))
|
||||
for _, r := range sysResolvers {
|
||||
resolvers = append(resolvers, r.String())
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug("dnsforward: upstreams to resolve ptr for local addresses: %v", resolvers)
|
||||
|
||||
uc, err = s.prepareUpstreamConfig(resolvers, nil, &upstream.Options{
|
||||
Bootstrap: boot,
|
||||
Timeout: defaultLocalTimeout,
|
||||
// TODO(e.burkov): Should we verify server's certificates?
|
||||
PreferIPv6: s.conf.BootstrapPreferIPv6,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("preparing private upstreams: %w", err)
|
||||
}
|
||||
|
||||
if confNeedsFiltering {
|
||||
err = filterOutAddrs(uc, set)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("filtering private upstreams: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return uc, nil
|
||||
}
|
||||
|
||||
// LocalResolversError is an error type for errors during local resolvers setup.
|
||||
// This is only needed to distinguish these errors from errors returned by
|
||||
// creating the proxy.
|
||||
type LocalResolversError struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ error = (*LocalResolversError)(nil)
|
||||
|
||||
// Error implements the error interface for *LocalResolversError.
|
||||
func (err *LocalResolversError) Error() (s string) {
|
||||
return fmt.Sprintf("creating local resolvers: %s", err.Err)
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ errors.Wrapper = (*LocalResolversError)(nil)
|
||||
|
||||
// Unwrap implements the [errors.Wrapper] interface for *LocalResolversError.
|
||||
func (err *LocalResolversError) Unwrap() error {
|
||||
return err.Err
|
||||
}
|
||||
|
||||
// setupLocalResolvers initializes and sets the resolvers for local addresses.
|
||||
// It assumes s.serverLock is locked or s not running. It returns the upstream
|
||||
// configuration used for private PTR resolving, or nil if it's disabled. Note,
|
||||
// that it's safe to put nil into [proxy.Config.PrivateRDNSUpstreamConfig].
|
||||
func (s *Server) setupLocalResolvers(boot upstream.Resolver) (uc *proxy.UpstreamConfig, err error) {
|
||||
if !s.conf.UsePrivateRDNS {
|
||||
// It's safe to put nil into [proxy.Config.PrivateRDNSUpstreamConfig].
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
uc, err = s.prepareLocalResolvers(boot)
|
||||
if err != nil {
|
||||
// Don't wrap the error because it's informative enough as is.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
localResolvers, err := proxy.New(&proxy.Config{
|
||||
UpstreamConfig: uc,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, &LocalResolversError{Err: err}
|
||||
}
|
||||
|
||||
s.localResolvers = localResolvers
|
||||
|
||||
// TODO(e.burkov): Should we also consider the DNS64 usage?
|
||||
return uc, nil
|
||||
}
|
||||
|
||||
// Prepare initializes parameters of s using data from conf. conf must not be
|
||||
// nil.
|
||||
func (s *Server) Prepare(conf *ServerConfig) (err error) {
|
||||
@ -586,7 +467,7 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) {
|
||||
|
||||
s.initDefaultSettings()
|
||||
|
||||
boot, err := s.prepareInternalDNS()
|
||||
err = s.prepareInternalDNS()
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return err
|
||||
@ -608,12 +489,6 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) {
|
||||
return fmt.Errorf("preparing access: %w", err)
|
||||
}
|
||||
|
||||
// TODO(e.burkov): Remove once the local resolvers logic moved to dnsproxy.
|
||||
proxyConfig.PrivateRDNSUpstreamConfig, err = s.setupLocalResolvers(boot)
|
||||
if err != nil {
|
||||
return fmt.Errorf("setting up resolvers: %w", err)
|
||||
}
|
||||
|
||||
proxyConfig.Fallbacks, err = s.setupFallbackDNS()
|
||||
if err != nil {
|
||||
return fmt.Errorf("setting up fallback dns servers: %w", err)
|
||||
@ -626,8 +501,6 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) {
|
||||
|
||||
s.dnsProxy = dnsProxy
|
||||
|
||||
s.recDetector.clear()
|
||||
|
||||
s.setupAddrProc()
|
||||
|
||||
s.registerHandlers()
|
||||
@ -635,36 +508,127 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// prepareInternalDNS initializes the internal state of s before initializing
|
||||
// the primary DNS proxy instance. It assumes s.serverLock is locked or the
|
||||
// Server not running.
|
||||
func (s *Server) prepareInternalDNS() (boot upstream.Resolver, err error) {
|
||||
err = s.prepareIpsetListSettings()
|
||||
// prepareUpstreamSettings sets upstream DNS server settings.
|
||||
func (s *Server) prepareUpstreamSettings(boot upstream.Resolver) (err error) {
|
||||
// Load upstreams either from the file, or from the settings
|
||||
var upstreams []string
|
||||
upstreams, err = s.conf.loadUpstreams()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("preparing ipset settings: %w", err)
|
||||
return fmt.Errorf("loading upstreams: %w", err)
|
||||
}
|
||||
|
||||
s.bootstrap, s.bootResolvers, err = s.createBootstrap(s.conf.BootstrapDNS, &upstream.Options{
|
||||
Timeout: DefaultTimeout,
|
||||
uc, err := newUpstreamConfig(upstreams, defaultDNS, &upstream.Options{
|
||||
Bootstrap: boot,
|
||||
Timeout: s.conf.UpstreamTimeout,
|
||||
HTTPVersions: UpstreamHTTPVersions(s.conf.UseHTTP3Upstreams),
|
||||
PreferIPv6: s.conf.BootstrapPreferIPv6,
|
||||
// Use a customized set of RootCAs, because Go's default mechanism of
|
||||
// loading TLS roots does not always work properly on some routers so we're
|
||||
// loading roots manually and pass it here.
|
||||
//
|
||||
// See [aghtls.SystemRootCAs].
|
||||
//
|
||||
// TODO(a.garipov): Investigate if that's true.
|
||||
RootCAs: s.conf.TLSv12Roots,
|
||||
CipherSuites: s.conf.TLSCiphers,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("preparing upstream config: %w", err)
|
||||
}
|
||||
|
||||
s.conf.UpstreamConfig = uc
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PrivateRDNSError is returned when the private rDNS upstreams are
|
||||
// invalid but enabled.
|
||||
//
|
||||
// TODO(e.burkov): Consider allowing to use incomplete private rDNS upstreams
|
||||
// configuration in proxy when the private rDNS function is enabled. In theory,
|
||||
// proxy supports the case when no upstreams provided to resolve the private
|
||||
// request, since it already supports this for DNS64-prefixed PTR requests.
|
||||
type PrivateRDNSError struct {
|
||||
err error
|
||||
}
|
||||
|
||||
// Error implements the [errors.Error] interface.
|
||||
func (e *PrivateRDNSError) Error() (s string) {
|
||||
return e.err.Error()
|
||||
}
|
||||
|
||||
func (e *PrivateRDNSError) Unwrap() (err error) {
|
||||
return e.err
|
||||
}
|
||||
|
||||
// prepareLocalResolvers initializes the private RDNS upstream configuration
|
||||
// according to the server's settings. It assumes s.serverLock is locked or the
|
||||
// Server not running.
|
||||
func (s *Server) prepareLocalResolvers() (uc *proxy.UpstreamConfig, err error) {
|
||||
if !s.conf.UsePrivateRDNS {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var ownAddrs addrPortSet
|
||||
ownAddrs, err = s.conf.ourAddrsSet()
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
opts := &upstream.Options{
|
||||
Bootstrap: s.bootstrap,
|
||||
Timeout: defaultLocalTimeout,
|
||||
// TODO(e.burkov): Should we verify server's certificates?
|
||||
PreferIPv6: s.conf.BootstrapPreferIPv6,
|
||||
}
|
||||
|
||||
addrs := s.conf.LocalPTRResolvers
|
||||
uc, err = newPrivateConfig(addrs, ownAddrs, s.sysResolvers, s.privateNets, opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("preparing resolvers: %w", err)
|
||||
}
|
||||
|
||||
return uc, nil
|
||||
}
|
||||
|
||||
// prepareInternalDNS initializes the internal state of s before initializing
|
||||
// the primary DNS proxy instance. It assumes s.serverLock is locked or the
|
||||
// Server not running.
|
||||
func (s *Server) prepareInternalDNS() (err error) {
|
||||
err = s.prepareIpsetListSettings()
|
||||
if err != nil {
|
||||
return fmt.Errorf("preparing ipset settings: %w", err)
|
||||
}
|
||||
|
||||
bootOpts := &upstream.Options{
|
||||
Timeout: DefaultTimeout,
|
||||
HTTPVersions: UpstreamHTTPVersions(s.conf.UseHTTP3Upstreams),
|
||||
}
|
||||
|
||||
s.bootstrap, s.bootResolvers, err = newBootstrap(s.conf.BootstrapDNS, s.etcHosts, bootOpts)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
err = s.prepareUpstreamSettings(s.bootstrap)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return s.bootstrap, err
|
||||
return err
|
||||
}
|
||||
|
||||
s.conf.PrivateRDNSUpstreamConfig, err = s.prepareLocalResolvers()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = s.prepareInternalProxy()
|
||||
if err != nil {
|
||||
return s.bootstrap, fmt.Errorf("preparing internal proxy: %w", err)
|
||||
return fmt.Errorf("preparing internal proxy: %w", err)
|
||||
}
|
||||
|
||||
return s.bootstrap, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// setupFallbackDNS initializes the fallback DNS servers.
|
||||
@ -743,10 +707,16 @@ func validateBlockingMode(
|
||||
func (s *Server) prepareInternalProxy() (err error) {
|
||||
srvConf := s.conf
|
||||
conf := &proxy.Config{
|
||||
CacheEnabled: true,
|
||||
CacheSizeBytes: 4096,
|
||||
UpstreamConfig: srvConf.UpstreamConfig,
|
||||
MaxGoroutines: s.conf.MaxGoroutines,
|
||||
CacheEnabled: true,
|
||||
CacheSizeBytes: 4096,
|
||||
PrivateRDNSUpstreamConfig: srvConf.PrivateRDNSUpstreamConfig,
|
||||
UpstreamConfig: srvConf.UpstreamConfig,
|
||||
MaxGoroutines: srvConf.MaxGoroutines,
|
||||
UseDNS64: srvConf.UseDNS64,
|
||||
DNS64Prefs: srvConf.DNS64Prefixes,
|
||||
UsePrivateRDNS: srvConf.UsePrivateRDNS,
|
||||
PrivateSubnets: s.privateNets,
|
||||
MessageConstructor: s,
|
||||
}
|
||||
|
||||
err = setProxyUpstreamMode(conf, srvConf.UpstreamMode, srvConf.FastestTimeout.Duration)
|
||||
@ -782,11 +752,6 @@ func (s *Server) stopLocked() (err error) {
|
||||
}
|
||||
}
|
||||
|
||||
logCloserErr(s.internalProxy.UpstreamConfig, "dnsforward: closing internal resolvers: %s")
|
||||
if s.localResolvers != nil {
|
||||
logCloserErr(s.localResolvers.UpstreamConfig, "dnsforward: closing local resolvers: %s")
|
||||
}
|
||||
|
||||
for _, b := range s.bootResolvers {
|
||||
logCloserErr(b, "dnsforward: closing bootstrap %s: %s", b.Address())
|
||||
}
|
||||
|
@ -143,7 +143,7 @@ func (s *Server) filterDNSRewrite(
|
||||
res *filtering.Result,
|
||||
pctx *proxy.DNSContext,
|
||||
) (err error) {
|
||||
resp := s.makeResponse(req)
|
||||
resp := s.replyCompressed(req)
|
||||
dnsrr := res.DNSRewriteResult
|
||||
if dnsrr == nil {
|
||||
return errors.Error("no dns rewrite rule content")
|
||||
|
@ -1,57 +1,17 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/urlfilter/rules"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// beforeRequestHandler is the handler that is called before any other
|
||||
// processing, including logs. It performs access checks and puts the client
|
||||
// ID, if there is one, into the server's cache.
|
||||
func (s *Server) beforeRequestHandler(
|
||||
_ *proxy.Proxy,
|
||||
pctx *proxy.DNSContext,
|
||||
) (reply bool, err error) {
|
||||
clientID, err := s.clientIDFromDNSContext(pctx)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("getting clientid: %w", err)
|
||||
}
|
||||
|
||||
blocked, _ := s.IsBlockedClient(pctx.Addr.Addr(), clientID)
|
||||
if blocked {
|
||||
return s.preBlockedResponse(pctx)
|
||||
}
|
||||
|
||||
if len(pctx.Req.Question) == 1 {
|
||||
q := pctx.Req.Question[0]
|
||||
qt := q.Qtype
|
||||
host := aghnet.NormalizeDomain(q.Name)
|
||||
if s.access.isBlockedHost(host, qt) {
|
||||
log.Debug("access: request %s %s is in access blocklist", dns.Type(qt), host)
|
||||
|
||||
return s.preBlockedResponse(pctx)
|
||||
}
|
||||
}
|
||||
|
||||
if clientID != "" {
|
||||
key := [8]byte{}
|
||||
binary.BigEndian.PutUint64(key[:], pctx.RequestID)
|
||||
s.clientIDCache.Set(key[:], []byte(clientID))
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// clientRequestFilteringSettings looks up client filtering settings using the
|
||||
// client's IP address and ID, if any, from dctx.
|
||||
func (s *Server) clientRequestFilteringSettings(dctx *dnsContext) (setts *filtering.Settings) {
|
||||
|
@ -261,55 +261,17 @@ func (req *jsonDNSConfig) checkUpstreamMode() (err error) {
|
||||
}
|
||||
}
|
||||
|
||||
// checkBootstrap returns an error if any bootstrap address is invalid.
|
||||
func (req *jsonDNSConfig) checkBootstrap() (err error) {
|
||||
if req.Bootstraps == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var b string
|
||||
defer func() { err = errors.Annotate(err, "checking bootstrap %s: %w", b) }()
|
||||
|
||||
for _, b = range *req.Bootstraps {
|
||||
if b == "" {
|
||||
return errors.Error("empty")
|
||||
}
|
||||
|
||||
var resolver *upstream.UpstreamResolver
|
||||
if resolver, err = upstream.NewUpstreamResolver(b, nil); err != nil {
|
||||
// Don't wrap the error because it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
if err = resolver.Close(); err != nil {
|
||||
return fmt.Errorf("closing %s: %w", b, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkFallbacks returns an error if any fallback address is invalid.
|
||||
func (req *jsonDNSConfig) checkFallbacks() (err error) {
|
||||
if req.Fallbacks == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err = proxy.ParseUpstreamsConfig(*req.Fallbacks, &upstream.Options{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("fallback servers: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validate returns an error if any field of req is invalid.
|
||||
//
|
||||
// TODO(s.chzhen): Parse, don't validate.
|
||||
func (req *jsonDNSConfig) validate(privateNets netutil.SubnetSet) (err error) {
|
||||
func (req *jsonDNSConfig) validate(
|
||||
ownAddrs addrPortSet,
|
||||
sysResolvers SystemResolvers,
|
||||
privateNets netutil.SubnetSet,
|
||||
) (err error) {
|
||||
defer func() { err = errors.Annotate(err, "validating dns config: %w") }()
|
||||
|
||||
err = req.validateUpstreamDNSServers(privateNets)
|
||||
err = req.validateUpstreamDNSServers(ownAddrs, sysResolvers, privateNets)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return err
|
||||
@ -342,17 +304,54 @@ func (req *jsonDNSConfig) validate(privateNets netutil.SubnetSet) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkBootstrap returns an error if any bootstrap address is invalid.
|
||||
func (req *jsonDNSConfig) checkBootstrap() (err error) {
|
||||
if req.Bootstraps == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var b string
|
||||
defer func() { err = errors.Annotate(err, "checking bootstrap %s: %w", b) }()
|
||||
|
||||
for _, b = range *req.Bootstraps {
|
||||
if b == "" {
|
||||
return errors.Error("empty")
|
||||
}
|
||||
|
||||
var resolver *upstream.UpstreamResolver
|
||||
if resolver, err = upstream.NewUpstreamResolver(b, nil); err != nil {
|
||||
// Don't wrap the error because it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
if err = resolver.Close(); err != nil {
|
||||
return fmt.Errorf("closing %s: %w", b, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateUpstreamDNSServers returns an error if any field of req is invalid.
|
||||
func (req *jsonDNSConfig) validateUpstreamDNSServers(privateNets netutil.SubnetSet) (err error) {
|
||||
func (req *jsonDNSConfig) validateUpstreamDNSServers(
|
||||
ownAddrs addrPortSet,
|
||||
sysResolvers SystemResolvers,
|
||||
privateNets netutil.SubnetSet,
|
||||
) (err error) {
|
||||
var uc *proxy.UpstreamConfig
|
||||
opts := &upstream.Options{}
|
||||
|
||||
if req.Upstreams != nil {
|
||||
_, err = proxy.ParseUpstreamsConfig(*req.Upstreams, &upstream.Options{})
|
||||
uc, err = proxy.ParseUpstreamsConfig(*req.Upstreams, opts)
|
||||
err = errors.WithDeferred(err, uc.Close())
|
||||
if err != nil {
|
||||
return fmt.Errorf("upstream servers: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if req.LocalPTRUpstreams != nil {
|
||||
err = ValidateUpstreamsPrivate(*req.LocalPTRUpstreams, privateNets)
|
||||
if addrs := req.LocalPTRUpstreams; addrs != nil {
|
||||
uc, err = newPrivateConfig(*addrs, ownAddrs, sysResolvers, privateNets, opts)
|
||||
err = errors.WithDeferred(err, uc.Close())
|
||||
if err != nil {
|
||||
return fmt.Errorf("private upstream servers: %w", err)
|
||||
}
|
||||
@ -364,10 +363,12 @@ func (req *jsonDNSConfig) validateUpstreamDNSServers(privateNets netutil.SubnetS
|
||||
return err
|
||||
}
|
||||
|
||||
err = req.checkFallbacks()
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return err
|
||||
if req.Fallbacks != nil {
|
||||
uc, err = proxy.ParseUpstreamsConfig(*req.Fallbacks, opts)
|
||||
err = errors.WithDeferred(err, uc.Close())
|
||||
if err != nil {
|
||||
return fmt.Errorf("fallback servers: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
@ -436,7 +437,16 @@ func (s *Server) handleSetConfig(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
err = req.validate(s.privateNets)
|
||||
// TODO(e.burkov): Consider prebuilding this set on startup.
|
||||
ourAddrs, err := s.conf.ourAddrsSet()
|
||||
if err != nil {
|
||||
// TODO(e.burkov): !! Put into openapi
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "getting our addresses: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
err = req.validate(ourAddrs, s.sysResolvers, s.privateNets)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
@ -587,7 +597,7 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
var boots []*upstream.UpstreamResolver
|
||||
opts.Bootstrap, boots, err = s.createBootstrap(req.BootstrapDNS, opts)
|
||||
opts.Bootstrap, boots, err = newBootstrap(req.BootstrapDNS, s.etcHosts, opts)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "Failed to parse bootstrap servers: %s", err)
|
||||
|
||||
|
@ -245,9 +245,8 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
|
||||
wantSet: "",
|
||||
}, {
|
||||
name: "local_ptr_upstreams_bad",
|
||||
wantSet: `validating dns config: ` +
|
||||
`private upstream servers: checking domain-specific upstreams: ` +
|
||||
`bad arpa domain name "non.arpa.": not a reversed ip network`,
|
||||
wantSet: `validating dns config: private upstream servers: ` +
|
||||
`bad arpa domain name "non.arpa": not a reversed ip network`,
|
||||
}, {
|
||||
name: "local_ptr_upstreams_null",
|
||||
wantSet: "",
|
||||
@ -318,58 +317,6 @@ func TestIsCommentOrEmpty(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateUpstreamsPrivate(t *testing.T) {
|
||||
ss := netutil.SubnetSetFunc(netutil.IsLocallyServed)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
wantErr string
|
||||
u string
|
||||
}{{
|
||||
name: "success_address",
|
||||
wantErr: ``,
|
||||
u: "[/1.0.0.127.in-addr.arpa/]#",
|
||||
}, {
|
||||
name: "success_subnet",
|
||||
wantErr: ``,
|
||||
u: "[/127.in-addr.arpa/]#",
|
||||
}, {
|
||||
name: "not_arpa_subnet",
|
||||
wantErr: `checking domain-specific upstreams: ` +
|
||||
`bad arpa domain name "hello.world.": not a reversed ip network`,
|
||||
u: "[/hello.world/]#",
|
||||
}, {
|
||||
name: "non-private_arpa_address",
|
||||
wantErr: `checking domain-specific upstreams: ` +
|
||||
`arpa domain "1.2.3.4.in-addr.arpa." should point to a locally-served network`,
|
||||
u: "[/1.2.3.4.in-addr.arpa/]#",
|
||||
}, {
|
||||
name: "non-private_arpa_subnet",
|
||||
wantErr: `checking domain-specific upstreams: ` +
|
||||
`arpa domain "128.in-addr.arpa." should point to a locally-served network`,
|
||||
u: "[/128.in-addr.arpa/]#",
|
||||
}, {
|
||||
name: "several_bad",
|
||||
wantErr: `checking domain-specific upstreams: ` +
|
||||
`arpa domain "1.2.3.4.in-addr.arpa." should point to a locally-served network` + "\n" +
|
||||
`bad arpa domain name "non.arpa.": not a reversed ip network`,
|
||||
u: "[/non.arpa/1.2.3.4.in-addr.arpa/127.in-addr.arpa/]#",
|
||||
}, {
|
||||
name: "partial_good",
|
||||
wantErr: "",
|
||||
u: "[/a.1.2.3.10.in-addr.arpa/a.10.in-addr.arpa/]#",
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
set := []string{"192.168.0.1", tc.u}
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := ValidateUpstreamsPrivate(set, ss)
|
||||
testutil.AssertErrorMsg(t, tc.wantErr, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func newLocalUpstreamListener(t *testing.T, port uint16, handler dns.Handler) (real netip.AddrPort) {
|
||||
t.Helper()
|
||||
|
||||
|
@ -11,17 +11,21 @@ import (
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// makeResponse creates a DNS response by req and sets necessary flags. It also
|
||||
// guarantees that req.Question will be not empty.
|
||||
func (s *Server) makeResponse(req *dns.Msg) (resp *dns.Msg) {
|
||||
resp = &dns.Msg{
|
||||
MsgHdr: dns.MsgHdr{
|
||||
RecursionAvailable: true,
|
||||
},
|
||||
Compress: true,
|
||||
}
|
||||
// TODO(e.burkov): Name all the methods by a [proxy.MessageConstructor]
|
||||
// template. Also extract all the methods to a separate entity.
|
||||
|
||||
resp.SetReply(req)
|
||||
// reply creates a DNS response for req.
|
||||
func (*Server) reply(req *dns.Msg, code int) (resp *dns.Msg) {
|
||||
resp = (&dns.Msg{}).SetRcode(req, code)
|
||||
resp.RecursionAvailable = true
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
// replyCompressed creates a DNS response for req and sets the compress flag.
|
||||
func (s *Server) replyCompressed(req *dns.Msg) (resp *dns.Msg) {
|
||||
resp = s.reply(req, dns.RcodeSuccess)
|
||||
resp.Compress = true
|
||||
|
||||
return resp
|
||||
}
|
||||
@ -51,7 +55,7 @@ func (s *Server) genDNSFilterMessage(
|
||||
if qt != dns.TypeA && qt != dns.TypeAAAA {
|
||||
m, _, _ := s.dnsFilter.BlockingMode()
|
||||
if m == filtering.BlockingModeNullIP {
|
||||
return s.makeResponse(req)
|
||||
return s.replyCompressed(req)
|
||||
}
|
||||
|
||||
return s.newMsgNODATA(req)
|
||||
@ -75,7 +79,7 @@ func (s *Server) genDNSFilterMessage(
|
||||
// getCNAMEWithIPs generates a filtered response to req for with CNAME record
|
||||
// and provided ips.
|
||||
func (s *Server) getCNAMEWithIPs(req *dns.Msg, ips []netip.Addr, cname string) (resp *dns.Msg) {
|
||||
resp = s.makeResponse(req)
|
||||
resp = s.replyCompressed(req)
|
||||
|
||||
originalName := req.Question[0].Name
|
||||
|
||||
@ -121,13 +125,13 @@ func (s *Server) genForBlockingMode(req *dns.Msg, ips []netip.Addr) (resp *dns.M
|
||||
case filtering.BlockingModeNullIP:
|
||||
return s.makeResponseNullIP(req)
|
||||
case filtering.BlockingModeNXDOMAIN:
|
||||
return s.genNXDomain(req)
|
||||
return s.NewMsgNXDOMAIN(req)
|
||||
case filtering.BlockingModeREFUSED:
|
||||
return s.makeResponseREFUSED(req)
|
||||
default:
|
||||
log.Error("dnsforward: invalid blocking mode %q", mode)
|
||||
|
||||
return s.makeResponse(req)
|
||||
return s.replyCompressed(req)
|
||||
}
|
||||
}
|
||||
|
||||
@ -148,25 +152,18 @@ func (s *Server) makeResponseCustomIP(
|
||||
// genDNSFilterMessage.
|
||||
log.Error("dnsforward: invalid msg type %s for custom IP blocking mode", dns.Type(qt))
|
||||
|
||||
return s.makeResponse(req)
|
||||
return s.replyCompressed(req)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) genServerFailure(request *dns.Msg) *dns.Msg {
|
||||
resp := dns.Msg{}
|
||||
resp.SetRcode(request, dns.RcodeServerFailure)
|
||||
resp.RecursionAvailable = true
|
||||
return &resp
|
||||
}
|
||||
|
||||
func (s *Server) genARecord(request *dns.Msg, ip netip.Addr) *dns.Msg {
|
||||
resp := s.makeResponse(request)
|
||||
resp := s.replyCompressed(request)
|
||||
resp.Answer = append(resp.Answer, s.genAnswerA(request, ip))
|
||||
return resp
|
||||
}
|
||||
|
||||
func (s *Server) genAAAARecord(request *dns.Msg, ip netip.Addr) *dns.Msg {
|
||||
resp := s.makeResponse(request)
|
||||
resp := s.replyCompressed(request)
|
||||
resp.Answer = append(resp.Answer, s.genAnswerAAAA(request, ip))
|
||||
return resp
|
||||
}
|
||||
@ -252,7 +249,7 @@ func (s *Server) genResponseWithIPs(req *dns.Msg, ips []netip.Addr) (resp *dns.M
|
||||
// Go on and return an empty response.
|
||||
}
|
||||
|
||||
resp = s.makeResponse(req)
|
||||
resp = s.replyCompressed(req)
|
||||
resp.Answer = ans
|
||||
|
||||
return resp
|
||||
@ -288,7 +285,7 @@ func (s *Server) makeResponseNullIP(req *dns.Msg) (resp *dns.Msg) {
|
||||
case dns.TypeAAAA:
|
||||
resp = s.genResponseWithIPs(req, []netip.Addr{netip.IPv6Unspecified()})
|
||||
default:
|
||||
resp = s.makeResponse(req)
|
||||
resp = s.replyCompressed(req)
|
||||
}
|
||||
|
||||
return resp
|
||||
@ -298,7 +295,7 @@ func (s *Server) genBlockedHost(request *dns.Msg, newAddr string, d *proxy.DNSCo
|
||||
if newAddr == "" {
|
||||
log.Info("dnsforward: block host is not specified")
|
||||
|
||||
return s.genServerFailure(request)
|
||||
return s.NewMsgSERVFAIL(request)
|
||||
}
|
||||
|
||||
ip, err := netip.ParseAddr(newAddr)
|
||||
@ -321,17 +318,17 @@ func (s *Server) genBlockedHost(request *dns.Msg, newAddr string, d *proxy.DNSCo
|
||||
if prx == nil {
|
||||
log.Debug("dnsforward: %s", srvClosedErr)
|
||||
|
||||
return s.genServerFailure(request)
|
||||
return s.NewMsgSERVFAIL(request)
|
||||
}
|
||||
|
||||
err = prx.Resolve(newContext)
|
||||
if err != nil {
|
||||
log.Info("dnsforward: looking up replacement host %q: %s", newAddr, err)
|
||||
|
||||
return s.genServerFailure(request)
|
||||
return s.NewMsgSERVFAIL(request)
|
||||
}
|
||||
|
||||
resp := s.makeResponse(request)
|
||||
resp := s.replyCompressed(request)
|
||||
if newContext.Res != nil {
|
||||
for _, answer := range newContext.Res.Answer {
|
||||
answer.Header().Name = request.Question[0].Name
|
||||
@ -342,47 +339,21 @@ func (s *Server) genBlockedHost(request *dns.Msg, newAddr string, d *proxy.DNSCo
|
||||
return resp
|
||||
}
|
||||
|
||||
// preBlockedResponse returns a protocol-appropriate response for a request that
|
||||
// was blocked by access settings.
|
||||
func (s *Server) preBlockedResponse(pctx *proxy.DNSContext) (reply bool, err error) {
|
||||
if pctx.Proto == proxy.ProtoUDP || pctx.Proto == proxy.ProtoDNSCrypt {
|
||||
// Return nil so that dnsproxy drops the connection and thus
|
||||
// prevent DNS amplification attacks.
|
||||
return false, nil
|
||||
}
|
||||
|
||||
pctx.Res = s.makeResponseREFUSED(pctx.Req)
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Create REFUSED DNS response
|
||||
func (s *Server) makeResponseREFUSED(request *dns.Msg) *dns.Msg {
|
||||
resp := dns.Msg{}
|
||||
resp.SetRcode(request, dns.RcodeRefused)
|
||||
resp.RecursionAvailable = true
|
||||
return &resp
|
||||
func (s *Server) makeResponseREFUSED(req *dns.Msg) *dns.Msg {
|
||||
return s.reply(req, dns.RcodeRefused)
|
||||
}
|
||||
|
||||
// newMsgNODATA returns a properly initialized NODATA response.
|
||||
//
|
||||
// See https://www.rfc-editor.org/rfc/rfc2308#section-2.2.
|
||||
func (s *Server) newMsgNODATA(req *dns.Msg) (resp *dns.Msg) {
|
||||
resp = (&dns.Msg{}).SetRcode(req, dns.RcodeSuccess)
|
||||
resp.RecursionAvailable = true
|
||||
resp = s.reply(req, dns.RcodeSuccess)
|
||||
resp.Ns = s.genSOA(req)
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
func (s *Server) genNXDomain(request *dns.Msg) *dns.Msg {
|
||||
resp := dns.Msg{}
|
||||
resp.SetRcode(request, dns.RcodeNameError)
|
||||
resp.RecursionAvailable = true
|
||||
resp.Ns = s.genSOA(request)
|
||||
return &resp
|
||||
}
|
||||
|
||||
func (s *Server) genSOA(request *dns.Msg) []dns.RR {
|
||||
zone := ""
|
||||
if len(request.Question) > 0 {
|
||||
@ -414,5 +385,43 @@ func (s *Server) genSOA(request *dns.Msg) []dns.RR {
|
||||
if len(zone) > 0 && zone[0] != '.' {
|
||||
soa.Mbox += zone
|
||||
}
|
||||
|
||||
return []dns.RR{&soa}
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ proxy.MessageConstructor = (*Server)(nil)
|
||||
|
||||
// NewMsgNXDOMAIN implements the [proxy.MessageConstructor] interface for
|
||||
// *Server.
|
||||
func (s *Server) NewMsgNXDOMAIN(req *dns.Msg) (resp *dns.Msg) {
|
||||
resp = s.reply(req, dns.RcodeNameError)
|
||||
resp.Ns = s.genSOA(req)
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
// NewMsgSERVFAIL implements the [proxy.MessageConstructor] interface for
|
||||
// *Server.
|
||||
func (s *Server) NewMsgSERVFAIL(req *dns.Msg) (resp *dns.Msg) {
|
||||
return s.reply(req, dns.RcodeServerFailure)
|
||||
}
|
||||
|
||||
// NewMsgNOTIMPLEMENTED implements the [proxy.MessageConstructor] interface for
|
||||
// *Server.
|
||||
func (s *Server) NewMsgNOTIMPLEMENTED(req *dns.Msg) (resp *dns.Msg) {
|
||||
resp = s.reply(req, dns.RcodeNotImplemented)
|
||||
|
||||
// Most of the Internet and especially the inner core has an MTU of at least
|
||||
// 1500 octets. Maximum DNS/UDP payload size for IPv6 on MTU 1500 ethernet
|
||||
// is 1452 (1500 minus 40 (IPv6 header size) minus 8 (UDP header size)).
|
||||
//
|
||||
// See appendix A of https://datatracker.ietf.org/doc/draft-ietf-dnsop-avoid-fragmentation/17.
|
||||
const maxUDPPayload = 1452
|
||||
|
||||
// NOTIMPLEMENTED without EDNS is treated as 'we don't support EDNS', so
|
||||
// explicitly set it.
|
||||
resp.SetEdns0(maxUDPPayload, false)
|
||||
|
||||
return resp
|
||||
}
|
||||
|
@ -1,20 +1,17 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"encoding/binary"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
@ -34,11 +31,6 @@ type dnsContext struct {
|
||||
// response is modified by filters.
|
||||
origResp *dns.Msg
|
||||
|
||||
// unreversedReqIP stores an IP address obtained from a PTR request if it
|
||||
// was parsed successfully and belongs to one of the locally served IP
|
||||
// ranges.
|
||||
unreversedReqIP netip.Addr
|
||||
|
||||
// err is the error returned from a processing function.
|
||||
err error
|
||||
|
||||
@ -63,10 +55,6 @@ type dnsContext struct {
|
||||
// responseAD shows if the response had the AD bit set.
|
||||
responseAD bool
|
||||
|
||||
// isLocalClient shows if client's IP address is from locally served
|
||||
// network.
|
||||
isLocalClient bool
|
||||
|
||||
// isDHCPHost is true if the request for a local domain name and the DHCP is
|
||||
// available for this request.
|
||||
isDHCPHost bool
|
||||
@ -109,15 +97,11 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, pctx *proxy.DNSContext) error
|
||||
// (*proxy.Proxy).handleDNSRequest method performs it before calling the
|
||||
// appropriate handler.
|
||||
mods := []modProcessFunc{
|
||||
s.processRecursion,
|
||||
s.processInitial,
|
||||
s.processDDRQuery,
|
||||
s.processDetermineLocal,
|
||||
s.processDHCPHosts,
|
||||
s.processRestrictLocal,
|
||||
s.processDHCPAddrs,
|
||||
s.processFilteringBeforeRequest,
|
||||
s.processLocalPTR,
|
||||
s.processUpstream,
|
||||
s.processFilteringAfterResponse,
|
||||
s.ipset.process,
|
||||
@ -145,24 +129,6 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, pctx *proxy.DNSContext) error
|
||||
return nil
|
||||
}
|
||||
|
||||
// processRecursion checks the incoming request and halts its handling by
|
||||
// answering NXDOMAIN if s has tried to resolve it recently.
|
||||
func (s *Server) processRecursion(dctx *dnsContext) (rc resultCode) {
|
||||
log.Debug("dnsforward: started processing recursion")
|
||||
defer log.Debug("dnsforward: finished processing recursion")
|
||||
|
||||
pctx := dctx.proxyCtx
|
||||
|
||||
if msg := pctx.Req; msg != nil && s.recDetector.check(*msg) {
|
||||
log.Debug("dnsforward: recursion detected resolving %q", msg.Question[0].Name)
|
||||
pctx.Res = s.genNXDomain(pctx.Req)
|
||||
|
||||
return resultCodeFinish
|
||||
}
|
||||
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
// mozillaFQDN is the domain used to signal the Firefox browser to not use its
|
||||
// own DoH server.
|
||||
//
|
||||
@ -199,14 +165,14 @@ func (s *Server) processInitial(dctx *dnsContext) (rc resultCode) {
|
||||
}
|
||||
|
||||
if (qt == dns.TypeA || qt == dns.TypeAAAA) && q.Name == mozillaFQDN {
|
||||
pctx.Res = s.genNXDomain(pctx.Req)
|
||||
pctx.Res = s.NewMsgNXDOMAIN(pctx.Req)
|
||||
|
||||
return resultCodeFinish
|
||||
}
|
||||
|
||||
if q.Name == healthcheckFQDN {
|
||||
// Generate a NODATA negative response to make nslookup exit with 0.
|
||||
pctx.Res = s.makeResponse(pctx.Req)
|
||||
pctx.Res = s.replyCompressed(pctx.Req)
|
||||
|
||||
return resultCodeFinish
|
||||
}
|
||||
@ -272,7 +238,7 @@ func (s *Server) processDDRQuery(dctx *dnsContext) (rc resultCode) {
|
||||
//
|
||||
// [draft standard]: https://www.ietf.org/archive/id/draft-ietf-add-ddr-10.html.
|
||||
func (s *Server) makeDDRResponse(req *dns.Msg) (resp *dns.Msg) {
|
||||
resp = s.makeResponse(req)
|
||||
resp = s.replyCompressed(req)
|
||||
if req.Question[0].Qtype != dns.TypeSVCB {
|
||||
return resp
|
||||
}
|
||||
@ -339,19 +305,6 @@ func (s *Server) makeDDRResponse(req *dns.Msg) (resp *dns.Msg) {
|
||||
return resp
|
||||
}
|
||||
|
||||
// processDetermineLocal determines if the client's IP address is from locally
|
||||
// served network and saves the result into the context.
|
||||
func (s *Server) processDetermineLocal(dctx *dnsContext) (rc resultCode) {
|
||||
log.Debug("dnsforward: started processing local detection")
|
||||
defer log.Debug("dnsforward: finished processing local detection")
|
||||
|
||||
rc = resultCodeSuccess
|
||||
|
||||
dctx.isLocalClient = s.privateNets.Contains(dctx.proxyCtx.Addr.Addr())
|
||||
|
||||
return rc
|
||||
}
|
||||
|
||||
// processDHCPHosts respond to A requests if the target hostname is known to
|
||||
// the server. It responds with a mapped IP address if the DNS64 is enabled and
|
||||
// the request is for AAAA.
|
||||
@ -370,9 +323,9 @@ func (s *Server) processDHCPHosts(dctx *dnsContext) (rc resultCode) {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
if !dctx.isLocalClient {
|
||||
if !pctx.IsPrivateClient {
|
||||
log.Debug("dnsforward: %q requests for dhcp host %q", pctx.Addr, dhcpHost)
|
||||
pctx.Res = s.genNXDomain(req)
|
||||
pctx.Res = s.NewMsgNXDOMAIN(req)
|
||||
|
||||
// Do not even put into query log.
|
||||
return resultCodeFinish
|
||||
@ -389,7 +342,7 @@ func (s *Server) processDHCPHosts(dctx *dnsContext) (rc resultCode) {
|
||||
|
||||
log.Debug("dnsforward: dhcp record for %q is %s", dhcpHost, ip)
|
||||
|
||||
resp := s.makeResponse(req)
|
||||
resp := s.replyCompressed(req)
|
||||
switch q.Qtype {
|
||||
case dns.TypeA:
|
||||
a := &dns.A{
|
||||
@ -416,141 +369,6 @@ func (s *Server) processDHCPHosts(dctx *dnsContext) (rc resultCode) {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
// indexFirstV4Label returns the index at which the reversed IPv4 address
|
||||
// starts, assuming the domain is pre-validated ARPA domain having in-addr and
|
||||
// arpa labels removed.
|
||||
func indexFirstV4Label(domain string) (idx int) {
|
||||
idx = len(domain)
|
||||
for labelsNum := 0; labelsNum < net.IPv4len && idx > 0; labelsNum++ {
|
||||
curIdx := strings.LastIndexByte(domain[:idx-1], '.') + 1
|
||||
_, parseErr := strconv.ParseUint(domain[curIdx:idx-1], 10, 8)
|
||||
if parseErr != nil {
|
||||
return idx
|
||||
}
|
||||
|
||||
idx = curIdx
|
||||
}
|
||||
|
||||
return idx
|
||||
}
|
||||
|
||||
// indexFirstV6Label returns the index at which the reversed IPv6 address
|
||||
// starts, assuming the domain is pre-validated ARPA domain having ip6 and arpa
|
||||
// labels removed.
|
||||
func indexFirstV6Label(domain string) (idx int) {
|
||||
idx = len(domain)
|
||||
for labelsNum := 0; labelsNum < net.IPv6len*2 && idx > 0; labelsNum++ {
|
||||
curIdx := idx - len("a.")
|
||||
if curIdx > 1 && domain[curIdx-1] != '.' {
|
||||
return idx
|
||||
}
|
||||
|
||||
nibble := domain[curIdx]
|
||||
if (nibble < '0' || nibble > '9') && (nibble < 'a' || nibble > 'f') {
|
||||
return idx
|
||||
}
|
||||
|
||||
idx = curIdx
|
||||
}
|
||||
|
||||
return idx
|
||||
}
|
||||
|
||||
// extractARPASubnet tries to convert a reversed ARPA address being a part of
|
||||
// domain to an IP network. domain must be an FQDN.
|
||||
//
|
||||
// TODO(e.burkov): Move to golibs.
|
||||
func extractARPASubnet(domain string) (pref netip.Prefix, err error) {
|
||||
err = netutil.ValidateDomainName(strings.TrimSuffix(domain, "."))
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return netip.Prefix{}, err
|
||||
}
|
||||
|
||||
const (
|
||||
v4Suffix = "in-addr.arpa."
|
||||
v6Suffix = "ip6.arpa."
|
||||
)
|
||||
|
||||
domain = strings.ToLower(domain)
|
||||
|
||||
var idx int
|
||||
switch {
|
||||
case strings.HasSuffix(domain, v4Suffix):
|
||||
idx = indexFirstV4Label(domain[:len(domain)-len(v4Suffix)])
|
||||
case strings.HasSuffix(domain, v6Suffix):
|
||||
idx = indexFirstV6Label(domain[:len(domain)-len(v6Suffix)])
|
||||
default:
|
||||
return netip.Prefix{}, &netutil.AddrError{
|
||||
Err: netutil.ErrNotAReversedSubnet,
|
||||
Kind: netutil.AddrKindARPA,
|
||||
Addr: domain,
|
||||
}
|
||||
}
|
||||
|
||||
return netutil.PrefixFromReversedAddr(domain[idx:])
|
||||
}
|
||||
|
||||
// processRestrictLocal responds with NXDOMAIN to PTR requests for IP addresses
|
||||
// in locally served network from external clients.
|
||||
func (s *Server) processRestrictLocal(dctx *dnsContext) (rc resultCode) {
|
||||
log.Debug("dnsforward: started processing local restriction")
|
||||
defer log.Debug("dnsforward: finished processing local restriction")
|
||||
|
||||
pctx := dctx.proxyCtx
|
||||
req := pctx.Req
|
||||
q := req.Question[0]
|
||||
if q.Qtype != dns.TypePTR {
|
||||
// No need for restriction.
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
subnet, err := extractARPASubnet(q.Name)
|
||||
if err != nil {
|
||||
if errors.Is(err, netutil.ErrNotAReversedSubnet) {
|
||||
log.Debug("dnsforward: request is not for arpa domain")
|
||||
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
log.Debug("dnsforward: parsing reversed addr: %s", err)
|
||||
|
||||
return resultCodeError
|
||||
}
|
||||
|
||||
// Restrict an access to local addresses for external clients. We also
|
||||
// assume that all the DHCP leases we give are locally served or at least
|
||||
// shouldn't be accessible externally.
|
||||
subnetAddr := subnet.Addr()
|
||||
if !s.privateNets.Contains(subnetAddr) {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
log.Debug("dnsforward: addr %s is from locally served network", subnetAddr)
|
||||
|
||||
if !dctx.isLocalClient {
|
||||
log.Debug("dnsforward: %q requests an internal ip", pctx.Addr)
|
||||
pctx.Res = s.genNXDomain(req)
|
||||
|
||||
// Do not even put into query log.
|
||||
return resultCodeFinish
|
||||
}
|
||||
|
||||
// Do not perform unreversing ever again.
|
||||
dctx.unreversedReqIP = subnetAddr
|
||||
|
||||
// There is no need to filter request from external addresses since this
|
||||
// code is only executed when the request is for locally served ARPA
|
||||
// hostname so disable redundant filters.
|
||||
dctx.setts.ParentalEnabled = false
|
||||
dctx.setts.SafeBrowsingEnabled = false
|
||||
dctx.setts.SafeSearchEnabled = false
|
||||
dctx.setts.ServicesRules = nil
|
||||
|
||||
// Nothing to restrict.
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
// processDHCPAddrs responds to PTR requests if the target IP is leased by the
|
||||
// DHCP server.
|
||||
func (s *Server) processDHCPAddrs(dctx *dnsContext) (rc resultCode) {
|
||||
@ -562,20 +380,21 @@ func (s *Server) processDHCPAddrs(dctx *dnsContext) (rc resultCode) {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
ipAddr := dctx.unreversedReqIP
|
||||
if ipAddr == (netip.Addr{}) {
|
||||
pref := pctx.RequestedPrivateRDNS
|
||||
if pref == (netip.Prefix{}) {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
host := s.dhcpServer.HostByIP(ipAddr)
|
||||
addr := pref.Addr()
|
||||
host := s.dhcpServer.HostByIP(addr)
|
||||
if host == "" {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
log.Debug("dnsforward: dhcp client %s is %q", ipAddr, host)
|
||||
log.Debug("dnsforward: dhcp client %s is %q", addr, host)
|
||||
|
||||
req := pctx.Req
|
||||
resp := s.makeResponse(req)
|
||||
resp := s.replyCompressed(req)
|
||||
ptr := &dns.PTR{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: req.Question[0].Name,
|
||||
@ -593,62 +412,20 @@ func (s *Server) processDHCPAddrs(dctx *dnsContext) (rc resultCode) {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
// processLocalPTR responds to PTR requests if the target IP is detected to be
|
||||
// inside the local network and the query was not answered from DHCP.
|
||||
func (s *Server) processLocalPTR(dctx *dnsContext) (rc resultCode) {
|
||||
log.Debug("dnsforward: started processing local ptr")
|
||||
defer log.Debug("dnsforward: finished processing local ptr")
|
||||
|
||||
pctx := dctx.proxyCtx
|
||||
if pctx.Res != nil {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
ip := dctx.unreversedReqIP
|
||||
if ip == (netip.Addr{}) {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
s.serverLock.RLock()
|
||||
defer s.serverLock.RUnlock()
|
||||
|
||||
if s.conf.UsePrivateRDNS {
|
||||
s.recDetector.add(*pctx.Req)
|
||||
if err := s.localResolvers.Resolve(pctx); err != nil {
|
||||
log.Debug("dnsforward: resolving private address: %s", err)
|
||||
|
||||
// Generate the server failure if the private upstream configuration
|
||||
// is empty.
|
||||
//
|
||||
// This is a crutch, see TODO at [Server.localResolvers].
|
||||
if errors.Is(err, upstream.ErrNoUpstreams) {
|
||||
pctx.Res = s.genServerFailure(pctx.Req)
|
||||
|
||||
// Do not even put into query log.
|
||||
return resultCodeFinish
|
||||
}
|
||||
|
||||
dctx.err = err
|
||||
|
||||
return resultCodeError
|
||||
}
|
||||
}
|
||||
|
||||
if pctx.Res == nil {
|
||||
pctx.Res = s.genNXDomain(pctx.Req)
|
||||
|
||||
// Do not even put into query log.
|
||||
return resultCodeFinish
|
||||
}
|
||||
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
// Apply filtering logic
|
||||
func (s *Server) processFilteringBeforeRequest(dctx *dnsContext) (rc resultCode) {
|
||||
log.Debug("dnsforward: started processing filtering before req")
|
||||
defer log.Debug("dnsforward: finished processing filtering before req")
|
||||
|
||||
if dctx.proxyCtx.RequestedPrivateRDNS != (netip.Prefix{}) {
|
||||
// There is no need to filter request for locally served ARPA hostname
|
||||
// so disable redundant filters.
|
||||
dctx.setts.ParentalEnabled = false
|
||||
dctx.setts.SafeBrowsingEnabled = false
|
||||
dctx.setts.SafeSearchEnabled = false
|
||||
dctx.setts.ServicesRules = nil
|
||||
}
|
||||
|
||||
if dctx.proxyCtx.Res != nil {
|
||||
// Go on since the response is already set.
|
||||
return resultCodeSuccess
|
||||
@ -695,7 +472,7 @@ func (s *Server) processUpstream(dctx *dnsContext) (rc resultCode) {
|
||||
// local domain name if there is one.
|
||||
name := req.Question[0].Name
|
||||
log.Debug("dnsforward: dhcp client hostname %q was not filtered", name[:len(name)-1])
|
||||
pctx.Res = s.genNXDomain(req)
|
||||
pctx.Res = s.NewMsgNXDOMAIN(req)
|
||||
|
||||
return resultCodeFinish
|
||||
}
|
||||
@ -712,21 +489,7 @@ func (s *Server) processUpstream(dctx *dnsContext) (rc resultCode) {
|
||||
return resultCodeError
|
||||
}
|
||||
|
||||
if err := prx.Resolve(pctx); err != nil {
|
||||
if errors.Is(err, upstream.ErrNoUpstreams) {
|
||||
// Do not even put into querylog. Currently this happens either
|
||||
// when the private resolvers enabled and the request is DNS64 PTR,
|
||||
// or when the client isn't considered local by prx.
|
||||
//
|
||||
// TODO(e.burkov): Make proxy detect local client the same way as
|
||||
// AGH does.
|
||||
pctx.Res = s.genNXDomain(req)
|
||||
|
||||
return resultCodeFinish
|
||||
}
|
||||
|
||||
dctx.err = err
|
||||
|
||||
if dctx.err = prx.Resolve(pctx); dctx.err != nil {
|
||||
return resultCodeError
|
||||
}
|
||||
|
||||
@ -810,7 +573,7 @@ func (s *Server) setCustomUpstream(pctx *proxy.DNSContext, clientID string) {
|
||||
}
|
||||
|
||||
// Use the ClientID first, since it has a higher priority.
|
||||
id := stringutil.Coalesce(clientID, pctx.Addr.Addr().String())
|
||||
id := cmp.Or(clientID, pctx.Addr.Addr().String())
|
||||
upsConf, err := s.conf.ClientsContainer.UpstreamConfigByID(id, s.bootstrap)
|
||||
if err != nil {
|
||||
log.Error("dnsforward: getting custom upstreams for client %s: %s", id, err)
|
||||
|
@ -9,6 +9,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/AdguardTeam/urlfilter/rules"
|
||||
@ -375,44 +376,6 @@ func createTestDNSFilter(t *testing.T) (f *filtering.DNSFilter) {
|
||||
return f
|
||||
}
|
||||
|
||||
func TestServer_ProcessDetermineLocal(t *testing.T) {
|
||||
s := &Server{
|
||||
privateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
want assert.BoolAssertionFunc
|
||||
name string
|
||||
cliAddr netip.AddrPort
|
||||
}{{
|
||||
want: assert.True,
|
||||
name: "local",
|
||||
cliAddr: netip.MustParseAddrPort("192.168.0.1:1"),
|
||||
}, {
|
||||
want: assert.False,
|
||||
name: "external",
|
||||
cliAddr: netip.MustParseAddrPort("250.249.0.1:1"),
|
||||
}, {
|
||||
want: assert.False,
|
||||
name: "invalid",
|
||||
cliAddr: netip.AddrPort{},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
proxyCtx := &proxy.DNSContext{
|
||||
Addr: tc.cliAddr,
|
||||
}
|
||||
dctx := &dnsContext{
|
||||
proxyCtx: proxyCtx,
|
||||
}
|
||||
s.processDetermineLocal(dctx)
|
||||
|
||||
tc.want(t, dctx.isLocalClient)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) {
|
||||
const (
|
||||
localDomainSuffix = "lan"
|
||||
@ -482,9 +445,9 @@ func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) {
|
||||
|
||||
dctx := &dnsContext{
|
||||
proxyCtx: &proxy.DNSContext{
|
||||
Req: req,
|
||||
Req: req,
|
||||
IsPrivateClient: tc.isLocalCli,
|
||||
},
|
||||
isLocalClient: tc.isLocalCli,
|
||||
}
|
||||
|
||||
res := s.processDHCPHosts(dctx)
|
||||
@ -617,9 +580,9 @@ func TestServer_ProcessDHCPHosts(t *testing.T) {
|
||||
|
||||
dctx := &dnsContext{
|
||||
proxyCtx: &proxy.DNSContext{
|
||||
Req: req,
|
||||
Req: req,
|
||||
IsPrivateClient: true,
|
||||
},
|
||||
isLocalClient: true,
|
||||
}
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
@ -654,19 +617,28 @@ func TestServer_ProcessDHCPHosts(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_ProcessRestrictLocal(t *testing.T) {
|
||||
// TODO(e.burkov): Rewrite this test to use the whole server instead of just
|
||||
// testing the [handleDNSRequest] method. See comment on
|
||||
// "from_external_for_local" test case.
|
||||
func TestServer_HandleDNSRequest_restrictLocal(t *testing.T) {
|
||||
intAddr := netip.MustParseAddr("192.168.1.1")
|
||||
intPTRQuestion, err := netutil.IPToReversedAddr(intAddr.AsSlice())
|
||||
require.NoError(t, err)
|
||||
|
||||
extAddr := netip.MustParseAddr("254.253.252.1")
|
||||
extPTRQuestion, err := netutil.IPToReversedAddr(extAddr.AsSlice())
|
||||
require.NoError(t, err)
|
||||
|
||||
const (
|
||||
extPTRQuestion = "251.252.253.254.in-addr.arpa."
|
||||
extPTRAnswer = "host1.example.net."
|
||||
intPTRQuestion = "1.1.168.192.in-addr.arpa."
|
||||
intPTRAnswer = "some.local-client."
|
||||
extPTRAnswer = "host1.example.net."
|
||||
intPTRAnswer = "some.local-client."
|
||||
)
|
||||
|
||||
localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
|
||||
resp := cmp.Or(
|
||||
aghtest.MatchedResponse(req, dns.TypePTR, extPTRQuestion, extPTRAnswer),
|
||||
aghtest.MatchedResponse(req, dns.TypePTR, intPTRQuestion, intPTRAnswer),
|
||||
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
|
||||
(&dns.Msg{}).SetRcode(req, dns.RcodeNameError),
|
||||
)
|
||||
|
||||
require.NoError(testutil.PanicT{}, w.WriteMsg(resp))
|
||||
@ -692,123 +664,165 @@ func TestServer_ProcessRestrictLocal(t *testing.T) {
|
||||
startDeferStop(t, s)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
want string
|
||||
question net.IP
|
||||
cliAddr netip.AddrPort
|
||||
wantLen int
|
||||
name string
|
||||
question string
|
||||
wantErr error
|
||||
wantAns []dns.RR
|
||||
isPrivate bool
|
||||
}{{
|
||||
name: "from_local_to_external",
|
||||
want: "host1.example.net.",
|
||||
question: net.IP{254, 253, 252, 251},
|
||||
cliAddr: netip.MustParseAddrPort("192.168.10.10:1"),
|
||||
wantLen: 1,
|
||||
name: "from_local_for_external",
|
||||
question: extPTRQuestion,
|
||||
wantErr: nil,
|
||||
wantAns: []dns.RR{&dns.PTR{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: dns.Fqdn(extPTRQuestion),
|
||||
Rrtype: dns.TypePTR,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 60,
|
||||
Rdlength: uint16(len(extPTRAnswer) + 1),
|
||||
},
|
||||
Ptr: dns.Fqdn(extPTRAnswer),
|
||||
}},
|
||||
isPrivate: true,
|
||||
}, {
|
||||
name: "from_external_for_local",
|
||||
want: "",
|
||||
question: net.IP{192, 168, 1, 1},
|
||||
cliAddr: netip.MustParseAddrPort("254.253.252.251:1"),
|
||||
wantLen: 0,
|
||||
// In theory this case is not reproducible because [proxy.Proxy] should
|
||||
// respond to such queries with NXDOMAIN before they reach
|
||||
// [Server.handleDNSRequest].
|
||||
name: "from_external_for_local",
|
||||
question: intPTRQuestion,
|
||||
wantErr: upstream.ErrNoUpstreams,
|
||||
wantAns: nil,
|
||||
isPrivate: false,
|
||||
}, {
|
||||
name: "from_local_for_local",
|
||||
want: "some.local-client.",
|
||||
question: net.IP{192, 168, 1, 1},
|
||||
cliAddr: netip.MustParseAddrPort("192.168.1.2:1"),
|
||||
wantLen: 1,
|
||||
question: intPTRQuestion,
|
||||
wantErr: nil,
|
||||
wantAns: []dns.RR{&dns.PTR{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: dns.Fqdn(intPTRQuestion),
|
||||
Rrtype: dns.TypePTR,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 60,
|
||||
Rdlength: uint16(len(intPTRAnswer) + 1),
|
||||
},
|
||||
Ptr: dns.Fqdn(intPTRAnswer),
|
||||
}},
|
||||
isPrivate: true,
|
||||
}, {
|
||||
name: "from_external_for_external",
|
||||
want: "host1.example.net.",
|
||||
question: net.IP{254, 253, 252, 251},
|
||||
cliAddr: netip.MustParseAddrPort("254.253.252.255:1"),
|
||||
wantLen: 1,
|
||||
question: extPTRQuestion,
|
||||
wantErr: nil,
|
||||
wantAns: []dns.RR{&dns.PTR{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: dns.Fqdn(extPTRQuestion),
|
||||
Rrtype: dns.TypePTR,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 60,
|
||||
Rdlength: uint16(len(extPTRAnswer) + 1),
|
||||
},
|
||||
Ptr: dns.Fqdn(extPTRAnswer),
|
||||
}},
|
||||
isPrivate: false,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
reqAddr, err := dns.ReverseAddr(tc.question.String())
|
||||
require.NoError(t, err)
|
||||
req := createTestMessageWithType(reqAddr, dns.TypePTR)
|
||||
pref, extErr := netutil.ExtractReversedAddr(tc.question)
|
||||
require.NoError(t, extErr)
|
||||
|
||||
req := createTestMessageWithType(dns.Fqdn(tc.question), dns.TypePTR)
|
||||
pctx := &proxy.DNSContext{
|
||||
Proto: proxy.ProtoTCP,
|
||||
Req: req,
|
||||
Addr: tc.cliAddr,
|
||||
Req: req,
|
||||
IsPrivateClient: tc.isPrivate,
|
||||
}
|
||||
// TODO(e.burkov): Configure the subnet set properly.
|
||||
if netutil.IsLocallyServed(pref.Addr()) {
|
||||
pctx.RequestedPrivateRDNS = pref
|
||||
}
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err = s.handleDNSRequest(nil, pctx)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, pctx.Res)
|
||||
require.Len(t, pctx.Res.Answer, tc.wantLen)
|
||||
err = s.handleDNSRequest(s.dnsProxy, pctx)
|
||||
require.ErrorIs(t, err, tc.wantErr)
|
||||
|
||||
if tc.wantLen > 0 {
|
||||
assert.Equal(t, tc.want, pctx.Res.Answer[0].(*dns.PTR).Ptr)
|
||||
}
|
||||
require.NotNil(t, pctx.Res)
|
||||
assert.Equal(t, tc.wantAns, pctx.Res.Answer)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_ProcessLocalPTR_usingResolvers(t *testing.T) {
|
||||
func TestServer_ProcessUpstream_localPTR(t *testing.T) {
|
||||
const locDomain = "some.local."
|
||||
const reqAddr = "1.1.168.192.in-addr.arpa."
|
||||
|
||||
localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
|
||||
resp := cmp.Or(
|
||||
aghtest.MatchedResponse(req, dns.TypePTR, reqAddr, locDomain),
|
||||
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
|
||||
(&dns.Msg{}).SetRcode(req, dns.RcodeNameError),
|
||||
)
|
||||
|
||||
require.NoError(testutil.PanicT{}, w.WriteMsg(resp))
|
||||
})
|
||||
localUpsAddr := aghtest.StartLocalhostUpstream(t, localUpsHdlr).String()
|
||||
|
||||
s := createTestServer(
|
||||
t,
|
||||
&filtering.Config{
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
},
|
||||
ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
Config: Config{
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
},
|
||||
UsePrivateRDNS: true,
|
||||
LocalPTRResolvers: []string{localUpsAddr},
|
||||
ServePlainDNS: true,
|
||||
},
|
||||
)
|
||||
|
||||
var proxyCtx *proxy.DNSContext
|
||||
var dnsCtx *dnsContext
|
||||
setup := func(use bool) {
|
||||
proxyCtx = &proxy.DNSContext{
|
||||
Addr: testClientAddrPort,
|
||||
Req: createTestMessageWithType(reqAddr, dns.TypePTR),
|
||||
newPrxCtx := func() (prxCtx *proxy.DNSContext) {
|
||||
return &proxy.DNSContext{
|
||||
Addr: testClientAddrPort,
|
||||
Req: createTestMessageWithType(reqAddr, dns.TypePTR),
|
||||
IsPrivateClient: true,
|
||||
RequestedPrivateRDNS: netip.MustParsePrefix("192.168.1.1/32"),
|
||||
}
|
||||
dnsCtx = &dnsContext{
|
||||
proxyCtx: proxyCtx,
|
||||
unreversedReqIP: netip.MustParseAddr("192.168.1.1"),
|
||||
}
|
||||
s.conf.UsePrivateRDNS = use
|
||||
}
|
||||
|
||||
t.Run("enabled", func(t *testing.T) {
|
||||
setup(true)
|
||||
s := createTestServer(
|
||||
t,
|
||||
&filtering.Config{
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
},
|
||||
ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
Config: Config{
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
},
|
||||
UsePrivateRDNS: true,
|
||||
LocalPTRResolvers: []string{localUpsAddr},
|
||||
ServePlainDNS: true,
|
||||
},
|
||||
)
|
||||
pctx := newPrxCtx()
|
||||
|
||||
rc := s.processLocalPTR(dnsCtx)
|
||||
rc := s.processUpstream(&dnsContext{proxyCtx: pctx})
|
||||
require.Equal(t, resultCodeSuccess, rc)
|
||||
require.NotEmpty(t, proxyCtx.Res.Answer)
|
||||
require.NotEmpty(t, pctx.Res.Answer)
|
||||
ptr := testutil.RequireTypeAssert[*dns.PTR](t, pctx.Res.Answer[0])
|
||||
|
||||
assert.Equal(t, locDomain, proxyCtx.Res.Answer[0].(*dns.PTR).Ptr)
|
||||
assert.Equal(t, locDomain, ptr.Ptr)
|
||||
})
|
||||
|
||||
t.Run("disabled", func(t *testing.T) {
|
||||
setup(false)
|
||||
s := createTestServer(
|
||||
t,
|
||||
&filtering.Config{
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
},
|
||||
ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
Config: Config{
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
},
|
||||
UsePrivateRDNS: false,
|
||||
LocalPTRResolvers: []string{localUpsAddr},
|
||||
ServePlainDNS: true,
|
||||
},
|
||||
)
|
||||
pctx := newPrxCtx()
|
||||
|
||||
rc := s.processLocalPTR(dnsCtx)
|
||||
require.Equal(t, resultCodeFinish, rc)
|
||||
require.Empty(t, proxyCtx.Res.Answer)
|
||||
rc := s.processUpstream(&dnsContext{proxyCtx: pctx})
|
||||
require.Equal(t, resultCodeError, rc)
|
||||
require.Empty(t, pctx.Res.Answer)
|
||||
})
|
||||
}
|
||||
|
||||
@ -826,129 +840,3 @@ func TestIPStringFromAddr(t *testing.T) {
|
||||
assert.Empty(t, ipStringFromAddr(nil))
|
||||
})
|
||||
}
|
||||
|
||||
// TODO(e.burkov): Add fuzzing when moving to golibs.
|
||||
func TestExtractARPASubnet(t *testing.T) {
|
||||
const (
|
||||
v4Suf = `in-addr.arpa.`
|
||||
v4Part = `2.1.` + v4Suf
|
||||
v4Whole = `4.3.` + v4Part
|
||||
|
||||
v6Suf = `ip6.arpa.`
|
||||
v6Part = `4.3.2.1.0.0.0.0.0.0.0.0.0.0.0.0.` + v6Suf
|
||||
v6Whole = `f.e.d.c.0.0.0.0.0.0.0.0.0.0.0.0.` + v6Part
|
||||
)
|
||||
|
||||
v4Pref := netip.MustParsePrefix("1.2.3.4/32")
|
||||
v4PrefPart := netip.MustParsePrefix("1.2.0.0/16")
|
||||
v6Pref := netip.MustParsePrefix("::1234:0:0:0:cdef/128")
|
||||
v6PrefPart := netip.MustParsePrefix("0:0:0:1234::/64")
|
||||
|
||||
testCases := []struct {
|
||||
want netip.Prefix
|
||||
name string
|
||||
domain string
|
||||
wantErr string
|
||||
}{{
|
||||
want: netip.Prefix{},
|
||||
name: "not_an_arpa",
|
||||
domain: "some.domain.name.",
|
||||
wantErr: `bad arpa domain name "some.domain.name.": ` +
|
||||
`not a reversed ip network`,
|
||||
}, {
|
||||
want: netip.Prefix{},
|
||||
name: "bad_domain_name",
|
||||
domain: "abc.123.",
|
||||
wantErr: `bad domain name "abc.123": ` +
|
||||
`bad top-level domain name label "123": all octets are numeric`,
|
||||
}, {
|
||||
want: v4Pref,
|
||||
name: "whole_v4",
|
||||
domain: v4Whole,
|
||||
wantErr: "",
|
||||
}, {
|
||||
want: v4PrefPart,
|
||||
name: "partial_v4",
|
||||
domain: v4Part,
|
||||
wantErr: "",
|
||||
}, {
|
||||
want: v4Pref,
|
||||
name: "whole_v4_within_domain",
|
||||
domain: "a." + v4Whole,
|
||||
wantErr: "",
|
||||
}, {
|
||||
want: v4Pref,
|
||||
name: "whole_v4_additional_label",
|
||||
domain: "5." + v4Whole,
|
||||
wantErr: "",
|
||||
}, {
|
||||
want: v4PrefPart,
|
||||
name: "partial_v4_within_domain",
|
||||
domain: "a." + v4Part,
|
||||
wantErr: "",
|
||||
}, {
|
||||
want: v4PrefPart,
|
||||
name: "overflow_v4",
|
||||
domain: "256." + v4Part,
|
||||
wantErr: "",
|
||||
}, {
|
||||
want: v4PrefPart,
|
||||
name: "overflow_v4_within_domain",
|
||||
domain: "a.256." + v4Part,
|
||||
wantErr: "",
|
||||
}, {
|
||||
want: netip.Prefix{},
|
||||
name: "empty_v4",
|
||||
domain: v4Suf,
|
||||
wantErr: `bad arpa domain name "in-addr.arpa": ` +
|
||||
`not a reversed ip network`,
|
||||
}, {
|
||||
want: netip.Prefix{},
|
||||
name: "empty_v4_within_domain",
|
||||
domain: "a." + v4Suf,
|
||||
wantErr: `bad arpa domain name "in-addr.arpa": ` +
|
||||
`not a reversed ip network`,
|
||||
}, {
|
||||
want: v6Pref,
|
||||
name: "whole_v6",
|
||||
domain: v6Whole,
|
||||
wantErr: "",
|
||||
}, {
|
||||
want: v6PrefPart,
|
||||
name: "partial_v6",
|
||||
domain: v6Part,
|
||||
}, {
|
||||
want: v6Pref,
|
||||
name: "whole_v6_within_domain",
|
||||
domain: "g." + v6Whole,
|
||||
wantErr: "",
|
||||
}, {
|
||||
want: v6Pref,
|
||||
name: "whole_v6_additional_label",
|
||||
domain: "1." + v6Whole,
|
||||
wantErr: "",
|
||||
}, {
|
||||
want: v6PrefPart,
|
||||
name: "partial_v6_within_domain",
|
||||
domain: "label." + v6Part,
|
||||
wantErr: "",
|
||||
}, {
|
||||
want: netip.Prefix{},
|
||||
name: "empty_v6",
|
||||
domain: v6Suf,
|
||||
wantErr: `bad arpa domain name "ip6.arpa": not a reversed ip network`,
|
||||
}, {
|
||||
want: netip.Prefix{},
|
||||
name: "empty_v6_within_domain",
|
||||
domain: "g." + v6Suf,
|
||||
wantErr: `bad arpa domain name "ip6.arpa": not a reversed ip network`,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
subnet, err := extractARPASubnet(tc.domain)
|
||||
testutil.AssertErrorMsg(t, tc.wantErr, err)
|
||||
assert.Equal(t, tc.want, subnet)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -1,115 +0,0 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/cache"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// uint* sizes in bytes to improve readability.
|
||||
//
|
||||
// TODO(e.burkov): Remove when there will be a more regardful way to define
|
||||
// those. See https://github.com/golang/go/issues/29982.
|
||||
const (
|
||||
uint16sz = 2
|
||||
uint64sz = 8
|
||||
)
|
||||
|
||||
// recursionDetector detects recursion in DNS forwarding.
|
||||
type recursionDetector struct {
|
||||
recentRequests cache.Cache
|
||||
ttl time.Duration
|
||||
}
|
||||
|
||||
// check checks if the passed req was already sent by the server.
|
||||
func (rd *recursionDetector) check(msg dns.Msg) (ok bool) {
|
||||
if len(msg.Question) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
key := msgToSignature(msg)
|
||||
expireData := rd.recentRequests.Get(key)
|
||||
if expireData == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
expire := time.Unix(0, int64(binary.BigEndian.Uint64(expireData)))
|
||||
|
||||
return time.Now().Before(expire)
|
||||
}
|
||||
|
||||
// add caches the msg if it has anything in the questions section.
|
||||
func (rd *recursionDetector) add(msg dns.Msg) {
|
||||
now := time.Now()
|
||||
|
||||
if len(msg.Question) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
key := msgToSignature(msg)
|
||||
expire64 := uint64(now.Add(rd.ttl).UnixNano())
|
||||
expire := make([]byte, uint64sz)
|
||||
binary.BigEndian.PutUint64(expire, expire64)
|
||||
|
||||
rd.recentRequests.Set(key, expire)
|
||||
}
|
||||
|
||||
// clear clears the recent requests cache.
|
||||
func (rd *recursionDetector) clear() {
|
||||
rd.recentRequests.Clear()
|
||||
}
|
||||
|
||||
// newRecursionDetector returns the initialized *recursionDetector.
|
||||
func newRecursionDetector(ttl time.Duration, suspectsNum uint) (rd *recursionDetector) {
|
||||
return &recursionDetector{
|
||||
recentRequests: cache.New(cache.Config{
|
||||
EnableLRU: true,
|
||||
MaxCount: suspectsNum,
|
||||
}),
|
||||
ttl: ttl,
|
||||
}
|
||||
}
|
||||
|
||||
// msgToSignature converts msg into it's signature represented in bytes.
|
||||
func msgToSignature(msg dns.Msg) (sig []byte) {
|
||||
sig = make([]byte, uint16sz*2+netutil.MaxDomainNameLen)
|
||||
// The binary.BigEndian byte order is used everywhere except when the real
|
||||
// machine's endianness is needed.
|
||||
byteOrder := binary.BigEndian
|
||||
byteOrder.PutUint16(sig[0:], msg.Id)
|
||||
q := msg.Question[0]
|
||||
byteOrder.PutUint16(sig[uint16sz:], q.Qtype)
|
||||
copy(sig[2*uint16sz:], []byte(q.Name))
|
||||
|
||||
return sig
|
||||
}
|
||||
|
||||
// msgToSignatureSlow converts msg into it's signature represented in bytes in
|
||||
// the less efficient way.
|
||||
//
|
||||
// See BenchmarkMsgToSignature.
|
||||
func msgToSignatureSlow(msg dns.Msg) (sig []byte) {
|
||||
type msgSignature struct {
|
||||
name [netutil.MaxDomainNameLen]byte
|
||||
id uint16
|
||||
qtype uint16
|
||||
}
|
||||
|
||||
b := bytes.NewBuffer(sig)
|
||||
q := msg.Question[0]
|
||||
signature := msgSignature{
|
||||
id: msg.Id,
|
||||
qtype: q.Qtype,
|
||||
}
|
||||
copy(signature.name[:], q.Name)
|
||||
if err := binary.Write(b, binary.BigEndian, signature); err != nil {
|
||||
log.Debug("writing message signature: %s", err)
|
||||
}
|
||||
|
||||
return b.Bytes()
|
||||
}
|
@ -1,148 +0,0 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestRecursionDetector_Check(t *testing.T) {
|
||||
rd := newRecursionDetector(0, 2)
|
||||
|
||||
const (
|
||||
recID = 1234
|
||||
recTTL = time.Hour * 100
|
||||
)
|
||||
|
||||
const nonRecID = recID * 2
|
||||
|
||||
sampleQuestion := dns.Question{
|
||||
Name: "some.domain",
|
||||
Qtype: dns.TypeAAAA,
|
||||
}
|
||||
sampleMsg := dns.Msg{
|
||||
MsgHdr: dns.MsgHdr{
|
||||
Id: recID,
|
||||
},
|
||||
Question: []dns.Question{sampleQuestion},
|
||||
}
|
||||
|
||||
// Manually add the message with big ttl.
|
||||
key := msgToSignature(sampleMsg)
|
||||
expire := make([]byte, uint64sz)
|
||||
binary.BigEndian.PutUint64(expire, uint64(time.Now().Add(recTTL).UnixNano()))
|
||||
rd.recentRequests.Set(key, expire)
|
||||
|
||||
// Add an expired message.
|
||||
sampleMsg.Id = nonRecID
|
||||
rd.add(sampleMsg)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
questions []dns.Question
|
||||
id uint16
|
||||
want bool
|
||||
}{{
|
||||
name: "recurrent",
|
||||
questions: []dns.Question{sampleQuestion},
|
||||
id: recID,
|
||||
want: true,
|
||||
}, {
|
||||
name: "not_suspected",
|
||||
questions: []dns.Question{sampleQuestion},
|
||||
id: recID + 1,
|
||||
want: false,
|
||||
}, {
|
||||
name: "expired",
|
||||
questions: []dns.Question{sampleQuestion},
|
||||
id: nonRecID,
|
||||
want: false,
|
||||
}, {
|
||||
name: "empty",
|
||||
questions: []dns.Question{},
|
||||
id: nonRecID,
|
||||
want: false,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
sampleMsg.Id = tc.id
|
||||
sampleMsg.Question = tc.questions
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
detected := rd.check(sampleMsg)
|
||||
assert.Equal(t, tc.want, detected)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecursionDetector_Suspect(t *testing.T) {
|
||||
rd := newRecursionDetector(0, 1)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
msg dns.Msg
|
||||
want int
|
||||
}{{
|
||||
name: "simple",
|
||||
msg: dns.Msg{
|
||||
MsgHdr: dns.MsgHdr{
|
||||
Id: 1234,
|
||||
},
|
||||
Question: []dns.Question{{
|
||||
Name: "some.domain",
|
||||
Qtype: dns.TypeA,
|
||||
}},
|
||||
},
|
||||
want: 1,
|
||||
}, {
|
||||
name: "unencumbered",
|
||||
msg: dns.Msg{},
|
||||
want: 0,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Cleanup(rd.clear)
|
||||
rd.add(tc.msg)
|
||||
assert.Equal(t, tc.want, rd.recentRequests.Stats().Count)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
var sink []byte
|
||||
|
||||
func BenchmarkMsgToSignature(b *testing.B) {
|
||||
const name = "some.not.very.long.host.name"
|
||||
|
||||
msg := dns.Msg{
|
||||
MsgHdr: dns.MsgHdr{
|
||||
Id: 1234,
|
||||
},
|
||||
Question: []dns.Question{{
|
||||
Name: name,
|
||||
Qtype: dns.TypeAAAA,
|
||||
}},
|
||||
}
|
||||
|
||||
b.Run("efficient", func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
sink = msgToSignature(msg)
|
||||
}
|
||||
|
||||
assert.NotEmpty(b, sink)
|
||||
})
|
||||
|
||||
b.Run("inefficient", func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
sink = msgToSignatureSlow(msg)
|
||||
}
|
||||
|
||||
assert.NotEmpty(b, sink)
|
||||
})
|
||||
}
|
@ -2,90 +2,77 @@ package dnsforward
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"golang.org/x/exp/maps"
|
||||
)
|
||||
|
||||
// loadUpstreams parses upstream DNS servers from the configured file or from
|
||||
// the configuration itself.
|
||||
func (s *Server) loadUpstreams() (upstreams []string, err error) {
|
||||
if s.conf.UpstreamDNSFileName == "" {
|
||||
return stringutil.FilterOut(s.conf.UpstreamDNS, IsCommentOrEmpty), nil
|
||||
// newBootstrap returns a bootstrap resolver based on the configuration of s.
|
||||
// boots are the upstream resolvers that should be closed after use. r is the
|
||||
// actual bootstrap resolver, which may include the system hosts.
|
||||
//
|
||||
// TODO(e.burkov): This function currently returns a resolver and a slice of
|
||||
// the upstream resolvers, which are essentially the same. boots are returned
|
||||
// for being able to close them afterwards, but it introduces an implicit
|
||||
// contract that r could only be used before that. Anyway, this code should
|
||||
// improve when the [proxy.UpstreamConfig] will become an [upstream.Resolver]
|
||||
// and be used here.
|
||||
func newBootstrap(
|
||||
addrs []string,
|
||||
etcHosts upstream.Resolver,
|
||||
opts *upstream.Options,
|
||||
) (r upstream.Resolver, boots []*upstream.UpstreamResolver, err error) {
|
||||
if len(addrs) == 0 {
|
||||
addrs = defaultBootstrap
|
||||
}
|
||||
|
||||
var data []byte
|
||||
data, err = os.ReadFile(s.conf.UpstreamDNSFileName)
|
||||
boots, err = aghnet.ParseBootstraps(addrs, opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading upstream from file: %w", err)
|
||||
// Don't wrap the error, since it's informative enough as is.
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
upstreams = stringutil.SplitTrimmed(string(data), "\n")
|
||||
var parallel upstream.ParallelResolver
|
||||
for _, b := range boots {
|
||||
parallel = append(parallel, upstream.NewCachingResolver(b))
|
||||
}
|
||||
|
||||
log.Debug("dnsforward: got %d upstreams in %q", len(upstreams), s.conf.UpstreamDNSFileName)
|
||||
if etcHosts != nil {
|
||||
r = upstream.ConsequentResolver{etcHosts, parallel}
|
||||
} else {
|
||||
r = parallel
|
||||
}
|
||||
|
||||
return stringutil.FilterOut(upstreams, IsCommentOrEmpty), nil
|
||||
return r, boots, nil
|
||||
}
|
||||
|
||||
// prepareUpstreamSettings sets upstream DNS server settings.
|
||||
func (s *Server) prepareUpstreamSettings(boot upstream.Resolver) (err error) {
|
||||
// Load upstreams either from the file, or from the settings
|
||||
var upstreams []string
|
||||
upstreams, err = s.loadUpstreams()
|
||||
if err != nil {
|
||||
return fmt.Errorf("loading upstreams: %w", err)
|
||||
}
|
||||
|
||||
s.conf.UpstreamConfig, err = s.prepareUpstreamConfig(upstreams, defaultDNS, &upstream.Options{
|
||||
Bootstrap: boot,
|
||||
Timeout: s.conf.UpstreamTimeout,
|
||||
HTTPVersions: UpstreamHTTPVersions(s.conf.UseHTTP3Upstreams),
|
||||
PreferIPv6: s.conf.BootstrapPreferIPv6,
|
||||
// Use a customized set of RootCAs, because Go's default mechanism of
|
||||
// loading TLS roots does not always work properly on some routers so we're
|
||||
// loading roots manually and pass it here.
|
||||
//
|
||||
// See [aghtls.SystemRootCAs].
|
||||
//
|
||||
// TODO(a.garipov): Investigate if that's true.
|
||||
RootCAs: s.conf.TLSv12Roots,
|
||||
CipherSuites: s.conf.TLSCiphers,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("preparing upstream config: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// prepareUpstreamConfig returns the upstream configuration based on upstreams
|
||||
// and configuration of s.
|
||||
func (s *Server) prepareUpstreamConfig(
|
||||
// newUpstreamConfig returns the upstream configuration based on upstreams. If
|
||||
// upstreams slice specifies no default upstreams, defaultUpstreams are used to
|
||||
// create upstreams with no domain specifications. opts are used when creating
|
||||
// upstream configuration.
|
||||
func newUpstreamConfig(
|
||||
upstreams []string,
|
||||
defaultUpstreams []string,
|
||||
opts *upstream.Options,
|
||||
) (uc *proxy.UpstreamConfig, err error) {
|
||||
uc, err = proxy.ParseUpstreamsConfig(upstreams, opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing upstream config: %w", err)
|
||||
return uc, fmt.Errorf("parsing upstreams: %w", err)
|
||||
}
|
||||
|
||||
if len(uc.Upstreams) == 0 && defaultUpstreams != nil {
|
||||
if len(uc.Upstreams) == 0 && len(defaultUpstreams) > 0 {
|
||||
log.Info("dnsforward: warning: no default upstreams specified, using %v", defaultUpstreams)
|
||||
|
||||
var defaultUpstreamConfig *proxy.UpstreamConfig
|
||||
defaultUpstreamConfig, err = proxy.ParseUpstreamsConfig(defaultUpstreams, opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing default upstreams: %w", err)
|
||||
return uc, fmt.Errorf("parsing default upstreams: %w", err)
|
||||
}
|
||||
|
||||
uc.Upstreams = defaultUpstreamConfig.Upstreams
|
||||
@ -94,6 +81,54 @@ func (s *Server) prepareUpstreamConfig(
|
||||
return uc, nil
|
||||
}
|
||||
|
||||
// newPrivateConfig creates an upstream configuration for resolving PTR records
|
||||
// for local addresses. The configuration is built either from the provided
|
||||
// addresses or from the system resolvers. unwanted filters the resulting
|
||||
// upstream configuration.
|
||||
func newPrivateConfig(
|
||||
addrs []string,
|
||||
unwanted addrPortSet,
|
||||
sysResolvers SystemResolvers,
|
||||
privateNets netutil.SubnetSet,
|
||||
opts *upstream.Options,
|
||||
) (uc *proxy.UpstreamConfig, err error) {
|
||||
confNeedsFiltering := len(addrs) > 0
|
||||
if confNeedsFiltering {
|
||||
addrs = stringutil.FilterOut(addrs, IsCommentOrEmpty)
|
||||
} else {
|
||||
sysResolvers := slices.DeleteFunc(slices.Clone(sysResolvers.Addrs()), unwanted.Has)
|
||||
addrs = make([]string, 0, len(sysResolvers))
|
||||
for _, r := range sysResolvers {
|
||||
addrs = append(addrs, r.String())
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug("dnsforward: upstreams to resolve ptr for local addresses: %v", addrs)
|
||||
|
||||
uc, err = proxy.ParseUpstreamsConfig(addrs, opts)
|
||||
if err != nil {
|
||||
return uc, fmt.Errorf("preparing private upstreams: %w", err)
|
||||
}
|
||||
|
||||
if !confNeedsFiltering {
|
||||
return uc, nil
|
||||
}
|
||||
|
||||
err = filterOutAddrs(uc, unwanted)
|
||||
if err != nil {
|
||||
return uc, fmt.Errorf("filtering private upstreams: %w", err)
|
||||
}
|
||||
|
||||
// Prevalidate the config to catch the exact error before creating proxy.
|
||||
// See TODO on [PrivateRDNSError].
|
||||
err = proxy.ValidatePrivateConfig(uc, privateNets)
|
||||
if err != nil {
|
||||
return uc, &PrivateRDNSError{err: err}
|
||||
}
|
||||
|
||||
return uc, nil
|
||||
}
|
||||
|
||||
// UpstreamHTTPVersions returns the HTTP versions for upstream configuration
|
||||
// depending on configuration.
|
||||
func UpstreamHTTPVersions(http3 bool) (v []upstream.HTTPVersion) {
|
||||
@ -130,85 +165,9 @@ func setProxyUpstreamMode(
|
||||
return nil
|
||||
}
|
||||
|
||||
// createBootstrap returns a bootstrap resolver based on the configuration of s.
|
||||
// boots are the upstream resolvers that should be closed after use. r is the
|
||||
// actual bootstrap resolver, which may include the system hosts.
|
||||
//
|
||||
// TODO(e.burkov): This function currently returns a resolver and a slice of
|
||||
// the upstream resolvers, which are essentially the same. boots are returned
|
||||
// for being able to close them afterwards, but it introduces an implicit
|
||||
// contract that r could only be used before that. Anyway, this code should
|
||||
// improve when the [proxy.UpstreamConfig] will become an [upstream.Resolver]
|
||||
// and be used here.
|
||||
func (s *Server) createBootstrap(
|
||||
addrs []string,
|
||||
opts *upstream.Options,
|
||||
) (r upstream.Resolver, boots []*upstream.UpstreamResolver, err error) {
|
||||
if len(addrs) == 0 {
|
||||
addrs = defaultBootstrap
|
||||
}
|
||||
|
||||
boots, err = aghnet.ParseBootstraps(addrs, opts)
|
||||
if err != nil {
|
||||
// Don't wrap the error, since it's informative enough as is.
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
var parallel upstream.ParallelResolver
|
||||
for _, b := range boots {
|
||||
parallel = append(parallel, upstream.NewCachingResolver(b))
|
||||
}
|
||||
|
||||
if s.etcHosts != nil {
|
||||
r = upstream.ConsequentResolver{s.etcHosts, parallel}
|
||||
} else {
|
||||
r = parallel
|
||||
}
|
||||
|
||||
return r, boots, nil
|
||||
}
|
||||
|
||||
// IsCommentOrEmpty returns true if s starts with a "#" character or is empty.
|
||||
// This function is useful for filtering out non-upstream lines from upstream
|
||||
// configs.
|
||||
func IsCommentOrEmpty(s string) (ok bool) {
|
||||
return len(s) == 0 || s[0] == '#'
|
||||
}
|
||||
|
||||
// ValidateUpstreamsPrivate validates each upstream and returns an error if any
|
||||
// upstream is invalid or if there are no default upstreams specified. It also
|
||||
// checks each domain of domain-specific upstreams for being ARPA pointing to
|
||||
// a locally-served network. privateNets must not be nil.
|
||||
func ValidateUpstreamsPrivate(upstreams []string, privateNets netutil.SubnetSet) (err error) {
|
||||
conf, err := proxy.ParseUpstreamsConfig(upstreams, &upstream.Options{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating config: %w", err)
|
||||
}
|
||||
|
||||
if conf == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
keys := maps.Keys(conf.DomainReservedUpstreams)
|
||||
slices.Sort(keys)
|
||||
|
||||
var errs []error
|
||||
for _, domain := range keys {
|
||||
var subnet netip.Prefix
|
||||
subnet, err = extractARPASubnet(domain)
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if !privateNets.Contains(subnet.Addr()) {
|
||||
errs = append(
|
||||
errs,
|
||||
fmt.Errorf("arpa domain %q should point to a locally-served network", domain),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return errors.Annotate(errors.Join(errs...), "checking domain-specific upstreams: %w")
|
||||
}
|
||||
|
@ -18,7 +18,6 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/stats"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
@ -157,14 +156,12 @@ func initDNSServer(
|
||||
return fmt.Errorf("newServerConfig: %w", err)
|
||||
}
|
||||
|
||||
// Try to prepare the server with disabled private RDNS resolution if it
|
||||
// failed to prepare as is. See TODO on [ErrBadPrivateRDNSUpstreams].
|
||||
err = Context.dnsServer.Prepare(dnsConf)
|
||||
if privRDNSErr := (&dnsforward.PrivateRDNSError{}); errors.As(err, &privRDNSErr) {
|
||||
log.Info("WARNING: %s; trying to disable private RDNS resolution", err)
|
||||
|
||||
// TODO(e.burkov): Recreate the server with private RDNS disabled. This
|
||||
// should go away once the private RDNS resolution is moved to the proxy.
|
||||
var locResErr *dnsforward.LocalResolversError
|
||||
if errors.As(err, &locResErr) && errors.Is(locResErr.Err, upstream.ErrNoUpstreams) {
|
||||
log.Info("WARNING: no local resolvers configured while private RDNS " +
|
||||
"resolution enabled, trying to disable")
|
||||
dnsConf.UsePrivateRDNS = false
|
||||
err = Context.dnsServer.Prepare(dnsConf)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user