From 466cfde3fa198b1d8d14e48fbe6d8054fe4eb54f Mon Sep 17 00:00:00 2001 From: ZeroCool940711 Date: Thu, 29 Sep 2022 07:27:56 -0700 Subject: [PATCH 1/2] More fixes for img2txt --- scripts/img2txt.py | 50 ++++++++++++++++++++++++---------------------- 1 file changed, 26 insertions(+), 24 deletions(-) diff --git a/scripts/img2txt.py b/scripts/img2txt.py index 5940cb7..6d3167b 100644 --- a/scripts/img2txt.py +++ b/scripts/img2txt.py @@ -45,9 +45,8 @@ import clip import gc import os import pandas as pd -import requests +#import requests import torch -from IPython.display import display from PIL import Image #from torch import nn #from torch.nn import functional as F @@ -60,7 +59,7 @@ from ldm.models.blip import blip_decoder device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') -blip_image_eval_size = 384 +blip_image_eval_size = 256 #blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth' def generate_caption(pil_image): @@ -99,7 +98,10 @@ def rank(model, image_features, text_array, top_count=1): return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)] def interrogate(image, models): + print ("Generating Caption") + st.session_state["log_message"].code("Generating Caption", language='') caption = generate_caption(image) + if len(models) == 0: print(f"\n\n{caption}") return @@ -140,7 +142,7 @@ def interrogate(image, models): del model gc.collect() - display(pd.DataFrame(table, columns=["Model", "Medium", "Artist", "Trending", "Movement", "Flavors"])) + st.session_state["prediction_table"].dataframe(pd.DataFrame(table, columns=["Model", "Medium", "Artist", "Trending", "Movement", "Flavors"])) flaves = ', '.join([f"{x[0]}" for x in bests[4]]) medium = bests[0][0][0] @@ -165,39 +167,39 @@ def img2txt(): trending_list.extend(["featured on "+site for site in server_state["sites"]]) trending_list.extend([site+" contest winner" for site in server_state["sites"]]) - image_path_or_url = "https://i.redd.it/e2e8gimigjq91.jpg" + #image_path_or_url = "https://i.redd.it/e2e8gimigjq91.jpg" - st.session_state["models"] = [] + models = [] if st.session_state["ViTB32"]: - st.session_state["models"].append('ViT-B/32') + models.append('ViT-B/32') if st.session_state['ViTB16']: - st.session_state["models"].append('ViT-B/16') + models.append('ViT-B/16') if st.session_state["ViTL14"]: - st.session_state["models"].append('ViT-L/14') + models.append('ViT-L/14') if st.session_state["ViTL14_336px"]: - st.session_state["models"].append('ViT-L/14@336px') + models.append('ViT-L/14@336px') if st.session_state["RN101"]: - st.session_state["models"].append('RN101') + models.append('RN101') if st.session_state["RN50"]: - st.session_state["models"].append('RN50') + models.append('RN50') if st.session_state["RN50x4"]: - st.session_state["models"].append('RN50x4') + models.append('RN50x4') if st.session_state["RN50x16"]: - st.session_state["models"].append('RN50x16') + models.append('RN50x16') if st.session_state["RN50x64"]: - st.session_state["models"].append('RN50x64') + models.append('RN50x64') - if str(image_path_or_url).startswith('http://') or str(image_path_or_url).startswith('https://'): - image = Image.open(requests.get(image_path_or_url, stream=True).raw).convert('RGB') - else: - image = Image.open(image_path_or_url).convert('RGB') + #if str(image_path_or_url).startswith('http://') or str(image_path_or_url).startswith('https://'): + #image = Image.open(requests.get(image_path_or_url, stream=True).raw).convert('RGB') + #else: + #image = Image.open(image_path_or_url).convert('RGB') - thumb = image.copy() - thumb.thumbnail([blip_image_eval_size, blip_image_eval_size]) + #thumb = st.session_state["uploaded_image"].image.copy() + #thumb.thumbnail([blip_image_eval_size, blip_image_eval_size]) #display(thumb) - interrogate(image, models=st.session_state["models"]) + interrogate(st.session_state["uploaded_image"].pil_image, models=models) # def layout(): @@ -241,9 +243,9 @@ def layout(): st.session_state["input_image_preview"] = st.empty() if st.session_state["uploaded_image"]: - st.session_state["uploaded_image"].image = Image.open(st.session_state["uploaded_image"]).convert('RGBA') + st.session_state["uploaded_image"].pil_image = Image.open(st.session_state["uploaded_image"])#.convert('RGBA') #new_img = image.resize((width, height)) - st.session_state["input_image_preview"].image(st.session_state["uploaded_image"].image, clamp=True) + st.session_state["input_image_preview"].image(st.session_state["uploaded_image"].pil_image, clamp=True) else: #st.session_state["input_image_preview"].code('', language="") st.image("images/streamlit/img2txt_placeholder.png", clamp=True) From 5445287bcf5184d7d225695a01d69ab101c80ae7 Mon Sep 17 00:00:00 2001 From: ZeroCool940711 Date: Thu, 29 Sep 2022 08:52:46 -0700 Subject: [PATCH 2/2] Img2txt working now but needs more than 8GB og VRAM to work. Will be trying to improve it as the next step. --- scripts/img2txt.py | 36 ++++++++++++++++++++++++++---------- 1 file changed, 26 insertions(+), 10 deletions(-) diff --git a/scripts/img2txt.py b/scripts/img2txt.py index 6d3167b..a40eb6e 100644 --- a/scripts/img2txt.py +++ b/scripts/img2txt.py @@ -60,7 +60,22 @@ from ldm.models.blip import blip_decoder device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') blip_image_eval_size = 256 -#blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth' +#blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth' + +def load_blip_model(): + blip_model = blip_decoder(pretrained="models/blip/model__base_caption.pth", image_size=blip_image_eval_size, vit='base', med_config="configs/blip/med_config.json") + blip_model.eval() + + return blip_model + +def load_clip_model(clip_model_name): + import clip + + model, preprocess = clip.load(clip_model_name) + model.eval() + model = model.to(device) + + return model, preprocess def generate_caption(pil_image): blip_model = blip_decoder(pretrained="models/blip/model__base_caption.pth", image_size=blip_image_eval_size, vit='base', med_config="configs/blip/med_config.json") @@ -109,8 +124,8 @@ def interrogate(image, models): table = [] bests = [[('',0)]]*5 for model_name in models: - print(f"Interrogating with {model_name}...") - model, preprocess = clip.load(model_name) + st.session_state["log_message"].code(f"Interrogating with {model_name}...", language='') + model, preprocess = load_clip_model(model_name) model.cuda().eval() images = preprocess(image).unsqueeze(0).cuda() @@ -147,10 +162,11 @@ def interrogate(image, models): flaves = ', '.join([f"{x[0]}" for x in bests[4]]) medium = bests[0][0][0] if caption.startswith(medium): - print(f"\n\n{caption} {bests[1][0][0]}, {bests[2][0][0]}, {bests[3][0][0]}, {flaves}") + st.session_state["text_result"].code(f"\n\n{caption} {bests[1][0][0]}, {bests[2][0][0]}, {bests[3][0][0]}, {flaves}", language="") else: - print(f"\n\n{caption}, {medium} {bests[1][0][0]}, {bests[2][0][0]}, {bests[3][0][0]}, {flaves}") - + st.session_state["text_result"].code(f"\n\n{caption}, {medium} {bests[1][0][0]}, {bests[2][0][0]}, {bests[3][0][0]}, {flaves}", language="") + + st.session_state["log_message"].code("Finished Interrogating.", language="") # def img2txt(): @@ -162,10 +178,10 @@ def img2txt(): server_state["movements"] = load_list(os.path.join(data_path, 'img2txt', 'movements.txt')) server_state["sites"] = load_list(os.path.join(data_path, 'img2txt', 'sites.txt')) - trending_list = [site for site in server_state["sites"]] - trending_list.extend(["trending on "+site for site in server_state["sites"]]) - trending_list.extend(["featured on "+site for site in server_state["sites"]]) - trending_list.extend([site+" contest winner" for site in server_state["sites"]]) + server_state["trending_list"] = [site for site in server_state["sites"]] + server_state["trending_list"].extend(["trending on "+site for site in server_state["sites"]]) + server_state["trending_list"].extend(["featured on "+site for site in server_state["sites"]]) + server_state["trending_list"].extend([site+" contest winner" for site in server_state["sites"]]) #image_path_or_url = "https://i.redd.it/e2e8gimigjq91.jpg"