Spaces:
Running
Running
File size: 8,047 Bytes
aefc9ef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
# Copyright (c) Guangsheng Bao.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import random
import numpy as np
import torch
import torch.nn.functional as F
import tqdm
import argparse
import json
from data_builder import load_data
from model import load_tokenizer, load_model
from metrics import get_roc_metrics, get_precision_recall_metrics
def get_samples(logits, labels):
assert logits.shape[0] == 1
assert labels.shape[0] == 1
nsamples = 10000
lprobs = torch.log_softmax(logits, dim=-1)
distrib = torch.distributions.categorical.Categorical(logits=lprobs)
samples = distrib.sample([nsamples]).permute([1, 2, 0])
return samples
def get_likelihood(logits, labels):
assert logits.shape[0] == 1
assert labels.shape[0] == 1
labels = labels.unsqueeze(-1) if labels.ndim == logits.ndim - 1 else labels
lprobs = torch.log_softmax(logits, dim=-1)
log_likelihood = lprobs.gather(dim=-1, index=labels)
return log_likelihood.mean(dim=1)
def get_sampling_discrepancy(logits_ref, logits_score, labels):
assert logits_ref.shape[0] == 1
assert logits_score.shape[0] == 1
assert labels.shape[0] == 1
if logits_ref.size(-1) != logits_score.size(-1):
# print(f"WARNING: vocabulary size mismatch {logits_ref.size(-1)} vs {logits_score.size(-1)}.")
vocab_size = min(logits_ref.size(-1), logits_score.size(-1))
logits_ref = logits_ref[:, :, :vocab_size]
logits_score = logits_score[:, :, :vocab_size]
samples = get_samples(logits_ref, labels)
log_likelihood_x = get_likelihood(logits_score, labels)
log_likelihood_x_tilde = get_likelihood(logits_score, samples)
miu_tilde = log_likelihood_x_tilde.mean(dim=-1)
sigma_tilde = log_likelihood_x_tilde.std(dim=-1)
discrepancy = (log_likelihood_x.squeeze(-1) - miu_tilde) / sigma_tilde
return discrepancy.item()
def get_sampling_discrepancy_analytic(logits_ref, logits_score, labels):
assert logits_ref.shape[0] == 1
assert logits_score.shape[0] == 1
assert labels.shape[0] == 1
if logits_ref.size(-1) != logits_score.size(-1):
# print(f"WARNING: vocabulary size mismatch {logits_ref.size(-1)} vs {logits_score.size(-1)}.")
vocab_size = min(logits_ref.size(-1), logits_score.size(-1))
logits_ref = logits_ref[:, :, :vocab_size]
logits_score = logits_score[:, :, :vocab_size]
labels = labels.unsqueeze(-1) if labels.ndim == logits_score.ndim - 1 else labels
lprobs_score = torch.log_softmax(logits_score, dim=-1)
probs_ref = torch.softmax(logits_ref, dim=-1)
log_likelihood = lprobs_score.gather(dim=-1, index=labels).squeeze(-1)
mean_ref = (probs_ref * lprobs_score).sum(dim=-1)
var_ref = (probs_ref * torch.square(lprobs_score)).sum(dim=-1) - torch.square(mean_ref)
discrepancy = (log_likelihood.sum(dim=-1) - mean_ref.sum(dim=-1)) / var_ref.sum(dim=-1).sqrt()
discrepancy = discrepancy.mean()
return discrepancy.item()
def experiment(args):
# load model
scoring_tokenizer = load_tokenizer(args.scoring_model_name, args.dataset, args.cache_dir)
scoring_model = load_model(args.scoring_model_name, args.device, args.cache_dir)
scoring_model.eval()
if args.reference_model_name != args.scoring_model_name:
reference_tokenizer = load_tokenizer(args.reference_model_name, args.dataset, args.cache_dir)
reference_model = load_model(args.reference_model_name, args.device, args.cache_dir)
reference_model.eval()
# load data
data = load_data(args.dataset_file)
n_samples = len(data["sampled"])
# evaluate criterion
if args.discrepancy_analytic:
name = "sampling_discrepancy_analytic"
criterion_fn = get_sampling_discrepancy_analytic
else:
name = "sampling_discrepancy"
criterion_fn = get_sampling_discrepancy
random.seed(args.seed)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
results = []
for idx in tqdm.tqdm(range(n_samples), desc=f"Computing {name} criterion"):
original_text = data["original"][idx]
sampled_text = data["sampled"][idx]
# original text
tokenized = scoring_tokenizer(original_text, return_tensors="pt", padding=True, return_token_type_ids=False).to(args.device)
labels = tokenized.input_ids[:, 1:]
with torch.no_grad():
logits_score = scoring_model(**tokenized).logits[:, :-1]
if args.reference_model_name == args.scoring_model_name:
logits_ref = logits_score
else:
tokenized = reference_tokenizer(original_text, return_tensors="pt", padding=True, return_token_type_ids=False).to(args.device)
assert torch.all(tokenized.input_ids[:, 1:] == labels), "Tokenizer is mismatch."
logits_ref = reference_model(**tokenized).logits[:, :-1]
original_crit = criterion_fn(logits_ref, logits_score, labels)
# sampled text
tokenized = scoring_tokenizer(sampled_text, return_tensors="pt", padding=True, return_token_type_ids=False).to(args.device)
labels = tokenized.input_ids[:, 1:]
with torch.no_grad():
logits_score = scoring_model(**tokenized).logits[:, :-1]
if args.reference_model_name == args.scoring_model_name:
logits_ref = logits_score
else:
tokenized = reference_tokenizer(sampled_text, return_tensors="pt", padding=True, return_token_type_ids=False).to(args.device)
assert torch.all(tokenized.input_ids[:, 1:] == labels), "Tokenizer is mismatch."
logits_ref = reference_model(**tokenized).logits[:, :-1]
sampled_crit = criterion_fn(logits_ref, logits_score, labels)
# result
results.append({"original": original_text,
"original_crit": original_crit,
"sampled": sampled_text,
"sampled_crit": sampled_crit})
# compute prediction scores for real/sampled passages
predictions = {'real': [x["original_crit"] for x in results],
'samples': [x["sampled_crit"] for x in results]}
print(f"Real mean/std: {np.mean(predictions['real']):.2f}/{np.std(predictions['real']):.2f}, Samples mean/std: {np.mean(predictions['samples']):.2f}/{np.std(predictions['samples']):.2f}")
fpr, tpr, roc_auc = get_roc_metrics(predictions['real'], predictions['samples'])
p, r, pr_auc = get_precision_recall_metrics(predictions['real'], predictions['samples'])
print(f"Criterion {name}_threshold ROC AUC: {roc_auc:.4f}, PR AUC: {pr_auc:.4f}")
# results
results_file = f'{args.output_file}.{name}.json'
results = { 'name': f'{name}_threshold',
'info': {'n_samples': n_samples},
'predictions': predictions,
'raw_results': results,
'metrics': {'roc_auc': roc_auc, 'fpr': fpr, 'tpr': tpr},
'pr_metrics': {'pr_auc': pr_auc, 'precision': p, 'recall': r},
'loss': 1 - pr_auc}
with open(results_file, 'w') as fout:
json.dump(results, fout)
print(f'Results written into {results_file}')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--output_file', type=str, default="./exp_test/results/xsum_gpt2")
parser.add_argument('--dataset', type=str, default="xsum")
parser.add_argument('--dataset_file', type=str, default="./exp_test/data/xsum_gpt2")
parser.add_argument('--reference_model_name', type=str, default="gpt2")
parser.add_argument('--scoring_model_name', type=str, default="gpt2")
parser.add_argument('--discrepancy_analytic', action='store_true')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--device', type=str, default="cuda")
parser.add_argument('--cache_dir', type=str, default="../cache")
args = parser.parse_args()
experiment(args)
|