Pull request: 2508 ip conversion vol.2

Merge in DNS/adguard-home from 2508-ip-conversion-vol2 to master

Closes #2508.

Squashed commit of the following:

commit 5b9d33f9cd352756831f63e34c4aea48674628c1
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Wed Jan 20 17:15:17 2021 +0300

    util: replace net.IPNet with pointer

commit 680126de7d59464077f9edf1bbaa925dd3fcee19
Merge: d3ba6a6c 5a50efad
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Wed Jan 20 17:02:41 2021 +0300

    Merge branch 'master' into 2508-ip-conversion-vol2

commit d3ba6a6cdd01c0aa736418fdb86ed40120169fe9
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Tue Jan 19 18:29:54 2021 +0300

    all: remove last conversion

commit 88b63f11a6c3f8705d7fa0c448c50dd646cc9214
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Tue Jan 19 14:12:45 2021 +0300

    all: improve code quality

commit 71af60c70a0dbaf55e2221023d6d2e4993c9e9a7
Merge: 98af3784 9f75725d
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Mon Jan 18 17:13:27 2021 +0300

    Merge branch 'master' into 2508-ip-conversion-vol2

commit 98af3784ce44d0993d171653c13d6e83bb8d1e6a
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Mon Jan 18 16:32:53 2021 +0300

    all: log changes

commit e99595a172bae1e844019d344544be84ddd65e4e
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Mon Jan 18 16:06:49 2021 +0300

    all: fix or remove remaining net.IP <-> string conversions

commit 7fd0634ce945f7e4c9b856684c5199f8a84a543e
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Fri Jan 15 15:36:17 2021 +0300

    all: remove redundant net.IP <-> string converions

commit 5df8af030421237d41b67ed659f83526cc258199
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Thu Jan 14 16:35:25 2021 +0300

    stats: remove redundant net.IP <-> string conversion

commit fbe4e3fc015e6898063543a90c04401d76dbb18f
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Thu Jan 14 16:20:35 2021 +0300

    querylog: remove redundant net.IP <-> string conversion
This commit is contained in:
Eugene Burkov 2021-01-20 17:27:53 +03:00
parent 5a50efadb2
commit 7fab31beae
45 changed files with 324 additions and 302 deletions

View File

@ -66,6 +66,7 @@ and this project adheres to
### Fixed
- Unnecessary conversions from `string` to `net.IP`, and vice versa ([#2508]).
- Inability to set DNS cache TTL limits ([#2459]).
- Possible freezes on slower machines ([#2225]).
- A mitigation against records being shown in the wrong order on the query log
@ -79,9 +80,13 @@ and this project adheres to
[#2345]: https://github.com/AdguardTeam/AdGuardHome/issues/2345
[#2355]: https://github.com/AdguardTeam/AdGuardHome/issues/2355
[#2459]: https://github.com/AdguardTeam/AdGuardHome/issues/2459
[#2508]: https://github.com/AdguardTeam/AdGuardHome/issues/2508
### Removed
- The undocumented ability to use hostnames as any of `bind_host` values in
configuration. Documentation requires them to be valid IP addresses, and now
the implementation makes sure that that is the case ([#2508]).
- `Dockerfile` ([#2276]). Replaced with the script
`scripts/make/build-docker.sh` which uses `scripts/make/Dockerfile`.
- Support for pre-v0.99.3 format of query logs ([#2102]).

View File

@ -297,9 +297,6 @@ func parseOptionString(s string) (uint8, []byte) {
return 0, nil
}
val = ip
if ip.To4() != nil {
val = ip.To4()
}
default:
return 0, nil

View File

@ -61,11 +61,11 @@ func TestDB(t *testing.T) {
ll := s.srv4.GetLeases(LeasesAll)
assert.Equal(t, "aa:aa:aa:aa:aa:bb", ll[0].HWAddr.String())
assert.Equal(t, "192.168.10.101", ll[0].IP.String())
assert.True(t, net.IP{192, 168, 10, 101}.Equal(ll[0].IP))
assert.EqualValues(t, leaseExpireStatic, ll[0].Expiry.Unix())
assert.Equal(t, "aa:aa:aa:aa:aa:aa", ll[1].HWAddr.String())
assert.Equal(t, "192.168.10.100", ll[1].IP.String())
assert.True(t, net.IP{192, 168, 10, 100}.Equal(ll[1].IP))
assert.Equal(t, exp1.Unix(), ll[1].Expiry.Unix())
_ = os.Remove("leases.db")
@ -117,7 +117,7 @@ func TestOptions(t *testing.T) {
code, val = parseOptionString("123 ip 1.2.3.4")
assert.EqualValues(t, 123, code)
assert.Equal(t, "1.2.3.4", net.IP(string(val)).String())
assert.True(t, net.IP{1, 2, 3, 4}.Equal(net.IP(val)))
code, _ = parseOptionString("256 ip 1.1.1.1")
assert.EqualValues(t, 0, code)

View File

@ -40,7 +40,7 @@ func v4JSONToServerConf(j v4ServerConfJSON) V4ServerConf {
}
type v6ServerConfJSON struct {
RangeStart string `json:"range_start"`
RangeStart net.IP `json:"range_start"`
LeaseDuration uint32 `json:"lease_duration"`
}
@ -331,7 +331,7 @@ func (s *Server) handleDHCPFindActiveServer(w http.ResponseWriter, r *http.Reque
result.V4.StaticIP.Error = err.Error()
} else if !isStaticIP {
result.V4.StaticIP.Static = "no"
result.V4.StaticIP.IP = util.GetSubnet(interfaceName)
result.V4.StaticIP.IP = util.GetSubnet(interfaceName).String()
}
if found4 {

View File

@ -79,7 +79,7 @@ type V6ServerConf struct {
// The first IP address for dynamic leases
// The last allowed IP address ends with 0xff byte
RangeStart string `yaml:"range_start" json:"range_start"`
RangeStart net.IP `yaml:"range_start"`
LeaseDuration uint32 `yaml:"lease_duration" json:"lease_duration"` // in seconds

View File

@ -40,7 +40,7 @@ func TestV4StaticLeaseAddRemove(t *testing.T) {
// check
ls = s.GetLeases(LeasesStatic)
assert.Len(t, ls, 1)
assert.Equal(t, "192.168.10.150", ls[0].IP.String())
assert.True(t, net.IP{192, 168, 10, 150}.Equal(ls[0].IP))
assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String())
assert.EqualValues(t, leaseExpireStatic, ls[0].Expiry.Unix())
@ -102,11 +102,11 @@ func TestV4StaticLeaseAddReplaceDynamic(t *testing.T) {
ls := s.GetLeases(LeasesStatic)
assert.Len(t, ls, 2)
assert.Equal(t, "192.168.10.150", ls[0].IP.String())
assert.True(t, net.IP{192, 168, 10, 150}.Equal(ls[0].IP))
assert.Equal(t, "33:aa:aa:aa:aa:aa", ls[0].HWAddr.String())
assert.EqualValues(t, leaseExpireStatic, ls[0].Expiry.Unix())
assert.Equal(t, "192.168.10.152", ls[1].IP.String())
assert.True(t, net.IP{192, 168, 10, 152}.Equal(ls[1].IP))
assert.Equal(t, "22:aa:aa:aa:aa:aa", ls[1].HWAddr.String())
assert.EqualValues(t, leaseExpireStatic, ls[1].Expiry.Unix())
}
@ -139,10 +139,10 @@ func TestV4StaticLeaseGet(t *testing.T) {
// check "Offer"
assert.Equal(t, dhcpv4.MessageTypeOffer, resp.MessageType())
assert.Equal(t, "aa:aa:aa:aa:aa:aa", resp.ClientHWAddr.String())
assert.Equal(t, "192.168.10.150", resp.YourIPAddr.String())
assert.Equal(t, "192.168.10.1", resp.Router()[0].String())
assert.Equal(t, "192.168.10.1", resp.ServerIdentifier().String())
assert.Equal(t, "255.255.255.0", net.IP(resp.SubnetMask()).String())
assert.True(t, net.IP{192, 168, 10, 150}.Equal(resp.YourIPAddr))
assert.True(t, net.IP{192, 168, 10, 1}.Equal(resp.Router()[0]))
assert.True(t, net.IP{192, 168, 10, 1}.Equal(resp.ServerIdentifier()))
assert.True(t, net.IP{255, 255, 255, 0}.Equal(net.IP(resp.SubnetMask())))
assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds())
// "Request"
@ -153,20 +153,20 @@ func TestV4StaticLeaseGet(t *testing.T) {
// check "Ack"
assert.Equal(t, dhcpv4.MessageTypeAck, resp.MessageType())
assert.Equal(t, "aa:aa:aa:aa:aa:aa", resp.ClientHWAddr.String())
assert.Equal(t, "192.168.10.150", resp.YourIPAddr.String())
assert.Equal(t, "192.168.10.1", resp.Router()[0].String())
assert.Equal(t, "192.168.10.1", resp.ServerIdentifier().String())
assert.Equal(t, "255.255.255.0", net.IP(resp.SubnetMask()).String())
assert.True(t, net.IP{192, 168, 10, 150}.Equal(resp.YourIPAddr))
assert.True(t, net.IP{192, 168, 10, 1}.Equal(resp.Router()[0]))
assert.True(t, net.IP{192, 168, 10, 1}.Equal(resp.ServerIdentifier()))
assert.True(t, net.IP{255, 255, 255, 0}.Equal(net.IP(resp.SubnetMask())))
assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds())
dnsAddrs := resp.DNS()
assert.Len(t, dnsAddrs, 1)
assert.Equal(t, "192.168.10.1", dnsAddrs[0].String())
assert.True(t, net.IP{192, 168, 10, 1}.Equal(dnsAddrs[0]))
// check lease
ls := s.GetLeases(LeasesStatic)
assert.Len(t, ls, 1)
assert.Equal(t, "192.168.10.150", ls[0].IP.String())
assert.True(t, net.IP{192, 168, 10, 150}.Equal(ls[0].IP))
assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String())
}
@ -197,13 +197,13 @@ func TestV4DynamicLeaseGet(t *testing.T) {
// check "Offer"
assert.Equal(t, dhcpv4.MessageTypeOffer, resp.MessageType())
assert.Equal(t, "aa:aa:aa:aa:aa:aa", resp.ClientHWAddr.String())
assert.Equal(t, "192.168.10.100", resp.YourIPAddr.String())
assert.Equal(t, "192.168.10.1", resp.Router()[0].String())
assert.Equal(t, "192.168.10.1", resp.ServerIdentifier().String())
assert.Equal(t, "255.255.255.0", net.IP(resp.SubnetMask()).String())
assert.True(t, net.IP{192, 168, 10, 100}.Equal(resp.YourIPAddr))
assert.True(t, net.IP{192, 168, 10, 1}.Equal(resp.Router()[0]))
assert.True(t, net.IP{192, 168, 10, 1}.Equal(resp.ServerIdentifier()))
assert.True(t, net.IP{255, 255, 255, 0}.Equal(net.IP(resp.SubnetMask())))
assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds())
assert.Equal(t, []byte("012"), resp.Options[uint8(dhcpv4.OptionFQDN)])
assert.Equal(t, "1.2.3.4", net.IP(resp.Options[uint8(dhcpv4.OptionRelayAgentInformation)]).String())
assert.True(t, net.IP{1, 2, 3, 4}.Equal(net.IP(resp.Options[uint8(dhcpv4.OptionRelayAgentInformation)])))
// "Request"
req, _ = dhcpv4.NewRequestFromOffer(resp)
@ -213,20 +213,20 @@ func TestV4DynamicLeaseGet(t *testing.T) {
// check "Ack"
assert.Equal(t, dhcpv4.MessageTypeAck, resp.MessageType())
assert.Equal(t, "aa:aa:aa:aa:aa:aa", resp.ClientHWAddr.String())
assert.Equal(t, "192.168.10.100", resp.YourIPAddr.String())
assert.Equal(t, "192.168.10.1", resp.Router()[0].String())
assert.Equal(t, "192.168.10.1", resp.ServerIdentifier().String())
assert.Equal(t, "255.255.255.0", net.IP(resp.SubnetMask()).String())
assert.True(t, net.IP{192, 168, 10, 100}.Equal(resp.YourIPAddr))
assert.True(t, net.IP{192, 168, 10, 1}.Equal(resp.Router()[0]))
assert.True(t, net.IP{192, 168, 10, 1}.Equal(resp.ServerIdentifier()))
assert.True(t, net.IP{255, 255, 255, 0}.Equal(net.IP(resp.SubnetMask())))
assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds())
dnsAddrs := resp.DNS()
assert.Len(t, dnsAddrs, 1)
assert.Equal(t, "192.168.10.1", dnsAddrs[0].String())
assert.True(t, net.IP{192, 168, 10, 1}.Equal(dnsAddrs[0]))
// check lease
ls := s.GetLeases(LeasesDynamic)
assert.Len(t, ls, 1)
assert.Equal(t, "192.168.10.100", ls[0].IP.String())
assert.True(t, net.IP{192, 168, 10, 100}.Equal(ls[0].IP))
assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String())
start := net.IP{192, 168, 10, 100}

View File

@ -660,7 +660,7 @@ func v6Create(conf V6ServerConf) (DHCPServer, error) {
return s, nil
}
s.conf.ipStart = net.ParseIP(conf.RangeStart)
s.conf.ipStart = conf.RangeStart
if s.conf.ipStart == nil || s.conf.ipStart.To16() == nil {
return s, fmt.Errorf("dhcpv6: invalid range-start IP: %s", conf.RangeStart)
}

View File

@ -17,7 +17,7 @@ func notify6(flags uint32) {
func TestV6StaticLeaseAddRemove(t *testing.T) {
conf := V6ServerConf{
Enabled: true,
RangeStart: "2001::1",
RangeStart: net.ParseIP("2001::1"),
notify: notify6,
}
s, err := v6Create(conf)
@ -60,7 +60,7 @@ func TestV6StaticLeaseAddRemove(t *testing.T) {
func TestV6StaticLeaseAddReplaceDynamic(t *testing.T) {
conf := V6ServerConf{
Enabled: true,
RangeStart: "2001::1",
RangeStart: net.ParseIP("2001::1"),
notify: notify6,
}
sIface, err := v6Create(conf)
@ -109,7 +109,7 @@ func TestV6StaticLeaseAddReplaceDynamic(t *testing.T) {
func TestV6GetLease(t *testing.T) {
conf := V6ServerConf{
Enabled: true,
RangeStart: "2001::1",
RangeStart: net.ParseIP("2001::1"),
notify: notify6,
}
sIface, err := v6Create(conf)
@ -169,7 +169,7 @@ func TestV6GetLease(t *testing.T) {
func TestV6GetDynamicLease(t *testing.T) {
conf := V6ServerConf{
Enabled: true,
RangeStart: "2001::2",
RangeStart: net.ParseIP("2001::2"),
notify: notify6,
}
sIface, err := v6Create(conf)

View File

@ -36,7 +36,7 @@ type RequestFilteringSettings struct {
ParentalEnabled bool
ClientName string
ClientIP string
ClientIP net.IP
ClientTags []string
ServicesRules []ServiceEntry
@ -676,7 +676,8 @@ func (d *DNSFilter) matchHost(host string, qtype uint16, setts RequestFilteringS
ureq := urlfilter.DNSRequest{
Hostname: host,
SortedClientTags: setts.ClientTags,
ClientIP: setts.ClientIP,
// TODO(e.burkov): Wait for urlfilter update to pass net.IP.
ClientIP: setts.ClientIP.String(),
ClientName: setts.ClientName,
DNSType: qtype,
}

View File

@ -117,19 +117,19 @@ func TestRewritesLevels(t *testing.T) {
r := d.processRewrites("host.com", dns.TypeA)
assert.Equal(t, Rewritten, r.Reason)
assert.Len(t, r.IPList, 1)
assert.Equal(t, "1.1.1.1", r.IPList[0].String())
assert.True(t, net.IP{1, 1, 1, 1}.Equal(r.IPList[0]))
// match L2
r = d.processRewrites("sub.host.com", dns.TypeA)
assert.Equal(t, Rewritten, r.Reason)
assert.Len(t, r.IPList, 1)
assert.Equal(t, "2.2.2.2", r.IPList[0].String())
assert.True(t, net.IP{2, 2, 2, 2}.Equal(r.IPList[0]))
// match L3
r = d.processRewrites("my.sub.host.com", dns.TypeA)
assert.Equal(t, Rewritten, r.Reason)
assert.Len(t, r.IPList, 1)
assert.Equal(t, "3.3.3.3", r.IPList[0].String())
assert.True(t, net.IP{3, 3, 3, 3}.Equal(r.IPList[0]))
}
func TestRewritesExceptionCNAME(t *testing.T) {
@ -145,7 +145,7 @@ func TestRewritesExceptionCNAME(t *testing.T) {
r := d.processRewrites("my.host.com", dns.TypeA)
assert.Equal(t, Rewritten, r.Reason)
assert.Len(t, r.IPList, 1)
assert.Equal(t, "2.2.2.2", r.IPList[0].String())
assert.True(t, net.IP{2, 2, 2, 2}.Equal(r.IPList[0]))
// match sub-domain, but handle exception
r = d.processRewrites("sub.host.com", dns.TypeA)
@ -165,7 +165,7 @@ func TestRewritesExceptionWC(t *testing.T) {
r := d.processRewrites("my.host.com", dns.TypeA)
assert.Equal(t, Rewritten, r.Reason)
assert.Len(t, r.IPList, 1)
assert.Equal(t, "2.2.2.2", r.IPList[0].String())
assert.True(t, net.IP{2, 2, 2, 2}.Equal(r.IPList[0]))
// match sub-domain, but handle exception
r = d.processRewrites("my.sub.host.com", dns.TypeA)
@ -188,7 +188,7 @@ func TestRewritesExceptionIP(t *testing.T) {
r := d.processRewrites("host.com", dns.TypeA)
assert.Equal(t, Rewritten, r.Reason)
assert.Len(t, r.IPList, 1)
assert.Equal(t, "1.2.3.4", r.IPList[0].String())
assert.True(t, net.IP{1, 2, 3, 4}.Equal(r.IPList[0]))
// match exception
r = d.processRewrites("host.com", dns.TypeAAAA)

View File

@ -83,20 +83,21 @@ func processIPCIDRArray(dst *map[string]bool, dstIPNet *[]net.IPNet, src []strin
// Returns the item from the "disallowedClients" list that lead to blocking IP.
// If it returns TRUE and an empty string, it means that the "allowedClients" is not empty,
// but the ip does not belong to it.
func (a *accessCtx) IsBlockedIP(ip string) (bool, string) {
func (a *accessCtx) IsBlockedIP(ip net.IP) (bool, string) {
ipStr := ip.String()
a.lock.Lock()
defer a.lock.Unlock()
if len(a.allowedClients) != 0 || len(a.allowedClientsIPNet) != 0 {
_, ok := a.allowedClients[ip]
_, ok := a.allowedClients[ipStr]
if ok {
return false, ""
}
if len(a.allowedClientsIPNet) != 0 {
ipAddr := net.ParseIP(ip)
for _, ipnet := range a.allowedClientsIPNet {
if ipnet.Contains(ipAddr) {
if ipnet.Contains(ip) {
return false, ""
}
}
@ -105,15 +106,14 @@ func (a *accessCtx) IsBlockedIP(ip string) (bool, string) {
return true, ""
}
_, ok := a.disallowedClients[ip]
_, ok := a.disallowedClients[ipStr]
if ok {
return true, ip
return true, ipStr
}
if len(a.disallowedClientsIPNet) != 0 {
ipAddr := net.ParseIP(ip)
for _, ipnet := range a.disallowedClientsIPNet {
if ipnet.Contains(ipAddr) {
if ipnet.Contains(ip) {
return true, ipnet.String()
}
}

View File

@ -1,6 +1,7 @@
package dnsforward
import (
"net"
"testing"
"github.com/stretchr/testify/assert"
@ -10,19 +11,19 @@ func TestIsBlockedIPAllowed(t *testing.T) {
a := &accessCtx{}
assert.Nil(t, a.Init([]string{"1.1.1.1", "2.2.0.0/16"}, nil, nil))
disallowed, disallowedRule := a.IsBlockedIP("1.1.1.1")
disallowed, disallowedRule := a.IsBlockedIP(net.IPv4(1, 1, 1, 1))
assert.False(t, disallowed)
assert.Empty(t, disallowedRule)
disallowed, disallowedRule = a.IsBlockedIP("1.1.1.2")
disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(1, 1, 1, 2))
assert.True(t, disallowed)
assert.Empty(t, disallowedRule)
disallowed, disallowedRule = a.IsBlockedIP("2.2.1.1")
disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(2, 2, 1, 1))
assert.False(t, disallowed)
assert.Empty(t, disallowedRule)
disallowed, disallowedRule = a.IsBlockedIP("2.3.1.1")
disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(2, 3, 1, 1))
assert.True(t, disallowed)
assert.Empty(t, disallowedRule)
}
@ -31,19 +32,19 @@ func TestIsBlockedIPDisallowed(t *testing.T) {
a := &accessCtx{}
assert.Nil(t, a.Init(nil, []string{"1.1.1.1", "2.2.0.0/16"}, nil))
disallowed, disallowedRule := a.IsBlockedIP("1.1.1.1")
disallowed, disallowedRule := a.IsBlockedIP(net.IPv4(1, 1, 1, 1))
assert.True(t, disallowed)
assert.Equal(t, "1.1.1.1", disallowedRule)
disallowed, disallowedRule = a.IsBlockedIP("1.1.1.2")
disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(1, 1, 1, 2))
assert.False(t, disallowed)
assert.Empty(t, disallowedRule)
disallowed, disallowedRule = a.IsBlockedIP("2.2.1.1")
disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(2, 2, 1, 1))
assert.True(t, disallowed)
assert.Equal(t, "2.2.0.0/16", disallowedRule)
disallowed, disallowedRule = a.IsBlockedIP("2.3.1.1")
disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(2, 3, 1, 1))
assert.False(t, disallowed)
assert.Empty(t, disallowedRule)
}

View File

@ -25,11 +25,11 @@ type FilteringConfig struct {
// --
// Filtering callback function
FilterHandler func(clientAddr string, settings *dnsfilter.RequestFilteringSettings) `yaml:"-"`
FilterHandler func(clientAddr net.IP, settings *dnsfilter.RequestFilteringSettings) `yaml:"-"`
// GetCustomUpstreamByClient - a callback function that returns upstreams configuration
// based on the client IP address. Returns nil if there are no custom upstreams for the client
// TODO(e.burkov): replace argument type with net.IP.
// TODO(e.burkov): Replace argument type with net.IP.
GetCustomUpstreamByClient func(clientAddr string) *proxy.UpstreamConfig `yaml:"-"`
// Protection configuration

View File

@ -298,6 +298,6 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
// IsBlockedIP - return TRUE if this client should be blocked
func (s *Server) IsBlockedIP(ip string) (bool, string) {
func (s *Server) IsBlockedIP(ip net.IP) (bool, string) {
return s.access.IsBlockedIP(ip)
}

View File

@ -322,7 +322,7 @@ func TestServerCustomClientUpstream(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
assert.NotNil(t, reply.Answer)
assert.Equal(t, "192.168.0.1", reply.Answer[0].(*dns.A).A.String())
assert.True(t, net.IP{192, 168, 0, 1}.Equal(reply.Answer[0].(*dns.A).A))
assert.Nil(t, s.Stop())
}
@ -473,7 +473,7 @@ func TestBlockCNAME(t *testing.T) {
func TestClientRulesForCNAMEMatching(t *testing.T) {
s := createTestServer(t)
testUpstm := &testUpstream{testCNAMEs, testIPv4, nil}
s.conf.FilterHandler = func(_ string, settings *dnsfilter.RequestFilteringSettings) {
s.conf.FilterHandler = func(_ net.IP, settings *dnsfilter.RequestFilteringSettings) {
settings.FilteringEnabled = false
}
err := s.startWithUpstream(testUpstm)
@ -568,7 +568,7 @@ func TestBlockedCustomIP(t *testing.T) {
assert.Len(t, reply.Answer, 1)
a, ok := reply.Answer[0].(*dns.A)
assert.True(t, ok)
assert.Equal(t, "0.0.0.1", a.A.String())
assert.True(t, net.IP{0, 0, 0, 1}.Equal(a.A))
req = createTestMessageWithType("null.example.org.", dns.TypeAAAA)
reply, err = dns.Exchange(req, addr.String())
@ -713,7 +713,7 @@ func TestRewrite(t *testing.T) {
assert.Len(t, reply.Answer, 1)
a, ok := reply.Answer[0].(*dns.A)
assert.True(t, ok)
assert.Equal(t, "1.2.3.4", a.A.String())
assert.True(t, net.IP{1, 2, 3, 4}.Equal(a.A))
req = createTestMessageWithType("test.com.", dns.TypeAAAA)
reply, err = dns.Exchange(req, addr.String())
@ -725,7 +725,7 @@ func TestRewrite(t *testing.T) {
assert.Nil(t, err)
assert.Len(t, reply.Answer, 2)
assert.Equal(t, "test.com.", reply.Answer[0].(*dns.CNAME).Target)
assert.Equal(t, "1.2.3.4", reply.Answer[1].(*dns.A).A.String())
assert.True(t, net.IP{1, 2, 3, 4}.Equal(reply.Answer[1].(*dns.A).A))
req = createTestMessageWithType("my.alias.example.org.", dns.TypeA)
reply, err = dns.Exchange(req, addr.String())

View File

@ -12,7 +12,7 @@ import (
)
func (s *Server) beforeRequestHandler(_ *proxy.Proxy, d *proxy.DNSContext) (bool, error) {
ip := IPStringFromAddr(d.Addr)
ip := IPFromAddr(d.Addr)
disallowed, _ := s.access.IsBlockedIP(ip)
if disallowed {
log.Tracef("Client IP %s is blocked by settings", ip)
@ -36,8 +36,7 @@ func (s *Server) getClientRequestFilteringSettings(d *proxy.DNSContext) *dnsfilt
setts := s.dnsFilter.GetConfig()
setts.FilteringEnabled = true
if s.conf.FilterHandler != nil {
clientAddr := IPStringFromAddr(d.Addr)
s.conf.FilterHandler(clientAddr, &setts)
s.conf.FilterHandler(IPFromAddr(d.Addr), &setts)
}
return &setts
}

View File

@ -36,7 +36,7 @@ func processQueryLogsAndStats(ctx *dnsContext) int {
OrigAnswer: ctx.origResp,
Result: ctx.result,
Elapsed: elapsed,
ClientIP: ipFromAddr(d.Addr),
ClientIP: IPFromAddr(d.Addr),
}
switch d.Proto {

View File

@ -8,8 +8,8 @@ import (
"github.com/AdguardTeam/golibs/utils"
)
// ipFromAddr gets IP address from addr.
func ipFromAddr(addr net.Addr) (ip net.IP) {
// IPFromAddr gets IP address from addr.
func IPFromAddr(addr net.Addr) (ip net.IP) {
switch addr := addr.(type) {
case *net.UDPAddr:
return addr.IP
@ -22,8 +22,8 @@ func ipFromAddr(addr net.Addr) (ip net.IP) {
// IPStringFromAddr extracts IP address from net.Addr.
// Note: we can't use net.SplitHostPort(a.String()) because of IPv6 zone:
// https://github.com/AdguardTeam/AdGuardHome/internal/issues/1261
func IPStringFromAddr(addr net.Addr) (ipstr string) {
if ip := ipFromAddr(addr); ip != nil {
func IPStringFromAddr(addr net.Addr) (ipStr string) {
if ip := IPFromAddr(addr); ip != nil {
return ip.String()
}

View File

@ -72,6 +72,8 @@ type ClientHost struct {
type clientsContainer struct {
list map[string]*Client // name -> client
idIndex map[string]*Client // IP -> client
// TODO(e.burkov): Think of a way to not require string conversion for
// IP addresses.
ipHost map[string]*ClientHost // IP -> Hostname
lock sync.Mutex
@ -239,7 +241,7 @@ func (clients *clientsContainer) onHostsChanged() {
}
// Exists checks if client with this IP already exists
func (clients *clientsContainer) Exists(ip string, source clientSource) bool {
func (clients *clientsContainer) Exists(ip net.IP, source clientSource) bool {
clients.lock.Lock()
defer clients.lock.Unlock()
@ -248,7 +250,7 @@ func (clients *clientsContainer) Exists(ip string, source clientSource) bool {
return true
}
ch, ok := clients.ipHost[ip]
ch, ok := clients.ipHost[ip.String()]
if !ok {
return false
}
@ -265,7 +267,7 @@ func stringArrayDup(a []string) []string {
}
// Find searches for a client by IP
func (clients *clientsContainer) Find(ip string) (Client, bool) {
func (clients *clientsContainer) Find(ip net.IP) (Client, bool) {
clients.lock.Lock()
defer clients.lock.Unlock()
@ -287,7 +289,7 @@ func (clients *clientsContainer) FindUpstreams(ip string) *proxy.UpstreamConfig
clients.lock.Lock()
defer clients.lock.Unlock()
c, ok := clients.findByIP(ip)
c, ok := clients.findByIP(net.ParseIP(ip))
if !ok {
return nil
}
@ -307,13 +309,12 @@ func (clients *clientsContainer) FindUpstreams(ip string) *proxy.UpstreamConfig
}
// Find searches for a client by IP (and does not lock anything)
func (clients *clientsContainer) findByIP(ip string) (Client, bool) {
ipAddr := net.ParseIP(ip)
if ipAddr == nil {
func (clients *clientsContainer) findByIP(ip net.IP) (Client, bool) {
if ip == nil {
return Client{}, false
}
c, ok := clients.idIndex[ip]
c, ok := clients.idIndex[ip.String()]
if ok {
return *c, true
}
@ -324,7 +325,7 @@ func (clients *clientsContainer) findByIP(ip string) (Client, bool) {
if err != nil {
continue
}
if ipnet.Contains(ipAddr) {
if ipnet.Contains(ip) {
return *c, true
}
}
@ -333,7 +334,7 @@ func (clients *clientsContainer) findByIP(ip string) (Client, bool) {
if clients.dhcpServer == nil {
return Client{}, false
}
macFound := clients.dhcpServer.FindMACbyIP(ipAddr)
macFound := clients.dhcpServer.FindMACbyIP(ip)
if macFound == nil {
return Client{}, false
}
@ -353,16 +354,15 @@ func (clients *clientsContainer) findByIP(ip string) (Client, bool) {
}
// FindAutoClient - search for an auto-client by IP
func (clients *clientsContainer) FindAutoClient(ip string) (ClientHost, bool) {
ipAddr := net.ParseIP(ip)
if ipAddr == nil {
func (clients *clientsContainer) FindAutoClient(ip net.IP) (ClientHost, bool) {
if ip == nil {
return ClientHost{}, false
}
clients.lock.Lock()
defer clients.lock.Unlock()
ch, ok := clients.ipHost[ip]
ch, ok := clients.ipHost[ip.String()]
if ok {
return *ch, true
}
@ -539,7 +539,7 @@ func (clients *clientsContainer) Update(name string, c Client) error {
}
// SetWhoisInfo - associate WHOIS information with a client
func (clients *clientsContainer) SetWhoisInfo(ip string, info [][]string) {
func (clients *clientsContainer) SetWhoisInfo(ip net.IP, info [][]string) {
clients.lock.Lock()
defer clients.lock.Unlock()
@ -549,7 +549,8 @@ func (clients *clientsContainer) SetWhoisInfo(ip string, info [][]string) {
return
}
ch, ok := clients.ipHost[ip]
ipStr := ip.String()
ch, ok := clients.ipHost[ipStr]
if ok {
ch.WhoisInfo = info
log.Debug("Clients: set WHOIS info for auto-client %s: %v", ch.Host, ch.WhoisInfo)
@ -561,7 +562,7 @@ func (clients *clientsContainer) SetWhoisInfo(ip string, info [][]string) {
Source: ClientSourceWHOIS,
}
ch.WhoisInfo = info
clients.ipHost[ip] = ch
clients.ipHost[ipStr] = ch
log.Debug("Clients: set WHOIS info for auto-client with IP %s: %v", ip, ch.WhoisInfo)
}

View File

@ -36,21 +36,21 @@ func TestClients(t *testing.T) {
assert.True(t, b)
assert.Nil(t, err)
c, b = clients.Find("1.1.1.1")
c, b = clients.Find(net.IPv4(1, 1, 1, 1))
assert.True(t, b)
assert.Equal(t, c.Name, "client1")
c, b = clients.Find("1:2:3::4")
c, b = clients.Find(net.ParseIP("1:2:3::4"))
assert.True(t, b)
assert.Equal(t, c.Name, "client1")
c, b = clients.Find("2.2.2.2")
c, b = clients.Find(net.IPv4(2, 2, 2, 2))
assert.True(t, b)
assert.Equal(t, c.Name, "client2")
assert.False(t, clients.Exists("1.2.3.4", ClientSourceHostsFile))
assert.True(t, clients.Exists("1.1.1.1", ClientSourceHostsFile))
assert.True(t, clients.Exists("2.2.2.2", ClientSourceHostsFile))
assert.False(t, clients.Exists(net.IPv4(1, 2, 3, 4), ClientSourceHostsFile))
assert.True(t, clients.Exists(net.IPv4(1, 1, 1, 1), ClientSourceHostsFile))
assert.True(t, clients.Exists(net.IPv4(2, 2, 2, 2), ClientSourceHostsFile))
})
t.Run("add_fail_name", func(t *testing.T) {
@ -112,8 +112,8 @@ func TestClients(t *testing.T) {
err := clients.Update("client1", c)
assert.Nil(t, err)
assert.False(t, clients.Exists("1.1.1.1", ClientSourceHostsFile))
assert.True(t, clients.Exists("1.1.1.2", ClientSourceHostsFile))
assert.False(t, clients.Exists(net.IPv4(1, 1, 1, 1), ClientSourceHostsFile))
assert.True(t, clients.Exists(net.IPv4(1, 1, 1, 2), ClientSourceHostsFile))
c = Client{
IDs: []string{"1.1.1.2"},
@ -124,7 +124,7 @@ func TestClients(t *testing.T) {
err = clients.Update("client1", c)
assert.Nil(t, err)
c, b := clients.Find("1.1.1.2")
c, b := clients.Find(net.IPv4(1, 1, 1, 2))
assert.True(t, b)
assert.Equal(t, "client1-renamed", c.Name)
assert.Equal(t, "1.1.1.2", c.IDs[0])
@ -135,7 +135,7 @@ func TestClients(t *testing.T) {
t.Run("del_success", func(t *testing.T) {
b := clients.Del("client1-renamed")
assert.True(t, b)
assert.False(t, clients.Exists("1.1.1.2", ClientSourceHostsFile))
assert.False(t, clients.Exists(net.IPv4(1, 1, 1, 2), ClientSourceHostsFile))
})
t.Run("del_fail", func(t *testing.T) {
@ -156,7 +156,7 @@ func TestClients(t *testing.T) {
assert.True(t, b)
assert.Nil(t, err)
assert.True(t, clients.Exists("1.1.1.1", ClientSourceHostsFile))
assert.True(t, clients.Exists(net.IPv4(1, 1, 1, 1), ClientSourceHostsFile))
})
t.Run("addhost_fail", func(t *testing.T) {
@ -174,12 +174,12 @@ func TestClientsWhois(t *testing.T) {
whois := [][]string{{"orgname", "orgname-val"}, {"country", "country-val"}}
// set whois info on new client
clients.SetWhoisInfo("1.1.1.255", whois)
clients.SetWhoisInfo(net.IPv4(1, 1, 1, 255), whois)
assert.Equal(t, "orgname-val", clients.ipHost["1.1.1.255"].WhoisInfo[0][1])
// set whois info on existing auto-client
_, _ = clients.AddHost("1.1.1.1", "host", ClientSourceRDNS)
clients.SetWhoisInfo("1.1.1.1", whois)
clients.SetWhoisInfo(net.IPv4(1, 1, 1, 1), whois)
assert.Equal(t, "orgname-val", clients.ipHost["1.1.1.1"].WhoisInfo[0][1])
// Check that we cannot set whois info on a manually-added client
@ -188,7 +188,7 @@ func TestClientsWhois(t *testing.T) {
Name: "client1",
}
_, _ = clients.Add(c)
clients.SetWhoisInfo("1.1.1.2", whois)
clients.SetWhoisInfo(net.IPv4(1, 1, 1, 2), whois)
assert.Nil(t, clients.ipHost["1.1.1.2"])
_ = clients.Del("client1")
}

View File

@ -3,6 +3,7 @@ package home
import (
"encoding/json"
"fmt"
"net"
"net/http"
)
@ -229,8 +230,9 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http
q := r.URL.Query()
data := []map[string]interface{}{}
for i := 0; ; i++ {
ip := q.Get(fmt.Sprintf("ip%d", i))
if len(ip) == 0 {
ipStr := q.Get(fmt.Sprintf("ip%d", i))
ip := net.ParseIP(ipStr)
if ip == nil {
break
}
@ -248,7 +250,7 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http
cj.Disallowed, cj.DisallowedRule = clients.dnsServer.IsBlockedIP(ip)
}
el[ip] = cj
el[ipStr] = cj
data = append(data, el)
}
@ -267,7 +269,8 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http
// findTemporary looks up the IP in temporary storages, like autohosts or
// blocklists.
func (clients *clientsContainer) findTemporary(ip string) (cj clientJSON, found bool) {
func (clients *clientsContainer) findTemporary(ip net.IP) (cj clientJSON, found bool) {
ipStr := ip.String()
ch, ok := clients.FindAutoClient(ip)
if !ok {
// It is still possible that the IP used to be in the runtime
@ -281,7 +284,7 @@ func (clients *clientsContainer) findTemporary(ip string) (cj clientJSON, found
}
cj = clientJSON{
IDs: []string{ip},
IDs: []string{ipStr},
Disallowed: disallowed,
DisallowedRule: rule,
}
@ -289,7 +292,7 @@ func (clients *clientsContainer) findTemporary(ip string) (cj clientJSON, found
return cj, true
}
cj = clientHostToJSON(ip, ch)
cj = clientHostToJSON(ipStr, ch)
cj.Disallowed, cj.DisallowedRule = clients.dnsServer.IsBlockedIP(ip)
return cj, true

View File

@ -2,6 +2,7 @@ package home
import (
"io/ioutil"
"net"
"os"
"path/filepath"
"sync"
@ -40,7 +41,7 @@ type configuration struct {
// It's reset after config is parsed
fileData []byte
BindHost string `yaml:"bind_host"` // BindHost is the IP address of the HTTP server to bind to
BindHost net.IP `yaml:"bind_host"` // BindHost is the IP address of the HTTP server to bind to
BindPort int `yaml:"bind_port"` // BindPort is the port the HTTP server
BetaBindPort int `yaml:"beta_bind_port"` // BetaBindPort is the port for new client
Users []User `yaml:"users"` // Users that can access HTTP server
@ -74,7 +75,7 @@ type configuration struct {
// field ordering is important -- yaml fields will mirror ordering from here
type dnsConfig struct {
BindHost string `yaml:"bind_host"`
BindHost net.IP `yaml:"bind_host"`
Port int `yaml:"port"`
// time interval for statistics (in days)
@ -121,9 +122,9 @@ type tlsConfigSettings struct {
var config = configuration{
BindPort: 3000,
BetaBindPort: 0,
BindHost: "0.0.0.0",
BindHost: net.IP{0, 0, 0, 0},
DNS: dnsConfig{
BindHost: "0.0.0.0",
BindHost: net.IP{0, 0, 0, 0},
Port: 53,
StatsInterval: 1,
FilteringConfig: dnsforward.FilteringConfig{

View File

@ -36,11 +36,12 @@ func httpError(w http.ResponseWriter, code int, format string, args ...interface
// ---------------
// dns run control
// ---------------
func addDNSAddress(dnsAddresses *[]string, addr string) {
func addDNSAddress(dnsAddresses *[]string, addr net.IP) {
hostport := addr.String()
if config.DNS.Port != 53 {
addr = fmt.Sprintf("%s:%d", addr, config.DNS.Port)
hostport = net.JoinHostPort(hostport, strconv.Itoa(config.DNS.Port))
}
*dnsAddresses = append(*dnsAddresses, addr)
*dnsAddresses = append(*dnsAddresses, hostport)
}
func handleStatus(w http.ResponseWriter, _ *http.Request) {

View File

@ -31,7 +31,7 @@ type netInterfaceJSON struct {
Name string `json:"name"`
MTU int `json:"mtu"`
HardwareAddr string `json:"hardware_address"`
Addresses []string `json:"ip_addresses"`
Addresses []net.IP `json:"ip_addresses"`
Flags string `json:"flags"`
}
@ -69,7 +69,7 @@ func (web *Web) handleInstallGetAddresses(w http.ResponseWriter, r *http.Request
type checkConfigReqEnt struct {
Port int `json:"port"`
IP string `json:"ip"`
IP net.IP `json:"ip"`
Autofix bool `json:"autofix"`
}
@ -138,7 +138,7 @@ func (web *Web) handleInstallCheckConfig(w http.ResponseWriter, r *http.Request)
if err != nil {
respData.DNS.Status = err.Error()
} else if reqData.DNS.IP != "0.0.0.0" {
} else if !reqData.DNS.IP.IsUnspecified() {
respData.StaticIP = handleStaticIP(reqData.DNS.IP, reqData.SetStaticIP)
}
}
@ -154,7 +154,7 @@ func (web *Web) handleInstallCheckConfig(w http.ResponseWriter, r *http.Request)
// handleStaticIP - handles static IP request
// It either checks if we have a static IP
// Or if set=true, it tries to set it
func handleStaticIP(ip string, set bool) staticIPJSON {
func handleStaticIP(ip net.IP, set bool) staticIPJSON {
resp := staticIPJSON{}
interfaceName := util.GetInterfaceByIP(ip)
@ -186,7 +186,7 @@ func handleStaticIP(ip string, set bool) staticIPJSON {
if isStaticIP {
resp.Static = "yes"
}
resp.IP = util.GetSubnet(interfaceName)
resp.IP = util.GetSubnet(interfaceName).String()
}
return resp
}
@ -262,7 +262,7 @@ func disableDNSStubListener() error {
}
type applyConfigReqEnt struct {
IP string `json:"ip"`
IP net.IP `json:"ip"`
Port int `json:"port"`
}
@ -297,7 +297,7 @@ func (web *Web) handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
}
restartHTTP := true
if config.BindHost == newSettings.Web.IP && config.BindPort == newSettings.Web.Port {
if config.BindHost.Equal(newSettings.Web.IP) && config.BindPort == newSettings.Web.Port {
// no need to rebind
restartHTTP = false
}
@ -307,7 +307,7 @@ func (web *Web) handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
err = util.CheckPortAvailable(newSettings.Web.IP, newSettings.Web.Port)
if err != nil {
httpError(w, http.StatusBadRequest, "Impossible to listen on IP:port %s due to %s",
net.JoinHostPort(newSettings.Web.IP, strconv.Itoa(newSettings.Web.Port)), err)
net.JoinHostPort(newSettings.Web.IP.String(), strconv.Itoa(newSettings.Web.Port)), err)
return
}
@ -388,18 +388,18 @@ func (web *Web) registerInstallHandlers() {
// checkConfigReqEntBeta is a struct representing new client's config check
// request entry. It supports multiple IP values unlike the checkConfigReqEnt.
//
// TODO(e.burkov): this should removed with the API v1 when the appropriate
// TODO(e.burkov): This should removed with the API v1 when the appropriate
// functionality will appear in default checkConfigReqEnt.
type checkConfigReqEntBeta struct {
Port int `json:"port"`
IP []string `json:"ip"`
IP []net.IP `json:"ip"`
Autofix bool `json:"autofix"`
}
// checkConfigReqBeta is a struct representing new client's config check request
// body. It uses checkConfigReqEntBeta instead of checkConfigReqEnt.
//
// TODO(e.burkov): this should removed with the API v1 when the appropriate
// TODO(e.burkov): This should removed with the API v1 when the appropriate
// functionality will appear in default checkConfigReq.
type checkConfigReqBeta struct {
Web checkConfigReqEntBeta `json:"web"`
@ -410,7 +410,7 @@ type checkConfigReqBeta struct {
// handleInstallCheckConfigBeta is a substitution of /install/check_config
// handler for new client.
//
// TODO(e.burkov): this should removed with the API v1 when the appropriate
// TODO(e.burkov): This should removed with the API v1 when the appropriate
// functionality will appear in default handleInstallCheckConfig.
func (web *Web) handleInstallCheckConfigBeta(w http.ResponseWriter, r *http.Request) {
reqData := checkConfigReqBeta{}
@ -456,17 +456,17 @@ func (web *Web) handleInstallCheckConfigBeta(w http.ResponseWriter, r *http.Requ
// applyConfigReqEntBeta is a struct representing new client's config setting
// request entry. It supports multiple IP values unlike the applyConfigReqEnt.
//
// TODO(e.burkov): this should removed with the API v1 when the appropriate
// TODO(e.burkov): This should removed with the API v1 when the appropriate
// functionality will appear in default applyConfigReqEnt.
type applyConfigReqEntBeta struct {
IP []string `json:"ip"`
IP []net.IP `json:"ip"`
Port int `json:"port"`
}
// applyConfigReqBeta is a struct representing new client's config setting
// request body. It uses applyConfigReqEntBeta instead of applyConfigReqEnt.
//
// TODO(e.burkov): this should removed with the API v1 when the appropriate
// TODO(e.burkov): This should removed with the API v1 when the appropriate
// functionality will appear in default applyConfigReq.
type applyConfigReqBeta struct {
Web applyConfigReqEntBeta `json:"web"`
@ -478,7 +478,7 @@ type applyConfigReqBeta struct {
// handleInstallConfigureBeta is a substitution of /install/configure handler
// for new client.
//
// TODO(e.burkov): this should removed with the API v1 when the appropriate
// TODO(e.burkov): This should removed with the API v1 when the appropriate
// functionality will appear in default handleInstallConfigure.
func (web *Web) handleInstallConfigureBeta(w http.ResponseWriter, r *http.Request) {
reqData := applyConfigReqBeta{}
@ -523,7 +523,7 @@ func (web *Web) handleInstallConfigureBeta(w http.ResponseWriter, r *http.Reques
// firstRunDataBeta is a struct representing new client's getting addresses
// request body. It uses array of structs instead of map.
//
// TODO(e.burkov): this should removed with the API v1 when the appropriate
// TODO(e.burkov): This should removed with the API v1 when the appropriate
// functionality will appear in default firstRunData.
type firstRunDataBeta struct {
WebPort int `json:"web_port"`
@ -534,7 +534,7 @@ type firstRunDataBeta struct {
// handleInstallConfigureBeta is a substitution of /install/get_addresses
// handler for new client.
//
// TODO(e.burkov): this should removed with the API v1 when the appropriate
// TODO(e.burkov): This should removed with the API v1 when the appropriate
// functionality will appear in default handleInstallGetAddresses.
func (web *Web) handleInstallGetAddressesBeta(w http.ResponseWriter, r *http.Request) {
data := firstRunDataBeta{}
@ -570,7 +570,7 @@ func (web *Web) handleInstallGetAddressesBeta(w http.ResponseWriter, r *http.Req
// registerBetaInstallHandlers registers the install handlers for new client
// with the structures it supports.
//
// TODO(e.burkov): this should removed with the API v1 when the appropriate
// TODO(e.burkov): This should removed with the API v1 when the appropriate
// functionality will appear in default handlers.
func (web *Web) registerBetaInstallHandlers() {
Context.mux.HandleFunc("/control/install/get_addresses_beta", preInstall(ensureGET(web.handleInstallGetAddressesBeta)))

View File

@ -55,8 +55,8 @@ func initDNSServer() error {
filterConf := config.DNS.DnsfilterConf
bindhost := config.DNS.BindHost
if config.DNS.BindHost == "0.0.0.0" {
bindhost = "127.0.0.1"
if config.DNS.BindHost.IsUnspecified() {
bindhost = net.IPv4(127, 0, 0, 1)
}
filterConf.ResolverAddress = fmt.Sprintf("%s:%d", bindhost, config.DNS.Port)
filterConf.AutoHosts = &Context.autoHosts
@ -98,26 +98,24 @@ func isRunning() bool {
}
func onDNSRequest(d *proxy.DNSContext) {
ip := dnsforward.IPStringFromAddr(d.Addr)
if ip == "" {
ip := dnsforward.IPFromAddr(d.Addr)
if ip == nil {
// This would be quite weird if we get here
return
}
ipAddr := net.ParseIP(ip)
if !ipAddr.IsLoopback() {
if !ip.IsLoopback() {
Context.rdns.Begin(ip)
}
if !Context.ipDetector.detectSpecialNetwork(ipAddr) {
if !Context.ipDetector.detectSpecialNetwork(ip) {
Context.whois.Begin(ip)
}
}
func generateServerConfig() (newconfig dnsforward.ServerConfig, err error) {
bindHost := net.ParseIP(config.DNS.BindHost)
newconfig = dnsforward.ServerConfig{
UDPListenAddr: &net.UDPAddr{IP: bindHost, Port: config.DNS.Port},
TCPListenAddr: &net.TCPAddr{IP: bindHost, Port: config.DNS.Port},
UDPListenAddr: &net.UDPAddr{IP: config.DNS.BindHost, Port: config.DNS.Port},
TCPListenAddr: &net.TCPAddr{IP: config.DNS.BindHost, Port: config.DNS.Port},
FilteringConfig: config.DNS.FilteringConfig,
ConfigModified: onConfigModified,
HTTPRegister: httpRegister,
@ -131,20 +129,20 @@ func generateServerConfig() (newconfig dnsforward.ServerConfig, err error) {
if tlsConf.PortDNSOverTLS != 0 {
newconfig.TLSListenAddr = &net.TCPAddr{
IP: bindHost,
IP: config.DNS.BindHost,
Port: tlsConf.PortDNSOverTLS,
}
}
if tlsConf.PortDNSOverQUIC != 0 {
newconfig.QUICListenAddr = &net.UDPAddr{
IP: bindHost,
IP: config.DNS.BindHost,
Port: int(tlsConf.PortDNSOverQUIC),
}
}
if tlsConf.PortDNSCrypt != 0 {
newconfig.DNSCryptConfig, err = newDNSCrypt(bindHost, tlsConf)
newconfig.DNSCryptConfig, err = newDNSCrypt(config.DNS.BindHost, tlsConf)
if err != nil {
// Don't wrap the error, because it's already
// wrapped by newDNSCrypt.
@ -245,7 +243,7 @@ func getDNSEncryption() dnsEncryption {
func getDNSAddresses() []string {
dnsAddresses := []string{}
if config.DNS.BindHost == "0.0.0.0" {
if config.DNS.BindHost.IsUnspecified() {
ifaces, e := util.GetValidNetInterfacesForWeb()
if e != nil {
log.Error("Couldn't get network interfaces: %v", e)
@ -276,10 +274,10 @@ func getDNSAddresses() []string {
}
// If a client has his own settings, apply them
func applyAdditionalFiltering(clientAddr string, setts *dnsfilter.RequestFilteringSettings) {
func applyAdditionalFiltering(clientAddr net.IP, setts *dnsfilter.RequestFilteringSettings) {
Context.dnsFilter.ApplyBlockedServices(setts, nil, true)
if len(clientAddr) == 0 {
if clientAddr == nil {
return
}
setts.ClientIP = clientAddr
@ -328,13 +326,11 @@ func startDNSServer() error {
Context.queryLog.Start()
const topClientsNumber = 100 // the number of clients to get
topClients := Context.stats.GetTopClientsIP(topClientsNumber)
for _, ip := range topClients {
ipAddr := net.ParseIP(ip)
if !ipAddr.IsLoopback() {
for _, ip := range Context.stats.GetTopClientsIP(topClientsNumber) {
if !ip.IsLoopback() {
Context.rdns.Begin(ip)
}
if !Context.ipDetector.detectSpecialNetwork(ipAddr) {
if !Context.ipDetector.detectSpecialNetwork(ip) {
Context.whois.Begin(ip)
}
}

View File

@ -206,7 +206,7 @@ func setupConfig(args options) {
}
// override bind host/port from the console
if args.bindHost != "" {
if args.bindHost != nil {
config.BindHost = args.bindHost
}
if args.bindPort != 0 {
@ -575,36 +575,40 @@ func printHTTPAddresses(proto string) {
port = strconv.Itoa(tlsConf.PortHTTPS)
}
var hostStr string
if proto == "https" && tlsConf.ServerName != "" {
if tlsConf.PortHTTPS == 443 {
log.Printf("Go to https://%s", tlsConf.ServerName)
} else {
log.Printf("Go to https://%s:%s", tlsConf.ServerName, port)
}
} else if config.BindHost == "0.0.0.0" {
} else if config.BindHost.IsUnspecified() {
log.Println("AdGuard Home is available on the following addresses:")
ifaces, err := util.GetValidNetInterfacesForWeb()
if err != nil {
// That's weird, but we'll ignore it
log.Printf("Go to %s://%s", proto, net.JoinHostPort(config.BindHost, port))
hostStr = config.BindHost.String()
log.Printf("Go to %s://%s", proto, net.JoinHostPort(hostStr, port))
if config.BetaBindPort != 0 {
log.Printf("Go to %s://%s (BETA)", proto, net.JoinHostPort(config.BindHost, strconv.Itoa(config.BetaBindPort)))
log.Printf("Go to %s://%s (BETA)", proto, net.JoinHostPort(hostStr, strconv.Itoa(config.BetaBindPort)))
}
return
}
for _, iface := range ifaces {
for _, addr := range iface.Addresses {
log.Printf("Go to %s://%s", proto, net.JoinHostPort(addr, strconv.Itoa(config.BindPort)))
hostStr = addr.String()
log.Printf("Go to %s://%s", proto, net.JoinHostPort(hostStr, strconv.Itoa(config.BindPort)))
if config.BetaBindPort != 0 {
log.Printf("Go to %s://%s (BETA)", proto, net.JoinHostPort(addr, strconv.Itoa(config.BetaBindPort)))
log.Printf("Go to %s://%s (BETA)", proto, net.JoinHostPort(hostStr, strconv.Itoa(config.BetaBindPort)))
}
}
}
} else {
log.Printf("Go to %s://%s", proto, net.JoinHostPort(config.BindHost, port))
hostStr = config.BindHost.String()
log.Printf("Go to %s://%s", proto, net.JoinHostPort(hostStr, port))
if config.BetaBindPort != 0 {
log.Printf("Go to %s://%s (BETA)", proto, net.JoinHostPort(config.BindHost, strconv.Itoa(config.BetaBindPort)))
log.Printf("Go to %s://%s (BETA)", proto, net.JoinHostPort(hostStr, strconv.Itoa(config.BetaBindPort)))
}
}
}

View File

@ -1,6 +1,6 @@
// +build !race
// TODO(e.burkov): remove this weird buildtag.
// TODO(e.burkov): Remove this weird buildtag.
package home

View File

@ -2,6 +2,7 @@ package home
import (
"fmt"
"net"
"os"
"strconv"
@ -13,7 +14,7 @@ type options struct {
verbose bool // is verbose logging enabled
configFilename string // path to the config file
workDir string // path to the working directory where we will store the filters data and the querylog
bindHost string // host address to bind HTTP server on
bindHost net.IP // host address to bind HTTP server on
bindPort int // port to serve HTTP pages on
logFile string // Path to the log file. If empty, write to stdout. If "syslog", writes to syslog
pidFile string // File name to save PID to
@ -54,10 +55,19 @@ type arg struct {
// against its zero value and return nil if the parameter value is
// zero otherwise they return a string slice of the parameter
func ipSliceOrNil(ip net.IP) []string {
if ip == nil {
return nil
}
return []string{ip.String()}
}
func stringSliceOrNil(s string) []string {
if s == "" {
return nil
}
return []string{s}
}
@ -65,6 +75,7 @@ func intSliceOrNil(i int) []string {
if i == 0 {
return nil
}
return []string{strconv.Itoa(i)}
}
@ -72,6 +83,7 @@ func boolSliceOrNil(b bool) []string {
if b {
return []string{}
}
return nil
}
@ -96,8 +108,8 @@ var workDirArg = arg{
var hostArg = arg{
"Host address to bind HTTP server on",
"host", "h",
func(o options, v string) (options, error) { o.bindHost = v; return o, nil }, nil, nil,
func(o options) []string { return stringSliceOrNil(o.bindHost) },
func(o options, v string) (options, error) { o.bindHost = net.ParseIP(v); return o, nil }, nil, nil,
func(o options) []string { return ipSliceOrNil(o.bindHost) },
}
var portArg = arg{

View File

@ -2,6 +2,7 @@ package home
import (
"fmt"
"net"
"testing"
)
@ -65,14 +66,14 @@ func TestParseWorkDir(t *testing.T) {
}
func TestParseBindHost(t *testing.T) {
if testParseOk(t).bindHost != "" {
if testParseOk(t).bindHost != nil {
t.Fatal("empty is no host")
}
if testParseOk(t, "-h", "addr").bindHost != "addr" {
if !testParseOk(t, "-h", "1.2.3.4").bindHost.Equal(net.IP{1, 2, 3, 4}) {
t.Fatal("-h is host")
}
testParseParamMissing(t, "-h")
if testParseOk(t, "--host", "addr").bindHost != "addr" {
if !testParseOk(t, "--host", "1.2.3.4").bindHost.Equal(net.IP{1, 2, 3, 4}) {
t.Fatal("--host is host")
}
testParseParamMissing(t, "--host")
@ -204,7 +205,7 @@ func TestSerializeWorkDir(t *testing.T) {
}
func TestSerializeBindHost(t *testing.T) {
testSerialize(t, options{bindHost: "addr"}, "-h", "addr")
testSerialize(t, options{bindHost: net.IP{1, 2, 3, 4}}, "-h", "1.2.3.4")
}
func TestSerializeBindPort(t *testing.T) {

View File

@ -2,6 +2,7 @@ package home
import (
"encoding/binary"
"net"
"strings"
"time"
@ -15,7 +16,7 @@ import (
type RDNS struct {
dnsServer *dnsforward.Server
clients *clientsContainer
ipChannel chan string // pass data from DNS request handling thread to rDNS thread
ipChannel chan net.IP // pass data from DNS request handling thread to rDNS thread
// Contains IP addresses of clients to be resolved by rDNS
// If IP address is resolved, it stays here while it's inside Clients.
@ -35,15 +36,15 @@ func InitRDNS(dnsServer *dnsforward.Server, clients *clientsContainer) *RDNS {
cconf.MaxCount = 10000
r.ipAddrs = cache.New(cconf)
r.ipChannel = make(chan string, 256)
r.ipChannel = make(chan net.IP, 256)
go r.workerLoop()
return &r
}
// Begin - add IP address to rDNS queue
func (r *RDNS) Begin(ip string) {
func (r *RDNS) Begin(ip net.IP) {
now := uint64(time.Now().Unix())
expire := r.ipAddrs.Get([]byte(ip))
expire := r.ipAddrs.Get(ip)
if len(expire) != 0 {
exp := binary.BigEndian.Uint64(expire)
if exp > now {
@ -54,7 +55,7 @@ func (r *RDNS) Begin(ip string) {
expire = make([]byte, 8)
const ttl = 1 * 60 * 60
binary.BigEndian.PutUint64(expire, now+ttl)
_ = r.ipAddrs.Set([]byte(ip), expire)
_ = r.ipAddrs.Set(ip, expire)
if r.clients.Exists(ip, ClientSourceRDNS) {
return
@ -70,7 +71,7 @@ func (r *RDNS) Begin(ip string) {
}
// Use rDNS to get hostname by IP address
func (r *RDNS) resolve(ip string) string {
func (r *RDNS) resolve(ip net.IP) string {
log.Tracef("Resolving host for %s", ip)
req := dns.Msg{}
@ -83,7 +84,7 @@ func (r *RDNS) resolve(ip string) string {
},
}
var err error
req.Question[0].Name, err = dns.ReverseAddr(ip)
req.Question[0].Name, err = dns.ReverseAddr(ip.String())
if err != nil {
log.Debug("Error while calling dns.ReverseAddr(%s): %s", ip, err)
return ""
@ -123,6 +124,6 @@ func (r *RDNS) workerLoop() {
continue
}
_, _ = r.clients.AddHost(ip, host, ClientSourceRDNS)
_, _ = r.clients.AddHost(ip.String(), host, ClientSourceRDNS)
}
}

View File

@ -1,6 +1,7 @@
package home
import (
"net"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
@ -16,6 +17,6 @@ func TestResolveRDNS(t *testing.T) {
clients := &clientsContainer{}
rdns := InitRDNS(dns, clients)
r := rdns.resolve("1.1.1.1")
r := rdns.resolve(net.IP{1, 1, 1, 1})
assert.Equal(t, "one.one.one.one", r, r)
}

View File

@ -31,7 +31,7 @@ const (
type webConfig struct {
firstRun bool
BindHost string
BindHost net.IP
BindPort int
BetaBindPort int
PortHTTPS int
@ -161,10 +161,11 @@ func (web *Web) Start() {
printHTTPAddresses("http")
errs := make(chan error, 2)
hostStr := web.conf.BindHost.String()
// we need to have new instance, because after Shutdown() the Server is not usable
web.httpServer = &http.Server{
ErrorLog: log.StdLog("web: http", log.DEBUG),
Addr: net.JoinHostPort(web.conf.BindHost, strconv.Itoa(web.conf.BindPort)),
Addr: net.JoinHostPort(hostStr, strconv.Itoa(web.conf.BindPort)),
Handler: withMiddlewares(Context.mux, limitRequestBody),
ReadTimeout: web.conf.ReadTimeout,
ReadHeaderTimeout: web.conf.ReadHeaderTimeout,
@ -177,7 +178,7 @@ func (web *Web) Start() {
if web.conf.BetaBindPort != 0 {
web.httpServerBeta = &http.Server{
ErrorLog: log.StdLog("web: http", log.DEBUG),
Addr: net.JoinHostPort(web.conf.BindHost, strconv.Itoa(web.conf.BetaBindPort)),
Addr: net.JoinHostPort(hostStr, strconv.Itoa(web.conf.BetaBindPort)),
Handler: withMiddlewares(Context.mux, limitRequestBody, web.wrapIndexBeta),
ReadTimeout: web.conf.ReadTimeout,
ReadHeaderTimeout: web.conf.ReadHeaderTimeout,
@ -236,7 +237,7 @@ func (web *Web) tlsServerLoop() {
web.httpsServer.cond.L.Unlock()
// prepare HTTPS server
address := net.JoinHostPort(web.conf.BindHost, strconv.Itoa(web.conf.PortHTTPS))
address := net.JoinHostPort(web.conf.BindHost.String(), strconv.Itoa(web.conf.PortHTTPS))
web.httpsServer.server = &http.Server{
ErrorLog: log.StdLog("web: https", log.DEBUG),
Addr: address,

View File

@ -26,7 +26,7 @@ const (
// Whois - module context
type Whois struct {
clients *clientsContainer
ipChan chan string
ipChan chan net.IP
timeoutMsec uint
// Contains IP addresses of clients
@ -46,7 +46,7 @@ func initWhois(clients *clientsContainer) *Whois {
cconf.MaxCount = 10000
w.ipAddrs = cache.New(cconf)
w.ipChan = make(chan string, 255)
w.ipChan = make(chan net.IP, 255)
go w.workerLoop()
return &w
}
@ -183,9 +183,9 @@ func (w *Whois) queryAll(target string) (string, error) {
}
// Request WHOIS information
func (w *Whois) process(ip string) [][]string {
func (w *Whois) process(ip net.IP) [][]string {
data := [][]string{}
resp, err := w.queryAll(ip)
resp, err := w.queryAll(ip.String())
if err != nil {
log.Debug("Whois: error: %s IP:%s", err, ip)
return data
@ -209,7 +209,7 @@ func (w *Whois) process(ip string) [][]string {
}
// Begin - begin requesting WHOIS info
func (w *Whois) Begin(ip string) {
func (w *Whois) Begin(ip net.IP) {
now := uint64(time.Now().Unix())
expire := w.ipAddrs.Get([]byte(ip))
if len(expire) != 0 {

View File

@ -22,9 +22,11 @@ var logEntryHandlers = map[string]logEntryHandler{
if !ok {
return nil
}
if len(ent.IP) == 0 {
ent.IP = v
if ent.IP == nil {
ent.IP = net.ParseIP(v)
}
return nil
},
"T": func(t json.Token, ent *logEntry) error {

View File

@ -47,7 +47,7 @@ func TestDecodeLogEntry(t *testing.T) {
assert.Nil(t, err)
want := &logEntry{
IP: "127.0.0.1",
IP: net.IPv4(127, 0, 0, 1),
Time: time.Date(2020, 11, 25, 15, 55, 56, 519796000, time.UTC),
QHost: "an.yandex.ru",
QType: "A",

View File

@ -14,22 +14,19 @@ import (
// TODO(a.garipov): Use a proper structured approach here.
// Get Client IP address
func (l *queryLog) getClientIP(clientIP string) string {
if l.conf.AnonymizeClientIP {
ip := net.ParseIP(clientIP)
if ip != nil {
ip4 := ip.To4()
const AnonymizeClientIP4Mask = 16
const AnonymizeClientIP6Mask = 112
if ip4 != nil {
clientIP = ip4.Mask(net.CIDRMask(AnonymizeClientIP4Mask, 32)).String()
} else {
clientIP = ip.Mask(net.CIDRMask(AnonymizeClientIP6Mask, 128)).String()
}
}
func (l *queryLog) getClientIP(ip net.IP) (clientIP net.IP) {
if l.conf.AnonymizeClientIP && ip != nil {
const AnonymizeClientIPv4Mask = 16
const AnonymizeClientIPv6Mask = 112
if ip.To4() != nil {
return ip.Mask(net.CIDRMask(AnonymizeClientIPv4Mask, 32))
}
return clientIP
return ip.Mask(net.CIDRMask(AnonymizeClientIPv6Mask, 128))
}
return ip
}
// jobject is a JSON object alias.
@ -153,9 +150,9 @@ func answerToMap(a *dns.Msg) (answers []jobject) {
// try most common record types
switch v := k.(type) {
case *dns.A:
answer["value"] = v.A.String()
answer["value"] = v.A
case *dns.AAAA:
answer["value"] = v.AAAA.String()
answer["value"] = v.AAAA
case *dns.MX:
answer["value"] = fmt.Sprintf("%v %v", v.Preference, v.Mx)
case *dns.CNAME:

View File

@ -3,6 +3,7 @@ package querylog
import (
"fmt"
"net"
"os"
"path/filepath"
"strings"
@ -60,7 +61,7 @@ func NewClientProto(s string) (cp ClientProto, err error) {
// logEntry - represents a single log entry
type logEntry struct {
IP string `json:"IP"` // Client IP
IP net.IP `json:"IP"` // Client IP
Time time.Time `json:"T"`
QHost string `json:"QH"`
@ -147,7 +148,7 @@ func (l *queryLog) Add(params AddParams) {
now := time.Now()
entry := logEntry{
IP: l.getClientIP(params.ClientIP.String()),
IP: l.getClientIP(params.ClientIP),
Time: now,
Result: *params.Result,

View File

@ -40,27 +40,27 @@ func TestQueryLog(t *testing.T) {
l := newQueryLog(conf)
// add disk entries
addEntry(l, "example.org", "1.1.1.1", "2.2.2.1")
addEntry(l, "example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
// write to disk (first file)
_ = l.flushLogBuffer(true)
// start writing to the second file
_ = l.rotate()
// add disk entries
addEntry(l, "example.org", "1.1.1.2", "2.2.2.2")
addEntry(l, "example.org", net.IPv4(1, 1, 1, 2), net.IPv4(2, 2, 2, 2))
// write to disk
_ = l.flushLogBuffer(true)
// add memory entries
addEntry(l, "test.example.org", "1.1.1.3", "2.2.2.3")
addEntry(l, "example.com", "1.1.1.4", "2.2.2.4")
addEntry(l, "test.example.org", net.IPv4(1, 1, 1, 3), net.IPv4(2, 2, 2, 3))
addEntry(l, "example.com", net.IPv4(1, 1, 1, 4), net.IPv4(2, 2, 2, 4))
// get all entries
params := newSearchParams()
entries, _ := l.search(params)
assert.Len(t, entries, 4)
assertLogEntry(t, entries[0], "example.com", "1.1.1.4", "2.2.2.4")
assertLogEntry(t, entries[1], "test.example.org", "1.1.1.3", "2.2.2.3")
assertLogEntry(t, entries[2], "example.org", "1.1.1.2", "2.2.2.2")
assertLogEntry(t, entries[3], "example.org", "1.1.1.1", "2.2.2.1")
assertLogEntry(t, entries[0], "example.com", net.IPv4(1, 1, 1, 4), net.IPv4(2, 2, 2, 4))
assertLogEntry(t, entries[1], "test.example.org", net.IPv4(1, 1, 1, 3), net.IPv4(2, 2, 2, 3))
assertLogEntry(t, entries[2], "example.org", net.IPv4(1, 1, 1, 2), net.IPv4(2, 2, 2, 2))
assertLogEntry(t, entries[3], "example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
// search by domain (strict)
params = newSearchParams()
@ -71,7 +71,7 @@ func TestQueryLog(t *testing.T) {
})
entries, _ = l.search(params)
assert.Len(t, entries, 1)
assertLogEntry(t, entries[0], "test.example.org", "1.1.1.3", "2.2.2.3")
assertLogEntry(t, entries[0], "test.example.org", net.IPv4(1, 1, 1, 3), net.IPv4(2, 2, 2, 3))
// search by domain (not strict)
params = newSearchParams()
@ -82,9 +82,9 @@ func TestQueryLog(t *testing.T) {
})
entries, _ = l.search(params)
assert.Len(t, entries, 3)
assertLogEntry(t, entries[0], "test.example.org", "1.1.1.3", "2.2.2.3")
assertLogEntry(t, entries[1], "example.org", "1.1.1.2", "2.2.2.2")
assertLogEntry(t, entries[2], "example.org", "1.1.1.1", "2.2.2.1")
assertLogEntry(t, entries[0], "test.example.org", net.IPv4(1, 1, 1, 3), net.IPv4(2, 2, 2, 3))
assertLogEntry(t, entries[1], "example.org", net.IPv4(1, 1, 1, 2), net.IPv4(2, 2, 2, 2))
assertLogEntry(t, entries[2], "example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
// search by client IP (strict)
params = newSearchParams()
@ -95,7 +95,7 @@ func TestQueryLog(t *testing.T) {
})
entries, _ = l.search(params)
assert.Len(t, entries, 1)
assertLogEntry(t, entries[0], "example.org", "1.1.1.2", "2.2.2.2")
assertLogEntry(t, entries[0], "example.org", net.IPv4(1, 1, 1, 2), net.IPv4(2, 2, 2, 2))
// search by client IP (part of)
params = newSearchParams()
@ -106,10 +106,10 @@ func TestQueryLog(t *testing.T) {
})
entries, _ = l.search(params)
assert.Len(t, entries, 4)
assertLogEntry(t, entries[0], "example.com", "1.1.1.4", "2.2.2.4")
assertLogEntry(t, entries[1], "test.example.org", "1.1.1.3", "2.2.2.3")
assertLogEntry(t, entries[2], "example.org", "1.1.1.2", "2.2.2.2")
assertLogEntry(t, entries[3], "example.org", "1.1.1.1", "2.2.2.1")
assertLogEntry(t, entries[0], "example.com", net.IPv4(1, 1, 1, 4), net.IPv4(2, 2, 2, 4))
assertLogEntry(t, entries[1], "test.example.org", net.IPv4(1, 1, 1, 3), net.IPv4(2, 2, 2, 3))
assertLogEntry(t, entries[2], "example.org", net.IPv4(1, 1, 1, 2), net.IPv4(2, 2, 2, 2))
assertLogEntry(t, entries[3], "example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
}
func TestQueryLogOffsetLimit(t *testing.T) {
@ -124,13 +124,13 @@ func TestQueryLogOffsetLimit(t *testing.T) {
// add 10 entries to the log
for i := 0; i < 10; i++ {
addEntry(l, "second.example.org", "1.1.1.1", "2.2.2.1")
addEntry(l, "second.example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
}
// write them to disk (first file)
_ = l.flushLogBuffer(true)
// add 10 more entries to the log (memory)
for i := 0; i < 10; i++ {
addEntry(l, "first.example.org", "1.1.1.1", "2.2.2.1")
addEntry(l, "first.example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
}
// First page
@ -178,7 +178,7 @@ func TestQueryLogMaxFileScanEntries(t *testing.T) {
// add 10 entries to the log
for i := 0; i < 10; i++ {
addEntry(l, "example.org", "1.1.1.1", "2.2.2.1")
addEntry(l, "example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
}
// write them to disk (first file)
_ = l.flushLogBuffer(true)
@ -204,9 +204,9 @@ func TestQueryLogFileDisabled(t *testing.T) {
defer func() { _ = os.RemoveAll(conf.BaseDir) }()
l := newQueryLog(conf)
addEntry(l, "example1.org", "1.1.1.1", "2.2.2.1")
addEntry(l, "example2.org", "1.1.1.1", "2.2.2.1")
addEntry(l, "example3.org", "1.1.1.1", "2.2.2.1")
addEntry(l, "example1.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
addEntry(l, "example2.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
addEntry(l, "example3.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
// the oldest entry is now removed from mem buffer
params := newSearchParams()
@ -216,7 +216,7 @@ func TestQueryLogFileDisabled(t *testing.T) {
assert.Equal(t, "example2.org", ll[1].QHost)
}
func addEntry(l *queryLog, host, answerStr, client string) {
func addEntry(l *queryLog, host string, answerStr, client net.IP) {
q := dns.Msg{}
q.Question = append(q.Question, dns.Question{
Name: host + ".",
@ -232,7 +232,7 @@ func addEntry(l *queryLog, host, answerStr, client string) {
Rrtype: dns.TypeA,
Class: dns.ClassINET,
}
answer.A = net.ParseIP(answerStr)
answer.A = answerStr
a.Answer = append(a.Answer, answer)
res := dnsfilter.Result{
IsFiltered: true,
@ -248,13 +248,13 @@ func addEntry(l *queryLog, host, answerStr, client string) {
Answer: &a,
OrigAnswer: &a,
Result: &res,
ClientIP: net.ParseIP(client),
ClientIP: client,
Upstream: "upstream",
}
l.Add(params)
}
func assertLogEntry(t *testing.T, entry *logEntry, host, answer, client string) bool {
func assertLogEntry(t *testing.T, entry *logEntry, host string, answer, client net.IP) bool {
assert.Equal(t, host, entry.QHost)
assert.Equal(t, client, entry.IP)
assert.Equal(t, "A", entry.QType)
@ -263,9 +263,9 @@ func assertLogEntry(t *testing.T, entry *logEntry, host, answer, client string)
msg := new(dns.Msg)
assert.Nil(t, msg.Unpack(entry.Answer))
assert.Len(t, msg.Answer, 1)
ip := proxyutil.GetIPFromDNSRecord(msg.Answer[0])
ip := proxyutil.GetIPFromDNSRecord(msg.Answer[0]).To16()
assert.NotNil(t, ip)
assert.Equal(t, answer, ip.String())
assert.Equal(t, answer, ip)
return true
}

View File

@ -94,16 +94,20 @@ func (c *searchCriteria) ctDomainOrClientCase(entry *logEntry) bool {
if c.strict && qhost == searchVal {
return true
}
if !c.strict && strings.Contains(qhost, searchVal) {
return true
}
if c.strict && entry.IP == c.value {
ipStr := entry.IP.String()
if c.strict && ipStr == c.value {
return true
}
if !c.strict && strings.Contains(entry.IP, c.value) {
if !c.strict && strings.Contains(ipStr, c.value) {
return true
}
return false
}

View File

@ -48,7 +48,7 @@ type Stats interface {
Update(e Entry)
// Get IP addresses of the clients with the most number of requests
GetTopClientsIP(limit uint) []string
GetTopClientsIP(limit uint) []net.IP
// WriteDiskConfig - write configuration
WriteDiskConfig(dc *DiskConfig)

View File

@ -80,7 +80,7 @@ func TestStats(t *testing.T) {
assert.EqualValues(t, 0.123456, d["avg_processing_time"].(float64))
topClients := s.GetTopClientsIP(2)
assert.Equal(t, "127.0.0.1", topClients[0])
assert.True(t, net.IP{127, 0, 0, 1}.Equal(topClients[0]))
s.clear()
s.Close()

View File

@ -443,22 +443,19 @@ func (s *statsCtx) clear() {
}
// Get Client IP address
func (s *statsCtx) getClientIP(clientIP string) string {
if s.conf.AnonymizeClientIP {
ip := net.ParseIP(clientIP)
if ip != nil {
ip4 := ip.To4()
func (s *statsCtx) getClientIP(ip net.IP) (clientIP net.IP) {
if s.conf.AnonymizeClientIP && ip != nil {
const AnonymizeClientIP4Mask = 16
const AnonymizeClientIP6Mask = 112
if ip4 != nil {
clientIP = ip4.Mask(net.CIDRMask(AnonymizeClientIP4Mask, 32)).String()
} else {
clientIP = ip.Mask(net.CIDRMask(AnonymizeClientIP6Mask, 128)).String()
}
}
if ip.To4() != nil {
return ip.Mask(net.CIDRMask(AnonymizeClientIP4Mask, 32))
}
return clientIP
return ip.Mask(net.CIDRMask(AnonymizeClientIP6Mask, 128))
}
return ip
}
func (s *statsCtx) Update(e Entry) {
@ -468,7 +465,7 @@ func (s *statsCtx) Update(e Entry) {
!(len(e.Client) == 4 || len(e.Client) == 16) {
return
}
client := s.getClientIP(e.Client.String())
client := s.getClientIP(e.Client)
s.unitLock.Lock()
u := s.unit
@ -481,7 +478,7 @@ func (s *statsCtx) Update(e Entry) {
u.blockedDomains[e.Domain]++
}
u.clients[client]++
u.clients[client.String()]++
u.timeSum += uint64(e.Time)
u.nTotal++
s.unitLock.Unlock()
@ -658,7 +655,7 @@ func (s *statsCtx) getData() map[string]interface{} {
return d
}
func (s *statsCtx) GetTopClientsIP(maxCount uint) []string {
func (s *statsCtx) GetTopClientsIP(maxCount uint) []net.IP {
units, _ := s.loadUnits(s.conf.limit)
if units == nil {
return nil
@ -672,9 +669,9 @@ func (s *statsCtx) GetTopClientsIP(maxCount uint) []string {
}
}
a := convertMapToArray(m, int(maxCount))
d := []string{}
d := []net.IP{}
for _, it := range a {
d = append(d, it.Name)
d = append(d, net.ParseIP(it.Name))
}
return d
}

View File

@ -119,17 +119,13 @@ func ifacesStaticConfig(r io.Reader, ifaceName string) (has bool, err error) {
}
func ifaceSetStaticIP(ifaceName string) (err error) {
ip := util.GetSubnet(ifaceName)
if len(ip) == 0 {
ipNet := util.GetSubnet(ifaceName)
if ipNet.IP == nil {
return errors.New("can't get IP address")
}
ip4, _, err := net.ParseCIDR(ip)
if err != nil {
return err
}
gatewayIP := GatewayIP(ifaceName)
add := updateStaticIPdhcpcdConf(ifaceName, ip, gatewayIP, ip4)
add := updateStaticIPdhcpcdConf(ifaceName, ipNet.String(), gatewayIP, ipNet.IP)
body, err := ioutil.ReadFile("/etc/dhcpcd.conf")
if err != nil {

View File

@ -108,11 +108,11 @@ func TestAutoHostsFSNotify(t *testing.T) {
ips = ah.Process("newhost", dns.TypeA)
assert.NotNil(t, ips)
assert.Len(t, ips, 1)
assert.Equal(t, "127.0.0.2", ips[0].String())
assert.True(t, net.IP{127, 0, 0, 2}.Equal(ips[0]))
}
func TestIP(t *testing.T) {
assert.Equal(t, "127.0.0.1", DNSUnreverseAddr("1.0.0.127.in-addr.arpa").String())
assert.True(t, net.IP{127, 0, 0, 1}.Equal(DNSUnreverseAddr("1.0.0.127.in-addr.arpa")))
assert.Equal(t, "::abcd:1234", DNSUnreverseAddr("4.3.2.1.d.c.b.a.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa").String())
assert.Equal(t, "::abcd:1234", DNSUnreverseAddr("4.3.2.1.d.c.B.A.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa").String())

View File

@ -18,8 +18,8 @@ type NetInterface struct {
Name string // Network interface name
MTU int // MTU
HardwareAddr string // Hardware address
Addresses []string // Array with the network interface addresses
Subnets []string // Array with CIDR addresses of this network interface
Addresses []net.IP // Array with the network interface addresses
Subnets []*net.IPNet // Array with CIDR addresses of this network interface
Flags string // Network interface flags (up, broadcast, etc)
}
@ -78,8 +78,8 @@ func GetValidNetInterfacesForWeb() ([]NetInterface, error) {
if ipNet.IP.IsLinkLocalUnicast() {
continue
}
netIface.Addresses = append(netIface.Addresses, ipNet.IP.String())
netIface.Subnets = append(netIface.Subnets, ipNet.String())
netIface.Addresses = append(netIface.Addresses, ipNet.IP)
netIface.Subnets = append(netIface.Subnets, ipNet)
}
// Discard interfaces with no addresses
@ -91,8 +91,8 @@ func GetValidNetInterfacesForWeb() ([]NetInterface, error) {
return netInterfaces, nil
}
// GetInterfaceByIP - Get interface name by its IP address.
func GetInterfaceByIP(ip string) string {
// GetInterfaceByIP returns the name of interface containing provided ip.
func GetInterfaceByIP(ip net.IP) string {
ifaces, err := GetValidNetInterfacesForWeb()
if err != nil {
return ""
@ -100,7 +100,7 @@ func GetInterfaceByIP(ip string) string {
for _, iface := range ifaces {
for _, addr := range iface.Addresses {
if ip == addr {
if ip.Equal(addr) {
return iface.Name
}
}
@ -109,13 +109,13 @@ func GetInterfaceByIP(ip string) string {
return ""
}
// GetSubnet - Get IP address with netmask for the specified interface
// Returns an empty string if it fails to find it
func GetSubnet(ifaceName string) string {
// GetSubnet returns pointer to net.IPNet for the specified interface or nil if
// the search fails.
func GetSubnet(ifaceName string) *net.IPNet {
netIfaces, err := GetValidNetInterfacesForWeb()
if err != nil {
log.Error("Could not get network interfaces info: %v", err)
return ""
return nil
}
for _, netIface := range netIfaces {
@ -124,12 +124,12 @@ func GetSubnet(ifaceName string) string {
}
}
return ""
return nil
}
// CheckPortAvailable - check if TCP port is available
func CheckPortAvailable(host string, port int) error {
ln, err := net.Listen("tcp", net.JoinHostPort(host, strconv.Itoa(port)))
func CheckPortAvailable(host net.IP, port int) error {
ln, err := net.Listen("tcp", net.JoinHostPort(host.String(), strconv.Itoa(port)))
if err != nil {
return err
}
@ -142,8 +142,8 @@ func CheckPortAvailable(host string, port int) error {
}
// CheckPacketPortAvailable - check if UDP port is available
func CheckPacketPortAvailable(host string, port int) error {
ln, err := net.ListenPacket("udp", net.JoinHostPort(host, strconv.Itoa(port)))
func CheckPacketPortAvailable(host net.IP, port int) error {
ln, err := net.ListenPacket("udp", net.JoinHostPort(host.String(), strconv.Itoa(port)))
if err != nil {
return err
}