lvwerra HF staff commited on
Commit
68132d8
·
1 Parent(s): 20fed5e

Update Space (evaluate main: c447fc8e)

Browse files
Files changed (2) hide show
  1. code_eval.py +5 -21
  2. requirements.txt +1 -1
code_eval.py CHANGED
@@ -20,8 +20,6 @@ import itertools
20
  import os
21
  from collections import Counter, defaultdict
22
  from concurrent.futures import ThreadPoolExecutor, as_completed
23
- from dataclasses import dataclass, field
24
- from typing import List
25
 
26
  import datasets
27
  import numpy as np
@@ -133,28 +131,14 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
133
  THE SOFTWARE."""
134
 
135
 
136
- @dataclass
137
- class CodeEvalConfig(evaluate.info.Config):
138
-
139
- name: str = "default"
140
-
141
- k: List[int] = field(default_factory=lambda: [1, 10, 100])
142
- num_workers: int = 4
143
- timeout: float = 3.0
144
-
145
-
146
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
147
  class CodeEval(evaluate.Metric):
148
- CONFIG_CLASS = CodeEvalConfig
149
- ALLOWED_CONFIG_NAMES = ["default"]
150
-
151
- def _info(self, config):
152
  return evaluate.MetricInfo(
153
  # This is the description that will appear on the metrics page.
154
  description=_DESCRIPTION,
155
  citation=_CITATION,
156
  inputs_description=_KWARGS_DESCRIPTION,
157
- config=config,
158
  # This defines the format of each prediction and reference
159
  features=datasets.Features(
160
  {
@@ -168,7 +152,7 @@ class CodeEval(evaluate.Metric):
168
  license=_LICENSE,
169
  )
170
 
171
- def _compute(self, predictions, references):
172
  """Returns the scores"""
173
 
174
  if os.getenv("HF_ALLOW_CODE_EVAL", 0) != "1":
@@ -177,7 +161,7 @@ class CodeEval(evaluate.Metric):
177
  if os.name == "nt":
178
  raise NotImplementedError("This metric is currently not supported on Windows.")
179
 
180
- with ThreadPoolExecutor(max_workers=self.config.num_workers) as executor:
181
  futures = []
182
  completion_id = Counter()
183
  n_samples = 0
@@ -186,7 +170,7 @@ class CodeEval(evaluate.Metric):
186
  for task_id, (candidates, test_case) in enumerate(zip(predictions, references)):
187
  for candidate in candidates:
188
  test_program = candidate + "\n" + test_case
189
- args = (test_program, self.config.timeout, task_id, completion_id[task_id])
190
  future = executor.submit(check_correctness, *args)
191
  futures.append(future)
192
  completion_id[task_id] += 1
@@ -205,7 +189,7 @@ class CodeEval(evaluate.Metric):
205
  total = np.array(total)
206
  correct = np.array(correct)
207
 
208
- ks = self.config.k
209
  pass_at_k = {f"pass@{k}": estimate_pass_at_k(total, correct, k).mean() for k in ks if (total >= k).all()}
210
 
211
  return pass_at_k, results
 
20
  import os
21
  from collections import Counter, defaultdict
22
  from concurrent.futures import ThreadPoolExecutor, as_completed
 
 
23
 
24
  import datasets
25
  import numpy as np
 
131
  THE SOFTWARE."""
132
 
133
 
 
 
 
 
 
 
 
 
 
 
134
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
135
  class CodeEval(evaluate.Metric):
136
+ def _info(self):
 
 
 
137
  return evaluate.MetricInfo(
138
  # This is the description that will appear on the metrics page.
139
  description=_DESCRIPTION,
140
  citation=_CITATION,
141
  inputs_description=_KWARGS_DESCRIPTION,
 
142
  # This defines the format of each prediction and reference
143
  features=datasets.Features(
144
  {
 
152
  license=_LICENSE,
153
  )
154
 
155
+ def _compute(self, predictions, references, k=[1, 10, 100], num_workers=4, timeout=3.0):
156
  """Returns the scores"""
157
 
158
  if os.getenv("HF_ALLOW_CODE_EVAL", 0) != "1":
 
161
  if os.name == "nt":
162
  raise NotImplementedError("This metric is currently not supported on Windows.")
163
 
164
+ with ThreadPoolExecutor(max_workers=num_workers) as executor:
165
  futures = []
166
  completion_id = Counter()
167
  n_samples = 0
 
170
  for task_id, (candidates, test_case) in enumerate(zip(predictions, references)):
171
  for candidate in candidates:
172
  test_program = candidate + "\n" + test_case
173
+ args = (test_program, timeout, task_id, completion_id[task_id])
174
  future = executor.submit(check_correctness, *args)
175
  futures.append(future)
176
  completion_id[task_id] += 1
 
189
  total = np.array(total)
190
  correct = np.array(correct)
191
 
192
+ ks = k
193
  pass_at_k = {f"pass@{k}": estimate_pass_at_k(total, correct, k).mean() for k in ks if (total >= k).all()}
194
 
195
  return pass_at_k, results
requirements.txt CHANGED
@@ -1 +1 @@
1
- git+https://github.com/huggingface/evaluate@e4a2724377909fe2aeb4357e3971e5a569673b39
 
1
+ git+https://github.com/huggingface/evaluate@c447fc8eda9c62af501bfdc6988919571050d950