mirror of
https://github.com/toeverything/AFFiNE.git
synced 2025-01-07 03:26:39 +03:00
feat: improve histories query for forked session (#7414)
This commit is contained in:
parent
cc7740d8d3
commit
e8285289fe
@ -273,12 +273,7 @@ export class CopilotResolver {
|
|||||||
@Parent() copilot: CopilotType,
|
@Parent() copilot: CopilotType,
|
||||||
@CurrentUser() user: CurrentUser,
|
@CurrentUser() user: CurrentUser,
|
||||||
@Args('docId', { nullable: true }) docId?: string,
|
@Args('docId', { nullable: true }) docId?: string,
|
||||||
@Args({
|
@Args('options', { nullable: true }) options?: QueryChatHistoriesInput
|
||||||
name: 'options',
|
|
||||||
type: () => QueryChatHistoriesInput,
|
|
||||||
nullable: true,
|
|
||||||
})
|
|
||||||
options?: QueryChatHistoriesInput
|
|
||||||
) {
|
) {
|
||||||
const workspaceId = copilot.workspaceId;
|
const workspaceId = copilot.workspaceId;
|
||||||
if (!workspaceId) {
|
if (!workspaceId) {
|
||||||
|
@ -254,6 +254,7 @@ export class ChatSessionService {
|
|||||||
// connect
|
// connect
|
||||||
userId: state.userId,
|
userId: state.userId,
|
||||||
promptName: state.prompt.name,
|
promptName: state.prompt.name,
|
||||||
|
parentSessionId: state.parentSessionId,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@ -384,17 +385,36 @@ export class ChatSessionService {
|
|||||||
return await this.db.aiSession
|
return await this.db.aiSession
|
||||||
.findMany({
|
.findMany({
|
||||||
where: {
|
where: {
|
||||||
userId,
|
OR: [
|
||||||
workspaceId: workspaceId,
|
{
|
||||||
docId: workspaceId === docId ? undefined : docId,
|
userId,
|
||||||
prompt: {
|
workspaceId: workspaceId,
|
||||||
action: options?.action ? { not: null } : null,
|
docId: workspaceId === docId ? undefined : docId,
|
||||||
},
|
id: options?.sessionId
|
||||||
id: options?.sessionId ? { equals: options.sessionId } : undefined,
|
? { equals: options.sessionId }
|
||||||
deletedAt: null,
|
: undefined,
|
||||||
|
deletedAt: null,
|
||||||
|
},
|
||||||
|
...(options?.action
|
||||||
|
? []
|
||||||
|
: [
|
||||||
|
{
|
||||||
|
userId: { not: userId },
|
||||||
|
workspaceId: workspaceId,
|
||||||
|
docId: workspaceId === docId ? undefined : docId,
|
||||||
|
id: options?.sessionId
|
||||||
|
? { equals: options.sessionId }
|
||||||
|
: undefined,
|
||||||
|
// should only find forked session
|
||||||
|
parentSessionId: { not: null },
|
||||||
|
deletedAt: null,
|
||||||
|
},
|
||||||
|
]),
|
||||||
|
],
|
||||||
},
|
},
|
||||||
select: {
|
select: {
|
||||||
id: true,
|
id: true,
|
||||||
|
userId: true,
|
||||||
promptName: true,
|
promptName: true,
|
||||||
tokenCost: true,
|
tokenCost: true,
|
||||||
createdAt: true,
|
createdAt: true,
|
||||||
@ -419,15 +439,30 @@ export class ChatSessionService {
|
|||||||
.then(sessions =>
|
.then(sessions =>
|
||||||
Promise.all(
|
Promise.all(
|
||||||
sessions.map(
|
sessions.map(
|
||||||
async ({ id, promptName, tokenCost, messages, createdAt }) => {
|
async ({
|
||||||
|
id,
|
||||||
|
userId: uid,
|
||||||
|
promptName,
|
||||||
|
tokenCost,
|
||||||
|
messages,
|
||||||
|
createdAt,
|
||||||
|
}) => {
|
||||||
try {
|
try {
|
||||||
|
const prompt = await this.prompt.get(promptName);
|
||||||
|
if (!prompt) {
|
||||||
|
throw new CopilotPromptNotFound({ name: promptName });
|
||||||
|
}
|
||||||
|
if (
|
||||||
|
// filter out the user's session that not match the action option
|
||||||
|
(uid === userId && !!options?.action !== !!prompt.action) ||
|
||||||
|
// filter out the non chat session from other user
|
||||||
|
(uid !== userId && !!prompt.action)
|
||||||
|
) {
|
||||||
|
return undefined;
|
||||||
|
}
|
||||||
|
|
||||||
const ret = ChatMessageSchema.array().safeParse(messages);
|
const ret = ChatMessageSchema.array().safeParse(messages);
|
||||||
if (ret.success) {
|
if (ret.success) {
|
||||||
const prompt = await this.prompt.get(promptName);
|
|
||||||
if (!prompt) {
|
|
||||||
throw new CopilotPromptNotFound({ name: promptName });
|
|
||||||
}
|
|
||||||
|
|
||||||
// render system prompt
|
// render system prompt
|
||||||
const preload = withPrompt
|
const preload = withPrompt
|
||||||
? prompt
|
? prompt
|
||||||
|
@ -36,6 +36,7 @@ import {
|
|||||||
chatWithWorkflow,
|
chatWithWorkflow,
|
||||||
createCopilotMessage,
|
createCopilotMessage,
|
||||||
createCopilotSession,
|
createCopilotSession,
|
||||||
|
forkCopilotSession,
|
||||||
getHistories,
|
getHistories,
|
||||||
MockCopilotTestProvider,
|
MockCopilotTestProvider,
|
||||||
sse2array,
|
sse2array,
|
||||||
@ -164,6 +165,123 @@ test('should create session correctly', async t => {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
test('should fork session correctly', async t => {
|
||||||
|
const { app } = t.context;
|
||||||
|
|
||||||
|
const assertForkSession = async (
|
||||||
|
token: string,
|
||||||
|
workspaceId: string,
|
||||||
|
sessionId: string,
|
||||||
|
lastMessageId: string,
|
||||||
|
error: string,
|
||||||
|
asserter = async (x: any) => {
|
||||||
|
const forkedSessionId = await x;
|
||||||
|
t.truthy(forkedSessionId, error);
|
||||||
|
return forkedSessionId;
|
||||||
|
}
|
||||||
|
) =>
|
||||||
|
await asserter(
|
||||||
|
forkCopilotSession(
|
||||||
|
app,
|
||||||
|
token,
|
||||||
|
workspaceId,
|
||||||
|
randomUUID(),
|
||||||
|
sessionId,
|
||||||
|
lastMessageId
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
|
// prepare session
|
||||||
|
const { id } = await createWorkspace(app, token);
|
||||||
|
const sessionId = await createCopilotSession(
|
||||||
|
app,
|
||||||
|
token,
|
||||||
|
id,
|
||||||
|
randomUUID(),
|
||||||
|
promptName
|
||||||
|
);
|
||||||
|
|
||||||
|
let forkedSessionId: string;
|
||||||
|
// should be able to fork session
|
||||||
|
{
|
||||||
|
for (let i = 0; i < 3; i++) {
|
||||||
|
const messageId = await createCopilotMessage(app, token, sessionId);
|
||||||
|
await chatWithText(app, token, sessionId, messageId);
|
||||||
|
}
|
||||||
|
const histories = await getHistories(app, token, { workspaceId: id });
|
||||||
|
const latestMessageId = histories[0].messages.findLast(
|
||||||
|
m => m.role === 'assistant'
|
||||||
|
)?.id;
|
||||||
|
t.truthy(latestMessageId, 'should find last message id');
|
||||||
|
|
||||||
|
// should be able to fork session
|
||||||
|
forkedSessionId = await assertForkSession(
|
||||||
|
token,
|
||||||
|
id,
|
||||||
|
sessionId,
|
||||||
|
latestMessageId!,
|
||||||
|
'should be able to fork session with cloud workspace that user can access'
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
const {
|
||||||
|
token: { token: newToken },
|
||||||
|
} = await signUp(app, 'test', 'test@affine.pro', '123456');
|
||||||
|
await assertForkSession(
|
||||||
|
newToken,
|
||||||
|
id,
|
||||||
|
sessionId,
|
||||||
|
randomUUID(),
|
||||||
|
'',
|
||||||
|
async x => {
|
||||||
|
await t.throwsAsync(
|
||||||
|
x,
|
||||||
|
{ instanceOf: Error },
|
||||||
|
'should not able to fork session with cloud workspace that user cannot access'
|
||||||
|
);
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
const inviteId = await inviteUser(
|
||||||
|
app,
|
||||||
|
token,
|
||||||
|
id,
|
||||||
|
'test@affine.pro',
|
||||||
|
'Admin'
|
||||||
|
);
|
||||||
|
await acceptInviteById(app, id, inviteId, false);
|
||||||
|
await assertForkSession(
|
||||||
|
newToken,
|
||||||
|
id,
|
||||||
|
sessionId,
|
||||||
|
randomUUID(),
|
||||||
|
'',
|
||||||
|
async x => {
|
||||||
|
await t.throwsAsync(
|
||||||
|
x,
|
||||||
|
{ instanceOf: Error },
|
||||||
|
'should not able to fork a root session from other user'
|
||||||
|
);
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
const histories = await getHistories(app, token, { workspaceId: id });
|
||||||
|
const latestMessageId = histories
|
||||||
|
.find(h => h.sessionId === forkedSessionId)
|
||||||
|
?.messages.findLast(m => m.role === 'assistant')?.id;
|
||||||
|
t.truthy(latestMessageId, 'should find latest message id');
|
||||||
|
|
||||||
|
await assertForkSession(
|
||||||
|
newToken,
|
||||||
|
id,
|
||||||
|
forkedSessionId,
|
||||||
|
latestMessageId!,
|
||||||
|
'should able to fork a forked session created by other user'
|
||||||
|
);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
test('should be able to use test provider', async t => {
|
test('should be able to use test provider', async t => {
|
||||||
const { app } = t.context;
|
const { app } = t.context;
|
||||||
|
|
||||||
|
@ -174,6 +174,35 @@ export async function createCopilotSession(
|
|||||||
return res.body.data.createCopilotSession;
|
return res.body.data.createCopilotSession;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export async function forkCopilotSession(
|
||||||
|
app: INestApplication,
|
||||||
|
userToken: string,
|
||||||
|
workspaceId: string,
|
||||||
|
docId: string,
|
||||||
|
sessionId: string,
|
||||||
|
latestMessageId: string
|
||||||
|
): Promise<string> {
|
||||||
|
const res = await request(app.getHttpServer())
|
||||||
|
.post(gql)
|
||||||
|
.auth(userToken, { type: 'bearer' })
|
||||||
|
.set({ 'x-request-id': 'test', 'x-operation-name': 'test' })
|
||||||
|
.send({
|
||||||
|
query: `
|
||||||
|
mutation forkCopilotSession($options: ForkChatSessionInput!) {
|
||||||
|
forkCopilotSession(options: $options)
|
||||||
|
}
|
||||||
|
`,
|
||||||
|
variables: {
|
||||||
|
options: { workspaceId, docId, sessionId, latestMessageId },
|
||||||
|
},
|
||||||
|
})
|
||||||
|
.expect(200);
|
||||||
|
|
||||||
|
handleGraphQLError(res);
|
||||||
|
|
||||||
|
return res.body.data.forkCopilotSession;
|
||||||
|
}
|
||||||
|
|
||||||
export async function createCopilotMessage(
|
export async function createCopilotMessage(
|
||||||
app: INestApplication,
|
app: INestApplication,
|
||||||
userToken: string,
|
userToken: string,
|
||||||
@ -286,6 +315,7 @@ export function textToEventStream(
|
|||||||
}
|
}
|
||||||
|
|
||||||
type ChatMessage = {
|
type ChatMessage = {
|
||||||
|
id?: string;
|
||||||
role: string;
|
role: string;
|
||||||
content: string;
|
content: string;
|
||||||
attachments: string[] | null;
|
attachments: string[] | null;
|
||||||
@ -333,6 +363,7 @@ export async function getHistories(
|
|||||||
action
|
action
|
||||||
createdAt
|
createdAt
|
||||||
messages {
|
messages {
|
||||||
|
id
|
||||||
role
|
role
|
||||||
content
|
content
|
||||||
attachments
|
attachments
|
||||||
|
@ -145,7 +145,14 @@ export function handleGraphQLError(resp: Response) {
|
|||||||
if (errors) {
|
if (errors) {
|
||||||
const cause = errors[0];
|
const cause = errors[0];
|
||||||
const stacktrace = cause.extensions?.stacktrace;
|
const stacktrace = cause.extensions?.stacktrace;
|
||||||
throw new Error(stacktrace ? stacktrace.join('\n') : cause.message, cause);
|
throw new Error(
|
||||||
|
stacktrace
|
||||||
|
? Array.isArray(stacktrace)
|
||||||
|
? stacktrace.join('\n')
|
||||||
|
: String(stacktrace)
|
||||||
|
: cause.message,
|
||||||
|
cause
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -11,6 +11,7 @@ query getCopilotHistories(
|
|||||||
action
|
action
|
||||||
createdAt
|
createdAt
|
||||||
messages {
|
messages {
|
||||||
|
id
|
||||||
role
|
role
|
||||||
content
|
content
|
||||||
attachments
|
attachments
|
||||||
|
@ -267,6 +267,7 @@ query getCopilotHistories($workspaceId: String!, $docId: String, $options: Query
|
|||||||
action
|
action
|
||||||
createdAt
|
createdAt
|
||||||
messages {
|
messages {
|
||||||
|
id
|
||||||
role
|
role
|
||||||
content
|
content
|
||||||
attachments
|
attachments
|
||||||
|
@ -1405,6 +1405,7 @@ export type GetCopilotHistoriesQuery = {
|
|||||||
createdAt: string;
|
createdAt: string;
|
||||||
messages: Array<{
|
messages: Array<{
|
||||||
__typename?: 'ChatMessage';
|
__typename?: 'ChatMessage';
|
||||||
|
id: string | null;
|
||||||
role: string;
|
role: string;
|
||||||
content: string;
|
content: string;
|
||||||
attachments: Array<string> | null;
|
attachments: Array<string> | null;
|
||||||
|
Loading…
Reference in New Issue
Block a user