diff --git a/README.md b/README.md index 63f8d000..a3d722c8 100644 --- a/README.md +++ b/README.md @@ -194,3 +194,14 @@ Using `()` in prompt decreases model's attention to enclosed words, and `[]` inc multiple modifiers: ![](images/attention-3.jpg) + +### SD upscale +Upscale image using RealESRGAN and then go through tiles of the result, improving them with img2img. + +Original idea by: https://github.com/jquesnelle/txt2imghd. This is an independent implementation. + +To use this feature, tick a checkbox in the img2img interface. Original +image will be upscaled to twice the original width and height, while width and height sliders +will specify the size of individual tiles. At the moment this method does not support batch size. + +![](images/sd-upscale.jpg) diff --git a/images/sd-upscale.jpg b/images/sd-upscale.jpg new file mode 100644 index 00000000..f7b4ad9e Binary files /dev/null and b/images/sd-upscale.jpg differ diff --git a/webui.py b/webui.py index a0fa23c4..13e5112a 100644 --- a/webui.py +++ b/webui.py @@ -85,11 +85,6 @@ try: from realesrgan.archs.srvgg_arch import SRVGGNetCompact realesrgan_models = [ - RealesrganModelInfo( - name="Real-ESRGAN 2x plus", - location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth", - netscale=2, model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2) - ), RealesrganModelInfo( name="Real-ESRGAN 4x plus", location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth", @@ -100,6 +95,11 @@ try: location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth", netscale=4, model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) ), + RealesrganModelInfo( + name="Real-ESRGAN 2x plus", + location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth", + netscale=2, model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2) + ), ] have_realesrgan = True except: @@ -124,6 +124,7 @@ class Options: "verify_input": (True, "Check input, and produce warning if it's too long"), "enable_pnginfo": (True, "Save text information about generation parameters as chunks to png files"), "prompt_matrix_add_to_start": (True, "In prompt matrix, add the variable combination of text to the start of the prompt, rather than the end"), + "sd_upscale_overlap": (64, "Overlap for tiles for SD upscale. The smaller it is, the less smooth transition from one tile to another", 0, 256, 16), } def __init__(self): @@ -289,6 +290,73 @@ def image_grid(imgs, batch_size, force_n_rows=None): return grid +Grid = namedtuple("Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"]) + + +def split_grid(image, tile_w=512, tile_h=512, overlap=64): + w = image.width + h = image.height + + now = tile_w - overlap # non-overlap width + noh = tile_h - overlap + + cols = math.ceil((w - overlap) / now) + rows = math.ceil((h - overlap) / noh) + + grid = Grid([], tile_w, tile_h, w, h, overlap) + for row in range(rows): + row_images = [] + + y = row * noh + + if y + tile_h >= h: + y = h - tile_h + + for col in range(cols): + x = col * now + + if x+tile_w >= w: + x = w - tile_w + + tile = image.crop((x, y, x + tile_w, y + tile_h)) + + row_images.append([x, tile_w, tile]) + + grid.tiles.append([y, tile_h, row_images]) + + return grid + + +def combine_grid(grid): + def make_mask_image(r): + r = r * 255 / grid.overlap + r = r.astype(np.uint8) + return Image.fromarray(r, 'L') + + mask_w = make_mask_image(np.arange(grid.overlap, dtype=np.float).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0)) + mask_h = make_mask_image(np.arange(grid.overlap, dtype=np.float).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1)) + + combined_image = Image.new("RGB", (grid.image_w, grid.image_h)) + for y, h, row in grid.tiles: + combined_row = Image.new("RGB", (grid.image_w, h)) + for x, w, tile in row: + if x == 0: + combined_row.paste(tile, (0, 0)) + continue + + combined_row.paste(tile.crop((0, 0, grid.overlap, h)), (x, 0), mask=mask_w) + combined_row.paste(tile.crop((grid.overlap, 0, w, h)), (x + grid.overlap, 0)) + + if y == 0: + combined_image.paste(combined_row, (0, 0)) + continue + + combined_image.paste(combined_row.crop((0, 0, combined_row.width, grid.overlap)), (0, y), mask=mask_h) + combined_image.paste(combined_row.crop((0, grid.overlap, combined_row.width, h)), (0, y + grid.overlap)) + + return combined_image + + def draw_prompt_matrix(im, width, height, all_prompts): def wrap(text, d, font, line_length): lines = [''] @@ -491,6 +559,7 @@ class StableDiffuionModelHijack: model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) + class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): def __init__(self, wrapped, embeddings): super().__init__() @@ -740,8 +809,6 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index, save_image(grid, outpath, f"grid-{grid_count:04}", seed, prompt, opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename) grid_count += 1 - - torch_gc() return output_images, seed, infotext() @@ -847,7 +914,7 @@ txt2img_interface = gr.Interface( ) -def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPGAN: bool, prompt_matrix, loopback: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int): +def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPGAN: bool, prompt_matrix, loopback: bool, sd_upscale: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int): outpath = opts.outdir or "outputs/img2img-samples" sampler = samplers_for_img2img[sampler_index].constructor(model) @@ -894,7 +961,7 @@ def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPG func_sample=sample, prompt=prompt, seed=seed, - sampler_index=0, + sampler_index=sampler_index, batch_size=1, n_iter=1, steps=ddim_steps, @@ -923,6 +990,59 @@ def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPG output_images = history seed = initial_seed + elif sd_upscale: + initial_seed = None + initial_info = None + + img = upscale_with_realesrgan(init_img, RealESRGAN_upscaling=2, RealESRGAN_model_index=0) + + torch_gc() + + grid = split_grid(img, tile_w=width, tile_h=height, overlap=opts.sd_upscale_overlap) + + + print(f"SD upscaling will process a total of {len(grid.tiles[0][2])}x{len(grid.tiles)} images.") + + for y, h, row in grid.tiles: + for tiledata in row: + init_img = tiledata[2] + + output_images, seed, info = process_images( + outpath=outpath, + func_init=init, + func_sample=sample, + prompt=prompt, + seed=seed, + sampler_index=sampler_index, + batch_size=1, # since process_images can't work with multiple different images we have to do this for now + n_iter=1, + steps=ddim_steps, + cfg_scale=cfg_scale, + width=width, + height=height, + prompt_matrix=prompt_matrix, + use_GFPGAN=use_GFPGAN, + do_not_save_grid=True, + extra_generation_params={"Denoising Strength": denoising_strength}, + ) + + if initial_seed is None: + initial_seed = seed + initial_info = info + + seed += 1 + + tiledata[2] = output_images[0] + + combined_image = combine_grid(grid) + + grid_count = len(os.listdir(outpath)) - 1 + save_image(combined_image, outpath, f"grid-{grid_count:04}", initial_seed, prompt, opts.grid_format, info=initial_info, short_filename=not opts.grid_extended_filename) + + output_images = [combined_image] + seed = initial_seed + info = initial_info + else: output_images, seed, info = process_images( outpath=outpath, @@ -930,7 +1050,7 @@ def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPG func_sample=sample, prompt=prompt, seed=seed, - sampler_index=0, + sampler_index=sampler_index, batch_size=batch_size, n_iter=n_iter, steps=ddim_steps, @@ -960,6 +1080,7 @@ img2img_interface = gr.Interface( gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None), gr.Checkbox(label='Create prompt matrix (separate multiple prompts using |, and get all combinations of them)', value=False), gr.Checkbox(label='Loopback (use images from previous batch when creating next batch)', value=False), + gr.Checkbox(label='Stable Diffusion upscale', value=False), gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count (how many batches of images to generate)', value=1), gr.Slider(minimum=1, maximum=8, step=1, label='Batch size (how many images are in a batch; memory-hungry)', value=1), gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Classifier Free Guidance Scale (how strongly the image should follow the prompt)', value=7.0), @@ -978,7 +1099,26 @@ img2img_interface = gr.Interface( ) +def upscale_with_realesrgan(image, RealESRGAN_upscaling, RealESRGAN_model_index): + info = realesrgan_models[RealESRGAN_model_index] + + model = info.model() + upsampler = RealESRGANer( + scale=info.netscale, + model_path=info.location, + model=model, + half=True + ) + + upsampled = upsampler.enhance(np.array(image), outscale=RealESRGAN_upscaling)[0] + + image = Image.fromarray(upsampled) + return image + + def run_extras(image, GFPGAN_strength, RealESRGAN_upscaling, RealESRGAN_model_index): + torch_gc() + image = image.convert("RGB") outpath = opts.outdir or "outputs/extras-samples" @@ -993,19 +1133,7 @@ def run_extras(image, GFPGAN_strength, RealESRGAN_upscaling, RealESRGAN_model_in image = res if have_realesrgan and RealESRGAN_upscaling != 1.0: - info = realesrgan_models[RealESRGAN_model_index] - - model = info.model() - upsampler = RealESRGANer( - scale=info.netscale, - model_path=info.location, - model=model, - half=True - ) - - upsampled = upsampler.enhance(np.array(image), outscale=RealESRGAN_upscaling)[0] - - image = Image.fromarray(upsampled) + image = upscale_with_realesrgan(image, RealESRGAN_upscaling, RealESRGAN_model_index) os.makedirs(outpath, exist_ok=True) base_count = len(os.listdir(outpath)) @@ -1058,7 +1186,9 @@ def create_setting_component(key): if t == str: item = gr.Textbox(label=label, value=fun, lines=1) elif t == int: - if len(labelinfo) == 4: + if len(labelinfo) == 5: + item = gr.Slider(minimum=labelinfo[2], maximum=labelinfo[3], step=labelinfo[4], label=label, value=fun) + elif len(labelinfo) == 4: item = gr.Slider(minimum=labelinfo[2], maximum=labelinfo[3], step=1, label=label, value=fun) else: item = gr.Number(label=label, value=fun)