ziran / page_attitudes.py
krishaamer's picture
Add titles
5d29fa6
raw
history blame
2.52 kB
import streamlit as st
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.font_manager import FontProperties
from fields.likert_fields import likert_fields
from fields.field_translation_mapping import field_translation_mapping
from fields.translation_mapping import translation_mapping
@st.cache_data
def show(df):
st.title("Students Attitudes (Overall)")
st.write("Students Attitudes across all Likert fields without clustering")
# Chinese font
chinese_font = FontProperties(fname='mingliu.ttf')
if df is not None:
# Rename the columns in the DataFrame for visualization
df_translated = df.rename(columns={
field: f"{field} ({field_translation_mapping[category][i]})"
for category, fields in likert_fields.items()
for i, field in enumerate(fields)
})
# Loop through each category in likert_fields to create visualizations
for category, fields in likert_fields.items():
st.markdown(
f"<h2 style='text-align: center;'>{translation_mapping[category]}</h2>", unsafe_allow_html=True)
# Calculate the number of rows needed for this category
num_fields = len(fields)
# Equivalent to ceil(num_fields / 2)
num_rows = -(-num_fields // 2)
# Create subplots with 2 columns for this category
fig, axs = plt.subplots(num_rows, 2, figsize=(15, 5 * num_rows))
axs = axs.flatten() # Flatten the array of subplots
# Add padding to fit in the Chinese titles
plt.subplots_adjust(hspace=0.4)
# Loop through each field in the category to create individual bar plots
for i, field in enumerate(fields):
# Create the bar plot
sns.countplot(
x=f"{field} ({field_translation_mapping[category][i]})", data=df_translated, ax=axs[i], palette="coolwarm")
# Add title and labels
title_chinese = field
title_english = field_translation_mapping[category][i]
axs[i].set_title(
f"{title_chinese}\n{title_english}", fontproperties=chinese_font)
axs[i].set_xlabel('Likert Scale')
axs[i].set_ylabel('Frequency')
# Remove any unused subplots
for i in range(num_fields, num_rows * 2):
fig.delaxes(axs[i])
# Show the plot
st.pyplot(fig)