Spaces:
Paused
Paused
File size: 2,061 Bytes
d389c0e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
import numpy as np
import torch
import transformers
ce_loss_fn = torch.nn.CrossEntropyLoss(reduction="none")
softmax_fn = torch.nn.Softmax(dim=-1)
def perplexity(encoding: transformers.BatchEncoding,
logits: torch.Tensor,
median: bool = False,
temperature: float = 1.0):
shifted_logits = logits[..., :-1, :].contiguous() / temperature
shifted_labels = encoding.input_ids[..., 1:].contiguous()
shifted_attention_mask = encoding.attention_mask[..., 1:].contiguous()
if median:
ce_nan = (ce_loss_fn(shifted_logits.transpose(1, 2), shifted_labels).
masked_fill(~shifted_attention_mask.bool(), float("nan")))
ppl = np.nanmedian(ce_nan.cpu().float().numpy(), 1)
else:
ppl = (ce_loss_fn(shifted_logits.transpose(1, 2), shifted_labels) *
shifted_attention_mask).sum(1) / shifted_attention_mask.sum(1)
ppl = ppl.to("cpu").float().numpy()
return ppl
def entropy(p_logits: torch.Tensor,
q_logits: torch.Tensor,
encoding: transformers.BatchEncoding,
pad_token_id: int,
median: bool = False,
sample_p: bool = False,
temperature: float = 1.0):
vocab_size = p_logits.shape[-1]
total_tokens_available = q_logits.shape[-2]
p_scores, q_scores = p_logits / temperature, q_logits / temperature
p_proba = softmax_fn(p_scores).view(-1, vocab_size)
if sample_p:
p_proba = torch.multinomial(p_proba.view(-1, vocab_size), replacement=True, num_samples=1).view(-1)
q_scores = q_scores.view(-1, vocab_size)
ce = ce_loss_fn(input=q_scores, target=p_proba).view(-1, total_tokens_available)
padding_mask = (encoding.input_ids != pad_token_id).type(torch.uint8)
if median:
ce_nan = ce.masked_fill(~padding_mask.bool(), float("nan"))
agg_ce = np.nanmedian(ce_nan.cpu().float().numpy(), 1)
else:
agg_ce = (((ce * padding_mask).sum(1) / padding_mask.sum(1)).to("cpu").float().numpy())
return agg_ce
|