evijit HF Staff commited on
Commit
813c7cf
·
verified ·
1 Parent(s): 9bf5a46

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -72
app.py CHANGED
@@ -1,5 +1,3 @@
1
- # --- app.py (Dataverse Explorer - Corrected with drill-down) ---
2
-
3
  import gradio as gr
4
  import pandas as pd
5
  import plotly.express as px
@@ -29,7 +27,6 @@ def load_datasets_data():
29
  print(err_msg)
30
  return pd.DataFrame(), False, err_msg
31
 
32
- # --- CORRECTED: This function now preserves individual datasets for top orgs ---
33
  def make_treemap_data(df, count_by, top_k=25, tag_filter=None, skip_cats=None):
34
  """
35
  Filter data and prepare it for a multi-level treemap.
@@ -58,36 +55,24 @@ def make_treemap_data(df, count_by, top_k=25, tag_filter=None, skip_cats=None):
58
  filtered_df[count_by] = 0.0
59
  filtered_df[count_by] = pd.to_numeric(filtered_df[count_by], errors='coerce').fillna(0.0)
60
 
61
- # 1. Get total for every organization to determine the top K
62
  all_org_totals = filtered_df.groupby("organization")[count_by].sum()
63
  top_org_names = all_org_totals.nlargest(top_k, keep='first').index.tolist()
64
 
65
- # 2. Get the full data for the individual datasets belonging to the top organizations
66
  top_orgs_df = filtered_df[filtered_df['organization'].isin(top_org_names)].copy()
67
-
68
- # 3. Calculate the total for the "Other" category
69
  other_total = all_org_totals[~all_org_totals.index.isin(top_org_names)].sum()
70
 
71
- # 4. Create the final DataFrame for the plot
72
  final_df_for_plot = top_orgs_df
73
 
74
- # 5. Add the "Other" row as a single entry if its value is greater than zero
75
  if other_total > 0:
76
- other_row = pd.DataFrame([{
77
- 'organization': 'Other',
78
- 'id': 'Other', # The 'id' for the "Other" category must be defined for the path
79
- count_by: other_total
80
- }])
81
  final_df_for_plot = pd.concat([final_df_for_plot, other_row], ignore_index=True)
82
 
83
- # 6. Apply the skip filter to the organization/category level
84
  if skip_cats and len(skip_cats) > 0:
85
  final_df_for_plot = final_df_for_plot[~final_df_for_plot['organization'].isin(skip_cats)]
86
 
87
  final_df_for_plot["root"] = "datasets"
88
  return final_df_for_plot
89
 
90
- # --- CORRECTED: The path is now restored to allow drill-down ---
91
  def create_treemap(treemap_data, count_by, title=None):
92
  """Generate the Plotly treemap figure from the prepared data."""
93
  if treemap_data.empty or treemap_data[count_by].sum() <= 0:
@@ -95,8 +80,6 @@ def create_treemap(treemap_data, count_by, title=None):
95
  fig.update_layout(title="No data matches the selected filters", margin=dict(t=50, l=25, r=25, b=25))
96
  return fig
97
 
98
- # The path is restored to `["root", "organization", "id"]` to enable drill-down.
99
- # The "Other" row with id='Other' will correctly be displayed as a single block.
100
  fig = px.treemap(treemap_data, path=["root", "organization", "id"], values=count_by,
101
  title=title, color_discrete_sequence=px.colors.qualitative.Plotly)
102
  fig.update_layout(margin=dict(t=50, l=25, r=25, b=25))
@@ -106,7 +89,7 @@ def create_treemap(treemap_data, count_by, title=None):
106
  )
107
  return fig
108
 
109
- # --- Gradio UI Blocks (no changes needed here) ---
110
  with gr.Blocks(title="🤗 Dataverse Explorer", fill_width=True) as demo:
111
  datasets_data_state = gr.State(pd.DataFrame())
112
  loading_complete_state = gr.State(False)
@@ -116,34 +99,11 @@ with gr.Blocks(title="🤗 Dataverse Explorer", fill_width=True) as demo:
116
 
117
  with gr.Row():
118
  with gr.Column(scale=1):
119
- count_by_dropdown = gr.Dropdown(
120
- label="Metric",
121
- choices=[("Downloads (last 30 days)", "downloads"), ("Downloads (All Time)", "downloadsAllTime"), ("Likes", "likes")],
122
- value="downloads"
123
- )
124
-
125
- tag_filter_dropdown = gr.Dropdown(
126
- label="Filter by Tag",
127
- choices=TAG_FILTER_CHOICES,
128
- value="None"
129
- )
130
-
131
- top_k_dropdown = gr.Dropdown(
132
- label="Number of Top Organizations",
133
- choices=TOP_K_CHOICES,
134
- value=25
135
- )
136
-
137
- skip_cats_textbox = gr.Textbox(
138
- label="Organizations to Skip from the plot",
139
- value="Other"
140
- )
141
-
142
- generate_plot_button = gr.Button(
143
- value="Generate Plot",
144
- variant="primary",
145
- interactive=False
146
- )
147
 
148
  with gr.Column(scale=3):
149
  plot_output = gr.Plot()
@@ -153,40 +113,51 @@ with gr.Blocks(title="🤗 Dataverse Explorer", fill_width=True) as demo:
153
  def _update_button_interactivity(is_loaded_flag):
154
  return gr.update(interactive=is_loaded_flag)
155
 
156
- def ui_load_data_controller(progress=gr.Progress()):
 
157
  progress(0, desc=f"Loading dataset '{HF_DATASET_ID}'...")
 
158
  try:
159
  current_df, load_success_flag, status_msg_from_load = load_datasets_data()
160
  if load_success_flag:
161
- progress(0.9, desc="Processing data...")
162
  date_display = "Pre-processed (date unavailable)"
163
  if 'data_download_timestamp' in current_df.columns and pd.notna(current_df['data_download_timestamp'].iloc[0]):
164
  ts = pd.to_datetime(current_df['data_download_timestamp'].iloc[0], utc=True)
165
  date_display = ts.strftime('%B %d, %Y, %H:%M:%S %Z')
166
 
167
- data_info_text = (
168
- f"### Data Information\n- Source: `{HF_DATASET_ID}`\n"
169
- f"- Status: {status_msg_from_load}\n"
170
- f"- Total datasets loaded: {len(current_df):,}\n"
171
- f"- Data as of: {date_display}\n"
172
- )
173
- status_msg_ui = "Data loaded. Ready to generate plot."
174
  else:
175
  data_info_text = f"### Data Load Failed\n- {status_msg_from_load}"
176
- status_msg_ui = status_msg_from_load
177
  except Exception as e:
178
- status_msg_ui = f"An unexpected error occurred: {str(e)}"
179
- data_info_text = f"### Critical Error\n- {status_msg_ui}"
180
  load_success_flag = False
181
- print(f"Critical error in ui_load_data_controller: {e}")
 
182
 
183
- return current_df, load_success_flag, data_info_text, status_msg_ui
 
 
 
 
 
 
 
 
 
 
 
 
 
184
 
185
- # --- CORRECTED: Updated stats to reflect the new plot structure ---
186
  def ui_generate_plot_controller(metric_choice, tag_choice, k_orgs,
187
  skip_cats_input, df_current_datasets, progress=gr.Progress()):
188
  if df_current_datasets is None or df_current_datasets.empty:
189
- return create_treemap(pd.DataFrame(), metric_choice), "Dataset data is not loaded."
190
 
191
  progress(0.1, desc="Aggregating data...")
192
  cats_to_skip = [cat.strip() for cat in skip_cats_input.split(',') if cat.strip()]
@@ -202,21 +173,20 @@ with gr.Blocks(title="🤗 Dataverse Explorer", fill_width=True) as demo:
202
  plot_stats_md = "No data matches the selected filters. Please try different options."
203
  else:
204
  total_value_in_plot = treemap_df[metric_choice].sum()
205
- # Count datasets, excluding our placeholder "Other" id
206
  total_datasets_in_plot = treemap_df[treemap_df['id'] != 'Other']['id'].nunique()
207
- plot_stats_md = (
208
- f"## Plot Statistics\n- **Organizations/Categories Shown**: {treemap_df['organization'].nunique():,}\n"
209
- f"- **Individual Datasets Shown**: {total_datasets_in_plot:,}\n"
210
- f"- **Total {metric_choice} in plot**: {int(total_value_in_plot):,}"
211
- )
212
 
213
  return plotly_fig, plot_stats_md
214
 
215
- # --- Event Wiring (no changes needed) ---
 
 
216
  demo.load(
217
- fn=ui_load_data_controller,
218
  inputs=[],
219
- outputs=[datasets_data_state, loading_complete_state, data_info_md, status_message_md]
220
  )
221
 
222
  loading_complete_state.change(
 
 
 
1
  import gradio as gr
2
  import pandas as pd
3
  import plotly.express as px
 
27
  print(err_msg)
28
  return pd.DataFrame(), False, err_msg
29
 
 
30
  def make_treemap_data(df, count_by, top_k=25, tag_filter=None, skip_cats=None):
31
  """
32
  Filter data and prepare it for a multi-level treemap.
 
55
  filtered_df[count_by] = 0.0
56
  filtered_df[count_by] = pd.to_numeric(filtered_df[count_by], errors='coerce').fillna(0.0)
57
 
 
58
  all_org_totals = filtered_df.groupby("organization")[count_by].sum()
59
  top_org_names = all_org_totals.nlargest(top_k, keep='first').index.tolist()
60
 
 
61
  top_orgs_df = filtered_df[filtered_df['organization'].isin(top_org_names)].copy()
 
 
62
  other_total = all_org_totals[~all_org_totals.index.isin(top_org_names)].sum()
63
 
 
64
  final_df_for_plot = top_orgs_df
65
 
 
66
  if other_total > 0:
67
+ other_row = pd.DataFrame([{'organization': 'Other', 'id': 'Other', count_by: other_total}])
 
 
 
 
68
  final_df_for_plot = pd.concat([final_df_for_plot, other_row], ignore_index=True)
69
 
 
70
  if skip_cats and len(skip_cats) > 0:
71
  final_df_for_plot = final_df_for_plot[~final_df_for_plot['organization'].isin(skip_cats)]
72
 
73
  final_df_for_plot["root"] = "datasets"
74
  return final_df_for_plot
75
 
 
76
  def create_treemap(treemap_data, count_by, title=None):
77
  """Generate the Plotly treemap figure from the prepared data."""
78
  if treemap_data.empty or treemap_data[count_by].sum() <= 0:
 
80
  fig.update_layout(title="No data matches the selected filters", margin=dict(t=50, l=25, r=25, b=25))
81
  return fig
82
 
 
 
83
  fig = px.treemap(treemap_data, path=["root", "organization", "id"], values=count_by,
84
  title=title, color_discrete_sequence=px.colors.qualitative.Plotly)
85
  fig.update_layout(margin=dict(t=50, l=25, r=25, b=25))
 
89
  )
90
  return fig
91
 
92
+ # --- Gradio UI Blocks ---
93
  with gr.Blocks(title="🤗 Dataverse Explorer", fill_width=True) as demo:
94
  datasets_data_state = gr.State(pd.DataFrame())
95
  loading_complete_state = gr.State(False)
 
99
 
100
  with gr.Row():
101
  with gr.Column(scale=1):
102
+ count_by_dropdown = gr.Dropdown(label="Metric", choices=[("Downloads (last 30 days)", "downloads"), ("Downloads (All Time)", "downloadsAllTime"), ("Likes", "likes")], value="downloads")
103
+ tag_filter_dropdown = gr.Dropdown(label="Filter by Tag", choices=TAG_FILTER_CHOICES, value="None")
104
+ top_k_dropdown = gr.Dropdown(label="Number of Top Organizations", choices=TOP_K_CHOICES, value=25)
105
+ skip_cats_textbox = gr.Textbox(label="Organizations to Skip from the plot", value="Other")
106
+ generate_plot_button = gr.Button(value="Generate Plot", variant="primary", interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
  with gr.Column(scale=3):
109
  plot_output = gr.Plot()
 
113
  def _update_button_interactivity(is_loaded_flag):
114
  return gr.update(interactive=is_loaded_flag)
115
 
116
+ ## CHANGE: New combined function to load data and generate the initial plot on startup.
117
+ def load_and_generate_initial_plot(progress=gr.Progress()):
118
  progress(0, desc=f"Loading dataset '{HF_DATASET_ID}'...")
119
+ # --- Part 1: Data Loading ---
120
  try:
121
  current_df, load_success_flag, status_msg_from_load = load_datasets_data()
122
  if load_success_flag:
123
+ progress(0.5, desc="Processing data...")
124
  date_display = "Pre-processed (date unavailable)"
125
  if 'data_download_timestamp' in current_df.columns and pd.notna(current_df['data_download_timestamp'].iloc[0]):
126
  ts = pd.to_datetime(current_df['data_download_timestamp'].iloc[0], utc=True)
127
  date_display = ts.strftime('%B %d, %Y, %H:%M:%S %Z')
128
 
129
+ data_info_text = (f"### Data Information\n- Source: `{HF_DATASET_ID}`\n"
130
+ f"- Status: {status_msg_from_load}\n"
131
+ f"- Total datasets loaded: {len(current_df):,}\n"
132
+ f"- Data as of: {date_display}\n")
 
 
 
133
  else:
134
  data_info_text = f"### Data Load Failed\n- {status_msg_from_load}"
 
135
  except Exception as e:
136
+ status_msg_from_load = f"An unexpected error occurred: {str(e)}"
137
+ data_info_text = f"### Critical Error\n- {status_msg_from_load}"
138
  load_success_flag = False
139
+ current_df = pd.DataFrame() # Ensure df is empty on failure
140
+ print(f"Critical error in load_and_generate_initial_plot: {e}")
141
 
142
+ # --- Part 2: Generate Initial Plot ---
143
+ progress(0.6, desc="Generating initial plot...")
144
+ # Get default values directly from the UI component definitions
145
+ default_metric = "downloads"
146
+ default_tag = "None"
147
+ default_k = 25
148
+ default_skip_cats = "Other"
149
+
150
+ # Reuse the existing controller function for plotting
151
+ initial_plot, initial_status = ui_generate_plot_controller(
152
+ default_metric, default_tag, default_k, default_skip_cats, current_df, progress
153
+ )
154
+
155
+ return current_df, load_success_flag, data_info_text, initial_status, initial_plot
156
 
 
157
  def ui_generate_plot_controller(metric_choice, tag_choice, k_orgs,
158
  skip_cats_input, df_current_datasets, progress=gr.Progress()):
159
  if df_current_datasets is None or df_current_datasets.empty:
160
+ return create_treemap(pd.DataFrame(), metric_choice), "Dataset data is not loaded. Cannot generate plot."
161
 
162
  progress(0.1, desc="Aggregating data...")
163
  cats_to_skip = [cat.strip() for cat in skip_cats_input.split(',') if cat.strip()]
 
173
  plot_stats_md = "No data matches the selected filters. Please try different options."
174
  else:
175
  total_value_in_plot = treemap_df[metric_choice].sum()
 
176
  total_datasets_in_plot = treemap_df[treemap_df['id'] != 'Other']['id'].nunique()
177
+ plot_stats_md = (f"## Plot Statistics\n- **Organizations/Categories Shown**: {treemap_df['organization'].nunique():,}\n"
178
+ f"- **Individual Datasets Shown**: {total_datasets_in_plot:,}\n"
179
+ f"- **Total {metric_choice} in plot**: {int(total_value_in_plot):,}")
 
 
180
 
181
  return plotly_fig, plot_stats_md
182
 
183
+ # --- Event Wiring ---
184
+
185
+ ## CHANGE: Updated demo.load to call the new function and to add plot_output to the outputs list.
186
  demo.load(
187
+ fn=load_and_generate_initial_plot,
188
  inputs=[],
189
+ outputs=[datasets_data_state, loading_complete_state, data_info_md, status_message_md, plot_output]
190
  )
191
 
192
  loading_complete_state.change(