diff --git a/scripts/textual_inversion.py b/scripts/textual_inversion.py index 97201df..46dc866 100644 --- a/scripts/textual_inversion.py +++ b/scripts/textual_inversion.py @@ -30,7 +30,7 @@ from accelerate.logging import get_logger from accelerate.utils import set_seed from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel from diffusers.optimization import get_scheduler -from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker +#from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from huggingface_hub import HfFolder, Repository, whoami from PIL import Image from torchvision import transforms @@ -48,7 +48,6 @@ def parse_args(): "--pretrained_model_name_or_path", type=str, default=None, - required=True, help="Path to pretrained model or model identifier from huggingface.co/models.", ) parser.add_argument( @@ -58,17 +57,16 @@ def parse_args(): help="Pretrained tokenizer name or path if not the same as model_name", ) parser.add_argument( - "--train_data_dir", type=str, default=None, required=True, help="A folder containing the training data." + "--train_data_dir", type=str, default=None, help="A folder containing the training data." ) parser.add_argument( "--placeholder_token", type=str, default=None, - required=True, help="A token to use as a placeholder for the concept.", ) parser.add_argument( - "--initializer_token", type=str, default=None, required=True, help="A token to use as initializer word." + "--initializer_token", type=str, default=None, help="A token to use as initializer word." ) parser.add_argument("--learnable_property", type=str, default="object", help="Choose between 'object' and 'style'") parser.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.") @@ -172,8 +170,56 @@ def parse_args(): ), ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--checkpoint_frequency", + type=int, + default=500, + help="How often to save a checkpoint and sample image", + ) + parser.add_argument( + "--stable_sample_batches", + type=int, + default=0, + help="Number of fixed seed sample batches to generate per checkpoint", + ) + parser.add_argument( + "--random_sample_batches", + type=int, + default=1, + help="Number of random seed sample batches to generate per checkpoint", + ) + parser.add_argument( + "--sample_batch_size", + type=int, + default=1, + help="Number of samples to generate per batch", + ) + parser.add_argument( + "--custom_templates", + type=str, + default=None, + help=( + "A comma-delimited list of custom template to use for samples, using {} as a placeholder for the concept." + ), + ) + parser.add_argument( + "--resume_from", + type=str, + default=None, + help="Path to a directory to resume training from (ie, logs/token_name/2022-09-22T23-36-27)" + ) + parser.add_argument( + "--resume_checkpoint", + type=str, + default=None, + help="Path to a specific checkpoint to resume training from (ie, logs/token_name/2022-09-22T23-36-27/checkpoints/something.bin)." + ) args = parser.parse_args() + if args.resume_from is not None: + with open(Path(args.resume_from) / "resume.json", 'rt') as f: + args = parser.parse_args(namespace=argparse.Namespace(**json.load(f)["args"])) + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) if env_local_rank != -1 and env_local_rank != args.local_rank: args.local_rank = env_local_rank @@ -184,6 +230,59 @@ def parse_args(): return args +imagenet_templates_small = [ + "a photo of a {}", + "a rendering of a {}", + "a cropped photo of the {}", + "the photo of a {}", + "a photo of a clean {}", + "a photo of a dirty {}", + "a dark photo of the {}", + "a photo of my {}", + "a photo of the cool {}", + "a close-up photo of a {}", + "a bright photo of the {}", + "a cropped photo of a {}", + "a photo of the {}", + "a good photo of the {}", + "a photo of one {}", + "a close-up photo of the {}", + "a rendition of the {}", + "a photo of the clean {}", + "a rendition of a {}", + "a photo of a nice {}", + "a good photo of a {}", + "a photo of the nice {}", + "a photo of the small {}", + "a photo of the weird {}", + "a photo of the large {}", + "a photo of a cool {}", + "a photo of a small {}", +] + +imagenet_style_templates_small = [ + "a painting in the style of {}", + "a rendering in the style of {}", + "a cropped painting in the style of {}", + "the painting in the style of {}", + "a clean painting in the style of {}", + "a dirty painting in the style of {}", + "a dark painting in the style of {}", + "a picture in the style of {}", + "a cool painting in the style of {}", + "a close-up painting in the style of {}", + "a bright painting in the style of {}", + "a cropped painting in the style of {}", + "a good painting in the style of {}", + "a close-up painting in the style of {}", + "a rendition in the style of {}", + "a nice painting in the style of {}", + "a small painting in the style of {}", + "a weird painting in the style of {}", + "a large painting in the style of {}", +] + + class TextualInversionDataset(Dataset): def __init__( self, @@ -197,6 +296,7 @@ class TextualInversionDataset(Dataset): set="train", placeholder_token="*", center_crop=False, + templates=None ): self.data_root = data_root @@ -207,7 +307,7 @@ class TextualInversionDataset(Dataset): self.center_crop = center_crop self.flip_p = flip_p - self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)] + self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root) if file_path.lower().endswith(('.png', '.jpg', '.jpeg'))] self.num_images = len(self.image_paths) self._length = self.num_images @@ -221,62 +321,8 @@ class TextualInversionDataset(Dataset): "bicubic": PIL.Image.BICUBIC, "lanczos": PIL.Image.LANCZOS, }[interpolation] - - imagenet_templates_small = [ - "a photo of a {}", - "a rendering of a {}", - "a cropped photo of the {}", - "the photo of a {}", - "a photo of a clean {}", - "a photo of a dirty {}", - "a dark photo of the {}", - "a photo of my {}", - "a photo of the cool {}", - "a close-up photo of a {}", - "a bright photo of the {}", - "a cropped photo of a {}", - "a photo of the {}", - "a good photo of the {}", - "a photo of one {}", - "a close-up photo of the {}", - "a rendition of the {}", - "a photo of the clean {}", - "a rendition of a {}", - "a photo of a nice {}", - "a good photo of a {}", - "a photo of the nice {}", - "a photo of the small {}", - "a photo of the weird {}", - "a photo of the large {}", - "a photo of a cool {}", - "a photo of a small {}", - ] - - imagenet_style_templates_small = [ - "a painting in the style of {}", - "a rendering in the style of {}", - "a cropped painting in the style of {}", - "the painting in the style of {}", - "a clean painting in the style of {}", - "a dirty painting in the style of {}", - "a dark painting in the style of {}", - "a picture in the style of {}", - "a cool painting in the style of {}", - "a close-up painting in the style of {}", - "a bright painting in the style of {}", - "a cropped painting in the style of {}", - "a good painting in the style of {}", - "a close-up painting in the style of {}", - "a rendition in the style of {}", - "a nice painting in the style of {}", - "a small painting in the style of {}", - "a weird painting in the style of {}", - "a large painting in the style of {}", - ] - - - self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small + self.templates = templates self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p) def __len__(self): @@ -337,9 +383,146 @@ def freeze_params(params): param.requires_grad = False +def save_resume_file(basepath, args, extra = {}): + info = {"args": vars(args)} + info["args"].update(extra) + with open(Path(basepath) / "resume.json", "w") as f: + json.dump(info, f, indent=4) + + +class Checkpointer: + def __init__( + self, + accelerator, + vae, + unet, + tokenizer, + placeholder_token, + placeholder_token_id, + templates, + output_dir, + random_sample_batches, + sample_batch_size, + stable_sample_batches, + seed + ): + self.accelerator = accelerator + self.vae = vae + self.unet = unet + self.tokenizer = tokenizer + self.placeholder_token = placeholder_token + self.placeholder_token_id = placeholder_token_id + self.templates = templates + self.output_dir = output_dir + self.random_sample_batches = random_sample_batches + self.sample_batch_size = sample_batch_size + self.stable_sample_batches = stable_sample_batches + self.seed = seed + + + def checkpoint(self, step, text_encoder, save_samples=True): + print("Saving checkpoint for step %d..." % step) + with torch.autocast("cuda"): + checkpoints_path = self.output_dir / "checkpoints" + checkpoints_path.mkdir(exist_ok=True, parents=True) + + unwrapped = self.accelerator.unwrap_model(text_encoder) + + # Save a checkpoint + learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] + learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()} + + filename = f"learned_embeds_%s_%d.bin" % (slugify(self.placeholder_token), step) + torch.save(learned_embeds_dict, checkpoints_path / filename) + torch.save(learned_embeds_dict, checkpoints_path / "last.bin") + del unwrapped + return checkpoints_path / "last.bin" + + + def save_samples(self, step, text_encoder, height, width, guidance_scale, eta, num_inference_steps): + samples_path = self.output_dir / "samples" + samples_path.mkdir(exist_ok=True, parents=True) + checker = NoCheck() + + with torch.autocast("cuda"): + unwrapped = self.accelerator.unwrap_model(text_encoder) + # Save a sample image + pipeline = StableDiffusionPipeline( + text_encoder=unwrapped, + vae=self.vae, + unet=self.unet, + tokenizer=self.tokenizer, + scheduler=PNDMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True + ), + safety_checker=NoCheck(), + feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), + ).to('cuda') + pipeline.enable_attention_slicing() + + if self.stable_sample_batches > 0: + stable_latents = torch.randn( + (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), + device=pipeline.device, + generator=torch.Generator(device=pipeline.device).manual_seed(self.seed), + ) + + stable_prompts = [choice.format(self.placeholder_token) for choice in (self.templates * self.sample_batch_size)[:self.sample_batch_size]] + + # Generate and save stable samples + for i in range(0, self.stable_sample_batches): + samples = pipeline( + prompt=stable_prompts, + height=max(512, height), + latents=stable_latents, + width=max(512, width), + guidance_scale=guidance_scale, + eta=eta, + num_inference_steps=num_inference_steps, + output_type='pil' + )["sample"] + for idx, im in enumerate(samples): + filename = f"stable_sample_%d_%d_step_%d.png" % (i+1, idx+1, step) + im.save(samples_path / filename) + + prompts = [choice.format(self.placeholder_token) for choice in random.choices(self.templates, k=self.sample_batch_size)] + # Generate and save random samples + for i in range(0, self.random_sample_batches): + samples = pipeline( + prompt=prompts, + height=max(512, height), + width=max(512, width), + guidance_scale=guidance_scale, + eta=eta, + num_inference_steps=num_inference_steps, + output_type='pil' + )["sample"] + for idx, im in enumerate(samples): + filename = f"step_%d_sample_%d_%d.png" % (step, i+1, idx+1) + im.save(samples_path / filename) + + del im + del pipeline + del unwrapped + + def main(): args = parse_args() - #logging_dir = os.path.join(args.output_dir, args.logging_dir) + + global_step_offset = 0 + if args.resume_from is not None: + basepath = Path(args.resume_from) + print("Resuming state from %s" % args.resume_from) + with open(basepath / "resume.json", 'r') as f: + state = json.load(f) + global_step_offset = state["args"]["global_step"] + + print("We've trained %d steps so far" % global_step_offset) + else: + now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + basepath = Path(args.logging_dir) / slugify(args.placeholder_token) / now + basepath.mkdir(exist_ok=True, parents=True) + accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, @@ -394,24 +577,35 @@ def main(): # Load models and create wrapper for stable diffusion text_encoder = CLIPTextModel.from_pretrained( - args.pretrained_model_name_or_path + '/text_encoder', + args.pretrained_model_name_or_path + '/text_encoder', ) vae = AutoencoderKL.from_pretrained( - args.pretrained_model_name_or_path + '/vae', + args.pretrained_model_name_or_path + '/vae', ) unet = UNet2DConditionModel.from_pretrained( - args.pretrained_model_name_or_path + '/unet', + args.pretrained_model_name_or_path + '/unet', ) + + base_templates = imagenet_style_templates_small if args.learnable_property == "style" else imagenet_templates_small + if args.custom_templates: + templates = args.custom_templates.split(",") + else: + templates = base_templates + slice_size = unet.config.attention_head_dim // 2 unet.set_attention_slice(slice_size) - #vae = vae.to("cuda").half() - #unet = unet.to("cuda").half() +# vae = vae.to("cuda").half() +#unet = unet.to("cuda").half() # Resize the token embeddings as we are adding new special tokens to the tokenizer text_encoder.resize_token_embeddings(len(tokenizer)) # Initialise the newly added placeholder token with the embeddings of the initializer token token_embeds = text_encoder.get_input_embeddings().weight.data - token_embeds[placeholder_token_id] = token_embeds[initializer_token_id] + + if args.resume_checkpoint is not None: + token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[args.placeholder_token] + else: + token_embeds[placeholder_token_id] = token_embeds[initializer_token_id] # Freeze vae and unet freeze_params(vae.parameters()) @@ -424,6 +618,21 @@ def main(): ) freeze_params(params_to_freeze) + checkpointer = Checkpointer( + accelerator=accelerator, + vae=vae, + unet=unet, + tokenizer=tokenizer, + placeholder_token=args.placeholder_token, + placeholder_token_id=placeholder_token_id, + templates=templates, + output_dir=basepath, + sample_batch_size=args.sample_batch_size, + random_sample_batches=args.random_sample_batches, + stable_sample_batches=args.stable_sample_batches, + seed=args.seed + ) + if args.scale_lr: args.learning_rate = ( args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes @@ -452,6 +661,7 @@ def main(): learnable_property=args.learnable_property, center_crop=args.center_crop, set="train", + templates=base_templates ) train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True) @@ -508,87 +718,116 @@ def main(): progress_bar.set_description("Steps") global_step = 0 - for epoch in range(args.num_train_epochs): - text_encoder.train() - for step, batch in enumerate(train_dataloader): - with accelerator.accumulate(text_encoder): - # Convert images to latent space - latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach().half() - latents = latents * 0.18215 + try: + for epoch in range(args.num_train_epochs): + text_encoder.train() + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(text_encoder): + # Convert images to latent space + latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach().half() + latents = latents * 0.18215 - # Sample noise that we'll add to the latents - noise = torch.randn(latents.shape).to(latents.device) - bsz = latents.shape[0] - # Sample a random timestep for each image - timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device).long() + # Sample noise that we'll add to the latents + noise = torch.randn(latents.shape).to(latents.device) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device).long() - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - # Get the text embedding for conditioning - encoder_hidden_states = text_encoder(batch["input_ids"])[0] + # Get the text embedding for conditioning + encoder_hidden_states = text_encoder(batch["input_ids"])[0] - # Predict the noise residual - noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + # Predict the noise residual + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() - accelerator.backward(loss) + loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() + accelerator.backward(loss) - # Zero out the gradients for all token embeddings except the newly added - # embeddings for the concept, as we only want to optimize the concept embeddings - if accelerator.num_processes > 1: - grads = text_encoder.module.get_input_embeddings().weight.grad - else: - grads = text_encoder.get_input_embeddings().weight.grad - # Get the index for tokens that we want to zero the grads for - index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id - grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0) + # Zero out the gradients for all token embeddings except the newly added + # embeddings for the concept, as we only want to optimize the concept embeddings + if accelerator.num_processes > 1: + grads = text_encoder.module.get_input_embeddings().weight.grad + else: + grads = text_encoder.get_input_embeddings().weight.grad + # Get the index for tokens that we want to zero the grads for + index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id + grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0) - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad() + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - progress_bar.update(1) - global_step += 1 + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 - logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} - progress_bar.set_postfix(**logs) - #accelerator.log(logs, step=global_step) + if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process: + checkpointer.checkpoint(global_step + global_step_offset, text_encoder) + save_resume_file(basepath, args, { + "global_step": global_step + global_step_offset, + "resume_checkpoint": str(Path(basepath) / "checkpoints" / "last.bin") + }) + checkpointer.save_samples(global_step + global_step_offset, text_encoder, + args.resolution, args.resolution, 7.5, 0.0, 25) - if global_step >= args.max_train_steps: - break + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + #accelerator.log(logs, step=global_step) - accelerator.wait_for_everyone() + if global_step >= args.max_train_steps: + break - # Create the pipeline using using the trained modules and save it. - if accelerator.is_main_process: - pipeline = StableDiffusionPipeline( - text_encoder=accelerator.unwrap_model(text_encoder), - vae=vae, - unet=unet, - tokenizer=tokenizer, - scheduler=PNDMScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True - ), - safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"), - feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), - ) - #pipeline.save_pretrained(args.output_dir) - # Also save the newly trained embeddings - learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id] - learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()} - torch.save(learned_embeds_dict, os.path.join(args.train_data_dir, f"learned_embeds.bin")) + accelerator.wait_for_everyone() - if args.push_to_hub: - repo.push_to_hub( - args, pipeline, repo, commit_message="End of training", blocking=False, auto_lfs_prune=True + # Create the pipeline using using the trained modules and save it. + if accelerator.is_main_process: + pipeline = StableDiffusionPipeline( + text_encoder=accelerator.unwrap_model(text_encoder), + vae=vae, + unet=unet, + tokenizer=tokenizer, + scheduler=PNDMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True + ), + safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"), + feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), ) + #pipeline.save_pretrained(args.output_dir) + # Also save the newly trained embeddings + learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id] + learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()} + torch.save(learned_embeds_dict, basepath / f"learned_embeds.bin") + if global_step % args.checkpoint_frequency != 0: + checkpointer.save_samples(global_step + global_step_offset, text_encoder, + args.resolution, args.resolution, 7.5, 0.0, 25) - accelerator.end_training() + print("Saving resume state") + save_resume_file(basepath, args, { + "global_step": global_step + global_step_offset, + "resume_checkpoint": str(Path(basepath) / "checkpoints" / "last.bin") + }) + + if args.push_to_hub: + repo.push_to_hub( + args, pipeline, repo, commit_message="End of training", blocking=False, auto_lfs_prune=True) + + accelerator.end_training() + + except KeyboardInterrupt: + if accelerator.is_main_process: + print("Interrupted, saving checkpoint and resume state...") + checkpointer.checkpoint(global_step + global_step_offset, text_encoder) + save_resume_file(basepath, args, { + "global_step": global_step + global_step_offset, + "resume_checkpoint": str(Path(basepath) / "checkpoints" / "last.bin") + }) + quit() def layout(): - st.write("Textual Inversion") \ No newline at end of file + st.write("Textual Inversion") + st.info("Under Construction. :construction_worker:") \ No newline at end of file