From d8a86878522baccb02b75a8122c1f80f78b413fe Mon Sep 17 00:00:00 2001 From: ZeroCool940711 Date: Fri, 30 Sep 2022 12:40:52 -0700 Subject: [PATCH] Improved img2txt layout and performance. --- configs/webui/webui_streamlit.yaml | 3 + scripts/img2txt.py | 241 ++++++++++++++++------------- scripts/sd_utils.py | 2 +- 3 files changed, 140 insertions(+), 106 deletions(-) diff --git a/configs/webui/webui_streamlit.yaml b/configs/webui/webui_streamlit.yaml index 77f8c37..aac5bf1 100644 --- a/configs/webui/webui_streamlit.yaml +++ b/configs/webui/webui_streamlit.yaml @@ -292,6 +292,9 @@ img2img: variant_seed: "" write_info_files: True +img2txt: + batch_size: 100 + concepts_library: concepts_per_page: 12 diff --git a/scripts/img2txt.py b/scripts/img2txt.py index d9fe8fa..82d98f9 100644 --- a/scripts/img2txt.py +++ b/scripts/img2txt.py @@ -14,6 +14,7 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +#--------------------------------------------------------------------------------------------------------------------------------------------------- """ CLIP Interrogator made by @pharmapsychotic modified to work with our WebUI. @@ -30,28 +31,26 @@ Please consider buying him a coffee via [ko-fi](https://ko-fi.com/pharmapsychoti And if you're looking for more Ai art tools check out my [Ai generative art tools list](https://pharmapsychotic.com/tools.html). """ +#--------------------------------------------------------------------------------------------------------------------------------------------------- -# # base webui import and utils. +from ldm.util import default from sd_utils import * # streamlit imports -import streamlit_nested_layout #streamlit components section +import streamlit_nested_layout #other imports + import clip, open_clip import gc import os import pandas as pd #import requests import torch -import torchvision.transforms as T -import torchvision.transforms.functional as TF from PIL import Image -from torch import nn -from torch.nn import functional as F from torchvision import transforms from torchvision.transforms.functional import InterpolationMode from ldm.models.blip import blip_decoder @@ -66,39 +65,41 @@ blip_model = None def load_blip_model(): st.session_state["log_message"].code("Loading BLIP Model", language='') + with server_state_lock['blip_model']: if "blip_model" not in server_state: - blip_model = blip_decoder(pretrained="models/blip/model__base_caption.pth", image_size=blip_image_eval_size, vit='base', med_config="configs/blip/med_config.json") + blip_model = blip_decoder(pretrained="models/blip/model__base_caption.pth", + image_size=blip_image_eval_size, vit='base', med_config="configs/blip/med_config.json") blip_model.eval() - blip_model = blip_model.to(device).half() - + blip_model = blip_model.to(device).half() + st.session_state["log_message"].code("BLIP Model Loaded", language='') else: st.session_state["log_message"].code("BLIP Model Already Loaded", language='') - + return blip_model def generate_caption(pil_image): global blip_model - width, height = pil_image.size - + #width, height = pil_image.size + gpu_image = transforms.Compose([ - transforms.Resize((width, height), interpolation=InterpolationMode.BICUBIC), - transforms.ToTensor(), - transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) - ])(pil_image).unsqueeze(0).to(device).half() + transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC), + transforms.ToTensor(), + transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) + ])(pil_image).unsqueeze(0).to(device).half() with torch.no_grad(): caption = blip_model.generate(gpu_image, sample=False, num_beams=3, max_length=20, min_length=5) - + #print (caption) return caption[0] def load_list(filename): with open(filename, 'r', encoding='utf-8', errors='replace') as f: items = [line.strip() for line in f.readlines()] - return items - + return items + def rank(model, image_features, text_array, top_count=1): top_count = min(top_count, len(text_array)) text_tokens = clip.tokenize([text for text in text_array]).cuda() @@ -118,16 +119,28 @@ def clear_cuda(): torch.cuda.empty_cache() gc.collect() +def batch_rank(model, image_features, text_array, batch_size=st.session_state["defaults"].img2txt.batch_size): + batch_count = len(text_array) // batch_size + batches = [text_array[i*batch_size:(i+1)*batch_size] for i in range(batch_count)] + batches.append(text_array[batch_count*batch_size:]) + ranks = [] + for batch in batches: + ranks += rank(model, image_features, batch) + return ranks + def interrogate(image, models): global blip_model blip_model = load_blip_model() print ("Generating Caption") st.session_state["log_message"].code("Generating Caption", language='') caption = generate_caption(image) - - del blip_model - clear_cuda() + + if st.session_state["defaults"].general.optimized: + del blip_model + clear_cuda() + print ("Caption Generated") + st.session_state["log_message"].code("Caption Generated", language='') if len(models) == 0: print(f"\n\n{caption}") @@ -135,42 +148,41 @@ def interrogate(image, models): table = [] bests = [[('',0)]]*5 + for model_name in models: - print(f"Interrogating with {model_name}") + print(f"Interrogating with {model_name}...") st.session_state["log_message"].code(f"Interrogating with {model_name}...", language='') - if model_name == 'ViT-H-14': - model, _, preprocess = open_clip.create_model_and_transforms(model_name, pretrained='laion2b_s32b_b79k') - elif model_name == 'ViT-g-14': - model, _, preprocess = open_clip.create_model_and_transforms(model_name, pretrained='laion2b_s12b_b42k') - else: - model, preprocess = clip.load(model_name, device=device) - - #model, preprocess = clip.load(model_name) - + model, _, preprocess = open_clip.create_model_and_transforms(model_name, pretrained='laion2b_s32b_b79k') + elif model_name == 'ViT-g-14': + model, _, preprocess = open_clip.create_model_and_transforms(model_name, pretrained='laion2b_s12b_b42k') + else: + model, preprocess = clip.load(model_name, device=device) + model.cuda().eval() images = preprocess(image).unsqueeze(0).cuda() with torch.no_grad(): image_features = model.encode_image(images).float() image_features /= image_features.norm(dim=-1, keepdim=True) - clear_cuda() + + if st.session_state["defaults"].general.optimized: + clear_cuda() ranks = [] - ranks.append(rank(model, image_features, server_state["mediums"])) - clear_cuda() - artists = [] - for batch in range(int(len(server_state["artists"])/1000)): - artist_rank = rank(model, image_features, server_state["artists"][batch*1000:(batch+1)*1000]) - artists.extend(artist_rank) - clear_cuda() - ranks.append(artists) - ranks.append(rank(model, image_features, server_state["trending_list"])) - clear_cuda() - ranks.append(rank(model, image_features, server_state["movements"])) - clear_cuda() - ranks.append(rank(model, image_features, server_state["flavors"], top_count=3)) - clear_cuda() + ranks.append(batch_rank(model, image_features, server_state["mediums"])) + ranks.append(batch_rank(model, image_features, ["by "+artist for artist in server_state["artists"]])) + ranks.append(batch_rank(model, image_features, server_state["trending_list"])) + ranks.append(batch_rank(model, image_features, server_state["movements"])) + ranks.append(batch_rank(model, image_features, server_state["flavors"])) + # ranks.append(batch_rank(model, image_features, server_state["genres"])) + # ranks.append(batch_rank(model, image_features, server_state["styles"])) + # ranks.append(batch_rank(model, image_features, server_state["techniques"])) + # ranks.append(batch_rank(model, image_features, server_state["subjects"])) + # ranks.append(batch_rank(model, image_features, server_state["colors"])) + # ranks.append(batch_rank(model, image_features, server_state["moods"])) + # ranks.append(batch_rank(model, image_features, server_state["themes"])) + # ranks.append(batch_rank(model, image_features, server_state["keywords"])) for i in range(len(ranks)): @@ -186,20 +198,25 @@ def interrogate(image, models): table.append(row) - #del model - gc.collect() - - st.session_state["prediction_table"].dataframe(pd.DataFrame(table, columns=["Model", "Medium", "Artist", "Trending", "Movement", "Flavors"])) + if st.session_state["defaults"].general.optimized: + del model + gc.collect() + + #for i in range(len(st.session_state["uploaded_image"])): + st.session_state["prediction_table"][st.session_state["processed_image_count"]].dataframe(pd.DataFrame( + table, columns=["Model", "Medium", "Artist", "Trending", "Movement", "Flavors"])) flaves = ', '.join([f"{x[0]}" for x in bests[4]]) medium = bests[0][0][0] - - for items in caption: - if items.startswith(medium): - st.session_state["text_result"].code(f"\n\n{caption} {bests[1][0][0]}, {bests[2][0][0]}, {bests[3][0][0]}, {flaves}", language="") - else: - st.session_state["text_result"].code(f"\n\n{caption}, {medium} {bests[1][0][0]}, {bests[2][0][0]}, {bests[3][0][0]}, {flaves}", language="") - + if caption.startswith(medium): + st.session_state["text_result"][st.session_state["processed_image_count"]].code( + f"\n\n{caption} {bests[1][0][0]}, {bests[2][0][0]}, {bests[3][0][0]}, {flaves}", language="") + else: + st.session_state["text_result"][st.session_state["processed_image_count"]].code( + f"\n\n{caption}, {medium} {bests[1][0][0]}, {bests[2][0][0]}, {bests[3][0][0]}, {flaves}", language="") + + # + print ("Finished Interrogating.") st.session_state["log_message"].code("Finished Interrogating.", language="") # @@ -211,6 +228,10 @@ def img2txt(): server_state["mediums"] = load_list(os.path.join(data_path, 'img2txt', 'mediums.txt')) server_state["movements"] = load_list(os.path.join(data_path, 'img2txt', 'movements.txt')) server_state["sites"] = load_list(os.path.join(data_path, 'img2txt', 'sites.txt')) + # server_state["genres"] = load_list(os.path.join(data_path, 'img2txt', 'genres.txt')) + # server_state["styles"] = load_list(os.path.join(data_path, 'img2txt', 'styles.txt')) + # server_state["techniques"] = load_list(os.path.join(data_path, 'img2txt', 'techniques.txt')) + # server_state["subjects"] = load_list(os.path.join(data_path, 'img2txt', 'subjects.txt')) server_state["trending_list"] = [site for site in server_state["sites"]] server_state["trending_list"].extend(["trending on "+site for site in server_state["sites"]]) @@ -227,6 +248,10 @@ def img2txt(): models.append('ViT-B/16') if st.session_state["ViTL14"]: models.append('ViT-L/14') + if st.session_state["ViT-H-14"]: + models.append('ViT-H-14') + if st.session_state["ViT-g-14"]: + models.append('ViT-g-14') if st.session_state["ViTL14_336px"]: models.append('ViT-L/14@336px') if st.session_state["RN101"]: @@ -249,9 +274,13 @@ def img2txt(): #thumb.thumbnail([blip_image_eval_size, blip_image_eval_size]) #display(thumb) + st.session_state["processed_image_count"] = 0 - interrogate(st.session_state["uploaded_image"].pil_image, models=models) - + for i in range(len(st.session_state["uploaded_image"])): + + interrogate(st.session_state["uploaded_image"][i].pil_image, models=models) + # increase counter. + st.session_state["processed_image_count"] += 1 # def layout(): #set_page_title("Image-to-Text - Stable Diffusion WebUI") @@ -270,12 +299,14 @@ def layout(): #st.subheader("Input Image") st.session_state["uploaded_image"] = st.file_uploader('Input Image', type=['png', 'jpg', 'jpeg'], accept_multiple_files=True) - st.subheader("CLIP models") + st.subheader("CLIP models") with st.expander("Stable Diffusion", expanded=True): st.session_state["ViTL14"] = st.checkbox("ViTL14", value=True, help="For StableDiffusion you can just use ViTL14.") with st.expander("Others"): st.info("For DiscoDiffusion and JAX enable all the same models here as you intend to use when generating your images.") + st.session_state["ViT-H-14"] = st.checkbox("ViT-H-14", value=False, help="ViT-H-14 model.") + st.session_state["ViT-g-14"] = st.checkbox("ViT-g-14", value=False, help="ViT-g-14 model.") st.session_state["ViTL14_336px"] = st.checkbox("ViTL14_336px", value=False, help="ViTL14_336px model.") st.session_state["ViTB16"] = st.checkbox("ViTB16", value=False, help="ViTB16 model.") st.session_state["ViTB32"] = st.checkbox("ViTB32", value=False, help="ViTB32 model.") @@ -283,13 +314,13 @@ def layout(): st.session_state["RN50x4"] = st.checkbox("RN50x4", value=False, help="RN50x4 model.") st.session_state["RN50x16"] = st.checkbox("RN50x16", value=False, help="RN50x16 model.") st.session_state["RN50x64"] = st.checkbox("RN50x64", value=False, help="RN50x64 model.") - st.session_state["RN101"] = st.checkbox("RN101", value=False, help="RN101 model.") + st.session_state["RN101"] = st.checkbox("RN101", value=False, help="RN101 model.") # #st.subheader("Logs:") - - st.session_state["log_message"] = st.empty() if not st.session_state["log_message"] else st.session_state["log_message"] - st.session_state["log_message"].code('', language="") + + st.session_state["log_message"] = st.empty() + st.session_state["log_message"].code('', language="") with col2: @@ -297,51 +328,40 @@ def layout(): refresh = st.form_submit_button("Refresh", help='Refresh the image preview to show your uploaded image instead of the default placeholder.') - col1_output, col2_output = st.columns([2,10], gap="medium") - if st.session_state["uploaded_image"]: - if type(st.session_state["uploaded_image"]) != list: + #print (type(st.session_state["uploaded_image"])) + #if len(st.session_state["uploaded_image"]) == 1: + st.session_state["input_image_preview"] = [] + st.session_state["input_image_preview_container"] = [] + st.session_state["prediction_table"] = [] + st.session_state["text_result"] = [] + + for i in range(len(st.session_state["uploaded_image"])): + st.session_state["input_image_preview_container"].append(i) + st.session_state["input_image_preview_container"][i]= st.empty() - with col1_output: - st.session_state["input_image_preview"] = st.empty() - st.session_state["uploaded_image"].pil_image = Image.open(st.session_state["uploaded_image"]).convert('RGB') - - st.session_state["input_image_preview"].image(st.session_state["uploaded_image"].pil_image, use_column_width=True, clamp=True) - - with col2_output: - #with st.container(): - ##st.subheader("Image To Text Result") - - st.session_state["prediction_table"] = st.empty() if not st.session_state["prediction_table"] or refresh else st.session_state["prediction_table"] - st.session_state["prediction_table"].table() if not st.session_state["prediction_table"].table() or refresh else st.session_state["prediction_table"].table() - - st.session_state["text_result"] = st.empty() if not st.session_state["text_result"] or refresh else st.session_state["text_result"] - st.session_state["text_result"].code('', language="") if not st.session_state["text_result"].code('', language="" - ) or refresh else st.session_state["text_result"].code('', language="") - - else: - for i in range(st.session_state["uploaded_image"]): - #for image in st.session_state["uploaded_image"]: - #st.session_state["uploaded_image"].pil_image[i] = [] - st.session_state["uploaded_image"].pil_image[i] = Image.open(st.session_state["uploaded_image"][i]).convert('RGB') - # - #st.write("---") + with st.session_state["input_image_preview_container"][i].container(): + col1_output, col2_output = st.columns([2,10], gap="medium") with col1_output: - st.session_state["input_image_preview"] = st.empty() - st.session_state["uploaded_image"].pil_image = Image.open(st.session_state["uploaded_image"]).convert('RGB') - st.session_state["input_image_preview"].image(st.session_state["uploaded_image"].pil_image, use_column_width=True, clamp=True) + st.session_state["input_image_preview"].append(i) + st.session_state["input_image_preview"][i]= st.empty() + st.session_state["uploaded_image"][i].pil_image = Image.open(st.session_state["uploaded_image"][i]).convert('RGB') + + st.session_state["input_image_preview"][i].image(st.session_state["uploaded_image"][i].pil_image, use_column_width=True, clamp=True) + + with st.session_state["input_image_preview_container"][i].container(): with col2_output: - #with st.container(): - ##st.subheader("Image To Text Result") + + st.session_state["prediction_table"].append(i) + st.session_state["prediction_table"][i] = st.empty() + st.session_state["prediction_table"][i].table() - st.session_state["prediction_table"] = st.empty() if not st.session_state["prediction_table"] or refresh else st.session_state["prediction_table"] - st.session_state["prediction_table"].table() if not st.session_state["prediction_table"].table() or refresh else st.session_state["prediction_table"].table() - - st.session_state["text_result"] = st.empty() if not st.session_state["text_result"] or refresh else st.session_state["text_result"] - st.session_state["text_result"].code('', language="") if not st.session_state["text_result"].code('', language="" - ) or refresh else st.session_state["text_result"].code('', language="") - + st.session_state["text_result"].append(i) + st.session_state["text_result"][i]= st.empty() + st.session_state["text_result"][i].code("", language="") + + else: #st.session_state["input_image_preview"].code('', language="") st.image("images/streamlit/img2txt_placeholder.png", clamp=True) @@ -352,6 +372,17 @@ def layout(): #generate_col1.title("") generate_button = st.form_submit_button("Generate!") - if generate_button: + if generate_button: + # if model, pipe, RealESRGAN or GFPGAN is in st.session_state remove the model and pipe form session_state so that they are reloaded. + if "model" in st.session_state and st.session_state["defaults"].general.optimized: + del st.session_state["model"] + if "pipe" in st.session_state and st.session_state["defaults"].general.optimized: + del st.session_state["pipe"] + if "RealESRGAN" in st.session_state and st.session_state["defaults"].general.optimized: + del st.session_state["RealESRGAN"] + if "GFPGAN" in st.session_state and st.session_state["defaults"].general.optimized: + del st.session_state["GFPGAN"] + + # run clip interrogator img2txt() \ No newline at end of file diff --git a/scripts/sd_utils.py b/scripts/sd_utils.py index 9ee93ba..6a5c8ea 100644 --- a/scripts/sd_utils.py +++ b/scripts/sd_utils.py @@ -870,7 +870,7 @@ def try_loading_LDSR(model_name: str,checking=False): # Loads Stable Diffusion model by name #@retry(tries=5) -def load_sd_model(model_name: str) -> [any, any, any, any, any]: +def load_sd_model(model_name: str): ckpt_path = st.session_state.defaults.general.default_model_path if model_name != st.session_state.defaults.general.default_model: