From ac90cf38c6b55d57d37923aa1fe86c7374e32d0b Mon Sep 17 00:00:00 2001 From: Tim Patton <38817597+pattontim@users.noreply.github.com> Date: Tue, 22 Nov 2022 10:13:07 -0500 Subject: [PATCH] safetensors optional for now --- modules/sd_models.py | 9 ++++++++- requirements.txt | 1 - 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 2bbb3bf5..75f7ab09 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -4,7 +4,6 @@ import sys import gc from collections import namedtuple import torch -from safetensors.torch import load_file, save_file import re from omegaconf import OmegaConf @@ -149,6 +148,10 @@ def torch_load(model_filename, model_info, map_override=None): # safely load weights # TODO: safetensors supports zero copy fast load to gpu, see issue #684. # GPU only for now, see https://github.com/huggingface/safetensors/issues/95 + try: + from safetensors.torch import load_file + except ImportError as e: + raise ImportError(f"The model is in safetensors format and it is not installed, use `pip install safetensors`: {e}") return load_file(model_filename, device='cuda') else: return torch.load(model_filename, map_location=map_override) @@ -157,6 +160,10 @@ def torch_save(model, output_filename): basename, exttype = os.path.splitext(output_filename) if(checkpoint_types[exttype] == 'safetensors'): # [===== >] Reticulating brines... + try: + from safetensors.torch import save_file + except ImportError as e: + raise ImportError(f"Export as safetensors selected, yet it is not installed, use `pip install safetensors`: {e}") save_file(model, output_filename, metadata={"format": "pt"}) else: torch.save(model, output_filename) diff --git a/requirements.txt b/requirements.txt index f7de9f70..762db4f3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -28,4 +28,3 @@ kornia lark inflection GitPython -safetensors