|
from transformers import AutoConfig, AutoModel, PretrainedConfig, CLIPTextConfig, CLIPVisionConfig, PreTrainedModel, CLIPTextModelWithProjection, CLIPVisionModelWithProjection |
|
from transformers.utils import ModelOutput |
|
import torch |
|
import open_clip |
|
from dataclasses import dataclass |
|
import safetensors.torch |
|
from peft import get_peft_config, get_peft_model, LoraConfig, TaskType |
|
import os |
|
|
|
HF_SAFE_WEIGHTS_NAME = "open_clip_model.safetensors" |
|
HF_SAFE_WEIGHTS_NAME_PRIOR = "prior_model.safetensors" |
|
|
|
@dataclass |
|
class PriorTransformerOutput(ModelOutput): |
|
""" |
|
The output of [`PriorTransformer`]. |
|
|
|
Args: |
|
predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`): |
|
The predicted CLIP image embedding conditioned on the CLIP text embedding input. |
|
""" |
|
|
|
predicted_image_embedding: torch.FloatTensor |
|
|
|
@dataclass |
|
class TextEncoderOutput(ModelOutput): |
|
""" |
|
Output class for CLIPTextEncoderOnly model to store the outputs in a Hugging Face transformer style. |
|
|
|
Attributes: |
|
prompt_embeds (torch.Tensor): The embeddings of the input prompts. |
|
last_hidden_states (torch.Tensor): The last hidden states from the model. |
|
""" |
|
text_embeds: torch.FloatTensor = None |
|
last_hidden_state: torch.FloatTensor = None |
|
|
|
class CLIPTextEncoderOnlyConfig(CLIPTextConfig): |
|
model_type = "clip_custom_text_model" |
|
|
|
def __init__(self, model_name: str = None, pretrained: bool = True, frozen: bool = False, lora: dict = None, **kwargs): |
|
self.model_name = model_name |
|
self.pretrained = pretrained |
|
self.frozen = frozen |
|
self.lora = lora |
|
super().__init__(**kwargs) |
|
|
|
class CLIPTextEncoderOnly(PreTrainedModel): |
|
config_class = CLIPTextEncoderOnlyConfig |
|
|
|
def __init__(self, config): |
|
""" |
|
Initializes the Hugging Face text encoder for CLIP model, inheriting from PreTrainedModel. |
|
|
|
:param model_name: The name or path of the pretrained model. |
|
:param pretrained: Whether to load the pretrained weights. |
|
""" |
|
super().__init__(config) |
|
|
|
if config.pretrained: |
|
self.model = CLIPTextModelWithProjection.from_pretrained(config.model_name) |
|
else: |
|
base_cfg = CLIPTextConfig.from_pretrained(config.model_name) |
|
self.model = CLIPTextModelWithProjection(base_cfg) |
|
|
|
if config.lora: |
|
l_config = LoraConfig( |
|
r=config.lora.lora_r, |
|
lora_alpha=config.lora.lora_alpha, |
|
target_modules=[ |
|
"k_proj", |
|
"v_proj", |
|
"q_proj", |
|
"out_proj", |
|
"fc1", |
|
"fc2", |
|
"visual_projection", |
|
"text_projection" |
|
], |
|
lora_dropout=config.lora.lora_dropout, |
|
bias="lora_only", |
|
) |
|
self.model = get_peft_model(self.model, l_config) |
|
|
|
|
|
def forward(self, input_ids, attention_mask=None, position_ids=None): |
|
""" |
|
Forward pass of the model. |
|
|
|
:param input_ids: Indices of input sequence tokens in the vocabulary. |
|
:param attention_mask: Mask to avoid performing attention on padding token indices. |
|
:param token_type_ids: Segment token indices to indicate first and second portions of the inputs. |
|
:return: Outputs of the model. |
|
""" |
|
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, output_hidden_states=True) |
|
return TextEncoderOutput(text_embeds=outputs.text_embeds, last_hidden_state=outputs.last_hidden_state) |
|
|
|
|
|
class CustomTextEncoderOnlyConfig(CLIPTextConfig): |
|
model_type = "whole_custom_text_model" |
|
|
|
def __init__(self, model_name: str = None, pretrained: bool = True, frozen: bool = False, output_hidden_size: int = 512, last_hidden_state: bool = False, lora: dict = None, **kwargs): |
|
self.model_name = model_name |
|
self.pretrained = pretrained |
|
self.frozen = frozen |
|
self.output_hidden_size = output_hidden_size |
|
self.last_hidden_state = last_hidden_state |
|
self.lora = lora |
|
super().__init__(**kwargs) |
|
|
|
class CustomTextEncoderOnly(PreTrainedModel): |
|
config_class = CustomTextEncoderOnlyConfig |
|
|
|
def __init__(self, config): |
|
""" |
|
Initializes the Hugging Face text encoder for CLIP model, inheriting from PreTrainedModel. |
|
|
|
:param model_name: The name or path of the pretrained model. |
|
:param pretrained: Whether to load the pretrained weights. |
|
""" |
|
super().__init__(config) |
|
|
|
self.last_hidden_state = config.last_hidden_state |
|
|
|
if config.pretrained: |
|
self.model = AutoModel.from_pretrained(config.model_name) |
|
if config.frozen: |
|
for param in self.model.parameters(): |
|
param.requires_grad = False |
|
else: |
|
self.model = AutoModel(config) |
|
|
|
self.fc1 = torch.nn.Linear(self.model.config.hidden_size, config.output_hidden_size) |
|
if config.last_hidden_state: |
|
self.fc2 = torch.nn.Linear(self.model.config.hidden_size, config.output_hidden_size) |
|
|
|
if config.lora: |
|
l_config = LoraConfig( |
|
task_type=TaskType.FEATURE_EXTRACTION, |
|
r=config.lora.lora_r, |
|
lora_alpha=config.lora.lora_alpha, |
|
lora_dropout=config.lora.lora_dropout, |
|
bias="lora_only", |
|
) |
|
self.model = get_peft_model(self.model, l_config) |
|
|
|
def forward(self, input_ids, attention_mask=None, token_type_ids=None): |
|
""" |
|
Forward pass of the model. |
|
|
|
:param input_ids: Indices of input sequence tokens in the vocabulary. |
|
:param attention_mask: Mask to avoid performing attention on padding token indices. |
|
:param token_type_ids: Segment token indices to indicate first and second portions of the inputs. |
|
:return: Outputs of the model. |
|
""" |
|
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, output_hidden_states=True) |
|
text_embeds = self.fc1(outputs[1]) |
|
last_hidden_state = None |
|
if self.last_hidden_state: |
|
last_hidden_state = self.fc2(outputs[0]) |
|
else: |
|
last_hidden_state = outputs[0] |
|
return TextEncoderOutput(text_embeds=text_embeds, last_hidden_state=last_hidden_state) |
|
|
|
class CLIPVisionEncoderOnlyConfig(PretrainedConfig): |
|
model_type = "clip_custom_vision_model" |
|
|
|
def __init__(self, model_name: str = None, pretrained: bool = True, frozen: bool = False, lora: dict = None, **kwargs): |
|
self.model_name = model_name |
|
self.pretrained = pretrained |
|
self.frozen = frozen |
|
self.lora = lora |
|
super().__init__(**kwargs) |
|
|
|
class CLIPVisionEncoderOnly(PreTrainedModel): |
|
config_class = CLIPVisionEncoderOnlyConfig |
|
|
|
def __init__(self, config): |
|
""" |
|
Initializes the Hugging Face text encoder for CLIP model, inheriting from PreTrainedModel. |
|
|
|
:param model_name: The name or path of the pretrained model. |
|
:param pretrained: Whether to load the pretrained weights. |
|
""" |
|
super().__init__(config) |
|
|
|
if config.pretrained: |
|
self.model = CLIPVisionModelWithProjection.from_pretrained(config.model_name) |
|
else: |
|
base_cfg = CLIPVisionConfig.from_pretrained(config.model_name) |
|
self.model = CLIPVisionModelWithProjection(base_cfg) |
|
|
|
if config.lora: |
|
l_config = LoraConfig( |
|
r=config.lora.lora_r, |
|
lora_alpha=config.lora.lora_alpha, |
|
target_modules=[ |
|
"k_proj", |
|
"v_proj", |
|
"q_proj", |
|
"out_proj", |
|
"fc1", |
|
"fc2", |
|
"visual_projection", |
|
"text_projection" |
|
], |
|
lora_dropout=config.lora.lora_dropout, |
|
bias="lora_only", |
|
) |
|
self.model = get_peft_model(self.model, l_config) |
|
|
|
def forward(self, data): |
|
""" |
|
Forward pass of the model. |
|
""" |
|
return self.model(**data).image_embeds |
|
|
|
def parameters(self): |
|
return self.model.parameters() |
|
|
|
|
|
class OpenCLIPVisionEncoderOnly(torch.nn.Module): |
|
def __init__(self, model_name: str, pretrained: bool = True, frozen: bool = False, lora: dict = None): |
|
""" |
|
Initializes the Hugging Face text encoder for CLIP model, inheriting from PreTrainedModel. |
|
|
|
:param model_name: The name or path of the pretrained model. |
|
:param pretrained: Whether to load the pretrained weights. |
|
""" |
|
super().__init__() |
|
if pretrained: |
|
model, _ = open_clip.create_model_from_pretrained(f"hf-hub:{model_name}") |
|
model = model.visual |
|
else: |
|
raise NotImplemented |
|
self.model = model |
|
|
|
if lora: |
|
l_config = LoraConfig( |
|
r=lora.lora_r, |
|
lora_alpha=lora.lora_alpha, |
|
target_modules=[ |
|
"k_proj", |
|
"v_proj", |
|
"q_proj", |
|
"out_proj", |
|
"fc1", |
|
"fc2", |
|
"visual_projection", |
|
"text_projection" |
|
], |
|
lora_dropout=lora.lora_dropout, |
|
bias="lora_only", |
|
) |
|
self.model = get_peft_model(self.model, l_config) |
|
|
|
def forward(self, image): |
|
""" |
|
Forward pass of the model. |
|
""" |
|
return self.model(image) |
|
|
|
def save_pretrained(self, save_dir): |
|
tensors = self.model.state_dict() |
|
safetensors.torch.save_file(tensors, save_dir / HF_SAFE_WEIGHTS_NAME) |
|
|
|
class CustomPriorModel(torch.nn.Module): |
|
def __init__(self, in_hidden_state, out_hidden_state): |
|
""" |
|
Initializes the Hugging Face text encoder for CLIP model, inheriting from PreTrainedModel. |
|
|
|
:param model_name: The name or path of the pretrained model. |
|
:param pretrained: Whether to load the pretrained weights. |
|
""" |
|
super().__init__() |
|
mid_hidden_state = max(in_hidden_state, out_hidden_state) |
|
|
|
self.fc1 = torch.nn.Linear(in_hidden_state*2, mid_hidden_state) |
|
self.relu = torch.nn.ReLU() |
|
self.fc2 = torch.nn.Linear(mid_hidden_state, out_hidden_state) |
|
|
|
def reinitialize_model(self): |
|
for name, param in self.named_parameters(): |
|
if param.requires_grad: |
|
if len(param.shape) > 1: |
|
torch.nn.init.xavier_uniform_(param) |
|
else: |
|
if 'weight' in name: |
|
torch.nn.init.normal_(param) |
|
else: |
|
torch.nn.init.zeros_(param) |
|
|
|
def forward(self, feats): |
|
""" |
|
Forward pass of the model. |
|
""" |
|
return PriorTransformerOutput(predicted_image_embedding=self.fc2(self.relu(self.fc1(feats)))) |
|
|
|
def save_pretrained(self, save_dir): |
|
pass |
|
|
|
|
|
|
|
|
|
def test_text_model(register=False, upload=False): |
|
|
|
if register: |
|
AutoConfig.register("clip_custom_text_model", CLIPTextEncoderOnlyConfig) |
|
AutoModel.register(CLIPTextEncoderOnlyConfig, CLIPTextEncoderOnly) |
|
CLIPTextEncoderOnlyConfig.register_for_auto_class() |
|
CLIPTextEncoderOnly.register_for_auto_class("AutoModel") |
|
|
|
if upload: |
|
|
|
model_name = "openai/clip-vit-base-patch32" |
|
pretrained=True |
|
lora=None |
|
|
|
cfg = CLIPTextEncoderOnlyConfig(model_name=model_name, pretrained=pretrained, lora=lora) |
|
model = CLIPTextEncoderOnly(cfg) |
|
model.push_to_hub("test-text-hf-upload") |
|
|
|
model = CLIPTextEncoderOnly.from_pretrained("mpatel57/test-text-hf-upload", force_download=True) |
|
|
|
def test_custom_text_model(register=False, upload=False): |
|
|
|
if register: |
|
AutoConfig.register("whole_custom_text_model", CustomTextEncoderOnlyConfig) |
|
AutoModel.register(CustomTextEncoderOnlyConfig, CustomTextEncoderOnly) |
|
CustomTextEncoderOnlyConfig.register_for_auto_class() |
|
CustomTextEncoderOnly.register_for_auto_class("AutoModel") |
|
|
|
if upload: |
|
|
|
model_name = "google-bert/bert-base-uncased" |
|
pretrained=True |
|
frozen=False |
|
output_hidden_size=512 |
|
last_hidden_state=False |
|
|
|
lora=None |
|
|
|
cfg = CustomTextEncoderOnlyConfig(model_name=model_name, pretrained=pretrained, frozen=frozen, output_hidden_size=output_hidden_size, last_hidden_state=last_hidden_state, lora=lora) |
|
model = CustomTextEncoderOnly(cfg) |
|
model.push_to_hub("test-text-hf-upload") |
|
|
|
model = CustomTextEncoderOnly.from_pretrained("mpatel57/test-text-hf-upload", force_download=True) |
|
|
|
def test_vision_model(register=False, upload=False): |
|
|
|
if register: |
|
AutoConfig.register("clip_custom_vision_model", CLIPVisionEncoderOnlyConfig) |
|
AutoModel.register(CLIPVisionEncoderOnlyConfig, CLIPVisionEncoderOnly) |
|
CLIPVisionEncoderOnlyConfig.register_for_auto_class() |
|
CLIPVisionEncoderOnly.register_for_auto_class("AutoModel") |
|
|
|
if upload: |
|
|
|
model_name = "openai/clip-vit-base-patch32" |
|
pretrained=True |
|
lora=None |
|
|
|
cfg = CLIPVisionEncoderOnlyConfig(model_name=model_name, pretrained=pretrained, lora=lora) |
|
model = CLIPVisionEncoderOnly(cfg) |
|
model.push_to_hub("test-vision-hf-upload") |
|
|
|
model = CLIPVisionEncoderOnly.from_pretrained("mpatel57/test-vision-hf-upload", force_download=True) |
|
|
|
|
|
if __name__ == "__main__": |
|
test_custom_text_model(register=False, upload=True) |
|
|
|
|
|
|