Pull request: 4074 fix upstream test

Merge in DNS/adguard-home from 4074-upstream-test to master

Updates #4074.

Squashed commit of the following:

commit 0de155b1e175a892b259791ff6d6e6f351bcfcf2
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Wed Jan 12 19:20:01 2022 +0500

    dnsforward: fix upstream test
This commit is contained in:
Eugene Burkov 2022-01-13 15:05:44 +03:00
parent 1458600c37
commit 0e4ffd339f
4 changed files with 126 additions and 136 deletions

View File

@ -29,6 +29,7 @@ TODO(a.garipov): Remove this deprecation, if v0.108.0 is released before the Go
### Fixed ### Fixed
- Poor testing of domain-specific upstream servers ([#4074]).
- Omitted aliases of hosts specified by another line within the OS's hosts file - Omitted aliases of hosts specified by another line within the OS's hosts file
([#4079]). ([#4079]).
@ -37,6 +38,8 @@ TODO(a.garipov): Remove this deprecation, if v0.108.0 is released before the Go
- Go 1.16 support. - Go 1.16 support.
[#3057]: https://github.com/AdguardTeam/AdGuardHome/issues/3057 [#3057]: https://github.com/AdguardTeam/AdGuardHome/issues/3057
[#4074]: https://github.com/AdguardTeam/AdGuardHome/issues/4074
[#4079]: https://github.com/AdguardTeam/AdGuardHome/issues/4079
@ -82,7 +85,6 @@ TODO(a.garipov): Remove this deprecation, if v0.108.0 is released before the Go
[#4008]: https://github.com/AdguardTeam/AdGuardHome/issues/4008 [#4008]: https://github.com/AdguardTeam/AdGuardHome/issues/4008
[#4016]: https://github.com/AdguardTeam/AdGuardHome/issues/4016 [#4016]: https://github.com/AdguardTeam/AdGuardHome/issues/4016
[#4027]: https://github.com/AdguardTeam/AdGuardHome/issues/4027 [#4027]: https://github.com/AdguardTeam/AdGuardHome/issues/4027
[#4079]: https://github.com/AdguardTeam/AdGuardHome/issues/4079

View File

@ -5,7 +5,6 @@ import (
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
"strconv"
"strings" "strings"
"time" "time"
@ -192,22 +191,23 @@ func (req *dnsConfig) checkCacheTTL() bool {
func (s *Server) handleSetConfig(w http.ResponseWriter, r *http.Request) { func (s *Server) handleSetConfig(w http.ResponseWriter, r *http.Request) {
req := dnsConfig{} req := dnsConfig{}
dec := json.NewDecoder(r.Body) err := json.NewDecoder(r.Body).Decode(&req)
if err := dec.Decode(&req); err != nil { if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "json Encode: %s", err) aghhttp.Error(r, w, http.StatusBadRequest, "json Encode: %s", err)
return return
} }
if req.Upstreams != nil { if req.Upstreams != nil {
if err := ValidateUpstreams(*req.Upstreams); err != nil { if err = ValidateUpstreams(*req.Upstreams); err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "wrong upstreams specification: %s", err) aghhttp.Error(r, w, http.StatusBadRequest, "wrong upstreams specification: %s", err)
return return
} }
} }
if errBoot, err := req.checkBootstrap(); err != nil { var errBoot string
if errBoot, err = req.checkBootstrap(); err != nil {
aghhttp.Error( aghhttp.Error(
r, r,
w, w,
@ -220,19 +220,16 @@ func (s *Server) handleSetConfig(w http.ResponseWriter, r *http.Request) {
return return
} }
if !req.checkBlockingMode() { switch {
case !req.checkBlockingMode():
aghhttp.Error(r, w, http.StatusBadRequest, "blocking_mode: incorrect value") aghhttp.Error(r, w, http.StatusBadRequest, "blocking_mode: incorrect value")
return return
} case !req.checkUpstreamsMode():
if !req.checkUpstreamsMode() {
aghhttp.Error(r, w, http.StatusBadRequest, "upstream_mode: incorrect value") aghhttp.Error(r, w, http.StatusBadRequest, "upstream_mode: incorrect value")
return return
} case !req.checkCacheTTL():
if !req.checkCacheTTL() {
aghhttp.Error( aghhttp.Error(
r, r,
w, w,
@ -241,13 +238,15 @@ func (s *Server) handleSetConfig(w http.ResponseWriter, r *http.Request) {
) )
return return
default:
// Go on.
} }
restart := s.setConfig(req) restart := s.setConfig(req)
s.conf.ConfigModified() s.conf.ConfigModified()
if restart { if restart {
if err := s.Reconfigure(nil); err != nil { if err = s.Reconfigure(nil); err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err) aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
} }
} }
@ -387,14 +386,14 @@ func ValidateUpstreams(upstreams []string) (err error) {
var defaultUpstreamFound bool var defaultUpstreamFound bool
for _, u := range upstreams { for _, u := range upstreams {
var ok bool var useDefault bool
ok, err = validateUpstream(u) useDefault, err = validateUpstream(u)
if err != nil { if err != nil {
return err return err
} }
if !defaultUpstreamFound { if !defaultUpstreamFound {
defaultUpstreamFound = ok defaultUpstreamFound = useDefault
} }
} }
@ -407,50 +406,62 @@ func ValidateUpstreams(upstreams []string) (err error) {
var protocols = []string{"tls://", "https://", "tcp://", "sdns://", "quic://"} var protocols = []string{"tls://", "https://", "tcp://", "sdns://", "quic://"}
func validateUpstream(u string) (bool, error) { func validateUpstream(u string) (useDefault bool, err error) {
// Check if the user tries to specify upstream for domain. // Check if the user tries to specify upstream for domain.
u, useDefault, err := separateUpstream(u) var isDomainSpec bool
u, isDomainSpec, err = separateUpstream(u)
if err != nil { if err != nil {
return useDefault, err return !isDomainSpec, err
} }
// The special server address '#' means "use the default servers" // The special server address '#' means that default server must be used.
if u == "#" && !useDefault { if useDefault = !isDomainSpec; u == "#" && isDomainSpec {
return useDefault, nil return useDefault, nil
} }
// Check if the upstream has a valid protocol prefix // Check if the upstream has a valid protocol prefix.
//
// TODO(e.burkov): Validate the domain name.
for _, proto := range protocols { for _, proto := range protocols {
if strings.HasPrefix(u, proto) { if strings.HasPrefix(u, proto) {
return useDefault, nil return useDefault, nil
} }
} }
// Return error if the upstream contains '://' without any valid protocol
if strings.Contains(u, "://") { if strings.Contains(u, "://") {
return useDefault, fmt.Errorf("wrong protocol") return useDefault, errors.Error("wrong protocol")
} }
// Check if upstream is valid plain DNS // Check if upstream is either an IP or IP with port.
return useDefault, checkPlainDNS(u) if net.ParseIP(u) != nil {
return useDefault, nil
} else if _, err = netutil.ParseIPPort(u); err != nil {
return useDefault, err
}
return useDefault, nil
} }
// separateUpstream returns the upstream without the specified domains. // separateUpstream returns the upstream without the specified domains.
// useDefault is true when a default upstream must be used. // isDomainSpec is true when the upstream is domains-specific.
func separateUpstream(upstreamStr string) (upstream string, useDefault bool, err error) { func separateUpstream(upstreamStr string) (upstream string, isDomainSpec bool, err error) {
defer func() { err = errors.Annotate(err, "bad upstream for domain spec %q: %w", upstreamStr) }()
if !strings.HasPrefix(upstreamStr, "[/") { if !strings.HasPrefix(upstreamStr, "[/") {
return upstreamStr, true, nil return upstreamStr, false, nil
} }
defer func() { err = errors.Annotate(err, "bad upstream for domain %q: %w", upstreamStr) }()
parts := strings.Split(upstreamStr[2:], "/]") parts := strings.Split(upstreamStr[2:], "/]")
if len(parts) != 2 { switch len(parts) {
return "", false, errors.Error("duplicated separator") case 2:
// Go on.
case 1:
return "", false, errors.Error("missing separator")
default:
return "", true, errors.Error("duplicated separator")
} }
domains := parts[0] var domains string
upstream = parts[1] domains, upstream = parts[0], parts[1]
for i, host := range strings.Split(domains, "/") { for i, host := range strings.Split(domains, "/") {
if host == "" { if host == "" {
continue continue
@ -458,36 +469,11 @@ func separateUpstream(upstreamStr string) (upstream string, useDefault bool, err
err = netutil.ValidateDomainName(host) err = netutil.ValidateDomainName(host)
if err != nil { if err != nil {
return "", false, fmt.Errorf("domain at index %d: %w", i, err) return "", true, fmt.Errorf("domain at index %d: %w", i, err)
} }
} }
return upstream, false, nil return upstream, true, nil
}
// checkPlainDNS checks if host is plain DNS
func checkPlainDNS(upstream string) error {
// Check if host is ip without port
if net.ParseIP(upstream) != nil {
return nil
}
// Check if host is ip with port
ip, port, err := net.SplitHostPort(upstream)
if err != nil {
return err
}
if net.ParseIP(ip) == nil {
return fmt.Errorf("%s is not a valid IP", ip)
}
_, err = strconv.ParseInt(port, 0, 64)
if err != nil {
return fmt.Errorf("%s is not a valid port: %w", port, err)
}
return nil
} }
// excFunc is a signature of function to check if upstream exchanges correctly. // excFunc is a signature of function to check if upstream exchanges correctly.
@ -515,12 +501,8 @@ func checkDNSUpstreamExc(u upstream.Upstream) (err error) {
if len(reply.Answer) != 1 { if len(reply.Answer) != 1 {
return fmt.Errorf("wrong response") return fmt.Errorf("wrong response")
} } else if a, ok := reply.Answer[0].(*dns.A); !ok || !a.A.Equal(net.IP{8, 8, 8, 8}) {
return fmt.Errorf("wrong response")
if t, ok := reply.Answer[0].(*dns.A); ok {
if !net.IPv4(8, 8, 8, 8).Equal(t.A) {
return fmt.Errorf("wrong response")
}
} }
return nil return nil
@ -555,7 +537,7 @@ func checkDNS(input string, bootstrap []string, timeout time.Duration, ef excFun
// Separate upstream from domains list. // Separate upstream from domains list.
var useDefault bool var useDefault bool
if input, useDefault, err = separateUpstream(input); err != nil { if useDefault, err = validateUpstream(input); err != nil {
return fmt.Errorf("wrong upstream format: %w", err) return fmt.Errorf("wrong upstream format: %w", err)
} }
@ -564,7 +546,7 @@ func checkDNS(input string, bootstrap []string, timeout time.Duration, ef excFun
return nil return nil
} }
if _, err = validateUpstream(input); err != nil { if input, _, err = separateUpstream(input); err != nil {
return fmt.Errorf("wrong upstream format: %w", err) return fmt.Errorf("wrong upstream format: %w", err)
} }
@ -572,7 +554,8 @@ func checkDNS(input string, bootstrap []string, timeout time.Duration, ef excFun
bootstrap = defaultBootstrap bootstrap = defaultBootstrap
} }
log.Debug("checking if dns server %q works...", input) log.Debug("checking if upstream %s works", input)
var u upstream.Upstream var u upstream.Upstream
u, err = upstream.AddressToUpstream(input, &upstream.Options{ u, err = upstream.AddressToUpstream(input, &upstream.Options{
Bootstrap: bootstrap, Bootstrap: bootstrap,
@ -586,7 +569,7 @@ func checkDNS(input string, bootstrap []string, timeout time.Duration, ef excFun
return fmt.Errorf("upstream %q fails to exchange: %w", input, err) return fmt.Errorf("upstream %q fails to exchange: %w", input, err)
} }
log.Debug("dns %s works OK", input) log.Debug("upstream %s is ok", input)
return nil return nil
} }
@ -620,9 +603,9 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
err = checkDNS(host, bootstraps, timeout, checkPrivateUpstreamExc) err = checkDNS(host, bootstraps, timeout, checkPrivateUpstreamExc)
if err != nil { if err != nil {
log.Info("%v", err) log.Info("%v", err)
// TODO(e.burkov): If passed upstream have already // TODO(e.burkov): If passed upstream have already written an error
// written an error above, we rewriting the error for // above, we rewriting the error for it. These cases should be
// it. These cases should be handled properly instead. // handled properly instead.
result[host] = err.Error() result[host] = err.Error()
continue continue

View File

@ -184,7 +184,7 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
wantSet: "", wantSet: "",
}, { }, {
name: "upstream_dns_bad", name: "upstream_dns_bad",
wantSet: `wrong upstreams specification: address !!!: ` + wantSet: `wrong upstreams specification: bad ipport address "!!!": address !!!: ` +
`missing port in address`, `missing port in address`,
}, { }, {
name: "bootstraps_bad", name: "bootstraps_bad",
@ -235,107 +235,117 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
} }
func TestIsCommentOrEmpty(t *testing.T) { func TestIsCommentOrEmpty(t *testing.T) {
assert.True(t, IsCommentOrEmpty("")) for _, tc := range []struct {
assert.True(t, IsCommentOrEmpty("# comment")) want assert.BoolAssertionFunc
assert.False(t, IsCommentOrEmpty("1.2.3.4")) str string
}{{
want: assert.True,
str: "",
}, {
want: assert.True,
str: "# comment",
}, {
want: assert.False,
str: "1.2.3.4",
}} {
tc.want(t, IsCommentOrEmpty(tc.str))
}
} }
// TODO(a.garipov): Rewrite to check the actual error messages.
func TestValidateUpstream(t *testing.T) { func TestValidateUpstream(t *testing.T) {
testCases := []struct { testCases := []struct {
wantDef assert.BoolAssertionFunc
name string name string
upstream string upstream string
valid bool wantErr string
wantDef bool
}{{ }{{
wantDef: assert.True,
name: "invalid", name: "invalid",
upstream: "1.2.3.4.5", upstream: "1.2.3.4.5",
valid: false, wantErr: `bad ipport address "1.2.3.4.5": address 1.2.3.4.5: missing port in address`,
wantDef: false,
}, { }, {
wantDef: assert.True,
name: "invalid", name: "invalid",
upstream: "123.3.7m", upstream: "123.3.7m",
valid: false, wantErr: `bad ipport address "123.3.7m": address 123.3.7m: missing port in address`,
wantDef: false,
}, { }, {
wantDef: assert.True,
name: "invalid", name: "invalid",
upstream: "htttps://google.com/dns-query", upstream: "htttps://google.com/dns-query",
valid: false, wantErr: `wrong protocol`,
wantDef: false,
}, { }, {
wantDef: assert.True,
name: "invalid", name: "invalid",
upstream: "[/host.com]tls://dns.adguard.com", upstream: "[/host.com]tls://dns.adguard.com",
valid: false, wantErr: `bad upstream for domain "[/host.com]tls://dns.adguard.com": missing separator`,
wantDef: false,
}, { }, {
wantDef: assert.True,
name: "invalid", name: "invalid",
upstream: "[host.ru]#", upstream: "[host.ru]#",
valid: false, wantErr: `bad ipport address "[host.ru]#": address [host.ru]#: missing port in address`,
wantDef: false,
}, { }, {
wantDef: assert.True,
name: "valid_default", name: "valid_default",
upstream: "1.1.1.1", upstream: "1.1.1.1",
valid: true, wantErr: ``,
wantDef: true,
}, { }, {
wantDef: assert.True,
name: "valid_default", name: "valid_default",
upstream: "tls://1.1.1.1", upstream: "tls://1.1.1.1",
valid: true, wantErr: ``,
wantDef: true,
}, { }, {
wantDef: assert.True,
name: "valid_default", name: "valid_default",
upstream: "https://dns.adguard.com/dns-query", upstream: "https://dns.adguard.com/dns-query",
valid: true, wantErr: ``,
wantDef: true,
}, { }, {
wantDef: assert.True,
name: "valid_default", name: "valid_default",
upstream: "sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20", upstream: "sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
valid: true, wantErr: ``,
wantDef: true,
}, { }, {
wantDef: assert.False,
name: "valid", name: "valid",
upstream: "[/host.com/]1.1.1.1", upstream: "[/host.com/]1.1.1.1",
valid: true, wantErr: ``,
wantDef: false,
}, { }, {
wantDef: assert.False,
name: "valid", name: "valid",
upstream: "[//]tls://1.1.1.1", upstream: "[//]tls://1.1.1.1",
valid: true, wantErr: ``,
wantDef: false,
}, { }, {
wantDef: assert.False,
name: "valid", name: "valid",
upstream: "[/www.host.com/]#", upstream: "[/www.host.com/]#",
valid: true, wantErr: ``,
wantDef: false,
}, { }, {
wantDef: assert.False,
name: "valid", name: "valid",
upstream: "[/host.com/google.com/]8.8.8.8", upstream: "[/host.com/google.com/]8.8.8.8",
valid: true, wantErr: ``,
wantDef: false,
}, { }, {
wantDef: assert.False,
name: "valid", name: "valid",
upstream: "[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20", upstream: "[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
valid: true, wantErr: ``,
wantDef: false,
}, { }, {
wantDef: assert.False,
name: "idna", name: "idna",
upstream: "[/пример.рф/]8.8.8.8", upstream: "[/пример.рф/]8.8.8.8",
valid: true, wantErr: ``,
wantDef: false,
}, { }, {
wantDef: assert.False,
name: "bad_domain", name: "bad_domain",
upstream: "[/!/]8.8.8.8", upstream: "[/!/]8.8.8.8",
valid: false, wantErr: `bad upstream for domain "[/!/]8.8.8.8": domain at index 0: ` +
wantDef: false, `bad domain name "!": bad domain name label "!": bad domain name label rune '!'`,
}} }}
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
defaultUpstream, err := validateUpstream(tc.upstream) defaultUpstream, err := validateUpstream(tc.upstream)
require.Equal(t, tc.valid, err == nil) testutil.AssertErrorMsg(t, tc.wantErr, err)
if tc.valid { tc.wantDef(t, defaultUpstream)
assert.Equal(t, tc.wantDef, defaultUpstream)
}
}) })
} }
} }
@ -343,22 +353,19 @@ func TestValidateUpstream(t *testing.T) {
func TestValidateUpstreamsSet(t *testing.T) { func TestValidateUpstreamsSet(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string
msg string wantErr string
set []string set []string
wantNil bool
}{{ }{{
name: "empty", name: "empty",
msg: "empty upstreams array should be valid", wantErr: ``,
set: nil, set: nil,
wantNil: true,
}, { }, {
name: "comment", name: "comment",
msg: "comments should not be validated", wantErr: ``,
set: []string{"# comment"}, set: []string{"# comment"},
wantNil: true,
}, { }, {
name: "valid_no_default", name: "valid_no_default",
msg: "there is no default upstream", wantErr: `no default upstreams specified`,
set: []string{ set: []string{
"[/host.com/]1.1.1.1", "[/host.com/]1.1.1.1",
"[//]tls://1.1.1.1", "[//]tls://1.1.1.1",
@ -366,10 +373,9 @@ func TestValidateUpstreamsSet(t *testing.T) {
"[/host.com/google.com/]8.8.8.8", "[/host.com/google.com/]8.8.8.8",
"[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20", "[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
}, },
wantNil: false,
}, { }, {
name: "valid_with_default", name: "valid_with_default",
msg: "upstreams set is valid, but doesn't pass through validation cause: %s", wantErr: ``,
set: []string{ set: []string{
"[/host.com/]1.1.1.1", "[/host.com/]1.1.1.1",
"[//]tls://1.1.1.1", "[//]tls://1.1.1.1",
@ -378,19 +384,16 @@ func TestValidateUpstreamsSet(t *testing.T) {
"[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20", "[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
"8.8.8.8", "8.8.8.8",
}, },
wantNil: true,
}, { }, {
name: "invalid", name: "invalid",
msg: "there is an invalid upstream in set, but it pass through validation", wantErr: `cannot prepare the upstream dhcp://fake.dns ([]): unsupported URL scheme: dhcp`,
set: []string{"dhcp://fake.dns"}, set: []string{"dhcp://fake.dns"},
wantNil: false,
}} }}
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
err := ValidateUpstreams(tc.set) err := ValidateUpstreams(tc.set)
testutil.AssertErrorMsg(t, tc.wantErr, err)
assert.Equalf(t, tc.wantNil, err == nil, tc.msg, err)
}) })
} }
} }

View File

@ -123,7 +123,9 @@
'8.8.8.8': 'OK' '8.8.8.8': 'OK'
'8.8.4.4': 'OK' '8.8.4.4': 'OK'
'192.168.1.104:53535': > '192.168.1.104:53535': >
Couldn't communicate with DNS server upstream "192.168.1.104:1234" fails to exchange: couldn't
communicate with upstream: read udp
192.168.1.100:60675->8.8.8.8:1234: i/o timeout
'/version.json': '/version.json':
'post': 'post':
'tags': 'tags':