|
|
|
import torch, math |
|
import torch.nn.functional as F |
|
|
|
|
|
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): |
|
|
|
|
|
if top_k > 0: |
|
|
|
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] |
|
logits[indices_to_remove] = filter_value |
|
if top_p > 0.0: |
|
|
|
sorted_logits, sorted_indices = torch.sort(logits, dim=-1, descending=True) |
|
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
|
|
|
sorted_indices_to_remove = cumulative_probs > top_p |
|
|
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
|
sorted_indices_to_remove[..., 0] = 0 |
|
|
|
for i in range(sorted_indices.size()[0]): |
|
indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]] |
|
logits[i][indices_to_remove] = filter_value |
|
return logits |
|
|
|
|
|
def enforce_repetition_penalty(lprobs, prev_output_tokens, repetition_penalty=1.5): |
|
"""repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858). """ |
|
for previous_token in set(prev_output_tokens): |
|
|
|
if lprobs[previous_token] < 0: |
|
lprobs[previous_token] *= repetition_penalty |
|
else: |
|
lprobs[previous_token] /= repetition_penalty |
|
|
|
|
|
def switch(next_value, init, is_update): |
|
is_update = is_update.type_as(next_value) |
|
return (1-is_update)*init + is_update*next_value |
|
|
|
|
|
def get_atten_mask(batch_size, seq_length, memory_length=0): |
|
memory_attention_mask = torch.ones( |
|
(batch_size, 1, seq_length, seq_length + memory_length), dtype=torch.int16) |
|
memory_attention_mask = torch.tril( |
|
torch.triu(memory_attention_mask, 1 - seq_length + memory_length), memory_length) |
|
|
|
return memory_attention_mask |
|
|
|
|
|
def get_masks_and_position_ids(data, mem_length=None): |
|
|
|
batch_size, seq_length = data.size() |
|
|
|
attention_mask = torch.ones((1, seq_length, seq_length + mem_length), device=data.device) |
|
attention_mask = torch.tril(torch.triu(attention_mask, 1 - seq_length + mem_length), mem_length) |
|
attention_mask = attention_mask.unsqueeze(1) |
|
|
|
position_ids = torch.arange(seq_length, dtype=torch.long, |
|
device=data.device) |
|
position_ids = position_ids.unsqueeze(0).expand_as(data) |
|
return attention_mask, position_ids |
|
|
|
|
|
def sample_sequence_batch(model, context_tokens_tensor, context_length_tensor, max_out_seq=None, mems=None, |
|
end_token_id=None, repetition_penalty=1.0, temperature=1.0, top_k=0, top_p=0.0): |
|
"""_summary_ |
|
|
|
Args: |
|
model (_type_): _description_ |
|
context_tokens_tensor (Tensor): [bs, seq_len] |
|
context_length_tensor (Tensor): [bs, ] |
|
max_out_seq (_type_, optional): _description_. Defaults to None. |
|
mems (_type_, optional): _description_. Defaults to None. |
|
end_token_id (_type_, optional): _description_. Defaults to None. |
|
repetition_penalty (float, optional): _description_. Defaults to 1.0. |
|
temperature (float, optional): _description_. Defaults to 1.0. |
|
top_k (int, optional): _description_. Defaults to 0. |
|
top_p (float, optional): _description_. Defaults to 0.0. |
|
|
|
Returns: |
|
_type_: _description_ |
|
""" |
|
|
|
model_dtype = next(model.parameters()).dtype |
|
org_context_length = torch.min(context_length_tensor).item() |
|
batch_size = context_tokens_tensor.shape[0] |
|
tokens = context_tokens_tensor[:, :org_context_length] |
|
attention_mask = get_atten_mask(batch_size, org_context_length).cuda(context_tokens_tensor.device).to(model_dtype) |
|
position_ids = torch.arange(org_context_length, dtype=torch.long, |
|
device=tokens.device) |
|
position_ids = position_ids.unsqueeze(0).expand_as(tokens) |
|
|
|
counter, mem_length = 0, 0 |
|
if mems is None: |
|
mems = [] |
|
if end_token_id is None: |
|
end_token_id = 50000 |
|
if max_out_seq is None: |
|
max_out_seq = 512 |
|
|
|
output_tokens_lists = [] |
|
|
|
|
|
origin_order = torch.tensor(range(batch_size), device=tokens.device) |
|
output_order = [] |
|
|
|
|
|
log_probs_tensor = torch.tensor([0.0] * batch_size, device=tokens.device) |
|
log_probs_list = [] |
|
|
|
with torch.no_grad(): |
|
|
|
while counter < max_out_seq: |
|
index = org_context_length + counter |
|
if counter == 0: |
|
output = model.forward(input_ids=tokens, position_ids=position_ids, |
|
attention_mask=attention_mask, hidden_states=mems) |
|
logits, mems = output.logits, output.hidden_states |
|
else: |
|
output = model.forward(input_ids=tokens[:, index - 1: index], position_ids=tokens.new_ones((1, 1)) * (index - 1), |
|
attention_mask=tokens.new_ones(batch_size, 1, 1, mem_length + 1).to(model_dtype), hidden_states=mems) |
|
logits, mems = output.logits, output.hidden_states |
|
logits = logits[:, -1] |
|
logits /= temperature |
|
logits = top_k_logits(logits, top_k=top_k, top_p=top_p) |
|
|
|
|
|
|
|
log_probs = F.softmax(logits, dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
prev = torch.multinomial(log_probs, num_samples=1).view(-1) |
|
|
|
if index < torch.max(context_length_tensor).item(): |
|
prev = switch( |
|
prev, context_tokens_tensor[:, index], context_length_tensor <= index) |
|
|
|
for i in range(batch_size): |
|
if index > context_length_tensor[i] and prev[i] != end_token_id: |
|
log_probs_tensor[i] += math.log(log_probs[i][prev[i]] + 1e-6) |
|
if prev[i] == end_token_id: |
|
log_probs_tensor[i] /= (context_length_tensor[i].cpu() - index) |
|
|
|
|
|
stop_idx = prev == end_token_id |
|
if torch.all(stop_idx).item(): |
|
output_order.extend(origin_order[stop_idx].tolist()) |
|
break |
|
|
|
finished = tokens[stop_idx] |
|
output_tokens_lists.extend(finished.detach().cpu().tolist()) |
|
log_probs_list.extend(log_probs_tensor[stop_idx].tolist()) |
|
output_order.extend(origin_order[stop_idx].tolist()) |
|
|
|
|
|
conti_idx = (prev != end_token_id) |
|
origin_order = origin_order[conti_idx] |
|
tokens, prev = tokens[conti_idx], prev[conti_idx] |
|
context_tokens_tensor = context_tokens_tensor[conti_idx] |
|
context_length_tensor = context_length_tensor[conti_idx] |
|
log_probs_tensor = log_probs_tensor[conti_idx] |
|
batch_size = tokens.shape[0] |
|
for im in range(len(mems)): |
|
mems[im] = mems[im][conti_idx, :, :] |
|
|
|
tokens = torch.cat((tokens, prev.view(batch_size, 1)), dim=-1) |
|
|
|
counter += 1 |
|
|
|
output_tokens_lists.extend(tokens.detach().cpu().tolist()) |
|
log_probs_list.extend(log_probs_tensor.tolist()) |
|
output_order.extend(origin_order.tolist()) |
|
output_tokens_lists = [tokens[:tokens.index( |
|
end_token_id)] if end_token_id in tokens else tokens for tokens in output_tokens_lists] |
|
|
|
output_tokens_lists = [tokens for _, tokens in sorted(zip(output_order, output_tokens_lists))] |
|
output_log_porbs = [prob for _, prob in sorted(zip(output_order, log_probs_list))] |
|
|
|
return output_tokens_lists, output_log_porbs |
|
|
|
|
|
def sample_sequence(model, tokens, attention_mask, do_sampling=True, |
|
repetition_penalty=1.0, max_out_seq=None, mems=None, end_token_id=None, |
|
mem_length=0, temperature=1.0, top_k=0, top_p=0.0): |
|
"""_summary_ |
|
|
|
Args: |
|
model (_type_): _description_ |
|
tokens (Tensor): [1, seq_len] |
|
attention_mask (Tensor): [1, 1, seq_len, seq_len] |
|
do_sampling (bool, optional): _description_. Defaults to True. |
|
repetition_penalty (float, optional): _description_. Defaults to 1.0. |
|
max_out_seq (_type_, optional): _description_. Defaults to None. |
|
mems (_type_, optional): _description_. Defaults to None. |
|
end_token (_type_, optional): _description_. Defaults to None. |
|
mem_length (int, optional): _description_. Defaults to 0. |
|
temperature (float, optional): _description_. Defaults to 1.0. |
|
top_k (int, optional): _description_. Defaults to 0. |
|
top_p (float, optional): _description_. Defaults to 0.0. |
|
|
|
Returns: |
|
_type_: _description_ |
|
""" |
|
counter = 0 |
|
if mems is None: |
|
mems = [] |
|
if end_token_id is None: |
|
end_token_id = 50000 |
|
if max_out_seq is None: |
|
max_out_seq = 512 |
|
org_context_length = tokens.size(1) |
|
with torch.no_grad(): |
|
|
|
while counter < max_out_seq: |
|
if counter == 0: |
|
logits, *mems = model(input_ids=tokens, position_ids=None, |
|
attention_mask=attention_mask, mems=mems) |
|
else: |
|
index = org_context_length + counter |
|
logits, *mems = model(input_ids=tokens[:, index - 1: index], position_ids=None, |
|
attention_mask=tokens.new_ones(1, 1, 1, mem_length + 1), mems=mems) |
|
logits = logits[:, -1] |
|
logits /= temperature |
|
if do_sampling: |
|
logits = top_k_logits(logits, top_k=top_k, top_p=top_p) |
|
log_probs = F.softmax(logits, dim=-1) |
|
|
|
if repetition_penalty != 1.0: |
|
enforce_repetition_penalty( |
|
log_probs[0, :], tokens[0, :], repetition_penalty) |
|
prev = torch.multinomial(log_probs, num_samples=1)[0] |
|
is_end = (prev == end_token_id) |
|
if is_end: |
|
break |
|
tokens = torch.cat((tokens, prev.view(1, 1)), dim=1) |
|
counter += 1 |
|
|
|
output_tokens_list = tokens.detach().cpu().tolist() |
|
if end_token_id in output_tokens_list: |
|
output_tokens_list = output_tokens_list[:output_tokens_list.index( |
|
end_token_id)] |
|
|
|
return output_tokens_list[0], mems |
|
|