Use less RAM when creating models

This commit is contained in:
AUTOMATIC1111 2023-07-24 22:08:08 +03:00
parent f451994053
commit 0a89cd1a58
4 changed files with 114 additions and 13 deletions

View File

@ -67,6 +67,7 @@ parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="pre
parser.add_argument("--disable-opt-split-attention", action='store_true', help="prefer no cross-attention layer optimization for automatic choice of optimization")
parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
parser.add_argument("--disable-model-loading-ram-optimization", action='store_true', help="disable an optimization that reduces RAM use when loading a model")
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)

View File

@ -3,8 +3,31 @@ import open_clip
import torch
import transformers.utils.hub
from modules import shared
class DisableInitialization:
class ReplaceHelper:
def __init__(self):
self.replaced = []
def replace(self, obj, field, func):
original = getattr(obj, field, None)
if original is None:
return None
self.replaced.append((obj, field, original))
setattr(obj, field, func)
return original
def restore(self):
for obj, field, original in self.replaced:
setattr(obj, field, original)
self.replaced.clear()
class DisableInitialization(ReplaceHelper):
"""
When an object of this class enters a `with` block, it starts:
- preventing torch's layer initialization functions from working
@ -21,7 +44,7 @@ class DisableInitialization:
"""
def __init__(self, disable_clip=True):
self.replaced = []
super().__init__()
self.disable_clip = disable_clip
def replace(self, obj, field, func):
@ -86,8 +109,81 @@ class DisableInitialization:
self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache)
def __exit__(self, exc_type, exc_val, exc_tb):
for obj, field, original in self.replaced:
setattr(obj, field, original)
self.restore()
self.replaced.clear()
class InitializeOnMeta(ReplaceHelper):
"""
Context manager that causes all parameters for linear/conv2d/mha layers to be allocated on meta device,
which results in those parameters having no values and taking no memory. model.to() will be broken and
will need to be repaired by using LoadStateDictOnMeta below when loading params from state dict.
Usage:
```
with sd_disable_initialization.InitializeOnMeta():
sd_model = instantiate_from_config(sd_config.model)
```
"""
def __enter__(self):
if shared.cmd_opts.disable_model_loading_ram_optimization:
return
def set_device(x):
x["device"] = "meta"
return x
linear_init = self.replace(torch.nn.Linear, '__init__', lambda *args, **kwargs: linear_init(*args, **set_device(kwargs)))
conv2d_init = self.replace(torch.nn.Conv2d, '__init__', lambda *args, **kwargs: conv2d_init(*args, **set_device(kwargs)))
mha_init = self.replace(torch.nn.MultiheadAttention, '__init__', lambda *args, **kwargs: mha_init(*args, **set_device(kwargs)))
self.replace(torch.nn.Module, 'to', lambda *args, **kwargs: None)
def __exit__(self, exc_type, exc_val, exc_tb):
self.restore()
class LoadStateDictOnMeta(ReplaceHelper):
"""
Context manager that allows to read parameters from state_dict into a model that has some of its parameters in the meta device.
As those parameters are read from state_dict, they will be deleted from it, so by the end state_dict will be mostly empty, to save memory.
Meant to be used together with InitializeOnMeta above.
Usage:
```
with sd_disable_initialization.LoadStateDictOnMeta(state_dict):
model.load_state_dict(state_dict, strict=False)
```
"""
def __init__(self, state_dict, device):
super().__init__()
self.state_dict = state_dict
self.device = device
def __enter__(self):
if shared.cmd_opts.disable_model_loading_ram_optimization:
return
sd = self.state_dict
device = self.device
def load_from_state_dict(original, self, state_dict, prefix, *args, **kwargs):
params = [(name, param) for name, param in self._parameters.items() if param is not None and param.is_meta]
for name, param in params:
if param.is_meta:
self._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device), requires_grad=param.requires_grad)
original(self, state_dict, prefix, *args, **kwargs)
for name, _ in params:
key = prefix + name
if key in sd:
del sd[key]
linear_load_from_state_dict = self.replace(torch.nn.Linear, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(linear_load_from_state_dict, *args, **kwargs))
conv2d_load_from_state_dict = self.replace(torch.nn.Conv2d, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(conv2d_load_from_state_dict, *args, **kwargs))
mha_load_from_state_dict = self.replace(torch.nn.MultiheadAttention, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(mha_load_from_state_dict, *args, **kwargs))
def __exit__(self, exc_type, exc_val, exc_tb):
self.restore()

View File

@ -460,7 +460,6 @@ def get_empty_cond(sd_model):
return sd_model.cond_stage_model([""])
def load_model(checkpoint_info=None, already_loaded_state_dict=None):
from modules import lowvram, sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint()
@ -495,19 +494,24 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
sd_model = None
try:
with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip):
sd_model = instantiate_from_config(sd_config.model)
except Exception:
pass
with sd_disable_initialization.InitializeOnMeta():
sd_model = instantiate_from_config(sd_config.model)
except Exception as e:
errors.display(e, "creating model quickly", full_traceback=True)
if sd_model is None:
print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
sd_model = instantiate_from_config(sd_config.model)
with sd_disable_initialization.InitializeOnMeta():
sd_model = instantiate_from_config(sd_config.model)
sd_model.used_config = checkpoint_config
timer.record("create model")
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu):
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)

View File

@ -320,9 +320,9 @@ def initialize_rest(*, reload_script_modules=False):
if modules.sd_hijack.current_optimizer is None:
modules.sd_hijack.apply_optimizations()
Thread(target=load_model).start()
devices.first_time_calculation()
Thread(target=devices.first_time_calculation).start()
Thread(target=load_model).start()
shared.reload_hypernetworks()
startup_timer.record("reload hypernetworks")