jjkim
commited on
Commit
•
cb0919a
1
Parent(s):
fe7364e
add early termination
Browse files- code_eval.py +27 -10
code_eval.py
CHANGED
@@ -19,7 +19,7 @@ described in the paper "Evaluating Large Language Models Trained on Code"
|
|
19 |
import itertools
|
20 |
import os
|
21 |
from collections import Counter, defaultdict
|
22 |
-
from concurrent.futures import ThreadPoolExecutor, as_completed
|
23 |
|
24 |
import datasets
|
25 |
import evaluate
|
@@ -171,6 +171,7 @@ class CodeEval(evaluate.Metric):
|
|
171 |
|
172 |
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
173 |
futures = []
|
|
|
174 |
completion_id = Counter()
|
175 |
results = defaultdict(list)
|
176 |
|
@@ -189,31 +190,47 @@ class CodeEval(evaluate.Metric):
|
|
189 |
)
|
190 |
future = executor.submit(check_correctness, *args)
|
191 |
futures.append(future)
|
|
|
192 |
completion_id[task_id] += 1
|
193 |
|
194 |
pbar = tqdm(total=len(futures))
|
195 |
for future in as_completed(futures):
|
196 |
-
|
|
|
|
|
|
|
|
|
|
|
197 |
results[result["task_id"]].append((result["completion_id"], result))
|
198 |
pbar.update(1)
|
199 |
|
200 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
201 |
new_result = []
|
202 |
for completion_id, group in itertools.groupby(result, key=lambda x: x[0]):
|
203 |
group = list(group)
|
204 |
new_result.append(
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
|
|
|
|
|
|
210 |
)
|
211 |
)
|
212 |
-
|
|
|
213 |
|
214 |
total, correct = [], []
|
215 |
for result in results.values():
|
216 |
-
result.sort()
|
217 |
passed = [r[1]["passed"] for r in result]
|
218 |
total.append(len(passed))
|
219 |
correct.append(sum(passed))
|
|
|
19 |
import itertools
|
20 |
import os
|
21 |
from collections import Counter, defaultdict
|
22 |
+
from concurrent.futures import CancelledError, ThreadPoolExecutor, as_completed
|
23 |
|
24 |
import datasets
|
25 |
import evaluate
|
|
|
171 |
|
172 |
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
173 |
futures = []
|
174 |
+
future_dict = defaultdict(lambda: defaultdict(list))
|
175 |
completion_id = Counter()
|
176 |
results = defaultdict(list)
|
177 |
|
|
|
190 |
)
|
191 |
future = executor.submit(check_correctness, *args)
|
192 |
futures.append(future)
|
193 |
+
future_dict[task_id][completion_id[task_id]].append(future)
|
194 |
completion_id[task_id] += 1
|
195 |
|
196 |
pbar = tqdm(total=len(futures))
|
197 |
for future in as_completed(futures):
|
198 |
+
try:
|
199 |
+
result = future.result()
|
200 |
+
except CancelledError:
|
201 |
+
pbar.update(1)
|
202 |
+
continue
|
203 |
+
|
204 |
results[result["task_id"]].append((result["completion_id"], result))
|
205 |
pbar.update(1)
|
206 |
|
207 |
+
if not result["passed"]:
|
208 |
+
future_list = future_dict[result["task_id"]][result["completion_id"]]
|
209 |
+
for future in future_list:
|
210 |
+
future.cancel()
|
211 |
+
|
212 |
+
new_results = {}
|
213 |
+
for key, result in results.items():
|
214 |
new_result = []
|
215 |
for completion_id, group in itertools.groupby(result, key=lambda x: x[0]):
|
216 |
group = list(group)
|
217 |
new_result.append(
|
218 |
+
(
|
219 |
+
group[0][0],
|
220 |
+
dict(
|
221 |
+
task_id=group[0][0],
|
222 |
+
passed=all(r[1]["passed"] for r in group),
|
223 |
+
result=[r[1]["result"] for r in group],
|
224 |
+
completion_id=completion_id,
|
225 |
+
),
|
226 |
)
|
227 |
)
|
228 |
+
new_results[key] = new_result
|
229 |
+
results = new_results
|
230 |
|
231 |
total, correct = [], []
|
232 |
for result in results.values():
|
233 |
+
result.sort(key=lambda x: x[0])
|
234 |
passed = [r[1]["passed"] for r in result]
|
235 |
total.append(len(passed))
|
236 |
correct.append(sum(passed))
|