diff --git a/packages/backend/server/src/app.module.ts b/packages/backend/server/src/app.module.ts index f5f44aaeb0..3e03605c18 100644 --- a/packages/backend/server/src/app.module.ts +++ b/packages/backend/server/src/app.module.ts @@ -28,6 +28,7 @@ import { GqlModule } from './fundamentals/graphql'; import { HelpersModule } from './fundamentals/helpers'; import { MailModule } from './fundamentals/mailer'; import { MetricsModule } from './fundamentals/metrics'; +import { MutexModule } from './fundamentals/mutex'; import { PrismaModule } from './fundamentals/prisma'; import { StorageProviderModule } from './fundamentals/storage'; import { RateLimiterModule } from './fundamentals/throttler'; @@ -39,6 +40,7 @@ export const FunctionalityModules = [ ScheduleModule.forRoot(), EventModule, CacheModule, + MutexModule, PrismaModule, MetricsModule, RateLimiterModule, diff --git a/packages/backend/server/src/core/features/feature.ts b/packages/backend/server/src/core/features/feature.ts index 48d89d53b8..61a99aa1af 100644 --- a/packages/backend/server/src/core/features/feature.ts +++ b/packages/backend/server/src/core/features/feature.ts @@ -1,5 +1,4 @@ -import { PrismaClient } from '@prisma/client'; - +import { PrismaTransaction } from '../../fundamentals'; import { Feature, FeatureSchema, FeatureType } from './types'; class FeatureConfig { @@ -67,7 +66,7 @@ export type FeatureConfigType = InstanceType< const FeatureCache = new Map>(); -export async function getFeature(prisma: PrismaClient, featureId: number) { +export async function getFeature(prisma: PrismaTransaction, featureId: number) { const cachedQuota = FeatureCache.get(featureId); if (cachedQuota) { diff --git a/packages/backend/server/src/core/quota/quota.ts b/packages/backend/server/src/core/quota/quota.ts index d6d7657c01..3f481de06d 100644 --- a/packages/backend/server/src/core/quota/quota.ts +++ b/packages/backend/server/src/core/quota/quota.ts @@ -1,5 +1,4 @@ -import { PrismaClient } from '@prisma/client'; - +import { PrismaTransaction } from '../../fundamentals'; import { formatDate, formatSize, Quota, QuotaSchema } from './types'; const QuotaCache = new Map(); @@ -7,14 +6,14 @@ const QuotaCache = new Map(); export class QuotaConfig { readonly config: Quota; - static async get(prisma: PrismaClient, featureId: number) { + static async get(tx: PrismaTransaction, featureId: number) { const cachedQuota = QuotaCache.get(featureId); if (cachedQuota) { return cachedQuota; } - const quota = await prisma.features.findFirst({ + const quota = await tx.features.findFirst({ where: { id: featureId, }, diff --git a/packages/backend/server/src/core/quota/service.ts b/packages/backend/server/src/core/quota/service.ts index d8b2b65fe7..fcff2d7da4 100644 --- a/packages/backend/server/src/core/quota/service.ts +++ b/packages/backend/server/src/core/quota/service.ts @@ -1,13 +1,15 @@ import { Injectable } from '@nestjs/common'; import { PrismaClient } from '@prisma/client'; -import { type EventPayload, OnEvent } from '../../fundamentals'; +import { + type EventPayload, + OnEvent, + PrismaTransaction, +} from '../../fundamentals'; import { FeatureKind } from '../features'; import { QuotaConfig } from './quota'; import { QuotaType } from './types'; -type Transaction = Parameters[0]>[0]; - @Injectable() export class QuotaService { constructor(private readonly prisma: PrismaClient) {} @@ -140,8 +142,8 @@ export class QuotaService { }); } - async hasQuota(userId: string, quota: QuotaType, transaction?: Transaction) { - const executor = transaction ?? this.prisma; + async hasQuota(userId: string, quota: QuotaType, tx?: PrismaTransaction) { + const executor = tx ?? this.prisma; return executor.userFeatures .count({ diff --git a/packages/backend/server/src/core/user/service.ts b/packages/backend/server/src/core/user/service.ts index 4c60f00975..a85aaac37f 100644 --- a/packages/backend/server/src/core/user/service.ts +++ b/packages/backend/server/src/core/user/service.ts @@ -54,7 +54,7 @@ export class UserService { return this.createUser({ email, - name: 'Unnamed', + name: email.split('@')[0], ...data, }); } diff --git a/packages/backend/server/src/core/workspaces/resolvers/workspace.ts b/packages/backend/server/src/core/workspaces/resolvers/workspace.ts index d322a69c1b..fd3ca1ecd8 100644 --- a/packages/backend/server/src/core/workspaces/resolvers/workspace.ts +++ b/packages/backend/server/src/core/workspaces/resolvers/workspace.ts @@ -25,7 +25,9 @@ import { EventEmitter, type FileUpload, MailService, + MutexService, Throttle, + TooManyRequestsException, } from '../../../fundamentals'; import { CurrentUser, Public } from '../../auth'; import { QuotaManagementService, QuotaQueryType } from '../../quota'; @@ -58,7 +60,8 @@ export class WorkspaceResolver { private readonly quota: QuotaManagementService, private readonly users: UserService, private readonly event: EventEmitter, - private readonly blobStorage: WorkspaceBlobStorage + private readonly blobStorage: WorkspaceBlobStorage, + private readonly mutex: MutexService ) {} @ResolveField(() => Permission, { @@ -336,74 +339,87 @@ export class WorkspaceResolver { throw new ForbiddenException('Cannot change owner'); } - // member limit check - const [memberCount, quota] = await Promise.all([ - this.prisma.workspaceUserPermission.count({ - where: { workspaceId }, - }), - this.quota.getWorkspaceUsage(workspaceId), - ]); - if (memberCount >= quota.memberLimit) { - throw new PayloadTooLargeException('Workspace member limit reached.'); - } + try { + // lock to prevent concurrent invite + const lockFlag = `invite:${workspaceId}`; + await using lock = await this.mutex.lock(lockFlag); + if (!lock) { + return new TooManyRequestsException('Server is busy'); + } - let target = await this.users.findUserByEmail(email); - if (target) { - const originRecord = await this.prisma.workspaceUserPermission.findFirst({ - where: { - workspaceId, - userId: target.id, - }, - }); - // only invite if the user is not already in the workspace - if (originRecord) return originRecord.id; - } else { - target = await this.users.createAnonymousUser(email, { - registered: false, - }); - } + // member limit check + const [memberCount, quota] = await Promise.all([ + this.prisma.workspaceUserPermission.count({ + where: { workspaceId }, + }), + this.quota.getWorkspaceUsage(workspaceId), + ]); + if (memberCount >= quota.memberLimit) { + return new PayloadTooLargeException('Workspace member limit reached.'); + } - const inviteId = await this.permissions.grant( - workspaceId, - target.id, - permission - ); - if (sendInviteMail) { - const inviteInfo = await this.getInviteInfo(inviteId); - - try { - await this.mailer.sendInviteEmail(email, inviteId, { - workspace: { - id: inviteInfo.workspace.id, - name: inviteInfo.workspace.name, - avatar: inviteInfo.workspace.avatar, - }, - user: { - avatar: inviteInfo.user?.avatarUrl || '', - name: inviteInfo.user?.name || '', - }, + let target = await this.users.findUserByEmail(email); + if (target) { + const originRecord = + await this.prisma.workspaceUserPermission.findFirst({ + where: { + workspaceId, + userId: target.id, + }, + }); + // only invite if the user is not already in the workspace + if (originRecord) return originRecord.id; + } else { + target = await this.users.createAnonymousUser(email, { + registered: false, }); - } catch (e) { - const ret = await this.permissions.revokeWorkspace( - workspaceId, - target.id - ); + } - if (!ret) { - this.logger.fatal( - `failed to send ${workspaceId} invite email to ${email} and failed to revoke permission: ${inviteId}, ${e}` + const inviteId = await this.permissions.grant( + workspaceId, + target.id, + permission + ); + if (sendInviteMail) { + const inviteInfo = await this.getInviteInfo(inviteId); + + try { + await this.mailer.sendInviteEmail(email, inviteId, { + workspace: { + id: inviteInfo.workspace.id, + name: inviteInfo.workspace.name, + avatar: inviteInfo.workspace.avatar, + }, + user: { + avatar: inviteInfo.user?.avatarUrl || '', + name: inviteInfo.user?.name || '', + }, + }); + } catch (e) { + const ret = await this.permissions.revokeWorkspace( + workspaceId, + target.id ); - } else { - this.logger.warn( - `failed to send ${workspaceId} invite email to ${email}, but successfully revoked permission: ${e}` + + if (!ret) { + this.logger.fatal( + `failed to send ${workspaceId} invite email to ${email} and failed to revoke permission: ${inviteId}, ${e}` + ); + } else { + this.logger.warn( + `failed to send ${workspaceId} invite email to ${email}, but successfully revoked permission: ${e}` + ); + } + return new InternalServerErrorException( + 'Failed to send invite email. Please try again.' ); } - return new InternalServerErrorException( - 'Failed to send invite email. Please try again.' - ); } + return inviteId; + } catch (e) { + this.logger.error('failed to invite user', e); + return new TooManyRequestsException('Server is busy'); } - return inviteId; } @Throttle({ diff --git a/packages/backend/server/src/fundamentals/error/index.ts b/packages/backend/server/src/fundamentals/error/index.ts index 0681702e4c..71ecd22a46 100644 --- a/packages/backend/server/src/fundamentals/error/index.ts +++ b/packages/backend/server/src/fundamentals/error/index.ts @@ -1 +1,2 @@ export * from './payment-required'; +export * from './too-many-requests'; diff --git a/packages/backend/server/src/fundamentals/error/too-many-requests.ts b/packages/backend/server/src/fundamentals/error/too-many-requests.ts new file mode 100644 index 0000000000..3a4a96130d --- /dev/null +++ b/packages/backend/server/src/fundamentals/error/too-many-requests.ts @@ -0,0 +1,14 @@ +import { HttpException, HttpStatus } from '@nestjs/common'; + +export class TooManyRequestsException extends HttpException { + constructor(desc?: string, code: string = 'Too Many Requests') { + super( + HttpException.createBody( + desc ?? code, + code, + HttpStatus.TOO_MANY_REQUESTS + ), + HttpStatus.TOO_MANY_REQUESTS + ); + } +} diff --git a/packages/backend/server/src/fundamentals/graphql/index.ts b/packages/backend/server/src/fundamentals/graphql/index.ts index 3f53ea05dd..04b5c6a1d2 100644 --- a/packages/backend/server/src/fundamentals/graphql/index.ts +++ b/packages/backend/server/src/fundamentals/graphql/index.ts @@ -11,6 +11,12 @@ import { GraphQLError } from 'graphql'; import { Config } from '../config'; import { GQLLoggerPlugin } from './logger-plugin'; +export type GraphqlContext = { + req: Request; + res: Response; + isAdminQuery: boolean; +}; + @Global() @Module({ imports: [ @@ -30,7 +36,13 @@ import { GQLLoggerPlugin } from './logger-plugin'; : '../../../schema.gql' ), sortSchema: true, - context: ({ req, res }: { req: Request; res: Response }) => ({ + context: ({ + req, + res, + }: { + req: Request; + res: Response; + }): GraphqlContext => ({ req, res, isAdminQuery: false, diff --git a/packages/backend/server/src/fundamentals/index.ts b/packages/backend/server/src/fundamentals/index.ts index 02c5515e16..9c125c7cca 100644 --- a/packages/backend/server/src/fundamentals/index.ts +++ b/packages/backend/server/src/fundamentals/index.ts @@ -14,14 +14,23 @@ export { } from './config'; export * from './error'; export { EventEmitter, type EventPayload, OnEvent } from './event'; +export type { GraphqlContext } from './graphql'; export { CryptoHelper, URLHelper } from './helpers'; export { MailService } from './mailer'; export { CallCounter, CallTimer, metrics } from './metrics'; +export { + BucketService, + LockGuard, + MUTEX_RETRY, + MUTEX_WAIT, + MutexService, +} from './mutex'; export { getOptionalModuleMetadata, GlobalExceptionFilter, OptionalModule, } from './nestjs'; +export type { PrismaTransaction } from './prisma'; export * from './storage'; export { type StorageProvider, StorageProviderFactory } from './storage'; export { AuthThrottlerGuard, CloudThrottlerGuard, Throttle } from './throttler'; diff --git a/packages/backend/server/src/fundamentals/mutex/bucket.ts b/packages/backend/server/src/fundamentals/mutex/bucket.ts new file mode 100644 index 0000000000..446676cb39 --- /dev/null +++ b/packages/backend/server/src/fundamentals/mutex/bucket.ts @@ -0,0 +1,15 @@ +export class BucketService { + private readonly bucket = new Map(); + + get(key: string) { + return this.bucket.get(key); + } + + set(key: string, value: string) { + this.bucket.set(key, value); + } + + delete(key: string) { + this.bucket.delete(key); + } +} diff --git a/packages/backend/server/src/fundamentals/mutex/index.ts b/packages/backend/server/src/fundamentals/mutex/index.ts new file mode 100644 index 0000000000..5c50fba547 --- /dev/null +++ b/packages/backend/server/src/fundamentals/mutex/index.ts @@ -0,0 +1,14 @@ +import { Global, Module } from '@nestjs/common'; + +import { BucketService } from './bucket'; +import { MutexService } from './mutex'; + +@Global() +@Module({ + providers: [BucketService, MutexService], + exports: [BucketService, MutexService], +}) +export class MutexModule {} + +export { BucketService, MutexService }; +export { LockGuard, MUTEX_RETRY, MUTEX_WAIT } from './mutex'; diff --git a/packages/backend/server/src/fundamentals/mutex/mutex.ts b/packages/backend/server/src/fundamentals/mutex/mutex.ts new file mode 100644 index 0000000000..ffdd8eb889 --- /dev/null +++ b/packages/backend/server/src/fundamentals/mutex/mutex.ts @@ -0,0 +1,96 @@ +import { randomUUID } from 'node:crypto'; +import { setTimeout } from 'node:timers/promises'; + +import { Inject, Injectable, Logger, Scope } from '@nestjs/common'; +import { CONTEXT } from '@nestjs/graphql'; + +import type { GraphqlContext } from '../graphql'; +import { BucketService } from './bucket'; + +export class LockGuard + implements AsyncDisposable +{ + constructor( + private readonly mutex: M, + private readonly key: string + ) {} + + async [Symbol.asyncDispose]() { + return this.mutex.unlock(this.key); + } +} + +export const MUTEX_RETRY = 5; +export const MUTEX_WAIT = 100; + +@Injectable({ scope: Scope.REQUEST }) +export class MutexService { + protected logger = new Logger(MutexService.name); + + constructor( + @Inject(CONTEXT) private readonly context: GraphqlContext, + private readonly bucket: BucketService + ) {} + + protected getId() { + let id = this.context.req.headers['x-transaction-id'] as string; + + if (!id) { + id = randomUUID(); + this.context.req.headers['x-transaction-id'] = id; + } + + return id; + } + + /** + * lock an resource and return a lock guard, which will release the lock when disposed + * + * if the lock is not available, it will retry for [MUTEX_RETRY] times + * + * usage: + * ```typescript + * { + * // lock is acquired here + * await using lock = await mutex.lock('resource-key'); + * if (lock) { + * // do something + * } else { + * // failed to lock + * } + * } + * // lock is released here + * ``` + * @param key resource key + * @returns LockGuard + */ + async lock(key: string): Promise { + const id = this.getId(); + const fetchLock = async (retry: number): Promise => { + if (retry === 0) { + this.logger.error( + `Failed to fetch lock ${key} after ${MUTEX_RETRY} retry` + ); + return undefined; + } + const current = this.bucket.get(key); + if (current && current !== id) { + this.logger.warn( + `Failed to fetch lock ${key}, retrying in ${MUTEX_WAIT} ms` + ); + await setTimeout(MUTEX_WAIT * (MUTEX_RETRY - retry + 1)); + return fetchLock(retry - 1); + } + this.bucket.set(key, id); + return new LockGuard(this, key); + }; + + return fetchLock(MUTEX_RETRY); + } + + async unlock(key: string): Promise { + if (this.bucket.get(key) === this.getId()) { + this.bucket.delete(key); + } + } +} diff --git a/packages/backend/server/src/fundamentals/prisma/index.ts b/packages/backend/server/src/fundamentals/prisma/index.ts index 535e238122..517997fec5 100644 --- a/packages/backend/server/src/fundamentals/prisma/index.ts +++ b/packages/backend/server/src/fundamentals/prisma/index.ts @@ -16,3 +16,7 @@ const clientProvider: Provider = { }) export class PrismaModule {} export { PrismaService } from './service'; + +export type PrismaTransaction = Parameters< + Parameters[0] +>[0]; diff --git a/packages/backend/server/src/plugins/redis/index.ts b/packages/backend/server/src/plugins/redis/index.ts index 58d82c4642..1ca721147a 100644 --- a/packages/backend/server/src/plugins/redis/index.ts +++ b/packages/backend/server/src/plugins/redis/index.ts @@ -1,18 +1,27 @@ import { Global, Provider, Type } from '@nestjs/common'; +import { CONTEXT } from '@nestjs/graphql'; import { Redis, type RedisOptions } from 'ioredis'; import { ThrottlerStorageRedisService } from 'nestjs-throttler-storage-redis'; -import { Cache, SessionCache } from '../../fundamentals'; +import { + BucketService, + Cache, + type GraphqlContext, + MutexService, + SessionCache, +} from '../../fundamentals'; import { ThrottlerStorage } from '../../fundamentals/throttler'; import { SocketIoAdapterImpl } from '../../fundamentals/websocket'; import { Plugin } from '../registry'; import { RedisCache } from './cache'; import { CacheRedis, + MutexRedis, SessionRedis, SocketIoRedis, ThrottlerRedis, } from './instances'; +import { MutexRedisService } from './mutex'; import { createSockerIoAdapterImpl } from './ws-adapter'; function makeProvider(token: Type, impl: Type): Provider { @@ -47,15 +56,31 @@ const socketIoRedisAdapterProvider: Provider = { inject: [SocketIoRedis], }; +// mutex +const mutexRedisAdapterProvider: Provider = { + provide: MutexService, + useFactory: (redis: Redis, ctx: GraphqlContext, bucket: BucketService) => { + return new MutexRedisService(redis, ctx, bucket); + }, + inject: [MutexRedis, CONTEXT, BucketService], +}; + @Global() @Plugin({ name: 'redis', - providers: [CacheRedis, SessionRedis, ThrottlerRedis, SocketIoRedis], + providers: [ + CacheRedis, + SessionRedis, + ThrottlerRedis, + SocketIoRedis, + MutexRedis, + ], overrides: [ cacheProvider, sessionCacheProvider, socketIoRedisAdapterProvider, throttlerStorageProvider, + mutexRedisAdapterProvider, ], requires: ['plugins.redis.host'], }) diff --git a/packages/backend/server/src/plugins/redis/instances.ts b/packages/backend/server/src/plugins/redis/instances.ts index 1e85dec622..8fbd13b0c6 100644 --- a/packages/backend/server/src/plugins/redis/instances.ts +++ b/packages/backend/server/src/plugins/redis/instances.ts @@ -54,3 +54,10 @@ export class SocketIoRedis extends Redis { super({ ...config.plugins.redis, db: (config.plugins.redis?.db ?? 0) + 3 }); } } + +@Injectable() +export class MutexRedis extends Redis { + constructor(config: Config) { + super({ ...config.plugins.redis, db: (config.plugins.redis?.db ?? 0) + 4 }); + } +} diff --git a/packages/backend/server/src/plugins/redis/mutex.ts b/packages/backend/server/src/plugins/redis/mutex.ts new file mode 100644 index 0000000000..9006507f08 --- /dev/null +++ b/packages/backend/server/src/plugins/redis/mutex.ts @@ -0,0 +1,96 @@ +import { setTimeout } from 'node:timers/promises'; + +import { Injectable, Logger } from '@nestjs/common'; +import Redis, { Command } from 'ioredis'; + +import { + BucketService, + type GraphqlContext, + LockGuard, + MUTEX_RETRY, + MUTEX_WAIT, + MutexService, +} from '../../fundamentals'; + +const lockScript = `local key = KEYS[1] +local clientId = ARGV[1] +local releaseTime = ARGV[2] + +if redis.call("get", key) == clientId or redis.call("set", key, clientId, "NX", "PX", releaseTime) then + return 1 +else + return 0 +end`; +const unlockScript = `local key = KEYS[1] +local clientId = ARGV[1] + +if redis.call("get", key) == clientId then + return redis.call("del", key) +else + return 0 +end`; + +@Injectable() +export class MutexRedisService extends MutexService { + constructor( + private readonly redis: Redis, + context: GraphqlContext, + bucket: BucketService + ) { + super(context, bucket); + this.logger = new Logger(MutexRedisService.name); + } + + override async lock( + key: string, + releaseTimeInMS: number = 200 + ): Promise { + const clientId = this.getId(); + this.logger.debug(`Client ${clientId} lock try to lock ${key}`); + const releaseTime = releaseTimeInMS.toString(); + + const fetchLock = async (retry: number): Promise => { + if (retry === 0) { + this.logger.error( + `Failed to fetch lock ${key} after ${MUTEX_RETRY} retry` + ); + return undefined; + } + try { + const success = await this.redis.sendCommand( + new Command('EVAL', [lockScript, '1', key, clientId, releaseTime]) + ); + if (success === 1) { + return new LockGuard(this, key); + } else { + this.logger.warn( + `Failed to fetch lock ${key}, retrying in ${MUTEX_WAIT} ms` + ); + await setTimeout(MUTEX_WAIT * (MUTEX_RETRY - retry + 1)); + return fetchLock(retry - 1); + } + } catch (error: any) { + this.logger.error( + `Unexpected error when fetch lock ${key}: ${error.message}` + ); + return undefined; + } + }; + + return fetchLock(MUTEX_RETRY); + } + + override async unlock(key: string, ignoreUnlockFail = false): Promise { + const clientId = this.getId(); + const result = await this.redis.sendCommand( + new Command('EVAL', [unlockScript, '1', key, clientId]) + ); + if (result === 0) { + if (!ignoreUnlockFail) { + throw new Error(`Failed to release lock ${key}`); + } else { + this.logger.warn(`Failed to release lock ${key}`); + } + } + } +} diff --git a/packages/backend/server/tests/tsconfig.json b/packages/backend/server/tests/tsconfig.json index b67a10b39d..445549efc4 100644 --- a/packages/backend/server/tests/tsconfig.json +++ b/packages/backend/server/tests/tsconfig.json @@ -2,7 +2,7 @@ "extends": "../../../../tsconfig.json", "compilerOptions": { "composite": true, - "target": "ESNext", + "target": "ES2022", "emitDecoratorMetadata": true, "experimentalDecorators": true, "rootDir": ".", diff --git a/packages/backend/server/tests/workspace-invite.e2e.ts b/packages/backend/server/tests/workspace-invite.e2e.ts index 767e518e2f..6a0c12018b 100644 --- a/packages/backend/server/tests/workspace-invite.e2e.ts +++ b/packages/backend/server/tests/workspace-invite.e2e.ts @@ -104,7 +104,7 @@ test('should create user if not exist', async t => { const user = await auth.getUserByEmail('u2@affine.pro'); t.not(user, undefined, 'failed to create user'); - t.is(user?.name, 'Unnamed', 'failed to create user'); + t.is(user?.name, 'u2', 'failed to create user'); }); test('should invite a user by link', async t => { @@ -255,3 +255,25 @@ test('should support pagination for member', async t => { ); t.is(secondPageWorkspace.members.length, 1, 'failed to check invite id'); }); + +test('should limit member count correctly', async t => { + const { app } = t.context; + const u1 = await signUp(app, 'u1', 'u1@affine.pro', '1'); + for (let i = 0; i < 10; i++) { + const workspace = await createWorkspace(app, u1.token.token); + await Promise.allSettled( + Array.from({ length: 10 }).map(async (_, i) => + inviteUser( + app, + u1.token.token, + workspace.id, + `u${i}@affine.pro`, + 'Admin' + ) + ) + ); + + const ws = await getWorkspace(app, u1.token.token, workspace.id); + t.assert(ws.members.length <= 3, 'failed to check member list'); + } +}); diff --git a/packages/backend/server/tsconfig.json b/packages/backend/server/tsconfig.json index 4fc8005390..ec754fecdb 100644 --- a/packages/backend/server/tsconfig.json +++ b/packages/backend/server/tsconfig.json @@ -2,7 +2,7 @@ "extends": "../../../tsconfig.json", "compilerOptions": { "composite": true, - "target": "ESNext", + "target": "ES2022", "module": "ESNext", "emitDecoratorMetadata": true, "experimentalDecorators": true,