kiriishi commited on
Commit
5c167d2
·
verified ·
1 Parent(s): 0eea2d4

Upload 61 files

Browse files
app.py CHANGED
@@ -16,6 +16,17 @@ from transformers import AutoModel
16
  from matplotlib.offsetbox import OffsetImage, AnnotationBbox
17
  import matplotlib.pyplot as plt
18
  import matplotlib.ticker as mticker
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  # モデルとトークナイザーのパス
21
  model_path = 'use14/bert-base-japanese-v3/2024-0208-0323/model'
@@ -24,11 +35,13 @@ tokenizer_path = 'use14/bert-base-japanese-v3/2024-0208-0323/tokenizer'
24
  # トークナイザーとモデルのロード
25
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
26
  model = AutoModelForSequenceClassification.from_pretrained(model_path, num_labels=14)
27
-
 
 
28
  font_path = 'NotoSansCJKjp-Bold.otf' # 実際のフォントのパスに置き換えてください
29
  font_prop = FontProperties(fname=font_path)
30
  # Streamlitアプリのタイトル
31
- st.title("セリフチェッカー")
32
  # セッション状態にキーが存在しない場合は、初期値を設定
33
  if 'button_clicked' not in st.session_state:
34
  st.session_state.button_clicked = False
@@ -83,6 +96,11 @@ category_list = [
83
  '121_荊棘従道[執事,大人,敬語(丁寧語)]', '133_司馬萌香[姉御,少女,粗雑(ヤンキー)]', '134_菜野花[落ち着いてる,少女,敬語]',
84
  '139_黒冬和馬[少しそっけない,少年,タメ口]', '142_四涼礼子[ダウナー,少女,語尾~,めんどくさがり]'
85
  ]
 
 
 
 
 
86
  # カテゴリ選択用のセレクトボックス
87
  selected_category = st.selectbox("1.目標キャラクターを選択", category_list)
88
 
@@ -91,146 +109,455 @@ selected_category = st.selectbox("1.目標キャラクターを選択", cate
91
  image_file_number = category_list.index(selected_category) + 1
92
  image_path = f"img/{image_file_number}.png"
93
  image_width=300
94
- judge_text = st.text_input("2.セリフを入力 //例: 貴方達も迷ったんですか?, よし、変身完了だ。, 勿論幽霊は抜きにして…でしょうね。//ですわ。,だろ!,ですね。")
95
 
 
 
 
96
  st.button("🔍 チェックする", on_click=on_button_click)
97
  st.divider()
 
 
 
 
98
  # 画面を2つの列に分割
99
  col1, col2 = st.columns([1, 4])
100
  # 左側の列に画像を表示
101
  with col1:
102
- st.image(image_path, caption=selected_category, width=120)
103
  # 右側の列にテキストボックスを配置
104
  with col2:
105
  if judge_text: # ユーザーが何か入力した場合のみ表示
106
  st.markdown(custom_css, unsafe_allow_html=True) # カスタムCSSの適用
107
  st.markdown(f'<div class="bubble">{judge_text}</div>', unsafe_allow_html=True)
 
 
 
108
 
 
 
 
109
  # 処理ステップ数に応じてプログレスバーを更新する関数
110
  def update_progress(step, total_steps):
111
  progress = int((step / total_steps) * 100)
112
  progress_bar.progress(progress)
113
 
 
114
  total_steps = 5 # 処理を行う総ステップ数
115
  if st.session_state.button_clicked:
116
  if judge_text:
117
- # プログレスバーの初期化
118
- progress_bar = st.progress(0)
119
-
120
- # トークナイズとテンソル化
121
- words = tokenizer.tokenize(judge_text)
122
- word_ids = tokenizer.convert_tokens_to_ids(words)
123
- word_tensor = torch.tensor([word_ids[:512]]) # 最大長を512に制限
124
- update_progress(1, total_steps)
125
-
126
- # デバイスの自動選択
127
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
128
- word_tensor = word_tensor.to(device)
129
- model = model.to(device)
130
- update_progress(2, total_steps)
131
-
132
- # 推論
133
- with torch.no_grad():
134
- y = model(word_tensor)
135
- update_progress(3, total_steps)
136
-
137
- # 最も近いカテゴリの決定
138
- pred = y.logits.argmax(-1)
139
- st.write(f"最もらしい: {category_list[pred.item()]}")
140
- update_progress(4, total_steps)
141
-
142
- # 各クラスの確率計算
143
- probabilities = torch.softmax(y.logits, dim=-1)
144
- top_prob, top_cat_indices = probabilities.topk(len(category_list))
145
- update_progress(5, total_steps)
146
-
147
- # 確率とカテゴリ名の準備
148
- top_probabilities = top_prob.cpu().numpy()[0]
149
- top_categories = [category_list[index] for index in top_cat_indices.cpu().numpy()[0]]
150
-
151
-
152
- # 新しい図と軸オブジェクトの作成
153
- fig, ax = plt.subplots(figsize=(10, 6))
154
-
155
-
156
- # すべての確率が0.2を下回っているかどうかをチェック
157
- all_below_0_2 = all(probability < 0.2 for probability in top_probabilities)
158
- # 確率とインデックスのタプルのリストを作成
159
-
160
- # 棒グラフの作成、条件に応じて色を変更
161
- for i, (probability, category) in enumerate(zip(top_probabilities, top_categories)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  # 確率とインデックスのタプルのリストを作成
163
- probability_index_tuples = list(enumerate(top_probabilities))
164
-
165
- # 確率でソートして上位3つを取得
166
- sorted_tuples = sorted(probability_index_tuples, key=lambda x: x[1], reverse=True)
167
- top_3_indices = [t[0] for t in sorted_tuples[:3]]
168
- top_3_probabilities = [t[1] for t in sorted_tuples[:3]]
169
-
170
- # 1位と2位の確率の差が大きいかどうかを評価
171
- # ここでは例として、1位の確率が2位の確率よりも27.6%以上大きい場合を「大きい」と判断
172
- is_first_place_significantly_higher = top_3_probabilities[0] - top_3_probabilities[1] > 0.276
173
 
174
  # 棒グラフの作成、条件に応じて色を変更
175
  for i, (probability, category) in enumerate(zip(top_probabilities, top_categories)):
176
- # 1位が顕著に大きい場合、1位の棒をオレンジに設定
177
- if is_first_place_significantly_higher and i == top_3_indices[0]:
178
- color = 'orange'
179
- elif probability < 0.1:
180
- color = 'grey'
181
- elif all_below_0_2:
182
- color = 'grey'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  else:
184
- color = 'skyblue' # それ以外の場合の色
185
-
186
- ax.bar(i, probability, color=color)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
- # 棒の上部または画像の上に数値を表示(パーセント表示に変更)
189
- text_y = probability if probability <= 1 else 1
190
- ax.text(i, text_y, f'{probability * 100:.1f}%', ha='center', va='bottom' if probability <= 1 else 'top')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
 
192
- # カテゴリリスト内でのカテゴリ名のインデックスを探し、その位置に基づいて画像ファイルを参照
193
- for i, category in enumerate(top_categories):
194
- # カテゴリリスト内の位置(インデックス+1)を使って画像ファイルパスを指定
195
- position = category_list.index(category) + 1
196
- img_path = f'img/{position}.png'
197
 
198
- # 画像の読み込みと配置
199
- image = plt.imread(img_path)
200
- imagebox = OffsetImage(image, zoom=0.1)
201
- ab = AnnotationBbox(imagebox, (i, 0), frameon=False, box_alignment=(0.5, -0.2))
202
- ax.add_artist(ab)
203
 
 
 
 
 
204
 
205
- # y軸の範囲設定を調整
 
 
 
 
206
 
207
- # 縦軸の範囲設定と横線の描画を調整
208
- max_probability = max(top_probabilities)
209
- # 最大確率が0.3以上ならば、それに合わせてy軸の上限を設定
210
- y_max = max(0.3, np.ceil(max_probability / 0.1) * 0.1)
211
 
212
- ax.set_ylim(0, y_max)
213
- ax.yaxis.set_major_formatter(mticker.PercentFormatter(xmax=1))
214
 
215
- # 0.1刻みで横線を引く、0.3, 0.6, 0.9 は太くする
216
- for y in np.arange(0.1, y_max + 0.1, 0.1):
217
- if y in [0.3, 0.6, 0.9]:
218
- ax.axhline(y=y, color='blue', linestyle='-', linewidth=2) # 太い横線
219
- else:
220
- ax.axhline(y=y, color='grey', linestyle='--', linewidth=0.5) # 通常の横線
221
- ax.set_xlabel('', fontproperties=font_prop)
222
- ax.set_ylabel('確率', fontproperties=font_prop)
223
- ax.set_xticks(range(len(top_categories)))
224
- ax.set_xticklabels(top_categories, rotation=45, ha="right", fontproperties=font_prop)
225
- # ax.set_title('カテゴリ別確率', fontproperties=font_prop)
226
 
227
- # Streamlitでグラフを表示
228
- st.pyplot(fig)
229
 
 
 
 
230
 
231
- progress_bar.progress(100)
232
- time.sleep(0.5) # 1秒待機
233
- progress_bar.empty() # プログレスバーを削除
234
 
235
- else:
236
- st.write("入力してください。")
 
 
 
16
  from matplotlib.offsetbox import OffsetImage, AnnotationBbox
17
  import matplotlib.pyplot as plt
18
  import matplotlib.ticker as mticker
19
+ import streamlit as st
20
+ import torch
21
+ import matplotlib.pyplot as plt
22
+ from matplotlib.offsetbox import OffsetImage, AnnotationBbox
23
+ import matplotlib.ticker as mticker
24
+ import numpy as np
25
+ import time
26
+ from transformers import BertJapaneseTokenizer, BertModel
27
+ from sklearn.metrics.pairwise import cosine_similarity
28
+ import re
29
+ import os
30
 
31
  # モデルとトークナイザーのパス
32
  model_path = 'use14/bert-base-japanese-v3/2024-0208-0323/model'
 
35
  # トークナイザーとモデルのロード
36
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
37
  model = AutoModelForSequenceClassification.from_pretrained(model_path, num_labels=14)
38
+ st.set_page_config(
39
+ page_title="台詞校正ツール",
40
+ layout="wide")
41
  font_path = 'NotoSansCJKjp-Bold.otf' # 実際のフォントのパスに置き換えてください
42
  font_prop = FontProperties(fname=font_path)
43
  # Streamlitアプリのタイトル
44
+ st.header("入力台詞のチェック")
45
  # セッション状態にキーが存在しない場合は、初期値を設定
46
  if 'button_clicked' not in st.session_state:
47
  st.session_state.button_clicked = False
 
96
  '121_荊棘従道[執事,大人,敬語(丁寧語)]', '133_司馬萌香[姉御,少女,粗雑(ヤンキー)]', '134_菜野花[落ち着いてる,少女,敬語]',
97
  '139_黒冬和馬[少しそっけない,少年,タメ口]', '142_四涼礼子[ダウナー,少女,語尾~,めんどくさがり]'
98
  ]
99
+
100
+ # キャラクターの名前と属性を抽出
101
+ character_names = [category.split('_')[1].split('[')[0] for category in category_list]
102
+ character_attributes = ['[' + category.split('[')[1] for category in category_list]
103
+
104
  # カテゴリ選択用のセレクトボックス
105
  selected_category = st.selectbox("1.目標キャラクターを選択", category_list)
106
 
 
109
  image_file_number = category_list.index(selected_category) + 1
110
  image_path = f"img/{image_file_number}.png"
111
  image_width=300
 
112
 
113
+ ##----------------------------------------------------
114
+
115
+ judge_text = st.text_input("2.セリフを入力 //例: 貴方達も迷ったんですか?, よし、変身完了だ。, 勿論幽霊は抜きにして…でしょうね。//ですわ。,だろ!,ですね。","生徒たちの安全を守りたい。")
116
  st.button("🔍 チェックする", on_click=on_button_click)
117
  st.divider()
118
+
119
+ ##----------------------------------------------------
120
+
121
+
122
  # 画面を2つの列に分割
123
  col1, col2 = st.columns([1, 4])
124
  # 左側の列に画像を表示
125
  with col1:
126
+ st.image(image_path, width=120)
127
  # 右側の列にテキストボックスを配置
128
  with col2:
129
  if judge_text: # ユーザーが何か入力した場合のみ表示
130
  st.markdown(custom_css, unsafe_allow_html=True) # カスタムCSSの適用
131
  st.markdown(f'<div class="bubble">{judge_text}</div>', unsafe_allow_html=True)
132
+ selected_character=selected_category
133
+
134
+ ##----------------------------------------------------
135
 
136
+
137
+ tab1, tab2, tab3 = st.tabs(["📈 話体", "💭 類似話題", "🔑 キーワード話題"])
138
+ data = np.random.randn(10, 1)
139
  # 処理ステップ数に応じてプログレスバーを更新する関数
140
  def update_progress(step, total_steps):
141
  progress = int((step / total_steps) * 100)
142
  progress_bar.progress(progress)
143
 
144
+ ##----------------------------------------------------
145
  total_steps = 5 # 処理を行う総ステップ数
146
  if st.session_state.button_clicked:
147
  if judge_text:
148
+ with tab1:
149
+ # プログレスバーの初期化
150
+ progress_bar = st.progress(0)
151
+
152
+ # トークナイズとテンソル化
153
+ words = tokenizer.tokenize(judge_text)
154
+ word_ids = tokenizer.convert_tokens_to_ids(words)
155
+ word_tensor = torch.tensor([word_ids[:512]]) # 最大長を512に制限
156
+ update_progress(1, total_steps)
157
+
158
+ # デバイスの自動選択
159
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
160
+ word_tensor = word_tensor.to(device)
161
+ model = model.to(device)
162
+ update_progress(2, total_steps)
163
+
164
+ # 推論
165
+ with torch.no_grad():
166
+ y = model(word_tensor)
167
+ update_progress(3, total_steps)
168
+
169
+ # 最も近いカテゴリの決定
170
+ pred = y.logits.argmax(-1)
171
+ update_progress(4, total_steps)
172
+
173
+ # 各クラスの確率計算
174
+ probabilities = torch.softmax(y.logits, dim=-1)
175
+ top_prob, top_cat_indices = probabilities.topk(len(category_list))
176
+ update_progress(5, total_steps)
177
+ top_probabilities = top_prob.cpu().numpy()[0]
178
+ top_categories = [category_list[index] for index in top_cat_indices.cpu().numpy()[0]]
179
+
180
+ selected_character_index = top_categories.index(selected_character) if selected_character in top_categories else -1
181
+
182
+ # 選択したキャラクターがリスト内のどの位置にあるかを判定
183
+ if selected_character in top_categories:
184
+ selected_character_index = top_categories.index(selected_character)
185
+ if selected_character_index == 0:
186
+ result_text = "OKです!"
187
+ elif selected_character_index in [1, 2]:
188
+ result_text = "OKです"
189
+ else:
190
+ result_text = "違うかも?"
191
+ else:
192
+ # リストにキャラクターがない場合
193
+ result_text = "違うかも?"
194
+ st.session_state.result_text = result_text
195
+
196
+ # 結果の表示
197
+ st.write(result_text)
198
+ # st.write(f"最も近い: {category_list[pred.item()]}")
199
+
200
+ # 確率とカテゴリ名の準備
201
+ top_probabilities = top_prob.cpu().numpy()[0]
202
+ top_categories = [category_list[index] for index in top_cat_indices.cpu().numpy()[0]]
203
+ # 新しい図と軸オブジェクトの作成
204
+ fig, ax = plt.subplots(figsize=(10, 6))
205
+
206
+ # すべての確率が0.2を下回っているかどうかをチェック
207
+ all_below_0_2 = all(probability < 0.2 for probability in top_probabilities)
208
  # 確率とインデックスのタプルのリストを作成
 
 
 
 
 
 
 
 
 
 
209
 
210
  # 棒グラフの作成、条件に応じて色を変更
211
  for i, (probability, category) in enumerate(zip(top_probabilities, top_categories)):
212
+ # 確率とインデックスのタプルのリストを作成
213
+ probability_index_tuples = list(enumerate(top_probabilities))
214
+
215
+ # 確率でソートして上位3つを取得
216
+ sorted_tuples = sorted(probability_index_tuples, key=lambda x: x[1], reverse=True)
217
+ top_3_indices = [t[0] for t in sorted_tuples[:3]]
218
+ top_3_probabilities = [t[1] for t in sorted_tuples[:3]]
219
+
220
+ # 1位と2位の確率の差が大きいかどうかを評価
221
+ # ここでは例として、1位の確率が2位の確率よりも27.6%以上大きい場合を「大きい」と判断
222
+ is_first_place_significantly_higher = top_3_probabilities[0] - top_3_probabilities[1] > 0.276
223
+
224
+ # 棒グラフの作成、条件に応じて色を変更
225
+ for i, (probability, category) in enumerate(zip(top_probabilities, top_categories)):
226
+ # 1位が顕著に大きい場合、1位の棒をオレンジに設定
227
+ if is_first_place_significantly_higher and i == top_3_indices[0]:
228
+ color = 'orange'
229
+ elif probability < 0.1:
230
+ color = 'grey'
231
+ elif all_below_0_2:
232
+ color = 'grey'
233
+ else:
234
+ color = 'skyblue' # それ以外の場合の色
235
+
236
+ ax.bar(i, probability, color=color)
237
+
238
+ # 棒の上部または画像の上に数値を表示(パーセント表示に変更)
239
+ text_y = probability if probability <= 1 else 1
240
+ ax.text(i, text_y, f'{probability * 100:.1f}%', ha='center', va='bottom' if probability <= 1 else 'top')
241
+
242
+ # カテゴリリスト内でのカテゴリ名のインデックスを探し、その位置に基づいて画像ファイルを参照
243
+ for i, category in enumerate(top_categories):
244
+ # カテゴリリスト内の位置(インデックス+1)を使って画像ファイルパスを指定
245
+ position = category_list.index(category) + 1
246
+ img_path = f'img/{position}.png'
247
+
248
+ # 画像の読み込みと配置
249
+ image = plt.imread(img_path)
250
+ imagebox = OffsetImage(image, zoom=0.1)
251
+ ab = AnnotationBbox(imagebox, (i, 0), frameon=False, box_alignment=(0.5, -0.2))
252
+ ax.add_artist(ab)
253
+
254
+
255
+ # y軸の範囲設定を調整
256
+ # 縦軸の範囲設定と横線の描画を調整
257
+ max_probability = max(top_probabilities)
258
+ # 最大確率が0.3以上ならば、それに合わせてy軸の上限を設定
259
+ y_max = max(0.3, np.ceil(max_probability / 0.1) * 0.1)
260
+
261
+ ax.set_ylim(0, y_max)
262
+ ax.yaxis.set_major_formatter(mticker.PercentFormatter(xmax=1))
263
+
264
+ # 0.1刻みで横線を引く、0.3, 0.6, 0.9 は太くする
265
+ for y in np.arange(0.1, y_max + 0.1, 0.1):
266
+ if y in [0.3, 0.6, 0.9]:
267
+ ax.axhline(y=y, color='blue', linestyle='-', linewidth=2) # 太い横線
268
  else:
269
+ ax.axhline(y=y, color='grey', linestyle='--', linewidth=0.5) # 通常の横線
270
+ ax.set_xlabel('', fontproperties=font_prop)
271
+ ax.set_ylabel('確率', fontproperties=font_prop)
272
+ ax.set_xticks(range(len(top_categories)))
273
+ ax.set_xticklabels(top_categories, rotation=45, ha="right", fontproperties=font_prop)
274
+ # ax.set_title('カテゴリ別確率', fontproperties=font_prop)
275
+
276
+ # Streamlitでグラフを表示
277
+ st.pyplot(fig)
278
+ progress_bar.progress(100)
279
+ time.sleep(0.5) # 1秒待機
280
+ progress_bar.empty() # プログレスバーを削除
281
+ ##----------------------------------------------------
282
+ with tab2:
283
+ with st.spinner('処理中...'):
284
+ # モデルとトークナイザーの初期化
285
+ ruiji_model_name = 'cl-tohoku/bert-base-japanese-v3'
286
+ ruiji_tokenizer = BertJapaneseTokenizer.from_pretrained(ruiji_model_name)
287
+ ruiji_model = BertModel.from_pretrained(ruiji_model_name)
288
+
289
+ def extract_keywords(sentence, tokenizer, num_keywords=10, min_length=2):
290
+ # トークナイズして品詞タグを取得
291
+ tokens = tokenizer.tokenize(sentence)
292
+ # 文字数が min_length 以上のトークンのみを選択
293
+ # 一文字かつひらがなの単語を除外するフィルタリング条件を追加
294
+ filtered_tokens = [token for token in tokens if len(token) >= min_length and not re.match(r'^[ぁ-ん、。]$', token)]
295
+ return filtered_tokens[:num_keywords]
296
+
297
+ def find_sentences_with_specific_keywords(sentences, keywords):
298
+ results = {}
299
+ sentence_to_id = {} # 文とそのIDをマッピングする辞書
300
+ id_to_sentence = {} # IDと文をマッピングする辞書
301
+ first_reference = {} # 各IDに対して最初に参照されたキーワードを記録
302
+ next_id = 1 # 次に割り当てるID
303
+
304
+ for sentence in sentences:
305
+ for keyword in keywords:
306
+ if keyword in sentence:
307
+ if sentence not in sentence_to_id:
308
+ # 文に新しいIDを割り当て、最初の参照として記録
309
+ sentence_to_id[sentence] = next_id
310
+ id_to_sentence[next_id] = sentence
311
+ first_reference[next_id] = keyword # このIDが最初に参照されたキーワード
312
+ next_id += 1
313
+
314
+ # 結果にIDと共に文を追加
315
+ results.setdefault(keyword, []).append(sentence_to_id[sentence])
316
+
317
+ # IDを参照して文を取得し、重複を示す情報を付加して返す
318
+ formatted_results = {}
319
+ seen_sentences = set() # 既に出力された文のIDを記録
320
+ for keyword, ids in results.items():
321
+ formatted_sentences = []
322
+ for id in ids:
323
+ if id in seen_sentences:
324
+ # 既に出力された文は、"と同じ"で参照
325
+ formatted_sentence = f"({id})と同じ"
326
+ else:
327
+ # 初めて出力される文は通常通り表示し、seen_sentencesに追加
328
+ formatted_sentence = f"({id}) {id_to_sentence[id]}"
329
+ seen_sentences.add(id)
330
+ formatted_sentences.append(formatted_sentence)
331
+ formatted_results[keyword] = formatted_sentences
332
+ return formatted_results
333
+
334
+ def get_sentence_embedding(sentence):
335
+ inputs = ruiji_tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=128)
336
+ outputs = ruiji_model(**inputs)
337
+ sentence_embedding = outputs.last_hidden_state.mean(dim=1).squeeze().detach().numpy()
338
+ return sentence_embedding
339
+
340
+ def load_sentences(file_path):
341
+ with open(file_path, 'r', encoding='utf-8') as file:
342
+ sentences = file.readlines()
343
+ return [sentence.strip() for sentence in sentences]
344
+
345
+ def process_input(character_index, input_sentence, characters_list):
346
+ file_name = characters_list[character_index] + ".txt"
347
+ file_path = os.path.join("use14", file_name)
348
+ sentences = load_sentences(file_path)
349
+ input_embedding = get_sentence_embedding(input_sentence)
350
+ similarities = []
351
+ for sentence in sentences:
352
+ sentence_embedding = get_sentence_embedding(sentence)
353
+ similarity = cosine_similarity([input_embedding], [sentence_embedding])[0]
354
+ similarities.append((sentence, similarity[0]))
355
+ similar_sentences = sorted(similarities, key=lambda x: x[1], reverse=True)[:10]
356
+ # キーワードの抽出
357
+ keywords = extract_keywords(input_sentence, tokenizer)
358
+ # 各キーワードに関連する台詞の抽出
359
+ results = find_sentences_with_specific_keywords(sentences, keywords)
360
+ return similar_sentences, results # mentioned_sentences を results に置き換え
361
+
362
+ character_index = category_list.index(selected_category)
363
+ input_sentence = judge_text
364
+ characters_list=['5_紫上鏡一_教師大人タメ', '20_見嶋千里_怖め大人粗雑', '29_白鳥王子_キザ大人調子', '30_白城院素子_お嬢様少女丁寧', '32_水陰那月_ネガ少年敬語', '50_緋崎平一郎_元気少年やんちゃ', '76_百知瑠璃_元気少女です', '91_桜結衣_元気少女タメ', '101_御伽美夜子_落着大人女性口調', '121_荊棘従道_執事大人丁寧', '133_司馬萌香_姉御少女粗雑', '134_菜野花_落着少女敬語', '139_黒冬和馬_普通少年タメ', '142_四涼礼子_ダウナー少女語尾伸']
365
+ similar_sentences, mentioned_sentences = process_input(character_index, input_sentence, characters_list)
366
+
367
+ with st.container():
368
+ st.subheader("類似度話題:")
369
+ # 類似度の閾値
370
+ similarity_threshold = 0.85
371
+ # 類似度の高い台詞をDataFrameに変換
372
+ similar_df = pd.DataFrame(similar_sentences, columns=['台詞', '類似度'])
373
+ # 閾値より上のデータが存在するかどうかをチェック
374
+ if not similar_df[similar_df['類似度'] >= similarity_threshold].empty:
375
+ # 閾値より上のデータが存在する場合の処理
376
+ st.dataframe(similar_df[similar_df['類似度'] >= similarity_threshold].reset_index(drop=True).style.format({'類似度': "{:.2f}"}))
377
+ # 閾値より下のデータが存在するかどうかをチェック
378
+ if not similar_df[similar_df['類似度'] < similarity_threshold].empty:
379
+ # 閾値より下のデータが存在する場合、警告メッセージを表示
380
+ st.markdown("<span style='color:red;'>以下、信頼度が低い可能性があります</span>", unsafe_allow_html=True)
381
+ # 閾値より下のデータを表示
382
+ st.dataframe(similar_df[similar_df['類似度'] < similarity_threshold].reset_index(drop=True).style.format({'類似度': "{:.2f}"}))
383
+ with tab3:
384
+ with st.container():
385
+ # キーワードのリストを取得
386
+ keywords_list = list(mentioned_sentences.keys())
387
+ # キーワードを文字列として結合し、表示
388
+ st.subheader(f"キーワード話題: [{', '.join(keywords_list)}]")
389
+ # キーワードに関連する台詞を表示
390
+ for keyword, sentences in mentioned_sentences.items():
391
+ st.text(f"キーワード '{keyword}' に関連する既存台詞:")
392
+ # 台詞からIDを抽出し、元の文と一緒にリストに格納
393
+ sentences_with_id = [(int(re.search(r"\((\d+)\)", sentence).group(1)), sentence) for sentence in sentences]
394
+ # IDで並び替え
395
+ sorted_sentences_with_id = sorted(sentences_with_id, key=lambda x: x[0])
396
+ # 並び替えた後の台詞リストを作成(IDを除外)
397
+ sorted_sentences = [sentence for _, sentence in sorted_sentences_with_id]
398
+ # 台詞をリストに変換してDataFrameを作成
399
+ keyword_df = pd.DataFrame(sorted_sentences, columns=['台詞'])
400
+ # DataFrameを表示
401
+ st.dataframe(keyword_df)
402
+
403
+ ##----------------------------------------------------
404
 
405
+ else:
406
+ st.write("入力してください。")
407
+
408
+ ##----------------------------------------------------
409
+
410
+ # 選択されたキャラクターのインデックスに基づいてデータを取得する関数
411
+ def get_character_data(selected_index):
412
+ data = []
413
+ with open("result_ninsho.txt", "r", encoding="utf-8") as file:
414
+ for line in file:
415
+ parts = line.strip().split(',')
416
+ if int(parts[0]) == selected_index + 1:
417
+ data.append(parts)
418
+ return data
419
+
420
+ # 頻出単語データを整形してテーブルに表示する関数
421
+ def display_word_frequency_table(data):
422
+ hinshi_categories = ['一人称代名詞', '二人称代名詞', '三人称代名詞', 'その他の代名詞', '名詞', '動詞', '形容詞・形容動詞', '副詞', '助動詞', '感動詞']
423
+ df = pd.DataFrame(columns=['品詞', '頻出単語'])
424
+
425
+ rows_list = [] # 新しい行を一時的に格納するリスト
426
+ for row in data:
427
+ category_index = int(row[1]) - 1 # 品詞のカテゴリインデックス
428
+ words = row[2:]
429
+ word_counts = {}
430
+ for word in words:
431
+ if word not in word_counts:
432
+ word_counts[word] = 1
433
+ else:
434
+ word_counts[word] += 1
435
+
436
+ # 単語を出現頻度順に並べ替え
437
+ sorted_words = sorted(word_counts.items(), key=lambda x: x[1], reverse=True)
438
+ formatted_words = [f"**{word}({count})**" if i == 0 else f"{word}({count})" for i, (word, count) in enumerate(sorted_words)]
439
+
440
+ # 新しい行をリストに追加
441
+ rows_list.append({'品詞': hinshi_categories[category_index], '頻出単語': ', '.join(formatted_words)})
442
+
443
+ # DataFrameに行を追加
444
+ df = pd.concat([df, pd.DataFrame(rows_list)], ignore_index=True)
445
+ return df
446
+
447
+ st.markdown(
448
+ """
449
+ <style>
450
+ section[data-testid="stSidebar"] {
451
+ width: 500px !important;
452
+ }
453
+ </style>
454
+ """,
455
+ unsafe_allow_html=True
456
+ )
457
+
458
+ def display_df_as_html_table_with_bold_text(df):
459
+ # HTMLテーブルのヘッダーを定義
460
+ html = "<table>"
461
+ html += "<tr>"
462
+ for col in df.columns:
463
+ html += f"<th>{col}</th>"
464
+ html += "</tr>"
465
+
466
+ # DataFrameの各行データをHTMLテーブルの行として追加
467
+ for index, row in df.iterrows():
468
+ html += "<tr>"
469
+ for col in df.columns:
470
+ # 特定の文字列を太字にしたい場合、ここで条件をチェック
471
+ # ここでは簡単な例として、文字列内の数字が含まれていれば太字にします
472
+ if any(char.isdigit() for char in str(row[col])):
473
+ cell_value = f"<strong>{row[col]}</strong>"
474
+ else:
475
+ cell_value = row[col]
476
+ html += f"<td>{cell_value}</td>"
477
+ html += "</tr>"
478
+ html += "</table>"
479
+
480
+ # 完成したHTMLテーブルをStreamlitで表示
481
+ st.markdown(html, unsafe_allow_html=True)
482
+
483
+ def display_df_with_markdown(df):
484
+ for index, row in df.iterrows():
485
+ # 品詞列を表示
486
+ st.markdown(f"**品詞**: {row['品詞']}")
487
+ # 頻出単語列をMarkdown形式で表示(太字を含む)
488
+ st.markdown(f"**頻出単語**: {row['頻出単語']}", unsafe_allow_html=True)
489
+
490
+ def display_df_with_markdown_table(df):
491
+ # テーブルのヘッダー
492
+ markdown_table = "品詞 | 頻出単語\n-|-|\n"
493
+
494
+ # DataFrameの各行をループしてテーブルの行を作成
495
+ for _, row in df.iterrows():
496
+ # 頻出単語列のデータを取得し、カンマで分割
497
+ words = row['頻出単語'].split(', ')
498
+
499
+ # 単語のリストを処理
500
+ formatted_words = []
501
+ for word in words:
502
+ # 出現頻度を取得
503
+ freq = int(word[word.find("(")+1:word.find(")")])
504
+ if freq == 1:
505
+ # 出現頻度が1の場合、薄い文字で表示
506
+ formatted_word = f"<span style='color: #bbb;'>{word}</span>"
507
+ else:
508
+ formatted_word = word
509
+ formatted_words.append(formatted_word)
510
+
511
+ # 最も頻出する単語(リストの最初の要素)を太字に
512
+ if formatted_words:
513
+ formatted_words[0] = f"**{formatted_words[0]}**"
514
+
515
+ # 処理した単語のリストをカンマで結合
516
+ formatted_words_str = ', '.join(formatted_words)
517
+
518
+ # Markdownテーブルの行を追加
519
+ markdown_table += f"{row['品詞']} | {formatted_words_str}\n"
520
+
521
+ # 生成したMarkdownテーブルをStreamlitで表示
522
+ st.markdown(markdown_table, unsafe_allow_html=True)
523
 
 
 
 
 
 
524
 
 
 
 
 
 
525
 
526
+ ##----------------------------------------------------
527
+ with st.sidebar:
528
+ st.title("台詞校正ツール")
529
+ st.header('キャラクター辞典')
530
 
531
+ selected_index = st.selectbox("キャラクターを選択", range(len(character_names)), format_func=lambda x: character_names[x])
532
+ # 選択されたキャラクターの情報をメイン画面に表示
533
+ character_name = character_names[selected_index]
534
+ character_attribute = character_attributes[selected_index]
535
+ image_path = f"img/full{selected_index + 1}.png" # img/1.pngから始まる
536
 
537
+ st.image(image_path, caption=f"{character_name}", width=250)
538
+ st.write(f"{character_name} の特徴: {character_attribute}")
 
 
539
 
540
+ character_names = [category.split('_')[1].split('[')[0] for category in category_list]
 
541
 
542
+ character_data = get_character_data(selected_index)
543
+ df = display_word_frequency_table(character_data)
544
+
545
+
546
+ # DataFrameを取得(この部分は既にあるコードを使用)
547
+ character_data = get_character_data(selected_index)
548
+ df = display_word_frequency_table(character_data)
 
 
 
 
549
 
550
+ # "頻出単語テーブル"としてセクションを表示
551
+ st.write("頻出単語テーブル")
552
 
553
+ # カスタムMarkdownテーブルでDataFrameの内容を表示
554
+ display_df_with_markdown_table(df)
555
+ import streamlit as st
556
 
557
+ # 画像データの例
558
+ image_path = "img/1.png"
 
559
 
560
+ # Expanderで画像を表示
561
+ expander = st.expander("ここをクリックして画像を表示")
562
+ with expander:
563
+ st.image(image_path, caption="画像のキャプション")
img/5.png CHANGED
img/full1.png ADDED
img/full10.png ADDED
img/full11.png ADDED
img/full12.png ADDED
img/full13.png ADDED
img/full14.png ADDED
img/full2.png ADDED
img/full3.png ADDED
img/full4.png ADDED
img/full5.png ADDED
img/full7.png ADDED
img/full8.png ADDED
img/full9.png ADDED
result_ninsho.txt ADDED
The diff for this file is too large to render. See raw diff