import streamlit as st from transformers import AutoTokenizer, AutoModelForSequenceClassification import random import torch from sklearn.preprocessing import StandardScaler from sklearn.cluster import KMeans from matplotlib.font_manager import FontProperties import pandas as pd import seaborn as sns from collections import Counter import matplotlib.pyplot as plt import numpy as np # モデルとトークナイザーのロード import time from transformers import AutoModel from matplotlib.offsetbox import OffsetImage, AnnotationBbox import matplotlib.pyplot as plt # モデルとトークナイザーのパス model_path = 'use14/bert-base-japanese-v3/2024-0208-0323/model' tokenizer_path = 'use14/bert-base-japanese-v3/2024-0208-0323/tokenizer' # トークナイザーとモデルのロード tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) model = AutoModelForSequenceClassification.from_pretrained(model_path, num_labels=14) font_path = 'NotoSansCJKjp-Bold.otf' # 実際のフォントのパスに置き換えてください font_prop = FontProperties(fname=font_path) # Streamlitアプリのタイトル st.title("セリフチェッカー") # セッション状態にキーが存在しない場合は、初期値を設定 if 'button_clicked' not in st.session_state: st.session_state.button_clicked = False def on_button_click(): # ボタンがクリックされた時の処理 st.session_state.button_clicked = True # 吹き出し風表示用のカスタムCSS custom_css = """ """ # CSSを使ってプログレスバーの色を変更 st.markdown(""" """, unsafe_allow_html=True) # カテゴリリスト category_list = [ '5_紫上鏡一[教師,大人,タメ口]', '20_見嶋千里[怖め,大人,粗雑な言葉,特徴的な笑い]', '29_白鳥王子[キザ,大人,調子いい]', '30_白城院素子[お嬢様,少女]', '32_水陰那月[ネガティブ,少年,敬語]', '50_緋崎平一郎[元気,少年,やんちゃ]', '76_百知瑠璃[元気,少女,です!,写真]', '91_桜結衣[元気,少女,タメ口,アイドル]', '101_御伽美夜子[落ち着いている,大人,女性口調]', '121_荊棘従道[執事,大人,敬語(丁寧語)]', '133_司馬萌香[姉御,少女,粗雑(ヤンキー)]', '134_菜野花[落ち着いてる,少女,敬語]', '139_黒冬和馬[少しそっけない,少年,タメ口]', '142_四涼礼子[ダウナー,少女,語尾~,めんどくさがり]' ] # カテゴリ選択用のセレクトボックス selected_category = st.selectbox("1.目標キャラクターを選択", category_list) # 選択されたカテゴリに対応する画像ファイル名の決定 # カテゴリリストのインデックスを取得し、それに1を加えることで1から始まる画像ファイル番号を作成 image_file_number = category_list.index(selected_category) + 1 image_path = f"img/{image_file_number}.png" image_width=300 judge_text = st.text_input("2.セリフを入力 //例: 貴方達も迷ったんですか?, よし、変身完了だ。, 勿論幽霊は抜きにして…でしょうね。,[ですわ。,だろ!,ですね。]") st.button("🔍 チェックする", on_click=on_button_click) st.divider() # 画面を2つの列に分割 col1, col2 = st.columns([1, 4]) # 左側の列に画像を表示 with col1: st.image(image_path, caption=selected_category, width=120) # 右側の列にテキストボックスを配置 with col2: if judge_text: # ユーザーが何か入力した場合のみ表示 st.markdown(custom_css, unsafe_allow_html=True) # カスタムCSSの適用 st.markdown(f'
{judge_text}
', unsafe_allow_html=True) # 処理ステップ数に応じてプログレスバーを更新する関数 def update_progress(step, total_steps): progress = int((step / total_steps) * 100) progress_bar.progress(progress) total_steps = 5 # 処理を行う総ステップ数 if st.session_state.button_clicked: if judge_text: # プログレスバーの初期化 progress_bar = st.progress(0) # トークナイズとテンソル化 words = tokenizer.tokenize(judge_text) word_ids = tokenizer.convert_tokens_to_ids(words) word_tensor = torch.tensor([word_ids[:512]]) # 最大長を512に制限 update_progress(1, total_steps) # デバイスの自動選択 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") word_tensor = word_tensor.to(device) model = model.to(device) update_progress(2, total_steps) # 推論 with torch.no_grad(): y = model(word_tensor) update_progress(3, total_steps) # 最も近いカテゴリの決定 pred = y.logits.argmax(-1) st.write(f"最もらしい: {category_list[pred.item()]}") update_progress(4, total_steps) # 各クラスの確率計算 probabilities = torch.softmax(y.logits, dim=-1) top_prob, top_cat_indices = probabilities.topk(len(category_list)) update_progress(5, total_steps) # 確率とカテゴリ名の準備 top_probabilities = top_prob.cpu().numpy()[0] top_categories = [category_list[index] for index in top_cat_indices.cpu().numpy()[0]] # 棒グラフの作成 # 棒グラフの作成 plt.figure(figsize=(10, 6)) bars = plt.bar(range(len(top_categories)), top_probabilities, color='skyblue') # 画像の読み込みと配置 for i, (bar, category) in enumerate(zip(bars, top_categories)): img_path = f'img/{i+1}.png' # ファイル名はcategory_listの順番+1の番号.png image = plt.imread(img_path) imagebox = OffsetImage(image, zoom=0.5) # zoomで画像のサイズを調整 ab = AnnotationBbox(imagebox, (bar.get_x() + bar.get_width() / 2, bar.get_height()), frameon=False, box_alignment=(0.5, -0.2)) plt.gca().add_artist(ab) plt.xlabel('カテゴリ', fontproperties=font_prop) plt.ylabel('確率', fontproperties=font_prop) plt.ylim(0, 0.5) # y軸の範囲設定 plt.xticks(range(len(top_categories)), top_categories, rotation=45, ha="right", fontproperties=font_prop) plt.title('カテゴリ別確率', fontproperties=font_prop) plt.show() progress_bar.progress(100) time.sleep(1) # 1秒待機 progress_bar.empty() # プログレスバーを削除 else: st.write("入力してください。")