gvecchio's picture
Update prompt_encoder/encoder.py
7e6c409 verified
raw
history blame
2.95 kB
from typing import List, Union, get_args
import PIL
import PIL.Jpeg2KImagePlugin
import PIL.JpegImagePlugin
import PIL.PngImagePlugin
import PIL.TiffImagePlugin
import torch
from diffusers.configuration_utils import ConfigMixin
from diffusers.image_processor import PipelineImageInput
from diffusers.models.modeling_utils import ModelMixin
from PIL import Image
from transformers import (
AutoProcessor,
AutoTokenizer,
CLIPTextModelWithProjection,
CLIPVisionModelWithProjection,
)
StrInput = Union[str, List[str]]
ImageInput = Union[
PIL.JpegImagePlugin.JpegImageFile,
PIL.Jpeg2KImagePlugin.Jpeg2KImageFile,
PIL.PngImagePlugin.PngImageFile,
PIL.TiffImagePlugin.TiffImageFile,
]
class BasePromptEncoder(ModelMixin, ConfigMixin):
def __init__(self):
super().__init__()
def encode_text(self, text):
raise NotImplementedError
def encode_image(self, image):
raise NotImplementedError
def forward(
self,
prompt,
negative_prompt=None,
):
raise NotImplementedError
class MaterialPromptEncoder(BasePromptEncoder):
def __init__(self):
super(MaterialPromptEncoder, self).__init__()
self.processor = AutoProcessor.from_pretrained("openai/clip-vit-large-patch14")
self.tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-large-patch14")
self.clip_vision = CLIPVisionModelWithProjection.from_pretrained(
"openai/clip-vit-large-patch14"
)
self.clip_text = CLIPTextModelWithProjection.from_pretrained(
"openai/clip-vit-large-patch14"
)
def encode_text(self, text):
inputs = self.tokenizer(text, padding=True, return_tensors="pt")
inputs["input_ids"] = inputs["input_ids"].to(self.device)
inputs["attention_mask"] = inputs["attention_mask"].to(self.device)
outputs = self.clip_text(**inputs)
return outputs.text_embeds.unsqueeze(1)
def encode_image(self, image):
inputs = self.processor(images=image, return_tensors="pt")
inputs["pixel_values"] = inputs["pixel_values"].to(self.device)
outputs = self.clip_vision(**inputs)
return outputs.image_embeds.unsqueeze(1)
def encode_prompt(
self,
prompt,
):
if type(prompt) != list:
prompt = [prompt]
embs = []
for prompt in prompt:
if isinstance(prompt, str):
embs.append(self.encode_text(prompt))
elif type(prompt) in get_args(ImageInput):
embs.append(self.encode_image(prompt))
else:
raise NotImplementedError
return torch.cat(embs, dim=0)
def forward(
self,
prompt,
negative_prompt=None,
):
prompt = self.encode_prompt(prompt)
negative_prompt = self.encode_prompt(negative_prompt)
return prompt, negative_prompt