import torch, math |
import torch.nn.functional as F |
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): |
if top_k > 0: |
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] |
logits[indices_to_remove] = filter_value |
if top_p > 0.0: |
sorted_logits, sorted_indices = torch.sort(logits, dim=-1, descending=True) |
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
sorted_indices_to_remove = cumulative_probs > top_p |
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
sorted_indices_to_remove[..., 0] = 0 |
for i in range(sorted_indices.size()[0]): |
indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]] |
logits[i][indices_to_remove] = filter_value |
return logits |
def enforce_repetition_penalty(lprobs, prev_output_tokens, repetition_penalty=1.5): |
"""repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858). """ |
for previous_token in set(prev_output_tokens): |
if lprobs[previous_token] < 0: |
lprobs[previous_token] *= repetition_penalty |
else: |
lprobs[previous_token] /= repetition_penalty |
def switch(next_value, init, is_update): |
is_update = is_update.type_as(next_value) |
return (1-is_update)*init + is_update*next_value |
def get_atten_mask(batch_size, seq_length, memory_length=0): |
memory_attention_mask = torch.ones( |
(batch_size, 1, seq_length, seq_length + memory_length), dtype=torch.int16) |
memory_attention_mask = torch.tril( |
torch.triu(memory_attention_mask, 1 - seq_length + memory_length), memory_length) |
return memory_attention_mask |
def get_masks_and_position_ids(data, mem_length=None): |
batch_size, seq_length = data.size() |
attention_mask = torch.ones((1, seq_length, seq_length + mem_length), device=data.device) |
attention_mask = torch.tril(torch.triu(attention_mask, 1 - seq_length + mem_length), mem_length) |
attention_mask = attention_mask.unsqueeze(1) |
position_ids = torch.arange(seq_length, dtype=torch.long, |
device=data.device) |
position_ids = position_ids.unsqueeze(0).expand_as(data) |
return attention_mask, position_ids |
def sample_sequence_batch(model, context_tokens_tensor, context_length_tensor, max_out_seq=None, mems=None, |
end_token_id=None, repetition_penalty=1.0, temperature=1.0, top_k=0, top_p=0.0): |
"""_summary_ |
Args: |
model (_type_): _description_ |
context_tokens_tensor (Tensor): [bs, seq_len] |
context_length_tensor (Tensor): [bs, ] |
max_out_seq (_type_, optional): _description_. Defaults to None. |
mems (_type_, optional): _description_. Defaults to None. |
end_token_id (_type_, optional): _description_. Defaults to None. |
repetition_penalty (float, optional): _description_. Defaults to 1.0. |
temperature (float, optional): _description_. Defaults to 1.0. |
top_k (int, optional): _description_. Defaults to 0. |
top_p (float, optional): _description_. Defaults to 0.0. |
Returns: |
_type_: _description_ |
""" |
model_dtype = next(model.parameters()).dtype |
org_context_length = torch.min(context_length_tensor).item() |
batch_size = context_tokens_tensor.shape[0] |
tokens = context_tokens_tensor[:, :org_context_length] |
attention_mask = get_atten_mask(batch_size, org_context_length).cuda(context_tokens_tensor.device).to(model_dtype) |
position_ids = torch.arange(org_context_length, dtype=torch.long, |
device=tokens.device) |
position_ids = position_ids.unsqueeze(0).expand_as(tokens) |
counter, mem_length = 0, 0 |
if mems is None: |
mems = [] |
if end_token_id is None: |
end_token_id = 50000 |
if max_out_seq is None: |
max_out_seq = 512 |
output_tokens_lists = [] |
origin_order = torch.tensor(range(batch_size), device=tokens.device) |
output_order = [] |
log_probs_tensor = torch.tensor([0.0] * batch_size, device=tokens.device) |
log_probs_list = [] |
with torch.no_grad(): |
while counter < max_out_seq: |
index = org_context_length + counter |
if counter == 0: |
output = model.forward(input_ids=tokens, position_ids=position_ids, |
attention_mask=attention_mask, hidden_states=mems) |
logits, mems = output.logits, output.hidden_states |
else: |
output = model.forward(input_ids=tokens[:, index - 1: index], position_ids=tokens.new_ones((1, 1)) * (index - 1), |
attention_mask=tokens.new_ones(batch_size, 1, 1, mem_length + 1).to(model_dtype), hidden_states=mems) |
logits, mems = output.logits, output.hidden_states |
logits = logits[:, -1] |
logits /= temperature |
logits = top_k_logits(logits, top_k=top_k, top_p=top_p) |
log_probs = F.softmax(logits, dim=-1) |
prev = torch.multinomial(log_probs, num_samples=1).view(-1) |
if index < torch.max(context_length_tensor).item(): |
prev = switch( |
prev, context_tokens_tensor[:, index], context_length_tensor <= index) |
for i in range(batch_size): |
if index > context_length_tensor[i] and prev[i] != end_token_id: |
log_probs_tensor[i] += math.log(log_probs[i][prev[i]] + 1e-6) |
if prev[i] == end_token_id: |
log_probs_tensor[i] /= (context_length_tensor[i].cpu() - index) |
stop_idx = prev == end_token_id |
if torch.all(stop_idx).item(): |
output_order.extend(origin_order[stop_idx].tolist()) |
break |
finished = tokens[stop_idx] |
output_tokens_lists.extend(finished.detach().cpu().tolist()) |
log_probs_list.extend(log_probs_tensor[stop_idx].tolist()) |
output_order.extend(origin_order[stop_idx].tolist()) |
conti_idx = (prev != end_token_id) |
origin_order = origin_order[conti_idx] |
tokens, prev = tokens[conti_idx], prev[conti_idx] |
context_tokens_tensor = context_tokens_tensor[conti_idx] |
context_length_tensor = context_length_tensor[conti_idx] |
log_probs_tensor = log_probs_tensor[conti_idx] |
batch_size = tokens.shape[0] |
for im in range(len(mems)): |
mems[im] = mems[im][conti_idx, :, :] |
tokens = torch.cat((tokens, prev.view(batch_size, 1)), dim=-1) |
counter += 1 |
output_tokens_lists.extend(tokens.detach().cpu().tolist()) |
log_probs_list.extend(log_probs_tensor.tolist()) |
output_order.extend(origin_order.tolist()) |
output_tokens_lists = [tokens[:tokens.index( |
end_token_id)] if end_token_id in tokens else tokens for tokens in output_tokens_lists] |
output_tokens_lists = [tokens for _, tokens in sorted(zip(output_order, output_tokens_lists))] |
output_log_porbs = [prob for _, prob in sorted(zip(output_order, log_probs_list))] |
return output_tokens_lists, output_log_porbs |
def sample_sequence(model, tokens, attention_mask, do_sampling=True, |
repetition_penalty=1.0, max_out_seq=None, mems=None, end_token_id=None, |
mem_length=0, temperature=1.0, top_k=0, top_p=0.0): |
"""_summary_ |
Args: |
model (_type_): _description_ |
tokens (Tensor): [1, seq_len] |
attention_mask (Tensor): [1, 1, seq_len, seq_len] |
do_sampling (bool, optional): _description_. Defaults to True. |
repetition_penalty (float, optional): _description_. Defaults to 1.0. |
max_out_seq (_type_, optional): _description_. Defaults to None. |
mems (_type_, optional): _description_. Defaults to None. |
end_token (_type_, optional): _description_. Defaults to None. |
mem_length (int, optional): _description_. Defaults to 0. |
temperature (float, optional): _description_. Defaults to 1.0. |
top_k (int, optional): _description_. Defaults to 0. |
top_p (float, optional): _description_. Defaults to 0.0. |
Returns: |
_type_: _description_ |
""" |
counter = 0 |
if mems is None: |
mems = [] |
if end_token_id is None: |
end_token_id = 50000 |
if max_out_seq is None: |
max_out_seq = 512 |
org_context_length = tokens.size(1) |
with torch.no_grad(): |
while counter < max_out_seq: |
if counter == 0: |
logits, *mems = model(input_ids=tokens, position_ids=None, |
attention_mask=attention_mask, mems=mems) |
else: |
index = org_context_length + counter |
logits, *mems = model(input_ids=tokens[:, index - 1: index], position_ids=None, |
attention_mask=tokens.new_ones(1, 1, 1, mem_length + 1), mems=mems) |
logits = logits[:, -1] |
logits /= temperature |
if do_sampling: |
logits = top_k_logits(logits, top_k=top_k, top_p=top_p) |
log_probs = F.softmax(logits, dim=-1) |
if repetition_penalty != 1.0: |
enforce_repetition_penalty( |
log_probs[0, :], tokens[0, :], repetition_penalty) |
prev = torch.multinomial(log_probs, num_samples=1)[0] |
is_end = (prev == end_token_id) |
if is_end: |
break |
tokens = torch.cat((tokens, prev.view(1, 1)), dim=1) |
counter += 1 |
output_tokens_list = tokens.detach().cpu().tolist() |
if end_token_id in output_tokens_list: |
output_tokens_list = output_tokens_list[:output_tokens_list.index( |
end_token_id)] |
return output_tokens_list[0], mems |