|
from typing import List, Optional |
|
|
|
from diffusers.configuration_utils import ConfigMixin |
|
from diffusers.models.modeling_utils import ModelMixin |
|
from PIL import Image |
|
from transformers import ( |
|
AutoProcessor, |
|
AutoTokenizer, |
|
CLIPTextModelWithProjection, |
|
CLIPVisionModelWithProjection, |
|
) |
|
|
|
|
|
class BasePromptEncoder(ModelMixin, ConfigMixin): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
def encode_text(self, text): |
|
raise NotImplementedError |
|
|
|
def encode_image(self, image): |
|
raise NotImplementedError |
|
|
|
def forward( |
|
self, |
|
prompt, |
|
negative_prompt=None, |
|
): |
|
raise NotImplementedError |
|
|
|
|
|
class MaterialPromptEncoder(BasePromptEncoder): |
|
def __init__(self): |
|
super(MaterialPromptEncoder, self).__init__() |
|
|
|
self.processor = AutoProcessor.from_pretrained("openai/clip-vit-large-patch14") |
|
self.tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-large-patch14") |
|
self.clip_vision = CLIPVisionModelWithProjection.from_pretrained( |
|
"openai/clip-vit-large-patch14" |
|
) |
|
self.clip_text = CLIPTextModelWithProjection.from_pretrained( |
|
"openai/clip-vit-large-patch14" |
|
) |
|
|
|
def encode_text(self, text): |
|
inputs = self.tokenizer(text, padding=True, return_tensors="pt") |
|
inputs["input_ids"] = inputs["input_ids"].to(self.device) |
|
inputs["attention_mask"] = inputs["attention_mask"].to(self.device) |
|
outputs = self.clip_text(**inputs) |
|
return outputs.text_embeds.unsqueeze(1) |
|
|
|
def encode_image(self, image): |
|
inputs = self.processor(images=image, return_tensors="pt") |
|
inputs["pixel_values"] = inputs["pixel_values"].to(self.device) |
|
outputs = self.clip_vision(**inputs) |
|
return outputs.image_embeds.unsqueeze(1) |
|
|
|
def encode_prompt( |
|
self, |
|
prompt, |
|
): |
|
dtype = type(prompt) |
|
if dtype == list: |
|
dtype = type(prompt[0]) |
|
|
|
if dtype == str: |
|
return self.encode_text(prompt) |
|
elif dtype == Image.Image: |
|
return self.encode_image(prompt) |
|
else: |
|
raise NotImplementedError |
|
|
|
def forward( |
|
self, |
|
prompt, |
|
negative_prompt=None, |
|
): |
|
prompt = self.encode_prompt(prompt) |
|
negative_prompt = self.encode_prompt(negative_prompt) |
|
return prompt, negative_prompt |
|
|