From a7f9e0122be964d78af3fd6d0e9737b173bcc628 Mon Sep 17 00:00:00 2001 From: Ainar Garipov Date: Thu, 25 Mar 2021 16:00:27 +0300 Subject: [PATCH] Pull request: all: custom autohost tlds Updates #2393. Squashed commit of the following: commit 87034134e240480938cdeec14d6b44294bf6442c Author: Ainar Garipov Date: Thu Mar 25 15:48:46 2021 +0300 dnsforward: fix commit abf3a1ce8ed7a148d1cc631007fb0422f6da4ae6 Author: Ainar Garipov Date: Thu Mar 25 15:21:11 2021 +0300 dnsforward: imp code, validation commit fac389bdafc093ce17a7e0831166b89293b550be Author: Ainar Garipov Date: Thu Mar 25 14:54:45 2021 +0300 all: add validation, imp docs, tests commit 21b4532afe59f3b89383cb330c9a7d49ec124b6e Author: Ainar Garipov Date: Wed Mar 24 19:09:43 2021 +0300 all: custom autohost tlds --- CHANGELOG.md | 2 + internal/dnsforward/clientid.go | 38 +++-- internal/dnsforward/clientid_test.go | 17 +-- internal/dnsforward/config.go | 8 +- internal/dnsforward/dns.go | 50 +++---- internal/dnsforward/dns_test.go | 107 ++++++++++++++ internal/dnsforward/dnsforward.go | 56 ++++++-- internal/dnsforward/dnsforward_test.go | 187 ++++++++++++++++++------- internal/home/config.go | 6 + internal/home/dns.go | 17 ++- internal/home/whois_test.go | 21 ++- staticcheck.conf | 1 + 12 files changed, 389 insertions(+), 121 deletions(-) create mode 100644 internal/dnsforward/dns_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index ad463ee4..ef25039f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ and this project adheres to ### Added +- The ability to set a custom TLD for known local-network hosts ([#2393]). - The ability to serve DNS queries on multiple hosts and interfaces ([#1401]). - `ips` and `text` DHCP server options ([#2385]). - `SRV` records support in `$dnsrewrite` filters ([#2533]). @@ -41,6 +42,7 @@ and this project adheres to [#1401]: https://github.com/AdguardTeam/AdGuardHome/issues/1401 [#2385]: https://github.com/AdguardTeam/AdGuardHome/issues/2385 +[#2393]: https://github.com/AdguardTeam/AdGuardHome/issues/2393 [#2412]: https://github.com/AdguardTeam/AdGuardHome/issues/2412 [#2498]: https://github.com/AdguardTeam/AdGuardHome/issues/2498 [#2533]: https://github.com/AdguardTeam/AdGuardHome/issues/2533 diff --git a/internal/dnsforward/clientid.go b/internal/dnsforward/clientid.go index c497c7b7..21dcac53 100644 --- a/internal/dnsforward/clientid.go +++ b/internal/dnsforward/clientid.go @@ -10,20 +10,31 @@ import ( "github.com/lucas-clemente/quic-go" ) -const maxDomainPartLen = 64 +// maxDomainLabelLen is the maximum allowed length of a domain name label +// according to RFC 1035. +const maxDomainLabelLen = 63 + +// validateDomainNameLabel returns an error if label is not a valid label of +// a domain name. +func validateDomainNameLabel(label string) (err error) { + if len(label) > maxDomainLabelLen { + return fmt.Errorf("%q is too long, max: %d", label, maxDomainLabelLen) + } + + for i, r := range label { + if (r < 'a' || r > 'z') && (r < '0' || r > '9') && r != '-' { + return fmt.Errorf("invalid char %q at index %d in %q", r, i, label) + } + } + + return nil +} // ValidateClientID returns an error if clientID is not a valid client ID. func ValidateClientID(clientID string) (err error) { - if len(clientID) > maxDomainPartLen { - return fmt.Errorf("client id %q is too long, max: %d", clientID, maxDomainPartLen) - } - - for i, r := range clientID { - if (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '-' { - continue - } - - return fmt.Errorf("invalid char %q at index %d in client id %q", r, i, clientID) + err = validateDomainNameLabel(clientID) + if err != nil { + return fmt.Errorf("invalid client id: %w", err) } return nil @@ -49,7 +60,8 @@ func clientIDFromClientServerName(hostSrvName, cliSrvName string, strict bool) ( clientID = cliSrvName[:len(cliSrvName)-len(hostSrvName)-1] err = ValidateClientID(clientID) if err != nil { - return "", fmt.Errorf("invalid client id: %w", err) + // Don't wrap the error, because it's informative enough as is. + return "", err } return clientID, nil @@ -93,7 +105,7 @@ func processClientIDHTTPS(ctx *dnsContext) (rc resultCode) { err := ValidateClientID(clientID) if err != nil { - ctx.err = fmt.Errorf("client id check: invalid client id: %w", err) + ctx.err = fmt.Errorf("client id check: %w", err) return resultCodeError } diff --git a/internal/dnsforward/clientid_test.go b/internal/dnsforward/clientid_test.go index 503203f9..463e841e 100644 --- a/internal/dnsforward/clientid_test.go +++ b/internal/dnsforward/clientid_test.go @@ -118,7 +118,7 @@ func TestProcessClientID(t *testing.T) { cliSrvName: "!!!.example.com", wantClientID: "", wantErrMsg: `client id check: invalid client id: invalid char '!' ` + - `at index 0 in client id "!!!"`, + `at index 0 in "!!!"`, wantRes: resultCodeError, strictSNI: true, }, { @@ -128,9 +128,9 @@ func TestProcessClientID(t *testing.T) { cliSrvName: `abcdefghijklmnopqrstuvwxyz0123456789abcdefghijklmno` + `pqrstuvwxyz0123456789.example.com`, wantClientID: "", - wantErrMsg: `client id check: invalid client id: client id "abcdefghijklmno` + + wantErrMsg: `client id check: invalid client id: "abcdefghijklmno` + `pqrstuvwxyz0123456789abcdefghijklmnopqrstuvwxyz0123456789" ` + - `is too long, max: 64`, + `is too long, max: 63`, wantRes: resultCodeError, strictSNI: true, }, { @@ -182,9 +182,9 @@ func TestProcessClientID(t *testing.T) { assert.Equal(t, tc.wantClientID, dctx.clientID) if tc.wantErrMsg == "" { - assert.Nil(t, dctx.err) + assert.NoError(t, dctx.err) } else { - require.NotNil(t, dctx.err) + require.Error(t, dctx.err) assert.Equal(t, tc.wantErrMsg, dctx.err.Error()) } }) @@ -239,7 +239,7 @@ func TestProcessClientID_https(t *testing.T) { path: "/dns-query/!!!", wantClientID: "", wantErrMsg: `client id check: invalid client id: invalid char '!'` + - ` at index 0 in client id "!!!"`, + ` at index 0 in "!!!"`, wantRes: resultCodeError, }} @@ -263,9 +263,10 @@ func TestProcessClientID_https(t *testing.T) { assert.Equal(t, tc.wantClientID, dctx.clientID) if tc.wantErrMsg == "" { - assert.Nil(t, dctx.err) + assert.NoError(t, dctx.err) } else { - require.NotNil(t, dctx.err) + require.Error(t, dctx.err) + assert.Equal(t, tc.wantErrMsg, dctx.err.Error()) } }) diff --git a/internal/dnsforward/config.go b/internal/dnsforward/config.go index e80ac941..fe2007c7 100644 --- a/internal/dnsforward/config.go +++ b/internal/dnsforward/config.go @@ -277,7 +277,7 @@ func (s *Server) prepareUpstreamSettings() error { s := util.SplitNext(&d, '\n') upstreams = append(upstreams, s) } - log.Debug("DNS: using %d upstream servers from file %s", len(upstreams), s.conf.UpstreamDNSFileName) + log.Debug("dns: using %d upstream servers from file %s", len(upstreams), s.conf.UpstreamDNSFileName) } else { upstreams = s.conf.UpstreamDNS } @@ -357,11 +357,11 @@ func (s *Server) prepareTLS(proxyConfig *proxy.Config) error { } if len(x.DNSNames) != 0 { s.conf.dnsNames = x.DNSNames - log.Debug("DNS: using DNS names from certificate's SAN: %v", x.DNSNames) + log.Debug("dns: using DNS names from certificate's SAN: %v", x.DNSNames) sort.Strings(s.conf.dnsNames) } else { s.conf.dnsNames = append(s.conf.dnsNames, x.Subject.CommonName) - log.Debug("DNS: using DNS name from certificate's CN: %s", x.Subject.CommonName) + log.Debug("dns: using DNS name from certificate's CN: %s", x.Subject.CommonName) } } @@ -377,7 +377,7 @@ func (s *Server) prepareTLS(proxyConfig *proxy.Config) error { // If the server name (from SNI) supplied by client is incorrect - we terminate the ongoing TLS handshake. func (s *Server) onGetCertificate(ch *tls.ClientHelloInfo) (*tls.Certificate, error) { if s.conf.StrictSNICheck && !matchDNSName(s.conf.dnsNames, ch.ServerName) { - log.Info("DNS: TLS: unknown SNI in Client Hello: %s", ch.ServerName) + log.Info("dns: tls: unknown SNI in Client Hello: %s", ch.ServerName) return nil, fmt.Errorf("invalid SNI") } return &s.conf.cert, nil diff --git a/internal/dnsforward/dns.go b/internal/dnsforward/dns.go index acc6aa86..b5f6c4ec 100644 --- a/internal/dnsforward/dns.go +++ b/internal/dnsforward/dns.go @@ -15,6 +15,8 @@ import ( // To transfer information between modules type dnsContext struct { + // TODO(a.garipov): Remove this and rewrite processors to be methods of + // *Server instead. srv *Server proxyCtx *proxy.DNSContext // setts are the filtering settings for the client. @@ -75,7 +77,7 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error { // appropriate handler. mods := []modProcessFunc{ processInitial, - processInternalHosts, + s.processInternalHosts, processInternalIPAddrs, processClientID, processFilteringBeforeRequest, @@ -136,7 +138,7 @@ func isHostnameOK(hostname string) bool { (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '.' || c == '-') { - log.Debug("DNS: skipping invalid hostname %s from DHCP", hostname) + log.Debug("dns: skipping invalid hostname %s from DHCP", hostname) return false } } @@ -172,7 +174,7 @@ func (s *Server) onDHCPLeaseChanged(flags int) { hostToIP[lowhost] = ip } - log.Debug("DNS: added %d A/PTR entries from DHCP", len(m)) + log.Debug("dns: added %d A/PTR entries from DHCP", len(m)) s.tableHostToIPLock.Lock() s.tableHostToIP = hostToIP @@ -183,20 +185,22 @@ func (s *Server) onDHCPLeaseChanged(flags int) { s.tablePTRLock.Unlock() } -// Respond to A requests if the target host name is associated with a lease from our DHCP server -func processInternalHosts(ctx *dnsContext) (rc resultCode) { - s := ctx.srv - req := ctx.proxyCtx.Req - if !(req.Question[0].Qtype == dns.TypeA || req.Question[0].Qtype == dns.TypeAAAA) { +// processInternalHosts respond to A requests if the target hostname is known to +// the server. +// +// TODO(a.garipov): Adapt to AAAA as well. +func (s *Server) processInternalHosts(dctx *dnsContext) (rc resultCode) { + req := dctx.proxyCtx.Req + q := req.Question[0] + if q.Qtype != dns.TypeA { return resultCodeSuccess } - host := req.Question[0].Name - host = strings.ToLower(host) - if !strings.HasSuffix(host, ".lan.") { + reqHost := strings.ToLower(q.Name) + host := strings.TrimSuffix(reqHost, s.autohostSuffix) + if host == reqHost { return resultCodeSuccess } - host = strings.TrimSuffix(host, ".lan.") s.tableHostToIPLock.Lock() if s.tableHostToIP == nil { @@ -209,24 +213,22 @@ func processInternalHosts(ctx *dnsContext) (rc resultCode) { return resultCodeSuccess } - log.Debug("DNS: internal record: %s -> %s", req.Question[0].Name, ip) + log.Debug("dns: internal record: %s -> %s", q.Name, ip) resp := s.makeResponse(req) - if req.Question[0].Qtype == dns.TypeA { - a := &dns.A{} - a.Hdr = dns.RR_Header{ - Name: req.Question[0].Name, - Rrtype: dns.TypeA, - Ttl: s.conf.BlockedResponseTTL, - Class: dns.ClassINET, + if q.Qtype == dns.TypeA { + a := &dns.A{ + Hdr: s.hdr(req, dns.TypeA), + A: make([]byte, len(ip)), } - a.A = make([]byte, 4) + copy(a.A, ip) resp.Answer = append(resp.Answer, a) } - ctx.proxyCtx.Res = resp + dctx.proxyCtx.Res = resp + return resultCodeSuccess } @@ -257,7 +259,7 @@ func processInternalIPAddrs(ctx *dnsContext) (rc resultCode) { return resultCodeSuccess } - log.Debug("DNS: reverse-lookup: %s -> %s", arpa, host) + log.Debug("dns: reverse-lookup: %s -> %s", arpa, host) resp := s.makeResponse(req) ptr := &dns.PTR{} @@ -325,7 +327,7 @@ func processUpstream(ctx *dnsContext) (rc resultCode) { if s.conf.EnableDNSSEC { opt := d.Req.IsEdns0() if opt == nil { - log.Debug("DNS: Adding OPT record with DNSSEC flag") + log.Debug("dns: Adding OPT record with DNSSEC flag") d.Req.SetEdns0(4096, true) } else if !opt.Do() { opt.SetDo(true) diff --git a/internal/dnsforward/dns_test.go b/internal/dnsforward/dns_test.go new file mode 100644 index 00000000..682a12c5 --- /dev/null +++ b/internal/dnsforward/dns_test.go @@ -0,0 +1,107 @@ +package dnsforward + +import ( + "net" + "testing" + + "github.com/AdguardTeam/dnsproxy/proxy" + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestServer_ProcessInternalHosts(t *testing.T) { + knownIP := net.IP{1, 2, 3, 4} + testCases := []struct { + name string + host string + suffix string + wantErrMsg string + wantIP net.IP + qtyp uint16 + wantRes resultCode + }{{ + name: "success_external", + host: "example.com", + suffix: defaultAutohostSuffix, + wantErrMsg: "", + wantIP: nil, + qtyp: dns.TypeA, + wantRes: resultCodeSuccess, + }, { + name: "success_external_non_a", + host: "example.com", + suffix: defaultAutohostSuffix, + wantErrMsg: "", + wantIP: nil, + qtyp: dns.TypeCNAME, + wantRes: resultCodeSuccess, + }, { + name: "success_internal", + host: "example.lan", + suffix: defaultAutohostSuffix, + wantErrMsg: "", + wantIP: knownIP, + qtyp: dns.TypeA, + wantRes: resultCodeSuccess, + }, { + name: "success_internal_unknown", + host: "example-new.lan", + suffix: defaultAutohostSuffix, + wantErrMsg: "", + wantIP: nil, + qtyp: dns.TypeA, + wantRes: resultCodeSuccess, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + s := &Server{ + autohostSuffix: tc.suffix, + tableHostToIP: map[string]net.IP{ + "example": knownIP, + }, + } + + req := &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Id: 1234, + }, + Question: []dns.Question{{ + Name: dns.Fqdn(tc.host), + Qtype: tc.qtyp, + Qclass: dns.ClassINET, + }}, + } + + dctx := &dnsContext{ + proxyCtx: &proxy.DNSContext{ + Req: req, + }, + } + + res := s.processInternalHosts(dctx) + assert.Equal(t, tc.wantRes, res) + + if tc.wantErrMsg == "" { + assert.NoError(t, dctx.err) + } else { + require.Error(t, dctx.err) + + assert.Equal(t, tc.wantErrMsg, dctx.err.Error()) + } + + pctx := dctx.proxyCtx + if tc.wantIP == nil { + assert.Nil(t, pctx.Res) + } else { + require.NotNil(t, pctx.Res) + + ans := pctx.Res.Answer + require.Len(t, ans, 1) + + assert.Equal(t, tc.wantIP, ans[0].(*dns.A).A) + } + }) + } +} diff --git a/internal/dnsforward/dnsforward.go b/internal/dnsforward/dnsforward.go index 68a0b6ae..85acf201 100644 --- a/internal/dnsforward/dnsforward.go +++ b/internal/dnsforward/dnsforward.go @@ -56,6 +56,10 @@ type Server struct { stats stats.Stats access *accessCtx + // autohostSuffix is the suffix used to detect internal hosts. It must + // be a valid top-level domain plus dots on each side. + autohostSuffix string + ipset ipsetCtx tableHostToIP map[string]net.IP // "hostname -> IP" table for internal addresses (DHCP) @@ -74,21 +78,50 @@ type Server struct { conf ServerConfig } -// DNSCreateParams - parameters for NewServer() +// defaultAutohostSuffix is the default suffix used to detect internal hosts +// when no suffix is provided. See the documentation for Server.autohostSuffix. +const defaultAutohostSuffix = ".lan." + +// DNSCreateParams are parameters to create a new server. type DNSCreateParams struct { - DNSFilter *dnsfilter.DNSFilter - Stats stats.Stats - QueryLog querylog.QueryLog - DHCPServer dhcpd.ServerInterface + DNSFilter *dnsfilter.DNSFilter + Stats stats.Stats + QueryLog querylog.QueryLog + DHCPServer dhcpd.ServerInterface + AutohostTLD string +} + +// tldToSuffix converts a top-level domain into an autohost suffix. +func tldToSuffix(tld string) (suffix string) { + l := len(tld) + 2 + b := make([]byte, l) + b[0] = '.' + copy(b[1:], tld) + b[l-1] = '.' + + return string(b) } // NewServer creates a new instance of the dnsforward.Server // Note: this function must be called only once -func NewServer(p DNSCreateParams) *Server { - s := &Server{ - dnsFilter: p.DNSFilter, - stats: p.Stats, - queryLog: p.QueryLog, +func NewServer(p DNSCreateParams) (s *Server, err error) { + var autohostSuffix string + if p.AutohostTLD == "" { + autohostSuffix = defaultAutohostSuffix + } else { + err = validateDomainNameLabel(p.AutohostTLD) + if err != nil { + return nil, fmt.Errorf("autohost tld: %w", err) + } + + autohostSuffix = tldToSuffix(p.AutohostTLD) + } + + s = &Server{ + dnsFilter: p.DNSFilter, + stats: p.Stats, + queryLog: p.QueryLog, + autohostSuffix: autohostSuffix, } if p.DHCPServer != nil { @@ -101,7 +134,8 @@ func NewServer(p DNSCreateParams) *Server { // Use plain DNS on MIPS, encryption is too slow defaultDNS = defaultBootstrap } - return s + + return s, nil } // NewCustomServer creates a new instance of *Server with custom internal proxy. diff --git a/internal/dnsforward/dnsforward_test.go b/internal/dnsforward/dnsforward_test.go index a1237849..9ee4d719 100644 --- a/internal/dnsforward/dnsforward_test.go +++ b/internal/dnsforward/dnsforward_test.go @@ -19,10 +19,9 @@ import ( "time" "github.com/AdguardTeam/AdGuardHome/internal/aghtest" - "github.com/AdguardTeam/AdGuardHome/internal/util" - "github.com/AdguardTeam/AdGuardHome/internal/dhcpd" "github.com/AdguardTeam/AdGuardHome/internal/dnsfilter" + "github.com/AdguardTeam/AdGuardHome/internal/util" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" "github.com/miekg/dns" @@ -43,11 +42,11 @@ func startDeferStop(t *testing.T, s *Server) { t.Helper() err := s.Start() - require.Nilf(t, err, "failed to start server: %s", err) + require.NoErrorf(t, err, "failed to start server: %s", err) t.Cleanup(func() { serr := s.Stop() - require.Nilf(t, serr, "dns server failed to stop: %s", serr) + require.NoErrorf(t, serr, "dns server failed to stop: %s", serr) }) } @@ -65,9 +64,13 @@ func createTestServer(t *testing.T, filterConf *dnsfilter.Config, forwardConf Se f := dnsfilter.New(filterConf, filters) - s := NewServer(DNSCreateParams{DNSFilter: f}) + s, err := NewServer(DNSCreateParams{DNSFilter: f}) + require.NoError(t, err) + s.conf = forwardConf - require.Nil(t, s.Prepare(nil)) + + err = s.Prepare(nil) + require.NoError(t, err) return s } @@ -76,11 +79,11 @@ func createServerTLSConfig(t *testing.T) (*tls.Config, []byte, []byte) { t.Helper() privateKey, err := rsa.GenerateKey(rand.Reader, 2048) - require.Nilf(t, err, "cannot generate RSA key: %s", err) + require.NoErrorf(t, err, "cannot generate RSA key: %s", err) serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) - require.Nilf(t, err, "failed to generate serial number: %s", err) + require.NoErrorf(t, err, "failed to generate serial number: %s", err) notBefore := time.Now() notAfter := notBefore.Add(5 * 365 * time.Hour * 24) @@ -101,13 +104,13 @@ func createServerTLSConfig(t *testing.T) (*tls.Config, []byte, []byte) { template.DNSNames = append(template.DNSNames, tlsServerName) derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, publicKey(privateKey), privateKey) - require.Nilf(t, err, "failed to create certificate: %s", err) + require.NoErrorf(t, err, "failed to create certificate: %s", err) certPem := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) keyPem := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)}) cert, err := tls.X509KeyPair(certPem, keyPem) - require.Nilf(t, err, "failed to create certificate: %s", err) + require.NoErrorf(t, err, "failed to create certificate: %s", err) return &tls.Config{ Certificates: []tls.Certificate{cert}, @@ -131,7 +134,7 @@ func createTestTLS(t *testing.T, tlsConf TLSConfig) (s *Server, certPem []byte) s.conf.TLSConfig = tlsConf err := s.Prepare(nil) - require.Nilf(t, err, "failed to prepare server: %s", err) + require.NoErrorf(t, err, "failed to prepare server: %s", err) return s, certPem } @@ -191,10 +194,10 @@ func sendTestMessagesAsync(t *testing.T, conn *dns.Conn) { defer wg.Done() err := conn.WriteMsg(msg) - require.Nilf(t, err, "cannot write message: %s", err) + require.NoErrorf(t, err, "cannot write message: %s", err) res, err := conn.ReadMsg() - require.Nilf(t, err, "cannot read response to message: %s", err) + require.NoErrorf(t, err, "cannot read response to message: %s", err) assertGoogleAResponse(t, res) }() @@ -248,7 +251,7 @@ func TestServer(t *testing.T) { client := dns.Client{Net: tc.proto} reply, _, err := client.Exchange(createGoogleATestMessage(), addr.String()) - require.Nilf(t, err, "сouldn't talk to server %s: %s", addr, err) + require.NoErrorf(t, err, "сouldn't talk to server %s: %s", addr, err) assertGoogleAResponse(t, reply) }) @@ -275,7 +278,7 @@ func TestServerWithProtectionDisabled(t *testing.T) { client := dns.Client{Net: proxy.ProtoUDP} reply, _, err := client.Exchange(req, addr.String()) - require.Nilf(t, err, "сouldn't talk to server %s: %s", addr, err) + require.NoErrorf(t, err, "сouldn't talk to server %s: %s", addr, err) assertGoogleAResponse(t, reply) } @@ -304,7 +307,7 @@ func TestDoTServer(t *testing.T) { // Create a DNS-over-TLS client connection. addr := s.dnsProxy.Addr(proxy.ProtoTLS) conn, err := dns.DialWithTLS("tcp-tls", addr.String(), tlsConfig) - require.Nilf(t, err, "cannot connect to the proxy: %s", err) + require.NoErrorf(t, err, "cannot connect to the proxy: %s", err) sendTestMessages(t, conn) } @@ -326,12 +329,12 @@ func TestDoQServer(t *testing.T) { addr := s.dnsProxy.Addr(proxy.ProtoQUIC) opts := upstream.Options{InsecureSkipVerify: true} u, err := upstream.AddressToUpstream(fmt.Sprintf("%s://%s", proxy.ProtoQUIC, addr), opts) - require.Nil(t, err) + require.NoError(t, err) // Send the test message. req := createGoogleATestMessage() res, err := u.Exchange(req) - require.Nil(t, err) + require.NoError(t, err) assertGoogleAResponse(t, res) } @@ -369,7 +372,7 @@ func TestServerRace(t *testing.T) { // Message over UDP. addr := s.dnsProxy.Addr(proxy.ProtoUDP) conn, err := dns.Dial(proxy.ProtoUDP, addr.String()) - require.Nilf(t, err, "cannot connect to the proxy: %s", err) + require.NoErrorf(t, err, "cannot connect to the proxy: %s", err) sendTestMessagesAsync(t, conn) } @@ -432,7 +435,7 @@ func TestSafeSearch(t *testing.T) { req := createTestMessage(tc.host) reply, _, err := client.Exchange(req, addr) - require.Nilf(t, err, "couldn't talk to server %s: %s", addr, err) + require.NoErrorf(t, err, "couldn't talk to server %s: %s", addr, err) assertResponse(t, reply, tc.want) }) } @@ -459,7 +462,7 @@ func TestInvalidRequest(t *testing.T) { Timeout: 500 * time.Millisecond, }).Exchange(&req, addr) - assert.Nil(t, err, "got a response to an invalid query") + assert.NoErrorf(t, err, "got a response to an invalid query") } func TestBlockedRequest(t *testing.T) { @@ -479,7 +482,8 @@ func TestBlockedRequest(t *testing.T) { req := createTestMessage("nxdomain.example.org.") reply, err := dns.Exchange(req, addr.String()) - require.Nilf(t, err, "couldn't talk to server %s: %s", addr, err) + require.NoErrorf(t, err, "couldn't talk to server %s: %s", addr, err) + assert.Equal(t, dns.RcodeSuccess, reply.Rcode) require.Len(t, reply.Answer, 1) @@ -514,8 +518,8 @@ func TestServerCustomClientUpstream(t *testing.T) { req := createTestMessage("host.") reply, err := dns.Exchange(req, addr.String()) + require.NoError(t, err) - require.Nil(t, err) assert.Equal(t, dns.RcodeSuccess, reply.Rcode) require.NotEmpty(t, reply.Answer) @@ -558,7 +562,8 @@ func TestBlockCNAMEProtectionEnabled(t *testing.T) { req := createTestMessage("badhost.") reply, err := dns.Exchange(req, addr.String()) - require.Nil(t, err) + require.NoError(t, err) + assert.Equal(t, dns.RcodeSuccess, reply.Rcode) } @@ -608,7 +613,8 @@ func TestBlockCNAME(t *testing.T) { req := createTestMessage(tc.host) reply, err := dns.Exchange(req, addr) - require.Nil(t, err) + require.NoError(t, err) + assert.Equal(t, dns.RcodeSuccess, reply.Rcode) if tc.want { require.Len(t, reply.Answer, 1) @@ -658,7 +664,8 @@ func TestClientRulesForCNAMEMatching(t *testing.T) { // However, in our case it should not be blocked as filtering is // disabled on the client level. reply, err := dns.Exchange(&req, addr.String()) - require.Nil(t, err) + require.NoError(t, err) + assert.Equal(t, dns.RcodeSuccess, reply.Rcode) } @@ -689,7 +696,7 @@ func TestNullBlockedRequest(t *testing.T) { } reply, err := dns.Exchange(&req, addr.String()) - require.Nilf(t, err, "couldn't talk to server %s: %s", addr, err) + require.NoErrorf(t, err, "couldn't talk to server %s: %s", addr, err) require.Lenf(t, reply.Answer, 1, "dns server %s returned reply with wrong number of answers - %d", addr, len(reply.Answer)) a, ok := reply.Answer[0].(*dns.A) require.Truef(t, ok, "dns server %s returned wrong answer type instead of A: %v", addr, reply.Answer[0]) @@ -703,9 +710,11 @@ func TestBlockedCustomIP(t *testing.T) { Data: []byte(rules), }} - s := NewServer(DNSCreateParams{ + s, err := NewServer(DNSCreateParams{ DNSFilter: dnsfilter.New(&dnsfilter.Config{}, filters), }) + require.NoError(t, err) + conf := &ServerConfig{ UDPListenAddrs: []*net.UDPAddr{{}}, TCPListenAddrs: []*net.TCPAddr{{}}, @@ -716,12 +725,16 @@ func TestBlockedCustomIP(t *testing.T) { UpstreamDNS: []string{"8.8.8.8:53", "8.8.4.4:53"}, }, } + // Invalid BlockingIPv4. - assert.NotNil(t, s.Prepare(conf)) + err = s.Prepare(conf) + assert.Error(t, err) conf.BlockingIPv4 = net.IP{0, 0, 0, 1} conf.BlockingIPv6 = net.ParseIP("::1") - require.Nil(t, s.Prepare(conf)) + + err = s.Prepare(conf) + require.NoError(t, err) startDeferStop(t, s) @@ -729,18 +742,24 @@ func TestBlockedCustomIP(t *testing.T) { req := createTestMessageWithType("null.example.org.", dns.TypeA) reply, err := dns.Exchange(req, addr.String()) - require.Nil(t, err) + require.NoError(t, err) + require.Len(t, reply.Answer, 1) + a, ok := reply.Answer[0].(*dns.A) require.True(t, ok) + assert.True(t, net.IP{0, 0, 0, 1}.Equal(a.A)) req = createTestMessageWithType("null.example.org.", dns.TypeAAAA) reply, err = dns.Exchange(req, addr.String()) - require.Nil(t, err) + require.NoError(t, err) + require.Len(t, reply.Answer, 1) + a6, ok := reply.Answer[0].(*dns.AAAA) require.True(t, ok) + assert.Equal(t, "::1", a6.AAAA.String()) } @@ -760,7 +779,7 @@ func TestBlockedByHosts(t *testing.T) { req := createTestMessage("host.example.org.") reply, err := dns.Exchange(req, addr.String()) - require.Nilf(t, err, "couldn't talk to server %s: %s", addr, err) + require.NoErrorf(t, err, "couldn't talk to server %s: %s", addr, err) require.Lenf(t, reply.Answer, 1, "dns server %s returned reply with wrong number of answers - %d", addr, len(reply.Answer)) a, ok := reply.Answer[0].(*dns.A) require.Truef(t, ok, "dns server %s returned wrong answer type instead of A: %v", addr, reply.Answer[0]) @@ -796,7 +815,7 @@ func TestBlockedBySafeBrowsing(t *testing.T) { req := createTestMessage(hostname + ".") reply, err := dns.Exchange(req, addr.String()) - require.Nilf(t, err, "couldn't talk to server %s: %s", addr, err) + require.NoErrorf(t, err, "couldn't talk to server %s: %s", addr, err) require.Lenf(t, reply.Answer, 1, "dns server %s returned reply with wrong number of answers - %d", addr, len(reply.Answer)) a, ok := reply.Answer[0].(*dns.A) @@ -822,15 +841,19 @@ func TestRewrite(t *testing.T) { } f := dnsfilter.New(c, nil) - s := NewServer(DNSCreateParams{DNSFilter: f}) - assert.Nil(t, s.Prepare(&ServerConfig{ + s, err := NewServer(DNSCreateParams{DNSFilter: f}) + require.NoError(t, err) + + err = s.Prepare(&ServerConfig{ UDPListenAddrs: []*net.UDPAddr{{}}, TCPListenAddrs: []*net.TCPAddr{{}}, FilteringConfig: FilteringConfig{ ProtectionEnabled: true, UpstreamDNS: []string{"8.8.8.8:53"}, }, - })) + }) + assert.NoError(t, err) + s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ &aghtest.TestUpstream{ CName: map[string]string{ @@ -847,34 +870,41 @@ func TestRewrite(t *testing.T) { req := createTestMessageWithType("test.com.", dns.TypeA) reply, err := dns.Exchange(req, addr.String()) - require.Nil(t, err) + require.NoError(t, err) + require.Len(t, reply.Answer, 1) + a, ok := reply.Answer[0].(*dns.A) require.True(t, ok) + assert.True(t, net.IP{1, 2, 3, 4}.Equal(a.A)) req = createTestMessageWithType("test.com.", dns.TypeAAAA) reply, err = dns.Exchange(req, addr.String()) - require.Nil(t, err) + require.NoError(t, err) + assert.Empty(t, reply.Answer) req = createTestMessageWithType("alias.test.com.", dns.TypeA) reply, err = dns.Exchange(req, addr.String()) - require.Nil(t, err) + require.NoError(t, err) require.Len(t, reply.Answer, 2) + assert.Equal(t, "test.com.", reply.Answer[0].(*dns.CNAME).Target) assert.True(t, net.IP{1, 2, 3, 4}.Equal(reply.Answer[1].(*dns.A).A)) req = createTestMessageWithType("my.alias.example.org.", dns.TypeA) reply, err = dns.Exchange(req, addr.String()) - require.Nil(t, err) + require.NoError(t, err) // The original question is restored. require.Len(t, reply.Question, 1) + assert.Equal(t, "my.alias.example.org.", reply.Question[0].Name) require.Len(t, reply.Answer, 2) + assert.Equal(t, "example.org.", reply.Answer[0].(*dns.CNAME).Target) assert.Equal(t, dns.TypeA, reply.Answer[1].Header().Rrtype) } @@ -1104,17 +1134,23 @@ func (d *testDHCP) Leases(flags int) []dhcpd.Lease { func (d *testDHCP) SetOnLeaseChanged(onLeaseChanged dhcpd.OnLeaseChangedT) {} func TestPTRResponseFromDHCPLeases(t *testing.T) { - s := NewServer(DNSCreateParams{ + s, err := NewServer(DNSCreateParams{ DNSFilter: dnsfilter.New(&dnsfilter.Config{}, nil), DHCPServer: &testDHCP{}, }) + require.NoError(t, err) s.conf.UDPListenAddrs = []*net.UDPAddr{{}} s.conf.TCPListenAddrs = []*net.TCPAddr{{}} s.conf.UpstreamDNS = []string{"127.0.0.1:53"} s.conf.FilteringConfig.ProtectionEnabled = true - require.Nil(t, s.Prepare(nil)) - require.Nil(t, s.Start()) + + err = s.Prepare(nil) + require.NoError(t, err) + + err = s.Start() + require.NoError(t, err) + t.Cleanup(func() { s.Close() }) @@ -1123,8 +1159,10 @@ func TestPTRResponseFromDHCPLeases(t *testing.T) { req := createTestMessageWithType("1.0.0.127.in-addr.arpa.", dns.TypePTR) resp, err := dns.Exchange(req, addr.String()) - require.Nil(t, err) + require.NoError(t, err) + require.Len(t, resp.Answer, 1) + assert.Equal(t, dns.TypePTR, resp.Answer[0].Header().Rrtype) assert.Equal(t, "1.0.0.127.in-addr.arpa.", resp.Answer[0].Header().Name) @@ -1140,10 +1178,11 @@ func TestPTRResponseFromHosts(t *testing.T) { // Prepare test hosts file. hf, err := ioutil.TempFile("", "") - require.Nil(t, err) + require.NoError(t, err) + t.Cleanup(func() { - assert.Nil(t, hf.Close()) - assert.Nil(t, os.Remove(hf.Name())) + assert.NoError(t, hf.Close()) + assert.NoError(t, os.Remove(hf.Name())) }) _, _ = hf.WriteString(" 127.0.0.1 host # comment \n") @@ -1153,14 +1192,20 @@ func TestPTRResponseFromHosts(t *testing.T) { c.AutoHosts.Init(hf.Name()) t.Cleanup(c.AutoHosts.Close) - s := NewServer(DNSCreateParams{DNSFilter: dnsfilter.New(&c, nil)}) + s, err := NewServer(DNSCreateParams{DNSFilter: dnsfilter.New(&c, nil)}) + require.NoError(t, err) + s.conf.UDPListenAddrs = []*net.UDPAddr{{}} s.conf.TCPListenAddrs = []*net.TCPAddr{{}} s.conf.UpstreamDNS = []string{"127.0.0.1:53"} s.conf.FilteringConfig.ProtectionEnabled = true - require.Nil(t, s.Prepare(nil)) - require.Nil(t, s.Start()) + err = s.Prepare(nil) + require.NoError(t, err) + + err = s.Start() + require.NoError(t, err) + t.Cleanup(func() { s.Close() }) @@ -1169,8 +1214,10 @@ func TestPTRResponseFromHosts(t *testing.T) { req := createTestMessageWithType("1.0.0.127.in-addr.arpa.", dns.TypePTR) resp, err := dns.Exchange(req, addr.String()) - require.Nil(t, err) + require.NoError(t, err) + require.Len(t, resp.Answer, 1) + assert.Equal(t, dns.TypePTR, resp.Answer[0].Header().Rrtype) assert.Equal(t, "1.0.0.127.in-addr.arpa.", resp.Answer[0].Header().Name) @@ -1178,3 +1225,39 @@ func TestPTRResponseFromHosts(t *testing.T) { require.True(t, ok) assert.Equal(t, "host.", ptr.Ptr) } + +func TestNewServer(t *testing.T) { + testCases := []struct { + name string + in DNSCreateParams + wantErrMsg string + }{{ + name: "success", + in: DNSCreateParams{}, + wantErrMsg: "", + }, { + name: "success_autohost_tld", + in: DNSCreateParams{ + AutohostTLD: "mynet", + }, + wantErrMsg: "", + }, { + name: "bad_autohost_tld", + in: DNSCreateParams{ + AutohostTLD: "!!!", + }, + wantErrMsg: `autohost tld: invalid char '!' at index 0 in "!!!"`, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + _, err := NewServer(tc.in) + if tc.wantErrMsg == "" { + assert.NoError(t, err) + } else { + require.Error(t, err) + assert.Equal(t, tc.wantErrMsg, err.Error()) + } + }) + } +} diff --git a/internal/home/config.go b/internal/home/config.go index 6bb87c52..58970520 100644 --- a/internal/home/config.go +++ b/internal/home/config.go @@ -93,6 +93,11 @@ type dnsConfig struct { FilteringEnabled bool `yaml:"filtering_enabled"` // whether or not use filter lists FiltersUpdateIntervalHours uint32 `yaml:"filters_update_interval"` // time period to update filters (in hours) DnsfilterConf dnsfilter.Config `yaml:",inline"` + + // AutohostTLD is the top-level domain used for known internal hosts. + // For example, a machine called "myhost" can be addressed as + // "myhost.lan" when AutohostTLD is "lan". + AutohostTLD string `yaml:"autohost_tld"` } type tlsConfigSettings struct { @@ -144,6 +149,7 @@ var config = configuration{ }, FilteringEnabled: true, // whether or not use filter lists FiltersUpdateIntervalHours: 24, + AutohostTLD: "lan", }, TLS: tlsConfigSettings{ PortHTTPS: 443, diff --git a/internal/home/dns.go b/internal/home/dns.go index c3ef8c09..e90e07c9 100644 --- a/internal/home/dns.go +++ b/internal/home/dns.go @@ -61,18 +61,27 @@ func initDNSServer() error { Context.dnsFilter = dnsfilter.New(&filterConf, nil) p := dnsforward.DNSCreateParams{ - DNSFilter: Context.dnsFilter, - Stats: Context.stats, - QueryLog: Context.queryLog, + DNSFilter: Context.dnsFilter, + Stats: Context.stats, + QueryLog: Context.queryLog, + AutohostTLD: config.DNS.AutohostTLD, } if Context.dhcpServer != nil { p.DHCPServer = Context.dhcpServer } - Context.dnsServer = dnsforward.NewServer(p) + + Context.dnsServer, err = dnsforward.NewServer(p) + if err != nil { + closeDNSServer() + + return fmt.Errorf("dnsforward.NewServer: %w", err) + } + Context.clients.dnsServer = Context.dnsServer dnsConfig, err := generateServerConfig() if err != nil { closeDNSServer() + return fmt.Errorf("generateServerConfig: %w", err) } diff --git a/internal/home/whois_test.go b/internal/home/whois_test.go index ed72740d..550cc297 100644 --- a/internal/home/whois_test.go +++ b/internal/home/whois_test.go @@ -6,15 +6,23 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/dnsforward" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) -func prepareTestDNSServer() error { +func prepareTestDNSServer(t *testing.T) { + t.Helper() + config.DNS.Port = 1234 - Context.dnsServer = dnsforward.NewServer(dnsforward.DNSCreateParams{}) + + var err error + Context.dnsServer, err = dnsforward.NewServer(dnsforward.DNSCreateParams{}) + require.NoError(t, err) + conf := &dnsforward.ServerConfig{} conf.UpstreamDNS = []string{"8.8.8.8"} - return Context.dnsServer.Prepare(conf) + err = Context.dnsServer.Prepare(conf) + require.NoError(t, err) } // TODO(e.burkov): It's kind of complicated to get rid of network access in this @@ -22,12 +30,15 @@ func prepareTestDNSServer() error { // the server, so it becomes hard to simulate handling of request from test even // with substituted upstream. However, it must be done. func TestWhois(t *testing.T) { - assert.Nil(t, prepareTestDNSServer()) + prepareTestDNSServer(t) w := Whois{timeoutMsec: 5000} resp, err := w.queryAll(context.Background(), "8.8.8.8") - assert.Nil(t, err) + assert.NoError(t, err) + m := whoisParse(resp) + require.NotEmpty(t, m) + assert.Equal(t, "Google LLC", m["orgname"]) assert.Equal(t, "US", m["country"]) assert.Equal(t, "Mountain View", m["city"]) diff --git a/staticcheck.conf b/staticcheck.conf index b997f6a9..3890f802 100644 --- a/staticcheck.conf +++ b/staticcheck.conf @@ -15,6 +15,7 @@ initialisms = [ , "SDNS" , "SLAAC" , "SVCB" +, "TLD" ] dot_import_whitelist = [] http_status_code_whitelist = []