diff --git a/packages/frontend/core/src/components/blocksuite/block-suite-editor/ai/copilot-client.ts b/packages/frontend/core/src/components/blocksuite/block-suite-editor/ai/copilot-client.ts index 1dcfc99b38..68e692ff9a 100644 --- a/packages/frontend/core/src/components/blocksuite/block-suite-editor/ai/copilot-client.ts +++ b/packages/frontend/core/src/components/blocksuite/block-suite-editor/ai/copilot-client.ts @@ -6,6 +6,7 @@ import { getBaseUrl, getCopilotHistoriesQuery, getCopilotSessionsQuery, + GraphQLError, type GraphQLQuery, type QueryOptions, type RequestOptions, @@ -20,24 +21,53 @@ import { getCurrentStore } from '@toeverything/infra'; type OptionsField = RequestOptions['variables'] extends { options: infer U } ? U : never; +function codeToError(code: number) { + switch (code) { + case 401: + return new UnauthorizedError(); + case 402: + return new PaymentRequiredError(); + default: + return new GeneralNetworkError(); + } +} + +type ErrorType = + | GraphQLError[] + | GraphQLError + | { status: number } + | Error + | string; + +export function resolveError(src: ErrorType) { + if (typeof src === 'string') { + return new GeneralNetworkError(src); + } else if (src instanceof GraphQLError || Array.isArray(src)) { + // only resolve the first error + const error = Array.isArray(src) ? src.at(0) : src; + const code = error?.extensions?.code; + return codeToError(code ?? 500); + } else { + return codeToError(src instanceof Error ? 500 : src.status); + } +} + +export function handleError(src: ErrorType) { + const err = resolveError(src); + if (err instanceof UnauthorizedError) { + getCurrentStore().set(showAILoginRequiredAtom, true); + } + return err; +} + const fetcher = async ( options: QueryOptions ) => { try { return await defaultFetcher(options); } catch (_err) { - const error = Array.isArray(_err) ? _err.at(0) : _err; - const code = error.extensions?.code; - - switch (code) { - case 401: - getCurrentStore().set(showAILoginRequiredAtom, true); - throw new UnauthorizedError(); - case 402: - throw new PaymentRequiredError(); - default: - throw new GeneralNetworkError(); - } + const err = _err as GraphQLError | GraphQLError[] | Error | string; + throw handleError(err); } }; diff --git a/packages/frontend/core/src/components/blocksuite/block-suite-editor/ai/event-source.ts b/packages/frontend/core/src/components/blocksuite/block-suite-editor/ai/event-source.ts index 27e1b8511c..eec95bd0f0 100644 --- a/packages/frontend/core/src/components/blocksuite/block-suite-editor/ai/event-source.ts +++ b/packages/frontend/core/src/components/blocksuite/block-suite-editor/ai/event-source.ts @@ -1,4 +1,4 @@ -import { GeneralNetworkError } from '@blocksuite/blocks'; +import { handleError } from './copilot-client'; export function delay(ms: number) { return new Promise(resolve => setTimeout(resolve, ms)); @@ -15,6 +15,17 @@ type toTextStreamOptions = { timeout?: number; }; +// todo: may need to extend the error type +const safeParseError = (data: string): { status: number } => { + try { + return JSON.parse(data); + } catch { + return { + status: 500, + }; + } +}; + export function toTextStream( eventSource: EventSource, { timeout }: toTextStreamOptions = {} @@ -52,7 +63,9 @@ export function toTextStream( // if there is data in Error event, it means the server sent an error message // otherwise, the stream is finished successfully if (event.type === 'error' && errorMessage) { - rejectMessagePromise(new GeneralNetworkError(errorMessage)); + // try to parse the error message as a JSON object + const error = safeParseError(errorMessage); + rejectMessagePromise(handleError(error)); } else { resolveMessagePromise(); }