terapyon commited on
Commit
5a7a866
·
1 Parent(s): 720d81f

モデル部分を移植し、リファクタリングをした

Browse files
Files changed (1) hide show
  1. inference.py +220 -0
inference.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import re
3
+ from pathlib import Path
4
+ from typing import Generator
5
+ from unicodedata import normalize
6
+
7
+ import numpy as np
8
+ import streamlit as st
9
+ import tomotopy as tp # type: ignore
10
+ import torch
11
+ import torch.nn as nn
12
+ import transformers as T # type: ignore
13
+ from scipy import stats # type: ignore
14
+ from sudachipy import dictionary, tokenizer # type: ignore
15
+
16
+ MODELS_PATH = Path("/home/terapyon/dev/awarefy/models/for-amp")
17
+ model_base_path = MODELS_PATH / "2class_1019/"
18
+ topic_model_trained = MODELS_PATH / "trained_model.bin"
19
+ japanese_selection_path = MODELS_PATH / "Japanese_selection.txt"
20
+
21
+ # GPUの指定
22
+ if torch.cuda.is_available():
23
+ gpu = 0
24
+ # gpu = -1 # For debugging
25
+ else:
26
+ gpu = -1 # gpu = -1 # GPUが使用できなければ(CPUで処理)-1を指定
27
+
28
+
29
+ cls_num = 3
30
+ max_length = 512
31
+ k_folds = 10
32
+ bert_model_name = "cl-tohoku/bert-base-japanese-v3"
33
+ device = torch.device(f"cuda:{gpu}" if gpu>=0 else "cpu")
34
+
35
+
36
+ #BERTモデルの定義
37
+ class BertClassifier(nn.Module):
38
+ def __init__(self, model_name, cls_num=3):
39
+ super(BertClassifier, self).__init__()
40
+ #model_name = "cl-tohoku/bert-base-japanese"
41
+ self.bert = T.BertModel.from_pretrained(model_name, output_attentions=True)
42
+ self.fc = nn.Linear(768, cls_num, bias=True)
43
+
44
+ nn.init.normal_(self.fc.weight, std=0.02)
45
+ nn.init.normal_(self.fc.bias, 0)
46
+
47
+ def forward(self, input_ids, masks):
48
+ result = self.bert(input_ids, masks)
49
+
50
+ vec = result[0]
51
+ _ = result[1]
52
+ attentions = result[2]
53
+
54
+ vec = vec[:, 0, :]
55
+ vec = vec.view(-1, 768)
56
+ output = self.fc(vec)
57
+ return output, _, attentions
58
+
59
+
60
+ #日本語Stopwords除去関数
61
+ def load_stopwords() -> set[str]:
62
+ with open(japanese_selection_path, "r", encoding="utf-8") as f:
63
+ # stopwords = [w.strip() for w in f]
64
+ # stopwords = set(stopwords)
65
+ stopwords = {w.strip() for w in f if w.strip()}
66
+ return stopwords
67
+
68
+
69
+ class SudachiTokenizer:
70
+ def __init__(self, split_mode="C"):
71
+ self.tokenizer_obj = dictionary.Dictionary(dict_type="full").create()
72
+ self.stopwords = load_stopwords()
73
+ if split_mode == "A":
74
+ self.mode = tokenizer.Tokenizer.SplitMode.C
75
+ elif split_mode == "B":
76
+ self.mode = tokenizer.Tokenizer.SplitMode.B
77
+ else:
78
+ self.mode = tokenizer.Tokenizer.SplitMode.C
79
+ # ひらがなのみの文字列にマッチする正規表現
80
+ self.kana_re = re.compile("^[ぁ-ゖ]+$")
81
+ #Stopwords
82
+ self.stopwords = load_stopwords()
83
+
84
+ def get_wakati(self, text: str) -> list[str]:
85
+ wakati_list = []
86
+ normalized_wakati_list = []
87
+ pos_list = []
88
+ normalized_text = normalize("NFKC", text)
89
+ tmp = re.sub(r'[0-9]','',normalized_text)
90
+ tmp = re.sub(r'[0-9]', '', tmp)
91
+ tmp = re.sub(r'[、。:()「」%『』()?!%→+`.・×,〜~—+=♪/!?]','',tmp)
92
+ tmp = re.sub(r'[a-zA-Z]','',tmp)
93
+ #絵文字除去
94
+ tmp = re.sub(r'[❓]', "", tmp)
95
+ for m in self.tokenizer_obj.tokenize(tmp, self.mode):
96
+ word = m.surface()
97
+ pos = m.part_of_speech()[0]
98
+ normalized_word = m.normalized_form()
99
+ wakati_list.append(word)
100
+ normalized_wakati_list.append(normalized_word)
101
+ pos_list.append(pos)
102
+ #名詞,動詞,形容詞のみに絞り込み
103
+ target_pos = ["名詞", "動詞", "形容詞"]
104
+ #target_pos = ["名詞", "形容詞"]
105
+ token_list = [t for t, p in zip(wakati_list, pos_list) if p in target_pos]
106
+ #アルファベットを小文字に統一
107
+ token_list = [t.lower() for t in token_list]
108
+ #ひらがなのみの単語を除く
109
+ #token_list = [t for t in token_list if not self.kana_re.match(t)]
110
+ #ストップワード除去
111
+ #token_list = [t for t in token_list if t not in self.stopwords]
112
+ return token_list
113
+
114
+
115
+ def make_traind_model(bert_model):
116
+ trained_models = []
117
+ for k in range(k_folds):
118
+ k = k + 1
119
+ model_path = model_base_path / f"trained_model{k}.pt"
120
+ trained_model = copy.deepcopy(bert_model)
121
+ trained_model.load_state_dict(torch.load(model_path, map_location=device),strict=False)
122
+ trained_models.append(trained_model)
123
+ return trained_models
124
+
125
+
126
+ @st.cache_resource
127
+ def init_models():
128
+ bert_model = BertClassifier(bert_model_name, cls_num=1) #出力ノードを1に設定
129
+ bert_model.eval()
130
+ bert_model.to(device)
131
+
132
+ tokenizer_sudachi = SudachiTokenizer(split_mode="C")
133
+ #Tokenizerの設定(ここではtokenizerをtokenizer_c2にしている)
134
+ tokenizer_c2 = T.BertJapaneseTokenizer.from_pretrained(bert_model_name)
135
+ trained_models = make_traind_model(bert_model)
136
+ return tokenizer_sudachi, tokenizer_c2, trained_models
137
+
138
+
139
+ tokenizer_sudachi, tokenizer_c2, trained_models = init_models()
140
+
141
+
142
+ # Attentionマップ���算出する関数の定義
143
+ def f_a(sentences: list[str], tokenizer_c2, model, device):
144
+ encoded = tokenizer_c2.batch_encode_plus(
145
+ sentences,
146
+ padding="max_length",
147
+ max_length=max_length,
148
+ truncation=True,
149
+ return_attention_mask=True
150
+ )
151
+
152
+ input_ids = torch.tensor(encoded["input_ids"]).to(device)
153
+ attention_mask = torch.tensor(encoded["attention_mask"]).to(device)
154
+
155
+ with torch.no_grad():
156
+ outputs, _, attentions = model(input_ids, attention_mask)
157
+ #return input_ids.detach().cpu(), attentions[-1].detach().cpu()
158
+ return input_ids.detach().cpu(), attentions[-1].detach().cpu(), outputs.detach().cpu()
159
+
160
+
161
+ def get_word_attn(input_ids, attention_weight) -> Generator[tuple[str, float], None, None]:
162
+ # 文章の長さ分のzero tensorを宣言
163
+ seq_len = attention_weight.size()[2]
164
+ all_attens = torch.zeros(seq_len)
165
+
166
+ # 12個のMulti Head Attentionの結果を全部足し合わせる
167
+ # 最初の0はinput_idsは1文章だけを想定しているため
168
+ # 次の0はCLSトークンのAttention結果を取得している、という意味です。
169
+ for i in range(12):
170
+ all_attens += attention_weight[0, i, 0, :]
171
+
172
+ for word, attn in zip(input_ids.flatten(), all_attens):
173
+ if tokenizer_c2.convert_ids_to_tokens(word.tolist()) == "[CLS]":
174
+ continue
175
+ if tokenizer_c2.convert_ids_to_tokens(word.tolist()) == "[SEP]":
176
+ break
177
+ converted_word = tokenizer_c2.convert_ids_to_tokens([word.numpy().tolist()])[0]
178
+ yield converted_word, attn
179
+
180
+
181
+ def classify_ma(sentence: str) -> tuple[int, torch.Tensor, torch.Tensor]:
182
+ normalized_sentence = normalize("NFKC", sentence)
183
+
184
+ attention_list, output_list = [], []
185
+ for trained_model in trained_models:
186
+ input_ids, attention, output = f_a([normalized_sentence], tokenizer_c2, trained_model, device)
187
+ attention_list.append(attention)
188
+ output_list.append(output)
189
+
190
+ #出力された10個の予測値の多数決を算出
191
+ outputs = np.concatenate(output_list)
192
+ prob_column = torch.sigmoid(torch.tensor(outputs))
193
+ pred_column = torch.ge(prob_column, 0.5).float()
194
+ ensemble_pred, count = stats.mode(pred_column)
195
+
196
+ #出力された10個のattention mapの平均値を算出
197
+ attentions = torch.concat(attention_list)
198
+ mean_attention = torch.mean(attentions, dim=0).unsqueeze(dim=0)
199
+ return ensemble_pred.item(), input_ids, mean_attention
200
+
201
+
202
+ #モデルのロードとinferの関数化
203
+ def infer_topic(new_text: str) -> tuple[np.ndarray, float]:
204
+ model_trained = tp.CTModel.load(str(topic_model_trained))
205
+ new_word_list = tokenizer_sudachi.get_wakati(new_text)
206
+ new_doc = model_trained.make_doc(new_word_list)
207
+ topic_dist, ll = model_trained.infer(new_doc)
208
+ return topic_dist, ll
209
+
210
+
211
+ if __name__ == "__main__":
212
+ text = "NHKの番組を見ていると,発達障害者の才能を特集されることが多い。それを見ていると自分もそのような才能を期待されているように感じる"
213
+ result_classify = classify_ma(text)
214
+ if result_classify[0] == 0:
215
+ print("マイクロアグレッションではない")
216
+ else:
217
+ print("マイクロアグレッションである")
218
+ res = infer_topic(text)
219
+ print(res[0])
220
+ print(res[1])