diff --git a/webui.py b/webui.py index b356784..95be9f6 100644 --- a/webui.py +++ b/webui.py @@ -56,6 +56,8 @@ parser.add_argument("--config", type=str, default="configs/stable-diffusion/v1-i parser.add_argument("--ckpt", type=str, default="models/ldm/stable-diffusion-v1/model.ckpt", help="path to checkpoint of model",) parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast") parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN')) # i disagree with where you're putting it but since all guidefags are doing it this way, there you go +parser.add_argument("--realesrgan-dir", type=str, help="RealESRGAN directory", default=('./src/realesrgan' if os.path.exists('./src/realesrgan') else './RealESRGAN')) +parser.add_argument("--realesrgan-model", type=str, help="Upscaling model for RealESRGAN", default=('RealESRGAN_x4plus')) parser.add_argument("--no-verify-input", action='store_true', help="do not verify input to check if it's too long") parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats") parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware accleration in browser)") @@ -63,6 +65,7 @@ parser.add_argument("--cli", type=str, help="don't launch web server, take Pytho opt = parser.parse_args() GFPGAN_dir = opt.gfpgan_dir +RealESRGAN_dir = opt.realesrgan_dir css_hide_progressbar = """ .wrap .m-12 svg { display:none!important; } @@ -228,6 +231,23 @@ def load_GFPGAN(): return GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None) +def load_RealESRGAN(model_name: str): + from basicsr.archs.rrdbnet_arch import RRDBNet + RealESRGAN_models = { + 'RealESRGAN_x4plus': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4), + 'RealESRGAN_x4plus_anime_6B': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) + } + + model_path = os.path.join(RealESRGAN_dir, 'experiments/pretrained_models', model_name + '.pth') + if not os.path.isfile(model_path): + raise Exception(model_name+".pth not found at path "+model_path) + + sys.path.append(os.path.abspath(RealESRGAN_dir)) + from realesrgan import RealESRGANer + + instance = RealESRGANer(scale=2, model_path=model_path, model=RealESRGAN_models[model_name], pre_pad=0, half=True) + instance.model.name = model_name + return instance GFPGAN = None if os.path.exists(GFPGAN_dir): @@ -239,6 +259,19 @@ if os.path.exists(GFPGAN_dir): print("Error loading GFPGAN:", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) +RealESRGAN = None +def try_loading_RealESRGAN(model_name: str): + global RealESRGAN + if os.path.exists(RealESRGAN_dir): + try: + RealESRGAN = load_RealESRGAN(model_name) # TODO: Should try to load both models before giving up + print("Loaded RealESRGAN with model "+RealESRGAN.model.name) + except Exception: + import traceback + print("Error loading RealESRGAN:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) +try_loading_RealESRGAN('RealESRGAN_x4plus') + config = OmegaConf.load("configs/stable-diffusion/v1-inference.yaml") model = load_model_from_config(config, "models/ldm/stable-diffusion-v1/model.ckpt") @@ -410,7 +443,7 @@ def check_prompt_length(prompt, comments): comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") -def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name, skip_grid, skip_save, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, fp, do_not_save_grid=False, normalize_prompt_weights=True, init_img=None, init_mask=None, keep_mask=False): +def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name, skip_grid, skip_save, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, use_RealESRGAN, realesrgan_model_name, fp, do_not_save_grid=False, normalize_prompt_weights=True, init_img=None, init_mask=None, keep_mask=False): """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch""" assert prompt is not None torch_gc() @@ -514,6 +547,13 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name, cropped_faces, restored_faces, restored_img = GFPGAN.enhance(x_sample[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True) x_sample = restored_img[:,:,::-1] + if use_RealESRGAN and RealESRGAN is not None: + if RealESRGAN.model.name != realesrgan_model_name: + try_loading_RealESRGAN(realesrgan_model_name) + + output, img_mode = RealESRGAN.enhance(x_sample[:,:,::-1]) + x_sample = output[:,:,::-1] + image = Image.fromarray(x_sample) if init_mask: #init_mask = init_mask if keep_mask else ImageOps.invert(init_mask) @@ -521,6 +561,18 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name, init_mask = init_mask.convert('L') init_img = init_img.convert('RGB') image = image.convert('RGB') + + if use_RealESRGAN and RealESRGAN is not None: + if RealESRGAN.model.name != realesrgan_model_name: + try_loading_RealESRGAN(realesrgan_model_name) + output, img_mode = RealESRGAN.enhance(np.array(init_img, dtype=np.uint8)) + init_img = Image.fromarray(output) + init_img = init_img.convert('RGB') + + output, img_mode = RealESRGAN.enhance(np.array(init_mask, dtype=np.uint8)) + init_mask = Image.fromarray(output) + init_mask = init_mask.convert('L') + image = Image.composite(init_img, image, init_mask) filename = f"{base_count:05}-{seeds[i]}_{prompts[i].replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128]}.png" @@ -554,7 +606,7 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name, info = f""" {prompt} -Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', GFPGAN' if use_GFPGAN and GFPGAN is not None else ''}{', Prompt Matrix Mode.' if prompt_matrix else ''}""".strip() +Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', GFPGAN' if use_GFPGAN and GFPGAN is not None else ''}{', RealESRGAN' if use_RealESRGAN and RealESRGAN is not None else ''}{', Prompt Matrix Mode.' if prompt_matrix else ''}""".strip() stats = f''' Took { round(time_diff, 2) }s total ({ round(time_diff/(len(all_prompts)),2) }s per image) Peak memory usage: { -(mem_max_used // -1_048_576) } MiB / { -(mem_total // -1_048_576) } MiB / { round(mem_max_used/mem_total*100, 3) }%''' @@ -569,7 +621,7 @@ Peak memory usage: { -(mem_max_used // -1_048_576) } MiB / { -(mem_total // -1_0 return output_images, seed, info, stats -def txt2img(prompt: str, ddim_steps: int, sampler_name: str, toggles: List[int], ddim_eta: float, n_iter: int, +def txt2img(prompt: str, ddim_steps: int, sampler_name: str, toggles: List[int], realesrgan_model_name: str, ddim_eta: float, n_iter: int, batch_size: int, cfg_scale: float, seed: Union[int, str, None], height: int, width: int, fp): outpath = opt.outdir_txt2img or opt.outdir or "outputs/txt2img-samples" err = False @@ -580,6 +632,7 @@ def txt2img(prompt: str, ddim_steps: int, sampler_name: str, toggles: List[int], skip_save = 2 not in toggles skip_grid = 3 not in toggles use_GFPGAN = 4 in toggles + use_RealESRGAN = 5 in toggles if sampler_name == 'PLMS': sampler = PLMSSampler(model) @@ -625,6 +678,8 @@ def txt2img(prompt: str, ddim_steps: int, sampler_name: str, toggles: List[int], height=height, prompt_matrix=prompt_matrix, use_GFPGAN=use_GFPGAN, + use_RealESRGAN=use_RealESRGAN, + realesrgan_model_name=realesrgan_model_name, fp=fp, normalize_prompt_weights=normalize_prompt_weights ) @@ -685,7 +740,7 @@ class Flagging(gr.FlaggingCallback): def img2img(prompt: str, image_editor_mode: str, cropped_image, image_with_mask, mask_mode: str, ddim_steps: int, sampler_name: str, - toggles: List[int], n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, + toggles: List[int], realesrgan_model_name: str, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int, fp): outpath = opt.outdir_img2img or opt.outdir or "outputs/img2img-samples" err = False @@ -698,6 +753,7 @@ def img2img(prompt: str, image_editor_mode: str, cropped_image, image_with_mask, skip_save = 4 not in toggles skip_grid = 5 not in toggles use_GFPGAN = 6 in toggles + use_RealESRGAN = 7 in toggles if sampler_name == 'DDIM': sampler = DDIMSampler(model) @@ -793,6 +849,8 @@ def img2img(prompt: str, image_editor_mode: str, cropped_image, image_with_mask, height=height, prompt_matrix=prompt_matrix, use_GFPGAN=use_GFPGAN, + use_RealESRGAN=False, # Forcefully disable upscaling when using loopback + realesrgan_model_name=realesrgan_model_name, fp=fp, do_not_save_grid=True, normalize_prompt_weights=normalize_prompt_weights, @@ -840,6 +898,8 @@ def img2img(prompt: str, image_editor_mode: str, cropped_image, image_with_mask, height=height, prompt_matrix=prompt_matrix, use_GFPGAN=use_GFPGAN, + use_RealESRGAN=use_RealESRGAN, + realesrgan_model_name=realesrgan_model_name, fp=fp, normalize_prompt_weights=normalize_prompt_weights, init_img=init_img, @@ -916,6 +976,17 @@ def run_GFPGAN(image, strength): return res +def run_RealESRGAN(image, model_name: str): + if RealESRGAN.model.name != model_name: + try_loading_RealESRGAN(model_name) + + image = image.convert("RGB") + + output, img_mode = RealESRGAN.enhance(np.array(image, dtype=np.uint8)) + res = Image.fromarray(output) + + return res + css = "" if opt.no_progressbar_hiding else css_hide_progressbar css = css + '[data-testid="image"] {min-height: 512px !important}' @@ -931,6 +1002,8 @@ txt2img_toggles = [ ] if GFPGAN is not None: txt2img_toggles.append('Fix faces using GFPGAN') +if RealESRGAN is not None: + txt2img_toggles.append('Upscale images using RealESRGAN') txt2img_toggle_defaults = [ 'Normalize Prompt Weights (ensure sum of weights add up to 1.0)', @@ -949,6 +1022,8 @@ img2img_toggles = [ ] if GFPGAN is not None: img2img_toggles.append('Fix faces using GFPGAN') +if RealESRGAN is not None: + img2img_toggles.append('Upscale images using RealESRGAN') img2img_toggle_defaults = [ 'Normalize Prompt Weights (ensure sum of weights add up to 1.0)', @@ -985,6 +1060,7 @@ with gr.Blocks(css=css) as demo: txt2img_steps = gr.Slider(minimum=1, maximum=250, step=1, label="Sampling Steps", value=50) txt2img_sampling = gr.Radio(label='Sampling method (k_lms is default k-diffusion sampler)', choices=["DDIM", "PLMS", 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms'], value="k_lms") txt2img_toggles = gr.CheckboxGroup(label='', choices=txt2img_toggles, value=txt2img_toggle_defaults, type="index") + txt2img_realesrgan_model_name = gr.Dropdown(label='RealESRGAN model', choices=['RealESRGAN_x4plus', 'RealESRGAN_x4plus_anime_6B'], value='RealESRGAN_x4plus', visible=RealESRGAN is not None) # TODO: Feels like I shouldnt slot it in here. txt2img_ddim_eta = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="DDIM ETA", value=0.0, visible=False) txt2img_batch_count = gr.Slider(minimum=1, maximum=250, step=1, label='Batch count (how many batches of images to generate)', value=1) txt2img_batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size (how many images are in a batch; memory-hungry)', value=1) @@ -1004,7 +1080,7 @@ with gr.Blocks(css=css) as demo: txt2img_btn.click( txt2img, - [txt2img_prompt, txt2img_steps, txt2img_sampling, txt2img_toggles, txt2img_ddim_eta, txt2img_batch_count, txt2img_batch_size, txt2img_cfg, txt2img_seed, txt2img_height, txt2img_width, txt2img_embeddings], + [txt2img_prompt, txt2img_steps, txt2img_sampling, txt2img_toggles, txt2img_realesrgan_model_name, txt2img_ddim_eta, txt2img_batch_count, txt2img_batch_size, txt2img_cfg, txt2img_seed, txt2img_height, txt2img_width, txt2img_embeddings], [output_txt2img_gallery, output_txt2img_seed, output_txt2img_params, output_txt2img_stats] ) @@ -1021,6 +1097,7 @@ with gr.Blocks(css=css) as demo: img2img_steps = gr.Slider(minimum=1, maximum=250, step=1, label="Sampling Steps", value=50) img2img_sampling = gr.Radio(label='Sampling method (k_lms is default k-diffusion sampler)', choices=["DDIM", 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms'], value="k_lms") img2img_toggles = gr.CheckboxGroup(label='', choices=img2img_toggles, value=img2img_toggle_defaults, type="index") + img2img_realesrgan_model_name = gr.Dropdown(label='RealESRGAN model', choices=['RealESRGAN_x4plus', 'RealESRGAN_x4plus_anime_6B'], value='RealESRGAN_x4plus', visible=RealESRGAN is not None) # TODO: Feels like I shouldnt slot it in here. img2img_batch_count = gr.Slider(minimum=1, maximum=250, step=1, label='Batch count (how many batches of images to generate)', value=1) img2img_batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size (how many images are in a batch; memory-hungry)', value=1) img2img_cfg = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='Classifier Free Guidance Scale (how strongly the image should follow the prompt)', value=5.0) @@ -1060,7 +1137,7 @@ with gr.Blocks(css=css) as demo: img2img_btn.click( img2img, - [img2img_prompt, img2img_image_editor_mode, img2img_image_editor, img2img_image_mask, img2img_mask, img2img_steps, img2img_sampling, img2img_toggles, img2img_batch_count, img2img_batch_size, img2img_cfg, img2img_denoising, img2img_seed, img2img_height, img2img_width, img2img_resize, img2img_embeddings], + [img2img_prompt, img2img_image_editor_mode, img2img_image_editor, img2img_image_mask, img2img_mask, img2img_steps, img2img_sampling, img2img_toggles, img2img_realesrgan_model_name, img2img_batch_count, img2img_batch_size, img2img_cfg, img2img_denoising, img2img_seed, img2img_height, img2img_width, img2img_resize, img2img_embeddings], [output_img2img_gallery, output_img2img_seed, output_img2img_params, output_img2img_stats] ) if GFPGAN is not None: @@ -1078,6 +1155,21 @@ with gr.Blocks(css=css) as demo: [gfpgan_source, gfpgan_strength], [gfpgan_output] ) + if RealESRGAN is not None: + with gr.TabItem("RealESRGAN"): + gr.Markdown("Upscale images") + with gr.Row(): + with gr.Column(): + realesrgan_source = gr.Image(label="Source", source="upload", interactive=True, type="pil") + realesrgan_model_name = gr.Dropdown(label='RealESRGAN model', choices=['RealESRGAN_x4plus', 'RealESRGAN_x4plus_anime_6B'], value='RealESRGAN_x4plus') + realesrgan_btn = gr.Button("Generate") + with gr.Column(): + realesrgan_output = gr.Image(label="Output") + realesrgan_btn.click( + run_RealESRGAN, + [realesrgan_source, realesrgan_model_name], + [realesrgan_output] + ) output_txt2img_copy_to_input_btn.click( copy_img_to_input,