2022-09-20 01:13:12 +03:00
|
|
|
import os
|
|
|
|
|
2022-09-11 11:31:16 +03:00
|
|
|
import numpy as np
|
|
|
|
from PIL import Image
|
|
|
|
|
2022-09-26 02:22:12 +03:00
|
|
|
import torch
|
2022-09-27 10:44:00 +03:00
|
|
|
import tqdm
|
2022-09-26 02:22:12 +03:00
|
|
|
|
2022-09-11 23:24:24 +03:00
|
|
|
from modules import processing, shared, images, devices
|
2022-09-11 11:31:16 +03:00
|
|
|
from modules.shared import opts
|
|
|
|
import modules.gfpgan_model
|
|
|
|
from modules.ui import plaintext_to_html
|
|
|
|
import modules.codeformer_model
|
2022-09-13 19:23:55 +03:00
|
|
|
import piexif
|
2022-09-14 15:20:05 +03:00
|
|
|
import piexif.helper
|
2022-09-13 19:23:55 +03:00
|
|
|
|
2022-09-11 11:31:16 +03:00
|
|
|
|
|
|
|
cached_images = {}
|
|
|
|
|
|
|
|
|
2022-09-22 12:11:48 +03:00
|
|
|
def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility):
|
2022-09-11 23:24:24 +03:00
|
|
|
devices.torch_gc()
|
2022-09-11 11:31:16 +03:00
|
|
|
|
2022-09-16 06:23:37 +03:00
|
|
|
imageArr = []
|
2022-09-20 01:13:12 +03:00
|
|
|
# Also keep track of original file names
|
|
|
|
imageNameArr = []
|
2022-09-12 18:59:53 +03:00
|
|
|
|
2022-09-22 12:11:48 +03:00
|
|
|
if extras_mode == 1:
|
2022-09-16 06:23:37 +03:00
|
|
|
#convert file to pillow image
|
|
|
|
for img in image_folder:
|
|
|
|
image = Image.fromarray(np.array(Image.open(img)))
|
|
|
|
imageArr.append(image)
|
2022-09-20 01:13:12 +03:00
|
|
|
imageNameArr.append(os.path.splitext(img.orig_name)[0])
|
2022-09-22 12:11:48 +03:00
|
|
|
else:
|
|
|
|
imageArr.append(image)
|
|
|
|
imageNameArr.append(None)
|
2022-09-11 11:31:16 +03:00
|
|
|
|
|
|
|
outpath = opts.outdir_samples or opts.outdir_extras_samples
|
|
|
|
|
2022-09-16 12:43:24 +03:00
|
|
|
outputs = []
|
2022-09-20 01:13:12 +03:00
|
|
|
for image, image_name in zip(imageArr, imageNameArr):
|
2022-09-16 06:23:37 +03:00
|
|
|
existing_pnginfo = image.info or {}
|
|
|
|
|
|
|
|
image = image.convert("RGB")
|
|
|
|
info = ""
|
|
|
|
|
|
|
|
if gfpgan_visibility > 0:
|
|
|
|
restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8))
|
|
|
|
res = Image.fromarray(restored_img)
|
2022-09-11 11:31:16 +03:00
|
|
|
|
2022-09-16 06:23:37 +03:00
|
|
|
if gfpgan_visibility < 1.0:
|
|
|
|
res = Image.blend(image, res, gfpgan_visibility)
|
2022-09-11 11:31:16 +03:00
|
|
|
|
2022-09-16 06:23:37 +03:00
|
|
|
info += f"GFPGAN visibility:{round(gfpgan_visibility, 2)}\n"
|
|
|
|
image = res
|
2022-09-11 11:31:16 +03:00
|
|
|
|
2022-09-16 06:23:37 +03:00
|
|
|
if codeformer_visibility > 0:
|
|
|
|
restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight)
|
|
|
|
res = Image.fromarray(restored_img)
|
2022-09-11 11:31:16 +03:00
|
|
|
|
2022-09-16 06:23:37 +03:00
|
|
|
if codeformer_visibility < 1.0:
|
|
|
|
res = Image.blend(image, res, codeformer_visibility)
|
2022-09-11 11:31:16 +03:00
|
|
|
|
2022-09-17 22:02:46 +03:00
|
|
|
info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n"
|
2022-09-16 06:23:37 +03:00
|
|
|
image = res
|
2022-09-11 11:31:16 +03:00
|
|
|
|
2022-09-16 06:23:37 +03:00
|
|
|
if upscaling_resize != 1.0:
|
|
|
|
def upscale(image, scaler_index, resize):
|
|
|
|
small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10))
|
|
|
|
pixels = tuple(np.array(small).flatten().tolist())
|
|
|
|
key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight) + pixels
|
2022-09-11 11:31:16 +03:00
|
|
|
|
2022-09-16 06:23:37 +03:00
|
|
|
c = cached_images.get(key)
|
|
|
|
if c is None:
|
|
|
|
upscaler = shared.sd_upscalers[scaler_index]
|
|
|
|
c = upscaler.upscale(image, image.width * resize, image.height * resize)
|
|
|
|
cached_images[key] = c
|
2022-09-11 11:31:16 +03:00
|
|
|
|
2022-09-16 06:23:37 +03:00
|
|
|
return c
|
2022-09-11 11:31:16 +03:00
|
|
|
|
2022-09-16 06:23:37 +03:00
|
|
|
info += f"Upscale: {round(upscaling_resize, 3)}, model:{shared.sd_upscalers[extras_upscaler_1].name}\n"
|
|
|
|
res = upscale(image, extras_upscaler_1, upscaling_resize)
|
2022-09-11 11:31:16 +03:00
|
|
|
|
2022-09-16 06:23:37 +03:00
|
|
|
if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0:
|
|
|
|
res2 = upscale(image, extras_upscaler_2, upscaling_resize)
|
|
|
|
info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {round(extras_upscaler_2_visibility, 3)}, model:{shared.sd_upscalers[extras_upscaler_2].name}\n"
|
|
|
|
res = Image.blend(res, res2, extras_upscaler_2_visibility)
|
2022-09-11 11:31:16 +03:00
|
|
|
|
2022-09-16 06:23:37 +03:00
|
|
|
image = res
|
2022-09-11 11:31:16 +03:00
|
|
|
|
2022-09-16 06:23:37 +03:00
|
|
|
while len(cached_images) > 2:
|
|
|
|
del cached_images[next(iter(cached_images.keys()))]
|
2022-09-11 11:31:16 +03:00
|
|
|
|
2022-09-20 01:13:12 +03:00
|
|
|
images.save_image(image, path=outpath, basename="", seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
|
|
|
|
no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo,
|
|
|
|
forced_filename=image_name if opts.use_original_name_batch else None)
|
2022-09-11 11:31:16 +03:00
|
|
|
|
2022-09-16 12:43:24 +03:00
|
|
|
outputs.append(image)
|
|
|
|
|
|
|
|
return outputs, plaintext_to_html(info), ''
|
2022-09-11 11:31:16 +03:00
|
|
|
|
|
|
|
|
2022-09-17 09:07:07 +03:00
|
|
|
def run_pnginfo(image):
|
2022-09-19 20:18:16 +03:00
|
|
|
if image is None:
|
|
|
|
return '', '', ''
|
|
|
|
|
2022-09-13 19:23:55 +03:00
|
|
|
items = image.info
|
2022-09-23 22:49:21 +03:00
|
|
|
geninfo = ''
|
2022-09-13 19:23:55 +03:00
|
|
|
|
|
|
|
if "exif" in image.info:
|
|
|
|
exif = piexif.load(image.info["exif"])
|
|
|
|
exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b'')
|
2022-09-14 15:20:05 +03:00
|
|
|
try:
|
|
|
|
exif_comment = piexif.helper.UserComment.load(exif_comment)
|
|
|
|
except ValueError:
|
|
|
|
exif_comment = exif_comment.decode('utf8', errors="ignore")
|
|
|
|
|
2022-09-13 19:23:55 +03:00
|
|
|
items['exif comment'] = exif_comment
|
2022-09-23 22:49:21 +03:00
|
|
|
geninfo = exif_comment
|
2022-09-13 19:23:55 +03:00
|
|
|
|
2022-09-16 23:48:22 +03:00
|
|
|
for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
|
|
|
|
'loop', 'background', 'timestamp', 'duration']:
|
|
|
|
items.pop(field, None)
|
2022-09-13 19:23:55 +03:00
|
|
|
|
2022-09-23 22:49:21 +03:00
|
|
|
geninfo = items.get('parameters', geninfo)
|
2022-09-13 19:23:55 +03:00
|
|
|
|
2022-09-11 11:31:16 +03:00
|
|
|
info = ''
|
2022-09-13 19:23:55 +03:00
|
|
|
for key, text in items.items():
|
2022-09-11 11:31:16 +03:00
|
|
|
info += f"""
|
|
|
|
<div>
|
|
|
|
<p><b>{plaintext_to_html(str(key))}</b></p>
|
|
|
|
<p>{plaintext_to_html(str(text))}</p>
|
|
|
|
</div>
|
|
|
|
""".strip()+"\n"
|
|
|
|
|
|
|
|
if len(info) == 0:
|
|
|
|
message = "Nothing found in the image."
|
|
|
|
info = f"<div><p>{message}<p></div>"
|
|
|
|
|
2022-09-23 22:49:21 +03:00
|
|
|
return '', geninfo, info
|
2022-09-26 02:22:12 +03:00
|
|
|
|
|
|
|
|
2022-09-28 04:09:28 +03:00
|
|
|
def run_modelmerger(primary_model_name, secondary_model_name, interp_method, interp_amount):
|
2022-09-26 17:50:21 +03:00
|
|
|
# Linear interpolation (https://en.wikipedia.org/wiki/Linear_interpolation)
|
|
|
|
def weighted_sum(theta0, theta1, alpha):
|
|
|
|
return ((1 - alpha) * theta0) + (alpha * theta1)
|
|
|
|
|
|
|
|
# Smoothstep (https://en.wikipedia.org/wiki/Smoothstep)
|
|
|
|
def sigmoid(theta0, theta1, alpha):
|
|
|
|
alpha = alpha * alpha * (3 - (2 * alpha))
|
|
|
|
return theta0 + ((theta1 - theta0) * alpha)
|
|
|
|
|
2022-09-28 04:09:28 +03:00
|
|
|
if os.path.exists(secondary_model_name):
|
|
|
|
secondary_model_filename = secondary_model_name
|
|
|
|
secondary_model_name = os.path.splitext(os.path.basename(secondary_model_name))[0]
|
2022-09-27 10:44:00 +03:00
|
|
|
else:
|
2022-09-28 04:09:28 +03:00
|
|
|
secondary_model_filename = 'models/' + secondary_model_name + '.ckpt'
|
2022-09-27 10:44:00 +03:00
|
|
|
|
2022-09-28 04:09:28 +03:00
|
|
|
if os.path.exists(primary_model_name):
|
|
|
|
primary_model_filename = primary_model_name
|
|
|
|
primary_model_name = os.path.splitext(os.path.basename(primary_model_name))[0]
|
2022-09-27 10:44:00 +03:00
|
|
|
else:
|
2022-09-28 04:09:28 +03:00
|
|
|
primary_model_filename = 'models/' + primary_model_name + '.ckpt'
|
2022-09-27 10:44:00 +03:00
|
|
|
|
2022-09-28 04:09:28 +03:00
|
|
|
print(f"Loading {secondary_model_filename}...")
|
|
|
|
model_0 = torch.load(secondary_model_filename, map_location='cpu')
|
2022-09-27 10:44:00 +03:00
|
|
|
|
2022-09-28 04:09:28 +03:00
|
|
|
print(f"Loading {primary_model_filename}...")
|
|
|
|
model_1 = torch.load(primary_model_filename, map_location='cpu')
|
2022-09-26 02:22:12 +03:00
|
|
|
|
|
|
|
theta_0 = model_0['state_dict']
|
|
|
|
theta_1 = model_1['state_dict']
|
2022-09-27 10:44:00 +03:00
|
|
|
|
|
|
|
theta_funcs = {
|
|
|
|
"Weighted Sum": weighted_sum,
|
|
|
|
"Sigmoid": sigmoid,
|
|
|
|
}
|
|
|
|
theta_func = theta_funcs[interp_method]
|
|
|
|
|
|
|
|
print(f"Merging...")
|
|
|
|
for key in tqdm.tqdm(theta_0.keys()):
|
2022-09-26 02:22:12 +03:00
|
|
|
if 'model' in key and key in theta_1:
|
2022-09-26 17:50:21 +03:00
|
|
|
theta_0[key] = theta_func(theta_0[key], theta_1[key], interp_amount)
|
2022-09-26 02:22:12 +03:00
|
|
|
|
|
|
|
for key in theta_1.keys():
|
|
|
|
if 'model' in key and key not in theta_0:
|
|
|
|
theta_0[key] = theta_1[key]
|
2022-09-27 10:44:00 +03:00
|
|
|
|
2022-09-28 04:09:28 +03:00
|
|
|
output_modelname = 'models/' + primary_model_name + '_' + str(interp_amount) + '-' + secondary_model_name + '_' + str(float(1.0) - interp_amount) + '-' + interp_method.replace(" ", "_") + '-merged.ckpt'
|
2022-09-27 10:44:00 +03:00
|
|
|
print(f"Saving to {output_modelname}...")
|
2022-09-26 02:22:12 +03:00
|
|
|
torch.save(model_0, output_modelname)
|
2022-09-27 10:44:00 +03:00
|
|
|
|
|
|
|
print(f"Checkpoint saved.")
|
|
|
|
return "Checkpoint saved to " + output_modelname
|