refactor(server): auth (#7994)

This commit is contained in:
forehalo 2024-09-03 09:03:39 +00:00
parent 821de0a3bb
commit 8b0afd6eeb
No known key found for this signature in database
GPG Key ID: 56709255DC7EC728
39 changed files with 639 additions and 775 deletions

View File

@ -60,13 +60,6 @@ spec:
name: affine-graphql
port:
number: {{ .Values.graphql.service.port }}
- path: /oauth
pathType: Prefix
backend:
service:
name: affine-graphql
port:
number: {{ .Values.graphql.service.port }}
- path: /
pathType: Prefix
backend:

View File

@ -62,12 +62,13 @@ model ConnectedAccount {
}
model Session {
id String @id @default(uuid()) @db.VarChar
expiresAt DateTime? @map("expires_at") @db.Timestamptz(3)
createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(3)
id String @id @default(uuid()) @db.VarChar
createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(3)
userSessions UserSession[]
// @deprecated use [UserSession.expiresAt]
deprecated_expiresAt DateTime? @map("expires_at") @db.Timestamptz(3)
@@map("multiple_users_sessions")
}

View File

@ -18,6 +18,7 @@ import {
EarlyAccessRequired,
EmailTokenNotFound,
InternalServerError,
InvalidEmail,
InvalidEmailToken,
SignUpForbidden,
Throttle,
@ -25,19 +26,25 @@ import {
} from '../../fundamentals';
import { UserService } from '../user';
import { validators } from '../utils/validators';
import { CurrentUser } from './current-user';
import { Public } from './guard';
import { AuthService, parseAuthUserSeqNum } from './service';
import { AuthService } from './service';
import { CurrentUser, Session } from './session';
import { TokenService, TokenType } from './token';
class SignInCredential {
email!: string;
password?: string;
interface PreflightResponse {
registered: boolean;
hasPassword: boolean;
}
class MagicLinkCredential {
email!: string;
token!: string;
interface SignInCredential {
email: string;
password?: string;
callbackUrl?: string;
}
interface MagicLinkCredential {
email: string;
token: string;
}
@Throttle('strict')
@ -51,6 +58,33 @@ export class AuthController {
private readonly config: Config
) {}
@Public()
@Post('/preflight')
async preflight(
@Body() params?: { email: string }
): Promise<PreflightResponse> {
if (!params?.email) {
throw new InvalidEmail();
}
validators.assertValidEmail(params.email);
const user = await this.user.findUserWithHashedPasswordByEmail(
params.email
);
if (!user) {
return {
registered: false,
hasPassword: false,
};
}
return {
registered: user.registered,
hasPassword: !!user.password,
};
}
@Public()
@Post('/sign-in')
@Header('content-type', 'application/json')
@ -58,7 +92,10 @@ export class AuthController {
@Req() req: Request,
@Res() res: Response,
@Body() credential: SignInCredential,
@Query('redirect_uri') redirectUri = this.url.home
/**
* @deprecated
*/
@Query('redirect_uri') redirectUri?: string
) {
validators.assertValidEmail(credential.email);
const canSignIn = await this.auth.canSignIn(credential.email);
@ -67,80 +104,83 @@ export class AuthController {
}
if (credential.password) {
const user = await this.auth.signIn(
await this.passwordSignIn(
req,
res,
credential.email,
credential.password
);
await this.auth.setCookie(req, res, user);
res.status(HttpStatus.OK).send(user);
} else {
// send email magic link
const user = await this.user.findUserByEmail(credential.email);
if (!user) {
const allowSignup = await this.config.runtime.fetch('auth/allowSignup');
if (!allowSignup) {
throw new SignUpForbidden();
}
}
const result = await this.sendSignInEmail(
{ email: credential.email, signUp: !user },
await this.sendMagicLink(
req,
res,
credential.email,
credential.callbackUrl,
redirectUri
);
if (result.rejected.length) {
throw new InternalServerError('Failed to send sign-in email.');
}
res.status(HttpStatus.OK).send({
email: credential.email,
});
}
}
async sendSignInEmail(
{ email, signUp }: { email: string; signUp: boolean },
redirectUri: string
async passwordSignIn(
req: Request,
res: Response,
email: string,
password: string
) {
const user = await this.auth.signIn(email, password);
await this.auth.setCookies(req, res, user.id);
res.status(HttpStatus.OK).send(user);
}
async sendMagicLink(
_req: Request,
res: Response,
email: string,
callbackUrl = '/magic-link',
redirectUrl = this.url.home
) {
// send email magic link
const user = await this.user.findUserByEmail(email);
if (!user) {
const allowSignup = await this.config.runtime.fetch('auth/allowSignup');
if (!allowSignup) {
throw new SignUpForbidden();
}
}
const token = await this.token.createToken(TokenType.SignIn, email);
const magicLink = this.url.link('/magic-link', {
const magicLink = this.url.link(callbackUrl, {
token,
email,
redirect_uri: redirectUri,
redirect_uri: redirectUrl,
});
const result = await this.auth.sendSignInEmail(email, magicLink, signUp);
const result = await this.auth.sendSignInEmail(email, magicLink, !user);
return result;
if (result.rejected.length) {
throw new InternalServerError('Failed to send sign-in email.');
}
res.status(HttpStatus.OK).send({
email: email,
});
}
@Get('/sign-out')
async signOut(
@Req() req: Request,
@Res() res: Response,
@Query('redirect_uri') redirectUri?: string
@Session() session: Session,
@Body() { all }: { all: boolean }
) {
const session = await this.auth.signOut(
req.cookies[AuthService.sessionCookieName],
parseAuthUserSeqNum(req.headers[AuthService.authUserSeqHeaderName])
await this.auth.signOut(
session.sessionId,
all ? undefined : session.userId
);
if (session) {
res.cookie(AuthService.sessionCookieName, session.id, {
expires: session.expiresAt ?? void 0, // expiredAt is `string | null`
...this.auth.cookieOptions,
});
} else {
res.clearCookie(AuthService.sessionCookieName);
}
if (redirectUri) {
return this.url.safeRedirect(res, redirectUri);
} else {
return res.send(null);
}
res.status(HttpStatus.OK).send({});
}
@Public()
@ -156,11 +196,11 @@ export class AuthController {
validators.assertValidEmail(email);
const valid = await this.token.verifyToken(TokenType.SignIn, token, {
const tokenRecord = await this.token.verifyToken(TokenType.SignIn, token, {
credential: email,
});
if (!valid) {
if (!tokenRecord) {
throw new InvalidEmailToken();
}
@ -169,9 +209,8 @@ export class AuthController {
registered: true,
});
await this.auth.setCookie(req, res, user);
res.send({ id: user.id, email: user.email, name: user.name });
await this.auth.setCookies(req, res, user.id);
res.send({ id: user.id });
}
@Throttle('default', { limit: 1200 })

View File

@ -4,7 +4,7 @@ import type {
FactoryProvider,
OnModuleInit,
} from '@nestjs/common';
import { Injectable, SetMetadata, UseGuards } from '@nestjs/common';
import { Injectable, SetMetadata } from '@nestjs/common';
import { ModuleRef, Reflector } from '@nestjs/core';
import type { Request } from 'express';
@ -16,16 +16,8 @@ import {
parseCookies,
} from '../../fundamentals';
import { WEBSOCKET_OPTIONS } from '../../fundamentals/websocket';
import { CurrentUser, UserSession } from './current-user';
import { AuthService, parseAuthUserSeqNum } from './service';
function extractTokenFromHeader(authorization: string) {
if (!/^Bearer\s/i.test(authorization)) {
return;
}
return authorization.substring(7);
}
import { AuthService } from './service';
import { Session } from './session';
const PUBLIC_ENTRYPOINT_SYMBOL = Symbol('public');
@ -46,8 +38,8 @@ export class AuthGuard implements CanActivate, OnModuleInit {
const { req, res } = getRequestResponseFromContext(context);
const userSession = await this.signIn(req);
if (res && userSession && userSession.session.expiresAt) {
await this.auth.refreshUserSessionIfNeeded(req, res, userSession.session);
if (res && userSession && userSession.expiresAt) {
await this.auth.refreshUserSessionIfNeeded(res, userSession);
}
// api is public
@ -60,43 +52,31 @@ export class AuthGuard implements CanActivate, OnModuleInit {
return true;
}
if (!req.user) {
if (!userSession) {
throw new AuthenticationRequired();
}
return true;
}
async signIn(
req: Request
): Promise<{ user: CurrentUser; session: UserSession } | null> {
if (req.user && req.session) {
return {
user: req.user,
session: req.session,
};
async signIn(req: Request): Promise<Session | null> {
if (req.session) {
return req.session;
}
// compatibility with websocket request
parseCookies(req);
let sessionToken: string | undefined =
req.cookies[AuthService.sessionCookieName];
if (!sessionToken && req.headers.authorization) {
sessionToken = extractTokenFromHeader(req.headers.authorization);
}
// TODO(@forehalo): a cache for user session
const userSession = await this.auth.getUserSessionFromRequest(req);
if (sessionToken) {
const userSeq = parseAuthUserSeqNum(
req.headers[AuthService.authUserSeqHeaderName]
);
if (userSession) {
req.session = {
...userSession.session,
user: userSession.user,
};
const userSession = await this.auth.getUserSession(sessionToken, userSeq);
if (userSession) {
req.session = userSession.session;
req.user = userSession.user;
}
return userSession;
return req.session;
}
return null;
@ -104,26 +84,8 @@ export class AuthGuard implements CanActivate, OnModuleInit {
}
/**
* This guard is used to protect routes/queries/mutations that require a user to be logged in.
*
* The `@CurrentUser()` parameter decorator used in a `Auth` guarded queries would always give us the user because the `Auth` guard will
* fast throw if user is not logged in.
*
* @example
*
* ```typescript
* \@Auth()
* \@Query(() => UserType)
* user(@CurrentUser() user: CurrentUser) {
* return user;
* }
* ```
* Mark api to be public accessible
*/
export const Auth = () => {
return UseGuards(AuthGuard);
};
// api is public accessible
export const Public = () => SetMetadata(PUBLIC_ENTRYPOINT_SYMBOL, true);
export const AuthWebsocketOptionsProvider: FactoryProvider = {

View File

@ -28,4 +28,4 @@ export class AuthModule {}
export * from './guard';
export { ClientTokenType } from './resolver';
export { AuthService, TokenService, TokenType };
export * from './current-user';
export * from './session';

View File

@ -11,7 +11,6 @@ import {
import {
ActionForbidden,
Config,
EmailAlreadyUsed,
EmailTokenNotFound,
EmailVerificationRequired,
@ -26,9 +25,9 @@ import { Admin } from '../common';
import { UserService } from '../user';
import { UserType } from '../user/types';
import { validators } from '../utils/validators';
import { CurrentUser } from './current-user';
import { Public } from './guard';
import { AuthService } from './service';
import { CurrentUser } from './session';
import { TokenService, TokenType } from './token';
@ObjectType('tokenType')
@ -47,7 +46,6 @@ export class ClientTokenType {
@Resolver(() => UserType)
export class AuthResolver {
constructor(
private readonly config: Config,
private readonly url: URLHelper,
private readonly auth: AuthService,
private readonly user: UserService,
@ -67,7 +65,7 @@ export class AuthResolver {
@ResolveField(() => ClientTokenType, {
name: 'token',
deprecationReason: 'use [/api/auth/authorize]',
deprecationReason: 'use [/api/auth/sign-in?native=true] instead',
})
async clientToken(
@CurrentUser() currentUser: CurrentUser,
@ -77,15 +75,11 @@ export class AuthResolver {
throw new ActionForbidden();
}
const session = await this.auth.createUserSession(
user,
undefined,
this.config.auth.accessToken.ttl
);
const userSession = await this.auth.createUserSession(user.id);
return {
sessionToken: session.sessionId,
token: session.sessionId,
sessionToken: userSession.sessionId,
token: userSession.sessionId,
refresh: '',
};
}
@ -101,14 +95,6 @@ export class AuthResolver {
throw new LinkExpired();
}
const config = await this.config.runtime.fetchAll({
'auth/password.max': true,
'auth/password.min': true,
});
validators.assertValidPassword(newPassword, {
min: config['auth/password.min'],
max: config['auth/password.max'],
});
// NOTE: Set & Change password are using the same token type.
const valid = await this.token.verifyToken(
TokenType.ChangePassword,
@ -134,7 +120,6 @@ export class AuthResolver {
@Args('token') token: string,
@Args('email') email: string
) {
validators.assertValidEmail(email);
// @see [sendChangeEmail]
const valid = await this.token.verifyToken(TokenType.VerifyEmail, token, {
credential: user.id,
@ -157,8 +142,11 @@ export class AuthResolver {
async sendChangePasswordEmail(
@CurrentUser() user: CurrentUser,
@Args('callbackUrl') callbackUrl: string,
// @deprecated
@Args('email', { nullable: true }) _email?: string
@Args('email', {
nullable: true,
deprecationReason: 'fetched from signed in user',
})
_email?: string
) {
if (!user.emailVerified) {
throw new EmailVerificationRequired();
@ -180,7 +168,11 @@ export class AuthResolver {
async sendSetPasswordEmail(
@CurrentUser() user: CurrentUser,
@Args('callbackUrl') callbackUrl: string,
@Args('email', { nullable: true }) _email?: string
@Args('email', {
nullable: true,
deprecationReason: 'fetched from signed in user',
})
_email?: string
) {
return this.sendChangePasswordEmail(user, callbackUrl);
}

View File

@ -5,35 +5,12 @@ import { PrismaClient } from '@prisma/client';
import type { CookieOptions, Request, Response } from 'express';
import { assign, pick } from 'lodash-es';
import { Config, EmailAlreadyUsed, MailService } from '../../fundamentals';
import { Config, MailService, SignUpForbidden } from '../../fundamentals';
import { FeatureManagementService } from '../features/management';
import { QuotaService } from '../quota/service';
import { QuotaType } from '../quota/types';
import { UserService } from '../user/service';
import type { CurrentUser } from './current-user';
export function parseAuthUserSeqNum(value: any) {
let seq: number = 0;
switch (typeof value) {
case 'number': {
seq = value;
break;
}
case 'string': {
const result = value.match(/^([\d{0, 10}])$/);
if (result?.[1]) {
seq = Number(result[1]);
}
break;
}
default: {
seq = 0;
}
}
return Math.max(0, seq);
}
import type { CurrentUser } from './session';
export function sessionUser(
user: Pick<
@ -48,6 +25,14 @@ export function sessionUser(
});
}
function extractTokenFromHeader(authorization: string) {
if (!/^Bearer\s/i.test(authorization)) {
return;
}
return authorization.substring(7);
}
@Injectable()
export class AuthService implements OnApplicationBootstrap {
readonly cookieOptions: CookieOptions = {
@ -57,7 +42,7 @@ export class AuthService implements OnApplicationBootstrap {
secure: this.config.server.https,
};
static readonly sessionCookieName = 'affine_session';
static readonly authUserSeqHeaderName = 'x-auth-user';
static readonly userCookieName = 'affine_user_id';
constructor(
private readonly config: Config,
@ -93,46 +78,69 @@ export class AuthService implements OnApplicationBootstrap {
return this.feature.canEarlyAccess(email);
}
async signUp(
name: string,
email: string,
password: string
): Promise<CurrentUser> {
const user = await this.user.findUserByEmail(email);
if (user) {
throw new EmailAlreadyUsed();
/**
* This is a test only helper to quickly signup a user, do not use in production
*/
async signUp(email: string, password: string): Promise<CurrentUser> {
if (!this.config.node.test) {
throw new SignUpForbidden(
'sign up helper is forbidden for non-test environment'
);
}
return this.user
.createUser({
name,
.createUser_without_verification({
email,
password,
})
.then(sessionUser);
}
async signIn(email: string, password: string) {
const user = await this.user.signIn(email, password);
async signIn(email: string, password: string): Promise<CurrentUser> {
return this.user.signIn(email, password).then(sessionUser);
}
return sessionUser(user);
async signOut(sessionId: string, userId?: string) {
// sign out all users in the session
if (!userId) {
await this.db.session.deleteMany({
where: {
id: sessionId,
},
});
} else {
await this.db.userSession.deleteMany({
where: {
sessionId,
userId,
},
});
}
}
async getUserSession(
token: string,
seq = 0
sessionId: string,
userId?: string
): Promise<{ user: CurrentUser; session: UserSession } | null> {
const session = await this.getSession(token);
const userSession = await this.db.userSession.findFirst({
where: {
sessionId,
userId,
},
select: {
id: true,
sessionId: true,
userId: true,
createdAt: true,
expiresAt: true,
user: true,
},
orderBy: {
createdAt: 'asc',
},
});
// no such session
if (!session) {
return null;
}
const userSession = session.userSessions.at(seq);
// no such user session
if (!userSession) {
return null;
}
@ -142,112 +150,93 @@ export class AuthService implements OnApplicationBootstrap {
return null;
}
const user = await this.db.user.findUnique({
where: { id: userSession.userId },
});
if (!user) {
return null;
}
return { user: sessionUser(user), session: userSession };
return { user: sessionUser(userSession.user), session: userSession };
}
async getUserList(token: string) {
const session = await this.getSession(token);
if (!session || !session.userSessions.length) {
return [];
}
const users = await this.db.user.findMany({
where: {
id: {
in: session.userSessions.map(({ userId }) => userId),
},
},
});
// TODO(@forehalo): need to separate expired session, same for [getUser]
// Session
// | { user: LimitedUser { email, avatarUrl }, expired: true }
// | { user: User, expired: false }
return session.userSessions
.map(userSession => {
// keep users in the same order as userSessions
const user = users.find(({ id }) => id === userSession.userId);
if (!user) {
return null;
}
return sessionUser(user);
})
.filter(Boolean) as CurrentUser[];
}
async signOut(token: string, seq = 0) {
const session = await this.getSession(token);
if (session) {
// overflow the logged in user
if (session.userSessions.length <= seq) {
return session;
}
await this.db.userSession.deleteMany({
where: { id: session.userSessions[seq].id },
});
// no more user session active, delete the whole session
if (session.userSessions.length === 1) {
await this.db.session.delete({ where: { id: session.id } });
return null;
}
return session;
}
return null;
}
async getSession(token: string) {
if (!token) {
return null;
}
return this.db.$transaction(async tx => {
const session = await tx.session.findUnique({
async createUserSession(
userId: string,
sessionId?: string,
ttl = this.config.auth.session.ttl
) {
// check whether given session is valid
if (sessionId) {
const session = await this.db.session.findFirst({
where: {
id: token,
},
include: {
userSessions: {
orderBy: {
createdAt: 'asc',
},
},
id: sessionId,
},
});
if (!session) {
return null;
sessionId = undefined;
}
}
if (session.expiresAt && session.expiresAt <= new Date()) {
await tx.session.delete({
where: {
id: session.id,
if (!sessionId) {
const session = await this.createSession();
sessionId = session.id;
}
const expiresAt = new Date(Date.now() + ttl * 1000);
return this.db.userSession.upsert({
where: {
sessionId_userId: {
sessionId,
userId,
},
},
update: {
expiresAt,
},
create: {
sessionId,
userId,
expiresAt,
},
});
}
async getUserList(sessionId: string) {
const sessions = await this.db.userSession.findMany({
where: {
sessionId,
OR: [
{
expiresAt: null,
},
});
{
expiresAt: {
gt: new Date(),
},
},
],
},
include: {
user: true,
},
orderBy: {
createdAt: 'asc',
},
});
return null;
}
return sessions.map(({ user }) => sessionUser(user));
}
return session;
async createSession() {
return this.db.session.create({
data: {},
});
}
async getSession(sessionId: string) {
return this.db.session.findFirst({
where: {
id: sessionId,
},
});
}
async refreshUserSessionIfNeeded(
_req: Request,
res: Response,
session: UserSession,
ttr = this.config.auth.session.ttr
@ -281,70 +270,63 @@ export class AuthService implements OnApplicationBootstrap {
return true;
}
async createUserSession(
user: { id: string },
existingSession?: string,
ttl = this.config.auth.session.ttl
) {
const session = existingSession
? await this.getSession(existingSession)
: null;
const expiresAt = new Date(Date.now() + ttl * 1000);
if (session) {
return this.db.userSession.upsert({
where: {
sessionId_userId: {
sessionId: session.id,
userId: user.id,
},
},
update: {
expiresAt,
},
create: {
sessionId: session.id,
userId: user.id,
expiresAt,
},
});
} else {
return this.db.userSession.create({
data: {
expiresAt,
session: {
create: {},
},
user: {
connect: {
id: user.id,
},
},
},
});
}
}
async revokeUserSessions(userId: string, sessionId?: string) {
async revokeUserSessions(userId: string) {
return this.db.userSession.deleteMany({
where: {
userId,
sessionId,
},
});
}
async setCookie(_req: Request, res: Response, user: { id: string }) {
const session = await this.createUserSession(
user
// TODO(@forehalo): enable multi user session
// req.cookies[AuthService.sessionCookieName]
);
getSessionOptionsFromRequest(req: Request) {
let sessionId: string | undefined =
req.cookies[AuthService.sessionCookieName];
res.cookie(AuthService.sessionCookieName, session.sessionId, {
expires: session.expiresAt ?? void 0,
if (!sessionId && req.headers.authorization) {
sessionId = extractTokenFromHeader(req.headers.authorization);
}
const userId: string | undefined =
req.cookies[AuthService.userCookieName] ||
req.headers[AuthService.userCookieName];
return {
sessionId,
userId,
};
}
async setCookies(req: Request, res: Response, userId: string) {
const { sessionId } = this.getSessionOptionsFromRequest(req);
const userSession = await this.createUserSession(userId, sessionId);
res.cookie(AuthService.sessionCookieName, userSession.sessionId, {
...this.cookieOptions,
expires: userSession.expiresAt ?? void 0,
});
this.setUserCookie(res, userId);
}
setUserCookie(res: Response, userId: string) {
res.cookie(AuthService.userCookieName, userId, {
...this.cookieOptions,
// user cookie is client readable & writable for fast user switch if there are multiple users in one session
// it safe to be non-secure & non-httpOnly because server will validate it by `cookie[AuthService.sessionCookieName]`
httpOnly: false,
secure: false,
});
}
async getUserSessionFromRequest(req: Request) {
const { sessionId, userId } = this.getSessionOptionsFromRequest(req);
if (!sessionId) {
return null;
}
return this.getUserSession(sessionId, userId);
}
async changePassword(
@ -393,24 +375,16 @@ export class AuthService implements OnApplicationBootstrap {
async sendSignInEmail(email: string, link: string, signUp: boolean) {
return signUp
? await this.mailer.sendSignUpMail(link.toString(), {
? await this.mailer.sendSignUpMail(link, {
to: email,
})
: await this.mailer.sendSignInMail(link.toString(), {
: await this.mailer.sendSignInMail(link, {
to: email,
});
}
@Cron(CronExpression.EVERY_DAY_AT_MIDNIGHT)
async cleanExpiredSessions() {
await this.db.session.deleteMany({
where: {
expiresAt: {
lte: new Date(),
},
},
});
await this.db.userSession.deleteMany({
where: {
expiresAt: {

View File

@ -4,10 +4,6 @@ import { User, UserSession } from '@prisma/client';
import { getRequestResponseFromContext } from '../../fundamentals';
function getUserFromContext(context: ExecutionContext) {
return getRequestResponseFromContext(context).req.user;
}
/**
* Used to fetch current user from the request context.
*
@ -44,7 +40,7 @@ function getUserFromContext(context: ExecutionContext) {
// eslint-disable-next-line no-redeclare
export const CurrentUser = createParamDecorator(
(_: unknown, context: ExecutionContext) => {
return getUserFromContext(context);
return getRequestResponseFromContext(context).req.session?.user;
}
);
@ -54,4 +50,14 @@ export interface CurrentUser
emailVerified: boolean;
}
export { type UserSession };
// interface and variable don't conflict
// eslint-disable-next-line no-redeclare
export const Session = createParamDecorator(
(_: unknown, context: ExecutionContext) => {
return getRequestResponseFromContext(context).req.session;
}
);
export type Session = UserSession & {
user: CurrentUser;
};

View File

@ -25,8 +25,8 @@ export class AdminGuard implements CanActivate, OnModuleInit {
async canActivate(context: ExecutionContext) {
const { req } = getRequestResponseFromContext(context);
let allow = false;
if (req.user) {
allow = await this.feature.isAdmin(req.user.id);
if (req.session) {
allow = await this.feature.isAdmin(req.session.user.id);
}
if (!allow) {

View File

@ -7,7 +7,7 @@ import {
} from '@nestjs/graphql';
import { SafeIntResolver } from 'graphql-scalars';
import { CurrentUser } from '../auth/current-user';
import { CurrentUser } from '../auth/session';
import { EarlyAccessType } from '../features';
import { UserType } from '../user';
import { QuotaService } from './service';

View File

@ -56,7 +56,7 @@ export class CustomSetupController {
try {
await this.event.emitAsync('user.admin.created', user);
await this.auth.setCookie(req, res, user);
await this.auth.setCookies(req, res, user.id);
res.send({ id: user.id, email: user.email, name: user.name });
} catch (e) {
await this.user.deleteUser(user.id);

View File

@ -21,7 +21,7 @@ import {
SpaceAccessDenied,
VersionRejected,
} from '../../fundamentals';
import { Auth, CurrentUser } from '../auth';
import { CurrentUser } from '../auth';
import {
DocStorageAdapter,
PgUserspaceDocStorageAdapter,
@ -203,7 +203,6 @@ export class SpaceSyncGateway
}
// v3
@Auth()
@SubscribeMessage('space:join')
async onJoinSpace(
@CurrentUser() user: CurrentUser,
@ -264,7 +263,6 @@ export class SpaceSyncGateway
};
}
@Auth()
@SubscribeMessage('space:push-doc-updates')
async onReceiveDocUpdates(
@ConnectedSocket() client: Socket,
@ -324,7 +322,6 @@ export class SpaceSyncGateway
};
}
@Auth()
@SubscribeMessage('space:join-awareness')
async onJoinAwareness(
@ConnectedSocket() client: Socket,
@ -410,7 +407,6 @@ export class SpaceSyncGateway
// TODO(@forehalo): remove
// deprecated section
@Auth()
@SubscribeMessage('client-handshake-sync')
async handleClientHandshakeSync(
@CurrentUser() user: CurrentUser,
@ -451,7 +447,6 @@ export class SpaceSyncGateway
});
}
@Auth()
@SubscribeMessage('client-update-v2')
async handleClientUpdateV2(
@CurrentUser() user: CurrentUser,
@ -499,7 +494,6 @@ export class SpaceSyncGateway
});
}
@Auth()
@SubscribeMessage('client-handshake-awareness')
async handleClientHandshakeAwareness(
@ConnectedSocket() client: Socket,

View File

@ -18,9 +18,9 @@ import {
Throttle,
UserNotFound,
} from '../../fundamentals';
import { CurrentUser } from '../auth/current-user';
import { Public } from '../auth/guard';
import { sessionUser } from '../auth/service';
import { CurrentUser } from '../auth/session';
import { Admin } from '../common';
import { AvatarStorage } from '../storage';
import { validators } from '../utils/validators';

View File

@ -56,11 +56,6 @@ export class UserService {
async createUser(data: CreateUserInput) {
validators.assertValidEmail(data.email);
const user = await this.findUserByEmail(data.email);
if (user) {
throw new EmailAlreadyUsed();
}
if (data.password) {
const config = await this.config.runtime.fetchAll({
@ -77,6 +72,12 @@ export class UserService {
}
async createUser_without_verification(data: CreateUserInput) {
const user = await this.findUserByEmail(data.email);
if (user) {
throw new EmailAlreadyUsed();
}
if (data.password) {
data.password = await this.crypto.encryptPassword(data.password);
}
@ -158,9 +159,7 @@ export class UserService {
async fulfillUser(
email: string,
data: Partial<
Pick<Prisma.UserCreateInput, 'emailVerifiedAt' | 'registered'>
>
data: Omit<Partial<Prisma.UserCreateInput>, 'id'>
) {
const user = await this.findUserByEmail(email);
if (!user) {
@ -180,7 +179,6 @@ export class UserService {
if (Object.keys(data).length) {
return await this.prisma.user.update({
select: this.defaultUserSelect,
where: { id: user.id },
data,
});

View File

@ -8,7 +8,7 @@ import {
import type { User } from '@prisma/client';
import type { Payload } from '../../fundamentals/event/def';
import { CurrentUser } from '../auth/current-user';
import { type CurrentUser } from '../auth/session';
@ObjectType()
export class UserType implements CurrentUser {

View File

@ -12,7 +12,7 @@ import {
ThrottlerRequest,
ThrottlerStorageService,
} from '@nestjs/throttler';
import type { Request } from 'express';
import type { Request, Response } from 'express';
import { Config } from '../config';
import { getRequestResponseFromContext } from '../utils/request';
@ -50,7 +50,10 @@ export class CloudThrottlerGuard extends ThrottlerGuard {
super(options, storageService, reflector);
}
override getRequestResponse(context: ExecutionContext) {
override getRequestResponse(context: ExecutionContext): {
req: Request;
res: Response;
} {
return getRequestResponseFromContext(context) as any;
}
@ -153,7 +156,7 @@ export class CloudThrottlerGuard extends ThrottlerGuard {
const throttler = this.getSpecifiedThrottler(context);
// if user is logged in, bypass non-protected handlers
if (!throttler && req.user) {
if (!throttler && req.session?.user) {
return true;
}

View File

@ -1,3 +1,5 @@
import { IncomingMessage } from 'node:http';
import type { ArgumentsHost, ExecutionContext } from '@nestjs/common';
import type { GqlContextType } from '@nestjs/graphql';
import { GqlArgumentsHost } from '@nestjs/graphql';
@ -25,26 +27,7 @@ export function getRequestResponseFromHost(host: ArgumentsHost) {
case 'ws': {
const ws = host.switchToWs();
const req = ws.getClient<Socket>().client.conn.request as Request;
const cookieStr = req?.headers?.cookie ?? '';
// patch cookies to match auth guard logic
if (typeof cookieStr === 'string') {
req.cookies = cookieStr.split(';').reduce(
(cookies, cookie) => {
const [key, val] = cookie.split('=');
if (key) {
cookies[decodeURIComponent(key.trim())] = val
? decodeURIComponent(val.trim())
: val;
}
return cookies;
},
{} as Record<string, string>
);
}
parseCookies(req);
return { req };
}
case 'rpc': {
@ -71,12 +54,14 @@ export function getRequestResponseFromContext(ctx: ExecutionContext) {
* simple patch for request not protected by `cookie-parser`
* only take effect if `req.cookies` is not defined
*/
export function parseCookies(req: Request) {
export function parseCookies(
req: IncomingMessage & { cookies?: Record<string, string> }
) {
if (req.cookies) {
return;
}
const cookieStr = req?.headers?.cookie ?? '';
const cookieStr = req.headers.cookie ?? '';
req.cookies = cookieStr.split(';').reduce(
(cookies, cookie) => {
const [key, val] = cookie.split('=');

View File

@ -1,7 +1,6 @@
declare namespace Express {
interface Request {
user?: import('./core/auth/current-user').CurrentUser;
session?: import('./core/auth/current-user').UserSession;
session?: import('./core/auth/session').Session;
}
}

View File

@ -27,8 +27,7 @@ import {
toArray,
} from 'rxjs';
import { Public } from '../../core/auth';
import { CurrentUser } from '../../core/auth/current-user';
import { CurrentUser, Public } from '../../core/auth';
import {
BlobNotFound,
Config,

View File

@ -1,4 +1,12 @@
import { Controller, Get, Query, Req, Res } from '@nestjs/common';
import {
Body,
Controller,
HttpCode,
HttpStatus,
Post,
Req,
Res,
} from '@nestjs/common';
import { ConnectedAccount, PrismaClient } from '@prisma/client';
import type { Request, Response } from 'express';
@ -11,34 +19,34 @@ import {
OauthStateExpired,
UnknownOauthProvider,
URLHelper,
WrongSignInMethod,
} from '../../fundamentals';
import { OAuthProviderName } from './config';
import { OAuthAccount, Tokens } from './providers/def';
import { OAuthProviderFactory } from './register';
import { OAuthService } from './service';
@Controller('/oauth')
@Controller('/api/oauth')
export class OAuthController {
constructor(
private readonly auth: AuthService,
private readonly oauth: OAuthService,
private readonly user: UserService,
private readonly providerFactory: OAuthProviderFactory,
private readonly url: URLHelper,
private readonly providerFactory: OAuthProviderFactory,
private readonly db: PrismaClient
) {}
@Public()
@Get('/login')
async login(
@Res() res: Response,
@Query('provider') unknownProviderName: string,
@Query('redirect_uri') redirectUri?: string
@Post('/preflight')
@HttpCode(HttpStatus.OK)
async preflight(
@Body('provider') unknownProviderName?: string,
@Body('redirect_uri') redirectUri: string = this.url.home
) {
if (!unknownProviderName) {
throw new MissingOauthQueryParameter({ name: 'provider' });
}
// @ts-expect-error safe
const providerName = OAuthProviderName[unknownProviderName];
const provider = this.providerFactory.get(providerName);
@ -48,20 +56,23 @@ export class OAuthController {
}
const state = await this.oauth.saveOAuthState({
redirectUri: redirectUri ?? this.url.home,
provider: providerName,
redirectUri,
});
return res.redirect(provider.getAuthUrl(state));
return {
url: provider.getAuthUrl(state),
};
}
@Public()
@Get('/callback')
@Post('/callback')
@HttpCode(HttpStatus.OK)
async callback(
@Req() req: Request,
@Res() res: Response,
@Query('code') code?: string,
@Query('state') stateStr?: string
@Body('code') code?: string,
@Body('state') stateStr?: string
) {
if (!code) {
throw new MissingOauthQueryParameter({ name: 'code' });
@ -93,43 +104,18 @@ export class OAuthController {
const tokens = await provider.getToken(code);
const externAccount = await provider.getUser(tokens.accessToken);
const user = req.user;
const user = await this.loginFromOauth(
state.provider,
externAccount,
tokens
);
try {
if (!user) {
// if user not found, login
const user = await this.loginFromOauth(
state.provider,
externAccount,
tokens
);
const session = await this.auth.createUserSession(
user,
req.cookies[AuthService.sessionCookieName]
);
res.cookie(AuthService.sessionCookieName, session.sessionId, {
expires: session.expiresAt ?? void 0, // expiredAt is `string | null`
...this.auth.cookieOptions,
});
} else {
// if user is found, connect the account to this user
await this.connectAccountFromOauth(
user,
state.provider,
externAccount,
tokens
);
}
} catch (e: any) {
return res.redirect(
this.url.link('/signIn', {
redirect_uri: state.redirectUri,
error: e.message,
})
);
}
this.url.safeRedirect(res, state.redirectUri);
await this.auth.setCookies(req, res, user.id);
res.send({
id: user.id,
/* @deprecated */
redirectUri: state.redirectUri,
});
}
private async loginFromOauth(
@ -154,37 +140,27 @@ export class OAuthController {
return connectedUser.user;
}
let user = await this.user.findUserByEmail(externalAccount.email);
const user = await this.user.fulfillUser(externalAccount.email, {
emailVerifiedAt: new Date(),
registered: true,
avatarUrl: externalAccount.avatarUrl,
});
if (user) {
// we can't directly connect the external account with given email in sign in scenario for safety concern.
// let user manually connect in account sessions instead.
if (user.registered) {
throw new WrongSignInMethod();
}
await this.db.connectedAccount.create({
data: {
userId: user.id,
provider,
providerAccountId: externalAccount.id,
...tokens,
},
});
return user;
} else {
user = await this.createUserWithConnectedAccount(
await this.db.connectedAccount.create({
data: {
userId: user.id,
provider,
externalAccount,
tokens
);
}
providerAccountId: externalAccount.id,
...tokens,
},
});
return user;
}
updateConnectedAccount(connectedUser: ConnectedAccount, tokens: Tokens) {
private async updateConnectedAccount(
connectedUser: ConnectedAccount,
tokens: Tokens
) {
return this.db.connectedAccount.update({
where: {
id: connectedUser.id,
@ -193,27 +169,12 @@ export class OAuthController {
});
}
async createUserWithConnectedAccount(
provider: OAuthProviderName,
externalAccount: OAuthAccount,
tokens: Tokens
) {
return this.user.createUser({
email: externalAccount.email,
name: externalAccount.email.split('@')[0],
avatarUrl: externalAccount.avatarUrl,
emailVerifiedAt: new Date(),
connectedAccounts: {
create: {
provider,
providerAccountId: externalAccount.id,
...tokens,
},
},
});
}
private async connectAccountFromOauth(
/**
* we currently don't support connect oauth account to existing user
* keep it incase we need it in the future
*/
// @ts-expect-error allow unused
private async _connectAccount(
user: { id: string },
provider: OAuthProviderName,
externalAccount: OAuthAccount,

View File

@ -15,7 +15,7 @@ export interface Tokens {
export abstract class OAuthProvider {
abstract provider: OAuthProviderName;
abstract getAuthUrl(state?: string): string;
abstract getAuthUrl(state: string): string;
abstract getToken(code: string): Promise<Tokens>;
abstract getUser(token: string): Promise<OAuthAccount>;
}

View File

@ -9,7 +9,7 @@ import { OAuthProviderFactory } from './register';
const OAUTH_STATE_KEY = 'OAUTH_STATE';
interface OAuthState {
redirectUri: string;
redirectUri?: string;
provider: OAuthProviderName;
}

View File

@ -474,8 +474,8 @@ type Mutation {
revokePage(pageId: String!, workspaceId: String!): Boolean! @deprecated(reason: "use revokePublicPage")
revokePublicPage(pageId: String!, workspaceId: String!): WorkspacePage!
sendChangeEmail(callbackUrl: String!, email: String): Boolean!
sendChangePasswordEmail(callbackUrl: String!, email: String): Boolean!
sendSetPasswordEmail(callbackUrl: String!, email: String): Boolean!
sendChangePasswordEmail(callbackUrl: String!, email: String @deprecated(reason: "fetched from signed in user")): Boolean!
sendSetPasswordEmail(callbackUrl: String!, email: String @deprecated(reason: "fetched from signed in user")): Boolean!
sendVerifyChangeEmail(callbackUrl: String!, email: String!, token: String!): Boolean!
sendVerifyEmail(callbackUrl: String!): Boolean!
setBlob(blob: Upload!, workspaceId: String!): String!
@ -862,7 +862,7 @@ type UserType {
quota: UserQuota
subscription(plan: SubscriptionPlan = Pro): UserSubscription @deprecated(reason: "use `UserType.subscriptions`")
subscriptions: [UserSubscription!]!
token: tokenType! @deprecated(reason: "use [/api/auth/authorize]")
token: tokenType! @deprecated(reason: "use [/api/auth/sign-in?native=true] instead")
}
type VersionRejectedDataType {

View File

@ -8,8 +8,8 @@ import type { INestApplication } from '@nestjs/common';
import type { TestFn } from 'ava';
import ava from 'ava';
import { AuthService } from '../src/core/auth/service';
import { MailService } from '../src/fundamentals/mailer';
import { AuthService } from '../../src/core/auth/service';
import { MailService } from '../../src/fundamentals/mailer';
import {
changeEmail,
changePassword,
@ -19,7 +19,7 @@ import {
sendSetPasswordEmail,
sendVerifyChangeEmail,
signUp,
} from './utils';
} from '../utils';
const test = ava as TestFn<{
app: INestApplication;

View File

@ -20,7 +20,7 @@ const test = ava as TestFn<{
app: INestApplication;
}>;
test.beforeEach(async t => {
test.before(async t => {
const { app } = await createTestingApp({
imports: [FeatureModule, UserModule, AuthModule],
tapModule: m => {
@ -36,10 +36,14 @@ test.beforeEach(async t => {
t.context.mailer = app.get(MailService);
t.context.app = app;
t.context.u1 = await t.context.auth.signUp('u1', 'u1@affine.pro', '1');
t.context.u1 = await t.context.auth.signUp('u1@affine.pro', '1');
});
test.afterEach.always(async t => {
test.beforeEach(() => {
Sinon.reset();
});
test.after.always(async t => {
await t.context.app.close();
});

View File

@ -1,15 +1,10 @@
import { Controller, Get, HttpStatus, INestApplication } from '@nestjs/common';
import { APP_GUARD } from '@nestjs/core';
import { PrismaClient } from '@prisma/client';
import ava, { TestFn } from 'ava';
import Sinon from 'sinon';
import request from 'supertest';
import {
AuthGuard,
AuthModule,
CurrentUser,
Public,
} from '../../src/core/auth';
import { AuthModule, CurrentUser, Public, Session } from '../../src/core/auth';
import { AuthService } from '../../src/core/auth/service';
import { createTestingApp } from '../utils';
@ -25,115 +20,123 @@ class TestController {
private(@CurrentUser() user: CurrentUser) {
return { user };
}
@Get('/session')
session(@Session() session: Session) {
return session;
}
}
const test = ava as TestFn<{
app: INestApplication;
auth: Sinon.SinonStubbedInstance<AuthService>;
}>;
test.beforeEach(async t => {
let server!: any;
let auth!: AuthService;
let u1!: CurrentUser;
test.before(async t => {
const { app } = await createTestingApp({
imports: [AuthModule],
providers: [
{
provide: APP_GUARD,
useClass: AuthGuard,
},
],
controllers: [TestController],
tapModule: m => {
m.overrideProvider(AuthService).useValue(
Sinon.createStubInstance(AuthService)
);
},
});
t.context.auth = app.get(AuthService);
auth = app.get(AuthService);
u1 = await auth.signUp('u1@affine.pro', '1');
const db = app.get(PrismaClient);
await db.session.create({
data: {
id: '1',
},
});
await auth.createUserSession(u1.id, '1');
server = app.getHttpServer();
t.context.app = app;
});
test.afterEach.always(async t => {
test.after.always(async t => {
await t.context.app.close();
});
test('should be able to visit public api if not signed in', async t => {
const { app } = t.context;
const res = await request(app.getHttpServer()).get('/public').expect(200);
const res = await request(server).get('/public').expect(200);
t.is(res.body.user, undefined);
});
test('should be able to visit public api if signed in', async t => {
const { app, auth } = t.context;
// @ts-expect-error mock
auth.getUserSession.resolves({ user: { id: '1' }, session: { id: '1' } });
const res = await request(app.getHttpServer())
const res = await request(server)
.get('/public')
.set('Cookie', `${AuthService.sessionCookieName}=1`)
.expect(HttpStatus.OK);
t.is(res.body.user.id, '1');
t.is(res.body.user.id, u1.id);
});
test('should not be able to visit private api if not signed in', async t => {
const { app } = t.context;
await request(app.getHttpServer())
.get('/private')
.expect(HttpStatus.UNAUTHORIZED)
.expect({
status: 401,
code: 'Unauthorized',
type: 'AUTHENTICATION_REQUIRED',
name: 'AUTHENTICATION_REQUIRED',
message: 'You must sign in first to access this resource.',
});
await request(server).get('/private').expect(HttpStatus.UNAUTHORIZED).expect({
status: 401,
code: 'Unauthorized',
type: 'AUTHENTICATION_REQUIRED',
name: 'AUTHENTICATION_REQUIRED',
message: 'You must sign in first to access this resource.',
});
t.assert(true);
});
test('should be able to visit private api if signed in', async t => {
const { app, auth } = t.context;
// @ts-expect-error mock
auth.getUserSession.resolves({ user: { id: '1' }, session: { id: '1' } });
const res = await request(app.getHttpServer())
const res = await request(server)
.get('/private')
.set('Cookie', `${AuthService.sessionCookieName}=1`)
.expect(HttpStatus.OK);
t.is(res.body.user.id, '1');
t.is(res.body.user.id, u1.id);
});
test('should be able to parse session cookie', async t => {
const { app, auth } = t.context;
// @ts-expect-error mock
auth.getUserSession.resolves({ user: { id: '1' }, session: { id: '1' } });
await request(app.getHttpServer())
const spy = Sinon.spy(auth, 'getUserSession');
await request(server)
.get('/public')
.set('cookie', `${AuthService.sessionCookieName}=1`)
.expect(200);
t.deepEqual(auth.getUserSession.firstCall.args, ['1', 0]);
t.deepEqual(spy.firstCall.args, ['1', undefined]);
spy.restore();
});
test('should be able to parse bearer token', async t => {
const { app, auth } = t.context;
const spy = Sinon.spy(auth, 'getUserSession');
// @ts-expect-error mock
auth.getUserSession.resolves({ user: { id: '1' }, session: { id: '1' } });
await request(app.getHttpServer())
await request(server)
.get('/public')
.auth('1', { type: 'bearer' })
.expect(200);
t.deepEqual(auth.getUserSession.firstCall.args, ['1', 0]);
t.deepEqual(spy.firstCall.args, ['1', undefined]);
spy.restore();
});
test('should be able to refresh session if needed', async t => {
await t.context.app.get(PrismaClient).userSession.updateMany({
where: {
sessionId: '1',
},
data: {
expiresAt: new Date(Date.now() + 1000 * 60 * 60 /* expires in 1 hour */),
},
});
const res = await request(server)
.get('/session')
.set('cookie', `${AuthService.sessionCookieName}=1`)
.expect(200);
const cookie = res
.get('Set-Cookie')
?.find(c => c.startsWith(AuthService.sessionCookieName));
t.truthy(cookie);
});

View File

@ -3,11 +3,11 @@ import { PrismaClient } from '@prisma/client';
import ava, { TestFn } from 'ava';
import { CurrentUser } from '../../src/core/auth';
import { AuthService, parseAuthUserSeqNum } from '../../src/core/auth/service';
import { AuthService } from '../../src/core/auth/service';
import { FeatureModule } from '../../src/core/features';
import { QuotaModule } from '../../src/core/quota';
import { UserModule, UserService } from '../../src/core/user';
import { createTestingModule } from '../utils';
import { createTestingModule, initTestingDB } from '../utils';
const test = ava as TestFn<{
auth: AuthService;
@ -17,7 +17,7 @@ const test = ava as TestFn<{
m: TestingModule;
}>;
test.beforeEach(async t => {
test.before(async t => {
const m = await createTestingModule({
imports: [QuotaModule, FeatureModule, UserModule],
providers: [AuthService],
@ -27,50 +27,18 @@ test.beforeEach(async t => {
t.context.user = m.get(UserService);
t.context.db = m.get(PrismaClient);
t.context.m = m;
t.context.u1 = await t.context.auth.signUp('u1', 'u1@affine.pro', '1');
});
test.afterEach.always(async t => {
test.beforeEach(async t => {
await initTestingDB(t.context.db);
t.context.u1 = await t.context.auth.signUp('u1@affine.pro', '1');
});
test.after.always(async t => {
await t.context.m.close();
});
test('should be able to parse auth user seq num', t => {
t.deepEqual(
[
'1',
'2',
3,
-3,
'-4',
'1.1',
'str',
'1111111111111111111111111111111111111111111',
].map(parseAuthUserSeqNum),
[1, 2, 3, 0, 0, 0, 0, 0]
);
});
test('should be able to sign up', async t => {
const { auth } = t.context;
const u2 = await auth.signUp('u2', 'u2@affine.pro', '1');
t.is(u2.email, 'u2@affine.pro');
const signedU2 = await auth.signIn(u2.email, '1');
t.is(u2.email, signedU2.email);
});
test('should throw if email duplicated', async t => {
const { auth } = t.context;
await t.throwsAsync(() => auth.signUp('u1', 'u1@affine.pro', '1'), {
message: 'This email has already been registered.',
});
});
test('should be able to sign in', async t => {
test('should be able to sign in by password', async t => {
const { auth } = t.context;
const signedInUser = await auth.signIn('u1@affine.pro', '1');
@ -114,7 +82,7 @@ test('should be able to change password', async t => {
let signedInU1 = await auth.signIn('u1@affine.pro', '1');
t.is(signedInU1.email, u1.email);
await auth.changePassword(u1.id, '2');
await auth.changePassword(u1.id, 'hello world affine');
await t.throwsAsync(
() => auth.signIn('u1@affine.pro', '1' /* old password */),
@ -123,7 +91,7 @@ test('should be able to change password', async t => {
}
);
signedInU1 = await auth.signIn('u1@affine.pro', '2');
signedInU1 = await auth.signIn('u1@affine.pro', 'hello world affine');
t.is(signedInU1.email, u1.email);
});
@ -147,7 +115,7 @@ test('should be able to change email', async t => {
test('should be able to create user session', async t => {
const { auth, u1 } = t.context;
const session = await auth.createUserSession(u1);
const session = await auth.createUserSession(u1.id);
t.is(session.userId, u1.id);
});
@ -155,7 +123,7 @@ test('should be able to create user session', async t => {
test('should be able to get user from session', async t => {
const { auth, u1 } = t.context;
const session = await auth.createUserSession(u1);
const session = await auth.createUserSession(u1.id);
const userSession = await auth.getUserSession(session.sessionId);
@ -166,23 +134,50 @@ test('should be able to get user from session', async t => {
test('should be able to sign out session', async t => {
const { auth, u1 } = t.context;
const session = await auth.createUserSession(u1);
const session = await auth.createUserSession(u1.id);
await auth.signOut(session.sessionId);
const userSession = await auth.getUserSession(session.sessionId);
const signedOutSession = await auth.signOut(session.sessionId);
t.is(userSession, null);
});
t.is(signedOutSession, null);
test('should not return expired session', async t => {
const { auth, u1, db } = t.context;
const session = await auth.createUserSession(u1.id);
await db.userSession.update({
where: { id: session.id },
data: {
expiresAt: new Date(Date.now() - 1000),
},
});
const userSession = await auth.getUserSession(session.sessionId);
t.is(userSession, null);
});
// Tests for Multi-Accounts Session
test('should be able to sign in different user in a same session', async t => {
const { auth, u1 } = t.context;
const u2 = await auth.signUp('u2', 'u2@affine.pro', '1');
const u2 = await auth.signUp('u2@affine.pro', '1');
const session = await auth.createUserSession(u1);
await auth.createUserSession(u2, session.sessionId);
const session = await auth.createSession();
const [signedU1, signedU2] = await auth.getUserList(session.sessionId);
await auth.createUserSession(u1.id, session.id);
let userList = await auth.getUserList(session.id);
t.is(userList.length, 1);
t.is(userList[0]!.id, u1.id);
await auth.createUserSession(u2.id, session.id);
userList = await auth.getUserList(session.id);
t.is(userList.length, 2);
const [signedU1, signedU2] = userList;
t.not(signedU1, null);
t.not(signedU2, null);
@ -193,29 +188,30 @@ test('should be able to sign in different user in a same session', async t => {
test('should be able to signout multi accounts session', async t => {
const { auth, u1 } = t.context;
const u2 = await auth.signUp('u2', 'u2@affine.pro', '1');
const u2 = await auth.signUp('u2@affine.pro', '1');
const session = await auth.createUserSession(u1);
await auth.createUserSession(u2, session.sessionId);
const session = await auth.createSession();
// sign out user at seq(0)
let signedOutSession = await auth.signOut(session.sessionId);
await auth.createUserSession(u1.id, session.id);
await auth.createUserSession(u2.id, session.id);
t.not(signedOutSession, null);
await auth.signOut(session.id, u1.id);
const userSession1 = await auth.getUserSession(session.sessionId, 0);
const userSession2 = await auth.getUserSession(session.sessionId, 1);
let list = await auth.getUserList(session.id);
t.is(userSession2, null);
t.not(userSession1, null);
t.is(list.length, 1);
t.is(list[0]!.id, u2.id);
t.is(userSession1!.user.id, u2.id);
const u1Session = await auth.getUserSession(session.id, u1.id);
// sign out user at seq(0)
signedOutSession = await auth.signOut(session.sessionId);
t.is(u1Session, null);
t.is(signedOutSession, null);
await auth.signOut(session.id, u2.id);
list = await auth.getUserList(session.id);
const userSession3 = await auth.getUserSession(session.sessionId, 0);
t.is(userSession3, null);
t.is(list.length, 0);
const u2Session = await auth.getUserSession(session.id, u2.id);
t.is(u2Session, null);
});

View File

@ -10,7 +10,7 @@ const test = ava as TestFn<{
m: TestingModule;
}>;
test.beforeEach(async t => {
test.before(async t => {
const m = await createTestingModule({
providers: [TokenService],
});
@ -19,7 +19,7 @@ test.beforeEach(async t => {
t.context.m = m;
});
test.afterEach.always(async t => {
test.after.always(async t => {
await t.context.m.close();
});

View File

@ -105,7 +105,7 @@ test.afterEach.always(async t => {
let userId: string;
test.beforeEach(async t => {
const { auth } = t.context;
const user = await auth.signUp('test', 'darksky@affine.pro', '123456');
const user = await auth.signUp('test@affine.pro', '123456');
userId = user.id;
});
@ -308,7 +308,7 @@ test('should be able to fork chat session', async t => {
});
t.not(sessionId, forkedSessionId1, 'should fork a new session');
const newUser = await auth.signUp('test', 'darksky.1@affine.pro', '123456');
const newUser = await auth.signUp('darksky.1@affine.pro', '123456');
const forkedSessionId2 = await session.fork({
userId: newUser.id,
sessionId,

View File

@ -82,7 +82,7 @@ test.afterEach.always(async t => {
test('should be able to set user feature', async t => {
const { auth, feature } = t.context;
const u1 = await auth.signUp('DarkSky', 'darksky@example.org', '123456');
const u1 = await auth.signUp('test@test.com', '123456');
const f1 = await feature.getUserFeatures(u1.id);
t.is(f1.length, 0, 'should be empty');
@ -96,7 +96,7 @@ test('should be able to set user feature', async t => {
test('should be able to check early access', async t => {
const { auth, feature, management } = t.context;
const u1 = await auth.signUp('DarkSky', 'darksky@example.org', '123456');
const u1 = await auth.signUp('test@test.com', '123456');
const f1 = await management.canEarlyAccess(u1.email);
t.false(f1, 'should not have early access');
@ -112,7 +112,7 @@ test('should be able to check early access', async t => {
test('should be able revert user feature', async t => {
const { auth, feature, management } = t.context;
const u1 = await auth.signUp('DarkSky', 'darksky@example.org', '123456');
const u1 = await auth.signUp('test@test.com', '123456');
const f1 = await management.canEarlyAccess(u1.email);
t.false(f1, 'should not have early access');
@ -138,7 +138,7 @@ test('should be able revert user feature', async t => {
test('should be same instance after reset the user feature', async t => {
const { auth, feature, management } = t.context;
const u1 = await auth.signUp('DarkSky', 'darksky@example.org', '123456');
const u1 = await auth.signUp('test@test.com', '123456');
await management.addEarlyAccess(u1.id);
const f1 = (await feature.getUserFeatures(u1.id))[0];
@ -154,7 +154,7 @@ test('should be same instance after reset the user feature', async t => {
test('should be able to set workspace feature', async t => {
const { auth, feature, workspace } = t.context;
const u1 = await auth.signUp('DarkSky', 'darksky@example.org', '123456');
const u1 = await auth.signUp('test@test.com', '123456');
const w1 = await workspace.createWorkspace(u1, null);
const f1 = await feature.getWorkspaceFeatures(w1.id);
@ -169,7 +169,7 @@ test('should be able to set workspace feature', async t => {
test('should be able to check workspace feature', async t => {
const { auth, feature, workspace, management } = t.context;
const u1 = await auth.signUp('DarkSky', 'darksky@example.org', '123456');
const u1 = await auth.signUp('test@test.com', '123456');
const w1 = await workspace.createWorkspace(u1, null);
const f1 = await management.hasWorkspaceFeature(w1.id, FeatureType.Copilot);
@ -186,7 +186,7 @@ test('should be able to check workspace feature', async t => {
test('should be able revert workspace feature', async t => {
const { auth, feature, workspace, management } = t.context;
const u1 = await auth.signUp('DarkSky', 'darksky@example.org', '123456');
const u1 = await auth.signUp('test@test.com', '123456');
const w1 = await workspace.createWorkspace(u1, null);
const f1 = await management.hasWorkspaceFeature(w1.id, FeatureType.Copilot);

View File

@ -33,7 +33,7 @@ test.afterEach.always(async t => {
test('should include callbackUrl in sending email', async t => {
const { auth } = t.context;
await auth.signUp('Alex Yang', 'alexyang@example.org', '123456');
await auth.signUp('test@affine.pro', '123456');
for (const fn of [
'sendSetPasswordEmail',
'sendChangeEmail',
@ -41,7 +41,7 @@ test('should include callbackUrl in sending email', async t => {
'sendVerifyChangeEmail',
] as const) {
const prev = await getCurrentMailMessageCount();
await auth[fn]('alexyang@example.org', 'https://test.com/callback');
await auth[fn]('test@affine.pro', 'https://test.com/callback');
const current = await getCurrentMailMessageCount();
const mail = await getLatestMailMessage();
t.regex(

View File

@ -7,6 +7,7 @@ import {
INestApplication,
UseGuards,
} from '@nestjs/common';
import { PrismaClient } from '@prisma/client';
import ava, { TestFn } from 'ava';
import Sinon from 'sinon';
import request, { type Response } from 'supertest';
@ -20,7 +21,7 @@ import {
Throttle,
ThrottlerStorage,
} from '../../src/fundamentals/throttler';
import { createTestingApp, internalSignIn } from '../utils';
import { createTestingApp, initTestingDB, internalSignIn } from '../utils';
const test = ava as TestFn<{
storage: ThrottlerStorage;
@ -93,7 +94,7 @@ class NonThrottledController {
}
}
test.beforeEach(async t => {
test.before(async t => {
const { app } = await createTestingApp({
imports: [
ConfigModule.forRoot({
@ -111,14 +112,17 @@ test.beforeEach(async t => {
t.context.storage = app.get(ThrottlerStorage);
t.context.app = app;
});
test.beforeEach(async t => {
await initTestingDB(t.context.app.get(PrismaClient));
const { app } = t.context;
const auth = app.get(AuthService);
const u1 = await auth.signUp('u1', 'u1@affine.pro', 'test');
const u1 = await auth.signUp('u1@affine.pro', 'test');
t.context.cookie = await internalSignIn(app, u1.id);
});
test.afterEach.always(async t => {
test.after.always(async t => {
await t.context.app.close();
});

View File

@ -15,7 +15,7 @@ import { ConfigModule } from '../../src/fundamentals/config';
import { OAuthProviderName } from '../../src/plugins/oauth/config';
import { GoogleOAuthProvider } from '../../src/plugins/oauth/providers/google';
import { OAuthService } from '../../src/plugins/oauth/service';
import { createTestingApp, getSession } from '../utils';
import { createTestingApp, getSession, initTestingDB } from '../utils';
const test = ava as TestFn<{
auth: AuthService;
@ -26,7 +26,7 @@ const test = ava as TestFn<{
app: INestApplication;
}>;
test.beforeEach(async t => {
test.before(async t => {
const { app } = await createTestingApp({
imports: [
ConfigModule.forRoot({
@ -50,11 +50,15 @@ test.beforeEach(async t => {
t.context.user = app.get(UserService);
t.context.db = app.get(PrismaClient);
t.context.app = app;
t.context.u1 = await t.context.auth.signUp('u1', 'u1@affine.pro', '1');
});
test.afterEach.always(async t => {
test.beforeEach(async t => {
Sinon.restore();
await initTestingDB(t.context.db);
t.context.u1 = await t.context.auth.signUp('u1@affine.pro', '1');
});
test.after.always(async t => {
await t.context.app.close();
});
@ -62,10 +66,13 @@ test("should be able to redirect to oauth provider's login page", async t => {
const { app } = t.context;
const res = await request(app.getHttpServer())
.get('/oauth/login?provider=Google')
.expect(HttpStatus.FOUND);
.post('/api/oauth/preflight')
.send({ provider: 'Google' })
.expect(HttpStatus.OK);
const redirect = new URL(res.header.location);
const { url } = res.body;
const redirect = new URL(url);
t.is(redirect.origin, 'https://accounts.google.com');
t.is(redirect.pathname, '/o/oauth2/v2/auth');
@ -83,7 +90,8 @@ test('should throw if provider is invalid', async t => {
const { app } = t.context;
await request(app.getHttpServer())
.get('/oauth/login?provider=Invalid')
.post('/api/oauth/preflight')
.send({ provider: 'Invalid' })
.expect(HttpStatus.BAD_REQUEST)
.expect({
status: 400,
@ -101,7 +109,6 @@ test('should be able to save oauth state', async t => {
const { oauth } = t.context;
const id = await oauth.saveOAuthState({
redirectUri: 'https://example.com',
provider: OAuthProviderName.Google,
});
@ -109,7 +116,6 @@ test('should be able to save oauth state', async t => {
t.truthy(state);
t.is(state!.provider, OAuthProviderName.Google);
t.is(state!.redirectUri, 'https://example.com');
});
test('should be able to get registered oauth providers', async t => {
@ -124,7 +130,8 @@ test('should throw if code is missing in callback uri', async t => {
const { app } = t.context;
await request(app.getHttpServer())
.get('/oauth/callback')
.post('/api/oauth/callback')
.send({})
.expect(HttpStatus.BAD_REQUEST)
.expect({
status: 400,
@ -142,7 +149,8 @@ test('should throw if state is missing in callback uri', async t => {
const { app } = t.context;
await request(app.getHttpServer())
.get('/oauth/callback?code=1')
.post('/api/oauth/callback')
.send({ code: '1' })
.expect(HttpStatus.BAD_REQUEST)
.expect({
status: 400,
@ -161,7 +169,8 @@ test('should throw if state is expired', async t => {
Sinon.stub(oauth, 'isValidState').resolves(true);
await request(app.getHttpServer())
.get('/oauth/callback?code=1&state=1')
.post('/api/oauth/callback')
.send({ code: '1', state: '1' })
.expect(HttpStatus.BAD_REQUEST)
.expect({
status: 400,
@ -178,7 +187,8 @@ test('should throw if state is invalid', async t => {
const { app } = t.context;
await request(app.getHttpServer())
.get('/oauth/callback?code=1&state=1')
.post('/api/oauth/callback')
.send({ code: '1', state: '1' })
.expect(HttpStatus.BAD_REQUEST)
.expect({
status: 400,
@ -199,7 +209,8 @@ test('should throw if provider is missing in state', async t => {
Sinon.stub(oauth, 'isValidState').resolves(true);
await request(app.getHttpServer())
.get(`/oauth/callback?code=1&state=1`)
.post('/api/oauth/callback')
.send({ code: '1', state: '1' })
.expect(HttpStatus.BAD_REQUEST)
.expect({
status: 400,
@ -221,7 +232,8 @@ test('should throw if provider is invalid in callback uri', async t => {
Sinon.stub(oauth, 'isValidState').resolves(true);
await request(app.getHttpServer())
.get(`/oauth/callback?code=1&state=1`)
.post('/api/oauth/callback')
.send({ code: '1', state: '1' })
.expect(HttpStatus.BAD_REQUEST)
.expect({
status: 400,
@ -242,7 +254,6 @@ function mockOAuthProvider(app: INestApplication, email: string) {
Sinon.stub(oauth, 'isValidState').resolves(true);
Sinon.stub(oauth, 'getOAuthState').resolves({
provider: OAuthProviderName.Google,
redirectUri: '/',
});
// @ts-expect-error mock
@ -260,8 +271,9 @@ test('should be able to sign up with oauth', async t => {
mockOAuthProvider(app, 'u2@affine.pro');
const res = await request(app.getHttpServer())
.get(`/oauth/callback?code=1&state=1`)
.expect(HttpStatus.FOUND);
.post(`/api/oauth/callback`)
.send({ code: '1', state: '1' })
.expect(HttpStatus.OK);
const session = await getSession(app, res);
@ -283,22 +295,17 @@ test('should be able to sign up with oauth', async t => {
t.is(user!.connectedAccounts[0].providerAccountId, '1');
});
test('should throw if account register in another way', async t => {
test('should not throw if account registered', async t => {
const { app, u1 } = t.context;
mockOAuthProvider(app, u1.email);
const res = await request(app.getHttpServer())
.get(`/oauth/callback?code=1&state=1`)
.expect(HttpStatus.FOUND);
.post(`/api/oauth/callback`)
.send({ code: '1', state: '1' })
.expect(HttpStatus.OK);
const link = new URL(res.headers.location);
t.is(link.pathname, '/signIn');
t.is(
link.searchParams.get('error'),
'You are trying to sign in by a different method than you signed up with.'
);
t.is(res.body.id, u1.id);
});
test('should be able to fullfil user with oauth sign in', async t => {
@ -313,8 +320,9 @@ test('should be able to fullfil user with oauth sign in', async t => {
mockOAuthProvider(app, u3.email);
const res = await request(app.getHttpServer())
.get(`/oauth/callback?code=1&state=1`)
.expect(HttpStatus.FOUND);
.post('/api/oauth/callback')
.send({ code: '1', state: '1' })
.expect(HttpStatus.OK);
const session = await getSession(app, res);
@ -329,60 +337,3 @@ test('should be able to fullfil user with oauth sign in', async t => {
t.truthy(account);
});
test('should throw if oauth account already connected', async t => {
const { app, db, u1, auth } = t.context;
await db.connectedAccount.create({
data: {
userId: u1.id,
provider: OAuthProviderName.Google,
providerAccountId: '1',
},
});
Sinon.stub(auth, 'getUserSession').resolves({
user: { id: 'u2-id' },
session: {},
} as any);
mockOAuthProvider(app, 'u2@affine.pro');
const res = await request(app.getHttpServer())
.get(`/oauth/callback?code=1&state=1`)
.set('cookie', `${AuthService.sessionCookieName}=1`)
.expect(HttpStatus.FOUND);
const link = new URL(res.headers.location);
t.is(link.pathname, '/signIn');
t.is(
link.searchParams.get('error'),
'The third-party account has already been connected to another user.'
);
});
test('should be able to connect oauth account', async t => {
const { app, u1, auth, db } = t.context;
Sinon.stub(auth, 'getUserSession').resolves({
user: { id: u1.id },
session: {},
} as any);
mockOAuthProvider(app, u1.email);
await request(app.getHttpServer())
.get(`/oauth/callback?code=1&state=1`)
.set('cookie', `${AuthService.sessionCookieName}=1`)
.expect(HttpStatus.FOUND);
const account = await db.connectedAccount.findFirst({
where: {
userId: u1.id,
},
});
t.truthy(account);
t.is(account!.userId, u1.id);
});

View File

@ -69,7 +69,7 @@ test.beforeEach(async t => {
t.context.db = app.get(PrismaClient);
t.context.app = app;
t.context.u1 = await app.get(AuthService).signUp('u1', 'u1@affine.pro', '1');
t.context.u1 = await app.get(AuthService).signUp('u1@affine.pro', '1');
await t.context.db.userStripeCustomer.create({
data: {
userId: t.context.u1.id,

View File

@ -44,7 +44,7 @@ test.afterEach.always(async t => {
test('should be able to set quota', async t => {
const { auth, quota } = t.context;
const u1 = await auth.signUp('DarkSky', 'darksky@example.org', '123456');
const u1 = await auth.signUp('test@affine.pro', '123456');
const q1 = await quota.getUserQuota(u1.id);
t.truthy(q1, 'should have quota');
@ -62,7 +62,7 @@ test('should be able to set quota', async t => {
test('should be able to check storage quota', async t => {
const { auth, quota, quotaManager } = t.context;
const u1 = await auth.signUp('DarkSky', 'darksky@example.org', '123456');
const u1 = await auth.signUp('test@affine.pro', '123456');
const freePlan = FreePlan.configs;
const proPlan = ProPlan.configs;
@ -78,7 +78,7 @@ test('should be able to check storage quota', async t => {
test('should be able revert quota', async t => {
const { auth, quota, quotaManager } = t.context;
const u1 = await auth.signUp('DarkSky', 'darksky@example.org', '123456');
const u1 = await auth.signUp('test@affine.pro', '123456');
const freePlan = FreePlan.configs;
const proPlan = ProPlan.configs;
@ -113,7 +113,7 @@ test('should be able revert quota', async t => {
test('should be able to check quota', async t => {
const { auth, quotaManager } = t.context;
const u1 = await auth.signUp('DarkSky', 'darksky@example.org', '123456');
const u1 = await auth.signUp('test@affine.pro', '123456');
const freePlan = FreePlan.configs;
const q1 = await quotaManager.getUserQuota(u1.id);

View File

@ -17,7 +17,7 @@ test.beforeEach(async t => {
imports: [AppModule],
});
t.context.u1 = await app.get(AuthService).signUp('u1', 'u1@affine.pro', '1');
t.context.u1 = await app.get(AuthService).signUp('u1@affine.pro', '1');
t.context.app = app;
});

View File

@ -13,7 +13,7 @@ import { gql } from './common';
export async function internalSignIn(app: INestApplication, userId: string) {
const auth = app.get(AuthService);
const session = await auth.createUserSession({ id: userId });
const session = await auth.createUserSession(userId);
return `${AuthService.sessionCookieName}=${session.sessionId}`;
}
@ -56,7 +56,7 @@ export async function signUp(
password,
emailVerifiedAt: autoVerifyEmail ? new Date() : null,
});
const { sessionId } = await app.get(AuthService).createUserSession(user);
const { sessionId } = await app.get(AuthService).createUserSession(user.id);
return {
...sessionUser(user),

View File

@ -33,7 +33,7 @@ test.before(async t => {
});
const auth = app.get(AuthService);
t.context.u1 = await auth.signUp('u1', 'u1@affine.pro', '1');
t.context.u1 = await auth.signUp('u1@affine.pro', '1');
const db = app.get(PrismaClient);
t.context.db = db;

View File

@ -1160,7 +1160,7 @@ export interface UserType {
/** @deprecated use `UserType.subscriptions` */
subscription: Maybe<UserSubscription>;
subscriptions: Array<UserSubscription>;
/** @deprecated use [/api/auth/authorize] */
/** @deprecated use [/api/auth/sign-in?native=true] instead */
token: TokenType;
}