From 857608c5f6c2d4f67cf4f869717788d6c1427137 Mon Sep 17 00:00:00 2001 From: ZeroCool940711 Date: Fri, 30 Sep 2022 15:20:02 -0700 Subject: [PATCH] ... --- configs/webui/webui_streamlit.yaml | 2 +- scripts/img2txt.py | 251 +++++++++++++++-------------- 2 files changed, 133 insertions(+), 120 deletions(-) diff --git a/configs/webui/webui_streamlit.yaml b/configs/webui/webui_streamlit.yaml index aac5bf1..63dfbe6 100644 --- a/configs/webui/webui_streamlit.yaml +++ b/configs/webui/webui_streamlit.yaml @@ -294,7 +294,7 @@ img2img: img2txt: batch_size: 100 - + blip_image_eval_size: 512 concepts_library: concepts_per_page: 12 diff --git a/scripts/img2txt.py b/scripts/img2txt.py index 0c0bcea..5a399a9 100644 --- a/scripts/img2txt.py +++ b/scripts/img2txt.py @@ -12,9 +12,9 @@ # GNU Affero General Public License for more details. # You should have received a copy of the GNU Affero General Public License -# along with this program. If not, see . +# along with this program. If not, see . -#--------------------------------------------------------------------------------------------------------------------------------------------------- +# --------------------------------------------------------------------------------------------------------------------------------------------------- """ CLIP Interrogator made by @pharmapsychotic modified to work with our WebUI. @@ -31,20 +31,20 @@ Please consider buying him a coffee via [ko-fi](https://ko-fi.com/pharmapsychoti And if you're looking for more Ai art tools check out my [Ai generative art tools list](https://pharmapsychotic.com/tools.html). """ -#--------------------------------------------------------------------------------------------------------------------------------------------------- +# --------------------------------------------------------------------------------------------------------------------------------------------------- # base webui import and utils. -from ldm.util import default from sd_utils import * # streamlit imports -#streamlit components section +# streamlit components section import streamlit_nested_layout -#other imports +# other imports -import clip, open_clip +import clip +import open_clip import gc import os import pandas as pd @@ -56,53 +56,59 @@ from torchvision.transforms.functional import InterpolationMode from ldm.models.blip import blip_decoder # end of imports -#--------------------------------------------------------------------------------------------------------------- +# --------------------------------------------------------------------------------------------------------------- device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') blip_image_eval_size = 512 -blip_model = None -#blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth' +#blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth' + def load_blip_model(): - print ("Loading BLIP Model") + print("Loading BLIP Model") st.session_state["log_message"].code("Loading BLIP Model", language='') - with server_state_lock['blip_model']: - if "blip_model" not in server_state: - blip_model = blip_decoder(pretrained="models/blip/model__base_caption.pth", - image_size=blip_image_eval_size, vit='base', med_config="configs/blip/med_config.json") - blip_model.eval() - blip_model = blip_model.to(device).half() - - print ("BLIP Model Loaded") + if "blip_model" not in server_state: + with server_state_lock['blip_model']: + server_state["blip_model"] = blip_decoder(pretrained="models/blip/model__base_caption.pth", + image_size=blip_image_eval_size, vit='base', med_config="configs/blip/med_config.json") + + server_state["blip_model"] = server_state["blip_model"].eval() + + #if not st.session_state["defaults"].general.optimized: + server_state["blip_model"] = server_state["blip_model"].to(device).half() + + print("BLIP Model Loaded") st.session_state["log_message"].code("BLIP Model Loaded", language='') - else: - print ("BLIP Model already loaded") - st.session_state["log_message"].code("BLIP Model Already Loaded", language='') + else: + print("BLIP Model already loaded") + st.session_state["log_message"].code("BLIP Model Already Loaded", language='') + + #return server_state["blip_model"] - return blip_model def generate_caption(pil_image): - global blip_model - #width, height = pil_image.size - gpu_image = transforms.Compose([ - transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC), - transforms.ToTensor(), - transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) - ])(pil_image).unsqueeze(0).to(device).half() + load_blip_model() + + gpu_image = transforms.Compose([ # type: ignore + transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC), # type: ignore + transforms.ToTensor(), # type: ignore + transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) # type: ignore + ])(pil_image).unsqueeze(0).to(device).half() with torch.no_grad(): - caption = blip_model.generate(gpu_image, sample=False, num_beams=3, max_length=20, min_length=5) + caption = server_state["blip_model"].generate(gpu_image, sample=False, num_beams=3, max_length=20, min_length=5) #print (caption) return caption[0] + def load_list(filename): with open(filename, 'r', encoding='utf-8', errors='replace') as f: items = [line.strip() for line in f.readlines()] return items - + + def rank(model, image_features, text_array, top_count=1): top_count = min(top_count, len(text_array)) text_tokens = clip.tokenize([text for text in text_array]).cuda() @@ -115,13 +121,15 @@ def rank(model, image_features, text_array, top_count=1): similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1) similarity /= image_features.shape[0] - top_probs, top_labels = similarity.cpu().topk(top_count, dim=-1) + top_probs, top_labels = similarity.cpu().topk(top_count, dim=-1) return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)] + def clear_cuda(): torch.cuda.empty_cache() gc.collect() + def batch_rank(model, image_features, text_array, batch_size=st.session_state["defaults"].img2txt.batch_size): batch_count = len(text_array) // batch_size batches = [text_array[i*batch_size:(i+1)*batch_size] for i in range(batch_count)] @@ -132,18 +140,19 @@ def batch_rank(model, image_features, text_array, batch_size=st.session_state["d return ranks def interrogate(image, models): - global blip_model + + #server_state["blip_model"] = + load_blip_model() - blip_model = load_blip_model() - print ("Generating Caption") + print("Generating Caption") st.session_state["log_message"].code("Generating Caption", language='') caption = generate_caption(image) - + if st.session_state["defaults"].general.optimized: - del blip_model + del server_state["blip_model"] clear_cuda() - print ("Caption Generated") + print("Caption Generated") st.session_state["log_message"].code("Caption Generated", language='') if len(models) == 0: @@ -151,44 +160,48 @@ def interrogate(image, models): return table = [] - bests = [[('',0)]]*5 - - print ("Ranking Text") + bests = [[('', 0)]]*5 + + print("Ranking Text") for model_name in models: print(f"Interrogating with {model_name}...") st.session_state["log_message"].code(f"Interrogating with {model_name}...", language='') - if model_name == 'ViT-H-14': - model, _, preprocess = open_clip.create_model_and_transforms(model_name, pretrained='laion2b_s32b_b79k') - elif model_name == 'ViT-g-14': - model, _, preprocess = open_clip.create_model_and_transforms(model_name, pretrained='laion2b_s12b_b42k') - else: - model, preprocess = clip.load(model_name, device=device) - - model.cuda().eval() - - images = preprocess(image).unsqueeze(0).cuda() - with torch.no_grad(): - image_features = model.encode_image(images).float() - image_features /= image_features.norm(dim=-1, keepdim=True) + if "clip_model" not in server_state: + #with server_state_lock[server_state["clip_model"]]: + if model_name == 'ViT-H-14': + server_state["clip_model"], _, server_state["preprocess"] = open_clip.create_model_and_transforms(model_name, pretrained='laion2b_s32b_b79k') + elif model_name == 'ViT-g-14': + server_state["clip_model"], _, server_state["preprocess"] = open_clip.create_model_and_transforms(model_name, pretrained='laion2b_s12b_b42k') + else: + server_state["clip_model"], server_state["preprocess"] = clip.load(model_name, device=device) + + server_state["clip_model"] = server_state["clip_model"].cuda().eval() + + images = server_state["preprocess"](image).unsqueeze(0).cuda() + + with torch.no_grad(): + image_features = server_state["clip_model"].encode_image(images).float() + + image_features /= image_features.norm(dim=-1, keepdim=True) + if st.session_state["defaults"].general.optimized: clear_cuda() - + ranks = [] - ranks.append(batch_rank(model, image_features, server_state["mediums"])) - ranks.append(batch_rank(model, image_features, ["by "+artist for artist in server_state["artists"]])) - ranks.append(batch_rank(model, image_features, server_state["trending_list"])) - ranks.append(batch_rank(model, image_features, server_state["movements"])) - ranks.append(batch_rank(model, image_features, server_state["flavors"])) - # ranks.append(batch_rank(model, image_features, server_state["genres"])) - # ranks.append(batch_rank(model, image_features, server_state["styles"])) - # ranks.append(batch_rank(model, image_features, server_state["techniques"])) - # ranks.append(batch_rank(model, image_features, server_state["subjects"])) - # ranks.append(batch_rank(model, image_features, server_state["colors"])) - # ranks.append(batch_rank(model, image_features, server_state["moods"])) - # ranks.append(batch_rank(model, image_features, server_state["themes"])) - # ranks.append(batch_rank(model, image_features, server_state["keywords"])) - + ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["mediums"])) + ranks.append(batch_rank(server_state["clip_model"], image_features, ["by "+artist for artist in server_state["artists"]])) + ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["trending_list"])) + ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["movements"])) + ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["flavors"])) + # ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["genres"])) + # ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["styles"])) + # ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["techniques"])) + # ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["subjects"])) + # ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["colors"])) + # ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["moods"])) + # ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["themes"])) + # ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["keywords"])) for i in range(len(ranks)): confidence_sum = 0 @@ -204,27 +217,28 @@ def interrogate(image, models): table.append(row) if st.session_state["defaults"].general.optimized: - del model + del server_state["clip_model"] gc.collect() - - #for i in range(len(st.session_state["uploaded_image"])): + + # for i in range(len(st.session_state["uploaded_image"])): st.session_state["prediction_table"][st.session_state["processed_image_count"]].dataframe(pd.DataFrame( table, columns=["Model", "Medium", "Artist", "Trending", "Movement", "Flavors"])) flaves = ', '.join([f"{x[0]}" for x in bests[4]]) medium = bests[0][0][0] - if caption.startswith(medium): + if caption.startswith(medium): st.session_state["text_result"][st.session_state["processed_image_count"]].code( f"\n\n{caption} {bests[1][0][0]}, {bests[2][0][0]}, {bests[3][0][0]}, {flaves}", language="") else: st.session_state["text_result"][st.session_state["processed_image_count"]].code( f"\n\n{caption}, {medium} {bests[1][0][0]}, {bests[2][0][0]}, {bests[3][0][0]}, {flaves}", language="") - - # - print ("Finished Interrogating.") + + # + print("Finished Interrogating.") st.session_state["log_message"].code("Finished Interrogating.", language="") # + def img2txt(): data_path = "data/" @@ -251,11 +265,11 @@ def img2txt(): models.append('ViT-B/32') if st.session_state['ViTB16']: models.append('ViT-B/16') - if st.session_state["ViTL14"]: + if st.session_state["ViTL14"]: models.append('ViT-L/14') - if st.session_state["ViT-H-14"]: + if st.session_state["ViT-H-14"]: models.append('ViT-H-14') - if st.session_state["ViT-g-14"]: + if st.session_state["ViT-g-14"]: models.append('ViT-g-14') if st.session_state["ViTL14_336px"]: models.append('ViT-L/14@336px') @@ -270,33 +284,35 @@ def img2txt(): if st.session_state["RN50x64"]: models.append('RN50x64') - #if str(image_path_or_url).startswith('http://') or str(image_path_or_url).startswith('https://'): + # if str(image_path_or_url).startswith('http://') or str(image_path_or_url).startswith('https://'): #image = Image.open(requests.get(image_path_or_url, stream=True).raw).convert('RGB') - #else: + # else: #image = Image.open(image_path_or_url).convert('RGB') #thumb = st.session_state["uploaded_image"].image.copy() #thumb.thumbnail([blip_image_eval_size, blip_image_eval_size]) - #display(thumb) + # display(thumb) st.session_state["processed_image_count"] = 0 - + for i in range(len(st.session_state["uploaded_image"])): - + interrogate(st.session_state["uploaded_image"][i].pil_image, models=models) # increase counter. st.session_state["processed_image_count"] += 1 # + + def layout(): #set_page_title("Image-to-Text - Stable Diffusion WebUI") - #st.info("Under Construction. :construction_worker:") + #st.info("Under Construction. :construction_worker:") with st.form("img2txt-inputs"): st.session_state["generation_mode"] = "img2txt" - #st.write("---") + # st.write("---") # creating the page layout using columns - col1, col2 = st.columns([1,4], gap="large") + col1, col2 = st.columns([1, 4], gap="large") with col1: #url = st.text_area("Input Text","") @@ -304,68 +320,66 @@ def layout(): #st.subheader("Input Image") st.session_state["uploaded_image"] = st.file_uploader('Input Image', type=['png', 'jpg', 'jpeg'], accept_multiple_files=True) - st.subheader("CLIP models") + st.subheader("CLIP models") with st.expander("Stable Diffusion", expanded=True): st.session_state["ViTL14"] = st.checkbox("ViTL14", value=True, help="For StableDiffusion you can just use ViTL14.") with st.expander("Others"): st.info("For DiscoDiffusion and JAX enable all the same models here as you intend to use when generating your images.") st.session_state["ViT-H-14"] = st.checkbox("ViT-H-14", value=False, help="ViT-H-14 model.") - st.session_state["ViT-g-14"] = st.checkbox("ViT-g-14", value=False, help="ViT-g-14 model.") - st.session_state["ViTL14_336px"] = st.checkbox("ViTL14_336px", value=False, help="ViTL14_336px model.") - st.session_state["ViTB16"] = st.checkbox("ViTB16", value=False, help="ViTB16 model.") + st.session_state["ViT-g-14"] = st.checkbox("ViT-g-14", value=False, help="ViT-g-14 model.") + st.session_state["ViTL14_336px"] = st.checkbox("ViTL14_336px", value=False, help="ViTL14_336px model.") + st.session_state["ViTB16"] = st.checkbox("ViTB16", value=False, help="ViTB16 model.") st.session_state["ViTB32"] = st.checkbox("ViTB32", value=False, help="ViTB32 model.") - st.session_state["RN50"] = st.checkbox("RN50", value=False, help="RN50 model.") - st.session_state["RN50x4"] = st.checkbox("RN50x4", value=False, help="RN50x4 model.") - st.session_state["RN50x16"] = st.checkbox("RN50x16", value=False, help="RN50x16 model.") + st.session_state["RN50"] = st.checkbox("RN50", value=False, help="RN50 model.") + st.session_state["RN50x4"] = st.checkbox("RN50x4", value=False, help="RN50x4 model.") + st.session_state["RN50x16"] = st.checkbox("RN50x16", value=False, help="RN50x16 model.") st.session_state["RN50x64"] = st.checkbox("RN50x64", value=False, help="RN50x64 model.") st.session_state["RN101"] = st.checkbox("RN101", value=False, help="RN101 model.") # - #st.subheader("Logs:") + # st.subheader("Logs:") st.session_state["log_message"] = st.empty() st.session_state["log_message"].code('', language="") - with col2: st.subheader("Image") refresh = st.form_submit_button("Refresh", help='Refresh the image preview to show your uploaded image instead of the default placeholder.') - + if st.session_state["uploaded_image"]: #print (type(st.session_state["uploaded_image"])) - #if len(st.session_state["uploaded_image"]) == 1: + # if len(st.session_state["uploaded_image"]) == 1: st.session_state["input_image_preview"] = [] st.session_state["input_image_preview_container"] = [] st.session_state["prediction_table"] = [] st.session_state["text_result"] = [] - + for i in range(len(st.session_state["uploaded_image"])): st.session_state["input_image_preview_container"].append(i) - st.session_state["input_image_preview_container"][i]= st.empty() - + st.session_state["input_image_preview_container"][i] = st.empty() + with st.session_state["input_image_preview_container"][i].container(): - col1_output, col2_output = st.columns([2,10], gap="medium") + col1_output, col2_output = st.columns([2, 10], gap="medium") with col1_output: st.session_state["input_image_preview"].append(i) - st.session_state["input_image_preview"][i]= st.empty() + st.session_state["input_image_preview"][i] = st.empty() st.session_state["uploaded_image"][i].pil_image = Image.open(st.session_state["uploaded_image"][i]).convert('RGB') - + st.session_state["input_image_preview"][i].image(st.session_state["uploaded_image"][i].pil_image, use_column_width=True, clamp=True) - - with st.session_state["input_image_preview_container"][i].container(): - + + with st.session_state["input_image_preview_container"][i].container(): + with col2_output: st.session_state["prediction_table"].append(i) st.session_state["prediction_table"][i] = st.empty() st.session_state["prediction_table"][i].table() - + st.session_state["text_result"].append(i) - st.session_state["text_result"][i]= st.empty() + st.session_state["text_result"][i] = st.empty() st.session_state["text_result"][i].code("", language="") - else: #st.session_state["input_image_preview"].code('', language="") @@ -373,13 +387,13 @@ def layout(): # # Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way. - #generate_col1.title("") - #generate_col1.title("") - generate_button = st.form_submit_button("Generate!") + # generate_col1.title("") + # generate_col1.title("") + generate_button = st.form_submit_button("Generate!") - if generate_button: + if generate_button: # if model, pipe, RealESRGAN or GFPGAN is in st.session_state remove the model and pipe form session_state so that they are reloaded. - if "model" in st.session_state and st.session_state["defaults"].general.optimized: + if "model" in st.session_state and st.session_state["defaults"].general.optimized: del st.session_state["model"] if "pipe" in st.session_state and st.session_state["defaults"].general.optimized: del st.session_state["pipe"] @@ -387,7 +401,6 @@ def layout(): del st.session_state["RealESRGAN"] if "GFPGAN" in st.session_state and st.session_state["defaults"].general.optimized: del st.session_state["GFPGAN"] - - + # run clip interrogator - img2txt() \ No newline at end of file + img2txt()