diff --git a/scripts/sd_utils.py b/scripts/sd_utils.py index f04d0b5..e2e8989 100644 --- a/scripts/sd_utils.py +++ b/scripts/sd_utils.py @@ -973,52 +973,6 @@ def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer, tok text_encoder.get_input_embeddings().weight.data[token_id] = embeds return token -def concepts_library(): - - html_gallery = ''' -
- ''' - for model in models: - html_gallery = html_gallery+f''' -
- - -
- ''' - html_gallery = html_gallery+''' -
- ''' - def image_grid(imgs, batch_size, force_n_rows=None, captions=None): #print (len(imgs)) if force_n_rows is not None: @@ -1377,6 +1331,9 @@ def process_images( for files in os.listdir(embedding_path): if files.endswith(ext): load_learned_embed_in_clip(f"{os.path.join(embedding_path, files)}", text_encoder, tokenizer, f"<{prompt_tokens[0]}>") + + # + os.makedirs(outpath, exist_ok=True) diff --git a/scripts/textual_inversion.py b/scripts/textual_inversion.py index d6917a4..3e5cc3e 100644 --- a/scripts/textual_inversion.py +++ b/scripts/textual_inversion.py @@ -8,46 +8,50 @@ from sd_utils import * #other imports #from transformers import CLIPTextModel, CLIPTokenizer -# Temp imports +# Temp imports # end of imports #--------------------------------------------------------------------------------------------------------------- -def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer, token=None): - loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu") +#def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer, token=None): + + #loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu") - # separate token and the embeds - trained_token = list(loaded_learned_embeds.keys())[0] - embeds = loaded_learned_embeds[trained_token] + ## separate token and the embeds + #print (loaded_learned_embeds) + #trained_token = list(loaded_learned_embeds.keys())[0] + #embeds = loaded_learned_embeds[trained_token] - # cast to dtype of text_encoder - dtype = text_encoder.get_input_embeddings().weight.dtype - embeds.to(dtype) + ## cast to dtype of text_encoder + #dtype = text_encoder.get_input_embeddings().weight.dtype + #embeds.to(dtype) - # add the token in tokenizer - token = token if token is not None else trained_token - num_added_tokens = tokenizer.add_tokens(token) - i = 1 - while(num_added_tokens == 0): - print(f"The tokenizer already contains the token {token}.") - token = f"{token[:-1]}-{i}>" - print(f"Attempting to add the token {token}.") - num_added_tokens = tokenizer.add_tokens(token) - i+=1 + ## add the token in tokenizer + #token = token if token is not None else trained_token + #num_added_tokens = tokenizer.add_tokens(token) + #i = 1 + #while(num_added_tokens == 0): + #print(f"The tokenizer already contains the token {token}.") + #token = f"{token[:-1]}-{i}>" + #print(f"Attempting to add the token {token}.") + #num_added_tokens = tokenizer.add_tokens(token) + #i+=1 - # resize the token embeddings - text_encoder.resize_token_embeddings(len(tokenizer)) + ## resize the token embeddings + #text_encoder.resize_token_embeddings(len(tokenizer)) - # get the id for the token and assign the embeds - token_id = tokenizer.convert_tokens_to_ids(token) - text_encoder.get_input_embeddings().weight.data[token_id] = embeds - return token + ## get the id for the token and assign the embeds + #token_id = tokenizer.convert_tokens_to_ids(token) + #text_encoder.get_input_embeddings().weight.data[token_id] = embeds + #return token -#def token_loader() -learned_token = load_learned_embed_in_clip(f"models/custom/embeddings/Custom Ami.pt", pipe.text_encoder, pipe.tokenizer, "*") +##def token_loader() +#learned_token = load_learned_embed_in_clip(f"models/custom/embeddings/Custom Ami.pt", st.session_state.pipe.text_encoder, st.session_state.pipe.tokenizer, "*") #model_content["token"] = learned_token #models.append(model_content) +model_id = "./models/custom/embeddings/" + def layout(): - st.write("Textual Inversion") \ No newline at end of file + st.write("Textual Inversion") \ No newline at end of file diff --git a/scripts/txt2img.py b/scripts/txt2img.py index 1586b88..63360d4 100644 --- a/scripts/txt2img.py +++ b/scripts/txt2img.py @@ -4,6 +4,8 @@ from sd_utils import * # streamlit imports from streamlit import StopException +from streamlit.runtime.in_memory_file_manager import in_memory_file_manager +from streamlit.elements import image as STImage #other imports import os @@ -12,8 +14,6 @@ from io import BytesIO from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.plms import PLMSSampler -from streamlit.runtime.in_memory_file_manager import in_memory_file_manager -from streamlit.elements import image as STImage # Temp imports @@ -185,12 +185,7 @@ def layout(): st.session_state["update_preview_frequency"] = st.text_input("Update Image Preview Frequency", value=st.session_state['defaults'].txt2img.update_preview_frequency, help="Frequency in steps at which the the preview image is updated. By default the frequency \ is set to 1 step.") - # - #if st.session_state.defaults.general.use_sd_concepts_library: - #with st.expander("Concept Library"): - #st.write("test") - - + with col2: preview_tab, gallery_tab = st.tabs(["Preview", "Gallery"]) diff --git a/scripts/webui_streamlit.py b/scripts/webui_streamlit.py index d034003..669237c 100644 --- a/scripts/webui_streamlit.py +++ b/scripts/webui_streamlit.py @@ -100,8 +100,8 @@ def layout(): iconName=['dashboard','model_training' ,'cloud_download', 'settings'], default_choice=0) if tabs =='Stable Diffusion': - txt2img_tab, img2img_tab, txt2vid_tab, postprocessing_tab = st.tabs(["Text-to-Image Unified", "Image-to-Image Unified", - "Text-to-Video","Post-Processing"]) + txt2img_tab, img2img_tab, txt2vid_tab, postprocessing_tab, concept_library_tab = st.tabs(["Text-to-Image Unified", "Image-to-Image Unified", + "Text-to-Video","Post-Processing", "Concept Library"]) #with home_tab: #from home import layout #layout() @@ -117,7 +117,10 @@ def layout(): with txt2vid_tab: from txt2vid import layout layout() - + + with concept_library_tab: + from sd_concept_library import layout + layout() # elif tabs == 'Model Manager':