SwiftSage / data_loader.py
yuchenlin's picture
Upload 14 files
1a0cf07 verified
raw
history blame
No virus
2.7 kB
import json
import os
import re
import random
from typing import Any, Iterable, Union
from datasets import Dataset, concatenate_datasets, load_dataset
from data_utils import (
lower_keys,
parse_question,
parse_ground_truth,
)
def load_jsonl(file):
with open(file, "r", encoding="utf-8") as f:
for line in f:
try:
yield json.loads(line)
except:
print("Error in loading:", line)
exit()
def load_data(
data_name,
split='test',
data_dir='./data',
num_test_sample=-1,
):
if data_name.lower() == "math":
data_name = 'MATH' # we use 500 problem test split in "Let's Verify Step-by-Step"
data_file = f"{data_dir}/{data_name}/{split}.jsonl"
if os.path.exists(data_file):
examples = list(load_jsonl(data_file))
else:
if data_name == "mmlu_stem":
dataset = load_dataset("hails/mmlu_no_train", 'all', split='test')
# only keep stem subjects
stem_subjects = ['abstract_algebra', 'astronomy', 'college_biology', 'college_chemistry',
'college_computer_science', 'college_mathematics', 'college_physics', 'computer_security',
'conceptual_physics', 'electrical_engineering', 'elementary_mathematics', 'high_school_biology',
'high_school_chemistry', 'high_school_computer_science', 'high_school_mathematics',
'high_school_physics', 'high_school_statistics', 'machine_learning']
dataset = dataset.rename_column("subject", "type")
dataset = dataset.filter(lambda x: x['type'] in stem_subjects)
elif data_name == "mathvista":
raise NotImplementedError(data_name)
elif data_name == "gpqa":
dataset = load_dataset("Idavidrein/gpqa", "gpqa_diamond", split="train")
elif data_name == "codeforces":
raise NotImplementedError(data_name)
else:
raise NotImplementedError(data_name)
examples = list(dataset)
examples = [lower_keys(example) for example in examples]
dataset = Dataset.from_list(examples)
os.makedirs(f"{data_dir}/{data_name}", exist_ok=True)
dataset.to_json(data_file)
# add 'idx' in the first column
if 'idx' not in examples[0]:
examples = [{'idx': i, **example} for i, example in enumerate(examples)]
# dedepulicate & sort
examples = sorted(examples, key=lambda x: x['idx'])
if num_test_sample > 0:
examples = examples[:num_test_sample]
return examples
if __name__ == "__main__":
examples = load_data("gpqa", "test")
print('test')