From cdac5ace1456ba779d5a0171ff8757f31955bfee Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 16 May 2023 11:54:02 +0300 Subject: [PATCH] suppress ENSD infotext for samplers that don't use it --- modules/processing.py | 11 +++++++---- modules/sd_samplers.py | 8 +++++++- modules/sd_samplers_common.py | 21 ++++++++++++++++++++- modules/sd_samplers_compvis.py | 8 ++++++-- modules/sd_samplers_kdiffusion.py | 16 ++++++++-------- 5 files changed, 48 insertions(+), 16 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 94fe2625..15806f78 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -13,7 +13,7 @@ from skimage import exposure from typing import Any, Dict, List import modules.sd_hijack -from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts +from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common from modules.sd_hijack import model_hijack from modules.shared import opts, cmd_opts, state import modules.shared as shared @@ -480,6 +480,10 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers) enable_hr = getattr(p, 'enable_hr', False) + uses_ensd = opts.eta_noise_seed_delta != 0 + if uses_ensd: + uses_ensd = sd_samplers_common.is_sampler_using_eta_noise_seed_delta(p) + generation_params = { "Steps": p.steps, "Sampler": p.sampler_name, @@ -496,17 +500,16 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "Denoising strength": getattr(p, 'denoising_strength', None), "Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None, "Clip skip": None if clip_skip <= 1 else clip_skip, - "ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta, + "ENSD": opts.eta_noise_seed_delta if uses_ensd else None, "Token merging ratio": None if opts.token_merging_ratio == 0 else opts.token_merging_ratio, "Token merging ratio hr": None if not enable_hr or opts.token_merging_ratio_hr == 0 else opts.token_merging_ratio_hr, "Init image hash": getattr(p, 'init_img_hash', None), "RNG": opts.randn_source if opts.randn_source != "GPU" else None, "NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond, + **p.extra_generation_params, "Version": program_version() if opts.add_version_to_infotext else None, } - generation_params.update(p.extra_generation_params) - generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None]) negative_prompt_text = f"\nNegative prompt: {p.all_negative_prompts[index]}" if p.all_negative_prompts[index] else "" diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 4f1bf21d..f22aad8f 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -14,12 +14,18 @@ samplers_for_img2img = [] samplers_map = {} -def create_sampler(name, model): +def find_sampler_config(name): if name is not None: config = all_samplers_map.get(name, None) else: config = all_samplers[0] + return config + + +def create_sampler(name, model): + config = find_sampler_config(name) + assert config is not None, f'bad sampler name: {name}' sampler = config.constructor(model) diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index bc074238..92880caf 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -2,7 +2,7 @@ from collections import namedtuple import numpy as np import torch from PIL import Image -from modules import devices, processing, images, sd_vae_approx +from modules import devices, processing, images, sd_vae_approx, sd_samplers from modules.shared import opts, state import modules.shared as shared @@ -58,6 +58,25 @@ def store_latent(decoded): shared.state.assign_current_image(sample_to_image(decoded)) +def is_sampler_using_eta_noise_seed_delta(p): + """returns whether sampler from config will use eta noise seed delta for image creation""" + + sampler_config = sd_samplers.find_sampler_config(p.sampler_name) + + eta = p.eta + + if eta is None and p.sampler is not None: + eta = p.sampler.eta + + if eta is None and sampler_config is not None: + eta = 0 if sampler_config.options.get("default_eta_is_0", False) else 1.0 + + if eta == 0: + return False + + return sampler_config.options.get("uses_ensd", False) + + class InterruptedException(BaseException): pass diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py index b1ee3be7..bdae8b40 100644 --- a/modules/sd_samplers_compvis.py +++ b/modules/sd_samplers_compvis.py @@ -11,7 +11,7 @@ import modules.models.diffusion.uni_pc samplers_data_compvis = [ - sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}), + sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {"default_eta_is_0": True, "uses_ensd": True}), sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}), sd_samplers_common.SamplerData('UniPC', lambda model: VanillaStableDiffusionSampler(modules.models.diffusion.uni_pc.UniPCSampler, model), [], {}), ] @@ -134,7 +134,11 @@ class VanillaStableDiffusionSampler: self.update_step(x) def initialize(self, p): - self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim + if self.is_ddim: + self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim + else: + self.eta = 0.0 + if self.eta != 0.0: p.extra_generation_params["Eta DDIM"] = self.eta diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 61f23ad7..5455561a 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -11,21 +11,21 @@ from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback samplers_k_diffusion = [ - ('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}), + ('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {"uses_ensd": True}), ('Euler', 'sample_euler', ['k_euler'], {}), ('LMS', 'sample_lms', ['k_lms'], {}), ('Heun', 'sample_heun', ['k_heun'], {}), ('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True}), - ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True}), - ('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}), + ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True, "uses_ensd": True}), + ('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {"uses_ensd": True}), ('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}), ('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {}), - ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}), - ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}), + ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {"uses_ensd": True}), + ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {"uses_ensd": True}), ('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}), - ('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}), - ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}), - ('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}), + ('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True}), + ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True}), + ('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras', "uses_ensd": True}), ('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}), ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras'}), ]