Added the ability to save the video during mid generation when we hit the stop button.

- Fixed GFPGAN not working on txt2vid.
This commit is contained in:
ZeroCool940711 2022-10-07 23:46:23 -07:00 committed by Alejandro Gil
parent 97bbf089ea
commit 63b2ff22c6

View File

@ -113,88 +113,93 @@ def diffuse(
if "update_preview_frequency_list" not in st.session_state:
st.session_state["update_preview_frequency_list"] = [0]
st.session_state["update_preview_frequency_list"].append(st.session_state['defaults'].txt2vid.update_preview_frequency)
st.session_state["update_preview_frequency_list"].append(st.session_state["update_preview_frequency"])
# diffuse!
for i, t in enumerate(pipe.scheduler.timesteps):
start = timeit.default_timer()
try:
# diffuse!
for i, t in enumerate(pipe.scheduler.timesteps):
start = timeit.default_timer()
#status_text.text(f"Running step: {step_counter}{total_number_steps} {percent} | {duration:.2f}{speed}")
#status_text.text(f"Running step: {step_counter}{total_number_steps} {percent} | {duration:.2f}{speed}")
# expand the latents for classifier free guidance
latent_model_input = torch.cat([cond_latents] * 2)
if isinstance(pipe.scheduler, LMSDiscreteScheduler):
sigma = pipe.scheduler.sigmas[i]
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
# expand the latents for classifier free guidance
latent_model_input = torch.cat([cond_latents] * 2)
if isinstance(pipe.scheduler, LMSDiscreteScheduler):
sigma = pipe.scheduler.sigmas[i]
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
# predict the noise residual
noise_pred = pipe.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
# predict the noise residual
noise_pred = pipe.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
# cfg
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + cfg_scale * (noise_pred_text - noise_pred_uncond)
# cfg
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + cfg_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
if isinstance(pipe.scheduler, LMSDiscreteScheduler):
cond_latents = pipe.scheduler.step(noise_pred, i, cond_latents, **extra_step_kwargs)["prev_sample"]
else:
cond_latents = pipe.scheduler.step(noise_pred, t, cond_latents, **extra_step_kwargs)["prev_sample"]
# compute the previous noisy sample x_t -> x_t-1
if isinstance(pipe.scheduler, LMSDiscreteScheduler):
cond_latents = pipe.scheduler.step(noise_pred, i, cond_latents, **extra_step_kwargs)["prev_sample"]
else:
cond_latents = pipe.scheduler.step(noise_pred, t, cond_latents, **extra_step_kwargs)["prev_sample"]
#print (st.session_state["update_preview_frequency"])
#update the preview image if it is enabled and the frequency matches the step_counter
if st.session_state['defaults'].txt2vid.update_preview:
step_counter += 1
if st.session_state['defaults'].txt2vid.update_preview_frequency == step_counter or step_counter == st.session_state.sampling_steps:
if st.session_state.dynamic_preview_frequency:
st.session_state["current_chunk_speed"],
st.session_state["previous_chunk_speed_list"],
st.session_state['defaults'].txt2vid.update_preview_frequency,
st.session_state["avg_update_preview_frequency"] = optimize_update_preview_frequency(st.session_state["current_chunk_speed"],
st.session_state["previous_chunk_speed_list"],
st.session_state['defaults'].txt2vid.update_preview_frequency,
st.session_state["update_preview_frequency_list"])
#update the preview image if it is enabled and the frequency matches the step_counter
if st.session_state["update_preview"]:
step_counter += 1
#scale and decode the image latents with vae
cond_latents_2 = 1 / 0.18215 * cond_latents
image = pipe.vae.decode(cond_latents_2)
if st.session_state["update_preview_frequency"] == step_counter or step_counter == st.session_state.sampling_steps:
if st.session_state.dynamic_preview_frequency:
st.session_state["current_chunk_speed"],
st.session_state["previous_chunk_speed_list"],
st.session_state["update_preview_frequency"],
st.session_state["avg_update_preview_frequency"] = optimize_update_preview_frequency(st.session_state["current_chunk_speed"],
st.session_state["previous_chunk_speed_list"],
st.session_state["update_preview_frequency"],
st.session_state["update_preview_frequency_list"])
# generate output numpy image as uint8
image = torch.clamp((image["sample"] + 1.0) / 2.0, min=0.0, max=1.0)
image2 = transforms.ToPILImage()(image.squeeze_(0))
#scale and decode the image latents with vae
cond_latents_2 = 1 / 0.18215 * cond_latents
image = pipe.vae.decode(cond_latents_2)
st.session_state["preview_image"].image(image2)
# generate output numpy image as uint8
image = torch.clamp((image["sample"] + 1.0) / 2.0, min=0.0, max=1.0)
image2 = transforms.ToPILImage()(image.squeeze_(0))
step_counter = 0
st.session_state["preview_image"].image(image2)
duration = timeit.default_timer() - start
step_counter = 0
st.session_state["current_chunk_speed"] = duration
duration = timeit.default_timer() - start
if duration >= 1:
speed = "s/it"
else:
speed = "it/s"
duration = 1 / duration
st.session_state["current_chunk_speed"] = duration
if i > st.session_state.sampling_steps:
inference_counter += 1
inference_percent = int(100 * float(inference_counter + 1 if inference_counter < num_inference_steps else num_inference_steps)/float(num_inference_steps))
inference_progress = f"{inference_counter + 1 if inference_counter < num_inference_steps else num_inference_steps}/{num_inference_steps} {inference_percent}% "
else:
inference_progress = ""
if duration >= 1:
speed = "s/it"
else:
speed = "it/s"
duration = 1 / duration
percent = int(100 * float(i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps)/float(st.session_state.sampling_steps))
frames_percent = int(100 * float(st.session_state.current_frame if st.session_state.current_frame < st.session_state.max_frames else st.session_state.max_frames)/float(st.session_state.max_frames))
if i > st.session_state.sampling_steps:
inference_counter += 1
inference_percent = int(100 * float(inference_counter + 1 if inference_counter < num_inference_steps else num_inference_steps)/float(num_inference_steps))
inference_progress = f"{inference_counter + 1 if inference_counter < num_inference_steps else num_inference_steps}/{num_inference_steps} {inference_percent}% "
else:
inference_progress = ""
st.session_state["progress_bar_text"].text(
f"Running step: {i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps}/{st.session_state.sampling_steps} "
f"{percent if percent < 100 else 100}% {inference_progress}{duration:.2f}{speed} | "
f"Frame: {st.session_state.current_frame + 1 if st.session_state.current_frame < st.session_state.max_frames else st.session_state.max_frames}/{st.session_state.max_frames} "
f"{frames_percent if frames_percent < 100 else 100}% {st.session_state.frame_duration:.2f}{st.session_state.frame_speed}"
)
st.session_state["progress_bar"].progress(percent if percent < 100 else 100)
percent = int(100 * float(i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps)/float(st.session_state.sampling_steps))
frames_percent = int(100 * float(st.session_state.current_frame if st.session_state.current_frame < st.session_state.max_frames else st.session_state.max_frames)/float(
st.session_state.max_frames))
st.session_state["progress_bar_text"].text(
f"Running step: {i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps}/{st.session_state.sampling_steps} "
f"{percent if percent < 100 else 100}% {inference_progress}{duration:.2f}{speed} | "
f"Frame: {st.session_state.current_frame + 1 if st.session_state.current_frame < st.session_state.max_frames else st.session_state.max_frames}/{st.session_state.max_frames} "
f"{frames_percent if frames_percent < 100 else 100}% {st.session_state.frame_duration:.2f}{st.session_state.frame_speed}"
)
st.session_state["progress_bar"].progress(percent if percent < 100 else 100)
except KeyError:
raise StopException
#scale and decode the image latents with vae
cond_latents_2 = 1 / 0.18215 * cond_latents
@ -262,7 +267,23 @@ def load_diffusers_model(weights_path,torch_device):
"You need a huggingface token in order to use the Text to Video tab. Use the Settings page from the sidebar on the left to add your token."
)
raise OSError("You need a huggingface token in order to use the Text to Video tab. Use the Settings page from the sidebar on the left to add your token.")
#
def save_video_to_disk(frames, seeds, sanitized_prompt, fps=6,save_video=True, outdir='outputs'):
if save_video:
# write video to memory
#output = io.BytesIO()
#writer = imageio.get_writer(os.path.join(os.getcwd(), st.session_state['defaults'].general.outdir, "txt2vid"), im, extension=".mp4", fps=30)
#try:
video_path = os.path.join(os.getcwd(), outdir, "txt2vid",f"{seeds}_{sanitized_prompt}.mp4")
writer = imageio.get_writer(video_path, fps=fps)
for frame in frames:
writer.append_data(frame)
writer.close()
#except:
# print("Can't save video, skipping.")
return video_path
#
def txt2vid(
# --------------------------------------
@ -275,6 +296,9 @@ def txt2vid(
max_frames:int = 10000, # number of frames to write and then exit the script
num_inference_steps:int = 50, # more (e.g. 100, 200 etc) can create slightly better images
cfg_scale:float = 5.0, # can depend on the prompt. usually somewhere between 3-10 is good
save_video = True,
save_video_on_stop = False,
outdir='outputs',
do_loop = False,
use_lerp_for_text = False,
seeds = None,
@ -332,11 +356,11 @@ def txt2vid(
# init the output dir
sanitized_prompt = slugify(prompts)
full_path = os.path.join(os.getcwd(), st.session_state['defaults'].general.outdir, "txt2vid-samples", "samples", sanitized_prompt)
full_path = os.path.join(os.getcwd(), st.session_state['defaults'].general.outdir, "txt2vid", "samples", sanitized_prompt)
if len(full_path) > 220:
sanitized_prompt = sanitized_prompt[:220-len(full_path)]
full_path = os.path.join(os.getcwd(), st.session_state['defaults'].general.outdir, "txt2vid-samples", "samples", sanitized_prompt)
full_path = os.path.join(os.getcwd(), st.session_state['defaults'].general.outdir, "txt2vid", "samples", sanitized_prompt)
os.makedirs(full_path, exist_ok=True)
@ -512,11 +536,12 @@ def txt2vid(
#append the frames to the frames list so we can use them later.
frames.append(np.asarray(gfpgan_image))
st.session_state["preview_image"].image(gfpgan_image)
#except AttributeError:
try:
st.session_state["preview_image"].image(gfpgan_image)
except KeyError:
print ("Cant get session_state, skipping image preview.")
#except (AttributeError, KeyError):
#print("Cant perform GFPGAN, skipping.")
#pass
#increase frame_index counter.
frame_index += 1
@ -536,23 +561,18 @@ def txt2vid(
init1 = init2
# save the video after the generation is done.
video_path = save_video_to_disk(frames, seeds, sanitized_prompt, save_video=save_video, outdir=outdir)
except StopException:
pass
if save_video_on_stop:
print ("Streamlit Stop Exception Received. Saving video")
video_path = save_video_to_disk(frames, seeds, sanitized_prompt, save_video=save_video, outdir=outdir)
else:
video_path = None
if st.session_state['save_video']:
# write video to memory
#output = io.BytesIO()
#writer = imageio.get_writer(os.path.join(os.getcwd(), st.session_state['defaults'].general.outdir, "txt2vid-samples"), im, extension=".mp4", fps=30)
try:
video_path = os.path.join(os.getcwd(), st.session_state['defaults'].general.outdir, "txt2vid-samples",f"{seeds}_{sanitized_prompt}.mp4")
writer = imageio.get_writer(video_path, fps=6)
for frame in frames:
writer.append_data(frame)
writer.close()
except:
print("Can't save video, skipping.")
if video_path and "preview_video" in st.session_state:
# show video preview on the UI
st.session_state["preview_video"].video(open(video_path, 'rb').read())
@ -620,6 +640,11 @@ def layout():
help="Frequency in steps at which the the preview image is updated. By default the frequency \
is set to 1 step.")
st.session_state["dynamic_preview_frequency"] = st.checkbox("Dynamic Preview Frequency", value=st.session_state['defaults'].txt2vid.dynamic_preview_frequency,
help="This option tries to find the best value at which we can update \
the preview image during generation while minimizing the impact it has in performance. Default: True")
#
@ -644,6 +669,7 @@ def layout():
#generate_video = st.empty()
st.session_state["preview_video"] = st.empty()
preview_video = st.session_state["preview_video"]
message = st.empty()
@ -702,19 +728,23 @@ def layout():
help="Separate multiple prompts using the `|` character, and get all combinations of them.")
st.session_state["normalize_prompt_weights"] = st.checkbox("Normalize Prompt Weights.",
value=st.session_state['defaults'].txt2vid.normalize_prompt_weights, help="Ensure the sum of all weights add up to 1.0")
st.session_state["save_individual_images"] = st.checkbox("Save individual images.",
value=st.session_state['defaults'].txt2vid.save_individual_images,
help="Save each image generated before any filter or enhancement is applied.")
st.session_state["save_video"] = st.checkbox("Save video",value=st.session_state['defaults'].txt2vid.save_video,
help="Save a video with all the images generated as frames at the end of the generation.")
save_video_on_stop = st.checkbox("Save video on Stop",value=st.session_state['defaults'].txt2vid.save_video_on_stop,
help="Save a video with all the images generated as frames when we hit the stop button during a generation.")
st.session_state["group_by_prompt"] = st.checkbox("Group results by prompt", value=st.session_state['defaults'].txt2vid.group_by_prompt,
help="Saves all the images with the same prompt into the same folder. When using a prompt matrix each prompt combination will have its own folder.")
help="Saves all the images with the same prompt into the same folder. When using a prompt \
matrix each prompt combination will have its own folder.")
st.session_state["write_info_files"] = st.checkbox("Write Info file", value=st.session_state['defaults'].txt2vid.write_info_files,
help="Save a file next to the image with informartion about the generation.")
st.session_state["dynamic_preview_frequency"] = st.checkbox("Dynamic Preview Frequency", value=st.session_state['defaults'].txt2vid.dynamic_preview_frequency,
help="This option tries to find the best value at which we can update \
the preview image during generation while minimizing the impact it has in performance. Default: True")
st.session_state["do_loop"] = st.checkbox("Do Loop", value=st.session_state['defaults'].txt2vid.do_loop,
help="Do loop")
st.session_state["save_as_jpg"] = st.checkbox("Save samples as jpg", value=st.session_state['defaults'].txt2vid.save_as_jpg, help="Saves the images as jpg instead of png.")
@ -830,7 +860,7 @@ def layout():
#load_models(False, st.session_state["use_GFPGAN"], True, st.session_state["RealESRGAN_model"])
if st.session_state["use_GFPGAN"]:
if "GFPGAN" in st.session_state:
if "GFPGAN" in server_state:
print("GFPGAN already loaded")
else:
with col2:
@ -838,28 +868,35 @@ def layout():
# Load GFPGAN
if os.path.exists(st.session_state["defaults"].general.GFPGAN_dir):
try:
server_state["GFPGAN"] = load_GFPGAN()
load_GFPGAN()
print("Loaded GFPGAN")
except Exception:
import traceback
print("Error loading GFPGAN:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
else:
if "GFPGAN" in st.session_state:
if "GFPGAN" in server_state:
del server_state["GFPGAN"]
#try:
# run video generation
video, seed, info, stats = txt2vid(prompts=prompt, gpu=st.session_state["defaults"].general.gpu,
num_steps=st.session_state.sampling_steps, max_frames=int(st.session_state.max_frames),
num_steps=st.session_state.sampling_steps, max_frames=st.session_state.max_frames,
num_inference_steps=st.session_state.num_inference_steps,
cfg_scale=cfg_scale,do_loop=st.session_state["do_loop"],
cfg_scale=cfg_scale, save_video_on_stop=save_video_on_stop,
outdir=st.session_state["defaults"].general.outdir,
do_loop=st.session_state["do_loop"],
seeds=seed, quality=100, eta=0.0, width=width,
height=height, weights_path=custom_model, scheduler=scheduler_name,
disable_tqdm=False, beta_start=st.session_state['defaults'].txt2vid.beta_start.value,
beta_end=st.session_state['defaults'].txt2vid.beta_end.value,
beta_schedule=beta_scheduler_type, starting_image=None)
if video and save_video_on_stop:
# show video preview on the UI after we hit the stop button
# currently not working as session_state is cleared on StopException
preview_video.video(open(video, 'rb').read())
#message.success('Done!', icon="✅")
message.success('Render Complete: ' + info + '; Stats: ' + stats, icon="")