appledora commited on
Commit
4098668
·
verified ·
1 Parent(s): f68b757

Upload modeling_recast_llama.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_recast_llama.py +670 -0
modeling_recast_llama.py ADDED
@@ -0,0 +1,670 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # filename: recastmlp_llama_model.py
2
+ from .configuration_recast_llama import RECAST1B_llama
3
+ from transformers import PreTrainedModel
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from typing import Optional, Tuple, Union, List
8
+ from transformers import AutoConfig
9
+ from transformers.utils import logging
10
+ from transformers.cache_utils import Cache, StaticCache
11
+ from transformers.modeling_outputs import CausalLMOutputWithPast
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
14
+
15
+ logger = logging.get_logger(__name__)
16
+
17
+
18
+ class MLPTemplateBank(nn.Module):
19
+ def __init__(self, config, num_templates):
20
+ super().__init__()
21
+ self.num_templates = num_templates
22
+ self.hidden_size = config.hidden_size
23
+ self.intermediate_size = config.intermediate_size
24
+
25
+ # Store templates in a more efficient layout
26
+ self.up_templates = nn.Parameter(
27
+ torch.empty(num_templates, self.intermediate_size * self.hidden_size)
28
+ )
29
+ self.gate_templates = nn.Parameter(
30
+ torch.empty(num_templates, self.intermediate_size * self.hidden_size)
31
+ )
32
+ self.down_templates = nn.Parameter(
33
+ torch.empty(num_templates, self.hidden_size * self.intermediate_size)
34
+ )
35
+
36
+ nn.init.kaiming_normal_(self.up_templates)
37
+ nn.init.kaiming_normal_(self.gate_templates)
38
+ nn.init.kaiming_normal_(self.down_templates)
39
+
40
+ def forward(self, up_coeffs, gate_coeffs, down_coeffs):
41
+ # Simple matrix multiplication instead of broadcasting
42
+ up_weights = torch.mm(up_coeffs, self.up_templates)
43
+ gate_weights = torch.mm(gate_coeffs, self.gate_templates)
44
+ down_weights = torch.mm(down_coeffs, self.down_templates)
45
+ up_weights = up_weights.view(self.intermediate_size, self.hidden_size)
46
+ gate_weights = gate_weights.view(self.intermediate_size, self.hidden_size)
47
+ down_weights = down_weights.view(self.hidden_size, self.intermediate_size)
48
+ return gate_weights, up_weights, down_weights
49
+
50
+
51
+ class SharedLlamaMLP(nn.Module):
52
+ def __init__(self, config, bank):
53
+ super().__init__()
54
+ self.config = config
55
+ self.bank = bank
56
+ self.hidden_size = config.hidden_size
57
+ self.intermediate_size = config.intermediate_size
58
+ self.up_coefficients = nn.Parameter(torch.zeros(1, config.num_templates))
59
+ self.gate_coefficients = nn.Parameter(torch.zeros(1, config.num_templates))
60
+ self.down_coefficients = nn.Parameter(torch.zeros(1, config.num_templates))
61
+
62
+ nn.init.normal_(self.up_coefficients, mean=0.0, std=1.0)
63
+ nn.init.normal_(self.gate_coefficients, mean=0.0, std=1.0)
64
+ nn.init.normal_(self.down_coefficients, mean=0.0, std=1.0)
65
+ if config.mlp_bias:
66
+ self.gate_bias = nn.Parameter(torch.zeros(self.intermediate_size))
67
+ self.up_bias = nn.Parameter(torch.zeros(self.intermediate_size))
68
+ self.down_bias = nn.Parameter(torch.zeros(self.hidden_size))
69
+ else:
70
+ self.register_parameter("gate_bias", None)
71
+ self.register_parameter("up_bias", None)
72
+ self.register_parameter("down_bias", None)
73
+
74
+ self.act_fn = F.silu
75
+
76
+ def forward(self, x):
77
+ # Generate weights with minimal operations
78
+ gate_weights, up_weights, down_weights = self.bank(
79
+ self.up_coefficients, self.gate_coefficients, self.down_coefficients
80
+ )
81
+
82
+ # Standard MLP operations
83
+ gate_output = F.linear(x, gate_weights, self.gate_bias)
84
+ up_output = F.linear(x, up_weights, self.up_bias)
85
+
86
+ hidden_states = self.act_fn(gate_output) * up_output
87
+ output = F.linear(hidden_states, down_weights, self.down_bias)
88
+
89
+ return output
90
+
91
+
92
+ def fixed_cross_entropy(
93
+ source,
94
+ target,
95
+ num_items_in_batch: int = None,
96
+ ignore_index: int = -100,
97
+ **kwargs,
98
+ ):
99
+ reduction = "sum" if num_items_in_batch is not None else "mean"
100
+ loss = nn.functional.cross_entropy(
101
+ source, target, ignore_index=ignore_index, reduction=reduction
102
+ )
103
+ if reduction == "sum":
104
+ loss = loss / num_items_in_batch
105
+ return loss
106
+
107
+
108
+ from transformers.models.llama.modeling_llama import (
109
+ LlamaDecoderLayer,
110
+ LlamaRotaryEmbedding,
111
+ LlamaRMSNorm,
112
+ )
113
+ from transformers.modeling_outputs import BaseModelOutputWithPast
114
+
115
+
116
+ class RECAST1B_llamaModel(PreTrainedModel):
117
+ config_class = RECAST1B_llama
118
+ base_model_prefix = "llama"
119
+ supports_gradient_checkpointing = True
120
+
121
+ def __init__(self, config):
122
+ super().__init__(config)
123
+ self.padding_idx = config.pad_token_id
124
+ self.vocab_size = config.vocab_size
125
+
126
+ self.embed_tokens = nn.Embedding(
127
+ config.vocab_size, config.hidden_size, self.padding_idx
128
+ )
129
+ # Initialize rotary embeddings
130
+ rope_config = config.rope_scaling
131
+ if rope_config:
132
+ rope_type = rope_config.get("rope_type", "default")
133
+ scaling_factor = rope_config.get("factor", 1.0)
134
+ else:
135
+ rope_type = "default"
136
+ scaling_factor = None
137
+ original_config = AutoConfig.from_pretrained(
138
+ "meta-llama/Llama-3.2-1b", trust_remote_code=True
139
+ )
140
+ self.rotary_emb = LlamaRotaryEmbedding(
141
+ config=original_config,
142
+ )
143
+
144
+ # Create template banks first
145
+ self.banks = []
146
+ layers_per_group = config.num_hidden_layers // config.num_groups
147
+ for _ in range(config.num_groups):
148
+ bank = MLPTemplateBank(config, config.num_templates)
149
+ self.banks.append(bank)
150
+
151
+ # Create layers using LlamaDecoderLayer but replace MLPs
152
+ self.layers = nn.ModuleList()
153
+ for layer_idx in range(config.num_hidden_layers):
154
+ # Create standard LlamaDecoderLayer
155
+ decoder_layer = LlamaDecoderLayer(config, layer_idx)
156
+
157
+ # Replace its MLP with our SharedLlamaMLP
158
+ group_idx = layer_idx // layers_per_group
159
+ group_bank = self.banks[group_idx]
160
+ decoder_layer.mlp = SharedLlamaMLP(config, bank=group_bank)
161
+
162
+ self.layers.append(decoder_layer)
163
+
164
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
165
+ self.gradient_checkpointing = False
166
+
167
+ def forward(
168
+ self,
169
+ input_ids: torch.LongTensor = None,
170
+ attention_mask: Optional[torch.Tensor] = None,
171
+ position_ids: Optional[torch.LongTensor] = None,
172
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
173
+ inputs_embeds: Optional[torch.FloatTensor] = None,
174
+ use_cache: Optional[bool] = None,
175
+ output_attentions: Optional[bool] = None,
176
+ output_hidden_states: Optional[bool] = None,
177
+ return_dict: Optional[bool] = None,
178
+ cache_position: Optional[torch.LongTensor] = None,
179
+ **flash_attn_kwargs,
180
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
181
+ output_attentions = (
182
+ output_attentions
183
+ if output_attentions is not None
184
+ else self.config.output_attentions
185
+ )
186
+ output_hidden_states = (
187
+ output_hidden_states
188
+ if output_hidden_states is not None
189
+ else self.config.output_hidden_states
190
+ )
191
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
192
+ return_dict = (
193
+ return_dict if return_dict is not None else self.config.use_return_dict
194
+ )
195
+
196
+ if (input_ids is None) ^ (inputs_embeds is not None):
197
+ raise ValueError(
198
+ "You must specify exactly one of input_ids or inputs_embeds"
199
+ )
200
+
201
+ if self.gradient_checkpointing and self.training and use_cache:
202
+ logger.warning_once(
203
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
204
+ )
205
+ use_cache = False
206
+
207
+ if inputs_embeds is None:
208
+ inputs_embeds = self.embed_tokens(input_ids)
209
+ # Set up cache position if not provided
210
+ if cache_position is None:
211
+ past_seen_tokens = (
212
+ 0
213
+ if past_key_values is None
214
+ else (
215
+ past_key_values.get_seq_length()
216
+ if isinstance(past_key_values, Cache)
217
+ else past_key_values[0][0].size(-2) if past_key_values else 0
218
+ )
219
+ )
220
+ cache_position = torch.arange(
221
+ past_seen_tokens,
222
+ past_seen_tokens + inputs_embeds.shape[1],
223
+ device=inputs_embeds.device,
224
+ )
225
+ # Create position embeddings to be shared across the decoder layers
226
+ # Set up position IDs if not provided
227
+ if position_ids is None:
228
+ position_ids = cache_position.unsqueeze(0)
229
+ # Get updated causal mask
230
+ causal_mask = self._update_causal_mask(
231
+ attention_mask,
232
+ inputs_embeds,
233
+ cache_position,
234
+ past_key_values,
235
+ output_attentions,
236
+ )
237
+ hidden_states = inputs_embeds
238
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
239
+
240
+ # Initialize outputs
241
+ all_hidden_states = () if output_hidden_states else None
242
+ all_self_attns = () if output_attentions else None
243
+ next_decoder_cache = None
244
+
245
+ # Process through layers
246
+ for decoder_layer in self.layers:
247
+ if output_hidden_states:
248
+ all_hidden_states += (hidden_states,)
249
+
250
+ if self.gradient_checkpointing and self.training:
251
+ layer_outputs = self._gradient_checkpointing_func(
252
+ decoder_layer.__call__,
253
+ hidden_states,
254
+ causal_mask,
255
+ position_ids,
256
+ past_key_values,
257
+ output_attentions,
258
+ use_cache,
259
+ position_embeddings,
260
+ )
261
+ else:
262
+ layer_outputs = decoder_layer(
263
+ hidden_states,
264
+ attention_mask=causal_mask,
265
+ position_ids=position_ids,
266
+ past_key_value=past_key_values,
267
+ output_attentions=output_attentions,
268
+ use_cache=use_cache,
269
+ position_embeddings=position_embeddings,
270
+ **flash_attn_kwargs,
271
+ )
272
+
273
+ hidden_states = layer_outputs[0]
274
+
275
+ if use_cache:
276
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
277
+
278
+ if output_attentions:
279
+ all_self_attns += (layer_outputs[1],)
280
+
281
+ # Final layer norm
282
+ hidden_states = self.norm(hidden_states)
283
+
284
+ # Add last hidden state
285
+ if output_hidden_states:
286
+ all_hidden_states += (hidden_states,)
287
+
288
+ next_cache = next_decoder_cache if use_cache else None
289
+
290
+ if not return_dict:
291
+ return tuple(
292
+ v
293
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
294
+ if v is not None
295
+ )
296
+
297
+ return BaseModelOutputWithPast(
298
+ last_hidden_state=hidden_states,
299
+ past_key_values=next_cache,
300
+ hidden_states=all_hidden_states,
301
+ attentions=all_self_attns,
302
+ )
303
+
304
+ @classmethod
305
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
306
+ if isinstance(
307
+ pretrained_model_name_or_path, str
308
+ ) and pretrained_model_name_or_path.endswith(".pt"):
309
+ print("Loading from local checkpoint")
310
+ # Load from local checkpoint
311
+ config = kwargs.get("config", None)
312
+ if config is None:
313
+ config = AutoConfig.from_pretrained(
314
+ pretrained_model_name_or_path, trust_remote_code=True
315
+ )
316
+
317
+ model = cls(config)
318
+ checkpoint = torch.load(pretrained_model_name_or_path, map_location="cpu")
319
+ state_dict = checkpoint["model_state_dict"]
320
+ logger.info(
321
+ f"Loaded checkpoint from epoch {checkpoint.get('epoch')} with loss {checkpoint.get('loss')}"
322
+ )
323
+
324
+ missing_keys, unexpected_keys = model.load_state_dict(
325
+ state_dict, strict=False
326
+ )
327
+
328
+ if len(missing_keys) > 0:
329
+ logger.warning(f"Missing keys: {missing_keys}")
330
+ if len(unexpected_keys) > 0:
331
+ logger.warning(f"Unexpected keys: {unexpected_keys}")
332
+
333
+ return model
334
+ else:
335
+ print("Loading from hub")
336
+ # Load from hub using parent's from_pretrained
337
+ return super().from_pretrained(
338
+ pretrained_model_name_or_path, *model_args, **kwargs
339
+ )
340
+
341
+ def get_input_embeddings(self):
342
+ return self.embed_tokens
343
+
344
+ def set_input_embeddings(self, value):
345
+ self.embed_tokens = value
346
+
347
+ def _update_causal_mask(
348
+ self,
349
+ attention_mask: torch.Tensor,
350
+ input_tensor: torch.Tensor,
351
+ cache_position: torch.Tensor,
352
+ past_key_values: Cache,
353
+ output_attentions: bool,
354
+ ):
355
+ if self.config._attn_implementation == "flash_attention_2":
356
+ if attention_mask is not None and 0.0 in attention_mask:
357
+ return attention_mask
358
+ return None
359
+
360
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
361
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
362
+ # to infer the attention mask.
363
+ past_seen_tokens = (
364
+ past_key_values.get_seq_length() if past_key_values is not None else 0
365
+ )
366
+ using_static_cache = isinstance(past_key_values, StaticCache)
367
+
368
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
369
+ if (
370
+ self.config._attn_implementation == "sdpa"
371
+ and not using_static_cache
372
+ and not output_attentions
373
+ ):
374
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
375
+ attention_mask,
376
+ inputs_embeds=input_tensor,
377
+ past_key_values_length=past_seen_tokens,
378
+ is_training=self.training,
379
+ ):
380
+ return None
381
+
382
+ dtype, device = input_tensor.dtype, input_tensor.device
383
+ sequence_length = input_tensor.shape[1]
384
+ if using_static_cache:
385
+ target_length = past_key_values.get_max_cache_shape()
386
+ else:
387
+ target_length = (
388
+ attention_mask.shape[-1]
389
+ if isinstance(attention_mask, torch.Tensor)
390
+ else past_seen_tokens + sequence_length + 1
391
+ )
392
+
393
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
394
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
395
+ attention_mask,
396
+ sequence_length=sequence_length,
397
+ target_length=target_length,
398
+ dtype=dtype,
399
+ device=device,
400
+ cache_position=cache_position,
401
+ batch_size=input_tensor.shape[0],
402
+ )
403
+
404
+ if (
405
+ self.config._attn_implementation == "sdpa"
406
+ and attention_mask is not None
407
+ and attention_mask.device.type == "cuda"
408
+ and not output_attentions
409
+ ):
410
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
411
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
412
+ # Details: https://github.com/pytorch/pytorch/issues/110213
413
+ min_dtype = torch.finfo(dtype).min
414
+ causal_mask = AttentionMaskConverter._unmask_unattended(
415
+ causal_mask, min_dtype
416
+ )
417
+
418
+ return causal_mask
419
+
420
+ @staticmethod
421
+ def _prepare_4d_causal_attention_mask_with_cache_position(
422
+ attention_mask: torch.Tensor,
423
+ sequence_length: int,
424
+ target_length: int,
425
+ dtype: torch.dtype,
426
+ device: torch.device,
427
+ cache_position: torch.Tensor,
428
+ batch_size: int,
429
+ **kwargs,
430
+ ):
431
+ if attention_mask is not None and attention_mask.dim() == 4:
432
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
433
+ causal_mask = attention_mask
434
+ else:
435
+ min_dtype = torch.finfo(dtype).min
436
+ causal_mask = torch.full(
437
+ (sequence_length, target_length),
438
+ fill_value=min_dtype,
439
+ dtype=dtype,
440
+ device=device,
441
+ )
442
+ if sequence_length != 1:
443
+ causal_mask = torch.triu(causal_mask, diagonal=1)
444
+ causal_mask *= torch.arange(
445
+ target_length, device=device
446
+ ) > cache_position.reshape(-1, 1)
447
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
448
+ if attention_mask is not None:
449
+ causal_mask = (
450
+ causal_mask.clone()
451
+ ) # copy to contiguous memory for in-place edit
452
+ mask_length = attention_mask.shape[-1]
453
+ padding_mask = (
454
+ causal_mask[:, :, :, :mask_length]
455
+ + attention_mask[:, None, None, :]
456
+ )
457
+ padding_mask = padding_mask == 0
458
+ causal_mask[:, :, :, :mask_length] = causal_mask[
459
+ :, :, :, :mask_length
460
+ ].masked_fill(padding_mask, min_dtype)
461
+
462
+ return causal_mask
463
+
464
+
465
+ class RECAST1B_LlamaForCausalLM(PreTrainedModel, GenerationMixin):
466
+ _tied_weights_keys = ["lm_head.weight"]
467
+ _tp_plan = {"lm_head": "colwise_rep"}
468
+ config_class = RECAST1B_llama
469
+ base_model_prefix = "llama"
470
+ supports_gradient_checkpointing = True
471
+
472
+ def __init__(self, config):
473
+ super().__init__(config)
474
+ self.model = RECAST1B_llamaModel(config)
475
+ self.vocab_size = config.vocab_size
476
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
477
+
478
+ # Initialize weights and apply final processing
479
+ self.post_init()
480
+
481
+ def get_input_embeddings(self):
482
+ return self.model.embed_tokens
483
+
484
+ def set_input_embeddings(self, value):
485
+ self.model.embed_tokens = value
486
+
487
+ def get_output_embeddings(self):
488
+ return self.lm_head
489
+
490
+ def set_output_embeddings(self, new_embeddings):
491
+ self.lm_head = new_embeddings
492
+
493
+ def set_decoder(self, decoder):
494
+ self.model = decoder
495
+
496
+ def get_decoder(self):
497
+ return self.model
498
+
499
+ def loss_function(
500
+ self,
501
+ logits,
502
+ labels,
503
+ vocab_size: int,
504
+ num_items_in_batch: int = None,
505
+ ignore_index: int = -100,
506
+ **kwargs,
507
+ ):
508
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
509
+ logits = logits.float()
510
+ # Shift so that tokens < n predict n
511
+ shift_logits = logits[..., :-1, :].contiguous()
512
+ shift_labels = labels[..., 1:].contiguous()
513
+ # Flatten the tokens
514
+ shift_logits = shift_logits.view(-1, vocab_size)
515
+ shift_labels = shift_labels.view(-1)
516
+ # Enable model parallelism
517
+ shift_labels = shift_labels.to(shift_logits.device)
518
+ loss = fixed_cross_entropy(
519
+ shift_logits, shift_labels, num_items_in_batch, ignore_index, **kwargs
520
+ )
521
+ return loss
522
+
523
+ def forward(
524
+ self,
525
+ input_ids: torch.LongTensor = None,
526
+ attention_mask: Optional[torch.Tensor] = None,
527
+ position_ids: Optional[torch.LongTensor] = None,
528
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
529
+ inputs_embeds: Optional[torch.FloatTensor] = None,
530
+ labels: Optional[torch.LongTensor] = None,
531
+ use_cache: Optional[bool] = None,
532
+ output_attentions: Optional[bool] = None,
533
+ output_hidden_states: Optional[bool] = None,
534
+ return_dict: Optional[bool] = None,
535
+ cache_position: Optional[torch.LongTensor] = None,
536
+ num_logits_to_keep: int = 0,
537
+ **kwargs,
538
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
539
+ """
540
+ Args:
541
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
542
+ Labels for computing the masked language modeling loss. Indices should be in
543
+ `[0, ..., config.vocab_size]` or -100 (masked tokens).
544
+ num_logits_to_keep (`int`, *optional*):
545
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate all logits.
546
+ """
547
+ output_attentions = (
548
+ output_attentions
549
+ if output_attentions is not None
550
+ else self.config.output_attentions
551
+ )
552
+ output_hidden_states = (
553
+ output_hidden_states
554
+ if output_hidden_states is not None
555
+ else self.config.output_hidden_states
556
+ )
557
+ return_dict = (
558
+ return_dict if return_dict is not None else self.config.use_return_dict
559
+ )
560
+
561
+ outputs = self.model(
562
+ input_ids=input_ids,
563
+ attention_mask=attention_mask,
564
+ position_ids=position_ids,
565
+ past_key_values=past_key_values,
566
+ inputs_embeds=inputs_embeds,
567
+ use_cache=use_cache,
568
+ output_attentions=output_attentions,
569
+ output_hidden_states=output_hidden_states,
570
+ return_dict=return_dict,
571
+ cache_position=cache_position,
572
+ **kwargs,
573
+ )
574
+
575
+ hidden_states = outputs[0]
576
+ # Only compute necessary logits
577
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
578
+
579
+ loss = None
580
+ if labels is not None:
581
+ # Calculate batch size for loss function
582
+ num_items_in_batch = (
583
+ input_ids.size(0) if input_ids is not None else inputs_embeds.size(0)
584
+ )
585
+ loss = self.loss_function(
586
+ logits=logits,
587
+ labels=labels,
588
+ vocab_size=self.config.vocab_size,
589
+ num_items_in_batch=num_items_in_batch,
590
+ **kwargs,
591
+ )
592
+
593
+ if not return_dict:
594
+ output = (logits,) + outputs[1:]
595
+ return (loss,) + output if loss is not None else output
596
+
597
+ return CausalLMOutputWithPast(
598
+ loss=loss,
599
+ logits=logits,
600
+ past_key_values=outputs.past_key_values,
601
+ hidden_states=outputs.hidden_states,
602
+ attentions=outputs.attentions,
603
+ )
604
+
605
+ def prepare_inputs_for_generation(
606
+ self,
607
+ input_ids,
608
+ past_key_values=None,
609
+ attention_mask=None,
610
+ inputs_embeds=None,
611
+ **kwargs,
612
+ ):
613
+ if past_key_values:
614
+ input_ids = input_ids[:, -1:]
615
+
616
+ position_ids = kwargs.get("position_ids", None)
617
+ if attention_mask is not None and position_ids is None:
618
+ # create position_ids on the fly for batch generation
619
+ position_ids = attention_mask.long().cumsum(-1) - 1
620
+ position_ids.masked_fill_(attention_mask == 0, 1)
621
+ if past_key_values:
622
+ position_ids = position_ids[:, -1].unsqueeze(-1)
623
+
624
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
625
+ if inputs_embeds is not None and past_key_values is None:
626
+ model_inputs = {"inputs_embeds": inputs_embeds}
627
+ else:
628
+ model_inputs = {"input_ids": input_ids}
629
+
630
+ model_inputs.update(
631
+ {
632
+ "position_ids": position_ids,
633
+ "past_key_values": past_key_values,
634
+ "use_cache": kwargs.get("use_cache"),
635
+ "attention_mask": attention_mask,
636
+ }
637
+ )
638
+ return model_inputs
639
+
640
+ @classmethod
641
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
642
+ if isinstance(
643
+ pretrained_model_name_or_path, str
644
+ ) and pretrained_model_name_or_path.endswith(".pt"):
645
+ print("Loading from local checkpoint")
646
+ config = kwargs.get("config", None)
647
+ if config is None:
648
+ config = AutoConfig.from_pretrained(
649
+ pretrained_model_name_or_path, trust_remote_code=True
650
+ )
651
+
652
+ model = cls(config)
653
+ checkpoint = torch.load(pretrained_model_name_or_path, map_location="cpu")
654
+ state_dict = checkpoint["model_state_dict"]
655
+
656
+ missing_keys, unexpected_keys = model.load_state_dict(
657
+ state_dict, strict=False
658
+ )
659
+
660
+ if len(missing_keys) > 0:
661
+ logger.warning(f"Missing keys: {missing_keys}")
662
+ if len(unexpected_keys) > 0:
663
+ logger.warning(f"Unexpected keys: {unexpected_keys}")
664
+
665
+ return model
666
+ else:
667
+ print("Loading from hub")
668
+ return super().from_pretrained(
669
+ pretrained_model_name_or_path, *model_args, **kwargs
670
+ )