Update modules.py

This commit is contained in:
hlky 2022-10-02 00:04:08 +01:00
parent 47e340dc2c
commit d017fe7af6
No known key found for this signature in database
GPG Key ID: 55A99F1E80D907D5

View File

@ -5,7 +5,7 @@ import clip
from einops import rearrange, repeat
from transformers import CLIPTokenizer, CLIPTextModel
import kornia
import os
from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
@ -138,8 +138,12 @@ class FrozenCLIPEmbedder(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(version)
self.transformer = CLIPTextModel.from_pretrained(version)
if os.path.exists("models/clip-vit-large-patch14"):
self.tokenizer = CLIPTokenizer.from_pretrained("models/clip-vit-large-patch14")
self.transformer = CLIPTextModel.from_pretrained("models/clip-vit-large-patch14")
else:
self.tokenizer = CLIPTokenizer.from_pretrained(version)
self.transformer = CLIPTextModel.from_pretrained(version)
self.device = device
self.max_length = max_length
self.freeze()