diff --git a/requirements.txt b/requirements.txt index 6f69981..3247e3e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -56,11 +56,11 @@ timm==0.6.7 tqdm==4.64.0 tensorboard==2.10.1 - # Other retry==0.9.2 # used by sd_utils python-slugify==6.1.2 # used by sd_utils piexif==1.1.3 # used by sd_utils +pywebview==3.6.3 # used by streamlit_webview.py accelerate==0.12.0 albumentations==0.4.3 diff --git a/scripts/ModelManager.py b/scripts/ModelManager.py index 07abc07..3f86ca1 100644 --- a/scripts/ModelManager.py +++ b/scripts/ModelManager.py @@ -45,12 +45,12 @@ def download_file(file_name, file_path, file_url): raise OSError("You need a huggingface token in order to use the Text to Video tab. Use the Settings page from the sidebar on the left to add your token.") try: - with requests.get(file_url, auth = HTTPBasicAuth('token', st.session_state.defaults.general.huggingface_token), stream=True) as r: + with requests.get(file_url, auth = HTTPBasicAuth('token', st.session_state.defaults.general.huggingface_token) if "huggingface.co" in file_url else None, stream=True) as r: r.raise_for_status() with open(os.path.join(file_path, file_name), 'wb') as f: for chunk in stqdm(r.iter_content(chunk_size=8192), backend=True, unit="kb"): f.write(chunk) - except HTTPError: + except HTTPError as e: if "huggingface.co" in file_url: if "resolve"in file_url: repo_url = file_url.split("resolve")[0] @@ -59,9 +59,12 @@ def download_file(file_name, file_path, file_url): f"You need to accept the license for the model in order to be able to download it. " f"Please visit {repo_url} and accept the lincense there, then try again to download the model.") + logger.error(e) + else: print(file_name + ' already exists.') + def download_model(models, model_name): """ Download all files from model_list[model_name] """ for file in models[model_name]: diff --git a/scripts/img2txt.py b/scripts/img2txt.py index f9a7f44..72f260a 100644 --- a/scripts/img2txt.py +++ b/scripts/img2txt.py @@ -66,6 +66,9 @@ st.session_state["log"] = [] def load_blip_model(): logger.info("Loading BLIP Model") + if "log" not in st.session_state: + st.session_state["log"] = [] + st.session_state["log"].append("Loading BLIP Model") st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='') @@ -232,7 +235,7 @@ def interrogate(image, models): for best in bests: best.sort(key=lambda x: x[1], reverse=True) - # prune to 3 + # prune to 3 best = best[:3] row = [model_name] @@ -326,7 +329,7 @@ def img2txt(): def layout(): #set_page_title("Image-to-Text - Stable Diffusion WebUI") #st.info("Under Construction. :construction_worker:") - # + # if "clip_models" not in server_state: server_state["clip_models"] = {} if "preprocesses" not in server_state: @@ -397,7 +400,9 @@ def layout(): with col2: st.subheader("Image") - refresh = st.form_submit_button("Refresh", help='Refresh the image preview to show your uploaded image instead of the default placeholder.') + image_col1, image_col2 = st.columns([10,25]) + with image_col1: + refresh = st.form_submit_button("Update Preview Image", help='Refresh the image preview to show your uploaded image instead of the default placeholder.') if st.session_state["uploaded_image"]: #print (type(st.session_state["uploaded_image"])) @@ -436,11 +441,12 @@ def layout(): #st.session_state["input_image_preview"].code('', language="") st.image("images/streamlit/img2txt_placeholder.png", clamp=True) - # - # Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way. - # generate_col1.title("") - # generate_col1.title("") - generate_button = st.form_submit_button("Generate!") + with image_col2: + # + # Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way. + # generate_col1.title("") + # generate_col1.title("") + generate_button = st.form_submit_button("Generate!", help="Start interrogating the images to generate a prompt from each of the selected images") if generate_button: # if model, pipe, RealESRGAN or GFPGAN is in st.session_state remove the model and pipe form session_state so that they are reloaded. diff --git a/scripts/nataili/__init__.py b/scripts/nataili/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/scripts/nataili/inference/__init__.py b/scripts/nataili/inference/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/scripts/nataili/inference/compvis/__init__.py b/scripts/nataili/inference/compvis/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/scripts/nataili/inference/compvis/img2img.py b/scripts/nataili/inference/compvis/img2img.py new file mode 100644 index 0000000..d55474b --- /dev/null +++ b/scripts/nataili/inference/compvis/img2img.py @@ -0,0 +1,551 @@ +import os +import re +import sys +import k_diffusion as K +import tqdm +from contextlib import contextmanager, nullcontext +import skimage +import numpy as np +import PIL +import torch +from einops import rearrange +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.models.diffusion.kdiffusion import CFGMaskedDenoiser, KDiffusionSampler +from ldm.models.diffusion.plms import PLMSSampler +from nataili.util.cache import torch_gc +from nataili.util.check_prompt_length import check_prompt_length +from nataili.util.get_next_sequence_number import get_next_sequence_number +from nataili.util.image_grid import image_grid +from nataili.util.load_learned_embed_in_clip import load_learned_embed_in_clip +from nataili.util.save_sample import save_sample +from nataili.util.seed_to_int import seed_to_int +from slugify import slugify +import PIL + + +class img2img: + def __init__(self, model, device, output_dir, save_extension='jpg', + output_file_path=False, load_concepts=False, concepts_dir=None, + verify_input=True, auto_cast=True): + self.model = model + self.output_dir = output_dir + self.output_file_path = output_file_path + self.save_extension = save_extension + self.load_concepts = load_concepts + self.concepts_dir = concepts_dir + self.verify_input = verify_input + self.auto_cast = auto_cast + self.device = device + self.comments = [] + self.output_images = [] + self.info = '' + self.stats = '' + self.images = [] + + def create_random_tensors(self, shape, seeds): + xs = [] + for seed in seeds: + torch.manual_seed(seed) + + # randn results depend on device; gpu and cpu get different results for same seed; + # the way I see it, it's better to do this on CPU, so that everyone gets same result; + # but the original script had it like this so i do not dare change it for now because + # it will break everyone's seeds. + xs.append(torch.randn(shape, device=self.device)) + x = torch.stack(xs) + return x + + def process_prompt_tokens(self, prompt_tokens): + # compviz codebase + tokenizer = self.model.cond_stage_model.tokenizer + text_encoder = self.model.cond_stage_model.transformer + + # diffusers codebase + #tokenizer = pipe.tokenizer + #text_encoder = pipe.text_encoder + + ext = ('.pt', '.bin') + for token_name in prompt_tokens: + embedding_path = os.path.join(self.concepts_dir, token_name) + if os.path.exists(embedding_path): + for files in os.listdir(embedding_path): + if files.endswith(ext): + load_learned_embed_in_clip(f"{os.path.join(embedding_path, files)}", text_encoder, tokenizer, f"<{token_name}>") + else: + print(f"Concept {token_name} not found in {self.concepts_dir}") + del tokenizer, text_encoder + return + del tokenizer, text_encoder + + def resize_image(self, resize_mode, im, width, height): + LANCZOS = (PIL.Image.Resampling.LANCZOS if hasattr(PIL.Image, 'Resampling') else PIL.Image.LANCZOS) + if resize_mode == "resize": + res = im.resize((width, height), resample=LANCZOS) + elif resize_mode == "crop": + ratio = width / height + src_ratio = im.width / im.height + + src_w = width if ratio > src_ratio else im.width * height // im.height + src_h = height if ratio <= src_ratio else im.height * width // im.width + + resized = im.resize((src_w, src_h), resample=LANCZOS) + res = PIL.Image.new("RGBA", (width, height)) + res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2)) + else: + ratio = width / height + src_ratio = im.width / im.height + + src_w = width if ratio < src_ratio else im.width * height // im.height + src_h = height if ratio >= src_ratio else im.height * width // im.width + + resized = im.resize((src_w, src_h), resample=LANCZOS) + res = PIL.Image.new("RGBA", (width, height)) + res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2)) + + if ratio < src_ratio: + fill_height = height // 2 - src_h // 2 + res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0)) + res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h)) + elif ratio > src_ratio: + fill_width = width // 2 - src_w // 2 + res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0)) + res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0)) + + return res + + # + # helper fft routines that keep ortho normalization and auto-shift before and after fft + def _fft2(self, data): + if data.ndim > 2: # has channels + out_fft = np.zeros((data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128) + for c in range(data.shape[2]): + c_data = data[:,:,c] + out_fft[:,:,c] = np.fft.fft2(np.fft.fftshift(c_data),norm="ortho") + out_fft[:,:,c] = np.fft.ifftshift(out_fft[:,:,c]) + else: # one channel + out_fft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128) + out_fft[:,:] = np.fft.fft2(np.fft.fftshift(data),norm="ortho") + out_fft[:,:] = np.fft.ifftshift(out_fft[:,:]) + + return out_fft + + def _ifft2(self, data): + if data.ndim > 2: # has channels + out_ifft = np.zeros((data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128) + for c in range(data.shape[2]): + c_data = data[:,:,c] + out_ifft[:,:,c] = np.fft.ifft2(np.fft.fftshift(c_data),norm="ortho") + out_ifft[:,:,c] = np.fft.ifftshift(out_ifft[:,:,c]) + else: # one channel + out_ifft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128) + out_ifft[:,:] = np.fft.ifft2(np.fft.fftshift(data),norm="ortho") + out_ifft[:,:] = np.fft.ifftshift(out_ifft[:,:]) + + return out_ifft + + def _get_gaussian_window(self, width, height, std=3.14, mode=0): + + window_scale_x = float(width / min(width, height)) + window_scale_y = float(height / min(width, height)) + + window = np.zeros((width, height)) + x = (np.arange(width) / width * 2. - 1.) * window_scale_x + for y in range(height): + fy = (y / height * 2. - 1.) * window_scale_y + if mode == 0: + window[:, y] = np.exp(-(x**2+fy**2) * std) + else: + window[:, y] = (1/((x**2+1.) * (fy**2+1.))) ** (std/3.14) # hey wait a minute that's not gaussian + + return window + + def _get_masked_window_rgb(self, np_mask_grey, hardness=1.): + np_mask_rgb = np.zeros((np_mask_grey.shape[0], np_mask_grey.shape[1], 3)) + if hardness != 1.: + hardened = np_mask_grey[:] ** hardness + else: + hardened = np_mask_grey[:] + for c in range(3): + np_mask_rgb[:,:,c] = hardened[:] + return np_mask_rgb + + def get_matched_noise(self, _np_src_image, np_mask_rgb, noise_q, color_variation): + """ + Explanation: + Getting good results in/out-painting with stable diffusion can be challenging. + Although there are simpler effective solutions for in-painting, out-painting can be especially challenging because there is no color data + in the masked area to help prompt the generator. Ideally, even for in-painting we'd like work effectively without that data as well. + Provided here is my take on a potential solution to this problem. + + By taking a fourier transform of the masked src img we get a function that tells us the presence and orientation of each feature scale in the unmasked src. + Shaping the init/seed noise for in/outpainting to the same distribution of feature scales, orientations, and positions increases output coherence + by helping keep features aligned. This technique is applicable to any continuous generation task such as audio or video, each of which can + be conceptualized as a series of out-painting steps where the last half of the input "frame" is erased. For multi-channel data such as color + or stereo sound the "color tone" or histogram of the seed noise can be matched to improve quality (using scikit-image currently) + This method is quite robust and has the added benefit of being fast independently of the size of the out-painted area. + The effects of this method include things like helping the generator integrate the pre-existing view distance and camera angle. + + Carefully managing color and brightness with histogram matching is also essential to achieving good coherence. + + noise_q controls the exponent in the fall-off of the distribution can be any positive number, lower values means higher detail (range > 0, default 1.) + color_variation controls how much freedom is allowed for the colors/palette of the out-painted area (range 0..1, default 0.01) + This code is provided as is under the Unlicense (https://unlicense.org/) + Although you have no obligation to do so, if you found this code helpful please find it in your heart to credit me [parlance-zz]. + + Questions or comments can be sent to parlance@fifth-harmonic.com (https://github.com/parlance-zz/) + This code is part of a new branch of a discord bot I am working on integrating with diffusers (https://github.com/parlance-zz/g-diffuser-bot) + + """ + + global DEBUG_MODE + global TMP_ROOT_PATH + + width = _np_src_image.shape[0] + height = _np_src_image.shape[1] + num_channels = _np_src_image.shape[2] + + np_src_image = _np_src_image[:] * (1. - np_mask_rgb) + np_mask_grey = (np.sum(np_mask_rgb, axis=2)/3.) + np_src_grey = (np.sum(np_src_image, axis=2)/3.) + all_mask = np.ones((width, height), dtype=bool) + img_mask = np_mask_grey > 1e-6 + ref_mask = np_mask_grey < 1e-3 + + windowed_image = _np_src_image * (1.-self._get_masked_window_rgb(np_mask_grey)) + windowed_image /= np.max(windowed_image) + windowed_image += np.average(_np_src_image) * np_mask_rgb# / (1.-np.average(np_mask_rgb)) # rather than leave the masked area black, we get better results from fft by filling the average unmasked color + #windowed_image += np.average(_np_src_image) * (np_mask_rgb * (1.- np_mask_rgb)) / (1.-np.average(np_mask_rgb)) # compensate for darkening across the mask transition area + #_save_debug_img(windowed_image, "windowed_src_img") + + src_fft = self._fft2(windowed_image) # get feature statistics from masked src img + src_dist = np.absolute(src_fft) + src_phase = src_fft / src_dist + #_save_debug_img(src_dist, "windowed_src_dist") + + noise_window = self._get_gaussian_window(width, height, mode=1) # start with simple gaussian noise + noise_rgb = np.random.random_sample((width, height, num_channels)) + noise_grey = (np.sum(noise_rgb, axis=2)/3.) + noise_rgb *= color_variation # the colorfulness of the starting noise is blended to greyscale with a parameter + for c in range(num_channels): + noise_rgb[:,:,c] += (1. - color_variation) * noise_grey + + noise_fft = self._fft2(noise_rgb) + for c in range(num_channels): + noise_fft[:,:,c] *= noise_window + noise_rgb = np.real(self._ifft2(noise_fft)) + shaped_noise_fft = self._fft2(noise_rgb) + shaped_noise_fft[:,:,:] = np.absolute(shaped_noise_fft[:,:,:])**2 * (src_dist ** noise_q) * src_phase # perform the actual shaping + + brightness_variation = 0.#color_variation # todo: temporarily tieing brightness variation to color variation for now + contrast_adjusted_np_src = _np_src_image[:] * (brightness_variation + 1.) - brightness_variation * 2. + + # scikit-image is used for histogram matching, very convenient! + shaped_noise = np.real(self._ifft2(shaped_noise_fft)) + shaped_noise -= np.min(shaped_noise) + shaped_noise /= np.max(shaped_noise) + shaped_noise[img_mask,:] = skimage.exposure.match_histograms(shaped_noise[img_mask,:]**1., contrast_adjusted_np_src[ref_mask,:], channel_axis=1) + shaped_noise = _np_src_image[:] * (1. - np_mask_rgb) + shaped_noise * np_mask_rgb + #_save_debug_img(shaped_noise, "shaped_noise") + + matched_noise = np.zeros((width, height, num_channels)) + matched_noise = shaped_noise[:] + #matched_noise[all_mask,:] = skimage.exposure.match_histograms(shaped_noise[all_mask,:], _np_src_image[ref_mask,:], channel_axis=1) + #matched_noise = _np_src_image[:] * (1. - np_mask_rgb) + matched_noise * np_mask_rgb + + #_save_debug_img(matched_noise, "matched_noise") + + """ + todo: + color_variation doesnt have to be a single number, the overall color tone of the out-painted area could be param controlled + """ + + return np.clip(matched_noise, 0., 1.) + + def find_noise_for_image(self, model, device, init_image, prompt, steps=200, cond_scale=2.0, verbose=False, normalize=False, generation_callback=None): + image = np.array(init_image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + image = 2. * image - 1. + image = image.to(device) + x = model.get_first_stage_encoding(model.encode_first_stage(image)) + + uncond = model.get_learned_conditioning(['']) + cond = model.get_learned_conditioning([prompt]) + + s_in = x.new_ones([x.shape[0]]) + dnw = K.external.CompVisDenoiser(model) + sigmas = dnw.get_sigmas(steps).flip(0) + + if verbose: + print(sigmas) + + for i in tqdm.trange(1, len(sigmas)): + x_in = torch.cat([x] * 2) + sigma_in = torch.cat([sigmas[i - 1] * s_in] * 2) + cond_in = torch.cat([uncond, cond]) + + c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)] + + if i == 1: + t = dnw.sigma_to_t(torch.cat([sigmas[i] * s_in] * 2)) + else: + t = dnw.sigma_to_t(sigma_in) + + eps = model.apply_model(x_in * c_in, t, cond=cond_in) + denoised_uncond, denoised_cond = (x_in + eps * c_out).chunk(2) + + denoised = denoised_uncond + (denoised_cond - denoised_uncond) * cond_scale + + if i == 1: + d = (x - denoised) / (2 * sigmas[i]) + else: + d = (x - denoised) / sigmas[i - 1] + + dt = sigmas[i] - sigmas[i - 1] + x = x + d * dt + + return x / sigmas[-1] + + def generate(self, prompt: str, init_img=None, init_mask=None, mask_mode='mask', resize_mode='resize', noise_mode='seed', + denoising_strength:float=0.8, ddim_steps=50, sampler_name='k_lms', n_iter=1, batch_size=1, cfg_scale=7.5, seed=None, + height=512, width=512, save_individual_images: bool = True, save_grid: bool = True, ddim_eta:float = 0.0): + seed = seed_to_int(seed) + image_dict = { + "seed": seed + } + # Init image is assumed to be a PIL image + init_img = self.resize_image('resize', init_img, width, height) + if sampler_name == 'PLMS': + sampler = PLMSSampler(self.model) + elif sampler_name == 'DDIM': + sampler = DDIMSampler(self.model) + elif sampler_name == 'k_dpm_2_a': + sampler = KDiffusionSampler(self.model,'dpm_2_ancestral') + elif sampler_name == 'k_dpm_2': + sampler = KDiffusionSampler(self.model,'dpm_2') + elif sampler_name == 'k_euler_a': + sampler = KDiffusionSampler(self.model,'euler_ancestral') + elif sampler_name == 'k_euler': + sampler = KDiffusionSampler(self.model,'euler') + elif sampler_name == 'k_heun': + sampler = KDiffusionSampler(self.model,'heun') + elif sampler_name == 'k_lms': + sampler = KDiffusionSampler(self.model,'lms') + else: + raise Exception("Unknown sampler: " + sampler_name) + + torch_gc() + def process_init_mask(init_mask: PIL.Image): + if init_mask.mode == "RGBA": + init_mask = init_mask.convert('RGBA') + background = PIL.Image.new('RGBA', init_mask.size, (0, 0, 0)) + init_mask = PIL.Image.alpha_composite(background, init_mask) + init_mask = init_mask.convert('RGB') + return init_mask + + if mask_mode == "mask": + if init_mask: + init_mask = process_init_mask(init_mask) + elif mask_mode == "invert": + if init_mask: + init_mask = process_init_mask(init_mask) + init_mask = PIL.ImageOps.invert(init_mask) + elif mask_mode == "alpha": + init_img_transparency = init_img.split()[-1].convert('L')#.point(lambda x: 255 if x > 0 else 0, mode='1') + init_mask = init_img_transparency + init_mask = init_mask.convert("RGB") + init_mask = self.resize_image(resize_mode, init_mask, width, height) + init_mask = init_mask.convert("RGB") + + assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]' + t_enc = int(denoising_strength * ddim_steps) + + if init_mask is not None and (noise_mode == "matched" or noise_mode == "find_and_matched") and init_img is not None: + noise_q = 0.99 + color_variation = 0.0 + mask_blend_factor = 1.0 + + np_init = (np.asarray(init_img.convert("RGB"))/255.0).astype(np.float64) # annoyingly complex mask fixing + np_mask_rgb = 1. - (np.asarray(PIL.ImageOps.invert(init_mask).convert("RGB"))/255.0).astype(np.float64) + np_mask_rgb -= np.min(np_mask_rgb) + np_mask_rgb /= np.max(np_mask_rgb) + np_mask_rgb = 1. - np_mask_rgb + np_mask_rgb_hardened = 1. - (np_mask_rgb < 0.99).astype(np.float64) + blurred = skimage.filters.gaussian(np_mask_rgb_hardened[:], sigma=16., channel_axis=2, truncate=32.) + blurred2 = skimage.filters.gaussian(np_mask_rgb_hardened[:], sigma=16., channel_axis=2, truncate=32.) + #np_mask_rgb_dilated = np_mask_rgb + blurred # fixup mask todo: derive magic constants + #np_mask_rgb = np_mask_rgb + blurred + np_mask_rgb_dilated = np.clip((np_mask_rgb + blurred2) * 0.7071, 0., 1.) + np_mask_rgb = np.clip((np_mask_rgb + blurred) * 0.7071, 0., 1.) + + noise_rgb = self.get_matched_noise(np_init, np_mask_rgb, noise_q, color_variation) + blend_mask_rgb = np.clip(np_mask_rgb_dilated,0.,1.) ** (mask_blend_factor) + noised = noise_rgb[:] + blend_mask_rgb **= (2.) + noised = np_init[:] * (1. - blend_mask_rgb) + noised * blend_mask_rgb + + np_mask_grey = np.sum(np_mask_rgb, axis=2)/3. + ref_mask = np_mask_grey < 1e-3 + + all_mask = np.ones((height, width), dtype=bool) + noised[all_mask,:] = skimage.exposure.match_histograms(noised[all_mask,:]**1., noised[ref_mask,:], channel_axis=1) + + init_img = PIL.Image.fromarray(np.clip(noised * 255., 0., 255.).astype(np.uint8), mode="RGB") + + def init(): + image = init_img.convert('RGB') + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + + mask_channel = None + if init_mask: + alpha = self.resize_image(resize_mode, init_mask, width // 8, height // 8) + mask_channel = alpha.split()[-1] + + mask = None + if mask_channel is not None: + mask = np.array(mask_channel).astype(np.float32) / 255.0 + mask = (1 - mask) + mask = np.tile(mask, (4, 1, 1)) + mask = mask[None].transpose(0, 1, 2, 3) + mask = torch.from_numpy(mask).to(self.model.device) + + init_image = 2. * image - 1. + init_image = init_image.to(self.model.device) + init_latent = self.model.get_first_stage_encoding(self.model.encode_first_stage(init_image)) # move to latent space + + return init_latent, mask, + + def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name): + t_enc_steps = t_enc + obliterate = False + if ddim_steps == t_enc_steps: + t_enc_steps = t_enc_steps - 1 + obliterate = True + + if sampler_name != 'DDIM': + x0, z_mask = init_data + + sigmas = sampler.model_wrap.get_sigmas(ddim_steps) + noise = x * sigmas[ddim_steps - t_enc_steps - 1] + + xi = x0 + noise + + # Obliterate masked image + if z_mask is not None and obliterate: + random = torch.randn(z_mask.shape, device=xi.device) + xi = (z_mask * noise) + ((1-z_mask) * xi) + + sigma_sched = sigmas[ddim_steps - t_enc_steps - 1:] + model_wrap_cfg = CFGMaskedDenoiser(sampler.model_wrap) + samples_ddim = K.sampling.__dict__[f'sample_{sampler.get_sampler_name()}'](model_wrap_cfg, xi, sigma_sched, + extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, + 'cond_scale': cfg_scale, 'mask': z_mask, 'x0': x0, 'xi': xi}, disable=False) + else: + + x0, z_mask = init_data + + sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=0.0, verbose=False) + z_enc = sampler.stochastic_encode(x0, torch.tensor([t_enc_steps]*batch_size).to(self.model.device)) + + # Obliterate masked image + if z_mask is not None and obliterate: + random = torch.randn(z_mask.shape, device=z_enc.device) + z_enc = (z_mask * random) + ((1-z_mask) * z_enc) + + # decode it + samples_ddim = sampler.decode(z_enc, conditioning, t_enc_steps, + unconditional_guidance_scale=cfg_scale, + unconditional_conditioning=unconditional_conditioning, + z_mask=z_mask, x0=x0) + return samples_ddim + + torch_gc() + + if self.load_concepts and self.concepts_dir is not None: + prompt_tokens = re.findall('<([a-zA-Z0-9-]+)>', prompt) + if prompt_tokens: + self.process_prompt_tokens(prompt_tokens) + + os.makedirs(self.output_dir, exist_ok=True) + + sample_path = os.path.join(self.output_dir, "samples") + os.makedirs(sample_path, exist_ok=True) + + if self.verify_input: + try: + check_prompt_length(self.model, prompt, self.comments) + except: + import traceback + print("Error verifying input:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + + all_prompts = batch_size * n_iter * [prompt] + all_seeds = [seed + x for x in range(len(all_prompts))] + + precision_scope = torch.autocast if self.auto_cast else nullcontext + + with torch.no_grad(), precision_scope("cuda"): + for n in range(n_iter): + print(f"Iteration: {n+1}/{n_iter}") + prompts = all_prompts[n * batch_size:(n + 1) * batch_size] + seeds = all_seeds[n * batch_size:(n + 1) * batch_size] + + uc = self.model.get_learned_conditioning(len(prompts) * ['']) + + if isinstance(prompts, tuple): + prompts = list(prompts) + + c = self.model.get_learned_conditioning(prompts) + + opt_C = 4 + opt_f = 8 + shape = [opt_C, height // opt_f, width // opt_f] + + x = self.create_random_tensors(shape, seeds=seeds) + init_data = init() + samples_ddim = sample(init_data=init_data, x=x, conditioning=c, unconditional_conditioning=uc, sampler_name=sampler_name) + + x_samples_ddim = self.model.decode_first_stage(samples_ddim) + x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + + for i, x_sample in enumerate(x_samples_ddim): + sanitized_prompt = slugify(prompts[i]) + full_path = os.path.join(os.getcwd(), sample_path) + sample_path_i = sample_path + base_count = get_next_sequence_number(sample_path_i) + filename = f"{base_count:05}-{ddim_steps}_{sampler_name}_{seeds[i]}_{sanitized_prompt}"[:200-len(full_path)] + + x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + x_sample = x_sample.astype(np.uint8) + image = PIL.Image.fromarray(x_sample) + image_dict['image'] = image + self.images.append(image_dict) + + if save_individual_images: + path = os.path.join(sample_path, filename + '.' + self.save_extension) + success = save_sample(image, filename, sample_path_i, self.save_extension) + if success: + if self.output_file_path: + self.output_images.append(path) + else: + self.output_images.append(image) + else: + return + + self.info = f""" + {prompt} + Steps: {ddim_steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed} + """.strip() + self.stats = f''' + ''' + + for comment in self.comments: + self.info += "\n\n" + comment + + torch_gc() + + del sampler + + return diff --git a/scripts/nataili/inference/compvis/txt2img.py b/scripts/nataili/inference/compvis/txt2img.py new file mode 100644 index 0000000..cb11c67 --- /dev/null +++ b/scripts/nataili/inference/compvis/txt2img.py @@ -0,0 +1,201 @@ +import os +import re +import sys +from contextlib import contextmanager, nullcontext + +import numpy as np +import PIL +import torch +from einops import rearrange +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.models.diffusion.kdiffusion import KDiffusionSampler +from ldm.models.diffusion.plms import PLMSSampler +from nataili.util.cache import torch_gc +from nataili.util.check_prompt_length import check_prompt_length +from nataili.util.get_next_sequence_number import get_next_sequence_number +from nataili.util.image_grid import image_grid +from nataili.util.load_learned_embed_in_clip import load_learned_embed_in_clip +from nataili.util.save_sample import save_sample +from nataili.util.seed_to_int import seed_to_int +from slugify import slugify + + +class txt2img: + def __init__(self, model, device, output_dir, save_extension='jpg', + output_file_path=False, load_concepts=False, concepts_dir=None, + verify_input=True, auto_cast=True): + self.model = model + self.output_dir = output_dir + self.output_file_path = output_file_path + self.save_extension = save_extension + self.load_concepts = load_concepts + self.concepts_dir = concepts_dir + self.verify_input = verify_input + self.auto_cast = auto_cast + self.device = device + self.comments = [] + self.output_images = [] + self.info = '' + self.stats = '' + self.images = [] + + def create_random_tensors(self, shape, seeds): + xs = [] + for seed in seeds: + torch.manual_seed(seed) + + # randn results depend on device; gpu and cpu get different results for same seed; + # the way I see it, it's better to do this on CPU, so that everyone gets same result; + # but the original script had it like this so i do not dare change it for now because + # it will break everyone's seeds. + xs.append(torch.randn(shape, device=self.device)) + x = torch.stack(xs) + return x + + def process_prompt_tokens(self, prompt_tokens): + # compviz codebase + tokenizer = self.model.cond_stage_model.tokenizer + text_encoder = self.model.cond_stage_model.transformer + + # diffusers codebase + #tokenizer = pipe.tokenizer + #text_encoder = pipe.text_encoder + + ext = ('.pt', '.bin') + for token_name in prompt_tokens: + embedding_path = os.path.join(self.concepts_dir, token_name) + if os.path.exists(embedding_path): + for files in os.listdir(embedding_path): + if files.endswith(ext): + load_learned_embed_in_clip(f"{os.path.join(embedding_path, files)}", text_encoder, tokenizer, f"<{token_name}>") + else: + print(f"Concept {token_name} not found in {self.concepts_dir}") + del tokenizer, text_encoder + return + del tokenizer, text_encoder + + def generate(self, prompt: str, ddim_steps=50, sampler_name='k_lms', n_iter=1, batch_size=1, cfg_scale=7.5, seed=None, + height=512, width=512, save_individual_images: bool = True, save_grid: bool = True, ddim_eta:float = 0.0): + seed = seed_to_int(seed) + + image_dict = { + "seed": seed + } + negprompt = '' + if '###' in prompt: + prompt, negprompt = prompt.split('###', 1) + prompt = prompt.strip() + negprompt = negprompt.strip() + + if sampler_name == 'PLMS': + sampler = PLMSSampler(self.model) + elif sampler_name == 'DDIM': + sampler = DDIMSampler(self.model) + elif sampler_name == 'k_dpm_2_a': + sampler = KDiffusionSampler(self.model,'dpm_2_ancestral') + elif sampler_name == 'k_dpm_2': + sampler = KDiffusionSampler(self.model,'dpm_2') + elif sampler_name == 'k_euler_a': + sampler = KDiffusionSampler(self.model,'euler_ancestral') + elif sampler_name == 'k_euler': + sampler = KDiffusionSampler(self.model,'euler') + elif sampler_name == 'k_heun': + sampler = KDiffusionSampler(self.model,'heun') + elif sampler_name == 'k_lms': + sampler = KDiffusionSampler(self.model,'lms') + else: + raise Exception("Unknown sampler: " + sampler_name) + + def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name): + samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=conditioning, unconditional_guidance_scale=cfg_scale, + unconditional_conditioning=unconditional_conditioning, x_T=x) + return samples_ddim + + torch_gc() + + if self.load_concepts and self.concepts_dir is not None: + prompt_tokens = re.findall('<([a-zA-Z0-9-]+)>', prompt) + if prompt_tokens: + self.process_prompt_tokens(prompt_tokens) + + os.makedirs(self.output_dir, exist_ok=True) + + sample_path = os.path.join(self.output_dir, "samples") + os.makedirs(sample_path, exist_ok=True) + + if self.verify_input: + try: + check_prompt_length(self.model, prompt, self.comments) + except: + import traceback + print("Error verifying input:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + + all_prompts = batch_size * n_iter * [prompt] + all_seeds = [seed + x for x in range(len(all_prompts))] + + precision_scope = torch.autocast if self.auto_cast else nullcontext + + with torch.no_grad(), precision_scope("cuda"): + for n in range(n_iter): + print(f"Iteration: {n+1}/{n_iter}") + prompts = all_prompts[n * batch_size:(n + 1) * batch_size] + seeds = all_seeds[n * batch_size:(n + 1) * batch_size] + + uc = self.model.get_learned_conditioning(len(prompts) * [negprompt]) + + if isinstance(prompts, tuple): + prompts = list(prompts) + + c = self.model.get_learned_conditioning(prompts) + + opt_C = 4 + opt_f = 8 + shape = [opt_C, height // opt_f, width // opt_f] + + x = self.create_random_tensors(shape, seeds=seeds) + + samples_ddim = sample(init_data=None, x=x, conditioning=c, unconditional_conditioning=uc, sampler_name=sampler_name) + + x_samples_ddim = self.model.decode_first_stage(samples_ddim) + x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + + for i, x_sample in enumerate(x_samples_ddim): + sanitized_prompt = slugify(prompts[i]) + full_path = os.path.join(os.getcwd(), sample_path) + sample_path_i = sample_path + base_count = get_next_sequence_number(sample_path_i) + filename = f"{base_count:05}-{ddim_steps}_{sampler_name}_{seeds[i]}_{sanitized_prompt}"[:200-len(full_path)] + + x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + x_sample = x_sample.astype(np.uint8) + image = PIL.Image.fromarray(x_sample) + image_dict['image'] = image + self.images.append(image_dict) + + if save_individual_images: + path = os.path.join(sample_path, filename + '.' + self.save_extension) + success = save_sample(image, filename, sample_path_i, self.save_extension) + if success: + if self.output_file_path: + self.output_images.append(path) + else: + self.output_images.append(image) + else: + return + + self.info = f""" + {prompt} + Steps: {ddim_steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed} + """.strip() + self.stats = f''' + ''' + + for comment in self.comments: + self.info += "\n\n" + comment + + torch_gc() + + del sampler + + return diff --git a/scripts/nataili/inference/diffusers/__init__.py b/scripts/nataili/inference/diffusers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/scripts/nataili/model_manager.py b/scripts/nataili/model_manager.py new file mode 100644 index 0000000..ee3bae1 --- /dev/null +++ b/scripts/nataili/model_manager.py @@ -0,0 +1,458 @@ +import os +import json +import shutil +import zipfile +import requests +import git +import torch +import hashlib +from ldm.util import instantiate_from_config +from omegaconf import OmegaConf +from transformers import logging + +from basicsr.archs.rrdbnet_arch import RRDBNet +from gfpgan import GFPGANer +from realesrgan import RealESRGANer +from ldm.models.blip import blip_decoder +from tqdm import tqdm +import open_clip +import clip + +from nataili.util.cache import torch_gc +from nataili.util import logger + +logging.set_verbosity_error() + +models = json.load(open('./db.json')) +dependencies = json.load(open('./db_dep.json')) +remote_models = "https://raw.githubusercontent.com/Sygil-Dev/nataili-model-reference/main/db.json" +remote_dependencies = "https://raw.githubusercontent.com/Sygil-Dev/nataili-model-reference/main/db_dep.json" + +class ModelManager(): + def __init__(self, hf_auth=None, download=True): + if download: + try: + logger.init("Model Reference", status="Downloading") + r = requests.get(remote_models) + self.models = r.json() + r = requests.get(remote_dependencies) + self.dependencies = json.load(open('./db_dep.json')) + logger.init_ok("Model Reference", status="OK") + except: + logger.init_err("Model Reference", status="Download Error") + self.models = json.load(open('./db.json')) + self.dependencies = json.load(open('./db_dep.json')) + logger.init_warn("Model Reference", status="Local") + self.available_models = [] + self.tainted_models = [] + self.available_dependencies = [] + self.loaded_models = {} + self.hf_auth = None + self.set_authentication(hf_auth) + + def init(self): + dependencies_available = [] + for dependency in self.dependencies: + if self.check_available(self.get_dependency_files(dependency)): + dependencies_available.append(dependency) + self.available_dependencies = dependencies_available + + models_available = [] + for model in self.models: + if self.check_available(self.get_model_files(model)): + models_available.append(model) + self.available_models = models_available + + if self.hf_auth is not None: + if 'username' not in self.hf_auth and 'password' not in self.hf_auth: + raise ValueError('hf_auth must contain username and password') + else: + if self.hf_auth['username'] == '' or self.hf_auth['password'] == '': + raise ValueError('hf_auth must contain username and password') + return True + + def set_authentication(self, hf_auth=None): + # We do not let No authentication override previously set auth + if not hf_auth and self.hf_auth: + return + self.hf_auth = hf_auth + + def get_model(self, model_name): + return self.models.get(model_name) + + def get_filtered_models(self, **kwargs): + '''Get all model names. + Can filter based on metadata of the model reference db + ''' + filtered_models = self.models + for keyword in kwargs: + iterating_models = filtered_models.copy() + filtered_models = {} + for model in iterating_models: + # logger.debug([keyword,iterating_models[model].get(keyword),kwargs[keyword]]) + if iterating_models[model].get(keyword) == kwargs[keyword]: + filtered_models[model] = iterating_models[model] + return filtered_models + + def get_filtered_model_names(self, **kwargs): + filtered_models = self.get_filtered_models(**kwargs) + return list(filtered_models.keys()) + + def get_dependency(self, dependency_name): + return self.dependencies[dependency_name] + + def get_model_files(self, model_name): + return self.models[model_name]['config']['files'] + + def get_dependency_files(self, dependency_name): + return self.dependencies[dependency_name]['config']['files'] + + def get_model_download(self, model_name): + return self.models[model_name]['config']['download'] + + def get_dependency_download(self, dependency_name): + return self.dependencies[dependency_name]['config']['download'] + + def get_available_models(self): + return self.available_models + + def get_available_dependencies(self): + return self.available_dependencies + + def get_loaded_models(self): + return self.loaded_models + + def get_loaded_models_names(self): + return list(self.loaded_models.keys()) + + def get_loaded_model(self, model_name): + return self.loaded_models[model_name] + + def unload_model(self, model_name): + if model_name in self.loaded_models: + del self.loaded_models[model_name] + return True + return False + + def unload_all_models(self): + for model in self.loaded_models: + del self.loaded_models[model] + return True + + def taint_model(self,model_name): + '''Marks a model as not valid by remiving it from available_models''' + if model_name in self.available_models: + self.available_models.remove(model_name) + self.tainted_models.append(model_name) + + def taint_models(self, models): + for model in models: + self.taint_model(model) + + def load_model_from_config(self, model_path='', config_path='', map_location="cpu"): + config = OmegaConf.load(config_path) + pl_sd = torch.load(model_path, map_location=map_location) + if "global_step" in pl_sd: + logger.info(f"Global Step: {pl_sd['global_step']}") + sd = pl_sd["state_dict"] + model = instantiate_from_config(config.model) + m, u = model.load_state_dict(sd, strict=False) + model = model.eval() + del pl_sd, sd, m, u + return model + + def load_ckpt(self, model_name='', precision='half', gpu_id=0): + ckpt_path = self.get_model_files(model_name)[0]['path'] + config_path = self.get_model_files(model_name)[1]['path'] + model = self.load_model_from_config(model_path=ckpt_path, config_path=config_path) + device = torch.device(f"cuda:{gpu_id}") + model = (model if precision=='full' else model.half()).to(device) + torch_gc() + return {'model': model, 'device': device} + + def load_realesrgan(self, model_name='', precision='half', gpu_id=0): + + RealESRGAN_models = { + 'RealESRGAN_x4plus': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4), + 'RealESRGAN_x4plus_anime_6B': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) + } + + model_path = self.get_model_files(model_name)[0]['path'] + device = torch.device(f"cuda:{gpu_id}") + model = RealESRGANer(scale=2, model_path=model_path, model=RealESRGAN_models[models[model_name]['name']], + pre_pad=0, half=True if precision == 'half' else False, device=device) + return {'model': model, 'device': device} + + def load_gfpgan(self, model_name='', gpu_id=0): + + model_path = self.get_model_files(model_name)[0]['path'] + device = torch.device(f"cuda:{gpu_id}") + model = GFPGANer(model_path=model_path, upscale=1, arch='clean', + channel_multiplier=2, bg_upsampler=None, device=device) + return {'model': model, 'device': device} + + def load_blip(self, model_name='', precision='half', gpu_id=0, blip_image_eval_size=512, vit='base'): + # vit = 'base' or 'large' + model_path = self.get_model_files(model_name)[0]['path'] + device = torch.device(f"cuda:{gpu_id}") + model = blip_decoder(pretrained=model_path, + med_config="configs/blip/med_config.json", + image_size=blip_image_eval_size, vit=vit) + model = model.eval() + model = (model if precision=='full' else model.half()).to(device) + return {'model': model, 'device': device} + + def load_open_clip(self, model_name='', precision='half', gpu_id=0): + pretrained = self.get_model(model_name)['pretrained_name'] + device = torch.device(f"cuda:{gpu_id}") + model, _, preprocesses = open_clip.create_model_and_transforms(model_name, pretrained=pretrained, cache_dir='models/clip') + model = model.eval() + model = (model if precision=='full' else model.half()).to(device) + return {'model': model, 'device': device, 'preprocesses': preprocesses} + + def load_clip(self, model_name='', precision='half', gpu_id=0): + device = torch.device(f"cuda:{gpu_id}") + model, preprocesses = clip.load(model_name, device=device, download_root='models/clip') + model = model.eval() + model = (model if precision=='full' else model.half()).to(device) + return {'model': model, 'device': device, 'preprocesses': preprocesses} + + def load_model(self, model_name='', precision='half', gpu_id=0): + if model_name not in self.available_models: + return False + if self.models[model_name]['type'] == 'ckpt': + self.loaded_models[model_name] = self.load_ckpt(model_name, precision, gpu_id) + return True + elif self.models[model_name]['type'] == 'realesrgan': + self.loaded_models[model_name] = self.load_realesrgan(model_name, precision, gpu_id) + return True + elif self.models[model_name]['type'] == 'gfpgan': + self.loaded_models[model_name] = self.load_gfpgan(model_name, gpu_id) + return True + elif self.models[model_name]['type'] == 'blip': + self.loaded_models[model_name] = self.load_blip(model_name, precision, gpu_id, 512, 'base') + return True + elif self.models[model_name]['type'] == 'open_clip': + self.loaded_models[model_name] = self.load_open_clip(model_name, precision, gpu_id) + return True + elif self.models[model_name]['type'] == 'clip': + self.loaded_models[model_name] = self.load_clip(model_name, precision, gpu_id) + return True + else: + return False + + def validate_model(self, model_name): + files = self.get_model_files(model_name) + all_ok = True + for file_details in files: + if not self.check_file_available(file_details['path']): + return False + if not self.validate_file(file_details): + return False + return True + + def validate_file(self, file_details): + if 'md5sum' in file_details: + file_name = file_details['path'] + logger.debug(f"Getting md5sum of {file_name}") + with open(file_name, 'rb') as file_to_check: + file_hash = hashlib.md5() + while chunk := file_to_check.read(8192): + file_hash.update(chunk) + if file_details['md5sum'] != file_hash.hexdigest(): + return False + return True + + def check_file_available(self, file_path): + return os.path.exists(file_path) + + def check_available(self, files): + available = True + for file in files: + if not self.check_file_available(file['path']): + available = False + return available + + def download_file(self, url, file_path): + # make directory + os.makedirs(os.path.dirname(file_path), exist_ok=True) + pbar_desc = file_path.split('/')[-1] + r = requests.get(url, stream=True) + with open(file_path, 'wb') as f: + with tqdm( + # all optional kwargs + unit='B', unit_scale=True, unit_divisor=1024, miniters=1, + desc=pbar_desc, total=int(r.headers.get('content-length', 0)) + ) as pbar: + for chunk in r.iter_content(chunk_size=16*1024): + if chunk: + f.write(chunk) + pbar.update(len(chunk)) + + def download_model(self, model_name): + if model_name in self.available_models: + logger.info(f"{model_name} is already available.") + return True + download = self.get_model_download(model_name) + files = self.get_model_files(model_name) + for i in range(len(download)): + file_path = f"{download[i]['file_path']}/{download[i]['file_name']}" if 'file_path' in download[i] else files[i]['path'] + + if 'file_url' in download[i]: + download_url = download[i]['file_url'] + if 'hf_auth' in download[i]: + username = self.hf_auth['username'] + password = self.hf_auth['password'] + download_url = download_url.format(username=username, password=password) + if 'file_name' in download[i]: + download_name = download[i]['file_name'] + if 'file_path' in download[i]: + download_path = download[i]['file_path'] + + if 'manual' in download[i]: + logger.warning(f"The model {model_name} requires manual download from {download_url}. Please place it in {download_path}/{download_name} then press ENTER to continue...") + input('') + continue + # TODO: simplify + if "file_content" in download[i]: + file_content = download[i]['file_content'] + logger.info(f"writing {file_content} to {file_path}") + # make directory download_path + os.makedirs(download_path, exist_ok=True) + # write file_content to download_path/download_name + with open(os.path.join(download_path, download_name), 'w') as f: + f.write(file_content) + elif 'symlink' in download[i]: + logger.info(f"symlink {file_path} to {download[i]['symlink']}") + symlink = download[i]['symlink'] + # make directory symlink + os.makedirs(download_path, exist_ok=True) + # make symlink from download_path/download_name to symlink + os.symlink(symlink, os.path.join(download_path, download_name)) + elif 'git' in download[i]: + logger.info(f"git clone {download_url} to {file_path}") + # make directory download_path + os.makedirs(file_path, exist_ok=True) + git.Git(file_path).clone(download_url) + if 'post_process' in download[i]: + for post_process in download[i]['post_process']: + if 'delete' in post_process: + # delete folder post_process['delete'] + logger.info(f"delete {post_process['delete']}") + try: + shutil.rmtree(post_process['delete']) + except PermissionError as e: + logger.error(f"[!] Something went wrong while deleting the `{post_process['delete']}`. Please delete it manually.") + logger.error("PermissionError: ", e) + else: + if not self.check_file_available(file_path) or model_name in self.tainted_models: + logger.debug(f'Downloading {download_url} to {file_path}') + self.download_file(download_url, file_path) + if not self.validate_model(model_name): + return False + if model_name in self.tainted_models: + self.tainted_models.remove(model_name) + self.init() + return True + + def download_dependency(self, dependency_name): + if dependency_name in self.available_dependencies: + logger.info(f"{dependency_name} is already installed.") + return True + download = self.get_dependency_download(dependency_name) + files = self.get_dependency_files(dependency_name) + for i in range(len(download)): + if "git" in download[i]: + logger.warning("git download not implemented yet") + break + + file_path = files[i]['path'] + if 'file_url' in download[i]: + download_url = download[i]['file_url'] + if 'file_name' in download[i]: + download_name = download[i]['file_name'] + if 'file_path' in download[i]: + download_path = download[i]['file_path'] + logger.debug(download_name) + if "unzip" in download[i]: + zip_path = f'temp/{download_name}.zip' + # os dirname zip_path + # mkdir temp + os.makedirs("temp", exist_ok=True) + + self.download_file(download_url, zip_path) + logger.info(f"unzip {zip_path}") + with zipfile.ZipFile(zip_path, 'r') as zip_ref: + zip_ref.extractall('temp/') + # move temp/sd-concepts-library-main/sd-concepts-library to download_path + logger.info(f"move temp/{download_name}-main/{download_name} to {download_path}") + shutil.move(f"temp/{download_name}-main/{download_name}", download_path) + logger.info(f"delete {zip_path}") + os.remove(zip_path) + logger.info(f"delete temp/{download_name}-main/") + shutil.rmtree(f"temp/{download_name}-main") + else: + if not self.check_file_available(file_path): + logger.init(f'{file_path}', status="Downloading") + self.download_file(download_url, file_path) + self.init() + return True + + def download_all_models(self): + for model in self.get_filtered_model_names(download_all = True): + if not self.check_model_available(model): + logger.init(f"{model}", status="Downloading") + self.download_model(model) + else: + logger.info(f"{model} is already downloaded.") + return True + + def download_all_dependencies(self): + for dependency in self.dependencies: + if not self.check_dependency_available(dependency): + logger.init(f"{dependency}",status="Downloading") + self.download_dependency(dependency) + else: + logger.info(f"{dependency} is already installed.") + return True + + def download_all(self): + self.download_all_dependencies() + self.download_all_models() + return True + + def check_all_available(self): + for model in self.models: + if not self.check_available(self.get_model_files(model)): + return False + for dependency in self.dependencies: + if not self.check_available(self.get_dependency_files(dependency)): + return False + return True + + def check_model_available(self, model_name): + if model_name not in self.models: + return False + return self.check_available(self.get_model_files(model_name)) + + def check_dependency_available(self, dependency_name): + if dependency_name not in self.dependencies: + return False + return self.check_available(self.get_dependency_files(dependency_name)) + + def check_all_available(self): + for model in self.models: + if not self.check_model_available(model): + return False + for dependency in self.dependencies: + if not self.check_dependency_available(dependency): + return False + return True + + + + + + + + diff --git a/scripts/nataili/postprocess/__init__.py b/scripts/nataili/postprocess/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/scripts/nataili/postprocess/upscaler.py b/scripts/nataili/postprocess/upscaler.py new file mode 100644 index 0000000..529b438 --- /dev/null +++ b/scripts/nataili/postprocess/upscaler.py @@ -0,0 +1,48 @@ +# Class realesrgan +# Inputs: +# - model +# - device +# - output_dir +# - output_ext +# outupts: +# - output_images +import PIL +from torchvision import transforms +import numpy as np +import os +import cv2 + +from nataili.util.save_sample import save_sample + +class realesrgan: + def __init__(self, model, device, output_dir, output_ext='jpg'): + self.model = model + self.device = device + self.output_dir = output_dir + self.output_ext = output_ext + self.output_images = [] + + def generate(self, input_image): + # load image + img = cv2.imread(input_image, cv2.IMREAD_UNCHANGED) + if len(img.shape) == 3 and img.shape[2] == 4: + img_mode = 'RGBA' + else: + img_mode = None + # upscale + output, _ = self.model.enhance(img) + if img_mode == 'RGBA': # RGBA images should be saved in png format + self.output_ext = 'png' + + esrgan_sample = output[:,:,::-1] + esrgan_image = PIL.Image.fromarray(esrgan_sample) + # append model name to output image name + filename = os.path.basename(input_image) + filename = os.path.splitext(filename)[0] + filename = f'{filename}_esrgan' + filename_with_ext = f'{filename}.{self.output_ext}' + output_image = os.path.join(self.output_dir, filename_with_ext) + save_sample(esrgan_image, filename, self.output_dir, self.output_ext) + self.output_images.append(output_image) + return + diff --git a/scripts/nataili/upscalers/__init__.py b/scripts/nataili/upscalers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/scripts/nataili/upscalers/realesrgan.py b/scripts/nataili/upscalers/realesrgan.py new file mode 100644 index 0000000..529b438 --- /dev/null +++ b/scripts/nataili/upscalers/realesrgan.py @@ -0,0 +1,48 @@ +# Class realesrgan +# Inputs: +# - model +# - device +# - output_dir +# - output_ext +# outupts: +# - output_images +import PIL +from torchvision import transforms +import numpy as np +import os +import cv2 + +from nataili.util.save_sample import save_sample + +class realesrgan: + def __init__(self, model, device, output_dir, output_ext='jpg'): + self.model = model + self.device = device + self.output_dir = output_dir + self.output_ext = output_ext + self.output_images = [] + + def generate(self, input_image): + # load image + img = cv2.imread(input_image, cv2.IMREAD_UNCHANGED) + if len(img.shape) == 3 and img.shape[2] == 4: + img_mode = 'RGBA' + else: + img_mode = None + # upscale + output, _ = self.model.enhance(img) + if img_mode == 'RGBA': # RGBA images should be saved in png format + self.output_ext = 'png' + + esrgan_sample = output[:,:,::-1] + esrgan_image = PIL.Image.fromarray(esrgan_sample) + # append model name to output image name + filename = os.path.basename(input_image) + filename = os.path.splitext(filename)[0] + filename = f'{filename}_esrgan' + filename_with_ext = f'{filename}.{self.output_ext}' + output_image = os.path.join(self.output_dir, filename_with_ext) + save_sample(esrgan_image, filename, self.output_dir, self.output_ext) + self.output_images.append(output_image) + return + diff --git a/scripts/nataili/util/__init__.py b/scripts/nataili/util/__init__.py new file mode 100644 index 0000000..766cb27 --- /dev/null +++ b/scripts/nataili/util/__init__.py @@ -0,0 +1 @@ +from nataili.util.logger import logger,set_logger_verbosity, quiesce_logger, test_logger diff --git a/scripts/nataili/util/cache.py b/scripts/nataili/util/cache.py new file mode 100644 index 0000000..ef7711c --- /dev/null +++ b/scripts/nataili/util/cache.py @@ -0,0 +1,16 @@ +import gc + +import torch +import threading +import pynvml +import time + +with torch.no_grad(): + def torch_gc(): + for _ in range(2): + gc.collect() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + torch.cuda.synchronize() + torch.cuda.reset_peak_memory_stats() + torch.cuda.reset_accumulated_memory_stats() diff --git a/scripts/nataili/util/check_prompt_length.py b/scripts/nataili/util/check_prompt_length.py new file mode 100644 index 0000000..4953eac --- /dev/null +++ b/scripts/nataili/util/check_prompt_length.py @@ -0,0 +1,18 @@ +def check_prompt_length(model, prompt, comments): + """this function tests if prompt is too long, and if so, adds a message to comments""" + + tokenizer = model.cond_stage_model.tokenizer + max_length = model.cond_stage_model.max_length + + info = model.cond_stage_model.tokenizer([prompt], truncation=True, max_length=max_length, + return_overflowing_tokens=True, padding="max_length", return_tensors="pt") + ovf = info['overflowing_tokens'][0] + overflowing_count = ovf.shape[0] + if overflowing_count == 0: + return + + vocab = {v: k for k, v in tokenizer.get_vocab().items()} + overflowing_words = [vocab.get(int(x), "") for x in ovf] + overflowing_text = tokenizer.convert_tokens_to_string(''.join(overflowing_words)) + comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") + del tokenizer \ No newline at end of file diff --git a/scripts/nataili/util/get_next_sequence_number.py b/scripts/nataili/util/get_next_sequence_number.py new file mode 100644 index 0000000..ede028f --- /dev/null +++ b/scripts/nataili/util/get_next_sequence_number.py @@ -0,0 +1,22 @@ +from pathlib import Path + +def get_next_sequence_number(path, prefix=''): + """ + Determines and returns the next sequence number to use when saving an + image in the specified directory. + + If a prefix is given, only consider files whose names start with that + prefix, and strip the prefix from filenames before extracting their + sequence number. + + The sequence starts at 0. + """ + result = -1 + for p in Path(path).iterdir(): + if p.name.endswith(('.png', '.jpg')) and p.name.startswith(prefix): + tmp = p.name[len(prefix):] + try: + result = max(int(tmp.split('-')[0]), result) + except ValueError: + pass + return result + 1 \ No newline at end of file diff --git a/scripts/nataili/util/image_grid.py b/scripts/nataili/util/image_grid.py new file mode 100644 index 0000000..7ea85eb --- /dev/null +++ b/scripts/nataili/util/image_grid.py @@ -0,0 +1,21 @@ +import math + +import PIL + + +def image_grid(imgs, n_rows=None): + if n_rows is not None: + rows = n_rows + else: + rows = math.sqrt(len(imgs)) + rows = round(rows) + + cols = math.ceil(len(imgs) / rows) + + w, h = imgs[0].size + grid = PIL.Image.new('RGB', size=(cols * w, rows * h), color='black') + + for i, img in enumerate(imgs): + grid.paste(img, box=(i % cols * w, i // cols * h)) + + return grid diff --git a/scripts/nataili/util/load_learned_embed_in_clip.py b/scripts/nataili/util/load_learned_embed_in_clip.py new file mode 100644 index 0000000..9507e58 --- /dev/null +++ b/scripts/nataili/util/load_learned_embed_in_clip.py @@ -0,0 +1,40 @@ +import os + +import torch + + +def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer, token=None): + loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu") + # separate token and the embeds + if learned_embeds_path.endswith('.pt'): + # old format + # token = * so replace with file directory name when converting + trained_token = os.path.basename(learned_embeds_path) + params_dict = { + trained_token: torch.tensor(list(loaded_learned_embeds['string_to_param'].items())[0][1]) + } + learned_embeds_path = os.path.splitext(learned_embeds_path)[0] + '.bin' + torch.save(params_dict, learned_embeds_path) + loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu") + trained_token = list(loaded_learned_embeds.keys())[0] + embeds = loaded_learned_embeds[trained_token] + elif learned_embeds_path.endswith('.bin'): + trained_token = list(loaded_learned_embeds.keys())[0] + embeds = loaded_learned_embeds[trained_token] + + embeds = loaded_learned_embeds[trained_token] + # cast to dtype of text_encoder + dtype = text_encoder.get_input_embeddings().weight.dtype + embeds.to(dtype) + + # add the token in tokenizer + token = token if token is not None else trained_token + num_added_tokens = tokenizer.add_tokens(token) + + # resize the token embeddings + text_encoder.resize_token_embeddings(len(tokenizer)) + + # get the id for the token and assign the embeds + token_id = tokenizer.convert_tokens_to_ids(token) + text_encoder.get_input_embeddings().weight.data[token_id] = embeds + return token diff --git a/scripts/nataili/util/logger.py b/scripts/nataili/util/logger.py new file mode 100644 index 0000000..d012c83 --- /dev/null +++ b/scripts/nataili/util/logger.py @@ -0,0 +1,102 @@ +import sys +from functools import partialmethod +from loguru import logger + +STDOUT_LEVELS = ["GENERATION", "PROMPT"] +INIT_LEVELS = ["INIT", "INIT_OK", "INIT_WARN", "INIT_ERR"] +MESSAGE_LEVELS = ["MESSAGE"] +# By default we're at error level or higher +verbosity = 20 +quiet = 0 + +def set_logger_verbosity(count): + global verbosity + # The count comes reversed. So count = 0 means minimum verbosity + # While count 5 means maximum verbosity + # So the more count we have, the lowe we drop the versbosity maximum + verbosity = 20 - (count * 10) + +def quiesce_logger(count): + global quiet + # The bigger the count, the more silent we want our logger + quiet = count * 10 + +def is_stdout_log(record): + if record["level"].name not in STDOUT_LEVELS: + return(False) + if record["level"].no < verbosity + quiet: + return(False) + return(True) + +def is_init_log(record): + if record["level"].name not in INIT_LEVELS: + return(False) + if record["level"].no < verbosity + quiet: + return(False) + return(True) + +def is_msg_log(record): + if record["level"].name not in MESSAGE_LEVELS: + return(False) + if record["level"].no < verbosity + quiet: + return(False) + return(True) + +def is_stderr_log(record): + if record["level"].name in STDOUT_LEVELS + INIT_LEVELS + MESSAGE_LEVELS: + return(False) + if record["level"].no < verbosity + quiet: + return(False) + return(True) + +def test_logger(): + logger.generation("This is a generation message\nIt is typically multiline\nThee Lines".encode("unicode_escape").decode("utf-8")) + logger.prompt("This is a prompt message") + logger.debug("Debug Message") + logger.info("Info Message") + logger.warning("Info Warning") + logger.error("Error Message") + logger.critical("Critical Message") + logger.init("This is an init message", status="Starting") + logger.init_ok("This is an init message", status="OK") + logger.init_warn("This is an init message", status="Warning") + logger.init_err("This is an init message", status="Error") + logger.message("This is user message") + sys.exit() + + +logfmt = "{level: <10} | {time:YYYY-MM-DD HH:mm:ss} | {name}:{function}:{line} - {message}" +genfmt = "{level: <10} @ {time:YYYY-MM-DD HH:mm:ss} | {message}" +initfmt = "INIT | {extra[status]: <11} | {message}" +msgfmt = "{level: <10} | {message}" + +try: + logger.level("GENERATION", no=24, color="") + logger.level("PROMPT", no=23, color="") + logger.level("INIT", no=31, color="") + logger.level("INIT_OK", no=31, color="") + logger.level("INIT_WARN", no=31, color="") + logger.level("INIT_ERR", no=31, color="") + # Messages contain important information without which this application might not be able to be used + # As such, they have the highest priority + logger.level("MESSAGE", no=61, color="") +except TypeError: + pass + +logger.__class__.generation = partialmethod(logger.__class__.log, "GENERATION") +logger.__class__.prompt = partialmethod(logger.__class__.log, "PROMPT") +logger.__class__.init = partialmethod(logger.__class__.log, "INIT") +logger.__class__.init_ok = partialmethod(logger.__class__.log, "INIT_OK") +logger.__class__.init_warn = partialmethod(logger.__class__.log, "INIT_WARN") +logger.__class__.init_err = partialmethod(logger.__class__.log, "INIT_ERR") +logger.__class__.message = partialmethod(logger.__class__.log, "MESSAGE") + +config = { + "handlers": [ + {"sink": sys.stderr, "format": logfmt, "colorize":True, "filter": is_stderr_log}, + {"sink": sys.stdout, "format": genfmt, "level": "PROMPT", "colorize":True, "filter": is_stdout_log}, + {"sink": sys.stdout, "format": initfmt, "level": "INIT", "colorize":True, "filter": is_init_log}, + {"sink": sys.stdout, "format": msgfmt, "level": "MESSAGE", "colorize":True, "filter": is_msg_log} + ], +} +logger.configure(**config) diff --git a/scripts/nataili/util/save_sample.py b/scripts/nataili/util/save_sample.py new file mode 100644 index 0000000..5c791d3 --- /dev/null +++ b/scripts/nataili/util/save_sample.py @@ -0,0 +1,20 @@ +import os + +def save_sample(image, filename, sample_path, extension='png', jpg_quality=95, webp_quality=95, webp_lossless=True, png_compression=9): + path = os.path.join(sample_path, filename + '.' + extension) + if os.path.exists(path): + return False + if not os.path.exists(sample_path): + os.makedirs(sample_path) + if extension == 'png': + image.save(path, format='PNG', compress_level=png_compression) + elif extension == 'jpg': + image.save(path, quality=jpg_quality, optimize=True) + elif extension == 'webp': + image.save(path, quality=webp_quality, lossless=webp_lossless) + else: + return False + if os.path.exists(path): + return True + else: + return False diff --git a/scripts/nataili/util/seed_to_int.py b/scripts/nataili/util/seed_to_int.py new file mode 100644 index 0000000..61cc9fd --- /dev/null +++ b/scripts/nataili/util/seed_to_int.py @@ -0,0 +1,22 @@ +import random + +def seed_to_int(s): + if type(s) is int: + return s + if s is None or s == '': + return random.randint(0, 2**32 - 1) + + if type(s) is list: + seed_list = [] + for seed in s: + if seed is None or seed == '': + seed_list.append(random.randint(0, 2**32 - 1)) + else: + seed_list = s + + return seed_list + + n = abs(int(s) if s.isdigit() else random.Random(s).randint(0, 2**32 - 1)) + while n >= 2**32: + n = n >> 32 + return n \ No newline at end of file diff --git a/streamlit_webview.py b/streamlit_webview.py new file mode 100644 index 0000000..9853c4e --- /dev/null +++ b/streamlit_webview.py @@ -0,0 +1,15 @@ +import os, webview +from streamlit.web import bootstrap +from streamlit import config as _config + +webview.create_window('Sygil', 'http://localhost:8501', width=1000, height=800, min_size=(500, 500)) +webview.start() + +dirname = os.path.dirname(__file__) +filename = os.path.join(dirname, 'scripts/webui_streamlit.py') + +_config.set_option("server.headless", True) +args = [] + +#streamlit.cli.main_run(filename, args) +bootstrap.run(filename,'',args, flag_options={}) \ No newline at end of file diff --git a/webui.cmd b/webui.cmd index e1248a9..f54fe8b 100644 --- a/webui.cmd +++ b/webui.cmd @@ -1,17 +1,17 @@ @echo off :: This file is part of sygil-webui (https://github.com/Sygil-Dev/sygil-webui/). - +:: :: Copyright 2022 Sygil-Dev team. :: This program is free software: you can redistribute it and/or modify :: it under the terms of the GNU Affero General Public License as published by :: the Free Software Foundation, either version 3 of the License, or :: (at your option) any later version. - +:: :: This program is distributed in the hope that it will be useful, :: but WITHOUT ANY WARRANTY; without even the implied warranty of :: MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the :: GNU Affero General Public License for more details. - +:: :: You should have received a copy of the GNU Affero General Public License :: along with this program. If not, see . :: Run all commands using this script's directory as the working directory @@ -102,12 +102,11 @@ call "%v_conda_path%\Scripts\activate.bat" "%v_conda_env_name%" :PROMPT set SETUPTOOLS_USE_DISTUTILS=stdlib -IF EXIST "models\ldm\stable-diffusion-v1\model.ckpt" ( - set "PYTHONPATH=%~dp0" - python scripts\relauncher.py %* +IF EXIST "models\ldm\stable-diffusion-v1\Stable Diffusion v1.5.ckpt" ( + python -m streamlit run scripts\webui_streamlit.py --theme.base dark --server.address localhost ) ELSE ( - echo Your model file does not exist! Place it in 'models\ldm\stable-diffusion-v1' with the name 'model.ckpt'. - pause + echo Your model file does not exist! Once the WebUI launches please visit the Model Manager page and download the models by using the Download button for each model. + python -m streamlit run scripts\webui_streamlit.py --theme.base dark --server.address localhost ) ::cmd /k diff --git a/webui-streamlit.cmd b/webui_legacy.cmd similarity index 97% rename from webui-streamlit.cmd rename to webui_legacy.cmd index 45414be..654e21c 100644 --- a/webui-streamlit.cmd +++ b/webui_legacy.cmd @@ -1,17 +1,17 @@ @echo off :: This file is part of sygil-webui (https://github.com/Sygil-Dev/sygil-webui/). -:: + :: Copyright 2022 Sygil-Dev team. :: This program is free software: you can redistribute it and/or modify :: it under the terms of the GNU Affero General Public License as published by :: the Free Software Foundation, either version 3 of the License, or :: (at your option) any later version. -:: + :: This program is distributed in the hope that it will be useful, :: but WITHOUT ANY WARRANTY; without even the implied warranty of :: MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the :: GNU Affero General Public License for more details. -:: + :: You should have received a copy of the GNU Affero General Public License :: along with this program. If not, see . :: Run all commands using this script's directory as the working directory @@ -99,7 +99,8 @@ call "%v_conda_path%\Scripts\activate.bat" "%v_conda_env_name%" :PROMPT set SETUPTOOLS_USE_DISTUTILS=stdlib IF EXIST "models\ldm\stable-diffusion-v1\model.ckpt" ( - python -m streamlit run scripts\webui_streamlit.py --theme.base dark --server.address localhost + set "PYTHONPATH=%~dp0" + python scripts\relauncher.py %* ) ELSE ( echo Your model file does not exist! Place it in 'models\ldm\stable-diffusion-v1' with the name 'model.ckpt'. pause