2022-10-24 03:31:41 +03:00
# This file is part of sygil-webui (https://github.com/Sygil-Dev/sygil-webui/).
2022-09-26 16:02:48 +03:00
2022-10-24 03:17:50 +03:00
# Copyright 2022 Sygil-Dev team.
2022-09-26 16:02:48 +03:00
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
2023-06-23 05:58:20 +03:00
# along with this program. If not, see <http://www.gnu.org/licenses/>.
2022-09-25 05:44:20 +03:00
import argparse
import itertools
import math
import os
import random
import datetime
import numpy as np
import torch
import torch . nn . functional as F
import torch . utils . checkpoint
from torch . utils . data import Dataset
import PIL
from accelerate import Accelerator
from accelerate . logging import get_logger
from accelerate . utils import set_seed
2023-06-23 05:58:20 +03:00
from diffusers import (
AutoencoderKL ,
DDPMScheduler ,
LMSDiscreteScheduler ,
StableDiffusionPipeline ,
UNet2DConditionModel ,
)
2022-09-25 05:44:20 +03:00
from diffusers . optimization import get_scheduler
from pipelines . stable_diffusion . no_check import NoCheck
from PIL import Image
from torchvision import transforms
from tqdm . auto import tqdm
from transformers import CLIPFeatureExtractor , CLIPTextModel , CLIPTokenizer
from slugify import slugify
import json
import os
logger = get_logger ( __name__ )
def parse_args ( ) :
parser = argparse . ArgumentParser ( description = " Simple example of a training script. " )
parser . add_argument (
" --pretrained_model_name_or_path " ,
type = str ,
default = None ,
help = " Path to pretrained model or model identifier from huggingface.co/models. " ,
)
parser . add_argument (
" --tokenizer_name " ,
type = str ,
default = None ,
help = " Pretrained tokenizer name or path if not the same as model_name " ,
)
parser . add_argument (
2023-06-23 05:58:20 +03:00
" --train_data_dir " ,
type = str ,
default = None ,
help = " A folder containing the training data. " ,
2022-09-25 05:44:20 +03:00
)
parser . add_argument (
" --placeholder_token " ,
type = str ,
default = None ,
help = " A token to use as a placeholder for the concept. " ,
)
parser . add_argument (
2023-06-23 05:58:20 +03:00
" --initializer_token " ,
type = str ,
default = None ,
help = " A token to use as initializer word. " ,
)
parser . add_argument (
" --learnable_property " ,
type = str ,
default = " object " ,
help = " Choose between ' object ' and ' style ' " ,
)
parser . add_argument (
" --repeats " ,
type = int ,
default = 100 ,
help = " How many times to repeat the training data. " ,
2022-09-25 05:44:20 +03:00
)
parser . add_argument (
" --output_dir " ,
type = str ,
default = " text-inversion-model " ,
help = " The output directory where the model predictions and checkpoints will be written. " ,
)
2023-06-23 05:58:20 +03:00
parser . add_argument (
" --seed " , type = int , default = None , help = " A seed for reproducible training. "
)
2022-09-25 05:44:20 +03:00
parser . add_argument (
" --resolution " ,
type = int ,
default = 512 ,
help = (
" The resolution for input images, all the images in the train/validation dataset will be resized to this "
" resolution "
) ,
)
parser . add_argument (
2023-06-23 05:58:20 +03:00
" --center_crop " ,
action = " store_true " ,
help = " Whether to center crop images before resizing to resolution " ,
2022-09-25 05:44:20 +03:00
)
parser . add_argument (
2023-06-23 05:58:20 +03:00
" --train_batch_size " ,
type = int ,
default = 1 ,
help = " Batch size (per device) for the training dataloader. " ,
2022-09-25 05:44:20 +03:00
)
parser . add_argument ( " --num_train_epochs " , type = int , default = 100 )
parser . add_argument (
" --max_train_steps " ,
type = int ,
default = 5000 ,
help = " Total number of training steps to perform. If provided, overrides num_train_epochs. " ,
)
parser . add_argument (
" --gradient_accumulation_steps " ,
type = int ,
default = 1 ,
help = " Number of updates steps to accumulate before performing a backward/update pass. " ,
)
parser . add_argument (
" --learning_rate " ,
type = float ,
default = 1e-4 ,
help = " Initial learning rate (after the potential warmup period) to use. " ,
)
parser . add_argument (
" --scale_lr " ,
action = " store_true " ,
default = True ,
help = " Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size. " ,
)
parser . add_argument (
" --lr_scheduler " ,
type = str ,
default = " constant " ,
help = (
' The scheduler type to use. Choose between [ " linear " , " cosine " , " cosine_with_restarts " , " polynomial " , '
' " constant " , " constant_with_warmup " ] '
) ,
)
parser . add_argument (
2023-06-23 05:58:20 +03:00
" --lr_warmup_steps " ,
type = int ,
default = 500 ,
help = " Number of steps for the warmup in the lr scheduler. " ,
)
parser . add_argument (
" --adam_beta1 " ,
type = float ,
default = 0.9 ,
help = " The beta1 parameter for the Adam optimizer. " ,
)
parser . add_argument (
" --adam_beta2 " ,
type = float ,
default = 0.999 ,
help = " The beta2 parameter for the Adam optimizer. " ,
)
parser . add_argument (
" --adam_weight_decay " , type = float , default = 1e-2 , help = " Weight decay to use. "
)
parser . add_argument (
" --adam_epsilon " ,
type = float ,
default = 1e-08 ,
help = " Epsilon value for the Adam optimizer " ,
2022-09-25 05:44:20 +03:00
)
parser . add_argument (
" --mixed_precision " ,
type = str ,
default = " no " ,
choices = [ " no " , " fp16 " , " bf16 " ] ,
help = (
" Whether to use mixed precision. Choose "
" between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10. "
" and an Nvidia Ampere GPU. "
) ,
)
2023-06-23 05:58:20 +03:00
parser . add_argument (
" --local_rank " ,
type = int ,
default = - 1 ,
help = " For distributed training: local_rank " ,
)
2022-09-25 05:44:20 +03:00
parser . add_argument (
" --checkpoint_frequency " ,
type = int ,
default = 500 ,
help = " How often to save a checkpoint and sample image " ,
)
parser . add_argument (
" --stable_sample_batches " ,
type = int ,
default = 0 ,
help = " Number of fixed seed sample batches to generate per checkpoint " ,
)
parser . add_argument (
" --random_sample_batches " ,
type = int ,
default = 1 ,
help = " Number of random seed sample batches to generate per checkpoint " ,
)
parser . add_argument (
" --sample_batch_size " ,
type = int ,
default = 1 ,
help = " Number of samples to generate per batch " ,
)
parser . add_argument (
" --sample_steps " ,
type = int ,
default = 50 ,
help = " Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes. " ,
)
parser . add_argument (
" --custom_templates " ,
type = str ,
default = None ,
help = (
" A semicolon-delimited list of custom template to use for samples, using {} as a placeholder for the concept. "
) ,
)
parser . add_argument (
" --resume_from " ,
type = str ,
default = None ,
2023-06-23 05:58:20 +03:00
help = " Path to a directory to resume training from (ie, logs/token_name/2022-09-22T23-36-27) " ,
2022-09-25 05:44:20 +03:00
)
parser . add_argument (
" --resume_checkpoint " ,
type = str ,
default = None ,
2023-06-23 05:58:20 +03:00
help = " Path to a specific checkpoint to resume training from (ie, logs/token_name/2022-09-22T23-36-27/checkpoints/something.bin). " ,
2022-09-25 05:44:20 +03:00
)
parser . add_argument (
" --config " ,
type = str ,
default = None ,
2023-06-23 05:58:20 +03:00
help = " Path to a JSON configuration file containing arguments for invoking this script. If resume_from is given, its resume.json takes priority over this. " ,
2022-09-25 05:44:20 +03:00
)
args = parser . parse_args ( )
if args . resume_from is not None :
2023-06-23 05:58:20 +03:00
with open ( f " { args . resume_from } /resume.json " , " rt " ) as f :
args = parser . parse_args (
namespace = argparse . Namespace ( * * json . load ( f ) [ " args " ] )
)
2022-09-25 05:44:20 +03:00
elif args . config is not None :
2023-06-23 05:58:20 +03:00
with open ( args . config , " rt " ) as f :
args = parser . parse_args (
namespace = argparse . Namespace ( * * json . load ( f ) [ " args " ] )
)
2022-09-25 05:44:20 +03:00
env_local_rank = int ( os . environ . get ( " LOCAL_RANK " , - 1 ) )
if env_local_rank != - 1 and env_local_rank != args . local_rank :
args . local_rank = env_local_rank
if args . train_data_dir is None :
raise ValueError ( " You must specify --train_data_dir " )
if args . pretrained_model_name_or_path is None :
raise ValueError ( " You must specify --pretrained_model_name_or_path " )
if args . placeholder_token is None :
raise ValueError ( " You must specify --placeholder_token " )
if args . initializer_token is None :
raise ValueError ( " You must specify --initializer_token " )
if args . output_dir is None :
raise ValueError ( " You must specify --output_dir " )
return args
imagenet_templates_small = [
" a photo of a {} " ,
" a rendering of a {} " ,
" a cropped photo of the {} " ,
" the photo of a {} " ,
" a photo of a clean {} " ,
" a photo of a dirty {} " ,
" a dark photo of the {} " ,
" a photo of my {} " ,
" a photo of the cool {} " ,
" a close-up photo of a {} " ,
" a bright photo of the {} " ,
" a cropped photo of a {} " ,
" a photo of the {} " ,
" a good photo of the {} " ,
" a photo of one {} " ,
" a close-up photo of the {} " ,
" a rendition of the {} " ,
" a photo of the clean {} " ,
" a rendition of a {} " ,
" a photo of a nice {} " ,
" a good photo of a {} " ,
" a photo of the nice {} " ,
" a photo of the small {} " ,
" a photo of the weird {} " ,
" a photo of the large {} " ,
" a photo of a cool {} " ,
" a photo of a small {} " ,
]
imagenet_style_templates_small = [
" a painting in the style of {} " ,
" a rendering in the style of {} " ,
" a cropped painting in the style of {} " ,
" the painting in the style of {} " ,
" a clean painting in the style of {} " ,
" a dirty painting in the style of {} " ,
" a dark painting in the style of {} " ,
" a picture in the style of {} " ,
" a cool painting in the style of {} " ,
" a close-up painting in the style of {} " ,
" a bright painting in the style of {} " ,
" a cropped painting in the style of {} " ,
" a good painting in the style of {} " ,
" a close-up painting in the style of {} " ,
" a rendition in the style of {} " ,
" a nice painting in the style of {} " ,
" a small painting in the style of {} " ,
" a weird painting in the style of {} " ,
" a large painting in the style of {} " ,
]
class TextualInversionDataset ( Dataset ) :
def __init__ (
self ,
data_root ,
tokenizer ,
learnable_property = " object " , # [object, style]
size = 512 ,
repeats = 100 ,
interpolation = " bicubic " ,
set = " train " ,
placeholder_token = " * " ,
center_crop = False ,
2023-06-23 05:58:20 +03:00
templates = None ,
2022-09-25 05:44:20 +03:00
) :
self . data_root = data_root
self . tokenizer = tokenizer
self . learnable_property = learnable_property
self . size = size
self . placeholder_token = placeholder_token
self . center_crop = center_crop
2023-06-23 05:58:20 +03:00
self . image_paths = [
os . path . join ( self . data_root , file_path )
for file_path in os . listdir ( self . data_root )
if file_path . lower ( ) . endswith ( ( " .png " , " .jpg " , " .jpeg " ) )
]
2022-09-25 05:44:20 +03:00
self . num_images = len ( self . image_paths )
self . _length = self . num_images
if set == " train " :
self . _length = self . num_images * repeats
self . interpolation = {
" linear " : PIL . Image . LINEAR ,
" bilinear " : PIL . Image . BILINEAR ,
" bicubic " : PIL . Image . BICUBIC ,
" lanczos " : PIL . Image . LANCZOS ,
} [ interpolation ]
self . templates = templates
self . cache = { }
2023-06-23 05:58:20 +03:00
self . tokenized_templates = [
self . tokenizer (
2022-09-25 05:44:20 +03:00
text . format ( self . placeholder_token ) ,
padding = " max_length " ,
truncation = True ,
max_length = self . tokenizer . model_max_length ,
return_tensors = " pt " ,
2023-06-23 05:58:20 +03:00
) . input_ids [ 0 ]
for text in self . templates
]
2022-09-25 05:44:20 +03:00
def __len__ ( self ) :
return self . _length
def get_example ( self , image_path , flipped ) :
if image_path in self . cache :
return self . cache [ image_path ]
example = { }
image = Image . open ( image_path )
if not image . mode == " RGB " :
image = image . convert ( " RGB " )
# default to score-sde preprocessing
img = np . array ( image ) . astype ( np . uint8 )
if self . center_crop :
crop = min ( img . shape [ 0 ] , img . shape [ 1 ] )
2023-06-23 05:58:20 +03:00
(
h ,
w ,
) = (
2022-09-25 05:44:20 +03:00
img . shape [ 0 ] ,
img . shape [ 1 ] ,
)
2023-06-23 05:58:20 +03:00
img = img [
( h - crop ) / / 2 : ( h + crop ) / / 2 , ( w - crop ) / / 2 : ( w + crop ) / / 2
]
2022-09-25 05:44:20 +03:00
image = Image . fromarray ( img )
image = image . resize ( ( self . size , self . size ) , resample = self . interpolation )
image = transforms . RandomHorizontalFlip ( p = 1 if flipped else 0 ) ( image )
image = np . array ( image ) . astype ( np . uint8 )
image = ( image / 127.5 - 1.0 ) . astype ( np . float32 )
example [ " key " ] = " - " . join ( [ image_path , " - " , str ( flipped ) ] )
example [ " pixel_values " ] = torch . from_numpy ( image ) . permute ( 2 , 0 , 1 )
self . cache [ image_path ] = example
return example
def __getitem__ ( self , i ) :
flipped = random . choice ( [ False , True ] )
example = self . get_example ( self . image_paths [ i % self . num_images ] , flipped )
example [ " input_ids " ] = random . choice ( self . tokenized_templates )
return example
def freeze_params ( params ) :
for param in params :
param . requires_grad = False
2023-06-23 05:58:20 +03:00
def save_resume_file ( basepath , args , extra = { } ) :
2022-09-25 05:44:20 +03:00
info = { " args " : vars ( args ) }
info [ " args " ] . update ( extra )
with open ( f " { basepath } /resume.json " , " w " ) as f :
json . dump ( info , f , indent = 4 )
2023-06-23 05:58:20 +03:00
2022-09-25 05:44:20 +03:00
class Checkpointer :
def __init__ (
self ,
accelerator ,
vae ,
unet ,
tokenizer ,
placeholder_token ,
placeholder_token_id ,
templates ,
output_dir ,
random_sample_batches ,
sample_batch_size ,
stable_sample_batches ,
2023-06-23 05:58:20 +03:00
seed ,
2022-09-25 05:44:20 +03:00
) :
self . accelerator = accelerator
self . vae = vae
self . unet = unet
self . tokenizer = tokenizer
self . placeholder_token = placeholder_token
self . placeholder_token_id = placeholder_token_id
self . templates = templates
self . output_dir = output_dir
self . seed = seed
self . random_sample_batches = random_sample_batches
self . sample_batch_size = sample_batch_size
self . stable_sample_batches = stable_sample_batches
@torch.no_grad ( )
def checkpoint ( self , step , text_encoder , save_samples = True , path = None ) :
print ( " Saving checkpoint for step %d ... " % step )
with torch . autocast ( " cuda " ) :
if path is None :
checkpoints_path = f " { self . output_dir } /checkpoints "
os . makedirs ( checkpoints_path , exist_ok = True )
unwrapped = self . accelerator . unwrap_model ( text_encoder )
# Save a checkpoint
2023-06-23 05:58:20 +03:00
learned_embeds = unwrapped . get_input_embeddings ( ) . weight [
self . placeholder_token_id
]
learned_embeds_dict = {
self . placeholder_token : learned_embeds . detach ( ) . cpu ( )
}
filename = " %s _ %d .bin " % ( slugify ( self . placeholder_token ) , step )
2022-09-25 05:44:20 +03:00
if path is not None :
torch . save ( learned_embeds_dict , path )
else :
torch . save ( learned_embeds_dict , f " { checkpoints_path } / { filename } " )
torch . save ( learned_embeds_dict , f " { checkpoints_path } /last.bin " )
del unwrapped
del learned_embeds
@torch.no_grad ( )
2023-06-23 05:58:20 +03:00
def save_samples (
self ,
step ,
text_encoder ,
height ,
width ,
guidance_scale ,
eta ,
num_inference_steps ,
) :
2022-09-25 05:44:20 +03:00
samples_path = f " { self . output_dir } /samples "
os . makedirs ( samples_path , exist_ok = True )
checker = NoCheck ( )
unwrapped = self . accelerator . unwrap_model ( text_encoder )
# Save a sample image
pipeline = StableDiffusionPipeline (
text_encoder = unwrapped ,
vae = self . vae ,
unet = self . unet ,
tokenizer = self . tokenizer ,
scheduler = LMSDiscreteScheduler (
beta_start = 0.00085 , beta_end = 0.012 , beta_schedule = " scaled_linear "
) ,
safety_checker = NoCheck ( ) ,
2023-06-23 05:58:20 +03:00
feature_extractor = CLIPFeatureExtractor . from_pretrained (
" openai/clip-vit-base-patch32 "
) ,
2022-09-25 05:44:20 +03:00
) . to ( " cuda " )
pipeline . enable_attention_slicing ( )
if self . stable_sample_batches > 0 :
stable_latents = torch . randn (
2023-06-23 05:58:20 +03:00
(
self . sample_batch_size ,
pipeline . unet . in_channels ,
height / / 8 ,
width / / 8 ,
) ,
2022-09-25 05:44:20 +03:00
device = pipeline . device ,
2023-06-23 05:58:20 +03:00
generator = torch . Generator ( device = pipeline . device ) . manual_seed (
self . seed
) ,
2022-09-25 05:44:20 +03:00
)
2023-06-23 05:58:20 +03:00
stable_prompts = [
choice . format ( self . placeholder_token )
for choice in ( self . templates * self . sample_batch_size ) [
: self . sample_batch_size
]
]
2022-09-25 05:44:20 +03:00
# Generate and save stable samples
for i in range ( 0 , self . stable_sample_batches ) :
samples = pipeline (
prompt = stable_prompts ,
height = 384 ,
latents = stable_latents ,
width = 384 ,
guidance_scale = guidance_scale ,
eta = eta ,
num_inference_steps = num_inference_steps ,
2023-06-23 05:58:20 +03:00
output_type = " pil " ,
2022-09-25 05:44:20 +03:00
) [ " sample " ]
for idx , im in enumerate ( samples ) :
2023-06-23 05:58:20 +03:00
filename = " stable_sample_ %d _ %d _step_ %d .png " % (
i + 1 ,
idx + 1 ,
step ,
)
2022-09-25 05:44:20 +03:00
im . save ( f " { samples_path } / { filename } " )
del samples
del stable_latents
2023-06-23 05:58:20 +03:00
prompts = [
choice . format ( self . placeholder_token )
for choice in random . choices ( self . templates , k = self . sample_batch_size )
]
2022-09-25 05:44:20 +03:00
# Generate and save random samples
for i in range ( 0 , self . random_sample_batches ) :
samples = pipeline (
prompt = prompts ,
height = 384 ,
width = 384 ,
guidance_scale = guidance_scale ,
eta = eta ,
num_inference_steps = num_inference_steps ,
2023-06-23 05:58:20 +03:00
output_type = " pil " ,
2022-09-25 05:44:20 +03:00
) [ " sample " ]
for idx , im in enumerate ( samples ) :
2023-06-23 05:58:20 +03:00
filename = " step_ %d _sample_ %d _ %d .png " % ( step , i + 1 , idx + 1 )
2022-09-25 05:44:20 +03:00
im . save ( f " { samples_path } / { filename } " )
del samples
del checker
del unwrapped
del pipeline
torch . cuda . empty_cache ( )
2023-06-23 05:58:20 +03:00
2022-09-25 05:44:20 +03:00
def main ( ) :
args = parse_args ( )
global_step_offset = 0
if args . resume_from is not None :
basepath = f " { args . resume_from } "
print ( " Resuming state from %s " % args . resume_from )
2023-06-23 05:58:20 +03:00
with open ( f " { basepath } /resume.json " , " r " ) as f :
2022-09-25 05:44:20 +03:00
state = json . load ( f )
global_step_offset = state [ " args " ] . get ( " global_step " , 0 )
print ( " We ' ve trained %d steps so far " % global_step_offset )
else :
now = datetime . datetime . now ( ) . strftime ( " % Y- % m- %d T % H- % M- % S " )
basepath = f " { args . output_dir } / { slugify ( args . placeholder_token ) } / { now } "
os . makedirs ( basepath , exist_ok = True )
accelerator = Accelerator (
gradient_accumulation_steps = args . gradient_accumulation_steps ,
2023-06-23 05:58:20 +03:00
mixed_precision = args . mixed_precision ,
2022-09-25 05:44:20 +03:00
)
# If passed along, set the training seed now.
if args . seed is not None :
set_seed ( args . seed )
# Load the tokenizer and add the placeholder token as a additional special token
if args . tokenizer_name :
tokenizer = CLIPTokenizer . from_pretrained ( args . tokenizer_name )
elif args . pretrained_model_name_or_path :
tokenizer = CLIPTokenizer . from_pretrained (
2023-06-23 05:58:20 +03:00
args . pretrained_model_name_or_path + " /tokenizer "
2022-09-25 05:44:20 +03:00
)
# Add the placeholder token in tokenizer
num_added_tokens = tokenizer . add_tokens ( args . placeholder_token )
if num_added_tokens == 0 :
raise ValueError (
f " The tokenizer already contains the token { args . placeholder_token } . Please pass a different "
" `placeholder_token` that is not already in the tokenizer. "
)
# Convert the initializer_token, placeholder_token to ids
token_ids = tokenizer . encode ( args . initializer_token , add_special_tokens = False )
# Check if initializer_token is a single token or a sequence of tokens
if len ( token_ids ) > 1 :
raise ValueError ( " The initializer token must be a single token. " )
initializer_token_id = token_ids [ 0 ]
placeholder_token_id = tokenizer . convert_tokens_to_ids ( args . placeholder_token )
# Load models and create wrapper for stable diffusion
text_encoder = CLIPTextModel . from_pretrained (
2023-06-23 05:58:20 +03:00
args . pretrained_model_name_or_path + " /text_encoder " ,
2022-09-25 05:44:20 +03:00
)
vae = AutoencoderKL . from_pretrained (
2023-06-23 05:58:20 +03:00
args . pretrained_model_name_or_path + " /vae " ,
2022-09-25 05:44:20 +03:00
)
unet = UNet2DConditionModel . from_pretrained (
2023-06-23 05:58:20 +03:00
args . pretrained_model_name_or_path + " /unet " ,
2022-09-25 05:44:20 +03:00
)
2023-06-23 05:58:20 +03:00
base_templates = (
imagenet_style_templates_small
if args . learnable_property == " style "
else imagenet_templates_small
)
2022-09-25 05:44:20 +03:00
if args . custom_templates :
templates = args . custom_templates . split ( " ; " )
else :
templates = base_templates
slice_size = unet . config . attention_head_dim / / 2
unet . set_attention_slice ( slice_size )
# Resize the token embeddings as we are adding new special tokens to the tokenizer
text_encoder . resize_token_embeddings ( len ( tokenizer ) )
# Initialise the newly added placeholder token with the embeddings of the initializer token
token_embeds = text_encoder . get_input_embeddings ( ) . weight . data
if args . resume_checkpoint is not None :
2023-06-23 05:58:20 +03:00
token_embeds [ placeholder_token_id ] = torch . load ( args . resume_checkpoint ) [
args . placeholder_token
]
2022-09-25 05:44:20 +03:00
else :
token_embeds [ placeholder_token_id ] = token_embeds [ initializer_token_id ]
# Freeze vae and unet
freeze_params ( vae . parameters ( ) )
freeze_params ( unet . parameters ( ) )
# Freeze all parameters except for the token embeddings in text encoder
params_to_freeze = itertools . chain (
text_encoder . text_model . encoder . parameters ( ) ,
text_encoder . text_model . final_layer_norm . parameters ( ) ,
text_encoder . text_model . embeddings . position_embedding . parameters ( ) ,
)
freeze_params ( params_to_freeze )
checkpointer = Checkpointer (
accelerator = accelerator ,
vae = vae ,
unet = unet ,
tokenizer = tokenizer ,
placeholder_token = args . placeholder_token ,
placeholder_token_id = placeholder_token_id ,
templates = templates ,
output_dir = basepath ,
sample_batch_size = args . sample_batch_size ,
random_sample_batches = args . random_sample_batches ,
stable_sample_batches = args . stable_sample_batches ,
2023-06-23 05:58:20 +03:00
seed = args . seed ,
2022-09-25 05:44:20 +03:00
)
if args . scale_lr :
args . learning_rate = (
2023-06-23 05:58:20 +03:00
args . learning_rate
* args . gradient_accumulation_steps
* args . train_batch_size
* accelerator . num_processes
2022-09-25 05:44:20 +03:00
)
# Initialize the optimizer
optimizer = torch . optim . AdamW (
text_encoder . get_input_embeddings ( ) . parameters ( ) , # only optimize the embeddings
lr = args . learning_rate ,
betas = ( args . adam_beta1 , args . adam_beta2 ) ,
weight_decay = args . adam_weight_decay ,
eps = args . adam_epsilon ,
)
# TODO (patil-suraj): laod scheduler using args
noise_scheduler = DDPMScheduler (
2023-06-23 05:58:20 +03:00
beta_start = 0.00085 ,
beta_end = 0.012 ,
beta_schedule = " scaled_linear " ,
num_train_timesteps = 1000 ,
tensor_format = " pt " ,
2022-09-25 05:44:20 +03:00
)
train_dataset = TextualInversionDataset (
data_root = args . train_data_dir ,
tokenizer = tokenizer ,
size = args . resolution ,
placeholder_token = args . placeholder_token ,
repeats = args . repeats ,
learnable_property = args . learnable_property ,
center_crop = args . center_crop ,
set = " train " ,
2023-06-23 05:58:20 +03:00
templates = templates ,
)
train_dataloader = torch . utils . data . DataLoader (
train_dataset , batch_size = args . train_batch_size , shuffle = True
2022-09-25 05:44:20 +03:00
)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
2023-06-23 05:58:20 +03:00
num_update_steps_per_epoch = math . ceil (
len ( train_dataloader ) / args . gradient_accumulation_steps
)
2022-09-25 05:44:20 +03:00
if args . max_train_steps is None :
args . max_train_steps = args . num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
lr_scheduler = get_scheduler (
args . lr_scheduler ,
optimizer = optimizer ,
num_warmup_steps = args . lr_warmup_steps * args . gradient_accumulation_steps ,
num_training_steps = args . max_train_steps * args . gradient_accumulation_steps ,
)
text_encoder , optimizer , train_dataloader , lr_scheduler = accelerator . prepare (
text_encoder , optimizer , train_dataloader , lr_scheduler
)
# Move vae and unet to device
vae . to ( accelerator . device )
unet . to ( accelerator . device )
# Keep vae and unet in eval mode as we don't train these
vae . eval ( )
unet . eval ( )
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
2023-06-23 05:58:20 +03:00
num_update_steps_per_epoch = math . ceil (
len ( train_dataloader ) / args . gradient_accumulation_steps
)
2022-09-25 05:44:20 +03:00
if overrode_max_train_steps :
args . max_train_steps = args . num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args . num_train_epochs = math . ceil ( args . max_train_steps / num_update_steps_per_epoch )
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator . is_main_process :
accelerator . init_trackers ( " textual_inversion " , config = vars ( args ) )
# Train!
2023-06-23 05:58:20 +03:00
total_batch_size = (
args . train_batch_size
* accelerator . num_processes
* args . gradient_accumulation_steps
)
2022-09-25 05:44:20 +03:00
logger . info ( " ***** Running training ***** " )
logger . info ( f " Num examples = { len ( train_dataset ) } " )
logger . info ( f " Num Epochs = { args . num_train_epochs } " )
logger . info ( f " Instantaneous batch size per device = { args . train_batch_size } " )
2023-06-23 05:58:20 +03:00
logger . info (
f " Total train batch size (w. parallel, distributed & accumulation) = { total_batch_size } "
)
2022-09-25 05:44:20 +03:00
logger . info ( f " Gradient Accumulation steps = { args . gradient_accumulation_steps } " )
logger . info ( f " Total optimization steps = { args . max_train_steps } " )
# Only show the progress bar once on each machine.
2023-06-23 05:58:20 +03:00
progress_bar = tqdm (
range ( args . max_train_steps ) , disable = not accelerator . is_local_main_process
)
2022-09-25 05:44:20 +03:00
progress_bar . set_description ( " Steps " )
global_step = 0
encoded_pixel_values_cache = { }
try :
for epoch in range ( args . num_train_epochs ) :
text_encoder . train ( )
for step , batch in enumerate ( train_dataloader ) :
with accelerator . accumulate ( text_encoder ) :
# Convert images to latent space
key = " | " . join ( batch [ " key " ] )
if encoded_pixel_values_cache . get ( key , None ) is None :
2023-06-23 05:58:20 +03:00
encoded_pixel_values_cache [ key ] = vae . encode (
batch [ " pixel_values " ]
) . latent_dist
latents = (
encoded_pixel_values_cache [ key ] . sample ( ) . detach ( ) . half ( )
* 0.18215
)
2022-09-25 05:44:20 +03:00
# Sample noise that we'll add to the latents
noise = torch . randn ( latents . shape ) . to ( latents . device )
bsz = latents . shape [ 0 ]
# Sample a random timestep for each image
2023-06-23 05:58:20 +03:00
timesteps = torch . randint (
0 ,
noise_scheduler . num_train_timesteps ,
( bsz , ) ,
device = latents . device ,
) . long ( )
2022-09-25 05:44:20 +03:00
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler . add_noise ( latents , noise , timesteps )
# Get the text embedding for conditioning
encoder_hidden_states = text_encoder ( batch [ " input_ids " ] ) [ 0 ]
# Predict the noise residual
2023-06-23 05:58:20 +03:00
noise_pred = unet (
noisy_latents , timesteps , encoder_hidden_states
) . sample
loss = (
F . mse_loss ( noise_pred , noise , reduction = " none " )
. mean ( [ 1 , 2 , 3 ] )
. mean ( )
)
2022-09-25 05:44:20 +03:00
accelerator . backward ( loss )
# Zero out the gradients for all token embeddings except the newly added
# embeddings for the concept, as we only want to optimize the concept embeddings
if accelerator . num_processes > 1 :
grads = text_encoder . module . get_input_embeddings ( ) . weight . grad
else :
grads = text_encoder . get_input_embeddings ( ) . weight . grad
# Get the index for tokens that we want to zero the grads for
2023-06-23 05:58:20 +03:00
index_grads_to_zero = (
torch . arange ( len ( tokenizer ) ) != placeholder_token_id
)
grads . data [ index_grads_to_zero , : ] = grads . data [
index_grads_to_zero , :
] . fill_ ( 0 )
2022-09-25 05:44:20 +03:00
optimizer . step ( )
lr_scheduler . step ( )
optimizer . zero_grad ( )
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator . sync_gradients :
progress_bar . update ( 1 )
global_step + = 1
2023-06-23 05:58:20 +03:00
if (
global_step % args . checkpoint_frequency == 0
and global_step > 0
and accelerator . is_main_process
) :
checkpointer . checkpoint (
global_step + global_step_offset , text_encoder
)
save_resume_file (
basepath ,
args ,
{
" global_step " : global_step + global_step_offset ,
" resume_checkpoint " : f " { basepath } /checkpoints/last.bin " ,
} ,
)
2022-09-25 05:44:20 +03:00
checkpointer . save_samples (
global_step + global_step_offset ,
text_encoder ,
2023-06-23 05:58:20 +03:00
args . resolution ,
args . resolution ,
7.5 ,
0.0 ,
args . sample_steps ,
)
logs = {
" loss " : loss . detach ( ) . item ( ) ,
" lr " : lr_scheduler . get_last_lr ( ) [ 0 ] ,
}
2022-09-25 05:44:20 +03:00
progress_bar . set_postfix ( * * logs )
2023-06-23 05:58:20 +03:00
# accelerator.log(logs, step=global_step)
2022-09-25 05:44:20 +03:00
if global_step > = args . max_train_steps :
break
accelerator . wait_for_everyone ( )
# Create the pipeline using using the trained modules and save it.
if accelerator . is_main_process :
print ( " Finished! Saving final checkpoint and resume state. " )
checkpointer . checkpoint (
global_step + global_step_offset ,
text_encoder ,
2023-06-23 05:58:20 +03:00
path = f " { basepath } /learned_embeds.bin " ,
2022-09-25 05:44:20 +03:00
)
2023-06-23 05:58:20 +03:00
save_resume_file (
basepath ,
args ,
{
" global_step " : global_step + global_step_offset ,
" resume_checkpoint " : f " { basepath } /checkpoints/last.bin " ,
} ,
)
2022-09-25 05:44:20 +03:00
accelerator . end_training ( )
except KeyboardInterrupt :
if accelerator . is_main_process :
print ( " Interrupted, saving checkpoint and resume state... " )
checkpointer . checkpoint ( global_step + global_step_offset , text_encoder )
2023-06-23 05:58:20 +03:00
save_resume_file (
basepath ,
args ,
{
" global_step " : global_step + global_step_offset ,
" resume_checkpoint " : f " { basepath } /checkpoints/last.bin " ,
} ,
)
2022-09-25 05:44:20 +03:00
quit ( )
2023-06-23 05:58:20 +03:00
2022-09-25 05:44:20 +03:00
if __name__ == " __main__ " :
main ( )