diff --git a/config.go b/config.go index f9907feb..52efdfd9 100644 --- a/config.go +++ b/config.go @@ -46,14 +46,11 @@ type coreDNSConfig struct { dnsforward.FilteringConfig `yaml:",inline"` - QueryLogEnabled bool `yaml:"querylog_enabled"` - Ratelimit int `yaml:"ratelimit"` - RefuseAny bool `yaml:"refuse_any"` - Pprof string `yaml:"-"` - Cache string `yaml:"-"` - Prometheus string `yaml:"-"` - BootstrapDNS string `yaml:"bootstrap_dns"` - UpstreamDNS []string `yaml:"upstream_dns"` + Pprof string `yaml:"-"` + Cache string `yaml:"-"` + Prometheus string `yaml:"-"` + BootstrapDNS string `yaml:"bootstrap_dns"` + UpstreamDNS []string `yaml:"upstream_dns"` } var defaultDNS = []string{"tls://1.1.1.1", "tls://1.0.0.1"} @@ -71,14 +68,14 @@ var config = configuration{ ProtectionEnabled: true, // whether or not use any of dnsfilter features FilteringEnabled: true, // whether or not use filter lists BlockedResponseTTL: 10, // in seconds + QueryLogEnabled: true, + Ratelimit: 20, + RefuseAny: true, }, - QueryLogEnabled: true, - Ratelimit: 20, - RefuseAny: true, - BootstrapDNS: "8.8.8.8:53", - UpstreamDNS: defaultDNS, - Cache: "cache", - Prometheus: "prometheus :9153", + BootstrapDNS: "8.8.8.8:53", + UpstreamDNS: defaultDNS, + Cache: "cache", + Prometheus: "prometheus :9153", }, Filters: []filter{ {Filter: dnsfilter.Filter{ID: 1}, Enabled: true, URL: "https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt", Name: "AdGuard Simplified Domain Names filter"}, diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index 4b731d45..bee85d3a 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -12,6 +12,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/dnsfilter" "github.com/joomcode/errorx" "github.com/miekg/dns" + gocache "github.com/patrickmn/go-cache" ) // Server is the main way to start a DNS server. @@ -31,6 +32,8 @@ type Server struct { cache cache + ratelimitBuckets *gocache.Cache // where the ratelimiters are stored, per IP + sync.RWMutex ServerConfig } @@ -76,9 +79,13 @@ func (s *Server) RUnlock() { */ type FilteringConfig struct { - ProtectionEnabled bool `yaml:"protection_enabled"` - FilteringEnabled bool `yaml:"filtering_enabled"` - BlockedResponseTTL uint32 `yaml:"blocked_response_ttl"` // if 0, then default is used (3600) + ProtectionEnabled bool `yaml:"protection_enabled"` + FilteringEnabled bool `yaml:"filtering_enabled"` + BlockedResponseTTL uint32 `yaml:"blocked_response_ttl"` // if 0, then default is used (3600) + QueryLogEnabled bool `yaml:"querylog_enabled"` + Ratelimit int `yaml:"ratelimit"` + RatelimitWhitelist []string `yaml:"ratelimit_whitelist"` + RefuseAny bool `yaml:"refuse_any"` dnsfilter.Config `yaml:",inline"` } @@ -92,6 +99,7 @@ type ServerConfig struct { FilteringConfig } +// if any of ServerConfig values are zero, then default values from below are used var defaultValues = ServerConfig{ UDPListenAddr: &net.UDPAddr{Port: 53}, FilteringConfig: FilteringConfig{BlockedResponseTTL: 3600}, @@ -413,6 +421,10 @@ func (s *Server) handlePacketInternal(msg *dns.Msg, addr net.Addr, conn *net.UDP return s.genServerFailure(msg), nil, nil, nil } + if msg.Question[0].Qtype == dns.TypeANY && s.RefuseAny { + return s.genNotImpl(msg), nil, nil, nil + } + // use dnsfilter before cache -- changed settings or filters would require cache invalidation otherwise host := strings.TrimSuffix(msg.Question[0].Name, ".") res, err := s.dnsFilter.CheckHost(host) @@ -450,16 +462,36 @@ func (s *Server) handlePacketInternal(msg *dns.Msg, addr net.Addr, conn *net.UDP func (s *Server) handlePacket(p []byte, addr net.Addr, conn *net.UDPConn) { start := time.Now() + ip, _, err := net.SplitHostPort(addr.String()) + if err != nil { + log.Printf("Failed to split %v into host/port: %s", addr, err) + // not a fatal error, move on + } + + // ratelimit based on IP only, protects CPU cycles and outbound connections + if s.isRatelimited(ip) { + // log.Printf("Ratelimiting %s based on IP only", ip) + return // do nothing, don't reply, we got ratelimited + } msg := &dns.Msg{} - err := msg.Unpack(p) + err = msg.Unpack(p) if err != nil { log.Printf("got invalid DNS packet: %s", err) return // do nothing } reply, result, upstream, err := s.handlePacketInternal(msg, addr, conn) + if reply != nil { + // ratelimit based on reply size now + replysize := reply.Len() + if s.isRatelimitedForReply(ip, replysize) { + log.Printf("Ratelimiting %s based on IP and size %d", ip, replysize) + return // do nothing, don't reply, we got ratelimited + } + + // we're good to respond rerr := s.respond(reply, addr, conn) if rerr != nil { log.Printf("Couldn't respond to UDP packet: %s", err) @@ -467,16 +499,14 @@ func (s *Server) handlePacket(p []byte, addr net.Addr, conn *net.UDPConn) { } // query logging and stats counters - elapsed := time.Since(start) - upstreamAddr := "" - if upstream != nil { - upstreamAddr = upstream.Address() + if s.QueryLogEnabled { + elapsed := time.Since(start) + upstreamAddr := "" + if upstream != nil { + upstreamAddr = upstream.Address() + } + logRequest(msg, reply, result, elapsed, ip, upstreamAddr) } - host, _, err := net.SplitHostPort(addr.String()) - if err != nil { - log.Printf("Failed to split %v into host/port: %s", addr, err) - } - logRequest(msg, reply, result, elapsed, host, upstreamAddr) } // @@ -506,12 +536,22 @@ func (s *Server) respond(resp *dns.Msg, addr net.Addr, conn *net.UDPConn) error func (s *Server) genServerFailure(request *dns.Msg) *dns.Msg { resp := dns.Msg{} resp.SetRcode(request, dns.RcodeServerFailure) + resp.RecursionAvailable = true + return &resp +} + +func (s *Server) genNotImpl(request *dns.Msg) *dns.Msg { + resp := dns.Msg{} + resp.SetRcode(request, dns.RcodeNotImplemented) + resp.RecursionAvailable = true + resp.SetEdns0(1452, false) // NOTIMPL without EDNS is treated as 'we don't support EDNS', so explicitly set it return &resp } func (s *Server) genNXDomain(request *dns.Msg) *dns.Msg { resp := dns.Msg{} resp.SetRcode(request, dns.RcodeNameError) + resp.RecursionAvailable = true resp.Ns = s.genSOA(request) return &resp } diff --git a/dnsforward/ratelimit.go b/dnsforward/ratelimit.go new file mode 100644 index 00000000..9ea8d216 --- /dev/null +++ b/dnsforward/ratelimit.go @@ -0,0 +1,80 @@ +package dnsforward + +import ( + "log" + "sort" + "time" + + "github.com/beefsack/go-rate" + gocache "github.com/patrickmn/go-cache" +) + +func (s *Server) limiterForIP(ip string) interface{} { + if s.ratelimitBuckets == nil { + s.ratelimitBuckets = gocache.New(time.Hour, time.Hour) + } + + // check if ratelimiter for that IP already exists, if not, create + value, found := s.ratelimitBuckets.Get(ip) + if !found { + value = rate.New(s.Ratelimit, time.Second) + s.ratelimitBuckets.Set(ip, value, time.Hour) + } + + return value +} + +func (s *Server) isRatelimited(ip string) bool { + if s.Ratelimit == 0 { // 0 -- disabled + return false + } + if len(s.RatelimitWhitelist) > 0 { + i := sort.SearchStrings(s.RatelimitWhitelist, ip) + + if i < len(s.RatelimitWhitelist) && s.RatelimitWhitelist[i] == ip { + // found, don't ratelimit + return false + } + } + + value := s.limiterForIP(ip) + rl, ok := value.(*rate.RateLimiter) + if !ok { + log.Println("SHOULD NOT HAPPEN: non-bool entry found in safebrowsing lookup cache") + return false + } + + allow, _ := rl.Try() + return !allow +} + +func (s *Server) isRatelimitedForReply(ip string, size int) bool { + if s.Ratelimit == 0 { // 0 -- disabled + return false + } + if len(s.RatelimitWhitelist) > 0 { + i := sort.SearchStrings(s.RatelimitWhitelist, ip) + + if i < len(s.RatelimitWhitelist) && s.RatelimitWhitelist[i] == ip { + // found, don't ratelimit + return false + } + } + + value := s.limiterForIP(ip) + rl, ok := value.(*rate.RateLimiter) + if !ok { + log.Println("SHOULD NOT HAPPEN: non-bool entry found in safebrowsing lookup cache") + return false + } + + // For large UDP responses we try more times, effectively limiting per bandwidth + // The exact number of times depends on the response size + for i := 0; i < size/1000; i++ { + allow, _ := rl.Try() + if !allow { // not allowed -> ratelimited + return true + } + } + return false +} diff --git a/dnsforward/ratelimit_test.go b/dnsforward/ratelimit_test.go new file mode 100644 index 00000000..ed6f5ce9 --- /dev/null +++ b/dnsforward/ratelimit_test.go @@ -0,0 +1,42 @@ +package dnsforward + +import ( + "testing" +) + +func TestRatelimiting(t *testing.T) { + // rate limit is 1 per sec + p := Server{} + p.Ratelimit = 1 + + limited := p.isRatelimited("127.0.0.1") + + if limited { + t.Fatal("First request must have been allowed") + } + + limited = p.isRatelimited("127.0.0.1") + + if !limited { + t.Fatal("Second request must have been ratelimited") + } +} + +func TestWhitelist(t *testing.T) { + // rate limit is 1 per sec with whitelist + p := Server{} + p.Ratelimit = 1 + p.RatelimitWhitelist = []string{"127.0.0.1", "127.0.0.2", "127.0.0.125"} + + limited := p.isRatelimited("127.0.0.1") + + if limited { + t.Fatal("First request must have been allowed") + } + + limited = p.isRatelimited("127.0.0.1") + + if limited { + t.Fatal("Second request must have been allowed due to whitelist") + } +}