feat: integrate i18n error for copilot (#7311)

fix PD-1333 CLOUD-42
This commit is contained in:
darkskygit 2024-06-26 13:36:23 +00:00
parent 6b47c6beda
commit aeb666f95e
No known key found for this signature in database
GPG Key ID: 97B7D036B1566E9D
6 changed files with 247 additions and 124 deletions

View File

@ -447,6 +447,16 @@ export const USER_FRIENDLY_ERRORS = {
args: { name: 'string' },
message: ({ name }) => `Copilot prompt ${name} not found.`,
},
copilot_prompt_invalid: {
type: 'invalid_input',
message: `Copilot prompt is invalid.`,
},
copilot_provider_side_error: {
type: 'internal_server_error',
args: { provider: 'string', kind: 'string', message: 'string' },
message: ({ provider, kind, message }) =>
`Provider ${provider} failed with ${kind} error: ${message || 'unknown'}.`,
},
// Quota & Limit errors
blob_quota_exceeded: {

View File

@ -408,6 +408,24 @@ export class CopilotPromptNotFound extends UserFriendlyError {
}
}
export class CopilotPromptInvalid extends UserFriendlyError {
constructor(message?: string) {
super('invalid_input', 'copilot_prompt_invalid', message);
}
}
@ObjectType()
class CopilotProviderSideErrorDataType {
@Field() provider!: string
@Field() kind!: string
@Field() message!: string
}
export class CopilotProviderSideError extends UserFriendlyError {
constructor(args: CopilotProviderSideErrorDataType, message?: string | ((args: CopilotProviderSideErrorDataType) => string)) {
super('internal_server_error', 'copilot_provider_side_error', message, args);
}
}
export class BlobQuotaExceeded extends UserFriendlyError {
constructor(message?: string) {
super('quota_exceeded', 'blob_quota_exceeded', message);
@ -508,6 +526,8 @@ export enum ErrorNames {
COPILOT_ACTION_TAKEN,
COPILOT_MESSAGE_NOT_FOUND,
COPILOT_PROMPT_NOT_FOUND,
COPILOT_PROMPT_INVALID,
COPILOT_PROVIDER_SIDE_ERROR,
BLOB_QUOTA_EXCEEDED,
MEMBER_QUOTA_EXCEEDED,
COPILOT_QUOTA_EXCEEDED,
@ -522,5 +542,5 @@ registerEnumType(ErrorNames, {
export const ErrorDataUnionType = createUnionType({
name: 'ErrorDataUnion',
types: () =>
[UnknownOauthProviderDataType, MissingOauthQueryParameterDataType, InvalidPasswordLengthDataType, WorkspaceNotFoundDataType, NotInWorkspaceDataType, WorkspaceAccessDeniedDataType, WorkspaceOwnerNotFoundDataType, DocNotFoundDataType, DocAccessDeniedDataType, VersionRejectedDataType, InvalidHistoryTimestampDataType, DocHistoryNotFoundDataType, BlobNotFoundDataType, SubscriptionAlreadyExistsDataType, SubscriptionNotExistsDataType, SameSubscriptionRecurringDataType, SubscriptionPlanNotFoundDataType, CopilotPromptNotFoundDataType, RuntimeConfigNotFoundDataType, InvalidRuntimeConfigTypeDataType] as const,
[UnknownOauthProviderDataType, MissingOauthQueryParameterDataType, InvalidPasswordLengthDataType, WorkspaceNotFoundDataType, NotInWorkspaceDataType, WorkspaceAccessDeniedDataType, WorkspaceOwnerNotFoundDataType, DocNotFoundDataType, DocAccessDeniedDataType, VersionRejectedDataType, InvalidHistoryTimestampDataType, DocHistoryNotFoundDataType, BlobNotFoundDataType, SubscriptionAlreadyExistsDataType, SubscriptionNotExistsDataType, SameSubscriptionRecurringDataType, SubscriptionPlanNotFoundDataType, CopilotPromptNotFoundDataType, CopilotProviderSideErrorDataType, RuntimeConfigNotFoundDataType, InvalidRuntimeConfigTypeDataType] as const,
});

View File

@ -4,9 +4,13 @@ import {
config as falConfig,
stream as falStream,
} from '@fal-ai/serverless-client';
import { Logger } from '@nestjs/common';
import { z } from 'zod';
import { z, ZodType } from 'zod';
import {
CopilotPromptInvalid,
CopilotProviderSideError,
UserFriendlyError,
} from '../../../fundamentals';
import {
CopilotCapability,
CopilotChatOptions,
@ -37,7 +41,10 @@ type FalImage = z.infer<typeof FalImageSchema>;
const FalResponseSchema = z.object({
detail: z
.union([z.array(z.object({ msg: z.string() })), z.string()])
.union([
z.array(z.object({ type: z.string(), msg: z.string() })),
z.string(),
])
.optional(),
images: z.array(FalImageSchema).optional(),
image: FalImageSchema.optional(),
@ -84,8 +91,6 @@ export class FalProvider
'llava-next',
];
private readonly logger = new Logger(FalProvider.name);
constructor(private readonly config: FalConfig) {
assert(FalProvider.assetsConfig(config));
falConfig({ credentials: this.config.apiKey });
@ -107,23 +112,15 @@ export class FalProvider
return this.availableModels.includes(model);
}
private extractError(resp: FalResponse): string {
return Array.isArray(resp.detail)
? resp.detail[0]?.msg
: typeof resp.detail === 'string'
? resp.detail
: '';
}
private extractPrompt(message?: PromptMessage): FalPrompt {
if (!message) throw new Error('Prompt is empty');
if (!message) throw new CopilotPromptInvalid('Prompt is empty');
const { content, attachments, params } = message;
// prompt attachments require at least one
if (!content && (!Array.isArray(attachments) || !attachments.length)) {
throw new Error('Prompt or Attachments is empty');
throw new CopilotPromptInvalid('Prompt or Attachments is empty');
}
if (Array.isArray(attachments) && attachments.length > 1) {
throw new Error('Only one attachment is allowed');
throw new CopilotPromptInvalid('Only one attachment is allowed');
}
const lora = (
params?.lora
@ -139,38 +136,91 @@ export class FalProvider
};
}
private extractFalError(
resp: FalResponse,
message?: string
): CopilotProviderSideError {
if (Array.isArray(resp.detail) && resp.detail.length) {
const error = resp.detail[0].msg;
return new CopilotProviderSideError({
provider: this.type,
kind: resp.detail[0].type,
message: message ? `${message}: ${error}` : error,
});
} else if (typeof resp.detail === 'string') {
const error = resp.detail;
return new CopilotProviderSideError({
provider: this.type,
kind: resp.detail,
message: message ? `${message}: ${error}` : error,
});
}
return new CopilotProviderSideError({
provider: this.type,
kind: 'unknown',
message: 'No content generated',
});
}
private handleError(e: any) {
if (e instanceof UserFriendlyError) {
// pass through user friendly errors
return e;
} else {
const error = new CopilotProviderSideError({
provider: this.type,
kind: 'unexpected_response',
message: e?.message || 'Unexpected fal response',
});
return error;
}
}
private parseSchema<R>(schema: ZodType<R>, data: unknown): R {
const result = schema.safeParse(data);
if (result.success) return result.data;
const errors = JSON.stringify(result.error.errors);
throw new CopilotProviderSideError({
provider: this.type,
kind: 'unexpected_response',
message: `Unexpected fal response: ${errors}`,
});
}
async generateText(
messages: PromptMessage[],
model: string = 'llava-next',
options: CopilotChatOptions = {}
): Promise<string> {
if (!this.availableModels.includes(model)) {
throw new Error(`Invalid model: ${model}`);
throw new CopilotPromptInvalid(`Invalid model: ${model}`);
}
// by default, image prompt assumes there is only one message
const prompt = this.extractPrompt(messages.pop());
const data = (await fetch(`https://fal.run/fal-ai/${model}`, {
method: 'POST',
headers: {
Authorization: `key ${this.config.apiKey}`,
'Content-Type': 'application/json',
},
body: JSON.stringify({
...prompt,
sync_mode: true,
enable_safety_checks: false,
}),
signal: options.signal,
}).then(res => res.json())) as FalResponse;
try {
const response = await fetch(`https://fal.run/fal-ai/${model}`, {
method: 'POST',
headers: {
Authorization: `key ${this.config.apiKey}`,
'Content-Type': 'application/json',
},
body: JSON.stringify({
...prompt,
sync_mode: true,
enable_safety_checks: false,
}),
signal: options.signal,
});
if (!data.output) {
const error = this.extractError(data);
throw new Error(
error ? `Failed to generate image: ${error}` : 'No images generated'
);
const data = this.parseSchema(FalResponseSchema, await response.json());
if (!data.output) {
throw this.extractFalError(data, 'Failed to generate text');
}
return data.output;
} catch (e: any) {
throw this.handleError(e);
}
return data.output;
}
async *generateTextStream(
@ -199,11 +249,8 @@ export class FalProvider
const prompt = this.extractPrompt(messages.pop());
if (model.startsWith('workflows/')) {
const stream = await falStream(model, { input: prompt });
const result = FalStreamOutputSchema.safeParse(await stream.done());
if (result.success) return result.data.output;
const errors = JSON.stringify(result.error.errors);
throw new Error(`Unexpected fal response: ${errors}`);
return this.parseSchema(FalStreamOutputSchema, await stream.done())
.output;
} else {
const response = await fetch(`https://fal.run/fal-ai/${model}`, {
method: 'POST',
@ -219,10 +266,7 @@ export class FalProvider
}),
signal: options.signal,
});
const result = FalResponseSchema.safeParse(await response.json());
if (result.success) return result.data;
const errors = JSON.stringify(result.error.errors);
throw new Error(`Unexpected fal response: ${errors}`);
return this.parseSchema(FalResponseSchema, await response.json());
}
}
@ -233,19 +277,14 @@ export class FalProvider
options: CopilotImageOptions = {}
): Promise<Array<string>> {
if (!this.availableModels.includes(model)) {
throw new Error(`Invalid model: ${model}`);
throw new CopilotPromptInvalid(`Invalid model: ${model}`);
}
try {
const data = await this.buildResponse(messages, model, options);
if (!data.images?.length && !data.image?.url) {
const error = this.extractError(data);
const finalError = error
? `Failed to generate image: ${error}`
: 'No images generated';
this.logger.error(finalError);
throw new Error(finalError);
throw this.extractFalError(data, 'Failed to generate images');
}
if (data.image?.url) {
@ -258,9 +297,7 @@ export class FalProvider
.map(image => image.url) || []
);
} catch (e: any) {
const error = `Failed to generate image: ${e.message}`;
this.logger.error(error, e.stack);
throw new Error(error);
throw this.handleError(e);
}
}

View File

@ -1,6 +1,11 @@
import { Logger } from '@nestjs/common';
import { ClientOptions, OpenAI } from 'openai';
import { APIError, ClientOptions, OpenAI } from 'openai';
import {
CopilotPromptInvalid,
CopilotProviderSideError,
UserFriendlyError,
} from '../../../fundamentals';
import {
ChatMessageRole,
CopilotCapability,
@ -80,8 +85,8 @@ export class OpenAIProvider
this.existsModels = await this.instance.models
.list()
.then(({ data }) => data.map(m => m.id));
} catch (e) {
this.logger.error('Failed to fetch online model list', e);
} catch (e: any) {
this.logger.error('Failed to fetch online model list', e.stack);
}
}
return !!this.existsModels?.includes(model);
@ -147,7 +152,7 @@ export class OpenAIProvider
options: CopilotChatOptions;
}) {
if (!this.availableModels.includes(model)) {
throw new Error(`Invalid model: ${model}`);
throw new CopilotPromptInvalid(`Invalid model: ${model}`);
}
if (Array.isArray(messages) && messages.length > 0) {
this.extractOptionFromMessages(messages, options);
@ -164,7 +169,7 @@ export class OpenAIProvider
(!Array.isArray(m.attachments) || !m.attachments.length))
)
) {
throw new Error('Empty message content');
throw new CopilotPromptInvalid('Empty message content');
}
if (
messages.some(
@ -174,7 +179,7 @@ export class OpenAIProvider
!ChatMessageRole.includes(m.role)
)
) {
throw new Error('Invalid message role');
throw new CopilotPromptInvalid('Invalid message role');
}
// json mode need 'json' keyword in content
// ref: https://platform.openai.com/docs/api-reference/chat/create#chat-create-response_format
@ -182,42 +187,62 @@ export class OpenAIProvider
options.jsonMode &&
!messages.some(m => m.content.toLowerCase().includes('json'))
) {
throw new Error('Prompt not support json mode');
throw new CopilotPromptInvalid('Prompt not support json mode');
}
} else if (
Array.isArray(embeddings) &&
embeddings.some(e => typeof e !== 'string' || !e || !e.trim())
) {
throw new Error('Invalid embedding');
throw new CopilotPromptInvalid('Invalid embedding');
}
}
private handleError(e: any) {
if (e instanceof UserFriendlyError) {
return e;
} else if (e instanceof APIError) {
return new CopilotProviderSideError({
provider: this.type,
kind: e.type || 'unknown',
message: e.message,
});
} else {
return new CopilotProviderSideError({
provider: this.type,
kind: 'unexpected_response',
message: e?.message || 'Unexpected openai response',
});
}
}
// ====== text to text ======
async generateText(
messages: PromptMessage[],
model: string = 'gpt-3.5-turbo',
options: CopilotChatOptions = {}
): Promise<string> {
this.checkParams({ messages, model, options });
const result = await this.instance.chat.completions.create(
{
messages: this.chatToGPTMessage(messages),
model: model,
temperature: options.temperature || 0,
max_tokens: options.maxTokens || 4096,
response_format: {
type: options.jsonMode ? 'json_object' : 'text',
try {
const result = await this.instance.chat.completions.create(
{
messages: this.chatToGPTMessage(messages),
model: model,
temperature: options.temperature || 0,
max_tokens: options.maxTokens || 4096,
response_format: {
type: options.jsonMode ? 'json_object' : 'text',
},
user: options.user,
},
user: options.user,
},
{ signal: options.signal }
);
const { content } = result.choices[0].message;
if (!content) {
throw new Error('Failed to generate text');
{ signal: options.signal }
);
const { content } = result.choices[0].message;
if (!content) throw new Error('Failed to generate text');
return content.trim();
} catch (e: any) {
throw this.handleError(e);
}
return content;
}
async *generateTextStream(
@ -226,32 +251,36 @@ export class OpenAIProvider
options: CopilotChatOptions = {}
): AsyncIterable<string> {
this.checkParams({ messages, model, options });
const result = await this.instance.chat.completions.create(
{
stream: true,
messages: this.chatToGPTMessage(messages),
model: model,
temperature: options.temperature || 0,
max_tokens: options.maxTokens || 4096,
response_format: {
type: options.jsonMode ? 'json_object' : 'text',
try {
const result = await this.instance.chat.completions.create(
{
stream: true,
messages: this.chatToGPTMessage(messages),
model: model,
temperature: options.temperature || 0,
max_tokens: options.maxTokens || 4096,
response_format: {
type: options.jsonMode ? 'json_object' : 'text',
},
user: options.user,
},
user: options.user,
},
{
signal: options.signal,
}
);
{
signal: options.signal,
}
);
for await (const message of result) {
const content = message.choices[0].delta.content;
if (content) {
yield content;
if (options.signal?.aborted) {
result.controller.abort();
break;
for await (const message of result) {
const content = message.choices[0].delta.content;
if (content) {
yield content;
if (options.signal?.aborted) {
result.controller.abort();
break;
}
}
}
} catch (e: any) {
throw this.handleError(e);
}
}
@ -265,13 +294,17 @@ export class OpenAIProvider
messages = Array.isArray(messages) ? messages : [messages];
this.checkParams({ embeddings: messages, model, options });
const result = await this.instance.embeddings.create({
model: model,
input: messages,
dimensions: options.dimensions || DEFAULT_DIMENSIONS,
user: options.user,
});
return result.data.map(e => e.embedding);
try {
const result = await this.instance.embeddings.create({
model: model,
input: messages,
dimensions: options.dimensions || DEFAULT_DIMENSIONS,
user: options.user,
});
return result.data.map(e => e.embedding);
} catch (e: any) {
throw this.handleError(e);
}
}
// ====== text to image ======
@ -281,20 +314,25 @@ export class OpenAIProvider
options: CopilotImageOptions = {}
): Promise<Array<string>> {
const { content: prompt } = messages.pop() || {};
if (!prompt) {
throw new Error('Prompt is required');
}
const result = await this.instance.images.generate(
{
prompt,
model,
response_format: 'url',
user: options.user,
},
{ signal: options.signal }
);
if (!prompt) throw new CopilotPromptInvalid('Prompt is required');
return result.data.map(image => image.url).filter((v): v is string => !!v);
try {
const result = await this.instance.images.generate(
{
prompt,
model,
response_format: 'url',
user: options.user,
},
{ signal: options.signal }
);
return result.data
.map(image => image.url)
.filter((v): v is string => !!v);
} catch (e: any) {
throw this.handleError(e);
}
}
async *generateImagesStream(

View File

@ -81,6 +81,12 @@ type CopilotPromptType {
name: String!
}
type CopilotProviderSideErrorDataType {
kind: String!
message: String!
provider: String!
}
type CopilotQuota {
limit: SafeInt
used: SafeInt!
@ -169,7 +175,7 @@ enum EarlyAccessType {
App
}
union ErrorDataUnion = BlobNotFoundDataType | CopilotPromptNotFoundDataType | DocAccessDeniedDataType | DocHistoryNotFoundDataType | DocNotFoundDataType | InvalidHistoryTimestampDataType | InvalidPasswordLengthDataType | InvalidRuntimeConfigTypeDataType | MissingOauthQueryParameterDataType | NotInWorkspaceDataType | RuntimeConfigNotFoundDataType | SameSubscriptionRecurringDataType | SubscriptionAlreadyExistsDataType | SubscriptionNotExistsDataType | SubscriptionPlanNotFoundDataType | UnknownOauthProviderDataType | VersionRejectedDataType | WorkspaceAccessDeniedDataType | WorkspaceNotFoundDataType | WorkspaceOwnerNotFoundDataType
union ErrorDataUnion = BlobNotFoundDataType | CopilotPromptNotFoundDataType | CopilotProviderSideErrorDataType | DocAccessDeniedDataType | DocHistoryNotFoundDataType | DocNotFoundDataType | InvalidHistoryTimestampDataType | InvalidPasswordLengthDataType | InvalidRuntimeConfigTypeDataType | MissingOauthQueryParameterDataType | NotInWorkspaceDataType | RuntimeConfigNotFoundDataType | SameSubscriptionRecurringDataType | SubscriptionAlreadyExistsDataType | SubscriptionNotExistsDataType | SubscriptionPlanNotFoundDataType | UnknownOauthProviderDataType | VersionRejectedDataType | WorkspaceAccessDeniedDataType | WorkspaceNotFoundDataType | WorkspaceOwnerNotFoundDataType
enum ErrorNames {
ACCESS_DENIED
@ -182,7 +188,9 @@ enum ErrorNames {
COPILOT_FAILED_TO_CREATE_MESSAGE
COPILOT_FAILED_TO_GENERATE_TEXT
COPILOT_MESSAGE_NOT_FOUND
COPILOT_PROMPT_INVALID
COPILOT_PROMPT_NOT_FOUND
COPILOT_PROVIDER_SIDE_ERROR
COPILOT_QUOTA_EXCEEDED
COPILOT_SESSION_DELETED
COPILOT_SESSION_NOT_FOUND

View File

@ -126,6 +126,13 @@ export interface CopilotPromptType {
name: Scalars['String']['output'];
}
export interface CopilotProviderSideErrorDataType {
__typename?: 'CopilotProviderSideErrorDataType';
kind: Scalars['String']['output'];
message: Scalars['String']['output'];
provider: Scalars['String']['output'];
}
export interface CopilotQuota {
__typename?: 'CopilotQuota';
limit: Maybe<Scalars['SafeInt']['output']>;
@ -218,6 +225,7 @@ export enum EarlyAccessType {
export type ErrorDataUnion =
| BlobNotFoundDataType
| CopilotPromptNotFoundDataType
| CopilotProviderSideErrorDataType
| DocAccessDeniedDataType
| DocHistoryNotFoundDataType
| DocNotFoundDataType
@ -248,7 +256,9 @@ export enum ErrorNames {
COPILOT_FAILED_TO_CREATE_MESSAGE = 'COPILOT_FAILED_TO_CREATE_MESSAGE',
COPILOT_FAILED_TO_GENERATE_TEXT = 'COPILOT_FAILED_TO_GENERATE_TEXT',
COPILOT_MESSAGE_NOT_FOUND = 'COPILOT_MESSAGE_NOT_FOUND',
COPILOT_PROMPT_INVALID = 'COPILOT_PROMPT_INVALID',
COPILOT_PROMPT_NOT_FOUND = 'COPILOT_PROMPT_NOT_FOUND',
COPILOT_PROVIDER_SIDE_ERROR = 'COPILOT_PROVIDER_SIDE_ERROR',
COPILOT_QUOTA_EXCEEDED = 'COPILOT_QUOTA_EXCEEDED',
COPILOT_SESSION_DELETED = 'COPILOT_SESSION_DELETED',
COPILOT_SESSION_NOT_FOUND = 'COPILOT_SESSION_NOT_FOUND',