From 9b73f463cfa2068da6ddb119f650c406345c337a Mon Sep 17 00:00:00 2001 From: ZeroCool940711 Date: Sat, 24 Sep 2022 05:11:33 -0700 Subject: [PATCH] - Added function to convert bytes to a human readable string. - Added basic implementation of "enable_minimal_memory_usage". --- scripts/sd_utils.py | 38 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 36 insertions(+), 2 deletions(-) diff --git a/scripts/sd_utils.py b/scripts/sd_utils.py index f712753..acafa58 100644 --- a/scripts/sd_utils.py +++ b/scripts/sd_utils.py @@ -144,8 +144,14 @@ def set_page_title(title): title.text = '{title}' " /> """) - - + +def human_readable_size(size, decimal_places=3): + """Return a human readable size from bytes.""" + for unit in ['B','KB','MB','GB','TB']: + if size < 1024.0: + break + size /= 1024.0 + return f"{size:.{decimal_places}f}{unit}" @retry(tries=5) def load_models(continue_prev_run = False, use_GFPGAN=False, use_RealESRGAN=False, RealESRGAN_model="RealESRGAN_x4plus", @@ -1130,6 +1136,15 @@ def draw_prompt_matrix(im, width, height, all_prompts): return result +# +def enable_minimal_memory_usage(model): + """Moves only unet to fp16 and to CUDA, while keepping lighter models on CPUs""" + model.unet.to(torch.float16).to(torch.device("cuda")) + model.enable_attention_slicing(1) + + torch.cuda.empty_cache() + torch_gc() + def check_prompt_length(prompt, comments): """this function tests if prompt is too long, and if so, adds a message to comments""" @@ -1149,6 +1164,25 @@ def check_prompt_length(prompt, comments): comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") +# +def custom_models_available(): + # + # Allow for custom models to be used instead of the default one, + # an example would be Waifu-Diffusion or any other fine tune of stable diffusion + st.session_state["custom_models"]:sorted = [] + + for root, dirs, files in os.walk(os.path.join("models", "custom")): + for file in files: + if os.path.splitext(file)[1] == '.ckpt': + st.session_state["custom_models"].append(os.path.splitext(file)[0]) + + + if len(st.session_state["custom_models"]) > 0: + st.session_state["CustomModel_available"] = True + st.session_state["custom_models"].append("Stable Diffusion v1.4") + else: + st.session_state["CustomModel_available"] = False + def save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, save_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, save_individual_images, model_name):