import gradio as gr
import pandas as pd
import plotly.express as px
import time
from datasets import load_dataset
# --- Constants ---
TOP_K_CHOICES = list(range(5, 51, 5))
HF_DATASET_ID = "evijit/dataverse_daily_data"
TAG_FILTER_CHOICES = [
"None", "Audio & Speech", "Time series", "Robotics", "Music",
"Video", "Images", "Text", "Biomedical", "Sciences"
]
def load_datasets_data():
"""Load the processed datasets data from the Hugging Face Hub."""
start_time = time.time()
print(f"Attempting to load dataset from Hugging Face Hub: {HF_DATASET_ID}")
try:
dataset_dict = load_dataset(HF_DATASET_ID)
df = dataset_dict[list(dataset_dict.keys())[0]].to_pandas()
msg = f"Successfully loaded dataset in {time.time() - start_time:.2f}s."
print(msg)
return df, True, msg
except Exception as e:
err_msg = f"Failed to load dataset. Error: {e}"
print(err_msg)
return pd.DataFrame(), False, err_msg
def make_treemap_data(df, count_by, top_k=25, tag_filter=None, skip_cats=None):
"""
Filter data and prepare it for a multi-level treemap.
- Preserves individual datasets for the top K organizations.
- Groups all other organizations into a single "Other" category.
"""
if df is None or df.empty:
return pd.DataFrame()
filtered_df = df.copy()
col_map = {
"Audio & Speech": "is_audio_speech", "Music": "has_music", "Robotics": "has_robot",
"Biomedical": "is_biomed", "Time series": "has_series", "Sciences": "has_science",
"Video": "has_video", "Images": "has_image", "Text": "has_text"
}
if tag_filter and tag_filter != "None" and tag_filter in col_map:
if col_map[tag_filter] in filtered_df.columns:
filtered_df = filtered_df[filtered_df[col_map[tag_filter]]]
if filtered_df.empty:
return pd.DataFrame()
if count_by not in filtered_df.columns:
filtered_df[count_by] = 0.0
filtered_df[count_by] = pd.to_numeric(filtered_df[count_by], errors='coerce').fillna(0.0)
all_org_totals = filtered_df.groupby("organization")[count_by].sum()
top_org_names = all_org_totals.nlargest(top_k, keep='first').index.tolist()
top_orgs_df = filtered_df[filtered_df['organization'].isin(top_org_names)].copy()
other_total = all_org_totals[~all_org_totals.index.isin(top_org_names)].sum()
final_df_for_plot = top_orgs_df
if other_total > 0:
other_row = pd.DataFrame([{'organization': 'Other', 'id': 'Other', count_by: other_total}])
final_df_for_plot = pd.concat([final_df_for_plot, other_row], ignore_index=True)
if skip_cats and len(skip_cats) > 0:
final_df_for_plot = final_df_for_plot[~final_df_for_plot['organization'].isin(skip_cats)]
final_df_for_plot["root"] = "datasets"
return final_df_for_plot
def create_treemap(treemap_data, count_by, title=None):
"""Generate the Plotly treemap figure from the prepared data."""
if treemap_data.empty or treemap_data[count_by].sum() <= 0:
fig = px.treemap(names=["No data matches filters"], parents=[""], values=[1])
fig.update_layout(title="No data matches the selected filters", margin=dict(t=50, l=25, r=25, b=25))
return fig
fig = px.treemap(treemap_data, path=["root", "organization", "id"], values=count_by,
title=title, color_discrete_sequence=px.colors.qualitative.Plotly)
fig.update_layout(margin=dict(t=50, l=25, r=25, b=25))
fig.update_traces(
textinfo="label+value+percent root",
hovertemplate="%{label}
%{value:,} " + count_by + "
%{percentRoot:.2%} of total"
)
return fig
# --- Gradio UI Blocks ---
with gr.Blocks(title="🤗 Dataverse Explorer", fill_width=True) as demo:
datasets_data_state = gr.State(pd.DataFrame())
loading_complete_state = gr.State(False)
with gr.Row():
gr.Markdown("# 🤗 Dataverse Explorer")
with gr.Row():
with gr.Column(scale=1):
count_by_dropdown = gr.Dropdown(label="Metric", choices=[("Downloads (last 30 days)", "downloads"), ("Downloads (All Time)", "downloadsAllTime"), ("Likes", "likes")], value="downloads")
tag_filter_dropdown = gr.Dropdown(label="Filter by Tag", choices=TAG_FILTER_CHOICES, value="None")
top_k_dropdown = gr.Dropdown(label="Number of Top Organizations", choices=TOP_K_CHOICES, value=25)
skip_cats_textbox = gr.Textbox(label="Organizations to Skip from the plot", value="Other")
generate_plot_button = gr.Button(value="Generate Plot", variant="primary", interactive=False)
with gr.Column(scale=3):
plot_output = gr.Plot()
status_message_md = gr.Markdown("Initializing...")
data_info_md = gr.Markdown("")
def _update_button_interactivity(is_loaded_flag):
return gr.update(interactive=is_loaded_flag)
## CHANGE: New combined function to load data and generate the initial plot on startup.
def load_and_generate_initial_plot(progress=gr.Progress()):
progress(0, desc=f"Loading dataset '{HF_DATASET_ID}'...")
# --- Part 1: Data Loading ---
try:
current_df, load_success_flag, status_msg_from_load = load_datasets_data()
if load_success_flag:
progress(0.5, desc="Processing data...")
date_display = "Pre-processed (date unavailable)"
if 'data_download_timestamp' in current_df.columns and pd.notna(current_df['data_download_timestamp'].iloc[0]):
ts = pd.to_datetime(current_df['data_download_timestamp'].iloc[0], utc=True)
date_display = ts.strftime('%B %d, %Y, %H:%M:%S %Z')
data_info_text = (f"### Data Information\n- Source: `{HF_DATASET_ID}`\n"
f"- Status: {status_msg_from_load}\n"
f"- Total datasets loaded: {len(current_df):,}\n"
f"- Data as of: {date_display}\n")
else:
data_info_text = f"### Data Load Failed\n- {status_msg_from_load}"
except Exception as e:
status_msg_from_load = f"An unexpected error occurred: {str(e)}"
data_info_text = f"### Critical Error\n- {status_msg_from_load}"
load_success_flag = False
current_df = pd.DataFrame() # Ensure df is empty on failure
print(f"Critical error in load_and_generate_initial_plot: {e}")
# --- Part 2: Generate Initial Plot ---
progress(0.6, desc="Generating initial plot...")
# Get default values directly from the UI component definitions
default_metric = "downloads"
default_tag = "None"
default_k = 25
default_skip_cats = "Other"
# Reuse the existing controller function for plotting
initial_plot, initial_status = ui_generate_plot_controller(
default_metric, default_tag, default_k, default_skip_cats, current_df, progress
)
return current_df, load_success_flag, data_info_text, initial_status, initial_plot
def ui_generate_plot_controller(metric_choice, tag_choice, k_orgs,
skip_cats_input, df_current_datasets, progress=gr.Progress()):
if df_current_datasets is None or df_current_datasets.empty:
return create_treemap(pd.DataFrame(), metric_choice), "Dataset data is not loaded. Cannot generate plot."
progress(0.1, desc="Aggregating data...")
cats_to_skip = [cat.strip() for cat in skip_cats_input.split(',') if cat.strip()]
treemap_df = make_treemap_data(df_current_datasets, metric_choice, k_orgs, tag_choice, cats_to_skip)
progress(0.7, desc="Generating plot...")
title_labels = {"downloads": "Downloads (last 30 days)", "downloadsAllTime": "Downloads (All Time)", "likes": "Likes"}
chart_title = f"HuggingFace Datasets - {title_labels.get(metric_choice, metric_choice)} by Organization"
plotly_fig = create_treemap(treemap_df, metric_choice, chart_title)
if treemap_df.empty:
plot_stats_md = "No data matches the selected filters. Please try different options."
else:
total_value_in_plot = treemap_df[metric_choice].sum()
total_datasets_in_plot = treemap_df[treemap_df['id'] != 'Other']['id'].nunique()
plot_stats_md = (f"## Plot Statistics\n- **Organizations/Categories Shown**: {treemap_df['organization'].nunique():,}\n"
f"- **Individual Datasets Shown**: {total_datasets_in_plot:,}\n"
f"- **Total {metric_choice} in plot**: {int(total_value_in_plot):,}")
return plotly_fig, plot_stats_md
# --- Event Wiring ---
## CHANGE: Updated demo.load to call the new function and to add plot_output to the outputs list.
demo.load(
fn=load_and_generate_initial_plot,
inputs=[],
outputs=[datasets_data_state, loading_complete_state, data_info_md, status_message_md, plot_output]
)
loading_complete_state.change(
fn=_update_button_interactivity,
inputs=loading_complete_state,
outputs=generate_plot_button
)
generate_plot_button.click(
fn=ui_generate_plot_controller,
inputs=[count_by_dropdown, tag_filter_dropdown, top_k_dropdown,
skip_cats_textbox, datasets_data_state],
outputs=[plot_output, status_message_md]
)
if __name__ == "__main__":
print("Application starting...")
demo.queue().launch()