Spaces:
Runtime error
Runtime error
shizue
commited on
Commit
·
c9ecc6b
0
Parent(s):
...
Browse files- .gitattributes +34 -0
- .gitignore +3 -0
- README.md +15 -0
- app.py +317 -0
- content.py +38 -0
- requirements.txt +5 -0
- scorer.py +101 -0
.gitattributes
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
test_scripts/
|
2 |
+
**/__pycache__/
|
3 |
+
scored/
|
README.md
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: STEM Leaderboard
|
3 |
+
emoji: 🔬💻🛠️➕
|
4 |
+
colorFrom: yellow
|
5 |
+
colorTo: indigo
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 4.3.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: true
|
10 |
+
license: apache-2.0
|
11 |
+
tags:
|
12 |
+
- leaderboard
|
13 |
+
---
|
14 |
+
|
15 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import datetime
|
4 |
+
from email.utils import parseaddr
|
5 |
+
|
6 |
+
import gradio as gr
|
7 |
+
import pandas as pd
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
from datasets import load_dataset, DatasetDict
|
11 |
+
from apscheduler.schedulers.background import BackgroundScheduler
|
12 |
+
from huggingface_hub import HfApi
|
13 |
+
|
14 |
+
# InfoStrings
|
15 |
+
from scorer import question_scorer
|
16 |
+
from content import (
|
17 |
+
format_error,
|
18 |
+
format_warning,
|
19 |
+
format_log,
|
20 |
+
TITLE,
|
21 |
+
INTRODUCTION_TEXT,
|
22 |
+
CITATION_BUTTON_LABEL,
|
23 |
+
CITATION_BUTTON_TEXT,
|
24 |
+
model_hyperlink,
|
25 |
+
)
|
26 |
+
|
27 |
+
TOKEN = os.environ.get("TOKEN", None)
|
28 |
+
|
29 |
+
OWNER = "stemdataset"
|
30 |
+
INTERNAL_DATA_DATASET = f"{OWNER}/STEM-Labels-Private"
|
31 |
+
SUBMISSION_DATASET = f"{OWNER}/submissions_internal"
|
32 |
+
CONTACT_DATASET = f"{OWNER}/contact_info"
|
33 |
+
RESULTS_DATASET = f"{OWNER}/results"
|
34 |
+
LEADERBOARD_PATH = f"{OWNER}/stem-leaderboard"
|
35 |
+
api = HfApi()
|
36 |
+
|
37 |
+
os.makedirs("scored", exist_ok=True)
|
38 |
+
|
39 |
+
# Display the results
|
40 |
+
eval_results = load_dataset(
|
41 |
+
RESULTS_DATASET,
|
42 |
+
token=TOKEN,
|
43 |
+
download_mode="force_redownload",
|
44 |
+
verification_mode="no_checks",
|
45 |
+
)
|
46 |
+
contact_infos = load_dataset(
|
47 |
+
CONTACT_DATASET,
|
48 |
+
token=TOKEN,
|
49 |
+
download_mode="force_redownload",
|
50 |
+
verification_mode="no_checks",
|
51 |
+
)
|
52 |
+
|
53 |
+
|
54 |
+
def get_dataframe_from_results(eval_results: DatasetDict, split):
|
55 |
+
local_df = eval_results[split]
|
56 |
+
local_df = local_df.map(
|
57 |
+
lambda row: {"model": model_hyperlink(row["url"], row["model"])}
|
58 |
+
)
|
59 |
+
local_df = local_df.remove_columns(["url"])
|
60 |
+
local_df = local_df.rename_column("model", "Model Name")
|
61 |
+
local_df = local_df.rename_column("model_family", "Model Family")
|
62 |
+
local_df = local_df.rename_column("average", "Average")
|
63 |
+
local_df = local_df.rename_column("science", "Science")
|
64 |
+
local_df = local_df.rename_column("technology", "Technology")
|
65 |
+
local_df = local_df.rename_column("engineering", "Engineering")
|
66 |
+
local_df = local_df.rename_column("math", "Math")
|
67 |
+
local_df = local_df.rename_column("organisation", "Organisation")
|
68 |
+
local_df = local_df.rename_column("submit_date", "Submit Date")
|
69 |
+
df = pd.DataFrame(local_df)
|
70 |
+
df = df[[
|
71 |
+
"Model Name",
|
72 |
+
"Model Family",
|
73 |
+
"Science",
|
74 |
+
"Technology",
|
75 |
+
"Engineering",
|
76 |
+
"Math",
|
77 |
+
"Average",
|
78 |
+
"Organisation",
|
79 |
+
"Submit Date",
|
80 |
+
]]
|
81 |
+
df = df.sort_values(by=["Average"], ascending=False)
|
82 |
+
|
83 |
+
numeric_cols = ["Science", "Technology", "Engineering", "Math", "Average"]
|
84 |
+
df[numeric_cols] = df[numeric_cols].round(decimals=1)
|
85 |
+
for col in numeric_cols:
|
86 |
+
df[col] = df[col].apply(lambda x: f"{x:.1f}")
|
87 |
+
return df
|
88 |
+
|
89 |
+
|
90 |
+
eval_dataframe_test = get_dataframe_from_results(
|
91 |
+
eval_results=eval_results, split="basic"
|
92 |
+
)
|
93 |
+
|
94 |
+
# Gold answers
|
95 |
+
gold_dataset = load_dataset(INTERNAL_DATA_DATASET, token=TOKEN)["labels"]
|
96 |
+
|
97 |
+
|
98 |
+
def restart_space():
|
99 |
+
api.restart_space(repo_id=LEADERBOARD_PATH, token=TOKEN)
|
100 |
+
|
101 |
+
|
102 |
+
TYPES = ["markdown", "number", "number", "number", "number", "str", "str"]
|
103 |
+
|
104 |
+
|
105 |
+
def calc_test_acc(preds: list[int]) -> dict[str, float]:
|
106 |
+
tmp_accs = {
|
107 |
+
"science": [0, 0],
|
108 |
+
"technology": [0, 0],
|
109 |
+
"engineer": [0, 0],
|
110 |
+
"math": [0, 0],
|
111 |
+
}
|
112 |
+
labels = gold_dataset
|
113 |
+
for pred, label in zip(preds, labels):
|
114 |
+
subject = label["subject"]
|
115 |
+
tmp_accs[subject][1] += 1
|
116 |
+
if pred == label["answer_idx"]:
|
117 |
+
tmp_accs[subject][0] += 1
|
118 |
+
accs = {k: v[0] / v[1] for k, v in tmp_accs.items()}
|
119 |
+
accs["average"] = np.mean(list(accs.values()))
|
120 |
+
accs = {k: round(v * 100, 1) for k, v in accs.items()}
|
121 |
+
return accs
|
122 |
+
|
123 |
+
|
124 |
+
def add_new_eval(
|
125 |
+
val_or_test: str,
|
126 |
+
model: str,
|
127 |
+
model_family: str,
|
128 |
+
url: str,
|
129 |
+
path_to_file: gr.File,
|
130 |
+
organisation: str,
|
131 |
+
mail: str,
|
132 |
+
):
|
133 |
+
curr_timestamp = datetime.datetime.today()
|
134 |
+
# Very basic email parsing
|
135 |
+
_, parsed_mail = parseaddr(mail)
|
136 |
+
if not "@" in parsed_mail:
|
137 |
+
return format_warning("Please provide a valid email adress.")
|
138 |
+
if model == "":
|
139 |
+
return format_warning("Please provide a model name.")
|
140 |
+
if model_family == "":
|
141 |
+
return format_warning("Please provide a model family.")
|
142 |
+
print(
|
143 |
+
json.dumps(
|
144 |
+
{
|
145 |
+
"val_or_test": val_or_test,
|
146 |
+
"model": model,
|
147 |
+
"model_family": model_family,
|
148 |
+
"url": url,
|
149 |
+
"path_to_file": path_to_file,
|
150 |
+
"organisation": organisation,
|
151 |
+
"mail": mail,
|
152 |
+
},
|
153 |
+
indent=2,
|
154 |
+
)
|
155 |
+
)
|
156 |
+
|
157 |
+
print("Adding new eval")
|
158 |
+
|
159 |
+
# Check if the combination model/org already exists and prints a warning message if yes
|
160 |
+
if model.lower() in set(
|
161 |
+
[m.lower() for m in eval_results["basic"]["model"]]
|
162 |
+
) and organisation.lower() in set(
|
163 |
+
[l.lower() for l in eval_results["basic"]["organisation"]]
|
164 |
+
):
|
165 |
+
return format_warning("This model has been already submitted.")
|
166 |
+
|
167 |
+
if path_to_file is None:
|
168 |
+
return format_warning("Please attach a file.")
|
169 |
+
|
170 |
+
# Save submitted file
|
171 |
+
api.upload_file(
|
172 |
+
repo_id=SUBMISSION_DATASET,
|
173 |
+
path_or_fileobj=path_to_file.name,
|
174 |
+
path_in_repo=f"{organisation}/{model}/{val_or_test}_raw_{curr_timestamp}.txt",
|
175 |
+
repo_type="dataset",
|
176 |
+
token=TOKEN,
|
177 |
+
)
|
178 |
+
|
179 |
+
# Compute score
|
180 |
+
file_path = path_to_file.name
|
181 |
+
with open(f"scored/{organisation}_{model}.json", "w") as scored_file:
|
182 |
+
with open(file_path, "r") as f:
|
183 |
+
preds = []
|
184 |
+
for ix, line in enumerate(f):
|
185 |
+
try:
|
186 |
+
pred_idx = int(line.strip())
|
187 |
+
except Exception:
|
188 |
+
return format_error(
|
189 |
+
f"Line {ix} is incorrectly formatted. Please fix it and resubmit your file."
|
190 |
+
)
|
191 |
+
preds.append(pred_idx)
|
192 |
+
stem_scores = calc_test_acc(preds)
|
193 |
+
scored_file.write(json.dumps(stem_scores, indent=2))
|
194 |
+
|
195 |
+
# Save scored file
|
196 |
+
api.upload_file(
|
197 |
+
repo_id=SUBMISSION_DATASET,
|
198 |
+
path_or_fileobj=f"scored/{organisation}_{model}.json",
|
199 |
+
path_in_repo=f"{organisation}/{model}/{val_or_test}_scored_{curr_timestamp}.json",
|
200 |
+
repo_type="dataset",
|
201 |
+
token=TOKEN,
|
202 |
+
)
|
203 |
+
|
204 |
+
# Actual submission
|
205 |
+
eval_entry = {
|
206 |
+
"model": model,
|
207 |
+
"model_family": model_family,
|
208 |
+
"url": url,
|
209 |
+
"organisation": organisation,
|
210 |
+
"submit_date": "\n".join(str(curr_timestamp).split(" ")),
|
211 |
+
"science": stem_scores["science"],
|
212 |
+
"technology": stem_scores["technology"],
|
213 |
+
"engineering": stem_scores["engineer"],
|
214 |
+
"math": stem_scores["math"],
|
215 |
+
"average": stem_scores["average"],
|
216 |
+
}
|
217 |
+
eval_results["basic"] = eval_results["basic"].add_item(eval_entry)
|
218 |
+
print(eval_results)
|
219 |
+
eval_results.push_to_hub(RESULTS_DATASET, token=TOKEN)
|
220 |
+
|
221 |
+
contact_info = {
|
222 |
+
"model": model,
|
223 |
+
"model_family": model_family,
|
224 |
+
"url": url,
|
225 |
+
"organisation": organisation,
|
226 |
+
"mail": mail,
|
227 |
+
"submit_date": "\n".join(str(curr_timestamp).split(" ")),
|
228 |
+
}
|
229 |
+
contact_infos["basic"] = contact_infos["basic"].add_item(contact_info)
|
230 |
+
contact_infos.push_to_hub(CONTACT_DATASET, token=TOKEN)
|
231 |
+
|
232 |
+
return format_log(
|
233 |
+
f"Model {model} submitted by {organisation} successfully. \nPlease refresh the leaderboard, and wait a bit to see the score displayed"
|
234 |
+
)
|
235 |
+
|
236 |
+
|
237 |
+
def refresh():
|
238 |
+
eval_results = load_dataset(
|
239 |
+
RESULTS_DATASET,
|
240 |
+
token=TOKEN,
|
241 |
+
download_mode="force_redownload",
|
242 |
+
verification_mode="no_checks",
|
243 |
+
)
|
244 |
+
eval_dataframe_test = get_dataframe_from_results(
|
245 |
+
eval_results=eval_results, split="basic"
|
246 |
+
)
|
247 |
+
return eval_dataframe_test
|
248 |
+
|
249 |
+
|
250 |
+
def upload_file(files):
|
251 |
+
file_paths = [file.name for file in files]
|
252 |
+
return file_paths
|
253 |
+
|
254 |
+
|
255 |
+
demo = gr.Blocks()
|
256 |
+
with demo:
|
257 |
+
gr.HTML(TITLE)
|
258 |
+
gr.Markdown(INTRODUCTION_TEXT, elem_classes="markdown-text")
|
259 |
+
|
260 |
+
with gr.Row():
|
261 |
+
with gr.Accordion("📙 Citation", open=False):
|
262 |
+
citation_button = gr.Textbox(
|
263 |
+
value=CITATION_BUTTON_TEXT,
|
264 |
+
label=CITATION_BUTTON_LABEL,
|
265 |
+
elem_id="citation-button",
|
266 |
+
)
|
267 |
+
|
268 |
+
with gr.Tab("Results: Test"):
|
269 |
+
leaderboard_table_test = gr.components.Dataframe(
|
270 |
+
value=eval_dataframe_test,
|
271 |
+
datatype=TYPES,
|
272 |
+
interactive=False,
|
273 |
+
wrap=True,
|
274 |
+
)
|
275 |
+
|
276 |
+
refresh_button = gr.Button("Refresh")
|
277 |
+
refresh_button.click(
|
278 |
+
refresh,
|
279 |
+
inputs=[],
|
280 |
+
outputs=[
|
281 |
+
leaderboard_table_test,
|
282 |
+
],
|
283 |
+
)
|
284 |
+
with gr.Accordion("Submit a new model for evaluation"):
|
285 |
+
with gr.Row():
|
286 |
+
with gr.Column():
|
287 |
+
level_of_test = gr.Radio(["test"], value="test", label="Split")
|
288 |
+
model_name_textbox = gr.Textbox(label="Model name")
|
289 |
+
model_family_textbox = gr.Textbox(label="Model family")
|
290 |
+
url_textbox = gr.Textbox(label="Url to model information")
|
291 |
+
with gr.Column():
|
292 |
+
organisation = gr.Textbox(label="Organisation")
|
293 |
+
mail = gr.Textbox(
|
294 |
+
label="Contact email (will be stored privately, & used if there is an issue with your submission)"
|
295 |
+
)
|
296 |
+
file_output = gr.File()
|
297 |
+
|
298 |
+
submit_button = gr.Button("Submit Eval")
|
299 |
+
submission_result = gr.Markdown()
|
300 |
+
submit_button.click(
|
301 |
+
add_new_eval,
|
302 |
+
[
|
303 |
+
level_of_test,
|
304 |
+
model_name_textbox,
|
305 |
+
model_family_textbox,
|
306 |
+
url_textbox,
|
307 |
+
file_output,
|
308 |
+
organisation,
|
309 |
+
mail,
|
310 |
+
],
|
311 |
+
submission_result,
|
312 |
+
)
|
313 |
+
|
314 |
+
scheduler = BackgroundScheduler()
|
315 |
+
scheduler.add_job(restart_space, "interval", seconds=3600)
|
316 |
+
scheduler.start()
|
317 |
+
demo.launch(debug=True, server_name="0.0.0.0")
|
content.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
TITLE = """<h1 align="center" id="space-title">STEM Leaderboard</h1>"""
|
2 |
+
|
3 |
+
INTRODUCTION_TEXT = """
|
4 |
+
<p align="center">
|
5 |
+
📃 <a href="https://arxiv.org/abs/2402.17205" target="_blank">[Paper]</a> • 💻 <a href="https://github.com/stemdataset/STEM" target="_blank">[Github]</a> • 🤗 <a href="https://huggingface.co/datasets/stemdataset/STEM" target="_blank">[Dataset]</a> • 🏆 <a href="https://huggingface.co/spaces/stemdataset/stem-leaderboard" target="_blank">[Leaderboard]</a> • 📽 <a href="" target="_blank">[Slides]</a> • 📋 <a href="" target="_blank">[Poster]</a>
|
6 |
+
</p>
|
7 |
+
|
8 |
+
## Overview
|
9 |
+
|
10 |
+
This dataset is proposed in the ICLR 2024 paper: [Measuring Vision-Language STEM Skills of Neural Models](https://arxiv.org/abs/2402.17205). The problems in the real world often require solutions, combining knowledge from STEM (science, technology, engineering, and math). Unlike existing datasets, our dataset requires the understanding of multimodal vision-language information of STEM. Our dataset features one of the largest and most comprehensive datasets for the challenge. It includes 448 skills and 1,073,146 questions spanning all STEM subjects. Compared to existing datasets that often focus on examining expert-level ability, our dataset includes fundamental skills and questions designed based on the K-12 curriculum. We also add state-of-the-art foundation models such as CLIP and GPT-3.5-Turbo to our benchmark. Results show that the recent model advances only help master a very limited number of lower grade-level skills (2.5% in the third grade) in our dataset. In fact, these models are still well below (averaging 54.7%) the performance of elementary students, not to mention near expert-level performance. To understand and increase the performance on our dataset, we teach the models on a training split of our dataset. Even though we observe improved performance, the model performance remains relatively low compared to average elementary students. To solve STEM problems, we will need novel algorithmic innovations from the community.
|
11 |
+
|
12 |
+
## Submissions
|
13 |
+
Results can be submitted for test set, where the labels are hidden. Scores are expressed as the percentage of correct answers for a given split.
|
14 |
+
|
15 |
+
We expect submissions to contain a single `answer_idx` in a line for each question in the test set. Note that the order should be the same with the `test` split in the dataset.
|
16 |
+
"""
|
17 |
+
|
18 |
+
CITATION_BUTTON_LABEL = "Copy the following snippet to cite these results"
|
19 |
+
CITATION_BUTTON_TEXT = r"""@article{shen2024measuring,
|
20 |
+
title={Measuring Vision-Language STEM Skills of Neural Models},
|
21 |
+
author={Shen, Jianhao and Yuan, Ye and Mirzoyan, Srbuhi and Zhang, Ming and Wang, Chenguang},
|
22 |
+
journal={ICLR 2024},
|
23 |
+
year={2024}
|
24 |
+
}"""
|
25 |
+
|
26 |
+
|
27 |
+
def format_error(msg):
|
28 |
+
return f"<p style='color: red; font-size: 20px; text-align: center;'>{msg}</p>"
|
29 |
+
|
30 |
+
def format_warning(msg):
|
31 |
+
return f"<p style='color: orange; font-size: 20px; text-align: center;'>{msg}</p>"
|
32 |
+
|
33 |
+
def format_log(msg):
|
34 |
+
return f"<p style='color: green; font-size: 20px; text-align: center;'>{msg}</p>"
|
35 |
+
|
36 |
+
def model_hyperlink(link, model_name):
|
37 |
+
return f'<a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model_name}</a>'
|
38 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
datasets==2.14.5
|
2 |
+
gradio==4.3.0
|
3 |
+
huggingface-hub==0.18.0
|
4 |
+
numpy==1.24.2
|
5 |
+
APScheduler==3.10.1
|
scorer.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import re
|
3 |
+
import string
|
4 |
+
import warnings
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
|
9 |
+
def normalize_number_str(number_str: str) -> float:
|
10 |
+
# we replace these common units and commas to allow
|
11 |
+
# conversion to float
|
12 |
+
for char in ["$", "%", ","]:
|
13 |
+
number_str = number_str.replace(char, "")
|
14 |
+
try:
|
15 |
+
return float(number_str)
|
16 |
+
except ValueError:
|
17 |
+
print(f"String {number_str} cannot be normalized to number str.")
|
18 |
+
return float("inf")
|
19 |
+
|
20 |
+
|
21 |
+
def split_string(
|
22 |
+
s: str,
|
23 |
+
char_list: list[str] = [",", ";"],
|
24 |
+
) -> list[str]:
|
25 |
+
pattern = f"[{''.join(char_list)}]"
|
26 |
+
return re.split(pattern, s)
|
27 |
+
|
28 |
+
|
29 |
+
def question_scorer(
|
30 |
+
model_answer: str,
|
31 |
+
ground_truth: str,
|
32 |
+
) -> bool:
|
33 |
+
def is_float(element: any) -> bool:
|
34 |
+
try:
|
35 |
+
float(element)
|
36 |
+
return True
|
37 |
+
except ValueError:
|
38 |
+
return False
|
39 |
+
|
40 |
+
# if gt is a number
|
41 |
+
if is_float(ground_truth):
|
42 |
+
print(f"Evaluating {model_answer} as a number.")
|
43 |
+
normalized_answer = normalize_number_str(model_answer)
|
44 |
+
return normalized_answer == float(ground_truth)
|
45 |
+
|
46 |
+
# if gt is a list
|
47 |
+
elif any(char in ground_truth for char in [",", ";"]):
|
48 |
+
print(f"Evaluating {model_answer} as a comma separated list.")
|
49 |
+
# question with the fish: normalization removes punct
|
50 |
+
|
51 |
+
gt_elems = split_string(ground_truth)
|
52 |
+
ma_elems = split_string(model_answer)
|
53 |
+
|
54 |
+
# check length is the same
|
55 |
+
if len(gt_elems) != len(ma_elems):
|
56 |
+
warnings.warn(
|
57 |
+
"Answer lists have different lengths, returning False.", UserWarning
|
58 |
+
)
|
59 |
+
return False
|
60 |
+
|
61 |
+
# compare each element as float or str
|
62 |
+
comparisons = []
|
63 |
+
for ma_elem, gt_elem in zip(ma_elems, gt_elems):
|
64 |
+
if is_float(gt_elem):
|
65 |
+
normalized_ma_elem = normalize_number_str(ma_elem)
|
66 |
+
comparisons.append(normalized_ma_elem == float(gt_elem))
|
67 |
+
else:
|
68 |
+
# we do not remove punct since comparisons can include punct
|
69 |
+
comparisons.append(
|
70 |
+
normalize_str(ma_elem, remove_punct=False)
|
71 |
+
== normalize_str(gt_elem, remove_punct=False)
|
72 |
+
)
|
73 |
+
return all(comparisons)
|
74 |
+
|
75 |
+
# if gt is a str
|
76 |
+
else:
|
77 |
+
print(f"Evaluating {model_answer} as a string.")
|
78 |
+
return normalize_str(model_answer) == normalize_str(ground_truth)
|
79 |
+
|
80 |
+
|
81 |
+
def normalize_str(input_str, remove_punct=True) -> str:
|
82 |
+
"""
|
83 |
+
Normalize a string by:
|
84 |
+
- Removing all white spaces
|
85 |
+
- Optionally removing punctuation (if remove_punct is True)
|
86 |
+
- Converting to lowercase
|
87 |
+
Parameters:
|
88 |
+
- input_str: str, the string to normalize
|
89 |
+
- remove_punct: bool, whether to remove punctuation (default: True)
|
90 |
+
Returns:
|
91 |
+
- str, the normalized string
|
92 |
+
"""
|
93 |
+
# Remove all white spaces. Required e.g for seagull vs. sea gull
|
94 |
+
no_spaces = re.sub(r"\s", "", input_str)
|
95 |
+
|
96 |
+
# Remove punctuation, if specified.
|
97 |
+
if remove_punct:
|
98 |
+
translator = str.maketrans("", "", string.punctuation)
|
99 |
+
return no_spaces.lower().translate(translator)
|
100 |
+
else:
|
101 |
+
return no_spaces.lower()
|