|
""" |
|
Code for modifying categorical distributions to improve quality of sampling. |
|
|
|
Adapted from: |
|
- https://github.com/e-c-k-e-r/vall-e/blob/master/vall_e/samplers.py |
|
- Mirosoft UniLM |
|
- Matthew Baas's typical sampling code. |
|
- https://github.com/LostRuins/koboldcpp |
|
""" |
|
|
|
import math |
|
import torch |
|
import torch.nn.functional as F |
|
import numpy as np |
|
import logging |
|
|
|
from torch import Tensor, nn |
|
|
|
|
|
def freq_rep_penalty(logits: Tensor, previous: Tensor, alpha_frequency: float, alpha_presence: float, penalty_window: int = 100) -> Tensor: |
|
""" Apply frequency and presence penalty according to openai's formuation. |
|
Concretely: given `logits` (bs, vocab_size) and `previous` (bs, seq_len,) |
|
|
|
Modified to support batched inference. |
|
|
|
See: https://platform.openai.com/docs/guides/text-generation/parameter-details |
|
""" |
|
bs = logits.shape[0] |
|
previous = previous[..., -penalty_window:] |
|
c = torch.zeros_like(logits, device=logits.device, dtype=torch.long) |
|
for i in range(bs): |
|
vals, cnts = previous[i].unique(return_counts=True) |
|
c[i, vals] = cnts.to(c.device) |
|
|
|
logits = logits - c * alpha_frequency - (c > 0).to(logits.dtype) * alpha_presence |
|
return logits |
|
|
|
|
|
def early_eos_penalty(logits: Tensor, n_generated: int, estimated_gen_length: int, decay: float, factor: float = 1, eos_index: int = 0) -> Tensor: |
|
""" Penalize the `eos_index` of `logits` (bs, vocab_size) up to `estimated_gen_length`, |
|
whereby we reduce the logit value by `factor`*(expected_length - current_length)^decay, |
|
`n_generated` is the current number of generated samples. `decay` anneals the penalty relative to the distance. |
|
|
|
Good values for decay are between 0 and 1. 0 = hard always apply penalty of 1, 1 = linearly scale penalty relative to distance. |
|
Setting factor = 0 disabled penatly. Increasing factor increases penalty. |
|
""" |
|
if n_generated > estimated_gen_length: return logits |
|
penalty = max(estimated_gen_length - n_generated, 1) |
|
|
|
bigger = logits[:, eos_index] > 0 |
|
|
|
modifier = factor*(penalty ** decay) |
|
|
|
|
|
logits[:, eos_index] -= modifier |
|
return logits |
|
|
|
|
|
|
|
|
|
def top_k_top_p_filtering( logits: Tensor, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens=1 ) -> Tensor: |
|
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering |
|
Args: |
|
logits: logits distribution shape (batch size, vocabulary size) |
|
if top_k > 0: keep only top k tokens with highest probability (top-k filtering). |
|
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). |
|
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) |
|
Make sure we keep at least min_tokens per batch example in the output |
|
""" |
|
if top_k > 0: |
|
top_k = min(max(top_k, min_tokens), logits.size(-1)) |
|
|
|
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] |
|
logits[indices_to_remove] = filter_value |
|
|
|
if top_p < 1.0: |
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
|
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
|
|
|
|
|
sorted_indices_to_remove = cumulative_probs > top_p |
|
if min_tokens > 1: |
|
|
|
sorted_indices_to_remove[..., :min_tokens] = 0 |
|
|
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
|
sorted_indices_to_remove[..., 0] = 0 |
|
|
|
|
|
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) |
|
logits[indices_to_remove] = filter_value |
|
|
|
return logits |
|
|
|
|
|
def apply_typical_p(logprobs: Tensor, mass: float) -> Tensor: |
|
""" Warp categorical logprobs associated with `x` to be in line with `mass`. Last dimension is the bin dimension. |
|
`mass` corresponds to `tau` in the paper. |
|
""" |
|
if mass > 0.999: return logprobs |
|
|
|
|
|
|
|
normalized = torch.nn.functional.log_softmax(logprobs, dim=-1) |
|
p = torch.exp(normalized) |
|
ent = -(normalized * p).nansum(-1, keepdim=True) |
|
|
|
|
|
shifted_scores = torch.abs((-normalized) - ent) |
|
sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False) |
|
sorted_logits = logprobs.gather(-1, sorted_indices) |
|
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) |
|
|
|
|
|
last_ind = (cumulative_probs < mass).sum(dim=1) |
|
last_ind[last_ind < 0] = 0 |
|
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1)) |
|
|
|
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) |
|
|
|
scores = logprobs.masked_fill(indices_to_remove, -float('Inf')) |
|
return scores |