amp / inference.py
ogawaal's picture
Update inference.py
d29ce9c verified
import os
import re
from pathlib import Path
from typing import Generator
from unicodedata import normalize
import numpy as np
import streamlit as st
import tomotopy as tp # type: ignore
import torch
import torch.nn as nn
import transformers as T # type: ignore
from huggingface_hub import PyTorchModelHubMixin # type: ignore
from scipy import stats # type: ignore
from sudachipy import dictionary, tokenizer # type: ignore
HF_AUTH_TOKEN = os.getenv("HF_AUTH_TOKEN")
MODELS_PATH = Path(__file__).parent / "saved_model"
# model_base_path = MODELS_PATH / "two_class"
MODEL_BASE = "awarefy/awarefy-two_class-trained-"
topic_model_trained = MODELS_PATH / "topic" / "trained_model.bin"
japanese_selection_path = MODELS_PATH / "stop_words" / "Japanese_selection.txt"
# GPUの指定
if torch.cuda.is_available():
gpu = 0
# gpu = -1 # For debugging
else:
gpu = -1 # gpu = -1 # GPUが使用できなければ(CPUで処理)-1を指定
# cls_num = 3
max_length = 512
k_folds = 10
bert_model_name = "cl-tohoku/bert-base-japanese-v3"
device = torch.device(f"cuda:{gpu}" if gpu>=0 else "cpu")
#BERTモデルの定義
class BertClassifier(nn.Module, PyTorchModelHubMixin):
def __init__(self, cls_num: int):
super().__init__()
self.bert = T.BertModel.from_pretrained(bert_model_name, output_attentions=True)
self.fc = nn.Linear(768, cls_num, bias=True)
nn.init.normal_(self.fc.weight, std=0.02)
nn.init.normal_(self.fc.bias, 0)
def forward(self, input_ids, masks):
result = self.bert(input_ids, masks)
vec = result[0]
_ = result[1]
attentions = result[2]
vec = vec[:, 0, :]
vec = vec.view(-1, 768)
output = self.fc(vec)
return output, _, attentions
#日本語Stopwords除去関数
def load_stopwords() -> set[str]:
with open(japanese_selection_path, "r", encoding="utf-8") as f:
# stopwords = [w.strip() for w in f]
# stopwords = set(stopwords)
stopwords = {w.strip() for w in f if w.strip()}
return stopwords
class SudachiTokenizer:
def __init__(self, split_mode="C"):
self.tokenizer_obj = dictionary.Dictionary(dict_type="full").create()
self.stopwords = load_stopwords()
if split_mode == "A":
self.mode = tokenizer.Tokenizer.SplitMode.C
elif split_mode == "B":
self.mode = tokenizer.Tokenizer.SplitMode.B
else:
self.mode = tokenizer.Tokenizer.SplitMode.C
# ひらがなのみの文字列にマッチする正規表現
self.kana_re = re.compile("^[ぁ-ゖ]+$")
#Stopwords
self.stopwords = load_stopwords()
def get_wakati(self, text: str) -> list[str]:
wakati_list = []
normalized_wakati_list = []
pos_list = []
normalized_text = normalize("NFKC", text)
tmp = re.sub(r'[0-9]','',normalized_text)
tmp = re.sub(r'[0-9]', '', tmp)
tmp = re.sub(r'[、。:()「」%『』()?!%→+`.・×,〜~—+=♪/!?]','',tmp)
tmp = re.sub(r'[a-zA-Z]','',tmp)
#絵文字除去
tmp = re.sub(r'[❓]', "", tmp)
for m in self.tokenizer_obj.tokenize(tmp, self.mode):
word = m.surface()
pos = m.part_of_speech()[0]
normalized_word = m.normalized_form()
wakati_list.append(word)
normalized_wakati_list.append(normalized_word)
pos_list.append(pos)
#名詞,動詞,形容詞のみに絞り込み
target_pos = ["名詞", "動詞", "形容詞"]
#target_pos = ["名詞", "形容詞"]
token_list = [t for t, p in zip(wakati_list, pos_list) if p in target_pos]
#アルファベットを小文字に統一
token_list = [t.lower() for t in token_list]
#ひらがなのみの単語を除く
#token_list = [t for t in token_list if not self.kana_re.match(t)]
#ストップワード除去
token_list = [t for t in token_list if t not in self.stopwords]
return token_list
def make_traind_model():
trained_models = []
for k in range(k_folds):
k = k + 1
# model_path = model_base_path / f"trained_model{k}.pt"
# trained_model = copy.deepcopy(bert_model)
# trained_model.load_state_dict(torch.load(model_path, map_location=device), strict=False)
# trained_models.append(trained_model)
model_name = MODEL_BASE + str(k)
trained_model = BertClassifier.from_pretrained(model_name, token=HF_AUTH_TOKEN).to(device)
print(f"Got model {model_name}")
trained_models.append(trained_model)
return trained_models
@st.cache_resource
def init_models():
# bert_model = BertClassifier(cls_num=1) #出力ノードを1に設定
# bert_model.eval()
# bert_model.to(device)
tokenizer_sudachi = SudachiTokenizer(split_mode="C")
#Tokenizerの設定(ここではtokenizerをtokenizer_c2にしている)
tokenizer_c2 = T.BertJapaneseTokenizer.from_pretrained(bert_model_name)
# trained_models = make_traind_model(bert_model)
trained_models = make_traind_model()
return tokenizer_sudachi, tokenizer_c2, trained_models
tokenizer_sudachi, tokenizer_c2, trained_models = init_models()
# Attentionマップを算出する関数の定義
def f_a(sentences: list[str], tokenizer_c2, model, device):
encoded = tokenizer_c2.batch_encode_plus(
sentences,
padding="max_length",
max_length=max_length,
truncation=True,
return_attention_mask=True
)
input_ids = torch.tensor(encoded["input_ids"]).to(device)
attention_mask = torch.tensor(encoded["attention_mask"]).to(device)
with torch.no_grad():
outputs, _, attentions = model(input_ids, attention_mask)
#return input_ids.detach().cpu(), attentions[-1].detach().cpu()
return input_ids.detach().cpu(), attentions[-1].detach().cpu(), outputs.detach().cpu()
def get_word_attn(input_ids, attention_weight) -> Generator[tuple[str, float], None, None]:
# 文章の長さ分のzero tensorを宣言
seq_len = attention_weight.size()[2]
all_attens = torch.zeros(seq_len)
# 12個のMulti Head Attentionの結果を全部足し合わせる
# 最初の0はinput_idsは1文章だけを想定しているため
# 次の0はCLSトークンのAttention結果を取得している、という意味です。
for i in range(12):
all_attens += attention_weight[0, i, 0, :]
for word, attn in zip(input_ids.flatten(), all_attens):
if tokenizer_c2.convert_ids_to_tokens(word.tolist()) == "[CLS]":
continue
if tokenizer_c2.convert_ids_to_tokens(word.tolist()) == "[SEP]":
break
converted_word = tokenizer_c2.convert_ids_to_tokens([word.numpy().tolist()])[0]
yield converted_word, attn
def classify_ma(sentence: str) -> tuple[int, torch.Tensor, torch.Tensor]:
normalized_sentence = normalize("NFKC", sentence)
tmp = re.sub(r'[0-9]','',normalized_sentence)
tmp = re.sub(r'[0-9]', '', tmp)
tmp = re.sub(r'[、。:()「」%『』()?!%→+`.・×,〜~—+=♪/!?]','',tmp)
tmp = re.sub(r'[a-zA-Z]','',tmp)
#絵文字除去
tmp = re.sub(r'[❓]', "", tmp)
attention_list, output_list = [], []
for trained_model in trained_models:
input_ids, attention, output = f_a([tmp], tokenizer_c2, trained_model, device)
attention_list.append(attention)
output_list.append(output)
#出力された10個の予測値の多数決を算出
outputs = np.concatenate(output_list)
prob_column = torch.sigmoid(torch.tensor(outputs))
pred_column = torch.ge(prob_column, 0.5).float()
ensemble_pred, count = stats.mode(pred_column)
#出力された10個のattention mapの平均値を算出
attentions = torch.concat(attention_list)
mean_attention = torch.mean(attentions, dim=0).unsqueeze(dim=0)
return ensemble_pred.item(), input_ids, mean_attention
#モデルのロードとinferの関数化
def infer_topic(new_text: str) -> tuple[np.ndarray, float]:
model_trained = tp.CTModel.load(str(topic_model_trained))
new_word_list = tokenizer_sudachi.get_wakati(new_text)
new_doc = model_trained.make_doc(new_word_list)
topic_dist, ll = model_trained.infer(new_doc)
return topic_dist, ll