krishaamer commited on
Commit
e89b1a1
·
1 Parent(s): 24b6da8

Add network correlations

Browse files
Files changed (2) hide show
  1. page_shopping.py +38 -0
  2. requirements.txt +2 -1
page_shopping.py CHANGED
@@ -4,6 +4,7 @@ import matplotlib.pyplot as plt
4
  import seaborn as sns
5
  import pandas as pd
6
  import numpy as np
 
7
  from fields.likert_flat_fields import likert_flat_fields
8
  #from fields.boolean_fields import boolean_fields
9
 
@@ -16,6 +17,7 @@ def show(df):
16
  f"<h2 style='text-align: center;'>Boycott Count (Overall)</h2>", unsafe_allow_html=True)
17
  show_boycott_count(df, font_prop=chinese_font)
18
  #generate_correlation_chart(df, chinese_font)
 
19
 
20
  def show_boycott_count(df, font_prop):
21
  # Count the number of people who have invested and who have not
@@ -40,6 +42,42 @@ def show_boycott_count(df, font_prop):
40
  # Display the chart in Streamlit
41
  st.pyplot(plt)
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  def generate_correlation_chart(df, chinese_font):
45
 
 
4
  import seaborn as sns
5
  import pandas as pd
6
  import numpy as np
7
+ import networkx as nx
8
  from fields.likert_flat_fields import likert_flat_fields
9
  #from fields.boolean_fields import boolean_fields
10
 
 
17
  f"<h2 style='text-align: center;'>Boycott Count (Overall)</h2>", unsafe_allow_html=True)
18
  show_boycott_count(df, font_prop=chinese_font)
19
  #generate_correlation_chart(df, chinese_font)
20
+ create_correlation_network(df, 0.4, chinese_font)
21
 
22
  def show_boycott_count(df, font_prop):
23
  # Count the number of people who have invested and who have not
 
42
  # Display the chart in Streamlit
43
  st.pyplot(plt)
44
 
45
+ def create_correlation_network(df, threshold, chinese_font):
46
+
47
+ filtered_df = df[likert_flat_fields]
48
+ filtered_df = filtered_df.apply(pd.to_numeric, errors='coerce')
49
+
50
+ # Now you can calculate the correlation matrix and create the network
51
+ corr_matrix = filtered_df.corr()
52
+
53
+ # Create a graph
54
+ graph = nx.Graph()
55
+
56
+ # Iterate over the correlation matrix and add edges
57
+ for i in range(len(corr_matrix.columns)):
58
+ for j in range(i):
59
+ if abs(corr_matrix.iloc[i, j]) > threshold: # only consider strong correlations
60
+ graph.add_edge(corr_matrix.columns[i], corr_matrix.columns[j], weight=corr_matrix.iloc[i, j])
61
+
62
+ # Draw the network
63
+ pos = nx.spring_layout(graph, k=0.1, iterations=20)
64
+ edges = graph.edges()
65
+ weights = [graph[u][v]['weight'] for u, v in edges] # Use the weights for edge width
66
+
67
+ plt.figure(figsize=(10, 10))
68
+ nx.draw_networkx_nodes(graph, pos, node_size=500, node_color='lightblue', edgecolors='black')
69
+ nx.draw_networkx_edges(graph, pos, edgelist=edges, width=weights, alpha=0.5, edge_color='gray')
70
+
71
+ # Set Chinese font
72
+ for label in graph.nodes():
73
+ x, y = pos[label]
74
+ plt.text(x, y, label, fontsize=9, fontproperties=chinese_font, ha='center', va='center')
75
+
76
+ plt.title('Correlation Network', fontproperties=chinese_font)
77
+ plt.axis('off') # Turn off the axis
78
+
79
+ # Use Streamlit to render the plot
80
+ st.pyplot(plt)
81
 
82
  def generate_correlation_chart(df, chinese_font):
83
 
requirements.txt CHANGED
@@ -7,4 +7,5 @@ scipy
7
  scikit-learn
8
  squarify
9
  kmodes
10
- adjustText
 
 
7
  scikit-learn
8
  squarify
9
  kmodes
10
+ adjustText
11
+ networkx