mpatel57's picture
Upload CustomTextEncoderOnly
27c8593 verified
from transformers import AutoConfig, AutoModel, PretrainedConfig, CLIPTextConfig, CLIPVisionConfig, PreTrainedModel, CLIPTextModelWithProjection, CLIPVisionModelWithProjection
from transformers.utils import ModelOutput
import torch
import open_clip
from dataclasses import dataclass
import safetensors.torch
from peft import get_peft_config, get_peft_model, LoraConfig, TaskType
import os
HF_SAFE_WEIGHTS_NAME = "open_clip_model.safetensors"
HF_SAFE_WEIGHTS_NAME_PRIOR = "prior_model.safetensors"
@dataclass
class PriorTransformerOutput(ModelOutput):
"""
The output of [`PriorTransformer`].
Args:
predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
The predicted CLIP image embedding conditioned on the CLIP text embedding input.
"""
predicted_image_embedding: torch.FloatTensor
@dataclass
class TextEncoderOutput(ModelOutput):
"""
Output class for CLIPTextEncoderOnly model to store the outputs in a Hugging Face transformer style.
Attributes:
prompt_embeds (torch.Tensor): The embeddings of the input prompts.
last_hidden_states (torch.Tensor): The last hidden states from the model.
"""
text_embeds: torch.FloatTensor = None
last_hidden_state: torch.FloatTensor = None
class CLIPTextEncoderOnlyConfig(CLIPTextConfig):
model_type = "clip_custom_text_model"
def __init__(self, model_name: str = None, pretrained: bool = True, frozen: bool = False, lora: dict = None, **kwargs):
self.model_name = model_name
self.pretrained = pretrained
self.frozen = frozen
self.lora = lora
super().__init__(**kwargs)
class CLIPTextEncoderOnly(PreTrainedModel):
config_class = CLIPTextEncoderOnlyConfig
def __init__(self, config):
"""
Initializes the Hugging Face text encoder for CLIP model, inheriting from PreTrainedModel.
:param model_name: The name or path of the pretrained model.
:param pretrained: Whether to load the pretrained weights.
"""
super().__init__(config)
if config.pretrained:
self.model = CLIPTextModelWithProjection.from_pretrained(config.model_name)
else:
base_cfg = CLIPTextConfig.from_pretrained(config.model_name)
self.model = CLIPTextModelWithProjection(base_cfg)
if config.lora:
l_config = LoraConfig(
r=config.lora.lora_r,
lora_alpha=config.lora.lora_alpha,
target_modules=[
"k_proj",
"v_proj",
"q_proj",
"out_proj",
"fc1",
"fc2",
"visual_projection",
"text_projection"
],
lora_dropout=config.lora.lora_dropout,
bias="lora_only",
)
self.model = get_peft_model(self.model, l_config)
def forward(self, input_ids, attention_mask=None, position_ids=None):
"""
Forward pass of the model.
:param input_ids: Indices of input sequence tokens in the vocabulary.
:param attention_mask: Mask to avoid performing attention on padding token indices.
:param token_type_ids: Segment token indices to indicate first and second portions of the inputs.
:return: Outputs of the model.
"""
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, output_hidden_states=True)
return TextEncoderOutput(text_embeds=outputs.text_embeds, last_hidden_state=outputs.last_hidden_state)
class CustomTextEncoderOnlyConfig(CLIPTextConfig):
model_type = "whole_custom_text_model"
def __init__(self, model_name: str = None, pretrained: bool = True, frozen: bool = False, output_hidden_size: int = 512, last_hidden_state: bool = False, lora: dict = None, **kwargs):
self.model_name = model_name
self.pretrained = pretrained
self.frozen = frozen
self.output_hidden_size = output_hidden_size
self.last_hidden_state = last_hidden_state
self.lora = lora
super().__init__(**kwargs)
class CustomTextEncoderOnly(PreTrainedModel):
config_class = CustomTextEncoderOnlyConfig
def __init__(self, config):
"""
Initializes the Hugging Face text encoder for CLIP model, inheriting from PreTrainedModel.
:param model_name: The name or path of the pretrained model.
:param pretrained: Whether to load the pretrained weights.
"""
super().__init__(config)
self.last_hidden_state = config.last_hidden_state
if config.pretrained:
self.model = AutoModel.from_pretrained(config.model_name)
if config.frozen:
for param in self.model.parameters():
param.requires_grad = False
else:
self.model = AutoModel(config)
self.fc1 = torch.nn.Linear(self.model.config.hidden_size, config.output_hidden_size)
if config.last_hidden_state:
self.fc2 = torch.nn.Linear(self.model.config.hidden_size, config.output_hidden_size)
if config.lora:
l_config = LoraConfig(
task_type=TaskType.FEATURE_EXTRACTION,
r=config.lora.lora_r,
lora_alpha=config.lora.lora_alpha,
lora_dropout=config.lora.lora_dropout,
bias="lora_only",
)
self.model = get_peft_model(self.model, l_config)
def forward(self, input_ids, attention_mask=None, token_type_ids=None):
"""
Forward pass of the model.
:param input_ids: Indices of input sequence tokens in the vocabulary.
:param attention_mask: Mask to avoid performing attention on padding token indices.
:param token_type_ids: Segment token indices to indicate first and second portions of the inputs.
:return: Outputs of the model.
"""
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, output_hidden_states=True)
text_embeds = self.fc1(outputs[1])
last_hidden_state = None
if self.last_hidden_state:
last_hidden_state = self.fc2(outputs[0])
else:
last_hidden_state = outputs[0]
return TextEncoderOutput(text_embeds=text_embeds, last_hidden_state=last_hidden_state)
class CLIPVisionEncoderOnlyConfig(PretrainedConfig):
model_type = "clip_custom_vision_model"
def __init__(self, model_name: str = None, pretrained: bool = True, frozen: bool = False, lora: dict = None, **kwargs):
self.model_name = model_name
self.pretrained = pretrained
self.frozen = frozen
self.lora = lora
super().__init__(**kwargs)
class CLIPVisionEncoderOnly(PreTrainedModel):
config_class = CLIPVisionEncoderOnlyConfig
def __init__(self, config):
"""
Initializes the Hugging Face text encoder for CLIP model, inheriting from PreTrainedModel.
:param model_name: The name or path of the pretrained model.
:param pretrained: Whether to load the pretrained weights.
"""
super().__init__(config)
if config.pretrained:
self.model = CLIPVisionModelWithProjection.from_pretrained(config.model_name)
else:
base_cfg = CLIPVisionConfig.from_pretrained(config.model_name)
self.model = CLIPVisionModelWithProjection(base_cfg)
if config.lora:
l_config = LoraConfig(
r=config.lora.lora_r,
lora_alpha=config.lora.lora_alpha,
target_modules=[
"k_proj",
"v_proj",
"q_proj",
"out_proj",
"fc1",
"fc2",
"visual_projection",
"text_projection"
],
lora_dropout=config.lora.lora_dropout,
bias="lora_only",
)
self.model = get_peft_model(self.model, l_config)
def forward(self, data):
"""
Forward pass of the model.
"""
return self.model(**data).image_embeds
def parameters(self):
return self.model.parameters()
class OpenCLIPVisionEncoderOnly(torch.nn.Module):
def __init__(self, model_name: str, pretrained: bool = True, frozen: bool = False, lora: dict = None):
"""
Initializes the Hugging Face text encoder for CLIP model, inheriting from PreTrainedModel.
:param model_name: The name or path of the pretrained model.
:param pretrained: Whether to load the pretrained weights.
"""
super().__init__()
if pretrained:
model, _ = open_clip.create_model_from_pretrained(f"hf-hub:{model_name}")
model = model.visual
else:
raise NotImplemented
self.model = model
if lora:
l_config = LoraConfig(
r=lora.lora_r,
lora_alpha=lora.lora_alpha,
target_modules=[
"k_proj",
"v_proj",
"q_proj",
"out_proj",
"fc1",
"fc2",
"visual_projection",
"text_projection"
],
lora_dropout=lora.lora_dropout,
bias="lora_only",
)
self.model = get_peft_model(self.model, l_config)
def forward(self, image):
"""
Forward pass of the model.
"""
return self.model(image)
def save_pretrained(self, save_dir):
tensors = self.model.state_dict()
safetensors.torch.save_file(tensors, save_dir / HF_SAFE_WEIGHTS_NAME)
class CustomPriorModel(torch.nn.Module):
def __init__(self, in_hidden_state, out_hidden_state):
"""
Initializes the Hugging Face text encoder for CLIP model, inheriting from PreTrainedModel.
:param model_name: The name or path of the pretrained model.
:param pretrained: Whether to load the pretrained weights.
"""
super().__init__()
mid_hidden_state = max(in_hidden_state, out_hidden_state)
self.fc1 = torch.nn.Linear(in_hidden_state*2, mid_hidden_state)
self.relu = torch.nn.ReLU()
self.fc2 = torch.nn.Linear(mid_hidden_state, out_hidden_state)
def reinitialize_model(self):
for name, param in self.named_parameters():
if param.requires_grad:
if len(param.shape) > 1:
torch.nn.init.xavier_uniform_(param)
else:
if 'weight' in name:
torch.nn.init.normal_(param)
else:
torch.nn.init.zeros_(param)
def forward(self, feats):
"""
Forward pass of the model.
"""
return PriorTransformerOutput(predicted_image_embedding=self.fc2(self.relu(self.fc1(feats))))
def save_pretrained(self, save_dir):
pass
# tensors = self.state_dict()
# safetensors.torch.save_file(tensors, os.path.join(save_dir, HF_SAFE_WEIGHTS_NAME_PRIOR))
def test_text_model(register=False, upload=False):
# register the classes
if register:
AutoConfig.register("clip_custom_text_model", CLIPTextEncoderOnlyConfig)
AutoModel.register(CLIPTextEncoderOnlyConfig, CLIPTextEncoderOnly)
CLIPTextEncoderOnlyConfig.register_for_auto_class()
CLIPTextEncoderOnly.register_for_auto_class("AutoModel")
if upload:
# Initialize the model
model_name = "openai/clip-vit-base-patch32"
pretrained=True
lora=None
cfg = CLIPTextEncoderOnlyConfig(model_name=model_name, pretrained=pretrained, lora=lora)
model = CLIPTextEncoderOnly(cfg)
model.push_to_hub("test-text-hf-upload")
model = CLIPTextEncoderOnly.from_pretrained("mpatel57/test-text-hf-upload", force_download=True)
def test_custom_text_model(register=False, upload=False):
# register the classes
if register:
AutoConfig.register("whole_custom_text_model", CustomTextEncoderOnlyConfig)
AutoModel.register(CustomTextEncoderOnlyConfig, CustomTextEncoderOnly)
CustomTextEncoderOnlyConfig.register_for_auto_class()
CustomTextEncoderOnly.register_for_auto_class("AutoModel")
if upload:
# Initialize the model
model_name = "google-bert/bert-base-uncased"
pretrained=True
frozen=False
output_hidden_size=512
last_hidden_state=False
lora=None
cfg = CustomTextEncoderOnlyConfig(model_name=model_name, pretrained=pretrained, frozen=frozen, output_hidden_size=output_hidden_size, last_hidden_state=last_hidden_state, lora=lora)
model = CustomTextEncoderOnly(cfg)
model.push_to_hub("test-text-hf-upload")
model = CustomTextEncoderOnly.from_pretrained("mpatel57/test-text-hf-upload", force_download=True)
def test_vision_model(register=False, upload=False):
# register the classes
if register:
AutoConfig.register("clip_custom_vision_model", CLIPVisionEncoderOnlyConfig)
AutoModel.register(CLIPVisionEncoderOnlyConfig, CLIPVisionEncoderOnly)
CLIPVisionEncoderOnlyConfig.register_for_auto_class()
CLIPVisionEncoderOnly.register_for_auto_class("AutoModel")
if upload:
# Initialize the model
model_name = "openai/clip-vit-base-patch32"
pretrained=True
lora=None
cfg = CLIPVisionEncoderOnlyConfig(model_name=model_name, pretrained=pretrained, lora=lora)
model = CLIPVisionEncoderOnly(cfg)
model.push_to_hub("test-vision-hf-upload")
model = CLIPVisionEncoderOnly.from_pretrained("mpatel57/test-vision-hf-upload", force_download=True)
if __name__ == "__main__":
test_custom_text_model(register=False, upload=True)
# test_text_model(register=False, upload=True)
# test_vision_model(register=False, upload=True)