emanuelaboros commited on
Commit
5a14ece
·
1 Parent(s): a886816

lets try to change the pipeline

Browse files
Files changed (1) hide show
  1. modeling_stacked.py +180 -115
modeling_stacked.py CHANGED
@@ -28,136 +28,201 @@ def get_info(label_map):
28
 
29
 
30
  class ExtendedMultitaskModelForTokenClassification(PreTrainedModel):
31
-
32
  config_class = ImpressoConfig
33
  _keys_to_ignore_on_load_missing = [r"position_ids"]
34
 
35
  def __init__(self, config):
36
  super().__init__(config)
37
- # self.num_token_labels_dict = get_info(config.label_map)
38
- # self.config = config
39
- # # print(f"I dont think it arrives here: {self.config}")
40
- # self.bert = AutoModel.from_pretrained(
41
- # config.pretrained_config["_name_or_path"], config=config.pretrained_config
42
- # )
43
  self.model_floret = floret.load_model(self.config.filename)
44
- # print(f"Model loaded: {self.model_floret}")
45
- # if "classifier_dropout" not in config.__dict__:
46
- # classifier_dropout = 0.1
47
- # else:
48
- # classifier_dropout = (
49
- # config.classifier_dropout
50
- # if config.classifier_dropout is not None
51
- # else config.hidden_dropout_prob
52
- # )
53
- # self.dropout = nn.Dropout(classifier_dropout)
54
- #
55
- # # Additional transformer layers
56
- # self.transformer_encoder = nn.TransformerEncoder(
57
- # nn.TransformerEncoderLayer(
58
- # d_model=config.hidden_size, nhead=config.num_attention_heads
59
- # ),
60
- # num_layers=2,
61
- # )
62
-
63
- # For token classification, create a classifier for each task
64
- # self.token_classifiers = nn.ModuleDict(
65
- # {
66
- # task: nn.Linear(config.hidden_size, num_labels)
67
- # for task, num_labels in self.num_token_labels_dict.items()
68
- # }
69
- # )
70
- #
71
- # # Initialize weights and apply final processing
72
- # self.post_init()
73
 
74
  def get_floret_model(self):
75
  return self.model_floret
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  @classmethod
78
  def from_pretrained(cls, *args, **kwargs):
79
  print("Ignoring weights and using custom initialization.")
80
-
81
  # Manually create the config
82
- config = ImpressoConfig()
83
-
84
  # Pass the manually created config to the class
85
  model = cls(config)
86
  return model
87
 
88
- # def forward(
89
- # self,
90
- # input_ids: Optional[torch.Tensor] = None,
91
- # attention_mask: Optional[torch.Tensor] = None,
92
- # token_type_ids: Optional[torch.Tensor] = None,
93
- # position_ids: Optional[torch.Tensor] = None,
94
- # head_mask: Optional[torch.Tensor] = None,
95
- # inputs_embeds: Optional[torch.Tensor] = None,
96
- # labels: Optional[torch.Tensor] = None,
97
- # token_labels: Optional[dict] = None,
98
- # output_attentions: Optional[bool] = None,
99
- # output_hidden_states: Optional[bool] = None,
100
- # return_dict: Optional[bool] = None,
101
- # ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
102
- # r"""
103
- # token_labels (`dict` of `torch.LongTensor` of shape `(batch_size, seq_length)`, *optional*):
104
- # Labels for computing the token classification loss. Keys should match the tasks.
105
- # """
106
- # return_dict = (
107
- # return_dict if return_dict is not None else self.config.use_return_dict
108
- # )
109
- #
110
- # bert_kwargs = {
111
- # "input_ids": input_ids,
112
- # "attention_mask": attention_mask,
113
- # "token_type_ids": token_type_ids,
114
- # "position_ids": position_ids,
115
- # "head_mask": head_mask,
116
- # "inputs_embeds": inputs_embeds,
117
- # "output_attentions": output_attentions,
118
- # "output_hidden_states": output_hidden_states,
119
- # "return_dict": return_dict,
120
- # }
121
- #
122
- # if any(
123
- # keyword in self.config.name_or_path.lower()
124
- # for keyword in ["llama", "deberta"]
125
- # ):
126
- # bert_kwargs.pop("token_type_ids")
127
- # bert_kwargs.pop("head_mask")
128
- #
129
- # outputs = self.bert(**bert_kwargs)
130
- #
131
- # # For token classification
132
- # token_output = outputs[0]
133
- # token_output = self.dropout(token_output)
134
- #
135
- # # Pass through additional transformer layers
136
- # token_output = self.transformer_encoder(token_output.transpose(0, 1)).transpose(
137
- # 0, 1
138
- # )
139
- #
140
- # # Collect the logits and compute the loss for each task
141
- # task_logits = {}
142
- # total_loss = 0
143
- # for task, classifier in self.token_classifiers.items():
144
- # logits = classifier(token_output)
145
- # task_logits[task] = logits
146
- # if token_labels and task in token_labels:
147
- # loss_fct = CrossEntropyLoss()
148
- # loss = loss_fct(
149
- # logits.view(-1, self.num_token_labels_dict[task]),
150
- # token_labels[task].view(-1),
151
- # )
152
- # total_loss += loss
153
- #
154
- # if not return_dict:
155
- # output = (task_logits,) + outputs[2:]
156
- # return ((total_loss,) + output) if total_loss != 0 else output
157
- # print(f"Is there anobidy coming here?")
158
- # return TokenClassifierOutput(
159
- # loss=total_loss,
160
- # logits=task_logits,
161
- # hidden_states=outputs.hidden_states,
162
- # attentions=outputs.attentions,
163
- # )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
 
30
  class ExtendedMultitaskModelForTokenClassification(PreTrainedModel):
 
31
  config_class = ImpressoConfig
32
  _keys_to_ignore_on_load_missing = [r"position_ids"]
33
 
34
  def __init__(self, config):
35
  super().__init__(config)
36
+ self.config = config
37
+
38
+ # Load floret model
 
 
 
39
  self.model_floret = floret.load_model(self.config.filename)
40
+
41
+ def forward(self, input_ids, attention_mask=None, **kwargs):
42
+ # Convert input_ids to strings using tokenizer
43
+ if input_ids is not None:
44
+ tokenizer = kwargs.get("tokenizer")
45
+ texts = tokenizer.batch_decode(input_ids, skip_special_tokens=True)
46
+ else:
47
+ texts = kwargs.get("text", None)
48
+
49
+ if texts:
50
+ # Floret expects strings, not tensors
51
+ predictions = [self.model_floret(text) for text in texts]
52
+ # Convert predictions to tensors for Hugging Face compatibility
53
+ return torch.tensor(predictions)
54
+ else:
55
+ # If no text is found, return dummy output
56
+ return torch.zeros(
57
+ (1, 2)
58
+ ) # Dummy tensor with shape (batch_size, num_classes)
59
+
60
+ def state_dict(self, *args, **kwargs):
61
+ # Return an empty state dictionary
62
+ return {}
63
+
64
+ def load_state_dict(self, state_dict, strict=True):
65
+ # Ignore loading since there are no parameters
66
+ print("Ignoring state_dict since model has no parameters.")
 
 
67
 
68
  def get_floret_model(self):
69
  return self.model_floret
70
 
71
+ def get_extended_attention_mask(
72
+ self, attention_mask, input_shape, device=None, dtype=torch.float
73
+ ):
74
+ if attention_mask is None:
75
+ attention_mask = torch.ones(input_shape, device=device)
76
+ extended_attention_mask = attention_mask[:, None, None, :]
77
+ extended_attention_mask = extended_attention_mask.to(dtype=dtype)
78
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
79
+ return extended_attention_mask
80
+
81
+ @property
82
+ def device(self):
83
+ return next(self.parameters()).device
84
+
85
  @classmethod
86
  def from_pretrained(cls, *args, **kwargs):
87
  print("Ignoring weights and using custom initialization.")
 
88
  # Manually create the config
89
+ config = ImpressoConfig(**kwargs)
 
90
  # Pass the manually created config to the class
91
  model = cls(config)
92
  return model
93
 
94
+
95
+ # class ExtendedMultitaskModelForTokenClassification(PreTrainedModel):
96
+ #
97
+ # config_class = ImpressoConfig
98
+ # _keys_to_ignore_on_load_missing = [r"position_ids"]
99
+ #
100
+ # def __init__(self, config):
101
+ # super().__init__(config)
102
+ # # self.num_token_labels_dict = get_info(config.label_map)
103
+ # # self.config = config
104
+ # # # print(f"I dont think it arrives here: {self.config}")
105
+ # # self.bert = AutoModel.from_pretrained(
106
+ # # config.pretrained_config["_name_or_path"], config=config.pretrained_config
107
+ # # )
108
+ # self.model_floret = floret.load_model(self.config.filename)
109
+ # # print(f"Model loaded: {self.model_floret}")
110
+ # # if "classifier_dropout" not in config.__dict__:
111
+ # # classifier_dropout = 0.1
112
+ # # else:
113
+ # # classifier_dropout = (
114
+ # # config.classifier_dropout
115
+ # # if config.classifier_dropout is not None
116
+ # # else config.hidden_dropout_prob
117
+ # # )
118
+ # # self.dropout = nn.Dropout(classifier_dropout)
119
+ # #
120
+ # # # Additional transformer layers
121
+ # # self.transformer_encoder = nn.TransformerEncoder(
122
+ # # nn.TransformerEncoderLayer(
123
+ # # d_model=config.hidden_size, nhead=config.num_attention_heads
124
+ # # ),
125
+ # # num_layers=2,
126
+ # # )
127
+ #
128
+ # # For token classification, create a classifier for each task
129
+ # # self.token_classifiers = nn.ModuleDict(
130
+ # # {
131
+ # # task: nn.Linear(config.hidden_size, num_labels)
132
+ # # for task, num_labels in self.num_token_labels_dict.items()
133
+ # # }
134
+ # # )
135
+ # #
136
+ # # # Initialize weights and apply final processing
137
+ # # self.post_init()
138
+ #
139
+ # def get_floret_model(self):
140
+ # return self.model_floret
141
+ #
142
+ # @classmethod
143
+ # def from_pretrained(cls, *args, **kwargs):
144
+ # print("Ignoring weights and using custom initialization.")
145
+ #
146
+ # # Manually create the config
147
+ # config = ImpressoConfig()
148
+ #
149
+ # # Pass the manually created config to the class
150
+ # model = cls(config)
151
+ # return model
152
+ #
153
+ # # def forward(
154
+ # # self,
155
+ # # input_ids: Optional[torch.Tensor] = None,
156
+ # # attention_mask: Optional[torch.Tensor] = None,
157
+ # # token_type_ids: Optional[torch.Tensor] = None,
158
+ # # position_ids: Optional[torch.Tensor] = None,
159
+ # # head_mask: Optional[torch.Tensor] = None,
160
+ # # inputs_embeds: Optional[torch.Tensor] = None,
161
+ # # labels: Optional[torch.Tensor] = None,
162
+ # # token_labels: Optional[dict] = None,
163
+ # # output_attentions: Optional[bool] = None,
164
+ # # output_hidden_states: Optional[bool] = None,
165
+ # # return_dict: Optional[bool] = None,
166
+ # # ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
167
+ # # r"""
168
+ # # token_labels (`dict` of `torch.LongTensor` of shape `(batch_size, seq_length)`, *optional*):
169
+ # # Labels for computing the token classification loss. Keys should match the tasks.
170
+ # # """
171
+ # # return_dict = (
172
+ # # return_dict if return_dict is not None else self.config.use_return_dict
173
+ # # )
174
+ # #
175
+ # # bert_kwargs = {
176
+ # # "input_ids": input_ids,
177
+ # # "attention_mask": attention_mask,
178
+ # # "token_type_ids": token_type_ids,
179
+ # # "position_ids": position_ids,
180
+ # # "head_mask": head_mask,
181
+ # # "inputs_embeds": inputs_embeds,
182
+ # # "output_attentions": output_attentions,
183
+ # # "output_hidden_states": output_hidden_states,
184
+ # # "return_dict": return_dict,
185
+ # # }
186
+ # #
187
+ # # if any(
188
+ # # keyword in self.config.name_or_path.lower()
189
+ # # for keyword in ["llama", "deberta"]
190
+ # # ):
191
+ # # bert_kwargs.pop("token_type_ids")
192
+ # # bert_kwargs.pop("head_mask")
193
+ # #
194
+ # # outputs = self.bert(**bert_kwargs)
195
+ # #
196
+ # # # For token classification
197
+ # # token_output = outputs[0]
198
+ # # token_output = self.dropout(token_output)
199
+ # #
200
+ # # # Pass through additional transformer layers
201
+ # # token_output = self.transformer_encoder(token_output.transpose(0, 1)).transpose(
202
+ # # 0, 1
203
+ # # )
204
+ # #
205
+ # # # Collect the logits and compute the loss for each task
206
+ # # task_logits = {}
207
+ # # total_loss = 0
208
+ # # for task, classifier in self.token_classifiers.items():
209
+ # # logits = classifier(token_output)
210
+ # # task_logits[task] = logits
211
+ # # if token_labels and task in token_labels:
212
+ # # loss_fct = CrossEntropyLoss()
213
+ # # loss = loss_fct(
214
+ # # logits.view(-1, self.num_token_labels_dict[task]),
215
+ # # token_labels[task].view(-1),
216
+ # # )
217
+ # # total_loss += loss
218
+ # #
219
+ # # if not return_dict:
220
+ # # output = (task_logits,) + outputs[2:]
221
+ # # return ((total_loss,) + output) if total_loss != 0 else output
222
+ # # print(f"Is there anobidy coming here?")
223
+ # # return TokenClassifierOutput(
224
+ # # loss=total_loss,
225
+ # # logits=task_logits,
226
+ # # hidden_states=outputs.hidden_states,
227
+ # # attentions=outputs.attentions,
228
+ # # )