mirror of
https://github.com/Sygil-Dev/sygil-webui.git
synced 2024-12-14 22:13:41 +03:00
Improved img2txt layout and performance.
This commit is contained in:
parent
3bb432162e
commit
d8a8687852
@ -292,6 +292,9 @@ img2img:
|
||||
variant_seed: ""
|
||||
write_info_files: True
|
||||
|
||||
img2txt:
|
||||
batch_size: 100
|
||||
|
||||
concepts_library:
|
||||
concepts_per_page: 12
|
||||
|
||||
|
@ -14,6 +14,7 @@
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
#---------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
"""
|
||||
CLIP Interrogator made by @pharmapsychotic modified to work with our WebUI.
|
||||
|
||||
@ -30,28 +31,26 @@ Please consider buying him a coffee via [ko-fi](https://ko-fi.com/pharmapsychoti
|
||||
And if you're looking for more Ai art tools check out my [Ai generative art tools list](https://pharmapsychotic.com/tools.html).
|
||||
|
||||
"""
|
||||
#---------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
#
|
||||
# base webui import and utils.
|
||||
from ldm.util import default
|
||||
from sd_utils import *
|
||||
|
||||
# streamlit imports
|
||||
import streamlit_nested_layout
|
||||
|
||||
#streamlit components section
|
||||
import streamlit_nested_layout
|
||||
|
||||
#other imports
|
||||
|
||||
import clip, open_clip
|
||||
import gc
|
||||
import os
|
||||
import pandas as pd
|
||||
#import requests
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
import torchvision.transforms.functional as TF
|
||||
from PIL import Image
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
from ldm.models.blip import blip_decoder
|
||||
@ -66,9 +65,11 @@ blip_model = None
|
||||
|
||||
def load_blip_model():
|
||||
st.session_state["log_message"].code("Loading BLIP Model", language='')
|
||||
|
||||
with server_state_lock['blip_model']:
|
||||
if "blip_model" not in server_state:
|
||||
blip_model = blip_decoder(pretrained="models/blip/model__base_caption.pth", image_size=blip_image_eval_size, vit='base', med_config="configs/blip/med_config.json")
|
||||
blip_model = blip_decoder(pretrained="models/blip/model__base_caption.pth",
|
||||
image_size=blip_image_eval_size, vit='base', med_config="configs/blip/med_config.json")
|
||||
blip_model.eval()
|
||||
blip_model = blip_model.to(device).half()
|
||||
|
||||
@ -80,10 +81,10 @@ def load_blip_model():
|
||||
|
||||
def generate_caption(pil_image):
|
||||
global blip_model
|
||||
width, height = pil_image.size
|
||||
#width, height = pil_image.size
|
||||
|
||||
gpu_image = transforms.Compose([
|
||||
transforms.Resize((width, height), interpolation=InterpolationMode.BICUBIC),
|
||||
transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
||||
])(pil_image).unsqueeze(0).to(device).half()
|
||||
@ -118,6 +119,15 @@ def clear_cuda():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
def batch_rank(model, image_features, text_array, batch_size=st.session_state["defaults"].img2txt.batch_size):
|
||||
batch_count = len(text_array) // batch_size
|
||||
batches = [text_array[i*batch_size:(i+1)*batch_size] for i in range(batch_count)]
|
||||
batches.append(text_array[batch_count*batch_size:])
|
||||
ranks = []
|
||||
for batch in batches:
|
||||
ranks += rank(model, image_features, batch)
|
||||
return ranks
|
||||
|
||||
def interrogate(image, models):
|
||||
global blip_model
|
||||
blip_model = load_blip_model()
|
||||
@ -125,9 +135,12 @@ def interrogate(image, models):
|
||||
st.session_state["log_message"].code("Generating Caption", language='')
|
||||
caption = generate_caption(image)
|
||||
|
||||
if st.session_state["defaults"].general.optimized:
|
||||
del blip_model
|
||||
clear_cuda()
|
||||
|
||||
print ("Caption Generated")
|
||||
st.session_state["log_message"].code("Caption Generated", language='')
|
||||
|
||||
if len(models) == 0:
|
||||
print(f"\n\n{caption}")
|
||||
@ -135,10 +148,10 @@ def interrogate(image, models):
|
||||
|
||||
table = []
|
||||
bests = [[('',0)]]*5
|
||||
for model_name in models:
|
||||
print(f"Interrogating with {model_name}")
|
||||
st.session_state["log_message"].code(f"Interrogating with {model_name}...", language='')
|
||||
|
||||
for model_name in models:
|
||||
print(f"Interrogating with {model_name}...")
|
||||
st.session_state["log_message"].code(f"Interrogating with {model_name}...", language='')
|
||||
if model_name == 'ViT-H-14':
|
||||
model, _, preprocess = open_clip.create_model_and_transforms(model_name, pretrained='laion2b_s32b_b79k')
|
||||
elif model_name == 'ViT-g-14':
|
||||
@ -146,31 +159,30 @@ def interrogate(image, models):
|
||||
else:
|
||||
model, preprocess = clip.load(model_name, device=device)
|
||||
|
||||
#model, preprocess = clip.load(model_name)
|
||||
|
||||
model.cuda().eval()
|
||||
|
||||
images = preprocess(image).unsqueeze(0).cuda()
|
||||
with torch.no_grad():
|
||||
image_features = model.encode_image(images).float()
|
||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||
|
||||
if st.session_state["defaults"].general.optimized:
|
||||
clear_cuda()
|
||||
|
||||
ranks = []
|
||||
ranks.append(rank(model, image_features, server_state["mediums"]))
|
||||
clear_cuda()
|
||||
artists = []
|
||||
for batch in range(int(len(server_state["artists"])/1000)):
|
||||
artist_rank = rank(model, image_features, server_state["artists"][batch*1000:(batch+1)*1000])
|
||||
artists.extend(artist_rank)
|
||||
clear_cuda()
|
||||
ranks.append(artists)
|
||||
ranks.append(rank(model, image_features, server_state["trending_list"]))
|
||||
clear_cuda()
|
||||
ranks.append(rank(model, image_features, server_state["movements"]))
|
||||
clear_cuda()
|
||||
ranks.append(rank(model, image_features, server_state["flavors"], top_count=3))
|
||||
clear_cuda()
|
||||
ranks.append(batch_rank(model, image_features, server_state["mediums"]))
|
||||
ranks.append(batch_rank(model, image_features, ["by "+artist for artist in server_state["artists"]]))
|
||||
ranks.append(batch_rank(model, image_features, server_state["trending_list"]))
|
||||
ranks.append(batch_rank(model, image_features, server_state["movements"]))
|
||||
ranks.append(batch_rank(model, image_features, server_state["flavors"]))
|
||||
# ranks.append(batch_rank(model, image_features, server_state["genres"]))
|
||||
# ranks.append(batch_rank(model, image_features, server_state["styles"]))
|
||||
# ranks.append(batch_rank(model, image_features, server_state["techniques"]))
|
||||
# ranks.append(batch_rank(model, image_features, server_state["subjects"]))
|
||||
# ranks.append(batch_rank(model, image_features, server_state["colors"]))
|
||||
# ranks.append(batch_rank(model, image_features, server_state["moods"]))
|
||||
# ranks.append(batch_rank(model, image_features, server_state["themes"]))
|
||||
# ranks.append(batch_rank(model, image_features, server_state["keywords"]))
|
||||
|
||||
|
||||
for i in range(len(ranks)):
|
||||
@ -186,20 +198,25 @@ def interrogate(image, models):
|
||||
|
||||
table.append(row)
|
||||
|
||||
#del model
|
||||
if st.session_state["defaults"].general.optimized:
|
||||
del model
|
||||
gc.collect()
|
||||
|
||||
st.session_state["prediction_table"].dataframe(pd.DataFrame(table, columns=["Model", "Medium", "Artist", "Trending", "Movement", "Flavors"]))
|
||||
#for i in range(len(st.session_state["uploaded_image"])):
|
||||
st.session_state["prediction_table"][st.session_state["processed_image_count"]].dataframe(pd.DataFrame(
|
||||
table, columns=["Model", "Medium", "Artist", "Trending", "Movement", "Flavors"]))
|
||||
|
||||
flaves = ', '.join([f"{x[0]}" for x in bests[4]])
|
||||
medium = bests[0][0][0]
|
||||
|
||||
for items in caption:
|
||||
if items.startswith(medium):
|
||||
st.session_state["text_result"].code(f"\n\n{caption} {bests[1][0][0]}, {bests[2][0][0]}, {bests[3][0][0]}, {flaves}", language="")
|
||||
if caption.startswith(medium):
|
||||
st.session_state["text_result"][st.session_state["processed_image_count"]].code(
|
||||
f"\n\n{caption} {bests[1][0][0]}, {bests[2][0][0]}, {bests[3][0][0]}, {flaves}", language="")
|
||||
else:
|
||||
st.session_state["text_result"].code(f"\n\n{caption}, {medium} {bests[1][0][0]}, {bests[2][0][0]}, {bests[3][0][0]}, {flaves}", language="")
|
||||
st.session_state["text_result"][st.session_state["processed_image_count"]].code(
|
||||
f"\n\n{caption}, {medium} {bests[1][0][0]}, {bests[2][0][0]}, {bests[3][0][0]}, {flaves}", language="")
|
||||
|
||||
#
|
||||
print ("Finished Interrogating.")
|
||||
st.session_state["log_message"].code("Finished Interrogating.", language="")
|
||||
#
|
||||
|
||||
@ -211,6 +228,10 @@ def img2txt():
|
||||
server_state["mediums"] = load_list(os.path.join(data_path, 'img2txt', 'mediums.txt'))
|
||||
server_state["movements"] = load_list(os.path.join(data_path, 'img2txt', 'movements.txt'))
|
||||
server_state["sites"] = load_list(os.path.join(data_path, 'img2txt', 'sites.txt'))
|
||||
# server_state["genres"] = load_list(os.path.join(data_path, 'img2txt', 'genres.txt'))
|
||||
# server_state["styles"] = load_list(os.path.join(data_path, 'img2txt', 'styles.txt'))
|
||||
# server_state["techniques"] = load_list(os.path.join(data_path, 'img2txt', 'techniques.txt'))
|
||||
# server_state["subjects"] = load_list(os.path.join(data_path, 'img2txt', 'subjects.txt'))
|
||||
|
||||
server_state["trending_list"] = [site for site in server_state["sites"]]
|
||||
server_state["trending_list"].extend(["trending on "+site for site in server_state["sites"]])
|
||||
@ -227,6 +248,10 @@ def img2txt():
|
||||
models.append('ViT-B/16')
|
||||
if st.session_state["ViTL14"]:
|
||||
models.append('ViT-L/14')
|
||||
if st.session_state["ViT-H-14"]:
|
||||
models.append('ViT-H-14')
|
||||
if st.session_state["ViT-g-14"]:
|
||||
models.append('ViT-g-14')
|
||||
if st.session_state["ViTL14_336px"]:
|
||||
models.append('ViT-L/14@336px')
|
||||
if st.session_state["RN101"]:
|
||||
@ -249,9 +274,13 @@ def img2txt():
|
||||
#thumb.thumbnail([blip_image_eval_size, blip_image_eval_size])
|
||||
#display(thumb)
|
||||
|
||||
st.session_state["processed_image_count"] = 0
|
||||
|
||||
interrogate(st.session_state["uploaded_image"].pil_image, models=models)
|
||||
for i in range(len(st.session_state["uploaded_image"])):
|
||||
|
||||
interrogate(st.session_state["uploaded_image"][i].pil_image, models=models)
|
||||
# increase counter.
|
||||
st.session_state["processed_image_count"] += 1
|
||||
#
|
||||
def layout():
|
||||
#set_page_title("Image-to-Text - Stable Diffusion WebUI")
|
||||
@ -276,6 +305,8 @@ def layout():
|
||||
|
||||
with st.expander("Others"):
|
||||
st.info("For DiscoDiffusion and JAX enable all the same models here as you intend to use when generating your images.")
|
||||
st.session_state["ViT-H-14"] = st.checkbox("ViT-H-14", value=False, help="ViT-H-14 model.")
|
||||
st.session_state["ViT-g-14"] = st.checkbox("ViT-g-14", value=False, help="ViT-g-14 model.")
|
||||
st.session_state["ViTL14_336px"] = st.checkbox("ViTL14_336px", value=False, help="ViTL14_336px model.")
|
||||
st.session_state["ViTB16"] = st.checkbox("ViTB16", value=False, help="ViTB16 model.")
|
||||
st.session_state["ViTB32"] = st.checkbox("ViTB32", value=False, help="ViTB32 model.")
|
||||
@ -288,7 +319,7 @@ def layout():
|
||||
#
|
||||
#st.subheader("Logs:")
|
||||
|
||||
st.session_state["log_message"] = st.empty() if not st.session_state["log_message"] else st.session_state["log_message"]
|
||||
st.session_state["log_message"] = st.empty()
|
||||
st.session_state["log_message"].code('', language="")
|
||||
|
||||
|
||||
@ -297,50 +328,39 @@ def layout():
|
||||
|
||||
refresh = st.form_submit_button("Refresh", help='Refresh the image preview to show your uploaded image instead of the default placeholder.')
|
||||
|
||||
col1_output, col2_output = st.columns([2,10], gap="medium")
|
||||
|
||||
if st.session_state["uploaded_image"]:
|
||||
if type(st.session_state["uploaded_image"]) != list:
|
||||
#print (type(st.session_state["uploaded_image"]))
|
||||
#if len(st.session_state["uploaded_image"]) == 1:
|
||||
st.session_state["input_image_preview"] = []
|
||||
st.session_state["input_image_preview_container"] = []
|
||||
st.session_state["prediction_table"] = []
|
||||
st.session_state["text_result"] = []
|
||||
|
||||
for i in range(len(st.session_state["uploaded_image"])):
|
||||
st.session_state["input_image_preview_container"].append(i)
|
||||
st.session_state["input_image_preview_container"][i]= st.empty()
|
||||
|
||||
with st.session_state["input_image_preview_container"][i].container():
|
||||
col1_output, col2_output = st.columns([2,10], gap="medium")
|
||||
with col1_output:
|
||||
st.session_state["input_image_preview"] = st.empty()
|
||||
st.session_state["uploaded_image"].pil_image = Image.open(st.session_state["uploaded_image"]).convert('RGB')
|
||||
st.session_state["input_image_preview"].append(i)
|
||||
st.session_state["input_image_preview"][i]= st.empty()
|
||||
st.session_state["uploaded_image"][i].pil_image = Image.open(st.session_state["uploaded_image"][i]).convert('RGB')
|
||||
|
||||
st.session_state["input_image_preview"].image(st.session_state["uploaded_image"].pil_image, use_column_width=True, clamp=True)
|
||||
st.session_state["input_image_preview"][i].image(st.session_state["uploaded_image"][i].pil_image, use_column_width=True, clamp=True)
|
||||
|
||||
with st.session_state["input_image_preview_container"][i].container():
|
||||
|
||||
with col2_output:
|
||||
#with st.container():
|
||||
##st.subheader("Image To Text Result")
|
||||
|
||||
st.session_state["prediction_table"] = st.empty() if not st.session_state["prediction_table"] or refresh else st.session_state["prediction_table"]
|
||||
st.session_state["prediction_table"].table() if not st.session_state["prediction_table"].table() or refresh else st.session_state["prediction_table"].table()
|
||||
st.session_state["prediction_table"].append(i)
|
||||
st.session_state["prediction_table"][i] = st.empty()
|
||||
st.session_state["prediction_table"][i].table()
|
||||
|
||||
st.session_state["text_result"] = st.empty() if not st.session_state["text_result"] or refresh else st.session_state["text_result"]
|
||||
st.session_state["text_result"].code('', language="") if not st.session_state["text_result"].code('', language=""
|
||||
) or refresh else st.session_state["text_result"].code('', language="")
|
||||
st.session_state["text_result"].append(i)
|
||||
st.session_state["text_result"][i]= st.empty()
|
||||
st.session_state["text_result"][i].code("", language="")
|
||||
|
||||
else:
|
||||
for i in range(st.session_state["uploaded_image"]):
|
||||
#for image in st.session_state["uploaded_image"]:
|
||||
#st.session_state["uploaded_image"].pil_image[i] = []
|
||||
st.session_state["uploaded_image"].pil_image[i] = Image.open(st.session_state["uploaded_image"][i]).convert('RGB')
|
||||
#
|
||||
#st.write("---")
|
||||
with col1_output:
|
||||
st.session_state["input_image_preview"] = st.empty()
|
||||
st.session_state["uploaded_image"].pil_image = Image.open(st.session_state["uploaded_image"]).convert('RGB')
|
||||
st.session_state["input_image_preview"].image(st.session_state["uploaded_image"].pil_image, use_column_width=True, clamp=True)
|
||||
|
||||
with col2_output:
|
||||
#with st.container():
|
||||
##st.subheader("Image To Text Result")
|
||||
|
||||
st.session_state["prediction_table"] = st.empty() if not st.session_state["prediction_table"] or refresh else st.session_state["prediction_table"]
|
||||
st.session_state["prediction_table"].table() if not st.session_state["prediction_table"].table() or refresh else st.session_state["prediction_table"].table()
|
||||
|
||||
st.session_state["text_result"] = st.empty() if not st.session_state["text_result"] or refresh else st.session_state["text_result"]
|
||||
st.session_state["text_result"].code('', language="") if not st.session_state["text_result"].code('', language=""
|
||||
) or refresh else st.session_state["text_result"].code('', language="")
|
||||
|
||||
else:
|
||||
#st.session_state["input_image_preview"].code('', language="")
|
||||
@ -353,5 +373,16 @@ def layout():
|
||||
generate_button = st.form_submit_button("Generate!")
|
||||
|
||||
if generate_button:
|
||||
# if model, pipe, RealESRGAN or GFPGAN is in st.session_state remove the model and pipe form session_state so that they are reloaded.
|
||||
if "model" in st.session_state and st.session_state["defaults"].general.optimized:
|
||||
del st.session_state["model"]
|
||||
if "pipe" in st.session_state and st.session_state["defaults"].general.optimized:
|
||||
del st.session_state["pipe"]
|
||||
if "RealESRGAN" in st.session_state and st.session_state["defaults"].general.optimized:
|
||||
del st.session_state["RealESRGAN"]
|
||||
if "GFPGAN" in st.session_state and st.session_state["defaults"].general.optimized:
|
||||
del st.session_state["GFPGAN"]
|
||||
|
||||
|
||||
# run clip interrogator
|
||||
img2txt()
|
@ -870,7 +870,7 @@ def try_loading_LDSR(model_name: str,checking=False):
|
||||
|
||||
# Loads Stable Diffusion model by name
|
||||
#@retry(tries=5)
|
||||
def load_sd_model(model_name: str) -> [any, any, any, any, any]:
|
||||
def load_sd_model(model_name: str):
|
||||
ckpt_path = st.session_state.defaults.general.default_model_path
|
||||
|
||||
if model_name != st.session_state.defaults.general.default_model:
|
||||
|
Loading…
Reference in New Issue
Block a user