2022-10-24 03:31:41 +03:00
# This file is part of sygil-webui (https://github.com/Sygil-Dev/sygil-webui/).
2022-09-26 16:02:48 +03:00
2022-10-24 03:17:50 +03:00
# Copyright 2022 Sygil-Dev team.
2022-09-26 16:02:48 +03:00
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
2022-10-06 11:45:24 +03:00
# along with this program. If not, see <http://www.gnu.org/licenses/>.
2022-09-14 12:52:02 +03:00
# base webui import and utils.
2022-10-12 10:08:38 +03:00
"""
Implementation of Text to Video based on the
https : / / github . com / nateraw / stable - diffusion - videos
repo and the original gist script from
https : / / gist . github . com / karpathy / 00103 b0037c5aaea32fe1da1af553355
"""
2022-09-14 00:08:40 +03:00
from sd_utils import *
2022-09-14 12:52:02 +03:00
# streamlit imports
2022-09-14 00:08:40 +03:00
from streamlit import StopException
2022-09-14 12:52:02 +03:00
from streamlit . elements import image as STImage
2022-09-25 09:28:02 +03:00
#streamlit components section
from streamlit_server_state import server_state , server_state_lock
2022-09-14 12:52:02 +03:00
#other imports
2022-09-14 00:08:40 +03:00
2022-10-14 22:09:47 +03:00
import os , sys
2022-09-14 00:08:40 +03:00
from PIL import Image
import torch
import numpy as np
2022-09-14 12:52:02 +03:00
import time , inspect , timeit
2022-09-14 00:08:40 +03:00
import torch
from torch import autocast
from io import BytesIO
2022-09-14 14:19:24 +03:00
import imageio
2022-09-14 00:08:40 +03:00
from slugify import slugify
2022-09-14 12:52:02 +03:00
2022-10-23 13:00:46 +03:00
from diffusers import StableDiffusionPipeline , DiffusionPipeline
2022-09-19 04:29:19 +03:00
from diffusers . schedulers import DDIMScheduler , LMSDiscreteScheduler , \
PNDMScheduler
2022-09-14 12:52:02 +03:00
2022-10-19 17:52:08 +03:00
# streamlit components
from custom_components import key_phrase_suggestions
2022-09-25 09:28:02 +03:00
# Temp imports
2022-09-14 12:52:02 +03:00
# end of imports
#---------------------------------------------------------------------------------------------------------------
2022-09-14 00:08:40 +03:00
2022-10-19 17:52:08 +03:00
key_phrase_suggestions . init ( )
2022-09-14 00:08:40 +03:00
try :
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
from transformers import logging
logging . set_verbosity_error ( )
except :
pass
class plugin_info ( ) :
2022-10-03 06:26:01 +03:00
plugname = " txt2vid "
2022-09-18 01:25:55 +03:00
description = " Text to Image "
isTab = True
displayPriority = 1
2022-09-14 00:08:40 +03:00
#
# -----------------------------------------------------------------------------
@torch.no_grad ( )
def diffuse (
2022-10-23 13:00:46 +03:00
pipe ,
cond_embeddings , # text conditioning, should be (1, 77, 768)
cond_latents , # image conditioning, should be (1, 4, 64, 64)
num_inference_steps ,
cfg_scale ,
eta ,
) :
2022-09-18 01:25:55 +03:00
2022-09-14 00:08:40 +03:00
torch_device = cond_latents . get_device ( )
# classifier guidance: add the unconditional embedding
max_length = cond_embeddings . shape [ 1 ] # 77
uncond_input = pipe . tokenizer ( [ " " ] , padding = " max_length " , max_length = max_length , return_tensors = " pt " )
uncond_embeddings = pipe . text_encoder ( uncond_input . input_ids . to ( torch_device ) ) [ 0 ]
text_embeddings = torch . cat ( [ uncond_embeddings , cond_embeddings ] )
# if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
if isinstance ( pipe . scheduler , LMSDiscreteScheduler ) :
cond_latents = cond_latents * pipe . scheduler . sigmas [ 0 ]
# init the scheduler
accepts_offset = " offset " in set ( inspect . signature ( pipe . scheduler . set_timesteps ) . parameters . keys ( ) )
extra_set_kwargs = { }
if accepts_offset :
extra_set_kwargs [ " offset " ] = 1
2022-09-18 01:25:55 +03:00
2022-09-14 00:08:40 +03:00
pipe . scheduler . set_timesteps ( num_inference_steps + st . session_state . sampling_steps , * * extra_set_kwargs )
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = " eta " in set ( inspect . signature ( pipe . scheduler . step ) . parameters . keys ( ) )
extra_step_kwargs = { }
if accepts_eta :
extra_step_kwargs [ " eta " ] = eta
2022-09-18 01:25:55 +03:00
2022-09-14 00:08:40 +03:00
step_counter = 0
inference_counter = 0
2022-09-18 01:25:55 +03:00
2022-09-14 19:24:54 +03:00
if " current_chunk_speed " not in st . session_state :
st . session_state [ " current_chunk_speed " ] = 0
2022-09-18 01:25:55 +03:00
2022-09-14 19:24:54 +03:00
if " previous_chunk_speed_list " not in st . session_state :
st . session_state [ " previous_chunk_speed_list " ] = [ 0 ]
st . session_state [ " previous_chunk_speed_list " ] . append ( st . session_state [ " current_chunk_speed " ] )
2022-09-18 01:25:55 +03:00
2022-09-14 19:24:54 +03:00
if " update_preview_frequency_list " not in st . session_state :
st . session_state [ " update_preview_frequency_list " ] = [ 0 ]
2022-10-08 09:46:23 +03:00
st . session_state [ " update_preview_frequency_list " ] . append ( st . session_state [ " update_preview_frequency " ] )
2022-09-18 01:25:55 +03:00
2022-10-08 09:46:23 +03:00
try :
# diffuse!
for i , t in enumerate ( pipe . scheduler . timesteps ) :
start = timeit . default_timer ( )
2022-09-18 01:25:55 +03:00
2022-10-08 09:46:23 +03:00
#status_text.text(f"Running step: {step_counter}{total_number_steps} {percent} | {duration:.2f}{speed}")
2022-09-14 00:08:40 +03:00
2022-10-08 09:46:23 +03:00
# expand the latents for classifier free guidance
latent_model_input = torch . cat ( [ cond_latents ] * 2 )
if isinstance ( pipe . scheduler , LMSDiscreteScheduler ) :
sigma = pipe . scheduler . sigmas [ i ]
latent_model_input = latent_model_input / ( ( sigma * * 2 + 1 ) * * 0.5 )
2022-09-14 00:08:40 +03:00
2022-10-08 09:46:23 +03:00
# predict the noise residual
noise_pred = pipe . unet ( latent_model_input , t , encoder_hidden_states = text_embeddings ) [ " sample " ]
2022-09-14 00:08:40 +03:00
2022-10-08 09:46:23 +03:00
# cfg
noise_pred_uncond , noise_pred_text = noise_pred . chunk ( 2 )
noise_pred = noise_pred_uncond + cfg_scale * ( noise_pred_text - noise_pred_uncond )
2022-09-14 00:08:40 +03:00
2022-10-08 09:46:23 +03:00
# compute the previous noisy sample x_t -> x_t-1
if isinstance ( pipe . scheduler , LMSDiscreteScheduler ) :
cond_latents = pipe . scheduler . step ( noise_pred , i , cond_latents , * * extra_step_kwargs ) [ " prev_sample " ]
else :
cond_latents = pipe . scheduler . step ( noise_pred , t , cond_latents , * * extra_step_kwargs ) [ " prev_sample " ]
2022-09-18 01:25:55 +03:00
2022-10-08 09:46:23 +03:00
#update the preview image if it is enabled and the frequency matches the step_counter
if st . session_state [ " update_preview " ] :
step_counter + = 1
2022-10-08 06:32:02 +03:00
2022-10-24 00:56:00 +03:00
if step_counter == st . session_state [ " update_preview_frequency " ] :
2022-10-08 09:46:23 +03:00
if st . session_state . dynamic_preview_frequency :
st . session_state [ " current_chunk_speed " ] ,
st . session_state [ " previous_chunk_speed_list " ] ,
st . session_state [ " update_preview_frequency " ] ,
st . session_state [ " avg_update_preview_frequency " ] = optimize_update_preview_frequency ( st . session_state [ " current_chunk_speed " ] ,
st . session_state [ " previous_chunk_speed_list " ] ,
st . session_state [ " update_preview_frequency " ] ,
st . session_state [ " update_preview_frequency_list " ] )
2022-09-18 01:25:55 +03:00
2022-10-08 09:46:23 +03:00
#scale and decode the image latents with vae
cond_latents_2 = 1 / 0.18215 * cond_latents
image = pipe . vae . decode ( cond_latents_2 )
2022-09-18 01:25:55 +03:00
2022-10-08 09:46:23 +03:00
# generate output numpy image as uint8
image = torch . clamp ( ( image [ " sample " ] + 1.0 ) / 2.0 , min = 0.0 , max = 1.0 )
image2 = transforms . ToPILImage ( ) ( image . squeeze_ ( 0 ) )
2022-09-18 01:25:55 +03:00
2022-10-08 09:46:23 +03:00
st . session_state [ " preview_image " ] . image ( image2 )
2022-10-06 11:45:24 +03:00
2022-10-08 09:46:23 +03:00
step_counter = 0
2022-09-18 01:25:55 +03:00
2022-10-08 09:46:23 +03:00
duration = timeit . default_timer ( ) - start
2022-09-18 01:25:55 +03:00
2022-10-08 09:46:23 +03:00
st . session_state [ " current_chunk_speed " ] = duration
2022-09-18 01:25:55 +03:00
2022-10-08 09:46:23 +03:00
if duration > = 1 :
speed = " s/it "
else :
speed = " it/s "
duration = 1 / duration
2022-10-24 00:56:00 +03:00
#
total_frames = ( st . session_state . sampling_steps + st . session_state . num_inference_steps ) * st . session_state . max_duration_in_seconds
total_steps = st . session_state . sampling_steps + st . session_state . num_inference_steps
2022-10-08 09:46:23 +03:00
if i > st . session_state . sampling_steps :
inference_counter + = 1
inference_percent = int ( 100 * float ( inference_counter + 1 if inference_counter < num_inference_steps else num_inference_steps ) / float ( num_inference_steps ) )
inference_progress = f " { inference_counter + 1 if inference_counter < num_inference_steps else num_inference_steps } / { num_inference_steps } { inference_percent } % "
else :
inference_progress = " "
2022-09-18 01:25:55 +03:00
2022-10-24 01:13:18 +03:00
total_percent = int ( 100 * float ( i + 1 if i + 1 < ( num_inference_steps + st . session_state . sampling_steps )
2022-10-24 00:56:00 +03:00
else ( num_inference_steps + st . session_state . sampling_steps ) ) / float ( ( num_inference_steps + st . session_state . sampling_steps ) ) )
percent = int ( 100 * float ( i + 1 if i + 1 < num_inference_steps else st . session_state . sampling_steps ) / float ( st . session_state . sampling_steps ) )
frames_percent = int ( 100 * float ( st . session_state . current_frame if st . session_state . current_frame < total_frames else total_frames ) / float ( total_frames ) )
2022-09-18 01:25:55 +03:00
2022-10-14 22:09:47 +03:00
if " progress_bar_text " in st . session_state :
st . session_state [ " progress_bar_text " ] . text (
f " Running step: { i + 1 if i + 1 < st . session_state . sampling_steps else st . session_state . sampling_steps } / { st . session_state . sampling_steps } "
2022-10-08 09:46:23 +03:00
f " { percent if percent < 100 else 100 } % { inference_progress } { duration : .2f } { speed } | "
2022-10-24 00:56:00 +03:00
f " Frame: { st . session_state . current_frame + 1 if st . session_state . current_frame < total_frames else total_frames } / { total_frames } "
2022-10-08 09:46:23 +03:00
f " { frames_percent if frames_percent < 100 else 100 } % { st . session_state . frame_duration : .2f } { st . session_state . frame_speed } "
2022-10-14 22:09:47 +03:00
)
if " progress_bar " in st . session_state :
2022-10-24 00:56:00 +03:00
st . session_state [ " progress_bar " ] . progress ( total_percent if total_percent < 100 else 100 )
2022-10-08 09:46:23 +03:00
except KeyError :
raise StopException
2022-10-06 11:45:24 +03:00
2022-09-21 03:07:38 +03:00
#scale and decode the image latents with vae
cond_latents_2 = 1 / 0.18215 * cond_latents
image = pipe . vae . decode ( cond_latents_2 )
# generate output numpy image as uint8
image = torch . clamp ( ( image [ " sample " ] + 1.0 ) / 2.0 , min = 0.0 , max = 1.0 )
image2 = transforms . ToPILImage ( ) ( image . squeeze_ ( 0 ) )
2022-10-06 11:45:24 +03:00
2022-09-14 00:08:40 +03:00
2022-09-19 01:17:02 +03:00
return image2
2022-09-14 00:08:40 +03:00
#
2022-09-25 09:28:02 +03:00
def load_diffusers_model ( weights_path , torch_device ) :
with server_state_lock [ " model " ] :
if " model " in server_state :
del server_state [ " model " ]
2022-10-06 11:45:24 +03:00
2022-09-27 21:36:48 +03:00
if " textual_inversion " in st . session_state :
2022-10-06 11:45:24 +03:00
del st . session_state [ ' textual_inversion ' ]
2022-09-27 10:57:28 +03:00
try :
with server_state_lock [ " pipe " ] :
2022-10-05 04:38:40 +03:00
if " pipe " not in server_state :
2022-10-11 06:45:37 +03:00
if " weights_path " in st . session_state and st . session_state [ " weights_path " ] != weights_path :
2022-10-02 00:50:40 +03:00
del st . session_state [ " weights_path " ]
2022-10-06 11:45:24 +03:00
2022-10-02 00:50:40 +03:00
st . session_state [ " weights_path " ] = weights_path
2022-10-12 10:08:38 +03:00
server_state [ ' float16 ' ] = st . session_state [ ' defaults ' ] . general . use_float16
server_state [ ' no_half ' ] = st . session_state [ ' defaults ' ] . general . no_half
server_state [ ' optimized ' ] = st . session_state [ ' defaults ' ] . general . optimized
2022-10-11 06:45:37 +03:00
#if folder "models/diffusers/stable-diffusion-v1-4" exists, load the model from there
2022-10-02 01:59:50 +03:00
if weights_path == " CompVis/stable-diffusion-v1-4 " :
2022-10-04 17:25:47 +03:00
model_path = os . path . join ( " models " , " diffusers " , " stable-diffusion-v1-4 " )
2022-10-02 00:50:40 +03:00
2022-10-21 05:50:40 +03:00
if weights_path == " runwayml/stable-diffusion-v1-5 " :
model_path = os . path . join ( " models " , " diffusers " , " stable-diffusion-v1-5 " )
2022-10-02 02:52:53 +03:00
if not os . path . exists ( model_path + " /model_index.json " ) :
2022-10-23 13:00:46 +03:00
server_state [ " pipe " ] = DiffusionPipeline . from_pretrained (
weights_path ,
use_local_file = True ,
use_auth_token = st . session_state [ " defaults " ] . general . huggingface_token ,
torch_dtype = torch . float16 if st . session_state [ ' defaults ' ] . general . use_float16 else None ,
revision = " fp16 " if not st . session_state [ ' defaults ' ] . general . no_half else None ,
safety_checker = None , # Very important for videos...lots of false positives while interpolating
custom_pipeline = " interpolate_stable_diffusion " ,
2022-09-27 10:57:28 +03:00
)
2022-10-23 13:00:46 +03:00
DiffusionPipeline . save_pretrained ( server_state [ " pipe " ] , model_path )
2022-09-27 10:57:28 +03:00
else :
2022-10-23 13:00:46 +03:00
server_state [ " pipe " ] = DiffusionPipeline . from_pretrained (
model_path ,
use_local_file = True ,
torch_dtype = torch . float16 if st . session_state [ ' defaults ' ] . general . use_float16 else None ,
revision = " fp16 " if not st . session_state [ ' defaults ' ] . general . no_half else None ,
safety_checker = None , # Very important for videos...lots of false positives while interpolating
custom_pipeline = " interpolate_stable_diffusion " ,
2022-10-02 00:50:40 +03:00
)
2022-10-06 11:45:24 +03:00
2022-09-25 09:28:02 +03:00
server_state [ " pipe " ] . unet . to ( torch_device )
server_state [ " pipe " ] . vae . to ( torch_device )
server_state [ " pipe " ] . text_encoder . to ( torch_device )
2022-10-06 11:45:24 +03:00
2022-09-25 09:28:02 +03:00
if st . session_state . defaults . general . enable_attention_slicing :
server_state [ " pipe " ] . enable_attention_slicing ( )
2022-10-06 11:45:24 +03:00
if st . session_state . defaults . general . enable_minimal_memory_usage :
2022-09-25 09:28:02 +03:00
server_state [ " pipe " ] . enable_minimal_memory_usage ( )
2022-10-06 11:45:24 +03:00
2022-10-15 15:34:07 +03:00
logger . info ( " Tx2Vid Model Loaded " )
2022-10-02 00:50:40 +03:00
else :
2022-10-11 06:45:37 +03:00
# if the float16 or no_half options have changed since the last time the model was loaded then we need to reload the model.
2022-10-12 10:08:38 +03:00
if ( " float16 " in server_state and server_state [ ' float16 ' ] != st . session_state [ ' defaults ' ] . general . use_float16 ) \
or ( " no_half " in server_state and server_state [ ' no_half ' ] != st . session_state [ ' defaults ' ] . general . no_half ) \
or ( " optimized " in server_state and server_state [ ' optimized ' ] != st . session_state [ ' defaults ' ] . general . optimized ) :
2022-10-11 06:45:37 +03:00
2022-10-12 10:08:38 +03:00
del server_state [ ' float16 ' ]
del server_state [ ' no_half ' ]
2022-10-11 06:45:37 +03:00
with server_state_lock [ " pipe " ] :
del server_state [ " pipe " ]
torch_gc ( )
2022-10-12 10:08:38 +03:00
del server_state [ ' optimized ' ]
2022-10-11 06:45:37 +03:00
2022-10-12 10:08:38 +03:00
server_state [ ' float16 ' ] = st . session_state [ ' defaults ' ] . general . use_float16
server_state [ ' no_half ' ] = st . session_state [ ' defaults ' ] . general . no_half
server_state [ ' optimized ' ] = st . session_state [ ' defaults ' ] . general . optimized
2022-10-11 06:45:37 +03:00
load_diffusers_model ( weights_path , torch_device )
else :
2022-10-15 15:34:07 +03:00
logger . info ( " Tx2Vid Model already Loaded " )
2022-10-11 06:45:37 +03:00
except ( EnvironmentError , OSError ) as e :
if " huggingface_token " not in st . session_state or st . session_state [ " defaults " ] . general . huggingface_token == " None " :
2022-10-14 22:09:47 +03:00
if " progress_bar_text " in st . session_state :
st . session_state [ " progress_bar_text " ] . error (
" You need a huggingface token in order to use the Text to Video tab. Use the Settings page from the sidebar on the left to add your token. "
)
2022-10-11 06:45:37 +03:00
raise OSError ( " You need a huggingface token in order to use the Text to Video tab. Use the Settings page from the sidebar on the left to add your token. " )
else :
2022-10-14 22:09:47 +03:00
if " progress_bar_text " in st . session_state :
st . session_state [ " progress_bar_text " ] . error ( e )
2022-10-11 06:45:37 +03:00
2022-10-08 09:46:23 +03:00
#
def save_video_to_disk ( frames , seeds , sanitized_prompt , fps = 6 , save_video = True , outdir = ' outputs ' ) :
if save_video :
# write video to memory
#output = io.BytesIO()
#writer = imageio.get_writer(os.path.join(os.getcwd(), st.session_state['defaults'].general.outdir, "txt2vid"), im, extension=".mp4", fps=30)
#try:
2022-10-23 13:09:53 +03:00
video_path = os . path . join ( os . getcwd ( ) , outdir , " txt2vid " , f " { seeds } _ { sanitized_prompt } { datetime . now ( ) . strftime ( ' % Y % m- %d % H- % M % S- ' ) + str ( uuid4 ( ) ) [ : 8 ] } .mp4 " )
2022-10-08 09:46:23 +03:00
writer = imageio . get_writer ( video_path , fps = fps )
for frame in frames :
writer . append_data ( frame )
writer . close ( )
#except:
# print("Can't save video, skipping.")
2022-10-06 11:45:24 +03:00
2022-10-08 09:46:23 +03:00
return video_path
2022-09-25 09:28:02 +03:00
#
2022-09-14 00:08:40 +03:00
def txt2vid (
2022-10-23 13:00:46 +03:00
# --------------------------------------
# args you probably want to change
2022-09-18 01:25:55 +03:00
prompts = [ " blueberry spaghetti " , " strawberry spaghetti " ] , # prompt to dream about
gpu : int = st . session_state [ ' defaults ' ] . general . gpu , # id of the gpu to run on
#name:str = 'test', # name of this project, for the output directory
#rootdir:str = st.session_state['defaults'].general.outdir,
num_steps : int = 200 , # number of steps between each pair of sampled points
2022-10-23 13:00:46 +03:00
max_duration_in_seconds : int = 30 , # number of frames to write and then exit the script
num_inference_steps : int = 50 , # more (e.g. 100, 200 etc) can create slightly better images
cfg_scale : float = 5.0 , # can depend on the prompt. usually somewhere between 3-10 is good
save_video = True ,
save_video_on_stop = False ,
outdir = ' outputs ' ,
do_loop = False ,
use_lerp_for_text = False ,
seeds = None ,
quality : int = 100 , # for jpeg compression of the output images
eta : float = 0.0 ,
width : int = 256 ,
height : int = 256 ,
weights_path = " runwayml/stable-diffusion-v1-5 " ,
scheduler = " klms " , # choices: default, ddim, klms
disable_tqdm = False ,
#-----------------------------------------------
beta_start = 0.0001 ,
beta_end = 0.00012 ,
beta_schedule = " scaled_linear " ,
starting_image = None
) :
2022-09-14 00:08:40 +03:00
"""
prompt = [ " blueberry spaghetti " , " strawberry spaghetti " ] , # prompt to dream about
2022-09-14 16:40:56 +03:00
gpu : int = st . session_state [ ' defaults ' ] . general . gpu , # id of the gpu to run on
2022-09-14 00:08:40 +03:00
#name:str = 'test', # name of this project, for the output directory
2022-09-14 16:40:56 +03:00
#rootdir:str = st.session_state['defaults'].general.outdir,
2022-09-14 00:08:40 +03:00
num_steps : int = 200 , # number of steps between each pair of sampled points
2022-10-23 13:00:46 +03:00
max_duration_in_seconds : int = 10000 , # number of frames to write and then exit the script
2022-09-14 00:08:40 +03:00
num_inference_steps : int = 50 , # more (e.g. 100, 200 etc) can create slightly better images
cfg_scale : float = 5.0 , # can depend on the prompt. usually somewhere between 3-10 is good
do_loop = False ,
use_lerp_for_text = False ,
seed = None ,
quality : int = 100 , # for jpeg compression of the output images
eta : float = 0.0 ,
width : int = 256 ,
height : int = 256 ,
2022-10-21 05:50:40 +03:00
weights_path = " runwayml/stable-diffusion-v1-5 " ,
2022-09-14 00:08:40 +03:00
scheduler = " klms " , # choices: default, ddim, klms
disable_tqdm = False ,
beta_start = 0.0001 ,
beta_end = 0.00012 ,
beta_schedule = " scaled_linear "
"""
mem_mon = MemUsageMonitor ( ' MemMon ' )
2022-09-18 01:25:55 +03:00
mem_mon . start ( )
2022-09-14 00:08:40 +03:00
seeds = seed_to_int ( seeds )
2022-09-18 01:25:55 +03:00
2022-09-14 00:08:40 +03:00
# We add an extra frame because most
# of the time the first frame is just the noise.
2022-10-23 13:00:46 +03:00
#max_duration_in_seconds +=1
2022-09-18 01:25:55 +03:00
2022-09-14 00:08:40 +03:00
assert torch . cuda . is_available ( )
assert height % 8 == 0 and width % 8 == 0
torch . manual_seed ( seeds )
torch_device = f " cuda: { gpu } "
2022-09-18 01:25:55 +03:00
2022-09-14 00:08:40 +03:00
# init the output dir
sanitized_prompt = slugify ( prompts )
2022-09-18 01:25:55 +03:00
2022-10-08 09:46:23 +03:00
full_path = os . path . join ( os . getcwd ( ) , st . session_state [ ' defaults ' ] . general . outdir , " txt2vid " , " samples " , sanitized_prompt )
2022-09-18 01:25:55 +03:00
2022-09-14 00:08:40 +03:00
if len ( full_path ) > 220 :
sanitized_prompt = sanitized_prompt [ : 220 - len ( full_path ) ]
2022-10-08 09:46:23 +03:00
full_path = os . path . join ( os . getcwd ( ) , st . session_state [ ' defaults ' ] . general . outdir , " txt2vid " , " samples " , sanitized_prompt )
2022-09-18 01:25:55 +03:00
2022-09-14 00:08:40 +03:00
os . makedirs ( full_path , exist_ok = True )
2022-09-18 01:25:55 +03:00
2022-09-14 00:08:40 +03:00
# Write prompt info to file in output dir so we can keep track of what we did
if st . session_state . write_info_files :
with open ( os . path . join ( full_path , f ' { slugify ( str ( seeds ) ) } _config.json ' if len ( prompts ) > 1 else " prompts_config.json " ) , " w " ) as outfile :
outfile . write ( json . dumps (
2022-10-08 06:32:02 +03:00
dict (
prompts = prompts ,
gpu = gpu ,
num_steps = num_steps ,
2022-10-23 13:00:46 +03:00
max_duration_in_seconds = max_duration_in_seconds ,
2022-10-08 06:32:02 +03:00
num_inference_steps = num_inference_steps ,
cfg_scale = cfg_scale ,
do_loop = do_loop ,
use_lerp_for_text = use_lerp_for_text ,
seeds = seeds ,
quality = quality ,
eta = eta ,
width = width ,
height = height ,
weights_path = weights_path ,
scheduler = scheduler ,
disable_tqdm = disable_tqdm ,
beta_start = beta_start ,
beta_end = beta_end ,
beta_schedule = beta_schedule
) ,
indent = 2 ,
sort_keys = False ,
2022-09-18 01:25:55 +03:00
) )
2022-09-14 00:08:40 +03:00
#print(scheduler)
default_scheduler = PNDMScheduler (
beta_start = beta_start , beta_end = beta_end , beta_schedule = beta_schedule
)
# ------------------------------------------------------------------------------
#Schedulers
ddim_scheduler = DDIMScheduler (
beta_start = beta_start ,
2022-09-18 01:25:55 +03:00
beta_end = beta_end ,
beta_schedule = beta_schedule ,
clip_sample = False ,
set_alpha_to_one = False ,
2022-09-14 00:08:40 +03:00
)
2022-09-18 01:25:55 +03:00
2022-09-14 00:08:40 +03:00
klms_scheduler = LMSDiscreteScheduler (
beta_start = beta_start , beta_end = beta_end , beta_schedule = beta_schedule
)
2022-10-06 11:45:24 +03:00
2022-09-18 01:25:55 +03:00
SCHEDULERS = dict ( default = default_scheduler , ddim = ddim_scheduler , klms = klms_scheduler )
2022-10-11 06:45:37 +03:00
with st . session_state [ " progress_bar_text " ] . container ( ) :
with hc . HyLoader ( ' Loading Models... ' , hc . Loaders . standard_loaders , index = [ 0 ] ) :
load_diffusers_model ( weights_path , torch_device )
2022-10-06 11:45:24 +03:00
2022-10-05 04:38:40 +03:00
if " pipe " not in server_state :
2022-10-15 15:34:07 +03:00
logger . error ( ' wtf ' )
2022-10-05 04:38:40 +03:00
2022-09-25 09:28:02 +03:00
server_state [ " pipe " ] . scheduler = SCHEDULERS [ scheduler ]
2022-10-06 11:45:24 +03:00
2022-09-25 09:28:02 +03:00
server_state [ " pipe " ] . use_multiprocessing_for_evaluation = False
2022-10-06 11:45:24 +03:00
server_state [ " pipe " ] . use_multiprocessed_decoding = False
2022-09-19 01:17:02 +03:00
if do_loop :
prompts = str ( [ prompts , prompts ] )
seeds = [ seeds , seeds ]
#first_seed, *seeds = seeds
#prompts.append(prompts)
2022-10-06 11:45:24 +03:00
#seeds.append(first_seed)
2022-09-18 01:25:55 +03:00
2022-10-08 06:32:02 +03:00
with torch . autocast ( ' cuda ' ) :
# get the conditional text embeddings based on the prompt
text_input = server_state [ " pipe " ] . tokenizer ( prompts , padding = " max_length " , max_length = server_state [ " pipe " ] . tokenizer . model_max_length , truncation = True , return_tensors = " pt " )
cond_embeddings = server_state [ " pipe " ] . text_encoder ( text_input . input_ids . to ( torch_device ) ) [ 0 ]
2022-09-18 01:25:55 +03:00
#
if st . session_state . defaults . general . use_sd_concepts_library :
2022-10-06 11:45:24 +03:00
prompt_tokens = re . findall ( ' <([a-zA-Z0-9-]+)> ' , prompts )
2022-09-18 01:25:55 +03:00
if prompt_tokens :
# compviz
#tokenizer = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelCS).cond_stage_model.tokenizer
#text_encoder = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelCS).cond_stage_model.transformer
# diffusers
tokenizer = st . session_state . pipe . tokenizer
text_encoder = st . session_state . pipe . text_encoder
ext = ( ' pt ' , ' bin ' )
#print (prompt_tokens)
2022-10-06 11:45:24 +03:00
if len ( prompt_tokens ) > 1 :
2022-09-18 01:25:55 +03:00
for token_name in prompt_tokens :
2022-10-06 11:45:24 +03:00
embedding_path = os . path . join ( st . session_state [ ' defaults ' ] . general . sd_concepts_library_folder , token_name )
2022-09-18 01:25:55 +03:00
if os . path . exists ( embedding_path ) :
for files in os . listdir ( embedding_path ) :
if files . endswith ( ext ) :
load_learned_embed_in_clip ( f " { os . path . join ( embedding_path , files ) } " , text_encoder , tokenizer , f " < { token_name } > " )
else :
embedding_path = os . path . join ( st . session_state [ ' defaults ' ] . general . sd_concepts_library_folder , prompt_tokens [ 0 ] )
if os . path . exists ( embedding_path ) :
for files in os . listdir ( embedding_path ) :
if files . endswith ( ext ) :
2022-10-06 11:45:24 +03:00
load_learned_embed_in_clip ( f " { os . path . join ( embedding_path , files ) } " , text_encoder , tokenizer , f " < { prompt_tokens [ 0 ] } > " )
2022-09-18 01:25:55 +03:00
2022-09-14 00:08:40 +03:00
# sample a source
2022-09-25 09:28:02 +03:00
init1 = torch . randn ( ( 1 , server_state [ " pipe " ] . unet . in_channels , height / / 8 , width / / 8 ) , device = torch_device )
2022-09-14 00:08:40 +03:00
2022-09-18 01:25:55 +03:00
2022-09-14 00:08:40 +03:00
# iterate the loop
frames = [ ]
frame_index = 0
2022-09-18 01:25:55 +03:00
2022-10-12 10:08:38 +03:00
second_count = 1
2022-09-14 19:24:54 +03:00
st . session_state [ " total_frames_avg_duration " ] = [ ]
2022-09-18 01:25:55 +03:00
st . session_state [ " total_frames_avg_speed " ] = [ ]
2022-09-14 00:08:40 +03:00
try :
2022-10-23 13:00:46 +03:00
while second_count < max_duration_in_seconds :
2022-09-14 00:08:40 +03:00
st . session_state [ " frame_duration " ] = 0
2022-09-18 01:25:55 +03:00
st . session_state [ " frame_speed " ] = 0
2022-09-14 00:08:40 +03:00
st . session_state [ " current_frame " ] = frame_index
2022-10-23 13:00:46 +03:00
#print(f"Second: {second_count+1}/{max_duration_in_seconds}")
2022-10-12 10:08:38 +03:00
2022-09-14 00:08:40 +03:00
# sample the destination
2022-09-25 09:28:02 +03:00
init2 = torch . randn ( ( 1 , server_state [ " pipe " ] . unet . in_channels , height / / 8 , width / / 8 ) , device = torch_device )
2022-09-14 00:08:40 +03:00
2022-09-20 16:00:08 +03:00
for i , t in enumerate ( np . linspace ( 0 , 1 , num_steps ) ) :
2022-09-14 00:08:40 +03:00
start = timeit . default_timer ( )
2022-10-23 13:29:17 +03:00
logger . info ( f " COUNT: { frame_index + 1 } / { num_steps } " )
2022-09-18 01:25:55 +03:00
2022-10-12 10:08:38 +03:00
if use_lerp_for_text :
init = torch . lerp ( init1 , init2 , float ( t ) )
else :
init = slerp ( gpu , float ( t ) , init1 , init2 )
2022-09-18 01:25:55 +03:00
2022-10-12 10:08:38 +03:00
#init = slerp(gpu, float(t), init1, init2)
2022-09-18 01:25:55 +03:00
2022-09-14 00:08:40 +03:00
with autocast ( " cuda " ) :
2022-09-25 09:28:02 +03:00
image = diffuse ( server_state [ " pipe " ] , cond_embeddings , init , num_inference_steps , cfg_scale , eta )
2022-10-06 11:45:24 +03:00
2022-09-26 07:23:32 +03:00
if st . session_state [ " save_individual_images " ] and not st . session_state [ " use_GFPGAN " ] and not st . session_state [ " use_RealESRGAN " ] :
2022-09-19 01:17:02 +03:00
#im = Image.fromarray(image)
outpath = os . path . join ( full_path , ' frame %06d .png ' % frame_index )
image . save ( outpath , quality = quality )
2022-10-06 11:45:24 +03:00
2022-09-19 01:17:02 +03:00
# send the image to the UI to update it
#st.session_state["preview_image"].image(im)
2022-10-06 11:45:24 +03:00
2022-09-19 01:17:02 +03:00
#append the frames to the frames list so we can use them later.
frames . append ( np . asarray ( image ) )
2022-10-06 11:45:24 +03:00
2022-09-19 01:17:02 +03:00
#
#try:
2022-09-26 07:23:32 +03:00
#if st.session_state["use_GFPGAN"] and server_state["GFPGAN"] is not None and not st.session_state["use_RealESRGAN"]:
if st . session_state [ " use_GFPGAN " ] and server_state [ " GFPGAN " ] is not None :
2022-09-19 01:17:02 +03:00
#print("Running GFPGAN on image ...")
2022-10-14 22:09:47 +03:00
if " progress_bar_text " in st . session_state :
st . session_state [ " progress_bar_text " ] . text ( " Running GFPGAN on image ... " )
2022-09-19 01:17:02 +03:00
#skip_save = True # #287 >_>
torch_gc ( )
2022-09-25 09:28:02 +03:00
cropped_faces , restored_faces , restored_img = server_state [ " GFPGAN " ] . enhance ( np . array ( image ) [ : , : , : : - 1 ] , has_aligned = False , only_center_face = False , paste_back = True )
2022-09-19 01:17:02 +03:00
gfpgan_sample = restored_img [ : , : , : : - 1 ]
gfpgan_image = Image . fromarray ( gfpgan_sample )
2022-10-06 11:45:24 +03:00
2022-09-19 01:17:02 +03:00
outpath = os . path . join ( full_path , ' frame %06d .png ' % frame_index )
gfpgan_image . save ( outpath , quality = quality )
2022-10-06 11:45:24 +03:00
2022-09-19 01:17:02 +03:00
#append the frames to the frames list so we can use them later.
2022-10-06 11:45:24 +03:00
frames . append ( np . asarray ( gfpgan_image ) )
2022-10-08 09:46:23 +03:00
try :
st . session_state [ " preview_image " ] . image ( gfpgan_image )
except KeyError :
2022-10-15 15:34:07 +03:00
logger . error ( " Cant get session_state, skipping image preview. " )
2022-10-08 09:46:23 +03:00
#except (AttributeError, KeyError):
2022-09-19 01:17:02 +03:00
#print("Cant perform GFPGAN, skipping.")
2022-10-06 11:45:24 +03:00
2022-09-14 00:08:40 +03:00
#increase frame_index counter.
frame_index + = 1
st . session_state [ " current_frame " ] = frame_index
duration = timeit . default_timer ( ) - start
if duration > = 1 :
speed = " s/it "
else :
speed = " it/s "
2022-09-18 01:25:55 +03:00
duration = 1 / duration
2022-09-14 00:08:40 +03:00
st . session_state [ " frame_duration " ] = duration
st . session_state [ " frame_speed " ] = speed
init1 = init2
2022-10-08 09:46:23 +03:00
# save the video after the generation is done.
video_path = save_video_to_disk ( frames , seeds , sanitized_prompt , save_video = save_video , outdir = outdir )
2022-09-18 01:25:55 +03:00
2022-10-08 09:46:23 +03:00
except StopException :
if save_video_on_stop :
2022-10-15 15:34:07 +03:00
logger . info ( " Streamlit Stop Exception Received. Saving video " )
2022-10-08 09:46:23 +03:00
video_path = save_video_to_disk ( frames , seeds , sanitized_prompt , save_video = save_video , outdir = outdir )
else :
video_path = None
2022-09-18 01:25:55 +03:00
2022-10-08 09:46:23 +03:00
if video_path and " preview_video " in st . session_state :
2022-09-14 00:08:40 +03:00
# show video preview on the UI
st . session_state [ " preview_video " ] . video ( open ( video_path , ' rb ' ) . read ( ) )
2022-09-18 01:25:55 +03:00
2022-09-14 00:08:40 +03:00
mem_max_used , mem_total = mem_mon . read_and_stop ( )
2022-09-18 01:25:55 +03:00
time_diff = time . time ( ) - start
2022-09-14 00:08:40 +03:00
info = f """
2022-09-18 01:25:55 +03:00
{ prompts }
2022-10-23 13:29:17 +03:00
Sampling Steps : { num_steps } , Sampler : { scheduler } , CFG scale : { cfg_scale } , Seed : { seeds } , Max Duration In Seconds : { max_duration_in_seconds } """ .strip()
2022-09-14 00:08:40 +03:00
stats = f '''
2022-10-23 13:00:46 +03:00
Took { round ( time_diff , 2 ) } s total ( { round ( time_diff / ( max_duration_in_seconds ) , 2 ) } s per image )
2022-09-18 01:25:55 +03:00
Peak memory usage : { - ( mem_max_used / / - 1_048_576 ) } MiB / { - ( mem_total / / - 1_048_576 ) } MiB / { round ( mem_max_used / mem_total * 100 , 3 ) } % '''
2022-09-14 00:08:40 +03:00
2022-09-15 16:29:41 +03:00
return video_path , seeds , info , stats
2022-09-14 00:08:40 +03:00
#
def layout ( ) :
with st . form ( " txt2vid-inputs " ) :
st . session_state [ " generation_mode " ] = " txt2vid "
2022-09-18 01:25:55 +03:00
2022-09-14 00:08:40 +03:00
input_col1 , generate_col1 = st . columns ( [ 10 , 1 ] )
with input_col1 :
#prompt = st.text_area("Input Text","")
2022-10-19 17:52:08 +03:00
placeholder = " A corgi wearing a top hat as an oil painting. "
2022-10-20 23:19:26 +03:00
prompt = st . text_area ( " Input Text " , " " , placeholder = placeholder , height = 54 )
2022-10-19 17:52:08 +03:00
key_phrase_suggestions . suggestion_area ( placeholder )
2022-09-18 01:25:55 +03:00
2022-09-14 00:08:40 +03:00
# Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way.
generate_col1 . write ( " " )
generate_col1 . write ( " " )
generate_button = generate_col1 . form_submit_button ( " Generate " )
2022-09-18 01:25:55 +03:00
2022-09-14 00:08:40 +03:00
# creating the page layout using columns
2022-10-06 11:45:24 +03:00
col1 , col2 , col3 = st . columns ( [ 1 , 2 , 1 ] , gap = " large " )
2022-09-18 01:25:55 +03:00
2022-09-14 00:08:40 +03:00
with col1 :
2022-09-18 21:11:23 +03:00
width = st . slider ( " Width: " , min_value = st . session_state [ ' defaults ' ] . txt2vid . width . min_value , max_value = st . session_state [ ' defaults ' ] . txt2vid . width . max_value ,
value = st . session_state [ ' defaults ' ] . txt2vid . width . value , step = st . session_state [ ' defaults ' ] . txt2vid . width . step )
height = st . slider ( " Height: " , min_value = st . session_state [ ' defaults ' ] . txt2vid . height . min_value , max_value = st . session_state [ ' defaults ' ] . txt2vid . height . max_value ,
value = st . session_state [ ' defaults ' ] . txt2vid . height . value , step = st . session_state [ ' defaults ' ] . txt2vid . height . step )
2022-10-18 04:01:19 +03:00
cfg_scale = st . number_input ( " CFG (Classifier Free Guidance Scale): " , min_value = st . session_state [ ' defaults ' ] . txt2vid . cfg_scale . min_value ,
value = st . session_state [ ' defaults ' ] . txt2vid . cfg_scale . value ,
step = st . session_state [ ' defaults ' ] . txt2vid . cfg_scale . step ,
help = " How strongly the image should follow the prompt. " )
2022-09-18 01:25:55 +03:00
2022-09-16 22:55:58 +03:00
#uploaded_images = st.file_uploader("Upload Image", accept_multiple_files=False, type=["png", "jpg", "jpeg", "webp"],
2022-10-06 11:45:24 +03:00
#help="Upload an image which will be used for the image to image generation.")
2022-09-14 16:40:56 +03:00
seed = st . text_input ( " Seed: " , value = st . session_state [ ' defaults ' ] . txt2vid . seed , help = " The seed to use, if left blank a random seed will be generated. " )
2022-10-02 23:10:17 +03:00
#batch_count = st.slider("Batch count.", min_value=1, max_value=100, value=st.session_state['defaults'].txt2vid.batch_count,
# step=1, help="How many iterations or batches of images to generate in total.")
2022-09-14 16:40:56 +03:00
#batch_size = st.slider("Batch size", min_value=1, max_value=250, value=st.session_state['defaults'].txt2vid.batch_size, step=1,
2022-09-18 01:25:55 +03:00
#help="How many images are at once in a batch.\
#It increases the VRAM usage a lot but if you have enough VRAM it can reduce the time it takes to finish generation as more images are generated at once.\
#Default: 1")
2022-10-23 13:00:46 +03:00
st . session_state [ " max_duration_in_seconds " ] = st . number_input ( " Max Duration In Seconds: " , value = st . session_state [ ' defaults ' ] . txt2vid . max_duration_in_seconds ,
help = " Specify the max duration in seconds you want your video to be. " )
2022-09-18 01:25:55 +03:00
2022-09-14 00:08:40 +03:00
with st . expander ( " Preview Settings " ) :
2022-10-02 23:10:17 +03:00
#st.session_state["update_preview"] = st.checkbox("Update Image Preview", value=st.session_state['defaults'].txt2vid.update_preview,
#help="If enabled the image preview will be updated during the generation instead of at the end. \
#You can use the Update Preview \Frequency option bellow to customize how frequent it's updated. \
#By default this is enabled and the frequency is set to 1 step.")
2022-09-18 01:25:55 +03:00
2022-10-02 23:10:17 +03:00
st . session_state [ " update_preview " ] = st . session_state [ " defaults " ] . general . update_preview
2022-10-08 02:59:29 +03:00
st . session_state [ " update_preview_frequency " ] = st . number_input ( " Update Image Preview Frequency " ,
2022-10-24 22:35:30 +03:00
min_value = 0 ,
2022-10-08 02:59:29 +03:00
value = st . session_state [ ' defaults ' ] . txt2vid . update_preview_frequency ,
help = " Frequency in steps at which the the preview image is updated. By default the frequency \
is set to 1 step . " )
2022-10-06 11:45:24 +03:00
2022-10-08 09:46:23 +03:00
st . session_state [ " dynamic_preview_frequency " ] = st . checkbox ( " Dynamic Preview Frequency " , value = st . session_state [ ' defaults ' ] . txt2vid . dynamic_preview_frequency ,
help = " This option tries to find the best value at which we can update \
the preview image during generation while minimizing the impact it has in performance . Default : True " )
2022-09-18 01:25:55 +03:00
#
2022-10-06 11:45:24 +03:00
2022-09-18 01:25:55 +03:00
with col2 :
preview_tab , gallery_tab = st . tabs ( [ " Preview " , " Gallery " ] )
with preview_tab :
#st.write("Image")
#Image for testing
#image = Image.open(requests.get("https://icon-library.com/images/image-placeholder-icon/image-placeholder-icon-13.jpg", stream=True).raw).convert('RGB')
#new_image = image.resize((175, 240))
#preview_image = st.image(image)
# create an empty container for the image, progress bar, etc so we can update it later and use session_state to hold them globally.
st . session_state [ " preview_image " ] = st . empty ( )
st . session_state [ " loading " ] = st . empty ( )
st . session_state [ " progress_bar_text " ] = st . empty ( )
st . session_state [ " progress_bar " ] = st . empty ( )
#generate_video = st.empty()
st . session_state [ " preview_video " ] = st . empty ( )
2022-10-08 09:46:23 +03:00
preview_video = st . session_state [ " preview_video " ]
2022-09-18 01:25:55 +03:00
message = st . empty ( )
with gallery_tab :
st . write ( ' Here should be the image gallery, if I could make a grid in streamlit. ' )
with col3 :
# If we have custom models available on the "models/custom"
#folder then we show a menu to select which model we want to use, otherwise we use the main model for SD
2022-09-25 01:27:09 +03:00
custom_models_available ( )
2022-09-29 13:29:44 +03:00
if server_state [ " CustomModel_available " ] :
2022-09-18 01:25:55 +03:00
custom_model = st . selectbox ( " Custom Model: " , st . session_state [ " defaults " ] . txt2vid . custom_models_list ,
index = st . session_state [ " defaults " ] . txt2vid . custom_models_list . index ( st . session_state [ " defaults " ] . txt2vid . default_model ) ,
help = " Select the model you want to use. This option is only available if you have custom models \
2022-09-15 01:18:37 +03:00
on your ' models/custom ' folder . The model name that will be shown here is the same as the name \
the file for the model has on said folder , it is recommended to give the . ckpt file a name that \
2022-10-21 05:50:40 +03:00
will make it easier for you to distinguish it from other models . Default : Stable Diffusion v1 .5 " )
2022-09-18 01:25:55 +03:00
else :
2022-10-21 05:50:40 +03:00
custom_model = " runwayml/stable-diffusion-v1-5 "
2022-09-18 01:25:55 +03:00
#st.session_state["weights_path"] = custom_model
#else:
2022-10-21 05:50:40 +03:00
#custom_model = "runwayml/stable-diffusion-v1-5"
2022-09-18 01:25:55 +03:00
#st.session_state["weights_path"] = f"CompVis/{slugify(custom_model.lower())}"
2022-10-06 11:45:24 +03:00
st . session_state . sampling_steps = st . number_input ( " Sampling Steps " , value = st . session_state [ ' defaults ' ] . txt2vid . sampling_steps . value ,
2022-09-18 21:11:23 +03:00
min_value = st . session_state [ ' defaults ' ] . txt2vid . sampling_steps . min_value ,
2022-10-06 11:45:24 +03:00
step = st . session_state [ ' defaults ' ] . txt2vid . sampling_steps . step , help = " Number of steps between each pair of sampled points " )
st . session_state . num_inference_steps = st . number_input ( " Inference Steps: " , value = st . session_state [ ' defaults ' ] . txt2vid . num_inference_steps . value ,
2022-09-18 21:11:23 +03:00
min_value = st . session_state [ ' defaults ' ] . txt2vid . num_inference_steps . min_value ,
2022-09-19 01:17:02 +03:00
step = st . session_state [ ' defaults ' ] . txt2vid . num_inference_steps . step ,
2022-09-18 01:25:55 +03:00
help = " Higher values (e.g. 100, 200 etc) can create better images. " )
#sampler_name_list = ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"]
#sampler_name = st.selectbox("Sampling method", sampler_name_list,
#index=sampler_name_list.index(st.session_state['defaults'].txt2vid.default_sampler), help="Sampling method to use. Default: k_euler")
scheduler_name_list = [ " klms " , " ddim " ]
scheduler_name = st . selectbox ( " Scheduler: " , scheduler_name_list ,
index = scheduler_name_list . index ( st . session_state [ ' defaults ' ] . txt2vid . scheduler_name ) , help = " Scheduler to use. Default: klms " )
beta_scheduler_type_list = [ " scaled_linear " , " linear " ]
beta_scheduler_type = st . selectbox ( " Beta Schedule Type: " , beta_scheduler_type_list ,
index = beta_scheduler_type_list . index ( st . session_state [ ' defaults ' ] . txt2vid . beta_scheduler_type ) , help = " Schedule Type to use. Default: linear " )
#basic_tab, advanced_tab = st.tabs(["Basic", "Advanced"])
#with basic_tab:
#summit_on_enter = st.radio("Submit on enter?", ("Yes", "No"), horizontal=True,
#help="Press the Enter key to summit, when 'No' is selected you can use the Enter key to write multiple lines.")
with st . expander ( " Advanced " ) :
2022-10-02 23:10:17 +03:00
with st . expander ( " Output Settings " ) :
st . session_state [ " separate_prompts " ] = st . checkbox ( " Create Prompt Matrix. " , value = st . session_state [ ' defaults ' ] . txt2vid . separate_prompts ,
help = " Separate multiple prompts using the `|` character, and get all combinations of them. " )
st . session_state [ " normalize_prompt_weights " ] = st . checkbox ( " Normalize Prompt Weights. " ,
value = st . session_state [ ' defaults ' ] . txt2vid . normalize_prompt_weights , help = " Ensure the sum of all weights add up to 1.0 " )
2022-10-08 09:46:23 +03:00
2022-10-02 23:10:17 +03:00
st . session_state [ " save_individual_images " ] = st . checkbox ( " Save individual images. " ,
value = st . session_state [ ' defaults ' ] . txt2vid . save_individual_images ,
help = " Save each image generated before any filter or enhancement is applied. " )
2022-10-08 09:46:23 +03:00
2022-10-02 23:10:17 +03:00
st . session_state [ " save_video " ] = st . checkbox ( " Save video " , value = st . session_state [ ' defaults ' ] . txt2vid . save_video ,
help = " Save a video with all the images generated as frames at the end of the generation. " )
2022-10-06 11:45:24 +03:00
2022-10-08 09:46:23 +03:00
save_video_on_stop = st . checkbox ( " Save video on Stop " , value = st . session_state [ ' defaults ' ] . txt2vid . save_video_on_stop ,
help = " Save a video with all the images generated as frames when we hit the stop button during a generation. " )
2022-10-02 23:10:17 +03:00
st . session_state [ " group_by_prompt " ] = st . checkbox ( " Group results by prompt " , value = st . session_state [ ' defaults ' ] . txt2vid . group_by_prompt ,
2022-10-08 09:46:23 +03:00
help = " Saves all the images with the same prompt into the same folder. When using a prompt \
matrix each prompt combination will have its own folder . " )
2022-10-02 23:10:17 +03:00
st . session_state [ " write_info_files " ] = st . checkbox ( " Write Info file " , value = st . session_state [ ' defaults ' ] . txt2vid . write_info_files ,
help = " Save a file next to the image with informartion about the generation. " )
2022-10-12 10:08:38 +03:00
2022-10-23 13:00:46 +03:00
st . session_state [ " do_loop " ] = st . checkbox ( " Do Loop " , value = st . session_state [ ' defaults ' ] . txt2vid . do_loop ,
help = " Loop the prompt making two prompts from a single one. " )
2022-10-12 10:08:38 +03:00
st . session_state [ " use_lerp_for_text " ] = st . checkbox ( " Use Lerp Instead of Slerp " , value = st . session_state [ ' defaults ' ] . txt2vid . use_lerp_for_text ,
help = " Uses torch.lerp() instead of slerp. When interpolating between related prompts. \
e . g . ' a lion in a grassy meadow ' - > ' a bear in a grassy meadow ' tends to keep the meadow \
the whole way through when lerped , but slerping will often find a path where the meadow \
disappears in the middle " )
2022-10-02 23:10:17 +03:00
st . session_state [ " save_as_jpg " ] = st . checkbox ( " Save samples as jpg " , value = st . session_state [ ' defaults ' ] . txt2vid . save_as_jpg , help = " Saves the images as jpg instead of png. " )
2022-09-18 01:25:55 +03:00
2022-10-03 06:26:01 +03:00
#
if " GFPGAN_available " not in st . session_state :
GFPGAN_available ( )
2022-10-06 11:45:24 +03:00
2022-10-03 06:26:01 +03:00
if " RealESRGAN_available " not in st . session_state :
RealESRGAN_available ( )
2022-10-06 11:45:24 +03:00
2022-10-03 06:26:01 +03:00
if " LDSR_available " not in st . session_state :
LDSR_available ( )
2022-10-06 11:45:24 +03:00
2022-10-03 06:26:01 +03:00
if st . session_state [ " GFPGAN_available " ] or st . session_state [ " RealESRGAN_available " ] or st . session_state [ " LDSR_available " ] :
with st . expander ( " Post-Processing " ) :
face_restoration_tab , upscaling_tab = st . tabs ( [ " Face Restoration " , " Upscaling " ] )
with face_restoration_tab :
# GFPGAN used for face restoration
if st . session_state [ " GFPGAN_available " ] :
#with st.expander("Face Restoration"):
#if st.session_state["GFPGAN_available"]:
#with st.expander("GFPGAN"):
st . session_state [ " use_GFPGAN " ] = st . checkbox ( " Use GFPGAN " , value = st . session_state [ ' defaults ' ] . txt2vid . use_GFPGAN ,
help = " Uses the GFPGAN model to improve faces after the generation. \
This greatly improve the quality and consistency of faces but uses \
extra VRAM . Disable if you need the extra VRAM . " )
2022-10-06 11:45:24 +03:00
2022-10-03 06:26:01 +03:00
st . session_state [ " GFPGAN_model " ] = st . selectbox ( " GFPGAN model " , st . session_state [ " GFPGAN_models " ] ,
2022-10-06 11:45:24 +03:00
index = st . session_state [ " GFPGAN_models " ] . index ( st . session_state [ ' defaults ' ] . general . GFPGAN_model ) )
2022-10-03 06:26:01 +03:00
#st.session_state["GFPGAN_strenght"] = st.slider("Effect Strenght", min_value=1, max_value=100, value=1, step=1, help='')
2022-10-06 11:45:24 +03:00
2022-10-03 06:26:01 +03:00
else :
2022-10-06 11:45:24 +03:00
st . session_state [ " use_GFPGAN " ] = False
2022-10-03 06:26:01 +03:00
with upscaling_tab :
st . session_state [ ' us_upscaling ' ] = st . checkbox ( " Use Upscaling " , value = st . session_state [ ' defaults ' ] . txt2vid . use_upscaling )
2022-10-06 11:45:24 +03:00
# RealESRGAN and LDSR used for upscaling.
2022-10-03 06:26:01 +03:00
if st . session_state [ " RealESRGAN_available " ] or st . session_state [ " LDSR_available " ] :
2022-10-06 11:45:24 +03:00
2022-10-03 06:26:01 +03:00
upscaling_method_list = [ ]
if st . session_state [ " RealESRGAN_available " ] :
upscaling_method_list . append ( " RealESRGAN " )
if st . session_state [ " LDSR_available " ] :
upscaling_method_list . append ( " LDSR " )
2022-10-06 11:45:24 +03:00
2022-10-03 06:26:01 +03:00
st . session_state [ " upscaling_method " ] = st . selectbox ( " Upscaling Method " , upscaling_method_list ,
index = upscaling_method_list . index ( st . session_state [ ' defaults ' ] . general . upscaling_method ) )
2022-10-06 11:45:24 +03:00
2022-10-03 06:26:01 +03:00
if st . session_state [ " RealESRGAN_available " ] :
with st . expander ( " RealESRGAN " ) :
if st . session_state [ " upscaling_method " ] == " RealESRGAN " and st . session_state [ ' us_upscaling ' ] :
st . session_state [ " use_RealESRGAN " ] = True
else :
st . session_state [ " use_RealESRGAN " ] = False
2022-10-06 11:45:24 +03:00
2022-10-03 06:26:01 +03:00
st . session_state [ " RealESRGAN_model " ] = st . selectbox ( " RealESRGAN model " , st . session_state [ " RealESRGAN_models " ] ,
2022-10-06 11:45:24 +03:00
index = st . session_state [ " RealESRGAN_models " ] . index ( st . session_state [ ' defaults ' ] . general . RealESRGAN_model ) )
2022-10-03 06:26:01 +03:00
else :
st . session_state [ " use_RealESRGAN " ] = False
st . session_state [ " RealESRGAN_model " ] = " RealESRGAN_x4plus "
2022-10-06 11:45:24 +03:00
2022-10-03 06:26:01 +03:00
#
if st . session_state [ " LDSR_available " ] :
with st . expander ( " LDSR " ) :
if st . session_state [ " upscaling_method " ] == " LDSR " and st . session_state [ ' us_upscaling ' ] :
st . session_state [ " use_LDSR " ] = True
else :
st . session_state [ " use_LDSR " ] = False
2022-10-06 11:45:24 +03:00
2022-10-03 06:26:01 +03:00
st . session_state [ " LDSR_model " ] = st . selectbox ( " LDSR model " , st . session_state [ " LDSR_models " ] ,
2022-10-06 11:45:24 +03:00
index = st . session_state [ " LDSR_models " ] . index ( st . session_state [ ' defaults ' ] . general . LDSR_model ) )
st . session_state [ " ldsr_sampling_steps " ] = st . number_input ( " Sampling Steps " , value = st . session_state [ ' defaults ' ] . txt2vid . LDSR_config . sampling_steps ,
help = " " )
st . session_state [ " preDownScale " ] = st . number_input ( " PreDownScale " , value = st . session_state [ ' defaults ' ] . txt2vid . LDSR_config . preDownScale ,
help = " " )
st . session_state [ " postDownScale " ] = st . number_input ( " postDownScale " , value = st . session_state [ ' defaults ' ] . txt2vid . LDSR_config . postDownScale ,
help = " " )
2022-10-03 06:26:01 +03:00
downsample_method_list = [ ' Nearest ' , ' Lanczos ' ]
st . session_state [ " downsample_method " ] = st . selectbox ( " Downsample Method " , downsample_method_list ,
index = downsample_method_list . index ( st . session_state [ ' defaults ' ] . txt2vid . LDSR_config . downsample_method ) )
2022-10-06 11:45:24 +03:00
2022-10-03 06:26:01 +03:00
else :
st . session_state [ " use_LDSR " ] = False
2022-10-06 11:45:24 +03:00
st . session_state [ " LDSR_model " ] = " model "
2022-09-26 16:17:50 +03:00
with st . expander ( " Variant " ) :
2022-10-06 11:45:24 +03:00
st . session_state [ " variant_amount " ] = st . number_input ( " Variant Amount: " , value = st . session_state [ ' defaults ' ] . txt2vid . variant_amount . value ,
2022-09-26 16:17:50 +03:00
min_value = st . session_state [ ' defaults ' ] . txt2vid . variant_amount . min_value ,
max_value = st . session_state [ ' defaults ' ] . txt2vid . variant_amount . max_value ,
step = st . session_state [ ' defaults ' ] . txt2vid . variant_amount . step )
2022-10-06 11:45:24 +03:00
st . session_state [ " variant_seed " ] = st . text_input ( " Variant Seed: " , value = st . session_state [ ' defaults ' ] . txt2vid . seed ,
2022-09-26 16:17:50 +03:00
help = " The seed to use when generating a variant, if left blank a random seed will be generated. " )
2022-10-06 11:45:24 +03:00
2022-09-18 21:11:23 +03:00
#st.session_state["beta_start"] = st.slider("Beta Start:", value=st.session_state['defaults'].txt2vid.beta_start.value,
#min_value=st.session_state['defaults'].txt2vid.beta_start.min_value,
#max_value=st.session_state['defaults'].txt2vid.beta_start.max_value,
#step=st.session_state['defaults'].txt2vid.beta_start.step, format=st.session_state['defaults'].txt2vid.beta_start.format)
#st.session_state["beta_end"] = st.slider("Beta End:", value=st.session_state['defaults'].txt2vid.beta_end.value,
#min_value=st.session_state['defaults'].txt2vid.beta_end.min_value, max_value=st.session_state['defaults'].txt2vid.beta_end.max_value,
#step=st.session_state['defaults'].txt2vid.beta_end.step, format=st.session_state['defaults'].txt2vid.beta_end.format)
2022-09-18 01:25:55 +03:00
if generate_button :
#print("Loading models")
# load the models when we hit the generate button for the first time, it wont be loaded after that so dont worry.
2022-09-26 07:23:32 +03:00
#load_models(False, st.session_state["use_GFPGAN"], True, st.session_state["RealESRGAN_model"])
2022-10-06 11:45:24 +03:00
2022-09-26 07:23:32 +03:00
if st . session_state [ " use_GFPGAN " ] :
2022-10-08 09:46:23 +03:00
if " GFPGAN " in server_state :
2022-10-15 15:34:07 +03:00
logger . info ( " GFPGAN already loaded " )
2022-09-19 01:17:02 +03:00
else :
2022-09-28 19:33:54 +03:00
with col2 :
with hc . HyLoader ( ' Loading Models... ' , hc . Loaders . standard_loaders , index = [ 0 ] ) :
# Load GFPGAN
if os . path . exists ( st . session_state [ " defaults " ] . general . GFPGAN_dir ) :
try :
2022-10-08 09:46:23 +03:00
load_GFPGAN ( )
2022-10-15 15:34:07 +03:00
logger . info ( " Loaded GFPGAN " )
2022-09-28 19:33:54 +03:00
except Exception :
import traceback
2022-10-15 15:34:07 +03:00
logger . error ( " Error loading GFPGAN: " , file = sys . stderr )
logger . error ( traceback . format_exc ( ) , file = sys . stderr )
2022-09-19 01:17:02 +03:00
else :
2022-10-08 09:46:23 +03:00
if " GFPGAN " in server_state :
2022-10-06 11:45:24 +03:00
del server_state [ " GFPGAN " ]
2022-09-25 09:28:02 +03:00
2022-09-26 07:23:32 +03:00
#try:
# run video generation
video , seed , info , stats = txt2vid ( prompts = prompt , gpu = st . session_state [ " defaults " ] . general . gpu ,
2022-10-23 13:00:46 +03:00
num_steps = st . session_state . sampling_steps , max_duration_in_seconds = st . session_state . max_duration_in_seconds ,
2022-10-03 07:25:41 +03:00
num_inference_steps = st . session_state . num_inference_steps ,
2022-10-08 09:46:23 +03:00
cfg_scale = cfg_scale , save_video_on_stop = save_video_on_stop ,
outdir = st . session_state [ " defaults " ] . general . outdir ,
2022-10-23 13:00:46 +03:00
do_loop = st . session_state [ " do_loop " ] ,
2022-10-12 10:08:38 +03:00
use_lerp_for_text = st . session_state [ " use_lerp_for_text " ] ,
2022-10-03 07:25:41 +03:00
seeds = seed , quality = 100 , eta = 0.0 , width = width ,
height = height , weights_path = custom_model , scheduler = scheduler_name ,
disable_tqdm = False , beta_start = st . session_state [ ' defaults ' ] . txt2vid . beta_start . value ,
beta_end = st . session_state [ ' defaults ' ] . txt2vid . beta_end . value ,
beta_schedule = beta_scheduler_type , starting_image = None )
2022-09-26 07:23:32 +03:00
2022-10-08 09:46:23 +03:00
if video and save_video_on_stop :
# show video preview on the UI after we hit the stop button
# currently not working as session_state is cleared on StopException
preview_video . video ( open ( video , ' rb ' ) . read ( ) )
2022-09-26 07:23:32 +03:00
#message.success('Done!', icon="✅")
message . success ( ' Render Complete: ' + info + ' ; Stats: ' + stats , icon = " ✅ " )
#history_tab,col1,col2,col3,PlaceHolder,col1_cont,col2_cont,col3_cont = st.session_state['historyTab']
#if 'latestVideos' in st.session_state:
#for i in video:
##push the new image to the list of latest images and remove the oldest one
##remove the last index from the list\
#st.session_state['latestVideos'].pop()
##add the new image to the start of the list
#st.session_state['latestVideos'].insert(0, i)
#PlaceHolder.empty()
#with PlaceHolder.container():
#col1, col2, col3 = st.columns(3)
#col1_cont = st.container()
#col2_cont = st.container()
#col3_cont = st.container()
#with col1_cont:
#with col1:
#st.image(st.session_state['latestVideos'][0])
#st.image(st.session_state['latestVideos'][3])
#st.image(st.session_state['latestVideos'][6])
#with col2_cont:
#with col2:
#st.image(st.session_state['latestVideos'][1])
#st.image(st.session_state['latestVideos'][4])
#st.image(st.session_state['latestVideos'][7])
#with col3_cont:
#with col3:
#st.image(st.session_state['latestVideos'][2])
#st.image(st.session_state['latestVideos'][5])
#st.image(st.session_state['latestVideos'][8])
#historyGallery = st.empty()
## check if output_images length is the same as seeds length
#with gallery_tab:
#st.markdown(createHTMLGallery(video,seed), unsafe_allow_html=True)
#st.session_state['historyTab'] = [history_tab,col1,col2,col3,PlaceHolder,col1_cont,col2_cont,col3_cont]
#except (StopException, KeyError):
#print(f"Received Streamlit StopException")
2022-09-18 01:25:55 +03:00
2022-09-14 00:08:40 +03:00