Added message to tell user that the textual inversion page is under construction instead of having it empty.

This commit is contained in:
ZeroCool940711 2022-09-24 05:28:54 -07:00
parent b319bf8ec7
commit 37b2090844

View File

@ -30,7 +30,7 @@ from accelerate.logging import get_logger
from accelerate.utils import set_seed
from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
#from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from huggingface_hub import HfFolder, Repository, whoami
from PIL import Image
from torchvision import transforms
@ -48,7 +48,6 @@ def parse_args():
"--pretrained_model_name_or_path",
type=str,
default=None,
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
@ -58,17 +57,16 @@ def parse_args():
help="Pretrained tokenizer name or path if not the same as model_name",
)
parser.add_argument(
"--train_data_dir", type=str, default=None, required=True, help="A folder containing the training data."
"--train_data_dir", type=str, default=None, help="A folder containing the training data."
)
parser.add_argument(
"--placeholder_token",
type=str,
default=None,
required=True,
help="A token to use as a placeholder for the concept.",
)
parser.add_argument(
"--initializer_token", type=str, default=None, required=True, help="A token to use as initializer word."
"--initializer_token", type=str, default=None, help="A token to use as initializer word."
)
parser.add_argument("--learnable_property", type=str, default="object", help="Choose between 'object' and 'style'")
parser.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.")
@ -172,8 +170,56 @@ def parse_args():
),
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument(
"--checkpoint_frequency",
type=int,
default=500,
help="How often to save a checkpoint and sample image",
)
parser.add_argument(
"--stable_sample_batches",
type=int,
default=0,
help="Number of fixed seed sample batches to generate per checkpoint",
)
parser.add_argument(
"--random_sample_batches",
type=int,
default=1,
help="Number of random seed sample batches to generate per checkpoint",
)
parser.add_argument(
"--sample_batch_size",
type=int,
default=1,
help="Number of samples to generate per batch",
)
parser.add_argument(
"--custom_templates",
type=str,
default=None,
help=(
"A comma-delimited list of custom template to use for samples, using {} as a placeholder for the concept."
),
)
parser.add_argument(
"--resume_from",
type=str,
default=None,
help="Path to a directory to resume training from (ie, logs/token_name/2022-09-22T23-36-27)"
)
parser.add_argument(
"--resume_checkpoint",
type=str,
default=None,
help="Path to a specific checkpoint to resume training from (ie, logs/token_name/2022-09-22T23-36-27/checkpoints/something.bin)."
)
args = parser.parse_args()
if args.resume_from is not None:
with open(Path(args.resume_from) / "resume.json", 'rt') as f:
args = parser.parse_args(namespace=argparse.Namespace(**json.load(f)["args"]))
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
@ -184,6 +230,59 @@ def parse_args():
return args
imagenet_templates_small = [
"a photo of a {}",
"a rendering of a {}",
"a cropped photo of the {}",
"the photo of a {}",
"a photo of a clean {}",
"a photo of a dirty {}",
"a dark photo of the {}",
"a photo of my {}",
"a photo of the cool {}",
"a close-up photo of a {}",
"a bright photo of the {}",
"a cropped photo of a {}",
"a photo of the {}",
"a good photo of the {}",
"a photo of one {}",
"a close-up photo of the {}",
"a rendition of the {}",
"a photo of the clean {}",
"a rendition of a {}",
"a photo of a nice {}",
"a good photo of a {}",
"a photo of the nice {}",
"a photo of the small {}",
"a photo of the weird {}",
"a photo of the large {}",
"a photo of a cool {}",
"a photo of a small {}",
]
imagenet_style_templates_small = [
"a painting in the style of {}",
"a rendering in the style of {}",
"a cropped painting in the style of {}",
"the painting in the style of {}",
"a clean painting in the style of {}",
"a dirty painting in the style of {}",
"a dark painting in the style of {}",
"a picture in the style of {}",
"a cool painting in the style of {}",
"a close-up painting in the style of {}",
"a bright painting in the style of {}",
"a cropped painting in the style of {}",
"a good painting in the style of {}",
"a close-up painting in the style of {}",
"a rendition in the style of {}",
"a nice painting in the style of {}",
"a small painting in the style of {}",
"a weird painting in the style of {}",
"a large painting in the style of {}",
]
class TextualInversionDataset(Dataset):
def __init__(
self,
@ -197,6 +296,7 @@ class TextualInversionDataset(Dataset):
set="train",
placeholder_token="*",
center_crop=False,
templates=None
):
self.data_root = data_root
@ -207,7 +307,7 @@ class TextualInversionDataset(Dataset):
self.center_crop = center_crop
self.flip_p = flip_p
self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root) if file_path.lower().endswith(('.png', '.jpg', '.jpeg'))]
self.num_images = len(self.image_paths)
self._length = self.num_images
@ -221,62 +321,8 @@ class TextualInversionDataset(Dataset):
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
}[interpolation]
imagenet_templates_small = [
"a photo of a {}",
"a rendering of a {}",
"a cropped photo of the {}",
"the photo of a {}",
"a photo of a clean {}",
"a photo of a dirty {}",
"a dark photo of the {}",
"a photo of my {}",
"a photo of the cool {}",
"a close-up photo of a {}",
"a bright photo of the {}",
"a cropped photo of a {}",
"a photo of the {}",
"a good photo of the {}",
"a photo of one {}",
"a close-up photo of the {}",
"a rendition of the {}",
"a photo of the clean {}",
"a rendition of a {}",
"a photo of a nice {}",
"a good photo of a {}",
"a photo of the nice {}",
"a photo of the small {}",
"a photo of the weird {}",
"a photo of the large {}",
"a photo of a cool {}",
"a photo of a small {}",
]
imagenet_style_templates_small = [
"a painting in the style of {}",
"a rendering in the style of {}",
"a cropped painting in the style of {}",
"the painting in the style of {}",
"a clean painting in the style of {}",
"a dirty painting in the style of {}",
"a dark painting in the style of {}",
"a picture in the style of {}",
"a cool painting in the style of {}",
"a close-up painting in the style of {}",
"a bright painting in the style of {}",
"a cropped painting in the style of {}",
"a good painting in the style of {}",
"a close-up painting in the style of {}",
"a rendition in the style of {}",
"a nice painting in the style of {}",
"a small painting in the style of {}",
"a weird painting in the style of {}",
"a large painting in the style of {}",
]
self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
self.templates = templates
self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
def __len__(self):
@ -337,9 +383,146 @@ def freeze_params(params):
param.requires_grad = False
def save_resume_file(basepath, args, extra = {}):
info = {"args": vars(args)}
info["args"].update(extra)
with open(Path(basepath) / "resume.json", "w") as f:
json.dump(info, f, indent=4)
class Checkpointer:
def __init__(
self,
accelerator,
vae,
unet,
tokenizer,
placeholder_token,
placeholder_token_id,
templates,
output_dir,
random_sample_batches,
sample_batch_size,
stable_sample_batches,
seed
):
self.accelerator = accelerator
self.vae = vae
self.unet = unet
self.tokenizer = tokenizer
self.placeholder_token = placeholder_token
self.placeholder_token_id = placeholder_token_id
self.templates = templates
self.output_dir = output_dir
self.random_sample_batches = random_sample_batches
self.sample_batch_size = sample_batch_size
self.stable_sample_batches = stable_sample_batches
self.seed = seed
def checkpoint(self, step, text_encoder, save_samples=True):
print("Saving checkpoint for step %d..." % step)
with torch.autocast("cuda"):
checkpoints_path = self.output_dir / "checkpoints"
checkpoints_path.mkdir(exist_ok=True, parents=True)
unwrapped = self.accelerator.unwrap_model(text_encoder)
# Save a checkpoint
learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id]
learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()}
filename = f"learned_embeds_%s_%d.bin" % (slugify(self.placeholder_token), step)
torch.save(learned_embeds_dict, checkpoints_path / filename)
torch.save(learned_embeds_dict, checkpoints_path / "last.bin")
del unwrapped
return checkpoints_path / "last.bin"
def save_samples(self, step, text_encoder, height, width, guidance_scale, eta, num_inference_steps):
samples_path = self.output_dir / "samples"
samples_path.mkdir(exist_ok=True, parents=True)
checker = NoCheck()
with torch.autocast("cuda"):
unwrapped = self.accelerator.unwrap_model(text_encoder)
# Save a sample image
pipeline = StableDiffusionPipeline(
text_encoder=unwrapped,
vae=self.vae,
unet=self.unet,
tokenizer=self.tokenizer,
scheduler=PNDMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
),
safety_checker=NoCheck(),
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
).to('cuda')
pipeline.enable_attention_slicing()
if self.stable_sample_batches > 0:
stable_latents = torch.randn(
(self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8),
device=pipeline.device,
generator=torch.Generator(device=pipeline.device).manual_seed(self.seed),
)
stable_prompts = [choice.format(self.placeholder_token) for choice in (self.templates * self.sample_batch_size)[:self.sample_batch_size]]
# Generate and save stable samples
for i in range(0, self.stable_sample_batches):
samples = pipeline(
prompt=stable_prompts,
height=max(512, height),
latents=stable_latents,
width=max(512, width),
guidance_scale=guidance_scale,
eta=eta,
num_inference_steps=num_inference_steps,
output_type='pil'
)["sample"]
for idx, im in enumerate(samples):
filename = f"stable_sample_%d_%d_step_%d.png" % (i+1, idx+1, step)
im.save(samples_path / filename)
prompts = [choice.format(self.placeholder_token) for choice in random.choices(self.templates, k=self.sample_batch_size)]
# Generate and save random samples
for i in range(0, self.random_sample_batches):
samples = pipeline(
prompt=prompts,
height=max(512, height),
width=max(512, width),
guidance_scale=guidance_scale,
eta=eta,
num_inference_steps=num_inference_steps,
output_type='pil'
)["sample"]
for idx, im in enumerate(samples):
filename = f"step_%d_sample_%d_%d.png" % (step, i+1, idx+1)
im.save(samples_path / filename)
del im
del pipeline
del unwrapped
def main():
args = parse_args()
#logging_dir = os.path.join(args.output_dir, args.logging_dir)
global_step_offset = 0
if args.resume_from is not None:
basepath = Path(args.resume_from)
print("Resuming state from %s" % args.resume_from)
with open(basepath / "resume.json", 'r') as f:
state = json.load(f)
global_step_offset = state["args"]["global_step"]
print("We've trained %d steps so far" % global_step_offset)
else:
now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
basepath = Path(args.logging_dir) / slugify(args.placeholder_token) / now
basepath.mkdir(exist_ok=True, parents=True)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
@ -394,24 +577,35 @@ def main():
# Load models and create wrapper for stable diffusion
text_encoder = CLIPTextModel.from_pretrained(
args.pretrained_model_name_or_path + '/text_encoder',
args.pretrained_model_name_or_path + '/text_encoder',
)
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path + '/vae',
args.pretrained_model_name_or_path + '/vae',
)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path + '/unet',
args.pretrained_model_name_or_path + '/unet',
)
base_templates = imagenet_style_templates_small if args.learnable_property == "style" else imagenet_templates_small
if args.custom_templates:
templates = args.custom_templates.split(",")
else:
templates = base_templates
slice_size = unet.config.attention_head_dim // 2
unet.set_attention_slice(slice_size)
#vae = vae.to("cuda").half()
#unet = unet.to("cuda").half()
# vae = vae.to("cuda").half()
#unet = unet.to("cuda").half()
# Resize the token embeddings as we are adding new special tokens to the tokenizer
text_encoder.resize_token_embeddings(len(tokenizer))
# Initialise the newly added placeholder token with the embeddings of the initializer token
token_embeds = text_encoder.get_input_embeddings().weight.data
token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]
if args.resume_checkpoint is not None:
token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[args.placeholder_token]
else:
token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]
# Freeze vae and unet
freeze_params(vae.parameters())
@ -424,6 +618,21 @@ def main():
)
freeze_params(params_to_freeze)
checkpointer = Checkpointer(
accelerator=accelerator,
vae=vae,
unet=unet,
tokenizer=tokenizer,
placeholder_token=args.placeholder_token,
placeholder_token_id=placeholder_token_id,
templates=templates,
output_dir=basepath,
sample_batch_size=args.sample_batch_size,
random_sample_batches=args.random_sample_batches,
stable_sample_batches=args.stable_sample_batches,
seed=args.seed
)
if args.scale_lr:
args.learning_rate = (
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
@ -452,6 +661,7 @@ def main():
learnable_property=args.learnable_property,
center_crop=args.center_crop,
set="train",
templates=base_templates
)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True)
@ -508,87 +718,116 @@ def main():
progress_bar.set_description("Steps")
global_step = 0
for epoch in range(args.num_train_epochs):
text_encoder.train()
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(text_encoder):
# Convert images to latent space
latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach().half()
latents = latents * 0.18215
try:
for epoch in range(args.num_train_epochs):
text_encoder.train()
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(text_encoder):
# Convert images to latent space
latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach().half()
latents = latents * 0.18215
# Sample noise that we'll add to the latents
noise = torch.randn(latents.shape).to(latents.device)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device).long()
# Sample noise that we'll add to the latents
noise = torch.randn(latents.shape).to(latents.device)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device).long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Get the text embedding for conditioning
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
# Get the text embedding for conditioning
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
# Predict the noise residual
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
# Predict the noise residual
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
accelerator.backward(loss)
loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
accelerator.backward(loss)
# Zero out the gradients for all token embeddings except the newly added
# embeddings for the concept, as we only want to optimize the concept embeddings
if accelerator.num_processes > 1:
grads = text_encoder.module.get_input_embeddings().weight.grad
else:
grads = text_encoder.get_input_embeddings().weight.grad
# Get the index for tokens that we want to zero the grads for
index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id
grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0)
# Zero out the gradients for all token embeddings except the newly added
# embeddings for the concept, as we only want to optimize the concept embeddings
if accelerator.num_processes > 1:
grads = text_encoder.module.get_input_embeddings().weight.grad
else:
grads = text_encoder.get_input_embeddings().weight.grad
# Get the index for tokens that we want to zero the grads for
index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id
grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
#accelerator.log(logs, step=global_step)
if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process:
checkpointer.checkpoint(global_step + global_step_offset, text_encoder)
save_resume_file(basepath, args, {
"global_step": global_step + global_step_offset,
"resume_checkpoint": str(Path(basepath) / "checkpoints" / "last.bin")
})
checkpointer.save_samples(global_step + global_step_offset, text_encoder,
args.resolution, args.resolution, 7.5, 0.0, 25)
if global_step >= args.max_train_steps:
break
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
#accelerator.log(logs, step=global_step)
accelerator.wait_for_everyone()
if global_step >= args.max_train_steps:
break
# Create the pipeline using using the trained modules and save it.
if accelerator.is_main_process:
pipeline = StableDiffusionPipeline(
text_encoder=accelerator.unwrap_model(text_encoder),
vae=vae,
unet=unet,
tokenizer=tokenizer,
scheduler=PNDMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
),
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
)
#pipeline.save_pretrained(args.output_dir)
# Also save the newly trained embeddings
learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]
learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}
torch.save(learned_embeds_dict, os.path.join(args.train_data_dir, f"learned_embeds.bin"))
accelerator.wait_for_everyone()
if args.push_to_hub:
repo.push_to_hub(
args, pipeline, repo, commit_message="End of training", blocking=False, auto_lfs_prune=True
# Create the pipeline using using the trained modules and save it.
if accelerator.is_main_process:
pipeline = StableDiffusionPipeline(
text_encoder=accelerator.unwrap_model(text_encoder),
vae=vae,
unet=unet,
tokenizer=tokenizer,
scheduler=PNDMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
),
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
)
#pipeline.save_pretrained(args.output_dir)
# Also save the newly trained embeddings
learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]
learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}
torch.save(learned_embeds_dict, basepath / f"learned_embeds.bin")
if global_step % args.checkpoint_frequency != 0:
checkpointer.save_samples(global_step + global_step_offset, text_encoder,
args.resolution, args.resolution, 7.5, 0.0, 25)
accelerator.end_training()
print("Saving resume state")
save_resume_file(basepath, args, {
"global_step": global_step + global_step_offset,
"resume_checkpoint": str(Path(basepath) / "checkpoints" / "last.bin")
})
if args.push_to_hub:
repo.push_to_hub(
args, pipeline, repo, commit_message="End of training", blocking=False, auto_lfs_prune=True)
accelerator.end_training()
except KeyboardInterrupt:
if accelerator.is_main_process:
print("Interrupted, saving checkpoint and resume state...")
checkpointer.checkpoint(global_step + global_step_offset, text_encoder)
save_resume_file(basepath, args, {
"global_step": global_step + global_step_offset,
"resume_checkpoint": str(Path(basepath) / "checkpoints" / "last.bin")
})
quit()
def layout():
st.write("Textual Inversion")
st.write("Textual Inversion")
st.info("Under Construction. :construction_worker:")