mirror of
https://github.com/binwiederhier/ntfy.git
synced 2024-11-26 23:53:21 +03:00
Fix rate limiting behind proxy, make configurable
This commit is contained in:
parent
86a16e3944
commit
0170f673bd
19
cmd/app.go
19
cmd/app.go
@ -22,8 +22,13 @@ func New() *cli.App {
|
|||||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "firebase-key-file", Aliases: []string{"F"}, EnvVars: []string{"NTFY_FIREBASE_KEY_FILE"}, Usage: "Firebase credentials file; if set additionally publish to FCM topic"}),
|
altsrc.NewStringFlag(&cli.StringFlag{Name: "firebase-key-file", Aliases: []string{"F"}, EnvVars: []string{"NTFY_FIREBASE_KEY_FILE"}, Usage: "Firebase credentials file; if set additionally publish to FCM topic"}),
|
||||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "cache-file", Aliases: []string{"C"}, EnvVars: []string{"NTFY_CACHE_FILE"}, Usage: "cache file used for message caching"}),
|
altsrc.NewStringFlag(&cli.StringFlag{Name: "cache-file", Aliases: []string{"C"}, EnvVars: []string{"NTFY_CACHE_FILE"}, Usage: "cache file used for message caching"}),
|
||||||
altsrc.NewDurationFlag(&cli.DurationFlag{Name: "cache-duration", Aliases: []string{"b"}, EnvVars: []string{"NTFY_CACHE_DURATION"}, Value: config.DefaultCacheDuration, Usage: "buffer messages for this time to allow `since` requests"}),
|
altsrc.NewDurationFlag(&cli.DurationFlag{Name: "cache-duration", Aliases: []string{"b"}, EnvVars: []string{"NTFY_CACHE_DURATION"}, Value: config.DefaultCacheDuration, Usage: "buffer messages for this time to allow `since` requests"}),
|
||||||
altsrc.NewDurationFlag(&cli.DurationFlag{Name: "keepalive-interval", Aliases: []string{"k"}, EnvVars: []string{"NTFY_KEEPALIVE_INTERVAL"}, Value: config.DefaultKeepaliveInterval, Usage: "default interval of keepalive messages"}),
|
altsrc.NewDurationFlag(&cli.DurationFlag{Name: "keepalive-interval", Aliases: []string{"k"}, EnvVars: []string{"NTFY_KEEPALIVE_INTERVAL"}, Value: config.DefaultKeepaliveInterval, Usage: "interval of keepalive messages"}),
|
||||||
altsrc.NewDurationFlag(&cli.DurationFlag{Name: "manager-interval", Aliases: []string{"m"}, EnvVars: []string{"NTFY_MANAGER_INTERVAL"}, Value: config.DefaultManagerInterval, Usage: "default interval of for message pruning and stats printing"}),
|
altsrc.NewDurationFlag(&cli.DurationFlag{Name: "manager-interval", Aliases: []string{"m"}, EnvVars: []string{"NTFY_MANAGER_INTERVAL"}, Value: config.DefaultManagerInterval, Usage: "interval of for message pruning and stats printing"}),
|
||||||
|
altsrc.NewIntFlag(&cli.IntFlag{Name: "global-topic-limit", Aliases: []string{"T"}, EnvVars: []string{"NTFY_GLOBAL_TOPIC_LIMIT"}, Value: config.DefaultGlobalTopicLimit, Usage: "total number of topics allowed"}),
|
||||||
|
altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-subscription-limit", Aliases: []string{"V"}, EnvVars: []string{"NTFY_VISITOR_SUBSCRIPTION_LIMIT"}, Value: config.DefaultVisitorSubscriptionLimit, Usage: "number of subscriptions per visitor"}),
|
||||||
|
altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-request-limit-burst", Aliases: []string{"B"}, EnvVars: []string{"NTFY_VISITOR_REQUEST_LIMIT_BURST"}, Value: config.DefaultVisitorRequestLimitBurst, Usage: "initial limit of requests per visitor"}),
|
||||||
|
altsrc.NewDurationFlag(&cli.DurationFlag{Name: "visitor-request-limit-replenish", Aliases: []string{"R"}, EnvVars: []string{"NTFY_VISITOR_REQUEST_LIMIT_REPLENISH"}, Value: config.DefaultVisitorRequestLimitReplenish, Usage: "interval at which burst limit is replenished (one per x)"}),
|
||||||
|
altsrc.NewBoolFlag(&cli.BoolFlag{Name: "behind-proxy", Aliases: []string{"P"}, EnvVars: []string{"NTFY_BEHIND_PROXY"}, Value: false, Usage: "if set, use X-Forwarded-For header to determine visitor IP address (for rate limiting)"}),
|
||||||
}
|
}
|
||||||
return &cli.App{
|
return &cli.App{
|
||||||
Name: "ntfy",
|
Name: "ntfy",
|
||||||
@ -50,6 +55,11 @@ func execRun(c *cli.Context) error {
|
|||||||
cacheDuration := c.Duration("cache-duration")
|
cacheDuration := c.Duration("cache-duration")
|
||||||
keepaliveInterval := c.Duration("keepalive-interval")
|
keepaliveInterval := c.Duration("keepalive-interval")
|
||||||
managerInterval := c.Duration("manager-interval")
|
managerInterval := c.Duration("manager-interval")
|
||||||
|
globalTopicLimit := c.Int("global-topic-limit")
|
||||||
|
visitorSubscriptionLimit := c.Int("visitor-subscription-limit")
|
||||||
|
visitorRequestLimitBurst := c.Int("visitor-request-limit-burst")
|
||||||
|
visitorRequestLimitReplenish := c.Duration("visitor-request-limit-replenish")
|
||||||
|
behindProxy := c.Bool("behind-proxy")
|
||||||
|
|
||||||
// Check values
|
// Check values
|
||||||
if firebaseKeyFile != "" && !util.FileExists(firebaseKeyFile) {
|
if firebaseKeyFile != "" && !util.FileExists(firebaseKeyFile) {
|
||||||
@ -69,6 +79,11 @@ func execRun(c *cli.Context) error {
|
|||||||
conf.CacheDuration = cacheDuration
|
conf.CacheDuration = cacheDuration
|
||||||
conf.KeepaliveInterval = keepaliveInterval
|
conf.KeepaliveInterval = keepaliveInterval
|
||||||
conf.ManagerInterval = managerInterval
|
conf.ManagerInterval = managerInterval
|
||||||
|
conf.GlobalTopicLimit = globalTopicLimit
|
||||||
|
conf.VisitorSubscriptionLimit = visitorSubscriptionLimit
|
||||||
|
conf.VisitorRequestLimitBurst = visitorRequestLimitBurst
|
||||||
|
conf.VisitorRequestLimitReplenish = visitorRequestLimitReplenish
|
||||||
|
conf.BehindProxy = behindProxy
|
||||||
s, err := server.New(conf)
|
s, err := server.New(conf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalln(err)
|
log.Fatalln(err)
|
||||||
|
@ -2,7 +2,6 @@
|
|||||||
package config
|
package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"golang.org/x/time/rate"
|
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -15,42 +14,44 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// Defines all the limits
|
// Defines all the limits
|
||||||
// - request limit: max number of PUT/GET/.. requests (here: 50 requests bucket, replenished at a rate of one per 10 seconds)
|
|
||||||
// - global topic limit: max number of topics overall
|
// - global topic limit: max number of topics overall
|
||||||
// - subscription limit: max number of subscriptions (active HTTP connections) per per-visitor/IP
|
// - per visistor request limit: max number of PUT/GET/.. requests (here: 60 requests bucket, replenished at a rate of one per 10 seconds)
|
||||||
var (
|
// - per visistor subscription limit: max number of subscriptions (active HTTP connections) per per-visitor/IP
|
||||||
defaultGlobalTopicLimit = 5000
|
const (
|
||||||
defaultVisitorRequestLimit = rate.Every(10 * time.Second)
|
DefaultGlobalTopicLimit = 5000
|
||||||
defaultVisitorRequestLimitBurst = 60
|
DefaultVisitorRequestLimitBurst = 60
|
||||||
defaultVisitorSubscriptionLimit = 30
|
DefaultVisitorRequestLimitReplenish = 10 * time.Second
|
||||||
|
DefaultVisitorSubscriptionLimit = 30
|
||||||
)
|
)
|
||||||
|
|
||||||
// Config is the main config struct for the application. Use New to instantiate a default config struct.
|
// Config is the main config struct for the application. Use New to instantiate a default config struct.
|
||||||
type Config struct {
|
type Config struct {
|
||||||
ListenHTTP string
|
ListenHTTP string
|
||||||
FirebaseKeyFile string
|
FirebaseKeyFile string
|
||||||
CacheFile string
|
CacheFile string
|
||||||
CacheDuration time.Duration
|
CacheDuration time.Duration
|
||||||
KeepaliveInterval time.Duration
|
KeepaliveInterval time.Duration
|
||||||
ManagerInterval time.Duration
|
ManagerInterval time.Duration
|
||||||
GlobalTopicLimit int
|
GlobalTopicLimit int
|
||||||
VisitorRequestLimit rate.Limit
|
VisitorRequestLimitBurst int
|
||||||
VisitorRequestLimitBurst int
|
VisitorRequestLimitReplenish time.Duration
|
||||||
VisitorSubscriptionLimit int
|
VisitorSubscriptionLimit int
|
||||||
|
BehindProxy bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// New instantiates a default new config
|
// New instantiates a default new config
|
||||||
func New(listenHTTP string) *Config {
|
func New(listenHTTP string) *Config {
|
||||||
return &Config{
|
return &Config{
|
||||||
ListenHTTP: listenHTTP,
|
ListenHTTP: listenHTTP,
|
||||||
FirebaseKeyFile: "",
|
FirebaseKeyFile: "",
|
||||||
CacheFile: "",
|
CacheFile: "",
|
||||||
CacheDuration: DefaultCacheDuration,
|
CacheDuration: DefaultCacheDuration,
|
||||||
KeepaliveInterval: DefaultKeepaliveInterval,
|
KeepaliveInterval: DefaultKeepaliveInterval,
|
||||||
ManagerInterval: DefaultManagerInterval,
|
ManagerInterval: DefaultManagerInterval,
|
||||||
GlobalTopicLimit: defaultGlobalTopicLimit,
|
GlobalTopicLimit: DefaultGlobalTopicLimit,
|
||||||
VisitorRequestLimit: defaultVisitorRequestLimit,
|
VisitorRequestLimitBurst: DefaultVisitorRequestLimitBurst,
|
||||||
VisitorRequestLimitBurst: defaultVisitorRequestLimitBurst,
|
VisitorRequestLimitReplenish: DefaultVisitorRequestLimitReplenish,
|
||||||
VisitorSubscriptionLimit: defaultVisitorSubscriptionLimit,
|
VisitorSubscriptionLimit: DefaultVisitorSubscriptionLimit,
|
||||||
|
BehindProxy: false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -25,6 +25,30 @@
|
|||||||
#
|
#
|
||||||
# keepalive-interval: 30s
|
# keepalive-interval: 30s
|
||||||
|
|
||||||
# Interval in which the manager prunes old messages, deletes topics and prints the stats.
|
# Interval in which the manager prunes old messages, deletes topics
|
||||||
|
# and prints the stats.
|
||||||
#
|
#
|
||||||
# manager-interval: 1m
|
# manager-interval: 1m
|
||||||
|
|
||||||
|
# Rate limiting: Total number of topics before the server rejects new topics.
|
||||||
|
#
|
||||||
|
# global-topic-limit: 5000
|
||||||
|
|
||||||
|
# Rate limiting: Number of subscriptions per visitor (IP address)
|
||||||
|
#
|
||||||
|
# visitor-subscription-limit: 30
|
||||||
|
|
||||||
|
# Rate limiting: Allowed GET/PUT/POST requests per second, per visitor:
|
||||||
|
# - visitor-request-limit-burst is the initial bucket of requests each visitor has
|
||||||
|
# - visitor-request-limit-replenish is the rate at which the bucket is refilled
|
||||||
|
#
|
||||||
|
# visitor-request-limit-burst: 60
|
||||||
|
# visitor-request-limit-replenish: 10s
|
||||||
|
|
||||||
|
# If set, the X-Forwarded-For header is used to determine the visitor IP address
|
||||||
|
# instead of the remote address of the connection.
|
||||||
|
#
|
||||||
|
# WARNING: If you are behind a proxy, you must set this, otherwise all visitors are rate limited
|
||||||
|
# as if they are one.
|
||||||
|
#
|
||||||
|
# behind-proxy: false
|
||||||
|
@ -159,24 +159,22 @@ func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request) error {
|
func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request) error {
|
||||||
v := s.visitor(r.RemoteAddr)
|
|
||||||
if err := v.RequestAllowed(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if r.Method == http.MethodGet && r.URL.Path == "/" {
|
if r.Method == http.MethodGet && r.URL.Path == "/" {
|
||||||
return s.handleHome(w, r)
|
return s.handleHome(w, r)
|
||||||
|
} else if r.Method == http.MethodHead && r.URL.Path == "/" {
|
||||||
|
return s.handleEmpty(w, r)
|
||||||
} else if r.Method == http.MethodGet && staticRegex.MatchString(r.URL.Path) {
|
} else if r.Method == http.MethodGet && staticRegex.MatchString(r.URL.Path) {
|
||||||
return s.handleStatic(w, r)
|
return s.handleStatic(w, r)
|
||||||
} else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && topicRegex.MatchString(r.URL.Path) {
|
|
||||||
return s.handlePublish(w, r, v)
|
|
||||||
} else if r.Method == http.MethodGet && jsonRegex.MatchString(r.URL.Path) {
|
|
||||||
return s.handleSubscribeJSON(w, r, v)
|
|
||||||
} else if r.Method == http.MethodGet && sseRegex.MatchString(r.URL.Path) {
|
|
||||||
return s.handleSubscribeSSE(w, r, v)
|
|
||||||
} else if r.Method == http.MethodGet && rawRegex.MatchString(r.URL.Path) {
|
|
||||||
return s.handleSubscribeRaw(w, r, v)
|
|
||||||
} else if r.Method == http.MethodOptions {
|
} else if r.Method == http.MethodOptions {
|
||||||
return s.handleOptions(w, r)
|
return s.handleOptions(w, r)
|
||||||
|
} else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && topicRegex.MatchString(r.URL.Path) {
|
||||||
|
return s.withRateLimit(w, r, s.handlePublish)
|
||||||
|
} else if r.Method == http.MethodGet && jsonRegex.MatchString(r.URL.Path) {
|
||||||
|
return s.withRateLimit(w, r, s.handleSubscribeJSON)
|
||||||
|
} else if r.Method == http.MethodGet && sseRegex.MatchString(r.URL.Path) {
|
||||||
|
return s.withRateLimit(w, r, s.handleSubscribeSSE)
|
||||||
|
} else if r.Method == http.MethodGet && rawRegex.MatchString(r.URL.Path) {
|
||||||
|
return s.withRateLimit(w, r, s.handleSubscribeRaw)
|
||||||
}
|
}
|
||||||
return errHTTPNotFound
|
return errHTTPNotFound
|
||||||
}
|
}
|
||||||
@ -186,6 +184,10 @@ func (s *Server) handleHome(w http.ResponseWriter, r *http.Request) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleEmpty(w http.ResponseWriter, r *http.Request) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) handleStatic(w http.ResponseWriter, r *http.Request) error {
|
func (s *Server) handleStatic(w http.ResponseWriter, r *http.Request) error {
|
||||||
http.FileServer(http.FS(webStaticFs)).ServeHTTP(w, r)
|
http.FileServer(http.FS(webStaticFs)).ServeHTTP(w, r)
|
||||||
return nil
|
return nil
|
||||||
@ -394,15 +396,27 @@ func (s *Server) updateStatsAndExpire() {
|
|||||||
s.messages, len(s.topics), subscribers, messages, len(s.visitors))
|
s.messages, len(s.topics), subscribers, messages, len(s.visitors))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) withRateLimit(w http.ResponseWriter, r *http.Request, handler func(w http.ResponseWriter, r *http.Request, v *visitor) error) error {
|
||||||
|
v := s.visitor(r)
|
||||||
|
if err := v.RequestAllowed(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return handler(w, r, v)
|
||||||
|
}
|
||||||
|
|
||||||
// visitor creates or retrieves a rate.Limiter for the given visitor.
|
// visitor creates or retrieves a rate.Limiter for the given visitor.
|
||||||
// This function was taken from https://www.alexedwards.net/blog/how-to-rate-limit-http-requests (MIT).
|
// This function was taken from https://www.alexedwards.net/blog/how-to-rate-limit-http-requests (MIT).
|
||||||
func (s *Server) visitor(remoteAddr string) *visitor {
|
func (s *Server) visitor(r *http.Request) *visitor {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
remoteAddr := r.RemoteAddr
|
||||||
ip, _, err := net.SplitHostPort(remoteAddr)
|
ip, _, err := net.SplitHostPort(remoteAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ip = remoteAddr // This should not happen in real life; only in tests.
|
ip = remoteAddr // This should not happen in real life; only in tests.
|
||||||
}
|
}
|
||||||
|
if s.config.BehindProxy && r.Header.Get("X-Forwarded-For") != "" {
|
||||||
|
ip = r.Header.Get("X-Forwarded-For")
|
||||||
|
}
|
||||||
v, exists := s.visitors[ip]
|
v, exists := s.visitors[ip]
|
||||||
if !exists {
|
if !exists {
|
||||||
s.visitors[ip] = newVisitor(s.config)
|
s.visitors[ip] = newVisitor(s.config)
|
||||||
|
@ -24,7 +24,7 @@ type visitor struct {
|
|||||||
func newVisitor(conf *config.Config) *visitor {
|
func newVisitor(conf *config.Config) *visitor {
|
||||||
return &visitor{
|
return &visitor{
|
||||||
config: conf,
|
config: conf,
|
||||||
limiter: rate.NewLimiter(conf.VisitorRequestLimit, conf.VisitorRequestLimitBurst),
|
limiter: rate.NewLimiter(rate.Every(conf.VisitorRequestLimitReplenish), conf.VisitorRequestLimitBurst),
|
||||||
subscriptions: util.NewLimiter(int64(conf.VisitorSubscriptionLimit)),
|
subscriptions: util.NewLimiter(int64(conf.VisitorSubscriptionLimit)),
|
||||||
seen: time.Now(),
|
seen: time.Now(),
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user