From e268abf9268aef7a5386b5e126b01b249c590f49 Mon Sep 17 00:00:00 2001 From: Stanislav Chzhen Date: Thu, 13 Jun 2024 15:19:36 +0300 Subject: [PATCH] client: add tests --- internal/client/storage.go | 21 +++++++- internal/client/storage_test.go | 85 +++++++++++++++++++++++++++++++++ 2 files changed, 105 insertions(+), 1 deletion(-) create mode 100644 internal/client/storage_test.go diff --git a/internal/client/storage.go b/internal/client/storage.go index 46597cec..a336125c 100644 --- a/internal/client/storage.go +++ b/internal/client/storage.go @@ -4,6 +4,7 @@ import ( "fmt" "net/netip" "slices" + "sync" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" @@ -17,6 +18,9 @@ type Storage struct { // allTags is a set of all client tags. allTags *container.MapSet[string] + // mu protects index of persistent clients. + mu *sync.Mutex + // index contains information about persistent clients. index *Index @@ -30,13 +34,18 @@ func NewStorage(clientTags []string) (s *Storage) { return &Storage{ allTags: allTags, + mu: &sync.Mutex{}, index: NewIndex(), runtimeIndex: map[netip.Addr]*Runtime{}, } } -// Add stores persistent client information or returns an error. +// Add stores persistent client information or returns an error. p must be +// valid persistent client. func (s *Storage) Add(p *Persistent) (err error) { + s.mu.Lock() + defer s.mu.Unlock() + err = s.check(p) if err != nil { return fmt.Errorf("adding client: %w", err) @@ -48,6 +57,8 @@ func (s *Storage) Add(p *Persistent) (err error) { } // check returns an error if persistent client information contains errors. +// +// TODO(s.chzhen): Remove persistent client information validation. func (s *Storage) check(p *Persistent) (err error) { switch { case p == nil: @@ -56,6 +67,14 @@ func (s *Storage) check(p *Persistent) (err error) { return errors.Error("empty name") case p.IDsLen() == 0: return errors.Error("id required") + case p.UID == UID{}: + return errors.Error("uid required") + } + + err = s.index.ClashesUID(p) + if err != nil { + // Don't wrap the error since there is already an annotation deferred. + return err } conf, err := proxy.ParseUpstreamsConfig(p.Upstreams, &upstream.Options{}) diff --git a/internal/client/storage_test.go b/internal/client/storage_test.go new file mode 100644 index 00000000..66e23c4e --- /dev/null +++ b/internal/client/storage_test.go @@ -0,0 +1,85 @@ +package client_test + +import ( + "net/netip" + "testing" + + "github.com/AdguardTeam/AdGuardHome/internal/client" + "github.com/AdguardTeam/golibs/testutil" + "github.com/stretchr/testify/require" +) + +func TestStorage_Add(t *testing.T) { + testCases := []struct { + name string + cli *client.Persistent + wantErrMsg string + }{{ + name: "basic", + cli: &client.Persistent{ + Name: "basic", + IPs: []netip.Addr{ + netip.MustParseAddr("1.2.3.4"), + }, + UID: client.MustNewUID(), + }, + wantErrMsg: "", + }, { + name: "nil", + cli: nil, + wantErrMsg: "adding client: client is nil", + }, { + name: "empty_name", + cli: &client.Persistent{ + Name: "", + }, + wantErrMsg: "adding client: empty name", + }, { + name: "no_id", + cli: &client.Persistent{ + Name: "no_id", + }, + wantErrMsg: "adding client: id required", + }, { + name: "no_uid", + cli: &client.Persistent{ + Name: "no_uid", + IPs: []netip.Addr{ + netip.MustParseAddr("1.2.3.4"), + }, + }, + wantErrMsg: "adding client: uid required", + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + s := client.NewStorage(nil) + err := s.Add(tc.cli) + + testutil.AssertErrorMsg(t, tc.wantErrMsg, err) + }) + } + + t.Run("duplicate_uid", func(t *testing.T) { + sameUID := client.MustNewUID() + s := client.NewStorage(nil) + + cli1 := &client.Persistent{ + Name: "cli1", + IPs: []netip.Addr{netip.MustParseAddr("1.2.3.4")}, + UID: sameUID, + } + + cli2 := &client.Persistent{ + Name: "cli2", + IPs: []netip.Addr{netip.MustParseAddr("4.3.2.1")}, + UID: sameUID, + } + + err := s.Add(cli1) + require.NoError(t, err) + + err = s.Add(cli2) + testutil.AssertErrorMsg(t, `adding client: another client "cli1" uses the same uid`, err) + }) +}