Pull request: all: allow clientid in access settings

Updates #2624.
Updates #3162.

Squashed commit of the following:

commit 68860da717a23a0bfeba14b7fe10b5e4ad38726d
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Tue Jun 29 15:41:33 2021 +0300

    all: imp types, names

commit ebd4ec26636853d0d58c4e331e6a78feede20813
Merge: 239eb721 16e5e09c
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Tue Jun 29 15:14:33 2021 +0300

    Merge branch 'master' into 2624-clientid-access

commit 239eb7215abc47e99a0300a0f4cf56002689b1a9
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Tue Jun 29 15:13:10 2021 +0300

    all: fix client blocking check

commit e6bece3ea8367b3cbe3d90702a3368c870ad4f13
Merge: 9935f2a3 9d1656b5
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Tue Jun 29 13:12:28 2021 +0300

    Merge branch 'master' into 2624-clientid-access

commit 9935f2a30bcfae2b853f3ef610c0ab7a56a8f448
Author: Ildar Kamalov <ik@adguard.com>
Date:   Tue Jun 29 11:26:51 2021 +0300

    client: show block button for client id

commit ed786a6a74a081cd89e9d67df3537a4fadd54831
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Fri Jun 25 15:56:23 2021 +0300

    client: imp i18n

commit 4fed21c68473ad408960c08a7d87624cabce1911
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Fri Jun 25 15:34:09 2021 +0300

    all: imp i18n, docs

commit 55e65c0d6b939560c53dcb834a4557eb3853d194
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Fri Jun 25 13:34:01 2021 +0300

    all: fix cache, imp code, docs, tests

commit c1e5a83e76deb44b1f92729bb9ddfcc6a96ac4a8
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Thu Jun 24 19:27:12 2021 +0300

    all: allow clientid in access settings
This commit is contained in:
Ainar Garipov 2021-06-29 15:53:28 +03:00
parent 16e5e09c2e
commit e08a64ebe4
33 changed files with 955 additions and 604 deletions

View File

@ -15,6 +15,7 @@ and this project adheres to
### Added
- Blocking access using client IDs ([#2624], [#3162]).
- `source` directives support in `/etc/network/interfaces` on Linux ([#3257]).
- RFC 9000 support in DNS-over-QUIC.
- Completely disabling statistics by setting the statistics interval to zero
@ -80,9 +81,11 @@ released by then.
[#2439]: https://github.com/AdguardTeam/AdGuardHome/issues/2439
[#2441]: https://github.com/AdguardTeam/AdGuardHome/issues/2441
[#2443]: https://github.com/AdguardTeam/AdGuardHome/issues/2443
[#2624]: https://github.com/AdguardTeam/AdGuardHome/issues/2624
[#2763]: https://github.com/AdguardTeam/AdGuardHome/issues/2763
[#3013]: https://github.com/AdguardTeam/AdGuardHome/issues/3013
[#3136]: https://github.com/AdguardTeam/AdGuardHome/issues/3136
[#3162]: https://github.com/AdguardTeam/AdGuardHome/issues/3162
[#3166]: https://github.com/AdguardTeam/AdGuardHome/issues/3166
[#3172]: https://github.com/AdguardTeam/AdGuardHome/issues/3172
[#3184]: https://github.com/AdguardTeam/AdGuardHome/issues/3184

View File

@ -159,8 +159,10 @@ attributes to make it work in Markdown renderers that strip "id". -->
* Minimize scope of variables as much as possible.
* No shadowing, since it can often lead to subtle bugs, especially with
errors.
* No name shadowing, including of predeclared identifiers, since it can often
lead to subtle bugs, especially with errors. This rule does not apply to
struct fields, since they are always used together with the name of the
struct value, so there isn't any confusion.
* Prefer constants to variables where possible. Avoid global variables. Use
[constant errors] instead of `errors.New`.

View File

@ -426,9 +426,9 @@
"access_title": "Access settings",
"access_desc": "Here you can configure access rules for the AdGuard Home DNS server.",
"access_allowed_title": "Allowed clients",
"access_allowed_desc": "A list of CIDR or IP addresses. If configured, AdGuard Home will accept requests from these IP addresses only.",
"access_allowed_desc": "A list of CIDRs, IP addresses, or client IDs. If configured, AdGuard Home will accept requests only from these clients.",
"access_disallowed_title": "Disallowed clients",
"access_disallowed_desc": "A list of CIDR or IP addresses. If configured, AdGuard Home will drop requests from these IP addresses.",
"access_disallowed_desc": "A list of CIDRs, IP addresses, or client IDs. If configured, AdGuard Home will drop requests from these clients. If allowed clients are configured, this field is ignored.",
"access_blocked_title": "Disallowed domains",
"access_blocked_desc": "Not to be confused with filters. AdGuard Home drops DNS queries matching these domains, and these queries don't even appear in the query log. You can specify exact domain names, wildcards, or URL filter rules, e.g. \"example.org\", \"*.example.org\", or \"||example.org^\" correspondingly.",
"access_settings_saved": "Access settings successfully saved",

View File

@ -9,7 +9,7 @@ import Card from '../ui/Card';
import Cell from '../ui/Cell';
import { getPercent, sortIp } from '../../helpers/helpers';
import { BLOCK_ACTIONS, R_CLIENT_ID, STATUS_COLORS } from '../../helpers/constants';
import { BLOCK_ACTIONS, STATUS_COLORS } from '../../helpers/constants';
import { toggleClientBlock } from '../../actions/access';
import { renderFormattedClientCell } from '../../helpers/renderFormattedClientCell';
import { getStats } from '../../actions/stats';
@ -35,10 +35,6 @@ const CountCell = (row) => {
};
const renderBlockingButton = (ip, disallowed, disallowed_rule) => {
if (R_CLIENT_ID.test(ip)) {
return null;
}
const dispatch = useDispatch();
const { t } = useTranslation();
const processingSet = useSelector((state) => state.access.processingSet);

2
go.mod
View File

@ -3,7 +3,7 @@ module github.com/AdguardTeam/AdGuardHome
go 1.16
require (
github.com/AdguardTeam/dnsproxy v0.37.7
github.com/AdguardTeam/dnsproxy v0.38.0
github.com/AdguardTeam/golibs v0.8.0
github.com/AdguardTeam/urlfilter v0.14.6
github.com/NYTimes/gziphandler v1.1.1

4
go.sum
View File

@ -9,8 +9,8 @@ dmitri.shuralyov.com/state v0.0.0-20180228185332-28bcc343414c/go.mod h1:0PRwlb0D
git.apache.org/thrift.git v0.0.0-20180902110319-2566ecd5d999/go.mod h1:fPE2ZNJGynbRyZ4dJvy6G277gSllfV2HJqblrnkyeyg=
github.com/AdguardTeam/dhcp v0.0.0-20210519141215-51808c73c0bf h1:gc042VRSIRSUzZ+Px6xQCRWNJZTaPkomisDfUZmoFNk=
github.com/AdguardTeam/dhcp v0.0.0-20210519141215-51808c73c0bf/go.mod h1:TKl4jN3Voofo4UJIicyNhWGp/nlQqQkFxmwIFTvBkKI=
github.com/AdguardTeam/dnsproxy v0.37.7 h1:yp0vEVYobf/1l8iY7es9yMqguw8BUEeC74OGA4G2v2A=
github.com/AdguardTeam/dnsproxy v0.37.7/go.mod h1:xMfevPAwpK1ULoLO0CARg/OiUsPH92kfyliXhPTc62M=
github.com/AdguardTeam/dnsproxy v0.38.0 h1:7GyyNJOieIVOgdnhu47exqWjHPQro7wQhqzvQjaZt6M=
github.com/AdguardTeam/dnsproxy v0.38.0/go.mod h1:xMfevPAwpK1ULoLO0CARg/OiUsPH92kfyliXhPTc62M=
github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
github.com/AdguardTeam/golibs v0.4.2/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
github.com/AdguardTeam/golibs v0.4.4/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=

View File

@ -27,10 +27,9 @@ type EtcHostsContainer struct {
lock sync.RWMutex
// table is the host-to-IPs map.
table map[string][]net.IP
// tableReverse is the IP-to-hosts map.
//
// TODO(a.garipov): Make better use of newtypes. Perhaps a custom map.
tableReverse map[string][]string
// tableReverse is the IP-to-hosts map. The type of the values in the
// map is []string.
tableReverse *IPMap
hostsFn string // path to the main hosts-file
hostsDirs []string // paths to OS-specific directories with hosts-files
@ -80,7 +79,7 @@ func (ehc *EtcHostsContainer) Init(hostsFn string) {
var err error
ehc.watcher, err = fsnotify.NewWatcher()
if err != nil {
log.Error("etchostscontainer: %s", err)
log.Error("etchosts: %s", err)
}
}
@ -141,7 +140,7 @@ func (ehc *EtcHostsContainer) Process(host string, qtype uint16) []net.IP {
copy(ipsCopy, ips)
}
log.Debug("etchostscontainer: answer: %s -> %v", host, ipsCopy)
log.Debug("etchosts: answer: %s -> %v", host, ipsCopy)
return ipsCopy
}
@ -151,38 +150,40 @@ func (ehc *EtcHostsContainer) ProcessReverse(addr string, qtype uint16) (hosts [
return nil
}
ipReal := UnreverseAddr(addr)
if ipReal == nil {
ip := UnreverseAddr(addr)
if ip == nil {
return nil
}
ipStr := ipReal.String()
ehc.lock.RLock()
defer ehc.lock.RUnlock()
hosts = ehc.tableReverse[ipStr]
if len(hosts) == 0 {
return nil // not found
v, ok := ehc.tableReverse.Get(ip)
if !ok {
return nil
}
log.Debug("etchostscontainer: reverse-lookup: %s -> %s", addr, hosts)
hosts, ok = v.([]string)
if !ok {
log.Error("etchosts: bad type %T in tableReverse for %s", v, ip)
return nil
} else if len(hosts) == 0 {
return nil
}
log.Debug("etchosts: reverse-lookup: %s -> %s", addr, hosts)
return hosts
}
// List returns an IP-to-hostnames table. It is safe for concurrent use.
func (ehc *EtcHostsContainer) List() (ipToHosts map[string][]string) {
// List returns an IP-to-hostnames table. The type of the values in the map is
// []string. It is safe for concurrent use.
func (ehc *EtcHostsContainer) List() (ipToHosts *IPMap) {
ehc.lock.RLock()
defer ehc.lock.RUnlock()
ipToHosts = make(map[string][]string, len(ehc.tableReverse))
for k, v := range ehc.tableReverse {
ipToHosts[k] = v
}
return ipToHosts
return ehc.tableReverse.ShallowClone()
}
// update table
@ -205,29 +206,31 @@ func (ehc *EtcHostsContainer) updateTable(table map[string][]net.IP, host string
ok = true
}
if ok {
log.Debug("etchostscontainer: added %s -> %s", ipAddr, host)
log.Debug("etchosts: added %s -> %s", ipAddr, host)
}
}
// updateTableRev updates the reverse address table.
func (ehc *EtcHostsContainer) updateTableRev(tableRev map[string][]string, newHost string, ipAddr net.IP) {
ipStr := ipAddr.String()
hosts, ok := tableRev[ipStr]
func (ehc *EtcHostsContainer) updateTableRev(tableRev *IPMap, newHost string, ip net.IP) {
v, ok := tableRev.Get(ip)
if !ok {
tableRev[ipStr] = []string{newHost}
log.Debug("etchostscontainer: added reverse-address %s -> %s", ipStr, newHost)
tableRev.Set(ip, []string{newHost})
log.Debug("etchosts: added reverse-address %s -> %s", ip, newHost)
return
}
hosts, _ := v.([]string)
for _, host := range hosts {
if host == newHost {
return
}
}
tableRev[ipStr] = append(tableRev[ipStr], newHost)
log.Debug("etchostscontainer: added reverse-address %s -> %s", ipStr, newHost)
hosts = append(hosts, newHost)
tableRev.Set(ip, hosts)
log.Debug("etchosts: added reverse-address %s -> %s", ip, newHost)
}
// parseHostsLine parses hosts from the fields.
@ -255,12 +258,12 @@ func parseHostsLine(fields []string) (hosts []string) {
// line for one IP are supported.
func (ehc *EtcHostsContainer) load(
table map[string][]net.IP,
tableRev map[string][]string,
tableRev *IPMap,
fn string,
) {
f, err := os.Open(fn)
if err != nil {
log.Error("etchostscontainer: %s", err)
log.Error("etchosts: %s", err)
return
}
@ -268,11 +271,11 @@ func (ehc *EtcHostsContainer) load(
defer func() {
derr := f.Close()
if derr != nil {
log.Error("etchostscontainer: closing file: %s", err)
log.Error("etchosts: closing file: %s", err)
}
}()
log.Debug("etchostscontainer: loading hosts from file %s", fn)
log.Debug("etchosts: loading hosts from file %s", fn)
s := bufio.NewScanner(f)
for s.Scan() {
@ -296,7 +299,7 @@ func (ehc *EtcHostsContainer) load(
err = s.Err()
if err != nil {
log.Error("etchostscontainer: %s", err)
log.Error("etchosts: %s", err)
}
}
@ -334,7 +337,7 @@ func (ehc *EtcHostsContainer) watcherLoop() {
}
if event.Op&fsnotify.Write == fsnotify.Write {
log.Debug("etchostscontainer: modified: %s", event.Name)
log.Debug("etchosts: modified: %s", event.Name)
ehc.updateHosts()
}
@ -342,7 +345,7 @@ func (ehc *EtcHostsContainer) watcherLoop() {
if !ok {
return
}
log.Error("etchostscontainer: %s", err)
log.Error("etchosts: %s", err)
}
}
}
@ -350,7 +353,7 @@ func (ehc *EtcHostsContainer) watcherLoop() {
// updateHosts - loads system hosts
func (ehc *EtcHostsContainer) updateHosts() {
table := make(map[string][]net.IP)
tableRev := make(map[string][]string)
tableRev := NewIPMap(0)
ehc.load(table, tableRev, ehc.hostsFn)
@ -358,7 +361,7 @@ func (ehc *EtcHostsContainer) updateHosts() {
des, err := os.ReadDir(dir)
if err != nil {
if !errors.Is(err, os.ErrNotExist) {
log.Error("etchostscontainer: Opening directory: %q: %s", dir, err)
log.Error("etchosts: Opening directory: %q: %s", dir, err)
}
continue

View File

@ -70,7 +70,7 @@ func TestEtcHostsContainerResolution(t *testing.T) {
})
t.Run("hosts_file", func(t *testing.T) {
names, ok := ehc.List()["127.0.0.1"]
names, ok := ehc.List().Get(net.IP{127, 0, 0, 1})
require.True(t, ok)
assert.Equal(t, []string{"host", "localhost"}, names)
})

112
internal/aghnet/ipmap.go Normal file
View File

@ -0,0 +1,112 @@
package aghnet
import (
"fmt"
"net"
)
// ipArr is a representation of an IP address as an array of bytes.
type ipArr [16]byte
// String implements the fmt.Stringer interface for ipArr.
func (a ipArr) String() (s string) {
return net.IP(a[:]).String()
}
// IPMap is a map of IP addresses.
type IPMap struct {
m map[ipArr]interface{}
}
// NewIPMap returns a new empty IP map using hint as a size hint for the
// underlying map.
func NewIPMap(hint int) (m *IPMap) {
return &IPMap{
m: make(map[ipArr]interface{}, hint),
}
}
// ipToArr converts a net.IP into an ipArr.
//
// TODO(a.garipov): Use the slice-to-array conversion in Go 1.17.
func ipToArr(ip net.IP) (a ipArr) {
copy(a[:], ip.To16())
return a
}
// Del deletes ip from the map. Calling Del on a nil *IPMap has no effect, just
// like delete on an empty map doesn't.
func (m *IPMap) Del(ip net.IP) {
if m != nil {
delete(m.m, ipToArr(ip))
}
}
// Get returns the value from the map. Calling Get on a nil *IPMap returns nil
// and false, just like indexing on an empty map does.
func (m *IPMap) Get(ip net.IP) (v interface{}, ok bool) {
if m != nil {
v, ok = m.m[ipToArr(ip)]
return v, ok
}
return nil, false
}
// Len returns the length of the map. A nil *IPMap has a length of zero, just
// like an empty map.
func (m *IPMap) Len() (n int) {
if m == nil {
return 0
}
return len(m.m)
}
// Range calls f for each key and value present in the map in an undefined
// order. If cont is false, range stops the iteration. Calling Range on a nil
// *IPMap has no effect, just like ranging over a nil map.
func (m *IPMap) Range(f func(ip net.IP, v interface{}) (cont bool)) {
if m == nil {
return
}
for k, v := range m.m {
if !f(net.IP(k[:]), v) {
break
}
}
}
// Set sets the value. Set panics if the m is a nil *IPMap, just like a nil map
// does.
func (m *IPMap) Set(ip net.IP, v interface{}) {
m.m[ipToArr(ip)] = v
}
// ShallowClone returns a shallow clone of the map.
func (m *IPMap) ShallowClone() (sclone *IPMap) {
if m == nil {
return nil
}
sclone = NewIPMap(m.Len())
m.Range(func(ip net.IP, v interface{}) (cont bool) {
sclone.Set(ip, v)
return true
})
return sclone
}
// String implements the fmt.Stringer interface for *IPMap.
func (m *IPMap) String() (s string) {
if m == nil {
return "<nil>"
}
return fmt.Sprint(m.m)
}

View File

@ -0,0 +1,142 @@
package aghnet
import (
"net"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestIPMap_allocs(t *testing.T) {
ip4 := net.IP{1, 2, 3, 4}
m := NewIPMap(0)
m.Set(ip4, 42)
t.Run("get", func(t *testing.T) {
var v interface{}
var ok bool
allocs := testing.AllocsPerRun(100, func() {
v, ok = m.Get(ip4)
})
require.True(t, ok)
require.Equal(t, 42, v)
assert.Equal(t, float64(0), allocs)
})
t.Run("len", func(t *testing.T) {
var n int
allocs := testing.AllocsPerRun(100, func() {
n = m.Len()
})
require.Equal(t, 1, n)
assert.Equal(t, float64(0), allocs)
})
}
func TestIPMap(t *testing.T) {
ip4 := net.IP{1, 2, 3, 4}
ip6 := net.IP{
0x12, 0x34, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x56, 0x78,
}
val := 42
t.Run("nil", func(t *testing.T) {
var m *IPMap
assert.NotPanics(t, func() {
m.Del(ip4)
m.Del(ip6)
})
assert.NotPanics(t, func() {
v, ok := m.Get(ip4)
assert.Nil(t, v)
assert.False(t, ok)
v, ok = m.Get(ip6)
assert.Nil(t, v)
assert.False(t, ok)
})
assert.NotPanics(t, func() {
assert.Equal(t, 0, m.Len())
})
assert.NotPanics(t, func() {
n := 0
m.Range(func(_ net.IP, _ interface{}) (cont bool) {
n++
return true
})
assert.Equal(t, 0, n)
})
assert.Panics(t, func() {
m.Set(ip4, val)
})
assert.Panics(t, func() {
m.Set(ip6, val)
})
assert.NotPanics(t, func() {
sclone := m.ShallowClone()
assert.Nil(t, sclone)
})
})
testIPMap := func(t *testing.T, ip net.IP, s string) {
m := NewIPMap(0)
assert.Equal(t, 0, m.Len())
v, ok := m.Get(ip)
assert.Nil(t, v)
assert.False(t, ok)
m.Set(ip, val)
v, ok = m.Get(ip)
assert.Equal(t, val, v)
assert.True(t, ok)
n := 0
m.Range(func(ipKey net.IP, v interface{}) (cont bool) {
assert.Equal(t, ip.To16(), ipKey)
assert.Equal(t, val, v)
n++
return false
})
assert.Equal(t, 1, n)
sclone := m.ShallowClone()
assert.Equal(t, m, sclone)
assert.Equal(t, s, m.String())
m.Del(ip)
v, ok = m.Get(ip)
assert.Nil(t, v)
assert.False(t, ok)
assert.Equal(t, 0, m.Len())
}
t.Run("ipv4", func(t *testing.T) {
testIPMap(t, ip4, "map[1.2.3.4:42]")
})
t.Run("ipv6", func(t *testing.T) {
testIPMap(t, ip6, "map[1234::5678:42]")
})
}

View File

@ -6,138 +6,163 @@ import (
"net"
"net/http"
"strings"
"sync"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/aghstrings"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/urlfilter"
"github.com/AdguardTeam/urlfilter/filterlist"
)
// accessCtx controls IP and client blocking that takes place before all other
// processing. An accessCtx is safe for concurrent use.
type accessCtx struct {
lock sync.Mutex
allowedIPs *aghnet.IPMap
blockedIPs *aghnet.IPMap
// allowedClients are the IP addresses of clients in the allowlist.
allowedClients *aghstrings.Set
allowedClientIDs *aghstrings.Set
blockedClientIDs *aghstrings.Set
// disallowedClients are the IP addresses of clients in the blocklist.
disallowedClients *aghstrings.Set
blockedHostsEng *urlfilter.DNSEngine
allowedClientsIPNet []net.IPNet // CIDRs of whitelist clients
disallowedClientsIPNet []net.IPNet // CIDRs of clients that should be blocked
blockedHostsEngine *urlfilter.DNSEngine // finds hosts that should be blocked
// TODO(a.garipov): Create a type for a set of IP networks.
// aghnet.IPNetSet?
allowedNets []*net.IPNet
blockedNets []*net.IPNet
}
func newAccessCtx(allowedClients, disallowedClients, blockedHosts []string) (a *accessCtx, err error) {
a = &accessCtx{
allowedClients: aghstrings.NewSet(),
disallowedClients: aghstrings.NewSet(),
}
// unit is a convenient alias for struct{}
type unit = struct{}
err = processIPCIDRArray(a.allowedClients, &a.allowedClientsIPNet, allowedClients)
if err != nil {
return nil, fmt.Errorf("processing allowed clients: %w", err)
}
// processAccessClients is a helper for processing a list of client strings,
// which may be an IP address, a CIDR, or a ClientID.
func processAccessClients(
clientStrs []string,
ips *aghnet.IPMap,
nets *[]*net.IPNet,
clientIDs *aghstrings.Set,
) (err error) {
for i, s := range clientStrs {
if ip := net.ParseIP(s); ip != nil {
ips.Set(ip, unit{})
} else if cidrIP, ipnet, cidrErr := net.ParseCIDR(s); cidrErr == nil {
ipnet.IP = cidrIP
*nets = append(*nets, ipnet)
} else {
idErr := ValidateClientID(s)
if idErr != nil {
return fmt.Errorf(
"value %q at index %d: bad ip, cidr, or clientid",
s,
i,
)
}
err = processIPCIDRArray(a.disallowedClients, &a.disallowedClientsIPNet, disallowedClients)
if err != nil {
return nil, fmt.Errorf("processing disallowed clients: %w", err)
}
b := &strings.Builder{}
for _, s := range blockedHosts {
aghstrings.WriteToBuilder(b, strings.ToLower(s), "\n")
}
listArray := []filterlist.RuleList{}
list := &filterlist.StringRuleList{
ID: int(0),
RulesText: b.String(),
IgnoreCosmetic: true,
}
listArray = append(listArray, list)
rulesStorage, err := filterlist.NewRuleStorage(listArray)
if err != nil {
return nil, fmt.Errorf("filterlist.NewRuleStorage(): %w", err)
}
a.blockedHostsEngine = urlfilter.NewDNSEngine(rulesStorage)
return a, nil
}
// Split array of IP or CIDR into 2 containers for fast search
func processIPCIDRArray(dst *aghstrings.Set, dstIPNet *[]net.IPNet, src []string) error {
for _, s := range src {
ip := net.ParseIP(s)
if ip != nil {
dst.Add(s)
continue
clientIDs.Add(s)
}
_, ipnet, err := net.ParseCIDR(s)
if err != nil {
return err
}
*dstIPNet = append(*dstIPNet, *ipnet)
}
return nil
}
// IsBlockedIP - return TRUE if this client should be blocked
// Returns the item from the "disallowedClients" list that lead to blocking IP.
// If it returns TRUE and an empty string, it means that the "allowedClients" is not empty,
// but the ip does not belong to it.
func (a *accessCtx) IsBlockedIP(ip net.IP) (bool, string) {
ipStr := ip.String()
// newAccessCtx creates a new accessCtx.
func newAccessCtx(allowed, blocked, blockedHosts []string) (a *accessCtx, err error) {
a = &accessCtx{
allowedIPs: aghnet.NewIPMap(0),
blockedIPs: aghnet.NewIPMap(0),
a.lock.Lock()
defer a.lock.Unlock()
if a.allowedClients.Len() != 0 || len(a.allowedClientsIPNet) != 0 {
if a.allowedClients.Has(ipStr) {
return false, ""
}
if len(a.allowedClientsIPNet) != 0 {
for _, ipnet := range a.allowedClientsIPNet {
if ipnet.Contains(ip) {
return false, ""
}
}
}
return true, ""
allowedClientIDs: aghstrings.NewSet(),
blockedClientIDs: aghstrings.NewSet(),
}
if a.disallowedClients.Has(ipStr) {
return true, ipStr
err = processAccessClients(allowed, a.allowedIPs, &a.allowedNets, a.allowedClientIDs)
if err != nil {
return nil, fmt.Errorf("adding allowed: %w", err)
}
if len(a.disallowedClientsIPNet) != 0 {
for _, ipnet := range a.disallowedClientsIPNet {
if ipnet.Contains(ip) {
return true, ipnet.String()
}
}
err = processAccessClients(blocked, a.blockedIPs, &a.blockedNets, a.blockedClientIDs)
if err != nil {
return nil, fmt.Errorf("adding blocked: %w", err)
}
return false, ""
b := &strings.Builder{}
for _, h := range blockedHosts {
aghstrings.WriteToBuilder(b, strings.ToLower(h), "\n")
}
lists := []filterlist.RuleList{
&filterlist.StringRuleList{
ID: int(0),
RulesText: b.String(),
IgnoreCosmetic: true,
},
}
rulesStrg, err := filterlist.NewRuleStorage(lists)
if err != nil {
return nil, fmt.Errorf("adding blocked hosts: %w", err)
}
a.blockedHostsEng = urlfilter.NewDNSEngine(rulesStrg)
return a, nil
}
// IsBlockedDomain - return TRUE if this domain should be blocked
func (a *accessCtx) IsBlockedDomain(host string) (ok bool) {
a.lock.Lock()
defer a.lock.Unlock()
// allowlistMode returns true if this *accessCtx is in the allowlist mode.
func (a *accessCtx) allowlistMode() (ok bool) {
return a.allowedIPs.Len() != 0 || a.allowedClientIDs.Len() != 0 || len(a.allowedNets) != 0
}
_, ok = a.blockedHostsEngine.Match(strings.ToLower(host))
// isBlockedClientID returns true if the ClientID should be blocked.
func (a *accessCtx) isBlockedClientID(id string) (ok bool) {
allowlistMode := a.allowlistMode()
if id == "" {
// In allowlist mode, consider requests without client IDs
// blocked by default.
return allowlistMode
}
if allowlistMode {
return !a.allowedClientIDs.Has(id)
}
return a.blockedClientIDs.Has(id)
}
// isBlockedHost returns true if host should be blocked.
func (a *accessCtx) isBlockedHost(host string) (ok bool) {
_, ok = a.blockedHostsEng.Match(strings.ToLower(host))
return ok
}
// isBlockedIP returns the status of the IP address blocking as well as the rule
// that blocked it.
func (a *accessCtx) isBlockedIP(ip net.IP) (blocked bool, rule string) {
blocked = true
ips := a.blockedIPs
ipnets := a.blockedNets
if a.allowlistMode() {
// Enable allowlist mode and use the allowlist sets.
blocked = false
ips = a.allowedIPs
ipnets = a.allowedNets
}
if _, ok := ips.Get(ip); ok {
return blocked, ip.String()
}
for _, ipnet := range ipnets {
if ipnet.Contains(ip) {
return blocked, ipnet.String()
}
}
return !blocked, ""
}
type accessListJSON struct {
AllowedClients []string `json:"allowed_clients"`
DisallowedClients []string `json:"disallowed_clients"`
@ -161,62 +186,43 @@ func (s *Server) handleAccessList(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w).Encode(j)
if err != nil {
httpError(r, w, http.StatusInternalServerError, "json.Encode: %s", err)
httpError(r, w, http.StatusInternalServerError, "encoding response: %s", err)
return
}
}
func checkIPCIDRArray(src []string) error {
for _, s := range src {
ip := net.ParseIP(s)
if ip != nil {
continue
}
_, _, err := net.ParseCIDR(s)
if err != nil {
return err
}
}
return nil
}
func (s *Server) handleAccessSet(w http.ResponseWriter, r *http.Request) {
j := accessListJSON{}
err := json.NewDecoder(r.Body).Decode(&j)
list := accessListJSON{}
err := json.NewDecoder(r.Body).Decode(&list)
if err != nil {
httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err)
return
}
httpError(r, w, http.StatusBadRequest, "decoding request: %s", err)
err = checkIPCIDRArray(j.AllowedClients)
if err == nil {
err = checkIPCIDRArray(j.DisallowedClients)
}
if err != nil {
httpError(r, w, http.StatusBadRequest, "%s", err)
return
}
var a *accessCtx
a, err = newAccessCtx(j.AllowedClients, j.DisallowedClients, j.BlockedHosts)
a, err = newAccessCtx(list.AllowedClients, list.DisallowedClients, list.BlockedHosts)
if err != nil {
httpError(r, w, http.StatusBadRequest, "creating access ctx: %s", err)
return
}
defer log.Debug("Access: updated lists: %d, %d, %d",
len(j.AllowedClients), len(j.DisallowedClients), len(j.BlockedHosts))
defer log.Debug(
"access: updated lists: %d, %d, %d",
len(list.AllowedClients),
len(list.DisallowedClients),
len(list.BlockedHosts),
)
defer s.conf.ConfigModified()
s.serverLock.Lock()
defer s.serverLock.Unlock()
s.conf.AllowedClients = j.AllowedClients
s.conf.DisallowedClients = j.DisallowedClients
s.conf.BlockedHosts = j.BlockedHosts
s.conf.AllowedClients = list.AllowedClients
s.conf.DisallowedClients = list.DisallowedClients
s.conf.BlockedHosts = list.BlockedHosts
s.access = a
}

View File

@ -8,99 +8,23 @@ import (
"github.com/stretchr/testify/require"
)
func TestIsBlockedIP(t *testing.T) {
const (
ip int = iota
cidr
)
func TestIsBlockedClientID(t *testing.T) {
clientID := "client-1"
clients := []string{clientID}
rules := []string{
ip: "1.1.1.1",
cidr: "2.2.0.0/16",
}
a, err := newAccessCtx(clients, nil, nil)
require.NoError(t, err)
testCases := []struct {
name string
allowed bool
ip net.IP
wantDis bool
wantRule string
}{{
name: "allow_ip",
allowed: true,
ip: net.IPv4(1, 1, 1, 1),
wantDis: false,
wantRule: "",
}, {
name: "disallow_ip",
allowed: true,
ip: net.IPv4(1, 1, 1, 2),
wantDis: true,
wantRule: "",
}, {
name: "allow_cidr",
allowed: true,
ip: net.IPv4(2, 2, 1, 1),
wantDis: false,
wantRule: "",
}, {
name: "disallow_cidr",
allowed: true,
ip: net.IPv4(2, 3, 1, 1),
wantDis: true,
wantRule: "",
}, {
name: "allow_ip",
allowed: false,
ip: net.IPv4(1, 1, 1, 1),
wantDis: true,
wantRule: rules[ip],
}, {
name: "disallow_ip",
allowed: false,
ip: net.IPv4(1, 1, 1, 2),
wantDis: false,
wantRule: "",
}, {
name: "allow_cidr",
allowed: false,
ip: net.IPv4(2, 2, 1, 1),
wantDis: true,
wantRule: rules[cidr],
}, {
name: "disallow_cidr",
allowed: false,
ip: net.IPv4(2, 3, 1, 1),
wantDis: false,
wantRule: "",
}}
assert.False(t, a.isBlockedClientID(clientID))
for _, tc := range testCases {
prefix := "allowed_"
if !tc.allowed {
prefix = "disallowed_"
}
a, err = newAccessCtx(nil, clients, nil)
require.NoError(t, err)
t.Run(prefix+tc.name, func(t *testing.T) {
allowedRules := rules
var disallowedRules []string
if !tc.allowed {
allowedRules, disallowedRules = disallowedRules, allowedRules
}
aCtx, err := newAccessCtx(allowedRules, disallowedRules, nil)
require.NoError(t, err)
disallowed, rule := aCtx.IsBlockedIP(tc.ip)
assert.Equal(t, tc.wantDis, disallowed)
assert.Equal(t, tc.wantRule, rule)
})
}
assert.True(t, a.isBlockedClientID(clientID))
}
func TestIsBlockedDomain(t *testing.T) {
aCtx, err := newAccessCtx(nil, nil, []string{
func TestIsBlockedHost(t *testing.T) {
a, err := newAccessCtx(nil, nil, []string{
"host1",
"*.host.com",
"||host3.com^",
@ -108,50 +32,106 @@ func TestIsBlockedDomain(t *testing.T) {
require.NoError(t, err)
testCases := []struct {
name string
domain string
want bool
name string
host string
want bool
}{{
name: "plain_match",
domain: "host1",
want: true,
name: "plain_match",
host: "host1",
want: true,
}, {
name: "plain_mismatch",
domain: "host2",
want: false,
name: "plain_mismatch",
host: "host2",
want: false,
}, {
name: "wildcard_type-1_match_short",
domain: "asdf.host.com",
want: true,
name: "subdomain_match_short",
host: "asdf.host.com",
want: true,
}, {
name: "wildcard_type-1_match_long",
domain: "qwer.asdf.host.com",
want: true,
name: "subdomain_match_long",
host: "qwer.asdf.host.com",
want: true,
}, {
name: "wildcard_type-1_mismatch_no-lead",
domain: "host.com",
want: false,
name: "subdomain_mismatch_no_lead",
host: "host.com",
want: false,
}, {
name: "wildcard_type-1_mismatch_bad-asterisk",
domain: "asdf.zhost.com",
want: false,
name: "subdomain_mismatch_bad_asterisk",
host: "asdf.zhost.com",
want: false,
}, {
name: "wildcard_type-2_match_simple",
domain: "host3.com",
want: true,
name: "rule_match_simple",
host: "host3.com",
want: true,
}, {
name: "wildcard_type-2_match_complex",
domain: "asdf.host3.com",
want: true,
name: "rule_match_complex",
host: "asdf.host3.com",
want: true,
}, {
name: "wildcard_type-2_mismatch",
domain: ".host3.com",
want: false,
name: "rule_mismatch",
host: ".host3.com",
want: false,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, tc.want, aCtx.IsBlockedDomain(tc.domain))
assert.Equal(t, tc.want, a.isBlockedHost(tc.host))
})
}
}
func TestIsBlockedIP(t *testing.T) {
clients := []string{
"1.2.3.4",
"5.6.7.8/24",
}
allowCtx, err := newAccessCtx(clients, nil, nil)
require.NoError(t, err)
blockCtx, err := newAccessCtx(nil, clients, nil)
require.NoError(t, err)
testCases := []struct {
name string
wantRule string
ip net.IP
wantBlocked bool
}{{
name: "match_ip",
wantRule: "1.2.3.4",
ip: net.IP{1, 2, 3, 4},
wantBlocked: true,
}, {
name: "match_cidr",
wantRule: "5.6.7.8/24",
ip: net.IP{5, 6, 7, 100},
wantBlocked: true,
}, {
name: "no_match_ip",
wantRule: "",
ip: net.IP{9, 2, 3, 4},
wantBlocked: false,
}, {
name: "no_match_cidr",
wantRule: "",
ip: net.IP{9, 6, 7, 100},
wantBlocked: false,
}}
t.Run("allow", func(t *testing.T) {
for _, tc := range testCases {
blocked, rule := allowCtx.isBlockedIP(tc.ip)
assert.Equal(t, !tc.wantBlocked, blocked)
assert.Equal(t, tc.wantRule, rule)
}
})
t.Run("block", func(t *testing.T) {
for _, tc := range testCases {
blocked, rule := blockCtx.isBlockedIP(tc.ip)
assert.Equal(t, tc.wantBlocked, blocked)
assert.Equal(t, tc.wantRule, rule)
}
})
}

View File

@ -2,6 +2,7 @@ package dnsforward
import (
"crypto/tls"
"encoding/binary"
"fmt"
"path"
"strings"
@ -50,15 +51,15 @@ func clientIDFromClientServerName(hostSrvName, cliSrvName string, strict bool) (
return clientID, nil
}
// processClientIDHTTPS extracts the client's ID from the path of the
// clientIDFromDNSContextHTTPS extracts the client's ID from the path of the
// client's DNS-over-HTTPS request.
func processClientIDHTTPS(ctx *dnsContext) (rc resultCode) {
pctx := ctx.proxyCtx
func clientIDFromDNSContextHTTPS(pctx *proxy.DNSContext) (clientID string, err error) {
r := pctx.HTTPRequest
if r == nil {
ctx.err = fmt.Errorf("proxy ctx http request of proto %s is nil", pctx.Proto)
return resultCodeError
return "", fmt.Errorf(
"proxy ctx http request of proto %s is nil",
pctx.Proto,
)
}
origPath := r.URL.Path
@ -68,34 +69,25 @@ func processClientIDHTTPS(ctx *dnsContext) (rc resultCode) {
}
if len(parts) == 0 || parts[0] != "dns-query" {
ctx.err = fmt.Errorf("client id check: invalid path %q", origPath)
return resultCodeError
return "", fmt.Errorf("client id check: invalid path %q", origPath)
}
clientID := ""
switch len(parts) {
case 1:
// Just /dns-query, no client ID.
return resultCodeSuccess
return "", nil
case 2:
clientID = parts[1]
default:
ctx.err = fmt.Errorf("client id check: invalid path %q: extra parts", origPath)
return resultCodeError
return "", fmt.Errorf("client id check: invalid path %q: extra parts", origPath)
}
err := ValidateClientID(clientID)
err = ValidateClientID(clientID)
if err != nil {
ctx.err = fmt.Errorf("client id check: %w", err)
return resultCodeError
return "", fmt.Errorf("client id check: %w", err)
}
ctx.clientID = clientID
return resultCodeSuccess
return clientID, nil
}
// tlsConn is a narrow interface for *tls.Conn to simplify testing.
@ -108,53 +100,73 @@ type quicSession interface {
ConnectionState() (cs quic.ConnectionState)
}
// processClientID extracts the client's ID from the server name of the client's
// DoT or DoQ request or the path of the client's DoH.
func processClientID(dctx *dnsContext) (rc resultCode) {
pctx := dctx.proxyCtx
// clientIDFromDNSContext extracts the client's ID from the server name of the
// client's DoT or DoQ request or the path of the client's DoH. If the protocol
// is not one of these, clientID is an empty string and err is nil.
func (s *Server) clientIDFromDNSContext(pctx *proxy.DNSContext) (clientID string, err error) {
proto := pctx.Proto
if proto == proxy.ProtoHTTPS {
return processClientIDHTTPS(dctx)
return clientIDFromDNSContextHTTPS(pctx)
} else if proto != proxy.ProtoTLS && proto != proxy.ProtoQUIC {
return resultCodeSuccess
return "", nil
}
srvConf := dctx.srv.conf
hostSrvName := srvConf.TLSConfig.ServerName
hostSrvName := s.conf.ServerName
if hostSrvName == "" {
return resultCodeSuccess
return "", nil
}
cliSrvName := ""
if proto == proxy.ProtoTLS {
switch proto {
case proxy.ProtoTLS:
conn := pctx.Conn
tc, ok := conn.(tlsConn)
if !ok {
dctx.err = fmt.Errorf("proxy ctx conn of proto %s is %T, want *tls.Conn", proto, conn)
return resultCodeError
return "", fmt.Errorf(
"proxy ctx conn of proto %s is %T, want *tls.Conn",
proto,
conn,
)
}
cliSrvName = tc.ConnectionState().ServerName
} else if proto == proxy.ProtoQUIC {
case proxy.ProtoQUIC:
qs, ok := pctx.QUICSession.(quicSession)
if !ok {
dctx.err = fmt.Errorf("proxy ctx quic session of proto %s is %T, want quic.Session", proto, pctx.QUICSession)
return resultCodeError
return "", fmt.Errorf(
"proxy ctx quic session of proto %s is %T, want quic.Session",
proto,
pctx.QUICSession,
)
}
cliSrvName = qs.ConnectionState().TLS.ServerName
}
clientID, err := clientIDFromClientServerName(hostSrvName, cliSrvName, srvConf.StrictSNICheck)
clientID, err = clientIDFromClientServerName(
hostSrvName,
cliSrvName,
s.conf.StrictSNICheck,
)
if err != nil {
dctx.err = fmt.Errorf("client id check: %w", err)
return resultCodeError
return "", fmt.Errorf("client id check: %w", err)
}
dctx.clientID = clientID
return clientID, nil
}
// processClientID puts the clientID into the DNS context, if there is one.
func (s *Server) processClientID(dctx *dnsContext) (rc resultCode) {
pctx := dctx.proxyCtx
var key [8]byte
binary.BigEndian.PutUint64(key[:], pctx.RequestID)
clientIDData := s.clientIDCache.Get(key[:])
if clientIDData == nil {
return resultCodeSuccess
}
dctx.clientID = string(clientIDData)
return resultCodeSuccess
}

View File

@ -45,15 +45,14 @@ func (c testQUICSession) ConnectionState() (cs quic.ConnectionState) {
return cs
}
func TestProcessClientID(t *testing.T) {
func TestServer_clientIDFromDNSContext(t *testing.T) {
testCases := []struct {
name string
proto string
proto proxy.Proto
hostSrvName string
cliSrvName string
wantClientID string
wantErrMsg string
wantRes resultCode
strictSNI bool
}{{
name: "udp",
@ -62,7 +61,6 @@ func TestProcessClientID(t *testing.T) {
cliSrvName: "",
wantClientID: "",
wantErrMsg: "",
wantRes: resultCodeSuccess,
strictSNI: false,
}, {
name: "tls_no_client_id",
@ -71,7 +69,6 @@ func TestProcessClientID(t *testing.T) {
cliSrvName: "example.com",
wantClientID: "",
wantErrMsg: "",
wantRes: resultCodeSuccess,
strictSNI: true,
}, {
name: "tls_no_client_server_name",
@ -81,7 +78,6 @@ func TestProcessClientID(t *testing.T) {
wantClientID: "",
wantErrMsg: `client id check: client server name "" ` +
`doesn't match host server name "example.com"`,
wantRes: resultCodeError,
strictSNI: true,
}, {
name: "tls_no_client_server_name_no_strict",
@ -90,7 +86,6 @@ func TestProcessClientID(t *testing.T) {
cliSrvName: "",
wantClientID: "",
wantErrMsg: "",
wantRes: resultCodeSuccess,
strictSNI: false,
}, {
name: "tls_client_id",
@ -99,7 +94,6 @@ func TestProcessClientID(t *testing.T) {
cliSrvName: "cli.example.com",
wantClientID: "cli",
wantErrMsg: "",
wantRes: resultCodeSuccess,
strictSNI: true,
}, {
name: "tls_client_id_hostname_error",
@ -109,7 +103,6 @@ func TestProcessClientID(t *testing.T) {
wantClientID: "",
wantErrMsg: `client id check: client server name "cli.example.net" ` +
`doesn't match host server name "example.com"`,
wantRes: resultCodeError,
strictSNI: true,
}, {
name: "tls_invalid_client_id",
@ -119,7 +112,6 @@ func TestProcessClientID(t *testing.T) {
wantClientID: "",
wantErrMsg: `client id check: invalid client id "!!!": ` +
`invalid char '!' at index 0`,
wantRes: resultCodeError,
strictSNI: true,
}, {
name: "tls_client_id_too_long",
@ -131,7 +123,6 @@ func TestProcessClientID(t *testing.T) {
wantErrMsg: `client id check: invalid client id "abcdefghijklmno` +
`pqrstuvwxyz0123456789abcdefghijklmnopqrstuvwxyz0123456789": ` +
`label is too long, max: 63`,
wantRes: resultCodeError,
strictSNI: true,
}, {
name: "quic_client_id",
@ -140,7 +131,6 @@ func TestProcessClientID(t *testing.T) {
cliSrvName: "cli.example.com",
wantClientID: "cli",
wantErrMsg: "",
wantRes: resultCodeSuccess,
strictSNI: true,
}}
@ -150,6 +140,7 @@ func TestProcessClientID(t *testing.T) {
ServerName: tc.hostSrvName,
StrictSNICheck: tc.strictSNI,
}
srv := &Server{
conf: ServerConfig{TLSConfig: tlsConf},
}
@ -168,79 +159,68 @@ func TestProcessClientID(t *testing.T) {
}
}
dctx := &dnsContext{
srv: srv,
proxyCtx: &proxy.DNSContext{
Proto: tc.proto,
Conn: conn,
QUICSession: qs,
},
pctx := &proxy.DNSContext{
Proto: tc.proto,
Conn: conn,
QUICSession: qs,
}
res := processClientID(dctx)
assert.Equal(t, tc.wantRes, res)
assert.Equal(t, tc.wantClientID, dctx.clientID)
clientID, err := srv.clientIDFromDNSContext(pctx)
assert.Equal(t, tc.wantClientID, clientID)
if tc.wantErrMsg == "" {
assert.NoError(t, dctx.err)
assert.NoError(t, err)
} else {
require.Error(t, dctx.err)
assert.Equal(t, tc.wantErrMsg, dctx.err.Error())
require.Error(t, err)
assert.Equal(t, tc.wantErrMsg, err.Error())
}
})
}
}
func TestProcessClientID_https(t *testing.T) {
func TestClientIDFromDNSContextHTTPS(t *testing.T) {
testCases := []struct {
name string
path string
wantClientID string
wantErrMsg string
wantRes resultCode
}{{
name: "no_client_id",
path: "/dns-query",
wantClientID: "",
wantErrMsg: "",
wantRes: resultCodeSuccess,
}, {
name: "no_client_id_slash",
path: "/dns-query/",
wantClientID: "",
wantErrMsg: "",
wantRes: resultCodeSuccess,
}, {
name: "client_id",
path: "/dns-query/cli",
wantClientID: "cli",
wantErrMsg: "",
wantRes: resultCodeSuccess,
}, {
name: "client_id_slash",
path: "/dns-query/cli/",
wantClientID: "cli",
wantErrMsg: "",
wantRes: resultCodeSuccess,
}, {
name: "bad_url",
path: "/foo",
wantClientID: "",
wantErrMsg: `client id check: invalid path "/foo"`,
wantRes: resultCodeError,
}, {
name: "extra",
path: "/dns-query/cli/foo",
wantClientID: "",
wantErrMsg: `client id check: invalid path "/dns-query/cli/foo": extra parts`,
wantRes: resultCodeError,
}, {
name: "invalid_client_id",
path: "/dns-query/!!!",
wantClientID: "",
wantErrMsg: `client id check: invalid client id "!!!": ` +
`invalid char '!' at index 0`,
wantRes: resultCodeError,
}}
for _, tc := range testCases {
@ -251,23 +231,20 @@ func TestProcessClientID_https(t *testing.T) {
},
}
dctx := &dnsContext{
proxyCtx: &proxy.DNSContext{
Proto: proxy.ProtoHTTPS,
HTTPRequest: r,
},
pctx := &proxy.DNSContext{
Proto: proxy.ProtoHTTPS,
HTTPRequest: r,
}
res := processClientID(dctx)
assert.Equal(t, tc.wantRes, res)
assert.Equal(t, tc.wantClientID, dctx.clientID)
clientID, err := clientIDFromDNSContextHTTPS(pctx)
assert.Equal(t, tc.wantClientID, clientID)
if tc.wantErrMsg == "" {
assert.NoError(t, dctx.err)
assert.NoError(t, err)
} else {
require.Error(t, dctx.err)
require.Error(t, err)
assert.Equal(t, tc.wantErrMsg, dctx.err.Error())
assert.Equal(t, tc.wantErrMsg, err.Error())
}
})
}

View File

@ -331,7 +331,7 @@ func (s *Server) prepareUpstreamSettings() error {
upstreams = aghstrings.FilterOut(upstreams, aghstrings.IsCommentOrEmpty)
upstreamConfig, err := proxy.ParseUpstreamsConfig(
upstreams,
upstream.Options{
&upstream.Options{
Bootstrap: s.conf.BootstrapDNS,
Timeout: s.conf.UpstreamTimeout,
},
@ -342,10 +342,10 @@ func (s *Server) prepareUpstreamSettings() error {
if len(upstreamConfig.Upstreams) == 0 {
log.Info("warning: no default upstream servers specified, using %v", defaultDNS)
var uc proxy.UpstreamConfig
var uc *proxy.UpstreamConfig
uc, err = proxy.ParseUpstreamsConfig(
defaultDNS,
upstream.Options{
&upstream.Options{
Bootstrap: s.conf.BootstrapDNS,
Timeout: s.conf.UpstreamTimeout,
},
@ -356,7 +356,8 @@ func (s *Server) prepareUpstreamSettings() error {
upstreamConfig.Upstreams = uc.Upstreams
}
s.conf.UpstreamConfig = &upstreamConfig
s.conf.UpstreamConfig = upstreamConfig
return nil
}

View File

@ -89,7 +89,7 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error {
s.processInternalHosts,
s.processRestrictLocal,
s.processInternalIPAddrs,
processClientID,
s.processClientID,
processFilteringBeforeRequest,
s.processLocalPTR,
s.processUpstream,
@ -165,7 +165,7 @@ func (s *Server) setTableHostToIP(t hostToIPTable) {
s.tableHostToIP = t
}
func (s *Server) setTableIPToHost(t ipToHostTable) {
func (s *Server) setTableIPToHost(t *aghnet.IPMap) {
s.tableIPToHostLock.Lock()
defer s.tableIPToHostLock.Unlock()
@ -188,13 +188,13 @@ func (s *Server) onDHCPLeaseChanged(flags int) {
}
var hostToIP hostToIPTable
var ipToHost ipToHostTable
var ipToHost *aghnet.IPMap
if add {
hostToIP = make(hostToIPTable)
ipToHost = make(ipToHostTable)
ll := s.dhcpServer.Leases(dhcpd.LeasesAll)
hostToIP = make(hostToIPTable, len(ll))
ipToHost = aghnet.NewIPMap(len(ll))
for _, l := range ll {
// TODO(a.garipov): Remove this after we're finished
// with the client hostname validations in the DHCP
@ -210,14 +210,14 @@ func (s *Server) onDHCPLeaseChanged(flags int) {
lowhost := strings.ToLower(l.Hostname)
ipToHost[l.IP.String()] = lowhost
ipToHost.Set(l.IP, lowhost)
ip := make(net.IP, 4)
copy(ip, l.IP.To4())
hostToIP[lowhost] = ip
}
log.Debug("dns: added %d A/PTR entries from DHCP", len(ipToHost))
log.Debug("dns: added %d A/PTR entries from DHCP", ipToHost.Len())
}
s.setTableHostToIP(hostToIP)
@ -377,7 +377,15 @@ func (s *Server) ipToHost(ip net.IP) (host string, ok bool) {
return "", false
}
host, ok = s.tableIPToHost[ip.String()]
var v interface{}
v, ok = s.tableIPToHost.Get(ip)
var typOK bool
if host, typOK = v.(string); !typOK {
log.Error("dns: bad type %T in tableIPToHost for %s", v, ip)
return "", false
}
return host, ok
}

View File

@ -18,6 +18,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/stats"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/cache"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
@ -26,6 +27,11 @@ import (
// DefaultTimeout is the default upstream timeout
const DefaultTimeout = 10 * time.Second
// defaultClientIDCacheCount is the default count of items in the LRU client ID
// cache. The assumption here is that there won't be more than this many
// requests between the BeforeRequestHandler stage and the actual processing.
const defaultClientIDCacheCount = 1024
const (
safeBrowsingBlockHost = "standard-block.dns.adguard.com"
parentalBlockHost = "family-block.dns.adguard.com"
@ -44,12 +50,6 @@ var webRegistered bool
// hostToIPTable is an alias for the type of Server.tableHostToIP.
type hostToIPTable = map[string]net.IP
// ipToHostTable is an alias for the type of Server.tableIPToHost.
//
// TODO(a.garipov): Define an IPMap type in aghnet and use here and in other
// places?
type ipToHostTable = map[string]string
// Server is the main way to start a DNS server.
//
// Example:
@ -81,9 +81,13 @@ type Server struct {
tableHostToIP hostToIPTable
tableHostToIPLock sync.Mutex
tableIPToHost ipToHostTable
tableIPToHost *aghnet.IPMap
tableIPToHostLock sync.Mutex
// clientIDCache is a temporary storage for clientIDs that were
// extracted during the BeforeRequestHandler stage.
clientIDCache cache.Cache
// DNS proxy instance for internal usage
// We don't Start() it and so no listen port is required.
internalProxy *proxy.Proxy
@ -152,6 +156,10 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
subnetDetector: p.SubnetDetector,
localDomainSuffix: localDomainSuffix,
recDetector: newRecursionDetector(recursionTTL, cachedRecurrentReqNum),
clientIDCache: cache.New(cache.Config{
EnableLRU: true,
MaxCount: defaultClientIDCacheCount,
}),
}
// TODO(e.burkov): Enable the refresher after the actual implementation
@ -414,19 +422,22 @@ func (s *Server) setupResolvers(localAddrs []string) (err error) {
log.Debug("upstreams to resolve PTR for local addresses: %v", localAddrs)
var upsConfig proxy.UpstreamConfig
upsConfig, err = proxy.ParseUpstreamsConfig(localAddrs, upstream.Options{
Bootstrap: bootstraps,
Timeout: defaultLocalTimeout,
// TODO(e.burkov): Should we verify server's ceritificates?
})
var upsConfig *proxy.UpstreamConfig
upsConfig, err = proxy.ParseUpstreamsConfig(
localAddrs,
&upstream.Options{
Bootstrap: bootstraps,
Timeout: defaultLocalTimeout,
// TODO(e.burkov): Should we verify server's ceritificates?
},
)
if err != nil {
return fmt.Errorf("parsing upstreams: %w", err)
}
s.localResolvers = &proxy.Proxy{
Config: proxy.Config{
UpstreamConfig: &upsConfig,
UpstreamConfig: upsConfig,
},
}
@ -577,11 +588,33 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}
// IsBlockedIP - return TRUE if this client should be blocked
func (s *Server) IsBlockedIP(ip net.IP) (bool, string) {
if ip == nil {
return false, ""
// IsBlockedClient returns true if the client is blocked by the current access
// settings.
func (s *Server) IsBlockedClient(ip net.IP, clientID string) (blocked bool, rule string) {
s.serverLock.RLock()
defer s.serverLock.RUnlock()
allowlistMode := s.access.allowlistMode()
blockedByIP, rule := s.access.isBlockedIP(ip)
blockedByClientID := s.access.isBlockedClientID(clientID)
// Allow if at least one of the checks allows in allowlist mode, but
// block if at least one of the checks blocks in blocklist mode.
if allowlistMode && blockedByIP && blockedByClientID {
log.Debug("client %s (id %q) is not in access allowlist", ip, clientID)
// Return now without substituting the empty rule for the
// clientID because the rule can't be empty here.
return true, rule
} else if !allowlistMode && (blockedByIP || blockedByClientID) {
log.Debug("client %s (id %q) is in access blocklist", ip, clientID)
blocked = true
}
return s.access.IsBlockedIP(ip)
if rule == "" {
rule = clientID
}
return blocked, rule
}

View File

@ -257,19 +257,22 @@ func TestServer(t *testing.T) {
testCases := []struct {
name string
proto string
net string
proto proxy.Proto
}{{
name: "message_over_udp",
net: "",
proto: proxy.ProtoUDP,
}, {
name: "message_over_tcp",
net: "tcp",
proto: proxy.ProtoTCP,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
addr := s.dnsProxy.Addr(tc.proto)
client := dns.Client{Net: tc.proto}
client := dns.Client{Net: tc.net}
reply, _, err := client.Exchange(createGoogleATestMessage(), addr.String())
require.NoErrorf(t, err, "сouldn't talk to server %s: %s", addr, err)
@ -324,7 +327,7 @@ func TestServerWithProtectionDisabled(t *testing.T) {
// Message over UDP.
req := createGoogleATestMessage()
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
client := dns.Client{Net: proxy.ProtoUDP}
client := &dns.Client{}
reply, _, err := client.Exchange(req, addr.String())
require.NoErrorf(t, err, "сouldn't talk to server %s: %s", addr, err)
@ -376,7 +379,7 @@ func TestDoQServer(t *testing.T) {
// Create a DNS-over-QUIC upstream.
addr := s.dnsProxy.Addr(proxy.ProtoQUIC)
opts := upstream.Options{InsecureSkipVerify: true}
opts := &upstream.Options{InsecureSkipVerify: true}
u, err := upstream.AddressToUpstream(fmt.Sprintf("%s://%s", proxy.ProtoQUIC, addr), opts)
require.NoError(t, err)
@ -420,7 +423,7 @@ func TestServerRace(t *testing.T) {
// Message over UDP.
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
conn, err := dns.Dial(proxy.ProtoUDP, addr.String())
conn, err := dns.Dial("udp", addr.String())
require.NoErrorf(t, err, "cannot connect to the proxy: %s", err)
sendTestMessagesAsync(t, conn)
@ -445,7 +448,7 @@ func TestSafeSearch(t *testing.T) {
startDeferStop(t, s)
addr := s.dnsProxy.Addr(proxy.ProtoUDP).String()
client := dns.Client{Net: proxy.ProtoUDP}
client := &dns.Client{}
yandexIP := net.IP{213, 180, 193, 56}
googleIP, _ := resolver.HostToIPs("forcesafesearch.google.com")
@ -507,7 +510,6 @@ func TestInvalidRequest(t *testing.T) {
// Send a DNS request without question.
_, _, err := (&dns.Client{
Net: proxy.ProtoUDP,
Timeout: 500 * time.Millisecond,
}).Exchange(&req, addr)

View File

@ -1,6 +1,7 @@
package dnsforward
import (
"encoding/binary"
"fmt"
"strings"
@ -11,23 +12,39 @@ import (
"github.com/miekg/dns"
)
func (s *Server) beforeRequestHandler(_ *proxy.Proxy, d *proxy.DNSContext) (bool, error) {
ip := aghnet.IPFromAddr(d.Addr)
disallowed, _ := s.access.IsBlockedIP(ip)
if disallowed {
log.Tracef("Client IP %s is blocked by settings", ip)
// beforeRequestHandler is the handler that is called before any other
// processing, including logs. It performs access checks and puts the client
// ID, if there is one, into the server's cache.
func (s *Server) beforeRequestHandler(
_ *proxy.Proxy,
pctx *proxy.DNSContext,
) (reply bool, err error) {
ip := aghnet.IPFromAddr(pctx.Addr)
clientID, err := s.clientIDFromDNSContext(pctx)
if err != nil {
return false, fmt.Errorf("getting clientid: %w", err)
}
blocked, _ := s.IsBlockedClient(ip, clientID)
if blocked {
return false, nil
}
if len(d.Req.Question) == 1 {
host := strings.TrimSuffix(d.Req.Question[0].Name, ".")
if s.access.IsBlockedDomain(host) {
log.Tracef("domain %s is blocked by access settings", host)
if len(pctx.Req.Question) == 1 {
host := strings.TrimSuffix(pctx.Req.Question[0].Name, ".")
if s.access.isBlockedHost(host) {
log.Debug("host %s is in access blocklist", host)
return false, nil
}
}
if clientID != "" {
key := [8]byte{}
binary.BigEndian.PutUint64(key[:], pctx.RequestID)
s.clientIDCache.Set(key[:], []byte(clientID))
}
return true, nil
}

View File

@ -167,7 +167,7 @@ func (req *dnsConfig) checkBootstrap() (string, error) {
return boot, fmt.Errorf("invalid bootstrap server address: empty")
}
if _, err := upstream.NewResolver(boot, upstream.Options{Timeout: 0}); err != nil {
if _, err := upstream.NewResolver(boot, nil); err != nil {
return boot, fmt.Errorf("invalid bootstrap server address: %w", err)
}
}
@ -348,7 +348,7 @@ func ValidateUpstreams(upstreams []string) (err error) {
_, err = proxy.ParseUpstreamsConfig(
upstreams,
upstream.Options{
&upstream.Options{
Bootstrap: []string{},
Timeout: DefaultTimeout,
},
@ -546,7 +546,7 @@ func checkDNS(input string, bootstrap []string, timeout time.Duration, ef excFun
log.Debug("checking if dns server %q works...", input)
var u upstream.Upstream
u, err = upstream.AddressToUpstream(input, upstream.Options{
u, err = upstream.AddressToUpstream(input, &upstream.Options{
Bootstrap: bootstrap,
Timeout: timeout,
})

View File

@ -46,7 +46,7 @@ func (l *testStats) Update(e stats.Entry) {
func TestProcessQueryLogsAndStats(t *testing.T) {
testCases := []struct {
name string
proto string
proto proxy.Proto
addr net.Addr
clientID string
wantLogProto querylog.ClientProto
@ -156,7 +156,7 @@ func TestProcessQueryLogsAndStats(t *testing.T) {
wantStatResult: stats.RParental,
}}
ups, err := upstream.AddressToUpstream("1.1.1.1", upstream.Options{})
ups, err := upstream.AddressToUpstream("1.1.1.1", nil)
require.Nil(t, err)
for _, tc := range testCases {

View File

@ -49,7 +49,7 @@ func (d *DNSFilter) initSecurityServices() error {
var err error
d.safeBrowsingServer = defaultSafebrowsingServer
d.parentalServer = defaultParentalServer
opts := upstream.Options{
opts := &upstream.Options{
Timeout: dnsTimeout,
ServerIPAddrs: []net.IP{
{94, 140, 14, 15},

View File

@ -78,10 +78,13 @@ type RuntimeClientWHOISInfo struct {
type clientsContainer struct {
// TODO(a.garipov): Perhaps use a number of separate indices for
// different types (string, net.IP, and so on).
list map[string]*Client // name -> client
idIndex map[string]*Client // ID -> client
ipToRC map[string]*RuntimeClient // IP -> runtime client
lock sync.Mutex
list map[string]*Client // name -> client
idIndex map[string]*Client // ID -> client
// ipToRC is the IP address to *RuntimeClient map.
ipToRC *aghnet.IPMap
lock sync.Mutex
allTags *aghstrings.Set
@ -109,7 +112,7 @@ func (clients *clientsContainer) Init(
}
clients.list = make(map[string]*Client)
clients.idIndex = make(map[string]*Client)
clients.ipToRC = make(map[string]*RuntimeClient)
clients.ipToRC = aghnet.NewIPMap(0)
clients.allTags = aghstrings.NewSet(clientTags...)
@ -250,18 +253,17 @@ func (clients *clientsContainer) onHostsChanged() {
clients.addFromHostsFile()
}
// Exists checks if client with this ID already exists.
func (clients *clientsContainer) Exists(id string, source clientSource) (ok bool) {
// Exists checks if client with this IP address already exists.
func (clients *clientsContainer) Exists(ip net.IP, source clientSource) (ok bool) {
clients.lock.Lock()
defer clients.lock.Unlock()
_, ok = clients.findLocked(id)
_, ok = clients.findLocked(ip.String())
if ok {
return true
}
var rc *RuntimeClient
rc, ok = clients.ipToRC[id]
rc, ok := clients.findRuntimeClientLocked(ip)
if !ok {
return false
}
@ -288,13 +290,14 @@ func (clients *clientsContainer) findMultiple(ids []string) (c *querylog.Client,
for _, id := range ids {
var name string
whois := &querylog.ClientWHOIS{}
ip := net.ParseIP(id)
c, ok := clients.Find(id)
if ok {
name = c.Name
} else {
var rc RuntimeClient
rc, ok = clients.FindRuntimeClient(id)
} else if ip != nil {
var rc *RuntimeClient
rc, ok = clients.FindRuntimeClient(ip)
if !ok {
continue
}
@ -303,8 +306,7 @@ func (clients *clientsContainer) findMultiple(ids []string) (c *querylog.Client,
whois = toQueryLogWHOIS(rc.WHOISInfo)
}
ip := net.ParseIP(id)
disallowed, disallowedRule := clients.dnsServer.IsBlockedIP(ip)
disallowed, disallowedRule := clients.dnsServer.IsBlockedClient(ip, id)
return &querylog.Client{
Name: name,
@ -356,10 +358,10 @@ func (clients *clientsContainer) findUpstreams(
return c.upstreamConfig, nil
}
var conf proxy.UpstreamConfig
var conf *proxy.UpstreamConfig
conf, err = proxy.ParseUpstreamsConfig(
upstreams,
upstream.Options{
&upstream.Options{
Bootstrap: config.DNS.BootstrapDNS,
Timeout: config.DNS.UpstreamTimeout.Duration,
},
@ -368,9 +370,9 @@ func (clients *clientsContainer) findUpstreams(
return nil, err
}
c.upstreamConfig = &conf
c.upstreamConfig = conf
return &conf, nil
return conf, nil
}
// findLocked searches for a client by its ID. For internal use only.
@ -423,22 +425,35 @@ func (clients *clientsContainer) findLocked(id string) (c *Client, ok bool) {
return nil, false
}
// findRuntimeClientLocked finds a runtime client by their IP address. For
// internal use only.
func (clients *clientsContainer) findRuntimeClientLocked(ip net.IP) (rc *RuntimeClient, ok bool) {
var v interface{}
v, ok = clients.ipToRC.Get(ip)
if !ok {
return nil, false
}
rc, ok = v.(*RuntimeClient)
if !ok {
log.Error("clients: bad type %T in ipToRC for %s", v, ip)
return nil, false
}
return rc, true
}
// FindRuntimeClient finds a runtime client by their IP.
func (clients *clientsContainer) FindRuntimeClient(ip string) (RuntimeClient, bool) {
ipAddr := net.ParseIP(ip)
if ipAddr == nil {
return RuntimeClient{}, false
func (clients *clientsContainer) FindRuntimeClient(ip net.IP) (rc *RuntimeClient, ok bool) {
if ip == nil {
return nil, false
}
clients.lock.Lock()
defer clients.lock.Unlock()
rc, ok := clients.ipToRC[ip]
if ok {
return *rc, true
}
return RuntimeClient{}, false
return clients.findRuntimeClientLocked(ip)
}
// check validates the client.
@ -621,17 +636,17 @@ func (clients *clientsContainer) Update(name string, c *Client) (err error) {
}
// SetWHOISInfo sets the WHOIS information for a client.
func (clients *clientsContainer) SetWHOISInfo(ip string, wi *RuntimeClientWHOISInfo) {
func (clients *clientsContainer) SetWHOISInfo(ip net.IP, wi *RuntimeClientWHOISInfo) {
clients.lock.Lock()
defer clients.lock.Unlock()
_, ok := clients.findLocked(ip)
_, ok := clients.findLocked(ip.String())
if ok {
log.Debug("clients: client for %s is already created, ignore whois info", ip)
return
}
rc, ok := clients.ipToRC[ip]
rc, ok := clients.findRuntimeClientLocked(ip)
if ok {
rc.WHOISInfo = wi
log.Debug("clients: set whois info for runtime client %s: %+v", rc.Host, wi)
@ -646,14 +661,15 @@ func (clients *clientsContainer) SetWHOISInfo(ip string, wi *RuntimeClientWHOISI
}
rc.WHOISInfo = wi
clients.ipToRC[ip] = rc
clients.ipToRC.Set(ip, rc)
log.Debug("clients: set whois info for runtime client with ip %s: %+v", ip, wi)
}
// AddHost adds a new IP-hostname pairing. The priorities of the sources is
// taken into account. ok is true if the pairing was added.
func (clients *clientsContainer) AddHost(ip, host string, src clientSource) (ok bool, err error) {
func (clients *clientsContainer) AddHost(ip net.IP, host string, src clientSource) (ok bool, err error) {
clients.lock.Lock()
defer clients.lock.Unlock()
@ -663,9 +679,9 @@ func (clients *clientsContainer) AddHost(ip, host string, src clientSource) (ok
}
// addHostLocked adds a new IP-hostname pairing. For internal use only.
func (clients *clientsContainer) addHostLocked(ip, host string, src clientSource) (ok bool) {
func (clients *clientsContainer) addHostLocked(ip net.IP, host string, src clientSource) (ok bool) {
var rc *RuntimeClient
rc, ok = clients.ipToRC[ip]
rc, ok = clients.findRuntimeClientLocked(ip)
if ok {
if rc.Source > src {
return false
@ -679,10 +695,10 @@ func (clients *clientsContainer) addHostLocked(ip, host string, src clientSource
WHOISInfo: &RuntimeClientWHOISInfo{},
}
clients.ipToRC[ip] = rc
clients.ipToRC.Set(ip, rc)
}
log.Debug("clients: added %q -> %q [%d]", ip, host, len(clients.ipToRC))
log.Debug("clients: added %s -> %q [%d]", ip, host, clients.ipToRC.Len())
return true
}
@ -690,12 +706,21 @@ func (clients *clientsContainer) addHostLocked(ip, host string, src clientSource
// rmHostsBySrc removes all entries that match the specified source.
func (clients *clientsContainer) rmHostsBySrc(src clientSource) {
n := 0
for k, v := range clients.ipToRC {
if v.Source == src {
delete(clients.ipToRC, k)
clients.ipToRC.Range(func(ip net.IP, v interface{}) (cont bool) {
rc, ok := v.(*RuntimeClient)
if !ok {
log.Error("clients: bad type %T in ipToRC for %s", v, ip)
return true
}
if rc.Source == src {
clients.ipToRC.Del(ip)
n++
}
}
return true
})
log.Debug("clients: removed %d client aliases", n)
}
@ -715,16 +740,23 @@ func (clients *clientsContainer) addFromHostsFile() {
clients.rmHostsBySrc(ClientSourceHostsFile)
n := 0
for ip, names := range hosts {
hosts.Range(func(ip net.IP, v interface{}) (cont bool) {
names, ok := v.([]string)
if !ok {
log.Error("dns: bad type %T in ipToRC for %s", v, ip)
}
for _, name := range names {
ok := clients.addHostLocked(ip, name, ClientSourceHostsFile)
ok = clients.addHostLocked(ip, name, ClientSourceHostsFile)
if ok {
n++
}
}
}
log.Debug("Clients: added %d client aliases from system hosts-file", n)
return true
})
log.Debug("clients: added %d client aliases from system hosts-file", n)
}
// addFromSystemARP adds the IP-hostname pairings from the output of the arp -a
@ -752,15 +784,16 @@ func (clients *clientsContainer) addFromSystemARP() {
// TODO(a.garipov): Rewrite to use bufio.Scanner.
lines := strings.Split(string(data), "\n")
for _, ln := range lines {
open := strings.Index(ln, " (")
close := strings.Index(ln, ") ")
if open == -1 || close == -1 || open >= close {
lparen := strings.Index(ln, " (")
rparen := strings.Index(ln, ") ")
if lparen == -1 || rparen == -1 || lparen >= rparen {
continue
}
host := ln[:open]
ip := ln[open+2 : close]
if aghnet.ValidateDomainName(host) != nil || net.ParseIP(ip) == nil {
host := ln[:lparen]
ipStr := ln[lparen+2 : rparen]
ip := net.ParseIP(ipStr)
if aghnet.ValidateDomainName(host) != nil || ip == nil {
continue
}
@ -796,7 +829,7 @@ func (clients *clientsContainer) updateFromDHCP(add bool) {
continue
}
ok := clients.addHostLocked(l.IP.String(), l.Hostname, ClientSourceDHCP)
ok := clients.addHostLocked(l.IP, l.Hostname, ClientSourceDHCP)
if ok {
n++
}

View File

@ -26,6 +26,7 @@ func TestClients(t *testing.T) {
ok, err := clients.Add(c)
require.NoError(t, err)
assert.True(t, ok)
c = &Client{
@ -35,23 +36,27 @@ func TestClients(t *testing.T) {
ok, err = clients.Add(c)
require.NoError(t, err)
assert.True(t, ok)
c, ok = clients.Find("1.1.1.1")
require.True(t, ok)
assert.Equal(t, "client1", c.Name)
c, ok = clients.Find("1:2:3::4")
require.True(t, ok)
assert.Equal(t, "client1", c.Name)
c, ok = clients.Find("2.2.2.2")
require.True(t, ok)
assert.Equal(t, "client2", c.Name)
assert.False(t, clients.Exists("1.2.3.4", ClientSourceHostsFile))
assert.True(t, clients.Exists("1.1.1.1", ClientSourceHostsFile))
assert.True(t, clients.Exists("2.2.2.2", ClientSourceHostsFile))
assert.False(t, clients.Exists(net.IP{1, 2, 3, 4}, ClientSourceHostsFile))
assert.True(t, clients.Exists(net.IP{1, 1, 1, 1}, ClientSourceHostsFile))
assert.True(t, clients.Exists(net.IP{2, 2, 2, 2}, ClientSourceHostsFile))
})
t.Run("add_fail_name", func(t *testing.T) {
@ -101,8 +106,8 @@ func TestClients(t *testing.T) {
})
require.NoError(t, err)
assert.False(t, clients.Exists("1.1.1.1", ClientSourceHostsFile))
assert.True(t, clients.Exists("1.1.1.2", ClientSourceHostsFile))
assert.False(t, clients.Exists(net.IP{1, 1, 1, 1}, ClientSourceHostsFile))
assert.True(t, clients.Exists(net.IP{1, 1, 1, 2}, ClientSourceHostsFile))
err = clients.Update("client1", &Client{
IDs: []string{"1.1.1.2"},
@ -113,21 +118,25 @@ func TestClients(t *testing.T) {
c, ok := clients.Find("1.1.1.2")
require.True(t, ok)
assert.Equal(t, "client1-renamed", c.Name)
assert.True(t, c.UseOwnSettings)
nilCli, ok := clients.list["client1"]
require.False(t, ok)
assert.Nil(t, nilCli)
require.Len(t, c.IDs, 1)
assert.Equal(t, "1.1.1.2", c.IDs[0])
})
t.Run("del_success", func(t *testing.T) {
ok := clients.Del("client1-renamed")
require.True(t, ok)
assert.False(t, clients.Exists("1.1.1.2", ClientSourceHostsFile))
assert.False(t, clients.Exists(net.IP{1, 1, 1, 2}, ClientSourceHostsFile))
})
t.Run("del_fail", func(t *testing.T) {
@ -136,37 +145,44 @@ func TestClients(t *testing.T) {
})
t.Run("addhost_success", func(t *testing.T) {
ok, err := clients.AddHost("1.1.1.1", "host", ClientSourceARP)
ip := net.IP{1, 1, 1, 1}
ok, err := clients.AddHost(ip, "host", ClientSourceARP)
require.NoError(t, err)
assert.True(t, ok)
ok, err = clients.AddHost("1.1.1.1", "host2", ClientSourceARP)
ok, err = clients.AddHost(ip, "host2", ClientSourceARP)
require.NoError(t, err)
assert.True(t, ok)
ok, err = clients.AddHost("1.1.1.1", "host3", ClientSourceHostsFile)
ok, err = clients.AddHost(ip, "host3", ClientSourceHostsFile)
require.NoError(t, err)
assert.True(t, ok)
assert.True(t, clients.Exists("1.1.1.1", ClientSourceHostsFile))
assert.True(t, clients.Exists(ip, ClientSourceHostsFile))
})
t.Run("dhcp_replaces_arp", func(t *testing.T) {
ok, err := clients.AddHost("1.2.3.4", "from_arp", ClientSourceARP)
ip := net.IP{1, 2, 3, 4}
ok, err := clients.AddHost(ip, "from_arp", ClientSourceARP)
require.NoError(t, err)
assert.True(t, ok)
assert.True(t, clients.Exists(ip, ClientSourceARP))
assert.True(t, clients.Exists("1.2.3.4", ClientSourceARP))
ok, err = clients.AddHost("1.2.3.4", "from_dhcp", ClientSourceDHCP)
ok, err = clients.AddHost(ip, "from_dhcp", ClientSourceDHCP)
require.NoError(t, err)
assert.True(t, ok)
assert.True(t, clients.Exists("1.2.3.4", ClientSourceDHCP))
assert.True(t, ok)
assert.True(t, clients.Exists(ip, ClientSourceDHCP))
})
t.Run("addhost_fail", func(t *testing.T) {
ok, err := clients.AddHost("1.1.1.1", "host1", ClientSourceRDNS)
ok, err := clients.AddHost(net.IP{1, 1, 1, 1}, "host1", ClientSourceRDNS)
require.NoError(t, err)
assert.False(t, ok)
})
@ -183,31 +199,39 @@ func TestClientsWHOIS(t *testing.T) {
}
t.Run("new_client", func(t *testing.T) {
clients.SetWHOISInfo("1.1.1.255", whois)
ip := net.IP{1, 1, 1, 255}
clients.SetWHOISInfo(ip, whois)
v, _ := clients.ipToRC.Get(ip)
require.NotNil(t, v)
require.NotNil(t, clients.ipToRC["1.1.1.255"])
rc, ok := v.(*RuntimeClient)
require.True(t, ok)
require.NotNil(t, rc)
h := clients.ipToRC["1.1.1.255"]
require.NotNil(t, h)
assert.Equal(t, h.WHOISInfo, whois)
assert.Equal(t, rc.WHOISInfo, whois)
})
t.Run("existing_auto-client", func(t *testing.T) {
ok, err := clients.AddHost("1.1.1.1", "host", ClientSourceRDNS)
ip := net.IP{1, 1, 1, 1}
ok, err := clients.AddHost(ip, "host", ClientSourceRDNS)
require.NoError(t, err)
assert.True(t, ok)
clients.SetWHOISInfo("1.1.1.1", whois)
clients.SetWHOISInfo(ip, whois)
v, _ := clients.ipToRC.Get(ip)
require.NotNil(t, v)
require.NotNil(t, clients.ipToRC["1.1.1.1"])
h := clients.ipToRC["1.1.1.1"]
require.NotNil(t, h)
rc, ok := v.(*RuntimeClient)
require.True(t, ok)
require.NotNil(t, rc)
assert.Equal(t, h.WHOISInfo, whois)
assert.Equal(t, rc.WHOISInfo, whois)
})
t.Run("can't_set_manually-added", func(t *testing.T) {
ip := net.IP{1, 1, 1, 2}
ok, err := clients.Add(&Client{
IDs: []string{"1.1.1.2"},
Name: "client1",
@ -215,8 +239,10 @@ func TestClientsWHOIS(t *testing.T) {
require.NoError(t, err)
assert.True(t, ok)
clients.SetWHOISInfo("1.1.1.2", whois)
require.Nil(t, clients.ipToRC["1.1.1.2"])
clients.SetWHOISInfo(ip, whois)
v, _ := clients.ipToRC.Get(ip)
require.Nil(t, v)
assert.True(t, clients.Del("client1"))
})
}
@ -228,16 +254,18 @@ func TestClientsAddExisting(t *testing.T) {
clients.Init(nil, nil, nil)
t.Run("simple", func(t *testing.T) {
ip := net.IP{1, 1, 1, 1}
// Add a client.
ok, err := clients.Add(&Client{
IDs: []string{"1.1.1.1", "1:2:3::4", "aa:aa:aa:aa:aa:aa", "2.2.2.0/24"},
IDs: []string{ip.String(), "1:2:3::4", "aa:aa:aa:aa:aa:aa", "2.2.2.0/24"},
Name: "client1",
})
require.NoError(t, err)
assert.True(t, ok)
// Now add an auto-client with the same IP.
ok, err = clients.AddHost("1.1.1.1", "test", ClientSourceRDNS)
ok, err = clients.AddHost(ip, "test", ClientSourceRDNS)
require.NoError(t, err)
assert.True(t, ok)
})
@ -245,7 +273,7 @@ func TestClientsAddExisting(t *testing.T) {
t.Run("complicated", func(t *testing.T) {
var err error
testIP := net.IP{1, 2, 3, 4}
ip := net.IP{1, 2, 3, 4}
// First, init a DHCP server with a single static lease.
config := dhcpd.ServerConfig{
@ -267,7 +295,7 @@ func TestClientsAddExisting(t *testing.T) {
err = clients.dhcpServer.AddStaticLease(&dhcpd.Lease{
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
IP: testIP,
IP: ip,
Hostname: "testhost",
Expiry: time.Now().Add(time.Hour),
})
@ -275,7 +303,7 @@ func TestClientsAddExisting(t *testing.T) {
// Add a new client with the same IP as for a client with MAC.
ok, err := clients.Add(&Client{
IDs: []string{testIP.String()},
IDs: []string{ip.String()},
Name: "client2",
})
require.NoError(t, err)

View File

@ -5,6 +5,8 @@ import (
"fmt"
"net"
"net/http"
"github.com/AdguardTeam/golibs/log"
)
// clientJSON is a common structure used by several handlers to deal with
@ -44,13 +46,13 @@ type clientJSON struct {
type runtimeClientJSON struct {
WHOISInfo *RuntimeClientWHOISInfo `json:"whois_info"`
IP string `json:"ip"`
Name string `json:"name"`
Source string `json:"source"`
IP net.IP `json:"ip"`
}
type clientListJSON struct {
Clients []clientJSON `json:"clients"`
Clients []*clientJSON `json:"clients"`
RuntimeClients []runtimeClientJSON `json:"auto_clients"`
Tags []string `json:"supported_tags"`
}
@ -66,11 +68,20 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, _ *http
cj := clientToJSON(c)
data.Clients = append(data.Clients, cj)
}
for ip, rc := range clients.ipToRC {
clients.ipToRC.Range(func(ip net.IP, v interface{}) (cont bool) {
rc, ok := v.(*RuntimeClient)
if !ok {
log.Error("dns: bad type %T in ipToRC for %s", v, ip)
return true
}
cj := runtimeClientJSON{
IP: ip,
Name: rc.Host,
WHOISInfo: rc.WHOISInfo,
Name: rc.Host,
IP: ip,
}
cj.Source = "etc/hosts"
@ -86,7 +97,9 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, _ *http
}
data.RuntimeClients = append(data.RuntimeClients, cj)
}
return true
})
data.Tags = clientTags
@ -118,8 +131,8 @@ func jsonToClient(cj clientJSON) (c *Client) {
}
// Convert Client object to JSON
func clientToJSON(c *Client) clientJSON {
cj := clientJSON{
func clientToJSON(c *Client) (cj *clientJSON) {
return &clientJSON{
Name: c.Name,
IDs: c.IDs,
Tags: c.Tags,
@ -134,19 +147,6 @@ func clientToJSON(c *Client) clientJSON {
Upstreams: c.Upstreams,
}
return cj
}
// runtimeClientToJSON converts a RuntimeClient into a JSON struct.
func runtimeClientToJSON(ip string, rc RuntimeClient) (cj clientJSON) {
cj = clientJSON{
Name: rc.Host,
IDs: []string{ip},
WHOISInfo: rc.WHOISInfo,
}
return cj
}
// Add a new client
@ -230,7 +230,7 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht
// Get the list of clients by IP address list
func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query()
data := []map[string]clientJSON{}
data := []map[string]*clientJSON{}
for i := 0; i < len(q); i++ {
idStr := q.Get(fmt.Sprintf("ip%d", i))
if idStr == "" {
@ -239,20 +239,16 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http
ip := net.ParseIP(idStr)
c, ok := clients.Find(idStr)
var cj clientJSON
var cj *clientJSON
if !ok {
var found bool
cj, found = clients.findRuntime(ip, idStr)
if !found {
continue
}
cj = clients.findRuntime(ip, idStr)
} else {
cj = clientToJSON(c)
disallowed, rule := clients.dnsServer.IsBlockedIP(ip)
disallowed, rule := clients.dnsServer.IsBlockedClient(ip, idStr)
cj.Disallowed, cj.DisallowedRule = &disallowed, &rule
}
data = append(data, map[string]clientJSON{
data = append(data, map[string]*clientJSON{
idStr: cj,
})
}
@ -265,39 +261,37 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http
}
// findRuntime looks up the IP in runtime and temporary storages, like
// /etc/hosts tables, DHCP leases, or blocklists.
func (clients *clientsContainer) findRuntime(ip net.IP, idStr string) (cj clientJSON, found bool) {
if ip == nil {
return cj, false
}
rc, ok := clients.FindRuntimeClient(idStr)
// /etc/hosts tables, DHCP leases, or blocklists. cj is guaranteed to be
// non-nil.
func (clients *clientsContainer) findRuntime(ip net.IP, idStr string) (cj *clientJSON) {
rc, ok := clients.FindRuntimeClient(ip)
if !ok {
// It is still possible that the IP used to be in the runtime
// clients list, but then the server was reloaded. So, check
// the DNS server's blocked IP list.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/2428.
disallowed, rule := clients.dnsServer.IsBlockedIP(ip)
if rule == "" {
return clientJSON{}, false
}
cj = clientJSON{
disallowed, rule := clients.dnsServer.IsBlockedClient(ip, idStr)
cj = &clientJSON{
IDs: []string{idStr},
Disallowed: &disallowed,
DisallowedRule: &rule,
WHOISInfo: &RuntimeClientWHOISInfo{},
}
return cj, true
return cj
}
cj = runtimeClientToJSON(idStr, rc)
disallowed, rule := clients.dnsServer.IsBlockedIP(ip)
cj = &clientJSON{
Name: rc.Host,
IDs: []string{idStr},
WHOISInfo: rc.WHOISInfo,
}
disallowed, rule := clients.dnsServer.IsBlockedClient(ip, idStr)
cj.Disallowed, cj.DisallowedRule = &disallowed, &rule
return cj, true
return cj
}
// RegisterClientsHandlers registers HTTP handlers

View File

@ -105,8 +105,8 @@ func isRunning() bool {
return Context.dnsServer != nil && Context.dnsServer.IsRunning()
}
func onDNSRequest(d *proxy.DNSContext) {
ip := aghnet.IPFromAddr(d.Addr)
func onDNSRequest(pctx *proxy.DNSContext) {
ip := aghnet.IPFromAddr(pctx.Addr)
if ip == nil {
// This would be quite weird if we get here.
return

View File

@ -503,7 +503,7 @@ Please note, that this is crucial for a server to be able to use privileged port
You have two options:
1. Run AdGuard Home with root privileges
2. On Linux you can grant the CAP_NET_BIND_SERVICE capability:
https://github.com/AdguardTeam/AdGuardHome/internal/wiki/Getting-Started#running-without-superuser`
https://github.com/AdguardTeam/AdGuardHome/wiki/Getting-Started#running-without-superuser`
log.Fatal(msg)
}

View File

@ -102,12 +102,7 @@ func (r *RDNS) isCached(ip net.IP) (ok bool) {
func (r *RDNS) Begin(ip net.IP) {
r.ensurePrivateCache()
if r.isCached(ip) {
return
}
id := ip.String()
if r.clients.Exists(id, ClientSourceRDNS) {
if r.isCached(ip) || r.clients.Exists(ip, ClientSourceRDNS) {
return
}
@ -138,6 +133,6 @@ func (r *RDNS) workerLoop() {
// Don't handle any errors since AddHost doesn't return non-nil
// errors for now.
_, _ = r.clients.AddHost(ip.String(), host, ClientSourceRDNS)
_, _ = r.clients.AddHost(ip, host, ClientSourceRDNS)
}
}

View File

@ -8,6 +8,7 @@ import (
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/aghstrings"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/dnsproxy/upstream"
@ -84,7 +85,7 @@ func TestRDNS_Begin(t *testing.T) {
clients: &clientsContainer{
list: map[string]*Client{},
idIndex: tc.cliIDIndex,
ipToRC: map[string]*RuntimeClient{},
ipToRC: aghnet.NewIPMap(0),
allTags: aghstrings.NewSet(),
},
}
@ -204,7 +205,7 @@ func TestRDNS_WorkerLoop(t *testing.T) {
cc := &clientsContainer{
list: map[string]*Client{},
idIndex: map[string]*Client{},
ipToRC: map[string]*RuntimeClient{},
ipToRC: aghnet.NewIPMap(0),
allTags: aghstrings.NewSet(),
}
ch := make(chan net.IP)
@ -236,7 +237,7 @@ func TestRDNS_WorkerLoop(t *testing.T) {
return
}
assert.True(t, cc.Exists(tc.cliIP.String(), ClientSourceRDNS))
assert.True(t, cc.Exists(tc.cliIP, ClientSourceRDNS))
})
}
}

View File

@ -252,7 +252,6 @@ func (w *WHOIS) workerLoop() {
continue
}
id := ip.String()
w.clients.SetWHOISInfo(id, info)
w.clients.SetWHOISInfo(ip, info)
}
}

View File

@ -720,7 +720,10 @@ func (s *statsCtx) GetTopClientsIP(maxCount uint) []net.IP {
a := convertMapToSlice(m, int(maxCount))
d := []net.IP{}
for _, it := range a {
d = append(d, net.ParseIP(it.Name))
ip := net.ParseIP(it.Name)
if ip != nil {
d = append(d, ip)
}
}
return d
}

View File

@ -4,6 +4,11 @@
## v0.107: API changes
### Client IDs in Access Settings
* The `POST /control/access/set` HTTP API now accepts client IDs in
`"allowed_clients"` and `"disallowed_clients"` fields.
### The new field `"unicode_name"` in `DNSQuestion`
* The new optional field `"unicode_name"` is the Unicode representation of
@ -17,7 +22,7 @@
### Disabling Statistics
* The API `POST /control/stats_config` HTTP API allows disabling statistics by
* The `POST /control/stats_config` HTTP API allows disabling statistics by
setting `"interval"` to `0`.
### `POST /control/dhcp/reset_leases`

View File

@ -1957,10 +1957,7 @@
'disallowed_rule':
'type': 'string'
'description': >
The rule due to which the client is disallowed. If disallowed is
set to true, and this string is empty, then the client IP is
disallowed by the "allowed IP list", that is it is not included in
the allowed list.
The rule due to which the client is allowed or blocked.
'name':
'description': >
Persistent client's name or an empty string if this is a runtime
@ -2352,17 +2349,19 @@
'description': 'Client and host access list'
'properties':
'allowed_clients':
'description': 'Allowlist of clients.'
'description': >
The allowlist of clients: IP addresses, CIDRs, or client IDs.
'items':
'type': 'string'
'type': 'array'
'disallowed_clients':
'description': 'Blocklist of clients.'
'description': >
The blocklist of clients: IP addresses, CIDRs, or client IDs.
'items':
'type': 'string'
'type': 'array'
'blocked_hosts':
'description': 'Blocklist of hosts.'
'description': 'The blocklist of hosts.'
'items':
'type': 'string'
'type': 'array'