stable-diffusion-webui/scripts/img2txt.py

461 lines
22 KiB
Python
Raw Normal View History

# This file is part of sygil-webui (https://github.com/Sygil-Dev/sygil-webui/).
# Copyright 2022 Sygil-Dev team.
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
2022-10-01 01:20:02 +03:00
# along with this program. If not, see <http://www.gnu.org/licenses/>.
2022-10-01 01:20:02 +03:00
# ---------------------------------------------------------------------------------------------------------------------------------------------------
"""
CLIP Interrogator made by @pharmapsychotic modified to work with our WebUI.
# CLIP Interrogator by @pharmapsychotic
Twitter: https://twitter.com/pharmapsychotic
Github: https://github.com/pharmapsychotic/clip-interrogator
Description:
What do the different OpenAI CLIP models see in an image? What might be a good text prompt to create similar images using CLIP guided diffusion
or another text to image model? The CLIP Interrogator is here to get you answers!
Please consider buying him a coffee via [ko-fi](https://ko-fi.com/pharmapsychotic) or following him on [twitter](https://twitter.com/pharmapsychotic).
And if you're looking for more Ai art tools check out my [Ai generative art tools list](https://pharmapsychotic.com/tools.html).
"""
2022-10-01 01:20:02 +03:00
# ---------------------------------------------------------------------------------------------------------------------------------------------------
# base webui import and utils.
from sd_utils import *
# streamlit imports
2022-10-01 01:20:02 +03:00
# streamlit components section
import streamlit_nested_layout
2022-10-01 01:20:02 +03:00
# other imports
2022-10-01 01:20:02 +03:00
import clip
import open_clip
import gc
import os
import pandas as pd
2022-09-29 17:27:56 +03:00
#import requests
import torch
from PIL import Image
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from ldm.models.blip import blip_decoder
2022-10-18 01:59:29 +03:00
#import hashlib
# end of imports
2022-10-01 01:20:02 +03:00
# ---------------------------------------------------------------------------------------------------------------
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
blip_image_eval_size = 512
2022-10-17 15:10:06 +03:00
2022-10-24 14:45:09 +03:00
st.session_state["log"] = []
def load_blip_model():
logger.info("Loading BLIP Model")
2022-10-27 03:21:51 +03:00
if "log" not in st.session_state:
st.session_state["log"] = []
2022-10-24 14:45:09 +03:00
st.session_state["log"].append("Loading BLIP Model")
st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
2022-10-01 01:20:02 +03:00
if "blip_model" not in server_state:
with server_state_lock['blip_model']:
server_state["blip_model"] = blip_decoder(pretrained="models/blip/model__base_caption.pth",
image_size=blip_image_eval_size, vit='base', med_config="configs/blip/med_config.json")
2022-10-01 01:20:02 +03:00
server_state["blip_model"] = server_state["blip_model"].eval()
2022-10-01 01:20:02 +03:00
server_state["blip_model"] = server_state["blip_model"].to(device).half()
logger.info("BLIP Model Loaded")
2022-10-24 14:45:09 +03:00
st.session_state["log"].append("BLIP Model Loaded")
st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
2022-10-01 01:20:02 +03:00
else:
logger.info("BLIP Model already loaded")
2022-10-24 14:45:09 +03:00
st.session_state["log"].append("BLIP Model already loaded")
st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
2022-10-01 01:20:02 +03:00
def generate_caption(pil_image):
2022-10-01 01:20:02 +03:00
load_blip_model()
2022-10-01 01:20:02 +03:00
gpu_image = transforms.Compose([ # type: ignore
transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC), # type: ignore
transforms.ToTensor(), # type: ignore
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) # type: ignore
])(pil_image).unsqueeze(0).to(device).half()
2022-09-30 12:40:02 +03:00
with torch.no_grad():
2022-10-01 01:20:02 +03:00
caption = server_state["blip_model"].generate(gpu_image, sample=False, num_beams=3, max_length=20, min_length=5)
2022-09-30 12:40:02 +03:00
return caption[0]
def load_list(filename):
with open(filename, 'r', encoding='utf-8', errors='replace') as f:
items = [line.strip() for line in f.readlines()]
return items
2022-10-01 01:20:02 +03:00
def rank(model, image_features, text_array, top_count=1):
2022-09-30 12:40:02 +03:00
top_count = min(top_count, len(text_array))
text_tokens = clip.tokenize([text for text in text_array]).cuda()
with torch.no_grad():
text_features = model.encode_text(text_tokens).float()
text_features /= text_features.norm(dim=-1, keepdim=True)
2022-09-30 12:40:02 +03:00
similarity = torch.zeros((1, len(text_array))).to(device)
for i in range(image_features.shape[0]):
similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)
similarity /= image_features.shape[0]
2022-10-01 01:20:02 +03:00
top_probs, top_labels = similarity.cpu().topk(top_count, dim=-1)
2022-09-30 12:40:02 +03:00
return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)]
2022-10-01 01:20:02 +03:00
2022-09-30 12:40:02 +03:00
def clear_cuda():
torch.cuda.empty_cache()
gc.collect()
2022-10-01 01:20:02 +03:00
def batch_rank(model, image_features, text_array, batch_size=st.session_state["defaults"].img2txt.batch_size):
2022-10-05 07:31:20 +03:00
batch_size = min(batch_size, len(text_array))
batch_count = int(len(text_array) / batch_size)
batches = [text_array[i*batch_size:(i+1)*batch_size] for i in range(batch_count)]
ranks = []
for batch in batches:
ranks += rank(model, image_features, batch)
return ranks
def interrogate(image, models):
2022-10-01 01:20:02 +03:00
load_blip_model()
logger.info("Generating Caption")
2022-10-24 14:45:09 +03:00
st.session_state["log"].append("Generating Caption")
st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
caption = generate_caption(image)
2022-10-01 01:20:02 +03:00
if st.session_state["defaults"].general.optimized:
2022-10-01 01:20:02 +03:00
del server_state["blip_model"]
clear_cuda()
logger.info("Caption Generated")
2022-10-24 14:45:09 +03:00
st.session_state["log"].append("Caption Generated")
st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
if len(models) == 0:
logger.info(f"\n\n{caption}")
return
2022-10-18 01:59:29 +03:00
table = []
bests = [[('', 0)]]*7
2022-10-01 01:20:02 +03:00
logger.info("Ranking Text")
2022-10-24 14:45:09 +03:00
st.session_state["log"].append("Ranking Text")
st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
for model_name in models:
2022-10-01 23:47:26 +03:00
with torch.no_grad(), torch.autocast('cuda', dtype=torch.float16):
logger.info(f"Interrogating with {model_name}...")
2022-10-24 14:45:09 +03:00
st.session_state["log"].append(f"Interrogating with {model_name}...")
st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
2022-10-01 23:47:26 +03:00
if model_name not in server_state["clip_models"]:
2022-10-05 07:31:20 +03:00
if not st.session_state["defaults"].img2txt.keep_all_models_loaded:
model_to_delete = []
for model in server_state["clip_models"]:
if model != model_name:
model_to_delete.append(model)
for model in model_to_delete:
del server_state["clip_models"][model]
del server_state["preprocesses"][model]
clear_cuda()
2022-10-01 23:47:26 +03:00
if model_name == 'ViT-H-14':
2022-10-24 13:48:45 +03:00
server_state["clip_models"][model_name], _, server_state["preprocesses"][model_name] = \
open_clip.create_model_and_transforms(model_name, pretrained='laion2b_s32b_b79k', cache_dir='models/clip')
2022-10-01 23:47:26 +03:00
elif model_name == 'ViT-g-14':
2022-10-24 13:48:45 +03:00
server_state["clip_models"][model_name], _, server_state["preprocesses"][model_name] = \
open_clip.create_model_and_transforms(model_name, pretrained='laion2b_s12b_b42k', cache_dir='models/clip')
2022-10-01 23:47:26 +03:00
else:
2022-10-24 13:48:45 +03:00
server_state["clip_models"][model_name], server_state["preprocesses"][model_name] = \
clip.load(model_name, device=device, download_root='models/clip')
2022-10-01 23:47:26 +03:00
server_state["clip_models"][model_name] = server_state["clip_models"][model_name].cuda().eval()
2022-10-01 23:47:26 +03:00
images = server_state["preprocesses"][model_name](image).unsqueeze(0).cuda()
2022-10-01 23:47:26 +03:00
image_features = server_state["clip_models"][model_name].encode_image(images).float()
2022-10-01 23:47:26 +03:00
image_features /= image_features.norm(dim=-1, keepdim=True)
if st.session_state["defaults"].general.optimized:
clear_cuda()
2022-10-01 23:47:26 +03:00
ranks = []
ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["mediums"]))
ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, ["by "+artist for artist in server_state["artists"]]))
ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["trending_list"]))
ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["movements"]))
ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["flavors"]))
#ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["domains"]))
#ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["subreddits"]))
ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["techniques"]))
ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["tags"]))
2022-10-01 23:47:26 +03:00
# ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["genres"]))
# ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["styles"]))
# ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["subjects"]))
# ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["colors"]))
# ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["moods"]))
# ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["themes"]))
# ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["keywords"]))
2022-10-17 15:10:06 +03:00
#print (bests)
#print (ranks)
2022-10-01 23:47:26 +03:00
for i in range(len(ranks)):
confidence_sum = 0
for ci in range(len(ranks[i])):
confidence_sum += ranks[i][ci][1]
if confidence_sum > sum(bests[i][t][1] for t in range(len(bests[i]))):
bests[i] = ranks[i]
2022-10-24 13:15:22 +03:00
for best in bests:
best.sort(key=lambda x: x[1], reverse=True)
2022-10-27 03:21:51 +03:00
# prune to 3
2022-10-24 13:15:22 +03:00
best = best[:3]
2022-10-01 23:47:26 +03:00
2022-10-24 13:15:22 +03:00
row = [model_name]
2022-10-24 14:45:09 +03:00
for r in ranks:
row.append(', '.join([f"{x[0]} ({x[1]:0.1f}%)" for x in r]))
#for rank in ranks:
# rank.sort(key=lambda x: x[1], reverse=True)
# row.append(f'{rank[0][0]} {rank[0][1]:.2f}%')
2022-10-01 23:47:26 +03:00
table.append(row)
if st.session_state["defaults"].general.optimized:
del server_state["clip_models"][model_name]
gc.collect()
2022-10-01 01:20:02 +03:00
st.session_state["prediction_table"][st.session_state["processed_image_count"]].dataframe(pd.DataFrame(
table, columns=["Model", "Medium", "Artist", "Trending", "Movement", "Flavors", "Techniques", "Tags"]))
medium = bests[0][0][0]
2022-10-24 13:15:22 +03:00
artist = bests[1][0][0]
trending = bests[2][0][0]
movement = bests[3][0][0]
flavors = bests[4][0][0]
#domains = bests[5][0][0]
#subreddits = bests[6][0][0]
techniques = bests[5][0][0]
tags = bests[6][0][0]
2022-10-17 15:10:06 +03:00
2022-10-01 01:20:02 +03:00
if caption.startswith(medium):
st.session_state["text_result"][st.session_state["processed_image_count"]].code(
2022-10-24 13:15:22 +03:00
f"\n\n{caption} {artist}, {trending}, {movement}, {techniques}, {flavors}, {tags}", language="")
else:
st.session_state["text_result"][st.session_state["processed_image_count"]].code(
2022-10-24 13:15:22 +03:00
f"\n\n{caption}, {medium} {artist}, {trending}, {movement}, {techniques}, {flavors}, {tags}", language="")
2022-10-01 01:20:02 +03:00
logger.info("Finished Interrogating.")
2022-10-24 14:45:09 +03:00
st.session_state["log"].append("Finished Interrogating.")
st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
2022-10-01 01:20:02 +03:00
def img2txt():
models = []
2022-10-05 07:31:20 +03:00
if st.session_state["ViT-L/14"]:
models.append('ViT-L/14')
2022-10-01 01:20:02 +03:00
if st.session_state["ViT-H-14"]:
models.append('ViT-H-14')
2022-10-01 01:20:02 +03:00
if st.session_state["ViT-g-14"]:
models.append('ViT-g-14')
2022-10-06 09:48:32 +03:00
if st.session_state["ViTB32"]:
models.append('ViT-B/32')
if st.session_state['ViTB16']:
models.append('ViT-B/16')
2022-10-06 09:48:32 +03:00
if st.session_state["ViTL14_336px"]:
models.append('ViT-L/14@336px')
if st.session_state["RN101"]:
models.append('RN101')
if st.session_state["RN50"]:
models.append('RN50')
if st.session_state["RN50x4"]:
models.append('RN50x4')
if st.session_state["RN50x16"]:
models.append('RN50x16')
if st.session_state["RN50x64"]:
models.append('RN50x64')
2022-10-01 01:20:02 +03:00
# if str(image_path_or_url).startswith('http://') or str(image_path_or_url).startswith('https://'):
#image = Image.open(requests.get(image_path_or_url, stream=True).raw).convert('RGB')
2022-10-01 01:20:02 +03:00
# else:
#image = Image.open(image_path_or_url).convert('RGB')
#thumb = st.session_state["uploaded_image"].image.copy()
#thumb.thumbnail([blip_image_eval_size, blip_image_eval_size])
2022-10-01 01:20:02 +03:00
# display(thumb)
st.session_state["processed_image_count"] = 0
2022-10-01 01:20:02 +03:00
for i in range(len(st.session_state["uploaded_image"])):
2022-10-01 01:20:02 +03:00
interrogate(st.session_state["uploaded_image"][i].pil_image, models=models)
# increase counter.
st.session_state["processed_image_count"] += 1
#
2022-10-01 01:20:02 +03:00
def layout():
#set_page_title("Image-to-Text - Stable Diffusion WebUI")
2022-10-01 01:20:02 +03:00
#st.info("Under Construction. :construction_worker:")
2022-10-27 03:21:51 +03:00
#
2022-10-24 13:48:45 +03:00
if "clip_models" not in server_state:
server_state["clip_models"] = {}
if "preprocesses" not in server_state:
server_state["preprocesses"] = {}
data_path = "data/"
if "artists" not in server_state:
server_state["artists"] = load_list(os.path.join(data_path, 'img2txt', 'artists.txt'))
if "flavors" not in server_state:
server_state["flavors"] = random.choices(load_list(os.path.join(data_path, 'img2txt', 'flavors.txt')), k=2000)
if "mediums" not in server_state:
server_state["mediums"] = load_list(os.path.join(data_path, 'img2txt', 'mediums.txt'))
if "movements" not in server_state:
server_state["movements"] = load_list(os.path.join(data_path, 'img2txt', 'movements.txt'))
if "sites" not in server_state:
server_state["sites"] = load_list(os.path.join(data_path, 'img2txt', 'sites.txt'))
#server_state["domains"] = load_list(os.path.join(data_path, 'img2txt', 'domains.txt'))
#server_state["subreddits"] = load_list(os.path.join(data_path, 'img2txt', 'subreddits.txt'))
if "techniques" not in server_state:
server_state["techniques"] = load_list(os.path.join(data_path, 'img2txt', 'techniques.txt'))
if "tags" not in server_state:
server_state["tags"] = load_list(os.path.join(data_path, 'img2txt', 'tags.txt'))
#server_state["genres"] = load_list(os.path.join(data_path, 'img2txt', 'genres.txt'))
# server_state["styles"] = load_list(os.path.join(data_path, 'img2txt', 'styles.txt'))
# server_state["subjects"] = load_list(os.path.join(data_path, 'img2txt', 'subjects.txt'))
if "trending_list" not in server_state:
server_state["trending_list"] = [site for site in server_state["sites"]]
server_state["trending_list"].extend(["trending on "+site for site in server_state["sites"]])
server_state["trending_list"].extend(["featured on "+site for site in server_state["sites"]])
server_state["trending_list"].extend([site+" contest winner" for site in server_state["sites"]])
with st.form("img2txt-inputs"):
st.session_state["generation_mode"] = "img2txt"
2022-10-01 01:20:02 +03:00
# st.write("---")
# creating the page layout using columns
2022-10-01 01:20:02 +03:00
col1, col2 = st.columns([1, 4], gap="large")
with col1:
st.session_state["uploaded_image"] = st.file_uploader('Input Image', type=['png', 'jpg', 'jpeg', 'jfif', 'webp'], accept_multiple_files=True)
2022-10-05 07:31:20 +03:00
with st.expander("CLIP models", expanded=True):
st.session_state["ViT-L/14"] = st.checkbox("ViT-L/14", value=True, help="ViT-L/14 model.")
st.session_state["ViT-H-14"] = st.checkbox("ViT-H-14", value=False, help="ViT-H-14 model.")
2022-10-01 01:20:02 +03:00
st.session_state["ViT-g-14"] = st.checkbox("ViT-g-14", value=False, help="ViT-g-14 model.")
2022-10-06 09:48:32 +03:00
with st.expander("Others"):
st.info("For DiscoDiffusion and JAX enable all the same models here as you intend to use when generating your images.")
2022-10-06 09:48:32 +03:00
st.session_state["ViTL14_336px"] = st.checkbox("ViTL14_336px", value=False, help="ViTL14_336px model.")
st.session_state["ViTB16"] = st.checkbox("ViTB16", value=False, help="ViTB16 model.")
st.session_state["ViTB32"] = st.checkbox("ViTB32", value=False, help="ViTB32 model.")
st.session_state["RN50"] = st.checkbox("RN50", value=False, help="RN50 model.")
st.session_state["RN50x4"] = st.checkbox("RN50x4", value=False, help="RN50x4 model.")
st.session_state["RN50x16"] = st.checkbox("RN50x16", value=False, help="RN50x16 model.")
st.session_state["RN50x64"] = st.checkbox("RN50x64", value=False, help="RN50x64 model.")
st.session_state["RN101"] = st.checkbox("RN101", value=False, help="RN101 model.")
#
2022-10-01 01:20:02 +03:00
# st.subheader("Logs:")
st.session_state["log_message"] = st.empty()
st.session_state["log_message"].code('', language="")
with col2:
st.subheader("Image")
image_col1, image_col2 = st.columns([10,25])
with image_col1:
refresh = st.form_submit_button("Update Preview Image", help='Refresh the image preview to show your uploaded image instead of the default placeholder.')
2022-10-01 01:20:02 +03:00
if st.session_state["uploaded_image"]:
#print (type(st.session_state["uploaded_image"]))
2022-10-01 01:20:02 +03:00
# if len(st.session_state["uploaded_image"]) == 1:
st.session_state["input_image_preview"] = []
st.session_state["input_image_preview_container"] = []
st.session_state["prediction_table"] = []
st.session_state["text_result"] = []
2022-10-01 01:20:02 +03:00
for i in range(len(st.session_state["uploaded_image"])):
st.session_state["input_image_preview_container"].append(i)
2022-10-01 01:20:02 +03:00
st.session_state["input_image_preview_container"][i] = st.empty()
with st.session_state["input_image_preview_container"][i].container():
2022-10-01 01:20:02 +03:00
col1_output, col2_output = st.columns([2, 10], gap="medium")
with col1_output:
st.session_state["input_image_preview"].append(i)
2022-10-01 01:20:02 +03:00
st.session_state["input_image_preview"][i] = st.empty()
st.session_state["uploaded_image"][i].pil_image = Image.open(st.session_state["uploaded_image"][i]).convert('RGB')
2022-10-01 01:20:02 +03:00
st.session_state["input_image_preview"][i].image(st.session_state["uploaded_image"][i].pil_image, use_column_width=True, clamp=True)
2022-10-01 01:20:02 +03:00
with st.session_state["input_image_preview_container"][i].container():
with col2_output:
st.session_state["prediction_table"].append(i)
st.session_state["prediction_table"][i] = st.empty()
st.session_state["prediction_table"][i].table()
2022-10-01 01:20:02 +03:00
st.session_state["text_result"].append(i)
2022-10-01 01:20:02 +03:00
st.session_state["text_result"][i] = st.empty()
st.session_state["text_result"][i].code("", language="")
else:
#st.session_state["input_image_preview"].code('', language="")
st.image("images/streamlit/img2txt_placeholder.png", clamp=True)
with image_col2:
#
# Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way.
# generate_col1.title("")
# generate_col1.title("")
generate_button = st.form_submit_button("Generate!", help="Start interrogating the images to generate a prompt from each of the selected images")
2022-10-01 01:20:02 +03:00
if generate_button:
# if model, pipe, RealESRGAN or GFPGAN is in st.session_state remove the model and pipe form session_state so that they are reloaded.
2022-10-05 07:31:20 +03:00
if "model" in server_state and st.session_state["defaults"].general.optimized:
del server_state["model"]
if "pipe" in server_state and st.session_state["defaults"].general.optimized:
del server_state["pipe"]
if "RealESRGAN" in server_state and st.session_state["defaults"].general.optimized:
del server_state["RealESRGAN"]
if "GFPGAN" in server_state and st.session_state["defaults"].general.optimized:
del server_state["GFPGAN"]
2022-10-01 01:20:02 +03:00
# run clip interrogator
2022-10-01 01:20:02 +03:00
img2txt()