TabPFNEvaluationDemo / TabPFN /scripts /tabular_metrics.py
Samuel Mueller
init
e487255
"""
===============================
Metrics calculation
===============================
Includes a few metric as well as functions composing metrics on results files.
"""
import numpy as np
import torch
from sklearn.metrics import roc_auc_score, accuracy_score, balanced_accuracy_score, average_precision_score
from scipy.stats import rankdata
import pandas as pd
"""
===============================
Metrics calculation
===============================
"""
def auc_metric(target, pred, multi_class='ovo', numpy=False):
lib = np if numpy else torch
try:
if not numpy:
target = torch.tensor(target) if not torch.is_tensor(target) else target
pred = torch.tensor(pred) if not torch.is_tensor(pred) else pred
if len(lib.unique(target)) > 2:
if not numpy:
return torch.tensor(roc_auc_score(target, pred, multi_class=multi_class))
return roc_auc_score(target, pred, multi_class=multi_class)
else:
if len(pred.shape) == 2:
pred = pred[:, 1]
if not numpy:
return torch.tensor(roc_auc_score(target, pred))
return roc_auc_score(target, pred)
except ValueError as e:
print(e)
return np.nan
def accuracy_metric(target, pred):
target = torch.tensor(target) if not torch.is_tensor(target) else target
pred = torch.tensor(pred) if not torch.is_tensor(pred) else pred
if len(torch.unique(target)) > 2:
return torch.tensor(accuracy_score(target, torch.argmax(pred, -1)))
else:
return torch.tensor(accuracy_score(target, pred[:, 1] > 0.5))
def average_precision_metric(target, pred):
target = torch.tensor(target) if not torch.is_tensor(target) else target
pred = torch.tensor(pred) if not torch.is_tensor(pred) else pred
if len(torch.unique(target)) > 2:
return torch.tensor(average_precision_score(target, torch.argmax(pred, -1)))
else:
return torch.tensor(average_precision_score(target, pred[:, 1] > 0.5))
def balanced_accuracy_metric(target, pred):
target = torch.tensor(target) if not torch.is_tensor(target) else target
pred = torch.tensor(pred) if not torch.is_tensor(pred) else pred
if len(torch.unique(target)) > 2:
return torch.tensor(balanced_accuracy_score(target, torch.argmax(pred, -1)))
else:
return torch.tensor(balanced_accuracy_score(target, pred[:, 1] > 0.5))
def cross_entropy(target, pred):
target = torch.tensor(target) if not torch.is_tensor(target) else target
pred = torch.tensor(pred) if not torch.is_tensor(pred) else pred
if len(torch.unique(target)) > 2:
ce = torch.nn.CrossEntropyLoss()
return ce(pred.float(), target.long())
else:
bce = torch.nn.BCELoss()
return bce(pred[:, 1].float(), target.float())
def time_metric():
"""
Dummy function, will just be used as a handler.
"""
pass
def count_metric(x, y):
"""
Dummy function, returns one count per dataset.
"""
return 1
"""
===============================
Metrics composition
===============================
"""
def calculate_score_per_method(metric, name:str, global_results:dict, ds:list, eval_positions:list, aggregator:str='mean'):
"""
Calculates the metric given by 'metric' and saves it under 'name' in the 'global_results'
:param metric: Metric function
:param name: Name of metric in 'global_results'
:param global_results: Dicrtonary containing the results for current method for a collection of datasets
:param ds: Dataset to calculate metrics on, a list of dataset properties
:param eval_positions: List of positions to calculate metrics on
:param aggregator: Specifies way to aggregate results across evaluation positions
:return:
"""
aggregator_f = np.nanmean if aggregator == 'mean' else np.nansum
for pos in eval_positions:
valid_positions = 0
for d in ds:
if f'{d[0]}_outputs_at_{pos}' in global_results:
preds = global_results[f'{d[0]}_outputs_at_{pos}']
y = global_results[f'{d[0]}_ys_at_{pos}']
preds, y = preds.detach().cpu().numpy() if torch.is_tensor(
preds) else preds, y.detach().cpu().numpy() if torch.is_tensor(y) else y
try:
if metric == time_metric:
global_results[f'{d[0]}_{name}_at_{pos}'] = global_results[f'{d[0]}_time_at_{pos}']
valid_positions = valid_positions + 1
else:
global_results[f'{d[0]}_{name}_at_{pos}'] = aggregator_f(
[metric(y[split], preds[split]) for split in range(y.shape[0])])
valid_positions = valid_positions + 1
except Exception as err:
print(f'Error calculating metric with {err}, {type(err)} at {d[0]} {pos} {name}')
global_results[f'{d[0]}_{name}_at_{pos}'] = np.nan
else:
global_results[f'{d[0]}_{name}_at_{pos}'] = np.nan
if valid_positions > 0:
global_results[f'{aggregator}_{name}_at_{pos}'] = aggregator_f([global_results[f'{d[0]}_{name}_at_{pos}'] for d in ds])
else:
global_results[f'{aggregator}_{name}_at_{pos}'] = np.nan
for d in ds:
metrics = [global_results[f'{d[0]}_{name}_at_{pos}'] for pos in eval_positions]
metrics = [m for m in metrics if not np.isnan(m)]
global_results[f'{d[0]}_{aggregator}_{name}'] = aggregator_f(metrics) if len(metrics) > 0 else np.nan
metrics = [global_results[f'{aggregator}_{name}_at_{pos}'] for pos in eval_positions]
metrics = [m for m in metrics if not np.isnan(m)]
global_results[f'{aggregator}_{name}'] = aggregator_f(metrics) if len(metrics) > 0 else np.nan
def calculate_score(metric, name, global_results, ds, eval_positions, aggregator='mean', limit_to=''):
"""
Calls calculate_metrics_by_method with a range of methods. See arguments of that method.
:param limit_to: This method will not get metric calculations.
"""
for m in global_results:
if limit_to not in m:
continue
calculate_score_per_method(metric, name, global_results[m], ds, eval_positions, aggregator=aggregator)
def make_metric_matrix(global_results, methods, pos, name, ds):
result = []
for m in global_results:
result += [[global_results[m][d[0] + '_' + name + '_at_' + str(pos)] for d in ds]]
result = np.array(result)
result = pd.DataFrame(result.T, index=[d[0] for d in ds], columns=[k[:-8] for k in list(global_results.keys())])
matrix_means, matrix_stds = [], []
for method in methods:
matrix_means += [result.iloc[:, [(method) in c for c in result.columns]].mean(axis=1)]
matrix_stds += [result.iloc[:, [(method) in c for c in result.columns]].std(axis=1)]
matrix_means = pd.DataFrame(matrix_means, index=methods).T
matrix_stds = pd.DataFrame(matrix_stds, index=methods).T
return matrix_means, matrix_stds
def make_ranks_and_wins_table(matrix):
for dss in matrix.T:
matrix.loc[dss] = rankdata(-matrix.round(3).loc[dss])
ranks_acc = matrix.mean()
wins_acc = (matrix == 1).sum()
return ranks_acc, wins_acc