Crystalcareai
commited on
Update modeling_gemmoe.py
Browse files- modeling_gemmoe.py +45 -25
modeling_gemmoe.py
CHANGED
@@ -55,6 +55,7 @@ if is_flash_attn_2_available():
|
|
55 |
|
56 |
# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
|
57 |
# It means that the function will not be traced through and simply appear as a node in the graph.
|
|
|
58 |
if is_torch_fx_available():
|
59 |
if not is_torch_greater_or_equal_than_1_13:
|
60 |
import torch.fx
|
@@ -166,42 +167,52 @@ class GemmoeRMSNorm(nn.Module):
|
|
166 |
self.weight = nn.Parameter(torch.zeros(dim))
|
167 |
|
168 |
def _norm(self, x):
|
169 |
-
|
|
|
|
|
|
|
|
|
170 |
|
171 |
def forward(self, x):
|
172 |
-
|
173 |
-
|
|
|
|
|
174 |
|
175 |
ALL_LAYERNORM_LAYERS.append(GemmoeRMSNorm)
|
176 |
|
177 |
class GemmoeRotaryEmbedding(nn.Module):
|
178 |
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
179 |
super().__init__()
|
180 |
-
|
181 |
self.dim = dim
|
182 |
self.max_position_embeddings = max_position_embeddings
|
183 |
self.base = base
|
184 |
-
self.
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
|
|
|
|
|
|
|
|
|
|
205 |
|
206 |
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
207 |
def rotate_half(x):
|
@@ -1034,6 +1045,15 @@ class GemmoeModel(GemmoePreTrainedModel):
|
|
1034 |
if inputs_embeds is None:
|
1035 |
inputs_embeds = self.embed_tokens(input_ids)
|
1036 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1037 |
past_seen_tokens = 0
|
1038 |
if use_cache: # kept for BC (cache positions)
|
1039 |
if not isinstance(past_key_values, StaticCache):
|
|
|
55 |
|
56 |
# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
|
57 |
# It means that the function will not be traced through and simply appear as a node in the graph.
|
58 |
+
|
59 |
if is_torch_fx_available():
|
60 |
if not is_torch_greater_or_equal_than_1_13:
|
61 |
import torch.fx
|
|
|
167 |
self.weight = nn.Parameter(torch.zeros(dim))
|
168 |
|
169 |
def _norm(self, x):
|
170 |
+
# Ensure the entire normalization is done in float32
|
171 |
+
x_float = x.float() # upcast to float32
|
172 |
+
mean = x_float.pow(2).mean(-1, keepdim=True)
|
173 |
+
normed_x = x_float * torch.rsqrt(mean + self.eps)
|
174 |
+
return normed_x
|
175 |
|
176 |
def forward(self, x):
|
177 |
+
normed_x = self._norm(x)
|
178 |
+
# Downcast the result to the original dtype at the end
|
179 |
+
normed_x = normed_x.type_as(x)
|
180 |
+
return normed_x * (self.weight + 1)
|
181 |
|
182 |
ALL_LAYERNORM_LAYERS.append(GemmoeRMSNorm)
|
183 |
|
184 |
class GemmoeRotaryEmbedding(nn.Module):
|
185 |
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
186 |
super().__init__()
|
|
|
187 |
self.dim = dim
|
188 |
self.max_position_embeddings = max_position_embeddings
|
189 |
self.base = base
|
190 |
+
self._set_cos_sin_cache(seq_len=max_position_embeddings, device=device, dtype=torch.get_default_dtype())
|
191 |
+
|
192 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
193 |
+
self.max_seq_len_cached = seq_len
|
194 |
+
freq_exponents = (2.0 / self.dim) * (
|
195 |
+
torch.arange(self.dim // 2, dtype=torch.int64, device="cpu").float()
|
196 |
+
)
|
197 |
+
timescale = self.base ** freq_exponents
|
198 |
+
positions = torch.arange(self.max_seq_len_cached, device="cpu", dtype=torch.int64).float()
|
199 |
+
radians_new = positions[..., None] / timescale[None, None, :]
|
200 |
+
radians_new = radians_new.squeeze(0)
|
201 |
+
emb = torch.cat((radians_new, radians_new), dim=-1)
|
202 |
+
cos = emb.cos().to(device=device, non_blocking=True)
|
203 |
+
sin = emb.sin().to(device=device, non_blocking=True)
|
204 |
+
self.register_buffer("cos_cached", cos, persistent=False)
|
205 |
+
self.register_buffer("sin_cached", sin, persistent=False)
|
206 |
+
|
207 |
+
def forward(self, x, position_ids=None, seq_len=None):
|
208 |
+
if seq_len is None:
|
209 |
+
seq_len = x.size(2)
|
210 |
+
if seq_len > self.max_seq_len_cached:
|
211 |
+
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
212 |
+
return (
|
213 |
+
self.cos_cached[:seq_len],
|
214 |
+
self.sin_cached[:seq_len],
|
215 |
+
)
|
216 |
|
217 |
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
218 |
def rotate_half(x):
|
|
|
1045 |
if inputs_embeds is None:
|
1046 |
inputs_embeds = self.embed_tokens(input_ids)
|
1047 |
|
1048 |
+
# Scale embeddings
|
1049 |
+
# Fix for precision issue when casting to bfloat16
|
1050 |
+
hidden_size_sqrt = math.sqrt(self.config.hidden_size)
|
1051 |
+
if inputs_embeds.dtype == torch.bfloat16:
|
1052 |
+
# Use float32 for sqrt calculation to avoid precision loss
|
1053 |
+
hidden_size_sqrt = hidden_size_sqrt.astype(torch.float32)
|
1054 |
+
|
1055 |
+
hidden_states = inputs_embeds * hidden_size_sqrt
|
1056 |
+
|
1057 |
past_seen_tokens = 0
|
1058 |
if use_cache: # kept for BC (cache positions)
|
1059 |
if not isinstance(past_key_values, StaticCache):
|