mirror of
https://github.com/sd-webui/stable-diffusion-webui.git
synced 2025-01-07 06:11:01 +03:00
a9bc7eae19
for more information, see https://pre-commit.ci
182 lines
5.8 KiB
Python
182 lines
5.8 KiB
Python
import torch
|
|
from torch.utils.data import DataLoader
|
|
from torchvision import transforms
|
|
from torchvision.transforms.functional import InterpolationMode
|
|
|
|
from data.coco_karpathy_dataset import (
|
|
coco_karpathy_train,
|
|
coco_karpathy_caption_eval,
|
|
coco_karpathy_retrieval_eval,
|
|
)
|
|
from data.nocaps_dataset import nocaps_eval
|
|
from data.flickr30k_dataset import flickr30k_train, flickr30k_retrieval_eval
|
|
from data.vqa_dataset import vqa_dataset
|
|
from data.nlvr_dataset import nlvr_dataset
|
|
from data.pretrain_dataset import pretrain_dataset
|
|
from transform.randaugment import RandomAugment
|
|
|
|
|
|
def create_dataset(dataset, config, min_scale=0.5):
|
|
normalize = transforms.Normalize(
|
|
(0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)
|
|
)
|
|
|
|
transform_train = transforms.Compose(
|
|
[
|
|
transforms.RandomResizedCrop(
|
|
config["image_size"],
|
|
scale=(min_scale, 1.0),
|
|
interpolation=InterpolationMode.BICUBIC,
|
|
),
|
|
transforms.RandomHorizontalFlip(),
|
|
RandomAugment(
|
|
2,
|
|
5,
|
|
isPIL=True,
|
|
augs=[
|
|
"Identity",
|
|
"AutoContrast",
|
|
"Brightness",
|
|
"Sharpness",
|
|
"Equalize",
|
|
"ShearX",
|
|
"ShearY",
|
|
"TranslateX",
|
|
"TranslateY",
|
|
"Rotate",
|
|
],
|
|
),
|
|
transforms.ToTensor(),
|
|
normalize,
|
|
]
|
|
)
|
|
transform_test = transforms.Compose(
|
|
[
|
|
transforms.Resize(
|
|
(config["image_size"], config["image_size"]),
|
|
interpolation=InterpolationMode.BICUBIC,
|
|
),
|
|
transforms.ToTensor(),
|
|
normalize,
|
|
]
|
|
)
|
|
|
|
if dataset == "pretrain":
|
|
dataset = pretrain_dataset(
|
|
config["train_file"], config["laion_path"], transform_train
|
|
)
|
|
return dataset
|
|
|
|
elif dataset == "caption_coco":
|
|
train_dataset = coco_karpathy_train(
|
|
transform_train,
|
|
config["image_root"],
|
|
config["ann_root"],
|
|
prompt=config["prompt"],
|
|
)
|
|
val_dataset = coco_karpathy_caption_eval(
|
|
transform_test, config["image_root"], config["ann_root"], "val"
|
|
)
|
|
test_dataset = coco_karpathy_caption_eval(
|
|
transform_test, config["image_root"], config["ann_root"], "test"
|
|
)
|
|
return train_dataset, val_dataset, test_dataset
|
|
|
|
elif dataset == "nocaps":
|
|
val_dataset = nocaps_eval(
|
|
transform_test, config["image_root"], config["ann_root"], "val"
|
|
)
|
|
test_dataset = nocaps_eval(
|
|
transform_test, config["image_root"], config["ann_root"], "test"
|
|
)
|
|
return val_dataset, test_dataset
|
|
|
|
elif dataset == "retrieval_coco":
|
|
train_dataset = coco_karpathy_train(
|
|
transform_train, config["image_root"], config["ann_root"]
|
|
)
|
|
val_dataset = coco_karpathy_retrieval_eval(
|
|
transform_test, config["image_root"], config["ann_root"], "val"
|
|
)
|
|
test_dataset = coco_karpathy_retrieval_eval(
|
|
transform_test, config["image_root"], config["ann_root"], "test"
|
|
)
|
|
return train_dataset, val_dataset, test_dataset
|
|
|
|
elif dataset == "retrieval_flickr":
|
|
train_dataset = flickr30k_train(
|
|
transform_train, config["image_root"], config["ann_root"]
|
|
)
|
|
val_dataset = flickr30k_retrieval_eval(
|
|
transform_test, config["image_root"], config["ann_root"], "val"
|
|
)
|
|
test_dataset = flickr30k_retrieval_eval(
|
|
transform_test, config["image_root"], config["ann_root"], "test"
|
|
)
|
|
return train_dataset, val_dataset, test_dataset
|
|
|
|
elif dataset == "vqa":
|
|
train_dataset = vqa_dataset(
|
|
transform_train,
|
|
config["ann_root"],
|
|
config["vqa_root"],
|
|
config["vg_root"],
|
|
train_files=config["train_files"],
|
|
split="train",
|
|
)
|
|
test_dataset = vqa_dataset(
|
|
transform_test,
|
|
config["ann_root"],
|
|
config["vqa_root"],
|
|
config["vg_root"],
|
|
split="test",
|
|
)
|
|
return train_dataset, test_dataset
|
|
|
|
elif dataset == "nlvr":
|
|
train_dataset = nlvr_dataset(
|
|
transform_train, config["image_root"], config["ann_root"], "train"
|
|
)
|
|
val_dataset = nlvr_dataset(
|
|
transform_test, config["image_root"], config["ann_root"], "val"
|
|
)
|
|
test_dataset = nlvr_dataset(
|
|
transform_test, config["image_root"], config["ann_root"], "test"
|
|
)
|
|
return train_dataset, val_dataset, test_dataset
|
|
|
|
|
|
def create_sampler(datasets, shuffles, num_tasks, global_rank):
|
|
samplers = []
|
|
for dataset, shuffle in zip(datasets, shuffles):
|
|
sampler = torch.utils.data.DistributedSampler(
|
|
dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle
|
|
)
|
|
samplers.append(sampler)
|
|
return samplers
|
|
|
|
|
|
def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns):
|
|
loaders = []
|
|
for dataset, sampler, bs, n_worker, is_train, collate_fn in zip(
|
|
datasets, samplers, batch_size, num_workers, is_trains, collate_fns
|
|
):
|
|
if is_train:
|
|
shuffle = sampler is None
|
|
drop_last = True
|
|
else:
|
|
shuffle = False
|
|
drop_last = False
|
|
loader = DataLoader(
|
|
dataset,
|
|
batch_size=bs,
|
|
num_workers=n_worker,
|
|
pin_memory=True,
|
|
sampler=sampler,
|
|
shuffle=shuffle,
|
|
collate_fn=collate_fn,
|
|
drop_last=drop_last,
|
|
)
|
|
loaders.append(loader)
|
|
return loaders
|