feat: fetch fal stream correctly (#7141)

This commit is contained in:
darkskygit 2024-06-04 09:30:11 +00:00
parent 01fc1ea835
commit db0837936a
No known key found for this signature in database
GPG Key ID: 97B7D036B1566E9D
5 changed files with 123 additions and 90 deletions

View File

@ -20,6 +20,7 @@
"dependencies": {
"@apollo/server": "^4.10.2",
"@aws-sdk/client-s3": "^3.552.0",
"@fal-ai/serverless-client": "^0.10.2",
"@google-cloud/opentelemetry-cloud-monitoring-exporter": "^0.18.0",
"@google-cloud/opentelemetry-cloud-trace-exporter": "^2.2.0",
"@google-cloud/opentelemetry-resource-util": "^2.2.0",

View File

@ -0,0 +1,13 @@
import { PrismaClient } from '@prisma/client';
import { refreshPrompts } from './utils/prompts';
export class UpdatePrompts1717490700326 {
// do the migration
static async up(db: PrismaClient) {
await refreshPrompts(db);
}
// revert the migration
static async down(_db: PrismaClient) {}
}

View File

@ -86,64 +86,26 @@ export const prompts: Prompt[] = [
{
name: 'debug:action:fal-sdturbo-clay',
action: 'AI image filter clay style',
model: 'fast-sdxl/image-to-image',
messages: [
{
role: 'user',
content: 'claymation, clay, {{content}}',
params: {
lora: [
'https://models.affine.pro/fal/Clay_AFFiNEAI_SDXL1_CLAYMATION.safetensors',
],
},
},
],
model: 'workflows/darkskygit/clay',
messages: [],
},
{
name: 'debug:action:fal-sdturbo-pixel',
action: 'AI image filter pixel style',
model: 'fast-sdxl/image-to-image',
messages: [
{
role: 'user',
content: 'pixel art, very high detail, masterpiece, {{content}}',
params: {
lora: ['https://models.affine.pro/fal/pixel-art-xl-v1.1.safetensors'],
},
},
],
model: 'workflows/darkskygit/pixel-art',
messages: [],
},
{
name: 'debug:action:fal-sdturbo-sketch',
action: 'AI image filter sketch style',
model: 'fast-sdxl/image-to-image',
messages: [
{
role: 'user',
content: 'sketch for art examination, {{content}}',
params: {
lora: [
'https://models.affine.pro/fal/sketch_for_art_examination.safetensors',
],
},
},
],
model: 'workflows/darkskygit/sketch',
messages: [],
},
{
name: 'debug:action:fal-sdturbo-fantasy',
action: 'AI image filter anime style',
model: 'fast-sdxl/image-to-image',
messages: [
{
role: 'user',
content: 'fansty world, {{content}}',
params: {
lora: [
'https://models.affine.pro/fal/fansty%20world-000020.safetensors',
],
},
},
],
model: 'workflows/darkskygit/animie',
messages: [],
},
{
name: 'debug:action:fal-face-to-sticker',

View File

@ -1,5 +1,12 @@
import assert from 'node:assert';
import {
config as falConfig,
stream as falStream,
} from '@fal-ai/serverless-client';
import { Logger } from '@nestjs/common';
import { z } from 'zod';
import {
CopilotCapability,
CopilotChatOptions,
@ -14,21 +21,35 @@ export type FalConfig = {
apiKey: string;
};
export type FalImage = {
url: string;
seed: number;
file_name: string;
};
const FalImageSchema = z
.object({
url: z.string(),
seed: z.number().optional(),
content_type: z.string(),
file_name: z.string(),
file_size: z.number(),
width: z.number(),
height: z.number(),
})
.optional();
export type FalResponse = {
detail: Array<{ msg: string }> | string;
// normal sd/sdxl response
images?: Array<FalImage>;
// special i2i model response
image?: FalImage;
// image2text response
output: string;
};
type FalImage = z.infer<typeof FalImageSchema>;
const FalResponseSchema = z.object({
detail: z
.union([z.array(z.object({ msg: z.string() })), z.string()])
.optional(),
images: z.array(FalImageSchema).optional(),
image: FalImageSchema.optional(),
output: z.string().optional(),
});
type FalResponse = z.infer<typeof FalResponseSchema>;
const FalStreamOutputSchema = z.object({
type: z.literal('output'),
output: FalResponseSchema,
});
type FalPrompt = {
image_url?: string;
@ -55,12 +76,19 @@ export class FalProvider
'face-to-sticker',
'imageutils/rembg',
'fast-sdxl/image-to-image',
'workflows/darkskygit/animie',
'workflows/darkskygit/clay',
'workflows/darkskygit/pixel-art',
'workflows/darkskygit/sketch',
// image to text
'llava-next',
];
private readonly logger = new Logger(FalProvider.name);
constructor(private readonly config: FalConfig) {
assert(FalProvider.assetsConfig(config));
falConfig({ credentials: this.config.apiKey });
}
static assetsConfig(config: FalConfig) {
@ -162,6 +190,37 @@ export class FalProvider
}
}
private async buildResponse(
messages: PromptMessage[],
model: string = this.availableModels[0],
options: CopilotImageOptions = {}
) {
// by default, image prompt assumes there is only one message
const prompt = this.extractPrompt(messages.pop());
if (model.startsWith('workflows/')) {
const stream = await falStream(model, { input: prompt });
const result = FalStreamOutputSchema.parse(await stream.done());
return result.output;
} else {
const response = await fetch(`https://fal.run/fal-ai/${model}`, {
method: 'POST',
headers: {
Authorization: `key ${this.config.apiKey}`,
'Content-Type': 'application/json',
},
body: JSON.stringify({
...prompt,
sync_mode: true,
seed: options.seed || 42,
enable_safety_checks: false,
}),
signal: options.signal,
});
return FalResponseSchema.parse(await response.json());
}
}
// ====== image to image ======
async generateImages(
messages: PromptMessage[],
@ -172,35 +231,32 @@ export class FalProvider
throw new Error(`Invalid model: ${model}`);
}
// by default, image prompt assumes there is only one message
const prompt = this.extractPrompt(messages.pop());
const data = (await fetch(`https://fal.run/fal-ai/${model}`, {
method: 'POST',
headers: {
Authorization: `key ${this.config.apiKey}`,
'Content-Type': 'application/json',
},
body: JSON.stringify({
...prompt,
sync_mode: true,
seed: options.seed || 42,
enable_safety_checks: false,
}),
signal: options.signal,
}).then(res => res.json())) as FalResponse;
try {
const data = await this.buildResponse(messages, model, options);
if (!data.images?.length && !data.image?.url) {
const error = this.extractError(data);
throw new Error(
error ? `Failed to generate image: ${error}` : 'No images generated'
if (!data.images?.length && !data.image?.url) {
const error = this.extractError(data);
const finalError = error
? `Failed to generate image: ${error}`
: 'No images generated';
this.logger.error(finalError);
throw new Error(finalError);
}
if (data.image?.url) {
return [data.image.url];
}
return (
data.images
?.filter((image): image is NonNullable<FalImage> => !!image)
.map(image => image.url) || []
);
} catch (e: any) {
const error = `Failed to generate image: ${e.message}`;
this.logger.error(error, e.stack);
throw new Error(error);
}
if (data.image?.url) {
return [data.image.url];
}
return data.images?.map(image => image.url) || [];
}
async *generateImagesStream(

View File

@ -658,6 +658,7 @@ __metadata:
"@affine/server-native": "workspace:*"
"@apollo/server": "npm:^4.10.2"
"@aws-sdk/client-s3": "npm:^3.552.0"
"@fal-ai/serverless-client": "npm:^0.10.2"
"@google-cloud/opentelemetry-cloud-monitoring-exporter": "npm:^0.18.0"
"@google-cloud/opentelemetry-cloud-trace-exporter": "npm:^2.2.0"
"@google-cloud/opentelemetry-resource-util": "npm:^2.2.0"
@ -5427,15 +5428,15 @@ __metadata:
languageName: node
linkType: hard
"@fal-ai/serverless-client@npm:^0.10.0":
version: 0.10.0
resolution: "@fal-ai/serverless-client@npm:0.10.0"
"@fal-ai/serverless-client@npm:^0.10.0, @fal-ai/serverless-client@npm:^0.10.2":
version: 0.10.2
resolution: "@fal-ai/serverless-client@npm:0.10.2"
dependencies:
"@msgpack/msgpack": "npm:^3.0.0-beta2"
eventsource-parser: "npm:^1.1.2"
robot3: "npm:^0.4.1"
uuid-random: "npm:^1.3.2"
checksum: 10/46bf17fa08523ad6847c063535458b2f132e2baa0e40c70f09b881112d8aa3fa8d3be085e4f915cfe5106f8ad6abe31e7a8236e05acf7a884f17a78ae24a705b
checksum: 10/d96951b606179ed06d5d14cc31db7c1e55372bfbef34c1bc894c76e338d5e3dde3686848d866e273e033b0190aa730f48fcbcac72449f7047c50319f552d2423
languageName: node
linkType: hard