j.gilyazev
add personalized-chat-bot
c1c5bd9
raw
history blame contribute delete
915 Bytes
import numpy as np
import scipy
import torch
def _perplexity(logits, labels, pad_token=3):
for i in range(len(labels)-1, -1, -1):
if labels[i] != pad_token:
last_not_pad_id = i
break
logits = logits[:last_not_pad_id + 1]
labels = labels[:last_not_pad_id + 1]
log_probas = scipy.special.log_softmax(logits, axis=1).astype(np.float32)
log_probas = [log_probas[i][labels[i]] for i in range(len(labels))]
l = np.mean(log_probas)
return 2 ** (-l)
def perplexity(logits, labels, pad_token=3):
pp = []
if isinstance(logits, torch.Tensor):
logits = logits.detach().cpu().numpy()
if isinstance(labels, torch.Tensor):
labels = labels.detach().cpu().numpy()
for cur_logits, cur_labels in zip(logits, labels):
pp.append(_perplexity(np.array(cur_logits), np.array(cur_labels).astype(int), pad_token))
return np.mean(pp)