feat: add lora support (#6977)

This commit is contained in:
darkskygit 2024-05-20 05:05:33 +00:00
parent 53ee1801e6
commit f2866f57c9
No known key found for this signature in database
GPG Key ID: 97B7D036B1566E9D
3 changed files with 118 additions and 13 deletions

View File

@ -0,0 +1,13 @@
import { PrismaClient } from '@prisma/client';
import { refreshPrompts } from './utils/prompts';
export class UpdatePrompts1715936358947 {
// do the migration
static async up(db: PrismaClient) {
await refreshPrompts(db);
}
// revert the migration
static async down(_db: PrismaClient) {}
}

View File

@ -83,6 +83,68 @@ export const prompts: Prompt[] = [
model: 'imageutils/rembg',
messages: [],
},
{
name: 'debug:action:fal-sdturbo-clay',
action: 'image',
model: 'fast-turbo-diffusion',
messages: [
{
role: 'user',
content: 'claymation, clay, {{content}}',
params: {
lora: [
'https://models.affine.pro/fal/Clay_AFFiNEAI_SDXL1_CLAYMATION.safetensors',
],
},
},
],
},
{
name: 'debug:action:fal-sdturbo-pixel',
action: 'image',
model: 'fast-turbo-diffusion',
messages: [
{
role: 'user',
content: 'pixel art, very high detail, masterpiece, {{content}}',
params: {
lora: ['https://models.affine.pro/fal/pixel-art-xl-v1.1.safetensors'],
},
},
],
},
{
name: 'debug:action:fal-sdturbo-sketch',
action: 'image',
model: 'fast-turbo-diffusion',
messages: [
{
role: 'user',
content: 'sketch for art examination, {{content}}',
params: {
lora: [
'https://models.affine.pro/fal/sketch_for_art_examination.safetensors',
],
},
},
],
},
{
name: 'debug:action:fal-sdturbo-fantasy',
action: 'image',
model: 'fast-turbo-diffusion',
messages: [
{
role: 'user',
content: 'fansty world, {{content}}',
params: {
lora: [
'https://models.affine.pro/fal/fansty%20world-000020.safetensors',
],
},
},
],
},
{
name: 'Summary',
action: 'Summary',

View File

@ -18,6 +18,12 @@ export type FalResponse = {
images: Array<{ url: string }>;
};
type FalPrompt = {
image_url?: string;
prompt?: string;
lora?: string[];
};
export class FalProvider
implements CopilotTextToImageProvider, CopilotImageToImageProvider
{
@ -56,21 +62,50 @@ export class FalProvider
return this.availableModels.includes(model);
}
private extractError(resp: FalResponse): string {
return Array.isArray(resp.detail)
? resp.detail[0]?.msg
: typeof resp.detail === 'string'
? resp.detail
: '';
}
private extractPrompt(message?: PromptMessage): FalPrompt {
if (!message) throw new Error('Prompt is empty');
const { content, attachments, params } = message;
// prompt attachments require at least one
if (!content && (!Array.isArray(attachments) || !attachments.length)) {
throw new Error('Prompt or Attachments is empty');
}
if (Array.isArray(attachments) && attachments.length > 1) {
throw new Error('Only one attachment is allowed');
}
const lora = (
params?.lora
? Array.isArray(params.lora)
? params.lora
: [params.lora]
: []
).filter(v => typeof v === 'string' && v.length);
return {
image_url: attachments?.[0],
prompt: content || undefined,
lora: lora.length ? lora : undefined,
};
}
// ====== image to image ======
async generateImages(
messages: PromptMessage[],
model: string = this.availableModels[0],
options: CopilotImageOptions = {}
): Promise<Array<string>> {
const { content, attachments } = messages.pop() || {};
if (!this.availableModels.includes(model)) {
throw new Error(`Invalid model: ${model}`);
}
// prompt attachments require at least one
if (!content && (!Array.isArray(attachments) || !attachments.length)) {
throw new Error('Prompt or Attachments is empty');
}
// by default, image prompt assumes there is only one message
const prompt = this.extractPrompt(messages.pop());
const data = (await fetch(`https://fal.run/fal-ai/${model}`, {
method: 'POST',
@ -79,8 +114,7 @@ export class FalProvider
'Content-Type': 'application/json',
},
body: JSON.stringify({
image_url: attachments?.[0],
prompt: content,
...prompt,
sync_mode: true,
seed: options.seed || 42,
enable_safety_checks: false,
@ -89,13 +123,9 @@ export class FalProvider
}).then(res => res.json())) as FalResponse;
if (!data.images?.length) {
const error = Array.isArray(data.detail)
? data.detail[0]?.msg
: typeof data.detail === 'string'
? data.detail
: '';
const error = this.extractError(data);
throw new Error(
error ? `Invalid message: ${error}` : 'No images generated'
error ? `Failed to generate image: ${error}` : 'No images generated'
);
}
return data.images?.map(image => image.url) || [];