Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,204 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
|
|
|
|
2 |
from gradio_leaderboard import Leaderboard, ColumnFilter, SelectColumns
|
3 |
-
import
|
4 |
-
from
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
)
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
def init_leaderboard(dataframe):
|
61 |
if dataframe is None or dataframe.empty:
|
62 |
raise ValueError("Leaderboard DataFrame is empty or None.")
|
|
|
|
|
|
|
63 |
return Leaderboard(
|
64 |
value=dataframe,
|
65 |
-
datatype=[c.type for c in
|
66 |
select_columns=SelectColumns(
|
67 |
-
default_selection=[c.name for c in
|
68 |
-
cant_deselect=[c.name for c in
|
69 |
label="Select Columns to Display:",
|
70 |
),
|
71 |
-
search_columns=[
|
72 |
-
hide_columns=[c.name for c in
|
73 |
filter_columns=[
|
74 |
-
ColumnFilter(
|
75 |
-
ColumnFilter(
|
76 |
ColumnFilter(
|
77 |
-
|
78 |
type="slider",
|
79 |
min=0.01,
|
80 |
-
max=
|
81 |
label="Select the number of parameters (B)",
|
82 |
),
|
83 |
-
ColumnFilter(
|
84 |
-
AutoEvalColumn.still_on_hub.name, type="boolean", label="Deleted/incomplete", default=True
|
85 |
-
),
|
86 |
],
|
87 |
-
bool_checkboxgroup_label="Hide models",
|
88 |
interactive=False,
|
89 |
)
|
90 |
|
91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
demo = gr.Blocks(css=custom_css)
|
93 |
-
with demo:
|
94 |
-
gr.HTML(TITLE)
|
95 |
-
gr.Markdown(INTRODUCTION_TEXT, elem_classes="markdown-text")
|
96 |
|
|
|
|
|
|
|
|
|
97 |
with gr.Tabs(elem_classes="tab-buttons") as tabs:
|
98 |
-
with gr.TabItem("π
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
running_eval_table = gr.components.Dataframe(
|
127 |
-
value=running_eval_queue_df,
|
128 |
-
headers=EVAL_COLS,
|
129 |
-
datatype=EVAL_TYPES,
|
130 |
-
row_count=5,
|
131 |
-
)
|
132 |
-
|
133 |
-
with gr.Accordion(
|
134 |
-
f"β³ Pending Evaluation Queue ({len(pending_eval_queue_df)})",
|
135 |
-
open=False,
|
136 |
-
):
|
137 |
-
with gr.Row():
|
138 |
-
pending_eval_table = gr.components.Dataframe(
|
139 |
-
value=pending_eval_queue_df,
|
140 |
-
headers=EVAL_COLS,
|
141 |
-
datatype=EVAL_TYPES,
|
142 |
-
row_count=5,
|
143 |
-
)
|
144 |
-
with gr.Row():
|
145 |
-
gr.Markdown("# βοΈβ¨ Submit your model here!", elem_classes="markdown-text")
|
146 |
-
|
147 |
with gr.Row():
|
148 |
with gr.Column():
|
149 |
model_name_textbox = gr.Textbox(label="Model name")
|
150 |
-
|
151 |
model_type = gr.Dropdown(
|
152 |
-
choices=[
|
153 |
label="Model type",
|
154 |
multiselect=False,
|
155 |
value=None,
|
156 |
interactive=True,
|
157 |
)
|
158 |
-
|
159 |
with gr.Column():
|
160 |
precision = gr.Dropdown(
|
161 |
-
choices=[
|
162 |
label="Precision",
|
163 |
multiselect=False,
|
164 |
value="float16",
|
165 |
interactive=True,
|
166 |
)
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
value="Original",
|
172 |
-
interactive=True,
|
173 |
-
)
|
174 |
-
base_model_name_textbox = gr.Textbox(label="Base model (for delta or adapter weights)")
|
175 |
-
|
176 |
-
submit_button = gr.Button("Submit Eval")
|
177 |
submission_result = gr.Markdown()
|
|
|
|
|
|
|
|
|
|
|
178 |
submit_button.click(
|
179 |
-
|
180 |
-
[
|
181 |
-
model_name_textbox,
|
182 |
-
base_model_name_textbox,
|
183 |
-
revision_name_textbox,
|
184 |
-
precision,
|
185 |
-
weight_type,
|
186 |
-
model_type,
|
187 |
-
],
|
188 |
submission_result,
|
189 |
)
|
190 |
|
191 |
with gr.Row():
|
192 |
with gr.Accordion("π Citation", open=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
citation_button = gr.Textbox(
|
194 |
-
value=
|
195 |
-
label=
|
196 |
-
lines=
|
197 |
elem_id="citation-button",
|
198 |
show_copy_button=True,
|
199 |
)
|
200 |
|
201 |
-
|
202 |
-
scheduler.add_job(restart_space, "interval", seconds=1800)
|
203 |
-
scheduler.start()
|
204 |
-
demo.queue(default_concurrency_limit=40).launch()
|
|
|
1 |
+
import torch
|
2 |
+
import time
|
3 |
+
import numpy as np
|
4 |
+
import pandas as pd
|
5 |
+
import evaluate
|
6 |
import gradio as gr
|
7 |
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
8 |
+
from sklearn.metrics import accuracy_score, classification_report
|
9 |
from gradio_leaderboard import Leaderboard, ColumnFilter, SelectColumns
|
10 |
+
from dataclasses import dataclass, field
|
11 |
+
from typing import List, Optional
|
12 |
+
|
13 |
+
# Load Accuracy and F1-Score Metrics
|
14 |
+
accuracy_metric = evaluate.load("accuracy")
|
15 |
+
f1_metric = evaluate.load("f1")
|
16 |
+
|
17 |
+
# Define Model Paths
|
18 |
+
MODEL_PATHS = {
|
19 |
+
"MindBERT": "DrSyedFaizan/mindBERT",
|
20 |
+
"BERT-base": "bert-base-uncased",
|
21 |
+
"RoBERTa": "roberta-base",
|
22 |
+
"DistilBERT": "distilbert-base-uncased"
|
23 |
+
}
|
24 |
+
|
25 |
+
# Load Test Dataset (Example: Reddit Mental Health)
|
26 |
+
test_texts = [
|
27 |
+
"I feel so anxious and panicked all the time.",
|
28 |
+
"I'm feeling absolutely wonderful today!",
|
29 |
+
"I don't think I can go on anymore, I feel suicidal.",
|
30 |
+
"Lately, I have mood swings that I can't explain.",
|
31 |
+
"I feel so stressed out about everything."
|
32 |
+
]
|
33 |
+
test_labels = [0, 3, 6, 1, 5] # Anxiety, Normal, Suicidal, Bipolar, Stress
|
34 |
+
|
35 |
+
# Define column structure for leaderboard
|
36 |
+
@dataclass
|
37 |
+
class ModelEvalColumn:
|
38 |
+
name: str
|
39 |
+
type: str
|
40 |
+
displayed_by_default: bool = True
|
41 |
+
never_hidden: bool = False
|
42 |
+
hidden: bool = False
|
43 |
+
|
44 |
+
# Define the columns for your leaderboard
|
45 |
+
fields = lambda cls: [
|
46 |
+
ModelEvalColumn(name="model", type="str", never_hidden=True),
|
47 |
+
ModelEvalColumn(name="model_type", type="str"),
|
48 |
+
ModelEvalColumn(name="precision", type="str"),
|
49 |
+
ModelEvalColumn(name="params", type="number"),
|
50 |
+
ModelEvalColumn(name="accuracy", type="number"),
|
51 |
+
ModelEvalColumn(name="f1_score", type="number"),
|
52 |
+
ModelEvalColumn(name="inference_time", type="number"),
|
53 |
+
ModelEvalColumn(name="license", type="str", displayed_by_default=False),
|
54 |
+
]
|
55 |
+
|
56 |
+
# Function to evaluate models and format for leaderboard
|
57 |
+
def evaluate_models():
|
58 |
+
results = []
|
59 |
+
|
60 |
+
# Model metadata (you would normally get this from model card or API)
|
61 |
+
model_metadata = {
|
62 |
+
"MindBERT": {"model_type": "BERT", "precision": "float16", "params": 0.11, "license": "MIT"},
|
63 |
+
"BERT-base": {"model_type": "BERT", "precision": "float16", "params": 0.11, "license": "Apache-2.0"},
|
64 |
+
"RoBERTa": {"model_type": "RoBERTa", "precision": "float16", "params": 0.125, "license": "MIT"},
|
65 |
+
"DistilBERT": {"model_type": "DistilBERT", "precision": "float16", "params": 0.067, "license": "Apache-2.0"}
|
66 |
+
}
|
67 |
+
|
68 |
+
for model_name, model_path in MODEL_PATHS.items():
|
69 |
+
print(f"Evaluating {model_name}...")
|
70 |
+
# Load Tokenizer and Model
|
71 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
72 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_path)
|
73 |
+
model.eval()
|
74 |
+
|
75 |
+
# Tokenize Test Data
|
76 |
+
inputs = tokenizer(test_texts, padding=True, truncation=True, return_tensors="pt")
|
77 |
+
|
78 |
+
# Measure Inference Time
|
79 |
+
start_time = time.time()
|
80 |
+
with torch.no_grad():
|
81 |
+
outputs = model(**inputs)
|
82 |
+
logits = outputs.logits
|
83 |
+
predictions = torch.argmax(logits, dim=1).numpy()
|
84 |
+
end_time = time.time()
|
85 |
+
|
86 |
+
# Compute Metrics
|
87 |
+
accuracy = accuracy_score(test_labels, predictions)
|
88 |
+
f1_score = f1_metric.compute(predictions=predictions, references=test_labels, average="macro")["f1"]
|
89 |
+
inference_time = round(end_time - start_time, 4)
|
90 |
+
|
91 |
+
# Store Results with additional metadata needed for leaderboard
|
92 |
+
result = {
|
93 |
+
"model": model_name,
|
94 |
+
"model_type": model_metadata[model_name]["model_type"],
|
95 |
+
"precision": model_metadata[model_name]["precision"],
|
96 |
+
"params": model_metadata[model_name]["params"],
|
97 |
+
"accuracy": round(accuracy, 4),
|
98 |
+
"f1_score": round(f1_score, 4),
|
99 |
+
"inference_time": inference_time,
|
100 |
+
"license": model_metadata[model_name]["license"]
|
101 |
+
}
|
102 |
+
results.append(result)
|
103 |
+
|
104 |
+
# Convert to DataFrame
|
105 |
+
df_results = pd.DataFrame(results)
|
106 |
+
return df_results
|
107 |
+
|
108 |
+
# Initialize leaderboard with custom columns
|
109 |
def init_leaderboard(dataframe):
|
110 |
if dataframe is None or dataframe.empty:
|
111 |
raise ValueError("Leaderboard DataFrame is empty or None.")
|
112 |
+
|
113 |
+
columns = fields(ModelEvalColumn)
|
114 |
+
|
115 |
return Leaderboard(
|
116 |
value=dataframe,
|
117 |
+
datatype=[c.type for c in columns],
|
118 |
select_columns=SelectColumns(
|
119 |
+
default_selection=[c.name for c in columns if c.displayed_by_default],
|
120 |
+
cant_deselect=[c.name for c in columns if c.never_hidden],
|
121 |
label="Select Columns to Display:",
|
122 |
),
|
123 |
+
search_columns=["model", "license"],
|
124 |
+
hide_columns=[c.name for c in columns if c.hidden],
|
125 |
filter_columns=[
|
126 |
+
ColumnFilter("model_type", type="checkboxgroup", label="Model types"),
|
127 |
+
ColumnFilter("precision", type="checkboxgroup", label="Precision"),
|
128 |
ColumnFilter(
|
129 |
+
"params",
|
130 |
type="slider",
|
131 |
min=0.01,
|
132 |
+
max=0.5,
|
133 |
label="Select the number of parameters (B)",
|
134 |
),
|
|
|
|
|
|
|
135 |
],
|
|
|
136 |
interactive=False,
|
137 |
)
|
138 |
|
139 |
+
# Custom CSS similar to the original
|
140 |
+
custom_css = """
|
141 |
+
.markdown-text {
|
142 |
+
padding: 0 20px;
|
143 |
+
}
|
144 |
+
.tab-buttons button.selected {
|
145 |
+
background-color: #FF9C00 !important;
|
146 |
+
color: white !important;
|
147 |
+
}
|
148 |
+
"""
|
149 |
+
|
150 |
+
# Create Gradio Interface
|
151 |
demo = gr.Blocks(css=custom_css)
|
|
|
|
|
|
|
152 |
|
153 |
+
with demo:
|
154 |
+
gr.HTML("<h1>Mental Health Model Evaluation Benchmark</h1>")
|
155 |
+
gr.Markdown("This benchmark evaluates various transformer models on mental health classification tasks.", elem_classes="markdown-text")
|
156 |
+
|
157 |
with gr.Tabs(elem_classes="tab-buttons") as tabs:
|
158 |
+
with gr.TabItem("π
Model Benchmark", elem_id="model-benchmark-tab", id=0):
|
159 |
+
# Get evaluation results
|
160 |
+
df_results = evaluate_models()
|
161 |
+
leaderboard = init_leaderboard(df_results)
|
162 |
+
|
163 |
+
with gr.TabItem("π About", elem_id="about-tab", id=1):
|
164 |
+
gr.Markdown("""
|
165 |
+
## About This Benchmark
|
166 |
+
|
167 |
+
This leaderboard compares various transformer models on mental health text classification tasks.
|
168 |
+
The benchmark uses a test set from Reddit Mental Health datasets with examples covering anxiety,
|
169 |
+
depression, bipolar disorder, suicidal ideation, stress, and normal emotional states.
|
170 |
+
|
171 |
+
Models are evaluated on:
|
172 |
+
- Accuracy
|
173 |
+
- F1-Score (Macro)
|
174 |
+
- Inference Time
|
175 |
+
|
176 |
+
### Model Types
|
177 |
+
- BERT-based models
|
178 |
+
- RoBERTa models
|
179 |
+
- DistilBERT models
|
180 |
+
- Specialized mental health models (MindBERT)
|
181 |
+
""", elem_classes="markdown-text")
|
182 |
+
|
183 |
+
with gr.TabItem("π Submit Model", elem_id="submit-tab", id=2):
|
184 |
+
gr.Markdown("# βοΈβ¨ Submit your model here!", elem_classes="markdown-text")
|
185 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
with gr.Row():
|
187 |
with gr.Column():
|
188 |
model_name_textbox = gr.Textbox(label="Model name")
|
189 |
+
model_path_textbox = gr.Textbox(label="Model path (HF repo ID)")
|
190 |
model_type = gr.Dropdown(
|
191 |
+
choices=["BERT", "RoBERTa", "DistilBERT", "GPT", "T5", "Other"],
|
192 |
label="Model type",
|
193 |
multiselect=False,
|
194 |
value=None,
|
195 |
interactive=True,
|
196 |
)
|
197 |
+
|
198 |
with gr.Column():
|
199 |
precision = gr.Dropdown(
|
200 |
+
choices=["float16", "float32", "int8", "int4"],
|
201 |
label="Precision",
|
202 |
multiselect=False,
|
203 |
value="float16",
|
204 |
interactive=True,
|
205 |
)
|
206 |
+
params = gr.Number(label="Parameters (billions)", value=0.11)
|
207 |
+
license = gr.Textbox(label="License", value="Apache-2.0")
|
208 |
+
|
209 |
+
submit_button = gr.Button("Submit Model for Evaluation")
|
|
|
|
|
|
|
|
|
|
|
|
|
210 |
submission_result = gr.Markdown()
|
211 |
+
|
212 |
+
# This would typically connect to a submission system
|
213 |
+
def handle_submission(model_name, model_path, model_type, precision, params, license):
|
214 |
+
return f"Model {model_name} successfully submitted for evaluation. It will appear in the leaderboard once processing is complete."
|
215 |
+
|
216 |
submit_button.click(
|
217 |
+
handle_submission,
|
218 |
+
[model_name_textbox, model_path_textbox, model_type, precision, params, license],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
219 |
submission_result,
|
220 |
)
|
221 |
|
222 |
with gr.Row():
|
223 |
with gr.Accordion("π Citation", open=False):
|
224 |
+
citation_text = """
|
225 |
+
@misc{mental-health-model-benchmark,
|
226 |
+
author = {Syed Faizan},
|
227 |
+
title = {Mental Health Model Benchmark},
|
228 |
+
year = {2025},
|
229 |
+
publisher = {GitHub},
|
230 |
+
url = {https://github.com/SYEDFAIZAN1987/mindBERT}
|
231 |
+
}
|
232 |
+
"""
|
233 |
citation_button = gr.Textbox(
|
234 |
+
value=citation_text,
|
235 |
+
label="Citation",
|
236 |
+
lines=10,
|
237 |
elem_id="citation-button",
|
238 |
show_copy_button=True,
|
239 |
)
|
240 |
|
241 |
+
demo.launch()
|
|
|
|
|
|