diff --git a/CHANGELOG.md b/CHANGELOG.md index d784fb76..f86766af 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ and this project adheres to ### Added +- Detection and handling of recurrent requests ([#3185]). - The ability to completely disable reverse DNS resolving of IPs from locally-served networks ([#3184]). - New flag `--local-frontend` to serve dinamically changeable frontend files @@ -38,6 +39,7 @@ released by then. - Go 1.15 support. [#3184]: https://github.com/AdguardTeam/AdGuardHome/issues/3184 +[#3185]: https://github.com/AdguardTeam/AdGuardHome/issues/3185 diff --git a/internal/aghnet/addr.go b/internal/aghnet/addr.go index ebe7d48d..09971955 100644 --- a/internal/aghnet/addr.go +++ b/internal/aghnet/addr.go @@ -42,11 +42,11 @@ func ValidateHardwareAddress(hwa net.HardwareAddr) (err error) { // according to RFC 1035. const maxDomainLabelLen = 63 -// maxDomainNameLen is the maximum allowed length of a full domain name +// MaxDomainNameLen is the maximum allowed length of a full domain name // according to RFC 1035. // // See https://stackoverflow.com/a/32294443/1892060. -const maxDomainNameLen = 253 +const MaxDomainNameLen = 253 // ValidateDomainNameLabel returns an error if label is not a valid label of // a domain name. @@ -97,8 +97,8 @@ func ValidateDomainName(name string) (err error) { l := len(name) if l == 0 { return errors.Error("domain name is empty") - } else if l > maxDomainNameLen { - return fmt.Errorf("too long, max: %d", maxDomainNameLen) + } else if l > MaxDomainNameLen { + return fmt.Errorf("too long, max: %d", MaxDomainNameLen) } labels := strings.Split(name, ".") diff --git a/internal/dnsforward/dns.go b/internal/dnsforward/dns.go index 364d447a..92780057 100644 --- a/internal/dnsforward/dns.go +++ b/internal/dnsforward/dns.go @@ -82,6 +82,7 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error { // (*proxy.Proxy).handleDNSRequest method performs it before calling the // appropriate handler. mods := []modProcessFunc{ + s.processRecursion, processInitial, s.processDetermineLocal, s.processInternalHosts, @@ -90,7 +91,7 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error { processClientID, processFilteringBeforeRequest, s.processLocalPTR, - processUpstream, + s.processUpstream, processDNSSECAfterResponse, processFilteringAfterResponse, s.ipset.process, @@ -116,6 +117,22 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error { return nil } +// processRecursion checks the incoming request and halts it's handling if s +// have tried to resolve it recently. +func (s *Server) processRecursion(dctx *dnsContext) (rc resultCode) { + pctx := dctx.proxyCtx + + if msg := pctx.Req; msg != nil && s.recDetector.check(*msg) { + log.Debug("recursion detected resolving %q", msg.Question[0].Name) + pctx.Res = s.genNXDomain(pctx.Req) + + return resultCodeFinish + + } + + return resultCodeSuccess +} + // Perform initial checks; process WHOIS & rDNS func processInitial(ctx *dnsContext) (rc resultCode) { s := ctx.srv @@ -422,6 +439,7 @@ func (s *Server) processLocalPTR(ctx *dnsContext) (rc resultCode) { } if s.conf.UsePrivateRDNS { + s.recDetector.add(*d.Req) if err := s.localResolvers.Resolve(d); err != nil { ctx.err = err @@ -472,8 +490,7 @@ func processFilteringBeforeRequest(ctx *dnsContext) (rc resultCode) { } // processUpstream passes request to upstream servers and handles the response. -func processUpstream(ctx *dnsContext) (rc resultCode) { - s := ctx.srv +func (s *Server) processUpstream(ctx *dnsContext) (rc resultCode) { d := ctx.proxyCtx if d.Res != nil { return resultCodeSuccess // response is already set - nothing to do @@ -481,18 +498,18 @@ func processUpstream(ctx *dnsContext) (rc resultCode) { if d.Addr != nil && s.conf.GetCustomUpstreamByClient != nil { clientIP := IPStringFromAddr(d.Addr) - upstreamsConf := s.conf.GetCustomUpstreamByClient(clientIP) - if upstreamsConf != nil { - log.Debug("Using custom upstreams for %s", clientIP) - d.CustomUpstreamConfig = upstreamsConf + if upsConf := s.conf.GetCustomUpstreamByClient(clientIP); upsConf != nil { + log.Debug("dns: using custom upstreams for client %s", clientIP) + d.CustomUpstreamConfig = upsConf } } + req := d.Req if s.conf.EnableDNSSEC { - opt := d.Req.IsEdns0() + opt := req.IsEdns0() if opt == nil { - log.Debug("dns: Adding OPT record with DNSSEC flag") - d.Req.SetEdns0(4096, true) + log.Debug("dns: adding OPT record with DNSSEC flag") + req.SetEdns0(4096, true) } else if !opt.Do() { opt.SetDo(true) } else { @@ -501,13 +518,13 @@ func processUpstream(ctx *dnsContext) (rc resultCode) { } // request was not filtered so let it be processed further - err := s.dnsProxy.Resolve(d) - if err != nil { - ctx.err = err + s.recDetector.add(*req) + if ctx.err = s.dnsProxy.Resolve(d); ctx.err != nil { return resultCodeError } ctx.responseFromUpstream = true + return resultCodeSuccess } diff --git a/internal/dnsforward/dnsforward.go b/internal/dnsforward/dnsforward.go index 15ddbc57..e0c938e0 100644 --- a/internal/dnsforward/dnsforward.go +++ b/internal/dnsforward/dnsforward.go @@ -76,6 +76,7 @@ type Server struct { ipset ipsetCtx subnetDetector *aghnet.SubnetDetector localResolvers *proxy.Proxy + recDetector *recursionDetector tableHostToIP hostToIPTable tableHostToIPLock sync.Mutex @@ -121,6 +122,14 @@ func domainNameToSuffix(tld string) (suffix string) { return string(b) } +const ( + // recursionTTL is the time recursive request is cached for. + recursionTTL = 5 * time.Second + // cachedRecurrentReqNum is the maximum number of cached recurrent + // requests. + cachedRecurrentReqNum = 1000 +) + // NewServer creates a new instance of the dnsforward.Server // Note: this function must be called only once func NewServer(p DNSCreateParams) (s *Server, err error) { @@ -142,6 +151,7 @@ func NewServer(p DNSCreateParams) (s *Server, err error) { queryLog: p.QueryLog, subnetDetector: p.SubnetDetector, localDomainSuffix: localDomainSuffix, + recDetector: newRecursionDetector(recursionTTL, cachedRecurrentReqNum), } if p.DHCPServer != nil { @@ -160,7 +170,9 @@ func NewServer(p DNSCreateParams) (s *Server, err error) { // NewCustomServer creates a new instance of *Server with custom internal proxy. func NewCustomServer(internalProxy *proxy.Proxy) *Server { - s := &Server{} + s := &Server{ + recDetector: newRecursionDetector(0, 1), + } if internalProxy != nil { s.internalProxy = internalProxy } @@ -278,14 +290,13 @@ func (s *Server) Exchange(ip net.IP) (host string, err error) { Req: req, StartTime: time.Now(), } - var resp *dns.Msg - err = resolver.Resolve(ctx) - if err != nil { + + s.recDetector.add(*req) + if err = resolver.Resolve(ctx); err != nil { return "", err } - resp = ctx.Res - + resp := ctx.Res if len(resp.Answer) == 0 { return "", fmt.Errorf("lookup for %q: %w", arpa, rDNSEmptyAnswerErr) } @@ -490,6 +501,8 @@ func (s *Server) Prepare(config *ServerConfig) error { return fmt.Errorf("setting up resolvers: %w", err) } + s.recDetector.clear() + return nil } diff --git a/internal/dnsforward/recursiondetector.go b/internal/dnsforward/recursiondetector.go new file mode 100644 index 00000000..5203f518 --- /dev/null +++ b/internal/dnsforward/recursiondetector.go @@ -0,0 +1,115 @@ +package dnsforward + +import ( + "bytes" + "encoding/binary" + "time" + + "github.com/AdguardTeam/AdGuardHome/internal/aghnet" + "github.com/AdguardTeam/golibs/cache" + "github.com/AdguardTeam/golibs/log" + "github.com/miekg/dns" +) + +// uint* sizes in bytes to improve readability. +// +// TODO(e.burkov): Remove when there will be a more regardful way to define +// those. See https://github.com/golang/go/issues/29982. +const ( + uint16sz = 2 + uint64sz = 8 +) + +// recursionDetector detects recursion in DNS forwarding. +type recursionDetector struct { + recentRequests cache.Cache + ttl time.Duration +} + +// check checks if the passed req was already sent by s. +func (rd *recursionDetector) check(msg dns.Msg) (ok bool) { + if len(msg.Question) == 0 { + return false + } + + key := msgToSignature(msg) + expireData := rd.recentRequests.Get(key) + if expireData == nil { + return false + } + + expire := time.Unix(0, int64(binary.BigEndian.Uint64(expireData))) + + return time.Now().Before(expire) +} + +// add caches the msg if it has anything in the questions section. +func (rd *recursionDetector) add(msg dns.Msg) { + now := time.Now() + + if len(msg.Question) == 0 { + return + } + + key := msgToSignature(msg) + expire64 := uint64(now.Add(rd.ttl).UnixNano()) + expire := make([]byte, uint64sz) + binary.BigEndian.PutUint64(expire, expire64) + + rd.recentRequests.Set(key, expire) +} + +// clear clears the recent requests cache. +func (rd *recursionDetector) clear() { + rd.recentRequests.Clear() +} + +// newRecursionDetector returns the initialized *recursionDetector. +func newRecursionDetector(ttl time.Duration, suspectsNum uint) (rd *recursionDetector) { + return &recursionDetector{ + recentRequests: cache.New(cache.Config{ + EnableLRU: true, + MaxCount: suspectsNum, + }), + ttl: ttl, + } +} + +// msgToSignature converts msg into it's signature represented in bytes. +func msgToSignature(msg dns.Msg) (sig []byte) { + sig = make([]byte, uint16sz*2+aghnet.MaxDomainNameLen) + // The binary.BigEndian byte order is used everywhere except when the + // real machine's endianess is needed. + byteOrder := binary.BigEndian + byteOrder.PutUint16(sig[0:], msg.Id) + q := msg.Question[0] + byteOrder.PutUint16(sig[uint16sz:], q.Qtype) + copy(sig[2*uint16sz:], []byte(q.Name)) + + return sig +} + +// msgToSignatureSlow converts msg into it's signature represented in bytes in +// the less efficient way. +// +// See BenchmarkMsgToSignature. +func msgToSignatureSlow(msg dns.Msg) (sig []byte) { + type msgSignature struct { + name [aghnet.MaxDomainNameLen]byte + id uint16 + qtype uint16 + } + + b := bytes.NewBuffer(sig) + q := msg.Question[0] + signature := msgSignature{ + id: msg.Id, + qtype: q.Qtype, + } + copy(signature.name[:], q.Name) + if err := binary.Write(b, binary.BigEndian, signature); err != nil { + log.Debug("writing message signature: %s", err) + } + + return b.Bytes() +} diff --git a/internal/dnsforward/recursiondetector_test.go b/internal/dnsforward/recursiondetector_test.go new file mode 100644 index 00000000..7573b668 --- /dev/null +++ b/internal/dnsforward/recursiondetector_test.go @@ -0,0 +1,154 @@ +package dnsforward + +import ( + "encoding/binary" + "testing" + "time" + + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" +) + +func TestRecursionDetector_Check(t *testing.T) { + rd := newRecursionDetector(0, 2) + + const ( + recID = 1234 + recTTL = time.Hour * 100 + ) + + const nonRecID = recID * 2 + + sampleQuestion := dns.Question{ + Name: "some.domain", + Qtype: dns.TypeAAAA, + } + sampleMsg := dns.Msg{ + MsgHdr: dns.MsgHdr{ + Id: recID, + }, + Question: []dns.Question{sampleQuestion}, + } + + // Manually add the message with big ttl. + key := msgToSignature(sampleMsg) + expire := make([]byte, uint64sz) + binary.BigEndian.PutUint64(expire, uint64(time.Now().Add(recTTL).UnixNano())) + rd.recentRequests.Set(key, expire) + + // Add an expired message. + sampleMsg.Id = nonRecID + rd.add(sampleMsg) + + testCases := []struct { + name string + questions []dns.Question + id uint16 + want bool + }{{ + name: "recurrent", + questions: []dns.Question{sampleQuestion}, + id: recID, + want: true, + }, { + name: "not_suspected", + questions: []dns.Question{sampleQuestion}, + id: recID + 1, + want: false, + }, { + name: "expired", + questions: []dns.Question{sampleQuestion}, + id: nonRecID, + want: false, + }, { + name: "empty", + questions: []dns.Question{}, + id: nonRecID, + want: false, + }} + + for _, tc := range testCases { + sampleMsg.Id = tc.id + sampleMsg.Question = tc.questions + t.Run(tc.name, func(t *testing.T) { + detected := rd.check(sampleMsg) + assert.Equal(t, tc.want, detected) + }) + } +} + +func TestRecursionDetector_Suspect(t *testing.T) { + rd := newRecursionDetector(0, 1) + + testCases := []struct { + name string + msg dns.Msg + want bool + }{{ + name: "simple", + msg: dns.Msg{ + MsgHdr: dns.MsgHdr{ + Id: 1234, + }, + Question: []dns.Question{{ + Name: "some.domain", + Qtype: dns.TypeA, + }}, + }, + want: true, + }, { + name: "unencumbered", + msg: dns.Msg{}, + want: false, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Cleanup(rd.clear) + + rd.add(tc.msg) + + if tc.want { + assert.Equal(t, 1, rd.recentRequests.Stats().Count) + } else { + assert.Zero(t, rd.recentRequests.Stats().Count) + } + }) + } +} + +var sink []byte + +func BenchmarkMsgToSignature(b *testing.B) { + const name = "some.not.very.long.host.name" + + msg := dns.Msg{ + MsgHdr: dns.MsgHdr{ + Id: 1234, + }, + Question: []dns.Question{{ + Name: name, + Qtype: dns.TypeAAAA, + }}, + } + + b.Run("efficient", func(b *testing.B) { + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + sink = msgToSignature(msg) + } + + assert.NotEmpty(b, sink) + }) + + b.Run("inefficient", func(b *testing.B) { + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + sink = msgToSignatureSlow(msg) + } + + assert.NotEmpty(b, sink) + }) +}