Add stat handling to img2img

This commit is contained in:
netcavy 2022-08-24 23:59:44 +10:00
parent cd46fb54b5
commit 8cd893f279

View File

@ -354,7 +354,7 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
start_time = time.time()
assert prompt is not None
torch_gc()
torch.cuda.empty_cache()
if seed == -1:
seed = random.randrange(4294967294)
@ -598,6 +598,7 @@ txt2img_interface = gr.Interface(
def img2img(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, prompt_matrix, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int):
outpath = opt.outdir or "outputs/img2img-samples"
err = False
sampler = KDiffusionSampler(model)
@ -630,26 +631,33 @@ def img2img(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, prompt_mat
samples_ddim = K.sampling.sample_lms(model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': cfg_scale}, disable=False)
return samples_ddim
output_images, seed, info = process_images(
outpath=outpath,
func_init=init,
func_sample=sample,
prompt=prompt,
seed=seed,
sampler_name='k-diffusion',
batch_size=batch_size,
n_iter=n_iter,
steps=ddim_steps,
cfg_scale=cfg_scale,
width=width,
height=height,
prompt_matrix=prompt_matrix,
use_GFPGAN=use_GFPGAN
)
try:
output_images, seed, info, stats = process_images(
outpath=outpath,
func_init=init,
func_sample=sample,
prompt=prompt,
seed=seed,
sampler_name='k-diffusion',
batch_size=batch_size,
n_iter=n_iter,
steps=ddim_steps,
cfg_scale=cfg_scale,
width=width,
height=height,
prompt_matrix=prompt_matrix,
use_GFPGAN=use_GFPGAN
)
del sampler
del sampler
return output_images, seed, info
return output_images, seed, info, stats
except RuntimeError as e:
err = e
return [], f'CRASHED:<br><textarea rows="5" style="background: black;width: -webkit-fill-available;font-family: monospace;font-size: small;font-weight: bold;">{str(e)}</textarea><br><br>Please wait while the program restarts.'
finally:
if err:
crash(err, '!!Runtime error (dream)!!')
sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg"