facat commited on
Commit
72dba58
1 Parent(s): 5ca9a91
Files changed (2) hide show
  1. tasks.py +3 -3
  2. tlem.py +33 -13
tasks.py CHANGED
@@ -59,7 +59,7 @@ class Task:
59
  dataset_name: str | tuple[str, str] = ("gsm8k", "main")
60
  split: str = "test"
61
  # metrics: list[str] = field(default_factory=list)
62
- metric_name: str | tuple[str, str] = ("sustech/tlem", "gsm8k")
63
  input_column: str = "question"
64
  label_column: str = ""
65
  prompt: Optional[Callable | str] = None
@@ -147,12 +147,12 @@ class Task:
147
  if isinstance(self.metric_name, str)
148
  else load(*self.metric_name)
149
  )
150
- return metric
151
 
152
  @cached_property
153
  def result(self) -> dict:
154
  assert self.outputs, "Please run the task first."
155
- results = self.metric._compute(
156
  responses=self.outputs, references=self.dataset[self.label_column]
157
  )
158
  # logging.info(f"{self.name}:{results}")
 
59
  dataset_name: str | tuple[str, str] = ("gsm8k", "main")
60
  split: str = "test"
61
  # metrics: list[str] = field(default_factory=list)
62
+ metric_name: str | tuple[str, str] = ("sustech/tlem", "mmlu")
63
  input_column: str = "question"
64
  label_column: str = ""
65
  prompt: Optional[Callable | str] = None
 
147
  if isinstance(self.metric_name, str)
148
  else load(*self.metric_name)
149
  )
150
+ return metric._compute
151
 
152
  @cached_property
153
  def result(self) -> dict:
154
  assert self.outputs, "Please run the task first."
155
+ results = self.metric(
156
  responses=self.outputs, references=self.dataset[self.label_column]
157
  )
158
  # logging.info(f"{self.name}:{results}")
tlem.py CHANGED
@@ -72,6 +72,17 @@ class ReasoningMetric(evaluate.Metric):
72
  class Suite(EvaluationSuite):
73
  task_class = Task
74
  utils = utils
 
 
 
 
 
 
 
 
 
 
 
75
 
76
  def __getitem__(self, key) -> Task:
77
  match key:
@@ -171,19 +182,26 @@ class Suite(EvaluationSuite):
171
  ]:
172
  suite[name] = self.get_suite(name)
173
  case "tlem":
174
- suite = {}
175
- for name in [
176
- "arc",
177
- "hellaswag",
178
- "mmlu-chat",
179
- "winogrande",
180
- "gsm8k",
181
- "cmmlu-chat",
182
- "ceval-chat",
183
- # "truthful_qa",
184
- "drop",
185
- ]:
186
- suite[name] = self.get_suite(name)
 
 
 
 
 
 
 
187
  if isinstance(suite, Task):
188
  suite = [suite]
189
  if isinstance(suite, list):
@@ -195,6 +213,7 @@ class Suite(EvaluationSuite):
195
  try:
196
  return self.tasks[self.tasks.index(task)]
197
  except ValueError:
 
198
  self.tasks.append(task)
199
  return self.tasks[-1]
200
 
@@ -212,6 +231,7 @@ class Suite(EvaluationSuite):
212
  def load(self, name):
213
  self.suite.update(self.get_suite(name))
214
  self.suite = self.drop_duplicates(self.suite)
 
215
 
216
  def __init__(self, name="tlem"):
217
  super().__init__(name)
 
72
  class Suite(EvaluationSuite):
73
  task_class = Task
74
  utils = utils
75
+ supported_datasets = [
76
+ "arc",
77
+ "hellaswag",
78
+ "mmlu-chat",
79
+ "winogrande",
80
+ "gsm8k",
81
+ "cmmlu-chat",
82
+ "ceval-chat",
83
+ "bbh",
84
+ "drop",
85
+ ]
86
 
87
  def __getitem__(self, key) -> Task:
88
  match key:
 
182
  ]:
183
  suite[name] = self.get_suite(name)
184
  case "tlem":
185
+ suite = {
186
+ name: self.get_suite(name)
187
+ for name in [
188
+ "arc",
189
+ "hellaswag",
190
+ "mmlu-chat",
191
+ "winogrande",
192
+ "gsm8k",
193
+ "cmmlu-chat",
194
+ "ceval-chat",
195
+ "bbh",
196
+ ]
197
+ }
198
+ case "all":
199
+ suite = {name: self.get_suite(name) for name in self.supported_datasets}
200
+ case _:
201
+ raise NotImplementedError(
202
+ f"{name} is not supported in {self.supported_datasets}"
203
+ )
204
+
205
  if isinstance(suite, Task):
206
  suite = [suite]
207
  if isinstance(suite, list):
 
213
  try:
214
  return self.tasks[self.tasks.index(task)]
215
  except ValueError:
216
+ logging.debug(f"add {task.name} to suite.")
217
  self.tasks.append(task)
218
  return self.tasks[-1]
219
 
 
231
  def load(self, name):
232
  self.suite.update(self.get_suite(name))
233
  self.suite = self.drop_duplicates(self.suite)
234
+ # return self
235
 
236
  def __init__(self, name="tlem"):
237
  super().__init__(name)