mirror of
https://github.com/sd-webui/stable-diffusion-webui.git
synced 2024-12-15 15:22:55 +03:00
Added a progress bar and updated the Image-to-image tab. (#652)
* Added a progress bar as well as some extra info to know how the generation is going without having to check the console every time. * - Updated the Image-to-image tab, it is now working at a basic level. - Disabled RealESRGAN by default for the Image-to-Image tab as it is not working right now.
This commit is contained in:
parent
75a7ef77f0
commit
78ad3c3445
@ -88,6 +88,8 @@ os.environ["CUDA_VISIBLE_DEVICES"] = str(defaults.general.gpu)
|
||||
def load_models(continue_prev_run = False, use_GFPGAN=False, use_RealESRGAN=False, RealESRGAN_model="RealESRGAN_x4plus"):
|
||||
"""Load the different models. We also reuse the models that are already in memory to speed things up instead of loading them again. """
|
||||
|
||||
print ("Loading models.")
|
||||
|
||||
# Generate random run ID
|
||||
# Used to link runs linked w/ continue_prev_run which is not yet implemented
|
||||
# Use URL and filesystem safe version just in case.
|
||||
@ -180,8 +182,13 @@ def load_sd_from_config(ckpt, verbose=False):
|
||||
|
||||
def generation_callback(img, i=0):
|
||||
|
||||
try:
|
||||
if i == 0:
|
||||
if img['i']: i = img['i']
|
||||
except TypeError:
|
||||
pass
|
||||
|
||||
|
||||
|
||||
if i % int(defaults.general.update_preview_frequency) == 0 and defaults.general.update_preview:
|
||||
#print (img)
|
||||
@ -200,8 +207,21 @@ def generation_callback(img, i=0):
|
||||
|
||||
pil_image = transforms.ToPILImage()(x_samples_ddim.squeeze_(0))
|
||||
|
||||
# update image on the UI so we can see the progress
|
||||
st.session_state["preview_image"].image(pil_image, width=512)
|
||||
|
||||
# Show a progress bar so we can keep track of the progress even when the image progress is not been shown,
|
||||
# Dont worry, it doesnt affect the performance.
|
||||
if st.session_state["generation_mode"] == "txt2img":
|
||||
percent = int(100 * float(i+1)/float(st.session_state.sampling_steps))
|
||||
st.session_state["progress_bar_text"].text(f"Running step: {i+1}/{st.session_state.sampling_steps} {percent}%")
|
||||
else:
|
||||
percent = int(100 * float(i+1 )/float(round(st.session_state.sampling_steps * st.session_state["denoising_strength"])))
|
||||
st.session_state["progress_bar_text"].text(f"""Running step: {i+1}/{round(st.session_state.sampling_steps * st.session_state["denoising_strength"])} {percent}%""")
|
||||
|
||||
st.session_state["progress_bar"] = st.session_state["progress_bar"].progress(percent)
|
||||
|
||||
|
||||
|
||||
class MemUsageMonitor(threading.Thread):
|
||||
stop_flag = False
|
||||
@ -401,7 +421,7 @@ def load_RealESRGAN(model_name: str):
|
||||
}
|
||||
|
||||
model_path = os.path.join(defaults.general.RealESRGAN_dir, 'experiments/pretrained_models', model_name + '.pth')
|
||||
if not os.path.isfile(model_path):
|
||||
if not os.path.exists(os.path.join(defaults.general.RealESRGAN_dir, "experiments","pretrained_models", f"{model_name}.pth")):
|
||||
raise Exception(model_name+".pth not found at path "+model_path)
|
||||
|
||||
sys.path.append(os.path.abspath(defaults.general.RealESRGAN_dir))
|
||||
@ -505,27 +525,6 @@ def ModelLoader(models,load=False,unload=False,imgproc_realesrgan_model_name='Re
|
||||
torch_gc()
|
||||
|
||||
|
||||
def run_GFPGAN(image, strength):
|
||||
image = image.convert("RGB")
|
||||
|
||||
cropped_faces, restored_faces, restored_img = GFPGAN.enhance(np.array(image, dtype=np.uint8), has_aligned=False, only_center_face=False, paste_back=True)
|
||||
res = Image.fromarray(restored_img)
|
||||
|
||||
if strength < 1.0:
|
||||
res = Image.blend(image, res, strength)
|
||||
|
||||
return res
|
||||
|
||||
def run_RealESRGAN(image, model_name: str):
|
||||
if RealESRGAN.model.name != model_name:
|
||||
try_loading_RealESRGAN(model_name)
|
||||
|
||||
image = image.convert("RGB")
|
||||
|
||||
output, img_mode = RealESRGAN.enhance(np.array(image, dtype=np.uint8))
|
||||
res = Image.fromarray(output)
|
||||
|
||||
return res
|
||||
|
||||
def get_font(fontsize):
|
||||
fonts = ["arial.ttf", "DejaVuSans.ttf"]
|
||||
@ -957,8 +956,11 @@ def process_images(
|
||||
if use_RealESRGAN and st.session_state["RealESRGAN"] is not None and not use_GFPGAN:
|
||||
skip_save = True # #287 >_>
|
||||
torch_gc()
|
||||
|
||||
if st.session_state["RealESRGAN"].model.name != realesrgan_model_name:
|
||||
try_loading_RealESRGAN(realesrgan_model_name)
|
||||
#try_loading_RealESRGAN(realesrgan_model_name)
|
||||
load_models(use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name)
|
||||
|
||||
output, img_mode = st.session_state["RealESRGAN"].enhance(x_sample[:,:,::-1])
|
||||
esrgan_filename = original_filename + '-esrgan4x'
|
||||
esrgan_sample = output[:,:,::-1]
|
||||
@ -980,8 +982,11 @@ def process_images(
|
||||
torch_gc()
|
||||
cropped_faces, restored_faces, restored_img = st.session_state["GFPGAN"].enhance(x_sample[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True)
|
||||
gfpgan_sample = restored_img[:,:,::-1]
|
||||
|
||||
if st.session_state["RealESRGAN"].model.name != realesrgan_model_name:
|
||||
try_loading_RealESRGAN(realesrgan_model_name)
|
||||
#try_loading_RealESRGAN(realesrgan_model_name)
|
||||
load_models(use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name)
|
||||
|
||||
output, img_mode = st.session_state["RealESRGAN"].enhance(gfpgan_sample[:,:,::-1])
|
||||
gfpgan_esrgan_filename = original_filename + '-gfpgan-esrgan4x'
|
||||
gfpgan_esrgan_sample = output[:,:,::-1]
|
||||
@ -1102,7 +1107,8 @@ def img2img(prompt: str = '', init_info: any = None, ddim_steps: int = 50, sampl
|
||||
write_info_files:bool = True, RealESRGAN_model: str = "RealESRGAN_x4plus_anime_6B",
|
||||
separate_prompts:bool = False, normalize_prompt_weights:bool = True,
|
||||
save_individual_images: bool = True, save_grid: bool = True, group_by_prompt: bool = True,
|
||||
save_as_jpg: bool = True, use_GFPGAN: bool = True, use_RealESRGAN: bool = True):
|
||||
save_as_jpg: bool = True, use_GFPGAN: bool = True, use_RealESRGAN: bool = True, loopback: bool = True
|
||||
):
|
||||
|
||||
outpath = defaults.general.outdir_img2img or defaults.general.outdir or "outputs/img2img-samples"
|
||||
err = False
|
||||
@ -1196,7 +1202,10 @@ def img2img(prompt: str = '', init_info: any = None, ddim_steps: int = 50, sampl
|
||||
|
||||
sigma_sched = sigmas[ddim_steps - t_enc_steps - 1:]
|
||||
model_wrap_cfg = CFGMaskedDenoiser(sampler.model_wrap)
|
||||
samples_ddim = K.sampling.__dict__[f'sample_{sampler.get_sampler_name()}'](model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': cfg_scale, 'mask': z_mask, 'x0': x0, 'xi': xi}, disable=False)
|
||||
samples_ddim = K.sampling.__dict__[f'sample_{sampler.get_sampler_name()}'](model_wrap_cfg, xi, sigma_sched,
|
||||
extra_args={'cond': conditioning, 'uncond': unconditional_conditioning,
|
||||
'cond_scale': cfg_scale, 'mask': z_mask, 'x0': x0, 'xi': xi}, disable=False,
|
||||
callback=generation_callback)
|
||||
else:
|
||||
|
||||
x0, z_mask = init_data
|
||||
@ -1448,7 +1457,7 @@ def layout():
|
||||
else:
|
||||
GFPGAN_available = False
|
||||
|
||||
if os.path.exists(os.path.join(defaults.general.RealESRGAN_dir, "experiments","pretrained_models", "RealESRGAN_x4plus.pth")):
|
||||
if os.path.exists(os.path.join(defaults.general.RealESRGAN_dir, "experiments","pretrained_models", f"{defaults.general.RealESRGAN_model}.pth")):
|
||||
RealESRGAN_available = True
|
||||
else:
|
||||
RealESRGAN_available = False
|
||||
@ -1468,6 +1477,8 @@ def layout():
|
||||
txt2img_tab, img2img_tab, postprocessing_tab = st.tabs(["Stable Diffusion Text-to-Image Unified", "Stable Diffusion Image-to-Image Unified", "Post-Processing"])
|
||||
|
||||
with txt2img_tab:
|
||||
st.session_state["generation_mode"] = "txt2img"
|
||||
|
||||
with st.form("txt2img-inputs"):
|
||||
|
||||
input_col1, generate_col1 = st.columns([10,1])
|
||||
@ -1484,8 +1495,8 @@ def layout():
|
||||
col1, col2, col3 = st.columns([1,2,1], gap="large")
|
||||
|
||||
with col1:
|
||||
height = st.slider("Height:", min_value=64, max_value=2048, value=defaults.txt2img.height, step=64)
|
||||
width = st.slider("Width:", min_value=64, max_value=2048, value=defaults.txt2img.width, step=64)
|
||||
height = st.slider("Height:", min_value=64, max_value=2048, value=defaults.txt2img.height, step=64)
|
||||
cfg_scale = st.slider("CFG (Classifier Free Guidance Scale):", min_value=1.0, max_value=30.0, value=defaults.txt2img.cfg_scale, step=0.5, help="How strongly the image should follow the prompt.")
|
||||
seed = st.text_input("Seed:", value=defaults.txt2img.seed, help=" The seed to use, if left blank a random seed will be generated.")
|
||||
batch_count = st.slider("Batch count.", min_value=1, max_value=500, value=defaults.txt2img.batch_count, step=1, help="How many iterations or batches of images to generate in total.")
|
||||
@ -1498,26 +1509,33 @@ def layout():
|
||||
preview_tab, gallery_tab = st.tabs(["Preview", "Gallery"])
|
||||
|
||||
with preview_tab:
|
||||
st.write("Image")
|
||||
#st.write("Image")
|
||||
#Image for testing
|
||||
#image = Image.open(requests.get("https://icon-library.com/images/image-placeholder-icon/image-placeholder-icon-13.jpg", stream=True).raw)
|
||||
#new_image = image.resize((175, 240))
|
||||
#preview_image = st.image(image)
|
||||
|
||||
# create an empty container for the image and use session_state to hold it globally.
|
||||
preview_image = st.empty()
|
||||
st.session_state["preview_image"] = preview_image
|
||||
# create an empty container for the image, progress bar, etc so we can update it later and use session_state to hold them globally.
|
||||
st.session_state["preview_image"] = st.empty()
|
||||
|
||||
st.session_state["loading"] = st.empty()
|
||||
|
||||
st.session_state["progress_bar_text"] = st.empty()
|
||||
st.session_state["progress_bar"] = st.empty()
|
||||
|
||||
message = st.empty()
|
||||
|
||||
with gallery_tab:
|
||||
st.write('Here should be the image gallery, if I could make a grid in streamlit.')
|
||||
|
||||
with col3:
|
||||
sampling_steps = st.slider("Sampling Steps", value=defaults.txt2img.sampling_steps, min_value=1, max_value=250)
|
||||
st.session_state.sampling_steps = st.slider("Sampling Steps", value=defaults.txt2img.sampling_steps, min_value=1, max_value=250)
|
||||
sampler_name = st.selectbox("Sampling method",
|
||||
["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"],
|
||||
index=0, help="Sampling method to use. Default: k_lms")
|
||||
|
||||
|
||||
|
||||
#basic_tab, advanced_tab = st.tabs(["Basic", "Advanced"])
|
||||
|
||||
#with basic_tab:
|
||||
@ -1556,10 +1574,13 @@ def layout():
|
||||
load_models(False, use_GFPGAN, use_RealESRGAN, RealESRGAN_model)
|
||||
|
||||
try:
|
||||
output_images, seed, info, stats = txt2img(prompt, sampling_steps, sampler_name, RealESRGAN_model, batch_count, batch_size,
|
||||
output_images, seed, info, stats = txt2img(prompt, st.session_state.sampling_steps, sampler_name, RealESRGAN_model, batch_count, batch_size,
|
||||
cfg_scale, seed, height, width, separate_prompts, normalize_prompt_weights, save_individual_images,
|
||||
save_grid, group_by_prompt, save_as_jpg, use_GFPGAN, use_RealESRGAN, RealESRGAN_model, fp=defaults.general.fp,
|
||||
variant_amount=variant_amount, variant_seed=variant_seed, write_info_files=write_info_files)
|
||||
|
||||
message.success('Done!', icon="✅")
|
||||
|
||||
except (StopException, KeyError):
|
||||
print(f"Received Streamlit StopException")
|
||||
|
||||
@ -1568,6 +1589,8 @@ def layout():
|
||||
#preview_image.image(output_images, width=750)
|
||||
|
||||
with img2img_tab:
|
||||
st.session_state["generation_mode"] = "img2img"
|
||||
|
||||
with st.form("img2img-inputs"):
|
||||
|
||||
img2img_input_col, img2img_generate_col = st.columns([10,1])
|
||||
@ -1593,8 +1616,8 @@ def layout():
|
||||
help="Upload an image which will be used for the image to image generation."
|
||||
)
|
||||
|
||||
height = st.slider("Height:", min_value=64, max_value=2048, value=defaults.img2img.height, step=64)
|
||||
width = st.slider("Width:", min_value=64, max_value=2048, value=defaults.img2img.width, step=64)
|
||||
height = st.slider("Height:", min_value=64, max_value=2048, value=defaults.img2img.height, step=64)
|
||||
seed = st.text_input("Seed:", value=defaults.img2img.seed, help=" The seed to use, if left blank a random seed will be generated.")
|
||||
batch_count = st.slider("Batch count.", min_value=1, max_value=500, value=defaults.img2img.batch_count, step=1, help="How many iterations or batches of images to generate in total.")
|
||||
|
||||
@ -1633,27 +1656,43 @@ def layout():
|
||||
It increases the VRAM usage a lot but if you have enough VRAM it can reduce the time it takes to finish generation as more images are generated at once.\
|
||||
Default: 1")
|
||||
|
||||
st.session_state["denoising_strength"] = st.slider("Denoising Strength:", value=defaults.img2img.denoising_strength, min_value=0.0, max_value=1.0, step=0.01)
|
||||
|
||||
|
||||
with col2_img2img_layout:
|
||||
preview_tab, gallery_tab = st.tabs(["Preview", "Gallery"])
|
||||
editor_tab = st.tabs(["Editor"])
|
||||
|
||||
with preview_tab:
|
||||
st.write("Image")
|
||||
#Image for testing
|
||||
#image = Image.open(requests.get("https://icon-library.com/images/image-placeholder-icon/image-placeholder-icon-13.jpg", stream=True).raw)
|
||||
#new_image = image.resize((175, 240))
|
||||
#preview_image = st.image(image)
|
||||
editor_image = st.empty()
|
||||
st.session_state["editor_image"] = editor_image
|
||||
|
||||
# create an empty container for the image and use session_state to hold it globally.
|
||||
if uploaded_images:
|
||||
image = Image.open(uploaded_images)
|
||||
#img_array = np.array(image) # if you want to pass it to OpenCV
|
||||
new_img = image.resize((width, height))
|
||||
st.image(new_img)
|
||||
|
||||
|
||||
with col3_img2img_layout:
|
||||
result_tab = st.tabs(["Result"])
|
||||
|
||||
# create an empty container for the image, progress bar, etc so we can update it later and use session_state to hold them globally.
|
||||
preview_image = st.empty()
|
||||
st.session_state["preview_image"] = preview_image
|
||||
|
||||
with gallery_tab:
|
||||
st.write('Here should be the image gallery, if I could make a grid in streamlit.')
|
||||
#st.session_state["loading"] = st.empty()
|
||||
|
||||
st.session_state["progress_bar_text"] = st.empty()
|
||||
st.session_state["progress_bar"] = st.empty()
|
||||
|
||||
|
||||
message = st.empty()
|
||||
|
||||
#if uploaded_images:
|
||||
#image = Image.open(uploaded_images)
|
||||
# img_array = np.array(image) # if you want to pass it to OpenCV
|
||||
# st.image(image, use_column_width=True)
|
||||
##img_array = np.array(image) # if you want to pass it to OpenCV
|
||||
#new_img = image.resize((width, height))
|
||||
#st.image(new_img, use_column_width=True)
|
||||
|
||||
|
||||
if generate_button:
|
||||
#print("Loading models")
|
||||
@ -1661,9 +1700,24 @@ def layout():
|
||||
load_models(False, use_GFPGAN, use_RealESRGAN, RealESRGAN_model)
|
||||
if uploaded_images:
|
||||
image = Image.open(uploaded_images)
|
||||
img_array = np.array(image) # if you want to pass it to OpenCV
|
||||
new_img = image.resize((width, height))
|
||||
#img_array = np.array(image) # if you want to pass it to OpenCV
|
||||
|
||||
try:
|
||||
output_images, seed, info, stats = img2img(prompt=prompt, init_info=image, ddim_steps=sampling_steps, sampler_name=sampler_name, n_iter=batch_count)
|
||||
#output_images, seed, info, stats = img2img(prompt=prompt, init_info=new_img, ddim_steps=sampling_steps, sampler_name=sampler_name, n_iter=batch_count)
|
||||
output_images, seed, info, stats = img2img(prompt=prompt, init_info=new_img, ddim_steps=sampling_steps, sampler_name=sampler_name, n_iter=batch_count,
|
||||
cfg_scale=cfg_scale, denoising_strength=st.session_state["denoising_strength"], variant_seed=variant_seed,
|
||||
seed=seed, width=width, height=height, fp=defaults.general.fp, variant_amount=variant_amount,
|
||||
ddim_eta=0.0, write_info_files=True, RealESRGAN_model="RealESRGAN_x4plus_anime_6B",
|
||||
separate_prompts=separate_prompts, normalize_prompt_weights=normalize_prompt_weights,
|
||||
save_individual_images=save_individual_images, save_grid=save_grid,
|
||||
group_by_prompt=group_by_prompt, save_as_jpg=save_as_jpg, use_GFPGAN=use_GFPGAN,
|
||||
use_RealESRGAN=use_RealESRGAN
|
||||
)
|
||||
|
||||
#show a message when the generation is complete.
|
||||
message.success('Done!', icon="✅")
|
||||
|
||||
except (StopException, KeyError):
|
||||
print(f"Received Streamlit StopException")
|
||||
|
||||
|
@ -6,8 +6,8 @@ general:
|
||||
ckpt: "models/ldm/stable-diffusion-v1/model.ckpt"
|
||||
fp:
|
||||
name: 'embeddings/alex/embeddings_gs-11000.pt'
|
||||
GFPGAN_dir: "./src/GFPGAN"
|
||||
RealESRGAN_dir: "./src/RealESRGAN"
|
||||
GFPGAN_dir: "./src/gfpgan"
|
||||
RealESRGAN_dir: "./src/realesrgan"
|
||||
RealESRGAN_model: "RealESRGAN_x4plus"
|
||||
outdir_txt2img: outputs/txt2img-samples
|
||||
outdir_img2img: outputs/img2img-samples
|
||||
@ -91,7 +91,7 @@ img2img:
|
||||
group_by_prompt: True
|
||||
save_as_jpg: False
|
||||
use_GFPGAN: True
|
||||
use_RealESRGAN: True
|
||||
use_RealESRGAN: False
|
||||
RealESRGAN_model: "RealESRGAN_x4plus"
|
||||
variant_amount: 0.0
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user