From e9592b656854a4cffe671db05dc12d813c8f8fa5 Mon Sep 17 00:00:00 2001 From: ZeroCool940711 Date: Sat, 26 Nov 2022 18:07:59 -0700 Subject: [PATCH] Added some missing files from the ldm folder. --- ldm/{modules/midas => }/__init__.py | 0 ldm/data/personalized.py | 202 ++++++++++++++++++++ ldm/data/personalized_file.py | 169 +++++++++++++++++ ldm/devices/__init__.py | 1 + ldm/devices/devices.py | 24 +++ ldm/modules/__init__.py | 0 ldm/modules/embedding_manager.py | 273 ++++++++++++++++++++++++++++ 7 files changed, 669 insertions(+) rename ldm/{modules/midas => }/__init__.py (100%) create mode 100644 ldm/data/personalized.py create mode 100644 ldm/data/personalized_file.py create mode 100644 ldm/devices/__init__.py create mode 100644 ldm/devices/devices.py create mode 100644 ldm/modules/__init__.py create mode 100644 ldm/modules/embedding_manager.py diff --git a/ldm/modules/midas/__init__.py b/ldm/__init__.py similarity index 100% rename from ldm/modules/midas/__init__.py rename to ldm/__init__.py diff --git a/ldm/data/personalized.py b/ldm/data/personalized.py new file mode 100644 index 0000000..15fc8a8 --- /dev/null +++ b/ldm/data/personalized.py @@ -0,0 +1,202 @@ +import os +import numpy as np +import PIL +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms + +import random + +imagenet_templates_smallest = [ + 'a photo of a {}', +] + +imagenet_templates_small = [ + 'a photo of a {}', + 'a rendering of a {}', + 'a cropped photo of the {}', + 'the photo of a {}', + 'a photo of a clean {}', + 'a photo of a dirty {}', + 'a dark photo of the {}', + 'a photo of my {}', + 'a photo of the cool {}', + 'a close-up photo of a {}', + 'a bright photo of the {}', + 'a cropped photo of a {}', + 'a photo of the {}', + 'a good photo of the {}', + 'a photo of one {}', + 'a close-up photo of the {}', + 'a rendition of the {}', + 'a photo of the clean {}', + 'a rendition of a {}', + 'a photo of a nice {}', + 'a good photo of a {}', + 'a photo of the nice {}', + 'a photo of the small {}', + 'a photo of the weird {}', + 'a photo of the large {}', + 'a photo of a cool {}', + 'a photo of a small {}', +] + +imagenet_dual_templates_small = [ + 'a photo of a {} with {}', + 'a rendering of a {} with {}', + 'a cropped photo of the {} with {}', + 'the photo of a {} with {}', + 'a photo of a clean {} with {}', + 'a photo of a dirty {} with {}', + 'a dark photo of the {} with {}', + 'a photo of my {} with {}', + 'a photo of the cool {} with {}', + 'a close-up photo of a {} with {}', + 'a bright photo of the {} with {}', + 'a cropped photo of a {} with {}', + 'a photo of the {} with {}', + 'a good photo of the {} with {}', + 'a photo of one {} with {}', + 'a close-up photo of the {} with {}', + 'a rendition of the {} with {}', + 'a photo of the clean {} with {}', + 'a rendition of a {} with {}', + 'a photo of a nice {} with {}', + 'a good photo of a {} with {}', + 'a photo of the nice {} with {}', + 'a photo of the small {} with {}', + 'a photo of the weird {} with {}', + 'a photo of the large {} with {}', + 'a photo of a cool {} with {}', + 'a photo of a small {} with {}', +] + +per_img_token_list = [ + 'א', + 'ב', + 'ג', + 'ד', + 'ה', + 'ו', + 'ז', + 'ח', + 'ט', + 'י', + 'כ', + 'ל', + 'מ', + 'נ', + 'ס', + 'ע', + 'פ', + 'צ', + 'ק', + 'ר', + 'ש', + 'ת', +] + + +class PersonalizedBase(Dataset): + def __init__( + self, + data_root, + size=None, + repeats=100, + interpolation='bicubic', + flip_p=0.5, + set='train', + placeholder_token='*', + per_image_tokens=False, + center_crop=False, + mixing_prob=0.25, + coarse_class_text=None, + ): + + self.data_root = data_root + + self.image_paths = [ + os.path.join(self.data_root, file_path) + for file_path in os.listdir(self.data_root) + ] + + # self._length = len(self.image_paths) + self.num_images = len(self.image_paths) + self._length = self.num_images + + self.placeholder_token = placeholder_token + + self.per_image_tokens = per_image_tokens + self.center_crop = center_crop + self.mixing_prob = mixing_prob + + self.coarse_class_text = coarse_class_text + + if per_image_tokens: + assert self.num_images < len( + per_img_token_list + ), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'." + + if set == 'train': + self._length = self.num_images * repeats + + self.size = size + self.interpolation = { + 'linear': PIL.Image.LINEAR, + 'bilinear': PIL.Image.BILINEAR, + 'bicubic': PIL.Image.BICUBIC, + 'lanczos': PIL.Image.LANCZOS, + }[interpolation] + self.flip = transforms.RandomHorizontalFlip(p=flip_p) + + def __len__(self): + return self._length + + def __getitem__(self, i): + example = {} + image = Image.open(self.image_paths[i % self.num_images]) + + if not image.mode == 'RGB': + image = image.convert('RGB') + + placeholder_string = self.placeholder_token + if self.coarse_class_text: + placeholder_string = ( + f'{self.coarse_class_text} {placeholder_string}' + ) + + if self.per_image_tokens and np.random.uniform() < self.mixing_prob: + text = random.choice(imagenet_dual_templates_small).format( + placeholder_string, per_img_token_list[i % self.num_images] + ) + else: + text = random.choice(imagenet_templates_small).format( + placeholder_string + ) + + example['caption'] = text + + # default to score-sde preprocessing + img = np.array(image).astype(np.uint8) + + if self.center_crop: + crop = min(img.shape[0], img.shape[1]) + h, w, = ( + img.shape[0], + img.shape[1], + ) + img = img[ + (h - crop) // 2 : (h + crop) // 2, + (w - crop) // 2 : (w + crop) // 2, + ] + + image = Image.fromarray(img) + if self.size is not None: + image = image.resize( + (self.size, self.size), resample=self.interpolation + ) + + image = self.flip(image) + image = np.array(image).astype(np.uint8) + example['image'] = (image / 127.5 - 1.0).astype(np.float32) + return example diff --git a/ldm/data/personalized_file.py b/ldm/data/personalized_file.py new file mode 100644 index 0000000..56d77d7 --- /dev/null +++ b/ldm/data/personalized_file.py @@ -0,0 +1,169 @@ +import os +import numpy as np +import PIL +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms + +import random + +imagenet_templates_small = [ + 'a painting in the style of {}', + 'a rendering in the style of {}', + 'a cropped painting in the style of {}', + 'the painting in the style of {}', + 'a clean painting in the style of {}', + 'a dirty painting in the style of {}', + 'a dark painting in the style of {}', + 'a picture in the style of {}', + 'a cool painting in the style of {}', + 'a close-up painting in the style of {}', + 'a bright painting in the style of {}', + 'a cropped painting in the style of {}', + 'a good painting in the style of {}', + 'a close-up painting in the style of {}', + 'a rendition in the style of {}', + 'a nice painting in the style of {}', + 'a small painting in the style of {}', + 'a weird painting in the style of {}', + 'a large painting in the style of {}', +] + +imagenet_dual_templates_small = [ + 'a painting in the style of {} with {}', + 'a rendering in the style of {} with {}', + 'a cropped painting in the style of {} with {}', + 'the painting in the style of {} with {}', + 'a clean painting in the style of {} with {}', + 'a dirty painting in the style of {} with {}', + 'a dark painting in the style of {} with {}', + 'a cool painting in the style of {} with {}', + 'a close-up painting in the style of {} with {}', + 'a bright painting in the style of {} with {}', + 'a cropped painting in the style of {} with {}', + 'a good painting in the style of {} with {}', + 'a painting of one {} in the style of {}', + 'a nice painting in the style of {} with {}', + 'a small painting in the style of {} with {}', + 'a weird painting in the style of {} with {}', + 'a large painting in the style of {} with {}', +] + +per_img_token_list = [ + 'א', + 'ב', + 'ג', + 'ד', + 'ה', + 'ו', + 'ז', + 'ח', + 'ט', + 'י', + 'כ', + 'ל', + 'מ', + 'נ', + 'ס', + 'ע', + 'פ', + 'צ', + 'ק', + 'ר', + 'ש', + 'ת', +] + + +class PersonalizedBase(Dataset): + def __init__( + self, + data_root, + size=None, + repeats=100, + interpolation='bicubic', + flip_p=0.5, + set='train', + placeholder_token='*', + per_image_tokens=False, + center_crop=False, + ): + + self.data_root = data_root + + self.image_paths = [ + os.path.join(self.data_root, file_path) + for file_path in os.listdir(self.data_root) + ] + + # self._length = len(self.image_paths) + self.num_images = len(self.image_paths) + self._length = self.num_images + + self.placeholder_token = placeholder_token + + self.per_image_tokens = per_image_tokens + self.center_crop = center_crop + + if per_image_tokens: + assert self.num_images < len( + per_img_token_list + ), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'." + + if set == 'train': + self._length = self.num_images * repeats + + self.size = size + self.interpolation = { + 'linear': PIL.Image.LINEAR, + 'bilinear': PIL.Image.BILINEAR, + 'bicubic': PIL.Image.BICUBIC, + 'lanczos': PIL.Image.LANCZOS, + }[interpolation] + self.flip = transforms.RandomHorizontalFlip(p=flip_p) + + def __len__(self): + return self._length + + def __getitem__(self, i): + example = {} + image = Image.open(self.image_paths[i % self.num_images]) + + if not image.mode == 'RGB': + image = image.convert('RGB') + + if self.per_image_tokens and np.random.uniform() < 0.25: + text = random.choice(imagenet_dual_templates_small).format( + self.placeholder_token, per_img_token_list[i % self.num_images] + ) + else: + text = random.choice(imagenet_templates_small).format( + self.placeholder_token + ) + + example['caption'] = text + + # default to score-sde preprocessing + img = np.array(image).astype(np.uint8) + + if self.center_crop: + crop = min(img.shape[0], img.shape[1]) + h, w, = ( + img.shape[0], + img.shape[1], + ) + img = img[ + (h - crop) // 2 : (h + crop) // 2, + (w - crop) // 2 : (w + crop) // 2, + ] + + image = Image.fromarray(img) + if self.size is not None: + image = image.resize( + (self.size, self.size), resample=self.interpolation + ) + + image = self.flip(image) + image = np.array(image).astype(np.uint8) + example['image'] = (image / 127.5 - 1.0).astype(np.float32) + return example diff --git a/ldm/devices/__init__.py b/ldm/devices/__init__.py new file mode 100644 index 0000000..ad85ea6 --- /dev/null +++ b/ldm/devices/__init__.py @@ -0,0 +1 @@ +from ldm.devices.devices import choose_autocast_device, choose_torch_device diff --git a/ldm/devices/devices.py b/ldm/devices/devices.py new file mode 100644 index 0000000..a92cfcb --- /dev/null +++ b/ldm/devices/devices.py @@ -0,0 +1,24 @@ +import torch +from torch import autocast +from contextlib import contextmanager, nullcontext + +def choose_torch_device() -> str: + '''Convenience routine for guessing which GPU device to run model on''' + if torch.cuda.is_available(): + return 'cuda' + if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): + return 'mps' + return 'cpu' + +def choose_autocast_device(device): + '''Returns an autocast compatible device from a torch device''' + device_type = device.type # this returns 'mps' on M1 + # autocast only for cuda, but GTX 16xx have issues with it + if device_type == 'cuda': + device_name = torch.cuda.get_device_name() + if 'GeForce GTX 1660' in device_name or 'GeForce GTX 1650' in device_name: + return device_type,nullcontext + else: + return device_type,autocast + else: + return 'cpu',nullcontext diff --git a/ldm/modules/__init__.py b/ldm/modules/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ldm/modules/embedding_manager.py b/ldm/modules/embedding_manager.py new file mode 100644 index 0000000..c86fa6b --- /dev/null +++ b/ldm/modules/embedding_manager.py @@ -0,0 +1,273 @@ +from cmath import log +import torch +from torch import nn + +import sys + +from ldm.data.personalized import per_img_token_list +from transformers import CLIPTokenizer +from functools import partial + +DEFAULT_PLACEHOLDER_TOKEN = ['*'] + +PROGRESSIVE_SCALE = 2000 + + +def get_clip_token_for_string(tokenizer, string): + batch_encoding = tokenizer( + string, + truncation=True, + max_length=77, + return_length=True, + return_overflowing_tokens=False, + padding='max_length', + return_tensors='pt', + ) + tokens = batch_encoding['input_ids'] + """ assert ( + torch.count_nonzero(tokens - 49407) == 2 + ), f"String '{string}' maps to more than a single token. Please use another string" """ + + return tokens[0, 1] + + +def get_bert_token_for_string(tokenizer, string): + token = tokenizer(string) + # assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string" + + token = token[0, 1] + + return token + + +def get_embedding_for_clip_token(embedder, token): + return embedder(token.unsqueeze(0))[0, 0] + + +class EmbeddingManager(nn.Module): + def __init__( + self, + embedder, + placeholder_strings=None, + initializer_words=None, + per_image_tokens=False, + num_vectors_per_token=1, + progressive_words=False, + **kwargs, + ): + super().__init__() + + self.embedder = embedder + device = embedder.device + + self.string_to_token_dict = {} + self.string_to_param_dict = nn.ParameterDict() + + self.initial_embeddings = ( + nn.ParameterDict() + ) # These should not be optimized + + self.progressive_words = progressive_words + self.progressive_counter = 0 + + self.max_vectors_per_token = num_vectors_per_token + + if hasattr( + embedder, 'tokenizer' + ): # using Stable Diffusion's CLIP encoder + self.is_clip = True + get_token_for_string = partial( + get_clip_token_for_string, embedder.tokenizer + ) + get_embedding_for_tkn = partial( + get_embedding_for_clip_token, + embedder.transformer.text_model.embeddings, + ) + # per bug report #572 + #token_dim = 1280 + token_dim = 768 + else: # using LDM's BERT encoder + self.is_clip = False + get_token_for_string = partial( + get_bert_token_for_string, embedder.tknz_fn + ) + get_embedding_for_tkn = embedder.transformer.token_emb + token_dim = 1280 + + if per_image_tokens: + placeholder_strings.extend(per_img_token_list) + + for idx, placeholder_string in enumerate(placeholder_strings): + + token = get_token_for_string(placeholder_string) + + if initializer_words and idx < len(initializer_words): + init_word_token = get_token_for_string(initializer_words[idx]) + + with torch.no_grad(): + init_word_embedding = get_embedding_for_tkn( + init_word_token.to(device) + ) + + token_params = torch.nn.Parameter( + init_word_embedding.unsqueeze(0).repeat( + num_vectors_per_token, 1 + ), + requires_grad=True, + ) + self.initial_embeddings[ + placeholder_string + ] = torch.nn.Parameter( + init_word_embedding.unsqueeze(0).repeat( + num_vectors_per_token, 1 + ), + requires_grad=False, + ) + else: + token_params = torch.nn.Parameter( + torch.rand( + size=(num_vectors_per_token, token_dim), + requires_grad=True, + ) + ) + + self.string_to_token_dict[placeholder_string] = token + self.string_to_param_dict[placeholder_string] = token_params + + def forward( + self, + tokenized_text, + embedded_text, + ): + b, n, device = *tokenized_text.shape, tokenized_text.device + + for ( + placeholder_string, + placeholder_token, + ) in self.string_to_token_dict.items(): + + placeholder_embedding = self.string_to_param_dict[ + placeholder_string + ].to(device) + + if ( + self.max_vectors_per_token == 1 + ): # If there's only one vector per token, we can do a simple replacement + placeholder_idx = torch.where( + tokenized_text == placeholder_token.to(device) + ) + embedded_text[placeholder_idx] = placeholder_embedding + else: # otherwise, need to insert and keep track of changing indices + if self.progressive_words: + self.progressive_counter += 1 + max_step_tokens = ( + 1 + self.progressive_counter // PROGRESSIVE_SCALE + ) + else: + max_step_tokens = self.max_vectors_per_token + + num_vectors_for_token = min( + placeholder_embedding.shape[0], max_step_tokens + ) + + placeholder_rows, placeholder_cols = torch.where( + tokenized_text == placeholder_token.to(device) + ) + + if placeholder_rows.nelement() == 0: + continue + + sorted_cols, sort_idx = torch.sort( + placeholder_cols, descending=True + ) + sorted_rows = placeholder_rows[sort_idx] + + for idx in range(len(sorted_rows)): + row = sorted_rows[idx] + col = sorted_cols[idx] + + new_token_row = torch.cat( + [ + tokenized_text[row][:col], + placeholder_token.repeat(num_vectors_for_token).to( + device + ), + tokenized_text[row][col + 1 :], + ], + axis=0, + )[:n] + new_embed_row = torch.cat( + [ + embedded_text[row][:col], + placeholder_embedding[:num_vectors_for_token], + embedded_text[row][col + 1 :], + ], + axis=0, + )[:n] + + embedded_text[row] = new_embed_row + tokenized_text[row] = new_token_row + + return embedded_text + + def save(self, ckpt_path): + torch.save( + { + 'string_to_token': self.string_to_token_dict, + 'string_to_param': self.string_to_param_dict, + }, + ckpt_path, + ) + + def load(self, ckpt_path, full=True): + ckpt = torch.load(ckpt_path, map_location='cpu') + + # Handle .pt textual inversion files + if 'string_to_token' in ckpt and 'string_to_param' in ckpt: + self.string_to_token_dict = ckpt["string_to_token"] + self.string_to_param_dict = ckpt["string_to_param"] + + # Handle .bin textual inversion files from Huggingface Concepts + # https://huggingface.co/sd-concepts-library + else: + for token_str in list(ckpt.keys()): + token = get_clip_token_for_string(self.embedder.tokenizer, token_str) + self.string_to_token_dict[token_str] = token + ckpt[token_str] = torch.nn.Parameter(ckpt[token_str]) + + self.string_to_param_dict.update(ckpt) + + if not full: + for key, value in self.string_to_param_dict.items(): + self.string_to_param_dict[key] = torch.nn.Parameter(value.half()) + + def get_embedding_norms_squared(self): + all_params = torch.cat( + list(self.string_to_param_dict.values()), axis=0 + ) # num_placeholders x embedding_dim + param_norm_squared = (all_params * all_params).sum( + axis=-1 + ) # num_placeholders + + return param_norm_squared + + def embedding_parameters(self): + return self.string_to_param_dict.parameters() + + def embedding_to_coarse_loss(self): + + loss = 0.0 + num_embeddings = len(self.initial_embeddings) + + for key in self.initial_embeddings: + optimized = self.string_to_param_dict[key] + coarse = self.initial_embeddings[key].clone().to(optimized.device) + + loss = ( + loss + + (optimized - coarse) + @ (optimized - coarse).T + / num_embeddings + ) + + return loss