Update img2txt.py

This commit is contained in:
hlky 2022-10-24 11:48:45 +01:00
parent 70ef109d17
commit 8dcd432674
No known key found for this signature in database
GPG Key ID: 55A99F1E80D907D5

View File

@ -61,16 +61,9 @@ from ldm.models.blip import blip_decoder
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
blip_image_eval_size = 512 blip_image_eval_size = 512
#blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth'
server_state["clip_models"] = {}
server_state["preprocesses"] = {}
st.session_state["log"] = []
def load_blip_model(): def load_blip_model():
logger.info("Loading BLIP Model") logger.info("Loading BLIP Model")
st.session_state["log"].append("Loading BLIP Model")
st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
if "blip_model" not in server_state: if "blip_model" not in server_state:
with server_state_lock['blip_model']: with server_state_lock['blip_model']:
@ -79,18 +72,12 @@ def load_blip_model():
server_state["blip_model"] = server_state["blip_model"].eval() server_state["blip_model"] = server_state["blip_model"].eval()
#if not st.session_state["defaults"].general.optimized:
server_state["blip_model"] = server_state["blip_model"].to(device).half() server_state["blip_model"] = server_state["blip_model"].to(device).half()
logger.info("BLIP Model Loaded") logger.info("BLIP Model Loaded")
st.session_state["log"].append("BLIP Model Loaded")
st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
else: else:
logger.info("BLIP Model already loaded") logger.info("BLIP Model already loaded")
st.session_state["log"].append("BLIP Model already loaded")
st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
#return server_state["blip_model"]
def generate_caption(pil_image): def generate_caption(pil_image):
@ -105,7 +92,6 @@ def generate_caption(pil_image):
with torch.no_grad(): with torch.no_grad():
caption = server_state["blip_model"].generate(gpu_image, sample=False, num_beams=3, max_length=20, min_length=5) caption = server_state["blip_model"].generate(gpu_image, sample=False, num_beams=3, max_length=20, min_length=5)
#print (caption)
return caption[0] return caption[0]
def load_list(filename): def load_list(filename):
@ -135,10 +121,6 @@ def clear_cuda():
def batch_rank(model, image_features, text_array, batch_size=st.session_state["defaults"].img2txt.batch_size): def batch_rank(model, image_features, text_array, batch_size=st.session_state["defaults"].img2txt.batch_size):
#logger.info("Ranking")
#st.session_state["log"].append("Ranking")
#st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
batch_size = min(batch_size, len(text_array)) batch_size = min(batch_size, len(text_array))
batch_count = int(len(text_array) / batch_size) batch_count = int(len(text_array) / batch_size)
batches = [text_array[i*batch_size:(i+1)*batch_size] for i in range(batch_count)] batches = [text_array[i*batch_size:(i+1)*batch_size] for i in range(batch_count)]
@ -148,13 +130,9 @@ def batch_rank(model, image_features, text_array, batch_size=st.session_state["d
return ranks return ranks
def interrogate(image, models): def interrogate(image, models):
#server_state["blip_model"] =
load_blip_model() load_blip_model()
logger.info("Generating Caption") logger.info("Generating Caption")
st.session_state["log"].append("Generating Caption")
st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
caption = generate_caption(image) caption = generate_caption(image)
if st.session_state["defaults"].general.optimized: if st.session_state["defaults"].general.optimized:
@ -162,8 +140,6 @@ def interrogate(image, models):
clear_cuda() clear_cuda()
logger.info("Caption Generated") logger.info("Caption Generated")
st.session_state["log"].append("Caption Generated")
st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
if len(models) == 0: if len(models) == 0:
logger.info(f"\n\n{caption}") logger.info(f"\n\n{caption}")
@ -174,16 +150,9 @@ def interrogate(image, models):
logger.info("Ranking Text") logger.info("Ranking Text")
#if "clip_model" in server_state:
#print (server_state["clip_model"])
#print (st.session_state["log_message"])
for model_name in models: for model_name in models:
with torch.no_grad(), torch.autocast('cuda', dtype=torch.float16): with torch.no_grad(), torch.autocast('cuda', dtype=torch.float16):
logger.info(f"Interrogating with {model_name}...") logger.info(f"Interrogating with {model_name}...")
st.session_state["log"].append(f"Interrogating with {model_name}...")
st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
if model_name not in server_state["clip_models"]: if model_name not in server_state["clip_models"]:
if not st.session_state["defaults"].img2txt.keep_all_models_loaded: if not st.session_state["defaults"].img2txt.keep_all_models_loaded:
@ -196,15 +165,14 @@ def interrogate(image, models):
del server_state["preprocesses"][model] del server_state["preprocesses"][model]
clear_cuda() clear_cuda()
if model_name == 'ViT-H-14': if model_name == 'ViT-H-14':
server_state["clip_models"][model_name], _, server_state["preprocesses"][model_name] = open_clip.create_model_and_transforms(model_name, server_state["clip_models"][model_name], _, server_state["preprocesses"][model_name] = \
pretrained='laion2b_s32b_b79k', open_clip.create_model_and_transforms(model_name, pretrained='laion2b_s32b_b79k', cache_dir='models/clip')
cache_dir='models/clip')
elif model_name == 'ViT-g-14': elif model_name == 'ViT-g-14':
server_state["clip_models"][model_name], _, server_state["preprocesses"][model_name] = open_clip.create_model_and_transforms(model_name, server_state["clip_models"][model_name], _, server_state["preprocesses"][model_name] = \
pretrained='laion2b_s12b_b42k', open_clip.create_model_and_transforms(model_name, pretrained='laion2b_s12b_b42k', cache_dir='models/clip')
cache_dir='models/clip')
else: else:
server_state["clip_models"][model_name], server_state["preprocesses"][model_name] = clip.load(model_name, device=device, download_root='models/clip') server_state["clip_models"][model_name], server_state["preprocesses"][model_name] = \
clip.load(model_name, device=device, download_root='models/clip')
server_state["clip_models"][model_name] = server_state["clip_models"][model_name].cuda().eval() server_state["clip_models"][model_name] = server_state["clip_models"][model_name].cuda().eval()
images = server_state["preprocesses"][model_name](image).unsqueeze(0).cuda() images = server_state["preprocesses"][model_name](image).unsqueeze(0).cuda()
@ -261,7 +229,6 @@ def interrogate(image, models):
del server_state["clip_models"][model_name] del server_state["clip_models"][model_name]
gc.collect() gc.collect()
# for i in range(len(st.session_state["uploaded_image"])):
st.session_state["prediction_table"][st.session_state["processed_image_count"]].dataframe(pd.DataFrame( st.session_state["prediction_table"][st.session_state["processed_image_count"]].dataframe(pd.DataFrame(
table, columns=["Model", "Medium", "Artist", "Trending", "Movement", "Flavors", "Techniques", "Tags"])) table, columns=["Model", "Medium", "Artist", "Trending", "Movement", "Flavors", "Techniques", "Tags"]))
@ -283,38 +250,9 @@ def interrogate(image, models):
st.session_state["text_result"][st.session_state["processed_image_count"]].code( st.session_state["text_result"][st.session_state["processed_image_count"]].code(
f"\n\n{caption}, {medium} {artist}, {trending}, {movement}, {techniques}, {flavors}, {tags}", language="") f"\n\n{caption}, {medium} {artist}, {trending}, {movement}, {techniques}, {flavors}, {tags}", language="")
#
logger.info("Finished Interrogating.") logger.info("Finished Interrogating.")
st.session_state["log"].append("Finished Interrogating.")
st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
st.session_state["log"] = []
#
def img2txt(): def img2txt():
data_path = "data/"
server_state["artists"] = load_list(os.path.join(data_path, 'img2txt', 'artists.txt'))
server_state["flavors"] = load_list(os.path.join(data_path, 'img2txt', 'flavors.txt'))
server_state["mediums"] = load_list(os.path.join(data_path, 'img2txt', 'mediums.txt'))
server_state["movements"] = load_list(os.path.join(data_path, 'img2txt', 'movements.txt'))
server_state["sites"] = load_list(os.path.join(data_path, 'img2txt', 'sites.txt'))
#server_state["domains"] = load_list(os.path.join(data_path, 'img2txt', 'domains.txt'))
#server_state["subreddits"] = load_list(os.path.join(data_path, 'img2txt', 'subreddits.txt'))
server_state["techniques"] = load_list(os.path.join(data_path, 'img2txt', 'techniques.txt'))
server_state["tags"] = load_list(os.path.join(data_path, 'img2txt', 'tags.txt'))
#server_state["genres"] = load_list(os.path.join(data_path, 'img2txt', 'genres.txt'))
# server_state["styles"] = load_list(os.path.join(data_path, 'img2txt', 'styles.txt'))
# server_state["subjects"] = load_list(os.path.join(data_path, 'img2txt', 'subjects.txt'))
server_state["trending_list"] = [site for site in server_state["sites"]]
server_state["trending_list"].extend(["trending on "+site for site in server_state["sites"]])
server_state["trending_list"].extend(["featured on "+site for site in server_state["sites"]])
server_state["trending_list"].extend([site+" contest winner" for site in server_state["sites"]])
#image_path_or_url = "https://i.redd.it/e2e8gimigjq91.jpg"
models = [] models = []
if st.session_state["ViT-L/14"]: if st.session_state["ViT-L/14"]:
@ -364,7 +302,36 @@ def img2txt():
def layout(): def layout():
#set_page_title("Image-to-Text - Stable Diffusion WebUI") #set_page_title("Image-to-Text - Stable Diffusion WebUI")
#st.info("Under Construction. :construction_worker:") #st.info("Under Construction. :construction_worker:")
#
if "clip_models" not in server_state:
server_state["clip_models"] = {}
if "preprocesses" not in server_state:
server_state["preprocesses"] = {}
data_path = "data/"
if "artists" not in server_state:
server_state["artists"] = load_list(os.path.join(data_path, 'img2txt', 'artists.txt'))
if "flavors" not in server_state:
server_state["flavors"] = random.choices(load_list(os.path.join(data_path, 'img2txt', 'flavors.txt')), k=2000)
if "mediums" not in server_state:
server_state["mediums"] = load_list(os.path.join(data_path, 'img2txt', 'mediums.txt'))
if "movements" not in server_state:
server_state["movements"] = load_list(os.path.join(data_path, 'img2txt', 'movements.txt'))
if "sites" not in server_state:
server_state["sites"] = load_list(os.path.join(data_path, 'img2txt', 'sites.txt'))
#server_state["domains"] = load_list(os.path.join(data_path, 'img2txt', 'domains.txt'))
#server_state["subreddits"] = load_list(os.path.join(data_path, 'img2txt', 'subreddits.txt'))
if "techniques" not in server_state:
server_state["techniques"] = load_list(os.path.join(data_path, 'img2txt', 'techniques.txt'))
if "tags" not in server_state:
server_state["tags"] = load_list(os.path.join(data_path, 'img2txt', 'tags.txt'))
#server_state["genres"] = load_list(os.path.join(data_path, 'img2txt', 'genres.txt'))
# server_state["styles"] = load_list(os.path.join(data_path, 'img2txt', 'styles.txt'))
# server_state["subjects"] = load_list(os.path.join(data_path, 'img2txt', 'subjects.txt'))
if "trending_list" not in server_state:
server_state["trending_list"] = [site for site in server_state["sites"]]
server_state["trending_list"].extend(["trending on "+site for site in server_state["sites"]])
server_state["trending_list"].extend(["featured on "+site for site in server_state["sites"]])
server_state["trending_list"].extend([site+" contest winner" for site in server_state["sites"]])
with st.form("img2txt-inputs"): with st.form("img2txt-inputs"):
st.session_state["generation_mode"] = "img2txt" st.session_state["generation_mode"] = "img2txt"