Spaces:
Sleeping
Sleeping
fix _temp_run can't be pickled; pass indices to allow evaluation on subset
Browse files- apps_metric.py +2 -2
- tests.py +20 -10
- utils.py +17 -10
apps_metric.py
CHANGED
@@ -76,7 +76,7 @@ class apps_metric(evaluate.EvaluationModule):
|
|
76 |
|
77 |
|
78 |
|
79 |
-
def _compute(self, predictions, k_list=[1, 10, 100], count_errors=True, level="all", debug=False):
|
80 |
"""Returns the scores"""
|
81 |
-
metrics = compute_metrics(predictions, k_list=k_list, count_errors=count_errors, level=level, debug=debug)
|
82 |
return metrics
|
|
|
76 |
|
77 |
|
78 |
|
79 |
+
def _compute(self, predictions, indices=None, k_list=[1, 10, 100], count_errors=True, level="all", debug=False):
|
80 |
"""Returns the scores"""
|
81 |
+
metrics = compute_metrics(predictions, indices=indices, k_list=k_list, count_errors=count_errors, level=level, debug=debug)
|
82 |
return metrics
|
tests.py
CHANGED
@@ -1,14 +1,24 @@
|
|
1 |
import json
|
2 |
-
from
|
3 |
|
4 |
-
|
5 |
-
solution_sample2 = json.load(open("test_examples/solutions_problem_2.json", "r"))
|
6 |
-
single_solutions = [solution_sample1[:1], solution_sample2[:1]]
|
7 |
-
multiple_solutions = [solution_sample1[:3], solution_sample2[:3]]
|
8 |
|
9 |
-
metric = load("codeparrot/apps_metric")
|
10 |
-
result_1 = metric.compute(predictions=single_solutions, level="all")
|
11 |
-
result_2 = metric.compute(predictions=multiple_solutions, level="all", k_list=[1, 2, 3])
|
12 |
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import json
|
2 |
+
from multiprocessing import freeze_support
|
3 |
|
4 |
+
from apps_metric import apps_metric
|
|
|
|
|
|
|
5 |
|
|
|
|
|
|
|
6 |
|
7 |
+
if __name__ == '__main__':
|
8 |
+
"""
|
9 |
+
Verify by checking if reference solutions pass all test cases (with strict accuracy == 1).
|
10 |
+
Note that some reference solutions may not pass all test cases. So only throw a warning.
|
11 |
+
"""
|
12 |
+
freeze_support()
|
13 |
+
|
14 |
+
solution_sample1 = json.load(open("test_examples/solutions_problem_1.json", "r"))
|
15 |
+
solution_sample2 = json.load(open("test_examples/solutions_problem_2.json", "r"))
|
16 |
+
single_solutions = [solution_sample1[:1], solution_sample2[:1]]
|
17 |
+
multiple_solutions = [solution_sample1[:3], solution_sample2[:3]]
|
18 |
+
|
19 |
+
metric = apps_metric()
|
20 |
+
result_1 = metric.compute(predictions=single_solutions, level="all")
|
21 |
+
result_2 = metric.compute(predictions=multiple_solutions, level="all", k_list=[1, 2, 3])
|
22 |
+
|
23 |
+
assert result_1 == {'avg_accuracy': 1.0, 'strict_accuracy': 1.0, 'pass_at_k': None}
|
24 |
+
assert result_2 == {'avg_accuracy': None, 'strict_accuracy': None, 'pass_at_k': {'pass@1': 1.0, 'pass@2': 1.0, 'pass@3': 1.0}}
|
utils.py
CHANGED
@@ -9,13 +9,14 @@ from .testing_util import run_test
|
|
9 |
DATASET = "codeparrot/apps"
|
10 |
TIMEOUT = 10
|
11 |
|
|
|
|
|
|
|
|
|
12 |
def check_correctness(sample, generation, timeout, debug=True):
|
13 |
"""Check correctness of code generation with a global timeout.
|
14 |
The global timeout is to catch some extreme/rare cases not handled by the timeouts
|
15 |
inside `run_test`"""
|
16 |
-
def _temp_run(sample, generation, debug, result):
|
17 |
-
result.append(run_test(sample, test=generation, debug=debug))
|
18 |
-
|
19 |
manager = multiprocessing.Manager()
|
20 |
result = manager.list()
|
21 |
p = multiprocessing.Process(target=_temp_run, args=(sample, generation, debug, result))
|
@@ -32,12 +33,13 @@ def check_correctness(sample, generation, timeout, debug=True):
|
|
32 |
return result[0]
|
33 |
|
34 |
|
35 |
-
def evaluate_generations(generations: list, level: str = "all", debug: bool = False):
|
36 |
"""We take the list of code generations and try to compile them
|
37 |
and the run their corresponding unit tests which are retrieved from the APPS dataset.
|
38 |
|
39 |
Args:
|
40 |
generations: list of code generations (same order as samples in APPS dataset)
|
|
|
41 |
level: difficulty level used in the generation, can be "all", "introductory", "interview" or "competition"
|
42 |
|
43 |
Returns:
|
@@ -47,10 +49,14 @@ def evaluate_generations(generations: list, level: str = "all", debug: bool = Fa
|
|
47 |
|
48 |
# generations are code generations in the same order of the dataset
|
49 |
apps_eval = load_dataset(DATASET, split="test", difficulties=[level])
|
|
|
|
|
|
|
|
|
50 |
results = {}
|
51 |
-
for index in
|
52 |
# code generations for problem (index)
|
53 |
-
problem_generations =
|
54 |
# get corresponding samples from APPS dataset
|
55 |
sample = apps_eval[index]
|
56 |
res = []
|
@@ -74,7 +80,7 @@ def evaluate_generations(generations: list, level: str = "all", debug: bool = Fa
|
|
74 |
print(f"Results were not True for all test cases")
|
75 |
except Exception as e:
|
76 |
if debug:
|
77 |
-
print(f"Compilation failed, test framework exception = {repr(e)}
|
78 |
break
|
79 |
finally:
|
80 |
assert isinstance(curr_res, list)
|
@@ -125,7 +131,7 @@ def get_results(results: Dict[int, list], count_errors: bool = False, k_list: li
|
|
125 |
|
126 |
metrics = {"avg_accuracy": None, "strict_accuracy": None, "pass_at_k": None}
|
127 |
|
128 |
-
if len(results[0]) == 1:
|
129 |
# for single generations we compute average accuracy and stric accuracy: original APPS metrics
|
130 |
print("Computing accuracy metrics...")
|
131 |
res = []
|
@@ -173,10 +179,11 @@ def get_results(results: Dict[int, list], count_errors: bool = False, k_list: li
|
|
173 |
metrics["pass_at_k"] = pass_at_k
|
174 |
return metrics
|
175 |
|
176 |
-
def compute_metrics(generations, level="all", k_list=[1, 10, 100], count_errors=True, debug=False):
|
177 |
"""Return metrics for the given generations.
|
178 |
Args:
|
179 |
generations: list of code generations for each problem (each generation is a list of generations)
|
|
|
180 |
k_list: list of k values to compute pass@k when using multiple generations
|
181 |
count_errors: whether to count compilation and runtime errors when using single generations
|
182 |
level: difficulty level in APPS dataset that was used for the given generations (from: "all", "introductory", "interview", "competition")
|
@@ -204,7 +211,7 @@ def compute_metrics(generations, level="all", k_list=[1, 10, 100], count_errors=
|
|
204 |
{'pass@1': 1.0, 'pass@2': 1.0, 'pass@3': 1.0}
|
205 |
{'avg_accuracy': None, 'strict_accuracy': None, 'pass_at_k': {'pass@1': 1.0, 'pass@2': 1.0, 'pass@3': 1.0}}
|
206 |
"""
|
207 |
-
results = evaluate_generations(generations, level=level, debug=debug)
|
208 |
metrics = get_results(results, count_errors=count_errors, k_list=k_list)
|
209 |
return metrics
|
210 |
|
|
|
9 |
DATASET = "codeparrot/apps"
|
10 |
TIMEOUT = 10
|
11 |
|
12 |
+
|
13 |
+
def _temp_run(sample, generation, debug, result):
|
14 |
+
result.append(run_test(sample, test=generation, debug=debug))
|
15 |
+
|
16 |
def check_correctness(sample, generation, timeout, debug=True):
|
17 |
"""Check correctness of code generation with a global timeout.
|
18 |
The global timeout is to catch some extreme/rare cases not handled by the timeouts
|
19 |
inside `run_test`"""
|
|
|
|
|
|
|
20 |
manager = multiprocessing.Manager()
|
21 |
result = manager.list()
|
22 |
p = multiprocessing.Process(target=_temp_run, args=(sample, generation, debug, result))
|
|
|
33 |
return result[0]
|
34 |
|
35 |
|
36 |
+
def evaluate_generations(generations: list, indices: list = [], level: str = "all", debug: bool = False):
|
37 |
"""We take the list of code generations and try to compile them
|
38 |
and the run their corresponding unit tests which are retrieved from the APPS dataset.
|
39 |
|
40 |
Args:
|
41 |
generations: list of code generations (same order as samples in APPS dataset)
|
42 |
+
indices: list of indicies of problems to evaluate, if empty, evaluate all problems
|
43 |
level: difficulty level used in the generation, can be "all", "introductory", "interview" or "competition"
|
44 |
|
45 |
Returns:
|
|
|
49 |
|
50 |
# generations are code generations in the same order of the dataset
|
51 |
apps_eval = load_dataset(DATASET, split="test", difficulties=[level])
|
52 |
+
|
53 |
+
if indices is None:
|
54 |
+
indices = range(len(generations))
|
55 |
+
|
56 |
results = {}
|
57 |
+
for index, generation in zip(indices, generations):
|
58 |
# code generations for problem (index)
|
59 |
+
problem_generations = generation
|
60 |
# get corresponding samples from APPS dataset
|
61 |
sample = apps_eval[index]
|
62 |
res = []
|
|
|
80 |
print(f"Results were not True for all test cases")
|
81 |
except Exception as e:
|
82 |
if debug:
|
83 |
+
print(f"Compilation failed, test framework exception = {repr(e)}\n")
|
84 |
break
|
85 |
finally:
|
86 |
assert isinstance(curr_res, list)
|
|
|
131 |
|
132 |
metrics = {"avg_accuracy": None, "strict_accuracy": None, "pass_at_k": None}
|
133 |
|
134 |
+
if len(list(results.values())[0]) == 1:
|
135 |
# for single generations we compute average accuracy and stric accuracy: original APPS metrics
|
136 |
print("Computing accuracy metrics...")
|
137 |
res = []
|
|
|
179 |
metrics["pass_at_k"] = pass_at_k
|
180 |
return metrics
|
181 |
|
182 |
+
def compute_metrics(generations, indices=None, level="all", k_list=[1, 10, 100], count_errors=True, debug=False):
|
183 |
"""Return metrics for the given generations.
|
184 |
Args:
|
185 |
generations: list of code generations for each problem (each generation is a list of generations)
|
186 |
+
indices: list of indices of problems (if None, generations are all problems)
|
187 |
k_list: list of k values to compute pass@k when using multiple generations
|
188 |
count_errors: whether to count compilation and runtime errors when using single generations
|
189 |
level: difficulty level in APPS dataset that was used for the given generations (from: "all", "introductory", "interview", "competition")
|
|
|
211 |
{'pass@1': 1.0, 'pass@2': 1.0, 'pass@3': 1.0}
|
212 |
{'avg_accuracy': None, 'strict_accuracy': None, 'pass_at_k': {'pass@1': 1.0, 'pass@2': 1.0, 'pass@3': 1.0}}
|
213 |
"""
|
214 |
+
results = evaluate_generations(generations, indices=indices, level=level, debug=debug)
|
215 |
metrics = get_results(results, count_errors=count_errors, k_list=k_list)
|
216 |
return metrics
|
217 |
|