Merge pull request #1199 from ZeroCool940711/dev

Basic implementation for the Concept Library tab made by cloning the Home tab.
This commit is contained in:
ZeroCool 2022-09-17 19:10:23 -07:00 committed by GitHub
commit 1f534cd5ac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 44 additions and 85 deletions

View File

@ -973,52 +973,6 @@ def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer, tok
text_encoder.get_input_embeddings().weight.data[token_id] = embeds
return token
def concepts_library():
html_gallery = '''
<div class="flex gr-gap gr-form-gap row gap-4 w-full flex-wrap" id="main_row">
'''
for model in models:
html_gallery = html_gallery+f'''
<div class="gr-block gr-box relative w-full overflow-hidden border-solid border border-gray-200 gr-panel">
<div class="output-markdown gr-prose" style="max-width: 100%;">
<h3>
<a href="https://huggingface.co/{model["id"]}" target="_blank">
<code>{html.escape(model["token"])}</code>
</a>
</h3>
</div>
<div id="gallery" class="gr-block gr-box relative w-full overflow-hidden border-solid border border-gray-200">
<div class="wrap svelte-17ttdjv opacity-0"></div>
<div class="absolute left-0 top-0 py-1 px-2 rounded-br-lg shadow-sm text-xs text-gray-500 flex items-center pointer-events-none bg-white z-20 border-b border-r border-gray-100 dark:bg-gray-900">
<span class="mr-2 h-[12px] w-[12px] opacity-80">
<svg xmlns="http://www.w3.org/2000/svg" width="100%" height="100%" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="feather feather-image">
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"></rect>
<circle cx="8.5" cy="8.5" r="1.5"></circle>
<polyline points="21 15 16 10 5 21"></polyline>
</svg>
</span> {model["concept_type"]}
</div>
<div class="overflow-y-auto h-full p-2" style="position: relative;">
<div class="grid gap-2 grid-cols-2 sm:grid-cols-2 md:grid-cols-2 lg:grid-cols-2 xl:grid-cols-2 2xl:grid-cols-2 svelte-1g9btlg pt-6">
'''
for image in model["images"]:
html_gallery = html_gallery + f'''
<button class="gallery-item svelte-1g9btlg">
<img alt="" loading="lazy" class="h-full w-full overflow-hidden object-contain" src="file/{image}">
</button>
'''
html_gallery = html_gallery+'''
</div>
<iframe style="display: block; position: absolute; top: 0; left: 0; width: 100%; height: 100%; overflow: hidden; border: 0; opacity: 0; pointer-events: none; z-index: -1;" aria-hidden="true" tabindex="-1" src="about:blank"></iframe>
</div>
</div>
</div>
'''
html_gallery = html_gallery+'''
</div>
'''
def image_grid(imgs, batch_size, force_n_rows=None, captions=None):
#print (len(imgs))
if force_n_rows is not None:
@ -1377,6 +1331,9 @@ def process_images(
for files in os.listdir(embedding_path):
if files.endswith(ext):
load_learned_embed_in_clip(f"{os.path.join(embedding_path, files)}", text_encoder, tokenizer, f"<{prompt_tokens[0]}>")
#
os.makedirs(outpath, exist_ok=True)

View File

@ -8,46 +8,50 @@ from sd_utils import *
#other imports
#from transformers import CLIPTextModel, CLIPTokenizer
# Temp imports
# Temp imports
# end of imports
#---------------------------------------------------------------------------------------------------------------
def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer, token=None):
loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
#def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer, token=None):
#loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
# separate token and the embeds
trained_token = list(loaded_learned_embeds.keys())[0]
embeds = loaded_learned_embeds[trained_token]
## separate token and the embeds
#print (loaded_learned_embeds)
#trained_token = list(loaded_learned_embeds.keys())[0]
#embeds = loaded_learned_embeds[trained_token]
# cast to dtype of text_encoder
dtype = text_encoder.get_input_embeddings().weight.dtype
embeds.to(dtype)
## cast to dtype of text_encoder
#dtype = text_encoder.get_input_embeddings().weight.dtype
#embeds.to(dtype)
# add the token in tokenizer
token = token if token is not None else trained_token
num_added_tokens = tokenizer.add_tokens(token)
i = 1
while(num_added_tokens == 0):
print(f"The tokenizer already contains the token {token}.")
token = f"{token[:-1]}-{i}>"
print(f"Attempting to add the token {token}.")
num_added_tokens = tokenizer.add_tokens(token)
i+=1
## add the token in tokenizer
#token = token if token is not None else trained_token
#num_added_tokens = tokenizer.add_tokens(token)
#i = 1
#while(num_added_tokens == 0):
#print(f"The tokenizer already contains the token {token}.")
#token = f"{token[:-1]}-{i}>"
#print(f"Attempting to add the token {token}.")
#num_added_tokens = tokenizer.add_tokens(token)
#i+=1
# resize the token embeddings
text_encoder.resize_token_embeddings(len(tokenizer))
## resize the token embeddings
#text_encoder.resize_token_embeddings(len(tokenizer))
# get the id for the token and assign the embeds
token_id = tokenizer.convert_tokens_to_ids(token)
text_encoder.get_input_embeddings().weight.data[token_id] = embeds
return token
## get the id for the token and assign the embeds
#token_id = tokenizer.convert_tokens_to_ids(token)
#text_encoder.get_input_embeddings().weight.data[token_id] = embeds
#return token
#def token_loader()
learned_token = load_learned_embed_in_clip(f"models/custom/embeddings/Custom Ami.pt", pipe.text_encoder, pipe.tokenizer, "*")
##def token_loader()
#learned_token = load_learned_embed_in_clip(f"models/custom/embeddings/Custom Ami.pt", st.session_state.pipe.text_encoder, st.session_state.pipe.tokenizer, "*")
#model_content["token"] = learned_token
#models.append(model_content)
model_id = "./models/custom/embeddings/"
def layout():
st.write("Textual Inversion")
st.write("Textual Inversion")

View File

@ -4,6 +4,8 @@ from sd_utils import *
# streamlit imports
from streamlit import StopException
from streamlit.runtime.in_memory_file_manager import in_memory_file_manager
from streamlit.elements import image as STImage
#other imports
import os
@ -12,8 +14,6 @@ from io import BytesIO
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from streamlit.runtime.in_memory_file_manager import in_memory_file_manager
from streamlit.elements import image as STImage
# Temp imports
@ -185,12 +185,7 @@ def layout():
st.session_state["update_preview_frequency"] = st.text_input("Update Image Preview Frequency", value=st.session_state['defaults'].txt2img.update_preview_frequency,
help="Frequency in steps at which the the preview image is updated. By default the frequency \
is set to 1 step.")
#
#if st.session_state.defaults.general.use_sd_concepts_library:
#with st.expander("Concept Library"):
#st.write("test")
with col2:
preview_tab, gallery_tab = st.tabs(["Preview", "Gallery"])

View File

@ -100,8 +100,8 @@ def layout():
iconName=['dashboard','model_training' ,'cloud_download', 'settings'], default_choice=0)
if tabs =='Stable Diffusion':
txt2img_tab, img2img_tab, txt2vid_tab, postprocessing_tab = st.tabs(["Text-to-Image Unified", "Image-to-Image Unified",
"Text-to-Video","Post-Processing"])
txt2img_tab, img2img_tab, txt2vid_tab, postprocessing_tab, concept_library_tab = st.tabs(["Text-to-Image Unified", "Image-to-Image Unified",
"Text-to-Video","Post-Processing", "Concept Library"])
#with home_tab:
#from home import layout
#layout()
@ -117,7 +117,10 @@ def layout():
with txt2vid_tab:
from txt2vid import layout
layout()
with concept_library_tab:
from sd_concept_library import layout
layout()
#
elif tabs == 'Model Manager':