Reverted txt2vid to use the StableDiffusionPipeline instead of StableDiffusionWalkPipeline when loading the model.

This commit is contained in:
ZeroCool940711 2022-10-31 06:50:21 -07:00
parent 25acf77260
commit 2ea2606e83

View File

@ -377,8 +377,7 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
print(
"The following part of your input was truncated because CLIP can only handle sequences up to"
print("The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
@ -613,7 +612,7 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
def walk(
self,
prompts: Optional[List[str]] = None,
prompt: Optional[List[str]] = None,
seeds: Optional[List[int]] = None,
num_interpolation_steps: Optional[Union[int, List[int]]] = 5, # int or list of int
output_dir: Optional[str] = "./dreams",
@ -1108,7 +1107,7 @@ def load_diffusers_model(weights_path,torch_device):
model_path = os.path.join("models", "diffusers", "stable-diffusion-v1-5")
if not os.path.exists(model_path + "/model_index.json"):
server_state["pipe"] = StableDiffusionWalkPipeline.from_pretrained(
server_state["pipe"] = StableDiffusionPipeline.from_pretrained(
weights_path,
use_local_file=True,
use_auth_token=st.session_state["defaults"].general.huggingface_token,
@ -1116,11 +1115,12 @@ def load_diffusers_model(weights_path,torch_device):
revision="fp16" if not st.session_state['defaults'].general.no_half else None,
safety_checker=None, # Very important for videos...lots of false positives while interpolating
#custom_pipeline="interpolate_stable_diffusion",
)
StableDiffusionWalkPipeline.save_pretrained(server_state["pipe"], model_path)
StableDiffusionPipeline.save_pretrained(server_state["pipe"], model_path)
else:
server_state["pipe"] = StableDiffusionWalkPipeline.from_pretrained(
server_state["pipe"] = StableDiffusionPipeline.from_pretrained(
model_path,
use_local_file=True,
torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None,
@ -1436,8 +1436,9 @@ def txt2vid(
# works correctly generating all frames but do not show the preview image
# we also do not have control over the generation and cant stop it until the end of it.
#with torch.autocast("cuda"):
#print (prompts)
#video_path = server_state["pipe"].walk(
#prompts=prompts,
#prompt=prompts,
#seeds=seeds,
#num_interpolation_steps=num_steps,
#height=height, # use multiples of 64 if > 512. Multiples of 8 if < 512.