|
"""The program includes several functions: setting a random seed, |
|
loading data from a JSON file, batching data, and extracting answers from generated text. |
|
""" |
|
|
|
import random |
|
import numpy as np |
|
import torch |
|
import json |
|
import re |
|
def set_random_seed(seed: int): |
|
""" |
|
Set the random seed for `random`, `numpy`, `torch`, `torch.cuda`. |
|
|
|
Parameters |
|
------------ |
|
seed : int |
|
The default seed. |
|
|
|
""" |
|
random.seed(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
if torch.cuda.is_available(): |
|
torch.cuda.manual_seed_all(seed) |
|
|
|
def load_data(file_name: str): |
|
""" |
|
Load data with file name. |
|
|
|
Parameters |
|
------------ |
|
file_name : str. |
|
The dataset file name. |
|
|
|
Returns |
|
------------ |
|
inputs : list. |
|
The input texts of the dataset. |
|
outputs : list. |
|
The output texts file datasets. |
|
len : int. |
|
The length of the dataset. |
|
""" |
|
inputs = [] |
|
outputs = [] |
|
type = "" |
|
with open(file_name, encoding='utf-8') as f: |
|
json_data = json.load(f) |
|
type = json_data["type"] |
|
for line in json_data["instances"]: |
|
inputs.append(line["input"]) |
|
outputs.append(line["output"]) |
|
|
|
print(f"load dataset {file_name} success.\n") |
|
print(f"Type : {type}, datasize : {len(outputs)}") |
|
|
|
return inputs, outputs, len(outputs) |
|
|
|
def batchlize(examples: list, batch_size: int, random_shuffle: bool): |
|
""" |
|
Convert examples to a dataloader. |
|
|
|
Parameters |
|
------------ |
|
examples : list. |
|
Data list. |
|
batch_size : int. |
|
|
|
random_shuffle : bool |
|
If true, the dataloader shuffle the training data. |
|
|
|
Returns |
|
------------ |
|
dataloader: |
|
Dataloader with batch generator. |
|
""" |
|
size = 0 |
|
dataloader = [] |
|
length = len(examples) |
|
if (random_shuffle): |
|
random.shuffle(examples) |
|
while size < length: |
|
if length - size > batch_size: |
|
dataloader.append(examples[size : size+batch_size]) |
|
size += batch_size |
|
else: |
|
dataloader.append(examples[size : size+(length-size)]) |
|
size += (length - size) |
|
return dataloader |
|
|
|
|
|
|
|
def answer_extraction(response, answer_type=None): |
|
|
|
""" |
|
Use this funtion to extract answers from generated text |
|
|
|
Parameters |
|
------------ |
|
args : |
|
Arguments. |
|
response : str |
|
plain string response. |
|
|
|
|
|
Returns |
|
------------ |
|
answer: |
|
Decoded answer (such as A, B, C, D, E for mutiple-choice QA). |
|
""" |
|
|
|
|
|
temp = response |
|
if answer_type in ("gsm8k", "svamp", "asdiv", "addsub", "singleeq", "multiarith", "math"): |
|
temp = temp.replace(",", "") |
|
temp = [s for s in re.findall(r'-?\d+\.?\d*', temp)] |
|
elif answer_type in ("aqua", "csqa", "multiple_choice"): |
|
temp = re.findall(r'A|B|C|D|E', temp) |
|
elif answer_type in ("strategyqa", "coin_flip"): |
|
temp = temp.lower() |
|
temp = re.sub("\"|\'|\n|\.|\s|\:|\,"," ", temp) |
|
temp = temp.split(" ") |
|
temp = [i for i in temp if i in ("yes", "no")] |
|
elif answer_type in ("last_letters"): |
|
temp = re.sub("\"|\'|\n|\.|\s","", temp) |
|
temp = [temp] |
|
elif answer_type in ("pubmedqa", "binary_choice"): |
|
|
|
|
|
|
|
pattern = "(answer|Answer|ANSWER|output|Output|OUTPUT|A): \(*(yes|Yes|YES|no|No|NO|maybe|Maybe|MAYBE)" |
|
sttr = re.search(pattern, temp) |
|
if sttr is not None: |
|
mid_answer = sttr.group(0) |
|
mid_answer = mid_answer.split(":")[-1].strip() |
|
answer = mid_answer.lower() |
|
else: |
|
pattern = "(yes|Yes|YES|no|No|NO|maybe|Maybe|MAYBE)(\.|\s)" |
|
sttr = re.search(pattern, temp) |
|
if sttr is not None: |
|
answer = sttr.group(0)[:-1].lower() |
|
else: |
|
answer = "N/A" |
|
return answer |
|
elif answer_type == "medmcqa": |
|
|
|
|
|
|
|
pattern = "(answer|Answer|ANSWER|output|Output|OUTPUT|A): \(*(A|B|C|D|a|b|c|d)" |
|
sttr = re.search(pattern, temp) |
|
if sttr is not None: |
|
mid_answer = sttr.group(0) |
|
answer = mid_answer[-1].lower() |
|
else: |
|
pattern = "\(*(A|B|C|D|a|b|c|d)\)*(\.|\s)" |
|
sttr = re.search(pattern, temp) |
|
if sttr is not None: |
|
if '(' in sttr.group(0): |
|
answer = sttr.group(0)[1].lower() |
|
else: |
|
answer = sttr.group(0)[0].lower() |
|
else: |
|
answer = "N/A" |
|
return answer |
|
|
|
elif answer_type == "usmle": |
|
|
|
|
|
|
|
pattern = "(Answer|Output|A): \(*(A|B|C|D|a|b|c|d)" |
|
sttr = re.search(pattern, temp) |
|
if sttr is not None: |
|
mid_answer = sttr.group(0) |
|
answer = mid_answer[-1].lower() |
|
else: |
|
pattern = "\(*(A|B|C|D|a|b|c|d)\)*(\.|\s)" |
|
sttr = re.search(pattern, temp) |
|
if sttr is not None: |
|
if '(' in sttr.group(0): |
|
answer = sttr.group(0)[1].lower() |
|
else: |
|
answer = sttr.group(0)[0].lower() |
|
else: |
|
answer = "N/A" |
|
return answer |
|
elif answer_type == "text": |
|
return response |
|
else: |
|
raise NotImplementedError(f"Unsupported answer type: {answer_type}") |
|
|
|
if len(temp) != 0: |
|
answer = temp[-1] |
|
|
|
|
|
if answer != "": |
|
if answer[-1] == ".": |
|
answer = answer[:-1] |
|
|
|
|
|
if answer_type in ("gsm8k", "svamp"): |
|
try: |
|
answer = str(round(float(answer))) |
|
except: |
|
answer = "" |
|
elif answer_type in ("last_letters"): |
|
try: |
|
answer = answer[-args.concat_length:] |
|
except: |
|
answer = "" |
|
else: |
|
answer = "" |
|
return answer |
|
|