Img2txt is working now (#1360)

Img2txt is working and ready to use but needs more than 8GB of VRAM to
work for now. I will be trying to improve that as the next step.
This commit is contained in:
Alejandro Gil 2022-09-29 08:58:56 -07:00 committed by GitHub
commit 4cefd5463c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -45,9 +45,8 @@ import clip
import gc
import os
import pandas as pd
import requests
#import requests
import torch
from IPython.display import display
from PIL import Image
#from torch import nn
#from torch.nn import functional as F
@ -60,8 +59,23 @@ from ldm.models.blip import blip_decoder
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
blip_image_eval_size = 384
#blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth'
blip_image_eval_size = 256
#blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth'
def load_blip_model():
blip_model = blip_decoder(pretrained="models/blip/model__base_caption.pth", image_size=blip_image_eval_size, vit='base', med_config="configs/blip/med_config.json")
blip_model.eval()
return blip_model
def load_clip_model(clip_model_name):
import clip
model, preprocess = clip.load(clip_model_name)
model.eval()
model = model.to(device)
return model, preprocess
def generate_caption(pil_image):
blip_model = blip_decoder(pretrained="models/blip/model__base_caption.pth", image_size=blip_image_eval_size, vit='base', med_config="configs/blip/med_config.json")
@ -99,7 +113,10 @@ def rank(model, image_features, text_array, top_count=1):
return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)]
def interrogate(image, models):
print ("Generating Caption")
st.session_state["log_message"].code("Generating Caption", language='')
caption = generate_caption(image)
if len(models) == 0:
print(f"\n\n{caption}")
return
@ -107,8 +124,8 @@ def interrogate(image, models):
table = []
bests = [[('',0)]]*5
for model_name in models:
print(f"Interrogating with {model_name}...")
model, preprocess = clip.load(model_name)
st.session_state["log_message"].code(f"Interrogating with {model_name}...", language='')
model, preprocess = load_clip_model(model_name)
model.cuda().eval()
images = preprocess(image).unsqueeze(0).cuda()
@ -140,15 +157,16 @@ def interrogate(image, models):
del model
gc.collect()
display(pd.DataFrame(table, columns=["Model", "Medium", "Artist", "Trending", "Movement", "Flavors"]))
st.session_state["prediction_table"].dataframe(pd.DataFrame(table, columns=["Model", "Medium", "Artist", "Trending", "Movement", "Flavors"]))
flaves = ', '.join([f"{x[0]}" for x in bests[4]])
medium = bests[0][0][0]
if caption.startswith(medium):
print(f"\n\n{caption} {bests[1][0][0]}, {bests[2][0][0]}, {bests[3][0][0]}, {flaves}")
st.session_state["text_result"].code(f"\n\n{caption} {bests[1][0][0]}, {bests[2][0][0]}, {bests[3][0][0]}, {flaves}", language="")
else:
print(f"\n\n{caption}, {medium} {bests[1][0][0]}, {bests[2][0][0]}, {bests[3][0][0]}, {flaves}")
st.session_state["text_result"].code(f"\n\n{caption}, {medium} {bests[1][0][0]}, {bests[2][0][0]}, {bests[3][0][0]}, {flaves}", language="")
st.session_state["log_message"].code("Finished Interrogating.", language="")
#
def img2txt():
@ -160,44 +178,44 @@ def img2txt():
server_state["movements"] = load_list(os.path.join(data_path, 'img2txt', 'movements.txt'))
server_state["sites"] = load_list(os.path.join(data_path, 'img2txt', 'sites.txt'))
trending_list = [site for site in server_state["sites"]]
trending_list.extend(["trending on "+site for site in server_state["sites"]])
trending_list.extend(["featured on "+site for site in server_state["sites"]])
trending_list.extend([site+" contest winner" for site in server_state["sites"]])
server_state["trending_list"] = [site for site in server_state["sites"]]
server_state["trending_list"].extend(["trending on "+site for site in server_state["sites"]])
server_state["trending_list"].extend(["featured on "+site for site in server_state["sites"]])
server_state["trending_list"].extend([site+" contest winner" for site in server_state["sites"]])
image_path_or_url = "https://i.redd.it/e2e8gimigjq91.jpg"
#image_path_or_url = "https://i.redd.it/e2e8gimigjq91.jpg"
st.session_state["models"] = []
models = []
if st.session_state["ViTB32"]:
st.session_state["models"].append('ViT-B/32')
models.append('ViT-B/32')
if st.session_state['ViTB16']:
st.session_state["models"].append('ViT-B/16')
models.append('ViT-B/16')
if st.session_state["ViTL14"]:
st.session_state["models"].append('ViT-L/14')
models.append('ViT-L/14')
if st.session_state["ViTL14_336px"]:
st.session_state["models"].append('ViT-L/14@336px')
models.append('ViT-L/14@336px')
if st.session_state["RN101"]:
st.session_state["models"].append('RN101')
models.append('RN101')
if st.session_state["RN50"]:
st.session_state["models"].append('RN50')
models.append('RN50')
if st.session_state["RN50x4"]:
st.session_state["models"].append('RN50x4')
models.append('RN50x4')
if st.session_state["RN50x16"]:
st.session_state["models"].append('RN50x16')
models.append('RN50x16')
if st.session_state["RN50x64"]:
st.session_state["models"].append('RN50x64')
models.append('RN50x64')
if str(image_path_or_url).startswith('http://') or str(image_path_or_url).startswith('https://'):
image = Image.open(requests.get(image_path_or_url, stream=True).raw).convert('RGB')
else:
image = Image.open(image_path_or_url).convert('RGB')
#if str(image_path_or_url).startswith('http://') or str(image_path_or_url).startswith('https://'):
#image = Image.open(requests.get(image_path_or_url, stream=True).raw).convert('RGB')
#else:
#image = Image.open(image_path_or_url).convert('RGB')
thumb = image.copy()
thumb.thumbnail([blip_image_eval_size, blip_image_eval_size])
#thumb = st.session_state["uploaded_image"].image.copy()
#thumb.thumbnail([blip_image_eval_size, blip_image_eval_size])
#display(thumb)
interrogate(image, models=st.session_state["models"])
interrogate(st.session_state["uploaded_image"].pil_image, models=models)
#
def layout():
@ -241,9 +259,9 @@ def layout():
st.session_state["input_image_preview"] = st.empty()
if st.session_state["uploaded_image"]:
st.session_state["uploaded_image"].image = Image.open(st.session_state["uploaded_image"]).convert('RGBA')
st.session_state["uploaded_image"].pil_image = Image.open(st.session_state["uploaded_image"])#.convert('RGBA')
#new_img = image.resize((width, height))
st.session_state["input_image_preview"].image(st.session_state["uploaded_image"].image, clamp=True)
st.session_state["input_image_preview"].image(st.session_state["uploaded_image"].pil_image, clamp=True)
else:
#st.session_state["input_image_preview"].code('', language="")
st.image("images/streamlit/img2txt_placeholder.png", clamp=True)