Fix rate limiting behind proxy, make configurable

This commit is contained in:
Philipp Heckel 2021-11-05 13:46:27 -04:00
parent 86a16e3944
commit 0170f673bd
5 changed files with 99 additions and 45 deletions

View File

@ -22,8 +22,13 @@ func New() *cli.App {
altsrc.NewStringFlag(&cli.StringFlag{Name: "firebase-key-file", Aliases: []string{"F"}, EnvVars: []string{"NTFY_FIREBASE_KEY_FILE"}, Usage: "Firebase credentials file; if set additionally publish to FCM topic"}), altsrc.NewStringFlag(&cli.StringFlag{Name: "firebase-key-file", Aliases: []string{"F"}, EnvVars: []string{"NTFY_FIREBASE_KEY_FILE"}, Usage: "Firebase credentials file; if set additionally publish to FCM topic"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "cache-file", Aliases: []string{"C"}, EnvVars: []string{"NTFY_CACHE_FILE"}, Usage: "cache file used for message caching"}), altsrc.NewStringFlag(&cli.StringFlag{Name: "cache-file", Aliases: []string{"C"}, EnvVars: []string{"NTFY_CACHE_FILE"}, Usage: "cache file used for message caching"}),
altsrc.NewDurationFlag(&cli.DurationFlag{Name: "cache-duration", Aliases: []string{"b"}, EnvVars: []string{"NTFY_CACHE_DURATION"}, Value: config.DefaultCacheDuration, Usage: "buffer messages for this time to allow `since` requests"}), altsrc.NewDurationFlag(&cli.DurationFlag{Name: "cache-duration", Aliases: []string{"b"}, EnvVars: []string{"NTFY_CACHE_DURATION"}, Value: config.DefaultCacheDuration, Usage: "buffer messages for this time to allow `since` requests"}),
altsrc.NewDurationFlag(&cli.DurationFlag{Name: "keepalive-interval", Aliases: []string{"k"}, EnvVars: []string{"NTFY_KEEPALIVE_INTERVAL"}, Value: config.DefaultKeepaliveInterval, Usage: "default interval of keepalive messages"}), altsrc.NewDurationFlag(&cli.DurationFlag{Name: "keepalive-interval", Aliases: []string{"k"}, EnvVars: []string{"NTFY_KEEPALIVE_INTERVAL"}, Value: config.DefaultKeepaliveInterval, Usage: "interval of keepalive messages"}),
altsrc.NewDurationFlag(&cli.DurationFlag{Name: "manager-interval", Aliases: []string{"m"}, EnvVars: []string{"NTFY_MANAGER_INTERVAL"}, Value: config.DefaultManagerInterval, Usage: "default interval of for message pruning and stats printing"}), altsrc.NewDurationFlag(&cli.DurationFlag{Name: "manager-interval", Aliases: []string{"m"}, EnvVars: []string{"NTFY_MANAGER_INTERVAL"}, Value: config.DefaultManagerInterval, Usage: "interval of for message pruning and stats printing"}),
altsrc.NewIntFlag(&cli.IntFlag{Name: "global-topic-limit", Aliases: []string{"T"}, EnvVars: []string{"NTFY_GLOBAL_TOPIC_LIMIT"}, Value: config.DefaultGlobalTopicLimit, Usage: "total number of topics allowed"}),
altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-subscription-limit", Aliases: []string{"V"}, EnvVars: []string{"NTFY_VISITOR_SUBSCRIPTION_LIMIT"}, Value: config.DefaultVisitorSubscriptionLimit, Usage: "number of subscriptions per visitor"}),
altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-request-limit-burst", Aliases: []string{"B"}, EnvVars: []string{"NTFY_VISITOR_REQUEST_LIMIT_BURST"}, Value: config.DefaultVisitorRequestLimitBurst, Usage: "initial limit of requests per visitor"}),
altsrc.NewDurationFlag(&cli.DurationFlag{Name: "visitor-request-limit-replenish", Aliases: []string{"R"}, EnvVars: []string{"NTFY_VISITOR_REQUEST_LIMIT_REPLENISH"}, Value: config.DefaultVisitorRequestLimitReplenish, Usage: "interval at which burst limit is replenished (one per x)"}),
altsrc.NewBoolFlag(&cli.BoolFlag{Name: "behind-proxy", Aliases: []string{"P"}, EnvVars: []string{"NTFY_BEHIND_PROXY"}, Value: false, Usage: "if set, use X-Forwarded-For header to determine visitor IP address (for rate limiting)"}),
} }
return &cli.App{ return &cli.App{
Name: "ntfy", Name: "ntfy",
@ -50,6 +55,11 @@ func execRun(c *cli.Context) error {
cacheDuration := c.Duration("cache-duration") cacheDuration := c.Duration("cache-duration")
keepaliveInterval := c.Duration("keepalive-interval") keepaliveInterval := c.Duration("keepalive-interval")
managerInterval := c.Duration("manager-interval") managerInterval := c.Duration("manager-interval")
globalTopicLimit := c.Int("global-topic-limit")
visitorSubscriptionLimit := c.Int("visitor-subscription-limit")
visitorRequestLimitBurst := c.Int("visitor-request-limit-burst")
visitorRequestLimitReplenish := c.Duration("visitor-request-limit-replenish")
behindProxy := c.Bool("behind-proxy")
// Check values // Check values
if firebaseKeyFile != "" && !util.FileExists(firebaseKeyFile) { if firebaseKeyFile != "" && !util.FileExists(firebaseKeyFile) {
@ -69,6 +79,11 @@ func execRun(c *cli.Context) error {
conf.CacheDuration = cacheDuration conf.CacheDuration = cacheDuration
conf.KeepaliveInterval = keepaliveInterval conf.KeepaliveInterval = keepaliveInterval
conf.ManagerInterval = managerInterval conf.ManagerInterval = managerInterval
conf.GlobalTopicLimit = globalTopicLimit
conf.VisitorSubscriptionLimit = visitorSubscriptionLimit
conf.VisitorRequestLimitBurst = visitorRequestLimitBurst
conf.VisitorRequestLimitReplenish = visitorRequestLimitReplenish
conf.BehindProxy = behindProxy
s, err := server.New(conf) s, err := server.New(conf)
if err != nil { if err != nil {
log.Fatalln(err) log.Fatalln(err)

View File

@ -2,7 +2,6 @@
package config package config
import ( import (
"golang.org/x/time/rate"
"time" "time"
) )
@ -15,42 +14,44 @@ const (
) )
// Defines all the limits // Defines all the limits
// - request limit: max number of PUT/GET/.. requests (here: 50 requests bucket, replenished at a rate of one per 10 seconds)
// - global topic limit: max number of topics overall // - global topic limit: max number of topics overall
// - subscription limit: max number of subscriptions (active HTTP connections) per per-visitor/IP // - per visistor request limit: max number of PUT/GET/.. requests (here: 60 requests bucket, replenished at a rate of one per 10 seconds)
var ( // - per visistor subscription limit: max number of subscriptions (active HTTP connections) per per-visitor/IP
defaultGlobalTopicLimit = 5000 const (
defaultVisitorRequestLimit = rate.Every(10 * time.Second) DefaultGlobalTopicLimit = 5000
defaultVisitorRequestLimitBurst = 60 DefaultVisitorRequestLimitBurst = 60
defaultVisitorSubscriptionLimit = 30 DefaultVisitorRequestLimitReplenish = 10 * time.Second
DefaultVisitorSubscriptionLimit = 30
) )
// Config is the main config struct for the application. Use New to instantiate a default config struct. // Config is the main config struct for the application. Use New to instantiate a default config struct.
type Config struct { type Config struct {
ListenHTTP string ListenHTTP string
FirebaseKeyFile string FirebaseKeyFile string
CacheFile string CacheFile string
CacheDuration time.Duration CacheDuration time.Duration
KeepaliveInterval time.Duration KeepaliveInterval time.Duration
ManagerInterval time.Duration ManagerInterval time.Duration
GlobalTopicLimit int GlobalTopicLimit int
VisitorRequestLimit rate.Limit VisitorRequestLimitBurst int
VisitorRequestLimitBurst int VisitorRequestLimitReplenish time.Duration
VisitorSubscriptionLimit int VisitorSubscriptionLimit int
BehindProxy bool
} }
// New instantiates a default new config // New instantiates a default new config
func New(listenHTTP string) *Config { func New(listenHTTP string) *Config {
return &Config{ return &Config{
ListenHTTP: listenHTTP, ListenHTTP: listenHTTP,
FirebaseKeyFile: "", FirebaseKeyFile: "",
CacheFile: "", CacheFile: "",
CacheDuration: DefaultCacheDuration, CacheDuration: DefaultCacheDuration,
KeepaliveInterval: DefaultKeepaliveInterval, KeepaliveInterval: DefaultKeepaliveInterval,
ManagerInterval: DefaultManagerInterval, ManagerInterval: DefaultManagerInterval,
GlobalTopicLimit: defaultGlobalTopicLimit, GlobalTopicLimit: DefaultGlobalTopicLimit,
VisitorRequestLimit: defaultVisitorRequestLimit, VisitorRequestLimitBurst: DefaultVisitorRequestLimitBurst,
VisitorRequestLimitBurst: defaultVisitorRequestLimitBurst, VisitorRequestLimitReplenish: DefaultVisitorRequestLimitReplenish,
VisitorSubscriptionLimit: defaultVisitorSubscriptionLimit, VisitorSubscriptionLimit: DefaultVisitorSubscriptionLimit,
BehindProxy: false,
} }
} }

View File

@ -25,6 +25,30 @@
# #
# keepalive-interval: 30s # keepalive-interval: 30s
# Interval in which the manager prunes old messages, deletes topics and prints the stats. # Interval in which the manager prunes old messages, deletes topics
# and prints the stats.
# #
# manager-interval: 1m # manager-interval: 1m
# Rate limiting: Total number of topics before the server rejects new topics.
#
# global-topic-limit: 5000
# Rate limiting: Number of subscriptions per visitor (IP address)
#
# visitor-subscription-limit: 30
# Rate limiting: Allowed GET/PUT/POST requests per second, per visitor:
# - visitor-request-limit-burst is the initial bucket of requests each visitor has
# - visitor-request-limit-replenish is the rate at which the bucket is refilled
#
# visitor-request-limit-burst: 60
# visitor-request-limit-replenish: 10s
# If set, the X-Forwarded-For header is used to determine the visitor IP address
# instead of the remote address of the connection.
#
# WARNING: If you are behind a proxy, you must set this, otherwise all visitors are rate limited
# as if they are one.
#
# behind-proxy: false

View File

@ -159,24 +159,22 @@ func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
} }
func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request) error { func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request) error {
v := s.visitor(r.RemoteAddr)
if err := v.RequestAllowed(); err != nil {
return err
}
if r.Method == http.MethodGet && r.URL.Path == "/" { if r.Method == http.MethodGet && r.URL.Path == "/" {
return s.handleHome(w, r) return s.handleHome(w, r)
} else if r.Method == http.MethodHead && r.URL.Path == "/" {
return s.handleEmpty(w, r)
} else if r.Method == http.MethodGet && staticRegex.MatchString(r.URL.Path) { } else if r.Method == http.MethodGet && staticRegex.MatchString(r.URL.Path) {
return s.handleStatic(w, r) return s.handleStatic(w, r)
} else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && topicRegex.MatchString(r.URL.Path) {
return s.handlePublish(w, r, v)
} else if r.Method == http.MethodGet && jsonRegex.MatchString(r.URL.Path) {
return s.handleSubscribeJSON(w, r, v)
} else if r.Method == http.MethodGet && sseRegex.MatchString(r.URL.Path) {
return s.handleSubscribeSSE(w, r, v)
} else if r.Method == http.MethodGet && rawRegex.MatchString(r.URL.Path) {
return s.handleSubscribeRaw(w, r, v)
} else if r.Method == http.MethodOptions { } else if r.Method == http.MethodOptions {
return s.handleOptions(w, r) return s.handleOptions(w, r)
} else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && topicRegex.MatchString(r.URL.Path) {
return s.withRateLimit(w, r, s.handlePublish)
} else if r.Method == http.MethodGet && jsonRegex.MatchString(r.URL.Path) {
return s.withRateLimit(w, r, s.handleSubscribeJSON)
} else if r.Method == http.MethodGet && sseRegex.MatchString(r.URL.Path) {
return s.withRateLimit(w, r, s.handleSubscribeSSE)
} else if r.Method == http.MethodGet && rawRegex.MatchString(r.URL.Path) {
return s.withRateLimit(w, r, s.handleSubscribeRaw)
} }
return errHTTPNotFound return errHTTPNotFound
} }
@ -186,6 +184,10 @@ func (s *Server) handleHome(w http.ResponseWriter, r *http.Request) error {
return err return err
} }
func (s *Server) handleEmpty(w http.ResponseWriter, r *http.Request) error {
return nil
}
func (s *Server) handleStatic(w http.ResponseWriter, r *http.Request) error { func (s *Server) handleStatic(w http.ResponseWriter, r *http.Request) error {
http.FileServer(http.FS(webStaticFs)).ServeHTTP(w, r) http.FileServer(http.FS(webStaticFs)).ServeHTTP(w, r)
return nil return nil
@ -394,15 +396,27 @@ func (s *Server) updateStatsAndExpire() {
s.messages, len(s.topics), subscribers, messages, len(s.visitors)) s.messages, len(s.topics), subscribers, messages, len(s.visitors))
} }
func (s *Server) withRateLimit(w http.ResponseWriter, r *http.Request, handler func(w http.ResponseWriter, r *http.Request, v *visitor) error) error {
v := s.visitor(r)
if err := v.RequestAllowed(); err != nil {
return err
}
return handler(w, r, v)
}
// visitor creates or retrieves a rate.Limiter for the given visitor. // visitor creates or retrieves a rate.Limiter for the given visitor.
// This function was taken from https://www.alexedwards.net/blog/how-to-rate-limit-http-requests (MIT). // This function was taken from https://www.alexedwards.net/blog/how-to-rate-limit-http-requests (MIT).
func (s *Server) visitor(remoteAddr string) *visitor { func (s *Server) visitor(r *http.Request) *visitor {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
remoteAddr := r.RemoteAddr
ip, _, err := net.SplitHostPort(remoteAddr) ip, _, err := net.SplitHostPort(remoteAddr)
if err != nil { if err != nil {
ip = remoteAddr // This should not happen in real life; only in tests. ip = remoteAddr // This should not happen in real life; only in tests.
} }
if s.config.BehindProxy && r.Header.Get("X-Forwarded-For") != "" {
ip = r.Header.Get("X-Forwarded-For")
}
v, exists := s.visitors[ip] v, exists := s.visitors[ip]
if !exists { if !exists {
s.visitors[ip] = newVisitor(s.config) s.visitors[ip] = newVisitor(s.config)

View File

@ -24,7 +24,7 @@ type visitor struct {
func newVisitor(conf *config.Config) *visitor { func newVisitor(conf *config.Config) *visitor {
return &visitor{ return &visitor{
config: conf, config: conf,
limiter: rate.NewLimiter(conf.VisitorRequestLimit, conf.VisitorRequestLimitBurst), limiter: rate.NewLimiter(rate.Every(conf.VisitorRequestLimitReplenish), conf.VisitorRequestLimitBurst),
subscriptions: util.NewLimiter(int64(conf.VisitorSubscriptionLimit)), subscriptions: util.NewLimiter(int64(conf.VisitorSubscriptionLimit)),
seen: time.Now(), seen: time.Now(),
} }