AniMemory-alpha-dpo / text_encoder_2 /animemory_altclip.py
devancao
model release
5d125d1
raw
history blame
4.44 kB
# Copyright 2024 AniMemory Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
from safetensors.torch import load_file
from transformers import CLIPTextConfig, CLIPTextModelWithProjection
class AniMemoryAltCLip(torch.nn.Module):
def __init__(self, config: CLIPTextConfig):
super().__init__()
self.model_hf = CLIPTextModelWithProjection(config)
self.linear_proj = torch.nn.Linear(in_features=1280, out_features=1280)
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path,
subfolder="",
linear_proj_name="weights.safetensors",
torch_dtype=torch.float16,
):
cls.dtype = torch_dtype
config = CLIPTextModelWithProjection.config_class.from_pretrained(
pretrained_model_name_or_path, subfolder=subfolder
)
model = cls(config=config)
model.model_hf = CLIPTextModelWithProjection.from_pretrained(
pretrained_model_name_or_path, subfolder=subfolder
)
linear_proj_state = load_file(
os.path.join(pretrained_model_name_or_path, subfolder, linear_proj_name)
)
model.linear_proj.load_state_dict(linear_proj_state)
return model
def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(
*args, **kwargs
)
super(AniMemoryAltCLip, self).to(*args, **kwargs)
self.dtype = dtype if dtype is not None else self.dtype
self.device = device if device is not None else self.device
return self
def expand_mask(self, mask=None, dtype="", tgt_len=None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = (
mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(
inverted_mask.to(torch.bool), torch.finfo(dtype).min
)
def make_attn_mask(self, attn_mask):
seq_len = attn_mask.shape[1]
query = attn_mask.unsqueeze(1).float()
attn_mask = (
query.repeat([1, seq_len, 1]).unsqueeze(1).repeat([1, self.num_head, 1, 1])
)
attn_mask = attn_mask.view([-1, seq_len, seq_len])
return attn_mask
def gradient_checkpointing_enable(
self,
):
self.model_hf.gradient_checkpointing_enable()
def forward(self, text, attention_mask):
hidden_states = self.model_hf.text_model.embeddings(
input_ids=text, position_ids=None
)
if attention_mask is None:
print("Warning: attention_mask is None in altclip!")
new_attn_mask = (
self.expand_mask(attention_mask, hidden_states.dtype)
if attention_mask is not None
else None
)
encoder_outputs = self.model_hf.text_model.encoder(
inputs_embeds=hidden_states,
attention_mask=new_attn_mask,
causal_attention_mask=None,
output_attentions=False,
output_hidden_states=True,
return_dict=True,
)
last_hidden_state = encoder_outputs[0]
last_hidden_state = self.model_hf.text_model.final_layer_norm(last_hidden_state)
last_hidden_state = (
last_hidden_state[torch.arange(last_hidden_state.shape[0]), 0]
@ self.model_hf.text_projection.weight
)
pooled_output = self.linear_proj(last_hidden_state)
extra_features = encoder_outputs.hidden_states[-2]
extra_features = self.model_hf.text_model.final_layer_norm(extra_features)
return extra_features, pooled_output