Add simple templating

This commit is contained in:
Oxlamon 2022-08-27 18:27:13 +03:00 committed by hlky
parent 54a24088aa
commit 82770bacae

View File

@ -1,4 +1,4 @@
import argparse, os, sys, glob
import argparse, os, sys, glob, re
parser = argparse.ArgumentParser()
parser.add_argument("--outdir", type=str, nargs="?", help="dir to write results to", default=None)
@ -345,7 +345,7 @@ def load_embeddings(fp):
if fp is not None and hasattr(model, "embedding_manager"):
model.embedding_manager.load(fp.name)
def image_grid(imgs, batch_size, force_n_rows=None):
def image_grid(imgs, batch_size, force_n_rows=None, captions=None):
if force_n_rows is not None:
rows = force_n_rows
elif opt.n_rows > 0:
@ -361,8 +361,14 @@ def image_grid(imgs, batch_size, force_n_rows=None):
w, h = imgs[0].size
grid = Image.new('RGB', size=(cols * w, rows * h), color='black')
fnt = ImageFont.truetype("arial.ttf", 30)
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
if captions:
d = ImageDraw.Draw( grid )
size = d.textbbox( (0,0), captions[i], font=fnt, stroke_width=2, align="center" )
d.multiline_text((i % cols * w + w/2, i // cols * h + h - size[3]), captions[i], font=fnt, fill=(255,0,255), stroke_width=2, stroke_fill=(0,0,0), anchor="mm", align="center")
return grid
@ -585,6 +591,57 @@ def get_next_sequence_number(path, prefix=''):
pass
return result + 1
def oxlamon_matrix(prompt, seed, batch_size):
pattern = re.compile(r'(,\s){2,}')
class PromptItem:
def __init__(self, text, parts, item):
self.text = text
self.parts = parts
if item:
self.parts.append( item )
def clean(txt):
return re.sub(pattern, ', ', txt)
def repliter( txt ):
for data in re.finditer( ".*?\\((.*?)\\).*", txt ):
if data:
r = data.span(1)
for item in data.group(1).split("|"):
yield (clean(txt[:r[0]-1] + item.strip() + txt[r[1]+1:]), item.strip())
break
def iterlist( items ):
outitems = []
for item in items:
for newitem, newpart in repliter(item.text):
outitems.append( PromptItem(newitem, item.parts.copy(), newpart) )
return outitems
def getmatrix( prompt ):
dataitems = [ PromptItem( prompt[1:].strip(), [], None ) ]
while True:
newdataitems = iterlist( dataitems )
if len( newdataitems ) == 0:
return dataitems
dataitems = newdataitems
def classToArrays( items ):
texts = []
parts = []
for item in items:
texts.append( item.text )
parts.append( "\n".join(item.parts) )
return texts, parts
all_prompts, prompt_matrix_parts = classToArrays(getmatrix( prompt ))
n_iter = math.ceil(len(all_prompts) / batch_size)
all_seeds = len(all_prompts) * [seed]
return all_seeds, n_iter, prompt_matrix_parts, all_prompts
def process_images(
outpath, func_init, func_sample, prompt, seed, sampler_name, skip_grid, skip_save, batch_size,
@ -613,20 +670,23 @@ def process_images(
prompt_matrix_parts = []
if prompt_matrix:
all_prompts = []
prompt_matrix_parts = prompt.split("|")
combination_count = 2 ** (len(prompt_matrix_parts) - 1)
for combination_num in range(combination_count):
current = prompt_matrix_parts[0]
if prompt.startswith("@"):
all_seeds, n_iter, prompt_matrix_parts, all_prompts = oxlamon_matrix(prompt, seed, batch_size)
else:
all_prompts = []
prompt_matrix_parts = prompt.split("|")
combination_count = 2 ** (len(prompt_matrix_parts) - 1)
for combination_num in range(combination_count):
current = prompt_matrix_parts[0]
for n, text in enumerate(prompt_matrix_parts[1:]):
if combination_num & (2 ** n) > 0:
current += ("" if text.strip().startswith(",") else ", ") + text
for n, text in enumerate(prompt_matrix_parts[1:]):
if combination_num & (2 ** n) > 0:
current += ("" if text.strip().startswith(",") else ", ") + text
all_prompts.append(current)
all_prompts.append(current)
n_iter = math.ceil(len(all_prompts) / batch_size)
all_seeds = len(all_prompts) * [seed]
n_iter = math.ceil(len(all_prompts) / batch_size)
all_seeds = len(all_prompts) * [seed]
print(f"Prompt matrix will create {len(all_prompts)} images using a total of {n_iter} batches.")
else:
@ -650,6 +710,7 @@ def process_images(
tic = time.time()
for n in range(n_iter):
print(f"Iteration: {n+1}/{n_iter}")
prompts = all_prompts[n * batch_size:(n + 1) * batch_size]
seeds = all_seeds[n * batch_size:(n + 1) * batch_size]
@ -773,11 +834,11 @@ skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoisin
if (prompt_matrix or not skip_grid) and not do_not_save_grid:
if prompt_matrix:
grid = image_grid(output_images, batch_size, force_n_rows=1 << ((len(prompt_matrix_parts)-1)//2))
grid = image_grid(output_images, batch_size, force_n_rows=1 << ((len(prompt_matrix_parts)-1)//2), captions=prompt_matrix_parts)
else:
grid = image_grid(output_images, batch_size)
if prompt_matrix:
if prompt_matrix and not prompt.startswith("@"):
try:
grid = draw_prompt_matrix(grid, width, height, prompt_matrix_parts)
except:
@ -785,9 +846,9 @@ skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoisin
print("Error creating prompt_matrix text:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
output_images.insert(0, grid)
else:
grid = image_grid(output_images, batch_size)
output_images.insert(0, grid)
#else:
# grid = image_grid(output_images, batch_size)
grid_count = get_next_sequence_number(outpath, 'grid-')
grid_file = f"grid-{grid_count:05}-{seed}_{prompts[i].replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128]}.{grid_ext}"