mirror of
https://github.com/openvinotoolkit/stable-diffusion-webui.git
synced 2024-12-14 22:53:25 +03:00
Added TI training optimizations
option to use xattention optimizations when training option to unload vae when training
This commit is contained in:
parent
700162a603
commit
006756f9cd
@ -256,11 +256,12 @@ options_templates.update(options_section(('system', "System"), {
|
|||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('training', "Training"), {
|
options_templates.update(options_section(('training', "Training"), {
|
||||||
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training hypernetwork. Saves VRAM."),
|
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
|
||||||
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
|
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
|
||||||
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
|
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
|
||||||
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
|
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
|
||||||
"training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"),
|
"training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"),
|
||||||
|
"training_xattention_optimizations": OptionInfo(False, "Use cross attention optimizations while training"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
||||||
|
@ -214,6 +214,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
|
|||||||
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
|
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
|
||||||
|
|
||||||
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), embedding_name)
|
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), embedding_name)
|
||||||
|
unload = shared.opts.unload_models_when_training
|
||||||
|
|
||||||
if save_embedding_every > 0:
|
if save_embedding_every > 0:
|
||||||
embedding_dir = os.path.join(log_directory, "embeddings")
|
embedding_dir = os.path.join(log_directory, "embeddings")
|
||||||
@ -238,6 +239,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
|
|||||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||||
with torch.autocast("cuda"):
|
with torch.autocast("cuda"):
|
||||||
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size)
|
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size)
|
||||||
|
if unload:
|
||||||
|
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||||
|
|
||||||
hijack = sd_hijack.model_hijack
|
hijack = sd_hijack.model_hijack
|
||||||
|
|
||||||
@ -303,6 +306,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
|
|||||||
if images_dir is not None and steps_done % create_image_every == 0:
|
if images_dir is not None and steps_done % create_image_every == 0:
|
||||||
forced_filename = f'{embedding_name}-{steps_done}'
|
forced_filename = f'{embedding_name}-{steps_done}'
|
||||||
last_saved_image = os.path.join(images_dir, forced_filename)
|
last_saved_image = os.path.join(images_dir, forced_filename)
|
||||||
|
|
||||||
|
shared.sd_model.first_stage_model.to(devices.device)
|
||||||
|
|
||||||
p = processing.StableDiffusionProcessingTxt2Img(
|
p = processing.StableDiffusionProcessingTxt2Img(
|
||||||
sd_model=shared.sd_model,
|
sd_model=shared.sd_model,
|
||||||
do_not_save_grid=True,
|
do_not_save_grid=True,
|
||||||
@ -330,6 +336,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
|
|||||||
processed = processing.process_images(p)
|
processed = processing.process_images(p)
|
||||||
image = processed.images[0]
|
image = processed.images[0]
|
||||||
|
|
||||||
|
if unload:
|
||||||
|
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||||
|
|
||||||
shared.state.current_image = image
|
shared.state.current_image = image
|
||||||
|
|
||||||
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
|
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
|
||||||
|
@ -25,8 +25,10 @@ def train_embedding(*args):
|
|||||||
|
|
||||||
assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible'
|
assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible'
|
||||||
|
|
||||||
|
apply_optimizations = shared.opts.training_xattention_optimizations
|
||||||
try:
|
try:
|
||||||
sd_hijack.undo_optimizations()
|
if not apply_optimizations:
|
||||||
|
sd_hijack.undo_optimizations()
|
||||||
|
|
||||||
embedding, filename = modules.textual_inversion.textual_inversion.train_embedding(*args)
|
embedding, filename = modules.textual_inversion.textual_inversion.train_embedding(*args)
|
||||||
|
|
||||||
@ -38,5 +40,6 @@ Embedding saved to {html.escape(filename)}
|
|||||||
except Exception:
|
except Exception:
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
sd_hijack.apply_optimizations()
|
if not apply_optimizations:
|
||||||
|
sd_hijack.apply_optimizations()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user