From 0a89cd1a584b1584a0609c0ba27fb35c434b0b68 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 24 Jul 2023 22:08:08 +0300 Subject: [PATCH] Use less RAM when creating models --- modules/cmd_args.py | 1 + modules/sd_disable_initialization.py | 106 +++++++++++++++++++++++++-- modules/sd_models.py | 16 ++-- webui.py | 4 +- 4 files changed, 114 insertions(+), 13 deletions(-) diff --git a/modules/cmd_args.py b/modules/cmd_args.py index dd5fadc4..cb4ec5f7 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -67,6 +67,7 @@ parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="pre parser.add_argument("--disable-opt-split-attention", action='store_true', help="prefer no cross-attention layer optimization for automatic choice of optimization") parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI") parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower) +parser.add_argument("--disable-model-loading-ram-optimization", action='store_true', help="disable an optimization that reduces RAM use when loading a model") parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests") parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None) parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False) diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py index 9fc89dc6..695c5736 100644 --- a/modules/sd_disable_initialization.py +++ b/modules/sd_disable_initialization.py @@ -3,8 +3,31 @@ import open_clip import torch import transformers.utils.hub +from modules import shared -class DisableInitialization: + +class ReplaceHelper: + def __init__(self): + self.replaced = [] + + def replace(self, obj, field, func): + original = getattr(obj, field, None) + if original is None: + return None + + self.replaced.append((obj, field, original)) + setattr(obj, field, func) + + return original + + def restore(self): + for obj, field, original in self.replaced: + setattr(obj, field, original) + + self.replaced.clear() + + +class DisableInitialization(ReplaceHelper): """ When an object of this class enters a `with` block, it starts: - preventing torch's layer initialization functions from working @@ -21,7 +44,7 @@ class DisableInitialization: """ def __init__(self, disable_clip=True): - self.replaced = [] + super().__init__() self.disable_clip = disable_clip def replace(self, obj, field, func): @@ -86,8 +109,81 @@ class DisableInitialization: self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache) def __exit__(self, exc_type, exc_val, exc_tb): - for obj, field, original in self.replaced: - setattr(obj, field, original) + self.restore() - self.replaced.clear() +class InitializeOnMeta(ReplaceHelper): + """ + Context manager that causes all parameters for linear/conv2d/mha layers to be allocated on meta device, + which results in those parameters having no values and taking no memory. model.to() will be broken and + will need to be repaired by using LoadStateDictOnMeta below when loading params from state dict. + + Usage: + ``` + with sd_disable_initialization.InitializeOnMeta(): + sd_model = instantiate_from_config(sd_config.model) + ``` + """ + + def __enter__(self): + if shared.cmd_opts.disable_model_loading_ram_optimization: + return + + def set_device(x): + x["device"] = "meta" + return x + + linear_init = self.replace(torch.nn.Linear, '__init__', lambda *args, **kwargs: linear_init(*args, **set_device(kwargs))) + conv2d_init = self.replace(torch.nn.Conv2d, '__init__', lambda *args, **kwargs: conv2d_init(*args, **set_device(kwargs))) + mha_init = self.replace(torch.nn.MultiheadAttention, '__init__', lambda *args, **kwargs: mha_init(*args, **set_device(kwargs))) + self.replace(torch.nn.Module, 'to', lambda *args, **kwargs: None) + + def __exit__(self, exc_type, exc_val, exc_tb): + self.restore() + + +class LoadStateDictOnMeta(ReplaceHelper): + """ + Context manager that allows to read parameters from state_dict into a model that has some of its parameters in the meta device. + As those parameters are read from state_dict, they will be deleted from it, so by the end state_dict will be mostly empty, to save memory. + Meant to be used together with InitializeOnMeta above. + + Usage: + ``` + with sd_disable_initialization.LoadStateDictOnMeta(state_dict): + model.load_state_dict(state_dict, strict=False) + ``` + """ + + def __init__(self, state_dict, device): + super().__init__() + self.state_dict = state_dict + self.device = device + + def __enter__(self): + if shared.cmd_opts.disable_model_loading_ram_optimization: + return + + sd = self.state_dict + device = self.device + + def load_from_state_dict(original, self, state_dict, prefix, *args, **kwargs): + params = [(name, param) for name, param in self._parameters.items() if param is not None and param.is_meta] + + for name, param in params: + if param.is_meta: + self._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device), requires_grad=param.requires_grad) + + original(self, state_dict, prefix, *args, **kwargs) + + for name, _ in params: + key = prefix + name + if key in sd: + del sd[key] + + linear_load_from_state_dict = self.replace(torch.nn.Linear, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(linear_load_from_state_dict, *args, **kwargs)) + conv2d_load_from_state_dict = self.replace(torch.nn.Conv2d, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(conv2d_load_from_state_dict, *args, **kwargs)) + mha_load_from_state_dict = self.replace(torch.nn.MultiheadAttention, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(mha_load_from_state_dict, *args, **kwargs)) + + def __exit__(self, exc_type, exc_val, exc_tb): + self.restore() diff --git a/modules/sd_models.py b/modules/sd_models.py index fb31a793..acb1e817 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -460,7 +460,6 @@ def get_empty_cond(sd_model): return sd_model.cond_stage_model([""]) - def load_model(checkpoint_info=None, already_loaded_state_dict=None): from modules import lowvram, sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() @@ -495,19 +494,24 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): sd_model = None try: with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip): - sd_model = instantiate_from_config(sd_config.model) - except Exception: - pass + with sd_disable_initialization.InitializeOnMeta(): + sd_model = instantiate_from_config(sd_config.model) + + except Exception as e: + errors.display(e, "creating model quickly", full_traceback=True) if sd_model is None: print('Failed to create model quickly; will retry using slow method.', file=sys.stderr) - sd_model = instantiate_from_config(sd_config.model) + + with sd_disable_initialization.InitializeOnMeta(): + sd_model = instantiate_from_config(sd_config.model) sd_model.used_config = checkpoint_config timer.record("create model") - load_model_weights(sd_model, checkpoint_info, state_dict, timer) + with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu): + load_model_weights(sd_model, checkpoint_info, state_dict, timer) if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram) diff --git a/webui.py b/webui.py index 2314735f..51248c39 100644 --- a/webui.py +++ b/webui.py @@ -320,9 +320,9 @@ def initialize_rest(*, reload_script_modules=False): if modules.sd_hijack.current_optimizer is None: modules.sd_hijack.apply_optimizations() - Thread(target=load_model).start() + devices.first_time_calculation() - Thread(target=devices.first_time_calculation).start() + Thread(target=load_model).start() shared.reload_hypernetworks() startup_timer.record("reload hypernetworks")