File size: 2,411 Bytes
1ad8ca8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
from typing import List, Optional
from diffusers.configuration_utils import ConfigMixin
from diffusers.models.modeling_utils import ModelMixin
from PIL import Image
from transformers import (
AutoProcessor,
AutoTokenizer,
CLIPTextModelWithProjection,
CLIPVisionModelWithProjection,
)
class BasePromptEncoder(ModelMixin, ConfigMixin):
def __init__(self):
super().__init__()
def encode_text(self, text):
raise NotImplementedError
def encode_image(self, image):
raise NotImplementedError
def forward(
self,
prompt,
negative_prompt=None,
):
raise NotImplementedError
class MaterialPromptEncoder(BasePromptEncoder):
def __init__(self):
super(MaterialPromptEncoder, self).__init__()
self.processor = AutoProcessor.from_pretrained("openai/clip-vit-large-patch14")
self.tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-large-patch14")
self.clip_vision = CLIPVisionModelWithProjection.from_pretrained(
"openai/clip-vit-large-patch14"
)
self.clip_text = CLIPTextModelWithProjection.from_pretrained(
"openai/clip-vit-large-patch14"
)
def encode_text(self, text):
inputs = self.tokenizer(text, padding=True, return_tensors="pt")
inputs["input_ids"] = inputs["input_ids"].to(self.device)
inputs["attention_mask"] = inputs["attention_mask"].to(self.device)
outputs = self.clip_text(**inputs)
return outputs.text_embeds.unsqueeze(1)
def encode_image(self, image):
inputs = self.processor(images=image, return_tensors="pt")
inputs["pixel_values"] = inputs["pixel_values"].to(self.device)
outputs = self.clip_vision(**inputs)
return outputs.image_embeds.unsqueeze(1)
def encode_prompt(
self,
prompt,
):
dtype = type(prompt)
if dtype == list:
dtype = type(prompt[0])
if dtype == str:
return self.encode_text(prompt)
elif dtype == Image.Image:
return self.encode_image(prompt)
else:
raise NotImplementedError
def forward(
self,
prompt,
negative_prompt=None,
):
prompt = self.encode_prompt(prompt)
negative_prompt = self.encode_prompt(negative_prompt)
return prompt, negative_prompt
|