diff --git a/internal/client/client.go b/internal/client/client.go index 9e76f01e..24e8c9a2 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -119,8 +119,8 @@ func (r *Runtime) Info() (cs Source, host string) { return cs, info[0] } -// SetInfo sets a host as a client information from the cs. -func (r *Runtime) SetInfo(cs Source, hosts []string) { +// setInfo sets a host as a client information from the cs. +func (r *Runtime) setInfo(cs Source, hosts []string) { // TODO(s.chzhen): Use contract where hosts must contain non-empty host. if len(hosts) == 1 && hosts[0] == "" { hosts = []string{} @@ -138,13 +138,13 @@ func (r *Runtime) SetInfo(cs Source, hosts []string) { } } -// WHOIS returns a WHOIS client information. +// WHOIS returns a copy of WHOIS client information. func (r *Runtime) WHOIS() (info *whois.Info) { - return r.whois + return r.whois.Clone() } -// SetWHOIS sets a WHOIS client information. info must be non-nil. -func (r *Runtime) SetWHOIS(info *whois.Info) { +// setWHOIS sets a WHOIS client information. info must be non-nil. +func (r *Runtime) setWHOIS(info *whois.Info) { r.whois = info } @@ -178,8 +178,8 @@ func (r *Runtime) Addr() (ip netip.Addr) { return r.ip } -// Clone returns a deep copy of the runtime client. -func (r *Runtime) Clone() (c *Runtime) { +// clone returns a deep copy of the runtime client. +func (r *Runtime) clone() (c *Runtime) { return &Runtime{ ip: r.ip, whois: r.whois.Clone(), diff --git a/internal/client/runtimeindex.go b/internal/client/runtimeindex.go index 300fdca0..10ee0b47 100644 --- a/internal/client/runtimeindex.go +++ b/internal/client/runtimeindex.go @@ -2,39 +2,34 @@ package client import "net/netip" -// RuntimeIndex stores information about runtime clients. -type RuntimeIndex struct { +// runtimeIndex stores information about runtime clients. +type runtimeIndex struct { // index maps IP address to runtime client. index map[netip.Addr]*Runtime } -// NewRuntimeIndex returns initialized runtime index. -func NewRuntimeIndex() (ri *RuntimeIndex) { - return &RuntimeIndex{ +// newRuntimeIndex returns initialized runtime index. +func newRuntimeIndex() (ri *runtimeIndex) { + return &runtimeIndex{ index: map[netip.Addr]*Runtime{}, } } -// Client returns the saved runtime client by ip. If no such client exists, +// client returns the saved runtime client by ip. If no such client exists, // returns nil. -func (ri *RuntimeIndex) Client(ip netip.Addr) (rc *Runtime) { +func (ri *runtimeIndex) client(ip netip.Addr) (rc *Runtime) { return ri.index[ip] } -// Add saves the runtime client in the index. IP address of a client must be +// add saves the runtime client in the index. IP address of a client must be // unique. See [Runtime.Client]. rc must not be nil. -func (ri *RuntimeIndex) Add(rc *Runtime) { +func (ri *runtimeIndex) add(rc *Runtime) { ip := rc.Addr() ri.index[ip] = rc } -// Size returns the number of the runtime clients. -func (ri *RuntimeIndex) Size() (n int) { - return len(ri.index) -} - -// Range calls f for each runtime client in an undefined order. -func (ri *RuntimeIndex) Range(f func(rc *Runtime) (cont bool)) { +// rangeF calls f for each runtime client in an undefined order. +func (ri *runtimeIndex) rangeF(f func(rc *Runtime) (cont bool)) { for _, rc := range ri.index { if !f(rc) { return @@ -42,17 +37,31 @@ func (ri *RuntimeIndex) Range(f func(rc *Runtime) (cont bool)) { } } -// Delete removes the runtime client by ip. -func (ri *RuntimeIndex) Delete(ip netip.Addr) { - delete(ri.index, ip) +// setInfo sets the client information from cs for runtime client stored by ip. +// If no such client exists, it creates one. +func (ri *runtimeIndex) setInfo(ip netip.Addr, cs Source, hosts []string) (rc *Runtime) { + rc = ri.index[ip] + if rc == nil { + rc = NewRuntime(ip) + ri.add(rc) + } + + rc.setInfo(cs, hosts) + + return rc } -// DeleteBySource removes all runtime clients that have information only from -// the specified source and returns the number of removed clients. -func (ri *RuntimeIndex) DeleteBySource(src Source) (n int) { - for ip, rc := range ri.index { +// clearSource removes information from the specified source from all clients. +func (ri *runtimeIndex) clearSource(src Source) { + for _, rc := range ri.index { rc.unset(src) + } +} +// removeEmpty removes empty runtime clients and returns the number of removed +// clients. +func (ri *runtimeIndex) removeEmpty() (n int) { + for ip, rc := range ri.index { if rc.isEmpty() { delete(ri.index, ip) n++ diff --git a/internal/client/runtimeindex_test.go b/internal/client/runtimeindex_test.go deleted file mode 100644 index 66b975a0..00000000 --- a/internal/client/runtimeindex_test.go +++ /dev/null @@ -1,85 +0,0 @@ -package client_test - -import ( - "net/netip" - "testing" - - "github.com/AdguardTeam/AdGuardHome/internal/client" - "github.com/stretchr/testify/assert" -) - -func TestRuntimeIndex(t *testing.T) { - const cliSrc = client.SourceARP - - var ( - ip1 = netip.MustParseAddr("1.1.1.1") - ip2 = netip.MustParseAddr("2.2.2.2") - ip3 = netip.MustParseAddr("3.3.3.3") - ) - - ri := client.NewRuntimeIndex() - currentSize := 0 - - testCases := []struct { - ip netip.Addr - name string - hosts []string - src client.Source - }{{ - src: cliSrc, - ip: ip1, - name: "1", - hosts: []string{"host1"}, - }, { - src: cliSrc, - ip: ip2, - name: "2", - hosts: []string{"host2"}, - }, { - src: cliSrc, - ip: ip3, - name: "3", - hosts: []string{"host3"}, - }} - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - rc := client.NewRuntime(tc.ip) - rc.SetInfo(tc.src, tc.hosts) - - ri.Add(rc) - currentSize++ - - got := ri.Client(tc.ip) - assert.Equal(t, rc, got) - }) - } - - t.Run("size", func(t *testing.T) { - assert.Equal(t, currentSize, ri.Size()) - }) - - t.Run("range", func(t *testing.T) { - s := 0 - - ri.Range(func(rc *client.Runtime) (cont bool) { - s++ - - return true - }) - - assert.Equal(t, currentSize, s) - }) - - t.Run("delete", func(t *testing.T) { - ri.Delete(ip1) - currentSize-- - - assert.Equal(t, currentSize, ri.Size()) - }) - - t.Run("delete_by_src", func(t *testing.T) { - assert.Equal(t, currentSize, ri.DeleteBySource(cliSrc)) - assert.Equal(t, 0, ri.Size()) - }) -} diff --git a/internal/client/storage.go b/internal/client/storage.go index d1e306f9..3a311c78 100644 --- a/internal/client/storage.go +++ b/internal/client/storage.go @@ -2,10 +2,10 @@ package client import ( "cmp" + "context" "fmt" "net" "net/netip" - "slices" "sync" "time" @@ -24,8 +24,8 @@ type DHCP interface { Leases() (leases []*dhcpsvc.Lease) // HostByIP returns the hostname of the DHCP client with the given IP - // address. The address will be netip.Addr{} if there is no such client, - // due to an assumption that a DHCP client must always have a hostname. + // address. host will be empty if there is no such client, due to an + // assumption that a DHCP client must always have a hostname. HostByIP(ip netip.Addr) (host string) // MACByIP returns the MAC address for the given IP address leased. It @@ -34,31 +34,47 @@ type DHCP interface { MACByIP(ip netip.Addr) (mac net.HardwareAddr) } +// emptyDHCP is the empty [DHCP] implementation that does nothing. type emptyDHCP struct{} // type check var _ DHCP = emptyDHCP{} -func (emptyDHCP) Leases() (_ []*dhcpsvc.Lease) { return nil } +// Leases implements the [DHCP] interface for emptyDHCP. +func (emptyDHCP) Leases() (leases []*dhcpsvc.Lease) { return nil } -func (emptyDHCP) HostByIP(_ netip.Addr) (_ string) { return "" } +// HostByIP implements the [DHCP] interface for emptyDHCP. +func (emptyDHCP) HostByIP(_ netip.Addr) (host string) { return "" } -func (emptyDHCP) MACByIP(_ netip.Addr) (_ net.HardwareAddr) { return nil } +// MACByIP implements the [DHCP] interface for emptyDHCP. +func (emptyDHCP) MACByIP(_ netip.Addr) (mac net.HardwareAddr) { return nil } +// HostsContainer is an interface for receiving updates to the system hosts +// file. type HostsContainer interface { Upd() (updates <-chan *hostsfile.DefaultStorage) } // Config is the client storage configuration structure. type Config struct { - DHCP DHCP + // DHCP is used to update [SourceDHCP] runtime client information. + DHCP DHCP + + // EtcHosts is used to update [SourceHostsFile] runtime client information. EtcHosts HostsContainer - ARPDB arpdb.Interface + + // ARPDB is used to update [SourceARP] runtime client information. + ARPDB arpdb.Interface // AllowedTags is a list of all allowed client tags. AllowedTags []string - InitialClients []*Persistent + // InitialClients is a list of persistent clients parsed from the + // configuration file. Each client must not be nil. + InitialClients []*Persistent + + // ARPClientsUpdatePeriod defines how often [SourceARP] runtime client + // information is updated. ARPClientsUpdatePeriod time.Duration } @@ -74,11 +90,22 @@ type Storage struct { index *index // runtimeIndex contains information about runtime clients. - runtimeIndex *RuntimeIndex + runtimeIndex *runtimeIndex - dhcp DHCP - etcHosts HostsContainer - arpDB arpdb.Interface + // dhcp is used to update [SourceDHCP] runtime client information. + dhcp DHCP + + // etcHosts is used to update [SourceHostsFile] runtime client information. + etcHosts HostsContainer + + // arpDB is used to update [SourceARP] runtime client information. + arpDB arpdb.Interface + + // done is the shutdown signaling channel. + done chan struct{} + + // arpClientsUpdatePeriod defines how often [SourceARP] runtime client + // information is updated. It must be greater than zero. arpClientsUpdatePeriod time.Duration } @@ -89,11 +116,12 @@ func NewStorage(conf *Config) (s *Storage, err error) { allowedTags: allowedTags, mu: &sync.Mutex{}, index: newIndex(), - runtimeIndex: NewRuntimeIndex(), + runtimeIndex: newRuntimeIndex(), dhcp: cmp.Or(conf.DHCP, DHCP(emptyDHCP{})), etcHosts: conf.EtcHosts, arpDB: conf.ARPDB, arpClientsUpdatePeriod: conf.ARPClientsUpdatePeriod, + done: make(chan struct{}), } for i, p := range conf.InitialClients { @@ -107,9 +135,18 @@ func NewStorage(conf *Config) (s *Storage, err error) { } // Start starts the goroutines for updating the runtime client information. -func (s *Storage) Start() { +func (s *Storage) Start(_ context.Context) (err error) { go s.periodicARPUpdate() go s.handleHostsUpdates() + + return nil +} + +// Shutdown gracefully stops the client storage. +func (s *Storage) Shutdown(_ context.Context) (err error) { + close(s.done) + + return nil } // periodicARPUpdate periodically reloads runtime clients from ARP. It is @@ -117,9 +154,15 @@ func (s *Storage) Start() { func (s *Storage) periodicARPUpdate() { defer log.OnPanic("storage") + t := time.NewTicker(s.arpClientsUpdatePeriod) + for { - s.ReloadARP() - time.Sleep(s.arpClientsUpdatePeriod) + select { + case <-t.C: + s.ReloadARP() + case <-s.done: + return + } } } @@ -133,6 +176,9 @@ func (s *Storage) ReloadARP() { // addFromSystemARP adds the IP-hostname pairings from the output of the arp -a // command. func (s *Storage) addFromSystemARP() { + s.mu.Lock() + defer s.mu.Unlock() + if err := s.arpDB.Refresh(); err != nil { s.arpDB = arpdb.Empty{} log.Error("refreshing arp container: %s", err) @@ -147,16 +193,16 @@ func (s *Storage) addFromSystemARP() { return } - var rcs []*Runtime - for _, n := range ns { - rc := NewRuntime(n.IP) - rc.SetInfo(SourceARP, []string{n.Name}) + src := SourceARP + s.runtimeIndex.clearSource(src) - rcs = append(rcs, rc) + for _, n := range ns { + s.runtimeIndex.setInfo(n.IP, src, []string{n.Name}) } - added, removed := s.BatchUpdateBySource(SourceARP, rcs) - log.Debug("storage: added %d, removed %d client aliases from arp neighborhood", added, removed) + removed := s.runtimeIndex.removeEmpty() + + log.Debug("storage: added %d, removed %d client aliases from arp neighborhood", len(ns), removed) } // handleHostsUpdates receives the updates from the hosts container and adds @@ -168,29 +214,38 @@ func (s *Storage) handleHostsUpdates() { defer log.OnPanic("storage") - for upd := range s.etcHosts.Upd() { - s.addFromHostsFile(upd) + for { + select { + case upd := <-s.etcHosts.Upd(): + s.addFromHostsFile(upd) + case <-s.done: + return + } } } // addFromHostsFile fills the client-hostname pairing index from the system's // hosts files. func (s *Storage) addFromHostsFile(hosts *hostsfile.DefaultStorage) { - var rcs []*Runtime + s.mu.Lock() + defer s.mu.Unlock() + + src := SourceHostsFile + s.runtimeIndex.clearSource(src) + + added := 0 hosts.RangeNames(func(addr netip.Addr, names []string) (cont bool) { // Only the first name of the first record is considered a canonical // hostname for the IP address. // // TODO(e.burkov): Consider using all the names from all the records. - rc := NewRuntime(addr) - rc.SetInfo(SourceHostsFile, []string{names[0]}) - - rcs = append(rcs, rc) + s.runtimeIndex.setInfo(addr, src, []string{names[0]}) + added++ return true }) - added, removed := s.BatchUpdateBySource(SourceHostsFile, rcs) + removed := s.runtimeIndex.removeEmpty() log.Debug("storage: added %d, removed %d client aliases from system hosts file", added, removed) } @@ -204,10 +259,11 @@ func (s *Storage) UpdateAddress(ip netip.Addr, host string, info *whois.Info) { return } + s.mu.Lock() + defer s.mu.Unlock() + if host != "" { - rc := NewRuntime(ip) - rc.SetInfo(SourceRDNS, []string{host}) - s.UpdateRuntime(rc) + s.runtimeIndex.setInfo(ip, SourceRDNS, []string{host}) } if info != nil { @@ -215,18 +271,44 @@ func (s *Storage) UpdateAddress(ip netip.Addr, host string, info *whois.Info) { } } +// UpdateDHCP updates [SourceDHCP] runtime client information. +func (s *Storage) UpdateDHCP() { + if s.dhcp == nil { + return + } + + s.mu.Lock() + defer s.mu.Unlock() + + src := SourceDHCP + s.runtimeIndex.clearSource(src) + + added := 0 + for _, l := range s.dhcp.Leases() { + s.runtimeIndex.setInfo(l.IP, src, []string{l.Hostname}) + added++ + } + + removed := s.runtimeIndex.removeEmpty() + log.Debug("storage: added %d, removed %d client aliases from dhcp", added, removed) +} + // setWHOISInfo sets the WHOIS information for a runtime client. func (s *Storage) setWHOISInfo(ip netip.Addr, wi *whois.Info) { - _, ok := s.Find(ip.String()) + _, ok := s.index.findByIP(ip) if ok { log.Debug("storage: client for %s is already created, ignore whois info", ip) return } - rc := NewRuntime(ip) - rc.SetWHOIS(wi) - s.UpdateRuntime(rc) + rc := s.runtimeIndex.client(ip) + if rc == nil { + rc = NewRuntime(ip) + s.runtimeIndex.add(rc) + } + + rc.setWHOIS(wi) log.Debug("storage: set whois info for runtime client with ip %s: %+v", ip, wi) } @@ -422,9 +504,9 @@ func (s *Storage) ClientRuntime(ip netip.Addr) (rc *Runtime) { s.mu.Lock() defer s.mu.Unlock() - rc = s.runtimeIndex.Client(ip) + rc = s.runtimeIndex.client(ip) if rc != nil { - return rc + return rc.clone() } host := s.dhcp.HostByIP(ip) @@ -432,87 +514,9 @@ func (s *Storage) ClientRuntime(ip netip.Addr) (rc *Runtime) { return nil } - // TODO(s.chzhen): Update runtime index. - rc = NewRuntime(ip) - rc.SetInfo(SourceDHCP, []string{host}) + rc = s.runtimeIndex.setInfo(ip, SourceDHCP, []string{host}) - return rc -} - -// UpdateRuntime updates the stored runtime client with information from rc. If -// no such client exists, saves the copy of rc in storage. rc must not be nil. -func (s *Storage) UpdateRuntime(rc *Runtime) (added bool) { - s.mu.Lock() - defer s.mu.Unlock() - - return s.updateRuntimeLocked(rc) -} - -// updateRuntimeLocked updates the stored runtime client with information from -// rc. rc must not be nil. Storage.mu is expected to be locked. -func (s *Storage) updateRuntimeLocked(rc *Runtime) (added bool) { - stored := s.runtimeIndex.Client(rc.ip) - if stored == nil { - s.runtimeIndex.Add(rc.Clone()) - - return true - } - - if rc.whois != nil { - stored.whois = rc.whois.Clone() - } - - if rc.arp != nil { - stored.arp = slices.Clone(rc.arp) - } - - if rc.rdns != nil { - stored.rdns = slices.Clone(rc.rdns) - } - - if rc.dhcp != nil { - stored.dhcp = slices.Clone(rc.dhcp) - } - - if rc.hostsFile != nil { - stored.hostsFile = slices.Clone(rc.hostsFile) - } - - return false -} - -// BatchUpdateBySource updates the stored runtime clients information from the -// specified source and returns the number of added and removed clients. -func (s *Storage) BatchUpdateBySource(src Source, rcs []*Runtime) (added, removed int) { - s.mu.Lock() - defer s.mu.Unlock() - - for _, rc := range s.runtimeIndex.index { - rc.unset(src) - } - - for _, rc := range rcs { - if s.updateRuntimeLocked(rc) { - added++ - } - } - - for ip, rc := range s.runtimeIndex.index { - if rc.isEmpty() { - delete(s.runtimeIndex.index, ip) - removed++ - } - } - - return added, removed -} - -// SizeRuntime returns the number of the runtime clients. -func (s *Storage) SizeRuntime() (n int) { - s.mu.Lock() - defer s.mu.Unlock() - - return s.runtimeIndex.Size() + return rc.clone() } // RangeRuntime calls f for each runtime client in an undefined order. @@ -520,16 +524,5 @@ func (s *Storage) RangeRuntime(f func(rc *Runtime) (cont bool)) { s.mu.Lock() defer s.mu.Unlock() - s.runtimeIndex.Range(f) -} - -// DeleteBySource removes all runtime clients that have information only from -// the specified source and returns the number of removed clients. -// -// TODO(s.chzhen): Use it. -func (s *Storage) DeleteBySource(src Source) (n int) { - s.mu.Lock() - defer s.mu.Unlock() - - return s.runtimeIndex.DeleteBySource(src) + s.runtimeIndex.rangeF(f) } diff --git a/internal/client/storage_test.go b/internal/client/storage_test.go index 1f2dd022..7f9ffe60 100644 --- a/internal/client/storage_test.go +++ b/internal/client/storage_test.go @@ -19,6 +19,8 @@ import ( "github.com/stretchr/testify/require" ) +// testHostsContainer is a mock implementation of the [client.HostsContainer] +// interface. type testHostsContainer struct { onUpd func() (updates <-chan *hostsfile.DefaultStorage) } @@ -26,6 +28,7 @@ type testHostsContainer struct { // type check var _ client.HostsContainer = (*testHostsContainer)(nil) +// Upd implements the [client.HostsContainer] interface for *testHostsContainer. func (c *testHostsContainer) Upd() (updates <-chan *hostsfile.DefaultStorage) { return c.onUpd() } @@ -41,33 +44,42 @@ type Interface interface { Neighbors() (ns []arpdb.Neighbor) } +// testARP is a mock implementation of the [arpdb.Interface]. type testARP struct { - onRefresh func() (err error) - + onRefresh func() (err error) onNeighbors func() (ns []arpdb.Neighbor) } +// type check +var _ arpdb.Interface = (*testARP)(nil) + +// Refresh implements the [arpdb.Interface] interface for *testARP. func (c *testARP) Refresh() (err error) { return c.onRefresh() } +// Neighbors implements the [arpdb.Interface] interface for *testARP. func (c *testARP) Neighbors() (ns []arpdb.Neighbor) { return c.onNeighbors() } +// testDHCP is a mock implementation of the [client.DHCP]. type testDHCP struct { OnLeases func() (leases []*dhcpsvc.Lease) OnHostBy func(ip netip.Addr) (host string) OnMACBy func(ip netip.Addr) (mac net.HardwareAddr) } -// Lease implements the [DHCP] interface for testDHCP. +// type check +var _ client.DHCP = (*testDHCP)(nil) + +// Lease implements the [client.DHCP] interface for *testDHCP. func (t *testDHCP) Leases() (leases []*dhcpsvc.Lease) { return t.OnLeases() } -// HostByIP implements the [DHCP] interface for testDHCP. +// HostByIP implements the [client.DHCP] interface for *testDHCP. func (t *testDHCP) HostByIP(ip netip.Addr) (host string) { return t.OnHostBy(ip) } -// MACByIP implements the [DHCP] interface for testDHCP. +// MACByIP implements the [client.DHCP] interface for *testDHCP. func (t *testDHCP) MACByIP(ip netip.Addr) (mac net.HardwareAddr) { return t.OnMACBy(ip) } // compareRuntimeInfo is a helper function that returns true if the runtime @@ -98,11 +110,17 @@ func TestStorage_Add_hostsfile(t *testing.T) { } storage, err := client.NewStorage(&client.Config{ - EtcHosts: h, + EtcHosts: h, + ARPClientsUpdatePeriod: testTimeout / 10, }) require.NoError(t, err) - storage.Start() + err = storage.Start(testutil.ContextWithTimeout(t, testTimeout)) + require.NoError(t, err) + + testutil.CleanupAndRequireSuccess(t, func() (err error) { + return storage.Shutdown(testutil.ContextWithTimeout(t, testTimeout)) + }) t.Run("add_hosts", func(t *testing.T) { var s *hostsfile.DefaultStorage @@ -184,7 +202,12 @@ func TestStorage_Add_arp(t *testing.T) { }) require.NoError(t, err) - storage.Start() + err = storage.Start(testutil.ContextWithTimeout(t, testTimeout)) + require.NoError(t, err) + + testutil.CleanupAndRequireSuccess(t, func() (err error) { + return storage.Shutdown(testutil.ContextWithTimeout(t, testTimeout)) + }) t.Run("add_hosts", func(t *testing.T) { func() { @@ -292,11 +315,19 @@ func TestStorage_Add_whois(t *testing.T) { func TestClientsDHCP(t *testing.T) { var ( cliIP1 = netip.MustParseAddr("1.1.1.1") - cliName1 = "client_one" + cliName1 = "one.dhcp" + + cliIP2 = netip.MustParseAddr("2.2.2.2") + cliMAC2 = mustParseMAC("22:22:22:22:22:22") + cliName2 = "two.dhcp" + + cliIP3 = netip.MustParseAddr("3.3.3.3") + cliMAC3 = mustParseMAC("33:33:33:33:33:33") + cliName3 = "three.dhcp" prsCliIP = netip.MustParseAddr("4.3.2.1") prsCliMAC = mustParseMAC("AA:AA:AA:AA:AA:AA") - prsCliName = "persitent_client" + prsCliName = "persitent.dhcp" ) ipToHost := map[netip.Addr]string{ @@ -306,8 +337,20 @@ func TestClientsDHCP(t *testing.T) { prsCliIP: prsCliMAC, } + leases := []*dhcpsvc.Lease{{ + IP: cliIP2, + Hostname: cliName2, + HWAddr: cliMAC2, + }, { + IP: cliIP3, + Hostname: cliName3, + HWAddr: cliMAC3, + }} + d := &testDHCP{ - OnLeases: func() (leases []*dhcpsvc.Lease) { panic("not implemented") }, + OnLeases: func() (ls []*dhcpsvc.Lease) { + return leases + }, OnHostBy: func(ip netip.Addr) (host string) { return ipToHost[ip] }, @@ -341,6 +384,34 @@ func TestClientsDHCP(t *testing.T) { assert.Equal(t, prsCliName, prsCli.Name) }) + + t.Run("leases", func(t *testing.T) { + delete(ipToHost, cliIP1) + storage.UpdateDHCP() + + cli1 := storage.ClientRuntime(cliIP1) + require.Nil(t, cli1) + + for i, l := range leases { + cli := storage.ClientRuntime(l.IP) + require.NotNil(t, cli) + + src, host := cli.Info() + assert.Equal(t, client.SourceDHCP, src) + assert.Equal(t, leases[i].Hostname, host) + } + }) + + t.Run("range", func(t *testing.T) { + s := 0 + storage.RangeRuntime(func(rc *client.Runtime) (cont bool) { + s++ + + return true + }) + + assert.Equal(t, len(leases), s) + }) } func TestClientsAddExisting(t *testing.T) { @@ -439,14 +510,6 @@ func newStorage(tb testing.TB, m []*client.Persistent) (s *client.Storage) { return s } -// newRuntimeClient is a helper function that returns a new runtime client. -func newRuntimeClient(ip netip.Addr, source client.Source, host string) (rc *client.Runtime) { - rc = client.NewRuntime(ip) - rc.SetInfo(source, []string{host}) - - return rc -} - // mustParseMAC is wrapper around [net.ParseMAC] that panics if there is an // error. func mustParseMAC(s string) (mac net.HardwareAddr) { @@ -1037,159 +1100,3 @@ func TestStorage_RangeByName(t *testing.T) { }) } } - -func TestStorage_UpdateRuntime(t *testing.T) { - const ( - addedARP = "added_arp" - addedSecondARP = "added_arp" - - updatedARP = "updated_arp" - - cliCity = "City" - cliCountry = "Country" - cliOrgname = "Orgname" - ) - - var ( - ip = netip.MustParseAddr("1.1.1.1") - ip2 = netip.MustParseAddr("2.2.2.2") - ) - - updated := client.NewRuntime(ip) - updated.SetInfo(client.SourceARP, []string{updatedARP}) - - info := &whois.Info{ - City: cliCity, - Country: cliCountry, - Orgname: cliOrgname, - } - updated.SetWHOIS(info) - - s, err := client.NewStorage(&client.Config{ - AllowedTags: nil, - }) - require.NoError(t, err) - - t.Run("add_arp_client", func(t *testing.T) { - added := client.NewRuntime(ip) - added.SetInfo(client.SourceARP, []string{addedARP}) - - require.True(t, s.UpdateRuntime(added)) - require.Equal(t, 1, s.SizeRuntime()) - - got := s.ClientRuntime(ip) - source, host := got.Info() - assert.Equal(t, client.SourceARP, source) - assert.Equal(t, addedARP, host) - }) - - t.Run("add_second_arp_client", func(t *testing.T) { - added := client.NewRuntime(ip2) - added.SetInfo(client.SourceARP, []string{addedSecondARP}) - - require.True(t, s.UpdateRuntime(added)) - require.Equal(t, 2, s.SizeRuntime()) - - got := s.ClientRuntime(ip2) - source, host := got.Info() - assert.Equal(t, client.SourceARP, source) - assert.Equal(t, addedSecondARP, host) - }) - - t.Run("update_first_client", func(t *testing.T) { - require.False(t, s.UpdateRuntime(updated)) - got := s.ClientRuntime(ip) - require.Equal(t, 2, s.SizeRuntime()) - - source, host := got.Info() - assert.Equal(t, client.SourceARP, source) - assert.Equal(t, updatedARP, host) - }) - - t.Run("remove_arp_info", func(t *testing.T) { - n := s.DeleteBySource(client.SourceARP) - require.Equal(t, 1, n) - require.Equal(t, 1, s.SizeRuntime()) - - got := s.ClientRuntime(ip) - source, _ := got.Info() - assert.Equal(t, client.SourceWHOIS, source) - assert.Equal(t, info, got.WHOIS()) - }) - - t.Run("remove_whois_info", func(t *testing.T) { - n := s.DeleteBySource(client.SourceWHOIS) - require.Equal(t, 1, n) - require.Equal(t, 0, s.SizeRuntime()) - }) -} - -func TestStorage_BatchUpdateBySource(t *testing.T) { - const ( - defSrc = client.SourceARP - - cliFirstHost1 = "host1" - cliFirstHost2 = "host2" - cliUpdatedHost3 = "host3" - cliUpdatedHost4 = "host4" - cliUpdatedHost5 = "host5" - ) - - var ( - cliFirstIP1 = netip.MustParseAddr("1.1.1.1") - cliFirstIP2 = netip.MustParseAddr("2.2.2.2") - cliUpdatedIP3 = netip.MustParseAddr("3.3.3.3") - cliUpdatedIP4 = netip.MustParseAddr("4.4.4.4") - cliUpdatedIP5 = netip.MustParseAddr("5.5.5.5") - ) - - firstClients := []*client.Runtime{ - newRuntimeClient(cliFirstIP1, defSrc, cliFirstHost1), - newRuntimeClient(cliFirstIP2, defSrc, cliFirstHost2), - } - - updatedClients := []*client.Runtime{ - newRuntimeClient(cliUpdatedIP3, defSrc, cliUpdatedHost3), - newRuntimeClient(cliUpdatedIP4, defSrc, cliUpdatedHost4), - newRuntimeClient(cliUpdatedIP5, defSrc, cliUpdatedHost5), - } - - s, err := client.NewStorage(&client.Config{ - AllowedTags: nil, - }) - require.NoError(t, err) - - t.Run("populate_storage_with_first_clients", func(t *testing.T) { - added, removed := s.BatchUpdateBySource(defSrc, firstClients) - require.Equal(t, len(firstClients), added) - require.Equal(t, 0, removed) - require.Equal(t, len(firstClients), s.SizeRuntime()) - - rc := s.ClientRuntime(cliFirstIP1) - src, host := rc.Info() - assert.Equal(t, defSrc, src) - assert.Equal(t, cliFirstHost1, host) - }) - - t.Run("update_storage", func(t *testing.T) { - added, removed := s.BatchUpdateBySource(defSrc, updatedClients) - require.Equal(t, len(updatedClients), added) - require.Equal(t, len(firstClients), removed) - require.Equal(t, len(updatedClients), s.SizeRuntime()) - - rc := s.ClientRuntime(cliUpdatedIP3) - src, host := rc.Info() - assert.Equal(t, defSrc, src) - assert.Equal(t, cliUpdatedHost3, host) - - rc = s.ClientRuntime(cliFirstIP1) - assert.Nil(t, rc) - }) - - t.Run("remove_all", func(t *testing.T) { - added, removed := s.BatchUpdateBySource(defSrc, []*client.Runtime{}) - require.Equal(t, 0, added) - require.Equal(t, len(updatedClients), removed) - require.Equal(t, 0, s.SizeRuntime()) - }) -} diff --git a/internal/home/clients.go b/internal/home/clients.go index 0952991c..9a4dc630 100644 --- a/internal/home/clients.go +++ b/internal/home/clients.go @@ -1,6 +1,7 @@ package home import ( + "context" "fmt" "net/netip" "slices" @@ -110,7 +111,7 @@ func (clients *clientsContainer) Init( var webHandlersRegistered = false // Start starts the clients container. -func (clients *clientsContainer) Start() { +func (clients *clientsContainer) Start(ctx context.Context) (err error) { if clients.testing { return } @@ -120,7 +121,7 @@ func (clients *clientsContainer) Start() { clients.registerWebHandlers() } - clients.storage.Start() + return clients.storage.Start(ctx) } // clientObject is the YAML representation of a persistent client. diff --git a/internal/home/clients_internal_test.go b/internal/home/clients_internal_test.go index 1fd004c8..927f9a32 100644 --- a/internal/home/clients_internal_test.go +++ b/internal/home/clients_internal_test.go @@ -3,34 +3,14 @@ package home import ( "net" "net/netip" - "runtime" "testing" - "time" "github.com/AdguardTeam/AdGuardHome/internal/client" - "github.com/AdguardTeam/AdGuardHome/internal/dhcpd" - "github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc" "github.com/AdguardTeam/AdGuardHome/internal/filtering" - "github.com/AdguardTeam/AdGuardHome/internal/whois" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -type testDHCP struct { - OnLeases func() (leases []*dhcpsvc.Lease) - OnHostBy func(ip netip.Addr) (host string) - OnMACBy func(ip netip.Addr) (mac net.HardwareAddr) -} - -// Lease implements the [DHCP] interface for testDHCP. -func (t *testDHCP) Leases() (leases []*dhcpsvc.Lease) { return t.OnLeases() } - -// HostByIP implements the [DHCP] interface for testDHCP. -func (t *testDHCP) HostByIP(ip netip.Addr) (host string) { return t.OnHostBy(ip) } - -// MACByIP implements the [DHCP] interface for testDHCP. -func (t *testDHCP) MACByIP(ip netip.Addr) (mac net.HardwareAddr) { return t.OnMACBy(ip) } - // newClientsContainer is a helper that creates a new clients container for // tests. func newClientsContainer(t *testing.T) (c *clientsContainer) { @@ -40,359 +20,11 @@ func newClientsContainer(t *testing.T) (c *clientsContainer) { testing: true, } - dhcp := &testDHCP{ - OnLeases: func() (leases []*dhcpsvc.Lease) { return nil }, - OnHostBy: func(ip netip.Addr) (host string) { return "" }, - OnMACBy: func(ip netip.Addr) (mac net.HardwareAddr) { return nil }, - } - - require.NoError(t, c.Init(nil, dhcp, nil, nil, &filtering.Config{})) + require.NoError(t, c.Init(nil, nil, nil, nil, &filtering.Config{})) return c } -// addHost adds a new IP-hostname pairing. -func (clients *clientsContainer) addHost( - ip netip.Addr, - host string, - src client.Source, -) (ok bool) { - rc := client.NewRuntime(ip) - rc.SetInfo(src, []string{host}) - clients.storage.UpdateRuntime(rc) - - return true -} - -// setWHOISInfo sets the WHOIS information for a client. -func (clients *clientsContainer) setWHOISInfo(ip netip.Addr, wi *whois.Info) { - _, ok := clients.storage.Find(ip.String()) - if ok { - return - } - - rc := client.NewRuntime(ip) - rc.SetWHOIS(wi) - clients.storage.UpdateRuntime(rc) -} - -// clientSource checks if client with this IP address already exists and returns -// the highest priority client source. -func (clients *clientsContainer) clientSource(ip netip.Addr) (src client.Source) { - _, ok := clients.storage.Find(ip.String()) - if ok { - return client.SourcePersistent - } - - rc := clients.storage.ClientRuntime(ip) - if rc != nil { - src, _ = rc.Info() - } - - return src -} - -func TestClients(t *testing.T) { - clients := newClientsContainer(t) - - t.Run("add_success", func(t *testing.T) { - var ( - cliNone = "1.2.3.4" - cli1 = "1.1.1.1" - cli2 = "2.2.2.2" - - cli1IP = netip.MustParseAddr(cli1) - cli2IP = netip.MustParseAddr(cli2) - - cliIPv6 = netip.MustParseAddr("1:2:3::4") - ) - - c := &client.Persistent{ - Name: "client1", - UID: client.MustNewUID(), - IPs: []netip.Addr{cli1IP, cliIPv6}, - } - - err := clients.storage.Add(c) - require.NoError(t, err) - - c = &client.Persistent{ - Name: "client2", - UID: client.MustNewUID(), - IPs: []netip.Addr{cli2IP}, - } - - err = clients.storage.Add(c) - require.NoError(t, err) - - c, ok := clients.storage.Find(cli1) - require.True(t, ok) - - assert.Equal(t, "client1", c.Name) - - c, ok = clients.storage.Find("1:2:3::4") - require.True(t, ok) - - assert.Equal(t, "client1", c.Name) - - c, ok = clients.storage.Find(cli2) - require.True(t, ok) - - assert.Equal(t, "client2", c.Name) - - _, ok = clients.storage.Find(cliNone) - assert.False(t, ok) - - assert.Equal(t, clients.clientSource(cli1IP), client.SourcePersistent) - assert.Equal(t, clients.clientSource(cli2IP), client.SourcePersistent) - }) - - t.Run("add_fail_name", func(t *testing.T) { - err := clients.storage.Add(&client.Persistent{ - Name: "client1", - UID: client.MustNewUID(), - IPs: []netip.Addr{netip.MustParseAddr("1.2.3.5")}, - }) - require.Error(t, err) - }) - - t.Run("add_fail_ip", func(t *testing.T) { - err := clients.storage.Add(&client.Persistent{ - Name: "client3", - UID: client.MustNewUID(), - }) - require.Error(t, err) - }) - - t.Run("update_fail_ip", func(t *testing.T) { - err := clients.storage.Update("client1", &client.Persistent{ - Name: "client1", - UID: client.MustNewUID(), - }) - assert.Error(t, err) - }) - - t.Run("update_success", func(t *testing.T) { - var ( - cliOld = "1.1.1.1" - cliNew = "1.1.1.2" - - cliNewIP = netip.MustParseAddr(cliNew) - ) - - prev, ok := clients.storage.FindByName("client1") - require.True(t, ok) - require.NotNil(t, prev) - - err := clients.storage.Update("client1", &client.Persistent{ - Name: "client1", - UID: prev.UID, - IPs: []netip.Addr{cliNewIP}, - }) - require.NoError(t, err) - - _, ok = clients.storage.Find(cliOld) - assert.False(t, ok) - - assert.Equal(t, clients.clientSource(cliNewIP), client.SourcePersistent) - - prev, ok = clients.storage.FindByName("client1") - require.True(t, ok) - require.NotNil(t, prev) - - err = clients.storage.Update("client1", &client.Persistent{ - Name: "client1-renamed", - UID: prev.UID, - IPs: []netip.Addr{cliNewIP}, - UseOwnSettings: true, - }) - require.NoError(t, err) - - c, ok := clients.storage.Find(cliNew) - require.True(t, ok) - - assert.Equal(t, "client1-renamed", c.Name) - assert.True(t, c.UseOwnSettings) - - nilCli, ok := clients.storage.FindByName("client1") - require.False(t, ok) - - assert.Nil(t, nilCli) - - require.Len(t, c.IDs(), 1) - - assert.Equal(t, cliNewIP, c.IPs[0]) - }) - - t.Run("del_success", func(t *testing.T) { - ok := clients.storage.RemoveByName("client1-renamed") - require.True(t, ok) - - _, ok = clients.storage.Find("1.1.1.2") - assert.False(t, ok) - }) - - t.Run("del_fail", func(t *testing.T) { - ok := clients.storage.RemoveByName("client3") - assert.False(t, ok) - }) - - t.Run("addhost_success", func(t *testing.T) { - ip := netip.MustParseAddr("1.1.1.1") - ok := clients.addHost(ip, "host", client.SourceARP) - assert.True(t, ok) - - ok = clients.addHost(ip, "host2", client.SourceARP) - assert.True(t, ok) - - ok = clients.addHost(ip, "host3", client.SourceHostsFile) - assert.True(t, ok) - - assert.Equal(t, clients.clientSource(ip), client.SourceHostsFile) - }) - - t.Run("dhcp_replaces_arp", func(t *testing.T) { - ip := netip.MustParseAddr("1.2.3.4") - ok := clients.addHost(ip, "from_arp", client.SourceARP) - assert.True(t, ok) - assert.Equal(t, clients.clientSource(ip), client.SourceARP) - - ok = clients.addHost(ip, "from_dhcp", client.SourceDHCP) - assert.True(t, ok) - assert.Equal(t, clients.clientSource(ip), client.SourceDHCP) - }) - - t.Run("addhost_priority", func(t *testing.T) { - ip := netip.MustParseAddr("1.1.1.1") - ok := clients.addHost(ip, "host1", client.SourceRDNS) - assert.True(t, ok) - - assert.Equal(t, client.SourceHostsFile, clients.clientSource(ip)) - }) -} - -func TestClientsWHOIS(t *testing.T) { - clients := newClientsContainer(t) - whois := &whois.Info{ - Country: "AU", - Orgname: "Example Org", - } - - t.Run("new_client", func(t *testing.T) { - ip := netip.MustParseAddr("1.1.1.255") - clients.setWHOISInfo(ip, whois) - rc := clients.storage.ClientRuntime(ip) - require.NotNil(t, rc) - - assert.Equal(t, whois, rc.WHOIS()) - }) - - t.Run("existing_auto-client", func(t *testing.T) { - ip := netip.MustParseAddr("1.1.1.1") - ok := clients.addHost(ip, "host", client.SourceRDNS) - assert.True(t, ok) - - clients.setWHOISInfo(ip, whois) - rc := clients.storage.ClientRuntime(ip) - require.NotNil(t, rc) - - assert.Equal(t, whois, rc.WHOIS()) - }) - - t.Run("can't_set_manually-added", func(t *testing.T) { - ip := netip.MustParseAddr("1.1.1.2") - - err := clients.storage.Add(&client.Persistent{ - Name: "client1", - UID: client.MustNewUID(), - IPs: []netip.Addr{netip.MustParseAddr("1.1.1.2")}, - }) - require.NoError(t, err) - - clients.setWHOISInfo(ip, whois) - rc := clients.storage.ClientRuntime(ip) - require.Nil(t, rc) - - assert.True(t, clients.storage.RemoveByName("client1")) - }) -} - -func TestClientsAddExisting(t *testing.T) { - clients := &clientsContainer{ - testing: true, - } - - // First, init a DHCP server with a single static lease. - config := &dhcpd.ServerConfig{ - Enabled: true, - DataDir: t.TempDir(), - Conf4: dhcpd.V4ServerConf{ - Enabled: true, - GatewayIP: netip.MustParseAddr("1.2.3.1"), - SubnetMask: netip.MustParseAddr("255.255.255.0"), - RangeStart: netip.MustParseAddr("1.2.3.2"), - RangeEnd: netip.MustParseAddr("1.2.3.10"), - }, - } - - dhcpServer, err := dhcpd.Create(config) - require.NoError(t, err) - - require.NoError(t, clients.Init(nil, dhcpServer, nil, nil, &filtering.Config{})) - - t.Run("simple", func(t *testing.T) { - ip := netip.MustParseAddr("1.1.1.1") - - // Add a client. - err = clients.storage.Add(&client.Persistent{ - Name: "client1", - UID: client.MustNewUID(), - IPs: []netip.Addr{ip, netip.MustParseAddr("1:2:3::4")}, - Subnets: []netip.Prefix{netip.MustParsePrefix("2.2.2.0/24")}, - MACs: []net.HardwareAddr{{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}}, - }) - require.NoError(t, err) - - // Now add an auto-client with the same IP. - ok := clients.addHost(ip, "test", client.SourceRDNS) - assert.True(t, ok) - }) - - t.Run("complicated", func(t *testing.T) { - // TODO(a.garipov): Properly decouple the DHCP server from the client - // storage. - if runtime.GOOS == "windows" { - t.Skip("skipping dhcp test on windows") - } - - ip := netip.MustParseAddr("1.2.3.4") - - err = dhcpServer.AddStaticLease(&dhcpsvc.Lease{ - HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}, - IP: ip, - Hostname: "testhost", - Expiry: time.Now().Add(time.Hour), - }) - require.NoError(t, err) - - // Add a new client with the same IP as for a client with MAC. - err = clients.storage.Add(&client.Persistent{ - Name: "client2", - UID: client.MustNewUID(), - IPs: []netip.Addr{ip}, - }) - require.NoError(t, err) - - // Add a new client with the IP from the first client's IP range. - err = clients.storage.Add(&client.Persistent{ - Name: "client3", - UID: client.MustNewUID(), - IPs: []netip.Addr{netip.MustParseAddr("2.2.2.2")}, - }) - require.NoError(t, err) - }) -} - func TestClientsCustomUpstream(t *testing.T) { clients := newClientsContainer(t) diff --git a/internal/home/clientshttp.go b/internal/home/clientshttp.go index 7766ba57..b6f27448 100644 --- a/internal/home/clientshttp.go +++ b/internal/home/clientshttp.go @@ -103,6 +103,8 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http return true }) + clients.storage.UpdateDHCP() + clients.storage.RangeRuntime(func(rc *client.Runtime) (cont bool) { src, host := rc.Info() cj := runtimeClientJSON{ @@ -117,18 +119,6 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http return true }) - // TODO(s.chzhen): Remove. - for _, l := range clients.dhcp.Leases() { - cj := runtimeClientJSON{ - Name: l.Hostname, - Source: client.SourceDHCP, - IP: l.IP, - WHOIS: &whois.Info{}, - } - - data.RuntimeClients = append(data.RuntimeClients, cj) - } - data.Tags = clientTags aghhttp.WriteJSONResponseOK(w, r, data) diff --git a/internal/home/dns.go b/internal/home/dns.go index 44c8f93f..72eb3a88 100644 --- a/internal/home/dns.go +++ b/internal/home/dns.go @@ -1,6 +1,7 @@ package home import ( + "context" "fmt" "log/slog" "net" @@ -457,9 +458,12 @@ func startDNSServer() error { Context.filters.EnableFilters(false) - Context.clients.Start() + err := Context.clients.Start(context.TODO()) + if err != nil { + return fmt.Errorf("couldn't start clients container: %w", err) + } - err := Context.dnsServer.Start() + err = Context.dnsServer.Start() if err != nil { return fmt.Errorf("couldn't start forwarding DNS server: %w", err) }