mirror of
https://github.com/Sygil-Dev/sygil-webui.git
synced 2024-12-14 14:05:36 +03:00
img2txt speed + vram issues
This commit is contained in:
parent
7985c97c69
commit
c1e97c8bcc
@ -302,6 +302,7 @@ img2img:
|
||||
img2txt:
|
||||
batch_size: 420
|
||||
blip_image_eval_size: 512
|
||||
keep_all_models_loaded: False
|
||||
|
||||
concepts_library:
|
||||
concepts_per_page: 12
|
||||
|
@ -152,13 +152,11 @@ def generate_caption(pil_image):
|
||||
#print (caption)
|
||||
return caption[0]
|
||||
|
||||
|
||||
def load_list(filename):
|
||||
with open(filename, 'r', encoding='utf-8', errors='replace') as f:
|
||||
items = [line.strip() for line in f.readlines()]
|
||||
return items
|
||||
|
||||
|
||||
def rank(model, image_features, text_array, top_count=1):
|
||||
top_count = min(top_count, len(text_array))
|
||||
text_tokens = clip.tokenize([text for text in text_array]).cuda()
|
||||
@ -181,9 +179,9 @@ def clear_cuda():
|
||||
|
||||
|
||||
def batch_rank(model, image_features, text_array, batch_size=st.session_state["defaults"].img2txt.batch_size):
|
||||
batch_count = len(text_array)
|
||||
batch_size = min(batch_size, len(text_array))
|
||||
batch_count = int(len(text_array) / batch_size)
|
||||
batches = [text_array[i*batch_size:(i+1)*batch_size] for i in range(batch_count)]
|
||||
batches.append(text_array[batch_count*batch_size:])
|
||||
ranks = []
|
||||
for batch in batches:
|
||||
ranks += rank(model, image_features, batch)
|
||||
@ -225,6 +223,15 @@ def interrogate(image, models):
|
||||
st.session_state["log_message"].code(f"Interrogating with {model_name}...", language='')
|
||||
|
||||
if model_name not in server_state["clip_models"]:
|
||||
if not st.session_state["defaults"].img2txt.keep_all_models_loaded:
|
||||
model_to_delete = []
|
||||
for model in server_state["clip_models"]:
|
||||
if model != model_name:
|
||||
model_to_delete.append(model)
|
||||
for model in model_to_delete:
|
||||
del server_state["clip_models"][model]
|
||||
del server_state["preprocesses"][model]
|
||||
clear_cuda()
|
||||
if model_name == 'ViT-H-14':
|
||||
server_state["clip_models"][model_name], _, server_state["preprocesses"][model_name] = open_clip.create_model_and_transforms(model_name, pretrained='laion2b_s32b_b79k', cache_dir='models/clip')
|
||||
elif model_name == 'ViT-g-14':
|
||||
@ -316,28 +323,12 @@ def img2txt():
|
||||
|
||||
models = []
|
||||
|
||||
if st.session_state["ViTB32"]:
|
||||
models.append('ViT-B/32')
|
||||
if st.session_state['ViTB16']:
|
||||
models.append('ViT-B/16')
|
||||
if st.session_state["ViTL14"]:
|
||||
if st.session_state["ViT-L/14"]:
|
||||
models.append('ViT-L/14')
|
||||
if st.session_state["ViT-H-14"]:
|
||||
models.append('ViT-H-14')
|
||||
if st.session_state["ViT-g-14"]:
|
||||
models.append('ViT-g-14')
|
||||
if st.session_state["ViTL14_336px"]:
|
||||
models.append('ViT-L/14@336px')
|
||||
if st.session_state["RN101"]:
|
||||
models.append('RN101')
|
||||
if st.session_state["RN50"]:
|
||||
models.append('RN50')
|
||||
if st.session_state["RN50x4"]:
|
||||
models.append('RN50x4')
|
||||
if st.session_state["RN50x16"]:
|
||||
models.append('RN50x16')
|
||||
if st.session_state["RN50x64"]:
|
||||
models.append('RN50x64')
|
||||
|
||||
# if str(image_path_or_url).startswith('http://') or str(image_path_or_url).startswith('https://'):
|
||||
#image = Image.open(requests.get(image_path_or_url, stream=True).raw).convert('RGB')
|
||||
@ -375,22 +366,10 @@ def layout():
|
||||
#st.subheader("Input Image")
|
||||
st.session_state["uploaded_image"] = st.file_uploader('Input Image', type=['png', 'jpg', 'jpeg'], accept_multiple_files=True)
|
||||
|
||||
st.subheader("CLIP models")
|
||||
with st.expander("Stable Diffusion", expanded=True):
|
||||
st.session_state["ViTL14"] = st.checkbox("ViTL14", value=True, help="For StableDiffusion you can just use ViTL14.")
|
||||
|
||||
with st.expander("Others"):
|
||||
st.info("For DiscoDiffusion and JAX enable all the same models here as you intend to use when generating your images.")
|
||||
with st.expander("CLIP models", expanded=True):
|
||||
st.session_state["ViT-L/14"] = st.checkbox("ViT-L/14", value=True, help="ViT-L/14 model.")
|
||||
st.session_state["ViT-H-14"] = st.checkbox("ViT-H-14", value=False, help="ViT-H-14 model.")
|
||||
st.session_state["ViT-g-14"] = st.checkbox("ViT-g-14", value=False, help="ViT-g-14 model.")
|
||||
st.session_state["ViTL14_336px"] = st.checkbox("ViTL14_336px", value=False, help="ViTL14_336px model.")
|
||||
st.session_state["ViTB16"] = st.checkbox("ViTB16", value=False, help="ViTB16 model.")
|
||||
st.session_state["ViTB32"] = st.checkbox("ViTB32", value=False, help="ViTB32 model.")
|
||||
st.session_state["RN50"] = st.checkbox("RN50", value=False, help="RN50 model.")
|
||||
st.session_state["RN50x4"] = st.checkbox("RN50x4", value=False, help="RN50x4 model.")
|
||||
st.session_state["RN50x16"] = st.checkbox("RN50x16", value=False, help="RN50x16 model.")
|
||||
st.session_state["RN50x64"] = st.checkbox("RN50x64", value=False, help="RN50x64 model.")
|
||||
st.session_state["RN101"] = st.checkbox("RN101", value=False, help="RN101 model.")
|
||||
|
||||
#
|
||||
# st.subheader("Logs:")
|
||||
@ -448,14 +427,14 @@ def layout():
|
||||
|
||||
if generate_button:
|
||||
# if model, pipe, RealESRGAN or GFPGAN is in st.session_state remove the model and pipe form session_state so that they are reloaded.
|
||||
if "model" in st.session_state and st.session_state["defaults"].general.optimized:
|
||||
del st.session_state["model"]
|
||||
if "pipe" in st.session_state and st.session_state["defaults"].general.optimized:
|
||||
del st.session_state["pipe"]
|
||||
if "RealESRGAN" in st.session_state and st.session_state["defaults"].general.optimized:
|
||||
del st.session_state["RealESRGAN"]
|
||||
if "GFPGAN" in st.session_state and st.session_state["defaults"].general.optimized:
|
||||
del st.session_state["GFPGAN"]
|
||||
if "model" in server_state and st.session_state["defaults"].general.optimized:
|
||||
del server_state["model"]
|
||||
if "pipe" in server_state and st.session_state["defaults"].general.optimized:
|
||||
del server_state["pipe"]
|
||||
if "RealESRGAN" in server_state and st.session_state["defaults"].general.optimized:
|
||||
del server_state["RealESRGAN"]
|
||||
if "GFPGAN" in server_state and st.session_state["defaults"].general.optimized:
|
||||
del server_state["GFPGAN"]
|
||||
|
||||
# run clip interrogator
|
||||
img2txt()
|
||||
|
Loading…
Reference in New Issue
Block a user