diff --git a/internal/client/index.go b/internal/client/index.go index 18f826de..63ae690e 100644 --- a/internal/client/index.go +++ b/internal/client/index.go @@ -4,9 +4,12 @@ import ( "fmt" "net" "net/netip" + "slices" + "strings" "github.com/AdguardTeam/AdGuardHome/internal/aghalg" "github.com/AdguardTeam/golibs/errors" + "golang.org/x/exp/maps" ) // macKey contains MAC as byte array of 6, 8, or 20 bytes. @@ -29,6 +32,9 @@ func macToKey(mac net.HardwareAddr) (key macKey) { // Index stores all information about persistent clients. type Index struct { + // nameToUID maps client name to UID. + nameToUID map[string]UID + // clientIDToUID maps client ID to UID. clientIDToUID map[string]UID @@ -48,6 +54,7 @@ type Index struct { // NewIndex initializes the new instance of client index. func NewIndex() (ci *Index) { return &Index{ + nameToUID: map[string]UID{}, clientIDToUID: map[string]UID{}, ipToUID: map[netip.Addr]UID{}, subnetToUID: aghalg.NewSortedMap[netip.Prefix, UID](subnetCompare), @@ -63,6 +70,8 @@ func (ci *Index) Add(c *Persistent) { panic("client must contain uid") } + ci.nameToUID[c.Name] = c.UID + for _, id := range c.ClientIDs { ci.clientIDToUID[id] = c.UID } @@ -83,21 +92,26 @@ func (ci *Index) Add(c *Persistent) { ci.uidToClient[c.UID] = c } -// ErrDuplicateUID is an error returned by [Index.Clashes] when adding a -// persistent client with a UID that already exists in an index. -const ErrDuplicateUID errors.Error = "duplicate uid" +// ClashesUID returns existing persistent client with the same UID as c. Note +// that this is only possible when configuration contains duplicate fields. +func (ci *Index) ClashesUID(c *Persistent) (err error) { + p, ok := ci.uidToClient[c.UID] + if ok { + return fmt.Errorf("another client %q uses the same uid", p.Name) + } + + return nil +} // Clashes returns an error if the index contains a different persistent client // with at least a single identifier contained by c. c must be non-nil. func (ci *Index) Clashes(c *Persistent) (err error) { - _, ok := ci.uidToClient[c.UID] - if ok { - return ErrDuplicateUID + if p := ci.clashesName(c); p != nil { + return fmt.Errorf("another client uses the same name %q", p.Name) } for _, id := range c.ClientIDs { - var existing UID - existing, ok = ci.clientIDToUID[id] + existing, ok := ci.clientIDToUID[id] if ok && existing != c.UID { p := ci.uidToClient[existing] @@ -123,6 +137,21 @@ func (ci *Index) Clashes(c *Persistent) (err error) { return nil } +// clashesName returns existing persistent client with the same name as c or +// nil. c must be non-nil. +func (ci *Index) clashesName(c *Persistent) (existing *Persistent) { + existing, ok := ci.FindByName(c.Name) + if !ok { + return nil + } + + if existing.UID != c.UID { + return existing + } + + return nil +} + // clashesIP returns a previous client with the same IP address as c. c must be // non-nil. func (ci *Index) clashesIP(c *Persistent) (p *Persistent, ip netip.Addr) { @@ -195,13 +224,23 @@ func (ci *Index) Find(id string) (c *Persistent, ok bool) { mac, err := net.ParseMAC(id) if err == nil { - return ci.findByMAC(mac) + return ci.FindByMAC(mac) } return nil, false } -// find finds persistent client by IP address. +// FindByName finds persistent client by name. +func (ci *Index) FindByName(name string) (c *Persistent, found bool) { + uid, found := ci.nameToUID[name] + if found { + return ci.uidToClient[uid], true + } + + return nil, false +} + +// findByIP finds persistent client by IP address. func (ci *Index) findByIP(ip netip.Addr) (c *Persistent, found bool) { uid, found := ci.ipToUID[ip] if found { @@ -227,6 +266,17 @@ func (ci *Index) findByIP(ip netip.Addr) (c *Persistent, found bool) { return nil, false } +// FindByMAC finds persistent client by MAC. +func (ci *Index) FindByMAC(mac net.HardwareAddr) (c *Persistent, found bool) { + k := macToKey(mac) + uid, found := ci.macToUID[k] + if found { + return ci.uidToClient[uid], true + } + + return nil, false +} + // FindByIPWithoutZone finds a persistent client by IP address without zone. It // strips the IPv6 zone index from the stored IP addresses before comparing, // because querylog entries don't have it. See TODO on [querylog.logEntry.IP]. @@ -247,20 +297,11 @@ func (ci *Index) FindByIPWithoutZone(ip netip.Addr) (c *Persistent) { return nil } -// find finds persistent client by MAC. -func (ci *Index) findByMAC(mac net.HardwareAddr) (c *Persistent, found bool) { - k := macToKey(mac) - uid, found := ci.macToUID[k] - if found { - return ci.uidToClient[uid], true - } - - return nil, false -} - // Delete removes information about persistent client from the index. c must be // non-nil. func (ci *Index) Delete(c *Persistent) { + delete(ci.nameToUID, c.Name) + for _, id := range c.ClientIDs { delete(ci.clientIDToUID, id) } @@ -280,3 +321,48 @@ func (ci *Index) Delete(c *Persistent) { delete(ci.uidToClient, c.UID) } + +// Size returns the number of persistent clients. +func (ci *Index) Size() (n int) { + return len(ci.uidToClient) +} + +// Range calls f for each persistent client, unless cont is false. The order is +// undefined. +func (ci *Index) Range(f func(c *Persistent) (cont bool)) { + for _, c := range ci.uidToClient { + if !f(c) { + return + } + } +} + +// RangeByName is like [Index.Range] but sorts the persistent clients by name +// before iterating ensuring a predictable order. +func (ci *Index) RangeByName(f func(c *Persistent) (cont bool)) { + cs := maps.Values(ci.uidToClient) + slices.SortFunc(cs, func(a, b *Persistent) (n int) { + return strings.Compare(a.Name, b.Name) + }) + + for _, c := range cs { + if !f(c) { + break + } + } +} + +// CloseUpstreams closes upstream configurations of persistent clients. +func (ci *Index) CloseUpstreams() (err error) { + var errs []error + ci.RangeByName(func(c *Persistent) (cont bool) { + err = c.CloseUpstreams() + if err != nil { + errs = append(errs, err) + } + + return true + }) + + return errors.Join(errs...) +} diff --git a/internal/client/index_internal_test.go b/internal/client/index_internal_test.go index 4e478462..38c0df15 100644 --- a/internal/client/index_internal_test.go +++ b/internal/client/index_internal_test.go @@ -22,7 +22,7 @@ func newIDIndex(m []*Persistent) (ci *Index) { return ci } -func TestClientIndex(t *testing.T) { +func TestClientIndex_Find(t *testing.T) { const ( cliIPNone = "1.2.3.4" cliIP1 = "1.1.1.1" @@ -71,13 +71,14 @@ func TestClientIndex(t *testing.T) { } ) - ci := newIDIndex([]*Persistent{ + clients := []*Persistent{ clientWithBothFams, clientWithSubnet, clientWithMAC, clientWithID, clientLinkLocal, - }) + } + ci := newIDIndex(clients) testCases := []struct { want *Persistent @@ -296,3 +297,54 @@ func TestIndex_FindByIPWithoutZone(t *testing.T) { }) } } + +func TestClientIndex_RangeByName(t *testing.T) { + sortedClients := []*Persistent{{ + Name: "clientA", + ClientIDs: []string{"A"}, + }, { + Name: "clientB", + ClientIDs: []string{"B"}, + }, { + Name: "clientC", + ClientIDs: []string{"C"}, + }, { + Name: "clientD", + ClientIDs: []string{"D"}, + }, { + Name: "clientE", + ClientIDs: []string{"E"}, + }} + + testCases := []struct { + name string + want []*Persistent + }{{ + name: "basic", + want: sortedClients, + }, { + name: "nil", + want: nil, + }, { + name: "one_element", + want: sortedClients[:1], + }, { + name: "two_elements", + want: sortedClients[:2], + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ci := newIDIndex(tc.want) + + var got []*Persistent + ci.RangeByName(func(c *Persistent) (cont bool) { + got = append(got, c) + + return true + }) + + assert.Equal(t, tc.want, got) + }) + } +} diff --git a/internal/home/clients.go b/internal/home/clients.go index 9cc210ab..4f3870ec 100644 --- a/internal/home/clients.go +++ b/internal/home/clients.go @@ -24,7 +24,6 @@ import ( "github.com/AdguardTeam/golibs/hostsfile" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/stringutil" - "golang.org/x/exp/maps" ) // DHCP is an interface for accessing DHCP lease data the [clientsContainer] @@ -46,10 +45,6 @@ type DHCP interface { // clientsContainer is the storage of all runtime and persistent clients. type clientsContainer struct { - // TODO(a.garipov): Perhaps use a number of separate indices for different - // types (string, netip.Addr, and so on). - list map[string]*client.Persistent // name -> client - // clientIndex stores information about persistent clients. clientIndex *client.Index @@ -61,8 +56,9 @@ type clientsContainer struct { // dhcp is the DHCP service implementation. dhcp DHCP - // dnsServer is used for checking clients IP status access list status - dnsServer *dnsforward.Server + // clientChecker checks if a client is blocked by the current access + // settings. + clientChecker BlockedClientChecker // etcHosts contains list of rewrite rules taken from the operating system's // hosts database. @@ -91,6 +87,12 @@ type clientsContainer struct { testing bool } +// BlockedClientChecker checks if a client is blocked by the current access +// settings. +type BlockedClientChecker interface { + IsBlockedClient(ip netip.Addr, clientID string) (blocked bool, rule string) +} + // Init initializes clients container // dhcpServer: optional // Note: this function must be called only once @@ -101,11 +103,11 @@ func (clients *clientsContainer) Init( arpDB arpdb.Interface, filteringConf *filtering.Config, ) (err error) { - if clients.list != nil { - log.Fatal("clients.list != nil") + // TODO(s.chzhen): Refactor it. + if clients.clientIndex != nil { + return errors.Error("clients container already initialized") } - clients.list = map[string]*client.Persistent{} clients.runtimeIndex = client.NewRuntimeIndex() clients.clientIndex = client.NewIndex() @@ -284,12 +286,14 @@ func (clients *clientsContainer) addFromConfig( return fmt.Errorf("clients: init persistent client at index %d: %w", i, err) } - _, err = clients.add(cli) + // TODO(s.chzhen): Consider moving to the client index constructor. + err = clients.clientIndex.ClashesUID(cli) if err != nil { - if errors.Is(err, client.ErrDuplicateUID) { - return fmt.Errorf("clients: adding client %s at index %d: %w", cli.Name, i, err) - } + return fmt.Errorf("adding client %s at index %d: %w", cli.Name, i, err) + } + err = clients.add(cli) + if err != nil { // TODO(s.chzhen): Return an error instead of logging if more // stringent requirements are implemented. log.Error("clients: adding client %s at index %d: %s", cli.Name, i, err) @@ -305,9 +309,9 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) { clients.lock.Lock() defer clients.lock.Unlock() - objs = make([]*clientObject, 0, len(clients.list)) - for _, cli := range clients.list { - o := &clientObject{ + objs = make([]*clientObject, 0, clients.clientIndex.Size()) + clients.clientIndex.Range(func(cli *client.Persistent) (cont bool) { + objs = append(objs, &clientObject{ Name: cli.Name, BlockedServices: cli.BlockedServices.Clone(), @@ -328,10 +332,10 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) { IgnoreStatistics: cli.IgnoreStatistics, UpstreamsCacheEnabled: cli.UpstreamsCacheEnabled, UpstreamsCacheSize: cli.UpstreamsCacheSize, - } + }) - objs = append(objs, o) - } + return true + }) // Maps aren't guaranteed to iterate in the same order each time, so the // above loop can generate different orderings when writing to the config @@ -411,7 +415,7 @@ func (clients *clientsContainer) clientOrArtificial( id string, ) (c *querylog.Client, art bool) { defer func() { - c.Disallowed, c.DisallowedRule = clients.dnsServer.IsBlockedClient(ip, id) + c.Disallowed, c.DisallowedRule = clients.clientChecker.IsBlockedClient(ip, id) if c.WHOIS == nil { c.WHOIS = &whois.Info{} } @@ -550,14 +554,7 @@ func (clients *clientsContainer) findDHCP(ip netip.Addr) (c *client.Persistent, return nil, false } - for _, c = range clients.list { - _, found := slices.BinarySearchFunc(c.MACs, foundMAC, slices.Compare[net.HardwareAddr]) - if found { - return c, true - } - } - - return nil, false + return clients.clientIndex.FindByMAC(foundMAC) } // runtimeClient returns a runtime client from internal index. Note that it @@ -621,43 +618,32 @@ func (clients *clientsContainer) check(c *client.Persistent) (err error) { return nil } -// add adds a new client object. ok is false if such client already exists or -// if an error occurred. -func (clients *clientsContainer) add(c *client.Persistent) (ok bool, err error) { +// add adds a persistent client or returns an error. +func (clients *clientsContainer) add(c *client.Persistent) (err error) { err = clients.check(c) if err != nil { - return false, err + // Don't wrap the error since it's informative enough as is. + return err } clients.lock.Lock() defer clients.lock.Unlock() - // check Name index - _, ok = clients.list[c.Name] - if ok { - return false, nil - } - - // check ID index err = clients.clientIndex.Clashes(c) if err != nil { // Don't wrap the error since it's informative enough as is. - return false, err + return err } clients.addLocked(c) - log.Debug("clients: added %q: ID:%q [%d]", c.Name, c.IDs(), len(clients.list)) + log.Debug("clients: added %q: ID:%q [%d]", c.Name, c.IDs(), clients.clientIndex.Size()) - return true, nil + return nil } // addLocked c to the indexes. clients.lock is expected to be locked. func (clients *clientsContainer) addLocked(c *client.Persistent) { - // update Name index - clients.list[c.Name] = c - - // update ID index clients.clientIndex.Add(c) } @@ -666,8 +652,7 @@ func (clients *clientsContainer) remove(name string) (ok bool) { clients.lock.Lock() defer clients.lock.Unlock() - var c *client.Persistent - c, ok = clients.list[name] + c, ok := clients.clientIndex.FindByName(name) if !ok { return false } @@ -684,9 +669,6 @@ func (clients *clientsContainer) removeLocked(c *client.Persistent) { log.Error("client container: removing client %s: %s", c.Name, err) } - // Update the name index. - delete(clients.list, c.Name) - // Update the ID index. clients.clientIndex.Delete(c) } @@ -702,22 +684,6 @@ func (clients *clientsContainer) update(prev, c *client.Persistent) (err error) clients.lock.Lock() defer clients.lock.Unlock() - // Check the name index. - if prev.Name != c.Name { - _, ok := clients.list[c.Name] - if ok { - return errors.Error("client already exists") - } - } - - if c.EqualIDs(prev) { - clients.removeLocked(prev) - clients.addLocked(c) - - return nil - } - - // Check the ID index. err = clients.clientIndex.Clashes(c) if err != nil { // Don't wrap the error since it's informative enough as is. @@ -891,18 +857,5 @@ func (clients *clientsContainer) addFromSystemARP() { // close gracefully closes all the client-specific upstream configurations of // the persistent clients. func (clients *clientsContainer) close() (err error) { - persistent := maps.Values(clients.list) - slices.SortFunc(persistent, func(a, b *client.Persistent) (res int) { - return strings.Compare(a.Name, b.Name) - }) - - var errs []error - - for _, cli := range persistent { - if err = cli.CloseUpstreams(); err != nil { - errs = append(errs, err) - } - } - - return errors.Join(errs...) + return clients.clientIndex.CloseUpstreams() } diff --git a/internal/home/clients_internal_test.go b/internal/home/clients_internal_test.go index ac83bb5e..d371df7b 100644 --- a/internal/home/clients_internal_test.go +++ b/internal/home/clients_internal_test.go @@ -41,7 +41,7 @@ func newClientsContainer(t *testing.T) (c *clientsContainer) { } dhcp := &testDHCP{ - OnLeases: func() (leases []*dhcpsvc.Lease) { panic("not implemented") }, + OnLeases: func() (leases []*dhcpsvc.Lease) { return nil }, OnHostBy: func(ip netip.Addr) (host string) { return "" }, OnMACBy: func(ip netip.Addr) (mac net.HardwareAddr) { return nil }, } @@ -72,23 +72,19 @@ func TestClients(t *testing.T) { IPs: []netip.Addr{cli1IP, cliIPv6}, } - ok, err := clients.add(c) + err := clients.add(c) require.NoError(t, err) - assert.True(t, ok) - c = &client.Persistent{ Name: "client2", UID: client.MustNewUID(), IPs: []netip.Addr{cli2IP}, } - ok, err = clients.add(c) + err = clients.add(c) require.NoError(t, err) - assert.True(t, ok) - - c, ok = clients.find(cli1) + c, ok := clients.find(cli1) require.True(t, ok) assert.Equal(t, "client1", c.Name) @@ -111,22 +107,20 @@ func TestClients(t *testing.T) { }) t.Run("add_fail_name", func(t *testing.T) { - ok, err := clients.add(&client.Persistent{ + err := clients.add(&client.Persistent{ Name: "client1", UID: client.MustNewUID(), IPs: []netip.Addr{netip.MustParseAddr("1.2.3.5")}, }) - require.NoError(t, err) - assert.False(t, ok) + require.Error(t, err) }) t.Run("add_fail_ip", func(t *testing.T) { - ok, err := clients.add(&client.Persistent{ + err := clients.add(&client.Persistent{ Name: "client3", UID: client.MustNewUID(), }) require.Error(t, err) - assert.False(t, ok) }) t.Run("update_fail_ip", func(t *testing.T) { @@ -145,12 +139,13 @@ func TestClients(t *testing.T) { cliNewIP = netip.MustParseAddr(cliNew) ) - prev, ok := clients.list["client1"] + prev, ok := clients.clientIndex.FindByName("client1") require.True(t, ok) + require.NotNil(t, prev) err := clients.update(prev, &client.Persistent{ Name: "client1", - UID: client.MustNewUID(), + UID: prev.UID, IPs: []netip.Addr{cliNewIP}, }) require.NoError(t, err) @@ -160,12 +155,13 @@ func TestClients(t *testing.T) { assert.Equal(t, clients.clientSource(cliNewIP), client.SourcePersistent) - prev, ok = clients.list["client1"] + prev, ok = clients.clientIndex.FindByName("client1") require.True(t, ok) + require.NotNil(t, prev) err = clients.update(prev, &client.Persistent{ Name: "client1-renamed", - UID: client.MustNewUID(), + UID: prev.UID, IPs: []netip.Addr{cliNewIP}, UseOwnSettings: true, }) @@ -177,7 +173,7 @@ func TestClients(t *testing.T) { assert.Equal(t, "client1-renamed", c.Name) assert.True(t, c.UseOwnSettings) - nilCli, ok := clients.list["client1"] + nilCli, ok := clients.clientIndex.FindByName("client1") require.False(t, ok) assert.Nil(t, nilCli) @@ -265,13 +261,12 @@ func TestClientsWHOIS(t *testing.T) { t.Run("can't_set_manually-added", func(t *testing.T) { ip := netip.MustParseAddr("1.1.1.2") - ok, err := clients.add(&client.Persistent{ + err := clients.add(&client.Persistent{ Name: "client1", UID: client.MustNewUID(), IPs: []netip.Addr{netip.MustParseAddr("1.1.1.2")}, }) require.NoError(t, err) - assert.True(t, ok) clients.setWHOISInfo(ip, whois) rc := clients.runtimeIndex.Client(ip) @@ -288,7 +283,7 @@ func TestClientsAddExisting(t *testing.T) { ip := netip.MustParseAddr("1.1.1.1") // Add a client. - ok, err := clients.add(&client.Persistent{ + err := clients.add(&client.Persistent{ Name: "client1", UID: client.MustNewUID(), IPs: []netip.Addr{ip, netip.MustParseAddr("1:2:3::4")}, @@ -296,10 +291,9 @@ func TestClientsAddExisting(t *testing.T) { MACs: []net.HardwareAddr{{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}}, }) require.NoError(t, err) - assert.True(t, ok) // Now add an auto-client with the same IP. - ok = clients.addHost(ip, "test", client.SourceRDNS) + ok := clients.addHost(ip, "test", client.SourceRDNS) assert.True(t, ok) }) @@ -339,22 +333,20 @@ func TestClientsAddExisting(t *testing.T) { require.NoError(t, err) // Add a new client with the same IP as for a client with MAC. - ok, err := clients.add(&client.Persistent{ + err = clients.add(&client.Persistent{ Name: "client2", UID: client.MustNewUID(), IPs: []netip.Addr{ip}, }) require.NoError(t, err) - assert.True(t, ok) // Add a new client with the IP from the first client's IP range. - ok, err = clients.add(&client.Persistent{ + err = clients.add(&client.Persistent{ Name: "client3", UID: client.MustNewUID(), IPs: []netip.Addr{netip.MustParseAddr("2.2.2.2")}, }) require.NoError(t, err) - assert.True(t, ok) }) } @@ -362,7 +354,7 @@ func TestClientsCustomUpstream(t *testing.T) { clients := newClientsContainer(t) // Add client with upstreams. - ok, err := clients.add(&client.Persistent{ + err := clients.add(&client.Persistent{ Name: "client1", UID: client.MustNewUID(), IPs: []netip.Addr{netip.MustParseAddr("1.1.1.1"), netip.MustParseAddr("1:2:3::4")}, @@ -372,7 +364,6 @@ func TestClientsCustomUpstream(t *testing.T) { }, }) require.NoError(t, err) - assert.True(t, ok) upsConf, err := clients.UpstreamConfigByID("1.2.3.4", net.DefaultResolver) assert.Nil(t, upsConf) diff --git a/internal/home/clientshttp.go b/internal/home/clientshttp.go index 03762f30..40a91f86 100644 --- a/internal/home/clientshttp.go +++ b/internal/home/clientshttp.go @@ -96,10 +96,12 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http clients.lock.Lock() defer clients.lock.Unlock() - for _, c := range clients.list { + clients.clientIndex.Range(func(c *client.Persistent) (cont bool) { cj := clientToJSON(c) data.Clients = append(data.Clients, cj) - } + + return true + }) clients.runtimeIndex.Range(func(rc *client.Runtime) (cont bool) { src, host := rc.Info() @@ -334,20 +336,16 @@ func (clients *clientsContainer) handleAddClient(w http.ResponseWriter, r *http. return } - ok, err := clients.add(c) + err = clients.add(c) if err != nil { aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) return } - if !ok { - aghhttp.Error(r, w, http.StatusBadRequest, "Client already exists") - - return + if !clients.testing { + onConfigModified() } - - onConfigModified() } // handleDelClient is the handler for POST /control/clients/delete HTTP API. @@ -372,7 +370,9 @@ func (clients *clientsContainer) handleDelClient(w http.ResponseWriter, r *http. return } - onConfigModified() + if !clients.testing { + onConfigModified() + } } // updateJSON contains the name and data of the updated persistent client. @@ -406,7 +406,7 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht clients.lock.Lock() defer clients.lock.Unlock() - prev, ok = clients.list[dj.Name] + prev, ok = clients.clientIndex.FindByName(dj.Name) }() if !ok { @@ -429,7 +429,9 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht return } - onConfigModified() + if !clients.testing { + onConfigModified() + } } // handleFindClient is the handler for GET /control/clients/find HTTP API. @@ -449,7 +451,7 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http cj = clients.findRuntime(ip, idStr) } else { cj = clientToJSON(c) - disallowed, rule := clients.dnsServer.IsBlockedClient(ip, idStr) + disallowed, rule := clients.clientChecker.IsBlockedClient(ip, idStr) cj.Disallowed, cj.DisallowedRule = &disallowed, &rule } @@ -472,7 +474,7 @@ func (clients *clientsContainer) findRuntime(ip netip.Addr, idStr string) (cj *c // blocked IP list. // // See https://github.com/AdguardTeam/AdGuardHome/issues/2428. - disallowed, rule := clients.dnsServer.IsBlockedClient(ip, idStr) + disallowed, rule := clients.clientChecker.IsBlockedClient(ip, idStr) cj = &clientJSON{ IDs: []string{idStr}, Disallowed: &disallowed, @@ -490,7 +492,7 @@ func (clients *clientsContainer) findRuntime(ip netip.Addr, idStr string) (cj *c WHOIS: whoisOrEmpty(rc), } - disallowed, rule := clients.dnsServer.IsBlockedClient(ip, idStr) + disallowed, rule := clients.clientChecker.IsBlockedClient(ip, idStr) cj.Disallowed, cj.DisallowedRule = &disallowed, &rule return cj diff --git a/internal/home/clientshttp_internal_test.go b/internal/home/clientshttp_internal_test.go new file mode 100644 index 00000000..dc1aa87d --- /dev/null +++ b/internal/home/clientshttp_internal_test.go @@ -0,0 +1,399 @@ +package home + +import ( + "bytes" + "cmp" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "net/netip" + "net/url" + "slices" + "testing" + + "github.com/AdguardTeam/AdGuardHome/internal/client" + "github.com/AdguardTeam/AdGuardHome/internal/filtering" + "github.com/AdguardTeam/AdGuardHome/internal/schedule" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const ( + testClientIP1 = "1.1.1.1" + testClientIP2 = "2.2.2.2" +) + +// testBlockedClientChecker is a mock implementation of the +// [BlockedClientChecker] interface. +type testBlockedClientChecker struct { + onIsBlockedClient func(ip netip.Addr, clientiD string) (blocked bool, rule string) +} + +// type check +var _ BlockedClientChecker = (*testBlockedClientChecker)(nil) + +// IsBlockedClient implements the [BlockedClientChecker] interface for +// *testBlockedClientChecker. +func (c *testBlockedClientChecker) IsBlockedClient( + ip netip.Addr, + clientID string, +) (blocked bool, rule string) { + return c.onIsBlockedClient(ip, clientID) +} + +// newPersistentClient is a helper function that returns a persistent client +// with the specified name and newly generated UID. +func newPersistentClient(name string) (c *client.Persistent) { + return &client.Persistent{ + Name: name, + UID: client.MustNewUID(), + BlockedServices: &filtering.BlockedServices{ + Schedule: &schedule.Weekly{}, + }, + } +} + +// newPersistentClientWithIDs is a helper function that returns a persistent +// client with the specified name and ids. +func newPersistentClientWithIDs(tb testing.TB, name string, ids []string) (c *client.Persistent) { + tb.Helper() + + c = newPersistentClient(name) + err := c.SetIDs(ids) + require.NoError(tb, err) + + return c +} + +// assertClients is a helper function that compares lists of persistent clients. +func assertClients(tb testing.TB, want, got []*client.Persistent) { + tb.Helper() + + require.Len(tb, got, len(want)) + + sortFunc := func(a, b *client.Persistent) (n int) { + return cmp.Compare(a.Name, b.Name) + } + + slices.SortFunc(want, sortFunc) + slices.SortFunc(got, sortFunc) + + slices.CompareFunc(want, got, func(a, b *client.Persistent) (n int) { + assert.True(tb, a.EqualIDs(b), "%q doesn't have the same ids as %q", a.Name, b.Name) + + return 0 + }) +} + +// assertPersistentClients is a helper function that uses HTTP API to check +// whether want persistent clients are the same as the persistent clients stored +// in the clients container. +func assertPersistentClients(tb testing.TB, clients *clientsContainer, want []*client.Persistent) { + tb.Helper() + + rw := httptest.NewRecorder() + clients.handleGetClients(rw, &http.Request{}) + + body, err := io.ReadAll(rw.Body) + require.NoError(tb, err) + + clientList := &clientListJSON{} + err = json.Unmarshal(body, clientList) + require.NoError(tb, err) + + var got []*client.Persistent + for _, cj := range clientList.Clients { + var c *client.Persistent + c, err = clients.jsonToClient(*cj, nil) + require.NoError(tb, err) + + got = append(got, c) + } + + assertClients(tb, want, got) +} + +// assertPersistentClientsData is a helper function that checks whether want +// persistent clients are the same as the persistent clients stored in data. +func assertPersistentClientsData( + tb testing.TB, + clients *clientsContainer, + data []map[string]*clientJSON, + want []*client.Persistent, +) { + tb.Helper() + + var got []*client.Persistent + for _, cm := range data { + for _, cj := range cm { + var c *client.Persistent + c, err := clients.jsonToClient(*cj, nil) + require.NoError(tb, err) + + got = append(got, c) + } + } + + assertClients(tb, want, got) +} + +func TestClientsContainer_HandleAddClient(t *testing.T) { + clients := newClientsContainer(t) + + clientOne := newPersistentClientWithIDs(t, "client1", []string{testClientIP1}) + clientTwo := newPersistentClientWithIDs(t, "client2", []string{testClientIP2}) + + clientEmptyID := newPersistentClient("empty_client_id") + clientEmptyID.ClientIDs = []string{""} + + testCases := []struct { + name string + client *client.Persistent + wantCode int + wantClient []*client.Persistent + }{{ + name: "add_one", + client: clientOne, + wantCode: http.StatusOK, + wantClient: []*client.Persistent{clientOne}, + }, { + name: "add_two", + client: clientTwo, + wantCode: http.StatusOK, + wantClient: []*client.Persistent{clientOne, clientTwo}, + }, { + name: "duplicate_client", + client: clientTwo, + wantCode: http.StatusBadRequest, + wantClient: []*client.Persistent{clientOne, clientTwo}, + }, { + name: "empty_client_id", + client: clientEmptyID, + wantCode: http.StatusBadRequest, + wantClient: []*client.Persistent{clientOne, clientTwo}, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + cj := clientToJSON(tc.client) + + body, err := json.Marshal(cj) + require.NoError(t, err) + + r, err := http.NewRequest(http.MethodPost, "", bytes.NewReader(body)) + require.NoError(t, err) + + rw := httptest.NewRecorder() + clients.handleAddClient(rw, r) + require.NoError(t, err) + require.Equal(t, tc.wantCode, rw.Code) + + assertPersistentClients(t, clients, tc.wantClient) + }) + } +} + +func TestClientsContainer_HandleDelClient(t *testing.T) { + clients := newClientsContainer(t) + + clientOne := newPersistentClientWithIDs(t, "client1", []string{testClientIP1}) + err := clients.add(clientOne) + require.NoError(t, err) + + clientTwo := newPersistentClientWithIDs(t, "client2", []string{testClientIP2}) + err = clients.add(clientTwo) + require.NoError(t, err) + + assertPersistentClients(t, clients, []*client.Persistent{clientOne, clientTwo}) + + testCases := []struct { + name string + client *client.Persistent + wantCode int + wantClient []*client.Persistent + }{{ + name: "remove_one", + client: clientOne, + wantCode: http.StatusOK, + wantClient: []*client.Persistent{clientTwo}, + }, { + name: "duplicate_client", + client: clientOne, + wantCode: http.StatusBadRequest, + wantClient: []*client.Persistent{clientTwo}, + }, { + name: "empty_client_name", + client: newPersistentClient(""), + wantCode: http.StatusBadRequest, + wantClient: []*client.Persistent{clientTwo}, + }, { + name: "remove_two", + client: clientTwo, + wantCode: http.StatusOK, + wantClient: []*client.Persistent{}, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + cj := clientToJSON(tc.client) + + var body []byte + body, err = json.Marshal(cj) + require.NoError(t, err) + + var r *http.Request + r, err = http.NewRequest(http.MethodPost, "", bytes.NewReader(body)) + require.NoError(t, err) + + rw := httptest.NewRecorder() + clients.handleDelClient(rw, r) + require.NoError(t, err) + require.Equal(t, tc.wantCode, rw.Code) + + assertPersistentClients(t, clients, tc.wantClient) + }) + } +} + +func TestClientsContainer_HandleUpdateClient(t *testing.T) { + clients := newClientsContainer(t) + + clientOne := newPersistentClientWithIDs(t, "client1", []string{testClientIP1}) + err := clients.add(clientOne) + require.NoError(t, err) + + assertPersistentClients(t, clients, []*client.Persistent{clientOne}) + + clientModified := newPersistentClientWithIDs(t, "client2", []string{testClientIP2}) + + clientEmptyID := newPersistentClient("empty_client_id") + clientEmptyID.ClientIDs = []string{""} + + testCases := []struct { + name string + clientName string + modified *client.Persistent + wantCode int + wantClient []*client.Persistent + }{{ + name: "update_one", + clientName: clientOne.Name, + modified: clientModified, + wantCode: http.StatusOK, + wantClient: []*client.Persistent{clientModified}, + }, { + name: "empty_name", + clientName: "", + modified: clientOne, + wantCode: http.StatusBadRequest, + wantClient: []*client.Persistent{clientModified}, + }, { + name: "client_not_found", + clientName: "client_not_found", + modified: clientOne, + wantCode: http.StatusBadRequest, + wantClient: []*client.Persistent{clientModified}, + }, { + name: "empty_client_id", + clientName: clientModified.Name, + modified: clientEmptyID, + wantCode: http.StatusBadRequest, + wantClient: []*client.Persistent{clientModified}, + }, { + name: "no_ids", + clientName: clientModified.Name, + modified: newPersistentClient("no_ids"), + wantCode: http.StatusBadRequest, + wantClient: []*client.Persistent{clientModified}, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + uj := updateJSON{ + Name: tc.clientName, + Data: *clientToJSON(tc.modified), + } + + var body []byte + body, err = json.Marshal(uj) + require.NoError(t, err) + + var r *http.Request + r, err = http.NewRequest(http.MethodPost, "", bytes.NewReader(body)) + require.NoError(t, err) + + rw := httptest.NewRecorder() + clients.handleUpdateClient(rw, r) + require.NoError(t, err) + require.Equal(t, tc.wantCode, rw.Code) + + assertPersistentClients(t, clients, tc.wantClient) + }) + } +} + +func TestClientsContainer_HandleFindClient(t *testing.T) { + clients := newClientsContainer(t) + clients.clientChecker = &testBlockedClientChecker{ + onIsBlockedClient: func(ip netip.Addr, clientID string) (ok bool, rule string) { + return false, "" + }, + } + + clientOne := newPersistentClientWithIDs(t, "client1", []string{testClientIP1}) + err := clients.add(clientOne) + require.NoError(t, err) + + clientTwo := newPersistentClientWithIDs(t, "client2", []string{testClientIP2}) + err = clients.add(clientTwo) + require.NoError(t, err) + + assertPersistentClients(t, clients, []*client.Persistent{clientOne, clientTwo}) + + testCases := []struct { + name string + query url.Values + wantCode int + wantClient []*client.Persistent + }{{ + name: "single", + query: url.Values{ + "ip0": []string{testClientIP1}, + }, + wantCode: http.StatusOK, + wantClient: []*client.Persistent{clientOne}, + }, { + name: "multiple", + query: url.Values{ + "ip0": []string{testClientIP1}, + "ip1": []string{testClientIP2}, + }, + wantCode: http.StatusOK, + wantClient: []*client.Persistent{clientOne, clientTwo}, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var r *http.Request + r, err = http.NewRequest(http.MethodGet, "", nil) + require.NoError(t, err) + + r.URL.RawQuery = tc.query.Encode() + rw := httptest.NewRecorder() + clients.handleFindClient(rw, r) + require.NoError(t, err) + require.Equal(t, tc.wantCode, rw.Code) + + var body []byte + body, err = io.ReadAll(rw.Body) + require.NoError(t, err) + + clientData := []map[string]*clientJSON{} + err = json.Unmarshal(body, &clientData) + require.NoError(t, err) + + assertPersistentClientsData(t, clients, clientData, tc.wantClient) + }) + } +} diff --git a/internal/home/dns.go b/internal/home/dns.go index 1e67c4ef..d64effd5 100644 --- a/internal/home/dns.go +++ b/internal/home/dns.go @@ -148,7 +148,7 @@ func initDNSServer( return fmt.Errorf("dnsforward.NewServer: %w", err) } - Context.clients.dnsServer = Context.dnsServer + Context.clients.clientChecker = Context.dnsServer dnsConf, err := newServerConfig(&config.DNS, config.Clients.Sources, tlsConf, httpReg) if err != nil {