sygil-webui/scripts/textual_inversion.py
ZeroCool940711 ef2da42489 - Bumped the version of diffusers used on the txt2vid tab to be now v0.3.0.
- Added initial file for the textual inversion tab.
2022-09-16 11:50:22 -07:00

53 lines
1.7 KiB
Python

# base webui import and utils.
from webui_streamlit import st
from sd_utils import *
# streamlit imports
#other imports
#from transformers import CLIPTextModel, CLIPTokenizer
# Temp imports
# end of imports
#---------------------------------------------------------------------------------------------------------------
def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer, token=None):
loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
# separate token and the embeds
trained_token = list(loaded_learned_embeds.keys())[0]
embeds = loaded_learned_embeds[trained_token]
# cast to dtype of text_encoder
dtype = text_encoder.get_input_embeddings().weight.dtype
embeds.to(dtype)
# add the token in tokenizer
token = token if token is not None else trained_token
num_added_tokens = tokenizer.add_tokens(token)
i = 1
while(num_added_tokens == 0):
print(f"The tokenizer already contains the token {token}.")
token = f"{token[:-1]}-{i}>"
print(f"Attempting to add the token {token}.")
num_added_tokens = tokenizer.add_tokens(token)
i+=1
# resize the token embeddings
text_encoder.resize_token_embeddings(len(tokenizer))
# get the id for the token and assign the embeds
token_id = tokenizer.convert_tokens_to_ids(token)
text_encoder.get_input_embeddings().weight.data[token_id] = embeds
return token
#def token_loader()
learned_token = load_learned_embed_in_clip(f"models/custom/embeddings/Custom Ami.pt", pipe.text_encoder, pipe.tokenizer, "*")
#model_content["token"] = learned_token
#models.append(model_content)
def layout():
st.write("Textual Inversion")