Crystalcareai
commited on
Update modeling_gemmoe.py
Browse files- modeling_gemmoe.py +76 -80
modeling_gemmoe.py
CHANGED
@@ -65,9 +65,82 @@ logger = logging.get_logger(__name__)
|
|
65 |
|
66 |
_CONFIG_FOR_DOC = "GemmoeConfig"
|
67 |
|
68 |
-
|
69 |
-
|
70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
|
72 |
def approx_gelu(x):
|
73 |
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * x**3)))
|
@@ -164,76 +237,6 @@ class GemmoeMLP(nn.Module):
|
|
164 |
def forward(self, x):
|
165 |
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
166 |
|
167 |
-
def load_balancing_loss_func(
|
168 |
-
self,
|
169 |
-
gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2, attention_mask: Optional[torch.Tensor] = None
|
170 |
-
) -> float:
|
171 |
-
r"""
|
172 |
-
Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
|
173 |
-
|
174 |
-
See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
|
175 |
-
function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
|
176 |
-
experts is too unbalanced.
|
177 |
-
|
178 |
-
Args:
|
179 |
-
gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
|
180 |
-
Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
|
181 |
-
shape [batch_size X sequence_length, num_experts].
|
182 |
-
attention_mask (`torch.Tensor`, None):
|
183 |
-
The attention_mask used in forward function
|
184 |
-
shape [batch_size X sequence_length] if not None.
|
185 |
-
num_experts (`int`, *optional*):
|
186 |
-
Number of experts
|
187 |
-
|
188 |
-
Returns:
|
189 |
-
The auxiliary loss.
|
190 |
-
"""
|
191 |
-
if gate_logits is None or not isinstance(gate_logits, tuple):
|
192 |
-
return 0
|
193 |
-
|
194 |
-
if isinstance(gate_logits, tuple):
|
195 |
-
compute_device = gate_logits[0].device
|
196 |
-
concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
|
197 |
-
|
198 |
-
routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
|
199 |
-
_, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
|
200 |
-
expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
|
201 |
-
|
202 |
-
if attention_mask is None:
|
203 |
-
# Compute the percentage of tokens routed to each experts
|
204 |
-
tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
|
205 |
-
# Compute the average probability of routing to these experts
|
206 |
-
router_prob_per_expert = torch.mean(routing_weights, dim=0)
|
207 |
-
else:
|
208 |
-
batch_size, sequence_length = attention_mask.shape
|
209 |
-
num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
|
210 |
-
|
211 |
-
# Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
|
212 |
-
expert_attention_mask = (
|
213 |
-
attention_mask[None, :, :, None, None]
|
214 |
-
.expand((num_hidden_layers, batch_size, sequence_length, 2, num_experts))
|
215 |
-
.reshape(-1, 2, num_experts)
|
216 |
-
.to(compute_device)
|
217 |
-
)
|
218 |
-
# Compute the percentage of tokens routed to each experts
|
219 |
-
tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
|
220 |
-
expert_attention_mask, dim=0
|
221 |
-
)
|
222 |
-
|
223 |
-
# Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
|
224 |
-
router_per_expert_attention_mask = (
|
225 |
-
attention_mask[None, :, :, None]
|
226 |
-
.expand((num_hidden_layers, batch_size, sequence_length, num_experts))
|
227 |
-
.reshape(-1, num_experts)
|
228 |
-
.to(compute_device)
|
229 |
-
)
|
230 |
-
# Compute the average probability of routing to these experts
|
231 |
-
router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
|
232 |
-
router_per_expert_attention_mask, dim=0
|
233 |
-
)
|
234 |
-
|
235 |
-
overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
|
236 |
-
return overall_loss * num_experts
|
237 |
|
238 |
def repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
239 |
"""
|
@@ -1153,13 +1156,6 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
|
|
1153 |
# Initialize weights and apply final processing
|
1154 |
self.post_init()
|
1155 |
|
1156 |
-
def parallelize(self, device_map=None):
|
1157 |
-
self.model = GemmoeDistributedDataParallel(
|
1158 |
-
self.model,
|
1159 |
-
device_ids=[torch.cuda.current_device()],
|
1160 |
-
output_device=torch.cuda.current_device(),
|
1161 |
-
)
|
1162 |
-
|
1163 |
def get_input_embeddings(self):
|
1164 |
return self.model.embed_tokens
|
1165 |
|
|
|
65 |
|
66 |
_CONFIG_FOR_DOC = "GemmoeConfig"
|
67 |
|
68 |
+
def load_balancing_loss_func(
|
69 |
+
gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2, attention_mask: Optional[torch.Tensor] = None
|
70 |
+
) -> float:
|
71 |
+
r"""
|
72 |
+
Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
|
73 |
+
|
74 |
+
See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
|
75 |
+
function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
|
76 |
+
experts is too unbalanced.
|
77 |
+
|
78 |
+
Args:
|
79 |
+
gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
|
80 |
+
Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
|
81 |
+
shape [batch_size X sequence_length, num_experts].
|
82 |
+
attention_mask (`torch.Tensor`, None):
|
83 |
+
The attention_mask used in forward function
|
84 |
+
shape [batch_size X sequence_length] if not None.
|
85 |
+
num_experts (`int`, *optional*):
|
86 |
+
Number of experts
|
87 |
+
|
88 |
+
Returns:
|
89 |
+
The auxiliary loss.
|
90 |
+
"""
|
91 |
+
if gate_logits is None or not isinstance(gate_logits, tuple):
|
92 |
+
return 0
|
93 |
+
|
94 |
+
if isinstance(gate_logits, tuple):
|
95 |
+
compute_device = gate_logits[0].device
|
96 |
+
concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
|
97 |
+
|
98 |
+
routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
|
99 |
+
|
100 |
+
_, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
|
101 |
+
|
102 |
+
expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
|
103 |
+
|
104 |
+
if attention_mask is None:
|
105 |
+
# Compute the percentage of tokens routed to each experts
|
106 |
+
tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
|
107 |
+
|
108 |
+
# Compute the average probability of routing to these experts
|
109 |
+
router_prob_per_expert = torch.mean(routing_weights, dim=0)
|
110 |
+
else:
|
111 |
+
batch_size, sequence_length = attention_mask.shape
|
112 |
+
num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
|
113 |
+
|
114 |
+
# Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
|
115 |
+
expert_attention_mask = (
|
116 |
+
attention_mask[None, :, :, None, None]
|
117 |
+
.expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
|
118 |
+
.reshape(-1, top_k, num_experts)
|
119 |
+
.to(compute_device)
|
120 |
+
)
|
121 |
+
|
122 |
+
# Compute the percentage of tokens routed to each experts
|
123 |
+
tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
|
124 |
+
expert_attention_mask, dim=0
|
125 |
+
)
|
126 |
+
|
127 |
+
# Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
|
128 |
+
router_per_expert_attention_mask = (
|
129 |
+
attention_mask[None, :, :, None]
|
130 |
+
.expand((num_hidden_layers, batch_size, sequence_length, num_experts))
|
131 |
+
.reshape(-1, num_experts)
|
132 |
+
.to(compute_device)
|
133 |
+
)
|
134 |
+
|
135 |
+
# Compute the average probability of routing to these experts
|
136 |
+
router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
|
137 |
+
router_per_expert_attention_mask, dim=0
|
138 |
+
)
|
139 |
+
|
140 |
+
overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
|
141 |
+
return overall_loss * num_experts
|
142 |
+
|
143 |
+
|
144 |
|
145 |
def approx_gelu(x):
|
146 |
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * x**3)))
|
|
|
237 |
def forward(self, x):
|
238 |
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
239 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
240 |
|
241 |
def repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
242 |
"""
|
|
|
1156 |
# Initialize weights and apply final processing
|
1157 |
self.post_init()
|
1158 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1159 |
def get_input_embeddings(self):
|
1160 |
return self.model.embed_tokens
|
1161 |
|