diff --git a/CHANGELOG.md b/CHANGELOG.md index 0f888d72..d20ef582 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -43,6 +43,7 @@ and this project adheres to ### Changed +- Better OpenWrt detection ([#3435]). - DNS-over-HTTPS queries that come from HTTP proxies in the `trusted_proxies` list now use the real IP address of the client instead of the address of the proxy ([#2799]). @@ -123,6 +124,7 @@ and this project adheres to [#3351]: https://github.com/AdguardTeam/AdGuardHome/issues/3351 [#3372]: https://github.com/AdguardTeam/AdGuardHome/issues/3372 [#3417]: https://github.com/AdguardTeam/AdGuardHome/issues/3417 +[#3435]: https://github.com/AdguardTeam/AdGuardHome/issues/3435 [#3437]: https://github.com/AdguardTeam/AdGuardHome/issues/3437 diff --git a/go.mod b/go.mod index 0e6d2f59..198bba1e 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.16 require ( github.com/AdguardTeam/dnsproxy v0.39.2 - github.com/AdguardTeam/golibs v0.9.0 + github.com/AdguardTeam/golibs v0.9.1 github.com/AdguardTeam/urlfilter v0.14.6 github.com/NYTimes/gziphandler v1.1.1 github.com/ameshkov/dnscrypt/v2 v2.2.1 diff --git a/go.sum b/go.sum index cce9ed46..850b2375 100644 --- a/go.sum +++ b/go.sum @@ -14,8 +14,8 @@ github.com/AdguardTeam/dnsproxy v0.39.2/go.mod h1:aNXKNdTyKfgAG2OS712SYSaGIM9Aas github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4= github.com/AdguardTeam/golibs v0.4.2/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4= github.com/AdguardTeam/golibs v0.8.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4= -github.com/AdguardTeam/golibs v0.9.0 h1:QwmHqeZOVs9XpkmPb2iYpZ35OBArjgTesE8gLtEFRFg= -github.com/AdguardTeam/golibs v0.9.0/go.mod h1:fCAMwPBJ8S7YMYbTWvYS+eeTLblP5E04IDtNAo7y7IY= +github.com/AdguardTeam/golibs v0.9.1 h1:mHSN4LfaY1uGmHPsl97paAND/VeSnM5r9XQ7pSYx93o= +github.com/AdguardTeam/golibs v0.9.1/go.mod h1:fCAMwPBJ8S7YMYbTWvYS+eeTLblP5E04IDtNAo7y7IY= github.com/AdguardTeam/gomitmproxy v0.2.0/go.mod h1:Qdv0Mktnzer5zpdpi5rAwixNJzW2FN91LjKJCkVbYGU= github.com/AdguardTeam/urlfilter v0.14.6 h1:emqoKZElooHACYehRBYENeKVN1a/rspxiqTIMYLuoIo= github.com/AdguardTeam/urlfilter v0.14.6/go.mod h1:klx4JbOfc4EaNb5lWLqOwfg+pVcyRukmoJRvO55lL5U= diff --git a/internal/aghnet/net_freebsd.go b/internal/aghnet/net_freebsd.go index df6e8970..f4c106f1 100644 --- a/internal/aghnet/net_freebsd.go +++ b/internal/aghnet/net_freebsd.go @@ -8,69 +8,47 @@ import ( "fmt" "io" "net" - "os" "strings" - "github.com/AdguardTeam/AdGuardHome/internal/aghio" "github.com/AdguardTeam/AdGuardHome/internal/aghos" - "github.com/AdguardTeam/golibs/errors" ) func canBindPrivilegedPorts() (can bool, err error) { return aghos.HaveAdminRights() } -// maxCheckedFileSize is the maximum acceptable length of the /etc/rc.conf file. -const maxCheckedFileSize = 1024 * 1024 - func ifaceHasStaticIP(ifaceName string) (ok bool, err error) { const filename = "/etc/rc.conf" - var f *os.File - f, err = os.Open(filename) - if err != nil { - return false, err - } - defer func() { err = errors.WithDeferred(err, f.Close()) }() - - var r io.Reader - r, err = aghio.LimitReader(f, maxCheckedFileSize) - if err != nil { - return false, err - } - - return rcConfStaticConfig(r, ifaceName) + return aghos.FileWalker(interfaceName(ifaceName).rcConfStaticConfig).Walk(filename) } // rcConfStaticConfig checks if the interface is configured by /etc/rc.conf to // have a static IP. -func rcConfStaticConfig(r io.Reader, ifaceName string) (has bool, err error) { +func (n interfaceName) rcConfStaticConfig(r io.Reader) (_ []string, cont bool, err error) { s := bufio.NewScanner(r) - for ifaceLinePref := fmt.Sprintf("ifconfig_%s", ifaceName); s.Scan(); { + for pref := fmt.Sprintf("ifconfig_%s=", n); s.Scan(); { line := strings.TrimSpace(s.Text()) - if !strings.HasPrefix(line, ifaceLinePref) { + if !strings.HasPrefix(line, pref) { continue } - eqIdx := len(ifaceLinePref) - if line[eqIdx] != '=' { + cfgLeft, cfgRight := len(pref)+1, len(line)-1 + if cfgLeft >= cfgRight { continue } - fieldsStart, fieldsEnd := eqIdx+2, len(line)-1 - if fieldsStart >= fieldsEnd { - continue - } - - fields := strings.Fields(line[fieldsStart:fieldsEnd]) + // TODO(e.burkov): Expand the check to cover possible + // configurations from man rc.conf(5). + fields := strings.Fields(line[cfgLeft:cfgRight]) if len(fields) >= 2 && - strings.ToLower(fields[0]) == "inet" && + strings.EqualFold(fields[0], "inet") && net.ParseIP(fields[1]) != nil { - return true, s.Err() + return nil, false, s.Err() } } - return false, s.Err() + return nil, true, s.Err() } func ifaceSetStaticIP(string) (err error) { diff --git a/internal/aghnet/net_freebsd_test.go b/internal/aghnet/net_freebsd_test.go index bf8c7ec3..3781b154 100644 --- a/internal/aghnet/net_freebsd_test.go +++ b/internal/aghnet/net_freebsd_test.go @@ -12,49 +12,48 @@ import ( ) func TestRcConfStaticConfig(t *testing.T) { - const ifaceName = `em0` + const iface interfaceName = `em0` const nl = "\n" testCases := []struct { name string rcconfData string - wantHas bool + wantCont bool }{{ name: "simple", rcconfData: `ifconfig_em0="inet 127.0.0.253 netmask 0xffffffff"` + nl, - wantHas: true, + wantCont: false, }, { name: "case_insensitiveness", rcconfData: `ifconfig_em0="InEt 127.0.0.253 NeTmAsK 0xffffffff"` + nl, - wantHas: true, + wantCont: false, }, { name: "comments_and_trash", rcconfData: `# comment 1` + nl + `` + nl + `# comment 2` + nl + `ifconfig_em0="inet 127.0.0.253 netmask 0xffffffff"` + nl, - wantHas: true, + wantCont: false, }, { name: "aliases", rcconfData: `ifconfig_em0_alias="inet 127.0.0.1/24"` + nl + `ifconfig_em0="inet 127.0.0.253 netmask 0xffffffff"` + nl, - wantHas: true, + wantCont: false, }, { name: "incorrect_config", rcconfData: `ifconfig_em0="inet6 127.0.0.253 netmask 0xffffffff"` + nl + - `ifconfig_em0="inet 127.0.0.253 net-mask 0xffffffff"` + nl + `ifconfig_em0="inet 256.256.256.256 netmask 0xffffffff"` + nl + `ifconfig_em0=""` + nl, - wantHas: false, + wantCont: true, }} for _, tc := range testCases { r := strings.NewReader(tc.rcconfData) t.Run(tc.name, func(t *testing.T) { - has, err := rcConfStaticConfig(r, ifaceName) + _, cont, err := iface.rcConfStaticConfig(r) require.NoError(t, err) - assert.Equal(t, tc.wantHas, has) + assert.Equal(t, tc.wantCont, cont) }) } } diff --git a/internal/aghnet/net_linux.go b/internal/aghnet/net_linux.go index 0bdfbeb7..3c6a6659 100644 --- a/internal/aghnet/net_linux.go +++ b/internal/aghnet/net_linux.go @@ -9,130 +9,72 @@ import ( "io" "net" "os" - "path/filepath" "strings" - "github.com/AdguardTeam/AdGuardHome/internal/aghio" "github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/golibs/errors" - "github.com/AdguardTeam/golibs/stringutil" "github.com/google/renameio/maybe" "golang.org/x/sys/unix" ) -// recurrentChecker is used to check all the files which may include references -// for other ones. -type recurrentChecker struct { - // checker is the function to check if r's stream contains the desired - // attribute. It must return all the patterns for files which should - // also be checked and each of them should be valid for filepath.Glob - // function. - checker func(r io.Reader, desired string) (patterns []string, has bool, err error) - // initPath is the path of the first member in the sequence of checked - // files. - initPath string +// dhcpcdStaticConfig checks if interface is configured by /etc/dhcpcd.conf to +// have a static IP. +func (n interfaceName) dhcpcdStaticConfig(r io.Reader) (subsources []string, cont bool, err error) { + s := bufio.NewScanner(r) + ifaceFound := findIfaceLine(s, string(n)) + if !ifaceFound { + return nil, true, s.Err() + } + + for s.Scan() { + line := strings.TrimSpace(s.Text()) + fields := strings.Fields(line) + if len(fields) >= 2 && + fields[0] == "static" && + strings.HasPrefix(fields[1], "ip_address=") { + return nil, false, s.Err() + } + + if len(fields) > 0 && fields[0] == "interface" { + // Another interface found. + break + } + } + + return nil, true, s.Err() } -// maxCheckedFileSize is the maximum length of the file that recurrentChecker -// may check. -const maxCheckedFileSize = 1024 * 1024 - -// checkFile tries to open and to check single file located on the sourcePath. -func (rc *recurrentChecker) checkFile(sourcePath, desired string) ( - subsources []string, - has bool, - err error, -) { - var f *os.File - f, err = os.Open(sourcePath) - if err != nil { - return nil, false, err - } - defer func() { err = errors.WithDeferred(err, f.Close()) }() - - var r io.Reader - r, err = aghio.LimitReader(f, maxCheckedFileSize) - if err != nil { - return nil, false, err - } - - subsources, has, err = rc.checker(r, desired) - if err != nil { - return nil, false, err - } - - if has { - return nil, true, nil - } - - return subsources, has, nil -} - -// handlePatterns parses the patterns and takes care of duplicates. -func (rc *recurrentChecker) handlePatterns(sourcesSet *stringutil.Set, patterns []string) ( - subsources []string, - err error, -) { - subsources = make([]string, 0, len(patterns)) - for _, p := range patterns { - var matches []string - matches, err = filepath.Glob(p) - if err != nil { - return nil, fmt.Errorf("invalid pattern %q: %w", p, err) +// ifacesStaticConfig checks if the interface is configured by any file of +// /etc/network/interfaces format to have a static IP. +func (n interfaceName) ifacesStaticConfig(r io.Reader) (sub []string, cont bool, err error) { + s := bufio.NewScanner(r) + for s.Scan() { + line := strings.TrimSpace(s.Text()) + if len(line) == 0 || line[0] == '#' { + continue } - for _, m := range matches { - if sourcesSet.Has(m) { - continue - } + // TODO(e.burkov): As man page interfaces(5) says, a line may be + // extended across multiple lines by making the last character a + // backslash. Provide extended lines support. - sourcesSet.Add(m) - subsources = append(subsources, m) + fields := strings.Fields(line) + fieldsNum := len(fields) + + // Man page interfaces(5) declares that interface definition + // should consist of the key word "iface" followed by interface + // name, and method at fourth field. + if fieldsNum >= 4 && + fields[0] == "iface" && fields[1] == string(n) && fields[3] == "static" { + return nil, false, nil + } + + if fieldsNum >= 2 && fields[0] == "source" { + sub = append(sub, fields[1]) } } - return subsources, nil -} - -// check walks through all the files searching for the desired attribute. -func (rc *recurrentChecker) check(desired string) (has bool, err error) { - var i int - sources := []string{rc.initPath} - - defer func() { - if i >= len(sources) { - return - } - - err = errors.Annotate(err, "checking %q: %w", sources[i]) - }() - - var patterns, subsources []string - // The slice of sources is separate from the set of sources to keep the - // order in which the files are walked. - for sourcesSet := stringutil.NewSet(rc.initPath); i < len(sources); i++ { - patterns, has, err = rc.checkFile(sources[i], desired) - if err != nil { - if errors.Is(err, os.ErrNotExist) { - continue - } - - return false, err - } - - if has { - return true, nil - } - - subsources, err = rc.handlePatterns(sourcesSet, patterns) - if err != nil { - return false, err - } - - sources = append(sources, subsources...) - } - - return false, nil + return sub, true, s.Err() } func ifaceHasStaticIP(ifaceName string) (has bool, err error) { @@ -141,14 +83,19 @@ func ifaceHasStaticIP(ifaceName string) (has bool, err error) { // /etc/network/interfaces doesn't, it will return true. Perhaps this // is not the most desirable behavior. - for _, rc := range []*recurrentChecker{{ - checker: dhcpcdStaticConfig, - initPath: "/etc/dhcpcd.conf", + iface := interfaceName(ifaceName) + + for _, pair := range []struct { + aghos.FileWalker + filename string + }{{ + FileWalker: iface.dhcpcdStaticConfig, + filename: "/etc/dhcpcd.conf", }, { - checker: ifacesStaticConfig, - initPath: "/etc/network/interfaces", + FileWalker: iface.ifacesStaticConfig, + filename: "/etc/network/interfaces", }} { - has, err = rc.check(ifaceName) + has, err = pair.Walk(pair.filename) if err != nil { return false, err } @@ -183,67 +130,6 @@ func findIfaceLine(s *bufio.Scanner, name string) (ok bool) { return false } -// dhcpcdStaticConfig checks if interface is configured by /etc/dhcpcd.conf to -// have a static IP. -func dhcpcdStaticConfig(r io.Reader, ifaceName string) (subsources []string, has bool, err error) { - s := bufio.NewScanner(r) - ifaceFound := findIfaceLine(s, ifaceName) - if !ifaceFound { - return nil, false, s.Err() - } - - for s.Scan() { - line := strings.TrimSpace(s.Text()) - fields := strings.Fields(line) - if len(fields) >= 2 && - fields[0] == "static" && - strings.HasPrefix(fields[1], "ip_address=") { - return nil, true, s.Err() - } - - if len(fields) > 0 && fields[0] == "interface" { - // Another interface found. - break - } - } - - return nil, false, s.Err() -} - -// ifacesStaticConfig checks if the interface is configured by any file of -// /etc/network/interfaces format to have a static IP. -func ifacesStaticConfig(r io.Reader, ifaceName string) (subsources []string, has bool, err error) { - s := bufio.NewScanner(r) - for s.Scan() { - line := strings.TrimSpace(s.Text()) - if len(line) == 0 || line[0] == '#' { - continue - } - - // TODO(e.burkov): As man page interfaces(5) says, a line may be - // extended across multiple lines by making the last character a - // backslash. Provide extended lines and "source-directory" - // stanzas support. - - fields := strings.Fields(line) - fieldsNum := len(fields) - - // Man page interfaces(5) declares that interface definition - // should consist of the key word "iface" followed by interface - // name, and method at fourth field. - if fieldsNum >= 4 && - fields[0] == "iface" && fields[1] == ifaceName && fields[3] == "static" { - return nil, true, nil - } - - if fieldsNum >= 2 && fields[0] == "source" { - subsources = append(subsources, fields[1]) - } - } - - return subsources, false, s.Err() -} - // ifaceSetStaticIP configures the system to retain its current IP on the // interface through dhcpdc.conf. func ifaceSetStaticIP(ifaceName string) (err error) { diff --git a/internal/aghnet/net_linux_test.go b/internal/aghnet/net_linux_test.go index 9dd9a866..a907819c 100644 --- a/internal/aghnet/net_linux_test.go +++ b/internal/aghnet/net_linux_test.go @@ -12,101 +12,90 @@ import ( "github.com/stretchr/testify/require" ) -func TestRecurrentChecker(t *testing.T) { - c := &recurrentChecker{ - checker: ifacesStaticConfig, - initPath: "./testdata/include-subsources", - } - - has, err := c.check("sample_name") - require.NoError(t, err) - assert.True(t, has) - - has, err = c.check("another_name") - require.NoError(t, err) - assert.False(t, has) -} - const nl = "\n" func TestDHCPCDStaticConfig(t *testing.T) { + const iface interfaceName = `wlan0` + testCases := []struct { - name string - data []byte - want bool + name string + data []byte + wantCont bool }{{ name: "has_not", data: []byte(`#comment` + nl + `# comment` + nl + `interface eth0` + nl + `static ip_address=192.168.0.1/24` + nl + - `# interface wlan0` + nl + + `# interface ` + iface + nl + `static ip_address=192.168.1.1/24` + nl + `# comment` + nl, ), - want: false, + wantCont: true, }, { name: "has", data: []byte(`#comment` + nl + `# comment` + nl + `interface eth0` + nl + `static ip_address=192.168.0.1/24` + nl + - `# interface wlan0` + nl + + `# interface ` + iface + nl + `static ip_address=192.168.1.1/24` + nl + `# comment` + nl + - `interface wlan0` + nl + + `interface ` + iface + nl + `# comment` + nl + `static ip_address=192.168.2.1/24` + nl, ), - want: true, + wantCont: false, }} for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { r := bytes.NewReader(tc.data) - _, has, err := dhcpcdStaticConfig(r, "wlan0") + _, cont, err := iface.dhcpcdStaticConfig(r) require.NoError(t, err) - assert.Equal(t, tc.want, has) + assert.Equal(t, tc.wantCont, cont) }) } } func TestIfacesStaticConfig(t *testing.T) { + const iface interfaceName = `enp0s3` + testCases := []struct { name string data []byte - want bool + wantCont bool wantPatterns []string }{{ name: "has_not", - data: []byte(`allow-hotplug enp0s3` + nl + + data: []byte(`allow-hotplug ` + iface + nl + `#iface enp0s3 inet static` + nl + `# address 192.168.0.200` + nl + `# netmask 255.255.255.0` + nl + `# gateway 192.168.0.1` + nl + - `iface enp0s3 inet dhcp` + nl, + `iface ` + iface + ` inet dhcp` + nl, ), - want: false, + wantCont: true, wantPatterns: []string{}, }, { name: "has", - data: []byte(`allow-hotplug enp0s3` + nl + - `iface enp0s3 inet static` + nl + + data: []byte(`allow-hotplug ` + iface + nl + + `iface ` + iface + ` inet static` + nl + ` address 192.168.0.200` + nl + ` netmask 255.255.255.0` + nl + ` gateway 192.168.0.1` + nl + - `#iface enp0s3 inet dhcp` + nl, + `#iface ` + iface + ` inet dhcp` + nl, ), - want: true, + wantCont: false, wantPatterns: []string{}, }, { name: "return_patterns", data: []byte(`source hello` + nl + `source world` + nl + - `#iface enp0s3 inet static` + nl, + `#iface ` + iface + ` inet static` + nl, ), - want: false, + wantCont: true, wantPatterns: []string{"hello", "world"}, }, { // This one tests if the first found valid interface prevents @@ -114,19 +103,19 @@ func TestIfacesStaticConfig(t *testing.T) { name: "ignore_patterns", data: []byte(`source hello` + nl + `source world` + nl + - `iface enp0s3 inet static` + nl, + `iface ` + iface + ` inet static` + nl, ), - want: true, + wantCont: false, wantPatterns: []string{}, }} for _, tc := range testCases { + r := bytes.NewReader(tc.data) t.Run(tc.name, func(t *testing.T) { - r := bytes.NewReader(tc.data) - patterns, has, err := ifacesStaticConfig(r, "enp0s3") + patterns, has, err := iface.ifacesStaticConfig(r) require.NoError(t, err) - assert.Equal(t, tc.want, has) + assert.Equal(t, tc.wantCont, has) assert.ElementsMatch(t, tc.wantPatterns, patterns) }) } diff --git a/internal/aghnet/net_openbsd.go b/internal/aghnet/net_openbsd.go index c62689a9..d2604005 100644 --- a/internal/aghnet/net_openbsd.go +++ b/internal/aghnet/net_openbsd.go @@ -8,61 +8,34 @@ import ( "fmt" "io" "net" - "os" "strings" - "github.com/AdguardTeam/AdGuardHome/internal/aghio" "github.com/AdguardTeam/AdGuardHome/internal/aghos" - "github.com/AdguardTeam/golibs/errors" ) func canBindPrivilegedPorts() (can bool, err error) { return aghos.HaveAdminRights() } -// maxCheckedFileSize is the maximum acceptable length of the /etc/hostname.* -// files. -const maxCheckedFileSize = 1024 * 1024 - func ifaceHasStaticIP(ifaceName string) (ok bool, err error) { - const filenameFmt = "/etc/hostname.%s" + filename := fmt.Sprintf("/etc/hostname.%s", ifaceName) - filename := fmt.Sprintf(filenameFmt, ifaceName) - var f *os.File - if f, err = os.Open(filename); err != nil { - if errors.Is(err, os.ErrNotExist) { - err = nil - } - - return false, err - } - defer func() { err = errors.WithDeferred(err, f.Close()) }() - - var r io.Reader - r, err = aghio.LimitReader(f, maxCheckedFileSize) - if err != nil { - return false, err - } - - return hostnameIfStaticConfig(r) + return aghos.FileWalker(hostnameIfStaticConfig).Walk(filename) } // hostnameIfStaticConfig checks if the interface is configured by // /etc/hostname.* to have a static IP. -// -// TODO(e.burkov): The platform-dependent functions to check the static IP -// address configured are rather similar. Think about unifying common parts. -func hostnameIfStaticConfig(r io.Reader) (has bool, err error) { +func hostnameIfStaticConfig(r io.Reader) (_ []string, ok bool, err error) { s := bufio.NewScanner(r) for s.Scan() { line := strings.TrimSpace(s.Text()) fields := strings.Fields(line) if len(fields) >= 2 && fields[0] == "inet" && net.ParseIP(fields[1]) != nil { - return true, s.Err() + return nil, true, s.Err() } } - return false, s.Err() + return nil, false, s.Err() } func ifaceSetStaticIP(string) (err error) { diff --git a/internal/aghnet/net_openbsd_test.go b/internal/aghnet/net_openbsd_test.go index 5b005a6b..e157d93a 100644 --- a/internal/aghnet/net_openbsd_test.go +++ b/internal/aghnet/net_openbsd_test.go @@ -43,7 +43,7 @@ func TestHostnameIfStaticConfig(t *testing.T) { for _, tc := range testCases { r := strings.NewReader(tc.rcconfData) t.Run(tc.name, func(t *testing.T) { - has, err := hostnameIfStaticConfig(r) + _, has, err := hostnameIfStaticConfig(r) require.NoError(t, err) assert.Equal(t, tc.wantHas, has) diff --git a/internal/aghnet/net_unix.go b/internal/aghnet/net_unix.go new file mode 100644 index 00000000..efca131b --- /dev/null +++ b/internal/aghnet/net_unix.go @@ -0,0 +1,8 @@ +//go:build openbsd || freebsd || linux +// +build openbsd freebsd linux + +package aghnet + +// interfaceName is a string containing network interface's name. The name is +// used in file walking methods. +type interfaceName string diff --git a/internal/aghos/sysutil_test.go b/internal/aghos/aghos_test.go similarity index 100% rename from internal/aghos/sysutil_test.go rename to internal/aghos/aghos_test.go diff --git a/internal/aghos/filewalker.go b/internal/aghos/filewalker.go new file mode 100644 index 00000000..b6473d1f --- /dev/null +++ b/internal/aghos/filewalker.go @@ -0,0 +1,119 @@ +package aghos + +import ( + "fmt" + "io" + "os" + "path/filepath" + + "github.com/AdguardTeam/AdGuardHome/internal/aghio" + "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/stringutil" +) + +// FileWalker is the signature of a function called for files in the file tree. +// As opposed to filepath.Walk it only walk the files (not directories) matching +// the provided pattern and those returned by function itself. All patterns +// should be valid for filepath.Glob. If cont is false, the walking terminates. +// Each opened file is also limited for reading to MaxWalkedFileSize. +// +// TODO(e.burkov): Consider moving to the separate package like pathutil. +// +// TODO(e.burkov): Think about passing filename or any additional data. +type FileWalker func(r io.Reader) (patterns []string, cont bool, err error) + +// MaxWalkedFileSize is the maximum length of the file that FileWalker can +// check. +const MaxWalkedFileSize = 1024 * 1024 + +// checkFile tries to open and process a single file located on sourcePath. +func checkFile(c FileWalker, sourcePath string) (patterns []string, cont bool, err error) { + var f *os.File + f, err = os.Open(sourcePath) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + // Ignore non-existing files since this may only happen + // when the file was removed after filepath.Glob matched + // it. + return nil, true, nil + } + + return nil, false, err + } + defer func() { err = errors.WithDeferred(err, f.Close()) }() + + var r io.Reader + // Ignore the error since LimitReader function returns error only if + // passed limit value is less than zero, but the constant used. + // + // TODO(e.burkov): Make variable. + r, _ = aghio.LimitReader(f, MaxWalkedFileSize) + + return c(r) +} + +// handlePatterns parses the patterns and ignores duplicates using srcSet. +// srcSet must be non-nil. +func handlePatterns(srcSet *stringutil.Set, patterns ...string) (sub []string, err error) { + sub = make([]string, 0, len(patterns)) + for _, p := range patterns { + var matches []string + matches, err = filepath.Glob(p) + if err != nil { + // Enrich error with the pattern because filepath.Glob + // doesn't do it. + return nil, fmt.Errorf("invalid pattern %q: %w", p, err) + } + + for _, m := range matches { + if srcSet.Has(m) { + continue + } + + srcSet.Add(m) + sub = append(sub, m) + } + } + + return sub, nil +} + +// Walk starts walking the files defined by initPattern. It only returns true +// if c signed to stop walking. +func (c FileWalker) Walk(initPattern string) (ok bool, err error) { + // The slice of sources keeps the order in which the files are walked + // since srcSet.Values() returns strings in undefined order. + srcSet := stringutil.NewSet() + var src []string + src, err = handlePatterns(srcSet, initPattern) + if err != nil { + return false, err + } + + var filename string + defer func() { err = errors.Annotate(err, "checking %q: %w", filename) }() + + for i := 0; i < len(src); i++ { + var patterns []string + var cont bool + filename = src[i] + patterns, cont, err = checkFile(c, src[i]) + if err != nil { + return false, err + } + + if !cont { + return true, nil + } + + var subsrc []string + subsrc, err = handlePatterns(srcSet, patterns...) + if err != nil { + return false, err + } + + src = append(src, subsrc...) + } + + return false, nil +} diff --git a/internal/aghos/filewalker_test.go b/internal/aghos/filewalker_test.go new file mode 100644 index 00000000..4ba1db20 --- /dev/null +++ b/internal/aghos/filewalker_test.go @@ -0,0 +1,209 @@ +package aghos + +import ( + "bufio" + "io" + "io/fs" + "os" + "path/filepath" + "testing" + + "github.com/AdguardTeam/golibs/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// testFSDir maps entries' names to entries which should either be a testFSDir +// or byte slice. +type testFSDir map[string]interface{} + +// testFSGen is used to generate a temporary filesystem consisting of +// directories and plain text files from itself. +type testFSGen testFSDir + +// gen returns the name of top directory of the generated filesystem. +func (g testFSGen) gen(t *testing.T) (dirName string) { + t.Helper() + + dirName = t.TempDir() + g.rangeThrough(t, dirName) + + return dirName +} + +func (g testFSGen) rangeThrough(t *testing.T, dirName string) { + const perm fs.FileMode = 0o777 + + for k, e := range g { + switch e := e.(type) { + case []byte: + require.NoError(t, os.WriteFile(filepath.Join(dirName, k), e, perm)) + + case testFSDir: + newDir := filepath.Join(dirName, k) + require.NoError(t, os.Mkdir(newDir, perm)) + + testFSGen(e).rangeThrough(t, newDir) + default: + t.Fatalf("unexpected entry type %T", e) + } + } +} + +func TestFileWalker_Walk(t *testing.T) { + const attribute = `000` + + makeFileWalker := func(dirName string) (fw FileWalker) { + return func(r io.Reader) (patterns []string, cont bool, err error) { + s := bufio.NewScanner(r) + for s.Scan() { + line := s.Text() + if line == attribute { + return nil, false, nil + } + + if len(line) != 0 { + patterns = append(patterns, filepath.Join(dirName, line)) + } + } + + return patterns, true, s.Err() + } + } + + const nl = "\n" + + testCases := []struct { + name string + testFS testFSGen + initPattern string + want bool + }{{ + name: "simple", + testFS: testFSGen{ + "simple_0001.txt": []byte(attribute + nl), + }, + initPattern: "simple_0001.txt", + want: true, + }, { + name: "chain", + testFS: testFSGen{ + "chain_0001.txt": []byte(`chain_0002.txt` + nl), + "chain_0002.txt": []byte(`chain_0003.txt` + nl), + "chain_0003.txt": []byte(attribute + nl), + }, + initPattern: "chain_0001.txt", + want: true, + }, { + name: "several", + testFS: testFSGen{ + "several_0001.txt": []byte(`several_*` + nl), + "several_0002.txt": []byte(`several_0001.txt` + nl), + "several_0003.txt": []byte(attribute + nl), + }, + initPattern: "several_0001.txt", + want: true, + }, { + name: "no", + testFS: testFSGen{ + "no_0001.txt": []byte(nl), + "no_0002.txt": []byte(nl), + "no_0003.txt": []byte(nl), + }, + initPattern: "no_*", + want: false, + }, { + name: "subdirectory", + testFS: testFSGen{ + "dir": testFSDir{ + "subdir_0002.txt": []byte(attribute + nl), + }, + "subdir_0001.txt": []byte(`dir/*`), + }, + initPattern: "subdir_0001.txt", + want: true, + }} + + for _, tc := range testCases { + testDir := tc.testFS.gen(t) + fw := makeFileWalker(testDir) + + t.Run(tc.name, func(t *testing.T) { + ok, err := fw.Walk(filepath.Join(testDir, tc.initPattern)) + require.NoError(t, err) + + assert.Equal(t, tc.want, ok) + }) + } + + t.Run("pattern_malformed", func(t *testing.T) { + ok, err := makeFileWalker("").Walk("[]") + require.Error(t, err) + + assert.False(t, ok) + assert.ErrorIs(t, err, filepath.ErrBadPattern) + }) + + t.Run("bad_filename", func(t *testing.T) { + dir := testFSGen{ + "bad_filename.txt": []byte("[]"), + }.gen(t) + fw := FileWalker(func(r io.Reader) (patterns []string, cont bool, err error) { + s := bufio.NewScanner(r) + for s.Scan() { + patterns = append(patterns, s.Text()) + } + + return patterns, true, s.Err() + }) + + ok, err := fw.Walk(filepath.Join(dir, "bad_filename.txt")) + require.Error(t, err) + + assert.False(t, ok) + assert.ErrorIs(t, err, filepath.ErrBadPattern) + }) + + t.Run("itself_error", func(t *testing.T) { + const rerr errors.Error = "returned error" + + dir := testFSGen{ + "mockfile.txt": []byte(`mockdata`), + }.gen(t) + + ok, err := FileWalker(func(r io.Reader) (patterns []string, ok bool, err error) { + return nil, true, rerr + }).Walk(filepath.Join(dir, "*")) + require.Error(t, err) + require.False(t, ok) + + assert.ErrorIs(t, err, rerr) + }) +} + +func TestWalkerFunc_CheckFile(t *testing.T) { + t.Run("non-existing", func(t *testing.T) { + _, ok, err := checkFile(nil, "lol") + require.NoError(t, err) + + assert.True(t, ok) + }) + + t.Run("invalid_argument", func(t *testing.T) { + const badPath = "\x00" + + _, ok, err := checkFile(nil, badPath) + require.Error(t, err) + + assert.False(t, ok) + // TODO(e.burkov): Use assert.ErrorsIs within the error from + // less platform-dependent package instead of syscall.EINVAL. + // + // See https://github.com/golang/go/issues/46849 and + // https://github.com/golang/go/issues/30322. + pathErr := &os.PathError{} + require.ErrorAs(t, err, &pathErr) + assert.Equal(t, "open", pathErr.Op) + assert.Equal(t, badPath, pathErr.Path) + }) +} diff --git a/internal/aghos/os_linux.go b/internal/aghos/os_linux.go index f61a6874..5b1e5e87 100644 --- a/internal/aghos/os_linux.go +++ b/internal/aghos/os_linux.go @@ -4,11 +4,11 @@ package aghos import ( - "bytes" + "io" "os" - "path/filepath" - "strings" "syscall" + + "github.com/AdguardTeam/golibs/stringutil" ) func setRlimit(val uint64) (err error) { @@ -30,37 +30,20 @@ func sendProcessSignal(pid int, sig syscall.Signal) error { } func isOpenWrt() (ok bool) { - const etcDir = "/etc" + var err error + ok, err = FileWalker(func(r io.Reader) (_ []string, cont bool, err error) { + const osNameData = "openwrt" - dirEnts, err := os.ReadDir(etcDir) - if err != nil { - return false - } - - // fNameSubstr is a part of a name of the desired file. - const fNameSubstr = "release" - osNameData := []byte("OpenWrt") - - for _, dirEnt := range dirEnts { - if dirEnt.IsDir() { - continue - } - - fn := dirEnt.Name() - if !strings.Contains(fn, fNameSubstr) { - continue - } - - var body []byte - body, err = os.ReadFile(filepath.Join(etcDir, fn)) + // This use of ReadAll is now safe, because FileWalker's Walk() + // have limited r. + var data []byte + data, err = io.ReadAll(r) if err != nil { - continue + return nil, false, err } - if bytes.Contains(body, osNameData) { - return true - } - } + return nil, !stringutil.ContainsFold(string(data), osNameData), nil + }).Walk("/etc/*release*") - return false + return err == nil && ok } diff --git a/internal/querylog/searchcriterion.go b/internal/querylog/searchcriterion.go index 25ffc216..6517c936 100644 --- a/internal/querylog/searchcriterion.go +++ b/internal/querylog/searchcriterion.go @@ -2,10 +2,9 @@ package querylog import ( "strings" - "unicode" - "unicode/utf8" "github.com/AdguardTeam/AdGuardHome/internal/filtering" + "github.com/AdguardTeam/golibs/stringutil" ) type criterionType int @@ -69,37 +68,6 @@ func ctDomainOrClientCaseStrict( strings.EqualFold(name, term) } -// containsFold reports whehter s contains, ignoring letter case, substr. -// -// TODO(a.garipov): Move to aghstrings if needed elsewhere. -func containsFold(s, substr string) (ok bool) { - sLen, substrLen := len(s), len(substr) - if sLen < substrLen { - return false - } - - if sLen == substrLen { - return strings.EqualFold(s, substr) - } - - first, _ := utf8.DecodeRuneInString(substr) - firstFolded := unicode.SimpleFold(first) - - for i := 0; i != -1 && len(s) >= len(substr); { - if strings.EqualFold(s[:substrLen], substr) { - return true - } - - i = strings.IndexFunc(s[1:], func(r rune) (eq bool) { - return r == first || r == firstFolded - }) - - s = s[1+i:] - } - - return false -} - func ctDomainOrClientCaseNonStrict( term string, asciiTerm string, @@ -108,11 +76,11 @@ func ctDomainOrClientCaseNonStrict( host string, ip string, ) (ok bool) { - return containsFold(clientID, term) || - containsFold(host, term) || - (asciiTerm != "" && containsFold(host, asciiTerm)) || - containsFold(ip, term) || - containsFold(name, term) + return stringutil.ContainsFold(clientID, term) || + stringutil.ContainsFold(host, term) || + (asciiTerm != "" && stringutil.ContainsFold(host, asciiTerm)) || + stringutil.ContainsFold(ip, term) || + stringutil.ContainsFold(name, term) } // quickMatch quickly checks if the line matches the given search criterion. diff --git a/internal/querylog/searchcriterion_test.go b/internal/querylog/searchcriterion_test.go deleted file mode 100644 index 65ee6645..00000000 --- a/internal/querylog/searchcriterion_test.go +++ /dev/null @@ -1,121 +0,0 @@ -package querylog - -import ( - "strings" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestContainsFold(t *testing.T) { - testCases := []struct { - name string - inS string - inSubstr string - want bool - }{{ - name: "empty", - inS: "", - inSubstr: "", - want: true, - }, { - name: "shorter", - inS: "a", - inSubstr: "abc", - want: false, - }, { - name: "same_len_true", - inS: "abc", - inSubstr: "abc", - want: true, - }, { - name: "same_len_true_fold", - inS: "abc", - inSubstr: "aBc", - want: true, - }, { - name: "same_len_false", - inS: "abc", - inSubstr: "def", - want: false, - }, { - name: "longer_true", - inS: "abcdedef", - inSubstr: "def", - want: true, - }, { - name: "longer_false", - inS: "abcded", - inSubstr: "ghi", - want: false, - }, { - name: "longer_true_fold", - inS: "abcdedef", - inSubstr: "dEf", - want: true, - }, { - name: "longer_false_fold", - inS: "abcded", - inSubstr: "gHi", - want: false, - }, { - name: "longer_true_cyr_fold", - inS: "абвгдедеё", - inSubstr: "дЕЁ", - want: true, - }, { - name: "longer_false_cyr_fold", - inS: "абвгдедеё", - inSubstr: "жЗИ", - want: false, - }, { - name: "no_letters_true", - inS: "1.2.3.4", - inSubstr: "2.3.4", - want: true, - }, { - name: "no_letters_false", - inS: "1.2.3.4", - inSubstr: "2.3.5", - want: false, - }} - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - if tc.want { - assert.True(t, containsFold(tc.inS, tc.inSubstr)) - } else { - assert.False(t, containsFold(tc.inS, tc.inSubstr)) - } - }) - } -} - -var sink bool - -func BenchmarkContainsFold(b *testing.B) { - const s = "aaahBbBhccchDDDeEehFfFhGGGhHhh" - const substr = "HHH" - - // Compare our implementation of containsFold against a stupid solution - // of calling strings.ToLower and strings.Contains. - b.Run("containsfold", func(b *testing.B) { - b.ReportAllocs() - - for i := 0; i < b.N; i++ { - sink = containsFold(s, substr) - } - - assert.True(b, sink) - }) - - b.Run("tolower_contains", func(b *testing.B) { - b.ReportAllocs() - - for i := 0; i < b.N; i++ { - sink = strings.Contains(strings.ToLower(s), strings.ToLower(substr)) - } - - assert.True(b, sink) - }) -}