diff --git a/internal/dnsfilter/dnsfilter_test.go b/internal/dnsfilter/dnsfilter_test.go index 10f34089..76cbe5a2 100644 --- a/internal/dnsfilter/dnsfilter_test.go +++ b/internal/dnsfilter/dnsfilter_test.go @@ -5,6 +5,7 @@ import ( "context" "fmt" "net" + "strings" "testing" "github.com/AdguardTeam/AdGuardHome/internal/aghtest" @@ -13,6 +14,7 @@ import ( "github.com/AdguardTeam/urlfilter/rules" "github.com/miekg/dns" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestMain(m *testing.M) { @@ -58,7 +60,7 @@ func (d *DNSFilter) checkMatch(t *testing.T, hostname string) { t.Helper() res, err := d.CheckHost(hostname, dns.TypeA, &setts) - assert.Nilf(t, err, "Error while matching host %s: %s", hostname, err) + require.Nilf(t, err, "Error while matching host %s: %s", hostname, err) assert.Truef(t, res.IsFiltered, "Expected hostname %s to match", hostname) } @@ -66,20 +68,20 @@ func (d *DNSFilter) checkMatchIP(t *testing.T, hostname, ip string, qtype uint16 t.Helper() res, err := d.CheckHost(hostname, qtype, &setts) - assert.Nilf(t, err, "Error while matching host %s: %s", hostname, err) + require.Nilf(t, err, "Error while matching host %s: %s", hostname, err) assert.Truef(t, res.IsFiltered, "Expected hostname %s to match", hostname) - if assert.NotEmpty(t, res.Rules, "Expected result to have rules") { - r := res.Rules[0] - assert.NotNilf(t, r.IP, "Expected ip %s to match, actual: %v", ip, r.IP) - assert.Equalf(t, ip, r.IP.String(), "Expected ip %s to match, actual: %v", ip, r.IP) - } + + require.NotEmpty(t, res.Rules, "Expected result to have rules") + r := res.Rules[0] + require.NotNilf(t, r.IP, "Expected ip %s to match, actual: %v", ip, r.IP) + assert.Equalf(t, ip, r.IP.String(), "Expected ip %s to match, actual: %v", ip, r.IP) } func (d *DNSFilter) checkMatchEmpty(t *testing.T, hostname string) { t.Helper() res, err := d.CheckHost(hostname, dns.TypeA, &setts) - assert.Nilf(t, err, "Error while matching host %s: %s", hostname, err) + require.Nilf(t, err, "Error while matching host %s: %s", hostname, err) assert.Falsef(t, res.IsFiltered, "Expected hostname %s to not match", hostname) } @@ -110,40 +112,40 @@ func TestEtcHostsMatching(t *testing.T) { // Empty IPv6. res, err := d.CheckHost("block.com", dns.TypeAAAA, &setts) - assert.Nil(t, err) + require.Nil(t, err) assert.True(t, res.IsFiltered) - if assert.Len(t, res.Rules, 1) { - assert.Equal(t, "0.0.0.0 block.com", res.Rules[0].Text) - assert.Empty(t, res.Rules[0].IP) - } + + require.Len(t, res.Rules, 1) + assert.Equal(t, "0.0.0.0 block.com", res.Rules[0].Text) + assert.Empty(t, res.Rules[0].IP) // IPv6 match. d.checkMatchIP(t, "ipv6.com", addr6, dns.TypeAAAA) // Empty IPv4. res, err = d.CheckHost("ipv6.com", dns.TypeA, &setts) - assert.Nil(t, err) + require.Nil(t, err) assert.True(t, res.IsFiltered) - if assert.Len(t, res.Rules, 1) { - assert.Equal(t, "::1 ipv6.com", res.Rules[0].Text) - assert.Empty(t, res.Rules[0].IP) - } + + require.Len(t, res.Rules, 1) + assert.Equal(t, "::1 ipv6.com", res.Rules[0].Text) + assert.Empty(t, res.Rules[0].IP) // Two IPv4, the first one returned. res, err = d.CheckHost("host2", dns.TypeA, &setts) - assert.Nil(t, err) + require.Nil(t, err) assert.True(t, res.IsFiltered) - if assert.Len(t, res.Rules, 1) { - assert.Equal(t, res.Rules[0].IP, net.IP{0, 0, 0, 1}) - } + + require.Len(t, res.Rules, 1) + assert.Equal(t, res.Rules[0].IP, net.IP{0, 0, 0, 1}) // One IPv6 address. res, err = d.CheckHost("host2", dns.TypeAAAA, &setts) - assert.Nil(t, err) + require.Nil(t, err) assert.True(t, res.IsFiltered) - if assert.Len(t, res.Rules, 1) { - assert.Equal(t, res.Rules[0].IP, net.IPv6loopback) - } + + require.Len(t, res.Rules, 1) + assert.Equal(t, res.Rules[0].IP, net.IPv6loopback) } // Safe Browsing. @@ -155,14 +157,14 @@ func TestSafeBrowsing(t *testing.T) { d := newForTest(&Config{SafeBrowsingEnabled: true}, nil) t.Cleanup(d.Close) - matching := "wmconvirus.narod.ru" + const matching = "wmconvirus.narod.ru" d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{ Hostname: matching, Block: true, }) d.checkMatch(t, matching) - assert.Contains(t, logOutput.String(), "SafeBrowsing lookup for "+matching) + require.Contains(t, logOutput.String(), "SafeBrowsing lookup for "+matching) d.checkMatch(t, "test."+matching) d.checkMatchEmpty(t, "yandex.ru") @@ -178,7 +180,7 @@ func TestSafeBrowsing(t *testing.T) { func TestParallelSB(t *testing.T) { d := newForTest(&Config{SafeBrowsingEnabled: true}, nil) t.Cleanup(d.Close) - matching := "wmconvirus.narod.ru" + const matching = "wmconvirus.narod.ru" d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{ Hostname: matching, Block: true, @@ -203,7 +205,7 @@ func TestSafeSearch(t *testing.T) { d := newForTest(&Config{SafeSearchEnabled: true}, nil) t.Cleanup(d.Close) val, ok := d.SafeSearchDomain("www.google.com") - assert.True(t, ok, "Expected safesearch to find result for www.google.com") + require.True(t, ok, "Expected safesearch to find result for www.google.com") assert.Equal(t, "forcesafesearch.google.com", val, "Expected safesearch for google.com to be forcesafesearch.google.com") } @@ -211,6 +213,8 @@ func TestCheckHostSafeSearchYandex(t *testing.T) { d := newForTest(&Config{SafeSearchEnabled: true}, nil) t.Cleanup(d.Close) + yandexIP := net.IPv4(213, 180, 193, 56) + // Check host for each domain. for _, host := range []string{ "yAndeX.ru", @@ -220,22 +224,27 @@ func TestCheckHostSafeSearchYandex(t *testing.T) { "yandex.kz", "www.yandex.com", } { - res, err := d.CheckHost(host, dns.TypeA, &setts) - assert.Nil(t, err) - assert.True(t, res.IsFiltered) - if assert.Len(t, res.Rules, 1) { - assert.Equal(t, res.Rules[0].IP, net.IPv4(213, 180, 193, 56)) - } + t.Run(strings.ToLower(host), func(t *testing.T) { + res, err := d.CheckHost(host, dns.TypeA, &setts) + require.Nil(t, err) + assert.True(t, res.IsFiltered) + + require.Len(t, res.Rules, 1) + assert.Equal(t, yandexIP, res.Rules[0].IP) + }) } } func TestCheckHostSafeSearchGoogle(t *testing.T) { + resolver := &aghtest.TestResolver{} d := newForTest(&Config{ SafeSearchEnabled: true, - CustomResolver: &aghtest.TestResolver{}, + CustomResolver: resolver, }, nil) t.Cleanup(d.Close) + ip, _ := resolver.HostToIPs("forcesafesearch.google.com") + // Check host for each domain. for _, host := range []string{ "www.google.com", @@ -248,11 +257,10 @@ func TestCheckHostSafeSearchGoogle(t *testing.T) { } { t.Run(host, func(t *testing.T) { res, err := d.CheckHost(host, dns.TypeA, &setts) - assert.Nil(t, err) + require.Nil(t, err) assert.True(t, res.IsFiltered) - if assert.Len(t, res.Rules, 1) { - assert.NotEqual(t, res.Rules[0].IP.String(), "0.0.0.0") - } + require.Len(t, res.Rules, 1) + assert.Equal(t, ip, res.Rules[0].IP) }) } } @@ -260,31 +268,31 @@ func TestCheckHostSafeSearchGoogle(t *testing.T) { func TestSafeSearchCacheYandex(t *testing.T) { d := newForTest(nil, nil) t.Cleanup(d.Close) - domain := "yandex.ru" + const domain = "yandex.ru" // Check host with disabled safesearch. res, err := d.CheckHost(domain, dns.TypeA, &setts) - assert.Nil(t, err) + require.Nil(t, err) assert.False(t, res.IsFiltered) - assert.Empty(t, res.Rules) + require.Empty(t, res.Rules) + + yandexIP := net.IPv4(213, 180, 193, 56) d = newForTest(&Config{SafeSearchEnabled: true}, nil) t.Cleanup(d.Close) res, err = d.CheckHost(domain, dns.TypeA, &setts) - assert.Nilf(t, err, "CheckHost for safesearh domain %s failed cause %s", domain, err) + require.Nilf(t, err, "CheckHost for safesearh domain %s failed cause %s", domain, err) // For yandex we already know valid IP. - if assert.Len(t, res.Rules, 1) { - assert.Equal(t, res.Rules[0].IP, net.IPv4(213, 180, 193, 56)) - } + require.Len(t, res.Rules, 1) + assert.Equal(t, res.Rules[0].IP, yandexIP) // Check cache. cachedValue, isFound := getCachedResult(gctx.safeSearchCache, domain) - assert.True(t, isFound) - if assert.Len(t, cachedValue.Rules, 1) { - assert.Equal(t, cachedValue.Rules[0].IP, net.IPv4(213, 180, 193, 56)) - } + require.True(t, isFound) + require.Len(t, cachedValue.Rules, 1) + assert.Equal(t, cachedValue.Rules[0].IP, yandexIP) } func TestSafeSearchCacheGoogle(t *testing.T) { @@ -294,11 +302,11 @@ func TestSafeSearchCacheGoogle(t *testing.T) { }, nil) t.Cleanup(d.Close) - domain := "www.google.ru" + const domain = "www.google.ru" res, err := d.CheckHost(domain, dns.TypeA, &setts) - assert.Nil(t, err) + require.Nil(t, err) assert.False(t, res.IsFiltered) - assert.Empty(t, res.Rules) + require.Empty(t, res.Rules) d = newForTest(&Config{SafeSearchEnabled: true}, nil) t.Cleanup(d.Close) @@ -306,12 +314,10 @@ func TestSafeSearchCacheGoogle(t *testing.T) { // Lookup for safesearch domain. safeDomain, ok := d.SafeSearchDomain(domain) - assert.Truef(t, ok, "Failed to get safesearch domain for %s", domain) + require.Truef(t, ok, "Failed to get safesearch domain for %s", domain) ips, err := resolver.LookupIP(context.Background(), "ip", safeDomain) - if err != nil { - t.Fatalf("Failed to lookup for %s", safeDomain) - } + require.Nilf(t, err, "Failed to lookup for %s", safeDomain) var ip net.IP for _, foundIP := range ips { @@ -323,17 +329,15 @@ func TestSafeSearchCacheGoogle(t *testing.T) { } res, err = d.CheckHost(domain, dns.TypeA, &setts) - assert.Nil(t, err) - if assert.Len(t, res.Rules, 1) { - assert.True(t, res.Rules[0].IP.Equal(ip)) - } + require.Nil(t, err) + require.Len(t, res.Rules, 1) + assert.True(t, res.Rules[0].IP.Equal(ip)) // Check cache. cachedValue, isFound := getCachedResult(gctx.safeSearchCache, domain) - assert.True(t, isFound) - if assert.Len(t, cachedValue.Rules, 1) { - assert.True(t, cachedValue.Rules[0].IP.Equal(ip)) - } + require.True(t, isFound) + require.Len(t, cachedValue.Rules, 1) + assert.True(t, cachedValue.Rules[0].IP.Equal(ip)) } // Parental. @@ -345,24 +349,23 @@ func TestParentalControl(t *testing.T) { d := newForTest(&Config{ParentalEnabled: true}, nil) t.Cleanup(d.Close) - matching := "pornhub.com" + const matching = "pornhub.com" d.SetParentalUpstream(&aghtest.TestBlockUpstream{ Hostname: matching, Block: true, }) d.checkMatch(t, matching) - assert.Contains(t, logOutput.String(), "Parental lookup for "+matching) + require.Contains(t, logOutput.String(), "Parental lookup for "+matching) d.checkMatch(t, "www."+matching) d.checkMatchEmpty(t, "www.yandex.ru") d.checkMatchEmpty(t, "yandex.ru") d.checkMatchEmpty(t, "api.jquery.com") - // test cached result + // Test cached result. d.parentalServer = "127.0.0.1" d.checkMatch(t, matching) d.checkMatchEmpty(t, "yandex.ru") - d.parentalServer = defaultParentalServer } // Filtering. @@ -651,7 +654,7 @@ func TestMatching(t *testing.T) { t.Cleanup(d.Close) res, err := d.CheckHost(tc.host, tc.wantDNSType, &setts) - assert.Nilf(t, err, "Error while matching host %s: %s", tc.host, err) + require.Nilf(t, err, "Error while matching host %s: %s", tc.host, err) assert.Equalf(t, tc.wantIsFiltered, res.IsFiltered, "Hostname %s has wrong result (%v must be %v)", tc.host, res.IsFiltered, tc.wantIsFiltered) assert.Equalf(t, tc.wantReason, res.Reason, "Hostname %s has wrong reason (%v must be %v)", tc.host, res.Reason, tc.wantReason) }) @@ -674,28 +677,24 @@ func TestWhitelist(t *testing.T) { }} d := newForTest(nil, filters) - err := d.SetFilters(filters, whiteFilters, false) - assert.Nil(t, err) - + require.Nil(t, d.SetFilters(filters, whiteFilters, false)) t.Cleanup(d.Close) // Matched by white filter. res, err := d.CheckHost("host1", dns.TypeA, &setts) - assert.Nil(t, err) + require.Nil(t, err) assert.False(t, res.IsFiltered) assert.Equal(t, res.Reason, NotFilteredAllowList) - if assert.Len(t, res.Rules, 1) { - assert.Equal(t, "||host1^", res.Rules[0].Text) - } + require.Len(t, res.Rules, 1) + assert.Equal(t, "||host1^", res.Rules[0].Text) // Not matched by white filter, but matched by block filter. res, err = d.CheckHost("host2", dns.TypeA, &setts) - assert.Nil(t, err) + require.Nil(t, err) assert.True(t, res.IsFiltered) assert.Equal(t, res.Reason, FilteredBlockList) - if assert.Len(t, res.Rules, 1) { - assert.Equal(t, "||host2^", res.Rules[0].Text) - } + require.Len(t, res.Rules, 1) + assert.Equal(t, "||host2^", res.Rules[0].Text) } // Client Settings. @@ -797,7 +796,7 @@ func BenchmarkSafeBrowsing(b *testing.B) { }) for n := 0; n < b.N; n++ { res, err := d.CheckHost(blocked, dns.TypeA, &setts) - assert.Nilf(b, err, "Error while matching host %s: %s", blocked, err) + require.Nilf(b, err, "Error while matching host %s: %s", blocked, err) assert.True(b, res.IsFiltered, "Expected hostname %s to match", blocked) } } @@ -813,7 +812,7 @@ func BenchmarkSafeBrowsingParallel(b *testing.B) { b.RunParallel(func(pb *testing.PB) { for pb.Next() { res, err := d.CheckHost(blocked, dns.TypeA, &setts) - assert.Nilf(b, err, "Error while matching host %s: %s", blocked, err) + require.Nilf(b, err, "Error while matching host %s: %s", blocked, err) assert.True(b, res.IsFiltered, "Expected hostname %s to match", blocked) } }) @@ -824,7 +823,7 @@ func BenchmarkSafeSearch(b *testing.B) { b.Cleanup(d.Close) for n := 0; n < b.N; n++ { val, ok := d.SafeSearchDomain("www.google.com") - assert.True(b, ok, "Expected safesearch to find result for www.google.com") + require.True(b, ok, "Expected safesearch to find result for www.google.com") assert.Equal(b, "forcesafesearch.google.com", val, "Expected safesearch for google.com to be forcesafesearch.google.com") } } @@ -835,7 +834,7 @@ func BenchmarkSafeSearchParallel(b *testing.B) { b.RunParallel(func(pb *testing.PB) { for pb.Next() { val, ok := d.SafeSearchDomain("www.google.com") - assert.True(b, ok, "Expected safesearch to find result for www.google.com") + require.True(b, ok, "Expected safesearch to find result for www.google.com") assert.Equal(b, "forcesafesearch.google.com", val, "Expected safesearch for google.com to be forcesafesearch.google.com") } }) diff --git a/internal/dnsfilter/dnsrewrite_test.go b/internal/dnsfilter/dnsrewrite_test.go index c915d920..03b9f1ce 100644 --- a/internal/dnsfilter/dnsrewrite_test.go +++ b/internal/dnsfilter/dnsrewrite_test.go @@ -7,6 +7,7 @@ import ( "github.com/miekg/dns" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) { @@ -55,138 +56,89 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) { ipv6p1 := net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1} ipv6p2 := net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2} + testCasesA := []struct { + name string + dtyp uint16 + rcode int + want []interface{} + }{{ + name: "a-record", + dtyp: dns.TypeA, + rcode: dns.RcodeSuccess, + want: []interface{}{ipv4p1}, + }, { + name: "aaaa-record", + dtyp: dns.TypeAAAA, + rcode: dns.RcodeSuccess, + want: []interface{}{ipv6p1}, + }, { + name: "txt-record", + dtyp: dns.TypeTXT, + rcode: dns.RcodeSuccess, + want: []interface{}{"hello-world"}, + }, { + name: "refused", + rcode: dns.RcodeRefused, + }, { + name: "a-records", + dtyp: dns.TypeA, + rcode: dns.RcodeSuccess, + want: []interface{}{ipv4p1, ipv4p2}, + }, { + name: "aaaa-records", + dtyp: dns.TypeAAAA, + rcode: dns.RcodeSuccess, + want: []interface{}{ipv6p1, ipv6p2}, + }, { + name: "disable-one", + dtyp: dns.TypeA, + rcode: dns.RcodeSuccess, + want: []interface{}{ipv4p2}, + }, { + name: "disable-cname", + dtyp: dns.TypeA, + rcode: dns.RcodeSuccess, + want: []interface{}{ipv4p1}, + }} + + for _, tc := range testCasesA { + t.Run(tc.name, func(t *testing.T) { + host := path.Base(tc.name) + + res, err := f.CheckHostRules(host, tc.dtyp, setts) + require.Nil(t, err) + + dnsrr := res.DNSRewriteResult + require.NotNil(t, dnsrr) + assert.Equal(t, tc.rcode, dnsrr.RCode) + + if tc.rcode == dns.RcodeRefused { + return + } + + ipVals := dnsrr.Response[tc.dtyp] + require.Len(t, ipVals, len(tc.want)) + for i, val := range tc.want { + require.Equal(t, val, ipVals[i]) + } + }) + } + t.Run("cname", func(t *testing.T) { dtyp := dns.TypeA host := path.Base(t.Name()) res, err := f.CheckHostRules(host, dtyp, setts) - assert.Nil(t, err) + require.Nil(t, err) assert.Equal(t, "new-cname", res.CanonName) }) - t.Run("a-record", func(t *testing.T) { - dtyp := dns.TypeA - host := path.Base(t.Name()) - - res, err := f.CheckHostRules(host, dtyp, setts) - assert.Nil(t, err) - - if dnsrr := res.DNSRewriteResult; assert.NotNil(t, dnsrr) { - assert.Equal(t, dns.RcodeSuccess, dnsrr.RCode) - if ipVals := dnsrr.Response[dtyp]; assert.Len(t, ipVals, 1) { - assert.Equal(t, ipv4p1, ipVals[0]) - } - } - }) - - t.Run("aaaa-record", func(t *testing.T) { - dtyp := dns.TypeAAAA - host := path.Base(t.Name()) - - res, err := f.CheckHostRules(host, dtyp, setts) - assert.Nil(t, err) - - if dnsrr := res.DNSRewriteResult; assert.NotNil(t, dnsrr) { - assert.Equal(t, dns.RcodeSuccess, dnsrr.RCode) - if ipVals := dnsrr.Response[dtyp]; assert.Len(t, ipVals, 1) { - assert.Equal(t, ipv6p1, ipVals[0]) - } - } - }) - - t.Run("txt-record", func(t *testing.T) { - dtyp := dns.TypeTXT - host := path.Base(t.Name()) - res, err := f.CheckHostRules(host, dtyp, setts) - assert.Nil(t, err) - - if dnsrr := res.DNSRewriteResult; assert.NotNil(t, dnsrr) { - assert.Equal(t, dns.RcodeSuccess, dnsrr.RCode) - if strVals := dnsrr.Response[dtyp]; assert.Len(t, strVals, 1) { - assert.Equal(t, "hello-world", strVals[0]) - } - } - }) - - t.Run("refused", func(t *testing.T) { - host := path.Base(t.Name()) - res, err := f.CheckHostRules(host, dns.TypeA, setts) - assert.Nil(t, err) - - if dnsrr := res.DNSRewriteResult; assert.NotNil(t, dnsrr) { - assert.Equal(t, dns.RcodeRefused, dnsrr.RCode) - } - }) - - t.Run("a-records", func(t *testing.T) { - dtyp := dns.TypeA - host := path.Base(t.Name()) - - res, err := f.CheckHostRules(host, dtyp, setts) - assert.Nil(t, err) - - if dnsrr := res.DNSRewriteResult; assert.NotNil(t, dnsrr) { - assert.Equal(t, dns.RcodeSuccess, dnsrr.RCode) - if ipVals := dnsrr.Response[dtyp]; assert.Len(t, ipVals, 2) { - assert.Equal(t, ipv4p1, ipVals[0]) - assert.Equal(t, ipv4p2, ipVals[1]) - } - } - }) - - t.Run("aaaa-records", func(t *testing.T) { - dtyp := dns.TypeAAAA - host := path.Base(t.Name()) - - res, err := f.CheckHostRules(host, dtyp, setts) - assert.Nil(t, err) - - if dnsrr := res.DNSRewriteResult; assert.NotNil(t, dnsrr) { - assert.Equal(t, dns.RcodeSuccess, dnsrr.RCode) - if ipVals := dnsrr.Response[dtyp]; assert.Len(t, ipVals, 2) { - assert.Equal(t, ipv6p1, ipVals[0]) - assert.Equal(t, ipv6p2, ipVals[1]) - } - } - }) - - t.Run("disable-one", func(t *testing.T) { - dtyp := dns.TypeA - host := path.Base(t.Name()) - - res, err := f.CheckHostRules(host, dtyp, setts) - assert.Nil(t, err) - - if dnsrr := res.DNSRewriteResult; assert.NotNil(t, dnsrr) { - assert.Equal(t, dns.RcodeSuccess, dnsrr.RCode) - if ipVals := dnsrr.Response[dtyp]; assert.Len(t, ipVals, 1) { - assert.Equal(t, ipv4p2, ipVals[0]) - } - } - }) - - t.Run("disable-cname", func(t *testing.T) { - dtyp := dns.TypeA - host := path.Base(t.Name()) - - res, err := f.CheckHostRules(host, dtyp, setts) - assert.Nil(t, err) - assert.Empty(t, res.CanonName) - - if dnsrr := res.DNSRewriteResult; assert.NotNil(t, dnsrr) { - assert.Equal(t, dns.RcodeSuccess, dnsrr.RCode) - if ipVals := dnsrr.Response[dtyp]; assert.Len(t, ipVals, 1) { - assert.Equal(t, ipv4p1, ipVals[0]) - } - } - }) - t.Run("disable-cname-many", func(t *testing.T) { dtyp := dns.TypeA host := path.Base(t.Name()) res, err := f.CheckHostRules(host, dtyp, setts) - assert.Nil(t, err) + require.Nil(t, err) assert.Equal(t, "new-cname-2", res.CanonName) assert.Nil(t, res.DNSRewriteResult) }) @@ -196,7 +148,7 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) { host := path.Base(t.Name()) res, err := f.CheckHostRules(host, dtyp, setts) - assert.Nil(t, err) + require.Nil(t, err) assert.Empty(t, res.CanonName) assert.Empty(t, res.Rules) }) diff --git a/internal/dnsfilter/rewrites_test.go b/internal/dnsfilter/rewrites_test.go index a56d2a48..bf5c0fcc 100644 --- a/internal/dnsfilter/rewrites_test.go +++ b/internal/dnsfilter/rewrites_test.go @@ -6,215 +6,297 @@ import ( "github.com/miekg/dns" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) +// TODO(e.burkov): All the tests in this file may and should me merged together. + func TestRewrites(t *testing.T) { d := newForTest(nil, nil) t.Cleanup(d.Close) - // CNAME, A, AAAA - d.Rewrites = []RewriteEntry{ - {"somecname", "somehost.com", 0, nil}, - {"somehost.com", "0.0.0.0", 0, nil}, - {"host.com", "1.2.3.4", 0, nil}, - {"host.com", "1.2.3.5", 0, nil}, - {"host.com", "1:2:3::4", 0, nil}, - {"www.host.com", "host.com", 0, nil}, - } + d.Rewrites = []RewriteEntry{{ + // This one and below are about CNAME, A and AAAA. + Domain: "somecname", + Answer: "somehost.com", + }, { + Domain: "somehost.com", + Answer: "0.0.0.0", + }, { + Domain: "host.com", + Answer: "1.2.3.4", + }, { + Domain: "host.com", + Answer: "1.2.3.5", + }, { + Domain: "host.com", + Answer: "1:2:3::4", + }, { + Domain: "www.host.com", + Answer: "host.com", + }, { + // This one is a wildcard. + Domain: "*.host.com", + Answer: "1.2.3.5", + }, { + // This one and below are about wildcard overriding. + Domain: "a.host.com", + Answer: "1.2.3.4", + }, { + // This one is about CNAME and wildcard interacting. + Domain: "*.host2.com", + Answer: "host.com", + }, { + // This one and below are about 2 level CNAME. + Domain: "b.host.com", + Answer: "somecname", + }, { + // This one and below are about 2 level CNAME and wildcard. + Domain: "b.host3.com", + Answer: "a.host3.com", + }, { + Domain: "a.host3.com", + Answer: "x.host.com", + }} d.prepareRewrites() - r := d.processRewrites("host2.com", dns.TypeA) - assert.Equal(t, NotFilteredNotFound, r.Reason) - r = d.processRewrites("www.host.com", dns.TypeA) - assert.Equal(t, Rewritten, r.Reason) - assert.Equal(t, "host.com", r.CanonName) - assert.Len(t, r.IPList, 2) - assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 4})) - assert.True(t, r.IPList[1].Equal(net.IP{1, 2, 3, 5})) + testCases := []struct { + name string + host string + dtyp uint16 + wantCName string + wantVals []net.IP + }{{ + name: "not_filtered_not_found", + host: "hoost.com", + dtyp: dns.TypeA, + }, { + name: "rewritten_a", + host: "www.host.com", + dtyp: dns.TypeA, + wantCName: "host.com", + wantVals: []net.IP{{1, 2, 3, 4}, {1, 2, 3, 5}}, + }, { + name: "rewritten_aaaa", + host: "www.host.com", + dtyp: dns.TypeAAAA, + wantCName: "host.com", + wantVals: []net.IP{net.ParseIP("1:2:3::4")}, + }, { + name: "wildcard_match", + host: "abc.host.com", + dtyp: dns.TypeA, + wantVals: []net.IP{{1, 2, 3, 5}}, + }, { + name: "wildcard_override", + host: "a.host.com", + dtyp: dns.TypeA, + wantVals: []net.IP{{1, 2, 3, 4}}, + }, { + name: "wildcard_cname_interaction", + host: "www.host2.com", + dtyp: dns.TypeA, + wantCName: "host.com", + wantVals: []net.IP{{1, 2, 3, 4}, {1, 2, 3, 5}}, + }, { + name: "two_cnames", + host: "b.host.com", + dtyp: dns.TypeA, + wantCName: "somehost.com", + wantVals: []net.IP{{0, 0, 0, 0}}, + }, { + name: "two_cnames_and_wildcard", + host: "b.host3.com", + dtyp: dns.TypeA, + wantCName: "x.host.com", + wantVals: []net.IP{{1, 2, 3, 5}}, + }} - r = d.processRewrites("www.host.com", dns.TypeAAAA) - assert.Equal(t, Rewritten, r.Reason) - assert.Equal(t, "host.com", r.CanonName) - assert.Len(t, r.IPList, 1) - assert.True(t, r.IPList[0].Equal(net.ParseIP("1:2:3::4"))) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + valsNum := len(tc.wantVals) - // wildcard - d.Rewrites = []RewriteEntry{ - {"host.com", "1.2.3.4", 0, nil}, - {"*.host.com", "1.2.3.5", 0, nil}, + r := d.processRewrites(tc.host, tc.dtyp) + if valsNum == 0 { + assert.Equal(t, NotFilteredNotFound, r.Reason) + + return + } + + require.Equal(t, Rewritten, r.Reason) + if tc.wantCName != "" { + assert.Equal(t, tc.wantCName, r.CanonName) + } + + require.Len(t, r.IPList, valsNum) + for i, ip := range tc.wantVals { + assert.Equal(t, ip, r.IPList[i]) + } + }) } - d.prepareRewrites() - r = d.processRewrites("host.com", dns.TypeA) - assert.Equal(t, Rewritten, r.Reason) - assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 4})) - - r = d.processRewrites("www.host.com", dns.TypeA) - assert.Equal(t, Rewritten, r.Reason) - assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 5})) - - r = d.processRewrites("www.host2.com", dns.TypeA) - assert.Equal(t, NotFilteredNotFound, r.Reason) - - // override a wildcard - d.Rewrites = []RewriteEntry{ - {"a.host.com", "1.2.3.4", 0, nil}, - {"*.host.com", "1.2.3.5", 0, nil}, - } - d.prepareRewrites() - r = d.processRewrites("a.host.com", dns.TypeA) - assert.Equal(t, Rewritten, r.Reason) - assert.Len(t, r.IPList, 1) - assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 4})) - - // wildcard + CNAME - d.Rewrites = []RewriteEntry{ - {"host.com", "1.2.3.4", 0, nil}, - {"*.host.com", "host.com", 0, nil}, - } - d.prepareRewrites() - r = d.processRewrites("www.host.com", dns.TypeA) - assert.Equal(t, Rewritten, r.Reason) - assert.Equal(t, "host.com", r.CanonName) - assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 4})) - - // 2 CNAMEs - d.Rewrites = []RewriteEntry{ - {"b.host.com", "a.host.com", 0, nil}, - {"a.host.com", "host.com", 0, nil}, - {"host.com", "1.2.3.4", 0, nil}, - } - d.prepareRewrites() - r = d.processRewrites("b.host.com", dns.TypeA) - assert.Equal(t, Rewritten, r.Reason) - assert.Equal(t, "host.com", r.CanonName) - assert.Len(t, r.IPList, 1) - assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 4})) - - // 2 CNAMEs + wildcard - d.Rewrites = []RewriteEntry{ - {"b.host.com", "a.host.com", 0, nil}, - {"a.host.com", "x.somehost.com", 0, nil}, - {"*.somehost.com", "1.2.3.4", 0, nil}, - } - d.prepareRewrites() - r = d.processRewrites("b.host.com", dns.TypeA) - assert.Equal(t, Rewritten, r.Reason) - assert.Equal(t, "x.somehost.com", r.CanonName) - assert.Len(t, r.IPList, 1) - assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 4})) } func TestRewritesLevels(t *testing.T) { d := newForTest(nil, nil) t.Cleanup(d.Close) - // exact host, wildcard L2, wildcard L3 - d.Rewrites = []RewriteEntry{ - {"host.com", "1.1.1.1", 0, nil}, - {"*.host.com", "2.2.2.2", 0, nil}, - {"*.sub.host.com", "3.3.3.3", 0, nil}, - } + // Exact host, wildcard L2, wildcard L3. + d.Rewrites = []RewriteEntry{{ + Domain: "host.com", + Answer: "1.1.1.1", + }, { + Domain: "*.host.com", + Answer: "2.2.2.2", + }, { + Domain: "*.sub.host.com", + Answer: "3.3.3.3", + }} d.prepareRewrites() - // match exact - r := d.processRewrites("host.com", dns.TypeA) - assert.Equal(t, Rewritten, r.Reason) - assert.Len(t, r.IPList, 1) - assert.True(t, net.IP{1, 1, 1, 1}.Equal(r.IPList[0])) + testCases := []struct { + name string + host string + want net.IP + }{{ + name: "exact_match", + host: "host.com", + want: net.IP{1, 1, 1, 1}, + }, { + name: "l2_match", + host: "sub.host.com", + want: net.IP{2, 2, 2, 2}, + }, { + name: "l3_match", + host: "my.sub.host.com", + want: net.IP{3, 3, 3, 3}, + }} - // match L2 - r = d.processRewrites("sub.host.com", dns.TypeA) - assert.Equal(t, Rewritten, r.Reason) - assert.Len(t, r.IPList, 1) - assert.True(t, net.IP{2, 2, 2, 2}.Equal(r.IPList[0])) - - // match L3 - r = d.processRewrites("my.sub.host.com", dns.TypeA) - assert.Equal(t, Rewritten, r.Reason) - assert.Len(t, r.IPList, 1) - assert.True(t, net.IP{3, 3, 3, 3}.Equal(r.IPList[0])) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + r := d.processRewrites(tc.host, dns.TypeA) + assert.Equal(t, Rewritten, r.Reason) + require.Len(t, r.IPList, 1) + }) + } } func TestRewritesExceptionCNAME(t *testing.T) { d := newForTest(nil, nil) t.Cleanup(d.Close) - // wildcard; exception for a sub-domain - d.Rewrites = []RewriteEntry{ - {"*.host.com", "2.2.2.2", 0, nil}, - {"sub.host.com", "sub.host.com", 0, nil}, - } + // Wildcard and exception for a sub-domain. + d.Rewrites = []RewriteEntry{{ + Domain: "*.host.com", + Answer: "2.2.2.2", + }, { + Domain: "sub.host.com", + Answer: "sub.host.com", + }, { + Domain: "*.sub.host.com", + Answer: "*.sub.host.com", + }} d.prepareRewrites() - // match sub-domain - r := d.processRewrites("my.host.com", dns.TypeA) - assert.Equal(t, Rewritten, r.Reason) - assert.Len(t, r.IPList, 1) - assert.True(t, net.IP{2, 2, 2, 2}.Equal(r.IPList[0])) + testCases := []struct { + name string + host string + want net.IP + }{{ + name: "match_sub-domain", + host: "my.host.com", + want: net.IP{2, 2, 2, 2}, + }, { + name: "exception_cname", + host: "sub.host.com", + }, { + name: "exception_wildcard", + host: "my.sub.host.com", + }} - // match sub-domain, but handle exception - r = d.processRewrites("sub.host.com", dns.TypeA) - assert.Equal(t, NotFilteredNotFound, r.Reason) -} + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + r := d.processRewrites(tc.host, dns.TypeA) + if tc.want == nil { + assert.Equal(t, NotFilteredNotFound, r.Reason) -func TestRewritesExceptionWC(t *testing.T) { - d := newForTest(nil, nil) - t.Cleanup(d.Close) - // wildcard; exception for a sub-wildcard - d.Rewrites = []RewriteEntry{ - {"*.host.com", "2.2.2.2", 0, nil}, - {"*.sub.host.com", "*.sub.host.com", 0, nil}, + return + } + + assert.Equal(t, Rewritten, r.Reason) + require.Len(t, r.IPList, 1) + assert.True(t, tc.want.Equal(r.IPList[0])) + }) } - d.prepareRewrites() - - // match sub-domain - r := d.processRewrites("my.host.com", dns.TypeA) - assert.Equal(t, Rewritten, r.Reason) - assert.Len(t, r.IPList, 1) - assert.True(t, net.IP{2, 2, 2, 2}.Equal(r.IPList[0])) - - // match sub-domain, but handle exception - r = d.processRewrites("my.sub.host.com", dns.TypeA) - assert.Equal(t, NotFilteredNotFound, r.Reason) } func TestRewritesExceptionIP(t *testing.T) { d := newForTest(nil, nil) t.Cleanup(d.Close) - // exception for AAAA record - d.Rewrites = []RewriteEntry{ - {"host.com", "1.2.3.4", 0, nil}, - {"host.com", "AAAA", 0, nil}, - {"host2.com", "::1", 0, nil}, - {"host2.com", "A", 0, nil}, - {"host3.com", "A", 0, nil}, - } + // Exception for AAAA record. + d.Rewrites = []RewriteEntry{{ + Domain: "host.com", + Answer: "1.2.3.4", + }, { + Domain: "host.com", + Answer: "AAAA", + }, { + Domain: "host2.com", + Answer: "::1", + }, { + Domain: "host2.com", + Answer: "A", + }, { + Domain: "host3.com", + Answer: "A", + }} d.prepareRewrites() - // match domain - r := d.processRewrites("host.com", dns.TypeA) - assert.Equal(t, Rewritten, r.Reason) - assert.Len(t, r.IPList, 1) - assert.True(t, net.IP{1, 2, 3, 4}.Equal(r.IPList[0])) + testCases := []struct { + name string + host string + dtyp uint16 + want []net.IP + }{{ + name: "match_A", + host: "host.com", + dtyp: dns.TypeA, + want: []net.IP{{1, 2, 3, 4}}, + }, { + name: "exception_AAAA_host.com", + host: "host.com", + dtyp: dns.TypeAAAA, + }, { + name: "exception_A_host2.com", + host: "host2.com", + dtyp: dns.TypeA, + }, { + name: "match_AAAA_host2.com", + host: "host2.com", + dtyp: dns.TypeAAAA, + want: []net.IP{net.ParseIP("::1")}, + }, { + name: "exception_A_host3.com", + host: "host3.com", + dtyp: dns.TypeA, + }, { + name: "match_AAAA_host3.com", + host: "host3.com", + dtyp: dns.TypeAAAA, + want: []net.IP{}, + }} - // match exception - r = d.processRewrites("host.com", dns.TypeAAAA) - assert.Equal(t, NotFilteredNotFound, r.Reason) + for _, tc := range testCases { + t.Run(tc.name+"_"+tc.host, func(t *testing.T) { + r := d.processRewrites(tc.host, tc.dtyp) + if tc.want == nil { + assert.Equal(t, NotFilteredNotFound, r.Reason) - // match exception - r = d.processRewrites("host2.com", dns.TypeA) - assert.Equal(t, NotFilteredNotFound, r.Reason) + return + } - // match domain - r = d.processRewrites("host2.com", dns.TypeAAAA) - assert.Equal(t, Rewritten, r.Reason) - assert.Len(t, r.IPList, 1) - assert.Equal(t, "::1", r.IPList[0].String()) - - // match exception - r = d.processRewrites("host3.com", dns.TypeA) - assert.Equal(t, NotFilteredNotFound, r.Reason) - - // match domain - r = d.processRewrites("host3.com", dns.TypeAAAA) - assert.Equal(t, Rewritten, r.Reason) - assert.Empty(t, r.IPList) + assert.Equal(t, Rewritten, r.Reason) + require.Len(t, r.IPList, len(tc.want)) + for _, ip := range tc.want { + assert.True(t, ip.Equal(r.IPList[0])) + } + }) + } } diff --git a/internal/dnsfilter/safebrowsing_test.go b/internal/dnsfilter/safebrowsing_test.go index f94a3c99..d5a3acc1 100644 --- a/internal/dnsfilter/safebrowsing_test.go +++ b/internal/dnsfilter/safebrowsing_test.go @@ -8,6 +8,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/golibs/cache" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestSafeBrowsingHash(t *testing.T) { @@ -155,25 +156,25 @@ func TestSBPC(t *testing.T) { }} for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - // Prepare the upstream. - ups := &aghtest.TestBlockUpstream{ - Hostname: hostname, - Block: tc.block, - } - d.SetSafeBrowsingUpstream(ups) - d.SetParentalUpstream(ups) + // Prepare the upstream. + ups := &aghtest.TestBlockUpstream{ + Hostname: hostname, + Block: tc.block, + } + d.SetSafeBrowsingUpstream(ups) + d.SetParentalUpstream(ups) + t.Run(tc.name, func(t *testing.T) { // Firstly, check the request blocking. hits := 0 res, err := tc.testFunc(hostname) - assert.Nil(t, err) + require.Nil(t, err) if tc.block { assert.True(t, res.IsFiltered) - assert.Len(t, res.Rules, 1) + require.Len(t, res.Rules, 1) hits++ } else { - assert.False(t, res.IsFiltered) + require.False(t, res.IsFiltered) } // Check the cache state, check the response is now cached. @@ -185,12 +186,12 @@ func TestSBPC(t *testing.T) { // Now make the same request to check the cache was used. res, err = tc.testFunc(hostname) - assert.Nil(t, err) + require.Nil(t, err) if tc.block { assert.True(t, res.IsFiltered) - assert.Len(t, res.Rules, 1) + require.Len(t, res.Rules, 1) } else { - assert.False(t, res.IsFiltered) + require.False(t, res.IsFiltered) } // Check the cache state, it should've been used. @@ -199,8 +200,8 @@ func TestSBPC(t *testing.T) { // Check that there were no additional requests. assert.Equal(t, 1, ups.RequestsCount()) - - purgeCaches() }) + + purgeCaches() } } diff --git a/internal/dnsforward/access_test.go b/internal/dnsforward/access_test.go index af13b02e..6c8a3766 100644 --- a/internal/dnsforward/access_test.go +++ b/internal/dnsforward/access_test.go @@ -5,71 +5,153 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) -func TestIsBlockedIPAllowed(t *testing.T) { - a := &accessCtx{} - assert.Nil(t, a.Init([]string{"1.1.1.1", "2.2.0.0/16"}, nil, nil)) +func TestIsBlockedIP(t *testing.T) { + const ( + ip int = iota + cidr + ) - disallowed, disallowedRule := a.IsBlockedIP(net.IPv4(1, 1, 1, 1)) - assert.False(t, disallowed) - assert.Empty(t, disallowedRule) + rules := []string{ + ip: "1.1.1.1", + cidr: "2.2.0.0/16", + } - disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(1, 1, 1, 2)) - assert.True(t, disallowed) - assert.Empty(t, disallowedRule) + testCases := []struct { + name string + allowed bool + ip net.IP + wantDis bool + wantRule string + }{{ + name: "allow_ip", + allowed: true, + ip: net.IPv4(1, 1, 1, 1), + wantDis: false, + wantRule: "", + }, { + name: "disallow_ip", + allowed: true, + ip: net.IPv4(1, 1, 1, 2), + wantDis: true, + wantRule: "", + }, { + name: "allow_cidr", + allowed: true, + ip: net.IPv4(2, 2, 1, 1), + wantDis: false, + wantRule: "", + }, { + name: "disallow_cidr", + allowed: true, + ip: net.IPv4(2, 3, 1, 1), + wantDis: true, + wantRule: "", + }, { + name: "allow_ip", + allowed: false, + ip: net.IPv4(1, 1, 1, 1), + wantDis: true, + wantRule: rules[ip], + }, { + name: "disallow_ip", + allowed: false, + ip: net.IPv4(1, 1, 1, 2), + wantDis: false, + wantRule: "", + }, { + name: "allow_cidr", + allowed: false, + ip: net.IPv4(2, 2, 1, 1), + wantDis: true, + wantRule: rules[cidr], + }, { + name: "disallow_cidr", + allowed: false, + ip: net.IPv4(2, 3, 1, 1), + wantDis: false, + wantRule: "", + }} - disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(2, 2, 1, 1)) - assert.False(t, disallowed) - assert.Empty(t, disallowedRule) + for _, tc := range testCases { + prefix := "allowed_" + if !tc.allowed { + prefix = "disallowed_" + } - disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(2, 3, 1, 1)) - assert.True(t, disallowed) - assert.Empty(t, disallowedRule) + t.Run(prefix+tc.name, func(t *testing.T) { + aCtx := &accessCtx{} + allowedRules := rules + var disallowedRules []string + + if !tc.allowed { + allowedRules, disallowedRules = disallowedRules, allowedRules + } + + require.Nil(t, aCtx.Init(allowedRules, disallowedRules, nil)) + + disallowed, rule := aCtx.IsBlockedIP(tc.ip) + assert.Equal(t, tc.wantDis, disallowed) + assert.Equal(t, tc.wantRule, rule) + }) + } } -func TestIsBlockedIPDisallowed(t *testing.T) { - a := &accessCtx{} - assert.Nil(t, a.Init(nil, []string{"1.1.1.1", "2.2.0.0/16"}, nil)) - - disallowed, disallowedRule := a.IsBlockedIP(net.IPv4(1, 1, 1, 1)) - assert.True(t, disallowed) - assert.Equal(t, "1.1.1.1", disallowedRule) - - disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(1, 1, 1, 2)) - assert.False(t, disallowed) - assert.Empty(t, disallowedRule) - - disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(2, 2, 1, 1)) - assert.True(t, disallowed) - assert.Equal(t, "2.2.0.0/16", disallowedRule) - - disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(2, 3, 1, 1)) - assert.False(t, disallowed) - assert.Empty(t, disallowedRule) -} - -func TestIsBlockedIPBlockedDomain(t *testing.T) { - a := &accessCtx{} - assert.True(t, a.Init(nil, nil, []string{ +func TestIsBlockedDomain(t *testing.T) { + aCtx := &accessCtx{} + require.Nil(t, aCtx.Init(nil, nil, []string{ "host1", - "host2", "*.host.com", "||host3.com^", - }) == nil) + })) - // match by "host2.com" - assert.True(t, a.IsBlockedDomain("host1")) - assert.True(t, a.IsBlockedDomain("host2")) - assert.False(t, a.IsBlockedDomain("host3")) + testCases := []struct { + name string + domain string + want bool + }{{ + name: "plain_match", + domain: "host1", + want: true, + }, { + name: "plain_mismatch", + domain: "host2", + want: false, + }, { + name: "wildcard_type-1_match_short", + domain: "asdf.host.com", + want: true, + }, { + name: "wildcard_type-1_match_long", + domain: "qwer.asdf.host.com", + want: true, + }, { + name: "wildcard_type-1_mismatch_no-lead", + domain: "host.com", + want: false, + }, { + name: "wildcard_type-1_mismatch_bad-asterisk", + domain: "asdf.zhost.com", + want: false, + }, { + name: "wildcard_type-2_match_simple", + domain: "host3.com", + want: true, + }, { + name: "wildcard_type-2_match_complex", + domain: "asdf.host3.com", + want: true, + }, { + name: "wildcard_type-2_mismatch", + domain: ".host3.com", + want: false, + }} - // match by wildcard "*.host.com" - assert.False(t, a.IsBlockedDomain("host.com")) - assert.True(t, a.IsBlockedDomain("asdf.host.com")) - assert.True(t, a.IsBlockedDomain("qwer.asdf.host.com")) - assert.False(t, a.IsBlockedDomain("asdf.zhost.com")) - - // match by wildcard "||host3.com^" - assert.True(t, a.IsBlockedDomain("host3.com")) - assert.True(t, a.IsBlockedDomain("asdf.host3.com")) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.want, aCtx.IsBlockedDomain(tc.domain)) + }) + } } diff --git a/internal/dnsforward/dnsforward_test.go b/internal/dnsforward/dnsforward_test.go index 93b6bcb7..81ca4e94 100644 --- a/internal/dnsforward/dnsforward_test.go +++ b/internal/dnsforward/dnsforward_test.go @@ -27,6 +27,7 @@ import ( "github.com/AdguardTeam/dnsproxy/upstream" "github.com/miekg/dns" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestMain(m *testing.M) { @@ -42,14 +43,180 @@ func startDeferStop(t *testing.T, s *Server) { t.Helper() err := s.Start() - assert.Nilf(t, err, "failed to start server: %s", err) + require.Nilf(t, err, "failed to start server: %s", err) t.Cleanup(func() { err := s.Stop() - assert.Nilf(t, err, "dns server failed to stop: %s", err) + require.Nilf(t, err, "dns server failed to stop: %s", err) }) } +func createTestServer(t *testing.T, filterConf *dnsfilter.Config, forwardConf ServerConfig) *Server { + t.Helper() + + rules := `||nxdomain.example.org +||null.example.org^ +127.0.0.1 host.example.org +@@||whitelist.example.org^ +||127.0.0.255` + filters := []dnsfilter.Filter{{ + ID: 0, Data: []byte(rules), + }} + + f := dnsfilter.New(filterConf, filters) + + s := NewServer(DNSCreateParams{DNSFilter: f}) + s.conf = forwardConf + require.Nil(t, s.Prepare(nil)) + + return s +} + +func createServerTLSConfig(t *testing.T) (*tls.Config, []byte, []byte) { + t.Helper() + + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.Nilf(t, err, "cannot generate RSA key: %s", err) + + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + require.Nilf(t, err, "failed to generate serial number: %s", err) + + notBefore := time.Now() + notAfter := notBefore.Add(5 * 365 * time.Hour * 24) + + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"AdGuard Tests"}, + }, + NotBefore: notBefore, + NotAfter: notAfter, + + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + IsCA: true, + } + template.DNSNames = append(template.DNSNames, tlsServerName) + + derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, publicKey(privateKey), privateKey) + require.Nilf(t, err, "failed to create certificate: %s", err) + + certPem := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + keyPem := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)}) + + cert, err := tls.X509KeyPair(certPem, keyPem) + require.Nilf(t, err, "failed to create certificate: %s", err) + + return &tls.Config{ + Certificates: []tls.Certificate{cert}, + ServerName: tlsServerName, + MinVersion: tls.VersionTLS12, + }, certPem, keyPem +} + +func createTestTLS(t *testing.T, tlsConf TLSConfig) (s *Server, certPem []byte) { + t.Helper() + + var keyPem []byte + _, certPem, keyPem = createServerTLSConfig(t) + + s = createTestServer(t, &dnsfilter.Config{}, ServerConfig{ + UDPListenAddr: &net.UDPAddr{}, + TCPListenAddr: &net.TCPAddr{}, + }) + + tlsConf.CertificateChainData, tlsConf.PrivateKeyData = certPem, keyPem + s.conf.TLSConfig = tlsConf + + err := s.Prepare(nil) + require.Nilf(t, err, "failed to prepare server: %s", err) + + return s, certPem +} + +func createGoogleATestMessage() *dns.Msg { + return createTestMessage("google-public-dns-a.google.com.") +} + +func createTestMessage(host string) *dns.Msg { + return &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Id: dns.Id(), + RecursionDesired: true, + }, + Question: []dns.Question{{ + Name: host, + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + }}, + } +} + +func createTestMessageWithType(host string, qtype uint16) *dns.Msg { + req := createTestMessage(host) + req.Question[0].Qtype = qtype + + return req +} + +func assertGoogleAResponse(t *testing.T, reply *dns.Msg) { + assertResponse(t, reply, net.IP{8, 8, 8, 8}) +} + +func assertResponse(t *testing.T, reply *dns.Msg, ip net.IP) { + t.Helper() + + require.Lenf(t, reply.Answer, 1, "dns server returned reply with wrong number of answers - %d", len(reply.Answer)) + + a, ok := reply.Answer[0].(*dns.A) + require.Truef(t, ok, "dns server returned wrong answer type instead of A: %v", reply.Answer[0]) + assert.Truef(t, a.A.Equal(ip), "dns server returned wrong answer instead of %s: %s", ip, a.A) +} + +// sendTestMessagesAsync sends messages in parallel to check for race issues. +// +//lint:ignore U1000 it's called from the function which is skipped for now. +func sendTestMessagesAsync(t *testing.T, conn *dns.Conn) { + t.Helper() + + wg := &sync.WaitGroup{} + + for i := 0; i < testMessagesCount; i++ { + msg := createGoogleATestMessage() + wg.Add(1) + + go func() { + defer wg.Done() + + err := conn.WriteMsg(msg) + require.Nilf(t, err, "cannot write message: %s", err) + + res, err := conn.ReadMsg() + require.Nilf(t, err, "cannot read response to message: %s", err) + + assertGoogleAResponse(t, res) + }() + } + + wg.Wait() +} + +func sendTestMessages(t *testing.T, conn *dns.Conn) { + t.Helper() + + for i := 0; i < testMessagesCount; i++ { + req := createGoogleATestMessage() + err := conn.WriteMsg(req) + assert.Nilf(t, err, "cannot write message #%d: %s", i, err) + + res, err := conn.ReadMsg() + assert.Nilf(t, err, "cannot read response to message #%d: %s", i, err) + assertGoogleAResponse(t, res) + } +} + func TestServer(t *testing.T) { s := createTestServer(t, &dnsfilter.Config{}, ServerConfig{ UDPListenAddr: &net.UDPAddr{}, @@ -81,7 +248,7 @@ func TestServer(t *testing.T) { client := dns.Client{Net: tc.proto} reply, _, err := client.Exchange(createGoogleATestMessage(), addr.String()) - assert.Nilf(t, err, "сouldn't talk to server %s: %s", addr, err) + require.Nilf(t, err, "сouldn't talk to server %s: %s", addr, err) assertGoogleAResponse(t, reply) }) @@ -106,31 +273,12 @@ func TestServerWithProtectionDisabled(t *testing.T) { req := createGoogleATestMessage() addr := s.dnsProxy.Addr(proxy.ProtoUDP) client := dns.Client{Net: proxy.ProtoUDP} + reply, _, err := client.Exchange(req, addr.String()) - assert.Nilf(t, err, "сouldn't talk to server %s: %s", addr, err) + require.Nilf(t, err, "сouldn't talk to server %s: %s", addr, err) assertGoogleAResponse(t, reply) } -func createTestTLS(t *testing.T, tlsConf TLSConfig) (s *Server, certPem []byte) { - t.Helper() - - var keyPem []byte - _, certPem, keyPem = createServerTLSConfig(t) - - s = createTestServer(t, &dnsfilter.Config{}, ServerConfig{ - UDPListenAddr: &net.UDPAddr{}, - TCPListenAddr: &net.TCPAddr{}, - }) - - tlsConf.CertificateChainData, tlsConf.PrivateKeyData = certPem, keyPem - s.conf.TLSConfig = tlsConf - - err := s.Prepare(nil) - assert.Nilf(t, err, "failed to prepare server: %s", err) - - return s, certPem -} - func TestDoTServer(t *testing.T) { s, certPem := createTestTLS(t, TLSConfig{ TLSListenAddr: &net.TCPAddr{}, @@ -156,7 +304,7 @@ func TestDoTServer(t *testing.T) { // Create a DNS-over-TLS client connection. addr := s.dnsProxy.Addr(proxy.ProtoTLS) conn, err := dns.DialWithTLS("tcp-tls", addr.String(), tlsConfig) - assert.Nilf(t, err, "cannot connect to the proxy: %s", err) + require.Nilf(t, err, "cannot connect to the proxy: %s", err) sendTestMessages(t, conn) } @@ -178,12 +326,12 @@ func TestDoQServer(t *testing.T) { addr := s.dnsProxy.Addr(proxy.ProtoQUIC) opts := upstream.Options{InsecureSkipVerify: true} u, err := upstream.AddressToUpstream(fmt.Sprintf("%s://%s", proxy.ProtoQUIC, addr), opts) - assert.Nil(t, err) + require.Nil(t, err) // Send the test message. req := createGoogleATestMessage() res, err := u.Exchange(req) - assert.Nil(t, err) + require.Nil(t, err) assertGoogleAResponse(t, res) } @@ -221,7 +369,7 @@ func TestServerRace(t *testing.T) { // Message over UDP. addr := s.dnsProxy.Addr(proxy.ProtoUDP) conn, err := dns.Dial(proxy.ProtoUDP, addr.String()) - assert.Nilf(t, err, "cannot connect to the proxy: %s", err) + require.Nilf(t, err, "cannot connect to the proxy: %s", err) sendTestMessagesAsync(t, conn) } @@ -282,8 +430,9 @@ func TestSafeSearch(t *testing.T) { for _, tc := range testCases { t.Run(tc.host, func(t *testing.T) { req := createTestMessage(tc.host) + reply, _, err := client.Exchange(req, addr) - assert.Nilf(t, err, "couldn't talk to server %s: %s", addr, err) + require.Nilf(t, err, "couldn't talk to server %s: %s", addr, err) assertResponse(t, reply, tc.want) }) } @@ -330,8 +479,10 @@ func TestBlockedRequest(t *testing.T) { req := createTestMessage("nxdomain.example.org.") reply, err := dns.Exchange(req, addr.String()) - assert.Nilf(t, err, "couldn't talk to server %s: %s", addr, err) + require.Nilf(t, err, "couldn't talk to server %s: %s", addr, err) assert.Equal(t, dns.RcodeSuccess, reply.Rcode) + + require.Len(t, reply.Answer, 1) assert.True(t, reply.Answer[0].(*dns.A).A.IsUnspecified()) } @@ -364,28 +515,14 @@ func TestServerCustomClientUpstream(t *testing.T) { reply, err := dns.Exchange(req, addr.String()) - assert.Nil(t, err) + require.Nil(t, err) assert.Equal(t, dns.RcodeSuccess, reply.Rcode) - assert.NotEmpty(t, reply.Answer) + require.NotEmpty(t, reply.Answer) + require.Len(t, reply.Answer, 1) assert.Equal(t, net.IP{192, 168, 0, 1}, reply.Answer[0].(*dns.A).A) } -func (s *Server) startWithUpstream(u upstream.Upstream) error { - s.Lock() - defer s.Unlock() - err := s.Prepare(nil) - if err != nil { - return err - } - - s.dnsProxy.UpstreamConfig = &proxy.UpstreamConfig{ - Upstreams: []upstream.Upstream{u}, - } - - return s.dnsProxy.Start() -} - // testCNAMEs is a map of names and CNAMEs necessary for the TestUpstream work. var testCNAMEs = map[string]string{ "badhost.": "null.example.org.", @@ -409,15 +546,19 @@ func TestBlockCNAMEProtectionEnabled(t *testing.T) { IPv6: nil, } s.conf.ProtectionEnabled = false - err := s.startWithUpstream(testUpstm) - assert.Nil(t, err) + s.dnsProxy.UpstreamConfig = &proxy.UpstreamConfig{ + Upstreams: []upstream.Upstream{testUpstm}, + } + startDeferStop(t, s) + addr := s.dnsProxy.Addr(proxy.ProtoUDP) - // 'badhost' has a canonical name 'null.example.org' which is blocked by - // filters: but protection is disabled so response is _not_ blocked. + // 'badhost' has a canonical name 'null.example.org' which should be + // blocked by filters, but protection is disabled so it is not. req := createTestMessage("badhost.") + reply, err := dns.Exchange(req, addr.String()) - assert.Nil(t, err) + require.Nil(t, err) assert.Equal(t, dns.RcodeSuccess, reply.Rcode) } @@ -465,11 +606,15 @@ func TestBlockCNAME(t *testing.T) { for _, tc := range testCases { t.Run("block_cname_"+tc.host, func(t *testing.T) { req := createTestMessage(tc.host) + reply, err := dns.Exchange(req, addr) - assert.Nil(t, err) + require.Nil(t, err) assert.Equal(t, dns.RcodeSuccess, reply.Rcode) if tc.want { - assert.True(t, reply.Answer[0].(*dns.A).A.IsUnspecified()) + require.Len(t, reply.Answer, 1) + a, ok := reply.Answer[0].(*dns.A) + require.True(t, ok) + assert.True(t, a.A.IsUnspecified()) } }) } @@ -513,7 +658,7 @@ func TestClientRulesForCNAMEMatching(t *testing.T) { // However, in our case it should not be blocked as filtering is // disabled on the client level. reply, err := dns.Exchange(&req, addr.String()) - assert.Nil(t, err) + require.Nil(t, err) assert.Equal(t, dns.RcodeSuccess, reply.Rcode) } @@ -544,10 +689,10 @@ func TestNullBlockedRequest(t *testing.T) { } reply, err := dns.Exchange(&req, addr.String()) - assert.Nilf(t, err, "couldn't talk to server %s: %s", addr, err) - assert.Lenf(t, reply.Answer, 1, "dns server %s returned reply with wrong number of answers - %d", addr, len(reply.Answer)) + require.Nilf(t, err, "couldn't talk to server %s: %s", addr, err) + require.Lenf(t, reply.Answer, 1, "dns server %s returned reply with wrong number of answers - %d", addr, len(reply.Answer)) a, ok := reply.Answer[0].(*dns.A) - assert.Truef(t, ok, "dns server %s returned wrong answer type instead of A: %v", addr, reply.Answer[0]) + require.Truef(t, ok, "dns server %s returned wrong answer type instead of A: %v", addr, reply.Answer[0]) assert.Truef(t, a.A.IsUnspecified(), "dns server %s returned wrong answer instead of 0.0.0.0: %v", addr, a.A) } @@ -561,7 +706,7 @@ func TestBlockedCustomIP(t *testing.T) { s := NewServer(DNSCreateParams{ DNSFilter: dnsfilter.New(&dnsfilter.Config{}, filters), }) - conf := ServerConfig{ + conf := &ServerConfig{ UDPListenAddr: &net.UDPAddr{}, TCPListenAddr: &net.TCPAddr{}, FilteringConfig: FilteringConfig{ @@ -572,11 +717,11 @@ func TestBlockedCustomIP(t *testing.T) { }, } // Invalid BlockingIPv4. - assert.NotNil(t, s.Prepare(&conf)) + assert.NotNil(t, s.Prepare(conf)) conf.BlockingIPv4 = net.IP{0, 0, 0, 1} conf.BlockingIPv6 = net.ParseIP("::1") - assert.Nil(t, s.Prepare(&conf)) + require.Nil(t, s.Prepare(conf)) startDeferStop(t, s) @@ -584,18 +729,18 @@ func TestBlockedCustomIP(t *testing.T) { req := createTestMessageWithType("null.example.org.", dns.TypeA) reply, err := dns.Exchange(req, addr.String()) - assert.Nil(t, err) - assert.Len(t, reply.Answer, 1) + require.Nil(t, err) + require.Len(t, reply.Answer, 1) a, ok := reply.Answer[0].(*dns.A) - assert.True(t, ok) + require.True(t, ok) assert.True(t, net.IP{0, 0, 0, 1}.Equal(a.A)) req = createTestMessageWithType("null.example.org.", dns.TypeAAAA) reply, err = dns.Exchange(req, addr.String()) - assert.Nil(t, err) - assert.Len(t, reply.Answer, 1) + require.Nil(t, err) + require.Len(t, reply.Answer, 1) a6, ok := reply.Answer[0].(*dns.AAAA) - assert.True(t, ok) + require.True(t, ok) assert.Equal(t, "::1", a6.AAAA.String()) } @@ -615,11 +760,10 @@ func TestBlockedByHosts(t *testing.T) { req := createTestMessage("host.example.org.") reply, err := dns.Exchange(req, addr.String()) - assert.Nilf(t, err, "couldn't talk to server %s: %s", addr, err) - assert.Lenf(t, reply.Answer, 1, "dns server %s returned reply with wrong number of answers - %d", addr, len(reply.Answer)) - + require.Nilf(t, err, "couldn't talk to server %s: %s", addr, err) + require.Lenf(t, reply.Answer, 1, "dns server %s returned reply with wrong number of answers - %d", addr, len(reply.Answer)) a, ok := reply.Answer[0].(*dns.A) - assert.Truef(t, ok, "dns server %s returned wrong answer type instead of A: %v", addr, reply.Answer[0]) + require.Truef(t, ok, "dns server %s returned wrong answer type instead of A: %v", addr, reply.Answer[0]) assert.Equalf(t, net.IP{127, 0, 0, 1}, a.A, "dns server %s returned wrong answer instead of 8.8.8.8: %v", addr, a.A) } @@ -630,7 +774,7 @@ func TestBlockedBySafeBrowsing(t *testing.T) { Hostname: hostname, Block: true, } - ans, _ := (&aghtest.TestResolver{}).HostToIPs(hostname) + ans4, _ := (&aghtest.TestResolver{}).HostToIPs(hostname) filterConf := &dnsfilter.Config{ SafeBrowsingEnabled: true, @@ -639,7 +783,7 @@ func TestBlockedBySafeBrowsing(t *testing.T) { UDPListenAddr: &net.UDPAddr{}, TCPListenAddr: &net.TCPAddr{}, FilteringConfig: FilteringConfig{ - SafeBrowsingBlockHost: ans.String(), + SafeBrowsingBlockHost: ans4.String(), ProtectionEnabled: true, }, } @@ -652,13 +796,12 @@ func TestBlockedBySafeBrowsing(t *testing.T) { req := createTestMessage(hostname + ".") reply, err := dns.Exchange(req, addr.String()) - assert.Nilf(t, err, "couldn't talk to server %s: %s", addr, err) - assert.Lenf(t, reply.Answer, 1, "dns server %s returned reply with wrong number of answers - %d", addr, len(reply.Answer)) + require.Nilf(t, err, "couldn't talk to server %s: %s", addr, err) + require.Lenf(t, reply.Answer, 1, "dns server %s returned reply with wrong number of answers - %d", addr, len(reply.Answer)) a, ok := reply.Answer[0].(*dns.A) - if assert.Truef(t, ok, "dns server %s returned wrong answer type instead of A: %v", addr, reply.Answer[0]) { - assert.Equal(t, ans, a.A, "dns server %s returned wrong answer: %v", addr, a.A) - } + require.Truef(t, ok, "dns server %s returned wrong answer type instead of A: %v", addr, reply.Answer[0]) + assert.Equal(t, ans4, a.A, "dns server %s returned wrong answer: %v", addr, a.A) } func TestRewrite(t *testing.T) { @@ -680,14 +823,14 @@ func TestRewrite(t *testing.T) { f := dnsfilter.New(c, nil) s := NewServer(DNSCreateParams{DNSFilter: f}) - err := s.Prepare(&ServerConfig{ + assert.Nil(t, s.Prepare(&ServerConfig{ UDPListenAddr: &net.UDPAddr{}, TCPListenAddr: &net.TCPAddr{}, FilteringConfig: FilteringConfig{ ProtectionEnabled: true, UpstreamDNS: []string{"8.8.8.8:53"}, }, - }) + })) s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ &aghtest.TestUpstream{ CName: map[string]string{ @@ -698,185 +841,44 @@ func TestRewrite(t *testing.T) { }, }, } - assert.Nil(t, err) startDeferStop(t, s) addr := s.dnsProxy.Addr(proxy.ProtoUDP) req := createTestMessageWithType("test.com.", dns.TypeA) reply, err := dns.Exchange(req, addr.String()) - assert.Nil(t, err) - assert.Len(t, reply.Answer, 1) + require.Nil(t, err) + require.Len(t, reply.Answer, 1) a, ok := reply.Answer[0].(*dns.A) - assert.True(t, ok) + require.True(t, ok) assert.True(t, net.IP{1, 2, 3, 4}.Equal(a.A)) req = createTestMessageWithType("test.com.", dns.TypeAAAA) reply, err = dns.Exchange(req, addr.String()) - assert.Nil(t, err) + require.Nil(t, err) assert.Empty(t, reply.Answer) req = createTestMessageWithType("alias.test.com.", dns.TypeA) reply, err = dns.Exchange(req, addr.String()) - assert.Nil(t, err) - assert.Len(t, reply.Answer, 2) + require.Nil(t, err) + + require.Len(t, reply.Answer, 2) assert.Equal(t, "test.com.", reply.Answer[0].(*dns.CNAME).Target) assert.True(t, net.IP{1, 2, 3, 4}.Equal(reply.Answer[1].(*dns.A).A)) req = createTestMessageWithType("my.alias.example.org.", dns.TypeA) reply, err = dns.Exchange(req, addr.String()) - assert.Nil(t, err) + require.Nil(t, err) + // The original question is restored. + require.Len(t, reply.Question, 1) assert.Equal(t, "my.alias.example.org.", reply.Question[0].Name) - assert.Len(t, reply.Answer, 2) + + require.Len(t, reply.Answer, 2) assert.Equal(t, "example.org.", reply.Answer[0].(*dns.CNAME).Target) assert.Equal(t, dns.TypeA, reply.Answer[1].Header().Rrtype) } -func createTestServer(t *testing.T, filterConf *dnsfilter.Config, forwardConf ServerConfig) *Server { - rules := `||nxdomain.example.org -||null.example.org^ -127.0.0.1 host.example.org -@@||whitelist.example.org^ -||127.0.0.255` - filters := []dnsfilter.Filter{{ - ID: 0, Data: []byte(rules), - }} - - f := dnsfilter.New(filterConf, filters) - - s := NewServer(DNSCreateParams{DNSFilter: f}) - s.conf = forwardConf - assert.Nil(t, s.Prepare(nil)) - - return s -} - -func createServerTLSConfig(t *testing.T) (*tls.Config, []byte, []byte) { - t.Helper() - - privateKey, err := rsa.GenerateKey(rand.Reader, 2048) - assert.Nilf(t, err, "cannot generate RSA key: %s", err) - - serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) - serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) - assert.Nilf(t, err, "failed to generate serial number: %s", err) - - notBefore := time.Now() - notAfter := notBefore.Add(5 * 365 * time.Hour * 24) - - template := x509.Certificate{ - SerialNumber: serialNumber, - Subject: pkix.Name{ - Organization: []string{"AdGuard Tests"}, - }, - NotBefore: notBefore, - NotAfter: notAfter, - - KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, - BasicConstraintsValid: true, - IsCA: true, - } - template.DNSNames = append(template.DNSNames, tlsServerName) - - derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, publicKey(privateKey), privateKey) - assert.Nilf(t, err, "failed to create certificate: %s", err) - - certPem := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) - keyPem := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)}) - - cert, err := tls.X509KeyPair(certPem, keyPem) - assert.Nilf(t, err, "failed to create certificate: %s", err) - - return &tls.Config{ - Certificates: []tls.Certificate{cert}, - ServerName: tlsServerName, - MinVersion: tls.VersionTLS12, - }, certPem, keyPem -} - -// sendTestMessagesAsync sends messages in parallel to check for race issues. -//lint:ignore U1000 it's called from the function which is skipped for now. -func sendTestMessagesAsync(t *testing.T, conn *dns.Conn) { - wg := &sync.WaitGroup{} - - for i := 0; i < testMessagesCount; i++ { - msg := createGoogleATestMessage() - wg.Add(1) - - go func() { - defer wg.Done() - - err := conn.WriteMsg(msg) - assert.Nilf(t, err, "cannot write message: %s", err) - - res, err := conn.ReadMsg() - assert.Nilf(t, err, "cannot read response to message: %s", err) - - assertGoogleAResponse(t, res) - }() - } - - wg.Wait() -} - -func sendTestMessages(t *testing.T, conn *dns.Conn) { - t.Helper() - - for i := 0; i < testMessagesCount; i++ { - req := createGoogleATestMessage() - err := conn.WriteMsg(req) - assert.Nilf(t, err, "cannot write message #%d: %s", i, err) - - res, err := conn.ReadMsg() - assert.Nilf(t, err, "cannot read response to message #%d: %s", i, err) - assertGoogleAResponse(t, res) - } -} - -func createGoogleATestMessage() *dns.Msg { - return createTestMessage("google-public-dns-a.google.com.") -} - -func createTestMessage(host string) *dns.Msg { - return &dns.Msg{ - MsgHdr: dns.MsgHdr{ - Id: dns.Id(), - RecursionDesired: true, - }, - Question: []dns.Question{{ - Name: host, - Qtype: dns.TypeA, - Qclass: dns.ClassINET, - }}, - } -} - -func createTestMessageWithType(host string, qtype uint16) *dns.Msg { - req := createTestMessage(host) - req.Question[0].Qtype = qtype - - return req -} - -func assertGoogleAResponse(t *testing.T, reply *dns.Msg) { - assertResponse(t, reply, net.IP{8, 8, 8, 8}) -} - -func assertResponse(t *testing.T, reply *dns.Msg, ip net.IP) { - t.Helper() - - if !assert.Lenf(t, reply.Answer, 1, "dns server returned reply with wrong number of answers - %d", len(reply.Answer)) { - return - } - - a, ok := reply.Answer[0].(*dns.A) - if assert.Truef(t, ok, "dns server returned wrong answer type instead of A: %v", reply.Answer[0]) { - assert.Truef(t, a.A.Equal(ip), "dns server returned wrong answer instead of %s: %s", ip, a.A) - } -} - func publicKey(priv interface{}) interface{} { switch k := priv.(type) { case *rsa.PrivateKey: @@ -966,8 +968,8 @@ func TestValidateUpstream(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { defaultUpstream, err := validateUpstream(tc.upstream) - assert.Equal(t, tc.valid, err == nil) - if err == nil { + require.Equal(t, tc.valid, err == nil) + if tc.valid { assert.Equal(t, tc.wantDef, defaultUpstream) } }) @@ -975,42 +977,73 @@ func TestValidateUpstream(t *testing.T) { } func TestValidateUpstreamsSet(t *testing.T) { - // Empty upstreams array. - var upstreamsSet []string - assert.Nil(t, ValidateUpstreams(upstreamsSet), "empty upstreams array should be valid") + testCases := []struct { + name string + msg string + set []string + wantNil bool + }{{ + name: "empty", + msg: "empty upstreams array should be valid", + set: nil, + wantNil: true, + }, { + name: "comment", + msg: "comments should not be validated", + set: []string{"# comment"}, + wantNil: true, + }, { + name: "valid_no_default", + msg: "there is no default upstream", + set: []string{ + "[/host.com/]1.1.1.1", + "[//]tls://1.1.1.1", + "[/www.host.com/]#", + "[/host.com/google.com/]8.8.8.8", + "[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20", + }, + wantNil: false, + }, { + name: "valid_with_default", + msg: "upstreams set is valid, but doesn't pass through validation cause: %s", + set: []string{ + "[/host.com/]1.1.1.1", + "[//]tls://1.1.1.1", + "[/www.host.com/]#", + "[/host.com/google.com/]8.8.8.8", + "[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20", + "8.8.8.8", + }, + wantNil: true, + }, { + name: "invalid", + msg: "there is an invalid upstream in set, but it pass through validation", + set: []string{"dhcp://fake.dns"}, + wantNil: false, + }} - // Comment in upstreams array. - upstreamsSet = []string{"# comment"} - assert.Nil(t, ValidateUpstreams(upstreamsSet), "comments should not be validated") + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := ValidateUpstreams(tc.set) - // Set of valid upstreams. There is no default upstream specified. - upstreamsSet = []string{ - "[/host.com/]1.1.1.1", - "[//]tls://1.1.1.1", - "[/www.host.com/]#", - "[/host.com/google.com/]8.8.8.8", - "[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20", + assert.Equalf(t, tc.wantNil, err == nil, tc.msg, err) + }) } - assert.NotNil(t, ValidateUpstreams(upstreamsSet), "there is no default upstream") - - // Let's add default upstream. - upstreamsSet = append(upstreamsSet, "8.8.8.8") - err := ValidateUpstreams(upstreamsSet) - assert.Nilf(t, err, "upstreams set is valid, but doesn't pass through validation cause: %s", err) - - // Let's add invalid upstream. - upstreamsSet = append(upstreamsSet, "dhcp://fake.dns") - assert.NotNil(t, ValidateUpstreams(upstreamsSet), "there is an invalid upstream in set, but it pass through validation") } func TestIPStringFromAddr(t *testing.T) { - addr := net.UDPAddr{ - IP: net.ParseIP("1:2:3::4"), - Port: 12345, - Zone: "eth0", - } - assert.Equal(t, IPStringFromAddr(&addr), addr.IP.String()) - assert.Empty(t, IPStringFromAddr(nil)) + t.Run("not_nil", func(t *testing.T) { + addr := net.UDPAddr{ + IP: net.ParseIP("1:2:3::4"), + Port: 12345, + Zone: "eth0", + } + assert.Equal(t, IPStringFromAddr(&addr), addr.IP.String()) + }) + + t.Run("nil", func(t *testing.T) { + assert.Empty(t, IPStringFromAddr(nil)) + }) } func TestMatchDNSName(t *testing.T) { @@ -1071,38 +1104,33 @@ func (d *testDHCP) Leases(flags int) []dhcpd.Lease { func (d *testDHCP) SetOnLeaseChanged(onLeaseChanged dhcpd.OnLeaseChangedT) {} func TestPTRResponseFromDHCPLeases(t *testing.T) { - dhcp := &testDHCP{} - s := NewServer(DNSCreateParams{ DNSFilter: dnsfilter.New(&dnsfilter.Config{}, nil), - DHCPServer: dhcp, + DHCPServer: &testDHCP{}, }) s.conf.UDPListenAddr = &net.UDPAddr{} s.conf.TCPListenAddr = &net.TCPAddr{} s.conf.UpstreamDNS = []string{"127.0.0.1:53"} s.conf.FilteringConfig.ProtectionEnabled = true - err := s.Prepare(nil) - assert.Nil(t, err) + require.Nil(t, s.Prepare(nil)) + require.Nil(t, s.Start()) + t.Cleanup(func() { + s.Close() + }) - assert.Nil(t, s.Start()) addr := s.dnsProxy.Addr(proxy.ProtoUDP) - req := createTestMessageWithType("1.0.0.127.in-addr.arpa.", dns.TypePTR) resp, err := dns.Exchange(req, addr.String()) - - assert.Nil(t, err) - assert.Len(t, resp.Answer, 1) + require.Nil(t, err) + require.Len(t, resp.Answer, 1) assert.Equal(t, dns.TypePTR, resp.Answer[0].Header().Rrtype) assert.Equal(t, "1.0.0.127.in-addr.arpa.", resp.Answer[0].Header().Name) ptr, ok := resp.Answer[0].(*dns.PTR) - if assert.True(t, ok) { - assert.Equal(t, "localhost.", ptr.Ptr) - } - - s.Close() + require.True(t, ok) + assert.Equal(t, "localhost.", ptr.Ptr) } func TestPTRResponseFromHosts(t *testing.T) { @@ -1112,12 +1140,11 @@ func TestPTRResponseFromHosts(t *testing.T) { // Prepare test hosts file. hf, err := ioutil.TempFile("", "") - if assert.Nil(t, err) { - t.Cleanup(func() { - assert.Nil(t, hf.Close()) - assert.Nil(t, os.Remove(hf.Name())) - }) - } + require.Nil(t, err) + t.Cleanup(func() { + assert.Nil(t, hf.Close()) + assert.Nil(t, os.Remove(hf.Name())) + }) _, _ = hf.WriteString(" 127.0.0.1 host # comment \n") _, _ = hf.WriteString(" ::1 localhost#comment \n") @@ -1131,23 +1158,23 @@ func TestPTRResponseFromHosts(t *testing.T) { s.conf.TCPListenAddr = &net.TCPAddr{} s.conf.UpstreamDNS = []string{"127.0.0.1:53"} s.conf.FilteringConfig.ProtectionEnabled = true - assert.Nil(t, s.Prepare(nil)) + require.Nil(t, s.Prepare(nil)) - assert.Nil(t, s.Start()) + require.Nil(t, s.Start()) + t.Cleanup(func() { + s.Close() + }) addr := s.dnsProxy.Addr(proxy.ProtoUDP) req := createTestMessageWithType("1.0.0.127.in-addr.arpa.", dns.TypePTR) resp, err := dns.Exchange(req, addr.String()) - assert.Nil(t, err) - assert.Len(t, resp.Answer, 1) + require.Nil(t, err) + require.Len(t, resp.Answer, 1) assert.Equal(t, dns.TypePTR, resp.Answer[0].Header().Rrtype) assert.Equal(t, "1.0.0.127.in-addr.arpa.", resp.Answer[0].Header().Name) ptr, ok := resp.Answer[0].(*dns.PTR) - if assert.True(t, ok) { - assert.Equal(t, "host.", ptr.Ptr) - } - - s.Close() + require.True(t, ok) + assert.Equal(t, "host.", ptr.Ptr) } diff --git a/internal/dnsforward/dnsrewrite_test.go b/internal/dnsforward/dnsrewrite_test.go index 4029038a..38f81081 100644 --- a/internal/dnsforward/dnsrewrite_test.go +++ b/internal/dnsforward/dnsrewrite_test.go @@ -9,6 +9,7 @@ import ( "github.com/AdguardTeam/urlfilter/rules" "github.com/miekg/dns" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestServer_FilterDNSRewrite(t *testing.T) { @@ -54,7 +55,8 @@ func TestServer_FilterDNSRewrite(t *testing.T) { d := &proxy.DNSContext{} err := srv.filterDNSRewrite(req, res, d) - assert.Nil(t, err) + + require.Nil(t, err) assert.Equal(t, dns.RcodeNameError, d.Res.Rcode) }) @@ -64,7 +66,7 @@ func TestServer_FilterDNSRewrite(t *testing.T) { d := &proxy.DNSContext{} err := srv.filterDNSRewrite(req, res, d) - assert.Nil(t, err) + require.Nil(t, err) assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode) assert.Empty(t, d.Res.Answer) }) @@ -75,11 +77,11 @@ func TestServer_FilterDNSRewrite(t *testing.T) { d := &proxy.DNSContext{} err := srv.filterDNSRewrite(req, res, d) - assert.Nil(t, err) + require.Nil(t, err) assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode) - if assert.Len(t, d.Res.Answer, 1) { - assert.Equal(t, ip4, d.Res.Answer[0].(*dns.A).A) - } + + require.Len(t, d.Res.Answer, 1) + assert.Equal(t, ip4, d.Res.Answer[0].(*dns.A).A) }) t.Run("noerror_aaaa", func(t *testing.T) { @@ -88,11 +90,11 @@ func TestServer_FilterDNSRewrite(t *testing.T) { d := &proxy.DNSContext{} err := srv.filterDNSRewrite(req, res, d) - assert.Nil(t, err) + require.Nil(t, err) assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode) - if assert.Len(t, d.Res.Answer, 1) { - assert.Equal(t, ip6, d.Res.Answer[0].(*dns.AAAA).AAAA) - } + + require.Len(t, d.Res.Answer, 1) + assert.Equal(t, ip6, d.Res.Answer[0].(*dns.AAAA).AAAA) }) t.Run("noerror_ptr", func(t *testing.T) { @@ -101,11 +103,11 @@ func TestServer_FilterDNSRewrite(t *testing.T) { d := &proxy.DNSContext{} err := srv.filterDNSRewrite(req, res, d) - assert.Nil(t, err) + require.Nil(t, err) assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode) - if assert.Len(t, d.Res.Answer, 1) { - assert.Equal(t, domain, d.Res.Answer[0].(*dns.PTR).Ptr) - } + + require.Len(t, d.Res.Answer, 1) + assert.Equal(t, domain, d.Res.Answer[0].(*dns.PTR).Ptr) }) t.Run("noerror_txt", func(t *testing.T) { @@ -114,11 +116,11 @@ func TestServer_FilterDNSRewrite(t *testing.T) { d := &proxy.DNSContext{} err := srv.filterDNSRewrite(req, res, d) - assert.Nil(t, err) + require.Nil(t, err) assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode) - if assert.Len(t, d.Res.Answer, 1) { - assert.Equal(t, []string{domain}, d.Res.Answer[0].(*dns.TXT).Txt) - } + + require.Len(t, d.Res.Answer, 1) + assert.Equal(t, []string{domain}, d.Res.Answer[0].(*dns.TXT).Txt) }) t.Run("noerror_mx", func(t *testing.T) { @@ -127,15 +129,15 @@ func TestServer_FilterDNSRewrite(t *testing.T) { d := &proxy.DNSContext{} err := srv.filterDNSRewrite(req, res, d) - assert.Nil(t, err) + require.Nil(t, err) assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode) - if assert.Len(t, d.Res.Answer, 1) { - ans, ok := d.Res.Answer[0].(*dns.MX) - if assert.True(t, ok) { - assert.Equal(t, mx.Exchange, ans.Mx) - assert.Equal(t, mx.Preference, ans.Preference) - } - } + + require.Len(t, d.Res.Answer, 1) + ans, ok := d.Res.Answer[0].(*dns.MX) + + require.True(t, ok) + assert.Equal(t, mx.Exchange, ans.Mx) + assert.Equal(t, mx.Preference, ans.Preference) }) t.Run("noerror_svcb", func(t *testing.T) { @@ -144,17 +146,17 @@ func TestServer_FilterDNSRewrite(t *testing.T) { d := &proxy.DNSContext{} err := srv.filterDNSRewrite(req, res, d) - assert.Nil(t, err) + require.Nil(t, err) assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode) - if assert.Len(t, d.Res.Answer, 1) { - ans, ok := d.Res.Answer[0].(*dns.SVCB) - if assert.True(t, ok) { - assert.Equal(t, dns.SVCB_ALPN, ans.Value[0].Key()) - assert.Equal(t, svcb.Params["alpn"], ans.Value[0].String()) - assert.Equal(t, svcb.Target, ans.Target) - assert.Equal(t, svcb.Priority, ans.Priority) - } - } + + require.Len(t, d.Res.Answer, 1) + ans, ok := d.Res.Answer[0].(*dns.SVCB) + require.True(t, ok) + + assert.Equal(t, dns.SVCB_ALPN, ans.Value[0].Key()) + assert.Equal(t, svcb.Params["alpn"], ans.Value[0].String()) + assert.Equal(t, svcb.Target, ans.Target) + assert.Equal(t, svcb.Priority, ans.Priority) }) t.Run("noerror_https", func(t *testing.T) { @@ -163,16 +165,16 @@ func TestServer_FilterDNSRewrite(t *testing.T) { d := &proxy.DNSContext{} err := srv.filterDNSRewrite(req, res, d) - assert.Nil(t, err) + require.Nil(t, err) assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode) - if assert.Len(t, d.Res.Answer, 1) { - ans, ok := d.Res.Answer[0].(*dns.HTTPS) - if assert.True(t, ok) { - assert.Equal(t, dns.SVCB_ALPN, ans.Value[0].Key()) - assert.Equal(t, svcb.Params["alpn"], ans.Value[0].String()) - assert.Equal(t, svcb.Target, ans.Target) - assert.Equal(t, svcb.Priority, ans.Priority) - } - } + + require.Len(t, d.Res.Answer, 1) + ans, ok := d.Res.Answer[0].(*dns.HTTPS) + + require.True(t, ok) + assert.Equal(t, dns.SVCB_ALPN, ans.Value[0].Key()) + assert.Equal(t, svcb.Params["alpn"], ans.Value[0].String()) + assert.Equal(t, svcb.Target, ans.Target) + assert.Equal(t, svcb.Priority, ans.Priority) }) } diff --git a/internal/dnsforward/http_test.go b/internal/dnsforward/http_test.go index f1ba7031..302ce34d 100644 --- a/internal/dnsforward/http_test.go +++ b/internal/dnsforward/http_test.go @@ -10,6 +10,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/dnsfilter" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestDNSForwardHTTTP_handleGetConfig(t *testing.T) { @@ -31,9 +32,10 @@ func TestDNSForwardHTTTP_handleGetConfig(t *testing.T) { ConfigModified: func() {}, } s := createTestServer(t, filterConf, forwardConf) - err := s.Start() - assert.Nil(t, err) - defer assert.Nil(t, s.Stop()) + require.Nil(t, s.Start()) + t.Cleanup(func() { + require.Nil(t, s.Stop()) + }) defaultConf := s.conf @@ -71,13 +73,14 @@ func TestDNSForwardHTTTP_handleGetConfig(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { + t.Cleanup(w.Body.Reset) + s.conf = tc.conf() s.handleGetConfig(w, nil) - assert.Equal(t, tc.want, w.Body.String()) assert.Equal(t, "application/json", w.Header().Get("Content-Type")) + assert.Equal(t, tc.want, w.Body.String()) }) - w.Body.Reset() } } @@ -191,9 +194,13 @@ func TestDNSForwardHTTTP_handleSetConfig(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { + t.Cleanup(func() { + s.conf = defaultConf + }) + rBody := ioutil.NopCloser(strings.NewReader(tc.req)) r, err := http.NewRequest(http.MethodPost, "http://example.com", rBody) - assert.Nil(t, err) + require.Nil(t, err) s.handleSetConfig(w, r) assert.Equal(t, tc.wantSet, w.Body.String()) @@ -203,6 +210,5 @@ func TestDNSForwardHTTTP_handleSetConfig(t *testing.T) { assert.Equal(t, tc.wantGet, w.Body.String()) w.Body.Reset() }) - s.conf = defaultConf } } diff --git a/internal/dnsforward/stats_test.go b/internal/dnsforward/stats_test.go index 3b5981bb..4eb2d3e5 100644 --- a/internal/dnsforward/stats_test.go +++ b/internal/dnsforward/stats_test.go @@ -12,6 +12,7 @@ import ( "github.com/AdguardTeam/dnsproxy/upstream" "github.com/miekg/dns" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) // testQueryLog is a simple querylog.QueryLog implementation for tests. @@ -156,7 +157,7 @@ func TestProcessQueryLogsAndStats(t *testing.T) { }} ups, err := upstream.AddressToUpstream("1.1.1.1", upstream.Options{}) - assert.Nil(t, err) + require.Nil(t, err) for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { diff --git a/internal/home/auth_test.go b/internal/home/auth_test.go index 7dbbf3c6..00ce309c 100644 --- a/internal/home/auth_test.go +++ b/internal/home/auth_test.go @@ -20,10 +20,18 @@ func TestMain(m *testing.M) { aghtest.DiscardLogOutput(m) } -func prepareTestDir() string { +func prepareTestDir(t *testing.T) string { + t.Helper() + const dir = "./agh-test" - _ = os.RemoveAll(dir) - _ = os.MkdirAll(dir, 0o755) + + require.Nil(t, os.RemoveAll(dir)) + // TODO(e.burkov): Replace with testing.TempDir after updating Go + // version to 1.16. + require.Nil(t, os.MkdirAll(dir, 0o755)) + + t.Cleanup(func() { require.Nil(t, os.RemoveAll(dir)) }) + return dir } @@ -47,8 +55,7 @@ func TestNewSessionToken(t *testing.T) { } func TestAuth(t *testing.T) { - dir := prepareTestDir() - t.Cleanup(func() { _ = os.RemoveAll(dir) }) + dir := prepareTestDir(t) fn := filepath.Join(dir, "sessions.db") users := []User{{ @@ -123,8 +130,7 @@ func (w *testResponseWriter) WriteHeader(statusCode int) { } func TestAuthHTTP(t *testing.T) { - dir := prepareTestDir() - defer func() { _ = os.RemoveAll(dir) }() + dir := prepareTestDir(t) fn := filepath.Join(dir, "sessions.db") users := []User{ diff --git a/internal/home/authglinet_test.go b/internal/home/authglinet_test.go index 70bb6636..6f53c181 100644 --- a/internal/home/authglinet_test.go +++ b/internal/home/authglinet_test.go @@ -4,40 +4,38 @@ import ( "encoding/binary" "io/ioutil" "net/http" - "os" "testing" "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestAuthGL(t *testing.T) { - dir := prepareTestDir() - defer func() { _ = os.RemoveAll(dir) }() + dir := prepareTestDir(t) GLMode = true + t.Cleanup(func() { + GLMode = false + }) glFilePrefix = dir + "/gl_token_" - tval := uint32(1) - data := make([]byte, 4) + putFunc := binary.BigEndian.PutUint32 if archIsLittleEndian() { - binary.LittleEndian.PutUint32(data, tval) - } else { - binary.BigEndian.PutUint32(data, tval) + putFunc = binary.LittleEndian.PutUint32 } - assert.Nil(t, ioutil.WriteFile(glFilePrefix+"test", data, 0o644)) + + data := make([]byte, 4) + putFunc(data, 1) + + require.Nil(t, ioutil.WriteFile(glFilePrefix+"test", data, 0o644)) assert.False(t, glCheckToken("test")) - tval = uint32(time.Now().UTC().Unix() + 60) data = make([]byte, 4) - if archIsLittleEndian() { - binary.LittleEndian.PutUint32(data, tval) - } else { - binary.BigEndian.PutUint32(data, tval) - } - assert.Nil(t, ioutil.WriteFile(glFilePrefix+"test", data, 0o644)) + putFunc(data, uint32(time.Now().UTC().Unix()+60)) + + require.Nil(t, ioutil.WriteFile(glFilePrefix+"test", data, 0o644)) r, _ := http.NewRequest(http.MethodGet, "http://localhost/", nil) r.AddCookie(&http.Cookie{Name: glCookieName, Value: "test"}) assert.True(t, glProcessCookie(r)) - GLMode = false } diff --git a/internal/home/clients_test.go b/internal/home/clients_test.go index a098bf4c..bfd0c11b 100644 --- a/internal/home/clients_test.go +++ b/internal/home/clients_test.go @@ -9,6 +9,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/dhcpd" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestClients(t *testing.T) { @@ -24,8 +25,8 @@ func TestClients(t *testing.T) { } ok, err := clients.Add(c) + require.Nil(t, err) assert.True(t, ok) - assert.Nil(t, err) c = &Client{ IDs: []string{"2.2.2.2"}, @@ -33,110 +34,99 @@ func TestClients(t *testing.T) { } ok, err = clients.Add(c) + require.Nil(t, err) assert.True(t, ok) - assert.Nil(t, err) c, ok = clients.Find("1.1.1.1") - assert.True(t, ok) + require.True(t, ok) assert.Equal(t, "client1", c.Name) c, ok = clients.Find("1:2:3::4") - assert.True(t, ok) + require.True(t, ok) assert.Equal(t, "client1", c.Name) c, ok = clients.Find("2.2.2.2") - assert.True(t, ok) + require.True(t, ok) assert.Equal(t, "client2", c.Name) - assert.True(t, !clients.Exists("1.2.3.4", ClientSourceHostsFile)) + assert.False(t, clients.Exists("1.2.3.4", ClientSourceHostsFile)) assert.True(t, clients.Exists("1.1.1.1", ClientSourceHostsFile)) assert.True(t, clients.Exists("2.2.2.2", ClientSourceHostsFile)) }) t.Run("add_fail_name", func(t *testing.T) { - c := &Client{ + ok, err := clients.Add(&Client{ IDs: []string{"1.2.3.5"}, Name: "client1", - } - - ok, err := clients.Add(c) + }) + require.Nil(t, err) assert.False(t, ok) - assert.Nil(t, err) }) t.Run("add_fail_ip", func(t *testing.T) { - c := &Client{ + ok, err := clients.Add(&Client{ IDs: []string{"2.2.2.2"}, Name: "client3", - } - - ok, err := clients.Add(c) + }) + require.NotNil(t, err) assert.False(t, ok) - assert.NotNil(t, err) }) t.Run("update_fail_name", func(t *testing.T) { - c := &Client{ + err := clients.Update("client3", &Client{ IDs: []string{"1.2.3.0"}, Name: "client3", - } + }) + require.NotNil(t, err) - err := clients.Update("client3", c) - assert.NotNil(t, err) - - c = &Client{ + err = clients.Update("client3", &Client{ IDs: []string{"1.2.3.0"}, Name: "client2", - } - - err = clients.Update("client3", c) + }) assert.NotNil(t, err) }) t.Run("update_fail_ip", func(t *testing.T) { - c := &Client{ + err := clients.Update("client1", &Client{ IDs: []string{"2.2.2.2"}, Name: "client1", - } - - err := clients.Update("client1", c) + }) assert.NotNil(t, err) }) t.Run("update_success", func(t *testing.T) { - c := &Client{ + err := clients.Update("client1", &Client{ IDs: []string{"1.1.1.2"}, Name: "client1", - } + }) + require.Nil(t, err) - err := clients.Update("client1", c) - assert.Nil(t, err) - - assert.True(t, !clients.Exists("1.1.1.1", ClientSourceHostsFile)) + assert.False(t, clients.Exists("1.1.1.1", ClientSourceHostsFile)) assert.True(t, clients.Exists("1.1.1.2", ClientSourceHostsFile)) - c = &Client{ + err = clients.Update("client1", &Client{ IDs: []string{"1.1.1.2"}, Name: "client1-renamed", UseOwnSettings: true, - } - - err = clients.Update("client1", c) - assert.Nil(t, err) + }) + require.Nil(t, err) c, ok := clients.Find("1.1.1.2") - assert.True(t, ok) + require.True(t, ok) assert.Equal(t, "client1-renamed", c.Name) assert.True(t, c.UseOwnSettings) - assert.Nil(t, clients.list["client1"]) - if assert.Len(t, c.IDs, 1) { - assert.Equal(t, "1.1.1.2", c.IDs[0]) - } + + nilCli, ok := clients.list["client1"] + require.False(t, ok) + assert.Nil(t, nilCli) + + require.Len(t, c.IDs, 1) + assert.Equal(t, "1.1.1.2", c.IDs[0]) }) t.Run("del_success", func(t *testing.T) { ok := clients.Del("client1-renamed") - assert.True(t, ok) + require.True(t, ok) assert.False(t, clients.Exists("1.1.1.2", ClientSourceHostsFile)) }) @@ -147,146 +137,155 @@ func TestClients(t *testing.T) { t.Run("addhost_success", func(t *testing.T) { ok, err := clients.AddHost("1.1.1.1", "host", ClientSourceARP) + require.Nil(t, err) assert.True(t, ok) - assert.Nil(t, err) ok, err = clients.AddHost("1.1.1.1", "host2", ClientSourceARP) + require.Nil(t, err) assert.True(t, ok) - assert.Nil(t, err) ok, err = clients.AddHost("1.1.1.1", "host3", ClientSourceHostsFile) + require.Nil(t, err) assert.True(t, ok) - assert.Nil(t, err) assert.True(t, clients.Exists("1.1.1.1", ClientSourceHostsFile)) }) t.Run("addhost_fail", func(t *testing.T) { ok, err := clients.AddHost("1.1.1.1", "host1", ClientSourceRDNS) + require.Nil(t, err) assert.False(t, ok) - assert.Nil(t, err) }) } func TestClientsWhois(t *testing.T) { - var c *Client - clients := clientsContainer{} - clients.testing = true + clients := clientsContainer{ + testing: true, + } clients.Init(nil, nil, nil) - whois := [][]string{{"orgname", "orgname-val"}, {"country", "country-val"}} - // set whois info on new client - clients.SetWhoisInfo("1.1.1.255", whois) - if assert.NotNil(t, clients.ipHost["1.1.1.255"]) { + + t.Run("new_client", func(t *testing.T) { + clients.SetWhoisInfo("1.1.1.255", whois) + + require.NotNil(t, clients.ipHost["1.1.1.255"]) h := clients.ipHost["1.1.1.255"] - if assert.Len(t, h.WhoisInfo, 2) && assert.Len(t, h.WhoisInfo[0], 2) { - assert.Equal(t, "orgname-val", h.WhoisInfo[0][1]) - } - } - // set whois info on existing auto-client - _, _ = clients.AddHost("1.1.1.1", "host", ClientSourceRDNS) - clients.SetWhoisInfo("1.1.1.1", whois) - if assert.NotNil(t, clients.ipHost["1.1.1.1"]) { + require.Len(t, h.WhoisInfo, 2) + require.Len(t, h.WhoisInfo[0], 2) + assert.Equal(t, "orgname-val", h.WhoisInfo[0][1]) + }) + + t.Run("existing_auto-client", func(t *testing.T) { + ok, err := clients.AddHost("1.1.1.1", "host", ClientSourceRDNS) + require.Nil(t, err) + assert.True(t, ok) + + clients.SetWhoisInfo("1.1.1.1", whois) + + require.NotNil(t, clients.ipHost["1.1.1.1"]) h := clients.ipHost["1.1.1.1"] - if assert.Len(t, h.WhoisInfo, 2) && assert.Len(t, h.WhoisInfo[0], 2) { - assert.Equal(t, "orgname-val", h.WhoisInfo[0][1]) - } - } - // Check that we cannot set whois info on a manually-added client - c = &Client{ - IDs: []string{"1.1.1.2"}, - Name: "client1", - } - _, _ = clients.Add(c) - clients.SetWhoisInfo("1.1.1.2", whois) - assert.Nil(t, clients.ipHost["1.1.1.2"]) - _ = clients.Del("client1") + require.Len(t, h.WhoisInfo, 2) + require.Len(t, h.WhoisInfo[0], 2) + assert.Equal(t, "orgname-val", h.WhoisInfo[0][1]) + }) + + t.Run("can't_set_manually-added", func(t *testing.T) { + ok, err := clients.Add(&Client{ + IDs: []string{"1.1.1.2"}, + Name: "client1", + }) + require.Nil(t, err) + assert.True(t, ok) + + clients.SetWhoisInfo("1.1.1.2", whois) + require.Nil(t, clients.ipHost["1.1.1.2"]) + assert.True(t, clients.Del("client1")) + }) } func TestClientsAddExisting(t *testing.T) { - var c *Client - clients := clientsContainer{} - clients.testing = true + clients := clientsContainer{ + testing: true, + } clients.Init(nil, nil, nil) - // some test variables - mac, _ := net.ParseMAC("aa:aa:aa:aa:aa:aa") - testIP := "1.2.3.4" + t.Run("simple", func(t *testing.T) { + // Add a client. + ok, err := clients.Add(&Client{ + IDs: []string{"1.1.1.1", "1:2:3::4", "aa:aa:aa:aa:aa:aa", "2.2.2.0/24"}, + Name: "client1", + }) + require.Nil(t, err) + assert.True(t, ok) - // add a client - c = &Client{ - IDs: []string{"1.1.1.1", "1:2:3::4", "aa:aa:aa:aa:aa:aa", "2.2.2.0/24"}, - Name: "client1", - } - ok, err := clients.Add(c) - assert.True(t, ok) - assert.Nil(t, err) - - // add an auto-client with the same IP - it's allowed - ok, err = clients.AddHost("1.1.1.1", "test", ClientSourceRDNS) - assert.True(t, ok) - assert.Nil(t, err) - - // now some more complicated stuff - // first, init a DHCP server with a single static lease - config := dhcpd.ServerConfig{ - DBFilePath: "leases.db", - } - defer func() { _ = os.Remove("leases.db") }() - clients.dhcpServer = dhcpd.Create(config) - err = clients.dhcpServer.AddStaticLease(dhcpd.Lease{ - HWAddr: mac, - IP: net.ParseIP(testIP).To4(), - Hostname: "testhost", - Expiry: time.Now().Add(time.Hour), + // Now add an auto-client with the same IP. + ok, err = clients.AddHost("1.1.1.1", "test", ClientSourceRDNS) + require.Nil(t, err) + assert.True(t, ok) }) - assert.Nil(t, err) - // add a new client with the same IP as for a client with MAC - c = &Client{ - IDs: []string{testIP}, - Name: "client2", - } - ok, err = clients.Add(c) - assert.True(t, ok) - assert.Nil(t, err) + t.Run("complicated", func(t *testing.T) { + testIP := net.IP{1, 2, 3, 4} - // add a new client with the IP from the client1's IP range - c = &Client{ - IDs: []string{"2.2.2.2"}, - Name: "client3", - } - ok, err = clients.Add(c) - assert.True(t, ok) - assert.Nil(t, err) + // First, init a DHCP server with a single static lease. + config := dhcpd.ServerConfig{ + DBFilePath: "leases.db", + } + clients.dhcpServer = dhcpd.Create(config) + t.Cleanup(func() { _ = os.Remove("leases.db") }) + + err := clients.dhcpServer.AddStaticLease(dhcpd.Lease{ + HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}, + IP: testIP, + Hostname: "testhost", + Expiry: time.Now().Add(time.Hour), + }) + require.Nil(t, err) + + // Add a new client with the same IP as for a client with MAC. + ok, err := clients.Add(&Client{ + IDs: []string{testIP.String()}, + Name: "client2", + }) + require.Nil(t, err) + assert.True(t, ok) + + // Add a new client with the IP from the first client's IP + // range. + ok, err = clients.Add(&Client{ + IDs: []string{"2.2.2.2"}, + Name: "client3", + }) + require.Nil(t, err) + assert.True(t, ok) + }) } func TestClientsCustomUpstream(t *testing.T) { - clients := clientsContainer{} - clients.testing = true - + clients := clientsContainer{ + testing: true, + } clients.Init(nil, nil, nil) - // add client with upstreams - c := &Client{ + // Add client with upstreams. + ok, err := clients.Add(&Client{ IDs: []string{"1.1.1.1", "1:2:3::4", "aa:aa:aa:aa:aa:aa"}, Name: "client1", Upstreams: []string{ "1.1.1.1", "[/example.org/]8.8.8.8", }, - } - ok, err := clients.Add(c) - assert.Nil(t, err) + }) + require.Nil(t, err) assert.True(t, ok) config := clients.FindUpstreams("1.2.3.4") assert.Nil(t, config) config = clients.FindUpstreams("1.1.1.1") - assert.NotNil(t, config) - assert.Equal(t, 1, len(config.Upstreams)) - assert.Equal(t, 1, len(config.DomainReservedUpstreams)) + require.NotNil(t, config) + assert.Len(t, config.Upstreams, 1) + assert.Len(t, config.DomainReservedUpstreams, 1) } diff --git a/internal/home/control_test.go b/internal/home/control_test.go index 5b08c2bd..46f14a2a 100644 --- a/internal/home/control_test.go +++ b/internal/home/control_test.go @@ -3,32 +3,12 @@ package home import ( "testing" "time" + + "github.com/stretchr/testify/assert" ) -/* Tests performed: -. Bad certificate -. Bad private key -. Valid certificate & private key */ -func TestValidateCertificates(t *testing.T) { - var data tlsConfigStatus - - // bad cert - data = validateCertificates("bad cert", "", "") - if !(data.WarningValidation != "" && - !data.ValidCert && - !data.ValidChain) { - t.Fatalf("bad cert: validateCertificates(): %v", data) - } - - // bad priv key - data = validateCertificates("", "bad priv key", "") - if !(data.WarningValidation != "" && - !data.ValidKey) { - t.Fatalf("bad priv key: validateCertificates(): %v", data) - } - - // valid cert & priv key - CertificateChain := `-----BEGIN CERTIFICATE----- +const ( + CertificateChain = `-----BEGIN CERTIFICATE----- MIICKzCCAZSgAwIBAgIJAMT9kPVJdM7LMA0GCSqGSIb3DQEBCwUAMC0xFDASBgNV BAoMC0FkR3VhcmQgTHRkMRUwEwYDVQQDDAxBZEd1YXJkIEhvbWUwHhcNMTkwMjI3 MDkyNDIzWhcNNDYwNzE0MDkyNDIzWjAtMRQwEgYDVQQKDAtBZEd1YXJkIEx0ZDEV @@ -42,7 +22,7 @@ LwlXfbakf7qkVTlCNXgoY7RaJ8rJdPgOZPoCTVToEhT6u/cb1c2qp8QB0dNExDna b0Z+dnODTZqQOJo6z/wIXlcUrnR4cQVvytXt8lFn+26l6Y6EMI26twC/xWr+1swq Muj4FeWHVDerquH4yMr1jsYLD3ci+kc5sbIX6TfVxQ== -----END CERTIFICATE-----` - PrivateKey := `-----BEGIN PRIVATE KEY----- + PrivateKey = `-----BEGIN PRIVATE KEY----- MIICeAIBADANBgkqhkiG9w0BAQEFAASCAmIwggJeAgEAAoGBALC/BSc8mI68tw5p aYa7pjrySwWvXeetcFywOWHGVfLw9qiFWLdfESa3Y6tWMpZAXD9t1Xh9n211YUBV FGSB4ZshnM/tgEPU6t787lJD4NsIIRp++MkJxdAitN4oUTqL0bdpIwezQ/CrYuBX @@ -58,20 +38,35 @@ O5EX70gpeGQMPDK0QSWpaazg956njJSDbNCFM4BccrdQbJu1cW4qOsfBAkAMgZuG O88slmgTRHX4JGFmy3rrLiHNI2BbJSuJ++Yllz8beVzh6NfvuY+HKRCmPqoBPATU kXS9jgARhhiWXJrk -----END PRIVATE KEY-----` - data = validateCertificates(CertificateChain, PrivateKey, "") - notBefore, _ := time.Parse(time.RFC3339, "2019-02-27T09:24:23Z") - notAfter, _ := time.Parse(time.RFC3339, "2046-07-14T09:24:23Z") - if !(data.WarningValidation != "" /* self signed */ && - data.ValidCert && - !data.ValidChain && - data.ValidKey && - data.KeyType == "RSA" && - data.Subject == "CN=AdGuard Home,O=AdGuard Ltd" && - data.Issuer == "CN=AdGuard Home,O=AdGuard Ltd" && - data.NotBefore.Equal(notBefore) && - data.NotAfter.Equal(notAfter) && - // data.DNSNames[0] == && - data.ValidPair) { - t.Fatalf("valid cert & priv key: validateCertificates(): %v", data) - } +) + +func TestValidateCertificates(t *testing.T) { + t.Run("bad_certificate", func(t *testing.T) { + data := validateCertificates("bad cert", "", "") + assert.NotEmpty(t, data.WarningValidation) + assert.False(t, data.ValidCert) + assert.False(t, data.ValidChain) + }) + + t.Run("bad_private_key", func(t *testing.T) { + data := validateCertificates("", "bad priv key", "") + assert.NotEmpty(t, data.WarningValidation) + assert.False(t, data.ValidKey) + }) + + t.Run("valid", func(t *testing.T) { + data := validateCertificates(CertificateChain, PrivateKey, "") + notBefore, _ := time.Parse(time.RFC3339, "2019-02-27T09:24:23Z") + notAfter, _ := time.Parse(time.RFC3339, "2046-07-14T09:24:23Z") + assert.NotEmpty(t, data.WarningValidation) + assert.True(t, data.ValidCert) + assert.False(t, data.ValidChain) + assert.True(t, data.ValidKey) + assert.Equal(t, "RSA", data.KeyType) + assert.Equal(t, "CN=AdGuard Home,O=AdGuard Ltd", data.Subject) + assert.Equal(t, "CN=AdGuard Home,O=AdGuard Ltd", data.Issuer) + assert.Equal(t, notBefore, data.NotBefore) + assert.Equal(t, notAfter, data.NotAfter) + assert.True(t, data.ValidPair) + }) } diff --git a/internal/home/filter_test.go b/internal/home/filter_test.go index a5b6d20b..dc4b295c 100644 --- a/internal/home/filter_test.go +++ b/internal/home/filter_test.go @@ -9,38 +9,47 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) -func testStartFilterListener() net.Listener { +func testStartFilterListener(t *testing.T) net.Listener { + t.Helper() + + const content = `||example.org^$third-party + # Inline comment example + ||example.com^$third-party + 0.0.0.0 example.com + ` + mux := http.NewServeMux() mux.HandleFunc("/filters/1.txt", func(w http.ResponseWriter, r *http.Request) { - content := `||example.org^$third-party -# Inline comment example -||example.com^$third-party -0.0.0.0 example.com -` - _, _ = w.Write([]byte(content)) + _, werr := w.Write([]byte(content)) + assert.Nil(t, werr) }) listener, err := net.Listen("tcp", ":0") - if err != nil { - panic(err) - } + require.Nil(t, err) + + go func() { + _ = http.Serve(listener, mux) + }() + + t.Cleanup(func() { + assert.Nil(t, listener.Close()) + }) - go func() { _ = http.Serve(listener, mux) }() return listener } func TestFilters(t *testing.T) { - l := testStartFilterListener() - defer func() { _ = l.Close() }() + l := testStartFilterListener(t) + dir := prepareTestDir(t) - dir := prepareTestDir() - defer func() { _ = os.RemoveAll(dir) }() - Context = homeContext{} - Context.workDir = dir - Context.client = &http.Client{ - Timeout: 5 * time.Second, + Context = homeContext{ + workDir: dir, + client: &http.Client{ + Timeout: 5 * time.Second, + }, } Context.filters.Init() @@ -48,20 +57,20 @@ func TestFilters(t *testing.T) { URL: fmt.Sprintf("http://127.0.0.1:%d/filters/1.txt", l.Addr().(*net.TCPAddr).Port), } - // download + // Download. ok, err := Context.filters.update(&f) - assert.Nil(t, err) - assert.True(t, ok) + require.Nil(t, err) + require.True(t, ok) assert.Equal(t, 3, f.RulesCount) - // refresh + // Refresh. ok, err = Context.filters.update(&f) - assert.False(t, ok) - assert.Nil(t, err) + require.Nil(t, err) + require.False(t, ok) err = Context.filters.load(&f) - assert.Nil(t, err) + require.Nil(t, err) f.unload() - _ = os.Remove(f.Path()) + require.Nil(t, os.Remove(f.Path())) } diff --git a/internal/home/home_test.go b/internal/home/home_test.go index 033b8e26..c18c6207 100644 --- a/internal/home/home_test.go +++ b/internal/home/home_test.go @@ -114,8 +114,7 @@ func TestHome(t *testing.T) { // Init new context Context = homeContext{} - dir := prepareTestDir() - defer func() { _ = os.RemoveAll(dir) }() + dir := prepareTestDir(t) fn := filepath.Join(dir, "AdGuardHome.yaml") // Prepare the test config diff --git a/internal/home/middlewares_test.go b/internal/home/middlewares_test.go index 8397302b..ebf0008a 100644 --- a/internal/home/middlewares_test.go +++ b/internal/home/middlewares_test.go @@ -39,21 +39,21 @@ func TestLimitRequestBody(t *testing.T) { wantErr: nil, }} - makeHandler := func(err *error) http.HandlerFunc { + makeHandler := func(t *testing.T, err *error) http.HandlerFunc { + t.Helper() + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var b []byte b, *err = ioutil.ReadAll(r.Body) _, werr := w.Write(b) - if werr != nil { - panic(werr) - } + require.Nil(t, werr) }) } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { var err error - handler := makeHandler(&err) + handler := makeHandler(t, &err) lim := limitRequestBody(handler) req := httptest.NewRequest(http.MethodPost, "https://www.example.com", strings.NewReader(tc.body)) @@ -61,7 +61,7 @@ func TestLimitRequestBody(t *testing.T) { lim.ServeHTTP(res, req) - require.Equal(t, tc.wantErr, err) + assert.Equal(t, tc.wantErr, err) assert.Equal(t, tc.want, res.Body.Bytes()) }) } diff --git a/internal/home/mobileconfig_test.go b/internal/home/mobileconfig_test.go index 9dcafc97..2a0e6d43 100644 --- a/internal/home/mobileconfig_test.go +++ b/internal/home/mobileconfig_test.go @@ -6,29 +6,29 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "howett.net/plist" ) func TestHandleMobileConfigDOH(t *testing.T) { t.Run("success", func(t *testing.T) { r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/doh.mobileconfig?host=example.org", nil) - assert.Nil(t, err) + require.Nil(t, err) w := httptest.NewRecorder() handleMobileConfigDOH(w, r) - assert.Equal(t, http.StatusOK, w.Code) + require.Equal(t, http.StatusOK, w.Code) var mc mobileConfig _, err = plist.Unmarshal(w.Body.Bytes(), &mc) - assert.Nil(t, err) + require.Nil(t, err) - if assert.Len(t, mc.PayloadContent, 1) { - assert.Equal(t, "example.org DoH", mc.PayloadContent[0].Name) - assert.Equal(t, "example.org DoH", mc.PayloadContent[0].PayloadDisplayName) - assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName) - assert.Equal(t, "https://example.org/dns-query", mc.PayloadContent[0].DNSSettings.ServerURL) - } + require.Len(t, mc.PayloadContent, 1) + assert.Equal(t, "example.org DoH", mc.PayloadContent[0].Name) + assert.Equal(t, "example.org DoH", mc.PayloadContent[0].PayloadDisplayName) + assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName) + assert.Equal(t, "https://example.org/dns-query", mc.PayloadContent[0].DNSSettings.ServerURL) }) t.Run("success_no_host", func(t *testing.T) { @@ -40,23 +40,22 @@ func TestHandleMobileConfigDOH(t *testing.T) { } r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/doh.mobileconfig", nil) - assert.Nil(t, err) + require.Nil(t, err) w := httptest.NewRecorder() handleMobileConfigDOH(w, r) - assert.Equal(t, http.StatusOK, w.Code) + require.Equal(t, http.StatusOK, w.Code) var mc mobileConfig _, err = plist.Unmarshal(w.Body.Bytes(), &mc) - assert.Nil(t, err) + require.Nil(t, err) - if assert.Len(t, mc.PayloadContent, 1) { - assert.Equal(t, "example.org DoH", mc.PayloadContent[0].Name) - assert.Equal(t, "example.org DoH", mc.PayloadContent[0].PayloadDisplayName) - assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName) - assert.Equal(t, "https://example.org/dns-query", mc.PayloadContent[0].DNSSettings.ServerURL) - } + require.Len(t, mc.PayloadContent, 1) + assert.Equal(t, "example.org DoH", mc.PayloadContent[0].Name) + assert.Equal(t, "example.org DoH", mc.PayloadContent[0].PayloadDisplayName) + assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName) + assert.Equal(t, "https://example.org/dns-query", mc.PayloadContent[0].DNSSettings.ServerURL) }) t.Run("error_no_host", func(t *testing.T) { @@ -66,7 +65,7 @@ func TestHandleMobileConfigDOH(t *testing.T) { Context.tls = &TLSMod{conf: tlsConfigSettings{}} r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/doh.mobileconfig", nil) - assert.Nil(t, err) + require.Nil(t, err) w := httptest.NewRecorder() @@ -76,45 +75,43 @@ func TestHandleMobileConfigDOH(t *testing.T) { t.Run("client_id", func(t *testing.T) { r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/doh.mobileconfig?host=example.org&client_id=cli42", nil) - assert.Nil(t, err) + require.Nil(t, err) w := httptest.NewRecorder() handleMobileConfigDOH(w, r) - assert.Equal(t, http.StatusOK, w.Code) + require.Equal(t, http.StatusOK, w.Code) var mc mobileConfig _, err = plist.Unmarshal(w.Body.Bytes(), &mc) - assert.Nil(t, err) + require.Nil(t, err) - if assert.Len(t, mc.PayloadContent, 1) { - assert.Equal(t, "example.org DoH", mc.PayloadContent[0].Name) - assert.Equal(t, "example.org DoH", mc.PayloadContent[0].PayloadDisplayName) - assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName) - assert.Equal(t, "https://example.org/dns-query/cli42", mc.PayloadContent[0].DNSSettings.ServerURL) - } + require.Len(t, mc.PayloadContent, 1) + assert.Equal(t, "example.org DoH", mc.PayloadContent[0].Name) + assert.Equal(t, "example.org DoH", mc.PayloadContent[0].PayloadDisplayName) + assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName) + assert.Equal(t, "https://example.org/dns-query/cli42", mc.PayloadContent[0].DNSSettings.ServerURL) }) } func TestHandleMobileConfigDOT(t *testing.T) { t.Run("success", func(t *testing.T) { r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/dot.mobileconfig?host=example.org", nil) - assert.Nil(t, err) + require.Nil(t, err) w := httptest.NewRecorder() handleMobileConfigDOT(w, r) - assert.Equal(t, http.StatusOK, w.Code) + require.Equal(t, http.StatusOK, w.Code) var mc mobileConfig _, err = plist.Unmarshal(w.Body.Bytes(), &mc) - assert.Nil(t, err) + require.Nil(t, err) - if assert.Len(t, mc.PayloadContent, 1) { - assert.Equal(t, "example.org DoT", mc.PayloadContent[0].Name) - assert.Equal(t, "example.org DoT", mc.PayloadContent[0].PayloadDisplayName) - assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName) - } + require.Len(t, mc.PayloadContent, 1) + assert.Equal(t, "example.org DoT", mc.PayloadContent[0].Name) + assert.Equal(t, "example.org DoT", mc.PayloadContent[0].PayloadDisplayName) + assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName) }) t.Run("success_no_host", func(t *testing.T) { @@ -126,22 +123,21 @@ func TestHandleMobileConfigDOT(t *testing.T) { } r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/dot.mobileconfig", nil) - assert.Nil(t, err) + require.Nil(t, err) w := httptest.NewRecorder() handleMobileConfigDOT(w, r) - assert.Equal(t, http.StatusOK, w.Code) + require.Equal(t, http.StatusOK, w.Code) var mc mobileConfig _, err = plist.Unmarshal(w.Body.Bytes(), &mc) - assert.Nil(t, err) + require.Nil(t, err) - if assert.Len(t, mc.PayloadContent, 1) { - assert.Equal(t, "example.org DoT", mc.PayloadContent[0].Name) - assert.Equal(t, "example.org DoT", mc.PayloadContent[0].PayloadDisplayName) - assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName) - } + require.Len(t, mc.PayloadContent, 1) + assert.Equal(t, "example.org DoT", mc.PayloadContent[0].Name) + assert.Equal(t, "example.org DoT", mc.PayloadContent[0].PayloadDisplayName) + assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName) }) t.Run("error_no_host", func(t *testing.T) { @@ -151,7 +147,7 @@ func TestHandleMobileConfigDOT(t *testing.T) { Context.tls = &TLSMod{conf: tlsConfigSettings{}} r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/dot.mobileconfig", nil) - assert.Nil(t, err) + require.Nil(t, err) w := httptest.NewRecorder() @@ -161,21 +157,20 @@ func TestHandleMobileConfigDOT(t *testing.T) { t.Run("client_id", func(t *testing.T) { r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/dot.mobileconfig?host=example.org&client_id=cli42", nil) - assert.Nil(t, err) + require.Nil(t, err) w := httptest.NewRecorder() handleMobileConfigDOT(w, r) - assert.Equal(t, http.StatusOK, w.Code) + require.Equal(t, http.StatusOK, w.Code) var mc mobileConfig _, err = plist.Unmarshal(w.Body.Bytes(), &mc) - assert.Nil(t, err) + require.Nil(t, err) - if assert.Len(t, mc.PayloadContent, 1) { - assert.Equal(t, "example.org DoT", mc.PayloadContent[0].Name) - assert.Equal(t, "example.org DoT", mc.PayloadContent[0].PayloadDisplayName) - assert.Equal(t, "cli42.example.org", mc.PayloadContent[0].DNSSettings.ServerName) - } + require.Len(t, mc.PayloadContent, 1) + assert.Equal(t, "example.org DoT", mc.PayloadContent[0].Name) + assert.Equal(t, "example.org DoT", mc.PayloadContent[0].PayloadDisplayName) + assert.Equal(t, "cli42.example.org", mc.PayloadContent[0].DNSSettings.ServerName) }) } diff --git a/internal/home/options_test.go b/internal/home/options_test.go index f24dc816..ee4cc4a5 100644 --- a/internal/home/options_test.go +++ b/internal/home/options_test.go @@ -4,96 +4,74 @@ import ( "fmt" "net" "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) -func testParseOk(t *testing.T, ss ...string) options { +func testParseOK(t *testing.T, ss ...string) options { + t.Helper() + o, _, err := parse("", ss) - if err != nil { - t.Fatal(err.Error()) - } + require.Nil(t, err) + return o } func testParseErr(t *testing.T, descr string, ss ...string) { + t.Helper() + _, _, err := parse("", ss) - if err == nil { - t.Fatalf("expected an error because %s but no error returned", descr) - } + require.NotNilf(t, err, "expected an error because %s but no error returned", descr) } func testParseParamMissing(t *testing.T, param string) { + t.Helper() + testParseErr(t, fmt.Sprintf("%s parameter missing", param), param) } func TestParseVerbose(t *testing.T) { - if testParseOk(t).verbose { - t.Fatal("empty is not verbose") - } - if !testParseOk(t, "-v").verbose { - t.Fatal("-v is verbose") - } - if !testParseOk(t, "--verbose").verbose { - t.Fatal("--verbose is verbose") - } + assert.False(t, testParseOK(t).verbose, "empty is not verbose") + assert.True(t, testParseOK(t, "-v").verbose, "-v is verbose") + assert.True(t, testParseOK(t, "--verbose").verbose, "--verbose is verbose") } func TestParseConfigFilename(t *testing.T) { - if testParseOk(t).configFilename != "" { - t.Fatal("empty is no config filename") - } - if testParseOk(t, "-c", "path").configFilename != "path" { - t.Fatal("-c is config filename") - } + assert.Equal(t, "", testParseOK(t).configFilename, "empty is no config filename") + assert.Equal(t, "path", testParseOK(t, "-c", "path").configFilename, "-c is config filename") testParseParamMissing(t, "-c") - if testParseOk(t, "--config", "path").configFilename != "path" { - t.Fatal("--configFilename is config filename") - } + + assert.Equal(t, "path", testParseOK(t, "--config", "path").configFilename, "--config is config filename") testParseParamMissing(t, "--config") } func TestParseWorkDir(t *testing.T) { - if testParseOk(t).workDir != "" { - t.Fatal("empty is no work dir") - } - if testParseOk(t, "-w", "path").workDir != "path" { - t.Fatal("-w is work dir") - } + assert.Equal(t, "", testParseOK(t).workDir, "empty is no work dir") + assert.Equal(t, "path", testParseOK(t, "-w", "path").workDir, "-w is work dir") testParseParamMissing(t, "-w") - if testParseOk(t, "--work-dir", "path").workDir != "path" { - t.Fatal("--work-dir is work dir") - } + + assert.Equal(t, "path", testParseOK(t, "--work-dir", "path").workDir, "--work-dir is work dir") testParseParamMissing(t, "--work-dir") } func TestParseBindHost(t *testing.T) { - if testParseOk(t).bindHost != nil { - t.Fatal("empty is no host") - } - if !testParseOk(t, "-h", "1.2.3.4").bindHost.Equal(net.IP{1, 2, 3, 4}) { - t.Fatal("-h is host") - } + assert.Nil(t, testParseOK(t).bindHost, "empty is not host") + assert.Equal(t, net.IPv4(1, 2, 3, 4), testParseOK(t, "-h", "1.2.3.4").bindHost, "-h is host") testParseParamMissing(t, "-h") - if !testParseOk(t, "--host", "1.2.3.4").bindHost.Equal(net.IP{1, 2, 3, 4}) { - t.Fatal("--host is host") - } + + assert.Equal(t, net.IPv4(1, 2, 3, 4), testParseOK(t, "--host", "1.2.3.4").bindHost, "--host is host") testParseParamMissing(t, "--host") } func TestParseBindPort(t *testing.T) { - if testParseOk(t).bindPort != 0 { - t.Fatal("empty is port 0") - } - if testParseOk(t, "-p", "65535").bindPort != 65535 { - t.Fatal("-p is port") - } + assert.Equal(t, 0, testParseOK(t).bindPort, "empty is port 0") + assert.Equal(t, 65535, testParseOK(t, "-p", "65535").bindPort, "-p is port") testParseParamMissing(t, "-p") - if testParseOk(t, "--port", "65535").bindPort != 65535 { - t.Fatal("--port is port") - } - testParseParamMissing(t, "--port") -} -func TestParseBindPortBad(t *testing.T) { + assert.Equal(t, 65535, testParseOK(t, "--port", "65535").bindPort, "--port is port") + testParseParamMissing(t, "--port") + testParseErr(t, "not an int", "-p", "x") testParseErr(t, "hex not supported", "-p", "0x100") testParseErr(t, "port negative", "-p", "-1") @@ -103,72 +81,40 @@ func TestParseBindPortBad(t *testing.T) { } func TestParseLogfile(t *testing.T) { - if testParseOk(t).logFile != "" { - t.Fatal("empty is no log file") - } - if testParseOk(t, "-l", "path").logFile != "path" { - t.Fatal("-l is log file") - } - if testParseOk(t, "--logfile", "path").logFile != "path" { - t.Fatal("--logfile is log file") - } + assert.Equal(t, "", testParseOK(t).logFile, "empty is no log file") + assert.Equal(t, "path", testParseOK(t, "-l", "path").logFile, "-l is log file") + assert.Equal(t, "path", testParseOK(t, "--logfile", "path").logFile, "--logfile is log file") } func TestParsePidfile(t *testing.T) { - if testParseOk(t).pidFile != "" { - t.Fatal("empty is no pid file") - } - if testParseOk(t, "--pidfile", "path").pidFile != "path" { - t.Fatal("--pidfile is pid file") - } + assert.Equal(t, "", testParseOK(t).pidFile, "empty is no pid file") + assert.Equal(t, "path", testParseOK(t, "--pidfile", "path").pidFile, "--pidfile is pid file") } func TestParseCheckConfig(t *testing.T) { - if testParseOk(t).checkConfig { - t.Fatal("empty is not check config") - } - if !testParseOk(t, "--check-config").checkConfig { - t.Fatal("--check-config is check config") - } + assert.False(t, testParseOK(t).checkConfig, "empty is not check config") + assert.True(t, testParseOK(t, "--check-config").checkConfig, "--check-config is check config") } func TestParseDisableUpdate(t *testing.T) { - if testParseOk(t).disableUpdate { - t.Fatal("empty is not disable update") - } - if !testParseOk(t, "--no-check-update").disableUpdate { - t.Fatal("--no-check-update is disable update") - } + assert.False(t, testParseOK(t).disableUpdate, "empty is not disable update") + assert.True(t, testParseOK(t, "--no-check-update").disableUpdate, "--no-check-update is disable update") } func TestParseDisableMemoryOptimization(t *testing.T) { - if testParseOk(t).disableMemoryOptimization { - t.Fatal("empty is not disable update") - } - if !testParseOk(t, "--no-mem-optimization").disableMemoryOptimization { - t.Fatal("--no-mem-optimization is disable update") - } + assert.False(t, testParseOK(t).disableMemoryOptimization, "empty is not disable update") + assert.True(t, testParseOK(t, "--no-mem-optimization").disableMemoryOptimization, "--no-mem-optimization is disable update") } func TestParseService(t *testing.T) { - if testParseOk(t).serviceControlAction != "" { - t.Fatal("empty is no service command") - } - if testParseOk(t, "-s", "command").serviceControlAction != "command" { - t.Fatal("-s is service command") - } - if testParseOk(t, "--service", "command").serviceControlAction != "command" { - t.Fatal("--service is service command") - } + assert.Equal(t, "", testParseOK(t).serviceControlAction, "empty is not service cmd") + assert.Equal(t, "cmd", testParseOK(t, "-s", "cmd").serviceControlAction, "-s is service cmd") + assert.Equal(t, "cmd", testParseOK(t, "--service", "cmd").serviceControlAction, "--service is service cmd") } func TestParseGLInet(t *testing.T) { - if testParseOk(t).glinetMode { - t.Fatal("empty is not GL-Inet mode") - } - if !testParseOk(t, "--glinet").glinetMode { - t.Fatal("--glinet is GL-Inet mode") - } + assert.False(t, testParseOK(t).glinetMode, "empty is not GL-Inet mode") + assert.True(t, testParseOK(t, "--glinet").glinetMode, "--glinet is GL-Inet mode") } func TestParseUnknown(t *testing.T) { @@ -180,73 +126,85 @@ func TestParseUnknown(t *testing.T) { testParseErr(t, "unknown dash", "-") } -func testSerialize(t *testing.T, o options, ss ...string) { - result := serialize(o) - if len(result) != len(ss) { - t.Fatalf("expected %s but got %s", ss, result) - } - for i, r := range result { - if r != ss[i] { - t.Fatalf("expected %s but got %s", ss, result) - } +func TestSerialize(t *testing.T) { + const reportFmt = "expected %s but got %s" + + testCases := []struct { + name string + opts options + ss []string + }{{ + name: "empty", + opts: options{}, + ss: []string{}, + }, { + name: "config_filename", + opts: options{configFilename: "path"}, + ss: []string{"-c", "path"}, + }, { + name: "work_dir", + opts: options{workDir: "path"}, + ss: []string{"-w", "path"}, + }, { + name: "bind_host", + opts: options{bindHost: net.IP{1, 2, 3, 4}}, + ss: []string{"-h", "1.2.3.4"}, + }, { + name: "bind_port", + opts: options{bindPort: 666}, + ss: []string{"-p", "666"}, + }, { + name: "log_file", + opts: options{logFile: "path"}, + ss: []string{"-l", "path"}, + }, { + name: "pid_file", + opts: options{pidFile: "path"}, + ss: []string{"--pidfile", "path"}, + }, { + name: "disable_update", + opts: options{disableUpdate: true}, + ss: []string{"--no-check-update"}, + }, { + name: "control_action", + opts: options{serviceControlAction: "run"}, + ss: []string{"-s", "run"}, + }, { + name: "glinet_mode", + opts: options{glinetMode: true}, + ss: []string{"--glinet"}, + }, { + name: "disable_mem_opt", + opts: options{disableMemoryOptimization: true}, + ss: []string{"--no-mem-optimization"}, + }, { + name: "multiple", + opts: options{ + serviceControlAction: "run", + configFilename: "config", + workDir: "work", + pidFile: "pid", + disableUpdate: true, + disableMemoryOptimization: true, + }, + ss: []string{ + "-c", "config", + "-w", "work", + "-s", "run", + "--pidfile", "pid", + "--no-check-update", + "--no-mem-optimization", + }, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := serialize(tc.opts) + require.Lenf(t, result, len(tc.ss), reportFmt, tc.ss, result) + + for i, r := range result { + assert.Equalf(t, tc.ss[i], r, reportFmt, tc.ss, result) + } + }) } } - -func TestSerializeEmpty(t *testing.T) { - testSerialize(t, options{}) -} - -func TestSerializeConfigFilename(t *testing.T) { - testSerialize(t, options{configFilename: "path"}, "-c", "path") -} - -func TestSerializeWorkDir(t *testing.T) { - testSerialize(t, options{workDir: "path"}, "-w", "path") -} - -func TestSerializeBindHost(t *testing.T) { - testSerialize(t, options{bindHost: net.IP{1, 2, 3, 4}}, "-h", "1.2.3.4") -} - -func TestSerializeBindPort(t *testing.T) { - testSerialize(t, options{bindPort: 666}, "-p", "666") -} - -func TestSerializeLogfile(t *testing.T) { - testSerialize(t, options{logFile: "path"}, "-l", "path") -} - -func TestSerializePidfile(t *testing.T) { - testSerialize(t, options{pidFile: "path"}, "--pidfile", "path") -} - -func TestSerializeCheckConfig(t *testing.T) { - testSerialize(t, options{checkConfig: true}, "--check-config") -} - -func TestSerializeDisableUpdate(t *testing.T) { - testSerialize(t, options{disableUpdate: true}, "--no-check-update") -} - -func TestSerializeService(t *testing.T) { - testSerialize(t, options{serviceControlAction: "run"}, "-s", "run") -} - -func TestSerializeGLInet(t *testing.T) { - testSerialize(t, options{glinetMode: true}, "--glinet") -} - -func TestSerializeDisableMemoryOptimization(t *testing.T) { - testSerialize(t, options{disableMemoryOptimization: true}, "--no-mem-optimization") -} - -func TestSerializeMultiple(t *testing.T) { - testSerialize(t, options{ - serviceControlAction: "run", - configFilename: "config", - workDir: "work", - pidFile: "pid", - disableUpdate: true, - disableMemoryOptimization: true, - }, "-c", "config", "-w", "work", "-s", "run", "--pidfile", "pid", "--no-check-update", "--no-mem-optimization") -} diff --git a/internal/querylog/qlogreader_test.go b/internal/querylog/qlogreader_test.go index 07fd6fd3..31622866 100644 --- a/internal/querylog/qlogreader_test.go +++ b/internal/querylog/qlogreader_test.go @@ -20,6 +20,7 @@ func newTestQLogReader(t *testing.T, filesNum, linesNum int) (reader *QLogReader // Create the new QLogReader instance. reader, err := NewQLogReader(testFiles) require.Nil(t, err) + assert.NotNil(t, reader) t.Cleanup(func() { assert.Nil(t, reader.Close()) @@ -112,11 +113,7 @@ func TestQLogReader_Seek(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { timestamp, err := time.Parse(time.RFC3339Nano, tc.time) - assert.Nil(t, err) - - if tc.name == "first" { - assert.True(t, true) - } + require.Nil(t, err) err = r.SeekTS(timestamp.UnixNano()) assert.True(t, errors.Is(err, tc.want)) @@ -146,11 +143,11 @@ func TestQLogReader_ReadNext(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { err := r.SeekStart() - assert.Nil(t, err, err) + require.Nil(t, err) for i := 1; i < tc.start; i++ { _, err := r.ReadNext() - assert.Nil(t, err) + require.Nil(t, err) } _, err = r.ReadNext()