feat: add retry support for copilot (#6947)

This commit is contained in:
darkskygit 2024-05-15 11:02:37 +00:00
parent 7e7a4120aa
commit 0076359d6a
No known key found for this signature in database
GPG Key ID: 97B7D036B1566E9D
5 changed files with 161 additions and 9 deletions

View File

@ -82,14 +82,21 @@ export class CopilotController {
private async appendSessionMessage( private async appendSessionMessage(
sessionId: string, sessionId: string,
messageId: string messageId?: string
): Promise<ChatSession> { ): Promise<ChatSession> {
const session = await this.chatSession.get(sessionId); const session = await this.chatSession.get(sessionId);
if (!session) { if (!session) {
throw new BadRequestException('Session not found'); throw new BadRequestException('Session not found');
} }
await session.pushByMessageId(messageId); if (messageId) {
await session.pushByMessageId(messageId);
} else {
// revert the latest message generated by the assistant
// if messageId is not provided, then we can retry the action
await this.chatSession.revertLatestMessage(sessionId);
session.revertLatestMessage();
}
return session; return session;
} }
@ -129,7 +136,6 @@ export class CopilotController {
@CurrentUser() user: CurrentUser, @CurrentUser() user: CurrentUser,
@Req() req: Request, @Req() req: Request,
@Param('sessionId') sessionId: string, @Param('sessionId') sessionId: string,
@Query('messageId') messageId: string,
@Query() params: Record<string, string | string[]> @Query() params: Record<string, string | string[]>
): Promise<string> { ): Promise<string> {
const { model } = await this.checkRequest(user.id, sessionId); const { model } = await this.checkRequest(user.id, sessionId);
@ -141,6 +147,9 @@ export class CopilotController {
throw new InternalServerErrorException('No provider available'); throw new InternalServerErrorException('No provider available');
} }
const messageId = Array.isArray(params.messageId)
? params.messageId[0]
: params.messageId;
const session = await this.appendSessionMessage(sessionId, messageId); const session = await this.appendSessionMessage(sessionId, messageId);
try { try {
@ -174,7 +183,6 @@ export class CopilotController {
@CurrentUser() user: CurrentUser, @CurrentUser() user: CurrentUser,
@Req() req: Request, @Req() req: Request,
@Param('sessionId') sessionId: string, @Param('sessionId') sessionId: string,
@Query('messageId') messageId: string,
@Query() params: Record<string, string> @Query() params: Record<string, string>
): Promise<Observable<ChatEvent>> { ): Promise<Observable<ChatEvent>> {
try { try {
@ -187,6 +195,9 @@ export class CopilotController {
throw new InternalServerErrorException('No provider available'); throw new InternalServerErrorException('No provider available');
} }
const messageId = Array.isArray(params.messageId)
? params.messageId[0]
: params.messageId;
const session = await this.appendSessionMessage(sessionId, messageId); const session = await this.appendSessionMessage(sessionId, messageId);
delete params.messageId; delete params.messageId;
@ -237,10 +248,12 @@ export class CopilotController {
@CurrentUser() user: CurrentUser, @CurrentUser() user: CurrentUser,
@Req() req: Request, @Req() req: Request,
@Param('sessionId') sessionId: string, @Param('sessionId') sessionId: string,
@Query('messageId') messageId: string,
@Query() params: Record<string, string> @Query() params: Record<string, string>
): Promise<Observable<ChatEvent>> { ): Promise<Observable<ChatEvent>> {
try { try {
const messageId = Array.isArray(params.messageId)
? params.messageId[0]
: params.messageId;
const { model, hasAttachment } = await this.checkRequest( const { model, hasAttachment } = await this.checkRequest(
user.id, user.id,
sessionId, sessionId,

View File

@ -64,6 +64,13 @@ export class ChatSession implements AsyncDisposable {
this.stashMessageCount += 1; this.stashMessageCount += 1;
} }
revertLatestMessage() {
const messages = this.state.messages;
messages.splice(
messages.findLastIndex(({ role }) => role === AiPromptRole.user) + 1
);
}
async getMessageById(messageId: string) { async getMessageById(messageId: string) {
const message = await this.messageCache.get(messageId); const message = await this.messageCache.get(messageId);
if (!message || message.sessionId !== this.state.sessionId) { if (!message || message.sessionId !== this.state.sessionId) {
@ -287,6 +294,29 @@ export class ChatSessionService {
}); });
} }
// revert the latest messages not generate by user
// after revert, we can retry the action
async revertLatestMessage(sessionId: string) {
await this.db.$transaction(async tx => {
const ids = await tx.aiSessionMessage
.findMany({
where: { sessionId },
select: { id: true, role: true },
orderBy: { createdAt: 'asc' },
})
.then(roles =>
roles
.slice(
roles.findLastIndex(({ role }) => role === AiPromptRole.user) + 1
)
.map(({ id }) => id)
);
if (ids.length) {
await tx.aiSessionMessage.deleteMany({ where: { id: { in: ids } } });
}
});
}
private calculateTokenSize( private calculateTokenSize(
messages: PromptMessage[], messages: PromptMessage[],
model: AvailableModel model: AvailableModel

View File

@ -228,6 +228,61 @@ test('should be able to chat with api', async t => {
Sinon.restore(); Sinon.restore();
}); });
test('should be able to retry with api', async t => {
const { app, storage } = t.context;
Sinon.stub(storage, 'handleRemoteLink').resolvesArg(2);
// normal chat
{
const { id } = await createWorkspace(app, token);
const sessionId = await createCopilotSession(
app,
token,
id,
randomUUID(),
promptName
);
const messageId = await createCopilotMessage(app, token, sessionId);
// chat 2 times
await chatWithText(app, token, sessionId, messageId);
await chatWithText(app, token, sessionId, messageId);
const histories = await getHistories(app, token, { workspaceId: id });
t.deepEqual(
histories.map(h => h.messages.map(m => m.content)),
[['generate text to text', 'generate text to text']],
'should be able to list history'
);
}
// retry chat
{
const { id } = await createWorkspace(app, token);
const sessionId = await createCopilotSession(
app,
token,
id,
randomUUID(),
promptName
);
const messageId = await createCopilotMessage(app, token, sessionId);
await chatWithText(app, token, sessionId, messageId);
// retry without message id
await chatWithText(app, token, sessionId);
// should only have 1 message
const histories = await getHistories(app, token, { workspaceId: id });
t.deepEqual(
histories.map(h => h.messages.map(m => m.content)),
[['generate text to text']],
'should be able to list history'
);
}
Sinon.restore();
});
test('should reject message from different session', async t => { test('should reject message from different session', async t => {
const { app } = t.context; const { app } = t.context;

View File

@ -362,6 +362,59 @@ test('should save message correctly', async t => {
t.is(s.stashMessages.length, 0, 'should empty stash messages after save'); t.is(s.stashMessages.length, 0, 'should empty stash messages after save');
}); });
test('should revert message correctly', async t => {
const { prompt, session } = t.context;
// init session
let sessionId: string;
{
await prompt.set('prompt', 'model', [
{ role: 'system', content: 'hello {{word}}' },
]);
sessionId = await session.create({
docId: 'test',
workspaceId: 'test',
userId,
promptName: 'prompt',
});
const s = (await session.get(sessionId))!;
const message = (await session.createMessage({
sessionId,
content: 'hello',
}))!;
await s.pushByMessageId(message);
await s.save();
}
// check ChatSession behavior
{
const s = (await session.get(sessionId))!;
s.push({ role: 'assistant', content: 'hi', createdAt: new Date() });
await s.save();
const beforeRevert = s.finish({ word: 'world' });
t.is(beforeRevert.length, 3, 'should have three messages before revert');
s.revertLatestMessage();
const afterRevert = s.finish({ word: 'world' });
t.is(afterRevert.length, 2, 'should remove assistant message after revert');
}
// check database behavior
{
let s = (await session.get(sessionId))!;
const beforeRevert = s.finish({ word: 'world' });
t.is(beforeRevert.length, 3, 'should have three messages before revert');
await session.revertLatestMessage(sessionId);
s = (await session.get(sessionId))!;
const afterRevert = s.finish({ word: 'world' });
t.is(afterRevert.length, 2, 'should remove assistant message after revert');
}
});
// ==================== provider ==================== // ==================== provider ====================
test('should be able to get provider', async t => { test('should be able to get provider', async t => {

View File

@ -196,11 +196,12 @@ export async function chatWithText(
app: INestApplication, app: INestApplication,
userToken: string, userToken: string,
sessionId: string, sessionId: string,
messageId: string, messageId?: string,
prefix = '' prefix = ''
): Promise<string> { ): Promise<string> {
const query = messageId ? `?messageId=${messageId}` : '';
const res = await request(app.getHttpServer()) const res = await request(app.getHttpServer())
.get(`/api/copilot/chat/${sessionId}${prefix}?messageId=${messageId}`) .get(`/api/copilot/chat/${sessionId}${prefix}${query}`)
.auth(userToken, { type: 'bearer' }) .auth(userToken, { type: 'bearer' })
.expect(200); .expect(200);
@ -211,7 +212,7 @@ export async function chatWithTextStream(
app: INestApplication, app: INestApplication,
userToken: string, userToken: string,
sessionId: string, sessionId: string,
messageId: string messageId?: string
) { ) {
return chatWithText(app, userToken, sessionId, messageId, '/stream'); return chatWithText(app, userToken, sessionId, messageId, '/stream');
} }
@ -220,7 +221,7 @@ export async function chatWithImages(
app: INestApplication, app: INestApplication,
userToken: string, userToken: string,
sessionId: string, sessionId: string,
messageId: string messageId?: string
) { ) {
return chatWithText(app, userToken, sessionId, messageId, '/images'); return chatWithText(app, userToken, sessionId, messageId, '/images');
} }