mirror of
https://github.com/toeverything/AFFiNE.git
synced 2025-01-03 03:02:06 +03:00
feat: improve histories query for forked session (#7414)
This commit is contained in:
parent
cc7740d8d3
commit
e8285289fe
@ -273,12 +273,7 @@ export class CopilotResolver {
|
||||
@Parent() copilot: CopilotType,
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@Args('docId', { nullable: true }) docId?: string,
|
||||
@Args({
|
||||
name: 'options',
|
||||
type: () => QueryChatHistoriesInput,
|
||||
nullable: true,
|
||||
})
|
||||
options?: QueryChatHistoriesInput
|
||||
@Args('options', { nullable: true }) options?: QueryChatHistoriesInput
|
||||
) {
|
||||
const workspaceId = copilot.workspaceId;
|
||||
if (!workspaceId) {
|
||||
|
@ -254,6 +254,7 @@ export class ChatSessionService {
|
||||
// connect
|
||||
userId: state.userId,
|
||||
promptName: state.prompt.name,
|
||||
parentSessionId: state.parentSessionId,
|
||||
},
|
||||
});
|
||||
}
|
||||
@ -384,17 +385,36 @@ export class ChatSessionService {
|
||||
return await this.db.aiSession
|
||||
.findMany({
|
||||
where: {
|
||||
userId,
|
||||
workspaceId: workspaceId,
|
||||
docId: workspaceId === docId ? undefined : docId,
|
||||
prompt: {
|
||||
action: options?.action ? { not: null } : null,
|
||||
},
|
||||
id: options?.sessionId ? { equals: options.sessionId } : undefined,
|
||||
deletedAt: null,
|
||||
OR: [
|
||||
{
|
||||
userId,
|
||||
workspaceId: workspaceId,
|
||||
docId: workspaceId === docId ? undefined : docId,
|
||||
id: options?.sessionId
|
||||
? { equals: options.sessionId }
|
||||
: undefined,
|
||||
deletedAt: null,
|
||||
},
|
||||
...(options?.action
|
||||
? []
|
||||
: [
|
||||
{
|
||||
userId: { not: userId },
|
||||
workspaceId: workspaceId,
|
||||
docId: workspaceId === docId ? undefined : docId,
|
||||
id: options?.sessionId
|
||||
? { equals: options.sessionId }
|
||||
: undefined,
|
||||
// should only find forked session
|
||||
parentSessionId: { not: null },
|
||||
deletedAt: null,
|
||||
},
|
||||
]),
|
||||
],
|
||||
},
|
||||
select: {
|
||||
id: true,
|
||||
userId: true,
|
||||
promptName: true,
|
||||
tokenCost: true,
|
||||
createdAt: true,
|
||||
@ -419,15 +439,30 @@ export class ChatSessionService {
|
||||
.then(sessions =>
|
||||
Promise.all(
|
||||
sessions.map(
|
||||
async ({ id, promptName, tokenCost, messages, createdAt }) => {
|
||||
async ({
|
||||
id,
|
||||
userId: uid,
|
||||
promptName,
|
||||
tokenCost,
|
||||
messages,
|
||||
createdAt,
|
||||
}) => {
|
||||
try {
|
||||
const prompt = await this.prompt.get(promptName);
|
||||
if (!prompt) {
|
||||
throw new CopilotPromptNotFound({ name: promptName });
|
||||
}
|
||||
if (
|
||||
// filter out the user's session that not match the action option
|
||||
(uid === userId && !!options?.action !== !!prompt.action) ||
|
||||
// filter out the non chat session from other user
|
||||
(uid !== userId && !!prompt.action)
|
||||
) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
const ret = ChatMessageSchema.array().safeParse(messages);
|
||||
if (ret.success) {
|
||||
const prompt = await this.prompt.get(promptName);
|
||||
if (!prompt) {
|
||||
throw new CopilotPromptNotFound({ name: promptName });
|
||||
}
|
||||
|
||||
// render system prompt
|
||||
const preload = withPrompt
|
||||
? prompt
|
||||
|
@ -36,6 +36,7 @@ import {
|
||||
chatWithWorkflow,
|
||||
createCopilotMessage,
|
||||
createCopilotSession,
|
||||
forkCopilotSession,
|
||||
getHistories,
|
||||
MockCopilotTestProvider,
|
||||
sse2array,
|
||||
@ -164,6 +165,123 @@ test('should create session correctly', async t => {
|
||||
}
|
||||
});
|
||||
|
||||
test('should fork session correctly', async t => {
|
||||
const { app } = t.context;
|
||||
|
||||
const assertForkSession = async (
|
||||
token: string,
|
||||
workspaceId: string,
|
||||
sessionId: string,
|
||||
lastMessageId: string,
|
||||
error: string,
|
||||
asserter = async (x: any) => {
|
||||
const forkedSessionId = await x;
|
||||
t.truthy(forkedSessionId, error);
|
||||
return forkedSessionId;
|
||||
}
|
||||
) =>
|
||||
await asserter(
|
||||
forkCopilotSession(
|
||||
app,
|
||||
token,
|
||||
workspaceId,
|
||||
randomUUID(),
|
||||
sessionId,
|
||||
lastMessageId
|
||||
)
|
||||
);
|
||||
|
||||
// prepare session
|
||||
const { id } = await createWorkspace(app, token);
|
||||
const sessionId = await createCopilotSession(
|
||||
app,
|
||||
token,
|
||||
id,
|
||||
randomUUID(),
|
||||
promptName
|
||||
);
|
||||
|
||||
let forkedSessionId: string;
|
||||
// should be able to fork session
|
||||
{
|
||||
for (let i = 0; i < 3; i++) {
|
||||
const messageId = await createCopilotMessage(app, token, sessionId);
|
||||
await chatWithText(app, token, sessionId, messageId);
|
||||
}
|
||||
const histories = await getHistories(app, token, { workspaceId: id });
|
||||
const latestMessageId = histories[0].messages.findLast(
|
||||
m => m.role === 'assistant'
|
||||
)?.id;
|
||||
t.truthy(latestMessageId, 'should find last message id');
|
||||
|
||||
// should be able to fork session
|
||||
forkedSessionId = await assertForkSession(
|
||||
token,
|
||||
id,
|
||||
sessionId,
|
||||
latestMessageId!,
|
||||
'should be able to fork session with cloud workspace that user can access'
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
const {
|
||||
token: { token: newToken },
|
||||
} = await signUp(app, 'test', 'test@affine.pro', '123456');
|
||||
await assertForkSession(
|
||||
newToken,
|
||||
id,
|
||||
sessionId,
|
||||
randomUUID(),
|
||||
'',
|
||||
async x => {
|
||||
await t.throwsAsync(
|
||||
x,
|
||||
{ instanceOf: Error },
|
||||
'should not able to fork session with cloud workspace that user cannot access'
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
const inviteId = await inviteUser(
|
||||
app,
|
||||
token,
|
||||
id,
|
||||
'test@affine.pro',
|
||||
'Admin'
|
||||
);
|
||||
await acceptInviteById(app, id, inviteId, false);
|
||||
await assertForkSession(
|
||||
newToken,
|
||||
id,
|
||||
sessionId,
|
||||
randomUUID(),
|
||||
'',
|
||||
async x => {
|
||||
await t.throwsAsync(
|
||||
x,
|
||||
{ instanceOf: Error },
|
||||
'should not able to fork a root session from other user'
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
const histories = await getHistories(app, token, { workspaceId: id });
|
||||
const latestMessageId = histories
|
||||
.find(h => h.sessionId === forkedSessionId)
|
||||
?.messages.findLast(m => m.role === 'assistant')?.id;
|
||||
t.truthy(latestMessageId, 'should find latest message id');
|
||||
|
||||
await assertForkSession(
|
||||
newToken,
|
||||
id,
|
||||
forkedSessionId,
|
||||
latestMessageId!,
|
||||
'should able to fork a forked session created by other user'
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
test('should be able to use test provider', async t => {
|
||||
const { app } = t.context;
|
||||
|
||||
|
@ -174,6 +174,35 @@ export async function createCopilotSession(
|
||||
return res.body.data.createCopilotSession;
|
||||
}
|
||||
|
||||
export async function forkCopilotSession(
|
||||
app: INestApplication,
|
||||
userToken: string,
|
||||
workspaceId: string,
|
||||
docId: string,
|
||||
sessionId: string,
|
||||
latestMessageId: string
|
||||
): Promise<string> {
|
||||
const res = await request(app.getHttpServer())
|
||||
.post(gql)
|
||||
.auth(userToken, { type: 'bearer' })
|
||||
.set({ 'x-request-id': 'test', 'x-operation-name': 'test' })
|
||||
.send({
|
||||
query: `
|
||||
mutation forkCopilotSession($options: ForkChatSessionInput!) {
|
||||
forkCopilotSession(options: $options)
|
||||
}
|
||||
`,
|
||||
variables: {
|
||||
options: { workspaceId, docId, sessionId, latestMessageId },
|
||||
},
|
||||
})
|
||||
.expect(200);
|
||||
|
||||
handleGraphQLError(res);
|
||||
|
||||
return res.body.data.forkCopilotSession;
|
||||
}
|
||||
|
||||
export async function createCopilotMessage(
|
||||
app: INestApplication,
|
||||
userToken: string,
|
||||
@ -286,6 +315,7 @@ export function textToEventStream(
|
||||
}
|
||||
|
||||
type ChatMessage = {
|
||||
id?: string;
|
||||
role: string;
|
||||
content: string;
|
||||
attachments: string[] | null;
|
||||
@ -333,6 +363,7 @@ export async function getHistories(
|
||||
action
|
||||
createdAt
|
||||
messages {
|
||||
id
|
||||
role
|
||||
content
|
||||
attachments
|
||||
|
@ -145,7 +145,14 @@ export function handleGraphQLError(resp: Response) {
|
||||
if (errors) {
|
||||
const cause = errors[0];
|
||||
const stacktrace = cause.extensions?.stacktrace;
|
||||
throw new Error(stacktrace ? stacktrace.join('\n') : cause.message, cause);
|
||||
throw new Error(
|
||||
stacktrace
|
||||
? Array.isArray(stacktrace)
|
||||
? stacktrace.join('\n')
|
||||
: String(stacktrace)
|
||||
: cause.message,
|
||||
cause
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -11,6 +11,7 @@ query getCopilotHistories(
|
||||
action
|
||||
createdAt
|
||||
messages {
|
||||
id
|
||||
role
|
||||
content
|
||||
attachments
|
||||
|
@ -267,6 +267,7 @@ query getCopilotHistories($workspaceId: String!, $docId: String, $options: Query
|
||||
action
|
||||
createdAt
|
||||
messages {
|
||||
id
|
||||
role
|
||||
content
|
||||
attachments
|
||||
|
@ -1405,6 +1405,7 @@ export type GetCopilotHistoriesQuery = {
|
||||
createdAt: string;
|
||||
messages: Array<{
|
||||
__typename?: 'ChatMessage';
|
||||
id: string | null;
|
||||
role: string;
|
||||
content: string;
|
||||
attachments: Array<string> | null;
|
||||
|
Loading…
Reference in New Issue
Block a user