GFPGAN can now be used on the txt2vid tab.

This commit is contained in:
ZeroCool940711 2022-09-18 15:17:02 -07:00
parent 300e6865d6
commit a7d7955721

View File

@ -151,9 +151,9 @@ def diffuse(
# generate output numpy image as uint8
image = torch.clamp((image["sample"] + 1.0) / 2.0, min=0.0, max=1.0)
image = transforms.ToPILImage()(image.squeeze_(0))
image2 = transforms.ToPILImage()(image.squeeze_(0))
st.session_state["preview_image"].image(image)
st.session_state["preview_image"].image(image2)
step_counter = 0
@ -185,7 +185,7 @@ def diffuse(
)
st.session_state["progress_bar"].progress(percent if percent < 100 else 100)
return image
return image2
#
def txt2vid(
@ -376,6 +376,13 @@ def txt2vid(
st.session_state["pipe"].scheduler = SCHEDULERS[scheduler]
if do_loop:
prompts = str([prompts, prompts])
seeds = [seeds, seeds]
#first_seed, *seeds = seeds
#prompts.append(prompts)
#seeds.append(first_seed)
# get the conditional text embeddings based on the prompt
text_input = st.session_state["pipe"].tokenizer(prompts, padding="max_length", max_length=st.session_state["pipe"].tokenizer.model_max_length, truncation=True, return_tensors="pt")
cond_embeddings = st.session_state["pipe"].text_encoder(text_input.input_ids.to(torch_device))[0] # shape [1, 77, 768]
@ -414,13 +421,6 @@ def txt2vid(
# sample a source
init1 = torch.randn((1, st.session_state["pipe"].unet.in_channels, height // 8, width // 8), device=torch_device)
if do_loop:
prompts = [prompts, prompts]
seeds = [seeds, seeds]
#first_seed, *seeds = seeds
#prompts.append(prompts)
#seeds.append(first_seed)
# iterate the loop
frames = []
@ -452,15 +452,40 @@ def txt2vid(
with autocast("cuda"):
image = diffuse(st.session_state["pipe"], cond_embeddings, init, num_inference_steps, cfg_scale, eta)
#im = Image.fromarray(image)
outpath = os.path.join(full_path, 'frame%06d.png' % frame_index)
image.save(outpath, quality=quality)
if st.session_state["save_individual_images"] and not st.session_state["use_GFPGAN"] and not st.session_state["use_RealESRGAN"]:
#im = Image.fromarray(image)
outpath = os.path.join(full_path, 'frame%06d.png' % frame_index)
image.save(outpath, quality=quality)
# send the image to the UI to update it
#st.session_state["preview_image"].image(im)
# send the image to the UI to update it
#st.session_state["preview_image"].image(im)
#append the frames to the frames list so we can use them later.
frames.append(np.asarray(image))
#append the frames to the frames list so we can use them later.
frames.append(np.asarray(image))
#
#try:
#if st.session_state["use_GFPGAN"] and st.session_state["GFPGAN"] is not None and not st.session_state["use_RealESRGAN"]:
if st.session_state["use_GFPGAN"] and st.session_state["GFPGAN"] is not None:
#print("Running GFPGAN on image ...")
st.session_state["progress_bar_text"].text("Running GFPGAN on image ...")
#skip_save = True # #287 >_>
torch_gc()
cropped_faces, restored_faces, restored_img = st.session_state["GFPGAN"].enhance(np.array(image)[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True)
gfpgan_sample = restored_img[:,:,::-1]
gfpgan_image = Image.fromarray(gfpgan_sample)
outpath = os.path.join(full_path, 'frame%06d.png' % frame_index)
gfpgan_image.save(outpath, quality=quality)
#append the frames to the frames list so we can use them later.
frames.append(np.asarray(gfpgan_image))
st.session_state["preview_image"].image(gfpgan_image)
#except AttributeError:
#print("Cant perform GFPGAN, skipping.")
#pass
#increase frame_index counter.
frame_index += 1
@ -661,7 +686,7 @@ def layout():
st.session_state.num_inference_steps = st.slider("Inference Steps:", value=st.session_state['defaults'].txt2vid.num_inference_steps.value,
min_value=st.session_state['defaults'].txt2vid.num_inference_steps.min_value,
step=st.session_state['defaults'].txt2vid.num_inference_steps.max_value,
step=st.session_state['defaults'].txt2vid.num_inference_steps.step,
max_value=st.session_state['defaults'].txt2vid.num_inference_steps.max_value,
help="Higher values (e.g. 100, 200 etc) can create better images.")
@ -731,65 +756,82 @@ def layout():
if generate_button:
#print("Loading models")
# load the models when we hit the generate button for the first time, it wont be loaded after that so dont worry.
#load_models(False, False, False, st.session_state["RealESRGAN_model"], CustomModel_available=st.session_state["CustomModel_available"], custom_model=custom_model)
#load_models(False, st.session_state["use_GFPGAN"], True, st.session_state["RealESRGAN_model"])
try:
# run video generation
video, seed, info, stats = txt2vid(prompts=prompt, gpu=st.session_state["defaults"].general.gpu,
num_steps=st.session_state.sampling_steps, max_frames=int(st.session_state.max_frames),
num_inference_steps=st.session_state.num_inference_steps,
cfg_scale=cfg_scale,do_loop=st.session_state["do_loop"],
seeds=seed, quality=100, eta=0.0, width=width,
height=height, weights_path=custom_model, scheduler=scheduler_name,
disable_tqdm=False, beta_start=st.session_state['defaults'].txt2vid.beta_start.value,
beta_end=st.session_state['defaults'].txt2vid.beta_end.value,
beta_schedule=beta_scheduler_type, starting_image=None)
if st.session_state["use_GFPGAN"]:
if "GFPGAN" in st.session_state:
print("GFPGAN already loaded")
else:
# Load GFPGAN
if os.path.exists(st.session_state["defaults"].general.GFPGAN_dir):
try:
st.session_state["GFPGAN"] = load_GFPGAN()
print("Loaded GFPGAN")
except Exception:
import traceback
print("Error loading GFPGAN:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
else:
if "GFPGAN" in st.session_state:
del st.session_state["GFPGAN"]
#message.success('Done!', icon="✅")
message.success('Render Complete: ' + info + '; Stats: ' + stats, icon="")
#try:
# run video generation
video, seed, info, stats = txt2vid(prompts=prompt, gpu=st.session_state["defaults"].general.gpu,
num_steps=st.session_state.sampling_steps, max_frames=int(st.session_state.max_frames),
num_inference_steps=st.session_state.num_inference_steps,
cfg_scale=cfg_scale,do_loop=st.session_state["do_loop"],
seeds=seed, quality=100, eta=0.0, width=width,
height=height, weights_path=custom_model, scheduler=scheduler_name,
disable_tqdm=False, beta_start=st.session_state['defaults'].txt2vid.beta_start.value,
beta_end=st.session_state['defaults'].txt2vid.beta_end.value,
beta_schedule=beta_scheduler_type, starting_image=None)
#history_tab,col1,col2,col3,PlaceHolder,col1_cont,col2_cont,col3_cont = st.session_state['historyTab']
#message.success('Done!', icon="✅")
message.success('Render Complete: ' + info + '; Stats: ' + stats, icon="")
#if 'latestVideos' in st.session_state:
#for i in video:
##push the new image to the list of latest images and remove the oldest one
##remove the last index from the list\
#st.session_state['latestVideos'].pop()
##add the new image to the start of the list
#st.session_state['latestVideos'].insert(0, i)
#PlaceHolder.empty()
#history_tab,col1,col2,col3,PlaceHolder,col1_cont,col2_cont,col3_cont = st.session_state['historyTab']
#with PlaceHolder.container():
#col1, col2, col3 = st.columns(3)
#col1_cont = st.container()
#col2_cont = st.container()
#col3_cont = st.container()
#if 'latestVideos' in st.session_state:
#for i in video:
##push the new image to the list of latest images and remove the oldest one
##remove the last index from the list\
#st.session_state['latestVideos'].pop()
##add the new image to the start of the list
#st.session_state['latestVideos'].insert(0, i)
#PlaceHolder.empty()
#with col1_cont:
#with col1:
#st.image(st.session_state['latestVideos'][0])
#st.image(st.session_state['latestVideos'][3])
#st.image(st.session_state['latestVideos'][6])
#with col2_cont:
#with col2:
#st.image(st.session_state['latestVideos'][1])
#st.image(st.session_state['latestVideos'][4])
#st.image(st.session_state['latestVideos'][7])
#with col3_cont:
#with col3:
#st.image(st.session_state['latestVideos'][2])
#st.image(st.session_state['latestVideos'][5])
#st.image(st.session_state['latestVideos'][8])
#historyGallery = st.empty()
#with PlaceHolder.container():
#col1, col2, col3 = st.columns(3)
#col1_cont = st.container()
#col2_cont = st.container()
#col3_cont = st.container()
## check if output_images length is the same as seeds length
#with gallery_tab:
#st.markdown(createHTMLGallery(video,seed), unsafe_allow_html=True)
#with col1_cont:
#with col1:
#st.image(st.session_state['latestVideos'][0])
#st.image(st.session_state['latestVideos'][3])
#st.image(st.session_state['latestVideos'][6])
#with col2_cont:
#with col2:
#st.image(st.session_state['latestVideos'][1])
#st.image(st.session_state['latestVideos'][4])
#st.image(st.session_state['latestVideos'][7])
#with col3_cont:
#with col3:
#st.image(st.session_state['latestVideos'][2])
#st.image(st.session_state['latestVideos'][5])
#st.image(st.session_state['latestVideos'][8])
#historyGallery = st.empty()
## check if output_images length is the same as seeds length
#with gallery_tab:
#st.markdown(createHTMLGallery(video,seed), unsafe_allow_html=True)
#st.session_state['historyTab'] = [history_tab,col1,col2,col3,PlaceHolder,col1_cont,col2_cont,col3_cont]
#st.session_state['historyTab'] = [history_tab,col1,col2,col3,PlaceHolder,col1_cont,col2_cont,col3_cont]
except (StopException, KeyError):
print(f"Received Streamlit StopException")
#except (StopException, KeyError):
#print(f"Received Streamlit StopException")