feat: allow undefined new model (#6933)

This commit is contained in:
darkskygit 2024-05-14 13:05:07 +00:00
parent b036f1b5c9
commit 98e218af93
No known key found for this signature in database
GPG Key ID: 97B7D036B1566E9D
8 changed files with 74 additions and 26 deletions

View File

@ -351,6 +351,7 @@ jobs:
env:
CARGO_TARGET_DIR: '${{ github.workspace }}/target'
DATABASE_URL: postgresql://affine:affine@localhost:5432/affine
COPILOT_OPENAI_API_KEY: ${{ secrets.COPILOT_OPENAI_API_KEY }}
- name: Upload server test coverage results
uses: codecov/codecov-action@v4

View File

@ -133,7 +133,7 @@ export class CopilotController {
@Query() params: Record<string, string | string[]>
): Promise<string> {
const { model } = await this.checkRequest(user.id, sessionId);
const provider = this.provider.getProviderByCapability(
const provider = await this.provider.getProviderByCapability(
CopilotCapability.TextToText,
model
);
@ -179,7 +179,7 @@ export class CopilotController {
): Promise<Observable<ChatEvent>> {
try {
const { model } = await this.checkRequest(user.id, sessionId);
const provider = this.provider.getProviderByCapability(
const provider = await this.provider.getProviderByCapability(
CopilotCapability.TextToText,
model
);
@ -246,7 +246,7 @@ export class CopilotController {
sessionId,
messageId
);
const provider = this.provider.getProviderByCapability(
const provider = await this.provider.getProviderByCapability(
hasAttachment
? CopilotCapability.ImageToImage
: CopilotCapability.TextToImage,

View File

@ -50,7 +50,7 @@ export class FalProvider
return FalProvider.capabilities;
}
isModelAvailable(model: string): boolean {
async isModelAvailable(model: string): Promise<boolean> {
return this.availableModels.includes(model);
}

View File

@ -48,7 +48,7 @@ export function registerCopilotProvider<
const providerConfig = config.plugins.copilot?.[type];
if (!provider.assetsConfig(providerConfig as C)) {
throw new Error(
`Invalid configuration for copilot provider ${type}: ${providerConfig}`
`Invalid configuration for copilot provider ${type}: ${JSON.stringify(providerConfig)}`
);
}
const instance = new provider(providerConfig as C);
@ -116,11 +116,11 @@ export class CopilotProviderService {
return this.cachedProviders.get(provider)!;
}
getProviderByCapability<C extends CopilotCapability>(
async getProviderByCapability<C extends CopilotCapability>(
capability: C,
model?: string,
prefer?: CopilotProviderType
): CapabilityToCopilotProvider[C] | null {
): Promise<CapabilityToCopilotProvider[C] | null> {
const providers = PROVIDER_CAPABILITY_MAP.get(capability);
if (Array.isArray(providers) && providers.length) {
let selectedProvider: CopilotProviderType | undefined = prefer;
@ -137,7 +137,7 @@ export class CopilotProviderService {
const provider = this.getProvider(selectedProvider);
if (provider.getCapabilities().includes(capability)) {
if (model) {
if (provider.isModelAvailable(model)) {
if (await provider.isModelAvailable(model)) {
return provider as CapabilityToCopilotProvider[C];
}
} else {

View File

@ -1,5 +1,6 @@
import assert from 'node:assert';
import { Logger } from '@nestjs/common';
import { ClientOptions, OpenAI } from 'openai';
import {
@ -51,7 +52,9 @@ export class OpenAIProvider
'dall-e-3',
];
private readonly logger = new Logger(OpenAIProvider.type);
private readonly instance: OpenAI;
private existsModels: string[] | undefined;
constructor(config: ClientOptions) {
assert(OpenAIProvider.assetsConfig(config));
@ -70,8 +73,20 @@ export class OpenAIProvider
return OpenAIProvider.capabilities;
}
isModelAvailable(model: string): boolean {
return this.availableModels.includes(model);
async isModelAvailable(model: string): Promise<boolean> {
const knownModels = this.availableModels.includes(model);
if (knownModels) return true;
if (!this.existsModels) {
try {
this.existsModels = await this.instance.models
.list()
.then(({ data }) => data.map(m => m.id));
} catch (e) {
this.logger.error('Failed to fetch online model list', e);
}
}
return !!this.existsModels?.includes(model);
}
protected chatToGPTMessage(

View File

@ -172,7 +172,7 @@ export type CopilotImageOptions = z.infer<typeof CopilotImageOptionsSchema>;
export interface CopilotProvider {
readonly type: CopilotProviderType;
getCapabilities(): CopilotCapability[];
isModelAvailable(model: string): boolean;
isModelAvailable(model: string): Promise<boolean>;
}
export interface CopilotTextToTextProvider extends CopilotProvider {

View File

@ -36,7 +36,7 @@ test.beforeEach(async t => {
plugins: {
copilot: {
openai: {
apiKey: '1',
apiKey: process.env.COPILOT_OPENAI_API_KEY ?? '1',
},
fal: {
apiKey: '1',
@ -368,7 +368,9 @@ test('should be able to get provider', async t => {
const { provider } = t.context;
{
const p = provider.getProviderByCapability(CopilotCapability.TextToText);
const p = await provider.getProviderByCapability(
CopilotCapability.TextToText
);
t.is(
p?.type.toString(),
'openai',
@ -377,7 +379,7 @@ test('should be able to get provider', async t => {
}
{
const p = provider.getProviderByCapability(
const p = await provider.getProviderByCapability(
CopilotCapability.TextToEmbedding
);
t.is(
@ -388,7 +390,9 @@ test('should be able to get provider', async t => {
}
{
const p = provider.getProviderByCapability(CopilotCapability.TextToImage);
const p = await provider.getProviderByCapability(
CopilotCapability.TextToImage
);
t.is(
p?.type.toString(),
'fal',
@ -397,7 +401,9 @@ test('should be able to get provider', async t => {
}
{
const p = provider.getProviderByCapability(CopilotCapability.ImageToImage);
const p = await provider.getProviderByCapability(
CopilotCapability.ImageToImage
);
t.is(
p?.type.toString(),
'fal',
@ -406,7 +412,9 @@ test('should be able to get provider', async t => {
}
{
const p = provider.getProviderByCapability(CopilotCapability.ImageToText);
const p = await provider.getProviderByCapability(
CopilotCapability.ImageToText
);
t.is(
p?.type.toString(),
'openai',
@ -417,7 +425,7 @@ test('should be able to get provider', async t => {
// text-to-image use fal by default, but this case can use
// model dall-e-3 to select openai provider
{
const p = provider.getProviderByCapability(
const p = await provider.getProviderByCapability(
CopilotCapability.TextToImage,
'dall-e-3'
);
@ -427,14 +435,38 @@ test('should be able to get provider', async t => {
'should get provider support text-to-image and model'
);
}
// gpt4o is not defined now, but it already published by openai
// we should check from online api if it is available
{
const p = await provider.getProviderByCapability(
CopilotCapability.ImageToText,
'gpt-4o'
);
t.is(
p?.type.toString(),
'openai',
'should get provider support text-to-image and model'
);
}
// if a model is not defined and not available in online api
// it should return null
{
const p = await provider.getProviderByCapability(
CopilotCapability.ImageToText,
'gpt-4-not-exist'
);
t.falsy(p, 'should not get provider');
}
});
test('should be able to register test provider', async t => {
const { provider } = t.context;
registerCopilotProvider(MockCopilotTestProvider);
const assertProvider = (cap: CopilotCapability) => {
const p = provider.getProviderByCapability(cap, 'test');
const assertProvider = async (cap: CopilotCapability) => {
const p = await provider.getProviderByCapability(cap, 'test');
t.is(
p?.type,
CopilotProviderType.Test,
@ -442,9 +474,9 @@ test('should be able to register test provider', async t => {
);
};
assertProvider(CopilotCapability.TextToText);
assertProvider(CopilotCapability.TextToEmbedding);
assertProvider(CopilotCapability.TextToImage);
assertProvider(CopilotCapability.ImageToImage);
assertProvider(CopilotCapability.ImageToText);
await assertProvider(CopilotCapability.TextToText);
await assertProvider(CopilotCapability.TextToEmbedding);
await assertProvider(CopilotCapability.TextToImage);
await assertProvider(CopilotCapability.ImageToImage);
await assertProvider(CopilotCapability.ImageToText);
});

View File

@ -46,7 +46,7 @@ export class MockCopilotTestProvider
return MockCopilotTestProvider.capabilities;
}
override isModelAvailable(model: string): boolean {
override async isModelAvailable(model: string): Promise<boolean> {
return this.availableModels.includes(model);
}