Upload 8 files
Browse files- bert_tokenizer.py +206 -0
- config.json +29 -0
- modeling_glycebert.py +713 -0
- pytorch_model.bin +3 -0
- special_tokens_map.json +7 -0
- tokenizer.json +0 -0
- tokenizer_config.json +21 -0
- vocab.txt +0 -0
bert_tokenizer.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import shutil
|
4 |
+
import time
|
5 |
+
from pathlib import Path
|
6 |
+
from typing import List, Union, Optional
|
7 |
+
|
8 |
+
import tokenizers
|
9 |
+
import torch
|
10 |
+
from torch import NoneType
|
11 |
+
from huggingface_hub import hf_hub_download
|
12 |
+
from huggingface_hub.file_download import http_user_agent
|
13 |
+
from pypinyin import pinyin, Style
|
14 |
+
from transformers.tokenization_utils_base import TruncationStrategy
|
15 |
+
from transformers.utils import PaddingStrategy
|
16 |
+
from transformers.utils.generic import TensorType
|
17 |
+
|
18 |
+
try:
|
19 |
+
from tokenizers import BertWordPieceTokenizer
|
20 |
+
except:
|
21 |
+
from tokenizers.implementations import BertWordPieceTokenizer
|
22 |
+
|
23 |
+
from transformers import BertTokenizerFast, BatchEncoding
|
24 |
+
|
25 |
+
cache_path = Path(os.path.abspath(__file__)).parent
|
26 |
+
|
27 |
+
|
28 |
+
def download_file(filename: str, path: Path):
|
29 |
+
if os.path.exists(cache_path / filename):
|
30 |
+
return
|
31 |
+
|
32 |
+
if os.path.exists(path / filename):
|
33 |
+
shutil.copyfile(path / filename, cache_path / filename)
|
34 |
+
return
|
35 |
+
|
36 |
+
hf_hub_download(
|
37 |
+
"iioSnail/ChineseBERT-base",
|
38 |
+
filename,
|
39 |
+
local_dir=cache_path,
|
40 |
+
user_agent=http_user_agent(),
|
41 |
+
)
|
42 |
+
time.sleep(0.2)
|
43 |
+
|
44 |
+
|
45 |
+
class ChineseBertTokenizer(BertTokenizerFast):
|
46 |
+
|
47 |
+
def __init__(self, **kwargs):
|
48 |
+
super(ChineseBertTokenizer, self).__init__(**kwargs)
|
49 |
+
|
50 |
+
self.path = Path(kwargs['name_or_path'])
|
51 |
+
vocab_file = cache_path / 'vocab.txt'
|
52 |
+
config_path = cache_path / 'config'
|
53 |
+
if not os.path.exists(config_path):
|
54 |
+
os.makedirs(config_path)
|
55 |
+
|
56 |
+
self.max_length = 512
|
57 |
+
|
58 |
+
download_file('vocab.txt', self.path)
|
59 |
+
self.tokenizer = BertWordPieceTokenizer(str(vocab_file))
|
60 |
+
|
61 |
+
# load pinyin map dict
|
62 |
+
download_file('config/pinyin_map.json', self.path)
|
63 |
+
with open(config_path / 'pinyin_map.json', encoding='utf8') as fin:
|
64 |
+
self.pinyin_dict = json.load(fin)
|
65 |
+
|
66 |
+
# load char id map tensor
|
67 |
+
download_file('config/id2pinyin.json', self.path)
|
68 |
+
with open(config_path / 'id2pinyin.json', encoding='utf8') as fin:
|
69 |
+
self.id2pinyin = json.load(fin)
|
70 |
+
|
71 |
+
# load pinyin map tensor
|
72 |
+
download_file('config/pinyin2tensor.json', self.path)
|
73 |
+
with open(config_path / 'pinyin2tensor.json', encoding='utf8') as fin:
|
74 |
+
self.pinyin2tensor = json.load(fin)
|
75 |
+
|
76 |
+
def __call__(self,
|
77 |
+
text: Union[str, List[str], List[List[str]]] = None,
|
78 |
+
text_pair: Union[str, List[str], List[List[str]], NoneType] = None,
|
79 |
+
text_target: Union[str, List[str], List[List[str]]] = None,
|
80 |
+
text_pair_target: Union[str, List[str], List[List[str]], NoneType] = None,
|
81 |
+
add_special_tokens: bool = True,
|
82 |
+
padding: Union[bool, str, PaddingStrategy] = False,
|
83 |
+
truncation: Union[bool, str, TruncationStrategy] = None,
|
84 |
+
max_length: Optional[int] = None,
|
85 |
+
stride: int = 0,
|
86 |
+
is_split_into_words: bool = False,
|
87 |
+
pad_to_multiple_of: Optional[int] = None,
|
88 |
+
return_tensors: Union[str, TensorType, NoneType] = None,
|
89 |
+
return_token_type_ids: Optional[bool] = None,
|
90 |
+
return_attention_mask: Optional[bool] = None,
|
91 |
+
return_overflowing_tokens: bool = False, return_special_tokens_mask: bool = False,
|
92 |
+
return_offsets_mapping: bool = False,
|
93 |
+
return_length: bool = False,
|
94 |
+
verbose: bool = True, **kwargs) -> BatchEncoding:
|
95 |
+
encoding = super(ChineseBertTokenizer, self).__call__(
|
96 |
+
text=text,
|
97 |
+
text_pair=text_pair,
|
98 |
+
text_target=text_target,
|
99 |
+
text_pair_target=text_pair_target,
|
100 |
+
add_special_tokens=add_special_tokens,
|
101 |
+
padding=padding,
|
102 |
+
truncation=truncation,
|
103 |
+
max_length=max_length,
|
104 |
+
stride=stride,
|
105 |
+
is_split_into_words=is_split_into_words,
|
106 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
107 |
+
return_tensors=return_tensors,
|
108 |
+
return_token_type_ids=return_token_type_ids,
|
109 |
+
return_attention_mask=return_attention_mask,
|
110 |
+
return_overflowing_tokens=return_overflowing_tokens,
|
111 |
+
return_offsets_mapping=return_offsets_mapping,
|
112 |
+
return_length=return_length,
|
113 |
+
verbose=verbose,
|
114 |
+
)
|
115 |
+
|
116 |
+
input_ids = encoding.input_ids
|
117 |
+
|
118 |
+
pinyin_ids = None
|
119 |
+
if type(text) == str:
|
120 |
+
pinyin_ids = self.convert_ids_to_pinyin_ids(input_ids)
|
121 |
+
|
122 |
+
if type(text) == list:
|
123 |
+
pinyin_ids = []
|
124 |
+
for ids in input_ids:
|
125 |
+
pinyin_ids.append(self.convert_ids_to_pinyin_ids(ids))
|
126 |
+
|
127 |
+
if torch.is_tensor(encoding.input_ids):
|
128 |
+
pinyin_ids = torch.LongTensor(pinyin_ids)
|
129 |
+
|
130 |
+
encoding['pinyin_ids'] = pinyin_ids
|
131 |
+
|
132 |
+
return encoding
|
133 |
+
|
134 |
+
def tokenize_sentence(self, sentence):
|
135 |
+
# convert sentence to ids
|
136 |
+
tokenizer_output = self.tokenizer.encode(sentence)
|
137 |
+
bert_tokens = tokenizer_output.ids
|
138 |
+
pinyin_tokens = self.convert_sentence_to_pinyin_ids(sentence, tokenizer_output)
|
139 |
+
# assert,token nums should be same as pinyin token nums
|
140 |
+
assert len(bert_tokens) <= self.max_length
|
141 |
+
assert len(bert_tokens) == len(pinyin_tokens)
|
142 |
+
# convert list to tensor
|
143 |
+
input_ids = torch.LongTensor(bert_tokens)
|
144 |
+
pinyin_ids = torch.LongTensor(pinyin_tokens).view(-1)
|
145 |
+
return input_ids, pinyin_ids
|
146 |
+
|
147 |
+
def convert_ids_to_pinyin_ids(self, ids: List[int]):
|
148 |
+
pinyin_ids = []
|
149 |
+
tokens = self.convert_ids_to_tokens(ids)
|
150 |
+
for token in tokens:
|
151 |
+
if len(token) > 1:
|
152 |
+
pinyin_ids.append([0] * 8)
|
153 |
+
continue
|
154 |
+
|
155 |
+
pinyin_string = pinyin(token, style=Style.TONE3, errors=lambda x: [['not chinese'] for _ in x])[0][0]
|
156 |
+
|
157 |
+
if pinyin_string == "not chinese":
|
158 |
+
pinyin_ids.append([0] * 8)
|
159 |
+
continue
|
160 |
+
|
161 |
+
if pinyin_string in self.pinyin2tensor:
|
162 |
+
pinyin_ids.append(self.pinyin2tensor[pinyin_string])
|
163 |
+
else:
|
164 |
+
ids = [0] * 8
|
165 |
+
for i, p in enumerate(pinyin_string):
|
166 |
+
if p not in self.pinyin_dict["char2idx"]:
|
167 |
+
ids = [0] * 8
|
168 |
+
break
|
169 |
+
ids[i] = self.pinyin_dict["char2idx"][p]
|
170 |
+
pinyin_ids.append(pinyin_ids)
|
171 |
+
|
172 |
+
return pinyin_ids
|
173 |
+
|
174 |
+
def convert_sentence_to_pinyin_ids(self, sentence: str, tokenizer_output: tokenizers.Encoding) -> List[List[int]]:
|
175 |
+
# get pinyin of a sentence
|
176 |
+
pinyin_list = pinyin(sentence, style=Style.TONE3, heteronym=True, errors=lambda x: [['not chinese'] for _ in x])
|
177 |
+
pinyin_locs = {}
|
178 |
+
# get pinyin of each location
|
179 |
+
for index, item in enumerate(pinyin_list):
|
180 |
+
pinyin_string = item[0]
|
181 |
+
# not a Chinese character, pass
|
182 |
+
if pinyin_string == "not chinese":
|
183 |
+
continue
|
184 |
+
if pinyin_string in self.pinyin2tensor:
|
185 |
+
pinyin_locs[index] = self.pinyin2tensor[pinyin_string]
|
186 |
+
else:
|
187 |
+
ids = [0] * 8
|
188 |
+
for i, p in enumerate(pinyin_string):
|
189 |
+
if p not in self.pinyin_dict["char2idx"]:
|
190 |
+
ids = [0] * 8
|
191 |
+
break
|
192 |
+
ids[i] = self.pinyin_dict["char2idx"][p]
|
193 |
+
pinyin_locs[index] = ids
|
194 |
+
|
195 |
+
# find chinese character location, and generate pinyin ids
|
196 |
+
pinyin_ids = []
|
197 |
+
for idx, (token, offset) in enumerate(zip(tokenizer_output.tokens, tokenizer_output.offsets)):
|
198 |
+
if offset[1] - offset[0] != 1:
|
199 |
+
pinyin_ids.append([0] * 8)
|
200 |
+
continue
|
201 |
+
if offset[0] in pinyin_locs:
|
202 |
+
pinyin_ids.append(pinyin_locs[offset[0]])
|
203 |
+
else:
|
204 |
+
pinyin_ids.append([0] * 8)
|
205 |
+
|
206 |
+
return pinyin_ids
|
config.json
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "./ChineseBERT-large",
|
3 |
+
"architectures": [
|
4 |
+
"GlyceBertForMaskedLM"
|
5 |
+
],
|
6 |
+
"attention_probs_dropout_prob": 0.1,
|
7 |
+
"auto_map": {
|
8 |
+
"AutoModel": "modeling_glycebert.GlyceBertForMaskedLM"
|
9 |
+
},
|
10 |
+
"classifier_dropout": null,
|
11 |
+
"gradient_checkpointing": false,
|
12 |
+
"hidden_act": "gelu",
|
13 |
+
"hidden_dropout_prob": 0.1,
|
14 |
+
"hidden_size": 1024,
|
15 |
+
"initializer_range": 0.02,
|
16 |
+
"intermediate_size": 4096,
|
17 |
+
"layer_norm_eps": 1e-12,
|
18 |
+
"max_position_embeddings": 512,
|
19 |
+
"model_type": "bert",
|
20 |
+
"num_attention_heads": 16,
|
21 |
+
"num_hidden_layers": 24,
|
22 |
+
"pad_token_id": 0,
|
23 |
+
"position_embedding_type": "absolute",
|
24 |
+
"torch_dtype": "float32",
|
25 |
+
"transformers_version": "4.27.1",
|
26 |
+
"type_vocab_size": 2,
|
27 |
+
"use_cache": true,
|
28 |
+
"vocab_size": 23236
|
29 |
+
}
|
modeling_glycebert.py
ADDED
@@ -0,0 +1,713 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
@file : modeling_glycebert.py
|
5 |
+
@author: zijun
|
6 |
+
@contact : [email protected]
|
7 |
+
@date : 2020/9/6 18:50
|
8 |
+
@version: 1.0
|
9 |
+
@desc : ChineseBert Model
|
10 |
+
"""
|
11 |
+
import json
|
12 |
+
import os
|
13 |
+
import shutil
|
14 |
+
import time
|
15 |
+
import warnings
|
16 |
+
from pathlib import Path
|
17 |
+
from typing import List
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
import torch
|
21 |
+
from huggingface_hub import hf_hub_download
|
22 |
+
from huggingface_hub.file_download import http_user_agent
|
23 |
+
from torch import nn
|
24 |
+
from torch.nn import CrossEntropyLoss, MSELoss
|
25 |
+
from torch.nn import functional as F
|
26 |
+
|
27 |
+
try:
|
28 |
+
from transformers.modeling_bert import BertEncoder, BertPooler, BertOnlyMLMHead, BertPreTrainedModel, BertModel
|
29 |
+
except:
|
30 |
+
from transformers.models.bert.modeling_bert import BertEncoder, BertPooler, BertOnlyMLMHead, BertPreTrainedModel, \
|
31 |
+
BertModel
|
32 |
+
|
33 |
+
from transformers.modeling_outputs import BaseModelOutputWithPooling, MaskedLMOutput, SequenceClassifierOutput, \
|
34 |
+
QuestionAnsweringModelOutput, TokenClassifierOutput
|
35 |
+
|
36 |
+
cache_path = Path(os.path.abspath(__file__)).parent
|
37 |
+
|
38 |
+
|
39 |
+
def download_file(filename: str, path: Path):
|
40 |
+
if os.path.exists(cache_path / filename):
|
41 |
+
return
|
42 |
+
|
43 |
+
if os.path.exists(path / filename):
|
44 |
+
shutil.copyfile(path / filename, cache_path / filename)
|
45 |
+
return
|
46 |
+
|
47 |
+
hf_hub_download(
|
48 |
+
"iioSnail/ChineseBERT-base",
|
49 |
+
filename,
|
50 |
+
local_dir=cache_path,
|
51 |
+
user_agent=http_user_agent(),
|
52 |
+
)
|
53 |
+
time.sleep(0.2)
|
54 |
+
|
55 |
+
|
56 |
+
class GlyceBertModel(BertModel):
|
57 |
+
r"""
|
58 |
+
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
59 |
+
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
|
60 |
+
Sequence of hidden-states at the output of the last layer of the models.
|
61 |
+
**pooler_output**: ``torch.FloatTensor`` of shape ``(batch_size, hidden_size)``
|
62 |
+
Last layer hidden-state of the first token of the sequence (classification token)
|
63 |
+
further processed by a Linear layer and a Tanh activation function. The Linear
|
64 |
+
layer weights are trained from the next sentence prediction (classification)
|
65 |
+
objective during Bert pretraining. This output is usually *not* a good summary
|
66 |
+
of the semantic content of the input, you're often better with averaging or pooling
|
67 |
+
the sequence of hidden-states for the whole input sequence.
|
68 |
+
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
69 |
+
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
70 |
+
of shape ``(batch_size, sequence_length, hidden_size)``:
|
71 |
+
Hidden-states of the models at the output of each layer plus the initial embedding outputs.
|
72 |
+
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
73 |
+
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
74 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
75 |
+
|
76 |
+
Examples::
|
77 |
+
|
78 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
79 |
+
models = BertModel.from_pretrained('bert-base-uncased')
|
80 |
+
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
|
81 |
+
outputs = models(input_ids)
|
82 |
+
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
|
83 |
+
|
84 |
+
"""
|
85 |
+
|
86 |
+
def __init__(self, config):
|
87 |
+
super(GlyceBertModel, self).__init__(config)
|
88 |
+
self.config = config
|
89 |
+
|
90 |
+
self.embeddings = FusionBertEmbeddings(config)
|
91 |
+
self.encoder = BertEncoder(config)
|
92 |
+
self.pooler = BertPooler(config)
|
93 |
+
|
94 |
+
self.init_weights()
|
95 |
+
|
96 |
+
def forward(
|
97 |
+
self,
|
98 |
+
input_ids=None,
|
99 |
+
pinyin_ids=None,
|
100 |
+
attention_mask=None,
|
101 |
+
token_type_ids=None,
|
102 |
+
position_ids=None,
|
103 |
+
head_mask=None,
|
104 |
+
inputs_embeds=None,
|
105 |
+
encoder_hidden_states=None,
|
106 |
+
encoder_attention_mask=None,
|
107 |
+
output_attentions=None,
|
108 |
+
output_hidden_states=None,
|
109 |
+
return_dict=None,
|
110 |
+
):
|
111 |
+
r"""
|
112 |
+
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
113 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
|
114 |
+
if the models is configured as a decoder.
|
115 |
+
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
116 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask
|
117 |
+
is used in the cross-attention if the models is configured as a decoder.
|
118 |
+
Mask values selected in ``[0, 1]``:
|
119 |
+
|
120 |
+
- 1 for tokens that are **not masked**,
|
121 |
+
- 0 for tokens that are **masked**.
|
122 |
+
"""
|
123 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
124 |
+
output_hidden_states = (
|
125 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
126 |
+
)
|
127 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
128 |
+
|
129 |
+
if input_ids is not None and inputs_embeds is not None:
|
130 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
131 |
+
elif input_ids is not None:
|
132 |
+
input_shape = input_ids.size()
|
133 |
+
elif inputs_embeds is not None:
|
134 |
+
input_shape = inputs_embeds.size()[:-1]
|
135 |
+
else:
|
136 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
137 |
+
|
138 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
139 |
+
|
140 |
+
if attention_mask is None:
|
141 |
+
attention_mask = torch.ones(input_shape, device=device)
|
142 |
+
if token_type_ids is None:
|
143 |
+
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
144 |
+
|
145 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
146 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
147 |
+
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
|
148 |
+
|
149 |
+
# If a 2D or 3D attention mask is provided for the cross-attention
|
150 |
+
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
151 |
+
if self.config.is_decoder and encoder_hidden_states is not None:
|
152 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
153 |
+
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
154 |
+
if encoder_attention_mask is None:
|
155 |
+
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
156 |
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
157 |
+
else:
|
158 |
+
encoder_extended_attention_mask = None
|
159 |
+
|
160 |
+
# Prepare head mask if needed
|
161 |
+
# 1.0 in head_mask indicate we keep the head
|
162 |
+
# attention_probs has shape bsz x n_heads x N x N
|
163 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
164 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
165 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
166 |
+
|
167 |
+
embedding_output = self.embeddings(
|
168 |
+
input_ids=input_ids, pinyin_ids=pinyin_ids, position_ids=position_ids, token_type_ids=token_type_ids,
|
169 |
+
inputs_embeds=inputs_embeds
|
170 |
+
)
|
171 |
+
encoder_outputs = self.encoder(
|
172 |
+
embedding_output,
|
173 |
+
attention_mask=extended_attention_mask,
|
174 |
+
head_mask=head_mask,
|
175 |
+
encoder_hidden_states=encoder_hidden_states,
|
176 |
+
encoder_attention_mask=encoder_extended_attention_mask,
|
177 |
+
output_attentions=output_attentions,
|
178 |
+
output_hidden_states=output_hidden_states,
|
179 |
+
return_dict=return_dict,
|
180 |
+
)
|
181 |
+
sequence_output = encoder_outputs[0]
|
182 |
+
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
183 |
+
|
184 |
+
if not return_dict:
|
185 |
+
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
186 |
+
|
187 |
+
return BaseModelOutputWithPooling(
|
188 |
+
last_hidden_state=sequence_output,
|
189 |
+
pooler_output=pooled_output,
|
190 |
+
hidden_states=encoder_outputs.hidden_states,
|
191 |
+
attentions=encoder_outputs.attentions,
|
192 |
+
)
|
193 |
+
|
194 |
+
|
195 |
+
class GlyceBertForMaskedLM(BertPreTrainedModel):
|
196 |
+
def __init__(self, config):
|
197 |
+
super(GlyceBertForMaskedLM, self).__init__(config)
|
198 |
+
|
199 |
+
self.bert = GlyceBertModel(config)
|
200 |
+
self.cls = BertOnlyMLMHead(config)
|
201 |
+
|
202 |
+
self.init_weights()
|
203 |
+
|
204 |
+
def get_output_embeddings(self):
|
205 |
+
return self.cls.predictions.decoder
|
206 |
+
|
207 |
+
def forward(
|
208 |
+
self,
|
209 |
+
input_ids=None,
|
210 |
+
pinyin_ids=None,
|
211 |
+
attention_mask=None,
|
212 |
+
token_type_ids=None,
|
213 |
+
position_ids=None,
|
214 |
+
head_mask=None,
|
215 |
+
inputs_embeds=None,
|
216 |
+
encoder_hidden_states=None,
|
217 |
+
encoder_attention_mask=None,
|
218 |
+
labels=None,
|
219 |
+
output_attentions=None,
|
220 |
+
output_hidden_states=None,
|
221 |
+
return_dict=None,
|
222 |
+
**kwargs
|
223 |
+
):
|
224 |
+
r"""
|
225 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
226 |
+
Labels for computing the masked language modeling loss.
|
227 |
+
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
|
228 |
+
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
|
229 |
+
in ``[0, ..., config.vocab_size]``
|
230 |
+
kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
|
231 |
+
Used to hide legacy arguments that have been deprecated.
|
232 |
+
"""
|
233 |
+
if "masked_lm_labels" in kwargs:
|
234 |
+
warnings.warn(
|
235 |
+
"The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
|
236 |
+
FutureWarning,
|
237 |
+
)
|
238 |
+
labels = kwargs.pop("masked_lm_labels")
|
239 |
+
assert "lm_labels" not in kwargs, "Use `BertWithLMHead` for autoregressive language modeling task."
|
240 |
+
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
|
241 |
+
|
242 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
243 |
+
|
244 |
+
outputs = self.bert(
|
245 |
+
input_ids,
|
246 |
+
pinyin_ids,
|
247 |
+
attention_mask=attention_mask,
|
248 |
+
token_type_ids=token_type_ids,
|
249 |
+
position_ids=position_ids,
|
250 |
+
head_mask=head_mask,
|
251 |
+
inputs_embeds=inputs_embeds,
|
252 |
+
encoder_hidden_states=encoder_hidden_states,
|
253 |
+
encoder_attention_mask=encoder_attention_mask,
|
254 |
+
output_attentions=output_attentions,
|
255 |
+
output_hidden_states=output_hidden_states,
|
256 |
+
return_dict=return_dict,
|
257 |
+
)
|
258 |
+
|
259 |
+
sequence_output = outputs[0]
|
260 |
+
prediction_scores = self.cls(sequence_output)
|
261 |
+
|
262 |
+
masked_lm_loss = None
|
263 |
+
if labels is not None:
|
264 |
+
loss_fct = CrossEntropyLoss() # -100 index = padding token
|
265 |
+
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
266 |
+
|
267 |
+
if not return_dict:
|
268 |
+
output = (prediction_scores,) + outputs[2:]
|
269 |
+
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
270 |
+
|
271 |
+
return MaskedLMOutput(
|
272 |
+
loss=masked_lm_loss,
|
273 |
+
logits=prediction_scores,
|
274 |
+
hidden_states=outputs.hidden_states,
|
275 |
+
attentions=outputs.attentions,
|
276 |
+
)
|
277 |
+
|
278 |
+
|
279 |
+
class GlyceBertForSequenceClassification(BertPreTrainedModel):
|
280 |
+
def __init__(self, config):
|
281 |
+
super().__init__(config)
|
282 |
+
self.num_labels = config.num_labels
|
283 |
+
|
284 |
+
self.bert = GlyceBertModel(config)
|
285 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
286 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
287 |
+
|
288 |
+
self.init_weights()
|
289 |
+
|
290 |
+
def forward(
|
291 |
+
self,
|
292 |
+
input_ids=None,
|
293 |
+
pinyin_ids=None,
|
294 |
+
attention_mask=None,
|
295 |
+
token_type_ids=None,
|
296 |
+
position_ids=None,
|
297 |
+
head_mask=None,
|
298 |
+
inputs_embeds=None,
|
299 |
+
labels=None,
|
300 |
+
output_attentions=None,
|
301 |
+
output_hidden_states=None,
|
302 |
+
return_dict=None,
|
303 |
+
):
|
304 |
+
r"""
|
305 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
306 |
+
Labels for computing the sequence classification/regression loss.
|
307 |
+
Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
|
308 |
+
If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
|
309 |
+
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
310 |
+
"""
|
311 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
312 |
+
|
313 |
+
outputs = self.bert(
|
314 |
+
input_ids,
|
315 |
+
pinyin_ids,
|
316 |
+
attention_mask=attention_mask,
|
317 |
+
token_type_ids=token_type_ids,
|
318 |
+
position_ids=position_ids,
|
319 |
+
head_mask=head_mask,
|
320 |
+
inputs_embeds=inputs_embeds,
|
321 |
+
output_attentions=output_attentions,
|
322 |
+
output_hidden_states=output_hidden_states,
|
323 |
+
return_dict=return_dict,
|
324 |
+
)
|
325 |
+
|
326 |
+
pooled_output = outputs[1]
|
327 |
+
|
328 |
+
pooled_output = self.dropout(pooled_output)
|
329 |
+
logits = self.classifier(pooled_output)
|
330 |
+
|
331 |
+
loss = None
|
332 |
+
if labels is not None:
|
333 |
+
if self.num_labels == 1:
|
334 |
+
# We are doing regression
|
335 |
+
loss_fct = MSELoss()
|
336 |
+
loss = loss_fct(logits.view(-1), labels.view(-1))
|
337 |
+
else:
|
338 |
+
loss_fct = CrossEntropyLoss()
|
339 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
340 |
+
|
341 |
+
if not return_dict:
|
342 |
+
output = (logits,) + outputs[2:]
|
343 |
+
return ((loss,) + output) if loss is not None else output
|
344 |
+
|
345 |
+
return SequenceClassifierOutput(
|
346 |
+
loss=loss,
|
347 |
+
logits=logits,
|
348 |
+
hidden_states=outputs.hidden_states,
|
349 |
+
attentions=outputs.attentions,
|
350 |
+
)
|
351 |
+
|
352 |
+
|
353 |
+
class GlyceBertForQuestionAnswering(BertPreTrainedModel):
|
354 |
+
"""BERT model for Question Answering (span extraction).
|
355 |
+
This module is composed of the BERT model with a linear layer on top of
|
356 |
+
the sequence output that computes start_logits and end_logits
|
357 |
+
|
358 |
+
Params:
|
359 |
+
`config`: a BertConfig class instance with the configuration to build a new model.
|
360 |
+
|
361 |
+
Inputs:
|
362 |
+
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
|
363 |
+
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
|
364 |
+
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
|
365 |
+
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
|
366 |
+
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
|
367 |
+
a `sentence B` token (see BERT paper for more details).
|
368 |
+
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
|
369 |
+
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
|
370 |
+
input sequence length in the current batch. It's the mask that we typically use for attention when
|
371 |
+
a batch has varying length sentences.
|
372 |
+
`start_positions`: position of the first token for the labeled span: torch.LongTensor of shape [batch_size].
|
373 |
+
Positions are clamped to the length of the sequence and position outside of the sequence are not taken
|
374 |
+
into account for computing the loss.
|
375 |
+
`end_positions`: position of the last token for the labeled span: torch.LongTensor of shape [batch_size].
|
376 |
+
Positions are clamped to the length of the sequence and position outside of the sequence are not taken
|
377 |
+
into account for computing the loss.
|
378 |
+
|
379 |
+
Outputs:
|
380 |
+
if `start_positions` and `end_positions` are not `None`:
|
381 |
+
Outputs the total_loss which is the sum of the CrossEntropy loss for the start and end token positions.
|
382 |
+
if `start_positions` or `end_positions` is `None`:
|
383 |
+
Outputs a tuple of start_logits, end_logits which are the logits respectively for the start and end
|
384 |
+
position tokens of shape [batch_size, sequence_length].
|
385 |
+
|
386 |
+
Example usage:
|
387 |
+
```python
|
388 |
+
# Already been converted into WordPiece token ids
|
389 |
+
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
390 |
+
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
391 |
+
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
|
392 |
+
|
393 |
+
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
|
394 |
+
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
|
395 |
+
|
396 |
+
model = BertForQuestionAnswering(config)
|
397 |
+
start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
|
398 |
+
```
|
399 |
+
"""
|
400 |
+
|
401 |
+
def __init__(self, config):
|
402 |
+
super().__init__(config)
|
403 |
+
self.num_labels = config.num_labels
|
404 |
+
|
405 |
+
self.bert = GlyceBertModel(config)
|
406 |
+
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
407 |
+
|
408 |
+
self.init_weights()
|
409 |
+
|
410 |
+
def forward(
|
411 |
+
self,
|
412 |
+
input_ids=None,
|
413 |
+
pinyin_ids=None,
|
414 |
+
attention_mask=None,
|
415 |
+
token_type_ids=None,
|
416 |
+
position_ids=None,
|
417 |
+
head_mask=None,
|
418 |
+
inputs_embeds=None,
|
419 |
+
start_positions=None,
|
420 |
+
end_positions=None,
|
421 |
+
output_attentions=None,
|
422 |
+
output_hidden_states=None,
|
423 |
+
return_dict=None,
|
424 |
+
):
|
425 |
+
r"""
|
426 |
+
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
427 |
+
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
428 |
+
Positions are clamped to the length of the sequence (:obj:`sequence_length`).
|
429 |
+
Position outside of the sequence are not taken into account for computing the loss.
|
430 |
+
end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
431 |
+
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
432 |
+
Positions are clamped to the length of the sequence (:obj:`sequence_length`).
|
433 |
+
Position outside of the sequence are not taken into account for computing the loss.
|
434 |
+
"""
|
435 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
436 |
+
|
437 |
+
outputs = self.bert(
|
438 |
+
input_ids,
|
439 |
+
pinyin_ids,
|
440 |
+
attention_mask=attention_mask,
|
441 |
+
token_type_ids=token_type_ids,
|
442 |
+
position_ids=position_ids,
|
443 |
+
head_mask=head_mask,
|
444 |
+
inputs_embeds=inputs_embeds,
|
445 |
+
output_attentions=output_attentions,
|
446 |
+
output_hidden_states=output_hidden_states,
|
447 |
+
return_dict=return_dict,
|
448 |
+
)
|
449 |
+
|
450 |
+
sequence_output = outputs[0]
|
451 |
+
|
452 |
+
logits = self.qa_outputs(sequence_output)
|
453 |
+
start_logits, end_logits = logits.split(1, dim=-1)
|
454 |
+
start_logits = start_logits.squeeze(-1)
|
455 |
+
end_logits = end_logits.squeeze(-1)
|
456 |
+
|
457 |
+
total_loss = None
|
458 |
+
if start_positions is not None and end_positions is not None:
|
459 |
+
# If we are on multi-GPU, split add a dimension
|
460 |
+
if len(start_positions.size()) > 1:
|
461 |
+
start_positions = start_positions.squeeze(-1)
|
462 |
+
if len(end_positions.size()) > 1:
|
463 |
+
end_positions = end_positions.squeeze(-1)
|
464 |
+
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
465 |
+
ignored_index = start_logits.size(1)
|
466 |
+
start_positions.clamp_(0, ignored_index)
|
467 |
+
end_positions.clamp_(0, ignored_index)
|
468 |
+
|
469 |
+
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
470 |
+
start_loss = loss_fct(start_logits, start_positions)
|
471 |
+
end_loss = loss_fct(end_logits, end_positions)
|
472 |
+
total_loss = (start_loss + end_loss) / 2
|
473 |
+
|
474 |
+
if not return_dict:
|
475 |
+
output = (start_logits, end_logits) + outputs[2:]
|
476 |
+
return ((total_loss,) + output) if total_loss is not None else output
|
477 |
+
|
478 |
+
return QuestionAnsweringModelOutput(
|
479 |
+
loss=total_loss,
|
480 |
+
start_logits=start_logits,
|
481 |
+
end_logits=end_logits,
|
482 |
+
hidden_states=outputs.hidden_states,
|
483 |
+
attentions=outputs.attentions,
|
484 |
+
)
|
485 |
+
|
486 |
+
|
487 |
+
class GlyceBertForTokenClassification(BertPreTrainedModel):
|
488 |
+
def __init__(self, config, mlp=False):
|
489 |
+
super().__init__(config)
|
490 |
+
self.num_labels = config.num_labels
|
491 |
+
|
492 |
+
self.bert = GlyceBertModel(config)
|
493 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
494 |
+
if mlp:
|
495 |
+
self.classifier = BertMLP(config)
|
496 |
+
else:
|
497 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
498 |
+
|
499 |
+
self.init_weights()
|
500 |
+
|
501 |
+
def forward(self,
|
502 |
+
input_ids=None,
|
503 |
+
pinyin_ids=None,
|
504 |
+
attention_mask=None,
|
505 |
+
token_type_ids=None,
|
506 |
+
position_ids=None,
|
507 |
+
head_mask=None,
|
508 |
+
inputs_embeds=None,
|
509 |
+
labels=None,
|
510 |
+
output_attentions=None,
|
511 |
+
output_hidden_states=None,
|
512 |
+
return_dict=None,
|
513 |
+
):
|
514 |
+
r"""
|
515 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
516 |
+
Labels for computing the token classification loss.
|
517 |
+
Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
|
518 |
+
"""
|
519 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
520 |
+
|
521 |
+
outputs = self.bert(
|
522 |
+
input_ids,
|
523 |
+
pinyin_ids,
|
524 |
+
attention_mask=attention_mask,
|
525 |
+
token_type_ids=token_type_ids,
|
526 |
+
position_ids=position_ids,
|
527 |
+
head_mask=head_mask,
|
528 |
+
inputs_embeds=inputs_embeds,
|
529 |
+
output_attentions=output_attentions,
|
530 |
+
output_hidden_states=output_hidden_states,
|
531 |
+
return_dict=return_dict,
|
532 |
+
)
|
533 |
+
|
534 |
+
sequence_output = outputs[0]
|
535 |
+
|
536 |
+
sequence_output = self.dropout(sequence_output)
|
537 |
+
logits = self.classifier(sequence_output)
|
538 |
+
|
539 |
+
loss = None
|
540 |
+
if labels is not None:
|
541 |
+
loss_fct = CrossEntropyLoss()
|
542 |
+
# Only keep the active parts of the loss
|
543 |
+
if attention_mask is not None:
|
544 |
+
active_loss = attention_mask.view(-1) == 1
|
545 |
+
active_logits = logits.view(-1, self.num_labels)
|
546 |
+
active_labels = torch.where(
|
547 |
+
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
|
548 |
+
)
|
549 |
+
loss = loss_fct(active_logits, active_labels)
|
550 |
+
else:
|
551 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
552 |
+
|
553 |
+
if not return_dict:
|
554 |
+
output = (logits,) + outputs[2:]
|
555 |
+
return ((loss,) + output) if loss is not None else output
|
556 |
+
|
557 |
+
return TokenClassifierOutput(
|
558 |
+
loss=loss,
|
559 |
+
logits=logits,
|
560 |
+
hidden_states=outputs.hidden_states,
|
561 |
+
attentions=outputs.attentions,
|
562 |
+
)
|
563 |
+
|
564 |
+
|
565 |
+
class FusionBertEmbeddings(nn.Module):
|
566 |
+
"""
|
567 |
+
Construct the embeddings from word, position, glyph, pinyin and token_type embeddings.
|
568 |
+
"""
|
569 |
+
|
570 |
+
def __init__(self, config):
|
571 |
+
super(FusionBertEmbeddings, self).__init__()
|
572 |
+
self.path = Path(config._name_or_path)
|
573 |
+
config_path = cache_path / 'config'
|
574 |
+
if not os.path.exists(config_path):
|
575 |
+
os.makedirs(config_path)
|
576 |
+
|
577 |
+
font_files = []
|
578 |
+
download_file("config/STFANGSO.TTF24.npy", self.path)
|
579 |
+
download_file("config/STXINGKA.TTF24.npy", self.path)
|
580 |
+
download_file("config/方正古隶繁体.ttf24.npy", self.path)
|
581 |
+
for file in os.listdir(config_path):
|
582 |
+
if file.endswith(".npy"):
|
583 |
+
font_files.append(str(config_path / file))
|
584 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
|
585 |
+
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
586 |
+
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
|
587 |
+
self.pinyin_embeddings = PinyinEmbedding(embedding_size=128, pinyin_out_dim=config.hidden_size, config=config)
|
588 |
+
self.glyph_embeddings = GlyphEmbedding(font_npy_files=font_files)
|
589 |
+
|
590 |
+
# self.LayerNorm is not snake-cased to stick with TensorFlow models variable name and be able to load
|
591 |
+
# any TensorFlow checkpoint file
|
592 |
+
self.glyph_map = nn.Linear(1728, config.hidden_size)
|
593 |
+
self.map_fc = nn.Linear(config.hidden_size * 3, config.hidden_size)
|
594 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
595 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
596 |
+
|
597 |
+
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
598 |
+
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
599 |
+
|
600 |
+
def forward(self, input_ids=None, pinyin_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
|
601 |
+
if input_ids is not None:
|
602 |
+
input_shape = input_ids.size()
|
603 |
+
else:
|
604 |
+
input_shape = inputs_embeds.size()[:-1]
|
605 |
+
|
606 |
+
seq_length = input_shape[1]
|
607 |
+
|
608 |
+
if position_ids is None:
|
609 |
+
position_ids = self.position_ids[:, :seq_length]
|
610 |
+
|
611 |
+
if token_type_ids is None:
|
612 |
+
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
|
613 |
+
|
614 |
+
if inputs_embeds is None:
|
615 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
616 |
+
|
617 |
+
# get char embedding, pinyin embedding and glyph embedding
|
618 |
+
word_embeddings = inputs_embeds # [bs,l,hidden_size]
|
619 |
+
pinyin_embeddings = self.pinyin_embeddings(pinyin_ids) # [bs,l,hidden_size]
|
620 |
+
glyph_embeddings = self.glyph_map(self.glyph_embeddings(input_ids)) # [bs,l,hidden_size]
|
621 |
+
# fusion layer
|
622 |
+
concat_embeddings = torch.cat((word_embeddings, pinyin_embeddings, glyph_embeddings), 2)
|
623 |
+
inputs_embeds = self.map_fc(concat_embeddings)
|
624 |
+
|
625 |
+
position_embeddings = self.position_embeddings(position_ids)
|
626 |
+
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
627 |
+
|
628 |
+
embeddings = inputs_embeds + position_embeddings + token_type_embeddings
|
629 |
+
embeddings = self.LayerNorm(embeddings)
|
630 |
+
embeddings = self.dropout(embeddings)
|
631 |
+
return embeddings
|
632 |
+
|
633 |
+
|
634 |
+
class PinyinEmbedding(nn.Module):
|
635 |
+
|
636 |
+
def __init__(self, embedding_size: int, pinyin_out_dim: int, config):
|
637 |
+
"""
|
638 |
+
Pinyin Embedding Module
|
639 |
+
Args:
|
640 |
+
embedding_size: the size of each embedding vector
|
641 |
+
pinyin_out_dim: kernel number of conv
|
642 |
+
"""
|
643 |
+
super(PinyinEmbedding, self).__init__()
|
644 |
+
download_file('config/pinyin_map.json', Path(config._name_or_path))
|
645 |
+
with open(cache_path / 'config' / 'pinyin_map.json') as fin:
|
646 |
+
pinyin_dict = json.load(fin)
|
647 |
+
self.pinyin_out_dim = pinyin_out_dim
|
648 |
+
self.embedding = nn.Embedding(len(pinyin_dict['idx2char']), embedding_size)
|
649 |
+
self.conv = nn.Conv1d(in_channels=embedding_size, out_channels=self.pinyin_out_dim, kernel_size=2,
|
650 |
+
stride=1, padding=0)
|
651 |
+
|
652 |
+
def forward(self, pinyin_ids):
|
653 |
+
"""
|
654 |
+
Args:
|
655 |
+
pinyin_ids: (bs*sentence_length*pinyin_locs)
|
656 |
+
|
657 |
+
Returns:
|
658 |
+
pinyin_embed: (bs,sentence_length,pinyin_out_dim)
|
659 |
+
"""
|
660 |
+
# input pinyin ids for 1-D conv
|
661 |
+
embed = self.embedding(pinyin_ids) # [bs,sentence_length,pinyin_locs,embed_size]
|
662 |
+
bs, sentence_length, pinyin_locs, embed_size = embed.shape
|
663 |
+
view_embed = embed.view(-1, pinyin_locs, embed_size) # [(bs*sentence_length),pinyin_locs,embed_size]
|
664 |
+
input_embed = view_embed.permute(0, 2, 1) # [(bs*sentence_length), embed_size, pinyin_locs]
|
665 |
+
# conv + max_pooling
|
666 |
+
pinyin_conv = self.conv(input_embed) # [(bs*sentence_length),pinyin_out_dim,H]
|
667 |
+
pinyin_embed = F.max_pool1d(pinyin_conv, pinyin_conv.shape[-1]) # [(bs*sentence_length),pinyin_out_dim,1]
|
668 |
+
return pinyin_embed.view(bs, sentence_length, self.pinyin_out_dim) # [bs,sentence_length,pinyin_out_dim]
|
669 |
+
|
670 |
+
|
671 |
+
class BertMLP(nn.Module):
|
672 |
+
def __init__(self, config, ):
|
673 |
+
super().__init__()
|
674 |
+
self.dense_layer = nn.Linear(config.hidden_size, config.hidden_size)
|
675 |
+
self.dense_to_labels_layer = nn.Linear(config.hidden_size, config.num_labels)
|
676 |
+
self.activation = nn.Tanh()
|
677 |
+
|
678 |
+
def forward(self, sequence_hidden_states):
|
679 |
+
sequence_output = self.dense_layer(sequence_hidden_states)
|
680 |
+
sequence_output = self.activation(sequence_output)
|
681 |
+
sequence_output = self.dense_to_labels_layer(sequence_output)
|
682 |
+
return sequence_output
|
683 |
+
|
684 |
+
|
685 |
+
class GlyphEmbedding(nn.Module):
|
686 |
+
"""Glyph2Image Embedding"""
|
687 |
+
|
688 |
+
def __init__(self, font_npy_files: List[str]):
|
689 |
+
super(GlyphEmbedding, self).__init__()
|
690 |
+
font_arrays = [
|
691 |
+
np.load(np_file).astype(np.float32) for np_file in font_npy_files
|
692 |
+
]
|
693 |
+
self.vocab_size = font_arrays[0].shape[0]
|
694 |
+
self.font_num = len(font_arrays)
|
695 |
+
self.font_size = font_arrays[0].shape[-1]
|
696 |
+
# N, C, H, W
|
697 |
+
font_array = np.stack(font_arrays, axis=1)
|
698 |
+
self.embedding = nn.Embedding(
|
699 |
+
num_embeddings=self.vocab_size,
|
700 |
+
embedding_dim=self.font_size ** 2 * self.font_num,
|
701 |
+
_weight=torch.from_numpy(font_array.reshape([self.vocab_size, -1]))
|
702 |
+
)
|
703 |
+
|
704 |
+
def forward(self, input_ids):
|
705 |
+
"""
|
706 |
+
get glyph images for batch inputs
|
707 |
+
Args:
|
708 |
+
input_ids: [batch, sentence_length]
|
709 |
+
Returns:
|
710 |
+
images: [batch, sentence_length, self.font_num*self.font_size*self.font_size]
|
711 |
+
"""
|
712 |
+
# return self.embedding(input_ids).view([-1, self.font_num, self.font_size, self.font_size])
|
713 |
+
return self.embedding(input_ids)
|
pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:87c712d04ec75a023b5689e5f09a2fd48e1e94e8fabc4adbaa0e83afdf3c5f47
|
3 |
+
size 1496502999
|
special_tokens_map.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cls_token": "[CLS]",
|
3 |
+
"mask_token": "[MASK]",
|
4 |
+
"pad_token": "[PAD]",
|
5 |
+
"sep_token": "[SEP]",
|
6 |
+
"unk_token": "[UNK]"
|
7 |
+
}
|
tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tokenizer_config.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"auto_map": {
|
3 |
+
"AutoTokenizer": [
|
4 |
+
"bert_tokenizer.ChineseBertTokenizer",
|
5 |
+
null
|
6 |
+
]
|
7 |
+
},
|
8 |
+
"cls_token": "[CLS]",
|
9 |
+
"do_basic_tokenize": true,
|
10 |
+
"do_lower_case": true,
|
11 |
+
"mask_token": "[MASK]",
|
12 |
+
"model_max_length": 1000000000000000019884624838656,
|
13 |
+
"never_split": null,
|
14 |
+
"pad_token": "[PAD]",
|
15 |
+
"sep_token": "[SEP]",
|
16 |
+
"special_tokens_map_file": null,
|
17 |
+
"strip_accents": null,
|
18 |
+
"tokenize_chinese_chars": true,
|
19 |
+
"tokenizer_class": "ChineseBertTokenizer",
|
20 |
+
"unk_token": "[UNK]"
|
21 |
+
}
|
vocab.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|