Improved the img2txt tab by having more tags.

This commit is contained in:
ZeroCool940711 2022-10-24 00:22:36 -07:00
parent 0a3b6761e2
commit 9caef25390
4 changed files with 179 additions and 798 deletions

View File

@ -34,7 +34,6 @@ maxMessageSize = 200
enableWebsocketCompression = false
[browser]
serverAddress = "localhost"
gatherUsageStats = false
serverPort = 8501

File diff suppressed because it is too large Load Diff

View File

@ -3,7 +3,8 @@ Diagrammatic
Geometric
Architectural
Analytic
3D / Anamorphic
3D
Anamorphic
Pencil
Color Pencil
Charcoal
@ -38,7 +39,6 @@ Photorealism
Photo realistic
Doodling
Wordtoons
Tattoo
Cartoon
Anime
Manga
@ -48,21 +48,16 @@ Calligraphy
Mosaic
Figurative
Anatomy
Caricature
Life
Still life
Portrait
Landscape
Perspective
Cartoon
Funny
Surreal
Wall Mural
Street
3D
Anamorphic
Realistic
Photo Realistic
Hyper Realistic
Doodle
Scribble

View File

@ -135,6 +135,10 @@ def clear_cuda():
def batch_rank(model, image_features, text_array, batch_size=st.session_state["defaults"].img2txt.batch_size):
#logger.info("Ranking")
#st.session_state["log"].append("Ranking")
#st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
batch_size = min(batch_size, len(text_array))
batch_count = int(len(text_array) / batch_size)
batches = [text_array[i*batch_size:(i+1)*batch_size] for i in range(batch_count)]
@ -166,7 +170,7 @@ def interrogate(image, models):
return
table = []
bests = [[('', 0)]]*5
bests = [[('', 0)]]*7
logger.info("Ranking Text")
@ -219,8 +223,8 @@ def interrogate(image, models):
ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["trending_list"]))
ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["movements"]))
ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["flavors"]))
ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["domains"]))
ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["subreddits"]))
#ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["domains"]))
#ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["subreddits"]))
ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["techniques"]))
ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["tags"]))
@ -254,7 +258,7 @@ def interrogate(image, models):
# for i in range(len(st.session_state["uploaded_image"])):
st.session_state["prediction_table"][st.session_state["processed_image_count"]].dataframe(pd.DataFrame(
table, columns=["Model", "Medium", "Artist", "Trending", "Movement", "Flavors", "Domains", "Subreddits", "Techniques", "Tags"]))
table, columns=["Model", "Medium", "Artist", "Trending", "Movement", "Flavors", "Techniques", "Tags"]))
flaves = ', '.join([f"{x[0]}" for x in bests[4]])
medium = bests[0][0][0]
@ -283,8 +287,8 @@ def img2txt():
server_state["mediums"] = load_list(os.path.join(data_path, 'img2txt', 'mediums.txt'))
server_state["movements"] = load_list(os.path.join(data_path, 'img2txt', 'movements.txt'))
server_state["sites"] = load_list(os.path.join(data_path, 'img2txt', 'sites.txt'))
server_state["domains"] = load_list(os.path.join(data_path, 'img2txt', 'domains.txt'))
server_state["subreddits"] = load_list(os.path.join(data_path, 'img2txt', 'subreddits.txt'))
#server_state["domains"] = load_list(os.path.join(data_path, 'img2txt', 'domains.txt'))
#server_state["subreddits"] = load_list(os.path.join(data_path, 'img2txt', 'subreddits.txt'))
server_state["techniques"] = load_list(os.path.join(data_path, 'img2txt', 'techniques.txt'))
server_state["tags"] = load_list(os.path.join(data_path, 'img2txt', 'tags.txt'))
#server_state["genres"] = load_list(os.path.join(data_path, 'img2txt', 'genres.txt'))