Merge remote-tracking branch 'Narsil/fix_safetensors_load_speed'

This commit is contained in:
AUTOMATIC 2023-01-04 14:53:03 +03:00
commit 68fbf4558f
2 changed files with 5 additions and 2 deletions

View File

@ -171,7 +171,10 @@ def get_state_dict_from_checkpoint(pl_sd):
def read_state_dict(checkpoint_file, print_global_state=False, map_location=None): def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
_, extension = os.path.splitext(checkpoint_file) _, extension = os.path.splitext(checkpoint_file)
if extension.lower() == ".safetensors": if extension.lower() == ".safetensors":
pl_sd = safetensors.torch.load_file(checkpoint_file, device=map_location or shared.weight_load_location) device = map_location or shared.weight_load_location
if device is None:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
else: else:
pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location) pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)

View File

@ -26,5 +26,5 @@ lark==1.1.2
inflection==0.5.1 inflection==0.5.1
GitPython==3.1.27 GitPython==3.1.27
torchsde==0.2.5 torchsde==0.2.5
safetensors==0.2.5 safetensors==0.2.7
httpcore<=0.15 httpcore<=0.15