Pull request: 3185 detecting recursion

Merge in DNS/adguard-home from 3185-recursion to master

Closes #3185.

Squashed commit of the following:

commit 2fa44223f533c471f2b8c0e17d8550bf4ff73c7b
Merge: 7975957c 7a48e92e
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Thu May 27 19:04:44 2021 +0300

    Merge branch 'master' into 3185-recursion

commit 7975957cceb840f76eef0e2e434f4163a122ac34
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Thu May 27 17:36:22 2021 +0300

    dnsforward: imp docs

commit 1af7131a5b7c1fefed2d1eb8ee24ebfd3602dc77
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Thu May 27 17:15:00 2021 +0300

    dnsforward: imp code, tests, docs

commit f3f9145fb5e1174fab87ca6890da9df722cfebf0
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Thu May 27 15:45:44 2021 +0300

    dnsforward: add recursion detector
This commit is contained in:
Eugene Burkov 2021-05-27 19:19:19 +03:00
parent 7a48e92e4d
commit 48b8579703
6 changed files with 324 additions and 23 deletions

View File

@ -15,6 +15,7 @@ and this project adheres to
### Added
- Detection and handling of recurrent requests ([#3185]).
- The ability to completely disable reverse DNS resolving of IPs from
locally-served networks ([#3184]).
- New flag `--local-frontend` to serve dinamically changeable frontend files
@ -38,6 +39,7 @@ released by then.
- Go 1.15 support.
[#3184]: https://github.com/AdguardTeam/AdGuardHome/issues/3184
[#3185]: https://github.com/AdguardTeam/AdGuardHome/issues/3185

View File

@ -42,11 +42,11 @@ func ValidateHardwareAddress(hwa net.HardwareAddr) (err error) {
// according to RFC 1035.
const maxDomainLabelLen = 63
// maxDomainNameLen is the maximum allowed length of a full domain name
// MaxDomainNameLen is the maximum allowed length of a full domain name
// according to RFC 1035.
//
// See https://stackoverflow.com/a/32294443/1892060.
const maxDomainNameLen = 253
const MaxDomainNameLen = 253
// ValidateDomainNameLabel returns an error if label is not a valid label of
// a domain name.
@ -97,8 +97,8 @@ func ValidateDomainName(name string) (err error) {
l := len(name)
if l == 0 {
return errors.Error("domain name is empty")
} else if l > maxDomainNameLen {
return fmt.Errorf("too long, max: %d", maxDomainNameLen)
} else if l > MaxDomainNameLen {
return fmt.Errorf("too long, max: %d", MaxDomainNameLen)
}
labels := strings.Split(name, ".")

View File

@ -82,6 +82,7 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error {
// (*proxy.Proxy).handleDNSRequest method performs it before calling the
// appropriate handler.
mods := []modProcessFunc{
s.processRecursion,
processInitial,
s.processDetermineLocal,
s.processInternalHosts,
@ -90,7 +91,7 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error {
processClientID,
processFilteringBeforeRequest,
s.processLocalPTR,
processUpstream,
s.processUpstream,
processDNSSECAfterResponse,
processFilteringAfterResponse,
s.ipset.process,
@ -116,6 +117,22 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error {
return nil
}
// processRecursion checks the incoming request and halts it's handling if s
// have tried to resolve it recently.
func (s *Server) processRecursion(dctx *dnsContext) (rc resultCode) {
pctx := dctx.proxyCtx
if msg := pctx.Req; msg != nil && s.recDetector.check(*msg) {
log.Debug("recursion detected resolving %q", msg.Question[0].Name)
pctx.Res = s.genNXDomain(pctx.Req)
return resultCodeFinish
}
return resultCodeSuccess
}
// Perform initial checks; process WHOIS & rDNS
func processInitial(ctx *dnsContext) (rc resultCode) {
s := ctx.srv
@ -422,6 +439,7 @@ func (s *Server) processLocalPTR(ctx *dnsContext) (rc resultCode) {
}
if s.conf.UsePrivateRDNS {
s.recDetector.add(*d.Req)
if err := s.localResolvers.Resolve(d); err != nil {
ctx.err = err
@ -472,8 +490,7 @@ func processFilteringBeforeRequest(ctx *dnsContext) (rc resultCode) {
}
// processUpstream passes request to upstream servers and handles the response.
func processUpstream(ctx *dnsContext) (rc resultCode) {
s := ctx.srv
func (s *Server) processUpstream(ctx *dnsContext) (rc resultCode) {
d := ctx.proxyCtx
if d.Res != nil {
return resultCodeSuccess // response is already set - nothing to do
@ -481,18 +498,18 @@ func processUpstream(ctx *dnsContext) (rc resultCode) {
if d.Addr != nil && s.conf.GetCustomUpstreamByClient != nil {
clientIP := IPStringFromAddr(d.Addr)
upstreamsConf := s.conf.GetCustomUpstreamByClient(clientIP)
if upstreamsConf != nil {
log.Debug("Using custom upstreams for %s", clientIP)
d.CustomUpstreamConfig = upstreamsConf
if upsConf := s.conf.GetCustomUpstreamByClient(clientIP); upsConf != nil {
log.Debug("dns: using custom upstreams for client %s", clientIP)
d.CustomUpstreamConfig = upsConf
}
}
req := d.Req
if s.conf.EnableDNSSEC {
opt := d.Req.IsEdns0()
opt := req.IsEdns0()
if opt == nil {
log.Debug("dns: Adding OPT record with DNSSEC flag")
d.Req.SetEdns0(4096, true)
log.Debug("dns: adding OPT record with DNSSEC flag")
req.SetEdns0(4096, true)
} else if !opt.Do() {
opt.SetDo(true)
} else {
@ -501,13 +518,13 @@ func processUpstream(ctx *dnsContext) (rc resultCode) {
}
// request was not filtered so let it be processed further
err := s.dnsProxy.Resolve(d)
if err != nil {
ctx.err = err
s.recDetector.add(*req)
if ctx.err = s.dnsProxy.Resolve(d); ctx.err != nil {
return resultCodeError
}
ctx.responseFromUpstream = true
return resultCodeSuccess
}

View File

@ -76,6 +76,7 @@ type Server struct {
ipset ipsetCtx
subnetDetector *aghnet.SubnetDetector
localResolvers *proxy.Proxy
recDetector *recursionDetector
tableHostToIP hostToIPTable
tableHostToIPLock sync.Mutex
@ -121,6 +122,14 @@ func domainNameToSuffix(tld string) (suffix string) {
return string(b)
}
const (
// recursionTTL is the time recursive request is cached for.
recursionTTL = 5 * time.Second
// cachedRecurrentReqNum is the maximum number of cached recurrent
// requests.
cachedRecurrentReqNum = 1000
)
// NewServer creates a new instance of the dnsforward.Server
// Note: this function must be called only once
func NewServer(p DNSCreateParams) (s *Server, err error) {
@ -142,6 +151,7 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
queryLog: p.QueryLog,
subnetDetector: p.SubnetDetector,
localDomainSuffix: localDomainSuffix,
recDetector: newRecursionDetector(recursionTTL, cachedRecurrentReqNum),
}
if p.DHCPServer != nil {
@ -160,7 +170,9 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
// NewCustomServer creates a new instance of *Server with custom internal proxy.
func NewCustomServer(internalProxy *proxy.Proxy) *Server {
s := &Server{}
s := &Server{
recDetector: newRecursionDetector(0, 1),
}
if internalProxy != nil {
s.internalProxy = internalProxy
}
@ -278,14 +290,13 @@ func (s *Server) Exchange(ip net.IP) (host string, err error) {
Req: req,
StartTime: time.Now(),
}
var resp *dns.Msg
err = resolver.Resolve(ctx)
if err != nil {
s.recDetector.add(*req)
if err = resolver.Resolve(ctx); err != nil {
return "", err
}
resp = ctx.Res
resp := ctx.Res
if len(resp.Answer) == 0 {
return "", fmt.Errorf("lookup for %q: %w", arpa, rDNSEmptyAnswerErr)
}
@ -490,6 +501,8 @@ func (s *Server) Prepare(config *ServerConfig) error {
return fmt.Errorf("setting up resolvers: %w", err)
}
s.recDetector.clear()
return nil
}

View File

@ -0,0 +1,115 @@
package dnsforward
import (
"bytes"
"encoding/binary"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/golibs/cache"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
)
// uint* sizes in bytes to improve readability.
//
// TODO(e.burkov): Remove when there will be a more regardful way to define
// those. See https://github.com/golang/go/issues/29982.
const (
uint16sz = 2
uint64sz = 8
)
// recursionDetector detects recursion in DNS forwarding.
type recursionDetector struct {
recentRequests cache.Cache
ttl time.Duration
}
// check checks if the passed req was already sent by s.
func (rd *recursionDetector) check(msg dns.Msg) (ok bool) {
if len(msg.Question) == 0 {
return false
}
key := msgToSignature(msg)
expireData := rd.recentRequests.Get(key)
if expireData == nil {
return false
}
expire := time.Unix(0, int64(binary.BigEndian.Uint64(expireData)))
return time.Now().Before(expire)
}
// add caches the msg if it has anything in the questions section.
func (rd *recursionDetector) add(msg dns.Msg) {
now := time.Now()
if len(msg.Question) == 0 {
return
}
key := msgToSignature(msg)
expire64 := uint64(now.Add(rd.ttl).UnixNano())
expire := make([]byte, uint64sz)
binary.BigEndian.PutUint64(expire, expire64)
rd.recentRequests.Set(key, expire)
}
// clear clears the recent requests cache.
func (rd *recursionDetector) clear() {
rd.recentRequests.Clear()
}
// newRecursionDetector returns the initialized *recursionDetector.
func newRecursionDetector(ttl time.Duration, suspectsNum uint) (rd *recursionDetector) {
return &recursionDetector{
recentRequests: cache.New(cache.Config{
EnableLRU: true,
MaxCount: suspectsNum,
}),
ttl: ttl,
}
}
// msgToSignature converts msg into it's signature represented in bytes.
func msgToSignature(msg dns.Msg) (sig []byte) {
sig = make([]byte, uint16sz*2+aghnet.MaxDomainNameLen)
// The binary.BigEndian byte order is used everywhere except when the
// real machine's endianess is needed.
byteOrder := binary.BigEndian
byteOrder.PutUint16(sig[0:], msg.Id)
q := msg.Question[0]
byteOrder.PutUint16(sig[uint16sz:], q.Qtype)
copy(sig[2*uint16sz:], []byte(q.Name))
return sig
}
// msgToSignatureSlow converts msg into it's signature represented in bytes in
// the less efficient way.
//
// See BenchmarkMsgToSignature.
func msgToSignatureSlow(msg dns.Msg) (sig []byte) {
type msgSignature struct {
name [aghnet.MaxDomainNameLen]byte
id uint16
qtype uint16
}
b := bytes.NewBuffer(sig)
q := msg.Question[0]
signature := msgSignature{
id: msg.Id,
qtype: q.Qtype,
}
copy(signature.name[:], q.Name)
if err := binary.Write(b, binary.BigEndian, signature); err != nil {
log.Debug("writing message signature: %s", err)
}
return b.Bytes()
}

View File

@ -0,0 +1,154 @@
package dnsforward
import (
"encoding/binary"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
)
func TestRecursionDetector_Check(t *testing.T) {
rd := newRecursionDetector(0, 2)
const (
recID = 1234
recTTL = time.Hour * 100
)
const nonRecID = recID * 2
sampleQuestion := dns.Question{
Name: "some.domain",
Qtype: dns.TypeAAAA,
}
sampleMsg := dns.Msg{
MsgHdr: dns.MsgHdr{
Id: recID,
},
Question: []dns.Question{sampleQuestion},
}
// Manually add the message with big ttl.
key := msgToSignature(sampleMsg)
expire := make([]byte, uint64sz)
binary.BigEndian.PutUint64(expire, uint64(time.Now().Add(recTTL).UnixNano()))
rd.recentRequests.Set(key, expire)
// Add an expired message.
sampleMsg.Id = nonRecID
rd.add(sampleMsg)
testCases := []struct {
name string
questions []dns.Question
id uint16
want bool
}{{
name: "recurrent",
questions: []dns.Question{sampleQuestion},
id: recID,
want: true,
}, {
name: "not_suspected",
questions: []dns.Question{sampleQuestion},
id: recID + 1,
want: false,
}, {
name: "expired",
questions: []dns.Question{sampleQuestion},
id: nonRecID,
want: false,
}, {
name: "empty",
questions: []dns.Question{},
id: nonRecID,
want: false,
}}
for _, tc := range testCases {
sampleMsg.Id = tc.id
sampleMsg.Question = tc.questions
t.Run(tc.name, func(t *testing.T) {
detected := rd.check(sampleMsg)
assert.Equal(t, tc.want, detected)
})
}
}
func TestRecursionDetector_Suspect(t *testing.T) {
rd := newRecursionDetector(0, 1)
testCases := []struct {
name string
msg dns.Msg
want bool
}{{
name: "simple",
msg: dns.Msg{
MsgHdr: dns.MsgHdr{
Id: 1234,
},
Question: []dns.Question{{
Name: "some.domain",
Qtype: dns.TypeA,
}},
},
want: true,
}, {
name: "unencumbered",
msg: dns.Msg{},
want: false,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Cleanup(rd.clear)
rd.add(tc.msg)
if tc.want {
assert.Equal(t, 1, rd.recentRequests.Stats().Count)
} else {
assert.Zero(t, rd.recentRequests.Stats().Count)
}
})
}
}
var sink []byte
func BenchmarkMsgToSignature(b *testing.B) {
const name = "some.not.very.long.host.name"
msg := dns.Msg{
MsgHdr: dns.MsgHdr{
Id: 1234,
},
Question: []dns.Question{{
Name: name,
Qtype: dns.TypeAAAA,
}},
}
b.Run("efficient", func(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
sink = msgToSignature(msg)
}
assert.NotEmpty(b, sink)
})
b.Run("inefficient", func(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
sink = msgToSignatureSlow(msg)
}
assert.NotEmpty(b, sink)
})
}