From 016554e43740e0b7ded75e89255de81270de9d6c Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 22 Aug 2023 18:49:08 +0300 Subject: [PATCH] add --medvram-sdxl --- modules/cmd_args.py | 1 + modules/interrogate.py | 5 ++--- modules/lowvram.py | 18 ++++++++++++++++-- modules/sd_models.py | 16 ++++++++-------- modules/sd_unet.py | 2 +- modules/sd_vae.py | 4 ++-- modules/shared.py | 2 +- 7 files changed, 31 insertions(+), 17 deletions(-) diff --git a/modules/cmd_args.py b/modules/cmd_args.py index 9f8e5b30..f0f361bd 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -35,6 +35,7 @@ parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_ parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory") parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui") parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage") +parser.add_argument("--medvram-sdxl", action='store_true', help="enable --medvram optimization just for SDXL models") parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage") parser.add_argument("--lowram", action='store_true', help="load stable diffusion checkpoint weights to VRAM instead of RAM") parser.add_argument("--always-batch-cond-uncond", action='store_true', help="does not do anything") diff --git a/modules/interrogate.py b/modules/interrogate.py index a3ae1dd5..3045560d 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -186,9 +186,8 @@ class InterrogateModels: res = "" shared.state.begin(job="interrogate") try: - if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: - lowvram.send_everything_to_cpu() - devices.torch_gc() + lowvram.send_everything_to_cpu() + devices.torch_gc() self.load() diff --git a/modules/lowvram.py b/modules/lowvram.py index 96f52b7b..45701046 100644 --- a/modules/lowvram.py +++ b/modules/lowvram.py @@ -1,5 +1,5 @@ import torch -from modules import devices +from modules import devices, shared module_in_gpu = None cpu = torch.device("cpu") @@ -14,6 +14,20 @@ def send_everything_to_cpu(): module_in_gpu = None +def is_needed(sd_model): + return shared.cmd_opts.lowvram or shared.cmd_opts.medvram or shared.cmd_opts.medvram_sdxl and hasattr(sd_model, 'conditioner') + + +def apply(sd_model): + enable = is_needed(sd_model) + shared.parallel_processing_allowed = not enable + + if enable: + setup_for_low_vram(sd_model, not shared.cmd_opts.lowvram) + else: + sd_model.lowvram = False + + def setup_for_low_vram(sd_model, use_medvram): if getattr(sd_model, 'lowvram', False): return @@ -130,4 +144,4 @@ def setup_for_low_vram(sd_model, use_medvram): def is_enabled(sd_model): - return getattr(sd_model, 'lowvram', False) + return sd_model.lowvram diff --git a/modules/sd_models.py b/modules/sd_models.py index 27d15e66..4331853a 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -517,7 +517,7 @@ def get_empty_cond(sd_model): def send_model_to_cpu(m): - if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: + if m.lowvram: lowvram.send_everything_to_cpu() else: m.to(devices.cpu) @@ -525,17 +525,17 @@ def send_model_to_cpu(m): devices.torch_gc() -def model_target_device(): - if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: +def model_target_device(m): + if lowvram.is_needed(m): return devices.cpu else: return devices.device def send_model_to_device(m): - if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: - lowvram.setup_for_low_vram(m, shared.cmd_opts.medvram) - else: + lowvram.apply(m) + + if not m.lowvram: m.to(shared.device) @@ -601,7 +601,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): '': torch.float16, } - with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(), weight_dtype_conversion=weight_dtype_conversion): + with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(sd_model), weight_dtype_conversion=weight_dtype_conversion): load_model_weights(sd_model, checkpoint_info, state_dict, timer) timer.record("load weights from state dict") @@ -743,7 +743,7 @@ def reload_model_weights(sd_model=None, info=None): script_callbacks.model_loaded_callback(sd_model) timer.record("script callbacks") - if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: + if not sd_model.lowvram: sd_model.to(devices.device) timer.record("move model to device") diff --git a/modules/sd_unet.py b/modules/sd_unet.py index 6d708ad2..5525cfbc 100644 --- a/modules/sd_unet.py +++ b/modules/sd_unet.py @@ -47,7 +47,7 @@ def apply_unet(option=None): if current_unet_option is None: current_unet = None - if not (shared.cmd_opts.lowvram or shared.cmd_opts.medvram): + if not shared.sd_model.lowvram: shared.sd_model.model.diffusion_model.to(devices.device) return diff --git a/modules/sd_vae.py b/modules/sd_vae.py index ee118656..669097da 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -263,7 +263,7 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified): if loaded_vae_file == vae_file: return - if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: + if sd_model.lowvram: lowvram.send_everything_to_cpu() else: sd_model.to(devices.cpu) @@ -275,7 +275,7 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified): sd_hijack.model_hijack.hijack(sd_model) script_callbacks.model_loaded_callback(sd_model) - if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: + if not sd_model.lowvram: sd_model.to(devices.device) print("VAE weights loaded.") diff --git a/modules/shared.py b/modules/shared.py index 0c57b712..f321159d 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -11,7 +11,7 @@ cmd_opts = shared_cmd_options.cmd_opts parser = shared_cmd_options.parser batch_cond_uncond = True # old field, unused now in favor of shared.opts.batch_cond_uncond -parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram +parallel_processing_allowed = True styles_filename = cmd_opts.styles_file config_filename = cmd_opts.ui_settings_file hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config}