diff --git a/internal/client/client.go b/internal/client/client.go index d0a75045..d3ead923 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -7,6 +7,7 @@ package client import ( "encoding" "fmt" + "net/netip" "github.com/AdguardTeam/AdGuardHome/internal/whois" ) @@ -56,6 +57,9 @@ func (cs Source) MarshalText() (text []byte, err error) { // Runtime is a client information from different sources. type Runtime struct { + // ip is an IP address of a client. + ip netip.Addr + // whois is the filtered WHOIS information of a client. whois *whois.Info @@ -80,6 +84,15 @@ type Runtime struct { hostsFile []string } +// NewRuntime constructs a new runtime client. ip must be valid IP address. +// +// TODO(s.chzhen): Validate IP address. +func NewRuntime(ip netip.Addr) (r *Runtime) { + return &Runtime{ + ip: ip, + } +} + // Info returns a client information from the highest-priority source. func (r *Runtime) Info() (cs Source, host string) { info := []string{} @@ -133,8 +146,8 @@ func (r *Runtime) SetWHOIS(info *whois.Info) { r.whois = info } -// Unset clears a cs information. -func (r *Runtime) Unset(cs Source) { +// unset clears a cs information. +func (r *Runtime) unset(cs Source) { switch cs { case SourceWHOIS: r.whois = nil @@ -149,11 +162,16 @@ func (r *Runtime) Unset(cs Source) { } } -// IsEmpty returns true if there is no information from any source. -func (r *Runtime) IsEmpty() (ok bool) { +// isEmpty returns true if there is no information from any source. +func (r *Runtime) isEmpty() (ok bool) { return r.whois == nil && r.arp == nil && r.rdns == nil && r.dhcp == nil && r.hostsFile == nil } + +// Addr returns an IP address of the client. +func (r *Runtime) Addr() (ip netip.Addr) { + return r.ip +} diff --git a/internal/client/runtimeindex.go b/internal/client/runtimeindex.go new file mode 100644 index 00000000..300fdca0 --- /dev/null +++ b/internal/client/runtimeindex.go @@ -0,0 +1,63 @@ +package client + +import "net/netip" + +// RuntimeIndex stores information about runtime clients. +type RuntimeIndex struct { + // index maps IP address to runtime client. + index map[netip.Addr]*Runtime +} + +// NewRuntimeIndex returns initialized runtime index. +func NewRuntimeIndex() (ri *RuntimeIndex) { + return &RuntimeIndex{ + index: map[netip.Addr]*Runtime{}, + } +} + +// Client returns the saved runtime client by ip. If no such client exists, +// returns nil. +func (ri *RuntimeIndex) Client(ip netip.Addr) (rc *Runtime) { + return ri.index[ip] +} + +// Add saves the runtime client in the index. IP address of a client must be +// unique. See [Runtime.Client]. rc must not be nil. +func (ri *RuntimeIndex) Add(rc *Runtime) { + ip := rc.Addr() + ri.index[ip] = rc +} + +// Size returns the number of the runtime clients. +func (ri *RuntimeIndex) Size() (n int) { + return len(ri.index) +} + +// Range calls f for each runtime client in an undefined order. +func (ri *RuntimeIndex) Range(f func(rc *Runtime) (cont bool)) { + for _, rc := range ri.index { + if !f(rc) { + return + } + } +} + +// Delete removes the runtime client by ip. +func (ri *RuntimeIndex) Delete(ip netip.Addr) { + delete(ri.index, ip) +} + +// DeleteBySource removes all runtime clients that have information only from +// the specified source and returns the number of removed clients. +func (ri *RuntimeIndex) DeleteBySource(src Source) (n int) { + for ip, rc := range ri.index { + rc.unset(src) + + if rc.isEmpty() { + delete(ri.index, ip) + n++ + } + } + + return n +} diff --git a/internal/client/runtimeindex_test.go b/internal/client/runtimeindex_test.go new file mode 100644 index 00000000..66b975a0 --- /dev/null +++ b/internal/client/runtimeindex_test.go @@ -0,0 +1,85 @@ +package client_test + +import ( + "net/netip" + "testing" + + "github.com/AdguardTeam/AdGuardHome/internal/client" + "github.com/stretchr/testify/assert" +) + +func TestRuntimeIndex(t *testing.T) { + const cliSrc = client.SourceARP + + var ( + ip1 = netip.MustParseAddr("1.1.1.1") + ip2 = netip.MustParseAddr("2.2.2.2") + ip3 = netip.MustParseAddr("3.3.3.3") + ) + + ri := client.NewRuntimeIndex() + currentSize := 0 + + testCases := []struct { + ip netip.Addr + name string + hosts []string + src client.Source + }{{ + src: cliSrc, + ip: ip1, + name: "1", + hosts: []string{"host1"}, + }, { + src: cliSrc, + ip: ip2, + name: "2", + hosts: []string{"host2"}, + }, { + src: cliSrc, + ip: ip3, + name: "3", + hosts: []string{"host3"}, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + rc := client.NewRuntime(tc.ip) + rc.SetInfo(tc.src, tc.hosts) + + ri.Add(rc) + currentSize++ + + got := ri.Client(tc.ip) + assert.Equal(t, rc, got) + }) + } + + t.Run("size", func(t *testing.T) { + assert.Equal(t, currentSize, ri.Size()) + }) + + t.Run("range", func(t *testing.T) { + s := 0 + + ri.Range(func(rc *client.Runtime) (cont bool) { + s++ + + return true + }) + + assert.Equal(t, currentSize, s) + }) + + t.Run("delete", func(t *testing.T) { + ri.Delete(ip1) + currentSize-- + + assert.Equal(t, currentSize, ri.Size()) + }) + + t.Run("delete_by_src", func(t *testing.T) { + assert.Equal(t, currentSize, ri.DeleteBySource(cliSrc)) + assert.Equal(t, 0, ri.Size()) + }) +} diff --git a/internal/home/clients.go b/internal/home/clients.go index 2d5b1231..5c204973 100644 --- a/internal/home/clients.go +++ b/internal/home/clients.go @@ -50,10 +50,11 @@ type clientsContainer struct { // types (string, netip.Addr, and so on). list map[string]*client.Persistent // name -> client + // clientIndex stores information about persistent clients. clientIndex *client.Index - // ipToRC maps IP addresses to runtime client information. - ipToRC map[netip.Addr]*client.Runtime + // runtimeIndex stores information about runtime clients. + runtimeIndex *client.RuntimeIndex allTags *container.MapSet[string] @@ -105,7 +106,7 @@ func (clients *clientsContainer) Init( } clients.list = map[string]*client.Persistent{} - clients.ipToRC = map[netip.Addr]*client.Runtime{} + clients.runtimeIndex = client.NewRuntimeIndex() clients.clientIndex = client.NewIndex() @@ -363,8 +364,8 @@ func (clients *clientsContainer) clientSource(ip netip.Addr) (src client.Source) return client.SourcePersistent } - rc, ok := clients.ipToRC[ip] - if ok { + rc := clients.runtimeIndex.Client(ip) + if rc != nil { src, _ = rc.Info() } @@ -420,9 +421,8 @@ func (clients *clientsContainer) clientOrArtificial( }, false } - var rc *client.Runtime - rc, ok = clients.findRuntimeClient(ip) - if ok { + rc := clients.findRuntimeClient(ip) + if rc != nil { _, host := rc.Info() return &querylog.Client{ @@ -554,35 +554,33 @@ func (clients *clientsContainer) findDHCP(ip netip.Addr) (c *client.Persistent, // runtimeClient returns a runtime client from internal index. Note that it // doesn't include DHCP clients. -func (clients *clientsContainer) runtimeClient(ip netip.Addr) (rc *client.Runtime, ok bool) { +func (clients *clientsContainer) runtimeClient(ip netip.Addr) (rc *client.Runtime) { if ip == (netip.Addr{}) { - return nil, false + return nil } clients.lock.Lock() defer clients.lock.Unlock() - rc, ok = clients.ipToRC[ip] - - return rc, ok + return clients.runtimeIndex.Client(ip) } // findRuntimeClient finds a runtime client by their IP. -func (clients *clientsContainer) findRuntimeClient(ip netip.Addr) (rc *client.Runtime, ok bool) { - rc, ok = clients.runtimeClient(ip) +func (clients *clientsContainer) findRuntimeClient(ip netip.Addr) (rc *client.Runtime) { + rc = clients.runtimeClient(ip) host := clients.dhcp.HostByIP(ip) if host != "" { - if !ok { - rc = &client.Runtime{} + if rc == nil { + rc = client.NewRuntime(ip) } rc.SetInfo(client.SourceDHCP, []string{host}) - return rc, true + return rc } - return rc, ok + return rc } // check validates the client. It also sorts the client tags. @@ -734,12 +732,12 @@ func (clients *clientsContainer) setWHOISInfo(ip netip.Addr, wi *whois.Info) { return } - rc, ok := clients.ipToRC[ip] - if !ok { + rc := clients.runtimeIndex.Client(ip) + if rc == nil { // Create a RuntimeClient implicitly so that we don't do this check // again. - rc = &client.Runtime{} - clients.ipToRC[ip] = rc + rc = client.NewRuntime(ip) + clients.runtimeIndex.Add(rc) log.Debug("clients: set whois info for runtime client with ip %s: %+v", ip, wi) } else { @@ -798,61 +796,54 @@ func (clients *clientsContainer) addHostLocked( host string, src client.Source, ) (ok bool) { - rc, ok := clients.ipToRC[ip] - if !ok { + rc := clients.runtimeIndex.Client(ip) + if rc == nil { if src < client.SourceDHCP { if clients.dhcp.HostByIP(ip) != "" { return false } } - rc = &client.Runtime{} - clients.ipToRC[ip] = rc + rc = client.NewRuntime(ip) + clients.runtimeIndex.Add(rc) } rc.SetInfo(src, []string{host}) - log.Debug("clients: adding client info %s -> %q %q [%d]", ip, src, host, len(clients.ipToRC)) + log.Debug( + "clients: adding client info %s -> %q %q [%d]", + ip, + src, + host, + clients.runtimeIndex.Size(), + ) return true } -// rmHostsBySrc removes all entries that match the specified source. -func (clients *clientsContainer) rmHostsBySrc(src client.Source) { - n := 0 - for ip, rc := range clients.ipToRC { - rc.Unset(src) - if rc.IsEmpty() { - delete(clients.ipToRC, ip) - n++ - } - } - - log.Debug("clients: removed %d client aliases", n) -} - // addFromHostsFile fills the client-hostname pairing index from the system's // hosts files. func (clients *clientsContainer) addFromHostsFile(hosts *hostsfile.DefaultStorage) { clients.lock.Lock() defer clients.lock.Unlock() - clients.rmHostsBySrc(client.SourceHostsFile) + deleted := clients.runtimeIndex.DeleteBySource(client.SourceHostsFile) + log.Debug("clients: removed %d client aliases from system hosts file", deleted) - n := 0 + added := 0 hosts.RangeNames(func(addr netip.Addr, names []string) (cont bool) { // Only the first name of the first record is considered a canonical // hostname for the IP address. // // TODO(e.burkov): Consider using all the names from all the records. if clients.addHostLocked(addr, names[0], client.SourceHostsFile) { - n++ + added++ } return true }) - log.Debug("clients: added %d client aliases from system hosts file", n) + log.Debug("clients: added %d client aliases from system hosts file", added) } // addFromSystemARP adds the IP-hostname pairings from the output of the arp -a @@ -876,7 +867,8 @@ func (clients *clientsContainer) addFromSystemARP() { clients.lock.Lock() defer clients.lock.Unlock() - clients.rmHostsBySrc(client.SourceARP) + deleted := clients.runtimeIndex.DeleteBySource(client.SourceARP) + log.Debug("clients: removed %d client aliases from arp neighborhood", deleted) added := 0 for _, n := range ns { diff --git a/internal/home/clients_internal_test.go b/internal/home/clients_internal_test.go index 4f9cb946..ac83bb5e 100644 --- a/internal/home/clients_internal_test.go +++ b/internal/home/clients_internal_test.go @@ -244,7 +244,7 @@ func TestClientsWHOIS(t *testing.T) { t.Run("new_client", func(t *testing.T) { ip := netip.MustParseAddr("1.1.1.255") clients.setWHOISInfo(ip, whois) - rc := clients.ipToRC[ip] + rc := clients.runtimeIndex.Client(ip) require.NotNil(t, rc) assert.Equal(t, whois, rc.WHOIS()) @@ -256,7 +256,7 @@ func TestClientsWHOIS(t *testing.T) { assert.True(t, ok) clients.setWHOISInfo(ip, whois) - rc := clients.ipToRC[ip] + rc := clients.runtimeIndex.Client(ip) require.NotNil(t, rc) assert.Equal(t, whois, rc.WHOIS()) @@ -274,7 +274,7 @@ func TestClientsWHOIS(t *testing.T) { assert.True(t, ok) clients.setWHOISInfo(ip, whois) - rc := clients.ipToRC[ip] + rc := clients.runtimeIndex.Client(ip) require.Nil(t, rc) assert.True(t, clients.remove("client1")) diff --git a/internal/home/clientshttp.go b/internal/home/clientshttp.go index 77f877bd..03762f30 100644 --- a/internal/home/clientshttp.go +++ b/internal/home/clientshttp.go @@ -101,17 +101,19 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http data.Clients = append(data.Clients, cj) } - for ip, rc := range clients.ipToRC { + clients.runtimeIndex.Range(func(rc *client.Runtime) (cont bool) { src, host := rc.Info() cj := runtimeClientJSON{ WHOIS: whoisOrEmpty(rc), Name: host, Source: src, - IP: ip, + IP: rc.Addr(), } data.RuntimeClients = append(data.RuntimeClients, cj) - } + + return true + }) for _, l := range clients.dhcp.Leases() { cj := runtimeClientJSON{ @@ -463,8 +465,8 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http // /etc/hosts tables, DHCP leases, or blocklists. cj is guaranteed to be // non-nil. func (clients *clientsContainer) findRuntime(ip netip.Addr, idStr string) (cj *clientJSON) { - rc, ok := clients.findRuntimeClient(ip) - if !ok { + rc := clients.findRuntimeClient(ip) + if rc == nil { // It is still possible that the IP used to be in the runtime clients // list, but then the server was reloaded. So, check the DNS server's // blocked IP list.