|
from enum import Enum
|
|
|
|
import torch
|
|
|
|
|
|
class TASK_TYPE(Enum):
|
|
GENERATE_RESPONSE = 'generate_response'
|
|
GENERATE_PERSONA = 'generate_persona'
|
|
|
|
|
|
|
|
def format_personachat_input(batch, left_tokenizer, right_tokenizer, config, for_test=False, find_batch=False):
|
|
batch_size = len(batch['context_input'])
|
|
pad_token_id = left_tokenizer.pad_token_id
|
|
targets = [t.strip() for t in batch['target']]
|
|
eos_token = left_tokenizer.eos_token
|
|
concat_context = [' '.join(context) for context in batch['context_input']]
|
|
concat_persona = [' '.join(persona) for persona in batch['persona_list']]
|
|
concat_input = [f'#persona#{persona}#context#{context}' for persona, context in
|
|
zip(concat_persona, concat_context)]
|
|
inference_tokenized = None
|
|
bos_token = left_tokenizer.bos_token
|
|
if for_test:
|
|
inference_input = [f'#persona#{persona}#context#{context}{bos_token}' for persona, context in
|
|
zip(concat_persona, concat_context)]
|
|
inference_tokenized = left_tokenizer(inference_input, add_special_tokens=False, return_tensors='pt',
|
|
padding='max_length', truncation=True,
|
|
max_length=config.dataset.max_token_length - 16)
|
|
|
|
_target_with_bos = [f'{bos_token}{target}{eos_token}' for target in targets]
|
|
_target_with_bos_pt = right_tokenizer(_target_with_bos,
|
|
add_special_tokens=False, return_tensors='pt', \
|
|
padding=True)
|
|
_target_pt = _target_with_bos_pt.copy()
|
|
_target_pt['input_ids'] = torch.cat((_target_pt['input_ids'][:, 1:],
|
|
_target_pt['input_ids'].new_ones(batch_size, 1) * pad_token_id), dim=1)
|
|
_target_pt['attention_mask'] = torch.cat((_target_pt['attention_mask'][:, 1:],
|
|
_target_pt['attention_mask'].new_zeros(batch_size, 1)), dim=1)
|
|
|
|
context_pt = left_tokenizer(concat_input, add_special_tokens=False, return_tensors='pt',
|
|
padding='max_length', truncation=True,
|
|
max_length=config.dataset.max_token_length)
|
|
input_pt = torch.cat((context_pt['input_ids'], _target_with_bos_pt['input_ids']),
|
|
dim=1)[:, -config.dataset.max_token_length:]
|
|
input_attn = torch.cat((context_pt['attention_mask'], _target_with_bos_pt['attention_mask']),
|
|
dim=1)[:, -config.dataset.max_token_length:]
|
|
lm_input = {'input_ids': input_pt, 'attention_mask': input_attn}
|
|
if find_batch:
|
|
lm_target = torch.cat((context_pt['input_ids'],
|
|
_target_pt['input_ids']), dim=1)[:, -config.dataset.max_token_length:]
|
|
else:
|
|
lm_target = torch.cat((context_pt['input_ids'] * 0 - 1,
|
|
_target_pt['input_ids']), dim=1)[:, -config.dataset.max_token_length:]
|
|
if for_test:
|
|
return lm_input, lm_target, inference_tokenized
|
|
return lm_input, lm_target
|
|
|
|
|
|
|
|
|
|
|
|
def format_causal_personachat_input(batch, left_tokenizer, right_tokenizer, config, for_test=False,
|
|
find_batch=False, template_type=0):
|
|
template_types = [
|
|
'{cinput} R: {target}',
|
|
'{cinput} R: [COMPLETE] the answer for [COMPLETE] is {target}'
|
|
]
|
|
bos_token = left_tokenizer.bos_token
|
|
eos_token = left_tokenizer.eos_token
|
|
batch_size = len(batch['context_input'])
|
|
pad_token_id = right_tokenizer.pad_token_id
|
|
targets = [t.strip() for t in batch['target']]
|
|
concat_context = [' '.join(context) for context in batch['context_input']]
|
|
concat_persona = [' '.join(persona) for persona in batch['persona_list']]
|
|
concat_input = [f'given persona: {persona}; context: {context}' for persona, context in
|
|
zip(concat_persona, concat_context)]
|
|
concat_input_target = [template_types[template_type].format(cinput=cinput, target=target) for cinput, target in
|
|
zip(concat_input, targets)]
|
|
bos_concat_input = [f'{bos_token}{cinput}{eos_token}' for cinput in concat_input_target]
|
|
lm_input = right_tokenizer(bos_concat_input, add_special_tokens=False, return_tensors='pt',
|
|
padding='max_length', truncation=True,
|
|
max_length=config.dataset.max_token_length)
|
|
lm_target = lm_input.copy()
|
|
lm_target = torch.cat((lm_target['input_ids'][:, 1:], lm_target['input_ids'].new_full(
|
|
(batch_size, 1), pad_token_id)), dim=1)
|
|
|
|
|
|
|
|
|
|
if config.training.freeze_persona.__class__ is bool and config.training.freeze_persona:
|
|
for _lm_target in lm_target:
|
|
if 'given persona:' not in left_tokenizer.decode(_lm_target):
|
|
continue
|
|
_tokens = left_tokenizer.convert_ids_to_tokens(_lm_target)
|
|
_token_ids = _lm_target
|
|
_token_idx = None
|
|
for idx in range(0, len(_tokens) - 1):
|
|
if _tokens[idx].endswith('context') and _tokens[idx + 1].endswith(':'):
|
|
_token_idx = idx
|
|
break
|
|
_token_ids[idx] = left_tokenizer.pad_token_id
|
|
|
|
if config.training.freeze_context.__class__ is bool and config.training.freeze_context:
|
|
for _lm_target in lm_target:
|
|
_tokens = left_tokenizer.convert_ids_to_tokens(_lm_target)
|
|
_token_ids = _lm_target
|
|
_start_idx = None
|
|
_end_idx = None
|
|
for idx in range(0, len(_tokens) - 1):
|
|
if _tokens[idx].endswith('context') and _tokens[idx + 1].endswith(':'):
|
|
_start_idx = idx
|
|
if _tokens[idx].endswith('R') and _tokens[idx + 1].endswith(':'):
|
|
_end_idx = idx + 2
|
|
if _start_idx is None or _end_idx is None:
|
|
continue
|
|
for idx in range(_start_idx, _end_idx):
|
|
_token_ids[idx] = left_tokenizer.pad_token_id
|
|
|
|
if for_test:
|
|
inference_input = [template_types[template_type].format(cinput=cinput, target='') for cinput in concat_input]
|
|
bos_concat_input = [f'{bos_token}{cinput}' for cinput in inference_input]
|
|
inference_tokenized = left_tokenizer(bos_concat_input, add_special_tokens=False
|
|
, return_tensors='pt',
|
|
padding=True, truncation=True,
|
|
max_length=config.dataset.max_token_length)
|
|
return lm_input, lm_target, inference_tokenized
|
|
return lm_input, lm_target
|
|
|
|
|
|
def format_generate_persona_input(batch, left_tokenizer, right_tokenizer, config, for_test=False, find_batch=False):
|
|
batch_size = len(batch['context_input'])
|
|
pad_token_id = left_tokenizer.pad_token_id
|
|
targets = [' '.join(persona) for persona in batch['persona_list']]
|
|
eos_token = left_tokenizer.eos_token
|
|
concat_context = [' '.join(context) for context in batch['context_input']]
|
|
concat_input = [f'#context#{context}' for context in
|
|
concat_context]
|
|
inference_tokenized = None
|
|
bos_token = left_tokenizer.bos_token
|
|
if for_test:
|
|
inference_input = [f'#context#{context}{bos_token}' for context in
|
|
concat_context]
|
|
inference_tokenized = left_tokenizer(inference_input, add_special_tokens=False, return_tensors='pt',
|
|
padding='max_length', truncation=True,
|
|
max_length=config.dataset.max_token_length - 16)
|
|
|
|
_target_with_bos = [f'{bos_token}{target}{eos_token}' for target in targets]
|
|
_target_with_bos_pt = right_tokenizer(_target_with_bos,
|
|
add_special_tokens=False, return_tensors='pt',
|
|
padding=True)
|
|
_target_pt = _target_with_bos_pt.copy()
|
|
_target_pt['input_ids'] = torch.cat((_target_pt['input_ids'][:, 1:],
|
|
_target_pt['input_ids'].new_ones(batch_size, 1) * pad_token_id), dim=1)
|
|
_target_pt['attention_mask'] = torch.cat((_target_pt['attention_mask'][:, 1:],
|
|
_target_pt['attention_mask'].new_zeros(batch_size, 1)), dim=1)
|
|
|
|
context_pt = left_tokenizer(concat_input, add_special_tokens=False, return_tensors='pt',
|
|
padding='max_length', truncation=True,
|
|
max_length=config.dataset.max_token_length)
|
|
input_pt = torch.cat((context_pt['input_ids'], _target_with_bos_pt['input_ids']),
|
|
dim=1)[:, -config.dataset.max_token_length:]
|
|
input_attn = torch.cat((context_pt['attention_mask'], _target_with_bos_pt['attention_mask']),
|
|
dim=1)[:, -config.dataset.max_token_length:]
|
|
lm_input = {'input_ids': input_pt, 'attention_mask': input_attn}
|
|
if find_batch:
|
|
lm_target = torch.cat((context_pt['input_ids'],
|
|
_target_pt['input_ids']), dim=1)[:, -config.dataset.max_token_length:]
|
|
else:
|
|
lm_target = torch.cat((context_pt['input_ids'] * 0 - 1,
|
|
_target_pt['input_ids']), dim=1)[:, -config.dataset.max_token_length:]
|
|
if for_test:
|
|
return lm_input, lm_target, inference_tokenized
|
|
return lm_input, lm_target |