mirror of
https://github.com/toeverything/AFFiNE.git
synced 2025-01-01 23:02:58 +03:00
fix: choose provider correctly (#7081)
fix no provider error in caption generate action
This commit is contained in:
parent
50dcce891b
commit
5ba9e2e9b1
@ -0,0 +1,8 @@
|
||||
/*
|
||||
Warnings:
|
||||
|
||||
- Made the column `model` on table `ai_prompts_metadata` required. This step will fail if there are existing NULL values in that column.
|
||||
|
||||
*/
|
||||
-- AlterTable
|
||||
ALTER TABLE "ai_prompts_metadata" ALTER COLUMN "model" SET NOT NULL;
|
@ -455,7 +455,7 @@ model AiPrompt {
|
||||
// an mark identifying which view to use to display the session
|
||||
// it is only used in the frontend and does not affect the backend
|
||||
action String? @db.VarChar
|
||||
model String? @db.VarChar
|
||||
model String @db.VarChar
|
||||
createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(6)
|
||||
|
||||
messages AiPromptMessage[]
|
||||
|
@ -34,7 +34,11 @@ import { Config } from '../../fundamentals';
|
||||
import { CopilotProviderService } from './providers';
|
||||
import { ChatSession, ChatSessionService } from './session';
|
||||
import { CopilotStorage } from './storage';
|
||||
import { CopilotCapability } from './types';
|
||||
import {
|
||||
CopilotCapability,
|
||||
CopilotImageToTextProvider,
|
||||
CopilotTextToTextProvider,
|
||||
} from './types';
|
||||
|
||||
export interface ChatEvent {
|
||||
type: 'attachment' | 'message' | 'error';
|
||||
@ -71,7 +75,7 @@ export class CopilotController {
|
||||
|
||||
const ret: CheckResult = { model: session.model };
|
||||
|
||||
if (messageId) {
|
||||
if (messageId && typeof messageId === 'string') {
|
||||
const message = await session.getMessageById(messageId);
|
||||
ret.hasAttachment =
|
||||
Array.isArray(message.attachments) && !!message.attachments.length;
|
||||
@ -80,6 +84,34 @@ export class CopilotController {
|
||||
return ret;
|
||||
}
|
||||
|
||||
private async chooseTextProvider(
|
||||
userId: string,
|
||||
sessionId: string,
|
||||
messageId?: string
|
||||
): Promise<CopilotTextToTextProvider | CopilotImageToTextProvider> {
|
||||
const { hasAttachment, model } = await this.checkRequest(
|
||||
userId,
|
||||
sessionId,
|
||||
messageId
|
||||
);
|
||||
let provider = await this.provider.getProviderByCapability(
|
||||
CopilotCapability.TextToText,
|
||||
model
|
||||
);
|
||||
// fallback to image to text if text to text is not available
|
||||
if (!provider && hasAttachment) {
|
||||
provider = await this.provider.getProviderByCapability(
|
||||
CopilotCapability.ImageToText,
|
||||
model
|
||||
);
|
||||
}
|
||||
if (!provider) {
|
||||
throw new InternalServerErrorException('No provider available');
|
||||
}
|
||||
|
||||
return provider;
|
||||
}
|
||||
|
||||
private async appendSessionMessage(
|
||||
sessionId: string,
|
||||
messageId?: string
|
||||
@ -139,18 +171,15 @@ export class CopilotController {
|
||||
@Param('sessionId') sessionId: string,
|
||||
@Query() params: Record<string, string | string[]>
|
||||
): Promise<string> {
|
||||
const { model } = await this.checkRequest(user.id, sessionId);
|
||||
const provider = await this.provider.getProviderByCapability(
|
||||
CopilotCapability.TextToText,
|
||||
model
|
||||
);
|
||||
if (!provider) {
|
||||
throw new InternalServerErrorException('No provider available');
|
||||
}
|
||||
|
||||
const messageId = Array.isArray(params.messageId)
|
||||
? params.messageId[0]
|
||||
: params.messageId;
|
||||
const provider = await this.chooseTextProvider(
|
||||
user.id,
|
||||
sessionId,
|
||||
messageId
|
||||
);
|
||||
|
||||
const session = await this.appendSessionMessage(sessionId, messageId);
|
||||
|
||||
try {
|
||||
@ -187,18 +216,15 @@ export class CopilotController {
|
||||
@Query() params: Record<string, string>
|
||||
): Promise<Observable<ChatEvent>> {
|
||||
try {
|
||||
const { model } = await this.checkRequest(user.id, sessionId);
|
||||
const provider = await this.provider.getProviderByCapability(
|
||||
CopilotCapability.TextToText,
|
||||
model
|
||||
);
|
||||
if (!provider) {
|
||||
throw new InternalServerErrorException('No provider available');
|
||||
}
|
||||
|
||||
const messageId = Array.isArray(params.messageId)
|
||||
? params.messageId[0]
|
||||
: params.messageId;
|
||||
const provider = await this.chooseTextProvider(
|
||||
user.id,
|
||||
sessionId,
|
||||
messageId
|
||||
);
|
||||
|
||||
const session = await this.appendSessionMessage(sessionId, messageId);
|
||||
delete params.messageId;
|
||||
|
||||
|
@ -42,7 +42,7 @@ export class ChatPrompt {
|
||||
return new ChatPrompt(
|
||||
options.name,
|
||||
options.action || undefined,
|
||||
options.model || undefined,
|
||||
options.model,
|
||||
options.messages
|
||||
);
|
||||
}
|
||||
@ -50,7 +50,7 @@ export class ChatPrompt {
|
||||
constructor(
|
||||
public readonly name: string,
|
||||
public readonly action: string | undefined,
|
||||
public readonly model: string | undefined,
|
||||
public readonly model: string,
|
||||
private readonly messages: PromptMessage[]
|
||||
) {
|
||||
this.encoder = getTokenEncoder(model);
|
||||
|
Loading…
Reference in New Issue
Block a user