mirror of
https://github.com/binwiederhier/ntfy.git
synced 2024-12-27 19:05:07 +03:00
Merge branch 'ip-range-exempt'
This commit is contained in:
commit
cbc912d1e3
37
cmd/serve.go
37
cmd/serve.go
@ -5,16 +5,18 @@ package cmd
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"heckel.io/ntfy/log"
|
||||
"io/fs"
|
||||
"math"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"heckel.io/ntfy/log"
|
||||
|
||||
"github.com/urfave/cli/v2"
|
||||
"github.com/urfave/cli/v2/altsrc"
|
||||
"heckel.io/ntfy/server"
|
||||
@ -208,16 +210,14 @@ func execServe(c *cli.Context) error {
|
||||
}
|
||||
|
||||
// Resolve hosts
|
||||
visitorRequestLimitExemptIPs := make([]string, 0)
|
||||
visitorRequestLimitExemptIPs := make([]netip.Prefix, 0)
|
||||
for _, host := range visitorRequestLimitExemptHosts {
|
||||
ips, err := net.LookupIP(host)
|
||||
ips, err := parseIPHostPrefix(host)
|
||||
if err != nil {
|
||||
log.Warn("cannot resolve host %s: %s, ignoring visitor request exemption", host, err.Error())
|
||||
continue
|
||||
}
|
||||
for _, ip := range ips {
|
||||
visitorRequestLimitExemptIPs = append(visitorRequestLimitExemptIPs, ip.String())
|
||||
}
|
||||
visitorRequestLimitExemptIPs = append(visitorRequestLimitExemptIPs, ips...)
|
||||
}
|
||||
|
||||
// Run server
|
||||
@ -303,6 +303,31 @@ func sigHandlerConfigReload(config string) {
|
||||
}
|
||||
}
|
||||
|
||||
func parseIPHostPrefix(host string) (prefixes []netip.Prefix, err error) {
|
||||
// Try parsing as prefix, e.g. 10.0.1.0/24
|
||||
prefix, err := netip.ParsePrefix(host)
|
||||
if err == nil {
|
||||
prefixes = append(prefixes, prefix.Masked())
|
||||
return prefixes, nil
|
||||
}
|
||||
// Not a prefix, parse as host or IP (LookupHost passes through an IP as is)
|
||||
ips, err := net.LookupHost(host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, ipStr := range ips {
|
||||
ip, err := netip.ParseAddr(ipStr)
|
||||
if err == nil {
|
||||
prefix, err := ip.Prefix(ip.BitLen())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s successfully parsed but unable to make prefix: %s", ip.String(), err.Error())
|
||||
}
|
||||
prefixes = append(prefixes, prefix.Masked())
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func reloadLogLevel(inputSource altsrc.InputSourceContext) {
|
||||
newLevelStr, err := inputSource.String("log-level")
|
||||
if err != nil {
|
||||
|
@ -2,17 +2,19 @@ package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/stretchr/testify/require"
|
||||
"heckel.io/ntfy/client"
|
||||
"heckel.io/ntfy/test"
|
||||
"heckel.io/ntfy/util"
|
||||
"math/rand"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"heckel.io/ntfy/client"
|
||||
"heckel.io/ntfy/test"
|
||||
"heckel.io/ntfy/util"
|
||||
)
|
||||
|
||||
func init() {
|
||||
@ -70,6 +72,22 @@ func TestCLI_Serve_WebSocket(t *testing.T) {
|
||||
require.Equal(t, "mytopic", m.Topic)
|
||||
}
|
||||
|
||||
func TestIP_Host_Parsing(t *testing.T) {
|
||||
cases := map[string]string{
|
||||
"1.1.1.1": "1.1.1.1/32",
|
||||
"fd00::1234": "fd00::1234/128",
|
||||
"192.168.0.3/24": "192.168.0.0/24",
|
||||
"10.1.2.3/8": "10.0.0.0/8",
|
||||
"201:be93::4a6/21": "201:b800::/21",
|
||||
}
|
||||
for q, expectedAnswer := range cases {
|
||||
ips, err := parseIPHostPrefix(q)
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, 1, len(ips))
|
||||
assert.Equal(t, expectedAnswer, ips[0].String())
|
||||
}
|
||||
}
|
||||
|
||||
func newEmptyFile(t *testing.T) string {
|
||||
filename := filepath.Join(t.TempDir(), "empty")
|
||||
require.Nil(t, os.WriteFile(filename, []byte{}, 0600))
|
||||
|
@ -2,6 +2,7 @@ package server
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"net/netip"
|
||||
"time"
|
||||
)
|
||||
|
||||
@ -92,7 +93,7 @@ type Config struct {
|
||||
VisitorAttachmentDailyBandwidthLimit int
|
||||
VisitorRequestLimitBurst int
|
||||
VisitorRequestLimitReplenish time.Duration
|
||||
VisitorRequestExemptIPAddrs []string
|
||||
VisitorRequestExemptIPAddrs []netip.Prefix
|
||||
VisitorEmailLimitBurst int
|
||||
VisitorEmailLimitReplenish time.Duration
|
||||
BehindProxy bool
|
||||
@ -135,7 +136,7 @@ func NewConfig() *Config {
|
||||
VisitorAttachmentDailyBandwidthLimit: DefaultVisitorAttachmentDailyBandwidthLimit,
|
||||
VisitorRequestLimitBurst: DefaultVisitorRequestLimitBurst,
|
||||
VisitorRequestLimitReplenish: DefaultVisitorRequestLimitReplenish,
|
||||
VisitorRequestExemptIPAddrs: make([]string, 0),
|
||||
VisitorRequestExemptIPAddrs: make([]netip.Prefix, 0),
|
||||
VisitorEmailLimitBurst: DefaultVisitorEmailLimitBurst,
|
||||
VisitorEmailLimitReplenish: DefaultVisitorEmailLimitReplenish,
|
||||
BehindProxy: false,
|
||||
|
@ -5,11 +5,13 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
||||
"heckel.io/ntfy/log"
|
||||
"heckel.io/ntfy/util"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -279,7 +281,7 @@ func (c *messageCache) addMessages(ms []*message) error {
|
||||
attachmentSize,
|
||||
attachmentExpires,
|
||||
attachmentURL,
|
||||
m.Sender,
|
||||
m.Sender.String(),
|
||||
m.Encoding,
|
||||
published,
|
||||
)
|
||||
@ -454,6 +456,11 @@ func readMessages(rows *sql.Rows) ([]*message, error) {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
senderIP, err := netip.ParseAddr(sender)
|
||||
if err != nil {
|
||||
senderIP = netip.IPv4Unspecified() // if no IP stored in database, 0.0.0.0
|
||||
}
|
||||
|
||||
var att *attachment
|
||||
if attachmentName != "" && attachmentURL != "" {
|
||||
att = &attachment{
|
||||
@ -477,7 +484,7 @@ func readMessages(rows *sql.Rows) ([]*message, error) {
|
||||
Icon: icon,
|
||||
Actions: actions,
|
||||
Attachment: att,
|
||||
Sender: sender,
|
||||
Sender: senderIP, // Must parse assuming database must be correct
|
||||
Encoding: encoding,
|
||||
})
|
||||
}
|
||||
|
@ -3,11 +3,17 @@ package server
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"net/netip"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var (
|
||||
exampleIP1234 = netip.MustParseAddr("1.2.3.4")
|
||||
)
|
||||
|
||||
func TestSqliteCache_Messages(t *testing.T) {
|
||||
@ -281,7 +287,7 @@ func testCacheAttachments(t *testing.T, c *messageCache) {
|
||||
expires1 := time.Now().Add(-4 * time.Hour).Unix()
|
||||
m := newDefaultMessage("mytopic", "flower for you")
|
||||
m.ID = "m1"
|
||||
m.Sender = "1.2.3.4"
|
||||
m.Sender = exampleIP1234
|
||||
m.Attachment = &attachment{
|
||||
Name: "flower.jpg",
|
||||
Type: "image/jpeg",
|
||||
@ -294,7 +300,7 @@ func testCacheAttachments(t *testing.T, c *messageCache) {
|
||||
expires2 := time.Now().Add(2 * time.Hour).Unix() // Future
|
||||
m = newDefaultMessage("mytopic", "sending you a car")
|
||||
m.ID = "m2"
|
||||
m.Sender = "1.2.3.4"
|
||||
m.Sender = exampleIP1234
|
||||
m.Attachment = &attachment{
|
||||
Name: "car.jpg",
|
||||
Type: "image/jpeg",
|
||||
@ -307,7 +313,7 @@ func testCacheAttachments(t *testing.T, c *messageCache) {
|
||||
expires3 := time.Now().Add(1 * time.Hour).Unix() // Future
|
||||
m = newDefaultMessage("another-topic", "sending you another car")
|
||||
m.ID = "m3"
|
||||
m.Sender = "1.2.3.4"
|
||||
m.Sender = exampleIP1234
|
||||
m.Attachment = &attachment{
|
||||
Name: "another-car.jpg",
|
||||
Type: "image/jpeg",
|
||||
@ -327,7 +333,7 @@ func testCacheAttachments(t *testing.T, c *messageCache) {
|
||||
require.Equal(t, int64(5000), messages[0].Attachment.Size)
|
||||
require.Equal(t, expires1, messages[0].Attachment.Expires)
|
||||
require.Equal(t, "https://ntfy.sh/file/AbDeFgJhal.jpg", messages[0].Attachment.URL)
|
||||
require.Equal(t, "1.2.3.4", messages[0].Sender)
|
||||
require.Equal(t, "1.2.3.4", messages[0].Sender.String())
|
||||
|
||||
require.Equal(t, "sending you a car", messages[1].Message)
|
||||
require.Equal(t, "car.jpg", messages[1].Attachment.Name)
|
||||
@ -335,7 +341,7 @@ func testCacheAttachments(t *testing.T, c *messageCache) {
|
||||
require.Equal(t, int64(10000), messages[1].Attachment.Size)
|
||||
require.Equal(t, expires2, messages[1].Attachment.Expires)
|
||||
require.Equal(t, "https://ntfy.sh/file/aCaRURL.jpg", messages[1].Attachment.URL)
|
||||
require.Equal(t, "1.2.3.4", messages[1].Sender)
|
||||
require.Equal(t, "1.2.3.4", messages[1].Sender.String())
|
||||
|
||||
size, err := c.AttachmentBytesUsed("1.2.3.4")
|
||||
require.Nil(t, err)
|
||||
|
@ -11,6 +11,7 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"os"
|
||||
"path"
|
||||
@ -42,7 +43,7 @@ type Server struct {
|
||||
smtpServerBackend *smtpBackend
|
||||
smtpSender mailer
|
||||
topics map[string]*topic
|
||||
visitors map[string]*visitor
|
||||
visitors map[netip.Addr]*visitor
|
||||
firebaseClient *firebaseClient
|
||||
messages int64
|
||||
auth auth.Auther
|
||||
@ -150,7 +151,7 @@ func New(conf *Config) (*Server, error) {
|
||||
smtpSender: mailer,
|
||||
topics: topics,
|
||||
auth: auther,
|
||||
visitors: make(map[string]*visitor),
|
||||
visitors: make(map[netip.Addr]*visitor),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -1219,7 +1220,7 @@ func (s *Server) runFirebaseKeepaliver() {
|
||||
if s.firebaseClient == nil {
|
||||
return
|
||||
}
|
||||
v := newVisitor(s.config, s.messageCache, "0.0.0.0") // Background process, not a real visitor
|
||||
v := newVisitor(s.config, s.messageCache, netip.IPv4Unspecified()) // Background process, not a real visitor, uses IP 0.0.0.0
|
||||
for {
|
||||
select {
|
||||
case <-time.After(s.config.FirebaseKeepaliveInterval):
|
||||
@ -1286,7 +1287,7 @@ func (s *Server) sendDelayedMessage(v *visitor, m *message) error {
|
||||
|
||||
func (s *Server) limitRequests(next handleFunc) handleFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
if util.Contains(s.config.VisitorRequestExemptIPAddrs, v.ip) {
|
||||
if util.ContainsIP(s.config.VisitorRequestExemptIPAddrs, v.ip) {
|
||||
return next(w, r, v)
|
||||
} else if err := v.RequestAllowed(); err != nil {
|
||||
return errHTTPTooManyRequestsLimitRequests
|
||||
@ -1436,21 +1437,33 @@ func extractUserPass(r *http.Request) (username string, password string, ok bool
|
||||
// This function was taken from https://www.alexedwards.net/blog/how-to-rate-limit-http-requests (MIT).
|
||||
func (s *Server) visitor(r *http.Request) *visitor {
|
||||
remoteAddr := r.RemoteAddr
|
||||
ip, _, err := net.SplitHostPort(remoteAddr)
|
||||
addrPort, err := netip.ParseAddrPort(remoteAddr)
|
||||
ip := addrPort.Addr()
|
||||
if err != nil {
|
||||
ip = remoteAddr // This should not happen in real life; only in tests.
|
||||
// This should not happen in real life; only in tests. So, using falling back to 0.0.0.0 if address unspecified
|
||||
ip, err = netip.ParseAddr(remoteAddr)
|
||||
if err != nil {
|
||||
ip = netip.IPv4Unspecified()
|
||||
log.Warn("unable to parse IP (%s), new visitor with unspecified IP (0.0.0.0) created %s", remoteAddr, err)
|
||||
}
|
||||
}
|
||||
if s.config.BehindProxy && strings.TrimSpace(r.Header.Get("X-Forwarded-For")) != "" {
|
||||
// X-Forwarded-For can contain multiple addresses (see #328). If we are behind a proxy,
|
||||
// only the right-most address can be trusted (as this is the one added by our proxy server).
|
||||
// See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For for details.
|
||||
ips := util.SplitNoEmpty(r.Header.Get("X-Forwarded-For"), ",")
|
||||
ip = strings.TrimSpace(util.LastString(ips, remoteAddr))
|
||||
realIP, err := netip.ParseAddr(strings.TrimSpace(util.LastString(ips, remoteAddr)))
|
||||
if err != nil {
|
||||
log.Error("invalid IP address %s received in X-Forwarded-For header: %s", ip, err.Error())
|
||||
// Fall back to regular remote address if X-Forwarded-For is damaged
|
||||
} else {
|
||||
ip = realIP
|
||||
}
|
||||
}
|
||||
return s.visitorFromIP(ip)
|
||||
}
|
||||
|
||||
func (s *Server) visitorFromIP(ip string) *visitor {
|
||||
func (s *Server) visitorFromIP(ip netip.Addr) *visitor {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
v, exists := s.visitors[ip]
|
||||
|
@ -3,13 +3,15 @@ package server
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"firebase.google.com/go/v4/messaging"
|
||||
"fmt"
|
||||
"github.com/stretchr/testify/require"
|
||||
"heckel.io/ntfy/auth"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"firebase.google.com/go/v4/messaging"
|
||||
"github.com/stretchr/testify/require"
|
||||
"heckel.io/ntfy/auth"
|
||||
)
|
||||
|
||||
type testAuther struct {
|
||||
@ -322,7 +324,7 @@ func TestMaybeTruncateFCMMessage_NotTooLong(t *testing.T) {
|
||||
func TestToFirebaseSender_Abuse(t *testing.T) {
|
||||
sender := &testFirebaseSender{allowed: 2}
|
||||
client := newFirebaseClient(sender, &testAuther{})
|
||||
visitor := newVisitor(newTestConfig(t), newMemTestCache(t), "1.2.3.4")
|
||||
visitor := newVisitor(newTestConfig(t), newMemTestCache(t), netip.MustParseAddr("1.2.3.4"))
|
||||
|
||||
require.Nil(t, client.Send(visitor, &message{Topic: "mytopic"}))
|
||||
require.Equal(t, 1, len(sender.Messages()))
|
||||
|
@ -1,11 +1,13 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMatrix_NewRequestFromMatrixJSON_Success(t *testing.T) {
|
||||
@ -70,7 +72,7 @@ func TestMatrix_WriteMatrixDiscoveryResponse(t *testing.T) {
|
||||
func TestMatrix_WriteMatrixError(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
r, _ := http.NewRequest("POST", "http://ntfy.example.com/_matrix/push/v1/notify", nil)
|
||||
v := newVisitor(newTestConfig(t), nil, "1.2.3.4")
|
||||
v := newVisitor(newTestConfig(t), nil, netip.MustParseAddr("1.2.3.4"))
|
||||
require.Nil(t, writeMatrixError(w, r, v, &errMatrix{"https://ntfy.example.com/upABCDEFGHI?up=1", errHTTPBadRequestMatrixPushkeyBaseURLMismatch}))
|
||||
require.Equal(t, 200, w.Result().StatusCode)
|
||||
require.Equal(t, `{"rejected":["https://ntfy.example.com/upABCDEFGHI?up=1"]}`+"\n", w.Body.String())
|
||||
|
@ -6,18 +6,20 @@ import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"io"
|
||||
"log"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/netip"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"heckel.io/ntfy/auth"
|
||||
"heckel.io/ntfy/util"
|
||||
@ -292,13 +294,13 @@ func TestServer_PublishAt(t *testing.T) {
|
||||
messages = toMessages(t, response.Body.String())
|
||||
require.Equal(t, 1, len(messages))
|
||||
require.Equal(t, "a message", messages[0].Message)
|
||||
require.Equal(t, "", messages[0].Sender) // Never return the sender!
|
||||
require.Equal(t, netip.Addr{}, messages[0].Sender) // Never return the sender!
|
||||
|
||||
messages, err := s.messageCache.Messages("mytopic", sinceAllMessages, true)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(messages))
|
||||
require.Equal(t, "a message", messages[0].Message)
|
||||
require.Equal(t, "9.9.9.9", messages[0].Sender) // It's stored in the DB though!
|
||||
require.Equal(t, "9.9.9.9", messages[0].Sender.String()) // It's stored in the DB though!
|
||||
}
|
||||
|
||||
func TestServer_PublishAtWithCacheError(t *testing.T) {
|
||||
@ -814,7 +816,7 @@ func TestServer_PublishTooRequests_Defaults(t *testing.T) {
|
||||
|
||||
func TestServer_PublishTooRequests_Defaults_ExemptHosts(t *testing.T) {
|
||||
c := newTestConfig(t)
|
||||
c.VisitorRequestExemptIPAddrs = []string{"9.9.9.9"} // see request()
|
||||
c.VisitorRequestExemptIPAddrs = []netip.Prefix{netip.MustParsePrefix("9.9.9.9/32")} // see request()
|
||||
s := newTestServer(t, c)
|
||||
for i := 0; i < 65; i++ { // > 60
|
||||
response := request(t, s, "PUT", "/mytopic", fmt.Sprintf("message %d", i), nil)
|
||||
@ -1132,7 +1134,7 @@ func TestServer_PublishAttachment(t *testing.T) {
|
||||
require.Equal(t, int64(5000), msg.Attachment.Size)
|
||||
require.GreaterOrEqual(t, msg.Attachment.Expires, time.Now().Add(179*time.Minute).Unix()) // Almost 3 hours
|
||||
require.Contains(t, msg.Attachment.URL, "http://127.0.0.1:12345/file/")
|
||||
require.Equal(t, "", msg.Sender) // Should never be returned
|
||||
require.Equal(t, netip.Addr{}, msg.Sender) // Should never be returned
|
||||
require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, msg.ID))
|
||||
|
||||
// GET
|
||||
@ -1168,7 +1170,7 @@ func TestServer_PublishAttachmentShortWithFilename(t *testing.T) {
|
||||
require.Equal(t, int64(21), msg.Attachment.Size)
|
||||
require.GreaterOrEqual(t, msg.Attachment.Expires, time.Now().Add(3*time.Hour).Unix())
|
||||
require.Contains(t, msg.Attachment.URL, "http://127.0.0.1:12345/file/")
|
||||
require.Equal(t, "", msg.Sender) // Should never be returned
|
||||
require.Equal(t, netip.Addr{}, msg.Sender) // Should never be returned
|
||||
require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, msg.ID))
|
||||
|
||||
path := strings.TrimPrefix(msg.Attachment.URL, "http://127.0.0.1:12345")
|
||||
@ -1195,7 +1197,7 @@ func TestServer_PublishAttachmentExternalWithoutFilename(t *testing.T) {
|
||||
require.Equal(t, "", msg.Attachment.Type)
|
||||
require.Equal(t, int64(0), msg.Attachment.Size)
|
||||
require.Equal(t, int64(0), msg.Attachment.Expires)
|
||||
require.Equal(t, "", msg.Sender)
|
||||
require.Equal(t, netip.Addr{}, msg.Sender)
|
||||
|
||||
// Slightly unrelated cross-test: make sure we don't add an owner for external attachments
|
||||
size, err := s.messageCache.AttachmentBytesUsed("127.0.0.1")
|
||||
@ -1216,7 +1218,7 @@ func TestServer_PublishAttachmentExternalWithFilename(t *testing.T) {
|
||||
require.Equal(t, "", msg.Attachment.Type)
|
||||
require.Equal(t, int64(0), msg.Attachment.Size)
|
||||
require.Equal(t, int64(0), msg.Attachment.Expires)
|
||||
require.Equal(t, "", msg.Sender)
|
||||
require.Equal(t, netip.Addr{}, msg.Sender)
|
||||
}
|
||||
|
||||
func TestServer_PublishAttachmentBadURL(t *testing.T) {
|
||||
@ -1391,7 +1393,7 @@ func TestServer_Visitor_XForwardedFor_None(t *testing.T) {
|
||||
r.RemoteAddr = "8.9.10.11"
|
||||
r.Header.Set("X-Forwarded-For", " ") // Spaces, not empty!
|
||||
v := s.visitor(r)
|
||||
require.Equal(t, "8.9.10.11", v.ip)
|
||||
require.Equal(t, "8.9.10.11", v.ip.String())
|
||||
}
|
||||
|
||||
func TestServer_Visitor_XForwardedFor_Single(t *testing.T) {
|
||||
@ -1402,7 +1404,7 @@ func TestServer_Visitor_XForwardedFor_Single(t *testing.T) {
|
||||
r.RemoteAddr = "8.9.10.11"
|
||||
r.Header.Set("X-Forwarded-For", "1.1.1.1")
|
||||
v := s.visitor(r)
|
||||
require.Equal(t, "1.1.1.1", v.ip)
|
||||
require.Equal(t, "1.1.1.1", v.ip.String())
|
||||
}
|
||||
|
||||
func TestServer_Visitor_XForwardedFor_Multiple(t *testing.T) {
|
||||
@ -1413,7 +1415,7 @@ func TestServer_Visitor_XForwardedFor_Multiple(t *testing.T) {
|
||||
r.RemoteAddr = "8.9.10.11"
|
||||
r.Header.Set("X-Forwarded-For", "1.2.3.4 , 2.4.4.2,234.5.2.1 ")
|
||||
v := s.visitor(r)
|
||||
require.Equal(t, "234.5.2.1", v.ip)
|
||||
require.Equal(t, "234.5.2.1", v.ip.String())
|
||||
}
|
||||
|
||||
func TestServer_PublishWhileUpdatingStatsWithLotsOfMessages(t *testing.T) {
|
||||
|
@ -32,7 +32,7 @@ func (s *smtpSender) Send(v *visitor, m *message, to string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
message, err := formatMail(s.config.BaseURL, v.ip, s.config.SMTPSenderFrom, to, m)
|
||||
message, err := formatMail(s.config.BaseURL, v.ip.String(), s.config.SMTPSenderFrom, to, m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -1,9 +1,11 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"heckel.io/ntfy/util"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"heckel.io/ntfy/util"
|
||||
)
|
||||
|
||||
// List of possible events
|
||||
@ -33,7 +35,7 @@ type message struct {
|
||||
Actions []*action `json:"actions,omitempty"`
|
||||
Attachment *attachment `json:"attachment,omitempty"`
|
||||
PollID string `json:"poll_id,omitempty"`
|
||||
Sender string `json:"-"` // IP address of uploader, used for rate limiting
|
||||
Sender netip.Addr `json:"-"` // IP address of uploader, used for rate limiting
|
||||
Encoding string `json:"encoding,omitempty"` // empty for raw UTF-8, or "base64" for encoded bytes
|
||||
}
|
||||
|
||||
|
@ -2,10 +2,12 @@ package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"golang.org/x/time/rate"
|
||||
"heckel.io/ntfy/util"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
"heckel.io/ntfy/util"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -23,7 +25,7 @@ var (
|
||||
type visitor struct {
|
||||
config *Config
|
||||
messageCache *messageCache
|
||||
ip string
|
||||
ip netip.Addr
|
||||
requests *rate.Limiter
|
||||
emails *rate.Limiter
|
||||
subscriptions util.Limiter
|
||||
@ -40,7 +42,7 @@ type visitorStats struct {
|
||||
VisitorAttachmentBytesRemaining int64 `json:"visitorAttachmentBytesRemaining"`
|
||||
}
|
||||
|
||||
func newVisitor(conf *Config, messageCache *messageCache, ip string) *visitor {
|
||||
func newVisitor(conf *Config, messageCache *messageCache, ip netip.Addr) *visitor {
|
||||
return &visitor{
|
||||
config: conf,
|
||||
messageCache: messageCache,
|
||||
@ -115,7 +117,7 @@ func (v *visitor) Stale() bool {
|
||||
}
|
||||
|
||||
func (v *visitor) Stats() (*visitorStats, error) {
|
||||
attachmentsBytesUsed, err := v.messageCache.AttachmentBytesUsed(v.ip)
|
||||
attachmentsBytesUsed, err := v.messageCache.AttachmentBytesUsed(v.ip.String())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
16
util/util.go
16
util/util.go
@ -5,16 +5,18 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gabriel-vasile/mimetype"
|
||||
"golang.org/x/term"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net/netip"
|
||||
"os"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gabriel-vasile/mimetype"
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -45,6 +47,16 @@ func Contains[T comparable](haystack []T, needle T) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// ContainsIP returns true if any one of the of prefixes contains the ip.
|
||||
func ContainsIP(haystack []netip.Prefix, needle netip.Addr) bool {
|
||||
for _, s := range haystack {
|
||||
if s.Contains(needle) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ContainsAll returns true if all needles are contained in haystack
|
||||
func ContainsAll[T comparable](haystack []T, needles []T) bool {
|
||||
matches := 0
|
||||
|
@ -1,10 +1,12 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"net/netip"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRandomString(t *testing.T) {
|
||||
@ -42,6 +44,13 @@ func TestContains(t *testing.T) {
|
||||
require.False(t, Contains(s, 3))
|
||||
}
|
||||
|
||||
func TestContainsIP(t *testing.T) {
|
||||
require.True(t, ContainsIP([]netip.Prefix{netip.MustParsePrefix("fd00::/8"), netip.MustParsePrefix("1.1.0.0/16")}, netip.MustParseAddr("1.1.1.1")))
|
||||
require.True(t, ContainsIP([]netip.Prefix{netip.MustParsePrefix("fd00::/8"), netip.MustParsePrefix("1.1.0.0/16")}, netip.MustParseAddr("fd12:1234:5678::9876")))
|
||||
require.False(t, ContainsIP([]netip.Prefix{netip.MustParsePrefix("fd00::/8"), netip.MustParsePrefix("1.1.0.0/16")}, netip.MustParseAddr("1.2.0.1")))
|
||||
require.False(t, ContainsIP([]netip.Prefix{netip.MustParsePrefix("fd00::/8"), netip.MustParsePrefix("1.1.0.0/16")}, netip.MustParseAddr("fc00::1")))
|
||||
}
|
||||
|
||||
func TestSplitNoEmpty(t *testing.T) {
|
||||
require.Equal(t, []string{}, SplitNoEmpty("", ","))
|
||||
require.Equal(t, []string{}, SplitNoEmpty(",,,", ","))
|
||||
|
Loading…
Reference in New Issue
Block a user