|
import gradio as gr |
|
import pandas as pd |
|
import plotly.express as px |
|
import time |
|
from datasets import load_dataset |
|
|
|
|
|
TOP_K_CHOICES = list(range(5, 51, 5)) |
|
HF_DATASET_ID = "evijit/dataverse_daily_data" |
|
TAG_FILTER_CHOICES = [ |
|
"None", "Audio & Speech", "Time series", "Robotics", "Music", |
|
"Video", "Images", "Text", "Biomedical", "Sciences" |
|
] |
|
|
|
def load_datasets_data(): |
|
"""Load the processed datasets data from the Hugging Face Hub.""" |
|
start_time = time.time() |
|
print(f"Attempting to load dataset from Hugging Face Hub: {HF_DATASET_ID}") |
|
try: |
|
dataset_dict = load_dataset(HF_DATASET_ID) |
|
df = dataset_dict[list(dataset_dict.keys())[0]].to_pandas() |
|
msg = f"Successfully loaded dataset in {time.time() - start_time:.2f}s." |
|
print(msg) |
|
return df, True, msg |
|
except Exception as e: |
|
err_msg = f"Failed to load dataset. Error: {e}" |
|
print(err_msg) |
|
return pd.DataFrame(), False, err_msg |
|
|
|
def make_treemap_data(df, count_by, top_k=25, tag_filter=None, skip_cats=None): |
|
""" |
|
Filter data and prepare it for a multi-level treemap. |
|
- Preserves individual datasets for the top K organizations. |
|
- Groups all other organizations into a single "Other" category. |
|
""" |
|
if df is None or df.empty: |
|
return pd.DataFrame() |
|
|
|
filtered_df = df.copy() |
|
|
|
col_map = { |
|
"Audio & Speech": "is_audio_speech", "Music": "has_music", "Robotics": "has_robot", |
|
"Biomedical": "is_biomed", "Time series": "has_series", "Sciences": "has_science", |
|
"Video": "has_video", "Images": "has_image", "Text": "has_text" |
|
} |
|
|
|
if tag_filter and tag_filter != "None" and tag_filter in col_map: |
|
if col_map[tag_filter] in filtered_df.columns: |
|
filtered_df = filtered_df[filtered_df[col_map[tag_filter]]] |
|
|
|
if filtered_df.empty: |
|
return pd.DataFrame() |
|
|
|
if count_by not in filtered_df.columns: |
|
filtered_df[count_by] = 0.0 |
|
filtered_df[count_by] = pd.to_numeric(filtered_df[count_by], errors='coerce').fillna(0.0) |
|
|
|
all_org_totals = filtered_df.groupby("organization")[count_by].sum() |
|
top_org_names = all_org_totals.nlargest(top_k, keep='first').index.tolist() |
|
|
|
top_orgs_df = filtered_df[filtered_df['organization'].isin(top_org_names)].copy() |
|
other_total = all_org_totals[~all_org_totals.index.isin(top_org_names)].sum() |
|
|
|
final_df_for_plot = top_orgs_df |
|
|
|
if other_total > 0: |
|
other_row = pd.DataFrame([{'organization': 'Other', 'id': 'Other', count_by: other_total}]) |
|
final_df_for_plot = pd.concat([final_df_for_plot, other_row], ignore_index=True) |
|
|
|
if skip_cats and len(skip_cats) > 0: |
|
final_df_for_plot = final_df_for_plot[~final_df_for_plot['organization'].isin(skip_cats)] |
|
|
|
final_df_for_plot["root"] = "datasets" |
|
return final_df_for_plot |
|
|
|
def create_treemap(treemap_data, count_by, title=None): |
|
"""Generate the Plotly treemap figure from the prepared data.""" |
|
if treemap_data.empty or treemap_data[count_by].sum() <= 0: |
|
fig = px.treemap(names=["No data matches filters"], parents=[""], values=[1]) |
|
fig.update_layout(title="No data matches the selected filters", margin=dict(t=50, l=25, r=25, b=25)) |
|
return fig |
|
|
|
fig = px.treemap(treemap_data, path=["root", "organization", "id"], values=count_by, |
|
title=title, color_discrete_sequence=px.colors.qualitative.Plotly) |
|
fig.update_layout(margin=dict(t=50, l=25, r=25, b=25)) |
|
fig.update_traces( |
|
textinfo="label+value+percent root", |
|
hovertemplate="<b>%{label}</b><br>%{value:,} " + count_by + "<br>%{percentRoot:.2%} of total<extra></extra>" |
|
) |
|
return fig |
|
|
|
|
|
with gr.Blocks(title="🤗 Dataverse Explorer", fill_width=True) as demo: |
|
datasets_data_state = gr.State(pd.DataFrame()) |
|
loading_complete_state = gr.State(False) |
|
|
|
with gr.Row(): |
|
gr.Markdown("# 🤗 Dataverse Explorer") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
count_by_dropdown = gr.Dropdown(label="Metric", choices=[("Downloads (last 30 days)", "downloads"), ("Downloads (All Time)", "downloadsAllTime"), ("Likes", "likes")], value="downloads") |
|
tag_filter_dropdown = gr.Dropdown(label="Filter by Tag", choices=TAG_FILTER_CHOICES, value="None") |
|
top_k_dropdown = gr.Dropdown(label="Number of Top Organizations", choices=TOP_K_CHOICES, value=25) |
|
skip_cats_textbox = gr.Textbox(label="Organizations to Skip from the plot", value="Other") |
|
generate_plot_button = gr.Button(value="Generate Plot", variant="primary", interactive=False) |
|
|
|
with gr.Column(scale=3): |
|
plot_output = gr.Plot() |
|
status_message_md = gr.Markdown("Initializing...") |
|
data_info_md = gr.Markdown("") |
|
|
|
def _update_button_interactivity(is_loaded_flag): |
|
return gr.update(interactive=is_loaded_flag) |
|
|
|
|
|
def load_and_generate_initial_plot(progress=gr.Progress()): |
|
progress(0, desc=f"Loading dataset '{HF_DATASET_ID}'...") |
|
|
|
try: |
|
current_df, load_success_flag, status_msg_from_load = load_datasets_data() |
|
if load_success_flag: |
|
progress(0.5, desc="Processing data...") |
|
date_display = "Pre-processed (date unavailable)" |
|
if 'data_download_timestamp' in current_df.columns and pd.notna(current_df['data_download_timestamp'].iloc[0]): |
|
ts = pd.to_datetime(current_df['data_download_timestamp'].iloc[0], utc=True) |
|
date_display = ts.strftime('%B %d, %Y, %H:%M:%S %Z') |
|
|
|
data_info_text = (f"### Data Information\n- Source: `{HF_DATASET_ID}`\n" |
|
f"- Status: {status_msg_from_load}\n" |
|
f"- Total datasets loaded: {len(current_df):,}\n" |
|
f"- Data as of: {date_display}\n") |
|
else: |
|
data_info_text = f"### Data Load Failed\n- {status_msg_from_load}" |
|
except Exception as e: |
|
status_msg_from_load = f"An unexpected error occurred: {str(e)}" |
|
data_info_text = f"### Critical Error\n- {status_msg_from_load}" |
|
load_success_flag = False |
|
current_df = pd.DataFrame() |
|
print(f"Critical error in load_and_generate_initial_plot: {e}") |
|
|
|
|
|
progress(0.6, desc="Generating initial plot...") |
|
|
|
default_metric = "downloads" |
|
default_tag = "None" |
|
default_k = 25 |
|
default_skip_cats = "Other" |
|
|
|
|
|
initial_plot, initial_status = ui_generate_plot_controller( |
|
default_metric, default_tag, default_k, default_skip_cats, current_df, progress |
|
) |
|
|
|
return current_df, load_success_flag, data_info_text, initial_status, initial_plot |
|
|
|
def ui_generate_plot_controller(metric_choice, tag_choice, k_orgs, |
|
skip_cats_input, df_current_datasets, progress=gr.Progress()): |
|
if df_current_datasets is None or df_current_datasets.empty: |
|
return create_treemap(pd.DataFrame(), metric_choice), "Dataset data is not loaded. Cannot generate plot." |
|
|
|
progress(0.1, desc="Aggregating data...") |
|
cats_to_skip = [cat.strip() for cat in skip_cats_input.split(',') if cat.strip()] |
|
|
|
treemap_df = make_treemap_data(df_current_datasets, metric_choice, k_orgs, tag_choice, cats_to_skip) |
|
|
|
progress(0.7, desc="Generating plot...") |
|
title_labels = {"downloads": "Downloads (last 30 days)", "downloadsAllTime": "Downloads (All Time)", "likes": "Likes"} |
|
chart_title = f"HuggingFace Datasets - {title_labels.get(metric_choice, metric_choice)} by Organization" |
|
plotly_fig = create_treemap(treemap_df, metric_choice, chart_title) |
|
|
|
if treemap_df.empty: |
|
plot_stats_md = "No data matches the selected filters. Please try different options." |
|
else: |
|
total_value_in_plot = treemap_df[metric_choice].sum() |
|
total_datasets_in_plot = treemap_df[treemap_df['id'] != 'Other']['id'].nunique() |
|
plot_stats_md = (f"## Plot Statistics\n- **Organizations/Categories Shown**: {treemap_df['organization'].nunique():,}\n" |
|
f"- **Individual Datasets Shown**: {total_datasets_in_plot:,}\n" |
|
f"- **Total {metric_choice} in plot**: {int(total_value_in_plot):,}") |
|
|
|
return plotly_fig, plot_stats_md |
|
|
|
|
|
|
|
|
|
demo.load( |
|
fn=load_and_generate_initial_plot, |
|
inputs=[], |
|
outputs=[datasets_data_state, loading_complete_state, data_info_md, status_message_md, plot_output] |
|
) |
|
|
|
loading_complete_state.change( |
|
fn=_update_button_interactivity, |
|
inputs=loading_complete_state, |
|
outputs=generate_plot_button |
|
) |
|
|
|
generate_plot_button.click( |
|
fn=ui_generate_plot_controller, |
|
inputs=[count_by_dropdown, tag_filter_dropdown, top_k_dropdown, |
|
skip_cats_textbox, datasets_data_state], |
|
outputs=[plot_output, status_message_md] |
|
) |
|
|
|
if __name__ == "__main__": |
|
print("Application starting...") |
|
demo.queue().launch() |