feat: improve histories query for forked session (#7414)

This commit is contained in:
darkskygit 2024-07-03 04:49:19 +00:00
parent cc7740d8d3
commit e8285289fe
No known key found for this signature in database
GPG Key ID: 97B7D036B1566E9D
8 changed files with 210 additions and 21 deletions

View File

@ -273,12 +273,7 @@ export class CopilotResolver {
@Parent() copilot: CopilotType, @Parent() copilot: CopilotType,
@CurrentUser() user: CurrentUser, @CurrentUser() user: CurrentUser,
@Args('docId', { nullable: true }) docId?: string, @Args('docId', { nullable: true }) docId?: string,
@Args({ @Args('options', { nullable: true }) options?: QueryChatHistoriesInput
name: 'options',
type: () => QueryChatHistoriesInput,
nullable: true,
})
options?: QueryChatHistoriesInput
) { ) {
const workspaceId = copilot.workspaceId; const workspaceId = copilot.workspaceId;
if (!workspaceId) { if (!workspaceId) {

View File

@ -254,6 +254,7 @@ export class ChatSessionService {
// connect // connect
userId: state.userId, userId: state.userId,
promptName: state.prompt.name, promptName: state.prompt.name,
parentSessionId: state.parentSessionId,
}, },
}); });
} }
@ -384,17 +385,36 @@ export class ChatSessionService {
return await this.db.aiSession return await this.db.aiSession
.findMany({ .findMany({
where: { where: {
OR: [
{
userId, userId,
workspaceId: workspaceId, workspaceId: workspaceId,
docId: workspaceId === docId ? undefined : docId, docId: workspaceId === docId ? undefined : docId,
prompt: { id: options?.sessionId
action: options?.action ? { not: null } : null, ? { equals: options.sessionId }
}, : undefined,
id: options?.sessionId ? { equals: options.sessionId } : undefined,
deletedAt: null, deletedAt: null,
}, },
...(options?.action
? []
: [
{
userId: { not: userId },
workspaceId: workspaceId,
docId: workspaceId === docId ? undefined : docId,
id: options?.sessionId
? { equals: options.sessionId }
: undefined,
// should only find forked session
parentSessionId: { not: null },
deletedAt: null,
},
]),
],
},
select: { select: {
id: true, id: true,
userId: true,
promptName: true, promptName: true,
tokenCost: true, tokenCost: true,
createdAt: true, createdAt: true,
@ -419,15 +439,30 @@ export class ChatSessionService {
.then(sessions => .then(sessions =>
Promise.all( Promise.all(
sessions.map( sessions.map(
async ({ id, promptName, tokenCost, messages, createdAt }) => { async ({
id,
userId: uid,
promptName,
tokenCost,
messages,
createdAt,
}) => {
try { try {
const ret = ChatMessageSchema.array().safeParse(messages);
if (ret.success) {
const prompt = await this.prompt.get(promptName); const prompt = await this.prompt.get(promptName);
if (!prompt) { if (!prompt) {
throw new CopilotPromptNotFound({ name: promptName }); throw new CopilotPromptNotFound({ name: promptName });
} }
if (
// filter out the user's session that not match the action option
(uid === userId && !!options?.action !== !!prompt.action) ||
// filter out the non chat session from other user
(uid !== userId && !!prompt.action)
) {
return undefined;
}
const ret = ChatMessageSchema.array().safeParse(messages);
if (ret.success) {
// render system prompt // render system prompt
const preload = withPrompt const preload = withPrompt
? prompt ? prompt

View File

@ -36,6 +36,7 @@ import {
chatWithWorkflow, chatWithWorkflow,
createCopilotMessage, createCopilotMessage,
createCopilotSession, createCopilotSession,
forkCopilotSession,
getHistories, getHistories,
MockCopilotTestProvider, MockCopilotTestProvider,
sse2array, sse2array,
@ -164,6 +165,123 @@ test('should create session correctly', async t => {
} }
}); });
test('should fork session correctly', async t => {
const { app } = t.context;
const assertForkSession = async (
token: string,
workspaceId: string,
sessionId: string,
lastMessageId: string,
error: string,
asserter = async (x: any) => {
const forkedSessionId = await x;
t.truthy(forkedSessionId, error);
return forkedSessionId;
}
) =>
await asserter(
forkCopilotSession(
app,
token,
workspaceId,
randomUUID(),
sessionId,
lastMessageId
)
);
// prepare session
const { id } = await createWorkspace(app, token);
const sessionId = await createCopilotSession(
app,
token,
id,
randomUUID(),
promptName
);
let forkedSessionId: string;
// should be able to fork session
{
for (let i = 0; i < 3; i++) {
const messageId = await createCopilotMessage(app, token, sessionId);
await chatWithText(app, token, sessionId, messageId);
}
const histories = await getHistories(app, token, { workspaceId: id });
const latestMessageId = histories[0].messages.findLast(
m => m.role === 'assistant'
)?.id;
t.truthy(latestMessageId, 'should find last message id');
// should be able to fork session
forkedSessionId = await assertForkSession(
token,
id,
sessionId,
latestMessageId!,
'should be able to fork session with cloud workspace that user can access'
);
}
{
const {
token: { token: newToken },
} = await signUp(app, 'test', 'test@affine.pro', '123456');
await assertForkSession(
newToken,
id,
sessionId,
randomUUID(),
'',
async x => {
await t.throwsAsync(
x,
{ instanceOf: Error },
'should not able to fork session with cloud workspace that user cannot access'
);
}
);
const inviteId = await inviteUser(
app,
token,
id,
'test@affine.pro',
'Admin'
);
await acceptInviteById(app, id, inviteId, false);
await assertForkSession(
newToken,
id,
sessionId,
randomUUID(),
'',
async x => {
await t.throwsAsync(
x,
{ instanceOf: Error },
'should not able to fork a root session from other user'
);
}
);
const histories = await getHistories(app, token, { workspaceId: id });
const latestMessageId = histories
.find(h => h.sessionId === forkedSessionId)
?.messages.findLast(m => m.role === 'assistant')?.id;
t.truthy(latestMessageId, 'should find latest message id');
await assertForkSession(
newToken,
id,
forkedSessionId,
latestMessageId!,
'should able to fork a forked session created by other user'
);
}
});
test('should be able to use test provider', async t => { test('should be able to use test provider', async t => {
const { app } = t.context; const { app } = t.context;

View File

@ -174,6 +174,35 @@ export async function createCopilotSession(
return res.body.data.createCopilotSession; return res.body.data.createCopilotSession;
} }
export async function forkCopilotSession(
app: INestApplication,
userToken: string,
workspaceId: string,
docId: string,
sessionId: string,
latestMessageId: string
): Promise<string> {
const res = await request(app.getHttpServer())
.post(gql)
.auth(userToken, { type: 'bearer' })
.set({ 'x-request-id': 'test', 'x-operation-name': 'test' })
.send({
query: `
mutation forkCopilotSession($options: ForkChatSessionInput!) {
forkCopilotSession(options: $options)
}
`,
variables: {
options: { workspaceId, docId, sessionId, latestMessageId },
},
})
.expect(200);
handleGraphQLError(res);
return res.body.data.forkCopilotSession;
}
export async function createCopilotMessage( export async function createCopilotMessage(
app: INestApplication, app: INestApplication,
userToken: string, userToken: string,
@ -286,6 +315,7 @@ export function textToEventStream(
} }
type ChatMessage = { type ChatMessage = {
id?: string;
role: string; role: string;
content: string; content: string;
attachments: string[] | null; attachments: string[] | null;
@ -333,6 +363,7 @@ export async function getHistories(
action action
createdAt createdAt
messages { messages {
id
role role
content content
attachments attachments

View File

@ -145,7 +145,14 @@ export function handleGraphQLError(resp: Response) {
if (errors) { if (errors) {
const cause = errors[0]; const cause = errors[0];
const stacktrace = cause.extensions?.stacktrace; const stacktrace = cause.extensions?.stacktrace;
throw new Error(stacktrace ? stacktrace.join('\n') : cause.message, cause); throw new Error(
stacktrace
? Array.isArray(stacktrace)
? stacktrace.join('\n')
: String(stacktrace)
: cause.message,
cause
);
} }
} }

View File

@ -11,6 +11,7 @@ query getCopilotHistories(
action action
createdAt createdAt
messages { messages {
id
role role
content content
attachments attachments

View File

@ -267,6 +267,7 @@ query getCopilotHistories($workspaceId: String!, $docId: String, $options: Query
action action
createdAt createdAt
messages { messages {
id
role role
content content
attachments attachments

View File

@ -1405,6 +1405,7 @@ export type GetCopilotHistoriesQuery = {
createdAt: string; createdAt: string;
messages: Array<{ messages: Array<{
__typename?: 'ChatMessage'; __typename?: 'ChatMessage';
id: string | null;
role: string; role: string;
content: string; content: string;
attachments: Array<string> | null; attachments: Array<string> | null;