mirror of
https://github.com/sd-webui/stable-diffusion-webui.git
synced 2024-12-15 15:22:55 +03:00
Img2txt is working now (#1360)
Img2txt is working and ready to use but needs more than 8GB of VRAM to work for now. I will be trying to improve that as the next step.
This commit is contained in:
commit
4cefd5463c
@ -45,9 +45,8 @@ import clip
|
||||
import gc
|
||||
import os
|
||||
import pandas as pd
|
||||
import requests
|
||||
#import requests
|
||||
import torch
|
||||
from IPython.display import display
|
||||
from PIL import Image
|
||||
#from torch import nn
|
||||
#from torch.nn import functional as F
|
||||
@ -60,8 +59,23 @@ from ldm.models.blip import blip_decoder
|
||||
|
||||
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
blip_image_eval_size = 384
|
||||
#blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth'
|
||||
blip_image_eval_size = 256
|
||||
#blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth'
|
||||
|
||||
def load_blip_model():
|
||||
blip_model = blip_decoder(pretrained="models/blip/model__base_caption.pth", image_size=blip_image_eval_size, vit='base', med_config="configs/blip/med_config.json")
|
||||
blip_model.eval()
|
||||
|
||||
return blip_model
|
||||
|
||||
def load_clip_model(clip_model_name):
|
||||
import clip
|
||||
|
||||
model, preprocess = clip.load(clip_model_name)
|
||||
model.eval()
|
||||
model = model.to(device)
|
||||
|
||||
return model, preprocess
|
||||
|
||||
def generate_caption(pil_image):
|
||||
blip_model = blip_decoder(pretrained="models/blip/model__base_caption.pth", image_size=blip_image_eval_size, vit='base', med_config="configs/blip/med_config.json")
|
||||
@ -99,7 +113,10 @@ def rank(model, image_features, text_array, top_count=1):
|
||||
return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)]
|
||||
|
||||
def interrogate(image, models):
|
||||
print ("Generating Caption")
|
||||
st.session_state["log_message"].code("Generating Caption", language='')
|
||||
caption = generate_caption(image)
|
||||
|
||||
if len(models) == 0:
|
||||
print(f"\n\n{caption}")
|
||||
return
|
||||
@ -107,8 +124,8 @@ def interrogate(image, models):
|
||||
table = []
|
||||
bests = [[('',0)]]*5
|
||||
for model_name in models:
|
||||
print(f"Interrogating with {model_name}...")
|
||||
model, preprocess = clip.load(model_name)
|
||||
st.session_state["log_message"].code(f"Interrogating with {model_name}...", language='')
|
||||
model, preprocess = load_clip_model(model_name)
|
||||
model.cuda().eval()
|
||||
|
||||
images = preprocess(image).unsqueeze(0).cuda()
|
||||
@ -140,15 +157,16 @@ def interrogate(image, models):
|
||||
del model
|
||||
gc.collect()
|
||||
|
||||
display(pd.DataFrame(table, columns=["Model", "Medium", "Artist", "Trending", "Movement", "Flavors"]))
|
||||
st.session_state["prediction_table"].dataframe(pd.DataFrame(table, columns=["Model", "Medium", "Artist", "Trending", "Movement", "Flavors"]))
|
||||
|
||||
flaves = ', '.join([f"{x[0]}" for x in bests[4]])
|
||||
medium = bests[0][0][0]
|
||||
if caption.startswith(medium):
|
||||
print(f"\n\n{caption} {bests[1][0][0]}, {bests[2][0][0]}, {bests[3][0][0]}, {flaves}")
|
||||
st.session_state["text_result"].code(f"\n\n{caption} {bests[1][0][0]}, {bests[2][0][0]}, {bests[3][0][0]}, {flaves}", language="")
|
||||
else:
|
||||
print(f"\n\n{caption}, {medium} {bests[1][0][0]}, {bests[2][0][0]}, {bests[3][0][0]}, {flaves}")
|
||||
|
||||
st.session_state["text_result"].code(f"\n\n{caption}, {medium} {bests[1][0][0]}, {bests[2][0][0]}, {bests[3][0][0]}, {flaves}", language="")
|
||||
|
||||
st.session_state["log_message"].code("Finished Interrogating.", language="")
|
||||
#
|
||||
|
||||
def img2txt():
|
||||
@ -160,44 +178,44 @@ def img2txt():
|
||||
server_state["movements"] = load_list(os.path.join(data_path, 'img2txt', 'movements.txt'))
|
||||
server_state["sites"] = load_list(os.path.join(data_path, 'img2txt', 'sites.txt'))
|
||||
|
||||
trending_list = [site for site in server_state["sites"]]
|
||||
trending_list.extend(["trending on "+site for site in server_state["sites"]])
|
||||
trending_list.extend(["featured on "+site for site in server_state["sites"]])
|
||||
trending_list.extend([site+" contest winner" for site in server_state["sites"]])
|
||||
server_state["trending_list"] = [site for site in server_state["sites"]]
|
||||
server_state["trending_list"].extend(["trending on "+site for site in server_state["sites"]])
|
||||
server_state["trending_list"].extend(["featured on "+site for site in server_state["sites"]])
|
||||
server_state["trending_list"].extend([site+" contest winner" for site in server_state["sites"]])
|
||||
|
||||
image_path_or_url = "https://i.redd.it/e2e8gimigjq91.jpg"
|
||||
#image_path_or_url = "https://i.redd.it/e2e8gimigjq91.jpg"
|
||||
|
||||
st.session_state["models"] = []
|
||||
models = []
|
||||
|
||||
if st.session_state["ViTB32"]:
|
||||
st.session_state["models"].append('ViT-B/32')
|
||||
models.append('ViT-B/32')
|
||||
if st.session_state['ViTB16']:
|
||||
st.session_state["models"].append('ViT-B/16')
|
||||
models.append('ViT-B/16')
|
||||
if st.session_state["ViTL14"]:
|
||||
st.session_state["models"].append('ViT-L/14')
|
||||
models.append('ViT-L/14')
|
||||
if st.session_state["ViTL14_336px"]:
|
||||
st.session_state["models"].append('ViT-L/14@336px')
|
||||
models.append('ViT-L/14@336px')
|
||||
if st.session_state["RN101"]:
|
||||
st.session_state["models"].append('RN101')
|
||||
models.append('RN101')
|
||||
if st.session_state["RN50"]:
|
||||
st.session_state["models"].append('RN50')
|
||||
models.append('RN50')
|
||||
if st.session_state["RN50x4"]:
|
||||
st.session_state["models"].append('RN50x4')
|
||||
models.append('RN50x4')
|
||||
if st.session_state["RN50x16"]:
|
||||
st.session_state["models"].append('RN50x16')
|
||||
models.append('RN50x16')
|
||||
if st.session_state["RN50x64"]:
|
||||
st.session_state["models"].append('RN50x64')
|
||||
models.append('RN50x64')
|
||||
|
||||
if str(image_path_or_url).startswith('http://') or str(image_path_or_url).startswith('https://'):
|
||||
image = Image.open(requests.get(image_path_or_url, stream=True).raw).convert('RGB')
|
||||
else:
|
||||
image = Image.open(image_path_or_url).convert('RGB')
|
||||
#if str(image_path_or_url).startswith('http://') or str(image_path_or_url).startswith('https://'):
|
||||
#image = Image.open(requests.get(image_path_or_url, stream=True).raw).convert('RGB')
|
||||
#else:
|
||||
#image = Image.open(image_path_or_url).convert('RGB')
|
||||
|
||||
thumb = image.copy()
|
||||
thumb.thumbnail([blip_image_eval_size, blip_image_eval_size])
|
||||
#thumb = st.session_state["uploaded_image"].image.copy()
|
||||
#thumb.thumbnail([blip_image_eval_size, blip_image_eval_size])
|
||||
#display(thumb)
|
||||
|
||||
interrogate(image, models=st.session_state["models"])
|
||||
interrogate(st.session_state["uploaded_image"].pil_image, models=models)
|
||||
|
||||
#
|
||||
def layout():
|
||||
@ -241,9 +259,9 @@ def layout():
|
||||
st.session_state["input_image_preview"] = st.empty()
|
||||
|
||||
if st.session_state["uploaded_image"]:
|
||||
st.session_state["uploaded_image"].image = Image.open(st.session_state["uploaded_image"]).convert('RGBA')
|
||||
st.session_state["uploaded_image"].pil_image = Image.open(st.session_state["uploaded_image"])#.convert('RGBA')
|
||||
#new_img = image.resize((width, height))
|
||||
st.session_state["input_image_preview"].image(st.session_state["uploaded_image"].image, clamp=True)
|
||||
st.session_state["input_image_preview"].image(st.session_state["uploaded_image"].pil_image, clamp=True)
|
||||
else:
|
||||
#st.session_state["input_image_preview"].code('', language="")
|
||||
st.image("images/streamlit/img2txt_placeholder.png", clamp=True)
|
||||
|
Loading…
Reference in New Issue
Block a user