2022-10-02 15:03:39 +03:00
|
|
|
import os
|
|
|
|
import numpy as np
|
|
|
|
import PIL
|
|
|
|
import torch
|
|
|
|
from PIL import Image
|
|
|
|
from torch.utils.data import Dataset
|
|
|
|
from torchvision import transforms
|
|
|
|
|
|
|
|
import random
|
|
|
|
import tqdm
|
2022-10-11 19:03:08 +03:00
|
|
|
from modules import devices, shared
|
2022-10-04 08:52:11 +03:00
|
|
|
import re
|
|
|
|
|
2022-10-12 20:49:47 +03:00
|
|
|
re_numbers_at_start = re.compile(r"^[-\d]+\s*")
|
|
|
|
|
|
|
|
|
|
|
|
class DatasetEntry:
|
|
|
|
def __init__(self, filename=None, latent=None, filename_text=None):
|
|
|
|
self.filename = filename
|
|
|
|
self.latent = latent
|
|
|
|
self.filename_text = filename_text
|
|
|
|
self.cond = None
|
|
|
|
self.cond_text = None
|
2022-10-02 15:03:39 +03:00
|
|
|
|
|
|
|
|
|
|
|
class PersonalizedBase(Dataset):
|
2022-10-15 09:24:59 +03:00
|
|
|
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False, batch_size=1):
|
|
|
|
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
|
2022-10-02 15:03:39 +03:00
|
|
|
|
|
|
|
self.placeholder_token = placeholder_token
|
|
|
|
|
2022-10-15 09:24:59 +03:00
|
|
|
self.batch_size = batch_size
|
2022-10-10 16:35:35 +03:00
|
|
|
self.width = width
|
|
|
|
self.height = height
|
2022-10-02 15:03:39 +03:00
|
|
|
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
|
|
|
|
|
|
|
self.dataset = []
|
|
|
|
|
|
|
|
with open(template_file, "r") as file:
|
|
|
|
lines = [x.strip() for x in file.readlines()]
|
|
|
|
|
|
|
|
self.lines = lines
|
|
|
|
|
|
|
|
assert data_root, 'dataset directory not specified'
|
|
|
|
|
2022-10-11 19:03:08 +03:00
|
|
|
cond_model = shared.sd_model.cond_stage_model
|
|
|
|
|
2022-10-11 11:32:46 +03:00
|
|
|
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
|
2022-10-02 15:03:39 +03:00
|
|
|
print("Preparing dataset...")
|
|
|
|
for path in tqdm.tqdm(self.image_paths):
|
2022-10-11 11:32:46 +03:00
|
|
|
try:
|
|
|
|
image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC)
|
|
|
|
except Exception:
|
|
|
|
continue
|
2022-10-02 15:03:39 +03:00
|
|
|
|
2022-10-12 20:49:47 +03:00
|
|
|
text_filename = os.path.splitext(path)[0] + ".txt"
|
2022-10-02 15:03:39 +03:00
|
|
|
filename = os.path.basename(path)
|
2022-10-12 20:49:47 +03:00
|
|
|
|
|
|
|
if os.path.exists(text_filename):
|
|
|
|
with open(text_filename, "r", encoding="utf8") as file:
|
|
|
|
filename_text = file.read()
|
|
|
|
else:
|
|
|
|
filename_text = os.path.splitext(filename)[0]
|
|
|
|
filename_text = re.sub(re_numbers_at_start, '', filename_text)
|
|
|
|
if re_word:
|
|
|
|
tokens = re_word.findall(filename_text)
|
|
|
|
filename_text = (shared.opts.dataset_filename_join_string or "").join(tokens)
|
2022-10-02 15:03:39 +03:00
|
|
|
|
|
|
|
npimage = np.array(image).astype(np.uint8)
|
|
|
|
npimage = (npimage / 127.5 - 1.0).astype(np.float32)
|
|
|
|
|
|
|
|
torchdata = torch.from_numpy(npimage).to(device=device, dtype=torch.float32)
|
|
|
|
torchdata = torch.moveaxis(torchdata, 2, 0)
|
|
|
|
|
|
|
|
init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze()
|
2022-10-02 22:59:01 +03:00
|
|
|
init_latent = init_latent.to(devices.cpu)
|
2022-10-02 15:03:39 +03:00
|
|
|
|
2022-10-12 20:49:47 +03:00
|
|
|
entry = DatasetEntry(filename=path, filename_text=filename_text, latent=init_latent)
|
|
|
|
|
2022-10-11 19:03:08 +03:00
|
|
|
if include_cond:
|
2022-10-12 20:49:47 +03:00
|
|
|
entry.cond_text = self.create_text(filename_text)
|
2022-10-15 09:24:59 +03:00
|
|
|
entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
|
2022-10-11 19:03:08 +03:00
|
|
|
|
2022-10-12 20:49:47 +03:00
|
|
|
self.dataset.append(entry)
|
2022-10-02 15:03:39 +03:00
|
|
|
|
2022-10-14 23:45:26 +03:00
|
|
|
assert len(self.dataset) > 1, "No images have been found in the dataset."
|
2022-10-15 09:24:59 +03:00
|
|
|
self.length = len(self.dataset) * repeats // batch_size
|
2022-10-02 15:03:39 +03:00
|
|
|
|
2022-10-15 09:24:59 +03:00
|
|
|
self.initial_indexes = np.arange(len(self.dataset))
|
2022-10-02 15:03:39 +03:00
|
|
|
self.indexes = None
|
|
|
|
self.shuffle()
|
|
|
|
|
|
|
|
def shuffle(self):
|
|
|
|
self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])]
|
|
|
|
|
2022-10-12 20:49:47 +03:00
|
|
|
def create_text(self, filename_text):
|
2022-10-11 19:03:08 +03:00
|
|
|
text = random.choice(self.lines)
|
|
|
|
text = text.replace("[name]", self.placeholder_token)
|
2022-10-12 20:49:47 +03:00
|
|
|
text = text.replace("[filewords]", filename_text)
|
2022-10-11 19:03:08 +03:00
|
|
|
return text
|
|
|
|
|
2022-10-02 15:03:39 +03:00
|
|
|
def __len__(self):
|
|
|
|
return self.length
|
|
|
|
|
|
|
|
def __getitem__(self, i):
|
2022-10-15 09:24:59 +03:00
|
|
|
res = []
|
2022-10-02 15:03:39 +03:00
|
|
|
|
2022-10-15 09:24:59 +03:00
|
|
|
for j in range(self.batch_size):
|
|
|
|
position = i * self.batch_size + j
|
|
|
|
if position % len(self.indexes) == 0:
|
|
|
|
self.shuffle()
|
2022-10-12 20:49:47 +03:00
|
|
|
|
2022-10-15 09:24:59 +03:00
|
|
|
index = self.indexes[position % len(self.indexes)]
|
|
|
|
entry = self.dataset[index]
|
2022-10-02 15:03:39 +03:00
|
|
|
|
2022-10-15 09:24:59 +03:00
|
|
|
if entry.cond is None:
|
|
|
|
entry.cond_text = self.create_text(entry.filename_text)
|
2022-10-12 20:49:47 +03:00
|
|
|
|
2022-10-15 09:24:59 +03:00
|
|
|
res.append(entry)
|
2022-10-02 15:03:39 +03:00
|
|
|
|
2022-10-15 09:24:59 +03:00
|
|
|
return res
|