mirror of
https://github.com/binwiederhier/ntfy.git
synced 2024-11-23 03:24:27 +03:00
Fix a bunch of FIXMEs
This commit is contained in:
parent
f945fb4cdd
commit
3bd6518309
9
go.mod
9
go.mod
@ -25,7 +25,10 @@ require (
|
|||||||
|
|
||||||
require github.com/pkg/errors v0.9.1 // indirect
|
require github.com/pkg/errors v0.9.1 // indirect
|
||||||
|
|
||||||
require firebase.google.com/go/v4 v4.10.0
|
require (
|
||||||
|
firebase.google.com/go/v4 v4.10.0
|
||||||
|
github.com/stripe/stripe-go/v74 v74.5.0
|
||||||
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
cloud.google.com/go v0.107.0 // indirect
|
cloud.google.com/go v0.107.0 // indirect
|
||||||
@ -46,10 +49,6 @@ require (
|
|||||||
github.com/googleapis/gax-go/v2 v2.7.0 // indirect
|
github.com/googleapis/gax-go/v2 v2.7.0 // indirect
|
||||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
github.com/russross/blackfriday/v2 v2.1.0 // indirect
|
github.com/russross/blackfriday/v2 v2.1.0 // indirect
|
||||||
github.com/stripe/stripe-go/v74 v74.5.0 // indirect
|
|
||||||
github.com/tidwall/gjson v1.14.4 // indirect
|
|
||||||
github.com/tidwall/match v1.1.1 // indirect
|
|
||||||
github.com/tidwall/pretty v1.2.1 // indirect
|
|
||||||
github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect
|
github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect
|
||||||
go.opencensus.io v0.24.0 // indirect
|
go.opencensus.io v0.24.0 // indirect
|
||||||
golang.org/x/net v0.4.0 // indirect
|
golang.org/x/net v0.4.0 // indirect
|
||||||
|
7
go.sum
7
go.sum
@ -102,13 +102,6 @@ github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKs
|
|||||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||||
github.com/stripe/stripe-go/v74 v74.5.0 h1:YyqTvVQdS34KYGCfVB87EMn9eDV3FCFkSwfdOQhiVL4=
|
github.com/stripe/stripe-go/v74 v74.5.0 h1:YyqTvVQdS34KYGCfVB87EMn9eDV3FCFkSwfdOQhiVL4=
|
||||||
github.com/stripe/stripe-go/v74 v74.5.0/go.mod h1:5PoXNp30AJ3tGq57ZcFuaMylzNi8KpwlrYAFmO1fHZw=
|
github.com/stripe/stripe-go/v74 v74.5.0/go.mod h1:5PoXNp30AJ3tGq57ZcFuaMylzNi8KpwlrYAFmO1fHZw=
|
||||||
github.com/tidwall/gjson v1.14.4 h1:uo0p8EbA09J7RQaflQ1aBRffTR7xedD2bcIVSYxLnkM=
|
|
||||||
github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
|
||||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
|
||||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
|
||||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
|
||||||
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
|
|
||||||
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
|
||||||
github.com/urfave/cli/v2 v2.23.7 h1:YHDQ46s3VghFHFf1DdF+Sh7H4RqhcM+t0TmZRJx4oJY=
|
github.com/urfave/cli/v2 v2.23.7 h1:YHDQ46s3VghFHFf1DdF+Sh7H4RqhcM+t0TmZRJx4oJY=
|
||||||
github.com/urfave/cli/v2 v2.23.7/go.mod h1:GHupkWPMM0M/sj1a2b4wUrWBPzazNrIjouW6fmdJLxc=
|
github.com/urfave/cli/v2 v2.23.7/go.mod h1:GHupkWPMM0M/sj1a2b4wUrWBPzazNrIjouW6fmdJLxc=
|
||||||
github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 h1:bAn7/zixMGCfxrRTfdpNzjtPYqr8smhKouy9mxVdGPU=
|
github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 h1:bAn7/zixMGCfxrRTfdpNzjtPYqr8smhKouy9mxVdGPU=
|
||||||
|
@ -19,6 +19,7 @@ const (
|
|||||||
DefaultFirebaseKeepaliveInterval = 3 * time.Hour // ~control topic (Android), not too frequently to save battery
|
DefaultFirebaseKeepaliveInterval = 3 * time.Hour // ~control topic (Android), not too frequently to save battery
|
||||||
DefaultFirebasePollInterval = 20 * time.Minute // ~poll topic (iOS), max. 2-3 times per hour (see docs)
|
DefaultFirebasePollInterval = 20 * time.Minute // ~poll topic (iOS), max. 2-3 times per hour (see docs)
|
||||||
DefaultFirebaseQuotaExceededPenaltyDuration = 10 * time.Minute // Time that over-users are locked out of Firebase if it returns "quota exceeded"
|
DefaultFirebaseQuotaExceededPenaltyDuration = 10 * time.Minute // Time that over-users are locked out of Firebase if it returns "quota exceeded"
|
||||||
|
DefaultStripePriceCacheDuration = time.Hour // Time to keep Stripe prices cached in memory before a refresh is needed
|
||||||
)
|
)
|
||||||
|
|
||||||
// Defines all global and per-visitor limits
|
// Defines all global and per-visitor limits
|
||||||
@ -112,10 +113,12 @@ type Config struct {
|
|||||||
BehindProxy bool
|
BehindProxy bool
|
||||||
StripeSecretKey string
|
StripeSecretKey string
|
||||||
StripeWebhookKey string
|
StripeWebhookKey string
|
||||||
|
StripePriceCacheDuration time.Duration
|
||||||
EnableWeb bool
|
EnableWeb bool
|
||||||
EnableSignup bool // Enable creation of accounts via API and UI
|
EnableSignup bool // Enable creation of accounts via API and UI
|
||||||
EnableLogin bool
|
EnableLogin bool
|
||||||
EnableReservations bool // Allow users with role "user" to own/reserve topics
|
EnableReservations bool // Allow users with role "user" to own/reserve topics
|
||||||
|
AccessControlAllowOrigin string // CORS header field to restrict access from web clients
|
||||||
Version string // injected by App
|
Version string // injected by App
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -132,9 +135,11 @@ func NewConfig() *Config {
|
|||||||
FirebaseKeyFile: "",
|
FirebaseKeyFile: "",
|
||||||
CacheFile: "",
|
CacheFile: "",
|
||||||
CacheDuration: DefaultCacheDuration,
|
CacheDuration: DefaultCacheDuration,
|
||||||
|
CacheStartupQueries: "",
|
||||||
CacheBatchSize: 0,
|
CacheBatchSize: 0,
|
||||||
CacheBatchTimeout: 0,
|
CacheBatchTimeout: 0,
|
||||||
AuthFile: "",
|
AuthFile: "",
|
||||||
|
AuthStartupQueries: "",
|
||||||
AuthDefault: user.NewPermission(true, true),
|
AuthDefault: user.NewPermission(true, true),
|
||||||
AttachmentCacheDir: "",
|
AttachmentCacheDir: "",
|
||||||
AttachmentTotalSizeLimit: DefaultAttachmentTotalSizeLimit,
|
AttachmentTotalSizeLimit: DefaultAttachmentTotalSizeLimit,
|
||||||
@ -142,14 +147,24 @@ func NewConfig() *Config {
|
|||||||
AttachmentExpiryDuration: DefaultAttachmentExpiryDuration,
|
AttachmentExpiryDuration: DefaultAttachmentExpiryDuration,
|
||||||
KeepaliveInterval: DefaultKeepaliveInterval,
|
KeepaliveInterval: DefaultKeepaliveInterval,
|
||||||
ManagerInterval: DefaultManagerInterval,
|
ManagerInterval: DefaultManagerInterval,
|
||||||
MessageLimit: DefaultMessageLengthLimit,
|
WebRootIsApp: false,
|
||||||
MinDelay: DefaultMinDelay,
|
|
||||||
MaxDelay: DefaultMaxDelay,
|
|
||||||
DelayedSenderInterval: DefaultDelayedSenderInterval,
|
DelayedSenderInterval: DefaultDelayedSenderInterval,
|
||||||
FirebaseKeepaliveInterval: DefaultFirebaseKeepaliveInterval,
|
FirebaseKeepaliveInterval: DefaultFirebaseKeepaliveInterval,
|
||||||
FirebasePollInterval: DefaultFirebasePollInterval,
|
FirebasePollInterval: DefaultFirebasePollInterval,
|
||||||
FirebaseQuotaExceededPenaltyDuration: DefaultFirebaseQuotaExceededPenaltyDuration,
|
FirebaseQuotaExceededPenaltyDuration: DefaultFirebaseQuotaExceededPenaltyDuration,
|
||||||
|
UpstreamBaseURL: "",
|
||||||
|
SMTPSenderAddr: "",
|
||||||
|
SMTPSenderUser: "",
|
||||||
|
SMTPSenderPass: "",
|
||||||
|
SMTPSenderFrom: "",
|
||||||
|
SMTPServerListen: "",
|
||||||
|
SMTPServerDomain: "",
|
||||||
|
SMTPServerAddrPrefix: "",
|
||||||
|
MessageLimit: DefaultMessageLengthLimit,
|
||||||
|
MinDelay: DefaultMinDelay,
|
||||||
|
MaxDelay: DefaultMaxDelay,
|
||||||
TotalTopicLimit: DefaultTotalTopicLimit,
|
TotalTopicLimit: DefaultTotalTopicLimit,
|
||||||
|
TotalAttachmentSizeLimit: 0,
|
||||||
VisitorSubscriptionLimit: DefaultVisitorSubscriptionLimit,
|
VisitorSubscriptionLimit: DefaultVisitorSubscriptionLimit,
|
||||||
VisitorAttachmentTotalSizeLimit: DefaultVisitorAttachmentTotalSizeLimit,
|
VisitorAttachmentTotalSizeLimit: DefaultVisitorAttachmentTotalSizeLimit,
|
||||||
VisitorAttachmentDailyBandwidthLimit: DefaultVisitorAttachmentDailyBandwidthLimit,
|
VisitorAttachmentDailyBandwidthLimit: DefaultVisitorAttachmentDailyBandwidthLimit,
|
||||||
@ -162,7 +177,14 @@ func NewConfig() *Config {
|
|||||||
VisitorAccountCreateLimitReplenish: DefaultVisitorAccountCreateLimitReplenish,
|
VisitorAccountCreateLimitReplenish: DefaultVisitorAccountCreateLimitReplenish,
|
||||||
VisitorStatsResetTime: DefaultVisitorStatsResetTime,
|
VisitorStatsResetTime: DefaultVisitorStatsResetTime,
|
||||||
BehindProxy: false,
|
BehindProxy: false,
|
||||||
|
StripeSecretKey: "",
|
||||||
|
StripeWebhookKey: "",
|
||||||
|
StripePriceCacheDuration: DefaultStripePriceCacheDuration,
|
||||||
EnableWeb: true,
|
EnableWeb: true,
|
||||||
|
EnableSignup: false,
|
||||||
|
EnableLogin: false,
|
||||||
|
EnableReservations: false,
|
||||||
|
AccessControlAllowOrigin: "*",
|
||||||
Version: "",
|
Version: "",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -39,21 +39,18 @@ import (
|
|||||||
payments:
|
payments:
|
||||||
- send dunning emails when overdue
|
- send dunning emails when overdue
|
||||||
- payment methods
|
- payment methods
|
||||||
- unmarshal to stripe.Subscription instead of gjson
|
|
||||||
- delete subscription when account deleted
|
- delete subscription when account deleted
|
||||||
- delete messages + reserved topics on ResetTier
|
- delete messages + reserved topics on ResetTier
|
||||||
|
|
||||||
- move v1/account/tiers to v1/tiers
|
|
||||||
|
|
||||||
Limits & rate limiting:
|
Limits & rate limiting:
|
||||||
users without tier: should the stats be persisted? are they meaningful?
|
users without tier: should the stats be persisted? are they meaningful?
|
||||||
-> test that the visitor is based on the IP address!
|
-> test that the visitor is based on the IP address!
|
||||||
login/account endpoints
|
login/account endpoints
|
||||||
when ResetStats() is run, reset messagesLimiter (and others)?
|
when ResetStats() is run, reset messagesLimiter (and others)?
|
||||||
update last_seen when API is accessed
|
|
||||||
Make sure account endpoints make sense for admins
|
Make sure account endpoints make sense for admins
|
||||||
|
|
||||||
UI:
|
UI:
|
||||||
|
- revert home page change
|
||||||
- flicker of upgrade banner
|
- flicker of upgrade banner
|
||||||
- JS constants
|
- JS constants
|
||||||
Sync:
|
Sync:
|
||||||
@ -82,7 +79,7 @@ type Server struct {
|
|||||||
userManager *user.Manager // Might be nil!
|
userManager *user.Manager // Might be nil!
|
||||||
messageCache *messageCache
|
messageCache *messageCache
|
||||||
fileCache *fileCache
|
fileCache *fileCache
|
||||||
priceCache map[string]string // Stripe price ID -> formatted price
|
priceCache *util.LookupCache[map[string]string] // Stripe price ID -> formatted price
|
||||||
closeChan chan bool
|
closeChan chan bool
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
@ -144,7 +141,8 @@ const (
|
|||||||
emptyMessageBody = "triggered" // Used if message body is empty
|
emptyMessageBody = "triggered" // Used if message body is empty
|
||||||
newMessageBody = "New message" // Used in poll requests as generic message
|
newMessageBody = "New message" // Used in poll requests as generic message
|
||||||
defaultAttachmentMessage = "You received a file: %s" // Used if message body is empty, and there is an attachment
|
defaultAttachmentMessage = "You received a file: %s" // Used if message body is empty, and there is an attachment
|
||||||
encodingBase64 = "base64"
|
encodingBase64 = "base64" // Used mainly for binary UnifiedPush messages
|
||||||
|
jsonBodyBytesLimit = 16384
|
||||||
)
|
)
|
||||||
|
|
||||||
// WebSocket constants
|
// WebSocket constants
|
||||||
@ -201,7 +199,7 @@ func New(conf *Config) (*Server, error) {
|
|||||||
topics: topics,
|
topics: topics,
|
||||||
userManager: userManager,
|
userManager: userManager,
|
||||||
visitors: make(map[string]*visitor),
|
visitors: make(map[string]*visitor),
|
||||||
priceCache: make(map[string]string),
|
priceCache: util.NewLookupCache(fetchStripePrices, conf.StripePriceCacheDuration),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -454,22 +452,14 @@ func (s *Server) handleEmpty(_ http.ResponseWriter, _ *http.Request, _ *visitor)
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleTopicAuth(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
|
func (s *Server) handleTopicAuth(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
return s.writeJSON(w, newSuccessResponse())
|
||||||
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
|
|
||||||
_, err := io.WriteString(w, `{"success":true}`+"\n")
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleHealth(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
|
func (s *Server) handleHealth(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
|
||||||
response := &apiHealthResponse{
|
response := &apiHealthResponse{
|
||||||
Healthy: true,
|
Healthy: true,
|
||||||
}
|
}
|
||||||
w.Header().Set("Content-Type", "text/json")
|
return s.writeJSON(w, response)
|
||||||
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
|
|
||||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleWebConfig(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
|
func (s *Server) handleWebConfig(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
|
||||||
@ -620,12 +610,7 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
w.Header().Set("Content-Type", "application/json")
|
return s.writeJSON(w, m)
|
||||||
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
|
|
||||||
if err := json.NewEncoder(w).Encode(m); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handlePublishMatrix(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
func (s *Server) handlePublishMatrix(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||||
@ -1175,7 +1160,7 @@ func parseSince(r *http.Request, poll bool) (sinceMarker, error) {
|
|||||||
|
|
||||||
func (s *Server) handleOptions(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
|
func (s *Server) handleOptions(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
|
||||||
w.Header().Set("Access-Control-Allow-Methods", "GET, PUT, POST, PATCH, DELETE")
|
w.Header().Set("Access-Control-Allow-Methods", "GET, PUT, POST, PATCH, DELETE")
|
||||||
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
|
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
|
||||||
w.Header().Set("Access-Control-Allow-Headers", "*") // CORS, allow auth via JS // FIXME is this terrible?
|
w.Header().Set("Access-Control-Allow-Headers", "*") // CORS, allow auth via JS // FIXME is this terrible?
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -1482,7 +1467,7 @@ func (s *Server) limitRequests(next handleFunc) handleFunc {
|
|||||||
// before passing it on to the next handler. This is meant to be used in combination with handlePublish.
|
// before passing it on to the next handler. This is meant to be used in combination with handlePublish.
|
||||||
func (s *Server) transformBodyJSON(next handleFunc) handleFunc {
|
func (s *Server) transformBodyJSON(next handleFunc) handleFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||||
m, err := readJSONWithLimit[publishMessage](r.Body, s.config.MessageLimit)
|
m, err := readJSONWithLimit[publishMessage](r.Body, s.config.MessageLimit*2) // 2x to account for JSON format overhead
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -1650,3 +1635,12 @@ func (s *Server) visitorFromIP(ip netip.Addr) *visitor {
|
|||||||
func (s *Server) visitorFromUser(user *user.User, ip netip.Addr) *visitor {
|
func (s *Server) visitorFromUser(user *user.User, ip netip.Addr) *visitor {
|
||||||
return s.visitorFromID(fmt.Sprintf("user:%s", user.Name), ip, user)
|
return s.visitorFromID(fmt.Sprintf("user:%s", user.Name), ip, user)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) writeJSON(w http.ResponseWriter, v any) error {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
|
||||||
|
if err := json.NewEncoder(w).Encode(v); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
@ -10,7 +10,6 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
jsonBodyBytesLimit = 4096
|
|
||||||
subscriptionIDLength = 16
|
subscriptionIDLength = 16
|
||||||
createdByAPI = "api"
|
createdByAPI = "api"
|
||||||
syncTopicAccountSyncEvent = "sync"
|
syncTopicAccountSyncEvent = "sync"
|
||||||
@ -38,9 +37,7 @@ func (s *Server) handleAccountCreate(w http.ResponseWriter, r *http.Request, v *
|
|||||||
if err := s.userManager.AddUser(newAccount.Username, newAccount.Password, user.RoleUser, createdByAPI); err != nil { // TODO this should return a User
|
if err := s.userManager.AddUser(newAccount.Username, newAccount.Password, user.RoleUser, createdByAPI); err != nil { // TODO this should return a User
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
w.Header().Set("Content-Type", "application/json")
|
return s.writeJSON(w, newSuccessResponse())
|
||||||
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleAccountGet(w http.ResponseWriter, _ *http.Request, v *visitor) error {
|
func (s *Server) handleAccountGet(w http.ResponseWriter, _ *http.Request, v *visitor) error {
|
||||||
@ -118,21 +115,14 @@ func (s *Server) handleAccountGet(w http.ResponseWriter, _ *http.Request, v *vis
|
|||||||
response.Username = user.Everyone
|
response.Username = user.Everyone
|
||||||
response.Role = string(user.RoleAnonymous)
|
response.Role = string(user.RoleAnonymous)
|
||||||
}
|
}
|
||||||
w.Header().Set("Content-Type", "application/json")
|
return s.writeJSON(w, response)
|
||||||
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
|
|
||||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleAccountDelete(w http.ResponseWriter, _ *http.Request, v *visitor) error {
|
func (s *Server) handleAccountDelete(w http.ResponseWriter, _ *http.Request, v *visitor) error {
|
||||||
if err := s.userManager.RemoveUser(v.user.Name); err != nil {
|
if err := s.userManager.RemoveUser(v.user.Name); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
w.Header().Set("Content-Type", "application/json")
|
return s.writeJSON(w, newSuccessResponse())
|
||||||
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleAccountPasswordChange(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
func (s *Server) handleAccountPasswordChange(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||||
@ -143,9 +133,7 @@ func (s *Server) handleAccountPasswordChange(w http.ResponseWriter, r *http.Requ
|
|||||||
if err := s.userManager.ChangePassword(v.user.Name, newPassword.Password); err != nil {
|
if err := s.userManager.ChangePassword(v.user.Name, newPassword.Password); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
w.Header().Set("Content-Type", "application/json")
|
return s.writeJSON(w, newSuccessResponse())
|
||||||
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleAccountTokenIssue(w http.ResponseWriter, _ *http.Request, v *visitor) error {
|
func (s *Server) handleAccountTokenIssue(w http.ResponseWriter, _ *http.Request, v *visitor) error {
|
||||||
@ -154,16 +142,11 @@ func (s *Server) handleAccountTokenIssue(w http.ResponseWriter, _ *http.Request,
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
|
|
||||||
response := &apiAccountTokenResponse{
|
response := &apiAccountTokenResponse{
|
||||||
Token: token.Value,
|
Token: token.Value,
|
||||||
Expires: token.Expires.Unix(),
|
Expires: token.Expires.Unix(),
|
||||||
}
|
}
|
||||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
return s.writeJSON(w, response)
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleAccountTokenExtend(w http.ResponseWriter, _ *http.Request, v *visitor) error {
|
func (s *Server) handleAccountTokenExtend(w http.ResponseWriter, _ *http.Request, v *visitor) error {
|
||||||
@ -177,16 +160,11 @@ func (s *Server) handleAccountTokenExtend(w http.ResponseWriter, _ *http.Request
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
|
|
||||||
response := &apiAccountTokenResponse{
|
response := &apiAccountTokenResponse{
|
||||||
Token: token.Value,
|
Token: token.Value,
|
||||||
Expires: token.Expires.Unix(),
|
Expires: token.Expires.Unix(),
|
||||||
}
|
}
|
||||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
return s.writeJSON(w, response)
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleAccountTokenDelete(w http.ResponseWriter, _ *http.Request, v *visitor) error {
|
func (s *Server) handleAccountTokenDelete(w http.ResponseWriter, _ *http.Request, v *visitor) error {
|
||||||
@ -197,8 +175,7 @@ func (s *Server) handleAccountTokenDelete(w http.ResponseWriter, _ *http.Request
|
|||||||
if err := s.userManager.RemoveToken(v.user); err != nil {
|
if err := s.userManager.RemoveToken(v.user); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
|
return s.writeJSON(w, newSuccessResponse())
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleAccountSettingsChange(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
func (s *Server) handleAccountSettingsChange(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||||
@ -230,9 +207,7 @@ func (s *Server) handleAccountSettingsChange(w http.ResponseWriter, r *http.Requ
|
|||||||
if err := s.userManager.ChangeSettings(v.user); err != nil {
|
if err := s.userManager.ChangeSettings(v.user); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
w.Header().Set("Content-Type", "application/json")
|
return s.writeJSON(w, newSuccessResponse())
|
||||||
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleAccountSubscriptionAdd(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
func (s *Server) handleAccountSubscriptionAdd(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||||
@ -257,12 +232,7 @@ func (s *Server) handleAccountSubscriptionAdd(w http.ResponseWriter, r *http.Req
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
w.Header().Set("Content-Type", "application/json")
|
return s.writeJSON(w, newSubscription)
|
||||||
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
|
|
||||||
if err := json.NewEncoder(w).Encode(newSubscription); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleAccountSubscriptionChange(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
func (s *Server) handleAccountSubscriptionChange(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||||
@ -292,12 +262,7 @@ func (s *Server) handleAccountSubscriptionChange(w http.ResponseWriter, r *http.
|
|||||||
if err := s.userManager.ChangeSettings(v.user); err != nil {
|
if err := s.userManager.ChangeSettings(v.user); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
w.Header().Set("Content-Type", "application/json")
|
return s.writeJSON(w, subscription)
|
||||||
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
|
|
||||||
if err := json.NewEncoder(w).Encode(subscription); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleAccountSubscriptionDelete(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
func (s *Server) handleAccountSubscriptionDelete(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||||
@ -321,9 +286,7 @@ func (s *Server) handleAccountSubscriptionDelete(w http.ResponseWriter, r *http.
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
w.Header().Set("Content-Type", "application/json")
|
return s.writeJSON(w, newSuccessResponse())
|
||||||
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||||
@ -366,9 +329,7 @@ func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Requ
|
|||||||
if err := s.userManager.AllowAccess(owner, user.Everyone, req.Topic, everyone.IsRead(), everyone.IsWrite()); err != nil {
|
if err := s.userManager.AllowAccess(owner, user.Everyone, req.Topic, everyone.IsRead(), everyone.IsWrite()); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
w.Header().Set("Content-Type", "application/json")
|
return s.writeJSON(w, newSuccessResponse())
|
||||||
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleAccountReservationDelete(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
func (s *Server) handleAccountReservationDelete(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||||
@ -392,9 +353,7 @@ func (s *Server) handleAccountReservationDelete(w http.ResponseWriter, r *http.R
|
|||||||
if err := s.userManager.ResetAccess(user.Everyone, topic); err != nil {
|
if err := s.userManager.ResetAccess(user.Everyone, topic); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
w.Header().Set("Content-Type", "application/json")
|
return s.writeJSON(w, newSuccessResponse())
|
||||||
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) publishSyncEvent(v *visitor) error {
|
func (s *Server) publishSyncEvent(v *visitor) error {
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -11,19 +12,15 @@ import (
|
|||||||
"github.com/stripe/stripe-go/v74/price"
|
"github.com/stripe/stripe-go/v74/price"
|
||||||
"github.com/stripe/stripe-go/v74/subscription"
|
"github.com/stripe/stripe-go/v74/subscription"
|
||||||
"github.com/stripe/stripe-go/v74/webhook"
|
"github.com/stripe/stripe-go/v74/webhook"
|
||||||
"github.com/tidwall/gjson"
|
|
||||||
"heckel.io/ntfy/log"
|
"heckel.io/ntfy/log"
|
||||||
"heckel.io/ntfy/user"
|
"heckel.io/ntfy/user"
|
||||||
"heckel.io/ntfy/util"
|
"heckel.io/ntfy/util"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
stripeBodyBytesLimit = 16384
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
var (
|
||||||
errNotAPaidTier = errors.New("tier does not have billing price identifier")
|
errNotAPaidTier = errors.New("tier does not have billing price identifier")
|
||||||
errMultipleBillingSubscriptions = errors.New("cannot have multiple billing subscriptions")
|
errMultipleBillingSubscriptions = errors.New("cannot have multiple billing subscriptions")
|
||||||
@ -52,22 +49,14 @@ func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for _, tier := range tiers {
|
prices, err := s.priceCache.Value()
|
||||||
if tier.StripePriceID == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
priceStr, ok := s.priceCache[tier.StripePriceID]
|
|
||||||
if !ok {
|
|
||||||
p, err := price.Get(tier.StripePriceID, nil)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if p.UnitAmount%100 == 0 {
|
for _, tier := range tiers {
|
||||||
priceStr = fmt.Sprintf("$%d", p.UnitAmount/100)
|
priceStr, ok := prices[tier.StripePriceID]
|
||||||
} else {
|
if tier.StripePriceID == "" || !ok {
|
||||||
priceStr = fmt.Sprintf("$%.2f", float64(p.UnitAmount)/100)
|
continue
|
||||||
}
|
|
||||||
s.priceCache[tier.StripePriceID] = priceStr // FIXME race, make this sync.Map or something
|
|
||||||
}
|
}
|
||||||
response = append(response, &apiAccountBillingTier{
|
response = append(response, &apiAccountBillingTier{
|
||||||
Code: tier.Code,
|
Code: tier.Code,
|
||||||
@ -84,12 +73,7 @@ func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
w.Header().Set("Content-Type", "application/json")
|
return s.writeJSON(w, response)
|
||||||
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
|
|
||||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleAccountBillingSubscriptionCreate creates a Stripe checkout flow to create a user subscription. The tier
|
// handleAccountBillingSubscriptionCreate creates a Stripe checkout flow to create a user subscription. The tier
|
||||||
@ -143,12 +127,7 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
|
|||||||
response := &apiAccountBillingSubscriptionCreateResponse{
|
response := &apiAccountBillingSubscriptionCreateResponse{
|
||||||
RedirectURL: sess.URL,
|
RedirectURL: sess.URL,
|
||||||
}
|
}
|
||||||
w.Header().Set("Content-Type", "application/json")
|
return s.writeJSON(w, response)
|
||||||
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
|
|
||||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWriter, r *http.Request, _ *visitor) error {
|
func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWriter, r *http.Request, _ *visitor) error {
|
||||||
@ -219,12 +198,7 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
w.Header().Set("Content-Type", "application/json")
|
return s.writeJSON(w, newSuccessResponse())
|
||||||
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
|
|
||||||
if err := json.NewEncoder(w).Encode(newSuccessResponse()); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleAccountBillingSubscriptionDelete facilitates downgrading a paid user to a tier-less user,
|
// handleAccountBillingSubscriptionDelete facilitates downgrading a paid user to a tier-less user,
|
||||||
@ -239,12 +213,7 @@ func (s *Server) handleAccountBillingSubscriptionDelete(w http.ResponseWriter, r
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
w.Header().Set("Content-Type", "application/json")
|
return s.writeJSON(w, newSuccessResponse())
|
||||||
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
|
|
||||||
if err := json.NewEncoder(w).Encode(newSuccessResponse()); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||||
@ -262,12 +231,7 @@ func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter,
|
|||||||
response := &apiAccountBillingPortalRedirectResponse{
|
response := &apiAccountBillingPortalRedirectResponse{
|
||||||
RedirectURL: ps.URL,
|
RedirectURL: ps.URL,
|
||||||
}
|
}
|
||||||
w.Header().Set("Content-Type", "application/json")
|
return s.writeJSON(w, response)
|
||||||
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
|
|
||||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleAccountBillingWebhook handles incoming Stripe webhooks. It mainly keeps the local user database in sync
|
// handleAccountBillingWebhook handles incoming Stripe webhooks. It mainly keeps the local user database in sync
|
||||||
@ -278,7 +242,7 @@ func (s *Server) handleAccountBillingWebhook(w http.ResponseWriter, r *http.Requ
|
|||||||
if stripeSignature == "" {
|
if stripeSignature == "" {
|
||||||
return errHTTPBadRequestBillingRequestInvalid
|
return errHTTPBadRequestBillingRequestInvalid
|
||||||
}
|
}
|
||||||
body, err := util.Peek(r.Body, stripeBodyBytesLimit)
|
body, err := util.Peek(r.Body, jsonBodyBytesLimit)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
} else if body.LimitReached {
|
} else if body.LimitReached {
|
||||||
@ -302,25 +266,23 @@ func (s *Server) handleAccountBillingWebhook(w http.ResponseWriter, r *http.Requ
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(event json.RawMessage) error {
|
func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(event json.RawMessage) error {
|
||||||
subscriptionID := gjson.GetBytes(event, "id")
|
r, err := util.UnmarshalJSON[apiStripeSubscriptionUpdatedEvent](io.NopCloser(bytes.NewReader(event)))
|
||||||
customerID := gjson.GetBytes(event, "customer")
|
if err != nil {
|
||||||
status := gjson.GetBytes(event, "status")
|
return err
|
||||||
currentPeriodEnd := gjson.GetBytes(event, "current_period_end")
|
} else if r.ID == "" || r.Customer == "" || r.Status == "" || r.CurrentPeriodEnd == 0 || r.Items == nil || len(r.Items.Data) != 1 || r.Items.Data[0].Price == nil || r.Items.Data[0].Price.ID == "" {
|
||||||
cancelAt := gjson.GetBytes(event, "cancel_at")
|
|
||||||
priceID := gjson.GetBytes(event, "items.data.0.price.id")
|
|
||||||
if !subscriptionID.Exists() || !status.Exists() || !currentPeriodEnd.Exists() || !cancelAt.Exists() || !priceID.Exists() {
|
|
||||||
return errHTTPBadRequestBillingRequestInvalid
|
return errHTTPBadRequestBillingRequestInvalid
|
||||||
}
|
}
|
||||||
log.Info("Stripe: customer %s: Updating subscription to status %s, with price %s", customerID.String(), status, priceID)
|
subscriptionID, priceID := r.ID, r.Items.Data[0].Price.ID
|
||||||
u, err := s.userManager.UserByStripeCustomer(customerID.String())
|
log.Info("Stripe: customer %s: Updating subscription to status %s, with price %s", r.Customer, r.Status, priceID)
|
||||||
|
u, err := s.userManager.UserByStripeCustomer(r.Customer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
tier, err := s.userManager.TierByStripePrice(priceID.String())
|
tier, err := s.userManager.TierByStripePrice(priceID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := s.updateSubscriptionAndTier(u, customerID.String(), subscriptionID.String(), status.String(), currentPeriodEnd.Int(), cancelAt.Int(), tier.Code); err != nil {
|
if err := s.updateSubscriptionAndTier(u, r.Customer, subscriptionID, r.Status, r.CurrentPeriodEnd, r.CancelAt, tier.Code); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified()))
|
s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified()))
|
||||||
@ -328,16 +290,18 @@ func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(event json.RawMe
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleAccountBillingWebhookSubscriptionDeleted(event json.RawMessage) error {
|
func (s *Server) handleAccountBillingWebhookSubscriptionDeleted(event json.RawMessage) error {
|
||||||
customerID := gjson.GetBytes(event, "customer")
|
r, err := util.UnmarshalJSON[apiStripeSubscriptionDeletedEvent](io.NopCloser(bytes.NewReader(event)))
|
||||||
if !customerID.Exists() {
|
if err != nil {
|
||||||
|
return err
|
||||||
|
} else if r.Customer == "" {
|
||||||
return errHTTPBadRequestBillingRequestInvalid
|
return errHTTPBadRequestBillingRequestInvalid
|
||||||
}
|
}
|
||||||
log.Info("Stripe: customer %s: subscription deleted, downgrading to unpaid tier", customerID.String())
|
log.Info("Stripe: customer %s: subscription deleted, downgrading to unpaid tier", r.Customer)
|
||||||
u, err := s.userManager.UserByStripeCustomer(customerID.String())
|
u, err := s.userManager.UserByStripeCustomer(r.Customer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := s.updateSubscriptionAndTier(u, customerID.String(), "", "", 0, 0, ""); err != nil {
|
if err := s.updateSubscriptionAndTier(u, r.Customer, "", "", 0, 0, ""); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified()))
|
s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified()))
|
||||||
@ -364,3 +328,27 @@ func (s *Server) updateSubscriptionAndTier(u *user.User, customerID, subscriptio
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// fetchStripePrices contacts the Stripe API to retrieve all prices. This is used by the server to cache the prices
|
||||||
|
// in memory, and ultimately for the web app to display the price table.
|
||||||
|
func fetchStripePrices() (map[string]string, error) {
|
||||||
|
log.Debug("Caching prices from Stripe API")
|
||||||
|
prices := make(map[string]string)
|
||||||
|
iter := price.List(&stripe.PriceListParams{
|
||||||
|
Active: stripe.Bool(true),
|
||||||
|
})
|
||||||
|
for iter.Next() {
|
||||||
|
p := iter.Price()
|
||||||
|
if p.UnitAmount%100 == 0 {
|
||||||
|
prices[p.ID] = fmt.Sprintf("$%d", p.UnitAmount/100)
|
||||||
|
} else {
|
||||||
|
prices[p.ID] = fmt.Sprintf("$%.2f", float64(p.UnitAmount)/100)
|
||||||
|
}
|
||||||
|
log.Trace("- Caching price %s = %v", p.ID, prices[p.ID])
|
||||||
|
}
|
||||||
|
if iter.Err() != nil {
|
||||||
|
log.Warn("Fetching Stripe prices failed: %s", iter.Err().Error())
|
||||||
|
return nil, iter.Err()
|
||||||
|
}
|
||||||
|
return prices, nil
|
||||||
|
}
|
||||||
|
@ -1463,7 +1463,7 @@ func TestServer_PublishAttachmentBandwidthLimit(t *testing.T) {
|
|||||||
msg := toMessage(t, response.Body.String())
|
msg := toMessage(t, response.Body.String())
|
||||||
require.Contains(t, msg.Attachment.URL, "http://127.0.0.1:12345/file/")
|
require.Contains(t, msg.Attachment.URL, "http://127.0.0.1:12345/file/")
|
||||||
|
|
||||||
// Get it 4 times successfully
|
// Value it 4 times successfully
|
||||||
path := strings.TrimPrefix(msg.Attachment.URL, "http://127.0.0.1:12345")
|
path := strings.TrimPrefix(msg.Attachment.URL, "http://127.0.0.1:12345")
|
||||||
for i := 1; i <= 4; i++ { // 4 successful downloads
|
for i := 1; i <= 4; i++ { // 4 successful downloads
|
||||||
response = request(t, s, "GET", path, "", nil)
|
response = request(t, s, "GET", path, "", nil)
|
||||||
|
@ -336,3 +336,22 @@ func newSuccessResponse() *apiSuccessResponse {
|
|||||||
Success: true,
|
Success: true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type apiStripeSubscriptionUpdatedEvent struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Customer string `json:"customer"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
CurrentPeriodEnd int64 `json:"current_period_end"`
|
||||||
|
CancelAt int64 `json:"cancel_at"`
|
||||||
|
Items *struct {
|
||||||
|
Data []*struct {
|
||||||
|
Price *struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
} `json:"price"`
|
||||||
|
} `json:"data"`
|
||||||
|
} `json:"items"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type apiStripeSubscriptionDeletedEvent struct {
|
||||||
|
Customer string `json:"customer"`
|
||||||
|
}
|
||||||
|
@ -66,7 +66,6 @@ const (
|
|||||||
stripe_subscription_cancel_at INT,
|
stripe_subscription_cancel_at INT,
|
||||||
created_by TEXT NOT NULL,
|
created_by TEXT NOT NULL,
|
||||||
created_at INT NOT NULL,
|
created_at INT NOT NULL,
|
||||||
last_seen INT NOT NULL,
|
|
||||||
FOREIGN KEY (tier_id) REFERENCES tier (id)
|
FOREIGN KEY (tier_id) REFERENCES tier (id)
|
||||||
);
|
);
|
||||||
CREATE UNIQUE INDEX idx_user ON user (user);
|
CREATE UNIQUE INDEX idx_user ON user (user);
|
||||||
@ -93,8 +92,8 @@ const (
|
|||||||
id INT PRIMARY KEY,
|
id INT PRIMARY KEY,
|
||||||
version INT NOT NULL
|
version INT NOT NULL
|
||||||
);
|
);
|
||||||
INSERT INTO user (id, user, pass, role, sync_topic, created_by, created_at, last_seen)
|
INSERT INTO user (id, user, pass, role, sync_topic, created_by, created_at)
|
||||||
VALUES (1, '*', '', 'anonymous', '', 'system', UNIXEPOCH(), 0)
|
VALUES (1, '*', '', 'anonymous', '', 'system', UNIXEPOCH())
|
||||||
ON CONFLICT (id) DO NOTHING;
|
ON CONFLICT (id) DO NOTHING;
|
||||||
`
|
`
|
||||||
createTablesQueries = `BEGIN; ` + createTablesQueriesNoTx + ` COMMIT;`
|
createTablesQueries = `BEGIN; ` + createTablesQueriesNoTx + ` COMMIT;`
|
||||||
@ -130,8 +129,8 @@ const (
|
|||||||
`
|
`
|
||||||
|
|
||||||
insertUserQuery = `
|
insertUserQuery = `
|
||||||
INSERT INTO user (user, pass, role, sync_topic, created_by, created_at, last_seen)
|
INSERT INTO user (user, pass, role, sync_topic, created_by, created_at)
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
VALUES (?, ?, ?, ?, ?, ?)
|
||||||
`
|
`
|
||||||
selectUsernamesQuery = `
|
selectUsernamesQuery = `
|
||||||
SELECT user
|
SELECT user
|
||||||
@ -257,8 +256,8 @@ const (
|
|||||||
ALTER TABLE user RENAME TO user_old;
|
ALTER TABLE user RENAME TO user_old;
|
||||||
`
|
`
|
||||||
migrate1To2InsertFromOldTablesAndDropNoTx = `
|
migrate1To2InsertFromOldTablesAndDropNoTx = `
|
||||||
INSERT INTO user (user, pass, role, sync_topic, created_by, created_at, last_seen)
|
INSERT INTO user (user, pass, role, sync_topic, created_by, created_at)
|
||||||
SELECT user, pass, role, '', 'admin', UNIXEPOCH(), UNIXEPOCH() FROM user_old;
|
SELECT user, pass, role, '', 'admin', UNIXEPOCH() FROM user_old;
|
||||||
|
|
||||||
INSERT INTO user_access (user_id, topic, read, write)
|
INSERT INTO user_access (user_id, topic, read, write)
|
||||||
SELECT u.id, a.topic, a.read, a.write
|
SELECT u.id, a.topic, a.read, a.write
|
||||||
@ -531,7 +530,7 @@ func (a *Manager) AddUser(username, password string, role Role, createdBy string
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
syncTopic, now := util.RandomString(syncTopicLength), time.Now().Unix()
|
syncTopic, now := util.RandomString(syncTopicLength), time.Now().Unix()
|
||||||
if _, err = a.db.Exec(insertUserQuery, username, hash, role, syncTopic, createdBy, now, now); err != nil {
|
if _, err = a.db.Exec(insertUserQuery, username, hash, role, syncTopic, createdBy, now); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@ -589,6 +588,7 @@ func (a *Manager) User(username string) (*User, error) {
|
|||||||
return a.readUser(rows)
|
return a.readUser(rows)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UserByStripeCustomer returns the user with the given Stripe customer ID if it exists, or ErrUserNotFound otherwise.
|
||||||
func (a *Manager) UserByStripeCustomer(stripeCustomerID string) (*User, error) {
|
func (a *Manager) UserByStripeCustomer(stripeCustomerID string) (*User, error) {
|
||||||
rows, err := a.db.Query(selectUserByStripeCustomerIDQuery, stripeCustomerID)
|
rows, err := a.db.Query(selectUserByStripeCustomerIDQuery, stripeCustomerID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -878,6 +878,7 @@ func (a *Manager) CreateTier(tier *Tier) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ChangeBilling updates a user's billing fields, namely the Stripe customer ID, and subscription information
|
||||||
func (a *Manager) ChangeBilling(user *User) error {
|
func (a *Manager) ChangeBilling(user *User) error {
|
||||||
if _, err := a.db.Exec(updateBillingQuery, nullString(user.Billing.StripeCustomerID), nullString(user.Billing.StripeSubscriptionID), nullString(string(user.Billing.StripeSubscriptionStatus)), nullInt64(user.Billing.StripeSubscriptionPaidUntil.Unix()), nullInt64(user.Billing.StripeSubscriptionCancelAt.Unix()), user.Name); err != nil {
|
if _, err := a.db.Exec(updateBillingQuery, nullString(user.Billing.StripeCustomerID), nullString(user.Billing.StripeSubscriptionID), nullString(string(user.Billing.StripeSubscriptionStatus)), nullInt64(user.Billing.StripeSubscriptionPaidUntil.Unix()), nullInt64(user.Billing.StripeSubscriptionCancelAt.Unix()), user.Name); err != nil {
|
||||||
return err
|
return err
|
||||||
@ -885,6 +886,7 @@ func (a *Manager) ChangeBilling(user *User) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Tiers returns a list of all Tier structs
|
||||||
func (a *Manager) Tiers() ([]*Tier, error) {
|
func (a *Manager) Tiers() ([]*Tier, error) {
|
||||||
rows, err := a.db.Query(selectTiersQuery)
|
rows, err := a.db.Query(selectTiersQuery)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -904,6 +906,7 @@ func (a *Manager) Tiers() ([]*Tier, error) {
|
|||||||
return tiers, nil
|
return tiers, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Tier returns a Tier based on the code, or ErrTierNotFound if it does not exist
|
||||||
func (a *Manager) Tier(code string) (*Tier, error) {
|
func (a *Manager) Tier(code string) (*Tier, error) {
|
||||||
rows, err := a.db.Query(selectTierByCodeQuery, code)
|
rows, err := a.db.Query(selectTierByCodeQuery, code)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -913,6 +916,7 @@ func (a *Manager) Tier(code string) (*Tier, error) {
|
|||||||
return a.readTier(rows)
|
return a.readTier(rows)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TierByStripePrice returns a Tier based on the Stripe price ID, or ErrTierNotFound if it does not exist
|
||||||
func (a *Manager) TierByStripePrice(priceID string) (*Tier, error) {
|
func (a *Manager) TierByStripePrice(priceID string) (*Tier, error) {
|
||||||
rows, err := a.db.Query(selectTierByPriceIDQuery, priceID)
|
rows, err := a.db.Query(selectTierByPriceIDQuery, priceID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
52
util/lookup_cache.go
Normal file
52
util/lookup_cache.go
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// LookupCache is a single-value cache with a time-to-live (TTL). The cache has a lookup function
|
||||||
|
// to retrieve the value and stores it until TTL is reached.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// lookup := func() (string, error) {
|
||||||
|
// r, _ := http.Get("...")
|
||||||
|
// s, _ := io.ReadAll(r.Body)
|
||||||
|
// return string(s), nil
|
||||||
|
// }
|
||||||
|
// c := NewLookupCache[string](lookup, time.Hour)
|
||||||
|
// fmt.Println(c.Get()) // Fetches the string via HTTP
|
||||||
|
// fmt.Println(c.Get()) // Uses cached value
|
||||||
|
type LookupCache[T any] struct {
|
||||||
|
value *T
|
||||||
|
lookup func() (T, error)
|
||||||
|
ttl time.Duration
|
||||||
|
updated time.Time
|
||||||
|
mu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewLookupCache creates a new LookupCache with a given time-to-live (TTL)
|
||||||
|
func NewLookupCache[T any](lookup func() (T, error), ttl time.Duration) *LookupCache[T] {
|
||||||
|
return &LookupCache[T]{
|
||||||
|
value: nil,
|
||||||
|
lookup: lookup,
|
||||||
|
ttl: ttl,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value returns the cached value, or retrieves it via the lookup function
|
||||||
|
func (c *LookupCache[T]) Value() (T, error) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
if c.value == nil || (c.ttl > 0 && time.Since(c.updated) > c.ttl) {
|
||||||
|
value, err := c.lookup()
|
||||||
|
if err != nil {
|
||||||
|
var t T
|
||||||
|
return t, err
|
||||||
|
}
|
||||||
|
c.value = &value
|
||||||
|
c.updated = time.Now()
|
||||||
|
}
|
||||||
|
return *c.value, nil
|
||||||
|
}
|
63
util/lookup_cache_test.go
Normal file
63
util/lookup_cache_test.go
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLookupCache_Success(t *testing.T) {
|
||||||
|
values, i := []string{"first", "second"}, 0
|
||||||
|
c := NewLookupCache[string](func() (string, error) {
|
||||||
|
time.Sleep(300 * time.Millisecond)
|
||||||
|
v := values[i]
|
||||||
|
i++
|
||||||
|
return v, nil
|
||||||
|
}, 500*time.Millisecond)
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
v, err := c.Value()
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, values[0], v)
|
||||||
|
require.True(t, time.Since(start) >= 300*time.Millisecond)
|
||||||
|
|
||||||
|
start = time.Now()
|
||||||
|
v, err = c.Value()
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, values[0], v)
|
||||||
|
require.True(t, time.Since(start) < 200*time.Millisecond)
|
||||||
|
|
||||||
|
time.Sleep(550 * time.Millisecond)
|
||||||
|
|
||||||
|
start = time.Now()
|
||||||
|
v, err = c.Value()
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, values[1], v)
|
||||||
|
require.True(t, time.Since(start) >= 300*time.Millisecond)
|
||||||
|
|
||||||
|
start = time.Now()
|
||||||
|
v, err = c.Value()
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, values[1], v)
|
||||||
|
require.True(t, time.Since(start) < 200*time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLookupCache_Error(t *testing.T) {
|
||||||
|
c := NewLookupCache[string](func() (string, error) {
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
return "", errors.New("some error")
|
||||||
|
}, 500*time.Millisecond)
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
v, err := c.Value()
|
||||||
|
require.NotNil(t, err)
|
||||||
|
require.Equal(t, "", v)
|
||||||
|
require.True(t, time.Since(start) >= 200*time.Millisecond)
|
||||||
|
|
||||||
|
start = time.Now()
|
||||||
|
v, err = c.Value()
|
||||||
|
require.NotNil(t, err)
|
||||||
|
require.Equal(t, "", v)
|
||||||
|
require.True(t, time.Since(start) >= 200*time.Millisecond)
|
||||||
|
}
|
@ -24,11 +24,6 @@ class AccountApi {
|
|||||||
constructor() {
|
constructor() {
|
||||||
this.timer = null;
|
this.timer = null;
|
||||||
this.listener = null; // Fired when account is fetched from remote
|
this.listener = null; // Fired when account is fetched from remote
|
||||||
|
|
||||||
// Random ID used to identify this client when sending/receiving "sync" events
|
|
||||||
// to the sync topic of an account. This ID doesn't matter much, but it will prevent
|
|
||||||
// a client from reacting to its own message.
|
|
||||||
this.identity = Math.floor(Math.random() * 2586000);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
registerListener(listener) {
|
registerListener(listener) {
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
const config = window.config;
|
const config = window.config;
|
||||||
|
|
||||||
if (config.base_url === "") {
|
// The backend returns an empty base_url for the config struct,
|
||||||
|
// so the frontend (hey, that's us!) can use the current location.
|
||||||
|
if (!config.base_url || config.base_url === "") {
|
||||||
config.base_url = window.location.origin;
|
config.base_url = window.location.origin;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -7,6 +7,7 @@ import session from "./Session";
|
|||||||
// Notes:
|
// Notes:
|
||||||
// - As per docs, we only declare the indexable columns, not all columns
|
// - As per docs, we only declare the indexable columns, not all columns
|
||||||
|
|
||||||
|
// The IndexedDB database name is based on the logged-in user
|
||||||
const dbName = (session.username()) ? `ntfy-${session.username()}` : "ntfy";
|
const dbName = (session.username()) ? `ntfy-${session.username()}` : "ntfy";
|
||||||
const db = new Dexie(dbName);
|
const db = new Dexie(dbName);
|
||||||
|
|
||||||
|
@ -35,12 +35,8 @@ export const useConnectionListeners = (subscriptions, users) => {
|
|||||||
try {
|
try {
|
||||||
const data = JSON.parse(message.message);
|
const data = JSON.parse(message.message);
|
||||||
if (data.event === "sync") {
|
if (data.event === "sync") {
|
||||||
if (data.source !== accountApi.identity) {
|
|
||||||
console.log(`[ConnectionListener] Triggering account sync`);
|
console.log(`[ConnectionListener] Triggering account sync`);
|
||||||
await accountApi.sync();
|
await accountApi.sync();
|
||||||
} else {
|
|
||||||
console.log(`[ConnectionListener] I triggered the account sync, ignoring message`);
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
console.log(`[ConnectionListener] Unknown message type. Doing nothing.`);
|
console.log(`[ConnectionListener] Unknown message type. Doing nothing.`);
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user