Hannes Kuchelmeister commited on
Commit
6693f22
1 Parent(s): e447dcf

add hyperparameter search for convolutional model

Browse files
configs/hparams_search/focusConvMSE_150_hyperparameter_search.yaml ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # example hyperparameter optimization of some experiment with Optuna:
4
+ # python train.py -m hparams_search=mnist_optuna experiment=example
5
+
6
+ defaults:
7
+ - override /datamodule: focus150.yaml
8
+ - override /model: focusConv_150.yaml
9
+ - override /hydra/sweeper: optuna
10
+
11
+ # choose metric which will be optimized by Optuna
12
+ # make sure this is the correct name of some metric logged in lightning module!
13
+ optimized_metric: "val/mae_best"
14
+
15
+ name: "focusConvMSE_150_hyperparameter_search"
16
+
17
+ # here we define Optuna hyperparameter search
18
+ # it optimizes for value returned from function with @hydra.main decorator
19
+ # docs: https://hydra.cc/docs/next/plugins/optuna_sweeper
20
+ hydra:
21
+ sweeper:
22
+ _target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper
23
+
24
+ # storage URL to persist optimization results
25
+ # for example, you can use SQLite if you set 'sqlite:///example.db'
26
+ storage: null
27
+
28
+ # name of the study to persist optimization results
29
+ study_name: focusConvMSE_150_hyperparameter_search
30
+
31
+ # number of parallel workers
32
+ n_jobs: 1
33
+
34
+ # 'minimize' or 'maximize' the objective
35
+ direction: minimize
36
+
37
+ # total number of runs that will be executed
38
+ n_trials: 20
39
+
40
+ # choose Optuna hyperparameter sampler
41
+ # docs: https://optuna.readthedocs.io/en/stable/reference/samplers.html
42
+ sampler:
43
+ _target_: optuna.samplers.TPESampler
44
+ seed: 12345
45
+ n_startup_trials: 10 # number of random sampling runs before optimization starts
46
+
47
+ # define range of hyperparameters
48
+ search_space:
49
+ datamodule.batch_size:
50
+ type: categorical
51
+ choices: [64, 128]
52
+ model.lr:
53
+ type: float
54
+ low: 0.0001
55
+ high: 0.2
56
+ model.pool_size:
57
+ type: categorical
58
+ choices: [1, 2, 3]
59
+ model.conv1_size:
60
+ type: categorical
61
+ choices: [3, 5, 7, 9]
62
+ model.conv1_channels:
63
+ type: categorical
64
+ choices: [1, 3, 6, 9]
65
+ model.conv2_size:
66
+ type: categorical
67
+ choices: [3, 5, 7, 9]
68
+ model.conv2_channels:
69
+ type: categorical
70
+ choices: [1, 3, 6, 9]
71
+ model.lin1_size:
72
+ type: categorical
73
+ choices: [16, 32, 64, 96, 128]
74
+ model.lin2_size:
75
+ type: categorical
76
+ choices: [16, 32, 64, 96, 128]