krishaamer commited on
Commit
2167a2b
·
1 Parent(s): fc7d0e7

Move data loading up the tree

Browse files
Files changed (3) hide show
  1. app.py +16 -4
  2. page_demographics.py +2 -7
  3. page_likert.py +12 -13
app.py CHANGED
@@ -3,6 +3,7 @@ import page_home
3
  import page_likert
4
  import page_demographics
5
  from urllib.parse import quote, unquote
 
6
 
7
  # Default to wide mode
8
  st.set_page_config(layout="wide")
@@ -14,6 +15,15 @@ if 'page' not in st.session_state:
14
  initial_page = unquote(query_params.get("page", ["Home"])[0])
15
  st.session_state['page'] = initial_page
16
 
 
 
 
 
 
 
 
 
 
17
  # Sidebar navigation using buttons
18
  st.sidebar.title("Navigation")
19
  if st.sidebar.button("Introduction"):
@@ -27,13 +37,15 @@ if st.sidebar.button("Student Attitudes"):
27
  st.experimental_set_query_params(page=quote(st.session_state['page']))
28
 
29
  # Page contents based on session state
 
 
30
  if st.session_state['page'] == 'Home':
31
- page_home.show()
32
  elif st.session_state['page'] == 'Likert':
33
- page_likert.show()
34
  elif st.session_state['page'] == 'Demographics':
35
- page_demographics.show()
36
 
37
  # Rerun Calculations
38
  if st.sidebar.button("Rerun Calculations"):
39
- st.rerun()
 
3
  import page_likert
4
  import page_demographics
5
  from urllib.parse import quote, unquote
6
+ from datasets import load_dataset
7
 
8
  # Default to wide mode
9
  st.set_page_config(layout="wide")
 
15
  initial_page = unquote(query_params.get("page", ["Home"])[0])
16
  st.session_state['page'] = initial_page
17
 
18
+
19
+ @st.cache_data
20
+ def load_data():
21
+ # Load data from Huggingface
22
+ dataset = load_dataset(
23
+ "krishaamer/taiwanese-college-students", data_files={'train': 'clean.csv'})
24
+ return dataset['train'].to_pandas()
25
+
26
+
27
  # Sidebar navigation using buttons
28
  st.sidebar.title("Navigation")
29
  if st.sidebar.button("Introduction"):
 
37
  st.experimental_set_query_params(page=quote(st.session_state['page']))
38
 
39
  # Page contents based on session state
40
+ df = load_data()
41
+
42
  if st.session_state['page'] == 'Home':
43
+ page_home.show(df)
44
  elif st.session_state['page'] == 'Likert':
45
+ page_likert.show(df)
46
  elif st.session_state['page'] == 'Demographics':
47
+ page_demographics.show(df)
48
 
49
  # Rerun Calculations
50
  if st.sidebar.button("Rerun Calculations"):
51
+ st.rerun()
page_demographics.py CHANGED
@@ -2,14 +2,9 @@ import streamlit as st
2
  import pandas as pd
3
  from datasets import load_dataset
4
 
5
- @st.cache_data
6
- def show():
7
- # Load data from Huggingface
8
- dataset = load_dataset(
9
- "krishaamer/taiwanese-college-students", data_files={'train': 'clean.csv'})
10
 
11
- # Convert the loaded dataset to a pandas DataFrame
12
- df = dataset['train'].to_pandas()
13
 
14
  # Generate and display the university ranking table
15
  generate_university_ranking_table(df)
 
2
  import pandas as pd
3
  from datasets import load_dataset
4
 
 
 
 
 
 
5
 
6
+ @st.cache_data
7
+ def show(df):
8
 
9
  # Generate and display the university ranking table
10
  generate_university_ranking_table(df)
page_likert.py CHANGED
@@ -5,19 +5,14 @@ import seaborn as sns
5
  from datasets import load_dataset
6
  from matplotlib.font_manager import FontProperties
7
 
 
8
  @st.cache_data
9
- def show():
10
 
11
  # Chinese font
12
  chinese_font = FontProperties(fname='mingliu.ttf')
13
 
14
- # Get Data
15
- dataset = load_dataset("krishaamer/taiwanese-college-students", data_files={'train': 'clean.csv'})
16
-
17
- if dataset is not None:
18
-
19
- # Load the CSV file into a DataFrame
20
- df = dataset['train'].to_pandas()
21
 
22
  likert_fields = {
23
  '購物習慣': [
@@ -146,11 +141,13 @@ def show():
146
 
147
  # Loop through each category in likert_fields to create visualizations
148
  for category, fields in likert_fields.items():
149
- st.subheader(f'Distribution of Responses for {translation_mapping[category]}')
 
150
 
151
  # Calculate the number of rows needed for this category
152
  num_fields = len(fields)
153
- num_rows = -(-num_fields // 2) # Equivalent to ceil(num_fields / 2)
 
154
 
155
  # Create subplots with 2 columns for this category
156
  fig, axs = plt.subplots(num_rows, 2, figsize=(15, 5 * num_rows))
@@ -162,12 +159,14 @@ def show():
162
  # Loop through each field in the category to create individual bar plots
163
  for i, field in enumerate(fields):
164
  # Create the bar plot
165
- sns.countplot(x=f"{field} ({field_translation_mapping[category][i]})", data=df_translated, ax=axs[i])
 
166
 
167
  # Add title and labels
168
  title_chinese = field
169
  title_english = field_translation_mapping[category][i]
170
- axs[i].set_title(f"{title_chinese}\n{title_english}", fontproperties=chinese_font)
 
171
  axs[i].set_xlabel('Likert Scale')
172
  axs[i].set_ylabel('Frequency')
173
 
@@ -176,4 +175,4 @@ def show():
176
  fig.delaxes(axs[i])
177
 
178
  # Show the plot
179
- st.pyplot(fig)
 
5
  from datasets import load_dataset
6
  from matplotlib.font_manager import FontProperties
7
 
8
+
9
  @st.cache_data
10
+ def show(df):
11
 
12
  # Chinese font
13
  chinese_font = FontProperties(fname='mingliu.ttf')
14
 
15
+ if df is not None:
 
 
 
 
 
 
16
 
17
  likert_fields = {
18
  '購物習慣': [
 
141
 
142
  # Loop through each category in likert_fields to create visualizations
143
  for category, fields in likert_fields.items():
144
+ st.subheader(
145
+ f'Distribution of Responses for {translation_mapping[category]}')
146
 
147
  # Calculate the number of rows needed for this category
148
  num_fields = len(fields)
149
+ # Equivalent to ceil(num_fields / 2)
150
+ num_rows = -(-num_fields // 2)
151
 
152
  # Create subplots with 2 columns for this category
153
  fig, axs = plt.subplots(num_rows, 2, figsize=(15, 5 * num_rows))
 
159
  # Loop through each field in the category to create individual bar plots
160
  for i, field in enumerate(fields):
161
  # Create the bar plot
162
+ sns.countplot(
163
+ x=f"{field} ({field_translation_mapping[category][i]})", data=df_translated, ax=axs[i])
164
 
165
  # Add title and labels
166
  title_chinese = field
167
  title_english = field_translation_mapping[category][i]
168
+ axs[i].set_title(
169
+ f"{title_chinese}\n{title_english}", fontproperties=chinese_font)
170
  axs[i].set_xlabel('Likert Scale')
171
  axs[i].set_ylabel('Frequency')
172
 
 
175
  fig.delaxes(axs[i])
176
 
177
  # Show the plot
178
+ st.pyplot(fig)