stable-diffusion-webui/ldm/data/__init__.py
2023-06-23 02:58:24 +00:00

182 lines
5.8 KiB
Python

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from data.coco_karpathy_dataset import (
coco_karpathy_train,
coco_karpathy_caption_eval,
coco_karpathy_retrieval_eval,
)
from data.nocaps_dataset import nocaps_eval
from data.flickr30k_dataset import flickr30k_train, flickr30k_retrieval_eval
from data.vqa_dataset import vqa_dataset
from data.nlvr_dataset import nlvr_dataset
from data.pretrain_dataset import pretrain_dataset
from transform.randaugment import RandomAugment
def create_dataset(dataset, config, min_scale=0.5):
normalize = transforms.Normalize(
(0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)
)
transform_train = transforms.Compose(
[
transforms.RandomResizedCrop(
config["image_size"],
scale=(min_scale, 1.0),
interpolation=InterpolationMode.BICUBIC,
),
transforms.RandomHorizontalFlip(),
RandomAugment(
2,
5,
isPIL=True,
augs=[
"Identity",
"AutoContrast",
"Brightness",
"Sharpness",
"Equalize",
"ShearX",
"ShearY",
"TranslateX",
"TranslateY",
"Rotate",
],
),
transforms.ToTensor(),
normalize,
]
)
transform_test = transforms.Compose(
[
transforms.Resize(
(config["image_size"], config["image_size"]),
interpolation=InterpolationMode.BICUBIC,
),
transforms.ToTensor(),
normalize,
]
)
if dataset == "pretrain":
dataset = pretrain_dataset(
config["train_file"], config["laion_path"], transform_train
)
return dataset
elif dataset == "caption_coco":
train_dataset = coco_karpathy_train(
transform_train,
config["image_root"],
config["ann_root"],
prompt=config["prompt"],
)
val_dataset = coco_karpathy_caption_eval(
transform_test, config["image_root"], config["ann_root"], "val"
)
test_dataset = coco_karpathy_caption_eval(
transform_test, config["image_root"], config["ann_root"], "test"
)
return train_dataset, val_dataset, test_dataset
elif dataset == "nocaps":
val_dataset = nocaps_eval(
transform_test, config["image_root"], config["ann_root"], "val"
)
test_dataset = nocaps_eval(
transform_test, config["image_root"], config["ann_root"], "test"
)
return val_dataset, test_dataset
elif dataset == "retrieval_coco":
train_dataset = coco_karpathy_train(
transform_train, config["image_root"], config["ann_root"]
)
val_dataset = coco_karpathy_retrieval_eval(
transform_test, config["image_root"], config["ann_root"], "val"
)
test_dataset = coco_karpathy_retrieval_eval(
transform_test, config["image_root"], config["ann_root"], "test"
)
return train_dataset, val_dataset, test_dataset
elif dataset == "retrieval_flickr":
train_dataset = flickr30k_train(
transform_train, config["image_root"], config["ann_root"]
)
val_dataset = flickr30k_retrieval_eval(
transform_test, config["image_root"], config["ann_root"], "val"
)
test_dataset = flickr30k_retrieval_eval(
transform_test, config["image_root"], config["ann_root"], "test"
)
return train_dataset, val_dataset, test_dataset
elif dataset == "vqa":
train_dataset = vqa_dataset(
transform_train,
config["ann_root"],
config["vqa_root"],
config["vg_root"],
train_files=config["train_files"],
split="train",
)
test_dataset = vqa_dataset(
transform_test,
config["ann_root"],
config["vqa_root"],
config["vg_root"],
split="test",
)
return train_dataset, test_dataset
elif dataset == "nlvr":
train_dataset = nlvr_dataset(
transform_train, config["image_root"], config["ann_root"], "train"
)
val_dataset = nlvr_dataset(
transform_test, config["image_root"], config["ann_root"], "val"
)
test_dataset = nlvr_dataset(
transform_test, config["image_root"], config["ann_root"], "test"
)
return train_dataset, val_dataset, test_dataset
def create_sampler(datasets, shuffles, num_tasks, global_rank):
samplers = []
for dataset, shuffle in zip(datasets, shuffles):
sampler = torch.utils.data.DistributedSampler(
dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle
)
samplers.append(sampler)
return samplers
def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns):
loaders = []
for dataset, sampler, bs, n_worker, is_train, collate_fn in zip(
datasets, samplers, batch_size, num_workers, is_trains, collate_fns
):
if is_train:
shuffle = sampler is None
drop_last = True
else:
shuffle = False
drop_last = False
loader = DataLoader(
dataset,
batch_size=bs,
num_workers=n_worker,
pin_memory=True,
sampler=sampler,
shuffle=shuffle,
collate_fn=collate_fn,
drop_last=drop_last,
)
loaders.append(loader)
return loaders