Spaces:
Sleeping
Sleeping
import os | |
import re | |
from pathlib import Path | |
from typing import Generator | |
from unicodedata import normalize | |
import numpy as np | |
import streamlit as st | |
import tomotopy as tp # type: ignore | |
import torch | |
import torch.nn as nn | |
import transformers as T # type: ignore | |
from huggingface_hub import PyTorchModelHubMixin # type: ignore | |
from scipy import stats # type: ignore | |
from sudachipy import dictionary, tokenizer # type: ignore | |
HF_AUTH_TOKEN = os.getenv("HF_AUTH_TOKEN") | |
MODELS_PATH = Path(__file__).parent / "saved_model" | |
# model_base_path = MODELS_PATH / "two_class" | |
MODEL_BASE = "awarefy/awarefy-two_class-trained-" | |
topic_model_trained = MODELS_PATH / "topic" / "trained_model.bin" | |
japanese_selection_path = MODELS_PATH / "stop_words" / "Japanese_selection.txt" | |
# GPUの指定 | |
if torch.cuda.is_available(): | |
gpu = 0 | |
# gpu = -1 # For debugging | |
else: | |
gpu = -1 # gpu = -1 # GPUが使用できなければ(CPUで処理)-1を指定 | |
# cls_num = 3 | |
max_length = 512 | |
k_folds = 10 | |
bert_model_name = "cl-tohoku/bert-base-japanese-v3" | |
device = torch.device(f"cuda:{gpu}" if gpu>=0 else "cpu") | |
#BERTモデルの定義 | |
class BertClassifier(nn.Module, PyTorchModelHubMixin): | |
def __init__(self, cls_num: int): | |
super().__init__() | |
self.bert = T.BertModel.from_pretrained(bert_model_name, output_attentions=True) | |
self.fc = nn.Linear(768, cls_num, bias=True) | |
nn.init.normal_(self.fc.weight, std=0.02) | |
nn.init.normal_(self.fc.bias, 0) | |
def forward(self, input_ids, masks): | |
result = self.bert(input_ids, masks) | |
vec = result[0] | |
_ = result[1] | |
attentions = result[2] | |
vec = vec[:, 0, :] | |
vec = vec.view(-1, 768) | |
output = self.fc(vec) | |
return output, _, attentions | |
#日本語Stopwords除去関数 | |
def load_stopwords() -> set[str]: | |
with open(japanese_selection_path, "r", encoding="utf-8") as f: | |
# stopwords = [w.strip() for w in f] | |
# stopwords = set(stopwords) | |
stopwords = {w.strip() for w in f if w.strip()} | |
return stopwords | |
class SudachiTokenizer: | |
def __init__(self, split_mode="C"): | |
self.tokenizer_obj = dictionary.Dictionary(dict_type="full").create() | |
self.stopwords = load_stopwords() | |
if split_mode == "A": | |
self.mode = tokenizer.Tokenizer.SplitMode.C | |
elif split_mode == "B": | |
self.mode = tokenizer.Tokenizer.SplitMode.B | |
else: | |
self.mode = tokenizer.Tokenizer.SplitMode.C | |
# ひらがなのみの文字列にマッチする正規表現 | |
self.kana_re = re.compile("^[ぁ-ゖ]+$") | |
#Stopwords | |
self.stopwords = load_stopwords() | |
def get_wakati(self, text: str) -> list[str]: | |
wakati_list = [] | |
normalized_wakati_list = [] | |
pos_list = [] | |
normalized_text = normalize("NFKC", text) | |
tmp = re.sub(r'[0-9]','',normalized_text) | |
tmp = re.sub(r'[0-9]', '', tmp) | |
tmp = re.sub(r'[、。:()「」%『』()?!%→+`.・×,〜~—+=♪/!?]','',tmp) | |
tmp = re.sub(r'[a-zA-Z]','',tmp) | |
#絵文字除去 | |
tmp = re.sub(r'[❓]', "", tmp) | |
for m in self.tokenizer_obj.tokenize(tmp, self.mode): | |
word = m.surface() | |
pos = m.part_of_speech()[0] | |
normalized_word = m.normalized_form() | |
wakati_list.append(word) | |
normalized_wakati_list.append(normalized_word) | |
pos_list.append(pos) | |
#名詞,動詞,形容詞のみに絞り込み | |
target_pos = ["名詞", "動詞", "形容詞"] | |
#target_pos = ["名詞", "形容詞"] | |
token_list = [t for t, p in zip(wakati_list, pos_list) if p in target_pos] | |
#アルファベットを小文字に統一 | |
token_list = [t.lower() for t in token_list] | |
#ひらがなのみの単語を除く | |
#token_list = [t for t in token_list if not self.kana_re.match(t)] | |
#ストップワード除去 | |
token_list = [t for t in token_list if t not in self.stopwords] | |
return token_list | |
def make_traind_model(): | |
trained_models = [] | |
for k in range(k_folds): | |
k = k + 1 | |
# model_path = model_base_path / f"trained_model{k}.pt" | |
# trained_model = copy.deepcopy(bert_model) | |
# trained_model.load_state_dict(torch.load(model_path, map_location=device), strict=False) | |
# trained_models.append(trained_model) | |
model_name = MODEL_BASE + str(k) | |
trained_model = BertClassifier.from_pretrained(model_name, token=HF_AUTH_TOKEN).to(device) | |
print(f"Got model {model_name}") | |
trained_models.append(trained_model) | |
return trained_models | |
def init_models(): | |
# bert_model = BertClassifier(cls_num=1) #出力ノードを1に設定 | |
# bert_model.eval() | |
# bert_model.to(device) | |
tokenizer_sudachi = SudachiTokenizer(split_mode="C") | |
#Tokenizerの設定(ここではtokenizerをtokenizer_c2にしている) | |
tokenizer_c2 = T.BertJapaneseTokenizer.from_pretrained(bert_model_name) | |
# trained_models = make_traind_model(bert_model) | |
trained_models = make_traind_model() | |
return tokenizer_sudachi, tokenizer_c2, trained_models | |
tokenizer_sudachi, tokenizer_c2, trained_models = init_models() | |
# Attentionマップを算出する関数の定義 | |
def f_a(sentences: list[str], tokenizer_c2, model, device): | |
encoded = tokenizer_c2.batch_encode_plus( | |
sentences, | |
padding="max_length", | |
max_length=max_length, | |
truncation=True, | |
return_attention_mask=True | |
) | |
input_ids = torch.tensor(encoded["input_ids"]).to(device) | |
attention_mask = torch.tensor(encoded["attention_mask"]).to(device) | |
with torch.no_grad(): | |
outputs, _, attentions = model(input_ids, attention_mask) | |
#return input_ids.detach().cpu(), attentions[-1].detach().cpu() | |
return input_ids.detach().cpu(), attentions[-1].detach().cpu(), outputs.detach().cpu() | |
def get_word_attn(input_ids, attention_weight) -> Generator[tuple[str, float], None, None]: | |
# 文章の長さ分のzero tensorを宣言 | |
seq_len = attention_weight.size()[2] | |
all_attens = torch.zeros(seq_len) | |
# 12個のMulti Head Attentionの結果を全部足し合わせる | |
# 最初の0はinput_idsは1文章だけを想定しているため | |
# 次の0はCLSトークンのAttention結果を取得している、という意味です。 | |
for i in range(12): | |
all_attens += attention_weight[0, i, 0, :] | |
for word, attn in zip(input_ids.flatten(), all_attens): | |
if tokenizer_c2.convert_ids_to_tokens(word.tolist()) == "[CLS]": | |
continue | |
if tokenizer_c2.convert_ids_to_tokens(word.tolist()) == "[SEP]": | |
break | |
converted_word = tokenizer_c2.convert_ids_to_tokens([word.numpy().tolist()])[0] | |
yield converted_word, attn | |
def classify_ma(sentence: str) -> tuple[int, torch.Tensor, torch.Tensor]: | |
normalized_sentence = normalize("NFKC", sentence) | |
tmp = re.sub(r'[0-9]','',normalized_sentence) | |
tmp = re.sub(r'[0-9]', '', tmp) | |
tmp = re.sub(r'[、。:()「」%『』()?!%→+`.・×,〜~—+=♪/!?]','',tmp) | |
tmp = re.sub(r'[a-zA-Z]','',tmp) | |
#絵文字除去 | |
tmp = re.sub(r'[❓]', "", tmp) | |
attention_list, output_list = [], [] | |
for trained_model in trained_models: | |
input_ids, attention, output = f_a([tmp], tokenizer_c2, trained_model, device) | |
attention_list.append(attention) | |
output_list.append(output) | |
#出力された10個の予測値の多数決を算出 | |
outputs = np.concatenate(output_list) | |
prob_column = torch.sigmoid(torch.tensor(outputs)) | |
pred_column = torch.ge(prob_column, 0.5).float() | |
ensemble_pred, count = stats.mode(pred_column) | |
#出力された10個のattention mapの平均値を算出 | |
attentions = torch.concat(attention_list) | |
mean_attention = torch.mean(attentions, dim=0).unsqueeze(dim=0) | |
return ensemble_pred.item(), input_ids, mean_attention | |
#モデルのロードとinferの関数化 | |
def infer_topic(new_text: str) -> tuple[np.ndarray, float]: | |
model_trained = tp.CTModel.load(str(topic_model_trained)) | |
new_word_list = tokenizer_sudachi.get_wakati(new_text) | |
new_doc = model_trained.make_doc(new_word_list) | |
topic_dist, ll = model_trained.infer(new_doc) | |
return topic_dist, ll | |