feat(server): support registering ai early access users (#6565)

This commit is contained in:
forehalo 2024-04-16 13:54:08 +00:00
parent 677c4711df
commit e1c292b8b5
No known key found for this signature in database
GPG Key ID: 56709255DC7EC728
14 changed files with 331 additions and 139 deletions

View File

@ -61,7 +61,18 @@ export class UnlimitedCopilotFeatureConfig extends FeatureConfig {
super(data);
if (this.config.feature !== FeatureType.UnlimitedCopilot) {
throw new Error('Invalid feature config: type is not UnlimitedWorkspace');
throw new Error('Invalid feature config: type is not AIEarlyAccess');
}
}
}
export class AIEarlyAccessFeatureConfig extends FeatureConfig {
override config!: Feature & { feature: FeatureType.AIEarlyAccess };
constructor(data: any) {
super(data);
if (this.config.feature !== FeatureType.AIEarlyAccess) {
throw new Error('Invalid feature config: type is not AIEarlyAccess');
}
}
}
@ -69,6 +80,7 @@ export class UnlimitedCopilotFeatureConfig extends FeatureConfig {
const FeatureConfigMap = {
[FeatureType.Copilot]: CopilotFeatureConfig,
[FeatureType.EarlyAccess]: EarlyAccessFeatureConfig,
[FeatureType.AIEarlyAccess]: AIEarlyAccessFeatureConfig,
[FeatureType.UnlimitedWorkspace]: UnlimitedWorkspaceFeatureConfig,
[FeatureType.UnlimitedCopilot]: UnlimitedCopilotFeatureConfig,
};

View File

@ -1,6 +1,6 @@
import { Module } from '@nestjs/common';
import { FeatureManagementService } from './management';
import { EarlyAccessType, FeatureManagementService } from './management';
import { FeatureService } from './service';
/**
@ -15,6 +15,11 @@ import { FeatureService } from './service';
})
export class FeatureModule {}
export { type CommonFeature, commonFeatureSchema } from './types';
export { FeatureKind, Features, FeatureType } from './types';
export { FeatureManagementService, FeatureService };
export {
type CommonFeature,
commonFeatureSchema,
FeatureKind,
Features,
FeatureType,
} from './types';
export { EarlyAccessType, FeatureManagementService, FeatureService };

View File

@ -7,6 +7,11 @@ import { FeatureType } from './types';
const STAFF = ['@toeverything.info'];
export enum EarlyAccessType {
App = 'app',
AI = 'ai',
}
@Injectable()
export class FeatureManagementService {
protected logger = new Logger(FeatureManagementService.name);
@ -30,24 +35,43 @@ export class FeatureManagementService {
}
// ======== Early Access ========
async addEarlyAccess(userId: string) {
async addEarlyAccess(
userId: string,
type: EarlyAccessType = EarlyAccessType.App
) {
return this.feature.addUserFeature(
userId,
FeatureType.EarlyAccess,
type === EarlyAccessType.App
? FeatureType.EarlyAccess
: FeatureType.AIEarlyAccess,
'Early access user'
);
}
async removeEarlyAccess(userId: string) {
return this.feature.removeUserFeature(userId, FeatureType.EarlyAccess);
async removeEarlyAccess(
userId: string,
type: EarlyAccessType = EarlyAccessType.App
) {
return this.feature.removeUserFeature(
userId,
type === EarlyAccessType.App
? FeatureType.EarlyAccess
: FeatureType.AIEarlyAccess
);
}
async listEarlyAccess() {
return this.feature.listFeatureUsers(FeatureType.EarlyAccess);
async listEarlyAccess(type: EarlyAccessType = EarlyAccessType.App) {
return this.feature.listFeatureUsers(
type === EarlyAccessType.App
? FeatureType.EarlyAccess
: FeatureType.AIEarlyAccess
);
}
async isEarlyAccessUser(email: string) {
async isEarlyAccessUser(
email: string,
type: EarlyAccessType = EarlyAccessType.App
) {
const user = await this.prisma.user.findFirst({
where: {
email: {
@ -56,9 +80,15 @@ export class FeatureManagementService {
},
},
});
if (user) {
const canEarlyAccess = await this.feature
.hasUserFeature(user.id, FeatureType.EarlyAccess)
.hasUserFeature(
user.id,
type === EarlyAccessType.App
? FeatureType.EarlyAccess
: FeatureType.AIEarlyAccess
)
.catch(() => false);
return canEarlyAccess;
@ -67,9 +97,12 @@ export class FeatureManagementService {
}
/// check early access by email
async canEarlyAccess(email: string) {
async canEarlyAccess(
email: string,
type: EarlyAccessType = EarlyAccessType.App
) {
if (this.config.featureFlags.earlyAccessPreview && !this.isStaff(email)) {
return this.isEarlyAccessUser(email);
return this.isEarlyAccessUser(email, type);
} else {
return true;
}

View File

@ -63,13 +63,6 @@ export class FeatureService {
expiredAt?: Date | string
) {
return this.prisma.$transaction(async tx => {
const latestVersion = await tx.features
.aggregate({
where: { feature },
_max: { version: true },
})
.then(r => r._max.version || 1);
const latestFlag = await tx.userFeatures.findFirst({
where: {
userId,
@ -83,9 +76,21 @@ export class FeatureService {
createdAt: 'desc',
},
});
if (latestFlag) {
return latestFlag.id;
} else {
const latestVersion = await tx.features
.aggregate({
where: { feature },
_max: { version: true },
})
.then(r => r._max.version);
if (!latestVersion) {
throw new Error(`Feature ${feature} not found`);
}
return tx.userFeatures
.create({
data: {

View File

@ -3,6 +3,7 @@ import { registerEnumType } from '@nestjs/graphql';
export enum FeatureType {
// user feature
EarlyAccess = 'early_access',
AIEarlyAccess = 'ai_early_access',
UnlimitedCopilot = 'unlimited_copilot',
// workspace feature
Copilot = 'copilot',

View File

@ -9,3 +9,8 @@ export const featureEarlyAccess = z.object({
whitelist: z.string().array(),
}),
});
export const featureAIEarlyAccess = z.object({
feature: z.literal(FeatureType.AIEarlyAccess),
configs: z.object({}),
});

View File

@ -2,7 +2,7 @@ import { z } from 'zod';
import { FeatureType } from './common';
import { featureCopilot } from './copilot';
import { featureEarlyAccess } from './early-access';
import { featureAIEarlyAccess, featureEarlyAccess } from './early-access';
import { featureUnlimitedCopilot } from './unlimited-copilot';
import { featureUnlimitedWorkspace } from './unlimited-workspace';
@ -59,6 +59,12 @@ export const Features: Feature[] = [
version: 1,
configs: {},
},
{
feature: FeatureType.AIEarlyAccess,
type: FeatureKind.Feature,
version: 1,
configs: {},
},
];
/// ======== schema infer ========
@ -71,6 +77,7 @@ export const FeatureSchema = commonFeatureSchema
z.discriminatedUnion('feature', [
featureCopilot,
featureEarlyAccess,
featureAIEarlyAccess,
featureUnlimitedWorkspace,
featureUnlimitedCopilot,
])

View File

@ -3,15 +3,27 @@ import {
ForbiddenException,
UseGuards,
} from '@nestjs/common';
import { Args, Context, Int, Mutation, Query, Resolver } from '@nestjs/graphql';
import {
Args,
Context,
Int,
Mutation,
Query,
registerEnumType,
Resolver,
} from '@nestjs/graphql';
import { CloudThrottlerGuard, Throttle } from '../../fundamentals';
import { CurrentUser } from '../auth/current-user';
import { sessionUser } from '../auth/service';
import { FeatureManagementService } from '../features';
import { EarlyAccessType, FeatureManagementService } from '../features';
import { UserService } from './service';
import { UserType } from './types';
registerEnumType(EarlyAccessType, {
name: 'EarlyAccessType',
});
/**
* User resolver
* All op rate limit: 10 req/m
@ -33,19 +45,20 @@ export class UserManagementResolver {
@Mutation(() => Int)
async addToEarlyAccess(
@CurrentUser() currentUser: CurrentUser,
@Args('email') email: string
@Args('email') email: string,
@Args({ name: 'type', type: () => EarlyAccessType }) type: EarlyAccessType
): Promise<number> {
if (!this.feature.isStaff(currentUser.email)) {
throw new ForbiddenException('You are not allowed to do this');
}
const user = await this.users.findUserByEmail(email);
if (user) {
return this.feature.addEarlyAccess(user.id);
return this.feature.addEarlyAccess(user.id, type);
} else {
const user = await this.users.createAnonymousUser(email, {
registered: false,
});
return this.feature.addEarlyAccess(user.id);
return this.feature.addEarlyAccess(user.id, type);
}
}

View File

@ -0,0 +1,14 @@
import { PrismaClient } from '@prisma/client';
import { FeatureType } from '../../core/features';
import { upsertLatestFeatureVersion } from './utils/user-features';
export class AiEarlyAccess1713176777814 {
// do the migration
static async up(db: PrismaClient) {
await upsertLatestFeatureVersion(db, FeatureType.AIEarlyAccess);
}
// revert the migration
static async down(_db: PrismaClient) {}
}

View File

@ -160,17 +160,16 @@ export class SubscriptionResolver {
@Public()
@Query(() => [SubscriptionPrice])
async prices(): Promise<SubscriptionPrice[]> {
const prices = await this.service.listPrices();
async prices(
@CurrentUser() user?: CurrentUser
): Promise<SubscriptionPrice[]> {
const prices = await this.service.listPrices(user);
const group = groupBy(
prices.data.filter(price => !!price.lookup_key),
price => {
// @ts-expect-error empty lookup key is filtered out
const [plan] = decodeLookupKey(price.lookup_key);
return plan;
}
);
const group = groupBy(prices, price => {
// @ts-expect-error empty lookup key is filtered out
const [plan] = decodeLookupKey(price.lookup_key);
return plan;
});
function findPrice(plan: SubscriptionPlan) {
const prices = group[plan];

View File

@ -188,7 +188,7 @@ export class ScheduleManager {
});
}
async update(idempotencyKey: string, price: string, coupon?: string) {
async update(idempotencyKey: string, price: string) {
if (!this._schedule) {
throw new Error('No schedule');
}
@ -198,10 +198,7 @@ export class ScheduleManager {
}
// if current phase's plan matches target, and no coupon change, just release the schedule
if (
this.currentPhase.items[0].price === price &&
(!coupon || this.currentPhase.coupon === coupon)
) {
if (this.currentPhase.items[0].price === price) {
await this.stripe.subscriptionSchedules.release(this._schedule.id, {
idempotencyKey,
});
@ -227,7 +224,10 @@ export class ScheduleManager {
quantity: 1,
},
],
coupon,
coupon:
typeof this.currentPhase.coupon === 'string'
? this.currentPhase.coupon
: this.currentPhase.coupon?.id ?? undefined,
},
],
},

View File

@ -1,3 +1,5 @@
import { randomUUID } from 'node:crypto';
import { BadRequestException, Injectable, Logger } from '@nestjs/common';
import { OnEvent as RawOnEvent } from '@nestjs/event-emitter';
import type {
@ -11,12 +13,13 @@ import { PrismaClient } from '@prisma/client';
import Stripe from 'stripe';
import { CurrentUser } from '../../core/auth';
import { FeatureManagementService } from '../../core/features';
import { EarlyAccessType, FeatureManagementService } from '../../core/features';
import { EventEmitter } from '../../fundamentals';
import { ScheduleManager } from './schedule';
import {
InvoiceStatus,
SubscriptionPlan,
SubscriptionPriceVariant,
SubscriptionRecurring,
SubscriptionStatus,
} from './types';
@ -29,17 +32,22 @@ const OnEvent = (
// Plan x Recurring make a stripe price lookup key
export function encodeLookupKey(
plan: SubscriptionPlan,
recurring: SubscriptionRecurring
recurring: SubscriptionRecurring,
variant?: SubscriptionPriceVariant
): string {
return plan + '_' + recurring;
return `${plan}_${recurring}` + (variant ? `_${variant}` : '');
}
export function decodeLookupKey(
key: string
): [SubscriptionPlan, SubscriptionRecurring] {
const [plan, recurring] = key.split('_');
): [SubscriptionPlan, SubscriptionRecurring, SubscriptionPriceVariant?] {
const [plan, recurring, variant] = key.split('_');
return [plan as SubscriptionPlan, recurring as SubscriptionRecurring];
return [
plan as SubscriptionPlan,
recurring as SubscriptionRecurring,
variant as SubscriptionPriceVariant | undefined,
];
}
const SubscriptionActivated: Stripe.Subscription.Status[] = [
@ -48,8 +56,9 @@ const SubscriptionActivated: Stripe.Subscription.Status[] = [
];
export enum CouponType {
EarlyAccess = 'earlyaccess',
EarlyAccessRenew = 'earlyaccessrenew',
ProEarlyAccessOneYearFree = 'pro_ea_one_year_free',
AIEarlyAccessOneYearFree = 'ai_ea_one_year_free',
ProEarlyAccessAIOneYearFree = 'ai_pro_ea_one_year_free',
}
@Injectable()
@ -64,10 +73,70 @@ export class SubscriptionService {
private readonly features: FeatureManagementService
) {}
async listPrices() {
return this.stripe.prices.list({
async listPrices(user?: CurrentUser) {
let canHaveEarlyAccessDiscount = false;
let canHaveAIEarlyAccessDiscount = false;
if (user) {
canHaveEarlyAccessDiscount = await this.features.isEarlyAccessUser(
user.email
);
canHaveAIEarlyAccessDiscount = await this.features.isEarlyAccessUser(
user.email,
EarlyAccessType.AI
);
const customer = await this.getOrCreateCustomer(
'list-price:' + randomUUID(),
user
);
const oldSubscriptions = await this.stripe.subscriptions.list({
customer: customer.stripeCustomerId,
status: 'all',
});
oldSubscriptions.data.forEach(sub => {
if (sub.items.data[0].price.lookup_key) {
const [oldPlan] = decodeLookupKey(sub.items.data[0].price.lookup_key);
if (oldPlan === SubscriptionPlan.Pro) {
canHaveEarlyAccessDiscount = false;
}
if (oldPlan === SubscriptionPlan.AI) {
canHaveAIEarlyAccessDiscount = false;
}
}
});
}
const list = await this.stripe.prices.list({
active: true,
});
return list.data.filter(price => {
if (!price.lookup_key) {
return false;
}
const [plan, recurring, variant] = decodeLookupKey(price.lookup_key);
if (recurring === SubscriptionRecurring.Monthly) {
return !variant;
}
if (plan === SubscriptionPlan.Pro) {
return (
(canHaveEarlyAccessDiscount && variant) ||
(!canHaveEarlyAccessDiscount && !variant)
);
}
if (plan === SubscriptionPlan.AI) {
return (
(canHaveAIEarlyAccessDiscount && variant) ||
(!canHaveAIEarlyAccessDiscount && !variant)
);
}
return false;
});
}
async createCheckoutSession({
@ -99,13 +168,18 @@ export class SubscriptionService {
);
}
const price = await this.getPrice(plan, recurring);
const customer = await this.getOrCreateCustomer(
`${idempotencyKey}-getOrCreateCustomer`,
user
);
let discount: { coupon?: string; promotion_code?: string } | undefined;
const { price, coupon } = await this.getAvailablePrice(
customer,
plan,
recurring
);
let discounts: Stripe.Checkout.SessionCreateParams['discounts'] = [];
if (promotionCode) {
const code = await this.getAvailablePromotionCode(
@ -113,18 +187,10 @@ export class SubscriptionService {
customer.stripeCustomerId
);
if (code) {
discount ??= {};
discount.promotion_code = code;
}
} else if (plan === SubscriptionPlan.Pro) {
const coupon = await this.getAvailableCoupon(
user,
CouponType.EarlyAccess
);
if (coupon) {
discount ??= {};
discount.coupon = coupon;
discounts = [{ promotion_code: code }];
}
} else if (coupon) {
discounts = [{ coupon }];
}
return await this.stripe.checkout.sessions.create(
@ -138,11 +204,7 @@ export class SubscriptionService {
tax_id_collection: {
enabled: true,
},
...(discount
? {
discounts: [discount],
}
: { allow_promotion_codes: true }),
discounts,
mode: 'subscription',
success_url: redirectUrl,
customer: customer.stripeCustomerId,
@ -314,16 +376,7 @@ export class SubscriptionService {
subscriptionInDB.stripeSubscriptionId
);
await manager.update(
`${idempotencyKey}-update`,
price,
// if user is early access user, use early access coupon
manager.currentPhase?.coupon === CouponType.EarlyAccess ||
manager.currentPhase?.coupon === CouponType.EarlyAccessRenew ||
manager.nextPhase?.coupon === CouponType.EarlyAccessRenew
? CouponType.EarlyAccessRenew
: undefined
);
await manager.update(`${idempotencyKey}-update`, price);
return await this.db.userSubscription.update({
where: {
@ -392,20 +445,6 @@ export class SubscriptionService {
if (!line.price || line.price.type !== 'recurring') {
throw new Error('Unknown invoice with no recurring price');
}
// deal with early access user
if (stripeInvoice.discount?.coupon.id === CouponType.EarlyAccess) {
const idempotencyKey = stripeInvoice.id + '_earlyaccess';
const manager = await this.scheduleManager.fromSubscription(
`${idempotencyKey}-fromSubscription`,
line.subscription as string
);
await manager.update(
`${idempotencyKey}-update`,
line.price.id,
CouponType.EarlyAccessRenew
);
}
}
@OnEvent('invoice.created')
@ -591,38 +630,41 @@ export class SubscriptionService {
private async getOrCreateCustomer(
idempotencyKey: string,
user: CurrentUser
): Promise<UserStripeCustomer> {
const customer = await this.db.userStripeCustomer.findUnique({
): Promise<UserStripeCustomer & { email: string }> {
let customer = await this.db.userStripeCustomer.findUnique({
where: {
userId: user.id,
},
});
if (customer) {
return customer;
if (!customer) {
const stripeCustomersList = await this.stripe.customers.list({
email: user.email,
limit: 1,
});
let stripeCustomer: Stripe.Customer | undefined;
if (stripeCustomersList.data.length) {
stripeCustomer = stripeCustomersList.data[0];
} else {
stripeCustomer = await this.stripe.customers.create(
{ email: user.email },
{ idempotencyKey }
);
}
customer = await this.db.userStripeCustomer.create({
data: {
userId: user.id,
stripeCustomerId: stripeCustomer.id,
},
});
}
const stripeCustomersList = await this.stripe.customers.list({
return {
...customer,
email: user.email,
limit: 1,
});
let stripeCustomer: Stripe.Customer | undefined;
if (stripeCustomersList.data.length) {
stripeCustomer = stripeCustomersList.data[0];
} else {
stripeCustomer = await this.stripe.customers.create(
{ email: user.email },
{ idempotencyKey }
);
}
return await this.db.userStripeCustomer.create({
data: {
userId: user.id,
stripeCustomerId: stripeCustomer.id,
},
});
};
}
private async retrieveUserFromCustomer(customerId: string) {
@ -674,10 +716,11 @@ export class SubscriptionService {
private async getPrice(
plan: SubscriptionPlan,
recurring: SubscriptionRecurring
recurring: SubscriptionRecurring,
variant?: SubscriptionPriceVariant
): Promise<string> {
const prices = await this.stripe.prices.list({
lookup_keys: [encodeLookupKey(plan, recurring)],
lookup_keys: [encodeLookupKey(plan, recurring, variant)],
});
if (!prices.data.length) {
@ -689,22 +732,67 @@ export class SubscriptionService {
return prices.data[0].id;
}
private async getAvailableCoupon(
user: CurrentUser,
couponType: CouponType
): Promise<string | null> {
const earlyAccess = await this.features.isEarlyAccessUser(user.email);
if (earlyAccess) {
try {
const coupon = await this.stripe.coupons.retrieve(couponType);
return coupon.valid ? coupon.id : null;
} catch (e) {
this.logger.error('Failed to get early access coupon', e);
return null;
}
}
/**
* Get available for different plans with special early-access price and coupon
*/
private async getAvailablePrice(
customer: UserStripeCustomer & { email: string },
plan: SubscriptionPlan,
recurring: SubscriptionRecurring
): Promise<{ price: string; coupon?: string }> {
const isEaUser = await this.features.isEarlyAccessUser(customer.email);
const oldSubscriptions = await this.stripe.subscriptions.list({
customer: customer.stripeCustomerId,
status: 'all',
});
return null;
const subscribed = oldSubscriptions.data.some(sub => {
if (sub.items.data[0].price.lookup_key) {
const [oldPlan] = decodeLookupKey(sub.items.data[0].price.lookup_key);
return oldPlan === plan;
}
return false;
});
if (plan === SubscriptionPlan.Pro) {
const canHaveEADiscount = isEaUser && !subscribed;
const price = await this.getPrice(
plan,
recurring,
canHaveEADiscount && recurring === SubscriptionRecurring.Yearly
? SubscriptionPriceVariant.EA
: undefined
);
return {
price,
coupon: !subscribed ? CouponType.ProEarlyAccessOneYearFree : undefined,
};
} else {
const isAIEaUser = await this.features.isEarlyAccessUser(
customer.email,
EarlyAccessType.AI
);
const canHaveEADiscount = isAIEaUser && !subscribed;
const price = await this.getPrice(
plan,
recurring,
canHaveEADiscount && recurring === SubscriptionRecurring.Yearly
? SubscriptionPriceVariant.EA
: undefined
);
return {
price,
coupon: !subscribed
? isAIEaUser
? CouponType.AIEarlyAccessOneYearFree
: isEaUser
? CouponType.ProEarlyAccessAIOneYearFree
: undefined
: undefined,
};
}
}
private async getAvailablePromotionCode(

View File

@ -26,6 +26,10 @@ export enum SubscriptionPlan {
SelfHosted = 'selfhosted',
}
export enum SubscriptionPriceVariant {
EA = 'earlyaccess',
}
// see https://stripe.com/docs/api/subscriptions/object#subscription_object-status
export enum SubscriptionStatus {
Active = 'active',

View File

@ -80,8 +80,14 @@ type DocHistoryType {
workspaceId: String!
}
enum EarlyAccessType {
AI
App
}
"""The type of workspace feature"""
enum FeatureType {
AIEarlyAccess
Copilot
EarlyAccess
UnlimitedCopilot
@ -170,7 +176,7 @@ type LimitedUserType {
type Mutation {
acceptInviteById(inviteId: String!, sendAcceptMail: Boolean, workspaceId: String!): Boolean!
addToEarlyAccess(email: String!): Int!
addToEarlyAccess(email: String!, type: EarlyAccessType!): Int!
addWorkspaceFeature(feature: FeatureType!, workspaceId: String!): Int!
cancelSubscription(idempotencyKey: String!, plan: SubscriptionPlan = Pro): UserSubscription!
changeEmail(email: String!, token: String!): UserType!