loubnabnl HF staff commited on
Commit
e21caa2
1 Parent(s): c6538f5
Files changed (2) hide show
  1. testing_util.py +434 -0
  2. utils.py +159 -0
testing_util.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import sys
3
+ import faulthandler
4
+
5
+ # used for debugging to time steps
6
+ from datetime import datetime
7
+
8
+ # to run the solution files we're using a timing based approach
9
+ import signal
10
+
11
+ import numpy as np
12
+ # for capturing the stdout
13
+ from io import StringIO
14
+ # used for testing the code that reads from input
15
+ from unittest.mock import patch, mock_open
16
+
17
+ from pyext import RuntimeModule
18
+
19
+ from enum import Enum
20
+ class CODE_TYPE(Enum):
21
+ call_based = 0
22
+ standard_input = 1
23
+
24
+ # stuff for setting up signal timer
25
+ class TimeoutException(Exception):
26
+ pass
27
+ def timeout_handler(signum, frame):
28
+ print("alarm went off")
29
+ #return
30
+ raise TimeoutException
31
+ signal.signal(signal.SIGALRM, timeout_handler)
32
+ timeout = 4 # seconds
33
+
34
+ # used to capture stdout as a list
35
+ # from https://stackoverflow.com/a/16571630/6416660
36
+ # alternative use redirect_stdout() from contextlib
37
+ class Capturing(list):
38
+ def __enter__(self):
39
+ self._stdout = sys.stdout
40
+ sys.stdout = self._stringio = StringIO()
41
+ # Make closing the StringIO a no-op
42
+ self._stringio.close = lambda x: 1
43
+ return self
44
+ def __exit__(self, *args):
45
+ self.extend(self._stringio.getvalue().splitlines())
46
+ del self._stringio # free up some memory
47
+ sys.stdout = self._stdout
48
+
49
+
50
+ def run_test(sample, test=None, debug=False):
51
+ """
52
+ if test(generated_code) is not None it'll try to run the code.
53
+ otherwise it'll just return an input and output pair.
54
+ """
55
+ if debug:
56
+ print(f"start = {datetime.now().time()}")
57
+
58
+ try:
59
+ in_outs = json.loads(sample["input_output"])
60
+ except ValueError:
61
+ in_outs = None
62
+ if in_outs:
63
+ #if debug:
64
+ # print(f"test cases json = {in_outs['inputs']} {in_outs['outputs']}")
65
+ if in_outs.get("fn_name") is None:
66
+ which_type = CODE_TYPE.standard_input # Standard input
67
+ method_name = None
68
+ else:
69
+ which_type = CODE_TYPE.call_based # Call-based
70
+ method_name = in_outs["fn_name"]
71
+
72
+ if debug:
73
+ print(f"loaded input_output = {datetime.now().time()}")
74
+
75
+ #else:
76
+ # continue
77
+ if test is None:
78
+ return in_outs
79
+ elif test is not None:
80
+ results = []
81
+ sol = "import sys\nimport time\nimport itertools\nfrom itertools import accumulate, product, permutations, combinations\nimport collections\nfrom collections import Counter, OrderedDict, deque, defaultdict, ChainMap\nfrom functools import lru_cache\nimport math\nfrom math import sqrt, sin, cos, tan, ceil, fabs, floor, gcd, exp, log, log2\nimport fractions\nfrom typing import List, Tuple\nimport numpy as np\nimport random\nimport heapq\nfrom heapq import *\n"
82
+ if debug:
83
+ print(f"loading test code = {datetime.now().time()}")
84
+
85
+ if which_type == CODE_TYPE.call_based:
86
+ sol += test
87
+ if debug:
88
+ print(f"sol = {sol}")
89
+ signal.alarm(timeout)
90
+ try:
91
+ tmp_sol = RuntimeModule.from_string("tmp_sol", "", sol)
92
+ if "class Solution" not in test:
93
+ tmp = tmp_sol
94
+ else:
95
+ tmp = tmp_sol.Solution()
96
+ signal.alarm(0)
97
+ except Exception as e:
98
+ signal.alarm(0)
99
+ print(f"type 0 compilation error = {e}")
100
+ results.append(-2)
101
+ return results
102
+ signal.alarm(0)
103
+
104
+ elif which_type == CODE_TYPE.standard_input:
105
+ # sol
106
+ tmp_test = test.split("\n")
107
+
108
+ new_test = []
109
+ for x in tmp_test:
110
+ if (not x.startswith("from ")) and (not x.startswith("import ")):
111
+ new_test.append("\t" + x + "\n")
112
+ else:
113
+ new_test.append(x + "\n")
114
+ tmp_test = new_test
115
+
116
+ new_test = ""
117
+ started = False
118
+ for i in tmp_test:
119
+ if i.startswith("\t") and not started:
120
+ new_test += "stdin = sys.stdin\nstdout = sys.stdout\n"
121
+ new_test += "def code():\n"
122
+ new_test += i
123
+ started = True
124
+ elif started and ((i.startswith("from ")) or (i.startswith("import "))):
125
+ new_test += "\t" + i
126
+ else:
127
+ new_test += i
128
+ tmp_test = new_test
129
+
130
+ sol += tmp_test
131
+ if debug:
132
+ print(f"sol = {sol}")
133
+ method_name = "code"
134
+ signal.alarm(timeout)
135
+ try:
136
+ tmp_sol = RuntimeModule.from_string("tmp_sol", "", sol)
137
+ tmp = tmp_sol
138
+ signal.alarm(0)
139
+ except Exception as e:
140
+ signal.alarm(0)
141
+ print(f"type 1 compilation error = {e}")
142
+ results.append(-2)
143
+ return results
144
+ signal.alarm(0)
145
+ if debug:
146
+ print(f"get method = {datetime.now().time()}")
147
+
148
+ try:
149
+ method = getattr(tmp, method_name) # get_attr second arg must be str
150
+ except:
151
+ signal.alarm(0)
152
+ e = sys.exc_info()
153
+ print(f"unable to get function error = {e}")
154
+ return results
155
+
156
+ for index, inputs in enumerate(in_outs["inputs"]):
157
+ # JSON forces dictionaries to have string keys; this undoes this (assuming a singleton list)
158
+ try:
159
+ if isinstance(inputs[0], dict):
160
+ inputs = [{int(k): v for k,v in inputs[0].items()}]
161
+ except:
162
+ True
163
+ try:
164
+ if isinstance(in_outs["outputs"][index], dict):
165
+ in_outs["outputs"][index] = [{int(k): v for k,v in in_outs["outputs"][index].items()}]
166
+ except:
167
+ True
168
+ try:
169
+ if isinstance(in_outs["outputs"][index][0], dict):
170
+ in_outs["outputs"][index] = [{int(k): v for k,v in in_outs["outputs"][index][0].items()}]
171
+ except:
172
+ True
173
+
174
+ if debug:
175
+ print(f"time: {datetime.now().time()} testing index = {index} inputs = {inputs}, {type(inputs)}. type = {which_type}")
176
+ if which_type == CODE_TYPE.call_based: # Call-based
177
+ signal.alarm(timeout)
178
+ faulthandler.enable()
179
+ try:
180
+ output = method(*inputs)
181
+
182
+ # ground truth sequences are not tuples
183
+ if isinstance(output, tuple):
184
+ output = list(output)
185
+
186
+ tmp_result = output == in_outs["outputs"][index]
187
+ if isinstance(in_outs["outputs"][index], list) and in_outs["outputs"][index]:
188
+ tmp_result = tmp_result or (output == in_outs["outputs"][index][0])
189
+
190
+ # ground truth sequences are not tuples
191
+ try:
192
+ if isinstance(output[0], tuple):
193
+ tmp_result = tmp_result or ([list(x) for x in output] == in_outs["outputs"][index][0])
194
+ except:
195
+ True
196
+ results.append(tmp_result)
197
+
198
+ # reset the alarm
199
+ signal.alarm(0)
200
+ except Exception as e:
201
+ signal.alarm(0)
202
+ faulthandler.disable()
203
+ print(f"Standard input runtime error or time limit exceeded error = {e}")
204
+ results.append(-1)
205
+ continue
206
+ faulthandler.disable()
207
+ signal.alarm(0)
208
+ #if debug:
209
+ #print(f"outputs = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
210
+ elif which_type == CODE_TYPE.standard_input: # Standard input
211
+ faulthandler.enable()
212
+ signal.alarm(timeout)
213
+ passed = False
214
+
215
+ if isinstance(inputs, list):
216
+ inputs = "\n".join(inputs)
217
+ if isinstance(in_outs['outputs'][index], list):
218
+ in_outs['outputs'][index] = "\n".join(in_outs['outputs'][index])
219
+ with Capturing() as output:
220
+ try:
221
+ print("doing call")
222
+ call_method(method, inputs)
223
+ print("call done")
224
+ # reset the alarm
225
+ signal.alarm(0)
226
+ passed = True
227
+ except Exception as e:
228
+ print("call not done we are in exception")
229
+ # runtime error or took too long
230
+ signal.alarm(0)
231
+ print(f"Call-based runtime error or time limit exceeded error = {repr(e)}{e}")
232
+ results.append(-1)
233
+ signal.alarm(0)
234
+
235
+ if not passed:
236
+ if debug:
237
+ nl = "\n"
238
+ if not isinstance(inputs, list):
239
+ print(f"not passed output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
240
+ else:
241
+ print(f"not passed output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
242
+ continue
243
+
244
+ if passed and debug:
245
+ print(f"==> output = {output}, test outputs = {in_outs['outputs'][index]}")
246
+
247
+ if custom_compare_(output, in_outs['outputs'][index]):
248
+ tmp_result = True
249
+ results.append(tmp_result)
250
+ continue
251
+
252
+ # ground truth sequences are expressed as lists not tuples
253
+ if isinstance(output, tuple):
254
+ output = list(output)
255
+
256
+ tmp_result = False
257
+ try:
258
+ tmp_result = (output == [in_outs["outputs"][index]])
259
+ if isinstance(in_outs["outputs"][index], list):
260
+ tmp_result = tmp_result or (output == in_outs["outputs"][index])
261
+ if isinstance(output[0], str):
262
+ tmp_result = tmp_result or ([e.strip() for e in output] == in_outs["outputs"][index])
263
+ except Exception as e:
264
+ print(f"Failed check1 exception = {e}")
265
+ pass
266
+
267
+ if tmp_result == True:
268
+ results.append(tmp_result)
269
+ continue
270
+
271
+ # try one more time without \n
272
+ if isinstance(in_outs["outputs"][index], list):
273
+ for tmp_index, i in enumerate(in_outs["outputs"][index]):
274
+ in_outs["outputs"][index][tmp_index] = i.split("\n")
275
+ in_outs["outputs"][index][tmp_index] = [x.strip() for x in in_outs["outputs"][index][tmp_index] if x]
276
+ else:
277
+ in_outs["outputs"][index] = in_outs["outputs"][index].split("\n")
278
+ in_outs["outputs"][index] = list(filter(len, in_outs["outputs"][index]))
279
+ in_outs["outputs"][index] = list(map(lambda x:x.strip(), in_outs["outputs"][index]))
280
+
281
+ try:
282
+ tmp_result = (output == [in_outs["outputs"][index]])
283
+ if isinstance(in_outs["outputs"][index], list):
284
+ tmp_result = tmp_result or (output == in_outs["outputs"][index])
285
+ except Exception as e:
286
+ print(f"Failed check2 exception = {e}")
287
+ pass
288
+
289
+ if tmp_result == True:
290
+ results.append(tmp_result)
291
+ continue
292
+
293
+ # try by converting the output into a split up list too
294
+ if isinstance(output, list):
295
+ output = list(filter(len, output))
296
+
297
+ if debug:
298
+ nl = "\n"
299
+ if not isinstance(inputs, list):
300
+ print(f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
301
+ else:
302
+ print(f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
303
+
304
+ if tmp_result == True:
305
+ results.append(tmp_result)
306
+ continue
307
+
308
+ try:
309
+ tmp_result = (output == [in_outs["outputs"][index]])
310
+ if isinstance(in_outs["outputs"][index], list):
311
+ tmp_result = tmp_result or (output == in_outs["outputs"][index])
312
+ except Exception as e:
313
+ print(f"Failed check3 exception = {e}")
314
+ pass
315
+
316
+ try:
317
+ output_float = [float(e) for e in output]
318
+ gt_float = [float(e) for e in in_outs['outputs'][index]]
319
+ tmp_result = tmp_result or ((len(output_float) == len(gt_float)) and np.allclose(output_float, gt_float))
320
+ except Exception as e:
321
+ pass
322
+ try:
323
+ if isinstance(output[0], list):
324
+ output_float = [float(e) for e in output[0]]
325
+ gt_float = [float(e) for e in in_outs['outputs'][index][0]]
326
+ tmp_result = tmp_result or ((len(output_float) == len(gt_float)) and np.allclose(output_float, gt_float))
327
+ except Exception as e:
328
+ pass
329
+
330
+ if tmp_result == True:
331
+ results.append(tmp_result)
332
+ continue
333
+
334
+ # try by converting the stuff into split up list
335
+ if isinstance(in_outs["outputs"][index], list):
336
+ for tmp_index, i in enumerate(in_outs["outputs"][index]):
337
+ in_outs["outputs"][index][tmp_index] = set(i.split())
338
+ else:
339
+ in_outs["outputs"][index] = set(in_outs["outputs"][index].split())
340
+
341
+ try:
342
+ tmp_result = (output == in_outs["outputs"][index])
343
+ except Exception as e:
344
+ print(f"Failed check4 exception = {e}")
345
+ continue
346
+
347
+ if tmp_result == True:
348
+ results.append(tmp_result)
349
+ continue
350
+
351
+ # try by converting the output into a split up list too
352
+ if isinstance(output, list):
353
+ for tmp_index, i in enumerate(output):
354
+ output[tmp_index] = i.split()
355
+ output = list(filter(len, output))
356
+ for tmp_index, i in enumerate(output):
357
+ output[tmp_index] = set(i)
358
+ else:
359
+ output = output.split()
360
+ output = list(filter(len, output))
361
+ output = set(output)
362
+
363
+ try:
364
+ tmp_result = (set(frozenset(s) for s in output) == set(frozenset(s) for s in in_outs["outputs"][index]))
365
+ except Exception as e:
366
+ print(f"Failed check5 exception = {e}")
367
+
368
+
369
+ # if they are all numbers, round so that similar numbers are treated as identical
370
+ try:
371
+ tmp_result = tmp_result or (set(frozenset(round(float(t),3) for t in s) for s in output) ==\
372
+ set(frozenset(round(float(t),3) for t in s) for s in in_outs["outputs"][index]))
373
+ except Exception as e:
374
+ print(f"Failed check6 exception = {e}")
375
+
376
+ if tmp_result == True and debug:
377
+ print("PASSED")
378
+
379
+ results.append(tmp_result)
380
+
381
+ if debug:
382
+ nl = "\n"
383
+ if not isinstance(inputs, list):
384
+ print(f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
385
+ else:
386
+ print(f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
387
+
388
+
389
+ return results
390
+
391
+ def custom_compare_(output, ground_truth):
392
+
393
+ if isinstance(output, list):
394
+ output_1 = "\n".join(output)
395
+ if stripped_string_compare(output_1, ground_truth):
396
+ return True
397
+
398
+ if isinstance(output, list):
399
+ output_2 = [o.lstrip().rstrip() for o in output]
400
+ output_2 = "\n".join(output_2)
401
+ if stripped_string_compare(output_2, ground_truth):
402
+ return True
403
+
404
+ return False
405
+
406
+ def stripped_string_compare(s1, s2):
407
+ s1 = s1.lstrip().rstrip()
408
+ s2 = s2.lstrip().rstrip()
409
+ return s1 == s2
410
+
411
+ def call_method(method, inputs):
412
+
413
+ if isinstance(inputs, list):
414
+ inputs = "\n".join(inputs)
415
+
416
+ inputs_line_iterator = iter(inputs.split("\n"))
417
+
418
+ # sys.setrecursionlimit(10000)
419
+
420
+ # @patch('builtins.input', side_effect=inputs.split("\n"))
421
+ @patch('builtins.open', mock_open(read_data=inputs))
422
+ @patch('sys.stdin', StringIO(inputs))
423
+ @patch('sys.stdin.readline', lambda *args: next(inputs_line_iterator))
424
+ @patch('sys.stdin.readlines', lambda *args: inputs.split("\n"))
425
+ @patch('sys.stdin.read', lambda *args: inputs)
426
+ # @patch('sys.stdout.write', print)
427
+ def _inner_call_method(_method):
428
+ try:
429
+ return _method()
430
+ except SystemExit as e:
431
+ pass
432
+ finally:
433
+ pass
434
+ return _inner_call_method(method)
utils.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import numpy as np
3
+ from typing import Dict
4
+ from datasets import load_dataset
5
+ import testing_util as test_util
6
+
7
+
8
+ DATASET = "codeparrot/apps"
9
+
10
+
11
+ def evaluate_generations(generations, level=["all"]):
12
+ """We take the list of code generations and try to compile them
13
+ and the run their corresponding unit tests which are retrieved from the APPS dataset.
14
+
15
+ Args:
16
+ generations: list of code generations, in the same order as APPS dataset samples
17
+ level: list of levels to evaluate, can be "all", "introductory", "interview" or "competition"
18
+
19
+ Returns:
20
+ results: dictionary of results, key is the problem index, value is a list of results for each generation
21
+ [-2] = compile error, [-1] = runtime error [False] = failed test case [True] = passed test case
22
+ """
23
+
24
+ # generations are code generations in the same order of the dataset
25
+ apps_eval = load_dataset(DATASET, split="test", difficulties=level)
26
+ gpt_codes = generations
27
+ results = {}
28
+ for index in range(len(generations)):
29
+ print(f"task {index}")
30
+ generated_code = gpt_codes[index]
31
+ sample = apps_eval[index]
32
+ res = []
33
+ # loop over the generations
34
+ for o_idx, o in enumerate(generated_code):
35
+ curr_res = [-2]
36
+ try:
37
+ print("Run test")
38
+ curr_res = test_util.run_test(sample, test=o, debug=False)
39
+ print("\nSuccessful compilation!")
40
+ fixed = []
41
+ for e in curr_res:
42
+ if isinstance(e, np.ndarray):
43
+ e = e.item(0)
44
+ if isinstance(e, np.bool_):
45
+ e = bool(e)
46
+ fixed.append(e)
47
+ curr_res = fixed
48
+ if not np.all(curr_res):
49
+ print(f"Results were not True for all test cases") #{curr_res}")
50
+ except Exception as e:
51
+ print(f"Compilation failed, test framework exception = {repr(e)}{e}\n")
52
+ break
53
+ finally:
54
+ assert isinstance(curr_res, list)
55
+ res.append(curr_res)
56
+ results[index] = res
57
+
58
+ return results
59
+
60
+
61
+ def estimate_pass_at_k(num_samples, num_correct, k):
62
+ """Estimates pass@k of each problem and returns them in an array."""
63
+
64
+ def estimator(n: int, c: int, k: int) -> float:
65
+ """Calculates 1 - comb(n - c, k) / comb(n, k)."""
66
+ if n - c < k:
67
+ return 1.0
68
+ return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))
69
+
70
+ if isinstance(num_samples, int):
71
+ num_samples_it = itertools.repeat(num_samples, len(num_correct))
72
+ else:
73
+ assert len(num_samples) == len(num_correct)
74
+ num_samples_it = iter(num_samples)
75
+
76
+ return np.array([estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)])
77
+
78
+
79
+ def get_results(results: Dict, count_errors: bool = False, k_list: list = [1, 10, 100]):
80
+ """
81
+ Given the results evaluated against the testcases we output some statistics.
82
+ For single generations:
83
+ >>> example_results = {"0": [[-2]],"1": [[False,False]],"2": [[True,True]],"3": [[False,True,False,True]], "4": [[-1,-1]]}
84
+ >>> get_results(example_results, count_errors=True)
85
+ number of compile errors = 1 avg = 0.2
86
+ number of runtime errors = 1 avg = 0.2
87
+ number of test cases run = 5
88
+ Test Case Average (average accuracy over problems) = 0.3
89
+ Strict Accuracy (all test cases passed / total problems) = 0.2
90
+
91
+ For multiple generations:
92
+ >>> example_results = {"0": [[-2], [True, True, True]],"1": [[-1,-1, -1], [True, False, True]]}
93
+ >>> get_results(example_results k_list=[1, 2])
94
+ {'pass@1': 0.25, 'pass@2': 0.5}
95
+ """
96
+
97
+ metrics = {"avg_accuracy": None, "strict_accuracy": None, "pass_at_k": None}
98
+
99
+ if len(results["0"]) == 1:
100
+ # for single generations we compute average accuracy and stric accuracy: original APPS metrics
101
+ print("Computing accuracy metrics...")
102
+ res = []
103
+ per_prob_res = []
104
+ all_correct = []
105
+ for index in results:
106
+ results[index] = np.array(results[index])
107
+ res.extend(results[index])
108
+ per_prob_res.append(np.mean(results[index]>0))
109
+ all_correct.append(np.all(results[index]>0))
110
+ # we count campilation and runtime errors once per pronlem
111
+ compile_errors = len([e for e in res if -2 in e])
112
+ runtime_errors = len([e for e in res if -1 in e])
113
+ total_testcases = len(res)
114
+ if count_errors:
115
+ print(f"number of compile errors = {compile_errors} avg = {compile_errors / total_testcases}")
116
+ print(f"number of runtime errors = {runtime_errors} avg = {runtime_errors / total_testcases}")
117
+ print(f"number of problems evaluated = {total_testcases}")
118
+
119
+ print(f"Test Case Average Accuracy (ver tests) = {np.mean(per_prob_res)}")
120
+ print(f"Strict Accuracy (over problems that pass all tests) = {np.mean(all_correct)}")
121
+ metrics["avg_accuracy"] = np.mean(per_prob_res)
122
+ metrics["strict_accuracy"] = np.mean(all_correct)
123
+
124
+ else:
125
+ # for multiple generations we use pass@k metric used in the HumanEval benchmark
126
+ # we use strict accuracy, a generation is valid if it has to pass all the tests
127
+ print("Computing pass@k metric for multiple generations...")
128
+ # total is list with nb generations per task (task=index)
129
+ # correct is number of generations that passed all tests per task
130
+ total = []
131
+ correct = []
132
+ for index in results:
133
+ all_correct = []
134
+ for generation in results[index]:
135
+ gen = np.array(generation)
136
+ all_correct.append(np.all(gen>0))
137
+ total.append(len(all_correct))
138
+ correct.append(sum(all_correct))
139
+ total = np.array(total)
140
+ correct = np.array(correct)
141
+ ks = k_list
142
+ pass_at_k = {f"pass@{k}": estimate_pass_at_k(total, correct, k).mean() for k in ks if (total >= k).all()}
143
+ print(pass_at_k)
144
+ metrics["pass_at_k"] = pass_at_k
145
+ return metrics
146
+
147
+ def compute_metrics(generations, k_list=[1, 10, 100], count_errors=True, level=["all"]):
148
+ """Return metrics for the given generations.
149
+ Args:
150
+ generations: dict of generations, keyed by problem index
151
+ k_list: list of k values to compute pass@k when using multiple generations
152
+ count_errors: whether to count compilation and runtime errors when using single generations
153
+ level: which level difficulty in APPS dataset was used for the given generations
154
+ Returns:
155
+ metrics: dict of metrics
156
+ """
157
+ results = evaluate_generations(generations, level=level)
158
+ metrics = get_results(results, count_errors=count_errors, k_list=k_list)
159
+ return metrics