Spaces:
Build error
Build error
import os | |
import torch | |
import numpy as np | |
import time | |
import pickle | |
from scripts import tabular_metrics | |
from scripts.tabular_metrics import calculate_score_per_method | |
from scripts.tabular_evaluation import evaluate | |
from priors.differentiable_prior import draw_random_style | |
from tqdm import tqdm | |
from pathlib import Path | |
import random | |
from model_builder import load_model | |
from scripts.transformer_prediction_interface import get_params_from_config | |
""" | |
=============================== | |
PUBLIC FUNCTIONS FOR EVALUATION | |
=============================== | |
""" | |
def eval_model_range(i_range, *args, **kwargs): | |
for i in i_range: | |
eval_model(i, *args, **kwargs) | |
def load_model_workflow(i, e, add_name, base_path, device='cpu', eval_addition=''): | |
""" | |
Workflow for loading a model and setting appropriate parameters for diffable hparam tuning. | |
:param i: | |
:param e: | |
:param eval_positions_valid: | |
:param add_name: | |
:param base_path: | |
:param device: | |
:param eval_addition: | |
:return: | |
""" | |
def check_file(e): | |
model_file = f'models_diff/prior_diff_real_checkpoint{add_name}_n_{i}_epoch_{e}.cpkt' | |
model_path = os.path.join(base_path, model_file) | |
# print('Evaluate ', model_path) | |
results_file = os.path.join(base_path, | |
f'models_diff/prior_diff_real_results{add_name}_n_{i}_epoch_{e}_{eval_addition}.pkl') | |
if not Path(model_path).is_file(): # or Path(results_file).is_file(): | |
return None, None, None | |
return model_file, model_path, results_file | |
model_file = None | |
if e == -1: | |
for e_ in range(100, -1, -1): | |
model_file_, model_path_, results_file_ = check_file(e_) | |
if model_file_ is not None: | |
e = e_ | |
model_file, model_path, results_file = model_file_, model_path_, results_file_ | |
break | |
else: | |
model_file, model_path, results_file = check_file(e) | |
if model_file is None: | |
print('No checkpoint found') | |
return None | |
print(f'Loading {model_file}') | |
model, c = load_model(base_path, model_file, device, eval_positions=[], verbose=False) | |
return model, c, results_file | |
def eval_model(i, e, valid_datasets, test_datasets, train_datasets, eval_positions_valid, eval_positions_test, | |
bptt_valid, | |
bptt_test, add_name, base_path, device='cpu', eval_addition='', **extra_tuning_args): | |
""" | |
Differentiable model evaliation workflow. Evaluates and saves results to disk. | |
:param i: | |
:param e: | |
:param valid_datasets: | |
:param test_datasets: | |
:param train_datasets: | |
:param eval_positions_valid: | |
:param eval_positions_test: | |
:param bptt_valid: | |
:param bptt_test: | |
:param add_name: | |
:param base_path: | |
:param device: | |
:param eval_addition: | |
:param extra_tuning_args: | |
:return: | |
""" | |
model, c, results_file = load_model_workflow(i, e, add_name, base_path, device, eval_addition) | |
params = {'bptt': bptt_valid | |
, 'bptt_final': bptt_test | |
, 'eval_positions': eval_positions_valid | |
, 'eval_positions_test': eval_positions_test | |
, 'valid_datasets': valid_datasets | |
, 'test_datasets': test_datasets | |
, 'train_datasets': train_datasets | |
, 'verbose': True | |
, 'device': device | |
} | |
params.update(get_params_from_config(c)) | |
start = time.time() | |
metrics, metrics_valid, style, temperature, optimization_route = evaluate_differentiable_model(model, **params, | |
**extra_tuning_args) | |
print('Evaluation time: ', time.time() - start) | |
print(results_file) | |
r = [c.copy(), metrics, metrics_valid, style.to('cpu'), temperature.to('cpu'), optimization_route] | |
with open(results_file, 'wb') as output: | |
del r[0]['num_features_used'] | |
del r[0]['categorical_features_sampler'] | |
pickle.dump(r, output) | |
_, _, _, style, temperature, _ = r | |
return r, model | |
""" | |
=============================== | |
INTERNAL HELPER FUNCTIONS | |
=============================== | |
""" | |
def evaluate_differentiable_model(model | |
, valid_datasets | |
, test_datasets | |
, train_datasets | |
, N_draws=100 | |
, N_grad_steps=10 | |
, eval_positions=None | |
, eval_positions_test=None | |
, bptt=100 | |
, bptt_final=200 | |
, style=None | |
, n_parallel_configurations=1 | |
, device='cpu' | |
, selection_metric='auc' | |
, final_splits=[1, 2, 3, 4, 5] | |
, N_ensemble_configurations_list=[1, 5, 10, 20, 50, 100] | |
, **kwargs): | |
""" | |
Evaluation function for diffable model evaluation. Returns a list of results. | |
:param model: | |
:param valid_datasets: | |
:param test_datasets: | |
:param train_datasets: | |
:param N_draws: | |
:param N_grad_steps: | |
:param eval_positions: | |
:param eval_positions_test: | |
:param bptt: | |
:param bptt_final: | |
:param style: | |
:param n_parallel_configurations: | |
:param device: | |
:param selection_metric: | |
:param final_splits: | |
:param N_ensemble_configurations_list: | |
:param kwargs: | |
:return: | |
""" | |
torch.manual_seed(0) | |
np.random.seed(0) | |
random.seed(0) | |
diffable_metric = tabular_metrics.cross_entropy | |
evaluation_metric = tabular_metrics.auc_metric | |
if selection_metric in ('auc', 'roc'): | |
selection_metric_min_max = 'max' | |
selection_metric = tabular_metrics.auc_metric | |
evaluation_metric = selection_metric | |
elif selection_metric in ('ce', 'selection_metric'): | |
selection_metric_min_max = 'min' | |
selection_metric = tabular_metrics.cross_entropy | |
evaluation_metric = selection_metric | |
print('Diffable metric', diffable_metric, ' Selection metric', selection_metric, ' Evaluation metric', | |
evaluation_metric) | |
print('N PARALLEL CONFIGURATIONS', n_parallel_configurations) | |
print('eval_positions', eval_positions) | |
def evaluate_valid(style, softmax_temperature, results, results_tracked): | |
result_valid = eval_step(valid_datasets, style, softmax_temperature=softmax_temperature, | |
return_tensor=False, inference_mode=True, selection_metric=selection_metric, | |
evaluation_metric=evaluation_metric, eval_positions=eval_positions, bptt=bptt, model=model[2]) | |
result_valid = [float(result_valid[f'mean_select_at_{pos}']) for pos in eval_positions] | |
results += [result_valid] | |
results_tracked += [np.nanmean(result_valid)] | |
model[2].to(device) | |
model[2].eval() | |
results_on_valid, results_on_valid_tracked = [], [] | |
best_style, best_softmax_temperature = style, torch.cat( | |
[torch.tensor([0.0]).to(device) for n in range(0, n_parallel_configurations)], 0) | |
optimization_routes = [] | |
best_style = torch.cat([draw_random_style(model[3], device).detach() for n in range(0, n_parallel_configurations)], | |
0) | |
best_softmax_temperature = torch.cat([torch.tensor([0.0]).to(device) for n in range(0, n_parallel_configurations)], | |
0) | |
for _ in tqdm(range(0, N_draws), desc='Iterate over Optimization initializations'): # Evaluates N hparam draws | |
style = torch.cat([draw_random_style(model[3], device).detach() for n in range(0, n_parallel_configurations)], | |
0) | |
softmax_temperature = torch.cat([torch.tensor([0.0]).to(device) for n in range(0, n_parallel_configurations)], | |
0) | |
evaluate_valid(style, softmax_temperature, results_on_valid, results_on_valid_tracked) | |
print(f'Draw --> Valid Selection metric: {results_on_valid[-1]}') | |
if N_grad_steps > 0: | |
gradient_optimize_result = gradient_optimize_style(model, style, N_grad_steps | |
, softmax_temperature=softmax_temperature | |
, model=model[2] | |
, train_datasets=train_datasets | |
, valid_datasets=valid_datasets | |
, selection_metric_min_max=selection_metric_min_max | |
, **kwargs) | |
optimization_routes += [gradient_optimize_result['optimization_route']] | |
evaluate_valid(gradient_optimize_result['best_style'] | |
, gradient_optimize_result['best_temperature'] | |
, results_on_valid, results_on_valid_tracked) | |
print(f'After diff --> Valid Selection metric: {results_on_valid[-1]}') | |
if selection_metric_min_max == 'min': | |
is_best = (results_on_valid_tracked[-1] <= min(results_on_valid_tracked)) | |
else: | |
is_best = (results_on_valid_tracked[-1] >= max(results_on_valid_tracked)) | |
if is_best or best_style is None: | |
best_style = gradient_optimize_result['best_style'].clone() | |
best_softmax_temperature = gradient_optimize_result['best_temperature'].clone() | |
torch.cuda.empty_cache() | |
def final_evaluation(): | |
print('Running eval dataset with final params (no gradients)..') | |
print(best_style, best_softmax_temperature) | |
result_test = [] | |
for N_ensemble_configurations in N_ensemble_configurations_list: | |
print(f'Running with {N_ensemble_configurations} ensemble_configurations') | |
kwargs['N_ensemble_configurations'] = N_ensemble_configurations | |
splits = [] | |
for split in final_splits: | |
splits += [eval_step(test_datasets, best_style, softmax_temperature=best_softmax_temperature | |
, return_tensor=False, eval_positions=eval_positions_test, | |
bptt=bptt_final, inference_mode=True, split_number=split, model=model[2] | |
, selection_metric=selection_metric, evaluation_metric=evaluation_metric)] | |
result_test += [splits] | |
print('Running valid dataset with final params (no gradients)..') | |
result_valid = eval_step(valid_datasets, best_style, softmax_temperature=best_softmax_temperature | |
, return_tensor=False, eval_positions=eval_positions_test, | |
bptt=bptt_final, inference_mode=True, model=model[2] | |
, selection_metric=selection_metric, evaluation_metric=evaluation_metric) | |
return result_test, result_valid | |
result_test, result_valid = final_evaluation() | |
return result_test, result_valid, best_style, best_softmax_temperature, optimization_routes | |
def eval_step(ds, used_style, selection_metric, evaluation_metric, eval_positions, return_tensor=True, **kwargs): | |
def step(): | |
return evaluate(datasets=ds, | |
method='transformer' | |
, overwrite=True | |
, style=used_style | |
, eval_positions=eval_positions | |
, metric_used=selection_metric | |
, save=False | |
, path_interfix=None | |
, base_path=None | |
, verbose=True | |
, **kwargs) | |
if return_tensor: | |
r = step() | |
else: | |
with torch.no_grad(): | |
r = step() | |
calculate_score_per_method(selection_metric, 'select', r, ds, eval_positions, aggregator='mean') | |
calculate_score_per_method(evaluation_metric, 'eval', r, ds, eval_positions, aggregator='mean') | |
return r | |
def gradient_optimize_style(model, init_style, steps, softmax_temperature, train_datasets, valid_datasets, learning_rate=0.03, optimize_all=False, | |
limit_style=True, N_datasets_sampled=90, optimize_softmax_temperature=True, selection_metric_min_max='max', **kwargs): | |
""" | |
Uses gradient based methods to optimize 'style' on the 'train_datasets' and uses stopping with 'valid_datasets'. | |
:param model: | |
:param init_style: | |
:param steps: | |
:param learning_rate: | |
:param softmax_temperature: | |
:param train_datasets: | |
:param valid_datasets: | |
:param optimize_all: | |
:param limit_style: | |
:param N_datasets_sampled: | |
:param optimize_softmax_temperature: | |
:param selection_metric_min_max: | |
:param kwargs: | |
:return: | |
""" | |
grad_style = torch.nn.Parameter(init_style.detach(), requires_grad=True) | |
best_style, best_temperature, best_selection_metric, best_diffable_metric = grad_style.detach(), softmax_temperature.detach(), None, None | |
softmax_temperature = torch.nn.Parameter(softmax_temperature.detach(), requires_grad=optimize_softmax_temperature) | |
variables_to_optimize = model[2].parameters() if optimize_all else [grad_style, softmax_temperature] | |
optimizer = torch.optim.Adam(variables_to_optimize, lr=learning_rate) | |
optimization_route_selection, optimization_route_diffable = [], [] | |
optimization_route_selection_valid, optimization_route_diffable_valid = [], [] | |
def eval_opt(ds, return_tensor=True, inference_mode=False): | |
result = eval_step(ds, grad_style, softmax_temperature=softmax_temperature, return_tensor=return_tensor | |
, inference_mode=inference_mode, model=model[2], **kwargs) | |
diffable_metric = result['mean_metric'] | |
selection_metric = result['mean_select'] | |
return diffable_metric, selection_metric | |
def eval_all_datasets(datasets, propagate=True): | |
selection_metrics_this_step, diffable_metrics_this_step = [], [] | |
for ds in datasets: | |
diffable_metric_train, selection_metric_train = eval_opt([ds], inference_mode=(not propagate)) | |
if not torch.isnan(diffable_metric_train).any(): | |
if propagate and diffable_metric_train.requires_grad == True: | |
diffable_metric_train.backward() | |
selection_metrics_this_step += [selection_metric_train] | |
diffable_metrics_this_step += [float(diffable_metric_train.detach().cpu().numpy())] | |
diffable_metric_train = np.nanmean(diffable_metrics_this_step) | |
selection_metric_train = np.nanmean(selection_metrics_this_step) | |
return diffable_metric_train, selection_metric_train | |
for t in tqdm(range(steps), desc='Iterate over Optimization steps'): | |
optimizer.zero_grad() | |
# Select subset of datasets | |
random.seed(t) | |
train_datasets_ = random.sample(train_datasets, N_datasets_sampled) | |
# Get score on train | |
diffable_metric_train, selection_metric_train = eval_all_datasets(train_datasets_, propagate=True) | |
optimization_route_selection += [float(selection_metric_train)] | |
optimization_route_diffable += [float(diffable_metric_train)] | |
# Get score on valid | |
diffable_metric_valid, selection_metric_valid = eval_all_datasets(valid_datasets, propagate=False) | |
optimization_route_selection_valid += [float(selection_metric_valid)] | |
optimization_route_diffable_valid += [float(diffable_metric_valid)] | |
is_best = (selection_metric_min_max == 'min' and best_selection_metric > selection_metric_valid) | |
is_best = is_best or (selection_metric_min_max == 'max' and best_selection_metric < selection_metric_valid) | |
if (best_selection_metric is None) or (not np.isnan(selection_metric_valid) and is_best): | |
print('New best', best_selection_metric, selection_metric_valid) | |
best_style = grad_style.detach().clone() | |
best_temperature = softmax_temperature.detach().clone() | |
best_selection_metric, best_diffable_metric = selection_metric_valid, diffable_metric_valid | |
optimizer.step() | |
if limit_style: | |
grad_style = grad_style.detach().clamp(-1.74, 1.74) | |
print(f'Valid: Diffable metric={diffable_metric_valid} Selection metric={selection_metric_valid};' + | |
f'Train: Diffable metric={diffable_metric_train} Selection metric={selection_metric_train}') | |
print(f'Return best:{best_style} {best_selection_metric}') | |
return {'best_style': best_style, 'best_temperature': best_temperature | |
, 'optimization_route': {'select': optimization_route_selection, 'loss': optimization_route_diffable, | |
'test_select': optimization_route_selection_valid, 'test_loss': optimization_route_diffable_valid}} |