bozhou commited on
Commit
23fe031
1 Parent(s): acb98db

Upload 22 files

Browse files
modeling/__init__.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Zhou Bo
3
+
4
+ #
5
+
6
+ """ Components for NN
7
+ """
8
+
9
+ from __future__ import absolute_import
10
+ from __future__ import division
11
+ from __future__ import print_function
12
+
13
+ from .tokenizers import *
14
+ from .pooling import *
15
+ from .mlm import MLMPredictionHead
16
+ from .nnmodule import NNModule
17
+ from .deberta import *
18
+ from .disentangled_attention import *
19
+ from .ops import *
20
+ from .bert import *
21
+ from .config import *
22
+ from .cache_utils import *
23
+ from .focal_loss import *
24
+ # from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer
25
+ from .modeling import (BertConfig, BertModel, BertForPreTraining, BertForMaskedLM,
26
+ BertForNextSentencePrediction, PreTrainedBertModel,
27
+ BertForSequenceClassification, BertForMultipleChoice, BertForTokenClassification,
28
+ BertForQuestionAnswering, BertForPreTrainingLossMask, BertPreTrainingPairRel,
29
+ BertPreTrainingPairTransform, BertPreTrainingHeads, MLMHead)
30
+ # from .optimization import BertAdam, BertAdamFineTune
31
+ try:
32
+ from .optimization_fp16 import FP16_Optimizer_State
33
+ except:
34
+ pass
35
+ from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE
36
+ from .flash import FlashQuadModel
37
+ from .gat import GatModel
modeling/bert.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
2
+ # Copyright (c) Microsoft, Inc. 2020
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # This piece of code is modified based on https://github.com/huggingface/transformers
8
+
9
+ import copy
10
+ import torch
11
+ from torch import nn
12
+ from collections import Sequence
13
+ from packaging import version
14
+ import numpy as np
15
+ import math
16
+ import os
17
+ import pdb
18
+
19
+ import json
20
+ from .ops import *
21
+ from .disentangled_attention import *
22
+ from .da_utils import *
23
+
24
+ __all__ = ['BertEncoder', 'BertEmbeddings', 'ACT2FN', 'LayerNorm', 'BertLMPredictionHead']
25
+
26
+ class BertSelfOutput(nn.Module):
27
+ def __init__(self, config):
28
+ super().__init__()
29
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
30
+ self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
31
+ self.dropout = StableDropout(config.hidden_dropout_prob)
32
+ self.config = config
33
+
34
+ def forward(self, hidden_states, input_states, mask=None):
35
+ hidden_states = self.dense(hidden_states)
36
+ hidden_states = self.dropout(hidden_states)
37
+ hidden_states += input_states
38
+ hidden_states = MaskedLayerNorm(self.LayerNorm, hidden_states)
39
+ return hidden_states
40
+
41
+ class BertAttention(nn.Module):
42
+ def __init__(self, config):
43
+ super().__init__()
44
+ self.self = DisentangledSelfAttention(config)
45
+ self.output = BertSelfOutput(config)
46
+ self.config = config
47
+
48
+ def forward(self, hidden_states, attention_mask, return_att=False, query_states=None, relative_pos=None, rel_embeddings=None):
49
+ output = self.self(hidden_states, attention_mask, return_att, query_states=query_states, relative_pos=relative_pos, rel_embeddings=rel_embeddings)
50
+ self_output, att_matrix, att_logits_=output['hidden_states'], output['attention_probs'], output['attention_logits']
51
+ if query_states is None:
52
+ query_states = hidden_states
53
+ attention_output = self.output(self_output, query_states, attention_mask)
54
+
55
+ if return_att:
56
+ return (attention_output, att_matrix)
57
+ else:
58
+ return attention_output
59
+
60
+ class BertIntermediate(nn.Module):
61
+ def __init__(self, config):
62
+ super().__init__()
63
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
64
+ self.intermediate_act_fn = ACT2FN[config.hidden_act] \
65
+ if isinstance(config.hidden_act, str) else config.hidden_act
66
+
67
+ def forward(self, hidden_states):
68
+ hidden_states = self.dense(hidden_states)
69
+ hidden_states = self.intermediate_act_fn(hidden_states)
70
+ return hidden_states
71
+
72
+ class BertOutput(nn.Module):
73
+ def __init__(self, config):
74
+ super(BertOutput, self).__init__()
75
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
76
+ self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
77
+ self.dropout = StableDropout(config.hidden_dropout_prob)
78
+ self.config = config
79
+
80
+ def forward(self, hidden_states, input_states, mask=None):
81
+ hidden_states = self.dense(hidden_states)
82
+ hidden_states = self.dropout(hidden_states)
83
+ hidden_states += input_states
84
+ hidden_states = MaskedLayerNorm(self.LayerNorm, hidden_states)
85
+ return hidden_states
86
+
87
+ class BertLayer(nn.Module):
88
+ def __init__(self, config):
89
+ super(BertLayer, self).__init__()
90
+ self.attention = BertAttention(config)
91
+ self.intermediate = BertIntermediate(config)
92
+ self.output = BertOutput(config)
93
+
94
+ def forward(self, hidden_states, attention_mask, return_att=False, query_states=None, relative_pos=None, rel_embeddings=None):
95
+ attention_output = self.attention(hidden_states, attention_mask, return_att=return_att, \
96
+ query_states=query_states, relative_pos=relative_pos, rel_embeddings=rel_embeddings)
97
+ if return_att:
98
+ attention_output, att_matrix = attention_output
99
+ intermediate_output = self.intermediate(attention_output)
100
+ layer_output = self.output(intermediate_output, attention_output, attention_mask)
101
+ if return_att:
102
+ return (layer_output, att_matrix)
103
+ else:
104
+ return layer_output
105
+
106
+ class ConvLayer(nn.Module):
107
+ def __init__(self, config):
108
+ super().__init__()
109
+ kernel_size = getattr(config, 'conv_kernel_size', 3)
110
+ groups = getattr(config, 'conv_groups', 1)
111
+ self.conv_act = getattr(config, 'conv_act', 'tanh')
112
+ self.conv = torch.nn.Conv1d(config.hidden_size, config.hidden_size, kernel_size, padding = (kernel_size-1)//2, groups = groups)
113
+ self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
114
+ self.dropout = StableDropout(config.hidden_dropout_prob)
115
+ self.config = config
116
+
117
+ def forward(self, hidden_states, residual_states, input_mask):
118
+ out = self.conv(hidden_states.permute(0,2,1).contiguous()).permute(0,2,1).contiguous()
119
+ if version.Version(torch.__version__) >= version.Version('1.2.0a'):
120
+ rmask = (1-input_mask).bool()
121
+ else:
122
+ rmask = (1-input_mask).byte()
123
+ out.masked_fill_(rmask.unsqueeze(-1).expand(out.size()), 0)
124
+ out = ACT2FN[self.conv_act](self.dropout(out))
125
+ output_states = MaskedLayerNorm(self.LayerNorm, residual_states + out, input_mask)
126
+
127
+ return output_states
128
+
129
+ class BertEncoder(nn.Module):
130
+ """ Modified BertEncoder with relative position bias support
131
+ """
132
+ def __init__(self, config):
133
+ super().__init__()
134
+ #layer = BertLayer(config)
135
+ self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
136
+ self.relative_attention = getattr(config, 'relative_attention', False)
137
+ if self.relative_attention:
138
+ self.max_relative_positions = getattr(config, 'max_relative_positions', -1)
139
+ if self.max_relative_positions <1:
140
+ self.max_relative_positions = config.max_position_embeddings
141
+ self.position_buckets = getattr(config, 'position_buckets', -1)
142
+ pos_ebd_size = self.max_relative_positions*2
143
+ if self.position_buckets>0:
144
+ pos_ebd_size = self.position_buckets*2
145
+ self.rel_embeddings = nn.Embedding(pos_ebd_size, config.hidden_size)
146
+
147
+ self.norm_rel_ebd = [x.strip() for x in getattr(config, 'norm_rel_ebd', 'none').lower().split('|')]
148
+ if 'layer_norm' in self.norm_rel_ebd:
149
+ self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine = True)
150
+ kernel_size = getattr(config, 'conv_kernel_size', 0)
151
+ self.with_conv = False
152
+ if kernel_size > 0:
153
+ self.with_conv = True
154
+ self.conv = ConvLayer(config)
155
+
156
+ def get_rel_embedding(self):
157
+ rel_embeddings = self.rel_embeddings.weight if self.relative_attention else None
158
+ if rel_embeddings is not None and ('layer_norm' in self.norm_rel_ebd):
159
+ rel_embeddings = self.LayerNorm(rel_embeddings)
160
+ return rel_embeddings
161
+
162
+ def get_attention_mask(self, attention_mask):
163
+ if attention_mask.dim()<=2:
164
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
165
+ attention_mask = extended_attention_mask*extended_attention_mask.squeeze(-2).unsqueeze(-1)
166
+ attention_mask = attention_mask.byte()
167
+ elif attention_mask.dim()==3:
168
+ attention_mask = attention_mask.unsqueeze(1)
169
+
170
+ return attention_mask
171
+
172
+ def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None):
173
+ if self.relative_attention and relative_pos is None:
174
+ q = query_states.size(-2) if query_states is not None else hidden_states.size(-2)
175
+ relative_pos = build_relative_position(q, hidden_states.size(-2), bucket_size = self.position_buckets, max_position=self.max_relative_positions)
176
+ return relative_pos
177
+
178
+ def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True, return_att=False, query_states = None, relative_pos=None):
179
+ if attention_mask.dim()<=2:
180
+ input_mask = attention_mask
181
+ else:
182
+ input_mask = (attention_mask.sum(-2)>0).byte()
183
+ attention_mask = self.get_attention_mask(attention_mask)
184
+ relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos)
185
+
186
+ all_encoder_layers = []
187
+ att_matrices = []
188
+ if isinstance(hidden_states, Sequence):
189
+ next_kv = hidden_states[0]
190
+ else:
191
+ next_kv = hidden_states
192
+ rel_embeddings = self.get_rel_embedding()
193
+ for i, layer_module in enumerate(self.layer):
194
+ output_states = layer_module(next_kv, attention_mask, return_att, query_states = query_states, relative_pos=relative_pos, rel_embeddings=rel_embeddings)
195
+ if return_att:
196
+ output_states, att_m = output_states
197
+
198
+ if i == 0 and self.with_conv:
199
+ prenorm = output_states #output['prenorm_states']
200
+ output_states = self.conv(hidden_states, prenorm, input_mask)
201
+
202
+ if query_states is not None:
203
+ query_states = output_states
204
+ if isinstance(hidden_states, Sequence):
205
+ next_kv = hidden_states[i+1] if i+1 < len(self.layer) else None
206
+ else:
207
+ next_kv = output_states
208
+
209
+ if output_all_encoded_layers:
210
+ all_encoder_layers.append(output_states)
211
+ if return_att:
212
+ att_matrices.append(att_m)
213
+ if not output_all_encoded_layers:
214
+ all_encoder_layers.append(output_states)
215
+ if return_att:
216
+ att_matrices.append(att_m)
217
+ return {
218
+ 'hidden_states': all_encoder_layers,
219
+ 'attention_matrices': att_matrices
220
+ }
221
+
222
+ class BertEmbeddings(nn.Module):
223
+ """Construct the embeddings from word, position and token_type embeddings.
224
+ """
225
+ def __init__(self, config):
226
+ super(BertEmbeddings, self).__init__()
227
+ padding_idx = getattr(config, 'padding_idx', 0)
228
+ self.embedding_size = getattr(config, 'embedding_size', config.hidden_size)
229
+ self.word_embeddings = nn.Embedding(config.vocab_size, self.embedding_size, padding_idx = padding_idx)
230
+ self.position_biased_input = getattr(config, 'position_biased_input', True)
231
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.embedding_size)
232
+
233
+ if config.type_vocab_size>0:
234
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, self.embedding_size)
235
+
236
+ if self.embedding_size != config.hidden_size:
237
+ self.embed_proj = nn.Linear(self.embedding_size, config.hidden_size, bias=False)
238
+ self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
239
+ self.dropout = StableDropout(config.hidden_dropout_prob)
240
+ self.output_to_half = False
241
+ self.config = config
242
+
243
+ def forward(self, input_ids, token_type_ids=None, position_ids=None, mask = None):
244
+ seq_length = input_ids.size(1)
245
+ if position_ids is None:
246
+ position_ids = torch.arange(0, seq_length, dtype=torch.long, device=input_ids.device)
247
+ position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
248
+ if token_type_ids is None:
249
+ token_type_ids = torch.zeros_like(input_ids)
250
+
251
+ words_embeddings = self.word_embeddings(input_ids)
252
+ position_embeddings = self.position_embeddings(position_ids.long())
253
+
254
+ embeddings = words_embeddings
255
+ if self.config.type_vocab_size>0:
256
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
257
+ embeddings += token_type_embeddings
258
+
259
+ if self.position_biased_input:
260
+ embeddings += position_embeddings
261
+
262
+ if self.embedding_size != self.config.hidden_size:
263
+ embeddings = self.embed_proj(embeddings)
264
+ embeddings = MaskedLayerNorm(self.LayerNorm, embeddings, mask)
265
+ embeddings = self.dropout(embeddings)
266
+ return {
267
+ 'embeddings': embeddings,
268
+ 'position_embeddings': position_embeddings}
269
+
270
+ class BertLMPredictionHead(nn.Module):
271
+ def __init__(self, config, vocab_size):
272
+ super().__init__()
273
+ self.embedding_size = getattr(config, 'embedding_size', config.hidden_size)
274
+ self.dense = nn.Linear(config.hidden_size, self.embedding_size)
275
+ self.transform_act_fn = ACT2FN[config.hidden_act] \
276
+ if isinstance(config.hidden_act, str) else config.hidden_act
277
+
278
+ self.LayerNorm = LayerNorm(self.embedding_size, config.layer_norm_eps, elementwise_affine=True)
279
+
280
+ self.bias = nn.Parameter(torch.zeros(vocab_size))
281
+
282
+ def forward(self, hidden_states, embeding_weight):
283
+ hidden_states = self.dense(hidden_states)
284
+ hidden_states = self.transform_act_fn(hidden_states)
285
+ # b x s x d
286
+ hidden_states = MaskedLayerNorm(self.LayerNorm, hidden_states)
287
+
288
+ # b x s x v
289
+ logits = torch.matmul(hidden_states, embeding_weight.t().to(hidden_states)) + self.bias
290
+ return logits
291
+
292
+
293
+ class AR_MASK(object):
294
+ def get_attention_mask(self, input_ids=None, token_type_ids=None ):
295
+ seq_len = input_ids.size(1)
296
+ # idxs = torch.arange(0, seq_len)
297
+ # mask = idxs[None, :] <= idxs[:, None]
298
+ mask = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.uint8)).to(input_ids.device)
299
+ mask = mask.unsqueeze(0).expand(input_ids.size(0), seq_len, seq_len)
300
+ return mask
301
+ # torch.diagonal(torch.ones([input_ids.size(1), input_ids.size(1)])).byte().to(input_ids.device)
302
+
303
+ class Prefix_MASK(object):
304
+ def get_attention_mask(self, input_ids=None, token_type_ids=None):
305
+ idxs = torch.cumsum(token_type_ids, axis=1)
306
+ mask = idxs[:, None, :] <= idxs[:, :, None]
307
+ return mask.byte().to(input_ids.device)
modeling/cache_utils.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft, Inc. 2020
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ #
6
+ # Zhou Bo
7
+ # Date: 05/15/2020
8
+ #
9
+
10
+ import pdb
11
+ import torch
12
+ import os
13
+ import requests
14
+ from .config import ModelConfig
15
+ import pathlib
16
+ from ..utils import xtqdm as tqdm
17
+ from zipfile import ZipFile
18
+ import loguru
19
+ # from ..utils import get_logger
20
+ logger = loguru.logger
21
+
22
+ __all__ = ['pretrained_models', 'load_model_state', 'load_vocab']
23
+
24
+ class PretrainedModel:
25
+ def __init__(self, name, vocab, vocab_type, model='pytorch_model.bin', config='config.json', **kwargs):
26
+ self.__dict__.update(kwargs)
27
+ host = f'https://huggingface.co/microsoft/{name}/resolve/main/'
28
+ self.name = name
29
+ self.model_url = host + model
30
+ self.config_url = host + config
31
+ self.vocab_url = host + vocab
32
+ self.vocab_type = vocab_type
33
+
34
+ pretrained_models= {
35
+ 'base': PretrainedModel('deberta-base', 'bpe_encoder.bin', 'gpt2'),
36
+ 'large': PretrainedModel('deberta-large', 'bpe_encoder.bin', 'gpt2'),
37
+ 'xlarge': PretrainedModel('deberta-xlarge', 'bpe_encoder.bin', 'gpt2'),
38
+ 'base-mnli': PretrainedModel('deberta-base-mnli', 'bpe_encoder.bin', 'gpt2'),
39
+ 'large-mnli': PretrainedModel('deberta-large-mnli', 'bpe_encoder.bin', 'gpt2'),
40
+ 'xlarge-mnli': PretrainedModel('deberta-xlarge-mnli', 'bpe_encoder.bin', 'gpt2'),
41
+ 'xlarge-v2': PretrainedModel('deberta-v2-xlarge', 'spm.model', 'spm'),
42
+ 'xxlarge-v2': PretrainedModel('deberta-v2-xxlarge', 'spm.model', 'spm'),
43
+ 'xlarge-v2-mnli': PretrainedModel('deberta-v2-xlarge-mnli', 'spm.model', 'spm'),
44
+ 'xxlarge-v2-mnli': PretrainedModel('deberta-v2-xxlarge-mnli', 'spm.model', 'spm'),
45
+ 'deberta-v3-small': PretrainedModel('deberta-v3-small', 'spm.model', 'spm'),
46
+ 'deberta-v3-base': PretrainedModel('deberta-v3-base', 'spm.model', 'spm'),
47
+ 'deberta-v3-large': PretrainedModel('deberta-v3-large', 'spm.model', 'spm'),
48
+ 'mdeberta-v3-base': PretrainedModel('mdeberta-v3-base', 'spm.model', 'spm'),
49
+ 'deberta-v3-xsmall': PretrainedModel('deberta-v3-xsmall', 'spm.model', 'spm'),
50
+ }
51
+
52
+ def download_asset(url, name, tag=None, no_cache=False, cache_dir=None):
53
+ _tag = tag
54
+ if _tag is None:
55
+ _tag = 'latest'
56
+ if not cache_dir:
57
+ cache_dir = os.path.join(pathlib.Path.home(), f'.~DeBERTa/assets/{_tag}/')
58
+ os.makedirs(cache_dir, exist_ok=True)
59
+ output=os.path.join(cache_dir, name)
60
+ if os.path.exists(output) and (not no_cache):
61
+ return output
62
+
63
+ #repo=f'https://huggingface.co/microsoft/deberta-{name}/blob/main/bpe_encoder.bin'
64
+ headers = {}
65
+ headers['Accept'] = 'application/octet-stream'
66
+ resp = requests.get(url, stream=True, headers=headers)
67
+ if resp.status_code != 200:
68
+ raise Exception(f'Request for {url} return {resp.status_code}, {resp.text}')
69
+
70
+ try:
71
+ with open(output, 'wb') as fs:
72
+ progress = tqdm(total=int(resp.headers['Content-Length']) if 'Content-Length' in resp.headers else -1, ncols=80, desc=f'Downloading {name}')
73
+ for c in resp.iter_content(chunk_size=1024*1024):
74
+ fs.write(c)
75
+ progress.update(len(c))
76
+ progress.close()
77
+ except:
78
+ os.remove(output)
79
+ raise
80
+
81
+ return output
82
+
83
+ def load_model_state(path_or_pretrained_id, tag=None, no_cache=False, cache_dir=None):
84
+ model_path = path_or_pretrained_id
85
+ if model_path and (not os.path.exists(model_path)) and (path_or_pretrained_id.lower() in pretrained_models):
86
+ _tag = tag
87
+ pretrained = pretrained_models[path_or_pretrained_id.lower()]
88
+ if _tag is None:
89
+ _tag = 'latest'
90
+ if not cache_dir:
91
+ cache_dir = os.path.join(pathlib.Path.home(), f'.~DeBERTa/assets/{_tag}/{pretrained.name}')
92
+ os.makedirs(cache_dir, exist_ok=True)
93
+ model_path = os.path.join(cache_dir, 'pytorch_model.bin')
94
+ if (not os.path.exists(model_path)) or no_cache:
95
+ asset = download_asset(pretrained.model_url, 'pytorch_model.bin', tag=tag, no_cache=no_cache, cache_dir=cache_dir)
96
+ asset = download_asset(pretrained.config_url, 'model_config.json', tag=tag, no_cache=no_cache, cache_dir=cache_dir)
97
+ elif not model_path:
98
+ return None,None
99
+
100
+ config_path = os.path.join(os.path.dirname(model_path), 'model_config.json')
101
+ model_state = torch.load(model_path, map_location='cpu')
102
+ logger.info("Loaded pretrained model file {}".format(model_path))
103
+ if 'config' in model_state:
104
+ model_config = ModelConfig.from_dict(model_state['config'])
105
+ elif os.path.exists(config_path):
106
+ model_config = ModelConfig.from_json_file(config_path)
107
+ else:
108
+ model_config = None
109
+ return model_state, model_config
110
+
111
+ def load_vocab(vocab_path=None, vocab_type=None, pretrained_id=None, tag=None, no_cache=False, cache_dir=None):
112
+ if pretrained_id and (pretrained_id.lower() in pretrained_models):
113
+ _tag = tag
114
+ if _tag is None:
115
+ _tag = 'latest'
116
+
117
+ pretrained = pretrained_models[pretrained_id.lower()]
118
+ if not cache_dir:
119
+ cache_dir = os.path.join(pathlib.Path.home(), f'.~DeBERTa/assets/{_tag}/{pretrained.name}')
120
+ os.makedirs(cache_dir, exist_ok=True)
121
+ vocab_type = pretrained.vocab_type
122
+ url = pretrained.vocab_url
123
+ outname = os.path.basename(url)
124
+ vocab_path =os.path.join(cache_dir, outname)
125
+ if (not os.path.exists(vocab_path)) or no_cache:
126
+ asset = download_asset(url, outname, tag=tag, no_cache=no_cache, cache_dir=cache_dir)
127
+ if vocab_type is None:
128
+ vocab_type = 'spm'
129
+ return vocab_path, vocab_type
130
+
131
+ def test_download():
132
+ vocab = load_vocab()
modeling/config.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import copy
3
+
4
+ __all__=['AbsModelConfig', 'ModelConfig']
5
+
6
+ class AbsModelConfig(object):
7
+ def __init__(self):
8
+ pass
9
+
10
+ @classmethod
11
+ def from_dict(cls, json_object):
12
+ """Constructs a `ModelConfig` from a Python dictionary of parameters."""
13
+ config = cls()
14
+ for key, value in json_object.items():
15
+ if isinstance(value, dict):
16
+ value = AbsModelConfig.from_dict(value)
17
+ config.__dict__[key] = value
18
+ return config
19
+
20
+ @classmethod
21
+ def from_json_file(cls, json_file):
22
+ """Constructs a `ModelConfig` from a json file of parameters."""
23
+ with open(json_file, "r", encoding='utf-8') as reader:
24
+ text = reader.read()
25
+ return cls.from_dict(json.loads(text))
26
+
27
+ def __repr__(self):
28
+ return str(self.to_json_string())
29
+
30
+ def to_dict(self):
31
+ """Serializes this instance to a Python dictionary."""
32
+ output = copy.deepcopy(self.__dict__)
33
+ return output
34
+
35
+ def to_json_string(self):
36
+ """Serializes this instance to a JSON string."""
37
+ def _json_default(obj):
38
+ if isinstance(obj, AbsModelConfig):
39
+ return obj.__dict__
40
+ return json.dumps(self.__dict__, indent=2, sort_keys=True, default=_json_default) + "\n"
41
+
42
+ class ModelConfig(AbsModelConfig):
43
+ """Configuration class to store the configuration of a :class:`~DeBERTa.deberta.DeBERTa` model.
44
+
45
+ Attributes:
46
+ hidden_size (int): Size of the encoder layers and the pooler layer, default: `768`.
47
+ num_hidden_layers (int): Number of hidden layers in the Transformer encoder, default: `12`.
48
+ num_attention_heads (int): Number of attention heads for each attention layer in
49
+ the Transformer encoder, default: `12`.
50
+ intermediate_size (int): The size of the "intermediate" (i.e., feed-forward)
51
+ layer in the Transformer encoder, default: `3072`.
52
+ hidden_act (str): The non-linear activation function (function or string) in the
53
+ encoder and pooler. If string, "gelu", "relu" and "swish" are supported, default: `gelu`.
54
+ hidden_dropout_prob (float): The dropout probabilitiy for all fully connected
55
+ layers in the embeddings, encoder, and pooler, default: `0.1`.
56
+ attention_probs_dropout_prob (float): The dropout ratio for the attention
57
+ probabilities, default: `0.1`.
58
+ max_position_embeddings (int): The maximum sequence length that this model might
59
+ ever be used with. Typically set this to something large just in case
60
+ (e.g., 512 or 1024 or 2048), default: `512`.
61
+ type_vocab_size (int): The vocabulary size of the `token_type_ids` passed into
62
+ `DeBERTa` model, default: `-1`.
63
+ initializer_range (int): The sttdev of the _normal_initializer for
64
+ initializing all weight matrices, default: `0.02`.
65
+ relative_attention (:obj:`bool`): Whether use relative position encoding, default: `False`.
66
+ max_relative_positions (int): The range of relative positions [`-max_position_embeddings`, `max_position_embeddings`], default: -1, use the same value as `max_position_embeddings`.
67
+ padding_idx (int): The value used to pad input_ids, default: `0`.
68
+ position_biased_input (:obj:`bool`): Whether add absolute position embedding to content embedding, default: `True`.
69
+ pos_att_type (:obj:`str`): The type of relative position attention, it can be a combination of [`p2c`, `c2p`, `p2p`], e.g. "p2c", "p2c|c2p", "p2c|c2p|p2p"., default: "None".
70
+
71
+
72
+ """
73
+ def __init__(self):
74
+ """Constructs ModelConfig.
75
+
76
+ """
77
+
78
+ self.hidden_size = 768
79
+ self.num_hidden_layers = 12
80
+ self.num_attention_heads = 12
81
+ self.hidden_act = "gelu"
82
+ self.intermediate_size = 3072
83
+ self.hidden_dropout_prob = 0.1
84
+ self.attention_probs_dropout_prob = 0.1
85
+ self.max_position_embeddings = 512
86
+ self.type_vocab_size = 0
87
+ self.initializer_range = 0.02
88
+ self.layer_norm_eps = 1e-7
89
+ self.padding_idx = 0
90
+ self.vocab_size = -1
modeling/da_utils.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pdb
3
+ from functools import lru_cache
4
+ import numpy as np
5
+
6
+ __all__=['build_relative_position', 'make_log_bucket_position']
7
+
8
+ def make_log_bucket_position(relative_pos, bucket_size, max_position):
9
+ sign = np.sign(relative_pos)
10
+ mid = bucket_size//2
11
+ abs_pos = np.where((relative_pos<mid) & (relative_pos > -mid), mid-1, np.abs(relative_pos))
12
+ log_pos = np.ceil(np.log(abs_pos/mid)/np.log((max_position-1)/mid) * (mid-1)) + mid
13
+ bucket_pos = np.where(abs_pos<=mid, relative_pos, log_pos*sign).astype(np.int)
14
+ return bucket_pos
15
+
16
+ @lru_cache(maxsize=128)
17
+ def build_relative_position(query_size, key_size, bucket_size=-1, max_position=-1):
18
+ q_ids = np.arange(0, query_size)
19
+ k_ids = np.arange(0, key_size)
20
+ rel_pos_ids = q_ids[:, None] - np.tile(k_ids, (q_ids.shape[0],1))
21
+ if bucket_size>0 and max_position > 0:
22
+ rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position)
23
+ rel_pos_ids = torch.tensor(rel_pos_ids, dtype=torch.long)
24
+ rel_pos_ids = rel_pos_ids[:query_size, :]
25
+ rel_pos_ids = rel_pos_ids.unsqueeze(0)
26
+ return rel_pos_ids
27
+
28
+ def test_log_bucket():
29
+ x=np.arange(-511,511)
30
+ y=make_log_bucket_position(x, 128, 512)
31
+ # pdb.set_trace()
32
+
33
+
34
+ if __name__ == '__main__':
35
+ test_log_bucket()
36
+ build_relative_position(query_size=16, key_size=16, bucket_size=4, max_position=16)
modeling/deberta.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft, Inc. 2020
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ #
6
+ # Zhou Bo
7
+ # Date: 01/15/2020
8
+ #
9
+
10
+ import copy
11
+ import torch
12
+ import os
13
+
14
+ import json
15
+ from .ops import *
16
+ from .bert import *
17
+ from .config import ModelConfig
18
+ from .cache_utils import load_model_state
19
+ import pdb
20
+
21
+ __all__ = ['DeBERTa']
22
+
23
+ class DeBERTa(torch.nn.Module):
24
+ """ DeBERTa encoder
25
+ This module is composed of the input embedding layer with stacked transformer layers with disentangled attention.
26
+
27
+ Parameters:
28
+ config:
29
+ A model config class instance with the configuration to build a new model. The schema is similar to `BertConfig`, \
30
+ for more details, please refer :class:`~DeBERTa.deberta.ModelConfig`
31
+
32
+ pre_trained:
33
+ The pre-trained DeBERTa model, it can be a physical path of a pre-trained DeBERTa model or a released configurations, \
34
+ i.e. [**base, large, base_mnli, large_mnli**]
35
+
36
+ """
37
+
38
+ def __init__(self, config=None, pre_trained=None):
39
+ super().__init__()
40
+ state = None
41
+ if pre_trained is not None:
42
+ state, model_config = load_model_state(pre_trained)
43
+ if config is not None and model_config is not None:
44
+ for k in config.__dict__:
45
+ if k not in ['hidden_size',
46
+ 'intermediate_size',
47
+ 'num_attention_heads',
48
+ 'num_hidden_layers',
49
+ 'vocab_size',
50
+ 'max_position_embeddings']:
51
+ model_config.__dict__[k] = config.__dict__[k]
52
+ config = copy.copy(model_config)
53
+ self.embeddings = BertEmbeddings(config)
54
+ self.encoder = BertEncoder(config)
55
+ self.config = config
56
+ self.pre_trained = pre_trained
57
+ self.apply_state(state)
58
+
59
+ def forward(self, input_ids, attention_mask=None, token_type_ids=None, output_all_encoded_layers=True, position_ids = None, return_att = False):
60
+ """
61
+ Args:
62
+ input_ids:
63
+ a torch.LongTensor of shape [batch_size, sequence_length] \
64
+ with the word token indices in the vocabulary
65
+
66
+ attention_mask:
67
+ an optional parameter for input mask or attention mask.
68
+
69
+ - If it's an input mask, then it will be torch.LongTensor of shape [batch_size, sequence_length] with indices \
70
+ selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max \
71
+ input sequence length in the current batch. It's the mask that we typically use for attention when \
72
+ a batch has varying length sentences.
73
+
74
+ - If it's an attention mask then it will be torch.LongTensor of shape [batch_size, sequence_length, sequence_length]. \
75
+ In this case, it's a mask indicate which tokens in the sequence should be attended by other tokens in the sequence.
76
+
77
+ token_type_ids:
78
+ an optional torch.LongTensor of shape [batch_size, sequence_length] with the token \
79
+ types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to \
80
+ a `sentence B` token (see BERT paper for more details).
81
+
82
+ output_all_encoded_layers:
83
+ whether to output results of all encoder layers, default, True
84
+
85
+ Returns:
86
+
87
+ - The output of the stacked transformer layers if `output_all_encoded_layers=True`, else \
88
+ the last layer of stacked transformer layers
89
+
90
+ - Attention matrix of self-attention layers if `return_att=True`
91
+
92
+
93
+ Example::
94
+
95
+ # Batch of wordPiece token ids.
96
+ # Each sample was padded with zero to the maxium length of the batch
97
+ input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
98
+ # Mask of valid input ids
99
+ attention_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
100
+
101
+ # DeBERTa model initialized with pretrained base model
102
+ bert = DeBERTa(pre_trained='base')
103
+
104
+ encoder_layers = bert(input_ids, attention_mask=attention_mask)
105
+
106
+ """
107
+
108
+ if attention_mask is None:
109
+ attention_mask = torch.ones_like(input_ids)
110
+ if token_type_ids is None:
111
+ token_type_ids = torch.zeros_like(input_ids)
112
+ token_mask = torch.ones_like(input_ids)
113
+ else:
114
+ idxs = torch.flip(torch.cumsum(torch.flip(token_type_ids, [-1]), axis=1), [-1])
115
+ token_mask = idxs > 0
116
+ token_mask = token_mask.byte()
117
+ ebd_output = self.embeddings(input_ids.to(torch.long), token_type_ids.to(torch.long), position_ids, token_mask)
118
+ embedding_output = ebd_output['embeddings']
119
+ encoder_output = self.encoder(embedding_output,
120
+ attention_mask,
121
+ output_all_encoded_layers=output_all_encoded_layers, return_att = return_att)
122
+ encoder_output.update(ebd_output)
123
+ return encoder_output
124
+
125
+ def apply_state(self, state = None):
126
+ """ Load state from previous loaded model state dictionary.
127
+
128
+ Args:
129
+ state (:obj:`dict`, optional): State dictionary as the state returned by torch.module.state_dict(), default: `None`. \
130
+ If it's `None`, then will use the pre-trained state loaded via the constructor to re-initialize \
131
+ the `DeBERTa` model
132
+ """
133
+ if self.pre_trained is None and state is None:
134
+ return
135
+ if state is None:
136
+ state, config = load_model_state(self.pre_trained)
137
+ self.config = config
138
+
139
+ prefix = ''
140
+ for k in state:
141
+ if 'embeddings.' in k:
142
+ if not k.startswith('embeddings.'):
143
+ prefix = k[:k.index('embeddings.')]
144
+ break
145
+
146
+ missing_keys = []
147
+ unexpected_keys = []
148
+ error_msgs = []
149
+ self._load_from_state_dict(state, prefix = prefix, local_metadata=None, strict=True, missing_keys=missing_keys, unexpected_keys=unexpected_keys, error_msgs=error_msgs)
modeling/disentangled_attention.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft, Inc. 2020
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ #
6
+ # Zhou Bo
7
+ # Date: 01/15/2020
8
+ #
9
+
10
+ """
11
+ Disentangled SelfAttention module
12
+ """
13
+
14
+ import numpy as np
15
+ import math
16
+ import torch
17
+ from torch import nn
18
+ import functools
19
+ import pdb
20
+
21
+ from .ops import *
22
+ from .da_utils import build_relative_position
23
+
24
+ import loguru
25
+ logger=loguru.logger
26
+
27
+ __all__=['DisentangledSelfAttention']
28
+ class DisentangledSelfAttention(nn.Module):
29
+ def __init__(self, config):
30
+ super().__init__()
31
+ self.num_attention_heads = config.num_attention_heads
32
+ _attention_head_size = int(config.hidden_size / config.num_attention_heads)
33
+ self.attention_head_size = getattr(config, 'attention_head_size', _attention_head_size)
34
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
35
+ self.query_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
36
+ self.key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
37
+ self.value_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
38
+
39
+ self.share_att_key = getattr(config, 'share_att_key', False)
40
+ self.pos_att_type = [x.strip() for x in getattr(config, 'pos_att_type', 'c2p').lower().split('|')] # c2p|p2c
41
+ self.relative_attention = getattr(config, 'relative_attention', False)
42
+
43
+ if self.relative_attention:
44
+ self.position_buckets = getattr(config, 'position_buckets', -1)
45
+ self.max_relative_positions = getattr(config, 'max_relative_positions', -1)
46
+ if self.max_relative_positions <1:
47
+ self.max_relative_positions = config.max_position_embeddings
48
+ self.pos_ebd_size = self.max_relative_positions
49
+ if self.position_buckets>0:
50
+ self.pos_ebd_size = self.position_buckets
51
+ # For backward compitable
52
+
53
+ self.pos_dropout = StableDropout(config.hidden_dropout_prob)
54
+
55
+ if (not self.share_att_key):
56
+ if 'c2p' in self.pos_att_type or 'p2p' in self.pos_att_type:
57
+ self.pos_key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
58
+ if 'p2c' in self.pos_att_type or 'p2p' in self.pos_att_type:
59
+ self.pos_query_proj = nn.Linear(config.hidden_size, self.all_head_size)
60
+
61
+ self.dropout = StableDropout(config.attention_probs_dropout_prob)
62
+ self._register_load_state_dict_pre_hook(self._pre_load_hook)
63
+
64
+ def transpose_for_scores(self, x, attention_heads):
65
+ new_x_shape = x.size()[:-1] + (attention_heads, -1)
66
+ x = x.view(*new_x_shape)
67
+ return x.permute(0, 2, 1, 3).contiguous().view(-1, x.size(1), x.size(-1))
68
+
69
+ def forward(self, hidden_states, attention_mask, return_att=False, query_states=None, relative_pos=None, rel_embeddings=None):
70
+ if query_states is None:
71
+ query_states = hidden_states
72
+ query_layer = self.transpose_for_scores(self.query_proj(query_states), self.num_attention_heads).float()
73
+ key_layer = self.transpose_for_scores(self.key_proj(hidden_states), self.num_attention_heads).float()
74
+ value_layer = self.transpose_for_scores(self.value_proj(hidden_states), self.num_attention_heads)
75
+
76
+ rel_att = None
77
+ # Take the dot product between "query" and "key" to get the raw attention scores.
78
+ scale_factor = 1
79
+ if 'c2p' in self.pos_att_type:
80
+ scale_factor += 1
81
+ if 'p2c' in self.pos_att_type:
82
+ scale_factor += 1
83
+ if 'p2p' in self.pos_att_type:
84
+ scale_factor += 1
85
+ scale = 1/math.sqrt(query_layer.size(-1)*scale_factor)
86
+ attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)*scale)
87
+ if self.relative_attention:
88
+ rel_embeddings = self.pos_dropout(rel_embeddings)
89
+ rel_att = self.disentangled_attention_bias(query_layer, key_layer, relative_pos, rel_embeddings, scale_factor)
90
+
91
+ if rel_att is not None:
92
+ attention_scores = (attention_scores + rel_att)
93
+ attention_scores = (attention_scores - attention_scores.max(dim=-1, keepdim=True).values.detach()).to(hidden_states)
94
+ attention_scores = attention_scores.view(-1, self.num_attention_heads, attention_scores.size(-2), attention_scores.size(-1))
95
+
96
+ # bxhxlxd
97
+ _attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1)
98
+ attention_probs = self.dropout(_attention_probs)
99
+ context_layer = torch.bmm(attention_probs.view(-1, attention_probs.size(-2), attention_probs.size(-1)), value_layer)
100
+ context_layer = context_layer.view(-1, self.num_attention_heads, context_layer.size(-2), context_layer.size(-1)).permute(0, 2, 1, 3).contiguous()
101
+ new_context_layer_shape = context_layer.size()[:-2] + (-1,)
102
+ context_layer = context_layer.view(*new_context_layer_shape)
103
+
104
+ return {
105
+ 'hidden_states': context_layer,
106
+ 'attention_probs': _attention_probs,
107
+ 'attention_logits': attention_scores
108
+ }
109
+
110
+ def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor):
111
+ if relative_pos is None:
112
+ q = query_layer.size(-2)
113
+ relative_pos = build_relative_position(q, key_layer.size(-2), bucket_size = self.position_buckets, max_position = self.max_relative_positions)
114
+ if relative_pos.dim()==2:
115
+ relative_pos = relative_pos.unsqueeze(0).unsqueeze(0)
116
+ elif relative_pos.dim()==3:
117
+ relative_pos = relative_pos.unsqueeze(1)
118
+ # bxhxqxk
119
+ elif relative_pos.dim()!=4:
120
+ raise ValueError(f'Relative postion ids must be of dim 2 or 3 or 4. {relative_pos.dim()}')
121
+
122
+ att_span = self.pos_ebd_size
123
+ relative_pos = relative_pos.long().to(query_layer.device)
124
+
125
+ rel_embeddings = rel_embeddings[self.pos_ebd_size - att_span:self.pos_ebd_size + att_span, :].unsqueeze(0) #.repeat(query_layer.size(0)//self.num_attention_heads, 1, 1)
126
+ if self.share_att_key:
127
+ pos_query_layer = self.transpose_for_scores(self.query_proj(rel_embeddings), self.num_attention_heads)\
128
+ .repeat(query_layer.size(0)//self.num_attention_heads, 1, 1) #.split(self.all_head_size, dim=-1)
129
+ pos_key_layer = self.transpose_for_scores(self.key_proj(rel_embeddings), self.num_attention_heads)\
130
+ .repeat(query_layer.size(0)//self.num_attention_heads, 1, 1) #.split(self.all_head_size, dim=-1)
131
+ else:
132
+ if 'c2p' in self.pos_att_type or 'p2p' in self.pos_att_type:
133
+ pos_key_layer = self.transpose_for_scores(self.pos_key_proj(rel_embeddings), self.num_attention_heads)\
134
+ .repeat(query_layer.size(0)//self.num_attention_heads, 1, 1) #.split(self.all_head_size, dim=-1)
135
+ if 'p2c' in self.pos_att_type or 'p2p' in self.pos_att_type:
136
+ pos_query_layer = self.transpose_for_scores(self.pos_query_proj(rel_embeddings), self.num_attention_heads)\
137
+ .repeat(query_layer.size(0)//self.num_attention_heads, 1, 1) #.split(self.all_head_size, dim=-1)
138
+
139
+ score = 0
140
+ # content->position
141
+ if 'c2p' in self.pos_att_type:
142
+ scale = 1/math.sqrt(pos_key_layer.size(-1)*scale_factor)
143
+ c2p_att = torch.bmm(query_layer, pos_key_layer.transpose(-1, -2).to(query_layer)*scale)
144
+ c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span*2-1)
145
+ c2p_att = torch.gather(c2p_att, dim=-1, index=c2p_pos.squeeze(0).expand([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)]))
146
+ try:
147
+ score += c2p_att
148
+ except:
149
+ print(c2p_att.size())
150
+
151
+ # position->content
152
+ if 'p2c' in self.pos_att_type or 'p2p' in self.pos_att_type:
153
+ scale = 1/math.sqrt(pos_query_layer.size(-1)*scale_factor)
154
+ if key_layer.size(-2) != query_layer.size(-2):
155
+ r_pos = build_relative_position(key_layer.size(-2), key_layer.size(-2), bucket_size = self.position_buckets, max_position = self.max_relative_positions).to(query_layer.device)
156
+ r_pos = r_pos.unsqueeze(0)
157
+ else:
158
+ r_pos = relative_pos
159
+
160
+ p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span*2-1)
161
+ if query_layer.size(-2) != key_layer.size(-2):
162
+ pos_index = relative_pos[:, :, :, 0].unsqueeze(-1)
163
+
164
+ if 'p2c' in self.pos_att_type:
165
+ p2c_att = torch.bmm(key_layer, pos_query_layer.transpose(-1, -2).to(key_layer)*scale)
166
+ p2c_att = torch.gather(p2c_att, dim=-1, index=p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)])).transpose(-1,-2)
167
+ if query_layer.size(-2) != key_layer.size(-2):
168
+ p2c_att = torch.gather(p2c_att, dim=-2, index=pos_index.expand(p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2))))
169
+ score += p2c_att
170
+
171
+ # position->position
172
+ if 'p2p' in self.pos_att_type:
173
+ pos_query = pos_query_layer[:,:,att_span:,:]
174
+ p2p_att = torch.matmul(pos_query, pos_key_layer.transpose(-1, -2))
175
+ p2p_att = p2p_att.expand(query_layer.size()[:2] + p2p_att.size()[2:])
176
+ if query_layer.size(-2) != key_layer.size(-2):
177
+ p2p_att = torch.gather(p2p_att, dim=-2, index=pos_index.expand(query_layer.size()[:2] + (pos_index.size(-2), p2p_att.size(-1))))
178
+ p2p_att = torch.gather(p2p_att, dim=-1, index=c2p_pos.expand([query_layer.size(0), query_layer.size(1), query_layer.size(2), relative_pos.size(-1)]))
179
+ score += p2p_att
180
+
181
+ return score
182
+
183
+ def _pre_load_hook(self, state_dict, prefix, local_metadata, strict,
184
+ missing_keys, unexpected_keys, error_msgs):
185
+ self_state = self.state_dict()
186
+ if ((prefix + 'query_proj.weight') not in state_dict) and ((prefix + 'in_proj.weight') in state_dict):
187
+ v1_proj = state_dict[prefix+'in_proj.weight']
188
+ v1_proj = v1_proj.unsqueeze(0).reshape(self.num_attention_heads, -1, v1_proj.size(-1))
189
+ q,k,v=v1_proj.chunk(3, dim=1)
190
+ state_dict[prefix + 'query_proj.weight'] = q.reshape(-1, v1_proj.size(-1))
191
+ state_dict[prefix + 'key_proj.weight'] = k.reshape(-1, v1_proj.size(-1))
192
+ state_dict[prefix + 'key_proj.bias'] = self_state['key_proj.bias']
193
+ state_dict[prefix + 'value_proj.weight'] = v.reshape(-1, v1_proj.size(-1))
194
+ v1_query_bias = state_dict[prefix + 'q_bias']
195
+ state_dict[prefix + 'query_proj.bias'] = v1_query_bias
196
+ v1_value_bias = state_dict[prefix +'v_bias']
197
+ state_dict[prefix + 'value_proj.bias'] = v1_value_bias
198
+
199
+ v1_pos_key_proj = state_dict[prefix + 'pos_proj.weight']
200
+ state_dict[prefix + 'pos_key_proj.weight'] = v1_pos_key_proj
201
+ v1_pos_query_proj = state_dict[prefix + 'pos_q_proj.weight']
202
+ state_dict[prefix + 'pos_query_proj.weight'] = v1_pos_query_proj
203
+ v1_pos_query_proj_bias = state_dict[prefix + 'pos_q_proj.bias']
204
+ state_dict[prefix + 'pos_query_proj.bias'] = v1_pos_query_proj_bias
205
+ state_dict[prefix + 'pos_key_proj.bias'] = self_state['pos_key_proj.bias']
206
+
207
+ del state_dict[prefix + 'in_proj.weight']
208
+ del state_dict[prefix + 'q_bias']
209
+ del state_dict[prefix + 'v_bias']
210
+ del state_dict[prefix + 'pos_proj.weight']
211
+ del state_dict[prefix + 'pos_q_proj.weight']
212
+ del state_dict[prefix + 'pos_q_proj.bias']
modeling/file_utils.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for working with the local dataset cache.
3
+ This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
4
+ Copyright by the AllenNLP authors.
5
+ """
6
+
7
+ import os
8
+ import logging
9
+ import shutil
10
+ import tempfile
11
+ import json
12
+ from urllib.parse import urlparse
13
+ from pathlib import Path
14
+ from typing import Optional, Tuple, Union, IO, Callable, Set
15
+ from hashlib import sha256
16
+ from functools import wraps
17
+
18
+ from tqdm import tqdm
19
+
20
+ import boto3
21
+ from botocore.exceptions import ClientError
22
+ import requests
23
+
24
+ logger = logging.getLogger(__name__) # pylint: disable=invalid-name
25
+
26
+ PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
27
+ Path.home() / '.pytorch_pretrained_bert'))
28
+
29
+
30
+ def url_to_filename(url: str, etag: str = None) -> str:
31
+ """
32
+ Convert `url` into a hashed filename in a repeatable way.
33
+ If `etag` is specified, append its hash to the url's, delimited
34
+ by a period.
35
+ """
36
+ url_bytes = url.encode('utf-8')
37
+ url_hash = sha256(url_bytes)
38
+ filename = url_hash.hexdigest()
39
+
40
+ if etag:
41
+ etag_bytes = etag.encode('utf-8')
42
+ etag_hash = sha256(etag_bytes)
43
+ filename += '.' + etag_hash.hexdigest()
44
+
45
+ return filename
46
+
47
+
48
+ def filename_to_url(filename: str, cache_dir: Union[str, Path] = None) -> Tuple[str, str]:
49
+ """
50
+ Return the url and etag (which may be ``None``) stored for `filename`.
51
+ Raise ``FileNotFoundError`` if `filename` or its stored metadata do not exist.
52
+ """
53
+ if cache_dir is None:
54
+ cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
55
+ if isinstance(cache_dir, Path):
56
+ cache_dir = str(cache_dir)
57
+
58
+ cache_path = os.path.join(cache_dir, filename)
59
+ if not os.path.exists(cache_path):
60
+ raise FileNotFoundError("file {} not found".format(cache_path))
61
+
62
+ meta_path = cache_path + '.json'
63
+ if not os.path.exists(meta_path):
64
+ raise FileNotFoundError("file {} not found".format(meta_path))
65
+
66
+ with open(meta_path) as meta_file:
67
+ metadata = json.load(meta_file)
68
+ url = metadata['url']
69
+ etag = metadata['etag']
70
+
71
+ return url, etag
72
+
73
+
74
+ def cached_path(url_or_filename: Union[str, Path], cache_dir: Union[str, Path] = None) -> str:
75
+ """
76
+ Given something that might be a URL (or might be a local path),
77
+ determine which. If it's a URL, download the file and cache it, and
78
+ return the path to the cached file. If it's already a local path,
79
+ make sure the file exists and then return the path.
80
+ """
81
+ if cache_dir is None:
82
+ cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
83
+ if isinstance(url_or_filename, Path):
84
+ url_or_filename = str(url_or_filename)
85
+ if isinstance(cache_dir, Path):
86
+ cache_dir = str(cache_dir)
87
+
88
+ parsed = urlparse(url_or_filename)
89
+
90
+ if parsed.scheme in ('http', 'https', 's3'):
91
+ # URL, so get it from the cache (downloading if necessary)
92
+ return get_from_cache(url_or_filename, cache_dir)
93
+ elif os.path.exists(url_or_filename):
94
+ # File, and it exists.
95
+ return url_or_filename
96
+ elif parsed.scheme == '':
97
+ # File, but it doesn't exist.
98
+ raise FileNotFoundError("file {} not found".format(url_or_filename))
99
+ else:
100
+ # Something unknown
101
+ raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
102
+
103
+
104
+ def split_s3_path(url: str) -> Tuple[str, str]:
105
+ """Split a full s3 path into the bucket name and path."""
106
+ parsed = urlparse(url)
107
+ if not parsed.netloc or not parsed.path:
108
+ raise ValueError("bad s3 path {}".format(url))
109
+ bucket_name = parsed.netloc
110
+ s3_path = parsed.path
111
+ # Remove '/' at beginning of path.
112
+ if s3_path.startswith("/"):
113
+ s3_path = s3_path[1:]
114
+ return bucket_name, s3_path
115
+
116
+
117
+ def s3_request(func: Callable):
118
+ """
119
+ Wrapper function for s3 requests in order to create more helpful error
120
+ messages.
121
+ """
122
+
123
+ @wraps(func)
124
+ def wrapper(url: str, *args, **kwargs):
125
+ try:
126
+ return func(url, *args, **kwargs)
127
+ except ClientError as exc:
128
+ if int(exc.response["Error"]["Code"]) == 404:
129
+ raise FileNotFoundError("file {} not found".format(url))
130
+ else:
131
+ raise
132
+
133
+ return wrapper
134
+
135
+
136
+ @s3_request
137
+ def s3_etag(url: str) -> Optional[str]:
138
+ """Check ETag on S3 object."""
139
+ s3_resource = boto3.resource("s3")
140
+ bucket_name, s3_path = split_s3_path(url)
141
+ s3_object = s3_resource.Object(bucket_name, s3_path)
142
+ return s3_object.e_tag
143
+
144
+
145
+ @s3_request
146
+ def s3_get(url: str, temp_file: IO) -> None:
147
+ """Pull a file directly from S3."""
148
+ s3_resource = boto3.resource("s3")
149
+ bucket_name, s3_path = split_s3_path(url)
150
+ s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
151
+
152
+
153
+ def http_get(url: str, temp_file: IO) -> None:
154
+ req = requests.get(url, stream=True)
155
+ content_length = req.headers.get('Content-Length')
156
+ total = int(content_length) if content_length is not None else None
157
+ progress = tqdm(unit="B", total=total)
158
+ for chunk in req.iter_content(chunk_size=1024):
159
+ if chunk: # filter out keep-alive new chunks
160
+ progress.update(len(chunk))
161
+ temp_file.write(chunk)
162
+ progress.close()
163
+
164
+
165
+ def get_from_cache(url: str, cache_dir: Union[str, Path] = None) -> str:
166
+ """
167
+ Given a URL, look for the corresponding dataset in the local cache.
168
+ If it's not there, download it. Then return the path to the cached file.
169
+ """
170
+ if cache_dir is None:
171
+ cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
172
+ if isinstance(cache_dir, Path):
173
+ cache_dir = str(cache_dir)
174
+
175
+ os.makedirs(cache_dir, exist_ok=True)
176
+
177
+ # Get eTag to add to filename, if it exists.
178
+ if url.startswith("s3://"):
179
+ etag = s3_etag(url)
180
+ else:
181
+ response = requests.head(url, allow_redirects=True)
182
+ if response.status_code != 200:
183
+ raise IOError("HEAD request failed for url {} with status code {}"
184
+ .format(url, response.status_code))
185
+ etag = response.headers.get("ETag")
186
+
187
+ filename = url_to_filename(url, etag)
188
+
189
+ # get cache path to put the file
190
+ cache_path = os.path.join(cache_dir, filename)
191
+
192
+ if not os.path.exists(cache_path):
193
+ # Download to temporary file, then copy to cache dir once finished.
194
+ # Otherwise you get corrupt cache entries if the download gets interrupted.
195
+ with tempfile.NamedTemporaryFile() as temp_file:
196
+ logger.info("%s not found in cache, downloading to %s", url, temp_file.name)
197
+
198
+ # GET file object
199
+ if url.startswith("s3://"):
200
+ s3_get(url, temp_file)
201
+ else:
202
+ http_get(url, temp_file)
203
+
204
+ # we are copying the file before closing it, so flush to avoid truncation
205
+ temp_file.flush()
206
+ # shutil.copyfileobj() starts at the current position, so go to the start
207
+ temp_file.seek(0)
208
+
209
+ logger.info("copying %s to cache at %s", temp_file.name, cache_path)
210
+ with open(cache_path, 'wb') as cache_file:
211
+ shutil.copyfileobj(temp_file, cache_file)
212
+
213
+ logger.info("creating metadata file for %s", cache_path)
214
+ meta = {'url': url, 'etag': etag}
215
+ meta_path = cache_path + '.json'
216
+ with open(meta_path, 'w') as meta_file:
217
+ json.dump(meta, meta_file)
218
+
219
+ logger.info("removing temp file %s", temp_file.name)
220
+
221
+ return cache_path
222
+
223
+
224
+ def read_set_from_file(filename: str) -> Set[str]:
225
+ '''
226
+ Extract a de-duped collection (set) of text from a file.
227
+ Expected file format is one item per line.
228
+ '''
229
+ collection = set()
230
+ with open(filename, 'r', encoding='utf-8') as file_:
231
+ for line in file_:
232
+ collection.add(line.rstrip())
233
+ return collection
234
+
235
+
236
+ def get_file_extension(path: str, dot=True, lower: bool = True):
237
+ ext = os.path.splitext(path)[1]
238
+ ext = ext if dot else ext[1:]
239
+ return ext.lower() if lower else ext
modeling/flash.py ADDED
@@ -0,0 +1,794 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Zhoubo
3
+ #
4
+ """
5
+ FLASH: https://arxiv.org/abs/2202.10447
6
+ """
7
+ import copy
8
+ import torch
9
+ import os
10
+ from collections import Sequence
11
+ import json
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ from transformers.activations import ACT2FN
17
+ from .modeling import *
18
+ from .ops import XSoftmax, sequence_masking
19
+
20
+ from .bert import *
21
+ from .config import ModelConfig
22
+ from .cache_utils import load_model_state
23
+ import einops
24
+
25
+
26
+ class ScaleNorm(nn.Module):
27
+ def __init__(self, eps=1e-5):
28
+ super().__init__()
29
+ self.eps = eps
30
+ self.scala = nn.Parameter(torch.ones(1))
31
+
32
+ def forward(self, x):
33
+ mean_square = (x ** 2).mean(dim=-1, keepdim=True)
34
+ x = x * torch.rsqrt(mean_square + self.eps) * self.scala
35
+ return x
36
+
37
+
38
+
39
+ class OffsetScale(nn.Module):
40
+ def __init__(self, dim, heads = 1):
41
+ super().__init__()
42
+ self.gamma = nn.Parameter(torch.ones(heads, dim))
43
+ self.beta = nn.Parameter(torch.zeros(heads, dim))
44
+ # nn.init.normal_(self.gamma, std = 0.02)
45
+ # nn.init.xavier_uniform_(self.gamma)
46
+
47
+ def forward(self, x):
48
+ out = (x * self.gamma) + self.beta
49
+ return out
50
+
51
+
52
+ class ScaledSinuEmbedding(nn.Module):
53
+ def __init__(self, dim):
54
+ super().__init__()
55
+ self.scale = nn.Parameter(torch.ones(1,))
56
+ inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
57
+ self.register_buffer('inv_freq', inv_freq)
58
+
59
+ def forward(self, x):
60
+ n, device = x.shape[1], x.device
61
+ t = torch.arange(n, device = device).type_as(self.inv_freq)
62
+ sinu = torch.einsum('i , j -> i j', t, self.inv_freq)
63
+ emb = torch.cat((sinu.sin(), sinu.cos()), dim = -1)
64
+ return emb * self.scale
65
+
66
+
67
+ def RoPE(x, dim):
68
+ """
69
+ :param x: input tensor
70
+ :param dim: oprate dimension
71
+ :return: tensor
72
+ """
73
+ shape = x.shape
74
+ if isinstance(dim, int):
75
+ dim = [dim]
76
+
77
+ spatial_shape = [shape[i] for i in dim]
78
+ total_len = 1
79
+ for i in spatial_shape:
80
+ total_len *= i
81
+ position = torch.reshape(torch.arange(total_len, dtype=torch.float, device=x.device), spatial_shape)
82
+
83
+ for i in range(dim[-1] + 1, len(shape) - 1, 1):
84
+ position = torch.unsqueeze(position, dim=-1)
85
+
86
+ half_size = shape[-1] // 2
87
+ freq_seq = -torch.arange(half_size, dtype=torch.float, device=x.device) / float(half_size)
88
+ inv_freq = 10000 ** -freq_seq
89
+ sinusoid = torch.einsum("...,d->...d", position, inv_freq)
90
+ sin = torch.sin(sinusoid).repeat_interleave(2, -1)
91
+ cos = torch.cos(sinusoid).repeat_interleave(2, -1)
92
+ tensor_cross = torch.stack([-x[..., 1:: 2], x[..., :: 2]], -1).reshape(x.shape)
93
+ # x1, x2 = torch.chunk(x, 2, dim=-1)
94
+ # return torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1)
95
+ return x * cos + tensor_cross * sin
96
+
97
+
98
+ def rel_pos_bias(seq_len, s):
99
+ a = torch.rand([1, s], dtype=torch.float)
100
+ b = torch.rand([1, s], dtype=torch.float)
101
+ w = torch.rand([2 * seq_len - 1], dtype=torch.float)
102
+ if seq_len <= 512:
103
+ t = F.pad(w[: 2 * seq_len - 1], [0, seq_len]).repeat(seq_len)
104
+ t = t[..., :-seq_len].reshape(-1, seq_len, 3 * seq_len - 2)
105
+ r = (2 * seq_len - 1) // 2
106
+ t = t[..., r:-r]
107
+ else:
108
+ a = RoPE(a.repeat(seq_len, 1), dim=[0])
109
+ b = RoPE(b.repeat(seq_len, 1), dim=[0])
110
+ t = torch.einsum("mk,nk->mn", a, b)
111
+ return t
112
+
113
+ def squared_relu(x, attention_mask, dim=-1):
114
+ rmask = ~(attention_mask.bool())
115
+ x = x.masked_fill(rmask, 0)
116
+ return torch.square(F.relu(x))
117
+
118
+
119
+ def attention_normalize(a, axis=-1, mask=None, fn='softmax'):
120
+ if fn == 'softmax':
121
+ return XSoftmax.apply(a, mask, axis)
122
+ else:
123
+ mask_ = a > -float('inf') / 10
124
+ # mask_ = mask_.byte()
125
+ mask_ = torch.sum(mask_, axis=axis, keepdim=True)
126
+ l = torch.maximum(mask_, torch.ones_like(mask_))
127
+ if fn == 'squared_relu':
128
+ rmask = ~(mask.bool())
129
+ a = a.masked_fill(rmask, 0)
130
+ return torch.square(F.relu(a)) / l
131
+ elif fn == 'softmax_plus':
132
+ return XSoftmax.apply(a * torch.log(l) / np.log(512), mask, axis)
133
+ return a
134
+
135
+
136
+ class GAULinear(nn.Linear):
137
+ def init_weight(self):
138
+ nn.init.xavier_uniform_(self.weight)
139
+
140
+
141
+ class GatedAttentionUnit(nn.Module):
142
+ """
143
+ GAU Block: Gate Attention Unit
144
+ """
145
+ def __init__(
146
+ self,
147
+ max_seq_length,
148
+ hidden_size,
149
+ attention_key_size=128,
150
+ activation='swish',
151
+ use_bias=True,
152
+ attention_norm_type='squared_relu',
153
+ attention_scale=True,
154
+ dropout=0.1,
155
+ pre_norm=False,
156
+ norm_type="layer_norm",
157
+ eps=1e-5,
158
+ shift_token=False,
159
+ use_rel_bias=False,
160
+ add_residual=True,
161
+ **kwargs,):
162
+
163
+ super(GatedAttentionUnit, self).__init__(**kwargs)
164
+ self.max_seq_length = max_seq_length
165
+ self.units = hidden_size
166
+ self.intermediate_size = self.units * 2
167
+ self.key_size = attention_key_size
168
+ self.activation = activation
169
+ self.use_bias = use_bias
170
+ self.attention_norm_type = attention_norm_type
171
+ self.attention_scale = attention_scale
172
+ self.dropout = StableDropout(dropout)
173
+ self.i_dense = nn.Sequential(
174
+ nn.Linear(self.units, 2 * self.intermediate_size + self.key_size, bias=self.use_bias),
175
+ nn.SiLU()
176
+ )
177
+ self.o_dense = nn.Sequential(
178
+ nn.Linear(self.intermediate_size, self.units, bias=self.use_bias),
179
+ self.dropout)
180
+ self.q_scaleoffset = OffsetScale(self.key_size)
181
+ self.k_scaleoffset = OffsetScale(self.key_size)
182
+ self.pre_norm = pre_norm
183
+ self.norm = (nn.LayerNorm(hidden_size, eps=eps) if norm_type.lower() == "layer_norm" else ScaleNorm(eps=eps))
184
+ self.add_residual = add_residual
185
+
186
+ def forward(self, x, attention_mask=None, **kwargs):
187
+ shortcut = x
188
+
189
+ if self.pre_norm:
190
+ x = self.norm(x)
191
+
192
+ x = self.i_dense(x)
193
+ u, v, qk = torch.split(x, [self.intermediate_size, self.intermediate_size, self.key_size], dim=-1)
194
+ q, k = self.q_scaleoffset(qk), self.k_scaleoffset(qk)
195
+ qk = RoPE(torch.stack([q, k], 2), dim=1)
196
+ q, k = qk[:, :, 0], qk[:, :, 1]
197
+ a = torch.einsum('bmd,bnd->bmn', q, k)
198
+ if self.attention_scale:
199
+ a = a / self.key_size**0.5
200
+ a = sequence_masking(a, attention_mask, '-inf', -1)
201
+ A = attention_normalize(a, -1, fn=self.attention_norm_type)
202
+ if self.dropout:
203
+ A = self.dropout(A)
204
+ out = self.o_dense(u * torch.einsum('bmn,bnd->bmd', A, v))
205
+
206
+ if self.add_residual:
207
+ out = out + shortcut
208
+ if not self.pre_norm:
209
+ out = self.norm(out)
210
+ return out
211
+ # # 加入RoPE
212
+ # if p_bias == 'rotary':
213
+ # qk = K.stack([q, k], 2)
214
+ # qk = apply_rotary_position_embeddings(inputs[n], qk)[0]
215
+ # q, k = qk[:, :, 0], qk[:, :, 1]
216
+ # # Attention
217
+ # a = tf.einsum('bmd,bnd->bmn', q, k)
218
+ # if self.attention_scale:
219
+ # a = a / self.key_size**0.5
220
+ # if a_bias is not None:
221
+ # a = a + a_bias
222
+ # a = sequence_masking(a, mask, '-inf', -1)
223
+ # A = attention_normalize(a, -1, self.normalization)
224
+ # if self.attention_dropout:
225
+ # A = Dropout(self.attention_dropout)(A)
226
+ # # 计算输出
227
+ # o = self.o_dense(u * tf.einsum('bmn,bnd->bmd', A, v))
228
+
229
+ # return o
230
+
231
+ class GAU(nn.Module):
232
+ def __init__(self, max_seq_length, hidden_size, expansion_factor=2, s=128, norm_type="layer_norm", eps=1e-5,
233
+ hidden_act="silu", shift_token=False, use_rel_bias=False, attention_norm_type='softmax',
234
+ pre_norm=False, dropout=0, add_residual = True):
235
+ super(GAU, self).__init__()
236
+ self.max_seq_length = max_seq_length
237
+ self.shift_token = shift_token
238
+ hidden_dim = int(expansion_factor * hidden_size)
239
+ self.norm = (nn.LayerNorm(hidden_size, eps=eps) if norm_type == "layer_norm" else ScaleNorm(eps=eps))
240
+ self.use_rel_bias = use_rel_bias
241
+ self.attention_norm_type = attention_norm_type
242
+ # if attention_norm_type == 'relu':
243
+ # self.attention_norm_func = squared_relu
244
+ # else:
245
+ # self.attention_norm_func = XSoftmax.apply
246
+ # self.norm = norm_klass(hidden_size)
247
+
248
+ self.dropout = nn.Dropout(dropout)
249
+
250
+ self.to_hidden = nn.Sequential(
251
+ nn.Linear(hidden_size, hidden_dim * 2),
252
+ nn.SiLU()
253
+ )
254
+
255
+ self.to_qk = nn.Sequential(
256
+ nn.Linear(hidden_size, s),
257
+ nn.SiLU()
258
+ )
259
+
260
+ self.offsetscale = OffsetScale(s, heads = 2)
261
+
262
+ self.to_out = nn.Sequential(
263
+ nn.Linear(hidden_dim, hidden_size),
264
+ nn.Dropout(dropout)
265
+ )
266
+
267
+ self.add_residual = add_residual
268
+ self.act_fn = ACT2FN[hidden_act]
269
+ self.pre_norm = pre_norm
270
+
271
+
272
+ def forward(
273
+ self,
274
+ x,
275
+ relative_pos = None,
276
+ attention_mask = None
277
+ ):
278
+ seq_len, device = x.shape[-2], x.device
279
+ if self.pre_norm:
280
+ normed_x = self.norm(x)
281
+ else:
282
+ normed_x = x
283
+ v, gate = self.to_hidden(normed_x).chunk(2, dim = -1)
284
+
285
+ qk = self.to_qk(normed_x)
286
+ base = self.offsetscale(qk)
287
+ base = RoPE(base, 1)
288
+ q, k = base.unbind(dim = -2)
289
+ sim = torch.einsum('b i d, b j d -> b i j', q, k)
290
+
291
+ if relative_pos is not None:
292
+ sim = sim + relative_pos
293
+ if attention_mask is not None:
294
+ if attention_mask.dim() < 3:
295
+ attention_mask = einops.rearrange(attention_mask, 'b j -> b 1 j')
296
+ # attn = attn.masked_fill(~attention_mask.bool(), 0.)
297
+ attn = attention_normalize(sim, mask=attention_mask, fn=self.attention_norm_type)
298
+ # attn = F.relu(sim) ** 2 / seq_len# / q.size(-1)
299
+ # logger.info(attn.max())
300
+ attn = self.dropout(attn)
301
+ # if self.causal:
302
+ # causal_mask = torch.ones((seq_len, seq_len), dtype = torch.bool, device = device).triu(1)
303
+ # attn = attn.masked_fill(causal_mask, 0.)
304
+
305
+ out = torch.einsum('b i j, b j d -> b i d', attn, v)
306
+ out = out * gate
307
+
308
+ out = self.to_out(out)
309
+
310
+ if self.add_residual:
311
+ out = out + x
312
+ if not self.pre_norm:
313
+ out = self.norm(out)
314
+ return out
315
+
316
+
317
+ class GAULayer(nn.Module):
318
+ def __init__(self, config, shift_token=False, use_ffn=False):
319
+ super(GAULayer, self).__init__()
320
+ self.attention = GatedAttentionUnit(config.max_position_embeddings, config.hidden_size,
321
+ shift_token=shift_token, use_rel_bias=config.use_rel_bias,
322
+ norm_type=config.norm_type, attention_norm_type=config.attention_norm_type,
323
+ pre_norm=config.pre_norm, dropout=config.hidden_dropout_prob)
324
+ if use_ffn:
325
+ self.intermediate = BertIntermediate(config)
326
+ self.output = BertOutput(config)
327
+ self.use_ffn = use_ffn
328
+
329
+ def forward(self, hidden_states, attention_mask, return_att=False, query_states=None, relative_pos=None, rel_embeddings=None):
330
+ attention_output = self.attention(hidden_states, attention_mask=attention_mask, relative_pos=relative_pos)
331
+ if self.use_ffn:
332
+ intermediate_output = self.intermediate(attention_output)
333
+ layer_output = self.output(intermediate_output, attention_output)
334
+ return layer_output
335
+ else:
336
+ return attention_output
337
+
338
+
339
+ class FlashBlock(nn.Module):
340
+ """
341
+ FLASH Block: Fast Linear Attention with a Single Head
342
+ """
343
+
344
+ def __init__(self, model_size, sequence_length, chunk_size=256, expansion_factor=2, s=128, norm_type="layer_norm", eps=1e-5,
345
+ hidden_act="silu"):
346
+ super(FlashBlock, self).__init__()
347
+ self.s = s
348
+ self.eps = eps
349
+ self.norm_type = norm_type
350
+ self.model_size = model_size
351
+ self.chunk_size = chunk_size
352
+ self.hidden_act = hidden_act
353
+ self.sequence_length = sequence_length
354
+ self.expansion_factor = expansion_factor
355
+ self.e = int(self.model_size * self.expansion_factor)
356
+
357
+ self.dense1 = nn.Linear(self.model_size, 2 * self.e + self.s, bias=True)
358
+ self.gamma = nn.Parameter(torch.rand((4, self.s)))
359
+ self.beta = nn.Parameter(torch.rand((4, self.s)))
360
+ self.dense2 = nn.Linear(self.e, self.model_size)
361
+ self.LayerNorm = (
362
+ nn.LayerNorm(model_size, eps=self.eps) if norm_type == "layer_norm" else ScaleNorm(eps=self.eps))
363
+
364
+ nn.init.xavier_normal_(self.dense1.weight)
365
+ self.act_fn = ACT2FN(self.hidden_act)
366
+
367
+ def global_linear_attention(self, query, key, value, causal):
368
+ if causal:
369
+ kv = torch.einsum("bgcs, bgce->bgse", key, value)
370
+ kv = torch.cumsum(kv, dim=1)
371
+ lin_v = torch.einsum("bgcs, bgse->bgce", query, kv)
372
+ return lin_v
373
+ else:
374
+ kv = torch.einsum("bgcs, bgce->bse", key, value)
375
+ lin_v = torch.einsum("bgcs, bse->bgce", query, kv)
376
+ return lin_v
377
+
378
+ def segment_ids_to_mask(self, segment_ids, causal=False):
379
+ """Generate the segment mask from the segment ids.
380
+ The segment mask is used to remove the attention between tokens in different documents.
381
+ """
382
+ min_ids, max_ids = torch.min(segment_ids, dim=-1).values, torch.max(segment_ids, dim=-1).values
383
+ # 1.0 indicates in the same group and 0.0 otherwise
384
+ mask = torch.logical_and(torch.less_equal(min_ids[:, :, None], max_ids[:, None, :]),
385
+ torch.greater_equal(max_ids[:, :, None], min_ids[:, None, :]))
386
+ mask = torch.tensor(mask, torch.float32)
387
+ if causal:
388
+ g = segment_ids.size()[1]
389
+ causal_mask = 1.0 - torch.triu(torch.ones([g, g], dtype=torch.float32)) # 保留主对角线以及主对角线以上的元素
390
+ mask *= causal_mask
391
+ mask = torch.div(mask, torch.sum(mask, dim=-1, keepdim=True))
392
+ return mask
393
+
394
+ def forward(self, x, causal=False, attention_mask=None, sequence_mask=None, **kwargs):
395
+ """
396
+ inputs: [batch_size, num_chunk, chunk_length, model_size]
397
+ """
398
+ _, g, n, d = x.size()
399
+ shortcut, x = x, self.LayerNorm(x)
400
+ # 通过线性变换得到Z,见论文公式(4)
401
+ uv = self.dense1(x)
402
+ # 将uv按最后一维切分,得到Ug:[C*e],Vg:[C*e], Zg:[C*s], 论文中的3.2部分
403
+ # u:[batch_size, num_chunk, chunk_length, self.e]
404
+ # v:[batch_size, num_chunk, chunk_length, self.e]
405
+ # z:[batch_size, num_chunk, chunk_length, self.s]
406
+ u, v, z = torch.split(self.act_fn(uv), [self.e, self.e, self.s], dim=-1)
407
+
408
+ # 生���quad_q, quad_k, lin_q, lin_k
409
+ # 首先进行简单的offset和scale,融入RoPE位置向量
410
+ z = torch.einsum("...r, hr->...hr", z, self.gamma) + self.beta
411
+ z = RoPE(z, dim=[1, 2])
412
+ quad_q, quad_k, lin_q, lin_k = torch.unbind(z, dim=-2) # 按-2维进行分解得到quad_q, quad_k, lin_q和lin_k
413
+ # 计算global的lin_v
414
+ lin_v = self.global_linear_attention(lin_q, lin_k, v, causal)
415
+ if causal:
416
+ # 线性注意力部分
417
+ lin_kv = torch.einsum("bgnk, bgne->bgke", lin_k, lin_v) / torch.tensor(n, x.dtype) # 见公式(7)
418
+ mask = self.segment_ids_to_mask(segment_ids=segment_ids, causal=causal)
419
+ cum_lin_kv = torch.einsum('bhke, bgh->bgke', lin_kv, mask)
420
+ linear = torch.einsum("bgnk, bgke->bgne", lin_kv, cum_lin_kv)
421
+ # 二次注意力
422
+ quad_qk = torch.einsum("bgnk, bgmk->bgnm", quad_q, quad_k) # 论文Local attention per chunk部分
423
+ bias = rel_pos_bias(self.sequence_length, self.s)[:, :n, :n]
424
+ kernel = torch.square(F.relu(quad_qk / n + bias)) # 论文中的relu**2部分
425
+ causal_mask = torch.triu(torch.ones([n, n], dtype=x.dtype))
426
+ quadratic = torch.einsum("bgnm, bgme->bgne", kernel * causal_mask, v)
427
+ else:
428
+ lin_kv = torch.einsum("bgnk, bgne->bgke", lin_k, lin_v) / torch.tensor(n, x.dtype) # 见公式(7)
429
+ mask = self.segment_ids_to_mask(segment_ids=segment_ids, causal=causal)
430
+ lin_kv = torch.einsum("bhke, bgh->bgke", lin_kv, mask)
431
+ linear = torch.einsum("bgnk, bgke->bgne", lin_q, lin_kv)
432
+ # 二次注意力
433
+ quad_qk = torch.einsum("bgnk, bgmk->bgnm", quad_q, quad_k) # 论文Local attention per chunk部分
434
+ bias = rel_pos_bias(self.sequence_length, self.s)[:, :n, :n]
435
+ kernel = torch.square(F.relu(quad_qk / n + bias)) # 论文中的relu**2部分
436
+ quadratic = torch.einsum("bgnm, bgme->bgne", kernel, v)
437
+ x = u * (quadratic + linear)
438
+ x = self.dense2(x)
439
+ x = x + shortcut
440
+ return x
441
+
442
+ class RelativePositionBias(nn.Module):
443
+ def __init__(
444
+ self,
445
+ scale,
446
+ causal = False,
447
+ num_buckets = 32,
448
+ max_distance = 128
449
+ ):
450
+ super().__init__()
451
+ self.scale = scale
452
+ self.causal = causal
453
+ self.num_buckets = num_buckets
454
+ self.max_distance = max_distance
455
+ self.relative_attention_bias = nn.Embedding(num_buckets, 1)
456
+
457
+ @staticmethod
458
+ def _relative_position_bucket(
459
+ relative_position,
460
+ causal = True,
461
+ num_buckets = 32,
462
+ max_distance = 128
463
+ ):
464
+ ret = 0
465
+ n = -relative_position
466
+ if not causal:
467
+ num_buckets //= 2
468
+ ret += (n < 0).long() * num_buckets
469
+ n = torch.abs(n)
470
+ else:
471
+ n = torch.max(n, torch.zeros_like(n))
472
+
473
+ max_exact = num_buckets // 2
474
+ is_small = n < max_exact
475
+
476
+ val_if_large = max_exact + (
477
+ torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
478
+ ).long()
479
+ val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
480
+
481
+ ret += torch.where(is_small, n, val_if_large)
482
+ return ret
483
+
484
+ def forward(self, x):
485
+ i, j, device = *x.shape[-2:], x.device
486
+ q_pos = torch.arange(i, dtype = torch.long, device = device)
487
+ k_pos = torch.arange(j, dtype = torch.long, device = device)
488
+ rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
489
+ rp_bucket = self._relative_position_bucket(rel_pos, causal = self.causal, num_buckets = self.num_buckets, max_distance = self.max_distance)
490
+ values = self.relative_attention_bias(rp_bucket)
491
+ bias = rearrange(values, 'i j 1 -> i j')
492
+ return bias * self.scale
493
+
494
+
495
+ class FlashEmbeddings(nn.Module):
496
+ """Construct the embeddings from word, position and token_type embeddings.
497
+ """
498
+ def __init__(self, config, with_position=False):
499
+ super(FlashEmbeddings, self).__init__()
500
+ self.word_embeddings = nn.Embedding(
501
+ config.vocab_size, config.hidden_size)
502
+ self.token_type_embeddings = nn.Embedding(
503
+ config.type_vocab_size, config.hidden_size)
504
+ self.with_position = with_position
505
+ if with_position:
506
+ self.position_embeddings = ScaledSinuEmbedding(config.hidden_size)
507
+
508
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
509
+ # any TensorFlow checkpoint file
510
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-5)
511
+ self.dropout = StableDropout(config.hidden_dropout_prob)
512
+
513
+ def forward(self, input_ids, token_type_ids=None, position_ids=None, token_mask=None):
514
+ seq_length = input_ids.size(1)
515
+ if position_ids is None:
516
+ position_ids = torch.arange(
517
+ seq_length, dtype=torch.long, device=input_ids.device)
518
+ position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
519
+ if token_type_ids is None:
520
+ token_type_ids = torch.zeros_like(input_ids)
521
+
522
+ words_embeddings = self.word_embeddings(input_ids)
523
+ if self.with_position:
524
+ position_embeddings = self.position_embeddings(words_embeddings)
525
+ else:
526
+ position_embeddings = 0
527
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
528
+
529
+ # if self.num_pos_emb > 1:
530
+ # num_batch = position_embeddings.size(0)
531
+ # num_pos = position_embeddings.size(1)
532
+ # position_embeddings = position_embeddings.view(
533
+ # num_batch, num_pos, self.num_pos_emb, -1)[torch.arange(0, num_batch).long(), :, task_idx, :]
534
+
535
+ embeddings = words_embeddings + position_embeddings + token_type_embeddings
536
+ # if self.fp32_embedding:
537
+ # embeddings = embeddings.half()
538
+ embeddings = MaskedLayerNorm(self.LayerNorm, embeddings, token_mask)
539
+ embeddings = self.dropout(embeddings)
540
+ return {
541
+ 'embeddings': embeddings,
542
+ 'position_embeddings': position_embeddings}
543
+
544
+
545
+ class GAUEncoder(nn.Module):
546
+ def __init__(self, config, shift_token=False):
547
+ super().__init__()
548
+ layer = GAULayer(config, shift_token=shift_token)
549
+ self.layer = nn.ModuleList([copy.deepcopy(layer)
550
+ for _ in range(config.num_hidden_layers)])
551
+
552
+ def get_attention_mask(self, attention_mask):
553
+ if attention_mask.dim() <= 2:
554
+ extended_attention_mask = attention_mask.unsqueeze(1)
555
+ attention_mask = extended_attention_mask*extended_attention_mask.squeeze(-2).unsqueeze(-1)
556
+ attention_mask = attention_mask #.byte()
557
+ return attention_mask
558
+
559
+ def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True, return_att=False, query_states = None, relative_pos=None):
560
+ all_encoder_layers = []
561
+ att_matrices = []
562
+ if isinstance(hidden_states, Sequence):
563
+ next_kv = hidden_states[0]
564
+ else:
565
+ next_kv = hidden_states
566
+ # rel_embeddings = self.get_rel_embedding()
567
+ for i, layer_module in enumerate(self.layer):
568
+ output_states = layer_module(next_kv, attention_mask, query_states = query_states, relative_pos=relative_pos)
569
+ if return_att:
570
+ output_states, att_m = output_states
571
+
572
+ # if i == 0 and self.with_conv:
573
+ # prenorm = output_states #output['prenorm_states']
574
+ # output_states = self.conv(hidden_states, prenorm, input_mask)
575
+
576
+ if query_states is not None:
577
+ query_states = output_states
578
+ if isinstance(hidden_states, Sequence):
579
+ next_kv = hidden_states[i+1] if i+1 < len(self.layer) else None
580
+ else:
581
+ next_kv = output_states
582
+
583
+ if output_all_encoded_layers:
584
+ all_encoder_layers.append(output_states)
585
+ if return_att:
586
+ att_matrices.append(att_m)
587
+ if not output_all_encoded_layers:
588
+ all_encoder_layers.append(output_states)
589
+ if return_att:
590
+ att_matrices.append(att_m)
591
+ return {
592
+ 'hidden_states': all_encoder_layers,
593
+ 'attention_matrices': att_matrices
594
+ }
595
+
596
+ class FlashEncoder(nn.Module):
597
+ def __init__(self, config):
598
+ super().__init__(config)
599
+ layer = GateAttentionUnit(config.max_position_embeddings, config.hidden_size)
600
+ self.layer = nn.ModuleList([copy.deepcopy(layer)
601
+ for _ in range(config.num_hidden_layers)])
602
+
603
+ def forward(self, hidden_states, attention_mask, token_mask=None,
604
+ output_all_encoded_layers=True,
605
+ prev_embedding=None, prev_encoded_layers=None, mask_qkv=None, seg_ids=None):
606
+ # history embedding and encoded layer must be simultanously given
607
+ assert (prev_embedding is None) == (prev_encoded_layers is None)
608
+
609
+ all_encoder_layers = []
610
+ if (prev_embedding is not None) and (prev_encoded_layers is not None):
611
+ history_states = prev_embedding
612
+ for i, layer_module in enumerate(self.layer):
613
+ hidden_states = layer_module(
614
+ hidden_states, attention_mask, history_states=history_states, mask_qkv=mask_qkv, seg_ids=seg_ids)
615
+ if output_all_encoded_layers:
616
+ all_encoder_layers.append(hidden_states)
617
+ if prev_encoded_layers is not None:
618
+ history_states = prev_encoded_layers[i]
619
+ else:
620
+ for layer_module in self.layer:
621
+ hidden_states = layer_module(
622
+ hidden_states, attention_mask=attention_mask, mask_qkv=mask_qkv, seg_ids=seg_ids)
623
+ if output_all_encoded_layers:
624
+ all_encoder_layers.append(hidden_states)
625
+ if not output_all_encoded_layers:
626
+ all_encoder_layers.append(hidden_states)
627
+ return all_encoder_layers
628
+
629
+ # class FlashQuadModel(BertModel):
630
+ # def __init__(self, config, pooler=False, shift_token=False, causal=False) -> None:
631
+ # super().__init__(config)
632
+ # self.embeddings = FlashEmbeddings(config)
633
+ # self.encoder = GAUEncoder(config, causal=causal, shift_token=shift_token)
634
+ # if not pooler:
635
+ # self.pooler = None
636
+ # self.apply(self.init_bert_weights)
637
+
638
+
639
+ class FlashQuadModel(torch.nn.Module):
640
+ """
641
+ Parameters:
642
+ config:
643
+ A model config class instance with the configuration to build a new model. The schema is similar to `BertConfig`,
644
+
645
+ pre_trained:
646
+ The pre-trained DeBERTa model, it can be a physical path of a pre-trained DeBERTa model or a released configurations,
647
+ i.e. [**base, large, base_mnli, large_mnli**]
648
+
649
+ """
650
+
651
+ def __init__(self, config=None, pre_trained=None, pooler=False, shift_token=False, causal=False, **kwargs):
652
+ super().__init__()
653
+ state = None
654
+ if pre_trained is not None:
655
+ state, model_config = load_model_state(pre_trained)
656
+ if config is not None and model_config is not None:
657
+ for k in config.__dict__:
658
+ if k not in ['hidden_size',
659
+ 'intermediate_size',
660
+ 'num_attention_heads',
661
+ 'num_hidden_layers',
662
+ 'vocab_size',
663
+ 'max_position_embeddings']:
664
+ model_config.__dict__[k] = config.__dict__[k]
665
+ config = copy.copy(model_config)
666
+ self.embeddings = FlashEmbeddings(config, with_position=True)
667
+ self.encoder = GAUEncoder(config, shift_token=shift_token)
668
+ if not pooler:
669
+ self.pooler = None
670
+ self.config = config
671
+ self.pre_trained = pre_trained
672
+ self.apply_state(state)
673
+
674
+ def get_attention_mask(self, input_ids=None, token_type_ids=None, attention_mask=None, input_mask=None):
675
+ if attention_mask is None:
676
+ if input_mask is not None:
677
+ return input_mask.unsqueeze(-1).expand(input_mask.size(0), input_mask.size(1), input_mask.size(1))
678
+ else:
679
+ return torch.ones_like(input_ids, dtype=torch.uint8).unsqueeze(-1).expand(input_mask.size(0), input_mask.size(1), input_mask.size(1))
680
+ else:
681
+ if attention_mask.dim() == 2:
682
+ if input_mask is not None:
683
+ attention_mask = attention_mask * input_mask
684
+ return attention_mask.unsqueeze(-1).expand(input_mask.size(0), input_mask.size(1), attention_mask.size(-1))
685
+ if attention_mask.dim() == 4:
686
+ attention_mask = attention_mask.squeeze(2)
687
+ if attention_mask.dim() == 3:
688
+ if input_mask is not None:
689
+ return attention_mask * input_mask.unsqueeze(-1).expand(input_mask.size(0), input_mask.size(1), attention_mask.size(-1))
690
+ else:
691
+ return attention_mask
692
+
693
+
694
+ def forward(self, input_ids, input_mask, attention_mask=None, token_type_ids=None,
695
+ output_all_encoded_layers=True, position_ids=None, return_att=False):
696
+ """
697
+ Args:
698
+ input_ids:
699
+ a torch.LongTensor of shape [batch_size, sequence_length] \
700
+ with the word token indices in the vocabulary
701
+
702
+ attention_mask:
703
+ an optional parameter for input mask or attention mask.
704
+
705
+ - If it's an input mask, then it will be torch.LongTensor of shape [batch_size, sequence_length] with indices \
706
+ selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max \
707
+ input sequence length in the current batch. It's the mask that we typically use for attention when \
708
+ a batch has varying length sentences.
709
+
710
+ - If it's an attention mask then it will be torch.LongTensor of shape [batch_size, sequence_length, sequence_length]. \
711
+ In this case, it's a mask indicate which tokens in the sequence should be attended by other tokens in the sequence.
712
+
713
+ token_type_ids:
714
+ an optional torch.LongTensor of shape [batch_size, sequence_length] with the token \
715
+ types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to \
716
+ a `sentence B` token (see BERT paper for more details).
717
+
718
+ output_all_encoded_layers:
719
+ whether to output results of all encoder layers, default, True
720
+
721
+ Returns:
722
+
723
+ - The output of the stacked transformer layers if `output_all_encoded_layers=True`, else \
724
+ the last layer of stacked transformer layers
725
+
726
+ - Attention matrix of self-attention layers if `return_att=True`
727
+
728
+
729
+ Example::
730
+
731
+ # Batch of wordPiece token ids.
732
+ # Each sample was padded with zero to the maxium length of the batch
733
+ input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
734
+ # Mask of valid input ids
735
+ attention_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
736
+
737
+ # DeBERTa model initialized with pretrained base model
738
+ bert = DeBERTa(pre_trained='base')
739
+
740
+ encoder_layers = bert(input_ids, attention_mask=attention_mask)
741
+
742
+ """
743
+ if token_type_ids is None:
744
+ token_type_ids = torch.zeros_like(input_ids)
745
+ # input_mask = torch.ones_like(input_ids)
746
+
747
+ if input_mask is None:
748
+ idxs = torch.flip(torch.cumsum(torch.flip(token_type_ids, [-1]), axis=1), [-1])
749
+ input_mask = idxs > 0
750
+ if not torch.any(input_mask):
751
+ input_mask = torch.ones_like(input_ids)
752
+ input_mask = input_mask # .byte()
753
+ attention_mask = self.get_attention_mask(input_ids, token_type_ids, attention_mask, input_mask)
754
+ attention_mask = attention_mask #.byte()
755
+ embedding_output = self.embeddings(input_ids.to(torch.long), token_type_ids.to(torch.long), position_ids, input_mask)
756
+ encoder_output = self.encoder(embedding_output['embeddings'], attention_mask, output_all_encoded_layers=output_all_encoded_layers, return_att = return_att)
757
+ encoder_output.update(embedding_output)
758
+ return encoder_output
759
+
760
+ def apply_state(self, state = None):
761
+ """ Load state from previous loaded model state dictionary.
762
+
763
+ Args:
764
+ state (:obj:`dict`, optional): State dictionary as the state returned by torch.module.state_dict(), default: `None`. \
765
+ If it's `None`, then will use the pre-trained state loaded via the constructor to re-initialize \
766
+ the `DeBERTa` model
767
+ """
768
+ if self.pre_trained is None and state is None:
769
+ return
770
+ if state is None:
771
+ state, config = load_model_state(self.pre_trained)
772
+ self.config = config
773
+
774
+ prefix = ''
775
+ for k in state:
776
+ if 'embeddings.' in k:
777
+ if not k.startswith('embeddings.'):
778
+ prefix = k[:k.index('embeddings.')]
779
+ break
780
+
781
+ missing_keys = []
782
+ unexpected_keys = []
783
+ error_msgs = []
784
+ self._load_from_state_dict(state, prefix = prefix, local_metadata=None, strict=True, missing_keys=missing_keys, unexpected_keys=unexpected_keys, error_msgs=error_msgs)
785
+
786
+
787
+ class FlashModel(BertModel):
788
+ def __init__(self, config) -> None:
789
+ super().__init__(config)
790
+ self.encoder = FlashEncoder(config)
791
+ self.apply(self.init_bert_weights)
792
+
793
+ if __name__ == '__main__':
794
+ model = FlashModel(768, 64)
modeling/focal_loss.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.cuda.amp as amp
5
+
6
+
7
+ ##
8
+ # version 1: use torch.autograd
9
+ class FocalLossV1(nn.Module):
10
+
11
+ def __init__(self,
12
+ alpha=0.25,
13
+ gamma=2,
14
+ reduction='mean',):
15
+ super(FocalLossV1, self).__init__()
16
+ self.alpha = alpha
17
+ self.gamma = gamma
18
+ self.reduction = reduction
19
+ self.crit = nn.BCEWithLogitsLoss(reduction='none')
20
+
21
+ def forward(self, logits, label):
22
+ '''
23
+ Usage is same as nn.BCEWithLogits:
24
+ >>> criteria = FocalLossV1()
25
+ >>> logits = torch.randn(8, 19, 384, 384)
26
+ >>> lbs = torch.randint(0, 2, (8, 19, 384, 384)).float()
27
+ >>> loss = criteria(logits, lbs)
28
+ '''
29
+ probs = torch.sigmoid(logits)
30
+ coeff = torch.abs(label - probs).pow(self.gamma).neg()
31
+ log_probs = torch.where(logits >= 0,
32
+ F.softplus(logits, -1, 50),
33
+ logits - F.softplus(logits, 1, 50))
34
+ log_1_probs = torch.where(logits >= 0,
35
+ -logits + F.softplus(logits, -1, 50),
36
+ -F.softplus(logits, 1, 50))
37
+ loss = label * self.alpha * log_probs + (1. - label) * (1. - self.alpha) * log_1_probs
38
+ loss = loss * coeff
39
+
40
+ if self.reduction == 'mean':
41
+ loss = loss.mean()
42
+ if self.reduction == 'sum':
43
+ loss = loss.sum()
44
+ return loss
45
+
46
+
47
+ ##
48
+ # version 2: user derived grad computation
49
+ class FocalSigmoidLossFuncV2(torch.autograd.Function):
50
+ '''
51
+ compute backward directly for better numeric stability
52
+ '''
53
+ @staticmethod
54
+ @amp.custom_fwd(cast_inputs=torch.float32)
55
+ def forward(ctx, logits, label, alpha, gamma):
56
+ # logits = logits.float()
57
+
58
+ probs = torch.sigmoid(logits)
59
+ coeff = (label - probs).abs_().pow_(gamma).neg_()
60
+ log_probs = torch.where(logits >= 0,
61
+ F.softplus(logits, -1, 50),
62
+ logits - F.softplus(logits, 1, 50))
63
+ log_1_probs = torch.where(logits >= 0,
64
+ -logits + F.softplus(logits, -1, 50),
65
+ -F.softplus(logits, 1, 50))
66
+ ce_term1 = log_probs.mul_(label).mul_(alpha)
67
+ ce_term2 = log_1_probs.mul_(1. - label).mul_(1. - alpha)
68
+ ce = ce_term1.add_(ce_term2)
69
+ loss = ce * coeff
70
+
71
+ ctx.vars = (coeff, probs, ce, label, gamma, alpha)
72
+
73
+ return loss
74
+
75
+ @staticmethod
76
+ @amp.custom_bwd
77
+ def backward(ctx, grad_output):
78
+ '''
79
+ compute gradient of focal loss
80
+ '''
81
+ (coeff, probs, ce, label, gamma, alpha) = ctx.vars
82
+
83
+ d_coeff = (label - probs).abs_().pow_(gamma - 1.).mul_(gamma)
84
+ d_coeff.mul_(probs).mul_(1. - probs)
85
+ d_coeff = torch.where(label < probs, d_coeff.neg(), d_coeff)
86
+ term1 = d_coeff.mul_(ce)
87
+
88
+ d_ce = label * alpha
89
+ d_ce.sub_(probs.mul_((label * alpha).mul_(2).add_(1).sub_(label).sub_(alpha)))
90
+ term2 = d_ce.mul(coeff)
91
+
92
+ grads = term1.add_(term2)
93
+ grads.mul_(grad_output)
94
+
95
+ return grads, None, None, None
96
+
97
+
98
+ class FocalLossV2(nn.Module):
99
+
100
+ def __init__(self,
101
+ alpha=0.25,
102
+ gamma=2,
103
+ reduction='mean'):
104
+ super(FocalLossV2, self).__init__()
105
+ self.alpha = alpha
106
+ self.gamma = gamma
107
+ self.reduction = reduction
108
+
109
+ def forward(self, logits, label):
110
+ '''
111
+ Usage is same as nn.BCEWithLogits:
112
+ >>> criteria = FocalLossV2()
113
+ >>> logits = torch.randn(8, 19, 384, 384)
114
+ >>> lbs = torch.randint(0, 2, (8, 19, 384, 384)).float()
115
+ >>> loss = criteria(logits, lbs)
116
+ '''
117
+ loss = FocalSigmoidLossFuncV2.apply(logits, label, self.alpha, self.gamma)
118
+ if self.reduction == 'mean':
119
+ loss = loss.mean()
120
+ if self.reduction == 'sum':
121
+ loss = loss.sum()
122
+ return loss
123
+
124
+
125
+ if __name__ == '__main__':
126
+ import torchvision
127
+ import torch
128
+ import numpy as np
129
+ import random
130
+ torch.manual_seed(15)
131
+ random.seed(15)
132
+ np.random.seed(15)
133
+ torch.backends.cudnn.deterministic = True
134
+
135
+ class Model(nn.Module):
136
+ def __init__(self):
137
+ super(Model, self).__init__()
138
+ net = torchvision.models.resnet18(pretrained=False)
139
+ self.conv1 = net.conv1
140
+ self.bn1 = net.bn1
141
+ self.maxpool = net.maxpool
142
+ self.relu = net.relu
143
+ self.layer1 = net.layer1
144
+ self.layer2 = net.layer2
145
+ self.layer3 = net.layer3
146
+ self.layer4 = net.layer4
147
+ self.out = nn.Conv2d(512, 3, 3, 1, 1)
148
+ def forward(self, x):
149
+ feat = self.conv1(x)
150
+ feat = self.bn1(feat)
151
+ feat = self.relu(feat)
152
+ feat = self.maxpool(feat)
153
+ feat = self.layer1(feat)
154
+ feat = self.layer2(feat)
155
+ feat = self.layer3(feat)
156
+ feat = self.layer4(feat)
157
+ feat = self.out(feat)
158
+ out = F.interpolate(feat, x.size()[2:], mode='bilinear', align_corners=True)
159
+ return out
160
+ net1 = Model()
161
+ net2 = Model()
162
+ net2.load_state_dict(net1.state_dict())
163
+
164
+ criteria1 = FocalLossV2()
165
+ # criteria2 = FocalLossV3()
166
+ net1.cuda()
167
+ net2.cuda()
168
+ net1.train()
169
+ net2.train()
170
+ net1.double()
171
+ net2.double()
172
+ criteria1.cuda()
173
+ # criteria2.cuda()
174
+
175
+ optim1 = torch.optim.SGD(net1.parameters(), lr=1e-2)
176
+ # optim2 = torch.optim.SGD(net2.parameters(), lr=1e-2)
177
+
178
+ bs = 16
179
+ for it in range(300000):
180
+ inten = torch.randn(bs, 3, 224, 244).cuda()
181
+ # lbs = torch.randint(0, 2, (bs, 3, 224, 244)).float().cuda()
182
+ lbs = torch.randn(bs, 3, 224, 244).sigmoid().cuda()
183
+ inten = inten.double()
184
+ lbs = lbs.double()
185
+ logits = net1(inten)
186
+ loss1 = criteria1(logits, lbs)
187
+ optim1.zero_grad()
188
+ loss1.backward()
189
+ optim1.step()
190
+ # logits = net2(inten)
191
+ # loss2 = criteria2(logits, lbs)
192
+ # optim2.zero_grad()
193
+ # loss2.backward()
194
+ # optim2.step()
195
+ # with torch.no_grad():
196
+ # if (it+1) % 50 == 0:
197
+ # print('iter: {}, ================='.format(it+1))
198
+ # print('out.weight: ', torch.mean(torch.abs(net1.out.weight - net2.out.weight)).item())
199
+ # print('conv1.weight: ', torch.mean(torch.abs(net1.conv1.weight - net2.conv1.weight)).item())
200
+ # print('loss: ', loss1.item() - loss2.item())
modeling/gat.py ADDED
@@ -0,0 +1,665 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Zhoubo
3
+ #
4
+ """
5
+ FLASH: https://arxiv.org/abs/2202.10447
6
+ """
7
+ import copy
8
+ import torch
9
+ import math
10
+ import os
11
+ from collections import Sequence
12
+ import json
13
+ import numpy as np
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from transformers.activations import ACT2FN
18
+ from .ops import sequence_masking, XSoftmax, StableDropout, MaskedLayerNorm
19
+ from .config import ModelConfig
20
+ from .cache_utils import load_model_state
21
+ import einops
22
+
23
+
24
+ class ScaleNorm(nn.Module):
25
+ def __init__(self, eps=1e-5):
26
+ super().__init__()
27
+ self.eps = eps
28
+ self.scala = nn.Parameter(torch.ones(1))
29
+
30
+ def forward(self, x):
31
+ mean_square = (x ** 2).mean(dim=-1, keepdim=True)
32
+ x = x * torch.rsqrt(mean_square + self.eps) * self.scala
33
+ return x
34
+
35
+
36
+ class BertLayerNorm(nn.Module):
37
+ def __init__(self, hidden_size, eps=1e-5):
38
+ """Construct a layernorm module in the TF style (epsilon inside the square root).
39
+ """
40
+ super(BertLayerNorm, self).__init__()
41
+ self.weight = nn.Parameter(torch.ones(hidden_size))
42
+ self.bias = nn.Parameter(torch.zeros(hidden_size))
43
+ self.variance_epsilon = eps
44
+
45
+ def forward(self, x):
46
+ u = x.mean(-1, keepdim=True)
47
+ s = (x - u).pow(2).mean(-1, keepdim=True)
48
+ x = (x - u) / torch.sqrt(s + self.variance_epsilon)
49
+ return self.weight * x + self.bias
50
+
51
+
52
+ class ScaledSinuEmbedding(nn.Module):
53
+ def __init__(self, dim):
54
+ super().__init__()
55
+ self.scale = nn.Parameter(torch.ones(1,))
56
+ inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
57
+ self.register_buffer('inv_freq', inv_freq)
58
+
59
+ def forward(self, x):
60
+ n, device = x.shape[1], x.device
61
+ t = torch.arange(n, device = device).type_as(self.inv_freq)
62
+ sinu = torch.einsum('i , j -> i j', t, self.inv_freq)
63
+ emb = torch.cat((sinu.sin(), sinu.cos()), dim = -1)
64
+ return emb * self.scale
65
+
66
+
67
+ def RoPE(x, dim):
68
+ """
69
+ :param x: input tensor
70
+ :param dim: oprate dimension
71
+ :return: tensor
72
+ """
73
+ shape = x.shape
74
+ if isinstance(dim, int):
75
+ dim = [dim]
76
+
77
+ spatial_shape = [shape[i] for i in dim]
78
+ total_len = 1
79
+ for i in spatial_shape:
80
+ total_len *= i
81
+ position = torch.reshape(torch.arange(total_len, dtype=torch.float, device=x.device), spatial_shape)
82
+
83
+ for i in range(dim[-1] + 1, len(shape) - 1, 1):
84
+ position = torch.unsqueeze(position, dim=-1)
85
+
86
+ half_size = shape[-1] // 2
87
+ freq_seq = -torch.arange(half_size, dtype=torch.float, device=x.device) / float(half_size)
88
+ inv_freq = 10000 ** -freq_seq
89
+ sinusoid = torch.einsum("...,d->...d", position, inv_freq)
90
+ sin = torch.sin(sinusoid)
91
+ cos = torch.cos(sinusoid)
92
+ x1, x2 = torch.chunk(x, 2, dim=-1)
93
+ return torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1)
94
+
95
+
96
+ def rel_pos_bias(seq_len, s):
97
+ a = torch.rand([1, s], dtype=torch.float)
98
+ b = torch.rand([1, s], dtype=torch.float)
99
+ w = torch.rand([2 * seq_len - 1], dtype=torch.float)
100
+ if seq_len <= 512:
101
+ t = F.pad(w[: 2 * seq_len - 1], [0, seq_len]).repeat(seq_len)
102
+ t = t[..., :-seq_len].reshape(-1, seq_len, 3 * seq_len - 2)
103
+ r = (2 * seq_len - 1) // 2
104
+ t = t[..., r:-r]
105
+ else:
106
+ a = RoPE(a.repeat(seq_len, 1), dim=[0])
107
+ b = RoPE(b.repeat(seq_len, 1), dim=[0])
108
+ t = torch.einsum("mk,nk->mn", a, b)
109
+ return t
110
+
111
+ def squared_relu(x, attention_mask, dim=-1):
112
+ rmask = ~(attention_mask.bool())
113
+ x = x.masked_fill(rmask, 0)
114
+ return torch.square(F.relu(x))
115
+
116
+
117
+ def attention_normalize(a, axis=-1, mask=None, fn='softmax'):
118
+ if fn == 'softmax':
119
+ return XSoftmax.apply(a, mask, axis)
120
+ else:
121
+ mask_ = a > -float('inf') / 10
122
+ # mask_ = mask_.byte()
123
+ mask_ = torch.sum(mask_, axis=axis, keepdim=True)
124
+ l = torch.maximum(mask_, torch.ones_like(mask_))
125
+ if fn == 'relu':
126
+ rmask = ~(mask.bool())
127
+ a = a.masked_fill(rmask, 0)
128
+ return torch.square(F.relu(a)) / l
129
+ elif fn == 'softmax_plus':
130
+ return XSoftmax.apply(a * torch.log(l) / np.log(512), mask, axis)
131
+ return a
132
+
133
+
134
+ class GAULinear(nn.Linear):
135
+ def init_weight(self):
136
+ nn.init.xavier_uniform_(self.weight)
137
+
138
+
139
+ class GatedAttentionUnit(nn.Module):
140
+ """
141
+ GAU Block: Gate Attention Unit
142
+ """
143
+ def __init__(
144
+ self,
145
+ max_seq_length,
146
+ hidden_size,
147
+ attention_key_size=128,
148
+ activation='swish',
149
+ use_bias=True,
150
+ attention_norm_type='squared_relu',
151
+ attention_scale=True,
152
+ dropout=0.1,
153
+ pre_norm=False,
154
+ norm_type="layer_norm",
155
+ eps=1e-5,
156
+ shift_token=False,
157
+ use_rel_bias=False,
158
+ add_residual=True,
159
+ **kwargs,):
160
+
161
+ super(GatedAttentionUnit, self).__init__(**kwargs)
162
+ self.max_seq_length = max_seq_length
163
+ self.units = hidden_size
164
+ self.intermediate_size = self.units * 2
165
+ self.key_size = attention_key_size
166
+ self.activation = activation
167
+ self.use_bias = use_bias
168
+ self.attention_norm_type = attention_norm_type
169
+ self.attention_scale = attention_scale
170
+ self.dropout = StableDropout(dropout)
171
+ self.i_dense = nn.Sequential(
172
+ nn.Linear(self.units, 2 * self.intermediate_size + self.key_size, bias=self.use_bias),
173
+ nn.SiLU()
174
+ )
175
+ self.o_dense = nn.Sequential(
176
+ nn.Linear(self.intermediate_size, self.units, bias=self.use_bias),
177
+ self.dropout)
178
+ self.q_scaleoffset = OffsetScale(self.key_size)
179
+ self.k_scaleoffset = OffsetScale(self.key_size)
180
+ self.pre_norm = pre_norm
181
+ self.norm = (nn.LayerNorm(hidden_size, eps=eps) if norm_type.lower() == "layer_norm" else ScaleNorm(eps=eps))
182
+ self.add_residual = add_residual
183
+
184
+ def forward(self, x, attention_mask=None, **kwargs):
185
+ shortcut = x
186
+
187
+ if self.pre_norm:
188
+ x = self.norm(x)
189
+
190
+ x = self.i_dense(x)
191
+ u, v, qk = torch.split(x, [self.intermediate_size, self.intermediate_size, self.key_size], dim=-1)
192
+ q, k = self.q_scaleoffset(qk), self.k_scaleoffset(qk)
193
+ qk = RoPE(torch.stack([q, k], 2), dim=1)
194
+ q, k = qk[:, :, 0], qk[:, :, 1]
195
+ a = torch.einsum('bmd,bnd->bmn', q, k)
196
+ if self.attention_scale:
197
+ a = a / self.key_size**0.5
198
+ a = sequence_masking(a, attention_mask, '-inf', -1)
199
+ A = attention_normalize(a, -1, fn=self.attention_norm_type)
200
+ if self.dropout:
201
+ A = self.dropout(A)
202
+ out = self.o_dense(u * torch.einsum('bmn,bnd->bmd', A, v))
203
+
204
+ if self.add_residual:
205
+ out = out + shortcut
206
+ if not self.pre_norm:
207
+ out = self.norm(out)
208
+ return out
209
+
210
+
211
+ class OffsetScale(nn.Module):
212
+ def __init__(self, dim, heads = 1):
213
+ super().__init__()
214
+ self.gamma = nn.Parameter(torch.ones(heads, dim))
215
+ self.beta = nn.Parameter(torch.zeros(heads, dim))
216
+ # nn.init.normal_(self.gamma, std = 0.02)
217
+ nn.init.xavier_uniform_(self.gamma)
218
+
219
+ def forward(self, x):
220
+ out = torch.einsum('... d, h d -> ... h d', x, self.gamma) + self.beta
221
+ return out
222
+
223
+
224
+ class BertIntermediate(nn.Module):
225
+ def __init__(self, config):
226
+ super().__init__()
227
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
228
+ self.intermediate_act_fn = ACT2FN[config.hidden_act] \
229
+ if isinstance(config.hidden_act, str) else config.hidden_act
230
+
231
+ def forward(self, hidden_states):
232
+ hidden_states = self.dense(hidden_states)
233
+ hidden_states = self.intermediate_act_fn(hidden_states)
234
+ return hidden_states
235
+
236
+
237
+ class BertOutput(nn.Module):
238
+ def __init__(self, config):
239
+ super(BertOutput, self).__init__()
240
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
241
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, config.layer_norm_eps)
242
+ self.dropout = StableDropout(config.hidden_dropout_prob)
243
+ self.config = config
244
+
245
+ def forward(self, hidden_states, input_states, mask=None):
246
+ hidden_states = self.dense(hidden_states)
247
+ hidden_states = self.dropout(hidden_states)
248
+ hidden_states += input_states
249
+ hidden_states = MaskedLayerNorm(self.LayerNorm, hidden_states)
250
+ return hidden_states
251
+
252
+
253
+ class GAU(nn.Module):
254
+ def __init__(self, max_seq_length, hidden_size, expansion_factor=2, s=128, norm_type="layer_norm", eps=1e-5,
255
+ hidden_act="silu", shift_token=False, use_rel_bias=False, attention_norm_type='softmax',
256
+ pre_norm=False, dropout=0, add_residual = True):
257
+ super(GAU, self).__init__()
258
+ self.max_seq_length = max_seq_length
259
+ self.shift_token = shift_token
260
+ hidden_dim = int(expansion_factor * hidden_size)
261
+ self.norm = (nn.LayerNorm(hidden_size, eps=eps) if norm_type == "layer_norm" else ScaleNorm(eps=eps))
262
+ self.use_rel_bias = use_rel_bias
263
+ self.attention_norm_type = attention_norm_type
264
+ # if attention_norm_type == 'relu':
265
+ # self.attention_norm_func = squared_relu
266
+ # else:
267
+ # self.attention_norm_func = XSoftmax.apply
268
+ # self.norm = norm_klass(hidden_size)
269
+
270
+ self.dropout = nn.Dropout(dropout)
271
+
272
+ self.to_hidden = nn.Sequential(
273
+ nn.Linear(hidden_size, hidden_dim * 2),
274
+ nn.SiLU()
275
+ )
276
+
277
+ self.to_qk = nn.Sequential(
278
+ nn.Linear(hidden_size, s),
279
+ nn.SiLU()
280
+ )
281
+
282
+ self.offsetscale = OffsetScale(s, heads = 2)
283
+
284
+ self.to_out = nn.Sequential(
285
+ nn.Linear(hidden_dim, hidden_size),
286
+ nn.Dropout(dropout)
287
+ )
288
+
289
+ self.add_residual = add_residual
290
+ self.act_fn = ACT2FN[hidden_act]
291
+ self.pre_norm = pre_norm
292
+
293
+
294
+ def forward(
295
+ self,
296
+ x,
297
+ relative_pos = None,
298
+ attention_mask = None
299
+ ):
300
+ seq_len, device = x.shape[-2], x.device
301
+ if self.pre_norm:
302
+ normed_x = self.norm(x)
303
+ else:
304
+ normed_x = x
305
+ v, gate = self.to_hidden(normed_x).chunk(2, dim = -1)
306
+
307
+ qk = self.to_qk(normed_x)
308
+ base = self.offsetscale(qk)
309
+ base = RoPE(base, 1).half()
310
+ q, k = base.unbind(dim = -2)
311
+ sim = torch.einsum('b i d, b j d -> b i j', q, k)
312
+
313
+ if relative_pos is not None:
314
+ sim = sim + relative_pos
315
+ if attention_mask is not None:
316
+ if attention_mask.dim() < 3:
317
+ attention_mask = einops.rearrange(attention_mask, 'b j -> b 1 j')
318
+ # attn = attn.masked_fill(~attention_mask.bool(), 0.)
319
+ attn = attention_normalize(sim, mask=attention_mask, fn=self.attention_norm_type)
320
+ # attn = F.relu(sim) ** 2 / seq_len# / q.size(-1)
321
+ # logger.info(attn.max())
322
+ attn = self.dropout(attn)
323
+ # if self.causal:
324
+ # causal_mask = torch.ones((seq_len, seq_len), dtype = torch.bool, device = device).triu(1)
325
+ # attn = attn.masked_fill(causal_mask, 0.)
326
+
327
+ out = torch.einsum('b i j, b j d -> b i d', attn.half(), v)
328
+ out = out * gate
329
+
330
+ out = self.to_out(out)
331
+
332
+ if self.add_residual:
333
+ out = out + x
334
+ if not self.pre_norm:
335
+ out = self.norm(out)
336
+ return out
337
+
338
+
339
+ class GatLayer(nn.Module):
340
+ def __init__(self, config, shift_token=False, use_ffn=False):
341
+ super(GatLayer, self).__init__()
342
+ self.attention = GatedAttentionUnit(config.max_position_embeddings, config.hidden_size,
343
+ shift_token=shift_token, use_rel_bias=config.use_rel_bias,
344
+ norm_type=config.norm_type, attention_norm_type=config.attention_norm_type,
345
+ pre_norm=config.pre_norm, dropout=config.hidden_dropout_prob)
346
+ if use_ffn:
347
+ self.intermediate = BertIntermediate(config)
348
+ self.output = BertOutput(config)
349
+ self.use_ffn = use_ffn
350
+
351
+ def forward(self, hidden_states, attention_mask, return_att=False, query_states=None, relative_pos=None, rel_embeddings=None):
352
+ attention_output = self.attention(hidden_states, attention_mask=attention_mask, relative_pos=relative_pos)
353
+ if self.use_ffn:
354
+ intermediate_output = self.intermediate(attention_output)
355
+ layer_output = self.output(intermediate_output, attention_output)
356
+ return layer_output
357
+ else:
358
+ return attention_output
359
+
360
+
361
+ class RelativePositionBias(nn.Module):
362
+ def __init__(
363
+ self,
364
+ scale,
365
+ causal = False,
366
+ num_buckets = 32,
367
+ max_distance = 128
368
+ ):
369
+ super().__init__()
370
+ self.scale = scale
371
+ self.causal = causal
372
+ self.num_buckets = num_buckets
373
+ self.max_distance = max_distance
374
+ self.relative_attention_bias = nn.Embedding(num_buckets, 1)
375
+
376
+ @staticmethod
377
+ def _relative_position_bucket(
378
+ relative_position,
379
+ causal = True,
380
+ num_buckets = 32,
381
+ max_distance = 128
382
+ ):
383
+ ret = 0
384
+ n = -relative_position
385
+ if not causal:
386
+ num_buckets //= 2
387
+ ret += (n < 0).long() * num_buckets
388
+ n = torch.abs(n)
389
+ else:
390
+ n = torch.max(n, torch.zeros_like(n))
391
+
392
+ max_exact = num_buckets // 2
393
+ is_small = n < max_exact
394
+
395
+ val_if_large = max_exact + (
396
+ torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
397
+ ).long()
398
+ val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
399
+
400
+ ret += torch.where(is_small, n, val_if_large)
401
+ return ret
402
+
403
+ def forward(self, x):
404
+ i, j, device = *x.shape[-2:], x.device
405
+ q_pos = torch.arange(i, dtype = torch.long, device = device)
406
+ k_pos = torch.arange(j, dtype = torch.long, device = device)
407
+ rel_pos = einops.rearrange(k_pos, 'j -> 1 j') - einops.rearrange(q_pos, 'i -> i 1')
408
+ rp_bucket = self._relative_position_bucket(rel_pos, causal = self.causal, num_buckets = self.num_buckets, max_distance = self.max_distance)
409
+ values = self.relative_attention_bias(rp_bucket)
410
+ bias = einops.rearrange(values, 'i j 1 -> i j')
411
+ return bias * self.scale
412
+
413
+
414
+ class GatEmbeddings(nn.Module):
415
+ """Construct the embeddings from word, position and token_type embeddings.
416
+ """
417
+ def __init__(self, config, with_position=False):
418
+ super(GatEmbeddings, self).__init__()
419
+ self.word_embeddings = nn.Embedding(
420
+ config.vocab_size, config.hidden_size)
421
+ self.token_type_embeddings = nn.Embedding(
422
+ config.type_vocab_size, config.hidden_size)
423
+ self.with_position = with_position
424
+ if with_position:
425
+ self.position_embeddings = ScaledSinuEmbedding(config.hidden_size)
426
+
427
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
428
+ # any TensorFlow checkpoint file
429
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-5)
430
+ self.dropout = StableDropout(config.hidden_dropout_prob)
431
+
432
+ def forward(self, input_ids, token_type_ids=None, position_ids=None, token_mask=None):
433
+ seq_length = input_ids.size(1)
434
+ if position_ids is None:
435
+ position_ids = torch.arange(
436
+ seq_length, dtype=torch.long, device=input_ids.device)
437
+ position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
438
+ if token_type_ids is None:
439
+ token_type_ids = torch.zeros_like(input_ids)
440
+
441
+ words_embeddings = self.word_embeddings(input_ids)
442
+ if self.with_position:
443
+ position_embeddings = self.position_embeddings(words_embeddings)
444
+ else:
445
+ position_embeddings = 0
446
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
447
+
448
+ # if self.num_pos_emb > 1:
449
+ # num_batch = position_embeddings.size(0)
450
+ # num_pos = position_embeddings.size(1)
451
+ # position_embeddings = position_embeddings.view(
452
+ # num_batch, num_pos, self.num_pos_emb, -1)[torch.arange(0, num_batch).long(), :, task_idx, :]
453
+
454
+ embeddings = words_embeddings + position_embeddings + token_type_embeddings
455
+ # if self.fp32_embedding:
456
+ # embeddings = embeddings.half()
457
+ embeddings = MaskedLayerNorm(self.LayerNorm, embeddings, token_mask)
458
+ embeddings = self.dropout(embeddings)
459
+ return {
460
+ 'embeddings': embeddings,
461
+ 'position_embeddings': position_embeddings}
462
+
463
+
464
+ class GatEncoder(nn.Module):
465
+ def __init__(self, config, shift_token=False):
466
+ super().__init__()
467
+ layer = GatLayer(config, shift_token=shift_token)
468
+ self.layer = nn.ModuleList([copy.deepcopy(layer)
469
+ for _ in range(config.num_hidden_layers)])
470
+
471
+ def get_attention_mask(self, attention_mask):
472
+ if attention_mask.dim() <= 2:
473
+ extended_attention_mask = attention_mask.unsqueeze(1)
474
+ attention_mask = extended_attention_mask*extended_attention_mask.squeeze(-2).unsqueeze(-1)
475
+ attention_mask = attention_mask.byte()
476
+ return attention_mask
477
+
478
+ def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True, return_att=False, query_states = None, relative_pos=None):
479
+ all_encoder_layers = []
480
+ att_matrices = []
481
+ if isinstance(hidden_states, Sequence):
482
+ next_kv = hidden_states[0]
483
+ else:
484
+ next_kv = hidden_states
485
+ # rel_embeddings = self.get_rel_embedding()
486
+ for i, layer_module in enumerate(self.layer):
487
+ output_states = layer_module(next_kv, attention_mask, query_states = query_states, relative_pos=relative_pos)
488
+ if return_att:
489
+ output_states, att_m = output_states
490
+
491
+ # if i == 0 and self.with_conv:
492
+ # prenorm = output_states #output['prenorm_states']
493
+ # output_states = self.conv(hidden_states, prenorm, input_mask)
494
+
495
+ if query_states is not None:
496
+ query_states = output_states
497
+ if isinstance(hidden_states, Sequence):
498
+ next_kv = hidden_states[i+1] if i+1 < len(self.layer) else None
499
+ else:
500
+ next_kv = output_states
501
+
502
+ if output_all_encoded_layers:
503
+ all_encoder_layers.append(output_states)
504
+ if return_att:
505
+ att_matrices.append(att_m)
506
+ if not output_all_encoded_layers:
507
+ all_encoder_layers.append(output_states)
508
+ if return_att:
509
+ att_matrices.append(att_m)
510
+ return {
511
+ 'hidden_states': all_encoder_layers,
512
+ 'attention_matrices': att_matrices
513
+ }
514
+
515
+
516
+ class GatModel(torch.nn.Module):
517
+ """
518
+ Parameters:
519
+ config:
520
+ A model config class instance with the configuration to build a new model. The schema is similar to `BertConfig`,
521
+
522
+ pre_trained:
523
+ The pre-trained DeBERTa model, it can be a physical path of a pre-trained DeBERTa model or a released configurations,
524
+ i.e. [**base, large, base_mnli, large_mnli**]
525
+
526
+ """
527
+
528
+ def __init__(self, config=None, pre_trained=None, pooler=False, shift_token=False, causal=False, **kwargs):
529
+ super().__init__()
530
+ state = None
531
+ if pre_trained is not None:
532
+ state, model_config = load_model_state(pre_trained)
533
+ if config is not None and model_config is not None:
534
+ for k in config.__dict__:
535
+ if k not in ['hidden_size',
536
+ 'intermediate_size',
537
+ 'num_attention_heads',
538
+ 'num_hidden_layers',
539
+ 'vocab_size',
540
+ 'max_position_embeddings']:
541
+ model_config.__dict__[k] = config.__dict__[k]
542
+ config = copy.copy(model_config)
543
+ self.embeddings = GatEmbeddings(config, with_position=True)
544
+ self.encoder = GatEncoder(config, shift_token=shift_token)
545
+ if not pooler:
546
+ self.pooler = None
547
+ self.config = config
548
+ self.pre_trained = pre_trained
549
+ self.apply_state(state)
550
+
551
+ def get_attention_mask(self, input_ids=None, token_type_ids=None, attention_mask=None, input_mask=None):
552
+ if attention_mask is None:
553
+ if input_mask is not None:
554
+ return input_mask.unsqueeze(-1).expand(input_mask.size(0), input_mask.size(1), input_mask.size(1))
555
+ else:
556
+ return torch.ones_like(input_ids, dtype=torch.uint8).unsqueeze(-1).expand(input_mask.size(0), input_mask.size(1), input_mask.size(1))
557
+ else:
558
+ if attention_mask.dim() == 2:
559
+ if input_mask is not None:
560
+ attention_mask = attention_mask * input_mask
561
+ return attention_mask.unsqueeze(-1).expand(input_mask.size(0), input_mask.size(1), attention_mask.size(-1))
562
+ if attention_mask.dim() == 4:
563
+ attention_mask = attention_mask.squeeze(2)
564
+ if attention_mask.dim() == 3:
565
+ if input_mask is not None:
566
+ return attention_mask * input_mask.unsqueeze(-1).expand(input_mask.size(0), input_mask.size(1), attention_mask.size(-1))
567
+ else:
568
+ return attention_mask
569
+
570
+
571
+ def forward(self, input_ids, input_mask, attention_mask=None, token_type_ids=None,
572
+ output_all_encoded_layers=True, position_ids=None, return_att=False):
573
+ """
574
+ Args:
575
+ input_ids:
576
+ a torch.LongTensor of shape [batch_size, sequence_length] \
577
+ with the word token indices in the vocabulary
578
+
579
+ attention_mask:
580
+ an optional parameter for input mask or attention mask.
581
+
582
+ - If it's an input mask, then it will be torch.LongTensor of shape [batch_size, sequence_length] with indices \
583
+ selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max \
584
+ input sequence length in the current batch. It's the mask that we typically use for attention when \
585
+ a batch has varying length sentences.
586
+
587
+ - If it's an attention mask then it will be torch.LongTensor of shape [batch_size, sequence_length, sequence_length]. \
588
+ In this case, it's a mask indicate which tokens in the sequence should be attended by other tokens in the sequence.
589
+
590
+ token_type_ids:
591
+ an optional torch.LongTensor of shape [batch_size, sequence_length] with the token \
592
+ types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to \
593
+ a `sentence B` token (see BERT paper for more details).
594
+
595
+ output_all_encoded_layers:
596
+ whether to output results of all encoder layers, default, True
597
+
598
+ Returns:
599
+
600
+ - The output of the stacked transformer layers if `output_all_encoded_layers=True`, else \
601
+ the last layer of stacked transformer layers
602
+
603
+ - Attention matrix of self-attention layers if `return_att=True`
604
+
605
+
606
+ Example::
607
+
608
+ # Batch of wordPiece token ids.
609
+ # Each sample was padded with zero to the maxium length of the batch
610
+ input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
611
+ # Mask of valid input ids
612
+ attention_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
613
+
614
+ # DeBERTa model initialized with pretrained base model
615
+ bert = DeBERTa(pre_trained='base')
616
+
617
+ encoder_layers = bert(input_ids, attention_mask=attention_mask)
618
+
619
+ """
620
+ if token_type_ids is None:
621
+ token_type_ids = torch.zeros_like(input_ids)
622
+ # input_mask = torch.ones_like(input_ids)
623
+
624
+ if input_mask is None:
625
+ idxs = torch.flip(torch.cumsum(torch.flip(token_type_ids, [-1]), axis=1), [-1])
626
+ input_mask = idxs > 0
627
+ if not torch.any(input_mask):
628
+ input_mask = torch.ones_like(input_ids)
629
+ input_mask = input_mask.byte()
630
+ attention_mask = self.get_attention_mask(input_ids, token_type_ids, attention_mask, input_mask)
631
+ attention_mask = attention_mask.byte()
632
+ embedding_output = self.embeddings(input_ids.to(torch.long), token_type_ids.to(torch.long), position_ids, input_mask)
633
+ encoder_output = self.encoder(embedding_output['embeddings'], attention_mask, output_all_encoded_layers=output_all_encoded_layers, return_att = return_att)
634
+ encoder_output.update(embedding_output)
635
+ return encoder_output
636
+
637
+ def apply_state(self, state = None):
638
+ """ Load state from previous loaded model state dictionary.
639
+
640
+ Args:
641
+ state (:obj:`dict`, optional): State dictionary as the state returned by torch.module.state_dict(), default: `None`. \
642
+ If it's `None`, then will use the pre-trained state loaded via the constructor to re-initialize \
643
+ the `DeBERTa` model
644
+ """
645
+ if self.pre_trained is None and state is None:
646
+ return
647
+ if state is None:
648
+ state, config = load_model_state(self.pre_trained)
649
+ self.config = config
650
+
651
+ prefix = ''
652
+ for k in state:
653
+ if 'embeddings.' in k:
654
+ if not k.startswith('embeddings.'):
655
+ prefix = k[:k.index('embeddings.')]
656
+ break
657
+
658
+ missing_keys = []
659
+ unexpected_keys = []
660
+ error_msgs = []
661
+ self._load_from_state_dict(state, prefix = prefix, local_metadata=None, strict=True, missing_keys=missing_keys, unexpected_keys=unexpected_keys, error_msgs=error_msgs)
662
+
663
+
664
+ if __name__ == '__main__':
665
+ model = GatModel(768, 64)
modeling/gpt2_bpe_utils.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Byte pair encoding utilities from GPT-2.
3
+
4
+ Original source: https://github.com/openai/gpt-2/blob/master/src/encoder.py
5
+ Original license: MIT
6
+ """
7
+
8
+ from functools import lru_cache
9
+ import json
10
+ import random
11
+ import unicodedata
12
+
13
+ try:
14
+ import regex as re
15
+ except ImportError:
16
+ raise ImportError('Please install regex with: pip install regex')
17
+
18
+ @lru_cache()
19
+ def bytes_to_unicode():
20
+ """
21
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
22
+ The reversible bpe codes work on unicode strings.
23
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
24
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
25
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
26
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
27
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
28
+ """
29
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
30
+ cs = bs[:]
31
+ n = 0
32
+ for b in range(2**8):
33
+ if b not in bs:
34
+ bs.append(b)
35
+ cs.append(2**8+n)
36
+ n += 1
37
+ cs = [chr(n) for n in cs]
38
+ return dict(zip(bs, cs))
39
+
40
+ def get_pairs(word):
41
+ """Return set of symbol pairs in a word.
42
+ Word is represented as tuple of symbols (symbols being variable-length strings).
43
+ """
44
+ pairs = set()
45
+ prev_char = word[0]
46
+ for char in word[1:]:
47
+ pairs.add((prev_char, char))
48
+ prev_char = char
49
+ return pairs
50
+
51
+ class Encoder:
52
+
53
+ def __init__(self, encoder, bpe_merges, errors='replace'):
54
+ self.encoder = encoder
55
+ self.decoder = {v:k for k,v in self.encoder.items()}
56
+ self.errors = errors # how to handle errors in decoding
57
+ self.byte_encoder = bytes_to_unicode()
58
+ self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
59
+ self.bpe_ranks = dict(zip([tuple(k) for k in bpe_merges], range(len(bpe_merges))))
60
+ self.cache = {}
61
+ self.random = random.Random(0)
62
+
63
+ # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
64
+ self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
65
+
66
+ def bpe(self, token):
67
+ if token in self.cache:
68
+ return self.cache[token]
69
+ word = tuple(token)
70
+ pairs = get_pairs(word)
71
+
72
+ if not pairs:
73
+ return token
74
+
75
+ while True:
76
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
77
+ if bigram not in self.bpe_ranks:
78
+ break
79
+ first, second = bigram
80
+ new_word = []
81
+ i = 0
82
+ while i < len(word):
83
+ try:
84
+ j = word.index(first, i)
85
+ new_word.extend(word[i:j])
86
+ i = j
87
+ except:
88
+ new_word.extend(word[i:])
89
+ break
90
+
91
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
92
+ new_word.append(first+second)
93
+ i += 2
94
+ else:
95
+ new_word.append(word[i])
96
+ i += 1
97
+ new_word = tuple(new_word)
98
+ word = new_word
99
+ if len(word) == 1:
100
+ break
101
+ else:
102
+ pairs = get_pairs(word)
103
+ word = ' '.join(word)
104
+ self.cache[token] = word
105
+ return word
106
+
107
+ def split_to_words(self, text):
108
+ return list(re.findall(self.pat, text))
109
+
110
+ def encode(self, text):
111
+ bpe_tokens = []
112
+ for token in self.split_to_words(text):
113
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
114
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
115
+ return bpe_tokens
116
+
117
+ def decode(self, tokens):
118
+ text = ''.join([self.decoder[token] for token in tokens])
119
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
120
+ return text
121
+
122
+ def get_encoder(encoder, vocab):
123
+ return Encoder(
124
+ encoder=encoder,
125
+ bpe_merges=vocab,
126
+ )
127
+
128
+ def _is_whitespace(char):
129
+ """Checks whether `chars` is a whitespace character."""
130
+ # \t, \n, and \r are technically contorl characters but we treat them
131
+ # as whitespace since they are generally considered as such.
132
+ if char == " " or char == "\t" or char == "\n" or char == "\r":
133
+ return True
134
+ cat = unicodedata.category(char)
135
+ if cat == "Zs":
136
+ return True
137
+ return False
138
+
139
+ def _is_control(char):
140
+ """Checks whether `chars` is a control character."""
141
+ # These are technically control characters but we count them as whitespace
142
+ # characters.
143
+ if char == "\t" or char == "\n" or char == "\r":
144
+ return False
145
+ cat = unicodedata.category(char)
146
+ if cat.startswith("C"):
147
+ return True
148
+ return False
149
+
150
+ def _is_punctuation(char):
151
+ """Checks whether `chars` is a punctuation character."""
152
+ cp = ord(char)
153
+ # We treat all non-letter/number ASCII as punctuation.
154
+ # Characters such as "^", "$", and "`" are not in the Unicode
155
+ # Punctuation class but we treat them as punctuation anyways, for
156
+ # consistency.
157
+ if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
158
+ (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
159
+ return True
160
+ cat = unicodedata.category(char)
161
+ if cat.startswith("P"):
162
+ return True
163
+ return False
modeling/gpt2_tokenizer.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Copyright (c) Microsoft, Inc. 2020
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+ # Zhou Bo
8
+ # Date: 01/15/2020
9
+ #
10
+
11
+ # This piece of code is derived from https://github.com/pytorch/fairseq/blob/master/fairseq/data/encoders/gpt2_bpe.py
12
+
13
+ import torch
14
+ import unicodedata
15
+ import os
16
+ from .gpt2_bpe_utils import get_encoder,_is_control,_is_whitespace,_is_punctuation
17
+ from .cache_utils import load_vocab
18
+
19
+ __all__ = ['GPT2Tokenizer']
20
+
21
+ class GPT2Tokenizer(object):
22
+ """ A wrapper of GPT2 tokenizer with similar interface as BERT tokenizer
23
+
24
+ Args:
25
+
26
+ vocab_file (:obj:`str`, optional):
27
+ The local path of vocabulary package or the release name of vocabulary in `DeBERTa GitHub releases <https://github.com/microsoft/DeBERTa/releases>`_, \
28
+ e.g. "bpe_encoder", default: `None`.
29
+
30
+ If it's `None`, then it will download the vocabulary in the latest release from GitHub. The vocabulary file is a \
31
+ state dictionary with three items, "dict_map", "vocab", "encoder" which correspond to three files used in `RoBERTa`, i.e. `dict.txt`, `vocab.txt` and `encoder.json`. \
32
+
33
+ The difference between our wrapped GPT2 tokenizer and RoBERTa wrapped tokenizer are,
34
+
35
+ - Special tokens, unlike `RoBERTa` which use `<s>`, `</s>` as the `start` token and `end` token of a sentence. We use `[CLS]` and `[SEP]` as the `start` and `end`\
36
+ token of input sentence which is the same as `BERT`.
37
+
38
+ - We remapped the token ids in our dictionary with regarding to the new special tokens, `[PAD]` => 0, `[CLS]` => 1, `[SEP]` => 2, `[UNK]` => 3, `[MASK]` => 50264
39
+
40
+ do_lower_case (:obj:`bool`, optional):
41
+ Whether to convert inputs to lower case. **Not used in GPT2 tokenizer**.
42
+
43
+ special_tokens (:obj:`list`, optional):
44
+ List of special tokens to be added to the end of the vocabulary.
45
+
46
+
47
+ """
48
+ def __init__(self, vocab_file=None, do_lower_case=True, special_tokens=None):
49
+ self.pad_token='[PAD]'
50
+ self.sep_token='[SEP]'
51
+ self.unk_token='[UNK]'
52
+ self.cls_token='[CLS]'
53
+
54
+ self.symbols = []
55
+ self.count = []
56
+ self.indices = {}
57
+ self.pad_token_id = self.add_symbol(self.pad_token)
58
+ self.cls_token_id = self.add_symbol(self.cls_token)
59
+ self.sep_token_id = self.add_symbol(self.sep_token)
60
+ self.unk_token_id = self.add_symbol(self.unk_token)
61
+
62
+ self.gpt2_encoder = torch.load(vocab_file)
63
+ self.bpe = get_encoder(self.gpt2_encoder['encoder'], self.gpt2_encoder['vocab'])
64
+ for w,n in self.gpt2_encoder['dict_map']:
65
+ self.add_symbol(w, n)
66
+
67
+ self.mask_token='[MASK]'
68
+ self.mask_id = self.add_symbol(self.mask_token)
69
+ self.special_tokens = ['[MASK]', '[SEP]', '[PAD]', '[UNK]', '[CLS]']
70
+ if special_tokens is not None:
71
+ for t in special_tokens:
72
+ self.add_special_token(t)
73
+
74
+ self.vocab = self.indices
75
+ self.ids_to_tokens = self.symbols
76
+
77
+ def tokenize(self, text):
78
+ """ Convert an input text to tokens.
79
+
80
+ Args:
81
+
82
+ text (:obj:`str`): input text to be tokenized.
83
+
84
+ Returns:
85
+ A list of byte tokens where each token represent the byte id in GPT2 byte dictionary
86
+
87
+ Example::
88
+
89
+ >>> tokenizer = GPT2Tokenizer()
90
+ >>> text = "Hello world!"
91
+ >>> tokens = tokenizer.tokenize(text)
92
+ >>> print(tokens)
93
+ ['15496', '995', '0']
94
+
95
+ """
96
+ bpe = self._encode(text)
97
+
98
+ return [t for t in bpe.split(' ') if t]
99
+
100
+ def convert_tokens_to_ids(self, tokens):
101
+ """ Convert list of tokens to ids.
102
+
103
+ Args:
104
+
105
+ tokens (:obj:`list<str>`): list of tokens
106
+
107
+ Returns:
108
+
109
+ List of ids
110
+ """
111
+
112
+ return [self.vocab[t] for t in tokens]
113
+
114
+ def convert_ids_to_tokens(self, ids):
115
+ """ Convert list of ids to tokens.
116
+
117
+ Args:
118
+
119
+ ids (:obj:`list<int>`): list of ids
120
+
121
+ Returns:
122
+
123
+ List of tokens
124
+ """
125
+
126
+ tokens = []
127
+ for i in ids:
128
+ tokens.append(self.ids_to_tokens[i])
129
+ return tokens
130
+
131
+ def split_to_words(self, text):
132
+ return self.bpe.split_to_words(text)
133
+
134
+ def decode(self, tokens):
135
+ """ Decode list of tokens to text strings.
136
+
137
+ Args:
138
+
139
+ tokens (:obj:`list<str>`): list of tokens.
140
+
141
+ Returns:
142
+
143
+ Text string corresponds to the input tokens.
144
+
145
+ Example::
146
+
147
+ >>> tokenizer = GPT2Tokenizer()
148
+ >>> text = "Hello world!"
149
+ >>> tokens = tokenizer.tokenize(text)
150
+ >>> print(tokens)
151
+ ['15496', '995', '0']
152
+
153
+ >>> tokenizer.decode(tokens)
154
+ 'Hello world!'
155
+
156
+ """
157
+ return self.bpe.decode([int(t) for t in tokens if t not in self.special_tokens])
158
+
159
+ def add_special_token(self, token):
160
+ """Adds a special token to the dictionary.
161
+
162
+ Args:
163
+ token (:obj:`str`): Tthe new token/word to be added to the vocabulary.
164
+
165
+ Returns:
166
+ The id of new token in the vocabulary.
167
+
168
+ """
169
+ self.special_tokens.append(token)
170
+ return self.add_symbol(token)
171
+
172
+ def part_of_whole_word(self, token, is_bos=False):
173
+ if is_bos:
174
+ return True
175
+ s = self._decode(token)
176
+ if (len(s)==1 and (_is_whitespace(list(s)[0]) or _is_control(list(s)[0]) or _is_punctuation(list(s)[0]))):
177
+ return False
178
+
179
+ return not s.startswith(' ')
180
+
181
+ def sym(self, id):
182
+ return self.ids_to_tokens[id]
183
+
184
+ def id(self, sym):
185
+ return self.vocab[sym]
186
+
187
+ def _encode(self, x: str) -> str:
188
+ return ' '.join(map(str, self.bpe.encode(x)))
189
+
190
+ def _decode(self, x: str) -> str:
191
+ return self.bpe.decode(map(int, x.split()))
192
+
193
+ def add_symbol(self, word, n=1):
194
+ """Adds a word to the dictionary.
195
+
196
+ Args:
197
+ word (:obj:`str`): Tthe new token/word to be added to the vocabulary.
198
+ n (int, optional): The frequency of the word.
199
+
200
+ Returns:
201
+ The id of the new word.
202
+
203
+ """
204
+ if word in self.indices:
205
+ idx = self.indices[word]
206
+ self.count[idx] = self.count[idx] + n
207
+ return idx
208
+ else:
209
+ idx = len(self.symbols)
210
+ self.indices[word] = idx
211
+ self.symbols.append(word)
212
+ self.count.append(n)
213
+ return idx
214
+
215
+ def save_pretrained(self, path: str):
216
+ torch.save(self.gpt2_encoder, path)
modeling/mlm.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
2
+ # Copyright (c) Microsoft, Inc. 2020
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # This piece of code is modified based on https://github.com/huggingface/transformers
8
+
9
+ import torch
10
+ from torch import nn
11
+ import pdb
12
+
13
+ from .bert import LayerNorm,ACT2FN
14
+
15
+ __all__ = ['MLMPredictionHead']
16
+
17
+ class MLMPredictionHead(nn.Module):
18
+ def __init__(self, config, vocab_size):
19
+ super().__init__()
20
+ self.embedding_size = getattr(config, 'embedding_size', config.hidden_size)
21
+ self.dense = nn.Linear(config.hidden_size, self.embedding_size)
22
+ self.transform_act_fn = ACT2FN[config.hidden_act] \
23
+ if isinstance(config.hidden_act, str) else config.hidden_act
24
+
25
+ self.LayerNorm = LayerNorm(self.embedding_size, config.layer_norm_eps)
26
+ self.bias = nn.Parameter(torch.zeros(vocab_size))
27
+ self.pre_norm = PreLayerNorm(config)
28
+
29
+ def forward(self, hidden_states, embeding_weight):
30
+ hidden_states = self.pre_norm(hidden_states)
31
+ hidden_states = self.dense(hidden_states)
32
+ hidden_states = self.transform_act_fn(hidden_states)
33
+ # b x s x d
34
+ hidden_states = MaskedLayerNorm(self.LayerNorm, hidden_states)
35
+
36
+ # b x s x v
37
+ logits = torch.matmul(hidden_states, embeding_weight.t().to(hidden_states)) + self.bias
38
+ return logits
modeling/modeling.py ADDED
The diff for this file is too large to render. See raw diff
 
modeling/nnmodule.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pdb
2
+ import os
3
+ import torch
4
+ import copy
5
+ from torch import nn, tensor
6
+ from .config import ModelConfig
7
+ from ..utils import xtqdm as tqdm
8
+ from .cache_utils import load_model_state
9
+ from .flash import GAULinear
10
+
11
+ from ..utils import get_logger
12
+ logger = get_logger()
13
+
14
+ __all__ = ['NNModule']
15
+
16
+ def truncated_normal_(shape, mean=0, std=0.09):
17
+ with torch.no_grad():
18
+ tensor = torch.zeros(shape)
19
+ tmp = tensor.new_empty(shape + (4,)).normal_()
20
+ valid = (tmp < 2) & (tmp > -2)
21
+ ind = valid.max(-1, keepdim=True)[1]
22
+ tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
23
+ tensor.data.mul_(std).add_(mean)
24
+ return tensor
25
+
26
+ class NNModule(nn.Module):
27
+ """ An abstract class to handle weights initialization and \
28
+ a simple interface for dowloading and loading pretrained models.
29
+
30
+ Args:
31
+
32
+ config (:obj:`~DeBERTa.deberta.ModelConfig`): The model config to the module
33
+
34
+ """
35
+
36
+ def __init__(self, config, *inputs, **kwargs):
37
+ super().__init__()
38
+ self.config = config
39
+
40
+ def init_weights(self, module):
41
+ """ Apply Gaussian(mean=0, std=`config.initializer_range`) initialization to the module.
42
+
43
+ Args:
44
+
45
+ module (:obj:`torch.nn.Module`): The module to apply the initialization.
46
+
47
+ Example::
48
+
49
+ class MyModule(NNModule):
50
+ def __init__(self, config):
51
+ # Add construction instructions
52
+ self.bert = DeBERTa(config)
53
+
54
+ # Add other modules
55
+ ...
56
+
57
+ # Apply initialization
58
+ self.apply(self.init_weights)
59
+
60
+ """
61
+ if isinstance(module, (nn.Linear, nn.Embedding)):
62
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
63
+ if isinstance(module, nn.Linear) and module.bias is not None:
64
+ module.bias.data.zero_()
65
+
66
+ def init_weights_gau(self, module):
67
+ """ Apply Gaussian(mean=0, std=`config.initializer_range`) initialization to the module.
68
+
69
+ Args:
70
+
71
+ module (:obj:`torch.nn.Module`): The module to apply the initialization.
72
+
73
+ Example::
74
+
75
+ class MyModule(NNModule):
76
+ def __init__(self, config):
77
+ # Add construction instructions
78
+ self.bert = DeBERTa(config)
79
+
80
+ # Add other modules
81
+ ...
82
+
83
+ # Apply initialization
84
+ self.apply(self.init_weights)
85
+
86
+ """
87
+ if isinstance(module, GAULinear):
88
+ module.init_weight()
89
+ else:
90
+ if isinstance(module, (nn.Linear, nn.Embedding)):
91
+ # module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
92
+ module.weight.data.copy_(self.initializer(module.weight.data.shape))
93
+ if isinstance(module, nn.Linear) and module.bias is not None:
94
+ module.bias.data.zero_()
95
+
96
+ def initializer(self, shape, dtype=None, order=3, gain=1.0):
97
+ if shape[1] > 10000 or shape[1] < 10:
98
+ hidden_size = shape[0]
99
+ else:
100
+ hidden_size = shape[1]
101
+ gain *= self.config.num_hidden_layers ** (-1.0 / order)
102
+ stddev = 1.13684723 / hidden_size**0.5 * gain
103
+ return torch.nn.init.trunc_normal_(torch.empty(shape, dtype=dtype), std=stddev)# truncated_normal_(shape, std=stddev)
104
+
105
+ @classmethod
106
+ def load_model(cls, model_path, model_config=None, tag=None, no_cache=False, cache_dir=None , *inputs, **kwargs):
107
+ """ Instantiate a sub-class of NNModule from a pre-trained model file.
108
+
109
+ Args:
110
+
111
+ model_path (:obj:`str`): Path or name of the pre-trained model which can be either,
112
+
113
+ - The path of pre-trained model
114
+
115
+ - The pre-trained DeBERTa model name in `DeBERTa GitHub releases <https://github.com/microsoft/DeBERTa/releases>`_, i.e. [**base, base_mnli, large, large_mnli**].
116
+
117
+ If `model_path` is `None` or `-`, then the method will create a new sub-class without initialing from pre-trained models.
118
+
119
+ model_config (:obj:`str`): The path of model config file. If it's `None`, then the method will try to find the the config in order:
120
+
121
+ 1. ['config'] in the model state dictionary.
122
+
123
+ 2. `model_config.json` aside the `model_path`.
124
+
125
+ If it failed to find a config the method will fail.
126
+
127
+ tag (:obj:`str`, optional): The release tag of DeBERTa, default: `None`.
128
+
129
+ no_cache (:obj:`bool`, optional): Disable local cache of downloaded models, default: `False`.
130
+
131
+ cache_dir (:obj:`str`, optional): The cache directory used to save the downloaded models, default: `None`. If it's `None`, then the models will be saved at `$HOME/.~DeBERTa`
132
+
133
+ Return:
134
+
135
+ :obj:`NNModule` : The sub-class object.
136
+
137
+ """
138
+ # Load config
139
+ if model_config:
140
+ config = ModelConfig.from_json_file(model_config)
141
+ else:
142
+ config = None
143
+ model_config = None
144
+ model_state = None
145
+ if (model_path is not None) and (model_path.strip() == '-' or model_path.strip()==''):
146
+ model_path = None
147
+ try:
148
+ model_state, model_config = load_model_state(model_path, tag=tag, no_cache=no_cache, cache_dir=cache_dir)
149
+ except Exception as exp:
150
+ raise Exception(f'Failed to get model {model_path}. Exception: {exp}')
151
+
152
+ if config is not None and model_config is not None:
153
+ for k in config.__dict__:
154
+ if k not in ['hidden_size',
155
+ 'intermediate_size',
156
+ 'num_attention_heads',
157
+ 'num_hidden_layers',
158
+ 'vocab_size',
159
+ 'max_position_embeddings'] or (k not in model_config.__dict__) or (model_config.__dict__[k] < 0):
160
+ model_config.__dict__[k] = config.__dict__[k]
161
+ if model_config is not None:
162
+ config = copy.copy(model_config)
163
+ vocab_size = config.vocab_size
164
+ # Instantiate model.
165
+ model = cls(config, *inputs, **kwargs)
166
+ if not model_state:
167
+ return model
168
+ # copy state_dict so _load_from_state_dict can modify it
169
+ state_dict = model_state.copy()
170
+
171
+ missing_keys = []
172
+ unexpected_keys = []
173
+ error_msgs = []
174
+ metadata = getattr(state_dict, '_metadata', None)
175
+ def load(module, prefix=''):
176
+ local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
177
+ module._load_from_state_dict(
178
+ state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
179
+ for name, child in module._modules.items():
180
+ if child is not None:
181
+ load(child, prefix + name + '.')
182
+ load(model)
183
+ logger.warning(f'Missing keys: {missing_keys}, unexpected_keys: {unexpected_keys}, error_msgs: {error_msgs}')
184
+ return model
modeling/ops.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft, Inc. 2020
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ #
6
+ # Zhou Bo
7
+ # Date: 01/15/2020
8
+ #
9
+
10
+ import pdb
11
+ import math
12
+ from packaging import version
13
+ import torch
14
+ from torch.nn import LayerNorm
15
+ from wywLM.utils.jit_tracing import traceable
16
+
17
+ if version.Version(torch.__version__) >= version.Version('1.0.0'):
18
+ from torch import _softmax_backward_data as _softmax_backward_data
19
+ else:
20
+ from torch import softmax_backward_data as _softmax_backward_data
21
+
22
+ __all__ = ['StableDropout', 'MaskedLayerNorm', 'XSoftmax', 'ACT2FN', 'LayerNorm']
23
+
24
+ @traceable
25
+ class XSoftmax(torch.autograd.Function):
26
+ """ Masked Softmax which is optimized for saving memory
27
+
28
+ Args:
29
+
30
+ input (:obj:`torch.tensor`): The input tensor that will apply softmax.
31
+ mask (:obj:`torch.IntTensor`): The mask matrix where 0 indicate that element will be ignored in the softmax caculation.
32
+ dim (int): The dimenssion that will apply softmax.
33
+
34
+ Example::
35
+
36
+ import torch
37
+ from DeBERTa.deberta import XSoftmax
38
+ # Make a tensor
39
+ x = torch.randn([4,20,100])
40
+ # Create a mask
41
+ mask = (x>0).int()
42
+ y = XSoftmax.apply(x, mask, dim=-1)
43
+
44
+ """
45
+
46
+ @staticmethod
47
+ def forward(self, input, mask, dim):
48
+ """
49
+ """
50
+
51
+ self.dim = dim
52
+ if mask is None:
53
+ mask = torch.ones_like(input)
54
+ if version.Version(torch.__version__) >= version.Version('1.2.0a'):
55
+ rmask = ~(mask.bool())
56
+ else:
57
+ rmask = (1-mask).byte() # This line is not supported by Onnx tracing.
58
+
59
+ output = input.masked_fill(rmask, torch.finfo(input.dtype).min) # float('-inf')
60
+ output = torch.softmax(output, self.dim)
61
+ output.masked_fill_(rmask, 0)
62
+ self.save_for_backward(output)
63
+ return output
64
+
65
+ @staticmethod
66
+ def backward(self, grad_output):
67
+ """
68
+ """
69
+
70
+ output, = self.saved_tensors
71
+ if '1.11' in torch.__version__:
72
+ inputGrad = _softmax_backward_data(grad_output, output, self.dim, output.dtype)
73
+ else:
74
+ inputGrad = _softmax_backward_data(grad_output, output, self.dim, output.dtype)
75
+ return inputGrad, None, None
76
+
77
+ @staticmethod
78
+ def symbolic(g, self, mask, dim):
79
+ import torch.onnx.symbolic_helper as sym_help
80
+ from torch.onnx.symbolic_opset9 import masked_fill, softmax
81
+
82
+ mask_cast_value = g.op("Cast", mask, to_i=sym_help.cast_pytorch_to_onnx['Long'])
83
+ r_mask = g.op("Cast", g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value), to_i=sym_help.cast_pytorch_to_onnx['Byte'])
84
+ output = masked_fill(g, self, r_mask, g.op("Constant", value_t=torch.tensor(float('-inf'))))
85
+ output = softmax(g, output, dim)
86
+ return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8)))
87
+
88
+ class DropoutContext(object):
89
+ def __init__(self):
90
+ self.dropout = 0
91
+ self.mask = None
92
+ self.scale = 1
93
+ self.reuse_mask = True
94
+
95
+ def get_mask(input, local_context):
96
+ if not isinstance(local_context, DropoutContext):
97
+ dropout = local_context
98
+ mask = None
99
+ else:
100
+ dropout = local_context.dropout
101
+ dropout *= local_context.scale
102
+ mask = local_context.mask if local_context.reuse_mask else None
103
+
104
+ if dropout>0 and mask is None:
105
+ if version.Version(torch.__version__) >= version.Version('1.2.0a'):
106
+ mask=(1-torch.empty_like(input).bernoulli_(1-dropout)).bool()
107
+ else:
108
+ mask=(1-torch.empty_like(input).bernoulli_(1-dropout)).byte()
109
+
110
+ if isinstance(local_context, DropoutContext):
111
+ if local_context.mask is None:
112
+ local_context.mask = mask
113
+
114
+ return mask, dropout
115
+
116
+ @traceable
117
+ class XDropout(torch.autograd.Function):
118
+ @staticmethod
119
+ def forward(ctx, input, local_ctx):
120
+ mask, dropout = get_mask(input, local_ctx)
121
+ ctx.scale=1.0/(1-dropout)
122
+ if dropout>0:
123
+ ctx.save_for_backward(mask)
124
+ return input.masked_fill(mask, 0)*ctx.scale
125
+ else:
126
+ return input
127
+
128
+ @staticmethod
129
+ def backward(ctx, grad_output):
130
+ if ctx.scale > 1:
131
+ mask, = ctx.saved_tensors
132
+ return grad_output.masked_fill(mask, 0)*ctx.scale, None
133
+ else:
134
+ return grad_output, None
135
+
136
+ class StableDropout(torch.nn.Module):
137
+ """ Optimized dropout module for stabilizing the training
138
+
139
+ Args:
140
+
141
+ drop_prob (float): the dropout probabilities
142
+
143
+ """
144
+
145
+ def __init__(self, drop_prob):
146
+ super().__init__()
147
+ self.drop_prob = drop_prob
148
+ self.count = 0
149
+ self.context_stack = None
150
+
151
+ def forward(self, x):
152
+ """ Call the module
153
+
154
+ Args:
155
+
156
+ x (:obj:`torch.tensor`): The input tensor to apply dropout
157
+
158
+
159
+ """
160
+ if self.training and self.drop_prob>0:
161
+ return XDropout.apply(x, self.get_context())
162
+ return x
163
+
164
+ def clear_context(self):
165
+ self.count = 0
166
+ self.context_stack = None
167
+
168
+ def init_context(self, reuse_mask=True, scale = 1):
169
+ if self.context_stack is None:
170
+ self.context_stack = []
171
+ self.count = 0
172
+ for c in self.context_stack:
173
+ c.reuse_mask = reuse_mask
174
+ c.scale = scale
175
+
176
+ def get_context(self):
177
+ if self.context_stack is not None:
178
+ if self.count >= len(self.context_stack):
179
+ self.context_stack.append(DropoutContext())
180
+ ctx = self.context_stack[self.count]
181
+ ctx.dropout = self.drop_prob
182
+ self.count += 1
183
+ return ctx
184
+ else:
185
+ return self.drop_prob
186
+
187
+ def MaskedLayerNorm(layerNorm, input, mask = None):
188
+ """ Masked LayerNorm which will apply mask over the output of LayerNorm to avoid inaccurate updatings to the LayerNorm module.
189
+
190
+ Args:
191
+ layernorm (:obj:`~DeBERTa.deberta.LayerNorm`): LayerNorm module or function
192
+ input (:obj:`torch.tensor`): The input tensor
193
+ mask (:obj:`torch.IntTensor`): The mask to applied on the output of LayerNorm where `0` indicate the output of that element will be ignored, i.e. set to `0`
194
+
195
+ Example::
196
+
197
+ # Create a tensor b x n x d
198
+ x = torch.randn([1,10,100])
199
+ m = torch.tensor([[1,1,1,0,0,0,0,0,0,0]], dtype=torch.int)
200
+ LayerNorm = DeBERTa.deberta.LayerNorm(100)
201
+ y = MaskedLayerNorm(LayerNorm, x, m)
202
+
203
+ """
204
+ output = layerNorm(input).to(input)
205
+ if mask is None:
206
+ return output
207
+ if mask.dim()!=input.dim():
208
+ if mask.dim()==4:
209
+ mask=mask.squeeze(1).squeeze(1)
210
+ mask = mask.unsqueeze(2)
211
+ mask = mask.to(output.dtype)
212
+ return output*mask
213
+
214
+ def gelu(x):
215
+ """Implementation of the gelu activation function.
216
+ For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
217
+ 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
218
+ """
219
+ return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
220
+
221
+
222
+ def swish(x):
223
+ return x * torch.sigmoid(x)
224
+
225
+ def linear_act(x):
226
+ return x
227
+
228
+ def sequence_masking(x, mask, value=0, axis=None):
229
+ """为序列条件mask的函数
230
+ mask: 形如(batch_size, seq_len)的0-1矩阵;
231
+ value: mask部分要被替换成的值,可以是'-inf'或'inf';
232
+ axis: 序列所在轴,默认为1;
233
+ """
234
+ if mask is None:
235
+ return x
236
+ else:
237
+ x_dtype = x.dtype
238
+ if x_dtype == torch.bool:
239
+ x = x.to(torch.int32)
240
+ # if mask.dtype != x.dtype:
241
+ # mask = mask.to(x.dtype)
242
+ if value == '-inf':
243
+ value = -float('inf')
244
+ elif value == 'inf':
245
+ value = float('inf')
246
+ if axis is None:
247
+ axis = 1
248
+ elif axis < 0:
249
+ axis = x.dim() + axis
250
+ assert axis > 0, 'axis must be greater than 0'
251
+ if mask.dim() != x.dim():
252
+ mask = align(mask, [0, axis], x.dim())
253
+ # value = value.to(x.dtype)
254
+ x = x.masked_fill_(~mask.bool(), value) # * mask + mask.fill_(value)
255
+ if x_dtype == torch.bool:
256
+ x = x.to(torch.bool)
257
+ return x
258
+
259
+ def align(tensor, axes, ndim=None):
260
+ """重新对齐tensor(批量版expand_dims)
261
+ axes:原来的第i维对齐新tensor的第axes[i]维;
262
+ ndim:新tensor的维度。
263
+ """
264
+ assert len(axes) == tensor.dim()
265
+ assert ndim or min(axes) >= 0
266
+ ndim = ndim or max(axes) + 1
267
+ indices = [None] * ndim
268
+ for i in axes:
269
+ indices[i] = slice(None)
270
+ return tensor[indices]
271
+
272
+ ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish, "tanh": torch.tanh, "linear": linear_act, 'sigmoid': torch.sigmoid, 'silu': torch.nn.functional.silu}
273
+
274
+
modeling/pooling.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Zhou Bo
3
+ #
4
+ #
5
+ """
6
+ Pooling functions
7
+ """
8
+
9
+ from torch import nn
10
+ import copy
11
+ import json
12
+ import pdb
13
+ from .bert import ACT2FN
14
+ from .ops import StableDropout
15
+ from .config import AbsModelConfig
16
+
17
+ __all__ = ['PoolConfig', 'ContextPooler']
18
+
19
+ class PoolConfig(AbsModelConfig):
20
+ """Configuration class to store the configuration of `pool layer`.
21
+
22
+ Parameters:
23
+
24
+ config (:class:`~DeBERTa.deberta.ModelConfig`): The model config. The field of pool config will be initalized with the `pooling` field in model config.
25
+
26
+ Attributes:
27
+
28
+ hidden_size (int): Size of the encoder layers and the pooler layer, default: `768`.
29
+
30
+ dropout (float): The dropout rate applied on the output of `[CLS]` token,
31
+
32
+ hidden_act (:obj:`str`): The activation function of the projection layer, it can be one of ['gelu', 'tanh'].
33
+
34
+ Example::
35
+
36
+ # Here is the content of an exmple model config file in json format
37
+
38
+ {
39
+ "hidden_size": 768,
40
+ "num_hidden_layers" 12,
41
+ "num_attention_heads": 12,
42
+ "intermediate_size": 3072,
43
+ ...
44
+ "pooling": {
45
+ "hidden_size": 768,
46
+ "hidden_act": "gelu",
47
+ "dropout": 0.1
48
+ }
49
+ }
50
+
51
+ """
52
+ def __init__(self, config=None):
53
+ """Constructs PoolConfig.
54
+
55
+ Args:
56
+ `config`: the config of the model. The field of pool config will be initalized with the 'pooling' field in model config.
57
+ """
58
+
59
+ self.hidden_size = 768
60
+ self.dropout = 0
61
+ self.hidden_act = 'gelu'
62
+ if config:
63
+ pool_config = getattr(config, 'pooling', config)
64
+ if isinstance(pool_config, dict):
65
+ pool_config = AbsModelConfig.from_dict(pool_config)
66
+ self.hidden_size = getattr(pool_config, 'hidden_size', config.hidden_size)
67
+ self.dropout = getattr(pool_config, 'dropout', 0)
68
+ self.hidden_act = getattr(pool_config, 'hidden_act', 'gelu')
69
+
70
+ class ContextPooler(nn.Module):
71
+ def __init__(self, config):
72
+ super().__init__()
73
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
74
+ self.dropout = StableDropout(config.dropout)
75
+ self.config = config
76
+
77
+ def forward(self, hidden_states, mask = None):
78
+ # We "pool" the model by simply taking the hidden state corresponding
79
+ # to the first token.
80
+
81
+ context_token = hidden_states[:, 0]
82
+ context_token = self.dropout(context_token)
83
+ pooled_output = self.dense(context_token)
84
+ pooled_output = ACT2FN[self.config.hidden_act](pooled_output)
85
+ return pooled_output
86
+
87
+ def output_dim(self):
88
+ return self.config.hidden_size
modeling/pretrained_models.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+
2
+
modeling/spm_tokenizer.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft, Inc. 2020
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ #
6
+ # Zhou Bo
7
+ # Date: 11/15/2020
8
+ #
9
+
10
+
11
+ import sentencepiece as sp
12
+ import six
13
+ import unicodedata
14
+ import os
15
+ import regex as re
16
+ from .cache_utils import load_vocab
17
+ import loguru
18
+ logger=loguru.logger
19
+
20
+
21
+ import pdb
22
+
23
+ __all__ = ['SPMTokenizer']
24
+
25
+ class SPMTokenizer:
26
+ def __init__(self, vocab_file, do_lower_case=False, special_tokens=None, bpe_dropout=0, split_by_punct=False):
27
+ self.split_by_punct = split_by_punct
28
+ spm = sp.SentencePieceProcessor()
29
+ assert os.path.exists(vocab_file)
30
+ spm.load(vocab_file)
31
+ bpe_vocab_size = spm.GetPieceSize()
32
+ # Token map
33
+ # <unk> 0+1
34
+ # <s> 1+1
35
+ # </s> 2+1
36
+ self.vocab = {spm.IdToPiece(i):i for i in range(bpe_vocab_size)}
37
+ self.id_to_tokens = [spm.IdToPiece(i) for i in range(bpe_vocab_size)]
38
+ #self.vocab['[PAD]'] = 0
39
+ #self.vocab['[CLS]'] = 1
40
+ #self.vocab['[SEP]'] = 2
41
+ #self.vocab['[UNK]'] = 3
42
+
43
+ _special_tokens = ['[MASK]', '[SEP]', '[PAD]', '[UNK]', '[CLS]']
44
+ self.special_tokens = []
45
+ if special_tokens is not None:
46
+ _special_tokens.extend(special_tokens)
47
+ for t in _special_tokens:
48
+ self.add_special_token(t)
49
+
50
+ self.spm = spm
51
+ self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
52
+
53
+ def tokenize(self, text):
54
+ pieces = self._encode_as_pieces(text)
55
+ def _norm(x):
56
+ if x not in self.vocab or x=='<unk>':
57
+ return '[UNK]'
58
+ else:
59
+ return x
60
+ pieces = [_norm(p) for p in pieces]
61
+ return pieces
62
+
63
+ def convert_tokens_to_ids(self, tokens):
64
+ return [self.vocab[t] if t in self.vocab else 1 for t in tokens]
65
+
66
+ def convert_ids_to_tokens(self, ids):
67
+ tokens = []
68
+ for i in ids:
69
+ tokens.append(self.ids_to_tokens[i])
70
+ return tokens
71
+
72
+ def decode(self, tokens, start=-1, end=-1, raw_text=None):
73
+ if raw_text is None:
74
+ return self.spm.decode_pieces([t for t in tokens if t not in self.special_tokens])
75
+ else:
76
+ words = self.split_to_words(raw_text)
77
+ word_tokens = [self.tokenize(w) for w in words]
78
+ wt = [w for t in word_tokens for w in t]
79
+ #assert tokens == wt, f'{tokens} || {wt}'
80
+ if wt!=tokens:
81
+ for a,b in zip(wt, tokens):
82
+ if a!=b:
83
+ pdb.set_trace()
84
+ token2words = [0]*len(tokens)
85
+ tid = 0
86
+ for i,w in enumerate(word_tokens):
87
+ for k,t in enumerate(w):
88
+ token2words[tid] = i
89
+ tid += 1
90
+ word_start = token2words[start]
91
+ word_end = token2words[end] if end <len(tokens) else len(words)
92
+ text = ''.join(words[word_start:word_end])
93
+ return text
94
+
95
+ def add_special_token(self, token):
96
+ if token not in self.special_tokens:
97
+ self.special_tokens.append(token)
98
+ if token not in self.vocab:
99
+ self.vocab[token] = len(self.vocab)
100
+ self.id_to_tokens.append(token)
101
+ return self.id(token)
102
+
103
+ def part_of_whole_word(self, token, is_bos=False):
104
+ if is_bos:
105
+ return True
106
+ if (len(token)==1 and (_is_whitespace(list(token)[0]) or _is_control(list(token)[0]) or _is_punctuation(list(token)[0]))) or token in self.special_tokens:
107
+ return False
108
+
109
+ word_start = b'\xe2\x96\x81'.decode('utf-8')
110
+ return not token.startswith(word_start)
111
+
112
+ def pad(self):
113
+ return '[PAD]'
114
+
115
+ def bos(self):
116
+ return '[CLS]'
117
+
118
+ def eos(self):
119
+ return '[SEP]'
120
+
121
+ def unk(self):
122
+ return '[UNK]'
123
+
124
+ def mask(self):
125
+ return '[MASK]'
126
+
127
+ def sym(self, id):
128
+ return self.ids_to_tokens[id]
129
+
130
+ def id(self, sym):
131
+ return self.vocab[sym] if sym in self.vocab else 1
132
+
133
+ def _encode_as_pieces(self, text):
134
+ text = convert_to_unicode(text)
135
+ if self.split_by_punct:
136
+ words = self._run_split_on_punc(text)
137
+ pieces = [self.spm.encode_as_pieces(w) for w in words]
138
+ return [p for w in pieces for p in w]
139
+ else:
140
+ return self.spm.encode_as_pieces(text)
141
+
142
+ def split_to_words(self, text):
143
+ pieces = self._encode_as_pieces(text)
144
+ word_start = b'\xe2\x96\x81'.decode('utf-8')
145
+ words = []
146
+ offset = 0
147
+ prev_end = 0
148
+ for i,p in enumerate(pieces):
149
+ if p.startswith(word_start):
150
+ if offset>prev_end:
151
+ words.append(text[prev_end:offset])
152
+ prev_end = offset
153
+ w = p.replace(word_start, '')
154
+ else:
155
+ w = p
156
+ try:
157
+ s = text.index(w, offset)
158
+ pn = ""
159
+ k = i+1
160
+ while k < len(pieces):
161
+ pn = pieces[k].replace(word_start, '')
162
+ if len(pn)>0:
163
+ break
164
+ k += 1
165
+
166
+ if len(pn)>0 and pn in text[offset:s]:
167
+ offset = offset + 1
168
+ else:
169
+ offset = s + len(w)
170
+ except:
171
+ offset = offset + 1
172
+
173
+ if prev_end< offset:
174
+ words.append(text[prev_end:offset])
175
+
176
+ return words
177
+
178
+ def _run_strip_accents(self, text):
179
+ """Strips accents from a piece of text."""
180
+ text = unicodedata.normalize("NFD", text)
181
+ output = []
182
+ for char in text:
183
+ cat = unicodedata.category(char)
184
+ if cat == "Mn":
185
+ continue
186
+ output.append(char)
187
+ return "".join(output)
188
+
189
+ def _run_split_on_punc(self, text):
190
+ """Splits punctuation on a piece of text."""
191
+ #words = list(re.findall(self.pat, text))
192
+ chars = list(text)
193
+ i = 0
194
+ start_new_word = True
195
+ output = []
196
+ while i < len(chars):
197
+ char = chars[i]
198
+ if _is_punctuation(char):
199
+ output.append([char])
200
+ start_new_word = True
201
+ else:
202
+ if start_new_word:
203
+ output.append([])
204
+ start_new_word = False
205
+ output[-1].append(char)
206
+ i += 1
207
+
208
+ return ["".join(x) for x in output]
209
+
210
+ def _tokenize_chinese_chars(self, text):
211
+ """Adds whitespace around any CJK character."""
212
+ output = []
213
+ for char in text:
214
+ cp = ord(char)
215
+ if self._is_chinese_char(cp):
216
+ output.append(" ")
217
+ output.append(char)
218
+ output.append(" ")
219
+ else:
220
+ output.append(char)
221
+ return "".join(output)
222
+
223
+ def _is_chinese_char(self, cp):
224
+ """Checks whether CP is the codepoint of a CJK character."""
225
+ # This defines a "chinese character" as anything in the CJK Unicode block:
226
+ # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
227
+ #
228
+ # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
229
+ # despite its name. The modern Korean Hangul alphabet is a different block,
230
+ # as is Japanese Hiragana and Katakana. Those alphabets are used to write
231
+ # space-separated words, so they are not treated specially and handled
232
+ # like the all of the other languages.
233
+ if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
234
+ (cp >= 0x3400 and cp <= 0x4DBF) or #
235
+ (cp >= 0x20000 and cp <= 0x2A6DF) or #
236
+ (cp >= 0x2A700 and cp <= 0x2B73F) or #
237
+ (cp >= 0x2B740 and cp <= 0x2B81F) or #
238
+ (cp >= 0x2B820 and cp <= 0x2CEAF) or
239
+ (cp >= 0xF900 and cp <= 0xFAFF) or #
240
+ (cp >= 0x2F800 and cp <= 0x2FA1F)): #
241
+ return True
242
+
243
+ return False
244
+
245
+ def _clean_text(self, text):
246
+ """Performs invalid character removal and whitespace cleanup on text."""
247
+ output = []
248
+ for char in text:
249
+ cp = ord(char)
250
+ if cp == 0 or cp == 0xfffd or _is_control(char):
251
+ continue
252
+ if _is_whitespace(char):
253
+ output.append(" ")
254
+ else:
255
+ output.append(char)
256
+ return "".join(output)
257
+
258
+
259
+ def _is_whitespace(char):
260
+ """Checks whether `chars` is a whitespace character."""
261
+ # \t, \n, and \r are technically contorl characters but we treat them
262
+ # as whitespace since they are generally considered as such.
263
+ if char == " " or char == "\t" or char == "\n" or char == "\r":
264
+ return True
265
+ cat = unicodedata.category(char)
266
+ if cat == "Zs":
267
+ return True
268
+ return False
269
+
270
+ def _is_control(char):
271
+ """Checks whether `chars` is a control character."""
272
+ # These are technically control characters but we count them as whitespace
273
+ # characters.
274
+ if char == "\t" or char == "\n" or char == "\r":
275
+ return False
276
+ cat = unicodedata.category(char)
277
+ if cat.startswith("C"):
278
+ return True
279
+ return False
280
+
281
+ def _is_punctuation(char):
282
+ """Checks whether `chars` is a punctuation character."""
283
+ cp = ord(char)
284
+ # We treat all non-letter/number ASCII as punctuation.
285
+ # Characters such as "^", "$", and "`" are not in the Unicode
286
+ # Punctuation class but we treat them as punctuation anyways, for
287
+ # consistency.
288
+ if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
289
+ (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
290
+ return True
291
+ cat = unicodedata.category(char)
292
+ if cat.startswith("P"):
293
+ return True
294
+ return False
295
+
296
+ def whitespace_tokenize(text):
297
+ """Runs basic whitespace cleaning and splitting on a peice of text."""
298
+ text = text.strip()
299
+ if not text:
300
+ return []
301
+ tokens = text.split()
302
+ return tokens
303
+
304
+ def convert_to_unicode(text):
305
+ """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
306
+ if six.PY3:
307
+ if isinstance(text, str):
308
+ return text
309
+ elif isinstance(text, bytes):
310
+ return text.decode("utf-8", "ignore")
311
+ else:
312
+ raise ValueError("Unsupported string type: %s" % (type(text)))
313
+ elif six.PY2:
314
+ if isinstance(text, str):
315
+ return text.decode("utf-8", "ignore")
316
+ elif isinstance(text, unicode):
317
+ return text
318
+ else:
319
+ raise ValueError("Unsupported string type: %s" % (type(text)))
320
+ else:
321
+ raise ValueError("Not running on Python2 or Python 3?")
322
+
modeling/tokenizers.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Zhou Bo
3
+
4
+ #
5
+
6
+ """ tokenizers
7
+ """
8
+
9
+ from .spm_tokenizer import *
10
+ from .gpt2_tokenizer import GPT2Tokenizer
11
+ from wywLM.models import BertTokenizer
12
+
13
+ __all__ = ['tokenizers']
14
+ tokenizers={
15
+ 'gpt2': GPT2Tokenizer,
16
+ 'spm': SPMTokenizer,
17
+ 'bert': BertTokenizer
18
+ }
modeling/wywlm_modeling.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft, Inc. 2020
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ #
6
+ # Zhou Bo
7
+ # Date: 01/15/2020
8
+ #
9
+
10
+ import copy
11
+ import torch
12
+ import os
13
+ import random
14
+
15
+ import json
16
+ from .ops import *
17
+ from .bert import *
18
+ from .bert import BertLayer
19
+ from .config import ModelConfig
20
+ from .cache_utils import load_model_state
21
+ from .nnmodule import NNModule
22
+
23
+ # from ..utils.bad_grad_viz import register_hooks
24
+
25
+ __all__ = ['WywLM']
26
+
27
+ def flatten_states(q_states, mask_index):
28
+ q_states = q_states.reshape((-1, q_states.size(-1)))
29
+ q_states = q_states.index_select(0, mask_index)
30
+ return q_states
31
+
32
+
33
+ class UGDecoder(torch.nn.Module):
34
+ def __init__(self, config, vocab_size):
35
+ super().__init__()
36
+ self.config = config
37
+ self.position_biased_input = getattr(config, 'position_biased_input', True)
38
+ # self.layer = torch.nn.ModuleList([BertLayer(config) for _ in range(2)])
39
+
40
+ # self.causal_mask = torch.tril(torch.ones((input_ids.dim(0), input_ids.dim(1), input_ids.dim(1))), diagonal=0)
41
+
42
+ def forward(self, ctx_layers, word_embedding, input_ids, z_states, attention_mask, \
43
+ encoder, target_ids=None, relative_pos=None, decode=False, s2s_idx=None):
44
+ causal_outputs, lm_outputs = self.emd_context_layer(ctx_layers, z_states, attention_mask,
45
+ encoder, target_ids, input_ids,
46
+ relative_pos=relative_pos, decode=decode,
47
+ word_embedding=word_embedding, s2s_idx=s2s_idx)
48
+ # loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
49
+
50
+ # ctx_layer = mlm_ctx_layers[-1]
51
+
52
+ # lm_logits = lm_logits.view(-1, lm_logits.size(-1))
53
+
54
+ return causal_outputs[-1], lm_outputs[-1]
55
+
56
+ def emd_context_layer(self, encoder_layers, z_states, attention_mask, encoder, target_ids, input_ids,\
57
+ relative_pos=None, decode=False, word_embedding=None, s2s_idx=None):
58
+ # if decode:
59
+ # attention_mask = torch.tril(torch.ones((input_ids.shape[0], 1, input_ids.shape[1], input_ids.shape[1])), diagonal=0).to(input_ids.device)
60
+ # else:
61
+ if attention_mask.dim()<=2:
62
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
63
+ att_mask = extended_attention_mask.byte()
64
+ attention_mask = att_mask*att_mask.squeeze(-2).unsqueeze(-1)
65
+ elif attention_mask.dim()==3:
66
+ attention_mask = attention_mask.unsqueeze(1)
67
+
68
+
69
+ if not self.position_biased_input:
70
+
71
+
72
+ lm_outputs = []
73
+ # else:
74
+ hidden_states = encoder_layers[-2]
75
+ layers = [encoder.layer[-1] for _ in range(2)]
76
+ z_states += hidden_states
77
+ query_states = z_states
78
+ query_mask = attention_mask
79
+ rel_embeddings = encoder.get_rel_embedding()
80
+ for layer in layers:
81
+ # TODO: pass relative pos ids
82
+ output = layer(hidden_states, query_mask, return_att=False,
83
+ query_states=query_states, relative_pos=relative_pos,
84
+ rel_embeddings=rel_embeddings)
85
+ query_states = output
86
+ lm_outputs.append(query_states)
87
+
88
+ # if decode:
89
+ attention_mask = torch.tril(torch.ones((input_ids.shape[0], 1, input_ids.shape[1], input_ids.shape[1])),
90
+ diagonal=0).to(input_ids.device)
91
+ causal_outputs = []
92
+ # with torch.no_grad():
93
+ target_embd = word_embedding(target_ids)
94
+
95
+ target_embd += z_states.detach()
96
+ # self attention of target
97
+ output = layers[-2](target_embd, attention_mask, return_att=False,
98
+ query_states=target_embd, relative_pos=relative_pos,
99
+ rel_embeddings=encoder.get_rel_embedding())
100
+ causal_outputs.append(output)
101
+ # cross attention
102
+ output = layers[-1](output, attention_mask, return_att=False,
103
+ query_states=query_states, relative_pos=relative_pos,
104
+ rel_embeddings=encoder.get_rel_embedding())
105
+ causal_outputs.append(output)
106
+
107
+ else:
108
+ causal_outputs = [encoder_layers[-1]]
109
+ lm_outputs = [encoder_layers[-1]]
110
+ return causal_outputs, lm_outputs
111
+
112
+
113
+ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
114
+ """
115
+ Shift input ids one token to the right.
116
+ """
117
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
118
+ shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
119
+ shifted_input_ids[:, 0] = decoder_start_token_id
120
+
121
+ if pad_token_id is None:
122
+ raise ValueError("self.model.config.pad_token_id has to be defined.")
123
+ # replace possible -100 values in labels by `pad_token_id`
124
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
125
+
126
+ return shifted_input_ids
127
+
128
+
129
+ class WywLMLoss(torch.nn.Module):
130
+ def __init__(self, config) -> None:
131
+ super().__init__()
132
+ self.loss_fn = torch.nn.CrossEntropyLoss(reduction='mean')
133
+ hidden_size = getattr(config, 'embedding_size', config.hidden_size)
134
+ self.compare = torch.nn.Linear(hidden_size * 3, 2)
135
+ # self.mlm_head = BertLMPredictionHead(config, config.vocab_size)
136
+ self.lm_head = BertLMPredictionHead(config, config.vocab_size)
137
+
138
+ def forward(self, logits, lm_logits, target_ids, dict_pos, input_ids, target_ids_s2s, decode=False, ebd_weight=None, task=0):
139
+ loss_compare = torch.tensor(0).to(logits).float()
140
+ mlm_loss = torch.tensor(0).to(logits).float()
141
+ lm_loss = torch.tensor(0).to(logits).float()
142
+
143
+ # else:
144
+ if task == 1:
145
+ compare_logits = []
146
+ compare_labels = []
147
+ for bi, sampel_pos in enumerate(dict_pos):
148
+ num_pos = int((sampel_pos > 0).sum().detach().cpu().numpy() / 4) - 1
149
+ if num_pos <= 1:
150
+ continue
151
+ for pi in range(num_pos):
152
+ pos = sampel_pos[pi]
153
+ entry_logits = logits[bi][pos[0]: pos[1]]
154
+ desc_logits = logits[bi][pos[2]: pos[3]]
155
+ neg_num = random.randint(0, num_pos) # torch.randint(low=0, high=num_pos, size=(1,))
156
+ ids_neg = input_ids[bi][sampel_pos[neg_num][0]: sampel_pos[neg_num][1]]
157
+ ids_pos = input_ids[bi][pos[0]: pos[1]]
158
+ if neg_num == pi or (ids_neg.shape == ids_pos.shape and torch.all(ids_neg == ids_pos)):
159
+ neg_num = -1
160
+ for ni in range(num_pos):
161
+ neg_num = random.randint(0, num_pos)# torch.randint(low=0, high=num_pos, size=(1,))
162
+ ids_neg = input_ids[bi][sampel_pos[neg_num][0]: sampel_pos[neg_num][1]]
163
+ if neg_num != pi and (ids_neg.shape != ids_pos.shape or not torch.all(ids_neg == ids_pos)):
164
+ break
165
+ else:
166
+ neg_num = -1
167
+ if neg_num == -1:
168
+ continue
169
+ neg_desc_logits = logits[bi][sampel_pos[neg_num][2]: sampel_pos[neg_num][3]]
170
+ if torch.any(torch.isnan(neg_desc_logits)):
171
+ print('error')
172
+ entry_logits = entry_logits.mean(dim=0, keepdim=True).float()
173
+ desc_logits = desc_logits.mean(dim=0, keepdim=True).float()
174
+ neg_desc_logits = neg_desc_logits.mean(dim=0, keepdim=True).float()
175
+ compare_logits.append(torch.concat([entry_logits, desc_logits, entry_logits - desc_logits], dim=1))
176
+ compare_logits.append(torch.concat([entry_logits, neg_desc_logits, entry_logits - neg_desc_logits], dim=1))
177
+ compare_labels += [1, 0]
178
+ if len(compare_logits) > 0:
179
+ compare_logits = torch.concat(compare_logits, dim=0).to(logits.dtype)
180
+ compare_pred = self.compare(compare_logits)
181
+ loss_compare = self.loss_fn(compare_pred, torch.tensor(compare_labels, dtype=torch.long, device=compare_logits.device)).mean()
182
+
183
+ if torch.all(loss_compare == 0):
184
+ entry_logits = logits[0][0].unsqueeze(0)
185
+ compare_logits = torch.concat([entry_logits, entry_logits, entry_logits - entry_logits], dim=1)
186
+ compare_pred = self.compare(compare_logits)
187
+ compare_labels = [1]
188
+ loss_compare = self.loss_fn(compare_pred, torch.tensor(compare_labels, dtype=torch.long, device=compare_logits.device)).mean()
189
+
190
+ # if decode:
191
+ # lm_labels = target_ids_s2s.index_select(0, (target_ids_s2s.sum(-1) > 0).nonzero().view(-1)[0])
192
+ # lm_labels = lm_labels.repeat(logits.shape[0], 1).clone().view(-1)
193
+ # lm_labels = target_ids_s2s.clone()
194
+ # target_ids_s2s = shift_tokens_right(target_ids_s2s, 0, 1)
195
+ # target_ids_s2s.masked_fill_(target_ids_s2s==0, 3)
196
+ if task == 0:
197
+ _mask_index = (target_ids_s2s > 0).view(-1).nonzero().view(-1)
198
+ lm_logits_ = flatten_states(lm_logits, _mask_index)
199
+ lm_pred = self.lm_head(lm_logits_, ebd_weight).float()
200
+ lm_labels = target_ids_s2s.clone().reshape(-1)
201
+ lm_labels = lm_labels.index_select(0, _mask_index)
202
+ # lm_pred = torch.nn.functional.log_softmax(lm_pred)
203
+ # lm_loss = torch.nn.functional.nll_loss(lm_pred, lm_labels.long())
204
+ lm_loss = self.loss_fn(lm_pred, lm_labels.long())
205
+ # dot = register_hooks(lm_loss)
206
+ # lm_loss.backward()
207
+ # dot().save('tmp.dot')
208
+
209
+
210
+ _mask_index = (target_ids > 0).view(-1).nonzero().view(-1)
211
+ mlm_logits = flatten_states(logits, _mask_index)
212
+ mlm_pred = self.lm_head(mlm_logits, ebd_weight).float()
213
+ mlm_labels = target_ids.view(-1)
214
+ mlm_labels = mlm_labels.index_select(0, _mask_index)
215
+ mlm_loss = self.loss_fn(mlm_pred, mlm_labels.long())
216
+ return loss_compare, mlm_loss, lm_loss
217
+
218
+ class WywLM(torch.nn.Module):
219
+ """ DeBERTa encoder
220
+ This module is composed of the input embedding layer with stacked transformer layers with disentangled attention.
221
+
222
+ Parameters:
223
+ config:
224
+ A model config class instance with the configuration to build a new model. The schema is similar to `BertConfig`, \
225
+ for more details, please refer :class:`~DeBERTa.deberta.ModelConfig`
226
+
227
+ pre_trained:
228
+ The pre-trained DeBERTa model, it can be a physical path of a pre-trained DeBERTa model or a released configurations, \
229
+ i.e. [**base, large, base_mnli, large_mnli**]
230
+
231
+ """
232
+
233
+ def __init__(self, config=None, pre_trained=None):
234
+ super().__init__()
235
+ state = None
236
+ if pre_trained is not None:
237
+ state, model_config = load_model_state(pre_trained)
238
+ if config is not None and model_config is not None:
239
+ for k in config.__dict__:
240
+ if k not in ['hidden_size',
241
+ 'intermediate_size',
242
+ 'num_attention_heads',
243
+ 'num_hidden_layers',
244
+ 'vocab_size',
245
+ 'max_position_embeddings']:
246
+ model_config.__dict__[k] = config.__dict__[k]
247
+ config = copy.copy(model_config)
248
+ self.embeddings = BertEmbeddings(config)
249
+ self.encoder = BertEncoder(config)
250
+ self.config = config
251
+ self.pre_trained = pre_trained
252
+ self.apply_state(state)
253
+
254
+ def forward(self, input_ids, attention_mask=None, token_type_ids=None, output_all_encoded_layers=True, position_ids = None, return_att = False):
255
+ """
256
+ Args:
257
+ input_ids:
258
+ a torch.LongTensor of shape [batch_size, sequence_length] \
259
+ with the word token indices in the vocabulary
260
+
261
+ attention_mask:
262
+ an optional parameter for input mask or attention mask.
263
+
264
+ - If it's an input mask, then it will be torch.LongTensor of shape [batch_size, sequence_length] with indices \
265
+ selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max \
266
+ input sequence length in the current batch. It's the mask that we typically use for attention when \
267
+ a batch has varying length sentences.
268
+
269
+ - If it's an attention mask then it will be torch.LongTensor of shape [batch_size, sequence_length, sequence_length]. \
270
+ In this case, it's a mask indicate which tokens in the sequence should be attended by other tokens in the sequence.
271
+
272
+ token_type_ids:
273
+ an optional torch.LongTensor of shape [batch_size, sequence_length] with the token \
274
+ types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to \
275
+ a `sentence B` token (see BERT paper for more details).
276
+
277
+ output_all_encoded_layers:
278
+ whether to output results of all encoder layers, default, True
279
+
280
+ Returns:
281
+
282
+ - The output of the stacked transformer layers if `output_all_encoded_layers=True`, else \
283
+ the last layer of stacked transformer layers
284
+
285
+ - Attention matrix of self-attention layers if `return_att=True`
286
+
287
+
288
+ Example::
289
+
290
+ # Batch of wordPiece token ids.
291
+ # Each sample was padded with zero to the maxium length of the batch
292
+ input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
293
+ # Mask of valid input ids
294
+ attention_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
295
+
296
+ # DeBERTa model initialized with pretrained base model
297
+ bert = DeBERTa(pre_trained='base')
298
+
299
+ encoder_layers = bert(input_ids, attention_mask=attention_mask)
300
+
301
+ """
302
+
303
+ if attention_mask is None:
304
+ attention_mask = torch.ones_like(input_ids)
305
+ if token_type_ids is None:
306
+ token_type_ids = torch.zeros_like(input_ids)
307
+ token_mask = torch.ones_like(input_ids)
308
+ else:
309
+ idxs = torch.flip(torch.cumsum(torch.flip(token_type_ids, [-1]), axis=1), [-1])
310
+ token_mask = idxs > 0
311
+ token_mask = token_mask.byte()
312
+ ebd_output = self.embeddings(input_ids.to(torch.long), token_type_ids.to(torch.long), position_ids, token_mask)
313
+ embedding_output = ebd_output['embeddings']
314
+ encoder_output = self.encoder(embedding_output,
315
+ attention_mask,
316
+ output_all_encoded_layers=output_all_encoded_layers, return_att = return_att)
317
+ encoder_output.update(ebd_output)
318
+ return encoder_output
319
+
320
+ def apply_state(self, state = None):
321
+ """ Load state from previous loaded model state dictionary.
322
+
323
+ Args:
324
+ state (:obj:`dict`, optional): State dictionary as the state returned by torch.module.state_dict(), default: `None`. \
325
+ If it's `None`, then will use the pre-trained state loaded via the constructor to re-initialize \
326
+ the `DeBERTa` model
327
+ """
328
+ if self.pre_trained is None and state is None:
329
+ return
330
+ if state is None:
331
+ state, config = load_model_state(self.pre_trained)
332
+ self.config = config
333
+
334
+ prefix = ''
335
+ for k in state:
336
+ if 'embeddings.' in k:
337
+ if not k.startswith('embeddings.'):
338
+ prefix = k[:k.index('embeddings.')]
339
+ break
340
+
341
+ missing_keys = []
342
+ unexpected_keys = []
343
+ error_msgs = []
344
+ self._load_from_state_dict(state, prefix = prefix, local_metadata=None, strict=True, missing_keys=missing_keys, unexpected_keys=unexpected_keys, error_msgs=error_msgs)
345
+
346
+
347
+ class MaskedLanguageModel(NNModule):
348
+ """ Masked language model
349
+ """
350
+ def __init__(self, config, *wargs, **kwargs):
351
+ super().__init__(config)
352
+ self.backbone = WywLM(config)
353
+
354
+ self.max_relative_positions = getattr(config, 'max_relative_positions', -1)
355
+ self.position_buckets = getattr(config, 'position_buckets', -1)
356
+ if self.max_relative_positions <1:
357
+ self.max_relative_positions = config.max_position_embeddings
358
+ # self.mlm_predictions = UGDecoder(self.backbone.config, self.backbone.embeddings.word_embeddings.weight.size(0))
359
+ self.lm_predictions = UGDecoder(self.backbone.config, self.backbone.embeddings.word_embeddings.weight.size(0))
360
+ self.device = None
361
+ self.loss = WywLMLoss(config)
362
+ # self.loss_lm = WywLMLoss(config)
363
+ self.apply(self.init_weights)
364
+
365
+ def forward(self, samples, position_ids=None):
366
+ task = samples['task']
367
+ if task == 0:
368
+ input_ids = samples['s2s_input_ids']
369
+ type_ids = samples['s2s_token_type_ids']
370
+ attention_mask = samples['s2s_attention_mask']
371
+ labels = samples['s2s_masked_lm_labels']
372
+ dict_pos = samples['dict_pos']
373
+ s2s_label = samples['s2s_label']
374
+ else:
375
+ input_ids = samples['input_ids']
376
+ type_ids = samples['token_type_ids']
377
+ attention_mask = samples['attention_mask']
378
+ labels = samples['masked_lm_labels']
379
+ dict_pos = samples['dict_pos']
380
+ s2s_label = samples['s2s_label']
381
+
382
+ if self.device is None:
383
+ self.device = list(self.parameters())[0].device
384
+
385
+ input_ids = input_ids.to(self.device)
386
+
387
+ type_ids = None
388
+ lm_labels = labels.to(self.device)
389
+ s2s_label = s2s_label.to(self.device)
390
+ attention_mask = attention_mask.to(self.device)
391
+
392
+ encoder_output = self.backbone(input_ids, attention_mask, type_ids, output_all_encoded_layers=True, position_ids = position_ids)
393
+ encoder_layers = encoder_output['hidden_states']
394
+ z_states = encoder_output['position_embeddings']
395
+ ctx_layer = encoder_layers[-1]
396
+ mlm_loss = torch.tensor(0).to(ctx_layer).float()
397
+ lm_loss = torch.tensor(0).to(ctx_layer).float()
398
+ lm_logits = None
399
+ label_inputs = None
400
+ loss = torch.tensor(0).to(ctx_layer).float()
401
+ loss_compare = torch.tensor(0).to(ctx_layer).float()
402
+
403
+ ebd_weight = self.backbone.embeddings.word_embeddings.weight
404
+ lm_logits, mlm_logits = self.lm_predictions(encoder_layers, self.backbone.embeddings.word_embeddings,
405
+ input_ids, z_states,
406
+ attention_mask, self.backbone.encoder,
407
+ target_ids=lm_labels)
408
+ # if lm_labels.detach().sum() != 0:
409
+ loss_compare, mlm_loss, lm_loss = self.loss(mlm_logits,
410
+ lm_logits,
411
+ lm_labels,
412
+ dict_pos,
413
+ target_ids_s2s=s2s_label,
414
+ decode=False,
415
+ ebd_weight=ebd_weight,
416
+ input_ids=input_ids,
417
+ task=task)
418
+ loss = loss_compare * 10 + mlm_loss + lm_loss
419
+ # if s2s_label.detach().sum() != 0:
420
+ # s2s_idx = (s2s_label.sum(-1)>0).nonzero().view(-1)
421
+ # s2s_label = s2s_label.index_select(0, s2s_idx)
422
+ # # ebd_weight = self.backbone.embeddings.word_embeddings.weight
423
+ # # lm_logits = self.lm_predictions(encoder_layers[-3], self.backbone.embeddings.word_embeddings,
424
+ # # input_ids.index_select(0, s2s_idx), z_states.index_select(0, s2s_idx),
425
+ # # attention_mask.index_select(0, s2s_idx), self.backbone.encoder,
426
+ # # target_ids=s2s_label,
427
+ # # decode=True, s2s_idx=s2s_idx)
428
+ # # lm_logits = encoder_layers[-1].detach().index_select(0, s2s_idx)
429
+ # _, lm_loss = self.loss_lm(lm_logits,
430
+ # s2s_label,
431
+ # torch.zeros_like(dict_pos),
432
+ # decode=True,
433
+ # ebd_weight=ebd_weight,
434
+ # input_ids=input_ids.index_select(0, s2s_idx))
435
+ # lm_loss = lm_logits.max()
436
+ # loss = loss + lm_loss
437
+
438
+ return {
439
+ 'logits' : lm_logits,
440
+ 'labels' : lm_labels,
441
+ 's2s_label': s2s_label,
442
+ 'loss' : loss.float(),
443
+ 'loss_compare': loss_compare.float(),
444
+ 'lm_loss': lm_loss.float(),
445
+ 'mlm_loss': mlm_loss.float()
446
+ }