Spaces:
Running
on
Zero
Running
on
Zero
SunderAli17
commited on
Create hf_model.py
Browse files- eva_clip/hf_model.py +247 -0
eva_clip/hf_model.py
ADDED
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" huggingface model adapter
|
2 |
+
Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import re
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
from torch.nn import functional as F
|
10 |
+
from torch import TensorType
|
11 |
+
try:
|
12 |
+
import transformers
|
13 |
+
from transformers import AutoModel, AutoModelForMaskedLM, AutoTokenizer, AutoConfig, PretrainedConfig
|
14 |
+
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \
|
15 |
+
BaseModelOutputWithPoolingAndCrossAttentions
|
16 |
+
except ImportError as e:
|
17 |
+
transformers = None
|
18 |
+
|
19 |
+
|
20 |
+
class BaseModelOutput:
|
21 |
+
pass
|
22 |
+
|
23 |
+
|
24 |
+
class PretrainedConfig:
|
25 |
+
pass
|
26 |
+
|
27 |
+
from .hf_configs import arch_dict
|
28 |
+
|
29 |
+
# utils
|
30 |
+
def _camel2snake(s):
|
31 |
+
return re.sub(r'(?<!^)(?=[A-Z])', '_', s).lower()
|
32 |
+
|
33 |
+
# TODO: ?last - for gpt-like models
|
34 |
+
_POOLERS = {}
|
35 |
+
|
36 |
+
def register_pooler(cls):
|
37 |
+
"""Decorator registering pooler class"""
|
38 |
+
_POOLERS[_camel2snake(cls.__name__)] = cls
|
39 |
+
return cls
|
40 |
+
|
41 |
+
|
42 |
+
@register_pooler
|
43 |
+
class MeanPooler(nn.Module):
|
44 |
+
"""Mean pooling"""
|
45 |
+
def forward(self, x:BaseModelOutput, attention_mask:TensorType):
|
46 |
+
masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)
|
47 |
+
return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True)
|
48 |
+
|
49 |
+
@register_pooler
|
50 |
+
class MaxPooler(nn.Module):
|
51 |
+
"""Max pooling"""
|
52 |
+
def forward(self, x:BaseModelOutput, attention_mask:TensorType):
|
53 |
+
masked_output = x.last_hidden_state.masked_fill(attention_mask.unsqueeze(-1), -torch.inf)
|
54 |
+
return masked_output.max(1).values
|
55 |
+
|
56 |
+
@register_pooler
|
57 |
+
class ClsPooler(nn.Module):
|
58 |
+
"""CLS token pooling"""
|
59 |
+
def __init__(self, use_pooler_output=True):
|
60 |
+
super().__init__()
|
61 |
+
self.cls_token_position = 0
|
62 |
+
self.use_pooler_output = use_pooler_output
|
63 |
+
|
64 |
+
def forward(self, x:BaseModelOutput, attention_mask:TensorType):
|
65 |
+
|
66 |
+
if (self.use_pooler_output and
|
67 |
+
isinstance(x, (BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions)) and
|
68 |
+
(x.pooler_output is not None)
|
69 |
+
):
|
70 |
+
return x.pooler_output
|
71 |
+
|
72 |
+
return x.last_hidden_state[:, self.cls_token_position, :]
|
73 |
+
|
74 |
+
class HFTextEncoder(nn.Module):
|
75 |
+
"""HuggingFace model adapter"""
|
76 |
+
def __init__(
|
77 |
+
self,
|
78 |
+
model_name_or_path: str,
|
79 |
+
output_dim: int,
|
80 |
+
tokenizer_name: str = None,
|
81 |
+
config: PretrainedConfig = None,
|
82 |
+
pooler_type: str = None,
|
83 |
+
proj: str = None,
|
84 |
+
pretrained: bool = True,
|
85 |
+
masked_language_modeling: bool = False):
|
86 |
+
super().__init__()
|
87 |
+
|
88 |
+
self.output_dim = output_dim
|
89 |
+
|
90 |
+
# TODO: find better way to get this information
|
91 |
+
uses_transformer_pooler = (pooler_type == "cls_pooler")
|
92 |
+
|
93 |
+
if transformers is None:
|
94 |
+
raise RuntimeError("Please `pip install transformers` to use pre-trained HuggingFace models")
|
95 |
+
if config is None:
|
96 |
+
self.config = AutoConfig.from_pretrained(model_name_or_path)
|
97 |
+
if masked_language_modeling:
|
98 |
+
create_func, model_args = (AutoModelForMaskedLM.from_pretrained, model_name_or_path) if pretrained else (
|
99 |
+
AutoModelForMaskedLM.from_config, self.config)
|
100 |
+
else:
|
101 |
+
create_func, model_args = (AutoModel.from_pretrained, model_name_or_path) if pretrained else (
|
102 |
+
AutoModel.from_config, self.config)
|
103 |
+
# TODO: do all model configs have this attribute? PretrainedConfig does so yes??
|
104 |
+
if hasattr(self.config, "is_encoder_decoder") and self.config.is_encoder_decoder:
|
105 |
+
self.transformer = create_func(model_args)
|
106 |
+
self.transformer = self.transformer.encoder
|
107 |
+
else:
|
108 |
+
self.transformer = create_func(model_args, add_pooling_layer=uses_transformer_pooler)
|
109 |
+
else:
|
110 |
+
self.config = config
|
111 |
+
if masked_language_modeling:
|
112 |
+
self.transformer = AutoModelForMaskedLM.from_config(config)
|
113 |
+
else:
|
114 |
+
self.transformer = AutoModel.from_config(config)
|
115 |
+
|
116 |
+
if pooler_type is None: # get default arch pooler
|
117 |
+
self.pooler = _POOLERS[(arch_dict[self.config.model_type]["pooler"])]()
|
118 |
+
else:
|
119 |
+
self.pooler = _POOLERS[pooler_type]()
|
120 |
+
|
121 |
+
d_model = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["width"])
|
122 |
+
if (d_model == output_dim) and (proj is None): # do we always need a proj?
|
123 |
+
self.proj = nn.Identity()
|
124 |
+
elif proj == 'linear':
|
125 |
+
self.proj = nn.Linear(d_model, output_dim, bias=False)
|
126 |
+
elif proj == 'mlp':
|
127 |
+
hidden_size = (d_model + output_dim) // 2
|
128 |
+
self.proj = nn.Sequential(
|
129 |
+
nn.Linear(d_model, hidden_size, bias=False),
|
130 |
+
nn.GELU(),
|
131 |
+
nn.Linear(hidden_size, output_dim, bias=False),
|
132 |
+
)
|
133 |
+
|
134 |
+
# self.itm_proj = nn.Linear(d_model, 2, bias=False)
|
135 |
+
# self.mlm_proj = nn.Linear(d_model, self.config.vocab_size), bias=False)
|
136 |
+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
137 |
+
|
138 |
+
# def forward_itm(self, x:TensorType, image_embeds:TensorType) -> TensorType:
|
139 |
+
# image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(x.device)
|
140 |
+
# attn_mask = (x != self.config.pad_token_id).long()
|
141 |
+
# out = self.transformer(
|
142 |
+
# input_ids=x,
|
143 |
+
# attention_mask=attn_mask,
|
144 |
+
# encoder_hidden_states = image_embeds,
|
145 |
+
# encoder_attention_mask = image_atts,
|
146 |
+
# )
|
147 |
+
# pooled_out = self.pooler(out, attn_mask)
|
148 |
+
|
149 |
+
# return self.itm_proj(pooled_out)
|
150 |
+
|
151 |
+
def mask(self, input_ids, vocab_size, device, targets=None, masked_indices=None, probability_matrix=None):
|
152 |
+
if masked_indices is None:
|
153 |
+
masked_indices = torch.bernoulli(probability_matrix).bool()
|
154 |
+
|
155 |
+
masked_indices[input_ids == self.tokenizer.pad_token_id] = False
|
156 |
+
masked_indices[input_ids == self.tokenizer.cls_token_id] = False
|
157 |
+
|
158 |
+
if targets is not None:
|
159 |
+
targets[~masked_indices] = -100 # We only compute loss on masked tokens
|
160 |
+
|
161 |
+
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
162 |
+
indices_replaced = torch.bernoulli(torch.full(input_ids.shape, 0.8)).bool() & masked_indices
|
163 |
+
input_ids[indices_replaced] = self.tokenizer.mask_token_id
|
164 |
+
|
165 |
+
# 10% of the time, we replace masked input tokens with random word
|
166 |
+
indices_random = torch.bernoulli(torch.full(input_ids.shape, 0.5)).bool() & masked_indices & ~indices_replaced
|
167 |
+
random_words = torch.randint(vocab_size, input_ids.shape, dtype=torch.long).to(device)
|
168 |
+
input_ids[indices_random] = random_words[indices_random]
|
169 |
+
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
170 |
+
|
171 |
+
if targets is not None:
|
172 |
+
return input_ids, targets
|
173 |
+
else:
|
174 |
+
return input_ids
|
175 |
+
|
176 |
+
def forward_mlm(self, input_ids, image_embeds, mlm_probability=0.25):
|
177 |
+
labels = input_ids.clone()
|
178 |
+
attn_mask = (input_ids != self.config.pad_token_id).long()
|
179 |
+
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(input_ids.device)
|
180 |
+
vocab_size = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["vocab_size"])
|
181 |
+
probability_matrix = torch.full(labels.shape, mlm_probability)
|
182 |
+
input_ids, labels = self.mask(input_ids, vocab_size, input_ids.device, targets=labels,
|
183 |
+
probability_matrix = probability_matrix)
|
184 |
+
mlm_output = self.transformer(input_ids,
|
185 |
+
attention_mask = attn_mask,
|
186 |
+
encoder_hidden_states = image_embeds,
|
187 |
+
encoder_attention_mask = image_atts,
|
188 |
+
return_dict = True,
|
189 |
+
labels = labels,
|
190 |
+
)
|
191 |
+
return mlm_output.loss
|
192 |
+
# mlm_output = self.transformer(input_ids,
|
193 |
+
# attention_mask = attn_mask,
|
194 |
+
# encoder_hidden_states = image_embeds,
|
195 |
+
# encoder_attention_mask = image_atts,
|
196 |
+
# return_dict = True,
|
197 |
+
# ).last_hidden_state
|
198 |
+
# logits = self.mlm_proj(mlm_output)
|
199 |
+
|
200 |
+
# # logits = logits[:, :-1, :].contiguous().view(-1, vocab_size)
|
201 |
+
# logits = logits[:, 1:, :].contiguous().view(-1, vocab_size)
|
202 |
+
# labels = labels[:, 1:].contiguous().view(-1)
|
203 |
+
|
204 |
+
# mlm_loss = F.cross_entropy(
|
205 |
+
# logits,
|
206 |
+
# labels,
|
207 |
+
# # label_smoothing=0.1,
|
208 |
+
# )
|
209 |
+
# return mlm_loss
|
210 |
+
|
211 |
+
|
212 |
+
def forward(self, x:TensorType) -> TensorType:
|
213 |
+
attn_mask = (x != self.config.pad_token_id).long()
|
214 |
+
out = self.transformer(input_ids=x, attention_mask=attn_mask)
|
215 |
+
pooled_out = self.pooler(out, attn_mask)
|
216 |
+
|
217 |
+
return self.proj(pooled_out)
|
218 |
+
|
219 |
+
def lock(self, unlocked_layers:int=0, freeze_layer_norm:bool=True):
|
220 |
+
if not unlocked_layers: # full freezing
|
221 |
+
for n, p in self.transformer.named_parameters():
|
222 |
+
p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
|
223 |
+
return
|
224 |
+
|
225 |
+
encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer
|
226 |
+
layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
|
227 |
+
print(f"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model")
|
228 |
+
embeddings = getattr(
|
229 |
+
self.transformer, arch_dict[self.config.model_type]["config_names"]["token_embeddings_attr"])
|
230 |
+
modules = [embeddings, *layer_list][:-unlocked_layers]
|
231 |
+
# freeze layers
|
232 |
+
for module in modules:
|
233 |
+
for n, p in module.named_parameters():
|
234 |
+
p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
|
235 |
+
|
236 |
+
|
237 |
+
@torch.jit.ignore
|
238 |
+
def set_grad_checkpointing(self, enable=True):
|
239 |
+
self.transformer.gradient_checkpointing_enable()
|
240 |
+
|
241 |
+
def get_num_layers(self):
|
242 |
+
encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer
|
243 |
+
layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
|
244 |
+
return len(layer_list)
|
245 |
+
|
246 |
+
def init_parameters(self):
|
247 |
+
pass
|