Safetensors
gLM2
custom_code
andrecornman commited on
Commit
57d7cf6
·
verified ·
1 Parent(s): 421e981

Update modeling_glm2.py

Browse files
Files changed (1) hide show
  1. modeling_glm2.py +99 -197
modeling_glm2.py CHANGED
@@ -1,12 +1,11 @@
1
  """PyTorch gLM2 model.
2
 
3
- Requires flash attention.
4
  Some modules adapted from:
5
  https://github.com/meta-llama/llama/blob/main/llama/model.py
6
  """
7
- import math
8
  import torch
9
- from einops import rearrange
10
  from typing import Optional, Tuple, Union
11
  from torch import nn
12
  from torch.nn import CrossEntropyLoss
@@ -16,30 +15,51 @@ from transformers.modeling_outputs import (
16
  )
17
  from transformers.modeling_utils import PreTrainedModel
18
  from transformers.utils import logging
 
19
 
20
- try:
21
- from flash_attn.ops.activations import swiglu
22
- from flash_attn.layers.rotary import apply_rotary_emb_func
23
- from flash_attn import (
24
- flash_attn_kvpacked_func,
25
- flash_attn_varlen_kvpacked_func,
26
- )
27
- from flash_attn.bert_padding import pad_input, unpad_input
28
- from flash_attn.ops.triton.layer_norm import RMSNorm
29
- except ImportError:
30
- raise ImportError(
31
- "gLM2 requires flash attention: `pip install flash-attn --no-build-isolation`")
32
 
33
- from .configuration_glm2 import gLM2Config
 
 
 
 
 
 
 
 
 
34
 
35
 
36
- logger = logging.get_logger(__name__)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
 
39
  class RotaryEmbedding(torch.nn.Module):
40
  """
41
  Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py.
42
- Changed to only support passing in q or k individually, so that we can use varlen rotary.
43
  """
44
 
45
  def __init__(
@@ -137,92 +157,52 @@ class RotaryEmbedding(torch.nn.Module):
137
 
138
  def forward(
139
  self,
140
- q: torch.Tensor,
141
- k: torch.Tensor,
142
- seqlen_offset: Union[int, torch.Tensor] = 0,
143
- cu_seqlens: Optional[torch.Tensor] = None,
144
  max_seqlen: Optional[int] = None,
145
  ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
146
  """
147
- q: (batch, seqlen, nheads, headdim). If cu_seqlens is not None,
148
- shape (total_seqlen, nheads, headdim).
149
- k: (batch, seqlen, nheads, headdim). If cu_seqlens is not None,
150
- shape (total_seqlen, nheads, headdim).
151
- seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount.
152
- Most commonly used in inference when we have KV cache.
153
- If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one
154
- should pass in max_seqlen, which will update the cos / sin cache up to that length.
155
- Apply rotary embedding *inplace* to qkv and / or kv.
156
  """
157
- if cu_seqlens is not None:
158
- assert max_seqlen is not None
159
- seqlen = q.shape[1] if max_seqlen is None else max_seqlen
160
- if max_seqlen is not None:
161
  self._update_cos_sin_cache(
162
- max_seqlen, device=q.device, dtype=q.dtype)
163
- elif isinstance(seqlen_offset, int):
164
  self._update_cos_sin_cache(
165
- seqlen + seqlen_offset, device=q.device, dtype=q.dtype
166
- )
167
- q = apply_rotary_emb_func(
168
- q,
169
- self._cos_cached,
170
- self._sin_cached,
171
- interleaved=self.interleaved,
172
- inplace=True,
173
- seqlen_offsets=seqlen_offset,
174
- cu_seqlens=cu_seqlens,
175
- max_seqlen=max_seqlen,
176
  )
177
- if self.scale is None:
178
- k = apply_rotary_emb_func(
179
- k,
180
- self._cos_cached,
181
- self._sin_cached,
182
- interleaved=self.interleaved,
183
- inplace=True,
184
- seqlen_offsets=seqlen_offset,
185
- cu_seqlens=cu_seqlens,
186
- max_seqlen=max_seqlen,
187
- )
188
- else:
189
- k = apply_rotary_emb_func(
190
- k,
191
- self._cos_k_cached,
192
- self._sin_k_cached,
193
- interleaved=self.interleaved,
194
- inplace=True,
195
- seqlen_offsets=seqlen_offset,
196
- cu_seqlens=cu_seqlens,
197
- max_seqlen=max_seqlen,
198
- )
199
- return q, k
200
 
201
 
202
  # @torch.jit.script
203
- # def rmsnorm_func(hidden_states, weight, variance_epsilon):
204
- # """Apply the root mean square normalization."""
205
- # input_dtype = hidden_states.dtype
206
- # hidden_states = hidden_states.to(torch.float32)
207
- # variance = hidden_states.pow(2).mean(-1, keepdim=True)
208
- # hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
209
- # return (weight * hidden_states).to(input_dtype)
210
 
211
 
212
- # class RMSNorm(nn.Module):
213
- # """Root mean square normalization."""
214
 
215
- # def __init__(self, dim, eps=1e-6):
216
- # super().__init__()
217
- # self.weight = nn.Parameter(torch.ones(dim))
218
- # self.register_buffer(
219
- # "variance_epsilon",
220
- # torch.tensor(eps),
221
- # persistent=False,
222
- # )
223
 
224
- # def forward(self, hidden_states):
225
- # return rmsnorm_func(hidden_states, self.weight, self.variance_epsilon)
226
 
227
 
228
  class Attention(nn.Module):
@@ -240,67 +220,33 @@ class Attention(nn.Module):
240
 
241
  self.rotary_emb = RotaryEmbedding(self.head_dim)
242
 
243
- def _forward_varlen(
244
- self,
245
- x: torch.Tensor,
246
- cu_seqlens: Optional[torch.Tensor] = None,
247
- max_seq_len: Optional[torch.Tensor] = None,
248
- ) -> torch.Tensor:
249
- total_seqlen, h_size = x.shape
250
- qkv = self.wqkv(x)
251
- q, k, v = torch.split(qkv, self.n_heads * self.head_dim, dim=-1)
252
-
253
- q = q.view(total_seqlen, self.n_heads, self.head_dim)
254
- k = k.view(total_seqlen, self.n_heads, self.head_dim)
255
- v = v.view(total_seqlen, self.n_heads, self.head_dim)
256
-
257
- q, k = self.rotary_emb(
258
- q, k, cu_seqlens=cu_seqlens, max_seqlen=max_seq_len)
259
-
260
- # (seqlen, 2, n_heads, head_dim)
261
- kv = torch.stack([k, v], 1)
262
-
263
- # (seqlen, n_heads, head_dim)
264
- output = flash_attn_varlen_kvpacked_func(
265
- q,
266
- kv,
267
- cu_seqlens_q=cu_seqlens,
268
- cu_seqlens_k=cu_seqlens,
269
- max_seqlen_q=max_seq_len,
270
- max_seqlen_k=max_seq_len,
271
- dropout_p=0.0,
272
- causal=False,
273
- )
274
- output = output.view(total_seqlen, h_size)
275
- return self.wo(output)
276
-
277
  def forward(
278
  self,
279
  x: torch.Tensor,
280
- cu_seqlens: Optional[torch.Tensor] = None,
281
- max_seq_len: Optional[torch.Tensor] = None,
282
  ) -> torch.Tensor:
283
- if cu_seqlens is not None:
284
- assert max_seq_len is not None
285
- return self._forward_varlen(x, cu_seqlens, max_seq_len)
286
-
287
  bsz, seqlen, h_size = x.shape
288
  qkv = self.wqkv(x)
289
- q, k, v = torch.split(qkv, self.n_heads * self.head_dim, dim=-1)
290
- q = q.view(bsz, seqlen, self.n_heads, self.head_dim)
291
- k = k.view(bsz, seqlen, self.n_heads, self.head_dim)
292
- v = v.view(bsz, seqlen, self.n_heads, self.head_dim)
293
-
294
- q, k = self.rotary_emb(q, k)
295
- # (bs, seqlen, 2, n_heads, head_dim)
296
- kv = torch.stack([k, v], 2)
297
-
298
- output = flash_attn_kvpacked_func(
299
- q,
300
- kv,
301
- dropout_p=0.0,
302
- causal=False,
 
 
 
303
  )
 
 
304
  output = output.view(bsz, seqlen, h_size)
305
  return self.wo(output)
306
 
@@ -335,7 +281,7 @@ class FeedForward(nn.Module):
335
  self.w3 = nn.Linear(dim, hidden_dim, bias=False)
336
 
337
  def forward(self, x):
338
- return self.w2(swiglu(self.w1(x), self.w3(x)))
339
 
340
 
341
  class TransformerBlock(nn.Module):
@@ -357,12 +303,10 @@ class TransformerBlock(nn.Module):
357
  def forward(
358
  self,
359
  x: torch.Tensor,
360
- cu_seqlens: Optional[torch.Tensor] = None,
361
- max_seq_len: Optional[torch.Tensor] = None,
362
  ) -> torch.Tensor:
363
- r = self.attention(
364
- self.attention_norm(x), cu_seqlens, max_seq_len
365
- )
366
  h = x + r
367
  r = self.feed_forward(self.ffn_norm(h))
368
  out = h + r
@@ -376,19 +320,6 @@ class TransformerLayers(nn.Module):
376
  self.layers = torch.nn.ModuleList(
377
  [TransformerBlock(config=config) for _ in range(config.depth)]
378
  )
379
- self.apply(self._init_weights)
380
- # Apply special scaled init to the residual projections, per GPT-2 paper.
381
- # Weight w2 is output of FeedForward. Weight wo is output of Attention.
382
- for pn, p in self.named_parameters():
383
- if pn.endswith('w2.weight') or pn.endswith('wo.weight'):
384
- torch.nn.init.normal_(
385
- p, mean=0.0, std=0.02/math.sqrt(2 * self.config.depth))
386
-
387
- def _init_weights(self, module):
388
- if isinstance(module, nn.Linear):
389
- torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
390
- if module.bias is not None:
391
- torch.nn.init.zeros_(module.bias)
392
 
393
  def forward(
394
  self,
@@ -400,26 +331,12 @@ class TransformerLayers(nn.Module):
400
  raise ValueError(
401
  f"Input feature dim should be {self.config.dim}, but input has shape {x.shape}"
402
  )
403
- batch_size, seq_len = x.shape[:2]
404
- should_unpad = attention_mask is not None and not attention_mask.all()
405
- if should_unpad:
406
- x, indices, cu_seqlens, max_seq_len_in_batch = unpad_input(
407
- x, attention_mask
408
- )
409
- else:
410
- indices, cu_seqlens, max_seq_len_in_batch = None, None, None
411
  hiddens = []
412
  for layer in self.layers:
413
- x = layer(x, cu_seqlens, max_seq_len_in_batch)
414
  if return_all_hiddens:
415
  hiddens.append(x)
416
 
417
- if should_unpad:
418
- x = pad_input(x, indices, batch_size, seq_len)
419
- if return_all_hiddens:
420
- hiddens = [pad_input(h, indices, batch_size, seq_len)
421
- for h in hiddens]
422
-
423
  if return_all_hiddens:
424
  return x, hiddens
425
  return x
@@ -454,16 +371,9 @@ class gLM2Model(gLM2PreTrainedModel):
454
  self.config = config
455
 
456
  self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
457
- self._init_weights(self.tok_embeddings)
458
  self.encoder = TransformerLayers(config)
459
-
460
- def _init_weights(self, module):
461
- if isinstance(module, nn.Linear):
462
- torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
463
- if module.bias is not None:
464
- torch.nn.init.zeros_(module.bias)
465
- elif isinstance(module, nn.Embedding):
466
- torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
467
 
468
  def forward(
469
  self,
@@ -502,15 +412,7 @@ class gLM2ForMaskedLM(gLM2PreTrainedModel):
502
 
503
  self.glm2 = gLM2Model(config)
504
  self.lm_head = gLM2LMHead(config)
505
- self._init_weights(self.lm_head)
506
-
507
- def _init_weights(self, module):
508
- if isinstance(module, nn.Linear):
509
- torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
510
- if module.bias is not None:
511
- torch.nn.init.zeros_(module.bias)
512
- elif isinstance(module, nn.Embedding):
513
- torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
514
 
515
  def forward(
516
  self,
@@ -562,4 +464,4 @@ class gLM2LMHead(nn.Module):
562
  config.dim, config.vocab_size, bias=False)
563
 
564
  def forward(self, features):
565
- return self.proj_output(self.norm(features))
 
1
  """PyTorch gLM2 model.
2
 
 
3
  Some modules adapted from:
4
  https://github.com/meta-llama/llama/blob/main/llama/model.py
5
  """
6
+
7
  import torch
8
+ from einops import rearrange, repeat
9
  from typing import Optional, Tuple, Union
10
  from torch import nn
11
  from torch.nn import CrossEntropyLoss
 
15
  )
16
  from transformers.modeling_utils import PreTrainedModel
17
  from transformers.utils import logging
18
+ from .configuration_glm2 import gLM2Config
19
 
20
+ logger = logging.get_logger(__name__)
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+
23
+ def rotate_half(x, interleaved=False):
24
+ if not interleaved:
25
+ x1, x2 = x.chunk(2, dim=-1)
26
+ return torch.cat((-x2, x1), dim=-1)
27
+ else:
28
+ x1, x2 = x[..., ::2], x[..., 1::2]
29
+ return rearrange(
30
+ torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
31
+ )
32
 
33
 
34
+ def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
35
+ """
36
+ x: (batch_size, seqlen, nheads, headdim)
37
+ cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
38
+ """
39
+ ro_dim = cos.shape[-1] * 2
40
+ assert ro_dim <= x.shape[-1]
41
+ seqlen = x.shape[1]
42
+ cos, sin = cos[:seqlen], sin[:seqlen]
43
+ cos = repeat(
44
+ cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
45
+ )
46
+ sin = repeat(
47
+ sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
48
+ )
49
+ return torch.cat(
50
+ [
51
+ x[..., :ro_dim] * cos +
52
+ rotate_half(x[..., :ro_dim], interleaved) * sin,
53
+ x[..., ro_dim:],
54
+ ],
55
+ dim=-1,
56
+ )
57
 
58
 
59
  class RotaryEmbedding(torch.nn.Module):
60
  """
61
  Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py.
62
+ Changed to use the torch version of apply_rotary_emb_func.
63
  """
64
 
65
  def __init__(
 
157
 
158
  def forward(
159
  self,
160
+ qkv: torch.Tensor,
 
 
 
161
  max_seqlen: Optional[int] = None,
162
  ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
163
  """
164
+ qkv: (batch, seqlen, 3, nheads, headdim)
 
 
 
 
 
 
 
 
165
  """
166
+ seqlen = qkv.shape[1]
167
+ if seqlen > self._seq_len_cached:
 
 
168
  self._update_cos_sin_cache(
169
+ seqlen, device=qkv.device, dtype=qkv.dtype)
170
+ elif max_seqlen is not None:
171
  self._update_cos_sin_cache(
172
+ max_seqlen, device=qkv.device, dtype=qkv.dtype)
173
+ q_rot = apply_rotary_emb_torch(
174
+ qkv[:, :, 0], self._cos_cached, self._sin_cached, self.interleaved
 
 
 
 
 
 
 
 
175
  )
176
+ k_rot = apply_rotary_emb_torch(
177
+ qkv[:, :, 1], self._cos_cached, self._sin_cached, self.interleaved
178
+ )
179
+ return torch.stack((q_rot, k_rot, qkv[:, :, 2]), dim=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
 
182
  # @torch.jit.script
183
+ def rmsnorm_func(hidden_states, weight, variance_epsilon):
184
+ """Apply the root mean square normalization."""
185
+ input_dtype = hidden_states.dtype
186
+ hidden_states = hidden_states.to(torch.float32)
187
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
188
+ hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
189
+ return (weight * hidden_states).to(input_dtype)
190
 
191
 
192
+ class RMSNorm(nn.Module):
193
+ """Root mean square normalization."""
194
 
195
+ def __init__(self, dim, eps=1e-6):
196
+ super().__init__()
197
+ self.weight = nn.Parameter(torch.ones(dim))
198
+ self.register_buffer(
199
+ "variance_epsilon",
200
+ torch.tensor(eps),
201
+ persistent=False,
202
+ )
203
 
204
+ def forward(self, hidden_states):
205
+ return rmsnorm_func(hidden_states, self.weight, self.variance_epsilon)
206
 
207
 
208
  class Attention(nn.Module):
 
220
 
221
  self.rotary_emb = RotaryEmbedding(self.head_dim)
222
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  def forward(
224
  self,
225
  x: torch.Tensor,
226
+ attention_mask: Optional[torch.Tensor] = None,
 
227
  ) -> torch.Tensor:
 
 
 
 
228
  bsz, seqlen, h_size = x.shape
229
  qkv = self.wqkv(x)
230
+
231
+ qkv = qkv.view(bsz, seqlen, 3, self.n_heads, self.head_dim)
232
+ qkv = self.rotary_emb(qkv)
233
+
234
+ # (batch, nheads, 3, seqlen, headdim)
235
+ qkv = torch.transpose(qkv, 3, 1)
236
+ q = qkv[:, :, 0]
237
+ k = qkv[:, :, 1]
238
+ v = qkv[:, :, 2]
239
+ if attention_mask is not None:
240
+ attention_mask = attention_mask[:, None, None, :]
241
+ attention_mask = attention_mask.expand(
242
+ bsz, self.n_heads, seqlen, seqlen
243
+ ).bool()
244
+ # [B, heads, seq, D]
245
+ output = torch.nn.functional.scaled_dot_product_attention(
246
+ q, k, v, attn_mask=attention_mask
247
  )
248
+ output = output.permute(0, 2, 1, 3).contiguous()
249
+
250
  output = output.view(bsz, seqlen, h_size)
251
  return self.wo(output)
252
 
 
281
  self.w3 = nn.Linear(dim, hidden_dim, bias=False)
282
 
283
  def forward(self, x):
284
+ return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))
285
 
286
 
287
  class TransformerBlock(nn.Module):
 
303
  def forward(
304
  self,
305
  x: torch.Tensor,
306
+ attention_mask: Optional[torch.Tensor] = None,
 
307
  ) -> torch.Tensor:
308
+ r = self.attention(self.attention_norm(
309
+ x), attention_mask=attention_mask)
 
310
  h = x + r
311
  r = self.feed_forward(self.ffn_norm(h))
312
  out = h + r
 
320
  self.layers = torch.nn.ModuleList(
321
  [TransformerBlock(config=config) for _ in range(config.depth)]
322
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
323
 
324
  def forward(
325
  self,
 
331
  raise ValueError(
332
  f"Input feature dim should be {self.config.dim}, but input has shape {x.shape}"
333
  )
 
 
 
 
 
 
 
 
334
  hiddens = []
335
  for layer in self.layers:
336
+ x = layer(x, attention_mask=attention_mask)
337
  if return_all_hiddens:
338
  hiddens.append(x)
339
 
 
 
 
 
 
 
340
  if return_all_hiddens:
341
  return x, hiddens
342
  return x
 
371
  self.config = config
372
 
373
  self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
 
374
  self.encoder = TransformerLayers(config)
375
+ # Initialize weights and apply final processing
376
+ self.post_init()
 
 
 
 
 
 
377
 
378
  def forward(
379
  self,
 
412
 
413
  self.glm2 = gLM2Model(config)
414
  self.lm_head = gLM2LMHead(config)
415
+ self.init_weights()
 
 
 
 
 
 
 
 
416
 
417
  def forward(
418
  self,
 
464
  config.dim, config.vocab_size, bias=False)
465
 
466
  def forward(self, features):
467
+ return self.proj_output(self.norm(features))