mirror of
https://github.com/AdguardTeam/AdGuardHome.git
synced 2024-12-15 11:22:49 +03:00
Pull request: 3185 detecting recursion
Merge in DNS/adguard-home from 3185-recursion to master
Closes #3185.
Squashed commit of the following:
commit 2fa44223f533c471f2b8c0e17d8550bf4ff73c7b
Merge: 7975957c 7a48e92e
Author: Eugene Burkov <e.burkov@adguard.com>
Date: Thu May 27 19:04:44 2021 +0300
Merge branch 'master' into 3185-recursion
commit 7975957cceb840f76eef0e2e434f4163a122ac34
Author: Eugene Burkov <e.burkov@adguard.com>
Date: Thu May 27 17:36:22 2021 +0300
dnsforward: imp docs
commit 1af7131a5b7c1fefed2d1eb8ee24ebfd3602dc77
Author: Eugene Burkov <e.burkov@adguard.com>
Date: Thu May 27 17:15:00 2021 +0300
dnsforward: imp code, tests, docs
commit f3f9145fb5e1174fab87ca6890da9df722cfebf0
Author: Eugene Burkov <e.burkov@adguard.com>
Date: Thu May 27 15:45:44 2021 +0300
dnsforward: add recursion detector
This commit is contained in:
parent
7a48e92e4d
commit
48b8579703
@ -15,6 +15,7 @@ and this project adheres to
|
||||
|
||||
### Added
|
||||
|
||||
- Detection and handling of recurrent requests ([#3185]).
|
||||
- The ability to completely disable reverse DNS resolving of IPs from
|
||||
locally-served networks ([#3184]).
|
||||
- New flag `--local-frontend` to serve dinamically changeable frontend files
|
||||
@ -38,6 +39,7 @@ released by then.
|
||||
- Go 1.15 support.
|
||||
|
||||
[#3184]: https://github.com/AdguardTeam/AdGuardHome/issues/3184
|
||||
[#3185]: https://github.com/AdguardTeam/AdGuardHome/issues/3185
|
||||
|
||||
|
||||
|
||||
|
@ -42,11 +42,11 @@ func ValidateHardwareAddress(hwa net.HardwareAddr) (err error) {
|
||||
// according to RFC 1035.
|
||||
const maxDomainLabelLen = 63
|
||||
|
||||
// maxDomainNameLen is the maximum allowed length of a full domain name
|
||||
// MaxDomainNameLen is the maximum allowed length of a full domain name
|
||||
// according to RFC 1035.
|
||||
//
|
||||
// See https://stackoverflow.com/a/32294443/1892060.
|
||||
const maxDomainNameLen = 253
|
||||
const MaxDomainNameLen = 253
|
||||
|
||||
// ValidateDomainNameLabel returns an error if label is not a valid label of
|
||||
// a domain name.
|
||||
@ -97,8 +97,8 @@ func ValidateDomainName(name string) (err error) {
|
||||
l := len(name)
|
||||
if l == 0 {
|
||||
return errors.Error("domain name is empty")
|
||||
} else if l > maxDomainNameLen {
|
||||
return fmt.Errorf("too long, max: %d", maxDomainNameLen)
|
||||
} else if l > MaxDomainNameLen {
|
||||
return fmt.Errorf("too long, max: %d", MaxDomainNameLen)
|
||||
}
|
||||
|
||||
labels := strings.Split(name, ".")
|
||||
|
@ -82,6 +82,7 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error {
|
||||
// (*proxy.Proxy).handleDNSRequest method performs it before calling the
|
||||
// appropriate handler.
|
||||
mods := []modProcessFunc{
|
||||
s.processRecursion,
|
||||
processInitial,
|
||||
s.processDetermineLocal,
|
||||
s.processInternalHosts,
|
||||
@ -90,7 +91,7 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error {
|
||||
processClientID,
|
||||
processFilteringBeforeRequest,
|
||||
s.processLocalPTR,
|
||||
processUpstream,
|
||||
s.processUpstream,
|
||||
processDNSSECAfterResponse,
|
||||
processFilteringAfterResponse,
|
||||
s.ipset.process,
|
||||
@ -116,6 +117,22 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// processRecursion checks the incoming request and halts it's handling if s
|
||||
// have tried to resolve it recently.
|
||||
func (s *Server) processRecursion(dctx *dnsContext) (rc resultCode) {
|
||||
pctx := dctx.proxyCtx
|
||||
|
||||
if msg := pctx.Req; msg != nil && s.recDetector.check(*msg) {
|
||||
log.Debug("recursion detected resolving %q", msg.Question[0].Name)
|
||||
pctx.Res = s.genNXDomain(pctx.Req)
|
||||
|
||||
return resultCodeFinish
|
||||
|
||||
}
|
||||
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
// Perform initial checks; process WHOIS & rDNS
|
||||
func processInitial(ctx *dnsContext) (rc resultCode) {
|
||||
s := ctx.srv
|
||||
@ -422,6 +439,7 @@ func (s *Server) processLocalPTR(ctx *dnsContext) (rc resultCode) {
|
||||
}
|
||||
|
||||
if s.conf.UsePrivateRDNS {
|
||||
s.recDetector.add(*d.Req)
|
||||
if err := s.localResolvers.Resolve(d); err != nil {
|
||||
ctx.err = err
|
||||
|
||||
@ -472,8 +490,7 @@ func processFilteringBeforeRequest(ctx *dnsContext) (rc resultCode) {
|
||||
}
|
||||
|
||||
// processUpstream passes request to upstream servers and handles the response.
|
||||
func processUpstream(ctx *dnsContext) (rc resultCode) {
|
||||
s := ctx.srv
|
||||
func (s *Server) processUpstream(ctx *dnsContext) (rc resultCode) {
|
||||
d := ctx.proxyCtx
|
||||
if d.Res != nil {
|
||||
return resultCodeSuccess // response is already set - nothing to do
|
||||
@ -481,18 +498,18 @@ func processUpstream(ctx *dnsContext) (rc resultCode) {
|
||||
|
||||
if d.Addr != nil && s.conf.GetCustomUpstreamByClient != nil {
|
||||
clientIP := IPStringFromAddr(d.Addr)
|
||||
upstreamsConf := s.conf.GetCustomUpstreamByClient(clientIP)
|
||||
if upstreamsConf != nil {
|
||||
log.Debug("Using custom upstreams for %s", clientIP)
|
||||
d.CustomUpstreamConfig = upstreamsConf
|
||||
if upsConf := s.conf.GetCustomUpstreamByClient(clientIP); upsConf != nil {
|
||||
log.Debug("dns: using custom upstreams for client %s", clientIP)
|
||||
d.CustomUpstreamConfig = upsConf
|
||||
}
|
||||
}
|
||||
|
||||
req := d.Req
|
||||
if s.conf.EnableDNSSEC {
|
||||
opt := d.Req.IsEdns0()
|
||||
opt := req.IsEdns0()
|
||||
if opt == nil {
|
||||
log.Debug("dns: Adding OPT record with DNSSEC flag")
|
||||
d.Req.SetEdns0(4096, true)
|
||||
log.Debug("dns: adding OPT record with DNSSEC flag")
|
||||
req.SetEdns0(4096, true)
|
||||
} else if !opt.Do() {
|
||||
opt.SetDo(true)
|
||||
} else {
|
||||
@ -501,13 +518,13 @@ func processUpstream(ctx *dnsContext) (rc resultCode) {
|
||||
}
|
||||
|
||||
// request was not filtered so let it be processed further
|
||||
err := s.dnsProxy.Resolve(d)
|
||||
if err != nil {
|
||||
ctx.err = err
|
||||
s.recDetector.add(*req)
|
||||
if ctx.err = s.dnsProxy.Resolve(d); ctx.err != nil {
|
||||
return resultCodeError
|
||||
}
|
||||
|
||||
ctx.responseFromUpstream = true
|
||||
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
|
@ -76,6 +76,7 @@ type Server struct {
|
||||
ipset ipsetCtx
|
||||
subnetDetector *aghnet.SubnetDetector
|
||||
localResolvers *proxy.Proxy
|
||||
recDetector *recursionDetector
|
||||
|
||||
tableHostToIP hostToIPTable
|
||||
tableHostToIPLock sync.Mutex
|
||||
@ -121,6 +122,14 @@ func domainNameToSuffix(tld string) (suffix string) {
|
||||
return string(b)
|
||||
}
|
||||
|
||||
const (
|
||||
// recursionTTL is the time recursive request is cached for.
|
||||
recursionTTL = 5 * time.Second
|
||||
// cachedRecurrentReqNum is the maximum number of cached recurrent
|
||||
// requests.
|
||||
cachedRecurrentReqNum = 1000
|
||||
)
|
||||
|
||||
// NewServer creates a new instance of the dnsforward.Server
|
||||
// Note: this function must be called only once
|
||||
func NewServer(p DNSCreateParams) (s *Server, err error) {
|
||||
@ -142,6 +151,7 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
|
||||
queryLog: p.QueryLog,
|
||||
subnetDetector: p.SubnetDetector,
|
||||
localDomainSuffix: localDomainSuffix,
|
||||
recDetector: newRecursionDetector(recursionTTL, cachedRecurrentReqNum),
|
||||
}
|
||||
|
||||
if p.DHCPServer != nil {
|
||||
@ -160,7 +170,9 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
|
||||
|
||||
// NewCustomServer creates a new instance of *Server with custom internal proxy.
|
||||
func NewCustomServer(internalProxy *proxy.Proxy) *Server {
|
||||
s := &Server{}
|
||||
s := &Server{
|
||||
recDetector: newRecursionDetector(0, 1),
|
||||
}
|
||||
if internalProxy != nil {
|
||||
s.internalProxy = internalProxy
|
||||
}
|
||||
@ -278,14 +290,13 @@ func (s *Server) Exchange(ip net.IP) (host string, err error) {
|
||||
Req: req,
|
||||
StartTime: time.Now(),
|
||||
}
|
||||
var resp *dns.Msg
|
||||
err = resolver.Resolve(ctx)
|
||||
if err != nil {
|
||||
|
||||
s.recDetector.add(*req)
|
||||
if err = resolver.Resolve(ctx); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
resp = ctx.Res
|
||||
|
||||
resp := ctx.Res
|
||||
if len(resp.Answer) == 0 {
|
||||
return "", fmt.Errorf("lookup for %q: %w", arpa, rDNSEmptyAnswerErr)
|
||||
}
|
||||
@ -490,6 +501,8 @@ func (s *Server) Prepare(config *ServerConfig) error {
|
||||
return fmt.Errorf("setting up resolvers: %w", err)
|
||||
}
|
||||
|
||||
s.recDetector.clear()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
115
internal/dnsforward/recursiondetector.go
Normal file
115
internal/dnsforward/recursiondetector.go
Normal file
@ -0,0 +1,115 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/golibs/cache"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// uint* sizes in bytes to improve readability.
|
||||
//
|
||||
// TODO(e.burkov): Remove when there will be a more regardful way to define
|
||||
// those. See https://github.com/golang/go/issues/29982.
|
||||
const (
|
||||
uint16sz = 2
|
||||
uint64sz = 8
|
||||
)
|
||||
|
||||
// recursionDetector detects recursion in DNS forwarding.
|
||||
type recursionDetector struct {
|
||||
recentRequests cache.Cache
|
||||
ttl time.Duration
|
||||
}
|
||||
|
||||
// check checks if the passed req was already sent by s.
|
||||
func (rd *recursionDetector) check(msg dns.Msg) (ok bool) {
|
||||
if len(msg.Question) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
key := msgToSignature(msg)
|
||||
expireData := rd.recentRequests.Get(key)
|
||||
if expireData == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
expire := time.Unix(0, int64(binary.BigEndian.Uint64(expireData)))
|
||||
|
||||
return time.Now().Before(expire)
|
||||
}
|
||||
|
||||
// add caches the msg if it has anything in the questions section.
|
||||
func (rd *recursionDetector) add(msg dns.Msg) {
|
||||
now := time.Now()
|
||||
|
||||
if len(msg.Question) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
key := msgToSignature(msg)
|
||||
expire64 := uint64(now.Add(rd.ttl).UnixNano())
|
||||
expire := make([]byte, uint64sz)
|
||||
binary.BigEndian.PutUint64(expire, expire64)
|
||||
|
||||
rd.recentRequests.Set(key, expire)
|
||||
}
|
||||
|
||||
// clear clears the recent requests cache.
|
||||
func (rd *recursionDetector) clear() {
|
||||
rd.recentRequests.Clear()
|
||||
}
|
||||
|
||||
// newRecursionDetector returns the initialized *recursionDetector.
|
||||
func newRecursionDetector(ttl time.Duration, suspectsNum uint) (rd *recursionDetector) {
|
||||
return &recursionDetector{
|
||||
recentRequests: cache.New(cache.Config{
|
||||
EnableLRU: true,
|
||||
MaxCount: suspectsNum,
|
||||
}),
|
||||
ttl: ttl,
|
||||
}
|
||||
}
|
||||
|
||||
// msgToSignature converts msg into it's signature represented in bytes.
|
||||
func msgToSignature(msg dns.Msg) (sig []byte) {
|
||||
sig = make([]byte, uint16sz*2+aghnet.MaxDomainNameLen)
|
||||
// The binary.BigEndian byte order is used everywhere except when the
|
||||
// real machine's endianess is needed.
|
||||
byteOrder := binary.BigEndian
|
||||
byteOrder.PutUint16(sig[0:], msg.Id)
|
||||
q := msg.Question[0]
|
||||
byteOrder.PutUint16(sig[uint16sz:], q.Qtype)
|
||||
copy(sig[2*uint16sz:], []byte(q.Name))
|
||||
|
||||
return sig
|
||||
}
|
||||
|
||||
// msgToSignatureSlow converts msg into it's signature represented in bytes in
|
||||
// the less efficient way.
|
||||
//
|
||||
// See BenchmarkMsgToSignature.
|
||||
func msgToSignatureSlow(msg dns.Msg) (sig []byte) {
|
||||
type msgSignature struct {
|
||||
name [aghnet.MaxDomainNameLen]byte
|
||||
id uint16
|
||||
qtype uint16
|
||||
}
|
||||
|
||||
b := bytes.NewBuffer(sig)
|
||||
q := msg.Question[0]
|
||||
signature := msgSignature{
|
||||
id: msg.Id,
|
||||
qtype: q.Qtype,
|
||||
}
|
||||
copy(signature.name[:], q.Name)
|
||||
if err := binary.Write(b, binary.BigEndian, signature); err != nil {
|
||||
log.Debug("writing message signature: %s", err)
|
||||
}
|
||||
|
||||
return b.Bytes()
|
||||
}
|
154
internal/dnsforward/recursiondetector_test.go
Normal file
154
internal/dnsforward/recursiondetector_test.go
Normal file
@ -0,0 +1,154 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestRecursionDetector_Check(t *testing.T) {
|
||||
rd := newRecursionDetector(0, 2)
|
||||
|
||||
const (
|
||||
recID = 1234
|
||||
recTTL = time.Hour * 100
|
||||
)
|
||||
|
||||
const nonRecID = recID * 2
|
||||
|
||||
sampleQuestion := dns.Question{
|
||||
Name: "some.domain",
|
||||
Qtype: dns.TypeAAAA,
|
||||
}
|
||||
sampleMsg := dns.Msg{
|
||||
MsgHdr: dns.MsgHdr{
|
||||
Id: recID,
|
||||
},
|
||||
Question: []dns.Question{sampleQuestion},
|
||||
}
|
||||
|
||||
// Manually add the message with big ttl.
|
||||
key := msgToSignature(sampleMsg)
|
||||
expire := make([]byte, uint64sz)
|
||||
binary.BigEndian.PutUint64(expire, uint64(time.Now().Add(recTTL).UnixNano()))
|
||||
rd.recentRequests.Set(key, expire)
|
||||
|
||||
// Add an expired message.
|
||||
sampleMsg.Id = nonRecID
|
||||
rd.add(sampleMsg)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
questions []dns.Question
|
||||
id uint16
|
||||
want bool
|
||||
}{{
|
||||
name: "recurrent",
|
||||
questions: []dns.Question{sampleQuestion},
|
||||
id: recID,
|
||||
want: true,
|
||||
}, {
|
||||
name: "not_suspected",
|
||||
questions: []dns.Question{sampleQuestion},
|
||||
id: recID + 1,
|
||||
want: false,
|
||||
}, {
|
||||
name: "expired",
|
||||
questions: []dns.Question{sampleQuestion},
|
||||
id: nonRecID,
|
||||
want: false,
|
||||
}, {
|
||||
name: "empty",
|
||||
questions: []dns.Question{},
|
||||
id: nonRecID,
|
||||
want: false,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
sampleMsg.Id = tc.id
|
||||
sampleMsg.Question = tc.questions
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
detected := rd.check(sampleMsg)
|
||||
assert.Equal(t, tc.want, detected)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecursionDetector_Suspect(t *testing.T) {
|
||||
rd := newRecursionDetector(0, 1)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
msg dns.Msg
|
||||
want bool
|
||||
}{{
|
||||
name: "simple",
|
||||
msg: dns.Msg{
|
||||
MsgHdr: dns.MsgHdr{
|
||||
Id: 1234,
|
||||
},
|
||||
Question: []dns.Question{{
|
||||
Name: "some.domain",
|
||||
Qtype: dns.TypeA,
|
||||
}},
|
||||
},
|
||||
want: true,
|
||||
}, {
|
||||
name: "unencumbered",
|
||||
msg: dns.Msg{},
|
||||
want: false,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Cleanup(rd.clear)
|
||||
|
||||
rd.add(tc.msg)
|
||||
|
||||
if tc.want {
|
||||
assert.Equal(t, 1, rd.recentRequests.Stats().Count)
|
||||
} else {
|
||||
assert.Zero(t, rd.recentRequests.Stats().Count)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
var sink []byte
|
||||
|
||||
func BenchmarkMsgToSignature(b *testing.B) {
|
||||
const name = "some.not.very.long.host.name"
|
||||
|
||||
msg := dns.Msg{
|
||||
MsgHdr: dns.MsgHdr{
|
||||
Id: 1234,
|
||||
},
|
||||
Question: []dns.Question{{
|
||||
Name: name,
|
||||
Qtype: dns.TypeAAAA,
|
||||
}},
|
||||
}
|
||||
|
||||
b.Run("efficient", func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
sink = msgToSignature(msg)
|
||||
}
|
||||
|
||||
assert.NotEmpty(b, sink)
|
||||
})
|
||||
|
||||
b.Run("inefficient", func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
sink = msgToSignatureSlow(msg)
|
||||
}
|
||||
|
||||
assert.NotEmpty(b, sink)
|
||||
})
|
||||
}
|
Loading…
Reference in New Issue
Block a user