From aa858e01149d9f32a7dcdf7c766f4355664287de Mon Sep 17 00:00:00 2001 From: hlky <106811348+hlky@users.noreply.github.com> Date: Fri, 30 Sep 2022 19:52:14 +0100 Subject: [PATCH 1/2] Update img2txt.py --- scripts/img2txt.py | 103 +++++++++++++++++++++++++++------------------ 1 file changed, 62 insertions(+), 41 deletions(-) diff --git a/scripts/img2txt.py b/scripts/img2txt.py index 4111001..1b06aed 100644 --- a/scripts/img2txt.py +++ b/scripts/img2txt.py @@ -42,6 +42,7 @@ import streamlit_nested_layout #other imports import clip +import open_clip import gc import os import pandas as pd @@ -60,14 +61,14 @@ from ldm.models.blip import blip_decoder #--------------------------------------------------------------------------------------------------------------- device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') -blip_image_eval_size = 256 +blip_image_eval_size = 512 blip_model = None #blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth' def load_blip_model(): blip_model = blip_decoder(pretrained="models/blip/model__base_caption.pth", image_size=blip_image_eval_size, vit='base', med_config="configs/blip/med_config.json") blip_model.eval() - blip_model = blip_model.to(device).half() + blip_model = blip_model.to('cpu') return blip_model @@ -77,7 +78,7 @@ def generate_caption(pil_image): transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC), transforms.ToTensor(), transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) - ])(pil_image).unsqueeze(0).to(device).half() + ])(pil_image).unsqueeze(0).to('cpu') with torch.no_grad(): caption = blip_model.generate(gpu_image, sample=False, num_beams=3, max_length=20, min_length=5) @@ -107,6 +108,15 @@ def clear_cuda(): torch.cuda.empty_cache() gc.collect() +def batch_rank(model, image_features, text_array, batch_size=42): + batch_count = len(text_array) // batch_size + batches = [text_array[i*batch_size:(i+1)*batch_size] for i in range(batch_count)] + batches.append(text_array[batch_count*batch_size:]) + ranks = [] + for batch in batches: + ranks += rank(model, image_features, batch) + return ranks + def interrogate(image, models): global blip_model blip_model = load_blip_model() @@ -126,47 +136,51 @@ def interrogate(image, models): for model_name in models: print(f"Interrogating {model_name}") st.session_state["log_message"].code(f"Interrogating with {model_name}...", language='') - model, preprocess = clip.load(model_name) - model.cuda().eval() + if model_name == 'ViT-H-14': + model, _, preprocess = open_clip.create_model_and_transforms(model_name, pretrained='laion2b_s32b_b79k').to(device).half() + elif model_name == 'ViT-g-14': + model, _, preprocess = open_clip.create_model_and_transforms(model_name, pretrained='laion2b_s12b_b42k').to(device).half() + else: + model, preprocess = clip.load(model_name, device=device) + model.to(device).half() - images = preprocess(image).unsqueeze(0).cuda() - with torch.no_grad(): - image_features = model.encode_image(images).float() - image_features /= image_features.norm(dim=-1, keepdim=True) - clear_cuda() - - ranks = [] - ranks.append(rank(model, image_features, server_state["mediums"])) - clear_cuda() - artists = [] - for batch in range(int(len(server_state["artists"])/1000)): - artist_rank = rank(model, image_features, server_state["artists"][batch*1000:(batch+1)*1000]) - artists.extend(artist_rank) + model.eval() + with torch.no_grad(), torch.autocast('cuda', dtype=torch.float16): + image_features = model.encode_image(preprocess(image).unsqueeze(0).to(device)).float() + image_features /= image_features.norm(dim=-1, keepdim=True) clear_cuda() - ranks.append(artists) - ranks.append(rank(model, image_features, server_state["trending_list"])) - clear_cuda() - ranks.append(rank(model, image_features, server_state["movements"])) - clear_cuda() - ranks.append(rank(model, image_features, server_state["flavors"], top_count=3)) - clear_cuda() + + ranks = [] + ranks.append(batch_rank(model, image_features, server_state["mediums"])) + ranks.append(batch_rank(model, image_features, ["by "+artist for artist in server_state["artists"]])) + ranks.append(batch_rank(model, image_features, server_state["trending_list"])) + ranks.append(batch_rank(model, image_features, server_state["movements"])) + ranks.append(batch_rank(model, image_features, server_state["flavors"])) + # ranks.append(batch_rank(model, image_features, server_state["genres"])) + # ranks.append(batch_rank(model, image_features, server_state["styles"])) + # ranks.append(batch_rank(model, image_features, server_state["techniques"])) + # ranks.append(batch_rank(model, image_features, server_state["subjects"])) + # ranks.append(batch_rank(model, image_features, server_state["colors"])) + # ranks.append(batch_rank(model, image_features, server_state["moods"])) + # ranks.append(batch_rank(model, image_features, server_state["themes"])) + # ranks.append(batch_rank(model, image_features, server_state["keywords"])) + # ranks.append(batch_rank(model, image_features, server_state["artists"])) + for i in range(len(ranks)): + confidence_sum = 0 + for ci in range(len(ranks[i])): + confidence_sum += ranks[i][ci][1] + if confidence_sum > sum(bests[i][t][1] for t in range(len(bests[i]))): + bests[i] = ranks[i] - for i in range(len(ranks)): - confidence_sum = 0 - for ci in range(len(ranks[i])): - confidence_sum += ranks[i][ci][1] - if confidence_sum > sum(bests[i][t][1] for t in range(len(bests[i]))): - bests[i] = ranks[i] + row = [model_name] + for r in ranks: + row.append(', '.join([f"{x[0]} ({x[1]:0.1f}%)" for x in r])) - row = [model_name] - for r in ranks: - row.append(', '.join([f"{x[0]} ({x[1]:0.1f}%)" for x in r])) + table.append(row) - table.append(row) - - del model - gc.collect() + del model + gc.collect() st.session_state["prediction_table"].dataframe(pd.DataFrame(table, columns=["Model", "Medium", "Artist", "Trending", "Movement", "Flavors"])) @@ -204,6 +218,10 @@ def img2txt(): models.append('ViT-B/16') if st.session_state["ViTL14"]: models.append('ViT-L/14') + if st.session_state["ViT-H-14"]: + models.append('ViT-H-14') + if st.session_state["ViT-g-14"]: + models.append('ViT-g-14') if st.session_state["ViTL14_336px"]: models.append('ViT-L/14@336px') if st.session_state["RN101"]: @@ -226,7 +244,7 @@ def img2txt(): #thumb.thumbnail([blip_image_eval_size, blip_image_eval_size]) #display(thumb) - interrogate(st.session_state["uploaded_image"].pil_image, models=models) + interrogate(st.session_state["uploaded_image"].pil_image.convert('RGB'), models) # def layout(): @@ -259,7 +277,10 @@ def layout(): st.session_state["RN50x4"] = st.checkbox("RN50x4", value=False, help="RN50x4 model.") st.session_state["RN50x16"] = st.checkbox("RN50x16", value=False, help="RN50x16 model.") st.session_state["RN50x64"] = st.checkbox("RN50x64", value=False, help="RN50x64 model.") - st.session_state["RN101"] = st.checkbox("RN101", value=False, help="RN101 model.") + st.session_state["RN101"] = st.checkbox("RN101", value=False, help="RN101 model.") + st.session_state["ViT-H-14"] = st.checkbox("ViT-H-14", value=False, help="ViT-H-14 model.") + st.session_state["ViT-g-14"] = st.checkbox("ViT-g-14", value=False, help="ViT-g-14 model.") + with col2: @@ -270,7 +291,7 @@ def layout(): st.session_state["input_image_preview"] = st.empty() if st.session_state["uploaded_image"]: - st.session_state["uploaded_image"].pil_image = Image.open(st.session_state["uploaded_image"])#.convert('RGBA') + st.session_state["uploaded_image"].pil_image = Image.open(st.session_state["uploaded_image"]) #new_img = image.resize((width, height)) st.session_state["input_image_preview"].image(st.session_state["uploaded_image"].pil_image, clamp=True) else: From f1a98997d6f2409f533f461dfb589e0204df31de Mon Sep 17 00:00:00 2001 From: hlky <106811348+hlky@users.noreply.github.com> Date: Fri, 30 Sep 2022 19:54:01 +0100 Subject: [PATCH 2/2] Update img2txt.py --- scripts/img2txt.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/img2txt.py b/scripts/img2txt.py index 1b06aed..39cd786 100644 --- a/scripts/img2txt.py +++ b/scripts/img2txt.py @@ -68,7 +68,7 @@ blip_model = None def load_blip_model(): blip_model = blip_decoder(pretrained="models/blip/model__base_caption.pth", image_size=blip_image_eval_size, vit='base', med_config="configs/blip/med_config.json") blip_model.eval() - blip_model = blip_model.to('cpu') + blip_model = blip_model.to(device).half() return blip_model @@ -78,7 +78,7 @@ def generate_caption(pil_image): transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC), transforms.ToTensor(), transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) - ])(pil_image).unsqueeze(0).to('cpu') + ])(pil_image).unsqueeze(0).to(device).half() with torch.no_grad(): caption = blip_model.generate(gpu_image, sample=False, num_beams=3, max_length=20, min_length=5) @@ -123,7 +123,7 @@ def interrogate(image, models): print ("Generating Caption") st.session_state["log_message"].code("Generating Caption", language='') caption = generate_caption(image) - del blip_model + blip_model.to('cpu') clear_cuda() print ("Caption Generated")