kiriishi commited on
Commit
de90a66
·
verified ·
1 Parent(s): def91b4

Upload 15 files

Browse files
Files changed (15) hide show
  1. app.py +113 -20
  2. img/1.png +0 -0
  3. img/10.png +0 -0
  4. img/11.png +0 -0
  5. img/12.png +0 -0
  6. img/13.png +0 -0
  7. img/14.png +0 -0
  8. img/2.png +0 -0
  9. img/3.png +0 -0
  10. img/4.png +0 -0
  11. img/5.png +0 -0
  12. img/6.png +0 -0
  13. img/7.png +0 -0
  14. img/8.png +0 -0
  15. img/9.png +0 -0
app.py CHANGED
@@ -11,8 +11,10 @@ from collections import Counter
11
  import matplotlib.pyplot as plt
12
  import numpy as np
13
  # モデルとトークナイザーのロード
 
14
  from transformers import AutoModel
15
-
 
16
  # モデルとトークナイザーのパス
17
  model_path = 'use14/bert-base-japanese-v3/2024-0208-0323/model'
18
  tokenizer_path = 'use14/bert-base-japanese-v3/2024-0208-0323/tokenizer'
@@ -24,58 +26,149 @@ model = AutoModelForSequenceClassification.from_pretrained(model_path, num_label
24
  font_path = 'NotoSansCJKjp-Bold.otf' # 実際のフォントのパスに置き換えてください
25
  font_prop = FontProperties(fname=font_path)
26
  # Streamlitアプリのタイトル
27
- st.title('テキストファイル分析アプリ')
 
 
 
28
 
29
- # カテゴリリスト
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  category_list = [
32
- '5_紫上鏡一_教師大人タメ', '20_見嶋千里_怖め大人粗雑', '29_白鳥王子_キザ大人調子',
33
- '30_白城院素子_お嬢様少女丁寧', '32_水陰那月_ネガ少年敬語', '50_緋崎平一郎_元気少年やんちゃ',
34
- '76_百知瑠璃_元気少女です', '91_桜結衣_元気少女タメ', '101_御伽美夜子_落着大人女性口調',
35
- '121_荊棘従道_執事大人丁寧', '133_司馬萌香_姉御少女粗雑', '134_菜野花_落着少女敬語',
36
- '139_黒冬和馬_普通少年タメ', '142_四涼礼子_ダウナー少女語尾伸'
37
  ]
38
- # Streamlitのユーザーインターフェイス
39
- st.title("テキストカテゴリ推論アプリ")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
- # 入力フォーム
42
- judge_text = st.text_input("テキストを入力してください。# 貴方達も迷ったんですか?# よし、変身完了だ。勿論幽霊は抜きにして…でしょうね。")
 
 
43
 
44
- if st.button("チェックする"):
 
45
  if judge_text:
 
 
 
46
  # トークナイズとテンソル化
47
  words = tokenizer.tokenize(judge_text)
48
  word_ids = tokenizer.convert_tokens_to_ids(words)
49
  word_tensor = torch.tensor([word_ids[:512]]) # 最大長を512に制限
 
50
 
51
- # デバイスの自動選択 (GPUが利用可能ならGPUを使用)
52
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
53
  word_tensor = word_tensor.to(device)
54
  model = model.to(device)
 
55
 
56
  # 推論
57
  with torch.no_grad():
58
  y = model(word_tensor)
 
59
 
60
  # 最も近いカテゴリの決定
61
  pred = y.logits.argmax(-1)
62
- st.write(f"最も近いカテゴリ: {category_list[pred.item()]}")
 
63
 
64
  # 各クラスの確率計算
65
  probabilities = torch.softmax(y.logits, dim=-1)
66
  top_prob, top_cat_indices = probabilities.topk(len(category_list))
 
67
 
68
  # 確率とカテゴリ名の準備
69
- top_probabilities = top_prob.cpu().numpy()[0] # CPUに移動してからnumpy配列に変換
70
  top_categories = [category_list[index] for index in top_cat_indices.cpu().numpy()[0]]
71
 
 
 
72
  # 棒グラフの作成
73
  plt.figure(figsize=(10, 6))
74
- plt.bar(top_categories, top_probabilities, color='skyblue')
 
 
 
 
 
 
 
 
 
75
  plt.xlabel('カテゴリ', fontproperties=font_prop)
76
  plt.ylabel('確率', fontproperties=font_prop)
77
- plt.xticks(rotation=45, ha="right", fontproperties=font_prop)
 
78
  plt.title('カテゴリ別確率', fontproperties=font_prop)
79
- st.pyplot(plt)
 
 
 
 
 
 
80
  else:
81
- st.write("にゅうりょくしてー")
 
11
  import matplotlib.pyplot as plt
12
  import numpy as np
13
  # モデルとトークナイザーのロード
14
+ import time
15
  from transformers import AutoModel
16
+ from matplotlib.offsetbox import OffsetImage, AnnotationBbox
17
+ import matplotlib.pyplot as plt
18
  # モデルとトークナイザーのパス
19
  model_path = 'use14/bert-base-japanese-v3/2024-0208-0323/model'
20
  tokenizer_path = 'use14/bert-base-japanese-v3/2024-0208-0323/tokenizer'
 
26
  font_path = 'NotoSansCJKjp-Bold.otf' # 実際のフォントのパスに置き換えてください
27
  font_prop = FontProperties(fname=font_path)
28
  # Streamlitアプリのタイトル
29
+ st.title("セリフチェッカー")
30
+ # セッション状態にキーが存在しない場合は、初期値を設定
31
+ if 'button_clicked' not in st.session_state:
32
+ st.session_state.button_clicked = False
33
 
34
+ def on_button_click():
35
+ # ボタンがクリックされた時の処理
36
+ st.session_state.button_clicked = True
37
+ # 吹き出し風表示用のカスタムCSS
38
+
39
+ custom_css = """
40
+ <style>
41
+ .bubble {
42
+ position: relative;
43
+ background: #F0F0F0;
44
+ border-radius: .4em;
45
+ padding: 10px;
46
+ max-width: 95%; /* 吹き出しの最大幅を90%に設定 */
47
+ word-wrap: break-word; /* 長い単語でも折り返しを保証 */
48
+ }
49
 
50
+ .bubble::after {
51
+ content: '';
52
+ position: absolute;
53
+ top: 10px;
54
+ left: -10px;
55
+ width: 0;
56
+ height: 0;
57
+ border: 10px solid transparent;
58
+ border-right-color: #F0F0F0;
59
+ border-left: 0;
60
+ margin-top: 5px;
61
+ margin-left: 0;
62
+ }
63
+ </style>
64
+ """
65
+ # CSSを使ってプログレスバーの色を変更
66
+ st.markdown("""
67
+ <style>
68
+ /* プログレスバーの色を変更 */
69
+ .stProgress > div > div > div > div {
70
+ background-color: #008000;
71
+ }
72
+ </style>
73
+ """, unsafe_allow_html=True)
74
+ # カテゴリリスト
75
  category_list = [
76
+ '5_紫上鏡一[教師,大人,タメ口]', '20_見嶋千里[怖め,大人,粗雑な言葉,特徴的な笑い]', '29_白鳥王子[キザ,大人,調子いい]',
77
+ '30_白城院素子[お嬢様,少女]', '32_水陰那月[ネガティブ,少年,敬語]', '50_緋崎平一郎[元気,少年,やんちゃ]',
78
+ '76_百知瑠璃[元気,少女,です!,写真]', '91_桜結衣[元気,少女,タメ口,アイドル]', '101_御伽美夜子[落ち着いている,大人,女性口調]',
79
+ '121_荊棘従道[執事,大人,敬語(丁寧語)]', '133_司馬萌香[姉御,少女,粗雑(ヤンキー)]', '134_菜野花[落ち着いてる,少女,敬語]',
80
+ '139_黒冬和馬[少しそっけない,少年,タメ口]', '142_四涼礼子[ダウナー,少女,語尾~,めんどくさがり]'
81
  ]
82
+ # カテゴリ選択用のセレクトボックス
83
+ selected_category = st.selectbox("1.目標キャラクターを選択", category_list)
84
+
85
+ # 選択されたカテゴリに対応する画像ファイル名の決定
86
+ # カテゴリリストのインデックスを取得し、それに1を加えることで1から始まる画像ファイル番号を作成
87
+ image_file_number = category_list.index(selected_category) + 1
88
+ image_path = f"img/{image_file_number}.png"
89
+ image_width=300
90
+ judge_text = st.text_input("2.セリフを入力 //例: 貴方達も迷ったんですか?, よし、変身完了だ。, 勿論幽霊は抜きにして…でしょうね。,[ですわ。,だろ!,ですね。]")
91
+
92
+ st.button("🔍 チェックする", on_click=on_button_click)
93
+ st.divider()
94
+ # 画面を2つの列に分割
95
+ col1, col2 = st.columns([1, 4])
96
+ # 左側の列に画像を表示
97
+ with col1:
98
+ st.image(image_path, caption=selected_category, width=120)
99
+ # 右側の列にテキストボックスを配置
100
+ with col2:
101
+ if judge_text: # ユーザーが何か入力した場合のみ表示
102
+ st.markdown(custom_css, unsafe_allow_html=True) # カスタムCSSの適用
103
+ st.markdown(f'<div class="bubble">{judge_text}</div>', unsafe_allow_html=True)
104
 
105
+ # 処理ステップ数に応じてプロ��レスバーを更新する関数
106
+ def update_progress(step, total_steps):
107
+ progress = int((step / total_steps) * 100)
108
+ progress_bar.progress(progress)
109
 
110
+ total_steps = 5 # 処理を行う総ステップ数
111
+ if st.session_state.button_clicked:
112
  if judge_text:
113
+ # プログレスバーの初期化
114
+ progress_bar = st.progress(0)
115
+
116
  # トークナイズとテンソル化
117
  words = tokenizer.tokenize(judge_text)
118
  word_ids = tokenizer.convert_tokens_to_ids(words)
119
  word_tensor = torch.tensor([word_ids[:512]]) # 最大長を512に制限
120
+ update_progress(1, total_steps)
121
 
122
+ # デバイスの自動選択
123
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
124
  word_tensor = word_tensor.to(device)
125
  model = model.to(device)
126
+ update_progress(2, total_steps)
127
 
128
  # 推論
129
  with torch.no_grad():
130
  y = model(word_tensor)
131
+ update_progress(3, total_steps)
132
 
133
  # 最も近いカテゴリの決定
134
  pred = y.logits.argmax(-1)
135
+ st.write(f"最もらしい: {category_list[pred.item()]}")
136
+ update_progress(4, total_steps)
137
 
138
  # 各クラスの確率計算
139
  probabilities = torch.softmax(y.logits, dim=-1)
140
  top_prob, top_cat_indices = probabilities.topk(len(category_list))
141
+ update_progress(5, total_steps)
142
 
143
  # 確率とカテゴリ名の準備
144
+ top_probabilities = top_prob.cpu().numpy()[0]
145
  top_categories = [category_list[index] for index in top_cat_indices.cpu().numpy()[0]]
146
 
147
+ # 棒グラフの作成
148
+
149
  # 棒グラフの作成
150
  plt.figure(figsize=(10, 6))
151
+ bars = plt.bar(range(len(top_categories)), top_probabilities, color='skyblue')
152
+
153
+ # 画像の読み込みと配置
154
+ for i, (bar, category) in enumerate(zip(bars, top_categories)):
155
+ img_path = f'img/{i+1}.png' # ファイル名はcategory_listの順番+1の番号.png
156
+ image = plt.imread(img_path)
157
+ imagebox = OffsetImage(image, zoom=0.5) # zoomで画像のサイズを調整
158
+ ab = AnnotationBbox(imagebox, (bar.get_x() + bar.get_width() / 2, bar.get_height()), frameon=False, box_alignment=(0.5, -0.2))
159
+ plt.gca().add_artist(ab)
160
+
161
  plt.xlabel('カテゴリ', fontproperties=font_prop)
162
  plt.ylabel('確率', fontproperties=font_prop)
163
+ plt.ylim(0, 0.5) # y軸の範囲設定
164
+ plt.xticks(range(len(top_categories)), top_categories, rotation=45, ha="right", fontproperties=font_prop)
165
  plt.title('カテゴリ別確率', fontproperties=font_prop)
166
+ plt.show()
167
+
168
+
169
+ progress_bar.progress(100)
170
+ time.sleep(1) # 1秒待機
171
+ progress_bar.empty() # プログレスバーを削除
172
+
173
  else:
174
+ st.write("入力してください。")
img/1.png ADDED
img/10.png ADDED
img/11.png ADDED
img/12.png ADDED
img/13.png ADDED
img/14.png ADDED
img/2.png ADDED
img/3.png ADDED
img/4.png ADDED
img/5.png ADDED
img/6.png ADDED
img/7.png ADDED
img/8.png ADDED
img/9.png ADDED