Variations (#307)

* variation support

WIP

* removed extra )

* fixed warning

Co-authored-by: xra <mail@xra.dev>
Co-authored-by: hlky <106811348+hlky@users.noreply.github.com>
This commit is contained in:
xraxra 2022-08-30 03:18:00 -07:00 committed by GitHub
parent 123b21e452
commit 96aba4b36d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 61 additions and 8 deletions

View File

@ -92,6 +92,8 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, txt2img_defaul
visible=RealESRGAN is not None) # TODO: Feels like I shouldnt slot it in here.
txt2img_ddim_eta = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="DDIM ETA",
value=txt2img_defaults['ddim_eta'], visible=False)
txt2img_variant_amount = gr.Slider(minimum=0.0, maximum=1.0, label='Variation Amount',value=txt2img_defaults['variant_amount'])
txt2img_variant_seed = gr.Textbox(label="Variant Seed (blank to randomize)", lines=1, max_lines=1,value=txt2img_defaults["variant_seed"])
txt2img_embeddings = gr.File(label="Embeddings file for textual inversion",
visible=show_embeddings)
@ -99,14 +101,14 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, txt2img_defaul
txt2img,
[txt2img_prompt, txt2img_steps, txt2img_sampling, txt2img_toggles, txt2img_realesrgan_model_name,
txt2img_ddim_eta, txt2img_batch_count, txt2img_batch_size, txt2img_cfg, txt2img_seed,
txt2img_height, txt2img_width, txt2img_embeddings],
txt2img_height, txt2img_width, txt2img_embeddings, txt2img_variant_amount, txt2img_variant_seed],
[output_txt2img_gallery, output_txt2img_seed, output_txt2img_params, output_txt2img_stats]
)
txt2img_prompt.submit(
txt2img,
[txt2img_prompt, txt2img_steps, txt2img_sampling, txt2img_toggles, txt2img_realesrgan_model_name,
txt2img_ddim_eta, txt2img_batch_count, txt2img_batch_size, txt2img_cfg, txt2img_seed,
txt2img_height, txt2img_width, txt2img_embeddings],
txt2img_height, txt2img_width, txt2img_embeddings, txt2img_variant_amount, txt2img_variant_seed],
[output_txt2img_gallery, output_txt2img_seed, output_txt2img_params, output_txt2img_stats]
)

View File

@ -664,7 +664,8 @@ def process_images(
n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, use_RealESRGAN, realesrgan_model_name,
fp, ddim_eta=0.0, do_not_save_grid=False, normalize_prompt_weights=True, init_img=None, init_mask=None,
keep_mask=False, mask_blur_strength=3, denoising_strength=0.75, resize_mode=None, uses_loopback=False,
uses_random_seed_loopback=False, sort_samples=True, write_info_files=True, jpg_sample=False):
uses_random_seed_loopback=False, sort_samples=True, write_info_files=True, jpg_sample=False,
variant_amount=0.0, variant_seed=None):
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
assert prompt is not None
torch_gc()
@ -728,6 +729,19 @@ def process_images(
init_data = func_init()
tic = time.time()
# if variant_amount > 0.0 create noise from base seed
base_x = None
if variant_amount > 0.0:
target_seed_randomizer = seed_to_int('') # random seed
torch.manual_seed(seed) # this has to be the single starting seed (not per-iteration)
base_x = create_random_tensors([opt_C, height // opt_f, width // opt_f], seeds=[seed])
# we don't want all_seeds to be sequential from starting seed with variants,
# since that makes the same variants each time,
# so we add target_seed_randomizer as a random offset
for si in range(len(all_seeds)):
all_seeds[si] += target_seed_randomizer
for n in range(n_iter):
print(f"Iteration: {n+1}/{n_iter}")
prompts = all_prompts[n * batch_size:(n + 1) * batch_size]
@ -766,8 +780,21 @@ def process_images(
while(torch.cuda.memory_allocated()/1e6 >= mem):
time.sleep(1)
# we manually generate all input noises because each one should have a specific seed
x = create_random_tensors([opt_C, height // opt_f, width // opt_f], seeds=seeds)
if variant_amount == 0.0:
# we manually generate all input noises because each one should have a specific seed
x = create_random_tensors(shape, seeds=seeds)
else: # we are making variants
# using variant_seed as sneaky toggle,
# when not None or '' use the variant_seed
# otherwise use seeds
if variant_seed != None and variant_seed != '':
specified_variant_seed = seed_to_int(variant_seed)
torch.manual_seed(specified_variant_seed)
seeds = [specified_variant_seed]
target_x = create_random_tensors(shape, seeds=seeds)
# finally, slerp base_x noise to target_x noise for creating a variant
x = slerp(device, max(0.0, min(1.0, variant_amount)), base_x, target_x)
samples_ddim = func_sample(init_data=init_data, x=x, conditioning=c, unconditional_conditioning=uc, sampler_name=sampler_name)
if opt.optimized:
@ -925,7 +952,7 @@ Peak memory usage: { -(mem_max_used // -1_048_576) } MiB / { -(mem_total // -1_0
def txt2img(prompt: str, ddim_steps: int, sampler_name: str, toggles: List[int], realesrgan_model_name: str,
ddim_eta: float, n_iter: int, batch_size: int, cfg_scale: float, seed: Union[int, str, None],
height: int, width: int, fp):
height: int, width: int, fp, variant_amount: float = None, variant_seed: int = None):
outpath = opt.outdir_txt2img or opt.outdir or "outputs/txt2img-samples"
err = False
seed = seed_to_int(seed)
@ -992,6 +1019,8 @@ def txt2img(prompt: str, ddim_steps: int, sampler_name: str, toggles: List[int],
sort_samples=sort_samples,
write_info_files=write_info_files,
jpg_sample=jpg_sample,
variant_amount=variant_amount,
variant_seed=variant_seed,
)
del sampler
@ -1018,7 +1047,7 @@ class Flagging(gr.FlaggingCallback):
os.makedirs("log/images", exist_ok=True)
# those must match the "txt2img" function !! + images, seed, comment, stats !! NOTE: changes to UI output must be reflected here too
prompt, ddim_steps, sampler_name, toggles, ddim_eta, n_iter, batch_size, cfg_scale, seed, height, width, fp, images, seed, comment, stats = flag_data
prompt, ddim_steps, sampler_name, toggles, ddim_eta, n_iter, batch_size, cfg_scale, seed, height, width, fp, variant_amount, variant_seed, images, seed, comment, stats = flag_data
filenames = []
@ -1302,6 +1331,26 @@ def split_weighted_subprompts(text):
remaining = 0
return prompts, weights
def slerp(device, t, v0:torch.Tensor, v1:torch.Tensor, DOT_THRESHOLD=0.9995):
v0 = v0.detach().cpu().numpy()
v1 = v1.detach().cpu().numpy()
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
if np.abs(dot) > DOT_THRESHOLD:
v2 = (1 - t) * v0 + t * v1
else:
theta_0 = np.arccos(dot)
sin_theta_0 = np.sin(theta_0)
theta_t = theta_0 * t
sin_theta_t = np.sin(theta_t)
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
s1 = sin_theta_t / sin_theta_0
v2 = s0 * v0 + s1 * v1
v2 = torch.from_numpy(v2).to(device)
return v2
def run_GFPGAN(image, strength):
image = image.convert("RGB")
@ -1364,7 +1413,9 @@ txt2img_defaults = {
'height': 512,
'width': 512,
'fp': None,
'submit_on_enter': 'Yes'
'variant_amount': 0.0,
'variant_seed': '',
'submit_on_enter': 'Yes',
}
if 'txt2img' in user_defaults: