sygil-webui/ldm/data/__init__.py
2022-09-28 12:37:15 -07:00

102 lines
5.1 KiB
Python

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from data.coco_karpathy_dataset import coco_karpathy_train, coco_karpathy_caption_eval, coco_karpathy_retrieval_eval
from data.nocaps_dataset import nocaps_eval
from data.flickr30k_dataset import flickr30k_train, flickr30k_retrieval_eval
from data.vqa_dataset import vqa_dataset
from data.nlvr_dataset import nlvr_dataset
from data.pretrain_dataset import pretrain_dataset
from transform.randaugment import RandomAugment
def create_dataset(dataset, config, min_scale=0.5):
normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
transform_train = transforms.Compose([
transforms.RandomResizedCrop(config['image_size'],scale=(min_scale, 1.0),interpolation=InterpolationMode.BICUBIC),
transforms.RandomHorizontalFlip(),
RandomAugment(2,5,isPIL=True,augs=['Identity','AutoContrast','Brightness','Sharpness','Equalize',
'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),
transforms.ToTensor(),
normalize,
])
transform_test = transforms.Compose([
transforms.Resize((config['image_size'],config['image_size']),interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
normalize,
])
if dataset=='pretrain':
dataset = pretrain_dataset(config['train_file'], config['laion_path'], transform_train)
return dataset
elif dataset=='caption_coco':
train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root'], prompt=config['prompt'])
val_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'val')
test_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'test')
return train_dataset, val_dataset, test_dataset
elif dataset=='nocaps':
val_dataset = nocaps_eval(transform_test, config['image_root'], config['ann_root'], 'val')
test_dataset = nocaps_eval(transform_test, config['image_root'], config['ann_root'], 'test')
return val_dataset, test_dataset
elif dataset=='retrieval_coco':
train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root'])
val_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val')
test_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test')
return train_dataset, val_dataset, test_dataset
elif dataset=='retrieval_flickr':
train_dataset = flickr30k_train(transform_train, config['image_root'], config['ann_root'])
val_dataset = flickr30k_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val')
test_dataset = flickr30k_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test')
return train_dataset, val_dataset, test_dataset
elif dataset=='vqa':
train_dataset = vqa_dataset(transform_train, config['ann_root'], config['vqa_root'], config['vg_root'],
train_files = config['train_files'], split='train')
test_dataset = vqa_dataset(transform_test, config['ann_root'], config['vqa_root'], config['vg_root'], split='test')
return train_dataset, test_dataset
elif dataset=='nlvr':
train_dataset = nlvr_dataset(transform_train, config['image_root'], config['ann_root'],'train')
val_dataset = nlvr_dataset(transform_test, config['image_root'], config['ann_root'],'val')
test_dataset = nlvr_dataset(transform_test, config['image_root'], config['ann_root'],'test')
return train_dataset, val_dataset, test_dataset
def create_sampler(datasets, shuffles, num_tasks, global_rank):
samplers = []
for dataset,shuffle in zip(datasets,shuffles):
sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle)
samplers.append(sampler)
return samplers
def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns):
loaders = []
for dataset,sampler,bs,n_worker,is_train,collate_fn in zip(datasets,samplers,batch_size,num_workers,is_trains,collate_fns):
if is_train:
shuffle = (sampler is None)
drop_last = True
else:
shuffle = False
drop_last = False
loader = DataLoader(
dataset,
batch_size=bs,
num_workers=n_worker,
pin_memory=True,
sampler=sampler,
shuffle=shuffle,
collate_fn=collate_fn,
drop_last=drop_last,
)
loaders.append(loader)
return loaders