From 0076359d6a0dc456a44b768031e084e388737487 Mon Sep 17 00:00:00 2001 From: darkskygit Date: Wed, 15 May 2024 11:02:37 +0000 Subject: [PATCH] feat: add retry support for copilot (#6947) --- .../server/src/plugins/copilot/controller.ts | 23 ++++++-- .../server/src/plugins/copilot/session.ts | 30 ++++++++++ packages/backend/server/tests/copilot.e2e.ts | 55 +++++++++++++++++++ packages/backend/server/tests/copilot.spec.ts | 53 ++++++++++++++++++ .../backend/server/tests/utils/copilot.ts | 9 +-- 5 files changed, 161 insertions(+), 9 deletions(-) diff --git a/packages/backend/server/src/plugins/copilot/controller.ts b/packages/backend/server/src/plugins/copilot/controller.ts index 042996df14..62e3d14056 100644 --- a/packages/backend/server/src/plugins/copilot/controller.ts +++ b/packages/backend/server/src/plugins/copilot/controller.ts @@ -82,14 +82,21 @@ export class CopilotController { private async appendSessionMessage( sessionId: string, - messageId: string + messageId?: string ): Promise { const session = await this.chatSession.get(sessionId); if (!session) { throw new BadRequestException('Session not found'); } - await session.pushByMessageId(messageId); + if (messageId) { + await session.pushByMessageId(messageId); + } else { + // revert the latest message generated by the assistant + // if messageId is not provided, then we can retry the action + await this.chatSession.revertLatestMessage(sessionId); + session.revertLatestMessage(); + } return session; } @@ -129,7 +136,6 @@ export class CopilotController { @CurrentUser() user: CurrentUser, @Req() req: Request, @Param('sessionId') sessionId: string, - @Query('messageId') messageId: string, @Query() params: Record ): Promise { const { model } = await this.checkRequest(user.id, sessionId); @@ -141,6 +147,9 @@ export class CopilotController { throw new InternalServerErrorException('No provider available'); } + const messageId = Array.isArray(params.messageId) + ? params.messageId[0] + : params.messageId; const session = await this.appendSessionMessage(sessionId, messageId); try { @@ -174,7 +183,6 @@ export class CopilotController { @CurrentUser() user: CurrentUser, @Req() req: Request, @Param('sessionId') sessionId: string, - @Query('messageId') messageId: string, @Query() params: Record ): Promise> { try { @@ -187,6 +195,9 @@ export class CopilotController { throw new InternalServerErrorException('No provider available'); } + const messageId = Array.isArray(params.messageId) + ? params.messageId[0] + : params.messageId; const session = await this.appendSessionMessage(sessionId, messageId); delete params.messageId; @@ -237,10 +248,12 @@ export class CopilotController { @CurrentUser() user: CurrentUser, @Req() req: Request, @Param('sessionId') sessionId: string, - @Query('messageId') messageId: string, @Query() params: Record ): Promise> { try { + const messageId = Array.isArray(params.messageId) + ? params.messageId[0] + : params.messageId; const { model, hasAttachment } = await this.checkRequest( user.id, sessionId, diff --git a/packages/backend/server/src/plugins/copilot/session.ts b/packages/backend/server/src/plugins/copilot/session.ts index d313e31a34..9bf96b3319 100644 --- a/packages/backend/server/src/plugins/copilot/session.ts +++ b/packages/backend/server/src/plugins/copilot/session.ts @@ -64,6 +64,13 @@ export class ChatSession implements AsyncDisposable { this.stashMessageCount += 1; } + revertLatestMessage() { + const messages = this.state.messages; + messages.splice( + messages.findLastIndex(({ role }) => role === AiPromptRole.user) + 1 + ); + } + async getMessageById(messageId: string) { const message = await this.messageCache.get(messageId); if (!message || message.sessionId !== this.state.sessionId) { @@ -287,6 +294,29 @@ export class ChatSessionService { }); } + // revert the latest messages not generate by user + // after revert, we can retry the action + async revertLatestMessage(sessionId: string) { + await this.db.$transaction(async tx => { + const ids = await tx.aiSessionMessage + .findMany({ + where: { sessionId }, + select: { id: true, role: true }, + orderBy: { createdAt: 'asc' }, + }) + .then(roles => + roles + .slice( + roles.findLastIndex(({ role }) => role === AiPromptRole.user) + 1 + ) + .map(({ id }) => id) + ); + if (ids.length) { + await tx.aiSessionMessage.deleteMany({ where: { id: { in: ids } } }); + } + }); + } + private calculateTokenSize( messages: PromptMessage[], model: AvailableModel diff --git a/packages/backend/server/tests/copilot.e2e.ts b/packages/backend/server/tests/copilot.e2e.ts index cd1d9e5437..4653330fa5 100644 --- a/packages/backend/server/tests/copilot.e2e.ts +++ b/packages/backend/server/tests/copilot.e2e.ts @@ -228,6 +228,61 @@ test('should be able to chat with api', async t => { Sinon.restore(); }); +test('should be able to retry with api', async t => { + const { app, storage } = t.context; + + Sinon.stub(storage, 'handleRemoteLink').resolvesArg(2); + + // normal chat + { + const { id } = await createWorkspace(app, token); + const sessionId = await createCopilotSession( + app, + token, + id, + randomUUID(), + promptName + ); + const messageId = await createCopilotMessage(app, token, sessionId); + // chat 2 times + await chatWithText(app, token, sessionId, messageId); + await chatWithText(app, token, sessionId, messageId); + + const histories = await getHistories(app, token, { workspaceId: id }); + t.deepEqual( + histories.map(h => h.messages.map(m => m.content)), + [['generate text to text', 'generate text to text']], + 'should be able to list history' + ); + } + + // retry chat + { + const { id } = await createWorkspace(app, token); + const sessionId = await createCopilotSession( + app, + token, + id, + randomUUID(), + promptName + ); + const messageId = await createCopilotMessage(app, token, sessionId); + await chatWithText(app, token, sessionId, messageId); + // retry without message id + await chatWithText(app, token, sessionId); + + // should only have 1 message + const histories = await getHistories(app, token, { workspaceId: id }); + t.deepEqual( + histories.map(h => h.messages.map(m => m.content)), + [['generate text to text']], + 'should be able to list history' + ); + } + + Sinon.restore(); +}); + test('should reject message from different session', async t => { const { app } = t.context; diff --git a/packages/backend/server/tests/copilot.spec.ts b/packages/backend/server/tests/copilot.spec.ts index 6faf4fce1d..6d0cf31eae 100644 --- a/packages/backend/server/tests/copilot.spec.ts +++ b/packages/backend/server/tests/copilot.spec.ts @@ -362,6 +362,59 @@ test('should save message correctly', async t => { t.is(s.stashMessages.length, 0, 'should empty stash messages after save'); }); +test('should revert message correctly', async t => { + const { prompt, session } = t.context; + + // init session + let sessionId: string; + { + await prompt.set('prompt', 'model', [ + { role: 'system', content: 'hello {{word}}' }, + ]); + + sessionId = await session.create({ + docId: 'test', + workspaceId: 'test', + userId, + promptName: 'prompt', + }); + const s = (await session.get(sessionId))!; + + const message = (await session.createMessage({ + sessionId, + content: 'hello', + }))!; + + await s.pushByMessageId(message); + await s.save(); + } + + // check ChatSession behavior + { + const s = (await session.get(sessionId))!; + s.push({ role: 'assistant', content: 'hi', createdAt: new Date() }); + await s.save(); + const beforeRevert = s.finish({ word: 'world' }); + t.is(beforeRevert.length, 3, 'should have three messages before revert'); + + s.revertLatestMessage(); + const afterRevert = s.finish({ word: 'world' }); + t.is(afterRevert.length, 2, 'should remove assistant message after revert'); + } + + // check database behavior + { + let s = (await session.get(sessionId))!; + const beforeRevert = s.finish({ word: 'world' }); + t.is(beforeRevert.length, 3, 'should have three messages before revert'); + + await session.revertLatestMessage(sessionId); + s = (await session.get(sessionId))!; + const afterRevert = s.finish({ word: 'world' }); + t.is(afterRevert.length, 2, 'should remove assistant message after revert'); + } +}); + // ==================== provider ==================== test('should be able to get provider', async t => { diff --git a/packages/backend/server/tests/utils/copilot.ts b/packages/backend/server/tests/utils/copilot.ts index 52b65f69ad..b28bb2bcbd 100644 --- a/packages/backend/server/tests/utils/copilot.ts +++ b/packages/backend/server/tests/utils/copilot.ts @@ -196,11 +196,12 @@ export async function chatWithText( app: INestApplication, userToken: string, sessionId: string, - messageId: string, + messageId?: string, prefix = '' ): Promise { + const query = messageId ? `?messageId=${messageId}` : ''; const res = await request(app.getHttpServer()) - .get(`/api/copilot/chat/${sessionId}${prefix}?messageId=${messageId}`) + .get(`/api/copilot/chat/${sessionId}${prefix}${query}`) .auth(userToken, { type: 'bearer' }) .expect(200); @@ -211,7 +212,7 @@ export async function chatWithTextStream( app: INestApplication, userToken: string, sessionId: string, - messageId: string + messageId?: string ) { return chatWithText(app, userToken, sessionId, messageId, '/stream'); } @@ -220,7 +221,7 @@ export async function chatWithImages( app: INestApplication, userToken: string, sessionId: string, - messageId: string + messageId?: string ) { return chatWithText(app, userToken, sessionId, messageId, '/images'); }