diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index f12a9696..4a2d2153 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -282,14 +282,10 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): remade_batch_tokens_of_same_length = [x + [self.wrapped.tokenizer.eos_token_id] * (target_token_count - len(x)) for x in remade_batch_tokens] tokens = torch.asarray(remade_batch_tokens_of_same_length).to(device) - tmp = -opts.CLIP_ignore_last_layers - if (opts.CLIP_ignore_last_layers == 0): - outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids) - z = outputs.last_hidden_state - else: - outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids, output_hidden_states=tmp) - z = outputs.hidden_states[tmp] - z = self.wrapped.transformer.text_model.final_layer_norm(z) + tmp = -opts.CLIP_stop_at_last_layers + outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids, output_hidden_states=tmp) + z = outputs.hidden_states[tmp] + z = self.wrapped.transformer.text_model.final_layer_norm(z) # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise batch_multipliers_of_same_length = [x + [1.0] * (target_token_count - len(x)) for x in batch_multipliers]