mirror of
https://github.com/sd-webui/stable-diffusion-webui.git
synced 2024-12-14 14:52:31 +03:00
Added the ability to save the video during mid generation when we hit the stop button.
- Fixed GFPGAN not working on txt2vid.
This commit is contained in:
parent
97bbf089ea
commit
63b2ff22c6
@ -113,88 +113,93 @@ def diffuse(
|
||||
|
||||
if "update_preview_frequency_list" not in st.session_state:
|
||||
st.session_state["update_preview_frequency_list"] = [0]
|
||||
st.session_state["update_preview_frequency_list"].append(st.session_state['defaults'].txt2vid.update_preview_frequency)
|
||||
st.session_state["update_preview_frequency_list"].append(st.session_state["update_preview_frequency"])
|
||||
|
||||
|
||||
# diffuse!
|
||||
for i, t in enumerate(pipe.scheduler.timesteps):
|
||||
start = timeit.default_timer()
|
||||
try:
|
||||
# diffuse!
|
||||
for i, t in enumerate(pipe.scheduler.timesteps):
|
||||
start = timeit.default_timer()
|
||||
|
||||
#status_text.text(f"Running step: {step_counter}{total_number_steps} {percent} | {duration:.2f}{speed}")
|
||||
#status_text.text(f"Running step: {step_counter}{total_number_steps} {percent} | {duration:.2f}{speed}")
|
||||
|
||||
# expand the latents for classifier free guidance
|
||||
latent_model_input = torch.cat([cond_latents] * 2)
|
||||
if isinstance(pipe.scheduler, LMSDiscreteScheduler):
|
||||
sigma = pipe.scheduler.sigmas[i]
|
||||
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
|
||||
# expand the latents for classifier free guidance
|
||||
latent_model_input = torch.cat([cond_latents] * 2)
|
||||
if isinstance(pipe.scheduler, LMSDiscreteScheduler):
|
||||
sigma = pipe.scheduler.sigmas[i]
|
||||
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = pipe.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
|
||||
# predict the noise residual
|
||||
noise_pred = pipe.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
|
||||
|
||||
# cfg
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + cfg_scale * (noise_pred_text - noise_pred_uncond)
|
||||
# cfg
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + cfg_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
if isinstance(pipe.scheduler, LMSDiscreteScheduler):
|
||||
cond_latents = pipe.scheduler.step(noise_pred, i, cond_latents, **extra_step_kwargs)["prev_sample"]
|
||||
else:
|
||||
cond_latents = pipe.scheduler.step(noise_pred, t, cond_latents, **extra_step_kwargs)["prev_sample"]
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
if isinstance(pipe.scheduler, LMSDiscreteScheduler):
|
||||
cond_latents = pipe.scheduler.step(noise_pred, i, cond_latents, **extra_step_kwargs)["prev_sample"]
|
||||
else:
|
||||
cond_latents = pipe.scheduler.step(noise_pred, t, cond_latents, **extra_step_kwargs)["prev_sample"]
|
||||
|
||||
#print (st.session_state["update_preview_frequency"])
|
||||
#update the preview image if it is enabled and the frequency matches the step_counter
|
||||
if st.session_state['defaults'].txt2vid.update_preview:
|
||||
step_counter += 1
|
||||
|
||||
if st.session_state['defaults'].txt2vid.update_preview_frequency == step_counter or step_counter == st.session_state.sampling_steps:
|
||||
if st.session_state.dynamic_preview_frequency:
|
||||
st.session_state["current_chunk_speed"],
|
||||
st.session_state["previous_chunk_speed_list"],
|
||||
st.session_state['defaults'].txt2vid.update_preview_frequency,
|
||||
st.session_state["avg_update_preview_frequency"] = optimize_update_preview_frequency(st.session_state["current_chunk_speed"],
|
||||
st.session_state["previous_chunk_speed_list"],
|
||||
st.session_state['defaults'].txt2vid.update_preview_frequency,
|
||||
st.session_state["update_preview_frequency_list"])
|
||||
#update the preview image if it is enabled and the frequency matches the step_counter
|
||||
if st.session_state["update_preview"]:
|
||||
step_counter += 1
|
||||
|
||||
#scale and decode the image latents with vae
|
||||
cond_latents_2 = 1 / 0.18215 * cond_latents
|
||||
image = pipe.vae.decode(cond_latents_2)
|
||||
if st.session_state["update_preview_frequency"] == step_counter or step_counter == st.session_state.sampling_steps:
|
||||
if st.session_state.dynamic_preview_frequency:
|
||||
st.session_state["current_chunk_speed"],
|
||||
st.session_state["previous_chunk_speed_list"],
|
||||
st.session_state["update_preview_frequency"],
|
||||
st.session_state["avg_update_preview_frequency"] = optimize_update_preview_frequency(st.session_state["current_chunk_speed"],
|
||||
st.session_state["previous_chunk_speed_list"],
|
||||
st.session_state["update_preview_frequency"],
|
||||
st.session_state["update_preview_frequency_list"])
|
||||
|
||||
# generate output numpy image as uint8
|
||||
image = torch.clamp((image["sample"] + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
image2 = transforms.ToPILImage()(image.squeeze_(0))
|
||||
#scale and decode the image latents with vae
|
||||
cond_latents_2 = 1 / 0.18215 * cond_latents
|
||||
image = pipe.vae.decode(cond_latents_2)
|
||||
|
||||
st.session_state["preview_image"].image(image2)
|
||||
# generate output numpy image as uint8
|
||||
image = torch.clamp((image["sample"] + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
image2 = transforms.ToPILImage()(image.squeeze_(0))
|
||||
|
||||
step_counter = 0
|
||||
st.session_state["preview_image"].image(image2)
|
||||
|
||||
duration = timeit.default_timer() - start
|
||||
step_counter = 0
|
||||
|
||||
st.session_state["current_chunk_speed"] = duration
|
||||
duration = timeit.default_timer() - start
|
||||
|
||||
if duration >= 1:
|
||||
speed = "s/it"
|
||||
else:
|
||||
speed = "it/s"
|
||||
duration = 1 / duration
|
||||
st.session_state["current_chunk_speed"] = duration
|
||||
|
||||
if i > st.session_state.sampling_steps:
|
||||
inference_counter += 1
|
||||
inference_percent = int(100 * float(inference_counter + 1 if inference_counter < num_inference_steps else num_inference_steps)/float(num_inference_steps))
|
||||
inference_progress = f"{inference_counter + 1 if inference_counter < num_inference_steps else num_inference_steps}/{num_inference_steps} {inference_percent}% "
|
||||
else:
|
||||
inference_progress = ""
|
||||
if duration >= 1:
|
||||
speed = "s/it"
|
||||
else:
|
||||
speed = "it/s"
|
||||
duration = 1 / duration
|
||||
|
||||
percent = int(100 * float(i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps)/float(st.session_state.sampling_steps))
|
||||
frames_percent = int(100 * float(st.session_state.current_frame if st.session_state.current_frame < st.session_state.max_frames else st.session_state.max_frames)/float(st.session_state.max_frames))
|
||||
if i > st.session_state.sampling_steps:
|
||||
inference_counter += 1
|
||||
inference_percent = int(100 * float(inference_counter + 1 if inference_counter < num_inference_steps else num_inference_steps)/float(num_inference_steps))
|
||||
inference_progress = f"{inference_counter + 1 if inference_counter < num_inference_steps else num_inference_steps}/{num_inference_steps} {inference_percent}% "
|
||||
else:
|
||||
inference_progress = ""
|
||||
|
||||
st.session_state["progress_bar_text"].text(
|
||||
f"Running step: {i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps}/{st.session_state.sampling_steps} "
|
||||
f"{percent if percent < 100 else 100}% {inference_progress}{duration:.2f}{speed} | "
|
||||
f"Frame: {st.session_state.current_frame + 1 if st.session_state.current_frame < st.session_state.max_frames else st.session_state.max_frames}/{st.session_state.max_frames} "
|
||||
f"{frames_percent if frames_percent < 100 else 100}% {st.session_state.frame_duration:.2f}{st.session_state.frame_speed}"
|
||||
)
|
||||
st.session_state["progress_bar"].progress(percent if percent < 100 else 100)
|
||||
percent = int(100 * float(i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps)/float(st.session_state.sampling_steps))
|
||||
frames_percent = int(100 * float(st.session_state.current_frame if st.session_state.current_frame < st.session_state.max_frames else st.session_state.max_frames)/float(
|
||||
st.session_state.max_frames))
|
||||
|
||||
st.session_state["progress_bar_text"].text(
|
||||
f"Running step: {i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps}/{st.session_state.sampling_steps} "
|
||||
f"{percent if percent < 100 else 100}% {inference_progress}{duration:.2f}{speed} | "
|
||||
f"Frame: {st.session_state.current_frame + 1 if st.session_state.current_frame < st.session_state.max_frames else st.session_state.max_frames}/{st.session_state.max_frames} "
|
||||
f"{frames_percent if frames_percent < 100 else 100}% {st.session_state.frame_duration:.2f}{st.session_state.frame_speed}"
|
||||
)
|
||||
st.session_state["progress_bar"].progress(percent if percent < 100 else 100)
|
||||
|
||||
except KeyError:
|
||||
raise StopException
|
||||
|
||||
#scale and decode the image latents with vae
|
||||
cond_latents_2 = 1 / 0.18215 * cond_latents
|
||||
@ -262,7 +267,23 @@ def load_diffusers_model(weights_path,torch_device):
|
||||
"You need a huggingface token in order to use the Text to Video tab. Use the Settings page from the sidebar on the left to add your token."
|
||||
)
|
||||
raise OSError("You need a huggingface token in order to use the Text to Video tab. Use the Settings page from the sidebar on the left to add your token.")
|
||||
#
|
||||
def save_video_to_disk(frames, seeds, sanitized_prompt, fps=6,save_video=True, outdir='outputs'):
|
||||
if save_video:
|
||||
# write video to memory
|
||||
#output = io.BytesIO()
|
||||
#writer = imageio.get_writer(os.path.join(os.getcwd(), st.session_state['defaults'].general.outdir, "txt2vid"), im, extension=".mp4", fps=30)
|
||||
#try:
|
||||
video_path = os.path.join(os.getcwd(), outdir, "txt2vid",f"{seeds}_{sanitized_prompt}.mp4")
|
||||
writer = imageio.get_writer(video_path, fps=fps)
|
||||
for frame in frames:
|
||||
writer.append_data(frame)
|
||||
|
||||
writer.close()
|
||||
#except:
|
||||
# print("Can't save video, skipping.")
|
||||
|
||||
return video_path
|
||||
#
|
||||
def txt2vid(
|
||||
# --------------------------------------
|
||||
@ -275,6 +296,9 @@ def txt2vid(
|
||||
max_frames:int = 10000, # number of frames to write and then exit the script
|
||||
num_inference_steps:int = 50, # more (e.g. 100, 200 etc) can create slightly better images
|
||||
cfg_scale:float = 5.0, # can depend on the prompt. usually somewhere between 3-10 is good
|
||||
save_video = True,
|
||||
save_video_on_stop = False,
|
||||
outdir='outputs',
|
||||
do_loop = False,
|
||||
use_lerp_for_text = False,
|
||||
seeds = None,
|
||||
@ -332,11 +356,11 @@ def txt2vid(
|
||||
# init the output dir
|
||||
sanitized_prompt = slugify(prompts)
|
||||
|
||||
full_path = os.path.join(os.getcwd(), st.session_state['defaults'].general.outdir, "txt2vid-samples", "samples", sanitized_prompt)
|
||||
full_path = os.path.join(os.getcwd(), st.session_state['defaults'].general.outdir, "txt2vid", "samples", sanitized_prompt)
|
||||
|
||||
if len(full_path) > 220:
|
||||
sanitized_prompt = sanitized_prompt[:220-len(full_path)]
|
||||
full_path = os.path.join(os.getcwd(), st.session_state['defaults'].general.outdir, "txt2vid-samples", "samples", sanitized_prompt)
|
||||
full_path = os.path.join(os.getcwd(), st.session_state['defaults'].general.outdir, "txt2vid", "samples", sanitized_prompt)
|
||||
|
||||
os.makedirs(full_path, exist_ok=True)
|
||||
|
||||
@ -512,11 +536,12 @@ def txt2vid(
|
||||
|
||||
#append the frames to the frames list so we can use them later.
|
||||
frames.append(np.asarray(gfpgan_image))
|
||||
|
||||
st.session_state["preview_image"].image(gfpgan_image)
|
||||
#except AttributeError:
|
||||
try:
|
||||
st.session_state["preview_image"].image(gfpgan_image)
|
||||
except KeyError:
|
||||
print ("Cant get session_state, skipping image preview.")
|
||||
#except (AttributeError, KeyError):
|
||||
#print("Cant perform GFPGAN, skipping.")
|
||||
#pass
|
||||
|
||||
#increase frame_index counter.
|
||||
frame_index += 1
|
||||
@ -536,23 +561,18 @@ def txt2vid(
|
||||
|
||||
init1 = init2
|
||||
|
||||
# save the video after the generation is done.
|
||||
video_path = save_video_to_disk(frames, seeds, sanitized_prompt, save_video=save_video, outdir=outdir)
|
||||
|
||||
except StopException:
|
||||
pass
|
||||
if save_video_on_stop:
|
||||
print ("Streamlit Stop Exception Received. Saving video")
|
||||
video_path = save_video_to_disk(frames, seeds, sanitized_prompt, save_video=save_video, outdir=outdir)
|
||||
else:
|
||||
video_path = None
|
||||
|
||||
|
||||
if st.session_state['save_video']:
|
||||
# write video to memory
|
||||
#output = io.BytesIO()
|
||||
#writer = imageio.get_writer(os.path.join(os.getcwd(), st.session_state['defaults'].general.outdir, "txt2vid-samples"), im, extension=".mp4", fps=30)
|
||||
try:
|
||||
video_path = os.path.join(os.getcwd(), st.session_state['defaults'].general.outdir, "txt2vid-samples",f"{seeds}_{sanitized_prompt}.mp4")
|
||||
writer = imageio.get_writer(video_path, fps=6)
|
||||
for frame in frames:
|
||||
writer.append_data(frame)
|
||||
writer.close()
|
||||
except:
|
||||
print("Can't save video, skipping.")
|
||||
|
||||
if video_path and "preview_video" in st.session_state:
|
||||
# show video preview on the UI
|
||||
st.session_state["preview_video"].video(open(video_path, 'rb').read())
|
||||
|
||||
@ -620,6 +640,11 @@ def layout():
|
||||
help="Frequency in steps at which the the preview image is updated. By default the frequency \
|
||||
is set to 1 step.")
|
||||
|
||||
st.session_state["dynamic_preview_frequency"] = st.checkbox("Dynamic Preview Frequency", value=st.session_state['defaults'].txt2vid.dynamic_preview_frequency,
|
||||
help="This option tries to find the best value at which we can update \
|
||||
the preview image during generation while minimizing the impact it has in performance. Default: True")
|
||||
|
||||
|
||||
#
|
||||
|
||||
|
||||
@ -644,6 +669,7 @@ def layout():
|
||||
|
||||
#generate_video = st.empty()
|
||||
st.session_state["preview_video"] = st.empty()
|
||||
preview_video = st.session_state["preview_video"]
|
||||
|
||||
message = st.empty()
|
||||
|
||||
@ -702,19 +728,23 @@ def layout():
|
||||
help="Separate multiple prompts using the `|` character, and get all combinations of them.")
|
||||
st.session_state["normalize_prompt_weights"] = st.checkbox("Normalize Prompt Weights.",
|
||||
value=st.session_state['defaults'].txt2vid.normalize_prompt_weights, help="Ensure the sum of all weights add up to 1.0")
|
||||
|
||||
st.session_state["save_individual_images"] = st.checkbox("Save individual images.",
|
||||
value=st.session_state['defaults'].txt2vid.save_individual_images,
|
||||
help="Save each image generated before any filter or enhancement is applied.")
|
||||
|
||||
st.session_state["save_video"] = st.checkbox("Save video",value=st.session_state['defaults'].txt2vid.save_video,
|
||||
help="Save a video with all the images generated as frames at the end of the generation.")
|
||||
|
||||
save_video_on_stop = st.checkbox("Save video on Stop",value=st.session_state['defaults'].txt2vid.save_video_on_stop,
|
||||
help="Save a video with all the images generated as frames when we hit the stop button during a generation.")
|
||||
|
||||
st.session_state["group_by_prompt"] = st.checkbox("Group results by prompt", value=st.session_state['defaults'].txt2vid.group_by_prompt,
|
||||
help="Saves all the images with the same prompt into the same folder. When using a prompt matrix each prompt combination will have its own folder.")
|
||||
help="Saves all the images with the same prompt into the same folder. When using a prompt \
|
||||
matrix each prompt combination will have its own folder.")
|
||||
|
||||
st.session_state["write_info_files"] = st.checkbox("Write Info file", value=st.session_state['defaults'].txt2vid.write_info_files,
|
||||
help="Save a file next to the image with informartion about the generation.")
|
||||
st.session_state["dynamic_preview_frequency"] = st.checkbox("Dynamic Preview Frequency", value=st.session_state['defaults'].txt2vid.dynamic_preview_frequency,
|
||||
help="This option tries to find the best value at which we can update \
|
||||
the preview image during generation while minimizing the impact it has in performance. Default: True")
|
||||
st.session_state["do_loop"] = st.checkbox("Do Loop", value=st.session_state['defaults'].txt2vid.do_loop,
|
||||
help="Do loop")
|
||||
st.session_state["save_as_jpg"] = st.checkbox("Save samples as jpg", value=st.session_state['defaults'].txt2vid.save_as_jpg, help="Saves the images as jpg instead of png.")
|
||||
@ -830,7 +860,7 @@ def layout():
|
||||
#load_models(False, st.session_state["use_GFPGAN"], True, st.session_state["RealESRGAN_model"])
|
||||
|
||||
if st.session_state["use_GFPGAN"]:
|
||||
if "GFPGAN" in st.session_state:
|
||||
if "GFPGAN" in server_state:
|
||||
print("GFPGAN already loaded")
|
||||
else:
|
||||
with col2:
|
||||
@ -838,28 +868,35 @@ def layout():
|
||||
# Load GFPGAN
|
||||
if os.path.exists(st.session_state["defaults"].general.GFPGAN_dir):
|
||||
try:
|
||||
server_state["GFPGAN"] = load_GFPGAN()
|
||||
load_GFPGAN()
|
||||
print("Loaded GFPGAN")
|
||||
except Exception:
|
||||
import traceback
|
||||
print("Error loading GFPGAN:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
else:
|
||||
if "GFPGAN" in st.session_state:
|
||||
if "GFPGAN" in server_state:
|
||||
del server_state["GFPGAN"]
|
||||
|
||||
#try:
|
||||
# run video generation
|
||||
video, seed, info, stats = txt2vid(prompts=prompt, gpu=st.session_state["defaults"].general.gpu,
|
||||
num_steps=st.session_state.sampling_steps, max_frames=int(st.session_state.max_frames),
|
||||
num_steps=st.session_state.sampling_steps, max_frames=st.session_state.max_frames,
|
||||
num_inference_steps=st.session_state.num_inference_steps,
|
||||
cfg_scale=cfg_scale,do_loop=st.session_state["do_loop"],
|
||||
cfg_scale=cfg_scale, save_video_on_stop=save_video_on_stop,
|
||||
outdir=st.session_state["defaults"].general.outdir,
|
||||
do_loop=st.session_state["do_loop"],
|
||||
seeds=seed, quality=100, eta=0.0, width=width,
|
||||
height=height, weights_path=custom_model, scheduler=scheduler_name,
|
||||
disable_tqdm=False, beta_start=st.session_state['defaults'].txt2vid.beta_start.value,
|
||||
beta_end=st.session_state['defaults'].txt2vid.beta_end.value,
|
||||
beta_schedule=beta_scheduler_type, starting_image=None)
|
||||
|
||||
if video and save_video_on_stop:
|
||||
# show video preview on the UI after we hit the stop button
|
||||
# currently not working as session_state is cleared on StopException
|
||||
preview_video.video(open(video, 'rb').read())
|
||||
|
||||
#message.success('Done!', icon="✅")
|
||||
message.success('Render Complete: ' + info + '; Stats: ' + stats, icon="✅")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user