diff --git a/internal/home/clients.go b/internal/home/clients.go index c924d258..f1626eb6 100644 --- a/internal/home/clients.go +++ b/internal/home/clients.go @@ -57,7 +57,7 @@ type clientsContainer struct { dhcp DHCP // dnsServer is used for checking clients IP status access list status - dnsServer *dnsforward.Server + dnsServer BlockedClientChecker // etcHosts contains list of rewrite rules taken from the operating system's // hosts database. @@ -86,6 +86,12 @@ type clientsContainer struct { testing bool } +// BlockedClientChecker checks if a client is blocked by the current access +// settings. +type BlockedClientChecker interface { + IsBlockedClient(ip netip.Addr, clientID string) (blocked bool, rule string) +} + // Init initializes clients container // dhcpServer: optional // Note: this function must be called only once diff --git a/internal/home/clientshttp_internal_test.go b/internal/home/clientshttp_internal_test.go index 9697abfe..c0dd2a7f 100644 --- a/internal/home/clientshttp_internal_test.go +++ b/internal/home/clientshttp_internal_test.go @@ -2,15 +2,20 @@ package home import ( "bytes" + "cmp" "encoding/json" "io" "net/http" "net/http/httptest" + "net/netip" + "net/url" + "slices" "testing" "github.com/AdguardTeam/AdGuardHome/internal/client" "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/schedule" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -19,6 +24,24 @@ const ( testClientIP2 = "2.2.2.2" ) +// testBlockedClientChecker is a mock implementation of the +// [BlockedClientChecker] interface. +type testBlockedClientChecker struct { + onIsBlockedClient func(ip netip.Addr, clientiD string) (blocked bool, rule string) +} + +// type check +var _ BlockedClientChecker = (*testBlockedClientChecker)(nil) + +// IsBlockedClient implements the [BlockedClientChecker] interface for +// *testBlockedClientChecker. +func (c *testBlockedClientChecker) IsBlockedClient( + ip netip.Addr, + clientID string, +) (blocked bool, rule string) { + return c.onIsBlockedClient(ip, clientID) +} + // newPersistentClient is a helper function that returns a persistent client // with the specified name and newly generated UID. func newPersistentClient(name string) (c *client.Persistent) { @@ -43,10 +66,30 @@ func newPersistentClientWithIDs(tb testing.TB, name string, ids []string) (c *cl return c } -// clientsCompare is a helper function that uses HTTP API to check whether want -// persistent clients are the same as the persistent clients stored in the -// clients container. -func clientsCompare(tb testing.TB, clients *clientsContainer, want []*client.Persistent) (ok bool) { +// assertClients is a helper function that compares lists of persistent clients. +func assertClients(tb testing.TB, want, got []*client.Persistent) { + tb.Helper() + + require.Len(tb, want, len(got)) + + sortFunc := func(a, b *client.Persistent) (n int) { + return cmp.Compare(a.Name, b.Name) + } + + slices.SortFunc(want, sortFunc) + slices.SortFunc(got, sortFunc) + + slices.CompareFunc(want, got, func(a, b *client.Persistent) (n int) { + assert.True(tb, a.EqualIDs(b), "%q doesn't have the same ids as %q", a.Name, b.Name) + + return 0 + }) +} + +// assertPersistentClients is a helper function that uses HTTP API to check +// whether want persistent clients are the same as the persistent clients stored +// in the clients container. +func assertPersistentClients(tb testing.TB, clients *clientsContainer, want []*client.Persistent) { tb.Helper() rw := httptest.NewRecorder() @@ -59,25 +102,40 @@ func clientsCompare(tb testing.TB, clients *clientsContainer, want []*client.Per err = json.Unmarshal(body, clientList) require.NoError(tb, err) - got := map[string]*client.Persistent{} + var got []*client.Persistent for _, cj := range clientList.Clients { var c *client.Persistent c, err = clients.jsonToClient(*cj, nil) require.NoError(tb, err) - got[c.Name] = c + got = append(got, c) } - require.Len(tb, want, len(got)) - for _, c := range want { - var gotClient *client.Persistent - gotClient, ok = got[c.Name] - if !ok || !gotClient.EqualIDs(c) { - return false + assertClients(tb, want, got) +} + +// assertPersistentClientsData is a helper function that checks whether want +// persistent clients are the same as the persistent clients stored in data. +func assertPersistentClientsData( + tb testing.TB, + clients *clientsContainer, + data []map[string]*clientJSON, + want []*client.Persistent, +) { + tb.Helper() + + var got []*client.Persistent + for _, cm := range data { + for _, cj := range cm { + var c *client.Persistent + c, err := clients.jsonToClient(*cj, nil) + require.NoError(tb, err) + + got = append(got, c) } } - return true + assertClients(tb, want, got) } func TestClientsContainer_HandleAddClient(t *testing.T) { @@ -131,8 +189,7 @@ func TestClientsContainer_HandleAddClient(t *testing.T) { require.NoError(t, err) require.Equal(t, tc.wantCode, rw.Code) - ok := clientsCompare(t, clients, tc.wantClient) - require.True(t, ok) + assertPersistentClients(t, clients, tc.wantClient) }) } } @@ -148,8 +205,7 @@ func TestClientsContainer_HandleDelClient(t *testing.T) { err = clients.add(clientTwo) require.NoError(t, err) - ok := clientsCompare(t, clients, []*client.Persistent{clientOne, clientTwo}) - require.True(t, ok) + assertPersistentClients(t, clients, []*client.Persistent{clientOne, clientTwo}) testCases := []struct { name string @@ -195,8 +251,7 @@ func TestClientsContainer_HandleDelClient(t *testing.T) { require.NoError(t, err) require.Equal(t, tc.wantCode, rw.Code) - ok = clientsCompare(t, clients, tc.wantClient) - require.True(t, ok) + assertPersistentClients(t, clients, tc.wantClient) }) } } @@ -208,8 +263,7 @@ func TestClientsContainer_HandleUpdateClient(t *testing.T) { err := clients.add(clientOne) require.NoError(t, err) - ok := clientsCompare(t, clients, []*client.Persistent{clientOne}) - require.True(t, ok) + assertPersistentClients(t, clients, []*client.Persistent{clientOne}) clientModified := newPersistentClientWithIDs(t, "client2", []string{testClientIP2}) @@ -274,8 +328,73 @@ func TestClientsContainer_HandleUpdateClient(t *testing.T) { require.NoError(t, err) require.Equal(t, tc.wantCode, rw.Code) - ok = clientsCompare(t, clients, tc.wantClient) - require.True(t, ok) + assertPersistentClients(t, clients, tc.wantClient) + }) + } +} + +func TestClientsContainer_HandleFindClient(t *testing.T) { + clients := newClientsContainer(t) + clients.dnsServer = &testBlockedClientChecker{ + onIsBlockedClient: func(ip netip.Addr, clientID string) (ok bool, rule string) { + return false, "" + }, + } + + clientOne := newPersistentClientWithIDs(t, "client1", []string{testClientIP1}) + err := clients.add(clientOne) + require.NoError(t, err) + + clientTwo := newPersistentClientWithIDs(t, "client2", []string{testClientIP2}) + err = clients.add(clientTwo) + require.NoError(t, err) + + assertPersistentClients(t, clients, []*client.Persistent{clientOne, clientTwo}) + + testCases := []struct { + name string + query url.Values + wantCode int + wantClient []*client.Persistent + }{{ + name: "single", + query: url.Values{ + "ip0": []string{testClientIP1}, + }, + wantCode: http.StatusOK, + wantClient: []*client.Persistent{clientOne}, + }, { + name: "multiple", + query: url.Values{ + "ip0": []string{testClientIP1}, + "ip1": []string{testClientIP2}, + }, + wantCode: http.StatusOK, + wantClient: []*client.Persistent{clientOne, clientTwo}, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + rw := httptest.NewRecorder() + var r *http.Request + r, err = http.NewRequest(http.MethodGet, "", nil) + require.NoError(t, err) + + r.URL.RawQuery = tc.query.Encode() + + clients.handleFindClient(rw, r) + require.NoError(t, err) + require.Equal(t, tc.wantCode, rw.Code) + + var body []byte + body, err = io.ReadAll(rw.Body) + require.NoError(t, err) + + clientData := []map[string]*clientJSON{} + err = json.Unmarshal(body, &clientData) + require.NoError(t, err) + + assertPersistentClientsData(t, clients, clientData, tc.wantClient) }) } }