shizue commited on
Commit
c9ecc6b
·
0 Parent(s):
Files changed (7) hide show
  1. .gitattributes +34 -0
  2. .gitignore +3 -0
  3. README.md +15 -0
  4. app.py +317 -0
  5. content.py +38 -0
  6. requirements.txt +5 -0
  7. scorer.py +101 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ test_scripts/
2
+ **/__pycache__/
3
+ scored/
README.md ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: STEM Leaderboard
3
+ emoji: 🔬💻🛠️➕
4
+ colorFrom: yellow
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ sdk_version: 4.3.0
8
+ app_file: app.py
9
+ pinned: true
10
+ license: apache-2.0
11
+ tags:
12
+ - leaderboard
13
+ ---
14
+
15
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import datetime
4
+ from email.utils import parseaddr
5
+
6
+ import gradio as gr
7
+ import pandas as pd
8
+ import numpy as np
9
+
10
+ from datasets import load_dataset, DatasetDict
11
+ from apscheduler.schedulers.background import BackgroundScheduler
12
+ from huggingface_hub import HfApi
13
+
14
+ # InfoStrings
15
+ from scorer import question_scorer
16
+ from content import (
17
+ format_error,
18
+ format_warning,
19
+ format_log,
20
+ TITLE,
21
+ INTRODUCTION_TEXT,
22
+ CITATION_BUTTON_LABEL,
23
+ CITATION_BUTTON_TEXT,
24
+ model_hyperlink,
25
+ )
26
+
27
+ TOKEN = os.environ.get("TOKEN", None)
28
+
29
+ OWNER = "stemdataset"
30
+ INTERNAL_DATA_DATASET = f"{OWNER}/STEM-Labels-Private"
31
+ SUBMISSION_DATASET = f"{OWNER}/submissions_internal"
32
+ CONTACT_DATASET = f"{OWNER}/contact_info"
33
+ RESULTS_DATASET = f"{OWNER}/results"
34
+ LEADERBOARD_PATH = f"{OWNER}/stem-leaderboard"
35
+ api = HfApi()
36
+
37
+ os.makedirs("scored", exist_ok=True)
38
+
39
+ # Display the results
40
+ eval_results = load_dataset(
41
+ RESULTS_DATASET,
42
+ token=TOKEN,
43
+ download_mode="force_redownload",
44
+ verification_mode="no_checks",
45
+ )
46
+ contact_infos = load_dataset(
47
+ CONTACT_DATASET,
48
+ token=TOKEN,
49
+ download_mode="force_redownload",
50
+ verification_mode="no_checks",
51
+ )
52
+
53
+
54
+ def get_dataframe_from_results(eval_results: DatasetDict, split):
55
+ local_df = eval_results[split]
56
+ local_df = local_df.map(
57
+ lambda row: {"model": model_hyperlink(row["url"], row["model"])}
58
+ )
59
+ local_df = local_df.remove_columns(["url"])
60
+ local_df = local_df.rename_column("model", "Model Name")
61
+ local_df = local_df.rename_column("model_family", "Model Family")
62
+ local_df = local_df.rename_column("average", "Average")
63
+ local_df = local_df.rename_column("science", "Science")
64
+ local_df = local_df.rename_column("technology", "Technology")
65
+ local_df = local_df.rename_column("engineering", "Engineering")
66
+ local_df = local_df.rename_column("math", "Math")
67
+ local_df = local_df.rename_column("organisation", "Organisation")
68
+ local_df = local_df.rename_column("submit_date", "Submit Date")
69
+ df = pd.DataFrame(local_df)
70
+ df = df[[
71
+ "Model Name",
72
+ "Model Family",
73
+ "Science",
74
+ "Technology",
75
+ "Engineering",
76
+ "Math",
77
+ "Average",
78
+ "Organisation",
79
+ "Submit Date",
80
+ ]]
81
+ df = df.sort_values(by=["Average"], ascending=False)
82
+
83
+ numeric_cols = ["Science", "Technology", "Engineering", "Math", "Average"]
84
+ df[numeric_cols] = df[numeric_cols].round(decimals=1)
85
+ for col in numeric_cols:
86
+ df[col] = df[col].apply(lambda x: f"{x:.1f}")
87
+ return df
88
+
89
+
90
+ eval_dataframe_test = get_dataframe_from_results(
91
+ eval_results=eval_results, split="basic"
92
+ )
93
+
94
+ # Gold answers
95
+ gold_dataset = load_dataset(INTERNAL_DATA_DATASET, token=TOKEN)["labels"]
96
+
97
+
98
+ def restart_space():
99
+ api.restart_space(repo_id=LEADERBOARD_PATH, token=TOKEN)
100
+
101
+
102
+ TYPES = ["markdown", "number", "number", "number", "number", "str", "str"]
103
+
104
+
105
+ def calc_test_acc(preds: list[int]) -> dict[str, float]:
106
+ tmp_accs = {
107
+ "science": [0, 0],
108
+ "technology": [0, 0],
109
+ "engineer": [0, 0],
110
+ "math": [0, 0],
111
+ }
112
+ labels = gold_dataset
113
+ for pred, label in zip(preds, labels):
114
+ subject = label["subject"]
115
+ tmp_accs[subject][1] += 1
116
+ if pred == label["answer_idx"]:
117
+ tmp_accs[subject][0] += 1
118
+ accs = {k: v[0] / v[1] for k, v in tmp_accs.items()}
119
+ accs["average"] = np.mean(list(accs.values()))
120
+ accs = {k: round(v * 100, 1) for k, v in accs.items()}
121
+ return accs
122
+
123
+
124
+ def add_new_eval(
125
+ val_or_test: str,
126
+ model: str,
127
+ model_family: str,
128
+ url: str,
129
+ path_to_file: gr.File,
130
+ organisation: str,
131
+ mail: str,
132
+ ):
133
+ curr_timestamp = datetime.datetime.today()
134
+ # Very basic email parsing
135
+ _, parsed_mail = parseaddr(mail)
136
+ if not "@" in parsed_mail:
137
+ return format_warning("Please provide a valid email adress.")
138
+ if model == "":
139
+ return format_warning("Please provide a model name.")
140
+ if model_family == "":
141
+ return format_warning("Please provide a model family.")
142
+ print(
143
+ json.dumps(
144
+ {
145
+ "val_or_test": val_or_test,
146
+ "model": model,
147
+ "model_family": model_family,
148
+ "url": url,
149
+ "path_to_file": path_to_file,
150
+ "organisation": organisation,
151
+ "mail": mail,
152
+ },
153
+ indent=2,
154
+ )
155
+ )
156
+
157
+ print("Adding new eval")
158
+
159
+ # Check if the combination model/org already exists and prints a warning message if yes
160
+ if model.lower() in set(
161
+ [m.lower() for m in eval_results["basic"]["model"]]
162
+ ) and organisation.lower() in set(
163
+ [l.lower() for l in eval_results["basic"]["organisation"]]
164
+ ):
165
+ return format_warning("This model has been already submitted.")
166
+
167
+ if path_to_file is None:
168
+ return format_warning("Please attach a file.")
169
+
170
+ # Save submitted file
171
+ api.upload_file(
172
+ repo_id=SUBMISSION_DATASET,
173
+ path_or_fileobj=path_to_file.name,
174
+ path_in_repo=f"{organisation}/{model}/{val_or_test}_raw_{curr_timestamp}.txt",
175
+ repo_type="dataset",
176
+ token=TOKEN,
177
+ )
178
+
179
+ # Compute score
180
+ file_path = path_to_file.name
181
+ with open(f"scored/{organisation}_{model}.json", "w") as scored_file:
182
+ with open(file_path, "r") as f:
183
+ preds = []
184
+ for ix, line in enumerate(f):
185
+ try:
186
+ pred_idx = int(line.strip())
187
+ except Exception:
188
+ return format_error(
189
+ f"Line {ix} is incorrectly formatted. Please fix it and resubmit your file."
190
+ )
191
+ preds.append(pred_idx)
192
+ stem_scores = calc_test_acc(preds)
193
+ scored_file.write(json.dumps(stem_scores, indent=2))
194
+
195
+ # Save scored file
196
+ api.upload_file(
197
+ repo_id=SUBMISSION_DATASET,
198
+ path_or_fileobj=f"scored/{organisation}_{model}.json",
199
+ path_in_repo=f"{organisation}/{model}/{val_or_test}_scored_{curr_timestamp}.json",
200
+ repo_type="dataset",
201
+ token=TOKEN,
202
+ )
203
+
204
+ # Actual submission
205
+ eval_entry = {
206
+ "model": model,
207
+ "model_family": model_family,
208
+ "url": url,
209
+ "organisation": organisation,
210
+ "submit_date": "\n".join(str(curr_timestamp).split(" ")),
211
+ "science": stem_scores["science"],
212
+ "technology": stem_scores["technology"],
213
+ "engineering": stem_scores["engineer"],
214
+ "math": stem_scores["math"],
215
+ "average": stem_scores["average"],
216
+ }
217
+ eval_results["basic"] = eval_results["basic"].add_item(eval_entry)
218
+ print(eval_results)
219
+ eval_results.push_to_hub(RESULTS_DATASET, token=TOKEN)
220
+
221
+ contact_info = {
222
+ "model": model,
223
+ "model_family": model_family,
224
+ "url": url,
225
+ "organisation": organisation,
226
+ "mail": mail,
227
+ "submit_date": "\n".join(str(curr_timestamp).split(" ")),
228
+ }
229
+ contact_infos["basic"] = contact_infos["basic"].add_item(contact_info)
230
+ contact_infos.push_to_hub(CONTACT_DATASET, token=TOKEN)
231
+
232
+ return format_log(
233
+ f"Model {model} submitted by {organisation} successfully. \nPlease refresh the leaderboard, and wait a bit to see the score displayed"
234
+ )
235
+
236
+
237
+ def refresh():
238
+ eval_results = load_dataset(
239
+ RESULTS_DATASET,
240
+ token=TOKEN,
241
+ download_mode="force_redownload",
242
+ verification_mode="no_checks",
243
+ )
244
+ eval_dataframe_test = get_dataframe_from_results(
245
+ eval_results=eval_results, split="basic"
246
+ )
247
+ return eval_dataframe_test
248
+
249
+
250
+ def upload_file(files):
251
+ file_paths = [file.name for file in files]
252
+ return file_paths
253
+
254
+
255
+ demo = gr.Blocks()
256
+ with demo:
257
+ gr.HTML(TITLE)
258
+ gr.Markdown(INTRODUCTION_TEXT, elem_classes="markdown-text")
259
+
260
+ with gr.Row():
261
+ with gr.Accordion("📙 Citation", open=False):
262
+ citation_button = gr.Textbox(
263
+ value=CITATION_BUTTON_TEXT,
264
+ label=CITATION_BUTTON_LABEL,
265
+ elem_id="citation-button",
266
+ )
267
+
268
+ with gr.Tab("Results: Test"):
269
+ leaderboard_table_test = gr.components.Dataframe(
270
+ value=eval_dataframe_test,
271
+ datatype=TYPES,
272
+ interactive=False,
273
+ wrap=True,
274
+ )
275
+
276
+ refresh_button = gr.Button("Refresh")
277
+ refresh_button.click(
278
+ refresh,
279
+ inputs=[],
280
+ outputs=[
281
+ leaderboard_table_test,
282
+ ],
283
+ )
284
+ with gr.Accordion("Submit a new model for evaluation"):
285
+ with gr.Row():
286
+ with gr.Column():
287
+ level_of_test = gr.Radio(["test"], value="test", label="Split")
288
+ model_name_textbox = gr.Textbox(label="Model name")
289
+ model_family_textbox = gr.Textbox(label="Model family")
290
+ url_textbox = gr.Textbox(label="Url to model information")
291
+ with gr.Column():
292
+ organisation = gr.Textbox(label="Organisation")
293
+ mail = gr.Textbox(
294
+ label="Contact email (will be stored privately, & used if there is an issue with your submission)"
295
+ )
296
+ file_output = gr.File()
297
+
298
+ submit_button = gr.Button("Submit Eval")
299
+ submission_result = gr.Markdown()
300
+ submit_button.click(
301
+ add_new_eval,
302
+ [
303
+ level_of_test,
304
+ model_name_textbox,
305
+ model_family_textbox,
306
+ url_textbox,
307
+ file_output,
308
+ organisation,
309
+ mail,
310
+ ],
311
+ submission_result,
312
+ )
313
+
314
+ scheduler = BackgroundScheduler()
315
+ scheduler.add_job(restart_space, "interval", seconds=3600)
316
+ scheduler.start()
317
+ demo.launch(debug=True, server_name="0.0.0.0")
content.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TITLE = """<h1 align="center" id="space-title">STEM Leaderboard</h1>"""
2
+
3
+ INTRODUCTION_TEXT = """
4
+ <p align="center">
5
+ 📃 <a href="https://arxiv.org/abs/2402.17205" target="_blank">[Paper]</a> • 💻 <a href="https://github.com/stemdataset/STEM" target="_blank">[Github]</a> • 🤗 <a href="https://huggingface.co/datasets/stemdataset/STEM" target="_blank">[Dataset]</a> • 🏆 <a href="https://huggingface.co/spaces/stemdataset/stem-leaderboard" target="_blank">[Leaderboard]</a> • 📽 <a href="" target="_blank">[Slides]</a> • 📋 <a href="" target="_blank">[Poster]</a>
6
+ </p>
7
+
8
+ ## Overview
9
+
10
+ This dataset is proposed in the ICLR 2024 paper: [Measuring Vision-Language STEM Skills of Neural Models](https://arxiv.org/abs/2402.17205). The problems in the real world often require solutions, combining knowledge from STEM (science, technology, engineering, and math). Unlike existing datasets, our dataset requires the understanding of multimodal vision-language information of STEM. Our dataset features one of the largest and most comprehensive datasets for the challenge. It includes 448 skills and 1,073,146 questions spanning all STEM subjects. Compared to existing datasets that often focus on examining expert-level ability, our dataset includes fundamental skills and questions designed based on the K-12 curriculum. We also add state-of-the-art foundation models such as CLIP and GPT-3.5-Turbo to our benchmark. Results show that the recent model advances only help master a very limited number of lower grade-level skills (2.5% in the third grade) in our dataset. In fact, these models are still well below (averaging 54.7%) the performance of elementary students, not to mention near expert-level performance. To understand and increase the performance on our dataset, we teach the models on a training split of our dataset. Even though we observe improved performance, the model performance remains relatively low compared to average elementary students. To solve STEM problems, we will need novel algorithmic innovations from the community.
11
+
12
+ ## Submissions
13
+ Results can be submitted for test set, where the labels are hidden. Scores are expressed as the percentage of correct answers for a given split.
14
+
15
+ We expect submissions to contain a single `answer_idx` in a line for each question in the test set. Note that the order should be the same with the `test` split in the dataset.
16
+ """
17
+
18
+ CITATION_BUTTON_LABEL = "Copy the following snippet to cite these results"
19
+ CITATION_BUTTON_TEXT = r"""@article{shen2024measuring,
20
+ title={Measuring Vision-Language STEM Skills of Neural Models},
21
+ author={Shen, Jianhao and Yuan, Ye and Mirzoyan, Srbuhi and Zhang, Ming and Wang, Chenguang},
22
+ journal={ICLR 2024},
23
+ year={2024}
24
+ }"""
25
+
26
+
27
+ def format_error(msg):
28
+ return f"<p style='color: red; font-size: 20px; text-align: center;'>{msg}</p>"
29
+
30
+ def format_warning(msg):
31
+ return f"<p style='color: orange; font-size: 20px; text-align: center;'>{msg}</p>"
32
+
33
+ def format_log(msg):
34
+ return f"<p style='color: green; font-size: 20px; text-align: center;'>{msg}</p>"
35
+
36
+ def model_hyperlink(link, model_name):
37
+ return f'<a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model_name}</a>'
38
+
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ datasets==2.14.5
2
+ gradio==4.3.0
3
+ huggingface-hub==0.18.0
4
+ numpy==1.24.2
5
+ APScheduler==3.10.1
scorer.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ import string
4
+ import warnings
5
+
6
+ import numpy as np
7
+
8
+
9
+ def normalize_number_str(number_str: str) -> float:
10
+ # we replace these common units and commas to allow
11
+ # conversion to float
12
+ for char in ["$", "%", ","]:
13
+ number_str = number_str.replace(char, "")
14
+ try:
15
+ return float(number_str)
16
+ except ValueError:
17
+ print(f"String {number_str} cannot be normalized to number str.")
18
+ return float("inf")
19
+
20
+
21
+ def split_string(
22
+ s: str,
23
+ char_list: list[str] = [",", ";"],
24
+ ) -> list[str]:
25
+ pattern = f"[{''.join(char_list)}]"
26
+ return re.split(pattern, s)
27
+
28
+
29
+ def question_scorer(
30
+ model_answer: str,
31
+ ground_truth: str,
32
+ ) -> bool:
33
+ def is_float(element: any) -> bool:
34
+ try:
35
+ float(element)
36
+ return True
37
+ except ValueError:
38
+ return False
39
+
40
+ # if gt is a number
41
+ if is_float(ground_truth):
42
+ print(f"Evaluating {model_answer} as a number.")
43
+ normalized_answer = normalize_number_str(model_answer)
44
+ return normalized_answer == float(ground_truth)
45
+
46
+ # if gt is a list
47
+ elif any(char in ground_truth for char in [",", ";"]):
48
+ print(f"Evaluating {model_answer} as a comma separated list.")
49
+ # question with the fish: normalization removes punct
50
+
51
+ gt_elems = split_string(ground_truth)
52
+ ma_elems = split_string(model_answer)
53
+
54
+ # check length is the same
55
+ if len(gt_elems) != len(ma_elems):
56
+ warnings.warn(
57
+ "Answer lists have different lengths, returning False.", UserWarning
58
+ )
59
+ return False
60
+
61
+ # compare each element as float or str
62
+ comparisons = []
63
+ for ma_elem, gt_elem in zip(ma_elems, gt_elems):
64
+ if is_float(gt_elem):
65
+ normalized_ma_elem = normalize_number_str(ma_elem)
66
+ comparisons.append(normalized_ma_elem == float(gt_elem))
67
+ else:
68
+ # we do not remove punct since comparisons can include punct
69
+ comparisons.append(
70
+ normalize_str(ma_elem, remove_punct=False)
71
+ == normalize_str(gt_elem, remove_punct=False)
72
+ )
73
+ return all(comparisons)
74
+
75
+ # if gt is a str
76
+ else:
77
+ print(f"Evaluating {model_answer} as a string.")
78
+ return normalize_str(model_answer) == normalize_str(ground_truth)
79
+
80
+
81
+ def normalize_str(input_str, remove_punct=True) -> str:
82
+ """
83
+ Normalize a string by:
84
+ - Removing all white spaces
85
+ - Optionally removing punctuation (if remove_punct is True)
86
+ - Converting to lowercase
87
+ Parameters:
88
+ - input_str: str, the string to normalize
89
+ - remove_punct: bool, whether to remove punctuation (default: True)
90
+ Returns:
91
+ - str, the normalized string
92
+ """
93
+ # Remove all white spaces. Required e.g for seagull vs. sea gull
94
+ no_spaces = re.sub(r"\s", "", input_str)
95
+
96
+ # Remove punctuation, if specified.
97
+ if remove_punct:
98
+ translator = str.maketrans("", "", string.punctuation)
99
+ return no_spaces.lower().translate(translator)
100
+ else:
101
+ return no_spaces.lower()