mirror of
https://github.com/Sygil-Dev/sygil-webui.git
synced 2025-01-06 03:20:16 +03:00
Added more tags to the flavors.txt file for img2txt. (#1585)
- Changed default batch size for img2txt to 2000 batches. - Added logging back to the img2txt tab. - Added more tags to the flavors.txt file for img2txt.
This commit is contained in:
commit
bb7fce1a87
@ -304,7 +304,7 @@ img2img:
|
||||
write_info_files: True
|
||||
|
||||
img2txt:
|
||||
batch_size: 420
|
||||
batch_size: 2000
|
||||
blip_image_eval_size: 512
|
||||
keep_all_models_loaded: False
|
||||
|
||||
|
37025
data/img2txt/flavors.txt
37025
data/img2txt/flavors.txt
File diff suppressed because it is too large
Load Diff
@ -62,8 +62,12 @@ from ldm.models.blip import blip_decoder
|
||||
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
||||
blip_image_eval_size = 512
|
||||
|
||||
st.session_state["log"] = []
|
||||
|
||||
def load_blip_model():
|
||||
logger.info("Loading BLIP Model")
|
||||
st.session_state["log"].append("Loading BLIP Model")
|
||||
st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
|
||||
|
||||
if "blip_model" not in server_state:
|
||||
with server_state_lock['blip_model']:
|
||||
@ -75,8 +79,12 @@ def load_blip_model():
|
||||
server_state["blip_model"] = server_state["blip_model"].to(device).half()
|
||||
|
||||
logger.info("BLIP Model Loaded")
|
||||
st.session_state["log"].append("BLIP Model Loaded")
|
||||
st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
|
||||
else:
|
||||
logger.info("BLIP Model already loaded")
|
||||
st.session_state["log"].append("BLIP Model already loaded")
|
||||
st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
|
||||
|
||||
|
||||
def generate_caption(pil_image):
|
||||
@ -133,6 +141,8 @@ def interrogate(image, models):
|
||||
load_blip_model()
|
||||
|
||||
logger.info("Generating Caption")
|
||||
st.session_state["log"].append("Generating Caption")
|
||||
st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
|
||||
caption = generate_caption(image)
|
||||
|
||||
if st.session_state["defaults"].general.optimized:
|
||||
@ -140,6 +150,8 @@ def interrogate(image, models):
|
||||
clear_cuda()
|
||||
|
||||
logger.info("Caption Generated")
|
||||
st.session_state["log"].append("Caption Generated")
|
||||
st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
|
||||
|
||||
if len(models) == 0:
|
||||
logger.info(f"\n\n{caption}")
|
||||
@ -149,10 +161,14 @@ def interrogate(image, models):
|
||||
bests = [[('', 0)]]*7
|
||||
|
||||
logger.info("Ranking Text")
|
||||
st.session_state["log"].append("Ranking Text")
|
||||
st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
|
||||
|
||||
for model_name in models:
|
||||
with torch.no_grad(), torch.autocast('cuda', dtype=torch.float16):
|
||||
logger.info(f"Interrogating with {model_name}...")
|
||||
st.session_state["log"].append(f"Interrogating with {model_name}...")
|
||||
st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
|
||||
|
||||
if model_name not in server_state["clip_models"]:
|
||||
if not st.session_state["defaults"].img2txt.keep_all_models_loaded:
|
||||
@ -220,9 +236,14 @@ def interrogate(image, models):
|
||||
best = best[:3]
|
||||
|
||||
row = [model_name]
|
||||
for rank in ranks:
|
||||
rank.sort(key=lambda x: x[1], reverse=True)
|
||||
row.append(f'{rank[0][0]} {rank[0][1]:.2f}%')
|
||||
|
||||
for r in ranks:
|
||||
row.append(', '.join([f"{x[0]} ({x[1]:0.1f}%)" for x in r]))
|
||||
|
||||
#for rank in ranks:
|
||||
# rank.sort(key=lambda x: x[1], reverse=True)
|
||||
# row.append(f'{rank[0][0]} {rank[0][1]:.2f}%')
|
||||
|
||||
table.append(row)
|
||||
|
||||
if st.session_state["defaults"].general.optimized:
|
||||
@ -251,6 +272,9 @@ def interrogate(image, models):
|
||||
f"\n\n{caption}, {medium} {artist}, {trending}, {movement}, {techniques}, {flavors}, {tags}", language="")
|
||||
|
||||
logger.info("Finished Interrogating.")
|
||||
st.session_state["log"].append("Finished Interrogating.")
|
||||
st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
|
||||
|
||||
|
||||
def img2txt():
|
||||
models = []
|
||||
|
Loading…
Reference in New Issue
Block a user