diff --git a/webui.py b/webui.py index 1e53eb2..6a3f085 100644 --- a/webui.py +++ b/webui.py @@ -11,6 +11,7 @@ parser.add_argument("--n_rows", type=int, default=-1, help="rows in the grid; us parser.add_argument("--config", type=str, default="configs/stable-diffusion/v1-inference.yaml", help="path to config which constructs model",) parser.add_argument("--ckpt", type=str, default="models/ldm/stable-diffusion-v1/model.ckpt", help="path to checkpoint of model",) parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast") +parser.add_argument("--optimized", action='store_true', help="load the model onto the device piecemeal instead of all at once to reduce VRAM usage at the cost of performance") parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN')) # i disagree with where you're putting it but since all guidefags are doing it this way, there you go parser.add_argument("--realesrgan-dir", type=str, help="RealESRGAN directory", default=('./src/realesrgan' if os.path.exists('./src/realesrgan') else './RealESRGAN')) parser.add_argument("--realesrgan-model", type=str, help="Upscaling model for RealESRGAN", default=('RealESRGAN_x4plus')) @@ -110,6 +111,14 @@ def load_model_from_config(config, ckpt, verbose=False): model.eval() return model +def load_sd_from_config(ckpt, verbose=False): + print(f"Loading model from {ckpt}") + pl_sd = torch.load(ckpt, map_location="cpu") + if "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") + sd = pl_sd["state_dict"] + return sd + def crash(e, s): global model global device @@ -244,7 +253,7 @@ def load_RealESRGAN(model_name: str): instance = RealESRGANer(scale=2, model_path=model_path, model=RealESRGAN_models[model_name], pre_pad=0, half=not opt.no_half) instance.model.name = model_name instance.device = torch.device(f'cuda:{opt.gpu}') # another way to set gpu device - + return instance GFPGAN = None @@ -270,11 +279,49 @@ def try_loading_RealESRGAN(model_name: str): print(traceback.format_exc(), file=sys.stderr) try_loading_RealESRGAN('RealESRGAN_x4plus') -config = OmegaConf.load("configs/stable-diffusion/v1-inference.yaml") -model = load_model_from_config(config, "models/ldm/stable-diffusion-v1/model.ckpt") +if opt.optimized: + config = OmegaConf.load("optimizedSD/v1-inference.yaml") + sd = load_sd_from_config("models/ldm/stable-diffusion-v1/model.ckpt") + li, lo = [], [] + for key, v_ in sd.items(): + sp = key.split('.') + if(sp[0]) == 'model': + if('input_blocks' in sp): + li.append(key) + elif('middle_block' in sp): + li.append(key) + elif('time_embed' in sp): + li.append(key) + else: + lo.append(key) + for key in li: + sd['model1.' + key[6:]] = sd.pop(key) + for key in lo: + sd['model2.' + key[6:]] = sd.pop(key) -device = torch.device(f"cuda") if torch.cuda.is_available() else torch.device("cpu") -model = (model if opt.no_half else model.half()).to(device) + config.modelUNet.params.small_batch = False + + model = instantiate_from_config(config.modelUNet) + _, _ = model.load_state_dict(sd, strict=False) + model.cuda() + model.eval() + + modelCS = instantiate_from_config(config.modelCondStage) + _, _ = modelCS.load_state_dict(sd, strict=False) + modelCS.eval() + + modelFS = instantiate_from_config(config.modelFirstStage) + _, _ = modelFS.load_state_dict(sd, strict=False) + modelFS.eval() + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + model = model if opt.no_half else model.half() + modelCS = modelCS if opt.no_half else modelCS.half() +else: + config = OmegaConf.load("configs/stable-diffusion/v1-inference.yaml") + model = load_model_from_config(config, "models/ldm/stable-diffusion-v1/model.ckpt") + + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + model = (model if opt.no_half else model.half()).to(device) def load_embeddings(fp): if fp is not None and hasattr(model, "embedding_manager"): @@ -305,9 +352,9 @@ def seed_to_int(s): if type(s) is int: return s if s is None or s == '': - return random.randint(0, 2**32 - 1) - n = abs(int(s) if s.isdigit() else random.Random(s).randint(0, 2**32 - 1)) - while n >= 2**32: + return random.randint(0,2**32) + n = abs(int(s) if s.isdigit() else hash(s)) + while n > 2**32: n = n >> 32 return n @@ -425,10 +472,10 @@ def resize_image(resize_mode, im, width, height): def check_prompt_length(prompt, comments): """this function tests if prompt is too long, and if so, adds a message to comments""" - tokenizer = model.cond_stage_model.tokenizer - max_length = model.cond_stage_model.max_length + tokenizer = (model if not opt.optimized else modelCS).cond_stage_model.tokenizer + max_length = (model if not opt.optimized else modelCS).cond_stage_model.max_length - info = model.cond_stage_model.tokenizer([prompt], truncation=True, max_length=max_length, return_overflowing_tokens=True, padding="max_length", return_tensors="pt") + info = (model if not opt.optimized else modelCS).cond_stage_model.tokenizer([prompt], truncation=True, max_length=max_length, return_overflowing_tokens=True, padding="max_length", return_tensors="pt") ovf = info['overflowing_tokens'][0] overflowing_count = ovf.shape[0] if overflowing_count == 0: @@ -445,7 +492,7 @@ def process_images( outpath, func_init, func_sample, prompt, seed, sampler_name, skip_grid, skip_save, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, use_RealESRGAN, realesrgan_model_name, fp, ddim_eta=0.0, do_not_save_grid=False, normalize_prompt_weights=True, init_img=None, init_mask=None, - keep_mask=False, mask_blur_strength=3, denoising_strength=0.75, resize_mode=None, uses_loopback=False, + keep_mask=False, denoising_strength=0.75, resize_mode=None, uses_loopback=False, uses_random_seed_loopback=False, sort_samples=True, write_info_files=True, jpg_sample=False): """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch""" assert prompt is not None @@ -463,7 +510,7 @@ def process_images( sample_path = os.path.join(outpath, "samples") os.makedirs(sample_path, exist_ok=True) - grid_count = len([_ for x in os.listdir() if x.endswith(('.png', '.jpg'))]) - 1 # start at 0 + grid_count = len(os.listdir(outpath)) - 1 comments = [] @@ -501,7 +548,7 @@ def process_images( precision_scope = autocast if opt.precision == "autocast" else nullcontext output_images = [] stats = [] - with torch.no_grad(), precision_scope("cuda"), model.ema_scope(): + with torch.no_grad(), precision_scope("cuda"), (model.ema_scope() if not opt.optimized else nullcontext()): init_data = func_init() tic = time.time() @@ -509,7 +556,9 @@ def process_images( prompts = all_prompts[n * batch_size:(n + 1) * batch_size] seeds = all_seeds[n * batch_size:(n + 1) * batch_size] - uc = model.get_learned_conditioning(len(prompts) * [""]) + if opt.optimized: + modelCS.to(device) + uc = (model if not opt.optimized else modelCS).get_learned_conditioning(len(prompts) * [""]) if isinstance(prompts, tuple): prompts = list(prompts) @@ -528,18 +577,26 @@ def process_images( weight = weight / totalPromptWeight #print(f"{subprompts[i]} {weight*100.0}%") # note if alpha negative, it functions same as torch.sub - c = torch.add(c,model.get_learned_conditioning(subprompts[i]), alpha=weight) + c = torch.add(c, (model if not opt.optimized else modelCS).get_learned_conditioning(subprompts[i]), alpha=weight) else: # just behave like usual - c = model.get_learned_conditioning(prompts) + c = (model if not opt.optimized else modelCS).get_learned_conditioning(prompts) shape = [opt_C, height // opt_f, width // opt_f] + if opt.optimized: + mem = torch.cuda.memory_allocated()/1e6 + modelCS.to("cpu") + while(torch.cuda.memory_allocated()/1e6 >= mem): + time.sleep(1) + # we manually generate all input noises because each one should have a specific seed x = create_random_tensors([opt_C, height // opt_f, width // opt_f], seeds=seeds) samples_ddim = func_sample(init_data=init_data, x=x, conditioning=c, unconditional_conditioning=uc, sampler_name=sampler_name) + if opt.optimized: + modelFS.to(device) - x_samples_ddim = model.decode_first_stage(samples_ddim) + x_samples_ddim = (model if not opt.optimized else modelFS).decode_first_stage(samples_ddim) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) for i, x_sample in enumerate(x_samples_ddim): x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') @@ -561,7 +618,7 @@ def process_images( image = Image.fromarray(x_sample) if init_mask: #init_mask = init_mask if keep_mask else ImageOps.invert(init_mask) - init_mask = init_mask.filter(ImageFilter.GaussianBlur(mask_blur_strength)) + init_mask = init_mask.filter(ImageFilter.GaussianBlur(3)) init_mask = init_mask.convert('L') init_img = init_img.convert('RGB') image = image.convert('RGB') @@ -584,11 +641,11 @@ def process_images( sanitized_prompt = sanitized_prompt[:128] #200 is too long sample_path_i = os.path.join(sample_path, sanitized_prompt) os.makedirs(sample_path_i, exist_ok=True) - base_count = len([_ for x in os.listdir() if x.endswith(('.png', '.jpg'))]) - 1 # start at 0 + base_count = len(os.listdir(sample_path_i)) filename = f"{base_count:05}-{seeds[i]}" else: sample_path_i = sample_path - base_count = len([_ for x in os.listdir() if x.endswith(('.png', '.jpg'))]) - 1 # start at 0 + base_count = len(os.listdir(sample_path_i)) sanitized_prompt = sanitized_prompt filename = f"{base_count:05}-{seeds[i]}_{sanitized_prompt}"[:128] #same as before if not skip_save: @@ -667,6 +724,13 @@ def process_images( grid_file = f"grid-{grid_count:05}-{seed}_{prompts[i].replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128]}.jpg" grid.save(os.path.join(outpath, grid_file), 'jpeg', quality=100, optimize=True) grid_count += 1 + + if opt.optimized: + mem = torch.cuda.memory_allocated()/1e6 + modelFS.to("cpu") + while(torch.cuda.memory_allocated()/1e6 >= mem): + time.sleep(1) + toc = time.time() mem_max_used, mem_total = mem_mon.read_and_stop() @@ -815,7 +879,7 @@ class Flagging(gr.FlaggingCallback): print("Logged:", filenames[0]) -def img2img(prompt: str, image_editor_mode: str, init_info, mask_mode: str, mask_blur_strength: int, ddim_steps: int, sampler_name: str, +def img2img(prompt: str, image_editor_mode: str, init_info, mask_mode: str, ddim_steps: int, sampler_name: str, toggles: List[int], realesrgan_model_name: str, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int, fp): outpath = opt.outdir_img2img or opt.outdir or "outputs/img2img-samples" @@ -875,10 +939,19 @@ def img2img(prompt: str, image_editor_mode: str, init_info, mask_mode: str, mask image = image[None].transpose(0, 3, 1, 2) image = torch.from_numpy(image) + if opt.optimized: + modelFS.to(device) + init_image = 2. * image - 1. init_image = init_image.to(device) init_image = repeat(init_image, '1 ... -> b ...', b=batch_size) - init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space + init_latent = (model if not opt.optimized else modelFS).get_first_stage_encoding((model if not opt.optimized else modelFS).encode_first_stage(init_image)) # move to latent space + + if opt.optimized: + mem = torch.cuda.memory_allocated()/1e6 + modelFS.to("cpu") + while(torch.cuda.memory_allocated()/1e6 >= mem): + time.sleep(1) return init_latent, @@ -936,7 +1009,6 @@ def img2img(prompt: str, image_editor_mode: str, init_info, mask_mode: str, mask init_img=init_img, init_mask=init_mask, keep_mask=keep_mask, - mask_blur_strength=mask_blur_strength, denoising_strength=denoising_strength, resize_mode=resize_mode, uses_loopback=loopback, @@ -958,7 +1030,7 @@ def img2img(prompt: str, image_editor_mode: str, init_info, mask_mode: str, mask history.append(init_img) if not skip_grid: - grid_count = len([_ for x in os.listdir() if x.endswith(('.png', '.jpg'))]) - 1 # start at 0 + grid_count = len(os.listdir(outpath)) - 1 grid = image_grid(history, batch_size, force_n_rows=1) grid_file = f"grid-{grid_count:05}-{seed}_{prompt.replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128]}.jpg" grid.save(os.path.join(outpath, grid_file), 'jpeg', quality=100, optimize=True) @@ -992,7 +1064,6 @@ def img2img(prompt: str, image_editor_mode: str, init_info, mask_mode: str, mask init_img=init_img, init_mask=init_mask, keep_mask=keep_mask, - mask_blur_strength=mask_blur_strength, denoising_strength=denoising_strength, resize_mode=resize_mode, uses_loopback=loopback, @@ -1187,8 +1258,8 @@ img2img_image_mode = 'sketch' def change_image_editor_mode(choice, cropped_image, resize_mode, width, height): if choice == "Mask": - return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)] - return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)] + return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)] + return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)] def update_image_mask(cropped_image, resize_mode, width, height): resized_cropped_image = resize_image(resize_mode, cropped_image, width, height) if cropped_image else None @@ -1203,28 +1274,6 @@ def copy_img_to_input(selected=1, imgs = []): except IndexError: return [None, None] -help_text = """ - ## Mask/Crop - * The masking/cropping is very temperamental. - * It may take some time for the image to show when switching from Crop to Mask. - * If the image doesn't appear after switching to Mask, switch back to Crop and then back again to Mask - * If the mask appears distorted (the brush is weirdly shaped instead of round), switch back to Crop and then back again to Mask. - - ## Advanced Editor - * For now the button needs to be clicked twice the first time. - * Once you have edited your image, you _need_ to click the save button for the next step to work. - * Clear the image from the crop editor (click the x) - * Click "Get Image from Advanced Editor" to get the image you saved. If it doesn't work, try opening the editor and saving again. - - If it keeps not working, try switching modes again, switch tabs, clear the image or reload. -""" - -def show_help(): - return [gr.update(visible=False), gr.update(visible=True), gr.update(value=help_text)] - -def hide_help(): - return [gr.update(visible=True), gr.update(visible=False), gr.update(value="")] - with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion WebUI") as demo: with gr.Tabs(): with gr.TabItem("Stable Diffusion Text-to-Image Unified"): @@ -1265,16 +1314,19 @@ with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion WebUI") gr.Markdown("Generate images from images with Stable Diffusion") img2img_prompt = gr.Textbox(label="Prompt", placeholder="A fantasy landscape, trending on artstation.", lines=1, value=img2img_defaults['prompt']) img2img_image_editor_mode = gr.Radio(choices=["Mask", "Crop"], label="Image Editor Mode", value="Crop") - img2img_show_help_btn = gr.Button("Show Hints") - img2img_hide_help_btn = gr.Button("Hide Hints", visible=False) - img2img_help = gr.Markdown(visible=False, value="") - with gr.Row(): - img2img_painterro_btn = gr.Button("Advanced Editor") - img2img_copy_from_painterro_btn = gr.Button(value="Get Image from Advanced Editor") + gr.Markdown( + """ + The masking/cropping is very temperamental. + It may take some time for the image to show when switching from Crop to Mask. + * If the image doesn't appear after switching to Mask, switch back to Crop and then back again to Mask + * If the mask appears distorted (the brush is weirdly shaped instead of round), switch back to Crop and then back again to Mask. + + If it keeps not working, try switching modes again, switch tabs, clear the image or reload. + """ + ) img2img_image_editor = gr.Image(value=sample_img2img, source="upload", interactive=True, type="pil", tool="select") img2img_image_mask = gr.Image(value=sample_img2img, source="upload", interactive=True, type="pil", tool="sketch", visible=False) - img2img_mask = gr.Radio(choices=["Keep masked area", "Regenerate only masked area"], label="Mask Mode", type="index", value=img2img_mask_modes[img2img_defaults['mask_mode']], visible=False) - img2img_mask_blur_strength = gr.Slider(minimum=1, maximum=10, step=1, label="How much blurry should the mask be? (to avoid hard edges)", value=3, visible=False) + img2img_mask = gr.Radio(choices=["Keep masked area", "Regenerate only masked area"], label="Mask Mode", type="index", value=img2img_mask_modes[img2img_defaults['mask_mode']]) img2img_steps = gr.Slider(minimum=1, maximum=250, step=1, label="Sampling Steps", value=img2img_defaults['ddim_steps']) img2img_sampling = gr.Radio(label='Sampling method (k_lms is default k-diffusion sampler)', choices=["DDIM", 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms'], value=img2img_defaults['sampler_name']) img2img_toggles = gr.CheckboxGroup(label='', choices=img2img_toggles, value=img2img_toggle_defaults, type="index") @@ -1288,8 +1340,8 @@ with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion WebUI") img2img_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=img2img_defaults["width"]) img2img_resize = gr.Radio(label="Resize mode", choices=["Just resize", "Crop and resize", "Resize and fill"], type="index", value=img2img_resize_modes[img2img_defaults['resize_mode']]) img2img_embeddings = gr.File(label = "Embeddings file for textual inversion", visible=hasattr(model, "embedding_manager")) - img2img_btn_mask = gr.Button("Generate", visible=False).style(full_width=True) - img2img_btn_editor = gr.Button("Generate").style(full_width=True) + img2img_btn_mask = gr.Button("Generate", visible=False) + img2img_btn_editor = gr.Button("Generate") with gr.Column(): output_img2img_gallery = gr.Gallery(label="Images") output_img2img_select_image = gr.Number(label='Select image number from results for copying', value=1, precision=None) @@ -1302,7 +1354,7 @@ with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion WebUI") img2img_image_editor_mode.change( change_image_editor_mode, [img2img_image_editor_mode, img2img_image_editor, img2img_resize, img2img_width, img2img_height], - [img2img_image_editor, img2img_image_mask, img2img_btn_editor, img2img_btn_mask, img2img_painterro_btn, img2img_copy_from_painterro_btn, img2img_mask, img2img_mask_blur_strength] + [img2img_image_editor, img2img_image_mask, img2img_btn_editor, img2img_btn_mask] ) img2img_image_editor.edit( @@ -1311,18 +1363,6 @@ with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion WebUI") img2img_image_mask ) - img2img_show_help_btn.click( - show_help, - None, - [img2img_show_help_btn, img2img_hide_help_btn, img2img_help] - ) - - img2img_hide_help_btn.click( - hide_help, - None, - [img2img_show_help_btn, img2img_hide_help_btn, img2img_help] - ) - output_img2img_copy_to_input_btn.click( copy_img_to_input, [output_img2img_select_image, output_img2img_gallery], @@ -1337,41 +1377,16 @@ with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion WebUI") img2img_btn_mask.click( img2img, - [img2img_prompt, img2img_image_editor_mode, img2img_image_mask, img2img_mask, img2img_mask_blur_strength, img2img_steps, img2img_sampling, img2img_toggles, img2img_realesrgan_model_name, img2img_batch_count, img2img_batch_size, img2img_cfg, img2img_denoising, img2img_seed, img2img_height, img2img_width, img2img_resize, img2img_embeddings], + [img2img_prompt, img2img_image_editor_mode, img2img_image_mask, img2img_mask, img2img_steps, img2img_sampling, img2img_toggles, img2img_realesrgan_model_name, img2img_batch_count, img2img_batch_size, img2img_cfg, img2img_denoising, img2img_seed, img2img_height, img2img_width, img2img_resize, img2img_embeddings], [output_img2img_gallery, output_img2img_seed, output_img2img_params, output_img2img_stats] ) img2img_btn_editor.click( img2img, - [img2img_prompt, img2img_image_editor_mode, img2img_image_editor, img2img_mask, img2img_mask_blur_strength, img2img_steps, img2img_sampling, img2img_toggles, img2img_realesrgan_model_name, img2img_batch_count, img2img_batch_size, img2img_cfg, img2img_denoising, img2img_seed, img2img_height, img2img_width, img2img_resize, img2img_embeddings], + [img2img_prompt, img2img_image_editor_mode, img2img_image_editor, img2img_mask, img2img_steps, img2img_sampling, img2img_toggles, img2img_realesrgan_model_name, img2img_batch_count, img2img_batch_size, img2img_cfg, img2img_denoising, img2img_seed, img2img_height, img2img_width, img2img_resize, img2img_embeddings], [output_img2img_gallery, output_img2img_seed, output_img2img_params, output_img2img_stats] ) - img2img_painterro_btn.click(None, [img2img_image_editor], None, _js="""(img) => { - try { - Painterro({ - hiddenTools: ['arrow'], - saveHandler: function (image, done) { - localStorage.setItem('painterro-image', image.asDataURL()); - done(true); - }, - }).show(Array.isArray(img) ? img[0] : img); - } catch(e) { - const script = document.createElement('script'); - script.src = 'https://unpkg.com/painterro@1.2.78/build/painterro.min.js'; - document.head.appendChild(script); - const style = document.createElement('style'); - style.appendChild(document.createTextNode('.ptro-holder-wrapper { z-index: 9999 !important; }')); - document.head.appendChild(style); - } - return []; - }""") - - img2img_copy_from_painterro_btn.click(None, None, [img2img_image_editor, img2img_image_mask], _js="""() => { - const image = localStorage.getItem('painterro-image') - return [image, image]; - }""") - if GFPGAN is not None: gfpgan_defaults = { 'strength': 100,