Linhz commited on
Commit
c44c507
·
verified ·
1 Parent(s): 6e80c15

Update Model/NER/VLSP2021/Predict_Ner.py

Browse files
Files changed (1) hide show
  1. Model/NER/VLSP2021/Predict_Ner.py +210 -210
Model/NER/VLSP2021/Predict_Ner.py CHANGED
@@ -1,210 +1,210 @@
1
-
2
- from vncorenlp import VnCoreNLP
3
-
4
- from typing import Union
5
- from transformers import AutoConfig, AutoTokenizer
6
- from Model.NER.VLSP2021.Ner_CRF import PhoBertCrf,PhoBertSoftmax,PhoBertLstmCrf
7
- import re
8
- import os
9
- import torch
10
- import itertools
11
- import numpy as np
12
-
13
- MODEL_MAPPING = {
14
- 'vinai/phobert-base': {
15
- 'softmax': PhoBertSoftmax,
16
- 'crf': PhoBertCrf,
17
- 'lstm_crf': PhoBertLstmCrf
18
- },
19
- }
20
-
21
-
22
- def normalize_text(txt: str) -> str:
23
- # Remove special character
24
- txt = re.sub("\xad|\u200b|\ufeff", "", txt)
25
- # Normalize vietnamese accents
26
- txt = re.sub(r"òa", "oà", txt)
27
- txt = re.sub(r"óa", "oá", txt)
28
- txt = re.sub(r"ỏa", "oả", txt)
29
- txt = re.sub(r"õa", "oã", txt)
30
- txt = re.sub(r"ọa", "oạ", txt)
31
- txt = re.sub(r"òe", "oè", txt)
32
- txt = re.sub(r"óe", "oé", txt)
33
- txt = re.sub(r"ỏe", "oẻ", txt)
34
- txt = re.sub(r"õe", "oẽ", txt)
35
- txt = re.sub(r"ọe", "oẹ", txt)
36
- txt = re.sub(r"ùy", "uỳ", txt)
37
- txt = re.sub(r"úy", "uý", txt)
38
- txt = re.sub(r"ủy", "uỷ", txt)
39
- txt = re.sub(r"ũy", "uỹ", txt)
40
- txt = re.sub(r"ụy", "uỵ", txt)
41
- txt = re.sub(r"Ủy", "Uỷ", txt)
42
-
43
- txt = re.sub(r'"', '”', txt)
44
-
45
- # Remove multi-space
46
- txt = re.sub(" +", " ", txt)
47
- return txt.strip()
48
- class ViTagger(object):
49
- def __init__(self, model_path: Union[str or os.PathLike], no_cuda=False):
50
- self.device = 'cuda' if not no_cuda and torch.cuda.is_available() else 'cpu'
51
- print("[ViTagger] VnCoreNLP loading ...")
52
- self.rdrsegmenter = VnCoreNLP("E:/demo_datn/pythonProject1/VnCoreNLP/VnCoreNLP-1.1.1.jar", annotators="wseg", max_heap_size='-Xmx500m')
53
- print("[ViTagger] Model loading ...")
54
- self.model, self.tokenizer, self.max_seq_len, self.label2id, self.use_crf = self.load_model(model_path, device=self.device)
55
- self.id2label = {idx: label for idx, label in enumerate(self.label2id)}
56
- print("[ViTagger] All ready!")
57
-
58
- @staticmethod
59
- def load_model(model_path: Union[str or os.PathLike], device='cpu'):
60
- if device == 'cpu':
61
- checkpoint_data = torch.load(model_path, map_location='cpu')
62
- else:
63
- checkpoint_data = torch.load(model_path)
64
- args = checkpoint_data["args"]
65
- max_seq_len = args.max_seq_length
66
- use_crf = True if 'crf' in args.model_arch else False
67
- tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=False)
68
- config = AutoConfig.from_pretrained(args.model_name_or_path, num_labels=len(args.label2id))
69
- model_clss = MODEL_MAPPING[args.model_name_or_path][args.model_arch]
70
- model = model_clss(config=config)
71
- model.load_state_dict(checkpoint_data['model'],strict=False)
72
- model.to(device)
73
- model.eval()
74
-
75
- return model, tokenizer, max_seq_len, args.label2id, use_crf
76
-
77
- def preprocess(self, in_raw: str):
78
- norm_text = normalize_text(in_raw)
79
- sents = []
80
- sentences = self.rdrsegmenter.tokenize(norm_text)
81
- for sentence in sentences:
82
- sents.append(sentence)
83
- return sents
84
-
85
- def convert_tensor(self, tokens):
86
- seq_len = len(tokens)
87
- encoding = self.tokenizer(tokens,
88
- padding='max_length',
89
- truncation=True,
90
- is_split_into_words=True,
91
- max_length=self.max_seq_len)
92
- if 'vinai/phobert' in self.tokenizer.name_or_path:
93
- print(' '.join(tokens))
94
- subwords = self.tokenizer.tokenize(' '.join(tokens))
95
- valid_ids = np.zeros(len(encoding.input_ids), dtype=int)
96
- label_marks = np.zeros(len(encoding.input_ids), dtype=int)
97
- i = 1
98
- for idx, subword in enumerate(subwords[:self.max_seq_len - 2]):
99
- if idx != 0 and subwords[idx - 1].endswith("@@"):
100
- continue
101
- if self.use_crf:
102
- valid_ids[i - 1] = idx + 1
103
- else:
104
- valid_ids[idx + 1] = 1
105
- i += 1
106
- else:
107
- valid_ids = np.zeros(len(encoding.input_ids), dtype=int)
108
- label_marks = np.zeros(len(encoding.input_ids), dtype=int)
109
- i = 1
110
- word_ids = encoding.word_ids()
111
- for idx in range(1, len(word_ids)):
112
- if word_ids[idx] is not None and word_ids[idx] != word_ids[idx - 1]:
113
- if self.use_crf:
114
- valid_ids[i - 1] = idx
115
- else:
116
- valid_ids[idx] = 1
117
- i += 1
118
- if self.max_seq_len >= seq_len + 2:
119
- label_marks[:seq_len] = [1] * seq_len
120
- else:
121
- label_marks[:-2] = [1] * (self.max_seq_len - 2)
122
- if self.use_crf and label_marks[0] == 0:
123
- raise f"{tokens} have mark == 0 at index 0!"
124
- item = {key: torch.as_tensor([val]).to(self.device, dtype=torch.long) for key, val in encoding.items()}
125
- item['valid_ids'] = torch.as_tensor([valid_ids]).to(self.device, dtype=torch.long)
126
- item['label_masks'] = torch.as_tensor([valid_ids]).to(self.device, dtype=torch.long)
127
- return item
128
-
129
- def extract_entity_doc(self, in_raw: str):
130
- sents = self.preprocess(in_raw)
131
- print(sents)
132
- entities_doc = []
133
- for sent in sents:
134
- item = self.convert_tensor(sent)
135
- with torch.no_grad():
136
- outputs = self.model(**item)
137
- entity = None
138
- if isinstance(outputs.tags[0], list):
139
- tags = list(itertools.chain(*outputs.tags))
140
- else:
141
- tags = outputs.tags
142
- for w, l in list(zip(sent, tags)):
143
- w = w.replace("_", " ")
144
- tag = self.id2label[l]
145
- if not tag == 'O':
146
- parts = tag.split('-', 1)
147
- prefix = parts[0]
148
- tag = parts[1] if len(parts) > 1 else ""
149
- if entity is None:
150
- entity = (w, tag)
151
- else:
152
- if entity[-1] == tag:
153
- if prefix == 'I':
154
- entity = (entity[0] + f' {w}', tag)
155
- else:
156
- entities_doc.append(entity)
157
- entity = (w, tag)
158
- else:
159
- entities_doc.append(entity)
160
- entity = (w, tag)
161
- elif entity is not None:
162
- entities_doc.append(entity)
163
- if w != ' ':
164
- entities_doc.append((w, 'O'))
165
- entity = None
166
- elif w != ' ':
167
- entities_doc.append((w, 'O'))
168
- entity = None
169
- return entities_doc
170
-
171
-
172
- def __call__(self, in_raw: str):
173
- sents = self.preprocess(in_raw)
174
- entites = []
175
- for sent in sents:
176
- item = self.convert_tensor(sent)
177
- with torch.no_grad():
178
- outputs = self.model(**item)
179
- entity = None
180
- if isinstance(outputs.tags[0], list):
181
- tags = list(itertools.chain(*outputs.tags))
182
- else:
183
- tags = outputs.tags
184
- for w, l in list(zip(sent, tags)):
185
- w = w.replace("_", " ")
186
- tag = self.id2label[l]
187
- if not tag == 'O':
188
- prefix, tag = tag.split('-')
189
- if entity is None:
190
- entity = (w, tag)
191
- else:
192
- if entity[-1] == tag:
193
- if prefix == 'I':
194
- entity = (entity[0] + f' {w}', tag)
195
- else:
196
- entites.append(entity)
197
- entity = (w, tag)
198
- else:
199
- entites.append(entity)
200
- entity = (w, tag)
201
- elif entity is not None:
202
- entites.append(entity)
203
- entity = None
204
- else:
205
- entity = None
206
- return entites
207
-
208
-
209
-
210
-
 
1
+
2
+ from vncorenlp import VnCoreNLP
3
+
4
+ from typing import Union
5
+ from transformers import AutoConfig, AutoTokenizer
6
+ from Model.NER.VLSP2021.Ner_CRF import PhoBertCrf,PhoBertSoftmax,PhoBertLstmCrf
7
+ import re
8
+ import os
9
+ import torch
10
+ import itertools
11
+ import numpy as np
12
+
13
+ MODEL_MAPPING = {
14
+ 'vinai/phobert-base': {
15
+ 'softmax': PhoBertSoftmax,
16
+ 'crf': PhoBertCrf,
17
+ 'lstm_crf': PhoBertLstmCrf
18
+ },
19
+ }
20
+
21
+
22
+ def normalize_text(txt: str) -> str:
23
+ # Remove special character
24
+ txt = re.sub("\xad|\u200b|\ufeff", "", txt)
25
+ # Normalize vietnamese accents
26
+ txt = re.sub(r"òa", "oà", txt)
27
+ txt = re.sub(r"óa", "oá", txt)
28
+ txt = re.sub(r"ỏa", "oả", txt)
29
+ txt = re.sub(r"õa", "oã", txt)
30
+ txt = re.sub(r"ọa", "oạ", txt)
31
+ txt = re.sub(r"òe", "oè", txt)
32
+ txt = re.sub(r"óe", "oé", txt)
33
+ txt = re.sub(r"ỏe", "oẻ", txt)
34
+ txt = re.sub(r"õe", "oẽ", txt)
35
+ txt = re.sub(r"ọe", "oẹ", txt)
36
+ txt = re.sub(r"ùy", "uỳ", txt)
37
+ txt = re.sub(r"úy", "uý", txt)
38
+ txt = re.sub(r"ủy", "uỷ", txt)
39
+ txt = re.sub(r"ũy", "uỹ", txt)
40
+ txt = re.sub(r"ụy", "uỵ", txt)
41
+ txt = re.sub(r"Ủy", "Uỷ", txt)
42
+
43
+ txt = re.sub(r'"', '”', txt)
44
+
45
+ # Remove multi-space
46
+ txt = re.sub(" +", " ", txt)
47
+ return txt.strip()
48
+ class ViTagger(object):
49
+ def __init__(self, model_path: Union[str or os.PathLike], no_cuda=False):
50
+ self.device = 'cuda' if not no_cuda and torch.cuda.is_available() else 'cpu'
51
+ print("[ViTagger] VnCoreNLP loading ...")
52
+ self.rdrsegmenter = VnCoreNLP("/VnCoreNLP/VnCoreNLP-1.1.1.jar", annotators="wseg", max_heap_size='-Xmx500m')
53
+ print("[ViTagger] Model loading ...")
54
+ self.model, self.tokenizer, self.max_seq_len, self.label2id, self.use_crf = self.load_model(model_path, device=self.device)
55
+ self.id2label = {idx: label for idx, label in enumerate(self.label2id)}
56
+ print("[ViTagger] All ready!")
57
+
58
+ @staticmethod
59
+ def load_model(model_path: Union[str or os.PathLike], device='cpu'):
60
+ if device == 'cpu':
61
+ checkpoint_data = torch.load(model_path, map_location='cpu')
62
+ else:
63
+ checkpoint_data = torch.load(model_path)
64
+ args = checkpoint_data["args"]
65
+ max_seq_len = args.max_seq_length
66
+ use_crf = True if 'crf' in args.model_arch else False
67
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=False)
68
+ config = AutoConfig.from_pretrained(args.model_name_or_path, num_labels=len(args.label2id))
69
+ model_clss = MODEL_MAPPING[args.model_name_or_path][args.model_arch]
70
+ model = model_clss(config=config)
71
+ model.load_state_dict(checkpoint_data['model'],strict=False)
72
+ model.to(device)
73
+ model.eval()
74
+
75
+ return model, tokenizer, max_seq_len, args.label2id, use_crf
76
+
77
+ def preprocess(self, in_raw: str):
78
+ norm_text = normalize_text(in_raw)
79
+ sents = []
80
+ sentences = self.rdrsegmenter.tokenize(norm_text)
81
+ for sentence in sentences:
82
+ sents.append(sentence)
83
+ return sents
84
+
85
+ def convert_tensor(self, tokens):
86
+ seq_len = len(tokens)
87
+ encoding = self.tokenizer(tokens,
88
+ padding='max_length',
89
+ truncation=True,
90
+ is_split_into_words=True,
91
+ max_length=self.max_seq_len)
92
+ if 'vinai/phobert' in self.tokenizer.name_or_path:
93
+ print(' '.join(tokens))
94
+ subwords = self.tokenizer.tokenize(' '.join(tokens))
95
+ valid_ids = np.zeros(len(encoding.input_ids), dtype=int)
96
+ label_marks = np.zeros(len(encoding.input_ids), dtype=int)
97
+ i = 1
98
+ for idx, subword in enumerate(subwords[:self.max_seq_len - 2]):
99
+ if idx != 0 and subwords[idx - 1].endswith("@@"):
100
+ continue
101
+ if self.use_crf:
102
+ valid_ids[i - 1] = idx + 1
103
+ else:
104
+ valid_ids[idx + 1] = 1
105
+ i += 1
106
+ else:
107
+ valid_ids = np.zeros(len(encoding.input_ids), dtype=int)
108
+ label_marks = np.zeros(len(encoding.input_ids), dtype=int)
109
+ i = 1
110
+ word_ids = encoding.word_ids()
111
+ for idx in range(1, len(word_ids)):
112
+ if word_ids[idx] is not None and word_ids[idx] != word_ids[idx - 1]:
113
+ if self.use_crf:
114
+ valid_ids[i - 1] = idx
115
+ else:
116
+ valid_ids[idx] = 1
117
+ i += 1
118
+ if self.max_seq_len >= seq_len + 2:
119
+ label_marks[:seq_len] = [1] * seq_len
120
+ else:
121
+ label_marks[:-2] = [1] * (self.max_seq_len - 2)
122
+ if self.use_crf and label_marks[0] == 0:
123
+ raise f"{tokens} have mark == 0 at index 0!"
124
+ item = {key: torch.as_tensor([val]).to(self.device, dtype=torch.long) for key, val in encoding.items()}
125
+ item['valid_ids'] = torch.as_tensor([valid_ids]).to(self.device, dtype=torch.long)
126
+ item['label_masks'] = torch.as_tensor([valid_ids]).to(self.device, dtype=torch.long)
127
+ return item
128
+
129
+ def extract_entity_doc(self, in_raw: str):
130
+ sents = self.preprocess(in_raw)
131
+ print(sents)
132
+ entities_doc = []
133
+ for sent in sents:
134
+ item = self.convert_tensor(sent)
135
+ with torch.no_grad():
136
+ outputs = self.model(**item)
137
+ entity = None
138
+ if isinstance(outputs.tags[0], list):
139
+ tags = list(itertools.chain(*outputs.tags))
140
+ else:
141
+ tags = outputs.tags
142
+ for w, l in list(zip(sent, tags)):
143
+ w = w.replace("_", " ")
144
+ tag = self.id2label[l]
145
+ if not tag == 'O':
146
+ parts = tag.split('-', 1)
147
+ prefix = parts[0]
148
+ tag = parts[1] if len(parts) > 1 else ""
149
+ if entity is None:
150
+ entity = (w, tag)
151
+ else:
152
+ if entity[-1] == tag:
153
+ if prefix == 'I':
154
+ entity = (entity[0] + f' {w}', tag)
155
+ else:
156
+ entities_doc.append(entity)
157
+ entity = (w, tag)
158
+ else:
159
+ entities_doc.append(entity)
160
+ entity = (w, tag)
161
+ elif entity is not None:
162
+ entities_doc.append(entity)
163
+ if w != ' ':
164
+ entities_doc.append((w, 'O'))
165
+ entity = None
166
+ elif w != ' ':
167
+ entities_doc.append((w, 'O'))
168
+ entity = None
169
+ return entities_doc
170
+
171
+
172
+ def __call__(self, in_raw: str):
173
+ sents = self.preprocess(in_raw)
174
+ entites = []
175
+ for sent in sents:
176
+ item = self.convert_tensor(sent)
177
+ with torch.no_grad():
178
+ outputs = self.model(**item)
179
+ entity = None
180
+ if isinstance(outputs.tags[0], list):
181
+ tags = list(itertools.chain(*outputs.tags))
182
+ else:
183
+ tags = outputs.tags
184
+ for w, l in list(zip(sent, tags)):
185
+ w = w.replace("_", " ")
186
+ tag = self.id2label[l]
187
+ if not tag == 'O':
188
+ prefix, tag = tag.split('-')
189
+ if entity is None:
190
+ entity = (w, tag)
191
+ else:
192
+ if entity[-1] == tag:
193
+ if prefix == 'I':
194
+ entity = (entity[0] + f' {w}', tag)
195
+ else:
196
+ entites.append(entity)
197
+ entity = (w, tag)
198
+ else:
199
+ entites.append(entity)
200
+ entity = (w, tag)
201
+ elif entity is not None:
202
+ entites.append(entity)
203
+ entity = None
204
+ else:
205
+ entity = None
206
+ return entites
207
+
208
+
209
+
210
+