Spaces:
Sleeping
Sleeping
terapyon
commited on
Commit
·
5a7a866
1
Parent(s):
720d81f
モデル部分を移植し、リファクタリングをした
Browse files- inference.py +220 -0
inference.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import re
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Generator
|
5 |
+
from unicodedata import normalize
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import streamlit as st
|
9 |
+
import tomotopy as tp # type: ignore
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import transformers as T # type: ignore
|
13 |
+
from scipy import stats # type: ignore
|
14 |
+
from sudachipy import dictionary, tokenizer # type: ignore
|
15 |
+
|
16 |
+
MODELS_PATH = Path("/home/terapyon/dev/awarefy/models/for-amp")
|
17 |
+
model_base_path = MODELS_PATH / "2class_1019/"
|
18 |
+
topic_model_trained = MODELS_PATH / "trained_model.bin"
|
19 |
+
japanese_selection_path = MODELS_PATH / "Japanese_selection.txt"
|
20 |
+
|
21 |
+
# GPUの指定
|
22 |
+
if torch.cuda.is_available():
|
23 |
+
gpu = 0
|
24 |
+
# gpu = -1 # For debugging
|
25 |
+
else:
|
26 |
+
gpu = -1 # gpu = -1 # GPUが使用できなければ(CPUで処理)-1を指定
|
27 |
+
|
28 |
+
|
29 |
+
cls_num = 3
|
30 |
+
max_length = 512
|
31 |
+
k_folds = 10
|
32 |
+
bert_model_name = "cl-tohoku/bert-base-japanese-v3"
|
33 |
+
device = torch.device(f"cuda:{gpu}" if gpu>=0 else "cpu")
|
34 |
+
|
35 |
+
|
36 |
+
#BERTモデルの定義
|
37 |
+
class BertClassifier(nn.Module):
|
38 |
+
def __init__(self, model_name, cls_num=3):
|
39 |
+
super(BertClassifier, self).__init__()
|
40 |
+
#model_name = "cl-tohoku/bert-base-japanese"
|
41 |
+
self.bert = T.BertModel.from_pretrained(model_name, output_attentions=True)
|
42 |
+
self.fc = nn.Linear(768, cls_num, bias=True)
|
43 |
+
|
44 |
+
nn.init.normal_(self.fc.weight, std=0.02)
|
45 |
+
nn.init.normal_(self.fc.bias, 0)
|
46 |
+
|
47 |
+
def forward(self, input_ids, masks):
|
48 |
+
result = self.bert(input_ids, masks)
|
49 |
+
|
50 |
+
vec = result[0]
|
51 |
+
_ = result[1]
|
52 |
+
attentions = result[2]
|
53 |
+
|
54 |
+
vec = vec[:, 0, :]
|
55 |
+
vec = vec.view(-1, 768)
|
56 |
+
output = self.fc(vec)
|
57 |
+
return output, _, attentions
|
58 |
+
|
59 |
+
|
60 |
+
#日本語Stopwords除去関数
|
61 |
+
def load_stopwords() -> set[str]:
|
62 |
+
with open(japanese_selection_path, "r", encoding="utf-8") as f:
|
63 |
+
# stopwords = [w.strip() for w in f]
|
64 |
+
# stopwords = set(stopwords)
|
65 |
+
stopwords = {w.strip() for w in f if w.strip()}
|
66 |
+
return stopwords
|
67 |
+
|
68 |
+
|
69 |
+
class SudachiTokenizer:
|
70 |
+
def __init__(self, split_mode="C"):
|
71 |
+
self.tokenizer_obj = dictionary.Dictionary(dict_type="full").create()
|
72 |
+
self.stopwords = load_stopwords()
|
73 |
+
if split_mode == "A":
|
74 |
+
self.mode = tokenizer.Tokenizer.SplitMode.C
|
75 |
+
elif split_mode == "B":
|
76 |
+
self.mode = tokenizer.Tokenizer.SplitMode.B
|
77 |
+
else:
|
78 |
+
self.mode = tokenizer.Tokenizer.SplitMode.C
|
79 |
+
# ひらがなのみの文字列にマッチする正規表現
|
80 |
+
self.kana_re = re.compile("^[ぁ-ゖ]+$")
|
81 |
+
#Stopwords
|
82 |
+
self.stopwords = load_stopwords()
|
83 |
+
|
84 |
+
def get_wakati(self, text: str) -> list[str]:
|
85 |
+
wakati_list = []
|
86 |
+
normalized_wakati_list = []
|
87 |
+
pos_list = []
|
88 |
+
normalized_text = normalize("NFKC", text)
|
89 |
+
tmp = re.sub(r'[0-9]','',normalized_text)
|
90 |
+
tmp = re.sub(r'[0-9]', '', tmp)
|
91 |
+
tmp = re.sub(r'[、。:()「」%『』()?!%→+`.・×,〜~—+=♪/!?]','',tmp)
|
92 |
+
tmp = re.sub(r'[a-zA-Z]','',tmp)
|
93 |
+
#絵文字除去
|
94 |
+
tmp = re.sub(r'[❓]', "", tmp)
|
95 |
+
for m in self.tokenizer_obj.tokenize(tmp, self.mode):
|
96 |
+
word = m.surface()
|
97 |
+
pos = m.part_of_speech()[0]
|
98 |
+
normalized_word = m.normalized_form()
|
99 |
+
wakati_list.append(word)
|
100 |
+
normalized_wakati_list.append(normalized_word)
|
101 |
+
pos_list.append(pos)
|
102 |
+
#名詞,動詞,形容詞のみに絞り込み
|
103 |
+
target_pos = ["名詞", "動詞", "形容詞"]
|
104 |
+
#target_pos = ["名詞", "形容詞"]
|
105 |
+
token_list = [t for t, p in zip(wakati_list, pos_list) if p in target_pos]
|
106 |
+
#アルファベットを小文字に統一
|
107 |
+
token_list = [t.lower() for t in token_list]
|
108 |
+
#ひらがなのみの単語を除く
|
109 |
+
#token_list = [t for t in token_list if not self.kana_re.match(t)]
|
110 |
+
#ストップワード除去
|
111 |
+
#token_list = [t for t in token_list if t not in self.stopwords]
|
112 |
+
return token_list
|
113 |
+
|
114 |
+
|
115 |
+
def make_traind_model(bert_model):
|
116 |
+
trained_models = []
|
117 |
+
for k in range(k_folds):
|
118 |
+
k = k + 1
|
119 |
+
model_path = model_base_path / f"trained_model{k}.pt"
|
120 |
+
trained_model = copy.deepcopy(bert_model)
|
121 |
+
trained_model.load_state_dict(torch.load(model_path, map_location=device),strict=False)
|
122 |
+
trained_models.append(trained_model)
|
123 |
+
return trained_models
|
124 |
+
|
125 |
+
|
126 |
+
@st.cache_resource
|
127 |
+
def init_models():
|
128 |
+
bert_model = BertClassifier(bert_model_name, cls_num=1) #出力ノードを1に設定
|
129 |
+
bert_model.eval()
|
130 |
+
bert_model.to(device)
|
131 |
+
|
132 |
+
tokenizer_sudachi = SudachiTokenizer(split_mode="C")
|
133 |
+
#Tokenizerの設定(ここではtokenizerをtokenizer_c2にしている)
|
134 |
+
tokenizer_c2 = T.BertJapaneseTokenizer.from_pretrained(bert_model_name)
|
135 |
+
trained_models = make_traind_model(bert_model)
|
136 |
+
return tokenizer_sudachi, tokenizer_c2, trained_models
|
137 |
+
|
138 |
+
|
139 |
+
tokenizer_sudachi, tokenizer_c2, trained_models = init_models()
|
140 |
+
|
141 |
+
|
142 |
+
# Attentionマップ���算出する関数の定義
|
143 |
+
def f_a(sentences: list[str], tokenizer_c2, model, device):
|
144 |
+
encoded = tokenizer_c2.batch_encode_plus(
|
145 |
+
sentences,
|
146 |
+
padding="max_length",
|
147 |
+
max_length=max_length,
|
148 |
+
truncation=True,
|
149 |
+
return_attention_mask=True
|
150 |
+
)
|
151 |
+
|
152 |
+
input_ids = torch.tensor(encoded["input_ids"]).to(device)
|
153 |
+
attention_mask = torch.tensor(encoded["attention_mask"]).to(device)
|
154 |
+
|
155 |
+
with torch.no_grad():
|
156 |
+
outputs, _, attentions = model(input_ids, attention_mask)
|
157 |
+
#return input_ids.detach().cpu(), attentions[-1].detach().cpu()
|
158 |
+
return input_ids.detach().cpu(), attentions[-1].detach().cpu(), outputs.detach().cpu()
|
159 |
+
|
160 |
+
|
161 |
+
def get_word_attn(input_ids, attention_weight) -> Generator[tuple[str, float], None, None]:
|
162 |
+
# 文章の長さ分のzero tensorを宣言
|
163 |
+
seq_len = attention_weight.size()[2]
|
164 |
+
all_attens = torch.zeros(seq_len)
|
165 |
+
|
166 |
+
# 12個のMulti Head Attentionの結果を全部足し合わせる
|
167 |
+
# 最初の0はinput_idsは1文章だけを想定しているため
|
168 |
+
# 次の0はCLSトークンのAttention結果を取得している、という意味です。
|
169 |
+
for i in range(12):
|
170 |
+
all_attens += attention_weight[0, i, 0, :]
|
171 |
+
|
172 |
+
for word, attn in zip(input_ids.flatten(), all_attens):
|
173 |
+
if tokenizer_c2.convert_ids_to_tokens(word.tolist()) == "[CLS]":
|
174 |
+
continue
|
175 |
+
if tokenizer_c2.convert_ids_to_tokens(word.tolist()) == "[SEP]":
|
176 |
+
break
|
177 |
+
converted_word = tokenizer_c2.convert_ids_to_tokens([word.numpy().tolist()])[0]
|
178 |
+
yield converted_word, attn
|
179 |
+
|
180 |
+
|
181 |
+
def classify_ma(sentence: str) -> tuple[int, torch.Tensor, torch.Tensor]:
|
182 |
+
normalized_sentence = normalize("NFKC", sentence)
|
183 |
+
|
184 |
+
attention_list, output_list = [], []
|
185 |
+
for trained_model in trained_models:
|
186 |
+
input_ids, attention, output = f_a([normalized_sentence], tokenizer_c2, trained_model, device)
|
187 |
+
attention_list.append(attention)
|
188 |
+
output_list.append(output)
|
189 |
+
|
190 |
+
#出力された10個の予測値の多数決を算出
|
191 |
+
outputs = np.concatenate(output_list)
|
192 |
+
prob_column = torch.sigmoid(torch.tensor(outputs))
|
193 |
+
pred_column = torch.ge(prob_column, 0.5).float()
|
194 |
+
ensemble_pred, count = stats.mode(pred_column)
|
195 |
+
|
196 |
+
#出力された10個のattention mapの平均値を算出
|
197 |
+
attentions = torch.concat(attention_list)
|
198 |
+
mean_attention = torch.mean(attentions, dim=0).unsqueeze(dim=0)
|
199 |
+
return ensemble_pred.item(), input_ids, mean_attention
|
200 |
+
|
201 |
+
|
202 |
+
#モデルのロードとinferの関数化
|
203 |
+
def infer_topic(new_text: str) -> tuple[np.ndarray, float]:
|
204 |
+
model_trained = tp.CTModel.load(str(topic_model_trained))
|
205 |
+
new_word_list = tokenizer_sudachi.get_wakati(new_text)
|
206 |
+
new_doc = model_trained.make_doc(new_word_list)
|
207 |
+
topic_dist, ll = model_trained.infer(new_doc)
|
208 |
+
return topic_dist, ll
|
209 |
+
|
210 |
+
|
211 |
+
if __name__ == "__main__":
|
212 |
+
text = "NHKの番組を見ていると,発達障害者の才能を特集されることが多い。それを見ていると自分もそのような才能を期待されているように感じる"
|
213 |
+
result_classify = classify_ma(text)
|
214 |
+
if result_classify[0] == 0:
|
215 |
+
print("マイクロアグレッションではない")
|
216 |
+
else:
|
217 |
+
print("マイクロアグレッションである")
|
218 |
+
res = infer_topic(text)
|
219 |
+
print(res[0])
|
220 |
+
print(res[1])
|