krishaamer commited on
Commit
a9992ee
·
1 Parent(s): 1a1ff20

Add correlation charts

Browse files
Files changed (1) hide show
  1. page_attitudes.py +88 -0
page_attitudes.py CHANGED
@@ -1,8 +1,12 @@
1
  import streamlit as st
2
  import matplotlib.pyplot as plt
3
  import seaborn as sns
 
 
 
4
  from matplotlib.font_manager import FontProperties
5
  from fields.likert_fields import likert_fields
 
6
  from fields.field_translation_mapping import field_translation_mapping
7
  from fields.translation_mapping import translation_mapping
8
 
@@ -15,6 +19,14 @@ def show(df):
15
  # Chinese font
16
  chinese_font = FontProperties(fname='mingliu.ttf')
17
 
 
 
 
 
 
 
 
 
18
  if df is not None:
19
 
20
  # Rename the columns in the DataFrame for visualization
@@ -61,3 +73,79 @@ def show(df):
61
 
62
  # Show the plot
63
  st.pyplot(fig)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  import matplotlib.pyplot as plt
3
  import seaborn as sns
4
+ import networkx as nx
5
+ import pandas as pd
6
+ import numpy as np
7
  from matplotlib.font_manager import FontProperties
8
  from fields.likert_fields import likert_fields
9
+ from fields.likert_flat_fields import likert_flat_fields
10
  from fields.field_translation_mapping import field_translation_mapping
11
  from fields.translation_mapping import translation_mapping
12
 
 
19
  # Chinese font
20
  chinese_font = FontProperties(fname='mingliu.ttf')
21
 
22
+ #generate_correlation_chart(df, chinese_font)
23
+ create_likert_charts(df, chinese_font)
24
+
25
+ st.title("Correlations Between Fields")
26
+ create_correlation_network(df, 0.4, chinese_font)
27
+
28
+
29
+ def create_likert_charts(df, chinese_font):
30
  if df is not None:
31
 
32
  # Rename the columns in the DataFrame for visualization
 
73
 
74
  # Show the plot
75
  st.pyplot(fig)
76
+
77
+ def create_correlation_network(df, threshold, chinese_font):
78
+
79
+ filtered_df = df[likert_flat_fields]
80
+ filtered_df = filtered_df.apply(pd.to_numeric, errors='coerce')
81
+
82
+ # Now you can calculate the correlation matrix and create the network
83
+ corr_matrix = filtered_df.corr()
84
+
85
+ # Create a graph
86
+ graph = nx.Graph()
87
+
88
+ # Iterate over the correlation matrix and add edges
89
+ for i in range(len(corr_matrix.columns)):
90
+ for j in range(i):
91
+ if abs(corr_matrix.iloc[i, j]) > threshold: # only consider strong correlations
92
+ graph.add_edge(corr_matrix.columns[i], corr_matrix.columns[j], weight=corr_matrix.iloc[i, j])
93
+
94
+ # Draw the network
95
+ pos = nx.spring_layout(graph, k=0.1, iterations=20)
96
+ edges = graph.edges()
97
+ weights = [graph[u][v]['weight'] for u, v in edges] # Use the weights for edge width
98
+
99
+ plt.figure(figsize=(10, 10))
100
+ nx.draw_networkx_nodes(graph, pos, node_size=500, node_color='lightblue', edgecolors='black')
101
+ nx.draw_networkx_edges(graph, pos, edgelist=edges, width=weights, alpha=0.5, edge_color='gray')
102
+
103
+ # Set Chinese font
104
+ for label in graph.nodes():
105
+ x, y = pos[label]
106
+ plt.text(x, y, label, fontsize=9, fontproperties=chinese_font, ha='center', va='center')
107
+
108
+ plt.title('Correlation Network', fontproperties=chinese_font)
109
+ plt.axis('off') # Turn off the axis
110
+
111
+ # Use Streamlit to render the plot
112
+ st.pyplot(plt)
113
+
114
+ def generate_correlation_chart(df, chinese_font):
115
+
116
+ boolean_fields = [
117
+ '你/妳覺得目前有任何投資嗎?'
118
+ ]
119
+
120
+ # Encode boolean fields
121
+ for field in boolean_fields:
122
+ df[field + '_encoded'] = df[field].map({'有': 1, '沒有': 0})
123
+
124
+ # Combine all fields for correlation
125
+ all_fields = likert_flat_fields + [f"{field}_encoded" for field in boolean_fields]
126
+
127
+ # Calculate the correlation matrix
128
+ correlation_data = df[all_fields].corr()
129
+
130
+ # Define a threshold for strong correlations
131
+ threshold = 0.5
132
+
133
+ # Find all fields that have at least one strong correlation
134
+ strong_fields = correlation_data.columns[np.abs(correlation_data).max() > threshold]
135
+
136
+ # Filter the correlation matrix to only include these fields
137
+ filtered_correlation_data = correlation_data.loc[strong_fields, strong_fields]
138
+
139
+ # Plot the correlation matrix
140
+ plt.figure(figsize=(10, 8))
141
+ ax = sns.heatmap(filtered_correlation_data, annot=True, fmt=".2f", cmap="coolwarm")
142
+
143
+ # Set the labels with the Chinese font
144
+ ax.set_xticklabels(ax.get_xticklabels(), fontproperties=chinese_font, rotation=45, ha='right')
145
+ ax.set_yticklabels(ax.get_yticklabels(), fontproperties=chinese_font, rotation=0)
146
+
147
+ # Set the title with the Chinese font
148
+ plt.title("強相關分析", fontproperties=chinese_font)
149
+
150
+ # Show the plot in Streamlit
151
+ st.pyplot(plt)