Spaces:
Sleeping
Sleeping
Upload 15 files
Browse files- app.py +113 -20
- img/1.png +0 -0
- img/10.png +0 -0
- img/11.png +0 -0
- img/12.png +0 -0
- img/13.png +0 -0
- img/14.png +0 -0
- img/2.png +0 -0
- img/3.png +0 -0
- img/4.png +0 -0
- img/5.png +0 -0
- img/6.png +0 -0
- img/7.png +0 -0
- img/8.png +0 -0
- img/9.png +0 -0
app.py
CHANGED
@@ -11,8 +11,10 @@ from collections import Counter
|
|
11 |
import matplotlib.pyplot as plt
|
12 |
import numpy as np
|
13 |
# モデルとトークナイザーのロード
|
|
|
14 |
from transformers import AutoModel
|
15 |
-
|
|
|
16 |
# モデルとトークナイザーのパス
|
17 |
model_path = 'use14/bert-base-japanese-v3/2024-0208-0323/model'
|
18 |
tokenizer_path = 'use14/bert-base-japanese-v3/2024-0208-0323/tokenizer'
|
@@ -24,58 +26,149 @@ model = AutoModelForSequenceClassification.from_pretrained(model_path, num_label
|
|
24 |
font_path = 'NotoSansCJKjp-Bold.otf' # 実際のフォントのパスに置き換えてください
|
25 |
font_prop = FontProperties(fname=font_path)
|
26 |
# Streamlitアプリのタイトル
|
27 |
-
st.title(
|
|
|
|
|
|
|
28 |
|
29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
category_list = [
|
32 |
-
'5_紫上鏡一
|
33 |
-
'30_白城院素子
|
34 |
-
'76_百知瑠璃
|
35 |
-
'121_荊棘従道
|
36 |
-
'139_黒冬和馬
|
37 |
]
|
38 |
-
#
|
39 |
-
st.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
-
#
|
42 |
-
|
|
|
|
|
43 |
|
44 |
-
|
|
|
45 |
if judge_text:
|
|
|
|
|
|
|
46 |
# トークナイズとテンソル化
|
47 |
words = tokenizer.tokenize(judge_text)
|
48 |
word_ids = tokenizer.convert_tokens_to_ids(words)
|
49 |
word_tensor = torch.tensor([word_ids[:512]]) # 最大長を512に制限
|
|
|
50 |
|
51 |
-
# デバイスの自動選択
|
52 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
53 |
word_tensor = word_tensor.to(device)
|
54 |
model = model.to(device)
|
|
|
55 |
|
56 |
# 推論
|
57 |
with torch.no_grad():
|
58 |
y = model(word_tensor)
|
|
|
59 |
|
60 |
# 最も近いカテゴリの決定
|
61 |
pred = y.logits.argmax(-1)
|
62 |
-
st.write(f"
|
|
|
63 |
|
64 |
# 各クラスの確率計算
|
65 |
probabilities = torch.softmax(y.logits, dim=-1)
|
66 |
top_prob, top_cat_indices = probabilities.topk(len(category_list))
|
|
|
67 |
|
68 |
# 確率とカテゴリ名の準備
|
69 |
-
top_probabilities = top_prob.cpu().numpy()[0]
|
70 |
top_categories = [category_list[index] for index in top_cat_indices.cpu().numpy()[0]]
|
71 |
|
|
|
|
|
72 |
# 棒グラフの作成
|
73 |
plt.figure(figsize=(10, 6))
|
74 |
-
plt.bar(top_categories, top_probabilities, color='skyblue')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
plt.xlabel('カテゴリ', fontproperties=font_prop)
|
76 |
plt.ylabel('確率', fontproperties=font_prop)
|
77 |
-
plt.
|
|
|
78 |
plt.title('カテゴリ別確率', fontproperties=font_prop)
|
79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
else:
|
81 |
-
st.write("
|
|
|
11 |
import matplotlib.pyplot as plt
|
12 |
import numpy as np
|
13 |
# モデルとトークナイザーのロード
|
14 |
+
import time
|
15 |
from transformers import AutoModel
|
16 |
+
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
|
17 |
+
import matplotlib.pyplot as plt
|
18 |
# モデルとトークナイザーのパス
|
19 |
model_path = 'use14/bert-base-japanese-v3/2024-0208-0323/model'
|
20 |
tokenizer_path = 'use14/bert-base-japanese-v3/2024-0208-0323/tokenizer'
|
|
|
26 |
font_path = 'NotoSansCJKjp-Bold.otf' # 実際のフォントのパスに置き換えてください
|
27 |
font_prop = FontProperties(fname=font_path)
|
28 |
# Streamlitアプリのタイトル
|
29 |
+
st.title("セリフチェッカー")
|
30 |
+
# セッション状態にキーが存在しない場合は、初期値を設定
|
31 |
+
if 'button_clicked' not in st.session_state:
|
32 |
+
st.session_state.button_clicked = False
|
33 |
|
34 |
+
def on_button_click():
|
35 |
+
# ボタンがクリックされた時の処理
|
36 |
+
st.session_state.button_clicked = True
|
37 |
+
# 吹き出し風表示用のカスタムCSS
|
38 |
+
|
39 |
+
custom_css = """
|
40 |
+
<style>
|
41 |
+
.bubble {
|
42 |
+
position: relative;
|
43 |
+
background: #F0F0F0;
|
44 |
+
border-radius: .4em;
|
45 |
+
padding: 10px;
|
46 |
+
max-width: 95%; /* 吹き出しの最大幅を90%に設定 */
|
47 |
+
word-wrap: break-word; /* 長い単語でも折り返しを保証 */
|
48 |
+
}
|
49 |
|
50 |
+
.bubble::after {
|
51 |
+
content: '';
|
52 |
+
position: absolute;
|
53 |
+
top: 10px;
|
54 |
+
left: -10px;
|
55 |
+
width: 0;
|
56 |
+
height: 0;
|
57 |
+
border: 10px solid transparent;
|
58 |
+
border-right-color: #F0F0F0;
|
59 |
+
border-left: 0;
|
60 |
+
margin-top: 5px;
|
61 |
+
margin-left: 0;
|
62 |
+
}
|
63 |
+
</style>
|
64 |
+
"""
|
65 |
+
# CSSを使ってプログレスバーの色を変更
|
66 |
+
st.markdown("""
|
67 |
+
<style>
|
68 |
+
/* プログレスバーの色を変更 */
|
69 |
+
.stProgress > div > div > div > div {
|
70 |
+
background-color: #008000;
|
71 |
+
}
|
72 |
+
</style>
|
73 |
+
""", unsafe_allow_html=True)
|
74 |
+
# カテゴリリスト
|
75 |
category_list = [
|
76 |
+
'5_紫上鏡一[教師,大人,タメ口]', '20_見嶋千里[怖め,大人,粗雑な言葉,特徴的な笑い]', '29_白鳥王子[キザ,大人,調子いい]',
|
77 |
+
'30_白城院素子[お嬢様,少女]', '32_水陰那月[ネガティブ,少年,敬語]', '50_緋崎平一郎[元気,少年,やんちゃ]',
|
78 |
+
'76_百知瑠璃[元気,少女,です!,写真]', '91_桜結衣[元気,少女,タメ口,アイドル]', '101_御伽美夜子[落ち着いている,大人,女性口調]',
|
79 |
+
'121_荊棘従道[執事,大人,敬語(丁寧語)]', '133_司馬萌香[姉御,少女,粗雑(ヤンキー)]', '134_菜野花[落ち着いてる,少女,敬語]',
|
80 |
+
'139_黒冬和馬[少しそっけない,少年,タメ口]', '142_四涼礼子[ダウナー,少女,語尾~,めんどくさがり]'
|
81 |
]
|
82 |
+
# カテゴリ選択用のセレクトボックス
|
83 |
+
selected_category = st.selectbox("1.目標キャラクターを選択", category_list)
|
84 |
+
|
85 |
+
# 選択されたカテゴリに対応する画像ファイル名の決定
|
86 |
+
# カテゴリリストのインデックスを取得し、それに1を加えることで1から始まる画像ファイル番号を作成
|
87 |
+
image_file_number = category_list.index(selected_category) + 1
|
88 |
+
image_path = f"img/{image_file_number}.png"
|
89 |
+
image_width=300
|
90 |
+
judge_text = st.text_input("2.セリフを入力 //例: 貴方達も迷ったんですか?, よし、変身完了だ。, 勿論幽霊は抜きにして…でしょうね。,[ですわ。,だろ!,ですね。]")
|
91 |
+
|
92 |
+
st.button("🔍 チェックする", on_click=on_button_click)
|
93 |
+
st.divider()
|
94 |
+
# 画面を2つの列に分割
|
95 |
+
col1, col2 = st.columns([1, 4])
|
96 |
+
# 左側の列に画像を表示
|
97 |
+
with col1:
|
98 |
+
st.image(image_path, caption=selected_category, width=120)
|
99 |
+
# 右側の列にテキストボックスを配置
|
100 |
+
with col2:
|
101 |
+
if judge_text: # ユーザーが何か入力した場合のみ表示
|
102 |
+
st.markdown(custom_css, unsafe_allow_html=True) # カスタムCSSの適用
|
103 |
+
st.markdown(f'<div class="bubble">{judge_text}</div>', unsafe_allow_html=True)
|
104 |
|
105 |
+
# 処理ステップ数に応じてプロ��レスバーを更新する関数
|
106 |
+
def update_progress(step, total_steps):
|
107 |
+
progress = int((step / total_steps) * 100)
|
108 |
+
progress_bar.progress(progress)
|
109 |
|
110 |
+
total_steps = 5 # 処理を行う総ステップ数
|
111 |
+
if st.session_state.button_clicked:
|
112 |
if judge_text:
|
113 |
+
# プログレスバーの初期化
|
114 |
+
progress_bar = st.progress(0)
|
115 |
+
|
116 |
# トークナイズとテンソル化
|
117 |
words = tokenizer.tokenize(judge_text)
|
118 |
word_ids = tokenizer.convert_tokens_to_ids(words)
|
119 |
word_tensor = torch.tensor([word_ids[:512]]) # 最大長を512に制限
|
120 |
+
update_progress(1, total_steps)
|
121 |
|
122 |
+
# デバイスの自動選択
|
123 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
124 |
word_tensor = word_tensor.to(device)
|
125 |
model = model.to(device)
|
126 |
+
update_progress(2, total_steps)
|
127 |
|
128 |
# 推論
|
129 |
with torch.no_grad():
|
130 |
y = model(word_tensor)
|
131 |
+
update_progress(3, total_steps)
|
132 |
|
133 |
# 最も近いカテゴリの決定
|
134 |
pred = y.logits.argmax(-1)
|
135 |
+
st.write(f"最もらしい: {category_list[pred.item()]}")
|
136 |
+
update_progress(4, total_steps)
|
137 |
|
138 |
# 各クラスの確率計算
|
139 |
probabilities = torch.softmax(y.logits, dim=-1)
|
140 |
top_prob, top_cat_indices = probabilities.topk(len(category_list))
|
141 |
+
update_progress(5, total_steps)
|
142 |
|
143 |
# 確率とカテゴリ名の準備
|
144 |
+
top_probabilities = top_prob.cpu().numpy()[0]
|
145 |
top_categories = [category_list[index] for index in top_cat_indices.cpu().numpy()[0]]
|
146 |
|
147 |
+
# 棒グラフの作成
|
148 |
+
|
149 |
# 棒グラフの作成
|
150 |
plt.figure(figsize=(10, 6))
|
151 |
+
bars = plt.bar(range(len(top_categories)), top_probabilities, color='skyblue')
|
152 |
+
|
153 |
+
# 画像の読み込みと配置
|
154 |
+
for i, (bar, category) in enumerate(zip(bars, top_categories)):
|
155 |
+
img_path = f'img/{i+1}.png' # ファイル名はcategory_listの順番+1の番号.png
|
156 |
+
image = plt.imread(img_path)
|
157 |
+
imagebox = OffsetImage(image, zoom=0.5) # zoomで画像のサイズを調整
|
158 |
+
ab = AnnotationBbox(imagebox, (bar.get_x() + bar.get_width() / 2, bar.get_height()), frameon=False, box_alignment=(0.5, -0.2))
|
159 |
+
plt.gca().add_artist(ab)
|
160 |
+
|
161 |
plt.xlabel('カテゴリ', fontproperties=font_prop)
|
162 |
plt.ylabel('確率', fontproperties=font_prop)
|
163 |
+
plt.ylim(0, 0.5) # y軸の範囲設定
|
164 |
+
plt.xticks(range(len(top_categories)), top_categories, rotation=45, ha="right", fontproperties=font_prop)
|
165 |
plt.title('カテゴリ別確率', fontproperties=font_prop)
|
166 |
+
plt.show()
|
167 |
+
|
168 |
+
|
169 |
+
progress_bar.progress(100)
|
170 |
+
time.sleep(1) # 1秒待機
|
171 |
+
progress_bar.empty() # プログレスバーを削除
|
172 |
+
|
173 |
else:
|
174 |
+
st.write("入力してください。")
|
img/1.png
ADDED
![]() |
img/10.png
ADDED
![]() |
img/11.png
ADDED
![]() |
img/12.png
ADDED
![]() |
img/13.png
ADDED
![]() |
img/14.png
ADDED
![]() |
img/2.png
ADDED
![]() |
img/3.png
ADDED
![]() |
img/4.png
ADDED
![]() |
img/5.png
ADDED
![]() |
img/6.png
ADDED
![]() |
img/7.png
ADDED
![]() |
img/8.png
ADDED
![]() |
img/9.png
ADDED
![]() |