Added a progress bar and updated the Image-to-image tab. (#652)

* Added a progress bar as well as some extra info to know how the generation is going without having to check the console every time.

* - Updated the Image-to-image tab, it is now working at a basic level.
- Disabled RealESRGAN by default for the Image-to-Image tab as it is not working right now.
This commit is contained in:
ZeroCool 2022-09-04 20:53:52 -07:00 committed by GitHub
parent 75a7ef77f0
commit 78ad3c3445
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 157 additions and 103 deletions

View File

@ -88,6 +88,8 @@ os.environ["CUDA_VISIBLE_DEVICES"] = str(defaults.general.gpu)
def load_models(continue_prev_run = False, use_GFPGAN=False, use_RealESRGAN=False, RealESRGAN_model="RealESRGAN_x4plus"):
"""Load the different models. We also reuse the models that are already in memory to speed things up instead of loading them again. """
print ("Loading models.")
# Generate random run ID
# Used to link runs linked w/ continue_prev_run which is not yet implemented
# Use URL and filesystem safe version just in case.
@ -180,8 +182,13 @@ def load_sd_from_config(ckpt, verbose=False):
def generation_callback(img, i=0):
try:
if i == 0:
if img['i']: i = img['i']
except TypeError:
pass
if i % int(defaults.general.update_preview_frequency) == 0 and defaults.general.update_preview:
#print (img)
@ -200,8 +207,21 @@ def generation_callback(img, i=0):
pil_image = transforms.ToPILImage()(x_samples_ddim.squeeze_(0))
# update image on the UI so we can see the progress
st.session_state["preview_image"].image(pil_image, width=512)
# Show a progress bar so we can keep track of the progress even when the image progress is not been shown,
# Dont worry, it doesnt affect the performance.
if st.session_state["generation_mode"] == "txt2img":
percent = int(100 * float(i+1)/float(st.session_state.sampling_steps))
st.session_state["progress_bar_text"].text(f"Running step: {i+1}/{st.session_state.sampling_steps} {percent}%")
else:
percent = int(100 * float(i+1 )/float(round(st.session_state.sampling_steps * st.session_state["denoising_strength"])))
st.session_state["progress_bar_text"].text(f"""Running step: {i+1}/{round(st.session_state.sampling_steps * st.session_state["denoising_strength"])} {percent}%""")
st.session_state["progress_bar"] = st.session_state["progress_bar"].progress(percent)
class MemUsageMonitor(threading.Thread):
stop_flag = False
@ -401,7 +421,7 @@ def load_RealESRGAN(model_name: str):
}
model_path = os.path.join(defaults.general.RealESRGAN_dir, 'experiments/pretrained_models', model_name + '.pth')
if not os.path.isfile(model_path):
if not os.path.exists(os.path.join(defaults.general.RealESRGAN_dir, "experiments","pretrained_models", f"{model_name}.pth")):
raise Exception(model_name+".pth not found at path "+model_path)
sys.path.append(os.path.abspath(defaults.general.RealESRGAN_dir))
@ -505,27 +525,6 @@ def ModelLoader(models,load=False,unload=False,imgproc_realesrgan_model_name='Re
torch_gc()
def run_GFPGAN(image, strength):
image = image.convert("RGB")
cropped_faces, restored_faces, restored_img = GFPGAN.enhance(np.array(image, dtype=np.uint8), has_aligned=False, only_center_face=False, paste_back=True)
res = Image.fromarray(restored_img)
if strength < 1.0:
res = Image.blend(image, res, strength)
return res
def run_RealESRGAN(image, model_name: str):
if RealESRGAN.model.name != model_name:
try_loading_RealESRGAN(model_name)
image = image.convert("RGB")
output, img_mode = RealESRGAN.enhance(np.array(image, dtype=np.uint8))
res = Image.fromarray(output)
return res
def get_font(fontsize):
fonts = ["arial.ttf", "DejaVuSans.ttf"]
@ -957,8 +956,11 @@ def process_images(
if use_RealESRGAN and st.session_state["RealESRGAN"] is not None and not use_GFPGAN:
skip_save = True # #287 >_>
torch_gc()
if st.session_state["RealESRGAN"].model.name != realesrgan_model_name:
try_loading_RealESRGAN(realesrgan_model_name)
#try_loading_RealESRGAN(realesrgan_model_name)
load_models(use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name)
output, img_mode = st.session_state["RealESRGAN"].enhance(x_sample[:,:,::-1])
esrgan_filename = original_filename + '-esrgan4x'
esrgan_sample = output[:,:,::-1]
@ -980,8 +982,11 @@ def process_images(
torch_gc()
cropped_faces, restored_faces, restored_img = st.session_state["GFPGAN"].enhance(x_sample[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True)
gfpgan_sample = restored_img[:,:,::-1]
if st.session_state["RealESRGAN"].model.name != realesrgan_model_name:
try_loading_RealESRGAN(realesrgan_model_name)
#try_loading_RealESRGAN(realesrgan_model_name)
load_models(use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name)
output, img_mode = st.session_state["RealESRGAN"].enhance(gfpgan_sample[:,:,::-1])
gfpgan_esrgan_filename = original_filename + '-gfpgan-esrgan4x'
gfpgan_esrgan_sample = output[:,:,::-1]
@ -1102,7 +1107,8 @@ def img2img(prompt: str = '', init_info: any = None, ddim_steps: int = 50, sampl
write_info_files:bool = True, RealESRGAN_model: str = "RealESRGAN_x4plus_anime_6B",
separate_prompts:bool = False, normalize_prompt_weights:bool = True,
save_individual_images: bool = True, save_grid: bool = True, group_by_prompt: bool = True,
save_as_jpg: bool = True, use_GFPGAN: bool = True, use_RealESRGAN: bool = True):
save_as_jpg: bool = True, use_GFPGAN: bool = True, use_RealESRGAN: bool = True, loopback: bool = True
):
outpath = defaults.general.outdir_img2img or defaults.general.outdir or "outputs/img2img-samples"
err = False
@ -1196,7 +1202,10 @@ def img2img(prompt: str = '', init_info: any = None, ddim_steps: int = 50, sampl
sigma_sched = sigmas[ddim_steps - t_enc_steps - 1:]
model_wrap_cfg = CFGMaskedDenoiser(sampler.model_wrap)
samples_ddim = K.sampling.__dict__[f'sample_{sampler.get_sampler_name()}'](model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': cfg_scale, 'mask': z_mask, 'x0': x0, 'xi': xi}, disable=False)
samples_ddim = K.sampling.__dict__[f'sample_{sampler.get_sampler_name()}'](model_wrap_cfg, xi, sigma_sched,
extra_args={'cond': conditioning, 'uncond': unconditional_conditioning,
'cond_scale': cfg_scale, 'mask': z_mask, 'x0': x0, 'xi': xi}, disable=False,
callback=generation_callback)
else:
x0, z_mask = init_data
@ -1448,7 +1457,7 @@ def layout():
else:
GFPGAN_available = False
if os.path.exists(os.path.join(defaults.general.RealESRGAN_dir, "experiments","pretrained_models", "RealESRGAN_x4plus.pth")):
if os.path.exists(os.path.join(defaults.general.RealESRGAN_dir, "experiments","pretrained_models", f"{defaults.general.RealESRGAN_model}.pth")):
RealESRGAN_available = True
else:
RealESRGAN_available = False
@ -1468,6 +1477,8 @@ def layout():
txt2img_tab, img2img_tab, postprocessing_tab = st.tabs(["Stable Diffusion Text-to-Image Unified", "Stable Diffusion Image-to-Image Unified", "Post-Processing"])
with txt2img_tab:
st.session_state["generation_mode"] = "txt2img"
with st.form("txt2img-inputs"):
input_col1, generate_col1 = st.columns([10,1])
@ -1484,8 +1495,8 @@ def layout():
col1, col2, col3 = st.columns([1,2,1], gap="large")
with col1:
height = st.slider("Height:", min_value=64, max_value=2048, value=defaults.txt2img.height, step=64)
width = st.slider("Width:", min_value=64, max_value=2048, value=defaults.txt2img.width, step=64)
height = st.slider("Height:", min_value=64, max_value=2048, value=defaults.txt2img.height, step=64)
cfg_scale = st.slider("CFG (Classifier Free Guidance Scale):", min_value=1.0, max_value=30.0, value=defaults.txt2img.cfg_scale, step=0.5, help="How strongly the image should follow the prompt.")
seed = st.text_input("Seed:", value=defaults.txt2img.seed, help=" The seed to use, if left blank a random seed will be generated.")
batch_count = st.slider("Batch count.", min_value=1, max_value=500, value=defaults.txt2img.batch_count, step=1, help="How many iterations or batches of images to generate in total.")
@ -1498,26 +1509,33 @@ def layout():
preview_tab, gallery_tab = st.tabs(["Preview", "Gallery"])
with preview_tab:
st.write("Image")
#st.write("Image")
#Image for testing
#image = Image.open(requests.get("https://icon-library.com/images/image-placeholder-icon/image-placeholder-icon-13.jpg", stream=True).raw)
#new_image = image.resize((175, 240))
#preview_image = st.image(image)
# create an empty container for the image and use session_state to hold it globally.
preview_image = st.empty()
st.session_state["preview_image"] = preview_image
# create an empty container for the image, progress bar, etc so we can update it later and use session_state to hold them globally.
st.session_state["preview_image"] = st.empty()
st.session_state["loading"] = st.empty()
st.session_state["progress_bar_text"] = st.empty()
st.session_state["progress_bar"] = st.empty()
message = st.empty()
with gallery_tab:
st.write('Here should be the image gallery, if I could make a grid in streamlit.')
with col3:
sampling_steps = st.slider("Sampling Steps", value=defaults.txt2img.sampling_steps, min_value=1, max_value=250)
st.session_state.sampling_steps = st.slider("Sampling Steps", value=defaults.txt2img.sampling_steps, min_value=1, max_value=250)
sampler_name = st.selectbox("Sampling method",
["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"],
index=0, help="Sampling method to use. Default: k_lms")
#basic_tab, advanced_tab = st.tabs(["Basic", "Advanced"])
#with basic_tab:
@ -1556,10 +1574,13 @@ def layout():
load_models(False, use_GFPGAN, use_RealESRGAN, RealESRGAN_model)
try:
output_images, seed, info, stats = txt2img(prompt, sampling_steps, sampler_name, RealESRGAN_model, batch_count, batch_size,
output_images, seed, info, stats = txt2img(prompt, st.session_state.sampling_steps, sampler_name, RealESRGAN_model, batch_count, batch_size,
cfg_scale, seed, height, width, separate_prompts, normalize_prompt_weights, save_individual_images,
save_grid, group_by_prompt, save_as_jpg, use_GFPGAN, use_RealESRGAN, RealESRGAN_model, fp=defaults.general.fp,
variant_amount=variant_amount, variant_seed=variant_seed, write_info_files=write_info_files)
message.success('Done!', icon="")
except (StopException, KeyError):
print(f"Received Streamlit StopException")
@ -1568,6 +1589,8 @@ def layout():
#preview_image.image(output_images, width=750)
with img2img_tab:
st.session_state["generation_mode"] = "img2img"
with st.form("img2img-inputs"):
img2img_input_col, img2img_generate_col = st.columns([10,1])
@ -1593,8 +1616,8 @@ def layout():
help="Upload an image which will be used for the image to image generation."
)
height = st.slider("Height:", min_value=64, max_value=2048, value=defaults.img2img.height, step=64)
width = st.slider("Width:", min_value=64, max_value=2048, value=defaults.img2img.width, step=64)
height = st.slider("Height:", min_value=64, max_value=2048, value=defaults.img2img.height, step=64)
seed = st.text_input("Seed:", value=defaults.img2img.seed, help=" The seed to use, if left blank a random seed will be generated.")
batch_count = st.slider("Batch count.", min_value=1, max_value=500, value=defaults.img2img.batch_count, step=1, help="How many iterations or batches of images to generate in total.")
@ -1633,27 +1656,43 @@ def layout():
It increases the VRAM usage a lot but if you have enough VRAM it can reduce the time it takes to finish generation as more images are generated at once.\
Default: 1")
st.session_state["denoising_strength"] = st.slider("Denoising Strength:", value=defaults.img2img.denoising_strength, min_value=0.0, max_value=1.0, step=0.01)
with col2_img2img_layout:
preview_tab, gallery_tab = st.tabs(["Preview", "Gallery"])
editor_tab = st.tabs(["Editor"])
with preview_tab:
st.write("Image")
#Image for testing
#image = Image.open(requests.get("https://icon-library.com/images/image-placeholder-icon/image-placeholder-icon-13.jpg", stream=True).raw)
#new_image = image.resize((175, 240))
#preview_image = st.image(image)
editor_image = st.empty()
st.session_state["editor_image"] = editor_image
# create an empty container for the image and use session_state to hold it globally.
if uploaded_images:
image = Image.open(uploaded_images)
#img_array = np.array(image) # if you want to pass it to OpenCV
new_img = image.resize((width, height))
st.image(new_img)
with col3_img2img_layout:
result_tab = st.tabs(["Result"])
# create an empty container for the image, progress bar, etc so we can update it later and use session_state to hold them globally.
preview_image = st.empty()
st.session_state["preview_image"] = preview_image
with gallery_tab:
st.write('Here should be the image gallery, if I could make a grid in streamlit.')
#st.session_state["loading"] = st.empty()
st.session_state["progress_bar_text"] = st.empty()
st.session_state["progress_bar"] = st.empty()
message = st.empty()
#if uploaded_images:
#image = Image.open(uploaded_images)
# img_array = np.array(image) # if you want to pass it to OpenCV
# st.image(image, use_column_width=True)
##img_array = np.array(image) # if you want to pass it to OpenCV
#new_img = image.resize((width, height))
#st.image(new_img, use_column_width=True)
if generate_button:
#print("Loading models")
@ -1661,9 +1700,24 @@ def layout():
load_models(False, use_GFPGAN, use_RealESRGAN, RealESRGAN_model)
if uploaded_images:
image = Image.open(uploaded_images)
img_array = np.array(image) # if you want to pass it to OpenCV
new_img = image.resize((width, height))
#img_array = np.array(image) # if you want to pass it to OpenCV
try:
output_images, seed, info, stats = img2img(prompt=prompt, init_info=image, ddim_steps=sampling_steps, sampler_name=sampler_name, n_iter=batch_count)
#output_images, seed, info, stats = img2img(prompt=prompt, init_info=new_img, ddim_steps=sampling_steps, sampler_name=sampler_name, n_iter=batch_count)
output_images, seed, info, stats = img2img(prompt=prompt, init_info=new_img, ddim_steps=sampling_steps, sampler_name=sampler_name, n_iter=batch_count,
cfg_scale=cfg_scale, denoising_strength=st.session_state["denoising_strength"], variant_seed=variant_seed,
seed=seed, width=width, height=height, fp=defaults.general.fp, variant_amount=variant_amount,
ddim_eta=0.0, write_info_files=True, RealESRGAN_model="RealESRGAN_x4plus_anime_6B",
separate_prompts=separate_prompts, normalize_prompt_weights=normalize_prompt_weights,
save_individual_images=save_individual_images, save_grid=save_grid,
group_by_prompt=group_by_prompt, save_as_jpg=save_as_jpg, use_GFPGAN=use_GFPGAN,
use_RealESRGAN=use_RealESRGAN
)
#show a message when the generation is complete.
message.success('Done!', icon="")
except (StopException, KeyError):
print(f"Received Streamlit StopException")

View File

@ -6,8 +6,8 @@ general:
ckpt: "models/ldm/stable-diffusion-v1/model.ckpt"
fp:
name: 'embeddings/alex/embeddings_gs-11000.pt'
GFPGAN_dir: "./src/GFPGAN"
RealESRGAN_dir: "./src/RealESRGAN"
GFPGAN_dir: "./src/gfpgan"
RealESRGAN_dir: "./src/realesrgan"
RealESRGAN_model: "RealESRGAN_x4plus"
outdir_txt2img: outputs/txt2img-samples
outdir_img2img: outputs/img2img-samples
@ -91,7 +91,7 @@ img2img:
group_by_prompt: True
save_as_jpg: False
use_GFPGAN: True
use_RealESRGAN: True
use_RealESRGAN: False
RealESRGAN_model: "RealESRGAN_x4plus"
variant_amount: 0.0