mirror of
https://github.com/toeverything/AFFiNE.git
synced 2024-12-23 21:55:02 +03:00
feat: allow undefined new model (#6933)
This commit is contained in:
parent
b036f1b5c9
commit
98e218af93
1
.github/workflows/build-test.yml
vendored
1
.github/workflows/build-test.yml
vendored
@ -351,6 +351,7 @@ jobs:
|
||||
env:
|
||||
CARGO_TARGET_DIR: '${{ github.workspace }}/target'
|
||||
DATABASE_URL: postgresql://affine:affine@localhost:5432/affine
|
||||
COPILOT_OPENAI_API_KEY: ${{ secrets.COPILOT_OPENAI_API_KEY }}
|
||||
|
||||
- name: Upload server test coverage results
|
||||
uses: codecov/codecov-action@v4
|
||||
|
@ -133,7 +133,7 @@ export class CopilotController {
|
||||
@Query() params: Record<string, string | string[]>
|
||||
): Promise<string> {
|
||||
const { model } = await this.checkRequest(user.id, sessionId);
|
||||
const provider = this.provider.getProviderByCapability(
|
||||
const provider = await this.provider.getProviderByCapability(
|
||||
CopilotCapability.TextToText,
|
||||
model
|
||||
);
|
||||
@ -179,7 +179,7 @@ export class CopilotController {
|
||||
): Promise<Observable<ChatEvent>> {
|
||||
try {
|
||||
const { model } = await this.checkRequest(user.id, sessionId);
|
||||
const provider = this.provider.getProviderByCapability(
|
||||
const provider = await this.provider.getProviderByCapability(
|
||||
CopilotCapability.TextToText,
|
||||
model
|
||||
);
|
||||
@ -246,7 +246,7 @@ export class CopilotController {
|
||||
sessionId,
|
||||
messageId
|
||||
);
|
||||
const provider = this.provider.getProviderByCapability(
|
||||
const provider = await this.provider.getProviderByCapability(
|
||||
hasAttachment
|
||||
? CopilotCapability.ImageToImage
|
||||
: CopilotCapability.TextToImage,
|
||||
|
@ -50,7 +50,7 @@ export class FalProvider
|
||||
return FalProvider.capabilities;
|
||||
}
|
||||
|
||||
isModelAvailable(model: string): boolean {
|
||||
async isModelAvailable(model: string): Promise<boolean> {
|
||||
return this.availableModels.includes(model);
|
||||
}
|
||||
|
||||
|
@ -48,7 +48,7 @@ export function registerCopilotProvider<
|
||||
const providerConfig = config.plugins.copilot?.[type];
|
||||
if (!provider.assetsConfig(providerConfig as C)) {
|
||||
throw new Error(
|
||||
`Invalid configuration for copilot provider ${type}: ${providerConfig}`
|
||||
`Invalid configuration for copilot provider ${type}: ${JSON.stringify(providerConfig)}`
|
||||
);
|
||||
}
|
||||
const instance = new provider(providerConfig as C);
|
||||
@ -116,11 +116,11 @@ export class CopilotProviderService {
|
||||
return this.cachedProviders.get(provider)!;
|
||||
}
|
||||
|
||||
getProviderByCapability<C extends CopilotCapability>(
|
||||
async getProviderByCapability<C extends CopilotCapability>(
|
||||
capability: C,
|
||||
model?: string,
|
||||
prefer?: CopilotProviderType
|
||||
): CapabilityToCopilotProvider[C] | null {
|
||||
): Promise<CapabilityToCopilotProvider[C] | null> {
|
||||
const providers = PROVIDER_CAPABILITY_MAP.get(capability);
|
||||
if (Array.isArray(providers) && providers.length) {
|
||||
let selectedProvider: CopilotProviderType | undefined = prefer;
|
||||
@ -137,7 +137,7 @@ export class CopilotProviderService {
|
||||
const provider = this.getProvider(selectedProvider);
|
||||
if (provider.getCapabilities().includes(capability)) {
|
||||
if (model) {
|
||||
if (provider.isModelAvailable(model)) {
|
||||
if (await provider.isModelAvailable(model)) {
|
||||
return provider as CapabilityToCopilotProvider[C];
|
||||
}
|
||||
} else {
|
||||
|
@ -1,5 +1,6 @@
|
||||
import assert from 'node:assert';
|
||||
|
||||
import { Logger } from '@nestjs/common';
|
||||
import { ClientOptions, OpenAI } from 'openai';
|
||||
|
||||
import {
|
||||
@ -51,7 +52,9 @@ export class OpenAIProvider
|
||||
'dall-e-3',
|
||||
];
|
||||
|
||||
private readonly logger = new Logger(OpenAIProvider.type);
|
||||
private readonly instance: OpenAI;
|
||||
private existsModels: string[] | undefined;
|
||||
|
||||
constructor(config: ClientOptions) {
|
||||
assert(OpenAIProvider.assetsConfig(config));
|
||||
@ -70,8 +73,20 @@ export class OpenAIProvider
|
||||
return OpenAIProvider.capabilities;
|
||||
}
|
||||
|
||||
isModelAvailable(model: string): boolean {
|
||||
return this.availableModels.includes(model);
|
||||
async isModelAvailable(model: string): Promise<boolean> {
|
||||
const knownModels = this.availableModels.includes(model);
|
||||
if (knownModels) return true;
|
||||
|
||||
if (!this.existsModels) {
|
||||
try {
|
||||
this.existsModels = await this.instance.models
|
||||
.list()
|
||||
.then(({ data }) => data.map(m => m.id));
|
||||
} catch (e) {
|
||||
this.logger.error('Failed to fetch online model list', e);
|
||||
}
|
||||
}
|
||||
return !!this.existsModels?.includes(model);
|
||||
}
|
||||
|
||||
protected chatToGPTMessage(
|
||||
|
@ -172,7 +172,7 @@ export type CopilotImageOptions = z.infer<typeof CopilotImageOptionsSchema>;
|
||||
export interface CopilotProvider {
|
||||
readonly type: CopilotProviderType;
|
||||
getCapabilities(): CopilotCapability[];
|
||||
isModelAvailable(model: string): boolean;
|
||||
isModelAvailable(model: string): Promise<boolean>;
|
||||
}
|
||||
|
||||
export interface CopilotTextToTextProvider extends CopilotProvider {
|
||||
|
@ -36,7 +36,7 @@ test.beforeEach(async t => {
|
||||
plugins: {
|
||||
copilot: {
|
||||
openai: {
|
||||
apiKey: '1',
|
||||
apiKey: process.env.COPILOT_OPENAI_API_KEY ?? '1',
|
||||
},
|
||||
fal: {
|
||||
apiKey: '1',
|
||||
@ -368,7 +368,9 @@ test('should be able to get provider', async t => {
|
||||
const { provider } = t.context;
|
||||
|
||||
{
|
||||
const p = provider.getProviderByCapability(CopilotCapability.TextToText);
|
||||
const p = await provider.getProviderByCapability(
|
||||
CopilotCapability.TextToText
|
||||
);
|
||||
t.is(
|
||||
p?.type.toString(),
|
||||
'openai',
|
||||
@ -377,7 +379,7 @@ test('should be able to get provider', async t => {
|
||||
}
|
||||
|
||||
{
|
||||
const p = provider.getProviderByCapability(
|
||||
const p = await provider.getProviderByCapability(
|
||||
CopilotCapability.TextToEmbedding
|
||||
);
|
||||
t.is(
|
||||
@ -388,7 +390,9 @@ test('should be able to get provider', async t => {
|
||||
}
|
||||
|
||||
{
|
||||
const p = provider.getProviderByCapability(CopilotCapability.TextToImage);
|
||||
const p = await provider.getProviderByCapability(
|
||||
CopilotCapability.TextToImage
|
||||
);
|
||||
t.is(
|
||||
p?.type.toString(),
|
||||
'fal',
|
||||
@ -397,7 +401,9 @@ test('should be able to get provider', async t => {
|
||||
}
|
||||
|
||||
{
|
||||
const p = provider.getProviderByCapability(CopilotCapability.ImageToImage);
|
||||
const p = await provider.getProviderByCapability(
|
||||
CopilotCapability.ImageToImage
|
||||
);
|
||||
t.is(
|
||||
p?.type.toString(),
|
||||
'fal',
|
||||
@ -406,7 +412,9 @@ test('should be able to get provider', async t => {
|
||||
}
|
||||
|
||||
{
|
||||
const p = provider.getProviderByCapability(CopilotCapability.ImageToText);
|
||||
const p = await provider.getProviderByCapability(
|
||||
CopilotCapability.ImageToText
|
||||
);
|
||||
t.is(
|
||||
p?.type.toString(),
|
||||
'openai',
|
||||
@ -417,7 +425,7 @@ test('should be able to get provider', async t => {
|
||||
// text-to-image use fal by default, but this case can use
|
||||
// model dall-e-3 to select openai provider
|
||||
{
|
||||
const p = provider.getProviderByCapability(
|
||||
const p = await provider.getProviderByCapability(
|
||||
CopilotCapability.TextToImage,
|
||||
'dall-e-3'
|
||||
);
|
||||
@ -427,14 +435,38 @@ test('should be able to get provider', async t => {
|
||||
'should get provider support text-to-image and model'
|
||||
);
|
||||
}
|
||||
|
||||
// gpt4o is not defined now, but it already published by openai
|
||||
// we should check from online api if it is available
|
||||
{
|
||||
const p = await provider.getProviderByCapability(
|
||||
CopilotCapability.ImageToText,
|
||||
'gpt-4o'
|
||||
);
|
||||
t.is(
|
||||
p?.type.toString(),
|
||||
'openai',
|
||||
'should get provider support text-to-image and model'
|
||||
);
|
||||
}
|
||||
|
||||
// if a model is not defined and not available in online api
|
||||
// it should return null
|
||||
{
|
||||
const p = await provider.getProviderByCapability(
|
||||
CopilotCapability.ImageToText,
|
||||
'gpt-4-not-exist'
|
||||
);
|
||||
t.falsy(p, 'should not get provider');
|
||||
}
|
||||
});
|
||||
|
||||
test('should be able to register test provider', async t => {
|
||||
const { provider } = t.context;
|
||||
registerCopilotProvider(MockCopilotTestProvider);
|
||||
|
||||
const assertProvider = (cap: CopilotCapability) => {
|
||||
const p = provider.getProviderByCapability(cap, 'test');
|
||||
const assertProvider = async (cap: CopilotCapability) => {
|
||||
const p = await provider.getProviderByCapability(cap, 'test');
|
||||
t.is(
|
||||
p?.type,
|
||||
CopilotProviderType.Test,
|
||||
@ -442,9 +474,9 @@ test('should be able to register test provider', async t => {
|
||||
);
|
||||
};
|
||||
|
||||
assertProvider(CopilotCapability.TextToText);
|
||||
assertProvider(CopilotCapability.TextToEmbedding);
|
||||
assertProvider(CopilotCapability.TextToImage);
|
||||
assertProvider(CopilotCapability.ImageToImage);
|
||||
assertProvider(CopilotCapability.ImageToText);
|
||||
await assertProvider(CopilotCapability.TextToText);
|
||||
await assertProvider(CopilotCapability.TextToEmbedding);
|
||||
await assertProvider(CopilotCapability.TextToImage);
|
||||
await assertProvider(CopilotCapability.ImageToImage);
|
||||
await assertProvider(CopilotCapability.ImageToText);
|
||||
});
|
||||
|
@ -46,7 +46,7 @@ export class MockCopilotTestProvider
|
||||
return MockCopilotTestProvider.capabilities;
|
||||
}
|
||||
|
||||
override isModelAvailable(model: string): boolean {
|
||||
override async isModelAvailable(model: string): Promise<boolean> {
|
||||
return this.availableModels.includes(model);
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user