SPT / dataset /dataset.py
hqsiswiliam's picture
Upload 43 files
8359bb1 verified
raw
history blame
7.28 kB
import torch
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
from dataset.dataset_helper import read_personachat_split
from utils.format_inputs import TASK_TYPE
class PersonaChatDataset(torch.utils.data.Dataset):
# longest first for batch finder
def __init__(self, data_path, max_context_turns=-1,
add_role_indicator=True, only_longest=False, training_ratio=1.0,
task_type=TASK_TYPE.GENERATE_RESPONSE):
self.path = data_path
self.add_role_indicator = add_role_indicator
self.max_context_turns = max_context_turns
self.turns_data = read_personachat_split(data_path, only_longest=only_longest)
self.only_longest = only_longest
self.training_ratio = training_ratio
if training_ratio < 1.0:
self.turns_data = self.turns_data[:int(len(self.turns_data) * training_ratio)]
self.task_type = task_type
# # For debug only
# os.makedirs("data_logs", exist_ok=True)
# random_num = random.randint(0, 100000)
# self.file = open(f"data_logs/{random_num}_{data_path.split(os.sep)[-1]}", 'w')
# # add id to turns_data
# self.turns_data = [{'id': idx, **turn} for idx, turn in enumerate(self.turns_data)]
# self.file.write(f"total_turns: {len(self.turns_data)}\n")
def sort_longest_first(self):
self.turns_data = sorted(self.turns_data, key=lambda x: len(
(' '.join(x['persona']) + ' '.join(x['context']) + x['response']).split(' ')), reverse=True)
def __getitem__(self, idx):
# self.file.write(str(idx) + "\n")
# self.file.flush()
input_data = self.turns_data[idx]
persona_list = input_data['persona']
target = input_data['response']
context_input = input_data['context']
if self.add_role_indicator:
roled_context_input = [['Q: ', 'R: '][c_idx % 2] + context for c_idx, context in enumerate(context_input)]
context_input = roled_context_input
if self.max_context_turns != -1:
truncated_context = context_input[-(self.max_context_turns * 2 - 1):]
context_input = truncated_context
if self.only_longest:
context_input = context_input[:-1]
return {
'context_input': context_input,
'persona_list': persona_list,
'target': target
}
def __len__(self):
return len(self.turns_data)
# class HGPersonaChatDataset(PersonaChatDataset):
# def __init__(self, data_path, max_context_turns=-1,
# add_role_indicator=True, only_longest=False, tokenizer=None):
# super().__init__(data_path, max_context_turns, add_role_indicator, only_longest)
# self.tokenizer = tokenizer
#
# def __getitem__(self, idx):
# data = super().__getitem__(idx)
# input = "P: " + ' '.join(data['persona_list']) + " C: " + ' '.join(data['context_input']) + " R: " + data[
# 'target']
# tokenized = self.tokenizer(input)
# return {**data, **tokenized}
def collate_fn(sample_list):
dont_be_a_tensor = ['context_input', 'persona_list', 'target']
to_be_flattened = [*dont_be_a_tensor]
data = {}
for key in to_be_flattened:
if key not in sample_list[0].keys():
continue
if sample_list[0][key] is None:
continue
flatten_samples = [sample[key] for sample in sample_list]
if flatten_samples[-1].__class__ == str or key in dont_be_a_tensor:
data[key] = flatten_samples
else:
data[key] = torch.tensor(flatten_samples)
return data
def collate_fn_straight(sample_list):
sample_list = collate_fn(sample_list)
return sample_list
def collate_fn_straight_with_fn(fn):
def build_collate_fn(sample_list):
sample_list = collate_fn(sample_list)
sample_list_processed = fn(sample_list)
return {**sample_list, **sample_list_processed}
return build_collate_fn
def get_dataloader(dataset, batch_size, shuffle=False, num_workers=None, collate_fn=None, sampler=None):
if num_workers is None:
num_workers = batch_size // 4
# num_workers = min(num_workers, batch_size)
if collate_fn == None:
_collate_fn = collate_fn_straight
else:
_collate_fn = collate_fn_straight_with_fn(collate_fn)
return DataLoader(dataset, batch_size=batch_size,
collate_fn=_collate_fn,
shuffle=shuffle,
num_workers=num_workers,
sampler=sampler)
def get_lightening_dataloader(dataset, batch_size, shuffle=False, num_workers=None):
return LitDataModule(batch_size, dataset, shuffle, num_workers)
class LitDataModule(LightningDataModule):
def __init__(self, batch_size, dataset, shuffle, num_workers):
super().__init__()
self.save_hyperparameters(ignore=['dataset'])
# or
self.batch_size = batch_size
self.dataset = dataset
def train_dataloader(self):
return DataLoader(self.dataset, batch_size=self.batch_size,
collate_fn=collate_fn_straight,
shuffle=self.hparams.shuffle,
num_workers=self.hparams.num_workers)
if __name__ == '__main__':
import json
train_ds = PersonaChatDataset(data_path='data_file/ConvAI2/train_self_original_no_cands.txt',
)
from tqdm import tqdm
jsonfy_data = []
for data in tqdm(train_ds):
context_input = "\n".join(data['context_input'])
persona_input = '\n'.join(data['persona_list'])
jsonfy_data.append({
"instruction": f"""Given the dialog history between Q and R is:
{context_input}
Given the personality of the R as:
{persona_input}
Please response to Q according to both the dialog history and the R's personality.
Now, the R would say:""",
"input": "",
"output": data['target'],
"answer": "",
})
with open('data_file/train.json', 'w') as writer:
json.dump(jsonfy_data, writer)
jsonfy_data = []
del train_ds
train_ds = PersonaChatDataset(data_path='data_file/ConvAI2/valid_self_original_no_cands.txt',
)
for data in tqdm(train_ds):
context_input = "\n".join(data['context_input'])
persona_input = '\n'.join(data['persona_list'])
jsonfy_data.append({
"instruction": f"""Given the dialog history between Q and R is:
{context_input}
Given the personality of the R as:
{persona_input}
Please response to Q according to both the dialog history and the R's personality.
Now, the R would say:""",
"input": "",
"output": data['target'],
"answer": "",
})
with open('data_file/valid.json', 'w') as writer:
json.dump(jsonfy_data, writer)
with open('data_file/test.json', 'w') as writer:
json.dump(jsonfy_data, writer)