mirror of
https://github.com/sd-webui/stable-diffusion-webui.git
synced 2024-12-15 15:22:55 +03:00
59 lines
1.6 KiB
Python
59 lines
1.6 KiB
Python
import json
|
|
import os
|
|
import random
|
|
|
|
from torch.utils.data import Dataset
|
|
|
|
from PIL import Image
|
|
from PIL import ImageFile
|
|
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
|
Image.MAX_IMAGE_PIXELS = None
|
|
|
|
from data.utils import pre_caption
|
|
import os,glob
|
|
|
|
class pretrain_dataset(Dataset):
|
|
def __init__(self, ann_file, laion_path, transform):
|
|
|
|
self.ann_pretrain = []
|
|
for f in ann_file:
|
|
print('loading '+f)
|
|
ann = json.load(open(f,'r'))
|
|
self.ann_pretrain += ann
|
|
|
|
self.laion_path = laion_path
|
|
if self.laion_path:
|
|
self.laion_files = glob.glob(os.path.join(laion_path,'*.json'))
|
|
|
|
print('loading '+self.laion_files[0])
|
|
with open(self.laion_files[0],'r') as f:
|
|
self.ann_laion = json.load(f)
|
|
|
|
self.annotation = self.ann_pretrain + self.ann_laion
|
|
else:
|
|
self.annotation = self.ann_pretrain
|
|
|
|
self.transform = transform
|
|
|
|
|
|
def reload_laion(self, epoch):
|
|
n = epoch%len(self.laion_files)
|
|
print('loading '+self.laion_files[n])
|
|
with open(self.laion_files[n],'r') as f:
|
|
self.ann_laion = json.load(f)
|
|
|
|
self.annotation = self.ann_pretrain + self.ann_laion
|
|
|
|
|
|
def __len__(self):
|
|
return len(self.annotation)
|
|
|
|
def __getitem__(self, index):
|
|
|
|
ann = self.annotation[index]
|
|
|
|
image = Image.open(ann['image']).convert('RGB')
|
|
image = self.transform(image)
|
|
caption = pre_caption(ann['caption'],30)
|
|
|
|
return image, caption |