from typing import List, Union, get_args import PIL import PIL.Jpeg2KImagePlugin import PIL.JpegImagePlugin import PIL.PngImagePlugin import PIL.TiffImagePlugin import torch from diffusers.configuration_utils import ConfigMixin from diffusers.image_processor import PipelineImageInput from diffusers.models.modeling_utils import ModelMixin from PIL import Image from transformers import ( AutoProcessor, AutoTokenizer, CLIPTextModelWithProjection, CLIPVisionModelWithProjection, ) StrInput = Union[str, List[str]] ImageInput = Union[ PIL.JpegImagePlugin.JpegImageFile, PIL.Jpeg2KImagePlugin.Jpeg2KImageFile, PIL.PngImagePlugin.PngImageFile, PIL.TiffImagePlugin.TiffImageFile, ] class BasePromptEncoder(ModelMixin, ConfigMixin): def __init__(self): super().__init__() def encode_text(self, text): raise NotImplementedError def encode_image(self, image): raise NotImplementedError def forward( self, prompt, negative_prompt=None, ): raise NotImplementedError class MaterialPromptEncoder(BasePromptEncoder): def __init__(self): super(MaterialPromptEncoder, self).__init__() self.processor = AutoProcessor.from_pretrained("openai/clip-vit-large-patch14") self.tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-large-patch14") self.clip_vision = CLIPVisionModelWithProjection.from_pretrained( "openai/clip-vit-large-patch14" ) self.clip_text = CLIPTextModelWithProjection.from_pretrained( "openai/clip-vit-large-patch14" ) def encode_text(self, text): inputs = self.tokenizer(text, padding=True, return_tensors="pt") inputs["input_ids"] = inputs["input_ids"].to(self.device) inputs["attention_mask"] = inputs["attention_mask"].to(self.device) outputs = self.clip_text(**inputs) return outputs.text_embeds.unsqueeze(1) def encode_image(self, image): inputs = self.processor(images=image, return_tensors="pt") inputs["pixel_values"] = inputs["pixel_values"].to(self.device) outputs = self.clip_vision(**inputs) return outputs.image_embeds.unsqueeze(1) def encode_prompt( self, prompt, ): if type(prompt) != list: prompt = [prompt] embs = [] for prompt in prompt: if isinstance(prompt, str): embs.append(self.encode_text(prompt)) elif type(prompt) in get_args(ImageInput): embs.append(self.encode_image(prompt)) else: raise NotImplementedError return torch.cat(embs, dim=0) def forward( self, prompt, negative_prompt=None, ): prompt = self.encode_prompt(prompt) negative_prompt = self.encode_prompt(negative_prompt) return prompt, negative_prompt