Oscar Wang commited on
Commit
31fe5ed
·
verified ·
1 Parent(s): ad5d4bc

Create modelling_llamagloo.py

Browse files
Files changed (1) hide show
  1. modelling_llamagloo.py +412 -0
modelling_llamagloo.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel, PretrainedConfig
2
+ import torch
3
+ import torch.nn as nn
4
+ from transformers.modeling_outputs import CausalLMOutputWithPast
5
+
6
+ # -------------------- Configuration --------------------
7
+ class LlamaGlooConfig(PretrainedConfig):
8
+ model_type = "llamagloo"
9
+
10
+ def __init__(
11
+ self,
12
+ vocab_size=32000,
13
+ hidden_size=2560,
14
+ intermediate_size=10240,
15
+ num_hidden_layers=24,
16
+ num_attention_heads=32,
17
+ num_key_value_heads=None,
18
+ rope_theta=10000.0,
19
+ use_rms_norm=True,
20
+ rms_norm_eps=1e-6,
21
+ use_gqa=False,
22
+ ffn_type="llama",
23
+ initializer_range=0.02,
24
+ tie_word_embeddings=False,
25
+ pad_token_id=0,
26
+ bos_token_id=1,
27
+ eos_token_id=2,
28
+ **kwargs,
29
+ ):
30
+ self.vocab_size = vocab_size
31
+ self.hidden_size = hidden_size
32
+ self.intermediate_size = intermediate_size
33
+ self.num_hidden_layers = num_hidden_layers
34
+ self.num_attention_heads = num_attention_heads
35
+ self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads
36
+ self.rope_theta = rope_theta
37
+ self.use_rms_norm = use_rms_norm
38
+ self.rms_norm_eps = rms_norm_eps
39
+ self.use_gqa = use_gqa
40
+ self.ffn_type = ffn_type
41
+ self.initializer_range = initializer_range
42
+ super().__init__(tie_word_embeddings=tie_word_embeddings, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
43
+
44
+ # -------------------- Rotary Position Embeddings --------------------
45
+ def rotate_half(x):
46
+ x1, x2 = x.chunk(2, dim=-1)
47
+ return torch.cat((-x2, x1), dim=-1)
48
+
49
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
50
+ cos = cos[position_ids].unsqueeze(1)
51
+ sin = sin[position_ids].unsqueeze(1)
52
+ q_embed = (q * cos) + (rotate_half(q) * sin)
53
+ k_embed = (k * cos) + (rotate_half(k) * sin)
54
+ return q_embed, k_embed
55
+
56
+ class LlamaGlooRotaryEmbedding(nn.Module):
57
+ def __init__(self, dim, base=10000):
58
+ super().__init__()
59
+ self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
60
+ self.dim = dim
61
+ self.cos_cache = None
62
+ self.sin_cache = None
63
+
64
+ def forward(self, x, seq_len=None):
65
+ if seq_len is None:
66
+ seq_len = x.shape[-2]
67
+ t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
68
+ freqs = torch.outer(t, self.inv_freq)
69
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
70
+ cos = emb.cos()
71
+ sin = emb.sin()
72
+ return cos, sin
73
+
74
+ # -------------------- RMS Normalization --------------------
75
+ class RMSNorm(nn.Module):
76
+ def __init__(self, hidden_size, eps=1e-6):
77
+ super().__init__()
78
+ self.weight = nn.Parameter(torch.ones(hidden_size))
79
+ self.variance_epsilon = eps
80
+
81
+ def forward(self, hidden_states):
82
+ input_dtype = hidden_states.dtype
83
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
84
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
85
+ return (self.weight * hidden_states).to(input_dtype)
86
+
87
+ # -------------------- Attention Mechanism --------------------
88
+ class LlamaGlooAttention(nn.Module):
89
+ def __init__(self, config: LlamaGlooConfig):
90
+ super().__init__()
91
+ self.config = config
92
+ self.hidden_size = config.hidden_size
93
+ self.num_heads = config.num_attention_heads
94
+ self.head_dim = self.hidden_size // self.num_heads
95
+ self.num_key_value_heads = config.num_key_value_heads
96
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
97
+ self.rope_theta = config.rope_theta
98
+
99
+ if (self.head_dim * self.num_heads) != self.hidden_size:
100
+ raise ValueError(
101
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
102
+ f" and `num_heads`: {self.num_heads})."
103
+ )
104
+
105
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
106
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
107
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
108
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
109
+ self.rotary_emb = LlamaGlooRotaryEmbedding(self.head_dim, base=self.rope_theta)
110
+
111
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
112
+ return tensor.view(bsz, seq_len, self.num_heads if not self.config.use_gqa else self.num_key_value_heads, self.head_dim).transpose(1, 2)
113
+
114
+ def _unshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
115
+ return tensor.transpose(1, 2).contiguous().view(bsz, seq_len, self.hidden_size)
116
+
117
+ def forward(self, hidden_states, attention_mask=None, past_key_value=None, output_attentions=False, use_cache=True):
118
+ bsz, seq_len, _ = hidden_states.size()
119
+ q_proj = self.q_proj(hidden_states)
120
+ k_proj = self.k_proj(hidden_states)
121
+ v_proj = self.v_proj(hidden_states)
122
+
123
+ q = self._shape(q_proj, seq_len, bsz)
124
+ k = self._shape(k_proj, seq_len, bsz)
125
+ v = self._shape(v_proj, seq_len, bsz)
126
+
127
+ cos, sin = self.rotary_emb(q, seq_len)
128
+ q, k = apply_rotary_pos_emb(q, k, cos, sin, torch.arange(seq_len, device=hidden_states.device))
129
+
130
+ if past_key_value is not None:
131
+ kv_seq_len = past_key_value[0].shape[-2]
132
+ cos, sin = self.rotary_emb(k, seq_len + kv_seq_len)
133
+ k, v = apply_rotary_pos_emb(k, v, cos, sin, torch.arange(kv_seq_len, seq_len + kv_seq_len, device=hidden_states.device))
134
+ k = torch.cat([past_key_value[0], k], dim=1)
135
+ v = torch.cat([past_key_value[1], v], dim=1)
136
+
137
+ past_key_value = (k, v) if use_cache else None
138
+
139
+ if self.config.use_gqa:
140
+ k = k.repeat_interleave(self.num_key_value_groups, dim=1)
141
+ v = v.repeat_interleave(self.num_key_value_groups, dim=1)
142
+
143
+ attn_weights = torch.matmul(q, k.transpose(2, 3)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32, device=hidden_states.device))
144
+
145
+ if attention_mask is not None:
146
+ attn_weights = attn_weights + attention_mask
147
+
148
+ attn_weights = torch.nn.functional.softmax(attn_weights.float(), dim=-1).type_as(attn_weights)
149
+ attn_output = torch.matmul(attn_weights, v)
150
+
151
+ attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, seq_len, self.hidden_size)
152
+ attn_output = self.o_proj(attn_output)
153
+
154
+ outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
155
+
156
+ if use_cache:
157
+ outputs = outputs + (past_key_value,)
158
+
159
+ return outputs
160
+
161
+ # -------------------- Feedforward Network --------------------
162
+ class LlamaGlooMLP(nn.Module):
163
+ def __init__(self, config: LlamaGlooConfig):
164
+ super().__init__()
165
+ self.config = config
166
+ self.hidden_size = config.hidden_size
167
+ self.intermediate_size = config.intermediate_size
168
+
169
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
170
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
171
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
172
+ self.ffn_type = config.ffn_type
173
+
174
+ def forward(self, x):
175
+ if self.ffn_type == "llama":
176
+ gate = torch.nn.functional.silu(self.gate_proj(x))
177
+ up = self.up_proj(x)
178
+ return self.down_proj(gate * up)
179
+ elif self.ffn_type == "glu":
180
+ return self.down_proj(self.gate_proj(x) * self.up_proj(x)) # Example GLU
181
+ else:
182
+ raise ValueError(f"Unknown ffn_type: {self.ffn_type}")
183
+
184
+ # -------------------- Transformer Layer --------------------
185
+ class LlamaGlooDecoderLayer(nn.Module):
186
+ def __init__(self, config: LlamaGlooConfig):
187
+ super().__init__()
188
+ self.config = config
189
+ self.hidden_size = config.hidden_size
190
+ self.self_attn = LlamaGlooAttention(config=config)
191
+ self.mlp = LlamaGlooMLP(config)
192
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) if config.use_rms_norm else nn.LayerNorm(config.hidden_size)
193
+ self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) if config.use_rms_norm else nn.LayerNorm(config.hidden_size)
194
+
195
+ def forward(self, hidden_states, attention_mask=None, past_key_value=None, output_attentions=False, use_cache=True):
196
+ residual = hidden_states
197
+ hidden_states = self.input_layernorm(hidden_states)
198
+ attn_outputs = self.self_attn(
199
+ hidden_states,
200
+ attention_mask=attention_mask,
201
+ past_key_value=past_key_value,
202
+ output_attentions=output_attentions,
203
+ use_cache=use_cache,
204
+ )
205
+ attn_output = attn_outputs[0]
206
+ outputs = attn_outputs[1:]
207
+
208
+ hidden_states = residual + attn_output
209
+
210
+ residual = hidden_states
211
+ hidden_states = self.post_attention_layernorm(hidden_states)
212
+ hidden_states = self.mlp(hidden_states)
213
+ hidden_states = residual + hidden_states
214
+
215
+ if use_cache:
216
+ outputs = (past_key_value,) + outputs
217
+
218
+ return (hidden_states,) + outputs
219
+
220
+ # -------------------- LlamaGloo Model --------------------
221
+ class LlamaGlooModel(PreTrainedModel):
222
+ config_class = LlamaGlooConfig
223
+
224
+ def __init__(self, config: LlamaGlooConfig):
225
+ super().__init__(config)
226
+ self.padding_idx = config.pad_token_id
227
+ self.vocab_size = config.vocab_size
228
+
229
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
230
+ self.layers = nn.ModuleList([LlamaGlooDecoderLayer(config) for _ in range(config.num_hidden_layers)])
231
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) if config.use_rms_norm else nn.LayerNorm(config.hidden_size)
232
+
233
+ self.gradient_checkpointing = False
234
+ self.post_init()
235
+
236
+ def get_input_embeddings(self):
237
+ return self.embed_tokens
238
+
239
+ def set_input_embeddings(self, value):
240
+ self.embed_tokens = value
241
+
242
+ def forward(
243
+ self,
244
+ input_ids=None,
245
+ attention_mask=None,
246
+ past_key_values=None,
247
+ inputs_embeds=None,
248
+ use_cache=None,
249
+ output_attentions=None,
250
+ output_hidden_states=None,
251
+ return_dict=None,
252
+ ):
253
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
254
+ output_hidden_states = (
255
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
256
+ )
257
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
258
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
259
+
260
+ if input_ids is not None and inputs_embeds is not None:
261
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
262
+ elif input_ids is not None:
263
+ input_shape = input_ids.size()
264
+ input_ids = input_ids.view(-1, input_shape[-1])
265
+ batch_size = input_ids.shape[0]
266
+ elif inputs_embeds is not None:
267
+ input_shape = inputs_embeds.size()[:-1]
268
+ batch_size = inputs_embeds.shape[0]
269
+ else:
270
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
271
+
272
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
273
+
274
+ if past_key_values is None:
275
+ past_key_values = tuple([None] * len(self.layers))
276
+
277
+ if attention_mask is not None:
278
+ if batch_size <= 0:
279
+ raise ValueError("batch_size has to be defined and > 0")
280
+ attention_mask = attention_mask.to(device)
281
+ if attention_mask.dim() == 3:
282
+ extended_attention_mask = attention_mask[:, None, :, :]
283
+ elif attention_mask.dim() == 2:
284
+ extended_attention_mask = attention_mask[:, None, None, :]
285
+ else:
286
+ raise ValueError(
287
+ f"Wrong number of dimensions of attention_mask. Expected 2 or 3, but got {attention_mask.dim()}"
288
+ )
289
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype)
290
+ extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min
291
+ else:
292
+ extended_attention_mask = None
293
+
294
+ if inputs_embeds is None:
295
+ inputs_embeds = self.embed_tokens(input_ids)
296
+
297
+ hidden_states = inputs_embeds
298
+
299
+ all_hidden_states = () if output_hidden_states else None
300
+ all_self_attns = () if output_attentions else None
301
+ next_decoder_cache = () if use_cache else None
302
+
303
+ for idx, decoder_layer in enumerate(self.layers):
304
+ if output_hidden_states:
305
+ all_hidden_states += (hidden_states,)
306
+
307
+ past_key_value = past_key_values[idx]
308
+
309
+ layer_outputs = decoder_layer(
310
+ hidden_states,
311
+ attention_mask=extended_attention_mask,
312
+ past_key_value=past_key_value,
313
+ output_attentions=output_attentions,
314
+ use_cache=use_cache,
315
+ )
316
+
317
+ hidden_states = layer_outputs[0]
318
+
319
+ if use_cache:
320
+ next_decoder_cache += (layer_outputs[1],)
321
+
322
+ if output_attentions:
323
+ all_self_attns += (layer_outputs[2],)
324
+
325
+ hidden_states = self.norm(hidden_states)
326
+
327
+ if output_hidden_states:
328
+ all_hidden_states += (hidden_states,)
329
+
330
+ next_cache = next_decoder_cache if use_cache else None
331
+ if not return_dict:
332
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
333
+ return CausalLMOutputWithPast(
334
+ last_hidden_state=hidden_states,
335
+ past_key_values=next_cache,
336
+ hidden_states=all_hidden_states,
337
+ attentions=all_self_attns,
338
+ )
339
+
340
+ # -------------------- LlamaGloo For Causal LM --------------------
341
+ class LlamaGlooForCausalLM(PreTrainedModel):
342
+ config_class = LlamaGlooConfig
343
+
344
+ def __init__(self, config):
345
+ super().__init__(config)
346
+ self.model = LlamaGlooModel(config)
347
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
348
+ self.post_init()
349
+
350
+ def get_input_embeddings(self):
351
+ return self.model.embed_tokens
352
+
353
+ def set_input_embeddings(self, value):
354
+ self.model.embed_tokens = value
355
+
356
+ def get_output_embeddings(self):
357
+ return self.lm_head
358
+
359
+ def set_output_embeddings(self, new_embeddings):
360
+ self.lm_head = new_embeddings
361
+
362
+ def set_decoder(self, decoder):
363
+ self.model = decoder
364
+
365
+ def get_decoder(self):
366
+ return self.model
367
+
368
+ def forward(
369
+ self,
370
+ input_ids=None,
371
+ attention_mask=None,
372
+ past_key_values=None,
373
+ inputs_embeds=None,
374
+ labels=None,
375
+ use_cache=None,
376
+ output_attentions=None,
377
+ output_hidden_states=None,
378
+ return_dict=None,
379
+ ):
380
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
381
+
382
+ outputs = self.model(
383
+ input_ids=input_ids,
384
+ attention_mask=attention_mask,
385
+ past_key_values=past_key_values,
386
+ inputs_embeds=inputs_embeds,
387
+ use_cache=use_cache,
388
+ output_attentions=output_attentions,
389
+ output_hidden_states=output_hidden_states,
390
+ return_dict=return_dict,
391
+ )
392
+
393
+ logits = self.lm_head(outputs[0])
394
+
395
+ loss = None
396
+ if labels is not None:
397
+ shift_logits = logits[..., :-1, :].contiguous()
398
+ shift_labels = labels[..., 1:].contiguous()
399
+ loss_fct = nn.CrossEntropyLoss()
400
+ loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
401
+
402
+ if not return_dict:
403
+ output = (logits,) + outputs[1:]
404
+ return ((loss,) + output) if loss is not None else output
405
+
406
+ return CausalLMOutputWithPast(
407
+ loss=loss,
408
+ logits=logits,
409
+ past_key_values=outputs.past_key_values,
410
+ hidden_states=outputs.hidden_states,
411
+ attentions=outputs.attentions,
412
+ )