2022-10-24 03:31:41 +03:00
# This file is part of sygil-webui (https://github.com/Sygil-Dev/sygil-webui/).
2022-09-28 19:33:54 +03:00
2022-10-24 03:17:50 +03:00
# Copyright 2022 Sygil-Dev team.
2022-09-28 19:33:54 +03:00
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
2022-10-01 01:20:02 +03:00
# along with this program. If not, see <http://www.gnu.org/licenses/>.
2022-09-28 19:33:54 +03:00
2022-10-01 01:20:02 +03:00
# ---------------------------------------------------------------------------------------------------------------------------------------------------
2022-09-28 19:33:54 +03:00
"""
CLIP Interrogator made by @pharmapsychotic modified to work with our WebUI .
2022-10-15 15:34:07 +03:00
# CLIP Interrogator by @pharmapsychotic
2022-09-28 19:33:54 +03:00
Twitter : https : / / twitter . com / pharmapsychotic
Github : https : / / github . com / pharmapsychotic / clip - interrogator
Description :
What do the different OpenAI CLIP models see in an image ? What might be a good text prompt to create similar images using CLIP guided diffusion
or another text to image model ? The CLIP Interrogator is here to get you answers !
Please consider buying him a coffee via [ ko - fi ] ( https : / / ko - fi . com / pharmapsychotic ) or following him on [ twitter ] ( https : / / twitter . com / pharmapsychotic ) .
And if you ' re looking for more Ai art tools check out my [Ai generative art tools list](https://pharmapsychotic.com/tools.html).
"""
2022-10-01 01:20:02 +03:00
# ---------------------------------------------------------------------------------------------------------------------------------------------------
2022-09-28 19:33:54 +03:00
# base webui import and utils.
from sd_utils import *
# streamlit imports
2022-10-01 01:20:02 +03:00
# streamlit components section
2022-09-30 22:40:52 +03:00
import streamlit_nested_layout
2022-09-28 19:33:54 +03:00
2022-10-01 01:20:02 +03:00
# other imports
2022-09-30 22:40:52 +03:00
2022-10-01 01:20:02 +03:00
import clip
import open_clip
2022-09-28 22:37:15 +03:00
import gc
import os
import pandas as pd
2022-09-29 17:27:56 +03:00
#import requests
2022-09-28 22:37:15 +03:00
import torch
from PIL import Image
from torchvision import transforms
from torchvision . transforms . functional import InterpolationMode
from ldm . models . blip import blip_decoder
2022-10-18 01:59:29 +03:00
#import hashlib
2022-09-28 19:33:54 +03:00
# end of imports
2022-10-01 01:20:02 +03:00
# ---------------------------------------------------------------------------------------------------------------
2022-09-28 19:33:54 +03:00
2022-09-28 22:37:15 +03:00
device = torch . device ( ' cuda:0 ' if torch . cuda . is_available ( ) else ' cpu ' )
2022-09-30 18:47:30 +03:00
blip_image_eval_size = 512
2022-10-17 15:10:06 +03:00
2022-10-24 14:45:09 +03:00
st . session_state [ " log " ] = [ ]
2022-09-29 18:52:46 +03:00
def load_blip_model ( ) :
2022-10-15 15:34:07 +03:00
logger . info ( " Loading BLIP Model " )
2022-10-27 03:21:51 +03:00
if " log " not in st . session_state :
st . session_state [ " log " ] = [ ]
2022-10-24 14:45:09 +03:00
st . session_state [ " log " ] . append ( " Loading BLIP Model " )
st . session_state [ " log_message " ] . code ( ' \n ' . join ( st . session_state [ " log " ] ) , language = ' ' )
2022-09-30 22:40:52 +03:00
2022-10-01 01:20:02 +03:00
if " blip_model " not in server_state :
with server_state_lock [ ' blip_model ' ] :
server_state [ " blip_model " ] = blip_decoder ( pretrained = " models/blip/model__base_caption.pth " ,
image_size = blip_image_eval_size , vit = ' base ' , med_config = " configs/blip/med_config.json " )
2022-10-15 15:34:07 +03:00
2022-10-01 01:20:02 +03:00
server_state [ " blip_model " ] = server_state [ " blip_model " ] . eval ( )
2022-10-15 15:34:07 +03:00
2022-10-01 01:20:02 +03:00
server_state [ " blip_model " ] = server_state [ " blip_model " ] . to ( device ) . half ( )
2022-10-15 15:34:07 +03:00
logger . info ( " BLIP Model Loaded " )
2022-10-24 14:45:09 +03:00
st . session_state [ " log " ] . append ( " BLIP Model Loaded " )
st . session_state [ " log_message " ] . code ( ' \n ' . join ( st . session_state [ " log " ] ) , language = ' ' )
2022-10-01 01:20:02 +03:00
else :
2022-10-15 15:34:07 +03:00
logger . info ( " BLIP Model already loaded " )
2022-10-24 14:45:09 +03:00
st . session_state [ " log " ] . append ( " BLIP Model already loaded " )
st . session_state [ " log_message " ] . code ( ' \n ' . join ( st . session_state [ " log " ] ) , language = ' ' )
2022-10-01 01:20:02 +03:00
2022-09-30 22:40:52 +03:00
2022-09-28 22:37:15 +03:00
def generate_caption ( pil_image ) :
2022-09-30 22:40:52 +03:00
2022-10-01 01:20:02 +03:00
load_blip_model ( )
2022-10-15 15:34:07 +03:00
2022-10-01 01:20:02 +03:00
gpu_image = transforms . Compose ( [ # type: ignore
transforms . Resize ( ( blip_image_eval_size , blip_image_eval_size ) , interpolation = InterpolationMode . BICUBIC ) , # type: ignore
transforms . ToTensor ( ) , # type: ignore
transforms . Normalize ( ( 0.48145466 , 0.4578275 , 0.40821073 ) , ( 0.26862954 , 0.26130258 , 0.27577711 ) ) # type: ignore
] ) ( pil_image ) . unsqueeze ( 0 ) . to ( device ) . half ( )
2022-09-28 22:37:15 +03:00
2022-09-30 12:40:02 +03:00
with torch . no_grad ( ) :
2022-10-01 01:20:02 +03:00
caption = server_state [ " blip_model " ] . generate ( gpu_image , sample = False , num_beams = 3 , max_length = 20 , min_length = 5 )
2022-09-30 22:40:52 +03:00
2022-09-30 12:40:02 +03:00
return caption [ 0 ]
2022-09-28 22:37:15 +03:00
def load_list ( filename ) :
2022-09-30 18:47:30 +03:00
with open ( filename , ' r ' , encoding = ' utf-8 ' , errors = ' replace ' ) as f :
items = [ line . strip ( ) for line in f . readlines ( ) ]
2022-09-30 22:40:52 +03:00
return items
2022-10-01 01:20:02 +03:00
2022-09-28 22:37:15 +03:00
def rank ( model , image_features , text_array , top_count = 1 ) :
2022-09-30 12:40:02 +03:00
top_count = min ( top_count , len ( text_array ) )
text_tokens = clip . tokenize ( [ text for text in text_array ] ) . cuda ( )
with torch . no_grad ( ) :
text_features = model . encode_text ( text_tokens ) . float ( )
text_features / = text_features . norm ( dim = - 1 , keepdim = True )
2022-09-28 22:37:15 +03:00
2022-09-30 12:40:02 +03:00
similarity = torch . zeros ( ( 1 , len ( text_array ) ) ) . to ( device )
for i in range ( image_features . shape [ 0 ] ) :
similarity + = ( 100.0 * image_features [ i ] . unsqueeze ( 0 ) @ text_features . T ) . softmax ( dim = - 1 )
similarity / = image_features . shape [ 0 ]
2022-09-28 22:37:15 +03:00
2022-10-01 01:20:02 +03:00
top_probs , top_labels = similarity . cpu ( ) . topk ( top_count , dim = - 1 )
2022-09-30 12:40:02 +03:00
return [ ( text_array [ top_labels [ 0 ] [ i ] . numpy ( ) ] , ( top_probs [ 0 ] [ i ] . numpy ( ) * 100 ) ) for i in range ( top_count ) ]
2022-10-01 01:20:02 +03:00
2022-09-30 12:40:02 +03:00
def clear_cuda ( ) :
2022-09-30 18:47:30 +03:00
torch . cuda . empty_cache ( )
gc . collect ( )
2022-09-28 22:37:15 +03:00
2022-10-01 01:20:02 +03:00
2022-09-30 22:40:52 +03:00
def batch_rank ( model , image_features , text_array , batch_size = st . session_state [ " defaults " ] . img2txt . batch_size ) :
2022-10-05 07:31:20 +03:00
batch_size = min ( batch_size , len ( text_array ) )
batch_count = int ( len ( text_array ) / batch_size )
2022-09-30 22:40:52 +03:00
batches = [ text_array [ i * batch_size : ( i + 1 ) * batch_size ] for i in range ( batch_count ) ]
ranks = [ ]
for batch in batches :
ranks + = rank ( model , image_features , batch )
return ranks
2022-09-28 22:37:15 +03:00
def interrogate ( image , models ) :
2022-10-01 01:20:02 +03:00
load_blip_model ( )
2022-10-15 15:34:07 +03:00
logger . info ( " Generating Caption " )
2022-10-24 14:45:09 +03:00
st . session_state [ " log " ] . append ( " Generating Caption " )
st . session_state [ " log_message " ] . code ( ' \n ' . join ( st . session_state [ " log " ] ) , language = ' ' )
2022-09-30 18:47:30 +03:00
caption = generate_caption ( image )
2022-10-01 01:20:02 +03:00
2022-09-30 22:40:52 +03:00
if st . session_state [ " defaults " ] . general . optimized :
2022-10-01 01:20:02 +03:00
del server_state [ " blip_model " ]
2022-09-30 22:40:52 +03:00
clear_cuda ( )
2022-10-15 15:34:07 +03:00
logger . info ( " Caption Generated " )
2022-10-24 14:45:09 +03:00
st . session_state [ " log " ] . append ( " Caption Generated " )
st . session_state [ " log_message " ] . code ( ' \n ' . join ( st . session_state [ " log " ] ) , language = ' ' )
2022-09-30 18:47:30 +03:00
if len ( models ) == 0 :
2022-10-15 15:34:07 +03:00
logger . info ( f " \n \n { caption } " )
2022-09-30 18:47:30 +03:00
return
2022-10-18 01:59:29 +03:00
table = [ ]
2022-10-24 10:22:36 +03:00
bests = [ [ ( ' ' , 0 ) ] ] * 7
2022-10-01 01:20:02 +03:00
2022-10-15 15:34:07 +03:00
logger . info ( " Ranking Text " )
2022-10-24 14:45:09 +03:00
st . session_state [ " log " ] . append ( " Ranking Text " )
st . session_state [ " log_message " ] . code ( ' \n ' . join ( st . session_state [ " log " ] ) , language = ' ' )
2022-10-15 15:34:07 +03:00
2022-09-30 18:47:30 +03:00
for model_name in models :
2022-10-01 23:47:26 +03:00
with torch . no_grad ( ) , torch . autocast ( ' cuda ' , dtype = torch . float16 ) :
2022-10-15 15:34:07 +03:00
logger . info ( f " Interrogating with { model_name } ... " )
2022-10-24 14:45:09 +03:00
st . session_state [ " log " ] . append ( f " Interrogating with { model_name } ... " )
st . session_state [ " log_message " ] . code ( ' \n ' . join ( st . session_state [ " log " ] ) , language = ' ' )
2022-10-15 15:34:07 +03:00
2022-10-01 23:47:26 +03:00
if model_name not in server_state [ " clip_models " ] :
2022-10-05 07:31:20 +03:00
if not st . session_state [ " defaults " ] . img2txt . keep_all_models_loaded :
model_to_delete = [ ]
for model in server_state [ " clip_models " ] :
if model != model_name :
model_to_delete . append ( model )
for model in model_to_delete :
del server_state [ " clip_models " ] [ model ]
del server_state [ " preprocesses " ] [ model ]
clear_cuda ( )
2022-10-01 23:47:26 +03:00
if model_name == ' ViT-H-14 ' :
2022-10-24 13:48:45 +03:00
server_state [ " clip_models " ] [ model_name ] , _ , server_state [ " preprocesses " ] [ model_name ] = \
open_clip . create_model_and_transforms ( model_name , pretrained = ' laion2b_s32b_b79k ' , cache_dir = ' models/clip ' )
2022-10-01 23:47:26 +03:00
elif model_name == ' ViT-g-14 ' :
2022-10-24 13:48:45 +03:00
server_state [ " clip_models " ] [ model_name ] , _ , server_state [ " preprocesses " ] [ model_name ] = \
open_clip . create_model_and_transforms ( model_name , pretrained = ' laion2b_s12b_b42k ' , cache_dir = ' models/clip ' )
2022-10-01 23:47:26 +03:00
else :
2022-10-24 13:48:45 +03:00
server_state [ " clip_models " ] [ model_name ] , server_state [ " preprocesses " ] [ model_name ] = \
clip . load ( model_name , device = device , download_root = ' models/clip ' )
2022-10-01 23:47:26 +03:00
server_state [ " clip_models " ] [ model_name ] = server_state [ " clip_models " ] [ model_name ] . cuda ( ) . eval ( )
2022-10-15 15:34:07 +03:00
2022-10-01 23:47:26 +03:00
images = server_state [ " preprocesses " ] [ model_name ] ( image ) . unsqueeze ( 0 ) . cuda ( )
2022-10-15 15:34:07 +03:00
2022-10-01 23:47:26 +03:00
image_features = server_state [ " clip_models " ] [ model_name ] . encode_image ( images ) . float ( )
2022-10-15 15:34:07 +03:00
2022-10-01 23:47:26 +03:00
image_features / = image_features . norm ( dim = - 1 , keepdim = True )
if st . session_state [ " defaults " ] . general . optimized :
clear_cuda ( )
2022-10-15 15:34:07 +03:00
2022-10-01 23:47:26 +03:00
ranks = [ ]
ranks . append ( batch_rank ( server_state [ " clip_models " ] [ model_name ] , image_features , server_state [ " mediums " ] ) )
ranks . append ( batch_rank ( server_state [ " clip_models " ] [ model_name ] , image_features , [ " by " + artist for artist in server_state [ " artists " ] ] ) )
ranks . append ( batch_rank ( server_state [ " clip_models " ] [ model_name ] , image_features , server_state [ " trending_list " ] ) )
ranks . append ( batch_rank ( server_state [ " clip_models " ] [ model_name ] , image_features , server_state [ " movements " ] ) )
ranks . append ( batch_rank ( server_state [ " clip_models " ] [ model_name ] , image_features , server_state [ " flavors " ] ) )
2022-10-24 10:22:36 +03:00
#ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["domains"]))
#ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["subreddits"]))
2022-10-24 10:20:44 +03:00
ranks . append ( batch_rank ( server_state [ " clip_models " ] [ model_name ] , image_features , server_state [ " techniques " ] ) )
ranks . append ( batch_rank ( server_state [ " clip_models " ] [ model_name ] , image_features , server_state [ " tags " ] ) )
2022-10-01 23:47:26 +03:00
# ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["genres"]))
# ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["styles"]))
# ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["subjects"]))
# ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["colors"]))
# ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["moods"]))
# ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["themes"]))
# ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["keywords"]))
2022-10-17 15:10:06 +03:00
#print (bests)
#print (ranks)
2022-10-01 23:47:26 +03:00
for i in range ( len ( ranks ) ) :
confidence_sum = 0
for ci in range ( len ( ranks [ i ] ) ) :
confidence_sum + = ranks [ i ] [ ci ] [ 1 ]
if confidence_sum > sum ( bests [ i ] [ t ] [ 1 ] for t in range ( len ( bests [ i ] ) ) ) :
bests [ i ] = ranks [ i ]
2022-10-24 13:15:22 +03:00
for best in bests :
best . sort ( key = lambda x : x [ 1 ] , reverse = True )
2022-10-27 03:21:51 +03:00
# prune to 3
2022-10-24 13:15:22 +03:00
best = best [ : 3 ]
2022-10-01 23:47:26 +03:00
2022-10-24 13:15:22 +03:00
row = [ model_name ]
2022-10-24 14:45:09 +03:00
for r in ranks :
row . append ( ' , ' . join ( [ f " { x [ 0 ] } ( { x [ 1 ] : 0.1f } %) " for x in r ] ) )
#for rank in ranks:
# rank.sort(key=lambda x: x[1], reverse=True)
# row.append(f'{rank[0][0]} {rank[0][1]:.2f}%')
2022-10-01 23:47:26 +03:00
table . append ( row )
if st . session_state [ " defaults " ] . general . optimized :
del server_state [ " clip_models " ] [ model_name ]
gc . collect ( )
2022-10-01 01:20:02 +03:00
2022-09-30 22:40:52 +03:00
st . session_state [ " prediction_table " ] [ st . session_state [ " processed_image_count " ] ] . dataframe ( pd . DataFrame (
2022-10-24 10:22:36 +03:00
table , columns = [ " Model " , " Medium " , " Artist " , " Trending " , " Movement " , " Flavors " , " Techniques " , " Tags " ] ) )
2022-09-30 18:47:30 +03:00
medium = bests [ 0 ] [ 0 ] [ 0 ]
2022-10-24 13:15:22 +03:00
artist = bests [ 1 ] [ 0 ] [ 0 ]
trending = bests [ 2 ] [ 0 ] [ 0 ]
movement = bests [ 3 ] [ 0 ] [ 0 ]
flavors = bests [ 4 ] [ 0 ] [ 0 ]
#domains = bests[5][0][0]
#subreddits = bests[6][0][0]
techniques = bests [ 5 ] [ 0 ] [ 0 ]
tags = bests [ 6 ] [ 0 ] [ 0 ]
2022-10-17 15:10:06 +03:00
2022-10-01 01:20:02 +03:00
if caption . startswith ( medium ) :
2022-09-30 22:40:52 +03:00
st . session_state [ " text_result " ] [ st . session_state [ " processed_image_count " ] ] . code (
2022-10-24 13:15:22 +03:00
f " \n \n { caption } { artist } , { trending } , { movement } , { techniques } , { flavors } , { tags } " , language = " " )
2022-09-30 22:40:52 +03:00
else :
st . session_state [ " text_result " ] [ st . session_state [ " processed_image_count " ] ] . code (
2022-10-24 13:15:22 +03:00
f " \n \n { caption } , { medium } { artist } , { trending } , { movement } , { techniques } , { flavors } , { tags } " , language = " " )
2022-10-01 01:20:02 +03:00
2022-10-15 15:34:07 +03:00
logger . info ( " Finished Interrogating. " )
2022-10-24 14:45:09 +03:00
st . session_state [ " log " ] . append ( " Finished Interrogating. " )
st . session_state [ " log_message " ] . code ( ' \n ' . join ( st . session_state [ " log " ] ) , language = ' ' )
2022-10-01 01:20:02 +03:00
2022-09-28 22:37:15 +03:00
def img2txt ( ) :
2022-09-30 18:47:30 +03:00
models = [ ]
2022-10-05 07:31:20 +03:00
if st . session_state [ " ViT-L/14 " ] :
2022-09-30 18:47:30 +03:00
models . append ( ' ViT-L/14 ' )
2022-10-01 01:20:02 +03:00
if st . session_state [ " ViT-H-14 " ] :
2022-09-30 22:40:52 +03:00
models . append ( ' ViT-H-14 ' )
2022-10-01 01:20:02 +03:00
if st . session_state [ " ViT-g-14 " ] :
2022-09-30 22:40:52 +03:00
models . append ( ' ViT-g-14 ' )
2022-10-15 15:34:07 +03:00
2022-10-06 09:48:32 +03:00
if st . session_state [ " ViTB32 " ] :
models . append ( ' ViT-B/32 ' )
if st . session_state [ ' ViTB16 ' ] :
2022-10-15 15:34:07 +03:00
models . append ( ' ViT-B/16 ' )
2022-10-06 09:48:32 +03:00
if st . session_state [ " ViTL14_336px " ] :
models . append ( ' ViT-L/14@336px ' )
if st . session_state [ " RN101 " ] :
models . append ( ' RN101 ' )
if st . session_state [ " RN50 " ] :
models . append ( ' RN50 ' )
if st . session_state [ " RN50x4 " ] :
models . append ( ' RN50x4 ' )
if st . session_state [ " RN50x16 " ] :
models . append ( ' RN50x16 ' )
if st . session_state [ " RN50x64 " ] :
2022-10-15 15:34:07 +03:00
models . append ( ' RN50x64 ' )
2022-09-30 18:47:30 +03:00
2022-10-01 01:20:02 +03:00
# if str(image_path_or_url).startswith('http://') or str(image_path_or_url).startswith('https://'):
2022-09-30 18:47:30 +03:00
#image = Image.open(requests.get(image_path_or_url, stream=True).raw).convert('RGB')
2022-10-01 01:20:02 +03:00
# else:
2022-09-30 18:47:30 +03:00
#image = Image.open(image_path_or_url).convert('RGB')
#thumb = st.session_state["uploaded_image"].image.copy()
#thumb.thumbnail([blip_image_eval_size, blip_image_eval_size])
2022-10-01 01:20:02 +03:00
# display(thumb)
2022-09-30 18:47:30 +03:00
2022-09-30 22:40:52 +03:00
st . session_state [ " processed_image_count " ] = 0
2022-10-01 01:20:02 +03:00
2022-09-30 22:40:52 +03:00
for i in range ( len ( st . session_state [ " uploaded_image " ] ) ) :
2022-10-01 01:20:02 +03:00
2022-09-30 22:40:52 +03:00
interrogate ( st . session_state [ " uploaded_image " ] [ i ] . pil_image , models = models )
# increase counter.
st . session_state [ " processed_image_count " ] + = 1
2022-09-28 22:37:15 +03:00
#
2022-10-01 01:20:02 +03:00
2022-09-28 22:37:15 +03:00
def layout ( ) :
2022-09-30 18:47:30 +03:00
#set_page_title("Image-to-Text - Stable Diffusion WebUI")
2022-10-01 01:20:02 +03:00
#st.info("Under Construction. :construction_worker:")
2022-10-27 03:21:51 +03:00
#
2022-10-24 13:48:45 +03:00
if " clip_models " not in server_state :
server_state [ " clip_models " ] = { }
if " preprocesses " not in server_state :
server_state [ " preprocesses " ] = { }
data_path = " data/ "
if " artists " not in server_state :
server_state [ " artists " ] = load_list ( os . path . join ( data_path , ' img2txt ' , ' artists.txt ' ) )
if " flavors " not in server_state :
server_state [ " flavors " ] = random . choices ( load_list ( os . path . join ( data_path , ' img2txt ' , ' flavors.txt ' ) ) , k = 2000 )
if " mediums " not in server_state :
server_state [ " mediums " ] = load_list ( os . path . join ( data_path , ' img2txt ' , ' mediums.txt ' ) )
if " movements " not in server_state :
server_state [ " movements " ] = load_list ( os . path . join ( data_path , ' img2txt ' , ' movements.txt ' ) )
if " sites " not in server_state :
server_state [ " sites " ] = load_list ( os . path . join ( data_path , ' img2txt ' , ' sites.txt ' ) )
#server_state["domains"] = load_list(os.path.join(data_path, 'img2txt', 'domains.txt'))
#server_state["subreddits"] = load_list(os.path.join(data_path, 'img2txt', 'subreddits.txt'))
if " techniques " not in server_state :
server_state [ " techniques " ] = load_list ( os . path . join ( data_path , ' img2txt ' , ' techniques.txt ' ) )
if " tags " not in server_state :
server_state [ " tags " ] = load_list ( os . path . join ( data_path , ' img2txt ' , ' tags.txt ' ) )
#server_state["genres"] = load_list(os.path.join(data_path, 'img2txt', 'genres.txt'))
# server_state["styles"] = load_list(os.path.join(data_path, 'img2txt', 'styles.txt'))
# server_state["subjects"] = load_list(os.path.join(data_path, 'img2txt', 'subjects.txt'))
if " trending_list " not in server_state :
server_state [ " trending_list " ] = [ site for site in server_state [ " sites " ] ]
server_state [ " trending_list " ] . extend ( [ " trending on " + site for site in server_state [ " sites " ] ] )
server_state [ " trending_list " ] . extend ( [ " featured on " + site for site in server_state [ " sites " ] ] )
server_state [ " trending_list " ] . extend ( [ site + " contest winner " for site in server_state [ " sites " ] ] )
2022-09-30 18:47:30 +03:00
with st . form ( " img2txt-inputs " ) :
st . session_state [ " generation_mode " ] = " img2txt "
2022-10-01 01:20:02 +03:00
# st.write("---")
2022-09-30 18:47:30 +03:00
# creating the page layout using columns
2022-10-01 01:20:02 +03:00
col1 , col2 = st . columns ( [ 1 , 4 ] , gap = " large " )
2022-09-30 18:47:30 +03:00
with col1 :
2022-10-31 23:38:31 +03:00
st . session_state [ " uploaded_image " ] = st . file_uploader ( ' Input Image ' , type = [ ' png ' , ' jpg ' , ' jpeg ' , ' jfif ' , ' webp ' ] , accept_multiple_files = True )
2022-09-30 18:47:30 +03:00
2022-10-05 07:31:20 +03:00
with st . expander ( " CLIP models " , expanded = True ) :
st . session_state [ " ViT-L/14 " ] = st . checkbox ( " ViT-L/14 " , value = True , help = " ViT-L/14 model. " )
2022-09-30 22:40:52 +03:00
st . session_state [ " ViT-H-14 " ] = st . checkbox ( " ViT-H-14 " , value = False , help = " ViT-H-14 model. " )
2022-10-01 01:20:02 +03:00
st . session_state [ " ViT-g-14 " ] = st . checkbox ( " ViT-g-14 " , value = False , help = " ViT-g-14 model. " )
2022-09-30 18:47:30 +03:00
2022-10-15 15:34:07 +03:00
2022-10-06 09:48:32 +03:00
with st . expander ( " Others " ) :
2022-10-15 15:34:07 +03:00
st . info ( " For DiscoDiffusion and JAX enable all the same models here as you intend to use when generating your images. " )
2022-10-06 09:48:32 +03:00
st . session_state [ " ViTL14_336px " ] = st . checkbox ( " ViTL14_336px " , value = False , help = " ViTL14_336px model. " )
st . session_state [ " ViTB16 " ] = st . checkbox ( " ViTB16 " , value = False , help = " ViTB16 model. " )
st . session_state [ " ViTB32 " ] = st . checkbox ( " ViTB32 " , value = False , help = " ViTB32 model. " )
st . session_state [ " RN50 " ] = st . checkbox ( " RN50 " , value = False , help = " RN50 model. " )
st . session_state [ " RN50x4 " ] = st . checkbox ( " RN50x4 " , value = False , help = " RN50x4 model. " )
st . session_state [ " RN50x16 " ] = st . checkbox ( " RN50x16 " , value = False , help = " RN50x16 model. " )
st . session_state [ " RN50x64 " ] = st . checkbox ( " RN50x64 " , value = False , help = " RN50x64 model. " )
2022-10-15 15:34:07 +03:00
st . session_state [ " RN101 " ] = st . checkbox ( " RN101 " , value = False , help = " RN101 model. " )
2022-09-30 18:47:30 +03:00
#
2022-10-01 01:20:02 +03:00
# st.subheader("Logs:")
2022-09-30 22:40:52 +03:00
st . session_state [ " log_message " ] = st . empty ( )
st . session_state [ " log_message " ] . code ( ' ' , language = " " )
2022-09-30 18:47:30 +03:00
with col2 :
st . subheader ( " Image " )
2022-10-27 04:40:17 +03:00
image_col1 , image_col2 = st . columns ( [ 10 , 25 ] )
with image_col1 :
refresh = st . form_submit_button ( " Update Preview Image " , help = ' Refresh the image preview to show your uploaded image instead of the default placeholder. ' )
2022-10-01 01:20:02 +03:00
2022-09-30 18:47:30 +03:00
if st . session_state [ " uploaded_image " ] :
2022-09-30 22:40:52 +03:00
#print (type(st.session_state["uploaded_image"]))
2022-10-01 01:20:02 +03:00
# if len(st.session_state["uploaded_image"]) == 1:
2022-09-30 22:40:52 +03:00
st . session_state [ " input_image_preview " ] = [ ]
st . session_state [ " input_image_preview_container " ] = [ ]
st . session_state [ " prediction_table " ] = [ ]
st . session_state [ " text_result " ] = [ ]
2022-10-01 01:20:02 +03:00
2022-09-30 22:40:52 +03:00
for i in range ( len ( st . session_state [ " uploaded_image " ] ) ) :
st . session_state [ " input_image_preview_container " ] . append ( i )
2022-10-01 01:20:02 +03:00
st . session_state [ " input_image_preview_container " ] [ i ] = st . empty ( )
2022-09-30 22:40:52 +03:00
with st . session_state [ " input_image_preview_container " ] [ i ] . container ( ) :
2022-10-01 01:20:02 +03:00
col1_output , col2_output = st . columns ( [ 2 , 10 ] , gap = " medium " )
2022-09-30 18:47:30 +03:00
with col1_output :
2022-09-30 22:40:52 +03:00
st . session_state [ " input_image_preview " ] . append ( i )
2022-10-01 01:20:02 +03:00
st . session_state [ " input_image_preview " ] [ i ] = st . empty ( )
2022-09-30 22:40:52 +03:00
st . session_state [ " uploaded_image " ] [ i ] . pil_image = Image . open ( st . session_state [ " uploaded_image " ] [ i ] ) . convert ( ' RGB ' )
2022-10-01 01:20:02 +03:00
2022-09-30 22:40:52 +03:00
st . session_state [ " input_image_preview " ] [ i ] . image ( st . session_state [ " uploaded_image " ] [ i ] . pil_image , use_column_width = True , clamp = True )
2022-10-01 01:20:02 +03:00
with st . session_state [ " input_image_preview_container " ] [ i ] . container ( ) :
2022-09-30 18:47:30 +03:00
with col2_output :
2022-09-30 22:40:52 +03:00
st . session_state [ " prediction_table " ] . append ( i )
st . session_state [ " prediction_table " ] [ i ] = st . empty ( )
st . session_state [ " prediction_table " ] [ i ] . table ( )
2022-10-01 01:20:02 +03:00
2022-09-30 22:40:52 +03:00
st . session_state [ " text_result " ] . append ( i )
2022-10-01 01:20:02 +03:00
st . session_state [ " text_result " ] [ i ] = st . empty ( )
2022-09-30 22:40:52 +03:00
st . session_state [ " text_result " ] [ i ] . code ( " " , language = " " )
2022-09-30 18:47:30 +03:00
else :
#st.session_state["input_image_preview"].code('', language="")
st . image ( " images/streamlit/img2txt_placeholder.png " , clamp = True )
2022-10-27 04:40:17 +03:00
with image_col2 :
#
# Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way.
# generate_col1.title("")
# generate_col1.title("")
generate_button = st . form_submit_button ( " Generate! " , help = " Start interrogating the images to generate a prompt from each of the selected images " )
2022-09-30 18:47:30 +03:00
2022-10-01 01:20:02 +03:00
if generate_button :
2022-09-30 22:40:52 +03:00
# if model, pipe, RealESRGAN or GFPGAN is in st.session_state remove the model and pipe form session_state so that they are reloaded.
2022-10-05 07:31:20 +03:00
if " model " in server_state and st . session_state [ " defaults " ] . general . optimized :
del server_state [ " model " ]
if " pipe " in server_state and st . session_state [ " defaults " ] . general . optimized :
del server_state [ " pipe " ]
if " RealESRGAN " in server_state and st . session_state [ " defaults " ] . general . optimized :
del server_state [ " RealESRGAN " ]
if " GFPGAN " in server_state and st . session_state [ " defaults " ] . general . optimized :
del server_state [ " GFPGAN " ]
2022-10-01 01:20:02 +03:00
2022-09-30 18:47:30 +03:00
# run clip interrogator
2022-10-01 01:20:02 +03:00
img2txt ( )