DialogueChecker / app.py
kiriishi's picture
Upload 61 files
5c167d2 verified
raw
history blame
27.9 kB
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import random
import torch
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans
from matplotlib.font_manager import FontProperties
import pandas as pd
import seaborn as sns
from collections import Counter
import matplotlib.pyplot as plt
import numpy as np
# モデルとトークナイザーのロード
import time
from transformers import AutoModel
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import streamlit as st
import torch
import matplotlib.pyplot as plt
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
import matplotlib.ticker as mticker
import numpy as np
import time
from transformers import BertJapaneseTokenizer, BertModel
from sklearn.metrics.pairwise import cosine_similarity
import re
import os
# モデルとトークナイザーのパス
model_path = 'use14/bert-base-japanese-v3/2024-0208-0323/model'
tokenizer_path = 'use14/bert-base-japanese-v3/2024-0208-0323/tokenizer'
# トークナイザーとモデルのロード
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path, num_labels=14)
st.set_page_config(
page_title="台詞校正ツール",
layout="wide")
font_path = 'NotoSansCJKjp-Bold.otf' # 実際のフォントのパスに置き換えてください
font_prop = FontProperties(fname=font_path)
# Streamlitアプリのタイトル
st.header("入力台詞のチェック")
# セッション状態にキーが存在しない場合は、初期値を設定
if 'button_clicked' not in st.session_state:
st.session_state.button_clicked = False
def on_button_click():
# ボタンがクリックされた時の処理
st.session_state.button_clicked = True
# 吹き出し風表示用のカスタムCSS
custom_css = """
<style>
.bubble {
position: relative;
background: #EDF2ED;
border-radius: .4em;
padding: 10px;
max-width: 95%; /* 吹き出しの最大幅を90%に設定 */
word-wrap: break-word; /* 長い単語でも折り返しを保証 */
color: #555; /* フォントの色を灰色に設定 */
}
.bubble::after {
content: '';
position: absolute;
top: 10px;
left: -10px;
width: 0;
height: 0;
border: 10px solid transparent;
border-right-color: #EDF2ED;
border-left: 0;
margin-top: 5px;
margin-left: 0;
}
</style>
"""
# CSSを使ってプログレスバーの色を変更
st.markdown("""
<style>
/* プログレスバーの色を変更 */
.stProgress > div > div > div > div {
background-color: #8CAB8B;
}
</style>
""", unsafe_allow_html=True)
# カテゴリリスト
category_list = [
'5_紫上鏡一[教師,大人,タメ口]', '20_見嶋千里[怖め,大人,粗雑な言葉,特徴的な笑い]', '29_白鳥王子[キザ,大人,調子いい]',
'30_白城院素子[お嬢様,少女]', '32_水陰那月[ネガティブ,少年,敬語]', '50_緋崎平一郎[元気,少年,やんちゃ]',
'76_百知瑠璃[元気,少女,です!,写真]', '91_桜結衣[元気,少女,タメ口,アイドル]', '101_御伽美夜子[落ち着いている,大人,女性口調]',
'121_荊棘従道[執事,大人,敬語(丁寧語)]', '133_司馬萌香[姉御,少女,粗雑(ヤンキー)]', '134_菜野花[落ち着いてる,少女,敬語]',
'139_黒冬和馬[少しそっけない,少年,タメ口]', '142_四涼礼子[ダウナー,少女,語尾~,めんどくさがり]'
]
# キャラクターの名前と属性を抽出
character_names = [category.split('_')[1].split('[')[0] for category in category_list]
character_attributes = ['[' + category.split('[')[1] for category in category_list]
# カテゴリ選択用のセレクトボックス
selected_category = st.selectbox("1.目標キャラクターを選択", category_list)
# 選択されたカテゴリに対応する画像ファイル名の決定
# カテゴリリストのインデックスを取得し、それに1を加えることで1から始まる画像ファイル番号を作成
image_file_number = category_list.index(selected_category) + 1
image_path = f"img/{image_file_number}.png"
image_width=300
##----------------------------------------------------
judge_text = st.text_input("2.セリフを入力 //例: 貴方達も迷ったんですか?, よし、変身完了だ。, 勿論幽霊は抜きにして…でしょうね。//ですわ。,だろ!,ですね。","生徒たちの安全を守りたい。")
st.button("🔍 チェックする", on_click=on_button_click)
st.divider()
##----------------------------------------------------
# 画面を2つの列に分割
col1, col2 = st.columns([1, 4])
# 左側の列に画像を表示
with col1:
st.image(image_path, width=120)
# 右側の列にテキストボックスを配置
with col2:
if judge_text: # ユーザーが何か入力した場合のみ表示
st.markdown(custom_css, unsafe_allow_html=True) # カスタムCSSの適用
st.markdown(f'<div class="bubble">{judge_text}</div>', unsafe_allow_html=True)
selected_character=selected_category
##----------------------------------------------------
tab1, tab2, tab3 = st.tabs(["📈 話体", "💭 類似話題", "🔑 キーワード話題"])
data = np.random.randn(10, 1)
# 処理ステップ数に応じてプログレスバーを更新する関数
def update_progress(step, total_steps):
progress = int((step / total_steps) * 100)
progress_bar.progress(progress)
##----------------------------------------------------
total_steps = 5 # 処理を行う総ステップ数
if st.session_state.button_clicked:
if judge_text:
with tab1:
# プログレスバーの初期化
progress_bar = st.progress(0)
# トークナイズとテンソル化
words = tokenizer.tokenize(judge_text)
word_ids = tokenizer.convert_tokens_to_ids(words)
word_tensor = torch.tensor([word_ids[:512]]) # 最大長を512に制限
update_progress(1, total_steps)
# デバイスの自動選択
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
word_tensor = word_tensor.to(device)
model = model.to(device)
update_progress(2, total_steps)
# 推論
with torch.no_grad():
y = model(word_tensor)
update_progress(3, total_steps)
# 最も近いカテゴリの決定
pred = y.logits.argmax(-1)
update_progress(4, total_steps)
# 各クラスの確率計算
probabilities = torch.softmax(y.logits, dim=-1)
top_prob, top_cat_indices = probabilities.topk(len(category_list))
update_progress(5, total_steps)
top_probabilities = top_prob.cpu().numpy()[0]
top_categories = [category_list[index] for index in top_cat_indices.cpu().numpy()[0]]
selected_character_index = top_categories.index(selected_character) if selected_character in top_categories else -1
# 選択したキャラクターがリスト内のどの位置にあるかを判定
if selected_character in top_categories:
selected_character_index = top_categories.index(selected_character)
if selected_character_index == 0:
result_text = "OKです!"
elif selected_character_index in [1, 2]:
result_text = "OKです"
else:
result_text = "違うかも?"
else:
# リストにキャラクターがない場合
result_text = "違うかも?"
st.session_state.result_text = result_text
# 結果の表示
st.write(result_text)
# st.write(f"最も近い: {category_list[pred.item()]}")
# 確率とカテゴリ名の準備
top_probabilities = top_prob.cpu().numpy()[0]
top_categories = [category_list[index] for index in top_cat_indices.cpu().numpy()[0]]
# 新しい図と軸オブジェクトの作成
fig, ax = plt.subplots(figsize=(10, 6))
# すべての確率が0.2を下回っているかどうかをチェック
all_below_0_2 = all(probability < 0.2 for probability in top_probabilities)
# 確率とインデックスのタプルのリストを作成
# 棒グラフの作成、条件に応じて色を変更
for i, (probability, category) in enumerate(zip(top_probabilities, top_categories)):
# 確率とインデックスのタプルのリストを作成
probability_index_tuples = list(enumerate(top_probabilities))
# 確率でソートして上位3つを取得
sorted_tuples = sorted(probability_index_tuples, key=lambda x: x[1], reverse=True)
top_3_indices = [t[0] for t in sorted_tuples[:3]]
top_3_probabilities = [t[1] for t in sorted_tuples[:3]]
# 1位と2位の確率の差が大きいかどうかを評価
# ここでは例として、1位の確率が2位の確率よりも27.6%以上大きい場合を「大きい」と判断
is_first_place_significantly_higher = top_3_probabilities[0] - top_3_probabilities[1] > 0.276
# 棒グラフの作成、条件に応じて色を変更
for i, (probability, category) in enumerate(zip(top_probabilities, top_categories)):
# 1位が顕著に大きい場合、1位の棒をオレンジに設定
if is_first_place_significantly_higher and i == top_3_indices[0]:
color = 'orange'
elif probability < 0.1:
color = 'grey'
elif all_below_0_2:
color = 'grey'
else:
color = 'skyblue' # それ以外の場合の色
ax.bar(i, probability, color=color)
# 棒の上部または画像の上に数値を表示(パーセント表示に変更)
text_y = probability if probability <= 1 else 1
ax.text(i, text_y, f'{probability * 100:.1f}%', ha='center', va='bottom' if probability <= 1 else 'top')
# カテゴリリスト内でのカテゴリ名のインデックスを探し、その位置に基づいて画像ファイルを参照
for i, category in enumerate(top_categories):
# カテゴリリスト内の位置(インデックス+1)を使って画像ファイルパスを指定
position = category_list.index(category) + 1
img_path = f'img/{position}.png'
# 画像の読み込みと配置
image = plt.imread(img_path)
imagebox = OffsetImage(image, zoom=0.1)
ab = AnnotationBbox(imagebox, (i, 0), frameon=False, box_alignment=(0.5, -0.2))
ax.add_artist(ab)
# y軸の範囲設定を調整
# 縦軸の範囲設定と横線の描画を調整
max_probability = max(top_probabilities)
# 最大確率が0.3以上ならば、それに合わせてy軸の上限を設定
y_max = max(0.3, np.ceil(max_probability / 0.1) * 0.1)
ax.set_ylim(0, y_max)
ax.yaxis.set_major_formatter(mticker.PercentFormatter(xmax=1))
# 0.1刻みで横線を引く、0.3, 0.6, 0.9 は太くする
for y in np.arange(0.1, y_max + 0.1, 0.1):
if y in [0.3, 0.6, 0.9]:
ax.axhline(y=y, color='blue', linestyle='-', linewidth=2) # 太い横線
else:
ax.axhline(y=y, color='grey', linestyle='--', linewidth=0.5) # 通常の横線
ax.set_xlabel('', fontproperties=font_prop)
ax.set_ylabel('確率', fontproperties=font_prop)
ax.set_xticks(range(len(top_categories)))
ax.set_xticklabels(top_categories, rotation=45, ha="right", fontproperties=font_prop)
# ax.set_title('カテゴリ別確率', fontproperties=font_prop)
# Streamlitでグラフを表示
st.pyplot(fig)
progress_bar.progress(100)
time.sleep(0.5) # 1秒待機
progress_bar.empty() # プログレスバーを削除
##----------------------------------------------------
with tab2:
with st.spinner('処理中...'):
# モデルとトークナイザーの初期化
ruiji_model_name = 'cl-tohoku/bert-base-japanese-v3'
ruiji_tokenizer = BertJapaneseTokenizer.from_pretrained(ruiji_model_name)
ruiji_model = BertModel.from_pretrained(ruiji_model_name)
def extract_keywords(sentence, tokenizer, num_keywords=10, min_length=2):
# トークナイズして品詞タグを取得
tokens = tokenizer.tokenize(sentence)
# 文字数が min_length 以上のトークンのみを選択
# 一文字かつひらがなの単語を除外するフィルタリング条件を追加
filtered_tokens = [token for token in tokens if len(token) >= min_length and not re.match(r'^[ぁ-ん、。]$', token)]
return filtered_tokens[:num_keywords]
def find_sentences_with_specific_keywords(sentences, keywords):
results = {}
sentence_to_id = {} # 文とそのIDをマッピングする辞書
id_to_sentence = {} # IDと文をマッピングする辞書
first_reference = {} # 各IDに対して最初に参照されたキーワードを記録
next_id = 1 # 次に割り当てるID
for sentence in sentences:
for keyword in keywords:
if keyword in sentence:
if sentence not in sentence_to_id:
# 文に新しいIDを割り当て、最初の参照として記録
sentence_to_id[sentence] = next_id
id_to_sentence[next_id] = sentence
first_reference[next_id] = keyword # このIDが最初に参照されたキーワード
next_id += 1
# 結果にIDと共に文を追加
results.setdefault(keyword, []).append(sentence_to_id[sentence])
# IDを参照して文を取得し、重複を示す情報を付加して返す
formatted_results = {}
seen_sentences = set() # 既に出力された文のIDを記録
for keyword, ids in results.items():
formatted_sentences = []
for id in ids:
if id in seen_sentences:
# 既に出力された文は、"と同じ"で参照
formatted_sentence = f"({id})と同じ"
else:
# 初めて出力される文は通常通り表示し、seen_sentencesに追加
formatted_sentence = f"({id}) {id_to_sentence[id]}"
seen_sentences.add(id)
formatted_sentences.append(formatted_sentence)
formatted_results[keyword] = formatted_sentences
return formatted_results
def get_sentence_embedding(sentence):
inputs = ruiji_tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=128)
outputs = ruiji_model(**inputs)
sentence_embedding = outputs.last_hidden_state.mean(dim=1).squeeze().detach().numpy()
return sentence_embedding
def load_sentences(file_path):
with open(file_path, 'r', encoding='utf-8') as file:
sentences = file.readlines()
return [sentence.strip() for sentence in sentences]
def process_input(character_index, input_sentence, characters_list):
file_name = characters_list[character_index] + ".txt"
file_path = os.path.join("use14", file_name)
sentences = load_sentences(file_path)
input_embedding = get_sentence_embedding(input_sentence)
similarities = []
for sentence in sentences:
sentence_embedding = get_sentence_embedding(sentence)
similarity = cosine_similarity([input_embedding], [sentence_embedding])[0]
similarities.append((sentence, similarity[0]))
similar_sentences = sorted(similarities, key=lambda x: x[1], reverse=True)[:10]
# キーワードの抽出
keywords = extract_keywords(input_sentence, tokenizer)
# 各キーワードに関連する台詞の抽出
results = find_sentences_with_specific_keywords(sentences, keywords)
return similar_sentences, results # mentioned_sentences を results に置き換え
character_index = category_list.index(selected_category)
input_sentence = judge_text
characters_list=['5_紫上鏡一_教師大人タメ', '20_見嶋千里_怖め大人粗雑', '29_白鳥王子_キザ大人調子', '30_白城院素子_お嬢様少女丁寧', '32_水陰那月_ネガ少年敬語', '50_緋崎平一郎_元気少年やんちゃ', '76_百知瑠璃_元気少女です', '91_桜結衣_元気少女タメ', '101_御伽美夜子_落着大人女性口調', '121_荊棘従道_執事大人丁寧', '133_司馬萌香_姉御少女粗雑', '134_菜野花_落着少女敬語', '139_黒冬和馬_普通少年タメ', '142_四涼礼子_ダウナー少女語尾伸']
similar_sentences, mentioned_sentences = process_input(character_index, input_sentence, characters_list)
with st.container():
st.subheader("類似度話題:")
# 類似度の閾値
similarity_threshold = 0.85
# 類似度の高い台詞をDataFrameに変換
similar_df = pd.DataFrame(similar_sentences, columns=['台詞', '類似度'])
# 閾値より上のデータが存在するかどうかをチェック
if not similar_df[similar_df['類似度'] >= similarity_threshold].empty:
# 閾値より上のデータが存在する場合の処理
st.dataframe(similar_df[similar_df['類似度'] >= similarity_threshold].reset_index(drop=True).style.format({'類似度': "{:.2f}"}))
# 閾値より下のデータが存在するかどうかをチェック
if not similar_df[similar_df['類似度'] < similarity_threshold].empty:
# 閾値より下のデータが存在する場合、警告メッセージを表示
st.markdown("<span style='color:red;'>以下、信頼度が低い可能性があります</span>", unsafe_allow_html=True)
# 閾値より下のデータを表示
st.dataframe(similar_df[similar_df['類似度'] < similarity_threshold].reset_index(drop=True).style.format({'類似度': "{:.2f}"}))
with tab3:
with st.container():
# キーワードのリストを取得
keywords_list = list(mentioned_sentences.keys())
# キーワードを文字列として結合し、表示
st.subheader(f"キーワード話題: [{', '.join(keywords_list)}]")
# キーワードに関連する台詞を表示
for keyword, sentences in mentioned_sentences.items():
st.text(f"キーワード '{keyword}' に関連する既存台詞:")
# 台詞からIDを抽出し、元の文と一緒にリストに格納
sentences_with_id = [(int(re.search(r"\((\d+)\)", sentence).group(1)), sentence) for sentence in sentences]
# IDで並び替え
sorted_sentences_with_id = sorted(sentences_with_id, key=lambda x: x[0])
# 並び替えた後の台詞リストを作成(IDを除外)
sorted_sentences = [sentence for _, sentence in sorted_sentences_with_id]
# 台詞をリストに変換してDataFrameを作成
keyword_df = pd.DataFrame(sorted_sentences, columns=['台詞'])
# DataFrameを表示
st.dataframe(keyword_df)
##----------------------------------------------------
else:
st.write("入力してください。")
##----------------------------------------------------
# 選択されたキャラクターのインデックスに基づいてデータを取得する関数
def get_character_data(selected_index):
data = []
with open("result_ninsho.txt", "r", encoding="utf-8") as file:
for line in file:
parts = line.strip().split(',')
if int(parts[0]) == selected_index + 1:
data.append(parts)
return data
# 頻出単語データを整形してテーブルに表示する関数
def display_word_frequency_table(data):
hinshi_categories = ['一人称代名詞', '二人称代名詞', '三人称代名詞', 'その他の代名詞', '名詞', '動詞', '形容詞・形容動詞', '副詞', '助動詞', '感動詞']
df = pd.DataFrame(columns=['品詞', '頻出単語'])
rows_list = [] # 新しい行を一時的に格納するリスト
for row in data:
category_index = int(row[1]) - 1 # 品詞のカテゴリインデックス
words = row[2:]
word_counts = {}
for word in words:
if word not in word_counts:
word_counts[word] = 1
else:
word_counts[word] += 1
# 単語を出現頻度順に並べ替え
sorted_words = sorted(word_counts.items(), key=lambda x: x[1], reverse=True)
formatted_words = [f"**{word}({count})**" if i == 0 else f"{word}({count})" for i, (word, count) in enumerate(sorted_words)]
# 新しい行をリストに追加
rows_list.append({'品詞': hinshi_categories[category_index], '頻出単語': ', '.join(formatted_words)})
# DataFrameに行を追加
df = pd.concat([df, pd.DataFrame(rows_list)], ignore_index=True)
return df
st.markdown(
"""
<style>
section[data-testid="stSidebar"] {
width: 500px !important;
}
</style>
""",
unsafe_allow_html=True
)
def display_df_as_html_table_with_bold_text(df):
# HTMLテーブルのヘッダーを定義
html = "<table>"
html += "<tr>"
for col in df.columns:
html += f"<th>{col}</th>"
html += "</tr>"
# DataFrameの各行データをHTMLテーブルの行として追加
for index, row in df.iterrows():
html += "<tr>"
for col in df.columns:
# 特定の文字列を太字にしたい場合、ここで条件をチェック
# ここでは簡単な例として、文字列内の数字が含まれていれば太字にします
if any(char.isdigit() for char in str(row[col])):
cell_value = f"<strong>{row[col]}</strong>"
else:
cell_value = row[col]
html += f"<td>{cell_value}</td>"
html += "</tr>"
html += "</table>"
# 完成したHTMLテーブルをStreamlitで表示
st.markdown(html, unsafe_allow_html=True)
def display_df_with_markdown(df):
for index, row in df.iterrows():
# 品詞列を表示
st.markdown(f"**品詞**: {row['品詞']}")
# 頻出単語列をMarkdown形式で表示(太字を含む)
st.markdown(f"**頻出単語**: {row['頻出単語']}", unsafe_allow_html=True)
def display_df_with_markdown_table(df):
# テーブルのヘッダー
markdown_table = "品詞 | 頻出単語\n-|-|\n"
# DataFrameの各行をループしてテーブルの行を作成
for _, row in df.iterrows():
# 頻出単語列のデータを取得し、カンマで分割
words = row['頻出単語'].split(', ')
# 単語のリストを処理
formatted_words = []
for word in words:
# 出現頻度を取得
freq = int(word[word.find("(")+1:word.find(")")])
if freq == 1:
# 出現頻度が1の場合、薄い文字で表示
formatted_word = f"<span style='color: #bbb;'>{word}</span>"
else:
formatted_word = word
formatted_words.append(formatted_word)
# 最も頻出する単語(リストの最初の要素)を太字に
if formatted_words:
formatted_words[0] = f"**{formatted_words[0]}**"
# 処理した単語のリストをカンマで結合
formatted_words_str = ', '.join(formatted_words)
# Markdownテーブルの行を追加
markdown_table += f"{row['品詞']} | {formatted_words_str}\n"
# 生成したMarkdownテーブルをStreamlitで表示
st.markdown(markdown_table, unsafe_allow_html=True)
##----------------------------------------------------
with st.sidebar:
st.title("台詞校正ツール")
st.header('キャラクター辞典')
selected_index = st.selectbox("キャラクターを選択", range(len(character_names)), format_func=lambda x: character_names[x])
# 選択されたキャラクターの情報をメイン画面に表示
character_name = character_names[selected_index]
character_attribute = character_attributes[selected_index]
image_path = f"img/full{selected_index + 1}.png" # img/1.pngから始まる
st.image(image_path, caption=f"{character_name}", width=250)
st.write(f"{character_name} の特徴: {character_attribute}")
character_names = [category.split('_')[1].split('[')[0] for category in category_list]
character_data = get_character_data(selected_index)
df = display_word_frequency_table(character_data)
# DataFrameを取得(この部分は既にあるコードを使用)
character_data = get_character_data(selected_index)
df = display_word_frequency_table(character_data)
# "頻出単語テーブル"としてセクションを表示
st.write("頻出単語テーブル")
# カスタムMarkdownテーブルでDataFrameの内容を表示
display_df_with_markdown_table(df)
import streamlit as st
# 画像データの例
image_path = "img/1.png"
# Expanderで画像を表示
expander = st.expander("ここをクリックして画像を表示")
with expander:
st.image(image_path, caption="画像のキャプション")