mirror of
https://github.com/sd-webui/stable-diffusion-webui.git
synced 2024-12-14 06:35:14 +03:00
Basic implementation for the Concept Library tab made by cloning the Home tab.
This commit is contained in:
parent
5111541bcd
commit
997eb12733
@ -973,52 +973,6 @@ def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer, tok
|
||||
text_encoder.get_input_embeddings().weight.data[token_id] = embeds
|
||||
return token
|
||||
|
||||
def concepts_library():
|
||||
|
||||
html_gallery = '''
|
||||
<div class="flex gr-gap gr-form-gap row gap-4 w-full flex-wrap" id="main_row">
|
||||
'''
|
||||
for model in models:
|
||||
html_gallery = html_gallery+f'''
|
||||
<div class="gr-block gr-box relative w-full overflow-hidden border-solid border border-gray-200 gr-panel">
|
||||
<div class="output-markdown gr-prose" style="max-width: 100%;">
|
||||
<h3>
|
||||
<a href="https://huggingface.co/{model["id"]}" target="_blank">
|
||||
<code>{html.escape(model["token"])}</code>
|
||||
</a>
|
||||
</h3>
|
||||
</div>
|
||||
<div id="gallery" class="gr-block gr-box relative w-full overflow-hidden border-solid border border-gray-200">
|
||||
<div class="wrap svelte-17ttdjv opacity-0"></div>
|
||||
<div class="absolute left-0 top-0 py-1 px-2 rounded-br-lg shadow-sm text-xs text-gray-500 flex items-center pointer-events-none bg-white z-20 border-b border-r border-gray-100 dark:bg-gray-900">
|
||||
<span class="mr-2 h-[12px] w-[12px] opacity-80">
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="100%" height="100%" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="feather feather-image">
|
||||
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"></rect>
|
||||
<circle cx="8.5" cy="8.5" r="1.5"></circle>
|
||||
<polyline points="21 15 16 10 5 21"></polyline>
|
||||
</svg>
|
||||
</span> {model["concept_type"]}
|
||||
</div>
|
||||
<div class="overflow-y-auto h-full p-2" style="position: relative;">
|
||||
<div class="grid gap-2 grid-cols-2 sm:grid-cols-2 md:grid-cols-2 lg:grid-cols-2 xl:grid-cols-2 2xl:grid-cols-2 svelte-1g9btlg pt-6">
|
||||
'''
|
||||
for image in model["images"]:
|
||||
html_gallery = html_gallery + f'''
|
||||
<button class="gallery-item svelte-1g9btlg">
|
||||
<img alt="" loading="lazy" class="h-full w-full overflow-hidden object-contain" src="file/{image}">
|
||||
</button>
|
||||
'''
|
||||
html_gallery = html_gallery+'''
|
||||
</div>
|
||||
<iframe style="display: block; position: absolute; top: 0; left: 0; width: 100%; height: 100%; overflow: hidden; border: 0; opacity: 0; pointer-events: none; z-index: -1;" aria-hidden="true" tabindex="-1" src="about:blank"></iframe>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
'''
|
||||
html_gallery = html_gallery+'''
|
||||
</div>
|
||||
'''
|
||||
|
||||
def image_grid(imgs, batch_size, force_n_rows=None, captions=None):
|
||||
#print (len(imgs))
|
||||
if force_n_rows is not None:
|
||||
@ -1377,6 +1331,9 @@ def process_images(
|
||||
for files in os.listdir(embedding_path):
|
||||
if files.endswith(ext):
|
||||
load_learned_embed_in_clip(f"{os.path.join(embedding_path, files)}", text_encoder, tokenizer, f"<{prompt_tokens[0]}>")
|
||||
|
||||
#
|
||||
|
||||
|
||||
os.makedirs(outpath, exist_ok=True)
|
||||
|
||||
|
@ -8,46 +8,50 @@ from sd_utils import *
|
||||
#other imports
|
||||
#from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
# Temp imports
|
||||
# Temp imports
|
||||
|
||||
|
||||
# end of imports
|
||||
#---------------------------------------------------------------------------------------------------------------
|
||||
|
||||
def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer, token=None):
|
||||
loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
|
||||
#def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer, token=None):
|
||||
|
||||
#loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
|
||||
|
||||
# separate token and the embeds
|
||||
trained_token = list(loaded_learned_embeds.keys())[0]
|
||||
embeds = loaded_learned_embeds[trained_token]
|
||||
## separate token and the embeds
|
||||
#print (loaded_learned_embeds)
|
||||
#trained_token = list(loaded_learned_embeds.keys())[0]
|
||||
#embeds = loaded_learned_embeds[trained_token]
|
||||
|
||||
# cast to dtype of text_encoder
|
||||
dtype = text_encoder.get_input_embeddings().weight.dtype
|
||||
embeds.to(dtype)
|
||||
## cast to dtype of text_encoder
|
||||
#dtype = text_encoder.get_input_embeddings().weight.dtype
|
||||
#embeds.to(dtype)
|
||||
|
||||
# add the token in tokenizer
|
||||
token = token if token is not None else trained_token
|
||||
num_added_tokens = tokenizer.add_tokens(token)
|
||||
i = 1
|
||||
while(num_added_tokens == 0):
|
||||
print(f"The tokenizer already contains the token {token}.")
|
||||
token = f"{token[:-1]}-{i}>"
|
||||
print(f"Attempting to add the token {token}.")
|
||||
num_added_tokens = tokenizer.add_tokens(token)
|
||||
i+=1
|
||||
## add the token in tokenizer
|
||||
#token = token if token is not None else trained_token
|
||||
#num_added_tokens = tokenizer.add_tokens(token)
|
||||
#i = 1
|
||||
#while(num_added_tokens == 0):
|
||||
#print(f"The tokenizer already contains the token {token}.")
|
||||
#token = f"{token[:-1]}-{i}>"
|
||||
#print(f"Attempting to add the token {token}.")
|
||||
#num_added_tokens = tokenizer.add_tokens(token)
|
||||
#i+=1
|
||||
|
||||
# resize the token embeddings
|
||||
text_encoder.resize_token_embeddings(len(tokenizer))
|
||||
## resize the token embeddings
|
||||
#text_encoder.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
# get the id for the token and assign the embeds
|
||||
token_id = tokenizer.convert_tokens_to_ids(token)
|
||||
text_encoder.get_input_embeddings().weight.data[token_id] = embeds
|
||||
return token
|
||||
## get the id for the token and assign the embeds
|
||||
#token_id = tokenizer.convert_tokens_to_ids(token)
|
||||
#text_encoder.get_input_embeddings().weight.data[token_id] = embeds
|
||||
#return token
|
||||
|
||||
#def token_loader()
|
||||
learned_token = load_learned_embed_in_clip(f"models/custom/embeddings/Custom Ami.pt", pipe.text_encoder, pipe.tokenizer, "*")
|
||||
##def token_loader()
|
||||
#learned_token = load_learned_embed_in_clip(f"models/custom/embeddings/Custom Ami.pt", st.session_state.pipe.text_encoder, st.session_state.pipe.tokenizer, "*")
|
||||
#model_content["token"] = learned_token
|
||||
#models.append(model_content)
|
||||
|
||||
model_id = "./models/custom/embeddings/"
|
||||
|
||||
def layout():
|
||||
st.write("Textual Inversion")
|
||||
st.write("Textual Inversion")
|
@ -4,6 +4,8 @@ from sd_utils import *
|
||||
|
||||
# streamlit imports
|
||||
from streamlit import StopException
|
||||
from streamlit.runtime.in_memory_file_manager import in_memory_file_manager
|
||||
from streamlit.elements import image as STImage
|
||||
|
||||
#other imports
|
||||
import os
|
||||
@ -12,8 +14,6 @@ from io import BytesIO
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.models.diffusion.plms import PLMSSampler
|
||||
|
||||
from streamlit.runtime.in_memory_file_manager import in_memory_file_manager
|
||||
from streamlit.elements import image as STImage
|
||||
# Temp imports
|
||||
|
||||
|
||||
@ -185,12 +185,7 @@ def layout():
|
||||
st.session_state["update_preview_frequency"] = st.text_input("Update Image Preview Frequency", value=st.session_state['defaults'].txt2img.update_preview_frequency,
|
||||
help="Frequency in steps at which the the preview image is updated. By default the frequency \
|
||||
is set to 1 step.")
|
||||
#
|
||||
#if st.session_state.defaults.general.use_sd_concepts_library:
|
||||
#with st.expander("Concept Library"):
|
||||
#st.write("test")
|
||||
|
||||
|
||||
|
||||
with col2:
|
||||
preview_tab, gallery_tab = st.tabs(["Preview", "Gallery"])
|
||||
|
||||
|
@ -100,8 +100,8 @@ def layout():
|
||||
iconName=['dashboard','model_training' ,'cloud_download', 'settings'], default_choice=0)
|
||||
|
||||
if tabs =='Stable Diffusion':
|
||||
txt2img_tab, img2img_tab, txt2vid_tab, postprocessing_tab = st.tabs(["Text-to-Image Unified", "Image-to-Image Unified",
|
||||
"Text-to-Video","Post-Processing"])
|
||||
txt2img_tab, img2img_tab, txt2vid_tab, postprocessing_tab, concept_library_tab = st.tabs(["Text-to-Image Unified", "Image-to-Image Unified",
|
||||
"Text-to-Video","Post-Processing", "Concept Library"])
|
||||
#with home_tab:
|
||||
#from home import layout
|
||||
#layout()
|
||||
@ -117,7 +117,10 @@ def layout():
|
||||
with txt2vid_tab:
|
||||
from txt2vid import layout
|
||||
layout()
|
||||
|
||||
|
||||
with concept_library_tab:
|
||||
from sd_concept_library import layout
|
||||
layout()
|
||||
|
||||
#
|
||||
elif tabs == 'Model Manager':
|
||||
|
Loading…
Reference in New Issue
Block a user