diff --git a/cmd/serve.go b/cmd/serve.go index 952c426e..3cc01143 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -5,16 +5,18 @@ package cmd import ( "errors" "fmt" - "heckel.io/ntfy/log" "io/fs" "math" "net" + "net/netip" "os" "os/signal" "strings" "syscall" "time" + "heckel.io/ntfy/log" + "github.com/urfave/cli/v2" "github.com/urfave/cli/v2/altsrc" "heckel.io/ntfy/server" @@ -208,15 +210,15 @@ func execServe(c *cli.Context) error { } // Resolve hosts - visitorRequestLimitExemptIPs := make([]string, 0) + visitorRequestLimitExemptIPs := make([]netip.Prefix, 0) for _, host := range visitorRequestLimitExemptHosts { - ips, err := net.LookupIP(host) + ips, err := parseIPHostPrefix(host) if err != nil { log.Warn("cannot resolve host %s: %s, ignoring visitor request exemption", host, err.Error()) continue } for _, ip := range ips { - visitorRequestLimitExemptIPs = append(visitorRequestLimitExemptIPs, ip.String()) + visitorRequestLimitExemptIPs = append(visitorRequestLimitExemptIPs, ip) } } @@ -303,6 +305,33 @@ func sigHandlerConfigReload(config string) { } } +func parseIPHostPrefix(host string) (prefixes []netip.Prefix, err error) { + //try parsing as prefix + prefix, err := netip.ParsePrefix(host) + if err == nil { + prefixes = append(prefixes, prefix.Masked()) // masked and canonical for easy of debugging, shouldn't matter + return prefixes, nil // success + } + + // not a prefix, parse as host or IP + // LookupHost forwards through if it's an IP + ips, err := net.LookupHost(host) + if err == nil { + for _, i := range ips { + ip, err := netip.ParseAddr(i) + if err == nil { + prefix, err := ip.Prefix(ip.BitLen()) + if err != nil { + return prefixes, errors.New(fmt.Sprint("ip", ip, " successfully parsed as IP but unable to turn into prefix. THIS SHOULD NEVER HAPPEN. err:", err.Error())) + } + prefixes = append(prefixes, prefix.Masked()) //also masked canonical ip + } + } + } + return +} + + func reloadLogLevel(inputSource altsrc.InputSourceContext) { newLevelStr, err := inputSource.String("log-level") if err != nil { diff --git a/server/config.go b/server/config.go index e117da88..d8fd429e 100644 --- a/server/config.go +++ b/server/config.go @@ -2,6 +2,7 @@ package server import ( "io/fs" + "net/netip" "time" ) @@ -92,7 +93,7 @@ type Config struct { VisitorAttachmentDailyBandwidthLimit int VisitorRequestLimitBurst int VisitorRequestLimitReplenish time.Duration - VisitorRequestExemptIPAddrs []string + VisitorRequestExemptIPAddrs []netip.Prefix VisitorEmailLimitBurst int VisitorEmailLimitReplenish time.Duration BehindProxy bool @@ -135,7 +136,7 @@ func NewConfig() *Config { VisitorAttachmentDailyBandwidthLimit: DefaultVisitorAttachmentDailyBandwidthLimit, VisitorRequestLimitBurst: DefaultVisitorRequestLimitBurst, VisitorRequestLimitReplenish: DefaultVisitorRequestLimitReplenish, - VisitorRequestExemptIPAddrs: make([]string, 0), + VisitorRequestExemptIPAddrs: make([]netip.Prefix, 0), VisitorEmailLimitBurst: DefaultVisitorEmailLimitBurst, VisitorEmailLimitReplenish: DefaultVisitorEmailLimitReplenish, BehindProxy: false, diff --git a/server/message_cache.go b/server/message_cache.go index a2f49e75..4845a918 100644 --- a/server/message_cache.go +++ b/server/message_cache.go @@ -5,11 +5,13 @@ import ( "encoding/json" "errors" "fmt" + "net/netip" + "strings" + "time" + _ "github.com/mattn/go-sqlite3" // SQLite driver "heckel.io/ntfy/log" "heckel.io/ntfy/util" - "strings" - "time" ) var ( @@ -279,7 +281,7 @@ func (c *messageCache) addMessages(ms []*message) error { attachmentSize, attachmentExpires, attachmentURL, - m.Sender, + m.Sender.String(), m.Encoding, published, ) @@ -477,7 +479,7 @@ func readMessages(rows *sql.Rows) ([]*message, error) { Icon: icon, Actions: actions, Attachment: att, - Sender: sender, + Sender: netip.MustParseAddr(sender), // Must parse assuming database must be correct Encoding: encoding, }) } diff --git a/server/message_cache_test.go b/server/message_cache_test.go index 23c080d4..ea9580a5 100644 --- a/server/message_cache_test.go +++ b/server/message_cache_test.go @@ -3,11 +3,13 @@ package server import ( "database/sql" "fmt" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "net/netip" "path/filepath" "testing" "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestSqliteCache_Messages(t *testing.T) { @@ -281,7 +283,7 @@ func testCacheAttachments(t *testing.T, c *messageCache) { expires1 := time.Now().Add(-4 * time.Hour).Unix() m := newDefaultMessage("mytopic", "flower for you") m.ID = "m1" - m.Sender = "1.2.3.4" + m.Sender = netip.MustParseAddr("1.2.3.4") m.Attachment = &attachment{ Name: "flower.jpg", Type: "image/jpeg", @@ -294,7 +296,7 @@ func testCacheAttachments(t *testing.T, c *messageCache) { expires2 := time.Now().Add(2 * time.Hour).Unix() // Future m = newDefaultMessage("mytopic", "sending you a car") m.ID = "m2" - m.Sender = "1.2.3.4" + m.Sender = netip.MustParseAddr("1.2.3.4") m.Attachment = &attachment{ Name: "car.jpg", Type: "image/jpeg", @@ -307,7 +309,7 @@ func testCacheAttachments(t *testing.T, c *messageCache) { expires3 := time.Now().Add(1 * time.Hour).Unix() // Future m = newDefaultMessage("another-topic", "sending you another car") m.ID = "m3" - m.Sender = "1.2.3.4" + m.Sender = netip.MustParseAddr("1.2.3.4") m.Attachment = &attachment{ Name: "another-car.jpg", Type: "image/jpeg", diff --git a/server/server.go b/server/server.go index 276e56fa..0b9cb21a 100644 --- a/server/server.go +++ b/server/server.go @@ -11,6 +11,7 @@ import ( "io" "net" "net/http" + "net/netip" "net/url" "os" "path" @@ -42,7 +43,7 @@ type Server struct { smtpServerBackend *smtpBackend smtpSender mailer topics map[string]*topic - visitors map[string]*visitor + visitors map[netip.Addr]*visitor firebaseClient *firebaseClient messages int64 auth auth.Auther @@ -150,7 +151,7 @@ func New(conf *Config) (*Server, error) { smtpSender: mailer, topics: topics, auth: auther, - visitors: make(map[string]*visitor), + visitors: make(map[netip.Addr]*visitor), }, nil } @@ -642,8 +643,8 @@ func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (ca return false, false, "", false, errHTTPBadRequestDelayTooLarge } m.Time = delay.Unix() - m.Sender = v.ip // Important for rate limiting } + m.Sender = v.ip // Important for rate limiting actionsStr := readParam(r, "x-actions", "actions", "action") if actionsStr != "" { m.Actions, err = parseActions(actionsStr) @@ -1219,7 +1220,7 @@ func (s *Server) runFirebaseKeepaliver() { if s.firebaseClient == nil { return } - v := newVisitor(s.config, s.messageCache, "0.0.0.0") // Background process, not a real visitor + v := newVisitor(s.config, s.messageCache, netip.MustParseAddr("0.0.0.0")) // Background process, not a real visitor for { select { case <-time.After(s.config.FirebaseKeepaliveInterval): @@ -1286,7 +1287,7 @@ func (s *Server) sendDelayedMessage(v *visitor, m *message) error { func (s *Server) limitRequests(next handleFunc) handleFunc { return func(w http.ResponseWriter, r *http.Request, v *visitor) error { - if util.Contains(s.config.VisitorRequestExemptIPAddrs, v.ip) { + if util.ContainsContains(s.config.VisitorRequestExemptIPAddrs, v.ip) { return next(w, r, v) } else if err := v.RequestAllowed(); err != nil { return errHTTPTooManyRequestsLimitRequests @@ -1436,21 +1437,29 @@ func extractUserPass(r *http.Request) (username string, password string, ok bool // This function was taken from https://www.alexedwards.net/blog/how-to-rate-limit-http-requests (MIT). func (s *Server) visitor(r *http.Request) *visitor { remoteAddr := r.RemoteAddr - ip, _, err := net.SplitHostPort(remoteAddr) + ipport, err := netip.ParseAddrPort(remoteAddr) + ip := ipport.Addr() if err != nil { - ip = remoteAddr // This should not happen in real life; only in tests. + ip = netip.MustParseAddr(remoteAddr) // This should not happen in real life; only in tests. So, using MustParse, which panics on error. } if s.config.BehindProxy && strings.TrimSpace(r.Header.Get("X-Forwarded-For")) != "" { // X-Forwarded-For can contain multiple addresses (see #328). If we are behind a proxy, // only the right-most address can be trusted (as this is the one added by our proxy server). // See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For for details. ips := util.SplitNoEmpty(r.Header.Get("X-Forwarded-For"), ",") - ip = strings.TrimSpace(util.LastString(ips, remoteAddr)) + myip, err := netip.ParseAddr(strings.TrimSpace(util.LastString(ips, remoteAddr))) + if err != nil { + log.Error("Invalid IP Address Received from proxy in X-Forwarded-For header. This should NEVER happen, your proxy is seriously broken: ", ip, err) + // fall back to regular remote address if x forwarded for is damaged + } else { + ip = myip + } + } return s.visitorFromIP(ip) } -func (s *Server) visitorFromIP(ip string) *visitor { +func (s *Server) visitorFromIP(ip netip.Addr) *visitor { s.mu.Lock() defer s.mu.Unlock() v, exists := s.visitors[ip] diff --git a/server/server_firebase_test.go b/server/server_firebase_test.go index 3e034c06..36fd8b51 100644 --- a/server/server_firebase_test.go +++ b/server/server_firebase_test.go @@ -3,13 +3,15 @@ package server import ( "encoding/json" "errors" - "firebase.google.com/go/v4/messaging" "fmt" - "github.com/stretchr/testify/require" - "heckel.io/ntfy/auth" + "net/netip" "strings" "sync" "testing" + + "firebase.google.com/go/v4/messaging" + "github.com/stretchr/testify/require" + "heckel.io/ntfy/auth" ) type testAuther struct { @@ -322,7 +324,7 @@ func TestMaybeTruncateFCMMessage_NotTooLong(t *testing.T) { func TestToFirebaseSender_Abuse(t *testing.T) { sender := &testFirebaseSender{allowed: 2} client := newFirebaseClient(sender, &testAuther{}) - visitor := newVisitor(newTestConfig(t), newMemTestCache(t), "1.2.3.4") + visitor := newVisitor(newTestConfig(t), newMemTestCache(t), netip.MustParseAddr("1.2.3.4")) require.Nil(t, client.Send(visitor, &message{Topic: "mytopic"})) require.Equal(t, 1, len(sender.Messages())) diff --git a/server/server_matrix_test.go b/server/server_matrix_test.go index b2f9b1d5..4b5a66c4 100644 --- a/server/server_matrix_test.go +++ b/server/server_matrix_test.go @@ -1,11 +1,13 @@ package server import ( - "github.com/stretchr/testify/require" "net/http" "net/http/httptest" + "net/netip" "strings" "testing" + + "github.com/stretchr/testify/require" ) func TestMatrix_NewRequestFromMatrixJSON_Success(t *testing.T) { @@ -70,7 +72,7 @@ func TestMatrix_WriteMatrixDiscoveryResponse(t *testing.T) { func TestMatrix_WriteMatrixError(t *testing.T) { w := httptest.NewRecorder() r, _ := http.NewRequest("POST", "http://ntfy.example.com/_matrix/push/v1/notify", nil) - v := newVisitor(newTestConfig(t), nil, "1.2.3.4") + v := newVisitor(newTestConfig(t), nil, netip.MustParseAddr("1.2.3.4")) require.Nil(t, writeMatrixError(w, r, v, &errMatrix{"https://ntfy.example.com/upABCDEFGHI?up=1", errHTTPBadRequestMatrixPushkeyBaseURLMismatch})) require.Equal(t, 200, w.Result().StatusCode) require.Equal(t, `{"rejected":["https://ntfy.example.com/upABCDEFGHI?up=1"]}`+"\n", w.Body.String()) diff --git a/server/server_test.go b/server/server_test.go index ea3495d6..5a3dcc8d 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -6,18 +6,20 @@ import ( "encoding/base64" "encoding/json" "fmt" - "github.com/stretchr/testify/assert" "io" "log" "math/rand" "net/http" "net/http/httptest" + "net/netip" "path/filepath" "strings" "sync" "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "heckel.io/ntfy/auth" "heckel.io/ntfy/util" @@ -814,7 +816,7 @@ func TestServer_PublishTooRequests_Defaults(t *testing.T) { func TestServer_PublishTooRequests_Defaults_ExemptHosts(t *testing.T) { c := newTestConfig(t) - c.VisitorRequestExemptIPAddrs = []string{"9.9.9.9"} // see request() + c.VisitorRequestExemptIPAddrs = []netip.Prefix{netip.MustParsePrefix("9.9.9.9/32")} // see request() s := newTestServer(t, c) for i := 0; i < 65; i++ { // > 60 response := request(t, s, "PUT", "/mytopic", fmt.Sprintf("message %d", i), nil) diff --git a/server/smtp_sender.go b/server/smtp_sender.go index ecefd9c2..7d6b7519 100644 --- a/server/smtp_sender.go +++ b/server/smtp_sender.go @@ -32,7 +32,7 @@ func (s *smtpSender) Send(v *visitor, m *message, to string) error { if err != nil { return err } - message, err := formatMail(s.config.BaseURL, v.ip, s.config.SMTPSenderFrom, to, m) + message, err := formatMail(s.config.BaseURL, v.ip.String(), s.config.SMTPSenderFrom, to, m) if err != nil { return err } diff --git a/server/types.go b/server/types.go index b217b9db..ce57c9b5 100644 --- a/server/types.go +++ b/server/types.go @@ -1,9 +1,11 @@ package server import ( - "heckel.io/ntfy/util" "net/http" + "net/netip" "time" + + "heckel.io/ntfy/util" ) // List of possible events @@ -33,7 +35,7 @@ type message struct { Actions []*action `json:"actions,omitempty"` Attachment *attachment `json:"attachment,omitempty"` PollID string `json:"poll_id,omitempty"` - Sender string `json:"-"` // IP address of uploader, used for rate limiting + Sender netip.Addr `json:"-"` // IP address of uploader, used for rate limiting Encoding string `json:"encoding,omitempty"` // empty for raw UTF-8, or "base64" for encoded bytes } diff --git a/server/visitor.go b/server/visitor.go index 5a8e186b..cd120c43 100644 --- a/server/visitor.go +++ b/server/visitor.go @@ -2,10 +2,12 @@ package server import ( "errors" - "golang.org/x/time/rate" - "heckel.io/ntfy/util" + "net/netip" "sync" "time" + + "golang.org/x/time/rate" + "heckel.io/ntfy/util" ) const ( @@ -23,7 +25,7 @@ var ( type visitor struct { config *Config messageCache *messageCache - ip string + ip netip.Addr requests *rate.Limiter emails *rate.Limiter subscriptions util.Limiter @@ -40,7 +42,7 @@ type visitorStats struct { VisitorAttachmentBytesRemaining int64 `json:"visitorAttachmentBytesRemaining"` } -func newVisitor(conf *Config, messageCache *messageCache, ip string) *visitor { +func newVisitor(conf *Config, messageCache *messageCache, ip netip.Addr) *visitor { return &visitor{ config: conf, messageCache: messageCache, @@ -115,7 +117,7 @@ func (v *visitor) Stale() bool { } func (v *visitor) Stats() (*visitorStats, error) { - attachmentsBytesUsed, err := v.messageCache.AttachmentBytesUsed(v.ip) + attachmentsBytesUsed, err := v.messageCache.AttachmentBytesUsed(v.ip.String()) if err != nil { return nil, err } diff --git a/util/util.go b/util/util.go index 05079180..de4b908f 100644 --- a/util/util.go +++ b/util/util.go @@ -5,8 +5,6 @@ import ( "encoding/json" "errors" "fmt" - "github.com/gabriel-vasile/mimetype" - "golang.org/x/term" "io" "math/rand" "os" @@ -15,6 +13,9 @@ import ( "strings" "sync" "time" + + "github.com/gabriel-vasile/mimetype" + "golang.org/x/term" ) const ( @@ -45,6 +46,16 @@ func Contains[T comparable](haystack []T, needle T) bool { return false } +// ContainsContains returns true if any element of haystack .Contains(needle). +func ContainsContains[T interface{ Contains(U) bool }, U any](haystack []T, needle U) bool { + for _, s := range haystack { + if s.Contains(needle) { + return true + } + } + return false +} + // ContainsAll returns true if all needles are contained in haystack func ContainsAll[T comparable](haystack []T, needles []T) bool { matches := 0