ziran / page_ai.py
krishaamer's picture
Show radar chart for Likert cluters; add tabs
b2ce8cc
import streamlit as st
from kmodes.kmodes import KModes
import pandas as pd
from matplotlib.font_manager import FontProperties
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from fields.prod_feat_flat_fields import prod_feat_flat_fields
from fields.feature_translations import feature_translations
from fields.likert_flat_fields import likert_flat_fields
#@st.cache_data
def show(df):
# Load the Chinese font
chinese_font = FontProperties(fname='notosans.ttf', size=12)
st.title("AI Companion")
tab1, tab2 = st.tabs(["Likert-Based Clustering", "Feature-Based Clustering"])
with tab1:
st.write("AI-assistant feature choices per Likert-based Personas")
likert_cluster_and_visualize(df, likert_flat_fields, chinese_font)
with tab2:
st.write("Clustering students based on AI-assistant feature choices")
clusters = perform_kmodes_clustering(df, prod_feat_flat_fields)
st.markdown(
f"<h2 style='text-align: center;'>Feature Preferences (Overall)</h2>", unsafe_allow_html=True)
show_radar_chart(clusters, font_prop=chinese_font)
st.markdown(
f"<h2 style='text-align: center;'>Feature Preferences (By Cluster)</h2>", unsafe_allow_html=True)
plot_feature_preferences(clusters, font_prop=chinese_font)
st.markdown(
f"<h2 style='text-align: center;'>Preferred AI Roles (Overall)</h2>", unsafe_allow_html=True)
visualize_ai_roles(df, chinese_font)
def visualize_ai_roles(df, chinese_font):
# Check if the "其他" column exists and concatenate it with the AI roles column if it does
if "其他" in df.columns:
df["你/妳想要AI扮什麼角色?"] = df["你/妳想要AI扮什麼角色?"].str.cat(df["其他"], na_rep='', sep=' ')
# Summarize the data
ai_roles_data = df["你/妳想要AI扮什麼角色?"].value_counts().head(20)
# Plot the data
plt.figure(figsize=(10, 6))
ai_roles_data.plot(kind='bar', color='skyblue')
plt.title('Preferred AI Roles', fontproperties=chinese_font)
plt.xlabel('Roles', fontproperties=chinese_font)
plt.ylabel('Number of Responses', fontproperties=chinese_font)
plt.xticks(rotation=45, ha='right', fontproperties=chinese_font)
plt.tight_layout()
# Display the plot in Streamlit
st.pyplot(plt)
def perform_kmodes_clustering(df, feature_columns, n_clusters=3):
# Extract the relevant fields for clustering
cluster_data = df[feature_columns]
# Convert boolean features to integer type
cluster_data_encoded = cluster_data.astype(int)
# Define the K-modes model
km = KModes(n_clusters=n_clusters, init='Huang', n_init=5, verbose=1)
# Fit the cluster model
clusters = km.fit_predict(cluster_data_encoded)
# Add the cluster labels to the original dataframe
df['Cluster'] = clusters
# Create a dictionary to store dataframes for each cluster
cluster_dict = {}
for cluster in df['Cluster'].unique():
cluster_df = df[df['Cluster'] == cluster]
cluster_dict[cluster] = cluster_df
return cluster_dict
def show_radar_chart(clusters, font_prop):
df_dict={
'Conscious (n=340)': clusters[0],
'Interested (n=215)': clusters[1],
'Advocate (n=126)': clusters[2]
}
feature_translations_dict = dict(zip(prod_feat_flat_fields, feature_translations))
persona_averages = [df[list(feature_translations_dict.keys())].mean().tolist() for df in df_dict.values()]
# Append the first value at the end of each list for the radar chart
for averages in persona_averages:
averages += averages[:1]
# Prepare the English labels for plotting
english_feature_labels = list(feature_translations)
english_feature_labels += [english_feature_labels[0]] # Repeat the first label to close the loop
# Number of variables we're plotting
num_vars = len(english_feature_labels)
# Split the circle into even parts and save the angles
angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist()
angles += angles[:1] # Complete the loop
# Set up the font properties for using a custom font
fig, ax = plt.subplots(figsize=(12, 12), subplot_kw=dict(polar=True))
fig.subplots_adjust(left=0.1, right=0.9, top=0.9, bottom=0.1)
# Draw one axe per variable and add labels
plt.xticks(angles[:-1], english_feature_labels, color='grey', size=12, fontproperties=font_prop)
# Draw ylabels
ax.set_rlabel_position(0)
plt.yticks([0.2, 0.4, 0.6, 0.8, 1], ["0.2", "0.4", "0.6", "0.8", "1"], color="grey", size=7)
plt.ylim(0, 1)
# Plot data and fill with color
for label, data in zip(df_dict.keys(), persona_averages):
data += data[:1] # Complete the loop
ax.plot(angles, data, label=label, linewidth=1, linestyle='solid')
ax.fill(angles, data, alpha=0.25)
# Add legend
plt.legend(title='Personas')
plt.legend(loc='upper right', bbox_to_anchor=(0.1, 0.1))
# Add a title
plt.title('Product Feature Preferences by Persona', size=20, color='grey', y=1.1, fontproperties=font_prop)
# Display the radar chart
st.pyplot(fig)
def plot_feature_preferences(clusters, font_prop):
# Given comparative table data
data = {
'Feature': [
"買東西先查看產品的運輸距離(是不是當地食品)\nCheck Product Transportation Distance (Whether it is Local Food)",
"買東西先查看公司生產過程 多環保\nCheck Company's Eco-Friendly Production Process",
"買東西先查看公司工人的員工 福利多好\nCheck Company Worker Welfare",
"投資前查看AI摘要的消費者對公司環保的評論\nReview Consumer Eco Comments on Companies before Investing",
"用社交網絡認識其他環保的同學\nMeet Eco-Friendly Peers on Social Networks",
"買東西先了解哪些產品污染最嚴重,以便避免它們\nUnderstand Which Products are Most Polluting to Avoid Them",
"每個月查看我自己的環保分數報告(了解我用錢的方式多環保\nMonthly Review of My Eco Score Report (Understanding How My Spending is Eco-Friendly)",
"買東西先尋找有機產品 Search for\nOrganic Products Before Purchasing",
"跟我的AI幫手討論環保問題\nDiscuss Environmental Issues with My AI Assistant",
"投資前查看公司的環保認證和生態評分\nCheck Company's Environmental Certifications and Eco-Scores Before Investing",
"如何讓我支持的公司更環保\nHow to Make the Companies I Support More Eco-Friendly",
"買東西先了解我吃的動物性食品動物的生活環境\nUnderstand the Living Conditions of Animals for the Animal Products I Consume",
"老實說我對任何環保資訊都沒有太多興趣\nHonestly, I'm Not Very Interested in Any Eco Information",
"投資前比較公司的環保表現\nCompare Companies' Environmental Performance Before Investing"
],
'Conscious (n=340)': [0.367, 0.415, 0.191, 0.176, 0.079, 1.000, 0.197, 0.265, 0.144, 0.241, 0.144, 0.332, 0.044, 0.188],
'Interested (n=215)': [0.260, 0.163, 0.153, 0.191, 0.107, 0.000, 0.135, 0.219, 0.172, 0.186, 0.093, 0.214, 0.233, 0.130],
'Advocate (n=126)': [0.825, 0.881, 0.460, 0.746, 0.230, 0.881, 0.667, 0.690, 0.421, 0.865, 0.468, 0.778, 0.143, 0.738]
}
# Create a DataFrame
df = pd.DataFrame(data)
# Set the 'Feature' column as the index
df.set_index('Feature', inplace=True)
# Plot
fig, ax = plt.subplots(figsize=(14, 8))
df.plot(kind='bar', width=0.8, ax=ax)
# Set titles and labels using the Chinese font where necessary
plt.title('Comparison of Product Feature Preferences by Persona', fontproperties=font_prop)
plt.ylabel('Average Score', fontproperties=font_prop)
plt.xlabel('Feature', fontproperties=font_prop)
plt.xticks(rotation=45, ha='right', fontproperties=font_prop)
# Set the x-tick labels to use the Chinese font
ax.set_xticklabels(df.index, fontproperties=font_prop, rotation=45, ha='right')
plt.legend(title='Personas')
# Ensure layout is tight so everything fits
plt.tight_layout()
# Streamlit uses st.pyplot() to display matplotlib charts
st.pyplot(fig)
def likert_cluster_and_visualize(df, likert_flat_fields, chinese_font):
# Clean the DataFrame column names
df.columns = [col.strip() for col in df.columns]
# Also clean the likert_flat_fields if necessary
likert_flat_fields = [field.strip() for field in likert_flat_fields]
# Prepare the likert data, dropping any rows with missing values
df_likert_data = df[likert_flat_fields].dropna()
# Perform k-means clustering
kmeans = KMeans(n_clusters=3, n_init=10, random_state=42).fit(df_likert_data)
df_likert_data['Cluster'] = kmeans.labels_
# Concatenate the cluster labels with the original data
df_clustered = pd.concat([df, df_likert_data['Cluster']], axis=1)
# Aggregate the product preference data for each cluster
cluster_preferences = []
for i in range(3):
cluster_data = df_clustered[df_clustered['Cluster'] == i]
cluster_preferences.append(cluster_data[prod_feat_flat_fields].mean())
# Radar Chart Plotting
df_dict = {
'Eco-Friendly': cluster_preferences[0],
'Moderate': cluster_preferences[1],
'Frugal': cluster_preferences[2]
}
feature_translations_dict = dict(zip(prod_feat_flat_fields, feature_translations))
persona_averages = [df_dict[key].tolist() for key in df_dict]
# Append the first value at the end of each list for the radar chart
for averages in persona_averages:
averages += averages[:1]
# Prepare the English labels for plotting
english_feature_labels = list(feature_translations)
english_feature_labels += [english_feature_labels[0]] # Repeat the first label to close the loop
# Number of variables we're plotting
num_vars = len(english_feature_labels)
# Split the circle into even parts and save the angles
angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist()
angles += angles[:1] # Complete the loop
# Set up the font properties for using a custom font
fig, ax = plt.subplots(figsize=(12, 12), subplot_kw=dict(polar=True))
fig.subplots_adjust(left=0.1, right=0.9, top=0.9, bottom=0.1)
# Draw one axe per variable and add labels
plt.xticks(angles[:-1], english_feature_labels, color='grey', size=12, fontproperties=chinese_font)
# Draw ylabels
ax.set_rlabel_position(0)
plt.yticks([0.2, 0.4, 0.6, 0.8, 1], ["0.2", "0.4", "0.6", "0.8", "1"], color="grey", size=7)
plt.ylim(0, 1)
# Plot data and fill with color
for label, data in zip(df_dict.keys(), persona_averages):
data += data[:1] # Complete the loop
ax.plot(angles, data, label=label, linewidth=1, linestyle='solid')
ax.fill(angles, data, alpha=0.25)
# Add legend
plt.legend(title='Personas')
plt.legend(loc='upper right', bbox_to_anchor=(0.1, 0.1))
# Add a title
plt.title('Product Feature Preferences by Persona', size=20, color='grey', y=1.1, fontproperties=chinese_font)
# Display the radar chart
st.pyplot(fig)