File size: 15,573 Bytes
e487255
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
import os
import torch
import numpy as np
import time
import pickle
from scripts import  tabular_metrics
from scripts.tabular_metrics import calculate_score_per_method
from scripts.tabular_evaluation import evaluate
from priors.differentiable_prior import draw_random_style
from tqdm import tqdm
import random
from scripts.transformer_prediction_interface import get_params_from_config, load_model_workflow

"""
===============================
PUBLIC FUNCTIONS FOR EVALUATION
===============================
"""


def eval_model_range(i_range, *args, **kwargs):
    for i in i_range:
        eval_model(i, *args, **kwargs)



def eval_model(i, e, valid_datasets, test_datasets, train_datasets, eval_positions_valid, eval_positions_test,
               bptt_valid,
               bptt_test, add_name, base_path, device='cpu', eval_addition='', **extra_tuning_args):
    """
    Differentiable model evaliation workflow. Evaluates and saves results to disk.

    :param i:
    :param e:
    :param valid_datasets:
    :param test_datasets:
    :param train_datasets:
    :param eval_positions_valid:
    :param eval_positions_test:
    :param bptt_valid:
    :param bptt_test:
    :param add_name:
    :param base_path:
    :param device:
    :param eval_addition:
    :param extra_tuning_args:
    :return:
    """
    model, c, results_file = load_model_workflow(i, e, add_name, base_path, device, eval_addition)
    params = {'bptt': bptt_valid
        , 'bptt_final': bptt_test
        , 'eval_positions': eval_positions_valid
        , 'eval_positions_test': eval_positions_test
        , 'valid_datasets': valid_datasets
        , 'test_datasets': test_datasets
        , 'train_datasets': train_datasets
        , 'verbose': True
        , 'device': device
              }

    params.update(get_params_from_config(c))

    start = time.time()
    metrics, metrics_valid, style, temperature, optimization_route = evaluate_differentiable_model(model, **params,
                                                                                                   **extra_tuning_args)
    print('Evaluation time: ', time.time() - start)

    print(results_file)
    r = [c.copy(), metrics, metrics_valid, style.to('cpu'), temperature.to('cpu'), optimization_route]
    with open(results_file, 'wb') as output:
        del r[0]['num_features_used']
        del r[0]['categorical_features_sampler']
        pickle.dump(r, output)

    _, _, _, style, temperature, _ = r

    return r, model

"""
===============================
INTERNAL HELPER FUNCTIONS
===============================
"""

def evaluate_differentiable_model(model
                                  , valid_datasets
                                  , test_datasets
                                  , train_datasets
                                  , N_draws=100
                                  , N_grad_steps=10
                                  , eval_positions=None
                                  , eval_positions_test=None
                                  , bptt=100
                                  , bptt_final=200
                                  , style=None
                                  , n_parallel_configurations=1
                                  , device='cpu'
                                  , selection_metric='auc'
                                  , final_splits=[1, 2, 3, 4, 5]
                                  , N_ensemble_configurations_list=[1, 5, 10, 20, 50, 100]
                                  , **kwargs):
    """
    Evaluation function for diffable model evaluation. Returns a list of results.

    :param model:
    :param valid_datasets:
    :param test_datasets:
    :param train_datasets:
    :param N_draws:
    :param N_grad_steps:
    :param eval_positions:
    :param eval_positions_test:
    :param bptt:
    :param bptt_final:
    :param style:
    :param n_parallel_configurations:
    :param device:
    :param selection_metric:
    :param final_splits:
    :param N_ensemble_configurations_list:
    :param kwargs:
    :return:
    """
    torch.manual_seed(0)
    np.random.seed(0)
    random.seed(0)

    diffable_metric = tabular_metrics.cross_entropy
    evaluation_metric = tabular_metrics.auc_metric
    if selection_metric in ('auc', 'roc'):
        selection_metric_min_max = 'max'
        selection_metric = tabular_metrics.auc_metric
        evaluation_metric = selection_metric
    elif selection_metric in ('ce', 'selection_metric'):
        selection_metric_min_max = 'min'
        selection_metric = tabular_metrics.cross_entropy
        evaluation_metric = selection_metric

    print('Diffable metric', diffable_metric, ' Selection metric', selection_metric, ' Evaluation metric',
          evaluation_metric)
    print('N PARALLEL CONFIGURATIONS', n_parallel_configurations)
    print('eval_positions', eval_positions)

    def evaluate_valid(style, softmax_temperature, results, results_tracked):
        result_valid = eval_step(valid_datasets, style, softmax_temperature=softmax_temperature,
                                 return_tensor=False, inference_mode=True, selection_metric=selection_metric,
                                 evaluation_metric=evaluation_metric, eval_positions=eval_positions, bptt=bptt, model=model[2])
        result_valid = [float(result_valid[f'mean_select_at_{pos}']) for pos in eval_positions]
        results += [result_valid]
        results_tracked += [np.nanmean(result_valid)]

    model[2].to(device)
    model[2].eval()

    results_on_valid, results_on_valid_tracked = [], []
    best_style, best_softmax_temperature = style, torch.cat(
        [torch.tensor([0.0]).to(device) for n in range(0, n_parallel_configurations)], 0)
    optimization_routes = []

    best_style = torch.cat([draw_random_style(model[3], device).detach() for n in range(0, n_parallel_configurations)],
                      0)
    best_softmax_temperature = torch.cat([torch.tensor([0.0]).to(device) for n in range(0, n_parallel_configurations)],
                                    0)


    for _ in tqdm(range(0, N_draws), desc='Iterate over Optimization initializations'): # Evaluates N hparam draws
        style = torch.cat([draw_random_style(model[3], device).detach() for n in range(0, n_parallel_configurations)],
                          0)
        softmax_temperature = torch.cat([torch.tensor([0.0]).to(device) for n in range(0, n_parallel_configurations)],
                                        0)

        evaluate_valid(style, softmax_temperature, results_on_valid, results_on_valid_tracked)

        print(f'Draw --> Valid Selection metric: {results_on_valid[-1]}')

        if N_grad_steps > 0:
            gradient_optimize_result = gradient_optimize_style(model, style, N_grad_steps
                                                               , softmax_temperature=softmax_temperature
                                                               , model=model[2]
                                                               , train_datasets=train_datasets
                                                               , valid_datasets=valid_datasets
                                                               , selection_metric_min_max=selection_metric_min_max
                                                               , **kwargs)
            optimization_routes += [gradient_optimize_result['optimization_route']]

            evaluate_valid(gradient_optimize_result['best_style']
                                          , gradient_optimize_result['best_temperature']
                                          , results_on_valid, results_on_valid_tracked)

            print(f'After diff --> Valid Selection metric: {results_on_valid[-1]}')

        if selection_metric_min_max == 'min':
            is_best = (results_on_valid_tracked[-1] <= min(results_on_valid_tracked))
        else:
            is_best = (results_on_valid_tracked[-1] >= max(results_on_valid_tracked))

        if is_best or best_style is None:
            best_style = gradient_optimize_result['best_style'].clone()
            best_softmax_temperature = gradient_optimize_result['best_temperature'].clone()
    torch.cuda.empty_cache()

    def final_evaluation():
        print('Running eval dataset with final params (no gradients)..')
        print(best_style, best_softmax_temperature)
        result_test = []
        for N_ensemble_configurations in N_ensemble_configurations_list:
            print(f'Running with {N_ensemble_configurations} ensemble_configurations')
            kwargs['N_ensemble_configurations'] = N_ensemble_configurations
            splits = []
            for split in final_splits:
                splits += [eval_step(test_datasets, best_style, softmax_temperature=best_softmax_temperature
                                     , return_tensor=False, eval_positions=eval_positions_test,
                                     bptt=bptt_final, inference_mode=True, split_number=split, model=model[2]
                                     , selection_metric=selection_metric, evaluation_metric=evaluation_metric)]
            result_test += [splits]

        print('Running valid dataset with final params (no gradients)..')
        result_valid = eval_step(valid_datasets, best_style, softmax_temperature=best_softmax_temperature
                                 , return_tensor=False, eval_positions=eval_positions_test,
                                 bptt=bptt_final, inference_mode=True, model=model[2]
                                 , selection_metric=selection_metric, evaluation_metric=evaluation_metric)

        return result_test, result_valid

    result_test, result_valid = final_evaluation()

    return result_test, result_valid, best_style, best_softmax_temperature, optimization_routes


def eval_step(ds, used_style, selection_metric, evaluation_metric, eval_positions, return_tensor=True, **kwargs):
    def step():
        return evaluate(datasets=ds,
                        method='transformer'
                        , overwrite=True
                        , style=used_style
                        , eval_positions=eval_positions
                        , metric_used=selection_metric
                        , save=False
                        , path_interfix=None
                        , base_path=None
                        , verbose=True
                        , **kwargs)

    if return_tensor:
        r = step()
    else:
        with torch.no_grad():
            r = step()

    calculate_score_per_method(selection_metric, 'select', r, ds, eval_positions, aggregator='mean')
    calculate_score_per_method(evaluation_metric, 'eval', r, ds, eval_positions, aggregator='mean')

    return r


def gradient_optimize_style(model, init_style, steps, softmax_temperature, train_datasets, valid_datasets, learning_rate=0.03, optimize_all=False,
                            limit_style=True, N_datasets_sampled=90, optimize_softmax_temperature=True, selection_metric_min_max='max', **kwargs):
    """
    Uses gradient based methods to optimize 'style' on the 'train_datasets' and uses stopping with 'valid_datasets'.

    :param model:
    :param init_style:
    :param steps:
    :param learning_rate:
    :param softmax_temperature:
    :param train_datasets:
    :param valid_datasets:
    :param optimize_all:
    :param limit_style:
    :param N_datasets_sampled:
    :param optimize_softmax_temperature:
    :param selection_metric_min_max:
    :param kwargs:
    :return:
    """
    grad_style = torch.nn.Parameter(init_style.detach(), requires_grad=True)

    best_style, best_temperature, best_selection_metric, best_diffable_metric = grad_style.detach(), softmax_temperature.detach(), None, None
    softmax_temperature = torch.nn.Parameter(softmax_temperature.detach(), requires_grad=optimize_softmax_temperature)
    variables_to_optimize = model[2].parameters() if optimize_all else [grad_style, softmax_temperature]
    optimizer = torch.optim.Adam(variables_to_optimize, lr=learning_rate)

    optimization_route_selection, optimization_route_diffable = [], []
    optimization_route_selection_valid, optimization_route_diffable_valid = [], []

    def eval_opt(ds, return_tensor=True, inference_mode=False):
        result = eval_step(ds, grad_style, softmax_temperature=softmax_temperature, return_tensor=return_tensor
                           , inference_mode=inference_mode, model=model[2], **kwargs)

        diffable_metric = result['mean_metric']
        selection_metric = result['mean_select']

        return diffable_metric, selection_metric

    def eval_all_datasets(datasets, propagate=True):
        selection_metrics_this_step, diffable_metrics_this_step = [], []
        for ds in datasets:
            diffable_metric_train, selection_metric_train = eval_opt([ds], inference_mode=(not propagate))
            if not torch.isnan(diffable_metric_train).any():
                if propagate and diffable_metric_train.requires_grad == True:
                    diffable_metric_train.backward()
                selection_metrics_this_step += [selection_metric_train]
                diffable_metrics_this_step += [float(diffable_metric_train.detach().cpu().numpy())]
        diffable_metric_train = np.nanmean(diffable_metrics_this_step)
        selection_metric_train = np.nanmean(selection_metrics_this_step)

        return diffable_metric_train, selection_metric_train

    for t in tqdm(range(steps), desc='Iterate over Optimization steps'):
        optimizer.zero_grad()

        # Select subset of datasets
        random.seed(t)
        train_datasets_ = random.sample(train_datasets, N_datasets_sampled)

        # Get score on train
        diffable_metric_train, selection_metric_train = eval_all_datasets(train_datasets_, propagate=True)
        optimization_route_selection += [float(selection_metric_train)]
        optimization_route_diffable += [float(diffable_metric_train)]

        # Get score on valid
        diffable_metric_valid, selection_metric_valid = eval_all_datasets(valid_datasets, propagate=False)
        optimization_route_selection_valid += [float(selection_metric_valid)]
        optimization_route_diffable_valid += [float(diffable_metric_valid)]

        is_best = (selection_metric_min_max == 'min' and best_selection_metric > selection_metric_valid)
        is_best = is_best or (selection_metric_min_max == 'max' and best_selection_metric < selection_metric_valid)
        if (best_selection_metric is None) or (not np.isnan(selection_metric_valid) and is_best):
            print('New best', best_selection_metric, selection_metric_valid)
            best_style = grad_style.detach().clone()
            best_temperature = softmax_temperature.detach().clone()
            best_selection_metric, best_diffable_metric = selection_metric_valid, diffable_metric_valid

        optimizer.step()

        if limit_style:
            grad_style = grad_style.detach().clamp(-1.74, 1.74)

        print(f'Valid: Diffable metric={diffable_metric_valid} Selection metric={selection_metric_valid};' +
            f'Train: Diffable metric={diffable_metric_train} Selection metric={selection_metric_train}')

    print(f'Return best:{best_style} {best_selection_metric}')
    return {'best_style': best_style, 'best_temperature': best_temperature
            , 'optimization_route': {'select': optimization_route_selection, 'loss': optimization_route_diffable,
               'test_select': optimization_route_selection_valid, 'test_loss': optimization_route_diffable_valid}}