mirror of
https://github.com/sd-webui/stable-diffusion-webui.git
synced 2024-12-14 06:35:14 +03:00
Add support for RealESRGAN upscaling
This commit is contained in:
parent
ab2770ef12
commit
601b176a18
104
webui.py
104
webui.py
@ -56,6 +56,8 @@ parser.add_argument("--config", type=str, default="configs/stable-diffusion/v1-i
|
||||
parser.add_argument("--ckpt", type=str, default="models/ldm/stable-diffusion-v1/model.ckpt", help="path to checkpoint of model",)
|
||||
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
|
||||
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN')) # i disagree with where you're putting it but since all guidefags are doing it this way, there you go
|
||||
parser.add_argument("--realesrgan-dir", type=str, help="RealESRGAN directory", default=('./src/realesrgan' if os.path.exists('./src/realesrgan') else './RealESRGAN'))
|
||||
parser.add_argument("--realesrgan-model", type=str, help="Upscaling model for RealESRGAN", default=('RealESRGAN_x4plus'))
|
||||
parser.add_argument("--no-verify-input", action='store_true', help="do not verify input to check if it's too long")
|
||||
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
|
||||
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware accleration in browser)")
|
||||
@ -63,6 +65,7 @@ parser.add_argument("--cli", type=str, help="don't launch web server, take Pytho
|
||||
opt = parser.parse_args()
|
||||
|
||||
GFPGAN_dir = opt.gfpgan_dir
|
||||
RealESRGAN_dir = opt.realesrgan_dir
|
||||
|
||||
css_hide_progressbar = """
|
||||
.wrap .m-12 svg { display:none!important; }
|
||||
@ -228,6 +231,23 @@ def load_GFPGAN():
|
||||
|
||||
return GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
|
||||
|
||||
def load_RealESRGAN(model_name: str):
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
RealESRGAN_models = {
|
||||
'RealESRGAN_x4plus': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4),
|
||||
'RealESRGAN_x4plus_anime_6B': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
|
||||
}
|
||||
|
||||
model_path = os.path.join(RealESRGAN_dir, 'experiments/pretrained_models', model_name + '.pth')
|
||||
if not os.path.isfile(model_path):
|
||||
raise Exception(model_name+".pth not found at path "+model_path)
|
||||
|
||||
sys.path.append(os.path.abspath(RealESRGAN_dir))
|
||||
from realesrgan import RealESRGANer
|
||||
|
||||
instance = RealESRGANer(scale=2, model_path=model_path, model=RealESRGAN_models[model_name], pre_pad=0, half=True)
|
||||
instance.model.name = model_name
|
||||
return instance
|
||||
|
||||
GFPGAN = None
|
||||
if os.path.exists(GFPGAN_dir):
|
||||
@ -239,6 +259,19 @@ if os.path.exists(GFPGAN_dir):
|
||||
print("Error loading GFPGAN:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
RealESRGAN = None
|
||||
def try_loading_RealESRGAN(model_name: str):
|
||||
global RealESRGAN
|
||||
if os.path.exists(RealESRGAN_dir):
|
||||
try:
|
||||
RealESRGAN = load_RealESRGAN(model_name) # TODO: Should try to load both models before giving up
|
||||
print("Loaded RealESRGAN with model "+RealESRGAN.model.name)
|
||||
except Exception:
|
||||
import traceback
|
||||
print("Error loading RealESRGAN:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
try_loading_RealESRGAN('RealESRGAN_x4plus')
|
||||
|
||||
config = OmegaConf.load("configs/stable-diffusion/v1-inference.yaml")
|
||||
model = load_model_from_config(config, "models/ldm/stable-diffusion-v1/model.ckpt")
|
||||
|
||||
@ -410,7 +443,7 @@ def check_prompt_length(prompt, comments):
|
||||
comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
|
||||
|
||||
|
||||
def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name, skip_grid, skip_save, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, fp, do_not_save_grid=False, normalize_prompt_weights=True, init_img=None, init_mask=None, keep_mask=False):
|
||||
def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name, skip_grid, skip_save, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, use_RealESRGAN, realesrgan_model_name, fp, do_not_save_grid=False, normalize_prompt_weights=True, init_img=None, init_mask=None, keep_mask=False):
|
||||
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
|
||||
assert prompt is not None
|
||||
torch_gc()
|
||||
@ -514,6 +547,13 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
|
||||
cropped_faces, restored_faces, restored_img = GFPGAN.enhance(x_sample[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True)
|
||||
x_sample = restored_img[:,:,::-1]
|
||||
|
||||
if use_RealESRGAN and RealESRGAN is not None:
|
||||
if RealESRGAN.model.name != realesrgan_model_name:
|
||||
try_loading_RealESRGAN(realesrgan_model_name)
|
||||
|
||||
output, img_mode = RealESRGAN.enhance(x_sample[:,:,::-1])
|
||||
x_sample = output[:,:,::-1]
|
||||
|
||||
image = Image.fromarray(x_sample)
|
||||
if init_mask:
|
||||
#init_mask = init_mask if keep_mask else ImageOps.invert(init_mask)
|
||||
@ -521,6 +561,18 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
|
||||
init_mask = init_mask.convert('L')
|
||||
init_img = init_img.convert('RGB')
|
||||
image = image.convert('RGB')
|
||||
|
||||
if use_RealESRGAN and RealESRGAN is not None:
|
||||
if RealESRGAN.model.name != realesrgan_model_name:
|
||||
try_loading_RealESRGAN(realesrgan_model_name)
|
||||
output, img_mode = RealESRGAN.enhance(np.array(init_img, dtype=np.uint8))
|
||||
init_img = Image.fromarray(output)
|
||||
init_img = init_img.convert('RGB')
|
||||
|
||||
output, img_mode = RealESRGAN.enhance(np.array(init_mask, dtype=np.uint8))
|
||||
init_mask = Image.fromarray(output)
|
||||
init_mask = init_mask.convert('L')
|
||||
|
||||
image = Image.composite(init_img, image, init_mask)
|
||||
|
||||
filename = f"{base_count:05}-{seeds[i]}_{prompts[i].replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128]}.png"
|
||||
@ -554,7 +606,7 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
|
||||
|
||||
info = f"""
|
||||
{prompt}
|
||||
Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', GFPGAN' if use_GFPGAN and GFPGAN is not None else ''}{', Prompt Matrix Mode.' if prompt_matrix else ''}""".strip()
|
||||
Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', GFPGAN' if use_GFPGAN and GFPGAN is not None else ''}{', RealESRGAN' if use_RealESRGAN and RealESRGAN is not None else ''}{', Prompt Matrix Mode.' if prompt_matrix else ''}""".strip()
|
||||
stats = f'''
|
||||
Took { round(time_diff, 2) }s total ({ round(time_diff/(len(all_prompts)),2) }s per image)
|
||||
Peak memory usage: { -(mem_max_used // -1_048_576) } MiB / { -(mem_total // -1_048_576) } MiB / { round(mem_max_used/mem_total*100, 3) }%'''
|
||||
@ -569,7 +621,7 @@ Peak memory usage: { -(mem_max_used // -1_048_576) } MiB / { -(mem_total // -1_0
|
||||
return output_images, seed, info, stats
|
||||
|
||||
|
||||
def txt2img(prompt: str, ddim_steps: int, sampler_name: str, toggles: List[int], ddim_eta: float, n_iter: int,
|
||||
def txt2img(prompt: str, ddim_steps: int, sampler_name: str, toggles: List[int], realesrgan_model_name: str, ddim_eta: float, n_iter: int,
|
||||
batch_size: int, cfg_scale: float, seed: Union[int, str, None], height: int, width: int, fp):
|
||||
outpath = opt.outdir_txt2img or opt.outdir or "outputs/txt2img-samples"
|
||||
err = False
|
||||
@ -580,6 +632,7 @@ def txt2img(prompt: str, ddim_steps: int, sampler_name: str, toggles: List[int],
|
||||
skip_save = 2 not in toggles
|
||||
skip_grid = 3 not in toggles
|
||||
use_GFPGAN = 4 in toggles
|
||||
use_RealESRGAN = 5 in toggles
|
||||
|
||||
if sampler_name == 'PLMS':
|
||||
sampler = PLMSSampler(model)
|
||||
@ -625,6 +678,8 @@ def txt2img(prompt: str, ddim_steps: int, sampler_name: str, toggles: List[int],
|
||||
height=height,
|
||||
prompt_matrix=prompt_matrix,
|
||||
use_GFPGAN=use_GFPGAN,
|
||||
use_RealESRGAN=use_RealESRGAN,
|
||||
realesrgan_model_name=realesrgan_model_name,
|
||||
fp=fp,
|
||||
normalize_prompt_weights=normalize_prompt_weights
|
||||
)
|
||||
@ -685,7 +740,7 @@ class Flagging(gr.FlaggingCallback):
|
||||
|
||||
|
||||
def img2img(prompt: str, image_editor_mode: str, cropped_image, image_with_mask, mask_mode: str, ddim_steps: int, sampler_name: str,
|
||||
toggles: List[int], n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float,
|
||||
toggles: List[int], realesrgan_model_name: str, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float,
|
||||
seed: int, height: int, width: int, resize_mode: int, fp):
|
||||
outpath = opt.outdir_img2img or opt.outdir or "outputs/img2img-samples"
|
||||
err = False
|
||||
@ -698,6 +753,7 @@ def img2img(prompt: str, image_editor_mode: str, cropped_image, image_with_mask,
|
||||
skip_save = 4 not in toggles
|
||||
skip_grid = 5 not in toggles
|
||||
use_GFPGAN = 6 in toggles
|
||||
use_RealESRGAN = 7 in toggles
|
||||
|
||||
if sampler_name == 'DDIM':
|
||||
sampler = DDIMSampler(model)
|
||||
@ -793,6 +849,8 @@ def img2img(prompt: str, image_editor_mode: str, cropped_image, image_with_mask,
|
||||
height=height,
|
||||
prompt_matrix=prompt_matrix,
|
||||
use_GFPGAN=use_GFPGAN,
|
||||
use_RealESRGAN=False, # Forcefully disable upscaling when using loopback
|
||||
realesrgan_model_name=realesrgan_model_name,
|
||||
fp=fp,
|
||||
do_not_save_grid=True,
|
||||
normalize_prompt_weights=normalize_prompt_weights,
|
||||
@ -840,6 +898,8 @@ def img2img(prompt: str, image_editor_mode: str, cropped_image, image_with_mask,
|
||||
height=height,
|
||||
prompt_matrix=prompt_matrix,
|
||||
use_GFPGAN=use_GFPGAN,
|
||||
use_RealESRGAN=use_RealESRGAN,
|
||||
realesrgan_model_name=realesrgan_model_name,
|
||||
fp=fp,
|
||||
normalize_prompt_weights=normalize_prompt_weights,
|
||||
init_img=init_img,
|
||||
@ -916,6 +976,17 @@ def run_GFPGAN(image, strength):
|
||||
|
||||
return res
|
||||
|
||||
def run_RealESRGAN(image, model_name: str):
|
||||
if RealESRGAN.model.name != model_name:
|
||||
try_loading_RealESRGAN(model_name)
|
||||
|
||||
image = image.convert("RGB")
|
||||
|
||||
output, img_mode = RealESRGAN.enhance(np.array(image, dtype=np.uint8))
|
||||
res = Image.fromarray(output)
|
||||
|
||||
return res
|
||||
|
||||
css = "" if opt.no_progressbar_hiding else css_hide_progressbar
|
||||
css = css + '[data-testid="image"] {min-height: 512px !important}'
|
||||
|
||||
@ -931,6 +1002,8 @@ txt2img_toggles = [
|
||||
]
|
||||
if GFPGAN is not None:
|
||||
txt2img_toggles.append('Fix faces using GFPGAN')
|
||||
if RealESRGAN is not None:
|
||||
txt2img_toggles.append('Upscale images using RealESRGAN')
|
||||
|
||||
txt2img_toggle_defaults = [
|
||||
'Normalize Prompt Weights (ensure sum of weights add up to 1.0)',
|
||||
@ -949,6 +1022,8 @@ img2img_toggles = [
|
||||
]
|
||||
if GFPGAN is not None:
|
||||
img2img_toggles.append('Fix faces using GFPGAN')
|
||||
if RealESRGAN is not None:
|
||||
img2img_toggles.append('Upscale images using RealESRGAN')
|
||||
|
||||
img2img_toggle_defaults = [
|
||||
'Normalize Prompt Weights (ensure sum of weights add up to 1.0)',
|
||||
@ -985,6 +1060,7 @@ with gr.Blocks(css=css) as demo:
|
||||
txt2img_steps = gr.Slider(minimum=1, maximum=250, step=1, label="Sampling Steps", value=50)
|
||||
txt2img_sampling = gr.Radio(label='Sampling method (k_lms is default k-diffusion sampler)', choices=["DDIM", "PLMS", 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms'], value="k_lms")
|
||||
txt2img_toggles = gr.CheckboxGroup(label='', choices=txt2img_toggles, value=txt2img_toggle_defaults, type="index")
|
||||
txt2img_realesrgan_model_name = gr.Dropdown(label='RealESRGAN model', choices=['RealESRGAN_x4plus', 'RealESRGAN_x4plus_anime_6B'], value='RealESRGAN_x4plus', visible=RealESRGAN is not None) # TODO: Feels like I shouldnt slot it in here.
|
||||
txt2img_ddim_eta = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="DDIM ETA", value=0.0, visible=False)
|
||||
txt2img_batch_count = gr.Slider(minimum=1, maximum=250, step=1, label='Batch count (how many batches of images to generate)', value=1)
|
||||
txt2img_batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size (how many images are in a batch; memory-hungry)', value=1)
|
||||
@ -1004,7 +1080,7 @@ with gr.Blocks(css=css) as demo:
|
||||
|
||||
txt2img_btn.click(
|
||||
txt2img,
|
||||
[txt2img_prompt, txt2img_steps, txt2img_sampling, txt2img_toggles, txt2img_ddim_eta, txt2img_batch_count, txt2img_batch_size, txt2img_cfg, txt2img_seed, txt2img_height, txt2img_width, txt2img_embeddings],
|
||||
[txt2img_prompt, txt2img_steps, txt2img_sampling, txt2img_toggles, txt2img_realesrgan_model_name, txt2img_ddim_eta, txt2img_batch_count, txt2img_batch_size, txt2img_cfg, txt2img_seed, txt2img_height, txt2img_width, txt2img_embeddings],
|
||||
[output_txt2img_gallery, output_txt2img_seed, output_txt2img_params, output_txt2img_stats]
|
||||
)
|
||||
|
||||
@ -1021,6 +1097,7 @@ with gr.Blocks(css=css) as demo:
|
||||
img2img_steps = gr.Slider(minimum=1, maximum=250, step=1, label="Sampling Steps", value=50)
|
||||
img2img_sampling = gr.Radio(label='Sampling method (k_lms is default k-diffusion sampler)', choices=["DDIM", 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms'], value="k_lms")
|
||||
img2img_toggles = gr.CheckboxGroup(label='', choices=img2img_toggles, value=img2img_toggle_defaults, type="index")
|
||||
img2img_realesrgan_model_name = gr.Dropdown(label='RealESRGAN model', choices=['RealESRGAN_x4plus', 'RealESRGAN_x4plus_anime_6B'], value='RealESRGAN_x4plus', visible=RealESRGAN is not None) # TODO: Feels like I shouldnt slot it in here.
|
||||
img2img_batch_count = gr.Slider(minimum=1, maximum=250, step=1, label='Batch count (how many batches of images to generate)', value=1)
|
||||
img2img_batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size (how many images are in a batch; memory-hungry)', value=1)
|
||||
img2img_cfg = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='Classifier Free Guidance Scale (how strongly the image should follow the prompt)', value=5.0)
|
||||
@ -1060,7 +1137,7 @@ with gr.Blocks(css=css) as demo:
|
||||
|
||||
img2img_btn.click(
|
||||
img2img,
|
||||
[img2img_prompt, img2img_image_editor_mode, img2img_image_editor, img2img_image_mask, img2img_mask, img2img_steps, img2img_sampling, img2img_toggles, img2img_batch_count, img2img_batch_size, img2img_cfg, img2img_denoising, img2img_seed, img2img_height, img2img_width, img2img_resize, img2img_embeddings],
|
||||
[img2img_prompt, img2img_image_editor_mode, img2img_image_editor, img2img_image_mask, img2img_mask, img2img_steps, img2img_sampling, img2img_toggles, img2img_realesrgan_model_name, img2img_batch_count, img2img_batch_size, img2img_cfg, img2img_denoising, img2img_seed, img2img_height, img2img_width, img2img_resize, img2img_embeddings],
|
||||
[output_img2img_gallery, output_img2img_seed, output_img2img_params, output_img2img_stats]
|
||||
)
|
||||
if GFPGAN is not None:
|
||||
@ -1078,6 +1155,21 @@ with gr.Blocks(css=css) as demo:
|
||||
[gfpgan_source, gfpgan_strength],
|
||||
[gfpgan_output]
|
||||
)
|
||||
if RealESRGAN is not None:
|
||||
with gr.TabItem("RealESRGAN"):
|
||||
gr.Markdown("Upscale images")
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
realesrgan_source = gr.Image(label="Source", source="upload", interactive=True, type="pil")
|
||||
realesrgan_model_name = gr.Dropdown(label='RealESRGAN model', choices=['RealESRGAN_x4plus', 'RealESRGAN_x4plus_anime_6B'], value='RealESRGAN_x4plus')
|
||||
realesrgan_btn = gr.Button("Generate")
|
||||
with gr.Column():
|
||||
realesrgan_output = gr.Image(label="Output")
|
||||
realesrgan_btn.click(
|
||||
run_RealESRGAN,
|
||||
[realesrgan_source, realesrgan_model_name],
|
||||
[realesrgan_output]
|
||||
)
|
||||
|
||||
output_txt2img_copy_to_input_btn.click(
|
||||
copy_img_to_input,
|
||||
|
Loading…
Reference in New Issue
Block a user