Add support for RealESRGAN upscaling

This commit is contained in:
chanoc 2022-08-26 02:37:35 -07:00
parent ab2770ef12
commit 601b176a18

104
webui.py
View File

@ -56,6 +56,8 @@ parser.add_argument("--config", type=str, default="configs/stable-diffusion/v1-i
parser.add_argument("--ckpt", type=str, default="models/ldm/stable-diffusion-v1/model.ckpt", help="path to checkpoint of model",)
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN')) # i disagree with where you're putting it but since all guidefags are doing it this way, there you go
parser.add_argument("--realesrgan-dir", type=str, help="RealESRGAN directory", default=('./src/realesrgan' if os.path.exists('./src/realesrgan') else './RealESRGAN'))
parser.add_argument("--realesrgan-model", type=str, help="Upscaling model for RealESRGAN", default=('RealESRGAN_x4plus'))
parser.add_argument("--no-verify-input", action='store_true', help="do not verify input to check if it's too long")
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware accleration in browser)")
@ -63,6 +65,7 @@ parser.add_argument("--cli", type=str, help="don't launch web server, take Pytho
opt = parser.parse_args()
GFPGAN_dir = opt.gfpgan_dir
RealESRGAN_dir = opt.realesrgan_dir
css_hide_progressbar = """
.wrap .m-12 svg { display:none!important; }
@ -228,6 +231,23 @@ def load_GFPGAN():
return GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
def load_RealESRGAN(model_name: str):
from basicsr.archs.rrdbnet_arch import RRDBNet
RealESRGAN_models = {
'RealESRGAN_x4plus': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4),
'RealESRGAN_x4plus_anime_6B': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
}
model_path = os.path.join(RealESRGAN_dir, 'experiments/pretrained_models', model_name + '.pth')
if not os.path.isfile(model_path):
raise Exception(model_name+".pth not found at path "+model_path)
sys.path.append(os.path.abspath(RealESRGAN_dir))
from realesrgan import RealESRGANer
instance = RealESRGANer(scale=2, model_path=model_path, model=RealESRGAN_models[model_name], pre_pad=0, half=True)
instance.model.name = model_name
return instance
GFPGAN = None
if os.path.exists(GFPGAN_dir):
@ -239,6 +259,19 @@ if os.path.exists(GFPGAN_dir):
print("Error loading GFPGAN:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
RealESRGAN = None
def try_loading_RealESRGAN(model_name: str):
global RealESRGAN
if os.path.exists(RealESRGAN_dir):
try:
RealESRGAN = load_RealESRGAN(model_name) # TODO: Should try to load both models before giving up
print("Loaded RealESRGAN with model "+RealESRGAN.model.name)
except Exception:
import traceback
print("Error loading RealESRGAN:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
try_loading_RealESRGAN('RealESRGAN_x4plus')
config = OmegaConf.load("configs/stable-diffusion/v1-inference.yaml")
model = load_model_from_config(config, "models/ldm/stable-diffusion-v1/model.ckpt")
@ -410,7 +443,7 @@ def check_prompt_length(prompt, comments):
comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name, skip_grid, skip_save, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, fp, do_not_save_grid=False, normalize_prompt_weights=True, init_img=None, init_mask=None, keep_mask=False):
def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name, skip_grid, skip_save, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, use_RealESRGAN, realesrgan_model_name, fp, do_not_save_grid=False, normalize_prompt_weights=True, init_img=None, init_mask=None, keep_mask=False):
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
assert prompt is not None
torch_gc()
@ -514,6 +547,13 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
cropped_faces, restored_faces, restored_img = GFPGAN.enhance(x_sample[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True)
x_sample = restored_img[:,:,::-1]
if use_RealESRGAN and RealESRGAN is not None:
if RealESRGAN.model.name != realesrgan_model_name:
try_loading_RealESRGAN(realesrgan_model_name)
output, img_mode = RealESRGAN.enhance(x_sample[:,:,::-1])
x_sample = output[:,:,::-1]
image = Image.fromarray(x_sample)
if init_mask:
#init_mask = init_mask if keep_mask else ImageOps.invert(init_mask)
@ -521,6 +561,18 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
init_mask = init_mask.convert('L')
init_img = init_img.convert('RGB')
image = image.convert('RGB')
if use_RealESRGAN and RealESRGAN is not None:
if RealESRGAN.model.name != realesrgan_model_name:
try_loading_RealESRGAN(realesrgan_model_name)
output, img_mode = RealESRGAN.enhance(np.array(init_img, dtype=np.uint8))
init_img = Image.fromarray(output)
init_img = init_img.convert('RGB')
output, img_mode = RealESRGAN.enhance(np.array(init_mask, dtype=np.uint8))
init_mask = Image.fromarray(output)
init_mask = init_mask.convert('L')
image = Image.composite(init_img, image, init_mask)
filename = f"{base_count:05}-{seeds[i]}_{prompts[i].replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128]}.png"
@ -554,7 +606,7 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
info = f"""
{prompt}
Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', GFPGAN' if use_GFPGAN and GFPGAN is not None else ''}{', Prompt Matrix Mode.' if prompt_matrix else ''}""".strip()
Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', GFPGAN' if use_GFPGAN and GFPGAN is not None else ''}{', RealESRGAN' if use_RealESRGAN and RealESRGAN is not None else ''}{', Prompt Matrix Mode.' if prompt_matrix else ''}""".strip()
stats = f'''
Took { round(time_diff, 2) }s total ({ round(time_diff/(len(all_prompts)),2) }s per image)
Peak memory usage: { -(mem_max_used // -1_048_576) } MiB / { -(mem_total // -1_048_576) } MiB / { round(mem_max_used/mem_total*100, 3) }%'''
@ -569,7 +621,7 @@ Peak memory usage: { -(mem_max_used // -1_048_576) } MiB / { -(mem_total // -1_0
return output_images, seed, info, stats
def txt2img(prompt: str, ddim_steps: int, sampler_name: str, toggles: List[int], ddim_eta: float, n_iter: int,
def txt2img(prompt: str, ddim_steps: int, sampler_name: str, toggles: List[int], realesrgan_model_name: str, ddim_eta: float, n_iter: int,
batch_size: int, cfg_scale: float, seed: Union[int, str, None], height: int, width: int, fp):
outpath = opt.outdir_txt2img or opt.outdir or "outputs/txt2img-samples"
err = False
@ -580,6 +632,7 @@ def txt2img(prompt: str, ddim_steps: int, sampler_name: str, toggles: List[int],
skip_save = 2 not in toggles
skip_grid = 3 not in toggles
use_GFPGAN = 4 in toggles
use_RealESRGAN = 5 in toggles
if sampler_name == 'PLMS':
sampler = PLMSSampler(model)
@ -625,6 +678,8 @@ def txt2img(prompt: str, ddim_steps: int, sampler_name: str, toggles: List[int],
height=height,
prompt_matrix=prompt_matrix,
use_GFPGAN=use_GFPGAN,
use_RealESRGAN=use_RealESRGAN,
realesrgan_model_name=realesrgan_model_name,
fp=fp,
normalize_prompt_weights=normalize_prompt_weights
)
@ -685,7 +740,7 @@ class Flagging(gr.FlaggingCallback):
def img2img(prompt: str, image_editor_mode: str, cropped_image, image_with_mask, mask_mode: str, ddim_steps: int, sampler_name: str,
toggles: List[int], n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float,
toggles: List[int], realesrgan_model_name: str, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float,
seed: int, height: int, width: int, resize_mode: int, fp):
outpath = opt.outdir_img2img or opt.outdir or "outputs/img2img-samples"
err = False
@ -698,6 +753,7 @@ def img2img(prompt: str, image_editor_mode: str, cropped_image, image_with_mask,
skip_save = 4 not in toggles
skip_grid = 5 not in toggles
use_GFPGAN = 6 in toggles
use_RealESRGAN = 7 in toggles
if sampler_name == 'DDIM':
sampler = DDIMSampler(model)
@ -793,6 +849,8 @@ def img2img(prompt: str, image_editor_mode: str, cropped_image, image_with_mask,
height=height,
prompt_matrix=prompt_matrix,
use_GFPGAN=use_GFPGAN,
use_RealESRGAN=False, # Forcefully disable upscaling when using loopback
realesrgan_model_name=realesrgan_model_name,
fp=fp,
do_not_save_grid=True,
normalize_prompt_weights=normalize_prompt_weights,
@ -840,6 +898,8 @@ def img2img(prompt: str, image_editor_mode: str, cropped_image, image_with_mask,
height=height,
prompt_matrix=prompt_matrix,
use_GFPGAN=use_GFPGAN,
use_RealESRGAN=use_RealESRGAN,
realesrgan_model_name=realesrgan_model_name,
fp=fp,
normalize_prompt_weights=normalize_prompt_weights,
init_img=init_img,
@ -916,6 +976,17 @@ def run_GFPGAN(image, strength):
return res
def run_RealESRGAN(image, model_name: str):
if RealESRGAN.model.name != model_name:
try_loading_RealESRGAN(model_name)
image = image.convert("RGB")
output, img_mode = RealESRGAN.enhance(np.array(image, dtype=np.uint8))
res = Image.fromarray(output)
return res
css = "" if opt.no_progressbar_hiding else css_hide_progressbar
css = css + '[data-testid="image"] {min-height: 512px !important}'
@ -931,6 +1002,8 @@ txt2img_toggles = [
]
if GFPGAN is not None:
txt2img_toggles.append('Fix faces using GFPGAN')
if RealESRGAN is not None:
txt2img_toggles.append('Upscale images using RealESRGAN')
txt2img_toggle_defaults = [
'Normalize Prompt Weights (ensure sum of weights add up to 1.0)',
@ -949,6 +1022,8 @@ img2img_toggles = [
]
if GFPGAN is not None:
img2img_toggles.append('Fix faces using GFPGAN')
if RealESRGAN is not None:
img2img_toggles.append('Upscale images using RealESRGAN')
img2img_toggle_defaults = [
'Normalize Prompt Weights (ensure sum of weights add up to 1.0)',
@ -985,6 +1060,7 @@ with gr.Blocks(css=css) as demo:
txt2img_steps = gr.Slider(minimum=1, maximum=250, step=1, label="Sampling Steps", value=50)
txt2img_sampling = gr.Radio(label='Sampling method (k_lms is default k-diffusion sampler)', choices=["DDIM", "PLMS", 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms'], value="k_lms")
txt2img_toggles = gr.CheckboxGroup(label='', choices=txt2img_toggles, value=txt2img_toggle_defaults, type="index")
txt2img_realesrgan_model_name = gr.Dropdown(label='RealESRGAN model', choices=['RealESRGAN_x4plus', 'RealESRGAN_x4plus_anime_6B'], value='RealESRGAN_x4plus', visible=RealESRGAN is not None) # TODO: Feels like I shouldnt slot it in here.
txt2img_ddim_eta = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="DDIM ETA", value=0.0, visible=False)
txt2img_batch_count = gr.Slider(minimum=1, maximum=250, step=1, label='Batch count (how many batches of images to generate)', value=1)
txt2img_batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size (how many images are in a batch; memory-hungry)', value=1)
@ -1004,7 +1080,7 @@ with gr.Blocks(css=css) as demo:
txt2img_btn.click(
txt2img,
[txt2img_prompt, txt2img_steps, txt2img_sampling, txt2img_toggles, txt2img_ddim_eta, txt2img_batch_count, txt2img_batch_size, txt2img_cfg, txt2img_seed, txt2img_height, txt2img_width, txt2img_embeddings],
[txt2img_prompt, txt2img_steps, txt2img_sampling, txt2img_toggles, txt2img_realesrgan_model_name, txt2img_ddim_eta, txt2img_batch_count, txt2img_batch_size, txt2img_cfg, txt2img_seed, txt2img_height, txt2img_width, txt2img_embeddings],
[output_txt2img_gallery, output_txt2img_seed, output_txt2img_params, output_txt2img_stats]
)
@ -1021,6 +1097,7 @@ with gr.Blocks(css=css) as demo:
img2img_steps = gr.Slider(minimum=1, maximum=250, step=1, label="Sampling Steps", value=50)
img2img_sampling = gr.Radio(label='Sampling method (k_lms is default k-diffusion sampler)', choices=["DDIM", 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms'], value="k_lms")
img2img_toggles = gr.CheckboxGroup(label='', choices=img2img_toggles, value=img2img_toggle_defaults, type="index")
img2img_realesrgan_model_name = gr.Dropdown(label='RealESRGAN model', choices=['RealESRGAN_x4plus', 'RealESRGAN_x4plus_anime_6B'], value='RealESRGAN_x4plus', visible=RealESRGAN is not None) # TODO: Feels like I shouldnt slot it in here.
img2img_batch_count = gr.Slider(minimum=1, maximum=250, step=1, label='Batch count (how many batches of images to generate)', value=1)
img2img_batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size (how many images are in a batch; memory-hungry)', value=1)
img2img_cfg = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='Classifier Free Guidance Scale (how strongly the image should follow the prompt)', value=5.0)
@ -1060,7 +1137,7 @@ with gr.Blocks(css=css) as demo:
img2img_btn.click(
img2img,
[img2img_prompt, img2img_image_editor_mode, img2img_image_editor, img2img_image_mask, img2img_mask, img2img_steps, img2img_sampling, img2img_toggles, img2img_batch_count, img2img_batch_size, img2img_cfg, img2img_denoising, img2img_seed, img2img_height, img2img_width, img2img_resize, img2img_embeddings],
[img2img_prompt, img2img_image_editor_mode, img2img_image_editor, img2img_image_mask, img2img_mask, img2img_steps, img2img_sampling, img2img_toggles, img2img_realesrgan_model_name, img2img_batch_count, img2img_batch_size, img2img_cfg, img2img_denoising, img2img_seed, img2img_height, img2img_width, img2img_resize, img2img_embeddings],
[output_img2img_gallery, output_img2img_seed, output_img2img_params, output_img2img_stats]
)
if GFPGAN is not None:
@ -1078,6 +1155,21 @@ with gr.Blocks(css=css) as demo:
[gfpgan_source, gfpgan_strength],
[gfpgan_output]
)
if RealESRGAN is not None:
with gr.TabItem("RealESRGAN"):
gr.Markdown("Upscale images")
with gr.Row():
with gr.Column():
realesrgan_source = gr.Image(label="Source", source="upload", interactive=True, type="pil")
realesrgan_model_name = gr.Dropdown(label='RealESRGAN model', choices=['RealESRGAN_x4plus', 'RealESRGAN_x4plus_anime_6B'], value='RealESRGAN_x4plus')
realesrgan_btn = gr.Button("Generate")
with gr.Column():
realesrgan_output = gr.Image(label="Output")
realesrgan_btn.click(
run_RealESRGAN,
[realesrgan_source, realesrgan_model_name],
[realesrgan_output]
)
output_txt2img_copy_to_input_btn.click(
copy_img_to_input,