2022-09-28 19:33:54 +03:00
# This file is part of stable-diffusion-webui (https://github.com/sd-webui/stable-diffusion-webui/).
# Copyright 2022 sd-webui team.
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
CLIP Interrogator made by @pharmapsychotic modified to work with our WebUI .
# CLIP Interrogator by @pharmapsychotic
Twitter : https : / / twitter . com / pharmapsychotic
Github : https : / / github . com / pharmapsychotic / clip - interrogator
Description :
What do the different OpenAI CLIP models see in an image ? What might be a good text prompt to create similar images using CLIP guided diffusion
or another text to image model ? The CLIP Interrogator is here to get you answers !
Please consider buying him a coffee via [ ko - fi ] ( https : / / ko - fi . com / pharmapsychotic ) or following him on [ twitter ] ( https : / / twitter . com / pharmapsychotic ) .
And if you ' re looking for more Ai art tools check out my [Ai generative art tools list](https://pharmapsychotic.com/tools.html).
# base webui import and utils.
from sd_utils import *
# streamlit imports
import streamlit_nested_layout
#streamlit components section
#other imports
2022-09-30 18:47:30 +03:00
import clip , open_clip
2022-09-28 22:37:15 +03:00
import gc
import os
import pandas as pd
2022-09-29 17:27:56 +03:00
#import requests
2022-09-28 22:37:15 +03:00
import torch
2022-09-30 12:40:02 +03:00
import torchvision . transforms as T
import torchvision . transforms . functional as TF
2022-09-28 22:37:15 +03:00
from PIL import Image
2022-09-30 12:40:02 +03:00
from torch import nn
from torch . nn import functional as F
2022-09-28 22:37:15 +03:00
from torchvision import transforms
from torchvision . transforms . functional import InterpolationMode
from ldm . models . blip import blip_decoder
2022-09-28 19:33:54 +03:00
# end of imports
2022-09-28 22:37:15 +03:00
device = torch . device ( ' cuda:0 ' if torch . cuda . is_available ( ) else ' cpu ' )
2022-09-30 18:47:30 +03:00
blip_image_eval_size = 512
2022-09-30 12:40:02 +03:00
blip_model = None
2022-09-29 18:52:46 +03:00
#blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth'
def load_blip_model ( ) :
2022-09-30 18:47:30 +03:00
st . session_state [ " log_message " ] . code ( " Loading BLIP Model " , language = ' ' )
with server_state_lock [ ' blip_model ' ] :
if " blip_model " not in server_state :
blip_model = blip_decoder ( pretrained = " models/blip/model__base_caption.pth " , image_size = blip_image_eval_size , vit = ' base ' , med_config = " configs/blip/med_config.json " )
blip_model . eval ( )
blip_model = blip_model . to ( device ) . half ( )
st . session_state [ " log_message " ] . code ( " BLIP Model Loaded " , language = ' ' )
else :
st . session_state [ " log_message " ] . code ( " BLIP Model Already Loaded " , language = ' ' )
return blip_model
2022-09-29 18:52:46 +03:00
2022-09-28 22:37:15 +03:00
def generate_caption ( pil_image ) :
2022-09-30 12:40:02 +03:00
global blip_model
2022-09-30 18:47:30 +03:00
width , height = pil_image . size
2022-09-30 12:40:02 +03:00
gpu_image = transforms . Compose ( [
2022-09-30 18:47:30 +03:00
transforms . Resize ( ( width , height ) , interpolation = InterpolationMode . BICUBIC ) ,
transforms . ToTensor ( ) ,
2022-09-30 12:40:02 +03:00
transforms . Normalize ( ( 0.48145466 , 0.4578275 , 0.40821073 ) , ( 0.26862954 , 0.26130258 , 0.27577711 ) )
2022-09-30 18:47:30 +03:00
] ) ( pil_image ) . unsqueeze ( 0 ) . to ( device ) . half ( )
2022-09-28 22:37:15 +03:00
2022-09-30 12:40:02 +03:00
with torch . no_grad ( ) :
caption = blip_model . generate ( gpu_image , sample = False , num_beams = 3 , max_length = 20 , min_length = 5 )
2022-09-30 18:47:30 +03:00
#print (caption)
2022-09-30 12:40:02 +03:00
return caption [ 0 ]
2022-09-28 22:37:15 +03:00
def load_list ( filename ) :
2022-09-30 18:47:30 +03:00
with open ( filename , ' r ' , encoding = ' utf-8 ' , errors = ' replace ' ) as f :
items = [ line . strip ( ) for line in f . readlines ( ) ]
return items
2022-09-28 22:37:15 +03:00
def rank ( model , image_features , text_array , top_count = 1 ) :
2022-09-30 12:40:02 +03:00
top_count = min ( top_count , len ( text_array ) )
text_tokens = clip . tokenize ( [ text for text in text_array ] ) . cuda ( )
with torch . no_grad ( ) :
text_features = model . encode_text ( text_tokens ) . float ( )
text_features / = text_features . norm ( dim = - 1 , keepdim = True )
2022-09-28 22:37:15 +03:00
2022-09-30 12:40:02 +03:00
similarity = torch . zeros ( ( 1 , len ( text_array ) ) ) . to ( device )
for i in range ( image_features . shape [ 0 ] ) :
similarity + = ( 100.0 * image_features [ i ] . unsqueeze ( 0 ) @ text_features . T ) . softmax ( dim = - 1 )
similarity / = image_features . shape [ 0 ]
2022-09-28 22:37:15 +03:00
2022-09-30 12:40:02 +03:00
top_probs , top_labels = similarity . cpu ( ) . topk ( top_count , dim = - 1 )
return [ ( text_array [ top_labels [ 0 ] [ i ] . numpy ( ) ] , ( top_probs [ 0 ] [ i ] . numpy ( ) * 100 ) ) for i in range ( top_count ) ]
def clear_cuda ( ) :
2022-09-30 18:47:30 +03:00
torch . cuda . empty_cache ( )
gc . collect ( )
2022-09-28 22:37:15 +03:00
def interrogate ( image , models ) :
2022-09-30 18:47:30 +03:00
global blip_model
blip_model = load_blip_model ( )
print ( " Generating Caption " )
st . session_state [ " log_message " ] . code ( " Generating Caption " , language = ' ' )
caption = generate_caption ( image )
del blip_model
clear_cuda ( )
print ( " Caption Generated " )
if len ( models ) == 0 :
print ( f " \n \n { caption } " )
table = [ ]
bests = [ [ ( ' ' , 0 ) ] ] * 5
for model_name in models :
print ( f " Interrogating with { model_name } " )
st . session_state [ " log_message " ] . code ( f " Interrogating with { model_name } ... " , language = ' ' )
if model_name == ' ViT-H-14 ' :
model , _ , preprocess = open_clip . create_model_and_transforms ( model_name , pretrained = ' laion2b_s32b_b79k ' )
elif model_name == ' ViT-g-14 ' :
model , _ , preprocess = open_clip . create_model_and_transforms ( model_name , pretrained = ' laion2b_s12b_b42k ' )
else :
model , preprocess = clip . load ( model_name , device = device )
#model, preprocess = clip.load(model_name)
model . cuda ( ) . eval ( )
images = preprocess ( image ) . unsqueeze ( 0 ) . cuda ( )
with torch . no_grad ( ) :
image_features = model . encode_image ( images ) . float ( )
image_features / = image_features . norm ( dim = - 1 , keepdim = True )
clear_cuda ( )
ranks = [ ]
ranks . append ( rank ( model , image_features , server_state [ " mediums " ] ) )
clear_cuda ( )
artists = [ ]
for batch in range ( int ( len ( server_state [ " artists " ] ) / 1000 ) ) :
artist_rank = rank ( model , image_features , server_state [ " artists " ] [ batch * 1000 : ( batch + 1 ) * 1000 ] )
artists . extend ( artist_rank )
clear_cuda ( )
ranks . append ( artists )
ranks . append ( rank ( model , image_features , server_state [ " trending_list " ] ) )
clear_cuda ( )
ranks . append ( rank ( model , image_features , server_state [ " movements " ] ) )
clear_cuda ( )
ranks . append ( rank ( model , image_features , server_state [ " flavors " ] , top_count = 3 ) )
clear_cuda ( )
for i in range ( len ( ranks ) ) :
confidence_sum = 0
for ci in range ( len ( ranks [ i ] ) ) :
confidence_sum + = ranks [ i ] [ ci ] [ 1 ]
if confidence_sum > sum ( bests [ i ] [ t ] [ 1 ] for t in range ( len ( bests [ i ] ) ) ) :
bests [ i ] = ranks [ i ]
row = [ model_name ]
for r in ranks :
row . append ( ' , ' . join ( [ f " { x [ 0 ] } ( { x [ 1 ] : 0.1f } %) " for x in r ] ) )
table . append ( row )
#del model
gc . collect ( )
st . session_state [ " prediction_table " ] . dataframe ( pd . DataFrame ( table , columns = [ " Model " , " Medium " , " Artist " , " Trending " , " Movement " , " Flavors " ] ) )
flaves = ' , ' . join ( [ f " { x [ 0 ] } " for x in bests [ 4 ] ] )
medium = bests [ 0 ] [ 0 ] [ 0 ]
for items in caption :
if items . startswith ( medium ) :
st . session_state [ " text_result " ] . code ( f " \n \n { caption } { bests [ 1 ] [ 0 ] [ 0 ] } , { bests [ 2 ] [ 0 ] [ 0 ] } , { bests [ 3 ] [ 0 ] [ 0 ] } , { flaves } " , language = " " )
else :
st . session_state [ " text_result " ] . code ( f " \n \n { caption } , { medium } { bests [ 1 ] [ 0 ] [ 0 ] } , { bests [ 2 ] [ 0 ] [ 0 ] } , { bests [ 3 ] [ 0 ] [ 0 ] } , { flaves } " , language = " " )
st . session_state [ " log_message " ] . code ( " Finished Interrogating. " , language = " " )
2022-09-28 19:33:54 +03:00
2022-09-28 22:37:15 +03:00
def img2txt ( ) :
2022-09-30 18:47:30 +03:00
data_path = " data/ "
server_state [ " artists " ] = load_list ( os . path . join ( data_path , ' img2txt ' , ' artists.txt ' ) )
server_state [ " flavors " ] = load_list ( os . path . join ( data_path , ' img2txt ' , ' flavors.txt ' ) )
server_state [ " mediums " ] = load_list ( os . path . join ( data_path , ' img2txt ' , ' mediums.txt ' ) )
server_state [ " movements " ] = load_list ( os . path . join ( data_path , ' img2txt ' , ' movements.txt ' ) )
server_state [ " sites " ] = load_list ( os . path . join ( data_path , ' img2txt ' , ' sites.txt ' ) )
server_state [ " trending_list " ] = [ site for site in server_state [ " sites " ] ]
server_state [ " trending_list " ] . extend ( [ " trending on " + site for site in server_state [ " sites " ] ] )
server_state [ " trending_list " ] . extend ( [ " featured on " + site for site in server_state [ " sites " ] ] )
server_state [ " trending_list " ] . extend ( [ site + " contest winner " for site in server_state [ " sites " ] ] )
#image_path_or_url = "https://i.redd.it/e2e8gimigjq91.jpg"
models = [ ]
if st . session_state [ " ViTB32 " ] :
models . append ( ' ViT-B/32 ' )
if st . session_state [ ' ViTB16 ' ] :
models . append ( ' ViT-B/16 ' )
if st . session_state [ " ViTL14 " ] :
models . append ( ' ViT-L/14 ' )
if st . session_state [ " ViTL14_336px " ] :
models . append ( ' ViT-L/14@336px ' )
if st . session_state [ " RN101 " ] :
models . append ( ' RN101 ' )
if st . session_state [ " RN50 " ] :
models . append ( ' RN50 ' )
if st . session_state [ " RN50x4 " ] :
models . append ( ' RN50x4 ' )
if st . session_state [ " RN50x16 " ] :
models . append ( ' RN50x16 ' )
if st . session_state [ " RN50x64 " ] :
models . append ( ' RN50x64 ' )
#if str(image_path_or_url).startswith('http://') or str(image_path_or_url).startswith('https://'):
#image = Image.open(requests.get(image_path_or_url, stream=True).raw).convert('RGB')
#image = Image.open(image_path_or_url).convert('RGB')
#thumb = st.session_state["uploaded_image"].image.copy()
#thumb.thumbnail([blip_image_eval_size, blip_image_eval_size])
interrogate ( st . session_state [ " uploaded_image " ] . pil_image , models = models )
2022-09-28 22:37:15 +03:00
def layout ( ) :
2022-09-30 18:47:30 +03:00
#set_page_title("Image-to-Text - Stable Diffusion WebUI")
#st.info("Under Construction. :construction_worker:")
with st . form ( " img2txt-inputs " ) :
st . session_state [ " generation_mode " ] = " img2txt "
# creating the page layout using columns
col1 , col2 = st . columns ( [ 1 , 4 ] , gap = " large " )
with col1 :
#url = st.text_area("Input Text","")
#url = st.text_input("Input Text","", placeholder="A corgi wearing a top hat as an oil painting.")
#st.subheader("Input Image")
st . session_state [ " uploaded_image " ] = st . file_uploader ( ' Input Image ' , type = [ ' png ' , ' jpg ' , ' jpeg ' ] , accept_multiple_files = True )
st . subheader ( " CLIP models " )
with st . expander ( " Stable Diffusion " , expanded = True ) :
st . session_state [ " ViTL14 " ] = st . checkbox ( " ViTL14 " , value = True , help = " For StableDiffusion you can just use ViTL14. " )
with st . expander ( " Others " ) :
st . info ( " For DiscoDiffusion and JAX enable all the same models here as you intend to use when generating your images. " )
st . session_state [ " ViTL14_336px " ] = st . checkbox ( " ViTL14_336px " , value = False , help = " ViTL14_336px model. " )
st . session_state [ " ViTB16 " ] = st . checkbox ( " ViTB16 " , value = False , help = " ViTB16 model. " )
st . session_state [ " ViTB32 " ] = st . checkbox ( " ViTB32 " , value = False , help = " ViTB32 model. " )
st . session_state [ " RN50 " ] = st . checkbox ( " RN50 " , value = False , help = " RN50 model. " )
st . session_state [ " RN50x4 " ] = st . checkbox ( " RN50x4 " , value = False , help = " RN50x4 model. " )
st . session_state [ " RN50x16 " ] = st . checkbox ( " RN50x16 " , value = False , help = " RN50x16 model. " )
st . session_state [ " RN50x64 " ] = st . checkbox ( " RN50x64 " , value = False , help = " RN50x64 model. " )
st . session_state [ " RN101 " ] = st . checkbox ( " RN101 " , value = False , help = " RN101 model. " )
st . session_state [ " log_message " ] = st . empty ( ) if not st . session_state [ " log_message " ] else st . session_state [ " log_message " ]
st . session_state [ " log_message " ] . code ( ' ' , language = " " )
with col2 :
st . subheader ( " Image " )
refresh = st . form_submit_button ( " Refresh " , help = ' Refresh the image preview to show your uploaded image instead of the default placeholder. ' )
col1_output , col2_output = st . columns ( [ 2 , 10 ] , gap = " medium " )
if st . session_state [ " uploaded_image " ] :
if type ( st . session_state [ " uploaded_image " ] ) != list :
with col1_output :
st . session_state [ " input_image_preview " ] = st . empty ( )
st . session_state [ " uploaded_image " ] . pil_image = Image . open ( st . session_state [ " uploaded_image " ] ) . convert ( ' RGB ' )
st . session_state [ " input_image_preview " ] . image ( st . session_state [ " uploaded_image " ] . pil_image , use_column_width = True , clamp = True )
with col2_output :
#with st.container():
##st.subheader("Image To Text Result")
st . session_state [ " prediction_table " ] = st . empty ( ) if not st . session_state [ " prediction_table " ] or refresh else st . session_state [ " prediction_table " ]
st . session_state [ " prediction_table " ] . table ( ) if not st . session_state [ " prediction_table " ] . table ( ) or refresh else st . session_state [ " prediction_table " ] . table ( )
st . session_state [ " text_result " ] = st . empty ( ) if not st . session_state [ " text_result " ] or refresh else st . session_state [ " text_result " ]
st . session_state [ " text_result " ] . code ( ' ' , language = " " ) if not st . session_state [ " text_result " ] . code ( ' ' , language = " "
) or refresh else st . session_state [ " text_result " ] . code ( ' ' , language = " " )
else :
for i in range ( st . session_state [ " uploaded_image " ] ) :
#for image in st.session_state["uploaded_image"]:
#st.session_state["uploaded_image"].pil_image[i] = []
st . session_state [ " uploaded_image " ] . pil_image [ i ] = Image . open ( st . session_state [ " uploaded_image " ] [ i ] ) . convert ( ' RGB ' )
with col1_output :
st . session_state [ " input_image_preview " ] = st . empty ( )
st . session_state [ " uploaded_image " ] . pil_image = Image . open ( st . session_state [ " uploaded_image " ] ) . convert ( ' RGB ' )
st . session_state [ " input_image_preview " ] . image ( st . session_state [ " uploaded_image " ] . pil_image , use_column_width = True , clamp = True )
with col2_output :
#with st.container():
##st.subheader("Image To Text Result")
st . session_state [ " prediction_table " ] = st . empty ( ) if not st . session_state [ " prediction_table " ] or refresh else st . session_state [ " prediction_table " ]
st . session_state [ " prediction_table " ] . table ( ) if not st . session_state [ " prediction_table " ] . table ( ) or refresh else st . session_state [ " prediction_table " ] . table ( )
st . session_state [ " text_result " ] = st . empty ( ) if not st . session_state [ " text_result " ] or refresh else st . session_state [ " text_result " ]
st . session_state [ " text_result " ] . code ( ' ' , language = " " ) if not st . session_state [ " text_result " ] . code ( ' ' , language = " "
) or refresh else st . session_state [ " text_result " ] . code ( ' ' , language = " " )
else :
#st.session_state["input_image_preview"].code('', language="")
st . image ( " images/streamlit/img2txt_placeholder.png " , clamp = True )
# Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way.
generate_button = st . form_submit_button ( " Generate! " )
if generate_button :
# run clip interrogator
img2txt ( )