Img2txt dependencies and necessary files. (#1354)

This commit is contained in:
Alejandro Gil 2022-09-28 12:37:15 -07:00 committed by GitHub
parent 510d103a83
commit f678efb4e2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
33 changed files with 4803 additions and 28 deletions

View File

@ -0,0 +1,21 @@
{
"architectures": [
"BertModel"
],
"attention_probs_dropout_prob": 0.1,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"model_type": "bert",
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 0,
"type_vocab_size": 2,
"vocab_size": 30522,
"encoder_width": 768,
"add_cross_attention": true
}

View File

@ -0,0 +1,33 @@
image_root: '/export/share/datasets/vision/coco/images/'
ann_root: 'annotation'
coco_gt_root: 'annotation/coco_gt'
# set pretrained as a file path or an url
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
# size of vit model; base or large
vit: 'base'
vit_grad_ckpt: False
vit_ckpt_layer: 0
batch_size: 32
init_lr: 1e-5
# vit: 'large'
# vit_grad_ckpt: True
# vit_ckpt_layer: 5
# batch_size: 16
# init_lr: 2e-6
image_size: 384
# generation configs
max_length: 20
min_length: 5
num_beams: 3
prompt: 'a picture of '
# optimizer
weight_decay: 0.05
min_lr: 0
max_epoch: 5

View File

@ -0,0 +1,21 @@
{
"architectures": [
"BertModel"
],
"attention_probs_dropout_prob": 0.1,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"model_type": "bert",
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 0,
"type_vocab_size": 2,
"vocab_size": 30524,
"encoder_width": 768,
"add_cross_attention": true
}

21
configs/blip/nlvr.yaml Normal file
View File

@ -0,0 +1,21 @@
image_root: '/export/share/datasets/vision/NLVR2/'
ann_root: 'annotation'
# set pretrained as a file path or an url
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_nlvr.pth'
#size of vit model; base or large
vit: 'base'
batch_size_train: 16
batch_size_test: 64
vit_grad_ckpt: False
vit_ckpt_layer: 0
max_epoch: 15
image_size: 384
# optimizer
weight_decay: 0.05
init_lr: 3e-5
min_lr: 0

15
configs/blip/nocaps.yaml Normal file
View File

@ -0,0 +1,15 @@
image_root: '/export/share/datasets/vision/nocaps/'
ann_root: 'annotation'
# set pretrained as a file path or an url
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
vit: 'base'
batch_size: 32
image_size: 384
max_length: 20
min_length: 5
num_beams: 3
prompt: 'a picture of '

View File

@ -0,0 +1,27 @@
train_file: ['/export/share/junnan-li/VL_pretrain/annotation/coco_karpathy_train.json',
'/export/share/junnan-li/VL_pretrain/annotation/vg_caption.json',
]
laion_path: ''
# size of vit model; base or large
vit: 'base'
vit_grad_ckpt: False
vit_ckpt_layer: 0
image_size: 224
batch_size: 75
queue_size: 57600
alpha: 0.4
# optimizer
weight_decay: 0.05
init_lr: 3e-4
min_lr: 1e-6
warmup_lr: 1e-6
lr_decay_rate: 0.9
max_epoch: 20
warmup_steps: 3000

View File

@ -0,0 +1,34 @@
image_root: '/export/share/datasets/vision/coco/images/'
ann_root: 'annotation'
dataset: 'coco'
# set pretrained as a file path or an url
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth'
# size of vit model; base or large
vit: 'base'
batch_size_train: 32
batch_size_test: 64
vit_grad_ckpt: True
vit_ckpt_layer: 4
init_lr: 1e-5
# vit: 'large'
# batch_size_train: 16
# batch_size_test: 32
# vit_grad_ckpt: True
# vit_ckpt_layer: 12
# init_lr: 5e-6
image_size: 384
queue_size: 57600
alpha: 0.4
k_test: 256
negative_all_rank: True
# optimizer
weight_decay: 0.05
min_lr: 0
max_epoch: 6

View File

@ -0,0 +1,34 @@
image_root: '/export/share/datasets/vision/flickr30k/'
ann_root: 'annotation'
dataset: 'flickr'
# set pretrained as a file path or an url
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_flickr.pth'
# size of vit model; base or large
vit: 'base'
batch_size_train: 32
batch_size_test: 64
vit_grad_ckpt: True
vit_ckpt_layer: 4
init_lr: 1e-5
# vit: 'large'
# batch_size_train: 16
# batch_size_test: 32
# vit_grad_ckpt: True
# vit_ckpt_layer: 10
# init_lr: 5e-6
image_size: 384
queue_size: 57600
alpha: 0.4
k_test: 128
negative_all_rank: False
# optimizer
weight_decay: 0.05
min_lr: 0
max_epoch: 6

View File

@ -0,0 +1,12 @@
video_root: '/export/share/dongxuli/data/msrvtt_retrieval/videos'
ann_root: 'annotation'
# set pretrained as a file path or an url
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth'
# size of vit model; base or large
vit: 'base'
batch_size: 64
k_test: 128
image_size: 384
num_frm_test: 8

25
configs/blip/vqa.yaml Normal file
View File

@ -0,0 +1,25 @@
vqa_root: '/export/share/datasets/vision/VQA/Images/mscoco/' #followed by train2014/
vg_root: '/export/share/datasets/vision/visual-genome/' #followed by image/
train_files: ['vqa_train','vqa_val','vg_qa']
ann_root: 'annotation'
# set pretrained as a file path or an url
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth'
# size of vit model; base or large
vit: 'base'
batch_size_train: 16
batch_size_test: 32
vit_grad_ckpt: False
vit_ckpt_layer: 0
init_lr: 2e-5
image_size: 480
k_test: 128
inference: 'rank'
# optimizer
weight_decay: 0.05
min_lr: 0
max_epoch: 10

View File

@ -41,6 +41,8 @@ dependencies:
- diffusers==0.3.0
- einops==0.3.0
- facexlib>=0.2.3
- ftfy==6.1.1
- fairscale==0.4.4
- gradio==3.1.6
- hydralit==1.0.14
- hydralit_components==1.0.10
@ -51,11 +53,14 @@ dependencies:
- opencv-python-headless==4.6.0.66
- pandas==1.4.3
- piexif==1.1.3
- pycocotools==2.0.5
- pycocoevalcap==1.2
- pudb==2019.2
- pynvml==11.4.1
- python-slugify>=6.1.2
- pytorch-lightning==1.4.2
- retry>=0.9.2
- regex
- realesrgan==0.3.0
- streamlit==1.13.0
- streamlit-on-Hover-tabs==1.0.1
@ -65,7 +70,9 @@ dependencies:
- streamlit-tensorboard==0.0.2
- test-tube>=0.7.5
- tensorboard==2.10.1
- timm==0.4.12
- torch-fidelity==0.3.0
- torchmetrics==0.6.0
- transformers==4.19.2
- tqdm==4.64.0

View File

@ -0,0 +1,101 @@
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from data.coco_karpathy_dataset import coco_karpathy_train, coco_karpathy_caption_eval, coco_karpathy_retrieval_eval
from data.nocaps_dataset import nocaps_eval
from data.flickr30k_dataset import flickr30k_train, flickr30k_retrieval_eval
from data.vqa_dataset import vqa_dataset
from data.nlvr_dataset import nlvr_dataset
from data.pretrain_dataset import pretrain_dataset
from transform.randaugment import RandomAugment
def create_dataset(dataset, config, min_scale=0.5):
normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
transform_train = transforms.Compose([
transforms.RandomResizedCrop(config['image_size'],scale=(min_scale, 1.0),interpolation=InterpolationMode.BICUBIC),
transforms.RandomHorizontalFlip(),
RandomAugment(2,5,isPIL=True,augs=['Identity','AutoContrast','Brightness','Sharpness','Equalize',
'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),
transforms.ToTensor(),
normalize,
])
transform_test = transforms.Compose([
transforms.Resize((config['image_size'],config['image_size']),interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
normalize,
])
if dataset=='pretrain':
dataset = pretrain_dataset(config['train_file'], config['laion_path'], transform_train)
return dataset
elif dataset=='caption_coco':
train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root'], prompt=config['prompt'])
val_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'val')
test_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'test')
return train_dataset, val_dataset, test_dataset
elif dataset=='nocaps':
val_dataset = nocaps_eval(transform_test, config['image_root'], config['ann_root'], 'val')
test_dataset = nocaps_eval(transform_test, config['image_root'], config['ann_root'], 'test')
return val_dataset, test_dataset
elif dataset=='retrieval_coco':
train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root'])
val_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val')
test_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test')
return train_dataset, val_dataset, test_dataset
elif dataset=='retrieval_flickr':
train_dataset = flickr30k_train(transform_train, config['image_root'], config['ann_root'])
val_dataset = flickr30k_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val')
test_dataset = flickr30k_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test')
return train_dataset, val_dataset, test_dataset
elif dataset=='vqa':
train_dataset = vqa_dataset(transform_train, config['ann_root'], config['vqa_root'], config['vg_root'],
train_files = config['train_files'], split='train')
test_dataset = vqa_dataset(transform_test, config['ann_root'], config['vqa_root'], config['vg_root'], split='test')
return train_dataset, test_dataset
elif dataset=='nlvr':
train_dataset = nlvr_dataset(transform_train, config['image_root'], config['ann_root'],'train')
val_dataset = nlvr_dataset(transform_test, config['image_root'], config['ann_root'],'val')
test_dataset = nlvr_dataset(transform_test, config['image_root'], config['ann_root'],'test')
return train_dataset, val_dataset, test_dataset
def create_sampler(datasets, shuffles, num_tasks, global_rank):
samplers = []
for dataset,shuffle in zip(datasets,shuffles):
sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle)
samplers.append(sampler)
return samplers
def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns):
loaders = []
for dataset,sampler,bs,n_worker,is_train,collate_fn in zip(datasets,samplers,batch_size,num_workers,is_trains,collate_fns):
if is_train:
shuffle = (sampler is None)
drop_last = True
else:
shuffle = False
drop_last = False
loader = DataLoader(
dataset,
batch_size=bs,
num_workers=n_worker,
pin_memory=True,
sampler=sampler,
shuffle=shuffle,
collate_fn=collate_fn,
drop_last=drop_last,
)
loaders.append(loader)
return loaders

View File

@ -0,0 +1,126 @@
import os
import json
from torch.utils.data import Dataset
from torchvision.datasets.utils import download_url
from PIL import Image
from data.utils import pre_caption
class coco_karpathy_train(Dataset):
def __init__(self, transform, image_root, ann_root, max_words=30, prompt=''):
'''
image_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
'''
url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_train.json'
filename = 'coco_karpathy_train.json'
download_url(url,ann_root)
self.annotation = json.load(open(os.path.join(ann_root,filename),'r'))
self.transform = transform
self.image_root = image_root
self.max_words = max_words
self.prompt = prompt
self.img_ids = {}
n = 0
for ann in self.annotation:
img_id = ann['image_id']
if img_id not in self.img_ids.keys():
self.img_ids[img_id] = n
n += 1
def __len__(self):
return len(self.annotation)
def __getitem__(self, index):
ann = self.annotation[index]
image_path = os.path.join(self.image_root,ann['image'])
image = Image.open(image_path).convert('RGB')
image = self.transform(image)
caption = self.prompt+pre_caption(ann['caption'], self.max_words)
return image, caption, self.img_ids[ann['image_id']]
class coco_karpathy_caption_eval(Dataset):
def __init__(self, transform, image_root, ann_root, split):
'''
image_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
split (string): val or test
'''
urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json',
'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json'}
filenames = {'val':'coco_karpathy_val.json','test':'coco_karpathy_test.json'}
download_url(urls[split],ann_root)
self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
self.transform = transform
self.image_root = image_root
def __len__(self):
return len(self.annotation)
def __getitem__(self, index):
ann = self.annotation[index]
image_path = os.path.join(self.image_root,ann['image'])
image = Image.open(image_path).convert('RGB')
image = self.transform(image)
img_id = ann['image'].split('/')[-1].strip('.jpg').split('_')[-1]
return image, int(img_id)
class coco_karpathy_retrieval_eval(Dataset):
def __init__(self, transform, image_root, ann_root, split, max_words=30):
'''
image_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
split (string): val or test
'''
urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json',
'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json'}
filenames = {'val':'coco_karpathy_val.json','test':'coco_karpathy_test.json'}
download_url(urls[split],ann_root)
self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
self.transform = transform
self.image_root = image_root
self.text = []
self.image = []
self.txt2img = {}
self.img2txt = {}
txt_id = 0
for img_id, ann in enumerate(self.annotation):
self.image.append(ann['image'])
self.img2txt[img_id] = []
for i, caption in enumerate(ann['caption']):
self.text.append(pre_caption(caption,max_words))
self.img2txt[img_id].append(txt_id)
self.txt2img[txt_id] = img_id
txt_id += 1
def __len__(self):
return len(self.annotation)
def __getitem__(self, index):
image_path = os.path.join(self.image_root, self.annotation[index]['image'])
image = Image.open(image_path).convert('RGB')
image = self.transform(image)
return image, index

View File

@ -0,0 +1,93 @@
import os
import json
from torch.utils.data import Dataset
from torchvision.datasets.utils import download_url
from PIL import Image
from data.utils import pre_caption
class flickr30k_train(Dataset):
def __init__(self, transform, image_root, ann_root, max_words=30, prompt=''):
'''
image_root (string): Root directory of images (e.g. flickr30k/)
ann_root (string): directory to store the annotation file
'''
url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_train.json'
filename = 'flickr30k_train.json'
download_url(url,ann_root)
self.annotation = json.load(open(os.path.join(ann_root,filename),'r'))
self.transform = transform
self.image_root = image_root
self.max_words = max_words
self.prompt = prompt
self.img_ids = {}
n = 0
for ann in self.annotation:
img_id = ann['image_id']
if img_id not in self.img_ids.keys():
self.img_ids[img_id] = n
n += 1
def __len__(self):
return len(self.annotation)
def __getitem__(self, index):
ann = self.annotation[index]
image_path = os.path.join(self.image_root,ann['image'])
image = Image.open(image_path).convert('RGB')
image = self.transform(image)
caption = self.prompt+pre_caption(ann['caption'], self.max_words)
return image, caption, self.img_ids[ann['image_id']]
class flickr30k_retrieval_eval(Dataset):
def __init__(self, transform, image_root, ann_root, split, max_words=30):
'''
image_root (string): Root directory of images (e.g. flickr30k/)
ann_root (string): directory to store the annotation file
split (string): val or test
'''
urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_val.json',
'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_test.json'}
filenames = {'val':'flickr30k_val.json','test':'flickr30k_test.json'}
download_url(urls[split],ann_root)
self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
self.transform = transform
self.image_root = image_root
self.text = []
self.image = []
self.txt2img = {}
self.img2txt = {}
txt_id = 0
for img_id, ann in enumerate(self.annotation):
self.image.append(ann['image'])
self.img2txt[img_id] = []
for i, caption in enumerate(ann['caption']):
self.text.append(pre_caption(caption,max_words))
self.img2txt[img_id].append(txt_id)
self.txt2img[txt_id] = img_id
txt_id += 1
def __len__(self):
return len(self.annotation)
def __getitem__(self, index):
image_path = os.path.join(self.image_root, self.annotation[index]['image'])
image = Image.open(image_path).convert('RGB')
image = self.transform(image)
return image, index

78
ldm/data/nlvr_dataset.py Normal file
View File

@ -0,0 +1,78 @@
import os
import json
import random
from torch.utils.data import Dataset
from torchvision.datasets.utils import download_url
from PIL import Image
from data.utils import pre_caption
class nlvr_dataset(Dataset):
def __init__(self, transform, image_root, ann_root, split):
'''
image_root (string): Root directory of images
ann_root (string): directory to store the annotation file
split (string): train, val or test
'''
urls = {'train':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_train.json',
'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_dev.json',
'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_test.json'}
filenames = {'train':'nlvr_train.json','val':'nlvr_dev.json','test':'nlvr_test.json'}
download_url(urls[split],ann_root)
self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
self.transform = transform
self.image_root = image_root
def __len__(self):
return len(self.annotation)
def __getitem__(self, index):
ann = self.annotation[index]
image0_path = os.path.join(self.image_root,ann['images'][0])
image0 = Image.open(image0_path).convert('RGB')
image0 = self.transform(image0)
image1_path = os.path.join(self.image_root,ann['images'][1])
image1 = Image.open(image1_path).convert('RGB')
image1 = self.transform(image1)
sentence = pre_caption(ann['sentence'], 40)
if ann['label']=='True':
label = 1
else:
label = 0
words = sentence.split(' ')
if 'left' not in words and 'right' not in words:
if random.random()<0.5:
return image0, image1, sentence, label
else:
return image1, image0, sentence, label
else:
if random.random()<0.5:
return image0, image1, sentence, label
else:
new_words = []
for word in words:
if word=='left':
new_words.append('right')
elif word=='right':
new_words.append('left')
else:
new_words.append(word)
sentence = ' '.join(new_words)
return image1, image0, sentence, label

View File

@ -0,0 +1,32 @@
import os
import json
from torch.utils.data import Dataset
from torchvision.datasets.utils import download_url
from PIL import Image
class nocaps_eval(Dataset):
def __init__(self, transform, image_root, ann_root, split):
urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nocaps_val.json',
'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nocaps_test.json'}
filenames = {'val':'nocaps_val.json','test':'nocaps_test.json'}
download_url(urls[split],ann_root)
self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
self.transform = transform
self.image_root = image_root
def __len__(self):
return len(self.annotation)
def __getitem__(self, index):
ann = self.annotation[index]
image_path = os.path.join(self.image_root,ann['image'])
image = Image.open(image_path).convert('RGB')
image = self.transform(image)
return image, int(ann['img_id'])

View File

@ -0,0 +1,59 @@
import json
import os
import random
from torch.utils.data import Dataset
from PIL import Image
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
Image.MAX_IMAGE_PIXELS = None
from data.utils import pre_caption
import os,glob
class pretrain_dataset(Dataset):
def __init__(self, ann_file, laion_path, transform):
self.ann_pretrain = []
for f in ann_file:
print('loading '+f)
ann = json.load(open(f,'r'))
self.ann_pretrain += ann
self.laion_path = laion_path
if self.laion_path:
self.laion_files = glob.glob(os.path.join(laion_path,'*.json'))
print('loading '+self.laion_files[0])
with open(self.laion_files[0],'r') as f:
self.ann_laion = json.load(f)
self.annotation = self.ann_pretrain + self.ann_laion
else:
self.annotation = self.ann_pretrain
self.transform = transform
def reload_laion(self, epoch):
n = epoch%len(self.laion_files)
print('loading '+self.laion_files[n])
with open(self.laion_files[n],'r') as f:
self.ann_laion = json.load(f)
self.annotation = self.ann_pretrain + self.ann_laion
def __len__(self):
return len(self.annotation)
def __getitem__(self, index):
ann = self.annotation[index]
image = Image.open(ann['image']).convert('RGB')
image = self.transform(image)
caption = pre_caption(ann['caption'],30)
return image, caption

112
ldm/data/utils.py Normal file
View File

@ -0,0 +1,112 @@
import re
import json
import os
import torch
import torch.distributed as dist
import utils
def pre_caption(caption,max_words=50):
caption = re.sub(
r"([.!\"()*#:;~])",
' ',
caption.lower(),
)
caption = re.sub(
r"\s{2,}",
' ',
caption,
)
caption = caption.rstrip('\n')
caption = caption.strip(' ')
#truncate caption
caption_words = caption.split(' ')
if len(caption_words)>max_words:
caption = ' '.join(caption_words[:max_words])
return caption
def pre_question(question,max_ques_words=50):
question = re.sub(
r"([.!\"()*#:;~])",
'',
question.lower(),
)
question = question.rstrip(' ')
#truncate question
question_words = question.split(' ')
if len(question_words)>max_ques_words:
question = ' '.join(question_words[:max_ques_words])
return question
def save_result(result, result_dir, filename, remove_duplicate=''):
result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,utils.get_rank()))
final_result_file = os.path.join(result_dir, '%s.json'%filename)
json.dump(result,open(result_file,'w'))
dist.barrier()
if utils.is_main_process():
# combine results from all processes
result = []
for rank in range(utils.get_world_size()):
result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,rank))
res = json.load(open(result_file,'r'))
result += res
if remove_duplicate:
result_new = []
id_list = []
for res in result:
if res[remove_duplicate] not in id_list:
id_list.append(res[remove_duplicate])
result_new.append(res)
result = result_new
json.dump(result,open(final_result_file,'w'))
print('result file saved to %s'%final_result_file)
return final_result_file
from pycocotools.coco import COCO
from pycocoevalcap.eval import COCOEvalCap
from torchvision.datasets.utils import download_url
def coco_caption_eval(coco_gt_root, results_file, split):
urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val_gt.json',
'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test_gt.json'}
filenames = {'val':'coco_karpathy_val_gt.json','test':'coco_karpathy_test_gt.json'}
download_url(urls[split],coco_gt_root)
annotation_file = os.path.join(coco_gt_root,filenames[split])
# create coco object and coco_result object
coco = COCO(annotation_file)
coco_result = coco.loadRes(results_file)
# create coco_eval object by taking coco and coco_result
coco_eval = COCOEvalCap(coco, coco_result)
# evaluate on a subset of images by setting
# coco_eval.params['image_id'] = coco_result.getImgIds()
# please remove this line when evaluating the full validation set
# coco_eval.params['image_id'] = coco_result.getImgIds()
# evaluate results
# SPICE will take a few minutes the first time, but speeds up due to caching
coco_eval.evaluate()
# print output evaluation scores
for metric, score in coco_eval.eval.items():
print(f'{metric}: {score:.3f}')
return coco_eval

110
ldm/data/video_dataset.py Normal file
View File

@ -0,0 +1,110 @@
from torch.utils.data import Dataset
from torchvision.datasets.utils import download_url
from PIL import Image
import torch
import numpy as np
import random
import decord
from decord import VideoReader
import json
import os
from data.utils import pre_caption
decord.bridge.set_bridge("torch")
class ImageNorm(object):
"""Apply Normalization to Image Pixels on GPU
"""
def __init__(self, mean, std):
self.mean = torch.tensor(mean).view(1, 3, 1, 1)
self.std = torch.tensor(std).view(1, 3, 1, 1)
def __call__(self, img):
if torch.max(img) > 1 and self.mean.max() <= 1:
img.div_(255.)
return img.sub_(self.mean).div_(self.std)
def load_jsonl(filename):
with open(filename, "r") as f:
return [json.loads(l.strip("\n")) for l in f.readlines()]
class VideoDataset(Dataset):
def __init__(self, video_root, ann_root, num_frm=4, frm_sampling_strategy="rand", max_img_size=384, video_fmt='.mp4'):
'''
image_root (string): Root directory of video
ann_root (string): directory to store the annotation file
'''
url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/msrvtt_test.jsonl'
filename = 'msrvtt_test.jsonl'
download_url(url,ann_root)
self.annotation = load_jsonl(os.path.join(ann_root,filename))
self.num_frm = num_frm
self.frm_sampling_strategy = frm_sampling_strategy
self.max_img_size = max_img_size
self.video_root = video_root
self.video_fmt = video_fmt
self.img_norm = ImageNorm(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
self.text = [pre_caption(ann['caption'],40) for ann in self.annotation]
self.txt2video = [i for i in range(len(self.annotation))]
self.video2txt = self.txt2video
def __len__(self):
return len(self.annotation)
def __getitem__(self, index):
ann = self.annotation[index]
video_path = os.path.join(self.video_root, ann['clip_name'] + self.video_fmt)
vid_frm_array = self._load_video_from_path_decord(video_path, height=self.max_img_size, width=self.max_img_size)
video = self.img_norm(vid_frm_array.float())
return video, ann['clip_name']
def _load_video_from_path_decord(self, video_path, height=None, width=None, start_time=None, end_time=None, fps=-1):
try:
if not height or not width:
vr = VideoReader(video_path)
else:
vr = VideoReader(video_path, width=width, height=height)
vlen = len(vr)
if start_time or end_time:
assert fps > 0, 'must provide video fps if specifying start and end time.'
start_idx = min(int(start_time * fps), vlen)
end_idx = min(int(end_time * fps), vlen)
else:
start_idx, end_idx = 0, vlen
if self.frm_sampling_strategy == 'uniform':
frame_indices = np.arange(start_idx, end_idx, vlen / self.num_frm, dtype=int)
elif self.frm_sampling_strategy == 'rand':
frame_indices = sorted(random.sample(range(vlen), self.num_frm))
elif self.frm_sampling_strategy == 'headtail':
frame_indices_head = sorted(random.sample(range(vlen // 2), self.num_frm // 2))
frame_indices_tail = sorted(random.sample(range(vlen // 2, vlen), self.num_frm // 2))
frame_indices = frame_indices_head + frame_indices_tail
else:
raise NotImplementedError('Invalid sampling strategy {} '.format(self.frm_sampling_strategy))
raw_sample_frms = vr.get_batch(frame_indices)
except Exception as e:
return None
raw_sample_frms = raw_sample_frms.permute(0, 3, 1, 2)
return raw_sample_frms

88
ldm/data/vqa_dataset.py Normal file
View File

@ -0,0 +1,88 @@
import os
import json
import random
from PIL import Image
import torch
from torch.utils.data import Dataset
from data.utils import pre_question
from torchvision.datasets.utils import download_url
class vqa_dataset(Dataset):
def __init__(self, transform, ann_root, vqa_root, vg_root, train_files=[], split="train"):
self.split = split
self.transform = transform
self.vqa_root = vqa_root
self.vg_root = vg_root
if split=='train':
urls = {'vqa_train':'https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_train.json',
'vqa_val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_val.json',
'vg_qa':'https://storage.googleapis.com/sfr-vision-language-research/datasets/vg_qa.json'}
self.annotation = []
for f in train_files:
download_url(urls[f],ann_root)
self.annotation += json.load(open(os.path.join(ann_root,'%s.json'%f),'r'))
else:
download_url('https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_test.json',ann_root)
self.annotation = json.load(open(os.path.join(ann_root,'vqa_test.json'),'r'))
download_url('https://storage.googleapis.com/sfr-vision-language-research/datasets/answer_list.json',ann_root)
self.answer_list = json.load(open(os.path.join(ann_root,'answer_list.json'),'r'))
def __len__(self):
return len(self.annotation)
def __getitem__(self, index):
ann = self.annotation[index]
if ann['dataset']=='vqa':
image_path = os.path.join(self.vqa_root,ann['image'])
elif ann['dataset']=='vg':
image_path = os.path.join(self.vg_root,ann['image'])
image = Image.open(image_path).convert('RGB')
image = self.transform(image)
if self.split == 'test':
question = pre_question(ann['question'])
question_id = ann['question_id']
return image, question, question_id
elif self.split=='train':
question = pre_question(ann['question'])
if ann['dataset']=='vqa':
answer_weight = {}
for answer in ann['answer']:
if answer in answer_weight.keys():
answer_weight[answer] += 1/len(ann['answer'])
else:
answer_weight[answer] = 1/len(ann['answer'])
answers = list(answer_weight.keys())
weights = list(answer_weight.values())
elif ann['dataset']=='vg':
answers = [ann['answer']]
weights = [0.2]
return image, question, answers, weights
def vqa_collate_fn(batch):
image_list, question_list, answer_list, weight_list, n = [], [], [], [], []
for image, question, answer, weights in batch:
image_list.append(image)
question_list.append(question)
weight_list += weights
answer_list += answer
n.append(len(answer))
return torch.stack(image_list,dim=0), question_list, answer_list, torch.Tensor(weight_list), n

0
ldm/models/__init__.py Normal file
View File

238
ldm/models/blip.py Normal file
View File

@ -0,0 +1,238 @@
'''
* Copyright (c) 2022, salesforce.com, inc.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
* By Junnan Li
'''
import warnings
warnings.filterwarnings("ignore")
from .vit import VisionTransformer, interpolate_pos_embed
from .med import BertConfig, BertModel, BertLMHeadModel
from transformers import BertTokenizer
import torch
from torch import nn
#import torch.nn.functional as F
import os
from urllib.parse import urlparse
from timm.models.hub import download_cached_file
class BLIP_Base(nn.Module):
def __init__(self,
med_config = 'configs/blip/med_config.json',
image_size = 224,
vit = 'base',
vit_grad_ckpt = False,
vit_ckpt_layer = 0,
):
"""
Args:
med_config (str): path for the mixture of encoder-decoder model's configuration file
image_size (int): input image size
vit (str): model size of vision transformer
"""
super().__init__()
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
self.tokenizer = init_tokenizer()
med_config = BertConfig.from_json_file(med_config)
med_config.encoder_width = vision_width
self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
def forward(self, image, caption, mode):
assert mode in ['image', 'text', 'multimodal'], "mode parameter must be image, text, or multimodal"
text = self.tokenizer(caption, return_tensors="pt").to(image.device)
if mode=='image':
# return image features
image_embeds = self.visual_encoder(image)
return image_embeds
elif mode=='text':
# return text features
text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
return_dict = True, mode = 'text')
return text_output.last_hidden_state
elif mode=='multimodal':
# return multimodel features
image_embeds = self.visual_encoder(image)
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
text.input_ids[:,0] = self.tokenizer.enc_token_id
output = self.text_encoder(text.input_ids,
attention_mask = text.attention_mask,
encoder_hidden_states = image_embeds,
encoder_attention_mask = image_atts,
return_dict = True,
)
return output.last_hidden_state
class BLIP_Decoder(nn.Module):
def __init__(self,
med_config = 'configs/blip/med_config.json',
image_size = 384,
vit = 'base',
vit_grad_ckpt = False,
vit_ckpt_layer = 0,
prompt = 'a picture of ',
):
"""
Args:
med_config (str): path for the mixture of encoder-decoder model's configuration file
image_size (int): input image size
vit (str): model size of vision transformer
"""
super().__init__()
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
self.tokenizer = init_tokenizer()
med_config = BertConfig.from_json_file(med_config)
med_config.encoder_width = vision_width
self.text_decoder = BertLMHeadModel(config=med_config)
self.prompt = prompt
self.prompt_length = len(self.tokenizer(self.prompt).input_ids)-1
def forward(self, image, caption):
image_embeds = self.visual_encoder(image)
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
text = self.tokenizer(caption, padding='longest', truncation=True, max_length=40, return_tensors="pt").to(image.device)
text.input_ids[:,0] = self.tokenizer.bos_token_id
decoder_targets = text.input_ids.masked_fill(text.input_ids == self.tokenizer.pad_token_id, -100)
decoder_targets[:,:self.prompt_length] = -100
decoder_output = self.text_decoder(text.input_ids,
attention_mask = text.attention_mask,
encoder_hidden_states = image_embeds,
encoder_attention_mask = image_atts,
labels = decoder_targets,
return_dict = True,
)
loss_lm = decoder_output.loss
return loss_lm
def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9, repetition_penalty=1.0):
image_embeds = self.visual_encoder(image)
if not sample:
image_embeds = image_embeds.repeat_interleave(num_beams,dim=0)
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts}
prompt = [self.prompt] * image.size(0)
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(image.device)
input_ids[:,0] = self.tokenizer.bos_token_id
input_ids = input_ids[:, :-1]
if sample:
#nucleus sampling
outputs = self.text_decoder.generate(input_ids=input_ids,
max_length=max_length,
min_length=min_length,
do_sample=True,
top_p=top_p,
num_return_sequences=1,
eos_token_id=self.tokenizer.sep_token_id,
pad_token_id=self.tokenizer.pad_token_id,
repetition_penalty=1.1,
**model_kwargs)
else:
#beam search
outputs = self.text_decoder.generate(input_ids=input_ids,
max_length=max_length,
min_length=min_length,
num_beams=num_beams,
eos_token_id=self.tokenizer.sep_token_id,
pad_token_id=self.tokenizer.pad_token_id,
repetition_penalty=repetition_penalty,
**model_kwargs)
captions = []
for output in outputs:
caption = self.tokenizer.decode(output, skip_special_tokens=True)
captions.append(caption[len(self.prompt):])
return captions
def blip_decoder(pretrained='',**kwargs):
model = BLIP_Decoder(**kwargs)
if pretrained:
model,msg = load_checkpoint(model,pretrained)
assert(len(msg.missing_keys)==0)
return model
def blip_feature_extractor(pretrained='',**kwargs):
model = BLIP_Base(**kwargs)
if pretrained:
model,msg = load_checkpoint(model,pretrained)
assert(len(msg.missing_keys)==0)
return model
def init_tokenizer():
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
tokenizer.add_special_tokens({'bos_token':'[DEC]'})
tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
return tokenizer
def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
assert vit in ['base', 'large'], "vit parameter must be base or large"
if vit=='base':
vision_width = 768
visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
drop_path_rate=0 or drop_path_rate
)
elif vit=='large':
vision_width = 1024
visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
drop_path_rate=0.1 or drop_path_rate
)
return visual_encoder, vision_width
def is_url(url_or_filename):
parsed = urlparse(url_or_filename)
return parsed.scheme in ("http", "https")
def load_checkpoint(model,url_or_filename):
if is_url(url_or_filename):
cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
checkpoint = torch.load(cached_file, map_location='cpu')
elif os.path.isfile(url_or_filename):
checkpoint = torch.load(url_or_filename, map_location='cpu')
else:
raise RuntimeError('checkpoint url or path is invalid')
state_dict = checkpoint['model']
state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
model.visual_encoder_m)
for key in model.state_dict().keys():
if key in state_dict.keys():
if state_dict[key].shape!=model.state_dict()[key].shape:
del state_dict[key]
msg = model.load_state_dict(state_dict,strict=False)
print('load checkpoint from %s'%url_or_filename)
return model,msg

76
ldm/models/blip_itm.py Normal file
View File

@ -0,0 +1,76 @@
from models.med import BertConfig, BertModel
from transformers import BertTokenizer
import torch
from torch import nn
import torch.nn.functional as F
from models.blip import create_vit, init_tokenizer, load_checkpoint
class BLIP_ITM(nn.Module):
def __init__(self,
med_config = 'configs/med_config.json',
image_size = 384,
vit = 'base',
vit_grad_ckpt = False,
vit_ckpt_layer = 0,
embed_dim = 256,
):
"""
Args:
med_config (str): path for the mixture of encoder-decoder model's configuration file
image_size (int): input image size
vit (str): model size of vision transformer
"""
super().__init__()
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
self.tokenizer = init_tokenizer()
med_config = BertConfig.from_json_file(med_config)
med_config.encoder_width = vision_width
self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
text_width = self.text_encoder.config.hidden_size
self.vision_proj = nn.Linear(vision_width, embed_dim)
self.text_proj = nn.Linear(text_width, embed_dim)
self.itm_head = nn.Linear(text_width, 2)
def forward(self, image, caption, match_head='itm'):
image_embeds = self.visual_encoder(image)
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=35,
return_tensors="pt").to(image.device)
if match_head=='itm':
output = self.text_encoder(text.input_ids,
attention_mask = text.attention_mask,
encoder_hidden_states = image_embeds,
encoder_attention_mask = image_atts,
return_dict = True,
)
itm_output = self.itm_head(output.last_hidden_state[:,0,:])
return itm_output
elif match_head=='itc':
text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
return_dict = True, mode = 'text')
image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1)
text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1)
sim = image_feat @ text_feat.t()
return sim
def blip_itm(pretrained='',**kwargs):
model = BLIP_ITM(**kwargs)
if pretrained:
model,msg = load_checkpoint(model,pretrained)
assert(len(msg.missing_keys)==0)
return model

103
ldm/models/blip_nlvr.py Normal file
View File

@ -0,0 +1,103 @@
from models.med import BertConfig
from models.nlvr_encoder import BertModel
from models.vit import interpolate_pos_embed
from models.blip import create_vit, init_tokenizer, is_url
from timm.models.hub import download_cached_file
import torch
from torch import nn
import torch.nn.functional as F
from transformers import BertTokenizer
import numpy as np
class BLIP_NLVR(nn.Module):
def __init__(self,
med_config = 'configs/med_config.json',
image_size = 480,
vit = 'base',
vit_grad_ckpt = False,
vit_ckpt_layer = 0,
):
"""
Args:
med_config (str): path for the mixture of encoder-decoder model's configuration file
image_size (int): input image size
vit (str): model size of vision transformer
"""
super().__init__()
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, drop_path_rate=0.1)
self.tokenizer = init_tokenizer()
med_config = BertConfig.from_json_file(med_config)
med_config.encoder_width = vision_width
self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
self.cls_head = nn.Sequential(
nn.Linear(self.text_encoder.config.hidden_size, self.text_encoder.config.hidden_size),
nn.ReLU(),
nn.Linear(self.text_encoder.config.hidden_size, 2)
)
def forward(self, image, text, targets, train=True):
image_embeds = self.visual_encoder(image)
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
image0_embeds, image1_embeds = torch.split(image_embeds,targets.size(0))
text = self.tokenizer(text, padding='longest', return_tensors="pt").to(image.device)
text.input_ids[:,0] = self.tokenizer.enc_token_id
output = self.text_encoder(text.input_ids,
attention_mask = text.attention_mask,
encoder_hidden_states = [image0_embeds,image1_embeds],
encoder_attention_mask = [image_atts[:image0_embeds.size(0)],
image_atts[image0_embeds.size(0):]],
return_dict = True,
)
hidden_state = output.last_hidden_state[:,0,:]
prediction = self.cls_head(hidden_state)
if train:
loss = F.cross_entropy(prediction, targets)
return loss
else:
return prediction
def blip_nlvr(pretrained='',**kwargs):
model = BLIP_NLVR(**kwargs)
if pretrained:
model,msg = load_checkpoint(model,pretrained)
print("missing keys:")
print(msg.missing_keys)
return model
def load_checkpoint(model,url_or_filename):
if is_url(url_or_filename):
cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
checkpoint = torch.load(cached_file, map_location='cpu')
elif os.path.isfile(url_or_filename):
checkpoint = torch.load(url_or_filename, map_location='cpu')
else:
raise RuntimeError('checkpoint url or path is invalid')
state_dict = checkpoint['model']
state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
for key in list(state_dict.keys()):
if 'crossattention.self.' in key:
new_key0 = key.replace('self','self0')
new_key1 = key.replace('self','self1')
state_dict[new_key0] = state_dict[key]
state_dict[new_key1] = state_dict[key]
elif 'crossattention.output.dense.' in key:
new_key0 = key.replace('dense','dense0')
new_key1 = key.replace('dense','dense1')
state_dict[new_key0] = state_dict[key]
state_dict[new_key1] = state_dict[key]
msg = model.load_state_dict(state_dict,strict=False)
print('load checkpoint from %s'%url_or_filename)
return model,msg

339
ldm/models/blip_pretrain.py Normal file
View File

@ -0,0 +1,339 @@
'''
* Copyright (c) 2022, salesforce.com, inc.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
* By Junnan Li
'''
from models.med import BertConfig, BertModel, BertLMHeadModel
from transformers import BertTokenizer
import transformers
transformers.logging.set_verbosity_error()
import torch
from torch import nn
import torch.nn.functional as F
from models.blip import create_vit, init_tokenizer, load_checkpoint
class BLIP_Pretrain(nn.Module):
def __init__(self,
med_config = 'configs/bert_config.json',
image_size = 224,
vit = 'base',
vit_grad_ckpt = False,
vit_ckpt_layer = 0,
embed_dim = 256,
queue_size = 57600,
momentum = 0.995,
):
"""
Args:
med_config (str): path for the mixture of encoder-decoder model's configuration file
image_size (int): input image size
vit (str): model size of vision transformer
"""
super().__init__()
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, 0)
if vit=='base':
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
map_location="cpu", check_hash=True)
state_dict = checkpoint["model"]
msg = self.visual_encoder.load_state_dict(state_dict,strict=False)
elif vit=='large':
from timm.models.helpers import load_custom_pretrained
from timm.models.vision_transformer import default_cfgs
load_custom_pretrained(self.visual_encoder,default_cfgs['vit_large_patch16_224_in21k'])
self.tokenizer = init_tokenizer()
encoder_config = BertConfig.from_json_file(med_config)
encoder_config.encoder_width = vision_width
self.text_encoder = BertModel.from_pretrained('bert-base-uncased',config=encoder_config, add_pooling_layer=False)
self.text_encoder.resize_token_embeddings(len(self.tokenizer))
text_width = self.text_encoder.config.hidden_size
self.vision_proj = nn.Linear(vision_width, embed_dim)
self.text_proj = nn.Linear(text_width, embed_dim)
self.itm_head = nn.Linear(text_width, 2)
# create momentum encoders
self.visual_encoder_m, vision_width = create_vit(vit,image_size)
self.vision_proj_m = nn.Linear(vision_width, embed_dim)
self.text_encoder_m = BertModel(config=encoder_config, add_pooling_layer=False)
self.text_proj_m = nn.Linear(text_width, embed_dim)
self.model_pairs = [[self.visual_encoder,self.visual_encoder_m],
[self.vision_proj,self.vision_proj_m],
[self.text_encoder,self.text_encoder_m],
[self.text_proj,self.text_proj_m],
]
self.copy_params()
# create the queue
self.register_buffer("image_queue", torch.randn(embed_dim, queue_size))
self.register_buffer("text_queue", torch.randn(embed_dim, queue_size))
self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
self.image_queue = nn.functional.normalize(self.image_queue, dim=0)
self.text_queue = nn.functional.normalize(self.text_queue, dim=0)
self.queue_size = queue_size
self.momentum = momentum
self.temp = nn.Parameter(0.07*torch.ones([]))
# create the decoder
decoder_config = BertConfig.from_json_file(med_config)
decoder_config.encoder_width = vision_width
self.text_decoder = BertLMHeadModel.from_pretrained('bert-base-uncased',config=decoder_config)
self.text_decoder.resize_token_embeddings(len(self.tokenizer))
tie_encoder_decoder_weights(self.text_encoder,self.text_decoder.bert,'','/attention')
def forward(self, image, caption, alpha):
with torch.no_grad():
self.temp.clamp_(0.001,0.5)
image_embeds = self.visual_encoder(image)
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1)
text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=30,
return_tensors="pt").to(image.device)
text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
return_dict = True, mode = 'text')
text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1)
# get momentum features
with torch.no_grad():
self._momentum_update()
image_embeds_m = self.visual_encoder_m(image)
image_feat_m = F.normalize(self.vision_proj_m(image_embeds_m[:,0,:]),dim=-1)
image_feat_all = torch.cat([image_feat_m.t(),self.image_queue.clone().detach()],dim=1)
text_output_m = self.text_encoder_m(text.input_ids, attention_mask = text.attention_mask,
return_dict = True, mode = 'text')
text_feat_m = F.normalize(self.text_proj_m(text_output_m.last_hidden_state[:,0,:]),dim=-1)
text_feat_all = torch.cat([text_feat_m.t(),self.text_queue.clone().detach()],dim=1)
sim_i2t_m = image_feat_m @ text_feat_all / self.temp
sim_t2i_m = text_feat_m @ image_feat_all / self.temp
sim_targets = torch.zeros(sim_i2t_m.size()).to(image.device)
sim_targets.fill_diagonal_(1)
sim_i2t_targets = alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets
sim_t2i_targets = alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets
sim_i2t = image_feat @ text_feat_all / self.temp
sim_t2i = text_feat @ image_feat_all / self.temp
loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean()
loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_t2i_targets,dim=1).mean()
loss_ita = (loss_i2t+loss_t2i)/2
self._dequeue_and_enqueue(image_feat_m, text_feat_m)
###============== Image-text Matching ===================###
encoder_input_ids = text.input_ids.clone()
encoder_input_ids[:,0] = self.tokenizer.enc_token_id
# forward the positve image-text pair
bs = image.size(0)
output_pos = self.text_encoder(encoder_input_ids,
attention_mask = text.attention_mask,
encoder_hidden_states = image_embeds,
encoder_attention_mask = image_atts,
return_dict = True,
)
with torch.no_grad():
weights_t2i = F.softmax(sim_t2i[:,:bs],dim=1)+1e-4
weights_t2i.fill_diagonal_(0)
weights_i2t = F.softmax(sim_i2t[:,:bs],dim=1)+1e-4
weights_i2t.fill_diagonal_(0)
# select a negative image for each text
image_embeds_neg = []
for b in range(bs):
neg_idx = torch.multinomial(weights_t2i[b], 1).item()
image_embeds_neg.append(image_embeds[neg_idx])
image_embeds_neg = torch.stack(image_embeds_neg,dim=0)
# select a negative text for each image
text_ids_neg = []
text_atts_neg = []
for b in range(bs):
neg_idx = torch.multinomial(weights_i2t[b], 1).item()
text_ids_neg.append(encoder_input_ids[neg_idx])
text_atts_neg.append(text.attention_mask[neg_idx])
text_ids_neg = torch.stack(text_ids_neg,dim=0)
text_atts_neg = torch.stack(text_atts_neg,dim=0)
text_ids_all = torch.cat([encoder_input_ids, text_ids_neg],dim=0)
text_atts_all = torch.cat([text.attention_mask, text_atts_neg],dim=0)
image_embeds_all = torch.cat([image_embeds_neg,image_embeds],dim=0)
image_atts_all = torch.cat([image_atts,image_atts],dim=0)
output_neg = self.text_encoder(text_ids_all,
attention_mask = text_atts_all,
encoder_hidden_states = image_embeds_all,
encoder_attention_mask = image_atts_all,
return_dict = True,
)
vl_embeddings = torch.cat([output_pos.last_hidden_state[:,0,:], output_neg.last_hidden_state[:,0,:]],dim=0)
vl_output = self.itm_head(vl_embeddings)
itm_labels = torch.cat([torch.ones(bs,dtype=torch.long),torch.zeros(2*bs,dtype=torch.long)],
dim=0).to(image.device)
loss_itm = F.cross_entropy(vl_output, itm_labels)
##================= LM ========================##
decoder_input_ids = text.input_ids.clone()
decoder_input_ids[:,0] = self.tokenizer.bos_token_id
decoder_targets = decoder_input_ids.masked_fill(decoder_input_ids == self.tokenizer.pad_token_id, -100)
decoder_output = self.text_decoder(decoder_input_ids,
attention_mask = text.attention_mask,
encoder_hidden_states = image_embeds,
encoder_attention_mask = image_atts,
labels = decoder_targets,
return_dict = True,
)
loss_lm = decoder_output.loss
return loss_ita, loss_itm, loss_lm
@torch.no_grad()
def copy_params(self):
for model_pair in self.model_pairs:
for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
param_m.data.copy_(param.data) # initialize
param_m.requires_grad = False # not update by gradient
@torch.no_grad()
def _momentum_update(self):
for model_pair in self.model_pairs:
for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum)
@torch.no_grad()
def _dequeue_and_enqueue(self, image_feat, text_feat):
# gather keys before updating queue
image_feats = concat_all_gather(image_feat)
text_feats = concat_all_gather(text_feat)
batch_size = image_feats.shape[0]
ptr = int(self.queue_ptr)
assert self.queue_size % batch_size == 0 # for simplicity
# replace the keys at ptr (dequeue and enqueue)
self.image_queue[:, ptr:ptr + batch_size] = image_feats.T
self.text_queue[:, ptr:ptr + batch_size] = text_feats.T
ptr = (ptr + batch_size) % self.queue_size # move pointer
self.queue_ptr[0] = ptr
def blip_pretrain(**kwargs):
model = BLIP_Pretrain(**kwargs)
return model
@torch.no_grad()
def concat_all_gather(tensor):
"""
Performs all_gather operation on the provided tensors.
*** Warning ***: torch.distributed.all_gather has no gradient.
"""
tensors_gather = [torch.ones_like(tensor)
for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
output = torch.cat(tensors_gather, dim=0)
return output
from typing import List
def tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_model_prefix: str, skip_key:str):
uninitialized_encoder_weights: List[str] = []
if decoder.__class__ != encoder.__class__:
logger.info(
f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder weights are correctly initialized."
)
def tie_encoder_to_decoder_recursively(
decoder_pointer: nn.Module,
encoder_pointer: nn.Module,
module_name: str,
uninitialized_encoder_weights: List[str],
skip_key: str,
depth=0,
):
assert isinstance(decoder_pointer, nn.Module) and isinstance(
encoder_pointer, nn.Module
), f"{decoder_pointer} and {encoder_pointer} have to be of type torch.nn.Module"
if hasattr(decoder_pointer, "weight") and skip_key not in module_name:
assert hasattr(encoder_pointer, "weight")
encoder_pointer.weight = decoder_pointer.weight
if hasattr(decoder_pointer, "bias"):
assert hasattr(encoder_pointer, "bias")
encoder_pointer.bias = decoder_pointer.bias
print(module_name+' is tied')
return
encoder_modules = encoder_pointer._modules
decoder_modules = decoder_pointer._modules
if len(decoder_modules) > 0:
assert (
len(encoder_modules) > 0
), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}"
all_encoder_weights = set([module_name + "/" + sub_name for sub_name in encoder_modules.keys()])
encoder_layer_pos = 0
for name, module in decoder_modules.items():
if name.isdigit():
encoder_name = str(int(name) + encoder_layer_pos)
decoder_name = name
if not isinstance(decoder_modules[decoder_name], type(encoder_modules[encoder_name])) and len(
encoder_modules
) != len(decoder_modules):
# this can happen if the name corresponds to the position in a list module list of layers
# in this case the decoder has added a cross-attention that the encoder does not have
# thus skip this step and subtract one layer pos from encoder
encoder_layer_pos -= 1
continue
elif name not in encoder_modules:
continue
elif depth > 500:
raise ValueError(
"Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is a circular dependency between two or more `nn.Modules` of your model."
)
else:
decoder_name = encoder_name = name
tie_encoder_to_decoder_recursively(
decoder_modules[decoder_name],
encoder_modules[encoder_name],
module_name + "/" + name,
uninitialized_encoder_weights,
skip_key,
depth=depth + 1,
)
all_encoder_weights.remove(module_name + "/" + encoder_name)
uninitialized_encoder_weights += list(all_encoder_weights)
# tie weights recursively
tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix, uninitialized_encoder_weights, skip_key)

View File

@ -0,0 +1,319 @@
from models.med import BertConfig, BertModel
from transformers import BertTokenizer
import torch
from torch import nn
import torch.nn.functional as F
from models.blip import create_vit, init_tokenizer, load_checkpoint
class BLIP_Retrieval(nn.Module):
def __init__(self,
med_config = 'configs/med_config.json',
image_size = 384,
vit = 'base',
vit_grad_ckpt = False,
vit_ckpt_layer = 0,
embed_dim = 256,
queue_size = 57600,
momentum = 0.995,
negative_all_rank = False,
):
"""
Args:
med_config (str): path for the mixture of encoder-decoder model's configuration file
image_size (int): input image size
vit (str): model size of vision transformer
"""
super().__init__()
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
self.tokenizer = init_tokenizer()
med_config = BertConfig.from_json_file(med_config)
med_config.encoder_width = vision_width
self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
text_width = self.text_encoder.config.hidden_size
self.vision_proj = nn.Linear(vision_width, embed_dim)
self.text_proj = nn.Linear(text_width, embed_dim)
self.itm_head = nn.Linear(text_width, 2)
# create momentum encoders
self.visual_encoder_m, vision_width = create_vit(vit,image_size)
self.vision_proj_m = nn.Linear(vision_width, embed_dim)
self.text_encoder_m = BertModel(config=med_config, add_pooling_layer=False)
self.text_proj_m = nn.Linear(text_width, embed_dim)
self.model_pairs = [[self.visual_encoder,self.visual_encoder_m],
[self.vision_proj,self.vision_proj_m],
[self.text_encoder,self.text_encoder_m],
[self.text_proj,self.text_proj_m],
]
self.copy_params()
# create the queue
self.register_buffer("image_queue", torch.randn(embed_dim, queue_size))
self.register_buffer("text_queue", torch.randn(embed_dim, queue_size))
self.register_buffer("idx_queue", torch.full((1,queue_size),-100))
self.register_buffer("ptr_queue", torch.zeros(1, dtype=torch.long))
self.image_queue = nn.functional.normalize(self.image_queue, dim=0)
self.text_queue = nn.functional.normalize(self.text_queue, dim=0)
self.queue_size = queue_size
self.momentum = momentum
self.temp = nn.Parameter(0.07*torch.ones([]))
self.negative_all_rank = negative_all_rank
def forward(self, image, caption, alpha, idx):
with torch.no_grad():
self.temp.clamp_(0.001,0.5)
image_embeds = self.visual_encoder(image)
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1)
text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=35,
return_tensors="pt").to(image.device)
text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
return_dict = True, mode = 'text')
text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1)
###============== Image-text Contrastive Learning ===================###
idx = idx.view(-1,1)
idx_all = torch.cat([idx.t(), self.idx_queue.clone().detach()],dim=1)
pos_idx = torch.eq(idx, idx_all).float()
sim_targets = pos_idx / pos_idx.sum(1,keepdim=True)
# get momentum features
with torch.no_grad():
self._momentum_update()
image_embeds_m = self.visual_encoder_m(image)
image_feat_m = F.normalize(self.vision_proj_m(image_embeds_m[:,0,:]),dim=-1)
image_feat_m_all = torch.cat([image_feat_m.t(),self.image_queue.clone().detach()],dim=1)
text_output_m = self.text_encoder_m(text.input_ids, attention_mask = text.attention_mask,
return_dict = True, mode = 'text')
text_feat_m = F.normalize(self.text_proj_m(text_output_m.last_hidden_state[:,0,:]),dim=-1)
text_feat_m_all = torch.cat([text_feat_m.t(),self.text_queue.clone().detach()],dim=1)
sim_i2t_m = image_feat_m @ text_feat_m_all / self.temp
sim_t2i_m = text_feat_m @ image_feat_m_all / self.temp
sim_i2t_targets = alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets
sim_t2i_targets = alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets
sim_i2t = image_feat @ text_feat_m_all / self.temp
sim_t2i = text_feat @ image_feat_m_all / self.temp
loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean()
loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_t2i_targets,dim=1).mean()
loss_ita = (loss_i2t+loss_t2i)/2
idxs = concat_all_gather(idx)
self._dequeue_and_enqueue(image_feat_m, text_feat_m, idxs)
###============== Image-text Matching ===================###
encoder_input_ids = text.input_ids.clone()
encoder_input_ids[:,0] = self.tokenizer.enc_token_id
# forward the positve image-text pair
bs = image.size(0)
output_pos = self.text_encoder(encoder_input_ids,
attention_mask = text.attention_mask,
encoder_hidden_states = image_embeds,
encoder_attention_mask = image_atts,
return_dict = True,
)
if self.negative_all_rank:
# compute sample similarity
with torch.no_grad():
mask = torch.eq(idx, idxs.t())
image_feat_world = concat_all_gather(image_feat)
text_feat_world = concat_all_gather(text_feat)
sim_i2t = image_feat @ text_feat_world.t() / self.temp
sim_t2i = text_feat @ image_feat_world.t() / self.temp
weights_i2t = F.softmax(sim_i2t,dim=1)
weights_i2t.masked_fill_(mask, 0)
weights_t2i = F.softmax(sim_t2i,dim=1)
weights_t2i.masked_fill_(mask, 0)
image_embeds_world = all_gather_with_grad(image_embeds)
# select a negative image (from all ranks) for each text
image_embeds_neg = []
for b in range(bs):
neg_idx = torch.multinomial(weights_t2i[b], 1).item()
image_embeds_neg.append(image_embeds_world[neg_idx])
image_embeds_neg = torch.stack(image_embeds_neg,dim=0)
# select a negative text (from all ranks) for each image
input_ids_world = concat_all_gather(encoder_input_ids)
att_mask_world = concat_all_gather(text.attention_mask)
text_ids_neg = []
text_atts_neg = []
for b in range(bs):
neg_idx = torch.multinomial(weights_i2t[b], 1).item()
text_ids_neg.append(input_ids_world[neg_idx])
text_atts_neg.append(att_mask_world[neg_idx])
else:
with torch.no_grad():
mask = torch.eq(idx, idx.t())
sim_i2t = image_feat @ text_feat.t() / self.temp
sim_t2i = text_feat @ image_feat.t() / self.temp
weights_i2t = F.softmax(sim_i2t,dim=1)
weights_i2t.masked_fill_(mask, 0)
weights_t2i = F.softmax(sim_t2i,dim=1)
weights_t2i.masked_fill_(mask, 0)
# select a negative image (from same rank) for each text
image_embeds_neg = []
for b in range(bs):
neg_idx = torch.multinomial(weights_t2i[b], 1).item()
image_embeds_neg.append(image_embeds[neg_idx])
image_embeds_neg = torch.stack(image_embeds_neg,dim=0)
# select a negative text (from same rank) for each image
text_ids_neg = []
text_atts_neg = []
for b in range(bs):
neg_idx = torch.multinomial(weights_i2t[b], 1).item()
text_ids_neg.append(encoder_input_ids[neg_idx])
text_atts_neg.append(text.attention_mask[neg_idx])
text_ids_neg = torch.stack(text_ids_neg,dim=0)
text_atts_neg = torch.stack(text_atts_neg,dim=0)
text_ids_all = torch.cat([encoder_input_ids, text_ids_neg],dim=0)
text_atts_all = torch.cat([text.attention_mask, text_atts_neg],dim=0)
image_embeds_all = torch.cat([image_embeds_neg,image_embeds],dim=0)
image_atts_all = torch.cat([image_atts,image_atts],dim=0)
output_neg = self.text_encoder(text_ids_all,
attention_mask = text_atts_all,
encoder_hidden_states = image_embeds_all,
encoder_attention_mask = image_atts_all,
return_dict = True,
)
vl_embeddings = torch.cat([output_pos.last_hidden_state[:,0,:], output_neg.last_hidden_state[:,0,:]],dim=0)
vl_output = self.itm_head(vl_embeddings)
itm_labels = torch.cat([torch.ones(bs,dtype=torch.long),torch.zeros(2*bs,dtype=torch.long)],
dim=0).to(image.device)
loss_itm = F.cross_entropy(vl_output, itm_labels)
return loss_ita, loss_itm
@torch.no_grad()
def copy_params(self):
for model_pair in self.model_pairs:
for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
param_m.data.copy_(param.data) # initialize
param_m.requires_grad = False # not update by gradient
@torch.no_grad()
def _momentum_update(self):
for model_pair in self.model_pairs:
for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum)
@torch.no_grad()
def _dequeue_and_enqueue(self, image_feat, text_feat, idxs):
# gather keys before updating queue
image_feats = concat_all_gather(image_feat)
text_feats = concat_all_gather(text_feat)
batch_size = image_feats.shape[0]
ptr = int(self.ptr_queue)
assert self.queue_size % batch_size == 0 # for simplicity
# replace the keys at ptr (dequeue and enqueue)
self.image_queue[:, ptr:ptr + batch_size] = image_feats.T
self.text_queue[:, ptr:ptr + batch_size] = text_feats.T
self.idx_queue[:, ptr:ptr + batch_size] = idxs.T
ptr = (ptr + batch_size) % self.queue_size # move pointer
self.ptr_queue[0] = ptr
def blip_retrieval(pretrained='',**kwargs):
model = BLIP_Retrieval(**kwargs)
if pretrained:
model,msg = load_checkpoint(model,pretrained)
print("missing keys:")
print(msg.missing_keys)
return model
@torch.no_grad()
def concat_all_gather(tensor):
"""
Performs all_gather operation on the provided tensors.
*** Warning ***: torch.distributed.all_gather has no gradient.
"""
tensors_gather = [torch.ones_like(tensor)
for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
output = torch.cat(tensors_gather, dim=0)
return output
class GatherLayer(torch.autograd.Function):
"""
Gather tensors from all workers with support for backward propagation:
This implementation does not cut the gradients as torch.distributed.all_gather does.
"""
@staticmethod
def forward(ctx, x):
output = [torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(output, x)
return tuple(output)
@staticmethod
def backward(ctx, *grads):
all_gradients = torch.stack(grads)
torch.distributed.all_reduce(all_gradients)
return all_gradients[torch.distributed.get_rank()]
def all_gather_with_grad(tensors):
"""
Performs all_gather operation on the provided tensors.
Graph remains connected for backward grad computation.
"""
# Queue the gathered tensors
world_size = torch.distributed.get_world_size()
# There is no need for reduction in the single-proc case
if world_size == 1:
return tensors
tensor_all = GatherLayer.apply(tensors)
return torch.cat(tensor_all, dim=0)

186
ldm/models/blip_vqa.py Normal file
View File

@ -0,0 +1,186 @@
from models.med import BertConfig, BertModel, BertLMHeadModel
from models.blip import create_vit, init_tokenizer, load_checkpoint
import torch
from torch import nn
import torch.nn.functional as F
from transformers import BertTokenizer
import numpy as np
class BLIP_VQA(nn.Module):
def __init__(self,
med_config = 'configs/med_config.json',
image_size = 480,
vit = 'base',
vit_grad_ckpt = False,
vit_ckpt_layer = 0,
):
"""
Args:
med_config (str): path for the mixture of encoder-decoder model's configuration file
image_size (int): input image size
vit (str): model size of vision transformer
"""
super().__init__()
self.visual_encoder, vision_width = create_vit(vit, image_size, vit_grad_ckpt, vit_ckpt_layer, drop_path_rate=0.1)
self.tokenizer = init_tokenizer()
encoder_config = BertConfig.from_json_file(med_config)
encoder_config.encoder_width = vision_width
self.text_encoder = BertModel(config=encoder_config, add_pooling_layer=False)
decoder_config = BertConfig.from_json_file(med_config)
self.text_decoder = BertLMHeadModel(config=decoder_config)
def forward(self, image, question, answer=None, n=None, weights=None, train=True, inference='rank', k_test=128):
image_embeds = self.visual_encoder(image)
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
question = self.tokenizer(question, padding='longest', truncation=True, max_length=35,
return_tensors="pt").to(image.device)
question.input_ids[:,0] = self.tokenizer.enc_token_id
if train:
'''
n: number of answers for each question
weights: weight for each answer
'''
answer = self.tokenizer(answer, padding='longest', return_tensors="pt").to(image.device)
answer.input_ids[:,0] = self.tokenizer.bos_token_id
answer_targets = answer.input_ids.masked_fill(answer.input_ids == self.tokenizer.pad_token_id, -100)
question_output = self.text_encoder(question.input_ids,
attention_mask = question.attention_mask,
encoder_hidden_states = image_embeds,
encoder_attention_mask = image_atts,
return_dict = True)
question_states = []
question_atts = []
for b, n in enumerate(n):
question_states += [question_output.last_hidden_state[b]]*n
question_atts += [question.attention_mask[b]]*n
question_states = torch.stack(question_states,0)
question_atts = torch.stack(question_atts,0)
answer_output = self.text_decoder(answer.input_ids,
attention_mask = answer.attention_mask,
encoder_hidden_states = question_states,
encoder_attention_mask = question_atts,
labels = answer_targets,
return_dict = True,
reduction = 'none',
)
loss = weights * answer_output.loss
loss = loss.sum()/image.size(0)
return loss
else:
question_output = self.text_encoder(question.input_ids,
attention_mask = question.attention_mask,
encoder_hidden_states = image_embeds,
encoder_attention_mask = image_atts,
return_dict = True)
if inference=='generate':
num_beams = 3
question_states = question_output.last_hidden_state.repeat_interleave(num_beams,dim=0)
question_atts = torch.ones(question_states.size()[:-1],dtype=torch.long).to(question_states.device)
model_kwargs = {"encoder_hidden_states": question_states, "encoder_attention_mask":question_atts}
bos_ids = torch.full((image.size(0),1),fill_value=self.tokenizer.bos_token_id,device=image.device)
outputs = self.text_decoder.generate(input_ids=bos_ids,
max_length=10,
min_length=1,
num_beams=num_beams,
eos_token_id=self.tokenizer.sep_token_id,
pad_token_id=self.tokenizer.pad_token_id,
**model_kwargs)
answers = []
for output in outputs:
answer = self.tokenizer.decode(output, skip_special_tokens=True)
answers.append(answer)
return answers
elif inference=='rank':
max_ids = self.rank_answer(question_output.last_hidden_state, question.attention_mask,
answer.input_ids, answer.attention_mask, k_test)
return max_ids
def rank_answer(self, question_states, question_atts, answer_ids, answer_atts, k):
num_ques = question_states.size(0)
start_ids = answer_ids[0,0].repeat(num_ques,1) # bos token
start_output = self.text_decoder(start_ids,
encoder_hidden_states = question_states,
encoder_attention_mask = question_atts,
return_dict = True,
reduction = 'none')
logits = start_output.logits[:,0,:] # first token's logit
# topk_probs: top-k probability
# topk_ids: [num_question, k]
answer_first_token = answer_ids[:,1]
prob_first_token = F.softmax(logits,dim=1).index_select(dim=1, index=answer_first_token)
topk_probs, topk_ids = prob_first_token.topk(k,dim=1)
# answer input: [num_question*k, answer_len]
input_ids = []
input_atts = []
for b, topk_id in enumerate(topk_ids):
input_ids.append(answer_ids.index_select(dim=0, index=topk_id))
input_atts.append(answer_atts.index_select(dim=0, index=topk_id))
input_ids = torch.cat(input_ids,dim=0)
input_atts = torch.cat(input_atts,dim=0)
targets_ids = input_ids.masked_fill(input_ids == self.tokenizer.pad_token_id, -100)
# repeat encoder's output for top-k answers
question_states = tile(question_states, 0, k)
question_atts = tile(question_atts, 0, k)
output = self.text_decoder(input_ids,
attention_mask = input_atts,
encoder_hidden_states = question_states,
encoder_attention_mask = question_atts,
labels = targets_ids,
return_dict = True,
reduction = 'none')
log_probs_sum = -output.loss
log_probs_sum = log_probs_sum.view(num_ques,k)
max_topk_ids = log_probs_sum.argmax(dim=1)
max_ids = topk_ids[max_topk_ids>=0,max_topk_ids]
return max_ids
def blip_vqa(pretrained='',**kwargs):
model = BLIP_VQA(**kwargs)
if pretrained:
model,msg = load_checkpoint(model,pretrained)
# assert(len(msg.missing_keys)==0)
return model
def tile(x, dim, n_tile):
init_dim = x.size(dim)
repeat_idx = [1] * x.dim()
repeat_idx[dim] = n_tile
x = x.repeat(*(repeat_idx))
order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]))
return torch.index_select(x, dim, order_index.to(x.device))

955
ldm/models/med.py Normal file
View File

@ -0,0 +1,955 @@
'''
* Copyright (c) 2022, salesforce.com, inc.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
* By Junnan Li
* Based on huggingface code base
* https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
'''
import math
import os
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
from torch import Tensor, device, dtype, nn
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
import torch.nn.functional as F
from transformers.activations import ACT2FN
from transformers.file_utils import (
ModelOutput,
)
from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
MaskedLMOutput,
MultipleChoiceModelOutput,
NextSentencePredictorOutput,
QuestionAnsweringModelOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
)
from transformers.modeling_utils import (
PreTrainedModel,
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
from transformers.utils import logging
from transformers.models.bert.configuration_bert import BertConfig
logger = logging.get_logger(__name__)
class BertEmbeddings(nn.Module):
"""Construct the embeddings from word and position embeddings."""
def __init__(self, config):
super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.config = config
def forward(
self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
):
if input_ids is not None:
input_shape = input_ids.size()
else:
input_shape = inputs_embeds.size()[:-1]
seq_length = input_shape[1]
if position_ids is None:
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
embeddings = inputs_embeds
if self.position_embedding_type == "absolute":
position_embeddings = self.position_embeddings(position_ids)
embeddings += position_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class BertSelfAttention(nn.Module):
def __init__(self, config, is_cross_attention):
super().__init__()
self.config = config
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
if is_cross_attention:
self.key = nn.Linear(config.encoder_width, self.all_head_size)
self.value = nn.Linear(config.encoder_width, self.all_head_size)
else:
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
self.save_attention = False
def save_attn_gradients(self, attn_gradients):
self.attn_gradients = attn_gradients
def get_attn_gradients(self):
return self.attn_gradients
def save_attention_map(self, attention_map):
self.attention_map = attention_map
def get_attention_map(self):
return self.attention_map
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
):
mixed_query_layer = self.query(hidden_states)
# If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None
if is_cross_attention:
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
else:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
past_key_value = (key_layer, value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
seq_length = hidden_states.size()[1]
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
distance = position_ids_l - position_ids_r
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
if self.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.Softmax(dim=-1)(attention_scores)
if is_cross_attention and self.save_attention:
self.save_attention_map(attention_probs)
attention_probs.register_hook(self.save_attn_gradients)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs_dropped = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs_dropped = attention_probs_dropped * head_mask
context_layer = torch.matmul(attention_probs_dropped, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
outputs = outputs + (past_key_value,)
return outputs
class BertSelfOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class BertAttention(nn.Module):
def __init__(self, config, is_cross_attention=False):
super().__init__()
self.self = BertSelfAttention(config, is_cross_attention)
self.output = BertSelfOutput(config)
self.pruned_heads = set()
def prune_heads(self, heads):
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
)
# Prune linear layers
self.self.query = prune_linear_layer(self.self.query, index)
self.self.key = prune_linear_layer(self.self.key, index)
self.self.value = prune_linear_layer(self.self.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
# Update hyper params and store pruned heads
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
):
self_outputs = self.self(
hidden_states,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs
class BertIntermediate(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class BertOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class BertLayer(nn.Module):
def __init__(self, config, layer_num):
super().__init__()
self.config = config
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = BertAttention(config)
self.layer_num = layer_num
if self.config.add_cross_attention:
self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
mode=None,
):
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention(
hidden_states,
attention_mask,
head_mask,
output_attentions=output_attentions,
past_key_value=self_attn_past_key_value,
)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:-1]
present_key_value = self_attention_outputs[-1]
if mode=='multimodal':
assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
cross_attention_outputs = self.crossattention(
attention_output,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
output_attentions=output_attentions,
)
attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
)
outputs = (layer_output,) + outputs
outputs = outputs + (present_key_value,)
return outputs
def feed_forward_chunk(self, attention_output):
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output
class BertEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
mode='multimodal',
):
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
next_decoder_cache = () if use_cache else None
for i in range(self.config.num_hidden_layers):
layer_module = self.layer[i]
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warn(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
mode=mode,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
mode=mode,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[-1],)
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_decoder_cache,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)
class BertPooler(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
class BertPredictionHeadTransform(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
if isinstance(config.hidden_act, str):
self.transform_act_fn = ACT2FN[config.hidden_act]
else:
self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
class BertLMPredictionHead(nn.Module):
def __init__(self, config):
super().__init__()
self.transform = BertPredictionHeadTransform(config)
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
return hidden_states
class BertOnlyMLMHead(nn.Module):
def __init__(self, config):
super().__init__()
self.predictions = BertLMPredictionHead(config)
def forward(self, sequence_output):
prediction_scores = self.predictions(sequence_output)
return prediction_scores
class BertPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = BertConfig
base_model_prefix = "bert"
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
""" Initialize the weights """
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
class BertModel(BertPreTrainedModel):
"""
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
input to the forward pass.
"""
def __init__(self, config, add_pooling_layer=True):
super().__init__(config)
self.config = config
self.embeddings = BertEmbeddings(config)
self.encoder = BertEncoder(config)
self.pooler = BertPooler(config) if add_pooling_layer else None
self.init_weights()
def get_input_embeddings(self):
return self.embeddings.word_embeddings
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
"""
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
Arguments:
attention_mask (:obj:`torch.Tensor`):
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
input_shape (:obj:`Tuple[int]`):
The shape of the input to the model.
device: (:obj:`torch.device`):
The device of the input to the model.
Returns:
:obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
"""
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
if attention_mask.dim() == 3:
extended_attention_mask = attention_mask[:, None, :, :]
elif attention_mask.dim() == 2:
# Provided a padding mask of dimensions [batch_size, seq_length]
# - if the model is a decoder, apply a causal mask in addition to the padding mask
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
if is_decoder:
batch_size, seq_length = input_shape
seq_ids = torch.arange(seq_length, device=device)
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
# in case past_key_values are used we need to add a prefix ones mask to the causal mask
# causal and attention masks must have same type with pytorch version < 1.3
causal_mask = causal_mask.to(attention_mask.dtype)
if causal_mask.shape[1] < attention_mask.shape[1]:
prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
causal_mask = torch.cat(
[
torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
causal_mask,
],
axis=-1,
)
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
else:
extended_attention_mask = attention_mask[:, None, None, :]
else:
raise ValueError(
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
input_shape, attention_mask.shape
)
)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
return extended_attention_mask
def forward(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
is_decoder=False,
mode='multimodal',
):
r"""
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
the model is configured as a decoder.
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
use_cache (:obj:`bool`, `optional`):
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
decoding (see :obj:`past_key_values`).
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if is_decoder:
use_cache = use_cache if use_cache is not None else self.config.use_cache
else:
use_cache = False
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
batch_size, seq_length = input_shape
device = input_ids.device
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
batch_size, seq_length = input_shape
device = inputs_embeds.device
elif encoder_embeds is not None:
input_shape = encoder_embeds.size()[:-1]
batch_size, seq_length = input_shape
device = encoder_embeds.device
else:
raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")
# past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if attention_mask is None:
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape,
device, is_decoder)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if encoder_hidden_states is not None:
if type(encoder_hidden_states) == list:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
else:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if type(encoder_attention_mask) == list:
encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
elif encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
if encoder_embeds is None:
embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
)
else:
embedding_output = encoder_embeds
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
mode=mode,
)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
past_key_values=encoder_outputs.past_key_values,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
)
class BertLMHeadModel(BertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
def __init__(self, config):
super().__init__(config)
self.bert = BertModel(config, add_pooling_layer=False)
self.cls = BertOnlyMLMHead(config)
self.init_weights()
def get_output_embeddings(self):
return self.cls.predictions.decoder
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
def forward(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
return_logits=False,
is_decoder=True,
reduction='mean',
mode='multimodal',
):
r"""
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
the model is configured as a decoder.
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
use_cache (:obj:`bool`, `optional`):
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
decoding (see :obj:`past_key_values`).
Returns:
Example::
>>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
>>> import torch
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
>>> config = BertConfig.from_pretrained("bert-base-cased")
>>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs)
>>> prediction_logits = outputs.logits
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None:
use_cache = False
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
is_decoder=is_decoder,
mode=mode,
)
sequence_output = outputs[0]
prediction_scores = self.cls(sequence_output)
if return_logits:
return prediction_scores[:, :-1, :].contiguous()
lm_loss = None
if labels is not None:
# we are doing next-token prediction; shift prediction scores and input ids by one
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
if reduction=='none':
lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1)
if not return_dict:
output = (prediction_scores,) + outputs[2:]
return ((lm_loss,) + output) if lm_loss is not None else output
return CausalLMOutputWithCrossAttentions(
loss=lm_loss,
logits=prediction_scores,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)
# cut decoder_input_ids if past is used
if past is not None:
input_ids = input_ids[:, -1:]
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": past,
"encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
"encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
"is_decoder": True,
}
def _reorder_cache(self, past, beam_idx):
reordered_past = ()
for layer_past in past:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past

843
ldm/models/nlvr_encoder.py Normal file
View File

@ -0,0 +1,843 @@
import math
import os
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
from torch import Tensor, device, dtype, nn
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
import torch.nn.functional as F
from transformers.activations import ACT2FN
from transformers.file_utils import (
ModelOutput,
)
from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
MaskedLMOutput,
MultipleChoiceModelOutput,
NextSentencePredictorOutput,
QuestionAnsweringModelOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
)
from transformers.modeling_utils import (
PreTrainedModel,
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
from transformers.utils import logging
from transformers.models.bert.configuration_bert import BertConfig
logger = logging.get_logger(__name__)
class BertEmbeddings(nn.Module):
"""Construct the embeddings from word and position embeddings."""
def __init__(self, config):
super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.config = config
def forward(
self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
):
if input_ids is not None:
input_shape = input_ids.size()
else:
input_shape = inputs_embeds.size()[:-1]
seq_length = input_shape[1]
if position_ids is None:
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
embeddings = inputs_embeds
if self.position_embedding_type == "absolute":
position_embeddings = self.position_embeddings(position_ids)
embeddings += position_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class BertSelfAttention(nn.Module):
def __init__(self, config, is_cross_attention):
super().__init__()
self.config = config
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
if is_cross_attention:
self.key = nn.Linear(config.encoder_width, self.all_head_size)
self.value = nn.Linear(config.encoder_width, self.all_head_size)
else:
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
self.save_attention = False
def save_attn_gradients(self, attn_gradients):
self.attn_gradients = attn_gradients
def get_attn_gradients(self):
return self.attn_gradients
def save_attention_map(self, attention_map):
self.attention_map = attention_map
def get_attention_map(self):
return self.attention_map
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
):
mixed_query_layer = self.query(hidden_states)
# If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None
if is_cross_attention:
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
else:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
past_key_value = (key_layer, value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
seq_length = hidden_states.size()[1]
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
distance = position_ids_l - position_ids_r
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
if self.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.Softmax(dim=-1)(attention_scores)
if is_cross_attention and self.save_attention:
self.save_attention_map(attention_probs)
attention_probs.register_hook(self.save_attn_gradients)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs_dropped = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs_dropped = attention_probs_dropped * head_mask
context_layer = torch.matmul(attention_probs_dropped, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
outputs = outputs + (past_key_value,)
return outputs
class BertSelfOutput(nn.Module):
def __init__(self, config, twin=False, merge=False):
super().__init__()
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
if twin:
self.dense0 = nn.Linear(config.hidden_size, config.hidden_size)
self.dense1 = nn.Linear(config.hidden_size, config.hidden_size)
else:
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
if merge:
self.act = ACT2FN[config.hidden_act]
self.merge_layer = nn.Linear(config.hidden_size * 2, config.hidden_size)
self.merge = True
else:
self.merge = False
def forward(self, hidden_states, input_tensor):
if type(hidden_states) == list:
hidden_states0 = self.dense0(hidden_states[0])
hidden_states1 = self.dense1(hidden_states[1])
if self.merge:
#hidden_states = self.merge_layer(self.act(torch.cat([hidden_states0,hidden_states1],dim=-1)))
hidden_states = self.merge_layer(torch.cat([hidden_states0,hidden_states1],dim=-1))
else:
hidden_states = (hidden_states0+hidden_states1)/2
else:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class BertAttention(nn.Module):
def __init__(self, config, is_cross_attention=False, layer_num=-1):
super().__init__()
if is_cross_attention:
self.self0 = BertSelfAttention(config, is_cross_attention)
self.self1 = BertSelfAttention(config, is_cross_attention)
else:
self.self = BertSelfAttention(config, is_cross_attention)
self.output = BertSelfOutput(config, twin=is_cross_attention, merge=(is_cross_attention and layer_num>=6))
self.pruned_heads = set()
def prune_heads(self, heads):
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
)
# Prune linear layers
self.self.query = prune_linear_layer(self.self.query, index)
self.self.key = prune_linear_layer(self.self.key, index)
self.self.value = prune_linear_layer(self.self.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
# Update hyper params and store pruned heads
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
):
if type(encoder_hidden_states)==list:
self_outputs0 = self.self0(
hidden_states,
attention_mask,
head_mask,
encoder_hidden_states[0],
encoder_attention_mask[0],
past_key_value,
output_attentions,
)
self_outputs1 = self.self1(
hidden_states,
attention_mask,
head_mask,
encoder_hidden_states[1],
encoder_attention_mask[1],
past_key_value,
output_attentions,
)
attention_output = self.output([self_outputs0[0],self_outputs1[0]], hidden_states)
outputs = (attention_output,) + self_outputs0[1:] # add attentions if we output them
else:
self_outputs = self.self(
hidden_states,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs
class BertIntermediate(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class BertOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class BertLayer(nn.Module):
def __init__(self, config, layer_num):
super().__init__()
self.config = config
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = BertAttention(config)
self.layer_num = layer_num
if self.config.add_cross_attention:
self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention, layer_num=layer_num)
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
mode=None,
):
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention(
hidden_states,
attention_mask,
head_mask,
output_attentions=output_attentions,
past_key_value=self_attn_past_key_value,
)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:-1]
present_key_value = self_attention_outputs[-1]
if mode=='multimodal':
assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
cross_attention_outputs = self.crossattention(
attention_output,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
output_attentions=output_attentions,
)
attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
)
outputs = (layer_output,) + outputs
outputs = outputs + (present_key_value,)
return outputs
def feed_forward_chunk(self, attention_output):
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output
class BertEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
mode='multimodal',
):
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
next_decoder_cache = () if use_cache else None
for i in range(self.config.num_hidden_layers):
layer_module = self.layer[i]
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warn(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
mode=mode,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
mode=mode,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[-1],)
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_decoder_cache,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)
class BertPooler(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
class BertPredictionHeadTransform(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
if isinstance(config.hidden_act, str):
self.transform_act_fn = ACT2FN[config.hidden_act]
else:
self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
class BertLMPredictionHead(nn.Module):
def __init__(self, config):
super().__init__()
self.transform = BertPredictionHeadTransform(config)
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
return hidden_states
class BertOnlyMLMHead(nn.Module):
def __init__(self, config):
super().__init__()
self.predictions = BertLMPredictionHead(config)
def forward(self, sequence_output):
prediction_scores = self.predictions(sequence_output)
return prediction_scores
class BertPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = BertConfig
base_model_prefix = "bert"
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
""" Initialize the weights """
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
class BertModel(BertPreTrainedModel):
"""
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
input to the forward pass.
"""
def __init__(self, config, add_pooling_layer=True):
super().__init__(config)
self.config = config
self.embeddings = BertEmbeddings(config)
self.encoder = BertEncoder(config)
self.pooler = BertPooler(config) if add_pooling_layer else None
self.init_weights()
def get_input_embeddings(self):
return self.embeddings.word_embeddings
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
"""
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
Arguments:
attention_mask (:obj:`torch.Tensor`):
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
input_shape (:obj:`Tuple[int]`):
The shape of the input to the model.
device: (:obj:`torch.device`):
The device of the input to the model.
Returns:
:obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
"""
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
if attention_mask.dim() == 3:
extended_attention_mask = attention_mask[:, None, :, :]
elif attention_mask.dim() == 2:
# Provided a padding mask of dimensions [batch_size, seq_length]
# - if the model is a decoder, apply a causal mask in addition to the padding mask
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
if is_decoder:
batch_size, seq_length = input_shape
seq_ids = torch.arange(seq_length, device=device)
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
# in case past_key_values are used we need to add a prefix ones mask to the causal mask
# causal and attention masks must have same type with pytorch version < 1.3
causal_mask = causal_mask.to(attention_mask.dtype)
if causal_mask.shape[1] < attention_mask.shape[1]:
prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
causal_mask = torch.cat(
[
torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
causal_mask,
],
axis=-1,
)
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
else:
extended_attention_mask = attention_mask[:, None, None, :]
else:
raise ValueError(
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
input_shape, attention_mask.shape
)
)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
return extended_attention_mask
def forward(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
is_decoder=False,
mode='multimodal',
):
r"""
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
the model is configured as a decoder.
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
use_cache (:obj:`bool`, `optional`):
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
decoding (see :obj:`past_key_values`).
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if is_decoder:
use_cache = use_cache if use_cache is not None else self.config.use_cache
else:
use_cache = False
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
batch_size, seq_length = input_shape
device = input_ids.device
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
batch_size, seq_length = input_shape
device = inputs_embeds.device
elif encoder_embeds is not None:
input_shape = encoder_embeds.size()[:-1]
batch_size, seq_length = input_shape
device = encoder_embeds.device
else:
raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")
# past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if attention_mask is None:
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape,
device, is_decoder)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if encoder_hidden_states is not None:
if type(encoder_hidden_states) == list:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
else:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if type(encoder_attention_mask) == list:
encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
elif encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
if encoder_embeds is None:
embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
)
else:
embedding_output = encoder_embeds
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
mode=mode,
)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
past_key_values=encoder_outputs.past_key_values,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
)

305
ldm/models/vit.py Normal file
View File

@ -0,0 +1,305 @@
'''
* Copyright (c) 2022, salesforce.com, inc.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
* By Junnan Li
* Based on timm code base
* https://github.com/rwightman/pytorch-image-models/tree/master/timm
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
from timm.models.vision_transformer import _cfg, PatchEmbed
from timm.models.registry import register_model
from timm.models.layers import trunc_normal_, DropPath
from timm.models.helpers import named_apply, adapt_input_conv
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
class Mlp(nn.Module):
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
"""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.attn_gradients = None
self.attention_map = None
def save_attn_gradients(self, attn_gradients):
self.attn_gradients = attn_gradients
def get_attn_gradients(self):
return self.attn_gradients
def save_attention_map(self, attention_map):
self.attention_map = attention_map
def get_attention_map(self):
return self.attention_map
def forward(self, x, register_hook=False):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
if register_hook:
self.save_attention_map(attn)
attn.register_hook(self.save_attn_gradients)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
if use_grad_checkpointing:
self.attn = checkpoint_wrapper(self.attn)
self.mlp = checkpoint_wrapper(self.mlp)
def forward(self, x, register_hook=False):
x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class VisionTransformer(nn.Module):
""" Vision Transformer
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
https://arxiv.org/abs/2010.11929
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None,
use_grad_checkpointing=False, ckpt_layer=0):
"""
Args:
img_size (int, tuple): input image size
patch_size (int, tuple): patch size
in_chans (int): number of input channels
num_classes (int): number of classes for classification head
embed_dim (int): embedding dimension
depth (int): depth of transformer
num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
qk_scale (float): override default qk scale of head_dim ** -0.5 if set
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
drop_rate (float): dropout rate
attn_drop_rate (float): attention dropout rate
drop_path_rate (float): stochastic depth rate
norm_layer: (nn.Module): normalization layer
"""
super().__init__()
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList([
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer)
)
for i in range(depth)])
self.norm = norm_layer(embed_dim)
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
def forward(self, x, register_blk=-1):
B = x.shape[0]
x = self.patch_embed(x)
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed[:,:x.size(1),:]
x = self.pos_drop(x)
for i,blk in enumerate(self.blocks):
x = blk(x, register_blk==i)
x = self.norm(x)
return x
@torch.jit.ignore()
def load_pretrained(self, checkpoint_path, prefix=''):
_load_weights(self, checkpoint_path, prefix)
@torch.no_grad()
def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
""" Load weights from .npz checkpoints for official Google Brain Flax implementation
"""
import numpy as np
def _n2p(w, t=True):
if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
w = w.flatten()
if t:
if w.ndim == 4:
w = w.transpose([3, 2, 0, 1])
elif w.ndim == 3:
w = w.transpose([2, 0, 1])
elif w.ndim == 2:
w = w.transpose([1, 0])
return torch.from_numpy(w)
w = np.load(checkpoint_path)
if not prefix and 'opt/target/embedding/kernel' in w:
prefix = 'opt/target/'
if hasattr(model.patch_embed, 'backbone'):
# hybrid
backbone = model.patch_embed.backbone
stem_only = not hasattr(backbone, 'stem')
stem = backbone if stem_only else backbone.stem
stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
if not stem_only:
for i, stage in enumerate(backbone.stages):
for j, block in enumerate(stage.blocks):
bp = f'{prefix}block{i + 1}/unit{j + 1}/'
for r in range(3):
getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
if block.downsample is not None:
block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
else:
embed_conv_w = adapt_input_conv(
model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
model.patch_embed.proj.weight.copy_(embed_conv_w)
model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
if pos_embed_w.shape != model.pos_embed.shape:
pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
model.pos_embed.copy_(pos_embed_w)
model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
# if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
# model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
# model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
# if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
# model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
# model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
for i, block in enumerate(model.blocks.children()):
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
block.attn.qkv.weight.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
block.attn.qkv.bias.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
for r in range(2):
getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):
# interpolate position embedding
embedding_size = pos_embed_checkpoint.shape[-1]
num_patches = visual_encoder.patch_embed.num_patches
num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches
# height (== width) for the checkpoint position embedding
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
# height (== width) for the new position embedding
new_size = int(num_patches ** 0.5)
if orig_size!=new_size:
# class_token and dist_token are kept unchanged
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2))
return new_pos_embed
else:
return pos_embed_checkpoint

View File

@ -40,7 +40,8 @@ def layout():
,Waifu Diffusion v1.2 , ./models/custom , https://huggingface.co/hakurei/waifu-diffusion
,Waifu Diffusion v1.2 Pruned , ./models/custom , https://huggingface.co/crumb/pruned-waifu-diffusion
,TrinArt Stable Diffusion v2 , ./models/custom , https://huggingface.co/naclbit/trinart_stable_diffusion_v2
,Stable Diffusion Concept Library , ./models/custom/sd-concepts-library , https://github.com/sd-webui/sd-concepts-library
,Stable Diffusion Concept Library , ./models/custom/sd-concepts-library , https://github.com/sd-webui/sd-concepts-library
,Blip Model , ./models/custom/blip , https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth
"""
colms = st.columns((1, 3, 5, 5))
columns = ["",'Model Name','Save Location','Download Link']

View File

@ -0,0 +1,175 @@
#@title Setup
#!pip3 install ftfy regex tqdm transformers==4.15.0 timm==0.4.12 fairscale==0.4.4
#!pip3 install git+https://github.com/openai/CLIP.git
#!git clone https://github.com/pharmapsychotic/clip-interrogator.git
#!git clone https://github.com/salesforce/BLIP
#%cd /content/BLIP
import clip
import gc
#import numpy as np
import os
import pandas as pd
import requests
import torch
#import torchvision.transforms as T
#import torchvision.transforms.functional as TF
from IPython.display import display
from PIL import Image
#from torch import nn
#from torch.nn import functional as F
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from ldm.models.blip import blip_decoder
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
blip_image_eval_size = 384
blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth'
blip_model = blip_decoder(pretrained=blip_model_url, image_size=blip_image_eval_size, vit='base')
blip_model.eval()
blip_model = blip_model.to(device)
def generate_caption(pil_image):
gpu_image = transforms.Compose([
transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])(image).unsqueeze(0).to(device)
with torch.no_grad():
caption = blip_model.generate(gpu_image, sample=False, num_beams=3, max_length=20, min_length=5)
return caption[0]
def load_list(filename):
with open(filename, 'r', encoding='utf-8', errors='replace') as f:
items = [line.strip() for line in f.readlines()]
return items
def rank(model, image_features, text_array, top_count=1):
top_count = min(top_count, len(text_array))
text_tokens = clip.tokenize([text for text in text_array]).cuda()
with torch.no_grad():
text_features = model.encode_text(text_tokens).float()
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = torch.zeros((1, len(text_array))).to(device)
for i in range(image_features.shape[0]):
similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)
similarity /= image_features.shape[0]
top_probs, top_labels = similarity.cpu().topk(top_count, dim=-1)
return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)]
def interrogate(image, models):
caption = generate_caption(image)
if len(models) == 0:
print(f"\n\n{caption}")
return
table = []
bests = [[('',0)]]*5
for model_name in models:
print(f"Interrogating with {model_name}...")
model, preprocess = clip.load(model_name)
model.cuda().eval()
images = preprocess(image).unsqueeze(0).cuda()
with torch.no_grad():
image_features = model.encode_image(images).float()
image_features /= image_features.norm(dim=-1, keepdim=True)
ranks = [
rank(model, image_features, mediums),
rank(model, image_features, ["by "+artist for artist in artists]),
rank(model, image_features, trending_list),
rank(model, image_features, movements),
rank(model, image_features, flavors, top_count=3)
]
for i in range(len(ranks)):
confidence_sum = 0
for ci in range(len(ranks[i])):
confidence_sum += ranks[i][ci][1]
if confidence_sum > sum(bests[i][t][1] for t in range(len(bests[i]))):
bests[i] = ranks[i]
row = [model_name]
for r in ranks:
row.append(', '.join([f"{x[0]} ({x[1]:0.1f}%)" for x in r]))
table.append(row)
del model
gc.collect()
display(pd.DataFrame(table, columns=["Model", "Medium", "Artist", "Trending", "Movement", "Flavors"]))
flaves = ', '.join([f"{x[0]}" for x in bests[4]])
medium = bests[0][0][0]
if caption.startswith(medium):
print(f"\n\n{caption} {bests[1][0][0]}, {bests[2][0][0]}, {bests[3][0][0]}, {flaves}")
else:
print(f"\n\n{caption}, {medium} {bests[1][0][0]}, {bests[2][0][0]}, {bests[3][0][0]}, {flaves}")
data_path = "../clip-interrogator/data/"
artists = load_list(os.path.join(data_path, 'artists.txt'))
flavors = load_list(os.path.join(data_path, 'flavors.txt'))
mediums = load_list(os.path.join(data_path, 'mediums.txt'))
movements = load_list(os.path.join(data_path, 'movements.txt'))
sites = ['Artstation', 'behance', 'cg society', 'cgsociety', 'deviantart', 'dribble', 'flickr', 'instagram', 'pexels', 'pinterest', 'pixabay', 'pixiv', 'polycount', 'reddit', 'shutterstock', 'tumblr', 'unsplash', 'zbrush central']
trending_list = [site for site in sites]
trending_list.extend(["trending on "+site for site in sites])
trending_list.extend(["featured on "+site for site in sites])
trending_list.extend([site+" contest winner" for site in sites])
#@title Interrogate!
#@markdown
#@markdown #####**Image:**
image_path_or_url = "https://i.redd.it/e2e8gimigjq91.jpg" #@param {type:"string"}
#@markdown
#@markdown #####**CLIP models:**
#@markdown For [StableDiffusion](https://stability.ai/blog/stable-diffusion-announcement) you can just use ViTL14<br>
#@markdown For [DiscoDiffusion](https://colab.research.google.com/github/alembics/disco-diffusion/blob/main/Disco_Diffusion.ipynb) and
#@markdown [JAX](https://colab.research.google.com/github/huemin-art/jax-guided-diffusion/blob/v2.7/Huemin_Jax_Diffusion_2_7.ipynb) enable all the same models here as you intend to use when generating your images
ViTB32 = True #@param{type:"boolean"}
ViTB16 = True #@param{type:"boolean"}
ViTL14 = False #@param{type:"boolean"}
ViTL14_336px = False #@param{type:"boolean"}
RN101 = False #@param{type:"boolean"}
RN50 = True #@param{type:"boolean"}
RN50x4 = False #@param{type:"boolean"}
RN50x16 = False #@param{type:"boolean"}
RN50x64 = False #@param{type:"boolean"}
models = []
if ViTB32: models.append('ViT-B/32')
if ViTB16: models.append('ViT-B/16')
if ViTL14: models.append('ViT-L/14')
if ViTL14_336px: models.append('ViT-L/14@336px')
if RN101: models.append('RN101')
if RN50: models.append('RN50')
if RN50x4: models.append('RN50x4')
if RN50x16: models.append('RN50x16')
if RN50x64: models.append('RN50x64')
if str(image_path_or_url).startswith('http://') or str(image_path_or_url).startswith('https://'):
image = Image.open(requests.get(image_path_or_url, stream=True).raw).convert('RGB')
else:
image = Image.open(image_path_or_url).convert('RGB')
thumb = image.copy()
thumb.thumbnail([blip_image_eval_size, blip_image_eval_size])
display(thumb)
interrogate(image, models=models)

View File

@ -39,44 +39,230 @@ from sd_utils import *
import streamlit_nested_layout
#streamlit components section
from streamlit_server_state import server_state, server_state_lock
#other imports
import hydralit_components as hc
import clip
import gc
import os
import pandas as pd
import requests
import torch
from IPython.display import display
from PIL import Image
#from torch import nn
#from torch.nn import functional as F
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from ldm.models.blip import blip_decoder
# end of imports
#---------------------------------------------------------------------------------------------------------------
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
blip_image_eval_size = 384
#blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth'
def generate_caption(pil_image):
blip_model = blip_decoder(pretrained="models/blip/model__base_caption.pth", image_size=blip_image_eval_size, vit='base', med_config="configs/blip/med_config.json")
blip_model.eval()
blip_model = blip_model.to(device)
gpu_image = transforms.Compose([
transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])(image).unsqueeze(0).to(device)
with torch.no_grad():
caption = blip_model.generate(gpu_image, sample=False, num_beams=3, max_length=20, min_length=5)
return caption[0]
def load_list(filename):
with open(filename, 'r', encoding='utf-8', errors='replace') as f:
items = [line.strip() for line in f.readlines()]
return items
def rank(model, image_features, text_array, top_count=1):
top_count = min(top_count, len(text_array))
text_tokens = clip.tokenize([text for text in text_array]).cuda()
with torch.no_grad():
text_features = model.encode_text(text_tokens).float()
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = torch.zeros((1, len(text_array))).to(device)
for i in range(image_features.shape[0]):
similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)
similarity /= image_features.shape[0]
top_probs, top_labels = similarity.cpu().topk(top_count, dim=-1)
return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)]
def interrogate(image, models):
caption = generate_caption(image)
if len(models) == 0:
print(f"\n\n{caption}")
return
table = []
bests = [[('',0)]]*5
for model_name in models:
print(f"Interrogating with {model_name}...")
model, preprocess = clip.load(model_name)
model.cuda().eval()
images = preprocess(image).unsqueeze(0).cuda()
with torch.no_grad():
image_features = model.encode_image(images).float()
image_features /= image_features.norm(dim=-1, keepdim=True)
ranks = [
rank(model, image_features, mediums),
rank(model, image_features, ["by "+artist for artist in artists]),
rank(model, image_features, trending_list),
rank(model, image_features, movements),
rank(model, image_features, flavors, top_count=3)
]
for i in range(len(ranks)):
confidence_sum = 0
for ci in range(len(ranks[i])):
confidence_sum += ranks[i][ci][1]
if confidence_sum > sum(bests[i][t][1] for t in range(len(bests[i]))):
bests[i] = ranks[i]
row = [model_name]
for r in ranks:
row.append(', '.join([f"{x[0]} ({x[1]:0.1f}%)" for x in r]))
table.append(row)
del model
gc.collect()
display(pd.DataFrame(table, columns=["Model", "Medium", "Artist", "Trending", "Movement", "Flavors"]))
flaves = ', '.join([f"{x[0]}" for x in bests[4]])
medium = bests[0][0][0]
if caption.startswith(medium):
print(f"\n\n{caption} {bests[1][0][0]}, {bests[2][0][0]}, {bests[3][0][0]}, {flaves}")
else:
print(f"\n\n{caption}, {medium} {bests[1][0][0]}, {bests[2][0][0]}, {bests[3][0][0]}, {flaves}")
#
def img2txt():
data_path = "data/"
artists = load_list(os.path.join(data_path, 'artists.txt'))
flavors = load_list(os.path.join(data_path, 'flavors.txt'))
mediums = load_list(os.path.join(data_path, 'mediums.txt'))
movements = load_list(os.path.join(data_path, 'movements.txt'))
sites = ['Artstation', 'behance', 'cg society', 'cgsociety', 'deviantart', 'dribble', 'flickr', 'instagram', 'pexels', 'pinterest', 'pixabay', 'pixiv', 'polycount', 'reddit', 'shutterstock', 'tumblr', 'unsplash', 'zbrush central']
trending_list = [site for site in sites]
trending_list.extend(["trending on "+site for site in sites])
trending_list.extend(["featured on "+site for site in sites])
trending_list.extend([site+" contest winner" for site in sites])
image_path_or_url = "https://i.redd.it/e2e8gimigjq91.jpg" #@param {type:"string"}
models = []
if st.session_state["ViTB32"]:
models.append('ViT-B/32')
if st.session_state['ViTB16']:
models.append('ViT-B/16')
if st.session_state["ViTL14"]:
models.append('ViT-L/14')
if st.session_state["ViTL14_336px"]:
models.append('ViT-L/14@336px')
if st.session_state["RN101"]:
models.append('RN101')
if st.session_state["RN50"]:
models.append('RN50')
if st.session_state["RN50x4"]:
models.append('RN50x4')
if st.session_state["RN50x16"]:
models.append('RN50x16')
if st.session_state["RN50x64"]:
models.append('RN50x64')
if str(image_path_or_url).startswith('http://') or str(image_path_or_url).startswith('https://'):
image = Image.open(requests.get(image_path_or_url, stream=True).raw).convert('RGB')
else:
image = Image.open(image_path_or_url).convert('RGB')
thumb = image.copy()
thumb.thumbnail([blip_image_eval_size, blip_image_eval_size])
#display(thumb)
interrogate(image, models=models)
#
def layout():
#set_page_title("Image-to-Text - Stable Diffusion WebUI")
st.info("Under Construction. :construction_worker:")
#st.info("Under Construction. :construction_worker:")
#theme_neutral = {'bgcolor': '#f9f9f9','title_color': 'black','content_color': 'black','icon_color': 'orange', 'icon': 'fa fa-question-circle'}
#hc.info_card(title='Some heading GOOD', content='All good!', sentiment='good',bar_value=77)
#hc.nav_bar([{'icon': "far fa-copy", 'label':"Left End"}, {'id':'Copy','icon':"🐙",'label':"Copy"},
#{'icon': "fa-solid fa-radar",'label':"Dropdown1",
#' submenu':[{'id':' subid11','icon': "fa fa-paperclip", 'label':"Sub-item 1"},
#{'id':'subid12','icon': "💀", 'label':"Sub-item 2"},
#{'id':'subid13','icon': "fa fa-database", 'label':"Sub-item 3"}]}],
#override_theme=theme_neutral, hide_streamlit_markers=False)
#with st.form("img2txt-inputs"):
#st.session_state["generation_mode"] = "txt2img"
with st.form("img2txt-inputs"):
st.session_state["generation_mode"] = "img2txt"
#input_col1, generate_col1 = st.columns([10,1])
input_col1, generate_col1 = st.columns([10,1])
#with input_col1:
##prompt = st.text_area("Input Text","")
with input_col1:
#prompt = st.text_area("Input Text","")
#prompt = st.text_input("Input Text","", placeholder="A corgi wearing a top hat as an oil painting.")
uploaded_image = st.file_uploader('Input Image')
## Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way.
#generate_col1.write("")
#generate_col1.write("")
#generate_button = generate_col1.form_submit_button("Generate")
# Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way.
generate_col1.write("")
generate_col1.write("")
generate_button = generate_col1.form_submit_button("Generate")
st.session_state["text_result"] = st.empty()
## creating the page layout using columns
#col1, col2, col3 = st.columns([1,2,1], gap="large")
# creating the page layout using columns
col1, col2, col3 = st.columns([1,2,1], gap="large")
with col1:
"""
CLIP models:
For StableDiffusion you can just use ViTL14
For DiscoDiffusion and JAX enable all the same models here as you intend to use when generating your images
ViTB32:
ViTB16:
ViTL14:
ViTL14_336px:
RN101:
RN50:
RN50x4:
RN50x16:
RN50x64:
"""
st.title("CLIP models")
st.session_state["ViTB32"] = st.checkbox("ViTB32", value=False, help="ViTB32 model.")
st.session_state["ViTB16"] = st.checkbox("ViTB16", value=False, help="ViTB16 model.")
st.session_state["ViTL14"] = st.checkbox("ViTL14", value=True, help="ViTL14 model.")
st.session_state["ViTL14_336px"] = st.checkbox("ViTL14_336px", value=False, help="ViTL14_336px model.")
st.session_state["RN101"] = st.checkbox("RN101", value=False, help="RN101 model.")
st.session_state["RN50"] = st.checkbox("RN50", value=False, help="RN50 model.")
st.session_state["RN50x4"] = st.checkbox("RN50x4", value=False, help="RN50x4 model.")
st.session_state["RN50x16"] = st.checkbox("RN50x16", value=False, help="RN50x16 model.")
st.session_state["RN50x64"] = st.checkbox("RN50x64", value=False, help="RN50x64 model.")
with col2:
st.session_state["input_image_preview"] = st.empty()
with col3:
st.session_state["text_message"] = st.empty()