sync code
Browse files- configuration_aria.py +7 -2
- modeling_aria.py +8 -0
- moe_lm.py +4 -2
configuration_aria.py
CHANGED
@@ -66,14 +66,19 @@ class AriaConfig(PretrainedConfig):
|
|
66 |
},
|
67 |
ignore_index=-100,
|
68 |
image_token_index=32000,
|
|
|
69 |
**kwargs,
|
70 |
):
|
71 |
super().__init__(**kwargs)
|
72 |
self.ignore_index = ignore_index
|
73 |
self.image_token_index = image_token_index
|
74 |
-
|
75 |
attn_implementation = kwargs.pop("attn_implementation", None)
|
76 |
-
|
|
|
|
|
|
|
|
|
77 |
|
78 |
# Convert the keys and values of projector_patch_to_query_dict to integers
|
79 |
# This ensures consistency even if they were provided as strings
|
|
|
66 |
},
|
67 |
ignore_index=-100,
|
68 |
image_token_index=32000,
|
69 |
+
tie_word_embeddings=False,
|
70 |
**kwargs,
|
71 |
):
|
72 |
super().__init__(**kwargs)
|
73 |
self.ignore_index = ignore_index
|
74 |
self.image_token_index = image_token_index
|
75 |
+
self.tie_word_embeddings = tie_word_embeddings
|
76 |
attn_implementation = kwargs.pop("attn_implementation", None)
|
77 |
+
|
78 |
+
# Set the default attention implementation to flash_attention_2 if not specified
|
79 |
+
self._attn_implementation = (
|
80 |
+
"flash_attention_2" if attn_implementation is None else attn_implementation
|
81 |
+
)
|
82 |
|
83 |
# Convert the keys and values of projector_patch_to_query_dict to integers
|
84 |
# This ensures consistency even if they were provided as strings
|
modeling_aria.py
CHANGED
@@ -165,6 +165,14 @@ class AriaForConditionalGeneration(AriaPretrainedModel, GenerationMixin):
|
|
165 |
"""Set the input embeddings for the language model."""
|
166 |
self.language_model.set_input_embeddings(value)
|
167 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
def set_moe_z_loss_coeff(self, value):
|
169 |
"""
|
170 |
Set the z-loss coefficient for Mixture of Experts (MoE) models.
|
|
|
165 |
"""Set the input embeddings for the language model."""
|
166 |
self.language_model.set_input_embeddings(value)
|
167 |
|
168 |
+
def get_output_embeddings(self):
|
169 |
+
"""Retrieve the output embeddings from the language model."""
|
170 |
+
return self.language_model.get_output_embeddings()
|
171 |
+
|
172 |
+
def set_output_embeddings(self, value):
|
173 |
+
"""Set the output embeddings for the language model."""
|
174 |
+
self.language_model.set_output_embeddings(value)
|
175 |
+
|
176 |
def set_moe_z_loss_coeff(self, value):
|
177 |
"""
|
178 |
Set the z-loss coefficient for Mixture of Experts (MoE) models.
|
moe_lm.py
CHANGED
@@ -255,7 +255,8 @@ class TopKRouter(nn.Module):
|
|
255 |
- top_indices: Indices of top-k experts for each token.
|
256 |
- tokens_per_expert: Number of tokens assigned to each expert.
|
257 |
"""
|
258 |
-
|
|
|
259 |
|
260 |
top_logits, top_indices = torch.topk(logits, k=self.config.moe_topk, dim=1)
|
261 |
scores = torch.softmax(top_logits, dim=-1, dtype=torch.float32).type_as(logits)
|
@@ -267,7 +268,8 @@ class TopKRouter(nn.Module):
|
|
267 |
max=self.config.moe_num_experts - 1,
|
268 |
)
|
269 |
|
270 |
-
|
|
|
271 |
return scores, top_indices, tokens_per_expert
|
272 |
|
273 |
def forward(
|
|
|
255 |
- top_indices: Indices of top-k experts for each token.
|
256 |
- tokens_per_expert: Number of tokens assigned to each expert.
|
257 |
"""
|
258 |
+
if self.training:
|
259 |
+
logits = self.apply_z_loss(logits)
|
260 |
|
261 |
top_logits, top_indices = torch.topk(logits, k=self.config.moe_topk, dim=1)
|
262 |
scores = torch.softmax(top_logits, dim=-1, dtype=torch.float32).type_as(logits)
|
|
|
268 |
max=self.config.moe_num_experts - 1,
|
269 |
)
|
270 |
|
271 |
+
if self.training:
|
272 |
+
scores = self.apply_aux_loss(logits, tokens_per_expert, scores)
|
273 |
return scores, top_indices, tokens_per_expert
|
274 |
|
275 |
def forward(
|