facat commited on
Commit
d13c0d8
1 Parent(s): 25e4875
Files changed (2) hide show
  1. tasks.py +268 -223
  2. tlem.py +9 -0
tasks.py CHANGED
@@ -1,8 +1,9 @@
1
  from dataclasses import dataclass, field
 
2
  from datasets import load_dataset, Dataset
3
  from functools import cached_property
4
  from tqdm.auto import tqdm
5
- from typing import Any, Optional, Protocol, Iterable, Callable
6
  import logging
7
  import pandas as pd
8
  from functools import partial
@@ -187,71 +188,57 @@ def multichoice_zh(responses: Any, references: list[str]):
187
  class Metrics:
188
  cmmlu = multichoice_zh
189
  mmlu = multichoice
190
-
191
  def ceval(responses: list[str], answers: list[str | int]):
192
  responses = [extract_choice_zh(pred) for pred in responses]
193
  return responses, answers
194
-
195
  def winogrande(responses: list[str], answers: list[str | int]):
196
  responses = [first_option_postprocess(pred, options="AB") for pred in responses]
197
  return responses, answers
198
-
199
  def arc(responses: list[str], answers: list[str | int]):
200
  if len(responses) != len(answers):
201
- return {
202
- 'error': 'predictions and references have different '
203
- 'length'
204
- }
205
- responses = [first_option_postprocess(pred, options="ABCD") for pred in responses]
206
 
207
  return responses, answers
208
-
209
  def hellaswag(responses: list[str], answers: list[str | int]):
210
  if len(responses) != len(answers):
211
- return {
212
- 'error': 'predictions and references have different '
213
- 'length'
214
- }
215
- responses = [first_option_postprocess(pred, options="ABCD") for pred in responses]
216
- answers = ['ABCD'[int(ans)] for ans in answers]
217
  return responses, answers
218
-
219
  def drop(responses: list[str], answers: list[list]):
220
  if len(responses) != len(answers):
221
- return {
222
- 'error': 'predictions and references have different '
223
- 'length'
224
- }
225
  responses = [general_postprocess(pred) for pred in responses]
226
- processed_answers = [[general_postprocess(j) for j in i]
227
- for i in answers]
228
  matched_answers = []
229
- for pred, ans, origin_ans in zip(responses, processed_answers,
230
- answers):
231
-
232
  if pred in ans or pred in origin_ans:
233
  matched_answers.append(pred)
234
  else:
235
  matched_answers.append(ans[0])
236
-
237
  return responses, matched_answers
238
-
239
  def bbh_mcq(responses: list[str], answers: list[str | int]):
240
  if len(responses) != len(answers):
241
- return {
242
- 'error': 'predictions and references have different '
243
- 'length'
244
- }
245
  responses = [bbh_mcq_postprocess(pred) for pred in responses]
246
 
247
  return responses, answers
248
-
249
  def bbh_freefrom(responses: list[str], answers: list[str | int]):
250
  if len(responses) != len(answers):
251
- return {
252
- 'error': 'predictions and references have different '
253
- 'length'
254
- }
255
 
256
  responses = [bbh_freeform_postprocess(pred) for pred in responses]
257
 
@@ -272,27 +259,16 @@ class Metrics:
272
  return responses, answers
273
 
274
  def MATH(responses: list[str], answers: list[str]):
275
- scores = []
276
-
277
- for response, answer in zip(responses, answers):
278
  indices = [pos for pos, char in enumerate(response) if char == "$"]
279
  if len(indices) <= 2:
280
- scores.append(0)
281
- continue
282
  else:
283
- result = response[indices[-2] + 1 : indices[-1]]
284
- gold = get_answer(answer)
285
- scores.append(1.0 * is_equiv(result, gold))
286
-
287
- return scores
288
-
289
- def math23k(responses: list[str], answers: list[str]):
290
- scores = []
291
- for response, answer in zip(responses, answers):
292
- pred = extract_numeric(response, pattern=NUMERIC_IN_ZH)
293
- gold = extract_numeric(answer, pattern=NUMERIC_IN_ZH)
294
- scores.append(1.0 * (pred == gold))
295
- return scores
296
 
297
 
298
  class CMMLU:
@@ -570,7 +546,7 @@ class MMLU:
570
  class Winogrande:
571
  input_column = "input"
572
  label_column = "answer"
573
-
574
  categories = [
575
  "winogrande_debiased",
576
  "winogrande_l",
@@ -579,24 +555,24 @@ class Winogrande:
579
  "winogrande_xl",
580
  "winogrande_xs",
581
  ]
582
-
583
  @classmethod
584
  def prompt_winogrande(cls, example):
585
- option1 = example["sentence"].replace("_", example['option1'])
586
- option2 = example["sentence"].replace("_", example['option2'])
587
  answer = example[cls.label_column]
588
  prompt = f"Which of the following is a good sentence:\nA. {option1}\nB. {option2}\nAnswer:"
589
-
590
  return {
591
  cls.input_column: prompt,
592
- cls.label_column: ' AB'[int(answer)] if answer != '' else ''
593
  }
594
-
595
  @classmethod
596
- def suite(cls,):
597
- subcategories = {
598
- item: [item] for item in cls.categories
599
- }
600
  finer_categories = (
601
  pd.Series(subcategories) # noqa # type: ignore
602
  .explode()
@@ -618,19 +594,18 @@ class Winogrande:
618
  label_column=cls.label_column,
619
  prompt=partial(cls.prompt_winogrande),
620
  few_shot=0,
621
- split="validation"
622
  )
623
  )
624
-
625
  return suite
626
-
627
 
628
 
629
  class DROP:
630
  input_column = "input"
631
  label_column = "answers"
632
-
633
- icl_prompt = '''\
634
  Text: In the county, the population was spread out with 23.50% under the age of 18, 8.70% from 18 to 24, 29.70% from 25 to 44, 24.70% from 45 to 64, and 13.30% who were 65 years of age or older.
635
  Question: How many more percent are under the age of 18 compared to the 18 to 24 group?
636
  Anawer: According to the text, 23.5% are under the age of 18, and 8.7% are from ages 18 to 24. 23.5%-8.7%=14.8%. So the answer is 14.8.
@@ -641,15 +616,16 @@ Anawer: According to the text, Stafford threw 5 TD passes, 3 of which were to Jo
641
 
642
  Text: [PROMPT]
643
  Question: [QUESTION]
644
- Anawer:'''
645
-
646
  categories = ["validation"]
647
-
648
  @classmethod
649
  def prompt_drop(cls, example):
650
-
651
- prompt = cls.icl_prompt.replace("[PROMPT]", example["passage"]).replace("[QUESTION]", example["question"])
652
-
 
653
  validated_answers = example["answers_spans"]["spans"]
654
  validated_types = example["answers_spans"]["types"]
655
  answers = []
@@ -661,18 +637,16 @@ Anawer:'''
661
  # answers.append(' '.join(d).strip())
662
  # else:
663
  # for span in answer_item['spans']:
664
- # answers.append(span)
665
  answers.append(answer_item)
666
  answers = list(set(answers))
667
-
668
- return {
669
- cls.input_column: prompt,
670
- cls.label_column: answers
671
- }
672
-
673
  @classmethod
674
- def suite(cls,):
675
-
 
676
  finer_categories = (
677
  pd.Series(cls.categories) # noqa # type: ignore
678
  .explode()
@@ -693,33 +667,34 @@ Anawer:'''
693
  label_column=cls.label_column,
694
  prompt=partial(cls.prompt_drop),
695
  few_shot=0,
696
- split="validation"
697
  )
698
  )
699
-
700
  return suite
701
 
702
 
703
  class HellaSwag:
704
  input_column = "input"
705
  label_column = "label"
706
-
707
  categories = ["validation"]
708
-
709
  @classmethod
710
  def prompt_hellaswag(cls, example):
711
-
712
  prompt = f"{example['ctx']}\nQuestion: Which ending makes the most sense?\n"
713
  prompt += f"A. {example['endings'][0]}\n"
714
  prompt += f"B. {example['endings'][1]}\n"
715
  prompt += f"C. {example['endings'][2]}\n"
716
  prompt += f"D. {example['endings'][3]}\n"
717
  prompt += "You may choose from 'A', 'B', 'C', 'D'.\nAnswer:"
718
-
719
  return {cls.input_column: prompt}
720
-
721
  @classmethod
722
- def suite(cls,):
 
 
723
  finer_categories = (
724
  pd.Series(cls.categories) # noqa # type: ignore
725
  .explode()
@@ -740,21 +715,22 @@ class HellaSwag:
740
  label_column=cls.label_column,
741
  prompt=partial(cls.prompt_hellaswag),
742
  few_shot=0,
743
- split="validation"
744
  )
745
  )
746
-
747
  return suite
748
 
 
749
  class ARC:
750
  input_column = "input"
751
  label_column = "answerKey"
752
-
753
  categories = [
754
  "ARC-Challenge",
755
  "ARC-Easy",
756
  ]
757
-
758
  @classmethod
759
  def prompt_arc(cls, example):
760
  choices = example["choices"]
@@ -762,10 +738,8 @@ class ARC:
762
  for label, choice in zip(choices["label"], choices["text"]):
763
  prompt += f"\n{label}. {choice}"
764
  prompt += "\nAnswer:"
765
- return {
766
- cls.input_column: prompt
767
- }
768
-
769
  @classmethod
770
  def suite(cls):
771
  finer_categories = (
@@ -790,62 +764,71 @@ class ARC:
790
  few_shot=0,
791
  )
792
  )
793
-
794
  return suite
795
 
796
 
797
  class BBH:
798
  input_column = "input"
799
  label_column = "target"
800
-
801
  multiple_choice_prefix = "Follow the given examples and answer the question.\n[HINT]\n\nQ: [INPUT]\nA: Let's think step by step."
802
  free_form_prefix = "Follow the given examples and answer the question.\n[HINT]\n\nQ: [INPUT]\nA: Let's think step by step."
803
-
804
  bbh_multiple_choice_sets = [
805
- 'temporal_sequences',
806
- 'disambiguation_qa',
807
- 'date_understanding',
808
- 'tracking_shuffled_objects_three_objects',
809
- 'penguins_in_a_table',
810
- 'geometric_shapes',
811
- 'snarks',
812
- 'ruin_names',
813
- 'tracking_shuffled_objects_seven_objects',
814
- 'tracking_shuffled_objects_five_objects',
815
- 'logical_deduction_three_objects',
816
- 'hyperbaton',
817
- 'logical_deduction_five_objects',
818
- 'logical_deduction_seven_objects',
819
- 'movie_recommendation',
820
- 'salient_translation_error_detection',
821
- 'reasoning_about_colored_objects',
822
  ]
823
-
824
  bbh_free_form_sets = [
825
- 'multistep_arithmetic_two',
826
- 'navigate',
827
- 'dyck_languages',
828
- 'word_sorting',
829
- 'sports_understanding',
830
- 'boolean_expressions',
831
- 'object_counting',
832
- 'formal_fallacies',
833
- 'causal_judgement',
834
- 'web_of_lies',
835
  ]
836
-
837
  @classmethod
838
- def prompt_bbh(cls, example, category:str):
839
-
840
- meta_prompt = cls.multiple_choice_prefix if category in cls.bbh_multiple_choice_sets else cls.free_form_prefix
841
- prompt = meta_prompt.replace("[HINT]", bbh_lib_prompt(category=category)).replace("[INPUT]", example[cls.input_column])
842
-
 
 
 
 
 
843
  return {"input": prompt}
844
-
845
  @classmethod
846
- def suite(cls,):
 
 
847
  finer_categories = (
848
- pd.Series(cls.bbh_free_form_sets + cls.bbh_multiple_choice_sets) # noqa # type: ignore
 
 
849
  .explode()
850
  .reset_index()
851
  .set_index(0)
@@ -878,167 +861,229 @@ class BBH:
878
  few_shot=0,
879
  )
880
  )
881
-
882
  return suite
883
-
884
 
885
-
886
  class CEVAL:
887
  input_column = "input"
888
  label_column = "answer"
889
-
890
  @classmethod
891
- def prompt_ceval(cls, example, cate:str, chat=False):
892
  _ch_name = cls.ceval_subject_mapping[cate][1]
893
- prefix = (
894
- f"以下是中国关于{_ch_name}考试的单项选择题,请选出其中的正确答案��\n"
895
- if chat
896
- else "问题:"
897
- )
898
-
899
  prompt = prefix + f'{example["question"]}'
900
  for choice in list("ABCD"):
901
  prompt += f"\n{choice}. {example[choice]}"
902
 
903
  prompt += "\n答案:"
904
  return {"input": prompt}
905
-
906
  ceval_subject_mapping = {
907
- "computer_network":
908
- ["Computer Network", "\u8ba1\u7b97\u673a\u7f51\u7edc", "STEM"],
909
- "operating_system":
910
- ["Operating System", "\u64cd\u4f5c\u7cfb\u7edf", "STEM"],
911
- "computer_architecture":
912
- ["Computer Architecture", "\u8ba1\u7b97\u673a\u7ec4\u6210", "STEM"],
913
- "college_programming":
914
- ["College Programming", "\u5927\u5b66\u7f16\u7a0b", "STEM"],
 
 
 
 
 
 
 
 
915
  "college_physics": ["College Physics", "\u5927\u5b66\u7269\u7406", "STEM"],
916
- "college_chemistry":
917
- ["College Chemistry", "\u5927\u5b66\u5316\u5b66", "STEM"],
918
- "advanced_mathematics":
919
- ["Advanced Mathematics", "\u9ad8\u7b49\u6570\u5b66", "STEM"],
920
- "probability_and_statistics":
921
- ["Probability and Statistics", "\u6982\u7387\u7edf\u8ba1", "STEM"],
922
- "discrete_mathematics":
923
- ["Discrete Mathematics", "\u79bb\u6563\u6570\u5b66", "STEM"],
 
 
 
 
 
 
 
 
924
  "electrical_engineer": [
925
- "Electrical Engineer", "\u6ce8\u518c\u7535\u6c14\u5de5\u7a0b\u5e08",
926
- "STEM"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
927
  ],
928
- "metrology_engineer":
929
- ["Metrology Engineer", "\u6ce8\u518c\u8ba1\u91cf\u5e08", "STEM"],
930
- "high_school_mathematics":
931
- ["High School Mathematics", "\u9ad8\u4e2d\u6570\u5b66", "STEM"],
932
- "high_school_physics":
933
- ["High School Physics", "\u9ad8\u4e2d\u7269\u7406", "STEM"],
934
- "high_school_chemistry":
935
- ["High School Chemistry", "\u9ad8\u4e2d\u5316\u5b66", "STEM"],
936
  "high_school_biology": [
937
- "High School Biology", "\u9ad8\u4e2d\u751f\u7269", "STEM"
 
 
938
  ],
939
  "middle_school_mathematics": [
940
- "Middle School Mathematics", "\u521d\u4e2d\u6570\u5b66", "STEM"
 
 
941
  ],
942
  "middle_school_biology": [
943
- "Middle School Biology", "\u521d\u4e2d\u751f\u7269", "STEM"
 
 
944
  ],
945
  "middle_school_physics": [
946
- "Middle School Physics", "\u521d\u4e2d\u7269\u7406", "STEM"
 
 
947
  ],
948
  "middle_school_chemistry": [
949
- "Middle School Chemistry", "\u521d\u4e2d\u5316\u5b66", "STEM"
950
- ],
951
- "veterinary_medicine": [
952
- "Veterinary Medicine", "\u517d\u533b\u5b66", "STEM"
953
  ],
 
954
  "college_economics": [
955
- "College Economics", "\u5927\u5b66\u7ecf\u6d4e\u5b66", "Social Science"
 
 
956
  ],
957
  "business_administration": [
958
- "Business Administration", "\u5de5\u5546\u7ba1\u7406", "Social Science"
 
 
959
  ],
960
  "marxism": [
961
- "Marxism", "\u9a6c\u514b\u601d\u4e3b\u4e49\u57fa\u672c\u539f\u7406",
962
- "Social Science"
 
963
  ],
964
  "mao_zedong_thought": [
965
  "Mao Zedong Thought",
966
  "\u6bdb\u6cfd\u4e1c\u601d\u60f3\u548c\u4e2d\u56fd\u7279\u8272\u793e\u4f1a\u4e3b\u4e49\u7406\u8bba\u4f53\u7cfb\u6982\u8bba",
967
- "Social Science"
968
  ],
969
  "education_science": [
970
- "Education Science", "\u6559\u80b2\u5b66", "Social Science"
 
 
971
  ],
972
  "teacher_qualification": [
973
- "Teacher Qualification", "\u6559\u5e08\u8d44\u683c", "Social Science"
 
 
974
  ],
975
  "high_school_politics": [
976
- "High School Politics", "\u9ad8\u4e2d\u653f\u6cbb", "Social Science"
 
 
977
  ],
978
  "high_school_geography": [
979
- "High School Geography", "\u9ad8\u4e2d\u5730\u7406", "Social Science"
 
 
980
  ],
981
  "middle_school_politics": [
982
- "Middle School Politics", "\u521d\u4e2d\u653f\u6cbb", "Social Science"
 
 
983
  ],
984
  "middle_school_geography": [
985
- "Middle School Geography", "\u521d\u4e2d\u5730\u7406", "Social Science"
 
 
 
 
 
 
 
986
  ],
987
- "modern_chinese_history":
988
- ["Modern Chinese History", "\u8fd1\u4ee3\u53f2\u7eb2\u8981", "Humanities"],
989
  "ideological_and_moral_cultivation": [
990
  "Ideological and Moral Cultivation",
991
  "\u601d\u60f3\u9053\u5fb7\u4fee\u517b\u4e0e\u6cd5\u5f8b\u57fa\u7840",
992
- "Humanities"
993
  ],
994
  "logic": ["Logic", "\u903b\u8f91\u5b66", "Humanities"],
995
  "law": ["Law", "\u6cd5\u5b66", "Humanities"],
996
  "chinese_language_and_literature": [
997
  "Chinese Language and Literature",
998
- "\u4e2d\u56fd\u8bed\u8a00\u6587\u5b66", "Humanities"
 
999
  ],
1000
  "art_studies": ["Art Studies", "\u827a\u672f\u5b66", "Humanities"],
1001
  "professional_tour_guide": [
1002
- "Professional Tour Guide", "\u5bfc\u6e38\u8d44\u683c", "Humanities"
 
 
1003
  ],
1004
  "legal_professional": [
1005
- "Legal Professional", "\u6cd5\u5f8b\u804c\u4e1a\u8d44\u683c",
1006
- "Humanities"
 
1007
  ],
1008
  "high_school_chinese": [
1009
- "High School Chinese", "\u9ad8\u4e2d\u8bed\u6587", "Humanities"
 
 
1010
  ],
1011
  "high_school_history": [
1012
- "High School History", "\u9ad8\u4e2d\u5386\u53f2", "Humanities"
 
 
1013
  ],
1014
  "middle_school_history": [
1015
- "Middle School History", "\u521d\u4e2d\u5386\u53f2", "Humanities"
 
 
1016
  ],
1017
  "civil_servant": ["Civil Servant", "\u516c\u52a1\u5458", "Other"],
1018
  "sports_science": ["Sports Science", "\u4f53\u80b2\u5b66", "Other"],
1019
- "plant_protection": [
1020
- "Plant Protection", "\u690d\u7269\u4fdd\u62a4", "Other"
1021
- ],
1022
  "basic_medicine": ["Basic Medicine", "\u57fa\u7840\u533b\u5b66", "Other"],
1023
- "clinical_medicine": [
1024
- "Clinical Medicine", "\u4e34\u5e8a\u533b\u5b66", "Other"
1025
- ],
1026
  "urban_and_rural_planner": [
1027
  "Urban and Rural Planner",
1028
- "\u6ce8\u518c\u57ce\u4e61\u89c4\u5212\u5e08", "Other"
 
1029
  ],
1030
  "accountant": ["Accountant", "\u6ce8\u518c\u4f1a\u8ba1\u5e08", "Other"],
1031
  "fire_engineer": [
1032
- "Fire Engineer", "\u6ce8\u518c\u6d88\u9632\u5de5\u7a0b\u5e08", "Other"
 
 
1033
  ],
1034
  "environmental_impact_assessment_engineer": [
1035
  "Environmental Impact Assessment Engineer",
1036
- "\u73af\u5883\u5f71\u54cd\u8bc4\u4ef7\u5de5\u7a0b\u5e08", "Other"
 
1037
  ],
1038
  "tax_accountant": ["Tax Accountant", "\u7a0e\u52a1\u5e08", "Other"],
1039
- "physician": ["Physician", "\u533b\u5e08\u8d44\u683c", "Other"]
1040
  }
1041
-
1042
  @classmethod
1043
  def suite(cls, chat: bool):
1044
  suite = defaultdict(list)
@@ -1058,8 +1103,8 @@ class CEVAL:
1058
  prompt=partial(cls.prompt_ceval, cate=subject, chat=chat),
1059
  few_shot=0 if chat else 5,
1060
  few_shot_from="dev",
1061
- split="val"
1062
  )
1063
  )
1064
-
1065
- return suite
 
1
  from dataclasses import dataclass, field
2
+
3
  from datasets import load_dataset, Dataset
4
  from functools import cached_property
5
  from tqdm.auto import tqdm
6
+ from typing import Any, Optional, Callable
7
  import logging
8
  import pandas as pd
9
  from functools import partial
 
188
  class Metrics:
189
  cmmlu = multichoice_zh
190
  mmlu = multichoice
191
+
192
  def ceval(responses: list[str], answers: list[str | int]):
193
  responses = [extract_choice_zh(pred) for pred in responses]
194
  return responses, answers
195
+
196
  def winogrande(responses: list[str], answers: list[str | int]):
197
  responses = [first_option_postprocess(pred, options="AB") for pred in responses]
198
  return responses, answers
199
+
200
  def arc(responses: list[str], answers: list[str | int]):
201
  if len(responses) != len(answers):
202
+ return {"error": "predictions and references have different " "length"}
203
+ responses = [
204
+ first_option_postprocess(pred, options="ABCD") for pred in responses
205
+ ]
 
206
 
207
  return responses, answers
208
+
209
  def hellaswag(responses: list[str], answers: list[str | int]):
210
  if len(responses) != len(answers):
211
+ return {"error": "predictions and references have different " "length"}
212
+ responses = [
213
+ first_option_postprocess(pred, options="ABCD") for pred in responses
214
+ ]
215
+ answers = ["ABCD"[int(ans)] for ans in answers]
 
216
  return responses, answers
217
+
218
  def drop(responses: list[str], answers: list[list]):
219
  if len(responses) != len(answers):
220
+ return {"error": "predictions and references have different " "length"}
 
 
 
221
  responses = [general_postprocess(pred) for pred in responses]
222
+ processed_answers = [[general_postprocess(j) for j in i] for i in answers]
 
223
  matched_answers = []
224
+ for pred, ans, origin_ans in zip(responses, processed_answers, answers):
 
 
225
  if pred in ans or pred in origin_ans:
226
  matched_answers.append(pred)
227
  else:
228
  matched_answers.append(ans[0])
229
+
230
  return responses, matched_answers
231
+
232
  def bbh_mcq(responses: list[str], answers: list[str | int]):
233
  if len(responses) != len(answers):
234
+ return {"error": "predictions and references have different " "length"}
 
 
 
235
  responses = [bbh_mcq_postprocess(pred) for pred in responses]
236
 
237
  return responses, answers
238
+
239
  def bbh_freefrom(responses: list[str], answers: list[str | int]):
240
  if len(responses) != len(answers):
241
+ return {"error": "predictions and references have different " "length"}
 
 
 
242
 
243
  responses = [bbh_freeform_postprocess(pred) for pred in responses]
244
 
 
259
  return responses, answers
260
 
261
  def MATH(responses: list[str], answers: list[str]):
262
+ extract_responses = []
263
+ for response in responses:
 
264
  indices = [pos for pos, char in enumerate(response) if char == "$"]
265
  if len(indices) <= 2:
266
+ ans = ""
 
267
  else:
268
+ ans = response[indices[-2] + 1 : indices[-1]]
269
+ extract_responses.append(strip_string(ans))
270
+ extract_answers = [strip_string(get_answer(answer)) for answer in answers]
271
+ return extract_responses, extract_answers
 
 
 
 
 
 
 
 
 
272
 
273
 
274
  class CMMLU:
 
546
  class Winogrande:
547
  input_column = "input"
548
  label_column = "answer"
549
+
550
  categories = [
551
  "winogrande_debiased",
552
  "winogrande_l",
 
555
  "winogrande_xl",
556
  "winogrande_xs",
557
  ]
558
+
559
  @classmethod
560
  def prompt_winogrande(cls, example):
561
+ option1 = example["sentence"].replace("_", example["option1"])
562
+ option2 = example["sentence"].replace("_", example["option2"])
563
  answer = example[cls.label_column]
564
  prompt = f"Which of the following is a good sentence:\nA. {option1}\nB. {option2}\nAnswer:"
565
+
566
  return {
567
  cls.input_column: prompt,
568
+ cls.label_column: " AB"[int(answer)] if answer != "" else "",
569
  }
570
+
571
  @classmethod
572
+ def suite(
573
+ cls,
574
+ ):
575
+ subcategories = {item: [item] for item in cls.categories}
576
  finer_categories = (
577
  pd.Series(subcategories) # noqa # type: ignore
578
  .explode()
 
594
  label_column=cls.label_column,
595
  prompt=partial(cls.prompt_winogrande),
596
  few_shot=0,
597
+ split="validation",
598
  )
599
  )
600
+
601
  return suite
 
602
 
603
 
604
  class DROP:
605
  input_column = "input"
606
  label_column = "answers"
607
+
608
+ icl_prompt = """\
609
  Text: In the county, the population was spread out with 23.50% under the age of 18, 8.70% from 18 to 24, 29.70% from 25 to 44, 24.70% from 45 to 64, and 13.30% who were 65 years of age or older.
610
  Question: How many more percent are under the age of 18 compared to the 18 to 24 group?
611
  Anawer: According to the text, 23.5% are under the age of 18, and 8.7% are from ages 18 to 24. 23.5%-8.7%=14.8%. So the answer is 14.8.
 
616
 
617
  Text: [PROMPT]
618
  Question: [QUESTION]
619
+ Anawer:"""
620
+
621
  categories = ["validation"]
622
+
623
  @classmethod
624
  def prompt_drop(cls, example):
625
+ prompt = cls.icl_prompt.replace("[PROMPT]", example["passage"]).replace(
626
+ "[QUESTION]", example["question"]
627
+ )
628
+
629
  validated_answers = example["answers_spans"]["spans"]
630
  validated_types = example["answers_spans"]["types"]
631
  answers = []
 
637
  # answers.append(' '.join(d).strip())
638
  # else:
639
  # for span in answer_item['spans']:
640
+ # answers.append(span)
641
  answers.append(answer_item)
642
  answers = list(set(answers))
643
+
644
+ return {cls.input_column: prompt, cls.label_column: answers}
645
+
 
 
 
646
  @classmethod
647
+ def suite(
648
+ cls,
649
+ ):
650
  finer_categories = (
651
  pd.Series(cls.categories) # noqa # type: ignore
652
  .explode()
 
667
  label_column=cls.label_column,
668
  prompt=partial(cls.prompt_drop),
669
  few_shot=0,
670
+ split="validation",
671
  )
672
  )
673
+
674
  return suite
675
 
676
 
677
  class HellaSwag:
678
  input_column = "input"
679
  label_column = "label"
680
+
681
  categories = ["validation"]
682
+
683
  @classmethod
684
  def prompt_hellaswag(cls, example):
 
685
  prompt = f"{example['ctx']}\nQuestion: Which ending makes the most sense?\n"
686
  prompt += f"A. {example['endings'][0]}\n"
687
  prompt += f"B. {example['endings'][1]}\n"
688
  prompt += f"C. {example['endings'][2]}\n"
689
  prompt += f"D. {example['endings'][3]}\n"
690
  prompt += "You may choose from 'A', 'B', 'C', 'D'.\nAnswer:"
691
+
692
  return {cls.input_column: prompt}
693
+
694
  @classmethod
695
+ def suite(
696
+ cls,
697
+ ):
698
  finer_categories = (
699
  pd.Series(cls.categories) # noqa # type: ignore
700
  .explode()
 
715
  label_column=cls.label_column,
716
  prompt=partial(cls.prompt_hellaswag),
717
  few_shot=0,
718
+ split="validation",
719
  )
720
  )
721
+
722
  return suite
723
 
724
+
725
  class ARC:
726
  input_column = "input"
727
  label_column = "answerKey"
728
+
729
  categories = [
730
  "ARC-Challenge",
731
  "ARC-Easy",
732
  ]
733
+
734
  @classmethod
735
  def prompt_arc(cls, example):
736
  choices = example["choices"]
 
738
  for label, choice in zip(choices["label"], choices["text"]):
739
  prompt += f"\n{label}. {choice}"
740
  prompt += "\nAnswer:"
741
+ return {cls.input_column: prompt}
742
+
 
 
743
  @classmethod
744
  def suite(cls):
745
  finer_categories = (
 
764
  few_shot=0,
765
  )
766
  )
767
+
768
  return suite
769
 
770
 
771
  class BBH:
772
  input_column = "input"
773
  label_column = "target"
774
+
775
  multiple_choice_prefix = "Follow the given examples and answer the question.\n[HINT]\n\nQ: [INPUT]\nA: Let's think step by step."
776
  free_form_prefix = "Follow the given examples and answer the question.\n[HINT]\n\nQ: [INPUT]\nA: Let's think step by step."
777
+
778
  bbh_multiple_choice_sets = [
779
+ "temporal_sequences",
780
+ "disambiguation_qa",
781
+ "date_understanding",
782
+ "tracking_shuffled_objects_three_objects",
783
+ "penguins_in_a_table",
784
+ "geometric_shapes",
785
+ "snarks",
786
+ "ruin_names",
787
+ "tracking_shuffled_objects_seven_objects",
788
+ "tracking_shuffled_objects_five_objects",
789
+ "logical_deduction_three_objects",
790
+ "hyperbaton",
791
+ "logical_deduction_five_objects",
792
+ "logical_deduction_seven_objects",
793
+ "movie_recommendation",
794
+ "salient_translation_error_detection",
795
+ "reasoning_about_colored_objects",
796
  ]
797
+
798
  bbh_free_form_sets = [
799
+ "multistep_arithmetic_two",
800
+ "navigate",
801
+ "dyck_languages",
802
+ "word_sorting",
803
+ "sports_understanding",
804
+ "boolean_expressions",
805
+ "object_counting",
806
+ "formal_fallacies",
807
+ "causal_judgement",
808
+ "web_of_lies",
809
  ]
810
+
811
  @classmethod
812
+ def prompt_bbh(cls, example, category: str):
813
+ meta_prompt = (
814
+ cls.multiple_choice_prefix
815
+ if category in cls.bbh_multiple_choice_sets
816
+ else cls.free_form_prefix
817
+ )
818
+ prompt = meta_prompt.replace(
819
+ "[HINT]", bbh_lib_prompt(category=category)
820
+ ).replace("[INPUT]", example[cls.input_column])
821
+
822
  return {"input": prompt}
823
+
824
  @classmethod
825
+ def suite(
826
+ cls,
827
+ ):
828
  finer_categories = (
829
+ pd.Series(
830
+ cls.bbh_free_form_sets + cls.bbh_multiple_choice_sets
831
+ ) # noqa # type: ignore
832
  .explode()
833
  .reset_index()
834
  .set_index(0)
 
861
  few_shot=0,
862
  )
863
  )
864
+
865
  return suite
 
866
 
867
+
868
  class CEVAL:
869
  input_column = "input"
870
  label_column = "answer"
871
+
872
  @classmethod
873
+ def prompt_ceval(cls, example, cate: str, chat=False):
874
  _ch_name = cls.ceval_subject_mapping[cate][1]
875
+ prefix = f"以下是中国关于{_ch_name}考试的单项选择题,请选出其中的正确答案。\n" if chat else "问题:"
876
+
 
 
 
 
877
  prompt = prefix + f'{example["question"]}'
878
  for choice in list("ABCD"):
879
  prompt += f"\n{choice}. {example[choice]}"
880
 
881
  prompt += "\n答案:"
882
  return {"input": prompt}
883
+
884
  ceval_subject_mapping = {
885
+ "computer_network": [
886
+ "Computer Network",
887
+ "\u8ba1\u7b97\u673a\u7f51\u7edc",
888
+ "STEM",
889
+ ],
890
+ "operating_system": ["Operating System", "\u64cd\u4f5c\u7cfb\u7edf", "STEM"],
891
+ "computer_architecture": [
892
+ "Computer Architecture",
893
+ "\u8ba1\u7b97\u673a\u7ec4\u6210",
894
+ "STEM",
895
+ ],
896
+ "college_programming": [
897
+ "College Programming",
898
+ "\u5927\u5b66\u7f16\u7a0b",
899
+ "STEM",
900
+ ],
901
  "college_physics": ["College Physics", "\u5927\u5b66\u7269\u7406", "STEM"],
902
+ "college_chemistry": ["College Chemistry", "\u5927\u5b66\u5316\u5b66", "STEM"],
903
+ "advanced_mathematics": [
904
+ "Advanced Mathematics",
905
+ "\u9ad8\u7b49\u6570\u5b66",
906
+ "STEM",
907
+ ],
908
+ "probability_and_statistics": [
909
+ "Probability and Statistics",
910
+ "\u6982\u7387\u7edf\u8ba1",
911
+ "STEM",
912
+ ],
913
+ "discrete_mathematics": [
914
+ "Discrete Mathematics",
915
+ "\u79bb\u6563\u6570\u5b66",
916
+ "STEM",
917
+ ],
918
  "electrical_engineer": [
919
+ "Electrical Engineer",
920
+ "\u6ce8\u518c\u7535\u6c14\u5de5\u7a0b\u5e08",
921
+ "STEM",
922
+ ],
923
+ "metrology_engineer": [
924
+ "Metrology Engineer",
925
+ "\u6ce8\u518c\u8ba1\u91cf\u5e08",
926
+ "STEM",
927
+ ],
928
+ "high_school_mathematics": [
929
+ "High School Mathematics",
930
+ "\u9ad8\u4e2d\u6570\u5b66",
931
+ "STEM",
932
+ ],
933
+ "high_school_physics": [
934
+ "High School Physics",
935
+ "\u9ad8\u4e2d\u7269\u7406",
936
+ "STEM",
937
+ ],
938
+ "high_school_chemistry": [
939
+ "High School Chemistry",
940
+ "\u9ad8\u4e2d\u5316\u5b66",
941
+ "STEM",
942
  ],
 
 
 
 
 
 
 
 
943
  "high_school_biology": [
944
+ "High School Biology",
945
+ "\u9ad8\u4e2d\u751f\u7269",
946
+ "STEM",
947
  ],
948
  "middle_school_mathematics": [
949
+ "Middle School Mathematics",
950
+ "\u521d\u4e2d\u6570\u5b66",
951
+ "STEM",
952
  ],
953
  "middle_school_biology": [
954
+ "Middle School Biology",
955
+ "\u521d\u4e2d\u751f\u7269",
956
+ "STEM",
957
  ],
958
  "middle_school_physics": [
959
+ "Middle School Physics",
960
+ "\u521d\u4e2d\u7269\u7406",
961
+ "STEM",
962
  ],
963
  "middle_school_chemistry": [
964
+ "Middle School Chemistry",
965
+ "\u521d\u4e2d\u5316\u5b66",
966
+ "STEM",
 
967
  ],
968
+ "veterinary_medicine": ["Veterinary Medicine", "\u517d\u533b\u5b66", "STEM"],
969
  "college_economics": [
970
+ "College Economics",
971
+ "\u5927\u5b66\u7ecf\u6d4e\u5b66",
972
+ "Social Science",
973
  ],
974
  "business_administration": [
975
+ "Business Administration",
976
+ "\u5de5\u5546\u7ba1\u7406",
977
+ "Social Science",
978
  ],
979
  "marxism": [
980
+ "Marxism",
981
+ "\u9a6c\u514b\u601d\u4e3b\u4e49\u57fa\u672c\u539f\u7406",
982
+ "Social Science",
983
  ],
984
  "mao_zedong_thought": [
985
  "Mao Zedong Thought",
986
  "\u6bdb\u6cfd\u4e1c\u601d\u60f3\u548c\u4e2d\u56fd\u7279\u8272\u793e\u4f1a\u4e3b\u4e49\u7406\u8bba\u4f53\u7cfb\u6982\u8bba",
987
+ "Social Science",
988
  ],
989
  "education_science": [
990
+ "Education Science",
991
+ "\u6559\u80b2\u5b66",
992
+ "Social Science",
993
  ],
994
  "teacher_qualification": [
995
+ "Teacher Qualification",
996
+ "\u6559\u5e08\u8d44\u683c",
997
+ "Social Science",
998
  ],
999
  "high_school_politics": [
1000
+ "High School Politics",
1001
+ "\u9ad8\u4e2d\u653f\u6cbb",
1002
+ "Social Science",
1003
  ],
1004
  "high_school_geography": [
1005
+ "High School Geography",
1006
+ "\u9ad8\u4e2d\u5730\u7406",
1007
+ "Social Science",
1008
  ],
1009
  "middle_school_politics": [
1010
+ "Middle School Politics",
1011
+ "\u521d\u4e2d\u653f\u6cbb",
1012
+ "Social Science",
1013
  ],
1014
  "middle_school_geography": [
1015
+ "Middle School Geography",
1016
+ "\u521d\u4e2d\u5730\u7406",
1017
+ "Social Science",
1018
+ ],
1019
+ "modern_chinese_history": [
1020
+ "Modern Chinese History",
1021
+ "\u8fd1\u4ee3\u53f2\u7eb2\u8981",
1022
+ "Humanities",
1023
  ],
 
 
1024
  "ideological_and_moral_cultivation": [
1025
  "Ideological and Moral Cultivation",
1026
  "\u601d\u60f3\u9053\u5fb7\u4fee\u517b\u4e0e\u6cd5\u5f8b\u57fa\u7840",
1027
+ "Humanities",
1028
  ],
1029
  "logic": ["Logic", "\u903b\u8f91\u5b66", "Humanities"],
1030
  "law": ["Law", "\u6cd5\u5b66", "Humanities"],
1031
  "chinese_language_and_literature": [
1032
  "Chinese Language and Literature",
1033
+ "\u4e2d\u56fd\u8bed\u8a00\u6587\u5b66",
1034
+ "Humanities",
1035
  ],
1036
  "art_studies": ["Art Studies", "\u827a\u672f\u5b66", "Humanities"],
1037
  "professional_tour_guide": [
1038
+ "Professional Tour Guide",
1039
+ "\u5bfc\u6e38\u8d44\u683c",
1040
+ "Humanities",
1041
  ],
1042
  "legal_professional": [
1043
+ "Legal Professional",
1044
+ "\u6cd5\u5f8b\u804c\u4e1a\u8d44\u683c",
1045
+ "Humanities",
1046
  ],
1047
  "high_school_chinese": [
1048
+ "High School Chinese",
1049
+ "\u9ad8\u4e2d\u8bed\u6587",
1050
+ "Humanities",
1051
  ],
1052
  "high_school_history": [
1053
+ "High School History",
1054
+ "\u9ad8\u4e2d\u5386\u53f2",
1055
+ "Humanities",
1056
  ],
1057
  "middle_school_history": [
1058
+ "Middle School History",
1059
+ "\u521d\u4e2d\u5386\u53f2",
1060
+ "Humanities",
1061
  ],
1062
  "civil_servant": ["Civil Servant", "\u516c\u52a1\u5458", "Other"],
1063
  "sports_science": ["Sports Science", "\u4f53\u80b2\u5b66", "Other"],
1064
+ "plant_protection": ["Plant Protection", "\u690d\u7269\u4fdd\u62a4", "Other"],
 
 
1065
  "basic_medicine": ["Basic Medicine", "\u57fa\u7840\u533b\u5b66", "Other"],
1066
+ "clinical_medicine": ["Clinical Medicine", "\u4e34\u5e8a\u533b\u5b66", "Other"],
 
 
1067
  "urban_and_rural_planner": [
1068
  "Urban and Rural Planner",
1069
+ "\u6ce8\u518c\u57ce\u4e61\u89c4\u5212\u5e08",
1070
+ "Other",
1071
  ],
1072
  "accountant": ["Accountant", "\u6ce8\u518c\u4f1a\u8ba1\u5e08", "Other"],
1073
  "fire_engineer": [
1074
+ "Fire Engineer",
1075
+ "\u6ce8\u518c\u6d88\u9632\u5de5\u7a0b\u5e08",
1076
+ "Other",
1077
  ],
1078
  "environmental_impact_assessment_engineer": [
1079
  "Environmental Impact Assessment Engineer",
1080
+ "\u73af\u5883\u5f71\u54cd\u8bc4\u4ef7\u5de5\u7a0b\u5e08",
1081
+ "Other",
1082
  ],
1083
  "tax_accountant": ["Tax Accountant", "\u7a0e\u52a1\u5e08", "Other"],
1084
+ "physician": ["Physician", "\u533b\u5e08\u8d44\u683c", "Other"],
1085
  }
1086
+
1087
  @classmethod
1088
  def suite(cls, chat: bool):
1089
  suite = defaultdict(list)
 
1103
  prompt=partial(cls.prompt_ceval, cate=subject, chat=chat),
1104
  few_shot=0 if chat else 5,
1105
  few_shot_from="dev",
1106
+ split="val",
1107
  )
1108
  )
1109
+
1110
+ return suite
tlem.py CHANGED
@@ -135,6 +135,15 @@ class Suite(EvaluationSuite):
135
  prompt=mt_bench_prompt
136
  # metric_name=("sustech/tlem", "gsm8k"),
137
  )
 
 
 
 
 
 
 
 
 
138
  match name:
139
  case _ if "test" in name:
140
  suite = suite["Test"]
 
135
  prompt=mt_bench_prompt
136
  # metric_name=("sustech/tlem", "gsm8k"),
137
  )
138
+ case "MATH" | "competition_math":
139
+ suite = Task(
140
+ dataset_name="hendrycks/competition_math",
141
+ split="test",
142
+ prompt="This is a math problem, please think step by step and slove it: {input_column}",
143
+ metric_name=("sustech/tlem", "MATH"),
144
+ input_column="problem",
145
+ label_column="solution",
146
+ )
147
  match name:
148
  case _ if "test" in name:
149
  suite = suite["Test"]