diff --git a/.gitignore b/.gitignore index e22df4e9..5cfd4889 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ debug /AdGuardHome /AdGuardHome.yaml +/data/ /build/ /client/node_modules/ /coredns diff --git a/app.go b/app.go index a317aa29..647ffa63 100644 --- a/app.go +++ b/app.go @@ -25,10 +25,18 @@ func main() { if err != nil { panic(err) } - config.ourBinaryDir = filepath.Dir(executable) - } - doConfigRename := true + executableName := filepath.Base(executable) + if executableName == "AdGuardHome" { + // Binary build + config.ourBinaryDir = filepath.Dir(executable) + } else { + // Most likely we're debugging -- using current working directory in this case + workDir, _ := os.Getwd() + config.ourBinaryDir = workDir + } + log.Printf("Current working directory is %s", config.ourBinaryDir) + } // config can be specified, which reads options from there, but other command line flags have to override config values // therefore, we must do it manually instead of using a lib @@ -98,18 +106,9 @@ func main() { } } if configFilename != nil { - // config was manually specified, don't do anything - doConfigRename = false config.ourConfigFilename = *configFilename } - if doConfigRename { - err := renameOldConfigIfNeccessary() - if err != nil { - panic(err) - } - } - err := askUsernamePasswordIfPossible() if err != nil { log.Fatal(err) @@ -120,6 +119,8 @@ func main() { if err != nil { log.Fatal(err) } + + // override bind host/port from the console if bindHost != nil { config.BindHost = *bindHost } @@ -128,19 +129,36 @@ func main() { } } - // eat all args so that coredns can start happily + // Eat all args so that coredns can start happily if len(os.Args) > 1 { os.Args = os.Args[:1] } - err := writeConfig() + // Do the upgrade if necessary + err := upgradeConfig() if err != nil { log.Fatal(err) } + // Save the updated config + err = writeConfig() + if err != nil { + log.Fatal(err) + } + + // Load filters from the disk + for i := range config.Filters { + filter := &config.Filters[i] + err = filter.load() + if err != nil { + // This is okay for the first start, the filter will be loaded later + log.Printf("Couldn't load filter %d contents due to %s", filter.ID, err) + } + } + address := net.JoinHostPort(config.BindHost, strconv.Itoa(config.BindPort)) - runFilterRefreshers() + runFiltersUpdatesTimer() http.Handle("/", optionalAuthHandler(http.FileServer(box))) registerControlHandlers() @@ -240,27 +258,79 @@ func askUsernamePasswordIfPossible() error { return nil } -func renameOldConfigIfNeccessary() error { - oldConfigFile := filepath.Join(config.ourBinaryDir, "AdguardDNS.yaml") - _, err := os.Stat(oldConfigFile) - if os.IsNotExist(err) { - // do nothing, file doesn't exist - trace("File %s doesn't exist, nothing to do", oldConfigFile) +// Performs necessary upgrade operations if needed +func upgradeConfig() error { + + if config.SchemaVersion == SchemaVersion { + // No upgrade, do nothing return nil } - newConfigFile := filepath.Join(config.ourBinaryDir, config.ourConfigFilename) - _, err = os.Stat(newConfigFile) - if !os.IsNotExist(err) { - // do nothing, file doesn't exist - trace("File %s already exists, will not overwrite", newConfigFile) - return nil + if config.SchemaVersion > SchemaVersion { + // Unexpected -- the config file is newer than we expect + return fmt.Errorf("configuration file is supposed to be used with a newer version of AdGuard Home, schema=%d", config.SchemaVersion) } - err = os.Rename(oldConfigFile, newConfigFile) - if err != nil { - log.Printf("Failed to rename %s to %s: %s", oldConfigFile, newConfigFile, err) - return err + // Perform upgrade operations for each consecutive version upgrade + for oldVersion, newVersion := config.SchemaVersion, config.SchemaVersion+1; newVersion <= SchemaVersion; { + + err := upgradeConfigSchema(oldVersion, newVersion) + if err != nil { + log.Fatal(err) + } + + // Increment old and new versions + oldVersion++ + newVersion++ + } + + // Save the current schema version + config.SchemaVersion = SchemaVersion + + return nil +} + +// Upgrade from oldVersion to newVersion +func upgradeConfigSchema(oldVersion int, newVersion int) error { + + if oldVersion == 0 && newVersion == 1 { + log.Printf("Updating schema from %d to %d", oldVersion, newVersion) + + // The first schema upgrade: + // Added "ID" field to "filter" -- we need to populate this field now + // Added "config.ourDataDir" -- where we will now store filters contents + for i := range config.Filters { + + filter := &config.Filters[i] // otherwise we will be operating on a copy + + // Set the filter ID + log.Printf("Seting ID=%d for filter %s", NextFilterId, filter.URL) + filter.ID = NextFilterId + NextFilterId++ + + // Forcibly update the filter + _, err := filter.update(true) + if err != nil { + log.Fatal(err) + } + + // Saving it to the filters dir now + err = filter.save() + if err != nil { + log.Fatal(err) + } + } + + // No more "dnsfilter.txt", filters are now loaded from config.ourDataDir/filters/ + dnsFilterPath := filepath.Join(config.ourBinaryDir, "dnsfilter.txt") + _, err := os.Stat(dnsFilterPath) + if !os.IsNotExist(err) { + log.Printf("Deleting %s as we don't need it anymore", dnsFilterPath) + err = os.Remove(dnsFilterPath) + if err != nil { + log.Printf("Cannot remove %s due to %s", dnsFilterPath, err) + } + } } return nil diff --git a/client/src/components/Logs/index.js b/client/src/components/Logs/index.js index b533f190..23409f6e 100644 --- a/client/src/components/Logs/index.js +++ b/client/src/components/Logs/index.js @@ -10,7 +10,7 @@ import { getTrackerData } from '../../helpers/trackers/trackers'; import PageTitle from '../ui/PageTitle'; import Card from '../ui/Card'; import Loading from '../ui/Loading'; -import Tooltip from '../ui/Tooltip'; +import PopoverFiltered from '../ui/PopoverFilter'; import Popover from '../ui/Popover'; import './Logs.css'; @@ -36,9 +36,9 @@ class Logs extends Component { } } - renderTooltip(isFiltered, rule) { + renderTooltip(isFiltered, rule, filter) { if (rule) { - return (isFiltered && ); + return (isFiltered && ); } return ''; } @@ -117,14 +117,27 @@ class Logs extends Component { const isFiltered = row ? reason.indexOf('Filtered') === 0 : false; const parsedFilteredReason = reason.replace('Filtered', 'Filtered by '); const rule = row && row.original && row.original.rule; + const { filterId } = row.original; + const { filters } = this.props.filtering; + let filterName = ''; + + if (reason === 'FilteredBlackList' || reason === 'NotFilteredWhiteList') { + if (filterId === 0) { + filterName = 'Custom filtering rules'; + } else { + const filterItem = Object.keys(filters) + .filter(key => filters[key].id === filterId); + filterName = filters[filterItem].name; + } + } if (isFiltered) { return (
- {this.renderTooltip(isFiltered, rule)} {parsedFilteredReason} + {this.renderTooltip(isFiltered, rule, filterName)}
); } @@ -132,17 +145,19 @@ class Logs extends Component { if (responses.length > 0) { const liNodes = responses.map((response, index) => (
  • {response}
  • )); + const isRenderTooltip = reason === 'NotFilteredWhiteList'; + return (
    - {this.renderTooltip(isFiltered, rule)}
      {liNodes}
    + {this.renderTooltip(isRenderTooltip, rule, filterName)}
    ); } return (
    - {this.renderTooltip(isFiltered, rule)} Empty + {this.renderTooltip(isFiltered, rule, filterName)}
    ); }, @@ -208,8 +223,19 @@ class Logs extends Component { if (!rowInfo) { return {}; } + + if (rowInfo.original.reason.indexOf('Filtered') === 0) { + return { + className: 'red', + }; + } else if (rowInfo.original.reason === 'NotFilteredWhiteList') { + return { + className: 'green', + }; + } + return { - className: (rowInfo.original.reason.indexOf('Filtered') === 0 ? 'red' : ''), + className: '', }; }} />); diff --git a/client/src/components/ui/Popover.css b/client/src/components/ui/Popover.css index cdaf5fe6..dc83e4cb 100644 --- a/client/src/components/ui/Popover.css +++ b/client/src/components/ui/Popover.css @@ -1,7 +1,9 @@ .popover-wrap { position: relative; + top: 1px; display: inline-block; vertical-align: middle; + align-self: flex-start; } .popover__trigger { @@ -24,9 +26,9 @@ content: ""; display: flex; position: absolute; - min-width: 275px; bottom: calc(100% + 3px); left: 50%; + min-width: 275px; padding: 10px 15px; font-size: 0.8rem; white-space: normal; @@ -39,6 +41,10 @@ opacity: 0; } +.popover__body--filter { + min-width: 100%; +} + .popover__body:after { content: ""; position: absolute; @@ -63,6 +69,10 @@ stroke: #9aa0ac; } +.popover__icon--green { + stroke: #66b574; +} + .popover__list-title { margin-bottom: 3px; } @@ -71,6 +81,13 @@ margin-bottom: 2px; } +.popover__list-item--nowrap { + max-width: 300px; + text-overflow: ellipsis; + white-space: nowrap; + overflow: hidden; +} + .popover__list-item:last-child { margin-bottom: 0; } diff --git a/client/src/components/ui/PopoverFilter.js b/client/src/components/ui/PopoverFilter.js new file mode 100644 index 00000000..9eb041dc --- /dev/null +++ b/client/src/components/ui/PopoverFilter.js @@ -0,0 +1,33 @@ +import React, { Component } from 'react'; +import PropTypes from 'prop-types'; + +import './Popover.css'; + +class PopoverFilter extends Component { + render() { + return ( +
    +
    + +
    +
    +
    +
    + Rule: {this.props.rule} +
    + {this.props.filter &&
    + Filter: {this.props.filter} +
    } +
    +
    +
    + ); + } +} + +PopoverFilter.propTypes = { + rule: PropTypes.string.isRequired, + filter: PropTypes.string, +}; + +export default PopoverFilter; diff --git a/client/src/components/ui/ReactTable.css b/client/src/components/ui/ReactTable.css index bdcf3576..e1de27b4 100644 --- a/client/src/components/ui/ReactTable.css +++ b/client/src/components/ui/ReactTable.css @@ -11,3 +11,7 @@ .rt-tr-group .red { background-color: #fff4f2; } + +.rt-tr-group .green { + background-color: #f1faf3; +} diff --git a/client/src/helpers/helpers.js b/client/src/helpers/helpers.js index 07c343bc..870d320f 100644 --- a/client/src/helpers/helpers.js +++ b/client/src/helpers/helpers.js @@ -18,6 +18,7 @@ export const normalizeLogs = logs => logs.map((log) => { answer: response, reason, client, + filterId, rule, } = log; const { host: domain, type } = question; @@ -32,6 +33,7 @@ export const normalizeLogs = logs => logs.map((log) => { response: responsesArray, reason, client, + filterId, rule, }; }); @@ -64,11 +66,11 @@ export const normalizeFilteringStatus = (filteringStatus) => { const { enabled, filters, user_rules: userRules } = filteringStatus; const newFilters = filters ? filters.map((filter) => { const { - url, enabled, last_updated: lastUpdated = Date.now(), name = 'Default name', rules_count: rulesCount = 0, + id, url, enabled, lastUpdated: lastUpdated = Date.now(), name = 'Default name', rulesCount: rulesCount = 0, } = filter; return { - url, enabled, lastUpdated: formatTime(lastUpdated), name, rulesCount, + id, url, enabled, lastUpdated: formatTime(lastUpdated), name, rulesCount, }; }) : []; const newUserRules = Array.isArray(userRules) ? userRules.join('\n') : ''; diff --git a/config.go b/config.go index 09e89ee1..bf63c0ee 100644 --- a/config.go +++ b/config.go @@ -14,48 +14,73 @@ import ( "gopkg.in/yaml.v2" ) +// Current schema version. We compare it with the value from +// the configuration file and perform necessary upgrade operations if needed +const SchemaVersion = 1 + +// Directory where we'll store all downloaded filters contents +const FiltersDir = "filters" + +// User filter ID is always 0 +const UserFilterId = 0 + +// Just a counter that we use for incrementing the filter ID +var NextFilterId = time.Now().Unix() + // configuration is loaded from YAML type configuration struct { + // Config filename (can be overriden via the command line arguments) ourConfigFilename string - ourBinaryDir string + // Basically, this is our working directory + ourBinaryDir string + // Directory to store data (i.e. filters contents) + ourDataDir string - BindHost string `yaml:"bind_host"` - BindPort int `yaml:"bind_port"` - AuthName string `yaml:"auth_name"` - AuthPass string `yaml:"auth_pass"` - CoreDNS coreDNSConfig `yaml:"coredns"` - Filters []filter `yaml:"filters"` - UserRules []string `yaml:"user_rules"` + // Schema version of the config file. This value is used when performing the app updates. + SchemaVersion int `yaml:"schema_version"` + BindHost string `yaml:"bind_host"` + BindPort int `yaml:"bind_port"` + AuthName string `yaml:"auth_name"` + AuthPass string `yaml:"auth_pass"` + CoreDNS coreDNSConfig `yaml:"coredns"` + Filters []filter `yaml:"filters"` + UserRules []string `yaml:"user_rules"` sync.RWMutex `yaml:"-"` } +type coreDnsFilter struct { + ID int64 `yaml:"-"` + Path string `yaml:"-"` +} + type coreDNSConfig struct { binaryFile string coreFile string - FilterFile string `yaml:"-"` - Port int `yaml:"port"` - ProtectionEnabled bool `yaml:"protection_enabled"` - FilteringEnabled bool `yaml:"filtering_enabled"` - SafeBrowsingEnabled bool `yaml:"safebrowsing_enabled"` - SafeSearchEnabled bool `yaml:"safesearch_enabled"` - ParentalEnabled bool `yaml:"parental_enabled"` - ParentalSensitivity int `yaml:"parental_sensitivity"` - BlockedResponseTTL int `yaml:"blocked_response_ttl"` - QueryLogEnabled bool `yaml:"querylog_enabled"` - Pprof string `yaml:"-"` - Cache string `yaml:"-"` - Prometheus string `yaml:"-"` - UpstreamDNS []string `yaml:"upstream_dns"` + Filters []coreDnsFilter `yaml:"-"` + Port int `yaml:"port"` + ProtectionEnabled bool `yaml:"protection_enabled"` + FilteringEnabled bool `yaml:"filtering_enabled"` + SafeBrowsingEnabled bool `yaml:"safebrowsing_enabled"` + SafeSearchEnabled bool `yaml:"safesearch_enabled"` + ParentalEnabled bool `yaml:"parental_enabled"` + ParentalSensitivity int `yaml:"parental_sensitivity"` + BlockedResponseTTL int `yaml:"blocked_response_ttl"` + QueryLogEnabled bool `yaml:"querylog_enabled"` + Pprof string `yaml:"-"` + Cache string `yaml:"-"` + Prometheus string `yaml:"-"` + UpstreamDNS []string `yaml:"upstream_dns"` } type filter struct { + ID int64 `json:"id" yaml:"id"` // auto-assigned when filter is added (see NextFilterId) URL string `json:"url"` Name string `json:"name" yaml:"name"` Enabled bool `json:"enabled"` - RulesCount int `json:"rules_count" yaml:"-"` + RulesCount int `json:"rulesCount" yaml:"-"` contents []byte - LastUpdated time.Time `json:"last_updated" yaml:"-"` + LastUpdated time.Time `json:"lastUpdated" yaml:"last_updated"` } var defaultDNS = []string{"tls://1.1.1.1", "tls://1.0.0.1"} @@ -63,13 +88,13 @@ var defaultDNS = []string{"tls://1.1.1.1", "tls://1.0.0.1"} // initialize to default values, will be changed later when reading config or parsing command line var config = configuration{ ourConfigFilename: "AdGuardHome.yaml", + ourDataDir: "data", BindPort: 3000, BindHost: "127.0.0.1", CoreDNS: coreDNSConfig{ Port: 53, - binaryFile: "coredns", // only filename, no path - coreFile: "Corefile", // only filename, no path - FilterFile: "dnsfilter.txt", // only filename, no path + binaryFile: "coredns", // only filename, no path + coreFile: "Corefile", // only filename, no path ProtectionEnabled: true, FilteringEnabled: true, SafeBrowsingEnabled: false, @@ -80,22 +105,43 @@ var config = configuration{ Prometheus: "prometheus :9153", }, Filters: []filter{ - {Enabled: true, URL: "https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt"}, - {Enabled: false, URL: "https://adaway.org/hosts.txt", Name: "AdAway"}, - {Enabled: false, URL: "https://hosts-file.net/ad_servers.txt", Name: "hpHosts - Ad and Tracking servers only"}, - {Enabled: false, URL: "http://www.malwaredomainlist.com/hostslist/hosts.txt", Name: "MalwareDomainList.com Hosts List"}, + {ID: 1, Enabled: true, URL: "https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt", Name: "AdGuard Simplified Domain Names filter"}, + {ID: 2, Enabled: false, URL: "https://adaway.org/hosts.txt", Name: "AdAway"}, + {ID: 3, Enabled: false, URL: "https://hosts-file.net/ad_servers.txt", Name: "hpHosts - Ad and Tracking servers only"}, + {ID: 4, Enabled: false, URL: "http://www.malwaredomainlist.com/hostslist/hosts.txt", Name: "MalwareDomainList.com Hosts List"}, }, } +// Creates a helper object for working with the user rules +func getUserFilter() filter { + + // TODO: This should be calculated when UserRules are set + var contents []byte + for _, rule := range config.UserRules { + contents = append(contents, []byte(rule)...) + contents = append(contents, '\n') + } + + userFilter := filter{ + // User filter always has constant ID=0 + ID: UserFilterId, + contents: contents, + Enabled: true, + } + + return userFilter +} + +// Loads configuration from the YAML file func parseConfig() error { - configfile := filepath.Join(config.ourBinaryDir, config.ourConfigFilename) - log.Printf("Reading YAML file: %s", configfile) - if _, err := os.Stat(configfile); os.IsNotExist(err) { + configFile := filepath.Join(config.ourBinaryDir, config.ourConfigFilename) + log.Printf("Reading YAML file: %s", configFile) + if _, err := os.Stat(configFile); os.IsNotExist(err) { // do nothing, file doesn't exist - log.Printf("YAML file doesn't exist, skipping: %s", configfile) + log.Printf("YAML file doesn't exist, skipping: %s", configFile) return nil } - yamlFile, err := ioutil.ReadFile(configfile) + yamlFile, err := ioutil.ReadFile(configFile) if err != nil { log.Printf("Couldn't read config file: %s", err) return err @@ -106,27 +152,54 @@ func parseConfig() error { return err } + // Deduplicate filters + { + i := 0 // output index, used for deletion later + urls := map[string]bool{} + for _, filter := range config.Filters { + if _, ok := urls[filter.URL]; !ok { + // we didn't see it before, keep it + urls[filter.URL] = true // remember the URL + config.Filters[i] = filter + i++ + } + } + // all entries we want to keep are at front, delete the rest + config.Filters = config.Filters[:i] + } + + // Set the next filter ID to max(filter.ID) + 1 + for i := range config.Filters { + if NextFilterId < config.Filters[i].ID { + NextFilterId = config.Filters[i].ID + 1 + } + } + return nil } +// Saves configuration to the YAML file and also saves the user filter contents to a file func writeConfig() error { - configfile := filepath.Join(config.ourBinaryDir, config.ourConfigFilename) - log.Printf("Writing YAML file: %s", configfile) + configFile := filepath.Join(config.ourBinaryDir, config.ourConfigFilename) + log.Printf("Writing YAML file: %s", configFile) yamlText, err := yaml.Marshal(&config) if err != nil { log.Printf("Couldn't generate YAML file: %s", err) return err } - err = ioutil.WriteFile(configfile+".tmp", yamlText, 0644) + err = writeFileSafe(configFile, yamlText) if err != nil { - log.Printf("Couldn't write YAML config: %s", err) + log.Printf("Couldn't save YAML config: %s", err) return err } - err = os.Rename(configfile+".tmp", configfile) + + userFilter := getUserFilter() + err = userFilter.save() if err != nil { - log.Printf("Couldn't rename YAML config: %s", err) + log.Printf("Couldn't save the user filter: %s", err) return err } + return nil } @@ -134,22 +207,19 @@ func writeConfig() error { // coredns config // -------------- func writeCoreDNSConfig() error { - corefile := filepath.Join(config.ourBinaryDir, config.CoreDNS.coreFile) - log.Printf("Writing DNS config: %s", corefile) - configtext, err := generateCoreDNSConfigText() + coreFile := filepath.Join(config.ourBinaryDir, config.CoreDNS.coreFile) + log.Printf("Writing DNS config: %s", coreFile) + configText, err := generateCoreDNSConfigText() if err != nil { log.Printf("Couldn't generate DNS config: %s", err) return err } - err = ioutil.WriteFile(corefile+".tmp", []byte(configtext), 0644) + err = writeFileSafe(coreFile, []byte(configText)) if err != nil { - log.Printf("Couldn't write DNS config: %s", err) + log.Printf("Couldn't save DNS config: %s", err) + return err } - err = os.Rename(corefile+".tmp", corefile) - if err != nil { - log.Printf("Couldn't rename DNS config: %s", err) - } - return err + return nil } func writeAllConfigs() error { @@ -167,12 +237,17 @@ func writeAllConfigs() error { } const coreDNSConfigTemplate = `.:{{.Port}} { - {{if .ProtectionEnabled}}dnsfilter {{if .FilteringEnabled}}{{.FilterFile}}{{end}} { + {{if .ProtectionEnabled}}dnsfilter { {{if .SafeBrowsingEnabled}}safebrowsing{{end}} {{if .ParentalEnabled}}parental {{.ParentalSensitivity}}{{end}} {{if .SafeSearchEnabled}}safesearch{{end}} {{if .QueryLogEnabled}}querylog{{end}} blocked_ttl {{.BlockedResponseTTL}} + {{if .FilteringEnabled}} + {{range .Filters}} + filter {{.ID}} "{{.Path}}" + {{end}} + {{end}} }{{end}} {{.Pprof}} hosts { @@ -186,7 +261,7 @@ const coreDNSConfigTemplate = `.:{{.Port}} { var removeEmptyLines = regexp.MustCompile("([\t ]*\n)+") -// generate config text +// generate CoreDNS config text func generateCoreDNSConfigText() (string, error) { t, err := template.New("config").Parse(coreDNSConfigTemplate) if err != nil { @@ -196,16 +271,36 @@ func generateCoreDNSConfigText() (string, error) { var configBytes bytes.Buffer temporaryConfig := config.CoreDNS - temporaryConfig.FilterFile = filepath.Join(config.ourBinaryDir, config.CoreDNS.FilterFile) + + // fill the list of filters + filters := make([]coreDnsFilter, 0) + + // first of all, append the user filter + userFilter := getUserFilter() + + if len(userFilter.contents) > 0 { + filters = append(filters, coreDnsFilter{ID: userFilter.ID, Path: userFilter.getFilterFilePath()}) + } + + // then go through other filters + for i := range config.Filters { + filter := &config.Filters[i] + + if filter.Enabled && len(filter.contents) > 0 { + filters = append(filters, coreDnsFilter{ID: filter.ID, Path: filter.getFilterFilePath()}) + } + } + temporaryConfig.Filters = filters + // run the template err = t.Execute(&configBytes, &temporaryConfig) if err != nil { log.Printf("Couldn't generate DNS config: %s", err) return "", err } - configtext := configBytes.String() + configText := configBytes.String() // remove empty lines from generated config - configtext = removeEmptyLines.ReplaceAllString(configtext, "\n") - return configtext, nil + configText = removeEmptyLines.ReplaceAllString(configText, "\n") + return configText, nil } diff --git a/control.go b/control.go index 1afd4e24..1e1084e8 100644 --- a/control.go +++ b/control.go @@ -15,15 +15,14 @@ import ( "strings" "time" - coredns_plugin "github.com/AdguardTeam/AdGuardHome/coredns_plugin" - "github.com/AdguardTeam/AdGuardHome/dnsfilter" + corednsplugin "github.com/AdguardTeam/AdGuardHome/coredns_plugin" "github.com/miekg/dns" "gopkg.in/asaskevich/govalidator.v4" ) const updatePeriod = time.Minute * 30 -var filterTitle = regexp.MustCompile(`^! Title: +(.*)$`) +var filterTitleRegexp = regexp.MustCompile(`^! Title: +(.*)$`) // cached version.json to avoid hammering github.io for each page reload var versionCheckJSON []byte @@ -40,7 +39,7 @@ var client = &http.Client{ // coredns run control // ------------------- func tellCoreDNSToReload() { - coredns_plugin.Reload <- true + corednsplugin.Reload <- true } func writeAllConfigsAndReloadCoreDNS() error { @@ -64,6 +63,7 @@ func httpUpdateConfigReloadDNSReturnOK(w http.ResponseWriter, r *http.Request) { returnOK(w, r) } +//noinspection GoUnusedParameter func returnOK(w http.ResponseWriter, r *http.Request) { _, err := fmt.Fprintf(w, "OK\n") if err != nil { @@ -73,6 +73,7 @@ func returnOK(w http.ResponseWriter, r *http.Request) { } } +//noinspection GoUnusedParameter func handleStatus(w http.ResponseWriter, r *http.Request) { data := map[string]interface{}{ "dns_address": config.BindHost, @@ -237,7 +238,7 @@ func checkDNS(input string) error { resp, rtt, err := c.Exchange(&req, host) if err != nil { - return fmt.Errorf("Couldn't communicate with DNS server %s: %s", input, err) + return fmt.Errorf("couldn't communicate with DNS server %s: %s", input, err) } trace("exchange with %s took %v", input, rtt) if len(resp.Answer) != 1 { @@ -254,7 +255,7 @@ func checkDNS(input string) error { func sanitiseDNSServers(input string) ([]string, error) { fields := strings.Fields(input) - hosts := []string{} + hosts := make([]string, 0) for _, field := range fields { sanitized, err := sanitizeDNSServer(field) if err != nil { @@ -292,7 +293,7 @@ func sanitizeDNSServer(input string) (string, error) { } ip := net.ParseIP(h) if ip == nil { - return "", fmt.Errorf("Invalid DNS server field: %s", h) + return "", fmt.Errorf("invalid DNS server field: %s", h) } } return prefix + host, nil @@ -311,6 +312,7 @@ func appendPortIfMissing(prefix, input string) string { return net.JoinHostPort(input, port) } +//noinspection GoUnusedParameter func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) { now := time.Now() if now.Sub(versionCheckLastTime) <= versionCheckPeriod && len(versionCheckJSON) != 0 { @@ -366,6 +368,7 @@ func handleFilteringDisable(w http.ResponseWriter, r *http.Request) { httpUpdateConfigReloadDNSReturnOK(w, r) } +//noinspection GoUnusedParameter func handleFilteringStatus(w http.ResponseWriter, r *http.Request) { data := map[string]interface{}{ "enabled": config.CoreDNS.FilteringEnabled, @@ -395,6 +398,7 @@ func handleFilteringStatus(w http.ResponseWriter, r *http.Request) { } func handleFilteringAddURL(w http.ResponseWriter, r *http.Request) { + filter := filter{} err := json.NewDecoder(r.Body).Decode(&filter) if err != nil { @@ -402,7 +406,6 @@ func handleFilteringAddURL(w http.ResponseWriter, r *http.Request) { return } - filter.Enabled = true if len(filter.URL) == 0 { http.Error(w, "URL parameter was not specified", 400) return @@ -413,33 +416,48 @@ func handleFilteringAddURL(w http.ResponseWriter, r *http.Request) { return } - // check for duplicates + // Check for duplicates for i := range config.Filters { if config.Filters[i].URL == filter.URL { - errortext := fmt.Sprintf("Filter URL already added -- %s", filter.URL) - log.Println(errortext) - http.Error(w, errortext, http.StatusBadRequest) + errorText := fmt.Sprintf("Filter URL already added -- %s", filter.URL) + log.Println(errorText) + http.Error(w, errorText, http.StatusBadRequest) return } } - ok, err := filter.update(time.Now()) + // Set necessary properties + filter.ID = NextFilterId + filter.Enabled = true + NextFilterId++ + + // Download the filter contents + ok, err := filter.update(true) if err != nil { - errortext := fmt.Sprintf("Couldn't fetch filter from url %s: %s", filter.URL, err) - log.Println(errortext) - http.Error(w, errortext, http.StatusBadRequest) + errorText := fmt.Sprintf("Couldn't fetch filter from url %s: %s", filter.URL, err) + log.Println(errorText) + http.Error(w, errorText, http.StatusBadRequest) return } if filter.RulesCount == 0 { - errortext := fmt.Sprintf("Filter at url %s has no rules (maybe it points to blank page?)", filter.URL) - log.Println(errortext) - http.Error(w, errortext, http.StatusBadRequest) + errorText := fmt.Sprintf("Filter at the url %s has no rules (maybe it points to blank page?)", filter.URL) + log.Println(errorText) + http.Error(w, errorText, http.StatusBadRequest) return } if !ok { - errortext := fmt.Sprintf("Filter at url %s is invalid (maybe it points to blank page?)", filter.URL) - log.Println(errortext) - http.Error(w, errortext, http.StatusBadRequest) + errorText := fmt.Sprintf("Filter at the url %s is invalid (maybe it points to blank page?)", filter.URL) + log.Println(errorText) + http.Error(w, errorText, http.StatusBadRequest) + return + } + + // Save the filter contents + err = filter.save() + if err != nil { + errorText := fmt.Sprintf("Failed to save filter %d due to %s", filter.ID, err) + log.Println(errorText) + http.Error(w, errorText, http.StatusBadRequest) return } @@ -447,33 +465,28 @@ func handleFilteringAddURL(w http.ResponseWriter, r *http.Request) { config.Filters = append(config.Filters, filter) err = writeAllConfigs() if err != nil { - errortext := fmt.Sprintf("Couldn't write config file: %s", err) - log.Println(errortext) - http.Error(w, errortext, http.StatusInternalServerError) - return - } - err = writeFilterFile() - if err != nil { - errortext := fmt.Sprintf("Couldn't write filter file: %s", err) - log.Println(errortext) - http.Error(w, errortext, http.StatusInternalServerError) + errorText := fmt.Sprintf("Couldn't write config file: %s", err) + log.Println(errorText) + http.Error(w, errorText, http.StatusInternalServerError) return } + tellCoreDNSToReload() + _, err = fmt.Fprintf(w, "OK %d rules\n", filter.RulesCount) if err != nil { - errortext := fmt.Sprintf("Couldn't write body: %s", err) - log.Println(errortext) - http.Error(w, errortext, http.StatusInternalServerError) + errorText := fmt.Sprintf("Couldn't write body: %s", err) + log.Println(errorText) + http.Error(w, errorText, http.StatusInternalServerError) } } func handleFilteringRemoveURL(w http.ResponseWriter, r *http.Request) { parameters, err := parseParametersFromBody(r.Body) if err != nil { - errortext := fmt.Sprintf("failed to parse parameters from body: %s", err) - log.Println(errortext) - http.Error(w, errortext, 400) + errorText := fmt.Sprintf("failed to parse parameters from body: %s", err) + log.Println(errorText) + http.Error(w, errorText, 400) return } @@ -493,25 +506,27 @@ func handleFilteringRemoveURL(w http.ResponseWriter, r *http.Request) { for _, filter := range config.Filters { if filter.URL != url { newFilters = append(newFilters, filter) + } else { + // Remove the filter file + err := os.Remove(filter.getFilterFilePath()) + if err != nil { + errorText := fmt.Sprintf("Couldn't remove the filter file: %s", err) + http.Error(w, errorText, http.StatusInternalServerError) + return + } } } + // Update the configuration after removing filter files config.Filters = newFilters - err = writeFilterFile() - if err != nil { - errortext := fmt.Sprintf("Couldn't write filter file: %s", err) - log.Println(errortext) - http.Error(w, errortext, http.StatusInternalServerError) - return - } httpUpdateConfigReloadDNSReturnOK(w, r) } func handleFilteringEnableURL(w http.ResponseWriter, r *http.Request) { parameters, err := parseParametersFromBody(r.Body) if err != nil { - errortext := fmt.Sprintf("failed to parse parameters from body: %s", err) - log.Println(errortext) - http.Error(w, errortext, 400) + errorText := fmt.Sprintf("failed to parse parameters from body: %s", err) + log.Println(errorText) + http.Error(w, errorText, 400) return } @@ -541,23 +556,16 @@ func handleFilteringEnableURL(w http.ResponseWriter, r *http.Request) { } // kick off refresh of rules from new URLs - refreshFiltersIfNeccessary() - err = writeFilterFile() - if err != nil { - errortext := fmt.Sprintf("Couldn't write filter file: %s", err) - log.Println(errortext) - http.Error(w, errortext, http.StatusInternalServerError) - return - } + checkFiltersUpdates(false) httpUpdateConfigReloadDNSReturnOK(w, r) } func handleFilteringDisableURL(w http.ResponseWriter, r *http.Request) { parameters, err := parseParametersFromBody(r.Body) if err != nil { - errortext := fmt.Sprintf("failed to parse parameters from body: %s", err) - log.Println(errortext) - http.Error(w, errortext, 400) + errorText := fmt.Sprintf("failed to parse parameters from body: %s", err) + log.Println(errorText) + http.Error(w, errorText, 400) return } @@ -586,116 +594,108 @@ func handleFilteringDisableURL(w http.ResponseWriter, r *http.Request) { return } - err = writeFilterFile() - if err != nil { - errortext := fmt.Sprintf("Couldn't write filter file: %s", err) - log.Println(errortext) - http.Error(w, errortext, http.StatusInternalServerError) - return - } httpUpdateConfigReloadDNSReturnOK(w, r) } func handleFilteringSetRules(w http.ResponseWriter, r *http.Request) { body, err := ioutil.ReadAll(r.Body) if err != nil { - errortext := fmt.Sprintf("Failed to read request body: %s", err) - log.Println(errortext) - http.Error(w, errortext, 400) + errorText := fmt.Sprintf("Failed to read request body: %s", err) + log.Println(errorText) + http.Error(w, errorText, 400) return } config.UserRules = strings.Split(string(body), "\n") - err = writeFilterFile() - if err != nil { - errortext := fmt.Sprintf("Couldn't write filter file: %s", err) - log.Println(errortext) - http.Error(w, errortext, http.StatusInternalServerError) - return - } httpUpdateConfigReloadDNSReturnOK(w, r) } func handleFilteringRefresh(w http.ResponseWriter, r *http.Request) { force := r.URL.Query().Get("force") - if force != "" { - config.Lock() - for i := range config.Filters { - filter := &config.Filters[i] // otherwise we will be operating on a copy - filter.LastUpdated = time.Unix(0, 0) - } - config.Unlock() // not defer because refreshFiltersIfNeccessary locks it too - } - updated := refreshFiltersIfNeccessary() + updated := checkFiltersUpdates(force != "") fmt.Fprintf(w, "OK %d filters updated\n", updated) } -func runFilterRefreshers() { +// Sets up a timer that will be checking for filters updates periodically +func runFiltersUpdatesTimer() { go func() { - for range time.Tick(time.Second) { - refreshFiltersIfNeccessary() + for range time.Tick(time.Minute) { + checkFiltersUpdates(false) } }() } -func refreshFiltersIfNeccessary() int { - now := time.Now() +// Checks filters updates if necessary +// If force is true, it ignores the filter.LastUpdated field value +func checkFiltersUpdates(force bool) int { config.Lock() - // deduplicate - // TODO: move it somewhere else - { - i := 0 // output index, used for deletion later - urls := map[string]bool{} - for _, filter := range config.Filters { - if _, ok := urls[filter.URL]; !ok { - // we didn't see it before, keep it - urls[filter.URL] = true // remember the URL - config.Filters[i] = filter - i++ - } - } - // all entries we want to keep are at front, delete the rest - config.Filters = config.Filters[:i] - } - // fetch URLs updateCount := 0 for i := range config.Filters { filter := &config.Filters[i] // otherwise we will be operating on a copy - updated, err := filter.update(now) + updated, err := filter.update(force) if err != nil { log.Printf("Failed to update filter %s: %s\n", filter.URL, err) continue } if updated { + // Saving it to the filters dir now + err = filter.save() + if err != nil { + log.Printf("Failed to save the updated filter %d: %s", filter.ID, err) + continue + } + updateCount++ } } config.Unlock() if updateCount > 0 { - err := writeFilterFile() - if err != nil { - errortext := fmt.Sprintf("Couldn't write filter file: %s", err) - log.Println(errortext) - } tellCoreDNSToReload() } return updateCount } -func (filter *filter) update(now time.Time) (bool, error) { +// A helper function that parses filter contents and returns a number of rules and a filter name (if there's any) +func parseFilterContents(contents []byte) (int, string) { + lines := strings.Split(string(contents), "\n") + rulesCount := 0 + name := "" + seenTitle := false + + // Count lines in the filter + for _, line := range lines { + line = strings.TrimSpace(line) + if len(line) > 0 && line[0] == '!' { + if m := filterTitleRegexp.FindAllStringSubmatch(line, -1); len(m) > 0 && len(m[0]) >= 2 && !seenTitle { + name = m[0][1] + seenTitle = true + } + } else if len(line) != 0 { + rulesCount++ + } + } + + return rulesCount, name +} + +// Checks for filters updates +// If "force" is true -- does not check the filter's LastUpdated field +// Call "save" to persist the filter contents +func (filter *filter) update(force bool) (bool, error) { if !filter.Enabled { return false, nil } - elapsed := time.Since(filter.LastUpdated) - if elapsed <= updatePeriod { + if !force && time.Since(filter.LastUpdated) <= updatePeriod { return false, nil } - // use same update period for failed filter downloads to avoid flooding with requests - filter.LastUpdated = now + log.Printf("Downloading update for filter %d from %s", filter.ID, filter.URL) + + // use the same update period for failed filter downloads to avoid flooding with requests + filter.LastUpdated = time.Now() resp, err := client.Get(filter.URL) if resp != nil && resp.Body != nil { @@ -706,9 +706,15 @@ func (filter *filter) update(now time.Time) (bool, error) { return false, err } - if resp.StatusCode >= 400 { + if resp.StatusCode != 200 { log.Printf("Got status code %d from URL %s, skipping", resp.StatusCode, filter.URL) - return false, fmt.Errorf("Got status code >= 400: %d", resp.StatusCode) + return false, fmt.Errorf("got status code != 200: %d", resp.StatusCode) + } + + contentType := strings.ToLower(resp.Header.Get("content-type")) + if !strings.HasPrefix(contentType, "text/plain") { + log.Printf("Non-text response %s from %s, skipping", contentType, filter.URL) + return false, fmt.Errorf("non-text response %s", contentType) } body, err := ioutil.ReadAll(resp.Body) @@ -717,74 +723,76 @@ func (filter *filter) update(now time.Time) (bool, error) { return false, err } - // extract filter name and count number of rules - lines := strings.Split(string(body), "\n") - rulesCount := 0 - seenTitle := false - d := dnsfilter.New() - for _, line := range lines { - line = strings.TrimSpace(line) - if len(line) > 0 && line[0] == '!' { - if m := filterTitle.FindAllStringSubmatch(line, -1); len(m) > 0 && len(m[0]) >= 2 && !seenTitle { - filter.Name = m[0][1] - seenTitle = true - } - } else if len(line) != 0 { - err = d.AddRule(line, 0) - if err == dnsfilter.ErrAlreadyExists || err == dnsfilter.ErrInvalidSyntax { - continue - } - if err != nil { - log.Printf("Cannot add rule %s from %s: %s", line, filter.URL, err) - // Just ignore invalid rules - continue - } - rulesCount++ - } + // Extract filter name and count number of rules + rulesCount, filterName := parseFilterContents(body) + + if filterName != "" { + filter.Name = filterName } + + // Check if the filter has been really changed if bytes.Equal(filter.contents, body) { + log.Printf("The filter %d text has not changed", filter.ID) return false, nil } - log.Printf("Filter %s updated: %d bytes, %d rules", filter.URL, len(body), rulesCount) + + log.Printf("Filter %d has been updated: %d bytes, %d rules", filter.ID, len(body), rulesCount) filter.RulesCount = rulesCount filter.contents = body + return true, nil } -// write filter file -func writeFilterFile() error { - filterpath := filepath.Join(config.ourBinaryDir, config.CoreDNS.FilterFile) - log.Printf("Writing filter file: %s", filterpath) - // TODO: check if file contents have modified - data := []byte{} - config.RLock() - filters := config.Filters - for _, filter := range filters { - if !filter.Enabled { - continue - } - data = append(data, filter.contents...) - data = append(data, '\n') - } - for _, rule := range config.UserRules { - data = append(data, []byte(rule)...) - data = append(data, '\n') - } - config.RUnlock() - err := ioutil.WriteFile(filterpath+".tmp", data, 0644) +// saves filter contents to the file in config.ourDataDir +func (filter *filter) save() error { + + filterFilePath := filter.getFilterFilePath() + log.Printf("Saving filter %d contents to: %s", filter.ID, filterFilePath) + + err := writeFileSafe(filterFilePath, filter.contents) if err != nil { - log.Printf("Couldn't write filter file: %s", err) return err } - err = os.Rename(filterpath+".tmp", filterpath) - if err != nil { - log.Printf("Couldn't rename filter file: %s", err) + return nil +} + +// loads filter contents from the file in config.ourDataDir +func (filter *filter) load() error { + + if !filter.Enabled { + // No need to load a filter that is not enabled + return nil + } + + filterFilePath := filter.getFilterFilePath() + log.Printf("Loading filter %d contents to: %s", filter.ID, filterFilePath) + + if _, err := os.Stat(filterFilePath); os.IsNotExist(err) { + // do nothing, file doesn't exist return err } + + filterFileContents, err := ioutil.ReadFile(filterFilePath) + if err != nil { + return err + } + + log.Printf("Filter %d length is %d", filter.ID, len(filterFileContents)) + filter.contents = filterFileContents + + // Now extract the rules count + rulesCount, _ := parseFilterContents(filter.contents) + filter.RulesCount = rulesCount + return nil } +// Path to the filter contents +func (filter *filter) getFilterFilePath() string { + return filepath.Join(config.ourBinaryDir, config.ourDataDir, FiltersDir, strconv.FormatInt(filter.ID, 10)+".txt") +} + // ------------ // safebrowsing // ------------ @@ -799,6 +807,7 @@ func handleSafeBrowsingDisable(w http.ResponseWriter, r *http.Request) { httpUpdateConfigReloadDNSReturnOK(w, r) } +//noinspection GoUnusedParameter func handleSafeBrowsingStatus(w http.ResponseWriter, r *http.Request) { data := map[string]interface{}{ "enabled": config.CoreDNS.SafeBrowsingEnabled, @@ -874,6 +883,7 @@ func handleParentalDisable(w http.ResponseWriter, r *http.Request) { httpUpdateConfigReloadDNSReturnOK(w, r) } +//noinspection GoUnusedParameter func handleParentalStatus(w http.ResponseWriter, r *http.Request) { data := map[string]interface{}{ "enabled": config.CoreDNS.ParentalEnabled, @@ -913,6 +923,7 @@ func handleSafeSearchDisable(w http.ResponseWriter, r *http.Request) { httpUpdateConfigReloadDNSReturnOK(w, r) } +//noinspection GoUnusedParameter func handleSafeSearchStatus(w http.ResponseWriter, r *http.Request) { data := map[string]interface{}{ "enabled": config.CoreDNS.SafeSearchEnabled, @@ -939,15 +950,15 @@ func registerControlHandlers() { http.HandleFunc("/control/status", optionalAuth(ensureGET(handleStatus))) http.HandleFunc("/control/enable_protection", optionalAuth(ensurePOST(handleProtectionEnable))) http.HandleFunc("/control/disable_protection", optionalAuth(ensurePOST(handleProtectionDisable))) - http.HandleFunc("/control/querylog", optionalAuth(ensureGET(coredns_plugin.HandleQueryLog))) + http.HandleFunc("/control/querylog", optionalAuth(ensureGET(corednsplugin.HandleQueryLog))) http.HandleFunc("/control/querylog_enable", optionalAuth(ensurePOST(handleQueryLogEnable))) http.HandleFunc("/control/querylog_disable", optionalAuth(ensurePOST(handleQueryLogDisable))) http.HandleFunc("/control/set_upstream_dns", optionalAuth(ensurePOST(handleSetUpstreamDNS))) http.HandleFunc("/control/test_upstream_dns", optionalAuth(ensurePOST(handleTestUpstreamDNS))) - http.HandleFunc("/control/stats_top", optionalAuth(ensureGET(coredns_plugin.HandleStatsTop))) - http.HandleFunc("/control/stats", optionalAuth(ensureGET(coredns_plugin.HandleStats))) - http.HandleFunc("/control/stats_history", optionalAuth(ensureGET(coredns_plugin.HandleStatsHistory))) - http.HandleFunc("/control/stats_reset", optionalAuth(ensurePOST(coredns_plugin.HandleStatsReset))) + http.HandleFunc("/control/stats_top", optionalAuth(ensureGET(corednsplugin.HandleStatsTop))) + http.HandleFunc("/control/stats", optionalAuth(ensureGET(corednsplugin.HandleStats))) + http.HandleFunc("/control/stats_history", optionalAuth(ensureGET(corednsplugin.HandleStatsHistory))) + http.HandleFunc("/control/stats_reset", optionalAuth(ensurePOST(corednsplugin.HandleStatsReset))) http.HandleFunc("/control/version.json", optionalAuth(handleGetVersionJSON)) http.HandleFunc("/control/filtering/enable", optionalAuth(ensurePOST(handleFilteringEnable))) http.HandleFunc("/control/filtering/disable", optionalAuth(ensurePOST(handleFilteringDisable))) diff --git a/coredns.go b/coredns.go index b6941f2b..5dbe01b4 100644 --- a/coredns.go +++ b/coredns.go @@ -120,12 +120,6 @@ func startDNSServer() error { log.Println(errortext) return errortext } - err = writeFilterFile() - if err != nil { - errortext := fmt.Errorf("Couldn't write filter file: %s", err) - log.Println(errortext) - return errortext - } go coremain.Run() return nil diff --git a/coredns_plugin/coredns_plugin.go b/coredns_plugin/coredns_plugin.go index 31b147cf..5209e37b 100644 --- a/coredns_plugin/coredns_plugin.go +++ b/coredns_plugin/coredns_plugin.go @@ -51,11 +51,17 @@ var ( lookupCache = map[string]cacheEntry{} ) +type plugFilter struct { + ID int64 + Path string +} + type plugSettings struct { SafeBrowsingBlockHost string ParentalBlockHost string QueryLogEnabled bool BlockedTTL uint32 // in seconds, default 3600 + Filters []plugFilter } type plug struct { @@ -71,6 +77,7 @@ var defaultPluginSettings = plugSettings{ SafeBrowsingBlockHost: "safebrowsing.block.dns.adguard.com", ParentalBlockHost: "family.block.dns.adguard.com", BlockedTTL: 3600, // in seconds + Filters: make([]plugFilter, 0), } // @@ -83,15 +90,14 @@ func setupPlugin(c *caddy.Controller) (*plug, error) { d: dnsfilter.New(), } - filterFileNames := []string{} + log.Println("Initializing the CoreDNS plugin") + for c.Next() { - args := c.RemainingArgs() - if len(args) > 0 { - filterFileNames = append(filterFileNames, args...) - } for c.NextBlock() { - switch c.Val() { + blockValue := c.Val() + switch blockValue { case "safebrowsing": + log.Println("Browsing security service is enabled") p.d.EnableSafeBrowsing() if c.NextArg() { if len(c.Val()) == 0 { @@ -100,6 +106,7 @@ func setupPlugin(c *caddy.Controller) (*plug, error) { p.d.SetSafeBrowsingServer(c.Val()) } case "safesearch": + log.Println("Safe search is enabled") p.d.EnableSafeSearch() case "parental": if !c.NextArg() { @@ -109,6 +116,8 @@ func setupPlugin(c *caddy.Controller) (*plug, error) { if err != nil { return nil, c.ArgErr() } + + log.Println("Parental control is enabled") err = p.d.EnableParental(sensitivity) if err != nil { return nil, c.ArgErr() @@ -123,24 +132,46 @@ func setupPlugin(c *caddy.Controller) (*plug, error) { if !c.NextArg() { return nil, c.ArgErr() } - blockttl, err := strconv.ParseUint(c.Val(), 10, 32) + blockedTtl, err := strconv.ParseUint(c.Val(), 10, 32) if err != nil { return nil, c.ArgErr() } - p.settings.BlockedTTL = uint32(blockttl) + log.Printf("Blocked request TTL is %d", blockedTtl) + p.settings.BlockedTTL = uint32(blockedTtl) case "querylog": + log.Println("Query log is enabled") p.settings.QueryLogEnabled = true + case "filter": + if !c.NextArg() { + return nil, c.ArgErr() + } + + filterId, err := strconv.ParseInt(c.Val(), 10, 64) + if err != nil { + return nil, c.ArgErr() + } + if !c.NextArg() { + return nil, c.ArgErr() + } + filterPath := c.Val() + + // Initialize filter and add it to the list + p.settings.Filters = append(p.settings.Filters, plugFilter{ + ID: filterId, + Path: filterPath, + }) } } } - log.Printf("filterFileNames = %+v", filterFileNames) + for _, filter := range p.settings.Filters { + log.Printf("Loading rules from %s", filter.Path) - for i, filterFileName := range filterFileNames { - file, err := os.Open(filterFileName) + file, err := os.Open(filter.Path) if err != nil { return nil, err } + //noinspection GoDeferInLoop defer file.Close() count := 0 @@ -148,7 +179,7 @@ func setupPlugin(c *caddy.Controller) (*plug, error) { for scanner.Scan() { text := scanner.Text() - err = p.d.AddRule(text, uint32(i)) + err = p.d.AddRule(text, filter.ID) if err == dnsfilter.ErrAlreadyExists || err == dnsfilter.ErrInvalidSyntax { continue } @@ -159,7 +190,7 @@ func setupPlugin(c *caddy.Controller) (*plug, error) { } count++ } - log.Printf("Added %d rules from %s", count, filterFileName) + log.Printf("Added %d rules from filter ID=%d", count, filter.ID) if err = scanner.Err(); err != nil { return nil, err @@ -250,6 +281,7 @@ func (p *plug) onFinalShutdown() error { type statsFunc func(ch interface{}, name string, text string, value float64, valueType prometheus.ValueType) +//noinspection GoUnusedParameter func doDesc(ch interface{}, name string, text string, value float64, valueType prometheus.ValueType) { realch, ok := ch.(chan<- *prometheus.Desc) if !ok { @@ -391,7 +423,7 @@ func (p *plug) writeNXdomain(ctx context.Context, w dns.ResponseWriter, r *dns.M func (p *plug) serveDNSInternal(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, dnsfilter.Result, error) { if len(r.Question) != 1 { // google DNS, bind and others do the same - return dns.RcodeFormatError, dnsfilter.Result{}, fmt.Errorf("Got DNS request with != 1 questions") + return dns.RcodeFormatError, dnsfilter.Result{}, fmt.Errorf("got a DNS request with more than one Question") } for _, question := range r.Question { host := strings.ToLower(strings.TrimSuffix(question.Name, ".")) diff --git a/coredns_plugin/coredns_plugin_test.go b/coredns_plugin/coredns_plugin_test.go index 2f65cf9a..1733fd6f 100644 --- a/coredns_plugin/coredns_plugin_test.go +++ b/coredns_plugin/coredns_plugin_test.go @@ -21,10 +21,20 @@ func TestSetup(t *testing.T) { failing bool }{ {`dnsfilter`, false}, - {`dnsfilter /dev/nonexistent/abcdef`, true}, - {`dnsfilter ../tests/dns.txt`, false}, - {`dnsfilter ../tests/dns.txt { safebrowsing }`, false}, - {`dnsfilter ../tests/dns.txt { parental }`, true}, + {`dnsfilter { + filter 0 /dev/nonexistent/abcdef + }`, true}, + {`dnsfilter { + filter 0 ../tests/dns.txt + }`, false}, + {`dnsfilter { + safebrowsing + filter 0 ../tests/dns.txt + }`, false}, + {`dnsfilter { + parental + filter 0 ../tests/dns.txt + }`, true}, } { c := caddy.NewTestController("dns", testcase.config) err := setup(c) @@ -55,7 +65,8 @@ func TestEtcHostsFilter(t *testing.T) { defer os.Remove(tmpfile.Name()) - c := caddy.NewTestController("dns", fmt.Sprintf("dnsfilter %s", tmpfile.Name())) + configText := fmt.Sprintf("dnsfilter {\nfilter 0 %s\n}", tmpfile.Name()) + c := caddy.NewTestController("dns", configText) p, err := setupPlugin(c) if err != nil { t.Fatal(err) diff --git a/coredns_plugin/querylog.go b/coredns_plugin/querylog.go index b3e1a5f0..2280fcde 100644 --- a/coredns_plugin/querylog.go +++ b/coredns_plugin/querylog.go @@ -25,7 +25,6 @@ const ( queryLogFileName = "querylog.json" // .gz added during compression queryLogSize = 5000 // maximum API response for /querylog queryLogTopSize = 500 // Keep in memory only top N values - queryLogAPIPort = "8618" // 8618 is sha512sum of "querylog" then each byte summed ) var ( @@ -34,7 +33,6 @@ var ( queryLogCache []*logEntry queryLogLock sync.RWMutex - queryLogTime time.Time ) type logEntry struct { @@ -107,6 +105,7 @@ func logRequest(question *dns.Msg, answer *dns.Msg, result dnsfilter.Result, ela } } +//noinspection GoUnusedParameter func HandleQueryLog(w http.ResponseWriter, r *http.Request) { queryLogLock.RLock() values := make([]*logEntry, len(queryLogCache)) @@ -140,14 +139,14 @@ func HandleQueryLog(w http.ResponseWriter, r *http.Request) { } } - jsonentry := map[string]interface{}{ + jsonEntry := map[string]interface{}{ "reason": entry.Result.Reason.String(), - "elapsed_ms": strconv.FormatFloat(entry.Elapsed.Seconds()*1000, 'f', -1, 64), + "elapsedMs": strconv.FormatFloat(entry.Elapsed.Seconds()*1000, 'f', -1, 64), "time": entry.Time.Format(time.RFC3339), "client": entry.IP, } if q != nil { - jsonentry["question"] = map[string]interface{}{ + jsonEntry["question"] = map[string]interface{}{ "host": strings.ToLower(strings.TrimSuffix(q.Question[0].Name, ".")), "type": dns.Type(q.Question[0].Qtype).String(), "class": dns.Class(q.Question[0].Qclass).String(), @@ -156,10 +155,11 @@ func HandleQueryLog(w http.ResponseWriter, r *http.Request) { if a != nil { status, _ := response.Typify(a, time.Now().UTC()) - jsonentry["status"] = status.String() + jsonEntry["status"] = status.String() } if len(entry.Result.Rule) > 0 { - jsonentry["rule"] = entry.Result.Rule + jsonEntry["rule"] = entry.Result.Rule + jsonEntry["filterId"] = entry.Result.FilterID } if a != nil && len(a.Answer) > 0 { @@ -202,26 +202,26 @@ func HandleQueryLog(w http.ResponseWriter, r *http.Request) { } answers = append(answers, answer) } - jsonentry["answer"] = answers + jsonEntry["answer"] = answers } - data = append(data, jsonentry) + data = append(data, jsonEntry) } jsonVal, err := json.Marshal(data) if err != nil { - errortext := fmt.Sprintf("Couldn't marshal data into json: %s", err) - log.Println(errortext) - http.Error(w, errortext, http.StatusInternalServerError) + errorText := fmt.Sprintf("Couldn't marshal data into json: %s", err) + log.Println(errorText) + http.Error(w, errorText, http.StatusInternalServerError) return } w.Header().Set("Content-Type", "application/json") _, err = w.Write(jsonVal) if err != nil { - errortext := fmt.Sprintf("Unable to write response json: %s", err) - log.Println(errortext) - http.Error(w, errortext, http.StatusInternalServerError) + errorText := fmt.Sprintf("Unable to write response json: %s", err) + log.Println(errorText) + http.Error(w, errorText, http.StatusInternalServerError) } } diff --git a/dnsfilter/dnsfilter.go b/dnsfilter/dnsfilter.go index 96ba2512..5fb5d0ed 100644 --- a/dnsfilter/dnsfilter.go +++ b/dnsfilter/dnsfilter.go @@ -71,7 +71,7 @@ type rule struct { isImportant bool // user-supplied data - listID uint32 + listID int64 // suffix matching isSuffix bool @@ -146,7 +146,7 @@ type Result struct { Reason Reason `json:",omitempty"` // Reason for blocking / unblocking Rule string `json:",omitempty"` // Original rule text Ip net.IP `json:",omitempty"` // Not nil only in the case of a hosts file syntax - FilterID uint32 `json:",omitempty"` // Filter ID the rule belongs to + FilterID int64 `json:",omitempty"` // Filter ID the rule belongs to } // Matched can be used to see if any match at all was found, no matter filtered or not @@ -499,11 +499,12 @@ func (rule *rule) match(host string) (Result, error) { if matched { res.Reason = FilteredBlackList res.IsFiltered = true + res.FilterID = rule.listID + res.Rule = rule.originalText if rule.isWhitelist { res.Reason = NotFilteredWhiteList res.IsFiltered = false } - res.Rule = rule.text } return res, nil } @@ -733,7 +734,7 @@ func (d *Dnsfilter) lookupCommon(host string, lookupstats *LookupStats, cache gc // // AddRule adds a rule, checking if it is a valid rule first and if it wasn't added already -func (d *Dnsfilter) AddRule(input string, filterListID uint32) error { +func (d *Dnsfilter) AddRule(input string, filterListID int64) error { input = strings.TrimSpace(input) d.storageMutex.RLock() _, exists := d.storage[input] @@ -796,7 +797,7 @@ func (d *Dnsfilter) AddRule(input string, filterListID uint32) error { } // Parses the hosts-syntax rules. Returns false if the input string is not of hosts-syntax. -func (d *Dnsfilter) parseEtcHosts(input string, filterListID uint32) bool { +func (d *Dnsfilter) parseEtcHosts(input string, filterListID int64) bool { // Strip the trailing comment ruleText := input if pos := strings.IndexByte(ruleText, '#'); pos != -1 { diff --git a/helpers.go b/helpers.go index 6d598224..7ae69b8a 100644 --- a/helpers.go +++ b/helpers.go @@ -5,21 +5,39 @@ import ( "errors" "fmt" "io" + "io/ioutil" "net/http" "os" "path" + "path/filepath" "runtime" "strings" ) -func clamp(value, low, high int) int { - if value < low { - return low +// ---------------------------------- +// helper functions for working with files +// ---------------------------------- + +// Writes data first to a temporary file and then renames it to what's specified in path +func writeFileSafe(path string, data []byte) error { + + dir := filepath.Dir(path) + err := os.MkdirAll(dir, 0755) + if err != nil { + return err } - if value > high { - return high + + tmpPath := path + ".tmp" + err = ioutil.WriteFile(tmpPath, data, 0644) + if err != nil { + return err } - return value + err = os.Rename(tmpPath, path) + if err != nil { + return err + } + + return nil } // ---------------------------------- @@ -117,13 +135,6 @@ func parseParametersFromBody(r io.Reader) (map[string]string, error) { // --------------------- // debug logging helpers // --------------------- -func _Func() string { - pc := make([]uintptr, 10) // at least 1 entry needed - runtime.Callers(2, pc) - f := runtime.FuncForPC(pc[0]) - return path.Base(f.Name()) -} - func trace(format string, args ...interface{}) { pc := make([]uintptr, 10) // at least 1 entry needed runtime.Callers(2, pc) diff --git a/openapi.yaml b/openapi.yaml index 16421815..35e32a90 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -92,7 +92,7 @@ paths: - ttl: 55 type: A value: 217.69.139.200 - elapsed_ms: '65.469556' + elapsedMs: '65.469556' question: class: IN host: mail.ru @@ -100,7 +100,7 @@ paths: reason: DNSFILTER_NOTFILTERED_NOTFOUND status: NOERROR time: '2018-07-16T22:24:02+03:00' - - elapsed_ms: '0.15716999999999998' + - elapsedMs: '0.15716999999999998' question: class: IN host: doubleclick.net @@ -113,13 +113,14 @@ paths: - ttl: 299 type: A value: 176.103.133.78 - elapsed_ms: '132.110929' + elapsedMs: '132.110929' question: class: IN host: wmconvirus.narod.ru type: A reason: DNSFILTER_FILTERED_SAFEBROWSING rule: adguard-malware-shavar + filterId: 1 status: NOERROR time: '2018-07-16T22:24:02+03:00' /querylog_enable: @@ -448,9 +449,13 @@ paths: examples: application/json: enabled: false - urls: - - 'https://filters.adtidy.org/windows/filters/1.txt' - - 'https://filters.adtidy.org/windows/filters/2.txt' + - filters: + enabled: true + id: 1 + lastUpdated: "2018-10-30T12:18:57.223101822+03:00" + name: "AdGuard Simplified Domain Names filter" + rulesCount: 24896 + url: "https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt" rules: - '@@||yandex.ru^|' /filtering/set_rules: