mirror of
https://github.com/Sygil-Dev/sygil-webui.git
synced 2024-12-15 06:21:34 +03:00
Update img2txt.py
This commit is contained in:
parent
70ef109d17
commit
8dcd432674
@ -61,16 +61,9 @@ from ldm.models.blip import blip_decoder
|
||||
|
||||
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
||||
blip_image_eval_size = 512
|
||||
#blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth'
|
||||
server_state["clip_models"] = {}
|
||||
server_state["preprocesses"] = {}
|
||||
|
||||
st.session_state["log"] = []
|
||||
|
||||
def load_blip_model():
|
||||
logger.info("Loading BLIP Model")
|
||||
st.session_state["log"].append("Loading BLIP Model")
|
||||
st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
|
||||
|
||||
if "blip_model" not in server_state:
|
||||
with server_state_lock['blip_model']:
|
||||
@ -79,18 +72,12 @@ def load_blip_model():
|
||||
|
||||
server_state["blip_model"] = server_state["blip_model"].eval()
|
||||
|
||||
#if not st.session_state["defaults"].general.optimized:
|
||||
server_state["blip_model"] = server_state["blip_model"].to(device).half()
|
||||
|
||||
logger.info("BLIP Model Loaded")
|
||||
st.session_state["log"].append("BLIP Model Loaded")
|
||||
st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
|
||||
else:
|
||||
logger.info("BLIP Model already loaded")
|
||||
st.session_state["log"].append("BLIP Model already loaded")
|
||||
st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
|
||||
|
||||
#return server_state["blip_model"]
|
||||
|
||||
def generate_caption(pil_image):
|
||||
|
||||
@ -105,7 +92,6 @@ def generate_caption(pil_image):
|
||||
with torch.no_grad():
|
||||
caption = server_state["blip_model"].generate(gpu_image, sample=False, num_beams=3, max_length=20, min_length=5)
|
||||
|
||||
#print (caption)
|
||||
return caption[0]
|
||||
|
||||
def load_list(filename):
|
||||
@ -135,10 +121,6 @@ def clear_cuda():
|
||||
|
||||
|
||||
def batch_rank(model, image_features, text_array, batch_size=st.session_state["defaults"].img2txt.batch_size):
|
||||
#logger.info("Ranking")
|
||||
#st.session_state["log"].append("Ranking")
|
||||
#st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
|
||||
|
||||
batch_size = min(batch_size, len(text_array))
|
||||
batch_count = int(len(text_array) / batch_size)
|
||||
batches = [text_array[i*batch_size:(i+1)*batch_size] for i in range(batch_count)]
|
||||
@ -148,13 +130,9 @@ def batch_rank(model, image_features, text_array, batch_size=st.session_state["d
|
||||
return ranks
|
||||
|
||||
def interrogate(image, models):
|
||||
|
||||
#server_state["blip_model"] =
|
||||
load_blip_model()
|
||||
|
||||
logger.info("Generating Caption")
|
||||
st.session_state["log"].append("Generating Caption")
|
||||
st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
|
||||
caption = generate_caption(image)
|
||||
|
||||
if st.session_state["defaults"].general.optimized:
|
||||
@ -162,8 +140,6 @@ def interrogate(image, models):
|
||||
clear_cuda()
|
||||
|
||||
logger.info("Caption Generated")
|
||||
st.session_state["log"].append("Caption Generated")
|
||||
st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
|
||||
|
||||
if len(models) == 0:
|
||||
logger.info(f"\n\n{caption}")
|
||||
@ -174,16 +150,9 @@ def interrogate(image, models):
|
||||
|
||||
logger.info("Ranking Text")
|
||||
|
||||
#if "clip_model" in server_state:
|
||||
#print (server_state["clip_model"])
|
||||
|
||||
#print (st.session_state["log_message"])
|
||||
|
||||
for model_name in models:
|
||||
with torch.no_grad(), torch.autocast('cuda', dtype=torch.float16):
|
||||
logger.info(f"Interrogating with {model_name}...")
|
||||
st.session_state["log"].append(f"Interrogating with {model_name}...")
|
||||
st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
|
||||
|
||||
if model_name not in server_state["clip_models"]:
|
||||
if not st.session_state["defaults"].img2txt.keep_all_models_loaded:
|
||||
@ -196,15 +165,14 @@ def interrogate(image, models):
|
||||
del server_state["preprocesses"][model]
|
||||
clear_cuda()
|
||||
if model_name == 'ViT-H-14':
|
||||
server_state["clip_models"][model_name], _, server_state["preprocesses"][model_name] = open_clip.create_model_and_transforms(model_name,
|
||||
pretrained='laion2b_s32b_b79k',
|
||||
cache_dir='models/clip')
|
||||
server_state["clip_models"][model_name], _, server_state["preprocesses"][model_name] = \
|
||||
open_clip.create_model_and_transforms(model_name, pretrained='laion2b_s32b_b79k', cache_dir='models/clip')
|
||||
elif model_name == 'ViT-g-14':
|
||||
server_state["clip_models"][model_name], _, server_state["preprocesses"][model_name] = open_clip.create_model_and_transforms(model_name,
|
||||
pretrained='laion2b_s12b_b42k',
|
||||
cache_dir='models/clip')
|
||||
server_state["clip_models"][model_name], _, server_state["preprocesses"][model_name] = \
|
||||
open_clip.create_model_and_transforms(model_name, pretrained='laion2b_s12b_b42k', cache_dir='models/clip')
|
||||
else:
|
||||
server_state["clip_models"][model_name], server_state["preprocesses"][model_name] = clip.load(model_name, device=device, download_root='models/clip')
|
||||
server_state["clip_models"][model_name], server_state["preprocesses"][model_name] = \
|
||||
clip.load(model_name, device=device, download_root='models/clip')
|
||||
server_state["clip_models"][model_name] = server_state["clip_models"][model_name].cuda().eval()
|
||||
|
||||
images = server_state["preprocesses"][model_name](image).unsqueeze(0).cuda()
|
||||
@ -261,7 +229,6 @@ def interrogate(image, models):
|
||||
del server_state["clip_models"][model_name]
|
||||
gc.collect()
|
||||
|
||||
# for i in range(len(st.session_state["uploaded_image"])):
|
||||
st.session_state["prediction_table"][st.session_state["processed_image_count"]].dataframe(pd.DataFrame(
|
||||
table, columns=["Model", "Medium", "Artist", "Trending", "Movement", "Flavors", "Techniques", "Tags"]))
|
||||
|
||||
@ -283,38 +250,9 @@ def interrogate(image, models):
|
||||
st.session_state["text_result"][st.session_state["processed_image_count"]].code(
|
||||
f"\n\n{caption}, {medium} {artist}, {trending}, {movement}, {techniques}, {flavors}, {tags}", language="")
|
||||
|
||||
#
|
||||
logger.info("Finished Interrogating.")
|
||||
st.session_state["log"].append("Finished Interrogating.")
|
||||
st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
|
||||
|
||||
st.session_state["log"] = []
|
||||
#
|
||||
|
||||
|
||||
def img2txt():
|
||||
data_path = "data/"
|
||||
|
||||
server_state["artists"] = load_list(os.path.join(data_path, 'img2txt', 'artists.txt'))
|
||||
server_state["flavors"] = load_list(os.path.join(data_path, 'img2txt', 'flavors.txt'))
|
||||
server_state["mediums"] = load_list(os.path.join(data_path, 'img2txt', 'mediums.txt'))
|
||||
server_state["movements"] = load_list(os.path.join(data_path, 'img2txt', 'movements.txt'))
|
||||
server_state["sites"] = load_list(os.path.join(data_path, 'img2txt', 'sites.txt'))
|
||||
#server_state["domains"] = load_list(os.path.join(data_path, 'img2txt', 'domains.txt'))
|
||||
#server_state["subreddits"] = load_list(os.path.join(data_path, 'img2txt', 'subreddits.txt'))
|
||||
server_state["techniques"] = load_list(os.path.join(data_path, 'img2txt', 'techniques.txt'))
|
||||
server_state["tags"] = load_list(os.path.join(data_path, 'img2txt', 'tags.txt'))
|
||||
#server_state["genres"] = load_list(os.path.join(data_path, 'img2txt', 'genres.txt'))
|
||||
# server_state["styles"] = load_list(os.path.join(data_path, 'img2txt', 'styles.txt'))
|
||||
# server_state["subjects"] = load_list(os.path.join(data_path, 'img2txt', 'subjects.txt'))
|
||||
|
||||
server_state["trending_list"] = [site for site in server_state["sites"]]
|
||||
server_state["trending_list"].extend(["trending on "+site for site in server_state["sites"]])
|
||||
server_state["trending_list"].extend(["featured on "+site for site in server_state["sites"]])
|
||||
server_state["trending_list"].extend([site+" contest winner" for site in server_state["sites"]])
|
||||
|
||||
#image_path_or_url = "https://i.redd.it/e2e8gimigjq91.jpg"
|
||||
|
||||
models = []
|
||||
|
||||
if st.session_state["ViT-L/14"]:
|
||||
@ -364,7 +302,36 @@ def img2txt():
|
||||
def layout():
|
||||
#set_page_title("Image-to-Text - Stable Diffusion WebUI")
|
||||
#st.info("Under Construction. :construction_worker:")
|
||||
|
||||
#
|
||||
if "clip_models" not in server_state:
|
||||
server_state["clip_models"] = {}
|
||||
if "preprocesses" not in server_state:
|
||||
server_state["preprocesses"] = {}
|
||||
data_path = "data/"
|
||||
if "artists" not in server_state:
|
||||
server_state["artists"] = load_list(os.path.join(data_path, 'img2txt', 'artists.txt'))
|
||||
if "flavors" not in server_state:
|
||||
server_state["flavors"] = random.choices(load_list(os.path.join(data_path, 'img2txt', 'flavors.txt')), k=2000)
|
||||
if "mediums" not in server_state:
|
||||
server_state["mediums"] = load_list(os.path.join(data_path, 'img2txt', 'mediums.txt'))
|
||||
if "movements" not in server_state:
|
||||
server_state["movements"] = load_list(os.path.join(data_path, 'img2txt', 'movements.txt'))
|
||||
if "sites" not in server_state:
|
||||
server_state["sites"] = load_list(os.path.join(data_path, 'img2txt', 'sites.txt'))
|
||||
#server_state["domains"] = load_list(os.path.join(data_path, 'img2txt', 'domains.txt'))
|
||||
#server_state["subreddits"] = load_list(os.path.join(data_path, 'img2txt', 'subreddits.txt'))
|
||||
if "techniques" not in server_state:
|
||||
server_state["techniques"] = load_list(os.path.join(data_path, 'img2txt', 'techniques.txt'))
|
||||
if "tags" not in server_state:
|
||||
server_state["tags"] = load_list(os.path.join(data_path, 'img2txt', 'tags.txt'))
|
||||
#server_state["genres"] = load_list(os.path.join(data_path, 'img2txt', 'genres.txt'))
|
||||
# server_state["styles"] = load_list(os.path.join(data_path, 'img2txt', 'styles.txt'))
|
||||
# server_state["subjects"] = load_list(os.path.join(data_path, 'img2txt', 'subjects.txt'))
|
||||
if "trending_list" not in server_state:
|
||||
server_state["trending_list"] = [site for site in server_state["sites"]]
|
||||
server_state["trending_list"].extend(["trending on "+site for site in server_state["sites"]])
|
||||
server_state["trending_list"].extend(["featured on "+site for site in server_state["sites"]])
|
||||
server_state["trending_list"].extend([site+" contest winner" for site in server_state["sites"]])
|
||||
with st.form("img2txt-inputs"):
|
||||
st.session_state["generation_mode"] = "img2txt"
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user