2022-10-09 17:58:43 +03:00
|
|
|
# this code is adapted from the script contributed by anon from /h/
|
|
|
|
|
|
|
|
import io
|
|
|
|
import pickle
|
|
|
|
import collections
|
|
|
|
import sys
|
|
|
|
import traceback
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import numpy
|
|
|
|
import _codecs
|
|
|
|
import zipfile
|
2022-10-11 17:03:00 +03:00
|
|
|
import re
|
2022-10-09 17:58:43 +03:00
|
|
|
|
|
|
|
|
2022-10-10 07:38:55 +03:00
|
|
|
# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
|
|
|
|
TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
|
|
|
|
|
|
|
|
|
2022-10-09 17:58:43 +03:00
|
|
|
def encode(*args):
|
|
|
|
out = _codecs.encode(*args)
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
class RestrictedUnpickler(pickle.Unpickler):
|
2022-11-06 11:20:23 +03:00
|
|
|
extra_handler = None
|
|
|
|
|
2022-10-09 17:58:43 +03:00
|
|
|
def persistent_load(self, saved_id):
|
|
|
|
assert saved_id[0] == 'storage'
|
2022-10-10 07:38:55 +03:00
|
|
|
return TypedStorage()
|
2022-10-09 17:58:43 +03:00
|
|
|
|
|
|
|
def find_class(self, module, name):
|
2022-11-06 11:20:23 +03:00
|
|
|
if self.extra_handler is not None:
|
|
|
|
res = self.extra_handler(module, name)
|
|
|
|
if res is not None:
|
|
|
|
return res
|
|
|
|
|
2022-10-09 17:58:43 +03:00
|
|
|
if module == 'collections' and name == 'OrderedDict':
|
|
|
|
return getattr(collections, name)
|
|
|
|
if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter']:
|
|
|
|
return getattr(torch._utils, name)
|
2022-11-01 14:19:24 +03:00
|
|
|
if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage']:
|
2022-10-09 17:58:43 +03:00
|
|
|
return getattr(torch, name)
|
|
|
|
if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
|
|
|
|
return getattr(torch.nn.modules.container, name)
|
|
|
|
if module == 'numpy.core.multiarray' and name == 'scalar':
|
|
|
|
return numpy.core.multiarray.scalar
|
|
|
|
if module == 'numpy' and name == 'dtype':
|
|
|
|
return numpy.dtype
|
|
|
|
if module == '_codecs' and name == 'encode':
|
|
|
|
return encode
|
|
|
|
if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint':
|
|
|
|
import pytorch_lightning.callbacks
|
|
|
|
return pytorch_lightning.callbacks.model_checkpoint
|
|
|
|
if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint':
|
|
|
|
import pytorch_lightning.callbacks.model_checkpoint
|
|
|
|
return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint
|
|
|
|
if module == "__builtin__" and name == 'set':
|
|
|
|
return set
|
|
|
|
|
|
|
|
# Forbid everything else.
|
2022-11-06 11:20:23 +03:00
|
|
|
raise Exception(f"global '{module}/{name}' is forbidden")
|
2022-10-09 17:58:43 +03:00
|
|
|
|
|
|
|
|
2022-10-11 17:03:00 +03:00
|
|
|
allowed_zip_names = ["archive/data.pkl", "archive/version"]
|
|
|
|
allowed_zip_names_re = re.compile(r"^archive/data/\d+$")
|
|
|
|
|
|
|
|
|
|
|
|
def check_zip_filenames(filename, names):
|
|
|
|
for name in names:
|
|
|
|
if name in allowed_zip_names:
|
|
|
|
continue
|
|
|
|
if allowed_zip_names_re.match(name):
|
|
|
|
continue
|
|
|
|
|
|
|
|
raise Exception(f"bad file inside {filename}: {name}")
|
|
|
|
|
|
|
|
|
2022-11-06 11:20:23 +03:00
|
|
|
def check_pt(filename, extra_handler):
|
2022-10-09 17:58:43 +03:00
|
|
|
try:
|
|
|
|
|
|
|
|
# new pytorch format is a zip file
|
|
|
|
with zipfile.ZipFile(filename) as z:
|
2022-10-11 17:03:00 +03:00
|
|
|
check_zip_filenames(filename, z.namelist())
|
|
|
|
|
2022-10-09 17:58:43 +03:00
|
|
|
with z.open('archive/data.pkl') as file:
|
|
|
|
unpickler = RestrictedUnpickler(file)
|
2022-11-06 11:20:23 +03:00
|
|
|
unpickler.extra_handler = extra_handler
|
2022-10-09 17:58:43 +03:00
|
|
|
unpickler.load()
|
|
|
|
|
|
|
|
except zipfile.BadZipfile:
|
|
|
|
|
|
|
|
# if it's not a zip file, it's an olf pytorch format, with five objects written to pickle
|
|
|
|
with open(filename, "rb") as file:
|
|
|
|
unpickler = RestrictedUnpickler(file)
|
2022-11-06 11:20:23 +03:00
|
|
|
unpickler.extra_handler = extra_handler
|
2022-10-09 17:58:43 +03:00
|
|
|
for i in range(5):
|
|
|
|
unpickler.load()
|
|
|
|
|
|
|
|
|
|
|
|
def load(filename, *args, **kwargs):
|
2022-11-06 11:20:23 +03:00
|
|
|
return load_with_extra(filename, *args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
def load_with_extra(filename, extra_handler=None, *args, **kwargs):
|
|
|
|
"""
|
|
|
|
this functon is intended to be used by extensions that want to load models with
|
|
|
|
some extra classes in them that the usual unpickler would find suspicious.
|
|
|
|
|
|
|
|
Use the extra_handler argument to specify a function that takes module and field name as text,
|
|
|
|
and returns that field's value:
|
|
|
|
|
|
|
|
```python
|
|
|
|
def extra(module, name):
|
|
|
|
if module == 'collections' and name == 'OrderedDict':
|
|
|
|
return collections.OrderedDict
|
|
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
safe.load_with_extra('model.pt', extra_handler=extra)
|
|
|
|
```
|
|
|
|
|
|
|
|
The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is
|
|
|
|
definitely unsafe.
|
|
|
|
"""
|
|
|
|
|
2022-10-09 17:58:43 +03:00
|
|
|
from modules import shared
|
|
|
|
|
|
|
|
try:
|
|
|
|
if not shared.cmd_opts.disable_safe_unpickle:
|
2022-11-06 11:20:23 +03:00
|
|
|
check_pt(filename, extra_handler)
|
2022-10-09 17:58:43 +03:00
|
|
|
|
2022-10-14 16:37:32 +03:00
|
|
|
except pickle.UnpicklingError:
|
|
|
|
print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
|
|
|
|
print(traceback.format_exc(), file=sys.stderr)
|
|
|
|
print(f"-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr)
|
|
|
|
print(f"You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr)
|
|
|
|
return None
|
|
|
|
|
2022-10-09 17:58:43 +03:00
|
|
|
except Exception:
|
|
|
|
print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
|
|
|
|
print(traceback.format_exc(), file=sys.stderr)
|
|
|
|
print(f"\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr)
|
2022-10-14 16:37:32 +03:00
|
|
|
print(f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr)
|
2022-10-09 17:58:43 +03:00
|
|
|
return None
|
|
|
|
|
|
|
|
return unsafe_torch_load(filename, *args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
unsafe_torch_load = torch.load
|
|
|
|
torch.load = load
|