File size: 8,488 Bytes
9968acc
5a7a866
 
 
 
 
 
 
 
 
 
 
e88a707
5a7a866
 
 
9968acc
763dc1a
 
26879ed
 
763dc1a
 
5a7a866
 
 
 
 
 
 
 
 
e88a707
5a7a866
 
 
 
 
 
 
e88a707
9968acc
e88a707
 
5a7a866
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b680ce
5a7a866
83c7586
5a7a866
 
 
e88a707
5a7a866
 
 
e88a707
 
 
 
 
9968acc
097953f
5a7a866
 
 
 
 
 
e88a707
 
 
5a7a866
 
 
 
e88a707
 
5a7a866
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b680ce
 
 
 
 
 
 
5a7a866
 
4b680ce
5a7a866
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d29ce9c
5a7a866
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
import os
import re
from pathlib import Path
from typing import Generator
from unicodedata import normalize

import numpy as np
import streamlit as st
import tomotopy as tp  # type: ignore
import torch
import torch.nn as nn
import transformers as T  # type: ignore
from huggingface_hub import PyTorchModelHubMixin  # type: ignore
from scipy import stats  # type: ignore
from sudachipy import dictionary, tokenizer  # type: ignore

HF_AUTH_TOKEN = os.getenv("HF_AUTH_TOKEN")

MODELS_PATH = Path(__file__).parent / "saved_model"
# model_base_path = MODELS_PATH / "two_class"
MODEL_BASE = "awarefy/awarefy-two_class-trained-"
topic_model_trained = MODELS_PATH / "topic" / "trained_model.bin"
japanese_selection_path = MODELS_PATH / "stop_words" / "Japanese_selection.txt"

# GPUの指定
if torch.cuda.is_available():
    gpu = 0
    # gpu = -1  # For debugging
else:
    gpu = -1  # gpu = -1  # GPUが使用できなければ(CPUで処理)-1を指定


# cls_num = 3
max_length = 512
k_folds = 10
bert_model_name = "cl-tohoku/bert-base-japanese-v3"
device = torch.device(f"cuda:{gpu}" if gpu>=0 else "cpu")


#BERTモデルの定義
class BertClassifier(nn.Module, PyTorchModelHubMixin):
    def __init__(self, cls_num: int):
        super().__init__()
        self.bert = T.BertModel.from_pretrained(bert_model_name, output_attentions=True)
        self.fc = nn.Linear(768, cls_num, bias=True)

        nn.init.normal_(self.fc.weight, std=0.02)
        nn.init.normal_(self.fc.bias, 0)

    def forward(self, input_ids, masks):
        result = self.bert(input_ids, masks)

        vec = result[0]
        _ = result[1]
        attentions = result[2]

        vec = vec[:, 0, :]
        vec = vec.view(-1, 768)
        output = self.fc(vec)
        return output, _, attentions


#日本語Stopwords除去関数
def load_stopwords() -> set[str]:
    with open(japanese_selection_path, "r", encoding="utf-8") as f:
        # stopwords = [w.strip() for w in f]
        # stopwords = set(stopwords)
        stopwords = {w.strip() for w in f if w.strip()}
    return stopwords


class SudachiTokenizer:
    def __init__(self, split_mode="C"):
        self.tokenizer_obj = dictionary.Dictionary(dict_type="full").create()
        self.stopwords = load_stopwords()
        if split_mode == "A":
            self.mode = tokenizer.Tokenizer.SplitMode.C
        elif split_mode == "B":
            self.mode = tokenizer.Tokenizer.SplitMode.B
        else:
            self.mode = tokenizer.Tokenizer.SplitMode.C
        # ひらがなのみの文字列にマッチする正規表現
        self.kana_re = re.compile("^[ぁ-ゖ]+$")
        #Stopwords
        self.stopwords = load_stopwords()

    def get_wakati(self, text: str) -> list[str]:
        wakati_list = []
        normalized_wakati_list = []
        pos_list = []
        normalized_text = normalize("NFKC", text)
        tmp = re.sub(r'[0-9]','',normalized_text)
        tmp = re.sub(r'[0-9]', '', tmp)
        tmp = re.sub(r'[、。:()「」%『』()?!%→+`.・×,〜~—+=♪/!?]','',tmp)
        tmp = re.sub(r'[a-zA-Z]','',tmp)
        #絵文字除去
        tmp = re.sub(r'[❓]', "", tmp)
        for m in self.tokenizer_obj.tokenize(tmp, self.mode):
            word = m.surface()
            pos = m.part_of_speech()[0]
            normalized_word = m.normalized_form()
            wakati_list.append(word)
            normalized_wakati_list.append(normalized_word)
            pos_list.append(pos)
        #名詞,動詞,形容詞のみに絞り込み
        target_pos = ["名詞", "動詞", "形容詞"]
        #target_pos = ["名詞", "形容詞"]
        token_list = [t for t, p in zip(wakati_list, pos_list) if p in target_pos]
        #アルファベットを小文字に統一
        token_list = [t.lower() for t in token_list]
        #ひらがなのみの単語を除く
        #token_list = [t for t in token_list if not self.kana_re.match(t)]
        #ストップワード除去
        token_list = [t for t in token_list if t not in self.stopwords]
        return token_list
    

def make_traind_model():
    trained_models = []
    for k in range(k_folds):
        k = k + 1
        # model_path = model_base_path / f"trained_model{k}.pt"
        # trained_model = copy.deepcopy(bert_model)
        # trained_model.load_state_dict(torch.load(model_path, map_location=device), strict=False)
        # trained_models.append(trained_model)
        model_name = MODEL_BASE + str(k)
        trained_model = BertClassifier.from_pretrained(model_name, token=HF_AUTH_TOKEN).to(device)
        print(f"Got model {model_name}")
        trained_models.append(trained_model)
    return trained_models


@st.cache_resource
def init_models():
    # bert_model = BertClassifier(cls_num=1) #出力ノードを1に設定
    # bert_model.eval()
    # bert_model.to(device)

    tokenizer_sudachi = SudachiTokenizer(split_mode="C")
    #Tokenizerの設定(ここではtokenizerをtokenizer_c2にしている)
    tokenizer_c2 = T.BertJapaneseTokenizer.from_pretrained(bert_model_name)
    # trained_models = make_traind_model(bert_model)
    trained_models = make_traind_model()
    return tokenizer_sudachi, tokenizer_c2, trained_models


tokenizer_sudachi, tokenizer_c2, trained_models = init_models()


# Attentionマップを算出する関数の定義
def f_a(sentences: list[str], tokenizer_c2, model, device):
    encoded = tokenizer_c2.batch_encode_plus(
                sentences,
                padding="max_length",
                max_length=max_length,
                truncation=True,
                return_attention_mask=True
                )

    input_ids = torch.tensor(encoded["input_ids"]).to(device)
    attention_mask = torch.tensor(encoded["attention_mask"]).to(device)

    with torch.no_grad():
        outputs, _, attentions = model(input_ids, attention_mask)
    #return input_ids.detach().cpu(), attentions[-1].detach().cpu()
    return input_ids.detach().cpu(), attentions[-1].detach().cpu(), outputs.detach().cpu()


def get_word_attn(input_ids, attention_weight) -> Generator[tuple[str, float], None, None]:
    # 文章の長さ分のzero tensorを宣言
    seq_len = attention_weight.size()[2]
    all_attens = torch.zeros(seq_len)

    # 12個のMulti Head Attentionの結果を全部足し合わせる
    # 最初の0はinput_idsは1文章だけを想定しているため
    # 次の0はCLSトークンのAttention結果を取得している、という意味です。
    for i in range(12):
        all_attens += attention_weight[0, i, 0, :]

    for word, attn in zip(input_ids.flatten(), all_attens):
        if tokenizer_c2.convert_ids_to_tokens(word.tolist()) == "[CLS]":
            continue
        if tokenizer_c2.convert_ids_to_tokens(word.tolist()) == "[SEP]":
            break
        converted_word = tokenizer_c2.convert_ids_to_tokens([word.numpy().tolist()])[0]
        yield converted_word, attn


def classify_ma(sentence: str) -> tuple[int, torch.Tensor, torch.Tensor]:
    normalized_sentence = normalize("NFKC", sentence)
    tmp = re.sub(r'[0-9]','',normalized_sentence)
    tmp = re.sub(r'[0-9]', '', tmp)
    tmp = re.sub(r'[、。:()「」%『』()?!%→+`.・×,〜~—+=♪/!?]','',tmp)
    tmp = re.sub(r'[a-zA-Z]','',tmp)
    #絵文字除去
    tmp = re.sub(r'[❓]', "", tmp)
    
    attention_list, output_list = [], []
    for trained_model in trained_models:
        input_ids, attention, output = f_a([tmp], tokenizer_c2, trained_model, device)
        attention_list.append(attention)
        output_list.append(output)

    #出力された10個の予測値の多数決を算出
    outputs = np.concatenate(output_list)
    prob_column = torch.sigmoid(torch.tensor(outputs))
    pred_column = torch.ge(prob_column, 0.5).float()
    ensemble_pred, count = stats.mode(pred_column)

    #出力された10個のattention mapの平均値を算出
    attentions = torch.concat(attention_list)
    mean_attention = torch.mean(attentions, dim=0).unsqueeze(dim=0)
    return ensemble_pred.item(), input_ids, mean_attention


#モデルのロードとinferの関数化
def infer_topic(new_text: str) -> tuple[np.ndarray, float]:
    model_trained = tp.CTModel.load(str(topic_model_trained))
    new_word_list = tokenizer_sudachi.get_wakati(new_text)
    new_doc = model_trained.make_doc(new_word_list)
    topic_dist, ll = model_trained.infer(new_doc)
    return topic_dist, ll