Spaces:
Running
Running
# Copyright (c) Guangsheng Bao. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import os.path | |
import numpy as np | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
import re | |
import torch | |
import tqdm | |
import argparse | |
import json | |
from data_builder import load_data, save_data | |
from metrics import get_roc_metrics, get_precision_recall_metrics | |
from model import load_tokenizer, load_model, get_model_fullname, from_pretrained | |
from data_builder import load_data | |
from model import load_tokenizer, load_model | |
from metrics import get_roc_metrics, get_precision_recall_metrics | |
import custom_datasets | |
class PrefixSampler: | |
def __init__(self, args): | |
self.args = args | |
self.base_tokenizer = load_tokenizer(args.base_model_name, args.dataset, args.cache_dir) | |
self.base_model = load_model(args.base_model_name, args.device, args.cache_dir) | |
def _sample_from_model(self, texts, min_words=55, truncate_ratio=0.5): | |
# encode each text as a list of token ids | |
if self.args.dataset == 'pubmed': | |
pubmed_sep = ' Answer:' | |
texts = [t[:t.index(pubmed_sep) + len(pubmed_sep)] for t in texts] | |
all_encoded = self.base_tokenizer(texts, return_tensors="pt", padding=True).to(self.args.device) | |
else: | |
texts = [t.split(' ') for t in texts] | |
texts = [' '.join(t[: int(len(t) * truncate_ratio)]) for t in texts] | |
all_encoded = self.base_tokenizer(texts, return_tensors="pt", padding=True).to(self.args.device) | |
self.base_model.eval() | |
decoded = ['' for _ in range(len(texts))] | |
# sample from the model until we get a sample with at least min_words words for each example | |
# this is an inefficient way to do this (since we regenerate for all inputs if just one is too short), but it works | |
tries = 0 | |
m = 0 | |
while m < min_words: | |
if tries != 0: | |
print() | |
print(f"min words: {m}, needed {min_words}, regenerating (try {tries})") | |
sampling_kwargs = {'temperature': self.args.temperature} | |
if self.args.do_top_p: | |
sampling_kwargs['top_p'] = self.args.top_p | |
elif self.args.do_top_k: | |
sampling_kwargs['top_k'] = self.args.top_k | |
min_length = 50 if self.args.dataset in ['pubmed'] else 150 | |
outputs = self.base_model.generate(**all_encoded, min_length=min_length, max_length=200, do_sample=True, | |
**sampling_kwargs, pad_token_id=self.base_tokenizer.eos_token_id, | |
eos_token_id=self.base_tokenizer.eos_token_id) | |
decoded = self.base_tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
m = min(len(x.split()) for x in decoded) | |
tries += 1 | |
return decoded | |
def generate_samples(self, raw_data, batch_size): | |
# trim to shorter length | |
def _trim_to_shorter_length(texta, textb): | |
# truncate to shorter of o and s | |
shorter_length = min(len(texta.split(' ')), len(textb.split(' '))) | |
texta = ' '.join(texta.split(' ')[:shorter_length]) | |
textb = ' '.join(textb.split(' ')[:shorter_length]) | |
return texta, textb | |
def _truncate_to_substring(text, substring, idx_occurrence): | |
# truncate everything after the idx_occurrence occurrence of substring | |
assert idx_occurrence > 0, 'idx_occurrence must be > 0' | |
idx = -1 | |
for _ in range(idx_occurrence): | |
idx = text.find(substring, idx + 1) | |
if idx == -1: | |
return text | |
return text[:idx] | |
data = { | |
"original": [], | |
"sampled": [], | |
} | |
assert len(raw_data) % batch_size == 0 | |
for batch in range(len(raw_data) // batch_size): | |
print('Generating samples for batch', batch, 'of', len(raw_data) // batch_size) | |
original_text = raw_data[batch * batch_size:(batch + 1) * batch_size] | |
sampled_text = self._sample_from_model(original_text, min_words=30 if self.args.dataset in ['pubmed'] else 55, truncate_ratio=self.args.truncate_ratio) | |
for o, s in zip(original_text, sampled_text): | |
if self.args.dataset == 'pubmed': | |
s = _truncate_to_substring(s, 'Question:', 2) | |
o = o.replace(custom_datasets.SEPARATOR, ' ') | |
o, s = _trim_to_shorter_length(o, s) | |
# add to the data | |
data["original"].append(o) | |
data["sampled"].append(s) | |
return data | |
def get_likelihood(logits, labels, pad_index): | |
labels = labels.unsqueeze(-1) if labels.ndim == logits.ndim - 1 else labels | |
lprobs = torch.log_softmax(logits, dim=-1) | |
log_likelihood = lprobs.gather(dim=-1, index=labels) | |
mask = labels != pad_index | |
log_likelihood = (log_likelihood * mask).sum(dim=1) / mask.sum(dim=1) | |
return log_likelihood.squeeze(-1) | |
def get_log_prob(sampler, text): | |
tokenized = sampler.base_tokenizer(text, return_tensors="pt", padding=True).to(sampler.args.device) | |
labels = tokenized.input_ids[:, 1:] | |
with torch.no_grad(): | |
logits_score = sampler.base_model(**tokenized).logits[:, :-1] | |
return get_likelihood(logits_score, labels, sampler.base_tokenizer.pad_token_id) | |
def get_log_probs(sampler, texts): | |
batch_size = sampler.args.batch_size | |
batch_lprobs = [] | |
for batch in range(len(texts) // batch_size): | |
tokenized = sampler.base_tokenizer(texts[batch * batch_size:(batch + 1) * batch_size], return_tensors="pt", padding=True).to(sampler.args.device) | |
labels = tokenized.input_ids[:, 1:] | |
with torch.no_grad(): | |
logits_score = sampler.base_model(**tokenized).logits[:, :-1] | |
lprobs = get_likelihood(logits_score, labels, sampler.base_tokenizer.pad_token_id) | |
batch_lprobs.append(lprobs) | |
return torch.cat(batch_lprobs, dim=0) | |
def get_regen_samples(sampler, text): | |
data = [text] * sampler.args.regen_number | |
data = sampler.generate_samples(data, batch_size=sampler.args.batch_size) | |
return data['sampled'] | |
def get_dna_gpt(sampler, text): | |
lprob = get_log_prob(sampler, text) | |
regens = get_regen_samples(sampler, text) | |
lprob_regens = get_log_probs(sampler, regens) | |
wscore = lprob[0] - lprob_regens.mean() | |
return wscore.item() | |
def experiment(args): | |
sampler = PrefixSampler(args) | |
# load data | |
data = load_data(args.dataset_file) | |
n_samples = len(data["sampled"]) | |
# evaluate criterion | |
name = "dna_gpt" | |
criterion_fn = get_dna_gpt | |
torch.manual_seed(args.seed) | |
np.random.seed(args.seed) | |
results = [] | |
for idx in tqdm.tqdm(range(n_samples), desc=f"Computing {name} criterion"): | |
original_text = data["original"][idx] | |
sampled_text = data["sampled"][idx] | |
# original text | |
original_crit = criterion_fn(sampler, original_text) | |
# sampled text | |
sampled_crit = criterion_fn(sampler, sampled_text) | |
# result | |
results.append({"original": original_text, | |
"original_crit": original_crit, | |
"sampled": sampled_text, | |
"sampled_crit": sampled_crit}) | |
# compute prediction scores for real/sampled passages | |
predictions = {'real': [x["original_crit"] for x in results], | |
'samples': [x["sampled_crit"] for x in results]} | |
fpr, tpr, roc_auc = get_roc_metrics(predictions['real'], predictions['samples']) | |
p, r, pr_auc = get_precision_recall_metrics(predictions['real'], predictions['samples']) | |
print(f"Criterion {name}_threshold ROC AUC: {roc_auc:.4f}, PR AUC: {pr_auc:.4f}") | |
# results | |
results_file = f'{args.output_file}.{name}.json' | |
results = { 'name': f'{name}_threshold', | |
'info': {'n_samples': n_samples}, | |
'predictions': predictions, | |
'raw_results': results, | |
'metrics': {'roc_auc': roc_auc, 'fpr': fpr, 'tpr': tpr}, | |
'pr_metrics': {'pr_auc': pr_auc, 'precision': p, 'recall': r}, | |
'loss': 1 - pr_auc} | |
with open(results_file, 'w') as fout: | |
json.dump(results, fout) | |
print(f'Results written into {results_file}') | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--output_file', type=str, default="./exp_test/results/pubmed_davinci") | |
parser.add_argument('--dataset', type=str, default="pubmed") | |
parser.add_argument('--dataset_file', type=str, default="./exp_test/data/pubmed_davinci") | |
parser.add_argument('--truncate_ratio', type=float, default=0.5) | |
parser.add_argument('--regen_number', type=int, default=10) | |
parser.add_argument('--base_model_name', type=str, default="gpt2") | |
parser.add_argument('--batch_size', type=int, default=10) | |
parser.add_argument('--do_top_k', action='store_true') | |
parser.add_argument('--top_k', type=int, default=40) | |
parser.add_argument('--do_top_p', action='store_true') | |
parser.add_argument('--top_p', type=float, default=0.96) | |
parser.add_argument('--temperature', type=float, default=1.0) | |
parser.add_argument('--seed', type=int, default=0) | |
parser.add_argument('--device', type=str, default="cuda") | |
parser.add_argument('--cache_dir', type=str, default="../cache") | |
args = parser.parse_args() | |
experiment(args) | |