appledora commited on
Commit
0cc47ae
·
verified ·
1 Parent(s): d1fcbb1

Upload 6 files

Browse files
__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.utils import (
2
+ OptionalDependencyNotAvailable,
3
+ _LazyModule,
4
+ is_torch_available,
5
+ )
6
+
7
+ try:
8
+ if not is_torch_available():
9
+ raise OptionalDependencyNotAvailable()
10
+ except OptionalDependencyNotAvailable:
11
+ pass
12
+ else:
13
+ from .modeling_optrecastmlp_llama import (
14
+ OPTRECASTMLP_llamaModel,
15
+ OPTRECASTMLP_LlamaForCausalLM,
16
+ )
17
+ from .configuration_optrecastmlp_llama import OPTRECASTMLP_llama
18
+
19
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
20
+
21
+ # Register your models with Auto classes
22
+ AutoConfig.register("optrecastmlp_llama", OPTRECASTMLP_llama)
23
+ AutoModel.register(OPTRECASTMLP_llama, OPTRECASTMLP_llamaModel)
24
+ AutoModelForCausalLM.register(OPTRECASTMLP_llama, OPTRECASTMLP_LlamaForCausalLM)
25
+
26
+ _import_structure = {
27
+ "configuration_optrecastmlp_llama": ["OPTRECASTMLP_llama"],
28
+ "modeling_optrecastmlp_llama": ["OPTRECASTMLP_llamaModel", "OPTRECASTMLP_LlamaForCausalLM"],
29
+ }
30
+
31
+ __all__ = ["OPTRECASTMLP_llamaModel", "OPTRECASTMLP_LlamaForCausalLM", "OPTRECASTMLP_llama"]
config.json CHANGED
@@ -1,37 +1,98 @@
1
  {
2
- "architectures": [
3
- "OPTRECASTMLP_LlamaForCausalLM"
4
- ],
5
- "attention_bias": false,
6
- "attention_dropout": 0.0,
7
- "bos_token_id": 128000,
8
- "eos_token_id": 128001,
9
- "hidden_act": "silu",
10
  "hidden_size": 4096,
11
- "initializer_range": 0.02,
12
  "intermediate_size": 14336,
13
- "max_position_embeddings": 131072,
14
- "mlp_bias": false,
15
- "model_type": "optrecastmlp_llama",
16
- "num_attention_heads": 32,
17
- "num_cf": 1,
18
- "num_groups": 8,
19
  "num_hidden_layers": 32,
 
20
  "num_key_value_heads": 8,
21
- "num_templates": 4,
22
- "pretraining_tp": 1,
23
  "rms_norm_eps": 1e-05,
 
 
 
 
 
 
24
  "rope_scaling": {
25
  "factor": 8.0,
26
- "high_freq_factor": 4.0,
27
  "low_freq_factor": 1.0,
 
28
  "original_max_position_embeddings": 8192,
29
  "rope_type": "llama3"
30
  },
31
- "rope_theta": 500000.0,
 
 
 
 
 
 
 
 
 
 
32
  "tie_word_embeddings": false,
33
- "torch_dtype": "float32",
34
- "transformers_version": "4.46.3",
35
- "use_cache": true,
36
- "vocab_size": 128256
37
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  {
2
+ "vocab_size": 128256,
3
+ "max_position_embeddings": 131072,
 
 
 
 
 
 
4
  "hidden_size": 4096,
 
5
  "intermediate_size": 14336,
 
 
 
 
 
 
6
  "num_hidden_layers": 32,
7
+ "num_attention_heads": 32,
8
  "num_key_value_heads": 8,
9
+ "hidden_act": "silu",
10
+ "initializer_range": 0.02,
11
  "rms_norm_eps": 1e-05,
12
+ "pretraining_tp": 1,
13
+ "use_cache": true,
14
+ "mlp_bias": false,
15
+ "attention_bias": false,
16
+ "attention_dropout": 0.0,
17
+ "rope_theta": 500000.0,
18
  "rope_scaling": {
19
  "factor": 8.0,
 
20
  "low_freq_factor": 1.0,
21
+ "high_freq_factor": 4.0,
22
  "original_max_position_embeddings": 8192,
23
  "rope_type": "llama3"
24
  },
25
+ "torch_dtype": null,
26
+ "num_templates": 4,
27
+ "num_groups": 8,
28
+ "num_cf": 1,
29
+ "return_dict": true,
30
+ "output_hidden_states": false,
31
+ "output_attentions": false,
32
+ "torchscript": false,
33
+ "use_bfloat16": false,
34
+ "tf_legacy_loss": false,
35
+ "pruned_heads": {},
36
  "tie_word_embeddings": false,
37
+ "chunk_size_feed_forward": 0,
38
+ "is_encoder_decoder": false,
39
+ "is_decoder": false,
40
+ "cross_attention_hidden_size": null,
41
+ "add_cross_attention": false,
42
+ "tie_encoder_decoder": false,
43
+ "max_length": 20,
44
+ "min_length": 0,
45
+ "do_sample": false,
46
+ "early_stopping": false,
47
+ "num_beams": 1,
48
+ "num_beam_groups": 1,
49
+ "diversity_penalty": 0.0,
50
+ "temperature": 1.0,
51
+ "top_k": 50,
52
+ "top_p": 1.0,
53
+ "typical_p": 1.0,
54
+ "repetition_penalty": 1.0,
55
+ "length_penalty": 1.0,
56
+ "no_repeat_ngram_size": 0,
57
+ "encoder_no_repeat_ngram_size": 0,
58
+ "bad_words_ids": null,
59
+ "num_return_sequences": 1,
60
+ "output_scores": false,
61
+ "return_dict_in_generate": false,
62
+ "forced_bos_token_id": null,
63
+ "forced_eos_token_id": null,
64
+ "remove_invalid_values": false,
65
+ "exponential_decay_length_penalty": null,
66
+ "suppress_tokens": null,
67
+ "begin_suppress_tokens": null,
68
+ "architectures": [
69
+ "OPTRECASTMLP_LlamaForCausalLM"
70
+ ],
71
+ "finetuning_task": null,
72
+ "id2label": {
73
+ "0": "LABEL_0",
74
+ "1": "LABEL_1"
75
+ },
76
+ "label2id": {
77
+ "LABEL_0": 0,
78
+ "LABEL_1": 1
79
+ },
80
+ "tokenizer_class": null,
81
+ "prefix": null,
82
+ "bos_token_id": 128000,
83
+ "pad_token_id": null,
84
+ "eos_token_id": 128001,
85
+ "sep_token_id": null,
86
+ "decoder_start_token_id": null,
87
+ "task_specific_params": null,
88
+ "problem_type": null,
89
+ "_name_or_path": "",
90
+ "_attn_implementation_autoset": true,
91
+ "transformers_version": "4.36.0",
92
+ "model_type": "optrecastmlp_llama",
93
+ "auto_map": {
94
+ "AutoConfig": "configuration_optrecastmlp_llama.OPTRECASTMLP_llama",
95
+ "AutoModel": "modeling_optrecastmlp_llama.OPTRECASTMLP_llamaModel",
96
+ "AutoModelForCausalLM": "modeling_optrecastmlp_llama.OPTRECASTMLP_LlamaForCausalLM"
97
+ }
98
+ }
configuration_optrecastmlp_llama.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class OPTRECASTMLP_llama(PretrainedConfig):
5
+ model_type = "optrecastmlp_llama"
6
+ attribute_map = {
7
+ "hidden_size": "hidden_size",
8
+ "num_attention_heads": "num_attention_heads",
9
+ }
10
+
11
+ def __init__(
12
+ self,
13
+ vocab_size=128256,
14
+ hidden_size=4096,
15
+ intermediate_size=14336,
16
+ num_hidden_layers=32,
17
+ num_attention_heads=32,
18
+ num_key_value_heads=8,
19
+ hidden_act="silu",
20
+ max_position_embeddings=131072,
21
+ initializer_range=0.02,
22
+ rms_norm_eps=1e-5,
23
+ use_cache=True,
24
+ pad_token_id=None,
25
+ bos_token_id=128000,
26
+ eos_token_id=128001,
27
+ pretraining_tp=1,
28
+ tie_word_embeddings=False,
29
+ rope_theta=500000.0,
30
+ rope_scaling={
31
+ "factor": 8.0,
32
+ "low_freq_factor": 1.0,
33
+ "high_freq_factor": 4.0,
34
+ "original_max_position_embeddings": 8192,
35
+ "rope_type": "llama3",
36
+ },
37
+ attention_bias=False,
38
+ attention_dropout=0.0,
39
+ mlp_bias=False,
40
+ # Template-specific configs
41
+ num_templates=4,
42
+ num_groups=8,
43
+ num_cf=1,
44
+ torch_dtype="bfloat16",
45
+ **kwargs
46
+ ):
47
+ self.vocab_size = vocab_size
48
+ self.max_position_embeddings = max_position_embeddings
49
+ self.hidden_size = hidden_size
50
+ self.intermediate_size = intermediate_size
51
+ self.num_hidden_layers = num_hidden_layers
52
+ self.num_attention_heads = num_attention_heads
53
+ self.num_key_value_heads = num_key_value_heads
54
+ self.hidden_act = hidden_act
55
+ self.initializer_range = initializer_range
56
+ self.rms_norm_eps = rms_norm_eps
57
+ self.pretraining_tp = pretraining_tp
58
+ self.use_cache = use_cache
59
+ self.mlp_bias = mlp_bias
60
+ self.attention_bias = attention_bias
61
+ self.attention_dropout = attention_dropout
62
+ self.rope_theta = rope_theta
63
+ self.rope_scaling = rope_scaling
64
+ self.torch_dtype = torch_dtype
65
+
66
+ # Template-specific configs
67
+ self.num_templates = num_templates
68
+ self.num_groups = num_groups
69
+ self.num_cf = num_cf
70
+
71
+ super().__init__(
72
+ pad_token_id=pad_token_id,
73
+ bos_token_id=bos_token_id,
74
+ eos_token_id=eos_token_id,
75
+ tie_word_embeddings=tie_word_embeddings,
76
+ **kwargs
77
+ )
metadata.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"library_name": "transformers", "model_type": "recastmlp_llama", "architectures": ["RECASTMLP_llamaModel"]}
model_card.md ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: en
3
+ tags:
4
+ - llama
5
+ - template-mlp
6
+ - parameter-efficient
7
+ - mlp-modification
8
+ datasets:
9
+ - none
10
+ license: apache-2.0
11
+ pipeline_tag: text-generation
12
+ library_name: transformers
13
+ ---
14
+
15
+ # RECASTMLP-LLaMA
16
+
17
+ This model implements a parameter-efficient modification of the LLaMA architecture by replacing the standard MLP layers with template-based shared MLPs. The model maintains LLaMA's attention mechanism while reducing parameters in the feed-forward networks.
18
+
19
+ ## Model Description
20
+
21
+ ### Overview
22
+ RECASTMLP-LLaMA modifies the original LLaMA architecture by introducing template banks for MLP layers. Instead of having separate MLP weights for each transformer layer, it uses a shared set of template weights that are combined using learned coefficients.
23
+
24
+ ### Architecture Details
25
+ - **Base Model:** LLaMA 3.1 8B
26
+ - **Number of Templates:** 4
27
+ - **Number of Groups:** 8
28
+ - **Coefficients per Template:** 1
29
+ - **Coefficients** 392
30
+ - **Hidden Size:** 4096
31
+ - **Intermediate Size:** 14336
32
+ - **Number of Attention Heads:** 32
33
+ - **Number of Key-Value Heads:** 8
34
+ - **Number of Layers:** 32
35
+ - **Max Position Embeddings:** 131072
36
+ - **Vocabulary Size:** 128256
37
+
38
+
39
+ ### Key Features
40
+ 1. **Template Banks:** Uses shared template weights across groups of layers
41
+ 2. **Parameter Efficiency:** Reduces the total number of parameters by sharing MLP weights
42
+ 3. **Group-wise Sharing:** Organizes layers into groups that share template banks
43
+ 4. **Coefficient Learning:** Uses learned coefficients to combine template weights
44
+
45
+ ## Usage
46
+
47
+ ```python
48
+ from transformers import AutoModel, AutoTokenizer
49
+
50
+ # Load model and tokenizer
51
+ model = AutoModel.from_pretrained("appledora/RECASTMLP-llama3.1-f8t4", trust_remote_code=True)
52
+ tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8b")
53
+
54
+ # Prepare input
55
+ text = "Hello, how are you?"
56
+ inputs = tokenizer(text, return_tensors="pt")
57
+
58
+ # Generate output
59
+ outputs = model(**inputs)
60
+ hidden_states = outputs.last_hidden_state
modeling_optrecastmlp_llama.py ADDED
@@ -0,0 +1,668 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # filename: recastmlp_llama_model.py
2
+ from .configuration_optrecastmlp_llama import OPTRECASTMLP_llama
3
+ from transformers import PreTrainedModel
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from typing import Optional, Tuple, Union, List
8
+ from transformers import AutoConfig
9
+ from transformers.utils import logging
10
+ from transformers.cache_utils import Cache, StaticCache
11
+ from transformers.modeling_outputs import CausalLMOutputWithPast
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
14
+
15
+ logger = logging.get_logger(__name__)
16
+
17
+
18
+ class MLPTemplateBank(nn.Module):
19
+ def __init__(self, config, num_templates):
20
+ super().__init__()
21
+ self.num_templates = num_templates
22
+ self.hidden_size = config.hidden_size
23
+ self.intermediate_size = config.intermediate_size
24
+
25
+ # Store templates in a more efficient layout
26
+ self.up_gate_templates = nn.Parameter(
27
+ torch.empty(num_templates, 2 * self.intermediate_size * self.hidden_size)
28
+ )
29
+ self.down_templates = nn.Parameter(
30
+ torch.empty(num_templates, self.hidden_size * self.intermediate_size)
31
+ )
32
+
33
+ # Initialize with proper scaling
34
+ std = 1.0 / (self.hidden_size**0.5)
35
+ nn.init.normal_(self.up_gate_templates, std=std)
36
+ nn.init.normal_(self.down_templates, std=std)
37
+
38
+ def forward(self, coeffs):
39
+ # Simple matrix multiplication instead of broadcasting
40
+ up_gate_weights = torch.mm(coeffs, self.up_gate_templates)
41
+ down_weights = torch.mm(coeffs, self.down_templates)
42
+
43
+ # Reshape to final dimensions
44
+ up_gate = up_gate_weights.view(2, self.intermediate_size, self.hidden_size)
45
+ return (
46
+ up_gate[0],
47
+ up_gate[1],
48
+ down_weights.view(self.hidden_size, self.intermediate_size),
49
+ )
50
+
51
+
52
+ class SharedLlamaMLP(nn.Module):
53
+ def __init__(self, config, bank):
54
+ super().__init__()
55
+ self.config = config
56
+ self.bank = bank
57
+ self.hidden_size = config.hidden_size
58
+ self.intermediate_size = config.intermediate_size
59
+
60
+ # Use transposed coefficients to avoid unnecessary operations
61
+ self.coefficients = nn.Parameter(torch.zeros(1, config.num_templates))
62
+ nn.init.normal_(self.coefficients, std=0.02)
63
+
64
+ if config.mlp_bias:
65
+ self.gate_bias = nn.Parameter(torch.zeros(self.intermediate_size))
66
+ self.up_bias = nn.Parameter(torch.zeros(self.intermediate_size))
67
+ self.down_bias = nn.Parameter(torch.zeros(self.hidden_size))
68
+ else:
69
+ self.register_parameter("gate_bias", None)
70
+ self.register_parameter("up_bias", None)
71
+ self.register_parameter("down_bias", None)
72
+
73
+ self.act_fn = F.silu
74
+
75
+ def forward(self, x):
76
+ # Generate weights with minimal operations
77
+ gate_weights, up_weights, down_weights = self.bank(self.coefficients)
78
+
79
+ # Standard MLP operations
80
+ gate_output = F.linear(x, gate_weights, self.gate_bias)
81
+ up_output = F.linear(x, up_weights, self.up_bias)
82
+
83
+ hidden_states = self.act_fn(gate_output) * up_output
84
+ output = F.linear(hidden_states, down_weights, self.down_bias)
85
+
86
+ return output
87
+
88
+
89
+ def fixed_cross_entropy(
90
+ source,
91
+ target,
92
+ num_items_in_batch: int = None,
93
+ ignore_index: int = -100,
94
+ **kwargs,
95
+ ):
96
+ reduction = "sum" if num_items_in_batch is not None else "mean"
97
+ loss = nn.functional.cross_entropy(
98
+ source, target, ignore_index=ignore_index, reduction=reduction
99
+ )
100
+ if reduction == "sum":
101
+ loss = loss / num_items_in_batch
102
+ return loss
103
+
104
+
105
+ from transformers.models.llama.modeling_llama import (
106
+ LlamaDecoderLayer,
107
+ LlamaRotaryEmbedding,
108
+ LlamaRMSNorm,
109
+ apply_rotary_pos_emb,
110
+ )
111
+ from transformers.modeling_outputs import BaseModelOutputWithPast
112
+
113
+
114
+ class OPTRECASTMLP_llamaModel(PreTrainedModel):
115
+ config_class = OPTRECASTMLP_llama
116
+ base_model_prefix = "llama"
117
+ supports_gradient_checkpointing = True
118
+
119
+ def __init__(self, config):
120
+ super().__init__(config)
121
+ self.padding_idx = config.pad_token_id
122
+ self.vocab_size = config.vocab_size
123
+
124
+ self.embed_tokens = nn.Embedding(
125
+ config.vocab_size, config.hidden_size, self.padding_idx
126
+ )
127
+ # Initialize rotary embeddings
128
+ rope_config = config.rope_scaling
129
+ if rope_config:
130
+ rope_type = rope_config.get("rope_type", "default")
131
+ scaling_factor = rope_config.get("factor", 1.0)
132
+ else:
133
+ rope_type = "default"
134
+ scaling_factor = None
135
+ original_config = AutoConfig.from_pretrained(
136
+ "meta-llama/Llama-3.1-8b", trust_remote_code=True
137
+ )
138
+ self.rotary_emb = LlamaRotaryEmbedding(
139
+ config=original_config,
140
+ )
141
+
142
+ # Create template banks first
143
+ self.banks = []
144
+ layers_per_group = config.num_hidden_layers // config.num_groups
145
+ for _ in range(config.num_groups):
146
+ bank = MLPTemplateBank(config, config.num_templates)
147
+ self.banks.append(bank)
148
+
149
+ # Create layers using LlamaDecoderLayer but replace MLPs
150
+ self.layers = nn.ModuleList()
151
+ for layer_idx in range(config.num_hidden_layers):
152
+ # Create standard LlamaDecoderLayer
153
+ decoder_layer = LlamaDecoderLayer(config, layer_idx)
154
+
155
+ # Replace its MLP with our SharedLlamaMLP
156
+ group_idx = layer_idx // layers_per_group
157
+ group_bank = self.banks[group_idx]
158
+ decoder_layer.mlp = SharedLlamaMLP(config, bank=group_bank)
159
+
160
+ self.layers.append(decoder_layer)
161
+
162
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
163
+ self.gradient_checkpointing = False
164
+
165
+ def forward(
166
+ self,
167
+ input_ids: torch.LongTensor = None,
168
+ attention_mask: Optional[torch.Tensor] = None,
169
+ position_ids: Optional[torch.LongTensor] = None,
170
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
171
+ inputs_embeds: Optional[torch.FloatTensor] = None,
172
+ use_cache: Optional[bool] = None,
173
+ output_attentions: Optional[bool] = None,
174
+ output_hidden_states: Optional[bool] = None,
175
+ return_dict: Optional[bool] = None,
176
+ cache_position: Optional[torch.LongTensor] = None,
177
+ **flash_attn_kwargs,
178
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
179
+ output_attentions = (
180
+ output_attentions
181
+ if output_attentions is not None
182
+ else self.config.output_attentions
183
+ )
184
+ output_hidden_states = (
185
+ output_hidden_states
186
+ if output_hidden_states is not None
187
+ else self.config.output_hidden_states
188
+ )
189
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
190
+ return_dict = (
191
+ return_dict if return_dict is not None else self.config.use_return_dict
192
+ )
193
+
194
+ if (input_ids is None) ^ (inputs_embeds is not None):
195
+ raise ValueError(
196
+ "You must specify exactly one of input_ids or inputs_embeds"
197
+ )
198
+
199
+ if self.gradient_checkpointing and self.training and use_cache:
200
+ logger.warning_once(
201
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
202
+ )
203
+ use_cache = False
204
+
205
+ if inputs_embeds is None:
206
+ inputs_embeds = self.embed_tokens(input_ids)
207
+ # Set up cache position if not provided
208
+ if cache_position is None:
209
+ past_seen_tokens = (
210
+ 0
211
+ if past_key_values is None
212
+ else (
213
+ past_key_values.get_seq_length()
214
+ if isinstance(past_key_values, Cache)
215
+ else past_key_values[0][0].size(-2) if past_key_values else 0
216
+ )
217
+ )
218
+ cache_position = torch.arange(
219
+ past_seen_tokens,
220
+ past_seen_tokens + inputs_embeds.shape[1],
221
+ device=inputs_embeds.device,
222
+ )
223
+ # Create position embeddings to be shared across the decoder layers
224
+ # Set up position IDs if not provided
225
+ if position_ids is None:
226
+ position_ids = cache_position.unsqueeze(0)
227
+ # Get updated causal mask
228
+ causal_mask = self._update_causal_mask(
229
+ attention_mask,
230
+ inputs_embeds,
231
+ cache_position,
232
+ past_key_values,
233
+ output_attentions,
234
+ )
235
+ hidden_states = inputs_embeds
236
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
237
+
238
+ # Initialize outputs
239
+ all_hidden_states = () if output_hidden_states else None
240
+ all_self_attns = () if output_attentions else None
241
+ next_decoder_cache = None
242
+
243
+ # Process through layers
244
+ for decoder_layer in self.layers:
245
+ if output_hidden_states:
246
+ all_hidden_states += (hidden_states,)
247
+
248
+ if self.gradient_checkpointing and self.training:
249
+ layer_outputs = self._gradient_checkpointing_func(
250
+ decoder_layer.__call__,
251
+ hidden_states,
252
+ causal_mask,
253
+ position_ids,
254
+ past_key_values,
255
+ output_attentions,
256
+ use_cache,
257
+ position_embeddings,
258
+ )
259
+ else:
260
+ layer_outputs = decoder_layer(
261
+ hidden_states,
262
+ attention_mask=causal_mask,
263
+ position_ids=position_ids,
264
+ past_key_value=past_key_values,
265
+ output_attentions=output_attentions,
266
+ use_cache=use_cache,
267
+ position_embeddings=position_embeddings,
268
+ **flash_attn_kwargs,
269
+ )
270
+
271
+ hidden_states = layer_outputs[0]
272
+
273
+ if use_cache:
274
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
275
+
276
+ if output_attentions:
277
+ all_self_attns += (layer_outputs[1],)
278
+
279
+ # Final layer norm
280
+ hidden_states = self.norm(hidden_states)
281
+
282
+ # Add last hidden state
283
+ if output_hidden_states:
284
+ all_hidden_states += (hidden_states,)
285
+
286
+ next_cache = next_decoder_cache if use_cache else None
287
+
288
+ if not return_dict:
289
+ return tuple(
290
+ v
291
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
292
+ if v is not None
293
+ )
294
+
295
+ return BaseModelOutputWithPast(
296
+ last_hidden_state=hidden_states,
297
+ past_key_values=next_cache,
298
+ hidden_states=all_hidden_states,
299
+ attentions=all_self_attns,
300
+ )
301
+
302
+ @classmethod
303
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
304
+ if isinstance(
305
+ pretrained_model_name_or_path, str
306
+ ) and pretrained_model_name_or_path.endswith(".pt"):
307
+ print("Loading from local checkpoint")
308
+ # Load from local checkpoint
309
+ config = kwargs.get("config", None)
310
+ if config is None:
311
+ config = AutoConfig.from_pretrained(
312
+ pretrained_model_name_or_path, trust_remote_code=True
313
+ )
314
+
315
+ model = cls(config)
316
+ checkpoint = torch.load(pretrained_model_name_or_path, map_location="cpu")
317
+ state_dict = checkpoint["model_state_dict"]
318
+ logger.info(
319
+ f"Loaded checkpoint from epoch {checkpoint.get('epoch')} with loss {checkpoint.get('loss')}"
320
+ )
321
+
322
+ missing_keys, unexpected_keys = model.load_state_dict(
323
+ state_dict, strict=False
324
+ )
325
+
326
+ if len(missing_keys) > 0:
327
+ logger.warning(f"Missing keys: {missing_keys}")
328
+ if len(unexpected_keys) > 0:
329
+ logger.warning(f"Unexpected keys: {unexpected_keys}")
330
+
331
+ return model
332
+ else:
333
+ print("Loading from hub")
334
+ # Load from hub using parent's from_pretrained
335
+ return super().from_pretrained(
336
+ pretrained_model_name_or_path, *model_args, **kwargs
337
+ )
338
+
339
+ def get_input_embeddings(self):
340
+ return self.embed_tokens
341
+
342
+ def set_input_embeddings(self, value):
343
+ self.embed_tokens = value
344
+
345
+ def _update_causal_mask(
346
+ self,
347
+ attention_mask: torch.Tensor,
348
+ input_tensor: torch.Tensor,
349
+ cache_position: torch.Tensor,
350
+ past_key_values: Cache,
351
+ output_attentions: bool,
352
+ ):
353
+ if self.config._attn_implementation == "flash_attention_2":
354
+ if attention_mask is not None and 0.0 in attention_mask:
355
+ return attention_mask
356
+ return None
357
+
358
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
359
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
360
+ # to infer the attention mask.
361
+ past_seen_tokens = (
362
+ past_key_values.get_seq_length() if past_key_values is not None else 0
363
+ )
364
+ using_static_cache = isinstance(past_key_values, StaticCache)
365
+
366
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
367
+ if (
368
+ self.config._attn_implementation == "sdpa"
369
+ and not using_static_cache
370
+ and not output_attentions
371
+ ):
372
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
373
+ attention_mask,
374
+ inputs_embeds=input_tensor,
375
+ past_key_values_length=past_seen_tokens,
376
+ is_training=self.training,
377
+ ):
378
+ return None
379
+
380
+ dtype, device = input_tensor.dtype, input_tensor.device
381
+ sequence_length = input_tensor.shape[1]
382
+ if using_static_cache:
383
+ target_length = past_key_values.get_max_cache_shape()
384
+ else:
385
+ target_length = (
386
+ attention_mask.shape[-1]
387
+ if isinstance(attention_mask, torch.Tensor)
388
+ else past_seen_tokens + sequence_length + 1
389
+ )
390
+
391
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
392
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
393
+ attention_mask,
394
+ sequence_length=sequence_length,
395
+ target_length=target_length,
396
+ dtype=dtype,
397
+ device=device,
398
+ cache_position=cache_position,
399
+ batch_size=input_tensor.shape[0],
400
+ )
401
+
402
+ if (
403
+ self.config._attn_implementation == "sdpa"
404
+ and attention_mask is not None
405
+ and attention_mask.device.type == "cuda"
406
+ and not output_attentions
407
+ ):
408
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
409
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
410
+ # Details: https://github.com/pytorch/pytorch/issues/110213
411
+ min_dtype = torch.finfo(dtype).min
412
+ causal_mask = AttentionMaskConverter._unmask_unattended(
413
+ causal_mask, min_dtype
414
+ )
415
+
416
+ return causal_mask
417
+
418
+ @staticmethod
419
+ def _prepare_4d_causal_attention_mask_with_cache_position(
420
+ attention_mask: torch.Tensor,
421
+ sequence_length: int,
422
+ target_length: int,
423
+ dtype: torch.dtype,
424
+ device: torch.device,
425
+ cache_position: torch.Tensor,
426
+ batch_size: int,
427
+ **kwargs,
428
+ ):
429
+ if attention_mask is not None and attention_mask.dim() == 4:
430
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
431
+ causal_mask = attention_mask
432
+ else:
433
+ min_dtype = torch.finfo(dtype).min
434
+ causal_mask = torch.full(
435
+ (sequence_length, target_length),
436
+ fill_value=min_dtype,
437
+ dtype=dtype,
438
+ device=device,
439
+ )
440
+ if sequence_length != 1:
441
+ causal_mask = torch.triu(causal_mask, diagonal=1)
442
+ causal_mask *= torch.arange(
443
+ target_length, device=device
444
+ ) > cache_position.reshape(-1, 1)
445
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
446
+ if attention_mask is not None:
447
+ causal_mask = (
448
+ causal_mask.clone()
449
+ ) # copy to contiguous memory for in-place edit
450
+ mask_length = attention_mask.shape[-1]
451
+ padding_mask = (
452
+ causal_mask[:, :, :, :mask_length]
453
+ + attention_mask[:, None, None, :]
454
+ )
455
+ padding_mask = padding_mask == 0
456
+ causal_mask[:, :, :, :mask_length] = causal_mask[
457
+ :, :, :, :mask_length
458
+ ].masked_fill(padding_mask, min_dtype)
459
+
460
+ return causal_mask
461
+
462
+
463
+ class OPTRECASTMLP_LlamaForCausalLM(PreTrainedModel, GenerationMixin):
464
+ _tied_weights_keys = ["lm_head.weight"]
465
+ _tp_plan = {"lm_head": "colwise_rep"}
466
+ config_class = OPTRECASTMLP_llama
467
+ base_model_prefix = "llama"
468
+ supports_gradient_checkpointing = True
469
+
470
+ def __init__(self, config):
471
+ super().__init__(config)
472
+ self.model = OPTRECASTMLP_llamaModel(config)
473
+ self.vocab_size = config.vocab_size
474
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
475
+
476
+ # Initialize weights and apply final processing
477
+ self.post_init()
478
+
479
+ def get_input_embeddings(self):
480
+ return self.model.embed_tokens
481
+
482
+ def set_input_embeddings(self, value):
483
+ self.model.embed_tokens = value
484
+
485
+ def get_output_embeddings(self):
486
+ return self.lm_head
487
+
488
+ def set_output_embeddings(self, new_embeddings):
489
+ self.lm_head = new_embeddings
490
+
491
+ def set_decoder(self, decoder):
492
+ self.model = decoder
493
+
494
+ def get_decoder(self):
495
+ return self.model
496
+
497
+ def loss_function(
498
+ self,
499
+ logits,
500
+ labels,
501
+ vocab_size: int,
502
+ num_items_in_batch: int = None,
503
+ ignore_index: int = -100,
504
+ **kwargs,
505
+ ):
506
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
507
+ logits = logits.float()
508
+ # Shift so that tokens < n predict n
509
+ shift_logits = logits[..., :-1, :].contiguous()
510
+ shift_labels = labels[..., 1:].contiguous()
511
+ # Flatten the tokens
512
+ shift_logits = shift_logits.view(-1, vocab_size)
513
+ shift_labels = shift_labels.view(-1)
514
+ # Enable model parallelism
515
+ shift_labels = shift_labels.to(shift_logits.device)
516
+ loss = fixed_cross_entropy(
517
+ shift_logits, shift_labels, num_items_in_batch, ignore_index, **kwargs
518
+ )
519
+ return loss
520
+
521
+ def forward(
522
+ self,
523
+ input_ids: torch.LongTensor = None,
524
+ attention_mask: Optional[torch.Tensor] = None,
525
+ position_ids: Optional[torch.LongTensor] = None,
526
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
527
+ inputs_embeds: Optional[torch.FloatTensor] = None,
528
+ labels: Optional[torch.LongTensor] = None,
529
+ use_cache: Optional[bool] = None,
530
+ output_attentions: Optional[bool] = None,
531
+ output_hidden_states: Optional[bool] = None,
532
+ return_dict: Optional[bool] = None,
533
+ cache_position: Optional[torch.LongTensor] = None,
534
+ num_logits_to_keep: int = 0,
535
+ **kwargs,
536
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
537
+ """
538
+ Args:
539
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
540
+ Labels for computing the masked language modeling loss. Indices should be in
541
+ `[0, ..., config.vocab_size]` or -100 (masked tokens).
542
+ num_logits_to_keep (`int`, *optional*):
543
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate all logits.
544
+ """
545
+ output_attentions = (
546
+ output_attentions
547
+ if output_attentions is not None
548
+ else self.config.output_attentions
549
+ )
550
+ output_hidden_states = (
551
+ output_hidden_states
552
+ if output_hidden_states is not None
553
+ else self.config.output_hidden_states
554
+ )
555
+ return_dict = (
556
+ return_dict if return_dict is not None else self.config.use_return_dict
557
+ )
558
+
559
+ outputs = self.model(
560
+ input_ids=input_ids,
561
+ attention_mask=attention_mask,
562
+ position_ids=position_ids,
563
+ past_key_values=past_key_values,
564
+ inputs_embeds=inputs_embeds,
565
+ use_cache=use_cache,
566
+ output_attentions=output_attentions,
567
+ output_hidden_states=output_hidden_states,
568
+ return_dict=return_dict,
569
+ cache_position=cache_position,
570
+ **kwargs,
571
+ )
572
+
573
+ hidden_states = outputs[0]
574
+ # Only compute necessary logits
575
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
576
+
577
+ loss = None
578
+ if labels is not None:
579
+ # Calculate batch size for loss function
580
+ num_items_in_batch = (
581
+ input_ids.size(0) if input_ids is not None else inputs_embeds.size(0)
582
+ )
583
+ loss = self.loss_function(
584
+ logits=logits,
585
+ labels=labels,
586
+ vocab_size=self.config.vocab_size,
587
+ num_items_in_batch=num_items_in_batch,
588
+ **kwargs,
589
+ )
590
+
591
+ if not return_dict:
592
+ output = (logits,) + outputs[1:]
593
+ return (loss,) + output if loss is not None else output
594
+
595
+ return CausalLMOutputWithPast(
596
+ loss=loss,
597
+ logits=logits,
598
+ past_key_values=outputs.past_key_values,
599
+ hidden_states=outputs.hidden_states,
600
+ attentions=outputs.attentions,
601
+ )
602
+
603
+ def prepare_inputs_for_generation(
604
+ self,
605
+ input_ids,
606
+ past_key_values=None,
607
+ attention_mask=None,
608
+ inputs_embeds=None,
609
+ **kwargs,
610
+ ):
611
+ if past_key_values:
612
+ input_ids = input_ids[:, -1:]
613
+
614
+ position_ids = kwargs.get("position_ids", None)
615
+ if attention_mask is not None and position_ids is None:
616
+ # create position_ids on the fly for batch generation
617
+ position_ids = attention_mask.long().cumsum(-1) - 1
618
+ position_ids.masked_fill_(attention_mask == 0, 1)
619
+ if past_key_values:
620
+ position_ids = position_ids[:, -1].unsqueeze(-1)
621
+
622
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
623
+ if inputs_embeds is not None and past_key_values is None:
624
+ model_inputs = {"inputs_embeds": inputs_embeds}
625
+ else:
626
+ model_inputs = {"input_ids": input_ids}
627
+
628
+ model_inputs.update(
629
+ {
630
+ "position_ids": position_ids,
631
+ "past_key_values": past_key_values,
632
+ "use_cache": kwargs.get("use_cache"),
633
+ "attention_mask": attention_mask,
634
+ }
635
+ )
636
+ return model_inputs
637
+
638
+ @classmethod
639
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
640
+ if isinstance(
641
+ pretrained_model_name_or_path, str
642
+ ) and pretrained_model_name_or_path.endswith(".pt"):
643
+ print("Loading from local checkpoint")
644
+ config = kwargs.get("config", None)
645
+ if config is None:
646
+ config = AutoConfig.from_pretrained(
647
+ pretrained_model_name_or_path, trust_remote_code=True
648
+ )
649
+
650
+ model = cls(config)
651
+ checkpoint = torch.load(pretrained_model_name_or_path, map_location="cpu")
652
+ state_dict = checkpoint["model_state_dict"]
653
+
654
+ missing_keys, unexpected_keys = model.load_state_dict(
655
+ state_dict, strict=False
656
+ )
657
+
658
+ if len(missing_keys) > 0:
659
+ logger.warning(f"Missing keys: {missing_keys}")
660
+ if len(unexpected_keys) > 0:
661
+ logger.warning(f"Unexpected keys: {unexpected_keys}")
662
+
663
+ return model
664
+ else:
665
+ print("Loading from hub")
666
+ return super().from_pretrained(
667
+ pretrained_model_name_or_path, *model_args, **kwargs
668
+ )