|
|
|
|
|
import torch |
|
import torch.nn.functional as F |
|
|
|
def top_k_top_p_filtering( |
|
logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1 |
|
): |
|
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering |
|
Args: |
|
logits: logits distribution shape (batch size, vocabulary size) |
|
if top_k > 0: keep only top k tokens with highest probability (top-k filtering). |
|
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). |
|
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) |
|
Make sure we keep at least min_tokens_to_keep per batch example in the output |
|
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 |
|
""" |
|
if top_k > 0: |
|
top_k = min( |
|
max(top_k, min_tokens_to_keep), logits.size(-1) |
|
) |
|
|
|
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] |
|
logits[indices_to_remove] = filter_value |
|
|
|
if top_p < 1.0: |
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
|
cumulative_probs = torch.cumsum( |
|
F.softmax(sorted_logits, dim=-1), dim=-1 |
|
) |
|
|
|
|
|
sorted_indices_to_remove = cumulative_probs > top_p |
|
if min_tokens_to_keep > 1: |
|
|
|
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 |
|
|
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[ |
|
..., :-1 |
|
].clone() |
|
sorted_indices_to_remove[..., 0] = 0 |
|
|
|
|
|
indices_to_remove = sorted_indices_to_remove.scatter( |
|
1, sorted_indices, sorted_indices_to_remove |
|
) |
|
logits[indices_to_remove] = filter_value |
|
return logits |
|
|
|
def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if temperature != 1.0: |
|
logits = logits / temperature |
|
|
|
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) |
|
|
|
token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1) |
|
return token |