Added more tags to the flavors.txt file for img2txt. (#1585)

- Changed default batch size for img2txt to 2000 batches.
- Added logging back to the img2txt tab.
- Added more tags to the flavors.txt file for img2txt.
This commit is contained in:
Alejandro Gil 2022-10-24 05:11:04 -07:00 committed by GitHub
commit bb7fce1a87
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 36792 additions and 265 deletions

View File

@ -304,7 +304,7 @@ img2img:
write_info_files: True write_info_files: True
img2txt: img2txt:
batch_size: 420 batch_size: 2000
blip_image_eval_size: 512 blip_image_eval_size: 512
keep_all_models_loaded: False keep_all_models_loaded: False

File diff suppressed because it is too large Load Diff

View File

@ -62,8 +62,12 @@ from ldm.models.blip import blip_decoder
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
blip_image_eval_size = 512 blip_image_eval_size = 512
st.session_state["log"] = []
def load_blip_model(): def load_blip_model():
logger.info("Loading BLIP Model") logger.info("Loading BLIP Model")
st.session_state["log"].append("Loading BLIP Model")
st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
if "blip_model" not in server_state: if "blip_model" not in server_state:
with server_state_lock['blip_model']: with server_state_lock['blip_model']:
@ -75,8 +79,12 @@ def load_blip_model():
server_state["blip_model"] = server_state["blip_model"].to(device).half() server_state["blip_model"] = server_state["blip_model"].to(device).half()
logger.info("BLIP Model Loaded") logger.info("BLIP Model Loaded")
st.session_state["log"].append("BLIP Model Loaded")
st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
else: else:
logger.info("BLIP Model already loaded") logger.info("BLIP Model already loaded")
st.session_state["log"].append("BLIP Model already loaded")
st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
def generate_caption(pil_image): def generate_caption(pil_image):
@ -133,6 +141,8 @@ def interrogate(image, models):
load_blip_model() load_blip_model()
logger.info("Generating Caption") logger.info("Generating Caption")
st.session_state["log"].append("Generating Caption")
st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
caption = generate_caption(image) caption = generate_caption(image)
if st.session_state["defaults"].general.optimized: if st.session_state["defaults"].general.optimized:
@ -140,6 +150,8 @@ def interrogate(image, models):
clear_cuda() clear_cuda()
logger.info("Caption Generated") logger.info("Caption Generated")
st.session_state["log"].append("Caption Generated")
st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
if len(models) == 0: if len(models) == 0:
logger.info(f"\n\n{caption}") logger.info(f"\n\n{caption}")
@ -149,10 +161,14 @@ def interrogate(image, models):
bests = [[('', 0)]]*7 bests = [[('', 0)]]*7
logger.info("Ranking Text") logger.info("Ranking Text")
st.session_state["log"].append("Ranking Text")
st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
for model_name in models: for model_name in models:
with torch.no_grad(), torch.autocast('cuda', dtype=torch.float16): with torch.no_grad(), torch.autocast('cuda', dtype=torch.float16):
logger.info(f"Interrogating with {model_name}...") logger.info(f"Interrogating with {model_name}...")
st.session_state["log"].append(f"Interrogating with {model_name}...")
st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
if model_name not in server_state["clip_models"]: if model_name not in server_state["clip_models"]:
if not st.session_state["defaults"].img2txt.keep_all_models_loaded: if not st.session_state["defaults"].img2txt.keep_all_models_loaded:
@ -220,9 +236,14 @@ def interrogate(image, models):
best = best[:3] best = best[:3]
row = [model_name] row = [model_name]
for rank in ranks:
rank.sort(key=lambda x: x[1], reverse=True) for r in ranks:
row.append(f'{rank[0][0]} {rank[0][1]:.2f}%') row.append(', '.join([f"{x[0]} ({x[1]:0.1f}%)" for x in r]))
#for rank in ranks:
# rank.sort(key=lambda x: x[1], reverse=True)
# row.append(f'{rank[0][0]} {rank[0][1]:.2f}%')
table.append(row) table.append(row)
if st.session_state["defaults"].general.optimized: if st.session_state["defaults"].general.optimized:
@ -251,6 +272,9 @@ def interrogate(image, models):
f"\n\n{caption}, {medium} {artist}, {trending}, {movement}, {techniques}, {flavors}, {tags}", language="") f"\n\n{caption}, {medium} {artist}, {trending}, {movement}, {techniques}, {flavors}, {tags}", language="")
logger.info("Finished Interrogating.") logger.info("Finished Interrogating.")
st.session_state["log"].append("Finished Interrogating.")
st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
def img2txt(): def img2txt():
models = [] models = []