From 45b97c705451930a44dae5c51f0d61bf84e2e898 Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Thu, 19 Jan 2023 14:03:39 -0500 Subject: [PATCH] Deleting account deletes subscription --- server/server.go | 4 ++- server/server_account.go | 10 ++++++++ server/server_payments.go | 5 ++++ server/server_payments_test.go | 47 ++++++++++++++++++++++++++++++++++ server/visitor.go | 2 ++ 5 files changed, 67 insertions(+), 1 deletion(-) diff --git a/server/server.go b/server/server.go index 66a46d8e..36bb9583 100644 --- a/server/server.go +++ b/server/server.go @@ -36,8 +36,10 @@ import ( /* TODO + races: + - v.user --> see publishSyncEventAsync() test + payments: - - delete subscription when account deleted - delete messages + reserved topics on ResetTier Limits & rate limiting: diff --git a/server/server_account.go b/server/server_account.go index 8414c9aa..5bc36016 100644 --- a/server/server_account.go +++ b/server/server_account.go @@ -119,6 +119,16 @@ func (s *Server) handleAccountGet(w http.ResponseWriter, _ *http.Request, v *vis } func (s *Server) handleAccountDelete(w http.ResponseWriter, _ *http.Request, v *visitor) error { + if v.user.Billing.StripeCustomerID != "" { + log.Info("Deleting user %s (billing customer: %s, billing subscription: %s)", v.user.Name, v.user.Billing.StripeCustomerID, v.user.Billing.StripeSubscriptionID) + if v.user.Billing.StripeSubscriptionID != "" { + if _, err := s.stripe.CancelSubscription(v.user.Billing.StripeSubscriptionID); err != nil { + return err + } + } + } else { + log.Info("Deleting user %s", v.user.Name) + } if err := s.userManager.RemoveUser(v.user.Name); err != nil { return err } diff --git a/server/server_payments.go b/server/server_payments.go index ee1bb1a2..45ef82e8 100644 --- a/server/server_payments.go +++ b/server/server_payments.go @@ -359,6 +359,7 @@ type stripeAPI interface { GetSession(id string) (*stripe.CheckoutSession, error) GetSubscription(id string) (*stripe.Subscription, error) UpdateSubscription(id string, params *stripe.SubscriptionParams) (*stripe.Subscription, error) + CancelSubscription(id string) (*stripe.Subscription, error) ConstructWebhookEvent(payload []byte, header string, secret string) (stripe.Event, error) } @@ -407,6 +408,10 @@ func (s *realStripeAPI) UpdateSubscription(id string, params *stripe.Subscriptio return subscription.Update(id, params) } +func (s *realStripeAPI) CancelSubscription(id string) (*stripe.Subscription, error) { + return subscription.Cancel(id, nil) +} + func (s *realStripeAPI) ConstructWebhookEvent(payload []byte, header string, secret string) (stripe.Event, error) { return webhook.ConstructEvent(payload, header, secret) } diff --git a/server/server_payments_test.go b/server/server_payments_test.go index 43375d62..2f2c60f0 100644 --- a/server/server_payments_test.go +++ b/server/server_payments_test.go @@ -83,6 +83,48 @@ func TestPayments_SubscriptionCreate_StripeCustomer_Success(t *testing.T) { require.Equal(t, "https://billing.stripe.com/abc/def", redirectResponse.RedirectURL) } +func TestPayments_AccountDelete_Cancels_Subscription(t *testing.T) { + stripeMock := &testStripeAPI{} + defer stripeMock.AssertExpectations(t) + + c := newTestConfigWithAuthFile(t) + c.EnableSignup = true + c.StripeSecretKey = "secret key" + c.StripeWebhookKey = "webhook key" + s := newTestServer(t, c) + s.stripe = stripeMock + + // Define how the mock should react + stripeMock. + On("CancelSubscription", "sub_123"). + Return(&stripe.Subscription{}, nil) + + // Create tier and user + require.Nil(t, s.userManager.CreateTier(&user.Tier{ + Code: "pro", + StripePriceID: "price_123", + })) + require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test")) + + u, err := s.userManager.User("phil") + require.Nil(t, err) + + u.Billing.StripeCustomerID = "acct_123" + u.Billing.StripeSubscriptionID = "sub_123" + require.Nil(t, s.userManager.ChangeBilling(u)) + + // Delete account + rr := request(t, s, "DELETE", "/v1/account", "", map[string]string{ + "Authorization": util.BasicAuth("phil", "phil"), + }) + require.Equal(t, 200, rr.Code) + + rr = request(t, s, "GET", "/v1/account", "", map[string]string{ + "Authorization": util.BasicAuth("phil", "mypass"), + }) + require.Equal(t, 401, rr.Code) +} + type testStripeAPI struct { mock.Mock } @@ -122,6 +164,11 @@ func (s *testStripeAPI) UpdateSubscription(id string, params *stripe.Subscriptio return args.Get(0).(*stripe.Subscription), args.Error(1) } +func (s *testStripeAPI) CancelSubscription(id string) (*stripe.Subscription, error) { + args := s.Called(id) + return args.Get(0).(*stripe.Subscription), args.Error(1) +} + func (s *testStripeAPI) ConstructWebhookEvent(payload []byte, header string, secret string) (stripe.Event, error) { args := s.Called(payload, header, secret) return args.Get(0).(stripe.Event), args.Error(1) diff --git a/server/visitor.go b/server/visitor.go index 0075b6ba..5fd89ffa 100644 --- a/server/visitor.go +++ b/server/visitor.go @@ -213,6 +213,8 @@ func (v *visitor) ResetStats() { } func (v *visitor) Limits() *visitorLimits { + v.mu.Lock() + defer v.mu.Unlock() limits := defaultVisitorLimits(v.config) if v.user != nil && v.user.Tier != nil { limits.Basis = visitorLimitBasisTier