make StableDiffusionProcessing class not hold a reference to shared.sd_model object

This commit is contained in:
AUTOMATIC 2023-01-16 23:09:08 +03:00
parent 9991967f40
commit e0e8005009
2 changed files with 5 additions and 5 deletions

View File

@ -94,7 +94,7 @@ def txt2img_image_conditioning(sd_model, x, width, height):
return image_conditioning return image_conditioning
class StableDiffusionProcessing(): class StableDiffusionProcessing:
""" """
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
""" """
@ -102,7 +102,6 @@ class StableDiffusionProcessing():
if sampler_index is not None: if sampler_index is not None:
print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr) print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
self.sd_model = sd_model
self.outpath_samples: str = outpath_samples self.outpath_samples: str = outpath_samples
self.outpath_grids: str = outpath_grids self.outpath_grids: str = outpath_grids
self.prompt: str = prompt self.prompt: str = prompt
@ -156,6 +155,10 @@ class StableDiffusionProcessing():
self.all_subseeds = None self.all_subseeds = None
self.iteration = 0 self.iteration = 0
@property
def sd_model(self):
return shared.sd_model
def txt2img_image_conditioning(self, x, width=None, height=None): def txt2img_image_conditioning(self, x, width=None, height=None):
self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'} self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'}
@ -236,7 +239,6 @@ class StableDiffusionProcessing():
raise NotImplementedError() raise NotImplementedError()
def close(self): def close(self):
self.sd_model = None
self.sampler = None self.sampler = None
@ -471,7 +473,6 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if k == 'sd_model_checkpoint': if k == 'sd_model_checkpoint':
sd_models.reload_model_weights() # make onchange call for changing SD model sd_models.reload_model_weights() # make onchange call for changing SD model
p.sd_model = shared.sd_model
if k == 'sd_vae': if k == 'sd_vae':
sd_vae.reload_vae_weights() # make onchange call for changing VAE sd_vae.reload_vae_weights() # make onchange call for changing VAE

View File

@ -86,7 +86,6 @@ def apply_checkpoint(p, x, xs):
if info is None: if info is None:
raise RuntimeError(f"Unknown checkpoint: {x}") raise RuntimeError(f"Unknown checkpoint: {x}")
modules.sd_models.reload_model_weights(shared.sd_model, info) modules.sd_models.reload_model_weights(shared.sd_model, info)
p.sd_model = shared.sd_model
def confirm_checkpoints(p, xs): def confirm_checkpoints(p, xs):