import os import re from pathlib import Path from typing import Generator from unicodedata import normalize import numpy as np import streamlit as st import tomotopy as tp # type: ignore import torch import torch.nn as nn import transformers as T # type: ignore from huggingface_hub import PyTorchModelHubMixin # type: ignore from scipy import stats # type: ignore from sudachipy import dictionary, tokenizer # type: ignore HF_AUTH_TOKEN = os.getenv("HF_AUTH_TOKEN") MODELS_PATH = Path(__file__).parent / "saved_model" # model_base_path = MODELS_PATH / "two_class" MODEL_BASE = "awarefy/awarefy-two_class-trained-" topic_model_trained = MODELS_PATH / "topic" / "trained_model.bin" japanese_selection_path = MODELS_PATH / "stop_words" / "Japanese_selection.txt" # GPUの指定 if torch.cuda.is_available(): gpu = 0 # gpu = -1 # For debugging else: gpu = -1 # gpu = -1 # GPUが使用できなければ(CPUで処理)-1を指定 # cls_num = 3 max_length = 512 k_folds = 10 bert_model_name = "cl-tohoku/bert-base-japanese-v3" device = torch.device(f"cuda:{gpu}" if gpu>=0 else "cpu") #BERTモデルの定義 class BertClassifier(nn.Module, PyTorchModelHubMixin): def __init__(self, cls_num: int): super().__init__() self.bert = T.BertModel.from_pretrained(bert_model_name, output_attentions=True) self.fc = nn.Linear(768, cls_num, bias=True) nn.init.normal_(self.fc.weight, std=0.02) nn.init.normal_(self.fc.bias, 0) def forward(self, input_ids, masks): result = self.bert(input_ids, masks) vec = result[0] _ = result[1] attentions = result[2] vec = vec[:, 0, :] vec = vec.view(-1, 768) output = self.fc(vec) return output, _, attentions #日本語Stopwords除去関数 def load_stopwords() -> set[str]: with open(japanese_selection_path, "r", encoding="utf-8") as f: # stopwords = [w.strip() for w in f] # stopwords = set(stopwords) stopwords = {w.strip() for w in f if w.strip()} return stopwords class SudachiTokenizer: def __init__(self, split_mode="C"): self.tokenizer_obj = dictionary.Dictionary(dict_type="full").create() self.stopwords = load_stopwords() if split_mode == "A": self.mode = tokenizer.Tokenizer.SplitMode.C elif split_mode == "B": self.mode = tokenizer.Tokenizer.SplitMode.B else: self.mode = tokenizer.Tokenizer.SplitMode.C # ひらがなのみの文字列にマッチする正規表現 self.kana_re = re.compile("^[ぁ-ゖ]+$") #Stopwords self.stopwords = load_stopwords() def get_wakati(self, text: str) -> list[str]: wakati_list = [] normalized_wakati_list = [] pos_list = [] normalized_text = normalize("NFKC", text) tmp = re.sub(r'[0-9]','',normalized_text) tmp = re.sub(r'[0-9]', '', tmp) tmp = re.sub(r'[、。:()「」%『』()?!%→+`.・×,〜~—+=♪/!?]','',tmp) tmp = re.sub(r'[a-zA-Z]','',tmp) #絵文字除去 tmp = re.sub(r'[❓]', "", tmp) for m in self.tokenizer_obj.tokenize(tmp, self.mode): word = m.surface() pos = m.part_of_speech()[0] normalized_word = m.normalized_form() wakati_list.append(word) normalized_wakati_list.append(normalized_word) pos_list.append(pos) #名詞,動詞,形容詞のみに絞り込み target_pos = ["名詞", "動詞", "形容詞"] #target_pos = ["名詞", "形容詞"] token_list = [t for t, p in zip(wakati_list, pos_list) if p in target_pos] #アルファベットを小文字に統一 token_list = [t.lower() for t in token_list] #ひらがなのみの単語を除く #token_list = [t for t in token_list if not self.kana_re.match(t)] #ストップワード除去 token_list = [t for t in token_list if t not in self.stopwords] return token_list def make_traind_model(): trained_models = [] for k in range(k_folds): k = k + 1 # model_path = model_base_path / f"trained_model{k}.pt" # trained_model = copy.deepcopy(bert_model) # trained_model.load_state_dict(torch.load(model_path, map_location=device), strict=False) # trained_models.append(trained_model) model_name = MODEL_BASE + str(k) trained_model = BertClassifier.from_pretrained(model_name, token=HF_AUTH_TOKEN).to(device) print(f"Got model {model_name}") trained_models.append(trained_model) return trained_models @st.cache_resource def init_models(): # bert_model = BertClassifier(cls_num=1) #出力ノードを1に設定 # bert_model.eval() # bert_model.to(device) tokenizer_sudachi = SudachiTokenizer(split_mode="C") #Tokenizerの設定(ここではtokenizerをtokenizer_c2にしている) tokenizer_c2 = T.BertJapaneseTokenizer.from_pretrained(bert_model_name) # trained_models = make_traind_model(bert_model) trained_models = make_traind_model() return tokenizer_sudachi, tokenizer_c2, trained_models tokenizer_sudachi, tokenizer_c2, trained_models = init_models() # Attentionマップを算出する関数の定義 def f_a(sentences: list[str], tokenizer_c2, model, device): encoded = tokenizer_c2.batch_encode_plus( sentences, padding="max_length", max_length=max_length, truncation=True, return_attention_mask=True ) input_ids = torch.tensor(encoded["input_ids"]).to(device) attention_mask = torch.tensor(encoded["attention_mask"]).to(device) with torch.no_grad(): outputs, _, attentions = model(input_ids, attention_mask) #return input_ids.detach().cpu(), attentions[-1].detach().cpu() return input_ids.detach().cpu(), attentions[-1].detach().cpu(), outputs.detach().cpu() def get_word_attn(input_ids, attention_weight) -> Generator[tuple[str, float], None, None]: # 文章の長さ分のzero tensorを宣言 seq_len = attention_weight.size()[2] all_attens = torch.zeros(seq_len) # 12個のMulti Head Attentionの結果を全部足し合わせる # 最初の0はinput_idsは1文章だけを想定しているため # 次の0はCLSトークンのAttention結果を取得している、という意味です。 for i in range(12): all_attens += attention_weight[0, i, 0, :] for word, attn in zip(input_ids.flatten(), all_attens): if tokenizer_c2.convert_ids_to_tokens(word.tolist()) == "[CLS]": continue if tokenizer_c2.convert_ids_to_tokens(word.tolist()) == "[SEP]": break converted_word = tokenizer_c2.convert_ids_to_tokens([word.numpy().tolist()])[0] yield converted_word, attn def classify_ma(sentence: str) -> tuple[int, torch.Tensor, torch.Tensor]: normalized_sentence = normalize("NFKC", sentence) tmp = re.sub(r'[0-9]','',normalized_sentence) tmp = re.sub(r'[0-9]', '', tmp) tmp = re.sub(r'[、。:()「」%『』()?!%→+`.・×,〜~—+=♪/!?]','',tmp) tmp = re.sub(r'[a-zA-Z]','',tmp) #絵文字除去 tmp = re.sub(r'[❓]', "", tmp) attention_list, output_list = [], [] for trained_model in trained_models: input_ids, attention, output = f_a([tmp], tokenizer_c2, trained_model, device) attention_list.append(attention) output_list.append(output) #出力された10個の予測値の多数決を算出 outputs = np.concatenate(output_list) prob_column = torch.sigmoid(torch.tensor(outputs)) pred_column = torch.ge(prob_column, 0.5).float() ensemble_pred, count = stats.mode(pred_column) #出力された10個のattention mapの平均値を算出 attentions = torch.concat(attention_list) mean_attention = torch.mean(attentions, dim=0).unsqueeze(dim=0) return ensemble_pred.item(), input_ids, mean_attention #モデルのロードとinferの関数化 def infer_topic(new_text: str) -> tuple[np.ndarray, float]: model_trained = tp.CTModel.load(str(topic_model_trained)) new_word_list = tokenizer_sudachi.get_wakati(new_text) new_doc = model_trained.make_doc(new_word_list) topic_dist, ll = model_trained.infer(new_doc) return topic_dist, ll