Update img2txt.py

This commit is contained in:
hlky 2022-09-30 19:52:14 +01:00
parent 084fa9732f
commit aa858e0114
No known key found for this signature in database
GPG Key ID: 55A99F1E80D907D5

View File

@ -42,6 +42,7 @@ import streamlit_nested_layout
#other imports
import clip
import open_clip
import gc
import os
import pandas as pd
@ -60,14 +61,14 @@ from ldm.models.blip import blip_decoder
#---------------------------------------------------------------------------------------------------------------
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
blip_image_eval_size = 256
blip_image_eval_size = 512
blip_model = None
#blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth'
def load_blip_model():
blip_model = blip_decoder(pretrained="models/blip/model__base_caption.pth", image_size=blip_image_eval_size, vit='base', med_config="configs/blip/med_config.json")
blip_model.eval()
blip_model = blip_model.to(device).half()
blip_model = blip_model.to('cpu')
return blip_model
@ -77,7 +78,7 @@ def generate_caption(pil_image):
transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])(pil_image).unsqueeze(0).to(device).half()
])(pil_image).unsqueeze(0).to('cpu')
with torch.no_grad():
caption = blip_model.generate(gpu_image, sample=False, num_beams=3, max_length=20, min_length=5)
@ -107,6 +108,15 @@ def clear_cuda():
torch.cuda.empty_cache()
gc.collect()
def batch_rank(model, image_features, text_array, batch_size=42):
batch_count = len(text_array) // batch_size
batches = [text_array[i*batch_size:(i+1)*batch_size] for i in range(batch_count)]
batches.append(text_array[batch_count*batch_size:])
ranks = []
for batch in batches:
ranks += rank(model, image_features, batch)
return ranks
def interrogate(image, models):
global blip_model
blip_model = load_blip_model()
@ -126,47 +136,51 @@ def interrogate(image, models):
for model_name in models:
print(f"Interrogating {model_name}")
st.session_state["log_message"].code(f"Interrogating with {model_name}...", language='')
model, preprocess = clip.load(model_name)
model.cuda().eval()
if model_name == 'ViT-H-14':
model, _, preprocess = open_clip.create_model_and_transforms(model_name, pretrained='laion2b_s32b_b79k').to(device).half()
elif model_name == 'ViT-g-14':
model, _, preprocess = open_clip.create_model_and_transforms(model_name, pretrained='laion2b_s12b_b42k').to(device).half()
else:
model, preprocess = clip.load(model_name, device=device)
model.to(device).half()
images = preprocess(image).unsqueeze(0).cuda()
with torch.no_grad():
image_features = model.encode_image(images).float()
image_features /= image_features.norm(dim=-1, keepdim=True)
clear_cuda()
ranks = []
ranks.append(rank(model, image_features, server_state["mediums"]))
clear_cuda()
artists = []
for batch in range(int(len(server_state["artists"])/1000)):
artist_rank = rank(model, image_features, server_state["artists"][batch*1000:(batch+1)*1000])
artists.extend(artist_rank)
model.eval()
with torch.no_grad(), torch.autocast('cuda', dtype=torch.float16):
image_features = model.encode_image(preprocess(image).unsqueeze(0).to(device)).float()
image_features /= image_features.norm(dim=-1, keepdim=True)
clear_cuda()
ranks.append(artists)
ranks.append(rank(model, image_features, server_state["trending_list"]))
clear_cuda()
ranks.append(rank(model, image_features, server_state["movements"]))
clear_cuda()
ranks.append(rank(model, image_features, server_state["flavors"], top_count=3))
clear_cuda()
ranks = []
ranks.append(batch_rank(model, image_features, server_state["mediums"]))
ranks.append(batch_rank(model, image_features, ["by "+artist for artist in server_state["artists"]]))
ranks.append(batch_rank(model, image_features, server_state["trending_list"]))
ranks.append(batch_rank(model, image_features, server_state["movements"]))
ranks.append(batch_rank(model, image_features, server_state["flavors"]))
# ranks.append(batch_rank(model, image_features, server_state["genres"]))
# ranks.append(batch_rank(model, image_features, server_state["styles"]))
# ranks.append(batch_rank(model, image_features, server_state["techniques"]))
# ranks.append(batch_rank(model, image_features, server_state["subjects"]))
# ranks.append(batch_rank(model, image_features, server_state["colors"]))
# ranks.append(batch_rank(model, image_features, server_state["moods"]))
# ranks.append(batch_rank(model, image_features, server_state["themes"]))
# ranks.append(batch_rank(model, image_features, server_state["keywords"]))
# ranks.append(batch_rank(model, image_features, server_state["artists"]))
for i in range(len(ranks)):
confidence_sum = 0
for ci in range(len(ranks[i])):
confidence_sum += ranks[i][ci][1]
if confidence_sum > sum(bests[i][t][1] for t in range(len(bests[i]))):
bests[i] = ranks[i]
for i in range(len(ranks)):
confidence_sum = 0
for ci in range(len(ranks[i])):
confidence_sum += ranks[i][ci][1]
if confidence_sum > sum(bests[i][t][1] for t in range(len(bests[i]))):
bests[i] = ranks[i]
row = [model_name]
for r in ranks:
row.append(', '.join([f"{x[0]} ({x[1]:0.1f}%)" for x in r]))
row = [model_name]
for r in ranks:
row.append(', '.join([f"{x[0]} ({x[1]:0.1f}%)" for x in r]))
table.append(row)
table.append(row)
del model
gc.collect()
del model
gc.collect()
st.session_state["prediction_table"].dataframe(pd.DataFrame(table, columns=["Model", "Medium", "Artist", "Trending", "Movement", "Flavors"]))
@ -204,6 +218,10 @@ def img2txt():
models.append('ViT-B/16')
if st.session_state["ViTL14"]:
models.append('ViT-L/14')
if st.session_state["ViT-H-14"]:
models.append('ViT-H-14')
if st.session_state["ViT-g-14"]:
models.append('ViT-g-14')
if st.session_state["ViTL14_336px"]:
models.append('ViT-L/14@336px')
if st.session_state["RN101"]:
@ -226,7 +244,7 @@ def img2txt():
#thumb.thumbnail([blip_image_eval_size, blip_image_eval_size])
#display(thumb)
interrogate(st.session_state["uploaded_image"].pil_image, models=models)
interrogate(st.session_state["uploaded_image"].pil_image.convert('RGB'), models)
#
def layout():
@ -259,7 +277,10 @@ def layout():
st.session_state["RN50x4"] = st.checkbox("RN50x4", value=False, help="RN50x4 model.")
st.session_state["RN50x16"] = st.checkbox("RN50x16", value=False, help="RN50x16 model.")
st.session_state["RN50x64"] = st.checkbox("RN50x64", value=False, help="RN50x64 model.")
st.session_state["RN101"] = st.checkbox("RN101", value=False, help="RN101 model.")
st.session_state["RN101"] = st.checkbox("RN101", value=False, help="RN101 model.")
st.session_state["ViT-H-14"] = st.checkbox("ViT-H-14", value=False, help="ViT-H-14 model.")
st.session_state["ViT-g-14"] = st.checkbox("ViT-g-14", value=False, help="ViT-g-14 model.")
with col2:
@ -270,7 +291,7 @@ def layout():
st.session_state["input_image_preview"] = st.empty()
if st.session_state["uploaded_image"]:
st.session_state["uploaded_image"].pil_image = Image.open(st.session_state["uploaded_image"])#.convert('RGBA')
st.session_state["uploaded_image"].pil_image = Image.open(st.session_state["uploaded_image"])
#new_img = image.resize((width, height))
st.session_state["input_image_preview"].image(st.session_state["uploaded_image"].pil_image, clamp=True)
else: