SPT / dataset /dataset_helper.py
hqsiswiliam's picture
Upload 43 files
8359bb1 verified
import re
from tqdm import tqdm
def read_personachat_split(split_dir, only_longest=False):
results = []
their_per_group = None
try:
file = open(split_dir, 'r')
lines = file.readlines()
persona = []
context = []
response = None
candidates = []
is_longest = False
for line in tqdm(lines[:], desc='loading {}'.format(split_dir)):
if line.startswith('1 your persona:'):
is_longest = True
if is_longest and only_longest:
if response is not None:
results.append({'persona': persona.copy(), 'context': context.copy(), 'response': response,
'candidates': candidates.copy()})
is_longest = False
persona = []
context = []
if 'persona:' in line:
persona.append(line.split(':')[1].strip())
if 'persona:' not in line:
context.append(re.sub(r"^\d+ ", "", line.split("\t")[0].strip()))
response = line.split("\t")[1].strip()
if len(line.split("\t\t"))==1:
candidates = []
else:
candidates = line.split("\t\t")[1].strip().split("|")
if not only_longest:
results.append({'persona': persona.copy(), 'context': context.copy(), 'response': response,
'candidates': candidates.copy()})
context.append(response)
except FileNotFoundError:
print(f"Sorry! The file {split_dir} can't be found.")
return results
def combine_persona_query_response(persona, query, response, candidates):
assert ((len(persona) == len(query)) and (len(query) == len(response))), \
'the length of persona, query, response must be equivalent'
data = {}
for index, psn in enumerate(persona):
split_persona = psn.strip().split("\t")
psn = psn.replace("\t", " ").strip()
if psn not in data.keys():
data[psn] = {'persona': psn, 'query': [], 'response': [], 'dialog': [], 'response_turns': 0,
'persona_list': split_persona, 'candidates': []}
data[psn]['query'].append(query[index])
data[psn]['response'].append(response[index])
data[psn]['dialog'].append(query[index])
data[psn]['dialog'].append(response[index])
data[psn]['candidates'].append(candidates[index])
data[psn]['response_turns'] += 1
return data
def preprocess_text(text):
punctuations = '.,?'
for punc in punctuations:
text = text.replace(punc, ' {} '.format(punc))
text = re.sub(' +', ' ', text).strip()
return text
def preprocess_texts(text_array):
return [preprocess_text(t) for t in text_array]
# "turns" means we need at least how many turns
# "max_context_turns" means how many history turns should be kept
def get_chat_by_turns(combined_data, turns=1,
sep_token='[SEP]', add_role_indicator=True,
add_persona_indicator=True, max_context_turns=-1):
assert turns > 0, 'turns must be large than 0'
all_persona = list(combined_data.keys())
filtered_persona = list(filter(lambda p: combined_data[p]['response_turns'] >= turns, all_persona))
data = []
for single_persona in filtered_persona:
single_persona_data = combined_data[single_persona]
persona_list = single_persona_data['persona_list']
context = []
for index, (query, response) in enumerate(
zip(single_persona_data['query'], single_persona_data['response'])
):
if max_context_turns != -1 and \
index + 1 < single_persona_data['response_turns'] - max_context_turns:
continue
if add_role_indicator:
query = "Q: {}".format(query)
if not index + 1 >= turns:
response = "R: {}".format(response)
context += [query, response]
if index + 1 >= turns:
break
response = context[-1]
context = context[:-1]
input_x_str = " {} ".format(sep_token).join(context)
input_x_str = re.sub(" +", " ", input_x_str)
if add_persona_indicator:
single_persona = "P: {}".format(single_persona)
data.append({'input': preprocess_texts(context),
'input_str': preprocess_text(input_x_str),
'target': preprocess_text(response),
'persona': preprocess_text(single_persona),
'persona_list': preprocess_texts(persona_list),
'candidates': preprocess_texts(single_persona_data['candidates'][-1])})
return data