mirror of
https://github.com/toeverything/AFFiNE.git
synced 2024-11-29 16:03:45 +03:00
feat: add retry support for copilot (#6947)
This commit is contained in:
parent
7e7a4120aa
commit
0076359d6a
@ -82,14 +82,21 @@ export class CopilotController {
|
|||||||
|
|
||||||
private async appendSessionMessage(
|
private async appendSessionMessage(
|
||||||
sessionId: string,
|
sessionId: string,
|
||||||
messageId: string
|
messageId?: string
|
||||||
): Promise<ChatSession> {
|
): Promise<ChatSession> {
|
||||||
const session = await this.chatSession.get(sessionId);
|
const session = await this.chatSession.get(sessionId);
|
||||||
if (!session) {
|
if (!session) {
|
||||||
throw new BadRequestException('Session not found');
|
throw new BadRequestException('Session not found');
|
||||||
}
|
}
|
||||||
|
|
||||||
await session.pushByMessageId(messageId);
|
if (messageId) {
|
||||||
|
await session.pushByMessageId(messageId);
|
||||||
|
} else {
|
||||||
|
// revert the latest message generated by the assistant
|
||||||
|
// if messageId is not provided, then we can retry the action
|
||||||
|
await this.chatSession.revertLatestMessage(sessionId);
|
||||||
|
session.revertLatestMessage();
|
||||||
|
}
|
||||||
|
|
||||||
return session;
|
return session;
|
||||||
}
|
}
|
||||||
@ -129,7 +136,6 @@ export class CopilotController {
|
|||||||
@CurrentUser() user: CurrentUser,
|
@CurrentUser() user: CurrentUser,
|
||||||
@Req() req: Request,
|
@Req() req: Request,
|
||||||
@Param('sessionId') sessionId: string,
|
@Param('sessionId') sessionId: string,
|
||||||
@Query('messageId') messageId: string,
|
|
||||||
@Query() params: Record<string, string | string[]>
|
@Query() params: Record<string, string | string[]>
|
||||||
): Promise<string> {
|
): Promise<string> {
|
||||||
const { model } = await this.checkRequest(user.id, sessionId);
|
const { model } = await this.checkRequest(user.id, sessionId);
|
||||||
@ -141,6 +147,9 @@ export class CopilotController {
|
|||||||
throw new InternalServerErrorException('No provider available');
|
throw new InternalServerErrorException('No provider available');
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const messageId = Array.isArray(params.messageId)
|
||||||
|
? params.messageId[0]
|
||||||
|
: params.messageId;
|
||||||
const session = await this.appendSessionMessage(sessionId, messageId);
|
const session = await this.appendSessionMessage(sessionId, messageId);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
@ -174,7 +183,6 @@ export class CopilotController {
|
|||||||
@CurrentUser() user: CurrentUser,
|
@CurrentUser() user: CurrentUser,
|
||||||
@Req() req: Request,
|
@Req() req: Request,
|
||||||
@Param('sessionId') sessionId: string,
|
@Param('sessionId') sessionId: string,
|
||||||
@Query('messageId') messageId: string,
|
|
||||||
@Query() params: Record<string, string>
|
@Query() params: Record<string, string>
|
||||||
): Promise<Observable<ChatEvent>> {
|
): Promise<Observable<ChatEvent>> {
|
||||||
try {
|
try {
|
||||||
@ -187,6 +195,9 @@ export class CopilotController {
|
|||||||
throw new InternalServerErrorException('No provider available');
|
throw new InternalServerErrorException('No provider available');
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const messageId = Array.isArray(params.messageId)
|
||||||
|
? params.messageId[0]
|
||||||
|
: params.messageId;
|
||||||
const session = await this.appendSessionMessage(sessionId, messageId);
|
const session = await this.appendSessionMessage(sessionId, messageId);
|
||||||
delete params.messageId;
|
delete params.messageId;
|
||||||
|
|
||||||
@ -237,10 +248,12 @@ export class CopilotController {
|
|||||||
@CurrentUser() user: CurrentUser,
|
@CurrentUser() user: CurrentUser,
|
||||||
@Req() req: Request,
|
@Req() req: Request,
|
||||||
@Param('sessionId') sessionId: string,
|
@Param('sessionId') sessionId: string,
|
||||||
@Query('messageId') messageId: string,
|
|
||||||
@Query() params: Record<string, string>
|
@Query() params: Record<string, string>
|
||||||
): Promise<Observable<ChatEvent>> {
|
): Promise<Observable<ChatEvent>> {
|
||||||
try {
|
try {
|
||||||
|
const messageId = Array.isArray(params.messageId)
|
||||||
|
? params.messageId[0]
|
||||||
|
: params.messageId;
|
||||||
const { model, hasAttachment } = await this.checkRequest(
|
const { model, hasAttachment } = await this.checkRequest(
|
||||||
user.id,
|
user.id,
|
||||||
sessionId,
|
sessionId,
|
||||||
|
@ -64,6 +64,13 @@ export class ChatSession implements AsyncDisposable {
|
|||||||
this.stashMessageCount += 1;
|
this.stashMessageCount += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
revertLatestMessage() {
|
||||||
|
const messages = this.state.messages;
|
||||||
|
messages.splice(
|
||||||
|
messages.findLastIndex(({ role }) => role === AiPromptRole.user) + 1
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
async getMessageById(messageId: string) {
|
async getMessageById(messageId: string) {
|
||||||
const message = await this.messageCache.get(messageId);
|
const message = await this.messageCache.get(messageId);
|
||||||
if (!message || message.sessionId !== this.state.sessionId) {
|
if (!message || message.sessionId !== this.state.sessionId) {
|
||||||
@ -287,6 +294,29 @@ export class ChatSessionService {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// revert the latest messages not generate by user
|
||||||
|
// after revert, we can retry the action
|
||||||
|
async revertLatestMessage(sessionId: string) {
|
||||||
|
await this.db.$transaction(async tx => {
|
||||||
|
const ids = await tx.aiSessionMessage
|
||||||
|
.findMany({
|
||||||
|
where: { sessionId },
|
||||||
|
select: { id: true, role: true },
|
||||||
|
orderBy: { createdAt: 'asc' },
|
||||||
|
})
|
||||||
|
.then(roles =>
|
||||||
|
roles
|
||||||
|
.slice(
|
||||||
|
roles.findLastIndex(({ role }) => role === AiPromptRole.user) + 1
|
||||||
|
)
|
||||||
|
.map(({ id }) => id)
|
||||||
|
);
|
||||||
|
if (ids.length) {
|
||||||
|
await tx.aiSessionMessage.deleteMany({ where: { id: { in: ids } } });
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
private calculateTokenSize(
|
private calculateTokenSize(
|
||||||
messages: PromptMessage[],
|
messages: PromptMessage[],
|
||||||
model: AvailableModel
|
model: AvailableModel
|
||||||
|
@ -228,6 +228,61 @@ test('should be able to chat with api', async t => {
|
|||||||
Sinon.restore();
|
Sinon.restore();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
test('should be able to retry with api', async t => {
|
||||||
|
const { app, storage } = t.context;
|
||||||
|
|
||||||
|
Sinon.stub(storage, 'handleRemoteLink').resolvesArg(2);
|
||||||
|
|
||||||
|
// normal chat
|
||||||
|
{
|
||||||
|
const { id } = await createWorkspace(app, token);
|
||||||
|
const sessionId = await createCopilotSession(
|
||||||
|
app,
|
||||||
|
token,
|
||||||
|
id,
|
||||||
|
randomUUID(),
|
||||||
|
promptName
|
||||||
|
);
|
||||||
|
const messageId = await createCopilotMessage(app, token, sessionId);
|
||||||
|
// chat 2 times
|
||||||
|
await chatWithText(app, token, sessionId, messageId);
|
||||||
|
await chatWithText(app, token, sessionId, messageId);
|
||||||
|
|
||||||
|
const histories = await getHistories(app, token, { workspaceId: id });
|
||||||
|
t.deepEqual(
|
||||||
|
histories.map(h => h.messages.map(m => m.content)),
|
||||||
|
[['generate text to text', 'generate text to text']],
|
||||||
|
'should be able to list history'
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// retry chat
|
||||||
|
{
|
||||||
|
const { id } = await createWorkspace(app, token);
|
||||||
|
const sessionId = await createCopilotSession(
|
||||||
|
app,
|
||||||
|
token,
|
||||||
|
id,
|
||||||
|
randomUUID(),
|
||||||
|
promptName
|
||||||
|
);
|
||||||
|
const messageId = await createCopilotMessage(app, token, sessionId);
|
||||||
|
await chatWithText(app, token, sessionId, messageId);
|
||||||
|
// retry without message id
|
||||||
|
await chatWithText(app, token, sessionId);
|
||||||
|
|
||||||
|
// should only have 1 message
|
||||||
|
const histories = await getHistories(app, token, { workspaceId: id });
|
||||||
|
t.deepEqual(
|
||||||
|
histories.map(h => h.messages.map(m => m.content)),
|
||||||
|
[['generate text to text']],
|
||||||
|
'should be able to list history'
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Sinon.restore();
|
||||||
|
});
|
||||||
|
|
||||||
test('should reject message from different session', async t => {
|
test('should reject message from different session', async t => {
|
||||||
const { app } = t.context;
|
const { app } = t.context;
|
||||||
|
|
||||||
|
@ -362,6 +362,59 @@ test('should save message correctly', async t => {
|
|||||||
t.is(s.stashMessages.length, 0, 'should empty stash messages after save');
|
t.is(s.stashMessages.length, 0, 'should empty stash messages after save');
|
||||||
});
|
});
|
||||||
|
|
||||||
|
test('should revert message correctly', async t => {
|
||||||
|
const { prompt, session } = t.context;
|
||||||
|
|
||||||
|
// init session
|
||||||
|
let sessionId: string;
|
||||||
|
{
|
||||||
|
await prompt.set('prompt', 'model', [
|
||||||
|
{ role: 'system', content: 'hello {{word}}' },
|
||||||
|
]);
|
||||||
|
|
||||||
|
sessionId = await session.create({
|
||||||
|
docId: 'test',
|
||||||
|
workspaceId: 'test',
|
||||||
|
userId,
|
||||||
|
promptName: 'prompt',
|
||||||
|
});
|
||||||
|
const s = (await session.get(sessionId))!;
|
||||||
|
|
||||||
|
const message = (await session.createMessage({
|
||||||
|
sessionId,
|
||||||
|
content: 'hello',
|
||||||
|
}))!;
|
||||||
|
|
||||||
|
await s.pushByMessageId(message);
|
||||||
|
await s.save();
|
||||||
|
}
|
||||||
|
|
||||||
|
// check ChatSession behavior
|
||||||
|
{
|
||||||
|
const s = (await session.get(sessionId))!;
|
||||||
|
s.push({ role: 'assistant', content: 'hi', createdAt: new Date() });
|
||||||
|
await s.save();
|
||||||
|
const beforeRevert = s.finish({ word: 'world' });
|
||||||
|
t.is(beforeRevert.length, 3, 'should have three messages before revert');
|
||||||
|
|
||||||
|
s.revertLatestMessage();
|
||||||
|
const afterRevert = s.finish({ word: 'world' });
|
||||||
|
t.is(afterRevert.length, 2, 'should remove assistant message after revert');
|
||||||
|
}
|
||||||
|
|
||||||
|
// check database behavior
|
||||||
|
{
|
||||||
|
let s = (await session.get(sessionId))!;
|
||||||
|
const beforeRevert = s.finish({ word: 'world' });
|
||||||
|
t.is(beforeRevert.length, 3, 'should have three messages before revert');
|
||||||
|
|
||||||
|
await session.revertLatestMessage(sessionId);
|
||||||
|
s = (await session.get(sessionId))!;
|
||||||
|
const afterRevert = s.finish({ word: 'world' });
|
||||||
|
t.is(afterRevert.length, 2, 'should remove assistant message after revert');
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
// ==================== provider ====================
|
// ==================== provider ====================
|
||||||
|
|
||||||
test('should be able to get provider', async t => {
|
test('should be able to get provider', async t => {
|
||||||
|
@ -196,11 +196,12 @@ export async function chatWithText(
|
|||||||
app: INestApplication,
|
app: INestApplication,
|
||||||
userToken: string,
|
userToken: string,
|
||||||
sessionId: string,
|
sessionId: string,
|
||||||
messageId: string,
|
messageId?: string,
|
||||||
prefix = ''
|
prefix = ''
|
||||||
): Promise<string> {
|
): Promise<string> {
|
||||||
|
const query = messageId ? `?messageId=${messageId}` : '';
|
||||||
const res = await request(app.getHttpServer())
|
const res = await request(app.getHttpServer())
|
||||||
.get(`/api/copilot/chat/${sessionId}${prefix}?messageId=${messageId}`)
|
.get(`/api/copilot/chat/${sessionId}${prefix}${query}`)
|
||||||
.auth(userToken, { type: 'bearer' })
|
.auth(userToken, { type: 'bearer' })
|
||||||
.expect(200);
|
.expect(200);
|
||||||
|
|
||||||
@ -211,7 +212,7 @@ export async function chatWithTextStream(
|
|||||||
app: INestApplication,
|
app: INestApplication,
|
||||||
userToken: string,
|
userToken: string,
|
||||||
sessionId: string,
|
sessionId: string,
|
||||||
messageId: string
|
messageId?: string
|
||||||
) {
|
) {
|
||||||
return chatWithText(app, userToken, sessionId, messageId, '/stream');
|
return chatWithText(app, userToken, sessionId, messageId, '/stream');
|
||||||
}
|
}
|
||||||
@ -220,7 +221,7 @@ export async function chatWithImages(
|
|||||||
app: INestApplication,
|
app: INestApplication,
|
||||||
userToken: string,
|
userToken: string,
|
||||||
sessionId: string,
|
sessionId: string,
|
||||||
messageId: string
|
messageId?: string
|
||||||
) {
|
) {
|
||||||
return chatWithText(app, userToken, sessionId, messageId, '/images');
|
return chatWithText(app, userToken, sessionId, messageId, '/images');
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user