mirror of
https://github.com/sd-webui/stable-diffusion-webui.git
synced 2024-12-15 07:12:58 +03:00
Update img2txt.py
This commit is contained in:
parent
084fa9732f
commit
aa858e0114
@ -42,6 +42,7 @@ import streamlit_nested_layout
|
||||
|
||||
#other imports
|
||||
import clip
|
||||
import open_clip
|
||||
import gc
|
||||
import os
|
||||
import pandas as pd
|
||||
@ -60,14 +61,14 @@ from ldm.models.blip import blip_decoder
|
||||
#---------------------------------------------------------------------------------------------------------------
|
||||
|
||||
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
||||
blip_image_eval_size = 256
|
||||
blip_image_eval_size = 512
|
||||
blip_model = None
|
||||
#blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth'
|
||||
|
||||
def load_blip_model():
|
||||
blip_model = blip_decoder(pretrained="models/blip/model__base_caption.pth", image_size=blip_image_eval_size, vit='base', med_config="configs/blip/med_config.json")
|
||||
blip_model.eval()
|
||||
blip_model = blip_model.to(device).half()
|
||||
blip_model = blip_model.to('cpu')
|
||||
|
||||
return blip_model
|
||||
|
||||
@ -77,7 +78,7 @@ def generate_caption(pil_image):
|
||||
transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
||||
])(pil_image).unsqueeze(0).to(device).half()
|
||||
])(pil_image).unsqueeze(0).to('cpu')
|
||||
|
||||
with torch.no_grad():
|
||||
caption = blip_model.generate(gpu_image, sample=False, num_beams=3, max_length=20, min_length=5)
|
||||
@ -107,6 +108,15 @@ def clear_cuda():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
def batch_rank(model, image_features, text_array, batch_size=42):
|
||||
batch_count = len(text_array) // batch_size
|
||||
batches = [text_array[i*batch_size:(i+1)*batch_size] for i in range(batch_count)]
|
||||
batches.append(text_array[batch_count*batch_size:])
|
||||
ranks = []
|
||||
for batch in batches:
|
||||
ranks += rank(model, image_features, batch)
|
||||
return ranks
|
||||
|
||||
def interrogate(image, models):
|
||||
global blip_model
|
||||
blip_model = load_blip_model()
|
||||
@ -126,47 +136,51 @@ def interrogate(image, models):
|
||||
for model_name in models:
|
||||
print(f"Interrogating {model_name}")
|
||||
st.session_state["log_message"].code(f"Interrogating with {model_name}...", language='')
|
||||
model, preprocess = clip.load(model_name)
|
||||
model.cuda().eval()
|
||||
if model_name == 'ViT-H-14':
|
||||
model, _, preprocess = open_clip.create_model_and_transforms(model_name, pretrained='laion2b_s32b_b79k').to(device).half()
|
||||
elif model_name == 'ViT-g-14':
|
||||
model, _, preprocess = open_clip.create_model_and_transforms(model_name, pretrained='laion2b_s12b_b42k').to(device).half()
|
||||
else:
|
||||
model, preprocess = clip.load(model_name, device=device)
|
||||
model.to(device).half()
|
||||
|
||||
images = preprocess(image).unsqueeze(0).cuda()
|
||||
with torch.no_grad():
|
||||
image_features = model.encode_image(images).float()
|
||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||
clear_cuda()
|
||||
|
||||
ranks = []
|
||||
ranks.append(rank(model, image_features, server_state["mediums"]))
|
||||
clear_cuda()
|
||||
artists = []
|
||||
for batch in range(int(len(server_state["artists"])/1000)):
|
||||
artist_rank = rank(model, image_features, server_state["artists"][batch*1000:(batch+1)*1000])
|
||||
artists.extend(artist_rank)
|
||||
model.eval()
|
||||
with torch.no_grad(), torch.autocast('cuda', dtype=torch.float16):
|
||||
image_features = model.encode_image(preprocess(image).unsqueeze(0).to(device)).float()
|
||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||
clear_cuda()
|
||||
ranks.append(artists)
|
||||
ranks.append(rank(model, image_features, server_state["trending_list"]))
|
||||
clear_cuda()
|
||||
ranks.append(rank(model, image_features, server_state["movements"]))
|
||||
clear_cuda()
|
||||
ranks.append(rank(model, image_features, server_state["flavors"], top_count=3))
|
||||
clear_cuda()
|
||||
|
||||
ranks = []
|
||||
ranks.append(batch_rank(model, image_features, server_state["mediums"]))
|
||||
ranks.append(batch_rank(model, image_features, ["by "+artist for artist in server_state["artists"]]))
|
||||
ranks.append(batch_rank(model, image_features, server_state["trending_list"]))
|
||||
ranks.append(batch_rank(model, image_features, server_state["movements"]))
|
||||
ranks.append(batch_rank(model, image_features, server_state["flavors"]))
|
||||
# ranks.append(batch_rank(model, image_features, server_state["genres"]))
|
||||
# ranks.append(batch_rank(model, image_features, server_state["styles"]))
|
||||
# ranks.append(batch_rank(model, image_features, server_state["techniques"]))
|
||||
# ranks.append(batch_rank(model, image_features, server_state["subjects"]))
|
||||
# ranks.append(batch_rank(model, image_features, server_state["colors"]))
|
||||
# ranks.append(batch_rank(model, image_features, server_state["moods"]))
|
||||
# ranks.append(batch_rank(model, image_features, server_state["themes"]))
|
||||
# ranks.append(batch_rank(model, image_features, server_state["keywords"]))
|
||||
# ranks.append(batch_rank(model, image_features, server_state["artists"]))
|
||||
|
||||
for i in range(len(ranks)):
|
||||
confidence_sum = 0
|
||||
for ci in range(len(ranks[i])):
|
||||
confidence_sum += ranks[i][ci][1]
|
||||
if confidence_sum > sum(bests[i][t][1] for t in range(len(bests[i]))):
|
||||
bests[i] = ranks[i]
|
||||
|
||||
for i in range(len(ranks)):
|
||||
confidence_sum = 0
|
||||
for ci in range(len(ranks[i])):
|
||||
confidence_sum += ranks[i][ci][1]
|
||||
if confidence_sum > sum(bests[i][t][1] for t in range(len(bests[i]))):
|
||||
bests[i] = ranks[i]
|
||||
row = [model_name]
|
||||
for r in ranks:
|
||||
row.append(', '.join([f"{x[0]} ({x[1]:0.1f}%)" for x in r]))
|
||||
|
||||
row = [model_name]
|
||||
for r in ranks:
|
||||
row.append(', '.join([f"{x[0]} ({x[1]:0.1f}%)" for x in r]))
|
||||
table.append(row)
|
||||
|
||||
table.append(row)
|
||||
|
||||
del model
|
||||
gc.collect()
|
||||
del model
|
||||
gc.collect()
|
||||
|
||||
st.session_state["prediction_table"].dataframe(pd.DataFrame(table, columns=["Model", "Medium", "Artist", "Trending", "Movement", "Flavors"]))
|
||||
|
||||
@ -204,6 +218,10 @@ def img2txt():
|
||||
models.append('ViT-B/16')
|
||||
if st.session_state["ViTL14"]:
|
||||
models.append('ViT-L/14')
|
||||
if st.session_state["ViT-H-14"]:
|
||||
models.append('ViT-H-14')
|
||||
if st.session_state["ViT-g-14"]:
|
||||
models.append('ViT-g-14')
|
||||
if st.session_state["ViTL14_336px"]:
|
||||
models.append('ViT-L/14@336px')
|
||||
if st.session_state["RN101"]:
|
||||
@ -226,7 +244,7 @@ def img2txt():
|
||||
#thumb.thumbnail([blip_image_eval_size, blip_image_eval_size])
|
||||
#display(thumb)
|
||||
|
||||
interrogate(st.session_state["uploaded_image"].pil_image, models=models)
|
||||
interrogate(st.session_state["uploaded_image"].pil_image.convert('RGB'), models)
|
||||
|
||||
#
|
||||
def layout():
|
||||
@ -259,7 +277,10 @@ def layout():
|
||||
st.session_state["RN50x4"] = st.checkbox("RN50x4", value=False, help="RN50x4 model.")
|
||||
st.session_state["RN50x16"] = st.checkbox("RN50x16", value=False, help="RN50x16 model.")
|
||||
st.session_state["RN50x64"] = st.checkbox("RN50x64", value=False, help="RN50x64 model.")
|
||||
st.session_state["RN101"] = st.checkbox("RN101", value=False, help="RN101 model.")
|
||||
st.session_state["RN101"] = st.checkbox("RN101", value=False, help="RN101 model.")
|
||||
st.session_state["ViT-H-14"] = st.checkbox("ViT-H-14", value=False, help="ViT-H-14 model.")
|
||||
st.session_state["ViT-g-14"] = st.checkbox("ViT-g-14", value=False, help="ViT-g-14 model.")
|
||||
|
||||
|
||||
|
||||
with col2:
|
||||
@ -270,7 +291,7 @@ def layout():
|
||||
st.session_state["input_image_preview"] = st.empty()
|
||||
|
||||
if st.session_state["uploaded_image"]:
|
||||
st.session_state["uploaded_image"].pil_image = Image.open(st.session_state["uploaded_image"])#.convert('RGBA')
|
||||
st.session_state["uploaded_image"].pil_image = Image.open(st.session_state["uploaded_image"])
|
||||
#new_img = image.resize((width, height))
|
||||
st.session_state["input_image_preview"].image(st.session_state["uploaded_image"].pil_image, clamp=True)
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user