This commit is contained in:
ZeroCool940711 2022-09-30 15:20:02 -07:00
parent 20a89a1233
commit 857608c5f6
2 changed files with 133 additions and 120 deletions

View File

@ -294,7 +294,7 @@ img2img:
img2txt:
batch_size: 100
blip_image_eval_size: 512
concepts_library:
concepts_per_page: 12

View File

@ -14,7 +14,7 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#---------------------------------------------------------------------------------------------------------------------------------------------------
# ---------------------------------------------------------------------------------------------------------------------------------------------------
"""
CLIP Interrogator made by @pharmapsychotic modified to work with our WebUI.
@ -31,20 +31,20 @@ Please consider buying him a coffee via [ko-fi](https://ko-fi.com/pharmapsychoti
And if you're looking for more Ai art tools check out my [Ai generative art tools list](https://pharmapsychotic.com/tools.html).
"""
#---------------------------------------------------------------------------------------------------------------------------------------------------
# ---------------------------------------------------------------------------------------------------------------------------------------------------
# base webui import and utils.
from ldm.util import default
from sd_utils import *
# streamlit imports
#streamlit components section
# streamlit components section
import streamlit_nested_layout
#other imports
# other imports
import clip, open_clip
import clip
import open_clip
import gc
import os
import pandas as pd
@ -56,53 +56,59 @@ from torchvision.transforms.functional import InterpolationMode
from ldm.models.blip import blip_decoder
# end of imports
#---------------------------------------------------------------------------------------------------------------
# ---------------------------------------------------------------------------------------------------------------
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
blip_image_eval_size = 512
blip_model = None
#blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth'
def load_blip_model():
print ("Loading BLIP Model")
print("Loading BLIP Model")
st.session_state["log_message"].code("Loading BLIP Model", language='')
with server_state_lock['blip_model']:
if "blip_model" not in server_state:
blip_model = blip_decoder(pretrained="models/blip/model__base_caption.pth",
with server_state_lock['blip_model']:
server_state["blip_model"] = blip_decoder(pretrained="models/blip/model__base_caption.pth",
image_size=blip_image_eval_size, vit='base', med_config="configs/blip/med_config.json")
blip_model.eval()
blip_model = blip_model.to(device).half()
print ("BLIP Model Loaded")
server_state["blip_model"] = server_state["blip_model"].eval()
#if not st.session_state["defaults"].general.optimized:
server_state["blip_model"] = server_state["blip_model"].to(device).half()
print("BLIP Model Loaded")
st.session_state["log_message"].code("BLIP Model Loaded", language='')
else:
print ("BLIP Model already loaded")
print("BLIP Model already loaded")
st.session_state["log_message"].code("BLIP Model Already Loaded", language='')
return blip_model
#return server_state["blip_model"]
def generate_caption(pil_image):
global blip_model
#width, height = pil_image.size
gpu_image = transforms.Compose([
transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
load_blip_model()
gpu_image = transforms.Compose([ # type: ignore
transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC), # type: ignore
transforms.ToTensor(), # type: ignore
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) # type: ignore
])(pil_image).unsqueeze(0).to(device).half()
with torch.no_grad():
caption = blip_model.generate(gpu_image, sample=False, num_beams=3, max_length=20, min_length=5)
caption = server_state["blip_model"].generate(gpu_image, sample=False, num_beams=3, max_length=20, min_length=5)
#print (caption)
return caption[0]
def load_list(filename):
with open(filename, 'r', encoding='utf-8', errors='replace') as f:
items = [line.strip() for line in f.readlines()]
return items
def rank(model, image_features, text_array, top_count=1):
top_count = min(top_count, len(text_array))
text_tokens = clip.tokenize([text for text in text_array]).cuda()
@ -118,10 +124,12 @@ def rank(model, image_features, text_array, top_count=1):
top_probs, top_labels = similarity.cpu().topk(top_count, dim=-1)
return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)]
def clear_cuda():
torch.cuda.empty_cache()
gc.collect()
def batch_rank(model, image_features, text_array, batch_size=st.session_state["defaults"].img2txt.batch_size):
batch_count = len(text_array) // batch_size
batches = [text_array[i*batch_size:(i+1)*batch_size] for i in range(batch_count)]
@ -132,18 +140,19 @@ def batch_rank(model, image_features, text_array, batch_size=st.session_state["d
return ranks
def interrogate(image, models):
global blip_model
blip_model = load_blip_model()
print ("Generating Caption")
#server_state["blip_model"] =
load_blip_model()
print("Generating Caption")
st.session_state["log_message"].code("Generating Caption", language='')
caption = generate_caption(image)
if st.session_state["defaults"].general.optimized:
del blip_model
del server_state["blip_model"]
clear_cuda()
print ("Caption Generated")
print("Caption Generated")
st.session_state["log_message"].code("Caption Generated", language='')
if len(models) == 0:
@ -151,44 +160,48 @@ def interrogate(image, models):
return
table = []
bests = [[('',0)]]*5
bests = [[('', 0)]]*5
print ("Ranking Text")
print("Ranking Text")
for model_name in models:
print(f"Interrogating with {model_name}...")
st.session_state["log_message"].code(f"Interrogating with {model_name}...", language='')
if "clip_model" not in server_state:
#with server_state_lock[server_state["clip_model"]]:
if model_name == 'ViT-H-14':
model, _, preprocess = open_clip.create_model_and_transforms(model_name, pretrained='laion2b_s32b_b79k')
server_state["clip_model"], _, server_state["preprocess"] = open_clip.create_model_and_transforms(model_name, pretrained='laion2b_s32b_b79k')
elif model_name == 'ViT-g-14':
model, _, preprocess = open_clip.create_model_and_transforms(model_name, pretrained='laion2b_s12b_b42k')
server_state["clip_model"], _, server_state["preprocess"] = open_clip.create_model_and_transforms(model_name, pretrained='laion2b_s12b_b42k')
else:
model, preprocess = clip.load(model_name, device=device)
server_state["clip_model"], server_state["preprocess"] = clip.load(model_name, device=device)
model.cuda().eval()
server_state["clip_model"] = server_state["clip_model"].cuda().eval()
images = server_state["preprocess"](image).unsqueeze(0).cuda()
images = preprocess(image).unsqueeze(0).cuda()
with torch.no_grad():
image_features = model.encode_image(images).float()
image_features = server_state["clip_model"].encode_image(images).float()
image_features /= image_features.norm(dim=-1, keepdim=True)
if st.session_state["defaults"].general.optimized:
clear_cuda()
ranks = []
ranks.append(batch_rank(model, image_features, server_state["mediums"]))
ranks.append(batch_rank(model, image_features, ["by "+artist for artist in server_state["artists"]]))
ranks.append(batch_rank(model, image_features, server_state["trending_list"]))
ranks.append(batch_rank(model, image_features, server_state["movements"]))
ranks.append(batch_rank(model, image_features, server_state["flavors"]))
# ranks.append(batch_rank(model, image_features, server_state["genres"]))
# ranks.append(batch_rank(model, image_features, server_state["styles"]))
# ranks.append(batch_rank(model, image_features, server_state["techniques"]))
# ranks.append(batch_rank(model, image_features, server_state["subjects"]))
# ranks.append(batch_rank(model, image_features, server_state["colors"]))
# ranks.append(batch_rank(model, image_features, server_state["moods"]))
# ranks.append(batch_rank(model, image_features, server_state["themes"]))
# ranks.append(batch_rank(model, image_features, server_state["keywords"]))
ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["mediums"]))
ranks.append(batch_rank(server_state["clip_model"], image_features, ["by "+artist for artist in server_state["artists"]]))
ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["trending_list"]))
ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["movements"]))
ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["flavors"]))
# ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["genres"]))
# ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["styles"]))
# ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["techniques"]))
# ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["subjects"]))
# ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["colors"]))
# ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["moods"]))
# ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["themes"]))
# ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["keywords"]))
for i in range(len(ranks)):
confidence_sum = 0
@ -204,10 +217,10 @@ def interrogate(image, models):
table.append(row)
if st.session_state["defaults"].general.optimized:
del model
del server_state["clip_model"]
gc.collect()
#for i in range(len(st.session_state["uploaded_image"])):
# for i in range(len(st.session_state["uploaded_image"])):
st.session_state["prediction_table"][st.session_state["processed_image_count"]].dataframe(pd.DataFrame(
table, columns=["Model", "Medium", "Artist", "Trending", "Movement", "Flavors"]))
@ -221,10 +234,11 @@ def interrogate(image, models):
f"\n\n{caption}, {medium} {bests[1][0][0]}, {bests[2][0][0]}, {bests[3][0][0]}, {flaves}", language="")
#
print ("Finished Interrogating.")
print("Finished Interrogating.")
st.session_state["log_message"].code("Finished Interrogating.", language="")
#
def img2txt():
data_path = "data/"
@ -270,14 +284,14 @@ def img2txt():
if st.session_state["RN50x64"]:
models.append('RN50x64')
#if str(image_path_or_url).startswith('http://') or str(image_path_or_url).startswith('https://'):
# if str(image_path_or_url).startswith('http://') or str(image_path_or_url).startswith('https://'):
#image = Image.open(requests.get(image_path_or_url, stream=True).raw).convert('RGB')
#else:
# else:
#image = Image.open(image_path_or_url).convert('RGB')
#thumb = st.session_state["uploaded_image"].image.copy()
#thumb.thumbnail([blip_image_eval_size, blip_image_eval_size])
#display(thumb)
# display(thumb)
st.session_state["processed_image_count"] = 0
@ -287,6 +301,8 @@ def img2txt():
# increase counter.
st.session_state["processed_image_count"] += 1
#
def layout():
#set_page_title("Image-to-Text - Stable Diffusion WebUI")
#st.info("Under Construction. :construction_worker:")
@ -294,9 +310,9 @@ def layout():
with st.form("img2txt-inputs"):
st.session_state["generation_mode"] = "img2txt"
#st.write("---")
# st.write("---")
# creating the page layout using columns
col1, col2 = st.columns([1,4], gap="large")
col1, col2 = st.columns([1, 4], gap="large")
with col1:
#url = st.text_area("Input Text","")
@ -322,12 +338,11 @@ def layout():
st.session_state["RN101"] = st.checkbox("RN101", value=False, help="RN101 model.")
#
#st.subheader("Logs:")
# st.subheader("Logs:")
st.session_state["log_message"] = st.empty()
st.session_state["log_message"].code('', language="")
with col2:
st.subheader("Image")
@ -335,7 +350,7 @@ def layout():
if st.session_state["uploaded_image"]:
#print (type(st.session_state["uploaded_image"]))
#if len(st.session_state["uploaded_image"]) == 1:
# if len(st.session_state["uploaded_image"]) == 1:
st.session_state["input_image_preview"] = []
st.session_state["input_image_preview_container"] = []
st.session_state["prediction_table"] = []
@ -343,13 +358,13 @@ def layout():
for i in range(len(st.session_state["uploaded_image"])):
st.session_state["input_image_preview_container"].append(i)
st.session_state["input_image_preview_container"][i]= st.empty()
st.session_state["input_image_preview_container"][i] = st.empty()
with st.session_state["input_image_preview_container"][i].container():
col1_output, col2_output = st.columns([2,10], gap="medium")
col1_output, col2_output = st.columns([2, 10], gap="medium")
with col1_output:
st.session_state["input_image_preview"].append(i)
st.session_state["input_image_preview"][i]= st.empty()
st.session_state["input_image_preview"][i] = st.empty()
st.session_state["uploaded_image"][i].pil_image = Image.open(st.session_state["uploaded_image"][i]).convert('RGB')
st.session_state["input_image_preview"][i].image(st.session_state["uploaded_image"][i].pil_image, use_column_width=True, clamp=True)
@ -363,18 +378,17 @@ def layout():
st.session_state["prediction_table"][i].table()
st.session_state["text_result"].append(i)
st.session_state["text_result"][i]= st.empty()
st.session_state["text_result"][i] = st.empty()
st.session_state["text_result"][i].code("", language="")
else:
#st.session_state["input_image_preview"].code('', language="")
st.image("images/streamlit/img2txt_placeholder.png", clamp=True)
#
# Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way.
#generate_col1.title("")
#generate_col1.title("")
# generate_col1.title("")
# generate_col1.title("")
generate_button = st.form_submit_button("Generate!")
if generate_button:
@ -388,6 +402,5 @@ def layout():
if "GFPGAN" in st.session_state and st.session_state["defaults"].general.optimized:
del st.session_state["GFPGAN"]
# run clip interrogator
img2txt()