mirror of
https://github.com/AdguardTeam/AdGuardHome.git
synced 2024-12-15 19:31:45 +03:00
Merge branch 'master' into i18n
This commit is contained in:
commit
67c8abcb8e
@ -106,8 +106,10 @@ Settings are stored in [YAML format](https://en.wikipedia.org/wiki/YAML), possib
|
||||
* `parental_enabled` — Parental control-based DNS requests filtering
|
||||
* `parental_sensitivity` — Age group for parental control-based filtering, must be either 3, 10, 13 or 17
|
||||
* `querylog_enabled` — Query logging (also used to calculate top 50 clients, blocked domains and requested domains for statistic purposes)
|
||||
* `bootstrap_dns` — DNS server used for initial hostnames resolution in case if upstream is DoH or DoT with a hostname
|
||||
* `upstream_dns` — List of upstream DNS servers
|
||||
* `filters` — List of filters, each filter has the following values:
|
||||
* `ID` - filter ID (must be unique)
|
||||
* `url` — URL pointing to the filter contents (filtering rules)
|
||||
* `enabled` — Current filter's status (enabled/disabled)
|
||||
* `user_rules` — User-specified filtering rules
|
||||
|
@ -70,6 +70,7 @@ type coreDNSConfig struct {
|
||||
Pprof string `yaml:"-"`
|
||||
Cache string `yaml:"-"`
|
||||
Prometheus string `yaml:"-"`
|
||||
BootstrapDNS string `yaml:"bootstrap_dns"`
|
||||
UpstreamDNS []string `yaml:"upstream_dns"`
|
||||
}
|
||||
|
||||
@ -100,6 +101,7 @@ var config = configuration{
|
||||
SafeBrowsingEnabled: false,
|
||||
BlockedResponseTTL: 10, // in seconds
|
||||
QueryLogEnabled: true,
|
||||
BootstrapDNS: "8.8.8.8:53",
|
||||
UpstreamDNS: defaultDNS,
|
||||
Cache: "cache",
|
||||
Prometheus: "prometheus :9153",
|
||||
@ -253,7 +255,7 @@ const coreDNSConfigTemplate = `.:{{.Port}} {
|
||||
hosts {
|
||||
fallthrough
|
||||
}
|
||||
{{if .UpstreamDNS}}forward . {{range .UpstreamDNS}}{{.}} {{end}}{{end}}
|
||||
{{if .UpstreamDNS}}upstream {{range .UpstreamDNS}}{{.}} {{end}} { bootstrap {{.BootstrapDNS}} }{{end}}
|
||||
{{.Cache}}
|
||||
{{.Prometheus}}
|
||||
}
|
||||
|
147
control.go
147
control.go
@ -6,7 +6,6 @@ import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@ -15,8 +14,9 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/upstream"
|
||||
|
||||
corednsplugin "github.com/AdguardTeam/AdGuardHome/coredns_plugin"
|
||||
"github.com/miekg/dns"
|
||||
"gopkg.in/asaskevich/govalidator.v4"
|
||||
)
|
||||
|
||||
@ -81,6 +81,7 @@ func handleStatus(w http.ResponseWriter, r *http.Request) {
|
||||
"protection_enabled": config.CoreDNS.ProtectionEnabled,
|
||||
"querylog_enabled": config.CoreDNS.QueryLogEnabled,
|
||||
"running": isRunning(),
|
||||
"bootstrap_dns": config.CoreDNS.BootstrapDNS,
|
||||
"upstream_dns": config.CoreDNS.UpstreamDNS,
|
||||
"version": VersionString,
|
||||
}
|
||||
@ -134,17 +135,14 @@ func httpError(w http.ResponseWriter, code int, format string, args ...interface
|
||||
func handleSetUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
errortext := fmt.Sprintf("Failed to read request body: %s", err)
|
||||
log.Println(errortext)
|
||||
http.Error(w, errortext, http.StatusBadRequest)
|
||||
errorText := fmt.Sprintf("Failed to read request body: %s", err)
|
||||
log.Println(errorText)
|
||||
http.Error(w, errorText, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
// if empty body -- user is asking for default servers
|
||||
hosts, err := sanitiseDNSServers(string(body))
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "Invalid DNS servers were given: %s", err)
|
||||
return
|
||||
}
|
||||
hosts := strings.Fields(string(body))
|
||||
|
||||
if len(hosts) == 0 {
|
||||
config.CoreDNS.UpstreamDNS = defaultDNS
|
||||
} else {
|
||||
@ -153,34 +151,34 @@ func handleSetUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
err = writeAllConfigs()
|
||||
if err != nil {
|
||||
errortext := fmt.Sprintf("Couldn't write config file: %s", err)
|
||||
log.Println(errortext)
|
||||
http.Error(w, errortext, http.StatusInternalServerError)
|
||||
errorText := fmt.Sprintf("Couldn't write config file: %s", err)
|
||||
log.Println(errorText)
|
||||
http.Error(w, errorText, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
tellCoreDNSToReload()
|
||||
_, err = fmt.Fprintf(w, "OK %d servers\n", len(hosts))
|
||||
if err != nil {
|
||||
errortext := fmt.Sprintf("Couldn't write body: %s", err)
|
||||
log.Println(errortext)
|
||||
http.Error(w, errortext, http.StatusInternalServerError)
|
||||
errorText := fmt.Sprintf("Couldn't write body: %s", err)
|
||||
log.Println(errorText)
|
||||
http.Error(w, errorText, http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
func handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
errortext := fmt.Sprintf("Failed to read request body: %s", err)
|
||||
log.Println(errortext)
|
||||
http.Error(w, errortext, 400)
|
||||
errorText := fmt.Sprintf("Failed to read request body: %s", err)
|
||||
log.Println(errorText)
|
||||
http.Error(w, errorText, 400)
|
||||
return
|
||||
}
|
||||
hosts := strings.Fields(string(body))
|
||||
|
||||
if len(hosts) == 0 {
|
||||
errortext := fmt.Sprintf("No servers specified")
|
||||
log.Println(errortext)
|
||||
http.Error(w, errortext, http.StatusBadRequest)
|
||||
errorText := fmt.Sprintf("No servers specified")
|
||||
log.Println(errorText)
|
||||
http.Error(w, errorText, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
@ -198,120 +196,43 @@ func handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
jsonVal, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
errortext := fmt.Sprintf("Unable to marshal status json: %s", err)
|
||||
log.Println(errortext)
|
||||
http.Error(w, errortext, http.StatusInternalServerError)
|
||||
errorText := fmt.Sprintf("Unable to marshal status json: %s", err)
|
||||
log.Println(errorText)
|
||||
http.Error(w, errorText, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, err = w.Write(jsonVal)
|
||||
if err != nil {
|
||||
errortext := fmt.Sprintf("Couldn't write body: %s", err)
|
||||
log.Println(errortext)
|
||||
http.Error(w, errortext, http.StatusInternalServerError)
|
||||
errorText := fmt.Sprintf("Couldn't write body: %s", err)
|
||||
log.Println(errorText)
|
||||
http.Error(w, errorText, http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
func checkDNS(input string) error {
|
||||
input, err := sanitizeDNSServer(input)
|
||||
|
||||
u, err := upstream.NewUpstream(input, config.CoreDNS.BootstrapDNS)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer u.Close()
|
||||
|
||||
req := dns.Msg{}
|
||||
req.Id = dns.Id()
|
||||
req.RecursionDesired = true
|
||||
req.Question = []dns.Question{
|
||||
{Name: "google-public-dns-a.google.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
|
||||
}
|
||||
alive, err := upstream.IsAlive(u)
|
||||
|
||||
prefix, host := splitDNSServerPrefixServer(input)
|
||||
|
||||
c := dns.Client{
|
||||
Timeout: time.Minute,
|
||||
}
|
||||
switch prefix {
|
||||
case "tls://":
|
||||
c.Net = "tcp-tls"
|
||||
}
|
||||
|
||||
resp, rtt, err := c.Exchange(&req, host)
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't communicate with DNS server %s: %s", input, err)
|
||||
}
|
||||
trace("exchange with %s took %v", input, rtt)
|
||||
if len(resp.Answer) != 1 {
|
||||
return fmt.Errorf("DNS server %s returned wrong answer", input)
|
||||
}
|
||||
if t, ok := resp.Answer[0].(*dns.A); ok {
|
||||
if !net.IPv4(8, 8, 8, 8).Equal(t.A) {
|
||||
return fmt.Errorf("DNS server %s returned wrong answer: %v", input, t.A)
|
||||
}
|
||||
|
||||
if !alive {
|
||||
return fmt.Errorf("DNS server has not passed the healthcheck: %s", input)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func sanitiseDNSServers(input string) ([]string, error) {
|
||||
fields := strings.Fields(input)
|
||||
hosts := make([]string, 0)
|
||||
for _, field := range fields {
|
||||
sanitized, err := sanitizeDNSServer(field)
|
||||
if err != nil {
|
||||
return hosts, err
|
||||
}
|
||||
hosts = append(hosts, sanitized)
|
||||
}
|
||||
return hosts, nil
|
||||
}
|
||||
|
||||
func getDNSServerPrefix(input string) string {
|
||||
prefix := ""
|
||||
switch {
|
||||
case strings.HasPrefix(input, "dns://"):
|
||||
prefix = "dns://"
|
||||
case strings.HasPrefix(input, "tls://"):
|
||||
prefix = "tls://"
|
||||
}
|
||||
return prefix
|
||||
}
|
||||
|
||||
func splitDNSServerPrefixServer(input string) (string, string) {
|
||||
prefix := getDNSServerPrefix(input)
|
||||
host := strings.TrimPrefix(input, prefix)
|
||||
return prefix, host
|
||||
}
|
||||
|
||||
func sanitizeDNSServer(input string) (string, error) {
|
||||
prefix, host := splitDNSServerPrefixServer(input)
|
||||
host = appendPortIfMissing(prefix, host)
|
||||
{
|
||||
h, _, err := net.SplitHostPort(host)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
ip := net.ParseIP(h)
|
||||
if ip == nil {
|
||||
return "", fmt.Errorf("invalid DNS server field: %s", h)
|
||||
}
|
||||
}
|
||||
return prefix + host, nil
|
||||
}
|
||||
|
||||
func appendPortIfMissing(prefix, input string) string {
|
||||
port := "53"
|
||||
switch prefix {
|
||||
case "tls://":
|
||||
port = "853"
|
||||
}
|
||||
_, _, err := net.SplitHostPort(input)
|
||||
if err == nil {
|
||||
return input
|
||||
}
|
||||
return net.JoinHostPort(input, port)
|
||||
}
|
||||
|
||||
//noinspection GoUnusedParameter
|
||||
func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
|
||||
now := time.Now()
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
"sync" // Include all plugins.
|
||||
|
||||
_ "github.com/AdguardTeam/AdGuardHome/coredns_plugin"
|
||||
_ "github.com/AdguardTeam/AdGuardHome/upstream"
|
||||
"github.com/coredns/coredns/core/dnsserver"
|
||||
"github.com/coredns/coredns/coremain"
|
||||
_ "github.com/coredns/coredns/plugin/auto"
|
||||
@ -79,6 +80,7 @@ var directives = []string{
|
||||
"loop",
|
||||
"forward",
|
||||
"proxy",
|
||||
"upstream",
|
||||
"erratic",
|
||||
"whoami",
|
||||
"on",
|
||||
|
@ -41,6 +41,7 @@ paths:
|
||||
protection_enabled: true
|
||||
querylog_enabled: true
|
||||
running: true
|
||||
bootstrap_dns: 8.8.8.8:53
|
||||
upstream_dns:
|
||||
- 1.1.1.1
|
||||
- 1.0.0.1
|
||||
|
109
upstream/dns_upstream.go
Normal file
109
upstream/dns_upstream.go
Normal file
@ -0,0 +1,109 @@
|
||||
package upstream
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
// DnsUpstream is a very simple upstream implementation for plain DNS
|
||||
type DnsUpstream struct {
|
||||
endpoint string // IP:port
|
||||
timeout time.Duration // Max read and write timeout
|
||||
proto string // Protocol (tcp, tcp-tls, or udp)
|
||||
transport *Transport // Persistent connections cache
|
||||
}
|
||||
|
||||
// NewDnsUpstream creates a new DNS upstream
|
||||
func NewDnsUpstream(endpoint string, proto string, tlsServerName string) (Upstream, error) {
|
||||
|
||||
u := &DnsUpstream{
|
||||
endpoint: endpoint,
|
||||
timeout: defaultTimeout,
|
||||
proto: proto,
|
||||
}
|
||||
|
||||
var tlsConfig *tls.Config
|
||||
|
||||
if proto == "tcp-tls" {
|
||||
tlsConfig = new(tls.Config)
|
||||
tlsConfig.ServerName = tlsServerName
|
||||
}
|
||||
|
||||
// Initialize the connections cache
|
||||
u.transport = NewTransport(endpoint)
|
||||
u.transport.tlsConfig = tlsConfig
|
||||
u.transport.Start()
|
||||
|
||||
return u, nil
|
||||
}
|
||||
|
||||
// Exchange provides an implementation for the Upstream interface
|
||||
func (u *DnsUpstream) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) {
|
||||
|
||||
resp, err := u.exchange(u.proto, query)
|
||||
|
||||
// Retry over TCP if response is truncated
|
||||
if err == dns.ErrTruncated && u.proto == "udp" {
|
||||
resp, err = u.exchange("tcp", query)
|
||||
} else if err == dns.ErrTruncated && resp != nil {
|
||||
// Reassemble something to be sent to client
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(query)
|
||||
m.Truncated = true
|
||||
m.Authoritative = true
|
||||
m.Rcode = dns.RcodeSuccess
|
||||
return m, nil
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
resp = &dns.Msg{}
|
||||
resp.SetRcode(resp, dns.RcodeServerFailure)
|
||||
}
|
||||
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// Clear resources
|
||||
func (u *DnsUpstream) Close() error {
|
||||
|
||||
// Close active connections
|
||||
u.transport.Stop()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Performs a synchronous query. It sends the message m via the conn
|
||||
// c and waits for a reply. The conn c is not closed.
|
||||
func (u *DnsUpstream) exchange(proto string, query *dns.Msg) (r *dns.Msg, err error) {
|
||||
|
||||
// Establish a connection if needed (or reuse cached)
|
||||
conn, err := u.transport.Dial(proto)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Write the request with a timeout
|
||||
conn.SetWriteDeadline(time.Now().Add(u.timeout))
|
||||
if err = conn.WriteMsg(query); err != nil {
|
||||
conn.Close() // Not giving it back
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Write response with a timeout
|
||||
conn.SetReadDeadline(time.Now().Add(u.timeout))
|
||||
r, err = conn.ReadMsg()
|
||||
if err != nil {
|
||||
conn.Close() // Not giving it back
|
||||
} else if err == nil && r.Id != query.Id {
|
||||
err = dns.ErrId
|
||||
conn.Close() // Not giving it back
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
// Return it back to the connections cache if there were no errors
|
||||
u.transport.Yield(conn)
|
||||
}
|
||||
return r, err
|
||||
}
|
101
upstream/helpers.go
Normal file
101
upstream/helpers.go
Normal file
@ -0,0 +1,101 @@
|
||||
package upstream
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
// Detects the upstream type from the specified url and creates a proper Upstream object
|
||||
func NewUpstream(url string, bootstrap string) (Upstream, error) {
|
||||
|
||||
proto := "udp"
|
||||
prefix := ""
|
||||
|
||||
switch {
|
||||
case strings.HasPrefix(url, "tcp://"):
|
||||
proto = "tcp"
|
||||
prefix = "tcp://"
|
||||
case strings.HasPrefix(url, "tls://"):
|
||||
proto = "tcp-tls"
|
||||
prefix = "tls://"
|
||||
case strings.HasPrefix(url, "https://"):
|
||||
return NewHttpsUpstream(url, bootstrap)
|
||||
}
|
||||
|
||||
hostname := strings.TrimPrefix(url, prefix)
|
||||
|
||||
host, port, err := net.SplitHostPort(hostname)
|
||||
if err != nil {
|
||||
// Set port depending on the protocol
|
||||
switch proto {
|
||||
case "udp":
|
||||
port = "53"
|
||||
case "tcp":
|
||||
port = "53"
|
||||
case "tcp-tls":
|
||||
port = "853"
|
||||
}
|
||||
|
||||
// Set host = hostname
|
||||
host = hostname
|
||||
}
|
||||
|
||||
// Try to resolve the host address (or check if it's an IP address)
|
||||
bootstrapResolver := CreateResolver(bootstrap)
|
||||
ips, err := bootstrapResolver.LookupIPAddr(context.Background(), host)
|
||||
|
||||
if err != nil || len(ips) == 0 {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
addr := ips[0].String()
|
||||
endpoint := net.JoinHostPort(addr, port)
|
||||
tlsServerName := ""
|
||||
|
||||
if proto == "tcp-tls" && host != addr {
|
||||
// Check if we need to specify TLS server name
|
||||
tlsServerName = host
|
||||
}
|
||||
|
||||
return NewDnsUpstream(endpoint, proto, tlsServerName)
|
||||
}
|
||||
|
||||
func CreateResolver(bootstrap string) *net.Resolver {
|
||||
|
||||
bootstrapResolver := net.DefaultResolver
|
||||
|
||||
if bootstrap != "" {
|
||||
bootstrapResolver = &net.Resolver{
|
||||
PreferGo: true,
|
||||
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
var d net.Dialer
|
||||
return d.DialContext(ctx, network, bootstrap)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return bootstrapResolver
|
||||
}
|
||||
|
||||
// Performs a simple health-check of the specified upstream
|
||||
func IsAlive(u Upstream) (bool, error) {
|
||||
|
||||
// Using ipv4only.arpa. domain as it is a part of DNS64 RFC and it should exist everywhere
|
||||
ping := new(dns.Msg)
|
||||
ping.SetQuestion("ipv4only.arpa.", dns.TypeA)
|
||||
|
||||
resp, err := u.Exchange(context.Background(), ping)
|
||||
|
||||
// If we got a header, we're alright, basically only care about I/O errors 'n stuff.
|
||||
if err != nil && resp != nil {
|
||||
// Silly check, something sane came back.
|
||||
if resp.Rcode != dns.RcodeServerFailure {
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
|
||||
return err == nil, err
|
||||
}
|
128
upstream/https_upstream.go
Normal file
128
upstream/https_upstream.go
Normal file
@ -0,0 +1,128 @@
|
||||
package upstream
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/net/context"
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
const (
|
||||
dnsMessageContentType = "application/dns-message"
|
||||
defaultKeepAlive = 30 * time.Second
|
||||
)
|
||||
|
||||
// HttpsUpstream is the upstream implementation for DNS-over-HTTPS
|
||||
type HttpsUpstream struct {
|
||||
client *http.Client
|
||||
endpoint *url.URL
|
||||
}
|
||||
|
||||
// NewHttpsUpstream creates a new DNS-over-HTTPS upstream from the specified url
|
||||
func NewHttpsUpstream(endpoint string, bootstrap string) (Upstream, error) {
|
||||
u, err := url.Parse(endpoint)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Initialize bootstrap resolver
|
||||
bootstrapResolver := CreateResolver(bootstrap)
|
||||
dialer := &net.Dialer{
|
||||
Timeout: defaultTimeout,
|
||||
KeepAlive: defaultKeepAlive,
|
||||
DualStack: true,
|
||||
Resolver: bootstrapResolver,
|
||||
}
|
||||
|
||||
// Update TLS and HTTP client configuration
|
||||
tlsConfig := &tls.Config{ServerName: u.Hostname()}
|
||||
transport := &http.Transport{
|
||||
TLSClientConfig: tlsConfig,
|
||||
DisableCompression: true,
|
||||
MaxIdleConns: 1,
|
||||
DialContext: dialer.DialContext,
|
||||
}
|
||||
http2.ConfigureTransport(transport)
|
||||
|
||||
client := &http.Client{
|
||||
Timeout: defaultTimeout,
|
||||
Transport: transport,
|
||||
}
|
||||
|
||||
return &HttpsUpstream{client: client, endpoint: u}, nil
|
||||
}
|
||||
|
||||
// Exchange provides an implementation for the Upstream interface
|
||||
func (u *HttpsUpstream) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) {
|
||||
queryBuf, err := query.Pack()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to pack DNS query")
|
||||
}
|
||||
|
||||
// No content negotiation for now, use DNS wire format
|
||||
buf, backendErr := u.exchangeWireformat(queryBuf)
|
||||
if backendErr == nil {
|
||||
response := &dns.Msg{}
|
||||
if err := response.Unpack(buf); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to unpack DNS response from body")
|
||||
}
|
||||
|
||||
response.Id = query.Id
|
||||
return response, nil
|
||||
}
|
||||
|
||||
log.Printf("failed to connect to an HTTPS backend %q due to %s", u.endpoint, backendErr)
|
||||
return nil, backendErr
|
||||
}
|
||||
|
||||
// Perform message exchange with the default UDP wireformat defined in current draft
|
||||
// https://tools.ietf.org/html/draft-ietf-doh-dns-over-https-10
|
||||
func (u *HttpsUpstream) exchangeWireformat(msg []byte) ([]byte, error) {
|
||||
req, err := http.NewRequest("POST", u.endpoint.String(), bytes.NewBuffer(msg))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to create an HTTPS request")
|
||||
}
|
||||
|
||||
req.Header.Add("Content-Type", dnsMessageContentType)
|
||||
req.Header.Add("Accept", dnsMessageContentType)
|
||||
req.Host = u.endpoint.Hostname()
|
||||
|
||||
resp, err := u.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to perform an HTTPS request")
|
||||
}
|
||||
|
||||
// Check response status code
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("returned status code %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
if contentType != dnsMessageContentType {
|
||||
return nil, fmt.Errorf("return wrong content type %s", contentType)
|
||||
}
|
||||
|
||||
// Read application/dns-message response from the body
|
||||
buf, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to read the response body")
|
||||
}
|
||||
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
// Clear resources
|
||||
func (u *HttpsUpstream) Close() error {
|
||||
return nil
|
||||
}
|
210
upstream/persistent.go
Normal file
210
upstream/persistent.go
Normal file
@ -0,0 +1,210 @@
|
||||
package upstream
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"sort"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// Persistent connections cache -- almost similar to the same used in the CoreDNS forward plugin
|
||||
|
||||
const (
|
||||
defaultExpire = 10 * time.Second
|
||||
minDialTimeout = 100 * time.Millisecond
|
||||
maxDialTimeout = 30 * time.Second
|
||||
defaultDialTimeout = 30 * time.Second
|
||||
cumulativeAvgWeight = 4
|
||||
)
|
||||
|
||||
// a persistConn hold the dns.Conn and the last used time.
|
||||
type persistConn struct {
|
||||
c *dns.Conn
|
||||
used time.Time
|
||||
}
|
||||
|
||||
// Transport hold the persistent cache.
|
||||
type Transport struct {
|
||||
avgDialTime int64 // kind of average time of dial time
|
||||
conns map[string][]*persistConn // Buckets for udp, tcp and tcp-tls.
|
||||
expire time.Duration // After this duration a connection is expired.
|
||||
addr string
|
||||
tlsConfig *tls.Config
|
||||
|
||||
dial chan string
|
||||
yield chan *dns.Conn
|
||||
ret chan *dns.Conn
|
||||
stop chan bool
|
||||
}
|
||||
|
||||
// Dial dials the address configured in transport, potentially reusing a connection or creating a new one.
|
||||
func (t *Transport) Dial(proto string) (*dns.Conn, error) {
|
||||
// If tls has been configured; use it.
|
||||
if t.tlsConfig != nil {
|
||||
proto = "tcp-tls"
|
||||
}
|
||||
|
||||
t.dial <- proto
|
||||
c := <-t.ret
|
||||
|
||||
if c != nil {
|
||||
return c, nil
|
||||
}
|
||||
|
||||
reqTime := time.Now()
|
||||
timeout := t.dialTimeout()
|
||||
if proto == "tcp-tls" {
|
||||
conn, err := dns.DialTimeoutWithTLS(proto, t.addr, t.tlsConfig, timeout)
|
||||
t.updateDialTimeout(time.Since(reqTime))
|
||||
return conn, err
|
||||
}
|
||||
conn, err := dns.DialTimeout(proto, t.addr, timeout)
|
||||
t.updateDialTimeout(time.Since(reqTime))
|
||||
return conn, err
|
||||
}
|
||||
|
||||
// Yield return the connection to transport for reuse.
|
||||
func (t *Transport) Yield(c *dns.Conn) { t.yield <- c }
|
||||
|
||||
// Start starts the transport's connection manager.
|
||||
func (t *Transport) Start() { go t.connManager() }
|
||||
|
||||
// Stop stops the transport's connection manager.
|
||||
func (t *Transport) Stop() { close(t.stop) }
|
||||
|
||||
// SetExpire sets the connection expire time in transport.
|
||||
func (t *Transport) SetExpire(expire time.Duration) { t.expire = expire }
|
||||
|
||||
// SetTLSConfig sets the TLS config in transport.
|
||||
func (t *Transport) SetTLSConfig(cfg *tls.Config) { t.tlsConfig = cfg }
|
||||
|
||||
func NewTransport(addr string) *Transport {
|
||||
t := &Transport{
|
||||
avgDialTime: int64(defaultDialTimeout / 2),
|
||||
conns: make(map[string][]*persistConn),
|
||||
expire: defaultExpire,
|
||||
addr: addr,
|
||||
dial: make(chan string),
|
||||
yield: make(chan *dns.Conn),
|
||||
ret: make(chan *dns.Conn),
|
||||
stop: make(chan bool),
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
func averageTimeout(currentAvg *int64, observedDuration time.Duration, weight int64) {
|
||||
dt := time.Duration(atomic.LoadInt64(currentAvg))
|
||||
atomic.AddInt64(currentAvg, int64(observedDuration-dt)/weight)
|
||||
}
|
||||
|
||||
func (t *Transport) dialTimeout() time.Duration {
|
||||
return limitTimeout(&t.avgDialTime, minDialTimeout, maxDialTimeout)
|
||||
}
|
||||
|
||||
func (t *Transport) updateDialTimeout(newDialTime time.Duration) {
|
||||
averageTimeout(&t.avgDialTime, newDialTime, cumulativeAvgWeight)
|
||||
}
|
||||
|
||||
// limitTimeout is a utility function to auto-tune timeout values
|
||||
// average observed time is moved towards the last observed delay moderated by a weight
|
||||
// next timeout to use will be the double of the computed average, limited by min and max frame.
|
||||
func limitTimeout(currentAvg *int64, minValue time.Duration, maxValue time.Duration) time.Duration {
|
||||
rt := time.Duration(atomic.LoadInt64(currentAvg))
|
||||
if rt < minValue {
|
||||
return minValue
|
||||
}
|
||||
if rt < maxValue/2 {
|
||||
return 2 * rt
|
||||
}
|
||||
return maxValue
|
||||
}
|
||||
|
||||
// connManagers manages the persistent connection cache for UDP and TCP.
|
||||
func (t *Transport) connManager() {
|
||||
ticker := time.NewTicker(t.expire)
|
||||
Wait:
|
||||
for {
|
||||
select {
|
||||
case proto := <-t.dial:
|
||||
// take the last used conn - complexity O(1)
|
||||
if stack := t.conns[proto]; len(stack) > 0 {
|
||||
pc := stack[len(stack)-1]
|
||||
if time.Since(pc.used) < t.expire {
|
||||
// Found one, remove from pool and return this conn.
|
||||
t.conns[proto] = stack[:len(stack)-1]
|
||||
t.ret <- pc.c
|
||||
continue Wait
|
||||
}
|
||||
// clear entire cache if the last conn is expired
|
||||
t.conns[proto] = nil
|
||||
// now, the connections being passed to closeConns() are not reachable from
|
||||
// transport methods anymore. So, it's safe to close them in a separate goroutine
|
||||
go closeConns(stack)
|
||||
}
|
||||
|
||||
t.ret <- nil
|
||||
|
||||
case conn := <-t.yield:
|
||||
|
||||
// no proto here, infer from config and conn
|
||||
if _, ok := conn.Conn.(*net.UDPConn); ok {
|
||||
t.conns["udp"] = append(t.conns["udp"], &persistConn{conn, time.Now()})
|
||||
continue Wait
|
||||
}
|
||||
|
||||
if t.tlsConfig == nil {
|
||||
t.conns["tcp"] = append(t.conns["tcp"], &persistConn{conn, time.Now()})
|
||||
continue Wait
|
||||
}
|
||||
|
||||
t.conns["tcp-tls"] = append(t.conns["tcp-tls"], &persistConn{conn, time.Now()})
|
||||
|
||||
case <-ticker.C:
|
||||
t.cleanup(false)
|
||||
|
||||
case <-t.stop:
|
||||
t.cleanup(true)
|
||||
close(t.ret)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// closeConns closes connections.
|
||||
func closeConns(conns []*persistConn) {
|
||||
for _, pc := range conns {
|
||||
pc.c.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// cleanup removes connections from cache.
|
||||
func (t *Transport) cleanup(all bool) {
|
||||
staleTime := time.Now().Add(-t.expire)
|
||||
for proto, stack := range t.conns {
|
||||
if len(stack) == 0 {
|
||||
continue
|
||||
}
|
||||
if all {
|
||||
t.conns[proto] = nil
|
||||
// now, the connections being passed to closeConns() are not reachable from
|
||||
// transport methods anymore. So, it's safe to close them in a separate goroutine
|
||||
go closeConns(stack)
|
||||
continue
|
||||
}
|
||||
if stack[0].used.After(staleTime) {
|
||||
continue
|
||||
}
|
||||
|
||||
// connections in stack are sorted by "used"
|
||||
good := sort.Search(len(stack), func(i int) bool {
|
||||
return stack[i].used.After(staleTime)
|
||||
})
|
||||
t.conns[proto] = stack[good:]
|
||||
// now, the connections being passed to closeConns() are not reachable from
|
||||
// transport methods anymore. So, it's safe to close them in a separate goroutine
|
||||
go closeConns(stack[:good])
|
||||
}
|
||||
}
|
84
upstream/setup.go
Normal file
84
upstream/setup.go
Normal file
@ -0,0 +1,84 @@
|
||||
package upstream
|
||||
|
||||
import (
|
||||
"log"
|
||||
|
||||
"github.com/coredns/coredns/core/dnsserver"
|
||||
"github.com/coredns/coredns/plugin"
|
||||
"github.com/mholt/caddy"
|
||||
)
|
||||
|
||||
func init() {
|
||||
caddy.RegisterPlugin("upstream", caddy.Plugin{
|
||||
ServerType: "dns",
|
||||
Action: setup,
|
||||
})
|
||||
}
|
||||
|
||||
// Read the configuration and initialize upstreams
|
||||
func setup(c *caddy.Controller) error {
|
||||
|
||||
p, err := setupPlugin(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
config := dnsserver.GetConfig(c)
|
||||
config.AddPlugin(func(next plugin.Handler) plugin.Handler {
|
||||
p.Next = next
|
||||
return p
|
||||
})
|
||||
|
||||
c.OnShutdown(p.onShutdown)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Read the configuration
|
||||
func setupPlugin(c *caddy.Controller) (*UpstreamPlugin, error) {
|
||||
|
||||
p := New()
|
||||
|
||||
log.Println("Initializing the Upstream plugin")
|
||||
|
||||
bootstrap := ""
|
||||
upstreamUrls := []string{}
|
||||
for c.Next() {
|
||||
args := c.RemainingArgs()
|
||||
if len(args) > 0 {
|
||||
upstreamUrls = append(upstreamUrls, args...)
|
||||
}
|
||||
for c.NextBlock() {
|
||||
switch c.Val() {
|
||||
case "bootstrap":
|
||||
if !c.NextArg() {
|
||||
return nil, c.ArgErr()
|
||||
}
|
||||
bootstrap = c.Val()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, url := range upstreamUrls {
|
||||
u, err := NewUpstream(url, bootstrap)
|
||||
if err != nil {
|
||||
log.Printf("Cannot initialize upstream %s", url)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
p.Upstreams = append(p.Upstreams, u)
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (p *UpstreamPlugin) onShutdown() error {
|
||||
for i := range p.Upstreams {
|
||||
|
||||
u := p.Upstreams[i]
|
||||
err := u.Close()
|
||||
if err != nil {
|
||||
log.Printf("Error while closing the upstream: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
30
upstream/setup_test.go
Normal file
30
upstream/setup_test.go
Normal file
@ -0,0 +1,30 @@
|
||||
package upstream
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/mholt/caddy"
|
||||
)
|
||||
|
||||
func TestSetup(t *testing.T) {
|
||||
|
||||
var tests = []struct {
|
||||
config string
|
||||
}{
|
||||
{`upstream 8.8.8.8`},
|
||||
{`upstream 8.8.8.8 {
|
||||
bootstrap 8.8.8.8:53
|
||||
}`},
|
||||
{`upstream tls://1.1.1.1 8.8.8.8 {
|
||||
bootstrap 1.1.1.1
|
||||
}`},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
c := caddy.NewTestController("dns", test.config)
|
||||
err := setup(c)
|
||||
if err != nil {
|
||||
t.Fatalf("Test failed")
|
||||
}
|
||||
}
|
||||
}
|
57
upstream/upstream.go
Normal file
57
upstream/upstream.go
Normal file
@ -0,0 +1,57 @@
|
||||
package upstream
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/coredns/coredns/plugin"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
// Upstream is a simplified interface for proxy destination
|
||||
type Upstream interface {
|
||||
Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error)
|
||||
Close() error
|
||||
}
|
||||
|
||||
// UpstreamPlugin is a simplified DNS proxy using a generic upstream interface
|
||||
type UpstreamPlugin struct {
|
||||
Upstreams []Upstream
|
||||
Next plugin.Handler
|
||||
}
|
||||
|
||||
// Initialize the upstream plugin
|
||||
func New() *UpstreamPlugin {
|
||||
p := &UpstreamPlugin{
|
||||
Upstreams: []Upstream{},
|
||||
}
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
// ServeDNS implements interface for CoreDNS plugin
|
||||
func (p *UpstreamPlugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||
var reply *dns.Msg
|
||||
var backendErr error
|
||||
|
||||
for i := range p.Upstreams {
|
||||
upstream := p.Upstreams[i]
|
||||
reply, backendErr = upstream.Exchange(ctx, r)
|
||||
if backendErr == nil {
|
||||
w.WriteMsg(reply)
|
||||
return 0, nil
|
||||
}
|
||||
}
|
||||
|
||||
return dns.RcodeServerFailure, errors.Wrap(backendErr, "failed to contact any of the upstreams")
|
||||
}
|
||||
|
||||
// Name implements interface for CoreDNS plugin
|
||||
func (p *UpstreamPlugin) Name() string {
|
||||
return "upstream"
|
||||
}
|
194
upstream/upstream_test.go
Normal file
194
upstream/upstream_test.go
Normal file
@ -0,0 +1,194 @@
|
||||
package upstream
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
func TestDnsUpstreamIsAlive(t *testing.T) {
|
||||
|
||||
var tests = []struct {
|
||||
url string
|
||||
bootstrap string
|
||||
}{
|
||||
{"8.8.8.8:53", "8.8.8.8:53"},
|
||||
{"1.1.1.1", ""},
|
||||
{"tcp://1.1.1.1:53", ""},
|
||||
{"176.103.130.130:5353", ""},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
u, err := NewUpstream(test.url, test.bootstrap)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("cannot create a DNS upstream")
|
||||
}
|
||||
|
||||
testUpstreamIsAlive(t, u)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHttpsUpstreamIsAlive(t *testing.T) {
|
||||
|
||||
var tests = []struct {
|
||||
url string
|
||||
bootstrap string
|
||||
}{
|
||||
{"https://cloudflare-dns.com/dns-query", "8.8.8.8:53"},
|
||||
{"https://dns.google.com/experimental", "8.8.8.8:53"},
|
||||
{"https://doh.cleanbrowsing.org/doh/security-filter/", ""},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
u, err := NewUpstream(test.url, test.bootstrap)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("cannot create a DNS-over-HTTPS upstream")
|
||||
}
|
||||
|
||||
testUpstreamIsAlive(t, u)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDnsOverTlsIsAlive(t *testing.T) {
|
||||
|
||||
var tests = []struct {
|
||||
url string
|
||||
bootstrap string
|
||||
}{
|
||||
{"tls://1.1.1.1", ""},
|
||||
{"tls://9.9.9.9:853", ""},
|
||||
{"tls://security-filter-dns.cleanbrowsing.org", "8.8.8.8:53"},
|
||||
{"tls://adult-filter-dns.cleanbrowsing.org:853", "8.8.8.8:53"},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
u, err := NewUpstream(test.url, test.bootstrap)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("cannot create a DNS-over-TLS upstream")
|
||||
}
|
||||
|
||||
testUpstreamIsAlive(t, u)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDnsUpstream(t *testing.T) {
|
||||
|
||||
var tests = []struct {
|
||||
url string
|
||||
bootstrap string
|
||||
}{
|
||||
{"8.8.8.8:53", "8.8.8.8:53"},
|
||||
{"1.1.1.1", ""},
|
||||
{"tcp://1.1.1.1:53", ""},
|
||||
{"176.103.130.130:5353", ""},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
u, err := NewUpstream(test.url, test.bootstrap)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("cannot create a DNS upstream")
|
||||
}
|
||||
|
||||
testUpstream(t, u)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHttpsUpstream(t *testing.T) {
|
||||
|
||||
var tests = []struct {
|
||||
url string
|
||||
bootstrap string
|
||||
}{
|
||||
{"https://cloudflare-dns.com/dns-query", "8.8.8.8:53"},
|
||||
{"https://dns.google.com/experimental", "8.8.8.8:53"},
|
||||
{"https://doh.cleanbrowsing.org/doh/security-filter/", ""},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
u, err := NewUpstream(test.url, test.bootstrap)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("cannot create a DNS-over-HTTPS upstream")
|
||||
}
|
||||
|
||||
testUpstream(t, u)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDnsOverTlsUpstream(t *testing.T) {
|
||||
|
||||
var tests = []struct {
|
||||
url string
|
||||
bootstrap string
|
||||
}{
|
||||
{"tls://1.1.1.1", ""},
|
||||
{"tls://9.9.9.9:853", ""},
|
||||
{"tls://security-filter-dns.cleanbrowsing.org", "8.8.8.8:53"},
|
||||
{"tls://adult-filter-dns.cleanbrowsing.org:853", "8.8.8.8:53"},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
u, err := NewUpstream(test.url, test.bootstrap)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("cannot create a DNS-over-TLS upstream")
|
||||
}
|
||||
|
||||
testUpstream(t, u)
|
||||
}
|
||||
}
|
||||
|
||||
func testUpstreamIsAlive(t *testing.T, u Upstream) {
|
||||
alive, err := IsAlive(u)
|
||||
if !alive || err != nil {
|
||||
t.Errorf("Upstream is not alive")
|
||||
}
|
||||
|
||||
u.Close()
|
||||
}
|
||||
|
||||
func testUpstream(t *testing.T, u Upstream) {
|
||||
|
||||
var tests = []struct {
|
||||
name string
|
||||
expected net.IP
|
||||
}{
|
||||
{"google-public-dns-a.google.com.", net.IPv4(8, 8, 8, 8)},
|
||||
{"google-public-dns-b.google.com.", net.IPv4(8, 8, 4, 4)},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
req := dns.Msg{}
|
||||
req.Id = dns.Id()
|
||||
req.RecursionDesired = true
|
||||
req.Question = []dns.Question{
|
||||
{Name: test.name, Qtype: dns.TypeA, Qclass: dns.ClassINET},
|
||||
}
|
||||
|
||||
resp, err := u.Exchange(context.Background(), &req)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("error while making an upstream request: %s", err)
|
||||
}
|
||||
|
||||
if len(resp.Answer) != 1 {
|
||||
t.Errorf("no answer section in the response")
|
||||
}
|
||||
if answer, ok := resp.Answer[0].(*dns.A); ok {
|
||||
if !test.expected.Equal(answer.A) {
|
||||
t.Errorf("wrong IP in the response: %v", answer.A)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err := u.Close()
|
||||
if err != nil {
|
||||
t.Errorf("Error while closing the upstream: %s", err)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user