Pull request: all: client id support

Merge in DNS/adguard-home from 1383-client-id to master

Updates #1383.

Squashed commit of the following:

commit ebe2678bfa9bf651a2cb1e64499b38edcf19a7ad
Author: Ildar Kamalov <ik@adguard.com>
Date:   Wed Jan 27 17:51:59 2021 +0300

    - client: check if IP is valid

commit 0c330585a170ea149ee75e43dfa65211e057299c
Author: Ildar Kamalov <ik@adguard.com>
Date:   Wed Jan 27 17:07:50 2021 +0300

    - client: find clients by client_id

commit 71c9593ee35d996846f061e114b7867c3aa3c978
Merge: 9104f161 3e9edd9e
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Wed Jan 27 16:09:45 2021 +0300

    Merge branch 'master' into 1383-client-id

commit 9104f1615d2d462606c52017df25a422df872cea
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Wed Jan 27 13:28:50 2021 +0300

    dnsforward: imp tests

commit ed47f26e611ade625a2cc2c2f71a291b796bbf8f
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Wed Jan 27 12:39:52 2021 +0300

    dnsforward: fix address

commit 98b222ba69a5d265f620c180c960d01c84a1fb3b
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Tue Jan 26 19:50:31 2021 +0300

    home: imp code

commit 4f3966548a2d8437d0b68207dd108dd1a6cb7d20
Merge: 199fdc05 c215b820
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Tue Jan 26 19:45:13 2021 +0300

    Merge branch 'master' into 1383-client-id

commit 199fdc056f8a8be5500584f3aaee32865188aedc
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Tue Jan 26 19:20:37 2021 +0300

    all: imp tests, logging, etc

commit 35ff14f4d534251aecb2ea60baba225f3eed8a3e
Author: Ildar Kamalov <ik@adguard.com>
Date:   Tue Jan 26 18:55:19 2021 +0300

    + client: remove block button from clients with client_id

commit 32991a0b4c56583a02fb5e00bba95d96000bce20
Author: Ildar Kamalov <ik@adguard.com>
Date:   Tue Jan 26 18:54:25 2021 +0300

    + client: add requests count for client_id

commit 2d68df4d2eac4a296d7469923e601dad4575c1a1
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Tue Jan 26 15:49:50 2021 +0300

    stats: handle client ids

commit 4e14ab3590328f93a8cd6e9cbe1665baf74f220b
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Tue Jan 26 13:45:25 2021 +0300

    openapi: fix example

commit ca9cf3f744fe197cace2c28ddc5bc68f71dad1f3
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Tue Jan 26 13:37:10 2021 +0300

    openapi: improve clients find api docs

commit f79876e550c424558b704bc316a4cd04f25db011
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Tue Jan 26 13:18:52 2021 +0300

    home: accept ids in clients find

commit 5b72595122aa0bd64debadfd753ed8a0e0840629
Merge: 607e241f abf8f65f
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Mon Jan 25 18:34:56 2021 +0300

    Merge branch 'master' into 1383-client-id

commit 607e241f1c339dd6397218f70b8301e3de6a1ee0
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Mon Jan 25 18:30:39 2021 +0300

    dnsforward: fix quic

commit f046352fef93e46234c2bbe8ae316d21034260e5
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Mon Jan 25 16:53:09 2021 +0300

    all: remove wildcard requirement

commit 3b679489bae82c54177372be453fe184d8f0bab6
Author: Andrey Meshkov <am@adguard.com>
Date:   Mon Jan 25 16:02:28 2021 +0300

    workDir now supports symlinks

commit 0647ab4f113de2223f6949df001f42ecab05c995
Author: Ildar Kamalov <ik@adguard.com>
Date:   Mon Jan 25 14:59:46 2021 +0300

    - client: remove wildcard from domain validation

commit b1aec04a4ecadc9d65648ed6d284188fecce01c3
Author: Ildar Kamalov <ik@adguard.com>
Date:   Mon Jan 25 14:55:39 2021 +0300

    + client: add form to download mobileconfig

... and 12 more commits
This commit is contained in:
Ainar Garipov 2021-01-27 18:32:13 +03:00
parent 3e9edd9eac
commit fc9ddcf941
56 changed files with 1735 additions and 623 deletions

View File

@ -10,11 +10,13 @@ and this project adheres to
## [Unreleased]
<!--
## [v0.105.0] - 2021-01-18
## [v0.105.0] - 2021-01-27
-->
### Added
- Client ID support for DNS-over-HTTPS, DNS-over-QUIC, and DNS-over-TLS
([#1383]).
- `$dnsrewrite` modifier for filters ([#2102]).
- The host checking API and the query logs API can now return multiple matched
rules ([#2102]).
@ -27,6 +29,7 @@ and this project adheres to
- HTTP API request body size limit ([#2305]).
[#1361]: https://github.com/AdguardTeam/AdGuardHome/issues/1361
[#1383]: https://github.com/AdguardTeam/AdGuardHome/issues/1383
[#2102]: https://github.com/AdguardTeam/AdGuardHome/issues/2102
[#2302]: https://github.com/AdguardTeam/AdGuardHome/issues/2302
[#2304]: https://github.com/AdguardTeam/AdGuardHome/issues/2304
@ -35,6 +38,7 @@ and this project adheres to
### Changed
- `workDir` now supports symlinks.
- Stopped mounting together the directories `/opt/adguardhome/conf` and
`/opt/adguardhome/work` in our Docker images ([#2589]).
- When `dns.bogus_nxdomain` option is used, the server will now transform

View File

@ -62,9 +62,12 @@ The rules are mostly sorted in the alphabetical order.
* Don't use underscores in file and package names, unless they're build tags
or for tests. This is to prevent accidental build errors with weird tags.
* Don't write code with more than four (**4**) levels of indentation. Just
like [Linus said], plus an additional level for an occasional error check or
struct initialization.
* Don't write non-test code with more than four (**4**) levels of indentation.
Just like [Linus said], plus an additional level for an occasional error
check or struct initialization.
The exception proving the rule is the table-driven test code, where an
additional level of indentation is allowed.
* Eschew external dependencies, including transitive, unless
absolutely necessary.

View File

@ -80,7 +80,10 @@ go-lint: ; $(ENV) "$(SHELL)" ./scripts/make/go-lint.sh
go-test: ; $(ENV) "$(SHELL)" ./scripts/make/go-test.sh
go-tools: ; $(ENV) "$(SHELL)" ./scripts/make/go-tools.sh
go-check: go-tools go-lint go-test
openapi-lint: ; cd ./openapi/ && $(YARN) test
openapi-show: ; cd ./openapi/ && $(YARN) start
# TODO(a.garipov): Remove the legacy targets once the build
# infrastructure stops using them.

View File

@ -87,12 +87,21 @@ If you're running **Linux**, there's a secure and easy way to install AdGuard Ho
### Guides
* [Getting Started](https://github.com/AdguardTeam/AdGuardHome/wiki/Getting-Started)
* [FAQ](https://github.com/AdguardTeam/AdGuardHome/wiki/FAQ)
* [How to Write Hosts Blocklists](https://github.com/AdguardTeam/AdGuardHome/wiki/Hosts-Blocklists)
* [Comparing AdGuard Home to Other Solutions](https://github.com/AdguardTeam/AdGuardHome/wiki/Comparison)
* Configuring AdGuard
* [Configuration](https://github.com/AdguardTeam/AdGuardHome/wiki/Configuration)
* [AdGuard Home as a DNS-over-HTTPS or DNS-over-TLS server](https://github.com/AdguardTeam/AdGuardHome/wiki/Encryption)
* [How to install and run AdGuard Home on Raspberry Pi](https://github.com/AdguardTeam/AdGuardHome/wiki/Raspberry-Pi)
* [How to install and run AdGuard Home on a Virtual Private Server](https://github.com/AdguardTeam/AdGuardHome/wiki/VPS)
* [How to write your own hosts blocklists properly](https://github.com/AdguardTeam/AdGuardHome/wiki/Hosts-Blocklists)
* [Configuring AdGuard Home Clients](https://github.com/AdguardTeam/AdGuardHome/wiki/Clients)
* [AdGuard Home as a DoH, DoT, or DoQ Server](https://github.com/AdguardTeam/AdGuardHome/wiki/Encryption)
* [AdGuard Home as a DNSCrypt Server](https://github.com/AdguardTeam/AdGuardHome/wiki/DNSCrypt)
* [AdGuard Home as a DHCP Server](https://github.com/AdguardTeam/AdGuardHome/wiki/DHCP)
* Installing AdGuard Home
* [Docker](https://github.com/AdguardTeam/AdGuardHome/wiki/Docker)
* [How to Install and Run AdGuard Home on a Raspberry Pi](https://github.com/AdguardTeam/AdGuardHome/wiki/Raspberry-Pi)
* [How to Install and Run AdGuard Home on a Virtual Private Server](https://github.com/AdguardTeam/AdGuardHome/wiki/VPS)
* [Verifying Releases](https://github.com/AdguardTeam/AdGuardHome/wiki/Verify-Releases)
### API

12
client/package-lock.json generated vendored
View File

@ -3066,12 +3066,6 @@
"pkg-up": "^2.0.0"
}
},
"caniuse-lite": {
"version": "1.0.30001062",
"resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001062.tgz",
"integrity": "sha512-ei9ZqeOnN7edDrb24QfJ0OZicpEbsWxv7WusOiQGz/f2SfvBgHHbOEwBJ8HKGVSyx8Z6ndPjxzR6m0NQq+0bfw==",
"dev": true
},
"postcss": {
"version": "7.0.30",
"resolved": "https://registry.npmjs.org/postcss/-/postcss-7.0.30.tgz",
@ -3928,9 +3922,9 @@
}
},
"caniuse-lite": {
"version": "1.0.30001059",
"resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001059.tgz",
"integrity": "sha512-oOrc+jPJWooKIA0IrNZ5sYlsXc7NP7KLhNWrSGEJhnfSzDvDJ0zd3i6HXsslExY9bbu+x0FQ5C61LcqmPt7bOQ==",
"version": "1.0.30001165",
"resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001165.tgz",
"integrity": "sha512-8cEsSMwXfx7lWSUMA2s08z9dIgsnR5NAqjXP23stdsU3AUWkCr/rr4s4OFtHXn5XXr6+7kam3QFVoYyXNPdJPA==",
"dev": true
},
"capture-exit": {

View File

@ -32,6 +32,7 @@
"form_error_ip_format": "Invalid IP format",
"form_error_mac_format": "Invalid MAC format",
"form_error_client_id_format": "Invalid client ID format",
"form_error_server_name": "Invalid server name",
"form_error_positive": "Must be greater than 0",
"form_error_negative": "Must be equal to 0 or greater",
"range_end_error": "Must be greater than range start",
@ -250,8 +251,12 @@
"dns_over_https": "DNS-over-HTTPS",
"dns_over_tls": "DNS-over-TLS",
"dns_over_quic": "DNS-over-QUIC",
"client_id": "Client ID",
"client_id_placeholder": "Enter client ID",
"client_id_desc": "Different clients can be identified by a special client ID. <a>Here</a> you can learn more about how to identify clients.",
"download_mobileconfig_doh": "Download .mobileconfig for DNS-over-HTTPS",
"download_mobileconfig_dot": "Download .mobileconfig for DNS-over-TLS",
"download_mobileconfig": "Download configuration file",
"plain_dns": "Plain DNS",
"form_enter_rate_limit": "Enter rate limit",
"rate_limit": "Rate limit",
@ -331,7 +336,7 @@
"encryption_config_saved": "Encryption config saved",
"encryption_server": "Server name",
"encryption_server_enter": "Enter your domain name",
"encryption_server_desc": "In order to use HTTPS, you need to enter the server name that matches your SSL certificate.",
"encryption_server_desc": "In order to use HTTPS, you need to enter the server name that matches your SSL certificate or wildcard certificate. If the field is not set, it will accept TLS connections for any domain.",
"encryption_redirect": "Redirect to HTTPS automatically",
"encryption_redirect_desc": "If checked, AdGuard Home will automatically redirect you from HTTP to HTTPS addresses.",
"encryption_https": "HTTPS port",
@ -387,7 +392,7 @@
"client_edit": "Edit Client",
"client_identifier": "Identifier",
"ip_address": "IP address",
"client_identifier_desc": "Clients can be identified by the IP address, CIDR, MAC address. Please note that using MAC as identifier is possible only if AdGuard Home is also a <0>DHCP server</0>",
"client_identifier_desc": "Clients can be identified by the IP address, CIDR, MAC address or a special client ID (can be used for DoT/DoH/DoQ). <0>Here</0> you can learn more about how to identify clients.",
"form_enter_ip": "Enter IP",
"form_enter_mac": "Enter MAC",
"form_enter_id": "Enter identifier",
@ -431,6 +436,7 @@
"setup_dns_privacy_other_3": "<0>dnscrypt-proxy</0> supports <1>DNS-over-HTTPS</1>.",
"setup_dns_privacy_other_4": "<0>Mozilla Firefox</0> supports <1>DNS-over-HTTPS</1>.",
"setup_dns_privacy_other_5": "You will find more implementations <0>here</0> and <1>here</1>.",
"setup_dns_privacy_ioc_mac": "iOS and macOS configuration",
"setup_dns_notice": "In order to use <1>DNS-over-HTTPS</1> or <1>DNS-over-TLS</1>, you need to <0>configure Encryption</0> in AdGuard Home settings.",
"rewrite_added": "DNS rewrite for \"{{key}}\" successfully added",
"rewrite_deleted": "DNS rewrite for \"{{key}}\" successfully deleted",

View File

@ -8,7 +8,7 @@ import {
import { addErrorToast, addSuccessToast } from './toasts';
const enrichWithClientInfo = async (logs) => {
const clientsParams = getParamsForClientsSearch(logs, 'client');
const clientsParams = getParamsForClientsSearch(logs, 'client', 'client_id');
if (Object.keys(clientsParams).length > 0) {
const clients = await apiClient.findClients(clientsParams);

View File

@ -81,3 +81,7 @@ body {
.ReactModal__Body--open {
overflow: hidden;
}
a.btn-success.disabled {
color: #fff;
}

View File

@ -9,7 +9,7 @@ import Card from '../ui/Card';
import Cell from '../ui/Cell';
import { getPercent, sortIp } from '../../helpers/helpers';
import { BLOCK_ACTIONS, STATUS_COLORS } from '../../helpers/constants';
import { BLOCK_ACTIONS, R_CLIENT_ID, STATUS_COLORS } from '../../helpers/constants';
import { toggleClientBlock } from '../../actions/access';
import { renderFormattedClientCell } from '../../helpers/renderFormattedClientCell';
import { getStats } from '../../actions/stats';
@ -35,6 +35,10 @@ const CountCell = (row) => {
};
const renderBlockingButton = (ip, disallowed, disallowed_rule) => {
if (R_CLIENT_ID.test(ip)) {
return null;
}
const dispatch = useDispatch();
const { t } = useTranslation();
const processingSet = useSelector((state) => state.access.processingSet);
@ -59,7 +63,8 @@ const renderBlockingButton = (ip, disallowed, disallowed_rule) => {
const text = disallowed ? BLOCK_ACTIONS.UNBLOCK : BLOCK_ACTIONS.BLOCK;
const isNotInAllowedList = disallowed && disallowed_rule === '';
return <div className="table__action pl-4">
return (
<div className="table__action pl-4">
<button
type="button"
className={buttonClass}
@ -69,7 +74,8 @@ const renderBlockingButton = (ip, disallowed, disallowed_rule) => {
>
<Trans>{text}</Trans>
</button>
</div>;
</div>
);
};
const ClientCell = (row) => {
@ -90,7 +96,8 @@ const Clients = ({
const { t } = useTranslation();
const topClients = useSelector((state) => state.stats.topClients, shallowEqual);
return <Card
return (
<Card
title={t('top_clients')}
subtitle={subtitle}
bodyType="card-table"
@ -107,7 +114,7 @@ const Clients = ({
}))}
columns={[
{
Header: 'IP',
Header: <Trans>client_table_header</Trans>,
accessor: 'ip',
sortMethod: sortIp,
Cell: ClientCell,
@ -135,7 +142,8 @@ const Clients = ({
return disallowed ? { className: 'logs__row--red' } : {};
}}
/>
</Card>;
</Card>
);
};
Clients.propTypes = {

View File

@ -16,6 +16,7 @@ import { updateLogs } from '../../../actions/queryLogs';
const ClientCell = ({
client,
client_id,
domain,
info,
info: {
@ -33,12 +34,14 @@ const ClientCell = ({
const autoClient = autoClients.find((autoClient) => autoClient.name === client);
const source = autoClient?.source;
const whoisAvailable = whois_info && Object.keys(whois_info).length > 0;
const clientName = name || client_id;
const clientInfo = { ...info, name: clientName };
const id = nanoid();
const data = {
address: client,
name,
name: clientName,
country: whois_info?.country,
city: whois_info?.city,
network: whois_info?.orgname,
@ -99,13 +102,20 @@ const ClientCell = ({
if (options.length === 0) {
return null;
}
return <>{options.map(({ name, onClick, disabled }) => <button
return (
<>
{options.map(({ name, onClick, disabled }) => (
<button
key={name}
className="button-action--arrow-option px-4 py-2"
onClick={onClick}
disabled={disabled}
>{t(name)}
</button>)}</>;
>
{t(name)}
</button>
))}
</>
);
};
const content = getOptions(BUTTON_OPTIONS);
@ -125,45 +135,70 @@ const ClientCell = ({
'button-action__container--detailed': isDetailed,
});
return <div className={containerClass}>
<button type="button"
return (
<div className={containerClass}>
<button
type="button"
className={buttonClass}
onClick={onClick}
disabled={processingRules}
>
{t(buttonType)}
</button>
{content && <button className={buttonArrowClass} disabled={processingRules}>
{content && (
<button className={buttonArrowClass} disabled={processingRules}>
<IconTooltip
className='h-100'
tooltipClass='button-action--arrow-option-container'
xlinkHref='chevron-down'
triggerClass='button-action--icon'
content={content} placement="bottom-end" trigger="click"
className="h-100"
tooltipClass="button-action--arrow-option-container"
xlinkHref="chevron-down"
triggerClass="button-action--icon"
content={content}
placement="bottom-end"
trigger="click"
onVisibilityChange={setOptionsOpened}
/>
</button>}
</div>;
</button>
)}
</div>
);
};
return <div className="o-hidden h-100 logs__cell logs__cell--client" role="gridcell">
<IconTooltip className={hintClass} columnClass='grid grid--limited' tooltipClass='px-5 pb-5 pt-4 mw-75'
xlinkHref='question' contentItemClass="contentItemClass" title="client_details"
content={processedData} placement="bottom" />
return (
<div
className="o-hidden h-100 logs__cell logs__cell--client"
role="gridcell"
>
<IconTooltip
className={hintClass}
columnClass="grid grid--limited"
tooltipClass="px-5 pb-5 pt-4"
xlinkHref="question"
contentItemClass="text-truncate key-colon o-hidden"
title="client_details"
content={processedData}
placement="bottom"
/>
<div className={nameClass}>
<div data-tip={true} data-for={id}>
{renderFormattedClientCell(client, info, isDetailed, true)}
{renderFormattedClientCell(client, clientInfo, isDetailed, true)}
</div>
{isDetailed && name && !whoisAvailable
&& <div className="detailed-info d-none d-sm-block logs__text"
title={name}>{name}</div>}
{isDetailed && clientName && !whoisAvailable && (
<div
className="detailed-info d-none d-sm-block logs__text"
title={clientName}
>
{clientName}
</div>
)}
</div>
{renderBlockingButton(isFiltered, domain)}
</div>;
</div>
);
};
ClientCell.propTypes = {
client: propTypes.string.isRequired,
client_id: propTypes.string,
domain: propTypes.string.isRequired,
info: propTypes.oneOfType([
propTypes.string,

View File

@ -70,6 +70,7 @@ const Row = memo(({
upstream,
type,
client_proto,
client_id,
rules,
originalResponse,
status,
@ -176,7 +177,7 @@ const Row = memo(({
response_code: status,
client_details: 'title',
ip_address: client,
name: info?.name,
name: info?.name || client_id,
country,
city,
network,
@ -233,6 +234,7 @@ Row.propTypes = {
upstream: propTypes.string.isRequired,
type: propTypes.string.isRequired,
client_proto: propTypes.string.isRequired,
client_id: propTypes.string,
rules: propTypes.arrayOf(propTypes.shape({
text: propTypes.string.isRequired,
filter_list_id: propTypes.number.isRequired,

View File

@ -282,7 +282,7 @@ let Form = (props) => {
<div className="form__desc mt-0">
<Trans
components={[
<a href="#dhcp" key="0">
<a href="https://github.com/AdguardTeam/AdGuardHome/wiki/Clients#idclient" key="0" target="_blank" rel="noopener noreferrer">
link
</a>,
]}

View File

@ -50,7 +50,7 @@ const CertificateStatus = ({
{dnsNames && (
<li>
<Trans>encryption_hostnames</Trans>:&nbsp;
{dnsNames}
{dnsNames.join(', ')}
</li>
)}
</Fragment>
@ -65,7 +65,7 @@ CertificateStatus.propTypes = {
subject: PropTypes.string,
issuer: PropTypes.string,
notAfter: PropTypes.string,
dnsNames: PropTypes.string,
dnsNames: PropTypes.arrayOf(PropTypes.string),
};
export default withTranslation()(CertificateStatus);

View File

@ -12,7 +12,7 @@ import {
toNumber,
} from '../../../helpers/form';
import {
validateIsSafePort, validatePort, validatePortQuic, validatePortTLS,
validateServerName, validateIsSafePort, validatePort, validatePortQuic, validatePortTLS,
} from '../../../helpers/validators';
import i18n from '../../../i18n';
import KeyStatus from './KeyStatus';
@ -127,6 +127,7 @@ let Form = (props) => {
placeholder={t('encryption_server_enter')}
onChange={handleChange}
disabled={!isEnabled}
validate={validateServerName}
/>
<div className="form__desc">
<Trans>encryption_server_desc</Trans>
@ -413,7 +414,7 @@ Form.propTypes = {
valid_key: PropTypes.bool,
valid_cert: PropTypes.bool,
valid_pair: PropTypes.bool,
dns_names: PropTypes.string,
dns_names: PropTypes.arrayOf(PropTypes.string),
key_type: PropTypes.string,
issuer: PropTypes.string,
subject: PropTypes.string,

View File

@ -3,27 +3,12 @@ import PropTypes from 'prop-types';
import { Trans, useTranslation } from 'react-i18next';
import i18next from 'i18next';
import { useSelector } from 'react-redux';
import Tabs from './Tabs';
import Icons from './Icons';
import { getPathWithQueryString } from '../../helpers/helpers';
const MOBILE_CONFIG_LINKS = {
DOT: '/apple/dot.mobileconfig',
DOH: '/apple/doh.mobileconfig',
};
const renderMobileconfigInfo = ({ label, components, server_name }) => <li key={label}>
<Trans components={components}>{label}</Trans>
<ul>
<li>
<a href={getPathWithQueryString(MOBILE_CONFIG_LINKS.DOT, { host: server_name })}
download>{i18next.t('download_mobileconfig_dot')}</a>
</li>
<li>
<a href={getPathWithQueryString(MOBILE_CONFIG_LINKS.DOH, { host: server_name })}
download>{i18next.t('download_mobileconfig_doh')}</a>
</li>
</ul>
</li>;
import { MOBILE_CONFIG_LINKS } from '../../../helpers/constants';
import Tabs from '../Tabs';
import Icons from '../Icons';
import MobileConfigForm from './MobileConfigForm';
const renderLi = ({ label, components }) => <li key={label}>
<Trans components={components?.map((props) => {
@ -41,49 +26,8 @@ const renderLi = ({ label, components }) => <li key={label}>
</Trans>
</li>;
const getDnsPrivacyList = (server_name) => {
const iosList = [
const getDnsPrivacyList = () => [
{
label: 'setup_dns_privacy_ios_2',
components: [
{
key: 0,
href: 'https://adguard.com/adguard-ios/overview.html',
},
<code key="1">text</code>,
],
},
{
label: 'setup_dns_privacy_ios_1',
components: [
{
key: 0,
href: 'https://itunes.apple.com/app/id1452162351',
},
<code key="1">text</code>,
{
key: 2,
href: 'https://dnscrypt.info/stamps',
},
],
}];
/* Insert second element if can generate .mobileconfig links */
if (server_name) {
iosList.splice(1, 0, {
label: 'setup_dns_privacy_4',
components: {
highlight: <code />,
},
renderComponent: ({ label, components }) => renderMobileconfigInfo({
label,
components,
server_name,
}),
});
}
return [{
title: 'Android',
list: [
{
@ -113,7 +57,32 @@ const getDnsPrivacyList = (server_name) => {
},
{
title: 'iOS',
list: iosList,
list: [
{
label: 'setup_dns_privacy_ios_2',
components: [
{
key: 0,
href: 'https://adguard.com/adguard-ios/overview.html',
},
<code key="1">text</code>,
],
},
{
label: 'setup_dns_privacy_ios_1',
components: [
{
key: 0,
href: 'https://itunes.apple.com/app/id1452162351',
},
<code key="1">text</code>,
{
key: 2,
href: 'https://dnscrypt.info/stamps',
},
],
},
],
},
{
title: 'setup_dns_privacy_other_title',
@ -167,19 +136,19 @@ const getDnsPrivacyList = (server_name) => {
],
},
];
};
const renderDnsPrivacyList = ({ title, list }) => <div className="tab__paragraph" key={title}>
<strong><Trans>{title}</Trans></strong>
<ul>{list.map(
({
label,
components,
renderComponent = renderLi,
}) => renderComponent({ label, components }),
)}
const renderDnsPrivacyList = ({ title, list }) => (
<div className="tab__paragraph" key={title}>
<strong>
<Trans>{title}</Trans>
</strong>
<ul>
{list.map(({ label, components, renderComponent = renderLi }) => (
renderComponent({ label, components })
))}
</ul>
</div>;
</div>
);
const getTabs = ({
tlsAddress,
@ -267,8 +236,8 @@ const getTabs = ({
</Trans>
</div>
)}
{showDnsPrivacyNotice
? <div className="tab__paragraph">
{showDnsPrivacyNotice ? (
<div className="tab__paragraph">
<Trans
components={[
<a
@ -285,35 +254,64 @@ const getTabs = ({
setup_dns_notice
</Trans>
</div>
: <>
) : (
<>
<div className="tab__paragraph">
<Trans components={[<p key="0">text</p>]}>
setup_dns_privacy_3
</Trans>
</div>
{getDnsPrivacyList(server_name).map(renderDnsPrivacyList)}
</>}
{getDnsPrivacyList().map(renderDnsPrivacyList)}
<div>
<strong>
<Trans>
setup_dns_privacy_ioc_mac
</Trans>
</strong>
</div>
<div className="mb-3">
<Trans components={{ highlight: <code /> }}>
setup_dns_privacy_4
</Trans>
</div>
<MobileConfigForm
initialValues={{
host: server_name,
clientId: '',
protocol: MOBILE_CONFIG_LINKS.DOH,
}}
/>
</>
)}
</div>
</div>;
},
},
});
const renderContent = ({ title, list, getTitle }) => <div key={title} label={i18next.t(title)}>
<div className="tab__title">{i18next.t(title)}</div>
const renderContent = ({ title, list, getTitle }) => (
<div key={title} label={i18next.t(title)}>
<div className="tab__title">
{i18next.t(title)}
</div>
<div className="tab__text">
{getTitle?.()}
{list
&& <ol>{list.map((item) => <li key={item}>
{list && (
<ol>
{list.map((item) => (
<li key={item}>
<Trans>{item}</Trans>
</li>)}
</ol>}
</li>
))}
</ol>
)}
</div>
</div>;
</div>
);
const Guide = ({ dnsAddresses }) => {
const { t } = useTranslation();
const server_name = useSelector((state) => state.encryption.server_name);
const server_name = useSelector((state) => state.encryption?.server_name);
const tlsAddress = dnsAddresses?.filter((item) => item.includes('tls://')) ?? '';
const httpsAddress = dnsAddresses?.filter((item) => item.includes('https://')) ?? '';
const showDnsPrivacyNotice = httpsAddress.length < 1 && tlsAddress.length < 1;
@ -332,9 +330,14 @@ const Guide = ({ dnsAddresses }) => {
return (
<div>
<Tabs
tabs={tabs}
activeTabLabel={activeTabLabel}
setActiveTabLabel={setActiveTabLabel}
>
{activeTab}
</Tabs>
<Icons />
<Tabs tabs={tabs} activeTabLabel={activeTabLabel}
setActiveTabLabel={setActiveTabLabel}>{activeTab}</Tabs>
</div>
);
};
@ -364,6 +367,4 @@ renderLi.propTypes = {
components: PropTypes.string,
};
renderMobileconfigInfo.propTypes = renderLi.propTypes;
export default Guide;

View File

@ -0,0 +1,131 @@
import React from 'react';
import PropTypes from 'prop-types';
import { Trans } from 'react-i18next';
import { useSelector } from 'react-redux';
import { Field, reduxForm } from 'redux-form';
import i18next from 'i18next';
import cn from 'classnames';
import { getPathWithQueryString } from '../../../helpers/helpers';
import { FORM_NAME, MOBILE_CONFIG_LINKS } from '../../../helpers/constants';
import {
renderInputField,
renderSelectField,
} from '../../../helpers/form';
import {
validateClientId,
validateServerName,
} from '../../../helpers/validators';
const getDownloadLink = (host, clientId, protocol, invalid) => {
if (!host || invalid) {
return (
<button
type="button"
className="btn btn-success btn-standard btn-large disabled"
>
<Trans>download_mobileconfig</Trans>
</button>
);
}
const linkParams = { host };
if (clientId) {
linkParams.client_id = clientId;
}
return (
<a
href={getPathWithQueryString(protocol, linkParams)}
className={cn('btn btn-success btn-standard btn-large')}
download
>
<Trans>download_mobileconfig</Trans>
</a>
);
};
const MobileConfigForm = ({ invalid }) => {
const formValues = useSelector((state) => state.form[FORM_NAME.MOBILE_CONFIG]?.values);
if (!formValues) {
return null;
}
const { host, clientId, protocol } = formValues;
const githubLink = (
<a
href="https://github.com/AdguardTeam/AdGuardHome/wiki/Clients#idclient"
target="_blank"
rel="noopener noreferrer"
>
text
</a>
);
return (
<form onSubmit={(e) => e.preventDefault()}>
<div>
<div className="form__group form__group--settings">
<label htmlFor="host" className="form__label">
{i18next.t('dhcp_table_hostname')}
</label>
<Field
name="host"
type="text"
component={renderInputField}
className="form-control"
placeholder={i18next.t('form_enter_hostname')}
validate={validateServerName}
/>
</div>
<div className="form__group form__group--settings">
<label htmlFor="clientId" className="form__label form__label--with-desc">
{i18next.t('client_id')}
</label>
<div className="form__desc form__desc--top">
<Trans components={{ a: githubLink }}>
client_id_desc
</Trans>
</div>
<Field
name="clientId"
type="text"
component={renderInputField}
className="form-control"
placeholder={i18next.t('client_id_placeholder')}
validate={validateClientId}
/>
</div>
<div className="form__group form__group--settings">
<label htmlFor="protocol" className="form__label">
{i18next.t('protocol')}
</label>
<Field
name="protocol"
type="text"
component={renderSelectField}
className="form-control"
>
<option value={MOBILE_CONFIG_LINKS.DOT}>
{i18next.t('dns_over_tls')}
</option>
<option value={MOBILE_CONFIG_LINKS.DOH}>
{i18next.t('dns_over_https')}
</option>
</Field>
</div>
</div>
{getDownloadLink(host, clientId, protocol, invalid)}
</form>
);
};
MobileConfigForm.propTypes = {
invalid: PropTypes.bool.isRequired,
};
export default reduxForm({ form: FORM_NAME.MOBILE_CONFIG })(MobileConfigForm);

View File

@ -0,0 +1 @@
export { default } from './Guide';

View File

@ -13,6 +13,8 @@ export const R_MAC = /^((([a-fA-F0-9][a-fA-F0-9]+[-]){5}|([a-fA-F0-9][a-fA-F0-9]
export const R_CIDR_IPV6 = /^s*((([0-9A-Fa-f]{1,4}:){7}([0-9A-Fa-f]{1,4}|:))|(([0-9A-Fa-f]{1,4}:){6}(:[0-9A-Fa-f]{1,4}|((25[0-5]|2[0-4]d|1dd|[1-9]?d)(.(25[0-5]|2[0-4]d|1dd|[1-9]?d)){3})|:))|(([0-9A-Fa-f]{1,4}:){5}(((:[0-9A-Fa-f]{1,4}){1,2})|:((25[0-5]|2[0-4]d|1dd|[1-9]?d)(.(25[0-5]|2[0-4]d|1dd|[1-9]?d)){3})|:))|(([0-9A-Fa-f]{1,4}:){4}(((:[0-9A-Fa-f]{1,4}){1,3})|((:[0-9A-Fa-f]{1,4})?:((25[0-5]|2[0-4]d|1dd|[1-9]?d)(.(25[0-5]|2[0-4]d|1dd|[1-9]?d)){3}))|:))|(([0-9A-Fa-f]{1,4}:){3}(((:[0-9A-Fa-f]{1,4}){1,4})|((:[0-9A-Fa-f]{1,4}){0,2}:((25[0-5]|2[0-4]d|1dd|[1-9]?d)(.(25[0-5]|2[0-4]d|1dd|[1-9]?d)){3}))|:))|(([0-9A-Fa-f]{1,4}:){2}(((:[0-9A-Fa-f]{1,4}){1,5})|((:[0-9A-Fa-f]{1,4}){0,3}:((25[0-5]|2[0-4]d|1dd|[1-9]?d)(.(25[0-5]|2[0-4]d|1dd|[1-9]?d)){3}))|:))|(([0-9A-Fa-f]{1,4}:){1}(((:[0-9A-Fa-f]{1,4}){1,6})|((:[0-9A-Fa-f]{1,4}){0,4}:((25[0-5]|2[0-4]d|1dd|[1-9]?d)(.(25[0-5]|2[0-4]d|1dd|[1-9]?d)){3}))|:))|(:(((:[0-9A-Fa-f]{1,4}){1,7})|((:[0-9A-Fa-f]{1,4}){0,5}:((25[0-5]|2[0-4]d|1dd|[1-9]?d)(.(25[0-5]|2[0-4]d|1dd|[1-9]?d)){3}))|:)))(%.+)?s*(\/(12[0-8]|1[0-1][0-9]|[1-9][0-9]|[0-9]))$/;
export const R_DOMAIN = /^[a-zA-Z0-9][a-zA-Z0-9-]{1,61}[a-zA-Z0-9]\.[a-zA-Z]{2,}$/;
export const R_PATH_LAST_PART = /\/[^/]*$/;
// eslint-disable-next-line no-control-regex
@ -21,6 +23,8 @@ export const R_UNIX_ABSOLUTE_PATH = /^(\/[^/\x00]+)+$/;
// eslint-disable-next-line no-control-regex
export const R_WIN_ABSOLUTE_PATH = /^([a-zA-Z]:)?(\\|\/)(?:[^\\/:*?"<>|\x00]+\\)*[^\\/:*?"<>|\x00]*$/;
export const R_CLIENT_ID = /^[a-z0-9-]{1,64}$/;
export const HTML_PAGES = {
INSTALL: '/install.html',
LOGIN: '/login.html',
@ -514,6 +518,7 @@ export const FORM_NAME = {
INSTALL: 'install',
LOGIN: 'login',
CACHE: 'cache',
MOBILE_CONFIG: 'mobileConfig',
...DHCP_FORM_NAMES,
};
@ -574,6 +579,7 @@ export const TOAST_TIMEOUTS = {
export const ADDRESS_TYPES = {
IP: 'IP',
CIDR: 'CIDR',
CLIENT_ID: 'CLIENT_ID',
UNKNOWN: 'UNKNOWN',
};
@ -585,3 +591,8 @@ export const CACHE_CONFIG_FIELDS = {
export const isFirefox = navigator.userAgent.indexOf('Firefox') !== -1;
export const COMMENT_LINE_DEFAULT_TOKEN = '#';
export const MOBILE_CONFIG_LINKS = {
DOT: '/apple/dot.mobileconfig',
DOH: '/apple/doh.mobileconfig',
};

View File

@ -4,7 +4,6 @@ import dateFormat from 'date-fns/format';
import round from 'lodash/round';
import axios from 'axios';
import i18n from 'i18next';
import uniqBy from 'lodash/uniqBy';
import ipaddr from 'ipaddr.js';
import queryString from 'query-string';
import React from 'react';
@ -22,6 +21,7 @@ import {
DHCP_VALUES_PLACEHOLDERS,
FILTERED,
FILTERED_STATUS,
R_CLIENT_ID,
SERVICES_ID_NAME_MAP,
STANDARD_DNS_PORT,
STANDARD_HTTPS_PORT,
@ -62,6 +62,7 @@ export const normalizeLogs = (logs) => logs.map((log) => {
answer_dnssec,
client,
client_proto,
client_id,
elapsedMs,
question,
reason,
@ -99,6 +100,7 @@ export const normalizeLogs = (logs) => logs.map((log) => {
reason,
client,
client_proto,
client_id,
/* TODO 'filterId' and 'rule' are deprecated, will be removed in 0.106 */
filterId,
rule,
@ -414,14 +416,21 @@ export const getPathWithQueryString = (path, params) => {
return `${path}?${searchParams.toString()}`;
};
export const getParamsForClientsSearch = (data, param) => {
const uniqueClients = uniqBy(data, param);
return uniqueClients
.reduce((acc, item, idx) => {
const key = `ip${idx}`;
acc[key] = item[param];
return acc;
}, {});
export const getParamsForClientsSearch = (data, param, additionalParam) => {
const clients = new Set();
data.forEach((e) => {
clients.add(e[param]);
if (e[additionalParam]) {
clients.add(e[additionalParam]);
}
});
const params = {};
const ids = Array.from(clients.values());
ids.forEach((id, i) => {
params[`ip${i}`] = id;
});
return params;
};
/**
@ -534,7 +543,7 @@ export const isIpInCidr = (ip, cidr) => {
/**
*
* @param ipOrCidr
* @returns {'IP' | 'CIDR' | 'UNKNOWN'}
* @returns {'IP' | 'CIDR' | 'CLIENT_ID' | 'UNKNOWN'}
*
*/
export const findAddressType = (address) => {
@ -547,6 +556,9 @@ export const findAddressType = (address) => {
if (cidrMaybe && ipaddr.parseCIDR(address)) {
return ADDRESS_TYPES.CIDR;
}
if (R_CLIENT_ID.test(address)) {
return ADDRESS_TYPES.CLIENT_ID;
}
return ADDRESS_TYPES.UNKNOWN;
} catch (e) {
@ -567,20 +579,31 @@ export const separateIpsAndCidrs = (ids) => ids.reduce((acc, curr) => {
if (addressType === ADDRESS_TYPES.CIDR) {
acc.cidrs.push(curr);
}
if (addressType === ADDRESS_TYPES.CLIENT_ID) {
acc.clientIds.push(curr);
}
return acc;
}, { ips: [], cidrs: [] });
}, { ips: [], cidrs: [], clientIds: [] });
export const countClientsStatistics = (ids, autoClients) => {
const { ips, cidrs } = separateIpsAndCidrs(ids);
const { ips, cidrs, clientIds } = separateIpsAndCidrs(ids);
const ipsCount = ips.reduce((acc, curr) => {
const count = autoClients[curr] || 0;
return acc + count;
}, 0);
const clientIdsCount = clientIds.reduce((acc, curr) => {
const count = autoClients[curr] || 0;
return acc + count;
}, 0);
const cidrsCount = Object.entries(autoClients)
.reduce((acc, curr) => {
const [id, count] = curr;
if (!ipaddr.isValid(id)) {
return false;
}
if (cidrs.some((cidr) => isIpInCidr(id, cidr))) {
// eslint-disable-next-line no-param-reassign
acc += count;
@ -588,7 +611,7 @@ export const countClientsStatistics = (ids, autoClients) => {
return acc;
}, 0);
return ipsCount + cidrsCount;
return ipsCount + cidrsCount + clientIdsCount;
};
/**

View File

@ -9,6 +9,8 @@ import {
R_URL_REQUIRES_PROTOCOL,
STANDARD_WEB_PORT,
UNSAFE_PORTS,
R_CLIENT_ID,
R_DOMAIN,
} from './constants';
import { getLastIpv4Octet, isValidAbsolutePath } from './form';
@ -71,12 +73,28 @@ export const validateClientId = (value) => {
|| R_MAC.test(formattedValue)
|| R_CIDR.test(formattedValue)
|| R_CIDR_IPV6.test(formattedValue)
|| R_CLIENT_ID.test(formattedValue)
)) {
return 'form_error_client_id_format';
}
return undefined;
};
/**
* @param value {string}
* @returns {undefined|string}
*/
export const validateServerName = (value) => {
if (!value) {
return undefined;
}
const formattedValue = value ? value.trim() : value;
if (formattedValue && !R_DOMAIN.test(formattedValue)) {
return 'form_error_server_name';
}
return undefined;
};
/**
* @param value {string}
* @returns {undefined|string}

3
go.mod
View File

@ -3,7 +3,7 @@ module github.com/AdguardTeam/AdGuardHome
go 1.14
require (
github.com/AdguardTeam/dnsproxy v0.33.7
github.com/AdguardTeam/dnsproxy v0.33.9
github.com/AdguardTeam/golibs v0.4.4
github.com/AdguardTeam/urlfilter v0.14.2
github.com/NYTimes/gziphandler v1.1.1
@ -17,6 +17,7 @@ require (
github.com/insomniacslk/dhcp v0.0.0-20201112113307-4de412bc85d8
github.com/kardianos/service v1.2.0
github.com/karrick/godirwalk v1.16.1 // indirect
github.com/lucas-clemente/quic-go v0.19.3
github.com/mdlayher/ethernet v0.0.0-20190606142754-0394541c37b7
github.com/mdlayher/raw v0.0.0-20191009151244-50f2db8cc065
github.com/miekg/dns v1.1.35

4
go.sum
View File

@ -18,8 +18,8 @@ dmitri.shuralyov.com/html/belt v0.0.0-20180602232347-f7d459c86be0/go.mod h1:JLBr
dmitri.shuralyov.com/service/change v0.0.0-20181023043359-a85b471d5412/go.mod h1:a1inKt/atXimZ4Mv927x+r7UpyzRUf4emIoiiSC2TN4=
dmitri.shuralyov.com/state v0.0.0-20180228185332-28bcc343414c/go.mod h1:0PRwlb0D6DFvNNtx+9ybjezNCa8XF0xaYcETyp6rHWU=
git.apache.org/thrift.git v0.0.0-20180902110319-2566ecd5d999/go.mod h1:fPE2ZNJGynbRyZ4dJvy6G277gSllfV2HJqblrnkyeyg=
github.com/AdguardTeam/dnsproxy v0.33.7 h1:DXsLTJoBSUejB2ZqVHyMG0/kXD8PzuVPbLCsGKBdaDc=
github.com/AdguardTeam/dnsproxy v0.33.7/go.mod h1:dkI9VWh43XlOzF2XogDm1EmoVl7PANOR4isQV6X9LZs=
github.com/AdguardTeam/dnsproxy v0.33.9 h1:HUwywkhUV/M73E7qWcBAF+SdsNq742s82Lvox4pr/tM=
github.com/AdguardTeam/dnsproxy v0.33.9/go.mod h1:dkI9VWh43XlOzF2XogDm1EmoVl7PANOR4isQV6X9LZs=
github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
github.com/AdguardTeam/golibs v0.4.2 h1:7M28oTZFoFwNmp8eGPb3ImmYbxGaJLyQXeIFVHjME0o=
github.com/AdguardTeam/golibs v0.4.2/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=

View File

@ -24,11 +24,12 @@ type FilteringConfig struct {
// Callbacks for other modules
// --
// Filtering callback function
FilterHandler func(clientAddr net.IP, settings *dnsfilter.RequestFilteringSettings) `yaml:"-"`
// FilterHandler is an optional additional filtering callback.
FilterHandler func(clientAddr net.IP, clientID string, settings *dnsfilter.RequestFilteringSettings) `yaml:"-"`
// GetCustomUpstreamByClient - a callback function that returns upstreams configuration
// based on the client IP address. Returns nil if there are no custom upstreams for the client
//
// TODO(e.burkov): Replace argument type with net.IP.
GetCustomUpstreamByClient func(clientAddr string) *proxy.UpstreamConfig `yaml:"-"`
@ -109,6 +110,10 @@ type TLSConfig struct {
CertificateChainData []byte `yaml:"-" json:"-"`
PrivateKeyData []byte `yaml:"-" json:"-"`
// ServerName is the hostname of the server. Currently, it is only
// being used for client ID checking.
ServerName string `yaml:"-" json:"-"`
cert tls.Certificate
// DNS names from certificate (SAN) or CN value from Subject
dnsNames []string

View File

@ -1,7 +1,10 @@
package dnsforward
import (
"crypto/tls"
"fmt"
"net"
"path"
"strings"
"time"
@ -10,6 +13,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/util"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/log"
"github.com/lucas-clemente/quic-go"
"github.com/miekg/dns"
)
@ -17,30 +21,56 @@ import (
type dnsContext struct {
srv *Server
proxyCtx *proxy.DNSContext
setts *dnsfilter.RequestFilteringSettings // filtering settings for this client
// setts are the filtering settings for the client.
setts *dnsfilter.RequestFilteringSettings
startTime time.Time
result *dnsfilter.Result
origResp *dns.Msg // response received from upstream servers. Set when response is modified by filtering
origQuestion dns.Question // question received from client. Set when Rewrites are used.
err error // error returned from the module
protectionEnabled bool // filtering is enabled, dnsfilter object is ready
responseFromUpstream bool // response is received from upstream servers
origReqDNSSEC bool // DNSSEC flag in the original request from user
// origResp is the response received from upstream. It is set when the
// response is modified by filters.
origResp *dns.Msg
// err is the error returned from a processing function.
err error
// clientID is the clientID from DOH, DOQ, or DOT, if provided.
clientID string
// origQuestion is the question received from the client. It is set
// when the request is modified by rewrites.
origQuestion dns.Question
// protectionEnabled shows if the filtering is enabled, and if the
// server's DNS filter is ready.
protectionEnabled bool
// responseFromUpstream shows if the response is received from the
// upstream servers.
responseFromUpstream bool
// origReqDNSSEC shows if the DNSSEC flag in the original request from
// the client is set.
origReqDNSSEC bool
}
// resultCode is the result of a request processing function.
type resultCode int
const (
resultDone = iota // module has completed its job, continue
resultFinish // module has completed its job, exit normally
resultError // an error occurred, exit with an error
// resultCodeSuccess is returned when a handler performed successfully,
// and the next handler must be called.
resultCodeSuccess resultCode = iota
// resultCodeFinish is returned when a handler performed successfully,
// and the processing of the request must be stopped.
resultCodeFinish
// resultCodeError is returned when a handler failed, and the processing
// of the request must be stopped.
resultCodeError
)
// handleDNSRequest filters the incoming DNS requests and writes them to the query log
func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error {
ctx := &dnsContext{srv: s, proxyCtx: d}
ctx.result = &dnsfilter.Result{}
ctx.startTime = time.Now()
ctx := &dnsContext{
srv: s,
proxyCtx: d,
result: &dnsfilter.Result{},
startTime: time.Now(),
}
type modProcessFunc func(ctx *dnsContext) int
type modProcessFunc func(ctx *dnsContext) (rc resultCode)
// Since (*dnsforward.Server).handleDNSRequest(...) is used as
// proxy.(Config).RequestHandler, there is no need for additional index
@ -51,6 +81,7 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error {
processInitial,
processInternalHosts,
processInternalIPAddrs,
processClientID,
processFilteringBeforeRequest,
processUpstream,
processDNSSECAfterResponse,
@ -61,13 +92,13 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error {
for _, process := range mods {
r := process(ctx)
switch r {
case resultDone:
case resultCodeSuccess:
// continue: call the next filter
case resultFinish:
case resultCodeFinish:
return nil
case resultError:
case resultCodeError:
return ctx.err
}
}
@ -79,12 +110,12 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error {
}
// Perform initial checks; process WHOIS & rDNS
func processInitial(ctx *dnsContext) int {
func processInitial(ctx *dnsContext) (rc resultCode) {
s := ctx.srv
d := ctx.proxyCtx
if s.conf.AAAADisabled && d.Req.Question[0].Qtype == dns.TypeAAAA {
_ = proxy.CheckDisabledAAAARequest(d, true)
return resultFinish
return resultCodeFinish
}
if s.conf.OnDNSRequest != nil {
@ -96,10 +127,10 @@ func processInitial(ctx *dnsContext) int {
if (d.Req.Question[0].Qtype == dns.TypeA || d.Req.Question[0].Qtype == dns.TypeAAAA) &&
d.Req.Question[0].Name == "use-application-dns.net." {
d.Res = s.genNXDomain(d.Req)
return resultFinish
return resultCodeFinish
}
return resultDone
return resultCodeSuccess
}
// Return TRUE if host names doesn't contain disallowed characters
@ -157,29 +188,29 @@ func (s *Server) onDHCPLeaseChanged(flags int) {
}
// Respond to A requests if the target host name is associated with a lease from our DHCP server
func processInternalHosts(ctx *dnsContext) int {
func processInternalHosts(ctx *dnsContext) (rc resultCode) {
s := ctx.srv
req := ctx.proxyCtx.Req
if !(req.Question[0].Qtype == dns.TypeA || req.Question[0].Qtype == dns.TypeAAAA) {
return resultDone
return resultCodeSuccess
}
host := req.Question[0].Name
host = strings.ToLower(host)
if !strings.HasSuffix(host, ".lan.") {
return resultDone
return resultCodeSuccess
}
host = strings.TrimSuffix(host, ".lan.")
s.tableHostToIPLock.Lock()
if s.tableHostToIP == nil {
s.tableHostToIPLock.Unlock()
return resultDone
return resultCodeSuccess
}
ip, ok := s.tableHostToIP[host]
s.tableHostToIPLock.Unlock()
if !ok {
return resultDone
return resultCodeSuccess
}
log.Debug("DNS: internal record: %s -> %s", req.Question[0].Name, ip)
@ -200,15 +231,163 @@ func processInternalHosts(ctx *dnsContext) int {
}
ctx.proxyCtx.Res = resp
return resultDone
return resultCodeSuccess
}
const maxDomainPartLen = 64
// ValidateClientID returns an error if clientID is not a valid client ID.
func ValidateClientID(clientID string) (err error) {
if len(clientID) > maxDomainPartLen {
return fmt.Errorf("client id %q is too long, max: %d", clientID, maxDomainPartLen)
}
for i, r := range clientID {
if (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '-' {
continue
}
return fmt.Errorf("invalid char %q at index %d in client id %q", r, i, clientID)
}
return nil
}
// clientIDFromClientServerName extracts and validates a client ID. hostSrvName
// is the server name of the host. cliSrvName is the server name as sent by the
// client.
func clientIDFromClientServerName(hostSrvName, cliSrvName string) (clientID string, err error) {
if hostSrvName == cliSrvName {
return "", nil
}
if !strings.HasSuffix(cliSrvName, hostSrvName) {
return "", fmt.Errorf("client server name %q doesn't match host server name %q", cliSrvName, hostSrvName)
}
clientID = cliSrvName[:len(cliSrvName)-len(hostSrvName)-1]
err = ValidateClientID(clientID)
if err != nil {
return "", fmt.Errorf("invalid client id: %w", err)
}
return clientID, nil
}
// processClientIDHTTPS extracts the client's ID from the path of the
// client's DNS-over-HTTPS request.
func processClientIDHTTPS(ctx *dnsContext) (rc resultCode) {
pctx := ctx.proxyCtx
r := pctx.HTTPRequest
if r == nil {
ctx.err = fmt.Errorf("proxy ctx http request of proto %s is nil", pctx.Proto)
return resultCodeError
}
origPath := r.URL.Path
parts := strings.Split(path.Clean(origPath), "/")
if parts[0] == "" {
parts = parts[1:]
}
if len(parts) == 0 || parts[0] != "dns-query" {
ctx.err = fmt.Errorf("client id check: invalid path %q", origPath)
return resultCodeError
}
clientID := ""
switch len(parts) {
case 1:
// Just /dns-query, no client ID.
return resultCodeSuccess
case 2:
clientID = parts[1]
default:
ctx.err = fmt.Errorf("client id check: invalid path %q: extra parts", origPath)
return resultCodeError
}
err := ValidateClientID(clientID)
if err != nil {
ctx.err = fmt.Errorf("client id check: invalid client id: %w", err)
return resultCodeError
}
ctx.clientID = clientID
return resultCodeSuccess
}
// tlsConn is a narrow interface for *tls.Conn to simplify testing.
type tlsConn interface {
ConnectionState() (cs tls.ConnectionState)
}
// quicSession is a narrow interface for quic.Session to simplify testing.
type quicSession interface {
ConnectionState() (cs quic.ConnectionState)
}
// processClientID extracts the client's ID from the server name of the client's
// DOT or DOQ request or the path of the client's DOH.
func processClientID(ctx *dnsContext) (rc resultCode) {
pctx := ctx.proxyCtx
proto := pctx.Proto
if proto == proxy.ProtoHTTPS {
return processClientIDHTTPS(ctx)
} else if proto != proxy.ProtoTLS && proto != proxy.ProtoQUIC {
return resultCodeSuccess
}
hostSrvName := ctx.srv.conf.TLSConfig.ServerName
if hostSrvName == "" {
return resultCodeSuccess
}
cliSrvName := ""
if proto == proxy.ProtoTLS {
conn := pctx.Conn
tc, ok := conn.(tlsConn)
if !ok {
ctx.err = fmt.Errorf("proxy ctx conn of proto %s is %T, want *tls.Conn", proto, conn)
return resultCodeError
}
cliSrvName = tc.ConnectionState().ServerName
} else if proto == proxy.ProtoQUIC {
qs, ok := pctx.QUICSession.(quicSession)
if !ok {
ctx.err = fmt.Errorf("proxy ctx quic session of proto %s is %T, want quic.Session", proto, pctx.QUICSession)
return resultCodeError
}
cliSrvName = qs.ConnectionState().ServerName
}
clientID, err := clientIDFromClientServerName(hostSrvName, cliSrvName)
if err != nil {
ctx.err = fmt.Errorf("client id check: %w", err)
return resultCodeError
}
ctx.clientID = clientID
return resultCodeSuccess
}
// Respond to PTR requests if the target IP address is leased by our DHCP server
func processInternalIPAddrs(ctx *dnsContext) int {
func processInternalIPAddrs(ctx *dnsContext) (rc resultCode) {
s := ctx.srv
req := ctx.proxyCtx.Req
if req.Question[0].Qtype != dns.TypePTR {
return resultDone
return resultCodeSuccess
}
arpa := req.Question[0].Name
@ -216,18 +395,18 @@ func processInternalIPAddrs(ctx *dnsContext) int {
arpa = strings.ToLower(arpa)
ip := util.DNSUnreverseAddr(arpa)
if ip == nil {
return resultDone
return resultCodeSuccess
}
s.tablePTRLock.Lock()
if s.tablePTR == nil {
s.tablePTRLock.Unlock()
return resultDone
return resultCodeSuccess
}
host, ok := s.tablePTR[ip.String()]
s.tablePTRLock.Unlock()
if !ok {
return resultDone
return resultCodeSuccess
}
log.Debug("DNS: reverse-lookup: %s -> %s", arpa, host)
@ -243,16 +422,16 @@ func processInternalIPAddrs(ctx *dnsContext) int {
ptr.Ptr = host + "."
resp.Answer = append(resp.Answer, ptr)
ctx.proxyCtx.Res = resp
return resultDone
return resultCodeSuccess
}
// Apply filtering logic
func processFilteringBeforeRequest(ctx *dnsContext) int {
func processFilteringBeforeRequest(ctx *dnsContext) (rc resultCode) {
s := ctx.srv
d := ctx.proxyCtx
if d.Res != nil {
return resultDone // response is already set - nothing to do
return resultCodeSuccess // response is already set - nothing to do
}
s.RLock()
@ -266,24 +445,24 @@ func processFilteringBeforeRequest(ctx *dnsContext) int {
var err error
ctx.protectionEnabled = s.conf.ProtectionEnabled && s.dnsFilter != nil
if ctx.protectionEnabled {
ctx.setts = s.getClientRequestFilteringSettings(d)
ctx.setts = s.getClientRequestFilteringSettings(ctx)
ctx.result, err = s.filterDNSRequest(ctx)
}
s.RUnlock()
if err != nil {
ctx.err = err
return resultError
return resultCodeError
}
return resultDone
return resultCodeSuccess
}
// processUpstream passes request to upstream servers and handles the response.
func processUpstream(ctx *dnsContext) int {
func processUpstream(ctx *dnsContext) (rc resultCode) {
s := ctx.srv
d := ctx.proxyCtx
if d.Res != nil {
return resultDone // response is already set - nothing to do
return resultCodeSuccess // response is already set - nothing to do
}
if d.Addr != nil && s.conf.GetCustomUpstreamByClient != nil {
@ -311,26 +490,26 @@ func processUpstream(ctx *dnsContext) int {
err := s.dnsProxy.Resolve(d)
if err != nil {
ctx.err = err
return resultError
return resultCodeError
}
ctx.responseFromUpstream = true
return resultDone
return resultCodeSuccess
}
// Process DNSSEC after response from upstream server
func processDNSSECAfterResponse(ctx *dnsContext) int {
func processDNSSECAfterResponse(ctx *dnsContext) (rc resultCode) {
d := ctx.proxyCtx
if !ctx.responseFromUpstream || // don't process response if it's not from upstream servers
!ctx.srv.conf.EnableDNSSEC {
return resultDone
return resultCodeSuccess
}
if !ctx.origReqDNSSEC {
optResp := d.Res.IsEdns0()
if optResp != nil && !optResp.Do() {
return resultDone
return resultCodeSuccess
}
// Remove RRSIG records from response
@ -361,11 +540,11 @@ func processDNSSECAfterResponse(ctx *dnsContext) int {
d.Res.Ns = answers
}
return resultDone
return resultCodeSuccess
}
// Apply filtering logic after we have received response from upstream servers
func processFilteringAfterResponse(ctx *dnsContext) int {
func processFilteringAfterResponse(ctx *dnsContext) (rc resultCode) {
s := ctx.srv
d := ctx.proxyCtx
res := ctx.result
@ -402,7 +581,7 @@ func processFilteringAfterResponse(ctx *dnsContext) int {
ctx.result, err = s.filterDNSResponse(ctx)
if err != nil {
ctx.err = err
return resultError
return resultCodeError
}
if ctx.result != nil {
ctx.origResp = origResp2 // matched by response
@ -411,5 +590,5 @@ func processFilteringAfterResponse(ctx *dnsContext) int {
}
}
return resultDone
return resultCodeSuccess
}

View File

@ -0,0 +1,235 @@
package dnsforward
import (
"crypto/tls"
"net"
"net/http"
"net/url"
"testing"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/lucas-clemente/quic-go"
"github.com/stretchr/testify/assert"
)
// testTLSConn is a tlsConn for tests.
type testTLSConn struct {
// Conn is embedded here simply to make testTLSConn a net.Conn without
// acctually implementing all methods.
net.Conn
serverName string
}
// ConnectionState implements the tlsConn interface for testTLSConn.
func (c testTLSConn) ConnectionState() (cs tls.ConnectionState) {
cs.ServerName = c.serverName
return cs
}
// testQUICSession is a quicSession for tests.
type testQUICSession struct {
// Session is embedded here simply to make testQUICSession
// a quic.Session without acctually implementing all methods.
quic.Session
serverName string
}
// ConnectionState implements the quicSession interface for testQUICSession.
func (c testQUICSession) ConnectionState() (cs quic.ConnectionState) {
cs.ServerName = c.serverName
return cs
}
func TestProcessClientID(t *testing.T) {
testCases := []struct {
name string
proto string
hostSrvName string
cliSrvName string
wantClientID string
wantErrMsg string
wantRes resultCode
}{{
name: "udp",
proto: proxy.ProtoUDP,
hostSrvName: "",
cliSrvName: "",
wantClientID: "",
wantErrMsg: "",
wantRes: resultCodeSuccess,
}, {
name: "tls_no_client_id",
proto: proxy.ProtoTLS,
hostSrvName: "example.com",
cliSrvName: "example.com",
wantClientID: "",
wantErrMsg: "",
wantRes: resultCodeSuccess,
}, {
name: "tls_client_id",
proto: proxy.ProtoTLS,
hostSrvName: "example.com",
cliSrvName: "cli.example.com",
wantClientID: "cli",
wantErrMsg: "",
wantRes: resultCodeSuccess,
}, {
name: "tls_client_id_hostname_error",
proto: proxy.ProtoTLS,
hostSrvName: "example.com",
cliSrvName: "cli.example.net",
wantClientID: "",
wantErrMsg: `client id check: client server name "cli.example.net" doesn't match host server name "example.com"`,
wantRes: resultCodeError,
}, {
name: "tls_invalid_client_id",
proto: proxy.ProtoTLS,
hostSrvName: "example.com",
cliSrvName: "!!!.example.com",
wantClientID: "",
wantErrMsg: `client id check: invalid client id: invalid char '!' at index 0 in client id "!!!"`,
wantRes: resultCodeError,
}, {
name: "tls_client_id_too_long",
proto: proxy.ProtoTLS,
hostSrvName: "example.com",
cliSrvName: "abcdefghijklmnopqrstuvwxyz0123456789abcdefghijklmnopqrstuvwxyz0123456789.example.com",
wantClientID: "",
wantErrMsg: `client id check: invalid client id: client id "abcdefghijklmnopqrstuvwxyz0123456789abcdefghijklmnopqrstuvwxyz0123456789" is too long, max: 64`,
wantRes: resultCodeError,
}, {
name: "quic_client_id",
proto: proxy.ProtoQUIC,
hostSrvName: "example.com",
cliSrvName: "cli.example.com",
wantClientID: "cli",
wantErrMsg: "",
wantRes: resultCodeSuccess,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
srv := &Server{
conf: ServerConfig{
TLSConfig: TLSConfig{ServerName: tc.hostSrvName},
},
}
var conn net.Conn
if tc.proto == proxy.ProtoTLS {
conn = testTLSConn{
serverName: tc.cliSrvName,
}
}
var qs quic.Session
if tc.proto == proxy.ProtoQUIC {
qs = testQUICSession{
serverName: tc.cliSrvName,
}
}
dctx := &dnsContext{
srv: srv,
proxyCtx: &proxy.DNSContext{
Proto: tc.proto,
Conn: conn,
QUICSession: qs,
},
}
res := processClientID(dctx)
assert.Equal(t, tc.wantRes, res)
assert.Equal(t, tc.wantClientID, dctx.clientID)
if tc.wantErrMsg != "" && assert.NotNil(t, dctx.err) {
assert.Equal(t, tc.wantErrMsg, dctx.err.Error())
} else {
assert.Nil(t, dctx.err)
}
})
}
}
func TestProcessClientID_https(t *testing.T) {
testCases := []struct {
name string
path string
wantClientID string
wantErrMsg string
wantRes resultCode
}{{
name: "no_client_id",
path: "/dns-query",
wantClientID: "",
wantErrMsg: "",
wantRes: resultCodeSuccess,
}, {
name: "no_client_id_slash",
path: "/dns-query/",
wantClientID: "",
wantErrMsg: "",
wantRes: resultCodeSuccess,
}, {
name: "client_id",
path: "/dns-query/cli",
wantClientID: "cli",
wantErrMsg: "",
wantRes: resultCodeSuccess,
}, {
name: "client_id_slash",
path: "/dns-query/cli/",
wantClientID: "cli",
wantErrMsg: "",
wantRes: resultCodeSuccess,
}, {
name: "bad_url",
path: "/foo",
wantClientID: "",
wantErrMsg: `client id check: invalid path "/foo"`,
wantRes: resultCodeError,
}, {
name: "extra",
path: "/dns-query/cli/foo",
wantClientID: "",
wantErrMsg: `client id check: invalid path "/dns-query/cli/foo": extra parts`,
wantRes: resultCodeError,
}, {
name: "invalid_client_id",
path: "/dns-query/!!!",
wantClientID: "",
wantErrMsg: `client id check: invalid client id: invalid char '!' at index 0 in client id "!!!"`,
wantRes: resultCodeError,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
r := &http.Request{
URL: &url.URL{
Path: tc.path,
},
}
dctx := &dnsContext{
proxyCtx: &proxy.DNSContext{
Proto: proxy.ProtoHTTPS,
HTTPRequest: r,
},
}
res := processClientID(dctx)
assert.Equal(t, tc.wantRes, res)
assert.Equal(t, tc.wantClientID, dctx.clientID)
if tc.wantErrMsg != "" && assert.NotNil(t, dctx.err) {
assert.Equal(t, tc.wantErrMsg, dctx.err.Error())
} else {
assert.Nil(t, dctx.err)
}
})
}
}

View File

@ -473,7 +473,7 @@ func TestBlockCNAME(t *testing.T) {
func TestClientRulesForCNAMEMatching(t *testing.T) {
s := createTestServer(t)
testUpstm := &testUpstream{testCNAMEs, testIPv4, nil}
s.conf.FilterHandler = func(_ net.IP, settings *dnsfilter.RequestFilteringSettings) {
s.conf.FilterHandler = func(_ net.IP, _ string, settings *dnsfilter.RequestFilteringSettings) {
settings.FilteringEnabled = false
}
err := s.startWithUpstream(testUpstm)
@ -1033,8 +1033,7 @@ func TestMatchDNSName(t *testing.T) {
assert.False(t, matchDNSName(dnsNames, "*.host2"))
}
type testDHCP struct {
}
type testDHCP struct{}
func (d *testDHCP) Leases(flags int) []dhcpd.Lease {
l := dhcpd.Lease{}

View File

@ -30,14 +30,15 @@ func (s *Server) beforeRequestHandler(_ *proxy.Proxy, d *proxy.DNSContext) (bool
return true, nil
}
// getClientRequestFilteringSettings lookups client filtering settings
// using the client's IP address from the DNSContext
func (s *Server) getClientRequestFilteringSettings(d *proxy.DNSContext) *dnsfilter.RequestFilteringSettings {
// getClientRequestFilteringSettings looks up client filtering settings using
// the client's IP address and ID, if any, from ctx.
func (s *Server) getClientRequestFilteringSettings(ctx *dnsContext) *dnsfilter.RequestFilteringSettings {
setts := s.dnsFilter.GetConfig()
setts.FilteringEnabled = true
if s.conf.FilterHandler != nil {
s.conf.FilterHandler(IPFromAddr(d.Addr), &setts)
s.conf.FilterHandler(IPFromAddr(ctx.proxyCtx.Addr), ctx.clientID, &setts)
}
return &setts
}

View File

@ -529,5 +529,5 @@ func (s *Server) registerHandlers() {
s.conf.HTTPRegister(http.MethodGet, "/control/access/list", s.handleAccessList)
s.conf.HTTPRegister(http.MethodPost, "/control/access/set", s.handleAccessSet)
s.conf.HTTPRegister("", "/dns-query", s.handleDOH)
s.conf.HTTPRegister("", "/dns-query/", s.handleDOH)
}

View File

@ -99,12 +99,12 @@ func (c *ipsetCtx) getIP(rr dns.RR) net.IP {
}
// Add IP addresses of the specified in configuration domain names to an ipset list
func (c *ipsetCtx) process(ctx *dnsContext) int {
func (c *ipsetCtx) process(ctx *dnsContext) (rc resultCode) {
req := ctx.proxyCtx.Req
if !(req.Question[0].Qtype == dns.TypeA ||
req.Question[0].Qtype == dns.TypeAAAA) ||
!ctx.responseFromUpstream {
return resultDone
return resultCodeSuccess
}
host := req.Question[0].Name
@ -112,7 +112,7 @@ func (c *ipsetCtx) process(ctx *dnsContext) int {
host = strings.ToLower(host)
ipsetNames, found := c.ipsetList[host]
if !found {
return resultDone
return resultCodeSuccess
}
log.Debug("IPSET: found ipsets %v for host %s", ipsetNames, host)
@ -138,5 +138,5 @@ func (c *ipsetCtx) process(ctx *dnsContext) int {
}
}
return resultDone
return resultCodeSuccess
}

View File

@ -37,5 +37,5 @@ func TestIPSET(t *testing.T) {
},
},
}
assert.Equal(t, resultDone, c.process(ctx))
assert.Equal(t, resultCodeSuccess, c.process(ctx))
}

View File

@ -1,7 +1,6 @@
package dnsforward
import (
"net"
"strings"
"time"
@ -13,13 +12,13 @@ import (
)
// Write Stats data and logs
func processQueryLogsAndStats(ctx *dnsContext) int {
func processQueryLogsAndStats(ctx *dnsContext) (rc resultCode) {
elapsed := time.Since(ctx.startTime)
s := ctx.srv
d := ctx.proxyCtx
pctx := ctx.proxyCtx
shouldLog := true
msg := d.Req
msg := pctx.Req
// don't log ANY request if refuseAny is enabled
if len(msg.Question) >= 1 && msg.Question[0].Qtype == dns.TypeANY && s.conf.RefuseAny {
@ -32,65 +31,67 @@ func processQueryLogsAndStats(ctx *dnsContext) int {
if shouldLog && s.queryLog != nil {
p := querylog.AddParams{
Question: msg,
Answer: d.Res,
Answer: pctx.Res,
OrigAnswer: ctx.origResp,
Result: ctx.result,
Elapsed: elapsed,
ClientIP: IPFromAddr(d.Addr),
ClientIP: IPFromAddr(pctx.Addr),
ClientID: ctx.clientID,
}
switch d.Proto {
switch pctx.Proto {
case proxy.ProtoHTTPS:
p.ClientProto = querylog.ClientProtoDOH
case proxy.ProtoQUIC:
p.ClientProto = querylog.ClientProtoDOQ
case proxy.ProtoTLS:
p.ClientProto = querylog.ClientProtoDOT
case proxy.ProtoDNSCrypt:
p.ClientProto = querylog.ClientProtoDNSCrypt
default:
// Consider this a plain DNS-over-UDP or DNS-over-TCL
// Consider this a plain DNS-over-UDP or DNS-over-TCP
// request.
}
if d.Upstream != nil {
p.Upstream = d.Upstream.Address()
if pctx.Upstream != nil {
p.Upstream = pctx.Upstream.Address()
}
s.queryLog.Add(p)
}
s.updateStats(d, elapsed, *ctx.result)
s.updateStats(ctx, elapsed, *ctx.result)
s.RUnlock()
return resultDone
return resultCodeSuccess
}
func (s *Server) updateStats(d *proxy.DNSContext, elapsed time.Duration, res dnsfilter.Result) {
func (s *Server) updateStats(ctx *dnsContext, elapsed time.Duration, res dnsfilter.Result) {
if s.stats == nil {
return
}
pctx := ctx.proxyCtx
e := stats.Entry{}
e.Domain = strings.ToLower(d.Req.Question[0].Name)
e.Domain = strings.ToLower(pctx.Req.Question[0].Name)
e.Domain = e.Domain[:len(e.Domain)-1] // remove last "."
switch addr := d.Addr.(type) {
case *net.UDPAddr:
e.Client = addr.IP
case *net.TCPAddr:
e.Client = addr.IP
if clientID := ctx.clientID; clientID != "" {
e.Client = clientID
} else if ip := IPFromAddr(pctx.Addr); ip != nil {
e.Client = ip.String()
}
e.Time = uint32(elapsed / 1000)
e.Result = stats.RNotFiltered
switch res.Reason {
case dnsfilter.FilteredSafeBrowsing:
e.Result = stats.RSafeBrowsing
case dnsfilter.FilteredParental:
e.Result = stats.RParental
case dnsfilter.FilteredSafeSearch:
e.Result = stats.RSafeSearch
case dnsfilter.FilteredBlockList:
fallthrough
case dnsfilter.FilteredInvalid:

View File

@ -0,0 +1,198 @@
package dnsforward
import (
"net"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/dnsfilter"
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
"github.com/AdguardTeam/AdGuardHome/internal/stats"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
)
// testQueryLog is a simple querylog.QueryLog implementation for tests.
type testQueryLog struct {
// QueryLog is embedded here simply to make testQueryLog
// a querylog.QueryLog without acctually implementing all methods.
querylog.QueryLog
lastParams querylog.AddParams
}
// Add implements the querylog.QueryLog interface for *testQueryLog.
func (l *testQueryLog) Add(p querylog.AddParams) {
l.lastParams = p
}
// testStats is a simple stats.Stats implementation for tests.
type testStats struct {
// Stats is embedded here simply to make testStats a stats.Stats without
// acctually implementing all methods.
stats.Stats
lastEntry stats.Entry
}
// Update implements the stats.Stats interface for *testStats.
func (l *testStats) Update(e stats.Entry) {
l.lastEntry = e
}
func TestProcessQueryLogsAndStats(t *testing.T) {
testCases := []struct {
name string
proto string
addr net.Addr
clientID string
wantLogProto querylog.ClientProto
wantStatClient string
wantCode resultCode
reason dnsfilter.Reason
wantStatResult stats.Result
}{{
name: "success_udp",
proto: proxy.ProtoUDP,
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
clientID: "",
wantLogProto: "",
wantStatClient: "1.2.3.4",
wantCode: resultCodeSuccess,
reason: dnsfilter.NotFilteredNotFound,
wantStatResult: stats.RNotFiltered,
}, {
name: "success_tls_client_id",
proto: proxy.ProtoTLS,
addr: &net.TCPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
clientID: "cli42",
wantLogProto: querylog.ClientProtoDOT,
wantStatClient: "cli42",
wantCode: resultCodeSuccess,
reason: dnsfilter.NotFilteredNotFound,
wantStatResult: stats.RNotFiltered,
}, {
name: "success_tls",
proto: proxy.ProtoTLS,
addr: &net.TCPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
clientID: "",
wantLogProto: querylog.ClientProtoDOT,
wantStatClient: "1.2.3.4",
wantCode: resultCodeSuccess,
reason: dnsfilter.NotFilteredNotFound,
wantStatResult: stats.RNotFiltered,
}, {
name: "success_quic",
proto: proxy.ProtoQUIC,
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
clientID: "",
wantLogProto: querylog.ClientProtoDOQ,
wantStatClient: "1.2.3.4",
wantCode: resultCodeSuccess,
reason: dnsfilter.NotFilteredNotFound,
wantStatResult: stats.RNotFiltered,
}, {
name: "success_https",
proto: proxy.ProtoHTTPS,
addr: &net.TCPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
clientID: "",
wantLogProto: querylog.ClientProtoDOH,
wantStatClient: "1.2.3.4",
wantCode: resultCodeSuccess,
reason: dnsfilter.NotFilteredNotFound,
wantStatResult: stats.RNotFiltered,
}, {
name: "success_dnscrypt",
proto: proxy.ProtoDNSCrypt,
addr: &net.TCPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
clientID: "",
wantLogProto: querylog.ClientProtoDNSCrypt,
wantStatClient: "1.2.3.4",
wantCode: resultCodeSuccess,
reason: dnsfilter.NotFilteredNotFound,
wantStatResult: stats.RNotFiltered,
}, {
name: "success_udp_filtered",
proto: proxy.ProtoUDP,
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
clientID: "",
wantLogProto: "",
wantStatClient: "1.2.3.4",
wantCode: resultCodeSuccess,
reason: dnsfilter.FilteredBlockList,
wantStatResult: stats.RFiltered,
}, {
name: "success_udp_sb",
proto: proxy.ProtoUDP,
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
clientID: "",
wantLogProto: "",
wantStatClient: "1.2.3.4",
wantCode: resultCodeSuccess,
reason: dnsfilter.FilteredSafeBrowsing,
wantStatResult: stats.RSafeBrowsing,
}, {
name: "success_udp_ss",
proto: proxy.ProtoUDP,
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
clientID: "",
wantLogProto: "",
wantStatClient: "1.2.3.4",
wantCode: resultCodeSuccess,
reason: dnsfilter.FilteredSafeSearch,
wantStatResult: stats.RSafeSearch,
}, {
name: "success_udp_pc",
proto: proxy.ProtoUDP,
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
clientID: "",
wantLogProto: "",
wantStatClient: "1.2.3.4",
wantCode: resultCodeSuccess,
reason: dnsfilter.FilteredParental,
wantStatResult: stats.RParental,
}}
ups, err := upstream.AddressToUpstream("1.1.1.1", upstream.Options{})
assert.Nil(t, err)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req := &dns.Msg{
Question: []dns.Question{{
Name: "example.com.",
}},
}
pctx := &proxy.DNSContext{
Proto: tc.proto,
Req: req,
Res: &dns.Msg{},
Addr: tc.addr,
Upstream: ups,
}
ql := &testQueryLog{}
st := &testStats{}
dctx := &dnsContext{
srv: &Server{
queryLog: ql,
stats: st,
},
proxyCtx: pctx,
startTime: time.Now(),
result: &dnsfilter.Result{
Reason: tc.reason,
},
clientID: tc.clientID,
}
code := processQueryLogsAndStats(dctx)
assert.Equal(t, tc.wantCode, code)
assert.Equal(t, tc.wantLogProto, ql.lastParams.ClientProto)
assert.Equal(t, tc.wantStatClient, st.lastEntry.Client)
assert.Equal(t, tc.wantStatResult, st.lastEntry.Result)
})
}
}

View File

@ -11,23 +11,21 @@ import (
"sync"
"time"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
"github.com/AdguardTeam/AdGuardHome/internal/dnsfilter"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/AdGuardHome/internal/util"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/utils"
)
const (
clientsUpdatePeriod = 10 * time.Minute
)
const clientsUpdatePeriod = 10 * time.Minute
var webHandlersRegistered = false
// Client information
// Client contains information about persistent clients.
type Client struct {
IDs []string
Tags []string
@ -52,14 +50,13 @@ type Client struct {
type clientSource uint
// Client sources
// Client sources. The order determines the priority.
const (
// Priority: etc/hosts > DHCP > ARP > rDNS > WHOIS
ClientSourceWHOIS clientSource = iota // from WHOIS
ClientSourceRDNS // from rDNS
ClientSourceDHCP // from DHCP
ClientSourceARP // from 'arp -a'
ClientSourceHostsFile // from /etc/hosts
ClientSourceWHOIS clientSource = iota
ClientSourceRDNS
ClientSourceDHCP
ClientSourceARP
ClientSourceHostsFile
)
// ClientHost information
@ -70,10 +67,10 @@ type ClientHost struct {
}
type clientsContainer struct {
// TODO(a.garipov): Perhaps use a number of separate indices for
// different types (string, net.IP, and so on).
list map[string]*Client // name -> client
idIndex map[string]*Client // IP -> client
// TODO(e.burkov): Think of a way to not require string conversion for
// IP addresses.
idIndex map[string]*Client // ID -> client
ipHost map[string]*ClientHost // IP -> Hostname
lock sync.Mutex
@ -158,7 +155,7 @@ func (clients *clientsContainer) tagKnown(tag string) bool {
func (clients *clientsContainer) addFromConfig(objects []clientObject) {
for _, cy := range objects {
cli := Client{
cli := &Client{
Name: cy.Name,
IDs: cy.IDs,
UseOwnSettings: !cy.UseGlobalSettings,
@ -174,7 +171,7 @@ func (clients *clientsContainer) addFromConfig(objects []clientObject) {
for _, s := range cy.BlockedServices {
if !dnsfilter.BlockedSvcKnown(s) {
log.Debug("Clients: skipping unknown blocked-service %q", s)
log.Debug("clients: skipping unknown blocked-service %q", s)
continue
}
cli.BlockedServices = append(cli.BlockedServices, s)
@ -182,7 +179,7 @@ func (clients *clientsContainer) addFromConfig(objects []clientObject) {
for _, t := range cy.Tags {
if !clients.tagKnown(t) {
log.Debug("Clients: skipping unknown tag %q", t)
log.Debug("clients: skipping unknown tag %q", t)
continue
}
cli.Tags = append(cli.Tags, t)
@ -210,10 +207,10 @@ func (clients *clientsContainer) WriteDiskConfig(objects *[]clientObject) {
UseGlobalBlockedServices: !cli.UseOwnBlockedServices,
}
cy.Tags = stringArrayDup(cli.Tags)
cy.IDs = stringArrayDup(cli.IDs)
cy.BlockedServices = stringArrayDup(cli.BlockedServices)
cy.Upstreams = stringArrayDup(cli.Upstreams)
cy.Tags = copyStrings(cli.Tags)
cy.IDs = copyStrings(cli.IDs)
cy.BlockedServices = copyStrings(cli.BlockedServices)
cy.Upstreams = copyStrings(cli.Upstreams)
*objects = append(*objects, cy)
}
@ -240,45 +237,44 @@ func (clients *clientsContainer) onHostsChanged() {
clients.addFromHostsFile()
}
// Exists checks if client with this IP already exists
func (clients *clientsContainer) Exists(ip net.IP, source clientSource) bool {
// Exists checks if client with this ID already exists.
func (clients *clientsContainer) Exists(id string, source clientSource) (ok bool) {
clients.lock.Lock()
defer clients.lock.Unlock()
_, ok := clients.findByIP(ip)
_, ok = clients.findLocked(id)
if ok {
return true
}
ch, ok := clients.ipHost[ip.String()]
var ch *ClientHost
ch, ok = clients.ipHost[id]
if !ok {
return false
}
if source > ch.Source {
return false // we're going to overwrite this client's info with a stronger source
}
return true
// Return false if the new source has higher priority.
return source <= ch.Source
}
func stringArrayDup(a []string) []string {
a2 := make([]string, len(a))
copy(a2, a)
return a2
func copyStrings(a []string) (b []string) {
return append(b, a...)
}
// Find searches for a client by IP
func (clients *clientsContainer) Find(ip net.IP) (Client, bool) {
// Find searches for a client by its ID.
func (clients *clientsContainer) Find(id string) (c *Client, ok bool) {
clients.lock.Lock()
defer clients.lock.Unlock()
c, ok := clients.findByIP(ip)
c, ok = clients.findLocked(id)
if !ok {
return Client{}, false
return nil, false
}
c.IDs = stringArrayDup(c.IDs)
c.Tags = stringArrayDup(c.Tags)
c.BlockedServices = stringArrayDup(c.BlockedServices)
c.Upstreams = stringArrayDup(c.Upstreams)
c.IDs = copyStrings(c.IDs)
c.Tags = copyStrings(c.Tags)
c.BlockedServices = copyStrings(c.BlockedServices)
c.Upstreams = copyStrings(c.Upstreams)
return c, true
}
@ -289,7 +285,7 @@ func (clients *clientsContainer) FindUpstreams(ip string) *proxy.UpstreamConfig
clients.lock.Lock()
defer clients.lock.Unlock()
c, ok := clients.findByIP(net.ParseIP(ip))
c, ok := clients.findLocked(ip)
if !ok {
return nil
}
@ -308,15 +304,16 @@ func (clients *clientsContainer) FindUpstreams(ip string) *proxy.UpstreamConfig
return c.upstreamConfig
}
// Find searches for a client by IP (and does not lock anything)
func (clients *clientsContainer) findByIP(ip net.IP) (Client, bool) {
if ip == nil {
return Client{}, false
// findLocked searches for a client by its ID. For internal use only.
func (clients *clientsContainer) findLocked(id string) (c *Client, ok bool) {
c, ok = clients.idIndex[id]
if ok {
return c, true
}
c, ok := clients.idIndex[ip.String()]
if ok {
return *c, true
ip := net.ParseIP(id)
if ip == nil {
return nil, false
}
for _, c = range clients.list {
@ -325,88 +322,96 @@ func (clients *clientsContainer) findByIP(ip net.IP) (Client, bool) {
if err != nil {
continue
}
if ipnet.Contains(ip) {
return *c, true
return c, true
}
}
}
if clients.dhcpServer == nil {
return Client{}, false
return nil, false
}
macFound := clients.dhcpServer.FindMACbyIP(ip)
if macFound == nil {
return Client{}, false
return nil, false
}
for _, c = range clients.list {
for _, id := range c.IDs {
hwAddr, err := net.ParseMAC(id)
if err != nil {
continue
}
if bytes.Equal(hwAddr, macFound) {
return *c, true
return c, true
}
}
}
return Client{}, false
return nil, false
}
// FindAutoClient - search for an auto-client by IP
func (clients *clientsContainer) FindAutoClient(ip net.IP) (ClientHost, bool) {
if ip == nil {
func (clients *clientsContainer) FindAutoClient(ip string) (ClientHost, bool) {
ipAddr := net.ParseIP(ip)
if ipAddr == nil {
return ClientHost{}, false
}
clients.lock.Lock()
defer clients.lock.Unlock()
ch, ok := clients.ipHost[ip.String()]
ch, ok := clients.ipHost[ip]
if ok {
return *ch, true
}
return ClientHost{}, false
}
// Check if Client object's fields are correct
func (clients *clientsContainer) check(c *Client) error {
if len(c.Name) == 0 {
return fmt.Errorf("invalid Name")
}
if len(c.IDs) == 0 {
return fmt.Errorf("id required")
// check validates the client.
func (clients *clientsContainer) check(c *Client) (err error) {
switch {
case c == nil:
return agherr.Error("client is nil")
case c.Name == "":
return agherr.Error("invalid name")
case len(c.IDs) == 0:
return agherr.Error("id required")
default:
// Go on.
}
for i, id := range c.IDs {
ip := net.ParseIP(id)
if ip != nil {
c.IDs[i] = ip.String() // normalize IP address
continue
// Normalize structured data.
var ip net.IP
var ipnet *net.IPNet
var mac net.HardwareAddr
if ip = net.ParseIP(id); ip != nil {
c.IDs[i] = ip.String()
} else if ip, ipnet, err = net.ParseCIDR(id); err == nil {
ipnet.IP = ip
c.IDs[i] = ipnet.String()
} else if mac, err = net.ParseMAC(id); err == nil {
c.IDs[i] = mac.String()
} else if err = dnsforward.ValidateClientID(id); err == nil {
c.IDs[i] = id
} else {
return fmt.Errorf("invalid client id at index %d: %q", i, id)
}
_, _, err := net.ParseCIDR(id)
if err == nil {
continue
}
_, err = net.ParseMAC(id)
if err == nil {
continue
}
return fmt.Errorf("invalid ID: %s", id)
}
for _, t := range c.Tags {
if !clients.tagKnown(t) {
return fmt.Errorf("invalid tag: %s", t)
return fmt.Errorf("invalid tag: %q", t)
}
}
sort.Strings(c.Tags)
err := dnsforward.ValidateUpstreams(c.Upstreams)
err = dnsforward.ValidateUpstreams(c.Upstreams)
if err != nil {
return fmt.Errorf("invalid upstream servers: %w", err)
}
@ -414,49 +419,52 @@ func (clients *clientsContainer) check(c *Client) error {
return nil
}
// Add a new client object
// Return true: success; false: client exists.
func (clients *clientsContainer) Add(c Client) (bool, error) {
e := clients.check(&c)
if e != nil {
return false, e
// Add adds a new client object. ok is false if such client already exists or
// if an error occurred.
func (clients *clientsContainer) Add(c *Client) (ok bool, err error) {
err = clients.check(c)
if err != nil {
return false, err
}
clients.lock.Lock()
defer clients.lock.Unlock()
// check Name index
_, ok := clients.list[c.Name]
_, ok = clients.list[c.Name]
if ok {
return false, nil
}
// check ID index
for _, id := range c.IDs {
c2, ok := clients.idIndex[id]
var c2 *Client
c2, ok = clients.idIndex[id]
if ok {
return false, fmt.Errorf("another client uses the same ID (%s): %s", id, c2.Name)
return false, fmt.Errorf("another client uses the same ID (%q): %q", id, c2.Name)
}
}
// update Name index
clients.list[c.Name] = &c
clients.list[c.Name] = c
// update ID index
for _, id := range c.IDs {
clients.idIndex[id] = &c
clients.idIndex[id] = c
}
log.Debug("Clients: added %q: ID:%v [%d]", c.Name, c.IDs, len(clients.list))
log.Debug("clients: added %q: ID:%q [%d]", c.Name, c.IDs, len(clients.list))
return true, nil
}
// Del removes a client
func (clients *clientsContainer) Del(name string) bool {
// Del removes a client. ok is false if there is no such client.
func (clients *clientsContainer) Del(name string) (ok bool) {
clients.lock.Lock()
defer clients.lock.Unlock()
c, ok := clients.list[name]
var c *Client
c, ok = clients.list[name]
if !ok {
return false
}
@ -468,25 +476,28 @@ func (clients *clientsContainer) Del(name string) bool {
for _, id := range c.IDs {
delete(clients.idIndex, id)
}
return true
}
// Return TRUE if arrays are equal
func arraysEqual(a, b []string) bool {
// equalStringSlices returns true if the slices are equal.
func equalStringSlices(a, b []string) (ok bool) {
if len(a) != len(b) {
return false
}
for i := 0; i != len(a); i++ {
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
// Update a client
func (clients *clientsContainer) Update(name string, c Client) error {
err := clients.check(&c)
// Update updates a client by its name.
func (clients *clientsContainer) Update(name string, c *Client) (err error) {
err = clients.check(c)
if err != nil {
return err
}
@ -494,66 +505,69 @@ func (clients *clientsContainer) Update(name string, c Client) error {
clients.lock.Lock()
defer clients.lock.Unlock()
old, ok := clients.list[name]
prev, ok := clients.list[name]
if !ok {
return fmt.Errorf("client not found")
return agherr.Error("client not found")
}
// check Name index
if old.Name != c.Name {
if prev.Name != c.Name {
_, ok = clients.list[c.Name]
if ok {
return fmt.Errorf("client already exists")
return agherr.Error("client already exists")
}
}
// check IP index
if !arraysEqual(old.IDs, c.IDs) {
if !equalStringSlices(prev.IDs, c.IDs) {
for _, id := range c.IDs {
c2, ok := clients.idIndex[id]
if ok && c2 != old {
return fmt.Errorf("another client uses the same ID (%s): %s", id, c2.Name)
if ok && c2 != prev {
return fmt.Errorf("another client uses the same ID (%q): %q", id, c2.Name)
}
}
// update ID index
for _, id := range old.IDs {
for _, id := range prev.IDs {
delete(clients.idIndex, id)
}
for _, id := range c.IDs {
clients.idIndex[id] = old
clients.idIndex[id] = prev
}
}
// update Name index
if old.Name != c.Name {
delete(clients.list, old.Name)
clients.list[c.Name] = old
if prev.Name != c.Name {
delete(clients.list, prev.Name)
clients.list[c.Name] = prev
}
// update upstreams cache
c.upstreamConfig = nil
*old = c
*prev = *c
return nil
}
// SetWhoisInfo - associate WHOIS information with a client
func (clients *clientsContainer) SetWhoisInfo(ip net.IP, info [][]string) {
// SetWhoisInfo sets the WHOIS information for a client.
//
// TODO(a.garipov): Perhaps replace [][]string with map[string]string.
func (clients *clientsContainer) SetWhoisInfo(ip string, info [][]string) {
clients.lock.Lock()
defer clients.lock.Unlock()
_, ok := clients.findByIP(ip)
_, ok := clients.findLocked(ip)
if ok {
log.Debug("Clients: client for %s is already created, ignore WHOIS info", ip)
log.Debug("clients: client for %s is already created, ignore whois info", ip)
return
}
ipStr := ip.String()
ch, ok := clients.ipHost[ipStr]
ch, ok := clients.ipHost[ip]
if ok {
ch.WhoisInfo = info
log.Debug("Clients: set WHOIS info for auto-client %s: %v", ch.Host, ch.WhoisInfo)
log.Debug("clients: set whois info for auto-client %s: %q", ch.Host, info)
return
}
@ -562,32 +576,34 @@ func (clients *clientsContainer) SetWhoisInfo(ip net.IP, info [][]string) {
Source: ClientSourceWHOIS,
}
ch.WhoisInfo = info
clients.ipHost[ipStr] = ch
log.Debug("Clients: set WHOIS info for auto-client with IP %s: %v", ip, ch.WhoisInfo)
clients.ipHost[ip] = ch
log.Debug("clients: set whois info for auto-client with IP %s: %q", ip, info)
}
// AddHost adds new IP -> Host pair
// Use priority of the source (etc/hosts > ARP > rDNS)
// so we overwrite existing entries with an equal or higher priority
func (clients *clientsContainer) AddHost(ip, host string, source clientSource) (bool, error) {
// AddHost adds a new IP-hostname pairing. The priorities of the sources is
// taken into account. ok is true if the pairing was added.
func (clients *clientsContainer) AddHost(ip, host string, src clientSource) (ok bool, err error) {
clients.lock.Lock()
b := clients.addHost(ip, host, source)
ok = clients.addHostLocked(ip, host, src)
clients.lock.Unlock()
return b, nil
return ok, nil
}
func (clients *clientsContainer) addHost(ip, host string, source clientSource) (addedNew bool) {
ch, ok := clients.ipHost[ip]
// addHostLocked adds a new IP-hostname pairing. For internal use only.
func (clients *clientsContainer) addHostLocked(ip, host string, src clientSource) (ok bool) {
var ch *ClientHost
ch, ok = clients.ipHost[ip]
if ok {
if ch.Source > source {
if ch.Source > src {
return false
}
ch.Source = source
ch.Source = src
} else {
ch = &ClientHost{
Host: host,
Source: source,
Source: src,
}
clients.ipHost[ip] = ch
@ -598,11 +614,11 @@ func (clients *clientsContainer) addHost(ip, host string, source clientSource) (
return true
}
// Remove all entries that match the specified source
func (clients *clientsContainer) rmHosts(source clientSource) {
// rmHostsBySrc removes all entries that match the specified source.
func (clients *clientsContainer) rmHostsBySrc(src clientSource) {
n := 0
for k, v := range clients.ipHost {
if v.Source == source {
if v.Source == src {
delete(clients.ipHost, k)
n++
}
@ -611,19 +627,20 @@ func (clients *clientsContainer) rmHosts(source clientSource) {
log.Debug("clients: removed %d client aliases", n)
}
// addFromHostsFile fills the clients hosts list from the system's hosts files.
// addFromHostsFile fills the client-hostname pairing index from the system's
// hosts files.
func (clients *clientsContainer) addFromHostsFile() {
hosts := clients.autoHosts.List()
clients.lock.Lock()
defer clients.lock.Unlock()
clients.rmHosts(ClientSourceHostsFile)
clients.rmHostsBySrc(ClientSourceHostsFile)
n := 0
for ip, names := range hosts {
for _, name := range names {
ok := clients.addHost(ip, name, ClientSourceHostsFile)
ok := clients.addHostLocked(ip, name, ClientSourceHostsFile)
if ok {
n++
}
@ -633,31 +650,31 @@ func (clients *clientsContainer) addFromHostsFile() {
log.Debug("Clients: added %d client aliases from system hosts-file", n)
}
// Add IP -> Host pairs from the system's `arp -a` command output
// The command's output is:
// HOST (IP) at MAC on IFACE
// addFromSystemARP adds the IP-hostname pairings from the output of the arp -a
// command.
func (clients *clientsContainer) addFromSystemARP() {
if runtime.GOOS == "windows" {
return
}
cmd := exec.Command("arp", "-a")
log.Tracef("executing %s %v", cmd.Path, cmd.Args)
log.Tracef("executing %q %q", cmd.Path, cmd.Args)
data, err := cmd.Output()
if err != nil || cmd.ProcessState.ExitCode() != 0 {
log.Debug("command %s has failed: %v code:%d",
log.Debug("command %q has failed: %q code:%d",
cmd.Path, err, cmd.ProcessState.ExitCode())
return
}
clients.lock.Lock()
defer clients.lock.Unlock()
clients.rmHosts(ClientSourceARP)
clients.rmHostsBySrc(ClientSourceARP)
n := 0
// TODO(a.garipov): Rewrite to use bufio.Scanner.
lines := strings.Split(string(data), "\n")
for _, ln := range lines {
open := strings.Index(ln, " (")
close := strings.Index(ln, ") ")
if open == -1 || close == -1 || open >= close {
@ -670,16 +687,17 @@ func (clients *clientsContainer) addFromSystemARP() {
continue
}
ok := clients.addHost(ip, host, ClientSourceARP)
ok := clients.addHostLocked(ip, host, ClientSourceARP)
if ok {
n++
}
}
log.Debug("Clients: added %d client aliases from 'arp -a' command output", n)
log.Debug("clients: added %d client aliases from 'arp -a' command output", n)
}
// Add clients from DHCP that have non-empty Hostname property
// addFromDHCP adds the clients that have a non-empty hostname from the DHCP
// server.
func (clients *clientsContainer) addFromDHCP() {
if clients.dhcpServer == nil {
return
@ -688,18 +706,20 @@ func (clients *clientsContainer) addFromDHCP() {
clients.lock.Lock()
defer clients.lock.Unlock()
clients.rmHosts(ClientSourceDHCP)
clients.rmHostsBySrc(ClientSourceDHCP)
leases := clients.dhcpServer.Leases(dhcpd.LeasesAll)
n := 0
for _, l := range leases {
if len(l.Hostname) == 0 {
if l.Hostname == "" {
continue
}
ok := clients.addHost(l.IP.String(), l.Hostname, ClientSourceDHCP)
ok := clients.addHostLocked(l.IP.String(), l.Hostname, ClientSourceDHCP)
if ok {
n++
}
}
log.Debug("Clients: added %d client aliases from DHCP", n)
log.Debug("clients: added %d client aliases from dhcp", n)
}

View File

@ -18,65 +18,65 @@ func TestClients(t *testing.T) {
clients.Init(nil, nil, nil)
t.Run("add_success", func(t *testing.T) {
c := Client{
c := &Client{
IDs: []string{"1.1.1.1", "1:2:3::4", "aa:aa:aa:aa:aa:aa"},
Name: "client1",
}
b, err := clients.Add(c)
assert.True(t, b)
ok, err := clients.Add(c)
assert.True(t, ok)
assert.Nil(t, err)
c = Client{
c = &Client{
IDs: []string{"2.2.2.2"},
Name: "client2",
}
b, err = clients.Add(c)
assert.True(t, b)
ok, err = clients.Add(c)
assert.True(t, ok)
assert.Nil(t, err)
c, b = clients.Find(net.IPv4(1, 1, 1, 1))
assert.True(t, b)
assert.Equal(t, c.Name, "client1")
c, ok = clients.Find("1.1.1.1")
assert.True(t, ok)
assert.Equal(t, "client1", c.Name)
c, b = clients.Find(net.ParseIP("1:2:3::4"))
assert.True(t, b)
assert.Equal(t, c.Name, "client1")
c, ok = clients.Find("1:2:3::4")
assert.True(t, ok)
assert.Equal(t, "client1", c.Name)
c, b = clients.Find(net.IPv4(2, 2, 2, 2))
assert.True(t, b)
assert.Equal(t, c.Name, "client2")
c, ok = clients.Find("2.2.2.2")
assert.True(t, ok)
assert.Equal(t, "client2", c.Name)
assert.False(t, clients.Exists(net.IPv4(1, 2, 3, 4), ClientSourceHostsFile))
assert.True(t, clients.Exists(net.IPv4(1, 1, 1, 1), ClientSourceHostsFile))
assert.True(t, clients.Exists(net.IPv4(2, 2, 2, 2), ClientSourceHostsFile))
assert.True(t, !clients.Exists("1.2.3.4", ClientSourceHostsFile))
assert.True(t, clients.Exists("1.1.1.1", ClientSourceHostsFile))
assert.True(t, clients.Exists("2.2.2.2", ClientSourceHostsFile))
})
t.Run("add_fail_name", func(t *testing.T) {
c := Client{
c := &Client{
IDs: []string{"1.2.3.5"},
Name: "client1",
}
b, err := clients.Add(c)
assert.False(t, b)
ok, err := clients.Add(c)
assert.False(t, ok)
assert.Nil(t, err)
})
t.Run("add_fail_ip", func(t *testing.T) {
c := Client{
c := &Client{
IDs: []string{"2.2.2.2"},
Name: "client3",
}
b, err := clients.Add(c)
assert.False(t, b)
ok, err := clients.Add(c)
assert.False(t, ok)
assert.NotNil(t, err)
})
t.Run("update_fail_name", func(t *testing.T) {
c := Client{
c := &Client{
IDs: []string{"1.2.3.0"},
Name: "client3",
}
@ -84,7 +84,7 @@ func TestClients(t *testing.T) {
err := clients.Update("client3", c)
assert.NotNil(t, err)
c = Client{
c = &Client{
IDs: []string{"1.2.3.0"},
Name: "client2",
}
@ -94,7 +94,7 @@ func TestClients(t *testing.T) {
})
t.Run("update_fail_ip", func(t *testing.T) {
c := Client{
c := &Client{
IDs: []string{"2.2.2.2"},
Name: "client1",
}
@ -104,7 +104,7 @@ func TestClients(t *testing.T) {
})
t.Run("update_success", func(t *testing.T) {
c := Client{
c := &Client{
IDs: []string{"1.1.1.2"},
Name: "client1",
}
@ -112,10 +112,10 @@ func TestClients(t *testing.T) {
err := clients.Update("client1", c)
assert.Nil(t, err)
assert.False(t, clients.Exists(net.IPv4(1, 1, 1, 1), ClientSourceHostsFile))
assert.True(t, clients.Exists(net.IPv4(1, 1, 1, 2), ClientSourceHostsFile))
assert.True(t, !clients.Exists("1.1.1.1", ClientSourceHostsFile))
assert.True(t, clients.Exists("1.1.1.2", ClientSourceHostsFile))
c = Client{
c = &Client{
IDs: []string{"1.1.1.2"},
Name: "client1-renamed",
UseOwnSettings: true,
@ -124,77 +124,89 @@ func TestClients(t *testing.T) {
err = clients.Update("client1", c)
assert.Nil(t, err)
c, b := clients.Find(net.IPv4(1, 1, 1, 2))
assert.True(t, b)
c, ok := clients.Find("1.1.1.2")
assert.True(t, ok)
assert.Equal(t, "client1-renamed", c.Name)
assert.Equal(t, "1.1.1.2", c.IDs[0])
assert.True(t, c.UseOwnSettings)
assert.Nil(t, clients.list["client1"])
if assert.Len(t, c.IDs, 1) {
assert.Equal(t, "1.1.1.2", c.IDs[0])
}
})
t.Run("del_success", func(t *testing.T) {
b := clients.Del("client1-renamed")
assert.True(t, b)
assert.False(t, clients.Exists(net.IPv4(1, 1, 1, 2), ClientSourceHostsFile))
ok := clients.Del("client1-renamed")
assert.True(t, ok)
assert.False(t, clients.Exists("1.1.1.2", ClientSourceHostsFile))
})
t.Run("del_fail", func(t *testing.T) {
b := clients.Del("client3")
assert.False(t, b)
ok := clients.Del("client3")
assert.False(t, ok)
})
t.Run("addhost_success", func(t *testing.T) {
b, err := clients.AddHost("1.1.1.1", "host", ClientSourceARP)
assert.True(t, b)
ok, err := clients.AddHost("1.1.1.1", "host", ClientSourceARP)
assert.True(t, ok)
assert.Nil(t, err)
b, err = clients.AddHost("1.1.1.1", "host2", ClientSourceARP)
assert.True(t, b)
ok, err = clients.AddHost("1.1.1.1", "host2", ClientSourceARP)
assert.True(t, ok)
assert.Nil(t, err)
b, err = clients.AddHost("1.1.1.1", "host3", ClientSourceHostsFile)
assert.True(t, b)
ok, err = clients.AddHost("1.1.1.1", "host3", ClientSourceHostsFile)
assert.True(t, ok)
assert.Nil(t, err)
assert.True(t, clients.Exists(net.IPv4(1, 1, 1, 1), ClientSourceHostsFile))
assert.True(t, clients.Exists("1.1.1.1", ClientSourceHostsFile))
})
t.Run("addhost_fail", func(t *testing.T) {
b, err := clients.AddHost("1.1.1.1", "host1", ClientSourceRDNS)
assert.False(t, b)
ok, err := clients.AddHost("1.1.1.1", "host1", ClientSourceRDNS)
assert.False(t, ok)
assert.Nil(t, err)
})
}
func TestClientsWhois(t *testing.T) {
var c Client
var c *Client
clients := clientsContainer{}
clients.testing = true
clients.Init(nil, nil, nil)
whois := [][]string{{"orgname", "orgname-val"}, {"country", "country-val"}}
// set whois info on new client
clients.SetWhoisInfo(net.IPv4(1, 1, 1, 255), whois)
assert.Equal(t, "orgname-val", clients.ipHost["1.1.1.255"].WhoisInfo[0][1])
clients.SetWhoisInfo("1.1.1.255", whois)
if assert.NotNil(t, clients.ipHost["1.1.1.255"]) {
h := clients.ipHost["1.1.1.255"]
if assert.Len(t, h.WhoisInfo, 2) && assert.Len(t, h.WhoisInfo[0], 2) {
assert.Equal(t, "orgname-val", h.WhoisInfo[0][1])
}
}
// set whois info on existing auto-client
_, _ = clients.AddHost("1.1.1.1", "host", ClientSourceRDNS)
clients.SetWhoisInfo(net.IPv4(1, 1, 1, 1), whois)
assert.Equal(t, "orgname-val", clients.ipHost["1.1.1.1"].WhoisInfo[0][1])
clients.SetWhoisInfo("1.1.1.1", whois)
if assert.NotNil(t, clients.ipHost["1.1.1.1"]) {
h := clients.ipHost["1.1.1.1"]
if assert.Len(t, h.WhoisInfo, 2) && assert.Len(t, h.WhoisInfo[0], 2) {
assert.Equal(t, "orgname-val", h.WhoisInfo[0][1])
}
}
// Check that we cannot set whois info on a manually-added client
c = Client{
c = &Client{
IDs: []string{"1.1.1.2"},
Name: "client1",
}
_, _ = clients.Add(c)
clients.SetWhoisInfo(net.IPv4(1, 1, 1, 2), whois)
clients.SetWhoisInfo("1.1.1.2", whois)
assert.Nil(t, clients.ipHost["1.1.1.2"])
_ = clients.Del("client1")
}
func TestClientsAddExisting(t *testing.T) {
var c Client
var c *Client
clients := clientsContainer{}
clients.testing = true
clients.Init(nil, nil, nil)
@ -204,7 +216,7 @@ func TestClientsAddExisting(t *testing.T) {
testIP := "1.2.3.4"
// add a client
c = Client{
c = &Client{
IDs: []string{"1.1.1.1", "1:2:3::4", "aa:aa:aa:aa:aa:aa", "2.2.2.0/24"},
Name: "client1",
}
@ -233,7 +245,7 @@ func TestClientsAddExisting(t *testing.T) {
assert.Nil(t, err)
// add a new client with the same IP as for a client with MAC
c = Client{
c = &Client{
IDs: []string{testIP},
Name: "client2",
}
@ -242,7 +254,7 @@ func TestClientsAddExisting(t *testing.T) {
assert.Nil(t, err)
// add a new client with the IP from the client1's IP range
c = Client{
c = &Client{
IDs: []string{"2.2.2.2"},
Name: "client3",
}
@ -258,7 +270,7 @@ func TestClientsCustomUpstream(t *testing.T) {
clients.Init(nil, nil, nil)
// add client with upstreams
client := Client{
c := &Client{
IDs: []string{"1.1.1.1", "1:2:3::4", "aa:aa:aa:aa:aa:aa"},
Name: "client1",
Upstreams: []string{
@ -266,7 +278,7 @@ func TestClientsCustomUpstream(t *testing.T) {
"[/example.org/]8.8.8.8",
},
}
ok, err := clients.Add(client)
ok, err := clients.Add(c)
assert.Nil(t, err)
assert.True(t, ok)
@ -275,6 +287,6 @@ func TestClientsCustomUpstream(t *testing.T) {
config = clients.FindUpstreams("1.1.1.1")
assert.NotNil(t, config)
assert.Len(t, config.Upstreams, 1)
assert.Len(t, config.DomainReservedUpstreams, 1)
assert.Equal(t, 1, len(config.Upstreams))
assert.Equal(t, 1, len(config.DomainReservedUpstreams))
}

View File

@ -158,7 +158,7 @@ func (clients *clientsContainer) handleAddClient(w http.ResponseWriter, r *http.
}
c := jsonToClient(cj)
ok, err := clients.Add(*c)
ok, err := clients.Add(c)
if err != nil {
httpError(w, http.StatusBadRequest, "%s", err)
return
@ -216,7 +216,7 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht
}
c := jsonToClient(dj.Data)
err = clients.Update(dj.Name, *c)
err = clients.Update(dj.Name, c)
if err != nil {
httpError(w, http.StatusBadRequest, "%s", err)
return
@ -229,28 +229,28 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht
func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query()
data := []map[string]clientJSON{}
for i := 0; ; i++ {
ipStr := q.Get(fmt.Sprintf("ip%d", i))
ip := net.ParseIP(ipStr)
if ip == nil {
for i := 0; i < len(q); i++ {
idStr := q.Get(fmt.Sprintf("ip%d", i))
if idStr == "" {
break
}
c, ok := clients.Find(ip)
ip := net.ParseIP(idStr)
c, ok := clients.Find(idStr)
var cj clientJSON
if !ok {
var found bool
cj, found = clients.findTemporary(ip)
cj, found = clients.findTemporary(ip, idStr)
if !found {
continue
}
} else {
cj = clientToJSON(&c)
cj = clientToJSON(c)
cj.Disallowed, cj.DisallowedRule = clients.dnsServer.IsBlockedIP(ip)
}
data = append(data, map[string]clientJSON{
ipStr: cj,
idStr: cj,
})
}
@ -263,10 +263,9 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http
// findTemporary looks up the IP in temporary storages, like autohosts or
// blocklists.
func (clients *clientsContainer) findTemporary(ip net.IP) (cj clientJSON, found bool) {
ipStr := ip.String()
ch, ok := clients.FindAutoClient(ip)
if !ok {
func (clients *clientsContainer) findTemporary(ip net.IP, idStr string) (cj clientJSON, found bool) {
ch, ok := clients.FindAutoClient(idStr)
if !ok && ip != nil {
// It is still possible that the IP used to be in the runtime
// clients list, but then the server was reloaded. So, check
// the DNS server's blocked IP list.
@ -278,7 +277,7 @@ func (clients *clientsContainer) findTemporary(ip net.IP) (cj clientJSON, found
}
cj = clientJSON{
IDs: []string{ipStr},
IDs: []string{idStr},
Disallowed: disallowed,
DisallowedRule: rule,
}
@ -286,8 +285,10 @@ func (clients *clientsContainer) findTemporary(ip net.IP) (cj clientJSON, found
return cj, true
}
cj = clientHostToJSON(ipStr, ch)
cj = clientHostToJSON(idStr, ch)
if ip != nil {
cj.Disallowed, cj.DisallowedRule = clients.dnsServer.IsBlockedIP(ip)
}
return cj, true
}

View File

@ -1,6 +1,7 @@
package home
import (
"errors"
"io/ioutil"
"net"
"os"
@ -188,7 +189,7 @@ func initConfig() {
func (c *configuration) getConfigFilename() string {
configFile, err := filepath.EvalSymlinks(Context.configFilename)
if err != nil {
if !os.IsNotExist(err) {
if !errors.Is(err, os.ErrNotExist) {
log.Error("unexpected error while config file path evaluation: %s", err)
}
configFile = Context.configFilename

View File

@ -3,8 +3,10 @@ package home
import (
"fmt"
"net"
"net/url"
"os"
"path/filepath"
"strconv"
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
"github.com/AdguardTeam/AdGuardHome/internal/dnsfilter"
@ -58,7 +60,7 @@ func initDNSServer() error {
if config.DNS.BindHost.IsUnspecified() {
bindhost = net.IPv4(127, 0, 0, 1)
}
filterConf.ResolverAddress = fmt.Sprintf("%s:%d", bindhost, config.DNS.Port)
filterConf.ResolverAddress = net.JoinHostPort(bindhost.String(), strconv.Itoa(config.DNS.Port))
filterConf.AutoHosts = &Context.autoHosts
filterConf.ConfigModified = onConfigModified
filterConf.HTTPRegister = httpRegister
@ -126,6 +128,7 @@ func generateServerConfig() (newconfig dnsforward.ServerConfig, err error) {
Context.tls.WriteDiskConfig(&tlsConf)
if tlsConf.Enabled {
newconfig.TLSConfig = tlsConf.TLSConfig
newconfig.TLSConfig.ServerName = tlsConf.ServerName
if tlsConf.PortDNSOverTLS != 0 {
newconfig.TLSListenAddr = &net.TCPAddr{
@ -207,36 +210,42 @@ type dnsEncryption struct {
quic string
}
func getDNSEncryption() dnsEncryption {
dnsEncryption := dnsEncryption{}
func getDNSEncryption() (de dnsEncryption) {
tlsConf := tlsConfigSettings{}
Context.tls.WriteDiskConfig(&tlsConf)
if tlsConf.Enabled && len(tlsConf.ServerName) != 0 {
hostname := tlsConf.ServerName
if tlsConf.PortHTTPS != 0 {
addr := tlsConf.ServerName
addr := hostname
if tlsConf.PortHTTPS != 443 {
addr = fmt.Sprintf("%s:%d", addr, tlsConf.PortHTTPS)
addr = net.JoinHostPort(addr, strconv.Itoa(tlsConf.PortHTTPS))
}
addr = fmt.Sprintf("https://%s/dns-query", addr)
dnsEncryption.https = addr
de.https = (&url.URL{
Scheme: "https",
Host: addr,
Path: "/dns-query",
}).String()
}
if tlsConf.PortDNSOverTLS != 0 {
addr := fmt.Sprintf("tls://%s:%d", tlsConf.ServerName, tlsConf.PortDNSOverTLS)
dnsEncryption.tls = addr
de.tls = (&url.URL{
Scheme: "tls",
Host: net.JoinHostPort(hostname, strconv.Itoa(tlsConf.PortDNSOverTLS)),
}).String()
}
if tlsConf.PortDNSOverQUIC != 0 {
addr := fmt.Sprintf("quic://%s:%d", tlsConf.ServerName, tlsConf.PortDNSOverQUIC)
dnsEncryption.quic = addr
de.quic = (&url.URL{
Scheme: "quic",
Host: net.JoinHostPort(hostname, strconv.Itoa(int(tlsConf.PortDNSOverQUIC))),
}).String()
}
}
return dnsEncryption
return de
}
// Get the list of DNS addresses the server is listening on
@ -273,21 +282,26 @@ func getDNSAddresses() []string {
return dnsAddresses
}
// If a client has his own settings, apply them
func applyAdditionalFiltering(clientAddr net.IP, setts *dnsfilter.RequestFilteringSettings) {
// applyAdditionalFiltering adds additional client information and settings if
// the client has them.
func applyAdditionalFiltering(clientAddr net.IP, clientID string, setts *dnsfilter.RequestFilteringSettings) {
Context.dnsFilter.ApplyBlockedServices(setts, nil, true)
if clientAddr == nil {
return
}
setts.ClientIP = clientAddr
c, ok := Context.clients.Find(clientAddr)
c, ok := Context.clients.Find(clientID)
if !ok {
c, ok = Context.clients.Find(clientAddr.String())
if !ok {
return
}
}
log.Debug("Using settings for client %s with IP %s", c.Name, clientAddr)
log.Debug("using settings for client %s with ip %s and id %q", c.Name, clientAddr, clientID)
if c.UseOwnBlockedServices {
Context.dnsFilter.ApplyBlockedServices(setts, c.BlockedServices, false)

View File

@ -5,6 +5,7 @@ import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io/ioutil"
"net"
@ -434,6 +435,10 @@ func initWorkingDir(args options) {
} else {
Context.workDir = filepath.Dir(execPath)
}
if workDir, err := filepath.EvalSymlinks(Context.workDir); err == nil {
Context.workDir = workDir
}
}
// configureLogger configures logger level and output
@ -624,7 +629,7 @@ func detectFirstRun() bool {
configfile = filepath.Join(Context.workDir, Context.configFilename)
}
_, err := os.Stat(configfile)
return os.IsNotExist(err)
return errors.Is(err, os.ErrNotExist)
}
// Connect to a remote server resolving hostname using our own DNS server

View File

@ -4,7 +4,10 @@ import (
"encoding/json"
"fmt"
"net/http"
"net/url"
"path"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/golibs/log"
uuid "github.com/satori/go.uuid"
"howett.net/plist"
@ -14,6 +17,7 @@ type dnsSettings struct {
DNSProtocol string
ServerURL string `plist:",omitempty"`
ServerName string `plist:",omitempty"`
clientID string
}
type payloadContent struct {
@ -23,19 +27,19 @@ type payloadContent struct {
PayloadIdentifier string
PayloadType string
PayloadUUID string
PayloadVersion int
DNSSettings dnsSettings
PayloadVersion int
}
type mobileConfig struct {
PayloadContent []payloadContent
PayloadDescription string
PayloadDisplayName string
PayloadIdentifier string
PayloadRemovalDisallowed bool
PayloadType string
PayloadUUID string
PayloadContent []payloadContent
PayloadVersion int
PayloadRemovalDisallowed bool
}
func genUUIDv4() string {
@ -48,22 +52,35 @@ const (
)
func getMobileConfig(d dnsSettings) ([]byte, error) {
var name string
var dspName string
switch d.DNSProtocol {
case dnsProtoHTTPS:
name = fmt.Sprintf("%s DoH", d.ServerName)
d.ServerURL = fmt.Sprintf("https://%s/dns-query", d.ServerName)
dspName = fmt.Sprintf("%s DoH", d.ServerName)
u := &url.URL{
Scheme: "https",
Host: d.ServerName,
Path: "/dns-query",
}
if d.clientID != "" {
u.Path = path.Join(u.Path, d.clientID)
}
d.ServerURL = u.String()
case dnsProtoTLS:
name = fmt.Sprintf("%s DoT", d.ServerName)
dspName = fmt.Sprintf("%s DoT", d.ServerName)
if d.clientID != "" {
d.ServerName = d.clientID + "." + d.ServerName
}
default:
return nil, fmt.Errorf("bad dns protocol %q", d.DNSProtocol)
}
data := mobileConfig{
PayloadContent: []payloadContent{{
Name: name,
Name: dspName,
PayloadDescription: "Configures device to use AdGuard Home",
PayloadDisplayName: name,
PayloadDisplayName: dspName,
PayloadIdentifier: fmt.Sprintf("com.apple.dnsSettings.managed.%s", genUUIDv4()),
PayloadType: "com.apple.dnsSettings.managed",
PayloadUUID: genUUIDv4(),
@ -71,7 +88,7 @@ func getMobileConfig(d dnsSettings) ([]byte, error) {
DNSSettings: d,
}},
PayloadDescription: "Adds AdGuard Home to Big Sur and iOS 14 or newer systems",
PayloadDisplayName: name,
PayloadDisplayName: dspName,
PayloadIdentifier: genUUIDv4(),
PayloadRemovalDisallowed: false,
PayloadType: "Configuration",
@ -83,7 +100,10 @@ func getMobileConfig(d dnsSettings) ([]byte, error) {
}
func handleMobileConfig(w http.ResponseWriter, r *http.Request, dnsp string) {
host := r.URL.Query().Get("host")
var err error
q := r.URL.Query()
host := q.Get("host")
if host == "" {
host = Context.tls.conf.ServerName
}
@ -92,7 +112,7 @@ func handleMobileConfig(w http.ResponseWriter, r *http.Request, dnsp string) {
w.WriteHeader(http.StatusInternalServerError)
const msg = "no host in query parameters and no server_name"
err := json.NewEncoder(w).Encode(&jsonError{
err = json.NewEncoder(w).Encode(&jsonError{
Message: msg,
})
if err != nil {
@ -102,9 +122,25 @@ func handleMobileConfig(w http.ResponseWriter, r *http.Request, dnsp string) {
return
}
clientID := q.Get("client_id")
err = dnsforward.ValidateClientID(clientID)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
err = json.NewEncoder(w).Encode(&jsonError{
Message: err.Error(),
})
if err != nil {
log.Debug("writing 400 json response: %s", err)
}
return
}
d := dnsSettings{
DNSProtocol: dnsp,
ServerName: host,
clientID: clientID,
}
mobileconfig, err := getMobileConfig(d)
@ -115,6 +151,7 @@ func handleMobileConfig(w http.ResponseWriter, r *http.Request, dnsp string) {
}
w.Header().Set("Content-Type", "application/xml")
_, _ = w.Write(mobileconfig)
}

View File

@ -73,6 +73,27 @@ func TestHandleMobileConfigDOH(t *testing.T) {
handleMobileConfigDOH(w, r)
assert.Equal(t, http.StatusInternalServerError, w.Code)
})
t.Run("client_id", func(t *testing.T) {
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/doh.mobileconfig?host=example.org&client_id=cli42", nil)
assert.Nil(t, err)
w := httptest.NewRecorder()
handleMobileConfigDOH(w, r)
assert.Equal(t, http.StatusOK, w.Code)
var mc mobileConfig
_, err = plist.Unmarshal(w.Body.Bytes(), &mc)
assert.Nil(t, err)
if assert.Len(t, mc.PayloadContent, 1) {
assert.Equal(t, "example.org DoH", mc.PayloadContent[0].Name)
assert.Equal(t, "example.org DoH", mc.PayloadContent[0].PayloadDisplayName)
assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName)
assert.Equal(t, "https://example.org/dns-query/cli42", mc.PayloadContent[0].DNSSettings.ServerURL)
}
})
}
func TestHandleMobileConfigDOT(t *testing.T) {
@ -137,4 +158,24 @@ func TestHandleMobileConfigDOT(t *testing.T) {
handleMobileConfigDOT(w, r)
assert.Equal(t, http.StatusInternalServerError, w.Code)
})
t.Run("client_id", func(t *testing.T) {
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/dot.mobileconfig?host=example.org&client_id=cli42", nil)
assert.Nil(t, err)
w := httptest.NewRecorder()
handleMobileConfigDOT(w, r)
assert.Equal(t, http.StatusOK, w.Code)
var mc mobileConfig
_, err = plist.Unmarshal(w.Body.Bytes(), &mc)
assert.Nil(t, err)
if assert.Len(t, mc.PayloadContent, 1) {
assert.Equal(t, "example.org DoT", mc.PayloadContent[0].Name)
assert.Equal(t, "example.org DoT", mc.PayloadContent[0].PayloadDisplayName)
assert.Equal(t, "cli42.example.org", mc.PayloadContent[0].DNSSettings.ServerName)
}
})
}

View File

@ -57,7 +57,8 @@ func (r *RDNS) Begin(ip net.IP) {
binary.BigEndian.PutUint64(expire, now+ttl)
_ = r.ipAddrs.Set(ip, expire)
if r.clients.Exists(ip, ClientSourceRDNS) {
id := ip.String()
if r.clients.Exists(id, ClientSourceRDNS) {
return
}

View File

@ -11,7 +11,6 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
"github.com/AdguardTeam/AdGuardHome/internal/util"
"github.com/AdguardTeam/golibs/cache"
"github.com/AdguardTeam/golibs/log"
)
@ -27,12 +26,14 @@ const (
type Whois struct {
clients *clientsContainer
ipChan chan net.IP
timeoutMsec uint
// Contains IP addresses of clients
// An active IP address is resolved once again after it expires.
// If IP address couldn't be resolved, it stays here for some time to prevent further attempts to resolve the same IP.
ipAddrs cache.Cache
// TODO(a.garipov): Rewrite to use time.Duration. Like, seriously, why?
timeoutMsec uint
}
// initWhois creates the Whois module context.
@ -244,6 +245,7 @@ func (w *Whois) workerLoop() {
continue
}
w.clients.SetWhoisInfo(ip, info)
id := ip.String()
w.clients.SetWhoisInfo(id, info)
}
}

View File

@ -17,6 +17,16 @@ import (
type logEntryHandler (func(t json.Token, ent *logEntry) error)
var logEntryHandlers = map[string]logEntryHandler{
"CID": func(t json.Token, ent *logEntry) error {
v, ok := t.(string)
if !ok {
return nil
}
ent.ClientID = v
return nil
},
"IP": func(t json.Token, ent *logEntry) error {
v, ok := t.(string)
if !ok {

View File

@ -25,6 +25,7 @@ func TestDecodeLogEntry(t *testing.T) {
t.Run("success", func(t *testing.T) {
const ansStr = `Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==`
const data = `{"IP":"127.0.0.1",` +
`"CID":"cli42",` +
`"T":"2020-11-25T18:55:56.519796+03:00",` +
`"QH":"an.yandex.ru",` +
`"QT":"A",` +
@ -52,6 +53,7 @@ func TestDecodeLogEntry(t *testing.T) {
QHost: "an.yandex.ru",
QType: "A",
QClass: "IN",
ClientID: "cli42",
ClientProto: "",
Answer: ans,
Result: dnsfilter.Result{

View File

@ -79,6 +79,10 @@ func (l *queryLog) logEntryToJSONEntry(entry *logEntry) (jsonEntry jobject) {
},
}
if entry.ClientID != "" {
jsonEntry["client_id"] = entry.ClientID
}
if msg != nil {
jsonEntry["status"] = dns.RcodeToString[msg.Rcode]

View File

@ -2,6 +2,7 @@
package querylog
import (
"errors"
"fmt"
"net"
"os"
@ -40,6 +41,7 @@ const (
ClientProtoDOH ClientProto = "doh"
ClientProtoDOQ ClientProto = "doq"
ClientProtoDOT ClientProto = "dot"
ClientProtoDNSCrypt ClientProto = "dnscrypt"
ClientProtoPlain ClientProto = ""
)
@ -68,6 +70,7 @@ type logEntry struct {
QType string `json:"QT"`
QClass string `json:"QC"`
ClientID string `json:"CID,omitempty"`
ClientProto ClientProto `json:"CP"`
Answer []byte `json:",omitempty"` // sometimes empty answers happen like binerdunt.top or rev2.globalrootservers.net
@ -119,14 +122,15 @@ func (l *queryLog) clear() {
l.flushPending = false
l.bufferLock.Unlock()
err := os.Remove(l.logFile + ".1")
if err != nil && !os.IsNotExist(err) {
log.Error("file remove: %s: %s", l.logFile+".1", err)
oldLogFile := l.logFile + ".1"
err := os.Remove(oldLogFile)
if err != nil && !errors.Is(err, os.ErrNotExist) {
log.Error("removing old log file %q: %s", oldLogFile, err)
}
err = os.Remove(l.logFile)
if err != nil && !os.IsNotExist(err) {
log.Error("file remove: %s: %s", l.logFile, err)
if err != nil && !errors.Is(err, os.ErrNotExist) {
log.Error("removing log file %q: %s", l.logFile, err)
}
log.Debug("Query log: cleared")
@ -154,6 +158,7 @@ func (l *queryLog) Add(params AddParams) {
Result: *params.Result,
Elapsed: params.Elapsed,
Upstream: params.Upstream,
ClientID: params.ClientID,
ClientProto: params.ClientProto,
}
q := params.Question.Question[0]

View File

@ -251,7 +251,7 @@ func (q *QLogFile) readNextLine(position int64) (string, int64, error) {
// the goal is to read a chunk of file that includes the line with the specified position.
func (q *QLogFile) initBuffer(position int64) error {
q.bufferStart = int64(0)
if (position - bufferSize) > 0 {
if position > bufferSize {
q.bufferStart = position - bufferSize
}
@ -264,12 +264,10 @@ func (q *QLogFile) initBuffer(position int64) error {
if q.buffer == nil {
q.buffer = make([]byte, bufferSize)
}
q.bufferLen, err = q.file.Read(q.buffer)
if err != nil {
return err
}
return nil
q.bufferLen, err = q.file.Read(q.buffer)
return err
}
// readProbeLine reads a line that includes the specified position
@ -280,7 +278,7 @@ func (q *QLogFile) readProbeLine(position int64) (string, int64, int64, error) {
// In order to do this, we'll define the boundaries
seekPosition := int64(0)
relativePos := position // position relative to the buffer we're going to read
if (position - maxEntrySize) > 0 {
if position > maxEntrySize {
seekPosition = position - maxEntrySize
relativePos = maxEntrySize
}

View File

@ -46,6 +46,7 @@ type AddParams struct {
OrigAnswer *dns.Msg // The response from an upstream server (optional)
Result *dnsfilter.Result // Filtering result (optional)
Elapsed time.Duration // Time spent for processing the request
ClientID string
ClientIP net.IP
Upstream string // Upstream server URL
ClientProto ClientProto

View File

@ -3,6 +3,7 @@ package querylog
import (
"bytes"
"encoding/json"
"errors"
"os"
"time"
@ -87,18 +88,19 @@ func (l *queryLog) rotate() error {
from := l.logFile
to := l.logFile + ".1"
if _, err := os.Stat(from); os.IsNotExist(err) {
// do nothing, file doesn't exist
err := os.Rename(from, to)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return nil
}
err := os.Rename(from, to)
if err != nil {
log.Error("querylog: failed to rename file: %s", err)
return err
}
log.Debug("querylog: renamed %s -> %s", from, to)
return nil
}

View File

@ -9,8 +9,13 @@ import (
type criteriaType int
const (
ctDomainOrClient criteriaType = iota // domain name or client IP address
ctFilteringStatus // filtering status
// ctDomainOrClient is for searching by the domain name, the client's IP
// address, or the clinet's ID.
ctDomainOrClient criteriaType = iota
// ctFilteringStatus is for searching by the filtering status.
//
// See (*searchCriteria).ctFilteringStatusCase for details.
ctFilteringStatus
)
const (
@ -38,9 +43,9 @@ var filteringStatusValues = []string{
// searchCriteria - every search request may contain a list of different search criteria
// we use each of them to match the query
type searchCriteria struct {
value string // search criteria value
criteriaType criteriaType // type of the criteria
strict bool // should we strictly match (equality) or not (indexOf)
value string // search criteria value
}
// quickMatch - quickly checks if the log entry matches this search criteria
@ -51,7 +56,8 @@ func (c *searchCriteria) quickMatch(line string) bool {
switch c.criteriaType {
case ctDomainOrClient:
return c.quickMatchJSONValue(line, "QH") ||
c.quickMatchJSONValue(line, "IP")
c.quickMatchJSONValue(line, "IP") ||
c.quickMatchJSONValue(line, "ID")
default:
return true
}
@ -89,13 +95,14 @@ func (c *searchCriteria) match(entry *logEntry) bool {
}
func (c *searchCriteria) ctDomainOrClientCase(entry *logEntry) bool {
clientID := strings.ToLower(entry.ClientID)
qhost := strings.ToLower(entry.QHost)
searchVal := strings.ToLower(c.value)
if c.strict && qhost == searchVal {
if c.strict && (qhost == searchVal || clientID == searchVal) {
return true
}
if !c.strict && strings.Contains(qhost, searchVal) {
if !c.strict && (strings.Contains(qhost, searchVal) || strings.Contains(clientID, searchVal)) {
return true
}

View File

@ -76,10 +76,14 @@ const (
rLast
)
// Entry - data to add
// Entry is a statistics data entry.
type Entry struct {
// Clients is the client's primary ID.
//
// TODO(a.garipov): Make this a {net.IP, string} enum?
Client string
Domain string
Client net.IP
Result Result
Time uint32 // processing time (msec)
}

View File

@ -39,13 +39,13 @@ func TestStats(t *testing.T) {
e := Entry{}
e.Domain = "domain"
e.Client = net.IP{127, 0, 0, 1}
e.Client = "127.0.0.1"
e.Result = RFiltered
e.Time = 123456
s.Update(e)
e.Domain = "domain"
e.Client = net.IP{127, 0, 0, 1}
e.Client = "127.0.0.1"
e.Result = RNotFiltered
e.Time = 123456
s.Update(e)
@ -113,9 +113,10 @@ func TestLargeNumbers(t *testing.T) {
}
for i := 0; i != n; i++ {
e.Domain = fmt.Sprintf("domain%d", i)
e.Client = net.IP{127, 0, 0, 1}
e.Client[2] = byte((i & 0xff00) >> 8)
e.Client[3] = byte(i & 0xff)
ip := net.IP{127, 0, 0, 1}
ip[2] = byte((i & 0xff00) >> 8)
ip[3] = byte(i & 0xff)
e.Client = ip.String()
e.Result = RNotFiltered
e.Time = 123456
s.Update(e)

View File

@ -223,6 +223,7 @@ func (s *statsCtx) periodicFlush() {
s.unitLock.Lock()
ptr := s.unit
s.unitLock.Unlock()
if ptr == nil {
break
}
@ -230,6 +231,7 @@ func (s *statsCtx) periodicFlush() {
id := s.conf.UnitID()
if ptr.id == id {
time.Sleep(time.Second)
continue
}
@ -243,6 +245,7 @@ func (s *statsCtx) periodicFlush() {
if tx == nil {
continue
}
ok1 := s.flushUnitToDB(tx, u.id, udb)
ok2 := s.deleteUnit(tx, id-s.conf.limit)
if ok1 || ok2 {
@ -251,6 +254,7 @@ func (s *statsCtx) periodicFlush() {
_ = tx.Rollback()
}
}
log.Tracef("periodicFlush() exited")
}
@ -265,7 +269,7 @@ func (s *statsCtx) deleteUnit(tx *bolt.Tx, id uint32) bool {
return true
}
func convertMapToArray(m map[string]uint64, max int) []countPair {
func convertMapToSlice(m map[string]uint64, max int) []countPair {
a := []countPair{}
for k, v := range m {
pair := countPair{}
@ -283,7 +287,7 @@ func convertMapToArray(m map[string]uint64, max int) []countPair {
return a[:max]
}
func convertArrayToMap(a []countPair) map[string]uint64 {
func convertSliceToMap(a []countPair) map[string]uint64 {
m := map[string]uint64{}
for _, it := range a {
m[it.Name] = it.Count
@ -301,9 +305,9 @@ func serialize(u *unit) *unitDB {
udb.TimeAvg = uint32(u.timeSum / u.nTotal)
}
udb.Domains = convertMapToArray(u.domains, maxDomains)
udb.BlockedDomains = convertMapToArray(u.blockedDomains, maxDomains)
udb.Clients = convertMapToArray(u.clients, maxClients)
udb.Domains = convertMapToSlice(u.domains, maxDomains)
udb.BlockedDomains = convertMapToSlice(u.blockedDomains, maxDomains)
udb.Clients = convertMapToSlice(u.clients, maxClients)
return &udb
}
@ -319,9 +323,9 @@ func deserialize(u *unit, udb *unitDB) {
u.nResult[i] = udb.NResult[i]
}
u.domains = convertArrayToMap(udb.Domains)
u.blockedDomains = convertArrayToMap(udb.BlockedDomains)
u.clients = convertArrayToMap(udb.Clients)
u.domains = convertSliceToMap(udb.Domains)
u.blockedDomains = convertSliceToMap(udb.BlockedDomains)
u.clients = convertSliceToMap(udb.Clients)
u.timeSum = uint64(udb.TimeAvg) * u.nTotal
}
@ -372,7 +376,7 @@ func (s *statsCtx) loadUnitFromDB(tx *bolt.Tx, id uint32) *unitDB {
return &udb
}
func convertTopArray(a []countPair) []map[string]uint64 {
func convertTopSlice(a []countPair) []map[string]uint64 {
m := []map[string]uint64{}
for _, it := range a {
ent := map[string]uint64{}
@ -461,13 +465,20 @@ func (s *statsCtx) getClientIP(ip net.IP) (clientIP net.IP) {
func (s *statsCtx) Update(e Entry) {
if e.Result == 0 ||
e.Result >= rLast ||
len(e.Domain) == 0 ||
!(len(e.Client) == 4 || len(e.Client) == 16) {
e.Domain == "" ||
e.Client == "" {
return
}
client := s.getClientIP(e.Client)
clientID := e.Client
if ip := net.ParseIP(clientID); ip != nil {
ip = s.getClientIP(ip)
clientID = ip.String()
}
s.unitLock.Lock()
defer s.unitLock.Unlock()
u := s.unit
u.nResult[e.Result]++
@ -478,10 +489,9 @@ func (s *statsCtx) Update(e Entry) {
u.blockedDomains[e.Domain]++
}
u.clients[client.String()]++
u.clients[clientID]++
u.timeSum += uint64(e.Time)
u.nTotal++
s.unitLock.Unlock()
}
func (s *statsCtx) loadUnits(limit uint32) ([]*unitDB, uint32) {
@ -594,8 +604,8 @@ func (s *statsCtx) getData() (statsResponse, bool) {
m[it.Name] += it.Count
}
}
a2 := convertMapToArray(m, max)
return convertTopArray(a2)
a2 := convertMapToSlice(m, max)
return convertTopSlice(a2)
}
dnsQueries := statsCollector(func(u *unitDB) (num uint64) { return u.NTotal })
@ -661,7 +671,7 @@ func (s *statsCtx) GetTopClientsIP(maxCount uint) []net.IP {
m[it.Name] += it.Count
}
}
a := convertMapToArray(m, int(maxCount))
a := convertMapToSlice(m, int(maxCount))
d := []net.IP{}
for _, it := range a {
d = append(d, net.ParseIP(it.Name))

View File

@ -4,6 +4,11 @@
## v0.105: API changes
### New `"dnscrypt"` `"client_proto"` value in `GET /querylog` response
* The field `"client_proto"` can now have the value `"dnscrypt"` when the
request was sent over a DNSCrypt connection.
### New `"reason"` in `GET /filtering/check_host` and `GET /querylog`
* The new `RewriteRule` reason is added to `GET /filtering/check_host` and

View File

@ -794,11 +794,17 @@
'tags':
- 'clients'
'operationId': 'clientsFind'
'summary': 'Get information about selected clients by their IP address'
'summary': >
Get information about clients by their IP addresses or client IDs.
'parameters':
- 'name': 'ip0'
'in': 'query'
'description': 'Filter by IP address'
'description': >
Filter by IP address or client IDs. Parameters with names `ip1`,
`ip2`, and so on are also accepted and interpreted as "ip0 OR ip1 OR
ip2".
TODO(a.garipov): Replace with a better query API.
'schema':
'type': 'string'
'responses':
@ -1109,6 +1115,13 @@
'name': 'host'
'schema':
'type': 'string'
- 'description': >
Client ID.
'example': 'client-1'
'in': 'query'
'name': 'client_id'
'schema':
'type': 'string'
'responses':
'200':
'description': 'DNS over HTTPS plist file.'
@ -1136,6 +1149,13 @@
'name': 'host'
'schema':
'type': 'string'
- 'description': >
Client ID.
'example': 'client-1'
'in': 'query'
'name': 'client_id'
'schema':
'type': 'string'
'responses':
'200':
'description': 'DNS over TLS plist file'
@ -1781,13 +1801,21 @@
'answer_dnssec':
'type': 'boolean'
'client':
'type': 'string'
'description': >
The client's IP address.
'example': '192.168.0.1'
'type': 'string'
'client_id':
'description': >
The client ID, if provided in DOH, DOQ, or DOT.
'example': 'cli123'
'type': 'string'
'client_proto':
'enum':
- 'dot'
- 'doh'
- 'doq'
- 'dnscrypt'
- ''
'elapsedMs':
'type': 'string'
@ -2094,7 +2122,7 @@
'type': 'string'
'Client':
'type': 'object'
'description': 'Client information'
'description': 'Client information.'
'properties':
'name':
'type': 'string'
@ -2102,7 +2130,7 @@
'example': 'localhost'
'ids':
'type': 'array'
'description': 'IP, CIDR or MAC address'
'description': 'IP, CIDR, MAC, or client ID.'
'items':
'type': 'string'
'use_global_settings':
@ -2157,9 +2185,38 @@
'type': 'string'
'ClientsFindResponse':
'type': 'array'
'description': 'Response to clients find operation'
'description': 'Client search results.'
'items':
'$ref': '#/components/schemas/ClientsFindEntry'
'example':
- 'cli42':
'name': 'Client 42'
'ids': ['cli42']
'use_global_settings': true
'filtering_enabled': true
'parental_enabled': true
'safebrowsing_enabled': true
'safesearch_enabled': true
'use_global_blocked_services': true
'blocked_services': null
'upstreams': null
'whois_info': null
'disallowed': false
'disallowed_rule': ''
- '1.2.3.4':
'name': 'Client 1-2-3-4'
'ids': ['1.2.3.4']
'use_global_settings': true
'filtering_enabled': true
'parental_enabled': true
'safebrowsing_enabled': true
'safesearch_enabled': true
'use_global_blocked_services': true
'blocked_services': null
'upstreams': null
'whois_info': null
'disallowed': false
'disallowed_rule': ''
'AccessListResponse':
'$ref': '#/components/schemas/AccessList'
'AccessSetRequest':
@ -2187,10 +2244,9 @@
'type': 'object'
'additionalProperties':
'$ref': '#/components/schemas/ClientFindSubEntry'
'example':
'1.2.3.4': 'test'
'ClientFindSubEntry':
'type': 'object'
'description': 'Client information.'
'properties':
'name':
'type': 'string'
@ -2198,7 +2254,7 @@
'example': 'localhost'
'ids':
'type': 'array'
'description': 'IP, CIDR or MAC address'
'description': 'IP, CIDR, MAC, or client ID.'
'items':
'type': 'string'
'use_global_settings':

View File

@ -54,7 +54,7 @@ esac
# TODO(a.garipov): Additional validation?
version="$VERSION"
# Set the linker flags accordingly: set the realease channel and the
# Set the linker flags accordingly: set the release channel and the
# current version as well as goarm and gomips variable values, if the
# variables are set and are not empty.
readonly version_pkg='github.com/AdguardTeam/AdGuardHome/internal/version'