yangwang825 commited on
Commit
aff5c65
·
1 Parent(s): 4eafb1b

Upload BertForSequenceClassification

Browse files
Files changed (3) hide show
  1. config.json +6 -1
  2. modeling_bert.py +149 -5
  3. pytorch_model.bin +1 -1
config.json CHANGED
@@ -1,9 +1,13 @@
1
  {
2
  "affine": false,
3
  "alpha": 1,
 
 
 
4
  "attention_probs_dropout_prob": 0.1,
5
  "auto_map": {
6
- "AutoConfig": "configuration_bert.BertConfig"
 
7
  },
8
  "center": false,
9
  "classifier_dropout": null,
@@ -27,6 +31,7 @@
27
  "r": 1,
28
  "return_mean": true,
29
  "return_std": true,
 
30
  "transformers_version": "4.33.3",
31
  "type_vocab_size": 2,
32
  "use_cache": true,
 
1
  {
2
  "affine": false,
3
  "alpha": 1,
4
+ "architectures": [
5
+ "BertForSequenceClassification"
6
+ ],
7
  "attention_probs_dropout_prob": 0.1,
8
  "auto_map": {
9
+ "AutoConfig": "configuration_bert.BertConfig",
10
+ "AutoModelForSequenceClassification": "modeling_bert.BertForSequenceClassification"
11
  },
12
  "center": false,
13
  "classifier_dropout": null,
 
31
  "r": 1,
32
  "return_mean": true,
33
  "return_std": true,
34
+ "torch_dtype": "float32",
35
  "transformers_version": "4.33.3",
36
  "type_vocab_size": 2,
37
  "use_cache": true,
modeling_bert.py CHANGED
@@ -1,5 +1,7 @@
1
  import torch
2
  import torch.nn as nn
 
 
3
  from typing import Optional, List, Union, Tuple
4
  from transformers import (
5
  PretrainedConfig,
@@ -46,21 +48,163 @@ class BertPreTrainedModel(PreTrainedModel):
46
  module.weight.data.fill_(1.0)
47
 
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  class BertPooler(nn.Module):
50
 
51
  def __init__(self, config):
52
  super().__init__()
53
- self.dense = nn.Linear(config.hidden_size, config.hidden_size)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  self.activation = nn.Tanh()
 
55
 
56
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
57
  # We "pool" the model by simply taking the hidden state corresponding
58
  # to the first token.
59
- first_token_tensor = hidden_states[:, 0]
60
- pooled_output = self.dense(first_token_tensor)
 
61
  pooled_output = self.activation(pooled_output)
62
  return pooled_output
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  class BertModel(BertPreTrainedModel):
66
 
@@ -180,7 +324,7 @@ class BertModel(BertPreTrainedModel):
180
  return_dict=return_dict,
181
  )
182
  sequence_output = encoder_outputs[0]
183
- pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
184
 
185
  if not return_dict:
186
  return (sequence_output, pooled_output) + encoder_outputs[1:]
 
1
  import torch
2
  import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from collections import OrderedDict
5
  from typing import Optional, List, Union, Tuple
6
  from transformers import (
7
  PretrainedConfig,
 
48
  module.weight.data.fill_(1.0)
49
 
50
 
51
+ class PFSA(nn.Module):
52
+ """
53
+ https://openreview.net/pdf?id=isodM5jTA7h
54
+ """
55
+ def __init__(self, input_dim, alpha=1):
56
+ super(PFSA, self).__init__()
57
+ self.input_dim = input_dim
58
+ self.alpha = alpha
59
+
60
+ def forward(self, x, mask=None):
61
+ """
62
+ x: [B, T, F]
63
+ """
64
+ x = x.transpose(1, 2)[..., None]
65
+ k = torch.mean(x, dim=[-1, -2], keepdim=True)
66
+ kd = torch.sqrt((k - k.mean(dim=1, keepdim=True)).pow(2).sum(dim=1, keepdim=True)) # [B, 1, 1, 1]
67
+ qd = torch.sqrt((x - x.mean(dim=1, keepdim=True)).pow(2).sum(dim=1, keepdim=True)) # [B, 1, T, 1]
68
+ C_qk = (((x - x.mean(dim=1, keepdim=True)) * (k - k.mean(dim=1, keepdim=True))).sum(dim=1, keepdim=True)) / (qd * kd)
69
+ A = (1 - torch.sigmoid(C_qk)) ** self.alpha
70
+ out = x * A
71
+ out = out.squeeze(dim=-1).transpose(1, 2)
72
+ return out
73
+
74
+
75
+ class PURE(nn.Module):
76
+
77
+ def __init__(
78
+ self,
79
+ in_dim,
80
+ q=5,
81
+ r=1,
82
+ center=False,
83
+ num_iters=1,
84
+ return_mean=True,
85
+ return_std=True,
86
+ normalize=False,
87
+ do_pcr=True,
88
+ do_pfsa=True,
89
+ alpha=1,
90
+ *args, **kwargs
91
+ ):
92
+ super().__init__()
93
+ self.in_dim = in_dim
94
+ self.target_rank = q
95
+ self.num_pc_to_remove = r
96
+ self.center = center
97
+ self.num_iters = num_iters
98
+ self.return_mean = return_mean
99
+ self.return_std = return_std
100
+ self.normalize = normalize
101
+ self.do_pcr = do_pcr
102
+ self.do_pfsa = do_pfsa
103
+ # self.attention = SelfAttention(in_dim)
104
+ self.attention = PFSA(in_dim, alpha=alpha)
105
+ self.eps = 1e-5
106
+
107
+ if self.normalize:
108
+ self.norm = nn.Sequential(OrderedDict([
109
+ ('relu', nn.LeakyReLU(inplace=True)),
110
+ ('bn', nn.BatchNorm1d(in_dim)),
111
+ ]))
112
+
113
+ def get_out_dim(self):
114
+ if self.return_mean and self.return_std:
115
+ self.out_dim = self.in_dim * 2
116
+ else:
117
+ self.out_dim = self.in_dim
118
+ return self.out_dim
119
+
120
+ def _compute_pc(self, x):
121
+ """
122
+ x: (B, T, F)
123
+ """
124
+ _, _, V = torch.pca_lowrank(x, q=self.target_rank, center=self.center, niter=self.num_iters)
125
+ pc = V.transpose(1, 2)[:, :self.num_pc_to_remove, :] # pc: [B, K, F]
126
+ return pc
127
+
128
+ def forward(self, x, attention_mask=None, *args, **kwargs):
129
+ """
130
+ PCR -> Attention
131
+ x: (B, F, T)
132
+ """
133
+ if self.normalize:
134
+ x = self.norm(x)
135
+ xt = x.transpose(1, 2)
136
+ if self.do_pcr:
137
+ pc = self._compute_pc(xt) # pc: [B, K, F]
138
+ xx = xt - xt @ pc.transpose(1, 2) @ pc # [B, T, F] * [B, F, K] * [B, K, F] = [B, T, F]
139
+ else:
140
+ xx = xt
141
+ if self.do_pfsa:
142
+ xx = self.attention(xx, attention_mask)
143
+ if self.normalize:
144
+ xx = F.normalize(xx, p=2, dim=2)
145
+ return xx
146
+
147
+
148
  class BertPooler(nn.Module):
149
 
150
  def __init__(self, config):
151
  super().__init__()
152
+ self.pure = PURE(
153
+ config.hidden_size,
154
+ q=config.q,
155
+ r=config.r,
156
+ center=config.center,
157
+ num_iters=config.num_iters,
158
+ return_mean=config.return_mean,
159
+ return_std=config.return_std,
160
+ normalize=config.normalize,
161
+ do_pcr=config.do_pcr,
162
+ do_pfsa=config.do_pfsa,
163
+ alpha=config.alpha
164
+ )
165
+ if config.affine:
166
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
167
+ else:
168
+ self.dense = nn.Identity()
169
  self.activation = nn.Tanh()
170
+ self.eps = 1e-5
171
 
172
+ def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
173
  # We "pool" the model by simply taking the hidden state corresponding
174
  # to the first token.
175
+ hidden_states = self.pure(hidden_states.transpose(1, 2), attention_mask)
176
+ mean_tensor = self.mean_pooling(hidden_states, attention_mask)
177
+ pooled_output = self.dense(mean_tensor)
178
  pooled_output = self.activation(pooled_output)
179
  return pooled_output
180
 
181
+ def _get_gauss_noise(self, shape_of_tensor, device="cpu"):
182
+ """Returns a tensor of epsilon Gaussian noise.
183
+
184
+ Arguments
185
+ ---------
186
+ shape_of_tensor : tensor
187
+ It represents the size of tensor for generating Gaussian noise.
188
+ """
189
+ gnoise = torch.randn(shape_of_tensor, device=device)
190
+ gnoise -= torch.min(gnoise)
191
+ gnoise /= torch.max(gnoise)
192
+ gnoise = self.eps * ((1 - 9) * gnoise + 9)
193
+
194
+ return gnoise
195
+
196
+ def add_noise(self, tensor):
197
+ gnoise = self._get_gauss_noise(tensor.size(), device=tensor.device)
198
+ gnoise = gnoise
199
+ tensor += gnoise
200
+ return tensor
201
+
202
+ def mean_pooling(self, token_embeddings, attention_mask):
203
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
204
+ mean = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
205
+ # mean = self.add_noise(mean)
206
+ return mean
207
+
208
 
209
  class BertModel(BertPreTrainedModel):
210
 
 
324
  return_dict=return_dict,
325
  )
326
  sequence_output = encoder_outputs[0]
327
+ pooled_output = self.pooler(sequence_output, attention_mask) if self.pooler is not None else None
328
 
329
  if not return_dict:
330
  return (sequence_output, pooled_output) + encoder_outputs[1:]
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:31b5cff2bb6cce0d41eceb729d8660438d177910122749eca6916b3f404c0f80
3
  size 438000689
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:64dd3354da4b868afe78cc83d9e51ed4ca20cab88015a22a38257b205c9eadd4
3
  size 438000689