fix: choose provider correctly (#7081)

fix no provider error in caption generate action
This commit is contained in:
darkskygit 2024-05-27 09:57:39 +00:00
parent 50dcce891b
commit 5ba9e2e9b1
No known key found for this signature in database
GPG Key ID: 97B7D036B1566E9D
4 changed files with 57 additions and 23 deletions

View File

@ -0,0 +1,8 @@
/*
Warnings:
- Made the column `model` on table `ai_prompts_metadata` required. This step will fail if there are existing NULL values in that column.
*/
-- AlterTable
ALTER TABLE "ai_prompts_metadata" ALTER COLUMN "model" SET NOT NULL;

View File

@ -455,7 +455,7 @@ model AiPrompt {
// an mark identifying which view to use to display the session
// it is only used in the frontend and does not affect the backend
action String? @db.VarChar
model String? @db.VarChar
model String @db.VarChar
createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(6)
messages AiPromptMessage[]

View File

@ -34,7 +34,11 @@ import { Config } from '../../fundamentals';
import { CopilotProviderService } from './providers';
import { ChatSession, ChatSessionService } from './session';
import { CopilotStorage } from './storage';
import { CopilotCapability } from './types';
import {
CopilotCapability,
CopilotImageToTextProvider,
CopilotTextToTextProvider,
} from './types';
export interface ChatEvent {
type: 'attachment' | 'message' | 'error';
@ -71,7 +75,7 @@ export class CopilotController {
const ret: CheckResult = { model: session.model };
if (messageId) {
if (messageId && typeof messageId === 'string') {
const message = await session.getMessageById(messageId);
ret.hasAttachment =
Array.isArray(message.attachments) && !!message.attachments.length;
@ -80,6 +84,34 @@ export class CopilotController {
return ret;
}
private async chooseTextProvider(
userId: string,
sessionId: string,
messageId?: string
): Promise<CopilotTextToTextProvider | CopilotImageToTextProvider> {
const { hasAttachment, model } = await this.checkRequest(
userId,
sessionId,
messageId
);
let provider = await this.provider.getProviderByCapability(
CopilotCapability.TextToText,
model
);
// fallback to image to text if text to text is not available
if (!provider && hasAttachment) {
provider = await this.provider.getProviderByCapability(
CopilotCapability.ImageToText,
model
);
}
if (!provider) {
throw new InternalServerErrorException('No provider available');
}
return provider;
}
private async appendSessionMessage(
sessionId: string,
messageId?: string
@ -139,18 +171,15 @@ export class CopilotController {
@Param('sessionId') sessionId: string,
@Query() params: Record<string, string | string[]>
): Promise<string> {
const { model } = await this.checkRequest(user.id, sessionId);
const provider = await this.provider.getProviderByCapability(
CopilotCapability.TextToText,
model
);
if (!provider) {
throw new InternalServerErrorException('No provider available');
}
const messageId = Array.isArray(params.messageId)
? params.messageId[0]
: params.messageId;
const provider = await this.chooseTextProvider(
user.id,
sessionId,
messageId
);
const session = await this.appendSessionMessage(sessionId, messageId);
try {
@ -187,18 +216,15 @@ export class CopilotController {
@Query() params: Record<string, string>
): Promise<Observable<ChatEvent>> {
try {
const { model } = await this.checkRequest(user.id, sessionId);
const provider = await this.provider.getProviderByCapability(
CopilotCapability.TextToText,
model
);
if (!provider) {
throw new InternalServerErrorException('No provider available');
}
const messageId = Array.isArray(params.messageId)
? params.messageId[0]
: params.messageId;
const provider = await this.chooseTextProvider(
user.id,
sessionId,
messageId
);
const session = await this.appendSessionMessage(sessionId, messageId);
delete params.messageId;

View File

@ -42,7 +42,7 @@ export class ChatPrompt {
return new ChatPrompt(
options.name,
options.action || undefined,
options.model || undefined,
options.model,
options.messages
);
}
@ -50,7 +50,7 @@ export class ChatPrompt {
constructor(
public readonly name: string,
public readonly action: string | undefined,
public readonly model: string | undefined,
public readonly model: string,
private readonly messages: PromptMessage[]
) {
this.encoder = getTokenEncoder(model);