Spaces:
Sleeping
Sleeping
import streamlit as st | |
from kmodes.kmodes import KModes | |
import pandas as pd | |
from matplotlib.font_manager import FontProperties | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import numpy as np | |
from sklearn.cluster import KMeans | |
from sklearn.decomposition import PCA | |
from fields.prod_feat_flat_fields import prod_feat_flat_fields | |
from fields.feature_translations import feature_translations | |
from fields.likert_flat_fields import likert_flat_fields | |
#@st.cache_data | |
def show(df): | |
# Load the Chinese font | |
chinese_font = FontProperties(fname='notosans.ttf', size=12) | |
st.title("AI Companion") | |
tab1, tab2 = st.tabs(["Likert-Based Clustering", "Feature-Based Clustering"]) | |
with tab1: | |
st.write("AI-assistant feature choices per Likert-based Personas") | |
likert_cluster_and_visualize(df, likert_flat_fields, chinese_font) | |
with tab2: | |
st.write("Clustering students based on AI-assistant feature choices") | |
clusters = perform_kmodes_clustering(df, prod_feat_flat_fields) | |
st.markdown( | |
f"<h2 style='text-align: center;'>Feature Preferences (Overall)</h2>", unsafe_allow_html=True) | |
show_radar_chart(clusters, font_prop=chinese_font) | |
st.markdown( | |
f"<h2 style='text-align: center;'>Feature Preferences (By Cluster)</h2>", unsafe_allow_html=True) | |
plot_feature_preferences(clusters, font_prop=chinese_font) | |
st.markdown( | |
f"<h2 style='text-align: center;'>Preferred AI Roles (Overall)</h2>", unsafe_allow_html=True) | |
visualize_ai_roles(df, chinese_font) | |
def visualize_ai_roles(df, chinese_font): | |
# Check if the "其他" column exists and concatenate it with the AI roles column if it does | |
if "其他" in df.columns: | |
df["你/妳想要AI扮什麼角色?"] = df["你/妳想要AI扮什麼角色?"].str.cat(df["其他"], na_rep='', sep=' ') | |
# Summarize the data | |
ai_roles_data = df["你/妳想要AI扮什麼角色?"].value_counts().head(20) | |
# Plot the data | |
plt.figure(figsize=(10, 6)) | |
ai_roles_data.plot(kind='bar', color='skyblue') | |
plt.title('Preferred AI Roles', fontproperties=chinese_font) | |
plt.xlabel('Roles', fontproperties=chinese_font) | |
plt.ylabel('Number of Responses', fontproperties=chinese_font) | |
plt.xticks(rotation=45, ha='right', fontproperties=chinese_font) | |
plt.tight_layout() | |
# Display the plot in Streamlit | |
st.pyplot(plt) | |
def perform_kmodes_clustering(df, feature_columns, n_clusters=3): | |
# Extract the relevant fields for clustering | |
cluster_data = df[feature_columns] | |
# Convert boolean features to integer type | |
cluster_data_encoded = cluster_data.astype(int) | |
# Define the K-modes model | |
km = KModes(n_clusters=n_clusters, init='Huang', n_init=5, verbose=1) | |
# Fit the cluster model | |
clusters = km.fit_predict(cluster_data_encoded) | |
# Add the cluster labels to the original dataframe | |
df['Cluster'] = clusters | |
# Create a dictionary to store dataframes for each cluster | |
cluster_dict = {} | |
for cluster in df['Cluster'].unique(): | |
cluster_df = df[df['Cluster'] == cluster] | |
cluster_dict[cluster] = cluster_df | |
return cluster_dict | |
def show_radar_chart(clusters, font_prop): | |
df_dict={ | |
'Conscious (n=340)': clusters[0], | |
'Interested (n=215)': clusters[1], | |
'Advocate (n=126)': clusters[2] | |
} | |
feature_translations_dict = dict(zip(prod_feat_flat_fields, feature_translations)) | |
persona_averages = [df[list(feature_translations_dict.keys())].mean().tolist() for df in df_dict.values()] | |
# Append the first value at the end of each list for the radar chart | |
for averages in persona_averages: | |
averages += averages[:1] | |
# Prepare the English labels for plotting | |
english_feature_labels = list(feature_translations) | |
english_feature_labels += [english_feature_labels[0]] # Repeat the first label to close the loop | |
# Number of variables we're plotting | |
num_vars = len(english_feature_labels) | |
# Split the circle into even parts and save the angles | |
angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist() | |
angles += angles[:1] # Complete the loop | |
# Set up the font properties for using a custom font | |
fig, ax = plt.subplots(figsize=(12, 12), subplot_kw=dict(polar=True)) | |
fig.subplots_adjust(left=0.1, right=0.9, top=0.9, bottom=0.1) | |
# Draw one axe per variable and add labels | |
plt.xticks(angles[:-1], english_feature_labels, color='grey', size=12, fontproperties=font_prop) | |
# Draw ylabels | |
ax.set_rlabel_position(0) | |
plt.yticks([0.2, 0.4, 0.6, 0.8, 1], ["0.2", "0.4", "0.6", "0.8", "1"], color="grey", size=7) | |
plt.ylim(0, 1) | |
# Plot data and fill with color | |
for label, data in zip(df_dict.keys(), persona_averages): | |
data += data[:1] # Complete the loop | |
ax.plot(angles, data, label=label, linewidth=1, linestyle='solid') | |
ax.fill(angles, data, alpha=0.25) | |
# Add legend | |
plt.legend(title='Personas') | |
plt.legend(loc='upper right', bbox_to_anchor=(0.1, 0.1)) | |
# Add a title | |
plt.title('Product Feature Preferences by Persona', size=20, color='grey', y=1.1, fontproperties=font_prop) | |
# Display the radar chart | |
st.pyplot(fig) | |
def plot_feature_preferences(clusters, font_prop): | |
# Given comparative table data | |
data = { | |
'Feature': [ | |
"買東西先查看產品的運輸距離(是不是當地食品)\nCheck Product Transportation Distance (Whether it is Local Food)", | |
"買東西先查看公司生產過程 多環保\nCheck Company's Eco-Friendly Production Process", | |
"買東西先查看公司工人的員工 福利多好\nCheck Company Worker Welfare", | |
"投資前查看AI摘要的消費者對公司環保的評論\nReview Consumer Eco Comments on Companies before Investing", | |
"用社交網絡認識其他環保的同學\nMeet Eco-Friendly Peers on Social Networks", | |
"買東西先了解哪些產品污染最嚴重,以便避免它們\nUnderstand Which Products are Most Polluting to Avoid Them", | |
"每個月查看我自己的環保分數報告(了解我用錢的方式多環保\nMonthly Review of My Eco Score Report (Understanding How My Spending is Eco-Friendly)", | |
"買東西先尋找有機產品 Search for\nOrganic Products Before Purchasing", | |
"跟我的AI幫手討論環保問題\nDiscuss Environmental Issues with My AI Assistant", | |
"投資前查看公司的環保認證和生態評分\nCheck Company's Environmental Certifications and Eco-Scores Before Investing", | |
"如何讓我支持的公司更環保\nHow to Make the Companies I Support More Eco-Friendly", | |
"買東西先了解我吃的動物性食品動物的生活環境\nUnderstand the Living Conditions of Animals for the Animal Products I Consume", | |
"老實說我對任何環保資訊都沒有太多興趣\nHonestly, I'm Not Very Interested in Any Eco Information", | |
"投資前比較公司的環保表現\nCompare Companies' Environmental Performance Before Investing" | |
], | |
'Conscious (n=340)': [0.367, 0.415, 0.191, 0.176, 0.079, 1.000, 0.197, 0.265, 0.144, 0.241, 0.144, 0.332, 0.044, 0.188], | |
'Interested (n=215)': [0.260, 0.163, 0.153, 0.191, 0.107, 0.000, 0.135, 0.219, 0.172, 0.186, 0.093, 0.214, 0.233, 0.130], | |
'Advocate (n=126)': [0.825, 0.881, 0.460, 0.746, 0.230, 0.881, 0.667, 0.690, 0.421, 0.865, 0.468, 0.778, 0.143, 0.738] | |
} | |
# Create a DataFrame | |
df = pd.DataFrame(data) | |
# Set the 'Feature' column as the index | |
df.set_index('Feature', inplace=True) | |
# Plot | |
fig, ax = plt.subplots(figsize=(14, 8)) | |
df.plot(kind='bar', width=0.8, ax=ax) | |
# Set titles and labels using the Chinese font where necessary | |
plt.title('Comparison of Product Feature Preferences by Persona', fontproperties=font_prop) | |
plt.ylabel('Average Score', fontproperties=font_prop) | |
plt.xlabel('Feature', fontproperties=font_prop) | |
plt.xticks(rotation=45, ha='right', fontproperties=font_prop) | |
# Set the x-tick labels to use the Chinese font | |
ax.set_xticklabels(df.index, fontproperties=font_prop, rotation=45, ha='right') | |
plt.legend(title='Personas') | |
# Ensure layout is tight so everything fits | |
plt.tight_layout() | |
# Streamlit uses st.pyplot() to display matplotlib charts | |
st.pyplot(fig) | |
def likert_cluster_and_visualize(df, likert_flat_fields, chinese_font): | |
# Clean the DataFrame column names | |
df.columns = [col.strip() for col in df.columns] | |
# Also clean the likert_flat_fields if necessary | |
likert_flat_fields = [field.strip() for field in likert_flat_fields] | |
# Prepare the likert data, dropping any rows with missing values | |
df_likert_data = df[likert_flat_fields].dropna() | |
# Perform k-means clustering | |
kmeans = KMeans(n_clusters=3, n_init=10, random_state=42).fit(df_likert_data) | |
df_likert_data['Cluster'] = kmeans.labels_ | |
# Concatenate the cluster labels with the original data | |
df_clustered = pd.concat([df, df_likert_data['Cluster']], axis=1) | |
# Aggregate the product preference data for each cluster | |
cluster_preferences = [] | |
for i in range(3): | |
cluster_data = df_clustered[df_clustered['Cluster'] == i] | |
cluster_preferences.append(cluster_data[prod_feat_flat_fields].mean()) | |
# Radar Chart Plotting | |
df_dict = { | |
'Eco-Friendly': cluster_preferences[0], | |
'Moderate': cluster_preferences[1], | |
'Frugal': cluster_preferences[2] | |
} | |
feature_translations_dict = dict(zip(prod_feat_flat_fields, feature_translations)) | |
persona_averages = [df_dict[key].tolist() for key in df_dict] | |
# Append the first value at the end of each list for the radar chart | |
for averages in persona_averages: | |
averages += averages[:1] | |
# Prepare the English labels for plotting | |
english_feature_labels = list(feature_translations) | |
english_feature_labels += [english_feature_labels[0]] # Repeat the first label to close the loop | |
# Number of variables we're plotting | |
num_vars = len(english_feature_labels) | |
# Split the circle into even parts and save the angles | |
angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist() | |
angles += angles[:1] # Complete the loop | |
# Set up the font properties for using a custom font | |
fig, ax = plt.subplots(figsize=(12, 12), subplot_kw=dict(polar=True)) | |
fig.subplots_adjust(left=0.1, right=0.9, top=0.9, bottom=0.1) | |
# Draw one axe per variable and add labels | |
plt.xticks(angles[:-1], english_feature_labels, color='grey', size=12, fontproperties=chinese_font) | |
# Draw ylabels | |
ax.set_rlabel_position(0) | |
plt.yticks([0.2, 0.4, 0.6, 0.8, 1], ["0.2", "0.4", "0.6", "0.8", "1"], color="grey", size=7) | |
plt.ylim(0, 1) | |
# Plot data and fill with color | |
for label, data in zip(df_dict.keys(), persona_averages): | |
data += data[:1] # Complete the loop | |
ax.plot(angles, data, label=label, linewidth=1, linestyle='solid') | |
ax.fill(angles, data, alpha=0.25) | |
# Add legend | |
plt.legend(title='Personas') | |
plt.legend(loc='upper right', bbox_to_anchor=(0.1, 0.1)) | |
# Add a title | |
plt.title('Product Feature Preferences by Persona', size=20, color='grey', y=1.1, fontproperties=chinese_font) | |
# Display the radar chart | |
st.pyplot(fig) |