change caption method

This commit is contained in:
DepFA 2022-10-10 00:07:52 +01:00 committed by GitHub
parent 0ac3a07eec
commit d6a599ef9b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -8,7 +8,7 @@ import html
import datetime import datetime
from PIL import Image,PngImagePlugin from PIL import Image,PngImagePlugin
from ..images import captionImge from ..images import captionImageOverlay
import numpy as np import numpy as np
import base64 import base64
import json import json
@ -212,6 +212,12 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
else: else:
images_dir = None images_dir = None
if create_image_every > 0 and save_image_with_stored_embedding:
images_embeds_dir = os.path.join(log_directory, "image_embeddings")
os.makedirs(images_embeds_dir, exist_ok=True)
else:
images_embeds_dir = None
cond_model = shared.sd_model.cond_stage_model cond_model = shared.sd_model.cond_stage_model
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
@ -279,18 +285,24 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
shared.state.current_image = image shared.state.current_image = image
if save_image_with_stored_embedding: if save_image_with_stored_embedding and os.path.exists(last_saved_file):
last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{embedding.step}.png')
info = PngImagePlugin.PngInfo() info = PngImagePlugin.PngInfo()
data = torch.load(last_saved_file) data = torch.load(last_saved_file)
info.add_text("sd-ti-embedding", embeddingToB64(data)) info.add_text("sd-ti-embedding", embeddingToB64(data))
pre_lines = [((255, 207, 175),"<{}>".format(data.get('name','???')))] title = "<{}>".format(data.get('name','???'))
checkpoint = sd_models.select_checkpoint() checkpoint = sd_models.select_checkpoint()
post_lines = [((240, 223, 175),"Trained against checkpoint [{}] for {} steps".format(checkpoint.hash, footer_left = checkpoint.model_name
embedding.step))] footer_mid = '[{}]'.format(checkpoint.hash)
captioned_image = captionImge(image,prelines=pre_lines,postlines=post_lines) footer_right = '[{}]'.format(embedding.step)
captioned_image.save(last_saved_image, "PNG", pnginfo=info)
else: captioned_image = captionImageOverlay(image,title,footer_left,footer_mid,footer_right)
captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
image.save(last_saved_image) image.save(last_saved_image)
last_saved_image += f", prompt: {text}" last_saved_image += f", prompt: {text}"