diff --git a/scripts/txt2vid.py b/scripts/txt2vid.py index edd4dd5..c6e0744 100644 --- a/scripts/txt2vid.py +++ b/scripts/txt2vid.py @@ -151,12 +151,12 @@ def diffuse( # generate output numpy image as uint8 image = torch.clamp((image["sample"] + 1.0) / 2.0, min=0.0, max=1.0) - image = transforms.ToPILImage()(image.squeeze_(0)) + image2 = transforms.ToPILImage()(image.squeeze_(0)) - st.session_state["preview_image"].image(image) + st.session_state["preview_image"].image(image2) step_counter = 0 - + duration = timeit.default_timer() - start st.session_state["current_chunk_speed"] = duration @@ -185,7 +185,7 @@ def diffuse( ) st.session_state["progress_bar"].progress(percent if percent < 100 else 100) - return image + return image2 # def txt2vid( @@ -375,6 +375,13 @@ def txt2vid( print("Tx2Vid Model Loaded") st.session_state["pipe"].scheduler = SCHEDULERS[scheduler] + + if do_loop: + prompts = str([prompts, prompts]) + seeds = [seeds, seeds] + #first_seed, *seeds = seeds + #prompts.append(prompts) + #seeds.append(first_seed) # get the conditional text embeddings based on the prompt text_input = st.session_state["pipe"].tokenizer(prompts, padding="max_length", max_length=st.session_state["pipe"].tokenizer.model_max_length, truncation=True, return_tensors="pt") @@ -414,13 +421,6 @@ def txt2vid( # sample a source init1 = torch.randn((1, st.session_state["pipe"].unet.in_channels, height // 8, width // 8), device=torch_device) - if do_loop: - prompts = [prompts, prompts] - seeds = [seeds, seeds] - #first_seed, *seeds = seeds - #prompts.append(prompts) - #seeds.append(first_seed) - # iterate the loop frames = [] @@ -451,17 +451,42 @@ def txt2vid( with autocast("cuda"): image = diffuse(st.session_state["pipe"], cond_embeddings, init, num_inference_steps, cfg_scale, eta) - - #im = Image.fromarray(image) - outpath = os.path.join(full_path, 'frame%06d.png' % frame_index) - image.save(outpath, quality=quality) - - # send the image to the UI to update it - #st.session_state["preview_image"].image(im) - - #append the frames to the frames list so we can use them later. - frames.append(np.asarray(image)) - + + if st.session_state["save_individual_images"] and not st.session_state["use_GFPGAN"] and not st.session_state["use_RealESRGAN"]: + #im = Image.fromarray(image) + outpath = os.path.join(full_path, 'frame%06d.png' % frame_index) + image.save(outpath, quality=quality) + + # send the image to the UI to update it + #st.session_state["preview_image"].image(im) + + #append the frames to the frames list so we can use them later. + frames.append(np.asarray(image)) + + + # + #try: + #if st.session_state["use_GFPGAN"] and st.session_state["GFPGAN"] is not None and not st.session_state["use_RealESRGAN"]: + if st.session_state["use_GFPGAN"] and st.session_state["GFPGAN"] is not None: + #print("Running GFPGAN on image ...") + st.session_state["progress_bar_text"].text("Running GFPGAN on image ...") + #skip_save = True # #287 >_> + torch_gc() + cropped_faces, restored_faces, restored_img = st.session_state["GFPGAN"].enhance(np.array(image)[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True) + gfpgan_sample = restored_img[:,:,::-1] + gfpgan_image = Image.fromarray(gfpgan_sample) + + outpath = os.path.join(full_path, 'frame%06d.png' % frame_index) + gfpgan_image.save(outpath, quality=quality) + + #append the frames to the frames list so we can use them later. + frames.append(np.asarray(gfpgan_image)) + + st.session_state["preview_image"].image(gfpgan_image) + #except AttributeError: + #print("Cant perform GFPGAN, skipping.") + #pass + #increase frame_index counter. frame_index += 1 @@ -661,7 +686,7 @@ def layout(): st.session_state.num_inference_steps = st.slider("Inference Steps:", value=st.session_state['defaults'].txt2vid.num_inference_steps.value, min_value=st.session_state['defaults'].txt2vid.num_inference_steps.min_value, - step=st.session_state['defaults'].txt2vid.num_inference_steps.max_value, + step=st.session_state['defaults'].txt2vid.num_inference_steps.step, max_value=st.session_state['defaults'].txt2vid.num_inference_steps.max_value, help="Higher values (e.g. 100, 200 etc) can create better images.") @@ -731,65 +756,82 @@ def layout(): if generate_button: #print("Loading models") # load the models when we hit the generate button for the first time, it wont be loaded after that so dont worry. - #load_models(False, False, False, st.session_state["RealESRGAN_model"], CustomModel_available=st.session_state["CustomModel_available"], custom_model=custom_model) + #load_models(False, st.session_state["use_GFPGAN"], True, st.session_state["RealESRGAN_model"]) - try: - # run video generation - video, seed, info, stats = txt2vid(prompts=prompt, gpu=st.session_state["defaults"].general.gpu, - num_steps=st.session_state.sampling_steps, max_frames=int(st.session_state.max_frames), - num_inference_steps=st.session_state.num_inference_steps, - cfg_scale=cfg_scale,do_loop=st.session_state["do_loop"], - seeds=seed, quality=100, eta=0.0, width=width, - height=height, weights_path=custom_model, scheduler=scheduler_name, - disable_tqdm=False, beta_start=st.session_state['defaults'].txt2vid.beta_start.value, - beta_end=st.session_state['defaults'].txt2vid.beta_end.value, - beta_schedule=beta_scheduler_type, starting_image=None) + if st.session_state["use_GFPGAN"]: + if "GFPGAN" in st.session_state: + print("GFPGAN already loaded") + else: + # Load GFPGAN + if os.path.exists(st.session_state["defaults"].general.GFPGAN_dir): + try: + st.session_state["GFPGAN"] = load_GFPGAN() + print("Loaded GFPGAN") + except Exception: + import traceback + print("Error loading GFPGAN:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + else: + if "GFPGAN" in st.session_state: + del st.session_state["GFPGAN"] - #message.success('Done!', icon="✅") - message.success('Render Complete: ' + info + '; Stats: ' + stats, icon="✅") + #try: + # run video generation + video, seed, info, stats = txt2vid(prompts=prompt, gpu=st.session_state["defaults"].general.gpu, + num_steps=st.session_state.sampling_steps, max_frames=int(st.session_state.max_frames), + num_inference_steps=st.session_state.num_inference_steps, + cfg_scale=cfg_scale,do_loop=st.session_state["do_loop"], + seeds=seed, quality=100, eta=0.0, width=width, + height=height, weights_path=custom_model, scheduler=scheduler_name, + disable_tqdm=False, beta_start=st.session_state['defaults'].txt2vid.beta_start.value, + beta_end=st.session_state['defaults'].txt2vid.beta_end.value, + beta_schedule=beta_scheduler_type, starting_image=None) - #history_tab,col1,col2,col3,PlaceHolder,col1_cont,col2_cont,col3_cont = st.session_state['historyTab'] + #message.success('Done!', icon="✅") + message.success('Render Complete: ' + info + '; Stats: ' + stats, icon="✅") - #if 'latestVideos' in st.session_state: - #for i in video: - ##push the new image to the list of latest images and remove the oldest one - ##remove the last index from the list\ - #st.session_state['latestVideos'].pop() - ##add the new image to the start of the list - #st.session_state['latestVideos'].insert(0, i) - #PlaceHolder.empty() + #history_tab,col1,col2,col3,PlaceHolder,col1_cont,col2_cont,col3_cont = st.session_state['historyTab'] - #with PlaceHolder.container(): - #col1, col2, col3 = st.columns(3) - #col1_cont = st.container() - #col2_cont = st.container() - #col3_cont = st.container() + #if 'latestVideos' in st.session_state: + #for i in video: + ##push the new image to the list of latest images and remove the oldest one + ##remove the last index from the list\ + #st.session_state['latestVideos'].pop() + ##add the new image to the start of the list + #st.session_state['latestVideos'].insert(0, i) + #PlaceHolder.empty() - #with col1_cont: - #with col1: - #st.image(st.session_state['latestVideos'][0]) - #st.image(st.session_state['latestVideos'][3]) - #st.image(st.session_state['latestVideos'][6]) - #with col2_cont: - #with col2: - #st.image(st.session_state['latestVideos'][1]) - #st.image(st.session_state['latestVideos'][4]) - #st.image(st.session_state['latestVideos'][7]) - #with col3_cont: - #with col3: - #st.image(st.session_state['latestVideos'][2]) - #st.image(st.session_state['latestVideos'][5]) - #st.image(st.session_state['latestVideos'][8]) - #historyGallery = st.empty() + #with PlaceHolder.container(): + #col1, col2, col3 = st.columns(3) + #col1_cont = st.container() + #col2_cont = st.container() + #col3_cont = st.container() - ## check if output_images length is the same as seeds length - #with gallery_tab: - #st.markdown(createHTMLGallery(video,seed), unsafe_allow_html=True) + #with col1_cont: + #with col1: + #st.image(st.session_state['latestVideos'][0]) + #st.image(st.session_state['latestVideos'][3]) + #st.image(st.session_state['latestVideos'][6]) + #with col2_cont: + #with col2: + #st.image(st.session_state['latestVideos'][1]) + #st.image(st.session_state['latestVideos'][4]) + #st.image(st.session_state['latestVideos'][7]) + #with col3_cont: + #with col3: + #st.image(st.session_state['latestVideos'][2]) + #st.image(st.session_state['latestVideos'][5]) + #st.image(st.session_state['latestVideos'][8]) + #historyGallery = st.empty() + + ## check if output_images length is the same as seeds length + #with gallery_tab: + #st.markdown(createHTMLGallery(video,seed), unsafe_allow_html=True) - #st.session_state['historyTab'] = [history_tab,col1,col2,col3,PlaceHolder,col1_cont,col2_cont,col3_cont] + #st.session_state['historyTab'] = [history_tab,col1,col2,col3,PlaceHolder,col1_cont,col2_cont,col3_cont] - except (StopException, KeyError): - print(f"Received Streamlit StopException") + #except (StopException, KeyError): + #print(f"Received Streamlit StopException")