Reverted txt2vid to use the StableDiffusionPipeline instead of StableDiffusionWalkPipeline when loading the model. (#1630)

This commit is contained in:
Alejandro Gil 2022-10-31 07:15:19 -07:00 committed by GitHub
commit e13132c78c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 47 additions and 69 deletions

View File

@ -654,40 +654,6 @@ def layout():
message.success('Render Complete: ' + info + '; Stats: ' + stats, icon="")
#history_tab,col1,col2,col3,PlaceHolder,col1_cont,col2_cont,col3_cont = st.session_state['historyTab']
#if 'latestImages' in st.session_state:
#for i in output_images:
##push the new image to the list of latest images and remove the oldest one
##remove the last index from the list\
#st.session_state['latestImages'].pop()
##add the new image to the start of the list
#st.session_state['latestImages'].insert(0, i)
#PlaceHolder.empty()
#with PlaceHolder.container():
#col1, col2, col3 = st.columns(3)
#col1_cont = st.container()
#col2_cont = st.container()
#col3_cont = st.container()
#images = st.session_state['latestImages']
#with col1_cont:
#with col1:
#[st.image(images[index]) for index in [0, 3, 6] if index < len(images)]
#with col2_cont:
#with col2:
#[st.image(images[index]) for index in [1, 4, 7] if index < len(images)]
#with col3_cont:
#with col3:
#[st.image(images[index]) for index in [2, 5, 8] if index < len(images)]
#historyGallery = st.empty()
## check if output_images length is the same as seeds length
#with gallery_tab:
#st.markdown(createHTMLGallery(output_images,seeds), unsafe_allow_html=True)
#st.session_state['historyTab'] = [history_tab,col1,col2,col3,PlaceHolder,col1_cont,col2_cont,col3_cont]
with gallery_tab:
logger.info(seeds)
st.session_state["gallery"].text = ""

View File

@ -377,8 +377,7 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
print(
"The following part of your input was truncated because CLIP can only handle sequences up to"
print("The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
@ -613,7 +612,7 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
def walk(
self,
prompts: Optional[List[str]] = None,
prompt: Optional[List[str]] = None,
seeds: Optional[List[int]] = None,
num_interpolation_steps: Optional[Union[int, List[int]]] = 5, # int or list of int
output_dir: Optional[str] = "./dreams",
@ -1108,7 +1107,7 @@ def load_diffusers_model(weights_path,torch_device):
model_path = os.path.join("models", "diffusers", "stable-diffusion-v1-5")
if not os.path.exists(model_path + "/model_index.json"):
server_state["pipe"] = StableDiffusionWalkPipeline.from_pretrained(
server_state["pipe"] = StableDiffusionPipeline.from_pretrained(
weights_path,
use_local_file=True,
use_auth_token=st.session_state["defaults"].general.huggingface_token,
@ -1116,11 +1115,12 @@ def load_diffusers_model(weights_path,torch_device):
revision="fp16" if not st.session_state['defaults'].general.no_half else None,
safety_checker=None, # Very important for videos...lots of false positives while interpolating
#custom_pipeline="interpolate_stable_diffusion",
)
StableDiffusionWalkPipeline.save_pretrained(server_state["pipe"], model_path)
StableDiffusionPipeline.save_pretrained(server_state["pipe"], model_path)
else:
server_state["pipe"] = StableDiffusionWalkPipeline.from_pretrained(
server_state["pipe"] = StableDiffusionPipeline.from_pretrained(
model_path,
use_local_file=True,
torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None,
@ -1166,9 +1166,11 @@ def load_diffusers_model(weights_path,torch_device):
if "huggingface_token" not in st.session_state or st.session_state["defaults"].general.huggingface_token == "None":
if "progress_bar_text" in st.session_state:
st.session_state["progress_bar_text"].error(
"You need a huggingface token in order to use the Text to Video tab. Use the Settings page from the sidebar on the left to add your token."
"You need a huggingface token in order to use the Text to Video tab. Use the Settings page to add your token under the Huggingface section. "
"Make sure you save your settings after adding it."
)
raise OSError("You need a huggingface token in order to use the Text to Video tab. Use the Settings page from the sidebar on the left to add your token.")
raise OSError("You need a huggingface token in order to use the Text to Video tab. Use the Settings page to add your token under the Huggingface section. "
"Make sure you save your settings after adding it.")
else:
if "progress_bar_text" in st.session_state:
st.session_state["progress_bar_text"].error(e)
@ -1434,8 +1436,9 @@ def txt2vid(
# works correctly generating all frames but do not show the preview image
# we also do not have control over the generation and cant stop it until the end of it.
#with torch.autocast("cuda"):
#print (prompts)
#video_path = server_state["pipe"].walk(
#prompts=prompts,
#prompt=prompts,
#seeds=seeds,
#num_interpolation_steps=num_steps,
#height=height, # use multiples of 64 if > 512. Multiples of 8 if < 512.

View File

@ -123,7 +123,10 @@ def layout():
# specify the primary menu definition
menu_data = [
{'id': 'Stable Diffusion', 'label': 'Stable Diffusion', 'icon': 'bi bi-grid-1x2-fill'},
{'id': 'Train','label':"Train", 'icon': "bi bi-lightbulb-fill", 'submenu':[
{'id': 'Textual Inversion', 'label': 'Textual Inversion', 'icon': 'bi bi-lightbulb-fill'},
{'id': 'Fine Tunning', 'label': 'Fine Tunning', 'icon': 'bi bi-lightbulb-fill'},
]},
{'id': 'Model Manager', 'label': 'Model Manager', 'icon': 'bi bi-cloud-arrow-down-fill'},
{'id': 'Tools','label':"Tools", 'icon': "bi bi-tools", 'submenu':[
{'id': 'API Server', 'label': 'API Server', 'icon': 'bi bi-server'},
@ -188,6 +191,7 @@ def layout():
st.experimental_rerun()
txt2img_tab, img2img_tab, txt2vid_tab, img2txt_tab, concept_library_tab = st.tabs(["Text-to-Image", "Image-to-Image",
#"Inpainting",
"Text-to-Video", "Image-To-Text",
"Concept Library"])
#with home_tab:
@ -229,6 +233,11 @@ def layout():
from textual_inversion import layout
layout()
elif menu_id == 'Fine Tunning':
#from textual_inversion import layout
#layout()
st.info("Under Construction. :construction_worker:")
elif menu_id == 'API Server':
set_page_title("API Server - Stable Diffusion Playground")
from APIServer import layout