Improved the img2txt tab by having more tags.

This commit is contained in:
ZeroCool940711 2022-10-24 00:22:36 -07:00
parent 0a3b6761e2
commit 9caef25390
4 changed files with 179 additions and 798 deletions

View File

@ -34,7 +34,6 @@ maxMessageSize = 200
enableWebsocketCompression = false enableWebsocketCompression = false
[browser] [browser]
serverAddress = "localhost"
gatherUsageStats = false gatherUsageStats = false
serverPort = 8501 serverPort = 8501

File diff suppressed because it is too large Load Diff

View File

@ -3,7 +3,8 @@ Diagrammatic
Geometric Geometric
Architectural Architectural
Analytic Analytic
3D / Anamorphic 3D
Anamorphic
Pencil Pencil
Color Pencil Color Pencil
Charcoal Charcoal
@ -38,7 +39,6 @@ Photorealism
Photo realistic Photo realistic
Doodling Doodling
Wordtoons Wordtoons
Tattoo
Cartoon Cartoon
Anime Anime
Manga Manga
@ -48,21 +48,16 @@ Calligraphy
Mosaic Mosaic
Figurative Figurative
Anatomy Anatomy
Caricature
Life Life
Still life Still life
Portrait Portrait
Landscape Landscape
Perspective Perspective
Cartoon
Funny Funny
Surreal Surreal
Wall Mural Wall Mural
Street Street
3D
Anamorphic
Realistic Realistic
Photo Realistic Photo Realistic
Hyper Realistic Hyper Realistic
Doodle Doodle
Scribble

View File

@ -135,6 +135,10 @@ def clear_cuda():
def batch_rank(model, image_features, text_array, batch_size=st.session_state["defaults"].img2txt.batch_size): def batch_rank(model, image_features, text_array, batch_size=st.session_state["defaults"].img2txt.batch_size):
#logger.info("Ranking")
#st.session_state["log"].append("Ranking")
#st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
batch_size = min(batch_size, len(text_array)) batch_size = min(batch_size, len(text_array))
batch_count = int(len(text_array) / batch_size) batch_count = int(len(text_array) / batch_size)
batches = [text_array[i*batch_size:(i+1)*batch_size] for i in range(batch_count)] batches = [text_array[i*batch_size:(i+1)*batch_size] for i in range(batch_count)]
@ -166,7 +170,7 @@ def interrogate(image, models):
return return
table = [] table = []
bests = [[('', 0)]]*5 bests = [[('', 0)]]*7
logger.info("Ranking Text") logger.info("Ranking Text")
@ -219,8 +223,8 @@ def interrogate(image, models):
ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["trending_list"])) ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["trending_list"]))
ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["movements"])) ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["movements"]))
ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["flavors"])) ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["flavors"]))
ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["domains"])) #ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["domains"]))
ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["subreddits"])) #ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["subreddits"]))
ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["techniques"])) ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["techniques"]))
ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["tags"])) ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["tags"]))
@ -254,7 +258,7 @@ def interrogate(image, models):
# for i in range(len(st.session_state["uploaded_image"])): # for i in range(len(st.session_state["uploaded_image"])):
st.session_state["prediction_table"][st.session_state["processed_image_count"]].dataframe(pd.DataFrame( st.session_state["prediction_table"][st.session_state["processed_image_count"]].dataframe(pd.DataFrame(
table, columns=["Model", "Medium", "Artist", "Trending", "Movement", "Flavors", "Domains", "Subreddits", "Techniques", "Tags"])) table, columns=["Model", "Medium", "Artist", "Trending", "Movement", "Flavors", "Techniques", "Tags"]))
flaves = ', '.join([f"{x[0]}" for x in bests[4]]) flaves = ', '.join([f"{x[0]}" for x in bests[4]])
medium = bests[0][0][0] medium = bests[0][0][0]
@ -283,8 +287,8 @@ def img2txt():
server_state["mediums"] = load_list(os.path.join(data_path, 'img2txt', 'mediums.txt')) server_state["mediums"] = load_list(os.path.join(data_path, 'img2txt', 'mediums.txt'))
server_state["movements"] = load_list(os.path.join(data_path, 'img2txt', 'movements.txt')) server_state["movements"] = load_list(os.path.join(data_path, 'img2txt', 'movements.txt'))
server_state["sites"] = load_list(os.path.join(data_path, 'img2txt', 'sites.txt')) server_state["sites"] = load_list(os.path.join(data_path, 'img2txt', 'sites.txt'))
server_state["domains"] = load_list(os.path.join(data_path, 'img2txt', 'domains.txt')) #server_state["domains"] = load_list(os.path.join(data_path, 'img2txt', 'domains.txt'))
server_state["subreddits"] = load_list(os.path.join(data_path, 'img2txt', 'subreddits.txt')) #server_state["subreddits"] = load_list(os.path.join(data_path, 'img2txt', 'subreddits.txt'))
server_state["techniques"] = load_list(os.path.join(data_path, 'img2txt', 'techniques.txt')) server_state["techniques"] = load_list(os.path.join(data_path, 'img2txt', 'techniques.txt'))
server_state["tags"] = load_list(os.path.join(data_path, 'img2txt', 'tags.txt')) server_state["tags"] = load_list(os.path.join(data_path, 'img2txt', 'tags.txt'))
#server_state["genres"] = load_list(os.path.join(data_path, 'img2txt', 'genres.txt')) #server_state["genres"] = load_list(os.path.join(data_path, 'img2txt', 'genres.txt'))