Robin-7b / lmflow /utils /data_utils.py
NingKanae's picture
Duplicate from OptimalScale/Robin-7b
98f2419
raw
history blame
6.59 kB
"""The program includes several functions: setting a random seed,
loading data from a JSON file, batching data, and extracting answers from generated text.
"""
import random
import numpy as np
import torch
import json
import re
def set_random_seed(seed: int):
"""
Set the random seed for `random`, `numpy`, `torch`, `torch.cuda`.
Parameters
------------
seed : int
The default seed.
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def load_data(file_name: str):
"""
Load data with file name.
Parameters
------------
file_name : str.
The dataset file name.
Returns
------------
inputs : list.
The input texts of the dataset.
outputs : list.
The output texts file datasets.
len : int.
The length of the dataset.
"""
inputs = []
outputs = []
type = ""
with open(file_name, encoding='utf-8') as f:
json_data = json.load(f)
type = json_data["type"]
for line in json_data["instances"]:
inputs.append(line["input"])
outputs.append(line["output"])
print(f"load dataset {file_name} success.\n")
print(f"Type : {type}, datasize : {len(outputs)}")
return inputs, outputs, len(outputs)
def batchlize(examples: list, batch_size: int, random_shuffle: bool):
"""
Convert examples to a dataloader.
Parameters
------------
examples : list.
Data list.
batch_size : int.
random_shuffle : bool
If true, the dataloader shuffle the training data.
Returns
------------
dataloader:
Dataloader with batch generator.
"""
size = 0
dataloader = []
length = len(examples)
if (random_shuffle):
random.shuffle(examples)
while size < length:
if length - size > batch_size:
dataloader.append(examples[size : size+batch_size])
size += batch_size
else:
dataloader.append(examples[size : size+(length-size)])
size += (length - size)
return dataloader
def answer_extraction(response, answer_type=None): #use this funtion to extract answers from generated text
"""
Use this funtion to extract answers from generated text
Parameters
------------
args :
Arguments.
response : str
plain string response.
Returns
------------
answer:
Decoded answer (such as A, B, C, D, E for mutiple-choice QA).
"""
# temp = response["generated_text"]
temp = response
if answer_type in ("gsm8k", "svamp", "asdiv", "addsub", "singleeq", "multiarith", "math"):
temp = temp.replace(",", "")
temp = [s for s in re.findall(r'-?\d+\.?\d*', temp)]
elif answer_type in ("aqua", "csqa", "multiple_choice"):
temp = re.findall(r'A|B|C|D|E', temp)
elif answer_type in ("strategyqa", "coin_flip"):
temp = temp.lower()
temp = re.sub("\"|\'|\n|\.|\s|\:|\,"," ", temp)
temp = temp.split(" ")
temp = [i for i in temp if i in ("yes", "no")]
elif answer_type in ("last_letters"):
temp = re.sub("\"|\'|\n|\.|\s","", temp)
temp = [temp]
elif answer_type in ("pubmedqa", "binary_choice"):
# pattern = "Output: (yes|no|maybe)"
# sttr = re.search(pattern, temp)
# answer = sttr.group(0)[8:] if sttr is not None else "N/A"
pattern = "(answer|Answer|ANSWER|output|Output|OUTPUT|A): \(*(yes|Yes|YES|no|No|NO|maybe|Maybe|MAYBE)"
sttr = re.search(pattern, temp)
if sttr is not None:
mid_answer = sttr.group(0)
mid_answer = mid_answer.split(":")[-1].strip()
answer = mid_answer.lower()
else:
pattern = "(yes|Yes|YES|no|No|NO|maybe|Maybe|MAYBE)(\.|\s)"
sttr = re.search(pattern, temp)
if sttr is not None:
answer = sttr.group(0)[:-1].lower()
else:
answer = "N/A"
return answer
elif answer_type == "medmcqa":
# pattern = "Output: (A|B|C|D)."
# sttr = re.search(pattern, temp)
# answer = sttr.group(0)[8:-1].lower() if sttr is not None else "N/A"
pattern = "(answer|Answer|ANSWER|output|Output|OUTPUT|A): \(*(A|B|C|D|a|b|c|d)"
sttr = re.search(pattern, temp)
if sttr is not None:
mid_answer = sttr.group(0)
answer = mid_answer[-1].lower()
else:
pattern = "\(*(A|B|C|D|a|b|c|d)\)*(\.|\s)"
sttr = re.search(pattern, temp)
if sttr is not None:
if '(' in sttr.group(0):
answer = sttr.group(0)[1].lower()
else:
answer = sttr.group(0)[0].lower()
else:
answer = "N/A"
return answer
elif answer_type == "usmle":
# pattern = "Output: (A|B|C|D)."
# sttr = re.search(pattern, temp)
# answer = sttr.group(0)[8:-1].lower() if sttr is not None else "N/A"
pattern = "(Answer|Output|A): \(*(A|B|C|D|a|b|c|d)"
sttr = re.search(pattern, temp)
if sttr is not None:
mid_answer = sttr.group(0)
answer = mid_answer[-1].lower()
else:
pattern = "\(*(A|B|C|D|a|b|c|d)\)*(\.|\s)"
sttr = re.search(pattern, temp)
if sttr is not None:
if '(' in sttr.group(0):
answer = sttr.group(0)[1].lower()
else:
answer = sttr.group(0)[0].lower()
else:
answer = "N/A"
return answer
elif answer_type == "text":
return response
else:
raise NotImplementedError(f"Unsupported answer type: {answer_type}")
if len(temp) != 0:
answer = temp[-1]
# if there is . at the end of answer, remove it
# e.g. answer = 64.
if answer != "":
if answer[-1] == ".":
answer = answer[:-1]
# round the answer to nearest integer
if answer_type in ("gsm8k", "svamp"):
try:
answer = str(round(float(answer)))
except:
answer = "" # no sol or sol doesn't have valid format
elif answer_type in ("last_letters"):
try:
answer = answer[-args.concat_length:]
except:
answer = ""
else:
answer = ""
return answer