Spaces:
Runtime error
Runtime error
import torch | |
import time | |
import numpy as np | |
import pandas as pd | |
import evaluate | |
import gradio as gr | |
import re | |
import csv | |
from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
from sklearn.metrics import accuracy_score | |
from gradio_leaderboard import Leaderboard, ColumnFilter, SelectColumns | |
from dataclasses import dataclass | |
from typing import List | |
# Load Accuracy and F1-Score Metrics | |
accuracy_metric = evaluate.load("accuracy") | |
f1_metric = evaluate.load("f1") | |
# Define Model Paths | |
MODEL_PATHS = { | |
"MindBERT": "DrSyedFaizan/mindBERT", | |
"BERT-base": "bert-base-uncased", | |
"RoBERTa": "roberta-base", | |
"DistilBERT": "distilbert-base-uncased" | |
} | |
# Label Mapping | |
LABEL_MAPPING = { | |
0: "Stress", | |
1: "Depression", | |
2: "Bipolar disorder", | |
3: "Personality disorder", | |
4: "Anxiety" | |
} | |
# Function to clean text using regular expressions | |
def clean_text(text): | |
text = text.lower() | |
text = re.sub(r'http\S+', '', text) # Remove URLs | |
text = re.sub(r'\s+', ' ', text) # Remove excessive whitespace | |
text = re.sub(r'[^a-zA-Z0-9 ]', '', text) # Remove special characters | |
return text.strip() | |
# Load and preprocess Reddit Mental Health Dataset | |
def load_reddit_data(file_path, sample_size=100): | |
df = pd.read_csv(file_path, sep=",", encoding="utf-8", quotechar='"', on_bad_lines="skip", engine="python") | |
df.columns = df.columns.str.strip() # Remove extra spaces from column names | |
print("Columns in dataset:", df.columns) # Debugging check | |
if "text" not in df.columns or "target" not in df.columns: | |
raise ValueError("Dataset does not contain required 'text' and 'target' columns.") | |
df = df.dropna(subset=["text", "target"]) # Ensure required columns exist | |
df["text"] = df["text"].apply(clean_text) # Clean text column | |
df_sample = df.sample(n=sample_size, random_state=42) # Sample a subset | |
test_texts = df_sample["text"].tolist() | |
test_labels = df_sample["target"].tolist() | |
return test_texts, test_labels | |
# Function to evaluate models | |
def evaluate_models(dataset_path): | |
test_texts, test_labels = load_reddit_data(dataset_path) | |
results = [] | |
model_metadata = { | |
"MindBERT": {"model_type": "BERT", "precision": "float16", "params": 0.11, "license": "MIT"}, | |
"BERT-base": {"model_type": "BERT", "precision": "float16", "params": 0.11, "license": "Apache-2.0"}, | |
"RoBERTa": {"model_type": "RoBERTa", "precision": "float16", "params": 0.125, "license": "MIT"}, | |
"DistilBERT": {"model_type": "DistilBERT", "precision": "float16", "params": 0.067, "license": "Apache-2.0"} | |
} | |
for model_name, model_path in MODEL_PATHS.items(): | |
print(f"Evaluating {model_name}...") | |
tokenizer = AutoTokenizer.from_pretrained(model_path) | |
model = AutoModelForSequenceClassification.from_pretrained(model_path) | |
model.eval() | |
inputs = tokenizer(test_texts, padding=True, truncation=True, return_tensors="pt") | |
start_time = time.time() | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
predictions = torch.argmax(logits, dim=1).numpy() | |
end_time = time.time() | |
accuracy = accuracy_score(test_labels, predictions) | |
f1_score = f1_metric.compute(predictions=predictions, references=test_labels, average="macro")["f1"] | |
inference_time = round(end_time - start_time, 4) | |
result = { | |
"model": model_name, | |
"model_type": model_metadata[model_name]["model_type"], | |
"precision": model_metadata[model_name]["precision"], | |
"params": model_metadata[model_name]["params"], | |
"accuracy": round(accuracy, 4), | |
"f1_score": round(f1_score, 4), | |
"inference_time": inference_time, | |
"license": model_metadata[model_name]["license"] | |
} | |
results.append(result) | |
return pd.DataFrame(results) | |
# Load and evaluate | |
DATASET_PATH = "https://huggingface.co/spaces/DrSyedFaizan/mindBERTevaluation/blob/main/rmhd.csv" | |
df_results = evaluate_models(DATASET_PATH) | |
# Initialize leaderboard with custom columns | |
def init_leaderboard(dataframe): | |
if dataframe is None or dataframe.empty: | |
raise ValueError("Leaderboard DataFrame is empty or None.") | |
columns = fields(ModelEvalColumn) | |
return Leaderboard( | |
value=dataframe, | |
datatype=[c.type for c in columns], | |
select_columns=SelectColumns( | |
default_selection=[c.name for c in columns if c.displayed_by_default], | |
cant_deselect=[c.name for c in columns if c.never_hidden], | |
label="Select Columns to Display:", | |
), | |
search_columns=["model", "license"], | |
hide_columns=[c.name for c in columns if c.hidden], | |
filter_columns=[ | |
ColumnFilter("model_type", type="checkboxgroup", label="Model types"), | |
ColumnFilter("precision", type="checkboxgroup", label="Precision"), | |
ColumnFilter( | |
"params", | |
type="slider", | |
min=0.01, | |
max=0.5, | |
label="Select the number of parameters (B)", | |
), | |
], | |
interactive=False, | |
) | |
# Custom CSS similar to the original | |
custom_css = """ | |
.markdown-text { | |
padding: 0 20px; | |
} | |
.tab-buttons button.selected { | |
background-color: #FF9C00 !important; | |
color: white !important; | |
} | |
""" | |
# Create Gradio Interface | |
demo = gr.Blocks(css=custom_css) | |
with demo: | |
gr.HTML("<h1>Mental Health Model Evaluation Benchmark</h1>") | |
gr.Markdown("This benchmark evaluates various transformer models on mental health classification tasks.", elem_classes="markdown-text") | |
with gr.Tabs(elem_classes="tab-buttons") as tabs: | |
with gr.TabItem("π Model Benchmark", elem_id="model-benchmark-tab", id=0): | |
# Get evaluation results | |
df_results = evaluate_models() | |
leaderboard = init_leaderboard(df_results) | |
with gr.TabItem("π About", elem_id="about-tab", id=1): | |
gr.Markdown(""" | |
## About This Benchmark | |
This leaderboard compares various transformer models on mental health text classification tasks. | |
The benchmark uses a test set from Reddit Mental Health datasets with examples covering anxiety, | |
depression, bipolar disorder, suicidal ideation, stress, and normal emotional states. | |
Models are evaluated on: | |
- Accuracy | |
- F1-Score (Macro) | |
- Inference Time | |
### Model Types | |
- BERT-based models | |
- RoBERTa models | |
- DistilBERT models | |
- Specialized mental health models (MindBERT) | |
""", elem_classes="markdown-text") | |
with gr.TabItem("π Submit Model", elem_id="submit-tab", id=2): | |
gr.Markdown("# βοΈβ¨ Submit your model here!", elem_classes="markdown-text") | |
with gr.Row(): | |
with gr.Column(): | |
model_name_textbox = gr.Textbox(label="Model name") | |
model_path_textbox = gr.Textbox(label="Model path (HF repo ID)") | |
model_type = gr.Dropdown( | |
choices=["BERT", "RoBERTa", "DistilBERT", "GPT", "T5", "Other"], | |
label="Model type", | |
multiselect=False, | |
value=None, | |
interactive=True, | |
) | |
with gr.Column(): | |
precision = gr.Dropdown( | |
choices=["float16", "float32", "int8", "int4"], | |
label="Precision", | |
multiselect=False, | |
value="float16", | |
interactive=True, | |
) | |
params = gr.Number(label="Parameters (billions)", value=0.11) | |
license = gr.Textbox(label="License", value="Apache-2.0") | |
submit_button = gr.Button("Submit Model for Evaluation") | |
submission_result = gr.Markdown() | |
# This would typically connect to a submission system | |
def handle_submission(model_name, model_path, model_type, precision, params, license): | |
return f"Model {model_name} successfully submitted for evaluation. It will appear in the leaderboard once processing is complete." | |
submit_button.click( | |
handle_submission, | |
[model_name_textbox, model_path_textbox, model_type, precision, params, license], | |
submission_result, | |
) | |
with gr.Row(): | |
with gr.Accordion("π Citation", open=False): | |
citation_text = """ | |
@misc{mental-health-model-benchmark, | |
author = {Syed Faizan}, | |
title = {Mental Health Model Benchmark}, | |
year = {2025}, | |
publisher = {GitHub}, | |
url = {https://github.com/SYEDFAIZAN1987/mindBERT} | |
} | |
""" | |
citation_button = gr.Textbox( | |
value=citation_text, | |
label="Citation", | |
lines=10, | |
elem_id="citation-button", | |
show_copy_button=True, | |
) | |
demo.launch() |