From bde3baa5da85dd5404f78bd79a6a3e85c55cf7fc Mon Sep 17 00:00:00 2001 From: Stanislav Chzhen Date: Mon, 20 May 2024 14:39:35 +0300 Subject: [PATCH] all: persistent client storage --- internal/client/index.go | 2 +- internal/client/storage.go | 106 +++++++++++++++++++++++++++++++++++++ internal/home/clients.go | 11 +--- 3 files changed, 108 insertions(+), 11 deletions(-) create mode 100644 internal/client/storage.go diff --git a/internal/client/index.go b/internal/client/index.go index 63ae690e..8cdbad13 100644 --- a/internal/client/index.go +++ b/internal/client/index.go @@ -64,7 +64,7 @@ func NewIndex() (ci *Index) { } // Add stores information about a persistent client in the index. c must be -// non-nil and contain UID. +// non-nil, have a UID, and contain at least one identifier. func (ci *Index) Add(c *Persistent) { if (c.UID == UID{}) { panic("client must contain uid") diff --git a/internal/client/storage.go b/internal/client/storage.go new file mode 100644 index 00000000..d6de62c3 --- /dev/null +++ b/internal/client/storage.go @@ -0,0 +1,106 @@ +package client + +import ( + "fmt" + "slices" + + "github.com/AdguardTeam/dnsproxy/proxy" + "github.com/AdguardTeam/dnsproxy/upstream" + "github.com/AdguardTeam/golibs/container" + "github.com/AdguardTeam/golibs/errors" +) + +// Storage contains information about persistent and runtime clients. +type Storage struct { + // allTags is a set of all client tags. + allTags *container.MapSet[string] + + // index contains information about persistent clients. + index *Index +} + +// NewStorage returns initialized client storage. +func NewStorage(clientTags []string) (s *Storage) { + allTags := container.NewMapSet(clientTags...) + + return &Storage{ + allTags: allTags, + index: NewIndex(), + } +} + +// Add stores persistent client information or returns an error. +func (s *Storage) Add(p *Persistent) (err error) { + err = s.check(p) + if err != nil { + return fmt.Errorf("adding client: %w", err) + } + + s.index.Add(p) + + return nil +} + +// check returns an error if persistent client information contains errors. +func (s *Storage) check(p *Persistent) (err error) { + switch { + case p == nil: + return errors.Error("client is nil") + case p.Name == "": + return errors.Error("empty name") + case p.IDsLen() == 0: + return errors.Error("id required") + } + + _, err = proxy.ParseUpstreamsConfig(p.Upstreams, &upstream.Options{}) + if err != nil { + return fmt.Errorf("invalid upstream servers: %w", err) + } + + for _, t := range p.Tags { + if !s.allTags.Has(t) { + return fmt.Errorf("invalid tag: %q", t) + } + } + + // TODO(s.chzhen): Move to the constructor. + slices.Sort(p.Tags) + + return nil +} + +// RemoveByName removes persistent client information. ok is false if no such +// client exists by that name. +func (s *Storage) RemoveByName(name string) (ok bool) { + p, ok := s.index.FindByName(name) + if !ok { + return false + } + + s.index.Delete(p) + + return true +} + +// Update updates stored persistent client information p with new information n +// or returns an error. p and n must have the same UID. +func (s *Storage) Update(p, n *Persistent) (err error) { + defer func() { err = errors.Annotate(err, "updating client: %w") }() + + err = s.check(n) + if err != nil { + // Don't wrap the error since there is already an annotation deferred. + return err + } + + err = s.index.Clashes(n) + if err != nil { + // Don't wrap the error since there is already an annotation deferred. + return err + } + + s.index.Delete(p) + s.index.Add(n) + + return nil +} diff --git a/internal/home/clients.go b/internal/home/clients.go index 4f3870ec..72c14178 100644 --- a/internal/home/clients.go +++ b/internal/home/clients.go @@ -5,7 +5,6 @@ import ( "net" "net/netip" "slices" - "strings" "sync" "time" @@ -310,7 +309,7 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) { defer clients.lock.Unlock() objs = make([]*clientObject, 0, clients.clientIndex.Size()) - clients.clientIndex.Range(func(cli *client.Persistent) (cont bool) { + clients.clientIndex.RangeByName(func(cli *client.Persistent) (cont bool) { objs = append(objs, &clientObject{ Name: cli.Name, @@ -337,14 +336,6 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) { return true }) - // Maps aren't guaranteed to iterate in the same order each time, so the - // above loop can generate different orderings when writing to the config - // file: this produces lots of diffs in config files, so sort objects by - // name before writing. - slices.SortStableFunc(objs, func(a, b *clientObject) (res int) { - return strings.Compare(a.Name, b.Name) - }) - return objs }