img2txt speed + vram issues

This commit is contained in:
hlky 2022-10-05 05:31:20 +01:00
parent 7985c97c69
commit c1e97c8bcc
No known key found for this signature in database
GPG Key ID: 55A99F1E80D907D5
2 changed files with 23 additions and 43 deletions

View File

@ -302,6 +302,7 @@ img2img:
img2txt:
batch_size: 420
blip_image_eval_size: 512
keep_all_models_loaded: False
concepts_library:
concepts_per_page: 12

View File

@ -152,13 +152,11 @@ def generate_caption(pil_image):
#print (caption)
return caption[0]
def load_list(filename):
with open(filename, 'r', encoding='utf-8', errors='replace') as f:
items = [line.strip() for line in f.readlines()]
return items
def rank(model, image_features, text_array, top_count=1):
top_count = min(top_count, len(text_array))
text_tokens = clip.tokenize([text for text in text_array]).cuda()
@ -181,9 +179,9 @@ def clear_cuda():
def batch_rank(model, image_features, text_array, batch_size=st.session_state["defaults"].img2txt.batch_size):
batch_count = len(text_array)
batch_size = min(batch_size, len(text_array))
batch_count = int(len(text_array) / batch_size)
batches = [text_array[i*batch_size:(i+1)*batch_size] for i in range(batch_count)]
batches.append(text_array[batch_count*batch_size:])
ranks = []
for batch in batches:
ranks += rank(model, image_features, batch)
@ -225,6 +223,15 @@ def interrogate(image, models):
st.session_state["log_message"].code(f"Interrogating with {model_name}...", language='')
if model_name not in server_state["clip_models"]:
if not st.session_state["defaults"].img2txt.keep_all_models_loaded:
model_to_delete = []
for model in server_state["clip_models"]:
if model != model_name:
model_to_delete.append(model)
for model in model_to_delete:
del server_state["clip_models"][model]
del server_state["preprocesses"][model]
clear_cuda()
if model_name == 'ViT-H-14':
server_state["clip_models"][model_name], _, server_state["preprocesses"][model_name] = open_clip.create_model_and_transforms(model_name, pretrained='laion2b_s32b_b79k', cache_dir='models/clip')
elif model_name == 'ViT-g-14':
@ -316,28 +323,12 @@ def img2txt():
models = []
if st.session_state["ViTB32"]:
models.append('ViT-B/32')
if st.session_state['ViTB16']:
models.append('ViT-B/16')
if st.session_state["ViTL14"]:
if st.session_state["ViT-L/14"]:
models.append('ViT-L/14')
if st.session_state["ViT-H-14"]:
models.append('ViT-H-14')
if st.session_state["ViT-g-14"]:
models.append('ViT-g-14')
if st.session_state["ViTL14_336px"]:
models.append('ViT-L/14@336px')
if st.session_state["RN101"]:
models.append('RN101')
if st.session_state["RN50"]:
models.append('RN50')
if st.session_state["RN50x4"]:
models.append('RN50x4')
if st.session_state["RN50x16"]:
models.append('RN50x16')
if st.session_state["RN50x64"]:
models.append('RN50x64')
# if str(image_path_or_url).startswith('http://') or str(image_path_or_url).startswith('https://'):
#image = Image.open(requests.get(image_path_or_url, stream=True).raw).convert('RGB')
@ -375,22 +366,10 @@ def layout():
#st.subheader("Input Image")
st.session_state["uploaded_image"] = st.file_uploader('Input Image', type=['png', 'jpg', 'jpeg'], accept_multiple_files=True)
st.subheader("CLIP models")
with st.expander("Stable Diffusion", expanded=True):
st.session_state["ViTL14"] = st.checkbox("ViTL14", value=True, help="For StableDiffusion you can just use ViTL14.")
with st.expander("Others"):
st.info("For DiscoDiffusion and JAX enable all the same models here as you intend to use when generating your images.")
with st.expander("CLIP models", expanded=True):
st.session_state["ViT-L/14"] = st.checkbox("ViT-L/14", value=True, help="ViT-L/14 model.")
st.session_state["ViT-H-14"] = st.checkbox("ViT-H-14", value=False, help="ViT-H-14 model.")
st.session_state["ViT-g-14"] = st.checkbox("ViT-g-14", value=False, help="ViT-g-14 model.")
st.session_state["ViTL14_336px"] = st.checkbox("ViTL14_336px", value=False, help="ViTL14_336px model.")
st.session_state["ViTB16"] = st.checkbox("ViTB16", value=False, help="ViTB16 model.")
st.session_state["ViTB32"] = st.checkbox("ViTB32", value=False, help="ViTB32 model.")
st.session_state["RN50"] = st.checkbox("RN50", value=False, help="RN50 model.")
st.session_state["RN50x4"] = st.checkbox("RN50x4", value=False, help="RN50x4 model.")
st.session_state["RN50x16"] = st.checkbox("RN50x16", value=False, help="RN50x16 model.")
st.session_state["RN50x64"] = st.checkbox("RN50x64", value=False, help="RN50x64 model.")
st.session_state["RN101"] = st.checkbox("RN101", value=False, help="RN101 model.")
#
# st.subheader("Logs:")
@ -448,14 +427,14 @@ def layout():
if generate_button:
# if model, pipe, RealESRGAN or GFPGAN is in st.session_state remove the model and pipe form session_state so that they are reloaded.
if "model" in st.session_state and st.session_state["defaults"].general.optimized:
del st.session_state["model"]
if "pipe" in st.session_state and st.session_state["defaults"].general.optimized:
del st.session_state["pipe"]
if "RealESRGAN" in st.session_state and st.session_state["defaults"].general.optimized:
del st.session_state["RealESRGAN"]
if "GFPGAN" in st.session_state and st.session_state["defaults"].general.optimized:
del st.session_state["GFPGAN"]
if "model" in server_state and st.session_state["defaults"].general.optimized:
del server_state["model"]
if "pipe" in server_state and st.session_state["defaults"].general.optimized:
del server_state["pipe"]
if "RealESRGAN" in server_state and st.session_state["defaults"].general.optimized:
del server_state["RealESRGAN"]
if "GFPGAN" in server_state and st.session_state["defaults"].general.optimized:
del server_state["GFPGAN"]
# run clip interrogator
img2txt()