mirror of
https://github.com/openvinotoolkit/stable-diffusion-webui.git
synced 2024-12-15 15:13:45 +03:00
change caption method
This commit is contained in:
parent
0ac3a07eec
commit
d6a599ef9b
@ -8,7 +8,7 @@ import html
|
|||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
from PIL import Image,PngImagePlugin
|
from PIL import Image,PngImagePlugin
|
||||||
from ..images import captionImge
|
from ..images import captionImageOverlay
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
@ -212,6 +212,12 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
|
|||||||
else:
|
else:
|
||||||
images_dir = None
|
images_dir = None
|
||||||
|
|
||||||
|
if create_image_every > 0 and save_image_with_stored_embedding:
|
||||||
|
images_embeds_dir = os.path.join(log_directory, "image_embeddings")
|
||||||
|
os.makedirs(images_embeds_dir, exist_ok=True)
|
||||||
|
else:
|
||||||
|
images_embeds_dir = None
|
||||||
|
|
||||||
cond_model = shared.sd_model.cond_stage_model
|
cond_model = shared.sd_model.cond_stage_model
|
||||||
|
|
||||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||||
@ -279,19 +285,25 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
|
|||||||
|
|
||||||
shared.state.current_image = image
|
shared.state.current_image = image
|
||||||
|
|
||||||
if save_image_with_stored_embedding:
|
if save_image_with_stored_embedding and os.path.exists(last_saved_file):
|
||||||
|
|
||||||
|
last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{embedding.step}.png')
|
||||||
|
|
||||||
info = PngImagePlugin.PngInfo()
|
info = PngImagePlugin.PngInfo()
|
||||||
data = torch.load(last_saved_file)
|
data = torch.load(last_saved_file)
|
||||||
info.add_text("sd-ti-embedding", embeddingToB64(data))
|
info.add_text("sd-ti-embedding", embeddingToB64(data))
|
||||||
|
|
||||||
pre_lines = [((255, 207, 175),"<{}>".format(data.get('name','???')))]
|
title = "<{}>".format(data.get('name','???'))
|
||||||
checkpoint = sd_models.select_checkpoint()
|
checkpoint = sd_models.select_checkpoint()
|
||||||
post_lines = [((240, 223, 175),"Trained against checkpoint [{}] for {} steps".format(checkpoint.hash,
|
footer_left = checkpoint.model_name
|
||||||
embedding.step))]
|
footer_mid = '[{}]'.format(checkpoint.hash)
|
||||||
captioned_image = captionImge(image,prelines=pre_lines,postlines=post_lines)
|
footer_right = '[{}]'.format(embedding.step)
|
||||||
captioned_image.save(last_saved_image, "PNG", pnginfo=info)
|
|
||||||
else:
|
captioned_image = captionImageOverlay(image,title,footer_left,footer_mid,footer_right)
|
||||||
image.save(last_saved_image)
|
|
||||||
|
captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
|
||||||
|
|
||||||
|
image.save(last_saved_image)
|
||||||
|
|
||||||
last_saved_image += f", prompt: {text}"
|
last_saved_image += f", prompt: {text}"
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user