mirror of
https://github.com/sd-webui/stable-diffusion-webui.git
synced 2024-12-14 06:35:14 +03:00
Added some missing files from the ldm folder.
This commit is contained in:
parent
2e2b35ff71
commit
e9592b6568
202
ldm/data/personalized.py
Normal file
202
ldm/data/personalized.py
Normal file
@ -0,0 +1,202 @@
|
|||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import PIL
|
||||||
|
from PIL import Image
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
import random
|
||||||
|
|
||||||
|
imagenet_templates_smallest = [
|
||||||
|
'a photo of a {}',
|
||||||
|
]
|
||||||
|
|
||||||
|
imagenet_templates_small = [
|
||||||
|
'a photo of a {}',
|
||||||
|
'a rendering of a {}',
|
||||||
|
'a cropped photo of the {}',
|
||||||
|
'the photo of a {}',
|
||||||
|
'a photo of a clean {}',
|
||||||
|
'a photo of a dirty {}',
|
||||||
|
'a dark photo of the {}',
|
||||||
|
'a photo of my {}',
|
||||||
|
'a photo of the cool {}',
|
||||||
|
'a close-up photo of a {}',
|
||||||
|
'a bright photo of the {}',
|
||||||
|
'a cropped photo of a {}',
|
||||||
|
'a photo of the {}',
|
||||||
|
'a good photo of the {}',
|
||||||
|
'a photo of one {}',
|
||||||
|
'a close-up photo of the {}',
|
||||||
|
'a rendition of the {}',
|
||||||
|
'a photo of the clean {}',
|
||||||
|
'a rendition of a {}',
|
||||||
|
'a photo of a nice {}',
|
||||||
|
'a good photo of a {}',
|
||||||
|
'a photo of the nice {}',
|
||||||
|
'a photo of the small {}',
|
||||||
|
'a photo of the weird {}',
|
||||||
|
'a photo of the large {}',
|
||||||
|
'a photo of a cool {}',
|
||||||
|
'a photo of a small {}',
|
||||||
|
]
|
||||||
|
|
||||||
|
imagenet_dual_templates_small = [
|
||||||
|
'a photo of a {} with {}',
|
||||||
|
'a rendering of a {} with {}',
|
||||||
|
'a cropped photo of the {} with {}',
|
||||||
|
'the photo of a {} with {}',
|
||||||
|
'a photo of a clean {} with {}',
|
||||||
|
'a photo of a dirty {} with {}',
|
||||||
|
'a dark photo of the {} with {}',
|
||||||
|
'a photo of my {} with {}',
|
||||||
|
'a photo of the cool {} with {}',
|
||||||
|
'a close-up photo of a {} with {}',
|
||||||
|
'a bright photo of the {} with {}',
|
||||||
|
'a cropped photo of a {} with {}',
|
||||||
|
'a photo of the {} with {}',
|
||||||
|
'a good photo of the {} with {}',
|
||||||
|
'a photo of one {} with {}',
|
||||||
|
'a close-up photo of the {} with {}',
|
||||||
|
'a rendition of the {} with {}',
|
||||||
|
'a photo of the clean {} with {}',
|
||||||
|
'a rendition of a {} with {}',
|
||||||
|
'a photo of a nice {} with {}',
|
||||||
|
'a good photo of a {} with {}',
|
||||||
|
'a photo of the nice {} with {}',
|
||||||
|
'a photo of the small {} with {}',
|
||||||
|
'a photo of the weird {} with {}',
|
||||||
|
'a photo of the large {} with {}',
|
||||||
|
'a photo of a cool {} with {}',
|
||||||
|
'a photo of a small {} with {}',
|
||||||
|
]
|
||||||
|
|
||||||
|
per_img_token_list = [
|
||||||
|
'א',
|
||||||
|
'ב',
|
||||||
|
'ג',
|
||||||
|
'ד',
|
||||||
|
'ה',
|
||||||
|
'ו',
|
||||||
|
'ז',
|
||||||
|
'ח',
|
||||||
|
'ט',
|
||||||
|
'י',
|
||||||
|
'כ',
|
||||||
|
'ל',
|
||||||
|
'מ',
|
||||||
|
'נ',
|
||||||
|
'ס',
|
||||||
|
'ע',
|
||||||
|
'פ',
|
||||||
|
'צ',
|
||||||
|
'ק',
|
||||||
|
'ר',
|
||||||
|
'ש',
|
||||||
|
'ת',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class PersonalizedBase(Dataset):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
data_root,
|
||||||
|
size=None,
|
||||||
|
repeats=100,
|
||||||
|
interpolation='bicubic',
|
||||||
|
flip_p=0.5,
|
||||||
|
set='train',
|
||||||
|
placeholder_token='*',
|
||||||
|
per_image_tokens=False,
|
||||||
|
center_crop=False,
|
||||||
|
mixing_prob=0.25,
|
||||||
|
coarse_class_text=None,
|
||||||
|
):
|
||||||
|
|
||||||
|
self.data_root = data_root
|
||||||
|
|
||||||
|
self.image_paths = [
|
||||||
|
os.path.join(self.data_root, file_path)
|
||||||
|
for file_path in os.listdir(self.data_root)
|
||||||
|
]
|
||||||
|
|
||||||
|
# self._length = len(self.image_paths)
|
||||||
|
self.num_images = len(self.image_paths)
|
||||||
|
self._length = self.num_images
|
||||||
|
|
||||||
|
self.placeholder_token = placeholder_token
|
||||||
|
|
||||||
|
self.per_image_tokens = per_image_tokens
|
||||||
|
self.center_crop = center_crop
|
||||||
|
self.mixing_prob = mixing_prob
|
||||||
|
|
||||||
|
self.coarse_class_text = coarse_class_text
|
||||||
|
|
||||||
|
if per_image_tokens:
|
||||||
|
assert self.num_images < len(
|
||||||
|
per_img_token_list
|
||||||
|
), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'."
|
||||||
|
|
||||||
|
if set == 'train':
|
||||||
|
self._length = self.num_images * repeats
|
||||||
|
|
||||||
|
self.size = size
|
||||||
|
self.interpolation = {
|
||||||
|
'linear': PIL.Image.LINEAR,
|
||||||
|
'bilinear': PIL.Image.BILINEAR,
|
||||||
|
'bicubic': PIL.Image.BICUBIC,
|
||||||
|
'lanczos': PIL.Image.LANCZOS,
|
||||||
|
}[interpolation]
|
||||||
|
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self._length
|
||||||
|
|
||||||
|
def __getitem__(self, i):
|
||||||
|
example = {}
|
||||||
|
image = Image.open(self.image_paths[i % self.num_images])
|
||||||
|
|
||||||
|
if not image.mode == 'RGB':
|
||||||
|
image = image.convert('RGB')
|
||||||
|
|
||||||
|
placeholder_string = self.placeholder_token
|
||||||
|
if self.coarse_class_text:
|
||||||
|
placeholder_string = (
|
||||||
|
f'{self.coarse_class_text} {placeholder_string}'
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.per_image_tokens and np.random.uniform() < self.mixing_prob:
|
||||||
|
text = random.choice(imagenet_dual_templates_small).format(
|
||||||
|
placeholder_string, per_img_token_list[i % self.num_images]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
text = random.choice(imagenet_templates_small).format(
|
||||||
|
placeholder_string
|
||||||
|
)
|
||||||
|
|
||||||
|
example['caption'] = text
|
||||||
|
|
||||||
|
# default to score-sde preprocessing
|
||||||
|
img = np.array(image).astype(np.uint8)
|
||||||
|
|
||||||
|
if self.center_crop:
|
||||||
|
crop = min(img.shape[0], img.shape[1])
|
||||||
|
h, w, = (
|
||||||
|
img.shape[0],
|
||||||
|
img.shape[1],
|
||||||
|
)
|
||||||
|
img = img[
|
||||||
|
(h - crop) // 2 : (h + crop) // 2,
|
||||||
|
(w - crop) // 2 : (w + crop) // 2,
|
||||||
|
]
|
||||||
|
|
||||||
|
image = Image.fromarray(img)
|
||||||
|
if self.size is not None:
|
||||||
|
image = image.resize(
|
||||||
|
(self.size, self.size), resample=self.interpolation
|
||||||
|
)
|
||||||
|
|
||||||
|
image = self.flip(image)
|
||||||
|
image = np.array(image).astype(np.uint8)
|
||||||
|
example['image'] = (image / 127.5 - 1.0).astype(np.float32)
|
||||||
|
return example
|
169
ldm/data/personalized_file.py
Normal file
169
ldm/data/personalized_file.py
Normal file
@ -0,0 +1,169 @@
|
|||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import PIL
|
||||||
|
from PIL import Image
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
import random
|
||||||
|
|
||||||
|
imagenet_templates_small = [
|
||||||
|
'a painting in the style of {}',
|
||||||
|
'a rendering in the style of {}',
|
||||||
|
'a cropped painting in the style of {}',
|
||||||
|
'the painting in the style of {}',
|
||||||
|
'a clean painting in the style of {}',
|
||||||
|
'a dirty painting in the style of {}',
|
||||||
|
'a dark painting in the style of {}',
|
||||||
|
'a picture in the style of {}',
|
||||||
|
'a cool painting in the style of {}',
|
||||||
|
'a close-up painting in the style of {}',
|
||||||
|
'a bright painting in the style of {}',
|
||||||
|
'a cropped painting in the style of {}',
|
||||||
|
'a good painting in the style of {}',
|
||||||
|
'a close-up painting in the style of {}',
|
||||||
|
'a rendition in the style of {}',
|
||||||
|
'a nice painting in the style of {}',
|
||||||
|
'a small painting in the style of {}',
|
||||||
|
'a weird painting in the style of {}',
|
||||||
|
'a large painting in the style of {}',
|
||||||
|
]
|
||||||
|
|
||||||
|
imagenet_dual_templates_small = [
|
||||||
|
'a painting in the style of {} with {}',
|
||||||
|
'a rendering in the style of {} with {}',
|
||||||
|
'a cropped painting in the style of {} with {}',
|
||||||
|
'the painting in the style of {} with {}',
|
||||||
|
'a clean painting in the style of {} with {}',
|
||||||
|
'a dirty painting in the style of {} with {}',
|
||||||
|
'a dark painting in the style of {} with {}',
|
||||||
|
'a cool painting in the style of {} with {}',
|
||||||
|
'a close-up painting in the style of {} with {}',
|
||||||
|
'a bright painting in the style of {} with {}',
|
||||||
|
'a cropped painting in the style of {} with {}',
|
||||||
|
'a good painting in the style of {} with {}',
|
||||||
|
'a painting of one {} in the style of {}',
|
||||||
|
'a nice painting in the style of {} with {}',
|
||||||
|
'a small painting in the style of {} with {}',
|
||||||
|
'a weird painting in the style of {} with {}',
|
||||||
|
'a large painting in the style of {} with {}',
|
||||||
|
]
|
||||||
|
|
||||||
|
per_img_token_list = [
|
||||||
|
'א',
|
||||||
|
'ב',
|
||||||
|
'ג',
|
||||||
|
'ד',
|
||||||
|
'ה',
|
||||||
|
'ו',
|
||||||
|
'ז',
|
||||||
|
'ח',
|
||||||
|
'ט',
|
||||||
|
'י',
|
||||||
|
'כ',
|
||||||
|
'ל',
|
||||||
|
'מ',
|
||||||
|
'נ',
|
||||||
|
'ס',
|
||||||
|
'ע',
|
||||||
|
'פ',
|
||||||
|
'צ',
|
||||||
|
'ק',
|
||||||
|
'ר',
|
||||||
|
'ש',
|
||||||
|
'ת',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class PersonalizedBase(Dataset):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
data_root,
|
||||||
|
size=None,
|
||||||
|
repeats=100,
|
||||||
|
interpolation='bicubic',
|
||||||
|
flip_p=0.5,
|
||||||
|
set='train',
|
||||||
|
placeholder_token='*',
|
||||||
|
per_image_tokens=False,
|
||||||
|
center_crop=False,
|
||||||
|
):
|
||||||
|
|
||||||
|
self.data_root = data_root
|
||||||
|
|
||||||
|
self.image_paths = [
|
||||||
|
os.path.join(self.data_root, file_path)
|
||||||
|
for file_path in os.listdir(self.data_root)
|
||||||
|
]
|
||||||
|
|
||||||
|
# self._length = len(self.image_paths)
|
||||||
|
self.num_images = len(self.image_paths)
|
||||||
|
self._length = self.num_images
|
||||||
|
|
||||||
|
self.placeholder_token = placeholder_token
|
||||||
|
|
||||||
|
self.per_image_tokens = per_image_tokens
|
||||||
|
self.center_crop = center_crop
|
||||||
|
|
||||||
|
if per_image_tokens:
|
||||||
|
assert self.num_images < len(
|
||||||
|
per_img_token_list
|
||||||
|
), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'."
|
||||||
|
|
||||||
|
if set == 'train':
|
||||||
|
self._length = self.num_images * repeats
|
||||||
|
|
||||||
|
self.size = size
|
||||||
|
self.interpolation = {
|
||||||
|
'linear': PIL.Image.LINEAR,
|
||||||
|
'bilinear': PIL.Image.BILINEAR,
|
||||||
|
'bicubic': PIL.Image.BICUBIC,
|
||||||
|
'lanczos': PIL.Image.LANCZOS,
|
||||||
|
}[interpolation]
|
||||||
|
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self._length
|
||||||
|
|
||||||
|
def __getitem__(self, i):
|
||||||
|
example = {}
|
||||||
|
image = Image.open(self.image_paths[i % self.num_images])
|
||||||
|
|
||||||
|
if not image.mode == 'RGB':
|
||||||
|
image = image.convert('RGB')
|
||||||
|
|
||||||
|
if self.per_image_tokens and np.random.uniform() < 0.25:
|
||||||
|
text = random.choice(imagenet_dual_templates_small).format(
|
||||||
|
self.placeholder_token, per_img_token_list[i % self.num_images]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
text = random.choice(imagenet_templates_small).format(
|
||||||
|
self.placeholder_token
|
||||||
|
)
|
||||||
|
|
||||||
|
example['caption'] = text
|
||||||
|
|
||||||
|
# default to score-sde preprocessing
|
||||||
|
img = np.array(image).astype(np.uint8)
|
||||||
|
|
||||||
|
if self.center_crop:
|
||||||
|
crop = min(img.shape[0], img.shape[1])
|
||||||
|
h, w, = (
|
||||||
|
img.shape[0],
|
||||||
|
img.shape[1],
|
||||||
|
)
|
||||||
|
img = img[
|
||||||
|
(h - crop) // 2 : (h + crop) // 2,
|
||||||
|
(w - crop) // 2 : (w + crop) // 2,
|
||||||
|
]
|
||||||
|
|
||||||
|
image = Image.fromarray(img)
|
||||||
|
if self.size is not None:
|
||||||
|
image = image.resize(
|
||||||
|
(self.size, self.size), resample=self.interpolation
|
||||||
|
)
|
||||||
|
|
||||||
|
image = self.flip(image)
|
||||||
|
image = np.array(image).astype(np.uint8)
|
||||||
|
example['image'] = (image / 127.5 - 1.0).astype(np.float32)
|
||||||
|
return example
|
1
ldm/devices/__init__.py
Normal file
1
ldm/devices/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
from ldm.devices.devices import choose_autocast_device, choose_torch_device
|
24
ldm/devices/devices.py
Normal file
24
ldm/devices/devices.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
import torch
|
||||||
|
from torch import autocast
|
||||||
|
from contextlib import contextmanager, nullcontext
|
||||||
|
|
||||||
|
def choose_torch_device() -> str:
|
||||||
|
'''Convenience routine for guessing which GPU device to run model on'''
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
return 'cuda'
|
||||||
|
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||||||
|
return 'mps'
|
||||||
|
return 'cpu'
|
||||||
|
|
||||||
|
def choose_autocast_device(device):
|
||||||
|
'''Returns an autocast compatible device from a torch device'''
|
||||||
|
device_type = device.type # this returns 'mps' on M1
|
||||||
|
# autocast only for cuda, but GTX 16xx have issues with it
|
||||||
|
if device_type == 'cuda':
|
||||||
|
device_name = torch.cuda.get_device_name()
|
||||||
|
if 'GeForce GTX 1660' in device_name or 'GeForce GTX 1650' in device_name:
|
||||||
|
return device_type,nullcontext
|
||||||
|
else:
|
||||||
|
return device_type,autocast
|
||||||
|
else:
|
||||||
|
return 'cpu',nullcontext
|
0
ldm/modules/__init__.py
Normal file
0
ldm/modules/__init__.py
Normal file
273
ldm/modules/embedding_manager.py
Normal file
273
ldm/modules/embedding_manager.py
Normal file
@ -0,0 +1,273 @@
|
|||||||
|
from cmath import log
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from ldm.data.personalized import per_img_token_list
|
||||||
|
from transformers import CLIPTokenizer
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
DEFAULT_PLACEHOLDER_TOKEN = ['*']
|
||||||
|
|
||||||
|
PROGRESSIVE_SCALE = 2000
|
||||||
|
|
||||||
|
|
||||||
|
def get_clip_token_for_string(tokenizer, string):
|
||||||
|
batch_encoding = tokenizer(
|
||||||
|
string,
|
||||||
|
truncation=True,
|
||||||
|
max_length=77,
|
||||||
|
return_length=True,
|
||||||
|
return_overflowing_tokens=False,
|
||||||
|
padding='max_length',
|
||||||
|
return_tensors='pt',
|
||||||
|
)
|
||||||
|
tokens = batch_encoding['input_ids']
|
||||||
|
""" assert (
|
||||||
|
torch.count_nonzero(tokens - 49407) == 2
|
||||||
|
), f"String '{string}' maps to more than a single token. Please use another string" """
|
||||||
|
|
||||||
|
return tokens[0, 1]
|
||||||
|
|
||||||
|
|
||||||
|
def get_bert_token_for_string(tokenizer, string):
|
||||||
|
token = tokenizer(string)
|
||||||
|
# assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string"
|
||||||
|
|
||||||
|
token = token[0, 1]
|
||||||
|
|
||||||
|
return token
|
||||||
|
|
||||||
|
|
||||||
|
def get_embedding_for_clip_token(embedder, token):
|
||||||
|
return embedder(token.unsqueeze(0))[0, 0]
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingManager(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedder,
|
||||||
|
placeholder_strings=None,
|
||||||
|
initializer_words=None,
|
||||||
|
per_image_tokens=False,
|
||||||
|
num_vectors_per_token=1,
|
||||||
|
progressive_words=False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.embedder = embedder
|
||||||
|
device = embedder.device
|
||||||
|
|
||||||
|
self.string_to_token_dict = {}
|
||||||
|
self.string_to_param_dict = nn.ParameterDict()
|
||||||
|
|
||||||
|
self.initial_embeddings = (
|
||||||
|
nn.ParameterDict()
|
||||||
|
) # These should not be optimized
|
||||||
|
|
||||||
|
self.progressive_words = progressive_words
|
||||||
|
self.progressive_counter = 0
|
||||||
|
|
||||||
|
self.max_vectors_per_token = num_vectors_per_token
|
||||||
|
|
||||||
|
if hasattr(
|
||||||
|
embedder, 'tokenizer'
|
||||||
|
): # using Stable Diffusion's CLIP encoder
|
||||||
|
self.is_clip = True
|
||||||
|
get_token_for_string = partial(
|
||||||
|
get_clip_token_for_string, embedder.tokenizer
|
||||||
|
)
|
||||||
|
get_embedding_for_tkn = partial(
|
||||||
|
get_embedding_for_clip_token,
|
||||||
|
embedder.transformer.text_model.embeddings,
|
||||||
|
)
|
||||||
|
# per bug report #572
|
||||||
|
#token_dim = 1280
|
||||||
|
token_dim = 768
|
||||||
|
else: # using LDM's BERT encoder
|
||||||
|
self.is_clip = False
|
||||||
|
get_token_for_string = partial(
|
||||||
|
get_bert_token_for_string, embedder.tknz_fn
|
||||||
|
)
|
||||||
|
get_embedding_for_tkn = embedder.transformer.token_emb
|
||||||
|
token_dim = 1280
|
||||||
|
|
||||||
|
if per_image_tokens:
|
||||||
|
placeholder_strings.extend(per_img_token_list)
|
||||||
|
|
||||||
|
for idx, placeholder_string in enumerate(placeholder_strings):
|
||||||
|
|
||||||
|
token = get_token_for_string(placeholder_string)
|
||||||
|
|
||||||
|
if initializer_words and idx < len(initializer_words):
|
||||||
|
init_word_token = get_token_for_string(initializer_words[idx])
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
init_word_embedding = get_embedding_for_tkn(
|
||||||
|
init_word_token.to(device)
|
||||||
|
)
|
||||||
|
|
||||||
|
token_params = torch.nn.Parameter(
|
||||||
|
init_word_embedding.unsqueeze(0).repeat(
|
||||||
|
num_vectors_per_token, 1
|
||||||
|
),
|
||||||
|
requires_grad=True,
|
||||||
|
)
|
||||||
|
self.initial_embeddings[
|
||||||
|
placeholder_string
|
||||||
|
] = torch.nn.Parameter(
|
||||||
|
init_word_embedding.unsqueeze(0).repeat(
|
||||||
|
num_vectors_per_token, 1
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
token_params = torch.nn.Parameter(
|
||||||
|
torch.rand(
|
||||||
|
size=(num_vectors_per_token, token_dim),
|
||||||
|
requires_grad=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.string_to_token_dict[placeholder_string] = token
|
||||||
|
self.string_to_param_dict[placeholder_string] = token_params
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
tokenized_text,
|
||||||
|
embedded_text,
|
||||||
|
):
|
||||||
|
b, n, device = *tokenized_text.shape, tokenized_text.device
|
||||||
|
|
||||||
|
for (
|
||||||
|
placeholder_string,
|
||||||
|
placeholder_token,
|
||||||
|
) in self.string_to_token_dict.items():
|
||||||
|
|
||||||
|
placeholder_embedding = self.string_to_param_dict[
|
||||||
|
placeholder_string
|
||||||
|
].to(device)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.max_vectors_per_token == 1
|
||||||
|
): # If there's only one vector per token, we can do a simple replacement
|
||||||
|
placeholder_idx = torch.where(
|
||||||
|
tokenized_text == placeholder_token.to(device)
|
||||||
|
)
|
||||||
|
embedded_text[placeholder_idx] = placeholder_embedding
|
||||||
|
else: # otherwise, need to insert and keep track of changing indices
|
||||||
|
if self.progressive_words:
|
||||||
|
self.progressive_counter += 1
|
||||||
|
max_step_tokens = (
|
||||||
|
1 + self.progressive_counter // PROGRESSIVE_SCALE
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
max_step_tokens = self.max_vectors_per_token
|
||||||
|
|
||||||
|
num_vectors_for_token = min(
|
||||||
|
placeholder_embedding.shape[0], max_step_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
placeholder_rows, placeholder_cols = torch.where(
|
||||||
|
tokenized_text == placeholder_token.to(device)
|
||||||
|
)
|
||||||
|
|
||||||
|
if placeholder_rows.nelement() == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
sorted_cols, sort_idx = torch.sort(
|
||||||
|
placeholder_cols, descending=True
|
||||||
|
)
|
||||||
|
sorted_rows = placeholder_rows[sort_idx]
|
||||||
|
|
||||||
|
for idx in range(len(sorted_rows)):
|
||||||
|
row = sorted_rows[idx]
|
||||||
|
col = sorted_cols[idx]
|
||||||
|
|
||||||
|
new_token_row = torch.cat(
|
||||||
|
[
|
||||||
|
tokenized_text[row][:col],
|
||||||
|
placeholder_token.repeat(num_vectors_for_token).to(
|
||||||
|
device
|
||||||
|
),
|
||||||
|
tokenized_text[row][col + 1 :],
|
||||||
|
],
|
||||||
|
axis=0,
|
||||||
|
)[:n]
|
||||||
|
new_embed_row = torch.cat(
|
||||||
|
[
|
||||||
|
embedded_text[row][:col],
|
||||||
|
placeholder_embedding[:num_vectors_for_token],
|
||||||
|
embedded_text[row][col + 1 :],
|
||||||
|
],
|
||||||
|
axis=0,
|
||||||
|
)[:n]
|
||||||
|
|
||||||
|
embedded_text[row] = new_embed_row
|
||||||
|
tokenized_text[row] = new_token_row
|
||||||
|
|
||||||
|
return embedded_text
|
||||||
|
|
||||||
|
def save(self, ckpt_path):
|
||||||
|
torch.save(
|
||||||
|
{
|
||||||
|
'string_to_token': self.string_to_token_dict,
|
||||||
|
'string_to_param': self.string_to_param_dict,
|
||||||
|
},
|
||||||
|
ckpt_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
def load(self, ckpt_path, full=True):
|
||||||
|
ckpt = torch.load(ckpt_path, map_location='cpu')
|
||||||
|
|
||||||
|
# Handle .pt textual inversion files
|
||||||
|
if 'string_to_token' in ckpt and 'string_to_param' in ckpt:
|
||||||
|
self.string_to_token_dict = ckpt["string_to_token"]
|
||||||
|
self.string_to_param_dict = ckpt["string_to_param"]
|
||||||
|
|
||||||
|
# Handle .bin textual inversion files from Huggingface Concepts
|
||||||
|
# https://huggingface.co/sd-concepts-library
|
||||||
|
else:
|
||||||
|
for token_str in list(ckpt.keys()):
|
||||||
|
token = get_clip_token_for_string(self.embedder.tokenizer, token_str)
|
||||||
|
self.string_to_token_dict[token_str] = token
|
||||||
|
ckpt[token_str] = torch.nn.Parameter(ckpt[token_str])
|
||||||
|
|
||||||
|
self.string_to_param_dict.update(ckpt)
|
||||||
|
|
||||||
|
if not full:
|
||||||
|
for key, value in self.string_to_param_dict.items():
|
||||||
|
self.string_to_param_dict[key] = torch.nn.Parameter(value.half())
|
||||||
|
|
||||||
|
def get_embedding_norms_squared(self):
|
||||||
|
all_params = torch.cat(
|
||||||
|
list(self.string_to_param_dict.values()), axis=0
|
||||||
|
) # num_placeholders x embedding_dim
|
||||||
|
param_norm_squared = (all_params * all_params).sum(
|
||||||
|
axis=-1
|
||||||
|
) # num_placeholders
|
||||||
|
|
||||||
|
return param_norm_squared
|
||||||
|
|
||||||
|
def embedding_parameters(self):
|
||||||
|
return self.string_to_param_dict.parameters()
|
||||||
|
|
||||||
|
def embedding_to_coarse_loss(self):
|
||||||
|
|
||||||
|
loss = 0.0
|
||||||
|
num_embeddings = len(self.initial_embeddings)
|
||||||
|
|
||||||
|
for key in self.initial_embeddings:
|
||||||
|
optimized = self.string_to_param_dict[key]
|
||||||
|
coarse = self.initial_embeddings[key].clone().to(optimized.device)
|
||||||
|
|
||||||
|
loss = (
|
||||||
|
loss
|
||||||
|
+ (optimized - coarse)
|
||||||
|
@ (optimized - coarse).T
|
||||||
|
/ num_embeddings
|
||||||
|
)
|
||||||
|
|
||||||
|
return loss
|
Loading…
Reference in New Issue
Block a user