Added a progress bar and updated the Image-to-image tab. (#652)

* Added a progress bar as well as some extra info to know how the generation is going without having to check the console every time.

* - Updated the Image-to-image tab, it is now working at a basic level.
- Disabled RealESRGAN by default for the Image-to-Image tab as it is not working right now.
This commit is contained in:
ZeroCool 2022-09-04 20:53:52 -07:00 committed by GitHub
parent 75a7ef77f0
commit 78ad3c3445
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 157 additions and 103 deletions

View File

@ -88,6 +88,8 @@ os.environ["CUDA_VISIBLE_DEVICES"] = str(defaults.general.gpu)
def load_models(continue_prev_run = False, use_GFPGAN=False, use_RealESRGAN=False, RealESRGAN_model="RealESRGAN_x4plus"): def load_models(continue_prev_run = False, use_GFPGAN=False, use_RealESRGAN=False, RealESRGAN_model="RealESRGAN_x4plus"):
"""Load the different models. We also reuse the models that are already in memory to speed things up instead of loading them again. """ """Load the different models. We also reuse the models that are already in memory to speed things up instead of loading them again. """
print ("Loading models.")
# Generate random run ID # Generate random run ID
# Used to link runs linked w/ continue_prev_run which is not yet implemented # Used to link runs linked w/ continue_prev_run which is not yet implemented
# Use URL and filesystem safe version just in case. # Use URL and filesystem safe version just in case.
@ -180,8 +182,13 @@ def load_sd_from_config(ckpt, verbose=False):
def generation_callback(img, i=0): def generation_callback(img, i=0):
try:
if i == 0: if i == 0:
if img['i']: i = img['i'] if img['i']: i = img['i']
except TypeError:
pass
if i % int(defaults.general.update_preview_frequency) == 0 and defaults.general.update_preview: if i % int(defaults.general.update_preview_frequency) == 0 and defaults.general.update_preview:
#print (img) #print (img)
@ -200,8 +207,21 @@ def generation_callback(img, i=0):
pil_image = transforms.ToPILImage()(x_samples_ddim.squeeze_(0)) pil_image = transforms.ToPILImage()(x_samples_ddim.squeeze_(0))
# update image on the UI so we can see the progress
st.session_state["preview_image"].image(pil_image, width=512) st.session_state["preview_image"].image(pil_image, width=512)
# Show a progress bar so we can keep track of the progress even when the image progress is not been shown,
# Dont worry, it doesnt affect the performance.
if st.session_state["generation_mode"] == "txt2img":
percent = int(100 * float(i+1)/float(st.session_state.sampling_steps))
st.session_state["progress_bar_text"].text(f"Running step: {i+1}/{st.session_state.sampling_steps} {percent}%")
else:
percent = int(100 * float(i+1 )/float(round(st.session_state.sampling_steps * st.session_state["denoising_strength"])))
st.session_state["progress_bar_text"].text(f"""Running step: {i+1}/{round(st.session_state.sampling_steps * st.session_state["denoising_strength"])} {percent}%""")
st.session_state["progress_bar"] = st.session_state["progress_bar"].progress(percent)
class MemUsageMonitor(threading.Thread): class MemUsageMonitor(threading.Thread):
stop_flag = False stop_flag = False
@ -401,7 +421,7 @@ def load_RealESRGAN(model_name: str):
} }
model_path = os.path.join(defaults.general.RealESRGAN_dir, 'experiments/pretrained_models', model_name + '.pth') model_path = os.path.join(defaults.general.RealESRGAN_dir, 'experiments/pretrained_models', model_name + '.pth')
if not os.path.isfile(model_path): if not os.path.exists(os.path.join(defaults.general.RealESRGAN_dir, "experiments","pretrained_models", f"{model_name}.pth")):
raise Exception(model_name+".pth not found at path "+model_path) raise Exception(model_name+".pth not found at path "+model_path)
sys.path.append(os.path.abspath(defaults.general.RealESRGAN_dir)) sys.path.append(os.path.abspath(defaults.general.RealESRGAN_dir))
@ -505,27 +525,6 @@ def ModelLoader(models,load=False,unload=False,imgproc_realesrgan_model_name='Re
torch_gc() torch_gc()
def run_GFPGAN(image, strength):
image = image.convert("RGB")
cropped_faces, restored_faces, restored_img = GFPGAN.enhance(np.array(image, dtype=np.uint8), has_aligned=False, only_center_face=False, paste_back=True)
res = Image.fromarray(restored_img)
if strength < 1.0:
res = Image.blend(image, res, strength)
return res
def run_RealESRGAN(image, model_name: str):
if RealESRGAN.model.name != model_name:
try_loading_RealESRGAN(model_name)
image = image.convert("RGB")
output, img_mode = RealESRGAN.enhance(np.array(image, dtype=np.uint8))
res = Image.fromarray(output)
return res
def get_font(fontsize): def get_font(fontsize):
fonts = ["arial.ttf", "DejaVuSans.ttf"] fonts = ["arial.ttf", "DejaVuSans.ttf"]
@ -957,8 +956,11 @@ def process_images(
if use_RealESRGAN and st.session_state["RealESRGAN"] is not None and not use_GFPGAN: if use_RealESRGAN and st.session_state["RealESRGAN"] is not None and not use_GFPGAN:
skip_save = True # #287 >_> skip_save = True # #287 >_>
torch_gc() torch_gc()
if st.session_state["RealESRGAN"].model.name != realesrgan_model_name: if st.session_state["RealESRGAN"].model.name != realesrgan_model_name:
try_loading_RealESRGAN(realesrgan_model_name) #try_loading_RealESRGAN(realesrgan_model_name)
load_models(use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name)
output, img_mode = st.session_state["RealESRGAN"].enhance(x_sample[:,:,::-1]) output, img_mode = st.session_state["RealESRGAN"].enhance(x_sample[:,:,::-1])
esrgan_filename = original_filename + '-esrgan4x' esrgan_filename = original_filename + '-esrgan4x'
esrgan_sample = output[:,:,::-1] esrgan_sample = output[:,:,::-1]
@ -980,8 +982,11 @@ def process_images(
torch_gc() torch_gc()
cropped_faces, restored_faces, restored_img = st.session_state["GFPGAN"].enhance(x_sample[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True) cropped_faces, restored_faces, restored_img = st.session_state["GFPGAN"].enhance(x_sample[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True)
gfpgan_sample = restored_img[:,:,::-1] gfpgan_sample = restored_img[:,:,::-1]
if st.session_state["RealESRGAN"].model.name != realesrgan_model_name: if st.session_state["RealESRGAN"].model.name != realesrgan_model_name:
try_loading_RealESRGAN(realesrgan_model_name) #try_loading_RealESRGAN(realesrgan_model_name)
load_models(use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name)
output, img_mode = st.session_state["RealESRGAN"].enhance(gfpgan_sample[:,:,::-1]) output, img_mode = st.session_state["RealESRGAN"].enhance(gfpgan_sample[:,:,::-1])
gfpgan_esrgan_filename = original_filename + '-gfpgan-esrgan4x' gfpgan_esrgan_filename = original_filename + '-gfpgan-esrgan4x'
gfpgan_esrgan_sample = output[:,:,::-1] gfpgan_esrgan_sample = output[:,:,::-1]
@ -1102,7 +1107,8 @@ def img2img(prompt: str = '', init_info: any = None, ddim_steps: int = 50, sampl
write_info_files:bool = True, RealESRGAN_model: str = "RealESRGAN_x4plus_anime_6B", write_info_files:bool = True, RealESRGAN_model: str = "RealESRGAN_x4plus_anime_6B",
separate_prompts:bool = False, normalize_prompt_weights:bool = True, separate_prompts:bool = False, normalize_prompt_weights:bool = True,
save_individual_images: bool = True, save_grid: bool = True, group_by_prompt: bool = True, save_individual_images: bool = True, save_grid: bool = True, group_by_prompt: bool = True,
save_as_jpg: bool = True, use_GFPGAN: bool = True, use_RealESRGAN: bool = True): save_as_jpg: bool = True, use_GFPGAN: bool = True, use_RealESRGAN: bool = True, loopback: bool = True
):
outpath = defaults.general.outdir_img2img or defaults.general.outdir or "outputs/img2img-samples" outpath = defaults.general.outdir_img2img or defaults.general.outdir or "outputs/img2img-samples"
err = False err = False
@ -1196,7 +1202,10 @@ def img2img(prompt: str = '', init_info: any = None, ddim_steps: int = 50, sampl
sigma_sched = sigmas[ddim_steps - t_enc_steps - 1:] sigma_sched = sigmas[ddim_steps - t_enc_steps - 1:]
model_wrap_cfg = CFGMaskedDenoiser(sampler.model_wrap) model_wrap_cfg = CFGMaskedDenoiser(sampler.model_wrap)
samples_ddim = K.sampling.__dict__[f'sample_{sampler.get_sampler_name()}'](model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': cfg_scale, 'mask': z_mask, 'x0': x0, 'xi': xi}, disable=False) samples_ddim = K.sampling.__dict__[f'sample_{sampler.get_sampler_name()}'](model_wrap_cfg, xi, sigma_sched,
extra_args={'cond': conditioning, 'uncond': unconditional_conditioning,
'cond_scale': cfg_scale, 'mask': z_mask, 'x0': x0, 'xi': xi}, disable=False,
callback=generation_callback)
else: else:
x0, z_mask = init_data x0, z_mask = init_data
@ -1448,7 +1457,7 @@ def layout():
else: else:
GFPGAN_available = False GFPGAN_available = False
if os.path.exists(os.path.join(defaults.general.RealESRGAN_dir, "experiments","pretrained_models", "RealESRGAN_x4plus.pth")): if os.path.exists(os.path.join(defaults.general.RealESRGAN_dir, "experiments","pretrained_models", f"{defaults.general.RealESRGAN_model}.pth")):
RealESRGAN_available = True RealESRGAN_available = True
else: else:
RealESRGAN_available = False RealESRGAN_available = False
@ -1468,6 +1477,8 @@ def layout():
txt2img_tab, img2img_tab, postprocessing_tab = st.tabs(["Stable Diffusion Text-to-Image Unified", "Stable Diffusion Image-to-Image Unified", "Post-Processing"]) txt2img_tab, img2img_tab, postprocessing_tab = st.tabs(["Stable Diffusion Text-to-Image Unified", "Stable Diffusion Image-to-Image Unified", "Post-Processing"])
with txt2img_tab: with txt2img_tab:
st.session_state["generation_mode"] = "txt2img"
with st.form("txt2img-inputs"): with st.form("txt2img-inputs"):
input_col1, generate_col1 = st.columns([10,1]) input_col1, generate_col1 = st.columns([10,1])
@ -1484,8 +1495,8 @@ def layout():
col1, col2, col3 = st.columns([1,2,1], gap="large") col1, col2, col3 = st.columns([1,2,1], gap="large")
with col1: with col1:
height = st.slider("Height:", min_value=64, max_value=2048, value=defaults.txt2img.height, step=64)
width = st.slider("Width:", min_value=64, max_value=2048, value=defaults.txt2img.width, step=64) width = st.slider("Width:", min_value=64, max_value=2048, value=defaults.txt2img.width, step=64)
height = st.slider("Height:", min_value=64, max_value=2048, value=defaults.txt2img.height, step=64)
cfg_scale = st.slider("CFG (Classifier Free Guidance Scale):", min_value=1.0, max_value=30.0, value=defaults.txt2img.cfg_scale, step=0.5, help="How strongly the image should follow the prompt.") cfg_scale = st.slider("CFG (Classifier Free Guidance Scale):", min_value=1.0, max_value=30.0, value=defaults.txt2img.cfg_scale, step=0.5, help="How strongly the image should follow the prompt.")
seed = st.text_input("Seed:", value=defaults.txt2img.seed, help=" The seed to use, if left blank a random seed will be generated.") seed = st.text_input("Seed:", value=defaults.txt2img.seed, help=" The seed to use, if left blank a random seed will be generated.")
batch_count = st.slider("Batch count.", min_value=1, max_value=500, value=defaults.txt2img.batch_count, step=1, help="How many iterations or batches of images to generate in total.") batch_count = st.slider("Batch count.", min_value=1, max_value=500, value=defaults.txt2img.batch_count, step=1, help="How many iterations or batches of images to generate in total.")
@ -1498,26 +1509,33 @@ def layout():
preview_tab, gallery_tab = st.tabs(["Preview", "Gallery"]) preview_tab, gallery_tab = st.tabs(["Preview", "Gallery"])
with preview_tab: with preview_tab:
st.write("Image") #st.write("Image")
#Image for testing #Image for testing
#image = Image.open(requests.get("https://icon-library.com/images/image-placeholder-icon/image-placeholder-icon-13.jpg", stream=True).raw) #image = Image.open(requests.get("https://icon-library.com/images/image-placeholder-icon/image-placeholder-icon-13.jpg", stream=True).raw)
#new_image = image.resize((175, 240)) #new_image = image.resize((175, 240))
#preview_image = st.image(image) #preview_image = st.image(image)
# create an empty container for the image and use session_state to hold it globally. # create an empty container for the image, progress bar, etc so we can update it later and use session_state to hold them globally.
preview_image = st.empty() st.session_state["preview_image"] = st.empty()
st.session_state["preview_image"] = preview_image
st.session_state["loading"] = st.empty()
st.session_state["progress_bar_text"] = st.empty()
st.session_state["progress_bar"] = st.empty()
message = st.empty()
with gallery_tab: with gallery_tab:
st.write('Here should be the image gallery, if I could make a grid in streamlit.') st.write('Here should be the image gallery, if I could make a grid in streamlit.')
with col3: with col3:
sampling_steps = st.slider("Sampling Steps", value=defaults.txt2img.sampling_steps, min_value=1, max_value=250) st.session_state.sampling_steps = st.slider("Sampling Steps", value=defaults.txt2img.sampling_steps, min_value=1, max_value=250)
sampler_name = st.selectbox("Sampling method", sampler_name = st.selectbox("Sampling method",
["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"], ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"],
index=0, help="Sampling method to use. Default: k_lms") index=0, help="Sampling method to use. Default: k_lms")
#basic_tab, advanced_tab = st.tabs(["Basic", "Advanced"]) #basic_tab, advanced_tab = st.tabs(["Basic", "Advanced"])
#with basic_tab: #with basic_tab:
@ -1556,10 +1574,13 @@ def layout():
load_models(False, use_GFPGAN, use_RealESRGAN, RealESRGAN_model) load_models(False, use_GFPGAN, use_RealESRGAN, RealESRGAN_model)
try: try:
output_images, seed, info, stats = txt2img(prompt, sampling_steps, sampler_name, RealESRGAN_model, batch_count, batch_size, output_images, seed, info, stats = txt2img(prompt, st.session_state.sampling_steps, sampler_name, RealESRGAN_model, batch_count, batch_size,
cfg_scale, seed, height, width, separate_prompts, normalize_prompt_weights, save_individual_images, cfg_scale, seed, height, width, separate_prompts, normalize_prompt_weights, save_individual_images,
save_grid, group_by_prompt, save_as_jpg, use_GFPGAN, use_RealESRGAN, RealESRGAN_model, fp=defaults.general.fp, save_grid, group_by_prompt, save_as_jpg, use_GFPGAN, use_RealESRGAN, RealESRGAN_model, fp=defaults.general.fp,
variant_amount=variant_amount, variant_seed=variant_seed, write_info_files=write_info_files) variant_amount=variant_amount, variant_seed=variant_seed, write_info_files=write_info_files)
message.success('Done!', icon="")
except (StopException, KeyError): except (StopException, KeyError):
print(f"Received Streamlit StopException") print(f"Received Streamlit StopException")
@ -1568,6 +1589,8 @@ def layout():
#preview_image.image(output_images, width=750) #preview_image.image(output_images, width=750)
with img2img_tab: with img2img_tab:
st.session_state["generation_mode"] = "img2img"
with st.form("img2img-inputs"): with st.form("img2img-inputs"):
img2img_input_col, img2img_generate_col = st.columns([10,1]) img2img_input_col, img2img_generate_col = st.columns([10,1])
@ -1593,8 +1616,8 @@ def layout():
help="Upload an image which will be used for the image to image generation." help="Upload an image which will be used for the image to image generation."
) )
height = st.slider("Height:", min_value=64, max_value=2048, value=defaults.img2img.height, step=64)
width = st.slider("Width:", min_value=64, max_value=2048, value=defaults.img2img.width, step=64) width = st.slider("Width:", min_value=64, max_value=2048, value=defaults.img2img.width, step=64)
height = st.slider("Height:", min_value=64, max_value=2048, value=defaults.img2img.height, step=64)
seed = st.text_input("Seed:", value=defaults.img2img.seed, help=" The seed to use, if left blank a random seed will be generated.") seed = st.text_input("Seed:", value=defaults.img2img.seed, help=" The seed to use, if left blank a random seed will be generated.")
batch_count = st.slider("Batch count.", min_value=1, max_value=500, value=defaults.img2img.batch_count, step=1, help="How many iterations or batches of images to generate in total.") batch_count = st.slider("Batch count.", min_value=1, max_value=500, value=defaults.img2img.batch_count, step=1, help="How many iterations or batches of images to generate in total.")
@ -1633,27 +1656,43 @@ def layout():
It increases the VRAM usage a lot but if you have enough VRAM it can reduce the time it takes to finish generation as more images are generated at once.\ It increases the VRAM usage a lot but if you have enough VRAM it can reduce the time it takes to finish generation as more images are generated at once.\
Default: 1") Default: 1")
st.session_state["denoising_strength"] = st.slider("Denoising Strength:", value=defaults.img2img.denoising_strength, min_value=0.0, max_value=1.0, step=0.01)
with col2_img2img_layout: with col2_img2img_layout:
preview_tab, gallery_tab = st.tabs(["Preview", "Gallery"]) editor_tab = st.tabs(["Editor"])
with preview_tab: editor_image = st.empty()
st.write("Image") st.session_state["editor_image"] = editor_image
#Image for testing
#image = Image.open(requests.get("https://icon-library.com/images/image-placeholder-icon/image-placeholder-icon-13.jpg", stream=True).raw)
#new_image = image.resize((175, 240))
#preview_image = st.image(image)
# create an empty container for the image and use session_state to hold it globally. if uploaded_images:
image = Image.open(uploaded_images)
#img_array = np.array(image) # if you want to pass it to OpenCV
new_img = image.resize((width, height))
st.image(new_img)
with col3_img2img_layout:
result_tab = st.tabs(["Result"])
# create an empty container for the image, progress bar, etc so we can update it later and use session_state to hold them globally.
preview_image = st.empty() preview_image = st.empty()
st.session_state["preview_image"] = preview_image st.session_state["preview_image"] = preview_image
with gallery_tab: #st.session_state["loading"] = st.empty()
st.write('Here should be the image gallery, if I could make a grid in streamlit.')
st.session_state["progress_bar_text"] = st.empty()
st.session_state["progress_bar"] = st.empty()
message = st.empty()
#if uploaded_images: #if uploaded_images:
#image = Image.open(uploaded_images) #image = Image.open(uploaded_images)
# img_array = np.array(image) # if you want to pass it to OpenCV ##img_array = np.array(image) # if you want to pass it to OpenCV
# st.image(image, use_column_width=True) #new_img = image.resize((width, height))
#st.image(new_img, use_column_width=True)
if generate_button: if generate_button:
#print("Loading models") #print("Loading models")
@ -1661,9 +1700,24 @@ def layout():
load_models(False, use_GFPGAN, use_RealESRGAN, RealESRGAN_model) load_models(False, use_GFPGAN, use_RealESRGAN, RealESRGAN_model)
if uploaded_images: if uploaded_images:
image = Image.open(uploaded_images) image = Image.open(uploaded_images)
img_array = np.array(image) # if you want to pass it to OpenCV new_img = image.resize((width, height))
#img_array = np.array(image) # if you want to pass it to OpenCV
try: try:
output_images, seed, info, stats = img2img(prompt=prompt, init_info=image, ddim_steps=sampling_steps, sampler_name=sampler_name, n_iter=batch_count) #output_images, seed, info, stats = img2img(prompt=prompt, init_info=new_img, ddim_steps=sampling_steps, sampler_name=sampler_name, n_iter=batch_count)
output_images, seed, info, stats = img2img(prompt=prompt, init_info=new_img, ddim_steps=sampling_steps, sampler_name=sampler_name, n_iter=batch_count,
cfg_scale=cfg_scale, denoising_strength=st.session_state["denoising_strength"], variant_seed=variant_seed,
seed=seed, width=width, height=height, fp=defaults.general.fp, variant_amount=variant_amount,
ddim_eta=0.0, write_info_files=True, RealESRGAN_model="RealESRGAN_x4plus_anime_6B",
separate_prompts=separate_prompts, normalize_prompt_weights=normalize_prompt_weights,
save_individual_images=save_individual_images, save_grid=save_grid,
group_by_prompt=group_by_prompt, save_as_jpg=save_as_jpg, use_GFPGAN=use_GFPGAN,
use_RealESRGAN=use_RealESRGAN
)
#show a message when the generation is complete.
message.success('Done!', icon="")
except (StopException, KeyError): except (StopException, KeyError):
print(f"Received Streamlit StopException") print(f"Received Streamlit StopException")

View File

@ -6,8 +6,8 @@ general:
ckpt: "models/ldm/stable-diffusion-v1/model.ckpt" ckpt: "models/ldm/stable-diffusion-v1/model.ckpt"
fp: fp:
name: 'embeddings/alex/embeddings_gs-11000.pt' name: 'embeddings/alex/embeddings_gs-11000.pt'
GFPGAN_dir: "./src/GFPGAN" GFPGAN_dir: "./src/gfpgan"
RealESRGAN_dir: "./src/RealESRGAN" RealESRGAN_dir: "./src/realesrgan"
RealESRGAN_model: "RealESRGAN_x4plus" RealESRGAN_model: "RealESRGAN_x4plus"
outdir_txt2img: outputs/txt2img-samples outdir_txt2img: outputs/txt2img-samples
outdir_img2img: outputs/img2img-samples outdir_img2img: outputs/img2img-samples
@ -91,7 +91,7 @@ img2img:
group_by_prompt: True group_by_prompt: True
save_as_jpg: False save_as_jpg: False
use_GFPGAN: True use_GFPGAN: True
use_RealESRGAN: True use_RealESRGAN: False
RealESRGAN_model: "RealESRGAN_x4plus" RealESRGAN_model: "RealESRGAN_x4plus"
variant_amount: 0.0 variant_amount: 0.0