From 176a344aeea480f83276eb1b9997dbb0aa0f023e Mon Sep 17 00:00:00 2001 From: Eugene Burkov Date: Wed, 15 Sep 2021 20:09:32 +0300 Subject: [PATCH] Pull request: 3567 filters update Merge in DNS/adguard-home from 3567-old-filters to master Updates #3567. Squashed commit of the following: commit d5cc419f1b01f89b2cbf40ff98b562d3498c15c2 Author: Eugene Burkov Date: Wed Sep 15 19:22:49 2021 +0300 home: lock doc commit 54edba6b3bd87a5e6a46c626db8eca9f4cd50858 Author: Eugene Burkov Date: Wed Sep 15 14:16:20 2021 +0300 home: imp code, docs commit e6dde1d3b3e3e0b196361806e77708bb797f5d29 Author: Eugene Burkov Date: Wed Sep 15 13:53:50 2021 +0300 home: imp code, logic commit b258b62948504e62d0e6366605dbd288f4584ada Author: Eugene Burkov Date: Tue Sep 14 19:35:14 2021 +0300 all: imp log of changes commit 9b66cde852ae1741d10e54fcb1d13d9676b42436 Author: Eugene Burkov Date: Tue Sep 14 18:56:52 2021 +0300 home: imp filter upd --- CHANGELOG.md | 4 +- internal/home/controlfiltering.go | 14 ++- internal/home/filter.go | 160 ++++++++++++++++++------------ internal/home/filter_test.go | 122 +++++++++++++++-------- 4 files changed, 191 insertions(+), 109 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index db8d8dc9..07827d96 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,7 +10,7 @@ and this project adheres to ## [Unreleased] ### Added @@ -115,6 +115,7 @@ In this release, the schema version has changed from 10 to 12. ### Fixed +- Removal of temporary filter files ([#3567]). - Panic when an upstream server responds with an empty question section ([#3551]). - 9GAG blocking ([#3564]). @@ -195,6 +196,7 @@ In this release, the schema version has changed from 10 to 12. [#3538]: https://github.com/AdguardTeam/AdGuardHome/issues/3538 [#3551]: https://github.com/AdguardTeam/AdGuardHome/issues/3551 [#3564]: https://github.com/AdguardTeam/AdGuardHome/issues/3564 +[#3567]: https://github.com/AdguardTeam/AdGuardHome/issues/3567 [#3568]: https://github.com/AdguardTeam/AdGuardHome/issues/3568 [#3579]: https://github.com/AdguardTeam/AdGuardHome/issues/3579 diff --git a/internal/home/controlfiltering.go b/internal/home/controlfiltering.go index 229c1e4c..550b4b87 100644 --- a/internal/home/controlfiltering.go +++ b/internal/home/controlfiltering.go @@ -254,13 +254,21 @@ func (f *Filtering) handleFilteringRefresh(w http.ResponseWriter, r *http.Reques return } - Context.controlLock.Unlock() flags := filterRefreshBlocklists if req.White { flags = filterRefreshAllowlists } - resp.Updated, err = f.refreshFilters(flags|filterRefreshForce, false) - Context.controlLock.Lock() + func() { + // Temporarily unlock the Context.controlLock because the + // f.refreshFilters waits for it to be unlocked but it's + // actually locked in ensure wrapper. + // + // TODO(e.burkov): Reconsider this messy syncing process. + Context.controlLock.Unlock() + defer Context.controlLock.Lock() + + resp.Updated, err = f.refreshFilters(flags|filterRefreshForce, false) + }() if err != nil { httpError(w, http.StatusInternalServerError, "%s", err) return diff --git a/internal/home/filter.go b/internal/home/filter.go index 03965fbf..22bc7435 100644 --- a/internal/home/filter.go +++ b/internal/home/filter.go @@ -18,6 +18,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/stringutil" ) var nextFilterID = time.Now().Unix() // semi-stable way to generate an unique ID @@ -535,26 +536,85 @@ func (f *Filtering) read(reader io.Reader, tmpFile *os.File, filter *filter) (in } } -// updateIntl returns true if filter update performed successfully. -func (f *Filtering) updateIntl(filter *filter) (updated bool, err error) { - updated = false - log.Tracef("Downloading update for filter %d from %s", filter.ID, filter.URL) +// finalizeUpdate closes and gets rid of temporary file f with filter's content +// according to updated. It also saves new values of flt's name, rules number +// and checksum if sucсeeded. +func finalizeUpdate( + f *os.File, + flt *filter, + updated bool, + name string, + rnum int, + cs uint32, +) (err error) { + tmpFileName := f.Name() - tmpFile, err := os.CreateTemp(filepath.Join(Context.getDataDir(), filterDir), "") + // Close the file before renaming it because it's required on Windows. + // + // See https://github.com/adguardTeam/adGuardHome/issues/1553. + if err = f.Close(); err != nil { + return fmt.Errorf("closing temporary file: %w", err) + } + + if !updated { + log.Tracef("filter #%d from %s has no changes, skip", flt.ID, flt.URL) + + return os.Remove(tmpFileName) + } + + log.Printf("saving filter %d contents to: %s", flt.ID, flt.Path()) + + if err = os.Rename(tmpFileName, flt.Path()); err != nil { + return errors.WithDeferred(err, os.Remove(tmpFileName)) + } + + flt.Name = stringutil.Coalesce(flt.Name, name) + flt.checksum = cs + flt.RulesCount = rnum + + return nil +} + +// processUpdate copies filter's content from src to dst and returns the name, +// rules number, and checksum for it. It also returns the number of bytes read +// from src. +func (f *Filtering) processUpdate( + src io.Reader, + dst *os.File, + flt *filter, +) (name string, rnum int, cs uint32, n int, err error) { + if n, err = f.read(src, dst, flt); err != nil { + return "", 0, 0, 0, err + } + + if _, err = dst.Seek(0, io.SeekStart); err != nil { + return "", 0, 0, 0, err + } + + rnum, cs, name = f.parseFilterContents(dst) + + return name, rnum, cs, n, nil +} + +// updateIntl updates the flt rewriting it's actual file. It returns true if +// the actual update has been performed. +func (f *Filtering) updateIntl(flt *filter) (ok bool, err error) { + log.Tracef("downloading update for filter %d from %s", flt.ID, flt.URL) + + var name string + var rnum, n int + var cs uint32 + + var tmpFile *os.File + tmpFile, err = os.CreateTemp(filepath.Join(Context.getDataDir(), filterDir), "") if err != nil { - return updated, err + return false, err } defer func() { - var derr error - if tmpFile != nil { - if derr = tmpFile.Close(); derr != nil { - log.Printf("Couldn't close temporary file: %s", derr) - } - - tmpFileName := tmpFile.Name() - if derr = os.Remove(tmpFileName); derr != nil { - log.Printf("Couldn't delete temporary file %s: %s", tmpFileName, derr) - } + err = errors.WithDeferred(err, finalizeUpdate(tmpFile, flt, ok, name, rnum, cs)) + ok = ok && err == nil + if ok { + log.Printf("updated filter %d: %d bytes, %d rules", flt.ID, n, rnum) } }() @@ -562,72 +622,42 @@ func (f *Filtering) updateIntl(filter *filter) (updated bool, err error) { // end users. // // See https://github.com/AdguardTeam/AdGuardHome/issues/3198. - err = tmpFile.Chmod(0o644) - if err != nil { - return updated, fmt.Errorf("changing file mode: %w", err) + if err = tmpFile.Chmod(0o644); err != nil { + return false, fmt.Errorf("changing file mode: %w", err) } - var reader io.Reader - if filepath.IsAbs(filter.URL) { - var f io.ReadCloser - f, err = os.Open(filter.URL) + var r io.Reader + if filepath.IsAbs(flt.URL) { + var file io.ReadCloser + file, err = os.Open(flt.URL) if err != nil { - return updated, fmt.Errorf("open file: %w", err) + return false, fmt.Errorf("open file: %w", err) } - defer func() { err = errors.WithDeferred(err, f.Close()) }() + defer func() { err = errors.WithDeferred(err, file.Close()) }() - reader = f + r = file } else { var resp *http.Response - resp, err = Context.client.Get(filter.URL) + resp, err = Context.client.Get(flt.URL) if err != nil { - log.Printf("Couldn't request filter from URL %s, skipping: %s", filter.URL, err) + log.Printf("requesting filter from %s, skip: %s", flt.URL, err) - return updated, err + return false, err } defer func() { err = errors.WithDeferred(err, resp.Body.Close()) }() if resp.StatusCode != http.StatusOK { - log.Printf("Got status code %d from URL %s, skipping", resp.StatusCode, filter.URL) - return updated, fmt.Errorf("got status code != 200: %d", resp.StatusCode) + log.Printf("got status code %d from %s, skip", resp.StatusCode, flt.URL) + + return false, fmt.Errorf("got status code != 200: %d", resp.StatusCode) } - reader = resp.Body + + r = resp.Body } - total, err := f.read(reader, tmpFile, filter) - if err != nil { - return updated, err - } + name, rnum, cs, n, err = f.processUpdate(r, tmpFile, flt) - // Extract filter name and count number of rules - _, _ = tmpFile.Seek(0, io.SeekStart) - rulesCount, checksum, filterName := f.parseFilterContents(tmpFile) - // Check if the filter has been really changed - if filter.checksum == checksum { - log.Tracef("Filter #%d at URL %s hasn't changed, not updating it", filter.ID, filter.URL) - return updated, nil - } - - log.Printf("Filter %d has been updated: %d bytes, %d rules", - filter.ID, total, rulesCount) - if len(filter.Name) == 0 { - filter.Name = filterName - } - filter.RulesCount = rulesCount - filter.checksum = checksum - filterFilePath := filter.Path() - log.Printf("Saving filter %d contents to: %s", filter.ID, filterFilePath) - - // Closing the file before renaming it is necessary on Windows - _ = tmpFile.Close() - err = os.Rename(tmpFile.Name(), filterFilePath) - if err != nil { - return updated, err - } - tmpFile = nil - updated = true - - return updated, nil + return cs != flt.checksum, err } // loads filter contents from the file in dataDir diff --git a/internal/home/filter_test.go b/internal/home/filter_test.go index 21156474..0aa8d3b1 100644 --- a/internal/home/filter_test.go +++ b/internal/home/filter_test.go @@ -1,76 +1,118 @@ package home import ( - "fmt" + "io/fs" "net" "net/http" + "net/url" "os" + "path" + "path/filepath" "testing" "time" + "github.com/AdguardTeam/golibs/netutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func testStartFilterListener(t *testing.T) net.Listener { +const testFltsFileName = "1.txt" + +func testStartFilterListener(t *testing.T, fltContent *[]byte) (l net.Listener) { t.Helper() + h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + n, werr := w.Write(*fltContent) + require.NoError(t, werr) + require.Equal(t, len(*fltContent), n) + }) + + var err error + l, err = net.Listen("tcp", ":0") + require.NoError(t, err) + + go func() { + _ = http.Serve(l, h) + }() + t.Cleanup(func() { + require.NoError(t, l.Close()) + }) + + return l +} + +func TestFilters(t *testing.T) { const content = `||example.org^$third-party # Inline comment example ||example.com^$third-party 0.0.0.0 example.com ` - mux := http.NewServeMux() - mux.HandleFunc("/filters/1.txt", func(w http.ResponseWriter, r *http.Request) { - _, werr := w.Write([]byte(content)) - assert.Nil(t, werr) - }) + fltContent := []byte(content) - listener, err := net.Listen("tcp", ":0") - require.Nil(t, err) - - go func() { - _ = http.Serve(listener, mux) - }() - - t.Cleanup(func() { - assert.Nil(t, listener.Close()) - }) - - return listener -} - -func TestFilters(t *testing.T) { - l := testStartFilterListener(t) - dir := t.TempDir() + l := testStartFilterListener(t, &fltContent) Context = homeContext{ - workDir: dir, + workDir: t.TempDir(), client: &http.Client{ Timeout: 5 * time.Second, }, } Context.filters.Init() - f := filter{ - URL: fmt.Sprintf("http://127.0.0.1:%d/filters/1.txt", l.Addr().(*net.TCPAddr).Port), + f := &filter{ + URL: (&url.URL{ + Scheme: "http", + Host: (&netutil.IPPort{ + IP: net.IP{127, 0, 0, 1}, + Port: l.Addr().(*net.TCPAddr).Port, + }).String(), + Path: path.Join(filterDir, testFltsFileName), + }).String(), } - // Download. - ok, err := Context.filters.update(&f) - require.Nil(t, err) - require.True(t, ok) - assert.Equal(t, 3, f.RulesCount) + updateAndAssert := func(t *testing.T, want require.BoolAssertionFunc, wantRulesCount int) { + ok, err := Context.filters.update(f) + require.NoError(t, err) + want(t, ok) - // Refresh. - ok, err = Context.filters.update(&f) - require.Nil(t, err) - require.False(t, ok) + assert.Equal(t, wantRulesCount, f.RulesCount) - err = Context.filters.load(&f) - require.Nil(t, err) + var dir []fs.DirEntry + dir, err = os.ReadDir(filepath.Join(Context.getDataDir(), filterDir)) + require.NoError(t, err) - f.unload() - require.Nil(t, os.Remove(f.Path())) + assert.Len(t, dir, 1) + + require.FileExists(t, f.Path()) + + err = Context.filters.load(f) + require.NoError(t, err) + } + + t.Run("download", func(t *testing.T) { + updateAndAssert(t, require.True, 3) + }) + + t.Run("refresh_idle", func(t *testing.T) { + updateAndAssert(t, require.False, 3) + }) + + t.Run("refresh_actually", func(t *testing.T) { + fltContent = []byte(`||example.com^`) + t.Cleanup(func() { + fltContent = []byte(content) + }) + + updateAndAssert(t, require.True, 1) + }) + + t.Run("load_unload", func(t *testing.T) { + err := Context.filters.load(f) + require.NoError(t, err) + + f.unload() + }) + + require.NoError(t, os.Remove(f.Path())) }