diff --git a/CHANGELOG.md b/CHANGELOG.md index 8183fd68..ce333385 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,10 @@ and this project adheres to - Support for Discovery of Designated Resolvers (DDR) according to the [RFC draft][ddr-draft] ([#4463]). +### Fixed + +- Data races and concurrent map access in statistics module ([#4358], [#4342]). + ### Deprecated - Go 1.18 support. v0.109.0 will require at least Go 1.19 to build. @@ -35,6 +39,8 @@ and this project adheres to ([#4670]). [#2993]: https://github.com/AdguardTeam/AdGuardHome/issues/2993 +[#4342]: https://github.com/AdguardTeam/AdGuardHome/issues/4342 +[#4358]: https://github.com/AdguardTeam/AdGuardHome/issues/4358 [#4670]: https://github.com/AdguardTeam/AdGuardHome/issues/4670 [ddr-draft]: https://datatracker.ietf.org/doc/html/draft-ietf-add-ddr-08 diff --git a/internal/aghhttp/aghhttp.go b/internal/aghhttp/aghhttp.go index e186f8a3..57a1c868 100644 --- a/internal/aghhttp/aghhttp.go +++ b/internal/aghhttp/aghhttp.go @@ -9,6 +9,12 @@ import ( "github.com/AdguardTeam/golibs/log" ) +// RegisterFunc is the function that sets the handler to handle the URL for the +// method. +// +// TODO(e.burkov, a.garipov): Get rid of it. +type RegisterFunc func(method, url string, handler http.HandlerFunc) + // OK responds with word OK. func OK(w http.ResponseWriter) { if _, err := io.WriteString(w, "OK\n"); err != nil { diff --git a/internal/dhcpd/dhcpd.go b/internal/dhcpd/dhcpd.go index 55c56c18..a085e656 100644 --- a/internal/dhcpd/dhcpd.go +++ b/internal/dhcpd/dhcpd.go @@ -5,11 +5,11 @@ import ( "encoding/json" "fmt" "net" - "net/http" "path/filepath" "runtime" "time" + "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/netutil" ) @@ -126,7 +126,7 @@ type ServerConfig struct { ConfigModified func() `yaml:"-"` // Register an HTTP handler - HTTPRegister func(string, string, func(http.ResponseWriter, *http.Request)) `yaml:"-"` + HTTPRegister aghhttp.RegisterFunc `yaml:"-"` Enabled bool `yaml:"enabled"` InterfaceName string `yaml:"interface_name"` diff --git a/internal/dnsforward/config.go b/internal/dnsforward/config.go index eaee9155..d5e918c3 100644 --- a/internal/dnsforward/config.go +++ b/internal/dnsforward/config.go @@ -5,12 +5,12 @@ import ( "crypto/x509" "fmt" "net" - "net/http" "os" "sort" "strings" "time" + "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/AdGuardHome/internal/aghtls" "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/dnsproxy/proxy" @@ -193,7 +193,7 @@ type ServerConfig struct { ConfigModified func() // Register an HTTP handler - HTTPRegister func(string, string, func(http.ResponseWriter, *http.Request)) + HTTPRegister aghhttp.RegisterFunc // ResolveClients signals if the RDNS should resolve clients' addresses. ResolveClients bool diff --git a/internal/dnsforward/dnsforward.go b/internal/dnsforward/dnsforward.go index 81ac93ed..ca479fc4 100644 --- a/internal/dnsforward/dnsforward.go +++ b/internal/dnsforward/dnsforward.go @@ -61,7 +61,7 @@ type Server struct { dnsFilter *filtering.DNSFilter // DNS filter instance dhcpServer dhcpd.ServerInterface // DHCP server instance (optional) queryLog querylog.QueryLog // Query log instance - stats stats.Stats + stats stats.Interface access *accessCtx // localDomainSuffix is the suffix used to detect internal hosts. It @@ -107,7 +107,7 @@ const defaultLocalDomainSuffix = "lan" // DNSCreateParams are parameters to create a new server. type DNSCreateParams struct { DNSFilter *filtering.DNSFilter - Stats stats.Stats + Stats stats.Interface QueryLog querylog.QueryLog DHCPServer dhcpd.ServerInterface PrivateNets netutil.SubnetSet diff --git a/internal/dnsforward/stats_test.go b/internal/dnsforward/stats_test.go index fdaa3678..d991be12 100644 --- a/internal/dnsforward/stats_test.go +++ b/internal/dnsforward/stats_test.go @@ -34,7 +34,7 @@ func (l *testQueryLog) Add(p *querylog.AddParams) { type testStats struct { // Stats is embedded here simply to make testStats a stats.Stats without // actually implementing all methods. - stats.Stats + stats.Interface lastEntry stats.Entry } diff --git a/internal/filtering/filtering.go b/internal/filtering/filtering.go index 8af49b0c..4a3e6b28 100644 --- a/internal/filtering/filtering.go +++ b/internal/filtering/filtering.go @@ -6,7 +6,6 @@ import ( "fmt" "io/fs" "net" - "net/http" "os" "runtime" "runtime/debug" @@ -14,6 +13,7 @@ import ( "sync" "sync/atomic" + "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/cache" @@ -94,7 +94,7 @@ type Config struct { ConfigModified func() `yaml:"-"` // Register an HTTP handler - HTTPRegister func(string, string, func(http.ResponseWriter, *http.Request)) `yaml:"-"` + HTTPRegister aghhttp.RegisterFunc `yaml:"-"` // CustomResolver is the resolver used by DNSFilter. CustomResolver Resolver `yaml:"-"` diff --git a/internal/home/clients.go b/internal/home/clients.go index 0b9bfcd0..e50b7904 100644 --- a/internal/home/clients.go +++ b/internal/home/clients.go @@ -2,6 +2,7 @@ package home import ( "bytes" + "encoding" "fmt" "net" "sort" @@ -60,6 +61,33 @@ const ( ClientSourceHostsFile ) +var _ fmt.Stringer = clientSource(0) + +// String returns a human-readable name of cs. +func (cs clientSource) String() (s string) { + switch cs { + case ClientSourceWHOIS: + return "WHOIS" + case ClientSourceARP: + return "ARP" + case ClientSourceRDNS: + return "rDNS" + case ClientSourceDHCP: + return "DHCP" + case ClientSourceHostsFile: + return "etc/hosts" + default: + return "" + } +} + +var _ encoding.TextMarshaler = clientSource(0) + +// MarshalText implements encoding.TextMarshaler for the clientSource. +func (cs clientSource) MarshalText() (text []byte, err error) { + return []byte(cs.String()), nil +} + // clientSourceConf is used to configure where the runtime clients will be // obtained from. type clientSourcesConf struct { @@ -397,6 +425,7 @@ func (clients *clientsContainer) Find(id string) (c *Client, ok bool) { c.Tags = stringutil.CloneSlice(c.Tags) c.BlockedServices = stringutil.CloneSlice(c.BlockedServices) c.Upstreams = stringutil.CloneSlice(c.Upstreams) + return c, true } diff --git a/internal/home/clientshttp.go b/internal/home/clientshttp.go index 3bdf95e1..5f10ccbe 100644 --- a/internal/home/clientshttp.go +++ b/internal/home/clientshttp.go @@ -47,9 +47,9 @@ type clientJSON struct { type runtimeClientJSON struct { WHOISInfo *RuntimeClientWHOISInfo `json:"whois_info"` - Name string `json:"name"` - Source string `json:"source"` - IP net.IP `json:"ip"` + Name string `json:"name"` + Source clientSource `json:"source"` + IP net.IP `json:"ip"` } type clientListJSON struct { @@ -81,20 +81,9 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http cj := runtimeClientJSON{ WHOISInfo: rc.WHOISInfo, - Name: rc.Host, - IP: ip, - } - - cj.Source = "etc/hosts" - switch rc.Source { - case ClientSourceDHCP: - cj.Source = "DHCP" - case ClientSourceRDNS: - cj.Source = "rDNS" - case ClientSourceARP: - cj.Source = "ARP" - case ClientSourceWHOIS: - cj.Source = "WHOIS" + Name: rc.Host, + Source: rc.Source, + IP: ip, } data.RuntimeClients = append(data.RuntimeClients, cj) @@ -107,13 +96,7 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http w.Header().Set("Content-Type", "application/json") e := json.NewEncoder(w).Encode(data) if e != nil { - aghhttp.Error( - r, - w, - http.StatusInternalServerError, - "Failed to encode to json: %v", - e, - ) + aghhttp.Error(r, w, http.StatusInternalServerError, "failed to encode to json: %v", e) return } @@ -279,9 +262,9 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http func (clients *clientsContainer) findRuntime(ip net.IP, idStr string) (cj *clientJSON) { rc, ok := clients.FindRuntimeClient(ip) if !ok { - // It is still possible that the IP used to be in the runtime - // clients list, but then the server was reloaded. So, check - // the DNS server's blocked IP list. + // It is still possible that the IP used to be in the runtime clients + // list, but then the server was reloaded. So, check the DNS server's + // blocked IP list. // // See https://github.com/AdguardTeam/AdGuardHome/issues/2428. disallowed, rule := clients.dnsServer.IsBlockedClient(ip, idStr) diff --git a/internal/home/control.go b/internal/home/control.go index 8234e00e..54d1652a 100644 --- a/internal/home/control.go +++ b/internal/home/control.go @@ -189,7 +189,7 @@ func registerControlHandlers() { RegisterAuthHandlers() } -func httpRegister(method, url string, handler func(http.ResponseWriter, *http.Request)) { +func httpRegister(method, url string, handler http.HandlerFunc) { if method == "" { // "/dns-query" handler doesn't need auth, gzip and isn't restricted by 1 HTTP method Context.mux.HandleFunc(url, postInstall(handler)) diff --git a/internal/home/home.go b/internal/home/home.go index 15fd4c46..0fc16c09 100644 --- a/internal/home/home.go +++ b/internal/home/home.go @@ -47,7 +47,7 @@ type homeContext struct { // -- clients clientsContainer // per-client-settings module - stats stats.Stats // statistics module + stats stats.Interface // statistics module queryLog querylog.QueryLog // query log module dnsServer *dnsforward.Server // DNS module rdns *RDNS // rDNS module diff --git a/internal/querylog/querylog.go b/internal/querylog/querylog.go index a854c2c4..2d8e397f 100644 --- a/internal/querylog/querylog.go +++ b/internal/querylog/querylog.go @@ -2,10 +2,10 @@ package querylog import ( "net" - "net/http" "path/filepath" "time" + "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/golibs/errors" @@ -38,7 +38,7 @@ type Config struct { ConfigModified func() // HTTPRegister registers an HTTP handler. - HTTPRegister func(string, string, func(http.ResponseWriter, *http.Request)) + HTTPRegister aghhttp.RegisterFunc // FindClient returns client information by their IDs. FindClient func(ids []string) (c *Client, err error) diff --git a/internal/stats/http.go b/internal/stats/http.go index e2f00039..033dd3bb 100644 --- a/internal/stats/http.go +++ b/internal/stats/http.go @@ -39,34 +39,21 @@ type statsResponse struct { } // handleStats is a handler for getting statistics. -func (s *statsCtx) handleStats(w http.ResponseWriter, r *http.Request) { +func (s *StatsCtx) handleStats(w http.ResponseWriter, r *http.Request) { start := time.Now() var resp statsResponse - if s.conf.limit == 0 { - resp = statsResponse{ - TimeUnits: "days", + var ok bool + resp, ok = s.getData() - TopBlocked: []topAddrs{}, - TopClients: []topAddrs{}, - TopQueried: []topAddrs{}, + log.Debug("stats: prepared data in %v", time.Since(start)) - BlockedFiltering: []uint64{}, - DNSQueries: []uint64{}, - ReplacedParental: []uint64{}, - ReplacedSafebrowsing: []uint64{}, - } - } else { - var ok bool - resp, ok = s.getData() + if !ok { + // Don't bring the message to the lower case since it's a part of UI + // text for the moment. + aghhttp.Error(r, w, http.StatusInternalServerError, "Couldn't get statistics data") - log.Debug("stats: prepared data in %v", time.Since(start)) - - if !ok { - aghhttp.Error(r, w, http.StatusInternalServerError, "Couldn't get statistics data") - - return - } + return } w.Header().Set("Content-Type", "application/json") @@ -84,9 +71,9 @@ type config struct { } // Get configuration -func (s *statsCtx) handleStatsInfo(w http.ResponseWriter, r *http.Request) { +func (s *StatsCtx) handleStatsInfo(w http.ResponseWriter, r *http.Request) { resp := config{} - resp.IntervalDays = s.conf.limit / 24 + resp.IntervalDays = s.limitHours / 24 data, err := json.Marshal(resp) if err != nil { @@ -102,7 +89,7 @@ func (s *statsCtx) handleStatsInfo(w http.ResponseWriter, r *http.Request) { } // Set configuration -func (s *statsCtx) handleStatsConfig(w http.ResponseWriter, r *http.Request) { +func (s *StatsCtx) handleStatsConfig(w http.ResponseWriter, r *http.Request) { reqData := config{} err := json.NewDecoder(r.Body).Decode(&reqData) if err != nil { @@ -118,22 +105,22 @@ func (s *statsCtx) handleStatsConfig(w http.ResponseWriter, r *http.Request) { } s.setLimit(int(reqData.IntervalDays)) - s.conf.ConfigModified() + s.configModified() } // Reset data -func (s *statsCtx) handleStatsReset(w http.ResponseWriter, r *http.Request) { +func (s *StatsCtx) handleStatsReset(w http.ResponseWriter, r *http.Request) { s.clear() } // Register web handlers -func (s *statsCtx) initWeb() { - if s.conf.HTTPRegister == nil { +func (s *StatsCtx) initWeb() { + if s.httpRegister == nil { return } - s.conf.HTTPRegister(http.MethodGet, "/control/stats", s.handleStats) - s.conf.HTTPRegister(http.MethodPost, "/control/stats_reset", s.handleStatsReset) - s.conf.HTTPRegister(http.MethodPost, "/control/stats_config", s.handleStatsConfig) - s.conf.HTTPRegister(http.MethodGet, "/control/stats_info", s.handleStatsInfo) + s.httpRegister(http.MethodGet, "/control/stats", s.handleStats) + s.httpRegister(http.MethodPost, "/control/stats_reset", s.handleStatsReset) + s.httpRegister(http.MethodPost, "/control/stats_config", s.handleStatsConfig) + s.httpRegister(http.MethodGet, "/control/stats_info", s.handleStatsInfo) } diff --git a/internal/stats/stats.go b/internal/stats/stats.go index 2944a163..04a933d4 100644 --- a/internal/stats/stats.go +++ b/internal/stats/stats.go @@ -4,75 +4,85 @@ package stats import ( "net" - "net/http" + + "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" ) -type unitIDCallback func() uint32 +// UnitIDGenFunc is the signature of a function that generates a unique ID for +// the statistics unit. +type UnitIDGenFunc func() (id uint32) -// DiskConfig - configuration settings that are stored on disk +// DiskConfig is the configuration structure that is stored in file. type DiskConfig struct { - Interval uint32 `yaml:"statistics_interval"` // time interval for statistics (in days) + // Interval is the number of days for which the statistics are collected + // before flushing to the database. + Interval uint32 `yaml:"statistics_interval"` } -// Config - module configuration +// Config is the configuration structure for the statistics collecting. type Config struct { - Filename string // database file name - LimitDays uint32 // time limit (in days) - UnitID unitIDCallback // user function to get the current unit ID. If nil, the current time hour is used. + // UnitID is the function to generate the identifier for current unit. If + // nil, the default function is used, see newUnitID. + UnitID UnitIDGenFunc - // Called when the configuration is changed by HTTP request + // ConfigModified will be called each time the configuration changed via web + // interface. ConfigModified func() - // Register an HTTP handler - HTTPRegister func(string, string, func(http.ResponseWriter, *http.Request)) + // HTTPRegister is the function that registers handlers for the stats + // endpoints. + HTTPRegister aghhttp.RegisterFunc - limit uint32 // maximum time we need to keep data for (in hours) + // Filename is the name of the database file. + Filename string + + // LimitDays is the maximum number of days to collect statistics into the + // current unit. + LimitDays uint32 } -// New - create object -func New(conf Config) (Stats, error) { - return createObject(conf) -} - -// Stats - main interface -type Stats interface { +// Interface is the statistics interface to be used by other packages. +type Interface interface { + // Start begins the statistics collecting. Start() - // Close object. - // This function is not thread safe - // (can't be called in parallel with any other function of this interface). + // Close stops the statistics collecting. Close() - // Update counters + // Update collects the incoming statistics data. Update(e Entry) - // Get IP addresses of the clients with the most number of requests + // GetTopClientIP returns at most limit IP addresses corresponding to the + // clients with the most number of requests. GetTopClientsIP(limit uint) []net.IP - // WriteDiskConfig - write configuration + // WriteDiskConfig puts the Interface's configuration to the dc. WriteDiskConfig(dc *DiskConfig) } -// TimeUnit - time unit +// TimeUnit is the unit of measuring time while aggregating the statistics. type TimeUnit int -// Supported time units +// Supported TimeUnit values. const ( Hours TimeUnit = iota Days ) -// Result of DNS request processing +// Result is the resulting code of processing the DNS request. type Result int -// Supported result values +// Supported Result values. +// +// TODO(e.burkov): Think about better naming. const ( RNotFiltered Result = iota + 1 RFiltered RSafeBrowsing RSafeSearch RParental - rLast + + resultLast = RParental + 1 ) // Entry is a statistics data entry. @@ -82,7 +92,12 @@ type Entry struct { // TODO(a.garipov): Make this a {net.IP, string} enum? Client string + // Domain is the domain name requested. Domain string + + // Result is the result of processing the request. Result Result - Time uint32 // processing time (msec) + + // Time is the duration of the request processing in milliseconds. + Time uint32 } diff --git a/internal/stats/stats_test.go b/internal/stats/stats_test.go index 70b71db8..0cffd2e3 100644 --- a/internal/stats/stats_test.go +++ b/internal/stats/stats_test.go @@ -37,7 +37,7 @@ func TestStats(t *testing.T) { LimitDays: 1, } - s, err := createObject(conf) + s, err := New(conf) require.NoError(t, err) testutil.CleanupAndRequireSuccess(t, func() (err error) { s.clear() @@ -110,7 +110,7 @@ func TestLargeNumbers(t *testing.T) { LimitDays: 1, UnitID: newID, } - s, err := createObject(conf) + s, err := New(conf) require.NoError(t, err) testutil.CleanupAndRequireSuccess(t, func() (err error) { s.Close() diff --git a/internal/stats/unit.go b/internal/stats/unit.go index 35d47a51..6d32a6d1 100644 --- a/internal/stats/unit.go +++ b/internal/stats/unit.go @@ -9,11 +9,13 @@ import ( "os" "sort" "sync" + "sync/atomic" "time" + "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" - bolt "go.etcd.io/bbolt" + "go.etcd.io/bbolt" ) // TODO(a.garipov): Rewrite all of this. Add proper error handling and @@ -24,47 +26,130 @@ const ( maxClients = 100 // max number of top clients to store in file or return via Get() ) -// statsCtx - global context -type statsCtx struct { - // mu protects unit. - mu *sync.Mutex - // current is the actual statistics collection result. - current *unit +// StatsCtx collects the statistics and flushes it to the database. Its default +// flushing interval is one hour. +// +// TODO(e.burkov): Use atomic.Pointer for accessing curr and db in go1.19. +type StatsCtx struct { + // currMu protects the current unit. + currMu *sync.Mutex + // curr is the actual statistics collection result. + curr *unit - db *bolt.DB - conf *Config + // dbMu protects db. + dbMu *sync.Mutex + // db is the opened statistics database, if any. + db *bbolt.DB + + // unitIDGen is the function that generates an identifier for the current + // unit. It's here for only testing purposes. + unitIDGen UnitIDGenFunc + + // httpRegister is used to set HTTP handlers. + httpRegister aghhttp.RegisterFunc + + // configModified is called whenever the configuration is modified via web + // interface. + configModified func() + + // filename is the name of database file. + filename string + + // limitHours is the maximum number of hours to collect statistics into the + // current unit. + limitHours uint32 } -// data for 1 time unit +// unit collects the statistics data for a specific period of time. type unit struct { - id uint32 // unit ID. Default: absolute hour since Jan 1, 1970 + // mu protects all the fields of a unit. + mu *sync.RWMutex - nTotal uint64 // total requests - nResult []uint64 // number of requests per one result - timeSum uint64 // sum of processing time of all requests (usec) + // id is the unique unit's identifier. It's set to an absolute hour number + // since the beginning of UNIX time by the default ID generating function. + id uint32 - // top: - domains map[string]uint64 // number of requests per domain - blockedDomains map[string]uint64 // number of blocked requests per domain - clients map[string]uint64 // number of requests per client + // nTotal stores the total number of requests. + nTotal uint64 + // nResult stores the number of requests grouped by it's result. + nResult []uint64 + // timeSum stores the sum of processing time in milliseconds of each request + // written by the unit. + timeSum uint64 + + // domains stores the number of requests for each domain. + domains map[string]uint64 + // blockedDomains stores the number of requests for each domain that has + // been blocked. + blockedDomains map[string]uint64 + // clients stores the number of requests from each client. + clients map[string]uint64 } -// name-count pair +// ongoing returns the current unit. It's safe for concurrent use. +// +// Note that the unit itself should be locked before accessing. +func (s *StatsCtx) ongoing() (u *unit) { + s.currMu.Lock() + defer s.currMu.Unlock() + + return s.curr +} + +// swapCurrent swaps the current unit with another and returns it. It's safe +// for concurrent use. +func (s *StatsCtx) swapCurrent(with *unit) (old *unit) { + s.currMu.Lock() + defer s.currMu.Unlock() + + old, s.curr = s.curr, with + + return old +} + +// database returns the database if it's opened. It's safe for concurrent use. +func (s *StatsCtx) database() (db *bbolt.DB) { + s.dbMu.Lock() + defer s.dbMu.Unlock() + + return s.db +} + +// swapDatabase swaps the database with another one and returns it. It's safe +// for concurrent use. +func (s *StatsCtx) swapDatabase(with *bbolt.DB) (old *bbolt.DB) { + s.dbMu.Lock() + defer s.dbMu.Unlock() + + old, s.db = s.db, with + + return old +} + +// countPair is a single name-number pair for deserializing statistics data into +// the database. type countPair struct { Name string Count uint64 } -// structure for storing data in file +// unitDB is the structure for deserializing statistics data into the database. type unitDB struct { - NTotal uint64 + // NTotal is the total number of requests. + NTotal uint64 + // NResult is the number of requests by the result's kind. NResult []uint64 - Domains []countPair + // Domains is the number of requests for each domain name. + Domains []countPair + // BlockedDomains is the number of requests blocked for each domain name. BlockedDomains []countPair - Clients []countPair + // Clients is the number of requests from each client. + Clients []countPair - TimeAvg uint32 // usec + // TimeAvg is the average of processing times in milliseconds of all the + // requests in the unit. + TimeAvg uint32 } // withRecovered turns the value recovered from panic if any into an error and @@ -86,34 +171,40 @@ func withRecovered(orig *error) { *orig = errors.WithDeferred(*orig, err) } -// createObject creates s from conf and properly initializes it. -func createObject(conf Config) (s *statsCtx, err error) { +// isEnabled is a helper that check if the statistics collecting is enabled. +func (s *StatsCtx) isEnabled() (ok bool) { + return atomic.LoadUint32(&s.limitHours) != 0 +} + +// New creates s from conf and properly initializes it. Don't use s before +// calling it's Start method. +func New(conf Config) (s *StatsCtx, err error) { defer withRecovered(&err) - s = &statsCtx{ - mu: &sync.Mutex{}, + s = &StatsCtx{ + currMu: &sync.Mutex{}, + dbMu: &sync.Mutex{}, + filename: conf.Filename, + configModified: conf.ConfigModified, + httpRegister: conf.HTTPRegister, } - if !checkInterval(conf.LimitDays) { - conf.LimitDays = 1 + if s.limitHours = conf.LimitDays * 24; !checkInterval(conf.LimitDays) { + s.limitHours = 24 + } + if s.unitIDGen = newUnitID; conf.UnitID != nil { + s.unitIDGen = conf.UnitID } - s.conf = &Config{} - *s.conf = conf - s.conf.limit = conf.LimitDays * 24 - if conf.UnitID == nil { - s.conf.UnitID = newUnitID + if err = s.dbOpen(); err != nil { + return nil, fmt.Errorf("opening database: %w", err) } - if !s.dbOpen() { - return nil, fmt.Errorf("open database") - } - - id := s.conf.UnitID() - tx := s.beginTxn(true) + id := s.unitIDGen() + tx := beginTxn(s.db, true) var udb *unitDB if tx != nil { log.Tracef("Deleting old units...") - firstID := id - s.conf.limit - 1 + firstID := id - s.limitHours - 1 unitDel := 0 err = tx.ForEach(newBucketWalker(tx, &unitDel, firstID)) @@ -133,12 +224,11 @@ func createObject(conf Config) (s *statsCtx, err error) { } } - u := unit{} - s.initUnit(&u, id) - if udb != nil { - deserialize(&u, udb) - } - s.current = &u + u := newUnit(id) + // This use of deserialize is safe since the accessed unit has just been + // created. + u.deserialize(udb) + s.curr = u log.Debug("stats: initialized") @@ -153,11 +243,11 @@ const errStop errors.Error = "stop iteration" // integer that unitDelPtr points to is incremented for every successful // deletion. If the bucket isn't deleted, f returns errStop. func newBucketWalker( - tx *bolt.Tx, + tx *bbolt.Tx, unitDelPtr *int, firstID uint32, -) (f func(name []byte, b *bolt.Bucket) (err error)) { - return func(name []byte, _ *bolt.Bucket) (err error) { +) (f func(name []byte, b *bbolt.Bucket) (err error)) { + return func(name []byte, _ *bbolt.Bucket) (err error) { nameID, ok := unitNameToID(name) if !ok || nameID < firstID { err = tx.DeleteBucket(name) @@ -178,80 +268,92 @@ func newBucketWalker( } } -func (s *statsCtx) Start() { +// Start makes s process the incoming data. +func (s *StatsCtx) Start() { s.initWeb() go s.periodicFlush() } -func checkInterval(days uint32) bool { +// checkInterval returns true if days is valid to be used as statistics +// retention interval. The valid values are 0, 1, 7, 30 and 90. +func checkInterval(days uint32) (ok bool) { return days == 0 || days == 1 || days == 7 || days == 30 || days == 90 } -func (s *statsCtx) dbOpen() bool { - var err error +// dbOpen returns an error if the database can't be opened from the specified +// file. It's safe for concurrent use. +func (s *StatsCtx) dbOpen() (err error) { log.Tracef("db.Open...") - s.db, err = bolt.Open(s.conf.Filename, 0o644, nil) + + s.dbMu.Lock() + defer s.dbMu.Unlock() + + s.db, err = bbolt.Open(s.filename, 0o644, nil) if err != nil { - log.Error("stats: open DB: %s: %s", s.conf.Filename, err) + log.Error("stats: open DB: %s: %s", s.filename, err) if err.Error() == "invalid argument" { log.Error("AdGuard Home cannot be initialized due to an incompatible file system.\nPlease read the explanation here: https://github.com/AdguardTeam/AdGuardHome/wiki/Getting-Started#limitations") } - return false + + return err } + log.Tracef("db.Open") - return true + + return nil } -// Atomically swap the currently active unit with a new value -// Return old value -func (s *statsCtx) swapUnit(new *unit) (u *unit) { - s.mu.Lock() - defer s.mu.Unlock() +// newUnitID is the default UnitIDGenFunc that generates the unique id hourly. +func newUnitID() (id uint32) { + const secsInHour = int64(time.Hour / time.Second) - u = s.current - s.current = new - - return u + return uint32(time.Now().Unix() / secsInHour) } -// Get unit ID for the current hour -func newUnitID() uint32 { - return uint32(time.Now().Unix() / (60 * 60)) +// newUnit allocates the new *unit. +func newUnit(id uint32) (u *unit) { + return &unit{ + mu: &sync.RWMutex{}, + id: id, + nResult: make([]uint64, resultLast), + domains: make(map[string]uint64), + blockedDomains: make(map[string]uint64), + clients: make(map[string]uint64), + } } -// Initialize a unit -func (s *statsCtx) initUnit(u *unit, id uint32) { - u.id = id - u.nResult = make([]uint64, rLast) - u.domains = make(map[string]uint64) - u.blockedDomains = make(map[string]uint64) - u.clients = make(map[string]uint64) -} - -// Open a DB transaction -func (s *statsCtx) beginTxn(wr bool) *bolt.Tx { - db := s.db +// beginTxn opens a new database transaction. If writable is true, the +// transaction will be opened for writing, and for reading otherwise. It +// returns nil if the transaction can't be created. +func beginTxn(db *bbolt.DB, writable bool) (tx *bbolt.Tx) { if db == nil { return nil } - log.Tracef("db.Begin...") - tx, err := db.Begin(wr) + log.Tracef("opening a database transaction") + + tx, err := db.Begin(writable) if err != nil { - log.Error("db.Begin: %s", err) + log.Error("stats: opening a transaction: %s", err) + return nil } - log.Tracef("db.Begin") + + log.Tracef("transaction has been opened") + return tx } -func (s *statsCtx) commitTxn(tx *bolt.Tx) { +// commitTxn applies the changes made in tx to the database. +func (s *StatsCtx) commitTxn(tx *bbolt.Tx) { err := tx.Commit() if err != nil { - log.Debug("tx.Commit: %s", err) + log.Error("stats: committing a transaction: %s", err) + return } - log.Tracef("tx.Commit") + + log.Tracef("transaction has been committed") } // bucketNameLen is the length of a bucket, a 64-bit unsigned integer. @@ -262,10 +364,10 @@ const bucketNameLen = 8 // idToUnitName converts a numerical ID into a database unit name. func idToUnitName(id uint32) (name []byte) { - name = make([]byte, bucketNameLen) - binary.BigEndian.PutUint64(name, uint64(id)) + n := [bucketNameLen]byte{} + binary.BigEndian.PutUint64(n[:], uint64(id)) - return name + return n[:] } // unitNameToID converts a database unit name into a numerical ID. ok is false @@ -278,13 +380,6 @@ func unitNameToID(name []byte) (id uint32, ok bool) { return uint32(binary.BigEndian.Uint64(name)), true } -func (s *statsCtx) ongoing() (u *unit) { - s.mu.Lock() - defer s.mu.Unlock() - - return s.current -} - // Flush the current unit to DB and delete an old unit when a new hour is started // If a unit must be flushed: // . lock DB @@ -293,34 +388,29 @@ func (s *statsCtx) ongoing() (u *unit) { // . write the unit to DB // . remove the stale unit from DB // . unlock DB -func (s *statsCtx) periodicFlush() { - for { - ptr := s.ongoing() - if ptr == nil { - break - } - - id := s.conf.UnitID() - if ptr.id == id || s.conf.limit == 0 { +func (s *StatsCtx) periodicFlush() { + for ptr := s.ongoing(); ptr != nil; ptr = s.ongoing() { + id := s.unitIDGen() + // Access the unit's ID with atomic to avoid locking the whole unit. + if !s.isEnabled() || atomic.LoadUint32(&ptr.id) == id { time.Sleep(time.Second) continue } - tx := s.beginTxn(true) + tx := beginTxn(s.database(), true) - nu := unit{} - s.initUnit(&nu, id) - u := s.swapUnit(&nu) - udb := serialize(u) + nu := newUnit(id) + u := s.swapCurrent(nu) + udb := u.serialize() if tx == nil { continue } - ok1 := s.flushUnitToDB(tx, u.id, udb) - ok2 := s.deleteUnit(tx, id-s.conf.limit) - if ok1 || ok2 { + flushOK := flushUnitToDB(tx, u.id, udb) + delOK := s.deleteUnit(tx, id-atomic.LoadUint32(&s.limitHours)) + if flushOK || delOK { s.commitTxn(tx) } else { _ = tx.Rollback() @@ -330,8 +420,8 @@ func (s *statsCtx) periodicFlush() { log.Tracef("periodicFlush() exited") } -// Delete unit's data from file -func (s *statsCtx) deleteUnit(tx *bolt.Tx, id uint32) bool { +// deleteUnit removes the unit by it's id from the database the tx belongs to. +func (s *StatsCtx) deleteUnit(tx *bbolt.Tx, id uint32) bool { err := tx.DeleteBucket(idToUnitName(id)) if err != nil { log.Tracef("stats: bolt DeleteBucket: %s", err) @@ -347,10 +437,7 @@ func (s *statsCtx) deleteUnit(tx *bolt.Tx, id uint32) bool { func convertMapToSlice(m map[string]uint64, max int) []countPair { a := []countPair{} for k, v := range m { - pair := countPair{} - pair.Name = k - pair.Count = v - a = append(a, pair) + a = append(a, countPair{Name: k, Count: v}) } less := func(i, j int) bool { return a[j].Count < a[i].Count @@ -370,41 +457,46 @@ func convertSliceToMap(a []countPair) map[string]uint64 { return m } -func serialize(u *unit) *unitDB { - udb := unitDB{} - udb.NTotal = u.nTotal - - udb.NResult = append(udb.NResult, u.nResult...) +// serialize converts u to the *unitDB. It's safe for concurrent use. +func (u *unit) serialize() (udb *unitDB) { + u.mu.RLock() + defer u.mu.RUnlock() + var timeAvg uint32 = 0 if u.nTotal != 0 { - udb.TimeAvg = uint32(u.timeSum / u.nTotal) + timeAvg = uint32(u.timeSum / u.nTotal) } - udb.Domains = convertMapToSlice(u.domains, maxDomains) - udb.BlockedDomains = convertMapToSlice(u.blockedDomains, maxDomains) - udb.Clients = convertMapToSlice(u.clients, maxClients) - - return &udb + return &unitDB{ + NTotal: u.nTotal, + NResult: append([]uint64{}, u.nResult...), + Domains: convertMapToSlice(u.domains, maxDomains), + BlockedDomains: convertMapToSlice(u.blockedDomains, maxDomains), + Clients: convertMapToSlice(u.clients, maxClients), + TimeAvg: timeAvg, + } } -func deserialize(u *unit, udb *unitDB) { +// deserealize assigns the appropriate values from udb to u. u must not be nil. +// It's safe for concurrent use. +func (u *unit) deserialize(udb *unitDB) { + if udb == nil { + return + } + + u.mu.Lock() + defer u.mu.Unlock() + u.nTotal = udb.NTotal - - n := len(udb.NResult) - if n < len(u.nResult) { - n = len(u.nResult) // n = min(len(udb.NResult), len(u.nResult)) - } - for i := 1; i < n; i++ { - u.nResult[i] = udb.NResult[i] - } - + u.nResult = make([]uint64, resultLast) + copy(u.nResult, udb.NResult) u.domains = convertSliceToMap(udb.Domains) u.blockedDomains = convertSliceToMap(udb.BlockedDomains) u.clients = convertSliceToMap(udb.Clients) - u.timeSum = uint64(udb.TimeAvg) * u.nTotal + u.timeSum = uint64(udb.TimeAvg) * udb.NTotal } -func (s *statsCtx) flushUnitToDB(tx *bolt.Tx, id uint32, udb *unitDB) bool { +func flushUnitToDB(tx *bbolt.Tx, id uint32, udb *unitDB) bool { log.Tracef("Flushing unit %d", id) bkt, err := tx.CreateBucketIfNotExists(idToUnitName(id)) @@ -430,7 +522,7 @@ func (s *statsCtx) flushUnitToDB(tx *bolt.Tx, id uint32, udb *unitDB) bool { return true } -func (s *statsCtx) loadUnitFromDB(tx *bolt.Tx, id uint32) *unitDB { +func (s *StatsCtx) loadUnitFromDB(tx *bbolt.Tx, id uint32) *unitDB { bkt := tx.Bucket(idToUnitName(id)) if bkt == nil { return nil @@ -451,44 +543,44 @@ func (s *statsCtx) loadUnitFromDB(tx *bolt.Tx, id uint32) *unitDB { return &udb } -func convertTopSlice(a []countPair) []map[string]uint64 { - m := []map[string]uint64{} +func convertTopSlice(a []countPair) (m []map[string]uint64) { + m = make([]map[string]uint64, 0, len(a)) for _, it := range a { - ent := map[string]uint64{} - ent[it.Name] = it.Count - m = append(m, ent) + m = append(m, map[string]uint64{it.Name: it.Count}) } + return m } -func (s *statsCtx) setLimit(limitDays int) { - s.conf.limit = uint32(limitDays) * 24 +func (s *StatsCtx) setLimit(limitDays int) { + atomic.StoreUint32(&s.limitHours, uint32(24*limitDays)) if limitDays == 0 { s.clear() } - log.Debug("stats: set limit: %d", limitDays) + log.Debug("stats: set limit: %d days", limitDays) } -func (s *statsCtx) WriteDiskConfig(dc *DiskConfig) { - dc.Interval = s.conf.limit / 24 +func (s *StatsCtx) WriteDiskConfig(dc *DiskConfig) { + dc.Interval = atomic.LoadUint32(&s.limitHours) / 24 } -func (s *statsCtx) Close() { - u := s.swapUnit(nil) - udb := serialize(u) - tx := s.beginTxn(true) - if tx != nil { - if s.flushUnitToDB(tx, u.id, udb) { +func (s *StatsCtx) Close() { + u := s.swapCurrent(nil) + + db := s.database() + if tx := beginTxn(db, true); tx != nil { + udb := u.serialize() + if flushUnitToDB(tx, u.id, udb) { s.commitTxn(tx) } else { _ = tx.Rollback() } } - if s.db != nil { + if db != nil { log.Tracef("db.Close...") - _ = s.db.Close() + _ = db.Close() log.Tracef("db.Close") } @@ -496,11 +588,11 @@ func (s *statsCtx) Close() { } // Reset counters and clear database -func (s *statsCtx) clear() { - tx := s.beginTxn(true) +func (s *StatsCtx) clear() { + db := s.database() + tx := beginTxn(db, true) if tx != nil { - db := s.db - s.db = nil + _ = s.swapDatabase(nil) _ = tx.Rollback() // the active transactions can continue using database, // but no new transactions will be opened @@ -509,11 +601,10 @@ func (s *statsCtx) clear() { // all active transactions are now closed } - u := unit{} - s.initUnit(&u, s.conf.UnitID()) - _ = s.swapUnit(&u) + u := newUnit(s.unitIDGen()) + _ = s.swapCurrent(u) - err := os.Remove(s.conf.Filename) + err := os.Remove(s.filename) if err != nil { log.Error("os.Remove: %s", err) } @@ -523,13 +614,13 @@ func (s *statsCtx) clear() { log.Debug("stats: cleared") } -func (s *statsCtx) Update(e Entry) { - if s.conf.limit == 0 { +func (s *StatsCtx) Update(e Entry) { + if !s.isEnabled() { return } if e.Result == 0 || - e.Result >= rLast || + e.Result >= resultLast || e.Domain == "" || e.Client == "" { return @@ -540,13 +631,15 @@ func (s *statsCtx) Update(e Entry) { clientID = ip.String() } - s.mu.Lock() - defer s.mu.Unlock() + u := s.ongoing() + if u == nil { + return + } - u := s.current + u.mu.Lock() + defer u.mu.Unlock() u.nResult[e.Result]++ - if e.Result == RNotFiltered { u.domains[e.Domain]++ } else { @@ -558,14 +651,19 @@ func (s *statsCtx) Update(e Entry) { u.nTotal++ } -func (s *statsCtx) loadUnits(limit uint32) ([]*unitDB, uint32) { - tx := s.beginTxn(false) +func (s *StatsCtx) loadUnits(limit uint32) ([]*unitDB, uint32) { + tx := beginTxn(s.database(), false) if tx == nil { return nil, 0 } cur := s.ongoing() - curID := cur.id + var curID uint32 + if cur != nil { + curID = atomic.LoadUint32(&cur.id) + } else { + curID = s.unitIDGen() + } // Per-hour units. units := []*unitDB{} @@ -574,14 +672,16 @@ func (s *statsCtx) loadUnits(limit uint32) ([]*unitDB, uint32) { u := s.loadUnitFromDB(tx, i) if u == nil { u = &unitDB{} - u.NResult = make([]uint64, rLast) + u.NResult = make([]uint64, resultLast) } units = append(units, u) } _ = tx.Rollback() - units = append(units, serialize(cur)) + if cur != nil { + units = append(units, cur.serialize()) + } if len(units) != int(limit) { log.Fatalf("len(units) != limit: %d %d", len(units), limit) @@ -628,13 +728,13 @@ func statsCollector(units []*unitDB, firstID uint32, timeUnit TimeUnit, ng numsG // pairsGetter is a signature for topsCollector argument. type pairsGetter func(u *unitDB) (pairs []countPair) -// topsCollector collects statistics about highest values fro the given *unitDB +// topsCollector collects statistics about highest values from the given *unitDB // slice using pg to retrieve data. func topsCollector(units []*unitDB, max int, pg pairsGetter) []map[string]uint64 { m := map[string]uint64{} for _, u := range units { - for _, it := range pg(u) { - m[it.Name] += it.Count + for _, cp := range pg(u) { + m[cp.Name] += cp.Count } } a2 := convertMapToSlice(m, max) @@ -668,8 +768,22 @@ func topsCollector(units []*unitDB, max int, pg pairsGetter) []map[string]uint64 * parental-blocked These values are just the sum of data for all units. */ -func (s *statsCtx) getData() (statsResponse, bool) { - limit := s.conf.limit +func (s *StatsCtx) getData() (statsResponse, bool) { + limit := atomic.LoadUint32(&s.limitHours) + if limit == 0 { + return statsResponse{ + TimeUnits: "days", + + TopBlocked: []topAddrs{}, + TopClients: []topAddrs{}, + TopQueried: []topAddrs{}, + + BlockedFiltering: []uint64{}, + DNSQueries: []uint64{}, + ReplacedParental: []uint64{}, + ReplacedSafebrowsing: []uint64{}, + }, true + } timeUnit := Hours if limit/24 > 7 { @@ -698,7 +812,7 @@ func (s *statsCtx) getData() (statsResponse, bool) { // Total counters: sum := unitDB{ - NResult: make([]uint64, rLast), + NResult: make([]uint64, resultLast), } timeN := 0 for _, u := range units { @@ -731,12 +845,12 @@ func (s *statsCtx) getData() (statsResponse, bool) { return data, true } -func (s *statsCtx) GetTopClientsIP(maxCount uint) []net.IP { - if s.conf.limit == 0 { +func (s *StatsCtx) GetTopClientsIP(maxCount uint) []net.IP { + if !s.isEnabled() { return nil } - units, _ := s.loadUnits(s.conf.limit) + units, _ := s.loadUnits(atomic.LoadUint32(&s.limitHours)) if units == nil { return nil }