amp / app.py
ogawaal's picture
Update app.py
eb54619 verified
import streamlit as st
from inference import classify_ma, get_word_attn, infer_topic
from visualization import heatmap, html_hext
ID2CAT = {
0: "マイクロアグレッションではない可能性が高い",
1: "マイクロアグレッションの可能性が高い",
}
explanation_text = """
このマイクロアグレッションチェッカーは、機械学習(AI技術のようなもの)によって、マイクロアグレッションらしい表現を検出できるように設計されています。
"""
attention_text = """
【結果を見る際の注意点】
この技術は「文中にマイクロアグレッションに結びつく要素が含まれているかどうか」を判定するモデルであり、
必ずしも「この文章の書き手がマイクロアグレッションをしている」ことを明確に示すものではありません。
判定結果を元に、改めて人間同士で「なぜ/どのようにしてマイクロアグレッションたりうるか」議論をするために利用してください。
"""
provide_by = """提供元: オールマイノリティプロジェクト
[https://all-minorities.com/](https://all-minorities.com/)
"""
st.title("マイクロアグレッション判別モデル")
st.markdown(explanation_text)
user_input = st.text_input("文章を入力してください:", key="user_input")
if st.button("判定", key="run"):
if not user_input:
st.warning("入力が空です。何か入力してください。")
else:
pred_class, input_ids, attention_list = classify_ma(user_input)
st.markdown(f"判定結果: **{ID2CAT[pred_class]}**")
if pred_class == 1:
topic_dist, ll = infer_topic(user_input)
words_atten = get_word_attn(input_ids, attention_list)
html_hext_result = html_hext(((word, attn) for word, attn in words_atten))
st.markdown(html_hext_result, unsafe_allow_html=True)
data = topic_dist.reshape(-1, 1)
st.plotly_chart(heatmap(data), use_container_width=True)
st.divider()
st.markdown(attention_text)
st.divider()
st.markdown(provide_by)