refactor(server): auth (#5895)

Remove `next-auth` and implement our own Authorization/Authentication system from scratch.

## Server

- [x] tokens
  - [x] function
  - [x] encryption

- [x] AuthController
  - [x] /api/auth/sign-in
  - [x] /api/auth/sign-out
  - [x] /api/auth/session
  - [x] /api/auth/session (WE SUPPORT MULTI-ACCOUNT!)

- [x] OAuthPlugin
  - [x] OAuthController
  - [x] /oauth/login
  - [x] /oauth/callback
  - [x] Providers
    - [x] Google
    - [x] GitHub

## Client

- [x] useSession
- [x] cloudSignIn
- [x] cloudSignOut

## NOTE:

Tests will be adding in the future
This commit is contained in:
liuyi 2024-03-12 10:00:09 +00:00
parent af49e8cc41
commit fb3a0e7b8f
No known key found for this signature in database
GPG Key ID: 56709255DC7EC728
148 changed files with 3407 additions and 2851 deletions

View File

@ -31,17 +31,6 @@ const createPattern = packageName => [
message: 'Use `useNavigateHelper` instead',
importNames: ['useNavigate'],
},
{
group: ['next-auth/react'],
message: "Import hooks from 'use-current-user.tsx'",
// useSession is type unsafe
importNames: ['useSession'],
},
{
group: ['next-auth/react'],
message: "Import hooks from 'cloud-utils.ts'",
importNames: ['signIn', 'signOut'],
},
{
group: ['yjs'],
message: 'Do not use this API because it has a bug',
@ -179,17 +168,6 @@ const config = {
message: 'Use `useNavigateHelper` instead',
importNames: ['useNavigate'],
},
{
group: ['next-auth/react'],
message: "Import hooks from 'use-current-user.tsx'",
// useSession is type unsafe
importNames: ['useSession'],
},
{
group: ['next-auth/react'],
message: "Import hooks from 'cloud-utils.ts'",
importNames: ['signIn', 'signOut'],
},
{
group: ['yjs'],
message: 'Do not use this API because it has a bug',

View File

@ -336,17 +336,11 @@ jobs:
env:
PGPASSWORD: affine
- name: Generate prisma client
- name: Run init-db script
run: |
yarn workspace @affine/server exec prisma generate
yarn workspace @affine/server exec prisma db push
env:
DATABASE_URL: postgresql://affine:affine@localhost:5432/affine
- name: Run init-db script
run: |
yarn workspace @affine/server data-migration run
yarn workspace @affine/server exec node --loader ts-node/esm/transpile-only ./scripts/init-db.ts
env:
DATABASE_URL: postgresql://affine:affine@localhost:5432/affine
@ -435,17 +429,11 @@ jobs:
env:
PGPASSWORD: affine
- name: Generate prisma client
- name: Run init-db script
run: |
yarn workspace @affine/server exec prisma generate
yarn workspace @affine/server exec prisma db push
env:
DATABASE_URL: postgresql://affine:affine@localhost:5432/affine
- name: Run init-db script
run: |
yarn workspace @affine/server data-migration run
yarn workspace @affine/server exec node --loader ts-node/esm/transpile-only ./scripts/init-db.ts
- name: ${{ matrix.tests.name }}
run: |

View File

@ -167,7 +167,6 @@
"unbox-primitive": "npm:@nolyfill/unbox-primitive@latest",
"which-boxed-primitive": "npm:@nolyfill/which-boxed-primitive@latest",
"which-typed-array": "npm:@nolyfill/which-typed-array@latest",
"next-auth@^4.24.5": "patch:next-auth@npm%3A4.24.5#~/.yarn/patches/next-auth-npm-4.24.5-8428e11927.patch",
"@reforged/maker-appimage/@electron-forge/maker-base": "7.3.0",
"macos-alias": "npm:@napi-rs/macos-alias@latest",
"fs-xattr": "npm:@napi-rs/xattr@latest",

View File

@ -0,0 +1,70 @@
-- DropForeignKey
ALTER TABLE "accounts" DROP CONSTRAINT "accounts_user_id_fkey";
-- DropForeignKey
ALTER TABLE "sessions" DROP CONSTRAINT "sessions_user_id_fkey";
-- CreateTable
CREATE TABLE "user_connected_accounts" (
"id" VARCHAR(36) NOT NULL,
"user_id" VARCHAR(36) NOT NULL,
"provider" VARCHAR NOT NULL,
"provider_account_id" VARCHAR NOT NULL,
"scope" TEXT,
"access_token" TEXT,
"refresh_token" TEXT,
"expires_at" TIMESTAMPTZ(6),
"created_at" TIMESTAMPTZ(6) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updated_at" TIMESTAMPTZ(6) NOT NULL,
CONSTRAINT "user_connected_accounts_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "multiple_users_sessions" (
"id" VARCHAR(36) NOT NULL,
"expires_at" TIMESTAMPTZ(6),
"created_at" TIMESTAMPTZ(6) NOT NULL DEFAULT CURRENT_TIMESTAMP,
CONSTRAINT "multiple_users_sessions_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "user_sessions" (
"id" VARCHAR(36) NOT NULL,
"session_id" VARCHAR(36) NOT NULL,
"user_id" VARCHAR(36) NOT NULL,
"expires_at" TIMESTAMPTZ(6),
"created_at" TIMESTAMPTZ(6) NOT NULL DEFAULT CURRENT_TIMESTAMP,
CONSTRAINT "user_sessions_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "verification_tokens" (
"token" VARCHAR(36) NOT NULL,
"type" SMALLINT NOT NULL,
"credential" TEXT,
"expiresAt" TIMESTAMPTZ(6) NOT NULL
);
-- CreateIndex
CREATE INDEX "user_connected_accounts_user_id_idx" ON "user_connected_accounts"("user_id");
-- CreateIndex
CREATE INDEX "user_connected_accounts_provider_account_id_idx" ON "user_connected_accounts"("provider_account_id");
-- CreateIndex
CREATE UNIQUE INDEX "user_sessions_session_id_user_id_key" ON "user_sessions"("session_id", "user_id");
-- CreateIndex
CREATE UNIQUE INDEX "verification_tokens_type_token_key" ON "verification_tokens"("type", "token");
-- AddForeignKey
ALTER TABLE "user_connected_accounts" ADD CONSTRAINT "user_connected_accounts_user_id_fkey" FOREIGN KEY ("user_id") REFERENCES "users"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "user_sessions" ADD CONSTRAINT "user_sessions_session_id_fkey" FOREIGN KEY ("session_id") REFERENCES "multiple_users_sessions"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "user_sessions" ADD CONSTRAINT "user_sessions_user_id_fkey" FOREIGN KEY ("user_id") REFERENCES "users"("id") ON DELETE CASCADE ON UPDATE CASCADE;

View File

@ -74,7 +74,6 @@
"nanoid": "^5.0.6",
"nest-commander": "^3.12.5",
"nestjs-throttler-storage-redis": "^0.4.1",
"next-auth": "^4.24.5",
"nodemailer": "^6.9.10",
"on-headers": "^1.0.2",
"parse-duration": "^1.1.0",
@ -143,7 +142,8 @@
"MAILER_USER": "noreply@toeverything.info",
"MAILER_PASSWORD": "affine",
"MAILER_SENDER": "noreply@toeverything.info",
"FEATURES_EARLY_ACCESS_PREVIEW": "false"
"FEATURES_EARLY_ACCESS_PREVIEW": "false",
"DEPLOYMENT_TYPE": "affine"
}
},
"nodemonConfig": {

View File

@ -13,25 +13,77 @@ model User {
id String @id @default(uuid()) @db.VarChar
name String
email String @unique
emailVerified DateTime? @map("email_verified")
// image field is for the next-auth
emailVerifiedAt DateTime? @map("email_verified")
avatarUrl String? @map("avatar_url") @db.VarChar
createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(6)
/// Not available if user signed up through OAuth providers
password String? @db.VarChar
accounts Account[]
sessions Session[]
features UserFeatures[]
customer UserStripeCustomer?
subscription UserSubscription?
invoices UserInvoice[]
workspacePermissions WorkspaceUserPermission[]
pagePermissions WorkspacePageUserPermission[]
connectedAccounts ConnectedAccount[]
sessions UserSession[]
@@map("users")
}
model ConnectedAccount {
id String @id @default(uuid()) @db.VarChar(36)
userId String @map("user_id") @db.VarChar(36)
provider String @db.VarChar
providerAccountId String @map("provider_account_id") @db.VarChar
scope String? @db.Text
accessToken String? @map("access_token") @db.Text
refreshToken String? @map("refresh_token") @db.Text
expiresAt DateTime? @map("expires_at") @db.Timestamptz(6)
createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(6)
updatedAt DateTime @updatedAt @map("updated_at") @db.Timestamptz(6)
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
@@index([userId])
@@index([providerAccountId])
@@map("user_connected_accounts")
}
model Session {
id String @id @default(uuid()) @db.VarChar(36)
expiresAt DateTime? @map("expires_at") @db.Timestamptz(6)
createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(6)
userSessions UserSession[]
@@map("multiple_users_sessions")
}
model UserSession {
id String @id @default(uuid()) @db.VarChar(36)
sessionId String @map("session_id") @db.VarChar(36)
userId String @map("user_id") @db.VarChar(36)
expiresAt DateTime? @map("expires_at") @db.Timestamptz(6)
createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(6)
session Session @relation(fields: [sessionId], references: [id], onDelete: Cascade)
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
@@unique([sessionId, userId])
@@map("user_sessions")
}
model VerificationToken {
token String @db.VarChar(36)
type Int @db.SmallInt
credential String? @db.Text
expiresAt DateTime @db.Timestamptz(6)
@@unique([type, token])
@@map("verification_tokens")
}
model Workspace {
id String @id @default(uuid()) @db.VarChar
public Boolean
@ -186,7 +238,7 @@ model Features {
@@map("features")
}
model Account {
model DeprecatedNextAuthAccount {
id String @id @default(cuid())
userId String @map("user_id")
type String
@ -200,23 +252,20 @@ model Account {
id_token String? @db.Text
session_state String?
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
@@unique([provider, providerAccountId])
@@map("accounts")
}
model Session {
model DeprecatedNextAuthSession {
id String @id @default(cuid())
sessionToken String @unique @map("session_token")
userId String @map("user_id")
expires DateTime
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
@@map("sessions")
}
model VerificationToken {
model DeprecatedNextAuthVerificationToken {
identifier String
token String @unique
expires DateTime

View File

@ -1,37 +0,0 @@
import userA from '@affine-test/fixtures/userA.json' assert { type: 'json' };
import { hash } from '@node-rs/argon2';
import { PrismaClient } from '@prisma/client';
const prisma = new PrismaClient();
async function main() {
await prisma.user.create({
data: {
...userA,
password: await hash(userA.password),
features: {
create: {
reason: 'created by api sign up',
activated: true,
feature: {
connect: {
feature_version: {
feature: 'free_plan_v1',
version: 1,
},
},
},
},
},
},
});
}
main()
.then(async () => {
await prisma.$disconnect();
})
.catch(async e => {
console.error(e);
await prisma.$disconnect();
process.exit(1);
});

View File

@ -1,20 +1,20 @@
import { join } from 'node:path';
import { Logger, Module } from '@nestjs/common';
import { APP_INTERCEPTOR } from '@nestjs/core';
import { APP_GUARD, APP_INTERCEPTOR } from '@nestjs/core';
import { ScheduleModule } from '@nestjs/schedule';
import { ServeStaticModule } from '@nestjs/serve-static';
import { get } from 'lodash-es';
import { AppController } from './app.controller';
import { AuthModule } from './core/auth';
import { AuthGuard, AuthModule } from './core/auth';
import { ADD_ENABLED_FEATURES, ServerConfigModule } from './core/config';
import { DocModule } from './core/doc';
import { FeatureModule } from './core/features';
import { QuotaModule } from './core/quota';
import { StorageModule } from './core/storage';
import { SyncModule } from './core/sync';
import { UsersModule } from './core/users';
import { UserModule } from './core/user';
import { WorkspaceModule } from './core/workspaces';
import { getOptionalModuleMetadata } from './fundamentals';
import { CacheInterceptor, CacheModule } from './fundamentals/cache';
@ -25,14 +25,14 @@ import {
} from './fundamentals/config';
import { EventModule } from './fundamentals/event';
import { GqlModule } from './fundamentals/graphql';
import { HelpersModule } from './fundamentals/helpers';
import { MailModule } from './fundamentals/mailer';
import { MetricsModule } from './fundamentals/metrics';
import { PrismaModule } from './fundamentals/prisma';
import { SessionModule } from './fundamentals/session';
import { StorageProviderModule } from './fundamentals/storage';
import { RateLimiterModule } from './fundamentals/throttler';
import { WebSocketModule } from './fundamentals/websocket';
import { pluginsMap } from './plugins';
import { REGISTERED_PLUGINS } from './plugins';
export const FunctionalityModules = [
ConfigModule.forRoot(),
@ -42,9 +42,9 @@ export const FunctionalityModules = [
PrismaModule,
MetricsModule,
RateLimiterModule,
SessionModule,
MailModule,
StorageProviderModule,
HelpersModule,
];
export class AppModuleBuilder {
@ -109,6 +109,10 @@ export class AppModuleBuilder {
provide: APP_INTERCEPTOR,
useClass: CacheInterceptor,
},
{
provide: APP_GUARD,
useClass: AuthGuard,
},
],
imports: this.modules,
controllers: this.config.isSelfhosted ? [] : [AppController],
@ -141,7 +145,7 @@ function buildAppModule() {
WebSocketModule,
GqlModule,
StorageModule,
UsersModule,
UserModule,
WorkspaceModule,
FeatureModule,
QuotaModule
@ -157,7 +161,7 @@ function buildAppModule() {
// plugin modules
AFFiNE.plugins.enabled.forEach(name => {
const plugin = pluginsMap.get(name as AvailablePlugins);
const plugin = REGISTERED_PLUGINS.get(name as AvailablePlugins);
if (!plugin) {
throw new Error(`Unknown plugin ${name}`);
}

View File

@ -7,12 +7,10 @@ AFFiNE.ENV_MAP = {
DATABASE_URL: 'db.url',
ENABLE_CAPTCHA: ['auth.captcha.enable', 'boolean'],
CAPTCHA_TURNSTILE_SECRET: ['auth.captcha.turnstile.secret', 'string'],
OAUTH_GOOGLE_ENABLED: ['auth.oauthProviders.google.enabled', 'boolean'],
OAUTH_GOOGLE_CLIENT_ID: 'auth.oauthProviders.google.clientId',
OAUTH_GOOGLE_CLIENT_SECRET: 'auth.oauthProviders.google.clientSecret',
OAUTH_GITHUB_ENABLED: ['auth.oauthProviders.github.enabled', 'boolean'],
OAUTH_GITHUB_CLIENT_ID: 'auth.oauthProviders.github.clientId',
OAUTH_GITHUB_CLIENT_SECRET: 'auth.oauthProviders.github.clientSecret',
OAUTH_GOOGLE_CLIENT_ID: 'plugins.oauth.providers.google.clientId',
OAUTH_GOOGLE_CLIENT_SECRET: 'plugins.oauth.providers.google.clientSecret',
OAUTH_GITHUB_CLIENT_ID: 'plugins.oauth.providers.github.clientId',
OAUTH_GITHUB_CLIENT_SECRET: 'plugins.oauth.providers.github.clientSecret',
MAILER_HOST: 'mailer.host',
MAILER_PORT: ['mailer.port', 'int'],
MAILER_USER: 'mailer.auth.user',

View File

@ -40,6 +40,7 @@ if (env.R2_OBJECT_STORAGE_ACCOUNT_ID) {
AFFiNE.plugins.use('redis');
AFFiNE.plugins.use('payment');
AFFiNE.plugins.use('oauth');
if (AFFiNE.deploy) {
AFFiNE.mailer = {

View File

@ -115,3 +115,27 @@ AFFiNE.plugins.use('payment', {
// /* Update the provider of storages */
// AFFiNE.storage.storages.blob.provider = 'r2';
// AFFiNE.storage.storages.avatar.provider = 'r2';
//
// /* OAuth Plugin */
// AFFiNE.plugins.use('oauth', {
// providers: {
// github: {
// clientId: '',
// clientSecret: '',
// // See https://docs.github.com/en/apps/oauth-apps/building-oauth-apps/authorizing-oauth-apps
// args: {
// scope: 'user',
// },
// },
// google: {
// clientId: '',
// clientSecret: '',
// args: {
// // See https://developers.google.com/identity/protocols/oauth2
// scope: 'openid email profile',
// promot: 'select_account',
// access_type: 'offline',
// },
// },
// },
// });

View File

@ -0,0 +1,212 @@
import { randomUUID } from 'node:crypto';
import {
BadRequestException,
Body,
Controller,
Get,
Header,
Post,
Query,
Req,
Res,
} from '@nestjs/common';
import type { Request, Response } from 'express';
import {
Config,
PaymentRequiredException,
URLHelper,
} from '../../fundamentals';
import { UserService } from '../user';
import { validators } from '../utils/validators';
import { CurrentUser } from './current-user';
import { Public } from './guard';
import { AuthService, parseAuthUserSeqNum } from './service';
import { TokenService, TokenType } from './token';
class SignInCredential {
email!: string;
password?: string;
}
@Controller('/api/auth')
export class AuthController {
constructor(
private readonly config: Config,
private readonly url: URLHelper,
private readonly auth: AuthService,
private readonly user: UserService,
private readonly token: TokenService
) {}
@Public()
@Post('/sign-in')
@Header('content-type', 'application/json')
async signIn(
@Req() req: Request,
@Res() res: Response,
@Body() credential: SignInCredential,
@Query('redirect_uri') redirectUri = this.url.home
) {
validators.assertValidEmail(credential.email);
const canSignIn = await this.auth.canSignIn(credential.email);
if (!canSignIn) {
throw new PaymentRequiredException(
`You don't have early access permission\nVisit https://community.affine.pro/c/insider-general/ for more information`
);
}
if (credential.password) {
validators.assertValidPassword(credential.password);
const user = await this.auth.signIn(
credential.email,
credential.password
);
await this.auth.setCookie(req, res, user);
res.send(user);
} else {
// send email magic link
const user = await this.user.findUserByEmail(credential.email);
const result = await this.sendSignInEmail(
{ email: credential.email, signUp: !user },
redirectUri
);
if (result.rejected.length) {
throw new Error('Failed to send sign-in email.');
}
res.send({
email: credential.email,
});
}
}
async sendSignInEmail(
{ email, signUp }: { email: string; signUp: boolean },
redirectUri: string
) {
const token = await this.token.createToken(TokenType.SignIn, email);
const magicLink = this.url.link('/api/auth/magic-link', {
token,
email,
redirect_uri: redirectUri,
});
const result = await this.auth.sendSignInEmail(email, magicLink, signUp);
return result;
}
@Get('/sign-out')
async signOut(
@Req() req: Request,
@Res() res: Response,
@Query('redirect_uri') redirectUri?: string
) {
const session = await this.auth.signOut(
req.cookies[AuthService.sessionCookieName],
parseAuthUserSeqNum(req.headers[AuthService.authUserSeqHeaderName])
);
if (session) {
res.cookie(AuthService.sessionCookieName, session.id, {
expires: session.expiresAt ?? void 0, // expiredAt is `string | null`
...this.auth.cookieOptions,
});
} else {
res.clearCookie(AuthService.sessionCookieName);
}
if (redirectUri) {
return this.url.safeRedirect(res, redirectUri);
} else {
return res.send(null);
}
}
@Public()
@Get('/magic-link')
async magicLinkSignIn(
@Req() req: Request,
@Res() res: Response,
@Query('token') token?: string,
@Query('email') email?: string,
@Query('redirect_uri') redirectUri = this.url.home
) {
if (!token || !email) {
throw new BadRequestException('Invalid Sign-in mail Token');
}
email = decodeURIComponent(email);
validators.assertValidEmail(email);
const valid = await this.token.verifyToken(TokenType.SignIn, token, {
credential: email,
});
if (!valid) {
throw new BadRequestException('Invalid Sign-in mail Token');
}
const user = await this.user.findOrCreateUser(email, {
emailVerifiedAt: new Date(),
});
await this.auth.setCookie(req, res, user);
return this.url.safeRedirect(res, redirectUri);
}
@Get('/authorize')
async authorize(
@CurrentUser() user: CurrentUser,
@Query('redirect_uri') redirect_uri?: string
) {
const session = await this.auth.createUserSession(
user,
undefined,
this.config.auth.accessToken.ttl
);
this.url.link(redirect_uri ?? '/open-app/redirect', {
token: session.sessionId,
});
}
@Public()
@Get('/session')
async currentSessionUser(@CurrentUser() user?: CurrentUser) {
return {
user,
};
}
@Public()
@Get('/sessions')
async currentSessionUsers(@Req() req: Request) {
const token = req.cookies[AuthService.sessionCookieName];
if (!token) {
return {
users: [],
};
}
return {
users: await this.auth.getUserList(token),
};
}
@Public()
@Get('/challenge')
async challenge() {
// TODO: impl in following PR
return {
challenge: randomUUID(),
resource: randomUUID(),
};
}
}

View File

@ -0,0 +1,55 @@
import type { ExecutionContext } from '@nestjs/common';
import { createParamDecorator } from '@nestjs/common';
import { User } from '@prisma/client';
import { getRequestResponseFromContext } from '../../fundamentals';
function getUserFromContext(context: ExecutionContext) {
return getRequestResponseFromContext(context).req.user;
}
/**
* Used to fetch current user from the request context.
*
* > The user may be undefined if authorization token or session cookie is not provided.
*
* @example
*
* ```typescript
* // Graphql Query
* \@Query(() => UserType)
* user(@CurrentUser() user: CurrentUser) {
* return user;
* }
* ```
*
* ```typescript
* // HTTP Controller
* \@Get('/user')
* user(@CurrentUser() user: CurrentUser) {
* return user;
* }
* ```
*
* ```typescript
* // for public apis
* \@Public()
* \@Get('/session')
* session(@currentUser() user?: CurrentUser) {
* return user
* }
* ```
*/
// interface and variable don't conflict
// eslint-disable-next-line no-redeclare
export const CurrentUser = createParamDecorator(
(_: unknown, context: ExecutionContext) => {
return getUserFromContext(context);
}
);
export interface CurrentUser
extends Omit<User, 'password' | 'createdAt' | 'emailVerifiedAt'> {
hasPassword: boolean | null;
emailVerified: boolean;
}

View File

@ -1,67 +1,74 @@
import type { CanActivate, ExecutionContext } from '@nestjs/common';
import type {
CanActivate,
ExecutionContext,
OnModuleInit,
} from '@nestjs/common';
import {
createParamDecorator,
Inject,
Injectable,
SetMetadata,
UnauthorizedException,
UseGuards,
} from '@nestjs/common';
import { Reflector } from '@nestjs/core';
import { PrismaClient } from '@prisma/client';
import type { NextAuthOptions } from 'next-auth';
import { AuthHandler } from 'next-auth/core';
import { ModuleRef, Reflector } from '@nestjs/core';
import { getRequestResponseFromContext } from '../../fundamentals';
import { NextAuthOptionsProvide } from './next-auth-options';
import { AuthService } from './service';
import { Config, getRequestResponseFromContext } from '../../fundamentals';
import { AuthService, parseAuthUserSeqNum } from './service';
export function getUserFromContext(context: ExecutionContext) {
return getRequestResponseFromContext(context).req.user;
function extractTokenFromHeader(authorization: string) {
if (!/^Bearer\s/i.test(authorization)) {
return;
}
return authorization.substring(7);
}
/**
* Used to fetch current user from the request context.
*
* > The user may be undefined if authorization token is not provided.
*
* @example
*
* ```typescript
* // Graphql Query
* \@Query(() => UserType)
* user(@CurrentUser() user?: User) {
* return user;
* }
* ```
*
* ```typescript
* // HTTP Controller
* \@Get('/user)
* user(@CurrentUser() user?: User) {
* return user;
* }
* ```
*/
export const CurrentUser = createParamDecorator(
(_: unknown, context: ExecutionContext) => {
return getUserFromContext(context);
}
);
@Injectable()
class AuthGuard implements CanActivate {
export class AuthGuard implements CanActivate, OnModuleInit {
private auth!: AuthService;
constructor(
@Inject(NextAuthOptionsProvide)
private readonly nextAuthOptions: NextAuthOptions,
private readonly auth: AuthService,
private readonly prisma: PrismaClient,
private readonly config: Config,
private readonly ref: ModuleRef,
private readonly reflector: Reflector
) {}
onModuleInit() {
this.auth = this.ref.get(AuthService, { strict: false });
}
async canActivate(context: ExecutionContext) {
const { req, res } = getRequestResponseFromContext(context);
const token = req.headers.authorization;
const { req } = getRequestResponseFromContext(context);
// check cookie
let sessionToken: string | undefined =
req.cookies[AuthService.sessionCookieName];
// backward compatibility for client older then 0.12
// TODO: remove
if (!sessionToken) {
sessionToken =
req.cookies[
this.config.https
? '__Secure-next-auth.session-token'
: 'next-auth.session-token'
];
}
if (!sessionToken && req.headers.authorization) {
sessionToken = extractTokenFromHeader(req.headers.authorization);
}
if (sessionToken) {
const userSeq = parseAuthUserSeqNum(
req.headers[AuthService.authUserSeqHeaderName]
);
const user = await this.auth.getUser(sessionToken, userSeq);
if (user) {
req.user = user;
}
}
// api is public
const isPublic = this.reflector.get<boolean>(
@ -69,63 +76,15 @@ class AuthGuard implements CanActivate {
context.getHandler()
);
// FIXME(@forehalo): @Publicable() is duplicated with @CurrentUser() user?: User
// ^ optional
// we can prefetch user session in each request even before this `Guard`
// api can be public, but if user is logged in, we can get user info
const isPublicable = this.reflector.get<boolean>(
'isPublicable',
context.getHandler()
);
if (isPublic) {
return true;
} else if (!token) {
if (!req.cookies) {
return isPublicable;
}
const session = await AuthHandler({
req: {
cookies: req.cookies,
action: 'session',
method: 'GET',
headers: req.headers,
},
options: this.nextAuthOptions,
});
const { body = {}, cookies, status = 200 } = session;
if (!body && !isPublicable) {
if (!req.user) {
throw new UnauthorizedException('You are not signed in.');
}
// @ts-expect-error body is user here
req.user = body.user;
if (cookies && res) {
for (const cookie of cookies) {
res.cookie(cookie.name, cookie.value, cookie.options);
}
}
return Boolean(
status === 200 &&
typeof body !== 'string' &&
// ignore body if api is publicable
(Object.keys(body).length || isPublicable)
);
} else {
const [type, jwt] = token.split(' ') ?? [];
if (type === 'Bearer') {
const claims = await this.auth.verify(jwt);
req.user = await this.prisma.user.findUnique({
where: { id: claims.id },
});
return !!req.user;
}
}
return false;
return true;
}
}
@ -140,7 +99,7 @@ class AuthGuard implements CanActivate {
* ```typescript
* \@Auth()
* \@Query(() => UserType)
* user(@CurrentUser() user: User) {
* user(@CurrentUser() user: CurrentUser) {
* return user;
* }
* ```
@ -151,5 +110,3 @@ export const Auth = () => {
// api is public accessible
export const Public = () => SetMetadata('isPublic', true);
// api is public accessible, but if user is logged in, we can get user info
export const Publicable = () => SetMetadata('isPublicable', true);

View File

@ -1,18 +1,21 @@
import { Global, Module } from '@nestjs/common';
import { Module } from '@nestjs/common';
import { NextAuthController } from './next-auth.controller';
import { NextAuthOptionsProvider } from './next-auth-options';
import { FeatureModule } from '../features';
import { UserModule } from '../user';
import { AuthController } from './controller';
import { AuthResolver } from './resolver';
import { AuthService } from './service';
import { TokenService } from './token';
@Global()
@Module({
providers: [AuthService, AuthResolver, NextAuthOptionsProvider],
exports: [AuthService, NextAuthOptionsProvider],
controllers: [NextAuthController],
imports: [FeatureModule, UserModule],
providers: [AuthService, AuthResolver, TokenService],
exports: [AuthService],
controllers: [AuthController],
})
export class AuthModule {}
export * from './guard';
export { TokenType } from './resolver';
export { ClientTokenType } from './resolver';
export { AuthService };
export * from './current-user';

View File

@ -1,286 +0,0 @@
import { PrismaAdapter } from '@auth/prisma-adapter';
import { FactoryProvider, Logger } from '@nestjs/common';
import { verify } from '@node-rs/argon2';
import { PrismaClient } from '@prisma/client';
import { assign, omit } from 'lodash-es';
import { NextAuthOptions } from 'next-auth';
import Credentials from 'next-auth/providers/credentials';
import Email from 'next-auth/providers/email';
import Github from 'next-auth/providers/github';
import Google from 'next-auth/providers/google';
import { Config, MailService, SessionService } from '../../fundamentals';
import { FeatureType } from '../features';
import { Quota_FreePlanV1_1 } from '../quota';
import {
decode,
encode,
sendVerificationRequest,
SendVerificationRequestParams,
} from './utils';
export const NextAuthOptionsProvide = Symbol('NextAuthOptions');
const TrustedProviders = ['google'];
export const NextAuthOptionsProvider: FactoryProvider<NextAuthOptions> = {
provide: NextAuthOptionsProvide,
useFactory(
config: Config,
prisma: PrismaClient,
mailer: MailService,
session: SessionService
) {
const logger = new Logger('NextAuth');
const prismaAdapter = PrismaAdapter(prisma);
// createUser exists in the adapter
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const createUser = prismaAdapter.createUser!.bind(prismaAdapter);
prismaAdapter.createUser = async data => {
const userData = {
name: data.name,
email: data.email,
avatarUrl: '',
emailVerified: data.emailVerified,
features: {
create: {
reason: 'created by email sign up',
activated: true,
feature: {
connect: {
feature_version: Quota_FreePlanV1_1,
},
},
},
},
};
if (data.email && !data.name) {
userData.name = data.email.split('@')[0];
}
if (data.image) {
userData.avatarUrl = data.image;
}
// @ts-expect-error third part library type mismatch
return createUser(userData);
};
// linkAccount exists in the adapter
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const linkAccount = prismaAdapter.linkAccount!.bind(prismaAdapter);
prismaAdapter.linkAccount = async account => {
// google account must be a verified email
if (TrustedProviders.includes(account.provider)) {
await prisma.user.update({
where: {
id: account.userId,
},
data: {
emailVerified: new Date(),
},
});
}
return linkAccount(account) as Promise<void>;
};
// getUser exists in the adapter
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const getUser = prismaAdapter.getUser!.bind(prismaAdapter)!;
prismaAdapter.getUser = async id => {
const result = await getUser(id);
if (result) {
// @ts-expect-error Third part library type mismatch
result.image = result.avatarUrl;
// @ts-expect-error Third part library type mismatch
result.hasPassword = Boolean(result.password);
}
return result;
};
prismaAdapter.createVerificationToken = async data => {
await session.set(
`${data.identifier}:${data.token}`,
Date.now() + session.sessionTtl
);
return data;
};
prismaAdapter.useVerificationToken = async ({ identifier, token }) => {
const expires = await session.get(`${identifier}:${token}`);
if (expires) {
return { identifier, token, expires: new Date(expires) };
} else {
return null;
}
};
const nextAuthOptions: NextAuthOptions = {
providers: [],
// @ts-expect-error Third part library type mismatch
adapter: prismaAdapter,
debug: !config.node.prod,
logger: {
debug(code, metadata) {
logger.debug(`${code}: ${JSON.stringify(metadata)}`);
},
error(code, metadata) {
if (metadata instanceof Error) {
// @ts-expect-error assign code to error
metadata.code = code;
logger.error(metadata);
} else if (metadata.error instanceof Error) {
assign(metadata.error, omit(metadata, 'error'), { code });
logger.error(metadata.error);
}
},
warn(code) {
logger.warn(code);
},
},
};
nextAuthOptions.providers.push(
// @ts-expect-error esm interop issue
Credentials.default({
name: 'Password',
credentials: {
email: {
label: 'Email',
type: 'text',
placeholder: 'torvalds@osdl.org',
},
password: { label: 'Password', type: 'password' },
},
async authorize(
credentials:
| Record<'email' | 'password' | 'hashedPassword', string>
| undefined
) {
if (!credentials) {
return null;
}
const { password, hashedPassword } = credentials;
if (!password || !hashedPassword) {
return null;
}
if (!(await verify(hashedPassword, password))) {
return null;
}
return credentials;
},
})
);
if (config.mailer && mailer) {
nextAuthOptions.providers.push(
// @ts-expect-error esm interop issue
Email.default({
sendVerificationRequest: (params: SendVerificationRequestParams) =>
sendVerificationRequest(config, logger, mailer, session, params),
})
);
}
if (config.auth.oauthProviders.github) {
nextAuthOptions.providers.push(
// @ts-expect-error esm interop issue
Github.default({
clientId: config.auth.oauthProviders.github.clientId,
clientSecret: config.auth.oauthProviders.github.clientSecret,
allowDangerousEmailAccountLinking: true,
})
);
}
if (config.auth.oauthProviders.google?.enabled) {
nextAuthOptions.providers.push(
// @ts-expect-error esm interop issue
Google.default({
clientId: config.auth.oauthProviders.google.clientId,
clientSecret: config.auth.oauthProviders.google.clientSecret,
checks: 'nonce',
allowDangerousEmailAccountLinking: true,
authorization: {
params: { scope: 'openid email profile', prompt: 'select_account' },
},
})
);
}
if (nextAuthOptions.providers.length > 1) {
// not only credentials provider
nextAuthOptions.session = { strategy: 'database' };
}
nextAuthOptions.jwt = {
encode: async ({ token, maxAge }) =>
encode(config, prisma, token, maxAge),
decode: async ({ token }) => decode(config, token),
};
nextAuthOptions.secret ??= config.auth.nextAuthSecret;
nextAuthOptions.callbacks = {
session: async ({ session, user, token }) => {
if (session.user) {
if (user) {
// @ts-expect-error Third part library type mismatch
session.user.id = user.id;
// @ts-expect-error Third part library type mismatch
session.user.image = user.image ?? user.avatarUrl;
// @ts-expect-error Third part library type mismatch
session.user.emailVerified = user.emailVerified;
// @ts-expect-error Third part library type mismatch
session.user.hasPassword = Boolean(user.password);
} else {
// technically the sub should be the same as id
// @ts-expect-error Third part library type mismatch
session.user.id = token.sub;
// @ts-expect-error Third part library type mismatch
session.user.emailVerified = token.emailVerified;
// @ts-expect-error Third part library type mismatch
session.user.hasPassword = token.hasPassword;
}
if (token && token.picture) {
session.user.image = token.picture;
}
}
return session;
},
signIn: async ({ profile, user }) => {
if (!config.featureFlags.earlyAccessPreview) {
return true;
}
const email = profile?.email ?? user.email;
if (email) {
// FIXME: cannot inject FeatureManagementService here
// it will cause prisma.account to be undefined
// then prismaAdapter.getUserByAccount will throw error
if (email.endsWith('@toeverything.info')) return true;
return prisma.userFeatures
.count({
where: {
user: {
email: {
equals: email,
mode: 'insensitive',
},
},
feature: {
feature: FeatureType.EarlyAccess,
},
activated: true,
},
})
.then(count => count > 0);
}
return false;
},
redirect({ url }) {
return url;
},
};
nextAuthOptions.pages = {
newUser: '/auth/onboarding',
};
return nextAuthOptions;
},
inject: [Config, PrismaClient, MailService, SessionService],
};

View File

@ -1,411 +0,0 @@
import { URLSearchParams } from 'node:url';
import {
All,
BadRequestException,
Controller,
Get,
Inject,
Logger,
Next,
NotFoundException,
Query,
Req,
Res,
UseGuards,
} from '@nestjs/common';
import { hash, verify } from '@node-rs/argon2';
import { PrismaClient, type User } from '@prisma/client';
import type { NextFunction, Request, Response } from 'express';
import { pick } from 'lodash-es';
import { nanoid } from 'nanoid';
import type { AuthAction, CookieOption, NextAuthOptions } from 'next-auth';
import { AuthHandler } from 'next-auth/core';
import {
AuthThrottlerGuard,
Config,
metrics,
SessionService,
Throttle,
} from '../../fundamentals';
import { NextAuthOptionsProvide } from './next-auth-options';
import { AuthService } from './service';
const BASE_URL = '/api/auth/';
const DEFAULT_SESSION_EXPIRE_DATE = 2592000 * 1000; // 30 days
@Controller(BASE_URL)
export class NextAuthController {
private readonly callbackSession;
private readonly logger = new Logger('NextAuthController');
constructor(
readonly config: Config,
readonly prisma: PrismaClient,
private readonly authService: AuthService,
@Inject(NextAuthOptionsProvide)
private readonly nextAuthOptions: NextAuthOptions,
private readonly session: SessionService
) {
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
this.callbackSession = nextAuthOptions.callbacks!.session;
}
@UseGuards(AuthThrottlerGuard)
@Throttle({
default: {
limit: 60,
ttl: 60,
},
})
@Get('/challenge')
async getChallenge(@Res() res: Response) {
const challenge = nanoid();
const resource = nanoid();
await this.session.set(challenge, resource, 5 * 60 * 1000);
res.json({ challenge, resource });
}
@UseGuards(AuthThrottlerGuard)
@Throttle({
default: {
limit: 60,
ttl: 60,
},
})
@All('*')
async auth(
@Req() req: Request,
@Res() res: Response,
@Query() query: Record<string, any>,
@Next() next: NextFunction
) {
if (req.path === '/api/auth/signin' && req.method === 'GET') {
const query = req.query
? // @ts-expect-error req.query is satisfy with the Record<string, any>
`?${new URLSearchParams(req.query).toString()}`
: '';
res.redirect(`/signin${query}`);
return;
}
const [action, providerId] = req.url // start with request url
.slice(BASE_URL.length) // make relative to baseUrl
.replace(/\?.*/, '') // remove query part, use only path part
.split('/') as [AuthAction, string]; // as array of strings;
metrics.auth.counter('call_counter').add(1, { action, providerId });
const credentialsSignIn =
req.method === 'POST' && providerId === 'credentials';
let userId: string | undefined;
if (credentialsSignIn) {
const { email } = req.body;
if (email) {
const user = await this.prisma.user.findFirst({
where: {
email: {
equals: email,
mode: 'insensitive',
},
},
});
if (!user) {
req.statusCode = 401;
req.statusMessage = 'User not found';
req.body = null;
throw new NotFoundException(`User not found`);
} else {
userId = user.id;
req.body = {
...req.body,
name: user.name,
email: user.email,
image: user.avatarUrl,
hashedPassword: user.password,
};
}
}
}
const options = this.nextAuthOptions;
if (req.method === 'POST' && action === 'session') {
if (typeof req.body !== 'object' || typeof req.body.data !== 'object') {
metrics.auth
.counter('call_fails_counter')
.add(1, { reason: 'invalid_session_data' });
throw new BadRequestException(`Invalid new session data`);
}
const user = await this.updateSession(req, req.body.data);
// callbacks.session existed
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
options.callbacks!.session = ({ session }) => {
return {
user: {
...pick(user, 'id', 'name', 'email'),
image: user.avatarUrl,
hasPassword: !!user.password,
},
expires: session.expires,
};
};
} else {
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
options.callbacks!.session = this.callbackSession;
}
if (
this.config.auth.captcha.enable &&
req.method === 'POST' &&
action === 'signin' &&
// TODO: add credentials support in frontend
['email'].includes(providerId)
) {
const isVerified = await this.verifyChallenge(req, res);
if (!isVerified) return;
}
const { status, headers, body, redirect, cookies } = await AuthHandler({
req: {
body: req.body,
query: query,
method: req.method,
action,
providerId,
error: query.error ?? providerId,
cookies: req.cookies,
},
options,
});
if (headers) {
for (const { key, value } of headers) {
res.setHeader(key, value);
}
}
if (cookies) {
for (const cookie of cookies) {
res.cookie(cookie.name, cookie.value, cookie.options);
}
}
let nextAuthTokenCookie: (CookieOption & { value: string }) | undefined;
const secureCookiePrefix = '__Secure-';
const sessionCookieName = `next-auth.session-token`;
// next-auth credentials login only support JWT strategy
// https://next-auth.js.org/configuration/providers/credentials
// let's store the session token in the database
if (
credentialsSignIn &&
(nextAuthTokenCookie = cookies?.find(
({ name }) =>
name === sessionCookieName ||
name === `${secureCookiePrefix}${sessionCookieName}`
))
) {
const cookieExpires = new Date();
cookieExpires.setTime(
cookieExpires.getTime() + DEFAULT_SESSION_EXPIRE_DATE
);
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
await this.nextAuthOptions.adapter!.createSession!({
sessionToken: nextAuthTokenCookie.value,
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
userId: userId!,
expires: cookieExpires,
});
}
if (redirect?.endsWith('api/auth/error?error=AccessDenied')) {
this.logger.log(`Early access redirect headers: ${req.headers}`);
metrics.auth
.counter('call_fails_counter')
.add(1, { reason: 'no_early_access_permission' });
if (
!req.headers?.referer ||
checkUrlOrigin(req.headers.referer, 'https://accounts.google.com')
) {
res.redirect('https://community.affine.pro/c/insider-general/');
} else {
res.status(403);
res.json({
url: 'https://community.affine.pro/c/insider-general/',
error: `You don't have early access permission`,
});
}
return;
}
if (status) {
res.status(status);
}
if (redirect) {
if (providerId === 'credentials') {
res.send(JSON.stringify({ ok: true, url: redirect }));
} else if (
action === 'callback' ||
action === 'error' ||
(providerId !== 'credentials' &&
// login in the next-auth page, /api/auth/signin, auto redirect.
// otherwise, return the json value to allow frontend to handle the redirect.
req.headers?.referer?.includes?.('/api/auth/signin'))
) {
res.redirect(redirect);
} else {
res.json({ url: redirect });
}
} else if (typeof body === 'string') {
res.send(body);
} else if (body && typeof body === 'object') {
res.json(body);
} else {
next();
}
}
private async updateSession(
req: Request,
newSession: Partial<Omit<User, 'id'>> & { oldPassword?: string }
): Promise<User> {
const { name, email, password, oldPassword } = newSession;
if (!name && !email && !password) {
throw new BadRequestException(`Invalid new session data`);
}
if (password) {
const user = await this.verifyUserFromRequest(req);
const { password: userPassword } = user;
if (!oldPassword) {
if (userPassword) {
throw new BadRequestException(
`Old password is required to update password`
);
}
} else {
if (!userPassword) {
throw new BadRequestException(`No existed password`);
}
if (await verify(userPassword, oldPassword)) {
await this.prisma.user.update({
where: {
id: user.id,
},
data: {
...pick(newSession, 'email', 'name'),
password: await hash(password),
},
});
}
}
return user;
} else {
const user = await this.verifyUserFromRequest(req);
return this.prisma.user.update({
where: {
id: user.id,
},
data: pick(newSession, 'name', 'email'),
});
}
}
private async verifyChallenge(req: Request, res: Response): Promise<boolean> {
const challenge = req.query?.challenge;
if (typeof challenge === 'string' && challenge) {
const resource = await this.session.get(challenge);
if (!resource) {
this.rejectResponse(res, 'Invalid Challenge');
return false;
}
const isChallengeVerified =
await this.authService.verifyChallengeResponse(
req.query?.token,
resource
);
this.logger.debug(
`Challenge: ${challenge}, Resource: ${resource}, Response: ${req.query?.token}, isChallengeVerified: ${isChallengeVerified}`
);
if (!isChallengeVerified) {
this.rejectResponse(res, 'Invalid Challenge Response');
return false;
}
} else {
const isTokenVerified = await this.authService.verifyCaptchaToken(
req.query?.token,
req.headers['CF-Connecting-IP'] as string
);
if (!isTokenVerified) {
this.rejectResponse(res, 'Invalid Captcha Response');
return false;
}
}
return true;
}
private async verifyUserFromRequest(req: Request): Promise<User> {
const token = req.headers.authorization;
if (!token) {
const session = await AuthHandler({
req: {
cookies: req.cookies,
action: 'session',
method: 'GET',
headers: req.headers,
},
options: this.nextAuthOptions,
});
const { body } = session;
// @ts-expect-error check if body.user exists
if (body && body.user && body.user.id) {
const user = await this.prisma.user.findUnique({
where: {
// @ts-expect-error body.user.id exists
id: body.user.id,
},
});
if (user) {
return user;
}
}
} else {
const [type, jwt] = token.split(' ') ?? [];
if (type === 'Bearer') {
const claims = await this.authService.verify(jwt);
const user = await this.prisma.user.findUnique({
where: { id: claims.id },
});
if (user) {
return user;
}
}
}
throw new BadRequestException(`User not found`);
}
rejectResponse(res: Response, error: string, status = 400) {
res.status(status);
res.json({
url: `${this.config.baseUrl}/api/auth/error?${new URLSearchParams({
error,
}).toString()}`,
error,
});
}
}
const checkUrlOrigin = (url: string, origin: string) => {
try {
return new URL(url).origin === origin;
} catch (e) {
return false;
}
};

View File

@ -10,24 +10,22 @@ import {
Mutation,
ObjectType,
Parent,
Query,
ResolveField,
Resolver,
} from '@nestjs/graphql';
import type { Request } from 'express';
import { nanoid } from 'nanoid';
import type { Request, Response } from 'express';
import {
CloudThrottlerGuard,
Config,
SessionService,
Throttle,
} from '../../fundamentals';
import { UserType } from '../users';
import { Auth, CurrentUser } from './guard';
import { CloudThrottlerGuard, Config, Throttle } from '../../fundamentals';
import { UserType } from '../user/types';
import { validators } from '../utils/validators';
import { CurrentUser } from './current-user';
import { Public } from './guard';
import { AuthService } from './service';
import { TokenService, TokenType } from './token';
@ObjectType()
export class TokenType {
@ObjectType('tokenType')
export class ClientTokenType {
@Field()
token!: string;
@ -50,46 +48,57 @@ export class AuthResolver {
constructor(
private readonly config: Config,
private readonly auth: AuthService,
private readonly session: SessionService
private readonly token: TokenService
) {}
@Throttle({
default: {
limit: 10,
ttl: 60,
},
})
@Public()
@Query(() => UserType, {
name: 'currentUser',
description: 'Get current user',
nullable: true,
})
currentUser(@CurrentUser() user?: CurrentUser): UserType | undefined {
return user;
}
@Throttle({
default: {
limit: 20,
ttl: 60,
},
})
@ResolveField(() => TokenType)
async token(
@Context() ctx: { req: Request },
@CurrentUser() currentUser: UserType,
@ResolveField(() => ClientTokenType, {
name: 'token',
deprecationReason: 'use [/api/auth/authorize]',
})
async clientToken(
@CurrentUser() currentUser: CurrentUser,
@Parent() user: UserType
) {
): Promise<ClientTokenType> {
if (user.id !== currentUser.id) {
throw new BadRequestException('Invalid user');
throw new ForbiddenException('Invalid user');
}
let sessionToken: string | undefined;
// only return session if the request is from the same origin & path == /open-app
if (
ctx.req.headers.referer &&
ctx.req.headers.host &&
new URL(ctx.req.headers.referer).pathname.startsWith('/open-app') &&
ctx.req.headers.host === new URL(this.config.origin).host
) {
const cookiePrefix = this.config.node.prod ? '__Secure-' : '';
const sessionCookieName = `${cookiePrefix}next-auth.session-token`;
sessionToken = ctx.req.cookies?.[sessionCookieName];
}
const session = await this.auth.createUserSession(
user,
undefined,
this.config.auth.accessToken.ttl
);
return {
sessionToken,
token: this.auth.sign(user),
refresh: this.auth.refresh(user),
sessionToken: session.sessionId,
token: session.sessionId,
refresh: '',
};
}
@Public()
@Throttle({
default: {
limit: 10,
@ -98,16 +107,19 @@ export class AuthResolver {
})
@Mutation(() => UserType)
async signUp(
@Context() ctx: { req: Request },
@Context() ctx: { req: Request; res: Response },
@Args('name') name: string,
@Args('email') email: string,
@Args('password') password: string
) {
validators.assertValidCredential({ email, password });
const user = await this.auth.signUp(name, email, password);
await this.auth.setCookie(ctx.req, ctx.res, user);
ctx.req.user = user;
return user;
}
@Public()
@Throttle({
default: {
limit: 10,
@ -116,11 +128,13 @@ export class AuthResolver {
})
@Mutation(() => UserType)
async signIn(
@Context() ctx: { req: Request },
@Context() ctx: { req: Request; res: Response },
@Args('email') email: string,
@Args('password') password: string
) {
validators.assertValidCredential({ email, password });
const user = await this.auth.signIn(email, password);
await this.auth.setCookie(ctx.req, ctx.res, user);
ctx.req.user = user;
return user;
}
@ -132,28 +146,26 @@ export class AuthResolver {
},
})
@Mutation(() => UserType)
@Auth()
async changePassword(
@CurrentUser() user: UserType,
@CurrentUser() user: CurrentUser,
@Args('token') token: string,
@Args('newPassword') newPassword: string
) {
const id = await this.session.get(token);
if (!user.emailVerified) {
throw new ForbiddenException('Please verify the email first');
validators.assertValidPassword(newPassword);
// NOTE: Set & Change password are using the same token type.
const valid = await this.token.verifyToken(
TokenType.ChangePassword,
token,
{
credential: user.id,
}
if (
!id ||
(id !== user.id &&
// change password after sign in with email link
// we only create user account after user sign in with email link
id !== user.email)
) {
);
if (!valid) {
throw new ForbiddenException('Invalid token');
}
await this.auth.changePassword(user.email, newPassword);
await this.session.delete(token);
return user;
}
@ -165,25 +177,24 @@ export class AuthResolver {
},
})
@Mutation(() => UserType)
@Auth()
async changeEmail(
@CurrentUser() user: UserType,
@Args('token') token: string
@CurrentUser() user: CurrentUser,
@Args('token') token: string,
@Args('email') email: string
) {
const key = await this.session.get(token);
if (!key) {
validators.assertValidEmail(email);
// @see [sendChangeEmail]
const valid = await this.token.verifyToken(TokenType.VerifyEmail, token, {
credential: user.id,
});
if (!valid) {
throw new ForbiddenException('Invalid token');
}
// email has set token in `sendVerifyChangeEmail`
const [id, email] = key.split(',');
if (!id || id !== user.id || !email) {
throw new ForbiddenException('Invalid token');
}
await this.auth.changeEmail(id, email);
await this.session.delete(token);
email = decodeURIComponent(email);
await this.auth.changeEmail(user.id, email);
await this.auth.sendNotificationChangeEmail(email);
return user;
@ -196,19 +207,29 @@ export class AuthResolver {
},
})
@Mutation(() => Boolean)
@Auth()
async sendChangePasswordEmail(
@CurrentUser() user: UserType,
@Args('email') email: string,
@Args('callbackUrl') callbackUrl: string
@CurrentUser() user: CurrentUser,
@Args('callbackUrl') callbackUrl: string,
// @deprecated
@Args('email', { nullable: true }) _email?: string
) {
const token = nanoid();
await this.session.set(token, user.id);
if (!user.emailVerified) {
throw new ForbiddenException('Please verify your email first.');
}
const token = await this.token.createToken(
TokenType.ChangePassword,
user.id
);
const url = new URL(callbackUrl, this.config.baseUrl);
url.searchParams.set('token', token);
const res = await this.auth.sendChangePasswordEmail(email, url.toString());
const res = await this.auth.sendChangePasswordEmail(
user.email,
url.toString()
);
return !res.rejected.length;
}
@ -219,19 +240,27 @@ export class AuthResolver {
},
})
@Mutation(() => Boolean)
@Auth()
async sendSetPasswordEmail(
@CurrentUser() user: UserType,
@Args('email') email: string,
@Args('callbackUrl') callbackUrl: string
@CurrentUser() user: CurrentUser,
@Args('callbackUrl') callbackUrl: string,
@Args('email', { nullable: true }) _email?: string
) {
const token = nanoid();
await this.session.set(token, user.id);
if (!user.emailVerified) {
throw new ForbiddenException('Please verify your email first.');
}
const token = await this.token.createToken(
TokenType.ChangePassword,
user.id
);
const url = new URL(callbackUrl, this.config.baseUrl);
url.searchParams.set('token', token);
const res = await this.auth.sendSetPasswordEmail(email, url.toString());
const res = await this.auth.sendSetPasswordEmail(
user.email,
url.toString()
);
return !res.rejected.length;
}
@ -249,19 +278,22 @@ export class AuthResolver {
},
})
@Mutation(() => Boolean)
@Auth()
async sendChangeEmail(
@CurrentUser() user: UserType,
@Args('email') email: string,
@Args('callbackUrl') callbackUrl: string
@CurrentUser() user: CurrentUser,
@Args('callbackUrl') callbackUrl: string,
// @deprecated
@Args('email', { nullable: true }) _email?: string
) {
const token = nanoid();
await this.session.set(token, user.id);
if (!user.emailVerified) {
throw new ForbiddenException('Please verify your email first.');
}
const token = await this.token.createToken(TokenType.ChangeEmail, user.id);
const url = new URL(callbackUrl, this.config.baseUrl);
url.searchParams.set('token', token);
const res = await this.auth.sendChangeEmail(email, url.toString());
const res = await this.auth.sendChangeEmail(user.email, url.toString());
return !res.rejected.length;
}
@ -272,34 +304,92 @@ export class AuthResolver {
},
})
@Mutation(() => Boolean)
@Auth()
async sendVerifyChangeEmail(
@CurrentUser() user: UserType,
@CurrentUser() user: CurrentUser,
@Args('token') token: string,
@Args('email') email: string,
@Args('callbackUrl') callbackUrl: string
) {
const id = await this.session.get(token);
if (!id || id !== user.id) {
validators.assertValidEmail(email);
const valid = await this.token.verifyToken(TokenType.ChangeEmail, token, {
credential: user.id,
});
if (!valid) {
throw new ForbiddenException('Invalid token');
}
const hasRegistered = await this.auth.getUserByEmail(email);
if (hasRegistered) {
throw new BadRequestException(`Invalid user email`);
if (hasRegistered.id !== user.id) {
throw new BadRequestException(`The email provided has been taken.`);
} else {
throw new BadRequestException(
`The email provided is the same as the current email.`
);
}
}
const withEmailToken = nanoid();
await this.session.set(withEmailToken, `${user.id},${email}`);
const verifyEmailToken = await this.token.createToken(
TokenType.VerifyEmail,
user.id
);
const url = new URL(callbackUrl, this.config.baseUrl);
url.searchParams.set('token', withEmailToken);
url.searchParams.set('token', verifyEmailToken);
url.searchParams.set('email', email);
const res = await this.auth.sendVerifyChangeEmail(email, url.toString());
await this.session.delete(token);
return !res.rejected.length;
}
@Throttle({
default: {
limit: 5,
ttl: 60,
},
})
@Mutation(() => Boolean)
async sendVerifyEmail(
@CurrentUser() user: CurrentUser,
@Args('callbackUrl') callbackUrl: string
) {
const token = await this.token.createToken(TokenType.VerifyEmail, user.id);
const url = new URL(callbackUrl, this.config.baseUrl);
url.searchParams.set('token', token);
const res = await this.auth.sendVerifyEmail(user.email, url.toString());
return !res.rejected.length;
}
@Throttle({
default: {
limit: 5,
ttl: 60,
},
})
@Mutation(() => Boolean)
async verifyEmail(
@CurrentUser() user: CurrentUser,
@Args('token') token: string
) {
if (!token) {
throw new BadRequestException('Invalid token');
}
const valid = await this.token.verifyToken(TokenType.VerifyEmail, token, {
credential: user.id,
});
if (!valid) {
throw new ForbiddenException('Invalid token');
}
const { emailVerifiedAt } = await this.auth.setEmailVerified(user.id);
return emailVerifiedAt !== null;
}
}

View File

@ -1,299 +1,327 @@
import { randomUUID } from 'node:crypto';
import {
BadRequestException,
Injectable,
InternalServerErrorException,
UnauthorizedException,
NotAcceptableException,
NotFoundException,
OnApplicationBootstrap,
} from '@nestjs/common';
import { hash, verify } from '@node-rs/argon2';
import { Algorithm, sign, verify as jwtVerify } from '@node-rs/jsonwebtoken';
import { PrismaClient, type User } from '@prisma/client';
import { nanoid } from 'nanoid';
import type { CookieOptions, Request, Response } from 'express';
import { assign, omit } from 'lodash-es';
import {
Config,
CryptoHelper,
MailService,
verifyChallengeResponse,
SessionCache,
} from '../../fundamentals';
import { Quota_FreePlanV1_1 } from '../quota';
import { FeatureManagementService } from '../features/management';
import { UserService } from '../user/service';
import type { CurrentUser } from './current-user';
export type UserClaim = Pick<
User,
'id' | 'name' | 'email' | 'emailVerified' | 'createdAt' | 'avatarUrl'
> & {
hasPassword?: boolean;
};
export function parseAuthUserSeqNum(value: any) {
switch (typeof value) {
case 'number': {
return value;
}
case 'string': {
value = Number.parseInt(value);
return Number.isNaN(value) ? 0 : value;
}
export const getUtcTimestamp = () => Math.floor(Date.now() / 1000);
default: {
return 0;
}
}
}
export function sessionUser(
user: Omit<User, 'password'> & { password?: string | null }
): CurrentUser {
return assign(omit(user, 'password', 'emailVerifiedAt', 'createdAt'), {
hasPassword: user.password !== null,
emailVerified: user.emailVerifiedAt !== null,
});
}
@Injectable()
export class AuthService {
export class AuthService implements OnApplicationBootstrap {
readonly cookieOptions: CookieOptions = {
sameSite: 'lax',
httpOnly: true,
path: '/',
domain: this.config.host,
secure: this.config.https,
};
static readonly sessionCookieName = 'sid';
static readonly authUserSeqHeaderName = 'x-auth-user';
constructor(
private readonly config: Config,
private readonly prisma: PrismaClient,
private readonly mailer: MailService
private readonly db: PrismaClient,
private readonly mailer: MailService,
private readonly feature: FeatureManagementService,
private readonly user: UserService,
private readonly crypto: CryptoHelper,
private readonly cache: SessionCache
) {}
sign(user: UserClaim) {
const now = getUtcTimestamp();
return sign(
{
data: {
id: user.id,
name: user.name,
email: user.email,
emailVerified: user.emailVerified?.toISOString(),
image: user.avatarUrl,
hasPassword: Boolean(user.hasPassword),
createdAt: user.createdAt.toISOString(),
},
iat: now,
exp: now + this.config.auth.accessTokenExpiresIn,
iss: this.config.serverId,
sub: user.id,
aud: 'https://affine.pro',
jti: randomUUID({
disableEntropyCache: true,
}),
},
this.config.auth.privateKey,
{
algorithm: Algorithm.ES256,
async onApplicationBootstrap() {
if (this.config.node.dev) {
await this.signUp('Dev User', 'dev@affine.pro', 'dev').catch(() => {
// ignore
});
}
);
}
refresh(user: UserClaim) {
const now = getUtcTimestamp();
return sign(
{
data: {
id: user.id,
name: user.name,
email: user.email,
emailVerified: user.emailVerified?.toISOString(),
image: user.avatarUrl,
hasPassword: Boolean(user.hasPassword),
createdAt: user.createdAt.toISOString(),
},
exp: now + this.config.auth.refreshTokenExpiresIn,
iat: now,
iss: this.config.serverId,
sub: user.id,
aud: 'https://affine.pro',
jti: randomUUID({
disableEntropyCache: true,
}),
},
this.config.auth.privateKey,
{
algorithm: Algorithm.ES256,
}
);
canSignIn(email: string) {
return this.feature.canEarlyAccess(email);
}
async verify(token: string) {
try {
const data = (
await jwtVerify(token, this.config.auth.publicKey, {
algorithms: [Algorithm.ES256],
iss: [this.config.serverId],
leeway: this.config.auth.leeway,
requiredSpecClaims: ['exp', 'iat', 'iss', 'sub'],
aud: ['https://affine.pro'],
async signUp(
name: string,
email: string,
password: string
): Promise<CurrentUser> {
const user = await this.getUserByEmail(email);
if (user) {
throw new BadRequestException('Email was taken');
}
const hashedPassword = await this.crypto.encryptPassword(password);
return this.user
.createUser({
name,
email,
password: hashedPassword,
})
).data as UserClaim;
return {
...data,
emailVerified: data.emailVerified ? new Date(data.emailVerified) : null,
createdAt: new Date(data.createdAt),
};
} catch (e) {
throw new UnauthorizedException('Invalid token');
}
.then(sessionUser);
}
async verifyCaptchaToken(token: any, ip: string) {
if (typeof token !== 'string' || !token) return false;
const formData = new FormData();
formData.append('secret', this.config.auth.captcha.turnstile.secret);
formData.append('response', token);
formData.append('remoteip', ip);
// prevent replay attack
formData.append('idempotency_key', nanoid());
const url = 'https://challenges.cloudflare.com/turnstile/v0/siteverify';
const result = await fetch(url, {
body: formData,
method: 'POST',
});
const outcome = await result.json();
return (
!!outcome.success &&
// skip hostname check in dev mode
(this.config.node.dev || outcome.hostname === this.config.host)
);
}
async verifyChallengeResponse(response: any, resource: string) {
return verifyChallengeResponse(
response,
this.config.auth.captcha.challenge.bits,
resource
);
}
async signIn(email: string, password: string): Promise<User> {
const user = await this.prisma.user.findFirst({
where: {
email: {
equals: email,
mode: 'insensitive',
},
},
});
async signIn(email: string, password: string) {
const user = await this.user.findUserWithHashedPasswordByEmail(email);
if (!user) {
throw new BadRequestException('Invalid email');
throw new NotFoundException('User Not Found');
}
if (!user.password) {
throw new BadRequestException('User has no password');
throw new NotAcceptableException(
'User Password is not set. Should login throw email link.'
);
}
let equal = false;
try {
equal = await verify(user.password, password);
} catch (e) {
console.error(e);
throw new InternalServerErrorException(e, 'Verify password failed');
const passwordMatches = await this.crypto.verifyPassword(
password,
user.password
);
if (!passwordMatches) {
throw new NotAcceptableException('Incorrect Password');
}
if (!equal) {
throw new UnauthorizedException('Invalid password');
return sessionUser(user);
}
async getUserWithCache(token: string, seq = 0) {
const cacheKey = `session:${token}:${seq}`;
let user = await this.cache.get<CurrentUser | null>(cacheKey);
if (user) {
return user;
}
user = await this.getUser(token, seq);
if (user) {
await this.cache.set(cacheKey, user);
}
return user;
}
async signUp(name: string, email: string, password: string): Promise<User> {
const user = await this.prisma.user.findFirst({
where: {
email: {
equals: email,
mode: 'insensitive',
},
},
});
async getUser(token: string, seq = 0): Promise<CurrentUser | null> {
const session = await this.getSession(token);
if (user) {
throw new BadRequestException('Email already exists');
// no such session
if (!session) {
return null;
}
const hashedPassword = await hash(password);
const userSession = session.userSessions.at(seq);
return this.prisma.user.create({
data: {
name,
email,
password: hashedPassword,
// TODO(@forehalo): handle in event system
features: {
create: {
reason: 'created by api sign up',
activated: true,
feature: {
connect: {
feature_version: Quota_FreePlanV1_1,
},
},
},
},
},
});
// no such user session
if (!userSession) {
return null;
}
async createAnonymousUser(email: string): Promise<User> {
const user = await this.prisma.user.findFirst({
where: {
email: {
equals: email,
mode: 'insensitive',
},
},
});
if (user) {
throw new BadRequestException('Email already exists');
// user session expired
if (userSession.expiresAt && userSession.expiresAt <= new Date()) {
return null;
}
return this.prisma.user.create({
data: {
name: 'Unnamed',
email,
features: {
create: {
reason: 'created by invite sign up',
activated: true,
feature: {
connect: {
feature_version: Quota_FreePlanV1_1,
},
},
},
},
},
const user = await this.db.user.findUnique({
where: { id: userSession.userId },
});
}
async getUserByEmail(email: string): Promise<User | null> {
return this.prisma.user.findFirst({
where: {
email: {
equals: email,
mode: 'insensitive',
},
},
});
}
async isUserHasPassword(email: string): Promise<boolean> {
const user = await this.prisma.user.findFirst({
where: {
email: {
equals: email,
mode: 'insensitive',
},
},
});
if (!user) {
throw new BadRequestException('Invalid email');
return null;
}
return Boolean(user.password);
return sessionUser(user);
}
async getUserList(token: string) {
const session = await this.getSession(token);
if (!session || !session.userSessions.length) {
return [];
}
const users = await this.db.user.findMany({
where: {
id: {
in: session.userSessions.map(({ userId }) => userId),
},
},
});
// TODO(@forehalo): need to separate expired session, same for [getUser]
// Session
// | { user: LimitedUser { email, avatarUrl }, expired: true }
// | { user: User, expired: false }
return users.map(sessionUser);
}
async signOut(token: string, seq = 0) {
const session = await this.getSession(token);
if (session) {
// overflow the logged in user
if (session.userSessions.length <= seq) {
return session;
}
await this.db.userSession.deleteMany({
where: { id: session.userSessions[seq].id },
});
// no more user session active, delete the whole session
if (session.userSessions.length === 1) {
await this.db.session.delete({ where: { id: session.id } });
return null;
}
return session;
}
return null;
}
async getSession(token: string) {
return this.db.$transaction(async tx => {
const session = await tx.session.findUnique({
where: {
id: token,
},
include: {
userSessions: {
orderBy: {
createdAt: 'asc',
},
},
},
});
if (!session) {
return null;
}
if (session.expiresAt && session.expiresAt <= new Date()) {
await tx.session.delete({
where: {
id: session.id,
},
});
return null;
}
return session;
});
}
async createUserSession(
user: { id: string },
existingSession?: string,
ttl = this.config.auth.session.ttl
) {
const session = existingSession
? await this.getSession(existingSession)
: null;
const expiresAt = new Date(Date.now() + ttl * 1000);
if (session) {
return this.db.userSession.upsert({
where: {
sessionId_userId: {
sessionId: session.id,
userId: user.id,
},
},
update: {
expiresAt,
},
create: {
sessionId: session.id,
userId: user.id,
expiresAt,
},
});
} else {
return this.db.userSession.create({
data: {
expiresAt,
session: {
create: {},
},
user: {
connect: {
id: user.id,
},
},
},
});
}
}
async setCookie(req: Request, res: Response, user: { id: string }) {
const session = await this.createUserSession(
user,
req.cookies[AuthService.sessionCookieName]
);
res.cookie(AuthService.sessionCookieName, session.sessionId, {
expires: session.expiresAt ?? void 0,
...this.cookieOptions,
});
}
async getUserByEmail(email: string) {
return this.user.findUserByEmail(email);
}
async changePassword(email: string, newPassword: string): Promise<User> {
const user = await this.prisma.user.findFirst({
where: {
email: {
equals: email,
mode: 'insensitive',
},
emailVerified: {
not: null,
},
},
});
const user = await this.getUserByEmail(email);
if (!user) {
throw new BadRequestException('Invalid email');
}
const hashedPassword = await hash(newPassword);
const hashedPassword = await this.crypto.encryptPassword(newPassword);
return this.prisma.user.update({
return this.db.user.update({
where: {
id: user.id,
},
@ -304,7 +332,7 @@ export class AuthService {
}
async changeEmail(id: string, newEmail: string): Promise<User> {
const user = await this.prisma.user.findUnique({
const user = await this.db.user.findUnique({
where: {
id,
},
@ -314,12 +342,27 @@ export class AuthService {
throw new BadRequestException('Invalid email');
}
return this.prisma.user.update({
return this.db.user.update({
where: {
id,
},
data: {
email: newEmail,
emailVerifiedAt: new Date(),
},
});
}
async setEmailVerified(id: string) {
return await this.db.user.update({
where: {
id,
},
data: {
emailVerifiedAt: new Date(),
},
select: {
emailVerifiedAt: true,
},
});
}
@ -336,7 +379,20 @@ export class AuthService {
async sendVerifyChangeEmail(email: string, callbackUrl: string) {
return this.mailer.sendVerifyChangeEmail(email, callbackUrl);
}
async sendVerifyEmail(email: string, callbackUrl: string) {
return this.mailer.sendVerifyEmail(email, callbackUrl);
}
async sendNotificationChangeEmail(email: string) {
return this.mailer.sendNotificationChangeEmail(email);
}
async sendSignInEmail(email: string, link: string, signUp: boolean) {
return signUp
? await this.mailer.sendSignUpMail(link.toString(), {
to: email,
})
: await this.mailer.sendSignInMail(link.toString(), {
to: email,
});
}
}

View File

@ -0,0 +1,84 @@
import { randomUUID } from 'node:crypto';
import { Injectable } from '@nestjs/common';
import { PrismaClient } from '@prisma/client';
import { CryptoHelper } from '../../fundamentals/helpers';
export enum TokenType {
SignIn,
VerifyEmail,
ChangeEmail,
ChangePassword,
Challenge,
}
@Injectable()
export class TokenService {
constructor(
private readonly db: PrismaClient,
private readonly crypto: CryptoHelper
) {}
async createToken(
type: TokenType,
credential?: string,
ttlInSec: number = 30 * 60
) {
const plaintextToken = randomUUID();
const { token } = await this.db.verificationToken.create({
data: {
type,
token: plaintextToken,
credential,
expiresAt: new Date(Date.now() + ttlInSec * 1000),
},
});
return this.crypto.encrypt(token);
}
async verifyToken(
type: TokenType,
token: string,
{
credential,
keep,
}: {
credential?: string;
keep?: boolean;
} = {}
) {
token = this.crypto.decrypt(token);
const record = await this.db.verificationToken.findUnique({
where: {
type_token: {
token,
type,
},
},
});
if (!record) {
return null;
}
const expired = record.expiresAt <= new Date();
const valid =
!expired && (!record.credential || record.credential === credential);
if ((expired || valid) && !keep) {
await this.db.verificationToken.delete({
where: {
type_token: {
token,
type,
},
},
});
}
return valid ? record : null;
}
}

View File

@ -1,3 +0,0 @@
export { jwtDecode as decode, jwtEncode as encode } from './jwt';
export { sendVerificationRequest } from './send-mail';
export type { SendVerificationRequestParams } from 'next-auth/providers/email';

View File

@ -1,76 +0,0 @@
import { randomUUID } from 'node:crypto';
import { BadRequestException } from '@nestjs/common';
import { Algorithm, sign, verify as jwtVerify } from '@node-rs/jsonwebtoken';
import { PrismaClient } from '@prisma/client';
import { JWT } from 'next-auth/jwt';
import { Config } from '../../../fundamentals';
import { getUtcTimestamp, UserClaim } from '../service';
export const jwtEncode = async (
config: Config,
prisma: PrismaClient,
token: JWT | undefined,
maxAge: number | undefined
) => {
if (!token?.email) {
throw new BadRequestException('Missing email in jwt token');
}
const user = await prisma.user.findFirstOrThrow({
where: {
email: token.email,
},
});
const now = getUtcTimestamp();
return sign(
{
data: {
id: user.id,
name: user.name,
email: user.email,
emailVerified: user.emailVerified?.toISOString(),
picture: user.avatarUrl,
createdAt: user.createdAt.toISOString(),
hasPassword: Boolean(user.password),
},
iat: now,
exp: now + (maxAge ?? config.auth.accessTokenExpiresIn),
iss: config.serverId,
sub: user.id,
aud: 'https://affine.pro',
jti: randomUUID({
disableEntropyCache: true,
}),
},
config.auth.privateKey,
{
algorithm: Algorithm.ES256,
}
);
};
export const jwtDecode = async (config: Config, token: string | undefined) => {
if (!token) {
return null;
}
const { name, email, emailVerified, id, picture, hasPassword } = (
await jwtVerify(token, config.auth.publicKey, {
algorithms: [Algorithm.ES256],
iss: [config.serverId],
leeway: config.auth.leeway,
requiredSpecClaims: ['exp', 'iat', 'iss', 'sub'],
})
).data as Omit<UserClaim, 'avatarUrl'> & {
picture: string | undefined;
};
return {
name,
email,
emailVerified,
picture,
sub: id,
id,
hasPassword,
};
};

View File

@ -1,38 +0,0 @@
import { Logger } from '@nestjs/common';
import { nanoid } from 'nanoid';
import type { SendVerificationRequestParams } from 'next-auth/providers/email';
import { Config, MailService, SessionService } from '../../../fundamentals';
export async function sendVerificationRequest(
config: Config,
logger: Logger,
mailer: MailService,
session: SessionService,
params: SendVerificationRequestParams
) {
const { identifier, url } = params;
const urlWithToken = new URL(url);
const callbackUrl = urlWithToken.searchParams.get('callbackUrl') || '';
if (!callbackUrl) {
throw new Error('callbackUrl is not set');
} else {
const newCallbackUrl = new URL(callbackUrl, config.origin);
const token = nanoid();
await session.set(token, identifier);
newCallbackUrl.searchParams.set('token', token);
urlWithToken.searchParams.set('callbackUrl', newCallbackUrl.toString());
}
const result = await mailer.sendSignInEmail(urlWithToken.toString(), {
to: identifier,
});
logger.log(`send verification email success: ${result.accepted.join(', ')}`);
const failed = result.rejected.concat(result.pending).filter(Boolean);
if (failed.length) {
throw new Error(`Email (${failed.join(', ')}) could not be sent`);
}
}

View File

@ -2,9 +2,11 @@ import { Module } from '@nestjs/common';
import { Field, ObjectType, Query, registerEnumType } from '@nestjs/graphql';
import { DeploymentType } from '../fundamentals';
import { Public } from './auth';
export enum ServerFeature {
Payment = 'payment',
OAuth = 'oauth',
}
registerEnumType(ServerFeature, {
@ -15,9 +17,9 @@ registerEnumType(DeploymentType, {
name: 'ServerDeploymentType',
});
const ENABLED_FEATURES: ServerFeature[] = [];
const ENABLED_FEATURES: Set<ServerFeature> = new Set();
export function ADD_ENABLED_FEATURES(feature: ServerFeature) {
ENABLED_FEATURES.push(feature);
ENABLED_FEATURES.add(feature);
}
@ObjectType()
@ -48,6 +50,7 @@ export class ServerConfigType {
}
export class ServerConfigResolver {
@Public()
@Query(() => ServerConfigType, {
description: 'server config',
})
@ -61,7 +64,7 @@ export class ServerConfigResolver {
// the old flavors contains `selfhosted` but it actually not flavor but deployment type
// this field should be removed after frontend feature flags implemented
flavor: AFFiNE.type,
features: ENABLED_FEATURES,
features: Array.from(ENABLED_FEATURES),
};
}
}

View File

@ -1,7 +1,6 @@
import { Injectable } from '@nestjs/common';
import { PrismaClient } from '@prisma/client';
import { UserType } from '../users/types';
import { WorkspaceType } from '../workspaces/types';
import { FeatureConfigType, getFeature } from './feature';
import { FeatureKind, FeatureType } from './types';
@ -158,7 +157,7 @@ export class FeatureService {
return configs.filter(feature => !!feature.feature);
}
async listFeatureUsers(feature: FeatureType): Promise<UserType[]> {
async listFeatureUsers(feature: FeatureType) {
return this.prisma.userFeatures
.findMany({
where: {
@ -175,7 +174,7 @@ export class FeatureService {
name: true,
avatarUrl: true,
email: true,
emailVerified: true,
emailVerifiedAt: true,
createdAt: true,
},
},

View File

@ -1,4 +1,4 @@
import { FeatureKind } from '../features';
import { FeatureKind } from '../features/types';
import { OneDay, OneGB, OneMB } from './constant';
import { Quota, QuotaType } from './types';

View File

@ -2,7 +2,7 @@ import { Field, ObjectType } from '@nestjs/graphql';
import { SafeIntResolver } from 'graphql-scalars';
import { z } from 'zod';
import { commonFeatureSchema, FeatureKind } from '../features';
import { commonFeatureSchema, FeatureKind } from '../features/types';
import { ByteUnit, OneDay, OneKB } from './constant';
/// ======== quota define ========

View File

@ -14,7 +14,6 @@ import { encodeStateAsUpdate, encodeStateVector } from 'yjs';
import { CallTimer, metrics } from '../../../fundamentals';
import { Auth, CurrentUser } from '../../auth';
import { DocManager } from '../../doc';
import { UserType } from '../../users';
import { DocID } from '../../utils/doc';
import { PermissionService } from '../../workspaces/permission';
import { Permission } from '../../workspaces/types';
@ -53,6 +52,7 @@ export const GatewayErrorWrapper = (): MethodDecorator => {
if (result instanceof Promise) {
return result.catch(e => {
metrics.socketio.counter('unhandled_errors').add(1);
new Logger('EventsGateway').error(e, e.stack);
return {
error: new InternalError(e),
};
@ -139,7 +139,7 @@ export class EventsGateway implements OnGatewayConnection, OnGatewayDisconnect {
@Auth()
@SubscribeMessage('client-handshake-sync')
async handleClientHandshakeSync(
@CurrentUser() user: UserType,
@CurrentUser() user: CurrentUser,
@MessageBody('workspaceId') workspaceId: string,
@MessageBody('version') version: string | undefined,
@ConnectedSocket() client: Socket
@ -172,7 +172,7 @@ export class EventsGateway implements OnGatewayConnection, OnGatewayDisconnect {
@Auth()
@SubscribeMessage('client-handshake-awareness')
async handleClientHandshakeAwareness(
@CurrentUser() user: UserType,
@CurrentUser() user: CurrentUser,
@MessageBody('workspaceId') workspaceId: string,
@MessageBody('version') version: string | undefined,
@ConnectedSocket() client: Socket
@ -290,7 +290,7 @@ export class EventsGateway implements OnGatewayConnection, OnGatewayDisconnect {
@SubscribeMessage('doc-load-v2')
async loadDocV2(
@ConnectedSocket() client: Socket,
@CurrentUser() user: UserType,
@CurrentUser() user: CurrentUser,
@MessageBody()
{
workspaceId,
@ -339,6 +339,7 @@ export class EventsGateway implements OnGatewayConnection, OnGatewayDisconnect {
};
}
@Auth()
@SubscribeMessage('awareness-init')
async handleInitAwareness(
@MessageBody() workspaceId: string,

View File

@ -6,15 +6,15 @@ import { StorageModule } from '../storage';
import { UserAvatarController } from './controller';
import { UserManagementResolver } from './management';
import { UserResolver } from './resolver';
import { UsersService } from './users';
import { UserService } from './service';
@Module({
imports: [StorageModule, FeatureModule, QuotaModule],
providers: [UserResolver, UserManagementResolver, UsersService],
providers: [UserResolver, UserManagementResolver, UserService],
controllers: [UserAvatarController],
exports: [UsersService],
exports: [UserService],
})
export class UsersModule {}
export class UserModule {}
export { UserService } from './service';
export { UserType } from './types';
export { UsersService } from './users';

View File

@ -6,23 +6,21 @@ import {
import { Args, Context, Int, Mutation, Query, Resolver } from '@nestjs/graphql';
import { CloudThrottlerGuard, Throttle } from '../../fundamentals';
import { Auth, CurrentUser } from '../auth/guard';
import { AuthService } from '../auth/service';
import { CurrentUser } from '../auth/current-user';
import { sessionUser } from '../auth/service';
import { FeatureManagementService } from '../features';
import { UserService } from './service';
import { UserType } from './types';
import { UsersService } from './users';
/**
* User resolver
* All op rate limit: 10 req/m
*/
@UseGuards(CloudThrottlerGuard)
@Auth()
@Resolver(() => UserType)
export class UserManagementResolver {
constructor(
private readonly auth: AuthService,
private readonly users: UsersService,
private readonly users: UserService,
private readonly feature: FeatureManagementService
) {}
@ -34,7 +32,7 @@ export class UserManagementResolver {
})
@Mutation(() => Int)
async addToEarlyAccess(
@CurrentUser() currentUser: UserType,
@CurrentUser() currentUser: CurrentUser,
@Args('email') email: string
): Promise<number> {
if (!this.feature.isStaff(currentUser.email)) {
@ -44,7 +42,7 @@ export class UserManagementResolver {
if (user) {
return this.feature.addEarlyAccess(user.id);
} else {
const user = await this.auth.createAnonymousUser(email);
const user = await this.users.createAnonymousUser(email);
return this.feature.addEarlyAccess(user.id);
}
}
@ -57,7 +55,7 @@ export class UserManagementResolver {
})
@Mutation(() => Int)
async removeEarlyAccess(
@CurrentUser() currentUser: UserType,
@CurrentUser() currentUser: CurrentUser,
@Args('email') email: string
): Promise<number> {
if (!this.feature.isStaff(currentUser.email)) {
@ -79,13 +77,15 @@ export class UserManagementResolver {
@Query(() => [UserType])
async earlyAccessUsers(
@Context() ctx: { isAdminQuery: boolean },
@CurrentUser() user: UserType
@CurrentUser() user: CurrentUser
): Promise<UserType[]> {
if (!this.feature.isStaff(user.email)) {
throw new ForbiddenException('You are not allowed to do this');
}
// allow query other user's subscription
ctx.isAdminQuery = true;
return this.feature.listEarlyAccess();
return this.feature.listEarlyAccess().then(users => {
return users.map(sessionUser);
});
}
}

View File

@ -9,6 +9,7 @@ import {
} from '@nestjs/graphql';
import { PrismaClient, type User } from '@prisma/client';
import GraphQLUpload from 'graphql-upload/GraphQLUpload.mjs';
import { isNil, omitBy } from 'lodash-es';
import {
CloudThrottlerGuard,
@ -17,68 +18,38 @@ import {
PaymentRequiredException,
Throttle,
} from '../../fundamentals';
import { Auth, CurrentUser, Public, Publicable } from '../auth/guard';
import { CurrentUser } from '../auth/current-user';
import { Public } from '../auth/guard';
import { sessionUser } from '../auth/service';
import { FeatureManagementService } from '../features';
import { QuotaService } from '../quota';
import { AvatarStorage } from '../storage';
import { UserService } from './service';
import {
DeleteAccount,
RemoveAvatar,
UpdateUserInput,
UserOrLimitedUser,
UserQuotaType,
UserType,
} from './types';
import { UsersService } from './users';
/**
* User resolver
* All op rate limit: 10 req/m
*/
@UseGuards(CloudThrottlerGuard)
@Auth()
@Resolver(() => UserType)
export class UserResolver {
constructor(
private readonly prisma: PrismaClient,
private readonly storage: AvatarStorage,
private readonly users: UsersService,
private readonly users: UserService,
private readonly feature: FeatureManagementService,
private readonly quota: QuotaService,
private readonly event: EventEmitter
) {}
@Throttle({
default: {
limit: 10,
ttl: 60,
},
})
@Publicable()
@Query(() => UserType, {
name: 'currentUser',
description: 'Get current user',
nullable: true,
})
async currentUser(@CurrentUser() user?: UserType) {
if (!user) {
return null;
}
const storedUser = await this.users.findUserById(user.id);
if (!storedUser) {
throw new BadRequestException(`User ${user.id} not found in db`);
}
return {
id: storedUser.id,
name: storedUser.name,
email: storedUser.email,
emailVerified: storedUser.emailVerified,
avatarUrl: storedUser.avatarUrl,
createdAt: storedUser.createdAt,
hasPassword: !!storedUser.password,
};
}
@Throttle({
default: {
limit: 10,
@ -92,9 +63,9 @@ export class UserResolver {
})
@Public()
async user(
@CurrentUser() currentUser?: UserType,
@CurrentUser() currentUser?: CurrentUser,
@Args('email') email?: string
) {
): Promise<typeof UserOrLimitedUser | null> {
if (!email || !(await this.feature.canEarlyAccess(email))) {
throw new PaymentRequiredException(
`You don't have early access permission\nVisit https://community.affine.pro/c/insider-general/ for more information`
@ -102,16 +73,19 @@ export class UserResolver {
}
// TODO: need to limit a user can only get another user witch is in the same workspace
const user = await this.users.findUserByEmail(email);
if (currentUser) return user;
const user = await this.users.findUserWithHashedPasswordByEmail(email);
// return empty response when user not exists
if (!user) return null;
if (currentUser) {
return sessionUser(user);
}
// only return limited info when not logged in
return {
email: user?.email,
hasPassword: !!user?.password,
email: user.email,
hasPassword: !!user.password,
};
}
@ -128,7 +102,7 @@ export class UserResolver {
name: 'invoiceCount',
description: 'Get user invoice count',
})
async invoiceCount(@CurrentUser() user: UserType) {
async invoiceCount(@CurrentUser() user: CurrentUser) {
return this.prisma.userInvoice.count({
where: { userId: user.id },
});
@ -145,7 +119,7 @@ export class UserResolver {
description: 'Upload user avatar',
})
async uploadAvatar(
@CurrentUser() user: UserType,
@CurrentUser() user: CurrentUser,
@Args({ name: 'avatar', type: () => GraphQLUpload })
avatar: FileUpload
) {
@ -169,6 +143,33 @@ export class UserResolver {
});
}
@Throttle({
default: {
limit: 10,
ttl: 60,
},
})
@Mutation(() => UserType, {
name: 'updateProfile',
})
async updateUserProfile(
@CurrentUser() user: CurrentUser,
@Args('input', { type: () => UpdateUserInput }) input: UpdateUserInput
): Promise<UserType> {
input = omitBy(input, isNil);
if (Object.keys(input).length === 0) {
return user;
}
return sessionUser(
await this.prisma.user.update({
where: { id: user.id },
data: input,
})
);
}
@Throttle({
default: {
limit: 10,
@ -179,7 +180,7 @@ export class UserResolver {
name: 'removeAvatar',
description: 'Remove user avatar',
})
async removeAvatar(@CurrentUser() user: UserType) {
async removeAvatar(@CurrentUser() user: CurrentUser) {
if (!user) {
throw new BadRequestException(`User not found`);
}
@ -197,7 +198,9 @@ export class UserResolver {
},
})
@Mutation(() => DeleteAccount)
async deleteAccount(@CurrentUser() user: UserType): Promise<DeleteAccount> {
async deleteAccount(
@CurrentUser() user: CurrentUser
): Promise<DeleteAccount> {
const deletedUser = await this.users.deleteUser(user.id);
this.event.emit('user.deleted', deletedUser);
return { success: true };

View File

@ -0,0 +1,112 @@
import { BadRequestException, Injectable } from '@nestjs/common';
import { Prisma, PrismaClient } from '@prisma/client';
import { Quota_FreePlanV1_1 } from '../quota/schema';
@Injectable()
export class UserService {
defaultUserSelect = {
id: true,
name: true,
email: true,
emailVerifiedAt: true,
avatarUrl: true,
} satisfies Prisma.UserSelect;
constructor(private readonly prisma: PrismaClient) {}
get userCreatingData(): Partial<Prisma.UserCreateInput> {
return {
name: 'Unnamed',
features: {
create: {
reason: 'created by invite sign up',
activated: true,
feature: {
connect: {
feature_version: Quota_FreePlanV1_1,
},
},
},
},
};
}
async createUser(data: Prisma.UserCreateInput) {
return this.prisma.user.create({
data: {
...this.userCreatingData,
...data,
},
});
}
async createAnonymousUser(
email: string,
data?: Partial<Prisma.UserCreateInput>
) {
const user = await this.findUserByEmail(email);
if (user) {
throw new BadRequestException('Email already exists');
}
return this.createUser({
email,
name: 'Unnamed',
...data,
});
}
async findUserById(id: string) {
return this.prisma.user
.findUnique({
where: { id },
select: this.defaultUserSelect,
})
.catch(() => {
return null;
});
}
async findUserByEmail(email: string) {
return this.prisma.user.findFirst({
where: {
email: {
equals: email,
mode: 'insensitive',
},
},
select: this.defaultUserSelect,
});
}
/**
* supposed to be used only for `Credential SignIn`
*/
async findUserWithHashedPasswordByEmail(email: string) {
return this.prisma.user.findFirst({
where: {
email: {
equals: email,
mode: 'insensitive',
},
},
});
}
async findOrCreateUser(
email: string,
data?: Partial<Prisma.UserCreateInput>
) {
const user = await this.findUserByEmail(email);
if (user) {
return user;
}
return this.createAnonymousUser(email, data);
}
async deleteUser(id: string) {
return this.prisma.user.delete({ where: { id } });
}
}

View File

@ -1,7 +1,15 @@
import { createUnionType, Field, ID, ObjectType } from '@nestjs/graphql';
import {
createUnionType,
Field,
ID,
InputType,
ObjectType,
} from '@nestjs/graphql';
import type { User } from '@prisma/client';
import { SafeIntResolver } from 'graphql-scalars';
import { CurrentUser } from '../auth/current-user';
@ObjectType('UserQuotaHumanReadable')
export class UserQuotaHumanReadableType {
@Field({ name: 'name' })
@ -42,7 +50,7 @@ export class UserQuotaType {
}
@ObjectType()
export class UserType implements Partial<User> {
export class UserType implements CurrentUser {
@Field(() => ID)
id!: string;
@ -53,19 +61,25 @@ export class UserType implements Partial<User> {
email!: string;
@Field(() => String, { description: 'User avatar url', nullable: true })
avatarUrl: string | null = null;
avatarUrl!: string | null;
@Field(() => Date, { description: 'User email verified', nullable: true })
emailVerified: Date | null = null;
@Field({ description: 'User created date', nullable: true })
createdAt!: Date;
@Field(() => Boolean, {
description: 'User email verified',
})
emailVerified!: boolean;
@Field(() => Boolean, {
description: 'User password has been set',
nullable: true,
})
hasPassword?: boolean;
hasPassword!: boolean | null;
@Field(() => Date, {
deprecationReason: 'useless',
description: 'User email verified',
nullable: true,
})
createdAt?: Date | null;
}
@ObjectType()
@ -77,7 +91,7 @@ export class LimitedUserType implements Partial<User> {
description: 'User password has been set',
nullable: true,
})
hasPassword?: boolean;
hasPassword!: boolean | null;
}
export const UserOrLimitedUser = createUnionType({
@ -101,3 +115,9 @@ export class RemoveAvatar {
@Field()
success!: boolean;
}
@InputType()
export class UpdateUserInput implements Partial<User> {
@Field({ description: 'User name', nullable: true })
name?: string;
}

View File

@ -1,32 +0,0 @@
import { Injectable } from '@nestjs/common';
import { PrismaClient } from '@prisma/client';
@Injectable()
export class UsersService {
constructor(private readonly prisma: PrismaClient) {}
async findUserByEmail(email: string) {
return this.prisma.user.findFirst({
where: {
email: {
equals: email,
mode: 'insensitive',
},
},
});
}
async findUserById(id: string) {
return this.prisma.user
.findUnique({
where: { id },
})
.catch(() => {
return null;
});
}
async deleteUser(id: string) {
return this.prisma.user.delete({ where: { id } });
}
}

View File

@ -0,0 +1,55 @@
import { BadRequestException } from '@nestjs/common';
import z from 'zod';
function getAuthCredentialValidator() {
const email = z.string().email({ message: 'Invalid email address' });
let password = z.string();
const minPasswordLength = AFFiNE.node.prod ? 8 : 1;
password = password
.min(minPasswordLength, {
message: `Password must be ${minPasswordLength} or more charactors long`,
})
.max(20, { message: 'Password must be 20 or fewer charactors long' });
return z
.object({
email,
password,
})
.required();
}
function assertValid<T>(z: z.ZodType<T>, value: unknown) {
const result = z.safeParse(value);
if (!result.success) {
const firstIssue = result.error.issues.at(0);
if (firstIssue) {
throw new BadRequestException(firstIssue.message);
} else {
throw new BadRequestException('Invalid credential');
}
}
}
export function assertValidEmail(email: string) {
assertValid(getAuthCredentialValidator().shape.email, email);
}
export function assertValidPassword(password: string) {
assertValid(getAuthCredentialValidator().shape.password, password);
}
export function assertValidCredential(credential: {
email: string;
password: string;
}) {
assertValid(getAuthCredentialValidator(), credential);
}
export const validators = {
assertValidEmail,
assertValidPassword,
assertValidCredential,
};

View File

@ -11,10 +11,9 @@ import { PrismaClient } from '@prisma/client';
import type { Response } from 'express';
import { CallTimer } from '../../fundamentals';
import { Auth, CurrentUser, Publicable } from '../auth';
import { CurrentUser, Public } from '../auth';
import { DocHistoryManager, DocManager } from '../doc';
import { WorkspaceBlobStorage } from '../storage';
import { UserType } from '../users';
import { DocID } from '../utils/doc';
import { PermissionService, PublicPageMode } from './permission';
import { Permission } from './types';
@ -63,11 +62,10 @@ export class WorkspacesController {
// get doc binary
@Get('/:id/docs/:guid')
@Auth()
@Publicable()
@Public()
@CallTimer('controllers', 'workspace_get_doc')
async doc(
@CurrentUser() user: UserType | undefined,
@CurrentUser() user: CurrentUser | undefined,
@Param('id') ws: string,
@Param('guid') guid: string,
@Res() res: Response
@ -112,10 +110,9 @@ export class WorkspacesController {
}
@Get('/:id/docs/:guid/histories/:timestamp')
@Auth()
@CallTimer('controllers', 'workspace_get_history')
async history(
@CurrentUser() user: UserType,
@CurrentUser() user: CurrentUser,
@Param('id') ws: string,
@Param('guid') guid: string,
@Param('timestamp') timestamp: string,

View File

@ -4,7 +4,7 @@ import { DocModule } from '../doc';
import { FeatureModule } from '../features';
import { QuotaModule } from '../quota';
import { StorageModule } from '../storage';
import { UsersService } from '../users';
import { UserModule } from '../user';
import { WorkspacesController } from './controller';
import { WorkspaceManagementResolver } from './management';
import { PermissionService } from './permission';
@ -16,13 +16,12 @@ import {
} from './resolvers';
@Module({
imports: [DocModule, FeatureModule, QuotaModule, StorageModule],
imports: [DocModule, FeatureModule, QuotaModule, StorageModule, UserModule],
controllers: [WorkspacesController],
providers: [
WorkspaceResolver,
WorkspaceManagementResolver,
PermissionService,
UsersService,
PagePermissionResolver,
DocHistoryResolver,
WorkspaceBlobResolver,

View File

@ -10,14 +10,12 @@ import {
} from '@nestjs/graphql';
import { CloudThrottlerGuard, Throttle } from '../../fundamentals';
import { Auth, CurrentUser } from '../auth';
import { CurrentUser } from '../auth';
import { FeatureManagementService, FeatureType } from '../features';
import { UserType } from '../users';
import { PermissionService } from './permission';
import { WorkspaceType } from './types';
@UseGuards(CloudThrottlerGuard)
@Auth()
@Resolver(() => WorkspaceType)
export class WorkspaceManagementResolver {
constructor(
@ -33,7 +31,7 @@ export class WorkspaceManagementResolver {
})
@Mutation(() => Int)
async addWorkspaceFeature(
@CurrentUser() currentUser: UserType,
@CurrentUser() currentUser: CurrentUser,
@Args('workspaceId') workspaceId: string,
@Args('feature', { type: () => FeatureType }) feature: FeatureType
): Promise<number> {
@ -52,7 +50,7 @@ export class WorkspaceManagementResolver {
})
@Mutation(() => Int)
async removeWorkspaceFeature(
@CurrentUser() currentUser: UserType,
@CurrentUser() currentUser: CurrentUser,
@Args('workspaceId') workspaceId: string,
@Args('feature', { type: () => FeatureType }) feature: FeatureType
): Promise<boolean> {
@ -71,7 +69,7 @@ export class WorkspaceManagementResolver {
})
@Query(() => [WorkspaceType])
async listWorkspaceFeatures(
@CurrentUser() user: UserType,
@CurrentUser() user: CurrentUser,
@Args('feature', { type: () => FeatureType }) feature: FeatureType
): Promise<WorkspaceType[]> {
if (!this.feature.isStaff(user.email)) {
@ -83,7 +81,7 @@ export class WorkspaceManagementResolver {
@Mutation(() => Boolean)
async setWorkspaceExperimentalFeature(
@CurrentUser() user: UserType,
@CurrentUser() user: CurrentUser,
@Args('workspaceId') workspaceId: string,
@Args('feature', { type: () => FeatureType }) feature: FeatureType,
@Args('enable') enable: boolean
@ -117,7 +115,7 @@ export class WorkspaceManagementResolver {
complexity: 2,
})
async availableFeatures(
@CurrentUser() user: UserType
@CurrentUser() user: CurrentUser
): Promise<FeatureType[]> {
const isEarlyAccessUser = await this.feature.isEarlyAccessUser(user.email);
if (isEarlyAccessUser) {

View File

@ -22,16 +22,14 @@ import {
MakeCache,
PreventCache,
} from '../../../fundamentals';
import { Auth, CurrentUser } from '../../auth';
import { CurrentUser } from '../../auth';
import { FeatureManagementService, FeatureType } from '../../features';
import { QuotaManagementService } from '../../quota';
import { WorkspaceBlobStorage } from '../../storage';
import { UserType } from '../../users';
import { PermissionService } from '../permission';
import { Permission, WorkspaceBlobSizes, WorkspaceType } from '../types';
@UseGuards(CloudThrottlerGuard)
@Auth()
@Resolver(() => WorkspaceType)
export class WorkspaceBlobResolver {
logger = new Logger(WorkspaceBlobResolver.name);
@ -47,7 +45,7 @@ export class WorkspaceBlobResolver {
complexity: 2,
})
async blobs(
@CurrentUser() user: UserType,
@CurrentUser() user: CurrentUser,
@Parent() workspace: WorkspaceType
) {
await this.permissions.checkWorkspace(workspace.id, user.id);
@ -74,7 +72,7 @@ export class WorkspaceBlobResolver {
})
@MakeCache(['blobs'], ['workspaceId'])
async listBlobs(
@CurrentUser() user: UserType,
@CurrentUser() user: CurrentUser,
@Args('workspaceId') workspaceId: string
) {
await this.permissions.checkWorkspace(workspaceId, user.id);
@ -90,7 +88,7 @@ export class WorkspaceBlobResolver {
@Query(() => WorkspaceBlobSizes, {
deprecationReason: 'use `user.storageUsage` instead',
})
async collectAllBlobSizes(@CurrentUser() user: UserType) {
async collectAllBlobSizes(@CurrentUser() user: CurrentUser) {
const size = await this.quota.getUserUsage(user.id);
return { size };
}
@ -102,7 +100,7 @@ export class WorkspaceBlobResolver {
deprecationReason: 'no more needed',
})
async checkBlobSize(
@CurrentUser() user: UserType,
@CurrentUser() user: CurrentUser,
@Args('workspaceId') workspaceId: string,
@Args('size', { type: () => SafeIntResolver }) blobSize: number
) {
@ -121,7 +119,7 @@ export class WorkspaceBlobResolver {
@Mutation(() => String)
@PreventCache(['blobs'], ['workspaceId'])
async setBlob(
@CurrentUser() user: UserType,
@CurrentUser() user: CurrentUser,
@Args('workspaceId') workspaceId: string,
@Args({ name: 'blob', type: () => GraphQLUpload })
blob: FileUpload
@ -199,7 +197,7 @@ export class WorkspaceBlobResolver {
@Mutation(() => Boolean)
@PreventCache(['blobs'], ['workspaceId'])
async deleteBlob(
@CurrentUser() user: UserType,
@CurrentUser() user: CurrentUser,
@Args('workspaceId') workspaceId: string,
@Args('hash') name: string
) {

View File

@ -13,9 +13,8 @@ import {
import type { SnapshotHistory } from '@prisma/client';
import { CloudThrottlerGuard } from '../../../fundamentals';
import { Auth, CurrentUser } from '../../auth';
import { CurrentUser } from '../../auth';
import { DocHistoryManager } from '../../doc';
import { UserType } from '../../users';
import { DocID } from '../../utils/doc';
import { PermissionService } from '../permission';
import { Permission, WorkspaceType } from '../types';
@ -68,10 +67,9 @@ export class DocHistoryResolver {
);
}
@Auth()
@Mutation(() => Date)
async recoverDoc(
@CurrentUser() user: UserType,
@CurrentUser() user: CurrentUser,
@Args('workspaceId') workspaceId: string,
@Args('guid') guid: string,
@Args({ name: 'timestamp', type: () => GraphQLISODateTime }) timestamp: Date

View File

@ -15,8 +15,7 @@ import {
} from '@prisma/client';
import { CloudThrottlerGuard } from '../../../fundamentals';
import { Auth, CurrentUser } from '../../auth';
import { UserType } from '../../users';
import { CurrentUser } from '../../auth';
import { DocID } from '../../utils/doc';
import { PermissionService, PublicPageMode } from '../permission';
import { Permission, WorkspaceType } from '../types';
@ -42,7 +41,6 @@ class WorkspacePage implements Partial<PrismaWorkspacePage> {
}
@UseGuards(CloudThrottlerGuard)
@Auth()
@Resolver(() => WorkspaceType)
export class PagePermissionResolver {
constructor(
@ -90,7 +88,7 @@ export class PagePermissionResolver {
deprecationReason: 'renamed to publicPage',
})
async deprecatedSharePage(
@CurrentUser() user: UserType,
@CurrentUser() user: CurrentUser,
@Args('workspaceId') workspaceId: string,
@Args('pageId') pageId: string
) {
@ -100,7 +98,7 @@ export class PagePermissionResolver {
@Mutation(() => WorkspacePage)
async publishPage(
@CurrentUser() user: UserType,
@CurrentUser() user: CurrentUser,
@Args('workspaceId') workspaceId: string,
@Args('pageId') pageId: string,
@Args({
@ -134,7 +132,7 @@ export class PagePermissionResolver {
deprecationReason: 'use revokePublicPage',
})
async deprecatedRevokePage(
@CurrentUser() user: UserType,
@CurrentUser() user: CurrentUser,
@Args('workspaceId') workspaceId: string,
@Args('pageId') pageId: string
) {
@ -144,7 +142,7 @@ export class PagePermissionResolver {
@Mutation(() => WorkspacePage)
async revokePublicPage(
@CurrentUser() user: UserType,
@CurrentUser() user: CurrentUser,
@Args('workspaceId') workspaceId: string,
@Args('pageId') pageId: string
) {

View File

@ -15,7 +15,7 @@ import {
ResolveField,
Resolver,
} from '@nestjs/graphql';
import { PrismaClient, type User } from '@prisma/client';
import { PrismaClient } from '@prisma/client';
import { getStreamAsBuffer } from 'get-stream';
import GraphQLUpload from 'graphql-upload/GraphQLUpload.mjs';
import { applyUpdate, Doc } from 'yjs';
@ -27,11 +27,10 @@ import {
MailService,
Throttle,
} from '../../../fundamentals';
import { Auth, CurrentUser, Public } from '../../auth';
import { AuthService } from '../../auth/service';
import { CurrentUser, Public } from '../../auth';
import { QuotaManagementService, QuotaQueryType } from '../../quota';
import { WorkspaceBlobStorage } from '../../storage';
import { UsersService, UserType } from '../../users';
import { UserService, UserType } from '../../user';
import { PermissionService } from '../permission';
import {
InvitationType,
@ -48,18 +47,16 @@ import { defaultWorkspaceAvatar } from '../utils';
* Other rate limit: 120 req/m
*/
@UseGuards(CloudThrottlerGuard)
@Auth()
@Resolver(() => WorkspaceType)
export class WorkspaceResolver {
private readonly logger = new Logger(WorkspaceResolver.name);
constructor(
private readonly auth: AuthService,
private readonly mailer: MailService,
private readonly prisma: PrismaClient,
private readonly permissions: PermissionService,
private readonly quota: QuotaManagementService,
private readonly users: UsersService,
private readonly users: UserService,
private readonly event: EventEmitter,
private readonly blobStorage: WorkspaceBlobStorage
) {}
@ -69,7 +66,7 @@ export class WorkspaceResolver {
complexity: 2,
})
async permission(
@CurrentUser() user: UserType,
@CurrentUser() user: CurrentUser,
@Parent() workspace: WorkspaceType
) {
// may applied in workspaces query
@ -160,7 +157,7 @@ export class WorkspaceResolver {
complexity: 2,
})
async isOwner(
@CurrentUser() user: UserType,
@CurrentUser() user: CurrentUser,
@Args('workspaceId') workspaceId: string
) {
const data = await this.permissions.tryGetWorkspaceOwner(workspaceId);
@ -172,7 +169,7 @@ export class WorkspaceResolver {
description: 'Get all accessible workspaces for current user',
complexity: 2,
})
async workspaces(@CurrentUser() user: User) {
async workspaces(@CurrentUser() user: CurrentUser) {
const data = await this.prisma.workspaceUserPermission.findMany({
where: {
userId: user.id,
@ -216,7 +213,7 @@ export class WorkspaceResolver {
@Query(() => WorkspaceType, {
description: 'Get workspace by id',
})
async workspace(@CurrentUser() user: UserType, @Args('id') id: string) {
async workspace(@CurrentUser() user: CurrentUser, @Args('id') id: string) {
await this.permissions.checkWorkspace(id, user.id);
const workspace = await this.prisma.workspace.findUnique({ where: { id } });
@ -231,7 +228,7 @@ export class WorkspaceResolver {
description: 'Create a new workspace',
})
async createWorkspace(
@CurrentUser() user: UserType,
@CurrentUser() user: CurrentUser,
// we no longer support init workspace with a preload file
// use sync system to uploading them once created
@Args({ name: 'init', type: () => GraphQLUpload, nullable: true })
@ -289,7 +286,7 @@ export class WorkspaceResolver {
description: 'Update workspace',
})
async updateWorkspace(
@CurrentUser() user: UserType,
@CurrentUser() user: CurrentUser,
@Args({ name: 'input', type: () => UpdateWorkspaceInput })
{ id, ...updates }: UpdateWorkspaceInput
) {
@ -304,7 +301,10 @@ export class WorkspaceResolver {
}
@Mutation(() => Boolean)
async deleteWorkspace(@CurrentUser() user: UserType, @Args('id') id: string) {
async deleteWorkspace(
@CurrentUser() user: CurrentUser,
@Args('id') id: string
) {
await this.permissions.checkWorkspace(id, user.id, Permission.Owner);
await this.prisma.workspace.delete({
@ -320,7 +320,7 @@ export class WorkspaceResolver {
@Mutation(() => String)
async invite(
@CurrentUser() user: UserType,
@CurrentUser() user: CurrentUser,
@Args('workspaceId') workspaceId: string,
@Args('email') email: string,
@Args('permission', { type: () => Permission }) permission: Permission,
@ -358,7 +358,7 @@ export class WorkspaceResolver {
// only invite if the user is not already in the workspace
if (originRecord) return originRecord.id;
} else {
target = await this.auth.createAnonymousUser(email);
target = await this.users.createAnonymousUser(email);
}
const inviteId = await this.permissions.grant(
@ -470,7 +470,7 @@ export class WorkspaceResolver {
@Mutation(() => Boolean)
async revoke(
@CurrentUser() user: UserType,
@CurrentUser() user: CurrentUser,
@Args('workspaceId') workspaceId: string,
@Args('userId') userId: string
) {
@ -514,7 +514,7 @@ export class WorkspaceResolver {
@Mutation(() => Boolean)
async leaveWorkspace(
@CurrentUser() user: UserType,
@CurrentUser() user: CurrentUser,
@Args('workspaceId') workspaceId: string,
@Args('workspaceName') workspaceName: string,
@Args('sendLeaveMail', { nullable: true }) sendLeaveMail: boolean

View File

@ -11,7 +11,7 @@ import {
import type { Workspace } from '@prisma/client';
import { SafeIntResolver } from 'graphql-scalars';
import { UserType } from '../users/types';
import { UserType } from '../user/types';
export enum Permission {
Read = 0,

View File

@ -1,13 +1,15 @@
import { ModuleRef } from '@nestjs/core';
import { hash } from '@node-rs/argon2';
import { PrismaClient } from '@prisma/client';
import { Config } from '../../fundamentals';
import { UserService } from '../../core/user';
import { Config, CryptoHelper } from '../../fundamentals';
export class SelfHostAdmin1605053000403 {
// do the migration
static async up(db: PrismaClient, ref: ModuleRef) {
static async up(_db: PrismaClient, ref: ModuleRef) {
const config = ref.get(Config, { strict: false });
const crypto = ref.get(CryptoHelper, { strict: false });
const user = ref.get(UserService, { strict: false });
if (config.isSelfhosted) {
if (
!process.env.AFFINE_ADMIN_EMAIL ||
@ -17,13 +19,12 @@ export class SelfHostAdmin1605053000403 {
'You have to set AFFINE_ADMIN_EMAIL and AFFINE_ADMIN_PASSWORD environment variables to generate the initial user for self-hosted AFFiNE Server.'
);
}
await db.user.create({
data: {
await user.findOrCreateUser(process.env.AFFINE_ADMIN_EMAIL, {
name: 'AFFINE First User',
email: process.env.AFFINE_ADMIN_EMAIL,
emailVerified: new Date(),
password: await hash(process.env.AFFINE_ADMIN_PASSWORD),
},
emailVerifiedAt: new Date(),
password: await crypto.encryptPassword(
process.env.AFFINE_ADMIN_PASSWORD
),
});
}
}

View File

@ -87,6 +87,22 @@ export interface AFFiNEConfig {
sync: boolean;
};
/**
* Application secrets for authentication and data encryption
*/
secrets: {
/**
* Application public key
*
*/
publicKey: string;
/**
* Application private key
*
*/
privateKey: string;
};
/**
* Deployment environment
*/
@ -204,67 +220,32 @@ export interface AFFiNEConfig {
* authentication config
*/
auth: {
session: {
/**
* Application access token expiration time
*/
readonly accessTokenExpiresIn: number;
/**
* Application refresh token expiration time
*/
readonly refreshTokenExpiresIn: number;
/**
* Add some leeway (in seconds) to the exp and nbf validation to account for clock skew.
* Defaults to 60 if omitted.
*/
readonly leeway: number;
/**
* Application public key
* Application auth expiration time in seconds
*
* @default 15 days
*/
readonly publicKey: string;
ttl: number;
};
/**
* Application private key
* Application access token config
*/
accessToken: {
/**
* Application access token expiration time in seconds
*
* @default 7 days
*/
readonly privateKey: string;
ttl: number;
/**
* whether allow user to signup with email directly
* Application refresh token expiration time in seconds
*
* @default 30 days
*/
enableSignup: boolean;
/**
* whether allow user to signup by oauth providers
*/
enableOauth: boolean;
/**
* NEXTAUTH_SECRET
*/
nextAuthSecret: string;
/**
* all available oauth providers
*/
oauthProviders: Partial<
Record<
ExternalAccount,
{
enabled: boolean;
clientId: string;
clientSecret: string;
/**
* uri to start oauth flow
*/
authorizationUri?: string;
/**
* uri to authenticate `access_token` when user is redirected back from oauth provider with `code`
*/
accessTokenUri?: string;
/**
* uri to get user info with authenticated `access_token`
*/
userInfoUri?: string;
args?: Record<string, any>;
}
>
>;
refreshTokenTtl: number;
};
captcha: {
/**
* whether to enable captcha

View File

@ -3,7 +3,6 @@
import { createPrivateKey, createPublicKey } from 'node:crypto';
import { merge } from 'lodash-es';
import parse from 'parse-duration';
import pkg from '../../../package.json' assert { type: 'json' };
import {
@ -23,7 +22,9 @@ AwEHoUQDQgAEF3U/0wIeJ3jRKXeFKqQyBKlr9F7xaAUScRrAuSP33rajm3cdfihI
3JvMxVNsS2lE8PSGQrvDrJZaDo0L+Lq9Gg==
-----END EC PRIVATE KEY-----`;
const jwtKeyPair = (function () {
const ONE_DAY_IN_SEC = 60 * 60 * 24;
const keyPair = (function () {
const AUTH_PRIVATE_KEY = process.env.AUTH_PRIVATE_KEY ?? examplePrivateKey;
const privateKey = createPrivateKey({
key: Buffer.from(AUTH_PRIVATE_KEY),
@ -114,6 +115,10 @@ export const getDefaultAFFiNEConfig: () => AFFiNEConfig = () => {
get deploy() {
return !this.node.dev && !this.node.test;
},
secrets: {
privateKey: keyPair.privateKey,
publicKey: keyPair.publicKey,
},
featureFlags: {
earlyAccessPreview: false,
syncClientVersionCheck: false,
@ -145,11 +150,13 @@ export const getDefaultAFFiNEConfig: () => AFFiNEConfig = () => {
playground: true,
},
auth: {
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
accessTokenExpiresIn: parse('1h')! / 1000,
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
refreshTokenExpiresIn: parse('7d')! / 1000,
leeway: 60,
session: {
ttl: 15 * ONE_DAY_IN_SEC,
},
accessToken: {
ttl: 7 * ONE_DAY_IN_SEC,
refreshTokenTtl: 30 * ONE_DAY_IN_SEC,
},
captcha: {
enable: false,
turnstile: {
@ -159,14 +166,6 @@ export const getDefaultAFFiNEConfig: () => AFFiNEConfig = () => {
bits: 20,
},
},
privateKey: jwtKeyPair.privateKey,
publicKey: jwtKeyPair.publicKey,
enableSignup: true,
enableOauth: false,
get nextAuthSecret() {
return this.privateKey;
},
oauthProviders: {},
},
storage: getDefaultAFFiNEStorageConfig(),
rateLimiter: {
@ -188,10 +187,10 @@ export const getDefaultAFFiNEConfig: () => AFFiNEConfig = () => {
enabled: false,
},
plugins: {
enabled: [],
enabled: new Set(),
use(plugin, config) {
this[plugin] = merge(this[plugin], config || {});
this.enabled.push(plugin);
this.enabled.add(plugin);
},
},
} satisfies AFFiNEConfig;

View File

@ -0,0 +1,105 @@
import { createPrivateKey, createPublicKey } from 'node:crypto';
import { Test } from '@nestjs/testing';
import ava, { TestFn } from 'ava';
import Sinon from 'sinon';
import { ConfigModule } from '../../config';
import { CryptoHelper } from '../crypto';
const test = ava as TestFn<{
crypto: CryptoHelper;
}>;
const key = `-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIEtyAJLIULkphVhqXqxk4Nr8Ggty3XLwUJWBxzAWCWTMoAoGCCqGSM49
AwEHoUQDQgAEF3U/0wIeJ3jRKXeFKqQyBKlr9F7xaAUScRrAuSP33rajm3cdfihI
3JvMxVNsS2lE8PSGQrvDrJZaDo0L+Lq9Gg==
-----END EC PRIVATE KEY-----`;
const privateKey = createPrivateKey({
key,
format: 'pem',
type: 'sec1',
})
.export({
type: 'pkcs8',
format: 'pem',
})
.toString('utf8');
const publicKey = createPublicKey({
key,
format: 'pem',
type: 'spki',
})
.export({
format: 'pem',
type: 'spki',
})
.toString('utf8');
test.beforeEach(async t => {
const module = await Test.createTestingModule({
imports: [
ConfigModule.forRoot({
secrets: {
publicKey,
privateKey,
},
}),
],
providers: [CryptoHelper],
}).compile();
t.context.crypto = module.get(CryptoHelper);
});
test('should be able to sign and verify', t => {
const data = 'hello world';
const signature = t.context.crypto.sign(data);
t.true(t.context.crypto.verify(data, signature));
t.false(t.context.crypto.verify(data, 'fake-signature'));
});
test('should be able to encrypt and decrypt', t => {
const data = 'top secret';
const stub = Sinon.stub(t.context.crypto, 'randomBytes').returns(
Buffer.alloc(12, 0)
);
const encrypted = t.context.crypto.encrypt(data);
const decrypted = t.context.crypto.decrypt(encrypted);
// we are using a stub to make sure the iv is always 0,
// the encrypted result will always be the same
t.is(encrypted, 'AAAAAAAAAAAAAAAAWUDlJRhzP+SZ3avvmLcgnou+q4E11w==');
t.is(decrypted, data);
stub.restore();
});
test('should be able to get random bytes', t => {
const bytes = t.context.crypto.randomBytes();
t.is(bytes.length, 12);
const bytes2 = t.context.crypto.randomBytes();
t.notDeepEqual(bytes, bytes2);
});
test('should be able to digest', t => {
const data = 'hello world';
const hash = t.context.crypto.sha256(data).toString('base64');
t.is(hash, 'uU0nuZNNPgilLlLX2n2r+sSE7+N6U4DukIj3rOLvzek=');
});
test('should be able to safe compare', t => {
t.true(t.context.crypto.compare('abc', 'abc'));
t.false(t.context.crypto.compare('abc', 'def'));
});
test('should be able to hash and verify password', async t => {
const password = 'mySecurePassword';
const hash = await t.context.crypto.encryptPassword(password);
t.true(await t.context.crypto.verifyPassword(password, hash));
t.false(await t.context.crypto.verifyPassword('wrong-password', hash));
});

View File

@ -0,0 +1,72 @@
import { Test } from '@nestjs/testing';
import ava, { TestFn } from 'ava';
import Sinon from 'sinon';
import { ConfigModule } from '../../config';
import { URLHelper } from '../url';
const test = ava as TestFn<{
url: URLHelper;
}>;
test.beforeEach(async t => {
const module = await Test.createTestingModule({
imports: [
ConfigModule.forRoot({
host: 'app.affine.local',
port: 3010,
https: true,
}),
],
providers: [URLHelper],
}).compile();
t.context.url = module.get(URLHelper);
});
test('can get home page', t => {
t.is(t.context.url.home, 'https://app.affine.local');
});
test('can stringify query', t => {
t.is(t.context.url.stringify({ a: 1, b: 2 }), 'a=1&b=2');
t.is(t.context.url.stringify({ a: 1, b: '/path' }), 'a=1&b=%2Fpath');
});
test('can create link', t => {
t.is(t.context.url.link('/path'), 'https://app.affine.local/path');
t.is(
t.context.url.link('/path', { a: 1, b: 2 }),
'https://app.affine.local/path?a=1&b=2'
);
t.is(
t.context.url.link('/path', { a: 1, b: '/path' }),
'https://app.affine.local/path?a=1&b=%2Fpath'
);
});
test('can safe redirect', t => {
const res = {
redirect: (to: string) => to,
} as any;
const spy = Sinon.spy(res, 'redirect');
function allow(to: string) {
t.context.url.safeRedirect(res, to);
t.true(spy.calledOnceWith(to));
spy.resetHistory();
}
function deny(to: string) {
t.context.url.safeRedirect(res, to);
t.true(spy.calledOnceWith(t.context.url.home));
spy.resetHistory();
}
[
'https://app.affine.local',
'https://app.affine.local/path',
'https://app.affine.local/path?query=1',
].forEach(allow);
['https://other.domain.com', 'a://invalid.uri'].forEach(deny);
});

View File

@ -0,0 +1,115 @@
import {
createCipheriv,
createDecipheriv,
createHash,
createSign,
createVerify,
randomBytes,
timingSafeEqual,
} from 'node:crypto';
import { Injectable } from '@nestjs/common';
import {
hash as hashPassword,
verify as verifyPassword,
} from '@node-rs/argon2';
import { Config } from '../config';
const NONCE_LENGTH = 12;
const AUTH_TAG_LENGTH = 12;
@Injectable()
export class CryptoHelper {
keyPair: {
publicKey: Buffer;
privateKey: Buffer;
sha256: {
publicKey: Buffer;
privateKey: Buffer;
};
};
constructor(config: Config) {
this.keyPair = {
publicKey: Buffer.from(config.secrets.publicKey, 'utf8'),
privateKey: Buffer.from(config.secrets.privateKey, 'utf8'),
sha256: {
publicKey: this.sha256(config.secrets.publicKey),
privateKey: this.sha256(config.secrets.privateKey),
},
};
}
sign(data: string) {
const sign = createSign('rsa-sha256');
sign.update(data, 'utf-8');
sign.end();
return sign.sign(this.keyPair.privateKey, 'base64');
}
verify(data: string, signature: string) {
const verify = createVerify('rsa-sha256');
verify.update(data, 'utf-8');
verify.end();
return verify.verify(this.keyPair.privateKey, signature, 'base64');
}
encrypt(data: string) {
const iv = this.randomBytes();
const cipher = createCipheriv(
'aes-256-gcm',
this.keyPair.sha256.privateKey,
iv,
{
authTagLength: AUTH_TAG_LENGTH,
}
);
const encrypted = Buffer.concat([
cipher.update(data, 'utf-8'),
cipher.final(),
]);
const authTag = cipher.getAuthTag();
return Buffer.concat([iv, authTag, encrypted]).toString('base64');
}
decrypt(encrypted: string) {
const buf = Buffer.from(encrypted, 'base64');
const iv = buf.subarray(0, NONCE_LENGTH);
const authTag = buf.subarray(NONCE_LENGTH, NONCE_LENGTH + AUTH_TAG_LENGTH);
const encryptedToken = buf.subarray(NONCE_LENGTH + AUTH_TAG_LENGTH);
const decipher = createDecipheriv(
'aes-256-gcm',
this.keyPair.sha256.privateKey,
iv,
{ authTagLength: AUTH_TAG_LENGTH }
);
decipher.setAuthTag(authTag);
const decrepted = decipher.update(encryptedToken, void 0, 'utf8');
return decrepted + decipher.final('utf8');
}
encryptPassword(password: string) {
return hashPassword(password);
}
verifyPassword(password: string, hash: string) {
return verifyPassword(hash, password);
}
compare(lhs: string, rhs: string) {
if (lhs.length !== rhs.length) {
return false;
}
return timingSafeEqual(Buffer.from(lhs), Buffer.from(rhs));
}
randomBytes(length = NONCE_LENGTH) {
return randomBytes(length);
}
sha256(data: string) {
return createHash('sha256').update(data).digest();
}
}

View File

@ -0,0 +1,13 @@
import { Global, Module } from '@nestjs/common';
import { CryptoHelper } from './crypto';
import { URLHelper } from './url';
@Global()
@Module({
providers: [URLHelper, CryptoHelper],
exports: [URLHelper, CryptoHelper],
})
export class HelpersModule {}
export { CryptoHelper, URLHelper };

View File

@ -0,0 +1,54 @@
import { Injectable } from '@nestjs/common';
import { type Response } from 'express';
import { Config } from '../config';
@Injectable()
export class URLHelper {
redirectAllowHosts: string[];
constructor(private readonly config: Config) {
this.redirectAllowHosts = [this.config.baseUrl];
}
get home() {
return this.config.baseUrl;
}
stringify(query: Record<string, any>) {
return new URLSearchParams(query).toString();
}
link(path: string, query: Record<string, any> = {}) {
const url = new URL(
this.config.baseUrl + (path.startsWith('/') ? path : '/' + path)
);
for (const key in query) {
url.searchParams.set(key, query[key]);
}
return url.toString();
}
safeRedirect(res: Response, to: string) {
try {
const finalTo = new URL(decodeURIComponent(to), this.config.baseUrl);
for (const host of this.redirectAllowHosts) {
const hostURL = new URL(host);
if (
hostURL.origin === finalTo.origin &&
finalTo.pathname.startsWith(hostURL.pathname)
) {
return res.redirect(finalTo.toString().replace(/\/$/, ''));
}
}
} catch {
// just ignore invalid url
}
// redirect to home if the url is invalid
return res.redirect(this.home);
}
}

View File

@ -14,6 +14,7 @@ export {
} from './config';
export * from './error';
export { EventEmitter, type EventPayload, OnEvent } from './event';
export { CryptoHelper, URLHelper } from './helpers';
export { MailService } from './mailer';
export { CallCounter, CallTimer, metrics } from './metrics';
export {
@ -21,7 +22,6 @@ export {
GlobalExceptionFilter,
OptionalModule,
} from './nestjs';
export { SessionService } from './session';
export * from './storage';
export { type StorageProvider, StorageProviderFactory } from './storage';
export { AuthThrottlerGuard, CloudThrottlerGuard, Throttle } from './throttler';

View File

@ -1,12 +1,14 @@
import { Inject, Injectable, Optional } from '@nestjs/common';
import { Config } from '../config';
import { URLHelper } from '../helpers';
import { MAILER_SERVICE, type MailerService, type Options } from './mailer';
import { emailTemplate } from './template';
@Injectable()
export class MailService {
constructor(
private readonly config: Config,
private readonly url: URLHelper,
@Optional() @Inject(MAILER_SERVICE) private readonly mailer?: MailerService
) {}
@ -41,7 +43,7 @@ export class MailService {
}
) {
// TODO: use callback url when need support desktop app
const buttonUrl = `${this.config.origin}/invite/${inviteId}`;
const buttonUrl = this.url.link(`/invite/${inviteId}`);
const workspaceAvatar = invitationInfo.workspace.avatar;
const content = `<p style="margin:0">${
@ -92,7 +94,23 @@ export class MailService {
});
}
async sendSignInEmail(url: string, options: Options) {
async sendSignUpMail(url: string, options: Options) {
const html = emailTemplate({
title: 'Create AFFiNE Account',
content:
'Click the button below to complete your account creation and sign in. This magic link will expire in 30 minutes.',
buttonContent: ' Create account and sign in',
buttonUrl: url,
});
return this.sendMail({
html,
subject: 'Your AFFiNE account is waiting for you!',
...options,
});
}
async sendSignInMail(url: string, options: Options) {
const html = emailTemplate({
title: 'Sign in to AFFiNE',
content:
@ -164,6 +182,20 @@ export class MailService {
html,
});
}
async sendVerifyEmail(to: string, url: string) {
const html = emailTemplate({
title: 'Verify your email address',
content:
'You recently requested to verify the email address associated with your AFFiNE account. To complete this process, please click on the verification link below. This magic link will expire in 30 minutes.',
buttonContent: 'Verify your email address',
buttonUrl: url,
});
return this.sendMail({
to,
subject: `Verify your email for AFFiNE`,
html,
});
}
async sendNotificationChangeEmail(to: string) {
const html = emailTemplate({
title: 'Email change successful',

View File

@ -9,7 +9,7 @@ import { omit } from 'lodash-es';
import { Config, ConfigPaths } from '../config';
interface OptionalModuleMetadata extends ModuleMetadata {
export interface OptionalModuleMetadata extends ModuleMetadata {
/**
* Only install module if given config paths are defined in AFFiNE config.
*/

View File

@ -1,44 +0,0 @@
import { Global, Injectable, Module } from '@nestjs/common';
import { SessionCache } from '../cache';
@Injectable()
export class SessionService {
private readonly prefix = 'session:';
public readonly sessionTtl = 30 * 60 * 1000; // 30 min
constructor(private readonly cache: SessionCache) {}
/**
* get session
* @param key session key
* @returns
*/
async get(key: string) {
return this.cache.get<string>(this.prefix + key);
}
/**
* set session
* @param key session key
* @param value session value
* @param sessionTtl session ttl (ms), default 30 min
* @returns return true if success
*/
async set(key: string, value?: any, sessionTtl = this.sessionTtl) {
return this.cache.set<string>(this.prefix + key, value, {
ttl: sessionTtl,
});
}
async delete(key: string) {
return this.cache.delete(this.prefix + key);
}
}
@Global()
@Module({
providers: [SessionService],
exports: [SessionService],
})
export class SessionModule {}

View File

@ -1,53 +1,8 @@
import type { ArgumentsHost, ExecutionContext } from '@nestjs/common';
import type { GqlContextType } from '@nestjs/graphql';
import { GqlArgumentsHost, GqlExecutionContext } from '@nestjs/graphql';
import { GqlArgumentsHost } from '@nestjs/graphql';
import type { Request, Response } from 'express';
export function getRequestResponseFromContext(context: ExecutionContext) {
switch (context.getType<GqlContextType>()) {
case 'graphql': {
const gqlContext = GqlExecutionContext.create(context).getContext<{
req: Request;
}>();
return {
req: gqlContext.req,
res: gqlContext.req.res,
};
}
case 'http': {
const http = context.switchToHttp();
return {
req: http.getRequest<Request>(),
res: http.getResponse<Response>(),
};
}
case 'ws': {
const ws = context.switchToWs();
const req = ws.getClient().handshake;
const cookies = req?.headers?.cookie;
// patch cookies to match auth guard logic
if (typeof cookies === 'string') {
req.cookies = cookies
.split(';')
.map(v => v.split('='))
.reduce(
(acc, v) => {
acc[decodeURIComponent(v[0].trim())] = decodeURIComponent(
v[1].trim()
);
return acc;
},
{} as Record<string, string>
);
}
return { req };
}
default:
throw new Error('Unknown context type for getting request and response');
}
}
import type { Socket } from 'socket.io';
export function getRequestResponseFromHost(host: ArgumentsHost) {
switch (host.getType<GqlContextType>()) {
@ -67,11 +22,47 @@ export function getRequestResponseFromHost(host: ArgumentsHost) {
res: http.getResponse<Response>(),
};
}
default:
throw new Error('Unknown host type for getting request and response');
case 'ws': {
const ws = host.switchToWs();
const req = ws.getClient<Socket>().client.conn.request as Request;
const cookieStr = req?.headers?.cookie;
// patch cookies to match auth guard logic
if (typeof cookieStr === 'string') {
req.cookies = cookieStr.split(';').reduce(
(cookies, cookie) => {
const [key, val] = cookie.split('=');
if (key) {
cookies[decodeURIComponent(key.trim())] = val
? decodeURIComponent(val.trim())
: val;
}
return cookies;
},
{} as Record<string, string>
);
}
return { req };
}
case 'rpc': {
const rpc = host.switchToRpc();
const { req } = rpc.getContext<{ req: Request }>();
return {
req,
res: req.res,
};
}
}
}
export function getRequestFromHost(host: ArgumentsHost) {
return getRequestResponseFromHost(host).req;
}
export function getRequestResponseFromContext(ctx: ExecutionContext) {
return getRequestResponseFromHost(ctx);
}

View File

@ -1,6 +1,6 @@
declare namespace Express {
interface Request {
user?: import('@prisma/client').User | null;
user?: import('./core/auth/current-user').CurrentUser;
}
}

View File

@ -1,4 +1,5 @@
import { GCloudConfig } from './gcloud/config';
import { OAuthConfig } from './oauth';
import { PaymentConfig } from './payment';
import { RedisOptions } from './redis';
import { R2StorageConfig, S3StorageConfig } from './storage';
@ -10,13 +11,14 @@ declare module '../fundamentals/config' {
readonly gcloud: GCloudConfig;
readonly 'cloudflare-r2': R2StorageConfig;
readonly 'aws-s3': S3StorageConfig;
readonly oauth: OAuthConfig;
}
export type AvailablePlugins = keyof PluginsConfig;
interface AFFiNEConfig {
readonly plugins: {
enabled: AvailablePlugins[];
enabled: Set<AvailablePlugins>;
use<Plugin extends AvailablePlugins>(
plugin: Plugin,
config?: DeepPartial<PluginsConfig[Plugin]>

View File

@ -1,10 +1,11 @@
import { Global } from '@nestjs/common';
import { OptionalModule } from '../../fundamentals';
import { Plugin } from '../registry';
import { GCloudMetrics } from './metrics';
@Global()
@OptionalModule({
@Plugin({
name: 'gcloud',
imports: [GCloudMetrics],
})
export class GCloudModule {}

View File

@ -1,13 +1,7 @@
import type { AvailablePlugins } from '../fundamentals/config';
import { GCloudModule } from './gcloud';
import { PaymentModule } from './payment';
import { RedisModule } from './redis';
import { AwsS3Module, CloudflareR2Module } from './storage';
import './gcloud';
import './oauth';
import './payment';
import './redis';
import './storage';
export const pluginsMap = new Map<AvailablePlugins, AFFiNEModule>([
['payment', PaymentModule],
['redis', RedisModule],
['gcloud', GCloudModule],
['cloudflare-r2', CloudflareR2Module],
['aws-s3', AwsS3Module],
]);
export { REGISTERED_PLUGINS } from './registry';

View File

@ -0,0 +1,230 @@
import {
BadRequestException,
Controller,
Get,
Query,
Req,
Res,
} from '@nestjs/common';
import { ConnectedAccount, PrismaClient } from '@prisma/client';
import type { Request, Response } from 'express';
import { AuthService, Public } from '../../core/auth';
import { UserService } from '../../core/user';
import { URLHelper } from '../../fundamentals';
import { OAuthAccount, Tokens } from './providers/def';
import { OAuthProviderFactory } from './register';
import { OAuthService } from './service';
import { OAuthProviderName } from './types';
@Controller('/oauth')
export class OAuthController {
constructor(
private readonly auth: AuthService,
private readonly oauth: OAuthService,
private readonly user: UserService,
private readonly providerFactory: OAuthProviderFactory,
private readonly url: URLHelper,
private readonly db: PrismaClient
) {}
@Public()
@Get('/login')
async login(
@Res() res: Response,
@Query('provider') unknownProviderName: string,
@Query('redirect_uri') redirectUri?: string
) {
// @ts-expect-error safe
const providerName = OAuthProviderName[unknownProviderName];
const provider = this.providerFactory.get(providerName);
if (!provider) {
throw new BadRequestException('Invalid provider');
}
const state = await this.oauth.saveOAuthState({
redirectUri: redirectUri ?? this.url.home,
provider: providerName,
});
return res.redirect(provider.getAuthUrl(state));
}
@Public()
@Get('/callback')
async callback(
@Req() req: Request,
@Res() res: Response,
@Query('code') code?: string,
@Query('state') stateStr?: string
) {
if (!code) {
throw new BadRequestException('Missing query parameter `code`');
}
if (!stateStr) {
throw new BadRequestException('Invalid callback state parameter');
}
const state = await this.oauth.getOAuthState(stateStr);
if (!state) {
throw new BadRequestException('OAuth state expired, please try again.');
}
if (!state.provider) {
throw new BadRequestException(
'Missing callback state parameter `provider`'
);
}
const provider = this.providerFactory.get(state.provider);
if (!provider) {
throw new BadRequestException('Invalid provider');
}
const tokens = await provider.getToken(code);
const externAccount = await provider.getUser(tokens.accessToken);
const user = req.user;
try {
if (!user) {
// if user not found, login
const user = await this.loginFromOauth(
state.provider,
externAccount,
tokens
);
const session = await this.auth.createUserSession(
user,
req.cookies[AuthService.sessionCookieName]
);
res.cookie(AuthService.sessionCookieName, session.sessionId, {
expires: session.expiresAt ?? void 0, // expiredAt is `string | null`
...this.auth.cookieOptions,
});
} else {
// if user is found, connect the account to this user
await this.connectAccountFromOauth(
user,
state.provider,
externAccount,
tokens
);
}
} catch (e: any) {
return res.redirect(
this.url.link('/signIn', {
redirect_uri: state.redirectUri,
error: e.message,
})
);
}
this.url.safeRedirect(res, state.redirectUri);
}
private async loginFromOauth(
provider: OAuthProviderName,
externalAccount: OAuthAccount,
tokens: Tokens
) {
const connectedUser = await this.db.connectedAccount.findFirst({
where: {
provider,
providerAccountId: externalAccount.id,
},
include: {
user: true,
},
});
if (connectedUser) {
// already connected
await this.updateConnectedAccount(connectedUser, tokens);
return connectedUser.user;
}
let user = await this.user.findUserByEmail(externalAccount.email);
if (user) {
// we can't directly connect the external account with given email in sign in scenario for safety concern.
// let user manually connect in account sessions instead.
throw new BadRequestException(
'The account with provided email is not register in the same way.'
);
} else {
user = await this.createUserWithConnectedAccount(
provider,
externalAccount,
tokens
);
}
return user;
}
updateConnectedAccount(connectedUser: ConnectedAccount, tokens: Tokens) {
return this.db.connectedAccount.update({
where: {
id: connectedUser.id,
},
data: tokens,
});
}
async createUserWithConnectedAccount(
provider: OAuthProviderName,
externalAccount: OAuthAccount,
tokens: Tokens
) {
return this.user.createUser({
email: externalAccount.email,
name: 'Unnamed',
avatarUrl: externalAccount.avatarUrl,
emailVerifiedAt: new Date(),
connectedAccounts: {
create: {
provider,
providerAccountId: externalAccount.id,
...tokens,
},
},
});
}
private async connectAccountFromOauth(
user: { id: string },
provider: OAuthProviderName,
externalAccount: OAuthAccount,
tokens: Tokens
) {
const connectedUser = await this.db.connectedAccount.findFirst({
where: {
provider,
providerAccountId: externalAccount.id,
},
});
if (connectedUser) {
if (connectedUser.id !== user.id) {
throw new BadRequestException(
'The third-party account has already been connected to another user.'
);
}
} else {
await this.db.connectedAccount.create({
data: {
userId: user.id,
provider,
providerAccountId: externalAccount.id,
accessToken: tokens.accessToken,
refreshToken: tokens.refreshToken,
},
});
}
}
}

View File

@ -0,0 +1,25 @@
import { AuthModule } from '../../core/auth';
import { ServerFeature } from '../../core/config';
import { UserModule } from '../../core/user';
import { Plugin } from '../registry';
import { OAuthController } from './controller';
import { OAuthProviders } from './providers';
import { OAuthProviderFactory } from './register';
import { OAuthResolver } from './resolver';
import { OAuthService } from './service';
@Plugin({
name: 'oauth',
imports: [AuthModule, UserModule],
providers: [
OAuthProviderFactory,
OAuthService,
OAuthResolver,
...OAuthProviders,
],
controllers: [OAuthController],
contributesTo: ServerFeature.OAuth,
if: config => !!config.plugins.oauth,
})
export class OAuthModule {}
export type { OAuthConfig } from './types';

View File

@ -0,0 +1,21 @@
import { OAuthProviderName } from '../types';
export interface OAuthAccount {
id: string;
email: string;
avatarUrl?: string;
}
export interface Tokens {
accessToken: string;
scope?: string;
refreshToken?: string;
expiresAt?: Date;
}
export abstract class OAuthProvider {
abstract provider: OAuthProviderName;
abstract getAuthUrl(state?: string): string;
abstract getToken(code: string): Promise<Tokens>;
abstract getUser(token: string): Promise<OAuthAccount>;
}

View File

@ -0,0 +1,113 @@
import { HttpException, HttpStatus, Injectable } from '@nestjs/common';
import { Config, URLHelper } from '../../../fundamentals';
import { AutoRegisteredOAuthProvider } from '../register';
import { OAuthProviderName } from '../types';
interface AuthTokenResponse {
access_token: string;
scope: string;
token_type: string;
}
export interface UserInfo {
login: string;
email: string;
avatar_url: string;
name: string;
}
@Injectable()
export class GithubOAuthProvider extends AutoRegisteredOAuthProvider {
provider = OAuthProviderName.GitHub;
constructor(
protected readonly AFFiNEConfig: Config,
private readonly url: URLHelper
) {
super();
}
getAuthUrl(state: string) {
return `https://github.com/login/oauth/authorize?${this.url.stringify({
client_id: this.config.clientId,
redirect_uri: this.url.link('/oauth/callback'),
scope: 'user',
...this.config.args,
state,
})}`;
}
async getToken(code: string) {
try {
const response = await fetch(
'https://github.com/login/oauth/access_token',
{
method: 'POST',
body: this.url.stringify({
code,
client_id: this.config.clientId,
client_secret: this.config.clientSecret,
redirect_uri: this.url.link('/oauth/callback'),
}),
headers: {
Accept: 'application/json',
'Content-Type': 'application/x-www-form-urlencoded',
},
}
);
if (response.ok) {
const ghToken = (await response.json()) as AuthTokenResponse;
return {
accessToken: ghToken.access_token,
scope: ghToken.scope,
};
} else {
throw new Error(
`Server responded with non-success code ${
response.status
}, ${JSON.stringify(await response.json())}`
);
}
} catch (e) {
throw new HttpException(
`Failed to get access_token, err: ${(e as Error).message}`,
HttpStatus.BAD_REQUEST
);
}
}
async getUser(token: string) {
try {
const response = await fetch('https://api.github.com/user', {
method: 'GET',
headers: {
Authorization: `Bearer ${token}`,
},
});
if (response.ok) {
const user = (await response.json()) as UserInfo;
return {
id: user.login,
avatarUrl: user.avatar_url,
email: user.email,
};
} else {
throw new Error(
`Server responded with non-success code ${
response.status
} ${await response.text()}`
);
}
} catch (e) {
throw new HttpException(
`Failed to get user information, err: ${(e as Error).stack}`,
HttpStatus.BAD_REQUEST
);
}
}
}

View File

@ -0,0 +1,121 @@
import { HttpException, HttpStatus, Injectable } from '@nestjs/common';
import { Config, URLHelper } from '../../../fundamentals';
import { AutoRegisteredOAuthProvider } from '../register';
import { OAuthProviderName } from '../types';
interface GoogleOAuthTokenResponse {
access_token: string;
expires_in: number;
refresh_token: string;
scope: string;
token_type: string;
}
export interface UserInfo {
id: string;
email: string;
picture: string;
name: string;
}
@Injectable()
export class GoogleOAuthProvider extends AutoRegisteredOAuthProvider {
override provider = OAuthProviderName.Google;
constructor(
protected readonly AFFiNEConfig: Config,
private readonly url: URLHelper
) {
super();
}
getAuthUrl(state: string) {
return `https://accounts.google.com/o/oauth2/v2/auth?${this.url.stringify({
client_id: this.config.clientId,
redirect_uri: this.url.link('/oauth/callback'),
response_type: 'code',
scope: 'openid email profile',
promot: 'select_account',
access_type: 'offline',
...this.config.args,
state,
})}`;
}
async getToken(code: string) {
try {
const response = await fetch('https://oauth2.googleapis.com/token', {
method: 'POST',
body: this.url.stringify({
code,
client_id: this.config.clientId,
client_secret: this.config.clientSecret,
redirect_uri: this.url.link('/oauth/callback'),
grant_type: 'authorization_code',
}),
headers: {
Accept: 'application/json',
'Content-Type': 'application/x-www-form-urlencoded',
},
});
if (response.ok) {
const ghToken = (await response.json()) as GoogleOAuthTokenResponse;
return {
accessToken: ghToken.access_token,
refreshToken: ghToken.refresh_token,
expiresAt: new Date(Date.now() + ghToken.expires_in * 1000),
scope: ghToken.scope,
};
} else {
throw new Error(
`Server responded with non-success code ${
response.status
}, ${JSON.stringify(await response.json())}`
);
}
} catch (e) {
throw new HttpException(
`Failed to get access_token, err: ${(e as Error).message}`,
HttpStatus.BAD_REQUEST
);
}
}
async getUser(token: string) {
try {
const response = await fetch(
'https://www.googleapis.com/oauth2/v2/userinfo',
{
method: 'GET',
headers: {
Authorization: `Bearer ${token}`,
},
}
);
if (response.ok) {
const user = (await response.json()) as UserInfo;
return {
id: user.id,
avatarUrl: user.picture,
email: user.email,
};
} else {
throw new Error(
`Server responded with non-success code ${
response.status
} ${await response.text()}`
);
}
} catch (e) {
throw new HttpException(
`Failed to get user information, err: ${(e as Error).stack}`,
HttpStatus.BAD_REQUEST
);
}
}
}

View File

@ -0,0 +1,4 @@
import { GithubOAuthProvider } from './github';
import { GoogleOAuthProvider } from './google';
export const OAuthProviders = [GoogleOAuthProvider, GithubOAuthProvider];

View File

@ -0,0 +1,58 @@
import { Injectable, Logger, OnModuleInit } from '@nestjs/common';
import { Config } from '../../fundamentals';
import { OAuthProvider } from './providers/def';
import { OAuthProviderName } from './types';
const PROVIDERS: Map<OAuthProviderName, OAuthProvider> = new Map();
export function registerOAuthProvider(
name: OAuthProviderName,
provider: OAuthProvider
) {
PROVIDERS.set(name, provider);
}
@Injectable()
export class OAuthProviderFactory {
get providers() {
return PROVIDERS.keys();
}
get(name: OAuthProviderName): OAuthProvider | undefined {
return PROVIDERS.get(name);
}
}
export abstract class AutoRegisteredOAuthProvider
extends OAuthProvider
implements OnModuleInit
{
protected abstract AFFiNEConfig: Config;
get optionalConfig() {
return this.AFFiNEConfig.plugins.oauth?.providers?.[this.provider];
}
get config() {
const config = this.optionalConfig;
if (!config) {
throw new Error(
`OAuthProvider Config should not be used before registered`
);
}
return config;
}
onModuleInit() {
const config = this.optionalConfig;
if (config && config.clientId && config.clientSecret) {
registerOAuthProvider(this.provider, this);
new Logger(`OAuthProvider:${this.provider}`).log(
'OAuth provider registered.'
);
}
}
}

View File

@ -0,0 +1,17 @@
import { registerEnumType, ResolveField, Resolver } from '@nestjs/graphql';
import { ServerConfigType } from '../../core/config';
import { OAuthProviderFactory } from './register';
import { OAuthProviderName } from './types';
registerEnumType(OAuthProviderName, { name: 'OAuthProviderType' });
@Resolver(() => ServerConfigType)
export class OAuthResolver {
constructor(private readonly factory: OAuthProviderFactory) {}
@ResolveField(() => [OAuthProviderName])
oauthProviders() {
return this.factory.providers;
}
}

View File

@ -0,0 +1,39 @@
import { randomUUID } from 'node:crypto';
import { Injectable } from '@nestjs/common';
import { SessionCache } from '../../fundamentals';
import { OAuthProviderFactory } from './register';
import { OAuthProviderName } from './types';
const OAUTH_STATE_KEY = 'OAUTH_STATE';
interface OAuthState {
redirectUri: string;
provider: OAuthProviderName;
}
@Injectable()
export class OAuthService {
constructor(
private readonly providerFactory: OAuthProviderFactory,
private readonly cache: SessionCache
) {}
async saveOAuthState(state: OAuthState) {
const token = randomUUID();
await this.cache.set(`${OAUTH_STATE_KEY}:${token}`, state, {
ttl: 3600 * 3 * 1000 /* 3 hours */,
});
return token;
}
async getOAuthState(token: string) {
return this.cache.get<OAuthState>(`${OAUTH_STATE_KEY}:${token}`);
}
availableOAuthProviders() {
return this.providerFactory.providers;
}
}

View File

@ -0,0 +1,15 @@
export interface OAuthProviderConfig {
clientId: string;
clientSecret: string;
args?: Record<string, string>;
}
export enum OAuthProviderName {
Google = 'google',
GitHub = 'github',
}
export interface OAuthConfig {
enabled: boolean;
providers: Partial<{ [key in OAuthProviderName]: OAuthProviderConfig }>;
}

View File

@ -1,13 +1,14 @@
import { ServerFeature } from '../../core/config';
import { FeatureModule } from '../../core/features';
import { OptionalModule } from '../../fundamentals';
import { Plugin } from '../registry';
import { SubscriptionResolver, UserSubscriptionResolver } from './resolver';
import { ScheduleManager } from './schedule';
import { SubscriptionService } from './service';
import { StripeProvider } from './stripe';
import { StripeWebhook } from './webhook';
@OptionalModule({
@Plugin({
name: 'payment',
imports: [FeatureModule],
providers: [
ScheduleManager,

View File

@ -21,8 +21,8 @@ import type { User, UserInvoice, UserSubscription } from '@prisma/client';
import { PrismaClient } from '@prisma/client';
import { groupBy } from 'lodash-es';
import { Auth, CurrentUser, Public } from '../../core/auth';
import { UserType } from '../../core/users';
import { CurrentUser, Public } from '../../core/auth';
import { UserType } from '../../core/user';
import { Config } from '../../fundamentals';
import { decodeLookupKey, SubscriptionService } from './service';
import {
@ -155,7 +155,6 @@ class CreateCheckoutSessionInput {
idempotencyKey!: string;
}
@Auth()
@Resolver(() => UserSubscriptionType)
export class SubscriptionResolver {
constructor(
@ -217,7 +216,7 @@ export class SubscriptionResolver {
description: 'Create a subscription checkout link of stripe',
})
async checkout(
@CurrentUser() user: User,
@CurrentUser() user: CurrentUser,
@Args({ name: 'recurring', type: () => SubscriptionRecurring })
recurring: SubscriptionRecurring,
@Args('idempotencyKey') idempotencyKey: string
@ -241,7 +240,7 @@ export class SubscriptionResolver {
description: 'Create a subscription checkout link of stripe',
})
async createCheckoutSession(
@CurrentUser() user: User,
@CurrentUser() user: CurrentUser,
@Args({ name: 'input', type: () => CreateCheckoutSessionInput })
input: CreateCheckoutSessionInput
) {
@ -265,13 +264,13 @@ export class SubscriptionResolver {
@Mutation(() => String, {
description: 'Create a stripe customer portal to manage payment methods',
})
async createCustomerPortal(@CurrentUser() user: User) {
async createCustomerPortal(@CurrentUser() user: CurrentUser) {
return this.service.createCustomerPortal(user.id);
}
@Mutation(() => UserSubscriptionType)
async cancelSubscription(
@CurrentUser() user: User,
@CurrentUser() user: CurrentUser,
@Args('idempotencyKey') idempotencyKey: string
) {
return this.service.cancelSubscription(idempotencyKey, user.id);
@ -279,7 +278,7 @@ export class SubscriptionResolver {
@Mutation(() => UserSubscriptionType)
async resumeSubscription(
@CurrentUser() user: User,
@CurrentUser() user: CurrentUser,
@Args('idempotencyKey') idempotencyKey: string
) {
return this.service.resumeCanceledSubscription(idempotencyKey, user.id);
@ -287,7 +286,7 @@ export class SubscriptionResolver {
@Mutation(() => UserSubscriptionType)
async updateSubscriptionRecurring(
@CurrentUser() user: User,
@CurrentUser() user: CurrentUser,
@Args({ name: 'recurring', type: () => SubscriptionRecurring })
recurring: SubscriptionRecurring,
@Args('idempotencyKey') idempotencyKey: string

View File

@ -10,6 +10,7 @@ import type {
import { PrismaClient } from '@prisma/client';
import Stripe from 'stripe';
import { CurrentUser } from '../../core/auth';
import { FeatureManagementService } from '../../core/features';
import { EventEmitter } from '../../fundamentals';
import { ScheduleManager } from './schedule';
@ -75,7 +76,7 @@ export class SubscriptionService {
redirectUrl,
idempotencyKey,
}: {
user: User;
user: CurrentUser;
recurring: SubscriptionRecurring;
plan: SubscriptionPlan;
promotionCode?: string | null;
@ -549,7 +550,7 @@ export class SubscriptionService {
private async getOrCreateCustomer(
idempotencyKey: string,
user: User
user: CurrentUser
): Promise<UserStripeCustomer> {
const customer = await this.db.userStripeCustomer.findUnique({
where: {
@ -649,7 +650,7 @@ export class SubscriptionService {
}
private async getAvailableCoupon(
user: User,
user: CurrentUser,
couponType: CouponType
): Promise<string | null> {
const earlyAccess = await this.features.isEarlyAccessUser(user.email);

View File

@ -2,9 +2,10 @@ import { Global, Provider, Type } from '@nestjs/common';
import { Redis, type RedisOptions } from 'ioredis';
import { ThrottlerStorageRedisService } from 'nestjs-throttler-storage-redis';
import { Cache, OptionalModule, SessionCache } from '../../fundamentals';
import { Cache, SessionCache } from '../../fundamentals';
import { ThrottlerStorage } from '../../fundamentals/throttler';
import { SocketIoAdapterImpl } from '../../fundamentals/websocket';
import { Plugin } from '../registry';
import { RedisCache } from './cache';
import {
CacheRedis,
@ -47,7 +48,8 @@ const socketIoRedisAdapterProvider: Provider = {
};
@Global()
@OptionalModule({
@Plugin({
name: 'redis',
providers: [CacheRedis, SessionRedis, ThrottlerRedis, SocketIoRedis],
overrides: [
cacheProvider,

View File

@ -0,0 +1,22 @@
import { omit } from 'lodash-es';
import { AvailablePlugins } from '../fundamentals/config';
import { OptionalModule, OptionalModuleMetadata } from '../fundamentals/nestjs';
export const REGISTERED_PLUGINS = new Map<AvailablePlugins, AFFiNEModule>();
function register(plugin: AvailablePlugins, module: AFFiNEModule) {
REGISTERED_PLUGINS.set(plugin, module);
}
interface PluginModuleMetadata extends OptionalModuleMetadata {
name: AvailablePlugins;
}
export const Plugin = (options: PluginModuleMetadata) => {
return (target: any) => {
register(options.name, target);
return OptionalModule(omit(options, 'name'))(target);
};
};

View File

@ -1,5 +1,5 @@
import { OptionalModule } from '../../fundamentals';
import { registerStorageProvider } from '../../fundamentals/storage';
import { Plugin } from '../registry';
import { R2StorageProvider } from './providers/r2';
import { S3StorageProvider } from './providers/s3';
@ -18,7 +18,8 @@ registerStorageProvider('aws-s3', (config, bucket) => {
return new S3StorageProvider(config.plugins['aws-s3'], bucket);
});
@OptionalModule({
@Plugin({
name: 'cloudflare-r2',
requires: [
'plugins.cloudflare-r2.accountId',
'plugins.cloudflare-r2.credentials.accessKeyId',
@ -28,7 +29,8 @@ registerStorageProvider('aws-s3', (config, bucket) => {
})
export class CloudflareR2Module {}
@OptionalModule({
@Plugin({
name: 'aws-s3',
requires: [
'plugins.aws-s3.credentials.accessKeyId',
'plugins.aws-s3.credentials.secretAccessKey',

View File

@ -67,14 +67,14 @@ type InviteUserType {
"""User avatar url"""
avatarUrl: String
"""User created date"""
createdAt: DateTime
"""User email verified"""
createdAt: DateTime @deprecated(reason: "useless")
"""User email"""
email: String
"""User email verified"""
emailVerified: DateTime
emailVerified: Boolean
"""User password has been set"""
hasPassword: Boolean
@ -111,7 +111,7 @@ type Mutation {
addToEarlyAccess(email: String!): Int!
addWorkspaceFeature(feature: FeatureType!, workspaceId: String!): Int!
cancelSubscription(idempotencyKey: String!): UserSubscription!
changeEmail(token: String!): UserType!
changeEmail(email: String!, token: String!): UserType!
changePassword(newPassword: String!, token: String!): UserType!
"""Create a subscription checkout link of stripe"""
@ -141,15 +141,17 @@ type Mutation {
revoke(userId: String!, workspaceId: String!): Boolean!
revokePage(pageId: String!, workspaceId: String!): Boolean! @deprecated(reason: "use revokePublicPage")
revokePublicPage(pageId: String!, workspaceId: String!): WorkspacePage!
sendChangeEmail(callbackUrl: String!, email: String!): Boolean!
sendChangePasswordEmail(callbackUrl: String!, email: String!): Boolean!
sendSetPasswordEmail(callbackUrl: String!, email: String!): Boolean!
sendChangeEmail(callbackUrl: String!, email: String): Boolean!
sendChangePasswordEmail(callbackUrl: String!, email: String): Boolean!
sendSetPasswordEmail(callbackUrl: String!, email: String): Boolean!
sendVerifyChangeEmail(callbackUrl: String!, email: String!, token: String!): Boolean!
sendVerifyEmail(callbackUrl: String!): Boolean!
setBlob(blob: Upload!, workspaceId: String!): String!
setWorkspaceExperimentalFeature(enable: Boolean!, feature: FeatureType!, workspaceId: String!): Boolean!
sharePage(pageId: String!, workspaceId: String!): Boolean! @deprecated(reason: "renamed to publicPage")
signIn(email: String!, password: String!): UserType!
signUp(email: String!, name: String!, password: String!): UserType!
updateProfile(input: UpdateUserInput!): UserType!
updateSubscriptionRecurring(idempotencyKey: String!, recurring: SubscriptionRecurring!): UserSubscription!
"""Update workspace"""
@ -157,6 +159,12 @@ type Mutation {
"""Upload user avatar"""
uploadAvatar(avatar: Upload!): UserType!
verifyEmail(token: String!): Boolean!
}
enum OAuthProviderType {
GitHub
Google
}
"""User permission in workspace"""
@ -239,6 +247,7 @@ type ServerConfigType {
"""server identical name could be shown as badge on user interface"""
name: String!
oauthProviders: [OAuthProviderType!]!
"""server type"""
type: ServerDeploymentType!
@ -253,6 +262,7 @@ enum ServerDeploymentType {
}
enum ServerFeature {
OAuth
Payment
}
@ -288,10 +298,9 @@ enum SubscriptionStatus {
Unpaid
}
type TokenType {
refresh: String!
sessionToken: String
token: String!
input UpdateUserInput {
"""User name"""
name: String
}
input UpdateWorkspaceInput {
@ -356,14 +365,14 @@ type UserType {
"""User avatar url"""
avatarUrl: String
"""User created date"""
createdAt: DateTime
"""User email verified"""
createdAt: DateTime @deprecated(reason: "useless")
"""User email"""
email: String!
"""User email verified"""
emailVerified: DateTime
emailVerified: Boolean!
"""User password has been set"""
hasPassword: Boolean
@ -377,7 +386,7 @@ type UserType {
name: String!
quota: UserQuota
subscription: UserSubscription
token: TokenType!
token: tokenType! @deprecated(reason: "use [/api/auth/authorize]")
}
type WorkspaceBlobSizes {
@ -433,3 +442,9 @@ type WorkspaceType {
"""Shared pages of workspace"""
sharedPages: [String!]! @deprecated(reason: "use WorkspaceType.publicPages")
}
type tokenType {
refresh: String!
sessionToken: String
token: String!
}

View File

@ -1,16 +1,8 @@
import { ok } from 'node:assert';
import { randomUUID } from 'node:crypto';
import { Transformer } from '@napi-rs/image';
import type { INestApplication } from '@nestjs/common';
import { hashSync } from '@node-rs/argon2';
import { PrismaClient, type User } from '@prisma/client';
import ava, { type TestFn } from 'ava';
import type { Express } from 'express';
import request from 'supertest';
import { AppModule } from '../src/app.module';
import { FeatureManagementService } from '../src/core/features';
import { createTestingApp } from './utils';
const gql = '/graphql';
@ -19,43 +11,9 @@ const test = ava as TestFn<{
app: INestApplication;
}>;
class FakePrisma {
fakeUser: User = {
id: randomUUID(),
name: 'Alex Yang',
avatarUrl: '',
email: 'alex.yang@example.org',
password: hashSync('123456'),
emailVerified: new Date(),
createdAt: new Date(),
};
get user() {
// eslint-disable-next-line @typescript-eslint/no-this-alias
const prisma = this;
return {
async findFirst() {
return prisma.fakeUser;
},
async findUnique() {
return this.findFirst();
},
async update() {
return this.findFirst();
},
};
}
}
test.beforeEach(async t => {
const { app } = await createTestingApp({
imports: [AppModule],
tapModule(builder) {
builder
.overrideProvider(PrismaClient)
.useClass(FakePrisma)
.overrideProvider(FeatureManagementService)
.useValue({ canEarlyAccess: () => true });
},
});
t.context.app = app;
@ -66,7 +24,6 @@ test.afterEach.always(async t => {
});
test('should init app', async t => {
t.is(typeof t.context.app, 'object');
await request(t.context.app.getHttpServer())
.post(gql)
.send({
@ -78,130 +35,22 @@ test('should init app', async t => {
})
.expect(400);
const { token } = await createToken(t.context.app);
await request(t.context.app.getHttpServer())
const response = await request(t.context.app.getHttpServer())
.post(gql)
.auth(token, { type: 'bearer' })
.send({
query: `
query {
__typename
}
`,
})
.expect(200)
.expect(res => {
t.is(res.body.data.__typename, 'Query');
});
});
test('should find default user', async t => {
const { token } = await createToken(t.context.app);
await request(t.context.app.getHttpServer())
.post(gql)
.auth(token, { type: 'bearer' })
.send({
query: `
query {
user(email: "alex.yang@example.org") {
... on UserType {
email
}
... on LimitedUserType {
email
}
}
}
`,
})
.expect(200)
.expect(res => {
t.is(res.body.data.user.email, 'alex.yang@example.org');
});
});
test('should be able to upload avatar and remove it', async t => {
const { token, id } = await createToken(t.context.app);
const png = await Transformer.fromRgbaPixels(
Buffer.alloc(400 * 400 * 4).fill(255),
400,
400
).png();
await request(t.context.app.getHttpServer())
.post(gql)
.auth(token, { type: 'bearer' })
.field(
'operations',
JSON.stringify({
name: 'uploadAvatar',
query: `mutation uploadAvatar($avatar: Upload!) {
uploadAvatar(avatar: $avatar) {
id
query: `query {
serverConfig {
name
avatarUrl
email
version
type
features
}
}
`,
variables: { id, avatar: null },
}`,
})
)
.field('map', JSON.stringify({ '0': ['variables.avatar'] }))
.attach('0', png, 'avatar.png')
.expect(200)
.expect(res => {
t.is(res.body.data.uploadAvatar.id, id);
});
.expect(200);
await request(t.context.app.getHttpServer())
.post(gql)
.auth(token, { type: 'bearer' })
.set({ 'x-request-id': 'test', 'x-operation-name': 'test' })
.send({
query: `
mutation removeAvatar {
removeAvatar {
success
}
}
`,
})
.expect(200)
.expect(res => {
t.is(res.body.data.removeAvatar.success, true);
});
const config = response.body.data.serverConfig;
t.is(config.type, 'Affine');
t.true(Array.isArray(config.features));
});
async function createToken(app: INestApplication<Express>): Promise<{
id: string;
token: string;
}> {
let token;
let id;
await request(app.getHttpServer())
.post(gql)
.send({
query: `
mutation {
signIn(email: "alex.yang@example.org", password: "123456") {
id
token {
token
}
}
}
`,
})
.expect(200)
.expect(res => {
id = res.body.data.signIn.id;
ok(
typeof res.body.data.signIn.token.token === 'string',
'res.body.data.signIn.token.token is not a string'
);
token = res.body.data.signIn.token.token;
});
return { token: token!, id: id! };
}

View File

@ -39,7 +39,7 @@ test('change email', async t => {
if (mail.hasConfigured()) {
const u1Email = 'u1@affine.pro';
const u2Email = 'u2@affine.pro';
const tokenRegex = /token=3D([^"&\s]+)/;
const tokenRegex = /token=3D([^"&]+)/;
const u1 = await signUp(app, 'u1', u1Email, '1');
@ -57,7 +57,7 @@ test('change email', async t => {
const changeTokenMatch = changeEmailContent.Content.Body.match(tokenRegex);
const changeEmailToken = changeTokenMatch
? decodeURIComponent(changeTokenMatch[1].replace(/=3D/g, '='))
? decodeURIComponent(changeTokenMatch[1].replace(/=\r\n/, ''))
: null;
t.not(
@ -85,7 +85,7 @@ test('change email', async t => {
const verifyTokenMatch = verifyEmailContent.Content.Body.match(tokenRegex);
const verifyEmailToken = verifyTokenMatch
? decodeURIComponent(verifyTokenMatch[1].replace(/=3D/g, '='))
? decodeURIComponent(verifyTokenMatch[1].replace(/=\r\n/, ''))
: null;
t.not(
@ -94,7 +94,7 @@ test('change email', async t => {
'fail to get verify change email token from email content'
);
await changeEmail(app, u1.token.token, verifyEmailToken as string);
await changeEmail(app, u1.token.token, verifyEmailToken as string, u2Email);
const afterNotificationMailCount = await getCurrentMailMessageCount();

View File

@ -1,172 +0,0 @@
/// <reference types="../src/global.d.ts" />
import { TestingModule } from '@nestjs/testing';
import test from 'ava';
import { AuthResolver } from '../src/core/auth/resolver';
import { AuthService } from '../src/core/auth/service';
import { ConfigModule } from '../src/fundamentals/config';
import {
mintChallengeResponse,
verifyChallengeResponse,
} from '../src/fundamentals/storage';
import { createTestingModule } from './utils';
let authService: AuthService;
let authResolver: AuthResolver;
let module: TestingModule;
test.beforeEach(async () => {
module = await createTestingModule({
imports: [
ConfigModule.forRoot({
auth: {
accessTokenExpiresIn: 1,
refreshTokenExpiresIn: 1,
leeway: 1,
},
host: 'example.org',
https: true,
}),
],
});
authService = module.get(AuthService);
authResolver = module.get(AuthResolver);
});
test.afterEach.always(async () => {
await module.close();
});
test('should be able to register and signIn', async t => {
await authService.signUp('Alex Yang', 'alexyang@example.org', '123456');
await authService.signIn('alexyang@example.org', '123456');
t.pass();
});
test('should be able to verify', async t => {
await authService.signUp('Alex Yang', 'alexyang@example.org', '123456');
await authService.signIn('alexyang@example.org', '123456');
const date = new Date();
const user = {
id: '1',
name: 'Alex Yang',
email: 'alexyang@example.org',
emailVerified: date,
createdAt: date,
avatarUrl: '',
};
{
const token = await authService.sign(user);
const claim = await authService.verify(token);
t.is(claim.id, '1');
t.is(claim.name, 'Alex Yang');
t.is(claim.email, 'alexyang@example.org');
t.is(claim.emailVerified?.toISOString(), date.toISOString());
t.is(claim.createdAt.toISOString(), date.toISOString());
}
{
const token = await authService.refresh(user);
const claim = await authService.verify(token);
t.is(claim.id, '1');
t.is(claim.name, 'Alex Yang');
t.is(claim.email, 'alexyang@example.org');
t.is(claim.emailVerified?.toISOString(), date.toISOString());
t.is(claim.createdAt.toISOString(), date.toISOString());
}
});
test('should not be able to return token if user is invalid', async t => {
const date = new Date();
const user = {
id: '1',
name: 'Alex Yang',
email: 'alexyang@example.org',
emailVerified: date,
createdAt: date,
avatarUrl: '',
};
const anotherUser = {
id: '2',
name: 'Alex Yang 2',
email: 'alexyang@example.org',
emailVerified: date,
createdAt: date,
avatarUrl: '',
};
await t.throwsAsync(
authResolver.token(
{
req: {
headers: {
referer: 'https://example.org',
host: 'example.org',
},
} as any,
},
user,
anotherUser
),
{
message: 'Invalid user',
}
);
});
test('should not return sessionToken if request headers is invalid', async t => {
const date = new Date();
const user = {
id: '1',
name: 'Alex Yang',
email: 'alexyang@example.org',
emailVerified: date,
createdAt: date,
avatarUrl: '',
};
const result = await authResolver.token(
{
req: {
headers: {},
} as any,
},
user,
user
);
t.is(result.sessionToken, undefined);
});
test('should return valid sessionToken if request headers valid', async t => {
const date = new Date();
const user = {
id: '1',
name: 'Alex Yang',
email: 'alexyang@example.org',
emailVerified: date,
createdAt: date,
avatarUrl: '',
};
const result = await authResolver.token(
{
req: {
headers: {
referer: 'https://example.org/open-app/test',
host: 'example.org',
},
cookies: {
'next-auth.session-token': '123456',
},
} as any,
},
user,
user
);
t.is(result.sessionToken, '123456');
});
test('verify challenge', async t => {
const resource = 'xp8D3rcXV9bMhWrb6abxl';
const response = await mintChallengeResponse(resource, 20);
const success = await verifyChallengeResponse(response, 20, resource);
t.true(success);
});

View File

@ -11,7 +11,7 @@ import {
FeatureService,
FeatureType,
} from '../src/core/features';
import { UserType } from '../src/core/users/types';
import { UserType } from '../src/core/user/types';
import { WorkspaceResolver } from '../src/core/workspaces/resolvers';
import { Permission } from '../src/core/workspaces/types';
import { ConfigModule } from '../src/fundamentals/config';
@ -54,11 +54,6 @@ test.beforeEach(async t => {
const { app } = await createTestingApp({
imports: [
ConfigModule.forRoot({
auth: {
accessTokenExpiresIn: 1,
refreshTokenExpiresIn: 1,
leeway: 1,
},
host: 'example.org',
https: true,
featureFlags: {

View File

@ -21,15 +21,7 @@ const test = ava as TestFn<{
test.beforeEach(async t => {
t.context.module = await createTestingModule({
imports: [
ConfigModule.forRoot({
auth: {
accessTokenExpiresIn: 1,
refreshTokenExpiresIn: 1,
leeway: 1,
},
}),
],
imports: [ConfigModule.forRoot({})],
});
t.context.auth = t.context.module.get(AuthService);
});

View File

@ -1,40 +0,0 @@
/// <reference types="../src/global.d.ts" />
import { TestingModule } from '@nestjs/testing';
import ava, { type TestFn } from 'ava';
import { CacheModule } from '../src/fundamentals/cache';
import { SessionModule, SessionService } from '../src/fundamentals/session';
import { createTestingModule } from './utils';
const test = ava as TestFn<{
session: SessionService;
module: TestingModule;
}>;
test.beforeEach(async t => {
const module = await createTestingModule({
imports: [CacheModule, SessionModule],
});
const session = module.get(SessionService);
t.context.module = module;
t.context.session = session;
});
test.afterEach.always(async t => {
await t.context.module.close();
});
test('should be able to set session', async t => {
const { session } = t.context;
await session.set('test', 'value');
t.is(await session.get('test'), 'value');
});
test('should be expired by ttl', async t => {
const { session } = t.context;
await session.set('test', 'value', 100);
t.is(await session.get('test'), 'value');
await new Promise(resolve => setTimeout(resolve, 500));
t.is(await session.get('test'), undefined);
});

View File

@ -1,16 +1,18 @@
import type { INestApplication } from '@nestjs/common';
import { PrismaClient } from '@prisma/client';
import request from 'supertest';
import type { TokenType } from '../../src/core/auth';
import type { UserType } from '../../src/core/users';
import type { ClientTokenType } from '../../src/core/auth';
import type { UserType } from '../../src/core/user';
import { gql } from './common';
export async function signUp(
app: INestApplication,
name: string,
email: string,
password: string
): Promise<UserType & { token: TokenType }> {
password: string,
autoVerifyEmail = true
): Promise<UserType & { token: ClientTokenType }> {
const res = await request(app.getHttpServer())
.post(gql)
.set({ 'x-request-id': 'test', 'x-operation-name': 'test' })
@ -24,9 +26,23 @@ export async function signUp(
`,
})
.expect(200);
if (autoVerifyEmail) {
await setEmailVerified(app, email);
}
return res.body.data.signUp;
}
async function setEmailVerified(app: INestApplication, email: string) {
await app.get(PrismaClient).user.update({
where: { email },
data: {
emailVerifiedAt: new Date(),
},
});
}
export async function currentUser(app: INestApplication, token: string) {
const res = await request(app.getHttpServer())
.post(gql)
@ -36,7 +52,7 @@ export async function currentUser(app: INestApplication, token: string) {
query: `
query {
currentUser {
id, name, email, emailVerified, avatarUrl, createdAt, hasPassword,
id, name, email, emailVerified, avatarUrl, hasPassword,
token { token }
}
}
@ -94,8 +110,9 @@ export async function sendVerifyChangeEmail(
export async function changeEmail(
app: INestApplication,
userToken: string,
token: string
): Promise<UserType & { token: TokenType }> {
token: string,
email: string
): Promise<UserType & { token: ClientTokenType }> {
const res = await request(app.getHttpServer())
.post(gql)
.auth(userToken, { type: 'bearer' })
@ -103,7 +120,7 @@ export async function changeEmail(
.send({
query: `
mutation {
changeEmail(token: "${token}") {
changeEmail(token: "${token}", email: "${email}") {
id
name
avatarUrl

View File

@ -1,11 +1,13 @@
import { INestApplication, ModuleMetadata } from '@nestjs/common';
import { APP_GUARD } from '@nestjs/core';
import { Query, Resolver } from '@nestjs/graphql';
import { Test, TestingModuleBuilder } from '@nestjs/testing';
import { PrismaClient } from '@prisma/client';
import cookieParser from 'cookie-parser';
import graphqlUploadExpress from 'graphql-upload/graphqlUploadExpress.mjs';
import { AppModule, FunctionalityModules } from '../../src/app.module';
import { AuthModule } from '../../src/core/auth';
import { AuthGuard, AuthModule } from '../../src/core/auth';
import { UserFeaturesInit1698652531198 } from '../../src/data/migrations/1698652531198-user-features-init';
import { GqlModule } from '../../src/fundamentals/graphql';
@ -78,7 +80,14 @@ export async function createTestingModule(
const builder = Test.createTestingModule({
imports,
providers: [MockResolver, ...(moduleDef.providers ?? [])],
providers: [
{
provide: APP_GUARD,
useClass: AuthGuard,
},
MockResolver,
...(moduleDef.providers ?? []),
],
controllers: moduleDef.controllers,
});
@ -113,6 +122,8 @@ export async function createTestingApp(moduleDef: TestingModuleMeatdata = {}) {
})
);
app.use(cookieParser());
if (moduleDef.tapApp) {
moduleDef.tapApp(app);
}

View File

@ -0,0 +1,22 @@
import { useAFFiNEI18N } from '@affine/i18n/hooks';
import type { FC } from 'react';
import { Button } from '../../ui/button';
import { AuthPageContainer } from './auth-page-container';
export const ConfirmChangeEmail: FC<{
onOpenAffine: () => void;
}> = ({ onOpenAffine }) => {
const t = useAFFiNEI18N();
return (
<AuthPageContainer
title={t['com.affine.auth.change.email.page.success.title']()}
subtitle={t['com.affine.auth.change.email.page.success.subtitle']()}
>
<Button type="primary" size="large" onClick={onOpenAffine}>
{t['com.affine.auth.open.affine']()}
</Button>
</AuthPageContainer>
);
};

View File

@ -37,7 +37,7 @@ function getCallbackUrl(location: Location) {
try {
const url =
location.state?.callbackURL ||
new URLSearchParams(location.search).get('callbackUrl');
new URLSearchParams(location.search).get('redirect_uri');
if (typeof url === 'string' && url) {
if (!url.startsWith('http:') && !url.startsWith('https:')) {
return url;

View File

@ -3,4 +3,5 @@ export interface User {
name: string;
email: string;
image?: string | null;
avatarUrl: string | null;
}

View File

@ -4,6 +4,7 @@ import { SignOutIcon } from '@blocksuite/icons';
import { Avatar } from '../../ui/avatar';
import { Button, IconButton } from '../../ui/button';
import { Tooltip } from '../../ui/tooltip';
import type { User } from '../auth-components';
import { NotFoundPattern } from './not-found-pattern';
import {
largeButtonEffect,
@ -12,11 +13,7 @@ import {
} from './styles.css';
export interface NotFoundPageProps {
user: {
name: string;
email: string;
avatar: string;
} | null;
user?: User | null;
onBack: () => void;
onSignOut: () => void;
}
@ -47,7 +44,7 @@ export const NotFoundPage = ({
{user ? (
<div className={wrapper}>
<Avatar url={user.avatar} name={user.name} />
<Avatar url={user.avatarUrl ?? user.image} name={user.name} />
<span style={{ margin: '0 12px' }}>{user.email}</span>
<Tooltip content={t['404.signOut']()}>
<IconButton onClick={onSignOut}>

View File

@ -384,6 +384,7 @@ export const createConfiguration: (
{ context: '/api', target: 'http://localhost:3010' },
{ context: '/socket.io', target: 'http://localhost:3010', ws: true },
{ context: '/graphql', target: 'http://localhost:3010' },
{ context: '/oauth', target: 'http://localhost:3010' },
],
} as DevServerConfiguration,
} satisfies webpack.Configuration;

View File

@ -78,7 +78,6 @@
"lottie-web": "^5.12.2",
"mini-css-extract-plugin": "^2.8.0",
"nanoid": "^5.0.6",
"next-auth": "^4.24.5",
"next-themes": "^0.2.1",
"postcss-loader": "^8.1.0",
"react": "18.2.0",

View File

@ -1,4 +0,0 @@
import { atom } from 'jotai';
import type { SessionContextValue } from 'next-auth/react';
export const sessionAtom = atom<SessionContextValue<true> | null>(null);

View File

@ -24,7 +24,7 @@ export type AuthProps = {
setAuthEmail: (state: AuthProps['email']) => void;
setEmailType: (state: AuthProps['emailType']) => void;
email: string;
emailType: 'setPassword' | 'changePassword' | 'changeEmail';
emailType: 'setPassword' | 'changePassword' | 'changeEmail' | 'verifyEmail';
onSignedIn?: () => void;
};
@ -59,8 +59,10 @@ export const AuthModal: FC<AuthModalBaseProps & AuthProps> = ({
emailType,
}) => {
const onSignedIn = useCallback(() => {
setAuthState('signIn');
setAuthEmail('');
setOpen(false);
}, [setOpen]);
}, [setAuthState, setAuthEmail, setOpen]);
return (
<AuthModalBase open={open} setOpen={setOpen}>

View File

@ -0,0 +1,66 @@
import { Button } from '@affine/component/ui/button';
import {
useOAuthProviders,
useServerFeatures,
} from '@affine/core/hooks/affine/use-server-config';
import { OAuthProviderType } from '@affine/graphql';
import { GithubIcon, GoogleDuotoneIcon } from '@blocksuite/icons';
import { type ReactElement, useCallback } from 'react';
import { useAuth } from './use-auth';
const OAuthProviderMap: Record<
OAuthProviderType,
{
icon: ReactElement;
}
> = {
[OAuthProviderType.Google]: {
icon: <GoogleDuotoneIcon />,
},
[OAuthProviderType.GitHub]: {
icon: <GithubIcon />,
},
};
export function OAuth() {
const { oauth } = useServerFeatures();
if (!oauth) {
return null;
}
return <OAuthProviders />;
}
function OAuthProviders() {
const providers = useOAuthProviders();
return providers.map(provider => (
<OAuthProvider key={provider} provider={provider} />
));
}
function OAuthProvider({ provider }: { provider: OAuthProviderType }) {
const { icon } = OAuthProviderMap[provider];
const { oauthSignIn } = useAuth();
const onClick = useCallback(() => {
oauthSignIn(provider);
}, [provider, oauthSignIn]);
return (
<Button
key={provider}
type="primary"
block
size="extraLarge"
style={{ marginTop: 30 }}
icon={icon}
onClick={onClick}
>
Continue with {provider}
</Button>
);
}

View File

@ -12,6 +12,7 @@ import {
sendChangeEmailMutation,
sendChangePasswordEmailMutation,
sendSetPasswordEmailMutation,
sendVerifyEmailMutation,
} from '@affine/graphql';
import { useAFFiNEI18N } from '@affine/i18n/hooks';
import { useSetAtom } from 'jotai/react';
@ -29,7 +30,9 @@ const useEmailTitle = (emailType: AuthPanelProps['emailType']) => {
case 'changePassword':
return t['com.affine.auth.reset.password']();
case 'changeEmail':
return t['com.affine.settings.email.action']();
return t['com.affine.settings.email.action.change']();
case 'verifyEmail':
return t['com.affine.settings.email.action.verify']();
}
};
const useContent = (emailType: AuthPanelProps['emailType'], email: string) => {
@ -41,7 +44,8 @@ const useContent = (emailType: AuthPanelProps['emailType'], email: string) => {
case 'changePassword':
return t['com.affine.auth.reset.password.message']();
case 'changeEmail':
return t['com.affine.auth.change.email.message']({
case 'verifyEmail':
return t['com.affine.auth.verify.email.message']({
email,
});
}
@ -56,7 +60,8 @@ const useNotificationHint = (emailType: AuthPanelProps['emailType']) => {
case 'changePassword':
return t['com.affine.auth.sent.change.password.hint']();
case 'changeEmail':
return t['com.affine.auth.sent.change.email.hint']();
case 'verifyEmail':
return t['com.affine.auth.sent.verify.email.hint']();
}
};
const useButtonContent = (emailType: AuthPanelProps['emailType']) => {
@ -68,7 +73,8 @@ const useButtonContent = (emailType: AuthPanelProps['emailType']) => {
case 'changePassword':
return t['com.affine.auth.send.reset.password.link']();
case 'changeEmail':
return t['com.affine.auth.send.change.email.link']();
case 'verifyEmail':
return t['com.affine.auth.send.verify.email.hint']();
}
};
@ -87,12 +93,17 @@ const useSendEmail = (emailType: AuthPanelProps['emailType']) => {
useMutation({
mutation: sendChangeEmailMutation,
});
const { trigger: sendVerifyEmail, isMutating: isVerifyEmailMutation } =
useMutation({
mutation: sendVerifyEmailMutation,
});
return {
loading:
isChangePasswordMutating ||
isSetPasswordMutating ||
isChangeEmailMutating,
isChangeEmailMutating ||
isVerifyEmailMutation,
sendEmail: useCallback(
(email: string) => {
let trigger: (args: {
@ -113,6 +124,10 @@ const useSendEmail = (emailType: AuthPanelProps['emailType']) => {
trigger = sendChangeEmail;
callbackUrl = 'changeEmail';
break;
case 'verifyEmail':
trigger = sendVerifyEmail;
callbackUrl = 'verify-email';
break;
}
// TODO: add error handler
return trigger({
@ -127,6 +142,7 @@ const useSendEmail = (emailType: AuthPanelProps['emailType']) => {
sendChangeEmail,
sendChangePasswordEmail,
sendSetPasswordEmail,
sendVerifyEmail,
]
),
};

View File

@ -5,10 +5,9 @@ import {
ModalHeader,
} from '@affine/component/auth-components';
import { Button } from '@affine/component/ui/button';
import { useSession } from '@affine/core/hooks/affine/use-current-user';
import { useAsyncCallback } from '@affine/core/hooks/affine-async-hooks';
import { useAFFiNEI18N } from '@affine/i18n/hooks';
// eslint-disable-next-line @typescript-eslint/no-restricted-imports
import { useSession } from 'next-auth/react';
import type { FC } from 'react';
import { useCallback, useState } from 'react';
@ -25,7 +24,7 @@ export const SignInWithPassword: FC<AuthPanelProps> = ({
onSignedIn,
}) => {
const t = useAFFiNEI18N();
const { update } = useSession();
const { reload } = useSession();
const [password, setPassword] = useState('');
const [passwordError, setPasswordError] = useState(false);
@ -39,7 +38,6 @@ export const SignInWithPassword: FC<AuthPanelProps> = ({
const onSignIn = useAsyncCallback(async () => {
const res = await signInCloud('credentials', {
redirect: false,
email,
password,
}).catch(console.error);
@ -48,9 +46,9 @@ export const SignInWithPassword: FC<AuthPanelProps> = ({
return setPasswordError(true);
}
await update();
await reload();
onSignedIn?.();
}, [email, password, onSignedIn, update]);
}, [email, password, onSignedIn, reload]);
const sendMagicLink = useAsyncCallback(async () => {
if (allowSendEmail && verifyToken && !sendingEmail) {

View File

@ -12,7 +12,7 @@ import {
} from '@affine/graphql';
import { Trans } from '@affine/i18n';
import { useAFFiNEI18N } from '@affine/i18n/hooks';
import { ArrowDownBigIcon, GoogleDuotoneIcon } from '@blocksuite/icons';
import { ArrowDownBigIcon } from '@blocksuite/icons';
import { type FC, useState } from 'react';
import { useCallback } from 'react';
@ -20,6 +20,7 @@ import { useCurrentLoginStatus } from '../../../hooks/affine/use-current-login-s
import { useMutation } from '../../../hooks/use-mutation';
import { emailRegex } from '../../../utils/email-regex';
import type { AuthPanelProps } from './index';
import { OAuth } from './oauth';
import * as style from './style.css';
import { INTERNAL_BETA_URL, useAuth } from './use-auth';
import { Captcha, useCaptcha } from './use-captcha';
@ -46,7 +47,6 @@ export const SignIn: FC<AuthPanelProps> = ({
allowSendEmail,
signIn,
signUp,
signInWithGoogle,
} = useAuth();
const { trigger: verifyUser, isMutating } = useMutation({
@ -59,6 +59,10 @@ export const SignIn: FC<AuthPanelProps> = ({
}
const onContinue = useAsyncCallback(async () => {
if (!allowSendEmail) {
return;
}
if (!validateEmail(email)) {
setIsValidEmail(false);
return;
@ -99,13 +103,14 @@ export const SignIn: FC<AuthPanelProps> = ({
const res = await signUp(email, verifyToken, challenge);
if (res?.status === 403 && res?.url === INTERNAL_BETA_URL) {
return setAuthState('noAccess');
} else if (!res || res.status >= 400 || res.error) {
} else if (!res || res.status >= 400) {
return;
}
setAuthState('afterSignUpSendEmail');
}
}
}, [
allowSendEmail,
subscriptionData,
challenge,
email,
@ -124,20 +129,7 @@ export const SignIn: FC<AuthPanelProps> = ({
subTitle={t['com.affine.brand.affineCloud']()}
/>
<Button
type="primary"
block
size="extraLarge"
style={{
marginTop: 30,
}}
icon={<GoogleDuotoneIcon />}
onClick={useCallback(() => {
signInWithGoogle();
}, [signInWithGoogle])}
>
{t['Continue with Google']()}
</Button>
<OAuth />
<div className={style.authModalContent}>
<AuthInput

View File

@ -1,7 +1,7 @@
import { pushNotificationAtom } from '@affine/component/notification-center';
import type { Notification } from '@affine/component/notification-center/index.jotai';
import type { OAuthProviderType } from '@affine/graphql';
import { atom, useAtom, useSetAtom } from 'jotai';
import { type SignInResponse } from 'next-auth/react';
import { useCallback } from 'react';
import { signInCloud } from '../../../utils/cloud-utils';
@ -11,10 +11,10 @@ const COUNT_DOWN_TIME = 60;
export const INTERNAL_BETA_URL = `https://community.affine.pro/c/insider-general/`;
function handleSendEmailError(
res: SignInResponse | undefined | void,
res: Response | undefined | void,
pushNotification: (notification: Notification) => void
) {
if (res?.error) {
if (!res?.ok) {
pushNotification({
title: 'Send email error',
message: 'Please back to home and try again',
@ -64,8 +64,13 @@ export const useAuth = () => {
const [authStore, setAuthStore] = useAtom(authStoreAtom);
const startResendCountDown = useSetAtom(countDownAtom);
const signIn = useCallback(
async (email: string, verifyToken: string, challenge?: string) => {
const sendEmailMagicLink = useCallback(
async (
signUp: boolean,
email: string,
verifyToken: string,
challenge?: string
) => {
setAuthStore(prev => {
return {
...prev,
@ -76,18 +81,19 @@ export const useAuth = () => {
const res = await signInCloud(
'email',
{
email: email,
callbackUrl: subscriptionData
? subscriptionData.getRedirectUrl(false)
: '/auth/signIn',
redirect: false,
email,
},
challenge
{
...(challenge
? {
challenge,
token: verifyToken,
}
: { token: verifyToken }
: { token: verifyToken }),
callbackUrl: subscriptionData
? subscriptionData.getRedirectUrl(signUp)
: '/auth/signIn',
}
).catch(console.error);
handleSendEmailError(res, pushNotification);
@ -107,47 +113,24 @@ export const useAuth = () => {
const signUp = useCallback(
async (email: string, verifyToken: string, challenge?: string) => {
setAuthStore(prev => {
return {
...prev,
isMutating: true,
};
});
const res = await signInCloud(
'email',
{
email: email,
callbackUrl: subscriptionData
? subscriptionData.getRedirectUrl(true)
: '/auth/signUp',
redirect: false,
return sendEmailMagicLink(true, email, verifyToken, challenge).catch(
console.error
);
},
challenge
? {
challenge,
token: verifyToken,
}
: { token: verifyToken }
).catch(console.error);
handleSendEmailError(res, pushNotification);
setAuthStore({
isMutating: false,
allowSendEmail: false,
resendCountDown: COUNT_DOWN_TIME,
});
startResendCountDown();
return res;
},
[pushNotification, setAuthStore, startResendCountDown, subscriptionData]
[sendEmailMagicLink]
);
const signInWithGoogle = useCallback(() => {
signInCloud('google').catch(console.error);
const signIn = useCallback(
async (email: string, verifyToken: string, challenge?: string) => {
return sendEmailMagicLink(false, email, verifyToken, challenge).catch(
console.error
);
},
[sendEmailMagicLink]
);
const oauthSignIn = useCallback((provider: OAuthProviderType) => {
signInCloud(provider).catch(console.error);
}, []);
const resetCountDown = useCallback(() => {
@ -165,6 +148,6 @@ export const useAuth = () => {
isMutating: authStore.isMutating,
signUp,
signIn,
signInWithGoogle,
oauthSignIn,
};
};

View File

@ -3,21 +3,21 @@ import { useLiveData } from '@toeverything/infra/livedata';
import { Suspense, useEffect } from 'react';
import { useCurrentLoginStatus } from '../../../hooks/affine/use-current-login-status';
import { useCurrentUser } from '../../../hooks/affine/use-current-user';
import { useSession } from '../../../hooks/affine/use-current-user';
import { CurrentWorkspaceService } from '../../../modules/workspace/current-workspace';
const SyncAwarenessInnerLoggedIn = () => {
const currentUser = useCurrentUser();
const { user } = useSession();
const currentWorkspace = useLiveData(
useService(CurrentWorkspaceService).currentWorkspace
);
useEffect(() => {
if (currentUser && currentWorkspace) {
if (user && currentWorkspace) {
currentWorkspace.blockSuiteWorkspace.awarenessStore.awareness.setLocalStateField(
'user',
{
name: currentUser.name,
name: user.name,
// todo: add avatar?
}
);
@ -30,7 +30,7 @@ const SyncAwarenessInnerLoggedIn = () => {
};
}
return;
}, [currentUser, currentWorkspace]);
}, [user, currentWorkspace]);
return null;
};

Some files were not shown because too many files have changed in this diff Show More