publishSyncEvent, Stripe endpoint changes

This commit is contained in:
binwiederhier 2023-01-16 16:35:37 -05:00
parent 7faed3ee1e
commit 83de879894
14 changed files with 424 additions and 262 deletions

View File

@ -80,8 +80,8 @@ var flagsServe = append(
altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-email-limit-burst", Aliases: []string{"visitor_email_limit_burst"}, EnvVars: []string{"NTFY_VISITOR_EMAIL_LIMIT_BURST"}, Value: server.DefaultVisitorEmailLimitBurst, Usage: "initial limit of e-mails per visitor"}),
altsrc.NewDurationFlag(&cli.DurationFlag{Name: "visitor-email-limit-replenish", Aliases: []string{"visitor_email_limit_replenish"}, EnvVars: []string{"NTFY_VISITOR_EMAIL_LIMIT_REPLENISH"}, Value: server.DefaultVisitorEmailLimitReplenish, Usage: "interval at which burst limit is replenished (one per x)"}),
altsrc.NewBoolFlag(&cli.BoolFlag{Name: "behind-proxy", Aliases: []string{"behind_proxy", "P"}, EnvVars: []string{"NTFY_BEHIND_PROXY"}, Value: false, Usage: "if set, use X-Forwarded-For header to determine visitor IP address (for rate limiting)"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "stripe-key", Aliases: []string{"stripe_key"}, EnvVars: []string{"NTFY_STRIPE_KEY"}, Value: "", Usage: "xxxxxxxxxxxxx"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "stripe-webhook-key", Aliases: []string{"stripe_webhook_key"}, EnvVars: []string{"NTFY_STRIPE_WEBHOOK_KEY"}, Value: "", Usage: "xxxxxxxxxxxx"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "stripe-secret-key", Aliases: []string{"stripe_secret_key"}, EnvVars: []string{"NTFY_STRIPE_SECRET_KEY"}, Value: "", Usage: "key used for the Stripe API communication, this enables payments"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "stripe-webhook-key", Aliases: []string{"stripe_webhook_key"}, EnvVars: []string{"NTFY_STRIPE_WEBHOOK_KEY"}, Value: "", Usage: "key required to validate the authenticity of incoming webhooks from Stripe"}),
)
var cmdServe = &cli.Command{
@ -153,7 +153,7 @@ func execServe(c *cli.Context) error {
visitorEmailLimitBurst := c.Int("visitor-email-limit-burst")
visitorEmailLimitReplenish := c.Duration("visitor-email-limit-replenish")
behindProxy := c.Bool("behind-proxy")
stripeKey := c.String("stripe-key")
stripeSecretKey := c.String("stripe-secret-key")
stripeWebhookKey := c.String("stripe-webhook-key")
// Check values
@ -191,17 +191,17 @@ func execServe(c *cli.Context) error {
return errors.New("if upstream-base-url is set, base-url must also be set")
} else if upstreamBaseURL != "" && baseURL != "" && baseURL == upstreamBaseURL {
return errors.New("base-url and upstream-base-url cannot be identical, you'll likely want to set upstream-base-url to https://ntfy.sh, see https://ntfy.sh/docs/config/#ios-instant-notifications")
} else if authFile == "" && (enableSignup || enableLogin || enableReservations || stripeKey != "") {
return errors.New("cannot set enable-signup, enable-login, enable-reserve-topics, or stripe-key if auth-file is not set")
} else if authFile == "" && (enableSignup || enableLogin || enableReservations || stripeSecretKey != "") {
return errors.New("cannot set enable-signup, enable-login, enable-reserve-topics, or stripe-secret-key if auth-file is not set")
} else if enableSignup && !enableLogin {
return errors.New("cannot set enable-signup without also setting enable-login")
} else if stripeKey != "" && (stripeWebhookKey == "" || baseURL == "") {
return errors.New("if stripe-key is set, stripe-webhook-key and base-url must also be set")
} else if stripeSecretKey != "" && (stripeWebhookKey == "" || baseURL == "") {
return errors.New("if stripe-secret-key is set, stripe-webhook-key and base-url must also be set")
}
webRootIsApp := webRoot == "app"
enableWeb := webRoot != "disable"
enablePayments := stripeKey != ""
enablePayments := stripeSecretKey != ""
// Default auth permissions
authDefault, err := user.ParsePermission(authDefaultAccess)
@ -246,8 +246,8 @@ func execServe(c *cli.Context) error {
}
// Stripe things
if stripeKey != "" {
stripe.Key = stripeKey
if stripeSecretKey != "" {
stripe.Key = stripeSecretKey
}
// Run server
@ -293,7 +293,7 @@ func execServe(c *cli.Context) error {
conf.VisitorEmailLimitBurst = visitorEmailLimitBurst
conf.VisitorEmailLimitReplenish = visitorEmailLimitReplenish
conf.BehindProxy = behindProxy
conf.StripeKey = stripeKey
conf.StripeSecretKey = stripeSecretKey
conf.StripeWebhookKey = stripeWebhookKey
conf.EnableWeb = enableWeb
conf.EnableSignup = enableSignup

View File

@ -110,7 +110,7 @@ type Config struct {
VisitorAccountCreateLimitReplenish time.Duration
VisitorStatsResetTime time.Time // Time of the day at which to reset visitor stats
BehindProxy bool
StripeKey string
StripeSecretKey string
StripeWebhookKey string
EnableWeb bool
EnableSignup bool // Enable creation of accounts via API and UI

View File

@ -40,12 +40,10 @@ import (
- send dunning emails when overdue
- payment methods
- unmarshal to stripe.Subscription instead of gjson
- Make ResetTier reset the stripe fields
- delete subscription when account deleted
- remove tier.paid
- add tier.visible
- fix tier selection boxes
- account sync after switching tiers
- delete messages + reserved topics on ResetTier
Limits & rate limiting:
users without tier: should the stats be persisted? are they meaningful?
@ -360,7 +358,7 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit
} else if r.Method == http.MethodGet && r.URL.Path == accountPath {
return s.handleAccountGet(w, r, v) // Allowed by anonymous
} else if r.Method == http.MethodDelete && r.URL.Path == accountPath {
return s.ensureUser(s.handleAccountDelete)(w, r, v)
return s.ensureUser(s.withAccountSync(s.handleAccountDelete))(w, r, v)
} else if r.Method == http.MethodPost && r.URL.Path == accountPasswordPath {
return s.ensureUser(s.handleAccountPasswordChange)(w, r, v)
} else if r.Method == http.MethodPatch && r.URL.Path == accountTokenPath {
@ -368,27 +366,29 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit
} else if r.Method == http.MethodDelete && r.URL.Path == accountTokenPath {
return s.ensureUser(s.handleAccountTokenDelete)(w, r, v)
} else if r.Method == http.MethodPatch && r.URL.Path == accountSettingsPath {
return s.ensureUser(s.handleAccountSettingsChange)(w, r, v)
return s.ensureUser(s.withAccountSync(s.handleAccountSettingsChange))(w, r, v)
} else if r.Method == http.MethodPost && r.URL.Path == accountSubscriptionPath {
return s.ensureUser(s.handleAccountSubscriptionAdd)(w, r, v)
return s.ensureUser(s.withAccountSync(s.handleAccountSubscriptionAdd))(w, r, v)
} else if r.Method == http.MethodPatch && accountSubscriptionSingleRegex.MatchString(r.URL.Path) {
return s.ensureUser(s.handleAccountSubscriptionChange)(w, r, v)
return s.ensureUser(s.withAccountSync(s.handleAccountSubscriptionChange))(w, r, v)
} else if r.Method == http.MethodDelete && accountSubscriptionSingleRegex.MatchString(r.URL.Path) {
return s.ensureUser(s.handleAccountSubscriptionDelete)(w, r, v)
return s.ensureUser(s.withAccountSync(s.handleAccountSubscriptionDelete))(w, r, v)
} else if r.Method == http.MethodPost && r.URL.Path == accountReservationPath {
return s.ensureUser(s.handleAccountReservationAdd)(w, r, v)
return s.ensureUser(s.withAccountSync(s.handleAccountReservationAdd))(w, r, v)
} else if r.Method == http.MethodDelete && accountReservationSingleRegex.MatchString(r.URL.Path) {
return s.ensureUser(s.handleAccountReservationDelete)(w, r, v)
return s.ensureUser(s.withAccountSync(s.handleAccountReservationDelete))(w, r, v)
} else if r.Method == http.MethodPost && r.URL.Path == accountBillingSubscriptionPath {
return s.ensureUser(s.handleAccountBillingSubscriptionChange)(w, r, v)
} else if r.Method == http.MethodDelete && r.URL.Path == accountBillingSubscriptionPath {
return s.ensureStripeCustomer(s.handleAccountBillingSubscriptionDelete)(w, r, v)
return s.ensurePaymentsEnabled(s.ensureUser(s.handleAccountBillingSubscriptionCreate))(w, r, v) // Account sync via incoming Stripe webhook
} else if r.Method == http.MethodGet && accountBillingSubscriptionCheckoutSuccessRegex.MatchString(r.URL.Path) {
return s.ensureUserManager(s.handleAccountCheckoutSessionSuccessGet)(w, r, v) // No user context!
return s.ensurePaymentsEnabled(s.ensureUserManager(s.handleAccountBillingSubscriptionCreateSuccess))(w, r, v) // No user context!
} else if r.Method == http.MethodPut && r.URL.Path == accountBillingSubscriptionPath {
return s.ensurePaymentsEnabled(s.ensureUser(s.handleAccountBillingSubscriptionUpdate))(w, r, v) // Account sync via incoming Stripe webhook
} else if r.Method == http.MethodDelete && r.URL.Path == accountBillingSubscriptionPath {
return s.ensurePaymentsEnabled(s.ensureStripeCustomer(s.handleAccountBillingSubscriptionDelete))(w, r, v) // Account sync via incoming Stripe webhook
} else if r.Method == http.MethodPost && r.URL.Path == accountBillingPortalPath {
return s.ensureStripeCustomer(s.handleAccountBillingPortalSessionCreate)(w, r, v)
return s.ensurePaymentsEnabled(s.ensureStripeCustomer(s.handleAccountBillingPortalSessionCreate))(w, r, v)
} else if r.Method == http.MethodPost && r.URL.Path == accountBillingWebhookPath {
return s.ensureUserManager(s.handleAccountBillingWebhook)(w, r, v)
return s.ensurePaymentsEnabled(s.ensureUserManager(s.handleAccountBillingWebhook))(w, r, v)
} else if r.Method == http.MethodGet && r.URL.Path == matrixPushPath {
return s.handleMatrixDiscovery(w)
} else if r.Method == http.MethodGet && staticRegex.MatchString(r.URL.Path) {
@ -1423,12 +1423,12 @@ func (s *Server) sendDelayedMessages() error {
for _, m := range messages {
var v *visitor
if s.userManager != nil && m.User != "" {
user, err := s.userManager.User(m.User)
u, err := s.userManager.User(m.User)
if err != nil {
log.Warn("%s Error sending delayed message: %s", logMessagePrefix(v, m), err.Error())
continue
}
v = s.visitorFromUser(user, m.Sender)
v = s.visitorFromUser(u, m.Sender)
} else {
v = s.visitorFromIP(m.Sender)
}
@ -1475,42 +1475,6 @@ func (s *Server) limitRequests(next handleFunc) handleFunc {
}
}
func (s *Server) ensureWebEnabled(next handleFunc) handleFunc {
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
if !s.config.EnableWeb {
return errHTTPNotFound
}
return next(w, r, v)
}
}
func (s *Server) ensureUserManager(next handleFunc) handleFunc {
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
if s.userManager == nil {
return errHTTPNotFound
}
return next(w, r, v)
}
}
func (s *Server) ensureUser(next handleFunc) handleFunc {
return s.ensureUserManager(func(w http.ResponseWriter, r *http.Request, v *visitor) error {
if v.user == nil {
return errHTTPUnauthorized
}
return next(w, r, v)
})
}
func (s *Server) ensureStripeCustomer(next handleFunc) handleFunc {
return s.ensureUser(func(w http.ResponseWriter, r *http.Request, v *visitor) error {
if v.user.Billing.StripeCustomerID == "" {
return errHTTPBadRequestNotAPaidUser
}
return next(w, r, v)
})
}
// transformBodyJSON peeks the request body, reads the JSON, and converts it to headers
// before passing it on to the next handler. This is meant to be used in combination with handlePublish.
func (s *Server) transformBodyJSON(next handleFunc) handleFunc {

View File

@ -164,12 +164,10 @@
# - enable-signup allows users to sign up via the web app, or API
# - enable-login allows users to log in via the web app, or API
# - enable-reservations allows users to reserve topics (if their tier allows it)
# - enable-payments enables payments integration [preliminary option, may change]
#
# enable-signup: false
# enable-login: false
# enable-reservations: false
# enable-payments: false
# Server URL of a Firebase/APNS-connected ntfy server (likely "https://ntfy.sh").
#
@ -216,6 +214,16 @@
# visitor-attachment-total-size-limit: "100M"
# visitor-attachment-daily-bandwidth-limit: "500M"
# Payments integration via Stripe
#
# - stripe-secret-key is the key used for the Stripe API communication. Setting this values
# enables payments in the ntfy web app (e.g. Upgrade dialog). See https://dashboard.stripe.com/apikeys.
# - stripe-webhook-key is the key required to validate the authenticity of incoming webhooks from Stripe.
# Webhooks are essential up keep the local database in sync with the payment provider. See https://dashboard.stripe.com/webhooks.
#
# stripe-secret-key:
# stripe-webhook-key:
# Log level, can be TRACE, DEBUG, INFO, WARN or ERROR
# This option can be hot-reloaded by calling "kill -HUP $pid" or "systemctl reload ntfy".
#

View File

@ -2,6 +2,8 @@ package server
import (
"encoding/json"
"errors"
"heckel.io/ntfy/log"
"heckel.io/ntfy/user"
"heckel.io/ntfy/util"
"net/http"
@ -11,6 +13,7 @@ const (
jsonBodyBytesLimit = 4096
subscriptionIDLength = 16
createdByAPI = "api"
syncTopicAccountSyncEvent = "sync"
)
func (s *Server) handleAccountCreate(w http.ResponseWriter, r *http.Request, v *visitor) error {
@ -395,3 +398,37 @@ func (s *Server) handleAccountReservationDelete(w http.ResponseWriter, r *http.R
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
return nil
}
func (s *Server) publishSyncEvent(v *visitor) error {
if v.user == nil || v.user.SyncTopic == "" {
return nil
}
log.Trace("Publishing sync event to user %s's sync topic %s", v.user.Name, v.user.SyncTopic)
topics, err := s.topicsFromIDs(v.user.SyncTopic)
if err != nil {
return err
} else if len(topics) == 0 {
return errors.New("cannot retrieve sync topic")
}
syncTopic := topics[0]
messageBytes, err := json.Marshal(&apiAccountSyncTopicResponse{Event: syncTopicAccountSyncEvent})
if err != nil {
return err
}
m := newDefaultMessage(syncTopic.ID, string(messageBytes))
if err := syncTopic.Publish(v, m); err != nil {
return err
}
return nil
}
func (s *Server) publishSyncEventAsync(v *visitor) {
go func() {
if v.user == nil || v.user.SyncTopic == "" {
return
}
if err := s.publishSyncEvent(v); err != nil {
log.Trace("Error publishing to user %s's sync topic %s: %s", v.user.Name, v.user.SyncTopic, err.Error())
}
}()
}

View File

@ -0,0 +1,63 @@
package server
import (
"net/http"
)
func (s *Server) ensureWebEnabled(next handleFunc) handleFunc {
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
if !s.config.EnableWeb {
return errHTTPNotFound
}
return next(w, r, v)
}
}
func (s *Server) ensureUserManager(next handleFunc) handleFunc {
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
if s.userManager == nil {
return errHTTPNotFound
}
return next(w, r, v)
}
}
func (s *Server) ensureUser(next handleFunc) handleFunc {
return s.ensureUserManager(func(w http.ResponseWriter, r *http.Request, v *visitor) error {
if v.user == nil {
return errHTTPUnauthorized
}
return next(w, r, v)
})
}
func (s *Server) ensurePaymentsEnabled(next handleFunc) handleFunc {
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
if !s.config.EnablePayments {
return errHTTPNotFound
}
return next(w, r, v)
}
}
func (s *Server) ensureStripeCustomer(next handleFunc) handleFunc {
return s.ensureUser(func(w http.ResponseWriter, r *http.Request, v *visitor) error {
if v.user.Billing.StripeCustomerID == "" {
return errHTTPBadRequestNotAPaidUser
}
return next(w, r, v)
})
}
func (s *Server) withAccountSync(next handleFunc) handleFunc {
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
if v.user == nil {
return next(w, r, v)
}
err := next(w, r, v)
if err == nil {
s.publishSyncEventAsync(v)
}
return err
}
}

View File

@ -6,13 +6,14 @@ import (
"github.com/stripe/stripe-go/v74"
portalsession "github.com/stripe/stripe-go/v74/billingportal/session"
"github.com/stripe/stripe-go/v74/checkout/session"
"github.com/stripe/stripe-go/v74/customer"
"github.com/stripe/stripe-go/v74/subscription"
"github.com/stripe/stripe-go/v74/webhook"
"github.com/tidwall/gjson"
"heckel.io/ntfy/log"
"heckel.io/ntfy/user"
"heckel.io/ntfy/util"
"net/http"
"net/netip"
"time"
)
@ -20,15 +21,13 @@ const (
stripeBodyBytesLimit = 16384
)
// handleAccountBillingSubscriptionChange facilitates all subscription/tier changes, including payment flows.
//
// FIXME this should be two functions!
//
// It handles two cases:
// - Create subscription: Transition from a user without Stripe subscription to a paid subscription (Checkout flow)
// - Change subscription: Switching between Stripe prices (& tiers) by changing the Stripe subscription
func (s *Server) handleAccountBillingSubscriptionChange(w http.ResponseWriter, r *http.Request, v *visitor) error {
req, err := readJSONWithLimit[apiAccountTierChangeRequest](r.Body, jsonBodyBytesLimit)
// handleAccountBillingSubscriptionCreate creates a Stripe checkout flow to create a user subscription. The tier
// will be updated by a subsequent webhook from Stripe, once the subscription becomes active.
func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error {
if v.user.Billing.StripeSubscriptionID != "" {
return errors.New("subscription already exists") //FIXME
}
req, err := readJSONWithLimit[apiAccountBillingSubscriptionChangeRequest](r.Body, jsonBodyBytesLimit)
if err != nil {
return err
}
@ -36,46 +35,21 @@ func (s *Server) handleAccountBillingSubscriptionChange(w http.ResponseWriter, r
if err != nil {
return err
}
if v.user.Billing.StripeSubscriptionID == "" && tier.StripePriceID != "" {
return s.handleAccountBillingSubscriptionAdd(w, v, tier)
} else if v.user.Billing.StripeSubscriptionID != "" {
return s.handleAccountBillingSubscriptionUpdate(w, v, tier)
if tier.StripePriceID == "" {
return errors.New("invalid tier") //FIXME
}
return errors.New("invalid state")
}
// handleAccountBillingSubscriptionDelete facilitates downgrading a paid user to a tier-less user,
// and cancelling the Stripe subscription entirely
func (s *Server) handleAccountBillingSubscriptionDelete(w http.ResponseWriter, r *http.Request, v *visitor) error {
if v.user.Billing.StripeCustomerID == "" {
return errHTTPBadRequestNotAPaidUser
}
if v.user.Billing.StripeSubscriptionID != "" {
_, err := subscription.Cancel(v.user.Billing.StripeSubscriptionID, nil)
if err != nil {
return err
}
}
if err := s.userManager.ResetTier(v.user.Name); err != nil {
return err
}
v.user.Billing.StripeSubscriptionID = ""
v.user.Billing.StripeSubscriptionStatus = ""
v.user.Billing.StripeSubscriptionPaidUntil = time.Unix(0, 0)
v.user.Billing.StripeSubscriptionCancelAt = time.Unix(0, 0)
if err := s.userManager.ChangeBilling(v.user); err != nil {
return err
}
return nil
}
func (s *Server) handleAccountBillingSubscriptionAdd(w http.ResponseWriter, v *visitor, tier *user.Tier) error {
log.Info("Stripe: No existing subscription, creating checkout flow")
var stripeCustomerID *string
if v.user.Billing.StripeCustomerID != "" {
stripeCustomerID = &v.user.Billing.StripeCustomerID
stripeCustomer, err := customer.Get(v.user.Billing.StripeCustomerID, nil)
if err != nil {
return err
} else if stripeCustomer.Subscriptions != nil && len(stripeCustomer.Subscriptions.Data) > 0 {
return errors.New("customer cannot have more than one subscription") //FIXME
}
successURL := s.config.BaseURL + accountBillingSubscriptionCheckoutSuccessTemplate
}
successURL := s.config.BaseURL + "/account" //+ accountBillingSubscriptionCheckoutSuccessTemplate
params := &stripe.CheckoutSessionParams{
Customer: stripeCustomerID, // A user may have previously deleted their subscription
ClientReferenceID: &v.user.Name, // FIXME Should be user ID
@ -106,36 +80,7 @@ func (s *Server) handleAccountBillingSubscriptionAdd(w http.ResponseWriter, v *v
return nil
}
func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, v *visitor, tier *user.Tier) error {
log.Info("Stripe: Changing tier and subscription to %s", tier.Code)
sub, err := subscription.Get(v.user.Billing.StripeSubscriptionID, nil)
if err != nil {
return err
}
params := &stripe.SubscriptionParams{
CancelAtPeriodEnd: stripe.Bool(false),
ProrationBehavior: stripe.String(string(stripe.SubscriptionSchedulePhaseProrationBehaviorCreateProrations)),
Items: []*stripe.SubscriptionItemsParams{
{
ID: stripe.String(sub.Items.Data[0].ID),
Price: stripe.String(tier.StripePriceID),
},
},
}
_, err = subscription.Update(sub.ID, params)
if err != nil {
return err
}
response := &apiAccountCheckoutResponse{}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
if err := json.NewEncoder(w).Encode(response); err != nil {
return err
}
return nil
}
func (s *Server) handleAccountCheckoutSessionSuccessGet(w http.ResponseWriter, r *http.Request, v *visitor) error {
func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWriter, r *http.Request, _ *visitor) error {
// We don't have a v.user in this endpoint, only a userManager!
matches := accountBillingSubscriptionCheckoutSuccessRegex.FindStringSubmatch(r.URL.Path)
if len(matches) != 2 {
@ -183,6 +128,66 @@ func (s *Server) handleAccountCheckoutSessionSuccessGet(w http.ResponseWriter, r
return nil
}
// handleAccountBillingSubscriptionUpdate updates an existing Stripe subscription to a new price, and updates
// a user's tier accordingly. This endpoint only works if there is an existing subscription.
func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r *http.Request, v *visitor) error {
if v.user.Billing.StripeSubscriptionID != "" {
return errors.New("no existing subscription for user")
}
req, err := readJSONWithLimit[apiAccountBillingSubscriptionChangeRequest](r.Body, jsonBodyBytesLimit)
if err != nil {
return err
}
tier, err := s.userManager.Tier(req.Tier)
if err != nil {
return err
}
log.Info("Stripe: Changing tier and subscription to %s", tier.Code)
sub, err := subscription.Get(v.user.Billing.StripeSubscriptionID, nil)
if err != nil {
return err
}
params := &stripe.SubscriptionParams{
CancelAtPeriodEnd: stripe.Bool(false),
ProrationBehavior: stripe.String(string(stripe.SubscriptionSchedulePhaseProrationBehaviorCreateProrations)),
Items: []*stripe.SubscriptionItemsParams{
{
ID: stripe.String(sub.Items.Data[0].ID),
Price: stripe.String(tier.StripePriceID),
},
},
}
_, err = subscription.Update(sub.ID, params)
if err != nil {
return err
}
response := &apiAccountCheckoutResponse{} // FIXME
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
if err := json.NewEncoder(w).Encode(response); err != nil {
return err
}
return nil
}
// handleAccountBillingSubscriptionDelete facilitates downgrading a paid user to a tier-less user,
// and cancelling the Stripe subscription entirely
func (s *Server) handleAccountBillingSubscriptionDelete(w http.ResponseWriter, r *http.Request, v *visitor) error {
if v.user.Billing.StripeCustomerID == "" {
return errHTTPBadRequestNotAPaidUser
}
if v.user.Billing.StripeSubscriptionID != "" {
params := &stripe.SubscriptionParams{
CancelAtPeriodEnd: stripe.Bool(true),
}
_, err := subscription.Update(v.user.Billing.StripeSubscriptionID, params)
if err != nil {
return err
}
}
return nil
}
func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error {
if v.user.Billing.StripeCustomerID == "" {
return errHTTPBadRequestNotAPaidUser
@ -206,8 +211,8 @@ func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter,
return nil
}
func (s *Server) handleAccountBillingWebhook(w http.ResponseWriter, r *http.Request, v *visitor) error {
// We don't have a v.user in this endpoint, only a userManager!
func (s *Server) handleAccountBillingWebhook(w http.ResponseWriter, r *http.Request, _ *visitor) error {
// Note that the visitor (v) in this endpoint is the Stripe API, so we don't have v.user available
stripeSignature := r.Header.Get("Stripe-Signature")
if stripeSignature == "" {
return errHTTPBadRequestInvalidStripeRequest
@ -225,30 +230,27 @@ func (s *Server) handleAccountBillingWebhook(w http.ResponseWriter, r *http.Requ
return errHTTPBadRequestInvalidStripeRequest
}
log.Info("Stripe: webhook event %s received", event.Type)
stripeCustomerID := gjson.GetBytes(event.Data.Raw, "customer")
if !stripeCustomerID.Exists() {
return errHTTPBadRequestInvalidStripeRequest
}
switch event.Type {
case "customer.subscription.updated":
return s.handleAccountBillingWebhookSubscriptionUpdated(stripeCustomerID.String(), event.Data.Raw)
return s.handleAccountBillingWebhookSubscriptionUpdated(event.Data.Raw)
case "customer.subscription.deleted":
return s.handleAccountBillingWebhookSubscriptionDeleted(stripeCustomerID.String(), event.Data.Raw)
return s.handleAccountBillingWebhookSubscriptionDeleted(event.Data.Raw)
default:
return nil
}
}
func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(stripeCustomerID string, event json.RawMessage) error {
func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(event json.RawMessage) error {
subscriptionID := gjson.GetBytes(event, "id")
customerID := gjson.GetBytes(event, "customer")
status := gjson.GetBytes(event, "status")
currentPeriodEnd := gjson.GetBytes(event, "current_period_end")
cancelAt := gjson.GetBytes(event, "cancel_at")
priceID := gjson.GetBytes(event, "items.data.0.price.id")
if !status.Exists() || !currentPeriodEnd.Exists() || !cancelAt.Exists() || !priceID.Exists() {
if !subscriptionID.Exists() || !status.Exists() || !currentPeriodEnd.Exists() || !cancelAt.Exists() || !priceID.Exists() {
return errHTTPBadRequestInvalidStripeRequest
}
log.Info("Stripe: customer %s: subscription updated to %s, with price %s", stripeCustomerID, status, priceID)
u, err := s.userManager.UserByStripeCustomer(stripeCustomerID)
u, err := s.userManager.UserByStripeCustomer(customerID.String())
if err != nil {
return err
}
@ -259,22 +261,25 @@ func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(stripeCustomerID
if err := s.userManager.ChangeTier(u.Name, tier.Code); err != nil {
return err
}
u.Billing.StripeSubscriptionID = subscriptionID.String()
u.Billing.StripeSubscriptionStatus = stripe.SubscriptionStatus(status.String())
u.Billing.StripeSubscriptionPaidUntil = time.Unix(currentPeriodEnd.Int(), 0)
u.Billing.StripeSubscriptionCancelAt = time.Unix(cancelAt.Int(), 0)
if err := s.userManager.ChangeBilling(u); err != nil {
return err
}
log.Info("Stripe: customer %s: subscription updated to %s, with price %s", customerID.String(), status, priceID)
s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified()))
return nil
}
func (s *Server) handleAccountBillingWebhookSubscriptionDeleted(stripeCustomerID string, event json.RawMessage) error {
status := gjson.GetBytes(event, "status")
if !status.Exists() {
func (s *Server) handleAccountBillingWebhookSubscriptionDeleted(event json.RawMessage) error {
stripeCustomerID := gjson.GetBytes(event, "customer")
if !stripeCustomerID.Exists() {
return errHTTPBadRequestInvalidStripeRequest
}
log.Info("Stripe: customer %s: subscription deleted, downgrading to unpaid tier", stripeCustomerID)
u, err := s.userManager.UserByStripeCustomer(stripeCustomerID)
log.Info("Stripe: customer %s: subscription deleted, downgrading to unpaid tier", stripeCustomerID.String())
u, err := s.userManager.UserByStripeCustomer(stripeCustomerID.String())
if err != nil {
return err
}
@ -288,5 +293,6 @@ func (s *Server) handleAccountBillingWebhookSubscriptionDeleted(stripeCustomerID
if err := s.userManager.ChangeBilling(u); err != nil {
return err
}
s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified()))
return nil
}

View File

@ -305,7 +305,7 @@ type apiConfigResponse struct {
DisallowedTopics []string `json:"disallowed_topics"`
}
type apiAccountTierChangeRequest struct {
type apiAccountBillingSubscriptionChangeRequest struct {
Tier string `json:"tier"`
}
@ -316,3 +316,7 @@ type apiAccountCheckoutResponse struct {
type apiAccountBillingPortalRedirectResponse struct {
RedirectURL string `json:"redirect_url"`
}
type apiAccountSyncTopicResponse struct {
Event string `json:"event"`
}

View File

@ -38,7 +38,6 @@ const (
id INTEGER PRIMARY KEY AUTOINCREMENT,
code TEXT NOT NULL,
name TEXT NOT NULL,
paid INT NOT NULL,
messages_limit INT NOT NULL,
messages_expiry_duration INT NOT NULL,
emails_limit INT NOT NULL,
@ -104,20 +103,20 @@ const (
`
selectUserByNameQuery = `
SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id
SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, p.code, p.name, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id
FROM user u
LEFT JOIN tier p on p.id = u.tier_id
WHERE user = ?
`
selectUserByTokenQuery = `
SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at , p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id
SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, p.code, p.name, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id
FROM user u
JOIN user_token t on u.id = t.user_id
LEFT JOIN tier p on p.id = u.tier_id
WHERE t.token = ? AND t.expires >= ?
`
selectUserByStripeCustomerIDQuery = `
SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at , p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id
SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, p.code, p.name, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id
FROM user u
LEFT JOIN tier p on p.id = u.tier_id
WHERE u.stripe_customer_id = ?
@ -218,17 +217,17 @@ const (
`
insertTierQuery = `
INSERT INTO tier (code, name, paid, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
INSERT INTO tier (code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
`
selectTierIDQuery = `SELECT id FROM tier WHERE code = ?`
selectTierByCodeQuery = `
SELECT code, name, paid, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id
SELECT code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id
FROM tier
WHERE code = ?
`
selectTierByPriceIDQuery = `
SELECT code, name, paid, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id
SELECT code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id
FROM tier
WHERE stripe_price_id = ?
`
@ -606,13 +605,12 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
defer rows.Close()
var username, hash, role, prefs, syncTopic string
var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripePriceID, tierCode, tierName sql.NullString
var paid sql.NullBool
var messages, emails int64
var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt sql.NullInt64
if !rows.Next() {
return nil, ErrUserNotFound
}
if err := rows.Scan(&username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &tierCode, &tierName, &paid, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil {
if err := rows.Scan(&username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil {
return nil, err
} else if err := rows.Err(); err != nil {
return nil, err
@ -643,7 +641,7 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
user.Tier = &Tier{
Code: tierCode.String,
Name: tierName.String,
Paid: paid.Bool,
Paid: stripePriceID.Valid, // If there is a price, it's a paid tier
MessagesLimit: messagesLimit.Int64,
MessagesExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
EmailsLimit: emailsLimit.Int64,
@ -870,7 +868,7 @@ func (a *Manager) DefaultAccess() Permission {
// CreateTier creates a new tier in the database
func (a *Manager) CreateTier(tier *Tier) error {
if _, err := a.db.Exec(insertTierQuery, tier.Code, tier.Name, tier.Paid, tier.MessagesLimit, int64(tier.MessagesExpiryDuration.Seconds()), tier.EmailsLimit, tier.ReservationsLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds())); err != nil {
if _, err := a.db.Exec(insertTierQuery, tier.Code, tier.Name, tier.MessagesLimit, int64(tier.MessagesExpiryDuration.Seconds()), tier.EmailsLimit, tier.ReservationsLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds())); err != nil {
return err
}
return nil
@ -903,12 +901,11 @@ func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) {
defer rows.Close()
var code, name string
var stripePriceID sql.NullString
var paid bool
var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration sql.NullInt64
if !rows.Next() {
return nil, ErrTierNotFound
}
if err := rows.Scan(&code, &name, &paid, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil {
if err := rows.Scan(&code, &name, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil {
return nil, err
} else if err := rows.Err(); err != nil {
return nil, err
@ -917,7 +914,7 @@ func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) {
return &Tier{
Code: code,
Name: name,
Paid: paid,
Paid: stripePriceID.Valid, // If there is a price, it's a paid tier
MessagesLimit: messagesLimit.Int64,
MessagesExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
EmailsLimit: emailsLimit.Int64,

View File

@ -179,8 +179,10 @@
"account_usage_unlimited": "Unlimited",
"account_usage_limits_reset_daily": "Usage limits are reset daily at midnight (UTC)",
"account_usage_tier_title": "Account type",
"account_usage_tier_description": "Your account's power level",
"account_usage_tier_admin": "Admin",
"account_usage_tier_none": "Basic",
"account_usage_tier_basic": "Basic",
"account_usage_tier_free": "Free",
"account_usage_tier_upgrade_button": "Upgrade to Pro",
"account_usage_tier_change_button": "Change",
"account_usage_tier_paid_until": "Subscription paid until {{date}}, and will auto-renew",
@ -199,6 +201,8 @@
"account_delete_dialog_label": "Type '{{username}}' to delete account",
"account_delete_dialog_button_cancel": "Cancel",
"account_delete_dialog_button_submit": "Permanently delete account",
"account_upgrade_dialog_title": "Change billing plan",
"account_upgrade_dialog_cancel_warning": "This will cancel your subscription, and downgrade your account on {{date}}. On that date, topic reservations as well as messages cached on the server will be deleted.",
"prefs_notifications_title": "Notifications",
"prefs_notifications_sound_title": "Notification sound",
"prefs_notifications_sound_description_none": "Notifications do not play any sound when they arrive",

View File

@ -264,11 +264,20 @@ class AccountApi {
this.triggerChange(); // Dangle!
}
async createBillingSubscription(tier) {
console.log(`[AccountApi] Creating billing subscription with ${tier}`);
return await this.upsertBillingSubscription("POST", tier)
}
async updateBillingSubscription(tier) {
console.log(`[AccountApi] Updating billing subscription with ${tier}`);
return await this.upsertBillingSubscription("PUT", tier)
}
async upsertBillingSubscription(method, tier) {
const url = accountBillingSubscriptionUrl(config.base_url);
console.log(`[AccountApi] Requesting tier change to ${tier}`);
const response = await fetch(url, {
method: "POST",
method: method,
headers: withBearerAuth({}, session.token()),
body: JSON.stringify({
tier: tier
@ -284,7 +293,7 @@ class AccountApi {
async deleteBillingSubscription() {
const url = accountBillingSubscriptionUrl(config.base_url);
console.log(`[AccountApi] Cancelling paid subscription`);
console.log(`[AccountApi] Cancelling billing subscription`);
const response = await fetch(url, {
method: "DELETE",
headers: withBearerAuth({}, session.token())
@ -345,6 +354,7 @@ class AccountApi {
}
async triggerChange() {
return null;
const account = await this.get();
if (!account.sync_topic) {
return;

View File

@ -56,6 +56,7 @@ const Basics = () => {
<PrefGroup>
<Username/>
<ChangePassword/>
<AccountType/>
</PrefGroup>
</Card>
);
@ -168,18 +169,20 @@ const ChangePasswordDialog = (props) => {
);
};
const Stats = () => {
const AccountType = () => {
const { t } = useTranslation();
const { account } = useContext(AccountContext);
const [upgradeDialogKey, setUpgradeDialogKey] = useState(0);
const [upgradeDialogOpen, setUpgradeDialogOpen] = useState(false);
if (!account) {
return <></>;
}
const normalize = (value, max) => {
return Math.min(value / max * 100, 100);
};
const handleUpgradeClick = () => {
setUpgradeDialogKey(k => k + 1);
setUpgradeDialogOpen(true);
}
const handleManageBilling = async () => {
try {
@ -194,44 +197,43 @@ const Stats = () => {
}
};
let accountType;
if (account.role === "admin") {
const tierSuffix = (account.tier) ? `(with ${account.tier.name} tier)` : `(no tier)`;
accountType = `${t("account_usage_tier_admin")} ${tierSuffix}`;
} else if (!account.tier) {
accountType = (config.enable_payments) ? t("account_usage_tier_free") : t("account_usage_tier_basic");
} else {
accountType = account.tier.name;
}
return (
<Card sx={{p: 3}} aria-label={t("account_usage_title")}>
<Typography variant="h5" sx={{marginBottom: 2}}>
{t("account_usage_title")}
</Typography>
<PrefGroup>
<Pref
alignTop={account.billing?.status === "past_due" || account.billing?.cancel_at > 0}
title={t("account_usage_tier_title")}
description={t("account_usage_tier_description")}
>
<div>
{account.role === "admin" &&
<>
{t("account_usage_tier_admin")}
{" "}{account.tier ? `(with ${account.tier.name} tier)` : `(no tier)`}
</>
}
{account.role === "user" && account.tier && account.tier.name}
{account.role === "user" && !account.tier && t("account_usage_tier_none")}
{account.billing?.paid_until &&
{accountType}
{account.billing?.paid_until && !account.billing?.cancel_at &&
<Tooltip title={t("account_usage_tier_paid_until", { date: formatShortDate(account.billing?.paid_until) })}>
<span><InfoIcon/></span>
</Tooltip>
}
{config.enable_payments && account.role === "user" && (!account.tier || !account.tier.paid) &&
{config.enable_payments && account.role === "user" && !account.billing?.subscription &&
<Button
variant="outlined"
size="small"
startIcon={<CelebrationIcon sx={{ color: "#55b86e" }}/>}
onClick={() => setUpgradeDialogOpen(true)}
onClick={handleUpgradeClick}
sx={{ml: 1}}
>{t("account_usage_tier_upgrade_button")}</Button>
}
{config.enable_payments && account.role === "user" && account.tier?.paid &&
{config.enable_payments && account.role === "user" && account.billing?.subscription &&
<Button
variant="outlined"
size="small"
onClick={() => setUpgradeDialogOpen(true)}
onClick={handleUpgradeClick}
sx={{ml: 1}}
>{t("account_usage_tier_change_button")}</Button>
}
@ -244,6 +246,7 @@ const Stats = () => {
>{t("account_usage_manage_billing_button")}</Button>
}
<UpgradeDialog
key={`upgradeDialogFromAccount${upgradeDialogKey}`}
open={upgradeDialogOpen}
onCancel={() => setUpgradeDialogOpen(false)}
/>
@ -252,9 +255,31 @@ const Stats = () => {
<Alert severity="error" sx={{mt: 1}}>{t("account_usage_tier_payment_overdue")}</Alert>
}
{account.billing?.cancel_at > 0 &&
<Alert severity="info" sx={{mt: 1}}>{t("account_usage_tier_canceled_subscription", { date: formatShortDate(account.billing.cancel_at) })}</Alert>
<Alert severity="warning" sx={{mt: 1}}>{t("account_usage_tier_canceled_subscription", { date: formatShortDate(account.billing.cancel_at) })}</Alert>
}
</Pref>
)
};
const Stats = () => {
const { t } = useTranslation();
const { account } = useContext(AccountContext);
const [upgradeDialogOpen, setUpgradeDialogOpen] = useState(false);
if (!account) {
return <></>;
}
const normalize = (value, max) => {
return Math.min(value / max * 100, 100);
};
return (
<Card sx={{p: 3}} aria-label={t("account_usage_title")}>
<Typography variant="h5" sx={{marginBottom: 2}}>
{t("account_usage_title")}
</Typography>
<PrefGroup>
{account.role !== "admin" &&
<Pref title={t("account_usage_reservations_title")}>
{account.limits.reservations > 0 &&

View File

@ -103,8 +103,8 @@ const NavList = (props) => {
};
const isAdmin = account?.role === "admin";
const isPaid = account?.tier?.paid;
const showUpgradeBanner = config.enable_payments && !isAdmin && !isPaid;// && (!props.account || !props.account.tier || !props.account.tier.paid || props.account);
const isPaid = account?.billing?.subscription;
const showUpgradeBanner = config.enable_payments && !isAdmin && !isPaid;
const showSubscriptionsList = props.subscriptions?.length > 0;
const showNotificationBrowserNotSupportedBox = !notifier.browserSupported();
const showNotificationContextNotSupportedBox = notifier.browserSupported() && !notifier.contextSupported(); // Only show if notifications are generally supported in the browser
@ -174,7 +174,14 @@ const NavList = (props) => {
};
const UpgradeBanner = () => {
const [dialogKey, setDialogKey] = useState(0);
const [dialogOpen, setDialogOpen] = useState(false);
const handleClick = () => {
setDialogKey(k => k + 1);
setDialogOpen(true);
};
return (
<Box sx={{
position: "fixed",
@ -184,7 +191,7 @@ const UpgradeBanner = () => {
background: "linear-gradient(150deg, rgba(196, 228, 221, 0.46) 0%, rgb(255, 255, 255) 100%)",
}}>
<Divider/>
<ListItemButton onClick={() => setDialogOpen(true)} sx={{pt: 2, pb: 2}}>
<ListItemButton onClick={handleClick} sx={{pt: 2, pb: 2}}>
<ListItemIcon><CelebrationIcon sx={{ color: "#55b86e" }} fontSize="large"/></ListItemIcon>
<ListItemText
sx={{ ml: 1 }}
@ -207,6 +214,7 @@ const UpgradeBanner = () => {
/>
</ListItemButton>
<UpgradeDialog
key={`upgradeDialog${dialogKey}`}
open={dialogOpen}
onCancel={() => setDialogOpen(false)}
/>

View File

@ -2,7 +2,7 @@ import * as React from 'react';
import Dialog from '@mui/material/Dialog';
import DialogContent from '@mui/material/DialogContent';
import DialogTitle from '@mui/material/DialogTitle';
import {CardActionArea, CardContent, useMediaQuery} from "@mui/material";
import {Alert, CardActionArea, CardContent, useMediaQuery} from "@mui/material";
import theme from "./theme";
import DialogFooter from "./DialogFooter";
import Button from "@mui/material/Button";
@ -13,28 +13,53 @@ import {useContext, useState} from "react";
import Card from "@mui/material/Card";
import Typography from "@mui/material/Typography";
import {AccountContext} from "./App";
import {formatShortDate} from "../app/utils";
import {useTranslation} from "react-i18next";
const UpgradeDialog = (props) => {
const { t } = useTranslation();
const { account } = useContext(AccountContext);
const fullScreen = useMediaQuery(theme.breakpoints.down('sm'));
const [newTier, setNewTier] = useState(account?.tier?.code || null);
const [errorText, setErrorText] = useState("");
const handleCheckout = async () => {
try {
if (newTier == null) {
await accountApi.deleteBillingSubscription();
} else {
const response = await accountApi.updateBillingSubscription(newTier);
if (response.redirect_url) {
window.location.href = response.redirect_url;
} else {
await accountApi.sync();
}
if (!account) {
return <></>;
}
const currentTier = account.tier?.code || null;
let action, submitButtonLabel, submitButtonEnabled;
if (currentTier === newTier) {
submitButtonLabel = "Update subscription";
submitButtonEnabled = false;
action = null;
} else if (currentTier === null) {
submitButtonLabel = "Pay $5 now and subscribe";
submitButtonEnabled = true;
action = Action.CREATE;
} else if (newTier === null) {
submitButtonLabel = "Cancel subscription";
submitButtonEnabled = true;
action = Action.CANCEL;
} else {
submitButtonLabel = "Update subscription";
submitButtonEnabled = true;
action = Action.UPDATE;
}
const handleSubmit = async () => {
try {
if (action === Action.CREATE) {
const response = await accountApi.createBillingSubscription(newTier);
window.location.href = response.redirect_url;
} else if (action === Action.UPDATE) {
await accountApi.updateBillingSubscription(newTier);
} else if (action === Action.CANCEL) {
await accountApi.deleteBillingSubscription();
}
props.onCancel();
} catch (e) {
console.log(`[UpgradeDialog] Error creating checkout session`, e);
console.log(`[UpgradeDialog] Error changing billing subscription`, e);
if ((e instanceof UnauthorizedError)) {
session.resetAndRedirect(routes.login);
}
@ -44,7 +69,7 @@ const UpgradeDialog = (props) => {
return (
<Dialog open={props.open} onClose={props.onCancel} maxWidth="md" fullScreen={fullScreen}>
<DialogTitle>Upgrade to Pro</DialogTitle>
<DialogTitle>Change billing plan</DialogTitle>
<DialogContent>
<div style={{
display: "flex",
@ -55,9 +80,15 @@ const UpgradeDialog = (props) => {
<TierCard code="pro" name={"Pro"} selected={newTier === "pro"} onClick={() => setNewTier("pro")}/>
<TierCard code="business" name={"Business"} selected={newTier === "business"} onClick={() => setNewTier("business")}/>
</div>
{action === Action.CANCEL &&
<Alert severity="warning">
{t("account_upgrade_dialog_cancel_warning", { date: formatShortDate(account.billing.paid_until) })}
</Alert>
}
</DialogContent>
<DialogFooter status={errorText}>
<Button onClick={handleCheckout}>Checkout</Button>
<Button onClick={props.onCancel}>Cancel</Button>
<Button onClick={handleSubmit} disabled={!submitButtonEnabled}>{submitButtonLabel}</Button>
</DialogFooter>
</Dialog>
);
@ -65,8 +96,7 @@ const UpgradeDialog = (props) => {
const TierCard = (props) => {
const cardStyle = (props.selected) ? {
border: "1px solid red",
background: "#eee"
} : {};
return (
<Card sx={{ m: 1, maxWidth: 345 }}>
@ -85,4 +115,10 @@ const TierCard = (props) => {
);
}
const Action = {
CREATE: 1,
UPDATE: 2,
CANCEL: 3
};
export default UpgradeDialog;