From bd68e35de3b7cf7547ed97d8bdf60147402133cc Mon Sep 17 00:00:00 2001 From: flamelaw Date: Sun, 20 Nov 2022 12:35:26 +0900 Subject: [PATCH] Gradient accumulation, autocast fix, new latent sampling method, etc --- modules/hypernetworks/hypernetwork.py | 251 ++++++++-------- modules/sd_hijack.py | 9 +- modules/sd_hijack_checkpoint.py | 10 + modules/shared.py | 3 +- modules/textual_inversion/dataset.py | 120 +++++--- .../textual_inversion/textual_inversion.py | 272 ++++++++++-------- modules/ui.py | 16 +- 7 files changed, 408 insertions(+), 273 deletions(-) create mode 100644 modules/sd_hijack_checkpoint.py diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index fbb87dd1..3d3301b0 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -367,13 +367,13 @@ def report_statistics(loss_info:dict): -def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): +def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): # images allows training previews to have infotext. Importing it at the top causes a circular import problem. from modules import images save_hypernetwork_every = save_hypernetwork_every or 0 create_image_every = create_image_every or 0 - textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork") + textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork") path = shared.hypernetworks.get(hypernetwork_name, None) shared.loaded_hypernetwork = Hypernetwork() @@ -403,28 +403,24 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log hypernetwork = shared.loaded_hypernetwork checkpoint = sd_models.select_checkpoint() - ititial_step = hypernetwork.step or 0 - if ititial_step >= steps: + initial_step = hypernetwork.step or 0 + if initial_step >= steps: shared.state.textinfo = f"Model has already been trained beyond specified max steps" return hypernetwork, filename - scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) - + scheduler = LearnRateScheduler(learn_rate, steps, initial_step) + # dataset loading may take a while, so input validations and early returns should be done before this shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." - with torch.autocast("cuda"): - ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size) + + pin_memory = shared.opts.pin_memory + + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method) + dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, batch_size=ds.batch_size, pin_memory=pin_memory) if unload: shared.sd_model.cond_stage_model.to(devices.cpu) shared.sd_model.first_stage_model.to(devices.cpu) - - size = len(ds.indexes) - loss_dict = defaultdict(lambda : deque(maxlen = 1024)) - losses = torch.zeros((size,)) - previous_mean_losses = [0] - previous_mean_loss = 0 - print("Mean loss of {} elements".format(size)) weights = hypernetwork.weights() for weight in weights: @@ -436,8 +432,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log optimizer_name = hypernetwork.optimizer_name else: print(f"Optimizer type {hypernetwork.optimizer_name} is not defined!") - optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate) - optimizer_name = 'AdamW' + optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate) + optimizer_name = 'AdamW' if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer. try: @@ -446,131 +442,155 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log print("Cannot resume from saved optimizer!") print(e) + scaler = torch.cuda.amp.GradScaler() + + batch_size = ds.batch_size + gradient_step = ds.gradient_step + # n steps = batch_size * gradient_step * n image processed + steps_per_epoch = len(ds) // batch_size // gradient_step + max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step + loss_step = 0 + _loss_step = 0 #internal + # size = len(ds.indexes) + # loss_dict = defaultdict(lambda : deque(maxlen = 1024)) + # losses = torch.zeros((size,)) + # previous_mean_losses = [0] + # previous_mean_loss = 0 + # print("Mean loss of {} elements".format(size)) + steps_without_grad = 0 last_saved_file = "" last_saved_image = "" forced_filename = "" - pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) - for i, entries in pbar: - hypernetwork.step = i + ititial_step - if len(loss_dict) > 0: - previous_mean_losses = [i[-1] for i in loss_dict.values()] - previous_mean_loss = mean(previous_mean_losses) - - scheduler.apply(optimizer, hypernetwork.step) - if scheduler.finished: - break + pbar = tqdm.tqdm(total=steps - initial_step) + try: + for i in range((steps-initial_step) * gradient_step): + if scheduler.finished: + break + if shared.state.interrupted: + break + for j, batch in enumerate(dl): + # works as a drop_last=True for gradient accumulation + if j == max_steps_per_epoch: + break + scheduler.apply(optimizer, hypernetwork.step) + if scheduler.finished: + break + if shared.state.interrupted: + break - if shared.state.interrupted: - break + with torch.autocast("cuda"): + x = batch.latent_sample.to(devices.device, non_blocking=pin_memory) + if tag_drop_out != 0 or shuffle_tags: + shared.sd_model.cond_stage_model.to(devices.device) + c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device, non_blocking=pin_memory) + shared.sd_model.cond_stage_model.to(devices.cpu) + else: + c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory) + loss = shared.sd_model(x, c)[0] / gradient_step + del x + del c - with torch.autocast("cuda"): - c = stack_conds([entry.cond for entry in entries]).to(devices.device) - # c = torch.vstack([entry.cond for entry in entries]).to(devices.device) - x = torch.stack([entry.latent for entry in entries]).to(devices.device) - loss = shared.sd_model(x, c)[0] - del x - del c + _loss_step += loss.item() + scaler.scale(loss).backward() + # go back until we reach gradient accumulation steps + if (j + 1) % gradient_step != 0: + continue + # print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.7f}") + # scaler.unscale_(optimizer) + # print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}") + # torch.nn.utils.clip_grad_norm_(weights, max_norm=1.0) + # print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}") + scaler.step(optimizer) + scaler.update() + hypernetwork.step += 1 + pbar.update() + optimizer.zero_grad(set_to_none=True) + loss_step = _loss_step + _loss_step = 0 - losses[hypernetwork.step % losses.shape[0]] = loss.item() - for entry in entries: - loss_dict[entry.filename].append(loss.item()) + steps_done = hypernetwork.step + 1 - optimizer.zero_grad() - weights[0].grad = None - loss.backward() + epoch_num = hypernetwork.step // steps_per_epoch + epoch_step = hypernetwork.step % steps_per_epoch - if weights[0].grad is None: - steps_without_grad += 1 - else: - steps_without_grad = 0 - assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue' + pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}") + if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0: + # Before saving, change name to match current checkpoint. + hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}' + last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt') + hypernetwork.optimizer_name = optimizer_name + if shared.opts.save_optimizer_state: + hypernetwork.optimizer_state_dict = optimizer.state_dict() + save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file) + hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory. - optimizer.step() + textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch, { + "loss": f"{loss_step:.7f}", + "learn_rate": scheduler.learn_rate + }) - steps_done = hypernetwork.step + 1 + if images_dir is not None and steps_done % create_image_every == 0: + forced_filename = f'{hypernetwork_name}-{steps_done}' + last_saved_image = os.path.join(images_dir, forced_filename) - if torch.isnan(losses[hypernetwork.step % losses.shape[0]]): - raise RuntimeError("Loss diverged.") - - if len(previous_mean_losses) > 1: - std = stdev(previous_mean_losses) - else: - std = 0 - dataset_loss_info = f"dataset loss:{mean(previous_mean_losses):.3f}" + u"\u00B1" + f"({std / (len(previous_mean_losses) ** 0.5):.3f})" - pbar.set_description(dataset_loss_info) + shared.sd_model.cond_stage_model.to(devices.device) + shared.sd_model.first_stage_model.to(devices.device) - if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0: - # Before saving, change name to match current checkpoint. - hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}' - last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt') - hypernetwork.optimizer_name = optimizer_name - if shared.opts.save_optimizer_state: - hypernetwork.optimizer_state_dict = optimizer.state_dict() - save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file) - hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory. + p = processing.StableDiffusionProcessingTxt2Img( + sd_model=shared.sd_model, + do_not_save_grid=True, + do_not_save_samples=True, + ) - textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), { - "loss": f"{previous_mean_loss:.7f}", - "learn_rate": scheduler.learn_rate - }) + if preview_from_txt2img: + p.prompt = preview_prompt + p.negative_prompt = preview_negative_prompt + p.steps = preview_steps + p.sampler_name = sd_samplers.samplers[preview_sampler_index].name + p.cfg_scale = preview_cfg_scale + p.seed = preview_seed + p.width = preview_width + p.height = preview_height + else: + p.prompt = batch.cond_text[0] + p.steps = 20 + p.width = training_width + p.height = training_height - if images_dir is not None and steps_done % create_image_every == 0: - forced_filename = f'{hypernetwork_name}-{steps_done}' - last_saved_image = os.path.join(images_dir, forced_filename) + preview_text = p.prompt - optimizer.zero_grad() - shared.sd_model.cond_stage_model.to(devices.device) - shared.sd_model.first_stage_model.to(devices.device) + processed = processing.process_images(p) + image = processed.images[0] if len(processed.images) > 0 else None - p = processing.StableDiffusionProcessingTxt2Img( - sd_model=shared.sd_model, - do_not_save_grid=True, - do_not_save_samples=True, - ) + if unload: + shared.sd_model.cond_stage_model.to(devices.cpu) + shared.sd_model.first_stage_model.to(devices.cpu) - if preview_from_txt2img: - p.prompt = preview_prompt - p.negative_prompt = preview_negative_prompt - p.steps = preview_steps - p.sampler_name = sd_samplers.samplers[preview_sampler_index].name - p.cfg_scale = preview_cfg_scale - p.seed = preview_seed - p.width = preview_width - p.height = preview_height - else: - p.prompt = entries[0].cond_text - p.steps = 20 + if image is not None: + shared.state.current_image = image + last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) + last_saved_image += f", prompt: {preview_text}" - preview_text = p.prompt + shared.state.job_no = hypernetwork.step - processed = processing.process_images(p) - image = processed.images[0] if len(processed.images)>0 else None - - if unload: - shared.sd_model.cond_stage_model.to(devices.cpu) - shared.sd_model.first_stage_model.to(devices.cpu) - - if image is not None: - shared.state.current_image = image - last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) - last_saved_image += f", prompt: {preview_text}" - - shared.state.job_no = hypernetwork.step - - shared.state.textinfo = f""" + shared.state.textinfo = f"""

-Loss: {previous_mean_loss:.7f}
+Loss: {loss_step:.7f}
Step: {hypernetwork.step}
-Last prompt: {html.escape(entries[0].cond_text)}
+Last prompt: {html.escape(batch.cond_text[0])}
Last saved hypernetwork: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}

""" - - report_statistics(loss_dict) + except Exception: + print(traceback.format_exc(), file=sys.stderr) + finally: + pbar.leave = False + pbar.close() + #report_statistics(loss_dict) filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') hypernetwork.optimizer_name = optimizer_name @@ -579,6 +599,9 @@ Last saved image: {html.escape(last_saved_image)}
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename) del optimizer hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory. + shared.sd_model.cond_stage_model.to(devices.device) + shared.sd_model.first_stage_model.to(devices.device) + return hypernetwork, filename def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename): diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index eaedac13..29c8b561 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -8,7 +8,7 @@ from torch import einsum from torch.nn.functional import silu import modules.textual_inversion.textual_inversion -from modules import prompt_parser, devices, sd_hijack_optimizations, shared +from modules import prompt_parser, devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint from modules.shared import opts, device, cmd_opts from modules.sd_hijack_optimizations import invokeAI_mps_available @@ -59,6 +59,10 @@ def undo_optimizations(): def get_target_prompt_token_count(token_count): return math.ceil(max(token_count, 1) / 75) * 75 +def fix_checkpoint(): + ldm.modules.attention.BasicTransformerBlock.forward = sd_hijack_checkpoint.BasicTransformerBlock_forward + ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = sd_hijack_checkpoint.ResBlock_forward + ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = sd_hijack_checkpoint.AttentionBlock_forward class StableDiffusionModelHijack: fixes = None @@ -78,6 +82,7 @@ class StableDiffusionModelHijack: self.clip = m.cond_stage_model apply_optimizations() + fix_checkpoint() def flatten(el): flattened = [flatten(children) for children in el.children()] @@ -303,7 +308,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text) else: batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text) - + self.hijack.comments += hijack_comments if len(used_custom_terms) > 0: diff --git a/modules/sd_hijack_checkpoint.py b/modules/sd_hijack_checkpoint.py new file mode 100644 index 00000000..5712972f --- /dev/null +++ b/modules/sd_hijack_checkpoint.py @@ -0,0 +1,10 @@ +from torch.utils.checkpoint import checkpoint + +def BasicTransformerBlock_forward(self, x, context=None): + return checkpoint(self._forward, x, context) + +def AttentionBlock_forward(self, x): + return checkpoint(self._forward, x) + +def ResBlock_forward(self, x, emb): + return checkpoint(self._forward, x, emb) \ No newline at end of file diff --git a/modules/shared.py b/modules/shared.py index a4457305..3704ce23 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -322,8 +322,7 @@ options_templates.update(options_section(('system', "System"), { options_templates.update(options_section(('training', "Training"), { "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."), - "shuffle_tags": OptionInfo(False, "Shuffleing tags by ',' when create texts."), - "tag_drop_out": OptionInfo(0, "Dropout tags when create texts", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.1}), + "pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."), "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training can be resumed with HN itself and matching optim file."), "dataset_filename_word_regex": OptionInfo("", "Filename word regex"), "dataset_filename_join_string": OptionInfo(" ", "Filename join string"), diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index eb75c376..d594b49d 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -3,7 +3,7 @@ import numpy as np import PIL import torch from PIL import Image -from torch.utils.data import Dataset +from torch.utils.data import Dataset, DataLoader from torchvision import transforms import random @@ -11,25 +11,28 @@ import tqdm from modules import devices, shared import re +from ldm.modules.distributions.distributions import DiagonalGaussianDistribution + re_numbers_at_start = re.compile(r"^[-\d]+\s*") class DatasetEntry: - def __init__(self, filename=None, latent=None, filename_text=None): + def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None): self.filename = filename - self.latent = latent self.filename_text = filename_text - self.cond = None - self.cond_text = None + self.latent_dist = latent_dist + self.latent_sample = latent_sample + self.cond = cond + self.cond_text = cond_text + self.pixel_values = pixel_values class PersonalizedBase(Dataset): - def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False, batch_size=1): + def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once'): re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None - + self.placeholder_token = placeholder_token - self.batch_size = batch_size self.width = width self.height = height self.flip = transforms.RandomHorizontalFlip(p=flip_p) @@ -45,11 +48,16 @@ class PersonalizedBase(Dataset): assert os.path.isdir(data_root), "Dataset directory doesn't exist" assert os.listdir(data_root), "Dataset directory is empty" - cond_model = shared.sd_model.cond_stage_model - self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)] + + + self.shuffle_tags = shuffle_tags + self.tag_drop_out = tag_drop_out + print("Preparing dataset...") for path in tqdm.tqdm(self.image_paths): + if shared.state.interrupted: + raise Exception("inturrupted") try: image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC) except Exception: @@ -71,37 +79,58 @@ class PersonalizedBase(Dataset): npimage = np.array(image).astype(np.uint8) npimage = (npimage / 127.5 - 1.0).astype(np.float32) - torchdata = torch.from_numpy(npimage).to(device=device, dtype=torch.float32) - torchdata = torch.moveaxis(torchdata, 2, 0) + torchdata = torch.from_numpy(npimage).permute(2, 0, 1).to(device=device, dtype=torch.float32) + latent_sample = None - init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze() - init_latent = init_latent.to(devices.cpu) + with torch.autocast("cuda"): + latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0)) - entry = DatasetEntry(filename=path, filename_text=filename_text, latent=init_latent) + if latent_sampling_method == "once" or (latent_sampling_method == "deterministic" and not isinstance(latent_dist, DiagonalGaussianDistribution)): + latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu) + latent_sampling_method = "once" + entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample) + elif latent_sampling_method == "deterministic": + # Works only for DiagonalGaussianDistribution + latent_dist.std = 0 + latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu) + entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample) + elif latent_sampling_method == "random": + entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist) - if include_cond: + if not (self.tag_drop_out != 0 or self.shuffle_tags): entry.cond_text = self.create_text(filename_text) - entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0) + + if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags): + with torch.autocast("cuda"): + entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0) + # elif not include_cond: + # _, _, _, _, hijack_fixes, token_count = cond_model.process_text([entry.cond_text]) + # max_n = token_count // 75 + # index_list = [ [] for _ in range(max_n + 1) ] + # for n, (z, _) in hijack_fixes[0]: + # index_list[n].append(z) + # with torch.autocast("cuda"): + # entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0) + # entry.emb_index = index_list self.dataset.append(entry) + del torchdata + del latent_dist + del latent_sample - assert len(self.dataset) > 0, "No images have been found in the dataset." - self.length = len(self.dataset) * repeats // batch_size - - self.dataset_length = len(self.dataset) - self.indexes = None - self.shuffle() - - def shuffle(self): - self.indexes = np.random.permutation(self.dataset_length) + self.length = len(self.dataset) + assert self.length > 0, "No images have been found in the dataset." + self.batch_size = min(batch_size, self.length) + self.gradient_step = min(gradient_step, self.length // self.batch_size) + self.latent_sampling_method = latent_sampling_method def create_text(self, filename_text): text = random.choice(self.lines) text = text.replace("[name]", self.placeholder_token) tags = filename_text.split(',') - if shared.opts.tag_drop_out != 0: - tags = [t for t in tags if random.random() > shared.opts.tag_drop_out] - if shared.opts.shuffle_tags: + if self.tag_drop_out != 0: + tags = [t for t in tags if random.random() > self.tag_drop_out] + if self.shuffle_tags: random.shuffle(tags) text = text.replace("[filewords]", ','.join(tags)) return text @@ -110,19 +139,28 @@ class PersonalizedBase(Dataset): return self.length def __getitem__(self, i): - res = [] + entry = self.dataset[i] + if self.tag_drop_out != 0 or self.shuffle_tags: + entry.cond_text = self.create_text(entry.filename_text) + if self.latent_sampling_method == "random": + entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist) + return entry - for j in range(self.batch_size): - position = i * self.batch_size + j - if position % len(self.indexes) == 0: - self.shuffle() +class PersonalizedDataLoader(DataLoader): + def __init__(self, *args, **kwargs): + super(PersonalizedDataLoader, self).__init__(shuffle=True, drop_last=True, *args, **kwargs) + self.collate_fn = collate_wrapper + - index = self.indexes[position % len(self.indexes)] - entry = self.dataset[index] +class BatchLoader: + def __init__(self, data): + self.cond_text = [entry.cond_text for entry in data] + self.cond = [entry.cond for entry in data] + self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1) - if entry.cond is None: - entry.cond_text = self.create_text(entry.filename_text) + def pin_memory(self): + self.latent_sample = self.latent_sample.pin_memory() + return self - res.append(entry) - - return res +def collate_wrapper(batch): + return BatchLoader(batch) \ No newline at end of file diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 5e4d8688..1d5e3a32 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -184,7 +184,7 @@ def write_loss(log_directory, filename, step, epoch_len, values): if shared.opts.training_write_csv_every == 0: return - if (step + 1) % shared.opts.training_write_csv_every != 0: + if step % shared.opts.training_write_csv_every != 0: return write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True @@ -194,21 +194,23 @@ def write_loss(log_directory, filename, step, epoch_len, values): if write_csv_header: csv_writer.writeheader() - epoch = step // epoch_len - epoch_step = step % epoch_len + epoch = (step - 1) // epoch_len + epoch_step = (step - 1) % epoch_len csv_writer.writerow({ - "step": step + 1, + "step": step, "epoch": epoch, - "epoch_step": epoch_step + 1, + "epoch_step": epoch_step, **values, }) -def validate_train_inputs(model_name, learn_rate, batch_size, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"): +def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"): assert model_name, f"{name} not selected" assert learn_rate, "Learning rate is empty or 0" assert isinstance(batch_size, int), "Batch size must be integer" assert batch_size > 0, "Batch size must be positive" + assert isinstance(gradient_step, int), "Gradient accumulation step must be integer" + assert gradient_step > 0, "Gradient accumulation step must be positive" assert data_root, "Dataset directory is empty" assert os.path.isdir(data_root), "Dataset directory doesn't exist" assert os.listdir(data_root), "Dataset directory is empty" @@ -224,10 +226,10 @@ def validate_train_inputs(model_name, learn_rate, batch_size, data_root, templat if save_model_every or create_image_every: assert log_directory, "Log directory is empty" -def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): +def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): save_embedding_every = save_embedding_every or 0 create_image_every = create_image_every or 0 - validate_train_inputs(embedding_name, learn_rate, batch_size, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding") + validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding") shared.state.textinfo = "Initializing textual inversion training..." shared.state.job_count = steps @@ -255,161 +257,205 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc else: images_embeds_dir = None - cond_model = shared.sd_model.cond_stage_model - hijack = sd_hijack.model_hijack embedding = hijack.embedding_db.word_embeddings[embedding_name] checkpoint = sd_models.select_checkpoint() - ititial_step = embedding.step or 0 - if ititial_step >= steps: + initial_step = embedding.step or 0 + if initial_step >= steps: shared.state.textinfo = f"Model has already been trained beyond specified max steps" return embedding, filename + scheduler = LearnRateScheduler(learn_rate, steps, initial_step) - scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) - - # dataset loading may take a while, so input validations and early returns should be done before this + # dataset loading may take a while, so input validations and early returns should be done before this shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." - with torch.autocast("cuda"): - ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size) + + pin_memory = shared.opts.pin_memory + + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method) + + latent_sampling_method = ds.latent_sampling_method + + dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, batch_size=ds.batch_size, pin_memory=False) + if unload: shared.sd_model.first_stage_model.to(devices.cpu) embedding.vec.requires_grad = True optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate) + scaler = torch.cuda.amp.GradScaler() - losses = torch.zeros((32,)) + batch_size = ds.batch_size + gradient_step = ds.gradient_step + # n steps = batch_size * gradient_step * n image processed + steps_per_epoch = len(ds) // batch_size // gradient_step + max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step + loss_step = 0 + _loss_step = 0 #internal + last_saved_file = "" last_saved_image = "" forced_filename = "" embedding_yet_to_be_embedded = False + + pbar = tqdm.tqdm(total=steps - initial_step) + try: + for i in range((steps-initial_step) * gradient_step): + if scheduler.finished: + break + if shared.state.interrupted: + break + for j, batch in enumerate(dl): + # works as a drop_last=True for gradient accumulation + if j == max_steps_per_epoch: + break + scheduler.apply(optimizer, embedding.step) + if scheduler.finished: + break + if shared.state.interrupted: + break - pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) - for i, entries in pbar: - embedding.step = i + ititial_step + with torch.autocast("cuda"): + # c = stack_conds(batch.cond).to(devices.device) + # mask = torch.tensor(batch.emb_index).to(devices.device, non_blocking=pin_memory) + # print(mask) + # c[:, 1:1+embedding.vec.shape[0]] = embedding.vec.to(devices.device, non_blocking=pin_memory) + x = batch.latent_sample.to(devices.device, non_blocking=pin_memory) + c = shared.sd_model.cond_stage_model(batch.cond_text) + loss = shared.sd_model(x, c)[0] / gradient_step + del x + + _loss_step += loss.item() + scaler.scale(loss).backward() + + # go back until we reach gradient accumulation steps + if (j + 1) % gradient_step != 0: + continue + #print(f"grad:{embedding.vec.grad.detach().cpu().abs().mean().item():.7f}") + #scaler.unscale_(optimizer) + #print(f"grad:{embedding.vec.grad.detach().cpu().abs().mean().item():.7f}") + #torch.nn.utils.clip_grad_norm_(embedding.vec, max_norm=1.0) + #print(f"grad:{embedding.vec.grad.detach().cpu().abs().mean().item():.7f}") + scaler.step(optimizer) + scaler.update() + embedding.step += 1 + pbar.update() + optimizer.zero_grad(set_to_none=True) + loss_step = _loss_step + _loss_step = 0 - scheduler.apply(optimizer, embedding.step) - if scheduler.finished: - break + steps_done = embedding.step + 1 - if shared.state.interrupted: - break + epoch_num = embedding.step // steps_per_epoch + epoch_step = embedding.step % steps_per_epoch - with torch.autocast("cuda"): - c = cond_model([entry.cond_text for entry in entries]) - x = torch.stack([entry.latent for entry in entries]).to(devices.device) - loss = shared.sd_model(x, c)[0] - del x + pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}") + if embedding_dir is not None and steps_done % save_embedding_every == 0: + # Before saving, change name to match current checkpoint. + embedding_name_every = f'{embedding_name}-{steps_done}' + last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt') + #if shared.opts.save_optimizer_state: + #embedding.optimizer_state_dict = optimizer.state_dict() + save_embedding(embedding, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True) + embedding_yet_to_be_embedded = True - losses[embedding.step % losses.shape[0]] = loss.item() + write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, steps_per_epoch, { + "loss": f"{loss_step:.7f}", + "learn_rate": scheduler.learn_rate + }) - optimizer.zero_grad() - loss.backward() - optimizer.step() + if images_dir is not None and steps_done % create_image_every == 0: + forced_filename = f'{embedding_name}-{steps_done}' + last_saved_image = os.path.join(images_dir, forced_filename) - steps_done = embedding.step + 1 + shared.sd_model.first_stage_model.to(devices.device) - epoch_num = embedding.step // len(ds) - epoch_step = embedding.step % len(ds) + p = processing.StableDiffusionProcessingTxt2Img( + sd_model=shared.sd_model, + do_not_save_grid=True, + do_not_save_samples=True, + do_not_reload_embeddings=True, + ) - pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{len(ds)}]loss: {losses.mean():.7f}") + if preview_from_txt2img: + p.prompt = preview_prompt + p.negative_prompt = preview_negative_prompt + p.steps = preview_steps + p.sampler_name = sd_samplers.samplers[preview_sampler_index].name + p.cfg_scale = preview_cfg_scale + p.seed = preview_seed + p.width = preview_width + p.height = preview_height + else: + p.prompt = batch.cond_text[0] + p.steps = 20 + p.width = training_width + p.height = training_height - if embedding_dir is not None and steps_done % save_embedding_every == 0: - # Before saving, change name to match current checkpoint. - embedding_name_every = f'{embedding_name}-{steps_done}' - last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt') - save_embedding(embedding, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True) - embedding_yet_to_be_embedded = True + preview_text = p.prompt - write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), { - "loss": f"{losses.mean():.7f}", - "learn_rate": scheduler.learn_rate - }) + processed = processing.process_images(p) + image = processed.images[0] if len(processed.images) > 0 else None - if images_dir is not None and steps_done % create_image_every == 0: - forced_filename = f'{embedding_name}-{steps_done}' - last_saved_image = os.path.join(images_dir, forced_filename) + if unload: + shared.sd_model.first_stage_model.to(devices.cpu) - shared.sd_model.first_stage_model.to(devices.device) + if image is not None: + shared.state.current_image = image + last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) + last_saved_image += f", prompt: {preview_text}" - p = processing.StableDiffusionProcessingTxt2Img( - sd_model=shared.sd_model, - do_not_save_grid=True, - do_not_save_samples=True, - do_not_reload_embeddings=True, - ) + if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded: - if preview_from_txt2img: - p.prompt = preview_prompt - p.negative_prompt = preview_negative_prompt - p.steps = preview_steps - p.sampler_name = sd_samplers.samplers[preview_sampler_index].name - p.cfg_scale = preview_cfg_scale - p.seed = preview_seed - p.width = preview_width - p.height = preview_height - else: - p.prompt = entries[0].cond_text - p.steps = 20 - p.width = training_width - p.height = training_height + last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png') - preview_text = p.prompt + info = PngImagePlugin.PngInfo() + data = torch.load(last_saved_file) + info.add_text("sd-ti-embedding", embedding_to_b64(data)) - processed = processing.process_images(p) - image = processed.images[0] + title = "<{}>".format(data.get('name', '???')) - if unload: - shared.sd_model.first_stage_model.to(devices.cpu) + try: + vectorSize = list(data['string_to_param'].values())[0].shape[0] + except Exception as e: + vectorSize = '?' - shared.state.current_image = image + checkpoint = sd_models.select_checkpoint() + footer_left = checkpoint.model_name + footer_mid = '[{}]'.format(checkpoint.hash) + footer_right = '{}v {}s'.format(vectorSize, steps_done) - if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded: + captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right) + captioned_image = insert_image_data_embed(captioned_image, data) - last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png') + captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info) + embedding_yet_to_be_embedded = False - info = PngImagePlugin.PngInfo() - data = torch.load(last_saved_file) - info.add_text("sd-ti-embedding", embedding_to_b64(data)) + last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) + last_saved_image += f", prompt: {preview_text}" - title = "<{}>".format(data.get('name', '???')) + shared.state.job_no = embedding.step - try: - vectorSize = list(data['string_to_param'].values())[0].shape[0] - except Exception as e: - vectorSize = '?' - - checkpoint = sd_models.select_checkpoint() - footer_left = checkpoint.model_name - footer_mid = '[{}]'.format(checkpoint.hash) - footer_right = '{}v {}s'.format(vectorSize, steps_done) - - captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right) - captioned_image = insert_image_data_embed(captioned_image, data) - - captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info) - embedding_yet_to_be_embedded = False - - last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) - last_saved_image += f", prompt: {preview_text}" - - shared.state.job_no = embedding.step - - shared.state.textinfo = f""" + shared.state.textinfo = f"""

-Loss: {losses.mean():.7f}
+Loss: {loss_step:.7f}
Step: {embedding.step}
-Last prompt: {html.escape(entries[0].cond_text)}
+Last prompt: {html.escape(batch.cond_text[0])}
Last saved embedding: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}

""" - - filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') - save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True) - shared.sd_model.first_stage_model.to(devices.device) + filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') + save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True) + except Exception: + print(traceback.format_exc(), file=sys.stderr) + pass + finally: + pbar.leave = False + pbar.close() + shared.sd_model.first_stage_model.to(devices.device) return embedding, filename diff --git a/modules/ui.py b/modules/ui.py index a5953fce..9d2a1cbf 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1262,7 +1262,7 @@ def create_ui(wrap_gradio_gpu_call): with gr.Column(): with gr.Row(): interrupt_preprocessing = gr.Button("Interrupt") - run_preprocess = gr.Button(value="Preprocess", variant='primary') + run_preprocess = gr.Button(value="Preprocess", variant='primary') process_split.change( fn=lambda show: gr_show(show), @@ -1289,6 +1289,7 @@ def create_ui(wrap_gradio_gpu_call): hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001") batch_size = gr.Number(label='Batch size', value=1, precision=0) + gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0) dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images") log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion") template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt")) @@ -1299,6 +1300,11 @@ def create_ui(wrap_gradio_gpu_call): save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0) save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True) preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False) + with gr.Row(): + shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False) + tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0) + with gr.Row(): + latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random']) with gr.Row(): interrupt_training = gr.Button(value="Interrupt") @@ -1387,11 +1393,15 @@ def create_ui(wrap_gradio_gpu_call): train_embedding_name, embedding_learn_rate, batch_size, + gradient_step, dataset_directory, log_directory, training_width, training_height, steps, + shuffle_tags, + tag_drop_out, + latent_sampling_method, create_image_every, save_embedding_every, template_file, @@ -1412,11 +1422,15 @@ def create_ui(wrap_gradio_gpu_call): train_hypernetwork_name, hypernetwork_learn_rate, batch_size, + gradient_step, dataset_directory, log_directory, training_width, training_height, steps, + shuffle_tags, + tag_drop_out, + latent_sampling_method, create_image_every, save_embedding_every, template_file,