sippycoder commited on
Commit
3516859
1 Parent(s): ac6e7c1

initial commit

Browse files
Files changed (2) hide show
  1. configuration_nucleus.py +0 -89
  2. modeling_nucleus.py +0 -155
configuration_nucleus.py DELETED
@@ -1,89 +0,0 @@
1
- # This config is based on LLaMA.
2
- """ Nucleus model configuration"""
3
-
4
- from transformers import PretrainedConfig
5
- from transformers.utils import logging
6
-
7
-
8
- logger = logging.get_logger(__name__)
9
-
10
- NUCLEUS_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
11
-
12
-
13
- class NucleusConfig(PretrainedConfig):
14
- model_type = "nulceus"
15
- keys_to_ignore_at_inference = ["past_key_values"]
16
-
17
- def __init__(
18
- self,
19
- vocab_size=32000,
20
- hidden_size=6656,
21
- intermediate_size=17920,
22
- num_hidden_layers=40,
23
- num_attention_heads=52,
24
- num_key_value_heads=52,
25
- hidden_act="silu",
26
- max_position_embeddings=2048,
27
- initializer_range=0.02,
28
- rms_norm_eps=1e-6,
29
- use_cache=True,
30
- pad_token_id=0,
31
- bos_token_id=1,
32
- eos_token_id=2,
33
- pretraining_tp=1,
34
- tie_word_embeddings=False,
35
- rope_theta=10000.0,
36
- rope_scaling=None,
37
- attention_bias=False,
38
- **kwargs,
39
- ):
40
- self.vocab_size = vocab_size
41
- self.max_position_embeddings = max_position_embeddings
42
- self.hidden_size = hidden_size
43
- self.intermediate_size = intermediate_size
44
- self.num_hidden_layers = num_hidden_layers
45
- self.num_attention_heads = num_attention_heads
46
-
47
- # for backward compatibility
48
- if num_key_value_heads is None:
49
- num_key_value_heads = num_attention_heads
50
-
51
- self.num_key_value_heads = num_key_value_heads
52
- self.hidden_act = hidden_act
53
- self.initializer_range = initializer_range
54
- self.rms_norm_eps = rms_norm_eps
55
- self.pretraining_tp = pretraining_tp
56
- self.use_cache = use_cache
57
- self.rope_theta = rope_theta
58
- self.rope_scaling = rope_scaling
59
- self._rope_scaling_validation()
60
- self.attention_bias = attention_bias
61
-
62
- super().__init__(
63
- pad_token_id=pad_token_id,
64
- bos_token_id=bos_token_id,
65
- eos_token_id=eos_token_id,
66
- tie_word_embeddings=tie_word_embeddings,
67
- **kwargs,
68
- )
69
-
70
- def _rope_scaling_validation(self):
71
- """
72
- Validate the `rope_scaling` configuration.
73
- """
74
- if self.rope_scaling is None:
75
- return
76
-
77
- if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
78
- raise ValueError(
79
- "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
80
- f"got {self.rope_scaling}"
81
- )
82
- rope_scaling_type = self.rope_scaling.get("type", None)
83
- rope_scaling_factor = self.rope_scaling.get("factor", None)
84
- if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
85
- raise ValueError(
86
- f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
87
- )
88
- if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
89
- raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modeling_nucleus.py DELETED
@@ -1,155 +0,0 @@
1
- # This code is based on LLaMA
2
- """ PyTorch Nucleus model."""
3
- from typing import List, Optional, Tuple, Union
4
-
5
- import torch
6
- import torch.nn.functional as F
7
- import torch.utils.checkpoint
8
- from torch import nn
9
- from torch.nn import CrossEntropyLoss
10
-
11
- from transformers.modeling_outputs import CausalLMOutputWithPast
12
- from .configuration_nucleus import NucleusConfig
13
-
14
- from transformers import (
15
- LlamaPreTrainedModel,
16
- LlamaModel
17
- )
18
-
19
-
20
- class NucleusForCausalLM(LlamaPreTrainedModel):
21
- _tied_weights_keys = ["lm_head.weight"]
22
-
23
- def __init__(self, config: NucleusConfig):
24
- super().__init__(config)
25
- self.model = LlamaModel(config)
26
- self.vocab_size = config.vocab_size
27
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
28
-
29
- # Initialize weights and apply final processing
30
- self.post_init()
31
-
32
- def get_input_embeddings(self):
33
- return self.model.embed_tokens
34
-
35
- def set_input_embeddings(self, value):
36
- self.model.embed_tokens = value
37
-
38
- def get_output_embeddings(self):
39
- return self.lm_head
40
-
41
- def set_output_embeddings(self, new_embeddings):
42
- self.lm_head = new_embeddings
43
-
44
- def set_decoder(self, decoder):
45
- self.model = decoder
46
-
47
- def get_decoder(self):
48
- return self.model
49
-
50
- def forward(
51
- self,
52
- input_ids: torch.LongTensor = None,
53
- attention_mask: Optional[torch.Tensor] = None,
54
- position_ids: Optional[torch.LongTensor] = None,
55
- past_key_values: Optional[List[torch.FloatTensor]] = None,
56
- inputs_embeds: Optional[torch.FloatTensor] = None,
57
- labels: Optional[torch.LongTensor] = None,
58
- use_cache: Optional[bool] = None,
59
- output_attentions: Optional[bool] = None,
60
- output_hidden_states: Optional[bool] = None,
61
- return_dict: Optional[bool] = None,
62
- ) -> Union[Tuple, CausalLMOutputWithPast]:
63
-
64
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
65
- output_hidden_states = (
66
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
67
- )
68
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
69
-
70
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
71
- outputs = self.model(
72
- input_ids=input_ids,
73
- attention_mask=attention_mask,
74
- position_ids=position_ids,
75
- past_key_values=past_key_values,
76
- inputs_embeds=inputs_embeds,
77
- use_cache=use_cache,
78
- output_attentions=output_attentions,
79
- output_hidden_states=output_hidden_states,
80
- return_dict=return_dict,
81
- )
82
-
83
- hidden_states = outputs[0]
84
- if self.config.pretraining_tp > 1:
85
- lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
86
- logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
87
- logits = torch.cat(logits, dim=-1)
88
- else:
89
- logits = self.lm_head(hidden_states)
90
- logits = logits.float()
91
-
92
- loss = None
93
- if labels is not None:
94
- # Shift so that tokens < n predict n
95
- shift_logits = logits[..., :-1, :].contiguous()
96
- shift_labels = labels[..., 1:].contiguous()
97
- # Flatten the tokens
98
- loss_fct = CrossEntropyLoss()
99
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
100
- shift_labels = shift_labels.view(-1)
101
- # Enable model parallelism
102
- shift_labels = shift_labels.to(shift_logits.device)
103
- loss = loss_fct(shift_logits, shift_labels)
104
-
105
- if not return_dict:
106
- output = (logits,) + outputs[1:]
107
- return (loss,) + output if loss is not None else output
108
-
109
- return CausalLMOutputWithPast(
110
- loss=loss,
111
- logits=logits,
112
- past_key_values=outputs.past_key_values,
113
- hidden_states=outputs.hidden_states,
114
- attentions=outputs.attentions,
115
- )
116
-
117
- def prepare_inputs_for_generation(
118
- self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
119
- ):
120
- if past_key_values:
121
- input_ids = input_ids[:, -1:]
122
-
123
- position_ids = kwargs.get("position_ids", None)
124
- if attention_mask is not None and position_ids is None:
125
- # create position_ids on the fly for batch generation
126
- position_ids = attention_mask.long().cumsum(-1) - 1
127
- position_ids.masked_fill_(attention_mask == 0, 1)
128
- if past_key_values:
129
- position_ids = position_ids[:, -1].unsqueeze(-1)
130
-
131
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
132
- if inputs_embeds is not None and past_key_values is None:
133
- model_inputs = {"inputs_embeds": inputs_embeds}
134
- else:
135
- model_inputs = {"input_ids": input_ids}
136
-
137
- model_inputs.update(
138
- {
139
- "position_ids": position_ids,
140
- "past_key_values": past_key_values,
141
- "use_cache": kwargs.get("use_cache"),
142
- "attention_mask": attention_mask,
143
- }
144
- )
145
- return model_inputs
146
-
147
- @staticmethod
148
- def _reorder_cache(past_key_values, beam_idx):
149
- reordered_past = ()
150
- for layer_past in past_key_values:
151
- reordered_past += (
152
- tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
153
- )
154
- return reordered_past
155
-