krishaamer commited on
Commit
e91f4d2
·
1 Parent(s): b43f093

Try to make a spider radar chart

Browse files
fields/feature_translations.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 14 fields
2
+ feature_translations = [
3
+ "Eco Mat.\n在購買前了解產品的環保材料",
4
+ "Eco Pack.\n在購買前了解產品的環保包裝",
5
+ "Local Eco Biz.\n了解當地環保企業",
6
+ "CSR Review.\n投資前了解企業的CSR",
7
+ "Eco Friends\n用社交網絡認識環保朋友",
8
+ "Prod. Origin\n了解產品來源",
9
+ "Carbon Track.\n每月回顧個人碳足跡",
10
+ "Eco Services\n尋找環保服務",
11
+ "Eco News\n個性化環保新聞",
12
+ "Pre-buy CSR\n購買前了解企業的CSR",
13
+ "Eco Initiatives\n支持當地環保計劃",
14
+ "Eco Practices\n了解企業的環保實踐",
15
+ "No Extra Info\n不需要額外的環境資訊",
16
+ "CSR Compare\n投資前比較企業的CSR"
17
+ ]
fields/prod_feat_flat_fields.py CHANGED
@@ -1,3 +1,4 @@
 
1
  prod_feat_flat_fields = [
2
  '想像一下AI陪伴能幫你/妳回答超多問題了。你/妳比較想知道的環保資訊是哪些? (買東西先查看產品的運輸距離(是不是當地食品))',
3
  '想像一下AI陪伴能幫你/妳回答超多問題了。你/妳比較想知道的環保資訊是哪些? (買東西先查看公司生產過程多環保)',
 
1
+ # 14 fields
2
  prod_feat_flat_fields = [
3
  '想像一下AI陪伴能幫你/妳回答超多問題了。你/妳比較想知道的環保資訊是哪些? (買東西先查看產品的運輸距離(是不是當地食品))',
4
  '想像一下AI陪伴能幫你/妳回答超多問題了。你/妳比較想知道的環保資訊是哪些? (買東西先查看公司生產過程多環保)',
page_shopping.py CHANGED
@@ -1,12 +1,17 @@
1
  import streamlit as st
 
 
 
2
  import pandas as pd
 
 
3
 
4
-
5
- @st.cache_data
6
  def show(df):
7
  st.write("Clustering Students based on Product Feature choices")
8
  show_boycott_count(df)
9
-
 
10
 
11
  def show_boycott_count(df):
12
  # Count the number of people who have invested and who have not
@@ -18,3 +23,83 @@ def show_boycott_count(df):
18
 
19
  # Display the DataFrame as a table in Streamlit
20
  st.table(investment_table)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ from kmodes.kmodes import KModes
3
+ from matplotlib.font_manager import FontProperties
4
+ import matplotlib.pyplot as plt
5
  import pandas as pd
6
+ from fields.prod_feat_flat_fields import prod_feat_flat_fields
7
+ from fields.feature_translations import feature_translations
8
 
9
+ #@st.cache_data
 
10
  def show(df):
11
  st.write("Clustering Students based on Product Feature choices")
12
  show_boycott_count(df)
13
+ clusters = perform_kmodes_clustering(df, prod_feat_flat_fields)
14
+ show_radar_chart(clusters, feature_translations=feature_translations, font_path='mingliu.ttf')
15
 
16
  def show_boycott_count(df):
17
  # Count the number of people who have invested and who have not
 
23
 
24
  # Display the DataFrame as a table in Streamlit
25
  st.table(investment_table)
26
+
27
+
28
+ def perform_kmodes_clustering(df, feature_columns, n_clusters=3):
29
+ # Extract the relevant fields for clustering
30
+ cluster_data = df[feature_columns]
31
+
32
+ # Convert boolean features to integer type
33
+ cluster_data_encoded = cluster_data.astype(int)
34
+
35
+ # Define the K-modes model
36
+ km = KModes(n_clusters=n_clusters, init='Huang', n_init=5, verbose=1)
37
+
38
+ # Fit the cluster model
39
+ clusters = km.fit_predict(cluster_data_encoded)
40
+
41
+ # Add the cluster labels to the original dataframe
42
+ df['Cluster'] = clusters
43
+
44
+ # Create a dictionary to store dataframes for each cluster
45
+ cluster_dict = {}
46
+ for cluster in df['Cluster'].unique():
47
+ cluster_df = df[df['Cluster'] == cluster]
48
+ cluster_dict[cluster] = cluster_df
49
+
50
+ return cluster_dict
51
+
52
+
53
+ def show_radar_chart(clusters, feature_translations, font_path):
54
+
55
+ df_dict={
56
+ 'Persona 1 (Cluster 0)': clusters[0],
57
+ 'Persona 2 (Cluster 1)': clusters[1],
58
+ 'Persona 3 (Cluster 2)': clusters[2]
59
+ }
60
+
61
+ feature_translations_dict = dict(zip(prod_feat_flat_fields, feature_translations))
62
+ persona_averages = [df[list(feature_translations_dict.keys())].mean().tolist() for df in df_dict.values()]
63
+
64
+ # Append the first value at the end of each list for the radar chart
65
+ for averages in persona_averages:
66
+ averages += averages[:1]
67
+
68
+ # Prepare the English labels for plotting
69
+ english_feature_labels = list(feature_translations.values())
70
+ english_feature_labels += [english_feature_labels[0]] # Repeat the first label to close the loop
71
+
72
+ # Number of variables we're plotting
73
+ num_vars = len(english_feature_labels)
74
+
75
+ # Split the circle into even parts and save the angles
76
+ angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist()
77
+ angles += angles[:1] # Complete the loop
78
+
79
+ # Set up the font properties for using a custom font
80
+ font_properties = FontProperties(fname=font_path, size=12)
81
+ fig, ax = plt.subplots(figsize=(12, 12), subplot_kw=dict(polar=True))
82
+ fig.subplots_adjust(left=0.1, right=0.9, top=0.9, bottom=0.1)
83
+
84
+ # Draw one axe per variable and add labels
85
+ plt.xticks(angles[:-1], english_feature_labels, color='grey', size=12, fontproperties=font_properties)
86
+
87
+ # Draw ylabels
88
+ ax.set_rlabel_position(0)
89
+ plt.yticks([0.2, 0.4, 0.6, 0.8, 1], ["0.2", "0.4", "0.6", "0.8", "1"], color="grey", size=7)
90
+ plt.ylim(0, 1)
91
+
92
+ # Plot data and fill with color
93
+ for label, data in zip(df_dict.keys(), persona_averages):
94
+ data += data[:1] # Complete the loop
95
+ ax.plot(angles, data, label=label, linewidth=1, linestyle='solid')
96
+ ax.fill(angles, data, alpha=0.25)
97
+
98
+ # Add legend
99
+ plt.legend(loc='upper right', bbox_to_anchor=(0.1, 0.1))
100
+
101
+ # Add a title
102
+ plt.title('Product Feature Preferences by Persona', size=20, color='grey', y=1.1, fontproperties=font_properties)
103
+
104
+ # Display the radar chart
105
+ plt.show()