diff --git a/configs/webui/webui_streamlit.yaml b/configs/webui/webui_streamlit.yaml index 01bcc19..fa58c34 100644 --- a/configs/webui/webui_streamlit.yaml +++ b/configs/webui/webui_streamlit.yaml @@ -13,6 +13,7 @@ general: GFPGAN_dir: "./src/gfpgan" RealESRGAN_dir: "./src/realesrgan" RealESRGAN_model: "RealESRGAN_x4plus" + LDSR_dir: "./src/latent-diffusion" outdir_txt2img: outputs/txt2img-samples outdir_img2img: outputs/img2img-samples gfpgan_cpu: False diff --git a/scripts/sd_utils.py b/scripts/sd_utils.py index 4425a95..8299ac6 100644 --- a/scripts/sd_utils.py +++ b/scripts/sd_utils.py @@ -1,21 +1,20 @@ -import warnings +# base webui import and utils. +from webui_streamlit import st -import piexif -import piexif.helper + +# streamlit imports + + +#other imports + +import warnings import json -import streamlit as st -from streamlit import StopException - -#streamlit components section -from st_on_hover_tabs import on_hover_tabs - -import base64, cv2 -import os, sys, re, random, datetime, timeit -from PIL import Image, ImageFont, ImageDraw, ImageFilter, ImageOps +import base64 +import os, sys, re, random, datetime, time, math +from PIL import Image, ImageFont, ImageDraw, ImageFilter from PIL.PngImagePlugin import PngInfo from scipy import integrate -import pandas as pd import torch from torchdiffeq import odeint import k_diffusion as K @@ -24,36 +23,28 @@ import mimetypes import numpy as np import pynvml import threading -import time, inspect import torch from torch import autocast from torchvision import transforms import torch.nn as nn +from omegaconf import OmegaConf import yaml -from typing import Union from pathlib import Path -#from tqdm import tqdm from contextlib import nullcontext from einops import rearrange -from omegaconf import OmegaConf -from io import StringIO -from ldm.models.diffusion.ddim import DDIMSampler -from ldm.models.diffusion.plms import PLMSSampler from ldm.util import instantiate_from_config from retry import retry -import find_noise_for_image -import matched_noise - -# these are for testing txt2vid, should be removed and we should use things from our own code. -from diffusers import StableDiffusionPipeline -from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler - -#will be used for saving and reading a video made by the txt2vid function -import imageio, io - -# we use python-slugify to make the filenames safe for windows and linux, its better than doing it manually -# install it with 'pip install python-slugify' from slugify import slugify +import skimage +import piexif +import piexif.helper +from tqdm import trange + +# Temp imports + + +# end of imports +#--------------------------------------------------------------------------------------------------------------- try: # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start. @@ -241,58 +232,6 @@ def load_sd_from_config(ckpt, verbose=False): print(f"Global Step: {pl_sd['global_step']}") sd = pl_sd["state_dict"] return sd -# -@retry(tries=5) -def generation_callback(img, i=0): - - try: - if i == 0: - if img['i']: i = img['i'] - except TypeError: - pass - - - if i % int(defaults.general.update_preview_frequency) == 0 and defaults.general.update_preview: - #print (img) - #print (type(img)) - # The following lines will convert the tensor we got on img to an actual image we can render on the UI. - # It can probably be done in a better way for someone who knows what they're doing. I don't. - #print (img,isinstance(img, torch.Tensor)) - if isinstance(img, torch.Tensor): - x_samples_ddim = (st.session_state["model"] if not defaults.general.optimized else modelFS).decode_first_stage(img) - else: - # When using the k Diffusion samplers they return a dict instead of a tensor that look like this: - # {'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised} - x_samples_ddim = (st.session_state["model"] if not defaults.general.optimized else modelFS).decode_first_stage(img["denoised"]) - - x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) - - pil_image = transforms.ToPILImage()(x_samples_ddim.squeeze_(0)) - - # update image on the UI so we can see the progress - st.session_state["preview_image"].image(pil_image) - - # Show a progress bar so we can keep track of the progress even when the image progress is not been shown, - # Dont worry, it doesnt affect the performance. - if st.session_state["generation_mode"] == "txt2img": - percent = int(100 * float(i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps)/float(st.session_state.sampling_steps)) - st.session_state["progress_bar_text"].text( - f"Running step: {i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps}/{st.session_state.sampling_steps} {percent if percent < 100 else 100}%") - else: - if st.session_state["generation_mode"] == "img2img": - round_sampling_steps = round(st.session_state.sampling_steps * st.session_state["denoising_strength"]) - percent = int(100 * float(i+1 if i+1 < round_sampling_steps else round_sampling_steps)/float(round_sampling_steps)) - st.session_state["progress_bar_text"].text( - f"""Running step: {i+1 if i+1 < round_sampling_steps else round_sampling_steps}/{round_sampling_steps} {percent if percent < 100 else 100}%""") - else: - if st.session_state["generation_mode"] == "txt2vid": - percent = int(100 * float(i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps)/float(st.session_state.sampling_steps)) - st.session_state["progress_bar_text"].text( - f"Running step: {i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps}/{st.session_state.sampling_steps}" - f"{percent if percent < 100 else 100}%") - - st.session_state["progress_bar"].progress(percent if percent < 100 else 100) - class MemUsageMonitor(threading.Thread): @@ -379,6 +318,204 @@ def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'): sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return append_zero(sigmas).to(device) +# +# helper fft routines that keep ortho normalization and auto-shift before and after fft +def _fft2(data): + if data.ndim > 2: # has channels + out_fft = np.zeros((data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128) + for c in range(data.shape[2]): + c_data = data[:,:,c] + out_fft[:,:,c] = np.fft.fft2(np.fft.fftshift(c_data),norm="ortho") + out_fft[:,:,c] = np.fft.ifftshift(out_fft[:,:,c]) + else: # one channel + out_fft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128) + out_fft[:,:] = np.fft.fft2(np.fft.fftshift(data),norm="ortho") + out_fft[:,:] = np.fft.ifftshift(out_fft[:,:]) + + return out_fft + +def _ifft2(data): + if data.ndim > 2: # has channels + out_ifft = np.zeros((data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128) + for c in range(data.shape[2]): + c_data = data[:,:,c] + out_ifft[:,:,c] = np.fft.ifft2(np.fft.fftshift(c_data),norm="ortho") + out_ifft[:,:,c] = np.fft.ifftshift(out_ifft[:,:,c]) + else: # one channel + out_ifft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128) + out_ifft[:,:] = np.fft.ifft2(np.fft.fftshift(data),norm="ortho") + out_ifft[:,:] = np.fft.ifftshift(out_ifft[:,:]) + + return out_ifft + +def _get_gaussian_window(width, height, std=3.14, mode=0): + + window_scale_x = float(width / min(width, height)) + window_scale_y = float(height / min(width, height)) + + window = np.zeros((width, height)) + x = (np.arange(width) / width * 2. - 1.) * window_scale_x + for y in range(height): + fy = (y / height * 2. - 1.) * window_scale_y + if mode == 0: + window[:, y] = np.exp(-(x**2+fy**2) * std) + else: + window[:, y] = (1/((x**2+1.) * (fy**2+1.))) ** (std/3.14) # hey wait a minute that's not gaussian + + return window + +def _get_masked_window_rgb(np_mask_grey, hardness=1.): + np_mask_rgb = np.zeros((np_mask_grey.shape[0], np_mask_grey.shape[1], 3)) + if hardness != 1.: + hardened = np_mask_grey[:] ** hardness + else: + hardened = np_mask_grey[:] + for c in range(3): + np_mask_rgb[:,:,c] = hardened[:] + return np_mask_rgb + +def get_matched_noise(_np_src_image, np_mask_rgb, noise_q, color_variation): + """ + Explanation: + Getting good results in/out-painting with stable diffusion can be challenging. + Although there are simpler effective solutions for in-painting, out-painting can be especially challenging because there is no color data + in the masked area to help prompt the generator. Ideally, even for in-painting we'd like work effectively without that data as well. + Provided here is my take on a potential solution to this problem. + + By taking a fourier transform of the masked src img we get a function that tells us the presence and orientation of each feature scale in the unmasked src. + Shaping the init/seed noise for in/outpainting to the same distribution of feature scales, orientations, and positions increases output coherence + by helping keep features aligned. This technique is applicable to any continuous generation task such as audio or video, each of which can + be conceptualized as a series of out-painting steps where the last half of the input "frame" is erased. For multi-channel data such as color + or stereo sound the "color tone" or histogram of the seed noise can be matched to improve quality (using scikit-image currently) + This method is quite robust and has the added benefit of being fast independently of the size of the out-painted area. + The effects of this method include things like helping the generator integrate the pre-existing view distance and camera angle. + + Carefully managing color and brightness with histogram matching is also essential to achieving good coherence. + + noise_q controls the exponent in the fall-off of the distribution can be any positive number, lower values means higher detail (range > 0, default 1.) + color_variation controls how much freedom is allowed for the colors/palette of the out-painted area (range 0..1, default 0.01) + This code is provided as is under the Unlicense (https://unlicense.org/) + Although you have no obligation to do so, if you found this code helpful please find it in your heart to credit me [parlance-zz]. + + Questions or comments can be sent to parlance@fifth-harmonic.com (https://github.com/parlance-zz/) + This code is part of a new branch of a discord bot I am working on integrating with diffusers (https://github.com/parlance-zz/g-diffuser-bot) + + """ + + global DEBUG_MODE + global TMP_ROOT_PATH + + width = _np_src_image.shape[0] + height = _np_src_image.shape[1] + num_channels = _np_src_image.shape[2] + + np_src_image = _np_src_image[:] * (1. - np_mask_rgb) + np_mask_grey = (np.sum(np_mask_rgb, axis=2)/3.) + np_src_grey = (np.sum(np_src_image, axis=2)/3.) + all_mask = np.ones((width, height), dtype=bool) + img_mask = np_mask_grey > 1e-6 + ref_mask = np_mask_grey < 1e-3 + + windowed_image = _np_src_image * (1.-_get_masked_window_rgb(np_mask_grey)) + windowed_image /= np.max(windowed_image) + windowed_image += np.average(_np_src_image) * np_mask_rgb# / (1.-np.average(np_mask_rgb)) # rather than leave the masked area black, we get better results from fft by filling the average unmasked color + #windowed_image += np.average(_np_src_image) * (np_mask_rgb * (1.- np_mask_rgb)) / (1.-np.average(np_mask_rgb)) # compensate for darkening across the mask transition area + #_save_debug_img(windowed_image, "windowed_src_img") + + src_fft = _fft2(windowed_image) # get feature statistics from masked src img + src_dist = np.absolute(src_fft) + src_phase = src_fft / src_dist + #_save_debug_img(src_dist, "windowed_src_dist") + + noise_window = _get_gaussian_window(width, height, mode=1) # start with simple gaussian noise + noise_rgb = np.random.random_sample((width, height, num_channels)) + noise_grey = (np.sum(noise_rgb, axis=2)/3.) + noise_rgb *= color_variation # the colorfulness of the starting noise is blended to greyscale with a parameter + for c in range(num_channels): + noise_rgb[:,:,c] += (1. - color_variation) * noise_grey + + noise_fft = _fft2(noise_rgb) + for c in range(num_channels): + noise_fft[:,:,c] *= noise_window + noise_rgb = np.real(_ifft2(noise_fft)) + shaped_noise_fft = _fft2(noise_rgb) + shaped_noise_fft[:,:,:] = np.absolute(shaped_noise_fft[:,:,:])**2 * (src_dist ** noise_q) * src_phase # perform the actual shaping + + brightness_variation = 0.#color_variation # todo: temporarily tieing brightness variation to color variation for now + contrast_adjusted_np_src = _np_src_image[:] * (brightness_variation + 1.) - brightness_variation * 2. + + # scikit-image is used for histogram matching, very convenient! + shaped_noise = np.real(_ifft2(shaped_noise_fft)) + shaped_noise -= np.min(shaped_noise) + shaped_noise /= np.max(shaped_noise) + shaped_noise[img_mask,:] = skimage.exposure.match_histograms(shaped_noise[img_mask,:]**1., contrast_adjusted_np_src[ref_mask,:], channel_axis=1) + shaped_noise = _np_src_image[:] * (1. - np_mask_rgb) + shaped_noise * np_mask_rgb + #_save_debug_img(shaped_noise, "shaped_noise") + + matched_noise = np.zeros((width, height, num_channels)) + matched_noise = shaped_noise[:] + #matched_noise[all_mask,:] = skimage.exposure.match_histograms(shaped_noise[all_mask,:], _np_src_image[ref_mask,:], channel_axis=1) + #matched_noise = _np_src_image[:] * (1. - np_mask_rgb) + matched_noise * np_mask_rgb + + #_save_debug_img(matched_noise, "matched_noise") + + """ + todo: + color_variation doesnt have to be a single number, the overall color tone of the out-painted area could be param controlled + """ + + return np.clip(matched_noise, 0., 1.) + + +# +def find_noise_for_image(model, device, init_image, prompt, steps=200, cond_scale=2.0, verbose=False, normalize=False, generation_callback=None): + image = np.array(init_image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + image = 2. * image - 1. + image = image.to(device) + x = model.get_first_stage_encoding(model.encode_first_stage(image)) + + uncond = model.get_learned_conditioning(['']) + cond = model.get_learned_conditioning([prompt]) + + s_in = x.new_ones([x.shape[0]]) + dnw = K.external.CompVisDenoiser(model) + sigmas = dnw.get_sigmas(steps).flip(0) + + if verbose: + print(sigmas) + + for i in trange(1, len(sigmas)): + x_in = torch.cat([x] * 2) + sigma_in = torch.cat([sigmas[i - 1] * s_in] * 2) + cond_in = torch.cat([uncond, cond]) + + c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)] + + if i == 1: + t = dnw.sigma_to_t(torch.cat([sigmas[i] * s_in] * 2)) + else: + t = dnw.sigma_to_t(sigma_in) + + eps = model.apply_model(x_in * c_in, t, cond=cond_in) + denoised_uncond, denoised_cond = (x_in + eps * c_out).chunk(2) + + denoised = denoised_uncond + (denoised_cond - denoised_uncond) * cond_scale + + if i == 1: + d = (x - denoised) / (2 * sigmas[i]) + else: + d = (x - denoised) / sigmas[i - 1] + + if generation_callback is not None: + generation_callback(x, i) + + dt = sigmas[i] - sigmas[i - 1] + x = x + d * dt + + return x / sigmas[-1] + def get_sigmas_exponential(n, sigma_min, sigma_max, device='cpu'): """Constructs an exponential noise schedule.""" @@ -511,6 +648,190 @@ def load_RealESRGAN(model_name: str): return instance +# +def load_LDSR(checking=False): + model_name = 'model' + yaml_name = 'project' + model_path = os.path.join(defaults.general.LDSR_dir, 'experiments/pretrained_models', model_name + '.ckpt') + yaml_path = os.path.join(defaults.general.LDSR_dir, 'experiments/pretrained_models', yaml_name + '.yaml') + if not os.path.isfile(model_path): + raise Exception("LDSR model not found at path "+model_path) + if not os.path.isfile(yaml_path): + raise Exception("LDSR model not found at path "+yaml_path) + if checking == True: + return True + + sys.path.append(os.path.abspath(defaults.general.LDSR_dir)) + from LDSR import LDSR + LDSRObject = LDSR(model_path, yaml_path) + return LDSRObject + +# +LDSR = None +def try_loading_LDSR(model_name: str,checking=False): + global LDSR + if os.path.exists(defaults.general.LDSR_dir): + try: + LDSR = load_LDSR(checking=True) # TODO: Should try to load both models before giving up + if checking == True: + print("Found LDSR") + return True + print("Latent Diffusion Super Sampling (LDSR) model loaded") + except Exception: + import traceback + print("Error loading LDSR:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + else: + print("LDSR not found at path, please make sure you have cloned the LDSR repo to ./src/latent-diffusion/") + +#try_loading_LDSR('model',checking=True) + +def load_SD_model(): + if defaults.general.optimized: + sd = load_sd_from_config(defaults.general.default_model_path) + li, lo = [], [] + for key, v_ in sd.items(): + sp = key.split('.') + if(sp[0]) == 'model': + if('input_blocks' in sp): + li.append(key) + elif('middle_block' in sp): + li.append(key) + elif('time_embed' in sp): + li.append(key) + else: + lo.append(key) + for key in li: + sd['model1.' + key[6:]] = sd.pop(key) + for key in lo: + sd['model2.' + key[6:]] = sd.pop(key) + + config = OmegaConf.load("optimizedSD/v1-inference.yaml") + device = torch.device(f"cuda:{opt.gpu}") if torch.cuda.is_available() else torch.device("cpu") + + model = instantiate_from_config(config.modelUNet) + _, _ = model.load_state_dict(sd, strict=False) + model.cuda() + model.eval() + model.turbo = defaults.general.optimized_turbo + + modelCS = instantiate_from_config(config.modelCondStage) + _, _ = modelCS.load_state_dict(sd, strict=False) + modelCS.cond_stage_model.device = device + modelCS.eval() + + modelFS = instantiate_from_config(config.modelFirstStage) + _, _ = modelFS.load_state_dict(sd, strict=False) + modelFS.eval() + + del sd + + if not defaults.general.no_half: + model = model.half() + modelCS = modelCS.half() + modelFS = modelFS.half() + return model,modelCS,modelFS,device, config + else: + config = OmegaConf.load(defaults.general.default_model_config) + model = load_model_from_config(config, defaults.general.default_model_path) + + device = torch.device(f"cuda:{opt.gpu}") if torch.cuda.is_available() else torch.device("cpu") + model = (model if defaults.general.no_half else model.half()).to(device) + return model, device,config + +# + +# +def ModelLoader(models,load=False,unload=False,imgproc_realesrgan_model_name='RealESRGAN_x4plus'): + #get global variables + global_vars = globals() + #check if m is in globals + if unload: + for m in models: + if m in global_vars: + #if it is, delete it + del global_vars[m] + if defaults.general.optimized: + if m == 'model': + del global_vars[m+'FS'] + del global_vars[m+'CS'] + if m =='model': + m='Stable Diffusion' + print('Unloaded ' + m) + if load: + for m in models: + if m not in global_vars or m in global_vars and type(global_vars[m]) == bool: + #if it isn't, load it + if m == 'GFPGAN': + global_vars[m] = load_GFPGAN() + elif m == 'model': + sdLoader = load_sd_from_config() + global_vars[m] = sdLoader[0] + if defaults.general.optimized: + global_vars[m+'CS'] = sdLoader[1] + global_vars[m+'FS'] = sdLoader[2] + elif m == 'RealESRGAN': + global_vars[m] = load_RealESRGAN(imgproc_realesrgan_model_name) + elif m == 'LDSR': + global_vars[m] = load_LDSR() + if m =='model': + m='Stable Diffusion' + print('Loaded ' + m) + torch_gc() + + +# +@retry(tries=5) +def generation_callback(img, i=0): + + try: + if i == 0: + if img['i']: i = img['i'] + except TypeError: + pass + + if i % int(defaults.general.update_preview_frequency) == 0 and defaults.general.update_preview: + #print (img) + #print (type(img)) + # The following lines will convert the tensor we got on img to an actual image we can render on the UI. + # It can probably be done in a better way for someone who knows what they're doing. I don't. + #print (img,isinstance(img, torch.Tensor)) + if isinstance(img, torch.Tensor): + x_samples_ddim = (st.session_state["model"] if not defaults.general.optimized else modelFS).decode_first_stage(img) + else: + # When using the k Diffusion samplers they return a dict instead of a tensor that look like this: + # {'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised} + x_samples_ddim = (st.session_state["model"] if not defaults.general.optimized else modelFS).decode_first_stage(img["denoised"]) + + x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + + pil_image = transforms.ToPILImage()(x_samples_ddim.squeeze_(0)) + + # update image on the UI so we can see the progress + st.session_state["preview_image"].image(pil_image) + + # Show a progress bar so we can keep track of the progress even when the image progress is not been shown, + # Dont worry, it doesnt affect the performance. + if st.session_state["generation_mode"] == "txt2img": + percent = int(100 * float(i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps)/float(st.session_state.sampling_steps)) + st.session_state["progress_bar_text"].text( + f"Running step: {i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps}/{st.session_state.sampling_steps} {percent if percent < 100 else 100}%") + else: + if st.session_state["generation_mode"] == "img2img": + round_sampling_steps = round(st.session_state.sampling_steps * st.session_state["denoising_strength"]) + percent = int(100 * float(i+1 if i+1 < round_sampling_steps else round_sampling_steps)/float(round_sampling_steps)) + st.session_state["progress_bar_text"].text( + f"""Running step: {i+1 if i+1 < round_sampling_steps else round_sampling_steps}/{round_sampling_steps} {percent if percent < 100 else 100}%""") + else: + if st.session_state["generation_mode"] == "txt2vid": + percent = int(100 * float(i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps)/float(st.session_state.sampling_steps)) + st.session_state["progress_bar_text"].text( + f"Running step: {i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps}/{st.session_state.sampling_steps}" + f"{percent if percent < 100 else 100}%") + + st.session_state["progress_bar"].progress(percent if percent < 100 else 100) + + prompt_parser = re.compile(""" (?P # capture group for 'prompt' [^:]+ # match one or more non ':' characters @@ -574,44 +895,6 @@ def optimize_update_preview_frequency(current_chunk_speed, previous_chunk_speed, return current_chunk_speed, previous_chunk_speed, update_preview_frequency -# -def ModelLoader(models,load=False,unload=False,imgproc_realesrgan_model_name='RealESRGAN_x4plus'): - #get global variables - global_vars = globals() - #check if m is in globals - if unload: - for m in models: - if m in global_vars: - #if it is, delete it - del global_vars[m] - if defaults.general.optimized: - if m == 'model': - del global_vars[m+'FS'] - del global_vars[m+'CS'] - if m =='model': - m='Stable Diffusion' - print('Unloaded ' + m) - if load: - for m in models: - if m not in global_vars or m in global_vars and type(global_vars[m]) == bool: - #if it isn't, load it - if m == 'GFPGAN': - global_vars[m] = load_GFPGAN() - elif m == 'model': - sdLoader = load_sd_from_config() - global_vars[m] = sdLoader[0] - if defaults.general.optimized: - global_vars[m+'CS'] = sdLoader[1] - global_vars[m+'FS'] = sdLoader[2] - elif m == 'RealESRGAN': - global_vars[m] = load_RealESRGAN(imgproc_realesrgan_model_name) - elif m == 'LDSR': - global_vars[m] = load_LDSR() - if m =='model': - m='Stable Diffusion' - print('Loaded ' + m) - torch_gc() - def get_font(fontsize): @@ -679,6 +962,71 @@ def seed_to_int(s): n = n >> 32 return n +# +def draw_prompt_matrix(im, width, height, all_prompts): + def wrap(text, d, font, line_length): + lines = [''] + for word in text.split(): + line = f'{lines[-1]} {word}'.strip() + if d.textlength(line, font=font) <= line_length: + lines[-1] = line + else: + lines.append(word) + return '\n'.join(lines) + + def draw_texts(pos, x, y, texts, sizes): + for i, (text, size) in enumerate(zip(texts, sizes)): + active = pos & (1 << i) != 0 + + if not active: + text = '\u0336'.join(text) + '\u0336' + + d.multiline_text((x, y + size[1] / 2), text, font=fnt, fill=color_active if active else color_inactive, anchor="mm", align="center") + + y += size[1] + line_spacing + + fontsize = (width + height) // 25 + line_spacing = fontsize // 2 + fnt = get_font(fontsize) + color_active = (0, 0, 0) + color_inactive = (153, 153, 153) + + pad_top = height // 4 + pad_left = width * 3 // 4 if len(all_prompts) > 2 else 0 + + cols = im.width // width + rows = im.height // height + + prompts = all_prompts[1:] + + result = Image.new("RGB", (im.width + pad_left, im.height + pad_top), "white") + result.paste(im, (pad_left, pad_top)) + + d = ImageDraw.Draw(result) + + boundary = math.ceil(len(prompts) / 2) + prompts_horiz = [wrap(x, d, fnt, width) for x in prompts[:boundary]] + prompts_vert = [wrap(x, d, fnt, pad_left) for x in prompts[boundary:]] + + sizes_hor = [(x[2] - x[0], x[3] - x[1]) for x in [d.multiline_textbbox((0, 0), x, font=fnt) for x in prompts_horiz]] + sizes_ver = [(x[2] - x[0], x[3] - x[1]) for x in [d.multiline_textbbox((0, 0), x, font=fnt) for x in prompts_vert]] + hor_text_height = sum([x[1] + line_spacing for x in sizes_hor]) - line_spacing + ver_text_height = sum([x[1] + line_spacing for x in sizes_ver]) - line_spacing + + for col in range(cols): + x = pad_left + width * col + width / 2 + y = pad_top / 2 - hor_text_height / 2 + + draw_texts(col, x, y, prompts_horiz, sizes_hor) + + for row in range(rows): + x = pad_left / 2 + y = pad_top + height * row + height / 2 - ver_text_height / 2 + + draw_texts(row, x, y, prompts_vert, sizes_ver) + + return result + def check_prompt_length(prompt, comments): """this function tests if prompt is too long, and if so, adds a message to comments"""