Add stat handling to img2img

This commit is contained in:
netcavy 2022-08-24 23:59:44 +10:00
parent cd46fb54b5
commit 8cd893f279

View File

@ -354,7 +354,7 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
start_time = time.time() start_time = time.time()
assert prompt is not None assert prompt is not None
torch_gc() torch.cuda.empty_cache()
if seed == -1: if seed == -1:
seed = random.randrange(4294967294) seed = random.randrange(4294967294)
@ -598,6 +598,7 @@ txt2img_interface = gr.Interface(
def img2img(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, prompt_matrix, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int): def img2img(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, prompt_matrix, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int):
outpath = opt.outdir or "outputs/img2img-samples" outpath = opt.outdir or "outputs/img2img-samples"
err = False
sampler = KDiffusionSampler(model) sampler = KDiffusionSampler(model)
@ -630,7 +631,8 @@ def img2img(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, prompt_mat
samples_ddim = K.sampling.sample_lms(model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': cfg_scale}, disable=False) samples_ddim = K.sampling.sample_lms(model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': cfg_scale}, disable=False)
return samples_ddim return samples_ddim
output_images, seed, info = process_images( try:
output_images, seed, info, stats = process_images(
outpath=outpath, outpath=outpath,
func_init=init, func_init=init,
func_sample=sample, func_sample=sample,
@ -649,7 +651,13 @@ def img2img(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, prompt_mat
del sampler del sampler
return output_images, seed, info return output_images, seed, info, stats
except RuntimeError as e:
err = e
return [], f'CRASHED:<br><textarea rows="5" style="background: black;width: -webkit-fill-available;font-family: monospace;font-size: small;font-weight: bold;">{str(e)}</textarea><br><br>Please wait while the program restarts.'
finally:
if err:
crash(err, '!!Runtime error (dream)!!')
sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg" sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg"