gvecchio's picture
Upload MatForgerPipeline
1ad8ca8
raw
history blame
2.41 kB
from typing import List, Optional
from diffusers.configuration_utils import ConfigMixin
from diffusers.models.modeling_utils import ModelMixin
from PIL import Image
from transformers import (
AutoProcessor,
AutoTokenizer,
CLIPTextModelWithProjection,
CLIPVisionModelWithProjection,
)
class BasePromptEncoder(ModelMixin, ConfigMixin):
def __init__(self):
super().__init__()
def encode_text(self, text):
raise NotImplementedError
def encode_image(self, image):
raise NotImplementedError
def forward(
self,
prompt,
negative_prompt=None,
):
raise NotImplementedError
class MaterialPromptEncoder(BasePromptEncoder):
def __init__(self):
super(MaterialPromptEncoder, self).__init__()
self.processor = AutoProcessor.from_pretrained("openai/clip-vit-large-patch14")
self.tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-large-patch14")
self.clip_vision = CLIPVisionModelWithProjection.from_pretrained(
"openai/clip-vit-large-patch14"
)
self.clip_text = CLIPTextModelWithProjection.from_pretrained(
"openai/clip-vit-large-patch14"
)
def encode_text(self, text):
inputs = self.tokenizer(text, padding=True, return_tensors="pt")
inputs["input_ids"] = inputs["input_ids"].to(self.device)
inputs["attention_mask"] = inputs["attention_mask"].to(self.device)
outputs = self.clip_text(**inputs)
return outputs.text_embeds.unsqueeze(1)
def encode_image(self, image):
inputs = self.processor(images=image, return_tensors="pt")
inputs["pixel_values"] = inputs["pixel_values"].to(self.device)
outputs = self.clip_vision(**inputs)
return outputs.image_embeds.unsqueeze(1)
def encode_prompt(
self,
prompt,
):
dtype = type(prompt)
if dtype == list:
dtype = type(prompt[0])
if dtype == str:
return self.encode_text(prompt)
elif dtype == Image.Image:
return self.encode_image(prompt)
else:
raise NotImplementedError
def forward(
self,
prompt,
negative_prompt=None,
):
prompt = self.encode_prompt(prompt)
negative_prompt = self.encode_prompt(negative_prompt)
return prompt, negative_prompt