Better support for large batches in optimized mode

This commit is contained in:
Soul-Burn 2022-09-07 21:00:04 +03:00 committed by hlky
parent 53aacef732
commit 3e9cdb1dcb
2 changed files with 8 additions and 9 deletions

View File

@ -38,8 +38,11 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda
value=txt2img_defaults['cfg_scale'], elem_id='cfg_slider') value=txt2img_defaults['cfg_scale'], elem_id='cfg_slider')
txt2img_seed = gr.Textbox(label="Seed (blank to randomize)", lines=1, max_lines=1, txt2img_seed = gr.Textbox(label="Seed (blank to randomize)", lines=1, max_lines=1,
value=txt2img_defaults["seed"]) value=txt2img_defaults["seed"])
txt2img_batch_size = gr.Slider(minimum=1, maximum=50, step=1,
label='Images per batch',
value=txt2img_defaults['batch_size'])
txt2img_batch_count = gr.Slider(minimum=1, maximum=50, step=1, txt2img_batch_count = gr.Slider(minimum=1, maximum=50, step=1,
label='Number of images to generate', label='Number of batches to generate',
value=txt2img_defaults['n_iter']) value=txt2img_defaults['n_iter'])
txt2img_job_ui = job_manager.draw_gradio_ui() if job_manager else None txt2img_job_ui = job_manager.draw_gradio_ui() if job_manager else None
@ -93,9 +96,6 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda
with gr.TabItem('Advanced'): with gr.TabItem('Advanced'):
txt2img_toggles = gr.CheckboxGroup(label='', choices=txt2img_toggles, txt2img_toggles = gr.CheckboxGroup(label='', choices=txt2img_toggles,
value=txt2img_toggle_defaults, type="index") value=txt2img_toggle_defaults, type="index")
txt2img_batch_size = gr.Slider(minimum=1, maximum=8, step=1,
label='Batch size (how many images are in a batch; memory-hungry)',
value=txt2img_defaults['batch_size'])
txt2img_realesrgan_model_name = gr.Dropdown(label='RealESRGAN model', txt2img_realesrgan_model_name = gr.Dropdown(label='RealESRGAN model',
choices=['RealESRGAN_x4plus', choices=['RealESRGAN_x4plus',
'RealESRGAN_x4plus_anime_6B'], 'RealESRGAN_x4plus_anime_6B'],

View File

@ -969,11 +969,10 @@ def process_images(
if opt.optimized: if opt.optimized:
modelFS.to(device) modelFS.to(device)
for i in range(len(samples_ddim)):
x_samples_ddim = (model if not opt.optimized else modelFS).decode_first_stage(samples_ddim[i].unsqueeze(0))
x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
x_samples_ddim = (model if not opt.optimized else modelFS).decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
for i, x_sample in enumerate(x_samples_ddim):
sanitized_prompt = prompts[i].replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars}) sanitized_prompt = prompts[i].replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})
if variant_seed != None and variant_seed != '': if variant_seed != None and variant_seed != '':
if variant_amount == 0.0: if variant_amount == 0.0:
@ -1005,7 +1004,7 @@ def process_images(
filename = filename.replace("[SEED]", seed_used) filename = filename.replace("[SEED]", seed_used)
filename = filename.replace("[VARIANT_AMOUNT]", f"{cur_variant_amount:.2f}") filename = filename.replace("[VARIANT_AMOUNT]", f"{cur_variant_amount:.2f}")
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') x_sample = 255. * rearrange(x_sample[0].cpu().numpy(), 'c h w -> h w c')
x_sample = x_sample.astype(np.uint8) x_sample = x_sample.astype(np.uint8)
metadata = ImageMetadata(prompt=prompts[i], seed=seeds[i], height=height, width=width, steps=steps, metadata = ImageMetadata(prompt=prompts[i], seed=seeds[i], height=height, width=width, steps=steps,
cfg_scale=cfg_scale, normalize_prompt_weights=normalize_prompt_weights, denoising_strength=denoising_strength, cfg_scale=cfg_scale, normalize_prompt_weights=normalize_prompt_weights, denoising_strength=denoising_strength,