mirror of
https://github.com/openvinotoolkit/stable-diffusion-webui.git
synced 2024-12-14 22:53:25 +03:00
fix for incorrect embedding token length calculation (will break seeds that use embeddings, you're welcome!)
add option to input initialization text for embeddings
This commit is contained in:
parent
53a3dc601f
commit
88ec0cf557
@ -130,7 +130,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||
while i < len(tokens):
|
||||
token = tokens[i]
|
||||
|
||||
embedding = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
|
||||
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
|
||||
|
||||
if embedding is None:
|
||||
remade_tokens.append(token)
|
||||
@ -142,7 +142,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||
remade_tokens += [0] * emb_len
|
||||
multipliers += [weight] * emb_len
|
||||
used_custom_terms.append((embedding.name, embedding.checksum()))
|
||||
i += emb_len
|
||||
i += embedding_length_in_tokens
|
||||
|
||||
if len(remade_tokens) > maxlen - 2:
|
||||
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
|
||||
@ -213,7 +213,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||
while i < len(tokens):
|
||||
token = tokens[i]
|
||||
|
||||
embedding = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
|
||||
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
|
||||
|
||||
mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
|
||||
if mult_change is not None:
|
||||
@ -229,7 +229,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||
remade_tokens += [0] * emb_len
|
||||
multipliers += [mult] * emb_len
|
||||
used_custom_terms.append((embedding.name, embedding.checksum()))
|
||||
i += emb_len
|
||||
i += embedding_length_in_tokens
|
||||
|
||||
if len(remade_tokens) > maxlen - 2:
|
||||
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
|
||||
|
@ -117,24 +117,21 @@ class EmbeddingDatabase:
|
||||
possible_matches = self.ids_lookup.get(token, None)
|
||||
|
||||
if possible_matches is None:
|
||||
return None
|
||||
return None, None
|
||||
|
||||
for ids, embedding in possible_matches:
|
||||
if tokens[offset:offset + len(ids)] == ids:
|
||||
return embedding
|
||||
return embedding, len(ids)
|
||||
|
||||
return None
|
||||
return None, None
|
||||
|
||||
|
||||
|
||||
def create_embedding(name, num_vectors_per_token):
|
||||
init_text = '*'
|
||||
|
||||
def create_embedding(name, num_vectors_per_token, init_text='*'):
|
||||
cond_model = shared.sd_model.cond_stage_model
|
||||
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
|
||||
|
||||
ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"]
|
||||
embedded = embedding_layer(ids.to(devices.device)).squeeze(0)
|
||||
embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
|
||||
vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
|
||||
|
||||
for i in range(num_vectors_per_token):
|
||||
|
@ -6,8 +6,8 @@ import modules.textual_inversion.textual_inversion as ti
|
||||
from modules import sd_hijack, shared
|
||||
|
||||
|
||||
def create_embedding(name, nvpt):
|
||||
filename = ti.create_embedding(name, nvpt)
|
||||
def create_embedding(name, initialization_text, nvpt):
|
||||
filename = ti.create_embedding(name, nvpt, init_text=initialization_text)
|
||||
|
||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
|
||||
|
||||
|
@ -954,6 +954,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
gr.HTML(value="<p style='margin-bottom: 0.7em'>Create a new embedding</p>")
|
||||
|
||||
new_embedding_name = gr.Textbox(label="Name")
|
||||
initialization_text = gr.Textbox(label="Initialization text", value="*")
|
||||
nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1)
|
||||
|
||||
with gr.Row():
|
||||
@ -997,6 +998,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
fn=modules.textual_inversion.ui.create_embedding,
|
||||
inputs=[
|
||||
new_embedding_name,
|
||||
initialization_text,
|
||||
nvpt,
|
||||
],
|
||||
outputs=[
|
||||
|
Loading…
Reference in New Issue
Block a user