Merge pull request #1132 from ZeroCool940711/dev

Fixed max_frame not being properly used and instead sampling_steps was the variable being use.
This commit is contained in:
ZeroCool 2022-09-14 12:02:49 -07:00 committed by GitHub
commit 8bc8b006fd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 95 additions and 80 deletions

View File

@ -28,6 +28,7 @@ general:
n_rows: -1
no_verify_input: False
no_half: False
use_float16: False
precision: "autocast"
optimized: False
optimized_turbo: False

View File

@ -882,21 +882,26 @@ def slerp(device, t, v0:torch.Tensor, v1:torch.Tensor, DOT_THRESHOLD=0.9995):
return v2
#
def optimize_update_preview_frequency(current_chunk_speed, previous_chunk_speed, update_preview_frequency):
def optimize_update_preview_frequency(current_chunk_speed, previous_chunk_speed_list, update_preview_frequency, update_preview_frequency_list):
"""Find the optimal update_preview_frequency value maximizing
performance while minimizing the time between updates."""
if current_chunk_speed >= previous_chunk_speed:
from statistics import mean
previous_chunk_avg_speed = mean(previous_chunk_speed_list)
previous_chunk_speed_list.append(current_chunk_speed)
current_chunk_avg_speed = mean(previous_chunk_speed_list)
if current_chunk_avg_speed >= previous_chunk_avg_speed:
#print(f"{current_chunk_speed} >= {previous_chunk_speed}")
update_preview_frequency +=1
previous_chunk_speed = current_chunk_speed
update_preview_frequency_list.append(update_preview_frequency + 1)
else:
#print(f"{current_chunk_speed} <= {previous_chunk_speed}")
update_preview_frequency -=1
previous_chunk_speed = current_chunk_speed
return current_chunk_speed, previous_chunk_speed, update_preview_frequency
update_preview_frequency_list.append(update_preview_frequency - 1)
update_preview_frequency = round(mean(update_preview_frequency_list))
return current_chunk_speed, previous_chunk_speed_list, update_preview_frequency, update_preview_frequency_list
def get_font(fontsize):

View File

@ -179,9 +179,9 @@ def layout():
with st.expander("Preview Settings"):
st.session_state["update_preview"] = st.checkbox("Update Image Preview", value=st.session_state['defaults'].txt2img.update_preview,
help="If enabled the image preview will be updated during the generation instead of at the end. \
You can use the Update Preview \Frequency option bellow to customize how frequent it's updated. \
By default this is enabled and the frequency is set to 1 step.")
help="If enabled the image preview will be updated during the generation instead of at the end. \
You can use the Update Preview \Frequency option bellow to customize how frequent it's updated. \
By default this is enabled and the frequency is set to 1 step.")
st.session_state["update_preview_frequency"] = st.text_input("Update Image Preview Frequency", value=st.session_state['defaults'].txt2img.update_preview_frequency,
help="Frequency in steps at which the the preview image is updated. By default the frequency \

View File

@ -49,7 +49,7 @@ if os.path.exists(os.path.join(st.session_state['defaults'].general.GFPGAN_dir,
else:
GFPGAN_available = False
if os.path.exists(os.path.join(st.session_state['defaults'].general.RealESRGAN_dir, "experiments","pretrained_models", f"{st.session_state['defaults'].general.RealESRGAN_model}.pth")):
if os.path.exists(os.path.join(st.session_state['defaults'].general.RealESRGAN_dir, "experiments","pretrained_models", f"{st.session_state['defaults'].txt2vid.RealESRGAN_model}.pth")):
RealESRGAN_available = True
else:
RealESRGAN_available = False
@ -98,8 +98,18 @@ def diffuse(
step_counter = 0
inference_counter = 0
current_chunk_speed = 0
previous_chunk_speed = 0
if "current_chunk_speed" not in st.session_state:
st.session_state["current_chunk_speed"] = 0
if "previous_chunk_speed_list" not in st.session_state:
st.session_state["previous_chunk_speed_list"] = [0]
st.session_state["previous_chunk_speed_list"].append(st.session_state["current_chunk_speed"])
if "update_preview_frequency_list" not in st.session_state:
st.session_state["update_preview_frequency_list"] = [0]
st.session_state["update_preview_frequency_list"].append(st.session_state['defaults'].txt2vid.update_preview_frequency)
# diffuse!
for i, t in enumerate(pipe.scheduler.timesteps):
@ -128,14 +138,13 @@ def diffuse(
#print (st.session_state["update_preview_frequency"])
#update the preview image if it is enabled and the frequency matches the step_counter
if st.session_state['defaults'].general.update_preview:
if st.session_state['defaults'].txt2vid.update_preview:
step_counter += 1
if st.session_state.dynamic_preview_frequency:
current_chunk_speed, previous_chunk_speed, st.session_state['defaults'].general.update_preview_frequency = optimize_update_preview_frequency(
current_chunk_speed, previous_chunk_speed, st.session_state['defaults'].general.update_preview_frequency)
if st.session_state['defaults'].general.update_preview_frequency == step_counter or step_counter == st.session_state.sampling_steps:
if st.session_state['defaults'].txt2vid.update_preview_frequency == step_counter or step_counter == st.session_state.sampling_steps:
if st.session_state.dynamic_preview_frequency:
st.session_state["current_chunk_speed"], st.session_state["previous_chunk_speed_list"], st.session_state['defaults'].txt2vid.update_preview_frequency, st.session_state["avg_update_preview_frequency"] = optimize_update_preview_frequency(st.session_state["current_chunk_speed"], st.session_state["previous_chunk_speed_list"], st.session_state['defaults'].txt2vid.update_preview_frequency, st.session_state["update_preview_frequency_list"])
#scale and decode the image latents with vae
cond_latents_2 = 1 / 0.18215 * cond_latents
image_2 = pipe.vae.decode(cond_latents_2)
@ -151,7 +160,7 @@ def diffuse(
duration = timeit.default_timer() - start
current_chunk_speed = duration
st.session_state["current_chunk_speed"] = duration
if duration >= 1:
speed = "s/it"
@ -161,8 +170,8 @@ def diffuse(
if i > st.session_state.sampling_steps:
inference_counter += 1
inference_percent = int(100 * float(inference_counter if inference_counter < num_inference_steps else num_inference_steps)/float(num_inference_steps))
inference_progress = f"{inference_counter if inference_counter < num_inference_steps else num_inference_steps}/{num_inference_steps} {inference_percent}% "
inference_percent = int(100 * float(inference_counter + 1 if inference_counter < num_inference_steps else num_inference_steps)/float(num_inference_steps))
inference_progress = f"{inference_counter + 1 if inference_counter < num_inference_steps else num_inference_steps}/{num_inference_steps} {inference_percent}% "
else:
inference_progress = ""
@ -172,7 +181,7 @@ def diffuse(
st.session_state["progress_bar_text"].text(
f"Running step: {i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps}/{st.session_state.sampling_steps} "
f"{percent if percent < 100 else 100}% {inference_progress}{duration:.2f}{speed} | "
f"Frame: {st.session_state.current_frame if st.session_state.current_frame < st.session_state.max_frames else st.session_state.max_frames}/{st.session_state.max_frames} "
f"Frame: {st.session_state.current_frame + 1 if st.session_state.current_frame < st.session_state.max_frames else st.session_state.max_frames}/{st.session_state.max_frames} "
f"{frames_percent if frames_percent < 100 else 100}% {st.session_state.frame_duration:.2f}{st.session_state.frame_speed}"
)
st.session_state["progress_bar"].progress(percent if percent < 100 else 100)
@ -246,7 +255,7 @@ def txt2vid(
# We add an extra frame because most
# of the time the first frame is just the noise.
max_frames +=1
#max_frames +=1
assert torch.cuda.is_available()
assert height % 8 == 0 and width % 8 == 0
@ -329,7 +338,7 @@ def txt2vid(
#print (st.session_state["weights_path"] != weights_path)
try:
if not st.session_state["pipe"] or st.session_state["weights_path"] != weights_path:
if not "pipe" in st.session_state or st.session_state["weights_path"] != weights_path:
if st.session_state["weights_path"] != weights_path:
del st.session_state["weights_path"]
@ -338,7 +347,7 @@ def txt2vid(
weights_path,
use_local_file=True,
use_auth_token=True,
#torch_dtype=torch.float16 if not st.session_state['defaults'].general.no_half else None,
torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None,
revision="fp16" if not st.session_state['defaults'].general.no_half else None
)
@ -358,7 +367,7 @@ def txt2vid(
weights_path,
use_local_file=True,
use_auth_token=True,
#torch_dtype=torch.float16 if not st.session_state['defaults'].general.no_half else None,
torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None,
revision="fp16" if not st.session_state['defaults'].general.no_half else None
)
@ -388,8 +397,8 @@ def txt2vid(
frames = []
frame_index = 0
st.session_state["frame_total_duration"] = 0
st.session_state["frame_total_speed"] = 0
st.session_state["total_frames_avg_duration"] = []
st.session_state["total_frames_avg_speed"] = []
try:
while frame_index < max_frames:
@ -400,9 +409,9 @@ def txt2vid(
# sample the destination
init2 = torch.randn((1, st.session_state["pipe"].unet.in_channels, height // 8, width // 8), device=torch_device)
for i, t in enumerate(np.linspace(0, 1, num_steps)):
for i, t in enumerate(np.linspace(0, 1, max_frames)):
start = timeit.default_timer()
print(f"COUNT: {frame_index+1}/{num_steps}")
print(f"COUNT: {frame_index+1}/{max_frames}")
#if use_lerp_for_text:
#init = torch.lerp(init1, init2, float(t))
@ -474,6 +483,52 @@ def txt2vid(
return im, seeds, info, stats
#on import run init
def createHTMLGallery(images,info):
html3 = """
<div class="gallery-history" style="
display: flex;
flex-wrap: wrap;
align-items: flex-start;">
"""
mkdwn_array = []
for i in images:
try:
seed = info[images.index(i)]
except:
seed = ' '
image_io = BytesIO()
i.save(image_io, 'PNG')
width, height = i.size
#get random number for the id
image_id = "%s" % (str(images.index(i)))
(data, mimetype) = STImage._normalize_to_bytes(image_io.getvalue(), width, 'auto')
this_file = in_memory_file_manager.add(data, mimetype, image_id)
img_str = this_file.url
#img_str = 'data:image/png;base64,' + b64encode(image_io.getvalue()).decode('ascii')
#get image size
#make sure the image is not bigger then 150px but keep the aspect ratio
if width > 150:
height = int(height * (150/width))
width = 150
if height > 150:
width = int(width * (150/height))
height = 150
#mkdwn = f"""<img src="{img_str}" alt="Image" with="200" height="200" />"""
mkdwn = f'''<div class="gallery" style="margin: 3px;" >
<a href="{img_str}">
<img src="{img_str}" alt="Image" width="{width}" height="{height}">
</a>
<div class="desc" style="text-align: center; opacity: 40%;">{seed}</div>
</div>
'''
mkdwn_array.append(mkdwn)
html3 += "".join(mkdwn_array)
html3 += '</div>'
return html3
#
def layout():
with st.form("txt2vid-inputs"):
@ -513,7 +568,7 @@ def layout():
st.session_state["update_preview_frequency"] = st.text_input("Update Image Preview Frequency", value=st.session_state['defaults'].txt2vid.update_preview_frequency,
help="Frequency in steps at which the the preview image is updated. By default the frequency \
is set to 1 step.")
is set to 1 step.")
with col2:
preview_tab, gallery_tab = st.tabs(["Preview", "Gallery"])
@ -642,49 +697,3 @@ def layout():
#preview_image.image(output_images)
#on import run init
def createHTMLGallery(images,info):
html3 = """
<div class="gallery-history" style="
display: flex;
flex-wrap: wrap;
align-items: flex-start;">
"""
mkdwn_array = []
for i in images:
try:
seed = info[images.index(i)]
except:
seed = ' '
image_io = BytesIO()
i.save(image_io, 'PNG')
width, height = i.size
#get random number for the id
image_id = "%s" % (str(images.index(i)))
(data, mimetype) = STImage._normalize_to_bytes(image_io.getvalue(), width, 'auto')
this_file = in_memory_file_manager.add(data, mimetype, image_id)
img_str = this_file.url
#img_str = 'data:image/png;base64,' + b64encode(image_io.getvalue()).decode('ascii')
#get image size
#make sure the image is not bigger then 150px but keep the aspect ratio
if width > 150:
height = int(height * (150/width))
width = 150
if height > 150:
width = int(width * (150/height))
height = 150
#mkdwn = f"""<img src="{img_str}" alt="Image" with="200" height="200" />"""
mkdwn = f'''<div class="gallery" style="margin: 3px;" >
<a href="{img_str}">
<img src="{img_str}" alt="Image" width="{width}" height="{height}">
</a>
<div class="desc" style="text-align: center; opacity: 40%;">{seed}</div>
</div>
'''
mkdwn_array.append(mkdwn)
html3 += "".join(mkdwn_array)
html3 += '</div>'
return html3