feat(server): support socketio auth field (#8595)

fix AF-1531
This commit is contained in:
forehalo 2024-10-25 03:27:02 +00:00
parent 10963da706
commit 08319bc560
No known key found for this signature in database
GPG Key ID: 56709255DC7EC728
5 changed files with 45 additions and 29 deletions

View File

@ -7,12 +7,12 @@ import type {
import { Injectable, SetMetadata } from '@nestjs/common';
import { ModuleRef, Reflector } from '@nestjs/core';
import type { Request, Response } from 'express';
import { Socket } from 'socket.io';
import {
AuthenticationRequired,
Config,
getRequestResponseFromContext,
mapAnyError,
parseCookies,
} from '../../fundamentals';
import { WEBSOCKET_OPTIONS } from '../../fundamentals/websocket';
@ -64,9 +64,6 @@ export class AuthGuard implements CanActivate, OnModuleInit {
return req.session;
}
// compatibility with websocket request
parseCookies(req);
// TODO(@forehalo): a cache for user session
const userSession = await this.auth.getUserSessionFromRequest(req, res);
@ -93,27 +90,22 @@ export const AuthWebsocketOptionsProvider: FactoryProvider = {
useFactory: (config: Config, guard: AuthGuard) => {
return {
...config.websocket,
allowRequest: async (
req: any,
pass: (err: string | null | undefined, success: boolean) => void
) => {
if (!config.websocket.requireAuthentication) {
return pass(null, true);
}
canActivate: async (socket: Socket) => {
const upgradeReq = socket.client.request as Request;
const handshake = socket.handshake;
try {
const authentication = await guard.signIn(req);
// compatibility with websocket request
parseCookies(upgradeReq);
if (authentication) {
return pass(null, true);
} else {
return pass('unauthenticated', false);
}
} catch (e) {
const error = mapAnyError(e);
error.log('Websocket');
return pass('unauthenticated', false);
}
upgradeReq.cookies = {
[AuthService.sessionCookieName]: handshake.auth.token,
[AuthService.userCookieName]: handshake.auth.userId,
...upgradeReq.cookies,
};
const session = await guard.signIn(upgradeReq);
return !!session;
},
};
},

View File

@ -298,7 +298,7 @@ export class AuthService implements OnApplicationBootstrap {
const userId: string | undefined =
req.cookies[AuthService.userCookieName] ||
req.headers[AuthService.userCookieName];
req.headers[AuthService.userCookieName.replaceAll('_', '-')];
return {
sessionId,

View File

@ -26,7 +26,7 @@ export function getRequestResponseFromHost(host: ArgumentsHost) {
}
case 'ws': {
const ws = host.switchToWs();
const req = ws.getClient<Socket>().client.conn.request as Request;
const req = ws.getClient<Socket>().request as Request;
parseCookies(req);
return { req };
}

View File

@ -1,4 +1,5 @@
import { GatewayMetadata } from '@nestjs/websockets';
import { Socket } from 'socket.io';
import { defineStartupConfig, ModuleConfig } from '../config';
@ -6,7 +7,7 @@ declare module '../config' {
interface AppConfig {
websocket: ModuleConfig<
GatewayMetadata & {
requireAuthentication?: boolean;
canActivate?: (socket: Socket) => Promise<boolean>;
}
>;
}
@ -16,5 +17,4 @@ defineStartupConfig('websocket', {
// see: https://socket.io/docs/v4/server-options/#maxhttpbuffersize
transports: ['websocket'],
maxHttpBufferSize: 1e8, // 100 MB
requireAuthentication: true,
});

View File

@ -10,6 +10,7 @@ import { IoAdapter } from '@nestjs/platform-socket.io';
import { Server } from 'socket.io';
import { Config } from '../config';
import { AuthenticationRequired } from '../error';
export const SocketIoAdapterImpl = Symbol('SocketIoAdapterImpl');
@ -19,8 +20,31 @@ export class SocketIoAdapter extends IoAdapter {
}
override createIOServer(port: number, options?: any): Server {
const config = this.app.get(WEBSOCKET_OPTIONS);
return super.createIOServer(port, { ...config, ...options });
const config = this.app.get(WEBSOCKET_OPTIONS) as Config['websocket'];
const server: Server = super.createIOServer(port, {
...config,
...options,
});
if (config.canActivate) {
server.use((socket, next) => {
// @ts-expect-error checked
config
.canActivate(socket)
.then(pass => {
if (pass) {
next();
} else {
throw new AuthenticationRequired();
}
})
.catch(e => {
next(e);
});
});
}
return server;
}
}