mirror of
https://github.com/Sygil-Dev/sygil-webui.git
synced 2024-12-15 14:31:44 +03:00
GFPGAN can now be used on the txt2vid tab.
This commit is contained in:
parent
300e6865d6
commit
a7d7955721
@ -151,9 +151,9 @@ def diffuse(
|
||||
|
||||
# generate output numpy image as uint8
|
||||
image = torch.clamp((image["sample"] + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
image = transforms.ToPILImage()(image.squeeze_(0))
|
||||
image2 = transforms.ToPILImage()(image.squeeze_(0))
|
||||
|
||||
st.session_state["preview_image"].image(image)
|
||||
st.session_state["preview_image"].image(image2)
|
||||
|
||||
step_counter = 0
|
||||
|
||||
@ -185,7 +185,7 @@ def diffuse(
|
||||
)
|
||||
st.session_state["progress_bar"].progress(percent if percent < 100 else 100)
|
||||
|
||||
return image
|
||||
return image2
|
||||
|
||||
#
|
||||
def txt2vid(
|
||||
@ -376,6 +376,13 @@ def txt2vid(
|
||||
|
||||
st.session_state["pipe"].scheduler = SCHEDULERS[scheduler]
|
||||
|
||||
if do_loop:
|
||||
prompts = str([prompts, prompts])
|
||||
seeds = [seeds, seeds]
|
||||
#first_seed, *seeds = seeds
|
||||
#prompts.append(prompts)
|
||||
#seeds.append(first_seed)
|
||||
|
||||
# get the conditional text embeddings based on the prompt
|
||||
text_input = st.session_state["pipe"].tokenizer(prompts, padding="max_length", max_length=st.session_state["pipe"].tokenizer.model_max_length, truncation=True, return_tensors="pt")
|
||||
cond_embeddings = st.session_state["pipe"].text_encoder(text_input.input_ids.to(torch_device))[0] # shape [1, 77, 768]
|
||||
@ -414,13 +421,6 @@ def txt2vid(
|
||||
# sample a source
|
||||
init1 = torch.randn((1, st.session_state["pipe"].unet.in_channels, height // 8, width // 8), device=torch_device)
|
||||
|
||||
if do_loop:
|
||||
prompts = [prompts, prompts]
|
||||
seeds = [seeds, seeds]
|
||||
#first_seed, *seeds = seeds
|
||||
#prompts.append(prompts)
|
||||
#seeds.append(first_seed)
|
||||
|
||||
|
||||
# iterate the loop
|
||||
frames = []
|
||||
@ -452,15 +452,40 @@ def txt2vid(
|
||||
with autocast("cuda"):
|
||||
image = diffuse(st.session_state["pipe"], cond_embeddings, init, num_inference_steps, cfg_scale, eta)
|
||||
|
||||
#im = Image.fromarray(image)
|
||||
outpath = os.path.join(full_path, 'frame%06d.png' % frame_index)
|
||||
image.save(outpath, quality=quality)
|
||||
if st.session_state["save_individual_images"] and not st.session_state["use_GFPGAN"] and not st.session_state["use_RealESRGAN"]:
|
||||
#im = Image.fromarray(image)
|
||||
outpath = os.path.join(full_path, 'frame%06d.png' % frame_index)
|
||||
image.save(outpath, quality=quality)
|
||||
|
||||
# send the image to the UI to update it
|
||||
#st.session_state["preview_image"].image(im)
|
||||
# send the image to the UI to update it
|
||||
#st.session_state["preview_image"].image(im)
|
||||
|
||||
#append the frames to the frames list so we can use them later.
|
||||
frames.append(np.asarray(image))
|
||||
#append the frames to the frames list so we can use them later.
|
||||
frames.append(np.asarray(image))
|
||||
|
||||
|
||||
#
|
||||
#try:
|
||||
#if st.session_state["use_GFPGAN"] and st.session_state["GFPGAN"] is not None and not st.session_state["use_RealESRGAN"]:
|
||||
if st.session_state["use_GFPGAN"] and st.session_state["GFPGAN"] is not None:
|
||||
#print("Running GFPGAN on image ...")
|
||||
st.session_state["progress_bar_text"].text("Running GFPGAN on image ...")
|
||||
#skip_save = True # #287 >_>
|
||||
torch_gc()
|
||||
cropped_faces, restored_faces, restored_img = st.session_state["GFPGAN"].enhance(np.array(image)[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True)
|
||||
gfpgan_sample = restored_img[:,:,::-1]
|
||||
gfpgan_image = Image.fromarray(gfpgan_sample)
|
||||
|
||||
outpath = os.path.join(full_path, 'frame%06d.png' % frame_index)
|
||||
gfpgan_image.save(outpath, quality=quality)
|
||||
|
||||
#append the frames to the frames list so we can use them later.
|
||||
frames.append(np.asarray(gfpgan_image))
|
||||
|
||||
st.session_state["preview_image"].image(gfpgan_image)
|
||||
#except AttributeError:
|
||||
#print("Cant perform GFPGAN, skipping.")
|
||||
#pass
|
||||
|
||||
#increase frame_index counter.
|
||||
frame_index += 1
|
||||
@ -661,7 +686,7 @@ def layout():
|
||||
|
||||
st.session_state.num_inference_steps = st.slider("Inference Steps:", value=st.session_state['defaults'].txt2vid.num_inference_steps.value,
|
||||
min_value=st.session_state['defaults'].txt2vid.num_inference_steps.min_value,
|
||||
step=st.session_state['defaults'].txt2vid.num_inference_steps.max_value,
|
||||
step=st.session_state['defaults'].txt2vid.num_inference_steps.step,
|
||||
max_value=st.session_state['defaults'].txt2vid.num_inference_steps.max_value,
|
||||
help="Higher values (e.g. 100, 200 etc) can create better images.")
|
||||
|
||||
@ -731,65 +756,82 @@ def layout():
|
||||
if generate_button:
|
||||
#print("Loading models")
|
||||
# load the models when we hit the generate button for the first time, it wont be loaded after that so dont worry.
|
||||
#load_models(False, False, False, st.session_state["RealESRGAN_model"], CustomModel_available=st.session_state["CustomModel_available"], custom_model=custom_model)
|
||||
#load_models(False, st.session_state["use_GFPGAN"], True, st.session_state["RealESRGAN_model"])
|
||||
|
||||
try:
|
||||
# run video generation
|
||||
video, seed, info, stats = txt2vid(prompts=prompt, gpu=st.session_state["defaults"].general.gpu,
|
||||
num_steps=st.session_state.sampling_steps, max_frames=int(st.session_state.max_frames),
|
||||
num_inference_steps=st.session_state.num_inference_steps,
|
||||
cfg_scale=cfg_scale,do_loop=st.session_state["do_loop"],
|
||||
seeds=seed, quality=100, eta=0.0, width=width,
|
||||
height=height, weights_path=custom_model, scheduler=scheduler_name,
|
||||
disable_tqdm=False, beta_start=st.session_state['defaults'].txt2vid.beta_start.value,
|
||||
beta_end=st.session_state['defaults'].txt2vid.beta_end.value,
|
||||
beta_schedule=beta_scheduler_type, starting_image=None)
|
||||
if st.session_state["use_GFPGAN"]:
|
||||
if "GFPGAN" in st.session_state:
|
||||
print("GFPGAN already loaded")
|
||||
else:
|
||||
# Load GFPGAN
|
||||
if os.path.exists(st.session_state["defaults"].general.GFPGAN_dir):
|
||||
try:
|
||||
st.session_state["GFPGAN"] = load_GFPGAN()
|
||||
print("Loaded GFPGAN")
|
||||
except Exception:
|
||||
import traceback
|
||||
print("Error loading GFPGAN:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
else:
|
||||
if "GFPGAN" in st.session_state:
|
||||
del st.session_state["GFPGAN"]
|
||||
|
||||
#message.success('Done!', icon="✅")
|
||||
message.success('Render Complete: ' + info + '; Stats: ' + stats, icon="✅")
|
||||
#try:
|
||||
# run video generation
|
||||
video, seed, info, stats = txt2vid(prompts=prompt, gpu=st.session_state["defaults"].general.gpu,
|
||||
num_steps=st.session_state.sampling_steps, max_frames=int(st.session_state.max_frames),
|
||||
num_inference_steps=st.session_state.num_inference_steps,
|
||||
cfg_scale=cfg_scale,do_loop=st.session_state["do_loop"],
|
||||
seeds=seed, quality=100, eta=0.0, width=width,
|
||||
height=height, weights_path=custom_model, scheduler=scheduler_name,
|
||||
disable_tqdm=False, beta_start=st.session_state['defaults'].txt2vid.beta_start.value,
|
||||
beta_end=st.session_state['defaults'].txt2vid.beta_end.value,
|
||||
beta_schedule=beta_scheduler_type, starting_image=None)
|
||||
|
||||
#history_tab,col1,col2,col3,PlaceHolder,col1_cont,col2_cont,col3_cont = st.session_state['historyTab']
|
||||
#message.success('Done!', icon="✅")
|
||||
message.success('Render Complete: ' + info + '; Stats: ' + stats, icon="✅")
|
||||
|
||||
#if 'latestVideos' in st.session_state:
|
||||
#for i in video:
|
||||
##push the new image to the list of latest images and remove the oldest one
|
||||
##remove the last index from the list\
|
||||
#st.session_state['latestVideos'].pop()
|
||||
##add the new image to the start of the list
|
||||
#st.session_state['latestVideos'].insert(0, i)
|
||||
#PlaceHolder.empty()
|
||||
#history_tab,col1,col2,col3,PlaceHolder,col1_cont,col2_cont,col3_cont = st.session_state['historyTab']
|
||||
|
||||
#with PlaceHolder.container():
|
||||
#col1, col2, col3 = st.columns(3)
|
||||
#col1_cont = st.container()
|
||||
#col2_cont = st.container()
|
||||
#col3_cont = st.container()
|
||||
#if 'latestVideos' in st.session_state:
|
||||
#for i in video:
|
||||
##push the new image to the list of latest images and remove the oldest one
|
||||
##remove the last index from the list\
|
||||
#st.session_state['latestVideos'].pop()
|
||||
##add the new image to the start of the list
|
||||
#st.session_state['latestVideos'].insert(0, i)
|
||||
#PlaceHolder.empty()
|
||||
|
||||
#with col1_cont:
|
||||
#with col1:
|
||||
#st.image(st.session_state['latestVideos'][0])
|
||||
#st.image(st.session_state['latestVideos'][3])
|
||||
#st.image(st.session_state['latestVideos'][6])
|
||||
#with col2_cont:
|
||||
#with col2:
|
||||
#st.image(st.session_state['latestVideos'][1])
|
||||
#st.image(st.session_state['latestVideos'][4])
|
||||
#st.image(st.session_state['latestVideos'][7])
|
||||
#with col3_cont:
|
||||
#with col3:
|
||||
#st.image(st.session_state['latestVideos'][2])
|
||||
#st.image(st.session_state['latestVideos'][5])
|
||||
#st.image(st.session_state['latestVideos'][8])
|
||||
#historyGallery = st.empty()
|
||||
#with PlaceHolder.container():
|
||||
#col1, col2, col3 = st.columns(3)
|
||||
#col1_cont = st.container()
|
||||
#col2_cont = st.container()
|
||||
#col3_cont = st.container()
|
||||
|
||||
## check if output_images length is the same as seeds length
|
||||
#with gallery_tab:
|
||||
#st.markdown(createHTMLGallery(video,seed), unsafe_allow_html=True)
|
||||
#with col1_cont:
|
||||
#with col1:
|
||||
#st.image(st.session_state['latestVideos'][0])
|
||||
#st.image(st.session_state['latestVideos'][3])
|
||||
#st.image(st.session_state['latestVideos'][6])
|
||||
#with col2_cont:
|
||||
#with col2:
|
||||
#st.image(st.session_state['latestVideos'][1])
|
||||
#st.image(st.session_state['latestVideos'][4])
|
||||
#st.image(st.session_state['latestVideos'][7])
|
||||
#with col3_cont:
|
||||
#with col3:
|
||||
#st.image(st.session_state['latestVideos'][2])
|
||||
#st.image(st.session_state['latestVideos'][5])
|
||||
#st.image(st.session_state['latestVideos'][8])
|
||||
#historyGallery = st.empty()
|
||||
|
||||
## check if output_images length is the same as seeds length
|
||||
#with gallery_tab:
|
||||
#st.markdown(createHTMLGallery(video,seed), unsafe_allow_html=True)
|
||||
|
||||
|
||||
#st.session_state['historyTab'] = [history_tab,col1,col2,col3,PlaceHolder,col1_cont,col2_cont,col3_cont]
|
||||
#st.session_state['historyTab'] = [history_tab,col1,col2,col3,PlaceHolder,col1_cont,col2_cont,col3_cont]
|
||||
|
||||
except (StopException, KeyError):
|
||||
print(f"Received Streamlit StopException")
|
||||
#except (StopException, KeyError):
|
||||
#print(f"Received Streamlit StopException")
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user