File size: 9,943 Bytes
8359bb1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
from enum import Enum
import torch
class TASK_TYPE(Enum):
GENERATE_RESPONSE = 'generate_response'
GENERATE_PERSONA = 'generate_persona'
def format_personachat_input(batch, left_tokenizer, right_tokenizer, config, for_test=False, find_batch=False):
batch_size = len(batch['context_input'])
pad_token_id = left_tokenizer.pad_token_id
targets = [t.strip() for t in batch['target']]
eos_token = left_tokenizer.eos_token
concat_context = [' '.join(context) for context in batch['context_input']]
concat_persona = [' '.join(persona) for persona in batch['persona_list']]
concat_input = [f'#persona#{persona}#context#{context}' for persona, context in
zip(concat_persona, concat_context)]
inference_tokenized = None
bos_token = left_tokenizer.bos_token
if for_test:
inference_input = [f'#persona#{persona}#context#{context}{bos_token}' for persona, context in
zip(concat_persona, concat_context)]
inference_tokenized = left_tokenizer(inference_input, add_special_tokens=False, return_tensors='pt',
padding='max_length', truncation=True,
max_length=config.dataset.max_token_length - 16)
# processing target
_target_with_bos = [f'{bos_token}{target}{eos_token}' for target in targets]
_target_with_bos_pt = right_tokenizer(_target_with_bos,
add_special_tokens=False, return_tensors='pt', \
padding=True)
_target_pt = _target_with_bos_pt.copy()
_target_pt['input_ids'] = torch.cat((_target_pt['input_ids'][:, 1:],
_target_pt['input_ids'].new_ones(batch_size, 1) * pad_token_id), dim=1)
_target_pt['attention_mask'] = torch.cat((_target_pt['attention_mask'][:, 1:],
_target_pt['attention_mask'].new_zeros(batch_size, 1)), dim=1)
# processing concat
context_pt = left_tokenizer(concat_input, add_special_tokens=False, return_tensors='pt',
padding='max_length', truncation=True,
max_length=config.dataset.max_token_length)
input_pt = torch.cat((context_pt['input_ids'], _target_with_bos_pt['input_ids']),
dim=1)[:, -config.dataset.max_token_length:]
input_attn = torch.cat((context_pt['attention_mask'], _target_with_bos_pt['attention_mask']),
dim=1)[:, -config.dataset.max_token_length:]
lm_input = {'input_ids': input_pt, 'attention_mask': input_attn}
if find_batch:
lm_target = torch.cat((context_pt['input_ids'],
_target_pt['input_ids']), dim=1)[:, -config.dataset.max_token_length:]
else:
lm_target = torch.cat((context_pt['input_ids'] * 0 - 1,
_target_pt['input_ids']), dim=1)[:, -config.dataset.max_token_length:]
if for_test:
return lm_input, lm_target, inference_tokenized
return lm_input, lm_target
# Template Type:
# 0: </s>
def format_causal_personachat_input(batch, left_tokenizer, right_tokenizer, config, for_test=False,
find_batch=False, template_type=0):
template_types = [
'{cinput} R: {target}',
'{cinput} R: [COMPLETE] the answer for [COMPLETE] is {target}'
]
bos_token = left_tokenizer.bos_token
eos_token = left_tokenizer.eos_token
batch_size = len(batch['context_input'])
pad_token_id = right_tokenizer.pad_token_id
targets = [t.strip() for t in batch['target']]
concat_context = [' '.join(context) for context in batch['context_input']]
concat_persona = [' '.join(persona) for persona in batch['persona_list']]
concat_input = [f'given persona: {persona}; context: {context}' for persona, context in
zip(concat_persona, concat_context)]
concat_input_target = [template_types[template_type].format(cinput=cinput, target=target) for cinput, target in
zip(concat_input, targets)]
bos_concat_input = [f'{bos_token}{cinput}{eos_token}' for cinput in concat_input_target]
lm_input = right_tokenizer(bos_concat_input, add_special_tokens=False, return_tensors='pt',
padding='max_length', truncation=True,
max_length=config.dataset.max_token_length)
lm_target = lm_input.copy()
lm_target = torch.cat((lm_target['input_ids'][:, 1:], lm_target['input_ids'].new_full(
(batch_size, 1), pad_token_id)), dim=1)
# lm_target['attention_mask'] = torch.cat(
# (lm_target['attention_mask'][:, 1:], lm_target['attention_mask'].new_full(
# (batch_size, 1), 0)), dim=1)
# freeze persona
if config.training.freeze_persona.__class__ is bool and config.training.freeze_persona:
for _lm_target in lm_target:
if 'given persona:' not in left_tokenizer.decode(_lm_target):
continue
_tokens = left_tokenizer.convert_ids_to_tokens(_lm_target)
_token_ids = _lm_target
_token_idx = None
for idx in range(0, len(_tokens) - 1):
if _tokens[idx].endswith('context') and _tokens[idx + 1].endswith(':'):
_token_idx = idx
break
_token_ids[idx] = left_tokenizer.pad_token_id
# freeze context
if config.training.freeze_context.__class__ is bool and config.training.freeze_context:
for _lm_target in lm_target:
_tokens = left_tokenizer.convert_ids_to_tokens(_lm_target)
_token_ids = _lm_target
_start_idx = None
_end_idx = None
for idx in range(0, len(_tokens) - 1):
if _tokens[idx].endswith('context') and _tokens[idx + 1].endswith(':'):
_start_idx = idx
if _tokens[idx].endswith('R') and _tokens[idx + 1].endswith(':'):
_end_idx = idx + 2
if _start_idx is None or _end_idx is None:
continue
for idx in range(_start_idx, _end_idx):
_token_ids[idx] = left_tokenizer.pad_token_id
if for_test:
inference_input = [template_types[template_type].format(cinput=cinput, target='') for cinput in concat_input]
bos_concat_input = [f'{bos_token}{cinput}' for cinput in inference_input]
inference_tokenized = left_tokenizer(bos_concat_input, add_special_tokens=False
, return_tensors='pt',
padding=True, truncation=True,
max_length=config.dataset.max_token_length)
return lm_input, lm_target, inference_tokenized
return lm_input, lm_target
def format_generate_persona_input(batch, left_tokenizer, right_tokenizer, config, for_test=False, find_batch=False):
batch_size = len(batch['context_input'])
pad_token_id = left_tokenizer.pad_token_id
targets = [' '.join(persona) for persona in batch['persona_list']]
eos_token = left_tokenizer.eos_token
concat_context = [' '.join(context) for context in batch['context_input']]
concat_input = [f'#context#{context}' for context in
concat_context]
inference_tokenized = None
bos_token = left_tokenizer.bos_token
if for_test:
inference_input = [f'#context#{context}{bos_token}' for context in
concat_context]
inference_tokenized = left_tokenizer(inference_input, add_special_tokens=False, return_tensors='pt',
padding='max_length', truncation=True,
max_length=config.dataset.max_token_length - 16)
# processing target
_target_with_bos = [f'{bos_token}{target}{eos_token}' for target in targets]
_target_with_bos_pt = right_tokenizer(_target_with_bos,
add_special_tokens=False, return_tensors='pt',
padding=True)
_target_pt = _target_with_bos_pt.copy()
_target_pt['input_ids'] = torch.cat((_target_pt['input_ids'][:, 1:],
_target_pt['input_ids'].new_ones(batch_size, 1) * pad_token_id), dim=1)
_target_pt['attention_mask'] = torch.cat((_target_pt['attention_mask'][:, 1:],
_target_pt['attention_mask'].new_zeros(batch_size, 1)), dim=1)
# processing concat
context_pt = left_tokenizer(concat_input, add_special_tokens=False, return_tensors='pt',
padding='max_length', truncation=True,
max_length=config.dataset.max_token_length)
input_pt = torch.cat((context_pt['input_ids'], _target_with_bos_pt['input_ids']),
dim=1)[:, -config.dataset.max_token_length:]
input_attn = torch.cat((context_pt['attention_mask'], _target_with_bos_pt['attention_mask']),
dim=1)[:, -config.dataset.max_token_length:]
lm_input = {'input_ids': input_pt, 'attention_mask': input_attn}
if find_batch:
lm_target = torch.cat((context_pt['input_ids'],
_target_pt['input_ids']), dim=1)[:, -config.dataset.max_token_length:]
else:
lm_target = torch.cat((context_pt['input_ids'] * 0 - 1,
_target_pt['input_ids']), dim=1)[:, -config.dataset.max_token_length:]
if for_test:
return lm_input, lm_target, inference_tokenized
return lm_input, lm_target |