mirror of
https://github.com/AdguardTeam/AdGuardHome.git
synced 2024-10-10 16:58:34 +03:00
home: add tests
This commit is contained in:
parent
c6cdba7a8d
commit
bf5c23a72c
@ -57,7 +57,7 @@ type clientsContainer struct {
|
||||
dhcp DHCP
|
||||
|
||||
// dnsServer is used for checking clients IP status access list status
|
||||
dnsServer *dnsforward.Server
|
||||
dnsServer BlockedClientChecker
|
||||
|
||||
// etcHosts contains list of rewrite rules taken from the operating system's
|
||||
// hosts database.
|
||||
@ -86,6 +86,12 @@ type clientsContainer struct {
|
||||
testing bool
|
||||
}
|
||||
|
||||
// BlockedClientChecker checks if a client is blocked by the current access
|
||||
// settings.
|
||||
type BlockedClientChecker interface {
|
||||
IsBlockedClient(ip netip.Addr, clientID string) (blocked bool, rule string)
|
||||
}
|
||||
|
||||
// Init initializes clients container
|
||||
// dhcpServer: optional
|
||||
// Note: this function must be called only once
|
||||
|
@ -2,15 +2,20 @@ package home
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/schedule"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@ -19,6 +24,24 @@ const (
|
||||
testClientIP2 = "2.2.2.2"
|
||||
)
|
||||
|
||||
// testBlockedClientChecker is a mock implementation of the
|
||||
// [BlockedClientChecker] interface.
|
||||
type testBlockedClientChecker struct {
|
||||
onIsBlockedClient func(ip netip.Addr, clientiD string) (blocked bool, rule string)
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ BlockedClientChecker = (*testBlockedClientChecker)(nil)
|
||||
|
||||
// IsBlockedClient implements the [BlockedClientChecker] interface for
|
||||
// *testBlockedClientChecker.
|
||||
func (c *testBlockedClientChecker) IsBlockedClient(
|
||||
ip netip.Addr,
|
||||
clientID string,
|
||||
) (blocked bool, rule string) {
|
||||
return c.onIsBlockedClient(ip, clientID)
|
||||
}
|
||||
|
||||
// newPersistentClient is a helper function that returns a persistent client
|
||||
// with the specified name and newly generated UID.
|
||||
func newPersistentClient(name string) (c *client.Persistent) {
|
||||
@ -43,10 +66,30 @@ func newPersistentClientWithIDs(tb testing.TB, name string, ids []string) (c *cl
|
||||
return c
|
||||
}
|
||||
|
||||
// clientsCompare is a helper function that uses HTTP API to check whether want
|
||||
// persistent clients are the same as the persistent clients stored in the
|
||||
// clients container.
|
||||
func clientsCompare(tb testing.TB, clients *clientsContainer, want []*client.Persistent) (ok bool) {
|
||||
// assertClients is a helper function that compares lists of persistent clients.
|
||||
func assertClients(tb testing.TB, want, got []*client.Persistent) {
|
||||
tb.Helper()
|
||||
|
||||
require.Len(tb, want, len(got))
|
||||
|
||||
sortFunc := func(a, b *client.Persistent) (n int) {
|
||||
return cmp.Compare(a.Name, b.Name)
|
||||
}
|
||||
|
||||
slices.SortFunc(want, sortFunc)
|
||||
slices.SortFunc(got, sortFunc)
|
||||
|
||||
slices.CompareFunc(want, got, func(a, b *client.Persistent) (n int) {
|
||||
assert.True(tb, a.EqualIDs(b), "%q doesn't have the same ids as %q", a.Name, b.Name)
|
||||
|
||||
return 0
|
||||
})
|
||||
}
|
||||
|
||||
// assertPersistentClients is a helper function that uses HTTP API to check
|
||||
// whether want persistent clients are the same as the persistent clients stored
|
||||
// in the clients container.
|
||||
func assertPersistentClients(tb testing.TB, clients *clientsContainer, want []*client.Persistent) {
|
||||
tb.Helper()
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
@ -59,25 +102,40 @@ func clientsCompare(tb testing.TB, clients *clientsContainer, want []*client.Per
|
||||
err = json.Unmarshal(body, clientList)
|
||||
require.NoError(tb, err)
|
||||
|
||||
got := map[string]*client.Persistent{}
|
||||
var got []*client.Persistent
|
||||
for _, cj := range clientList.Clients {
|
||||
var c *client.Persistent
|
||||
c, err = clients.jsonToClient(*cj, nil)
|
||||
require.NoError(tb, err)
|
||||
|
||||
got[c.Name] = c
|
||||
got = append(got, c)
|
||||
}
|
||||
require.Len(tb, want, len(got))
|
||||
|
||||
for _, c := range want {
|
||||
var gotClient *client.Persistent
|
||||
gotClient, ok = got[c.Name]
|
||||
if !ok || !gotClient.EqualIDs(c) {
|
||||
return false
|
||||
assertClients(tb, want, got)
|
||||
}
|
||||
|
||||
// assertPersistentClientsData is a helper function that checks whether want
|
||||
// persistent clients are the same as the persistent clients stored in data.
|
||||
func assertPersistentClientsData(
|
||||
tb testing.TB,
|
||||
clients *clientsContainer,
|
||||
data []map[string]*clientJSON,
|
||||
want []*client.Persistent,
|
||||
) {
|
||||
tb.Helper()
|
||||
|
||||
var got []*client.Persistent
|
||||
for _, cm := range data {
|
||||
for _, cj := range cm {
|
||||
var c *client.Persistent
|
||||
c, err := clients.jsonToClient(*cj, nil)
|
||||
require.NoError(tb, err)
|
||||
|
||||
got = append(got, c)
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
assertClients(tb, want, got)
|
||||
}
|
||||
|
||||
func TestClientsContainer_HandleAddClient(t *testing.T) {
|
||||
@ -131,8 +189,7 @@ func TestClientsContainer_HandleAddClient(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tc.wantCode, rw.Code)
|
||||
|
||||
ok := clientsCompare(t, clients, tc.wantClient)
|
||||
require.True(t, ok)
|
||||
assertPersistentClients(t, clients, tc.wantClient)
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -148,8 +205,7 @@ func TestClientsContainer_HandleDelClient(t *testing.T) {
|
||||
err = clients.add(clientTwo)
|
||||
require.NoError(t, err)
|
||||
|
||||
ok := clientsCompare(t, clients, []*client.Persistent{clientOne, clientTwo})
|
||||
require.True(t, ok)
|
||||
assertPersistentClients(t, clients, []*client.Persistent{clientOne, clientTwo})
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
@ -195,8 +251,7 @@ func TestClientsContainer_HandleDelClient(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tc.wantCode, rw.Code)
|
||||
|
||||
ok = clientsCompare(t, clients, tc.wantClient)
|
||||
require.True(t, ok)
|
||||
assertPersistentClients(t, clients, tc.wantClient)
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -208,8 +263,7 @@ func TestClientsContainer_HandleUpdateClient(t *testing.T) {
|
||||
err := clients.add(clientOne)
|
||||
require.NoError(t, err)
|
||||
|
||||
ok := clientsCompare(t, clients, []*client.Persistent{clientOne})
|
||||
require.True(t, ok)
|
||||
assertPersistentClients(t, clients, []*client.Persistent{clientOne})
|
||||
|
||||
clientModified := newPersistentClientWithIDs(t, "client2", []string{testClientIP2})
|
||||
|
||||
@ -274,8 +328,73 @@ func TestClientsContainer_HandleUpdateClient(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tc.wantCode, rw.Code)
|
||||
|
||||
ok = clientsCompare(t, clients, tc.wantClient)
|
||||
require.True(t, ok)
|
||||
assertPersistentClients(t, clients, tc.wantClient)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientsContainer_HandleFindClient(t *testing.T) {
|
||||
clients := newClientsContainer(t)
|
||||
clients.dnsServer = &testBlockedClientChecker{
|
||||
onIsBlockedClient: func(ip netip.Addr, clientID string) (ok bool, rule string) {
|
||||
return false, ""
|
||||
},
|
||||
}
|
||||
|
||||
clientOne := newPersistentClientWithIDs(t, "client1", []string{testClientIP1})
|
||||
err := clients.add(clientOne)
|
||||
require.NoError(t, err)
|
||||
|
||||
clientTwo := newPersistentClientWithIDs(t, "client2", []string{testClientIP2})
|
||||
err = clients.add(clientTwo)
|
||||
require.NoError(t, err)
|
||||
|
||||
assertPersistentClients(t, clients, []*client.Persistent{clientOne, clientTwo})
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
query url.Values
|
||||
wantCode int
|
||||
wantClient []*client.Persistent
|
||||
}{{
|
||||
name: "single",
|
||||
query: url.Values{
|
||||
"ip0": []string{testClientIP1},
|
||||
},
|
||||
wantCode: http.StatusOK,
|
||||
wantClient: []*client.Persistent{clientOne},
|
||||
}, {
|
||||
name: "multiple",
|
||||
query: url.Values{
|
||||
"ip0": []string{testClientIP1},
|
||||
"ip1": []string{testClientIP2},
|
||||
},
|
||||
wantCode: http.StatusOK,
|
||||
wantClient: []*client.Persistent{clientOne, clientTwo},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
rw := httptest.NewRecorder()
|
||||
var r *http.Request
|
||||
r, err = http.NewRequest(http.MethodGet, "", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
r.URL.RawQuery = tc.query.Encode()
|
||||
|
||||
clients.handleFindClient(rw, r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tc.wantCode, rw.Code)
|
||||
|
||||
var body []byte
|
||||
body, err = io.ReadAll(rw.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
clientData := []map[string]*clientJSON{}
|
||||
err = json.Unmarshal(body, &clientData)
|
||||
require.NoError(t, err)
|
||||
|
||||
assertPersistentClientsData(t, clients, clientData, tc.wantClient)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user