diff --git a/packages/backend/server/src/app.ts b/packages/backend/server/src/app.ts index afc5ab7b1a..a1f068f026 100644 --- a/packages/backend/server/src/app.ts +++ b/packages/backend/server/src/app.ts @@ -1,7 +1,8 @@ import { Module } from '@nestjs/common'; +import { APP_INTERCEPTOR } from '@nestjs/core'; import { AppController } from './app.controller'; -import { CacheModule } from './cache'; +import { CacheInterceptor, CacheModule } from './cache'; import { ConfigModule } from './config'; import { EventModule } from './event'; import { BusinessModules } from './modules'; @@ -23,6 +24,12 @@ const BasicModules = [ ]; @Module({ + providers: [ + { + provide: APP_INTERCEPTOR, + useClass: CacheInterceptor, + }, + ], imports: [...BasicModules, ...BusinessModules], controllers: [AppController], }) diff --git a/packages/backend/server/src/cache/index.ts b/packages/backend/server/src/cache/index.ts index 621407f031..3a1dab5310 100644 --- a/packages/backend/server/src/cache/index.ts +++ b/packages/backend/server/src/cache/index.ts @@ -22,3 +22,5 @@ const CacheProvider: FactoryProvider = { }) export class CacheModule {} export { LocalCache as Cache }; + +export { CacheInterceptor, MakeCache, PreventCache } from './interceptor'; diff --git a/packages/backend/server/src/cache/interceptor.ts b/packages/backend/server/src/cache/interceptor.ts new file mode 100644 index 0000000000..8cf078e18f --- /dev/null +++ b/packages/backend/server/src/cache/interceptor.ts @@ -0,0 +1,99 @@ +import { + CallHandler, + ExecutionContext, + Injectable, + Logger, + NestInterceptor, + SetMetadata, +} from '@nestjs/common'; +import { Reflector } from '@nestjs/core'; +import { GqlContextType, GqlExecutionContext } from '@nestjs/graphql'; +import { mergeMap, Observable, of } from 'rxjs'; + +import { LocalCache } from './cache'; + +export const MakeCache = (key: string[], args?: string[]) => + SetMetadata('cacheKey', [key, args]); +export const PreventCache = (key: string[], args?: string[]) => + SetMetadata('preventCache', [key, args]); + +type CacheConfig = [string[], string[]?]; + +@Injectable() +export class CacheInterceptor implements NestInterceptor { + private readonly logger = new Logger(CacheInterceptor.name); + constructor( + private readonly reflector: Reflector, + private readonly cache: LocalCache + ) {} + async intercept( + ctx: ExecutionContext, + next: CallHandler + ): Promise> { + const key = this.reflector.get( + 'cacheKey', + ctx.getHandler() + ); + const preventKey = this.reflector.get( + 'preventCache', + ctx.getHandler() + ); + + if (preventKey) { + this.logger.debug(`prevent cache: ${JSON.stringify(preventKey)}`); + const key = await this.getCacheKey(ctx, preventKey); + if (key) { + await this.cache.delete(key); + } + + return next.handle(); + } else if (!key) { + return next.handle(); + } + + const cacheKey = await this.getCacheKey(ctx, key); + + if (!cacheKey) { + return next.handle(); + } + + const cachedData = await this.cache.get(cacheKey); + + if (cachedData) { + this.logger.debug('cache hit', cacheKey, cachedData); + return of(cachedData); + } else { + return next.handle().pipe( + mergeMap(async result => { + this.logger.debug('cache miss', cacheKey, result); + await this.cache.set(cacheKey, result); + + return result; + }) + ); + } + } + + private async getCacheKey( + ctx: ExecutionContext, + config: CacheConfig + ): Promise { + const [key, params] = config; + + if (!params) { + return key.join(':'); + } else if (ctx.getType() === 'graphql') { + const args = GqlExecutionContext.create(ctx).getArgs(); + const cacheKey = params + .map(name => args[name]) + .filter(v => v) + .join(':'); + if (cacheKey) { + return [...key, cacheKey].join(':'); + } else { + return key.join(':'); + } + } + return null; + } +} diff --git a/packages/backend/server/src/modules/workspaces/resolver.ts b/packages/backend/server/src/modules/workspaces/resolver.ts index f8e30cf30f..1880ea935a 100644 --- a/packages/backend/server/src/modules/workspaces/resolver.ts +++ b/packages/backend/server/src/modules/workspaces/resolver.ts @@ -33,6 +33,7 @@ import type { import GraphQLUpload from 'graphql-upload/GraphQLUpload.mjs'; import { applyUpdate, Doc } from 'yjs'; +import { MakeCache, PreventCache } from '../../cache'; import { EventEmitter } from '../../event'; import { PrismaService } from '../../prisma'; import { StorageProvide } from '../../storage'; @@ -656,6 +657,7 @@ export class WorkspaceResolver { @Query(() => [String], { description: 'List blobs of workspace', }) + @MakeCache(['blobs'], ['workspaceId']) async listBlobs( @CurrentUser() user: UserType, @Args('workspaceId') workspaceId: string @@ -690,6 +692,7 @@ export class WorkspaceResolver { } @Mutation(() => String) + @PreventCache(['blobs'], ['workspaceId']) async setBlob( @CurrentUser() user: UserType, @Args('workspaceId') workspaceId: string, @@ -749,6 +752,7 @@ export class WorkspaceResolver { } @Mutation(() => Boolean) + @PreventCache(['blobs'], ['workspaceId']) async deleteBlob( @CurrentUser() user: UserType, @Args('workspaceId') workspaceId: string,