2022-09-28 19:33:54 +03:00
# This file is part of stable-diffusion-webui (https://github.com/sd-webui/stable-diffusion-webui/).
# Copyright 2022 sd-webui team.
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
CLIP Interrogator made by @pharmapsychotic modified to work with our WebUI .
# CLIP Interrogator by @pharmapsychotic
Twitter : https : / / twitter . com / pharmapsychotic
Github : https : / / github . com / pharmapsychotic / clip - interrogator
Description :
What do the different OpenAI CLIP models see in an image ? What might be a good text prompt to create similar images using CLIP guided diffusion
or another text to image model ? The CLIP Interrogator is here to get you answers !
Please consider buying him a coffee via [ ko - fi ] ( https : / / ko - fi . com / pharmapsychotic ) or following him on [ twitter ] ( https : / / twitter . com / pharmapsychotic ) .
And if you ' re looking for more Ai art tools check out my [Ai generative art tools list](https://pharmapsychotic.com/tools.html).
"""
#
# base webui import and utils.
from sd_utils import *
# streamlit imports
import streamlit_nested_layout
#streamlit components section
#other imports
2022-09-28 22:37:15 +03:00
import clip
import gc
import os
import pandas as pd
2022-09-29 17:27:56 +03:00
#import requests
2022-09-28 22:37:15 +03:00
import torch
from PIL import Image
#from torch import nn
#from torch.nn import functional as F
from torchvision import transforms
from torchvision . transforms . functional import InterpolationMode
from ldm . models . blip import blip_decoder
2022-09-28 19:33:54 +03:00
# end of imports
#---------------------------------------------------------------------------------------------------------------
2022-09-28 22:37:15 +03:00
device = torch . device ( ' cuda:0 ' if torch . cuda . is_available ( ) else ' cpu ' )
2022-09-29 17:27:56 +03:00
blip_image_eval_size = 256
2022-09-28 22:37:15 +03:00
#blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth'
def generate_caption ( pil_image ) :
blip_model = blip_decoder ( pretrained = " models/blip/model__base_caption.pth " , image_size = blip_image_eval_size , vit = ' base ' , med_config = " configs/blip/med_config.json " )
blip_model . eval ( )
blip_model = blip_model . to ( device )
gpu_image = transforms . Compose ( [
transforms . Resize ( ( blip_image_eval_size , blip_image_eval_size ) , interpolation = InterpolationMode . BICUBIC ) ,
transforms . ToTensor ( ) ,
transforms . Normalize ( ( 0.48145466 , 0.4578275 , 0.40821073 ) , ( 0.26862954 , 0.26130258 , 0.27577711 ) )
2022-09-29 11:06:17 +03:00
] ) ( pil_image ) . unsqueeze ( 0 ) . to ( device )
2022-09-28 22:37:15 +03:00
with torch . no_grad ( ) :
caption = blip_model . generate ( gpu_image , sample = False , num_beams = 3 , max_length = 20 , min_length = 5 )
return caption [ 0 ]
def load_list ( filename ) :
with open ( filename , ' r ' , encoding = ' utf-8 ' , errors = ' replace ' ) as f :
items = [ line . strip ( ) for line in f . readlines ( ) ]
return items
def rank ( model , image_features , text_array , top_count = 1 ) :
top_count = min ( top_count , len ( text_array ) )
text_tokens = clip . tokenize ( [ text for text in text_array ] ) . cuda ( )
with torch . no_grad ( ) :
text_features = model . encode_text ( text_tokens ) . float ( )
text_features / = text_features . norm ( dim = - 1 , keepdim = True )
similarity = torch . zeros ( ( 1 , len ( text_array ) ) ) . to ( device )
for i in range ( image_features . shape [ 0 ] ) :
similarity + = ( 100.0 * image_features [ i ] . unsqueeze ( 0 ) @ text_features . T ) . softmax ( dim = - 1 )
similarity / = image_features . shape [ 0 ]
top_probs , top_labels = similarity . cpu ( ) . topk ( top_count , dim = - 1 )
return [ ( text_array [ top_labels [ 0 ] [ i ] . numpy ( ) ] , ( top_probs [ 0 ] [ i ] . numpy ( ) * 100 ) ) for i in range ( top_count ) ]
def interrogate ( image , models ) :
2022-09-29 17:27:56 +03:00
print ( " Generating Caption " )
st . session_state [ " log_message " ] . code ( " Generating Caption " , language = ' ' )
2022-09-28 22:37:15 +03:00
caption = generate_caption ( image )
2022-09-29 17:27:56 +03:00
2022-09-28 22:37:15 +03:00
if len ( models ) == 0 :
print ( f " \n \n { caption } " )
return
table = [ ]
bests = [ [ ( ' ' , 0 ) ] ] * 5
for model_name in models :
print ( f " Interrogating with { model_name } ... " )
model , preprocess = clip . load ( model_name )
model . cuda ( ) . eval ( )
images = preprocess ( image ) . unsqueeze ( 0 ) . cuda ( )
with torch . no_grad ( ) :
image_features = model . encode_image ( images ) . float ( )
image_features / = image_features . norm ( dim = - 1 , keepdim = True )
ranks = [
2022-09-29 11:06:17 +03:00
rank ( model , image_features , server_state [ " mediums " ] ) ,
rank ( model , image_features , [ " by " + artist for artist in server_state [ " artists " ] ] ) ,
rank ( model , image_features , server_state [ " trending_list " ] ) ,
rank ( model , image_features , server_state [ " movements " ] ) ,
rank ( model , image_features , server_state [ " flavors " ] , top_count = 3 )
2022-09-28 22:37:15 +03:00
]
for i in range ( len ( ranks ) ) :
confidence_sum = 0
for ci in range ( len ( ranks [ i ] ) ) :
confidence_sum + = ranks [ i ] [ ci ] [ 1 ]
if confidence_sum > sum ( bests [ i ] [ t ] [ 1 ] for t in range ( len ( bests [ i ] ) ) ) :
bests [ i ] = ranks [ i ]
row = [ model_name ]
for r in ranks :
row . append ( ' , ' . join ( [ f " { x [ 0 ] } ( { x [ 1 ] : 0.1f } %) " for x in r ] ) )
table . append ( row )
del model
gc . collect ( )
2022-09-29 17:27:56 +03:00
st . session_state [ " prediction_table " ] . dataframe ( pd . DataFrame ( table , columns = [ " Model " , " Medium " , " Artist " , " Trending " , " Movement " , " Flavors " ] ) )
2022-09-28 22:37:15 +03:00
flaves = ' , ' . join ( [ f " { x [ 0 ] } " for x in bests [ 4 ] ] )
medium = bests [ 0 ] [ 0 ] [ 0 ]
if caption . startswith ( medium ) :
print ( f " \n \n { caption } { bests [ 1 ] [ 0 ] [ 0 ] } , { bests [ 2 ] [ 0 ] [ 0 ] } , { bests [ 3 ] [ 0 ] [ 0 ] } , { flaves } " )
else :
print ( f " \n \n { caption } , { medium } { bests [ 1 ] [ 0 ] [ 0 ] } , { bests [ 2 ] [ 0 ] [ 0 ] } , { bests [ 3 ] [ 0 ] [ 0 ] } , { flaves } " )
2022-09-28 19:33:54 +03:00
#
2022-09-28 22:37:15 +03:00
def img2txt ( ) :
data_path = " data/ "
2022-09-29 11:06:17 +03:00
server_state [ " artists " ] = load_list ( os . path . join ( data_path , ' img2txt ' , ' artists.txt ' ) )
server_state [ " flavors " ] = load_list ( os . path . join ( data_path , ' img2txt ' , ' flavors.txt ' ) )
server_state [ " mediums " ] = load_list ( os . path . join ( data_path , ' img2txt ' , ' mediums.txt ' ) )
server_state [ " movements " ] = load_list ( os . path . join ( data_path , ' img2txt ' , ' movements.txt ' ) )
server_state [ " sites " ] = load_list ( os . path . join ( data_path , ' img2txt ' , ' sites.txt ' ) )
2022-09-28 22:37:15 +03:00
2022-09-29 11:06:17 +03:00
trending_list = [ site for site in server_state [ " sites " ] ]
trending_list . extend ( [ " trending on " + site for site in server_state [ " sites " ] ] )
trending_list . extend ( [ " featured on " + site for site in server_state [ " sites " ] ] )
trending_list . extend ( [ site + " contest winner " for site in server_state [ " sites " ] ] )
2022-09-28 19:33:54 +03:00
2022-09-29 17:27:56 +03:00
#image_path_or_url = "https://i.redd.it/e2e8gimigjq91.jpg"
2022-09-28 19:33:54 +03:00
2022-09-29 17:27:56 +03:00
models = [ ]
2022-09-28 19:33:54 +03:00
2022-09-28 22:37:15 +03:00
if st . session_state [ " ViTB32 " ] :
2022-09-29 17:27:56 +03:00
models . append ( ' ViT-B/32 ' )
2022-09-28 22:37:15 +03:00
if st . session_state [ ' ViTB16 ' ] :
2022-09-29 17:27:56 +03:00
models . append ( ' ViT-B/16 ' )
2022-09-28 22:37:15 +03:00
if st . session_state [ " ViTL14 " ] :
2022-09-29 17:27:56 +03:00
models . append ( ' ViT-L/14 ' )
2022-09-28 22:37:15 +03:00
if st . session_state [ " ViTL14_336px " ] :
2022-09-29 17:27:56 +03:00
models . append ( ' ViT-L/14@336px ' )
2022-09-28 22:37:15 +03:00
if st . session_state [ " RN101 " ] :
2022-09-29 17:27:56 +03:00
models . append ( ' RN101 ' )
2022-09-28 22:37:15 +03:00
if st . session_state [ " RN50 " ] :
2022-09-29 17:27:56 +03:00
models . append ( ' RN50 ' )
2022-09-28 22:37:15 +03:00
if st . session_state [ " RN50x4 " ] :
2022-09-29 17:27:56 +03:00
models . append ( ' RN50x4 ' )
2022-09-28 22:37:15 +03:00
if st . session_state [ " RN50x16 " ] :
2022-09-29 17:27:56 +03:00
models . append ( ' RN50x16 ' )
2022-09-28 22:37:15 +03:00
if st . session_state [ " RN50x64 " ] :
2022-09-29 17:27:56 +03:00
models . append ( ' RN50x64 ' )
2022-09-28 19:33:54 +03:00
2022-09-29 17:27:56 +03:00
#if str(image_path_or_url).startswith('http://') or str(image_path_or_url).startswith('https://'):
#image = Image.open(requests.get(image_path_or_url, stream=True).raw).convert('RGB')
#else:
#image = Image.open(image_path_or_url).convert('RGB')
2022-09-28 22:37:15 +03:00
2022-09-29 17:27:56 +03:00
#thumb = st.session_state["uploaded_image"].image.copy()
#thumb.thumbnail([blip_image_eval_size, blip_image_eval_size])
2022-09-28 22:37:15 +03:00
#display(thumb)
2022-09-29 17:27:56 +03:00
interrogate ( st . session_state [ " uploaded_image " ] . pil_image , models = models )
2022-09-28 22:37:15 +03:00
#
def layout ( ) :
#set_page_title("Image-to-Text - Stable Diffusion WebUI")
#st.info("Under Construction. :construction_worker:")
with st . form ( " img2txt-inputs " ) :
st . session_state [ " generation_mode " ] = " img2txt "
2022-09-28 19:33:54 +03:00
2022-09-29 11:06:17 +03:00
#st.write("---")
2022-09-28 22:37:15 +03:00
# creating the page layout using columns
col1 , col2 , col3 = st . columns ( [ 1 , 2 , 1 ] , gap = " large " )
with col1 :
2022-09-29 11:06:17 +03:00
#url = st.text_area("Input Text","")
#url = st.text_input("Input Text","", placeholder="A corgi wearing a top hat as an oil painting.")
#st.subheader("Input Image")
st . session_state [ " uploaded_image " ] = st . file_uploader ( ' Input Image ' , type = [ ' png ' , ' jpg ' , ' jpeg ' ] )
st . subheader ( " CLIP models " )
with st . expander ( " Stable Diffusion " , expanded = True ) :
st . session_state [ " ViTL14 " ] = st . checkbox ( " ViTL14 " , value = True , help = " For StableDiffusion you can just use ViTL14. " )
2022-09-28 22:37:15 +03:00
2022-09-29 11:06:17 +03:00
with st . expander ( " Others " ) :
st . info ( " For DiscoDiffusion and JAX enable all the same models here as you intend to use when generating your images. " )
st . session_state [ " ViTL14_336px " ] = st . checkbox ( " ViTL14_336px " , value = False , help = " ViTL14_336px model. " )
st . session_state [ " ViTB16 " ] = st . checkbox ( " ViTB16 " , value = False , help = " ViTB16 model. " )
st . session_state [ " ViTB32 " ] = st . checkbox ( " ViTB32 " , value = False , help = " ViTB32 model. " )
st . session_state [ " RN50 " ] = st . checkbox ( " RN50 " , value = False , help = " RN50 model. " )
st . session_state [ " RN50x4 " ] = st . checkbox ( " RN50x4 " , value = False , help = " RN50x4 model. " )
st . session_state [ " RN50x16 " ] = st . checkbox ( " RN50x16 " , value = False , help = " RN50x16 model. " )
st . session_state [ " RN50x64 " ] = st . checkbox ( " RN50x64 " , value = False , help = " RN50x64 model. " )
st . session_state [ " RN101 " ] = st . checkbox ( " RN101 " , value = False , help = " RN101 model. " )
2022-09-28 22:37:15 +03:00
with col2 :
2022-09-29 11:06:17 +03:00
st . subheader ( " Image " )
st . form_submit_button ( " Refresh " ,
help = ' Refresh the image preview to show your uploaded image instead of the default placeholder. ' )
2022-09-28 22:37:15 +03:00
st . session_state [ " input_image_preview " ] = st . empty ( )
2022-09-29 11:06:17 +03:00
if st . session_state [ " uploaded_image " ] :
2022-09-29 17:27:56 +03:00
st . session_state [ " uploaded_image " ] . pil_image = Image . open ( st . session_state [ " uploaded_image " ] ) #.convert('RGBA')
2022-09-29 11:06:17 +03:00
#new_img = image.resize((width, height))
2022-09-29 17:27:56 +03:00
st . session_state [ " input_image_preview " ] . image ( st . session_state [ " uploaded_image " ] . pil_image , clamp = True )
2022-09-29 11:06:17 +03:00
else :
#st.session_state["input_image_preview"].code('', language="")
st . image ( " images/streamlit/img2txt_placeholder.png " , clamp = True )
2022-09-28 22:37:15 +03:00
with col3 :
2022-09-29 11:06:17 +03:00
st . subheader ( " Logs: " )
st . session_state [ " log_message " ] = st . empty ( )
st . session_state [ " log_message " ] . code ( ' ' , language = " " )
#
# Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way.
#generate_col1.title("")
#generate_col1.title("")
generate_button = st . form_submit_button ( " Generate! " )
#
st . write ( " --- " )
with st . container ( ) :
st . subheader ( " Image To Text Result " )
st . session_state [ " prediction_table " ] = st . empty ( )
st . session_state [ " prediction_table " ] . table ( )
st . session_state [ " text_result " ] = st . empty ( )
st . session_state [ " text_result " ] . code ( ' ' , language = " " )
if generate_button :
# run clip interrogator
img2txt ( )