from typing import List, Optional from diffusers.configuration_utils import ConfigMixin from diffusers.models.modeling_utils import ModelMixin from PIL import Image from transformers import ( AutoProcessor, AutoTokenizer, CLIPTextModelWithProjection, CLIPVisionModelWithProjection, ) class BasePromptEncoder(ModelMixin, ConfigMixin): def __init__(self): super().__init__() def encode_text(self, text): raise NotImplementedError def encode_image(self, image): raise NotImplementedError def forward( self, prompt, negative_prompt=None, ): raise NotImplementedError class MaterialPromptEncoder(BasePromptEncoder): def __init__(self): super(MaterialPromptEncoder, self).__init__() self.processor = AutoProcessor.from_pretrained("openai/clip-vit-large-patch14") self.tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-large-patch14") self.clip_vision = CLIPVisionModelWithProjection.from_pretrained( "openai/clip-vit-large-patch14" ) self.clip_text = CLIPTextModelWithProjection.from_pretrained( "openai/clip-vit-large-patch14" ) def encode_text(self, text): inputs = self.tokenizer(text, padding=True, return_tensors="pt") inputs["input_ids"] = inputs["input_ids"].to(self.device) inputs["attention_mask"] = inputs["attention_mask"].to(self.device) outputs = self.clip_text(**inputs) return outputs.text_embeds.unsqueeze(1) def encode_image(self, image): inputs = self.processor(images=image, return_tensors="pt") inputs["pixel_values"] = inputs["pixel_values"].to(self.device) outputs = self.clip_vision(**inputs) return outputs.image_embeds.unsqueeze(1) def encode_prompt( self, prompt, ): dtype = type(prompt) if dtype == list: dtype = type(prompt[0]) if dtype == str: return self.encode_text(prompt) elif dtype == Image.Image: return self.encode_image(prompt) else: raise NotImplementedError def forward( self, prompt, negative_prompt=None, ): prompt = self.encode_prompt(prompt) negative_prompt = self.encode_prompt(negative_prompt) return prompt, negative_prompt