mirror of
https://github.com/toeverything/AFFiNE.git
synced 2024-11-27 06:33:32 +03:00
feat: fetch fal stream correctly (#7141)
This commit is contained in:
parent
01fc1ea835
commit
db0837936a
@ -20,6 +20,7 @@
|
||||
"dependencies": {
|
||||
"@apollo/server": "^4.10.2",
|
||||
"@aws-sdk/client-s3": "^3.552.0",
|
||||
"@fal-ai/serverless-client": "^0.10.2",
|
||||
"@google-cloud/opentelemetry-cloud-monitoring-exporter": "^0.18.0",
|
||||
"@google-cloud/opentelemetry-cloud-trace-exporter": "^2.2.0",
|
||||
"@google-cloud/opentelemetry-resource-util": "^2.2.0",
|
||||
|
@ -0,0 +1,13 @@
|
||||
import { PrismaClient } from '@prisma/client';
|
||||
|
||||
import { refreshPrompts } from './utils/prompts';
|
||||
|
||||
export class UpdatePrompts1717490700326 {
|
||||
// do the migration
|
||||
static async up(db: PrismaClient) {
|
||||
await refreshPrompts(db);
|
||||
}
|
||||
|
||||
// revert the migration
|
||||
static async down(_db: PrismaClient) {}
|
||||
}
|
@ -86,64 +86,26 @@ export const prompts: Prompt[] = [
|
||||
{
|
||||
name: 'debug:action:fal-sdturbo-clay',
|
||||
action: 'AI image filter clay style',
|
||||
model: 'fast-sdxl/image-to-image',
|
||||
messages: [
|
||||
{
|
||||
role: 'user',
|
||||
content: 'claymation, clay, {{content}}',
|
||||
params: {
|
||||
lora: [
|
||||
'https://models.affine.pro/fal/Clay_AFFiNEAI_SDXL1_CLAYMATION.safetensors',
|
||||
],
|
||||
},
|
||||
},
|
||||
],
|
||||
model: 'workflows/darkskygit/clay',
|
||||
messages: [],
|
||||
},
|
||||
{
|
||||
name: 'debug:action:fal-sdturbo-pixel',
|
||||
action: 'AI image filter pixel style',
|
||||
model: 'fast-sdxl/image-to-image',
|
||||
messages: [
|
||||
{
|
||||
role: 'user',
|
||||
content: 'pixel art, very high detail, masterpiece, {{content}}',
|
||||
params: {
|
||||
lora: ['https://models.affine.pro/fal/pixel-art-xl-v1.1.safetensors'],
|
||||
},
|
||||
},
|
||||
],
|
||||
model: 'workflows/darkskygit/pixel-art',
|
||||
messages: [],
|
||||
},
|
||||
{
|
||||
name: 'debug:action:fal-sdturbo-sketch',
|
||||
action: 'AI image filter sketch style',
|
||||
model: 'fast-sdxl/image-to-image',
|
||||
messages: [
|
||||
{
|
||||
role: 'user',
|
||||
content: 'sketch for art examination, {{content}}',
|
||||
params: {
|
||||
lora: [
|
||||
'https://models.affine.pro/fal/sketch_for_art_examination.safetensors',
|
||||
],
|
||||
},
|
||||
},
|
||||
],
|
||||
model: 'workflows/darkskygit/sketch',
|
||||
messages: [],
|
||||
},
|
||||
{
|
||||
name: 'debug:action:fal-sdturbo-fantasy',
|
||||
action: 'AI image filter anime style',
|
||||
model: 'fast-sdxl/image-to-image',
|
||||
messages: [
|
||||
{
|
||||
role: 'user',
|
||||
content: 'fansty world, {{content}}',
|
||||
params: {
|
||||
lora: [
|
||||
'https://models.affine.pro/fal/fansty%20world-000020.safetensors',
|
||||
],
|
||||
},
|
||||
},
|
||||
],
|
||||
model: 'workflows/darkskygit/animie',
|
||||
messages: [],
|
||||
},
|
||||
{
|
||||
name: 'debug:action:fal-face-to-sticker',
|
||||
|
@ -1,5 +1,12 @@
|
||||
import assert from 'node:assert';
|
||||
|
||||
import {
|
||||
config as falConfig,
|
||||
stream as falStream,
|
||||
} from '@fal-ai/serverless-client';
|
||||
import { Logger } from '@nestjs/common';
|
||||
import { z } from 'zod';
|
||||
|
||||
import {
|
||||
CopilotCapability,
|
||||
CopilotChatOptions,
|
||||
@ -14,21 +21,35 @@ export type FalConfig = {
|
||||
apiKey: string;
|
||||
};
|
||||
|
||||
export type FalImage = {
|
||||
url: string;
|
||||
seed: number;
|
||||
file_name: string;
|
||||
};
|
||||
const FalImageSchema = z
|
||||
.object({
|
||||
url: z.string(),
|
||||
seed: z.number().optional(),
|
||||
content_type: z.string(),
|
||||
file_name: z.string(),
|
||||
file_size: z.number(),
|
||||
width: z.number(),
|
||||
height: z.number(),
|
||||
})
|
||||
.optional();
|
||||
|
||||
export type FalResponse = {
|
||||
detail: Array<{ msg: string }> | string;
|
||||
// normal sd/sdxl response
|
||||
images?: Array<FalImage>;
|
||||
// special i2i model response
|
||||
image?: FalImage;
|
||||
// image2text response
|
||||
output: string;
|
||||
};
|
||||
type FalImage = z.infer<typeof FalImageSchema>;
|
||||
|
||||
const FalResponseSchema = z.object({
|
||||
detail: z
|
||||
.union([z.array(z.object({ msg: z.string() })), z.string()])
|
||||
.optional(),
|
||||
images: z.array(FalImageSchema).optional(),
|
||||
image: FalImageSchema.optional(),
|
||||
output: z.string().optional(),
|
||||
});
|
||||
|
||||
type FalResponse = z.infer<typeof FalResponseSchema>;
|
||||
|
||||
const FalStreamOutputSchema = z.object({
|
||||
type: z.literal('output'),
|
||||
output: FalResponseSchema,
|
||||
});
|
||||
|
||||
type FalPrompt = {
|
||||
image_url?: string;
|
||||
@ -55,12 +76,19 @@ export class FalProvider
|
||||
'face-to-sticker',
|
||||
'imageutils/rembg',
|
||||
'fast-sdxl/image-to-image',
|
||||
'workflows/darkskygit/animie',
|
||||
'workflows/darkskygit/clay',
|
||||
'workflows/darkskygit/pixel-art',
|
||||
'workflows/darkskygit/sketch',
|
||||
// image to text
|
||||
'llava-next',
|
||||
];
|
||||
|
||||
private readonly logger = new Logger(FalProvider.name);
|
||||
|
||||
constructor(private readonly config: FalConfig) {
|
||||
assert(FalProvider.assetsConfig(config));
|
||||
falConfig({ credentials: this.config.apiKey });
|
||||
}
|
||||
|
||||
static assetsConfig(config: FalConfig) {
|
||||
@ -162,6 +190,37 @@ export class FalProvider
|
||||
}
|
||||
}
|
||||
|
||||
private async buildResponse(
|
||||
messages: PromptMessage[],
|
||||
model: string = this.availableModels[0],
|
||||
options: CopilotImageOptions = {}
|
||||
) {
|
||||
// by default, image prompt assumes there is only one message
|
||||
const prompt = this.extractPrompt(messages.pop());
|
||||
if (model.startsWith('workflows/')) {
|
||||
const stream = await falStream(model, { input: prompt });
|
||||
|
||||
const result = FalStreamOutputSchema.parse(await stream.done());
|
||||
return result.output;
|
||||
} else {
|
||||
const response = await fetch(`https://fal.run/fal-ai/${model}`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
Authorization: `key ${this.config.apiKey}`,
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
...prompt,
|
||||
sync_mode: true,
|
||||
seed: options.seed || 42,
|
||||
enable_safety_checks: false,
|
||||
}),
|
||||
signal: options.signal,
|
||||
});
|
||||
return FalResponseSchema.parse(await response.json());
|
||||
}
|
||||
}
|
||||
|
||||
// ====== image to image ======
|
||||
async generateImages(
|
||||
messages: PromptMessage[],
|
||||
@ -172,35 +231,32 @@ export class FalProvider
|
||||
throw new Error(`Invalid model: ${model}`);
|
||||
}
|
||||
|
||||
// by default, image prompt assumes there is only one message
|
||||
const prompt = this.extractPrompt(messages.pop());
|
||||
const data = (await fetch(`https://fal.run/fal-ai/${model}`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
Authorization: `key ${this.config.apiKey}`,
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
...prompt,
|
||||
sync_mode: true,
|
||||
seed: options.seed || 42,
|
||||
enable_safety_checks: false,
|
||||
}),
|
||||
signal: options.signal,
|
||||
}).then(res => res.json())) as FalResponse;
|
||||
try {
|
||||
const data = await this.buildResponse(messages, model, options);
|
||||
|
||||
if (!data.images?.length && !data.image?.url) {
|
||||
const error = this.extractError(data);
|
||||
throw new Error(
|
||||
error ? `Failed to generate image: ${error}` : 'No images generated'
|
||||
if (!data.images?.length && !data.image?.url) {
|
||||
const error = this.extractError(data);
|
||||
const finalError = error
|
||||
? `Failed to generate image: ${error}`
|
||||
: 'No images generated';
|
||||
this.logger.error(finalError);
|
||||
throw new Error(finalError);
|
||||
}
|
||||
|
||||
if (data.image?.url) {
|
||||
return [data.image.url];
|
||||
}
|
||||
|
||||
return (
|
||||
data.images
|
||||
?.filter((image): image is NonNullable<FalImage> => !!image)
|
||||
.map(image => image.url) || []
|
||||
);
|
||||
} catch (e: any) {
|
||||
const error = `Failed to generate image: ${e.message}`;
|
||||
this.logger.error(error, e.stack);
|
||||
throw new Error(error);
|
||||
}
|
||||
|
||||
if (data.image?.url) {
|
||||
return [data.image.url];
|
||||
}
|
||||
|
||||
return data.images?.map(image => image.url) || [];
|
||||
}
|
||||
|
||||
async *generateImagesStream(
|
||||
|
@ -658,6 +658,7 @@ __metadata:
|
||||
"@affine/server-native": "workspace:*"
|
||||
"@apollo/server": "npm:^4.10.2"
|
||||
"@aws-sdk/client-s3": "npm:^3.552.0"
|
||||
"@fal-ai/serverless-client": "npm:^0.10.2"
|
||||
"@google-cloud/opentelemetry-cloud-monitoring-exporter": "npm:^0.18.0"
|
||||
"@google-cloud/opentelemetry-cloud-trace-exporter": "npm:^2.2.0"
|
||||
"@google-cloud/opentelemetry-resource-util": "npm:^2.2.0"
|
||||
@ -5427,15 +5428,15 @@ __metadata:
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
"@fal-ai/serverless-client@npm:^0.10.0":
|
||||
version: 0.10.0
|
||||
resolution: "@fal-ai/serverless-client@npm:0.10.0"
|
||||
"@fal-ai/serverless-client@npm:^0.10.0, @fal-ai/serverless-client@npm:^0.10.2":
|
||||
version: 0.10.2
|
||||
resolution: "@fal-ai/serverless-client@npm:0.10.2"
|
||||
dependencies:
|
||||
"@msgpack/msgpack": "npm:^3.0.0-beta2"
|
||||
eventsource-parser: "npm:^1.1.2"
|
||||
robot3: "npm:^0.4.1"
|
||||
uuid-random: "npm:^1.3.2"
|
||||
checksum: 10/46bf17fa08523ad6847c063535458b2f132e2baa0e40c70f09b881112d8aa3fa8d3be085e4f915cfe5106f8ad6abe31e7a8236e05acf7a884f17a78ae24a705b
|
||||
checksum: 10/d96951b606179ed06d5d14cc31db7c1e55372bfbef34c1bc894c76e338d5e3dde3686848d866e273e033b0190aa730f48fcbcac72449f7047c50319f552d2423
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user