chore(server): clean up throttler (#6326)

This commit is contained in:
liuyi 2024-04-17 16:32:26 +08:00 committed by GitHub
parent 5b315bfc81
commit e53d5e2e3d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 551 additions and 265 deletions

View File

@ -1,13 +1,12 @@
import { join } from 'node:path';
import { Logger, Module } from '@nestjs/common';
import { APP_GUARD, APP_INTERCEPTOR } from '@nestjs/core';
import { ScheduleModule } from '@nestjs/schedule';
import { ServeStaticModule } from '@nestjs/serve-static';
import { get } from 'lodash-es';
import { AppController } from './app.controller';
import { AuthGuard, AuthModule } from './core/auth';
import { AuthModule } from './core/auth';
import { ADD_ENABLED_FEATURES, ServerConfigModule } from './core/config';
import { DocModule } from './core/doc';
import { FeatureModule } from './core/features';
@ -17,7 +16,7 @@ import { SyncModule } from './core/sync';
import { UserModule } from './core/user';
import { WorkspaceModule } from './core/workspaces';
import { getOptionalModuleMetadata } from './fundamentals';
import { CacheInterceptor, CacheModule } from './fundamentals/cache';
import { CacheModule } from './fundamentals/cache';
import type { AvailablePlugins } from './fundamentals/config';
import { Config, ConfigModule } from './fundamentals/config';
import { EventModule } from './fundamentals/event';
@ -103,16 +102,6 @@ export class AppModuleBuilder {
compile() {
@Module({
providers: [
{
provide: APP_INTERCEPTOR,
useClass: CacheInterceptor,
},
{
provide: APP_GUARD,
useClass: AuthGuard,
},
],
imports: this.modules,
controllers: this.config.isSelfhosted ? [] : [AppController],
})

View File

@ -4,7 +4,12 @@ import type { NestExpressApplication } from '@nestjs/platform-express';
import cookieParser from 'cookie-parser';
import graphqlUploadExpress from 'graphql-upload/graphqlUploadExpress.mjs';
import { GlobalExceptionFilter } from './fundamentals';
import { AuthGuard } from './core/auth';
import {
CacheInterceptor,
CloudThrottlerGuard,
GlobalExceptionFilter,
} from './fundamentals';
import { SocketIoAdapter, SocketIoAdapterImpl } from './fundamentals/websocket';
import { serverTimingAndCache } from './middleware/timing';
@ -28,6 +33,8 @@ export async function createApp() {
})
);
app.useGlobalGuards(app.get(AuthGuard), app.get(CloudThrottlerGuard));
app.useGlobalInterceptors(app.get(CacheInterceptor));
app.useGlobalFilters(new GlobalExceptionFilter(app.getHttpAdapter()));
app.use(cookieParser());

View File

@ -14,7 +14,11 @@ import {
} from '@nestjs/common';
import type { Request, Response } from 'express';
import { PaymentRequiredException, URLHelper } from '../../fundamentals';
import {
PaymentRequiredException,
Throttle,
URLHelper,
} from '../../fundamentals';
import { UserService } from '../user';
import { validators } from '../utils/validators';
import { CurrentUser } from './current-user';
@ -27,6 +31,7 @@ class SignInCredential {
password?: string;
}
@Throttle('strict')
@Controller('/api/auth')
export class AuthController {
constructor(
@ -158,6 +163,7 @@ export class AuthController {
return this.url.safeRedirect(res, redirectUri);
}
@Throttle('default', { limit: 1200 })
@Public()
@Get('/session')
async currentSessionUser(@CurrentUser() user?: CurrentUser) {
@ -166,6 +172,7 @@ export class AuthController {
};
}
@Throttle('default', { limit: 1200 })
@Public()
@Get('/sessions')
async currentSessionUsers(@Req() req: Request) {

View File

@ -54,6 +54,7 @@ export class AuthGuard implements CanActivate, OnModuleInit {
const user = await this.auth.getUser(sessionToken, userSeq);
if (user) {
req.sid = sessionToken;
req.user = user;
}
}

View File

@ -1,8 +1,4 @@
import {
BadRequestException,
ForbiddenException,
UseGuards,
} from '@nestjs/common';
import { BadRequestException, ForbiddenException } from '@nestjs/common';
import {
Args,
Context,
@ -16,7 +12,7 @@ import {
} from '@nestjs/graphql';
import type { Request, Response } from 'express';
import { CloudThrottlerGuard, Config, Throttle } from '../../fundamentals';
import { Config, Throttle } from '../../fundamentals';
import { UserService } from '../user';
import { UserType } from '../user/types';
import { validators } from '../utils/validators';
@ -43,7 +39,7 @@ export class ClientTokenType {
* Sign up/in rate limit: 10 req/m
* Other rate limit: 5 req/m
*/
@UseGuards(CloudThrottlerGuard)
@Throttle('strict')
@Resolver(() => UserType)
export class AuthResolver {
constructor(
@ -53,12 +49,6 @@ export class AuthResolver {
private readonly token: TokenService
) {}
@Throttle({
default: {
limit: 10,
ttl: 60,
},
})
@Public()
@Query(() => UserType, {
name: 'currentUser',
@ -69,12 +59,6 @@ export class AuthResolver {
return user;
}
@Throttle({
default: {
limit: 20,
ttl: 60,
},
})
@ResolveField(() => ClientTokenType, {
name: 'token',
deprecationReason: 'use [/api/auth/authorize]',
@ -101,12 +85,6 @@ export class AuthResolver {
}
@Public()
@Throttle({
default: {
limit: 10,
ttl: 60,
},
})
@Mutation(() => UserType)
async signUp(
@Context() ctx: { req: Request; res: Response },
@ -122,12 +100,6 @@ export class AuthResolver {
}
@Public()
@Throttle({
default: {
limit: 10,
ttl: 60,
},
})
@Mutation(() => UserType)
async signIn(
@Context() ctx: { req: Request; res: Response },
@ -141,12 +113,6 @@ export class AuthResolver {
return user;
}
@Throttle({
default: {
limit: 5,
ttl: 60,
},
})
@Mutation(() => UserType)
async changePassword(
@CurrentUser() user: CurrentUser,
@ -172,12 +138,6 @@ export class AuthResolver {
return user;
}
@Throttle({
default: {
limit: 5,
ttl: 60,
},
})
@Mutation(() => UserType)
async changeEmail(
@CurrentUser() user: CurrentUser,
@ -202,12 +162,6 @@ export class AuthResolver {
return user;
}
@Throttle({
default: {
limit: 5,
ttl: 60,
},
})
@Mutation(() => Boolean)
async sendChangePasswordEmail(
@CurrentUser() user: CurrentUser,
@ -235,12 +189,6 @@ export class AuthResolver {
return !res.rejected.length;
}
@Throttle({
default: {
limit: 5,
ttl: 60,
},
})
@Mutation(() => Boolean)
async sendSetPasswordEmail(
@CurrentUser() user: CurrentUser,
@ -273,12 +221,6 @@ export class AuthResolver {
// 4. user open confirm email page from new email
// 5. user click confirm button
// 6. send notification email
@Throttle({
default: {
limit: 5,
ttl: 60,
},
})
@Mutation(() => Boolean)
async sendChangeEmail(
@CurrentUser() user: CurrentUser,
@ -299,12 +241,6 @@ export class AuthResolver {
return !res.rejected.length;
}
@Throttle({
default: {
limit: 5,
ttl: 60,
},
})
@Mutation(() => Boolean)
async sendVerifyChangeEmail(
@CurrentUser() user: CurrentUser,
@ -347,12 +283,6 @@ export class AuthResolver {
return !res.rejected.length;
}
@Throttle({
default: {
limit: 5,
ttl: 60,
},
})
@Mutation(() => Boolean)
async sendVerifyEmail(
@CurrentUser() user: CurrentUser,
@ -367,12 +297,6 @@ export class AuthResolver {
return !res.rejected.length;
}
@Throttle({
default: {
limit: 5,
ttl: 60,
},
})
@Mutation(() => Boolean)
async verifyEmail(
@CurrentUser() user: CurrentUser,

View File

@ -1,8 +1,4 @@
import {
BadRequestException,
ForbiddenException,
UseGuards,
} from '@nestjs/common';
import { BadRequestException, ForbiddenException } from '@nestjs/common';
import {
Args,
Context,
@ -13,7 +9,6 @@ import {
Resolver,
} from '@nestjs/graphql';
import { CloudThrottlerGuard, Throttle } from '../../fundamentals';
import { CurrentUser } from '../auth/current-user';
import { sessionUser } from '../auth/service';
import { EarlyAccessType, FeatureManagementService } from '../features';
@ -24,11 +19,6 @@ registerEnumType(EarlyAccessType, {
name: 'EarlyAccessType',
});
/**
* User resolver
* All op rate limit: 10 req/m
*/
@UseGuards(CloudThrottlerGuard)
@Resolver(() => UserType)
export class UserManagementResolver {
constructor(
@ -36,12 +26,6 @@ export class UserManagementResolver {
private readonly feature: FeatureManagementService
) {}
@Throttle({
default: {
limit: 10,
ttl: 60,
},
})
@Mutation(() => Int)
async addToEarlyAccess(
@CurrentUser() currentUser: CurrentUser,
@ -62,12 +46,6 @@ export class UserManagementResolver {
}
}
@Throttle({
default: {
limit: 10,
ttl: 60,
},
})
@Mutation(() => Int)
async removeEarlyAccess(
@CurrentUser() currentUser: CurrentUser,
@ -83,12 +61,6 @@ export class UserManagementResolver {
return this.feature.removeEarlyAccess(user.id);
}
@Throttle({
default: {
limit: 10,
ttl: 60,
},
})
@Query(() => [UserType])
async earlyAccessUsers(
@Context() ctx: { isAdminQuery: boolean },

View File

@ -1,4 +1,4 @@
import { BadRequestException, UseGuards } from '@nestjs/common';
import { BadRequestException } from '@nestjs/common';
import {
Args,
Int,
@ -14,7 +14,6 @@ import { isNil, omitBy } from 'lodash-es';
import type { FileUpload } from '../../fundamentals';
import {
CloudThrottlerGuard,
EventEmitter,
PaymentRequiredException,
Throttle,
@ -35,11 +34,6 @@ import {
UserType,
} from './types';
/**
* User resolver
* All op rate limit: 10 req/m
*/
@UseGuards(CloudThrottlerGuard)
@Resolver(() => UserType)
export class UserResolver {
constructor(
@ -51,12 +45,7 @@ export class UserResolver {
private readonly event: EventEmitter
) {}
@Throttle({
default: {
limit: 10,
ttl: 60,
},
})
@Throttle('strict')
@Query(() => UserOrLimitedUser, {
name: 'user',
description: 'Get user by email',
@ -90,7 +79,6 @@ export class UserResolver {
};
}
@Throttle({ default: { limit: 10, ttl: 60 } })
@ResolveField(() => UserQuotaType, { name: 'quota', nullable: true })
async getQuota(@CurrentUser() me: User) {
const quota = await this.quota.getUserQuota(me.id);
@ -98,7 +86,6 @@ export class UserResolver {
return quota.feature;
}
@Throttle({ default: { limit: 10, ttl: 60 } })
@ResolveField(() => Int, {
name: 'invoiceCount',
description: 'Get user invoice count',
@ -109,7 +96,6 @@ export class UserResolver {
});
}
@Throttle({ default: { limit: 10, ttl: 60 } })
@ResolveField(() => [FeatureType], {
name: 'features',
description: 'Enabled features of a user',
@ -118,12 +104,6 @@ export class UserResolver {
return this.feature.getActivatedUserFeatures(user.id);
}
@Throttle({
default: {
limit: 10,
ttl: 60,
},
})
@Mutation(() => UserType, {
name: 'uploadAvatar',
description: 'Upload user avatar',
@ -153,12 +133,6 @@ export class UserResolver {
});
}
@Throttle({
default: {
limit: 10,
ttl: 60,
},
})
@Mutation(() => UserType, {
name: 'updateProfile',
})
@ -180,12 +154,6 @@ export class UserResolver {
);
}
@Throttle({
default: {
limit: 10,
ttl: 60,
},
})
@Mutation(() => RemoveAvatar, {
name: 'removeAvatar',
description: 'Remove user avatar',
@ -201,12 +169,6 @@ export class UserResolver {
return { success: true };
}
@Throttle({
default: {
limit: 10,
ttl: 60,
},
})
@Mutation(() => DeleteAccount)
async deleteAccount(
@CurrentUser() user: CurrentUser

View File

@ -1,4 +1,4 @@
import { ForbiddenException, UseGuards } from '@nestjs/common';
import { ForbiddenException } from '@nestjs/common';
import {
Args,
Int,
@ -9,13 +9,11 @@ import {
Resolver,
} from '@nestjs/graphql';
import { CloudThrottlerGuard, Throttle } from '../../fundamentals';
import { CurrentUser } from '../auth';
import { FeatureManagementService, FeatureType } from '../features';
import { PermissionService } from './permission';
import { WorkspaceType } from './types';
@UseGuards(CloudThrottlerGuard)
@Resolver(() => WorkspaceType)
export class WorkspaceManagementResolver {
constructor(
@ -23,12 +21,6 @@ export class WorkspaceManagementResolver {
private readonly permission: PermissionService
) {}
@Throttle({
default: {
limit: 10,
ttl: 60,
},
})
@Mutation(() => Int)
async addWorkspaceFeature(
@CurrentUser() currentUser: CurrentUser,
@ -42,12 +34,6 @@ export class WorkspaceManagementResolver {
return this.feature.addWorkspaceFeatures(workspaceId, feature);
}
@Throttle({
default: {
limit: 10,
ttl: 60,
},
})
@Mutation(() => Int)
async removeWorkspaceFeature(
@CurrentUser() currentUser: CurrentUser,
@ -61,12 +47,6 @@ export class WorkspaceManagementResolver {
return this.feature.removeWorkspaceFeature(workspaceId, feature);
}
@Throttle({
default: {
limit: 10,
ttl: 60,
},
})
@Query(() => [WorkspaceType])
async listWorkspaceFeatures(
@CurrentUser() user: CurrentUser,

View File

@ -2,7 +2,6 @@ import {
ForbiddenException,
Logger,
PayloadTooLargeException,
UseGuards,
} from '@nestjs/common';
import {
Args,
@ -17,11 +16,7 @@ import { SafeIntResolver } from 'graphql-scalars';
import GraphQLUpload from 'graphql-upload/GraphQLUpload.mjs';
import type { FileUpload } from '../../../fundamentals';
import {
CloudThrottlerGuard,
MakeCache,
PreventCache,
} from '../../../fundamentals';
import { MakeCache, PreventCache } from '../../../fundamentals';
import { CurrentUser } from '../../auth';
import { FeatureManagementService, FeatureType } from '../../features';
import { QuotaManagementService } from '../../quota';
@ -29,7 +24,6 @@ import { WorkspaceBlobStorage } from '../../storage';
import { PermissionService } from '../permission';
import { Permission, WorkspaceBlobSizes, WorkspaceType } from '../types';
@UseGuards(CloudThrottlerGuard)
@Resolver(() => WorkspaceType)
export class WorkspaceBlobResolver {
logger = new Logger(WorkspaceBlobResolver.name);

View File

@ -1,4 +1,3 @@
import { UseGuards } from '@nestjs/common';
import {
Args,
Field,
@ -12,7 +11,6 @@ import {
} from '@nestjs/graphql';
import type { SnapshotHistory } from '@prisma/client';
import { CloudThrottlerGuard } from '../../../fundamentals';
import { CurrentUser } from '../../auth';
import { DocHistoryManager } from '../../doc';
import { DocID } from '../../utils/doc';
@ -31,7 +29,6 @@ class DocHistoryType implements Partial<SnapshotHistory> {
timestamp!: Date;
}
@UseGuards(CloudThrottlerGuard)
@Resolver(() => WorkspaceType)
export class DocHistoryResolver {
constructor(

View File

@ -1,4 +1,4 @@
import { BadRequestException, UseGuards } from '@nestjs/common';
import { BadRequestException } from '@nestjs/common';
import {
Args,
Field,
@ -12,7 +12,6 @@ import {
import type { WorkspacePage as PrismaWorkspacePage } from '@prisma/client';
import { PrismaClient } from '@prisma/client';
import { CloudThrottlerGuard } from '../../../fundamentals';
import { CurrentUser } from '../../auth';
import { DocID } from '../../utils/doc';
import { PermissionService, PublicPageMode } from '../permission';
@ -38,7 +37,6 @@ class WorkspacePage implements Partial<PrismaWorkspacePage> {
public!: boolean;
}
@UseGuards(CloudThrottlerGuard)
@Resolver(() => WorkspaceType)
export class PagePermissionResolver {
constructor(

View File

@ -4,7 +4,6 @@ import {
Logger,
NotFoundException,
PayloadTooLargeException,
UseGuards,
} from '@nestjs/common';
import {
Args,
@ -22,7 +21,6 @@ import { applyUpdate, Doc } from 'yjs';
import type { FileUpload } from '../../../fundamentals';
import {
CloudThrottlerGuard,
EventEmitter,
MailService,
MutexService,
@ -48,7 +46,6 @@ import { defaultWorkspaceAvatar } from '../utils';
* Public apis rate limit: 10 req/m
* Other rate limit: 120 req/m
*/
@UseGuards(CloudThrottlerGuard)
@Resolver(() => WorkspaceType)
export class WorkspaceResolver {
private readonly logger = new Logger(WorkspaceResolver.name);
@ -191,12 +188,7 @@ export class WorkspaceResolver {
});
}
@Throttle({
default: {
limit: 10,
ttl: 30,
},
})
@Throttle('strict')
@Public()
@Query(() => WorkspaceType, {
description: 'Get public workspace by id',
@ -422,15 +414,10 @@ export class WorkspaceResolver {
}
}
@Throttle({
default: {
limit: 10,
ttl: 30,
},
})
@Throttle('strict')
@Public()
@Query(() => InvitationType, {
description: 'Update workspace',
description: 'send workspace invitation',
})
async getInviteInfo(@Args('inviteId') inviteId: string) {
const workspaceId = await this.prisma.workspaceUserPermission

View File

@ -1,11 +1,12 @@
import { Global, Module } from '@nestjs/common';
import { Cache, SessionCache } from './instances';
import { CacheInterceptor } from './interceptor';
@Global()
@Module({
providers: [Cache, SessionCache],
exports: [Cache, SessionCache],
providers: [Cache, SessionCache, CacheInterceptor],
exports: [Cache, SessionCache, CacheInterceptor],
})
export class CacheModule {}
export { Cache, SessionCache };

View File

@ -27,7 +27,7 @@ export {
export type { PrismaTransaction } from './prisma';
export * from './storage';
export { type StorageProvider, StorageProviderFactory } from './storage';
export { AuthThrottlerGuard, CloudThrottlerGuard, Throttle } from './throttler';
export { CloudThrottlerGuard, Throttle } from './throttler';
export {
getRequestFromHost,
getRequestResponseFromContext,

View File

@ -0,0 +1,38 @@
import { applyDecorators, SetMetadata } from '@nestjs/common';
import { SkipThrottle, Throttle as RawThrottle } from '@nestjs/throttler';
export type Throttlers = 'default' | 'strict';
export const THROTTLER_PROTECTED = 'affine_throttler:protected';
/**
* Choose what throttler to use
*
* If a Controller or Query do not protected behind a Throttler,
* it will never be rate limited.
*
* - Ease: 120 calls within 60 seconds
* - Strict: 10 calls within 60 seconds
*
* @example
*
* \@Throttle()
* \@Throttle('strict')
*
* // the config call be override by the second parameter,
* // and the call count will be calculated separately
* \@Throttle('default', { limit: 10, ttl: 10 })
*
*/
export function Throttle(
type: Throttlers = 'default',
override: { limit?: number; ttl?: number } = {}
): MethodDecorator & ClassDecorator {
return applyDecorators(
SetMetadata(THROTTLER_PROTECTED, type),
RawThrottle({
[type]: override,
})
);
}
export { SkipThrottle };

View File

@ -1,15 +1,20 @@
import { ExecutionContext, Global, Injectable, Module } from '@nestjs/common';
import { Reflector } from '@nestjs/core';
import {
Throttle,
InjectThrottlerOptions,
InjectThrottlerStorage,
ThrottlerGuard,
ThrottlerModule,
ThrottlerModuleOptions,
type ThrottlerModuleOptions,
ThrottlerOptions,
ThrottlerOptionsFactory,
ThrottlerStorageService,
} from '@nestjs/throttler';
import type { Request } from 'express';
import { Config } from '../config';
import { getRequestResponseFromContext } from '../utils/request';
import { THROTTLER_PROTECTED, Throttlers } from './decorators';
@Injectable()
export class ThrottlerStorage extends ThrottlerStorageService {}
@ -25,13 +30,16 @@ class CustomOptionsFactory implements ThrottlerOptionsFactory {
const options: ThrottlerModuleOptions = {
throttlers: [
{
name: 'default',
ttl: this.config.rateLimiter.ttl * 1000,
limit: this.config.rateLimiter.limit,
},
{
name: 'strict',
ttl: this.config.rateLimiter.ttl * 1000,
limit: 20,
},
],
skipIf: () => {
return !this.config.node.prod || this.config.affine.canary;
},
storage: this.storage,
};
@ -39,6 +47,132 @@ class CustomOptionsFactory implements ThrottlerOptionsFactory {
}
}
@Injectable()
export class CloudThrottlerGuard extends ThrottlerGuard {
constructor(
@InjectThrottlerOptions() options: ThrottlerModuleOptions,
@InjectThrottlerStorage() storageService: ThrottlerStorage,
reflector: Reflector,
private readonly config: Config
) {
super(options, storageService, reflector);
}
override getRequestResponse(context: ExecutionContext) {
return getRequestResponseFromContext(context) as any;
}
override getTracker(req: Request): Promise<string> {
return Promise.resolve(
// ↓ prefer session id if available
`throttler:${req.sid ?? req.get('CF-Connecting-IP') ?? req.get('CF-ray') ?? req.ip}`
// ^ throttler prefix make the key in store recognizable
);
}
override generateKey(
context: ExecutionContext,
tracker: string,
throttler: string
) {
if (tracker.endsWith(';custom')) {
return `${tracker};${throttler}:${context.getClass().name}.${context.getHandler().name}`;
}
return `${tracker};${throttler}`;
}
override async handleRequest(
context: ExecutionContext,
limit: number,
ttl: number,
throttlerOptions: ThrottlerOptions
) {
// give it 'default' if no throttler is specified,
// so the unauthenticated users visits will always hit default throttler
// authenticated users will directly bypass unprotected APIs in [CloudThrottlerGuard.canActivate]
const throttler = this.getSpecifiedThrottler(context) ?? 'default';
// by pass unmatched throttlers
if (throttlerOptions.name !== throttler) {
return true;
}
const { req, res } = this.getRequestResponse(context);
const ignoreUserAgents =
throttlerOptions.ignoreUserAgents ?? this.commonOptions.ignoreUserAgents;
if (Array.isArray(ignoreUserAgents)) {
for (const pattern of ignoreUserAgents) {
const ua = req.headers['user-agent'];
if (ua && pattern.test(ua)) {
return true;
}
}
}
let tracker = await this.getTracker(req);
if (this.config.node.dev) {
limit = Number.MAX_SAFE_INTEGER;
} else {
// custom limit or ttl APIs will be treated standalone
if (limit !== throttlerOptions.limit || ttl !== throttlerOptions.ttl) {
tracker += ';custom';
}
}
const key = this.generateKey(
context,
tracker,
throttlerOptions.name ?? 'default'
);
const { timeToExpire, totalHits } = await this.storageService.increment(
key,
ttl
);
if (totalHits > limit) {
res.header('Retry-After', timeToExpire.toString());
await this.throwThrottlingException(context, {
limit,
ttl,
key,
tracker,
totalHits,
timeToExpire,
});
}
res.header(`${this.headerPrefix}-Limit`, limit.toString());
res.header(
`${this.headerPrefix}-Remaining`,
(limit - totalHits).toString()
);
res.header(`${this.headerPrefix}-Reset`, timeToExpire.toString());
return true;
}
override async canActivate(context: ExecutionContext): Promise<boolean> {
const { req } = this.getRequestResponse(context);
const throttler = this.getSpecifiedThrottler(context);
// if user is logged in, bypass non-protected handlers
if (!throttler && req.user) {
return true;
}
return super.canActivate(context);
}
getSpecifiedThrottler(context: ExecutionContext) {
return this.reflector.getAllAndOverride<Throttlers | undefined>(
THROTTLER_PROTECTED,
[context.getHandler(), context.getClass()]
);
}
}
@Global()
@Module({
imports: [
@ -46,46 +180,9 @@ class CustomOptionsFactory implements ThrottlerOptionsFactory {
useClass: CustomOptionsFactory,
}),
],
providers: [ThrottlerStorage],
exports: [ThrottlerStorage],
providers: [ThrottlerStorage, CloudThrottlerGuard],
exports: [ThrottlerStorage, CloudThrottlerGuard],
})
export class RateLimiterModule {}
@Injectable()
export class CloudThrottlerGuard extends ThrottlerGuard {
override getRequestResponse(context: ExecutionContext) {
return getRequestResponseFromContext(context) as any;
}
protected override getTracker(req: Record<string, any>): Promise<string> {
return Promise.resolve(
req?.get('CF-Connecting-IP') ?? req?.get('CF-ray') ?? req?.ip
);
}
}
@Injectable()
export class AuthThrottlerGuard extends CloudThrottlerGuard {
override async handleRequest(
context: ExecutionContext,
limit: number,
ttl: number
): Promise<boolean> {
const { req } = this.getRequestResponse(context);
if (req?.url === '/api/auth/session') {
// relax throttle for session auto renew
return super.handleRequest(context, limit * 20, ttl, {
ttl: ttl * 20,
limit: limit * 20,
});
}
return super.handleRequest(context, limit, ttl, {
ttl,
limit,
});
}
}
export { Throttle };
export * from './decorators';

View File

@ -1,6 +1,7 @@
declare namespace Express {
interface Request {
user?: import('./core/auth/current-user').CurrentUser;
sid?: string;
}
}

View File

@ -265,7 +265,7 @@ type Query {
currentUser: UserType
earlyAccessUsers: [UserType!]!
"""Update workspace"""
"""send workspace invitation"""
getInviteInfo(inviteId: String!): InvitationType!
"""Get is owner of workspace"""

View File

@ -0,0 +1,331 @@
import '../../src/plugins/config';
import {
Controller,
Get,
HttpStatus,
INestApplication,
UseGuards,
} from '@nestjs/common';
import ava, { TestFn } from 'ava';
import Sinon from 'sinon';
import request, { type Response } from 'supertest';
import { AppModule } from '../../src/app.module';
import { AuthService, Public } from '../../src/core/auth';
import { ConfigModule } from '../../src/fundamentals/config';
import {
CloudThrottlerGuard,
SkipThrottle,
Throttle,
ThrottlerStorage,
} from '../../src/fundamentals/throttler';
import { createTestingApp, sessionCookie } from '../utils';
const test = ava as TestFn<{
storage: ThrottlerStorage;
cookie: string;
app: INestApplication;
}>;
@UseGuards(CloudThrottlerGuard)
@Throttle()
@Controller('/throttled')
class ThrottledController {
@Get('/default')
default() {
return 'default';
}
@Get('/default2')
default2() {
return 'default2';
}
@Get('/default3')
@Throttle('default', { limit: 10 })
default3() {
return 'default3';
}
@Throttle('strict')
@Get('/strict')
strict() {
return 'strict';
}
@Public()
@SkipThrottle()
@Get('/skip')
skip() {
return 'skip';
}
}
@UseGuards(CloudThrottlerGuard)
@Controller('/nonthrottled')
class NonThrottledController {
@Public()
@SkipThrottle()
@Get('/skip')
skip() {
return 'skip';
}
@Public()
@Get('/default')
default() {
return 'default';
}
@Public()
@Throttle('strict')
@Get('/strict')
strict() {
return 'strict';
}
}
test.beforeEach(async t => {
const { app } = await createTestingApp({
imports: [
ConfigModule.forRoot({
rateLimiter: {
ttl: 60,
limit: 120,
},
}),
AppModule,
],
controllers: [ThrottledController, NonThrottledController],
});
t.context.storage = app.get(ThrottlerStorage);
t.context.app = app;
const auth = app.get(AuthService);
const u1 = await auth.signUp('u1', 'u1@affine.pro', 'test');
const res = await request(app.getHttpServer())
.post('/api/auth/sign-in')
.send({ email: u1.email, password: 'test' });
t.context.cookie = sessionCookie(res.headers)!;
});
test.afterEach.always(async t => {
await t.context.app.close();
});
function rateLimitHeaders(res: Response) {
return {
limit: res.header['x-ratelimit-limit'],
remaining: res.header['x-ratelimit-remaining'],
reset: res.header['x-ratelimit-reset'],
retryAfter: res.header['retry-after'],
};
}
test('should be able to prevent requests if limit is reached', async t => {
const { app } = t.context;
const stub = Sinon.stub(app.get(ThrottlerStorage), 'increment').resolves({
timeToExpire: 10,
totalHits: 21,
});
const res = await request(app.getHttpServer())
.get('/nonthrottled/strict')
.expect(HttpStatus.TOO_MANY_REQUESTS);
const headers = rateLimitHeaders(res);
t.is(headers.retryAfter, '10');
stub.restore();
});
// ====== unauthenticated user visits ======
test('should use default throttler for unauthenticated user when not specified', async t => {
const { app } = t.context;
const res = await request(app.getHttpServer())
.get('/nonthrottled/default')
.expect(200);
const headers = rateLimitHeaders(res);
t.is(headers.limit, '120');
t.is(headers.remaining, '119');
t.is(headers.reset, '60');
});
test('should skip throttler for unauthenticated user when specified', async t => {
const { app } = t.context;
let res = await request(app.getHttpServer())
.get('/nonthrottled/skip')
.expect(200);
let headers = rateLimitHeaders(res);
t.is(headers.limit, undefined!);
t.is(headers.remaining, undefined!);
t.is(headers.reset, undefined!);
res = await request(app.getHttpServer()).get('/throttled/skip').expect(200);
headers = rateLimitHeaders(res);
t.is(headers.limit, undefined!);
t.is(headers.remaining, undefined!);
t.is(headers.reset, undefined!);
});
test('should use specified throttler for unauthenticated user', async t => {
const { app } = t.context;
const res = await request(app.getHttpServer())
.get('/nonthrottled/strict')
.expect(200);
const headers = rateLimitHeaders(res);
t.is(headers.limit, '20');
t.is(headers.remaining, '19');
t.is(headers.reset, '60');
});
// ==== authenticated user visits ====
test('should not protect unspecified routes', async t => {
const { app, cookie } = t.context;
const res = await request(app.getHttpServer())
.get('/nonthrottled/default')
.set('Cookie', cookie)
.expect(200);
const headers = rateLimitHeaders(res);
t.is(headers.limit, undefined!);
t.is(headers.remaining, undefined!);
t.is(headers.reset, undefined!);
});
test('should use default throttler for authenticated user when not specified', async t => {
const { app, cookie } = t.context;
const res = await request(app.getHttpServer())
.get('/throttled/default')
.set('Cookie', cookie)
.expect(200);
const headers = rateLimitHeaders(res);
t.is(headers.limit, '120');
t.is(headers.remaining, '119');
t.is(headers.reset, '60');
});
test('should use same throttler for multiple routes', async t => {
const { app, cookie } = t.context;
let res = await request(app.getHttpServer())
.get('/throttled/default')
.set('Cookie', cookie)
.expect(200);
let headers = rateLimitHeaders(res);
t.is(headers.limit, '120');
t.is(headers.remaining, '119');
t.is(headers.reset, '60');
res = await request(app.getHttpServer())
.get('/throttled/default2')
.set('Cookie', cookie)
.expect(200);
headers = rateLimitHeaders(res);
t.is(headers.limit, '120');
t.is(headers.remaining, '118');
});
test('should use different throttler if specified', async t => {
const { app, cookie } = t.context;
let res = await request(app.getHttpServer())
.get('/throttled/default')
.set('Cookie', cookie)
.expect(200);
let headers = rateLimitHeaders(res);
t.is(headers.limit, '120');
t.is(headers.remaining, '119');
t.is(headers.reset, '60');
res = await request(app.getHttpServer())
.get('/throttled/default3')
.set('Cookie', cookie)
.expect(200);
headers = rateLimitHeaders(res);
t.is(headers.limit, '10');
t.is(headers.remaining, '9');
t.is(headers.reset, '60');
});
test('should skip throttler for authenticated user when specified', async t => {
const { app, cookie } = t.context;
const res = await request(app.getHttpServer())
.get('/throttled/skip')
.set('Cookie', cookie)
.expect(200);
const headers = rateLimitHeaders(res);
t.is(headers.limit, undefined!);
t.is(headers.remaining, undefined!);
t.is(headers.reset, undefined!);
});
test('should use specified throttler for authenticated user', async t => {
const { app, cookie } = t.context;
const res = await request(app.getHttpServer())
.get('/throttled/strict')
.set('Cookie', cookie)
.expect(200);
const headers = rateLimitHeaders(res);
t.is(headers.limit, '20');
t.is(headers.remaining, '19');
t.is(headers.reset, '60');
});
test('should separate anonymous and authenticated user throttlers', async t => {
const { app, cookie } = t.context;
const authenticatedUserRes = await request(app.getHttpServer())
.get('/throttled/default')
.set('Cookie', cookie)
.expect(200);
const unauthenticatedUserRes = await request(app.getHttpServer())
.get('/nonthrottled/default')
.expect(200);
const authenticatedResHeaders = rateLimitHeaders(authenticatedUserRes);
const unauthenticatedResHeaders = rateLimitHeaders(unauthenticatedUserRes);
t.is(authenticatedResHeaders.limit, '120');
t.is(authenticatedResHeaders.remaining, '119');
t.is(authenticatedResHeaders.reset, '60');
t.is(unauthenticatedResHeaders.limit, '120');
t.is(unauthenticatedResHeaders.remaining, '119');
t.is(unauthenticatedResHeaders.reset, '60');
});

View File

@ -10,13 +10,13 @@ import {
import type { UserType } from '../../src/core/user';
import { gql } from './common';
export function sessionCookie(headers: any) {
export function sessionCookie(headers: any): string {
const cookie = headers['set-cookie']?.find((c: string) =>
c.startsWith(`${AuthService.sessionCookieName}=`)
);
if (!cookie) {
return null;
return '';
}
return cookie.split(';')[0];
@ -29,7 +29,7 @@ export async function getSession(
const cookie = sessionCookie(signInRes.headers);
const res = await request(app.getHttpServer())
.get('/api/auth/session')
.set('cookie', cookie)
.set('cookie', cookie!)
.expect(200);
return res.body;