import gradio as gr import pandas as pd import plotly.express as px import time from datasets import load_dataset # --- Constants --- TOP_K_CHOICES = list(range(5, 51, 5)) HF_DATASET_ID = "evijit/dataverse_daily_data" TAG_FILTER_CHOICES = [ "None", "Audio & Speech", "Time series", "Robotics", "Music", "Video", "Images", "Text", "Biomedical", "Sciences" ] def load_datasets_data(): """Load the processed datasets data from the Hugging Face Hub.""" start_time = time.time() print(f"Attempting to load dataset from Hugging Face Hub: {HF_DATASET_ID}") try: dataset_dict = load_dataset(HF_DATASET_ID) df = dataset_dict[list(dataset_dict.keys())[0]].to_pandas() msg = f"Successfully loaded dataset in {time.time() - start_time:.2f}s." print(msg) return df, True, msg except Exception as e: err_msg = f"Failed to load dataset. Error: {e}" print(err_msg) return pd.DataFrame(), False, err_msg def make_treemap_data(df, count_by, top_k=25, tag_filter=None, skip_cats=None): """ Filter data and prepare it for a multi-level treemap. - Preserves individual datasets for the top K organizations. - Groups all other organizations into a single "Other" category. """ if df is None or df.empty: return pd.DataFrame() filtered_df = df.copy() col_map = { "Audio & Speech": "is_audio_speech", "Music": "has_music", "Robotics": "has_robot", "Biomedical": "is_biomed", "Time series": "has_series", "Sciences": "has_science", "Video": "has_video", "Images": "has_image", "Text": "has_text" } if tag_filter and tag_filter != "None" and tag_filter in col_map: if col_map[tag_filter] in filtered_df.columns: filtered_df = filtered_df[filtered_df[col_map[tag_filter]]] if filtered_df.empty: return pd.DataFrame() if count_by not in filtered_df.columns: filtered_df[count_by] = 0.0 filtered_df[count_by] = pd.to_numeric(filtered_df[count_by], errors='coerce').fillna(0.0) all_org_totals = filtered_df.groupby("organization")[count_by].sum() top_org_names = all_org_totals.nlargest(top_k, keep='first').index.tolist() top_orgs_df = filtered_df[filtered_df['organization'].isin(top_org_names)].copy() other_total = all_org_totals[~all_org_totals.index.isin(top_org_names)].sum() final_df_for_plot = top_orgs_df if other_total > 0: other_row = pd.DataFrame([{'organization': 'Other', 'id': 'Other', count_by: other_total}]) final_df_for_plot = pd.concat([final_df_for_plot, other_row], ignore_index=True) if skip_cats and len(skip_cats) > 0: final_df_for_plot = final_df_for_plot[~final_df_for_plot['organization'].isin(skip_cats)] final_df_for_plot["root"] = "datasets" return final_df_for_plot def create_treemap(treemap_data, count_by, title=None): """Generate the Plotly treemap figure from the prepared data.""" if treemap_data.empty or treemap_data[count_by].sum() <= 0: fig = px.treemap(names=["No data matches filters"], parents=[""], values=[1]) fig.update_layout(title="No data matches the selected filters", margin=dict(t=50, l=25, r=25, b=25)) return fig fig = px.treemap(treemap_data, path=["root", "organization", "id"], values=count_by, title=title, color_discrete_sequence=px.colors.qualitative.Plotly) fig.update_layout(margin=dict(t=50, l=25, r=25, b=25)) fig.update_traces( textinfo="label+value+percent root", hovertemplate="%{label}
%{value:,} " + count_by + "
%{percentRoot:.2%} of total" ) return fig # --- Gradio UI Blocks --- with gr.Blocks(title="🤗 Dataverse Explorer", fill_width=True) as demo: datasets_data_state = gr.State(pd.DataFrame()) loading_complete_state = gr.State(False) with gr.Row(): gr.Markdown("# 🤗 Dataverse Explorer") with gr.Row(): with gr.Column(scale=1): count_by_dropdown = gr.Dropdown(label="Metric", choices=[("Downloads (last 30 days)", "downloads"), ("Downloads (All Time)", "downloadsAllTime"), ("Likes", "likes")], value="downloads") tag_filter_dropdown = gr.Dropdown(label="Filter by Tag", choices=TAG_FILTER_CHOICES, value="None") top_k_dropdown = gr.Dropdown(label="Number of Top Organizations", choices=TOP_K_CHOICES, value=25) skip_cats_textbox = gr.Textbox(label="Organizations to Skip from the plot", value="Other") generate_plot_button = gr.Button(value="Generate Plot", variant="primary", interactive=False) with gr.Column(scale=3): plot_output = gr.Plot() status_message_md = gr.Markdown("Initializing...") data_info_md = gr.Markdown("") def _update_button_interactivity(is_loaded_flag): return gr.update(interactive=is_loaded_flag) ## CHANGE: New combined function to load data and generate the initial plot on startup. def load_and_generate_initial_plot(progress=gr.Progress()): progress(0, desc=f"Loading dataset '{HF_DATASET_ID}'...") # --- Part 1: Data Loading --- try: current_df, load_success_flag, status_msg_from_load = load_datasets_data() if load_success_flag: progress(0.5, desc="Processing data...") date_display = "Pre-processed (date unavailable)" if 'data_download_timestamp' in current_df.columns and pd.notna(current_df['data_download_timestamp'].iloc[0]): ts = pd.to_datetime(current_df['data_download_timestamp'].iloc[0], utc=True) date_display = ts.strftime('%B %d, %Y, %H:%M:%S %Z') data_info_text = (f"### Data Information\n- Source: `{HF_DATASET_ID}`\n" f"- Status: {status_msg_from_load}\n" f"- Total datasets loaded: {len(current_df):,}\n" f"- Data as of: {date_display}\n") else: data_info_text = f"### Data Load Failed\n- {status_msg_from_load}" except Exception as e: status_msg_from_load = f"An unexpected error occurred: {str(e)}" data_info_text = f"### Critical Error\n- {status_msg_from_load}" load_success_flag = False current_df = pd.DataFrame() # Ensure df is empty on failure print(f"Critical error in load_and_generate_initial_plot: {e}") # --- Part 2: Generate Initial Plot --- progress(0.6, desc="Generating initial plot...") # Get default values directly from the UI component definitions default_metric = "downloads" default_tag = "None" default_k = 25 default_skip_cats = "Other" # Reuse the existing controller function for plotting initial_plot, initial_status = ui_generate_plot_controller( default_metric, default_tag, default_k, default_skip_cats, current_df, progress ) return current_df, load_success_flag, data_info_text, initial_status, initial_plot def ui_generate_plot_controller(metric_choice, tag_choice, k_orgs, skip_cats_input, df_current_datasets, progress=gr.Progress()): if df_current_datasets is None or df_current_datasets.empty: return create_treemap(pd.DataFrame(), metric_choice), "Dataset data is not loaded. Cannot generate plot." progress(0.1, desc="Aggregating data...") cats_to_skip = [cat.strip() for cat in skip_cats_input.split(',') if cat.strip()] treemap_df = make_treemap_data(df_current_datasets, metric_choice, k_orgs, tag_choice, cats_to_skip) progress(0.7, desc="Generating plot...") title_labels = {"downloads": "Downloads (last 30 days)", "downloadsAllTime": "Downloads (All Time)", "likes": "Likes"} chart_title = f"HuggingFace Datasets - {title_labels.get(metric_choice, metric_choice)} by Organization" plotly_fig = create_treemap(treemap_df, metric_choice, chart_title) if treemap_df.empty: plot_stats_md = "No data matches the selected filters. Please try different options." else: total_value_in_plot = treemap_df[metric_choice].sum() total_datasets_in_plot = treemap_df[treemap_df['id'] != 'Other']['id'].nunique() plot_stats_md = (f"## Plot Statistics\n- **Organizations/Categories Shown**: {treemap_df['organization'].nunique():,}\n" f"- **Individual Datasets Shown**: {total_datasets_in_plot:,}\n" f"- **Total {metric_choice} in plot**: {int(total_value_in_plot):,}") return plotly_fig, plot_stats_md # --- Event Wiring --- ## CHANGE: Updated demo.load to call the new function and to add plot_output to the outputs list. demo.load( fn=load_and_generate_initial_plot, inputs=[], outputs=[datasets_data_state, loading_complete_state, data_info_md, status_message_md, plot_output] ) loading_complete_state.change( fn=_update_button_interactivity, inputs=loading_complete_state, outputs=generate_plot_button ) generate_plot_button.click( fn=ui_generate_plot_controller, inputs=[count_by_dropdown, tag_filter_dropdown, top_k_dropdown, skip_cats_textbox, datasets_data_state], outputs=[plot_output, status_message_md] ) if __name__ == "__main__": print("Application starting...") demo.queue().launch()