|
from typing import List, Union, get_args |
|
|
|
import PIL |
|
import PIL.Jpeg2KImagePlugin |
|
import PIL.JpegImagePlugin |
|
import PIL.PngImagePlugin |
|
import PIL.TiffImagePlugin |
|
import torch |
|
from diffusers.configuration_utils import ConfigMixin |
|
from diffusers.image_processor import PipelineImageInput |
|
from diffusers.models.modeling_utils import ModelMixin |
|
from PIL import Image |
|
from transformers import ( |
|
AutoProcessor, |
|
AutoTokenizer, |
|
CLIPTextModelWithProjection, |
|
CLIPVisionModelWithProjection, |
|
) |
|
|
|
StrInput = Union[str, List[str]] |
|
|
|
ImageInput = Union[ |
|
PIL.JpegImagePlugin.JpegImageFile, |
|
PIL.Jpeg2KImagePlugin.Jpeg2KImageFile, |
|
PIL.PngImagePlugin.PngImageFile, |
|
PIL.TiffImagePlugin.TiffImageFile, |
|
] |
|
|
|
|
|
class BasePromptEncoder(ModelMixin, ConfigMixin): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
def encode_text(self, text): |
|
raise NotImplementedError |
|
|
|
def encode_image(self, image): |
|
raise NotImplementedError |
|
|
|
def forward( |
|
self, |
|
prompt, |
|
negative_prompt=None, |
|
): |
|
raise NotImplementedError |
|
|
|
|
|
class MaterialPromptEncoder(BasePromptEncoder): |
|
def __init__(self): |
|
super(MaterialPromptEncoder, self).__init__() |
|
|
|
self.processor = AutoProcessor.from_pretrained("openai/clip-vit-large-patch14") |
|
self.tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-large-patch14") |
|
self.clip_vision = CLIPVisionModelWithProjection.from_pretrained( |
|
"openai/clip-vit-large-patch14" |
|
) |
|
self.clip_text = CLIPTextModelWithProjection.from_pretrained( |
|
"openai/clip-vit-large-patch14" |
|
) |
|
|
|
def encode_text(self, text): |
|
inputs = self.tokenizer(text, padding=True, return_tensors="pt") |
|
inputs["input_ids"] = inputs["input_ids"].to(self.device) |
|
inputs["attention_mask"] = inputs["attention_mask"].to(self.device) |
|
outputs = self.clip_text(**inputs) |
|
return outputs.text_embeds.unsqueeze(1) |
|
|
|
def encode_image(self, image): |
|
inputs = self.processor(images=image, return_tensors="pt") |
|
inputs["pixel_values"] = inputs["pixel_values"].to(self.device) |
|
outputs = self.clip_vision(**inputs) |
|
return outputs.image_embeds.unsqueeze(1) |
|
|
|
def encode_prompt( |
|
self, |
|
prompt, |
|
): |
|
if type(prompt) != list: |
|
prompt = [prompt] |
|
|
|
embs = [] |
|
for prompt in prompt: |
|
if isinstance(prompt, str): |
|
embs.append(self.encode_text(prompt)) |
|
elif type(prompt) in get_args(ImageInput): |
|
embs.append(self.encode_image(prompt)) |
|
else: |
|
raise NotImplementedError |
|
|
|
return torch.cat(embs, dim=0) |
|
|
|
def forward( |
|
self, |
|
prompt, |
|
negative_prompt=None, |
|
): |
|
prompt = self.encode_prompt(prompt) |
|
negative_prompt = self.encode_prompt(negative_prompt) |
|
return prompt, negative_prompt |
|
|