From 8dcd4326749508823b76b39c1eb29a81f383845c Mon Sep 17 00:00:00 2001 From: hlky <106811348+hlky@users.noreply.github.com> Date: Mon, 24 Oct 2022 11:48:45 +0100 Subject: [PATCH] Update img2txt.py --- scripts/img2txt.py | 105 ++++++++++++++++----------------------------- 1 file changed, 36 insertions(+), 69 deletions(-) diff --git a/scripts/img2txt.py b/scripts/img2txt.py index 7bc41dc..c6050f1 100644 --- a/scripts/img2txt.py +++ b/scripts/img2txt.py @@ -61,16 +61,9 @@ from ldm.models.blip import blip_decoder device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') blip_image_eval_size = 512 -#blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth' -server_state["clip_models"] = {} -server_state["preprocesses"] = {} - -st.session_state["log"] = [] def load_blip_model(): logger.info("Loading BLIP Model") - st.session_state["log"].append("Loading BLIP Model") - st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='') if "blip_model" not in server_state: with server_state_lock['blip_model']: @@ -79,18 +72,12 @@ def load_blip_model(): server_state["blip_model"] = server_state["blip_model"].eval() - #if not st.session_state["defaults"].general.optimized: server_state["blip_model"] = server_state["blip_model"].to(device).half() logger.info("BLIP Model Loaded") - st.session_state["log"].append("BLIP Model Loaded") - st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='') else: logger.info("BLIP Model already loaded") - st.session_state["log"].append("BLIP Model already loaded") - st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='') - #return server_state["blip_model"] def generate_caption(pil_image): @@ -105,7 +92,6 @@ def generate_caption(pil_image): with torch.no_grad(): caption = server_state["blip_model"].generate(gpu_image, sample=False, num_beams=3, max_length=20, min_length=5) - #print (caption) return caption[0] def load_list(filename): @@ -135,10 +121,6 @@ def clear_cuda(): def batch_rank(model, image_features, text_array, batch_size=st.session_state["defaults"].img2txt.batch_size): - #logger.info("Ranking") - #st.session_state["log"].append("Ranking") - #st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='') - batch_size = min(batch_size, len(text_array)) batch_count = int(len(text_array) / batch_size) batches = [text_array[i*batch_size:(i+1)*batch_size] for i in range(batch_count)] @@ -148,13 +130,9 @@ def batch_rank(model, image_features, text_array, batch_size=st.session_state["d return ranks def interrogate(image, models): - - #server_state["blip_model"] = load_blip_model() logger.info("Generating Caption") - st.session_state["log"].append("Generating Caption") - st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='') caption = generate_caption(image) if st.session_state["defaults"].general.optimized: @@ -162,8 +140,6 @@ def interrogate(image, models): clear_cuda() logger.info("Caption Generated") - st.session_state["log"].append("Caption Generated") - st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='') if len(models) == 0: logger.info(f"\n\n{caption}") @@ -174,16 +150,9 @@ def interrogate(image, models): logger.info("Ranking Text") - #if "clip_model" in server_state: - #print (server_state["clip_model"]) - - #print (st.session_state["log_message"]) - for model_name in models: with torch.no_grad(), torch.autocast('cuda', dtype=torch.float16): logger.info(f"Interrogating with {model_name}...") - st.session_state["log"].append(f"Interrogating with {model_name}...") - st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='') if model_name not in server_state["clip_models"]: if not st.session_state["defaults"].img2txt.keep_all_models_loaded: @@ -196,15 +165,14 @@ def interrogate(image, models): del server_state["preprocesses"][model] clear_cuda() if model_name == 'ViT-H-14': - server_state["clip_models"][model_name], _, server_state["preprocesses"][model_name] = open_clip.create_model_and_transforms(model_name, - pretrained='laion2b_s32b_b79k', - cache_dir='models/clip') + server_state["clip_models"][model_name], _, server_state["preprocesses"][model_name] = \ + open_clip.create_model_and_transforms(model_name, pretrained='laion2b_s32b_b79k', cache_dir='models/clip') elif model_name == 'ViT-g-14': - server_state["clip_models"][model_name], _, server_state["preprocesses"][model_name] = open_clip.create_model_and_transforms(model_name, - pretrained='laion2b_s12b_b42k', - cache_dir='models/clip') + server_state["clip_models"][model_name], _, server_state["preprocesses"][model_name] = \ + open_clip.create_model_and_transforms(model_name, pretrained='laion2b_s12b_b42k', cache_dir='models/clip') else: - server_state["clip_models"][model_name], server_state["preprocesses"][model_name] = clip.load(model_name, device=device, download_root='models/clip') + server_state["clip_models"][model_name], server_state["preprocesses"][model_name] = \ + clip.load(model_name, device=device, download_root='models/clip') server_state["clip_models"][model_name] = server_state["clip_models"][model_name].cuda().eval() images = server_state["preprocesses"][model_name](image).unsqueeze(0).cuda() @@ -261,7 +229,6 @@ def interrogate(image, models): del server_state["clip_models"][model_name] gc.collect() - # for i in range(len(st.session_state["uploaded_image"])): st.session_state["prediction_table"][st.session_state["processed_image_count"]].dataframe(pd.DataFrame( table, columns=["Model", "Medium", "Artist", "Trending", "Movement", "Flavors", "Techniques", "Tags"])) @@ -283,38 +250,9 @@ def interrogate(image, models): st.session_state["text_result"][st.session_state["processed_image_count"]].code( f"\n\n{caption}, {medium} {artist}, {trending}, {movement}, {techniques}, {flavors}, {tags}", language="") - # logger.info("Finished Interrogating.") - st.session_state["log"].append("Finished Interrogating.") - st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='') - - st.session_state["log"] = [] -# - def img2txt(): - data_path = "data/" - - server_state["artists"] = load_list(os.path.join(data_path, 'img2txt', 'artists.txt')) - server_state["flavors"] = load_list(os.path.join(data_path, 'img2txt', 'flavors.txt')) - server_state["mediums"] = load_list(os.path.join(data_path, 'img2txt', 'mediums.txt')) - server_state["movements"] = load_list(os.path.join(data_path, 'img2txt', 'movements.txt')) - server_state["sites"] = load_list(os.path.join(data_path, 'img2txt', 'sites.txt')) - #server_state["domains"] = load_list(os.path.join(data_path, 'img2txt', 'domains.txt')) - #server_state["subreddits"] = load_list(os.path.join(data_path, 'img2txt', 'subreddits.txt')) - server_state["techniques"] = load_list(os.path.join(data_path, 'img2txt', 'techniques.txt')) - server_state["tags"] = load_list(os.path.join(data_path, 'img2txt', 'tags.txt')) - #server_state["genres"] = load_list(os.path.join(data_path, 'img2txt', 'genres.txt')) - # server_state["styles"] = load_list(os.path.join(data_path, 'img2txt', 'styles.txt')) - # server_state["subjects"] = load_list(os.path.join(data_path, 'img2txt', 'subjects.txt')) - - server_state["trending_list"] = [site for site in server_state["sites"]] - server_state["trending_list"].extend(["trending on "+site for site in server_state["sites"]]) - server_state["trending_list"].extend(["featured on "+site for site in server_state["sites"]]) - server_state["trending_list"].extend([site+" contest winner" for site in server_state["sites"]]) - - #image_path_or_url = "https://i.redd.it/e2e8gimigjq91.jpg" - models = [] if st.session_state["ViT-L/14"]: @@ -364,7 +302,36 @@ def img2txt(): def layout(): #set_page_title("Image-to-Text - Stable Diffusion WebUI") #st.info("Under Construction. :construction_worker:") - + # + if "clip_models" not in server_state: + server_state["clip_models"] = {} + if "preprocesses" not in server_state: + server_state["preprocesses"] = {} + data_path = "data/" + if "artists" not in server_state: + server_state["artists"] = load_list(os.path.join(data_path, 'img2txt', 'artists.txt')) + if "flavors" not in server_state: + server_state["flavors"] = random.choices(load_list(os.path.join(data_path, 'img2txt', 'flavors.txt')), k=2000) + if "mediums" not in server_state: + server_state["mediums"] = load_list(os.path.join(data_path, 'img2txt', 'mediums.txt')) + if "movements" not in server_state: + server_state["movements"] = load_list(os.path.join(data_path, 'img2txt', 'movements.txt')) + if "sites" not in server_state: + server_state["sites"] = load_list(os.path.join(data_path, 'img2txt', 'sites.txt')) + #server_state["domains"] = load_list(os.path.join(data_path, 'img2txt', 'domains.txt')) + #server_state["subreddits"] = load_list(os.path.join(data_path, 'img2txt', 'subreddits.txt')) + if "techniques" not in server_state: + server_state["techniques"] = load_list(os.path.join(data_path, 'img2txt', 'techniques.txt')) + if "tags" not in server_state: + server_state["tags"] = load_list(os.path.join(data_path, 'img2txt', 'tags.txt')) + #server_state["genres"] = load_list(os.path.join(data_path, 'img2txt', 'genres.txt')) + # server_state["styles"] = load_list(os.path.join(data_path, 'img2txt', 'styles.txt')) + # server_state["subjects"] = load_list(os.path.join(data_path, 'img2txt', 'subjects.txt')) + if "trending_list" not in server_state: + server_state["trending_list"] = [site for site in server_state["sites"]] + server_state["trending_list"].extend(["trending on "+site for site in server_state["sites"]]) + server_state["trending_list"].extend(["featured on "+site for site in server_state["sites"]]) + server_state["trending_list"].extend([site+" contest winner" for site in server_state["sites"]]) with st.form("img2txt-inputs"): st.session_state["generation_mode"] = "img2txt"