send weights to target device instead of CPU memory

This commit is contained in:
AUTOMATIC1111 2023-08-16 12:11:01 +03:00
parent 57e59c14c8
commit eaba3d7349
2 changed files with 31 additions and 10 deletions

View File

@ -155,10 +155,16 @@ class LoadStateDictOnMeta(ReplaceHelper):
``` ```
""" """
def __init__(self, state_dict, device): def __init__(self, state_dict, device, weight_dtype_conversion=None):
super().__init__() super().__init__()
self.state_dict = state_dict self.state_dict = state_dict
self.device = device self.device = device
self.weight_dtype_conversion = weight_dtype_conversion or {}
self.default_dtype = self.weight_dtype_conversion.get('')
def get_weight_dtype(self, key):
key_first_term, _ = key.split('.', 1)
return self.weight_dtype_conversion.get(key_first_term, self.default_dtype)
def __enter__(self): def __enter__(self):
if shared.cmd_opts.disable_model_loading_ram_optimization: if shared.cmd_opts.disable_model_loading_ram_optimization:
@ -167,24 +173,24 @@ class LoadStateDictOnMeta(ReplaceHelper):
sd = self.state_dict sd = self.state_dict
device = self.device device = self.device
def load_from_state_dict(original, self, state_dict, prefix, *args, **kwargs): def load_from_state_dict(original, module, state_dict, prefix, *args, **kwargs):
used_param_keys = [] used_param_keys = []
for name, param in self._parameters.items(): for name, param in module._parameters.items():
if param is None: if param is None:
continue continue
key = prefix + name key = prefix + name
sd_param = sd.pop(key, None) sd_param = sd.pop(key, None)
if sd_param is not None: if sd_param is not None:
state_dict[key] = sd_param state_dict[key] = sd_param.to(dtype=self.get_weight_dtype(key))
used_param_keys.append(key) used_param_keys.append(key)
if param.is_meta: if param.is_meta:
dtype = sd_param.dtype if sd_param is not None else param.dtype dtype = sd_param.dtype if sd_param is not None else param.dtype
self._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad) module._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad)
for name in self._buffers: for name in module._buffers:
key = prefix + name key = prefix + name
sd_param = sd.pop(key, None) sd_param = sd.pop(key, None)
@ -192,12 +198,12 @@ class LoadStateDictOnMeta(ReplaceHelper):
state_dict[key] = sd_param state_dict[key] = sd_param
used_param_keys.append(key) used_param_keys.append(key)
original(self, state_dict, prefix, *args, **kwargs) original(module, state_dict, prefix, *args, **kwargs)
for key in used_param_keys: for key in used_param_keys:
state_dict.pop(key, None) state_dict.pop(key, None)
def load_state_dict(original, self, state_dict, strict=True): def load_state_dict(original, module, state_dict, strict=True):
"""torch makes a lot of copies of the dictionary with weights, so just deleting entries from state_dict does not help """torch makes a lot of copies of the dictionary with weights, so just deleting entries from state_dict does not help
because the same values are stored in multiple copies of the dict. The trick used here is to give torch a dict with because the same values are stored in multiple copies of the dict. The trick used here is to give torch a dict with
all weights on meta device, i.e. deleted, and then it doesn't matter how many copies torch makes. all weights on meta device, i.e. deleted, and then it doesn't matter how many copies torch makes.
@ -212,7 +218,7 @@ class LoadStateDictOnMeta(ReplaceHelper):
if state_dict == sd: if state_dict == sd:
state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()} state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()}
original(self, state_dict, strict=strict) original(module, state_dict, strict=strict)
module_load_state_dict = self.replace(torch.nn.Module, 'load_state_dict', lambda *args, **kwargs: load_state_dict(module_load_state_dict, *args, **kwargs)) module_load_state_dict = self.replace(torch.nn.Module, 'load_state_dict', lambda *args, **kwargs: load_state_dict(module_load_state_dict, *args, **kwargs))
module_load_from_state_dict = self.replace(torch.nn.Module, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(module_load_from_state_dict, *args, **kwargs)) module_load_from_state_dict = self.replace(torch.nn.Module, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(module_load_from_state_dict, *args, **kwargs))

View File

@ -518,6 +518,13 @@ def send_model_to_cpu(m):
devices.torch_gc() devices.torch_gc()
def model_target_device():
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
return devices.cpu
else:
return devices.device
def send_model_to_device(m): def send_model_to_device(m):
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.setup_for_low_vram(m, shared.cmd_opts.medvram) lowvram.setup_for_low_vram(m, shared.cmd_opts.medvram)
@ -579,7 +586,15 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
timer.record("create model") timer.record("create model")
with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu): if shared.cmd_opts.no_half:
weight_dtype_conversion = None
else:
weight_dtype_conversion = {
'first_stage_model': None,
'': torch.float16,
}
with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(), weight_dtype_conversion=weight_dtype_conversion):
load_model_weights(sd_model, checkpoint_info, state_dict, timer) load_model_weights(sd_model, checkpoint_info, state_dict, timer)
timer.record("load weights from state dict") timer.record("load weights from state dict")