diff --git a/packages/backend/server/src/plugins/copilot/resolver.ts b/packages/backend/server/src/plugins/copilot/resolver.ts index b7ddba8501..067105df72 100644 --- a/packages/backend/server/src/plugins/copilot/resolver.ts +++ b/packages/backend/server/src/plugins/copilot/resolver.ts @@ -273,12 +273,7 @@ export class CopilotResolver { @Parent() copilot: CopilotType, @CurrentUser() user: CurrentUser, @Args('docId', { nullable: true }) docId?: string, - @Args({ - name: 'options', - type: () => QueryChatHistoriesInput, - nullable: true, - }) - options?: QueryChatHistoriesInput + @Args('options', { nullable: true }) options?: QueryChatHistoriesInput ) { const workspaceId = copilot.workspaceId; if (!workspaceId) { diff --git a/packages/backend/server/src/plugins/copilot/session.ts b/packages/backend/server/src/plugins/copilot/session.ts index cb121f5b81..f4f0496ac9 100644 --- a/packages/backend/server/src/plugins/copilot/session.ts +++ b/packages/backend/server/src/plugins/copilot/session.ts @@ -254,6 +254,7 @@ export class ChatSessionService { // connect userId: state.userId, promptName: state.prompt.name, + parentSessionId: state.parentSessionId, }, }); } @@ -384,17 +385,36 @@ export class ChatSessionService { return await this.db.aiSession .findMany({ where: { - userId, - workspaceId: workspaceId, - docId: workspaceId === docId ? undefined : docId, - prompt: { - action: options?.action ? { not: null } : null, - }, - id: options?.sessionId ? { equals: options.sessionId } : undefined, - deletedAt: null, + OR: [ + { + userId, + workspaceId: workspaceId, + docId: workspaceId === docId ? undefined : docId, + id: options?.sessionId + ? { equals: options.sessionId } + : undefined, + deletedAt: null, + }, + ...(options?.action + ? [] + : [ + { + userId: { not: userId }, + workspaceId: workspaceId, + docId: workspaceId === docId ? undefined : docId, + id: options?.sessionId + ? { equals: options.sessionId } + : undefined, + // should only find forked session + parentSessionId: { not: null }, + deletedAt: null, + }, + ]), + ], }, select: { id: true, + userId: true, promptName: true, tokenCost: true, createdAt: true, @@ -419,15 +439,30 @@ export class ChatSessionService { .then(sessions => Promise.all( sessions.map( - async ({ id, promptName, tokenCost, messages, createdAt }) => { + async ({ + id, + userId: uid, + promptName, + tokenCost, + messages, + createdAt, + }) => { try { + const prompt = await this.prompt.get(promptName); + if (!prompt) { + throw new CopilotPromptNotFound({ name: promptName }); + } + if ( + // filter out the user's session that not match the action option + (uid === userId && !!options?.action !== !!prompt.action) || + // filter out the non chat session from other user + (uid !== userId && !!prompt.action) + ) { + return undefined; + } + const ret = ChatMessageSchema.array().safeParse(messages); if (ret.success) { - const prompt = await this.prompt.get(promptName); - if (!prompt) { - throw new CopilotPromptNotFound({ name: promptName }); - } - // render system prompt const preload = withPrompt ? prompt diff --git a/packages/backend/server/tests/copilot.e2e.ts b/packages/backend/server/tests/copilot.e2e.ts index b2dd5ab222..cff25dff76 100644 --- a/packages/backend/server/tests/copilot.e2e.ts +++ b/packages/backend/server/tests/copilot.e2e.ts @@ -36,6 +36,7 @@ import { chatWithWorkflow, createCopilotMessage, createCopilotSession, + forkCopilotSession, getHistories, MockCopilotTestProvider, sse2array, @@ -164,6 +165,123 @@ test('should create session correctly', async t => { } }); +test('should fork session correctly', async t => { + const { app } = t.context; + + const assertForkSession = async ( + token: string, + workspaceId: string, + sessionId: string, + lastMessageId: string, + error: string, + asserter = async (x: any) => { + const forkedSessionId = await x; + t.truthy(forkedSessionId, error); + return forkedSessionId; + } + ) => + await asserter( + forkCopilotSession( + app, + token, + workspaceId, + randomUUID(), + sessionId, + lastMessageId + ) + ); + + // prepare session + const { id } = await createWorkspace(app, token); + const sessionId = await createCopilotSession( + app, + token, + id, + randomUUID(), + promptName + ); + + let forkedSessionId: string; + // should be able to fork session + { + for (let i = 0; i < 3; i++) { + const messageId = await createCopilotMessage(app, token, sessionId); + await chatWithText(app, token, sessionId, messageId); + } + const histories = await getHistories(app, token, { workspaceId: id }); + const latestMessageId = histories[0].messages.findLast( + m => m.role === 'assistant' + )?.id; + t.truthy(latestMessageId, 'should find last message id'); + + // should be able to fork session + forkedSessionId = await assertForkSession( + token, + id, + sessionId, + latestMessageId!, + 'should be able to fork session with cloud workspace that user can access' + ); + } + + { + const { + token: { token: newToken }, + } = await signUp(app, 'test', 'test@affine.pro', '123456'); + await assertForkSession( + newToken, + id, + sessionId, + randomUUID(), + '', + async x => { + await t.throwsAsync( + x, + { instanceOf: Error }, + 'should not able to fork session with cloud workspace that user cannot access' + ); + } + ); + + const inviteId = await inviteUser( + app, + token, + id, + 'test@affine.pro', + 'Admin' + ); + await acceptInviteById(app, id, inviteId, false); + await assertForkSession( + newToken, + id, + sessionId, + randomUUID(), + '', + async x => { + await t.throwsAsync( + x, + { instanceOf: Error }, + 'should not able to fork a root session from other user' + ); + } + ); + + const histories = await getHistories(app, token, { workspaceId: id }); + const latestMessageId = histories + .find(h => h.sessionId === forkedSessionId) + ?.messages.findLast(m => m.role === 'assistant')?.id; + t.truthy(latestMessageId, 'should find latest message id'); + + await assertForkSession( + newToken, + id, + forkedSessionId, + latestMessageId!, + 'should able to fork a forked session created by other user' + ); + } +}); + test('should be able to use test provider', async t => { const { app } = t.context; diff --git a/packages/backend/server/tests/utils/copilot.ts b/packages/backend/server/tests/utils/copilot.ts index c3aed677fb..d6e32a2472 100644 --- a/packages/backend/server/tests/utils/copilot.ts +++ b/packages/backend/server/tests/utils/copilot.ts @@ -174,6 +174,35 @@ export async function createCopilotSession( return res.body.data.createCopilotSession; } +export async function forkCopilotSession( + app: INestApplication, + userToken: string, + workspaceId: string, + docId: string, + sessionId: string, + latestMessageId: string +): Promise { + const res = await request(app.getHttpServer()) + .post(gql) + .auth(userToken, { type: 'bearer' }) + .set({ 'x-request-id': 'test', 'x-operation-name': 'test' }) + .send({ + query: ` + mutation forkCopilotSession($options: ForkChatSessionInput!) { + forkCopilotSession(options: $options) + } + `, + variables: { + options: { workspaceId, docId, sessionId, latestMessageId }, + }, + }) + .expect(200); + + handleGraphQLError(res); + + return res.body.data.forkCopilotSession; +} + export async function createCopilotMessage( app: INestApplication, userToken: string, @@ -286,6 +315,7 @@ export function textToEventStream( } type ChatMessage = { + id?: string; role: string; content: string; attachments: string[] | null; @@ -333,6 +363,7 @@ export async function getHistories( action createdAt messages { + id role content attachments diff --git a/packages/backend/server/tests/utils/utils.ts b/packages/backend/server/tests/utils/utils.ts index b8bd9910bf..1db32b7984 100644 --- a/packages/backend/server/tests/utils/utils.ts +++ b/packages/backend/server/tests/utils/utils.ts @@ -145,7 +145,14 @@ export function handleGraphQLError(resp: Response) { if (errors) { const cause = errors[0]; const stacktrace = cause.extensions?.stacktrace; - throw new Error(stacktrace ? stacktrace.join('\n') : cause.message, cause); + throw new Error( + stacktrace + ? Array.isArray(stacktrace) + ? stacktrace.join('\n') + : String(stacktrace) + : cause.message, + cause + ); } } diff --git a/packages/frontend/graphql/src/graphql/get-copilot-histories.gql b/packages/frontend/graphql/src/graphql/get-copilot-histories.gql index 3779afd3d9..2b3af6c1da 100644 --- a/packages/frontend/graphql/src/graphql/get-copilot-histories.gql +++ b/packages/frontend/graphql/src/graphql/get-copilot-histories.gql @@ -11,6 +11,7 @@ query getCopilotHistories( action createdAt messages { + id role content attachments diff --git a/packages/frontend/graphql/src/graphql/index.ts b/packages/frontend/graphql/src/graphql/index.ts index b17ca4ee0f..a52dde498d 100644 --- a/packages/frontend/graphql/src/graphql/index.ts +++ b/packages/frontend/graphql/src/graphql/index.ts @@ -267,6 +267,7 @@ query getCopilotHistories($workspaceId: String!, $docId: String, $options: Query action createdAt messages { + id role content attachments diff --git a/packages/frontend/graphql/src/schema.ts b/packages/frontend/graphql/src/schema.ts index c72980b334..2b1ed8dee9 100644 --- a/packages/frontend/graphql/src/schema.ts +++ b/packages/frontend/graphql/src/schema.ts @@ -1405,6 +1405,7 @@ export type GetCopilotHistoriesQuery = { createdAt: string; messages: Array<{ __typename?: 'ChatMessage'; + id: string | null; role: string; content: string; attachments: Array | null;