mirror of
https://github.com/Sygil-Dev/sygil-webui.git
synced 2024-12-14 22:13:41 +03:00
...
This commit is contained in:
parent
20a89a1233
commit
857608c5f6
@ -294,7 +294,7 @@ img2img:
|
|||||||
|
|
||||||
img2txt:
|
img2txt:
|
||||||
batch_size: 100
|
batch_size: 100
|
||||||
|
blip_image_eval_size: 512
|
||||||
concepts_library:
|
concepts_library:
|
||||||
concepts_per_page: 12
|
concepts_per_page: 12
|
||||||
|
|
||||||
|
@ -12,9 +12,9 @@
|
|||||||
# GNU Affero General Public License for more details.
|
# GNU Affero General Public License for more details.
|
||||||
|
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
#---------------------------------------------------------------------------------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------------------------------------------------------------------------------
|
||||||
"""
|
"""
|
||||||
CLIP Interrogator made by @pharmapsychotic modified to work with our WebUI.
|
CLIP Interrogator made by @pharmapsychotic modified to work with our WebUI.
|
||||||
|
|
||||||
@ -31,20 +31,20 @@ Please consider buying him a coffee via [ko-fi](https://ko-fi.com/pharmapsychoti
|
|||||||
And if you're looking for more Ai art tools check out my [Ai generative art tools list](https://pharmapsychotic.com/tools.html).
|
And if you're looking for more Ai art tools check out my [Ai generative art tools list](https://pharmapsychotic.com/tools.html).
|
||||||
|
|
||||||
"""
|
"""
|
||||||
#---------------------------------------------------------------------------------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
# base webui import and utils.
|
# base webui import and utils.
|
||||||
from ldm.util import default
|
|
||||||
from sd_utils import *
|
from sd_utils import *
|
||||||
|
|
||||||
# streamlit imports
|
# streamlit imports
|
||||||
|
|
||||||
#streamlit components section
|
# streamlit components section
|
||||||
import streamlit_nested_layout
|
import streamlit_nested_layout
|
||||||
|
|
||||||
#other imports
|
# other imports
|
||||||
|
|
||||||
import clip, open_clip
|
import clip
|
||||||
|
import open_clip
|
||||||
import gc
|
import gc
|
||||||
import os
|
import os
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
@ -56,53 +56,59 @@ from torchvision.transforms.functional import InterpolationMode
|
|||||||
from ldm.models.blip import blip_decoder
|
from ldm.models.blip import blip_decoder
|
||||||
|
|
||||||
# end of imports
|
# end of imports
|
||||||
#---------------------------------------------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
||||||
blip_image_eval_size = 512
|
blip_image_eval_size = 512
|
||||||
blip_model = None
|
#blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth'
|
||||||
#blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth'
|
|
||||||
|
|
||||||
def load_blip_model():
|
def load_blip_model():
|
||||||
print ("Loading BLIP Model")
|
print("Loading BLIP Model")
|
||||||
st.session_state["log_message"].code("Loading BLIP Model", language='')
|
st.session_state["log_message"].code("Loading BLIP Model", language='')
|
||||||
|
|
||||||
with server_state_lock['blip_model']:
|
if "blip_model" not in server_state:
|
||||||
if "blip_model" not in server_state:
|
with server_state_lock['blip_model']:
|
||||||
blip_model = blip_decoder(pretrained="models/blip/model__base_caption.pth",
|
server_state["blip_model"] = blip_decoder(pretrained="models/blip/model__base_caption.pth",
|
||||||
image_size=blip_image_eval_size, vit='base', med_config="configs/blip/med_config.json")
|
image_size=blip_image_eval_size, vit='base', med_config="configs/blip/med_config.json")
|
||||||
blip_model.eval()
|
|
||||||
blip_model = blip_model.to(device).half()
|
server_state["blip_model"] = server_state["blip_model"].eval()
|
||||||
|
|
||||||
print ("BLIP Model Loaded")
|
#if not st.session_state["defaults"].general.optimized:
|
||||||
|
server_state["blip_model"] = server_state["blip_model"].to(device).half()
|
||||||
|
|
||||||
|
print("BLIP Model Loaded")
|
||||||
st.session_state["log_message"].code("BLIP Model Loaded", language='')
|
st.session_state["log_message"].code("BLIP Model Loaded", language='')
|
||||||
else:
|
else:
|
||||||
print ("BLIP Model already loaded")
|
print("BLIP Model already loaded")
|
||||||
st.session_state["log_message"].code("BLIP Model Already Loaded", language='')
|
st.session_state["log_message"].code("BLIP Model Already Loaded", language='')
|
||||||
|
|
||||||
|
#return server_state["blip_model"]
|
||||||
|
|
||||||
return blip_model
|
|
||||||
|
|
||||||
def generate_caption(pil_image):
|
def generate_caption(pil_image):
|
||||||
global blip_model
|
|
||||||
#width, height = pil_image.size
|
|
||||||
|
|
||||||
gpu_image = transforms.Compose([
|
load_blip_model()
|
||||||
transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC),
|
|
||||||
transforms.ToTensor(),
|
gpu_image = transforms.Compose([ # type: ignore
|
||||||
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC), # type: ignore
|
||||||
])(pil_image).unsqueeze(0).to(device).half()
|
transforms.ToTensor(), # type: ignore
|
||||||
|
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) # type: ignore
|
||||||
|
])(pil_image).unsqueeze(0).to(device).half()
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
caption = blip_model.generate(gpu_image, sample=False, num_beams=3, max_length=20, min_length=5)
|
caption = server_state["blip_model"].generate(gpu_image, sample=False, num_beams=3, max_length=20, min_length=5)
|
||||||
|
|
||||||
#print (caption)
|
#print (caption)
|
||||||
return caption[0]
|
return caption[0]
|
||||||
|
|
||||||
|
|
||||||
def load_list(filename):
|
def load_list(filename):
|
||||||
with open(filename, 'r', encoding='utf-8', errors='replace') as f:
|
with open(filename, 'r', encoding='utf-8', errors='replace') as f:
|
||||||
items = [line.strip() for line in f.readlines()]
|
items = [line.strip() for line in f.readlines()]
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
def rank(model, image_features, text_array, top_count=1):
|
def rank(model, image_features, text_array, top_count=1):
|
||||||
top_count = min(top_count, len(text_array))
|
top_count = min(top_count, len(text_array))
|
||||||
text_tokens = clip.tokenize([text for text in text_array]).cuda()
|
text_tokens = clip.tokenize([text for text in text_array]).cuda()
|
||||||
@ -115,13 +121,15 @@ def rank(model, image_features, text_array, top_count=1):
|
|||||||
similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)
|
similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)
|
||||||
similarity /= image_features.shape[0]
|
similarity /= image_features.shape[0]
|
||||||
|
|
||||||
top_probs, top_labels = similarity.cpu().topk(top_count, dim=-1)
|
top_probs, top_labels = similarity.cpu().topk(top_count, dim=-1)
|
||||||
return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)]
|
return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)]
|
||||||
|
|
||||||
|
|
||||||
def clear_cuda():
|
def clear_cuda():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
|
|
||||||
def batch_rank(model, image_features, text_array, batch_size=st.session_state["defaults"].img2txt.batch_size):
|
def batch_rank(model, image_features, text_array, batch_size=st.session_state["defaults"].img2txt.batch_size):
|
||||||
batch_count = len(text_array) // batch_size
|
batch_count = len(text_array) // batch_size
|
||||||
batches = [text_array[i*batch_size:(i+1)*batch_size] for i in range(batch_count)]
|
batches = [text_array[i*batch_size:(i+1)*batch_size] for i in range(batch_count)]
|
||||||
@ -132,18 +140,19 @@ def batch_rank(model, image_features, text_array, batch_size=st.session_state["d
|
|||||||
return ranks
|
return ranks
|
||||||
|
|
||||||
def interrogate(image, models):
|
def interrogate(image, models):
|
||||||
global blip_model
|
|
||||||
|
#server_state["blip_model"] =
|
||||||
|
load_blip_model()
|
||||||
|
|
||||||
blip_model = load_blip_model()
|
print("Generating Caption")
|
||||||
print ("Generating Caption")
|
|
||||||
st.session_state["log_message"].code("Generating Caption", language='')
|
st.session_state["log_message"].code("Generating Caption", language='')
|
||||||
caption = generate_caption(image)
|
caption = generate_caption(image)
|
||||||
|
|
||||||
if st.session_state["defaults"].general.optimized:
|
if st.session_state["defaults"].general.optimized:
|
||||||
del blip_model
|
del server_state["blip_model"]
|
||||||
clear_cuda()
|
clear_cuda()
|
||||||
|
|
||||||
print ("Caption Generated")
|
print("Caption Generated")
|
||||||
st.session_state["log_message"].code("Caption Generated", language='')
|
st.session_state["log_message"].code("Caption Generated", language='')
|
||||||
|
|
||||||
if len(models) == 0:
|
if len(models) == 0:
|
||||||
@ -151,44 +160,48 @@ def interrogate(image, models):
|
|||||||
return
|
return
|
||||||
|
|
||||||
table = []
|
table = []
|
||||||
bests = [[('',0)]]*5
|
bests = [[('', 0)]]*5
|
||||||
|
|
||||||
print ("Ranking Text")
|
print("Ranking Text")
|
||||||
for model_name in models:
|
for model_name in models:
|
||||||
print(f"Interrogating with {model_name}...")
|
print(f"Interrogating with {model_name}...")
|
||||||
st.session_state["log_message"].code(f"Interrogating with {model_name}...", language='')
|
st.session_state["log_message"].code(f"Interrogating with {model_name}...", language='')
|
||||||
if model_name == 'ViT-H-14':
|
|
||||||
model, _, preprocess = open_clip.create_model_and_transforms(model_name, pretrained='laion2b_s32b_b79k')
|
|
||||||
elif model_name == 'ViT-g-14':
|
|
||||||
model, _, preprocess = open_clip.create_model_and_transforms(model_name, pretrained='laion2b_s12b_b42k')
|
|
||||||
else:
|
|
||||||
model, preprocess = clip.load(model_name, device=device)
|
|
||||||
|
|
||||||
model.cuda().eval()
|
|
||||||
|
|
||||||
images = preprocess(image).unsqueeze(0).cuda()
|
|
||||||
with torch.no_grad():
|
|
||||||
image_features = model.encode_image(images).float()
|
|
||||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
|
||||||
|
|
||||||
|
if "clip_model" not in server_state:
|
||||||
|
#with server_state_lock[server_state["clip_model"]]:
|
||||||
|
if model_name == 'ViT-H-14':
|
||||||
|
server_state["clip_model"], _, server_state["preprocess"] = open_clip.create_model_and_transforms(model_name, pretrained='laion2b_s32b_b79k')
|
||||||
|
elif model_name == 'ViT-g-14':
|
||||||
|
server_state["clip_model"], _, server_state["preprocess"] = open_clip.create_model_and_transforms(model_name, pretrained='laion2b_s12b_b42k')
|
||||||
|
else:
|
||||||
|
server_state["clip_model"], server_state["preprocess"] = clip.load(model_name, device=device)
|
||||||
|
|
||||||
|
server_state["clip_model"] = server_state["clip_model"].cuda().eval()
|
||||||
|
|
||||||
|
images = server_state["preprocess"](image).unsqueeze(0).cuda()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
image_features = server_state["clip_model"].encode_image(images).float()
|
||||||
|
|
||||||
|
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||||
|
|
||||||
if st.session_state["defaults"].general.optimized:
|
if st.session_state["defaults"].general.optimized:
|
||||||
clear_cuda()
|
clear_cuda()
|
||||||
|
|
||||||
ranks = []
|
ranks = []
|
||||||
ranks.append(batch_rank(model, image_features, server_state["mediums"]))
|
ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["mediums"]))
|
||||||
ranks.append(batch_rank(model, image_features, ["by "+artist for artist in server_state["artists"]]))
|
ranks.append(batch_rank(server_state["clip_model"], image_features, ["by "+artist for artist in server_state["artists"]]))
|
||||||
ranks.append(batch_rank(model, image_features, server_state["trending_list"]))
|
ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["trending_list"]))
|
||||||
ranks.append(batch_rank(model, image_features, server_state["movements"]))
|
ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["movements"]))
|
||||||
ranks.append(batch_rank(model, image_features, server_state["flavors"]))
|
ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["flavors"]))
|
||||||
# ranks.append(batch_rank(model, image_features, server_state["genres"]))
|
# ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["genres"]))
|
||||||
# ranks.append(batch_rank(model, image_features, server_state["styles"]))
|
# ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["styles"]))
|
||||||
# ranks.append(batch_rank(model, image_features, server_state["techniques"]))
|
# ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["techniques"]))
|
||||||
# ranks.append(batch_rank(model, image_features, server_state["subjects"]))
|
# ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["subjects"]))
|
||||||
# ranks.append(batch_rank(model, image_features, server_state["colors"]))
|
# ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["colors"]))
|
||||||
# ranks.append(batch_rank(model, image_features, server_state["moods"]))
|
# ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["moods"]))
|
||||||
# ranks.append(batch_rank(model, image_features, server_state["themes"]))
|
# ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["themes"]))
|
||||||
# ranks.append(batch_rank(model, image_features, server_state["keywords"]))
|
# ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["keywords"]))
|
||||||
|
|
||||||
|
|
||||||
for i in range(len(ranks)):
|
for i in range(len(ranks)):
|
||||||
confidence_sum = 0
|
confidence_sum = 0
|
||||||
@ -204,27 +217,28 @@ def interrogate(image, models):
|
|||||||
table.append(row)
|
table.append(row)
|
||||||
|
|
||||||
if st.session_state["defaults"].general.optimized:
|
if st.session_state["defaults"].general.optimized:
|
||||||
del model
|
del server_state["clip_model"]
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
#for i in range(len(st.session_state["uploaded_image"])):
|
# for i in range(len(st.session_state["uploaded_image"])):
|
||||||
st.session_state["prediction_table"][st.session_state["processed_image_count"]].dataframe(pd.DataFrame(
|
st.session_state["prediction_table"][st.session_state["processed_image_count"]].dataframe(pd.DataFrame(
|
||||||
table, columns=["Model", "Medium", "Artist", "Trending", "Movement", "Flavors"]))
|
table, columns=["Model", "Medium", "Artist", "Trending", "Movement", "Flavors"]))
|
||||||
|
|
||||||
flaves = ', '.join([f"{x[0]}" for x in bests[4]])
|
flaves = ', '.join([f"{x[0]}" for x in bests[4]])
|
||||||
medium = bests[0][0][0]
|
medium = bests[0][0][0]
|
||||||
if caption.startswith(medium):
|
if caption.startswith(medium):
|
||||||
st.session_state["text_result"][st.session_state["processed_image_count"]].code(
|
st.session_state["text_result"][st.session_state["processed_image_count"]].code(
|
||||||
f"\n\n{caption} {bests[1][0][0]}, {bests[2][0][0]}, {bests[3][0][0]}, {flaves}", language="")
|
f"\n\n{caption} {bests[1][0][0]}, {bests[2][0][0]}, {bests[3][0][0]}, {flaves}", language="")
|
||||||
else:
|
else:
|
||||||
st.session_state["text_result"][st.session_state["processed_image_count"]].code(
|
st.session_state["text_result"][st.session_state["processed_image_count"]].code(
|
||||||
f"\n\n{caption}, {medium} {bests[1][0][0]}, {bests[2][0][0]}, {bests[3][0][0]}, {flaves}", language="")
|
f"\n\n{caption}, {medium} {bests[1][0][0]}, {bests[2][0][0]}, {bests[3][0][0]}, {flaves}", language="")
|
||||||
|
|
||||||
#
|
#
|
||||||
print ("Finished Interrogating.")
|
print("Finished Interrogating.")
|
||||||
st.session_state["log_message"].code("Finished Interrogating.", language="")
|
st.session_state["log_message"].code("Finished Interrogating.", language="")
|
||||||
#
|
#
|
||||||
|
|
||||||
|
|
||||||
def img2txt():
|
def img2txt():
|
||||||
data_path = "data/"
|
data_path = "data/"
|
||||||
|
|
||||||
@ -251,11 +265,11 @@ def img2txt():
|
|||||||
models.append('ViT-B/32')
|
models.append('ViT-B/32')
|
||||||
if st.session_state['ViTB16']:
|
if st.session_state['ViTB16']:
|
||||||
models.append('ViT-B/16')
|
models.append('ViT-B/16')
|
||||||
if st.session_state["ViTL14"]:
|
if st.session_state["ViTL14"]:
|
||||||
models.append('ViT-L/14')
|
models.append('ViT-L/14')
|
||||||
if st.session_state["ViT-H-14"]:
|
if st.session_state["ViT-H-14"]:
|
||||||
models.append('ViT-H-14')
|
models.append('ViT-H-14')
|
||||||
if st.session_state["ViT-g-14"]:
|
if st.session_state["ViT-g-14"]:
|
||||||
models.append('ViT-g-14')
|
models.append('ViT-g-14')
|
||||||
if st.session_state["ViTL14_336px"]:
|
if st.session_state["ViTL14_336px"]:
|
||||||
models.append('ViT-L/14@336px')
|
models.append('ViT-L/14@336px')
|
||||||
@ -270,33 +284,35 @@ def img2txt():
|
|||||||
if st.session_state["RN50x64"]:
|
if st.session_state["RN50x64"]:
|
||||||
models.append('RN50x64')
|
models.append('RN50x64')
|
||||||
|
|
||||||
#if str(image_path_or_url).startswith('http://') or str(image_path_or_url).startswith('https://'):
|
# if str(image_path_or_url).startswith('http://') or str(image_path_or_url).startswith('https://'):
|
||||||
#image = Image.open(requests.get(image_path_or_url, stream=True).raw).convert('RGB')
|
#image = Image.open(requests.get(image_path_or_url, stream=True).raw).convert('RGB')
|
||||||
#else:
|
# else:
|
||||||
#image = Image.open(image_path_or_url).convert('RGB')
|
#image = Image.open(image_path_or_url).convert('RGB')
|
||||||
|
|
||||||
#thumb = st.session_state["uploaded_image"].image.copy()
|
#thumb = st.session_state["uploaded_image"].image.copy()
|
||||||
#thumb.thumbnail([blip_image_eval_size, blip_image_eval_size])
|
#thumb.thumbnail([blip_image_eval_size, blip_image_eval_size])
|
||||||
#display(thumb)
|
# display(thumb)
|
||||||
|
|
||||||
st.session_state["processed_image_count"] = 0
|
st.session_state["processed_image_count"] = 0
|
||||||
|
|
||||||
for i in range(len(st.session_state["uploaded_image"])):
|
for i in range(len(st.session_state["uploaded_image"])):
|
||||||
|
|
||||||
interrogate(st.session_state["uploaded_image"][i].pil_image, models=models)
|
interrogate(st.session_state["uploaded_image"][i].pil_image, models=models)
|
||||||
# increase counter.
|
# increase counter.
|
||||||
st.session_state["processed_image_count"] += 1
|
st.session_state["processed_image_count"] += 1
|
||||||
#
|
#
|
||||||
|
|
||||||
|
|
||||||
def layout():
|
def layout():
|
||||||
#set_page_title("Image-to-Text - Stable Diffusion WebUI")
|
#set_page_title("Image-to-Text - Stable Diffusion WebUI")
|
||||||
#st.info("Under Construction. :construction_worker:")
|
#st.info("Under Construction. :construction_worker:")
|
||||||
|
|
||||||
with st.form("img2txt-inputs"):
|
with st.form("img2txt-inputs"):
|
||||||
st.session_state["generation_mode"] = "img2txt"
|
st.session_state["generation_mode"] = "img2txt"
|
||||||
|
|
||||||
#st.write("---")
|
# st.write("---")
|
||||||
# creating the page layout using columns
|
# creating the page layout using columns
|
||||||
col1, col2 = st.columns([1,4], gap="large")
|
col1, col2 = st.columns([1, 4], gap="large")
|
||||||
|
|
||||||
with col1:
|
with col1:
|
||||||
#url = st.text_area("Input Text","")
|
#url = st.text_area("Input Text","")
|
||||||
@ -304,68 +320,66 @@ def layout():
|
|||||||
#st.subheader("Input Image")
|
#st.subheader("Input Image")
|
||||||
st.session_state["uploaded_image"] = st.file_uploader('Input Image', type=['png', 'jpg', 'jpeg'], accept_multiple_files=True)
|
st.session_state["uploaded_image"] = st.file_uploader('Input Image', type=['png', 'jpg', 'jpeg'], accept_multiple_files=True)
|
||||||
|
|
||||||
st.subheader("CLIP models")
|
st.subheader("CLIP models")
|
||||||
with st.expander("Stable Diffusion", expanded=True):
|
with st.expander("Stable Diffusion", expanded=True):
|
||||||
st.session_state["ViTL14"] = st.checkbox("ViTL14", value=True, help="For StableDiffusion you can just use ViTL14.")
|
st.session_state["ViTL14"] = st.checkbox("ViTL14", value=True, help="For StableDiffusion you can just use ViTL14.")
|
||||||
|
|
||||||
with st.expander("Others"):
|
with st.expander("Others"):
|
||||||
st.info("For DiscoDiffusion and JAX enable all the same models here as you intend to use when generating your images.")
|
st.info("For DiscoDiffusion and JAX enable all the same models here as you intend to use when generating your images.")
|
||||||
st.session_state["ViT-H-14"] = st.checkbox("ViT-H-14", value=False, help="ViT-H-14 model.")
|
st.session_state["ViT-H-14"] = st.checkbox("ViT-H-14", value=False, help="ViT-H-14 model.")
|
||||||
st.session_state["ViT-g-14"] = st.checkbox("ViT-g-14", value=False, help="ViT-g-14 model.")
|
st.session_state["ViT-g-14"] = st.checkbox("ViT-g-14", value=False, help="ViT-g-14 model.")
|
||||||
st.session_state["ViTL14_336px"] = st.checkbox("ViTL14_336px", value=False, help="ViTL14_336px model.")
|
st.session_state["ViTL14_336px"] = st.checkbox("ViTL14_336px", value=False, help="ViTL14_336px model.")
|
||||||
st.session_state["ViTB16"] = st.checkbox("ViTB16", value=False, help="ViTB16 model.")
|
st.session_state["ViTB16"] = st.checkbox("ViTB16", value=False, help="ViTB16 model.")
|
||||||
st.session_state["ViTB32"] = st.checkbox("ViTB32", value=False, help="ViTB32 model.")
|
st.session_state["ViTB32"] = st.checkbox("ViTB32", value=False, help="ViTB32 model.")
|
||||||
st.session_state["RN50"] = st.checkbox("RN50", value=False, help="RN50 model.")
|
st.session_state["RN50"] = st.checkbox("RN50", value=False, help="RN50 model.")
|
||||||
st.session_state["RN50x4"] = st.checkbox("RN50x4", value=False, help="RN50x4 model.")
|
st.session_state["RN50x4"] = st.checkbox("RN50x4", value=False, help="RN50x4 model.")
|
||||||
st.session_state["RN50x16"] = st.checkbox("RN50x16", value=False, help="RN50x16 model.")
|
st.session_state["RN50x16"] = st.checkbox("RN50x16", value=False, help="RN50x16 model.")
|
||||||
st.session_state["RN50x64"] = st.checkbox("RN50x64", value=False, help="RN50x64 model.")
|
st.session_state["RN50x64"] = st.checkbox("RN50x64", value=False, help="RN50x64 model.")
|
||||||
st.session_state["RN101"] = st.checkbox("RN101", value=False, help="RN101 model.")
|
st.session_state["RN101"] = st.checkbox("RN101", value=False, help="RN101 model.")
|
||||||
|
|
||||||
#
|
#
|
||||||
#st.subheader("Logs:")
|
# st.subheader("Logs:")
|
||||||
|
|
||||||
st.session_state["log_message"] = st.empty()
|
st.session_state["log_message"] = st.empty()
|
||||||
st.session_state["log_message"].code('', language="")
|
st.session_state["log_message"].code('', language="")
|
||||||
|
|
||||||
|
|
||||||
with col2:
|
with col2:
|
||||||
st.subheader("Image")
|
st.subheader("Image")
|
||||||
|
|
||||||
refresh = st.form_submit_button("Refresh", help='Refresh the image preview to show your uploaded image instead of the default placeholder.')
|
refresh = st.form_submit_button("Refresh", help='Refresh the image preview to show your uploaded image instead of the default placeholder.')
|
||||||
|
|
||||||
if st.session_state["uploaded_image"]:
|
if st.session_state["uploaded_image"]:
|
||||||
#print (type(st.session_state["uploaded_image"]))
|
#print (type(st.session_state["uploaded_image"]))
|
||||||
#if len(st.session_state["uploaded_image"]) == 1:
|
# if len(st.session_state["uploaded_image"]) == 1:
|
||||||
st.session_state["input_image_preview"] = []
|
st.session_state["input_image_preview"] = []
|
||||||
st.session_state["input_image_preview_container"] = []
|
st.session_state["input_image_preview_container"] = []
|
||||||
st.session_state["prediction_table"] = []
|
st.session_state["prediction_table"] = []
|
||||||
st.session_state["text_result"] = []
|
st.session_state["text_result"] = []
|
||||||
|
|
||||||
for i in range(len(st.session_state["uploaded_image"])):
|
for i in range(len(st.session_state["uploaded_image"])):
|
||||||
st.session_state["input_image_preview_container"].append(i)
|
st.session_state["input_image_preview_container"].append(i)
|
||||||
st.session_state["input_image_preview_container"][i]= st.empty()
|
st.session_state["input_image_preview_container"][i] = st.empty()
|
||||||
|
|
||||||
with st.session_state["input_image_preview_container"][i].container():
|
with st.session_state["input_image_preview_container"][i].container():
|
||||||
col1_output, col2_output = st.columns([2,10], gap="medium")
|
col1_output, col2_output = st.columns([2, 10], gap="medium")
|
||||||
with col1_output:
|
with col1_output:
|
||||||
st.session_state["input_image_preview"].append(i)
|
st.session_state["input_image_preview"].append(i)
|
||||||
st.session_state["input_image_preview"][i]= st.empty()
|
st.session_state["input_image_preview"][i] = st.empty()
|
||||||
st.session_state["uploaded_image"][i].pil_image = Image.open(st.session_state["uploaded_image"][i]).convert('RGB')
|
st.session_state["uploaded_image"][i].pil_image = Image.open(st.session_state["uploaded_image"][i]).convert('RGB')
|
||||||
|
|
||||||
st.session_state["input_image_preview"][i].image(st.session_state["uploaded_image"][i].pil_image, use_column_width=True, clamp=True)
|
st.session_state["input_image_preview"][i].image(st.session_state["uploaded_image"][i].pil_image, use_column_width=True, clamp=True)
|
||||||
|
|
||||||
with st.session_state["input_image_preview_container"][i].container():
|
with st.session_state["input_image_preview_container"][i].container():
|
||||||
|
|
||||||
with col2_output:
|
with col2_output:
|
||||||
|
|
||||||
st.session_state["prediction_table"].append(i)
|
st.session_state["prediction_table"].append(i)
|
||||||
st.session_state["prediction_table"][i] = st.empty()
|
st.session_state["prediction_table"][i] = st.empty()
|
||||||
st.session_state["prediction_table"][i].table()
|
st.session_state["prediction_table"][i].table()
|
||||||
|
|
||||||
st.session_state["text_result"].append(i)
|
st.session_state["text_result"].append(i)
|
||||||
st.session_state["text_result"][i]= st.empty()
|
st.session_state["text_result"][i] = st.empty()
|
||||||
st.session_state["text_result"][i].code("", language="")
|
st.session_state["text_result"][i].code("", language="")
|
||||||
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
#st.session_state["input_image_preview"].code('', language="")
|
#st.session_state["input_image_preview"].code('', language="")
|
||||||
@ -373,13 +387,13 @@ def layout():
|
|||||||
|
|
||||||
#
|
#
|
||||||
# Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way.
|
# Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way.
|
||||||
#generate_col1.title("")
|
# generate_col1.title("")
|
||||||
#generate_col1.title("")
|
# generate_col1.title("")
|
||||||
generate_button = st.form_submit_button("Generate!")
|
generate_button = st.form_submit_button("Generate!")
|
||||||
|
|
||||||
if generate_button:
|
if generate_button:
|
||||||
# if model, pipe, RealESRGAN or GFPGAN is in st.session_state remove the model and pipe form session_state so that they are reloaded.
|
# if model, pipe, RealESRGAN or GFPGAN is in st.session_state remove the model and pipe form session_state so that they are reloaded.
|
||||||
if "model" in st.session_state and st.session_state["defaults"].general.optimized:
|
if "model" in st.session_state and st.session_state["defaults"].general.optimized:
|
||||||
del st.session_state["model"]
|
del st.session_state["model"]
|
||||||
if "pipe" in st.session_state and st.session_state["defaults"].general.optimized:
|
if "pipe" in st.session_state and st.session_state["defaults"].general.optimized:
|
||||||
del st.session_state["pipe"]
|
del st.session_state["pipe"]
|
||||||
@ -387,7 +401,6 @@ def layout():
|
|||||||
del st.session_state["RealESRGAN"]
|
del st.session_state["RealESRGAN"]
|
||||||
if "GFPGAN" in st.session_state and st.session_state["defaults"].general.optimized:
|
if "GFPGAN" in st.session_state and st.session_state["defaults"].general.optimized:
|
||||||
del st.session_state["GFPGAN"]
|
del st.session_state["GFPGAN"]
|
||||||
|
|
||||||
|
|
||||||
# run clip interrogator
|
# run clip interrogator
|
||||||
img2txt()
|
img2txt()
|
||||||
|
Loading…
Reference in New Issue
Block a user