Pull request 2007: 6183-orig-resp

Closes #6183.

Squashed commit of the following:

commit a99b935d7a152f2cf2d003057cfb8e3c7c3579c5
Merge: 3534f663f 36517fc21
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Fri Sep 8 17:46:51 2023 +0300

    Merge branch 'master' into 6183-orig-resp

commit 3534f663ff4aaacc4a1044b018802bd23cd8f7ec
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Fri Sep 8 17:00:54 2023 +0300

    dnsforward: fix orig resp
This commit is contained in:
Ainar Garipov 2023-09-08 17:55:13 +03:00
parent 36517fc21b
commit 8b8ae8ffad
4 changed files with 50 additions and 42 deletions

View File

@ -25,12 +25,14 @@ NOTE: Add new changes BELOW THIS COMMENT.
### Fixed ### Fixed
- Incorrect original answer when a response is filtered ([#6183]).
- Comments in the *Fallback DNS servers* field in the UI ([#6182]). - Comments in the *Fallback DNS servers* field in the UI ([#6182]).
- Empty or default Safe Browsing and Parental Control settings ([#6181]). - Empty or default Safe Browsing and Parental Control settings ([#6181]).
- Various UI issues. - Various UI issues.
[#6181]: https://github.com/AdguardTeam/AdGuardHome/issues/6181 [#6181]: https://github.com/AdguardTeam/AdGuardHome/issues/6181
[#6182]: https://github.com/AdguardTeam/AdGuardHome/issues/6182 [#6182]: https://github.com/AdguardTeam/AdGuardHome/issues/6182
[#6183]: https://github.com/AdguardTeam/AdGuardHome/issues/6183
<!-- <!--
NOTE: Add new changes ABOVE THIS COMMENT. NOTE: Add new changes ABOVE THIS COMMENT.

View File

@ -11,6 +11,7 @@ import (
"github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns" "github.com/miekg/dns"
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
) )
@ -140,15 +141,15 @@ func (s *Server) filterRewritten(
// checkHostRules checks the host against filters. It is safe for concurrent // checkHostRules checks the host against filters. It is safe for concurrent
// use. // use.
func (s *Server) checkHostRules(host string, rrtype uint16, setts *filtering.Settings) ( func (s *Server) checkHostRules(
r *filtering.Result, host string,
err error, rrtype rules.RRType,
) { setts *filtering.Settings,
) (r *filtering.Result, err error) {
s.serverLock.RLock() s.serverLock.RLock()
defer s.serverLock.RUnlock() defer s.serverLock.RUnlock()
var res filtering.Result res, err := s.dnsFilter.CheckHostRules(host, rrtype, setts)
res, err = s.dnsFilter.CheckHostRules(host, rrtype, setts)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -156,20 +157,21 @@ func (s *Server) checkHostRules(host string, rrtype uint16, setts *filtering.Set
return &res, err return &res, err
} }
// filterDNSResponse checks each resource record of the response's answer // filterDNSResponse checks each resource record of answer section of
// section from pctx and returns a non-nil res if at least one of canonical // dctx.proxyCtx.Res. It sets dctx.result and dctx.origResp if at least one of
// names or IP addresses in it matches the filtering rules. // canonical names, IP addresses, or HTTPS RR hints in it matches the filtering
func (s *Server) filterDNSResponse( // rules, as well as sets dctx.proxyCtx.Res to the filtered response.
pctx *proxy.DNSContext, func (s *Server) filterDNSResponse(dctx *dnsContext) (err error) {
setts *filtering.Settings, setts := dctx.setts
) (res *filtering.Result, err error) {
if !setts.FilteringEnabled { if !setts.FilteringEnabled {
return nil, nil return nil
} }
for _, a := range pctx.Res.Answer { var res *filtering.Result
pctx := dctx.proxyCtx
for i, a := range pctx.Res.Answer {
host := "" host := ""
var rrtype uint16 var rrtype rules.RRType
switch a := a.(type) { switch a := a.(type) {
case *dns.CNAME: case *dns.CNAME:
host = strings.TrimSuffix(a.Target, ".") host = strings.TrimSuffix(a.Target, ".")
@ -195,18 +197,19 @@ func (s *Server) filterDNSResponse(
log.Debug("dnsforward: checked %s %s for %s", dns.Type(rrtype), host, a.Header().Name) log.Debug("dnsforward: checked %s %s for %s", dns.Type(rrtype), host, a.Header().Name)
if err != nil { if err != nil {
return nil, err return fmt.Errorf("filtering answer at index %d: %w", i, err)
} else if res == nil { } else if res != nil && res.IsFiltered {
continue dctx.result = res
} else if res.IsFiltered { dctx.origResp = pctx.Res
pctx.Res = s.genDNSFilterMessage(pctx, res) pctx.Res = s.genDNSFilterMessage(pctx, res)
log.Debug("dnsforward: matched %q by response: %q", pctx.Req.Question[0].Name, host) log.Debug("dnsforward: matched %q by response: %q", pctx.Req.Question[0].Name, host)
return res, nil break
} }
} }
return nil, nil return nil
} }
// removeIPv6Hints deletes IPv6 hints from RR values. // removeIPv6Hints deletes IPv6 hints from RR values.

View File

@ -328,26 +328,34 @@ func TestHandleDNSRequest_filterDNSResponse(t *testing.T) {
Addr: &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: 1}, Addr: &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: 1},
} }
res, rErr := s.filterDNSResponse(pctx, &filtering.Settings{ dctx := &dnsContext{
ProtectionEnabled: true, proxyCtx: pctx,
FilteringEnabled: true, setts: &filtering.Settings{
}) ProtectionEnabled: true,
require.NoError(t, rErr) FilteringEnabled: true,
},
}
fltErr := s.filterDNSResponse(dctx)
require.NoError(t, fltErr)
res := dctx.result
if tc.wantRule == "" { if tc.wantRule == "" {
assert.Nil(t, res) assert.Nil(t, res)
return return
} }
want := &filtering.Result{ wantResult := &filtering.Result{
IsFiltered: true, IsFiltered: true,
Reason: filtering.FilteredBlockList, Reason: filtering.FilteredBlockList,
Rules: []*filtering.ResultRule{{ Rules: []*filtering.ResultRule{{
Text: tc.wantRule, Text: tc.wantRule,
}}, }},
} }
assert.Equal(t, want, res)
assert.Equal(t, wantResult, res)
assert.Equal(t, resp, dctx.origResp)
}) })
} }
} }

View File

@ -671,11 +671,11 @@ func (s *Server) processLocalPTR(dctx *dnsContext) (rc resultCode) {
} }
// Apply filtering logic // Apply filtering logic
func (s *Server) processFilteringBeforeRequest(ctx *dnsContext) (rc resultCode) { func (s *Server) processFilteringBeforeRequest(dctx *dnsContext) (rc resultCode) {
log.Debug("dnsforward: started processing filtering before req") log.Debug("dnsforward: started processing filtering before req")
defer log.Debug("dnsforward: finished processing filtering before req") defer log.Debug("dnsforward: finished processing filtering before req")
if ctx.proxyCtx.Res != nil { if dctx.proxyCtx.Res != nil {
// Go on since the response is already set. // Go on since the response is already set.
return resultCodeSuccess return resultCodeSuccess
} }
@ -684,8 +684,8 @@ func (s *Server) processFilteringBeforeRequest(ctx *dnsContext) (rc resultCode)
defer s.serverLock.RUnlock() defer s.serverLock.RUnlock()
var err error var err error
if ctx.result, err = s.filterDNSRequest(ctx); err != nil { if dctx.result, err = s.filterDNSRequest(dctx); err != nil {
ctx.err = err dctx.err = err
return resultCodeError return resultCodeError
} }
@ -857,7 +857,6 @@ func (s *Server) processFilteringAfterResponse(dctx *dnsContext) (rc resultCode)
log.Debug("dnsforward: started processing filtering after resp") log.Debug("dnsforward: started processing filtering after resp")
defer log.Debug("dnsforward: finished processing filtering after resp") defer log.Debug("dnsforward: finished processing filtering after resp")
pctx := dctx.proxyCtx
switch res := dctx.result; res.Reason { switch res := dctx.result; res.Reason {
case filtering.NotFilteredAllowList: case filtering.NotFilteredAllowList:
return resultCodeSuccess return resultCodeSuccess
@ -871,6 +870,7 @@ func (s *Server) processFilteringAfterResponse(dctx *dnsContext) (rc resultCode)
return resultCodeSuccess return resultCodeSuccess
} }
pctx := dctx.proxyCtx
pctx.Req.Question[0], pctx.Res.Question[0] = dctx.origQuestion, dctx.origQuestion pctx.Req.Question[0], pctx.Res.Question[0] = dctx.origQuestion, dctx.origQuestion
if len(pctx.Res.Answer) > 0 { if len(pctx.Res.Answer) > 0 {
rr := s.genAnswerCNAME(pctx.Req, res.CanonName) rr := s.genAnswerCNAME(pctx.Req, res.CanonName)
@ -880,13 +880,13 @@ func (s *Server) processFilteringAfterResponse(dctx *dnsContext) (rc resultCode)
return resultCodeSuccess return resultCodeSuccess
default: default:
return s.filterAfterResponse(dctx, pctx) return s.filterAfterResponse(dctx)
} }
} }
// filterAfterResponse returns the result of filtering the response that wasn't // filterAfterResponse returns the result of filtering the response that wasn't
// explicitly allowed or rewritten. // explicitly allowed or rewritten.
func (s *Server) filterAfterResponse(dctx *dnsContext, pctx *proxy.DNSContext) (res resultCode) { func (s *Server) filterAfterResponse(dctx *dnsContext) (res resultCode) {
// Check the response only if it's from an upstream. Don't check the // Check the response only if it's from an upstream. Don't check the
// response if the protection is disabled since dnsrewrite rules aren't // response if the protection is disabled since dnsrewrite rules aren't
// applied to it anyway. // applied to it anyway.
@ -894,17 +894,12 @@ func (s *Server) filterAfterResponse(dctx *dnsContext, pctx *proxy.DNSContext) (
return resultCodeSuccess return resultCodeSuccess
} }
result, err := s.filterDNSResponse(pctx, dctx.setts) err := s.filterDNSResponse(dctx)
if err != nil { if err != nil {
dctx.err = err dctx.err = err
return resultCodeError return resultCodeError
} }
if result != nil {
dctx.result = result
dctx.origResp = pctx.Res
}
return resultCodeSuccess return resultCodeSuccess
} }