yangwang825 commited on
Commit
d656ed5
·
verified ·
1 Parent(s): 008bba7

Upload PureRobertaForSequenceClassification

Browse files
Files changed (3) hide show
  1. config.json +5 -2
  2. model.safetensors +3 -0
  3. modeling_pure_roberta.py +445 -0
config.json CHANGED
@@ -1,11 +1,13 @@
1
  {
 
2
  "alpha": 1,
3
  "architectures": [
4
- "RobertaForMaskedLM"
5
  ],
6
  "attention_probs_dropout_prob": 0.1,
7
  "auto_map": {
8
- "AutoConfig": "configuration_pure_roberta.PureRobertaConfig"
 
9
  },
10
  "bos_token_id": 0,
11
  "center": false,
@@ -29,6 +31,7 @@
29
  "pad_token_id": 1,
30
  "position_embedding_type": "absolute",
31
  "svd_rank": 5,
 
32
  "transformers_version": "4.44.2",
33
  "type_vocab_size": 1,
34
  "use_cache": true,
 
1
  {
2
+ "_name_or_path": "roberta-base",
3
  "alpha": 1,
4
  "architectures": [
5
+ "PureRobertaForSequenceClassification"
6
  ],
7
  "attention_probs_dropout_prob": 0.1,
8
  "auto_map": {
9
+ "AutoConfig": "configuration_pure_roberta.PureRobertaConfig",
10
+ "AutoModelForSequenceClassification": "modeling_pure_roberta.PureRobertaForSequenceClassification"
11
  },
12
  "bos_token_id": 0,
13
  "center": false,
 
31
  "pad_token_id": 1,
32
  "position_embedding_type": "absolute",
33
  "svd_rank": 5,
34
+ "torch_dtype": "float32",
35
  "transformers_version": "4.44.2",
36
  "type_vocab_size": 1,
37
  "use_cache": true,
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:34e1a941c91a240017e269af15d41dfb8fdb7f510f51f4c820fc0bb96e8827c4
3
+ size 498612824
modeling_pure_roberta.py ADDED
@@ -0,0 +1,445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from torch.autograd import Function
5
+ from transformers import PreTrainedModel
6
+ from transformers.models.roberta.modeling_roberta import (
7
+ RobertaModel,
8
+ RobertaClassificationHead,
9
+ )
10
+ from typing import Union, Tuple, Optional
11
+ from transformers.modeling_outputs import (
12
+ SequenceClassifierOutput,
13
+ MultipleChoiceModelOutput,
14
+ QuestionAnsweringModelOutput
15
+ )
16
+ from transformers.utils import ModelOutput
17
+
18
+ from .configuration_pure_roberta import PureRobertaConfig
19
+
20
+
21
+ class CovarianceFunction(Function):
22
+
23
+ @staticmethod
24
+ def forward(ctx, inputs):
25
+ x = inputs
26
+ b, c, h, w = x.data.shape
27
+ m = h * w
28
+ x = x.view(b, c, m)
29
+ I_hat = (-1.0 / m / m) * torch.ones(m, m, device=x.device) + (
30
+ 1.0 / m
31
+ ) * torch.eye(m, m, device=x.device)
32
+ I_hat = I_hat.view(1, m, m).repeat(b, 1, 1).type(x.dtype)
33
+ y = x @ I_hat @ x.transpose(-1, -2)
34
+ ctx.save_for_backward(inputs, I_hat)
35
+ return y
36
+
37
+ @staticmethod
38
+ def backward(ctx, grad_output):
39
+ inputs, I_hat = ctx.saved_tensors
40
+ x = inputs
41
+ b, c, h, w = x.data.shape
42
+ m = h * w
43
+ x = x.view(b, c, m)
44
+ grad_input = grad_output + grad_output.transpose(1, 2)
45
+ grad_input = grad_input @ x @ I_hat
46
+ grad_input = grad_input.reshape(b, c, h, w)
47
+ return grad_input
48
+
49
+
50
+ class Covariance(nn.Module):
51
+
52
+ def __init__(self):
53
+ super(Covariance, self).__init__()
54
+
55
+ def _covariance(self, x):
56
+ return CovarianceFunction.apply(x)
57
+
58
+ def forward(self, x):
59
+ # x should be [batch_size, seq_len, embed_dim]
60
+ if x.dim() == 2:
61
+ x = x.transpose(-1, -2)
62
+ C = self._covariance(x[None, :, :, None])
63
+ C = C.squeeze(dim=0)
64
+ return C
65
+
66
+
67
+ class PFSA(torch.nn.Module):
68
+ """
69
+ https://openreview.net/pdf?id=isodM5jTA7h
70
+ """
71
+ def __init__(self, input_dim, alpha=1):
72
+ super(PFSA, self).__init__()
73
+ self.input_dim = input_dim
74
+ self.alpha = alpha
75
+
76
+ def forward_one_sample(self, x):
77
+ x = x.transpose(1, 2)[..., None]
78
+ k = torch.mean(x, dim=[-1, -2], keepdim=True)
79
+ kd = torch.sqrt((k - k.mean(dim=1, keepdim=True)).pow(2).sum(dim=1, keepdim=True)) # [B, 1, 1, 1]
80
+ qd = torch.sqrt((x - x.mean(dim=1, keepdim=True)).pow(2).sum(dim=1, keepdim=True)) # [B, 1, T, 1]
81
+ C_qk = (((x - x.mean(dim=1, keepdim=True)) * (k - k.mean(dim=1, keepdim=True))).sum(dim=1, keepdim=True)) / (qd * kd)
82
+ A = (1 - torch.sigmoid(C_qk)) ** self.alpha
83
+ out = x * A
84
+ out = out.squeeze(dim=-1).transpose(1, 2)
85
+ return out
86
+
87
+ def forward(self, input_values, attention_mask=None):
88
+ """
89
+ x: [B, T, F]
90
+ """
91
+ out = []
92
+ b, t, f = input_values.shape
93
+ for x, mask in zip(input_values, attention_mask):
94
+ x = x.view(1, t, f)
95
+ # x_in = x[:, :sum(mask), :]
96
+ x_in = x[:, :int(mask.sum().item()), :]
97
+ x_out = self.forward_one_sample(x_in)
98
+ x_expanded = torch.zeros_like(x, device=x.device)
99
+ x_expanded[:, :x_out.shape[-2], :x_out.shape[-1]] = x_out
100
+ out.append(x_expanded)
101
+ out = torch.vstack(out)
102
+ out = out.view(b, t, f)
103
+ return out
104
+
105
+
106
+ class PURE(torch.nn.Module):
107
+
108
+ def __init__(
109
+ self,
110
+ in_dim,
111
+ svd_rank=16,
112
+ num_pc_to_remove=1,
113
+ center=False,
114
+ num_iters=2,
115
+ alpha=1,
116
+ disable_pcr=False,
117
+ disable_pfsa=False,
118
+ disable_covariance=True,
119
+ *args, **kwargs
120
+ ):
121
+ super().__init__()
122
+ self.in_dim = in_dim
123
+ self.svd_rank = svd_rank
124
+ self.num_pc_to_remove = num_pc_to_remove
125
+ self.center = center
126
+ self.num_iters = num_iters
127
+ self.do_pcr = not disable_pcr
128
+ self.do_pfsa = not disable_pfsa
129
+ self.do_covariance = not disable_covariance
130
+ self.attention = PFSA(in_dim, alpha=alpha)
131
+
132
+ def _compute_pc(self, X, attention_mask):
133
+ """
134
+ x: (B, T, F)
135
+ """
136
+ pcs = []
137
+ bs, seqlen, dim = X.shape
138
+ for x, mask in zip(X, attention_mask):
139
+ rank = int(mask.sum().item())
140
+ x = x[:rank, :]
141
+ if self.do_covariance:
142
+ x = Covariance()(x)
143
+ q = self.svd_rank
144
+ else:
145
+ q = min(self.svd_rank, rank)
146
+ _, _, V = torch.pca_lowrank(x, q=q, center=self.center, niter=self.num_iters)
147
+ # _, _, Vh = torch.linalg.svd(x_, full_matrices=False)
148
+ # V = Vh.mH
149
+ pc = V.transpose(0, 1)[:self.num_pc_to_remove, :] # pc: [K, F]
150
+ pcs.append(pc)
151
+ # pcs = torch.vstack(pcs)
152
+ # pcs = pcs.view(bs, self.num_pc_to_remove, dim)
153
+ return pcs
154
+
155
+ def _remove_pc(self, X, pcs):
156
+ """
157
+ [B, T, F], [B, ..., F]
158
+ """
159
+ b, t, f = X.shape
160
+ out = []
161
+ for i, (x, pc) in enumerate(zip(X, pcs)):
162
+ # v = []
163
+ # for j, t in enumerate(x):
164
+ # t_ = t
165
+ # for c_ in c:
166
+ # t_ = t_.view(f, 1) - c_.view(f, 1) @ c_.view(1, f) @ t.view(f, 1)
167
+ # v.append(t_.transpose(-1, -2))
168
+ # v = torch.vstack(v)
169
+ v = x - x @ pc.transpose(0, 1) @ pc
170
+ out.append(v[None, ...])
171
+ out = torch.vstack(out)
172
+ return out
173
+
174
+ def forward(self, input_values, attention_mask=None, *args, **kwargs):
175
+ """
176
+ PCR -> Attention
177
+ x: (B, T, F)
178
+ """
179
+ x = input_values
180
+ if self.do_pcr:
181
+ pc = self._compute_pc(x, attention_mask) # pc: [B, K, F]
182
+ xx = self._remove_pc(x, pc)
183
+ # xx = xt - xt @ pc.transpose(1, 2) @ pc # [B, T, F] * [B, F, K] * [B, K, F] = [B, T, F]
184
+ else:
185
+ xx = x
186
+ if self.do_pfsa:
187
+ xx = self.attention(xx, attention_mask)
188
+ return xx
189
+
190
+
191
+ class StatisticsPooling(torch.nn.Module):
192
+
193
+ def __init__(self, return_mean=True, return_std=True):
194
+ super().__init__()
195
+
196
+ # Small value for GaussNoise
197
+ self.eps = 1e-5
198
+ self.return_mean = return_mean
199
+ self.return_std = return_std
200
+ if not (self.return_mean or self.return_std):
201
+ raise ValueError(
202
+ "both of statistics are equal to False \n"
203
+ "consider enabling mean and/or std statistic pooling"
204
+ )
205
+
206
+ def forward(self, input_values, attention_mask=None):
207
+ """Calculates mean and std for a batch (input tensor).
208
+
209
+ Arguments
210
+ ---------
211
+ x : torch.Tensor
212
+ It represents a tensor for a mini-batch.
213
+ """
214
+ x = input_values
215
+ if attention_mask is None:
216
+ if self.return_mean:
217
+ mean = x.mean(dim=1)
218
+ if self.return_std:
219
+ std = x.std(dim=1)
220
+ else:
221
+ mean = []
222
+ std = []
223
+ for snt_id in range(x.shape[0]):
224
+ # Avoiding padded time steps
225
+ lengths = torch.sum(attention_mask, dim=1)
226
+ relative_lengths = lengths / torch.max(lengths)
227
+ actual_size = torch.round(relative_lengths[snt_id] * x.shape[1]).int()
228
+ # actual_size = int(torch.round(lengths[snt_id] * x.shape[1]))
229
+
230
+ # computing statistics
231
+ if self.return_mean:
232
+ mean.append(
233
+ torch.mean(x[snt_id, 0:actual_size, ...], dim=0)
234
+ )
235
+ if self.return_std:
236
+ std.append(torch.std(x[snt_id, 0:actual_size, ...], dim=0))
237
+ if self.return_mean:
238
+ mean = torch.stack(mean)
239
+ if self.return_std:
240
+ std = torch.stack(std)
241
+
242
+ if self.return_mean:
243
+ gnoise = self._get_gauss_noise(mean.size(), device=mean.device)
244
+ gnoise = gnoise
245
+ mean += gnoise
246
+ if self.return_std:
247
+ std = std + self.eps
248
+
249
+ # Append mean and std of the batch
250
+ if self.return_mean and self.return_std:
251
+ pooled_stats = torch.cat((mean, std), dim=1)
252
+ pooled_stats = pooled_stats.unsqueeze(1)
253
+ elif self.return_mean:
254
+ pooled_stats = mean.unsqueeze(1)
255
+ elif self.return_std:
256
+ pooled_stats = std.unsqueeze(1)
257
+
258
+ return pooled_stats
259
+
260
+ def _get_gauss_noise(self, shape_of_tensor, device="cpu"):
261
+ """Returns a tensor of epsilon Gaussian noise.
262
+
263
+ Arguments
264
+ ---------
265
+ shape_of_tensor : tensor
266
+ It represents the size of tensor for generating Gaussian noise.
267
+ """
268
+ gnoise = torch.randn(shape_of_tensor, device=device)
269
+ gnoise -= torch.min(gnoise)
270
+ gnoise /= torch.max(gnoise)
271
+ gnoise = self.eps * ((1 - 9) * gnoise + 9)
272
+
273
+ return gnoise
274
+
275
+
276
+ class PureRobertaPreTrainedModel(PreTrainedModel):
277
+ """
278
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
279
+ models.
280
+ """
281
+
282
+ config_class = PureRobertaConfig
283
+ base_model_prefix = "pure_roberta"
284
+ supports_gradient_checkpointing = True
285
+ _no_split_modules = ["RobertaEmbeddings", "RobertaSelfAttention", "RobertaSdpaSelfAttention"]
286
+ _supports_sdpa = True
287
+
288
+ # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
289
+ def _init_weights(self, module):
290
+ """Initialize the weights"""
291
+ if isinstance(module, nn.Linear):
292
+ # Slightly different from the TF version which uses truncated_normal for initialization
293
+ # cf https://github.com/pytorch/pytorch/pull/5617
294
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
295
+ if module.bias is not None:
296
+ module.bias.data.zero_()
297
+ elif isinstance(module, nn.Embedding):
298
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
299
+ if module.padding_idx is not None:
300
+ module.weight.data[module.padding_idx].zero_()
301
+ elif isinstance(module, nn.LayerNorm):
302
+ module.bias.data.zero_()
303
+ module.weight.data.fill_(1.0)
304
+
305
+
306
+ class PureRobertaForSequenceClassification(PureRobertaPreTrainedModel):
307
+
308
+ def __init__(
309
+ self,
310
+ config,
311
+ label_smoothing=0.0,
312
+ ):
313
+ super().__init__(config)
314
+ self.label_smoothing = label_smoothing
315
+ self.num_labels = config.num_labels
316
+ self.config = config
317
+
318
+ self.roberta = RobertaModel(config, add_pooling_layer=False)
319
+ self.pure = PURE(
320
+ in_dim=config.hidden_size,
321
+ svd_rank=config.svd_rank,
322
+ num_pc_to_remove=config.num_pc_to_remove,
323
+ center=config.center,
324
+ num_iters=config.num_iters,
325
+ alpha=config.alpha,
326
+ disable_pcr=config.disable_pcr,
327
+ disable_pfsa=config.disable_pfsa,
328
+ disable_covariance=config.disable_covariance
329
+ )
330
+ self.mean = StatisticsPooling(return_mean=True, return_std=False)
331
+ self.classifier = RobertaClassificationHead(config)
332
+
333
+ # Initialize weights and apply final processing
334
+ self.post_init()
335
+
336
+ def forward_pure_embeddings(
337
+ self,
338
+ input_ids: Optional[torch.LongTensor] = None,
339
+ attention_mask: Optional[torch.FloatTensor] = None,
340
+ token_type_ids: Optional[torch.LongTensor] = None,
341
+ position_ids: Optional[torch.LongTensor] = None,
342
+ head_mask: Optional[torch.FloatTensor] = None,
343
+ inputs_embeds: Optional[torch.FloatTensor] = None,
344
+ labels: Optional[torch.LongTensor] = None,
345
+ output_attentions: Optional[bool] = None,
346
+ output_hidden_states: Optional[bool] = None,
347
+ return_dict: Optional[bool] = None,
348
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
349
+ r"""
350
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
351
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
352
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
353
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
354
+ """
355
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
356
+
357
+ outputs = self.roberta(
358
+ input_ids,
359
+ attention_mask=attention_mask,
360
+ token_type_ids=token_type_ids,
361
+ position_ids=position_ids,
362
+ head_mask=head_mask,
363
+ inputs_embeds=inputs_embeds,
364
+ output_attentions=output_attentions,
365
+ output_hidden_states=output_hidden_states,
366
+ return_dict=return_dict,
367
+ )
368
+
369
+ token_embeddings = outputs.last_hidden_state
370
+ token_embeddings = self.pure(token_embeddings, attention_mask)
371
+
372
+ return ModelOutput(
373
+ last_hidden_state=token_embeddings,
374
+ )
375
+
376
+ def forward(
377
+ self,
378
+ input_ids: Optional[torch.LongTensor] = None,
379
+ attention_mask: Optional[torch.FloatTensor] = None,
380
+ token_type_ids: Optional[torch.LongTensor] = None,
381
+ position_ids: Optional[torch.LongTensor] = None,
382
+ head_mask: Optional[torch.FloatTensor] = None,
383
+ inputs_embeds: Optional[torch.FloatTensor] = None,
384
+ labels: Optional[torch.LongTensor] = None,
385
+ output_attentions: Optional[bool] = None,
386
+ output_hidden_states: Optional[bool] = None,
387
+ return_dict: Optional[bool] = None,
388
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
389
+ r"""
390
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
391
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
392
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
393
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
394
+ """
395
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
396
+
397
+ outputs = self.roberta(
398
+ input_ids,
399
+ attention_mask=attention_mask,
400
+ token_type_ids=token_type_ids,
401
+ position_ids=position_ids,
402
+ head_mask=head_mask,
403
+ inputs_embeds=inputs_embeds,
404
+ output_attentions=output_attentions,
405
+ output_hidden_states=output_hidden_states,
406
+ return_dict=return_dict,
407
+ )
408
+
409
+ token_embeddings = outputs.last_hidden_state
410
+ token_embeddings = self.pure(token_embeddings, attention_mask)
411
+ pooled_output = self.mean(token_embeddings).squeeze(1)
412
+ logits = self.classifier(pooled_output)
413
+
414
+ loss = None
415
+ if labels is not None:
416
+ if self.config.problem_type is None:
417
+ if self.num_labels == 1:
418
+ self.config.problem_type = "regression"
419
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
420
+ self.config.problem_type = "single_label_classification"
421
+ else:
422
+ self.config.problem_type = "multi_label_classification"
423
+
424
+ if self.config.problem_type == "regression":
425
+ loss_fct = nn.MSELoss()
426
+ if self.num_labels == 1:
427
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
428
+ else:
429
+ loss = loss_fct(logits, labels)
430
+ elif self.config.problem_type == "single_label_classification":
431
+ loss_fct = nn.CrossEntropyLoss(label_smoothing=self.label_smoothing)
432
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
433
+ elif self.config.problem_type == "multi_label_classification":
434
+ loss_fct = nn.BCEWithLogitsLoss()
435
+ loss = loss_fct(logits, labels)
436
+ if not return_dict:
437
+ output = (logits,) + outputs[2:]
438
+ return ((loss,) + output) if loss is not None else output
439
+
440
+ return SequenceClassifierOutput(
441
+ loss=loss,
442
+ logits=logits,
443
+ hidden_states=outputs.hidden_states,
444
+ attentions=outputs.attentions,
445
+ )