File size: 7,267 Bytes
def91b4
 
 
 
 
 
 
 
 
 
 
 
 
de90a66
def91b4
de90a66
 
def91b4
 
 
 
 
 
 
 
 
 
 
de90a66
 
 
 
def91b4
de90a66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
def91b4
de90a66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
def91b4
de90a66
 
 
 
 
def91b4
de90a66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
def91b4
de90a66
 
 
 
def91b4
de90a66
 
def91b4
de90a66
 
 
def91b4
 
 
 
de90a66
def91b4
de90a66
def91b4
 
 
de90a66
def91b4
 
 
 
de90a66
def91b4
 
 
de90a66
 
def91b4
 
 
 
de90a66
def91b4
 
de90a66
def91b4
 
de90a66
 
def91b4
 
de90a66
 
 
 
 
 
 
 
 
 
def91b4
 
de90a66
 
def91b4
de90a66
 
 
 
 
 
 
def91b4
de90a66
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import random
import torch
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans
from matplotlib.font_manager import FontProperties
import pandas as pd
import seaborn as sns
from collections import Counter
import matplotlib.pyplot as plt
import numpy as np
# モデルとトークナイザーのロード
import time
from transformers import AutoModel
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
import matplotlib.pyplot as plt
# モデルとトークナイザーのパス
model_path = 'use14/bert-base-japanese-v3/2024-0208-0323/model'
tokenizer_path = 'use14/bert-base-japanese-v3/2024-0208-0323/tokenizer'

# トークナイザーとモデルのロード
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path, num_labels=14)

font_path = 'NotoSansCJKjp-Bold.otf'  # 実際のフォントのパスに置き換えてください
font_prop = FontProperties(fname=font_path)
# Streamlitアプリのタイトル
st.title("セリフチェッカー")
# セッション状態にキーが存在しない場合は、初期値を設定
if 'button_clicked' not in st.session_state:
    st.session_state.button_clicked = False

def on_button_click():
    # ボタンがクリックされた時の処理
    st.session_state.button_clicked = True
# 吹き出し風表示用のカスタムCSS

custom_css = """
<style>
.bubble {
    position: relative;
    background: #F0F0F0;
    border-radius: .4em;
    padding: 10px;
    max-width: 95%; /* 吹き出しの最大幅を90%に設定 */
    word-wrap: break-word; /* 長い単語でも折り返しを保証 */
}

.bubble::after {
    content: '';
    position: absolute;
    top: 10px;
    left: -10px;
    width: 0;
    height: 0;
    border: 10px solid transparent;
    border-right-color: #F0F0F0;
    border-left: 0;
    margin-top: 5px;
    margin-left: 0;
}
</style>
"""
# CSSを使ってプログレスバーの色を変更
st.markdown("""
    <style>
    /* プログレスバーの色を変更 */
    .stProgress > div > div > div > div {
        background-color: #008000;
    }
    </style>
    """, unsafe_allow_html=True)
# カテゴリリスト
category_list = [
    '5_紫上鏡一[教師,大人,タメ口]', '20_見嶋千里[怖め,大人,粗雑な言葉,特徴的な笑い]', '29_白鳥王子[キザ,大人,調子いい]',
    '30_白城院素子[お嬢様,少女]', '32_水陰那月[ネガティブ,少年,敬語]', '50_緋崎平一郎[元気,少年,やんちゃ]',
    '76_百知瑠璃[元気,少女,です!,写真]', '91_桜結衣[元気,少女,タメ口,アイドル]', '101_御伽美夜子[落ち着いている,大人,女性口調]',
    '121_荊棘従道[執事,大人,敬語(丁寧語)]', '133_司馬萌香[姉御,少女,粗雑(ヤンキー)]', '134_菜野花[落ち着いてる,少女,敬語]',
    '139_黒冬和馬[少しそっけない,少年,タメ口]', '142_四涼礼子[ダウナー,少女,語尾~,めんどくさがり]'
]
# カテゴリ選択用のセレクトボックス
selected_category = st.selectbox("1.目標キャラクターを選択", category_list)

# 選択されたカテゴリに対応する画像ファイル名の決定
# カテゴリリストのインデックスを取得し、それに1を加えることで1から始まる画像ファイル番号を作成
image_file_number = category_list.index(selected_category) + 1
image_path = f"img/{image_file_number}.png"
image_width=300
judge_text = st.text_input("2.セリフを入力  //例: 貴方達も迷ったんですか?, よし、変身完了だ。, 勿論幽霊は抜きにして…でしょうね。,[ですわ。,だろ!,ですね。]")

st.button("🔍 チェックする", on_click=on_button_click)
st.divider()
# 画面を2つの列に分割
col1, col2 = st.columns([1, 4])
# 左側の列に画像を表示
with col1:
    st.image(image_path, caption=selected_category, width=120)
# 右側の列にテキストボックスを配置
with col2:
    if judge_text:  # ユーザーが何か入力した場合のみ表示
        st.markdown(custom_css, unsafe_allow_html=True)  # カスタムCSSの適用
        st.markdown(f'<div class="bubble">{judge_text}</div>', unsafe_allow_html=True)

# 処理ステップ数に応じてプログレスバーを更新する関数
def update_progress(step, total_steps):
    progress = int((step / total_steps) * 100)
    progress_bar.progress(progress)

total_steps = 5  # 処理を行う総ステップ数
if st.session_state.button_clicked:
    if judge_text:
        # プログレスバーの初期化
        progress_bar = st.progress(0)
        
        # トークナイズとテンソル化
        words = tokenizer.tokenize(judge_text)
        word_ids = tokenizer.convert_tokens_to_ids(words)
        word_tensor = torch.tensor([word_ids[:512]])  # 最大長を512に制限
        update_progress(1, total_steps)

        # デバイスの自動選択
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        word_tensor = word_tensor.to(device)
        model = model.to(device)
        update_progress(2, total_steps)

        # 推論
        with torch.no_grad():
            y = model(word_tensor)
        update_progress(3, total_steps)

        # 最も近いカテゴリの決定
        pred = y.logits.argmax(-1)
        st.write(f"最もらしい: {category_list[pred.item()]}")
        update_progress(4, total_steps)

        # 各クラスの確率計算
        probabilities = torch.softmax(y.logits, dim=-1)
        top_prob, top_cat_indices = probabilities.topk(len(category_list))
        update_progress(5, total_steps)

        # 確率とカテゴリ名の準備
        top_probabilities = top_prob.cpu().numpy()[0]
        top_categories = [category_list[index] for index in top_cat_indices.cpu().numpy()[0]]

        # 棒グラフの作成

        # 棒グラフの作成
        plt.figure(figsize=(10, 6))
        bars = plt.bar(range(len(top_categories)), top_probabilities, color='skyblue')

        # 画像の読み込みと配置
        for i, (bar, category) in enumerate(zip(bars, top_categories)):
            img_path = f'img/{i+1}.png'  # ファイル名はcategory_listの順番+1の番号.png
            image = plt.imread(img_path)
            imagebox = OffsetImage(image, zoom=0.5)  # zoomで画像のサイズを調整
            ab = AnnotationBbox(imagebox, (bar.get_x() + bar.get_width() / 2, bar.get_height()), frameon=False, box_alignment=(0.5, -0.2))
            plt.gca().add_artist(ab)

        plt.xlabel('カテゴリ', fontproperties=font_prop)
        plt.ylabel('確率', fontproperties=font_prop)
        plt.ylim(0, 0.5)  # y軸の範囲設定
        plt.xticks(range(len(top_categories)), top_categories, rotation=45, ha="right", fontproperties=font_prop)
        plt.title('カテゴリ別確率', fontproperties=font_prop)
        plt.show()


        progress_bar.progress(100)
        time.sleep(1)  # 1秒待機
        progress_bar.empty()  # プログレスバーを削除

    else:
        st.write("入力してください。")