|
import torch
|
|
from pytorch_lightning import LightningDataModule
|
|
from torch.utils.data import DataLoader
|
|
|
|
from dataset.dataset_helper import read_personachat_split
|
|
from utils.format_inputs import TASK_TYPE
|
|
|
|
|
|
class PersonaChatDataset(torch.utils.data.Dataset):
|
|
|
|
def __init__(self, data_path, max_context_turns=-1,
|
|
add_role_indicator=True, only_longest=False, training_ratio=1.0,
|
|
task_type=TASK_TYPE.GENERATE_RESPONSE):
|
|
self.path = data_path
|
|
self.add_role_indicator = add_role_indicator
|
|
self.max_context_turns = max_context_turns
|
|
self.turns_data = read_personachat_split(data_path, only_longest=only_longest)
|
|
self.only_longest = only_longest
|
|
self.training_ratio = training_ratio
|
|
if training_ratio < 1.0:
|
|
self.turns_data = self.turns_data[:int(len(self.turns_data) * training_ratio)]
|
|
self.task_type = task_type
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def sort_longest_first(self):
|
|
self.turns_data = sorted(self.turns_data, key=lambda x: len(
|
|
(' '.join(x['persona']) + ' '.join(x['context']) + x['response']).split(' ')), reverse=True)
|
|
|
|
def __getitem__(self, idx):
|
|
|
|
|
|
input_data = self.turns_data[idx]
|
|
persona_list = input_data['persona']
|
|
target = input_data['response']
|
|
context_input = input_data['context']
|
|
if self.add_role_indicator:
|
|
roled_context_input = [['Q: ', 'R: '][c_idx % 2] + context for c_idx, context in enumerate(context_input)]
|
|
context_input = roled_context_input
|
|
if self.max_context_turns != -1:
|
|
truncated_context = context_input[-(self.max_context_turns * 2 - 1):]
|
|
context_input = truncated_context
|
|
if self.only_longest:
|
|
context_input = context_input[:-1]
|
|
return {
|
|
'context_input': context_input,
|
|
'persona_list': persona_list,
|
|
'target': target
|
|
}
|
|
|
|
def __len__(self):
|
|
return len(self.turns_data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def collate_fn(sample_list):
|
|
dont_be_a_tensor = ['context_input', 'persona_list', 'target']
|
|
to_be_flattened = [*dont_be_a_tensor]
|
|
data = {}
|
|
for key in to_be_flattened:
|
|
if key not in sample_list[0].keys():
|
|
continue
|
|
if sample_list[0][key] is None:
|
|
continue
|
|
flatten_samples = [sample[key] for sample in sample_list]
|
|
if flatten_samples[-1].__class__ == str or key in dont_be_a_tensor:
|
|
data[key] = flatten_samples
|
|
else:
|
|
data[key] = torch.tensor(flatten_samples)
|
|
return data
|
|
|
|
|
|
def collate_fn_straight(sample_list):
|
|
sample_list = collate_fn(sample_list)
|
|
return sample_list
|
|
|
|
|
|
def collate_fn_straight_with_fn(fn):
|
|
def build_collate_fn(sample_list):
|
|
sample_list = collate_fn(sample_list)
|
|
sample_list_processed = fn(sample_list)
|
|
return {**sample_list, **sample_list_processed}
|
|
|
|
return build_collate_fn
|
|
|
|
|
|
def get_dataloader(dataset, batch_size, shuffle=False, num_workers=None, collate_fn=None, sampler=None):
|
|
if num_workers is None:
|
|
num_workers = batch_size // 4
|
|
|
|
if collate_fn == None:
|
|
_collate_fn = collate_fn_straight
|
|
else:
|
|
_collate_fn = collate_fn_straight_with_fn(collate_fn)
|
|
return DataLoader(dataset, batch_size=batch_size,
|
|
collate_fn=_collate_fn,
|
|
shuffle=shuffle,
|
|
num_workers=num_workers,
|
|
sampler=sampler)
|
|
|
|
|
|
def get_lightening_dataloader(dataset, batch_size, shuffle=False, num_workers=None):
|
|
return LitDataModule(batch_size, dataset, shuffle, num_workers)
|
|
|
|
|
|
class LitDataModule(LightningDataModule):
|
|
def __init__(self, batch_size, dataset, shuffle, num_workers):
|
|
super().__init__()
|
|
self.save_hyperparameters(ignore=['dataset'])
|
|
|
|
self.batch_size = batch_size
|
|
self.dataset = dataset
|
|
|
|
def train_dataloader(self):
|
|
return DataLoader(self.dataset, batch_size=self.batch_size,
|
|
collate_fn=collate_fn_straight,
|
|
shuffle=self.hparams.shuffle,
|
|
num_workers=self.hparams.num_workers)
|
|
|
|
if __name__ == '__main__':
|
|
import json
|
|
train_ds = PersonaChatDataset(data_path='data_file/ConvAI2/train_self_original_no_cands.txt',
|
|
)
|
|
from tqdm import tqdm
|
|
|
|
jsonfy_data = []
|
|
|
|
for data in tqdm(train_ds):
|
|
context_input = "\n".join(data['context_input'])
|
|
persona_input = '\n'.join(data['persona_list'])
|
|
jsonfy_data.append({
|
|
"instruction": f"""Given the dialog history between Q and R is:
|
|
{context_input}
|
|
|
|
Given the personality of the R as:
|
|
{persona_input}
|
|
|
|
Please response to Q according to both the dialog history and the R's personality.
|
|
Now, the R would say:""",
|
|
"input": "",
|
|
"output": data['target'],
|
|
"answer": "",
|
|
})
|
|
with open('data_file/train.json', 'w') as writer:
|
|
json.dump(jsonfy_data, writer)
|
|
jsonfy_data = []
|
|
del train_ds
|
|
|
|
train_ds = PersonaChatDataset(data_path='data_file/ConvAI2/valid_self_original_no_cands.txt',
|
|
)
|
|
|
|
for data in tqdm(train_ds):
|
|
context_input = "\n".join(data['context_input'])
|
|
persona_input = '\n'.join(data['persona_list'])
|
|
jsonfy_data.append({
|
|
"instruction": f"""Given the dialog history between Q and R is:
|
|
{context_input}
|
|
|
|
Given the personality of the R as:
|
|
{persona_input}
|
|
|
|
Please response to Q according to both the dialog history and the R's personality.
|
|
Now, the R would say:""",
|
|
"input": "",
|
|
"output": data['target'],
|
|
"answer": "",
|
|
})
|
|
with open('data_file/valid.json', 'w') as writer:
|
|
json.dump(jsonfy_data, writer)
|
|
with open('data_file/test.json', 'w') as writer:
|
|
json.dump(jsonfy_data, writer) |